├── imgs
├── bus.jpg
├── test.jpg
└── zidane.jpg
├── show
├── bus.jpg
├── test.jpg
└── zidane.jpg
├── README.md
└── demo_trt.py
/imgs/bus.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuanandsix/yolov9-segmentation-tensorrt/HEAD/imgs/bus.jpg
--------------------------------------------------------------------------------
/imgs/test.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuanandsix/yolov9-segmentation-tensorrt/HEAD/imgs/test.jpg
--------------------------------------------------------------------------------
/show/bus.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuanandsix/yolov9-segmentation-tensorrt/HEAD/show/bus.jpg
--------------------------------------------------------------------------------
/show/test.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuanandsix/yolov9-segmentation-tensorrt/HEAD/show/test.jpg
--------------------------------------------------------------------------------
/imgs/zidane.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuanandsix/yolov9-segmentation-tensorrt/HEAD/imgs/zidane.jpg
--------------------------------------------------------------------------------
/show/zidane.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xuanandsix/yolov9-segmentation-tensorrt/HEAD/show/zidane.jpg
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # yolov9-segmentation-tensorrt
2 | This is the tensorrt inference code for yolov9 instance segmentation.
3 |
4 | ---
5 |
6 |
7 |
8 |
9 |
10 |
11 | ---
12 |
13 | Download [gelan-c-pan.pt](https://github.com/WongKinYiu/yolov9/releases/download/v0.1/gelan-c-pan.pt).
14 |
15 | ### Prepare an onnx model:
16 | ```
17 | git clone https://github.com/WongKinYiu/yolov9
18 | pip install -r requirements.txt
19 | python export.py --weights gelan-c-pan.pt --include onnx
20 | ```
21 |
22 | ### Test tensorrt
23 |
24 | 1、Use trtexec tool convert onnx model to trt model. You can also try something else, please make sure to get the correct trt model.
25 | ```
26 | /path/to/trtexec --onnx=gelan-c-pan.onnx --saveEngine=gelan-c-pan.engine --fp16
27 | ```
28 |
29 | 2、run python demo_trt.py, get image output.
30 |
31 | ```
32 | python demo_trt.py --engine gelan-c-pan.engine --imgs imgs --out-dir outputs
33 | ```
34 |
35 | ---
36 | ### Acknowledgement
37 |
38 | This project is based on the following projects:
39 |
40 | [YOLOv9](https://github.com/WongKinYiu/yolov9)
41 |
42 | [YOLOv8-TensorRT](https://github.com/triple-Mu/YOLOv8-TensorRT)
43 |
44 | [YOLOv9-ONNX-Segmentation](https://github.com/spacewalk01/yolov9-onnx-segmentation)
45 |
--------------------------------------------------------------------------------
/demo_trt.py:
--------------------------------------------------------------------------------
1 | import math
2 | import argparse
3 | from pathlib import Path
4 | import os
5 | import cv2
6 | import numpy as np
7 | import tensorrt as trt
8 | from cuda import cudart
9 | from numpy import ndarray
10 | import random
11 | import warnings
12 | from dataclasses import dataclass
13 | from typing import List, Optional, Tuple, Union
14 |
15 | os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
16 | warnings.filterwarnings(action='ignore', category=DeprecationWarning)
17 |
18 | random.seed(42)
19 | @dataclass
20 | class Tensor:
21 | name: str
22 | dtype: np.dtype
23 | shape: Tuple
24 | cpu: ndarray
25 | gpu: int
26 |
27 | # detection model classes
28 | CLASSES_DET = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
29 | 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
30 | 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
31 | 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
32 | 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
33 | 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
34 | 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
35 | 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
36 | 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
37 | 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
38 | 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
39 | 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
40 | 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
41 | 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
42 |
43 | # colors for per classes
44 | COLORS = {
45 | cls: [random.randint(0, 255) for _ in range(3)]
46 | for i, cls in enumerate(CLASSES_DET)
47 | }
48 |
49 | class TRTEngine:
50 | def __init__(self, weight: Union[str, Path]) -> None:
51 | self.weight = Path(weight) if isinstance(weight, str) else weight
52 | status, self.stream = cudart.cudaStreamCreate()
53 | assert status.value == 0
54 | self.__init_engine()
55 | self.__init_bindings()
56 | self.__warm_up()
57 |
58 | def __init_engine(self) -> None:
59 | logger = trt.Logger(trt.Logger.WARNING)
60 | trt.init_libnvinfer_plugins(logger, namespace='')
61 | with trt.Runtime(logger) as runtime:
62 | model = runtime.deserialize_cuda_engine(self.weight.read_bytes())
63 |
64 | context = model.create_execution_context()
65 |
66 | names = [model.get_binding_name(i) for i in range(model.num_bindings)]
67 | self.num_bindings = model.num_bindings
68 | self.bindings: List[int] = [0] * self.num_bindings
69 | num_inputs, num_outputs = 0, 0
70 |
71 | for i in range(model.num_bindings):
72 | if model.binding_is_input(i):
73 | num_inputs += 1
74 | else:
75 | num_outputs += 1
76 |
77 | self.num_inputs = num_inputs
78 | self.num_outputs = num_outputs
79 | self.model = model
80 | self.context = context
81 | self.input_names = names[:num_inputs]
82 | self.output_names = names[num_inputs:]
83 |
84 | def __init_bindings(self) -> None:
85 | dynamic = False
86 | inp_info = []
87 | out_info = []
88 | out_ptrs = []
89 | for i, name in enumerate(self.input_names):
90 | assert self.model.get_binding_name(i) == name
91 | dtype = trt.nptype(self.model.get_binding_dtype(i))
92 | shape = tuple(self.model.get_binding_shape(i))
93 | if -1 in shape:
94 | dynamic |= True
95 | if not dynamic:
96 | cpu = np.empty(shape, dtype)
97 | status, gpu = cudart.cudaMallocAsync(cpu.nbytes, self.stream)
98 | assert status.value == 0
99 | cudart.cudaMemcpyAsync(
100 | gpu, cpu.ctypes.data, cpu.nbytes,
101 | cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, self.stream)
102 | else:
103 | cpu, gpu = np.empty(0), 0
104 | inp_info.append(Tensor(name, dtype, shape, cpu, gpu))
105 | for i, name in enumerate(self.output_names):
106 | i += self.num_inputs
107 | assert self.model.get_binding_name(i) == name
108 | dtype = trt.nptype(self.model.get_binding_dtype(i))
109 | shape = tuple(self.model.get_binding_shape(i))
110 | if not dynamic:
111 | cpu = np.empty(shape, dtype=dtype)
112 | status, gpu = cudart.cudaMallocAsync(cpu.nbytes, self.stream)
113 | assert status.value == 0
114 | cudart.cudaMemcpyAsync(
115 | gpu, cpu.ctypes.data, cpu.nbytes,
116 | cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, self.stream)
117 | out_ptrs.append(gpu)
118 | else:
119 | cpu, gpu = np.empty(0), 0
120 | out_info.append(Tensor(name, dtype, shape, cpu, gpu))
121 |
122 | self.is_dynamic = dynamic
123 | self.inp_info = inp_info
124 | self.out_info = out_info
125 | self.out_ptrs = out_ptrs
126 |
127 | def __warm_up(self) -> None:
128 | if self.is_dynamic:
129 | print('You engine has dynamic axes, please warm up by yourself !')
130 | return
131 | for _ in range(10):
132 | inputs = []
133 | for i in self.inp_info:
134 | inputs.append(i.cpu)
135 | self.run(inputs)
136 |
137 | def set_profiler(self, profiler: Optional[trt.IProfiler]) -> None:
138 | self.context.profiler = profiler \
139 | if profiler is not None else trt.Profiler()
140 |
141 | def run(self, *inputs) -> Union[Tuple, ndarray]:
142 |
143 | assert len(inputs) == self.num_inputs
144 | contiguous_inputs: List[ndarray] = [
145 | np.ascontiguousarray(i) for i in inputs
146 | ]
147 |
148 | for i in range(self.num_inputs):
149 |
150 | if self.is_dynamic:
151 | self.context.set_binding_shape(
152 | i, tuple(contiguous_inputs[i].shape))
153 | status, self.inp_info[i].gpu = cudart.cudaMallocAsync(
154 | contiguous_inputs[i].nbytes, self.stream)
155 | assert status.value == 0
156 | cudart.cudaMemcpyAsync(
157 | self.inp_info[i].gpu, contiguous_inputs[i].ctypes.data,
158 | contiguous_inputs[i].nbytes,
159 | cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, self.stream)
160 | self.bindings[i] = self.inp_info[i].gpu
161 |
162 | output_gpu_ptrs: List[int] = []
163 | outputs: List[ndarray] = []
164 |
165 | for i in range(self.num_outputs):
166 | j = i + self.num_inputs
167 | if self.is_dynamic:
168 | shape = tuple(self.context.get_binding_shape(j))
169 | dtype = self.out_info[i].dtype
170 | cpu = np.empty(shape, dtype=dtype)
171 | status, gpu = cudart.cudaMallocAsync(cpu.nbytes, self.stream)
172 | assert status.value == 0
173 | cudart.cudaMemcpyAsync(
174 | gpu, cpu.ctypes.data, cpu.nbytes,
175 | cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, self.stream)
176 | else:
177 | cpu = self.out_info[i].cpu
178 | gpu = self.out_info[i].gpu
179 | outputs.append(cpu)
180 | output_gpu_ptrs.append(gpu)
181 | self.bindings[j] = gpu
182 |
183 | self.context.execute_async_v2(self.bindings, self.stream)
184 | cudart.cudaStreamSynchronize(self.stream)
185 |
186 | for i, o in enumerate(output_gpu_ptrs):
187 | cudart.cudaMemcpyAsync(
188 | outputs[i].ctypes.data, o, outputs[i].nbytes,
189 | cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, self.stream)
190 |
191 | return tuple(outputs) if len(outputs) > 1 else outputs[0]
192 |
193 | class Yolov8SegTrt(TRTEngine):
194 | def __init__(self, engine_path, conf_threshold, iou_threshold):
195 | super().__init__(engine_path)
196 | self.num_masks = 32
197 | self.conf_threshold = conf_threshold
198 | self.iou_threshold = iou_threshold
199 | self.input_height = 640
200 | self.input_width = 640
201 | self.mask_alpha = 0.5
202 |
203 | def sigmoid(self, x: ndarray) -> ndarray:
204 | return 1. / (1. + np.exp(-x))
205 |
206 | def process_box_output(self, box_output):
207 | predictions = np.squeeze(box_output).T # (8400, 116)
208 | num_classes = box_output.shape[1] - self.num_masks - 4
209 | # Filter out object confidence scores below threshold
210 | scores = np.max(predictions[:, 4:4+num_classes], axis=1)
211 | predictions = predictions[scores > self.conf_threshold, :]
212 | scores = scores[scores > self.conf_threshold]
213 | if len(scores) == 0:
214 | return [], [], [], np.array([])
215 | box_predictions = predictions[..., :num_classes+4]
216 | mask_predictions = predictions[..., num_classes+4:]
217 | # Get the class with the highest confidence
218 | class_ids = np.argmax(box_predictions[:, 4:], axis=1)
219 | # Get bounding boxes for each object
220 | boxes = self.extract_boxes(box_predictions)
221 | # Apply non-maxima suppression to suppress weak, overlapping bounding boxes
222 | indices = self.nms(boxes, scores, self.iou_threshold)
223 |
224 | return boxes[indices], scores[indices], class_ids[indices], mask_predictions[indices]
225 |
226 |
227 | def process_mask_output(self, mask_predictions, mask_output):
228 |
229 | if mask_predictions.shape[0] == 0:
230 | return []
231 |
232 | mask_output = np.squeeze(mask_output)
233 |
234 | # Calculate the mask maps for each box
235 | num_mask, mask_height, mask_width = mask_output.shape # CHW
236 | masks = self.sigmoid(mask_predictions @ mask_output.reshape((num_mask, -1)))
237 | masks = masks.reshape((-1, mask_height, mask_width))
238 |
239 | # Downscale the boxes to match the mask size
240 | scale_boxes = self.rescale_boxes(self.boxes,
241 | (self.img_height, self.img_width),
242 | (mask_height, mask_width))
243 |
244 | # For every box/mask pair, get the mask map
245 | mask_maps = np.zeros((len(scale_boxes), self.img_height, self.img_width))
246 | blur_size = (int(self.img_width / mask_width), int(self.img_height / mask_height))
247 | for i in range(len(scale_boxes)):
248 |
249 | scale_x1 = int(math.floor(scale_boxes[i][0]))
250 | scale_y1 = int(math.floor(scale_boxes[i][1]))
251 | scale_x2 = int(math.ceil(scale_boxes[i][2]))
252 | scale_y2 = int(math.ceil(scale_boxes[i][3]))
253 |
254 | x1 = int(math.floor(self.boxes[i][0]))
255 | y1 = int(math.floor(self.boxes[i][1]))
256 | x2 = int(math.ceil(self.boxes[i][2]))
257 | y2 = int(math.ceil(self.boxes[i][3]))
258 |
259 | scale_crop_mask = masks[i][scale_y1:scale_y2, scale_x1:scale_x2]
260 | crop_mask = cv2.resize(scale_crop_mask,
261 | (x2 - x1, y2 - y1),
262 | interpolation=cv2.INTER_CUBIC)
263 |
264 | crop_mask = cv2.blur(crop_mask, blur_size)
265 |
266 | crop_mask = (crop_mask > 0.5).astype(np.uint8)
267 | mask_maps[i, y1:y2, x1:x2] = crop_mask
268 |
269 | return mask_maps
270 |
271 | def extract_boxes(self, box_predictions):
272 | # Extract boxes from predictions
273 | boxes = box_predictions[:, :4]
274 |
275 | # Scale boxes to original image dimensions
276 | boxes = self.rescale_boxes(boxes,
277 | (self.input_height, self.input_width),
278 | (self.img_height, self.img_width))
279 |
280 | # Convert boxes to xyxy format
281 | boxes = self.xywh2xyxy(boxes)
282 |
283 | # Check the boxes are within the image
284 | boxes[:, 0] = np.clip(boxes[:, 0], 0, self.img_width)
285 | boxes[:, 1] = np.clip(boxes[:, 1], 0, self.img_height)
286 | boxes[:, 2] = np.clip(boxes[:, 2], 0, self.img_width)
287 | boxes[:, 3] = np.clip(boxes[:, 3], 0, self.img_height)
288 |
289 | return boxes
290 |
291 | @staticmethod
292 | def rescale_boxes(boxes, input_shape, image_shape):
293 | # Rescale boxes to original image dimensions
294 | input_shape = np.array([input_shape[1], input_shape[0], input_shape[1], input_shape[0]])
295 | boxes = np.divide(boxes, input_shape, dtype=np.float32)
296 | boxes *= np.array([image_shape[1], image_shape[0], image_shape[1], image_shape[0]])
297 |
298 | return boxes
299 |
300 | def nms(self, boxes, scores, iou_threshold):
301 | # Sort by score
302 | sorted_indices = np.argsort(scores)[::-1]
303 |
304 | keep_boxes = []
305 | while sorted_indices.size > 0:
306 | # Pick the last box
307 | box_id = sorted_indices[0]
308 | keep_boxes.append(box_id)
309 |
310 | # Compute IoU of the picked box with the rest
311 | ious = self.compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])
312 |
313 | # Remove boxes with IoU over the threshold
314 | keep_indices = np.where(ious < iou_threshold)[0]
315 |
316 | # print(keep_indices.shape, sorted_indices.shape)
317 | sorted_indices = sorted_indices[keep_indices + 1]
318 |
319 | return keep_boxes
320 |
321 |
322 | def compute_iou(self, box, boxes):
323 | # Compute xmin, ymin, xmax, ymax for both boxes
324 | xmin = np.maximum(box[0], boxes[:, 0])
325 | ymin = np.maximum(box[1], boxes[:, 1])
326 | xmax = np.minimum(box[2], boxes[:, 2])
327 | ymax = np.minimum(box[3], boxes[:, 3])
328 | # Compute intersection area
329 | intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)
330 | # Compute union area
331 | box_area = (box[2] - box[0]) * (box[3] - box[1])
332 | boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
333 | union_area = box_area + boxes_area - intersection_area
334 |
335 | # Compute IoU
336 | iou = intersection_area / union_area
337 |
338 | return iou
339 |
340 | def xywh2xyxy(self, x):
341 | # Convert bounding box (x, y, w, h) to bounding box (x1, y1, x2, y2)
342 | y = np.copy(x)
343 | y[..., 0] = x[..., 0] - x[..., 2] / 2
344 | y[..., 1] = x[..., 1] - x[..., 3] / 2
345 | y[..., 2] = x[..., 0] + x[..., 2] / 2
346 | y[..., 3] = x[..., 1] + x[..., 3] / 2
347 | return y
348 |
349 | def prepare_input(self, image):
350 | self.img_height, self.img_width = image.shape[:2]
351 | input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
352 | # Resize input image
353 | input_img = cv2.resize(input_img, (self.input_width, self.input_height))
354 | # Scale input pixel values to 0 to 1
355 | input_img = input_img / 255.0
356 | input_img = input_img.transpose(2, 0, 1)
357 | input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32)
358 | return input_tensor
359 |
360 | def draw_masks(self, image, boxes, class_ids, mask_alpha=0.3, mask_maps=None):
361 | mask_img = image.copy()
362 |
363 | # Draw bounding boxes and labels of detections
364 | for i, (box, class_id) in enumerate(zip(boxes, class_ids)):
365 | color = COLORS[CLASSES_DET[class_id]]
366 |
367 | x1, y1, x2, y2 = box.astype(int)
368 |
369 | # Draw fill mask image
370 | if mask_maps is None:
371 | cv2.rectangle(mask_img, (x1, y1), (x2, y2), color, -1)
372 | else:
373 | crop_mask = mask_maps[i][y1:y2, x1:x2, np.newaxis]
374 | crop_mask_img = mask_img[y1:y2, x1:x2]
375 | crop_mask_img = crop_mask_img * (1 - crop_mask) + crop_mask * color
376 | mask_img[y1:y2, x1:x2] = crop_mask_img
377 |
378 | return cv2.addWeighted(mask_img, mask_alpha, image, 1 - mask_alpha, 0)
379 |
380 |
381 | def draw(self, image, boxes, scores, class_ids, mask_alpha=0.3, mask_maps=None):
382 | img_height, img_width = image.shape[:2]
383 | size = min([img_height, img_width]) * 0.0006
384 | text_thickness = int(min([img_height, img_width]) * 0.001)
385 | mask_img = self.draw_masks(image, boxes, class_ids, mask_alpha, mask_maps)
386 | # Draw bounding boxes and labels of detections
387 | for box, score, class_id in zip(boxes, scores, class_ids):
388 | color = COLORS[CLASSES_DET[class_id]]
389 | x1, y1, x2, y2 = box.astype(int)
390 | # Draw rectangle
391 | cv2.rectangle(mask_img, (x1, y1), (x2, y2), color, 2)
392 | label = CLASSES_DET[class_id]
393 | caption = f'{label} {int(score * 100)}%'
394 | (tw, th), _ = cv2.getTextSize(text=caption, fontFace=cv2.FONT_HERSHEY_SIMPLEX,
395 | fontScale=size, thickness=text_thickness)
396 | th = int(th * 1.2)
397 | cv2.rectangle(mask_img, (x1, y1),
398 | (x1 + tw, y1 - th), color, -1)
399 | cv2.putText(mask_img, caption, (x1, y1),
400 | cv2.FONT_HERSHEY_SIMPLEX, size, (255, 255, 255), text_thickness, cv2.LINE_AA)
401 | return mask_img
402 |
403 |
404 | def forward(self, image):
405 | input_tensor = self.prepare_input(image)
406 | outputs = self.run(input_tensor)
407 | self.boxes, self.scores, self.class_ids, mask_pred = self.process_box_output(outputs[2])
408 | self.mask_maps = self.process_mask_output(mask_pred, outputs[0])
409 |
410 | combined_img = self.draw(image, self.boxes, self.scores,
411 | self.class_ids, self.mask_alpha, mask_maps=self.mask_maps)
412 |
413 | return combined_img
414 |
415 | def parse_args():
416 | parser = argparse.ArgumentParser()
417 | parser.add_argument('--engine', type=str, help='Engine file')
418 | parser.add_argument('--imgs', type=str, help='Images file')
419 | parser.add_argument('--out-dir',
420 | type=str,
421 | default='./output',
422 | help='Path to output file')
423 | parser.add_argument('--conf_threshold',
424 | type=float,
425 | default=0.25,
426 | help='Confidence threshold')
427 | parser.add_argument('--iou_threshold',
428 | type=float,
429 | default=0.65,
430 | help='Confidence threshold')
431 | args = parser.parse_args()
432 | return args
433 |
434 | if __name__ == '__main__':
435 | args = parse_args()
436 | segment = Yolov8SegTrt(args.engine, args.conf_threshold, args.iou_threshold)
437 | image_files = os.listdir(args.imgs)
438 | save_path = args.out_dir
439 | if not os.path.exists(save_path):
440 | os.mkdir(save_path)
441 | for file in image_files:
442 | bgr = cv2.imread(args.imgs + '/' + str(file))
443 | # inference
444 | draw = segment.forward(bgr)
445 | print(draw.shape)
446 | cv2.imwrite(save_path + '/' + str(file), draw)
447 |
--------------------------------------------------------------------------------