├── .gitignore ├── .vscode ├── c_cpp_properties.json └── settings.json ├── CMakeLists.txt ├── Dockerfile.dev ├── LICENSE ├── README.md ├── README.md.bak ├── README_zh_windows.md ├── baseModel.h ├── commons ├── ThreadPool.h ├── buffers.h └── general.h ├── data ├── 2023-09-14 20-42-44.png └── truck.jpg ├── export.h ├── how_to_export_vim_h_model.ipynb ├── main.cpp ├── main_vim_h.cpp ├── sam.h ├── sam_utils.h ├── truck.gif ├── tutorials.ipynb └── tutorials_vim_h.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Compiled Dynamic libraries 15 | *.so 16 | *.dylib 17 | *.dll 18 | 19 | # Fortran module files 20 | *.mod 21 | *.smod 22 | 23 | # Compiled Static libraries 24 | *.lai 25 | *.la 26 | *.a 27 | *.lib 28 | 29 | # Executables 30 | *.exe 31 | *.out 32 | *.app 33 | 34 | build 35 | segment-anything 36 | data/*.onnx -------------------------------------------------------------------------------- /.vscode/c_cpp_properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "name": "linux", 5 | "includePath": [ 6 | "/usr/local/libtorch/include", 7 | "/usr/local/libtorch/include/torch/csrc/api/include", 8 | "/usr/include/opencv4", 9 | "/usr/include/opencv4/opencv2" 10 | ], 11 | "browse": { 12 | "limitSymbolsToIncludedHeaders": true, 13 | "databaseFilename": "" 14 | }, 15 | "intelliSenseMode": "linux-gcc-x64", 16 | "compilerPath": "/usr/bin/gcc", 17 | "cStandard": "c17", 18 | "cppStandard": "gnu++14" 19 | } 20 | ], 21 | "version": 4 22 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "cmake.configureSettings": { 3 | // "CMAKE_TOOLCHAIN_FILE": "C:/Users/77274/projects/vcpkg/scripts/buildsystems/vcpkg.cmake", 4 | // "yaml-cpp_DIR": "C:/Users/77274/projects/Dev/yaml-cpp/lib/cmake/yaml-cpp", 5 | // "Torch_DIR":"C:/Users/77274/projects/Dev/libtorch/share/cmake/Torch" 6 | "Torch_DIR":"/usr/local/libtorch/share/cmake/Torch" 7 | }, 8 | "files.associations": { 9 | "iostream": "cpp", 10 | "algorithm": "cpp", 11 | "array": "cpp", 12 | "atomic": "cpp", 13 | "bit": "cpp", 14 | "bitset": "cpp", 15 | "cctype": "cpp", 16 | "cfenv": "cpp", 17 | "charconv": "cpp", 18 | "chrono": "cpp", 19 | "clocale": "cpp", 20 | "cmath": "cpp", 21 | "compare": "cpp", 22 | "complex": "cpp", 23 | "concepts": "cpp", 24 | "condition_variable": "cpp", 25 | "csignal": "cpp", 26 | "cstdarg": "cpp", 27 | "cstddef": "cpp", 28 | "cstdint": "cpp", 29 | "cstdio": "cpp", 30 | "cstdlib": "cpp", 31 | "cstring": "cpp", 32 | "ctime": "cpp", 33 | "cwchar": "cpp", 34 | "cwctype": "cpp", 35 | "deque": "cpp", 36 | "exception": "cpp", 37 | "format": "cpp", 38 | "forward_list": "cpp", 39 | "fstream": "cpp", 40 | "functional": "cpp", 41 | "initializer_list": "cpp", 42 | "iomanip": "cpp", 43 | "ios": "cpp", 44 | "iosfwd": "cpp", 45 | "istream": "cpp", 46 | "iterator": "cpp", 47 | "limits": "cpp", 48 | "list": "cpp", 49 | "locale": "cpp", 50 | "map": "cpp", 51 | "memory": "cpp", 52 | "mutex": "cpp", 53 | "new": "cpp", 54 | "numeric": "cpp", 55 | "optional": "cpp", 56 | "ostream": "cpp", 57 | "queue": "cpp", 58 | "random": "cpp", 59 | "ratio": "cpp", 60 | "set": "cpp", 61 | "source_location": "cpp", 62 | "span": "cpp", 63 | "sstream": "cpp", 64 | "stdexcept": "cpp", 65 | "stop_token": "cpp", 66 | "streambuf": "cpp", 67 | "string": "cpp", 68 | "strstream": "cpp", 69 | "system_error": "cpp", 70 | "thread": "cpp", 71 | "tuple": "cpp", 72 | "type_traits": "cpp", 73 | "typeindex": "cpp", 74 | "typeinfo": "cpp", 75 | "unordered_map": "cpp", 76 | "unordered_set": "cpp", 77 | "utility": "cpp", 78 | "valarray": "cpp", 79 | "variant": "cpp", 80 | "vector": "cpp", 81 | "xfacet": "cpp", 82 | "xhash": "cpp", 83 | "xiosbase": "cpp", 84 | "xlocale": "cpp", 85 | "xlocbuf": "cpp", 86 | "xlocinfo": "cpp", 87 | "xlocmes": "cpp", 88 | "xlocmon": "cpp", 89 | "xlocnum": "cpp", 90 | "xloctime": "cpp", 91 | "xmemory": "cpp", 92 | "xstddef": "cpp", 93 | "xstring": "cpp", 94 | "xtr1common": "cpp", 95 | "xtree": "cpp", 96 | "xutility": "cpp", 97 | "cinttypes": "cpp", 98 | "filesystem": "cpp", 99 | "shared_mutex": "cpp", 100 | "stack": "cpp", 101 | "__nullptr": "cpp", 102 | "__config": "cpp", 103 | "*.tcc": "cpp", 104 | "codecvt": "cpp", 105 | "memory_resource": "cpp", 106 | "string_view": "cpp", 107 | "future": "cpp" 108 | }, 109 | "cmake.configureOnOpen": true, 110 | "C_Cpp.errorSquiggles": "disabled" 111 | } -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.14.0 FATAL_ERROR) 2 | project(Detections) 3 | 4 | # add_compile_options(-fno-elide-constructors) 5 | # set(CMAKE_CXX_FLAGS "-fno-elide-constructors ${CMAKE_CXX_FLAGS}") 6 | # set(ROOT_DIRS "C:/Users/77274/projects") 7 | # set(ROOT_DIRS "D:/projects") 8 | # set( CMAKE_CXX_COMPILER /usr/bin/g++ ) 9 | # set(CMAKE_CXX_FLAGS "-fno-elide-constructors ${CMAKE_CXX_FLAGS}") 10 | set(Torch_DIR "/usr/local/libtorch/share/cmake/Torch") 11 | find_package(Torch REQUIRED) 12 | # include_directories(${TORCH_INCLUDE_DIRS}) # Not needed for CMake >= 2.8.11 13 | 14 | 15 | # set(TorchVision_DIR "${ROOT_DIRS}/3rdpty/torchvision/share/cmake/TorchVision") 16 | # find_package(TorchVision REQUIRED) 17 | # include_directories(${TorchVision_INCLUDE_DIR}) # Not needed for CMake >= 2.8.11 18 | # link_directories("../3rdpty/torchvision/lib") 19 | 20 | find_package(OpenCV REQUIRED) 21 | 22 | set(SAMPLE_SOURCES main.cpp) 23 | set(TARGET_NAME sampleSAM) 24 | 25 | set(SAMPLE_DEP_LIBS 26 | nvinfer 27 | nvonnxparser 28 | ) 29 | 30 | # commons 31 | include_directories("./commons") 32 | 33 | add_executable(${TARGET_NAME} ${SAMPLE_SOURCES}) 34 | target_link_libraries(${TARGET_NAME} ${TORCH_LIBRARIES}) 35 | target_link_libraries(${TARGET_NAME} ${OpenCV_LIBS}) 36 | # target_link_libraries(${TARGET_NAME} TorchVision::TorchVision) 37 | target_link_libraries(${TARGET_NAME} ${SAMPLE_DEP_LIBS}) 38 | set_property(TARGET ${TARGET_NAME} PROPERTY CXX_STANDARD 17) 39 | 40 | 41 | set(SAMPLE_SOURCES main_vim_h.cpp) 42 | set(TARGET_NAME sampleSAM2) 43 | add_executable(${TARGET_NAME} ${SAMPLE_SOURCES}) 44 | target_link_libraries(${TARGET_NAME} ${TORCH_LIBRARIES}) 45 | target_link_libraries(${TARGET_NAME} ${OpenCV_LIBS}) 46 | # target_link_libraries(${TARGET_NAME} TorchVision::TorchVision) 47 | target_link_libraries(${TARGET_NAME} ${SAMPLE_DEP_LIBS}) 48 | set_property(TARGET ${TARGET_NAME} PROPERTY CXX_STANDARD 17) -------------------------------------------------------------------------------- /Dockerfile.dev: -------------------------------------------------------------------------------- 1 | #ARG os=ubuntu2204 tag=8.5.0-cuda-11.7 2 | #FROM nvidia/cuda:11.7.0-cudnn8-devel-ubuntu22.04 3 | #RUN dpkg -i nv-tensorrt-local-repo-${os}-${tag}_1.0-1_amd64.deb \ 4 | # && cp /var/nv-tensorrt-local-repo-${os}-${tag}/*-keyring.gpg /usr/share/keyrings/ \ 5 | # && apt-get update 6 | 7 | FROM nvcr.io/nvidia/tensorrt:22.10-py3 8 | 9 | # install libtorch 10 | RUN cd ~/ && wget https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.0.0%2Bcu118.zip \ 11 | && unzip libtorch-cxx11-abi-shared-with-deps-2.0.0+cu118.zip -d /usr/local/ \ 12 | && rm -f libtorch-cxx11-abi-shared-with-deps-2.0.0+cu118.zip 13 | 14 | # install opencv 15 | RUN apt update && apt install libopencv-dev -y 16 | 17 | # install python essential dependencies 18 | RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 19 | RUN pip3 install numpy pillow matplotlib pycocotools opencv-python onnx onnxruntime -i https://pypi.tuna.tsinghua.edu.cn/simple -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 mingj2021 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | [中文-windows](README_zh_windows.md) 3 |

4 | 5 |

