├── .gitignore
├── .vscode
├── c_cpp_properties.json
└── settings.json
├── CMakeLists.txt
├── Dockerfile.dev
├── LICENSE
├── README.md
├── README.md.bak
├── README_zh_windows.md
├── baseModel.h
├── commons
├── ThreadPool.h
├── buffers.h
└── general.h
├── data
├── 2023-09-14 20-42-44.png
└── truck.jpg
├── export.h
├── how_to_export_vim_h_model.ipynb
├── main.cpp
├── main_vim_h.cpp
├── sam.h
├── sam_utils.h
├── truck.gif
├── tutorials.ipynb
└── tutorials_vim_h.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | # Prerequisites
2 | *.d
3 |
4 | # Compiled Object files
5 | *.slo
6 | *.lo
7 | *.o
8 | *.obj
9 |
10 | # Precompiled Headers
11 | *.gch
12 | *.pch
13 |
14 | # Compiled Dynamic libraries
15 | *.so
16 | *.dylib
17 | *.dll
18 |
19 | # Fortran module files
20 | *.mod
21 | *.smod
22 |
23 | # Compiled Static libraries
24 | *.lai
25 | *.la
26 | *.a
27 | *.lib
28 |
29 | # Executables
30 | *.exe
31 | *.out
32 | *.app
33 |
34 | build
35 | segment-anything
36 | data/*.onnx
--------------------------------------------------------------------------------
/.vscode/c_cpp_properties.json:
--------------------------------------------------------------------------------
1 | {
2 | "configurations": [
3 | {
4 | "name": "linux",
5 | "includePath": [
6 | "/usr/local/libtorch/include",
7 | "/usr/local/libtorch/include/torch/csrc/api/include",
8 | "/usr/include/opencv4",
9 | "/usr/include/opencv4/opencv2"
10 | ],
11 | "browse": {
12 | "limitSymbolsToIncludedHeaders": true,
13 | "databaseFilename": ""
14 | },
15 | "intelliSenseMode": "linux-gcc-x64",
16 | "compilerPath": "/usr/bin/gcc",
17 | "cStandard": "c17",
18 | "cppStandard": "gnu++14"
19 | }
20 | ],
21 | "version": 4
22 | }
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "cmake.configureSettings": {
3 | // "CMAKE_TOOLCHAIN_FILE": "C:/Users/77274/projects/vcpkg/scripts/buildsystems/vcpkg.cmake",
4 | // "yaml-cpp_DIR": "C:/Users/77274/projects/Dev/yaml-cpp/lib/cmake/yaml-cpp",
5 | // "Torch_DIR":"C:/Users/77274/projects/Dev/libtorch/share/cmake/Torch"
6 | "Torch_DIR":"/usr/local/libtorch/share/cmake/Torch"
7 | },
8 | "files.associations": {
9 | "iostream": "cpp",
10 | "algorithm": "cpp",
11 | "array": "cpp",
12 | "atomic": "cpp",
13 | "bit": "cpp",
14 | "bitset": "cpp",
15 | "cctype": "cpp",
16 | "cfenv": "cpp",
17 | "charconv": "cpp",
18 | "chrono": "cpp",
19 | "clocale": "cpp",
20 | "cmath": "cpp",
21 | "compare": "cpp",
22 | "complex": "cpp",
23 | "concepts": "cpp",
24 | "condition_variable": "cpp",
25 | "csignal": "cpp",
26 | "cstdarg": "cpp",
27 | "cstddef": "cpp",
28 | "cstdint": "cpp",
29 | "cstdio": "cpp",
30 | "cstdlib": "cpp",
31 | "cstring": "cpp",
32 | "ctime": "cpp",
33 | "cwchar": "cpp",
34 | "cwctype": "cpp",
35 | "deque": "cpp",
36 | "exception": "cpp",
37 | "format": "cpp",
38 | "forward_list": "cpp",
39 | "fstream": "cpp",
40 | "functional": "cpp",
41 | "initializer_list": "cpp",
42 | "iomanip": "cpp",
43 | "ios": "cpp",
44 | "iosfwd": "cpp",
45 | "istream": "cpp",
46 | "iterator": "cpp",
47 | "limits": "cpp",
48 | "list": "cpp",
49 | "locale": "cpp",
50 | "map": "cpp",
51 | "memory": "cpp",
52 | "mutex": "cpp",
53 | "new": "cpp",
54 | "numeric": "cpp",
55 | "optional": "cpp",
56 | "ostream": "cpp",
57 | "queue": "cpp",
58 | "random": "cpp",
59 | "ratio": "cpp",
60 | "set": "cpp",
61 | "source_location": "cpp",
62 | "span": "cpp",
63 | "sstream": "cpp",
64 | "stdexcept": "cpp",
65 | "stop_token": "cpp",
66 | "streambuf": "cpp",
67 | "string": "cpp",
68 | "strstream": "cpp",
69 | "system_error": "cpp",
70 | "thread": "cpp",
71 | "tuple": "cpp",
72 | "type_traits": "cpp",
73 | "typeindex": "cpp",
74 | "typeinfo": "cpp",
75 | "unordered_map": "cpp",
76 | "unordered_set": "cpp",
77 | "utility": "cpp",
78 | "valarray": "cpp",
79 | "variant": "cpp",
80 | "vector": "cpp",
81 | "xfacet": "cpp",
82 | "xhash": "cpp",
83 | "xiosbase": "cpp",
84 | "xlocale": "cpp",
85 | "xlocbuf": "cpp",
86 | "xlocinfo": "cpp",
87 | "xlocmes": "cpp",
88 | "xlocmon": "cpp",
89 | "xlocnum": "cpp",
90 | "xloctime": "cpp",
91 | "xmemory": "cpp",
92 | "xstddef": "cpp",
93 | "xstring": "cpp",
94 | "xtr1common": "cpp",
95 | "xtree": "cpp",
96 | "xutility": "cpp",
97 | "cinttypes": "cpp",
98 | "filesystem": "cpp",
99 | "shared_mutex": "cpp",
100 | "stack": "cpp",
101 | "__nullptr": "cpp",
102 | "__config": "cpp",
103 | "*.tcc": "cpp",
104 | "codecvt": "cpp",
105 | "memory_resource": "cpp",
106 | "string_view": "cpp",
107 | "future": "cpp"
108 | },
109 | "cmake.configureOnOpen": true,
110 | "C_Cpp.errorSquiggles": "disabled"
111 | }
--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.14.0 FATAL_ERROR)
2 | project(Detections)
3 |
4 | # add_compile_options(-fno-elide-constructors)
5 | # set(CMAKE_CXX_FLAGS "-fno-elide-constructors ${CMAKE_CXX_FLAGS}")
6 | # set(ROOT_DIRS "C:/Users/77274/projects")
7 | # set(ROOT_DIRS "D:/projects")
8 | # set( CMAKE_CXX_COMPILER /usr/bin/g++ )
9 | # set(CMAKE_CXX_FLAGS "-fno-elide-constructors ${CMAKE_CXX_FLAGS}")
10 | set(Torch_DIR "/usr/local/libtorch/share/cmake/Torch")
11 | find_package(Torch REQUIRED)
12 | # include_directories(${TORCH_INCLUDE_DIRS}) # Not needed for CMake >= 2.8.11
13 |
14 |
15 | # set(TorchVision_DIR "${ROOT_DIRS}/3rdpty/torchvision/share/cmake/TorchVision")
16 | # find_package(TorchVision REQUIRED)
17 | # include_directories(${TorchVision_INCLUDE_DIR}) # Not needed for CMake >= 2.8.11
18 | # link_directories("../3rdpty/torchvision/lib")
19 |
20 | find_package(OpenCV REQUIRED)
21 |
22 | set(SAMPLE_SOURCES main.cpp)
23 | set(TARGET_NAME sampleSAM)
24 |
25 | set(SAMPLE_DEP_LIBS
26 | nvinfer
27 | nvonnxparser
28 | )
29 |
30 | # commons
31 | include_directories("./commons")
32 |
33 | add_executable(${TARGET_NAME} ${SAMPLE_SOURCES})
34 | target_link_libraries(${TARGET_NAME} ${TORCH_LIBRARIES})
35 | target_link_libraries(${TARGET_NAME} ${OpenCV_LIBS})
36 | # target_link_libraries(${TARGET_NAME} TorchVision::TorchVision)
37 | target_link_libraries(${TARGET_NAME} ${SAMPLE_DEP_LIBS})
38 | set_property(TARGET ${TARGET_NAME} PROPERTY CXX_STANDARD 17)
39 |
40 |
41 | set(SAMPLE_SOURCES main_vim_h.cpp)
42 | set(TARGET_NAME sampleSAM2)
43 | add_executable(${TARGET_NAME} ${SAMPLE_SOURCES})
44 | target_link_libraries(${TARGET_NAME} ${TORCH_LIBRARIES})
45 | target_link_libraries(${TARGET_NAME} ${OpenCV_LIBS})
46 | # target_link_libraries(${TARGET_NAME} TorchVision::TorchVision)
47 | target_link_libraries(${TARGET_NAME} ${SAMPLE_DEP_LIBS})
48 | set_property(TARGET ${TARGET_NAME} PROPERTY CXX_STANDARD 17)
--------------------------------------------------------------------------------
/Dockerfile.dev:
--------------------------------------------------------------------------------
1 | #ARG os=ubuntu2204 tag=8.5.0-cuda-11.7
2 | #FROM nvidia/cuda:11.7.0-cudnn8-devel-ubuntu22.04
3 | #RUN dpkg -i nv-tensorrt-local-repo-${os}-${tag}_1.0-1_amd64.deb \
4 | # && cp /var/nv-tensorrt-local-repo-${os}-${tag}/*-keyring.gpg /usr/share/keyrings/ \
5 | # && apt-get update
6 |
7 | FROM nvcr.io/nvidia/tensorrt:22.10-py3
8 |
9 | # install libtorch
10 | RUN cd ~/ && wget https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.0.0%2Bcu118.zip \
11 | && unzip libtorch-cxx11-abi-shared-with-deps-2.0.0+cu118.zip -d /usr/local/ \
12 | && rm -f libtorch-cxx11-abi-shared-with-deps-2.0.0+cu118.zip
13 |
14 | # install opencv
15 | RUN apt update && apt install libopencv-dev -y
16 |
17 | # install python essential dependencies
18 | RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
19 | RUN pip3 install numpy pillow matplotlib pycocotools opencv-python onnx onnxruntime -i https://pypi.tuna.tsinghua.edu.cn/simple
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 mingj2021
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Overview
2 | [中文-windows](README_zh_windows.md)
3 |
4 |
5 |
6 |
7 | download [translated onnx model](https://drive.google.com/drive/folders/1xZ7XpyKGx0Fg-t81SLND56vC2TEQ4otN?usp=sharing)
8 | # tutorials
9 | ```
10 | vim_h trt model tested on GPU 3060 Ti, tests ok.
11 | tutorials.ipynb
12 | tutorials_vim_h.ipynb
13 |
14 | ```
15 |
16 | # how to export vim_h model
17 | how_to_export_vim_h_model.ipynb
18 | ```
19 | # Divide the embeding model into 2 parts, named as part1 and part2
20 | # The decoder model remains unchanged
21 | # export_embedding_model_part_1()
22 | # export_embedding_model_part_2()
23 | # export_sam_h_model()
24 | # onnx_model_example2()
25 | ```
26 | # FP32 or FP16
27 | ```
28 | you can not set FP16 value in export.h, if gpu have enough mems.
29 | ```
--------------------------------------------------------------------------------
/README.md.bak:
--------------------------------------------------------------------------------
1 | # Overview
2 | [中文-windows](README_zh_CN.md)
3 |
4 |
5 |
6 |
7 | The repository helps you quickly deploy segment-anything to real-world applications, such as auto annotation,etc.The original model is divided into two submodels, one for embedding and the other for prompting and mask decoder.
8 | # Table of Contents
9 | - [Overview](#Overview)
10 | - [Table of Contents](#Table-of-Contents)
11 | - [Getting Started](#getting-started)
12 | - [Quick Start: Windows](#quick-start-windows)
13 | - [Quick Start: Ubuntu](#quick-start-ubuntu)
14 | - [Onnx Export](#onnx-export)
15 | - [Image Encoder](#export-embedding-onnx-model)
16 | - [Prompt Encoder and mask decoder](#export-prompt-encoder-mask-decoer-onnx-model)
17 | - [Test Exported-onnx models](#test-exported-onnx-models)
18 | - [Engine Export](#engine-export)
19 | - [Image Encoder](#convert-image-encoder-onnx-to-engine-model)
20 | - [Prompt Encoder and mask decoder](#convert-prompt-encoder-and-mask-decoder-onnx-to-engine-model)
21 | - [Quantification]()
22 | - [TensorRT Inferring]()
23 | - [Preprocess]()
24 | - [Postprocess]()
25 | - [build]()
26 | - [Examples]()
27 |
28 | # Getting Started
29 | Prerequisites:
30 | - [OpenCV](https://github.com/opencv/opencv)
31 | - [Libtorch](https://pytorch.org/)
32 | - [Torchvision](https://github.com/pytorch/vision)
33 | - [Tensorrt](https://developer.nvidia.com/tensorrt)
34 | ## Quick Start: Windows
35 | ```
36 | # create conda virtual env
37 | conda create -n segment-anything python=3.8
38 | # activate this environment
39 | conda activate segment-anything
40 | pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117
41 | pip install opencv-python pycocotools matplotlib onnxruntime onnx
42 | # download TensorRT-8.5.1.7.Windows10.x86_64.cuda-x.x.zip && unzip
43 | pip install ./tensorrt-8.5.1.7-cp38-none-win_amd64.whl
44 | # tensorrt tool: PyTorch-Quantization
45 | pip install pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com
46 | ```
47 |
48 | ## Quick Start: Ubuntu
49 | ```
50 | # create virtual env
51 | git clone https://github.com/mingj2021/segment-anything-tensorrt.git
52 | cd segment-anything-tensorrt
53 | docker build -t dev:ml -f ./Dockerfile.dev .
54 | git clone https://github.com/facebookresearch/segment-anything.git
55 | cd segment-anything
56 | mkdir weights && cd weights
57 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth
58 | cd ../
59 | sudo mv ../demo.py .
60 | # run container
61 | docker run -it --rm --gpus all -v $(yourdirs)/segment-anything-tensorrt:/workspace/segment-anything-tensorrt dev:ml
62 | ```
63 |
64 | # Onnx Export
65 | definition && import dependencies
66 | ```
67 | import torch
68 | import torch.nn as nn
69 | from torch.nn import functional as F
70 | from segment_anything.modeling import Sam
71 | import numpy as np
72 | from torchvision.transforms.functional import resize, to_pil_image
73 | from typing import Tuple
74 | from segment_anything import sam_model_registry, SamPredictor
75 | import cv2
76 | import matplotlib.pyplot as plt
77 | import warnings
78 | import os
79 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
80 | import onnxruntime
81 |
82 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
83 | """
84 | Compute the output size given input size and target long side length.
85 | """
86 | scale = long_side_length * 1.0 / max(oldh, oldw)
87 | newh, neww = oldh * scale, oldw * scale
88 | neww = int(neww + 0.5)
89 | newh = int(newh + 0.5)
90 | return (newh, neww)
91 |
92 | # @torch.no_grad()
93 | def pre_processing(image: np.ndarray, target_length: int, device,pixel_mean,pixel_std,img_size):
94 | target_size = get_preprocess_shape(image.shape[0], image.shape[1], target_length)
95 | input_image = np.array(resize(to_pil_image(image), target_size))
96 | input_image_torch = torch.as_tensor(input_image, device=device)
97 | input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
98 |
99 | # Normalize colors
100 | input_image_torch = (input_image_torch - pixel_mean) / pixel_std
101 |
102 | # Pad
103 | h, w = input_image_torch.shape[-2:]
104 | padh = img_size - h
105 | padw = img_size - w
106 | input_image_torch = F.pad(input_image_torch, (0, padw, 0, padh))
107 | return input_image_torch
108 |
109 | ```
110 | ## export embedding-onnx model
111 | ```
112 | def export_embedding_model():
113 | sam_checkpoint = "weights/sam_vit_l_0b3195.pth"
114 | model_type = "vit_l"
115 |
116 | device = "cpu"
117 |
118 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
119 | sam.to(device=device)
120 |
121 | # image_encoder = EmbeddingOnnxModel(sam)
122 | image = cv2.imread('notebooks/images/truck.jpg')
123 | target_length = sam.image_encoder.img_size
124 | pixel_mean = sam.pixel_mean
125 | pixel_std = sam.pixel_std
126 | img_size = sam.image_encoder.img_size
127 | inputs = pre_processing(image, target_length, device,pixel_mean,pixel_std,img_size)
128 | onnx_model_path = model_type+"_"+"embedding.onnx"
129 | dummy_inputs = {
130 | "images": inputs
131 | }
132 | output_names = ["image_embeddings"]
133 | image_embeddings = sam.image_encoder(inputs).cpu().numpy()
134 | print('image_embeddings', image_embeddings.shape)
135 | with warnings.catch_warnings():
136 | warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
137 | warnings.filterwarnings("ignore", category=UserWarning)
138 | with open(onnx_model_path, "wb") as f:
139 | torch.onnx.export(
140 | sam.image_encoder,
141 | tuple(dummy_inputs.values()),
142 | f,
143 | export_params=True,
144 | verbose=False,
145 | opset_version=13,
146 | do_constant_folding=True,
147 | input_names=list(dummy_inputs.keys()),
148 | output_names=output_names,
149 | # dynamic_axes=dynamic_axes,
150 | )
151 | with torch.no_grad():
152 | export_embedding_model()
153 | ```
154 | ## export prompt-encoder-mask-decoer-onnx model
155 | change "forward" function in the file which is "segment_anything/utils/onnx.py",as follows:
156 | ```
157 | def forward(
158 | self,
159 | image_embeddings: torch.Tensor,
160 | point_coords: torch.Tensor,
161 | point_labels: torch.Tensor,
162 | mask_input: torch.Tensor,
163 | has_mask_input: torch.Tensor
164 | # orig_im_size: torch.Tensor,
165 | ):
166 | sparse_embedding = self._embed_points(point_coords, point_labels)
167 | dense_embedding = self._embed_masks(mask_input, has_mask_input)
168 |
169 | masks, scores = self.model.mask_decoder.predict_masks(
170 | image_embeddings=image_embeddings,
171 | image_pe=self.model.prompt_encoder.get_dense_pe(),
172 | sparse_prompt_embeddings=sparse_embedding,
173 | dense_prompt_embeddings=dense_embedding,
174 | )
175 |
176 | if self.use_stability_score:
177 | scores = calculate_stability_score(
178 | masks, self.model.mask_threshold, self.stability_score_offset
179 | )
180 |
181 | if self.return_single_mask:
182 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
183 |
184 | return masks, scores
185 | # upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
186 |
187 | # if self.return_extra_metrics:
188 | # stability_scores = calculate_stability_score(
189 | # upscaled_masks, self.model.mask_threshold, self.stability_score_offset
190 | # )
191 | # areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
192 | # return upscaled_masks, scores, stability_scores, areas, masks
193 |
194 | # return upscaled_masks, scores, masks
195 | ```
196 |
197 | ```
198 | def export_sam_model():
199 | from segment_anything.utils.onnx import SamOnnxModel
200 | checkpoint = "weights/sam_vit_l_0b3195.pth"
201 | model_type = "vit_l"
202 | sam = sam_model_registry[model_type](checkpoint=checkpoint)
203 | onnx_model_path = "sam_onnx_example.onnx"
204 |
205 | onnx_model = SamOnnxModel(sam, return_single_mask=True)
206 |
207 | dynamic_axes = {
208 | "point_coords": {1: "num_points"},
209 | "point_labels": {1: "num_points"},
210 | }
211 |
212 | embed_dim = sam.prompt_encoder.embed_dim
213 | embed_size = sam.prompt_encoder.image_embedding_size
214 | mask_input_size = [4 * x for x in embed_size]
215 | dummy_inputs = {
216 | "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
217 | "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
218 | "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
219 | "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
220 | "has_mask_input": torch.tensor([1], dtype=torch.float),
221 | # "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
222 | }
223 | # output_names = ["masks", "iou_predictions", "low_res_masks"]
224 | output_names = ["masks", "scores"]
225 |
226 | with warnings.catch_warnings():
227 | warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
228 | warnings.filterwarnings("ignore", category=UserWarning)
229 | with open(onnx_model_path, "wb") as f:
230 | torch.onnx.export(
231 | onnx_model,
232 | tuple(dummy_inputs.values()),
233 | f,
234 | export_params=True,
235 | verbose=False,
236 | opset_version=13,
237 | do_constant_folding=True,
238 | input_names=list(dummy_inputs.keys()),
239 | output_names=output_names,
240 | dynamic_axes=dynamic_axes,
241 | )
242 |
243 | with torch.no_grad():
244 | export_sam_model()
245 | ```
246 | ## test exported-onnx models
247 | ```
248 | def show_mask(mask, ax):
249 | color = np.array([30/255, 144/255, 255/255, 0.6])
250 | h, w = mask.shape[-2:]
251 | mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
252 | ax.imshow(mask_image)
253 |
254 | def show_points(coords, labels, ax, marker_size=375):
255 | pos_points = coords[labels==1]
256 | neg_points = coords[labels==0]
257 | ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
258 | ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
259 |
260 | def show_box(box, ax):
261 | x0, y0 = box[0], box[1]
262 | w, h = box[2] - box[0], box[3] - box[1]
263 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
264 |
265 | def onnx_model_example():
266 | import os
267 | ort_session_embedding = onnxruntime.InferenceSession('vit_l_embedding.onnx',providers=['CPUExecutionProvider'])
268 | ort_session_sam = onnxruntime.InferenceSession('sam_onnx_example.onnx',providers=['CPUExecutionProvider'])
269 |
270 | image = cv2.imread('notebooks/images/truck.jpg')
271 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
272 | image2 = image.copy()
273 | prompt_embed_dim = 256
274 | image_size = 1024
275 | vit_patch_size = 16
276 | image_embedding_size = image_size // vit_patch_size
277 | target_length = image_size
278 | pixel_mean=[123.675, 116.28, 103.53],
279 | pixel_std=[58.395, 57.12, 57.375]
280 | pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1)
281 | pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1)
282 | device = "cpu"
283 | inputs = pre_processing(image, target_length, device,pixel_mean,pixel_std,image_size)
284 | ort_inputs = {
285 | "images": inputs.cpu().numpy()
286 | }
287 | image_embeddings = ort_session_embedding.run(None, ort_inputs)[0]
288 |
289 | input_point = np.array([[500, 375]])
290 | input_label = np.array([1])
291 |
292 | onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
293 | onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)
294 | from segment_anything.utils.transforms import ResizeLongestSide
295 | transf = ResizeLongestSide(image_size)
296 | onnx_coord = transf.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)
297 | onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
298 | onnx_has_mask_input = np.zeros(1, dtype=np.float32)
299 |
300 | ort_inputs = {
301 | "image_embeddings": image_embeddings,
302 | "point_coords": onnx_coord,
303 | "point_labels": onnx_label,
304 | "mask_input": onnx_mask_input,
305 | "has_mask_input": onnx_has_mask_input,
306 | # "orig_im_size": np.array(image.shape[:2], dtype=np.float32)
307 | }
308 |
309 | masks, _ = ort_session_sam.run(None, ort_inputs)
310 |
311 | from segment_anything.utils.onnx import SamOnnxModel
312 | checkpoint = "weights/sam_vit_l_0b3195.pth"
313 | model_type = "vit_l"
314 | sam = sam_model_registry[model_type](checkpoint=checkpoint)
315 | onnx_model_path = "sam_onnx_example.onnx"
316 |
317 | onnx_model = SamOnnxModel(sam, return_single_mask=True)
318 | masks = onnx_model.mask_postprocessing(torch.as_tensor(masks), torch.as_tensor(image.shape[:2]))
319 | masks = masks > 0.0
320 | plt.figure(figsize=(10, 10))
321 | plt.imshow(image)
322 | show_mask(masks, plt.gca())
323 | # show_box(input_box, plt.gca())
324 | show_points(input_point, input_label, plt.gca())
325 | plt.axis('off')
326 | plt.savefig('demo.png')
327 |
328 | with torch.no_grad():
329 | onnx_model_example()
330 | ```
331 | # Engine Export
332 | ## convert image-encoder-onnx to engine model
333 | ```
334 | def export_engine_image_encoder(f='vit_l_embedding.onnx'):
335 | import tensorrt as trt
336 | from pathlib import Path
337 | file = Path(f)
338 | f = file.with_suffix('.engine') # TensorRT engine file
339 | onnx = file.with_suffix('.onnx')
340 | logger = trt.Logger(trt.Logger.INFO)
341 | builder = trt.Builder(logger)
342 | config = builder.create_builder_config()
343 | workspace = 6
344 | print("workspace: ", workspace)
345 | config.max_workspace_size = workspace * 1 << 30
346 | flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
347 | network = builder.create_network(flag)
348 | parser = trt.OnnxParser(network, logger)
349 | if not parser.parse_from_file(str(onnx)):
350 | raise RuntimeError(f'failed to load ONNX file: {onnx}')
351 |
352 | inputs = [network.get_input(i) for i in range(network.num_inputs)]
353 | outputs = [network.get_output(i) for i in range(network.num_outputs)]
354 | for inp in inputs:
355 | print(f'input "{inp.name}" with shape{inp.shape} {inp.dtype}')
356 | for out in outputs:
357 | print(f'output "{out.name}" with shape{out.shape} {out.dtype}')
358 |
359 | half = True
360 | print(f'building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')
361 | if builder.platform_has_fast_fp16 and half:
362 | config.set_flag(trt.BuilderFlag.FP16)
363 | with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
364 | t.write(engine.serialize())
365 | with torch.no_grad():
366 | export_engine_image_encoder('vit_l_embedding.onnx')
367 | ```
368 |
369 | ## convert prompt-encoder-and-mask-decoder-onnx to engine model
370 | ```
371 | def export_engine_prompt_encoder_and_mask_decoder(f='sam_onnx_example.onnx'):
372 | import tensorrt as trt
373 | from pathlib import Path
374 | file = Path(f)
375 | f = file.with_suffix('.engine') # TensorRT engine file
376 | onnx = file.with_suffix('.onnx')
377 | logger = trt.Logger(trt.Logger.INFO)
378 | builder = trt.Builder(logger)
379 | config = builder.create_builder_config()
380 | workspace = 10
381 | print("workspace: ", workspace)
382 | config.max_workspace_size = workspace * 1 << 30
383 | flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
384 | network = builder.create_network(flag)
385 | parser = trt.OnnxParser(network, logger)
386 | if not parser.parse_from_file(str(onnx)):
387 | raise RuntimeError(f'failed to load ONNX file: {onnx}')
388 |
389 | inputs = [network.get_input(i) for i in range(network.num_inputs)]
390 | outputs = [network.get_output(i) for i in range(network.num_outputs)]
391 | for inp in inputs:
392 | print(f'input "{inp.name}" with shape{inp.shape} {inp.dtype}')
393 | for out in outputs:
394 | print(f'output "{out.name}" with shape{out.shape} {out.dtype}')
395 |
396 | profile = builder.create_optimization_profile()
397 | profile.set_shape('image_embeddings', (1, 256, 64, 64), (1, 256, 64, 64), (1, 256, 64, 64))
398 | profile.set_shape('point_coords', (1, 2,2), (1, 5,2), (1,10,2))
399 | profile.set_shape('point_labels', (1, 2), (1, 5), (1,10))
400 | profile.set_shape('mask_input', (1, 1, 256, 256), (1, 1, 256, 256), (1, 1, 256, 256))
401 | profile.set_shape('has_mask_input', (1,), (1, ), (1, ))
402 | # # profile.set_shape_input('orig_im_size', (416,416), (1024,1024), (1500, 2250))
403 | # profile.set_shape_input('orig_im_size', (2,), (2,), (2, ))
404 | config.add_optimization_profile(profile)
405 |
406 | half = True
407 | print(f'building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')
408 | if builder.platform_has_fast_fp16 and half:
409 | config.set_flag(trt.BuilderFlag.FP16)
410 | with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
411 | t.write(engine.serialize())
412 | with torch.no_grad():
413 | export_engine_image_encoder('sam_onnx_example.onnx')
414 | ```
415 | # TensorRT Inferring
416 | - export 2 engine model
417 | - open main.cpp change actions(show window or generate file including embeddings features) by define variables[SAMPROMPTENCODERANDMASKDECODER or EMBEDDING]
418 |
419 | ## Preprocess
420 | image about resizing,padding,normalization.
421 | ```
422 | ```
423 | ## Postprocess
424 | generated mask about processing, plot.
425 | ```
426 | ```
427 | ## build
428 | ```
429 | mkdir build && cd build
430 | # modify main.cpp
431 | cmake ..
432 | make -j10
433 | ```
--------------------------------------------------------------------------------
/README_zh_windows.md:
--------------------------------------------------------------------------------
1 | # Overview
2 | [english](README.md)
3 |
4 | # cudatoolkit 安装
5 | 运行 cuda_11.7.0_516.01_windows.exe ,一路默认即可
6 |
7 | # cudnn 安装
8 | 解压文件cudnn-windows-x86_64-8.6.0.163_cuda11-archive.zip,将解压后的文件夹[lib bin, include] 复制到cuda 对应安装路径[C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.7]
9 |
10 | # libtorch 安装
11 | 下载[https://download.pytorch.org/libtorch/cu117/libtorch-win-shared-with-deps-2.0.1%2Bcu117.zip]并解压。
12 |
13 | # torchvision 安装
14 | ```
15 | git clone https://github.com/pytorch/vision.git
16 | mkdir build
17 | cd build
18 | cmake .. -DWITH_CUDA=on
19 | cmake --build . --config release --target install
20 | ```
21 |
22 | # tensorrt 安装
23 | 下载[TensorRT-8.5.1.7.Windows10.x86_64.cuda-x.x.zip]并解压
24 |
25 | # opencv 安装
26 | 下载[https://github.com/opencv/opencv/releases/download/4.7.0/opencv-4.7.0-windows.exe]并默认安装.
27 |
28 |
--------------------------------------------------------------------------------
/baseModel.h:
--------------------------------------------------------------------------------
1 | #ifndef BASEMODEL_H
2 | #define BASEMODEL_H
3 |
4 | #include "sam_utils.h"
5 |
6 | class BaseModel
7 | {
8 | public:
9 | cudaStream_t stream;
10 | std::shared_ptr mEngine;
11 | std::unique_ptr context;
12 |
13 | std::vector mDeviceBindings;
14 | std::map> mInOut;
15 | std::vector mInputsName, mOutputsName;
16 |
17 | public:
18 | BaseModel(std::string modelFile);
19 | ~BaseModel();
20 | void read_engine_file(std::string modelFile);
21 | // std::vector get_inputs_name();
22 | // std::vector get_outputs_name();
23 | // std::vector get_device_buffer();
24 | };
25 |
26 | BaseModel::BaseModel(std::string modelFile)
27 | {
28 | read_engine_file(modelFile);
29 | context = std::unique_ptr(mEngine->createExecutionContext());
30 | if (!context)
31 | {
32 | std::cerr << "create context error" << std::endl;
33 | }
34 |
35 | CHECK(cudaStreamCreate(&stream));
36 |
37 | for (int i = 0; i < mEngine->getNbBindings(); i++)
38 | {
39 | auto dims = mEngine->getBindingDimensions(i);
40 | auto tensor_name = mEngine->getBindingName(i);
41 | bool isInput = mEngine->bindingIsInput(i);
42 | if (isInput)
43 | mInputsName.emplace_back(tensor_name);
44 | else
45 | mOutputsName.emplace_back(tensor_name);
46 | std::cout << "tensor_name: " << tensor_name << std::endl;
47 | dims2str(dims);
48 | nvinfer1::DataType type = mEngine->getBindingDataType(i);
49 | index2srt(type);
50 | auto vol = std::accumulate(dims.d, dims.d + dims.nbDims, int64_t{1}, std::multiplies{});
51 | std::unique_ptr device_buffer{new algorithms::DeviceBuffer(vol, type)};
52 | mDeviceBindings.emplace_back(device_buffer->data());
53 | mInOut[tensor_name] = std::move(device_buffer);
54 | }
55 | }
56 |
57 | BaseModel::~BaseModel()
58 | {
59 | CHECK(cudaStreamDestroy(stream));
60 | }
61 |
62 | void BaseModel::read_engine_file(std::string modelFile)
63 | {
64 | std::ifstream engineFile(modelFile.c_str(), std::ifstream::binary);
65 | assert(engineFile);
66 |
67 | int fsize;
68 | engineFile.seekg(0, engineFile.end);
69 | fsize = engineFile.tellg();
70 | engineFile.seekg(0, engineFile.beg);
71 | std::vector engineData(fsize);
72 | engineFile.read(engineData.data(), fsize);
73 |
74 | if (engineFile)
75 | std::cout << "all characters read successfully." << std::endl;
76 | else
77 | std::cout << "error: only " << engineFile.gcount() << " could be read" << std::endl;
78 | engineFile.close();
79 |
80 | std::unique_ptr runtime(nvinfer1::createInferRuntime(logger));
81 | mEngine = std::shared_ptr(runtime->deserializeCudaEngine(engineData.data(), fsize, nullptr));
82 | }
83 |
84 | #endif
--------------------------------------------------------------------------------
/commons/ThreadPool.h:
--------------------------------------------------------------------------------
1 | /*
2 | Copyright (c) 2012 Jakob Progsch, Václav Zeman
3 |
4 | This software is provided 'as-is', without any express or implied
5 | warranty. In no event will the authors be held liable for any damages
6 | arising from the use of this software.
7 |
8 | Permission is granted to anyone to use this software for any purpose,
9 | including commercial applications, and to alter it and redistribute it
10 | freely, subject to the following restrictions:
11 |
12 | 1. The origin of this software must not be misrepresented; you must not
13 | claim that you wrote the original software. If you use this software
14 | in a product, an acknowledgment in the product documentation would be
15 | appreciated but is not required.
16 |
17 | 2. Altered source versions must be plainly marked as such, and must not be
18 | misrepresented as being the original software.
19 |
20 | 3. This notice may not be removed or altered from any source
21 | distribution.
22 | */
23 |
24 | #ifndef THREAD_POOL_H
25 | #define THREAD_POOL_H
26 |
27 | #include
28 | #include
29 | #include
30 | #include
31 | #include
32 | #include
33 | #include
34 | #include
35 | #include
36 |
37 | class ThreadPool {
38 | public:
39 | ThreadPool(size_t);
40 | template
41 | auto enqueue(F&& f, Args&&... args)
42 | -> std::future::type>;
43 | ~ThreadPool();
44 | private:
45 | // need to keep track of threads so we can join them
46 | std::vector< std::thread > workers;
47 | // the task queue
48 | std::queue< std::function > tasks;
49 |
50 | // synchronization
51 | std::mutex queue_mutex;
52 | std::condition_variable condition;
53 | bool stop;
54 | };
55 |
56 | // the constructor just launches some amount of workers
57 | inline ThreadPool::ThreadPool(size_t threads)
58 | : stop(false)
59 | {
60 | for(size_t i = 0;i task;
67 |
68 | {
69 | std::unique_lock lock(this->queue_mutex);
70 | this->condition.wait(lock,
71 | [this]{ return this->stop || !this->tasks.empty(); });
72 | if(this->stop && this->tasks.empty())
73 | return;
74 | task = std::move(this->tasks.front());
75 | this->tasks.pop();
76 | }
77 |
78 | task();
79 | }
80 | }
81 | );
82 | }
83 |
84 | // add new work item to the pool
85 | template
86 | auto ThreadPool::enqueue(F&& f, Args&&... args)
87 | -> std::future::type>
88 | {
89 | using return_type = typename std::result_of::type;
90 |
91 | auto task = std::make_shared< std::packaged_task >(
92 | std::bind(std::forward(f), std::forward(args)...)
93 | );
94 |
95 | std::future res = task->get_future();
96 | {
97 | std::unique_lock lock(queue_mutex);
98 |
99 | // don't allow enqueueing after stopping the pool
100 | if(stop)
101 | throw std::runtime_error("enqueue on stopped ThreadPool");
102 |
103 | tasks.emplace([task](){ (*task)(); });
104 | }
105 | condition.notify_one();
106 | return res;
107 | }
108 |
109 | // the destructor joins all threads
110 | inline ThreadPool::~ThreadPool()
111 | {
112 | {
113 | std::unique_lock lock(queue_mutex);
114 | stop = true;
115 | }
116 | condition.notify_all();
117 | for(std::thread &worker: workers)
118 | worker.join();
119 | }
120 |
121 | #endif
122 |
--------------------------------------------------------------------------------
/commons/buffers.h:
--------------------------------------------------------------------------------
1 | #ifndef BUFFERS_H
2 | #define BUFFERS_H
3 |
4 | #include "general.h"
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 |
17 | namespace algorithms
18 | {
19 |
20 | template
21 | class GenericBuffer
22 | {
23 | public:
24 | //!
25 | //! \brief Construct an empty buffer.
26 | //!
27 | GenericBuffer(nvinfer1::DataType type = nvinfer1::DataType::kFLOAT)
28 | : mSize(0), mCapacity(0), mType(type), mBuffer(nullptr)
29 | {
30 | }
31 |
32 | //!
33 | //! \brief Construct a buffer with the specified allocation size in bytes.
34 | //!
35 | GenericBuffer(size_t size, nvinfer1::DataType type)
36 | : mSize(size), mCapacity(size), mType(type)
37 | {
38 | if (!allocFn(&mBuffer, this->nbBytes()))
39 | {
40 | throw std::bad_alloc();
41 | }
42 | }
43 |
44 | GenericBuffer(GenericBuffer &&buf)
45 | : mSize(buf.mSize), mCapacity(buf.mCapacity), mType(buf.mType), mBuffer(buf.mBuffer)
46 | {
47 | buf.mSize = 0;
48 | buf.mCapacity = 0;
49 | buf.mType = nvinfer1::DataType::kFLOAT;
50 | buf.mBuffer = nullptr;
51 | }
52 |
53 | GenericBuffer &operator=(GenericBuffer &&buf)
54 | {
55 | if (this != &buf)
56 | {
57 | freeFn(mBuffer);
58 | mSize = buf.mSize;
59 | mCapacity = buf.mCapacity;
60 | mType = buf.mType;
61 | mBuffer = buf.mBuffer;
62 | // Reset buf.
63 | buf.mSize = 0;
64 | buf.mCapacity = 0;
65 | buf.mBuffer = nullptr;
66 | }
67 | return *this;
68 | }
69 |
70 | //!
71 | //! \brief Returns pointer to underlying array.
72 | //!
73 | void *data()
74 | {
75 | return mBuffer;
76 | }
77 |
78 | //!
79 | //! \brief Returns pointer to underlying array.
80 | //!
81 | const void *data() const
82 | {
83 | return mBuffer;
84 | }
85 |
86 | //!
87 | //! \brief Returns the size (in number of elements) of the buffer.
88 | //!
89 | size_t size() const
90 | {
91 | return mSize;
92 | }
93 |
94 | //!
95 | //! \brief Returns the size (in bytes) of the buffer.
96 | //!
97 | size_t nbBytes() const
98 | {
99 | return this->size() * getElementSize(mType);
100 | }
101 |
102 | //!
103 | //! \brief Resizes the buffer. This is a no-op if the new size is smaller than or equal to the current capacity.
104 | //!
105 | void resize(size_t newSize)
106 | {
107 | mSize = newSize;
108 | if (mCapacity < newSize)
109 | {
110 | freeFn(mBuffer);
111 | if (!allocFn(&mBuffer, this->nbBytes()))
112 | {
113 | throw std::bad_alloc{};
114 | }
115 | mCapacity = newSize;
116 | }
117 | }
118 |
119 | //!
120 | //! \brief Overload of resize that accepts Dims
121 | //!
122 | void resize(const nvinfer1::Dims &dims)
123 | {
124 | auto vol = std::accumulate(dims.d, dims.d + dims.nbDims, int64_t{1}, std::multiplies{});
125 | return this->resize(vol);
126 | }
127 |
128 | //!
129 | //! \brief copy data from host to device
130 | //!
131 | int host2device(void* data, bool async, const cudaStream_t& stream = 0)
132 | {
133 | int ret = 0;
134 | if(async)
135 | ret = cudaMemcpyAsync(mBuffer, data, nbBytes(), cudaMemcpyHostToDevice, stream);
136 | else
137 | ret = cudaMemcpy(mBuffer, data, nbBytes(), cudaMemcpyHostToDevice);
138 | return ret;
139 | }
140 |
141 | //!
142 | //! \brief copy data from device to host
143 | //!
144 | int device2host(void* data, bool async, const cudaStream_t& stream = 0)
145 | {
146 | int ret = 0;
147 | if(async)
148 | ret = cudaMemcpyAsync(data, mBuffer, nbBytes(), cudaMemcpyDeviceToHost, stream);
149 | else
150 | ret = cudaMemcpy(data, mBuffer, nbBytes(), cudaMemcpyDeviceToHost);
151 | return ret;
152 | }
153 |
154 | //!
155 | //! \brief copy data from device to host
156 | //!
157 | int device2device(void* data, bool async, const cudaStream_t& stream = 0)
158 | {
159 | int ret = 0;
160 | if (async)
161 | ret = cudaMemcpyAsync(data, mBuffer, nbBytes(), cudaMemcpyDeviceToDevice, stream);
162 | else
163 | ret = cudaMemcpy(data, mBuffer, nbBytes(), cudaMemcpyDeviceToDevice);
164 | return ret;
165 | }
166 |
167 | nvinfer1::DataType getDataType()
168 | {
169 | return mType;
170 | }
171 |
172 | ~GenericBuffer()
173 | {
174 | freeFn(mBuffer);
175 | }
176 |
177 | private:
178 | size_t mSize{0}, mCapacity{0};
179 | nvinfer1::DataType mType;
180 | void *mBuffer;
181 | AllocFunc allocFn;
182 | FreeFunc freeFn;
183 | };
184 |
185 | class DeviceAllocator
186 | {
187 | public:
188 | bool operator()(void **ptr, size_t size) const
189 | {
190 | return cudaMalloc(ptr, size) == cudaSuccess;
191 | }
192 | };
193 |
194 | class DeviceFree
195 | {
196 | public:
197 | void operator()(void *ptr) const
198 | {
199 | cudaFree(ptr);
200 | }
201 | };
202 | using DeviceBuffer = GenericBuffer;
203 | }
204 |
205 | #endif
206 |
--------------------------------------------------------------------------------
/commons/general.h:
--------------------------------------------------------------------------------
1 | #ifndef GENERAL_H
2 | #define GENERAL_H
3 |
4 | #include
5 | #include
6 | #include
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 | #include
14 | #include
15 | #include
16 |
17 | namespace algorithms
18 | {
19 | uint32_t getElementSize(nvinfer1::DataType t) noexcept
20 | {
21 | switch (t)
22 | {
23 | case nvinfer1::DataType::kINT32:
24 | return 4;
25 | case nvinfer1::DataType::kFLOAT:
26 | return 4;
27 | case nvinfer1::DataType::kHALF:
28 | return 2;
29 | case nvinfer1::DataType::kBOOL:
30 | case nvinfer1::DataType::kUINT8:
31 | case nvinfer1::DataType::kINT8:
32 | return 1;
33 | }
34 | return 0;
35 | }
36 |
37 | template
38 | Type string2Num(const std::string &str)
39 | {
40 | std::istringstream iss(str);
41 | Type num;
42 | iss >> std::hex >> num;
43 | return num;
44 | }
45 |
46 | std::vector read_names(const std::string filename)
47 | {
48 | std::vector names;
49 | std::ifstream infile(filename);
50 | //assert(stream.is_open());
51 |
52 | std::string line;
53 | while (std::getline(infile, line))
54 | {
55 | names.emplace_back(line);
56 | }
57 | return names;
58 | }
59 | }
60 | #endif
61 |
--------------------------------------------------------------------------------
/data/2023-09-14 20-42-44.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingj2021/segment-anything-tensorrt/bd27bc114fc6f6881c905d56f802b3d7ab3f2eb9/data/2023-09-14 20-42-44.png
--------------------------------------------------------------------------------
/data/truck.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingj2021/segment-anything-tensorrt/bd27bc114fc6f6881c905d56f802b3d7ab3f2eb9/data/truck.jpg
--------------------------------------------------------------------------------
/export.h:
--------------------------------------------------------------------------------
1 | #ifndef EXPORT_H
2 | #define EXPORT_H
3 |
4 | #include
5 | #include
6 | #include "sam_utils.h"
7 |
8 |
9 | void export_engine_image_encoder(std::string f="vit_l_embedding.onnx",std::string output="vit_l_embedding.engine")
10 | {
11 | // create an instance of the builder
12 | std::unique_ptr builder(createInferBuilder(logger));
13 | // create a network definition
14 | // The kEXPLICIT_BATCH flag is required in order to import models using the ONNX parser.
15 | uint32_t flag = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
16 |
17 | //auto network = std::make_unique(builder->createNetworkV2(flag));
18 | std::unique_ptr network(builder->createNetworkV2(flag));
19 |
20 | // Importing a Model Using the ONNX Parser
21 | //auto parser = std::make_unique(createParser(*network, logger));
22 | std::unique_ptr parser(createParser(*network, logger));
23 |
24 | // read the model file and process any errors
25 | parser->parseFromFile(f.c_str(),
26 | static_cast(nvinfer1::ILogger::Severity::kWARNING));
27 | for (int32_t i = 0; i < parser->getNbErrors(); ++i)
28 | {
29 | std::cout << parser->getError(i)->desc() << std::endl;
30 | }
31 |
32 | // create a build configuration specifying how TensorRT should optimize the model
33 | std::unique_ptr config(builder->createBuilderConfig());
34 |
35 | // maximum workspace size
36 | // int workspace = 4; // GB
37 | // config->setMaxWorkspaceSize(workspace * 1U << 30);
38 | config->setFlag(BuilderFlag::kGPU_FALLBACK);
39 |
40 | config->setFlag(BuilderFlag::kFP16);
41 |
42 | // create an engine
43 | // auto serializedModel = std::make_unique(builder->buildSerializedNetwork(*network, *config));
44 | std::unique_ptr serializedModel(builder->buildSerializedNetwork(*network, *config));
45 | std::cout << "serializedModel->size()" << serializedModel->size() << std::endl;
46 | std::ofstream outfile(output, std::ofstream::out | std::ofstream::binary);
47 | outfile.write((char*)serializedModel->data(), serializedModel->size());
48 | }
49 |
50 | void export_engine_prompt_encoder_and_mask_decoder(std::string f="sam_onnx_example.onnx",std::string output="sam_onnx_example.engine")
51 | {
52 | // create an instance of the builder
53 | std::unique_ptr builder(createInferBuilder(logger));
54 | // create a network definition
55 | // The kEXPLICIT_BATCH flag is required in order to import models using the ONNX parser.
56 | uint32_t flag = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
57 |
58 | //auto network = std::make_unique(builder->createNetworkV2(flag));
59 | std::unique_ptr network(builder->createNetworkV2(flag));
60 |
61 | // Importing a Model Using the ONNX Parser
62 | //auto parser = std::make_unique(createParser(*network, logger));
63 | std::unique_ptr parser(createParser(*network, logger));
64 |
65 | // read the model file and process any errors
66 | parser->parseFromFile(f.c_str(),
67 | static_cast(nvinfer1::ILogger::Severity::kWARNING));
68 | for (int32_t i = 0; i < parser->getNbErrors(); ++i)
69 | {
70 | std::cout << parser->getError(i)->desc() << std::endl;
71 | }
72 |
73 | // create a build configuration specifying how TensorRT should optimize the model
74 | std::unique_ptr config(builder->createBuilderConfig());
75 |
76 | // maximum workspace size
77 | // int workspace = 8; // GB
78 | // config->setMaxWorkspaceSize(workspace * 1U << 30);
79 | config->setFlag(BuilderFlag::kGPU_FALLBACK);
80 |
81 | config->setFlag(BuilderFlag::kFP16);
82 |
83 | nvinfer1::IOptimizationProfile* profile = builder->createOptimizationProfile();
84 | // profile->setDimensions("image_embeddings", nvinfer1::OptProfileSelector::kMIN, {1, 256, 64, 64 });
85 | // profile->setDimensions("image_embeddings", nvinfer1::OptProfileSelector::kOPT, {1, 256, 64, 64 });
86 | // profile->setDimensions("image_embeddings", nvinfer1::OptProfileSelector::kMAX, {1, 256, 64, 64 });
87 |
88 | profile->setDimensions("point_coords", nvinfer1::OptProfileSelector::kMIN, {3, 1, 2,2 });
89 | profile->setDimensions("point_coords", nvinfer1::OptProfileSelector::kOPT, { 3,1, 5,2 });
90 | profile->setDimensions("point_coords", nvinfer1::OptProfileSelector::kMAX, { 3,1,10,2 });
91 |
92 | profile->setDimensions("point_labels", nvinfer1::OptProfileSelector::kMIN, { 2,1, 2});
93 | profile->setDimensions("point_labels", nvinfer1::OptProfileSelector::kOPT, { 2,1, 5 });
94 | profile->setDimensions("point_labels", nvinfer1::OptProfileSelector::kMAX, { 2,1,10 });
95 |
96 | // profile->setDimensions("mask_input", nvinfer1::OptProfileSelector::kMIN, { 1, 1, 256, 256});
97 | // profile->setDimensions("mask_input", nvinfer1::OptProfileSelector::kOPT, { 1, 1, 256, 256 });
98 | // profile->setDimensions("mask_input", nvinfer1::OptProfileSelector::kMAX, { 1, 1, 256, 256 });
99 |
100 | // profile->setDimensions("has_mask_input", nvinfer1::OptProfileSelector::kMIN, { 1,});
101 | // profile->setDimensions("has_mask_input", nvinfer1::OptProfileSelector::kOPT, { 1, });
102 | // profile->setDimensions("has_mask_input", nvinfer1::OptProfileSelector::kMAX, { 1, });
103 |
104 | config->addOptimizationProfile(profile);
105 |
106 | // create an engine
107 | // auto serializedModel = std::make_unique(builder->buildSerializedNetwork(*network, *config));
108 | std::unique_ptr serializedModel(builder->buildSerializedNetwork(*network, *config));
109 | std::cout << "serializedModel->size()" << serializedModel->size() << std::endl;
110 | std::ofstream outfile(output, std::ofstream::out | std::ofstream::binary);
111 | outfile.write((char*)serializedModel->data(), serializedModel->size());
112 | }
113 | #endif
--------------------------------------------------------------------------------
/how_to_export_vim_h_model.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# Divide the embeding model into 2 parts, named as part1 and part2 \n",
10 | "# The decoder model remains unchanged\n",
11 | "# export_embedding_model_part_1()\n",
12 | "# export_embedding_model_part_2()\n",
13 | "# export_sam_h_model()\n",
14 | "# onnx_model_example2()"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "import torch\n",
24 | "import torch.nn as nn\n",
25 | "from torch.nn import functional as F\n",
26 | "from segment_anything.modeling import Sam\n",
27 | "import numpy as np\n",
28 | "from torchvision.transforms.functional import resize, to_pil_image\n",
29 | "from typing import Tuple\n",
30 | "from segment_anything import sam_model_registry, SamPredictor\n",
31 | "import cv2\n",
32 | "import matplotlib.pyplot as plt\n",
33 | "import warnings\n",
34 | "import os\n",
35 | "os.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\"\n",
36 | "import onnxruntime\n",
37 | "\n",
38 | "\n",
39 | "def show_mask(mask, ax, random_color=False):\n",
40 | " if random_color:\n",
41 | " color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n",
42 | " else:\n",
43 | " color = np.array([30/255, 144/255, 255/255, 0.6])\n",
44 | " h, w = mask.shape[-2:]\n",
45 | " mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n",
46 | " ax.imshow(mask_image)\n",
47 | " \n",
48 | "def show_points(coords, labels, ax, marker_size=375):\n",
49 | " pos_points = coords[labels==1]\n",
50 | " neg_points = coords[labels==0]\n",
51 | " ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)\n",
52 | " ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) \n",
53 | " \n",
54 | "def show_box(box, ax):\n",
55 | " x0, y0 = box[0], box[1]\n",
56 | " w, h = box[2] - box[0], box[3] - box[1]\n",
57 | " ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) \n",
58 | "\n",
59 | "\n",
60 | "def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:\n",
61 | " \"\"\"\n",
62 | " Compute the output size given input size and target long side length.\n",
63 | " \"\"\"\n",
64 | " scale = long_side_length * 1.0 / max(oldh, oldw)\n",
65 | " newh, neww = oldh * scale, oldw * scale\n",
66 | " neww = int(neww + 0.5)\n",
67 | " newh = int(newh + 0.5)\n",
68 | " return (newh, neww)\n",
69 | "\n",
70 | "# @torch.no_grad()\n",
71 | "def pre_processing(image: np.ndarray, target_length: int, device,pixel_mean,pixel_std,img_size):\n",
72 | " target_size = get_preprocess_shape(image.shape[0], image.shape[1], target_length)\n",
73 | " input_image = np.array(resize(to_pil_image(image), target_size))\n",
74 | " input_image_torch = torch.as_tensor(input_image, device=device)\n",
75 | " input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]\n",
76 | "\n",
77 | " # Normalize colors\n",
78 | " input_image_torch = (input_image_torch - pixel_mean) / pixel_std\n",
79 | "\n",
80 | " # Pad\n",
81 | " h, w = input_image_torch.shape[-2:]\n",
82 | " padh = img_size - h\n",
83 | " padw = img_size - w\n",
84 | " input_image_torch = F.pad(input_image_torch, (0, padw, 0, padh))\n",
85 | " return input_image_torch\n",
86 | "\n",
87 | "def export_embedding_model():\n",
88 | " sam_checkpoint = \"weights/sam_vit_l_0b3195.pth\"\n",
89 | " model_type = \"vit_l\"\n",
90 | "\n",
91 | " device = \"cpu\"\n",
92 | "\n",
93 | " sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n",
94 | " sam.to(device=device)\n",
95 | "\n",
96 | " # image_encoder = EmbeddingOnnxModel(sam)\n",
97 | " image = cv2.imread('notebooks/images/truck.jpg')\n",
98 | " target_length = sam.image_encoder.img_size\n",
99 | " pixel_mean = sam.pixel_mean \n",
100 | " pixel_std = sam.pixel_std\n",
101 | " img_size = sam.image_encoder.img_size\n",
102 | " inputs = pre_processing(image, target_length, device,pixel_mean,pixel_std,img_size)\n",
103 | " onnx_model_path = model_type+\"_\"+\"embedding.onnx\"\n",
104 | " dummy_inputs = {\n",
105 | " \"images\": inputs\n",
106 | "}\n",
107 | " output_names = [\"image_embeddings\"]\n",
108 | " image_embeddings = sam.image_encoder(inputs).cpu().numpy()\n",
109 | " print('image_embeddings', image_embeddings.shape)\n",
110 | " with warnings.catch_warnings():\n",
111 | " warnings.filterwarnings(\"ignore\", category=torch.jit.TracerWarning)\n",
112 | " warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
113 | " with open(onnx_model_path, \"wb\") as f:\n",
114 | " torch.onnx.export(\n",
115 | " sam.image_encoder,\n",
116 | " tuple(dummy_inputs.values()),\n",
117 | " f,\n",
118 | " export_params=True,\n",
119 | " verbose=False,\n",
120 | " opset_version=13,\n",
121 | " do_constant_folding=True,\n",
122 | " input_names=list(dummy_inputs.keys()),\n",
123 | " output_names=output_names,\n",
124 | " # dynamic_axes=dynamic_axes,\n",
125 | " ) \n",
126 | "\n",
127 | "def export_sam_model():\n",
128 | " from segment_anything.utils.onnx import SamOnnxModel\n",
129 | " checkpoint = \"weights/sam_vit_l_0b3195.pth\"\n",
130 | " model_type = \"vit_l\"\n",
131 | " sam = sam_model_registry[model_type](checkpoint=checkpoint)\n",
132 | " onnx_model_path = \"sam_onnx_example.onnx\"\n",
133 | "\n",
134 | " onnx_model = SamOnnxModel(sam, return_single_mask=True)\n",
135 | "\n",
136 | " dynamic_axes = {\n",
137 | " \"point_coords\": {1: \"num_points\"},\n",
138 | " \"point_labels\": {1: \"num_points\"},\n",
139 | " }\n",
140 | "\n",
141 | " embed_dim = sam.prompt_encoder.embed_dim\n",
142 | " embed_size = sam.prompt_encoder.image_embedding_size\n",
143 | " mask_input_size = [4 * x for x in embed_size]\n",
144 | " dummy_inputs = {\n",
145 | " \"image_embeddings\": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),\n",
146 | " \"point_coords\": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),\n",
147 | " \"point_labels\": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),\n",
148 | " \"mask_input\": torch.randn(1, 1, *mask_input_size, dtype=torch.float),\n",
149 | " \"has_mask_input\": torch.tensor([1], dtype=torch.float),\n",
150 | " # \"orig_im_size\": torch.tensor([1500, 2250], dtype=torch.float),\n",
151 | " }\n",
152 | " # output_names = [\"masks\", \"iou_predictions\", \"low_res_masks\"]\n",
153 | " output_names = [\"masks\", \"scores\"]\n",
154 | "\n",
155 | " with warnings.catch_warnings():\n",
156 | " warnings.filterwarnings(\"ignore\", category=torch.jit.TracerWarning)\n",
157 | " warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
158 | " with open(onnx_model_path, \"wb\") as f:\n",
159 | " torch.onnx.export(\n",
160 | " onnx_model,\n",
161 | " tuple(dummy_inputs.values()),\n",
162 | " f,\n",
163 | " export_params=True,\n",
164 | " verbose=False,\n",
165 | " opset_version=13,\n",
166 | " do_constant_folding=True,\n",
167 | " input_names=list(dummy_inputs.keys()),\n",
168 | " output_names=output_names,\n",
169 | " dynamic_axes=dynamic_axes,\n",
170 | " ) \n",
171 | "\n",
172 | "def onnx_model_example():\n",
173 | " import os\n",
174 | " ort_session_embedding = onnxruntime.InferenceSession('vit_l_embedding.onnx',providers=['CPUExecutionProvider'])\n",
175 | " ort_session_sam = onnxruntime.InferenceSession('sam_onnx_example.onnx',providers=['CPUExecutionProvider'])\n",
176 | "\n",
177 | " image = cv2.imread('notebooks/images/truck.jpg')\n",
178 | " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
179 | " image2 = image.copy()\n",
180 | " prompt_embed_dim = 256\n",
181 | " image_size = 1024\n",
182 | " vit_patch_size = 16\n",
183 | " image_embedding_size = image_size // vit_patch_size\n",
184 | " target_length = image_size\n",
185 | " pixel_mean=[123.675, 116.28, 103.53],\n",
186 | " pixel_std=[58.395, 57.12, 57.375]\n",
187 | " pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1)\n",
188 | " pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1)\n",
189 | " device = \"cpu\"\n",
190 | " inputs = pre_processing(image, target_length, device,pixel_mean,pixel_std,image_size)\n",
191 | " ort_inputs = {\n",
192 | " \"images\": inputs.cpu().numpy()\n",
193 | " }\n",
194 | " image_embeddings = ort_session_embedding.run(None, ort_inputs)[0]\n",
195 | "\n",
196 | " input_point = np.array([[500, 375]])\n",
197 | " input_label = np.array([1])\n",
198 | "\n",
199 | " onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]\n",
200 | " onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)\n",
201 | " from segment_anything.utils.transforms import ResizeLongestSide\n",
202 | " transf = ResizeLongestSide(image_size)\n",
203 | " onnx_coord = transf.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)\n",
204 | " onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)\n",
205 | " onnx_has_mask_input = np.zeros(1, dtype=np.float32)\n",
206 | "\n",
207 | " ort_inputs = {\n",
208 | " \"image_embeddings\": image_embeddings,\n",
209 | " \"point_coords\": onnx_coord,\n",
210 | " \"point_labels\": onnx_label,\n",
211 | " \"mask_input\": onnx_mask_input,\n",
212 | " \"has_mask_input\": onnx_has_mask_input,\n",
213 | " # \"orig_im_size\": np.array(image.shape[:2], dtype=np.float32)\n",
214 | " }\n",
215 | "\n",
216 | " masks, _ = ort_session_sam.run(None, ort_inputs)\n",
217 | "\n",
218 | " from segment_anything.utils.onnx import SamOnnxModel\n",
219 | " checkpoint = \"weights/sam_vit_l_0b3195.pth\"\n",
220 | " model_type = \"vit_l\"\n",
221 | " sam = sam_model_registry[model_type](checkpoint=checkpoint)\n",
222 | " onnx_model_path = \"sam_onnx_example.onnx\"\n",
223 | "\n",
224 | " onnx_model = SamOnnxModel(sam, return_single_mask=True)\n",
225 | " masks = onnx_model.mask_postprocessing(torch.as_tensor(masks), torch.as_tensor(image.shape[:2]))\n",
226 | " masks = masks > 0.0\n",
227 | " plt.figure(figsize=(10, 10))\n",
228 | " plt.imshow(image)\n",
229 | " show_mask(masks, plt.gca())\n",
230 | " # show_box(input_box, plt.gca())\n",
231 | " show_points(input_point, input_label, plt.gca())\n",
232 | " plt.axis('off')\n",
233 | " plt.savefig('demo.png')\n",
234 | "\n",
235 | "def export_engine_image_encoder(f='vit_l_embedding.onnx'):\n",
236 | " import tensorrt as trt\n",
237 | " from pathlib import Path\n",
238 | " file = Path(f)\n",
239 | " f = file.with_suffix('.engine') # TensorRT engine file\n",
240 | " onnx = file.with_suffix('.onnx')\n",
241 | " logger = trt.Logger(trt.Logger.INFO)\n",
242 | " builder = trt.Builder(logger)\n",
243 | " config = builder.create_builder_config()\n",
244 | " workspace = 6\n",
245 | " print(\"workspace: \", workspace)\n",
246 | " config.max_workspace_size = workspace * 1 << 30\n",
247 | " flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))\n",
248 | " network = builder.create_network(flag)\n",
249 | " parser = trt.OnnxParser(network, logger)\n",
250 | " if not parser.parse_from_file(str(onnx)):\n",
251 | " raise RuntimeError(f'failed to load ONNX file: {onnx}')\n",
252 | "\n",
253 | " inputs = [network.get_input(i) for i in range(network.num_inputs)]\n",
254 | " outputs = [network.get_output(i) for i in range(network.num_outputs)]\n",
255 | " for inp in inputs:\n",
256 | " print(f'input \"{inp.name}\" with shape{inp.shape} {inp.dtype}')\n",
257 | " for out in outputs:\n",
258 | " print(f'output \"{out.name}\" with shape{out.shape} {out.dtype}')\n",
259 | "\n",
260 | " half = True\n",
261 | " print(f'building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')\n",
262 | " if builder.platform_has_fast_fp16 and half:\n",
263 | " config.set_flag(trt.BuilderFlag.FP16)\n",
264 | " with builder.build_engine(network, config) as engine, open(f, 'wb') as t:\n",
265 | " t.write(engine.serialize())\n",
266 | "\n",
267 | "\n",
268 | "def export_engine_prompt_encoder_and_mask_decoder(f='sam_onnx_example.onnx'):\n",
269 | " import tensorrt as trt\n",
270 | " from pathlib import Path\n",
271 | " file = Path(f)\n",
272 | " f = file.with_suffix('.engine') # TensorRT engine file\n",
273 | " onnx = file.with_suffix('.onnx')\n",
274 | " logger = trt.Logger(trt.Logger.INFO)\n",
275 | " builder = trt.Builder(logger)\n",
276 | " config = builder.create_builder_config()\n",
277 | " workspace = 10\n",
278 | " print(\"workspace: \", workspace)\n",
279 | " config.max_workspace_size = workspace * 1 << 30\n",
280 | " flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))\n",
281 | " network = builder.create_network(flag)\n",
282 | " parser = trt.OnnxParser(network, logger)\n",
283 | " if not parser.parse_from_file(str(onnx)):\n",
284 | " raise RuntimeError(f'failed to load ONNX file: {onnx}')\n",
285 | "\n",
286 | " inputs = [network.get_input(i) for i in range(network.num_inputs)]\n",
287 | " outputs = [network.get_output(i) for i in range(network.num_outputs)]\n",
288 | " for inp in inputs:\n",
289 | " print(f'input \"{inp.name}\" with shape{inp.shape} {inp.dtype}')\n",
290 | " for out in outputs:\n",
291 | " print(f'output \"{out.name}\" with shape{out.shape} {out.dtype}')\n",
292 | "\n",
293 | " profile = builder.create_optimization_profile()\n",
294 | " profile.set_shape('image_embeddings', (1, 256, 64, 64), (1, 256, 64, 64), (1, 256, 64, 64))\n",
295 | " profile.set_shape('point_coords', (1, 2,2), (1, 5,2), (1,10,2))\n",
296 | " profile.set_shape('point_labels', (1, 2), (1, 5), (1,10))\n",
297 | " profile.set_shape('mask_input', (1, 1, 256, 256), (1, 1, 256, 256), (1, 1, 256, 256))\n",
298 | " profile.set_shape('has_mask_input', (1,), (1, ), (1, ))\n",
299 | " # # profile.set_shape_input('orig_im_size', (416,416), (1024,1024), (1500, 2250))\n",
300 | " # profile.set_shape_input('orig_im_size', (2,), (2,), (2, ))\n",
301 | " config.add_optimization_profile(profile)\n",
302 | "\n",
303 | " half = True\n",
304 | " print(f'building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')\n",
305 | " if builder.platform_has_fast_fp16 and half:\n",
306 | " config.set_flag(trt.BuilderFlag.FP16)\n",
307 | " with builder.build_engine(network, config) as engine, open(f, 'wb') as t:\n",
308 | " t.write(engine.serialize())\n",
309 | "\n",
310 | "\n",
311 | "def collect_stats(model, device, data_loader, num_batches):\n",
312 | " \"\"\"Feed data to the network and collect statistic\"\"\"\n",
313 | "\n",
314 | " # Enable calibrators\n",
315 | " for name, module in model.named_modules():\n",
316 | " if isinstance(module, quant_nn.TensorQuantizer):\n",
317 | " if module._calibrator is not None:\n",
318 | " module.disable_quant()\n",
319 | " module.enable_calib()\n",
320 | " else:\n",
321 | " module.disable()\n",
322 | "\n",
323 | " for i, (path, im, im0s, vid_cap, s) in tqdm(enumerate(data_loader), total=num_batches):\n",
324 | " im = torch.from_numpy(im).to(device)\n",
325 | " im = im.float()\n",
326 | " im /= 255 # 0 - 255 to 0.0 - 1.0\n",
327 | " if len(im.shape) == 3:\n",
328 | " im = im[None] # expand for batch dim\n",
329 | " model(im)\n",
330 | " if i >= num_batches:\n",
331 | " break\n",
332 | "\n",
333 | " # Disable calibrators\n",
334 | " for name, module in model.named_modules():\n",
335 | " if isinstance(module, quant_nn.TensorQuantizer):\n",
336 | " if module._calibrator is not None:\n",
337 | " module.enable_quant()\n",
338 | " module.disable_calib()\n",
339 | " else:\n",
340 | " module.enable()\n",
341 | "\n",
342 | "def compute_amax(model, **kwargs):\n",
343 | " # Load calib result\n",
344 | " for name, module in model.named_modules():\n",
345 | " if isinstance(module, quant_nn.TensorQuantizer):\n",
346 | " if module._calibrator is not None:\n",
347 | " if isinstance(module._calibrator, calib.MaxCalibrator):\n",
348 | " module.load_calib_amax()\n",
349 | " else:\n",
350 | " module.load_calib_amax(**kwargs)\n",
351 | "\n",
352 | "class embedding_model_part_1(nn.Module):\n",
353 | " def __init__(self) :\n",
354 | " super().__init__()\n",
355 | " sam_checkpoint = \"weights/sam_vit_h_4b8939.pth\"\n",
356 | " model_type = \"vit_h\"\n",
357 | " device = \"cpu\"\n",
358 | " self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n",
359 | " self.sam.to(device=device)\n",
360 | "\n",
361 | " def forward(self, x):\n",
362 | " x = self.sam.image_encoder.patch_embed(x)\n",
363 | " if self.sam.image_encoder.pos_embed is not None:\n",
364 | " x = x + self.sam.image_encoder.pos_embed\n",
365 | "\n",
366 | " val_1 = len(self.sam.image_encoder.blocks) // 2\n",
367 | " for blk in self.sam.image_encoder.blocks[0:val_1]:\n",
368 | " x = blk(x)\n",
369 | "\n",
370 | " return x\n",
371 | " \n",
372 | "class embedding_model_part_2(nn.Module):\n",
373 | " def __init__(self) :\n",
374 | " super().__init__()\n",
375 | " sam_checkpoint = \"weights/sam_vit_h_4b8939.pth\"\n",
376 | " model_type = \"vit_h\"\n",
377 | " device = \"cpu\"\n",
378 | " self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n",
379 | " self.sam.to(device=device)\n",
380 | "\n",
381 | " def forward(self, x):\n",
382 | " val_1 = len(self.sam.image_encoder.blocks) // 2\n",
383 | " for blk in self.sam.image_encoder.blocks[val_1:]:\n",
384 | " x = blk(x)\n",
385 | " x = self.sam.image_encoder.neck(x.permute(0, 3, 1, 2))\n",
386 | " return x\n",
387 | "\n",
388 | "def export_embedding_model_part_1():\n",
389 | " device = \"cpu\"\n",
390 | " model_type = \"vit_h\"\n",
391 | " model = embedding_model_part_1()\n",
392 | " # image_encoder = EmbeddingOnnxModel(sam)\n",
393 | " image = cv2.imread('notebooks/images/truck.jpg')\n",
394 | " target_length = model.sam.image_encoder.img_size\n",
395 | " pixel_mean = model.sam.pixel_mean \n",
396 | " pixel_std = model.sam.pixel_std\n",
397 | " img_size = model.sam.image_encoder.img_size\n",
398 | " inputs = pre_processing(image, target_length, device,pixel_mean,pixel_std,img_size)\n",
399 | " onnx_model_path = model_type+\"_\"+\"part_1_embedding.onnx\"\n",
400 | " dummy_inputs = {\n",
401 | " \"images\": inputs\n",
402 | "}\n",
403 | " output_names = [\"image_embeddings_part_1\"]\n",
404 | " with warnings.catch_warnings():\n",
405 | " warnings.filterwarnings(\"ignore\", category=torch.jit.TracerWarning)\n",
406 | " warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
407 | " with open(onnx_model_path, \"wb\") as f:\n",
408 | " torch.onnx.export(\n",
409 | " model,\n",
410 | " tuple(dummy_inputs.values()),\n",
411 | " f,\n",
412 | " export_params=True,\n",
413 | " verbose=False,\n",
414 | " opset_version=13,\n",
415 | " do_constant_folding=True,\n",
416 | " input_names=list(dummy_inputs.keys()),\n",
417 | " output_names=output_names,\n",
418 | " # dynamic_axes=dynamic_axes,\n",
419 | " ) \n",
420 | "\n",
421 | "def export_embedding_model_part_2():\n",
422 | " device = \"cpu\"\n",
423 | " model_type = \"vit_h\"\n",
424 | " model = embedding_model_part_2()\n",
425 | " \n",
426 | " inputs = torch.randn(1, 64, 64, 1280)\n",
427 | " onnx_model_path = model_type+\"_\"+\"part_2_embedding.onnx\"\n",
428 | " dummy_inputs = {\n",
429 | " \"image_embeddings_part_1\": inputs\n",
430 | "}\n",
431 | " output_names = [\"image_embeddings_part_2\"]\n",
432 | " with warnings.catch_warnings():\n",
433 | " warnings.filterwarnings(\"ignore\", category=torch.jit.TracerWarning)\n",
434 | " warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
435 | " with open(onnx_model_path, \"wb\") as f:\n",
436 | " torch.onnx.export(\n",
437 | " model,\n",
438 | " tuple(dummy_inputs.values()),\n",
439 | " f,\n",
440 | " export_params=True,\n",
441 | " verbose=False,\n",
442 | " opset_version=13,\n",
443 | " do_constant_folding=True,\n",
444 | " input_names=list(dummy_inputs.keys()),\n",
445 | " output_names=output_names,\n",
446 | " # dynamic_axes=dynamic_axes,\n",
447 | " ) \n",
448 | " \n",
449 | "def export_sam_h_model():\n",
450 | " from segment_anything.utils.onnx import SamOnnxModel\n",
451 | " checkpoint = \"weights/sam_vit_h_4b8939.pth\"\n",
452 | " model_type = \"vit_h\"\n",
453 | " sam = sam_model_registry[model_type](checkpoint=checkpoint)\n",
454 | " onnx_model_path = \"sam_h_decoder_onnx.onnx\"\n",
455 | "\n",
456 | " onnx_model = SamOnnxModel(sam, return_single_mask=True)\n",
457 | "\n",
458 | " dynamic_axes = {\n",
459 | " \"point_coords\": {1: \"num_points\"},\n",
460 | " \"point_labels\": {1: \"num_points\"},\n",
461 | " }\n",
462 | "\n",
463 | " embed_dim = sam.prompt_encoder.embed_dim\n",
464 | " embed_size = sam.prompt_encoder.image_embedding_size\n",
465 | " mask_input_size = [4 * x for x in embed_size]\n",
466 | " dummy_inputs = {\n",
467 | " \"image_embeddings\": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),\n",
468 | " \"point_coords\": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),\n",
469 | " \"point_labels\": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),\n",
470 | " \"mask_input\": torch.randn(1, 1, *mask_input_size, dtype=torch.float),\n",
471 | " \"has_mask_input\": torch.tensor([1], dtype=torch.float),\n",
472 | " # \"orig_im_size\": torch.tensor([1500, 2250], dtype=torch.float),\n",
473 | " }\n",
474 | " # output_names = [\"masks\", \"iou_predictions\", \"low_res_masks\"]\n",
475 | " output_names = [\"masks\", \"scores\"]\n",
476 | "\n",
477 | " with warnings.catch_warnings():\n",
478 | " warnings.filterwarnings(\"ignore\", category=torch.jit.TracerWarning)\n",
479 | " warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
480 | " with open(onnx_model_path, \"wb\") as f:\n",
481 | " torch.onnx.export(\n",
482 | " onnx_model,\n",
483 | " tuple(dummy_inputs.values()),\n",
484 | " f,\n",
485 | " export_params=True,\n",
486 | " verbose=False,\n",
487 | " opset_version=13,\n",
488 | " do_constant_folding=True,\n",
489 | " input_names=list(dummy_inputs.keys()),\n",
490 | " output_names=output_names,\n",
491 | " dynamic_axes=dynamic_axes,\n",
492 | " ) \n",
493 | "\n",
494 | "def onnx_model_example2():\n",
495 | " import os\n",
496 | " ort_session_embedding_part_1 = onnxruntime.InferenceSession('vit_h_part_1_embedding.onnx',providers=['CPUExecutionProvider'])\n",
497 | " ort_session_embedding_part_2 = onnxruntime.InferenceSession('vit_h_part_2_embedding.onnx',providers=['CPUExecutionProvider'])\n",
498 | " ort_session_sam = onnxruntime.InferenceSession('sam_h_decoder_onnx.onnx',providers=['CPUExecutionProvider'])\n",
499 | "\n",
500 | " image = cv2.imread('notebooks/images/truck.jpg')\n",
501 | " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
502 | " image2 = image.copy()\n",
503 | " prompt_embed_dim = 256\n",
504 | " image_size = 1024\n",
505 | " vit_patch_size = 16\n",
506 | " image_embedding_size = image_size // vit_patch_size\n",
507 | " target_length = image_size\n",
508 | " pixel_mean=[123.675, 116.28, 103.53],\n",
509 | " pixel_std=[58.395, 57.12, 57.375]\n",
510 | " pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1)\n",
511 | " pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1)\n",
512 | " device = \"cpu\"\n",
513 | " inputs = pre_processing(image, target_length, device,pixel_mean,pixel_std,image_size)\n",
514 | " ort_inputs = {\n",
515 | " \"images\": inputs.cpu().numpy()\n",
516 | " }\n",
517 | " image_embeddings = ort_session_embedding_part_1.run(None, ort_inputs)[0]\n",
518 | " ort_inputs = {\n",
519 | " \"image_embeddings_part_1\": image_embeddings\n",
520 | " }\n",
521 | " image_embeddings = ort_session_embedding_part_2.run(None, ort_inputs)[0]\n",
522 | "\n",
523 | " input_point = np.array([[784, 379]])\n",
524 | " input_label = np.array([1])\n",
525 | "\n",
526 | " onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]\n",
527 | " onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)\n",
528 | " from segment_anything.utils.transforms import ResizeLongestSide\n",
529 | " transf = ResizeLongestSide(image_size)\n",
530 | " onnx_coord = transf.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)\n",
531 | " onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)\n",
532 | " onnx_has_mask_input = np.zeros(1, dtype=np.float32)\n",
533 | "\n",
534 | " ort_inputs = {\n",
535 | " \"image_embeddings\": image_embeddings,\n",
536 | " \"point_coords\": onnx_coord,\n",
537 | " \"point_labels\": onnx_label,\n",
538 | " \"mask_input\": onnx_mask_input,\n",
539 | " \"has_mask_input\": onnx_has_mask_input,\n",
540 | " # \"orig_im_size\": np.array(image.shape[:2], dtype=np.float32)\n",
541 | " }\n",
542 | "\n",
543 | " masks, _ = ort_session_sam.run(None, ort_inputs)\n",
544 | "\n",
545 | " from segment_anything.utils.onnx import SamOnnxModel\n",
546 | " checkpoint = \"weights/sam_vit_h_4b8939.pth\"\n",
547 | " model_type = \"vit_h\"\n",
548 | " sam = sam_model_registry[model_type](checkpoint=checkpoint)\n",
549 | "\n",
550 | " onnx_model = SamOnnxModel(sam, return_single_mask=True)\n",
551 | " masks = onnx_model.mask_postprocessing(torch.as_tensor(masks), torch.as_tensor(image.shape[:2]))\n",
552 | " masks = masks > 0.0\n",
553 | " plt.figure(figsize=(10, 10))\n",
554 | " plt.imshow(image)\n",
555 | " show_mask(masks, plt.gca())\n",
556 | " # show_box(input_box, plt.gca())\n",
557 | " show_points(input_point, input_label, plt.gca())\n",
558 | " plt.axis('off')\n",
559 | " plt.savefig('demo.png')\n",
560 | " plt.show()\n",
561 | "\n",
562 | "\n",
563 | "if __name__ == '__main__':\n",
564 | " with torch.no_grad():\n",
565 | " # export_embedding_model()\n",
566 | " # export_sam_model()\n",
567 | " # onnx_model_example()\n",
568 | " # export_engine_image_encoder()\n",
569 | " # export_engine_prompt_encoder_and_mask_decoder()\n",
570 | " export_embedding_model_part_1()\n",
571 | " export_embedding_model_part_2()\n",
572 | " export_sam_h_model()\n",
573 | " onnx_model_example2()\n",
574 | " "
575 | ]
576 | }
577 | ],
578 | "metadata": {
579 | "kernelspec": {
580 | "display_name": "Python 3",
581 | "language": "python",
582 | "name": "python3"
583 | },
584 | "language_info": {
585 | "codemirror_mode": {
586 | "name": "ipython",
587 | "version": 3
588 | },
589 | "file_extension": ".py",
590 | "mimetype": "text/x-python",
591 | "name": "python",
592 | "nbconvert_exporter": "python",
593 | "pygments_lexer": "ipython3",
594 | "version": "3.8.10"
595 | },
596 | "orig_nbformat": 4
597 | },
598 | "nbformat": 4,
599 | "nbformat_minor": 2
600 | }
601 |
--------------------------------------------------------------------------------
/main.cpp:
--------------------------------------------------------------------------------
1 | #include "sam.h"
2 | #include "export.h"
3 | ///////////////////////////////////////////////////////////////////////////////
4 | using namespace std;
5 | using namespace cv;
6 |
7 | std::shared_ptr eng_0;
8 | std::shared_ptr eng_1;
9 | at::Tensor image_embeddings;
10 |
11 | void locator(int event, int x, int y, int flags, void *userdata)
12 | { // function to track mouse movement and click//
13 | if (event == EVENT_LBUTTONDOWN)
14 | { // when left button clicked//
15 | cout << "Left click has been made, Position:(" << x << "," << y << ")" << endl;
16 | // auto res = eng_1->prepareInput(x, y, x - 100, y - 100, x + 100, y + 100, image_embeddings);
17 | auto res = eng_1->prepareInput(x, y, image_embeddings);
18 | // std::vector mult_pts = {x,y,x-5,y-5,x+5,y+5};
19 | // auto res = eng_1->prepareInput(mult_pts, image_embeddings);
20 | std::cout << "------------------prepareInput: " << res << std::endl;
21 | res = eng_1->infer();
22 | std::cout << "------------------infer: " << res << std::endl;
23 | eng_1->verifyOutput();
24 | // cv::Mat roi;
25 | // eng_1->verifyOutput(roi);
26 | // roi *= 255;
27 | // cv::imshow("img2_", roi);
28 | // cv::waitKey();
29 | std::cout << "------------------verifyOutput: " << std::endl;
30 | }
31 | else if (event == EVENT_RBUTTONDOWN)
32 | { // when right button clicked//
33 | // cout << "Rightclick has been made, Position:(" << x << "," << y << ")" << endl;
34 | }
35 | else if (event == EVENT_MBUTTONDOWN)
36 | { // when middle button clicked//
37 | // cout << "Middleclick has been made, Position:(" << x << "," << y << ")" << endl;
38 | }
39 | else if (event == EVENT_MOUSEMOVE)
40 | { // when mouse pointer moves//
41 | // cout << "Current mouse position:(" << x << "," << y << ")" << endl;
42 | }
43 | }
44 |
45 | #define EMBEDDING
46 | #define SAMPROMPTENCODERANDMASKDECODER
47 | int main(int argc, char const *argv[])
48 | {
49 | ifstream f1("vit_l_embedding.engine");
50 | if (!f1.good())
51 | export_engine_image_encoder("/workspace/segment-anything-tensorrt/data/vit_l_embedding.onnx");
52 |
53 | ifstream f2("sam_onnx_example.engine");
54 | if (!f2.good())
55 | export_engine_prompt_encoder_and_mask_decoder("/workspace/segment-anything-tensorrt/data/sam_onnx_example.onnx");
56 |
57 | #ifdef EMBEDDING
58 | {
59 | // const std::string modelFile = "D:/projects/detections/data/vit_l_embedding.engine";
60 | const std::string modelFile = "vit_l_embedding.engine";
61 | std::ifstream engineFile(modelFile.c_str(), std::ifstream::binary);
62 | assert(engineFile);
63 | // if (!engineFile)
64 | // return;
65 |
66 | int fsize;
67 | engineFile.seekg(0, engineFile.end);
68 | fsize = engineFile.tellg();
69 | engineFile.seekg(0, engineFile.beg);
70 | std::vector engineData(fsize);
71 | engineFile.read(engineData.data(), fsize);
72 |
73 | if (engineFile)
74 | std::cout << "all characters read successfully." << std::endl;
75 | else
76 | std::cout << "error: only " << engineFile.gcount() << " could be read" << std::endl;
77 | engineFile.close();
78 |
79 | std::unique_ptr runtime(nvinfer1::createInferRuntime(logger));
80 | std::shared_ptr mEngine(runtime->deserializeCudaEngine(engineData.data(), fsize, nullptr));
81 | cv::Mat frame = cv::imread("/workspace/segment-anything-tensorrt/data/truck.jpg");
82 | std::cout << frame.size << std::endl;
83 | eng_0 = std::shared_ptr(new SamEmbedding(std::to_string(1), mEngine, frame));
84 | auto res = eng_0->prepareInput();
85 | std::cout << "------------------prepareInput: " << res << std::endl;
86 | res = eng_0->infer();
87 | std::cout << "------------------infer: " << res << std::endl;
88 | image_embeddings = eng_0->verifyOutput();
89 | std::cout << "------------------verifyOutput: " << std::endl;
90 | }
91 |
92 | #endif
93 |
94 | #ifdef SAMPROMPTENCODERANDMASKDECODER
95 | {
96 | // const std::string modelFile = "D:/projects/detections/data/sam_onnx_example.engine";
97 | const std::string modelFile = "sam_onnx_example.engine";
98 | std::ifstream engineFile(modelFile.c_str(), std::ifstream::binary);
99 | assert(engineFile);
100 | // if (!engineFile)
101 | // return;
102 |
103 | int fsize;
104 | engineFile.seekg(0, engineFile.end);
105 | fsize = engineFile.tellg();
106 | engineFile.seekg(0, engineFile.beg);
107 | std::vector engineData(fsize);
108 | engineFile.read(engineData.data(), fsize);
109 |
110 | if (engineFile)
111 | std::cout << "all characters read successfully." << std::endl;
112 | else
113 | std::cout << "error: only " << engineFile.gcount() << " could be read" << std::endl;
114 | engineFile.close();
115 |
116 | std::unique_ptr runtime(nvinfer1::createInferRuntime(logger));
117 | std::shared_ptr mEngine(runtime->deserializeCudaEngine(engineData.data(), fsize, nullptr));
118 | cv::Mat frame = cv::imread("/workspace/segment-anything-tensorrt/data/truck.jpg");
119 | eng_1 = std::shared_ptr(new SamPromptEncoderAndMaskDecoder(std::to_string(1), mEngine, frame));
120 | // Mat image = imread("D:/projects/detections/data/2.png");//loading image in the matrix//
121 | // namedWindow("img_", 0); // declaring window to show image//
122 | // setMouseCallback("img_", locator, NULL); // Mouse callback function on define window//
123 | // imshow("img_", frame); // showing image on the window//
124 | // waitKey(0); // wait for keystroke//
125 |
126 | auto res = eng_1->prepareInput(100, 100, image_embeddings);
127 | // std::vector mult_pts = {x,y,x-5,y-5,x+5,y+5};
128 | // auto res = eng_1->prepareInput(mult_pts, image_embeddings);
129 | std::cout << "------------------prepareInput: " << res << std::endl;
130 | res = eng_1->infer();
131 | std::cout << "------------------infer: " << res << std::endl;
132 | eng_1->verifyOutput();
133 | std::cout << "-----------------done" << std::endl;
134 | }
135 | #endif
136 | }
--------------------------------------------------------------------------------
/main_vim_h.cpp:
--------------------------------------------------------------------------------
1 | #include "sam.h"
2 | #include "export.h"
3 | #include "baseModel.h"
4 | ///////////////////////////////////////////////////////////////////////////////
5 | using namespace std;
6 | using namespace cv;
7 |
8 | std::shared_ptr eng_0;
9 | std::shared_ptr eng_2;
10 | std::shared_ptr eng_1;
11 | at::Tensor image_embeddings;
12 |
13 | void locator(int event, int x, int y, int flags, void *userdata)
14 | { // function to track mouse movement and click//
15 | if (event == EVENT_LBUTTONDOWN)
16 | { // when left button clicked//
17 | cout << "Left click has been made, Position:(" << x << "," << y << ")" << endl;
18 | // auto res = eng_1->prepareInput(x, y, x - 100, y - 100, x + 100, y + 100, image_embeddings);
19 | auto res = eng_1->prepareInput(x, y, image_embeddings);
20 | // std::vector mult_pts = {x,y,x-5,y-5,x+5,y+5};
21 | // auto res = eng_1->prepareInput(mult_pts, image_embeddings);
22 | std::cout << "------------------prepareInput: " << res << std::endl;
23 | res = eng_1->infer();
24 | std::cout << "------------------infer: " << res << std::endl;
25 | eng_1->verifyOutput();
26 | std::cout << "------------------verifyOutput: " << std::endl;
27 | }
28 | else if (event == EVENT_RBUTTONDOWN)
29 | { // when right button clicked//
30 | // cout << "Rightclick has been made, Position:(" << x << "," << y << ")" << endl;
31 | }
32 | else if (event == EVENT_MBUTTONDOWN)
33 | { // when middle button clicked//
34 | // cout << "Middleclick has been made, Position:(" << x << "," << y << ")" << endl;
35 | }
36 | else if (event == EVENT_MOUSEMOVE)
37 | { // when mouse pointer moves//
38 | // cout << "Current mouse position:(" << x << "," << y << ")" << endl;
39 | }
40 | }
41 |
42 | #define EMBEDDING
43 | #define SAMPROMPTENCODERANDMASKDECODER
44 | int main(int argc, char const *argv[])
45 | {
46 | // BaseModel tmp_model("vit_h_embedding_part_1.engine");
47 |
48 | ifstream f1("vit_h_embedding_part_1.engine");
49 | if (!f1.good())
50 | export_engine_image_encoder("/workspace/segment-anything-tensorrt/data/vit_h_part_1_embedding.onnx", "vit_h_embedding_part_1.engine");
51 |
52 | ifstream f2("vit_h_embedding_part_2.engine");
53 | if (!f2.good())
54 | export_engine_image_encoder("/workspace/segment-anything-tensorrt/data/vit_h_part_2_embedding.onnx", "vit_h_embedding_part_2.engine");
55 |
56 | ifstream f3("sam_onnx_decoder.engine");
57 | if (!f3.good())
58 | export_engine_prompt_encoder_and_mask_decoder("/workspace/segment-anything-tensorrt/data/sam_h_decoder_onnx.onnx", "sam_onnx_decoder.engine");
59 |
60 | #ifdef EMBEDDING
61 | {
62 | // const std::string modelFile = "D:/projects/detections/data/vit_l_embedding.engine";
63 | const std::string modelFile = "vit_h_embedding_part_1.engine";
64 | std::ifstream engineFile(modelFile.c_str(), std::ifstream::binary);
65 | assert(engineFile);
66 | // if (!engineFile)
67 | // return;
68 |
69 | int fsize;
70 | engineFile.seekg(0, engineFile.end);
71 | fsize = engineFile.tellg();
72 | engineFile.seekg(0, engineFile.beg);
73 | std::vector engineData(fsize);
74 | engineFile.read(engineData.data(), fsize);
75 |
76 | if (engineFile)
77 | std::cout << "all characters read successfully." << std::endl;
78 | else
79 | std::cout << "error: only " << engineFile.gcount() << " could be read" << std::endl;
80 | engineFile.close();
81 |
82 | std::unique_ptr runtime(nvinfer1::createInferRuntime(logger));
83 | std::shared_ptr mEngine(runtime->deserializeCudaEngine(engineData.data(), fsize, nullptr));
84 | cv::Mat frame = cv::imread("/workspace/segment-anything-tensorrt/data/truck.jpg");
85 | std::cout << frame.size << std::endl;
86 | eng_0 = std::shared_ptr(new SamEmbedding(std::to_string(1), mEngine, frame));
87 | auto res = eng_0->prepareInput();
88 | std::cout << "------------------prepareInput: " << res << std::endl;
89 | res = eng_0->infer();
90 | std::cout << "------------------infer: " << res << std::endl;
91 | image_embeddings = eng_0->verifyOutput("image_embeddings_part_1");
92 | std::cout << "------------------verifyOutput: " << std::endl;
93 | {
94 | // const std::string modelFile = "D:/projects/detections/data/vit_l_embedding.engine";
95 | const std::string modelFile = "vit_h_embedding_part_2.engine";
96 | std::ifstream engineFile(modelFile.c_str(), std::ifstream::binary);
97 | assert(engineFile);
98 | // if (!engineFile)
99 | // return;
100 |
101 | int fsize;
102 | engineFile.seekg(0, engineFile.end);
103 | fsize = engineFile.tellg();
104 | engineFile.seekg(0, engineFile.beg);
105 | std::vector engineData(fsize);
106 | engineFile.read(engineData.data(), fsize);
107 |
108 | if (engineFile)
109 | std::cout << "all characters read successfully." << std::endl;
110 | else
111 | std::cout << "error: only " << engineFile.gcount() << " could be read" << std::endl;
112 | engineFile.close();
113 |
114 | std::unique_ptr runtime(nvinfer1::createInferRuntime(logger));
115 | std::shared_ptr mEngine(runtime->deserializeCudaEngine(engineData.data(), fsize, nullptr));
116 | eng_2 = std::shared_ptr(new SamEmbedding2(std::to_string(1), mEngine));
117 | auto res = eng_2->prepareInput(image_embeddings);
118 | std::cout << "------------------prepareInput: " << res << std::endl;
119 | res = eng_2->infer();
120 | std::cout << "------------------infer: " << res << std::endl;
121 | image_embeddings = eng_2->verifyOutput();
122 | std::cout << "------------------verifyOutput: " << std::endl;
123 | }
124 | }
125 |
126 | #endif
127 |
128 | #ifdef SAMPROMPTENCODERANDMASKDECODER
129 | {
130 | // const std::string modelFile = "D:/projects/detections/data/sam_onnx_example.engine";
131 | const std::string modelFile = "sam_onnx_decoder.engine";
132 | std::ifstream engineFile(modelFile.c_str(), std::ifstream::binary);
133 | assert(engineFile);
134 | // if (!engineFile)
135 | // return;
136 |
137 | int fsize;
138 | engineFile.seekg(0, engineFile.end);
139 | fsize = engineFile.tellg();
140 | engineFile.seekg(0, engineFile.beg);
141 | std::vector engineData(fsize);
142 | engineFile.read(engineData.data(), fsize);
143 |
144 | if (engineFile)
145 | std::cout << "all characters read successfully." << std::endl;
146 | else
147 | std::cout << "error: only " << engineFile.gcount() << " could be read" << std::endl;
148 | engineFile.close();
149 |
150 | std::unique_ptr runtime(nvinfer1::createInferRuntime(logger));
151 | std::shared_ptr mEngine(runtime->deserializeCudaEngine(engineData.data(), fsize, nullptr));
152 | cv::Mat frame = cv::imread("/workspace/segment-anything-tensorrt/data/truck.jpg");
153 | eng_1 = std::shared_ptr(new SamPromptEncoderAndMaskDecoder(std::to_string(1), mEngine, frame));
154 | // namedWindow("img_", 0); // declaring window to show image//
155 | // setMouseCallback("img_", locator, NULL); // Mouse callback function on define window//
156 | // imshow("img_", frame); // showing image on the window//
157 | // waitKey(0); // wait for keystroke//
158 |
159 | auto res = eng_1->prepareInput(760, 373, image_embeddings);
160 | // // std::vector mult_pts = {x,y,x-5,y-5,x+5,y+5};
161 | // // auto res = eng_1->prepareInput(mult_pts, image_embeddings);
162 | // std::cout << "------------------prepareInput: " << res << std::endl;
163 | res = eng_1->infer();
164 | // std::cout << "------------------infer: " << res << std::endl;
165 | eng_1->verifyOutput();
166 | // std::cout << "-----------------done" << std::endl;
167 | }
168 | #endif
169 | }
--------------------------------------------------------------------------------
/sam.h:
--------------------------------------------------------------------------------
1 | #ifndef SAM_H
2 | #define SAM_H
3 |
4 | #include "buffers.h"
5 | #include
6 | #include
7 | // #include
8 | #include
9 | #include "sam_utils.h"
10 |
11 | using namespace torch::indexing;
12 |
13 | class ResizeLongestSide
14 | {
15 | public:
16 | ResizeLongestSide(int target_length);
17 | ~ResizeLongestSide();
18 |
19 | std::vector get_preprocess_shape(int oldh, int oldw);
20 | at::Tensor apply_coords(at::Tensor boxes, at::IntArrayRef sz);
21 |
22 | public:
23 | int m_target_length;
24 | };
25 |
26 | ResizeLongestSide::ResizeLongestSide(int target_length) : m_target_length(target_length)
27 | {
28 | }
29 |
30 | ResizeLongestSide::~ResizeLongestSide()
31 | {
32 | }
33 |
34 | std::vector ResizeLongestSide::get_preprocess_shape(int oldh, int oldw)
35 | {
36 | float scale = m_target_length * 1.0 / std::max(oldh, oldw);
37 | int newh = static_cast(oldh * scale + 0.5);
38 | int neww = static_cast(oldw * scale + 0.5);
39 | std::cout << " newh " << newh << " neww " << neww << std::endl;
40 | std::cout << "at::IntArrayRef{newh, neww}" << at::IntArrayRef{newh, neww} << std::endl;
41 | return std::vector{newh, neww};
42 | }
43 |
44 | at::Tensor ResizeLongestSide::apply_coords(at::Tensor coords, at::IntArrayRef sz)
45 | {
46 | int old_h = sz[0], old_w = sz[1];
47 | auto new_sz = get_preprocess_shape(old_h, old_w);
48 | int new_h = new_sz[0], new_w = new_sz[1];
49 | coords.index_put_({"...", 0}, coords.index({"...", 0}) * (1.0 * new_w / old_w));
50 | coords.index_put_({"...", 1}, coords.index({"...", 1}) * (1.0 * new_h / old_h));
51 | return coords;
52 | }
53 |
54 | ////////////////////////////////////////////////////////////////////////////////////
55 |
56 | class SamEmbedding
57 | {
58 | public:
59 | SamEmbedding(std::string bufferName, std::shared_ptr &engine, cv::Mat im, int width = 640, int height = 640);
60 | ~SamEmbedding();
61 |
62 | int prepareInput();
63 | bool infer();
64 | at::Tensor verifyOutput();
65 | at::Tensor verifyOutput(std::string output_name);
66 |
67 | public:
68 | std::shared_ptr mEngine;
69 | std::unique_ptr context;
70 |
71 | cudaStream_t stream;
72 | cudaEvent_t start, end;
73 |
74 | std::vector mDeviceBindings;
75 | std::map> mInOut;
76 | std::vector pad_info;
77 | std::vector names;
78 | cv::Mat frame;
79 | cv::Mat img;
80 | int inp_width = 640;
81 | int inp_height = 640;
82 | std::string mBufferName;
83 | };
84 |
85 | SamEmbedding::SamEmbedding(std::string bufferName, std::shared_ptr &engine, cv::Mat im, int width, int height) : mBufferName(bufferName), mEngine(engine), frame(im), inp_width(width), inp_height(height)
86 | {
87 | context = std::unique_ptr(mEngine->createExecutionContext());
88 | if (!context)
89 | {
90 | std::cerr << "create context error" << std::endl;
91 | }
92 |
93 | CHECK(cudaStreamCreate(&stream));
94 | CHECK(cudaEventCreateWithFlags(&start, cudaEventBlockingSync));
95 | CHECK(cudaEventCreateWithFlags(&end, cudaEventBlockingSync));
96 |
97 | for (int i = 0; i < mEngine->getNbBindings(); i++)
98 | {
99 | auto dims = mEngine->getBindingDimensions(i);
100 | auto tensor_name = mEngine->getBindingName(i);
101 | std::cout << "tensor_name: " << tensor_name << std::endl;
102 | // dims2str(dims);
103 | nvinfer1::DataType type = mEngine->getBindingDataType(i);
104 | // index2srt(type);
105 | int vecDim = mEngine->getBindingVectorizedDim(i);
106 | // std::cout << "vecDim:" << vecDim << std::endl;
107 | if (-1 != vecDim) // i.e., 0 != lgScalarsPerVector
108 | {
109 | int scalarsPerVec = mEngine->getBindingComponentsPerElement(i);
110 | std::cout << "scalarsPerVec" << scalarsPerVec << std::endl;
111 | }
112 | auto vol = std::accumulate(dims.d, dims.d + dims.nbDims, int64_t{1}, std::multiplies{});
113 | std::unique_ptr device_buffer{new algorithms::DeviceBuffer(vol, type)};
114 | mDeviceBindings.emplace_back(device_buffer->data());
115 | mInOut[tensor_name] = std::move(device_buffer);
116 | }
117 | }
118 |
119 | SamEmbedding::~SamEmbedding()
120 | {
121 | CHECK(cudaEventDestroy(start));
122 | CHECK(cudaEventDestroy(end));
123 | CHECK(cudaStreamDestroy(stream));
124 | }
125 |
126 | int SamEmbedding::prepareInput()
127 | {
128 | int prompt_embed_dim = 256;
129 | int image_size = 1024;
130 | int vit_patch_size = 16;
131 | int target_length = image_size;
132 | auto pixel_mean = at::tensor({123.675, 116.28, 103.53}, torch::kFloat).view({-1, 1, 1});
133 | auto pixel_std = at::tensor({58.395, 57.12, 57.375}, torch::kFloat).view({-1, 1, 1});
134 | ResizeLongestSide transf(image_size);
135 | int newh, neww;
136 | auto target_size = transf.get_preprocess_shape(frame.rows, frame.cols);
137 | // std::cout << " " << torch::IntArrayRef{newh,neww} << std::endl;
138 | std::cout << "target_size = " << target_size << std::endl;
139 | cv::Mat im_sz;
140 | std::cout << frame.size << std::endl;
141 | cv::resize(frame, im_sz, cv::Size(target_size[1], target_size[0]));
142 | im_sz.convertTo(im_sz, CV_32F, 1.0);
143 | at::Tensor input_image_torch =
144 | at::from_blob(im_sz.data, {im_sz.rows, im_sz.cols, im_sz.channels()})
145 | .permute({2, 0, 1})
146 | .contiguous()
147 | .unsqueeze(0);
148 | input_image_torch = (input_image_torch - pixel_mean) / pixel_std;
149 | int h = input_image_torch.size(2);
150 | int w = input_image_torch.size(3);
151 | int padh = image_size - h;
152 | int padw = image_size - w;
153 | input_image_torch = at::pad(input_image_torch, {0, padw, 0, padh});
154 | auto ret = mInOut["images"]->host2device((void *)(input_image_torch.data_ptr()), true, stream);
155 | return ret;
156 | }
157 |
158 | bool SamEmbedding::infer()
159 | {
160 | CHECK(cudaEventRecord(start, stream));
161 | auto ret = context->enqueueV2(mDeviceBindings.data(), stream, nullptr);
162 | return ret;
163 | }
164 |
165 | at::Tensor SamEmbedding::verifyOutput()
166 | {
167 | float ms{0.0f};
168 | CHECK(cudaEventRecord(end, stream));
169 | CHECK(cudaEventSynchronize(end));
170 | CHECK(cudaEventElapsedTime(&ms, start, end));
171 |
172 | auto dim0 = mEngine->getTensorShape("image_embeddings");
173 |
174 | // dims2str(dim0);
175 | // dims2str(dim1);
176 | at::Tensor preds;
177 | preds = at::zeros({dim0.d[0], dim0.d[1], dim0.d[2], dim0.d[3]}, at::kFloat);
178 | mInOut["image_embeddings"]->device2host((void *)(preds.data_ptr()), stream);
179 |
180 | // Wait for the work in the stream to complete
181 | CHECK(cudaStreamSynchronize(stream));
182 | // torch::save({preds}, "preds.pt");
183 | // cv::FileStorage storage("1.yaml", cv::FileStorage::WRITE);
184 | // storage << "image_embeddings" << points3dmatrix;
185 | return preds;
186 | }
187 |
188 | at::Tensor SamEmbedding::verifyOutput(std::string output_name)
189 | {
190 | float ms{0.0f};
191 | CHECK(cudaEventRecord(end, stream));
192 | CHECK(cudaEventSynchronize(end));
193 | CHECK(cudaEventElapsedTime(&ms, start, end));
194 |
195 | auto dim0 = mEngine->getTensorShape(output_name.c_str());
196 |
197 | // dims2str(dim0);
198 | // dims2str(dim1);
199 | at::Tensor preds;
200 | preds = at::zeros({dim0.d[0], dim0.d[1], dim0.d[2], dim0.d[3]}, at::kFloat);
201 | mInOut[output_name]->device2host((void *)(preds.data_ptr()), stream);
202 |
203 | // Wait for the work in the stream to complete
204 | CHECK(cudaStreamSynchronize(stream));
205 | // torch::save({preds}, "preds.pt");
206 | // cv::FileStorage storage("1.yaml", cv::FileStorage::WRITE);
207 | // storage << "image_embeddings" << points3dmatrix;
208 | return preds;
209 | }
210 |
211 | ///////////////////////////////////////////////////
212 |
213 | class SamEmbedding2
214 | {
215 | public:
216 | SamEmbedding2(std::string bufferName, std::shared_ptr &engine);
217 | ~SamEmbedding2();
218 |
219 | int prepareInput(at::Tensor input_image_torch);
220 | bool infer();
221 | at::Tensor verifyOutput();
222 |
223 | public:
224 | std::shared_ptr mEngine;
225 | std::unique_ptr context;
226 |
227 | cudaStream_t stream;
228 | cudaEvent_t start, end;
229 |
230 | std::vector mDeviceBindings;
231 | std::map> mInOut;
232 | std::vector pad_info;
233 | std::vector names;
234 | std::string mBufferName;
235 | };
236 |
237 | SamEmbedding2::SamEmbedding2(std::string bufferName, std::shared_ptr &engine) :
238 | mBufferName(bufferName), mEngine(engine)
239 | {
240 | context = std::unique_ptr(mEngine->createExecutionContext());
241 | if (!context)
242 | {
243 | std::cerr << "create context error" << std::endl;
244 | }
245 |
246 | CHECK(cudaStreamCreate(&stream));
247 | CHECK(cudaEventCreateWithFlags(&start, cudaEventBlockingSync));
248 | CHECK(cudaEventCreateWithFlags(&end, cudaEventBlockingSync));
249 |
250 | for (int i = 0; i < mEngine->getNbBindings(); i++)
251 | {
252 | auto dims = mEngine->getBindingDimensions(i);
253 | auto tensor_name = mEngine->getBindingName(i);
254 | std::cout << "tensor_name: " << tensor_name << std::endl;
255 | // dims2str(dims);
256 | nvinfer1::DataType type = mEngine->getBindingDataType(i);
257 | // index2srt(type);
258 | int vecDim = mEngine->getBindingVectorizedDim(i);
259 | // std::cout << "vecDim:" << vecDim << std::endl;
260 | if (-1 != vecDim) // i.e., 0 != lgScalarsPerVector
261 | {
262 | int scalarsPerVec = mEngine->getBindingComponentsPerElement(i);
263 | std::cout << "scalarsPerVec" << scalarsPerVec << std::endl;
264 | }
265 | auto vol = std::accumulate(dims.d, dims.d + dims.nbDims, int64_t{1}, std::multiplies{});
266 | std::unique_ptr device_buffer{new algorithms::DeviceBuffer(vol, type)};
267 | mDeviceBindings.emplace_back(device_buffer->data());
268 | mInOut[tensor_name] = std::move(device_buffer);
269 | }
270 | }
271 |
272 | SamEmbedding2::~SamEmbedding2()
273 | {
274 | CHECK(cudaEventDestroy(start));
275 | CHECK(cudaEventDestroy(end));
276 | CHECK(cudaStreamDestroy(stream));
277 | }
278 |
279 | int SamEmbedding2::prepareInput(at::Tensor input_image_torch)
280 | {
281 | auto ret = mInOut["image_embeddings_part_1"]->host2device((void *)(input_image_torch.data_ptr()), false, stream);
282 | return ret;
283 | }
284 |
285 | bool SamEmbedding2::infer()
286 | {
287 | CHECK(cudaEventRecord(start, stream));
288 | auto ret = context->enqueueV2(mDeviceBindings.data(), stream, nullptr);
289 | return ret;
290 | }
291 |
292 | at::Tensor SamEmbedding2::verifyOutput()
293 | {
294 | float ms{0.0f};
295 | CHECK(cudaEventRecord(end, stream));
296 | CHECK(cudaEventSynchronize(end));
297 | CHECK(cudaEventElapsedTime(&ms, start, end));
298 |
299 | auto dim0 = mEngine->getTensorShape("image_embeddings_part_2");
300 |
301 | // dims2str(dim0);
302 | // dims2str(dim1);
303 | at::Tensor preds;
304 | preds = at::zeros({dim0.d[0], dim0.d[1], dim0.d[2], dim0.d[3]}, at::kFloat);
305 | mInOut["image_embeddings_part_2"]->device2host((void *)(preds.data_ptr()), stream);
306 |
307 | // Wait for the work in the stream to complete
308 | CHECK(cudaStreamSynchronize(stream));
309 | // torch::save({preds}, "preds.pt");
310 | // cv::FileStorage storage("1.yaml", cv::FileStorage::WRITE);
311 | // storage << "image_embeddings" << points3dmatrix;
312 | return preds;
313 | }
314 |
315 | ///////////////////////////////////////////////////
316 |
317 |
318 |
319 | class SamPromptEncoderAndMaskDecoder
320 | {
321 | public:
322 | SamPromptEncoderAndMaskDecoder(std::string bufferName, std::shared_ptr &engine, cv::Mat im, int width = 640, int height = 640);
323 | ~SamPromptEncoderAndMaskDecoder();
324 |
325 | int prepareInput(int x, int y, at::Tensor image_embeddings);
326 | int prepareInput(int x, int y, int x1, int y1, int x2, int y2, at::Tensor image_embeddings);
327 | int prepareInput(std::vector mult_pts, at::Tensor image_embeddings);
328 | bool infer();
329 | int verifyOutput();
330 | int verifyOutput(cv::Mat& roi);
331 | at::Tensor generator_colors(int num);
332 |
333 | template
334 | Type string2Num(const std::string &str);
335 |
336 | at::Tensor plot_masks(at::Tensor masks, at::Tensor im_gpu, float alpha);
337 |
338 | public:
339 | std::shared_ptr mEngine;
340 | std::unique_ptr context;
341 |
342 | cudaStream_t stream;
343 | cudaEvent_t start, end;
344 |
345 | std::vector mDeviceBindings;
346 | std::map> mInOut;
347 | std::vector pad_info;
348 | std::vector names;
349 | cv::Mat frame;
350 | cv::Mat img;
351 | int inp_width = 640;
352 | int inp_height = 640;
353 | std::string mBufferName;
354 | };
355 |
356 | SamPromptEncoderAndMaskDecoder::SamPromptEncoderAndMaskDecoder(std::string bufferName, std::shared_ptr &engine, cv::Mat im, int width, int height) : mBufferName(bufferName), mEngine(engine), frame(im), inp_width(width), inp_height(height)
357 | {
358 | context = std::unique_ptr(mEngine->createExecutionContext());
359 | if (!context)
360 | {
361 | std::cerr << "create context error" << std::endl;
362 | }
363 | // set input dims whichs name "point_coords "
364 | context->setBindingDimensions(1, nvinfer1::Dims3(1, 2, 2));
365 | // set input dims whichs name "point_label "
366 | context->setBindingDimensions(2, nvinfer1::Dims2(1, 2));
367 | // set input dims whichs name "point_label "
368 | // context->setBindingDimensions(5, nvinfer1::Dims2(frame.rows,frame.cols));
369 | CHECK(cudaStreamCreate(&stream));
370 | CHECK(cudaEventCreateWithFlags(&start, cudaEventBlockingSync));
371 | CHECK(cudaEventCreateWithFlags(&end, cudaEventBlockingSync));
372 |
373 | int nbopts = mEngine->getNbOptimizationProfiles();
374 | // std::cout << "nboopts: " << nbopts << std::endl;
375 | for (int i = 0; i < mEngine->getNbBindings(); i++)
376 | {
377 | // auto dims = mEngine->getBindingDimensions(i);
378 | auto tensor_name = mEngine->getBindingName(i);
379 | // std::cout << "tensor_name: " << tensor_name << std::endl;
380 | auto dims = context->getBindingDimensions(i);
381 | // dims2str(dims);
382 | nvinfer1::DataType type = mEngine->getBindingDataType(i);
383 | // index2srt(type);
384 | auto vol = std::accumulate(dims.d, dims.d + dims.nbDims, int64_t{1}, std::multiplies{});
385 | std::unique_ptr device_buffer{new algorithms::DeviceBuffer(vol, type)};
386 | mDeviceBindings.emplace_back(device_buffer->data());
387 | mInOut[tensor_name] = std::move(device_buffer);
388 | }
389 | }
390 |
391 | SamPromptEncoderAndMaskDecoder::~SamPromptEncoderAndMaskDecoder()
392 | {
393 | CHECK(cudaEventDestroy(start));
394 | CHECK(cudaEventDestroy(end));
395 | CHECK(cudaStreamDestroy(stream));
396 | }
397 |
398 | int SamPromptEncoderAndMaskDecoder::prepareInput(int x, int y, at::Tensor image_embeddings)
399 | {
400 | // at::Tensor image_embeddings;
401 |
402 | // torch::load(image_embeddings, "preds.pt");
403 | // std::cout << image_embeddings.sizes() << std::endl;
404 | int image_size = 1024;
405 | ResizeLongestSide transf(image_size);
406 |
407 | auto input_point = at::tensor({x, y}, at::kFloat).reshape({-1,2});
408 | auto input_label = at::tensor({1}, at::kFloat);
409 |
410 | auto trt_coord = at::concatenate({input_point, at::tensor({0, 0}, at::kFloat).unsqueeze(0)}, 0).unsqueeze(0);
411 | auto trt_label = at::concatenate({input_label, at::tensor({-1}, at::kFloat)}, 0).unsqueeze(0);
412 | // auto trt_coord = at::concatenate({input_point, at::tensor({x-100, y-100, x+100, y+100}, at::kFloat).reshape({-1,2})}, 0).unsqueeze(0);
413 | // auto trt_label = at::concatenate({input_label, at::tensor({2,3}, at::kFloat)}, 0).unsqueeze(0);
414 | trt_coord = transf.apply_coords(trt_coord, {frame.rows, frame.cols});
415 | // std::cout << "trt_coord " << trt_coord.sizes() << std::endl;
416 | auto trt_mask_input = at::zeros({1, 1, 256, 256}, at::kFloat);
417 | auto trt_has_mask_input = at::zeros(1, at::kFloat);
418 |
419 | context->setBindingDimensions(1, nvinfer1::Dims3(trt_coord.size(0), trt_coord.size(1), trt_coord.size(2)));
420 | // set input dims whichs name "point_label "
421 | context->setBindingDimensions(2, nvinfer1::Dims2(trt_coord.size(0), trt_coord.size(1)));
422 | int nbopts = mEngine->getNbOptimizationProfiles();
423 | std::cout << "nboopts: " << nbopts << std::endl;
424 | for (int i = 0; i < mEngine->getNbBindings(); i++)
425 | {
426 | // auto dims = mEngine->getBindingDimensions(i);
427 | auto tensor_name = mEngine->getBindingName(i);
428 | std::cout << "tensor_name: " << tensor_name << std::endl;
429 | auto dims = context->getBindingDimensions(i);
430 | dims2str(dims);
431 | nvinfer1::DataType type = mEngine->getBindingDataType(i);
432 | index2srt(type);
433 | auto vol = std::accumulate(dims.d, dims.d + dims.nbDims, int64_t{1}, std::multiplies{});
434 |
435 | mInOut[tensor_name]->resize(dims);
436 | }
437 |
438 | CHECK(mInOut["image_embeddings"]->host2device((void *)(image_embeddings.data_ptr()), true, stream));
439 | CHECK(cudaStreamSynchronize(stream));
440 | CHECK(mInOut["point_coords"]->host2device((void *)(trt_coord.data_ptr()), true, stream));
441 | CHECK(cudaStreamSynchronize(stream));
442 | CHECK(mInOut["point_labels"]->host2device((void *)(trt_label.data_ptr()), true, stream));
443 | CHECK(cudaStreamSynchronize(stream));
444 | CHECK(mInOut["mask_input"]->host2device((void *)(trt_mask_input.data_ptr()), true, stream));
445 | CHECK(cudaStreamSynchronize(stream));
446 | CHECK(mInOut["has_mask_input"]->host2device((void *)(trt_has_mask_input.data_ptr()), true, stream));
447 | CHECK(cudaStreamSynchronize(stream));
448 | return 0;
449 | }
450 |
451 | int SamPromptEncoderAndMaskDecoder::prepareInput(int x, int y, int x1, int y1, int x2, int y2, at::Tensor image_embeddings)
452 | {
453 | // at::Tensor image_embeddings;
454 |
455 | // torch::load(image_embeddings, "preds.pt");
456 | // std::cout << image_embeddings.sizes() << std::endl;
457 | int image_size = 1024;
458 | ResizeLongestSide transf(image_size);
459 |
460 | auto input_point = at::tensor({x, y}, at::kFloat).reshape({-1,2});
461 | auto input_label = at::tensor({1}, at::kFloat);
462 |
463 | auto trt_coord = at::concatenate({input_point, at::tensor({x1, y1, x2, y2}, at::kFloat).reshape({-1,2})}, 0).unsqueeze(0);
464 | auto trt_label = at::concatenate({input_label, at::tensor({2,3}, at::kFloat)}, 0).unsqueeze(0);
465 | trt_coord = transf.apply_coords(trt_coord, {frame.rows, frame.cols});
466 | // std::cout << "trt_coord " << trt_coord.sizes() << std::endl;
467 | auto trt_mask_input = at::zeros({1, 1, 256, 256}, at::kFloat);
468 | auto trt_has_mask_input = at::zeros(1, at::kFloat);
469 |
470 | context->setBindingDimensions(1, nvinfer1::Dims3(trt_coord.size(0), trt_coord.size(1), trt_coord.size(2)));
471 | // set input dims whichs name "point_label "
472 | context->setBindingDimensions(2, nvinfer1::Dims2(trt_coord.size(0), trt_coord.size(1)));
473 |
474 | for (int i = 0; i < mEngine->getNbBindings(); i++)
475 | {
476 | // auto dims = mEngine->getBindingDimensions(i);
477 | auto tensor_name = mEngine->getBindingName(i);
478 | std::cout << "tensor_name: " << tensor_name << std::endl;
479 | auto dims = context->getBindingDimensions(i);
480 | dims2str(dims);
481 | nvinfer1::DataType type = mEngine->getBindingDataType(i);
482 | index2srt(type);
483 | auto vol = std::accumulate(dims.d, dims.d + dims.nbDims, int64_t{1}, std::multiplies{});
484 |
485 | mInOut[tensor_name]->resize(dims);
486 | }
487 |
488 | CHECK(mInOut["image_embeddings"]->host2device((void *)(image_embeddings.data_ptr()), true, stream));
489 | CHECK(cudaStreamSynchronize(stream));
490 | CHECK(mInOut["point_coords"]->host2device((void *)(trt_coord.data_ptr()), true, stream));
491 | CHECK(cudaStreamSynchronize(stream));
492 | CHECK(mInOut["point_labels"]->host2device((void *)(trt_label.data_ptr()), true, stream));
493 | CHECK(cudaStreamSynchronize(stream));
494 | CHECK(mInOut["mask_input"]->host2device((void *)(trt_mask_input.data_ptr()), true, stream));
495 | CHECK(cudaStreamSynchronize(stream));
496 | CHECK(mInOut["has_mask_input"]->host2device((void *)(trt_has_mask_input.data_ptr()), true, stream));
497 | CHECK(cudaStreamSynchronize(stream));
498 | return 0;
499 | }
500 |
501 | int SamPromptEncoderAndMaskDecoder::prepareInput(std::vector mult_pts, at::Tensor image_embeddings)
502 | {
503 | // at::Tensor image_embeddings;
504 |
505 | // torch::load(image_embeddings, "preds.pt");
506 | // std::cout << image_embeddings.sizes() << std::endl;
507 | int image_size = 1024;
508 | ResizeLongestSide transf(image_size);
509 |
510 | auto input_point = at::tensor(mult_pts, at::kFloat).reshape({-1,2});
511 | std::cout << input_point << std::endl;
512 | auto input_label = at::ones({int(mult_pts.size() / 2)}, at::kFloat);
513 | std::cout << input_label << std::endl;
514 |
515 | auto trt_coord = at::concatenate({input_point, at::tensor({0, 0}, at::kFloat).unsqueeze(0)}, 0).unsqueeze(0);
516 | auto trt_label = at::concatenate({input_label, at::tensor({-1}, at::kFloat)}, 0).unsqueeze(0);
517 | trt_coord = transf.apply_coords(trt_coord, {frame.rows, frame.cols});
518 | // std::cout << "trt_coord " << trt_coord.sizes() << std::endl;
519 | auto trt_mask_input = at::zeros({1, 1, 256, 256}, at::kFloat);
520 | auto trt_has_mask_input = at::zeros(1, at::kFloat);
521 |
522 | context->setBindingDimensions(1, nvinfer1::Dims3(trt_coord.size(0), trt_coord.size(1), trt_coord.size(2)));
523 | // set input dims whichs name "point_label "
524 | context->setBindingDimensions(2, nvinfer1::Dims2(trt_coord.size(0), trt_coord.size(1)));
525 |
526 | for (int i = 0; i < mEngine->getNbBindings(); i++)
527 | {
528 | // auto dims = mEngine->getBindingDimensions(i);
529 | auto tensor_name = mEngine->getBindingName(i);
530 | std::cout << "tensor_name: " << tensor_name << std::endl;
531 | auto dims = context->getBindingDimensions(i);
532 | dims2str(dims);
533 | nvinfer1::DataType type = mEngine->getBindingDataType(i);
534 | index2srt(type);
535 | auto vol = std::accumulate(dims.d, dims.d + dims.nbDims, int64_t{1}, std::multiplies{});
536 |
537 | mInOut[tensor_name]->resize(dims);
538 | }
539 |
540 | CHECK(mInOut["image_embeddings"]->host2device((void *)(image_embeddings.data_ptr()), true, stream));
541 | CHECK(cudaStreamSynchronize(stream));
542 | CHECK(mInOut["point_coords"]->host2device((void *)(trt_coord.data_ptr()), true, stream));
543 | CHECK(cudaStreamSynchronize(stream));
544 | CHECK(mInOut["point_labels"]->host2device((void *)(trt_label.data_ptr()), true, stream));
545 | CHECK(cudaStreamSynchronize(stream));
546 | CHECK(mInOut["mask_input"]->host2device((void *)(trt_mask_input.data_ptr()), true, stream));
547 | CHECK(cudaStreamSynchronize(stream));
548 | CHECK(mInOut["has_mask_input"]->host2device((void *)(trt_has_mask_input.data_ptr()), true, stream));
549 | CHECK(cudaStreamSynchronize(stream));
550 | return 0;
551 | }
552 |
553 | bool SamPromptEncoderAndMaskDecoder::infer()
554 | {
555 | CHECK(cudaEventRecord(start, stream));
556 | auto ret = context->enqueueV2(mDeviceBindings.data(), stream, nullptr);
557 | return ret;
558 | }
559 |
560 | int SamPromptEncoderAndMaskDecoder::verifyOutput()
561 | {
562 | float ms{0.0f};
563 | CHECK(cudaEventRecord(end, stream));
564 | CHECK(cudaEventSynchronize(end));
565 | CHECK(cudaEventElapsedTime(&ms, start, end));
566 |
567 | auto dim0 = mEngine->getTensorShape("masks");
568 | auto dim1 = mEngine->getTensorShape("scores");
569 | // dims2str(dim0);
570 | // dims2str(dim1);
571 | at::Tensor masks;
572 | masks = at::zeros({dim0.d[0], dim0.d[1], dim0.d[2], dim0.d[3]}, at::kFloat);
573 | mInOut["masks"]->device2host((void *)(masks.data_ptr()), stream);
574 | // Wait for the work in the stream to complete
575 | CHECK(cudaStreamSynchronize(stream));
576 |
577 | int longest_side = 1024;
578 |
579 | namespace F = torch::nn::functional;
580 | masks = F::interpolate(masks, F::InterpolateFuncOptions().size(std::vector({longest_side, longest_side})).mode(torch::kBilinear).align_corners(false));
581 | // at::IntArrayRef input_image_size{frame.rows, frame.cols};
582 | ResizeLongestSide transf(longest_side);
583 | auto target_size = transf.get_preprocess_shape(frame.rows, frame.cols);
584 | masks = masks.index({"...", Slice(None, target_size[0]), Slice(None, target_size[1])});
585 |
586 | masks = F::interpolate(masks, F::InterpolateFuncOptions().size(std::vector({frame.rows, frame.cols})).mode(torch::kBilinear).align_corners(false));
587 | std::cout << "masks: " << masks.sizes() << std::endl;
588 |
589 | at::Tensor iou_predictions;
590 | iou_predictions = at::zeros({dim0.d[0], dim0.d[1]}, at::kFloat);
591 | mInOut["scores"]->device2host((void *)(iou_predictions.data_ptr()), stream);
592 | // Wait for the work in the stream to complete
593 | CHECK(cudaStreamSynchronize(stream));
594 |
595 | torch::DeviceType device_type;
596 | if (torch::cuda::is_available())
597 | {
598 | std::cout << "CUDA available! Training on GPU." << std::endl;
599 | device_type = torch::kCUDA;
600 | }
601 | else
602 | {
603 | std::cout << "Training on CPU." << std::endl;
604 | device_type = torch::kCPU;
605 | }
606 |
607 | torch::Device device(device_type);
608 | masks = masks.gt(0.) * 1.0;
609 | std::cout << "max " << masks.max() << std::endl;
610 | // masks = masks.sigmoid();
611 | std::cout << "masks: " << masks.sizes() << std::endl;
612 | masks = masks.to(device);
613 | std::cout << "iou_predictions: " << iou_predictions << std::endl;
614 | cv::Mat img;
615 | // cv::Mat frame = cv::imread("D:/projects/detections/data/truck.jpg");
616 | frame.convertTo(img, CV_32F, 1.0 / 255);
617 | at::Tensor im_gpu =
618 | at::from_blob(img.data, {img.rows, img.cols, img.channels()})
619 | .permute({2, 0, 1})
620 | .contiguous()
621 | .to(device);
622 | auto results = plot_masks(masks, im_gpu, 0.5);
623 | auto t_img = results.to(torch::kCPU).clamp(0, 255).to(torch::kU8);
624 |
625 | auto img_ = cv::Mat(t_img.size(0), t_img.size(1), CV_8UC3, t_img.data_ptr());
626 | std::cout << "1111111111111111" << std::endl;
627 | cv::cvtColor(img_, img_, cv::COLOR_RGB2BGR);
628 | cv::imwrite("img1111.jpg",img_);
629 | // cv::imshow("img_", img_);
630 | return 0;
631 | }
632 |
633 | int SamPromptEncoderAndMaskDecoder::verifyOutput(cv::Mat& roi)
634 | {
635 | float ms{0.0f};
636 | CHECK(cudaEventRecord(end, stream));
637 | CHECK(cudaEventSynchronize(end));
638 | CHECK(cudaEventElapsedTime(&ms, start, end));
639 |
640 | auto dim0 = mEngine->getTensorShape("masks");
641 | auto dim1 = mEngine->getTensorShape("scores");
642 | // dims2str(dim0);
643 | // dims2str(dim1);
644 | at::Tensor masks;
645 | masks = at::zeros({dim0.d[0], dim0.d[1], dim0.d[2], dim0.d[3]}, at::kFloat);
646 | mInOut["masks"]->device2host((void *)(masks.data_ptr()), stream);
647 | // Wait for the work in the stream to complete
648 | CHECK(cudaStreamSynchronize(stream));
649 |
650 | int longest_side = 1024;
651 |
652 | namespace F = torch::nn::functional;
653 | masks = F::interpolate(masks, F::InterpolateFuncOptions().size(std::vector({longest_side, longest_side})).mode(torch::kBilinear).align_corners(false));
654 | // at::IntArrayRef input_image_size{frame.rows, frame.cols};
655 | ResizeLongestSide transf(longest_side);
656 | auto target_size = transf.get_preprocess_shape(frame.rows, frame.cols);
657 | masks = masks.index({"...", Slice(None, target_size[0]), Slice(None, target_size[1])});
658 |
659 | masks = F::interpolate(masks, F::InterpolateFuncOptions().size(std::vector({frame.rows, frame.cols})).mode(torch::kBilinear).align_corners(false));
660 | // std::cout << "masks: " << masks.sizes() << std::endl;
661 |
662 | at::Tensor iou_predictions;
663 | iou_predictions = at::zeros({dim0.d[0], dim0.d[1]}, at::kFloat);
664 | mInOut["scores"]->device2host((void *)(iou_predictions.data_ptr()), stream);
665 | // Wait for the work in the stream to complete
666 | CHECK(cudaStreamSynchronize(stream));
667 |
668 | masks = masks.gt(0.) * 1.0;
669 | masks = masks.squeeze(0).squeeze(0);
670 | masks = masks.to(torch::kCPU).to(torch::kU8);
671 | std::cout << "masks: " << masks.sizes() << std::endl;
672 | auto roi_ = cv::Mat(masks.size(0), masks.size(1), CV_8U, masks.data_ptr());
673 | roi_.copyTo(roi);
674 |
675 | return 0;
676 | }
677 |
678 | /*
679 | return [r g b] * n
680 | */
681 | at::Tensor SamPromptEncoderAndMaskDecoder::generator_colors(int num)
682 | {
683 |
684 | std::vector hexs = {"FF37C7", "FF9D97", "FF701F", "FFB21D", "CFD231", "48F90A", "92CC17", "3DDB86", "1A9334", "00D4BB",
685 | "2C99A8", "00C2FF", "344593", "6473FF", "0018EC", "", "520085", "CB38FF", "FF95C8", "FF3838"};
686 |
687 | std::vector tmp;
688 | for (int i = 0; i < num; ++i)
689 | {
690 | int r = string2Num(hexs[i].substr(0, 2));
691 | // std::cout << r << std::endl;
692 | int g = string2Num(hexs[i].substr(2, 2));
693 | // std::cout << g << std::endl;
694 | int b = string2Num(hexs[i].substr(4, 2));
695 | // std::cout << b << std::endl;
696 | tmp.emplace_back(r);
697 | tmp.emplace_back(g);
698 | tmp.emplace_back(b);
699 | }
700 | return at::from_blob(tmp.data(), {(int)tmp.size()}, at::TensorOptions(at::kInt));
701 | }
702 |
703 | template
704 | Type SamPromptEncoderAndMaskDecoder::string2Num(const std::string &str)
705 | {
706 | std::istringstream iss(str);
707 | Type num;
708 | iss >> std::hex >> num;
709 | return num;
710 | }
711 |
712 | /*
713 | Plot masks at once.
714 | Args:
715 | masks (tensor): predicted masks on cuda, shape: [n, h, w]
716 | colors (List[List[Int]]): colors for predicted masks, [[r, g, b] * n]
717 | im_gpu (tensor): img is in cuda, shape: [3, h, w], range: [0, 1]
718 | alpha (float): mask transparency: 0.0 fully transparent, 1.0 opaque
719 | */
720 |
721 | at::Tensor SamPromptEncoderAndMaskDecoder::plot_masks(at::Tensor masks, at::Tensor im_gpu, float alpha)
722 | {
723 | int n = masks.size(0);
724 | auto colors = generator_colors(n);
725 | colors = colors.to(masks.device()).to(at::kFloat).div(255).reshape({-1, 3}).unsqueeze(1).unsqueeze(2);
726 | // std::cout << "colors: " << colors.sizes() << std::endl;
727 | masks = masks.permute({0, 2, 3, 1}).contiguous();
728 | // std::cout << "masks: " << masks.sizes() << std::endl;
729 | auto masks_color = masks * (colors * alpha);
730 | // std::cout << "masks_color: " << masks_color.sizes() << std::endl;
731 | auto inv_alph_masks = (1 - masks * alpha);
732 | inv_alph_masks = inv_alph_masks.cumprod(0);
733 | // std::cout << "inv_alph_masks: " << inv_alph_masks.sizes() << std::endl;
734 |
735 | auto mcs = masks_color * inv_alph_masks;
736 | mcs = mcs.sum(0) * 2;
737 | // std::cout << "mcs: " << mcs.sizes() << std::endl;
738 | im_gpu = im_gpu.flip({0});
739 | // std::cout << "im_gpu: " << im_gpu.sizes() << std::endl;
740 | im_gpu = im_gpu.permute({1, 2, 0}).contiguous();
741 | // std::cout << "im_gpu: " << im_gpu.sizes() << std::endl;
742 | im_gpu = im_gpu * inv_alph_masks[-1] + mcs;
743 | // std::cout << "im_gpu: " << im_gpu.sizes() << std::endl;
744 | auto im_mask = (im_gpu * 255);
745 | return im_mask;
746 | }
747 | #endif
--------------------------------------------------------------------------------
/sam_utils.h:
--------------------------------------------------------------------------------
1 | #ifndef SAM_UTILS_H
2 | #define SAM_UTILS_H
3 |
4 | #include
5 | #include
6 | #include
7 |
8 |
9 | using namespace nvinfer1;
10 | using namespace nvonnxparser;
11 |
12 | #undef CHECK
13 | #define CHECK(status) \
14 | do \
15 | { \
16 | auto ret = (status); \
17 | if (ret != 0) \
18 | { \
19 | std::cerr << "Cuda failure: " << ret << std::endl; \
20 | abort(); \
21 | } \
22 | } while (0)
23 |
24 | void index2srt(nvinfer1::DataType dataType)
25 | {
26 | switch (dataType)
27 | {
28 | case nvinfer1::DataType::kFLOAT:
29 | std::cout << "nvinfer1::DataType::kFLOAT" << std::endl;
30 | break;
31 | case nvinfer1::DataType::kHALF:
32 | std::cout << "nvinfer1::DataType::kHALF" << std::endl;
33 | break;
34 | case nvinfer1::DataType::kINT8:
35 | std::cout << "nvinfer1::DataType::kINT8" << std::endl;
36 | break;
37 | case nvinfer1::DataType::kINT32:
38 | std::cout << "nvinfer1::DataType::kINT32" << std::endl;
39 | break;
40 | case nvinfer1::DataType::kBOOL:
41 | std::cout << "nvinfer1::DataType::kBOOL" << std::endl;
42 | break;
43 | case nvinfer1::DataType::kUINT8:
44 | std::cout << "nvinfer1::DataType::kUINT8" << std::endl;
45 | break;
46 |
47 | default:
48 | break;
49 | }
50 | }
51 |
52 | void dims2str(nvinfer1::Dims dims)
53 | {
54 | std::string o_s("[");
55 | for (size_t i = 0; i < dims.nbDims; i++)
56 | {
57 | if (i > 0)
58 | o_s += ", ";
59 | o_s += std::to_string(dims.d[i]);
60 | }
61 | o_s += "]";
62 | std::cout << o_s << std::endl;
63 | }
64 | class Logger : public nvinfer1::ILogger
65 | {
66 | void log(Severity severity, const char *msg) noexcept override
67 | {
68 | // suppress info-level messages
69 | if (severity <= Severity::kWARNING)
70 | std::cout << msg << std::endl;
71 | }
72 | } logger;
73 |
74 | #endif
--------------------------------------------------------------------------------
/truck.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mingj2021/segment-anything-tensorrt/bd27bc114fc6f6881c905d56f802b3d7ab3f2eb9/truck.gif
--------------------------------------------------------------------------------