├── README.md ├── onnxruntime ├── images │ ├── 1.jpg │ ├── 2.jpg │ ├── bike.jpg │ ├── cam_image44.jpg │ └── cam_image47.jpg ├── main.cpp └── main.py └── opencv ├── images ├── 1.jpg ├── 2.jpg ├── bike.jpg ├── cam_image44.jpg └── cam_image47.jpg ├── main.cpp └── main.py /README.md: -------------------------------------------------------------------------------- 1 | # DIS-opencv-onnxrun 2 | 分别使用OpenCV、ONNXRuntime部署DIS高精度图像二类分割,包含C++和Python两种版本的程序。 3 | 本套程序对应的paper是ECCV2022的一篇文章《Highly Accurate Dichotomous Image Segmentation》, 4 | 跟BASNet和U2-Net都是出自同一个作者写的。 5 | 6 | 本套程序提供了50个onnx文件,占用磁盘空间8.2G,onnx文件在百度云盘,下载链接:https://pan.baidu.com/s/19jENx2Ul8oJn-iLBFK8sLg 7 | 提取码:uphj 8 | 9 | 需要注意的是opencv不能加载['isnet_general_use_HxW.onnx', 'isnet_HxW.onnx', 'isnet_Nx3xHxW.onnx']这3个文件做推理, 10 | onnxruntime不能加载['isnet_general_use_HxW.onnx', 'isnet_HxW.onnx', 'isnet_Nx3xHxW.onnx']这3个文件做推理 11 | -------------------------------------------------------------------------------- /onnxruntime/images/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/DIS-opencv-onnxrun/080e98939b7312edfb56d97fd8c4f4b2051f7e9d/onnxruntime/images/1.jpg -------------------------------------------------------------------------------- /onnxruntime/images/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/DIS-opencv-onnxrun/080e98939b7312edfb56d97fd8c4f4b2051f7e9d/onnxruntime/images/2.jpg -------------------------------------------------------------------------------- /onnxruntime/images/bike.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/DIS-opencv-onnxrun/080e98939b7312edfb56d97fd8c4f4b2051f7e9d/onnxruntime/images/bike.jpg -------------------------------------------------------------------------------- /onnxruntime/images/cam_image44.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/DIS-opencv-onnxrun/080e98939b7312edfb56d97fd8c4f4b2051f7e9d/onnxruntime/images/cam_image44.jpg -------------------------------------------------------------------------------- /onnxruntime/images/cam_image47.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/DIS-opencv-onnxrun/080e98939b7312edfb56d97fd8c4f4b2051f7e9d/onnxruntime/images/cam_image47.jpg -------------------------------------------------------------------------------- /onnxruntime/main.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/DIS-opencv-onnxrun/080e98939b7312edfb56d97fd8c4f4b2051f7e9d/onnxruntime/main.cpp -------------------------------------------------------------------------------- /onnxruntime/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | import onnxruntime 5 | ### onnxruntime load ['isnet_general_use_HxW.onnx', 'isnet_HxW.onnx', 'isnet_Nx3xHxW.onnx'] inference failed 6 | class DIS(): 7 | def __init__(self, modelpath, score_th=None): 8 | so = onnxruntime.SessionOptions() 9 | so.log_severity_level = 3 10 | self.net = onnxruntime.InferenceSession(modelpath, so) 11 | self.input_height = self.net.get_inputs()[0].shape[2] 12 | self.input_width = self.net.get_inputs()[0].shape[3] 13 | self.input_name = self.net.get_inputs()[0].name 14 | self.output_name = self.net.get_outputs()[0].name 15 | self.score_th = score_th 16 | 17 | def detect(self, srcimg): 18 | img = cv2.resize(srcimg, dsize=(self.input_width, self.input_height)) 19 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 20 | img = img.astype(np.float32) / 255.0 - 0.5 21 | blob = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0).astype(np.float32) 22 | outs = self.net.run([self.output_name], {self.input_name: blob}) 23 | 24 | mask = np.array(outs[0]).squeeze() 25 | min_value = np.min(mask) 26 | max_value = np.max(mask) 27 | mask = (mask - min_value) / (max_value - min_value) 28 | if self.score_th is not None: 29 | mask = np.where(mask < self.score_th, 0, 1) 30 | mask *= 255 31 | mask = mask.astype('uint8') 32 | 33 | mask = cv2.resize(mask, dsize=(srcimg.shape[1], srcimg.shape[0]), interpolation=cv2.INTER_LINEAR) 34 | return mask 35 | 36 | def generate_overlay_image(srcimg, mask): 37 | overlay_image = np.zeros(srcimg.shape, dtype=np.uint8) 38 | overlay_image[:] = (255, 255, 255) 39 | mask = np.stack((mask,) * 3, axis=-1).astype('uint8') ###沿着通道方向复制3次 40 | mask_image = np.where(mask, srcimg, overlay_image) 41 | return mask, mask_image 42 | 43 | if __name__ == '__main__': 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("--imgpath", type=str, default='images/cam_image47.jpg') 46 | parser.add_argument("--modelpath", type=str, default='weights/isnet_general_use_480x640.onnx') 47 | args = parser.parse_args() 48 | 49 | mynet = DIS(args.modelpath) 50 | srcimg = cv2.imread(args.imgpath) 51 | mask = mynet.detect(srcimg) 52 | mask, overlay_image = generate_overlay_image(srcimg, mask) 53 | 54 | winName = 'Deep learning object detection in onnxruntime' 55 | cv2.namedWindow(winName, cv2.WINDOW_NORMAL) 56 | cv2.imshow(winName, np.hstack((srcimg, mask))) 57 | cv2.waitKey(0) 58 | cv2.destroyAllWindows() -------------------------------------------------------------------------------- /opencv/images/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/DIS-opencv-onnxrun/080e98939b7312edfb56d97fd8c4f4b2051f7e9d/opencv/images/1.jpg -------------------------------------------------------------------------------- /opencv/images/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/DIS-opencv-onnxrun/080e98939b7312edfb56d97fd8c4f4b2051f7e9d/opencv/images/2.jpg -------------------------------------------------------------------------------- /opencv/images/bike.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/DIS-opencv-onnxrun/080e98939b7312edfb56d97fd8c4f4b2051f7e9d/opencv/images/bike.jpg -------------------------------------------------------------------------------- /opencv/images/cam_image44.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/DIS-opencv-onnxrun/080e98939b7312edfb56d97fd8c4f4b2051f7e9d/opencv/images/cam_image44.jpg -------------------------------------------------------------------------------- /opencv/images/cam_image47.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/DIS-opencv-onnxrun/080e98939b7312edfb56d97fd8c4f4b2051f7e9d/opencv/images/cam_image47.jpg -------------------------------------------------------------------------------- /opencv/main.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/DIS-opencv-onnxrun/080e98939b7312edfb56d97fd8c4f4b2051f7e9d/opencv/main.cpp -------------------------------------------------------------------------------- /opencv/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | ###opencv dnn load ['isnet_general_use_HxW.onnx', 'isnet_HxW.onnx', 'isnet_Nx3xHxW.onnx'] inference failed 5 | class DIS(): 6 | def __init__(self, modelpath, score_th=None): 7 | self.net = cv2.dnn.readNet(modelpath) 8 | hxw = modelpath.split('_')[-1].split('.')[0].split('x') 9 | self.input_height = int(hxw[0]) 10 | self.input_width = int(hxw[1]) 11 | self.score_th = score_th 12 | self.output_names = self.net.getUnconnectedOutLayersNames() 13 | 14 | def detect(self, srcimg): 15 | img = cv2.resize(srcimg, dsize=(self.input_width, self.input_height)) 16 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 17 | img = img.astype(np.float32) / 255.0 - 0.5 18 | blob = cv2.dnn.blobFromImage(img) 19 | self.net.setInput(blob) 20 | outs = self.net.forward(self.output_names) 21 | 22 | mask = np.array(outs[0]).squeeze() 23 | min_value = np.min(mask) 24 | max_value = np.max(mask) 25 | mask = (mask - min_value) / (max_value - min_value) 26 | if self.score_th is not None: 27 | mask = np.where(mask < self.score_th, 0, 1) 28 | mask *= 255 29 | mask = mask.astype('uint8') 30 | 31 | mask = cv2.resize(mask, dsize=(srcimg.shape[1], srcimg.shape[0]), interpolation=cv2.INTER_LINEAR) 32 | return mask 33 | 34 | def generate_overlay_image(srcimg, mask): 35 | overlay_image = np.zeros(srcimg.shape, dtype=np.uint8) 36 | overlay_image[:] = (255, 255, 255) 37 | mask = np.stack((mask,) * 3, axis=-1).astype('uint8') ###沿着通道方向复制3次 38 | mask_image = np.where(mask, srcimg, overlay_image) 39 | return mask, mask_image 40 | 41 | if __name__ == '__main__': 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--imgpath", type=str, default='images/bike.jpg') 44 | parser.add_argument("--modelpath", type=str, default='weights/isnet_general_use_480x640.onnx') 45 | args = parser.parse_args() 46 | 47 | mynet = DIS(args.modelpath) 48 | srcimg = cv2.imread(args.imgpath) 49 | mask = mynet.detect(srcimg) 50 | mask, overlay_image = generate_overlay_image(srcimg, mask) 51 | 52 | winName = 'Deep learning object detection in OpenCV' 53 | cv2.namedWindow(winName, cv2.WINDOW_NORMAL) 54 | cv2.imshow(winName, np.hstack((srcimg, mask))) 55 | cv2.waitKey(0) 56 | cv2.destroyAllWindows() --------------------------------------------------------------------------------