6 | 7 | download [translated onnx model](https://drive.google.com/drive/folders/1xZ7XpyKGx0Fg-t81SLND56vC2TEQ4otN?usp=sharing) 8 | # tutorials 9 | ``` 10 | vim_h trt model tested on GPU 3060 Ti, tests ok. 11 | tutorials.ipynb 12 | tutorials_vim_h.ipynb 13 | 14 | ``` 15 | 16 | # how to export vim_h model 17 | how_to_export_vim_h_model.ipynb 18 | ``` 19 | # Divide the embeding model into 2 parts, named as part1 and part2 20 | # The decoder model remains unchanged 21 | # export_embedding_model_part_1() 22 | # export_embedding_model_part_2() 23 | # export_sam_h_model() 24 | # onnx_model_example2() 25 | ``` 26 | # FP32 or FP16 27 | ``` 28 | you can not set FP16 value in export.h, if gpu have enough mems. 29 | ``` -------------------------------------------------------------------------------- /README.md.bak: -------------------------------------------------------------------------------- 1 | # Overview 2 | [中文-windows](README_zh_CN.md) 3 |

4 | 5 |

6 | 7 | The repository helps you quickly deploy segment-anything to real-world applications, such as auto annotation,etc.The original model is divided into two submodels, one for embedding and the other for prompting and mask decoder. 8 | # Table of Contents 9 | - [Overview](#Overview) 10 | - [Table of Contents](#Table-of-Contents) 11 | - [Getting Started](#getting-started) 12 | - [Quick Start: Windows](#quick-start-windows) 13 | - [Quick Start: Ubuntu](#quick-start-ubuntu) 14 | - [Onnx Export](#onnx-export) 15 | - [Image Encoder](#export-embedding-onnx-model) 16 | - [Prompt Encoder and mask decoder](#export-prompt-encoder-mask-decoer-onnx-model) 17 | - [Test Exported-onnx models](#test-exported-onnx-models) 18 | - [Engine Export](#engine-export) 19 | - [Image Encoder](#convert-image-encoder-onnx-to-engine-model) 20 | - [Prompt Encoder and mask decoder](#convert-prompt-encoder-and-mask-decoder-onnx-to-engine-model) 21 | - [Quantification]() 22 | - [TensorRT Inferring]() 23 | - [Preprocess]() 24 | - [Postprocess]() 25 | - [build]() 26 | - [Examples]() 27 | 28 | # Getting Started 29 | Prerequisites: 30 | - [OpenCV](https://github.com/opencv/opencv) 31 | - [Libtorch](https://pytorch.org/) 32 | - [Torchvision](https://github.com/pytorch/vision) 33 | - [Tensorrt](https://developer.nvidia.com/tensorrt) 34 | ## Quick Start: Windows 35 | ``` 36 | # create conda virtual env 37 | conda create -n segment-anything python=3.8 38 | # activate this environment 39 | conda activate segment-anything 40 | pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 41 | pip install opencv-python pycocotools matplotlib onnxruntime onnx 42 | # download TensorRT-8.5.1.7.Windows10.x86_64.cuda-x.x.zip && unzip 43 | pip install ./tensorrt-8.5.1.7-cp38-none-win_amd64.whl 44 | # tensorrt tool: PyTorch-Quantization 45 | pip install pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com 46 | ``` 47 | 48 | ## Quick Start: Ubuntu 49 | ``` 50 | # create virtual env 51 | git clone https://github.com/mingj2021/segment-anything-tensorrt.git 52 | cd segment-anything-tensorrt 53 | docker build -t dev:ml -f ./Dockerfile.dev . 54 | git clone https://github.com/facebookresearch/segment-anything.git 55 | cd segment-anything 56 | mkdir weights && cd weights 57 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth 58 | cd ../ 59 | sudo mv ../demo.py . 60 | # run container 61 | docker run -it --rm --gpus all -v $(yourdirs)/segment-anything-tensorrt:/workspace/segment-anything-tensorrt dev:ml 62 | ``` 63 | 64 | # Onnx Export 65 | definition && import dependencies 66 | ``` 67 | import torch 68 | import torch.nn as nn 69 | from torch.nn import functional as F 70 | from segment_anything.modeling import Sam 71 | import numpy as np 72 | from torchvision.transforms.functional import resize, to_pil_image 73 | from typing import Tuple 74 | from segment_anything import sam_model_registry, SamPredictor 75 | import cv2 76 | import matplotlib.pyplot as plt 77 | import warnings 78 | import os 79 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 80 | import onnxruntime 81 | 82 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 83 | """ 84 | Compute the output size given input size and target long side length. 85 | """ 86 | scale = long_side_length * 1.0 / max(oldh, oldw) 87 | newh, neww = oldh * scale, oldw * scale 88 | neww = int(neww + 0.5) 89 | newh = int(newh + 0.5) 90 | return (newh, neww) 91 | 92 | # @torch.no_grad() 93 | def pre_processing(image: np.ndarray, target_length: int, device,pixel_mean,pixel_std,img_size): 94 | target_size = get_preprocess_shape(image.shape[0], image.shape[1], target_length) 95 | input_image = np.array(resize(to_pil_image(image), target_size)) 96 | input_image_torch = torch.as_tensor(input_image, device=device) 97 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 98 | 99 | # Normalize colors 100 | input_image_torch = (input_image_torch - pixel_mean) / pixel_std 101 | 102 | # Pad 103 | h, w = input_image_torch.shape[-2:] 104 | padh = img_size - h 105 | padw = img_size - w 106 | input_image_torch = F.pad(input_image_torch, (0, padw, 0, padh)) 107 | return input_image_torch 108 | 109 | ``` 110 | ## export embedding-onnx model 111 | ``` 112 | def export_embedding_model(): 113 | sam_checkpoint = "weights/sam_vit_l_0b3195.pth" 114 | model_type = "vit_l" 115 | 116 | device = "cpu" 117 | 118 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 119 | sam.to(device=device) 120 | 121 | # image_encoder = EmbeddingOnnxModel(sam) 122 | image = cv2.imread('notebooks/images/truck.jpg') 123 | target_length = sam.image_encoder.img_size 124 | pixel_mean = sam.pixel_mean 125 | pixel_std = sam.pixel_std 126 | img_size = sam.image_encoder.img_size 127 | inputs = pre_processing(image, target_length, device,pixel_mean,pixel_std,img_size) 128 | onnx_model_path = model_type+"_"+"embedding.onnx" 129 | dummy_inputs = { 130 | "images": inputs 131 | } 132 | output_names = ["image_embeddings"] 133 | image_embeddings = sam.image_encoder(inputs).cpu().numpy() 134 | print('image_embeddings', image_embeddings.shape) 135 | with warnings.catch_warnings(): 136 | warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) 137 | warnings.filterwarnings("ignore", category=UserWarning) 138 | with open(onnx_model_path, "wb") as f: 139 | torch.onnx.export( 140 | sam.image_encoder, 141 | tuple(dummy_inputs.values()), 142 | f, 143 | export_params=True, 144 | verbose=False, 145 | opset_version=13, 146 | do_constant_folding=True, 147 | input_names=list(dummy_inputs.keys()), 148 | output_names=output_names, 149 | # dynamic_axes=dynamic_axes, 150 | ) 151 | with torch.no_grad(): 152 | export_embedding_model() 153 | ``` 154 | ## export prompt-encoder-mask-decoer-onnx model 155 | change "forward" function in the file which is "segment_anything/utils/onnx.py",as follows: 156 | ``` 157 | def forward( 158 | self, 159 | image_embeddings: torch.Tensor, 160 | point_coords: torch.Tensor, 161 | point_labels: torch.Tensor, 162 | mask_input: torch.Tensor, 163 | has_mask_input: torch.Tensor 164 | # orig_im_size: torch.Tensor, 165 | ): 166 | sparse_embedding = self._embed_points(point_coords, point_labels) 167 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 168 | 169 | masks, scores = self.model.mask_decoder.predict_masks( 170 | image_embeddings=image_embeddings, 171 | image_pe=self.model.prompt_encoder.get_dense_pe(), 172 | sparse_prompt_embeddings=sparse_embedding, 173 | dense_prompt_embeddings=dense_embedding, 174 | ) 175 | 176 | if self.use_stability_score: 177 | scores = calculate_stability_score( 178 | masks, self.model.mask_threshold, self.stability_score_offset 179 | ) 180 | 181 | if self.return_single_mask: 182 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 183 | 184 | return masks, scores 185 | # upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 186 | 187 | # if self.return_extra_metrics: 188 | # stability_scores = calculate_stability_score( 189 | # upscaled_masks, self.model.mask_threshold, self.stability_score_offset 190 | # ) 191 | # areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 192 | # return upscaled_masks, scores, stability_scores, areas, masks 193 | 194 | # return upscaled_masks, scores, masks 195 | ``` 196 | 197 | ``` 198 | def export_sam_model(): 199 | from segment_anything.utils.onnx import SamOnnxModel 200 | checkpoint = "weights/sam_vit_l_0b3195.pth" 201 | model_type = "vit_l" 202 | sam = sam_model_registry[model_type](checkpoint=checkpoint) 203 | onnx_model_path = "sam_onnx_example.onnx" 204 | 205 | onnx_model = SamOnnxModel(sam, return_single_mask=True) 206 | 207 | dynamic_axes = { 208 | "point_coords": {1: "num_points"}, 209 | "point_labels": {1: "num_points"}, 210 | } 211 | 212 | embed_dim = sam.prompt_encoder.embed_dim 213 | embed_size = sam.prompt_encoder.image_embedding_size 214 | mask_input_size = [4 * x for x in embed_size] 215 | dummy_inputs = { 216 | "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), 217 | "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), 218 | "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), 219 | "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), 220 | "has_mask_input": torch.tensor([1], dtype=torch.float), 221 | # "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float), 222 | } 223 | # output_names = ["masks", "iou_predictions", "low_res_masks"] 224 | output_names = ["masks", "scores"] 225 | 226 | with warnings.catch_warnings(): 227 | warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) 228 | warnings.filterwarnings("ignore", category=UserWarning) 229 | with open(onnx_model_path, "wb") as f: 230 | torch.onnx.export( 231 | onnx_model, 232 | tuple(dummy_inputs.values()), 233 | f, 234 | export_params=True, 235 | verbose=False, 236 | opset_version=13, 237 | do_constant_folding=True, 238 | input_names=list(dummy_inputs.keys()), 239 | output_names=output_names, 240 | dynamic_axes=dynamic_axes, 241 | ) 242 | 243 | with torch.no_grad(): 244 | export_sam_model() 245 | ``` 246 | ## test exported-onnx models 247 | ``` 248 | def show_mask(mask, ax): 249 | color = np.array([30/255, 144/255, 255/255, 0.6]) 250 | h, w = mask.shape[-2:] 251 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) 252 | ax.imshow(mask_image) 253 | 254 | def show_points(coords, labels, ax, marker_size=375): 255 | pos_points = coords[labels==1] 256 | neg_points = coords[labels==0] 257 | ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) 258 | ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) 259 | 260 | def show_box(box, ax): 261 | x0, y0 = box[0], box[1] 262 | w, h = box[2] - box[0], box[3] - box[1] 263 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) 264 | 265 | def onnx_model_example(): 266 | import os 267 | ort_session_embedding = onnxruntime.InferenceSession('vit_l_embedding.onnx',providers=['CPUExecutionProvider']) 268 | ort_session_sam = onnxruntime.InferenceSession('sam_onnx_example.onnx',providers=['CPUExecutionProvider']) 269 | 270 | image = cv2.imread('notebooks/images/truck.jpg') 271 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 272 | image2 = image.copy() 273 | prompt_embed_dim = 256 274 | image_size = 1024 275 | vit_patch_size = 16 276 | image_embedding_size = image_size // vit_patch_size 277 | target_length = image_size 278 | pixel_mean=[123.675, 116.28, 103.53], 279 | pixel_std=[58.395, 57.12, 57.375] 280 | pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1) 281 | pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1) 282 | device = "cpu" 283 | inputs = pre_processing(image, target_length, device,pixel_mean,pixel_std,image_size) 284 | ort_inputs = { 285 | "images": inputs.cpu().numpy() 286 | } 287 | image_embeddings = ort_session_embedding.run(None, ort_inputs)[0] 288 | 289 | input_point = np.array([[500, 375]]) 290 | input_label = np.array([1]) 291 | 292 | onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :] 293 | onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32) 294 | from segment_anything.utils.transforms import ResizeLongestSide 295 | transf = ResizeLongestSide(image_size) 296 | onnx_coord = transf.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32) 297 | onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) 298 | onnx_has_mask_input = np.zeros(1, dtype=np.float32) 299 | 300 | ort_inputs = { 301 | "image_embeddings": image_embeddings, 302 | "point_coords": onnx_coord, 303 | "point_labels": onnx_label, 304 | "mask_input": onnx_mask_input, 305 | "has_mask_input": onnx_has_mask_input, 306 | # "orig_im_size": np.array(image.shape[:2], dtype=np.float32) 307 | } 308 | 309 | masks, _ = ort_session_sam.run(None, ort_inputs) 310 | 311 | from segment_anything.utils.onnx import SamOnnxModel 312 | checkpoint = "weights/sam_vit_l_0b3195.pth" 313 | model_type = "vit_l" 314 | sam = sam_model_registry[model_type](checkpoint=checkpoint) 315 | onnx_model_path = "sam_onnx_example.onnx" 316 | 317 | onnx_model = SamOnnxModel(sam, return_single_mask=True) 318 | masks = onnx_model.mask_postprocessing(torch.as_tensor(masks), torch.as_tensor(image.shape[:2])) 319 | masks = masks > 0.0 320 | plt.figure(figsize=(10, 10)) 321 | plt.imshow(image) 322 | show_mask(masks, plt.gca()) 323 | # show_box(input_box, plt.gca()) 324 | show_points(input_point, input_label, plt.gca()) 325 | plt.axis('off') 326 | plt.savefig('demo.png') 327 | 328 | with torch.no_grad(): 329 | onnx_model_example() 330 | ``` 331 | # Engine Export 332 | ## convert image-encoder-onnx to engine model 333 | ``` 334 | def export_engine_image_encoder(f='vit_l_embedding.onnx'): 335 | import tensorrt as trt 336 | from pathlib import Path 337 | file = Path(f) 338 | f = file.with_suffix('.engine') # TensorRT engine file 339 | onnx = file.with_suffix('.onnx') 340 | logger = trt.Logger(trt.Logger.INFO) 341 | builder = trt.Builder(logger) 342 | config = builder.create_builder_config() 343 | workspace = 6 344 | print("workspace: ", workspace) 345 | config.max_workspace_size = workspace * 1 << 30 346 | flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) 347 | network = builder.create_network(flag) 348 | parser = trt.OnnxParser(network, logger) 349 | if not parser.parse_from_file(str(onnx)): 350 | raise RuntimeError(f'failed to load ONNX file: {onnx}') 351 | 352 | inputs = [network.get_input(i) for i in range(network.num_inputs)] 353 | outputs = [network.get_output(i) for i in range(network.num_outputs)] 354 | for inp in inputs: 355 | print(f'input "{inp.name}" with shape{inp.shape} {inp.dtype}') 356 | for out in outputs: 357 | print(f'output "{out.name}" with shape{out.shape} {out.dtype}') 358 | 359 | half = True 360 | print(f'building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}') 361 | if builder.platform_has_fast_fp16 and half: 362 | config.set_flag(trt.BuilderFlag.FP16) 363 | with builder.build_engine(network, config) as engine, open(f, 'wb') as t: 364 | t.write(engine.serialize()) 365 | with torch.no_grad(): 366 | export_engine_image_encoder('vit_l_embedding.onnx') 367 | ``` 368 | 369 | ## convert prompt-encoder-and-mask-decoder-onnx to engine model 370 | ``` 371 | def export_engine_prompt_encoder_and_mask_decoder(f='sam_onnx_example.onnx'): 372 | import tensorrt as trt 373 | from pathlib import Path 374 | file = Path(f) 375 | f = file.with_suffix('.engine') # TensorRT engine file 376 | onnx = file.with_suffix('.onnx') 377 | logger = trt.Logger(trt.Logger.INFO) 378 | builder = trt.Builder(logger) 379 | config = builder.create_builder_config() 380 | workspace = 10 381 | print("workspace: ", workspace) 382 | config.max_workspace_size = workspace * 1 << 30 383 | flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) 384 | network = builder.create_network(flag) 385 | parser = trt.OnnxParser(network, logger) 386 | if not parser.parse_from_file(str(onnx)): 387 | raise RuntimeError(f'failed to load ONNX file: {onnx}') 388 | 389 | inputs = [network.get_input(i) for i in range(network.num_inputs)] 390 | outputs = [network.get_output(i) for i in range(network.num_outputs)] 391 | for inp in inputs: 392 | print(f'input "{inp.name}" with shape{inp.shape} {inp.dtype}') 393 | for out in outputs: 394 | print(f'output "{out.name}" with shape{out.shape} {out.dtype}') 395 | 396 | profile = builder.create_optimization_profile() 397 | profile.set_shape('image_embeddings', (1, 256, 64, 64), (1, 256, 64, 64), (1, 256, 64, 64)) 398 | profile.set_shape('point_coords', (1, 2,2), (1, 5,2), (1,10,2)) 399 | profile.set_shape('point_labels', (1, 2), (1, 5), (1,10)) 400 | profile.set_shape('mask_input', (1, 1, 256, 256), (1, 1, 256, 256), (1, 1, 256, 256)) 401 | profile.set_shape('has_mask_input', (1,), (1, ), (1, )) 402 | # # profile.set_shape_input('orig_im_size', (416,416), (1024,1024), (1500, 2250)) 403 | # profile.set_shape_input('orig_im_size', (2,), (2,), (2, )) 404 | config.add_optimization_profile(profile) 405 | 406 | half = True 407 | print(f'building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}') 408 | if builder.platform_has_fast_fp16 and half: 409 | config.set_flag(trt.BuilderFlag.FP16) 410 | with builder.build_engine(network, config) as engine, open(f, 'wb') as t: 411 | t.write(engine.serialize()) 412 | with torch.no_grad(): 413 | export_engine_image_encoder('sam_onnx_example.onnx') 414 | ``` 415 | # TensorRT Inferring 416 | - export 2 engine model 417 | - open main.cpp change actions(show window or generate file including embeddings features) by define variables[SAMPROMPTENCODERANDMASKDECODER or EMBEDDING] 418 | 419 | ## Preprocess 420 | image about resizing,padding,normalization. 421 | ``` 422 | ``` 423 | ## Postprocess 424 | generated mask about processing, plot. 425 | ``` 426 | ``` 427 | ## build 428 | ``` 429 | mkdir build && cd build 430 | # modify main.cpp 431 | cmake .. 432 | make -j10 433 | ``` -------------------------------------------------------------------------------- /README_zh_windows.md: -------------------------------------------------------------------------------- 1 | # Overview 2 | [english](README.md) 3 | 4 | # cudatoolkit 安装 5 | 运行 cuda_11.7.0_516.01_windows.exe ,一路默认即可 6 | 7 | # cudnn 安装 8 | 解压文件cudnn-windows-x86_64-8.6.0.163_cuda11-archive.zip,将解压后的文件夹[lib bin, include] 复制到cuda 对应安装路径[C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7] 9 | 10 | # libtorch 安装 11 | 下载[https://download.pytorch.org/libtorch/cu117/libtorch-win-shared-with-deps-2.0.1%2Bcu117.zip]并解压。 12 | 13 | # torchvision 安装 14 | ``` 15 | git clone https://github.com/pytorch/vision.git 16 | mkdir build 17 | cd build 18 | cmake .. -DWITH_CUDA=on 19 | cmake --build . --config release --target install 20 | ``` 21 | 22 | # tensorrt 安装 23 | 下载[TensorRT-8.5.1.7.Windows10.x86_64.cuda-x.x.zip]并解压 24 | 25 | # opencv 安装 26 | 下载[https://github.com/opencv/opencv/releases/download/4.7.0/opencv-4.7.0-windows.exe]并默认安装. 27 | 28 | -------------------------------------------------------------------------------- /baseModel.h: -------------------------------------------------------------------------------- 1 | #ifndef BASEMODEL_H 2 | #define BASEMODEL_H 3 | 4 | #include "sam_utils.h" 5 | 6 | class BaseModel 7 | { 8 | public: 9 | cudaStream_t stream; 10 | std::shared_ptr mEngine; 11 | std::unique_ptr context; 12 | 13 | std::vector mDeviceBindings; 14 | std::map> mInOut; 15 | std::vector mInputsName, mOutputsName; 16 | 17 | public: 18 | BaseModel(std::string modelFile); 19 | ~BaseModel(); 20 | void read_engine_file(std::string modelFile); 21 | // std::vector get_inputs_name(); 22 | // std::vector get_outputs_name(); 23 | // std::vector get_device_buffer(); 24 | }; 25 | 26 | BaseModel::BaseModel(std::string modelFile) 27 | { 28 | read_engine_file(modelFile); 29 | context = std::unique_ptr(mEngine->createExecutionContext()); 30 | if (!context) 31 | { 32 | std::cerr << "create context error" << std::endl; 33 | } 34 | 35 | CHECK(cudaStreamCreate(&stream)); 36 | 37 | for (int i = 0; i < mEngine->getNbBindings(); i++) 38 | { 39 | auto dims = mEngine->getBindingDimensions(i); 40 | auto tensor_name = mEngine->getBindingName(i); 41 | bool isInput = mEngine->bindingIsInput(i); 42 | if (isInput) 43 | mInputsName.emplace_back(tensor_name); 44 | else 45 | mOutputsName.emplace_back(tensor_name); 46 | std::cout << "tensor_name: " << tensor_name << std::endl; 47 | dims2str(dims); 48 | nvinfer1::DataType type = mEngine->getBindingDataType(i); 49 | index2srt(type); 50 | auto vol = std::accumulate(dims.d, dims.d + dims.nbDims, int64_t{1}, std::multiplies{}); 51 | std::unique_ptr device_buffer{new algorithms::DeviceBuffer(vol, type)}; 52 | mDeviceBindings.emplace_back(device_buffer->data()); 53 | mInOut[tensor_name] = std::move(device_buffer); 54 | } 55 | } 56 | 57 | BaseModel::~BaseModel() 58 | { 59 | CHECK(cudaStreamDestroy(stream)); 60 | } 61 | 62 | void BaseModel::read_engine_file(std::string modelFile) 63 | { 64 | std::ifstream engineFile(modelFile.c_str(), std::ifstream::binary); 65 | assert(engineFile); 66 | 67 | int fsize; 68 | engineFile.seekg(0, engineFile.end); 69 | fsize = engineFile.tellg(); 70 | engineFile.seekg(0, engineFile.beg); 71 | std::vector engineData(fsize); 72 | engineFile.read(engineData.data(), fsize); 73 | 74 | if (engineFile) 75 | std::cout << "all characters read successfully." << std::endl; 76 | else 77 | std::cout << "error: only " << engineFile.gcount() << " could be read" << std::endl; 78 | engineFile.close(); 79 | 80 | std::unique_ptr runtime(nvinfer1::createInferRuntime(logger)); 81 | mEngine = std::shared_ptr(runtime->deserializeCudaEngine(engineData.data(), fsize, nullptr)); 82 | } 83 | 84 | #endif -------------------------------------------------------------------------------- /commons/ThreadPool.h: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright (c) 2012 Jakob Progsch, Václav Zeman 3 | 4 | This software is provided 'as-is', without any express or implied 5 | warranty. In no event will the authors be held liable for any damages 6 | arising from the use of this software. 7 | 8 | Permission is granted to anyone to use this software for any purpose, 9 | including commercial applications, and to alter it and redistribute it 10 | freely, subject to the following restrictions: 11 | 12 | 1. The origin of this software must not be misrepresented; you must not 13 | claim that you wrote the original software. If you use this software 14 | in a product, an acknowledgment in the product documentation would be 15 | appreciated but is not required. 16 | 17 | 2. Altered source versions must be plainly marked as such, and must not be 18 | misrepresented as being the original software. 19 | 20 | 3. This notice may not be removed or altered from any source 21 | distribution. 22 | */ 23 | 24 | #ifndef THREAD_POOL_H 25 | #define THREAD_POOL_H 26 | 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | #include 35 | #include 36 | 37 | class ThreadPool { 38 | public: 39 | ThreadPool(size_t); 40 | template 41 | auto enqueue(F&& f, Args&&... args) 42 | -> std::future::type>; 43 | ~ThreadPool(); 44 | private: 45 | // need to keep track of threads so we can join them 46 | std::vector< std::thread > workers; 47 | // the task queue 48 | std::queue< std::function > tasks; 49 | 50 | // synchronization 51 | std::mutex queue_mutex; 52 | std::condition_variable condition; 53 | bool stop; 54 | }; 55 | 56 | // the constructor just launches some amount of workers 57 | inline ThreadPool::ThreadPool(size_t threads) 58 | : stop(false) 59 | { 60 | for(size_t i = 0;i task; 67 | 68 | { 69 | std::unique_lock lock(this->queue_mutex); 70 | this->condition.wait(lock, 71 | [this]{ return this->stop || !this->tasks.empty(); }); 72 | if(this->stop && this->tasks.empty()) 73 | return; 74 | task = std::move(this->tasks.front()); 75 | this->tasks.pop(); 76 | } 77 | 78 | task(); 79 | } 80 | } 81 | ); 82 | } 83 | 84 | // add new work item to the pool 85 | template 86 | auto ThreadPool::enqueue(F&& f, Args&&... args) 87 | -> std::future::type> 88 | { 89 | using return_type = typename std::result_of::type; 90 | 91 | auto task = std::make_shared< std::packaged_task >( 92 | std::bind(std::forward(f), std::forward(args)...) 93 | ); 94 | 95 | std::future res = task->get_future(); 96 | { 97 | std::unique_lock lock(queue_mutex); 98 | 99 | // don't allow enqueueing after stopping the pool 100 | if(stop) 101 | throw std::runtime_error("enqueue on stopped ThreadPool"); 102 | 103 | tasks.emplace([task](){ (*task)(); }); 104 | } 105 | condition.notify_one(); 106 | return res; 107 | } 108 | 109 | // the destructor joins all threads 110 | inline ThreadPool::~ThreadPool() 111 | { 112 | { 113 | std::unique_lock lock(queue_mutex); 114 | stop = true; 115 | } 116 | condition.notify_all(); 117 | for(std::thread &worker: workers) 118 | worker.join(); 119 | } 120 | 121 | #endif 122 | -------------------------------------------------------------------------------- /commons/buffers.h: -------------------------------------------------------------------------------- 1 | #ifndef BUFFERS_H 2 | #define BUFFERS_H 3 | 4 | #include "general.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | namespace algorithms 18 | { 19 | 20 | template 21 | class GenericBuffer 22 | { 23 | public: 24 | //! 25 | //! \brief Construct an empty buffer. 26 | //! 27 | GenericBuffer(nvinfer1::DataType type = nvinfer1::DataType::kFLOAT) 28 | : mSize(0), mCapacity(0), mType(type), mBuffer(nullptr) 29 | { 30 | } 31 | 32 | //! 33 | //! \brief Construct a buffer with the specified allocation size in bytes. 34 | //! 35 | GenericBuffer(size_t size, nvinfer1::DataType type) 36 | : mSize(size), mCapacity(size), mType(type) 37 | { 38 | if (!allocFn(&mBuffer, this->nbBytes())) 39 | { 40 | throw std::bad_alloc(); 41 | } 42 | } 43 | 44 | GenericBuffer(GenericBuffer &&buf) 45 | : mSize(buf.mSize), mCapacity(buf.mCapacity), mType(buf.mType), mBuffer(buf.mBuffer) 46 | { 47 | buf.mSize = 0; 48 | buf.mCapacity = 0; 49 | buf.mType = nvinfer1::DataType::kFLOAT; 50 | buf.mBuffer = nullptr; 51 | } 52 | 53 | GenericBuffer &operator=(GenericBuffer &&buf) 54 | { 55 | if (this != &buf) 56 | { 57 | freeFn(mBuffer); 58 | mSize = buf.mSize; 59 | mCapacity = buf.mCapacity; 60 | mType = buf.mType; 61 | mBuffer = buf.mBuffer; 62 | // Reset buf. 63 | buf.mSize = 0; 64 | buf.mCapacity = 0; 65 | buf.mBuffer = nullptr; 66 | } 67 | return *this; 68 | } 69 | 70 | //! 71 | //! \brief Returns pointer to underlying array. 72 | //! 73 | void *data() 74 | { 75 | return mBuffer; 76 | } 77 | 78 | //! 79 | //! \brief Returns pointer to underlying array. 80 | //! 81 | const void *data() const 82 | { 83 | return mBuffer; 84 | } 85 | 86 | //! 87 | //! \brief Returns the size (in number of elements) of the buffer. 88 | //! 89 | size_t size() const 90 | { 91 | return mSize; 92 | } 93 | 94 | //! 95 | //! \brief Returns the size (in bytes) of the buffer. 96 | //! 97 | size_t nbBytes() const 98 | { 99 | return this->size() * getElementSize(mType); 100 | } 101 | 102 | //! 103 | //! \brief Resizes the buffer. This is a no-op if the new size is smaller than or equal to the current capacity. 104 | //! 105 | void resize(size_t newSize) 106 | { 107 | mSize = newSize; 108 | if (mCapacity < newSize) 109 | { 110 | freeFn(mBuffer); 111 | if (!allocFn(&mBuffer, this->nbBytes())) 112 | { 113 | throw std::bad_alloc{}; 114 | } 115 | mCapacity = newSize; 116 | } 117 | } 118 | 119 | //! 120 | //! \brief Overload of resize that accepts Dims 121 | //! 122 | void resize(const nvinfer1::Dims &dims) 123 | { 124 | auto vol = std::accumulate(dims.d, dims.d + dims.nbDims, int64_t{1}, std::multiplies{}); 125 | return this->resize(vol); 126 | } 127 | 128 | //! 129 | //! \brief copy data from host to device 130 | //! 131 | int host2device(void* data, bool async, const cudaStream_t& stream = 0) 132 | { 133 | int ret = 0; 134 | if(async) 135 | ret = cudaMemcpyAsync(mBuffer, data, nbBytes(), cudaMemcpyHostToDevice, stream); 136 | else 137 | ret = cudaMemcpy(mBuffer, data, nbBytes(), cudaMemcpyHostToDevice); 138 | return ret; 139 | } 140 | 141 | //! 142 | //! \brief copy data from device to host 143 | //! 144 | int device2host(void* data, bool async, const cudaStream_t& stream = 0) 145 | { 146 | int ret = 0; 147 | if(async) 148 | ret = cudaMemcpyAsync(data, mBuffer, nbBytes(), cudaMemcpyDeviceToHost, stream); 149 | else 150 | ret = cudaMemcpy(data, mBuffer, nbBytes(), cudaMemcpyDeviceToHost); 151 | return ret; 152 | } 153 | 154 | //! 155 | //! \brief copy data from device to host 156 | //! 157 | int device2device(void* data, bool async, const cudaStream_t& stream = 0) 158 | { 159 | int ret = 0; 160 | if (async) 161 | ret = cudaMemcpyAsync(data, mBuffer, nbBytes(), cudaMemcpyDeviceToDevice, stream); 162 | else 163 | ret = cudaMemcpy(data, mBuffer, nbBytes(), cudaMemcpyDeviceToDevice); 164 | return ret; 165 | } 166 | 167 | nvinfer1::DataType getDataType() 168 | { 169 | return mType; 170 | } 171 | 172 | ~GenericBuffer() 173 | { 174 | freeFn(mBuffer); 175 | } 176 | 177 | private: 178 | size_t mSize{0}, mCapacity{0}; 179 | nvinfer1::DataType mType; 180 | void *mBuffer; 181 | AllocFunc allocFn; 182 | FreeFunc freeFn; 183 | }; 184 | 185 | class DeviceAllocator 186 | { 187 | public: 188 | bool operator()(void **ptr, size_t size) const 189 | { 190 | return cudaMalloc(ptr, size) == cudaSuccess; 191 | } 192 | }; 193 | 194 | class DeviceFree 195 | { 196 | public: 197 | void operator()(void *ptr) const 198 | { 199 | cudaFree(ptr); 200 | } 201 | }; 202 | using DeviceBuffer = GenericBuffer; 203 | } 204 | 205 | #endif 206 | -------------------------------------------------------------------------------- /commons/general.h: -------------------------------------------------------------------------------- 1 | #ifndef GENERAL_H 2 | #define GENERAL_H 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | namespace algorithms 18 | { 19 | uint32_t getElementSize(nvinfer1::DataType t) noexcept 20 | { 21 | switch (t) 22 | { 23 | case nvinfer1::DataType::kINT32: 24 | return 4; 25 | case nvinfer1::DataType::kFLOAT: 26 | return 4; 27 | case nvinfer1::DataType::kHALF: 28 | return 2; 29 | case nvinfer1::DataType::kBOOL: 30 | case nvinfer1::DataType::kUINT8: 31 | case nvinfer1::DataType::kINT8: 32 | return 1; 33 | } 34 | return 0; 35 | } 36 | 37 | template 38 | Type string2Num(const std::string &str) 39 | { 40 | std::istringstream iss(str); 41 | Type num; 42 | iss >> std::hex >> num; 43 | return num; 44 | } 45 | 46 | std::vector read_names(const std::string filename) 47 | { 48 | std::vector names; 49 | std::ifstream infile(filename); 50 | //assert(stream.is_open()); 51 | 52 | std::string line; 53 | while (std::getline(infile, line)) 54 | { 55 | names.emplace_back(line); 56 | } 57 | return names; 58 | } 59 | } 60 | #endif 61 | -------------------------------------------------------------------------------- /data/2023-09-14 20-42-44.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingj2021/segment-anything-tensorrt/bd27bc114fc6f6881c905d56f802b3d7ab3f2eb9/data/2023-09-14 20-42-44.png -------------------------------------------------------------------------------- /data/truck.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingj2021/segment-anything-tensorrt/bd27bc114fc6f6881c905d56f802b3d7ab3f2eb9/data/truck.jpg -------------------------------------------------------------------------------- /export.h: -------------------------------------------------------------------------------- 1 | #ifndef EXPORT_H 2 | #define EXPORT_H 3 | 4 | #include 5 | #include 6 | #include "sam_utils.h" 7 | 8 | 9 | void export_engine_image_encoder(std::string f="vit_l_embedding.onnx",std::string output="vit_l_embedding.engine") 10 | { 11 | // create an instance of the builder 12 | std::unique_ptr builder(createInferBuilder(logger)); 13 | // create a network definition 14 | // The kEXPLICIT_BATCH flag is required in order to import models using the ONNX parser. 15 | uint32_t flag = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); 16 | 17 | //auto network = std::make_unique(builder->createNetworkV2(flag)); 18 | std::unique_ptr network(builder->createNetworkV2(flag)); 19 | 20 | // Importing a Model Using the ONNX Parser 21 | //auto parser = std::make_unique(createParser(*network, logger)); 22 | std::unique_ptr parser(createParser(*network, logger)); 23 | 24 | // read the model file and process any errors 25 | parser->parseFromFile(f.c_str(), 26 | static_cast(nvinfer1::ILogger::Severity::kWARNING)); 27 | for (int32_t i = 0; i < parser->getNbErrors(); ++i) 28 | { 29 | std::cout << parser->getError(i)->desc() << std::endl; 30 | } 31 | 32 | // create a build configuration specifying how TensorRT should optimize the model 33 | std::unique_ptr config(builder->createBuilderConfig()); 34 | 35 | // maximum workspace size 36 | // int workspace = 4; // GB 37 | // config->setMaxWorkspaceSize(workspace * 1U << 30); 38 | config->setFlag(BuilderFlag::kGPU_FALLBACK); 39 | 40 | config->setFlag(BuilderFlag::kFP16); 41 | 42 | // create an engine 43 | // auto serializedModel = std::make_unique(builder->buildSerializedNetwork(*network, *config)); 44 | std::unique_ptr serializedModel(builder->buildSerializedNetwork(*network, *config)); 45 | std::cout << "serializedModel->size()" << serializedModel->size() << std::endl; 46 | std::ofstream outfile(output, std::ofstream::out | std::ofstream::binary); 47 | outfile.write((char*)serializedModel->data(), serializedModel->size()); 48 | } 49 | 50 | void export_engine_prompt_encoder_and_mask_decoder(std::string f="sam_onnx_example.onnx",std::string output="sam_onnx_example.engine") 51 | { 52 | // create an instance of the builder 53 | std::unique_ptr builder(createInferBuilder(logger)); 54 | // create a network definition 55 | // The kEXPLICIT_BATCH flag is required in order to import models using the ONNX parser. 56 | uint32_t flag = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); 57 | 58 | //auto network = std::make_unique(builder->createNetworkV2(flag)); 59 | std::unique_ptr network(builder->createNetworkV2(flag)); 60 | 61 | // Importing a Model Using the ONNX Parser 62 | //auto parser = std::make_unique(createParser(*network, logger)); 63 | std::unique_ptr parser(createParser(*network, logger)); 64 | 65 | // read the model file and process any errors 66 | parser->parseFromFile(f.c_str(), 67 | static_cast(nvinfer1::ILogger::Severity::kWARNING)); 68 | for (int32_t i = 0; i < parser->getNbErrors(); ++i) 69 | { 70 | std::cout << parser->getError(i)->desc() << std::endl; 71 | } 72 | 73 | // create a build configuration specifying how TensorRT should optimize the model 74 | std::unique_ptr config(builder->createBuilderConfig()); 75 | 76 | // maximum workspace size 77 | // int workspace = 8; // GB 78 | // config->setMaxWorkspaceSize(workspace * 1U << 30); 79 | config->setFlag(BuilderFlag::kGPU_FALLBACK); 80 | 81 | config->setFlag(BuilderFlag::kFP16); 82 | 83 | nvinfer1::IOptimizationProfile* profile = builder->createOptimizationProfile(); 84 | // profile->setDimensions("image_embeddings", nvinfer1::OptProfileSelector::kMIN, {1, 256, 64, 64 }); 85 | // profile->setDimensions("image_embeddings", nvinfer1::OptProfileSelector::kOPT, {1, 256, 64, 64 }); 86 | // profile->setDimensions("image_embeddings", nvinfer1::OptProfileSelector::kMAX, {1, 256, 64, 64 }); 87 | 88 | profile->setDimensions("point_coords", nvinfer1::OptProfileSelector::kMIN, {3, 1, 2,2 }); 89 | profile->setDimensions("point_coords", nvinfer1::OptProfileSelector::kOPT, { 3,1, 5,2 }); 90 | profile->setDimensions("point_coords", nvinfer1::OptProfileSelector::kMAX, { 3,1,10,2 }); 91 | 92 | profile->setDimensions("point_labels", nvinfer1::OptProfileSelector::kMIN, { 2,1, 2}); 93 | profile->setDimensions("point_labels", nvinfer1::OptProfileSelector::kOPT, { 2,1, 5 }); 94 | profile->setDimensions("point_labels", nvinfer1::OptProfileSelector::kMAX, { 2,1,10 }); 95 | 96 | // profile->setDimensions("mask_input", nvinfer1::OptProfileSelector::kMIN, { 1, 1, 256, 256}); 97 | // profile->setDimensions("mask_input", nvinfer1::OptProfileSelector::kOPT, { 1, 1, 256, 256 }); 98 | // profile->setDimensions("mask_input", nvinfer1::OptProfileSelector::kMAX, { 1, 1, 256, 256 }); 99 | 100 | // profile->setDimensions("has_mask_input", nvinfer1::OptProfileSelector::kMIN, { 1,}); 101 | // profile->setDimensions("has_mask_input", nvinfer1::OptProfileSelector::kOPT, { 1, }); 102 | // profile->setDimensions("has_mask_input", nvinfer1::OptProfileSelector::kMAX, { 1, }); 103 | 104 | config->addOptimizationProfile(profile); 105 | 106 | // create an engine 107 | // auto serializedModel = std::make_unique(builder->buildSerializedNetwork(*network, *config)); 108 | std::unique_ptr serializedModel(builder->buildSerializedNetwork(*network, *config)); 109 | std::cout << "serializedModel->size()" << serializedModel->size() << std::endl; 110 | std::ofstream outfile(output, std::ofstream::out | std::ofstream::binary); 111 | outfile.write((char*)serializedModel->data(), serializedModel->size()); 112 | } 113 | #endif -------------------------------------------------------------------------------- /how_to_export_vim_h_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Divide the embeding model into 2 parts, named as part1 and part2 \n", 10 | "# The decoder model remains unchanged\n", 11 | "# export_embedding_model_part_1()\n", 12 | "# export_embedding_model_part_2()\n", 13 | "# export_sam_h_model()\n", 14 | "# onnx_model_example2()" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import torch\n", 24 | "import torch.nn as nn\n", 25 | "from torch.nn import functional as F\n", 26 | "from segment_anything.modeling import Sam\n", 27 | "import numpy as np\n", 28 | "from torchvision.transforms.functional import resize, to_pil_image\n", 29 | "from typing import Tuple\n", 30 | "from segment_anything import sam_model_registry, SamPredictor\n", 31 | "import cv2\n", 32 | "import matplotlib.pyplot as plt\n", 33 | "import warnings\n", 34 | "import os\n", 35 | "os.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\"\n", 36 | "import onnxruntime\n", 37 | "\n", 38 | "\n", 39 | "def show_mask(mask, ax, random_color=False):\n", 40 | " if random_color:\n", 41 | " color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n", 42 | " else:\n", 43 | " color = np.array([30/255, 144/255, 255/255, 0.6])\n", 44 | " h, w = mask.shape[-2:]\n", 45 | " mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n", 46 | " ax.imshow(mask_image)\n", 47 | " \n", 48 | "def show_points(coords, labels, ax, marker_size=375):\n", 49 | " pos_points = coords[labels==1]\n", 50 | " neg_points = coords[labels==0]\n", 51 | " ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)\n", 52 | " ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) \n", 53 | " \n", 54 | "def show_box(box, ax):\n", 55 | " x0, y0 = box[0], box[1]\n", 56 | " w, h = box[2] - box[0], box[3] - box[1]\n", 57 | " ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) \n", 58 | "\n", 59 | "\n", 60 | "def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:\n", 61 | " \"\"\"\n", 62 | " Compute the output size given input size and target long side length.\n", 63 | " \"\"\"\n", 64 | " scale = long_side_length * 1.0 / max(oldh, oldw)\n", 65 | " newh, neww = oldh * scale, oldw * scale\n", 66 | " neww = int(neww + 0.5)\n", 67 | " newh = int(newh + 0.5)\n", 68 | " return (newh, neww)\n", 69 | "\n", 70 | "# @torch.no_grad()\n", 71 | "def pre_processing(image: np.ndarray, target_length: int, device,pixel_mean,pixel_std,img_size):\n", 72 | " target_size = get_preprocess_shape(image.shape[0], image.shape[1], target_length)\n", 73 | " input_image = np.array(resize(to_pil_image(image), target_size))\n", 74 | " input_image_torch = torch.as_tensor(input_image, device=device)\n", 75 | " input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]\n", 76 | "\n", 77 | " # Normalize colors\n", 78 | " input_image_torch = (input_image_torch - pixel_mean) / pixel_std\n", 79 | "\n", 80 | " # Pad\n", 81 | " h, w = input_image_torch.shape[-2:]\n", 82 | " padh = img_size - h\n", 83 | " padw = img_size - w\n", 84 | " input_image_torch = F.pad(input_image_torch, (0, padw, 0, padh))\n", 85 | " return input_image_torch\n", 86 | "\n", 87 | "def export_embedding_model():\n", 88 | " sam_checkpoint = \"weights/sam_vit_l_0b3195.pth\"\n", 89 | " model_type = \"vit_l\"\n", 90 | "\n", 91 | " device = \"cpu\"\n", 92 | "\n", 93 | " sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n", 94 | " sam.to(device=device)\n", 95 | "\n", 96 | " # image_encoder = EmbeddingOnnxModel(sam)\n", 97 | " image = cv2.imread('notebooks/images/truck.jpg')\n", 98 | " target_length = sam.image_encoder.img_size\n", 99 | " pixel_mean = sam.pixel_mean \n", 100 | " pixel_std = sam.pixel_std\n", 101 | " img_size = sam.image_encoder.img_size\n", 102 | " inputs = pre_processing(image, target_length, device,pixel_mean,pixel_std,img_size)\n", 103 | " onnx_model_path = model_type+\"_\"+\"embedding.onnx\"\n", 104 | " dummy_inputs = {\n", 105 | " \"images\": inputs\n", 106 | "}\n", 107 | " output_names = [\"image_embeddings\"]\n", 108 | " image_embeddings = sam.image_encoder(inputs).cpu().numpy()\n", 109 | " print('image_embeddings', image_embeddings.shape)\n", 110 | " with warnings.catch_warnings():\n", 111 | " warnings.filterwarnings(\"ignore\", category=torch.jit.TracerWarning)\n", 112 | " warnings.filterwarnings(\"ignore\", category=UserWarning)\n", 113 | " with open(onnx_model_path, \"wb\") as f:\n", 114 | " torch.onnx.export(\n", 115 | " sam.image_encoder,\n", 116 | " tuple(dummy_inputs.values()),\n", 117 | " f,\n", 118 | " export_params=True,\n", 119 | " verbose=False,\n", 120 | " opset_version=13,\n", 121 | " do_constant_folding=True,\n", 122 | " input_names=list(dummy_inputs.keys()),\n", 123 | " output_names=output_names,\n", 124 | " # dynamic_axes=dynamic_axes,\n", 125 | " ) \n", 126 | "\n", 127 | "def export_sam_model():\n", 128 | " from segment_anything.utils.onnx import SamOnnxModel\n", 129 | " checkpoint = \"weights/sam_vit_l_0b3195.pth\"\n", 130 | " model_type = \"vit_l\"\n", 131 | " sam = sam_model_registry[model_type](checkpoint=checkpoint)\n", 132 | " onnx_model_path = \"sam_onnx_example.onnx\"\n", 133 | "\n", 134 | " onnx_model = SamOnnxModel(sam, return_single_mask=True)\n", 135 | "\n", 136 | " dynamic_axes = {\n", 137 | " \"point_coords\": {1: \"num_points\"},\n", 138 | " \"point_labels\": {1: \"num_points\"},\n", 139 | " }\n", 140 | "\n", 141 | " embed_dim = sam.prompt_encoder.embed_dim\n", 142 | " embed_size = sam.prompt_encoder.image_embedding_size\n", 143 | " mask_input_size = [4 * x for x in embed_size]\n", 144 | " dummy_inputs = {\n", 145 | " \"image_embeddings\": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),\n", 146 | " \"point_coords\": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),\n", 147 | " \"point_labels\": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),\n", 148 | " \"mask_input\": torch.randn(1, 1, *mask_input_size, dtype=torch.float),\n", 149 | " \"has_mask_input\": torch.tensor([1], dtype=torch.float),\n", 150 | " # \"orig_im_size\": torch.tensor([1500, 2250], dtype=torch.float),\n", 151 | " }\n", 152 | " # output_names = [\"masks\", \"iou_predictions\", \"low_res_masks\"]\n", 153 | " output_names = [\"masks\", \"scores\"]\n", 154 | "\n", 155 | " with warnings.catch_warnings():\n", 156 | " warnings.filterwarnings(\"ignore\", category=torch.jit.TracerWarning)\n", 157 | " warnings.filterwarnings(\"ignore\", category=UserWarning)\n", 158 | " with open(onnx_model_path, \"wb\") as f:\n", 159 | " torch.onnx.export(\n", 160 | " onnx_model,\n", 161 | " tuple(dummy_inputs.values()),\n", 162 | " f,\n", 163 | " export_params=True,\n", 164 | " verbose=False,\n", 165 | " opset_version=13,\n", 166 | " do_constant_folding=True,\n", 167 | " input_names=list(dummy_inputs.keys()),\n", 168 | " output_names=output_names,\n", 169 | " dynamic_axes=dynamic_axes,\n", 170 | " ) \n", 171 | "\n", 172 | "def onnx_model_example():\n", 173 | " import os\n", 174 | " ort_session_embedding = onnxruntime.InferenceSession('vit_l_embedding.onnx',providers=['CPUExecutionProvider'])\n", 175 | " ort_session_sam = onnxruntime.InferenceSession('sam_onnx_example.onnx',providers=['CPUExecutionProvider'])\n", 176 | "\n", 177 | " image = cv2.imread('notebooks/images/truck.jpg')\n", 178 | " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n", 179 | " image2 = image.copy()\n", 180 | " prompt_embed_dim = 256\n", 181 | " image_size = 1024\n", 182 | " vit_patch_size = 16\n", 183 | " image_embedding_size = image_size // vit_patch_size\n", 184 | " target_length = image_size\n", 185 | " pixel_mean=[123.675, 116.28, 103.53],\n", 186 | " pixel_std=[58.395, 57.12, 57.375]\n", 187 | " pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1)\n", 188 | " pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1)\n", 189 | " device = \"cpu\"\n", 190 | " inputs = pre_processing(image, target_length, device,pixel_mean,pixel_std,image_size)\n", 191 | " ort_inputs = {\n", 192 | " \"images\": inputs.cpu().numpy()\n", 193 | " }\n", 194 | " image_embeddings = ort_session_embedding.run(None, ort_inputs)[0]\n", 195 | "\n", 196 | " input_point = np.array([[500, 375]])\n", 197 | " input_label = np.array([1])\n", 198 | "\n", 199 | " onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]\n", 200 | " onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)\n", 201 | " from segment_anything.utils.transforms import ResizeLongestSide\n", 202 | " transf = ResizeLongestSide(image_size)\n", 203 | " onnx_coord = transf.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)\n", 204 | " onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)\n", 205 | " onnx_has_mask_input = np.zeros(1, dtype=np.float32)\n", 206 | "\n", 207 | " ort_inputs = {\n", 208 | " \"image_embeddings\": image_embeddings,\n", 209 | " \"point_coords\": onnx_coord,\n", 210 | " \"point_labels\": onnx_label,\n", 211 | " \"mask_input\": onnx_mask_input,\n", 212 | " \"has_mask_input\": onnx_has_mask_input,\n", 213 | " # \"orig_im_size\": np.array(image.shape[:2], dtype=np.float32)\n", 214 | " }\n", 215 | "\n", 216 | " masks, _ = ort_session_sam.run(None, ort_inputs)\n", 217 | "\n", 218 | " from segment_anything.utils.onnx import SamOnnxModel\n", 219 | " checkpoint = \"weights/sam_vit_l_0b3195.pth\"\n", 220 | " model_type = \"vit_l\"\n", 221 | " sam = sam_model_registry[model_type](checkpoint=checkpoint)\n", 222 | " onnx_model_path = \"sam_onnx_example.onnx\"\n", 223 | "\n", 224 | " onnx_model = SamOnnxModel(sam, return_single_mask=True)\n", 225 | " masks = onnx_model.mask_postprocessing(torch.as_tensor(masks), torch.as_tensor(image.shape[:2]))\n", 226 | " masks = masks > 0.0\n", 227 | " plt.figure(figsize=(10, 10))\n", 228 | " plt.imshow(image)\n", 229 | " show_mask(masks, plt.gca())\n", 230 | " # show_box(input_box, plt.gca())\n", 231 | " show_points(input_point, input_label, plt.gca())\n", 232 | " plt.axis('off')\n", 233 | " plt.savefig('demo.png')\n", 234 | "\n", 235 | "def export_engine_image_encoder(f='vit_l_embedding.onnx'):\n", 236 | " import tensorrt as trt\n", 237 | " from pathlib import Path\n", 238 | " file = Path(f)\n", 239 | " f = file.with_suffix('.engine') # TensorRT engine file\n", 240 | " onnx = file.with_suffix('.onnx')\n", 241 | " logger = trt.Logger(trt.Logger.INFO)\n", 242 | " builder = trt.Builder(logger)\n", 243 | " config = builder.create_builder_config()\n", 244 | " workspace = 6\n", 245 | " print(\"workspace: \", workspace)\n", 246 | " config.max_workspace_size = workspace * 1 << 30\n", 247 | " flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))\n", 248 | " network = builder.create_network(flag)\n", 249 | " parser = trt.OnnxParser(network, logger)\n", 250 | " if not parser.parse_from_file(str(onnx)):\n", 251 | " raise RuntimeError(f'failed to load ONNX file: {onnx}')\n", 252 | "\n", 253 | " inputs = [network.get_input(i) for i in range(network.num_inputs)]\n", 254 | " outputs = [network.get_output(i) for i in range(network.num_outputs)]\n", 255 | " for inp in inputs:\n", 256 | " print(f'input \"{inp.name}\" with shape{inp.shape} {inp.dtype}')\n", 257 | " for out in outputs:\n", 258 | " print(f'output \"{out.name}\" with shape{out.shape} {out.dtype}')\n", 259 | "\n", 260 | " half = True\n", 261 | " print(f'building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')\n", 262 | " if builder.platform_has_fast_fp16 and half:\n", 263 | " config.set_flag(trt.BuilderFlag.FP16)\n", 264 | " with builder.build_engine(network, config) as engine, open(f, 'wb') as t:\n", 265 | " t.write(engine.serialize())\n", 266 | "\n", 267 | "\n", 268 | "def export_engine_prompt_encoder_and_mask_decoder(f='sam_onnx_example.onnx'):\n", 269 | " import tensorrt as trt\n", 270 | " from pathlib import Path\n", 271 | " file = Path(f)\n", 272 | " f = file.with_suffix('.engine') # TensorRT engine file\n", 273 | " onnx = file.with_suffix('.onnx')\n", 274 | " logger = trt.Logger(trt.Logger.INFO)\n", 275 | " builder = trt.Builder(logger)\n", 276 | " config = builder.create_builder_config()\n", 277 | " workspace = 10\n", 278 | " print(\"workspace: \", workspace)\n", 279 | " config.max_workspace_size = workspace * 1 << 30\n", 280 | " flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))\n", 281 | " network = builder.create_network(flag)\n", 282 | " parser = trt.OnnxParser(network, logger)\n", 283 | " if not parser.parse_from_file(str(onnx)):\n", 284 | " raise RuntimeError(f'failed to load ONNX file: {onnx}')\n", 285 | "\n", 286 | " inputs = [network.get_input(i) for i in range(network.num_inputs)]\n", 287 | " outputs = [network.get_output(i) for i in range(network.num_outputs)]\n", 288 | " for inp in inputs:\n", 289 | " print(f'input \"{inp.name}\" with shape{inp.shape} {inp.dtype}')\n", 290 | " for out in outputs:\n", 291 | " print(f'output \"{out.name}\" with shape{out.shape} {out.dtype}')\n", 292 | "\n", 293 | " profile = builder.create_optimization_profile()\n", 294 | " profile.set_shape('image_embeddings', (1, 256, 64, 64), (1, 256, 64, 64), (1, 256, 64, 64))\n", 295 | " profile.set_shape('point_coords', (1, 2,2), (1, 5,2), (1,10,2))\n", 296 | " profile.set_shape('point_labels', (1, 2), (1, 5), (1,10))\n", 297 | " profile.set_shape('mask_input', (1, 1, 256, 256), (1, 1, 256, 256), (1, 1, 256, 256))\n", 298 | " profile.set_shape('has_mask_input', (1,), (1, ), (1, ))\n", 299 | " # # profile.set_shape_input('orig_im_size', (416,416), (1024,1024), (1500, 2250))\n", 300 | " # profile.set_shape_input('orig_im_size', (2,), (2,), (2, ))\n", 301 | " config.add_optimization_profile(profile)\n", 302 | "\n", 303 | " half = True\n", 304 | " print(f'building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')\n", 305 | " if builder.platform_has_fast_fp16 and half:\n", 306 | " config.set_flag(trt.BuilderFlag.FP16)\n", 307 | " with builder.build_engine(network, config) as engine, open(f, 'wb') as t:\n", 308 | " t.write(engine.serialize())\n", 309 | "\n", 310 | "\n", 311 | "def collect_stats(model, device, data_loader, num_batches):\n", 312 | " \"\"\"Feed data to the network and collect statistic\"\"\"\n", 313 | "\n", 314 | " # Enable calibrators\n", 315 | " for name, module in model.named_modules():\n", 316 | " if isinstance(module, quant_nn.TensorQuantizer):\n", 317 | " if module._calibrator is not None:\n", 318 | " module.disable_quant()\n", 319 | " module.enable_calib()\n", 320 | " else:\n", 321 | " module.disable()\n", 322 | "\n", 323 | " for i, (path, im, im0s, vid_cap, s) in tqdm(enumerate(data_loader), total=num_batches):\n", 324 | " im = torch.from_numpy(im).to(device)\n", 325 | " im = im.float()\n", 326 | " im /= 255 # 0 - 255 to 0.0 - 1.0\n", 327 | " if len(im.shape) == 3:\n", 328 | " im = im[None] # expand for batch dim\n", 329 | " model(im)\n", 330 | " if i >= num_batches:\n", 331 | " break\n", 332 | "\n", 333 | " # Disable calibrators\n", 334 | " for name, module in model.named_modules():\n", 335 | " if isinstance(module, quant_nn.TensorQuantizer):\n", 336 | " if module._calibrator is not None:\n", 337 | " module.enable_quant()\n", 338 | " module.disable_calib()\n", 339 | " else:\n", 340 | " module.enable()\n", 341 | "\n", 342 | "def compute_amax(model, **kwargs):\n", 343 | " # Load calib result\n", 344 | " for name, module in model.named_modules():\n", 345 | " if isinstance(module, quant_nn.TensorQuantizer):\n", 346 | " if module._calibrator is not None:\n", 347 | " if isinstance(module._calibrator, calib.MaxCalibrator):\n", 348 | " module.load_calib_amax()\n", 349 | " else:\n", 350 | " module.load_calib_amax(**kwargs)\n", 351 | "\n", 352 | "class embedding_model_part_1(nn.Module):\n", 353 | " def __init__(self) :\n", 354 | " super().__init__()\n", 355 | " sam_checkpoint = \"weights/sam_vit_h_4b8939.pth\"\n", 356 | " model_type = \"vit_h\"\n", 357 | " device = \"cpu\"\n", 358 | " self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n", 359 | " self.sam.to(device=device)\n", 360 | "\n", 361 | " def forward(self, x):\n", 362 | " x = self.sam.image_encoder.patch_embed(x)\n", 363 | " if self.sam.image_encoder.pos_embed is not None:\n", 364 | " x = x + self.sam.image_encoder.pos_embed\n", 365 | "\n", 366 | " val_1 = len(self.sam.image_encoder.blocks) // 2\n", 367 | " for blk in self.sam.image_encoder.blocks[0:val_1]:\n", 368 | " x = blk(x)\n", 369 | "\n", 370 | " return x\n", 371 | " \n", 372 | "class embedding_model_part_2(nn.Module):\n", 373 | " def __init__(self) :\n", 374 | " super().__init__()\n", 375 | " sam_checkpoint = \"weights/sam_vit_h_4b8939.pth\"\n", 376 | " model_type = \"vit_h\"\n", 377 | " device = \"cpu\"\n", 378 | " self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n", 379 | " self.sam.to(device=device)\n", 380 | "\n", 381 | " def forward(self, x):\n", 382 | " val_1 = len(self.sam.image_encoder.blocks) // 2\n", 383 | " for blk in self.sam.image_encoder.blocks[val_1:]:\n", 384 | " x = blk(x)\n", 385 | " x = self.sam.image_encoder.neck(x.permute(0, 3, 1, 2))\n", 386 | " return x\n", 387 | "\n", 388 | "def export_embedding_model_part_1():\n", 389 | " device = \"cpu\"\n", 390 | " model_type = \"vit_h\"\n", 391 | " model = embedding_model_part_1()\n", 392 | " # image_encoder = EmbeddingOnnxModel(sam)\n", 393 | " image = cv2.imread('notebooks/images/truck.jpg')\n", 394 | " target_length = model.sam.image_encoder.img_size\n", 395 | " pixel_mean = model.sam.pixel_mean \n", 396 | " pixel_std = model.sam.pixel_std\n", 397 | " img_size = model.sam.image_encoder.img_size\n", 398 | " inputs = pre_processing(image, target_length, device,pixel_mean,pixel_std,img_size)\n", 399 | " onnx_model_path = model_type+\"_\"+\"part_1_embedding.onnx\"\n", 400 | " dummy_inputs = {\n", 401 | " \"images\": inputs\n", 402 | "}\n", 403 | " output_names = [\"image_embeddings_part_1\"]\n", 404 | " with warnings.catch_warnings():\n", 405 | " warnings.filterwarnings(\"ignore\", category=torch.jit.TracerWarning)\n", 406 | " warnings.filterwarnings(\"ignore\", category=UserWarning)\n", 407 | " with open(onnx_model_path, \"wb\") as f:\n", 408 | " torch.onnx.export(\n", 409 | " model,\n", 410 | " tuple(dummy_inputs.values()),\n", 411 | " f,\n", 412 | " export_params=True,\n", 413 | " verbose=False,\n", 414 | " opset_version=13,\n", 415 | " do_constant_folding=True,\n", 416 | " input_names=list(dummy_inputs.keys()),\n", 417 | " output_names=output_names,\n", 418 | " # dynamic_axes=dynamic_axes,\n", 419 | " ) \n", 420 | "\n", 421 | "def export_embedding_model_part_2():\n", 422 | " device = \"cpu\"\n", 423 | " model_type = \"vit_h\"\n", 424 | " model = embedding_model_part_2()\n", 425 | " \n", 426 | " inputs = torch.randn(1, 64, 64, 1280)\n", 427 | " onnx_model_path = model_type+\"_\"+\"part_2_embedding.onnx\"\n", 428 | " dummy_inputs = {\n", 429 | " \"image_embeddings_part_1\": inputs\n", 430 | "}\n", 431 | " output_names = [\"image_embeddings_part_2\"]\n", 432 | " with warnings.catch_warnings():\n", 433 | " warnings.filterwarnings(\"ignore\", category=torch.jit.TracerWarning)\n", 434 | " warnings.filterwarnings(\"ignore\", category=UserWarning)\n", 435 | " with open(onnx_model_path, \"wb\") as f:\n", 436 | " torch.onnx.export(\n", 437 | " model,\n", 438 | " tuple(dummy_inputs.values()),\n", 439 | " f,\n", 440 | " export_params=True,\n", 441 | " verbose=False,\n", 442 | " opset_version=13,\n", 443 | " do_constant_folding=True,\n", 444 | " input_names=list(dummy_inputs.keys()),\n", 445 | " output_names=output_names,\n", 446 | " # dynamic_axes=dynamic_axes,\n", 447 | " ) \n", 448 | " \n", 449 | "def export_sam_h_model():\n", 450 | " from segment_anything.utils.onnx import SamOnnxModel\n", 451 | " checkpoint = \"weights/sam_vit_h_4b8939.pth\"\n", 452 | " model_type = \"vit_h\"\n", 453 | " sam = sam_model_registry[model_type](checkpoint=checkpoint)\n", 454 | " onnx_model_path = \"sam_h_decoder_onnx.onnx\"\n", 455 | "\n", 456 | " onnx_model = SamOnnxModel(sam, return_single_mask=True)\n", 457 | "\n", 458 | " dynamic_axes = {\n", 459 | " \"point_coords\": {1: \"num_points\"},\n", 460 | " \"point_labels\": {1: \"num_points\"},\n", 461 | " }\n", 462 | "\n", 463 | " embed_dim = sam.prompt_encoder.embed_dim\n", 464 | " embed_size = sam.prompt_encoder.image_embedding_size\n", 465 | " mask_input_size = [4 * x for x in embed_size]\n", 466 | " dummy_inputs = {\n", 467 | " \"image_embeddings\": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),\n", 468 | " \"point_coords\": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),\n", 469 | " \"point_labels\": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),\n", 470 | " \"mask_input\": torch.randn(1, 1, *mask_input_size, dtype=torch.float),\n", 471 | " \"has_mask_input\": torch.tensor([1], dtype=torch.float),\n", 472 | " # \"orig_im_size\": torch.tensor([1500, 2250], dtype=torch.float),\n", 473 | " }\n", 474 | " # output_names = [\"masks\", \"iou_predictions\", \"low_res_masks\"]\n", 475 | " output_names = [\"masks\", \"scores\"]\n", 476 | "\n", 477 | " with warnings.catch_warnings():\n", 478 | " warnings.filterwarnings(\"ignore\", category=torch.jit.TracerWarning)\n", 479 | " warnings.filterwarnings(\"ignore\", category=UserWarning)\n", 480 | " with open(onnx_model_path, \"wb\") as f:\n", 481 | " torch.onnx.export(\n", 482 | " onnx_model,\n", 483 | " tuple(dummy_inputs.values()),\n", 484 | " f,\n", 485 | " export_params=True,\n", 486 | " verbose=False,\n", 487 | " opset_version=13,\n", 488 | " do_constant_folding=True,\n", 489 | " input_names=list(dummy_inputs.keys()),\n", 490 | " output_names=output_names,\n", 491 | " dynamic_axes=dynamic_axes,\n", 492 | " ) \n", 493 | "\n", 494 | "def onnx_model_example2():\n", 495 | " import os\n", 496 | " ort_session_embedding_part_1 = onnxruntime.InferenceSession('vit_h_part_1_embedding.onnx',providers=['CPUExecutionProvider'])\n", 497 | " ort_session_embedding_part_2 = onnxruntime.InferenceSession('vit_h_part_2_embedding.onnx',providers=['CPUExecutionProvider'])\n", 498 | " ort_session_sam = onnxruntime.InferenceSession('sam_h_decoder_onnx.onnx',providers=['CPUExecutionProvider'])\n", 499 | "\n", 500 | " image = cv2.imread('notebooks/images/truck.jpg')\n", 501 | " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n", 502 | " image2 = image.copy()\n", 503 | " prompt_embed_dim = 256\n", 504 | " image_size = 1024\n", 505 | " vit_patch_size = 16\n", 506 | " image_embedding_size = image_size // vit_patch_size\n", 507 | " target_length = image_size\n", 508 | " pixel_mean=[123.675, 116.28, 103.53],\n", 509 | " pixel_std=[58.395, 57.12, 57.375]\n", 510 | " pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1)\n", 511 | " pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1)\n", 512 | " device = \"cpu\"\n", 513 | " inputs = pre_processing(image, target_length, device,pixel_mean,pixel_std,image_size)\n", 514 | " ort_inputs = {\n", 515 | " \"images\": inputs.cpu().numpy()\n", 516 | " }\n", 517 | " image_embeddings = ort_session_embedding_part_1.run(None, ort_inputs)[0]\n", 518 | " ort_inputs = {\n", 519 | " \"image_embeddings_part_1\": image_embeddings\n", 520 | " }\n", 521 | " image_embeddings = ort_session_embedding_part_2.run(None, ort_inputs)[0]\n", 522 | "\n", 523 | " input_point = np.array([[784, 379]])\n", 524 | " input_label = np.array([1])\n", 525 | "\n", 526 | " onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]\n", 527 | " onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)\n", 528 | " from segment_anything.utils.transforms import ResizeLongestSide\n", 529 | " transf = ResizeLongestSide(image_size)\n", 530 | " onnx_coord = transf.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)\n", 531 | " onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)\n", 532 | " onnx_has_mask_input = np.zeros(1, dtype=np.float32)\n", 533 | "\n", 534 | " ort_inputs = {\n", 535 | " \"image_embeddings\": image_embeddings,\n", 536 | " \"point_coords\": onnx_coord,\n", 537 | " \"point_labels\": onnx_label,\n", 538 | " \"mask_input\": onnx_mask_input,\n", 539 | " \"has_mask_input\": onnx_has_mask_input,\n", 540 | " # \"orig_im_size\": np.array(image.shape[:2], dtype=np.float32)\n", 541 | " }\n", 542 | "\n", 543 | " masks, _ = ort_session_sam.run(None, ort_inputs)\n", 544 | "\n", 545 | " from segment_anything.utils.onnx import SamOnnxModel\n", 546 | " checkpoint = \"weights/sam_vit_h_4b8939.pth\"\n", 547 | " model_type = \"vit_h\"\n", 548 | " sam = sam_model_registry[model_type](checkpoint=checkpoint)\n", 549 | "\n", 550 | " onnx_model = SamOnnxModel(sam, return_single_mask=True)\n", 551 | " masks = onnx_model.mask_postprocessing(torch.as_tensor(masks), torch.as_tensor(image.shape[:2]))\n", 552 | " masks = masks > 0.0\n", 553 | " plt.figure(figsize=(10, 10))\n", 554 | " plt.imshow(image)\n", 555 | " show_mask(masks, plt.gca())\n", 556 | " # show_box(input_box, plt.gca())\n", 557 | " show_points(input_point, input_label, plt.gca())\n", 558 | " plt.axis('off')\n", 559 | " plt.savefig('demo.png')\n", 560 | " plt.show()\n", 561 | "\n", 562 | "\n", 563 | "if __name__ == '__main__':\n", 564 | " with torch.no_grad():\n", 565 | " # export_embedding_model()\n", 566 | " # export_sam_model()\n", 567 | " # onnx_model_example()\n", 568 | " # export_engine_image_encoder()\n", 569 | " # export_engine_prompt_encoder_and_mask_decoder()\n", 570 | " export_embedding_model_part_1()\n", 571 | " export_embedding_model_part_2()\n", 572 | " export_sam_h_model()\n", 573 | " onnx_model_example2()\n", 574 | " " 575 | ] 576 | } 577 | ], 578 | "metadata": { 579 | "kernelspec": { 580 | "display_name": "Python 3", 581 | "language": "python", 582 | "name": "python3" 583 | }, 584 | "language_info": { 585 | "codemirror_mode": { 586 | "name": "ipython", 587 | "version": 3 588 | }, 589 | "file_extension": ".py", 590 | "mimetype": "text/x-python", 591 | "name": "python", 592 | "nbconvert_exporter": "python", 593 | "pygments_lexer": "ipython3", 594 | "version": "3.8.10" 595 | }, 596 | "orig_nbformat": 4 597 | }, 598 | "nbformat": 4, 599 | "nbformat_minor": 2 600 | } 601 | -------------------------------------------------------------------------------- /main.cpp: -------------------------------------------------------------------------------- 1 | #include "sam.h" 2 | #include "export.h" 3 | /////////////////////////////////////////////////////////////////////////////// 4 | using namespace std; 5 | using namespace cv; 6 | 7 | std::shared_ptr eng_0; 8 | std::shared_ptr eng_1; 9 | at::Tensor image_embeddings; 10 | 11 | void locator(int event, int x, int y, int flags, void *userdata) 12 | { // function to track mouse movement and click// 13 | if (event == EVENT_LBUTTONDOWN) 14 | { // when left button clicked// 15 | cout << "Left click has been made, Position:(" << x << "," << y << ")" << endl; 16 | // auto res = eng_1->prepareInput(x, y, x - 100, y - 100, x + 100, y + 100, image_embeddings); 17 | auto res = eng_1->prepareInput(x, y, image_embeddings); 18 | // std::vector mult_pts = {x,y,x-5,y-5,x+5,y+5}; 19 | // auto res = eng_1->prepareInput(mult_pts, image_embeddings); 20 | std::cout << "------------------prepareInput: " << res << std::endl; 21 | res = eng_1->infer(); 22 | std::cout << "------------------infer: " << res << std::endl; 23 | eng_1->verifyOutput(); 24 | // cv::Mat roi; 25 | // eng_1->verifyOutput(roi); 26 | // roi *= 255; 27 | // cv::imshow("img2_", roi); 28 | // cv::waitKey(); 29 | std::cout << "------------------verifyOutput: " << std::endl; 30 | } 31 | else if (event == EVENT_RBUTTONDOWN) 32 | { // when right button clicked// 33 | // cout << "Rightclick has been made, Position:(" << x << "," << y << ")" << endl; 34 | } 35 | else if (event == EVENT_MBUTTONDOWN) 36 | { // when middle button clicked// 37 | // cout << "Middleclick has been made, Position:(" << x << "," << y << ")" << endl; 38 | } 39 | else if (event == EVENT_MOUSEMOVE) 40 | { // when mouse pointer moves// 41 | // cout << "Current mouse position:(" << x << "," << y << ")" << endl; 42 | } 43 | } 44 | 45 | #define EMBEDDING 46 | #define SAMPROMPTENCODERANDMASKDECODER 47 | int main(int argc, char const *argv[]) 48 | { 49 | ifstream f1("vit_l_embedding.engine"); 50 | if (!f1.good()) 51 | export_engine_image_encoder("/workspace/segment-anything-tensorrt/data/vit_l_embedding.onnx"); 52 | 53 | ifstream f2("sam_onnx_example.engine"); 54 | if (!f2.good()) 55 | export_engine_prompt_encoder_and_mask_decoder("/workspace/segment-anything-tensorrt/data/sam_onnx_example.onnx"); 56 | 57 | #ifdef EMBEDDING 58 | { 59 | // const std::string modelFile = "D:/projects/detections/data/vit_l_embedding.engine"; 60 | const std::string modelFile = "vit_l_embedding.engine"; 61 | std::ifstream engineFile(modelFile.c_str(), std::ifstream::binary); 62 | assert(engineFile); 63 | // if (!engineFile) 64 | // return; 65 | 66 | int fsize; 67 | engineFile.seekg(0, engineFile.end); 68 | fsize = engineFile.tellg(); 69 | engineFile.seekg(0, engineFile.beg); 70 | std::vector engineData(fsize); 71 | engineFile.read(engineData.data(), fsize); 72 | 73 | if (engineFile) 74 | std::cout << "all characters read successfully." << std::endl; 75 | else 76 | std::cout << "error: only " << engineFile.gcount() << " could be read" << std::endl; 77 | engineFile.close(); 78 | 79 | std::unique_ptr runtime(nvinfer1::createInferRuntime(logger)); 80 | std::shared_ptr mEngine(runtime->deserializeCudaEngine(engineData.data(), fsize, nullptr)); 81 | cv::Mat frame = cv::imread("/workspace/segment-anything-tensorrt/data/truck.jpg"); 82 | std::cout << frame.size << std::endl; 83 | eng_0 = std::shared_ptr(new SamEmbedding(std::to_string(1), mEngine, frame)); 84 | auto res = eng_0->prepareInput(); 85 | std::cout << "------------------prepareInput: " << res << std::endl; 86 | res = eng_0->infer(); 87 | std::cout << "------------------infer: " << res << std::endl; 88 | image_embeddings = eng_0->verifyOutput(); 89 | std::cout << "------------------verifyOutput: " << std::endl; 90 | } 91 | 92 | #endif 93 | 94 | #ifdef SAMPROMPTENCODERANDMASKDECODER 95 | { 96 | // const std::string modelFile = "D:/projects/detections/data/sam_onnx_example.engine"; 97 | const std::string modelFile = "sam_onnx_example.engine"; 98 | std::ifstream engineFile(modelFile.c_str(), std::ifstream::binary); 99 | assert(engineFile); 100 | // if (!engineFile) 101 | // return; 102 | 103 | int fsize; 104 | engineFile.seekg(0, engineFile.end); 105 | fsize = engineFile.tellg(); 106 | engineFile.seekg(0, engineFile.beg); 107 | std::vector engineData(fsize); 108 | engineFile.read(engineData.data(), fsize); 109 | 110 | if (engineFile) 111 | std::cout << "all characters read successfully." << std::endl; 112 | else 113 | std::cout << "error: only " << engineFile.gcount() << " could be read" << std::endl; 114 | engineFile.close(); 115 | 116 | std::unique_ptr runtime(nvinfer1::createInferRuntime(logger)); 117 | std::shared_ptr mEngine(runtime->deserializeCudaEngine(engineData.data(), fsize, nullptr)); 118 | cv::Mat frame = cv::imread("/workspace/segment-anything-tensorrt/data/truck.jpg"); 119 | eng_1 = std::shared_ptr(new SamPromptEncoderAndMaskDecoder(std::to_string(1), mEngine, frame)); 120 | // Mat image = imread("D:/projects/detections/data/2.png");//loading image in the matrix// 121 | // namedWindow("img_", 0); // declaring window to show image// 122 | // setMouseCallback("img_", locator, NULL); // Mouse callback function on define window// 123 | // imshow("img_", frame); // showing image on the window// 124 | // waitKey(0); // wait for keystroke// 125 | 126 | auto res = eng_1->prepareInput(100, 100, image_embeddings); 127 | // std::vector mult_pts = {x,y,x-5,y-5,x+5,y+5}; 128 | // auto res = eng_1->prepareInput(mult_pts, image_embeddings); 129 | std::cout << "------------------prepareInput: " << res << std::endl; 130 | res = eng_1->infer(); 131 | std::cout << "------------------infer: " << res << std::endl; 132 | eng_1->verifyOutput(); 133 | std::cout << "-----------------done" << std::endl; 134 | } 135 | #endif 136 | } -------------------------------------------------------------------------------- /main_vim_h.cpp: -------------------------------------------------------------------------------- 1 | #include "sam.h" 2 | #include "export.h" 3 | #include "baseModel.h" 4 | /////////////////////////////////////////////////////////////////////////////// 5 | using namespace std; 6 | using namespace cv; 7 | 8 | std::shared_ptr eng_0; 9 | std::shared_ptr eng_2; 10 | std::shared_ptr eng_1; 11 | at::Tensor image_embeddings; 12 | 13 | void locator(int event, int x, int y, int flags, void *userdata) 14 | { // function to track mouse movement and click// 15 | if (event == EVENT_LBUTTONDOWN) 16 | { // when left button clicked// 17 | cout << "Left click has been made, Position:(" << x << "," << y << ")" << endl; 18 | // auto res = eng_1->prepareInput(x, y, x - 100, y - 100, x + 100, y + 100, image_embeddings); 19 | auto res = eng_1->prepareInput(x, y, image_embeddings); 20 | // std::vector mult_pts = {x,y,x-5,y-5,x+5,y+5}; 21 | // auto res = eng_1->prepareInput(mult_pts, image_embeddings); 22 | std::cout << "------------------prepareInput: " << res << std::endl; 23 | res = eng_1->infer(); 24 | std::cout << "------------------infer: " << res << std::endl; 25 | eng_1->verifyOutput(); 26 | std::cout << "------------------verifyOutput: " << std::endl; 27 | } 28 | else if (event == EVENT_RBUTTONDOWN) 29 | { // when right button clicked// 30 | // cout << "Rightclick has been made, Position:(" << x << "," << y << ")" << endl; 31 | } 32 | else if (event == EVENT_MBUTTONDOWN) 33 | { // when middle button clicked// 34 | // cout << "Middleclick has been made, Position:(" << x << "," << y << ")" << endl; 35 | } 36 | else if (event == EVENT_MOUSEMOVE) 37 | { // when mouse pointer moves// 38 | // cout << "Current mouse position:(" << x << "," << y << ")" << endl; 39 | } 40 | } 41 | 42 | #define EMBEDDING 43 | #define SAMPROMPTENCODERANDMASKDECODER 44 | int main(int argc, char const *argv[]) 45 | { 46 | // BaseModel tmp_model("vit_h_embedding_part_1.engine"); 47 | 48 | ifstream f1("vit_h_embedding_part_1.engine"); 49 | if (!f1.good()) 50 | export_engine_image_encoder("/workspace/segment-anything-tensorrt/data/vit_h_part_1_embedding.onnx", "vit_h_embedding_part_1.engine"); 51 | 52 | ifstream f2("vit_h_embedding_part_2.engine"); 53 | if (!f2.good()) 54 | export_engine_image_encoder("/workspace/segment-anything-tensorrt/data/vit_h_part_2_embedding.onnx", "vit_h_embedding_part_2.engine"); 55 | 56 | ifstream f3("sam_onnx_decoder.engine"); 57 | if (!f3.good()) 58 | export_engine_prompt_encoder_and_mask_decoder("/workspace/segment-anything-tensorrt/data/sam_h_decoder_onnx.onnx", "sam_onnx_decoder.engine"); 59 | 60 | #ifdef EMBEDDING 61 | { 62 | // const std::string modelFile = "D:/projects/detections/data/vit_l_embedding.engine"; 63 | const std::string modelFile = "vit_h_embedding_part_1.engine"; 64 | std::ifstream engineFile(modelFile.c_str(), std::ifstream::binary); 65 | assert(engineFile); 66 | // if (!engineFile) 67 | // return; 68 | 69 | int fsize; 70 | engineFile.seekg(0, engineFile.end); 71 | fsize = engineFile.tellg(); 72 | engineFile.seekg(0, engineFile.beg); 73 | std::vector engineData(fsize); 74 | engineFile.read(engineData.data(), fsize); 75 | 76 | if (engineFile) 77 | std::cout << "all characters read successfully." << std::endl; 78 | else 79 | std::cout << "error: only " << engineFile.gcount() << " could be read" << std::endl; 80 | engineFile.close(); 81 | 82 | std::unique_ptr runtime(nvinfer1::createInferRuntime(logger)); 83 | std::shared_ptr mEngine(runtime->deserializeCudaEngine(engineData.data(), fsize, nullptr)); 84 | cv::Mat frame = cv::imread("/workspace/segment-anything-tensorrt/data/truck.jpg"); 85 | std::cout << frame.size << std::endl; 86 | eng_0 = std::shared_ptr(new SamEmbedding(std::to_string(1), mEngine, frame)); 87 | auto res = eng_0->prepareInput(); 88 | std::cout << "------------------prepareInput: " << res << std::endl; 89 | res = eng_0->infer(); 90 | std::cout << "------------------infer: " << res << std::endl; 91 | image_embeddings = eng_0->verifyOutput("image_embeddings_part_1"); 92 | std::cout << "------------------verifyOutput: " << std::endl; 93 | { 94 | // const std::string modelFile = "D:/projects/detections/data/vit_l_embedding.engine"; 95 | const std::string modelFile = "vit_h_embedding_part_2.engine"; 96 | std::ifstream engineFile(modelFile.c_str(), std::ifstream::binary); 97 | assert(engineFile); 98 | // if (!engineFile) 99 | // return; 100 | 101 | int fsize; 102 | engineFile.seekg(0, engineFile.end); 103 | fsize = engineFile.tellg(); 104 | engineFile.seekg(0, engineFile.beg); 105 | std::vector engineData(fsize); 106 | engineFile.read(engineData.data(), fsize); 107 | 108 | if (engineFile) 109 | std::cout << "all characters read successfully." << std::endl; 110 | else 111 | std::cout << "error: only " << engineFile.gcount() << " could be read" << std::endl; 112 | engineFile.close(); 113 | 114 | std::unique_ptr runtime(nvinfer1::createInferRuntime(logger)); 115 | std::shared_ptr mEngine(runtime->deserializeCudaEngine(engineData.data(), fsize, nullptr)); 116 | eng_2 = std::shared_ptr(new SamEmbedding2(std::to_string(1), mEngine)); 117 | auto res = eng_2->prepareInput(image_embeddings); 118 | std::cout << "------------------prepareInput: " << res << std::endl; 119 | res = eng_2->infer(); 120 | std::cout << "------------------infer: " << res << std::endl; 121 | image_embeddings = eng_2->verifyOutput(); 122 | std::cout << "------------------verifyOutput: " << std::endl; 123 | } 124 | } 125 | 126 | #endif 127 | 128 | #ifdef SAMPROMPTENCODERANDMASKDECODER 129 | { 130 | // const std::string modelFile = "D:/projects/detections/data/sam_onnx_example.engine"; 131 | const std::string modelFile = "sam_onnx_decoder.engine"; 132 | std::ifstream engineFile(modelFile.c_str(), std::ifstream::binary); 133 | assert(engineFile); 134 | // if (!engineFile) 135 | // return; 136 | 137 | int fsize; 138 | engineFile.seekg(0, engineFile.end); 139 | fsize = engineFile.tellg(); 140 | engineFile.seekg(0, engineFile.beg); 141 | std::vector engineData(fsize); 142 | engineFile.read(engineData.data(), fsize); 143 | 144 | if (engineFile) 145 | std::cout << "all characters read successfully." << std::endl; 146 | else 147 | std::cout << "error: only " << engineFile.gcount() << " could be read" << std::endl; 148 | engineFile.close(); 149 | 150 | std::unique_ptr runtime(nvinfer1::createInferRuntime(logger)); 151 | std::shared_ptr mEngine(runtime->deserializeCudaEngine(engineData.data(), fsize, nullptr)); 152 | cv::Mat frame = cv::imread("/workspace/segment-anything-tensorrt/data/truck.jpg"); 153 | eng_1 = std::shared_ptr(new SamPromptEncoderAndMaskDecoder(std::to_string(1), mEngine, frame)); 154 | // namedWindow("img_", 0); // declaring window to show image// 155 | // setMouseCallback("img_", locator, NULL); // Mouse callback function on define window// 156 | // imshow("img_", frame); // showing image on the window// 157 | // waitKey(0); // wait for keystroke// 158 | 159 | auto res = eng_1->prepareInput(760, 373, image_embeddings); 160 | // // std::vector mult_pts = {x,y,x-5,y-5,x+5,y+5}; 161 | // // auto res = eng_1->prepareInput(mult_pts, image_embeddings); 162 | // std::cout << "------------------prepareInput: " << res << std::endl; 163 | res = eng_1->infer(); 164 | // std::cout << "------------------infer: " << res << std::endl; 165 | eng_1->verifyOutput(); 166 | // std::cout << "-----------------done" << std::endl; 167 | } 168 | #endif 169 | } -------------------------------------------------------------------------------- /sam.h: -------------------------------------------------------------------------------- 1 | #ifndef SAM_H 2 | #define SAM_H 3 | 4 | #include "buffers.h" 5 | #include 6 | #include 7 | // #include 8 | #include 9 | #include "sam_utils.h" 10 | 11 | using namespace torch::indexing; 12 | 13 | class ResizeLongestSide 14 | { 15 | public: 16 | ResizeLongestSide(int target_length); 17 | ~ResizeLongestSide(); 18 | 19 | std::vector get_preprocess_shape(int oldh, int oldw); 20 | at::Tensor apply_coords(at::Tensor boxes, at::IntArrayRef sz); 21 | 22 | public: 23 | int m_target_length; 24 | }; 25 | 26 | ResizeLongestSide::ResizeLongestSide(int target_length) : m_target_length(target_length) 27 | { 28 | } 29 | 30 | ResizeLongestSide::~ResizeLongestSide() 31 | { 32 | } 33 | 34 | std::vector ResizeLongestSide::get_preprocess_shape(int oldh, int oldw) 35 | { 36 | float scale = m_target_length * 1.0 / std::max(oldh, oldw); 37 | int newh = static_cast(oldh * scale + 0.5); 38 | int neww = static_cast(oldw * scale + 0.5); 39 | std::cout << " newh " << newh << " neww " << neww << std::endl; 40 | std::cout << "at::IntArrayRef{newh, neww}" << at::IntArrayRef{newh, neww} << std::endl; 41 | return std::vector{newh, neww}; 42 | } 43 | 44 | at::Tensor ResizeLongestSide::apply_coords(at::Tensor coords, at::IntArrayRef sz) 45 | { 46 | int old_h = sz[0], old_w = sz[1]; 47 | auto new_sz = get_preprocess_shape(old_h, old_w); 48 | int new_h = new_sz[0], new_w = new_sz[1]; 49 | coords.index_put_({"...", 0}, coords.index({"...", 0}) * (1.0 * new_w / old_w)); 50 | coords.index_put_({"...", 1}, coords.index({"...", 1}) * (1.0 * new_h / old_h)); 51 | return coords; 52 | } 53 | 54 | //////////////////////////////////////////////////////////////////////////////////// 55 | 56 | class SamEmbedding 57 | { 58 | public: 59 | SamEmbedding(std::string bufferName, std::shared_ptr &engine, cv::Mat im, int width = 640, int height = 640); 60 | ~SamEmbedding(); 61 | 62 | int prepareInput(); 63 | bool infer(); 64 | at::Tensor verifyOutput(); 65 | at::Tensor verifyOutput(std::string output_name); 66 | 67 | public: 68 | std::shared_ptr mEngine; 69 | std::unique_ptr context; 70 | 71 | cudaStream_t stream; 72 | cudaEvent_t start, end; 73 | 74 | std::vector mDeviceBindings; 75 | std::map> mInOut; 76 | std::vector pad_info; 77 | std::vector names; 78 | cv::Mat frame; 79 | cv::Mat img; 80 | int inp_width = 640; 81 | int inp_height = 640; 82 | std::string mBufferName; 83 | }; 84 | 85 | SamEmbedding::SamEmbedding(std::string bufferName, std::shared_ptr &engine, cv::Mat im, int width, int height) : mBufferName(bufferName), mEngine(engine), frame(im), inp_width(width), inp_height(height) 86 | { 87 | context = std::unique_ptr(mEngine->createExecutionContext()); 88 | if (!context) 89 | { 90 | std::cerr << "create context error" << std::endl; 91 | } 92 | 93 | CHECK(cudaStreamCreate(&stream)); 94 | CHECK(cudaEventCreateWithFlags(&start, cudaEventBlockingSync)); 95 | CHECK(cudaEventCreateWithFlags(&end, cudaEventBlockingSync)); 96 | 97 | for (int i = 0; i < mEngine->getNbBindings(); i++) 98 | { 99 | auto dims = mEngine->getBindingDimensions(i); 100 | auto tensor_name = mEngine->getBindingName(i); 101 | std::cout << "tensor_name: " << tensor_name << std::endl; 102 | // dims2str(dims); 103 | nvinfer1::DataType type = mEngine->getBindingDataType(i); 104 | // index2srt(type); 105 | int vecDim = mEngine->getBindingVectorizedDim(i); 106 | // std::cout << "vecDim:" << vecDim << std::endl; 107 | if (-1 != vecDim) // i.e., 0 != lgScalarsPerVector 108 | { 109 | int scalarsPerVec = mEngine->getBindingComponentsPerElement(i); 110 | std::cout << "scalarsPerVec" << scalarsPerVec << std::endl; 111 | } 112 | auto vol = std::accumulate(dims.d, dims.d + dims.nbDims, int64_t{1}, std::multiplies{}); 113 | std::unique_ptr device_buffer{new algorithms::DeviceBuffer(vol, type)}; 114 | mDeviceBindings.emplace_back(device_buffer->data()); 115 | mInOut[tensor_name] = std::move(device_buffer); 116 | } 117 | } 118 | 119 | SamEmbedding::~SamEmbedding() 120 | { 121 | CHECK(cudaEventDestroy(start)); 122 | CHECK(cudaEventDestroy(end)); 123 | CHECK(cudaStreamDestroy(stream)); 124 | } 125 | 126 | int SamEmbedding::prepareInput() 127 | { 128 | int prompt_embed_dim = 256; 129 | int image_size = 1024; 130 | int vit_patch_size = 16; 131 | int target_length = image_size; 132 | auto pixel_mean = at::tensor({123.675, 116.28, 103.53}, torch::kFloat).view({-1, 1, 1}); 133 | auto pixel_std = at::tensor({58.395, 57.12, 57.375}, torch::kFloat).view({-1, 1, 1}); 134 | ResizeLongestSide transf(image_size); 135 | int newh, neww; 136 | auto target_size = transf.get_preprocess_shape(frame.rows, frame.cols); 137 | // std::cout << " " << torch::IntArrayRef{newh,neww} << std::endl; 138 | std::cout << "target_size = " << target_size << std::endl; 139 | cv::Mat im_sz; 140 | std::cout << frame.size << std::endl; 141 | cv::resize(frame, im_sz, cv::Size(target_size[1], target_size[0])); 142 | im_sz.convertTo(im_sz, CV_32F, 1.0); 143 | at::Tensor input_image_torch = 144 | at::from_blob(im_sz.data, {im_sz.rows, im_sz.cols, im_sz.channels()}) 145 | .permute({2, 0, 1}) 146 | .contiguous() 147 | .unsqueeze(0); 148 | input_image_torch = (input_image_torch - pixel_mean) / pixel_std; 149 | int h = input_image_torch.size(2); 150 | int w = input_image_torch.size(3); 151 | int padh = image_size - h; 152 | int padw = image_size - w; 153 | input_image_torch = at::pad(input_image_torch, {0, padw, 0, padh}); 154 | auto ret = mInOut["images"]->host2device((void *)(input_image_torch.data_ptr()), true, stream); 155 | return ret; 156 | } 157 | 158 | bool SamEmbedding::infer() 159 | { 160 | CHECK(cudaEventRecord(start, stream)); 161 | auto ret = context->enqueueV2(mDeviceBindings.data(), stream, nullptr); 162 | return ret; 163 | } 164 | 165 | at::Tensor SamEmbedding::verifyOutput() 166 | { 167 | float ms{0.0f}; 168 | CHECK(cudaEventRecord(end, stream)); 169 | CHECK(cudaEventSynchronize(end)); 170 | CHECK(cudaEventElapsedTime(&ms, start, end)); 171 | 172 | auto dim0 = mEngine->getTensorShape("image_embeddings"); 173 | 174 | // dims2str(dim0); 175 | // dims2str(dim1); 176 | at::Tensor preds; 177 | preds = at::zeros({dim0.d[0], dim0.d[1], dim0.d[2], dim0.d[3]}, at::kFloat); 178 | mInOut["image_embeddings"]->device2host((void *)(preds.data_ptr()), stream); 179 | 180 | // Wait for the work in the stream to complete 181 | CHECK(cudaStreamSynchronize(stream)); 182 | // torch::save({preds}, "preds.pt"); 183 | // cv::FileStorage storage("1.yaml", cv::FileStorage::WRITE); 184 | // storage << "image_embeddings" << points3dmatrix; 185 | return preds; 186 | } 187 | 188 | at::Tensor SamEmbedding::verifyOutput(std::string output_name) 189 | { 190 | float ms{0.0f}; 191 | CHECK(cudaEventRecord(end, stream)); 192 | CHECK(cudaEventSynchronize(end)); 193 | CHECK(cudaEventElapsedTime(&ms, start, end)); 194 | 195 | auto dim0 = mEngine->getTensorShape(output_name.c_str()); 196 | 197 | // dims2str(dim0); 198 | // dims2str(dim1); 199 | at::Tensor preds; 200 | preds = at::zeros({dim0.d[0], dim0.d[1], dim0.d[2], dim0.d[3]}, at::kFloat); 201 | mInOut[output_name]->device2host((void *)(preds.data_ptr()), stream); 202 | 203 | // Wait for the work in the stream to complete 204 | CHECK(cudaStreamSynchronize(stream)); 205 | // torch::save({preds}, "preds.pt"); 206 | // cv::FileStorage storage("1.yaml", cv::FileStorage::WRITE); 207 | // storage << "image_embeddings" << points3dmatrix; 208 | return preds; 209 | } 210 | 211 | /////////////////////////////////////////////////// 212 | 213 | class SamEmbedding2 214 | { 215 | public: 216 | SamEmbedding2(std::string bufferName, std::shared_ptr &engine); 217 | ~SamEmbedding2(); 218 | 219 | int prepareInput(at::Tensor input_image_torch); 220 | bool infer(); 221 | at::Tensor verifyOutput(); 222 | 223 | public: 224 | std::shared_ptr mEngine; 225 | std::unique_ptr context; 226 | 227 | cudaStream_t stream; 228 | cudaEvent_t start, end; 229 | 230 | std::vector mDeviceBindings; 231 | std::map> mInOut; 232 | std::vector pad_info; 233 | std::vector names; 234 | std::string mBufferName; 235 | }; 236 | 237 | SamEmbedding2::SamEmbedding2(std::string bufferName, std::shared_ptr &engine) : 238 | mBufferName(bufferName), mEngine(engine) 239 | { 240 | context = std::unique_ptr(mEngine->createExecutionContext()); 241 | if (!context) 242 | { 243 | std::cerr << "create context error" << std::endl; 244 | } 245 | 246 | CHECK(cudaStreamCreate(&stream)); 247 | CHECK(cudaEventCreateWithFlags(&start, cudaEventBlockingSync)); 248 | CHECK(cudaEventCreateWithFlags(&end, cudaEventBlockingSync)); 249 | 250 | for (int i = 0; i < mEngine->getNbBindings(); i++) 251 | { 252 | auto dims = mEngine->getBindingDimensions(i); 253 | auto tensor_name = mEngine->getBindingName(i); 254 | std::cout << "tensor_name: " << tensor_name << std::endl; 255 | // dims2str(dims); 256 | nvinfer1::DataType type = mEngine->getBindingDataType(i); 257 | // index2srt(type); 258 | int vecDim = mEngine->getBindingVectorizedDim(i); 259 | // std::cout << "vecDim:" << vecDim << std::endl; 260 | if (-1 != vecDim) // i.e., 0 != lgScalarsPerVector 261 | { 262 | int scalarsPerVec = mEngine->getBindingComponentsPerElement(i); 263 | std::cout << "scalarsPerVec" << scalarsPerVec << std::endl; 264 | } 265 | auto vol = std::accumulate(dims.d, dims.d + dims.nbDims, int64_t{1}, std::multiplies{}); 266 | std::unique_ptr device_buffer{new algorithms::DeviceBuffer(vol, type)}; 267 | mDeviceBindings.emplace_back(device_buffer->data()); 268 | mInOut[tensor_name] = std::move(device_buffer); 269 | } 270 | } 271 | 272 | SamEmbedding2::~SamEmbedding2() 273 | { 274 | CHECK(cudaEventDestroy(start)); 275 | CHECK(cudaEventDestroy(end)); 276 | CHECK(cudaStreamDestroy(stream)); 277 | } 278 | 279 | int SamEmbedding2::prepareInput(at::Tensor input_image_torch) 280 | { 281 | auto ret = mInOut["image_embeddings_part_1"]->host2device((void *)(input_image_torch.data_ptr()), false, stream); 282 | return ret; 283 | } 284 | 285 | bool SamEmbedding2::infer() 286 | { 287 | CHECK(cudaEventRecord(start, stream)); 288 | auto ret = context->enqueueV2(mDeviceBindings.data(), stream, nullptr); 289 | return ret; 290 | } 291 | 292 | at::Tensor SamEmbedding2::verifyOutput() 293 | { 294 | float ms{0.0f}; 295 | CHECK(cudaEventRecord(end, stream)); 296 | CHECK(cudaEventSynchronize(end)); 297 | CHECK(cudaEventElapsedTime(&ms, start, end)); 298 | 299 | auto dim0 = mEngine->getTensorShape("image_embeddings_part_2"); 300 | 301 | // dims2str(dim0); 302 | // dims2str(dim1); 303 | at::Tensor preds; 304 | preds = at::zeros({dim0.d[0], dim0.d[1], dim0.d[2], dim0.d[3]}, at::kFloat); 305 | mInOut["image_embeddings_part_2"]->device2host((void *)(preds.data_ptr()), stream); 306 | 307 | // Wait for the work in the stream to complete 308 | CHECK(cudaStreamSynchronize(stream)); 309 | // torch::save({preds}, "preds.pt"); 310 | // cv::FileStorage storage("1.yaml", cv::FileStorage::WRITE); 311 | // storage << "image_embeddings" << points3dmatrix; 312 | return preds; 313 | } 314 | 315 | /////////////////////////////////////////////////// 316 | 317 | 318 | 319 | class SamPromptEncoderAndMaskDecoder 320 | { 321 | public: 322 | SamPromptEncoderAndMaskDecoder(std::string bufferName, std::shared_ptr &engine, cv::Mat im, int width = 640, int height = 640); 323 | ~SamPromptEncoderAndMaskDecoder(); 324 | 325 | int prepareInput(int x, int y, at::Tensor image_embeddings); 326 | int prepareInput(int x, int y, int x1, int y1, int x2, int y2, at::Tensor image_embeddings); 327 | int prepareInput(std::vector mult_pts, at::Tensor image_embeddings); 328 | bool infer(); 329 | int verifyOutput(); 330 | int verifyOutput(cv::Mat& roi); 331 | at::Tensor generator_colors(int num); 332 | 333 | template 334 | Type string2Num(const std::string &str); 335 | 336 | at::Tensor plot_masks(at::Tensor masks, at::Tensor im_gpu, float alpha); 337 | 338 | public: 339 | std::shared_ptr mEngine; 340 | std::unique_ptr context; 341 | 342 | cudaStream_t stream; 343 | cudaEvent_t start, end; 344 | 345 | std::vector mDeviceBindings; 346 | std::map> mInOut; 347 | std::vector pad_info; 348 | std::vector names; 349 | cv::Mat frame; 350 | cv::Mat img; 351 | int inp_width = 640; 352 | int inp_height = 640; 353 | std::string mBufferName; 354 | }; 355 | 356 | SamPromptEncoderAndMaskDecoder::SamPromptEncoderAndMaskDecoder(std::string bufferName, std::shared_ptr &engine, cv::Mat im, int width, int height) : mBufferName(bufferName), mEngine(engine), frame(im), inp_width(width), inp_height(height) 357 | { 358 | context = std::unique_ptr(mEngine->createExecutionContext()); 359 | if (!context) 360 | { 361 | std::cerr << "create context error" << std::endl; 362 | } 363 | // set input dims whichs name "point_coords " 364 | context->setBindingDimensions(1, nvinfer1::Dims3(1, 2, 2)); 365 | // set input dims whichs name "point_label " 366 | context->setBindingDimensions(2, nvinfer1::Dims2(1, 2)); 367 | // set input dims whichs name "point_label " 368 | // context->setBindingDimensions(5, nvinfer1::Dims2(frame.rows,frame.cols)); 369 | CHECK(cudaStreamCreate(&stream)); 370 | CHECK(cudaEventCreateWithFlags(&start, cudaEventBlockingSync)); 371 | CHECK(cudaEventCreateWithFlags(&end, cudaEventBlockingSync)); 372 | 373 | int nbopts = mEngine->getNbOptimizationProfiles(); 374 | // std::cout << "nboopts: " << nbopts << std::endl; 375 | for (int i = 0; i < mEngine->getNbBindings(); i++) 376 | { 377 | // auto dims = mEngine->getBindingDimensions(i); 378 | auto tensor_name = mEngine->getBindingName(i); 379 | // std::cout << "tensor_name: " << tensor_name << std::endl; 380 | auto dims = context->getBindingDimensions(i); 381 | // dims2str(dims); 382 | nvinfer1::DataType type = mEngine->getBindingDataType(i); 383 | // index2srt(type); 384 | auto vol = std::accumulate(dims.d, dims.d + dims.nbDims, int64_t{1}, std::multiplies{}); 385 | std::unique_ptr device_buffer{new algorithms::DeviceBuffer(vol, type)}; 386 | mDeviceBindings.emplace_back(device_buffer->data()); 387 | mInOut[tensor_name] = std::move(device_buffer); 388 | } 389 | } 390 | 391 | SamPromptEncoderAndMaskDecoder::~SamPromptEncoderAndMaskDecoder() 392 | { 393 | CHECK(cudaEventDestroy(start)); 394 | CHECK(cudaEventDestroy(end)); 395 | CHECK(cudaStreamDestroy(stream)); 396 | } 397 | 398 | int SamPromptEncoderAndMaskDecoder::prepareInput(int x, int y, at::Tensor image_embeddings) 399 | { 400 | // at::Tensor image_embeddings; 401 | 402 | // torch::load(image_embeddings, "preds.pt"); 403 | // std::cout << image_embeddings.sizes() << std::endl; 404 | int image_size = 1024; 405 | ResizeLongestSide transf(image_size); 406 | 407 | auto input_point = at::tensor({x, y}, at::kFloat).reshape({-1,2}); 408 | auto input_label = at::tensor({1}, at::kFloat); 409 | 410 | auto trt_coord = at::concatenate({input_point, at::tensor({0, 0}, at::kFloat).unsqueeze(0)}, 0).unsqueeze(0); 411 | auto trt_label = at::concatenate({input_label, at::tensor({-1}, at::kFloat)}, 0).unsqueeze(0); 412 | // auto trt_coord = at::concatenate({input_point, at::tensor({x-100, y-100, x+100, y+100}, at::kFloat).reshape({-1,2})}, 0).unsqueeze(0); 413 | // auto trt_label = at::concatenate({input_label, at::tensor({2,3}, at::kFloat)}, 0).unsqueeze(0); 414 | trt_coord = transf.apply_coords(trt_coord, {frame.rows, frame.cols}); 415 | // std::cout << "trt_coord " << trt_coord.sizes() << std::endl; 416 | auto trt_mask_input = at::zeros({1, 1, 256, 256}, at::kFloat); 417 | auto trt_has_mask_input = at::zeros(1, at::kFloat); 418 | 419 | context->setBindingDimensions(1, nvinfer1::Dims3(trt_coord.size(0), trt_coord.size(1), trt_coord.size(2))); 420 | // set input dims whichs name "point_label " 421 | context->setBindingDimensions(2, nvinfer1::Dims2(trt_coord.size(0), trt_coord.size(1))); 422 | int nbopts = mEngine->getNbOptimizationProfiles(); 423 | std::cout << "nboopts: " << nbopts << std::endl; 424 | for (int i = 0; i < mEngine->getNbBindings(); i++) 425 | { 426 | // auto dims = mEngine->getBindingDimensions(i); 427 | auto tensor_name = mEngine->getBindingName(i); 428 | std::cout << "tensor_name: " << tensor_name << std::endl; 429 | auto dims = context->getBindingDimensions(i); 430 | dims2str(dims); 431 | nvinfer1::DataType type = mEngine->getBindingDataType(i); 432 | index2srt(type); 433 | auto vol = std::accumulate(dims.d, dims.d + dims.nbDims, int64_t{1}, std::multiplies{}); 434 | 435 | mInOut[tensor_name]->resize(dims); 436 | } 437 | 438 | CHECK(mInOut["image_embeddings"]->host2device((void *)(image_embeddings.data_ptr()), true, stream)); 439 | CHECK(cudaStreamSynchronize(stream)); 440 | CHECK(mInOut["point_coords"]->host2device((void *)(trt_coord.data_ptr()), true, stream)); 441 | CHECK(cudaStreamSynchronize(stream)); 442 | CHECK(mInOut["point_labels"]->host2device((void *)(trt_label.data_ptr()), true, stream)); 443 | CHECK(cudaStreamSynchronize(stream)); 444 | CHECK(mInOut["mask_input"]->host2device((void *)(trt_mask_input.data_ptr()), true, stream)); 445 | CHECK(cudaStreamSynchronize(stream)); 446 | CHECK(mInOut["has_mask_input"]->host2device((void *)(trt_has_mask_input.data_ptr()), true, stream)); 447 | CHECK(cudaStreamSynchronize(stream)); 448 | return 0; 449 | } 450 | 451 | int SamPromptEncoderAndMaskDecoder::prepareInput(int x, int y, int x1, int y1, int x2, int y2, at::Tensor image_embeddings) 452 | { 453 | // at::Tensor image_embeddings; 454 | 455 | // torch::load(image_embeddings, "preds.pt"); 456 | // std::cout << image_embeddings.sizes() << std::endl; 457 | int image_size = 1024; 458 | ResizeLongestSide transf(image_size); 459 | 460 | auto input_point = at::tensor({x, y}, at::kFloat).reshape({-1,2}); 461 | auto input_label = at::tensor({1}, at::kFloat); 462 | 463 | auto trt_coord = at::concatenate({input_point, at::tensor({x1, y1, x2, y2}, at::kFloat).reshape({-1,2})}, 0).unsqueeze(0); 464 | auto trt_label = at::concatenate({input_label, at::tensor({2,3}, at::kFloat)}, 0).unsqueeze(0); 465 | trt_coord = transf.apply_coords(trt_coord, {frame.rows, frame.cols}); 466 | // std::cout << "trt_coord " << trt_coord.sizes() << std::endl; 467 | auto trt_mask_input = at::zeros({1, 1, 256, 256}, at::kFloat); 468 | auto trt_has_mask_input = at::zeros(1, at::kFloat); 469 | 470 | context->setBindingDimensions(1, nvinfer1::Dims3(trt_coord.size(0), trt_coord.size(1), trt_coord.size(2))); 471 | // set input dims whichs name "point_label " 472 | context->setBindingDimensions(2, nvinfer1::Dims2(trt_coord.size(0), trt_coord.size(1))); 473 | 474 | for (int i = 0; i < mEngine->getNbBindings(); i++) 475 | { 476 | // auto dims = mEngine->getBindingDimensions(i); 477 | auto tensor_name = mEngine->getBindingName(i); 478 | std::cout << "tensor_name: " << tensor_name << std::endl; 479 | auto dims = context->getBindingDimensions(i); 480 | dims2str(dims); 481 | nvinfer1::DataType type = mEngine->getBindingDataType(i); 482 | index2srt(type); 483 | auto vol = std::accumulate(dims.d, dims.d + dims.nbDims, int64_t{1}, std::multiplies{}); 484 | 485 | mInOut[tensor_name]->resize(dims); 486 | } 487 | 488 | CHECK(mInOut["image_embeddings"]->host2device((void *)(image_embeddings.data_ptr()), true, stream)); 489 | CHECK(cudaStreamSynchronize(stream)); 490 | CHECK(mInOut["point_coords"]->host2device((void *)(trt_coord.data_ptr()), true, stream)); 491 | CHECK(cudaStreamSynchronize(stream)); 492 | CHECK(mInOut["point_labels"]->host2device((void *)(trt_label.data_ptr()), true, stream)); 493 | CHECK(cudaStreamSynchronize(stream)); 494 | CHECK(mInOut["mask_input"]->host2device((void *)(trt_mask_input.data_ptr()), true, stream)); 495 | CHECK(cudaStreamSynchronize(stream)); 496 | CHECK(mInOut["has_mask_input"]->host2device((void *)(trt_has_mask_input.data_ptr()), true, stream)); 497 | CHECK(cudaStreamSynchronize(stream)); 498 | return 0; 499 | } 500 | 501 | int SamPromptEncoderAndMaskDecoder::prepareInput(std::vector mult_pts, at::Tensor image_embeddings) 502 | { 503 | // at::Tensor image_embeddings; 504 | 505 | // torch::load(image_embeddings, "preds.pt"); 506 | // std::cout << image_embeddings.sizes() << std::endl; 507 | int image_size = 1024; 508 | ResizeLongestSide transf(image_size); 509 | 510 | auto input_point = at::tensor(mult_pts, at::kFloat).reshape({-1,2}); 511 | std::cout << input_point << std::endl; 512 | auto input_label = at::ones({int(mult_pts.size() / 2)}, at::kFloat); 513 | std::cout << input_label << std::endl; 514 | 515 | auto trt_coord = at::concatenate({input_point, at::tensor({0, 0}, at::kFloat).unsqueeze(0)}, 0).unsqueeze(0); 516 | auto trt_label = at::concatenate({input_label, at::tensor({-1}, at::kFloat)}, 0).unsqueeze(0); 517 | trt_coord = transf.apply_coords(trt_coord, {frame.rows, frame.cols}); 518 | // std::cout << "trt_coord " << trt_coord.sizes() << std::endl; 519 | auto trt_mask_input = at::zeros({1, 1, 256, 256}, at::kFloat); 520 | auto trt_has_mask_input = at::zeros(1, at::kFloat); 521 | 522 | context->setBindingDimensions(1, nvinfer1::Dims3(trt_coord.size(0), trt_coord.size(1), trt_coord.size(2))); 523 | // set input dims whichs name "point_label " 524 | context->setBindingDimensions(2, nvinfer1::Dims2(trt_coord.size(0), trt_coord.size(1))); 525 | 526 | for (int i = 0; i < mEngine->getNbBindings(); i++) 527 | { 528 | // auto dims = mEngine->getBindingDimensions(i); 529 | auto tensor_name = mEngine->getBindingName(i); 530 | std::cout << "tensor_name: " << tensor_name << std::endl; 531 | auto dims = context->getBindingDimensions(i); 532 | dims2str(dims); 533 | nvinfer1::DataType type = mEngine->getBindingDataType(i); 534 | index2srt(type); 535 | auto vol = std::accumulate(dims.d, dims.d + dims.nbDims, int64_t{1}, std::multiplies{}); 536 | 537 | mInOut[tensor_name]->resize(dims); 538 | } 539 | 540 | CHECK(mInOut["image_embeddings"]->host2device((void *)(image_embeddings.data_ptr()), true, stream)); 541 | CHECK(cudaStreamSynchronize(stream)); 542 | CHECK(mInOut["point_coords"]->host2device((void *)(trt_coord.data_ptr()), true, stream)); 543 | CHECK(cudaStreamSynchronize(stream)); 544 | CHECK(mInOut["point_labels"]->host2device((void *)(trt_label.data_ptr()), true, stream)); 545 | CHECK(cudaStreamSynchronize(stream)); 546 | CHECK(mInOut["mask_input"]->host2device((void *)(trt_mask_input.data_ptr()), true, stream)); 547 | CHECK(cudaStreamSynchronize(stream)); 548 | CHECK(mInOut["has_mask_input"]->host2device((void *)(trt_has_mask_input.data_ptr()), true, stream)); 549 | CHECK(cudaStreamSynchronize(stream)); 550 | return 0; 551 | } 552 | 553 | bool SamPromptEncoderAndMaskDecoder::infer() 554 | { 555 | CHECK(cudaEventRecord(start, stream)); 556 | auto ret = context->enqueueV2(mDeviceBindings.data(), stream, nullptr); 557 | return ret; 558 | } 559 | 560 | int SamPromptEncoderAndMaskDecoder::verifyOutput() 561 | { 562 | float ms{0.0f}; 563 | CHECK(cudaEventRecord(end, stream)); 564 | CHECK(cudaEventSynchronize(end)); 565 | CHECK(cudaEventElapsedTime(&ms, start, end)); 566 | 567 | auto dim0 = mEngine->getTensorShape("masks"); 568 | auto dim1 = mEngine->getTensorShape("scores"); 569 | // dims2str(dim0); 570 | // dims2str(dim1); 571 | at::Tensor masks; 572 | masks = at::zeros({dim0.d[0], dim0.d[1], dim0.d[2], dim0.d[3]}, at::kFloat); 573 | mInOut["masks"]->device2host((void *)(masks.data_ptr()), stream); 574 | // Wait for the work in the stream to complete 575 | CHECK(cudaStreamSynchronize(stream)); 576 | 577 | int longest_side = 1024; 578 | 579 | namespace F = torch::nn::functional; 580 | masks = F::interpolate(masks, F::InterpolateFuncOptions().size(std::vector({longest_side, longest_side})).mode(torch::kBilinear).align_corners(false)); 581 | // at::IntArrayRef input_image_size{frame.rows, frame.cols}; 582 | ResizeLongestSide transf(longest_side); 583 | auto target_size = transf.get_preprocess_shape(frame.rows, frame.cols); 584 | masks = masks.index({"...", Slice(None, target_size[0]), Slice(None, target_size[1])}); 585 | 586 | masks = F::interpolate(masks, F::InterpolateFuncOptions().size(std::vector({frame.rows, frame.cols})).mode(torch::kBilinear).align_corners(false)); 587 | std::cout << "masks: " << masks.sizes() << std::endl; 588 | 589 | at::Tensor iou_predictions; 590 | iou_predictions = at::zeros({dim0.d[0], dim0.d[1]}, at::kFloat); 591 | mInOut["scores"]->device2host((void *)(iou_predictions.data_ptr()), stream); 592 | // Wait for the work in the stream to complete 593 | CHECK(cudaStreamSynchronize(stream)); 594 | 595 | torch::DeviceType device_type; 596 | if (torch::cuda::is_available()) 597 | { 598 | std::cout << "CUDA available! Training on GPU." << std::endl; 599 | device_type = torch::kCUDA; 600 | } 601 | else 602 | { 603 | std::cout << "Training on CPU." << std::endl; 604 | device_type = torch::kCPU; 605 | } 606 | 607 | torch::Device device(device_type); 608 | masks = masks.gt(0.) * 1.0; 609 | std::cout << "max " << masks.max() << std::endl; 610 | // masks = masks.sigmoid(); 611 | std::cout << "masks: " << masks.sizes() << std::endl; 612 | masks = masks.to(device); 613 | std::cout << "iou_predictions: " << iou_predictions << std::endl; 614 | cv::Mat img; 615 | // cv::Mat frame = cv::imread("D:/projects/detections/data/truck.jpg"); 616 | frame.convertTo(img, CV_32F, 1.0 / 255); 617 | at::Tensor im_gpu = 618 | at::from_blob(img.data, {img.rows, img.cols, img.channels()}) 619 | .permute({2, 0, 1}) 620 | .contiguous() 621 | .to(device); 622 | auto results = plot_masks(masks, im_gpu, 0.5); 623 | auto t_img = results.to(torch::kCPU).clamp(0, 255).to(torch::kU8); 624 | 625 | auto img_ = cv::Mat(t_img.size(0), t_img.size(1), CV_8UC3, t_img.data_ptr()); 626 | std::cout << "1111111111111111" << std::endl; 627 | cv::cvtColor(img_, img_, cv::COLOR_RGB2BGR); 628 | cv::imwrite("img1111.jpg",img_); 629 | // cv::imshow("img_", img_); 630 | return 0; 631 | } 632 | 633 | int SamPromptEncoderAndMaskDecoder::verifyOutput(cv::Mat& roi) 634 | { 635 | float ms{0.0f}; 636 | CHECK(cudaEventRecord(end, stream)); 637 | CHECK(cudaEventSynchronize(end)); 638 | CHECK(cudaEventElapsedTime(&ms, start, end)); 639 | 640 | auto dim0 = mEngine->getTensorShape("masks"); 641 | auto dim1 = mEngine->getTensorShape("scores"); 642 | // dims2str(dim0); 643 | // dims2str(dim1); 644 | at::Tensor masks; 645 | masks = at::zeros({dim0.d[0], dim0.d[1], dim0.d[2], dim0.d[3]}, at::kFloat); 646 | mInOut["masks"]->device2host((void *)(masks.data_ptr()), stream); 647 | // Wait for the work in the stream to complete 648 | CHECK(cudaStreamSynchronize(stream)); 649 | 650 | int longest_side = 1024; 651 | 652 | namespace F = torch::nn::functional; 653 | masks = F::interpolate(masks, F::InterpolateFuncOptions().size(std::vector({longest_side, longest_side})).mode(torch::kBilinear).align_corners(false)); 654 | // at::IntArrayRef input_image_size{frame.rows, frame.cols}; 655 | ResizeLongestSide transf(longest_side); 656 | auto target_size = transf.get_preprocess_shape(frame.rows, frame.cols); 657 | masks = masks.index({"...", Slice(None, target_size[0]), Slice(None, target_size[1])}); 658 | 659 | masks = F::interpolate(masks, F::InterpolateFuncOptions().size(std::vector({frame.rows, frame.cols})).mode(torch::kBilinear).align_corners(false)); 660 | // std::cout << "masks: " << masks.sizes() << std::endl; 661 | 662 | at::Tensor iou_predictions; 663 | iou_predictions = at::zeros({dim0.d[0], dim0.d[1]}, at::kFloat); 664 | mInOut["scores"]->device2host((void *)(iou_predictions.data_ptr()), stream); 665 | // Wait for the work in the stream to complete 666 | CHECK(cudaStreamSynchronize(stream)); 667 | 668 | masks = masks.gt(0.) * 1.0; 669 | masks = masks.squeeze(0).squeeze(0); 670 | masks = masks.to(torch::kCPU).to(torch::kU8); 671 | std::cout << "masks: " << masks.sizes() << std::endl; 672 | auto roi_ = cv::Mat(masks.size(0), masks.size(1), CV_8U, masks.data_ptr()); 673 | roi_.copyTo(roi); 674 | 675 | return 0; 676 | } 677 | 678 | /* 679 | return [r g b] * n 680 | */ 681 | at::Tensor SamPromptEncoderAndMaskDecoder::generator_colors(int num) 682 | { 683 | 684 | std::vector hexs = {"FF37C7", "FF9D97", "FF701F", "FFB21D", "CFD231", "48F90A", "92CC17", "3DDB86", "1A9334", "00D4BB", 685 | "2C99A8", "00C2FF", "344593", "6473FF", "0018EC", "", "520085", "CB38FF", "FF95C8", "FF3838"}; 686 | 687 | std::vector tmp; 688 | for (int i = 0; i < num; ++i) 689 | { 690 | int r = string2Num(hexs[i].substr(0, 2)); 691 | // std::cout << r << std::endl; 692 | int g = string2Num(hexs[i].substr(2, 2)); 693 | // std::cout << g << std::endl; 694 | int b = string2Num(hexs[i].substr(4, 2)); 695 | // std::cout << b << std::endl; 696 | tmp.emplace_back(r); 697 | tmp.emplace_back(g); 698 | tmp.emplace_back(b); 699 | } 700 | return at::from_blob(tmp.data(), {(int)tmp.size()}, at::TensorOptions(at::kInt)); 701 | } 702 | 703 | template 704 | Type SamPromptEncoderAndMaskDecoder::string2Num(const std::string &str) 705 | { 706 | std::istringstream iss(str); 707 | Type num; 708 | iss >> std::hex >> num; 709 | return num; 710 | } 711 | 712 | /* 713 | Plot masks at once. 714 | Args: 715 | masks (tensor): predicted masks on cuda, shape: [n, h, w] 716 | colors (List[List[Int]]): colors for predicted masks, [[r, g, b] * n] 717 | im_gpu (tensor): img is in cuda, shape: [3, h, w], range: [0, 1] 718 | alpha (float): mask transparency: 0.0 fully transparent, 1.0 opaque 719 | */ 720 | 721 | at::Tensor SamPromptEncoderAndMaskDecoder::plot_masks(at::Tensor masks, at::Tensor im_gpu, float alpha) 722 | { 723 | int n = masks.size(0); 724 | auto colors = generator_colors(n); 725 | colors = colors.to(masks.device()).to(at::kFloat).div(255).reshape({-1, 3}).unsqueeze(1).unsqueeze(2); 726 | // std::cout << "colors: " << colors.sizes() << std::endl; 727 | masks = masks.permute({0, 2, 3, 1}).contiguous(); 728 | // std::cout << "masks: " << masks.sizes() << std::endl; 729 | auto masks_color = masks * (colors * alpha); 730 | // std::cout << "masks_color: " << masks_color.sizes() << std::endl; 731 | auto inv_alph_masks = (1 - masks * alpha); 732 | inv_alph_masks = inv_alph_masks.cumprod(0); 733 | // std::cout << "inv_alph_masks: " << inv_alph_masks.sizes() << std::endl; 734 | 735 | auto mcs = masks_color * inv_alph_masks; 736 | mcs = mcs.sum(0) * 2; 737 | // std::cout << "mcs: " << mcs.sizes() << std::endl; 738 | im_gpu = im_gpu.flip({0}); 739 | // std::cout << "im_gpu: " << im_gpu.sizes() << std::endl; 740 | im_gpu = im_gpu.permute({1, 2, 0}).contiguous(); 741 | // std::cout << "im_gpu: " << im_gpu.sizes() << std::endl; 742 | im_gpu = im_gpu * inv_alph_masks[-1] + mcs; 743 | // std::cout << "im_gpu: " << im_gpu.sizes() << std::endl; 744 | auto im_mask = (im_gpu * 255); 745 | return im_mask; 746 | } 747 | #endif -------------------------------------------------------------------------------- /sam_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef SAM_UTILS_H 2 | #define SAM_UTILS_H 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | 9 | using namespace nvinfer1; 10 | using namespace nvonnxparser; 11 | 12 | #undef CHECK 13 | #define CHECK(status) \ 14 | do \ 15 | { \ 16 | auto ret = (status); \ 17 | if (ret != 0) \ 18 | { \ 19 | std::cerr << "Cuda failure: " << ret << std::endl; \ 20 | abort(); \ 21 | } \ 22 | } while (0) 23 | 24 | void index2srt(nvinfer1::DataType dataType) 25 | { 26 | switch (dataType) 27 | { 28 | case nvinfer1::DataType::kFLOAT: 29 | std::cout << "nvinfer1::DataType::kFLOAT" << std::endl; 30 | break; 31 | case nvinfer1::DataType::kHALF: 32 | std::cout << "nvinfer1::DataType::kHALF" << std::endl; 33 | break; 34 | case nvinfer1::DataType::kINT8: 35 | std::cout << "nvinfer1::DataType::kINT8" << std::endl; 36 | break; 37 | case nvinfer1::DataType::kINT32: 38 | std::cout << "nvinfer1::DataType::kINT32" << std::endl; 39 | break; 40 | case nvinfer1::DataType::kBOOL: 41 | std::cout << "nvinfer1::DataType::kBOOL" << std::endl; 42 | break; 43 | case nvinfer1::DataType::kUINT8: 44 | std::cout << "nvinfer1::DataType::kUINT8" << std::endl; 45 | break; 46 | 47 | default: 48 | break; 49 | } 50 | } 51 | 52 | void dims2str(nvinfer1::Dims dims) 53 | { 54 | std::string o_s("["); 55 | for (size_t i = 0; i < dims.nbDims; i++) 56 | { 57 | if (i > 0) 58 | o_s += ", "; 59 | o_s += std::to_string(dims.d[i]); 60 | } 61 | o_s += "]"; 62 | std::cout << o_s << std::endl; 63 | } 64 | class Logger : public nvinfer1::ILogger 65 | { 66 | void log(Severity severity, const char *msg) noexcept override 67 | { 68 | // suppress info-level messages 69 | if (severity <= Severity::kWARNING) 70 | std::cout << msg << std::endl; 71 | } 72 | } logger; 73 | 74 | #endif -------------------------------------------------------------------------------- /truck.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mingj2021/segment-anything-tensorrt/bd27bc114fc6f6881c905d56f802b3d7ab3f2eb9/truck.gif --------------------------------------------------------------------------------