├── main.cpp ├── sample.jpg ├── README.md └── main.py /main.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/u2net-onnxruntime/HEAD/main.cpp -------------------------------------------------------------------------------- /sample.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hpc203/u2net-onnxruntime/HEAD/sample.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # u2net-onnxruntime 2 | 使用ONNXRuntime部署U-2-Net生成人脸素描画,包含C++和Python两个版本的程序 3 | 4 | .onnx文件在百度云盘,下载链接:https://pan.baidu.com/s/1QCFa1nWJklfqMMeLa-woSg 5 | 提取码:e8fr 6 | 7 | 起初我打算使用opencv部署的,但是opencv的dnn模块读取onnx文件失败了,无奈只能使用onnxruntime部署 8 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | import onnxruntime 5 | 6 | class u2net(): 7 | def __init__(self): 8 | try: 9 | cvnet = cv2.dnn.readNet('u2net_portrait.onnx') 10 | except: 11 | print('opencv read onnx failed!!!') 12 | so = onnxruntime.SessionOptions() 13 | so.log_severity_level = 3 14 | self.net = onnxruntime.InferenceSession('u2net_portrait.onnx', so) 15 | self.input_size = 512 16 | self.mean = [0.485, 0.456, 0.406] 17 | self.std = [0.229, 0.224, 0.225] 18 | self.input_name = self.net.get_inputs()[0].name 19 | self.output_name = self.net.get_outputs()[0].name 20 | def detect(self, srcimg): 21 | img = cv2.resize(srcimg, dsize=(self.input_size, self.input_size)) 22 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 23 | img = np.array(img, dtype=np.float32) 24 | img = (img / 255.0 - self.mean) / self.std 25 | blob = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0).astype(np.float32) 26 | outs = self.net.run([self.output_name], {self.input_name: blob}) 27 | # outs = self.net.run(None, {self.net.get_inputs()[0].name: blob}) 28 | 29 | result = np.array(outs[0]).squeeze() 30 | result = (1 - result) 31 | min_value = np.min(result) 32 | max_value = np.max(result) 33 | result = (result - min_value) / (max_value - min_value) 34 | result *= 255 35 | return result.astype('uint8') 36 | 37 | if __name__ == '__main__': 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument("--imgpath", type=str, default='sample.jpg') 40 | args = parser.parse_args() 41 | 42 | mynet = u2net() 43 | srcimg = cv2.imread(args.imgpath) 44 | result = mynet.detect(srcimg) 45 | result = cv2.resize(result, (srcimg.shape[1], srcimg.shape[0])) 46 | 47 | cv2.namedWindow('srcimg', cv2.WINDOW_NORMAL) 48 | cv2.imshow('srcimg', srcimg) 49 | winName = 'Deep learning object detection in onnxruntime' 50 | cv2.namedWindow(winName, cv2.WINDOW_NORMAL) 51 | cv2.imshow(winName, result) 52 | cv2.waitKey(0) 53 | cv2.destroyAllWindows() --------------------------------------------------------------------------------