├── LICENSE.md ├── README.md ├── assets ├── dog.jpg ├── dog_mask.jpg ├── dogs.jpg ├── dogs_mask.jpg ├── masked_dogs.jpg ├── segment_with_single_click.gif └── video.gif ├── main.cpp ├── nanosam.sln ├── nanosam.vcxproj ├── nanosam.vcxproj.filters ├── nanosam.vcxproj.user ├── nanosam ├── config.h ├── cuda_utils.h ├── logging.h ├── macros.h ├── nanosam.cpp ├── nanosam.h ├── trt_module.cpp └── trt_module.h └── utils.h /LICENSE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |

NanoSAM C++

3 | 4 | This repo provides a TensorRT C++ implementation of Nvidia's [NanoSAM](https://github.com/NVIDIA-AI-IOT/nanosam), a distilled segment-anything model, for real-time inference on GPU. 5 | 6 |

7 | 8 | 9 |

10 | 11 | ## Getting Started 12 | 1. There are two ways to load engines: 13 | 14 | 1. Load engines built by trtexec: 15 | 16 | ```cpp 17 | #include "nanosam/nanosam.h" 18 | 19 | NanoSam nanosam( 20 | "resnet18_image_encoder.engine", 21 | "mobile_sam_mask_decoder.engine" 22 | ); 23 | ``` 24 | 2. Build engines directly from onnx files: 25 | 26 | ```cpp 27 | NanoSam nanosam( 28 | "resnet18_image_encoder.onnx", 29 | "mobile_sam_mask_decoder.onnx" 30 | ); 31 | ``` 32 | 33 | 2. Segment an object using a prompt point: 34 | 35 | ```cpp 36 | Mat image = imread("assets/dog.jpg"); 37 | // Foreground point 38 | vector points = { Point(1300, 900) }; 39 | vector labels = { 1 }; 40 | 41 | Mat mask = nanosam.predict(image, points, labels); 42 | ``` 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 |
InputOutput
54 | 55 | 3. Create masks from bounding boxes: 56 | 57 | ```cpp 58 | Mat image = imread("assets/dogs.jpg"); 59 | // Bounding box top-left and bottom-right points 60 | vector points = { Point(100, 100), Point(750, 759) }; 61 | vector labels = { 2, 3 }; 62 | 63 | Mat mask = nanosam.predict(image, points, labels); 64 | ``` 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 |
InputOutput
76 | 77 |
78 | Notes 79 | The point labels may be 80 | 81 | | Point Label | Description | 82 | |:--------------------:|-------------| 83 | | 0 | Background point | 84 | | 1 | Foreground point | 85 | | 2 | Bounding box top-left | 86 | | 3 | Bounding box bottom-right | 87 |
88 | 89 | 90 | ## Performance 91 | The inference time includes the pre-preprocessing time and the post-processing time: 92 | | Device | Image Shape(WxH) | Model Shape(WxH) | Inference Time(ms) | 93 | |:---------------:|:------------:|:------------:|:------------:| 94 | | RTX4090 |2048x1365 |1024x1024 |14 | 95 | 96 | ## Installation 97 | 98 | 1. Download the image encoder: [resnet18_image_encoder.onnx](https://drive.google.com/file/d/14-SsvoaTl-esC3JOzomHDnI9OGgdO2OR/view?usp=drive_link) 99 | 2. Download the mask decoder: [mobile_sam_mask_decoder.onnx](https://drive.google.com/file/d/1jYNvnseTL49SNRx9PDcbkZ9DwsY8up7n/view?usp=drive_link) 100 | 3. Download the [TensorRT](https://developer.nvidia.com/tensorrt) zip file that matches the Windows version you are using. 101 | 4. Choose where you want to install TensorRT. The zip file will install everything into a subdirectory called `TensorRT-8.x.x.x`. This new subdirectory will be referred to as `` in the steps below. 102 | 5. Unzip the `TensorRT-8.x.x.x.Windows10.x86_64.cuda-x.x.zip` file to the location that you chose. Where: 103 | - `8.x.x.x` is your TensorRT version 104 | - `cuda-x.x` is CUDA version `11.8` or `12.0` 105 | 6. Add the TensorRT library files to your system `PATH`. To do so, copy the DLL files from `/lib` to your CUDA installation directory, for example, `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\vX.Y\bin`, where `vX.Y` is your CUDA version. The CUDA installer should have already added the CUDA path to your system PATH. 106 | 7. Ensure that the following is present in your Visual Studio Solution project properties: 107 | - `/lib` has been added to your PATH variable and is present under **VC++ Directories > Executable Directories**. 108 | - `/include` is present under **C/C++ > General > Additional Directories**. 109 | - nvinfer.lib and any other LIB files that your project requires are present under **Linker > Input > Additional Dependencies**. 110 | 8. Download and install any recent [OpenCV](https://opencv.org/releases/) for Windows. 111 | 112 | ## Acknowledgement 113 | This project is based on the following projects: 114 | - [NanoSAM](https://github.com/NVIDIA-AI-IOT/nanosam) - The distilled Segment Anything (SAM). 115 | - [TensorRTx](https://github.com/wang-xinyu/tensorrtx) - Implementation of popular deep learning networks with TensorRT network definition API. 116 | - [TensorRT](https://github.com/NVIDIA/TensorRT/tree/release/8.6/samples) - TensorRT samples and api documentation. 117 | - ChatGPT - some of the simple functions were generated by ChatGPT :D 118 | -------------------------------------------------------------------------------- /assets/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spacewalk01/nanosam-cpp/1f1807b273fe9abcd5b685a747d91e42f0713ab2/assets/dog.jpg -------------------------------------------------------------------------------- /assets/dog_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spacewalk01/nanosam-cpp/1f1807b273fe9abcd5b685a747d91e42f0713ab2/assets/dog_mask.jpg -------------------------------------------------------------------------------- /assets/dogs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spacewalk01/nanosam-cpp/1f1807b273fe9abcd5b685a747d91e42f0713ab2/assets/dogs.jpg -------------------------------------------------------------------------------- /assets/dogs_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spacewalk01/nanosam-cpp/1f1807b273fe9abcd5b685a747d91e42f0713ab2/assets/dogs_mask.jpg -------------------------------------------------------------------------------- /assets/masked_dogs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spacewalk01/nanosam-cpp/1f1807b273fe9abcd5b685a747d91e42f0713ab2/assets/masked_dogs.jpg -------------------------------------------------------------------------------- /assets/segment_with_single_click.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spacewalk01/nanosam-cpp/1f1807b273fe9abcd5b685a747d91e42f0713ab2/assets/segment_with_single_click.gif -------------------------------------------------------------------------------- /assets/video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/spacewalk01/nanosam-cpp/1f1807b273fe9abcd5b685a747d91e42f0713ab2/assets/video.gif -------------------------------------------------------------------------------- /main.cpp: -------------------------------------------------------------------------------- 1 | #include "nanosam/nanosam.h" 2 | #include "utils.h" 3 | 4 | void segmentClickedPoint(NanoSam& nanosam, string imagePath) { 5 | 6 | auto image = imread(imagePath); 7 | 8 | // Create a window to display the image 9 | cv::namedWindow("Image"); 10 | 11 | // Data structure to hold clicked point 12 | PointData pointData; 13 | pointData.clicked = false; 14 | cv::Mat clonedImage = image.clone(); 15 | int clickCount = 0; 16 | 17 | // Loop until Esc key is pressed 18 | while (true) 19 | { 20 | // Display the original image 21 | cv::imshow("Image", clonedImage); 22 | 23 | if (pointData.clicked) 24 | { 25 | pointData.clicked = false; // Reset clicked flag 26 | 27 | auto mask = nanosam.predict(image, { pointData.point }, { 1.0f }); 28 | 29 | cv::circle(clonedImage, pointData.point, 5, cv::Scalar(0, 0, 255), -1); 30 | 31 | if (clickCount >= CITYSCAPES_COLORS.size()) clickCount = 0; 32 | overlay(image, mask, CITYSCAPES_COLORS[clickCount * 9]); 33 | clickCount++; 34 | } 35 | 36 | // Set the callback function for mouse events on the displayed cloned image 37 | cv::setMouseCallback("Image", onMouse, &pointData); 38 | 39 | // Check for Esc key press 40 | char key = cv::waitKey(1); 41 | if (key == 27) // ASCII code for Esc key 42 | { 43 | clonedImage = image.clone(); 44 | } 45 | } 46 | cv::destroyAllWindows(); 47 | } 48 | 49 | void segmentBbox(NanoSam& nanosam, string imagePath, string outputPath, vector bbox) 50 | { 51 | auto image = imread(imagePath); 52 | 53 | // 2 : Bounding box top-left, 3 : Bounding box bottom-right 54 | vector labels = { 2, 3 }; 55 | 56 | auto mask = nanosam.predict(image, bbox, labels); 57 | 58 | overlay(image, mask); 59 | 60 | rectangle(image, bbox[0], bbox[1], cv::Scalar(255, 255, 0), 3); 61 | 62 | imwrite(outputPath, image); 63 | } 64 | 65 | void segmentWithPoint(NanoSam& nanosam, string imagePath, string outputPath, Point promptPoint) 66 | { 67 | auto image = imread(imagePath); 68 | 69 | // 1 : Foreground 70 | vector labels = { 1.0f }; 71 | 72 | auto mask = nanosam.predict(image, { promptPoint }, labels); 73 | 74 | overlay(image, mask); 75 | 76 | imwrite(outputPath, image); 77 | } 78 | 79 | int main() 80 | { 81 | /* 1. Load engine examples */ 82 | 83 | // Option 1: Load the engines 84 | //NanoSam nanosam("data/resnet18_image_encoder.engine", "data/mobile_sam_mask_decoder.engine"); 85 | 86 | // Option 2: Build the engines from onnx files 87 | NanoSam nanosam("data/resnet18_image_encoder.onnx", "data/mobile_sam_mask_decoder.onnx"); 88 | 89 | /* 2. Segmentation examples */ 90 | 91 | // Demo 1: Segment using a point 92 | segmentWithPoint(nanosam, "assets/dog.jpg", "assets/dog_mask.jpg", Point(1300, 900)); 93 | 94 | // Demo 2: Segment using a bounding box 95 | segmentBbox(nanosam, "assets/dogs.jpg", "assets/dogs_mask.jpg", { Point(100, 100), Point(750, 759) }); 96 | 97 | // Demo 3: Segment the clicked object 98 | segmentClickedPoint(nanosam, "assets/dogs.jpg"); 99 | 100 | return 0; 101 | } 102 | -------------------------------------------------------------------------------- /nanosam.sln: -------------------------------------------------------------------------------- 1 |  2 | Microsoft Visual Studio Solution File, Format Version 12.00 3 | # Visual Studio Version 17 4 | VisualStudioVersion = 17.7.34202.233 5 | MinimumVisualStudioVersion = 10.0.40219.1 6 | Project("{8BC9CEB8-8B4A-11D0-8D11-00A0C91BC942}") = "nanosam", "nanosam.vcxproj", "{A6DD5CA9-1C7E-46B4-B0CA-F2D17D21DDCC}" 7 | EndProject 8 | Global 9 | GlobalSection(SolutionConfigurationPlatforms) = preSolution 10 | Debug|x64 = Debug|x64 11 | Debug|x86 = Debug|x86 12 | Release|x64 = Release|x64 13 | Release|x86 = Release|x86 14 | EndGlobalSection 15 | GlobalSection(ProjectConfigurationPlatforms) = postSolution 16 | {A6DD5CA9-1C7E-46B4-B0CA-F2D17D21DDCC}.Debug|x64.ActiveCfg = Debug|x64 17 | {A6DD5CA9-1C7E-46B4-B0CA-F2D17D21DDCC}.Debug|x64.Build.0 = Debug|x64 18 | {A6DD5CA9-1C7E-46B4-B0CA-F2D17D21DDCC}.Debug|x86.ActiveCfg = Debug|Win32 19 | {A6DD5CA9-1C7E-46B4-B0CA-F2D17D21DDCC}.Debug|x86.Build.0 = Debug|Win32 20 | {A6DD5CA9-1C7E-46B4-B0CA-F2D17D21DDCC}.Release|x64.ActiveCfg = Release|x64 21 | {A6DD5CA9-1C7E-46B4-B0CA-F2D17D21DDCC}.Release|x64.Build.0 = Release|x64 22 | {A6DD5CA9-1C7E-46B4-B0CA-F2D17D21DDCC}.Release|x86.ActiveCfg = Release|Win32 23 | {A6DD5CA9-1C7E-46B4-B0CA-F2D17D21DDCC}.Release|x86.Build.0 = Release|Win32 24 | EndGlobalSection 25 | GlobalSection(SolutionProperties) = preSolution 26 | HideSolutionNode = FALSE 27 | EndGlobalSection 28 | GlobalSection(ExtensibilityGlobals) = postSolution 29 | SolutionGuid = {C73DAFCD-6916-427D-9DCB-F06A0259B9CC} 30 | EndGlobalSection 31 | EndGlobal 32 | -------------------------------------------------------------------------------- /nanosam.vcxproj: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Debug 6 | Win32 7 | 8 | 9 | Release 10 | Win32 11 | 12 | 13 | Debug 14 | x64 15 | 16 | 17 | Release 18 | x64 19 | 20 | 21 | 22 | 17.0 23 | Win32Proj 24 | {a6dd5ca9-1c7e-46b4-b0ca-f2d17d21ddcc} 25 | nanosam 26 | 10.0 27 | 28 | 29 | 30 | Application 31 | true 32 | v143 33 | Unicode 34 | 35 | 36 | Application 37 | false 38 | v143 39 | true 40 | Unicode 41 | 42 | 43 | Application 44 | true 45 | v143 46 | Unicode 47 | 48 | 49 | Application 50 | false 51 | v143 52 | true 53 | Unicode 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | Level3 77 | true 78 | WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions) 79 | true 80 | 81 | 82 | Console 83 | true 84 | 85 | 86 | 87 | 88 | Level3 89 | true 90 | true 91 | true 92 | WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions) 93 | true 94 | 95 | 96 | Console 97 | true 98 | true 99 | true 100 | 101 | 102 | 103 | 104 | Level3 105 | true 106 | _DEBUG;_CONSOLE;%(PreprocessorDefinitions) 107 | true 108 | 109 | 110 | Console 111 | true 112 | 113 | 114 | 115 | 116 | Level3 117 | true 118 | true 119 | true 120 | NDEBUG;_CONSOLE;_CRT_SECURE_NO_WARNINGS;_CRT_NONSTDC_NO_DEPRECATE;%(PreprocessorDefinitions) 121 | true 122 | C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\include;C:\TensorRT-8.6.0.12\include;%(AdditionalIncludeDirectories) 123 | stdcpp17 124 | 125 | 126 | Console 127 | true 128 | true 129 | true 130 | C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.4\lib\x64;C:\TensorRT-8.6.0.12\lib;%(AdditionalLibraryDirectories) 131 | nvinfer.lib;nvinfer_plugin.lib;nvonnxparser.lib;nvparsers.lib;cublas.lib;cuda.lib;cudart.lib;cudnn.lib;%(AdditionalDependencies) 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | -------------------------------------------------------------------------------- /nanosam.vcxproj.filters: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | {4FC737F1-C7A5-4376-A066-2A32D752A2FF} 6 | cpp;c;cc;cxx;c++;cppm;ixx;def;odl;idl;hpj;bat;asm;asmx 7 | 8 | 9 | {93995380-89BD-4b04-88EB-625FBE52EBFB} 10 | h;hh;hpp;hxx;h++;hm;inl;inc;ipp;xsd 11 | 12 | 13 | {67DA6AB6-F800-4c08-8B7A-83BB121AAD01} 14 | rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms 15 | 16 | 17 | {9485532c-d9ca-4d0a-af13-e349120532da} 18 | 19 | 20 | 21 | 22 | Source Files 23 | 24 | 25 | nanosam 26 | 27 | 28 | nanosam 29 | 30 | 31 | 32 | 33 | Header Files 34 | 35 | 36 | nanosam 37 | 38 | 39 | nanosam 40 | 41 | 42 | nanosam 43 | 44 | 45 | nanosam 46 | 47 | 48 | nanosam 49 | 50 | 51 | nanosam 52 | 53 | 54 | -------------------------------------------------------------------------------- /nanosam.vcxproj.user: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | -------------------------------------------------------------------------------- /nanosam/config.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #define USE_FP16 // set USE_FP16 or USE_FP32 4 | 5 | #define MAX_NUM_PROMPTS 1 6 | 7 | // Model Params 8 | #define MODEL_INPUT_WIDTH 1024.0f 9 | #define MODEL_INPUT_HEIGHT 1024.0f 10 | #define HIDDEN_DIM 256 11 | #define NUM_LABELS 4 12 | #define FEATURE_WIDTH 64 13 | #define FEATURE_HEIGHT 64 14 | -------------------------------------------------------------------------------- /nanosam/cuda_utils.h: -------------------------------------------------------------------------------- 1 | #ifndef TRTX_CUDA_UTILS_H_ 2 | #define TRTX_CUDA_UTILS_H_ 3 | 4 | #include 5 | 6 | #ifndef CUDA_CHECK 7 | #define CUDA_CHECK(callstr)\ 8 | {\ 9 | cudaError_t error_code = callstr;\ 10 | if (error_code != cudaSuccess) {\ 11 | std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__;\ 12 | assert(0);\ 13 | }\ 14 | } 15 | #endif // CUDA_CHECK 16 | 17 | #define CHECK_RETURN_W_MSG(status, val, errMsg) \ 18 | do \ 19 | { \ 20 | if (!(status)) \ 21 | { \ 22 | sample::gLogError << errMsg << " Error in " << __FILE__ << ", function " << FN_NAME << "(), line " << __LINE__ \ 23 | << std::endl; \ 24 | return val; \ 25 | } \ 26 | } while (0) 27 | 28 | 29 | #endif // TRTX_CUDA_UTILS_H_ 30 | 31 | -------------------------------------------------------------------------------- /nanosam/logging.h: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #ifndef TENSORRT_LOGGING_H 18 | #define TENSORRT_LOGGING_H 19 | 20 | #include "NvInferRuntimeCommon.h" 21 | #include 22 | #include 23 | #include 24 | #include 25 | #include 26 | #include 27 | #include 28 | #include "macros.h" 29 | 30 | using Severity = nvinfer1::ILogger::Severity; 31 | 32 | class LogStreamConsumerBuffer : public std::stringbuf 33 | { 34 | public: 35 | LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog) 36 | : mOutput(stream) 37 | , mPrefix(prefix) 38 | , mShouldLog(shouldLog) 39 | { 40 | } 41 | 42 | LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other) 43 | : mOutput(other.mOutput) 44 | { 45 | } 46 | 47 | ~LogStreamConsumerBuffer() 48 | { 49 | // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence 50 | // std::streambuf::pptr() gives a pointer to the current position of the output sequence 51 | // if the pointer to the beginning is not equal to the pointer to the current position, 52 | // call putOutput() to log the output to the stream 53 | if (pbase() != pptr()) 54 | { 55 | putOutput(); 56 | } 57 | } 58 | 59 | // synchronizes the stream buffer and returns 0 on success 60 | // synchronizing the stream buffer consists of inserting the buffer contents into the stream, 61 | // resetting the buffer and flushing the stream 62 | virtual int sync() 63 | { 64 | putOutput(); 65 | return 0; 66 | } 67 | 68 | void putOutput() 69 | { 70 | if (mShouldLog) 71 | { 72 | // prepend timestamp 73 | std::time_t timestamp = std::time(nullptr); 74 | tm* tm_local = std::localtime(×tamp); 75 | std::cout << "["; 76 | std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/"; 77 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/"; 78 | std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-"; 79 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":"; 80 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":"; 81 | std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] "; 82 | // std::stringbuf::str() gets the string contents of the buffer 83 | // insert the buffer contents pre-appended by the appropriate prefix into the stream 84 | mOutput << mPrefix << str(); 85 | // set the buffer to empty 86 | str(""); 87 | // flush the stream 88 | mOutput.flush(); 89 | } 90 | } 91 | 92 | void setShouldLog(bool shouldLog) 93 | { 94 | mShouldLog = shouldLog; 95 | } 96 | 97 | private: 98 | std::ostream& mOutput; 99 | std::string mPrefix; 100 | bool mShouldLog; 101 | }; 102 | 103 | //! 104 | //! \class LogStreamConsumerBase 105 | //! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer 106 | //! 107 | class LogStreamConsumerBase 108 | { 109 | public: 110 | LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog) 111 | : mBuffer(stream, prefix, shouldLog) 112 | { 113 | } 114 | 115 | protected: 116 | LogStreamConsumerBuffer mBuffer; 117 | }; 118 | 119 | //! 120 | //! \class LogStreamConsumer 121 | //! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages. 122 | //! Order of base classes is LogStreamConsumerBase and then std::ostream. 123 | //! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field 124 | //! in LogStreamConsumer and then the address of the buffer is passed to std::ostream. 125 | //! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream. 126 | //! Please do not change the order of the parent classes. 127 | //! 128 | class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream 129 | { 130 | public: 131 | //! \brief Creates a LogStreamConsumer which logs messages with level severity. 132 | //! Reportable severity determines if the messages are severe enough to be logged. 133 | LogStreamConsumer(Severity reportableSeverity, Severity severity) 134 | : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity) 135 | , std::ostream(&mBuffer) // links the stream buffer with the stream 136 | , mShouldLog(severity <= reportableSeverity) 137 | , mSeverity(severity) 138 | { 139 | } 140 | 141 | LogStreamConsumer(LogStreamConsumer&& other) 142 | : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog) 143 | , std::ostream(&mBuffer) // links the stream buffer with the stream 144 | , mShouldLog(other.mShouldLog) 145 | , mSeverity(other.mSeverity) 146 | { 147 | } 148 | 149 | void setReportableSeverity(Severity reportableSeverity) 150 | { 151 | mShouldLog = mSeverity <= reportableSeverity; 152 | mBuffer.setShouldLog(mShouldLog); 153 | } 154 | 155 | private: 156 | static std::ostream& severityOstream(Severity severity) 157 | { 158 | return severity >= Severity::kINFO ? std::cout : std::cerr; 159 | } 160 | 161 | static std::string severityPrefix(Severity severity) 162 | { 163 | switch (severity) 164 | { 165 | case Severity::kINTERNAL_ERROR: return "[F] "; 166 | case Severity::kERROR: return "[E] "; 167 | case Severity::kWARNING: return "[W] "; 168 | case Severity::kINFO: return "[I] "; 169 | case Severity::kVERBOSE: return "[V] "; 170 | default: assert(0); return ""; 171 | } 172 | } 173 | 174 | bool mShouldLog; 175 | Severity mSeverity; 176 | }; 177 | 178 | //! \class Logger 179 | //! 180 | //! \brief Class which manages logging of TensorRT tools and samples 181 | //! 182 | //! \details This class provides a common interface for TensorRT tools and samples to log information to the console, 183 | //! and supports logging two types of messages: 184 | //! 185 | //! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal) 186 | //! - Test pass/fail messages 187 | //! 188 | //! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is 189 | //! that the logic for controlling the verbosity and formatting of sample output is centralized in one location. 190 | //! 191 | //! In the future, this class could be extended to support dumping test results to a file in some standard format 192 | //! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run). 193 | //! 194 | //! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger 195 | //! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT 196 | //! library and messages coming from the sample. 197 | //! 198 | //! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the 199 | //! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger 200 | //! object. 201 | 202 | class Logger : public nvinfer1::ILogger 203 | { 204 | public: 205 | Logger(Severity severity = Severity::kWARNING) 206 | : mReportableSeverity(severity) 207 | { 208 | } 209 | 210 | //! 211 | //! \enum TestResult 212 | //! \brief Represents the state of a given test 213 | //! 214 | enum class TestResult 215 | { 216 | kRUNNING, //!< The test is running 217 | kPASSED, //!< The test passed 218 | kFAILED, //!< The test failed 219 | kWAIVED //!< The test was waived 220 | }; 221 | 222 | //! 223 | //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger 224 | //! \return The nvinfer1::ILogger associated with this Logger 225 | //! 226 | //! TODO Once all samples are updated to use this method to register the logger with TensorRT, 227 | //! we can eliminate the inheritance of Logger from ILogger 228 | //! 229 | nvinfer1::ILogger& getTRTLogger() 230 | { 231 | return *this; 232 | } 233 | 234 | //! 235 | //! \brief Implementation of the nvinfer1::ILogger::log() virtual method 236 | //! 237 | //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the 238 | //! inheritance from nvinfer1::ILogger 239 | //! 240 | void log(Severity severity, const char* msg) TRT_NOEXCEPT override 241 | { 242 | LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl; 243 | } 244 | 245 | //! 246 | //! \brief Method for controlling the verbosity of logging output 247 | //! 248 | //! \param severity The logger will only emit messages that have severity of this level or higher. 249 | //! 250 | void setReportableSeverity(Severity severity) 251 | { 252 | mReportableSeverity = severity; 253 | } 254 | 255 | //! 256 | //! \brief Opaque handle that holds logging information for a particular test 257 | //! 258 | //! This object is an opaque handle to information used by the Logger to print test results. 259 | //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used 260 | //! with Logger::reportTest{Start,End}(). 261 | //! 262 | class TestAtom 263 | { 264 | public: 265 | TestAtom(TestAtom&&) = default; 266 | 267 | private: 268 | friend class Logger; 269 | 270 | TestAtom(bool started, const std::string& name, const std::string& cmdline) 271 | : mStarted(started) 272 | , mName(name) 273 | , mCmdline(cmdline) 274 | { 275 | } 276 | 277 | bool mStarted; 278 | std::string mName; 279 | std::string mCmdline; 280 | }; 281 | 282 | //! 283 | //! \brief Define a test for logging 284 | //! 285 | //! \param[in] name The name of the test. This should be a string starting with 286 | //! "TensorRT" and containing dot-separated strings containing 287 | //! the characters [A-Za-z0-9_]. 288 | //! For example, "TensorRT.sample_googlenet" 289 | //! \param[in] cmdline The command line used to reproduce the test 290 | // 291 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). 292 | //! 293 | static TestAtom defineTest(const std::string& name, const std::string& cmdline) 294 | { 295 | return TestAtom(false, name, cmdline); 296 | } 297 | 298 | //! 299 | //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments 300 | //! as input 301 | //! 302 | //! \param[in] name The name of the test 303 | //! \param[in] argc The number of command-line arguments 304 | //! \param[in] argv The array of command-line arguments (given as C strings) 305 | //! 306 | //! \return a TestAtom that can be used in Logger::reportTest{Start,End}(). 307 | static TestAtom defineTest(const std::string& name, int argc, char const* const* argv) 308 | { 309 | auto cmdline = genCmdlineString(argc, argv); 310 | return defineTest(name, cmdline); 311 | } 312 | 313 | //! 314 | //! \brief Report that a test has started. 315 | //! 316 | //! \pre reportTestStart() has not been called yet for the given testAtom 317 | //! 318 | //! \param[in] testAtom The handle to the test that has started 319 | //! 320 | static void reportTestStart(TestAtom& testAtom) 321 | { 322 | reportTestResult(testAtom, TestResult::kRUNNING); 323 | assert(!testAtom.mStarted); 324 | testAtom.mStarted = true; 325 | } 326 | 327 | //! 328 | //! \brief Report that a test has ended. 329 | //! 330 | //! \pre reportTestStart() has been called for the given testAtom 331 | //! 332 | //! \param[in] testAtom The handle to the test that has ended 333 | //! \param[in] result The result of the test. Should be one of TestResult::kPASSED, 334 | //! TestResult::kFAILED, TestResult::kWAIVED 335 | //! 336 | static void reportTestEnd(const TestAtom& testAtom, TestResult result) 337 | { 338 | assert(result != TestResult::kRUNNING); 339 | assert(testAtom.mStarted); 340 | reportTestResult(testAtom, result); 341 | } 342 | 343 | static int reportPass(const TestAtom& testAtom) 344 | { 345 | reportTestEnd(testAtom, TestResult::kPASSED); 346 | return EXIT_SUCCESS; 347 | } 348 | 349 | static int reportFail(const TestAtom& testAtom) 350 | { 351 | reportTestEnd(testAtom, TestResult::kFAILED); 352 | return EXIT_FAILURE; 353 | } 354 | 355 | static int reportWaive(const TestAtom& testAtom) 356 | { 357 | reportTestEnd(testAtom, TestResult::kWAIVED); 358 | return EXIT_SUCCESS; 359 | } 360 | 361 | static int reportTest(const TestAtom& testAtom, bool pass) 362 | { 363 | return pass ? reportPass(testAtom) : reportFail(testAtom); 364 | } 365 | 366 | Severity getReportableSeverity() const 367 | { 368 | return mReportableSeverity; 369 | } 370 | 371 | private: 372 | //! 373 | //! \brief returns an appropriate string for prefixing a log message with the given severity 374 | //! 375 | static const char* severityPrefix(Severity severity) 376 | { 377 | switch (severity) 378 | { 379 | case Severity::kINTERNAL_ERROR: return "[F] "; 380 | case Severity::kERROR: return "[E] "; 381 | case Severity::kWARNING: return "[W] "; 382 | case Severity::kINFO: return "[I] "; 383 | case Severity::kVERBOSE: return "[V] "; 384 | default: assert(0); return ""; 385 | } 386 | } 387 | 388 | //! 389 | //! \brief returns an appropriate string for prefixing a test result message with the given result 390 | //! 391 | static const char* testResultString(TestResult result) 392 | { 393 | switch (result) 394 | { 395 | case TestResult::kRUNNING: return "RUNNING"; 396 | case TestResult::kPASSED: return "PASSED"; 397 | case TestResult::kFAILED: return "FAILED"; 398 | case TestResult::kWAIVED: return "WAIVED"; 399 | default: assert(0); return ""; 400 | } 401 | } 402 | 403 | //! 404 | //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity 405 | //! 406 | static std::ostream& severityOstream(Severity severity) 407 | { 408 | return severity >= Severity::kINFO ? std::cout : std::cerr; 409 | } 410 | 411 | //! 412 | //! \brief method that implements logging test results 413 | //! 414 | static void reportTestResult(const TestAtom& testAtom, TestResult result) 415 | { 416 | severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # " 417 | << testAtom.mCmdline << std::endl; 418 | } 419 | 420 | //! 421 | //! \brief generate a command line string from the given (argc, argv) values 422 | //! 423 | static std::string genCmdlineString(int argc, char const* const* argv) 424 | { 425 | std::stringstream ss; 426 | for (int i = 0; i < argc; i++) 427 | { 428 | if (i > 0) 429 | ss << " "; 430 | ss << argv[i]; 431 | } 432 | return ss.str(); 433 | } 434 | 435 | Severity mReportableSeverity; 436 | }; 437 | 438 | namespace 439 | { 440 | 441 | //! 442 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE 443 | //! 444 | //! Example usage: 445 | //! 446 | //! LOG_VERBOSE(logger) << "hello world" << std::endl; 447 | //! 448 | inline LogStreamConsumer LOG_VERBOSE(const Logger& logger) 449 | { 450 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE); 451 | } 452 | 453 | //! 454 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO 455 | //! 456 | //! Example usage: 457 | //! 458 | //! LOG_INFO(logger) << "hello world" << std::endl; 459 | //! 460 | inline LogStreamConsumer LOG_INFO(const Logger& logger) 461 | { 462 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO); 463 | } 464 | 465 | //! 466 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING 467 | //! 468 | //! Example usage: 469 | //! 470 | //! LOG_WARN(logger) << "hello world" << std::endl; 471 | //! 472 | inline LogStreamConsumer LOG_WARN(const Logger& logger) 473 | { 474 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING); 475 | } 476 | 477 | //! 478 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR 479 | //! 480 | //! Example usage: 481 | //! 482 | //! LOG_ERROR(logger) << "hello world" << std::endl; 483 | //! 484 | inline LogStreamConsumer LOG_ERROR(const Logger& logger) 485 | { 486 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR); 487 | } 488 | 489 | //! 490 | //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR 491 | // ("fatal" severity) 492 | //! 493 | //! Example usage: 494 | //! 495 | //! LOG_FATAL(logger) << "hello world" << std::endl; 496 | //! 497 | inline LogStreamConsumer LOG_FATAL(const Logger& logger) 498 | { 499 | return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR); 500 | } 501 | 502 | } // anonymous namespace 503 | 504 | #endif // TENSORRT_LOGGING_H 505 | -------------------------------------------------------------------------------- /nanosam/macros.h: -------------------------------------------------------------------------------- 1 | #ifndef __MACROS_H 2 | #define __MACROS_H 3 | 4 | #ifdef API_EXPORTS 5 | #if defined(_MSC_VER) 6 | #define API __declspec(dllexport) 7 | #else 8 | #define API __attribute__((visibility("default"))) 9 | #endif 10 | #else 11 | 12 | #if defined(_MSC_VER) 13 | #define API __declspec(dllimport) 14 | #else 15 | #define API 16 | #endif 17 | #endif // API_EXPORTS 18 | 19 | #if NV_TENSORRT_MAJOR >= 8 20 | #define TRT_NOEXCEPT noexcept 21 | #define TRT_CONST_ENQUEUE const 22 | #else 23 | #define TRT_NOEXCEPT 24 | #define TRT_CONST_ENQUEUE 25 | #endif 26 | 27 | #endif // __MACROS_H 28 | -------------------------------------------------------------------------------- /nanosam/nanosam.cpp: -------------------------------------------------------------------------------- 1 | #include "nanosam.h" 2 | #include "config.h" 3 | 4 | using namespace std; 5 | 6 | // Constructor 7 | NanoSam::NanoSam(string encoderPath, string decoderPath) 8 | { 9 | mImageEncoder = new TRTModule(encoderPath, 10 | { "image" }, 11 | { "image_embeddings" }, false, true); 12 | 13 | mMaskDecoder = new TRTModule(decoderPath, 14 | { "image_embeddings", "point_coords", "point_labels", "mask_input", "has_mask_input" }, 15 | { "iou_predictions", "low_res_masks" }, true, false); 16 | 17 | mFeatures = new float[HIDDEN_DIM * FEATURE_WIDTH * FEATURE_HEIGHT]; 18 | mMaskInput = new float[HIDDEN_DIM * HIDDEN_DIM]; 19 | mHasMaskInput = new float; 20 | mIouPrediction = new float[NUM_LABELS]; 21 | mLowResMasks = new float[NUM_LABELS * HIDDEN_DIM * HIDDEN_DIM]; 22 | } 23 | 24 | // Deconstructor 25 | NanoSam::~NanoSam() 26 | { 27 | if (mFeatures) delete[] mFeatures; 28 | if (mMaskInput) delete[] mMaskInput; 29 | if (mIouPrediction) delete[] mIouPrediction; 30 | if (mLowResMasks) delete[] mLowResMasks; 31 | 32 | if (mImageEncoder) delete mImageEncoder; 33 | if (mMaskDecoder) delete mMaskDecoder; 34 | } 35 | 36 | // Perform inference using NanoSam models 37 | Mat NanoSam::predict(Mat& image, vector points, vector labels) 38 | { 39 | if (points.size() == 0) return cv::Mat(image.rows, image.cols, CV_32FC1); 40 | 41 | // Preprocess encoder input 42 | auto resizedImage = resizeImage(image, MODEL_INPUT_WIDTH, MODEL_INPUT_HEIGHT); 43 | 44 | // Encoder Inference 45 | mImageEncoder->setInput(resizedImage); 46 | mImageEncoder->infer(); 47 | mImageEncoder->getOutput(mFeatures); 48 | 49 | // Preprocess decoder input 50 | auto pointData = new float[2 * points.size()]; 51 | prepareDecoderInput(points, pointData, points.size(), image.cols, image.rows); 52 | 53 | // Decoder Inference 54 | mMaskDecoder->setInput(mFeatures, pointData, labels.data(), mMaskInput, mHasMaskInput, points.size()); 55 | mMaskDecoder->infer(); 56 | mMaskDecoder->getOutput(mIouPrediction, mLowResMasks); 57 | 58 | // Postprocessing 59 | Mat imgMask(HIDDEN_DIM, HIDDEN_DIM, CV_32FC1, mLowResMasks); 60 | upscaleMask(imgMask, image.cols, image.rows); 61 | 62 | delete[] pointData; 63 | 64 | return imgMask; 65 | } 66 | 67 | void NanoSam::prepareDecoderInput(vector& points, float* pointData, int numPoints, int imageWidth, int imageHeight) 68 | { 69 | float scale = MODEL_INPUT_WIDTH / max(imageWidth, imageHeight); 70 | 71 | for (int i = 0; i < numPoints; i++) 72 | { 73 | pointData[i * 2] = (float)points[i].x * scale; 74 | pointData[i * 2 + 1] = (float)points[i].y * scale; 75 | } 76 | 77 | for (int i = 0; i < HIDDEN_DIM * HIDDEN_DIM; i++) 78 | { 79 | mMaskInput[i] = 0; 80 | } 81 | *mHasMaskInput = 0; 82 | } 83 | 84 | Mat NanoSam::resizeImage(Mat& img, int inputWidth, int inputHeight) 85 | { 86 | int w, h; 87 | float aspectRatio = (float)img.cols / (float)img.rows; 88 | 89 | if (aspectRatio >= 1) 90 | { 91 | w = inputWidth; 92 | h = int(inputHeight / aspectRatio); 93 | } 94 | else 95 | { 96 | w = int(inputWidth * aspectRatio); 97 | h = inputHeight; 98 | } 99 | 100 | Mat re(h, w, CV_8UC3); 101 | cv::resize(img, re, re.size(), 0, 0, INTER_LINEAR); 102 | Mat out(inputHeight, inputWidth, CV_8UC3, 0.0); 103 | re.copyTo(out(Rect(0, 0, re.cols, re.rows))); 104 | 105 | return out; 106 | } 107 | 108 | void NanoSam::upscaleMask(Mat& mask, int targetWidth, int targetHeight, int size) 109 | { 110 | int limX, limY; 111 | if (targetWidth > targetHeight) 112 | { 113 | limX = size; 114 | limY = size * targetHeight / targetWidth; 115 | } 116 | else 117 | { 118 | limX = size * targetWidth / targetHeight; 119 | limY = size; 120 | } 121 | 122 | cv::resize(mask(Rect(0, 0, limX, limY)), mask, Size(targetWidth, targetHeight)); 123 | } 124 | -------------------------------------------------------------------------------- /nanosam/nanosam.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "trt_module.h" 5 | 6 | class NanoSam 7 | { 8 | 9 | public: 10 | 11 | NanoSam(string encoderPath, string decoderPath); 12 | 13 | ~NanoSam(); 14 | 15 | Mat predict(Mat& image, vector points, vector labels); 16 | 17 | private: 18 | 19 | // Variables 20 | float* mFeatures; 21 | float* mMaskInput; 22 | float* mHasMaskInput; 23 | float* mIouPrediction; 24 | float* mLowResMasks; 25 | 26 | TRTModule* mImageEncoder; 27 | TRTModule* mMaskDecoder; 28 | 29 | void upscaleMask(Mat& mask, int targetWidth, int targetHeight, int size = 256); 30 | Mat resizeImage(Mat& img, int modelWidth, int modelHeight); 31 | void prepareDecoderInput(vector& points, float* pointData, int numPoints, int imageWidth, int imageHeight); 32 | 33 | }; 34 | 35 | -------------------------------------------------------------------------------- /nanosam/trt_module.cpp: -------------------------------------------------------------------------------- 1 | #include "trt_module.h" 2 | #include "logging.h" 3 | #include "cuda_utils.h" 4 | #include "config.h" 5 | #include "macros.h" 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | static Logger gLogger; 15 | 16 | 17 | std::string getFileExtension(const std::string& filePath) { 18 | size_t dotPos = filePath.find_last_of("."); 19 | if (dotPos != std::string::npos) { 20 | return filePath.substr(dotPos + 1); 21 | } 22 | return ""; // No extension found 23 | } 24 | 25 | TRTModule::TRTModule(string modelPath, vector inputNames, vector outputNames, bool isDynamicShape, bool isFP16) 26 | { 27 | if (getFileExtension(modelPath) == "onnx") 28 | { 29 | cout << "Building Engine from " << modelPath << endl; 30 | build(modelPath, inputNames, outputNames, isDynamicShape, isFP16); 31 | } 32 | else 33 | { 34 | cout << "Deserializing Engine." << endl; 35 | deserializeEngine(modelPath, inputNames, outputNames); 36 | } 37 | } 38 | 39 | TRTModule::~TRTModule() 40 | { 41 | // Release stream and buffers 42 | cudaStreamDestroy(mCudaStream); 43 | for (int i = 0; i < mGpuBuffers.size(); i++) 44 | CUDA_CHECK(cudaFree(mGpuBuffers[i])); 45 | for (int i = 0; i < mCpuBuffers.size(); i++) 46 | delete[] mCpuBuffers[i]; 47 | 48 | // Destroy the engine 49 | delete mContext; 50 | delete mEngine; 51 | delete mRuntime; 52 | } 53 | 54 | void TRTModule::build(string onnxPath, vector inputNames, vector outputNames, bool isDynamicShape, bool isFP16) 55 | { 56 | auto builder = createInferBuilder(gLogger); 57 | assert(builder != nullptr); 58 | 59 | const auto explicitBatch = 1U << static_cast(NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); 60 | INetworkDefinition* network = builder->createNetworkV2(explicitBatch); 61 | assert(network != nullptr); 62 | 63 | IBuilderConfig* config = builder->createBuilderConfig(); 64 | assert(config != nullptr); 65 | 66 | if (isDynamicShape) // Only designed for NanoSAM mask decoder 67 | { 68 | auto profile = builder->createOptimizationProfile(); 69 | 70 | profile->setDimensions(inputNames[1].c_str(), OptProfileSelector::kMIN, Dims3{ 1, 1, 2 }); 71 | profile->setDimensions(inputNames[1].c_str(), OptProfileSelector::kOPT, Dims3{ 1, 1, 2 }); 72 | profile->setDimensions(inputNames[1].c_str(), OptProfileSelector::kMAX, Dims3{ 1, 10, 2 }); 73 | 74 | profile->setDimensions(inputNames[2].c_str(), OptProfileSelector::kMIN, Dims2{ 1, 1 }); 75 | profile->setDimensions(inputNames[2].c_str(), OptProfileSelector::kOPT, Dims2{ 1, 1 }); 76 | profile->setDimensions(inputNames[2].c_str(), OptProfileSelector::kMAX, Dims2{ 1, 10 }); 77 | 78 | config->addOptimizationProfile(profile); 79 | } 80 | 81 | if (isFP16) 82 | { 83 | config->setFlag(BuilderFlag::kFP16); 84 | } 85 | 86 | nvonnxparser::IParser* parser = nvonnxparser::createParser(*network, gLogger); 87 | assert(parser != nullptr); 88 | 89 | bool parsed = parser->parseFromFile(onnxPath.c_str(), static_cast(gLogger.getReportableSeverity())); 90 | assert(parsed != nullptrt); 91 | 92 | 93 | // CUDA stream used for profiling by the builder. 94 | assert(mCudaStream != nullptr); 95 | 96 | IHostMemory* plan{ builder->buildSerializedNetwork(*network, *config) }; 97 | assert(plan != nullptr); 98 | 99 | mRuntime = createInferRuntime(gLogger); 100 | assert(mRuntime != nullptr); 101 | 102 | mEngine = mRuntime->deserializeCudaEngine(plan->data(), plan->size(), nullptr); 103 | assert(mEngine != nullptr); 104 | 105 | mContext = mEngine->createExecutionContext(); 106 | assert(mContext != nullptr); 107 | 108 | delete network; 109 | delete config; 110 | delete parser; 111 | delete plan; 112 | 113 | initialize(inputNames, outputNames); 114 | } 115 | 116 | void TRTModule::deserializeEngine(string engine_name, vector inputNames, vector outputNames) 117 | { 118 | std::ifstream file(engine_name, std::ios::binary); 119 | if (!file.good()) { 120 | std::cerr << "read " << engine_name << " error!" << std::endl; 121 | assert(false); 122 | } 123 | size_t size = 0; 124 | file.seekg(0, file.end); 125 | size = file.tellg(); 126 | file.seekg(0, file.beg); 127 | char* serializedEngine = new char[size]; 128 | assert(serializedEngine); 129 | file.read(serializedEngine, size); 130 | file.close(); 131 | 132 | mRuntime = createInferRuntime(gLogger); 133 | assert(mRuntime); 134 | mEngine = mRuntime->deserializeCudaEngine(serializedEngine, size); 135 | assert(*mEngine); 136 | mContext = mEngine->createExecutionContext(); 137 | assert(*mContext); 138 | delete[] serializedEngine; 139 | 140 | assert(mEngine->getNbBindings() != inputNames.size() + outputNames.size()); 141 | 142 | initialize(inputNames, outputNames); 143 | } 144 | 145 | void TRTModule::initialize(vector inputNames, vector outputNames) 146 | { 147 | for (int i = 0; i < inputNames.size(); i++) 148 | { 149 | const int inputIndex = mEngine->getBindingIndex(inputNames[i].c_str()); 150 | } 151 | 152 | for (int i = 0; i < outputNames.size(); i++) 153 | { 154 | const int outputIndex = mEngine->getBindingIndex(outputNames[i].c_str()); 155 | } 156 | 157 | mGpuBuffers.resize(mEngine->getNbBindings()); 158 | mCpuBuffers.resize(mEngine->getNbBindings()); 159 | 160 | for (size_t i = 0; i < mEngine->getNbBindings(); ++i) 161 | { 162 | size_t binding_size = getSizeByDim(mEngine->getBindingDimensions(i)); 163 | mBufferBindingSizes.push_back(binding_size); 164 | mBufferBindingBytes.push_back(binding_size * sizeof(float)); 165 | 166 | mCpuBuffers[i] = new float[binding_size]; 167 | 168 | cudaMalloc(&mGpuBuffers[i], mBufferBindingBytes[i]); 169 | 170 | if (mEngine->bindingIsInput(i)) 171 | { 172 | mInputDims.push_back(mEngine->getBindingDimensions(i)); 173 | } 174 | else 175 | { 176 | mOutputDims.push_back(mEngine->getBindingDimensions(i)); 177 | } 178 | } 179 | 180 | CUDA_CHECK(cudaStreamCreate(&mCudaStream)); 181 | } 182 | 183 | //! 184 | //! \brief Runs the TensorRT inference engine for this sample 185 | //! 186 | //! \details This function is the main execution function of the sample. It allocates the buffer, 187 | //! sets inputs and executes the engine. 188 | //! 189 | bool TRTModule::infer() 190 | { 191 | // Memcpy from host input buffers to device input buffers 192 | copyInputToDeviceAsync(mCudaStream); 193 | 194 | bool status = mContext->executeV2(mGpuBuffers.data()); 195 | 196 | if (!status) 197 | { 198 | cout << "inference error!" << endl; 199 | return false; 200 | } 201 | 202 | // Memcpy from device output buffers to host output buffers 203 | copyOutputToHostAsync(mCudaStream); 204 | 205 | return true; 206 | } 207 | 208 | //! 209 | //! \brief Copy the contents of input host buffers to input device buffers asynchronously. 210 | //! 211 | void TRTModule::copyInputToDeviceAsync(const cudaStream_t& stream) 212 | { 213 | memcpyBuffers(true, false, true, stream); 214 | } 215 | 216 | //! 217 | //! \brief Copy the contents of output device buffers to output host buffers asynchronously. 218 | //! 219 | void TRTModule::copyOutputToHostAsync(const cudaStream_t& stream) 220 | { 221 | memcpyBuffers(false, true, true, stream); 222 | } 223 | 224 | void TRTModule::memcpyBuffers(const bool copyInput, const bool deviceToHost, const bool async, const cudaStream_t& stream) 225 | { 226 | for (int i = 0; i < mEngine->getNbBindings(); i++) 227 | { 228 | void* dstPtr = deviceToHost ? mCpuBuffers[i] : mGpuBuffers[i]; 229 | const void* srcPtr = deviceToHost ? mGpuBuffers[i] : mCpuBuffers[i]; 230 | const size_t byteSize = mBufferBindingBytes[i]; 231 | const cudaMemcpyKind memcpyType = deviceToHost ? cudaMemcpyDeviceToHost : cudaMemcpyHostToDevice; 232 | 233 | if ((copyInput && mEngine->bindingIsInput(i)) || (!copyInput && !mEngine->bindingIsInput(i))) 234 | { 235 | if (async) 236 | { 237 | CUDA_CHECK(cudaMemcpyAsync(dstPtr, srcPtr, byteSize, memcpyType, stream)); 238 | } 239 | else 240 | { 241 | CUDA_CHECK(cudaMemcpy(dstPtr, srcPtr, byteSize, memcpyType)); 242 | } 243 | } 244 | } 245 | } 246 | 247 | size_t TRTModule::getSizeByDim(const Dims& dims) 248 | { 249 | size_t size = 1; 250 | 251 | for (size_t i = 0; i < dims.nbDims; ++i) 252 | { 253 | if (dims.d[i] == -1) 254 | size *= MAX_NUM_PROMPTS; 255 | else 256 | size *= dims.d[i]; 257 | } 258 | 259 | return size; 260 | } 261 | 262 | void TRTModule::setInput(Mat& image) 263 | { 264 | const int inputH = mInputDims[0].d[2]; 265 | const int inputW = mInputDims[0].d[3]; 266 | 267 | int i = 0; 268 | for (int row = 0; row < image.rows; ++row) 269 | { 270 | uchar* uc_pixel = image.data + row * image.step; 271 | for (int col = 0; col < image.cols; ++col) 272 | { 273 | mCpuBuffers[0][i] = ((float)uc_pixel[2] / 255.0f - 0.485f) / 0.229f; 274 | mCpuBuffers[0][i + image.rows * image.cols] = ((float)uc_pixel[1] / 255.0f - 0.456f) / 0.224f; 275 | mCpuBuffers[0][i + 2 * image.rows * image.cols] = ((float)uc_pixel[0] / 255.0f - 0.406f) / 0.225f; 276 | uc_pixel += 3; 277 | ++i; 278 | } 279 | } 280 | } 281 | 282 | // Set dynamic input 283 | void TRTModule::setInput(float* features, float* imagePointCoords, float* imagePointLabels, float* maskInput, float* hasMaskInput, int numPoints) 284 | { 285 | delete[] mCpuBuffers[1]; 286 | delete[] mCpuBuffers[2]; 287 | mCpuBuffers[1] = new float[numPoints * 2]; 288 | mCpuBuffers[2] = new float[numPoints]; 289 | 290 | cudaMalloc(&mGpuBuffers[1], sizeof(float) * numPoints * 2); 291 | cudaMalloc(&mGpuBuffers[2], sizeof(float) * numPoints); 292 | 293 | mBufferBindingBytes[1] = sizeof(float) * numPoints * 2; 294 | mBufferBindingBytes[2] = sizeof(float) * numPoints; 295 | 296 | memcpy(mCpuBuffers[0], features, mBufferBindingBytes[0]); 297 | memcpy(mCpuBuffers[1], imagePointCoords, sizeof(float) * numPoints * 2); 298 | memcpy(mCpuBuffers[2], imagePointLabels, sizeof(float) * numPoints); 299 | memcpy(mCpuBuffers[3], maskInput, mBufferBindingBytes[3]); 300 | memcpy(mCpuBuffers[4], hasMaskInput, mBufferBindingBytes[4]); 301 | 302 | // Setting Dynamic Input Shape in TensorRT 303 | mContext->setOptimizationProfileAsync(0, mCudaStream); 304 | mContext->setBindingDimensions(1, Dims3{ 1, numPoints, 2 }); 305 | mContext->setBindingDimensions(2, Dims2{ 1, numPoints }); 306 | } 307 | 308 | void TRTModule::getOutput(float* features) 309 | { 310 | memcpy(features, mCpuBuffers[1], mBufferBindingBytes[1]); 311 | } 312 | 313 | void TRTModule::getOutput(float* iouPrediction, float* lowResolutionMasks) 314 | { 315 | memcpy(lowResolutionMasks, mCpuBuffers[5], mBufferBindingBytes[5]); 316 | memcpy(iouPrediction, mCpuBuffers[6], mBufferBindingBytes[6]); 317 | } 318 | -------------------------------------------------------------------------------- /nanosam/trt_module.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "NvInfer.h" 4 | #include 5 | 6 | using namespace nvinfer1; 7 | using namespace std; 8 | using namespace cv; 9 | 10 | class TRTModule 11 | { 12 | 13 | public: 14 | 15 | TRTModule(string modelPath, vector inputNames, vector outputNames, bool isDynamicShape, bool isFP16); 16 | 17 | bool infer(); 18 | 19 | void setInput(Mat& image); 20 | 21 | void setInput(float* features, float* imagePointCoords, float* imagePointLabels, float* maskInput, float* hasHaskInput, int numPoints); 22 | 23 | void getOutput(float* iouPrediction, float* lowResolutionMasks); 24 | 25 | void getOutput(float* features); 26 | 27 | ~TRTModule(); 28 | 29 | private: 30 | 31 | void build(string onnxPath, vector inputNames, vector outputNames, bool isDynamicShape = false, bool isFP16 = false); 32 | 33 | void deserializeEngine(string engineName, vector inputNames, vector outputNames); 34 | 35 | void initialize(vector inputNames, vector outputNames); 36 | 37 | size_t getSizeByDim(const Dims& dims); 38 | 39 | void memcpyBuffers(const bool copyInput, const bool deviceToHost, const bool async, const cudaStream_t& stream = 0); 40 | 41 | void copyInputToDeviceAsync(const cudaStream_t& stream = 0); 42 | 43 | void copyOutputToHostAsync(const cudaStream_t& stream = 0); 44 | 45 | 46 | vector mInputDims; //!< The dimensions of the input to the network. 47 | vector mOutputDims; //!< The dimensions of the output to the network. 48 | vector mGpuBuffers; //!< The vector of device buffers needed for engine execution 49 | vector mCpuBuffers; 50 | vector mBufferBindingBytes; 51 | vector mBufferBindingSizes; 52 | cudaStream_t mCudaStream; 53 | 54 | IRuntime* mRuntime; //!< The TensorRT runtime used to deserialize the engine 55 | ICudaEngine* mEngine; //!< The TensorRT engine used to run the network 56 | IExecutionContext* mContext; //!< The context for executing inference using an ICudaEngine 57 | }; 58 | -------------------------------------------------------------------------------- /utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | 4 | // Colors 5 | const std::vector CITYSCAPES_COLORS = { 6 | cv::Scalar(128, 64, 128), 7 | cv::Scalar(232, 35, 244), 8 | cv::Scalar(70, 70, 70), 9 | cv::Scalar(156, 102, 102), 10 | cv::Scalar(153, 153, 190), 11 | cv::Scalar(153, 153, 153), 12 | cv::Scalar(30, 170, 250), 13 | cv::Scalar(0, 220, 220), 14 | cv::Scalar(35, 142, 107), 15 | cv::Scalar(152, 251, 152), 16 | cv::Scalar(180, 130, 70), 17 | cv::Scalar(60, 20, 220), 18 | cv::Scalar(0, 0, 255), 19 | cv::Scalar(142, 0, 0), 20 | cv::Scalar(70, 0, 0), 21 | cv::Scalar(100, 60, 0), 22 | cv::Scalar(90, 0, 0), 23 | cv::Scalar(230, 0, 0), 24 | cv::Scalar(32, 11, 119), 25 | cv::Scalar(0, 74, 111), 26 | cv::Scalar(81, 0, 81) 27 | }; 28 | 29 | // Structure to hold clicked point coordinates 30 | struct PointData { 31 | cv::Point point; 32 | bool clicked; 33 | }; 34 | 35 | // Overlay mask on the image 36 | void overlay(Mat& image, Mat& mask, Scalar color = Scalar(128, 64, 128), float alpha = 0.8f, bool showEdge = true) 37 | { 38 | // Draw mask 39 | Mat ucharMask(image.rows, image.cols, CV_8UC3, color); 40 | image.copyTo(ucharMask, mask <= 0); 41 | addWeighted(ucharMask, alpha, image, 1.0 - alpha, 0.0f, image); 42 | 43 | // Draw contour edge 44 | if (showEdge) 45 | { 46 | vector> contours; 47 | vector hierarchy; 48 | findContours(mask <= 0, contours, hierarchy, RETR_TREE, CHAIN_APPROX_NONE); 49 | drawContours(image, contours, -1, Scalar(255, 255, 255), 2); 50 | } 51 | } 52 | 53 | // Function to handle mouse events 54 | void onMouse(int event, int x, int y, int flags, void* userdata) { 55 | PointData* pd = (PointData*)userdata; 56 | if (event == cv::EVENT_LBUTTONDOWN) { 57 | // Save the clicked coordinates 58 | pd->point = cv::Point(x, y); 59 | pd->clicked = true; 60 | } 61 | } 62 | --------------------------------------------------------------------------------