├── images └── truck.jpg ├── utils ├── __init__.py ├── common.py └── sam_function.py ├── sam_onnx_inference.py ├── sam_torch_inference.py ├── sam_trt_inference.py ├── README.md ├── requirements.txt └── scripts └── onnx2trt.py /images/truck.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BooHwang/segment_anything_tensorrt/HEAD/images/truck.jpg -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : __init__.py 5 | @Time : 2023/06/01 13:49:27 6 | @Author : Huang Bo 7 | @Contact : cenahwang0304@gmail.com 8 | @Desc : None 9 | ''' 10 | 11 | from .sam_function import pre_processing, apply_coords, show_mask, show_points, show_box, mask_postprocessing -------------------------------------------------------------------------------- /sam_onnx_inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : sam_onnx_inference.py 5 | @Time : 2023/05/31 11:16:23 6 | @Author : Huang Bo 7 | @Contact : cenahwang0304@gmail.com 8 | @Desc : None 9 | ''' 10 | 11 | 12 | import argparse 13 | import numpy as np 14 | import cv2 15 | import matplotlib.pyplot as plt 16 | import onnxruntime 17 | from utils import pre_processing, apply_coords, show_mask, show_points 18 | 19 | 20 | if __name__ == "__main__": 21 | parser = argparse.ArgumentParser("use sam onnx model inference") 22 | parser.add_argument("--img_path", type=str, default="images/truck.jpg", help="you want segment image") 23 | parser.add_argument("--img_onnx_model_path", type=str, default="embedding_onnx/sam_default_embedding.onnx") 24 | parser.add_argument("--sam_onnx_model_path", type=str, default="weights/sam_vit_h_4b8939.onnx", help="sam onnx model") 25 | parser.add_argument("--gpu_id", type=int, default=0, help="use which gpu to inference") 26 | args = parser.parse_args() 27 | 28 | device = f"cuda:{args.gpu_id}" 29 | 30 | ort_embedding_session = onnxruntime.InferenceSession(args.img_onnx_model_path, provider=['CUDAExecutionProvider']) 31 | ort_sam_session = onnxruntime.InferenceSession(args.sam_onnx_model_path, provider=['CUDAExecutionProvider']) 32 | ort_embedding_session.set_providers(['CUDAExecutionProvider'], provider_options=[{f'device_id': {args.gpu_id}}]) 33 | ort_sam_session.set_providers(['CUDAExecutionProvider'], provider_options=[{f'device_id': {args.gpu_id}}]) 34 | 35 | image = cv2.imread(args.img_path) 36 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 37 | 38 | img_inputs = pre_processing(image) 39 | ort_inputs = {"images": img_inputs} 40 | # Get image embedding, just extra once 41 | image_embeddings = ort_embedding_session.run(None, ort_inputs)[0] 42 | 43 | # Point prompt 44 | input_point = np.array([[500, 375]]) 45 | input_label = np.array([1]) 46 | onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :] 47 | onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32) 48 | onnx_coord = apply_coords(onnx_coord, image.shape[:2]).astype(np.float32) 49 | onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) 50 | onnx_has_mask_input = np.zeros(1, dtype=np.float32) 51 | 52 | ort_inputs = { 53 | "image_embeddings": image_embeddings, 54 | "point_coords": onnx_coord, 55 | "point_labels": onnx_label, 56 | "mask_input": onnx_mask_input, 57 | "has_mask_input": onnx_has_mask_input, 58 | "orig_im_size": np.array(image.shape[:2], dtype=np.int32) 59 | } 60 | masks, scores, low_res_logits = ort_sam_session.run(None, ort_inputs) 61 | masks = masks > 0.0 62 | 63 | for i, (mask, score) in enumerate(zip(masks[0], scores[0])): 64 | plt.figure(figsize=(10,10)) 65 | plt.imshow(image) 66 | show_mask(mask, plt.gca()) 67 | show_points(input_point, input_label, plt.gca()) 68 | plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18) 69 | plt.axis('off') 70 | plt.savefig(f"results/onnx_mask{i}.png", bbox_inches='tight', pad_inches=0) 71 | print(f"generate: results/onnx_mask{i}.png") 72 | # plt.show() -------------------------------------------------------------------------------- /sam_torch_inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : sam_torch_inference.py 5 | @Time : 2023/05/31 20:21:39 6 | @Author : Huang Bo 7 | @Contact : cenahwang0304@gmail.com 8 | @Desc : None 9 | ''' 10 | 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | import cv2 14 | import argparse 15 | from segment_anything import sam_model_registry, SamPredictor 16 | from utils import show_mask, show_points, show_box 17 | 18 | 19 | if __name__ == "__main__": 20 | parser = argparse.ArgumentParser("use sam torch model inference") 21 | parser.add_argument("--img_path", type=str, default="images/truck.jpg") 22 | parser.add_argument("--sam_checkpoint", type=str, default="weights/sam_vit_h_4b8939.pth") 23 | parser.add_argument("--model_type", type=str, default="vit_h") 24 | parser.add_argument("--gpu_id", type=int, default=0, help="use which gpu to inference") 25 | 26 | args = parser.parse_args() 27 | 28 | device = f"cuda:{args.gpu_id}" 29 | 30 | image = cv2.imread(args.img_path) 31 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 32 | 33 | sam = sam_model_registry[args.model_type](checkpoint=args.sam_checkpoint) 34 | sam.to(device=device) 35 | 36 | predictor = SamPredictor(sam) 37 | predictor.set_image(image) 38 | 39 | print("Use point prompt to segment ...") 40 | # Point prompt 41 | input_point = np.array([[500, 375]]) 42 | input_label = np.array([1]) 43 | 44 | masks, scores, logits = predictor.predict( 45 | point_coords=input_point, 46 | point_labels=input_label, 47 | multimask_output=True, 48 | ) 49 | 50 | for i, (mask, score) in enumerate(zip(masks, scores)): 51 | plt.figure(figsize=(10,10)) 52 | plt.imshow(image) 53 | show_mask(mask, plt.gca()) 54 | show_points(input_point, input_label, plt.gca()) 55 | plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18) 56 | plt.axis('off') 57 | plt.savefig(f"results/torch_mask{i}.png", bbox_inches='tight', pad_inches=0) 58 | print(f"generate: results/torch_mask{i}.png") 59 | # plt.show() 60 | 61 | 62 | print("Use point and last segment mask prompt to segment ...") 63 | # use last inference mask as input 64 | input_point = np.array([[500, 375], [1125, 625]]) 65 | input_label = np.array([1, 1]) 66 | mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask 67 | 68 | masks, _, _ = predictor.predict( 69 | point_coords=input_point, 70 | point_labels=input_label, 71 | mask_input=mask_input[None, :, :], 72 | multimask_output=False, 73 | ) 74 | 75 | print("Use point and boxes prompt to segment ...") 76 | # use box method to segment 77 | input_box = np.array([425, 600, 700, 875]) 78 | masks, _, _ = predictor.predict( 79 | point_coords=None, 80 | point_labels=None, 81 | box=input_box[None, :], 82 | multimask_output=False, 83 | ) 84 | 85 | 86 | input_box = np.array([425, 600, 700, 875]) 87 | input_point = np.array([[575, 750]]) 88 | input_label = np.array([0]) 89 | 90 | masks, _, _ = predictor.predict( 91 | point_coords=input_point, 92 | point_labels=input_label, 93 | box=input_box, 94 | multimask_output=False, 95 | ) -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : common.py 5 | @Time : 2023/04/20 16:30:58 6 | @Author : Huang Bo 7 | @Contact : cenahwang0304@gmail.com 8 | @Desc : None 9 | ''' 10 | 11 | ''' 12 | TensorRT common file 13 | ''' 14 | 15 | import tensorrt as trt 16 | import numpy as np 17 | import os 18 | import pycuda.driver as cuda 19 | # import pycuda.autoinit # if you want use other diff gpu rather gpu 0, you should commit this line 20 | 21 | class HostDeviceMem(object): 22 | def __init__(self, host_mem, device_mem): 23 | self.host = host_mem 24 | self.device = device_mem 25 | 26 | def __str__(self): 27 | return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device) 28 | 29 | def __repr__(self): 30 | return self.__str__() 31 | 32 | class TrtModel(): 33 | def __init__(self, engine_path, gpu_id=0, max_batch_size=1): 34 | print("TensorRT version: %s" % trt.__version__) 35 | cuda.init() 36 | self.cfx = cuda.Device(gpu_id).make_context() 37 | self.engine_path = engine_path 38 | self.logger = trt.Logger(trt.Logger.WARNING) 39 | self.runtime = trt.Runtime(self.logger) 40 | self.engine = self.load_engine(self.runtime, self.engine_path) 41 | self.context = self.engine.create_execution_context() 42 | self.max_batch_size = max_batch_size 43 | self.inputs, self.outputs, self.bindings, self.stream = self.allocate_buffers() 44 | 45 | @staticmethod 46 | def load_engine(trt_runtime, engine_path): 47 | trt.init_libnvinfer_plugins(None, "") 48 | with open(engine_path, 'rb') as f: 49 | engine_data = f.read() 50 | engine = trt_runtime.deserialize_cuda_engine(engine_data) 51 | return engine 52 | 53 | def allocate_buffers(self): 54 | inputs = [] 55 | outputs = [] 56 | bindings = [] 57 | stream = cuda.Stream() 58 | 59 | for binding in self.engine: 60 | if binding in ["point_coords", "point_labels"]: 61 | size = abs(trt.volume(self.engine.get_binding_shape(binding))) * self.max_batch_size 62 | else: 63 | size = abs(trt.volume(self.engine.get_binding_shape(binding))) 64 | # print(f"binding: {binding}, size: {size}") 65 | dtype = trt.nptype(self.engine.get_binding_dtype(binding)) 66 | host_mem = cuda.pagelocked_empty(size, dtype) 67 | device_mem = cuda.mem_alloc(host_mem.nbytes) 68 | bindings.append(int(device_mem)) 69 | if self.engine.binding_is_input(binding): 70 | inputs.append(HostDeviceMem(host_mem, device_mem)) 71 | else: 72 | outputs.append(HostDeviceMem(host_mem, device_mem)) 73 | return inputs, outputs, bindings, stream 74 | 75 | def __call__(self, inf_in_list, binding_shape_map=None): 76 | self.cfx.push() 77 | 78 | if binding_shape_map: 79 | self.context.set_optimization_profile_async 80 | for binding_name, shape in binding_shape_map.items(): 81 | binding_idx = self.engine[binding_name] 82 | # print(f"binding_name: {binding_name}, binding_idx: {binding_idx}, shape: {shape}") 83 | self.context.set_binding_shape(binding_idx, shape) 84 | 85 | for i in range(len(self.inputs)): 86 | self.inputs[i].host = inf_in_list[i] 87 | cuda.memcpy_htod_async(self.inputs[i].device, self.inputs[i].host, self.stream) 88 | 89 | self.context.execute_async_v2(bindings=self.bindings, stream_handle=self.stream.handle) 90 | 91 | for i in range(len(self.outputs)): 92 | cuda.memcpy_dtoh_async(self.outputs[i].host, self.outputs[i].device, self.stream) 93 | 94 | self.stream.synchronize() 95 | self.cfx.pop() 96 | return [out.host.copy() for out in self.outputs] 97 | 98 | def __del__(self): 99 | self.cfx.pop() 100 | del self.cfx -------------------------------------------------------------------------------- /sam_trt_inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : sam_trt_inference.py 5 | @Time : 2023/05/31 11:13:08 6 | @Author : Huang Bo 7 | @Contact : cenahwang0304@gmail.com 8 | @Desc : None 9 | ''' 10 | 11 | ''' 12 | Use tensorrt accerate segment anything model(SAM) inference 13 | ''' 14 | 15 | import numpy as np 16 | import cv2 17 | from utils.common import TrtModel 18 | import os 19 | import argparse 20 | import matplotlib.pyplot as plt 21 | from utils import apply_coords, pre_processing, mask_postprocessing, show_mask, show_points 22 | 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser("use tensorrt to inference segment anything model") 26 | parser.add_argument("--img_path", type=str, default="images/truck.jpg", help="you want segment image") 27 | parser.add_argument("--sam_engine_file", type=str, default="weights/sam_default_prompt_mask.engine") 28 | parser.add_argument("--embedding_engine_file", type=str, default="embedding_onnx/sam_default_embedding.engine") 29 | parser.add_argument("--gpu_id", type=int, default=0, help="use which gpu to inference") 30 | parser.add_argument("--batch_size", type=int, default=1, help="use batch size img to inference") 31 | args = parser.parse_args() 32 | 33 | image = cv2.imread(args.img_path) 34 | orig_im_size = image.shape[:2] 35 | img_inputs = pre_processing(image) 36 | print(f'img input: {img_inputs.shape}') 37 | 38 | # embedding init 39 | embedding_inference = TrtModel(engine_path=args.embedding_engine_file, gpu_id=args.gpu_id, max_batch_size=args.batch_size) 40 | # sam init 41 | sam_inference = TrtModel(engine_path=args.sam_engine_file, gpu_id=args.gpu_id, max_batch_size=20) 42 | 43 | image_embedding = embedding_inference([img_inputs])[0].reshape(1, 256, 64, 64) 44 | print(f"img embedding: {image_embedding.shape}") 45 | 46 | # Point prompt 47 | input_point = np.array([[500, 375]]) 48 | input_label = np.array([1]) 49 | onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :] 50 | onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32) 51 | onnx_coord = apply_coords(onnx_coord, image.shape[:2]).astype(np.float32) 52 | onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) 53 | onnx_has_mask_input = np.zeros(1, dtype=np.float32) 54 | 55 | # print(image_embedding.shape) 56 | # print(onnx_coord.shape) 57 | # print(onnx_label.shape) 58 | # print(onnx_mask_input.shape) 59 | # print(onnx_has_mask_input.shape) 60 | 61 | input = [image_embedding, onnx_coord, onnx_label, onnx_mask_input, onnx_has_mask_input] 62 | shape_map = {'image_embeddings': image_embedding.shape, 63 | 'point_coords': onnx_coord.shape, 64 | 'point_labels': onnx_label.shape, 65 | 'mask_input': onnx_mask_input.shape, 66 | 'has_mask_input': onnx_has_mask_input.shape} 67 | 68 | output = sam_inference(input, binding_shape_map=shape_map) 69 | 70 | # print(output[0].shape) 71 | # print(output[1].shape) 72 | 73 | low_res_logits = output[0].reshape(args.batch_size, -1).reshape(4, 256, 256) 74 | scores = output[1].reshape(args.batch_size, -1).squeeze(0) 75 | 76 | masks = mask_postprocessing(low_res_logits, orig_im_size, img_inputs.shape[2]) 77 | masks = masks.numpy().squeeze(0) 78 | os.makedirs("results", exist_ok=True) 79 | # for i in range(masks.shape[0]): 80 | # # mask_image = show_mask(masks[i]*255) 81 | # cv2.imwrite(f"results/trt_mask{i}.png", masks[i]*255) 82 | # print(f"Generate results/trt_mask{i}.png") 83 | 84 | for i, (mask, score) in enumerate(zip(masks, scores)): 85 | mask = mask > 0.0 86 | plt.figure(figsize=(10,10)) 87 | plt.imshow(image[:, :, ::-1]) 88 | show_mask(mask, plt.gca()) 89 | show_points(input_point, input_label, plt.gca()) 90 | plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18) 91 | plt.axis('off') 92 | plt.savefig(f"results/trt_mask{i}.png", bbox_inches='tight', pad_inches=0) 93 | print(f"generate: results/trt_mask{i}.png") 94 | # plt.show() 95 | -------------------------------------------------------------------------------- /utils/sam_function.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : sam_function.py 5 | @Time : 2023/06/01 13:50:45 6 | @Author : Huang Bo 7 | @Contact : cenahwang0304@gmail.com 8 | @Desc : None 9 | ''' 10 | 11 | ''' 12 | all the function are from sam source code 13 | ''' 14 | 15 | import torch 16 | import numpy as np 17 | from typing import Tuple, List 18 | import matplotlib.pyplot as plt 19 | from torchvision.transforms.functional import resize, to_pil_image 20 | from torch.nn import functional as F 21 | 22 | def show_mask(mask, ax, random_color=False): 23 | if random_color: 24 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 25 | else: 26 | color = np.array([30/255, 144/255, 255/255, 0.6]) 27 | h, w = mask.shape[-2:] 28 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 29 | ax.imshow(mask_image) 30 | 31 | def show_points(coords, labels, ax, marker_size=375): 32 | pos_points = coords[labels==1] 33 | neg_points = coords[labels==0] 34 | ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) 35 | ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) 36 | 37 | def show_box(box, ax): 38 | x0, y0 = box[0], box[1] 39 | w, h = box[2] - box[0], box[3] - box[1] 40 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) 41 | 42 | def apply_coords(coords: np.ndarray, original_size: Tuple[int, ...], target_length: int = 1024) -> np.ndarray: 43 | """ 44 | Expects a numpy array of length 2 in the final dimension. Requires the 45 | original image size in (H, W) format. 46 | """ 47 | old_h, old_w = original_size 48 | new_h, new_w = get_preprocess_shape( 49 | original_size[0], original_size[1], target_length 50 | ) 51 | coords = coords.copy().astype(float) 52 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 53 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 54 | return coords 55 | 56 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 57 | """ 58 | Compute the output size given input size and target long side length. 59 | """ 60 | scale = long_side_length * 1.0 / max(oldh, oldw) 61 | newh, neww = oldh * scale, oldw * scale 62 | neww = int(neww + 0.5) 63 | newh = int(newh + 0.5) 64 | return (newh, neww) 65 | 66 | def pre_processing(image: np.ndarray, 67 | img_size: int = 1024, 68 | target_length: int = 1024, 69 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 70 | pixel_std: List[float] = [58.395, 57.12, 57.375]): 71 | pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1) 72 | pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1) 73 | target_size = get_preprocess_shape(image.shape[0], image.shape[1], target_length) 74 | input_image = np.array(resize(to_pil_image(image), target_size)) 75 | input_image_torch = torch.as_tensor(input_image) 76 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 77 | 78 | # Normalize colors 79 | input_image_torch = (input_image_torch - pixel_mean) / pixel_std 80 | 81 | # Pad 82 | h, w = input_image_torch.shape[-2:] 83 | padh = img_size - h 84 | padw = img_size - w 85 | input_image_torch = F.pad(input_image_torch, (0, padw, 0, padh)) 86 | return input_image_torch.numpy() 87 | 88 | def resize_longest_image_size( 89 | input_image_size: torch.Tensor, longest_side: int 90 | ) -> torch.Tensor: 91 | input_image_size = input_image_size.to(torch.float32) 92 | scale = longest_side / torch.max(input_image_size) 93 | transformed_size = scale * input_image_size 94 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 95 | return transformed_size 96 | 97 | def mask_postprocessing(masks: np.array, orig_im_size: Tuple, img_size: int) -> torch.Tensor: 98 | masks = torch.from_numpy(masks[None, :, :, :]) # (4, 256, 256) -> (1, 4, 256, 256) 99 | orig_im_size = torch.tensor([orig_im_size[0], orig_im_size[1]], dtype=torch.int32) 100 | 101 | masks = F.interpolate( 102 | masks, 103 | size=(img_size, img_size), 104 | mode="bilinear", 105 | align_corners=False, 106 | ) 107 | 108 | prepadded_size = resize_longest_image_size(orig_im_size, img_size).to(torch.int64) 109 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 110 | 111 | orig_im_size = orig_im_size.to(torch.int64) 112 | h, w = orig_im_size[0], orig_im_size[1] 113 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 114 | return masks -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Segment anything tensorrt 2 | 3 | Use tensorrt accerate segment anything model ([SAM](https://github.com/facebookresearch/segment-anything)), which design by facebook research. In this repo, we divide SAM into two parts for model transformation, one is `ImageEncoderViT` (also named img embedding in this repo), and other one is `MaskDecoder`, `PromptEncoder` (also named sam model in this repo). while image encoder just inference once, and the most process time waste in image embedding, so you can save image embedding, and input different point or boxes to segment as your wish. 4 | 5 | 6 | 7 | ## Installation 8 | 9 | The code requires `python>=3.8`, as well as `pytorch>=1.7` and `torchvision>=0.8`. Installing both PyTorch and TorchVision with CUDA support is strongly recommended. 10 | 11 | ```shell 12 | # Install Segment Anything 13 | pip install git+https://github.com/facebookresearch/segment-anything.git 14 | 15 | # Install requirement, maybe more than real use 16 | pip install -r requirements.txt 17 | 18 | # clone our repo 19 | git clone https://github.com/BooHwang/segment_anything_tensorrt.git 20 | ``` 21 | 22 | After clone the code, you should download sam model from ([SAM](https://github.com/facebookresearch/segment-anything)), and put it in `weights`. 23 | 24 | 25 | 26 | ## Model Transform 27 | 28 | ### Image embedding transform 29 | 30 | - Transform image embedding pth from sam to onnx model 31 | 32 | ```shell 33 | python scripts/onnx2trt.py --img_pt2onnx --sam_checkpoint weights/sam_vit_h_4b8939.pth --model_type default 34 | ``` 35 | 36 | 37 | 38 | - Transform image embedding onnx model to tensorrt engine 39 | 40 | ```shell 41 | trtexec --onnx=embedding_onnx/sam_default_embedding.onnx --workspace=4096 --saveEngine=weights/sam_default_embedding.engine 42 | ``` 43 | 44 | 45 | 46 | - Or use code transform image embedding onnx model to tensorrt engine 47 | 48 | ```shell 49 | python scripts/onnx2trt.py --img_onnx2trt --img_onnx_model_path embedding_onnx/sam_default_embedding.onnx 50 | ``` 51 | 52 | 53 | 54 | ### SAM model transform 55 | 56 | **Notice:** opset set difference will get error while transfer onnx model to tensorrt engine, and it can set to 16 or 17 while my docker images is "nvidia/cuda:11.4.2-cudnn8-devel-ubuntu18.04" 57 | 58 | 59 | 60 | - Transform sam pth model to onnx model 61 | 62 | ```shell 63 | git clone https://github.com/facebookresearch/segment-anything.git 64 | 65 | cd segment-anything 66 | 67 | # To avoid fixing the original size when exporting the model, it is necessary to modify some of the code 68 | # change "forward" function in the file which is "segment_anything/utils/onnx.py",as follows: 69 | def forward( 70 | self, 71 | image_embeddings: torch.Tensor, 72 | point_coords: torch.Tensor, 73 | point_labels: torch.Tensor, 74 | mask_input: torch.Tensor, 75 | has_mask_input: torch.Tensor 76 | # orig_im_size: torch.Tensor, 77 | ): 78 | sparse_embedding = self._embed_points(point_coords, point_labels) 79 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 80 | 81 | masks, scores = self.model.mask_decoder.predict_masks( 82 | image_embeddings=image_embeddings, 83 | image_pe=self.model.prompt_encoder.get_dense_pe(), 84 | sparse_prompt_embeddings=sparse_embedding, 85 | dense_prompt_embeddings=dense_embedding, 86 | ) 87 | 88 | if self.use_stability_score: 89 | scores = calculate_stability_score( 90 | masks, self.model.mask_threshold, self.stability_score_offset 91 | ) 92 | 93 | if self.return_single_mask: 94 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 95 | 96 | return masks, scores 97 | # upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 98 | 99 | # if self.return_extra_metrics: 100 | # stability_scores = calculate_stability_score( 101 | # upscaled_masks, self.model.mask_threshold, self.stability_score_offset 102 | # ) 103 | # areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 104 | # return upscaled_masks, scores, stability_scores, areas, masks 105 | 106 | # return upscaled_masks, scores, masks 107 | 108 | # Download weights of sam_vit_h_4b8939.pth 109 | python scripts/onnx2trt.py --prompt_masks_pt2onnx 110 | ``` 111 | 112 | 113 | 114 | - Export onnx model to engine file 115 | 116 | ```shell 117 | # Download TensorRT packet from Nvidia, and unzip to /root, here we use version: TensorRT-8.6.1.6 118 | pip install ~/TensorRT-8.6.1.6/python/tensorrt-8.6.1-cp38-none-linux_x86_64.whl 119 | export PATH=$HOME/TensorRT-8.6.1.6/targets/x86_64-linux-gnu/bin:$PATH 120 | export TENSORRT_DIR=$HOME/TensorRT-8.6.1.6:$TENSORRT_DIR 121 | export LD_LIBRARY_PATH=$HOME/TensorRT-8.6.1.6/lib:$LD_LIBRARY_PATH 122 | 123 | # transform prompt encoder and mask decoder onnx model to tensorrt engine 124 | trtexec --onnx=weights/sam_default_prompt_mask.onnx --workspace=4096 --shapes=image_embeddings:1x256x64x64,point_coords:1x1x2,point_labels:1x1,mask_input:1x1x256x256,has_mask_input:1 --minShapes=image_embeddings:1x256x64x64,point_coords:1x1x2,point_labels:1x1,mask_input:1x1x256x256,has_mask_input:1 --optShapes=image_embeddings:1x256x64x64,point_coords:1x10x2,point_labels:1x10,mask_input:1x1x256x256,has_mask_input:1 --maxShapes=image_embeddings:1x256x64x64,point_coords:1x20x2,point_labels:1x20,mask_input:1x1x256x256,has_mask_input:1 --saveEngine=weights/sam_default_prompt_mask.engine 125 | ``` 126 | 127 | 128 | 129 | - Or export onnx model to tensorrt engine file by code 130 | 131 | ```shell 132 | python scripts/onnx2trt.py --sam_onnx2trt --sam_onnx_path ./weights/sam_vit_h_4b8939.onnx 133 | ``` 134 | 135 | 136 | 137 | ## Inference 138 | 139 | - Use **Pytorch** model inference 140 | 141 | ```shell 142 | python sam_torch_inference.py 143 | ``` 144 | 145 | 146 | 147 | - Use **ONNX** model inference 148 | 149 | ```shell 150 | python sam_onnx_inference.py 151 | ``` 152 | 153 | 154 | 155 | - Use **TensorRT** inference 156 | 157 | ```shell 158 | python sam_trt_inference.py 159 | ``` 160 | 161 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | accelerate==0.19.0 3 | addict==2.4.0 4 | aenum==3.1.12 5 | aiofiles==23.1.0 6 | aiohttp==3.8.4 7 | aiosignal==1.3.1 8 | albumentations==0.5.2 9 | altair==5.0.0 10 | antlr4-python3-runtime==4.9.3 11 | anyio==3.6.2 12 | appdirs==1.4.4 13 | asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1670263926556/work 14 | async-timeout==4.0.2 15 | attrs==23.1.0 16 | backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work 17 | backports.functools-lru-cache @ file:///home/conda/feedstock_root/build_artifacts/backports.functools_lru_cache_1618230623929/work 18 | basicsr==1.4.2 19 | beautifulsoup4==4.12.2 20 | bidict==0.22.1 21 | black==23.3.0 22 | blendmodes==2022 23 | blis==0.7.9 24 | boltons==23.0.0 25 | braceexpand==0.1.7 26 | cachetools==5.3.0 27 | catalogue==2.0.8 28 | certifi==2023.5.7 29 | cffi==1.15.1 30 | charset-normalizer==3.1.0 31 | clean-fid==0.1.35 32 | click==8.1.3 33 | clip @ git+https://github.com/openai/CLIP.git@a9b1bf5920416aaeaec965c25dd9e8f98c864f16 34 | cloudpickle==2.2.1 35 | cmake==3.26.3 36 | coloredlogs==15.0.1 37 | confection==0.0.4 38 | contourpy==1.0.7 39 | cycler==0.11.0 40 | cymem==2.0.7 41 | Cython==0.29.34 42 | debugpy @ file:///tmp/build/80754af9/debugpy_1637091796427/work 43 | decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work 44 | deprecation==2.1.0 45 | detectron2 @ git+https://github.com/facebookresearch/detectron2.git@2c6c380f94a27bd8455a39506c9105f652b9f760 46 | diffusers==0.16.1 47 | distlib==0.3.6 48 | dlib==19.24.1 49 | easydict==1.10 50 | einops==0.6.1 51 | en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.5.0/en_core_web_sm-3.5.0-py3-none-any.whl 52 | entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work 53 | executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1667317341051/work 54 | facexlib==0.3.0 55 | fairscale==0.4.13 56 | fastapi==0.95.1 57 | ffmpy==0.3.0 58 | filelock==3.12.0 59 | filterpy==1.4.5 60 | flatbuffers==23.5.8 61 | font-roboto==0.0.1 62 | fonts==0.0.3 63 | fonttools==4.39.4 64 | frozenlist==1.3.3 65 | fsspec==2023.5.0 66 | ftfy==6.1.1 67 | future==0.18.3 68 | fvcore==0.1.5.post20221221 69 | gdown==4.7.1 70 | gfpgan==1.3.8 71 | gitdb==4.0.10 72 | GitPython==3.1.31 73 | google-auth==2.18.0 74 | google-auth-oauthlib==1.0.0 75 | gradio==3.16.2 76 | gradio_client==0.2.4 77 | groundingdino @ git+https://github.com/IDEA-Research/GroundingDINO.git@39b1472457b8264adc8581d354bb1d1956ec7ee7 78 | grpcio==1.51.3 79 | h11==0.14.0 80 | h5py==3.8.0 81 | hdf5storage==0.1.19 82 | httpcore==0.17.0 83 | httpx==0.24.0 84 | huggingface-hub==0.14.1 85 | humanfriendly==10.0 86 | hydra-core==1.3.2 87 | idna==3.4 88 | imageio==2.28.1 89 | imgaug==0.4.0 90 | importlib-metadata==6.6.0 91 | importlib-resources==5.12.0 92 | imutils==0.5.4 93 | inflection==0.5.1 94 | invisible-watermark==0.1.5 95 | iopath==0.1.9 96 | ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1655369107642/work 97 | ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1653754926575/work 98 | jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1669134318875/work 99 | Jinja2==3.1.2 100 | joblib==1.2.0 101 | jsonmerge==1.9.0 102 | jsonschema==4.17.3 103 | jupyter-client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1633454794268/work 104 | jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1678994158927/work 105 | kiwisolver==1.4.4 106 | kornia==0.5.0 107 | langcodes==3.3.0 108 | lark==1.1.5 109 | lazy_loader==0.2 110 | lightning-utilities==0.8.0 111 | linkify-it-py==2.0.2 112 | lit==16.0.3 113 | llvmlite==0.40.0 114 | lmdb==1.4.1 115 | lpips==0.1.4 116 | Mako==1.2.4 117 | Markdown==3.4.3 118 | markdown-it-py==2.2.0 119 | MarkupSafe==2.1.2 120 | matplotlib==3.7.1 121 | matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work 122 | mdit-py-plugins==0.3.3 123 | mdurl==0.1.2 124 | mediapipe==0.10.0 125 | mmcv==2.0.0 126 | mmengine==0.7.3 127 | mpmath==1.3.0 128 | msgpack==1.0.5 129 | multidict==6.0.4 130 | murmurhash==1.0.9 131 | mypy-extensions==1.0.0 132 | nest-asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1664684991461/work 133 | netifaces==0.10.6 134 | networkx==3.1 135 | numba==0.57.0 136 | numpy==1.23.5 137 | nvidia-cublas-cu11==11.10.3.66 138 | nvidia-cuda-cupti-cu11==11.7.101 139 | nvidia-cuda-nvrtc-cu11==11.7.99 140 | nvidia-cuda-runtime-cu11==11.7.99 141 | nvidia-cudnn-cu11==8.5.0.96 142 | nvidia-cufft-cu11==10.9.0.58 143 | nvidia-curand-cu11==10.2.10.91 144 | nvidia-cusolver-cu11==11.4.0.1 145 | nvidia-cusparse-cu11==11.7.4.91 146 | nvidia-nccl-cu11==2.14.3 147 | nvidia-nvtx-cu11==11.7.91 148 | oauthlib==3.2.2 149 | omegaconf==2.3.0 150 | onnx==1.14.0 151 | onnx-graphsurgeon @ file:///root/TensorRT-8.5.3.1/onnx_graphsurgeon/onnx_graphsurgeon-0.3.12-py2.py3-none-any.whl 152 | onnxruntime==1.14.1 153 | onnxruntime-gpu==1.14.1 154 | open-clip-torch==2.20.0 155 | opencv-contrib-python==4.7.0.72 156 | opencv-python==4.7.0.72 157 | opencv-python-headless==4.7.0.72 158 | orjson==3.8.12 159 | packaging==23.1 160 | pandas==2.0.1 161 | parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work 162 | pathspec==0.11.1 163 | pathy==0.10.1 164 | pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1667297516076/work 165 | pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work 166 | piexif==1.1.3 167 | Pillow==9.5.0 168 | pkgutil_resolve_name==1.3.10 169 | platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1683850015520/work 170 | portalocker==2.7.0 171 | preshed==3.0.8 172 | prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1677600924538/work 173 | protobuf==3.20.3 174 | psutil==5.9.5 175 | ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl 176 | pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work 177 | pyasn1==0.5.0 178 | pyasn1-modules==0.3.0 179 | pycocotools==2.0.6 180 | pycparser==2.21 181 | pycryptodome==3.17 182 | pycuda==2022.2.2 183 | pydantic==1.10.7 184 | pyDeprecate==0.3.2 185 | pydot==1.4.2 186 | pydub==0.25.1 187 | Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1681904169130/work 188 | pyparsing==3.0.9 189 | pyrsistent==0.19.3 190 | PySocks==1.7.1 191 | python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work 192 | python-engineio==4.4.1 193 | python-multipart==0.0.6 194 | python-socketio==5.8.0 195 | pytools==2022.1.14 196 | pytorch-lightning==1.7.7 197 | pytz==2023.3 198 | PyWavelets==1.4.1 199 | PyYAML==6.0 200 | pyzmq==19.0.2 201 | qudida==0.0.4 202 | ray==2.4.0 203 | realesrgan==0.3.0 204 | regex==2023.5.5 205 | requests==2.30.0 206 | requests-oauthlib==1.3.1 207 | resize-right==0.0.2 208 | rich==13.3.5 209 | rsa==4.9 210 | safetensors==0.3.1 211 | scikit-image==0.20.0 212 | scikit-learn==1.2.2 213 | scipy==1.9.1 214 | segment-anything @ git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588 215 | semantic-version==2.10.0 216 | Send2Trash==1.8.2 217 | sentencepiece==0.1.99 218 | shapely==2.0.1 219 | six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work 220 | smart-open==6.3.0 221 | smmap==5.0.0 222 | sniffio==1.3.0 223 | sounddevice==0.4.6 224 | soupsieve==2.4.1 225 | spacy==3.5.3 226 | spacy-legacy==3.0.12 227 | spacy-loggers==1.0.4 228 | srsly==2.4.6 229 | stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work 230 | starlette==0.26.1 231 | supervision==0.6.0 232 | sympy==1.12 233 | tabulate==0.9.0 234 | tb-nightly==2.14.0a20230514 235 | tensorboard==2.13.0 236 | tensorboard-data-server==0.7.0 237 | tensorrt @ file:///root/TensorRT-8.6.1.6/python/tensorrt-8.6.1-cp38-none-linux_x86_64.whl 238 | termcolor==2.3.0 239 | thinc==8.1.10 240 | threadpoolctl==3.1.0 241 | tifffile==2023.4.12 242 | timm==0.4.12 243 | tokenizers==0.13.3 244 | tomli==2.0.1 245 | toolz==0.12.0 246 | torch==2.0.1 247 | torchdiffeq==0.2.3 248 | torchlm==0.1.6.10 249 | torchmetrics==0.11.4 250 | torchsde==0.2.5 251 | torchvision==0.15.2 252 | tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1648827257044/work 253 | tqdm==4.65.0 254 | traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1675110562325/work 255 | trampoline==0.1.2 256 | transformers==4.29.0 257 | triton==2.0.0 258 | typer==0.7.0 259 | typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1678559861143/work 260 | tzdata==2023.3 261 | uc-micro-py==1.0.2 262 | urllib3==1.26.15 263 | uvicorn==0.22.0 264 | virtualenv==20.21.0 265 | wasabi==1.1.1 266 | wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1673864653149/work 267 | webdataset==0.2.48 268 | websockets==11.0.3 269 | Werkzeug==2.3.4 270 | wldhx.yadisk-direct==0.0.6 271 | yacs==0.1.8 272 | yapf==0.33.0 273 | yarl==1.9.2 274 | zipp==3.15.0 275 | -------------------------------------------------------------------------------- /scripts/onnx2trt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : onnx2trt.py 5 | @Time : 2023/05/12 14:41:23 6 | @Author : Huang Bo 7 | @Contact : cenahwang0304@gmail.com 8 | @Desc : None 9 | ''' 10 | 11 | 12 | import torch 13 | from torch.nn import functional as F 14 | import numpy as np 15 | from torchvision.transforms.functional import resize, to_pil_image 16 | from typing import Tuple 17 | import cv2 18 | import matplotlib.pyplot as plt 19 | import warnings 20 | import os 21 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 22 | from pathlib import Path 23 | import tensorrt as trt 24 | import argparse 25 | from segment_anything import sam_model_registry 26 | from segment_anything.utils.onnx import SamOnnxModel 27 | 28 | def show_mask(mask, ax, random_color=False): 29 | if random_color: 30 | color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) 31 | else: 32 | color = np.array([30/255, 144/255, 255/255, 0.6]) 33 | h, w = mask.shape[-2:] 34 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 35 | ax.imshow(mask_image) 36 | 37 | def show_points(coords, labels, ax, marker_size=375): 38 | pos_points = coords[labels==1] 39 | neg_points = coords[labels==0] 40 | ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) 41 | ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) 42 | 43 | def show_box(box, ax): 44 | x0, y0 = box[0], box[1] 45 | w, h = box[2] - box[0], box[3] - box[1] 46 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) 47 | 48 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 49 | """ 50 | Compute the output size given input size and target long side length. 51 | """ 52 | scale = long_side_length * 1.0 / max(oldh, oldw) 53 | newh, neww = oldh * scale, oldw * scale 54 | neww = int(neww + 0.5) 55 | newh = int(newh + 0.5) 56 | return (newh, neww) 57 | 58 | def pre_processing(image: np.ndarray, target_length: int, device,pixel_mean,pixel_std,img_size): 59 | target_size = get_preprocess_shape(image.shape[0], image.shape[1], target_length) 60 | input_image = np.array(resize(to_pil_image(image), target_size)) 61 | input_image_torch = torch.as_tensor(input_image, device=device) 62 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 63 | 64 | # Normalize colors 65 | input_image_torch = (input_image_torch - pixel_mean) / pixel_std 66 | 67 | # Pad 68 | h, w = input_image_torch.shape[-2:] 69 | padh = img_size - h 70 | padw = img_size - w 71 | input_image_torch = F.pad(input_image_torch, (0, padw, 0, padh)) 72 | return input_image_torch 73 | 74 | def export_embedding_model(gpu_id, model_type, sam_checkpoint, opset): 75 | device = f"cuda:{gpu_id}" 76 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 77 | sam.to(device=device) 78 | 79 | image = cv2.imread('./images/truck.jpg') 80 | target_length = sam.image_encoder.img_size 81 | pixel_mean = sam.pixel_mean 82 | pixel_std = sam.pixel_std 83 | img_size = sam.image_encoder.img_size 84 | inputs = pre_processing(image, target_length, device, pixel_mean, pixel_std, img_size) 85 | os.makedirs("embedding_onnx", exist_ok=True) 86 | onnx_model_path = os.path.join("embedding_onnx", "sam_" + model_type+"_"+"embedding.onnx") 87 | dummy_inputs = {"images": inputs} 88 | 89 | output_names = ["image_embeddings"] 90 | # image_embeddings = sam.image_encoder(inputs).cpu().numpy() 91 | # print('image_embeddings', image_embeddings.shape) 92 | with warnings.catch_warnings(): 93 | warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) 94 | warnings.filterwarnings("ignore", category=UserWarning) 95 | # with open(onnx_model_path, "wb") as f: 96 | torch.onnx.export( 97 | sam.image_encoder, 98 | tuple(dummy_inputs.values()), 99 | onnx_model_path, 100 | export_params=True, 101 | verbose=False, 102 | opset_version=opset, 103 | do_constant_folding=True, 104 | input_names=list(dummy_inputs.keys()), 105 | output_names=output_names, 106 | ) 107 | print(f"Generate image onnx model, and save in: {onnx_model_path}") 108 | 109 | def export_prompt_masks_model(model_type: str, checkpoint: str, opset: int): 110 | print("Loading model...") 111 | sam = sam_model_registry[model_type](checkpoint=checkpoint) 112 | 113 | onnx_model = SamOnnxModel( 114 | model=sam, 115 | return_single_mask=False, 116 | use_stability_score=False, 117 | return_extra_metrics=False, 118 | ) 119 | onnx_model_path = os.path.join("weights", "sam_" + model_type+"_"+"prompt_mask.onnx") 120 | 121 | dynamic_axes = { 122 | "point_coords": {1: "num_points"}, 123 | "point_labels": {1: "num_points"}, 124 | } 125 | 126 | embed_dim = sam.prompt_encoder.embed_dim 127 | embed_size = sam.prompt_encoder.image_embedding_size 128 | mask_input_size = [4 * x for x in embed_size] 129 | dummy_inputs = { 130 | "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), 131 | "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), 132 | "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), 133 | "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), 134 | "has_mask_input": torch.tensor([1], dtype=torch.float), 135 | # "orig_im_size": torch.tensor([1500, 2250], dtype=torch.int32), 136 | } 137 | 138 | _ = onnx_model(**dummy_inputs) 139 | 140 | output_names = ["low_res_masks", "iou_predictions"] 141 | 142 | with warnings.catch_warnings(): 143 | warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) 144 | warnings.filterwarnings("ignore", category=UserWarning) 145 | with open(onnx_model_path, "wb") as f: 146 | print(f"Exporting onnx model to {onnx_model_path}...") 147 | torch.onnx.export( 148 | onnx_model, 149 | tuple(dummy_inputs.values()), 150 | f, 151 | export_params=True, 152 | verbose=False, 153 | opset_version=opset, 154 | do_constant_folding=True, 155 | input_names=list(dummy_inputs.keys()), 156 | output_names=output_names, 157 | dynamic_axes=dynamic_axes, 158 | ) 159 | print(f"Generate prompt and masks onnx model, and save in: {onnx_model_path}") 160 | 161 | def export_prompt_model(gpu_id=1, model_type="default", sam_checkpoint="weights/sam_vit_h_4b8939.pth"): 162 | device = f"cuda:{gpu_id}" 163 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 164 | sam.to(device=device) 165 | 166 | os.makedirs("prompt_onnx", exist_ok=True) 167 | 168 | embed_size = sam.prompt_encoder.image_embedding_size 169 | mask_input_size = [4 * x for x in embed_size] 170 | onnx_model_path = os.path.join("prompt_onnx", "sam_" + model_type+"_"+"prompt.onnx") 171 | dynamic_axes = { 172 | "point_coords": {0: "num_points"}, 173 | "point_labels": {0: "num_points"}, 174 | "boxes": {0: "num_boxes"}, 175 | } 176 | points_coord = torch.randint(low=0, high=1024, size=(1, 1, 2), dtype=torch.float).to(device) 177 | points_label = torch.randint(low=0, high=4, size=(1, 1), dtype=torch.float).to(device) 178 | points = (points_coord, points_label) 179 | boxes = torch.randint(low=0, high=1024, size=(1, 1, 4), dtype=torch.int32).to(device) 180 | 181 | dummy_inputs = { 182 | "points": points, 183 | "boxes": boxes, 184 | "masks": torch.randn(1, 1, *mask_input_size, dtype=torch.float).to(device), 185 | } 186 | input_names = ["point_coords", "point_labels", "boxes", "mask_input"] 187 | 188 | output_names = ["sparse_embeddings", "dense_embeddings"] 189 | with warnings.catch_warnings(): 190 | warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) 191 | warnings.filterwarnings("ignore", category=UserWarning) 192 | # with open(onnx_model_path, "wb") as f: 193 | torch.onnx.export( 194 | sam.prompt_encoder, 195 | tuple(dummy_inputs.values()), 196 | onnx_model_path, 197 | export_params=True, 198 | verbose=False, 199 | opset_version=17, 200 | do_constant_folding=True, 201 | input_names=input_names, 202 | output_names=output_names, 203 | dynamic_axes=dynamic_axes, 204 | ) 205 | print(f"Generate image onnx model, and save in: {onnx_model_path}") 206 | 207 | def export_masks_model(gpu_id=2, model_type="default", sam_checkpoint="weights/sam_vit_h_4b8939.pth"): 208 | device = f"cuda:{gpu_id}" 209 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 210 | sam.to(device=device) 211 | 212 | os.makedirs("masks_onnx", exist_ok=True) 213 | 214 | onnx_model_path = os.path.join("masks_onnx", "sam_" + model_type+"_"+"masks.onnx") 215 | dynamic_axes = { 216 | "sparse_embeddings": {1: "num_embedding"}, 217 | } 218 | sparse_embeddings = torch.randint(low=0, high=1024, size=(1, 2, 256), dtype=torch.float).to(device) 219 | dense_embeddings = torch.randint(low=0, high=1024, size=(1, 256, 64, 64), dtype=torch.float).to(device) 220 | image_embeddings = torch.randint(low=0, high=1024, size=(1, 256, 64, 64), dtype=torch.int32).to(device) 221 | image_pe = torch.randint(low=0, high=1024, size=(1, 256, 64, 64), dtype=torch.int32).to(device) 222 | multimask_output = torch.tensor([0], dtype=torch.float).to(device) 223 | 224 | dummy_inputs = { 225 | "image_embeddings": image_embeddings, 226 | "image_pe": image_pe, 227 | "sparse_embeddings": sparse_embeddings, 228 | "dense_embeddings": dense_embeddings, 229 | "multimask_output": multimask_output, 230 | } 231 | 232 | output_names = ["low_res_masks", "iou_predictions"] 233 | with warnings.catch_warnings(): 234 | warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) 235 | warnings.filterwarnings("ignore", category=UserWarning) 236 | # with open(onnx_model_path, "wb") as f: 237 | torch.onnx.export( 238 | sam.mask_decoder, 239 | tuple(dummy_inputs.values()), 240 | onnx_model_path, 241 | export_params=True, 242 | verbose=False, 243 | opset_version=17, 244 | do_constant_folding=True, 245 | input_names=list(dummy_inputs.keys()), 246 | output_names=output_names, 247 | dynamic_axes=dynamic_axes, 248 | ) 249 | print(f"Generate image onnx model, and save in: {onnx_model_path}") 250 | 251 | def export_engine_image_encoder(f='vit_l_embedding.onnx', half=True): 252 | file = Path(f) 253 | f = file.with_suffix('.engine') # TensorRT engine file 254 | onnx = file.with_suffix('.onnx') 255 | logger = trt.Logger(trt.Logger.INFO) 256 | builder = trt.Builder(logger) 257 | config = builder.create_builder_config() 258 | workspace = 6 259 | print("workspace: ", workspace) 260 | config.max_workspace_size = workspace * 1 << 30 261 | flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) 262 | network = builder.create_network(flag) 263 | parser = trt.OnnxParser(network, logger) 264 | if not parser.parse_from_file(str(onnx)): 265 | raise RuntimeError(f'failed to load ONNX file: {onnx}') 266 | 267 | inputs = [network.get_input(i) for i in range(network.num_inputs)] 268 | outputs = [network.get_output(i) for i in range(network.num_outputs)] 269 | for inp in inputs: 270 | print(f'input "{inp.name}" with shape{inp.shape} {inp.dtype}') 271 | for out in outputs: 272 | print(f'output "{out.name}" with shape{out.shape} {out.dtype}') 273 | 274 | print(f'building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}') 275 | if builder.platform_has_fast_fp16 and half: 276 | config.set_flag(trt.BuilderFlag.FP16) 277 | 278 | with builder.build_engine(network, config) as engine, open(f, 'wb') as t: 279 | t.write(engine.serialize()) 280 | print(f"Generate image embedding trt model, save in: {f}") 281 | 282 | def export_engine_prompt_encoder_and_mask_decoder(f='sam_onnx_example.onnx', half=True): 283 | import tensorrt as trt 284 | from pathlib import Path 285 | file = Path(f) 286 | f = file.with_suffix('.engine') # TensorRT engine file 287 | onnx = file.with_suffix('.onnx') 288 | logger = trt.Logger(trt.Logger.INFO) 289 | builder = trt.Builder(logger) 290 | config = builder.create_builder_config() 291 | workspace = 6 292 | print("workspace: ", workspace) 293 | config.max_workspace_size = workspace * 1 << 30 294 | flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) 295 | network = builder.create_network(flag) 296 | parser = trt.OnnxParser(network, logger) 297 | print(str(onnx)) 298 | if not parser.parse_from_file(str(onnx)): 299 | raise RuntimeError(f'failed to load ONNX file: {onnx}') 300 | 301 | inputs = [network.get_input(i) for i in range(network.num_inputs)] 302 | outputs = [network.get_output(i) for i in range(network.num_outputs)] 303 | for inp in inputs: 304 | print(f'input "{inp.name}" with shape{inp.shape} {inp.dtype}') 305 | for out in outputs: 306 | print(f'output "{out.name}" with shape{out.shape} {out.dtype}') 307 | 308 | profile = builder.create_optimization_profile() 309 | profile.set_shape('image_embeddings', (1, 256, 64, 64), (1, 256, 64, 64), (1, 256, 64, 64)) 310 | profile.set_shape('point_coords', (1, 2,2), (1, 5,2), (1,10,2)) 311 | profile.set_shape('point_labels', (1, 2), (1, 5), (1,10)) 312 | profile.set_shape('mask_input', (1, 1, 256, 256), (1, 1, 256, 256), (1, 1, 256, 256)) 313 | profile.set_shape('has_mask_input', (1,), (1, ), (1, )) 314 | # profile.set_shape_input('orig_im_size', (1200, 1800), (1200, 1800), (1200, 1800)) # Must be consistent with input 315 | config.add_optimization_profile(profile) 316 | 317 | print(f'building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}') 318 | if builder.platform_has_fast_fp16 and half: 319 | config.set_flag(trt.BuilderFlag.FP16) 320 | with builder.build_engine(network, config) as engine, open(f, 'wb') as t: 321 | t.write(engine.serialize()) 322 | 323 | 324 | if __name__ == '__main__': 325 | parser = argparse.ArgumentParser("transform pth model to onnx, or transform onnx to tensorrt") 326 | parser.add_argument("--img_pt2onnx", action="store_true", help="transform image embedding pth from sam model to onnx") 327 | parser.add_argument("--sam_checkpoint", type=str, default="weights/sam_vit_h_4b8939.pth") 328 | parser.add_argument("--model_type", type=str, default="default") 329 | parser.add_argument("--prompt_masks_pt2onnx", action="store_true", help="whether export prompt encoder and masks decoder module") 330 | parser.add_argument("--img_onnx2trt", action="store_true", help="only transform image embedding onnx model to tensorrt engine") 331 | parser.add_argument("--img_onnx_model_path", type=str, default="embedding_onnx/sam_default_embedding.onnx") 332 | parser.add_argument("--sam_onnx2trt", action="store_true", help="only transform sam prompt and mask decoder onnx model to tensorrt engine") 333 | parser.add_argument("--sam_onnx_path", type=str, default="./weights/sam_vit_h_4b8939.onnx") 334 | parser.add_argument("--gpu_id", type=int, default=0, help="use which gpu to transform model") 335 | parser.add_argument("--opset", type=int, default=17, help="onnx opset version") 336 | args = parser.parse_args() 337 | 338 | with torch.no_grad(): 339 | if args.img_pt2onnx: 340 | export_embedding_model(args.gpu_id, args.model_type, args.sam_checkpoint, args.opset) 341 | if args.prompt_masks_pt2onnx: 342 | export_prompt_masks_model(args.model_type, args.sam_checkpoint, args.opset) 343 | if args.img_onnx2trt: 344 | export_engine_image_encoder(args.img_onnx_model_path, False) 345 | if args.sam_onnx2trt: 346 | export_engine_prompt_encoder_and_mask_decoder(args.sam_onnx_path) 347 | 348 | # just test split prompt encoder and masks decoder module 349 | # export_prompt_model() 350 | # export_masks_model() --------------------------------------------------------------------------------