├── LICENSE.md
├── README.md
├── assets
├── dog.jpg
├── dog_mask.jpg
├── dogs.jpg
├── dogs_mask.jpg
├── masked_dogs.jpg
├── segment_with_single_click.gif
└── video.gif
├── main.cpp
├── nanosam.sln
├── nanosam.vcxproj
├── nanosam.vcxproj.filters
├── nanosam.vcxproj.user
├── nanosam
├── config.h
├── cuda_utils.h
├── logging.h
├── macros.h
├── nanosam.cpp
├── nanosam.h
├── trt_module.cpp
└── trt_module.h
└── utils.h
/LICENSE.md:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
NanoSAM C++
3 |
4 | This repo provides a TensorRT C++ implementation of Nvidia's [NanoSAM](https://github.com/NVIDIA-AI-IOT/nanosam), a distilled segment-anything model, for real-time inference on GPU.
5 |
6 |
7 |
8 |
9 |
10 |
11 | ## Getting Started
12 | 1. There are two ways to load engines:
13 |
14 | 1. Load engines built by trtexec:
15 |
16 | ```cpp
17 | #include "nanosam/nanosam.h"
18 |
19 | NanoSam nanosam(
20 | "resnet18_image_encoder.engine",
21 | "mobile_sam_mask_decoder.engine"
22 | );
23 | ```
24 | 2. Build engines directly from onnx files:
25 |
26 | ```cpp
27 | NanoSam nanosam(
28 | "resnet18_image_encoder.onnx",
29 | "mobile_sam_mask_decoder.onnx"
30 | );
31 | ```
32 |
33 | 2. Segment an object using a prompt point:
34 |
35 | ```cpp
36 | Mat image = imread("assets/dog.jpg");
37 | // Foreground point
38 | vector points = { Point(1300, 900) };
39 | vector labels = { 1 };
40 |
41 | Mat mask = nanosam.predict(image, points, labels);
42 | ```
43 |
44 |
45 |
46 | Input
47 | Output
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 | 3. Create masks from bounding boxes:
56 |
57 | ```cpp
58 | Mat image = imread("assets/dogs.jpg");
59 | // Bounding box top-left and bottom-right points
60 | vector points = { Point(100, 100), Point(750, 759) };
61 | vector labels = { 2, 3 };
62 |
63 | Mat mask = nanosam.predict(image, points, labels);
64 | ```
65 |
66 |
67 |
68 | Input
69 | Output
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 | Notes
79 | The point labels may be
80 |
81 | | Point Label | Description |
82 | |:--------------------:|-------------|
83 | | 0 | Background point |
84 | | 1 | Foreground point |
85 | | 2 | Bounding box top-left |
86 | | 3 | Bounding box bottom-right |
87 |
88 |
89 |
90 | ## Performance
91 | The inference time includes the pre-preprocessing time and the post-processing time:
92 | | Device | Image Shape(WxH) | Model Shape(WxH) | Inference Time(ms) |
93 | |:---------------:|:------------:|:------------:|:------------:|
94 | | RTX4090 |2048x1365 |1024x1024 |14 |
95 |
96 | ## Installation
97 |
98 | 1. Download the image encoder: [resnet18_image_encoder.onnx](https://drive.google.com/file/d/14-SsvoaTl-esC3JOzomHDnI9OGgdO2OR/view?usp=drive_link)
99 | 2. Download the mask decoder: [mobile_sam_mask_decoder.onnx](https://drive.google.com/file/d/1jYNvnseTL49SNRx9PDcbkZ9DwsY8up7n/view?usp=drive_link)
100 | 3. Download the [TensorRT](https://developer.nvidia.com/tensorrt) zip file that matches the Windows version you are using.
101 | 4. Choose where you want to install TensorRT. The zip file will install everything into a subdirectory called `TensorRT-8.x.x.x`. This new subdirectory will be referred to as `` in the steps below.
102 | 5. Unzip the `TensorRT-8.x.x.x.Windows10.x86_64.cuda-x.x.zip` file to the location that you chose. Where:
103 | - `8.x.x.x` is your TensorRT version
104 | - `cuda-x.x` is CUDA version `11.8` or `12.0`
105 | 6. Add the TensorRT library files to your system `PATH`. To do so, copy the DLL files from `/lib` to your CUDA installation directory, for example, `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\vX.Y\bin`, where `vX.Y` is your CUDA version. The CUDA installer should have already added the CUDA path to your system PATH.
106 | 7. Ensure that the following is present in your Visual Studio Solution project properties:
107 | - `/lib` has been added to your PATH variable and is present under **VC++ Directories > Executable Directories**.
108 | - `/include` is present under **C/C++ > General > Additional Directories**.
109 | - nvinfer.lib and any other LIB files that your project requires are present under **Linker > Input > Additional Dependencies**.
110 | 8. Download and install any recent [OpenCV](https://opencv.org/releases/) for Windows.
111 |
112 | ## Acknowledgement
113 | This project is based on the following projects:
114 | - [NanoSAM](https://github.com/NVIDIA-AI-IOT/nanosam) - The distilled Segment Anything (SAM).
115 | - [TensorRTx](https://github.com/wang-xinyu/tensorrtx) - Implementation of popular deep learning networks with TensorRT network definition API.
116 | - [TensorRT](https://github.com/NVIDIA/TensorRT/tree/release/8.6/samples) - TensorRT samples and api documentation.
117 | - ChatGPT - some of the simple functions were generated by ChatGPT :D
118 |
--------------------------------------------------------------------------------
/assets/dog.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spacewalk01/nanosam-cpp/1f1807b273fe9abcd5b685a747d91e42f0713ab2/assets/dog.jpg
--------------------------------------------------------------------------------
/assets/dog_mask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spacewalk01/nanosam-cpp/1f1807b273fe9abcd5b685a747d91e42f0713ab2/assets/dog_mask.jpg
--------------------------------------------------------------------------------
/assets/dogs.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spacewalk01/nanosam-cpp/1f1807b273fe9abcd5b685a747d91e42f0713ab2/assets/dogs.jpg
--------------------------------------------------------------------------------
/assets/dogs_mask.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spacewalk01/nanosam-cpp/1f1807b273fe9abcd5b685a747d91e42f0713ab2/assets/dogs_mask.jpg
--------------------------------------------------------------------------------
/assets/masked_dogs.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spacewalk01/nanosam-cpp/1f1807b273fe9abcd5b685a747d91e42f0713ab2/assets/masked_dogs.jpg
--------------------------------------------------------------------------------
/assets/segment_with_single_click.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spacewalk01/nanosam-cpp/1f1807b273fe9abcd5b685a747d91e42f0713ab2/assets/segment_with_single_click.gif
--------------------------------------------------------------------------------
/assets/video.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/spacewalk01/nanosam-cpp/1f1807b273fe9abcd5b685a747d91e42f0713ab2/assets/video.gif
--------------------------------------------------------------------------------
/main.cpp:
--------------------------------------------------------------------------------
1 | #include "nanosam/nanosam.h"
2 | #include "utils.h"
3 |
4 | void segmentClickedPoint(NanoSam& nanosam, string imagePath) {
5 |
6 | auto image = imread(imagePath);
7 |
8 | // Create a window to display the image
9 | cv::namedWindow("Image");
10 |
11 | // Data structure to hold clicked point
12 | PointData pointData;
13 | pointData.clicked = false;
14 | cv::Mat clonedImage = image.clone();
15 | int clickCount = 0;
16 |
17 | // Loop until Esc key is pressed
18 | while (true)
19 | {
20 | // Display the original image
21 | cv::imshow("Image", clonedImage);
22 |
23 | if (pointData.clicked)
24 | {
25 | pointData.clicked = false; // Reset clicked flag
26 |
27 | auto mask = nanosam.predict(image, { pointData.point }, { 1.0f });
28 |
29 | cv::circle(clonedImage, pointData.point, 5, cv::Scalar(0, 0, 255), -1);
30 |
31 | if (clickCount >= CITYSCAPES_COLORS.size()) clickCount = 0;
32 | overlay(image, mask, CITYSCAPES_COLORS[clickCount * 9]);
33 | clickCount++;
34 | }
35 |
36 | // Set the callback function for mouse events on the displayed cloned image
37 | cv::setMouseCallback("Image", onMouse, &pointData);
38 |
39 | // Check for Esc key press
40 | char key = cv::waitKey(1);
41 | if (key == 27) // ASCII code for Esc key
42 | {
43 | clonedImage = image.clone();
44 | }
45 | }
46 | cv::destroyAllWindows();
47 | }
48 |
49 | void segmentBbox(NanoSam& nanosam, string imagePath, string outputPath, vector bbox)
50 | {
51 | auto image = imread(imagePath);
52 |
53 | // 2 : Bounding box top-left, 3 : Bounding box bottom-right
54 | vector labels = { 2, 3 };
55 |
56 | auto mask = nanosam.predict(image, bbox, labels);
57 |
58 | overlay(image, mask);
59 |
60 | rectangle(image, bbox[0], bbox[1], cv::Scalar(255, 255, 0), 3);
61 |
62 | imwrite(outputPath, image);
63 | }
64 |
65 | void segmentWithPoint(NanoSam& nanosam, string imagePath, string outputPath, Point promptPoint)
66 | {
67 | auto image = imread(imagePath);
68 |
69 | // 1 : Foreground
70 | vector labels = { 1.0f };
71 |
72 | auto mask = nanosam.predict(image, { promptPoint }, labels);
73 |
74 | overlay(image, mask);
75 |
76 | imwrite(outputPath, image);
77 | }
78 |
79 | int main()
80 | {
81 | /* 1. Load engine examples */
82 |
83 | // Option 1: Load the engines
84 | //NanoSam nanosam("data/resnet18_image_encoder.engine", "data/mobile_sam_mask_decoder.engine");
85 |
86 | // Option 2: Build the engines from onnx files
87 | NanoSam nanosam("data/resnet18_image_encoder.onnx", "data/mobile_sam_mask_decoder.onnx");
88 |
89 | /* 2. Segmentation examples */
90 |
91 | // Demo 1: Segment using a point
92 | segmentWithPoint(nanosam, "assets/dog.jpg", "assets/dog_mask.jpg", Point(1300, 900));
93 |
94 | // Demo 2: Segment using a bounding box
95 | segmentBbox(nanosam, "assets/dogs.jpg", "assets/dogs_mask.jpg", { Point(100, 100), Point(750, 759) });
96 |
97 | // Demo 3: Segment the clicked object
98 | segmentClickedPoint(nanosam, "assets/dogs.jpg");
99 |
100 | return 0;
101 | }
102 |
--------------------------------------------------------------------------------
/nanosam.sln:
--------------------------------------------------------------------------------
1 |
2 | Microsoft Visual Studio Solution File, Format Version 12.00
3 | # Visual Studio Version 17
4 | VisualStudioVersion = 17.7.34202.233
5 | MinimumVisualStudioVersion = 10.0.40219.1
6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "nanosam", "nanosam.vcxproj", "{A6DD5CA9-1C7E-46B4-B0CA-F2D17D21DDCC}"
7 | EndProject
8 | Global
9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution
10 | Debug|x64 = Debug|x64
11 | Debug|x86 = Debug|x86
12 | Release|x64 = Release|x64
13 | Release|x86 = Release|x86
14 | EndGlobalSection
15 | GlobalSection(ProjectConfigurationPlatforms) = postSolution
16 | {A6DD5CA9-1C7E-46B4-B0CA-F2D17D21DDCC}.Debug|x64.ActiveCfg = Debug|x64
17 | {A6DD5CA9-1C7E-46B4-B0CA-F2D17D21DDCC}.Debug|x64.Build.0 = Debug|x64
18 | {A6DD5CA9-1C7E-46B4-B0CA-F2D17D21DDCC}.Debug|x86.ActiveCfg = Debug|Win32
19 | {A6DD5CA9-1C7E-46B4-B0CA-F2D17D21DDCC}.Debug|x86.Build.0 = Debug|Win32
20 | {A6DD5CA9-1C7E-46B4-B0CA-F2D17D21DDCC}.Release|x64.ActiveCfg = Release|x64
21 | {A6DD5CA9-1C7E-46B4-B0CA-F2D17D21DDCC}.Release|x64.Build.0 = Release|x64
22 | {A6DD5CA9-1C7E-46B4-B0CA-F2D17D21DDCC}.Release|x86.ActiveCfg = Release|Win32
23 | {A6DD5CA9-1C7E-46B4-B0CA-F2D17D21DDCC}.Release|x86.Build.0 = Release|Win32
24 | EndGlobalSection
25 | GlobalSection(SolutionProperties) = preSolution
26 | HideSolutionNode = FALSE
27 | EndGlobalSection
28 | GlobalSection(ExtensibilityGlobals) = postSolution
29 | SolutionGuid = {C73DAFCD-6916-427D-9DCB-F06A0259B9CC}
30 | EndGlobalSection
31 | EndGlobal
32 |
--------------------------------------------------------------------------------
/nanosam.vcxproj:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Debug
6 | Win32
7 |
8 |
9 | Release
10 | Win32
11 |
12 |
13 | Debug
14 | x64
15 |
16 |
17 | Release
18 | x64
19 |
20 |
21 |
22 | 17.0
23 | Win32Proj
24 | {a6dd5ca9-1c7e-46b4-b0ca-f2d17d21ddcc}
25 | nanosam
26 | 10.0
27 |
28 |
29 |
30 | Application
31 | true
32 | v143
33 | Unicode
34 |
35 |
36 | Application
37 | false
38 | v143
39 | true
40 | Unicode
41 |
42 |
43 | Application
44 | true
45 | v143
46 | Unicode
47 |
48 |
49 | Application
50 | false
51 | v143
52 | true
53 | Unicode
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 | Level3
77 | true
78 | WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions)
79 | true
80 |
81 |
82 | Console
83 | true
84 |
85 |
86 |
87 |
88 | Level3
89 | true
90 | true
91 | true
92 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions)
93 | true
94 |
95 |
96 | Console
97 | true
98 | true
99 | true
100 |
101 |
102 |
103 |
104 | Level3
105 | true
106 | _DEBUG;_CONSOLE;%(PreprocessorDefinitions)
107 | true
108 |
109 |
110 | Console
111 | true
112 |
113 |
114 |
115 |
116 | Level3
117 | true
118 | true
119 | true
120 | NDEBUG;_CONSOLE;_CRT_SECURE_NO_WARNINGS;_CRT_NONSTDC_NO_DEPRECATE;%(PreprocessorDefinitions)
121 | true
122 | C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\include;C:\TensorRT-8.6.0.12\include;%(AdditionalIncludeDirectories)
123 | stdcpp17
124 |
125 |
126 | Console
127 | true
128 | true
129 | true
130 | C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\lib\x64;C:\TensorRT-8.6.0.12\lib;%(AdditionalLibraryDirectories)
131 | nvinfer.lib;nvinfer_plugin.lib;nvonnxparser.lib;nvparsers.lib;cublas.lib;cuda.lib;cudart.lib;cudnn.lib;%(AdditionalDependencies)
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
--------------------------------------------------------------------------------
/nanosam.vcxproj.filters:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF}
6 | cpp;c;cc;cxx;c++;cppm;ixx;def;odl;idl;hpj;bat;asm;asmx
7 |
8 |
9 | {93995380-89BD-4b04-88EB-625FBE52EBFB}
10 | h;hh;hpp;hxx;h++;hm;inl;inc;ipp;xsd
11 |
12 |
13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01}
14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms
15 |
16 |
17 | {9485532c-d9ca-4d0a-af13-e349120532da}
18 |
19 |
20 |
21 |
22 | Source Files
23 |
24 |
25 | nanosam
26 |
27 |
28 | nanosam
29 |
30 |
31 |
32 |
33 | Header Files
34 |
35 |
36 | nanosam
37 |
38 |
39 | nanosam
40 |
41 |
42 | nanosam
43 |
44 |
45 | nanosam
46 |
47 |
48 | nanosam
49 |
50 |
51 | nanosam
52 |
53 |
54 |
--------------------------------------------------------------------------------
/nanosam.vcxproj.user:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/nanosam/config.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #define USE_FP16 // set USE_FP16 or USE_FP32
4 |
5 | #define MAX_NUM_PROMPTS 1
6 |
7 | // Model Params
8 | #define MODEL_INPUT_WIDTH 1024.0f
9 | #define MODEL_INPUT_HEIGHT 1024.0f
10 | #define HIDDEN_DIM 256
11 | #define NUM_LABELS 4
12 | #define FEATURE_WIDTH 64
13 | #define FEATURE_HEIGHT 64
14 |
--------------------------------------------------------------------------------
/nanosam/cuda_utils.h:
--------------------------------------------------------------------------------
1 | #ifndef TRTX_CUDA_UTILS_H_
2 | #define TRTX_CUDA_UTILS_H_
3 |
4 | #include
5 |
6 | #ifndef CUDA_CHECK
7 | #define CUDA_CHECK(callstr)\
8 | {\
9 | cudaError_t error_code = callstr;\
10 | if (error_code != cudaSuccess) {\
11 | std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__;\
12 | assert(0);\
13 | }\
14 | }
15 | #endif // CUDA_CHECK
16 |
17 | #define CHECK_RETURN_W_MSG(status, val, errMsg) \
18 | do \
19 | { \
20 | if (!(status)) \
21 | { \
22 | sample::gLogError << errMsg << " Error in " << __FILE__ << ", function " << FN_NAME << "(), line " << __LINE__ \
23 | << std::endl; \
24 | return val; \
25 | } \
26 | } while (0)
27 |
28 |
29 | #endif // TRTX_CUDA_UTILS_H_
30 |
31 |
--------------------------------------------------------------------------------
/nanosam/logging.h:
--------------------------------------------------------------------------------
1 | /*
2 | * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
3 | *
4 | * Licensed under the Apache License, Version 2.0 (the "License");
5 | * you may not use this file except in compliance with the License.
6 | * You may obtain a copy of the License at
7 | *
8 | * http://www.apache.org/licenses/LICENSE-2.0
9 | *
10 | * Unless required by applicable law or agreed to in writing, software
11 | * distributed under the License is distributed on an "AS IS" BASIS,
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | * See the License for the specific language governing permissions and
14 | * limitations under the License.
15 | */
16 |
17 | #ifndef TENSORRT_LOGGING_H
18 | #define TENSORRT_LOGGING_H
19 |
20 | #include "NvInferRuntimeCommon.h"
21 | #include
22 | #include
23 | #include
24 | #include
25 | #include
26 | #include
27 | #include
28 | #include "macros.h"
29 |
30 | using Severity = nvinfer1::ILogger::Severity;
31 |
32 | class LogStreamConsumerBuffer : public std::stringbuf
33 | {
34 | public:
35 | LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog)
36 | : mOutput(stream)
37 | , mPrefix(prefix)
38 | , mShouldLog(shouldLog)
39 | {
40 | }
41 |
42 | LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other)
43 | : mOutput(other.mOutput)
44 | {
45 | }
46 |
47 | ~LogStreamConsumerBuffer()
48 | {
49 | // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence
50 | // std::streambuf::pptr() gives a pointer to the current position of the output sequence
51 | // if the pointer to the beginning is not equal to the pointer to the current position,
52 | // call putOutput() to log the output to the stream
53 | if (pbase() != pptr())
54 | {
55 | putOutput();
56 | }
57 | }
58 |
59 | // synchronizes the stream buffer and returns 0 on success
60 | // synchronizing the stream buffer consists of inserting the buffer contents into the stream,
61 | // resetting the buffer and flushing the stream
62 | virtual int sync()
63 | {
64 | putOutput();
65 | return 0;
66 | }
67 |
68 | void putOutput()
69 | {
70 | if (mShouldLog)
71 | {
72 | // prepend timestamp
73 | std::time_t timestamp = std::time(nullptr);
74 | tm* tm_local = std::localtime(×tamp);
75 | std::cout << "[";
76 | std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/";
77 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/";
78 | std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-";
79 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":";
80 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":";
81 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] ";
82 | // std::stringbuf::str() gets the string contents of the buffer
83 | // insert the buffer contents pre-appended by the appropriate prefix into the stream
84 | mOutput << mPrefix << str();
85 | // set the buffer to empty
86 | str("");
87 | // flush the stream
88 | mOutput.flush();
89 | }
90 | }
91 |
92 | void setShouldLog(bool shouldLog)
93 | {
94 | mShouldLog = shouldLog;
95 | }
96 |
97 | private:
98 | std::ostream& mOutput;
99 | std::string mPrefix;
100 | bool mShouldLog;
101 | };
102 |
103 | //!
104 | //! \class LogStreamConsumerBase
105 | //! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer
106 | //!
107 | class LogStreamConsumerBase
108 | {
109 | public:
110 | LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog)
111 | : mBuffer(stream, prefix, shouldLog)
112 | {
113 | }
114 |
115 | protected:
116 | LogStreamConsumerBuffer mBuffer;
117 | };
118 |
119 | //!
120 | //! \class LogStreamConsumer
121 | //! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages.
122 | //! Order of base classes is LogStreamConsumerBase and then std::ostream.
123 | //! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field
124 | //! in LogStreamConsumer and then the address of the buffer is passed to std::ostream.
125 | //! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream.
126 | //! Please do not change the order of the parent classes.
127 | //!
128 | class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream
129 | {
130 | public:
131 | //! \brief Creates a LogStreamConsumer which logs messages with level severity.
132 | //! Reportable severity determines if the messages are severe enough to be logged.
133 | LogStreamConsumer(Severity reportableSeverity, Severity severity)
134 | : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity)
135 | , std::ostream(&mBuffer) // links the stream buffer with the stream
136 | , mShouldLog(severity <= reportableSeverity)
137 | , mSeverity(severity)
138 | {
139 | }
140 |
141 | LogStreamConsumer(LogStreamConsumer&& other)
142 | : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog)
143 | , std::ostream(&mBuffer) // links the stream buffer with the stream
144 | , mShouldLog(other.mShouldLog)
145 | , mSeverity(other.mSeverity)
146 | {
147 | }
148 |
149 | void setReportableSeverity(Severity reportableSeverity)
150 | {
151 | mShouldLog = mSeverity <= reportableSeverity;
152 | mBuffer.setShouldLog(mShouldLog);
153 | }
154 |
155 | private:
156 | static std::ostream& severityOstream(Severity severity)
157 | {
158 | return severity >= Severity::kINFO ? std::cout : std::cerr;
159 | }
160 |
161 | static std::string severityPrefix(Severity severity)
162 | {
163 | switch (severity)
164 | {
165 | case Severity::kINTERNAL_ERROR: return "[F] ";
166 | case Severity::kERROR: return "[E] ";
167 | case Severity::kWARNING: return "[W] ";
168 | case Severity::kINFO: return "[I] ";
169 | case Severity::kVERBOSE: return "[V] ";
170 | default: assert(0); return "";
171 | }
172 | }
173 |
174 | bool mShouldLog;
175 | Severity mSeverity;
176 | };
177 |
178 | //! \class Logger
179 | //!
180 | //! \brief Class which manages logging of TensorRT tools and samples
181 | //!
182 | //! \details This class provides a common interface for TensorRT tools and samples to log information to the console,
183 | //! and supports logging two types of messages:
184 | //!
185 | //! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal)
186 | //! - Test pass/fail messages
187 | //!
188 | //! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is
189 | //! that the logic for controlling the verbosity and formatting of sample output is centralized in one location.
190 | //!
191 | //! In the future, this class could be extended to support dumping test results to a file in some standard format
192 | //! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run).
193 | //!
194 | //! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger
195 | //! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT
196 | //! library and messages coming from the sample.
197 | //!
198 | //! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the
199 | //! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger
200 | //! object.
201 |
202 | class Logger : public nvinfer1::ILogger
203 | {
204 | public:
205 | Logger(Severity severity = Severity::kWARNING)
206 | : mReportableSeverity(severity)
207 | {
208 | }
209 |
210 | //!
211 | //! \enum TestResult
212 | //! \brief Represents the state of a given test
213 | //!
214 | enum class TestResult
215 | {
216 | kRUNNING, //!< The test is running
217 | kPASSED, //!< The test passed
218 | kFAILED, //!< The test failed
219 | kWAIVED //!< The test was waived
220 | };
221 |
222 | //!
223 | //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger
224 | //! \return The nvinfer1::ILogger associated with this Logger
225 | //!
226 | //! TODO Once all samples are updated to use this method to register the logger with TensorRT,
227 | //! we can eliminate the inheritance of Logger from ILogger
228 | //!
229 | nvinfer1::ILogger& getTRTLogger()
230 | {
231 | return *this;
232 | }
233 |
234 | //!
235 | //! \brief Implementation of the nvinfer1::ILogger::log() virtual method
236 | //!
237 | //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the
238 | //! inheritance from nvinfer1::ILogger
239 | //!
240 | void log(Severity severity, const char* msg) TRT_NOEXCEPT override
241 | {
242 | LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl;
243 | }
244 |
245 | //!
246 | //! \brief Method for controlling the verbosity of logging output
247 | //!
248 | //! \param severity The logger will only emit messages that have severity of this level or higher.
249 | //!
250 | void setReportableSeverity(Severity severity)
251 | {
252 | mReportableSeverity = severity;
253 | }
254 |
255 | //!
256 | //! \brief Opaque handle that holds logging information for a particular test
257 | //!
258 | //! This object is an opaque handle to information used by the Logger to print test results.
259 | //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used
260 | //! with Logger::reportTest{Start,End}().
261 | //!
262 | class TestAtom
263 | {
264 | public:
265 | TestAtom(TestAtom&&) = default;
266 |
267 | private:
268 | friend class Logger;
269 |
270 | TestAtom(bool started, const std::string& name, const std::string& cmdline)
271 | : mStarted(started)
272 | , mName(name)
273 | , mCmdline(cmdline)
274 | {
275 | }
276 |
277 | bool mStarted;
278 | std::string mName;
279 | std::string mCmdline;
280 | };
281 |
282 | //!
283 | //! \brief Define a test for logging
284 | //!
285 | //! \param[in] name The name of the test. This should be a string starting with
286 | //! "TensorRT" and containing dot-separated strings containing
287 | //! the characters [A-Za-z0-9_].
288 | //! For example, "TensorRT.sample_googlenet"
289 | //! \param[in] cmdline The command line used to reproduce the test
290 | //
291 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}().
292 | //!
293 | static TestAtom defineTest(const std::string& name, const std::string& cmdline)
294 | {
295 | return TestAtom(false, name, cmdline);
296 | }
297 |
298 | //!
299 | //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments
300 | //! as input
301 | //!
302 | //! \param[in] name The name of the test
303 | //! \param[in] argc The number of command-line arguments
304 | //! \param[in] argv The array of command-line arguments (given as C strings)
305 | //!
306 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}().
307 | static TestAtom defineTest(const std::string& name, int argc, char const* const* argv)
308 | {
309 | auto cmdline = genCmdlineString(argc, argv);
310 | return defineTest(name, cmdline);
311 | }
312 |
313 | //!
314 | //! \brief Report that a test has started.
315 | //!
316 | //! \pre reportTestStart() has not been called yet for the given testAtom
317 | //!
318 | //! \param[in] testAtom The handle to the test that has started
319 | //!
320 | static void reportTestStart(TestAtom& testAtom)
321 | {
322 | reportTestResult(testAtom, TestResult::kRUNNING);
323 | assert(!testAtom.mStarted);
324 | testAtom.mStarted = true;
325 | }
326 |
327 | //!
328 | //! \brief Report that a test has ended.
329 | //!
330 | //! \pre reportTestStart() has been called for the given testAtom
331 | //!
332 | //! \param[in] testAtom The handle to the test that has ended
333 | //! \param[in] result The result of the test. Should be one of TestResult::kPASSED,
334 | //! TestResult::kFAILED, TestResult::kWAIVED
335 | //!
336 | static void reportTestEnd(const TestAtom& testAtom, TestResult result)
337 | {
338 | assert(result != TestResult::kRUNNING);
339 | assert(testAtom.mStarted);
340 | reportTestResult(testAtom, result);
341 | }
342 |
343 | static int reportPass(const TestAtom& testAtom)
344 | {
345 | reportTestEnd(testAtom, TestResult::kPASSED);
346 | return EXIT_SUCCESS;
347 | }
348 |
349 | static int reportFail(const TestAtom& testAtom)
350 | {
351 | reportTestEnd(testAtom, TestResult::kFAILED);
352 | return EXIT_FAILURE;
353 | }
354 |
355 | static int reportWaive(const TestAtom& testAtom)
356 | {
357 | reportTestEnd(testAtom, TestResult::kWAIVED);
358 | return EXIT_SUCCESS;
359 | }
360 |
361 | static int reportTest(const TestAtom& testAtom, bool pass)
362 | {
363 | return pass ? reportPass(testAtom) : reportFail(testAtom);
364 | }
365 |
366 | Severity getReportableSeverity() const
367 | {
368 | return mReportableSeverity;
369 | }
370 |
371 | private:
372 | //!
373 | //! \brief returns an appropriate string for prefixing a log message with the given severity
374 | //!
375 | static const char* severityPrefix(Severity severity)
376 | {
377 | switch (severity)
378 | {
379 | case Severity::kINTERNAL_ERROR: return "[F] ";
380 | case Severity::kERROR: return "[E] ";
381 | case Severity::kWARNING: return "[W] ";
382 | case Severity::kINFO: return "[I] ";
383 | case Severity::kVERBOSE: return "[V] ";
384 | default: assert(0); return "";
385 | }
386 | }
387 |
388 | //!
389 | //! \brief returns an appropriate string for prefixing a test result message with the given result
390 | //!
391 | static const char* testResultString(TestResult result)
392 | {
393 | switch (result)
394 | {
395 | case TestResult::kRUNNING: return "RUNNING";
396 | case TestResult::kPASSED: return "PASSED";
397 | case TestResult::kFAILED: return "FAILED";
398 | case TestResult::kWAIVED: return "WAIVED";
399 | default: assert(0); return "";
400 | }
401 | }
402 |
403 | //!
404 | //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity
405 | //!
406 | static std::ostream& severityOstream(Severity severity)
407 | {
408 | return severity >= Severity::kINFO ? std::cout : std::cerr;
409 | }
410 |
411 | //!
412 | //! \brief method that implements logging test results
413 | //!
414 | static void reportTestResult(const TestAtom& testAtom, TestResult result)
415 | {
416 | severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # "
417 | << testAtom.mCmdline << std::endl;
418 | }
419 |
420 | //!
421 | //! \brief generate a command line string from the given (argc, argv) values
422 | //!
423 | static std::string genCmdlineString(int argc, char const* const* argv)
424 | {
425 | std::stringstream ss;
426 | for (int i = 0; i < argc; i++)
427 | {
428 | if (i > 0)
429 | ss << " ";
430 | ss << argv[i];
431 | }
432 | return ss.str();
433 | }
434 |
435 | Severity mReportableSeverity;
436 | };
437 |
438 | namespace
439 | {
440 |
441 | //!
442 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE
443 | //!
444 | //! Example usage:
445 | //!
446 | //! LOG_VERBOSE(logger) << "hello world" << std::endl;
447 | //!
448 | inline LogStreamConsumer LOG_VERBOSE(const Logger& logger)
449 | {
450 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE);
451 | }
452 |
453 | //!
454 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO
455 | //!
456 | //! Example usage:
457 | //!
458 | //! LOG_INFO(logger) << "hello world" << std::endl;
459 | //!
460 | inline LogStreamConsumer LOG_INFO(const Logger& logger)
461 | {
462 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO);
463 | }
464 |
465 | //!
466 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING
467 | //!
468 | //! Example usage:
469 | //!
470 | //! LOG_WARN(logger) << "hello world" << std::endl;
471 | //!
472 | inline LogStreamConsumer LOG_WARN(const Logger& logger)
473 | {
474 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING);
475 | }
476 |
477 | //!
478 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR
479 | //!
480 | //! Example usage:
481 | //!
482 | //! LOG_ERROR(logger) << "hello world" << std::endl;
483 | //!
484 | inline LogStreamConsumer LOG_ERROR(const Logger& logger)
485 | {
486 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR);
487 | }
488 |
489 | //!
490 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR
491 | // ("fatal" severity)
492 | //!
493 | //! Example usage:
494 | //!
495 | //! LOG_FATAL(logger) << "hello world" << std::endl;
496 | //!
497 | inline LogStreamConsumer LOG_FATAL(const Logger& logger)
498 | {
499 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR);
500 | }
501 |
502 | } // anonymous namespace
503 |
504 | #endif // TENSORRT_LOGGING_H
505 |
--------------------------------------------------------------------------------
/nanosam/macros.h:
--------------------------------------------------------------------------------
1 | #ifndef __MACROS_H
2 | #define __MACROS_H
3 |
4 | #ifdef API_EXPORTS
5 | #if defined(_MSC_VER)
6 | #define API __declspec(dllexport)
7 | #else
8 | #define API __attribute__((visibility("default")))
9 | #endif
10 | #else
11 |
12 | #if defined(_MSC_VER)
13 | #define API __declspec(dllimport)
14 | #else
15 | #define API
16 | #endif
17 | #endif // API_EXPORTS
18 |
19 | #if NV_TENSORRT_MAJOR >= 8
20 | #define TRT_NOEXCEPT noexcept
21 | #define TRT_CONST_ENQUEUE const
22 | #else
23 | #define TRT_NOEXCEPT
24 | #define TRT_CONST_ENQUEUE
25 | #endif
26 |
27 | #endif // __MACROS_H
28 |
--------------------------------------------------------------------------------
/nanosam/nanosam.cpp:
--------------------------------------------------------------------------------
1 | #include "nanosam.h"
2 | #include "config.h"
3 |
4 | using namespace std;
5 |
6 | // Constructor
7 | NanoSam::NanoSam(string encoderPath, string decoderPath)
8 | {
9 | mImageEncoder = new TRTModule(encoderPath,
10 | { "image" },
11 | { "image_embeddings" }, false, true);
12 |
13 | mMaskDecoder = new TRTModule(decoderPath,
14 | { "image_embeddings", "point_coords", "point_labels", "mask_input", "has_mask_input" },
15 | { "iou_predictions", "low_res_masks" }, true, false);
16 |
17 | mFeatures = new float[HIDDEN_DIM * FEATURE_WIDTH * FEATURE_HEIGHT];
18 | mMaskInput = new float[HIDDEN_DIM * HIDDEN_DIM];
19 | mHasMaskInput = new float;
20 | mIouPrediction = new float[NUM_LABELS];
21 | mLowResMasks = new float[NUM_LABELS * HIDDEN_DIM * HIDDEN_DIM];
22 | }
23 |
24 | // Deconstructor
25 | NanoSam::~NanoSam()
26 | {
27 | if (mFeatures) delete[] mFeatures;
28 | if (mMaskInput) delete[] mMaskInput;
29 | if (mIouPrediction) delete[] mIouPrediction;
30 | if (mLowResMasks) delete[] mLowResMasks;
31 |
32 | if (mImageEncoder) delete mImageEncoder;
33 | if (mMaskDecoder) delete mMaskDecoder;
34 | }
35 |
36 | // Perform inference using NanoSam models
37 | Mat NanoSam::predict(Mat& image, vector points, vector labels)
38 | {
39 | if (points.size() == 0) return cv::Mat(image.rows, image.cols, CV_32FC1);
40 |
41 | // Preprocess encoder input
42 | auto resizedImage = resizeImage(image, MODEL_INPUT_WIDTH, MODEL_INPUT_HEIGHT);
43 |
44 | // Encoder Inference
45 | mImageEncoder->setInput(resizedImage);
46 | mImageEncoder->infer();
47 | mImageEncoder->getOutput(mFeatures);
48 |
49 | // Preprocess decoder input
50 | auto pointData = new float[2 * points.size()];
51 | prepareDecoderInput(points, pointData, points.size(), image.cols, image.rows);
52 |
53 | // Decoder Inference
54 | mMaskDecoder->setInput(mFeatures, pointData, labels.data(), mMaskInput, mHasMaskInput, points.size());
55 | mMaskDecoder->infer();
56 | mMaskDecoder->getOutput(mIouPrediction, mLowResMasks);
57 |
58 | // Postprocessing
59 | Mat imgMask(HIDDEN_DIM, HIDDEN_DIM, CV_32FC1, mLowResMasks);
60 | upscaleMask(imgMask, image.cols, image.rows);
61 |
62 | delete[] pointData;
63 |
64 | return imgMask;
65 | }
66 |
67 | void NanoSam::prepareDecoderInput(vector& points, float* pointData, int numPoints, int imageWidth, int imageHeight)
68 | {
69 | float scale = MODEL_INPUT_WIDTH / max(imageWidth, imageHeight);
70 |
71 | for (int i = 0; i < numPoints; i++)
72 | {
73 | pointData[i * 2] = (float)points[i].x * scale;
74 | pointData[i * 2 + 1] = (float)points[i].y * scale;
75 | }
76 |
77 | for (int i = 0; i < HIDDEN_DIM * HIDDEN_DIM; i++)
78 | {
79 | mMaskInput[i] = 0;
80 | }
81 | *mHasMaskInput = 0;
82 | }
83 |
84 | Mat NanoSam::resizeImage(Mat& img, int inputWidth, int inputHeight)
85 | {
86 | int w, h;
87 | float aspectRatio = (float)img.cols / (float)img.rows;
88 |
89 | if (aspectRatio >= 1)
90 | {
91 | w = inputWidth;
92 | h = int(inputHeight / aspectRatio);
93 | }
94 | else
95 | {
96 | w = int(inputWidth * aspectRatio);
97 | h = inputHeight;
98 | }
99 |
100 | Mat re(h, w, CV_8UC3);
101 | cv::resize(img, re, re.size(), 0, 0, INTER_LINEAR);
102 | Mat out(inputHeight, inputWidth, CV_8UC3, 0.0);
103 | re.copyTo(out(Rect(0, 0, re.cols, re.rows)));
104 |
105 | return out;
106 | }
107 |
108 | void NanoSam::upscaleMask(Mat& mask, int targetWidth, int targetHeight, int size)
109 | {
110 | int limX, limY;
111 | if (targetWidth > targetHeight)
112 | {
113 | limX = size;
114 | limY = size * targetHeight / targetWidth;
115 | }
116 | else
117 | {
118 | limX = size * targetWidth / targetHeight;
119 | limY = size;
120 | }
121 |
122 | cv::resize(mask(Rect(0, 0, limX, limY)), mask, Size(targetWidth, targetHeight));
123 | }
124 |
--------------------------------------------------------------------------------
/nanosam/nanosam.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include "trt_module.h"
5 |
6 | class NanoSam
7 | {
8 |
9 | public:
10 |
11 | NanoSam(string encoderPath, string decoderPath);
12 |
13 | ~NanoSam();
14 |
15 | Mat predict(Mat& image, vector points, vector labels);
16 |
17 | private:
18 |
19 | // Variables
20 | float* mFeatures;
21 | float* mMaskInput;
22 | float* mHasMaskInput;
23 | float* mIouPrediction;
24 | float* mLowResMasks;
25 |
26 | TRTModule* mImageEncoder;
27 | TRTModule* mMaskDecoder;
28 |
29 | void upscaleMask(Mat& mask, int targetWidth, int targetHeight, int size = 256);
30 | Mat resizeImage(Mat& img, int modelWidth, int modelHeight);
31 | void prepareDecoderInput(vector& points, float* pointData, int numPoints, int imageWidth, int imageHeight);
32 |
33 | };
34 |
35 |
--------------------------------------------------------------------------------
/nanosam/trt_module.cpp:
--------------------------------------------------------------------------------
1 | #include "trt_module.h"
2 | #include "logging.h"
3 | #include "cuda_utils.h"
4 | #include "config.h"
5 | #include "macros.h"
6 |
7 | #include
8 | #include
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | static Logger gLogger;
15 |
16 |
17 | std::string getFileExtension(const std::string& filePath) {
18 | size_t dotPos = filePath.find_last_of(".");
19 | if (dotPos != std::string::npos) {
20 | return filePath.substr(dotPos + 1);
21 | }
22 | return ""; // No extension found
23 | }
24 |
25 | TRTModule::TRTModule(string modelPath, vector inputNames, vector outputNames, bool isDynamicShape, bool isFP16)
26 | {
27 | if (getFileExtension(modelPath) == "onnx")
28 | {
29 | cout << "Building Engine from " << modelPath << endl;
30 | build(modelPath, inputNames, outputNames, isDynamicShape, isFP16);
31 | }
32 | else
33 | {
34 | cout << "Deserializing Engine." << endl;
35 | deserializeEngine(modelPath, inputNames, outputNames);
36 | }
37 | }
38 |
39 | TRTModule::~TRTModule()
40 | {
41 | // Release stream and buffers
42 | cudaStreamDestroy(mCudaStream);
43 | for (int i = 0; i < mGpuBuffers.size(); i++)
44 | CUDA_CHECK(cudaFree(mGpuBuffers[i]));
45 | for (int i = 0; i < mCpuBuffers.size(); i++)
46 | delete[] mCpuBuffers[i];
47 |
48 | // Destroy the engine
49 | delete mContext;
50 | delete mEngine;
51 | delete mRuntime;
52 | }
53 |
54 | void TRTModule::build(string onnxPath, vector inputNames, vector outputNames, bool isDynamicShape, bool isFP16)
55 | {
56 | auto builder = createInferBuilder(gLogger);
57 | assert(builder != nullptr);
58 |
59 | const auto explicitBatch = 1U << static_cast(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
60 | INetworkDefinition* network = builder->createNetworkV2(explicitBatch);
61 | assert(network != nullptr);
62 |
63 | IBuilderConfig* config = builder->createBuilderConfig();
64 | assert(config != nullptr);
65 |
66 | if (isDynamicShape) // Only designed for NanoSAM mask decoder
67 | {
68 | auto profile = builder->createOptimizationProfile();
69 |
70 | profile->setDimensions(inputNames[1].c_str(), OptProfileSelector::kMIN, Dims3{ 1, 1, 2 });
71 | profile->setDimensions(inputNames[1].c_str(), OptProfileSelector::kOPT, Dims3{ 1, 1, 2 });
72 | profile->setDimensions(inputNames[1].c_str(), OptProfileSelector::kMAX, Dims3{ 1, 10, 2 });
73 |
74 | profile->setDimensions(inputNames[2].c_str(), OptProfileSelector::kMIN, Dims2{ 1, 1 });
75 | profile->setDimensions(inputNames[2].c_str(), OptProfileSelector::kOPT, Dims2{ 1, 1 });
76 | profile->setDimensions(inputNames[2].c_str(), OptProfileSelector::kMAX, Dims2{ 1, 10 });
77 |
78 | config->addOptimizationProfile(profile);
79 | }
80 |
81 | if (isFP16)
82 | {
83 | config->setFlag(BuilderFlag::kFP16);
84 | }
85 |
86 | nvonnxparser::IParser* parser = nvonnxparser::createParser(*network, gLogger);
87 | assert(parser != nullptr);
88 |
89 | bool parsed = parser->parseFromFile(onnxPath.c_str(), static_cast(gLogger.getReportableSeverity()));
90 | assert(parsed != nullptrt);
91 |
92 |
93 | // CUDA stream used for profiling by the builder.
94 | assert(mCudaStream != nullptr);
95 |
96 | IHostMemory* plan{ builder->buildSerializedNetwork(*network, *config) };
97 | assert(plan != nullptr);
98 |
99 | mRuntime = createInferRuntime(gLogger);
100 | assert(mRuntime != nullptr);
101 |
102 | mEngine = mRuntime->deserializeCudaEngine(plan->data(), plan->size(), nullptr);
103 | assert(mEngine != nullptr);
104 |
105 | mContext = mEngine->createExecutionContext();
106 | assert(mContext != nullptr);
107 |
108 | delete network;
109 | delete config;
110 | delete parser;
111 | delete plan;
112 |
113 | initialize(inputNames, outputNames);
114 | }
115 |
116 | void TRTModule::deserializeEngine(string engine_name, vector inputNames, vector outputNames)
117 | {
118 | std::ifstream file(engine_name, std::ios::binary);
119 | if (!file.good()) {
120 | std::cerr << "read " << engine_name << " error!" << std::endl;
121 | assert(false);
122 | }
123 | size_t size = 0;
124 | file.seekg(0, file.end);
125 | size = file.tellg();
126 | file.seekg(0, file.beg);
127 | char* serializedEngine = new char[size];
128 | assert(serializedEngine);
129 | file.read(serializedEngine, size);
130 | file.close();
131 |
132 | mRuntime = createInferRuntime(gLogger);
133 | assert(mRuntime);
134 | mEngine = mRuntime->deserializeCudaEngine(serializedEngine, size);
135 | assert(*mEngine);
136 | mContext = mEngine->createExecutionContext();
137 | assert(*mContext);
138 | delete[] serializedEngine;
139 |
140 | assert(mEngine->getNbBindings() != inputNames.size() + outputNames.size());
141 |
142 | initialize(inputNames, outputNames);
143 | }
144 |
145 | void TRTModule::initialize(vector inputNames, vector outputNames)
146 | {
147 | for (int i = 0; i < inputNames.size(); i++)
148 | {
149 | const int inputIndex = mEngine->getBindingIndex(inputNames[i].c_str());
150 | }
151 |
152 | for (int i = 0; i < outputNames.size(); i++)
153 | {
154 | const int outputIndex = mEngine->getBindingIndex(outputNames[i].c_str());
155 | }
156 |
157 | mGpuBuffers.resize(mEngine->getNbBindings());
158 | mCpuBuffers.resize(mEngine->getNbBindings());
159 |
160 | for (size_t i = 0; i < mEngine->getNbBindings(); ++i)
161 | {
162 | size_t binding_size = getSizeByDim(mEngine->getBindingDimensions(i));
163 | mBufferBindingSizes.push_back(binding_size);
164 | mBufferBindingBytes.push_back(binding_size * sizeof(float));
165 |
166 | mCpuBuffers[i] = new float[binding_size];
167 |
168 | cudaMalloc(&mGpuBuffers[i], mBufferBindingBytes[i]);
169 |
170 | if (mEngine->bindingIsInput(i))
171 | {
172 | mInputDims.push_back(mEngine->getBindingDimensions(i));
173 | }
174 | else
175 | {
176 | mOutputDims.push_back(mEngine->getBindingDimensions(i));
177 | }
178 | }
179 |
180 | CUDA_CHECK(cudaStreamCreate(&mCudaStream));
181 | }
182 |
183 | //!
184 | //! \brief Runs the TensorRT inference engine for this sample
185 | //!
186 | //! \details This function is the main execution function of the sample. It allocates the buffer,
187 | //! sets inputs and executes the engine.
188 | //!
189 | bool TRTModule::infer()
190 | {
191 | // Memcpy from host input buffers to device input buffers
192 | copyInputToDeviceAsync(mCudaStream);
193 |
194 | bool status = mContext->executeV2(mGpuBuffers.data());
195 |
196 | if (!status)
197 | {
198 | cout << "inference error!" << endl;
199 | return false;
200 | }
201 |
202 | // Memcpy from device output buffers to host output buffers
203 | copyOutputToHostAsync(mCudaStream);
204 |
205 | return true;
206 | }
207 |
208 | //!
209 | //! \brief Copy the contents of input host buffers to input device buffers asynchronously.
210 | //!
211 | void TRTModule::copyInputToDeviceAsync(const cudaStream_t& stream)
212 | {
213 | memcpyBuffers(true, false, true, stream);
214 | }
215 |
216 | //!
217 | //! \brief Copy the contents of output device buffers to output host buffers asynchronously.
218 | //!
219 | void TRTModule::copyOutputToHostAsync(const cudaStream_t& stream)
220 | {
221 | memcpyBuffers(false, true, true, stream);
222 | }
223 |
224 | void TRTModule::memcpyBuffers(const bool copyInput, const bool deviceToHost, const bool async, const cudaStream_t& stream)
225 | {
226 | for (int i = 0; i < mEngine->getNbBindings(); i++)
227 | {
228 | void* dstPtr = deviceToHost ? mCpuBuffers[i] : mGpuBuffers[i];
229 | const void* srcPtr = deviceToHost ? mGpuBuffers[i] : mCpuBuffers[i];
230 | const size_t byteSize = mBufferBindingBytes[i];
231 | const cudaMemcpyKind memcpyType = deviceToHost ? cudaMemcpyDeviceToHost : cudaMemcpyHostToDevice;
232 |
233 | if ((copyInput && mEngine->bindingIsInput(i)) || (!copyInput && !mEngine->bindingIsInput(i)))
234 | {
235 | if (async)
236 | {
237 | CUDA_CHECK(cudaMemcpyAsync(dstPtr, srcPtr, byteSize, memcpyType, stream));
238 | }
239 | else
240 | {
241 | CUDA_CHECK(cudaMemcpy(dstPtr, srcPtr, byteSize, memcpyType));
242 | }
243 | }
244 | }
245 | }
246 |
247 | size_t TRTModule::getSizeByDim(const Dims& dims)
248 | {
249 | size_t size = 1;
250 |
251 | for (size_t i = 0; i < dims.nbDims; ++i)
252 | {
253 | if (dims.d[i] == -1)
254 | size *= MAX_NUM_PROMPTS;
255 | else
256 | size *= dims.d[i];
257 | }
258 |
259 | return size;
260 | }
261 |
262 | void TRTModule::setInput(Mat& image)
263 | {
264 | const int inputH = mInputDims[0].d[2];
265 | const int inputW = mInputDims[0].d[3];
266 |
267 | int i = 0;
268 | for (int row = 0; row < image.rows; ++row)
269 | {
270 | uchar* uc_pixel = image.data + row * image.step;
271 | for (int col = 0; col < image.cols; ++col)
272 | {
273 | mCpuBuffers[0][i] = ((float)uc_pixel[2] / 255.0f - 0.485f) / 0.229f;
274 | mCpuBuffers[0][i + image.rows * image.cols] = ((float)uc_pixel[1] / 255.0f - 0.456f) / 0.224f;
275 | mCpuBuffers[0][i + 2 * image.rows * image.cols] = ((float)uc_pixel[0] / 255.0f - 0.406f) / 0.225f;
276 | uc_pixel += 3;
277 | ++i;
278 | }
279 | }
280 | }
281 |
282 | // Set dynamic input
283 | void TRTModule::setInput(float* features, float* imagePointCoords, float* imagePointLabels, float* maskInput, float* hasMaskInput, int numPoints)
284 | {
285 | delete[] mCpuBuffers[1];
286 | delete[] mCpuBuffers[2];
287 | mCpuBuffers[1] = new float[numPoints * 2];
288 | mCpuBuffers[2] = new float[numPoints];
289 |
290 | cudaMalloc(&mGpuBuffers[1], sizeof(float) * numPoints * 2);
291 | cudaMalloc(&mGpuBuffers[2], sizeof(float) * numPoints);
292 |
293 | mBufferBindingBytes[1] = sizeof(float) * numPoints * 2;
294 | mBufferBindingBytes[2] = sizeof(float) * numPoints;
295 |
296 | memcpy(mCpuBuffers[0], features, mBufferBindingBytes[0]);
297 | memcpy(mCpuBuffers[1], imagePointCoords, sizeof(float) * numPoints * 2);
298 | memcpy(mCpuBuffers[2], imagePointLabels, sizeof(float) * numPoints);
299 | memcpy(mCpuBuffers[3], maskInput, mBufferBindingBytes[3]);
300 | memcpy(mCpuBuffers[4], hasMaskInput, mBufferBindingBytes[4]);
301 |
302 | // Setting Dynamic Input Shape in TensorRT
303 | mContext->setOptimizationProfileAsync(0, mCudaStream);
304 | mContext->setBindingDimensions(1, Dims3{ 1, numPoints, 2 });
305 | mContext->setBindingDimensions(2, Dims2{ 1, numPoints });
306 | }
307 |
308 | void TRTModule::getOutput(float* features)
309 | {
310 | memcpy(features, mCpuBuffers[1], mBufferBindingBytes[1]);
311 | }
312 |
313 | void TRTModule::getOutput(float* iouPrediction, float* lowResolutionMasks)
314 | {
315 | memcpy(lowResolutionMasks, mCpuBuffers[5], mBufferBindingBytes[5]);
316 | memcpy(iouPrediction, mCpuBuffers[6], mBufferBindingBytes[6]);
317 | }
318 |
--------------------------------------------------------------------------------
/nanosam/trt_module.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "NvInfer.h"
4 | #include
5 |
6 | using namespace nvinfer1;
7 | using namespace std;
8 | using namespace cv;
9 |
10 | class TRTModule
11 | {
12 |
13 | public:
14 |
15 | TRTModule(string modelPath, vector inputNames, vector outputNames, bool isDynamicShape, bool isFP16);
16 |
17 | bool infer();
18 |
19 | void setInput(Mat& image);
20 |
21 | void setInput(float* features, float* imagePointCoords, float* imagePointLabels, float* maskInput, float* hasHaskInput, int numPoints);
22 |
23 | void getOutput(float* iouPrediction, float* lowResolutionMasks);
24 |
25 | void getOutput(float* features);
26 |
27 | ~TRTModule();
28 |
29 | private:
30 |
31 | void build(string onnxPath, vector inputNames, vector outputNames, bool isDynamicShape = false, bool isFP16 = false);
32 |
33 | void deserializeEngine(string engineName, vector inputNames, vector outputNames);
34 |
35 | void initialize(vector inputNames, vector outputNames);
36 |
37 | size_t getSizeByDim(const Dims& dims);
38 |
39 | void memcpyBuffers(const bool copyInput, const bool deviceToHost, const bool async, const cudaStream_t& stream = 0);
40 |
41 | void copyInputToDeviceAsync(const cudaStream_t& stream = 0);
42 |
43 | void copyOutputToHostAsync(const cudaStream_t& stream = 0);
44 |
45 |
46 | vector mInputDims; //!< The dimensions of the input to the network.
47 | vector mOutputDims; //!< The dimensions of the output to the network.
48 | vector mGpuBuffers; //!< The vector of device buffers needed for engine execution
49 | vector mCpuBuffers;
50 | vector mBufferBindingBytes;
51 | vector mBufferBindingSizes;
52 | cudaStream_t mCudaStream;
53 |
54 | IRuntime* mRuntime; //!< The TensorRT runtime used to deserialize the engine
55 | ICudaEngine* mEngine; //!< The TensorRT engine used to run the network
56 | IExecutionContext* mContext; //!< The context for executing inference using an ICudaEngine
57 | };
58 |
--------------------------------------------------------------------------------
/utils.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 | #include
3 |
4 | // Colors
5 | const std::vector CITYSCAPES_COLORS = {
6 | cv::Scalar(128, 64, 128),
7 | cv::Scalar(232, 35, 244),
8 | cv::Scalar(70, 70, 70),
9 | cv::Scalar(156, 102, 102),
10 | cv::Scalar(153, 153, 190),
11 | cv::Scalar(153, 153, 153),
12 | cv::Scalar(30, 170, 250),
13 | cv::Scalar(0, 220, 220),
14 | cv::Scalar(35, 142, 107),
15 | cv::Scalar(152, 251, 152),
16 | cv::Scalar(180, 130, 70),
17 | cv::Scalar(60, 20, 220),
18 | cv::Scalar(0, 0, 255),
19 | cv::Scalar(142, 0, 0),
20 | cv::Scalar(70, 0, 0),
21 | cv::Scalar(100, 60, 0),
22 | cv::Scalar(90, 0, 0),
23 | cv::Scalar(230, 0, 0),
24 | cv::Scalar(32, 11, 119),
25 | cv::Scalar(0, 74, 111),
26 | cv::Scalar(81, 0, 81)
27 | };
28 |
29 | // Structure to hold clicked point coordinates
30 | struct PointData {
31 | cv::Point point;
32 | bool clicked;
33 | };
34 |
35 | // Overlay mask on the image
36 | void overlay(Mat& image, Mat& mask, Scalar color = Scalar(128, 64, 128), float alpha = 0.8f, bool showEdge = true)
37 | {
38 | // Draw mask
39 | Mat ucharMask(image.rows, image.cols, CV_8UC3, color);
40 | image.copyTo(ucharMask, mask <= 0);
41 | addWeighted(ucharMask, alpha, image, 1.0 - alpha, 0.0f, image);
42 |
43 | // Draw contour edge
44 | if (showEdge)
45 | {
46 | vector> contours;
47 | vector hierarchy;
48 | findContours(mask <= 0, contours, hierarchy, RETR_TREE, CHAIN_APPROX_NONE);
49 | drawContours(image, contours, -1, Scalar(255, 255, 255), 2);
50 | }
51 | }
52 |
53 | // Function to handle mouse events
54 | void onMouse(int event, int x, int y, int flags, void* userdata) {
55 | PointData* pd = (PointData*)userdata;
56 | if (event == cv::EVENT_LBUTTONDOWN) {
57 | // Save the clicked coordinates
58 | pd->point = cv::Point(x, y);
59 | pd->clicked = true;
60 | }
61 | }
62 |
--------------------------------------------------------------------------------