├── README.md └── onnxruntime ├── main.py ├── result.jpg └── util.py /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | YOLOv9 ONNX Segmentation 4 | =========================== 5 | 6 | [![python](https://img.shields.io/badge/python-3.10.12-green)](https://www.python.org/downloads/release/python-31012/) 7 | [![mit](https://img.shields.io/badge/license-MIT-blue)](https://github.com/spacewalk01/depth-anything-tensorrt/blob/main/LICENSE) 8 | 9 |
10 | 11 | Instance and panoptic segmentation using yolov9 in onnxruntime. 12 | 13 |

14 | 15 |

16 | 17 | 18 | ## 🚀 Quick Start 19 | 20 | Download [gelan-c-pan.pt](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/gelan-c-pan.pt) 21 | 22 | Prepare an onnx model: 23 | ``` shell 24 | git clone https://github.com/WongKinYiu/yolov9 25 | pip install -r requirements.txt 26 | python export.py --weights gelan-c-pan.pt --include onnx 27 | ``` 28 | 29 | Perform inference: 30 | ``` shell 31 | git clone https://github.com/spacewalk01/yolov9-onnx-segmentation.git 32 | cd yolov9-onnx-segmentation/onnxruntime 33 | python main.py --model --input 34 | ``` 35 | 36 | Example: 37 | ``` shell 38 | # infer an image 39 | python main.py --model gelan-c-pan.onnx --input test.jpg 40 | # infer a folder(images) 41 | python main.py --model gelan-c-pan.onnx --input folder 42 | # infer a video 43 | python main.py --model gelan-c-pan.onnx --input test.mp4 # the video path 44 | ``` 45 | 46 | ## 👏 Acknowledgement 47 | 48 | This project is based on the following projects: 49 | - [YOLOv9](https://github.com/WongKinYiu/yolov9) - YOLOv9: Learning What You Want to Learn Using Programmable Gradient Information. 50 | - [ONNX-YOLOv8-Instance-Segmentation](https://github.com/ibaiGorordo/ONNX-YOLOv8-Instance-Segmentation) - Python scripts performing Instance Segmentation using the YOLOv8 model in ONNX. 51 | -------------------------------------------------------------------------------- /onnxruntime/main.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import cv2 4 | import numpy as np 5 | import onnxruntime 6 | import argparse 7 | import os 8 | from util import xywh2xyxy, nms, draw_detections, sigmoid, imread_from_url 9 | 10 | 11 | class YOLOSeg: 12 | 13 | def __init__(self, path, conf_thres=0.7, iou_thres=0.5, num_masks=32): 14 | self.conf_threshold = conf_thres 15 | self.iou_threshold = iou_thres 16 | self.num_masks = num_masks 17 | 18 | # Initialize model 19 | self.initialize_model(path) 20 | 21 | def __call__(self, image): 22 | return self.segment_objects(image) 23 | 24 | def initialize_model(self, path): 25 | self.session = onnxruntime.InferenceSession(path, 26 | providers=['CUDAExecutionProvider', 27 | 'CPUExecutionProvider']) 28 | # Get model info 29 | self.get_input_details() 30 | self.get_output_details() 31 | 32 | def segment_objects(self, image): 33 | input_tensor = self.prepare_input(image) 34 | 35 | # Perform inference on the image 36 | outputs = self.inference(input_tensor) 37 | 38 | self.boxes, self.scores, self.class_ids, mask_pred = self.process_box_output(outputs[0]) 39 | self.mask_maps = self.process_mask_output(mask_pred, outputs[1]) 40 | 41 | return self.boxes, self.scores, self.class_ids, self.mask_maps 42 | 43 | def prepare_input(self, image): 44 | self.img_height, self.img_width = image.shape[:2] 45 | 46 | input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 47 | 48 | # Resize input image 49 | input_img = cv2.resize(input_img, (self.input_width, self.input_height)) 50 | 51 | # Scale input pixel values to 0 to 1 52 | input_img = input_img / 255.0 53 | input_img = input_img.transpose(2, 0, 1) 54 | input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32) 55 | 56 | return input_tensor 57 | 58 | def inference(self, input_tensor): 59 | start = time.perf_counter() 60 | outputs = self.session.run(self.output_names, {self.input_names[0]: input_tensor}) 61 | 62 | # print(f"Inference time: {(time.perf_counter() - start)*1000:.2f} ms") 63 | return outputs 64 | 65 | def process_box_output(self, box_output): 66 | 67 | predictions = np.squeeze(box_output).T 68 | num_classes = box_output.shape[1] - self.num_masks - 4 69 | 70 | # Filter out object confidence scores below threshold 71 | scores = np.max(predictions[:, 4:4+num_classes], axis=1) 72 | predictions = predictions[scores > self.conf_threshold, :] 73 | scores = scores[scores > self.conf_threshold] 74 | 75 | if len(scores) == 0: 76 | return [], [], [], np.array([]) 77 | 78 | box_predictions = predictions[..., :num_classes+4] 79 | mask_predictions = predictions[..., num_classes+4:] 80 | 81 | # Get the class with the highest confidence 82 | class_ids = np.argmax(box_predictions[:, 4:], axis=1) 83 | 84 | # Get bounding boxes for each object 85 | boxes = self.extract_boxes(box_predictions) 86 | 87 | # Apply non-maxima suppression to suppress weak, overlapping bounding boxes 88 | indices = nms(boxes, scores, self.iou_threshold) 89 | 90 | return boxes[indices], scores[indices], class_ids[indices], mask_predictions[indices] 91 | 92 | def process_mask_output(self, mask_predictions, mask_output): 93 | 94 | if mask_predictions.shape[0] == 0: 95 | return [] 96 | 97 | mask_output = np.squeeze(mask_output) 98 | 99 | # Calculate the mask maps for each box 100 | num_mask, mask_height, mask_width = mask_output.shape # CHW 101 | masks = sigmoid(mask_predictions @ mask_output.reshape((num_mask, -1))) 102 | masks = masks.reshape((-1, mask_height, mask_width)) 103 | 104 | # Downscale the boxes to match the mask size 105 | scale_boxes = self.rescale_boxes(self.boxes, 106 | (self.img_height, self.img_width), 107 | (mask_height, mask_width)) 108 | 109 | # For every box/mask pair, get the mask map 110 | mask_maps = np.zeros((len(scale_boxes), self.img_height, self.img_width)) 111 | blur_size = (int(self.img_width / mask_width), int(self.img_height / mask_height)) 112 | for i in range(len(scale_boxes)): 113 | 114 | scale_x1 = int(math.floor(scale_boxes[i][0])) 115 | scale_y1 = int(math.floor(scale_boxes[i][1])) 116 | scale_x2 = int(math.ceil(scale_boxes[i][2])) 117 | scale_y2 = int(math.ceil(scale_boxes[i][3])) 118 | 119 | x1 = int(math.floor(self.boxes[i][0])) 120 | y1 = int(math.floor(self.boxes[i][1])) 121 | x2 = int(math.ceil(self.boxes[i][2])) 122 | y2 = int(math.ceil(self.boxes[i][3])) 123 | 124 | scale_crop_mask = masks[i][scale_y1:scale_y2, scale_x1:scale_x2] 125 | crop_mask = cv2.resize(scale_crop_mask, 126 | (x2 - x1, y2 - y1), 127 | interpolation=cv2.INTER_CUBIC) 128 | 129 | crop_mask = cv2.blur(crop_mask, blur_size) 130 | 131 | crop_mask = (crop_mask > 0.5).astype(np.uint8) 132 | mask_maps[i, y1:y2, x1:x2] = crop_mask 133 | 134 | return mask_maps 135 | 136 | def extract_boxes(self, box_predictions): 137 | # Extract boxes from predictions 138 | boxes = box_predictions[:, :4] 139 | 140 | # Scale boxes to original image dimensions 141 | boxes = self.rescale_boxes(boxes, 142 | (self.input_height, self.input_width), 143 | (self.img_height, self.img_width)) 144 | 145 | # Convert boxes to xyxy format 146 | boxes = xywh2xyxy(boxes) 147 | 148 | # Check the boxes are within the image 149 | boxes[:, 0] = np.clip(boxes[:, 0], 0, self.img_width) 150 | boxes[:, 1] = np.clip(boxes[:, 1], 0, self.img_height) 151 | boxes[:, 2] = np.clip(boxes[:, 2], 0, self.img_width) 152 | boxes[:, 3] = np.clip(boxes[:, 3], 0, self.img_height) 153 | 154 | return boxes 155 | 156 | def draw_detections(self, image, draw_scores=True, mask_alpha=0.4): 157 | return draw_detections(image, self.boxes, self.scores, 158 | self.class_ids, mask_alpha) 159 | 160 | def draw_masks(self, image, draw_scores=True, mask_alpha=0.5): 161 | return draw_detections(image, self.boxes, self.scores, 162 | self.class_ids, mask_alpha, mask_maps=self.mask_maps) 163 | 164 | def get_input_details(self): 165 | model_inputs = self.session.get_inputs() 166 | self.input_names = [model_inputs[i].name for i in range(len(model_inputs))] 167 | 168 | self.input_shape = model_inputs[0].shape 169 | self.input_height = self.input_shape[2] 170 | self.input_width = self.input_shape[3] 171 | 172 | def get_output_details(self): 173 | model_outputs = self.session.get_outputs() 174 | self.output_names = [model_outputs[i].name for i in range(len(model_outputs))] 175 | 176 | @staticmethod 177 | def rescale_boxes(boxes, input_shape, image_shape): 178 | # Rescale boxes to original image dimensions 179 | input_shape = np.array([input_shape[1], input_shape[0], input_shape[1], input_shape[0]]) 180 | boxes = np.divide(boxes, input_shape, dtype=np.float32) 181 | boxes *= np.array([image_shape[1], image_shape[0], image_shape[1], image_shape[0]]) 182 | 183 | return boxes 184 | 185 | 186 | 187 | def is_image(file_path): 188 | """Check if the given path points to an image.""" 189 | image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff'] 190 | return any(file_path.lower().endswith(ext) for ext in image_extensions) 191 | 192 | def is_video(file_path): 193 | """Check if the given path points to a video.""" 194 | video_extensions = ['.mp4', '.avi', '.mkv', '.mov', '.wmv'] 195 | return any(file_path.lower().endswith(ext) for ext in video_extensions) 196 | 197 | def detect_input_type(input_path): 198 | """Detect the type of input based on the provided path.""" 199 | if os.path.isfile(input_path): 200 | if is_image(input_path): 201 | return 'image' 202 | elif is_video(input_path): 203 | return 'video' 204 | else: 205 | return None # Not an image or video 206 | elif os.path.isdir(input_path): 207 | return 'folder' 208 | else: 209 | return None # Not a valid file or directory 210 | 211 | if __name__ == '__main__': 212 | # Initialize argument parser 213 | parser = argparse.ArgumentParser(description='Object detection using YOLOv9') 214 | parser.add_argument('--model', type=str, default="gelan-c-pan.onnx", help='Path to the ONNX model') 215 | parser.add_argument('--input', type=str, default="video", help='Input type: "image", "folder", or "video"') 216 | args = parser.parse_args() 217 | 218 | # Initialize YOLOv9 Instance Segmentator 219 | yoloseg = YOLOSeg(args.model, conf_thres=0.3, iou_thres=0.5) 220 | 221 | # Create an output folder 222 | output_folder = "results" 223 | if not os.path.exists(output_folder): 224 | os.makedirs(output_folder) 225 | 226 | input_type = detect_input_type(args.input) 227 | if input_type == 'image': 228 | img = cv2.imread(args.input) 229 | print("image: ", args.input) 230 | # Detect objects in the image 231 | yoloseg(img) 232 | # Draw detections 233 | combined_img = yoloseg.draw_masks(img) 234 | output_path = os.path.join(output_folder, "result.jpg") 235 | cv2.imwrite(output_path, combined_img) 236 | cv2.imshow("Output", combined_img) 237 | cv2.waitKey(0) 238 | elif input_type == 'folder': 239 | # Loop through image files in the given folder 240 | for filename in os.listdir(args.input): 241 | if filename.endswith(".jpg") or filename.endswith(".png"): # Assuming images are jpg or png format 242 | img_path = os.path.join(args.input, filename) 243 | print("folder:", img_path) 244 | img = cv2.imread(img_path) 245 | # Detect objects in the image 246 | yoloseg(img) 247 | # Draw detections 248 | combined_img = yoloseg.draw_masks(img) 249 | output_path = os.path.join(output_folder, filename) 250 | cv2.imwrite(output_path, combined_img) 251 | cv2.imshow("Output", combined_img) 252 | cv2.waitKey(0) 253 | elif input_type == 'video': 254 | cap = cv2.VideoCapture(args.input) # Replace with the actual video path 255 | frame_width = round(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # Frame width 256 | frame_height = round(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # Frame height 257 | fps = int(cap.get(cv2.CAP_PROP_FPS)) # Frames per second 258 | out = cv2.VideoWriter(os.path.join(output_folder, "result.mp4"), cv2.VideoWriter_fourcc(*'DIVX'), fps, (frame_width, frame_height)) 259 | while True: 260 | ret, frame = cap.read() 261 | if not ret: 262 | break 263 | # Detect objects in the frame 264 | yoloseg(frame) 265 | # Draw detections 266 | combined_frame = yoloseg.draw_masks(frame) 267 | out.write(combined_frame) 268 | cv2.imshow("Output", combined_frame) 269 | if cv2.waitKey(1) & 0xFF == ord('q'): 270 | break 271 | cap.release() 272 | out.release() 273 | cv2.destroyAllWindows() 274 | -------------------------------------------------------------------------------- /onnxruntime/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spacewalk01/yolov9-onnx-segmentation/3a8916f312b3086ed72ce232850aa9d738ea2c3b/onnxruntime/result.jpg -------------------------------------------------------------------------------- /onnxruntime/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import requests 4 | from io import BytesIO 5 | 6 | class_names = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 7 | 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 8 | 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 9 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 10 | 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 11 | 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 12 | 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 13 | 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 14 | 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] 15 | 16 | # Create a list of colors for each class where each color is a tuple of 3 integer values 17 | rng = np.random.default_rng(3) 18 | colors = rng.uniform(0, 255, size=(len(class_names), 3)) 19 | 20 | def imread_from_url(url): 21 | try: 22 | # Send a GET request to the URL to download the image 23 | response = requests.get(url) 24 | 25 | # Check if the request was successful 26 | if response.status_code == 200: 27 | # Read the image from the response content 28 | image_bytes = BytesIO(response.content) 29 | image_array = np.asarray(bytearray(image_bytes.read()), dtype=np.uint8) 30 | image = cv2.imdecode(image_array, cv2.IMREAD_COLOR) 31 | 32 | return image 33 | else: 34 | print("Failed to download image. Status code:", response.status_code) 35 | return None 36 | except Exception as e: 37 | print("An error occurred:", e) 38 | return None 39 | 40 | def nms(boxes, scores, iou_threshold): 41 | # Sort by score 42 | sorted_indices = np.argsort(scores)[::-1] 43 | 44 | keep_boxes = [] 45 | while sorted_indices.size > 0: 46 | # Pick the last box 47 | box_id = sorted_indices[0] 48 | keep_boxes.append(box_id) 49 | 50 | # Compute IoU of the picked box with the rest 51 | ious = compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :]) 52 | 53 | # Remove boxes with IoU over the threshold 54 | keep_indices = np.where(ious < iou_threshold)[0] 55 | 56 | # print(keep_indices.shape, sorted_indices.shape) 57 | sorted_indices = sorted_indices[keep_indices + 1] 58 | 59 | return keep_boxes 60 | 61 | 62 | def compute_iou(box, boxes): 63 | # Compute xmin, ymin, xmax, ymax for both boxes 64 | xmin = np.maximum(box[0], boxes[:, 0]) 65 | ymin = np.maximum(box[1], boxes[:, 1]) 66 | xmax = np.minimum(box[2], boxes[:, 2]) 67 | ymax = np.minimum(box[3], boxes[:, 3]) 68 | 69 | # Compute intersection area 70 | intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin) 71 | 72 | # Compute union area 73 | box_area = (box[2] - box[0]) * (box[3] - box[1]) 74 | boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 75 | union_area = box_area + boxes_area - intersection_area 76 | 77 | # Compute IoU 78 | iou = intersection_area / union_area 79 | 80 | return iou 81 | 82 | 83 | def xywh2xyxy(x): 84 | # Convert bounding box (x, y, w, h) to bounding box (x1, y1, x2, y2) 85 | y = np.copy(x) 86 | y[..., 0] = x[..., 0] - x[..., 2] / 2 87 | y[..., 1] = x[..., 1] - x[..., 3] / 2 88 | y[..., 2] = x[..., 0] + x[..., 2] / 2 89 | y[..., 3] = x[..., 1] + x[..., 3] / 2 90 | return y 91 | 92 | 93 | def sigmoid(x): 94 | return 1 / (1 + np.exp(-x)) 95 | 96 | 97 | def draw_detections(image, boxes, scores, class_ids, mask_alpha=0.3, mask_maps=None): 98 | img_height, img_width = image.shape[:2] 99 | size = min([img_height, img_width]) * 0.0006 100 | text_thickness = int(min([img_height, img_width]) * 0.001) 101 | 102 | mask_img = draw_masks(image, boxes, class_ids, mask_alpha, mask_maps) 103 | 104 | # Draw bounding boxes and labels of detections 105 | for box, score, class_id in zip(boxes, scores, class_ids): 106 | color = colors[class_id] 107 | 108 | x1, y1, x2, y2 = box.astype(int) 109 | 110 | # Draw rectangle 111 | cv2.rectangle(mask_img, (x1, y1), (x2, y2), color, 2) 112 | 113 | label = class_names[class_id] 114 | caption = f'{label} {int(score * 100)}%' 115 | (tw, th), _ = cv2.getTextSize(text=caption, fontFace=cv2.FONT_HERSHEY_SIMPLEX, 116 | fontScale=size, thickness=text_thickness) 117 | th = int(th * 1.2) 118 | 119 | cv2.rectangle(mask_img, (x1, y1), 120 | (x1 + tw, y1 - th), color, -1) 121 | 122 | cv2.putText(mask_img, caption, (x1, y1), 123 | cv2.FONT_HERSHEY_SIMPLEX, size, (255, 255, 255), text_thickness, cv2.LINE_AA) 124 | 125 | return mask_img 126 | 127 | 128 | def draw_masks(image, boxes, class_ids, mask_alpha=0.3, mask_maps=None): 129 | mask_img = image.copy() 130 | 131 | # Draw bounding boxes and labels of detections 132 | for i, (box, class_id) in enumerate(zip(boxes, class_ids)): 133 | color = colors[class_id] 134 | 135 | x1, y1, x2, y2 = box.astype(int) 136 | 137 | # Draw fill mask image 138 | if mask_maps is None: 139 | cv2.rectangle(mask_img, (x1, y1), (x2, y2), color, -1) 140 | else: 141 | crop_mask = mask_maps[i][y1:y2, x1:x2, np.newaxis] 142 | crop_mask_img = mask_img[y1:y2, x1:x2] 143 | crop_mask_img = crop_mask_img * (1 - crop_mask) + crop_mask * color 144 | mask_img[y1:y2, x1:x2] = crop_mask_img 145 | 146 | return cv2.addWeighted(mask_img, mask_alpha, image, 1 - mask_alpha, 0) 147 | 148 | 149 | def draw_comparison(img1, img2, name1, name2, fontsize=2.6, text_thickness=3): 150 | (tw, th), _ = cv2.getTextSize(text=name1, fontFace=cv2.FONT_HERSHEY_DUPLEX, 151 | fontScale=fontsize, thickness=text_thickness) 152 | x1 = img1.shape[1] // 3 153 | y1 = th 154 | offset = th // 5 155 | cv2.rectangle(img1, (x1 - offset * 2, y1 + offset), 156 | (x1 + tw + offset * 2, y1 - th - offset), (0, 115, 255), -1) 157 | cv2.putText(img1, name1, 158 | (x1, y1), 159 | cv2.FONT_HERSHEY_DUPLEX, fontsize, 160 | (255, 255, 255), text_thickness) 161 | 162 | (tw, th), _ = cv2.getTextSize(text=name2, fontFace=cv2.FONT_HERSHEY_DUPLEX, 163 | fontScale=fontsize, thickness=text_thickness) 164 | x1 = img2.shape[1] // 3 165 | y1 = th 166 | offset = th // 5 167 | cv2.rectangle(img2, (x1 - offset * 2, y1 + offset), 168 | (x1 + tw + offset * 2, y1 - th - offset), (94, 23, 235), -1) 169 | 170 | cv2.putText(img2, name2, 171 | (x1, y1), 172 | cv2.FONT_HERSHEY_DUPLEX, fontsize, 173 | (255, 255, 255), text_thickness) 174 | 175 | combined_img = cv2.hconcat([img1, img2]) 176 | if combined_img.shape[1] > 3840: 177 | combined_img = cv2.resize(combined_img, (3840, 2160)) 178 | 179 | return combined_img 180 | --------------------------------------------------------------------------------