├── README.md
└── onnxruntime
├── main.py
├── result.jpg
└── util.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | YOLOv9 ONNX Segmentation
4 | ===========================
5 |
6 | [](https://www.python.org/downloads/release/python-31012/)
7 | [](https://github.com/spacewalk01/depth-anything-tensorrt/blob/main/LICENSE)
8 |
9 |
10 |
11 | Instance and panoptic segmentation using yolov9 in onnxruntime.
12 |
13 |
14 |
15 |
16 |
17 |
18 | ## 🚀 Quick Start
19 |
20 | Download [gelan-c-pan.pt](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/gelan-c-pan.pt)
21 |
22 | Prepare an onnx model:
23 | ``` shell
24 | git clone https://github.com/WongKinYiu/yolov9
25 | pip install -r requirements.txt
26 | python export.py --weights gelan-c-pan.pt --include onnx
27 | ```
28 |
29 | Perform inference:
30 | ``` shell
31 | git clone https://github.com/spacewalk01/yolov9-onnx-segmentation.git
32 | cd yolov9-onnx-segmentation/onnxruntime
33 | python main.py --model --input
34 | ```
35 |
36 | Example:
37 | ``` shell
38 | # infer an image
39 | python main.py --model gelan-c-pan.onnx --input test.jpg
40 | # infer a folder(images)
41 | python main.py --model gelan-c-pan.onnx --input folder
42 | # infer a video
43 | python main.py --model gelan-c-pan.onnx --input test.mp4 # the video path
44 | ```
45 |
46 | ## 👏 Acknowledgement
47 |
48 | This project is based on the following projects:
49 | - [YOLOv9](https://github.com/WongKinYiu/yolov9) - YOLOv9: Learning What You Want to Learn Using Programmable Gradient Information.
50 | - [ONNX-YOLOv8-Instance-Segmentation](https://github.com/ibaiGorordo/ONNX-YOLOv8-Instance-Segmentation) - Python scripts performing Instance Segmentation using the YOLOv8 model in ONNX.
51 |
--------------------------------------------------------------------------------
/onnxruntime/main.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 | import cv2
4 | import numpy as np
5 | import onnxruntime
6 | import argparse
7 | import os
8 | from util import xywh2xyxy, nms, draw_detections, sigmoid, imread_from_url
9 |
10 |
11 | class YOLOSeg:
12 |
13 | def __init__(self, path, conf_thres=0.7, iou_thres=0.5, num_masks=32):
14 | self.conf_threshold = conf_thres
15 | self.iou_threshold = iou_thres
16 | self.num_masks = num_masks
17 |
18 | # Initialize model
19 | self.initialize_model(path)
20 |
21 | def __call__(self, image):
22 | return self.segment_objects(image)
23 |
24 | def initialize_model(self, path):
25 | self.session = onnxruntime.InferenceSession(path,
26 | providers=['CUDAExecutionProvider',
27 | 'CPUExecutionProvider'])
28 | # Get model info
29 | self.get_input_details()
30 | self.get_output_details()
31 |
32 | def segment_objects(self, image):
33 | input_tensor = self.prepare_input(image)
34 |
35 | # Perform inference on the image
36 | outputs = self.inference(input_tensor)
37 |
38 | self.boxes, self.scores, self.class_ids, mask_pred = self.process_box_output(outputs[0])
39 | self.mask_maps = self.process_mask_output(mask_pred, outputs[1])
40 |
41 | return self.boxes, self.scores, self.class_ids, self.mask_maps
42 |
43 | def prepare_input(self, image):
44 | self.img_height, self.img_width = image.shape[:2]
45 |
46 | input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
47 |
48 | # Resize input image
49 | input_img = cv2.resize(input_img, (self.input_width, self.input_height))
50 |
51 | # Scale input pixel values to 0 to 1
52 | input_img = input_img / 255.0
53 | input_img = input_img.transpose(2, 0, 1)
54 | input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32)
55 |
56 | return input_tensor
57 |
58 | def inference(self, input_tensor):
59 | start = time.perf_counter()
60 | outputs = self.session.run(self.output_names, {self.input_names[0]: input_tensor})
61 |
62 | # print(f"Inference time: {(time.perf_counter() - start)*1000:.2f} ms")
63 | return outputs
64 |
65 | def process_box_output(self, box_output):
66 |
67 | predictions = np.squeeze(box_output).T
68 | num_classes = box_output.shape[1] - self.num_masks - 4
69 |
70 | # Filter out object confidence scores below threshold
71 | scores = np.max(predictions[:, 4:4+num_classes], axis=1)
72 | predictions = predictions[scores > self.conf_threshold, :]
73 | scores = scores[scores > self.conf_threshold]
74 |
75 | if len(scores) == 0:
76 | return [], [], [], np.array([])
77 |
78 | box_predictions = predictions[..., :num_classes+4]
79 | mask_predictions = predictions[..., num_classes+4:]
80 |
81 | # Get the class with the highest confidence
82 | class_ids = np.argmax(box_predictions[:, 4:], axis=1)
83 |
84 | # Get bounding boxes for each object
85 | boxes = self.extract_boxes(box_predictions)
86 |
87 | # Apply non-maxima suppression to suppress weak, overlapping bounding boxes
88 | indices = nms(boxes, scores, self.iou_threshold)
89 |
90 | return boxes[indices], scores[indices], class_ids[indices], mask_predictions[indices]
91 |
92 | def process_mask_output(self, mask_predictions, mask_output):
93 |
94 | if mask_predictions.shape[0] == 0:
95 | return []
96 |
97 | mask_output = np.squeeze(mask_output)
98 |
99 | # Calculate the mask maps for each box
100 | num_mask, mask_height, mask_width = mask_output.shape # CHW
101 | masks = sigmoid(mask_predictions @ mask_output.reshape((num_mask, -1)))
102 | masks = masks.reshape((-1, mask_height, mask_width))
103 |
104 | # Downscale the boxes to match the mask size
105 | scale_boxes = self.rescale_boxes(self.boxes,
106 | (self.img_height, self.img_width),
107 | (mask_height, mask_width))
108 |
109 | # For every box/mask pair, get the mask map
110 | mask_maps = np.zeros((len(scale_boxes), self.img_height, self.img_width))
111 | blur_size = (int(self.img_width / mask_width), int(self.img_height / mask_height))
112 | for i in range(len(scale_boxes)):
113 |
114 | scale_x1 = int(math.floor(scale_boxes[i][0]))
115 | scale_y1 = int(math.floor(scale_boxes[i][1]))
116 | scale_x2 = int(math.ceil(scale_boxes[i][2]))
117 | scale_y2 = int(math.ceil(scale_boxes[i][3]))
118 |
119 | x1 = int(math.floor(self.boxes[i][0]))
120 | y1 = int(math.floor(self.boxes[i][1]))
121 | x2 = int(math.ceil(self.boxes[i][2]))
122 | y2 = int(math.ceil(self.boxes[i][3]))
123 |
124 | scale_crop_mask = masks[i][scale_y1:scale_y2, scale_x1:scale_x2]
125 | crop_mask = cv2.resize(scale_crop_mask,
126 | (x2 - x1, y2 - y1),
127 | interpolation=cv2.INTER_CUBIC)
128 |
129 | crop_mask = cv2.blur(crop_mask, blur_size)
130 |
131 | crop_mask = (crop_mask > 0.5).astype(np.uint8)
132 | mask_maps[i, y1:y2, x1:x2] = crop_mask
133 |
134 | return mask_maps
135 |
136 | def extract_boxes(self, box_predictions):
137 | # Extract boxes from predictions
138 | boxes = box_predictions[:, :4]
139 |
140 | # Scale boxes to original image dimensions
141 | boxes = self.rescale_boxes(boxes,
142 | (self.input_height, self.input_width),
143 | (self.img_height, self.img_width))
144 |
145 | # Convert boxes to xyxy format
146 | boxes = xywh2xyxy(boxes)
147 |
148 | # Check the boxes are within the image
149 | boxes[:, 0] = np.clip(boxes[:, 0], 0, self.img_width)
150 | boxes[:, 1] = np.clip(boxes[:, 1], 0, self.img_height)
151 | boxes[:, 2] = np.clip(boxes[:, 2], 0, self.img_width)
152 | boxes[:, 3] = np.clip(boxes[:, 3], 0, self.img_height)
153 |
154 | return boxes
155 |
156 | def draw_detections(self, image, draw_scores=True, mask_alpha=0.4):
157 | return draw_detections(image, self.boxes, self.scores,
158 | self.class_ids, mask_alpha)
159 |
160 | def draw_masks(self, image, draw_scores=True, mask_alpha=0.5):
161 | return draw_detections(image, self.boxes, self.scores,
162 | self.class_ids, mask_alpha, mask_maps=self.mask_maps)
163 |
164 | def get_input_details(self):
165 | model_inputs = self.session.get_inputs()
166 | self.input_names = [model_inputs[i].name for i in range(len(model_inputs))]
167 |
168 | self.input_shape = model_inputs[0].shape
169 | self.input_height = self.input_shape[2]
170 | self.input_width = self.input_shape[3]
171 |
172 | def get_output_details(self):
173 | model_outputs = self.session.get_outputs()
174 | self.output_names = [model_outputs[i].name for i in range(len(model_outputs))]
175 |
176 | @staticmethod
177 | def rescale_boxes(boxes, input_shape, image_shape):
178 | # Rescale boxes to original image dimensions
179 | input_shape = np.array([input_shape[1], input_shape[0], input_shape[1], input_shape[0]])
180 | boxes = np.divide(boxes, input_shape, dtype=np.float32)
181 | boxes *= np.array([image_shape[1], image_shape[0], image_shape[1], image_shape[0]])
182 |
183 | return boxes
184 |
185 |
186 |
187 | def is_image(file_path):
188 | """Check if the given path points to an image."""
189 | image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff']
190 | return any(file_path.lower().endswith(ext) for ext in image_extensions)
191 |
192 | def is_video(file_path):
193 | """Check if the given path points to a video."""
194 | video_extensions = ['.mp4', '.avi', '.mkv', '.mov', '.wmv']
195 | return any(file_path.lower().endswith(ext) for ext in video_extensions)
196 |
197 | def detect_input_type(input_path):
198 | """Detect the type of input based on the provided path."""
199 | if os.path.isfile(input_path):
200 | if is_image(input_path):
201 | return 'image'
202 | elif is_video(input_path):
203 | return 'video'
204 | else:
205 | return None # Not an image or video
206 | elif os.path.isdir(input_path):
207 | return 'folder'
208 | else:
209 | return None # Not a valid file or directory
210 |
211 | if __name__ == '__main__':
212 | # Initialize argument parser
213 | parser = argparse.ArgumentParser(description='Object detection using YOLOv9')
214 | parser.add_argument('--model', type=str, default="gelan-c-pan.onnx", help='Path to the ONNX model')
215 | parser.add_argument('--input', type=str, default="video", help='Input type: "image", "folder", or "video"')
216 | args = parser.parse_args()
217 |
218 | # Initialize YOLOv9 Instance Segmentator
219 | yoloseg = YOLOSeg(args.model, conf_thres=0.3, iou_thres=0.5)
220 |
221 | # Create an output folder
222 | output_folder = "results"
223 | if not os.path.exists(output_folder):
224 | os.makedirs(output_folder)
225 |
226 | input_type = detect_input_type(args.input)
227 | if input_type == 'image':
228 | img = cv2.imread(args.input)
229 | print("image: ", args.input)
230 | # Detect objects in the image
231 | yoloseg(img)
232 | # Draw detections
233 | combined_img = yoloseg.draw_masks(img)
234 | output_path = os.path.join(output_folder, "result.jpg")
235 | cv2.imwrite(output_path, combined_img)
236 | cv2.imshow("Output", combined_img)
237 | cv2.waitKey(0)
238 | elif input_type == 'folder':
239 | # Loop through image files in the given folder
240 | for filename in os.listdir(args.input):
241 | if filename.endswith(".jpg") or filename.endswith(".png"): # Assuming images are jpg or png format
242 | img_path = os.path.join(args.input, filename)
243 | print("folder:", img_path)
244 | img = cv2.imread(img_path)
245 | # Detect objects in the image
246 | yoloseg(img)
247 | # Draw detections
248 | combined_img = yoloseg.draw_masks(img)
249 | output_path = os.path.join(output_folder, filename)
250 | cv2.imwrite(output_path, combined_img)
251 | cv2.imshow("Output", combined_img)
252 | cv2.waitKey(0)
253 | elif input_type == 'video':
254 | cap = cv2.VideoCapture(args.input) # Replace with the actual video path
255 | frame_width = round(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # Frame width
256 | frame_height = round(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # Frame height
257 | fps = int(cap.get(cv2.CAP_PROP_FPS)) # Frames per second
258 | out = cv2.VideoWriter(os.path.join(output_folder, "result.mp4"), cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_width, frame_height))
259 | while True:
260 | ret, frame = cap.read()
261 | if not ret:
262 | break
263 | # Detect objects in the frame
264 | yoloseg(frame)
265 | # Draw detections
266 | combined_frame = yoloseg.draw_masks(frame)
267 | out.write(combined_frame)
268 | cv2.imshow("Output", combined_frame)
269 | if cv2.waitKey(1) & 0xFF == ord('q'):
270 | break
271 | cap.release()
272 | out.release()
273 | cv2.destroyAllWindows()
274 |
--------------------------------------------------------------------------------
/onnxruntime/result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spacewalk01/yolov9-onnx-segmentation/3a8916f312b3086ed72ce232850aa9d738ea2c3b/onnxruntime/result.jpg
--------------------------------------------------------------------------------
/onnxruntime/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import requests
4 | from io import BytesIO
5 |
6 | class_names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
7 | 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
8 | 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
9 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
10 | 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
11 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
12 | 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
13 | 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
14 | 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
15 |
16 | # Create a list of colors for each class where each color is a tuple of 3 integer values
17 | rng = np.random.default_rng(3)
18 | colors = rng.uniform(0, 255, size=(len(class_names), 3))
19 |
20 | def imread_from_url(url):
21 | try:
22 | # Send a GET request to the URL to download the image
23 | response = requests.get(url)
24 |
25 | # Check if the request was successful
26 | if response.status_code == 200:
27 | # Read the image from the response content
28 | image_bytes = BytesIO(response.content)
29 | image_array = np.asarray(bytearray(image_bytes.read()), dtype=np.uint8)
30 | image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
31 |
32 | return image
33 | else:
34 | print("Failed to download image. Status code:", response.status_code)
35 | return None
36 | except Exception as e:
37 | print("An error occurred:", e)
38 | return None
39 |
40 | def nms(boxes, scores, iou_threshold):
41 | # Sort by score
42 | sorted_indices = np.argsort(scores)[::-1]
43 |
44 | keep_boxes = []
45 | while sorted_indices.size > 0:
46 | # Pick the last box
47 | box_id = sorted_indices[0]
48 | keep_boxes.append(box_id)
49 |
50 | # Compute IoU of the picked box with the rest
51 | ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])
52 |
53 | # Remove boxes with IoU over the threshold
54 | keep_indices = np.where(ious < iou_threshold)[0]
55 |
56 | # print(keep_indices.shape, sorted_indices.shape)
57 | sorted_indices = sorted_indices[keep_indices + 1]
58 |
59 | return keep_boxes
60 |
61 |
62 | def compute_iou(box, boxes):
63 | # Compute xmin, ymin, xmax, ymax for both boxes
64 | xmin = np.maximum(box[0], boxes[:, 0])
65 | ymin = np.maximum(box[1], boxes[:, 1])
66 | xmax = np.minimum(box[2], boxes[:, 2])
67 | ymax = np.minimum(box[3], boxes[:, 3])
68 |
69 | # Compute intersection area
70 | intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)
71 |
72 | # Compute union area
73 | box_area = (box[2] - box[0]) * (box[3] - box[1])
74 | boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
75 | union_area = box_area + boxes_area - intersection_area
76 |
77 | # Compute IoU
78 | iou = intersection_area / union_area
79 |
80 | return iou
81 |
82 |
83 | def xywh2xyxy(x):
84 | # Convert bounding box (x, y, w, h) to bounding box (x1, y1, x2, y2)
85 | y = np.copy(x)
86 | y[..., 0] = x[..., 0] - x[..., 2] / 2
87 | y[..., 1] = x[..., 1] - x[..., 3] / 2
88 | y[..., 2] = x[..., 0] + x[..., 2] / 2
89 | y[..., 3] = x[..., 1] + x[..., 3] / 2
90 | return y
91 |
92 |
93 | def sigmoid(x):
94 | return 1 / (1 + np.exp(-x))
95 |
96 |
97 | def draw_detections(image, boxes, scores, class_ids, mask_alpha=0.3, mask_maps=None):
98 | img_height, img_width = image.shape[:2]
99 | size = min([img_height, img_width]) * 0.0006
100 | text_thickness = int(min([img_height, img_width]) * 0.001)
101 |
102 | mask_img = draw_masks(image, boxes, class_ids, mask_alpha, mask_maps)
103 |
104 | # Draw bounding boxes and labels of detections
105 | for box, score, class_id in zip(boxes, scores, class_ids):
106 | color = colors[class_id]
107 |
108 | x1, y1, x2, y2 = box.astype(int)
109 |
110 | # Draw rectangle
111 | cv2.rectangle(mask_img, (x1, y1), (x2, y2), color, 2)
112 |
113 | label = class_names[class_id]
114 | caption = f'{label} {int(score * 100)}%'
115 | (tw, th), _ = cv2.getTextSize(text=caption, fontFace=cv2.FONT_HERSHEY_SIMPLEX,
116 | fontScale=size, thickness=text_thickness)
117 | th = int(th * 1.2)
118 |
119 | cv2.rectangle(mask_img, (x1, y1),
120 | (x1 + tw, y1 - th), color, -1)
121 |
122 | cv2.putText(mask_img, caption, (x1, y1),
123 | cv2.FONT_HERSHEY_SIMPLEX, size, (255, 255, 255), text_thickness, cv2.LINE_AA)
124 |
125 | return mask_img
126 |
127 |
128 | def draw_masks(image, boxes, class_ids, mask_alpha=0.3, mask_maps=None):
129 | mask_img = image.copy()
130 |
131 | # Draw bounding boxes and labels of detections
132 | for i, (box, class_id) in enumerate(zip(boxes, class_ids)):
133 | color = colors[class_id]
134 |
135 | x1, y1, x2, y2 = box.astype(int)
136 |
137 | # Draw fill mask image
138 | if mask_maps is None:
139 | cv2.rectangle(mask_img, (x1, y1), (x2, y2), color, -1)
140 | else:
141 | crop_mask = mask_maps[i][y1:y2, x1:x2, np.newaxis]
142 | crop_mask_img = mask_img[y1:y2, x1:x2]
143 | crop_mask_img = crop_mask_img * (1 - crop_mask) + crop_mask * color
144 | mask_img[y1:y2, x1:x2] = crop_mask_img
145 |
146 | return cv2.addWeighted(mask_img, mask_alpha, image, 1 - mask_alpha, 0)
147 |
148 |
149 | def draw_comparison(img1, img2, name1, name2, fontsize=2.6, text_thickness=3):
150 | (tw, th), _ = cv2.getTextSize(text=name1, fontFace=cv2.FONT_HERSHEY_DUPLEX,
151 | fontScale=fontsize, thickness=text_thickness)
152 | x1 = img1.shape[1] // 3
153 | y1 = th
154 | offset = th // 5
155 | cv2.rectangle(img1, (x1 - offset * 2, y1 + offset),
156 | (x1 + tw + offset * 2, y1 - th - offset), (0, 115, 255), -1)
157 | cv2.putText(img1, name1,
158 | (x1, y1),
159 | cv2.FONT_HERSHEY_DUPLEX, fontsize,
160 | (255, 255, 255), text_thickness)
161 |
162 | (tw, th), _ = cv2.getTextSize(text=name2, fontFace=cv2.FONT_HERSHEY_DUPLEX,
163 | fontScale=fontsize, thickness=text_thickness)
164 | x1 = img2.shape[1] // 3
165 | y1 = th
166 | offset = th // 5
167 | cv2.rectangle(img2, (x1 - offset * 2, y1 + offset),
168 | (x1 + tw + offset * 2, y1 - th - offset), (94, 23, 235), -1)
169 |
170 | cv2.putText(img2, name2,
171 | (x1, y1),
172 | cv2.FONT_HERSHEY_DUPLEX, fontsize,
173 | (255, 255, 255), text_thickness)
174 |
175 | combined_img = cv2.hconcat([img1, img2])
176 | if combined_img.shape[1] > 3840:
177 | combined_img = cv2.resize(combined_img, (3840, 2160))
178 |
179 | return combined_img
180 |
--------------------------------------------------------------------------------