├── testimgs ├── 1.png ├── 1_1.JPG ├── 22.png ├── 23.png ├── 3_1.JPG ├── 4_1.JPG ├── Cave.png ├── Madison.png └── Farmhouse.png ├── README.md ├── main.py └── main.cpp /testimgs/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Diffusion-Low-Light-onnxrun/HEAD/testimgs/1.png -------------------------------------------------------------------------------- /testimgs/1_1.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Diffusion-Low-Light-onnxrun/HEAD/testimgs/1_1.JPG -------------------------------------------------------------------------------- /testimgs/22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Diffusion-Low-Light-onnxrun/HEAD/testimgs/22.png -------------------------------------------------------------------------------- /testimgs/23.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Diffusion-Low-Light-onnxrun/HEAD/testimgs/23.png -------------------------------------------------------------------------------- /testimgs/3_1.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Diffusion-Low-Light-onnxrun/HEAD/testimgs/3_1.JPG -------------------------------------------------------------------------------- /testimgs/4_1.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Diffusion-Low-Light-onnxrun/HEAD/testimgs/4_1.JPG -------------------------------------------------------------------------------- /testimgs/Cave.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Diffusion-Low-Light-onnxrun/HEAD/testimgs/Cave.png -------------------------------------------------------------------------------- /testimgs/Madison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Diffusion-Low-Light-onnxrun/HEAD/testimgs/Madison.png -------------------------------------------------------------------------------- /testimgs/Farmhouse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/Diffusion-Low-Light-onnxrun/HEAD/testimgs/Farmhouse.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 训练源码在https://github.com/JianghaiSCU/Diffusion-Low-Light 2 | ,看效果不错,我就编写了推理部署程序。opencv-dnn加载onnx文件报错,因此使用onnxruntime做推理引擎。 3 | 4 | onnx文件在百度云盘,链接: https://pan.baidu.com/s/1Uj15fnQaREw0SlfA6YeIPw 提取码: bbtx 5 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | import argparse 4 | import cv2 5 | import numpy as np 6 | import onnxruntime 7 | import matplotlib.pyplot as plt 8 | 9 | 10 | class Diffusion_Low_Light: 11 | def __init__(self, modelpath): 12 | so = onnxruntime.SessionOptions() 13 | so.log_severity_level = 3 14 | # Initialize model 15 | # net = cv2.dnn.readNet(modelpath) ###读取失败 16 | self.onnx_session = onnxruntime.InferenceSession(modelpath, so) 17 | self.input_name = self.onnx_session.get_inputs()[0].name 18 | 19 | input_shape = self.onnx_session.get_inputs()[0].shape 20 | self.input_height = input_shape[2] 21 | self.input_width = input_shape[3] 22 | 23 | def prepare_input(self, image): 24 | input_image = cv2.resize(cv2.cvtColor(image, cv2.COLOR_BGR2RGB), dsize=( 25 | self.input_width, self.input_height)) 26 | input_image = input_image.astype(np.float32) / 255.0 27 | input_image = input_image.transpose(2, 0, 1) 28 | input_image = np.expand_dims(input_image, axis=0) 29 | return input_image 30 | 31 | def detect(self, image): 32 | input_image = self.prepare_input(image) 33 | 34 | # Perform inference on the image 35 | result = self.onnx_session.run(None, {self.input_name: input_image}) 36 | 37 | # Post process:squeeze, RGB->BGR, Transpose, uint8 cast 38 | output_image = np.squeeze(result[0]) 39 | output_image = output_image.transpose(1, 2, 0) 40 | output_image = output_image * 255 41 | output_image = output_image.astype(np.uint8) 42 | output_image = cv2.cvtColor(output_image, cv2.COLOR_RGB2BGR) 43 | output_image = cv2.resize(output_image, (image.shape[1], image.shape[0])) 44 | return output_image 45 | 46 | 47 | if __name__ == '__main__': 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument('--imgpath', type=str, 50 | default='testimgs/1.png', help="image path") 51 | parser.add_argument('--modelpath', type=str, 52 | default='weights/diffusion_low_light_1x3x192x320.onnx', help="image path") 53 | args = parser.parse_args() 54 | 55 | mynet = Diffusion_Low_Light(args.modelpath) 56 | srcimg = cv2.imread(args.imgpath) 57 | dstimg = mynet.detect(srcimg) 58 | 59 | # cv2.namedWindow('srcimg', cv2.WINDOW_NORMAL) 60 | # cv2.imshow('srcimg', srcimg) 61 | # cv2.namedWindow('dstimg', cv2.WINDOW_NORMAL) 62 | # cv2.imshow('dstimg', dstimg) 63 | # cv2.waitKey(0) 64 | # cv2.destroyAllWindows() 65 | 66 | plt.subplot(1, 2, 1) 67 | plt.imshow(cv2.cvtColor(srcimg, cv2.COLOR_BGR2RGB)) 68 | plt.axis('off') 69 | plt.title('srcimg', color='red') 70 | 71 | plt.subplot(1, 2, 2) 72 | plt.imshow(cv2.cvtColor(dstimg, cv2.COLOR_BGR2RGB)) 73 | plt.axis('off') 74 | plt.title('dstimg', color='red') 75 | 76 | 77 | plt.show() 78 | # plt.savefig('result.jpg', dpi=700, bbox_inches='tight') ###保存高清图 79 | plt.close() -------------------------------------------------------------------------------- /main.cpp: -------------------------------------------------------------------------------- 1 | #define _CRT_SECURE_NO_WARNINGS 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | //#include 9 | #include 10 | 11 | using namespace cv; 12 | using namespace std; 13 | using namespace Ort; 14 | 15 | 16 | class Diffusion_Low_Light 17 | { 18 | public: 19 | Diffusion_Low_Light(string modelpath); 20 | Mat detect(const Mat& frame); 21 | private: 22 | vector input_image; 23 | 24 | Env env = Env(ORT_LOGGING_LEVEL_ERROR, "Low-light Image Enhancement with Wavelet-based Diffusion Models"); 25 | Ort::Session *ort_session = nullptr; 26 | SessionOptions sessionOptions = SessionOptions(); 27 | const vector input_names = {"input"}; 28 | const vector output_names = {"output"}; 29 | int inpWidth; 30 | int inpHeight; 31 | void preprocess(const Mat& frame); 32 | }; 33 | 34 | Diffusion_Low_Light::Diffusion_Low_Light(string model_path) 35 | { 36 | //OrtStatus* status = OrtSessionOptionsAppendExecutionProvider_CUDA(sessionOptions, 0); 37 | sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC); 38 | 39 | // std::wstring widestr = std::wstring(model_path.begin(), model_path.end()); ////windows写法 40 | // ort_session = new Session(env, widestr.c_str(), sessionOptions); ////windows写法 41 | ort_session = new Session(env, model_path.c_str(), sessionOptions); ////linux写法 42 | 43 | size_t numInputNodes = ort_session->GetInputCount(); 44 | AllocatorWithDefaultOptions allocator; 45 | vector> input_node_dims; 46 | for (int i = 0; i < numInputNodes; i++) 47 | { 48 | Ort::TypeInfo input_type_info = ort_session->GetInputTypeInfo(i); 49 | auto input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo(); 50 | auto input_dims = input_tensor_info.GetShape(); 51 | input_node_dims.push_back(input_dims); 52 | } 53 | 54 | this->inpHeight = input_node_dims[0][2]; 55 | this->inpWidth = input_node_dims[0][3]; 56 | } 57 | 58 | void Diffusion_Low_Light::preprocess(const Mat& frame) 59 | { 60 | Mat dstimg; 61 | cvtColor(frame, dstimg, COLOR_BGR2RGB); 62 | resize(dstimg, dstimg, Size(this->inpWidth, this->inpHeight)); 63 | dstimg.convertTo(dstimg, CV_32FC3, 1 / 255.f); 64 | 65 | vector rgbChannels(3); 66 | split(dstimg, rgbChannels); 67 | const int image_area = dstimg.rows * dstimg.cols; 68 | this->input_image.clear(); 69 | this->input_image.resize(1 * 3 * image_area); 70 | int single_chn_size = image_area * sizeof(float); 71 | memcpy(this->input_image.data(), (float *)rgbChannels[0].data, single_chn_size); 72 | memcpy(this->input_image.data() + image_area, (float *)rgbChannels[1].data, single_chn_size); 73 | memcpy(this->input_image.data() + image_area * 2, (float *)rgbChannels[2].data, single_chn_size); 74 | } 75 | 76 | Mat Diffusion_Low_Light::detect(const Mat& frame) 77 | { 78 | this->preprocess(frame); 79 | 80 | array input_shape_{ 1, 3, this->inpHeight, this->inpWidth }; 81 | auto allocator_info = MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); 82 | Value input_tensor_ = Value::CreateTensor(allocator_info, this->input_image.data(), this->input_image.size(), input_shape_.data(), input_shape_.size()); 83 | 84 | // 开始推理 85 | vector ort_outputs = ort_session->Run(RunOptions{ nullptr }, &input_names[0], &input_tensor_, 1, output_names.data(), output_names.size()); // 开始推理 86 | 87 | std::vector out_shape = ort_outputs[0].GetTensorTypeAndShapeInfo().GetShape(); 88 | const int out_h = out_shape[2]; 89 | const int out_w = out_shape[3]; 90 | float* pred = ort_outputs[0].GetTensorMutableData(); 91 | const int channel_step = out_h * out_w; 92 | Mat bmat(out_h, out_w, CV_32FC1, pred); 93 | Mat gmat(out_h, out_w, CV_32FC1, pred + channel_step); 94 | Mat rmat(out_h, out_w, CV_32FC1, pred + 2 * channel_step); 95 | bmat *= 255.f; 96 | gmat *= 255.f; 97 | rmat *= 255.f; 98 | 99 | vector channel_mats = {rmat, gmat, bmat}; 100 | Mat dstimg; 101 | merge(channel_mats, dstimg); 102 | dstimg.convertTo(dstimg, CV_8UC3); 103 | resize(dstimg, dstimg, Size(frame.cols, frame.rows)); 104 | return dstimg; 105 | } 106 | 107 | 108 | int main() 109 | { 110 | Diffusion_Low_Light mynet("weights/diffusion_low_light_1x3x192x320.onnx"); 111 | string imgpath = "testimgs/1.png"; ///文件路径写正确,程序才能正常运行的 112 | Mat srcimg = imread(imgpath); 113 | 114 | Mat dstimg = mynet.detect(srcimg); 115 | 116 | imwrite("result.jpg", dstimg); 117 | 118 | // namedWindow("srcimg", WINDOW_NORMAL); 119 | // imshow("srcimg", srcimg); 120 | // static const string kWinName = "Deep learning use onnxruntime"; 121 | // namedWindow(kWinName, WINDOW_NORMAL); 122 | // imshow(kWinName, dstimg); 123 | // waitKey(0); 124 | // destroyAllWindows(); 125 | } 126 | --------------------------------------------------------------------------------