├── .gitignore ├── .vscode └── settings.json ├── LICENSE.md ├── README.md ├── assets ├── basic_usage_out.jpg ├── camera.jpg ├── class.jpg ├── frog.jpg ├── jetson_person_2x.gif ├── nanoowl_jetbot.gif ├── owl_glove.jpg ├── owl_glove_out.jpg ├── owl_glove_small.jpg ├── owl_gradio_demo.jpg ├── owl_predict_out.jpg └── tree_predict_out.jpg ├── docker └── 23-01 │ ├── Dockerfile │ ├── build.sh │ └── run.sh ├── examples ├── owl_predict.py ├── tree_demo │ ├── index.html │ └── tree_demo.py └── tree_predict.py ├── nanoowl ├── __init__.py ├── build_image_encoder_engine.py ├── clip_predictor.py ├── image_preprocessor.py ├── owl_drawing.py ├── owl_predictor.py ├── sync_timer.py ├── tree.py ├── tree_drawing.py └── tree_predictor.py ├── setup.py └── test ├── __init__.py ├── test_clip_predictor.py ├── test_image_preprocessor.py ├── test_owl_predictor.py ├── test_tree.py └── test_tree_predictor.py /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | *.egg-info 3 | *.pyc 4 | charmap.json.gz 5 | sandbox -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.pytestArgs": [ 3 | "test" 4 | ], 5 | "python.testing.unittestEnabled": false, 6 | "python.testing.pytestEnabled": true 7 | } -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

NanoOWL

2 | 3 |

👍 Usage - ⏱️ Performance - 🛠️ Setup - 🤸 Examples
- 👏 Acknowledgment - 🔗 See also

4 | 5 | NanoOWL is a project that optimizes [OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit) to run 🔥 ***real-time*** 🔥 on [NVIDIA Jetson Orin Platforms](https://store.nvidia.com/en-us/jetson/store) with [NVIDIA TensorRT](https://developer.nvidia.com/tensorrt). NanoOWL also introduces a new "tree detection" pipeline that combines OWL-ViT and CLIP to enable nested detection and classification of anything, at any level, simply by providing text. 6 | 7 |

8 |

9 | 10 | > Interested in detecting object masks as well? Try combining NanoOWL with 11 | > [NanoSAM](https://github.com/NVIDIA-AI-IOT/nanosam) for zero-shot open-vocabulary 12 | > instance segmentation. 13 | 14 | 15 | ## 👍 Usage 16 | 17 | You can use NanoOWL in Python like this 18 | 19 | ```python3 20 | from nanoowl.owl_predictor import OwlPredictor 21 | 22 | predictor = OwlPredictor( 23 | "google/owlvit-base-patch32", 24 | image_encoder_engine="data/owlvit-base-patch32-image-encoder.engine" 25 | ) 26 | 27 | image = PIL.Image.open("assets/owl_glove_small.jpg") 28 | 29 | output = predictor.predict(image=image, text=["an owl", "a glove"], threshold=0.1) 30 | 31 | print(output) 32 | ``` 33 | 34 | Or better yet, to use OWL-ViT in conjunction with CLIP to detect and classify anything, 35 | at any level, check out the tree predictor example below! 36 | 37 | > See [Setup](#setup) for instructions on how to build the image encoder engine. 38 | 39 | 40 | ## ⏱️ Performance 41 | 42 | NanoOWL runs real-time on Jetson Orin Nano. 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 |
Model †Image SizePatch Size⏱️ Jetson Orin Nano (FPS)⏱️ Jetson AGX Orin (FPS)🎯 Accuracy (mAP)
OWL-ViT (ViT-B/32)76832TBD9528
OWL-ViT (ViT-B/16)76816TBD2531.7
74 | 75 | 76 | ## 🛠️ Setup 77 | 78 | 1. Install the dependencies 79 | 80 | 1. Install PyTorch 81 | 82 | 2. Install [torch2trt](https://github.com/NVIDIA-AI-IOT/torch2trt) 83 | 3. Install NVIDIA TensorRT 84 | 4. Install the Transformers library 85 | 86 | ```bash 87 | python3 -m pip install transformers 88 | ``` 89 | 5. (optional) Install NanoSAM (for the instance segmentation example) 90 | 91 | 2. Install the NanoOWL package. 92 | 93 | ```bash 94 | git clone https://github.com/NVIDIA-AI-IOT/nanoowl 95 | cd nanoowl 96 | python3 setup.py develop --user 97 | ``` 98 | 99 | 3. Build the TensorRT engine for the OWL-ViT vision encoder 100 | 101 | ```bash 102 | mkdir -p data 103 | python3 -m nanoowl.build_image_encoder_engine \ 104 | data/owl_image_encoder_patch32.engine 105 | ``` 106 | 107 | 108 | 4. Run an example prediction to ensure everything is working 109 | 110 | ```bash 111 | cd examples 112 | python3 owl_predict.py \ 113 | --prompt="[an owl, a glove]" \ 114 | --threshold=0.1 \ 115 | --image_encoder_engine=../data/owl_image_encoder_patch32.engine 116 | ``` 117 | 118 | That's it! If everything is working properly, you should see a visualization saved to ``data/owl_predict_out.jpg``. 119 | 120 | 121 | ## 🤸 Examples 122 | 123 | ### Example 1 - Basic prediction 124 | 125 | 126 | 127 | This example demonstrates how to use the TensorRT optimized OWL-ViT model to 128 | detect objects by providing text descriptions of the object labels. 129 | 130 | To run the example, first navigate to the examples folder 131 | 132 | ```bash 133 | cd examples 134 | ``` 135 | 136 | Then run the example 137 | 138 | ```bash 139 | python3 owl_predict.py \ 140 | --prompt="[an owl, a glove]" \ 141 | --threshold=0.1 \ 142 | --image_encoder_engine=../data/owl_image_encoder_patch32.engine 143 | ``` 144 | 145 | By default the output will be saved to ``data/owl_predict_out.jpg``. 146 | 147 | You can also use this example to profile inference. Simply set the flag ``--profile``. 148 | 149 | ### Example 2 - Tree prediction 150 | 151 | 152 | 153 | This example demonstrates how to use the tree predictor class to detect and 154 | classify objects at any level. 155 | 156 | To run the example, first navigate to the examples folder 157 | 158 | ```bash 159 | cd examples 160 | ``` 161 | 162 | To detect all owls, and the detect all wings and eyes in each detect owl region 163 | of interest, type 164 | 165 | ```bash 166 | python3 tree_predict.py \ 167 | --prompt="[an owl [a wing, an eye]]" \ 168 | --threshold=0.15 \ 169 | --image_encoder_engine=../data/owl_image_encoder_patch32.engine 170 | ``` 171 | 172 | By default the output will be saved to ``data/tree_predict_out.jpg``. 173 | 174 | To classify the image as indoors or outdoors, type 175 | 176 | ```bash 177 | python3 tree_predict.py \ 178 | --prompt="(indoors, outdoors)" \ 179 | --threshold=0.15 \ 180 | --image_encoder_engine=../data/owl_image_encoder_patch32.engine 181 | ``` 182 | 183 | To classify the image as indoors or outdoors, and if it's outdoors then detect 184 | all owls, type 185 | 186 | ```bash 187 | python3 tree_predict.py \ 188 | --prompt="(indoors, outdoors [an owl])" \ 189 | --threshold=0.15 \ 190 | --image_encoder_engine=../data/owl_image_encoder_patch32.engine 191 | ``` 192 | 193 | 194 | ### Example 3 - Tree prediction (Live Camera) 195 | 196 | 197 | 198 | This example demonstrates the tree predictor running on a live camera feed with 199 | live-edited text prompts. To run the example 200 | 201 | 1. Ensure you have a camera device connected 202 | 203 | 2. Launch the demo 204 | ```bash 205 | cd examples/tree_demo 206 | python3 tree_demo.py ../../data/owl_image_encoder_patch32.engine 207 | ``` 208 | 3. Second, open your browser to ``http://:7860`` 209 | 4. Type whatever prompt you like to see what works! Here are some examples 210 | - Example: [a face [a nose, an eye, a mouth]] 211 | - Example: [a face (interested, yawning / bored)] 212 | - Example: (indoors, outdoors) 213 | 214 | 215 | 216 | 217 | ## 👏 Acknowledgement 218 | 219 | Thanks to the authors of [OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit) for the great open-vocabluary detection work. 220 | 221 | 222 | ## 🔗 See also 223 | 224 | - [NanoSAM](https://github.com/NVIDIA-AI-IOT/nanosam) - A real-time Segment Anything (SAM) model variant for NVIDIA Jetson Orin platforms. 225 | - [Jetson Introduction to Knowledge Distillation Tutorial](https://github.com/NVIDIA-AI-IOT/jetson-intro-to-distillation) - For an introduction to knowledge distillation as a model optimization technique. 226 | - [Jetson Generative AI Playground](https://nvidia-ai-iot.github.io/jetson-generative-ai-playground/) - For instructions and tips for using a variety of LLMs and transformers on Jetson. 227 | - [Jetson Containers](https://github.com/dusty-nv/jetson-containers) - For a variety of easily deployable and modular Jetson Containers 228 | -------------------------------------------------------------------------------- /assets/basic_usage_out.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/basic_usage_out.jpg -------------------------------------------------------------------------------- /assets/camera.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/camera.jpg -------------------------------------------------------------------------------- /assets/class.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/class.jpg -------------------------------------------------------------------------------- /assets/frog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/frog.jpg -------------------------------------------------------------------------------- /assets/jetson_person_2x.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/jetson_person_2x.gif -------------------------------------------------------------------------------- /assets/nanoowl_jetbot.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/nanoowl_jetbot.gif -------------------------------------------------------------------------------- /assets/owl_glove.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/owl_glove.jpg -------------------------------------------------------------------------------- /assets/owl_glove_out.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/owl_glove_out.jpg -------------------------------------------------------------------------------- /assets/owl_glove_small.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/owl_glove_small.jpg -------------------------------------------------------------------------------- /assets/owl_gradio_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/owl_gradio_demo.jpg -------------------------------------------------------------------------------- /assets/owl_predict_out.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/owl_predict_out.jpg -------------------------------------------------------------------------------- /assets/tree_predict_out.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/tree_predict_out.jpg -------------------------------------------------------------------------------- /docker/23-01/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:23.01-py3 2 | 3 | # upgrade pillow to fix "UnidentifiedImageError" 4 | RUN pip install pillow --upgrade 5 | 6 | RUN pip install git+https://github.com/NVIDIA-AI-IOT/torch2trt.git 7 | RUN pip install transformers timm accelerate 8 | RUN pip install git+https://github.com/openai/CLIP.git -------------------------------------------------------------------------------- /docker/23-01/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | docker build -t nanoowl:23-01 -f $(pwd)/docker/23-01/Dockerfile $(pwd)/docker/23-01 -------------------------------------------------------------------------------- /docker/23-01/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | docker run \ 5 | -it \ 6 | -d \ 7 | --rm \ 8 | --ipc host \ 9 | --gpus all \ 10 | --shm-size 14G \ 11 | --device /dev/video0:/dev/video0 \ 12 | -v /tmp/.X11-unix:/tmp/.X11-unix \ 13 | -e DISPLAY=$DISPLAY \ 14 | -p 7860:7860 \ 15 | -v $(pwd):/nanoowl \ 16 | nanoowl:23-01 -------------------------------------------------------------------------------- /examples/owl_predict.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import argparse 18 | import PIL.Image 19 | import time 20 | import torch 21 | from nanoowl.owl_predictor import ( 22 | OwlPredictor 23 | ) 24 | from nanoowl.owl_drawing import ( 25 | draw_owl_output 26 | ) 27 | 28 | 29 | if __name__ == "__main__": 30 | 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--image", type=str, default="../assets/owl_glove_small.jpg") 33 | parser.add_argument("--prompt", type=str, default="[an owl, a glove]") 34 | parser.add_argument("--threshold", type=str, default="0.1,0.1") 35 | parser.add_argument("--output", type=str, default="../data/owl_predict_out.jpg") 36 | parser.add_argument("--model", type=str, default="google/owlvit-base-patch32") 37 | parser.add_argument("--image_encoder_engine", type=str, default="../data/owl_image_encoder_patch32.engine") 38 | parser.add_argument("--profile", action="store_true") 39 | parser.add_argument("--num_profiling_runs", type=int, default=30) 40 | args = parser.parse_args() 41 | 42 | prompt = args.prompt.strip("][()") 43 | text = prompt.split(',') 44 | print(text) 45 | 46 | thresholds = args.threshold.strip("][()") 47 | thresholds = thresholds.split(',') 48 | if len(thresholds) == 1: 49 | thresholds = float(thresholds[0]) 50 | else: 51 | thresholds = [float(x) for x in thresholds] 52 | print(thresholds) 53 | 54 | 55 | predictor = OwlPredictor( 56 | args.model, 57 | image_encoder_engine=args.image_encoder_engine 58 | ) 59 | 60 | image = PIL.Image.open(args.image) 61 | 62 | text_encodings = predictor.encode_text(text) 63 | 64 | output = predictor.predict( 65 | image=image, 66 | text=text, 67 | text_encodings=text_encodings, 68 | threshold=thresholds, 69 | pad_square=False 70 | ) 71 | 72 | if args.profile: 73 | torch.cuda.current_stream().synchronize() 74 | t0 = time.perf_counter_ns() 75 | for i in range(args.num_profiling_runs): 76 | output = predictor.predict( 77 | image=image, 78 | text=text, 79 | text_encodings=text_encodings, 80 | threshold=thresholds, 81 | pad_square=False 82 | ) 83 | torch.cuda.current_stream().synchronize() 84 | t1 = time.perf_counter_ns() 85 | dt = (t1 - t0) / 1e9 86 | print(f"PROFILING FPS: {args.num_profiling_runs/dt}") 87 | 88 | image = draw_owl_output(image, output, text=text, draw_text=True) 89 | 90 | image.save(args.output) -------------------------------------------------------------------------------- /examples/tree_demo/index.html: -------------------------------------------------------------------------------- 1 | 17 | 18 | 19 | 20 | 21 | 54 | 55 | 103 | 104 | 105 |
106 |

NanoOWL

107 | Camera Image 108 |
109 | 110 |
111 | 112 | -------------------------------------------------------------------------------- /examples/tree_demo/tree_demo.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import asyncio 18 | import argparse 19 | from aiohttp import web, WSCloseCode 20 | import logging 21 | import weakref 22 | import cv2 23 | import time 24 | import PIL.Image 25 | import matplotlib.pyplot as plt 26 | from typing import List 27 | from nanoowl.tree import Tree 28 | from nanoowl.tree_predictor import ( 29 | TreePredictor 30 | ) 31 | from nanoowl.tree_drawing import draw_tree_output 32 | from nanoowl.owl_predictor import OwlPredictor 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument("image_encode_engine", type=str) 38 | parser.add_argument("--image_quality", type=int, default=50) 39 | parser.add_argument("--port", type=int, default=7860) 40 | parser.add_argument("--host", type=str, default="0.0.0.0") 41 | parser.add_argument("--camera", type=int, default=0) 42 | parser.add_argument("--resolution", type=str, default="640x480", help="Camera resolution as WIDTHxHEIGHT") 43 | args = parser.parse_args() 44 | width, height = map(int, args.resolution.split("x")) 45 | 46 | CAMERA_DEVICE = args.camera 47 | IMAGE_QUALITY = args.image_quality 48 | 49 | predictor = TreePredictor( 50 | owl_predictor=OwlPredictor( 51 | image_encoder_engine=args.image_encode_engine 52 | ) 53 | ) 54 | 55 | prompt_data = None 56 | 57 | def get_colors(count: int): 58 | cmap = plt.cm.get_cmap("rainbow", count) 59 | colors = [] 60 | for i in range(count): 61 | color = cmap(i) 62 | color = [int(255 * value) for value in color] 63 | colors.append(tuple(color)) 64 | return colors 65 | 66 | 67 | def cv2_to_pil(image): 68 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 69 | return PIL.Image.fromarray(image) 70 | 71 | 72 | async def handle_index_get(request: web.Request): 73 | logging.info("handle_index_get") 74 | return web.FileResponse("./index.html") 75 | 76 | 77 | async def websocket_handler(request): 78 | 79 | global prompt_data 80 | 81 | ws = web.WebSocketResponse() 82 | 83 | await ws.prepare(request) 84 | 85 | logging.info("Websocket connected.") 86 | 87 | request.app['websockets'].add(ws) 88 | 89 | try: 90 | async for msg in ws: 91 | logging.info(f"Received message from websocket.") 92 | if "prompt" in msg.data: 93 | header, prompt = msg.data.split(":") 94 | logging.info("Received prompt: " + prompt) 95 | try: 96 | tree = Tree.from_prompt(prompt) 97 | clip_encodings = predictor.encode_clip_text(tree) 98 | owl_encodings = predictor.encode_owl_text(tree) 99 | prompt_data = { 100 | "tree": tree, 101 | "clip_encodings": clip_encodings, 102 | "owl_encodings": owl_encodings 103 | } 104 | logging.info("Set prompt: " + prompt) 105 | except Exception as e: 106 | print(e) 107 | finally: 108 | request.app['websockets'].discard(ws) 109 | 110 | return ws 111 | 112 | 113 | async def on_shutdown(app: web.Application): 114 | for ws in set(app['websockets']): 115 | await ws.close(code=WSCloseCode.GOING_AWAY, 116 | message='Server shutdown') 117 | 118 | 119 | async def detection_loop(app: web.Application): 120 | 121 | loop = asyncio.get_running_loop() 122 | 123 | logging.info("Opening camera.") 124 | 125 | camera = cv2.VideoCapture(CAMERA_DEVICE) 126 | camera.set(cv2.CAP_PROP_FRAME_WIDTH, width) 127 | camera.set(cv2.CAP_PROP_FRAME_HEIGHT, height) 128 | 129 | logging.info("Loading predictor.") 130 | 131 | def _read_and_encode_image(): 132 | 133 | re, image = camera.read() 134 | 135 | if not re: 136 | return re, None 137 | 138 | image_pil = cv2_to_pil(image) 139 | 140 | if prompt_data is not None: 141 | prompt_data_local = prompt_data 142 | t0 = time.perf_counter_ns() 143 | detections = predictor.predict( 144 | image_pil, 145 | tree=prompt_data_local['tree'], 146 | clip_text_encodings=prompt_data_local['clip_encodings'], 147 | owl_text_encodings=prompt_data_local['owl_encodings'] 148 | ) 149 | t1 = time.perf_counter_ns() 150 | dt = (t1 - t0) / 1e9 151 | tree = prompt_data_local['tree'] 152 | image = draw_tree_output(image, detections, prompt_data_local['tree']) 153 | 154 | image_jpeg = bytes( 155 | cv2.imencode(".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, IMAGE_QUALITY])[1] 156 | ) 157 | 158 | return re, image_jpeg 159 | 160 | while True: 161 | 162 | re, image = await loop.run_in_executor(None, _read_and_encode_image) 163 | 164 | if not re: 165 | break 166 | 167 | for ws in app["websockets"]: 168 | await ws.send_bytes(image) 169 | 170 | camera.release() 171 | 172 | 173 | async def run_detection_loop(app): 174 | try: 175 | task = asyncio.create_task(detection_loop(app)) 176 | yield 177 | task.cancel() 178 | except asyncio.CancelledError: 179 | pass 180 | finally: 181 | await task 182 | 183 | 184 | logging.basicConfig(level=logging.INFO) 185 | app = web.Application() 186 | app['websockets'] = weakref.WeakSet() 187 | app.router.add_get("/", handle_index_get) 188 | app.router.add_route("GET", "/ws", websocket_handler) 189 | app.on_shutdown.append(on_shutdown) 190 | app.cleanup_ctx.append(run_detection_loop) 191 | web.run_app(app, host=args.host, port=args.port) 192 | -------------------------------------------------------------------------------- /examples/tree_predict.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import argparse 18 | import PIL.Image 19 | from nanoowl.owl_predictor import OwlPredictor 20 | from nanoowl.tree_predictor import ( 21 | TreePredictor, Tree 22 | ) 23 | from nanoowl.tree_drawing import ( 24 | draw_tree_output 25 | ) 26 | 27 | 28 | if __name__ == "__main__": 29 | 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--image", type=str, default="../assets/owl_glove_small.jpg") 32 | parser.add_argument("--prompt", type=str, default="") 33 | parser.add_argument("--threshold", type=float, default=0.1) 34 | parser.add_argument("--output", type=str, default="../data/tree_predict_out.jpg") 35 | parser.add_argument("--model", type=str, default="google/owlvit-base-patch32") 36 | parser.add_argument("--image_encoder_engine", type=str, default="../data/owl_image_encoder_patch32.engine") 37 | args = parser.parse_args() 38 | 39 | predictor = TreePredictor( 40 | owl_predictor=OwlPredictor( 41 | args.model, 42 | image_encoder_engine=args.image_encoder_engine 43 | ) 44 | ) 45 | 46 | image = PIL.Image.open(args.image) 47 | tree = Tree.from_prompt(args.prompt) 48 | clip_text_encodings = predictor.encode_clip_text(tree) 49 | owl_text_encodings = predictor.encode_owl_text(tree) 50 | 51 | output = predictor.predict( 52 | image=image, 53 | tree=tree, 54 | clip_text_encodings=clip_text_encodings, 55 | owl_text_encodings=owl_text_encodings, 56 | threshold=args.threshold 57 | ) 58 | 59 | image = draw_tree_output(image, output, tree=tree, draw_text=True) 60 | 61 | image.save(args.output) -------------------------------------------------------------------------------- /nanoowl/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /nanoowl/build_image_encoder_engine.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import argparse 18 | from .owl_predictor import OwlPredictor 19 | 20 | 21 | if __name__ == "__main__": 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("output_path", type=str) 25 | parser.add_argument("--model_name", type=str, default="google/owlvit-base-patch32") 26 | parser.add_argument("--fp16_mode", type=bool, default=True) 27 | parser.add_argument("--onnx_opset", type=int, default=16) 28 | args = parser.parse_args() 29 | 30 | predictor = OwlPredictor( 31 | model_name=args.model_name 32 | ) 33 | 34 | predictor.build_image_encoder_engine( 35 | args.output_path, 36 | fp16_mode=args.fp16_mode, 37 | onnx_opset=args.onnx_opset 38 | ) -------------------------------------------------------------------------------- /nanoowl/clip_predictor.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import torch 18 | import clip 19 | import PIL.Image 20 | from torchvision.ops import roi_align 21 | from typing import List, Tuple, Optional 22 | from dataclasses import dataclass 23 | from .image_preprocessor import ImagePreprocessor 24 | 25 | 26 | __all__ = [ 27 | "ClipPredictor", 28 | "ClipEncodeTextOutput", 29 | "ClipEncodeImageOutput", 30 | "ClipDecodeOutput" 31 | ] 32 | 33 | 34 | @dataclass 35 | class ClipEncodeTextOutput: 36 | text_embeds: torch.Tensor 37 | 38 | def slice(self, start_index, end_index): 39 | return ClipEncodeTextOutput( 40 | text_embeds=self.text_embeds[start_index:end_index] 41 | ) 42 | 43 | 44 | @dataclass 45 | class ClipEncodeImageOutput: 46 | image_embeds: torch.Tensor 47 | 48 | 49 | @dataclass 50 | class ClipDecodeOutput: 51 | labels: torch.Tensor 52 | scores: torch.Tensor 53 | 54 | 55 | class ClipPredictor(torch.nn.Module): 56 | 57 | def __init__(self, 58 | model_name: str = "ViT-B/32", 59 | image_size: Tuple[int, int] = (224, 224), 60 | device: str = "cuda", 61 | image_preprocessor: Optional[ImagePreprocessor] = None 62 | ): 63 | super().__init__() 64 | self.device = device 65 | self.clip_model, _ = clip.load(model_name, device) 66 | self.image_size = image_size 67 | self.mesh_grid = torch.stack( 68 | torch.meshgrid( 69 | torch.linspace(0., 1., self.image_size[1]), 70 | torch.linspace(0., 1., self.image_size[0]) 71 | ) 72 | ).to(self.device).float() 73 | self.image_preprocessor = image_preprocessor.to(self.device).eval() if image_preprocessor else ImagePreprocessor().to(self.device).eval() 74 | 75 | def get_device(self): 76 | return self.device 77 | 78 | def get_image_size(self): 79 | return self.image_size 80 | 81 | def encode_text(self, text: List[str]) -> ClipEncodeTextOutput: 82 | text_tokens = clip.tokenize(text).to(self.device) 83 | text_embeds = self.clip_model.encode_text(text_tokens) 84 | return ClipEncodeTextOutput(text_embeds=text_embeds) 85 | 86 | def encode_image(self, image: torch.Tensor) -> ClipEncodeImageOutput: 87 | image_embeds = self.clip_model.encode_image(image) 88 | return ClipEncodeImageOutput(image_embeds=image_embeds) 89 | 90 | def extract_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool = True, padding_scale: float=1.0): 91 | if len(rois) == 0: 92 | return torch.empty( 93 | (0, image.shape[1], self.image_size[0], self.image_size[1]), 94 | dtype=image.dtype, 95 | device=image.device 96 | ) 97 | 98 | if pad_square: 99 | # pad square 100 | w = padding_scale * (rois[..., 2] - rois[..., 0]) / 2 101 | h = padding_scale * (rois[..., 3] - rois[..., 1]) / 2 102 | cx = (rois[..., 0] + rois[..., 2]) / 2 103 | cy = (rois[..., 1] + rois[..., 3]) / 2 104 | s = torch.max(w, h) 105 | rois = torch.stack([cx-s, cy-s, cx+s, cy+s], dim=-1) 106 | 107 | # compute mask 108 | pad_x = (s - w) / (2 * s) 109 | pad_y = (s - h) / (2 * s) 110 | mask_x = (self.mesh_grid[1][None, ...] > pad_x[..., None, None]) & (self.mesh_grid[1][None, ...] < (1. - pad_x[..., None, None])) 111 | mask_y = (self.mesh_grid[0][None, ...] > pad_y[..., None, None]) & (self.mesh_grid[0][None, ...] < (1. - pad_y[..., None, None])) 112 | mask = (mask_x & mask_y) 113 | 114 | roi_images = roi_align(image, [rois], output_size=self.get_image_size()) 115 | 116 | if pad_square: 117 | roi_images = roi_images * mask[:, None, :, :] 118 | 119 | return roi_images, rois 120 | 121 | def encode_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool = True, padding_scale: float = 1.0): 122 | roi_images, rois = self.extract_rois(image, rois, pad_square, padding_scale) 123 | return self.encode_image(roi_images) 124 | 125 | def decode(self, 126 | image_output: ClipEncodeImageOutput, 127 | text_output: ClipEncodeTextOutput 128 | ) -> ClipDecodeOutput: 129 | 130 | image_embeds = image_output.image_embeds 131 | text_embeds = text_output.text_embeds 132 | 133 | image_embeds = image_embeds / image_embeds.norm(dim=1, keepdim=True) 134 | text_embeds = text_embeds / text_embeds.norm(dim=1, keepdim=True) 135 | logit_scale = self.clip_model.logit_scale.exp() 136 | logits_per_image = logit_scale * image_embeds @ text_embeds.t() 137 | probs = torch.softmax(logits_per_image, dim=-1) 138 | prob_max = probs.max(dim=-1) 139 | 140 | return ClipDecodeOutput( 141 | labels=prob_max.indices, 142 | scores=prob_max.values 143 | ) 144 | 145 | def predict(self, 146 | image: PIL.Image, 147 | text: List[str], 148 | text_encodings: Optional[ClipEncodeTextOutput], 149 | pad_square: bool = True, 150 | threshold: float = 0.1 151 | ) -> ClipDecodeOutput: 152 | 153 | image_tensor = self.image_preprocessor.preprocess_pil_image(image) 154 | 155 | if text_encodings is None: 156 | text_encodings = self.encode_text(text) 157 | 158 | rois = torch.tensor([[0, 0, image.height, image.width]], dtype=image_tensor.dtype, device=image_tensor.device) 159 | 160 | image_encodings = self.encode_rois(image_tensor, rois, pad_square=pad_square) 161 | 162 | return self.decode(image_encodings, text_encodings) -------------------------------------------------------------------------------- /nanoowl/image_preprocessor.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import torch 18 | import PIL.Image 19 | import numpy as np 20 | from typing import Tuple 21 | 22 | 23 | __all__ = [ 24 | "ImagePreprocessor", 25 | "DEFAULT_IMAGE_PREPROCESSOR_MEAN", 26 | "DEFAULT_IMAGE_PREPROCESSOR_STD" 27 | ] 28 | 29 | 30 | DEFAULT_IMAGE_PREPROCESSOR_MEAN = [ 31 | 0.48145466 * 255., 32 | 0.4578275 * 255., 33 | 0.40821073 * 255. 34 | ] 35 | 36 | 37 | DEFAULT_IMAGE_PREPROCESSOR_STD = [ 38 | 0.26862954 * 255., 39 | 0.26130258 * 255., 40 | 0.27577711 * 255. 41 | ] 42 | 43 | 44 | class ImagePreprocessor(torch.nn.Module): 45 | def __init__(self, 46 | mean: Tuple[float, float, float] = DEFAULT_IMAGE_PREPROCESSOR_MEAN, 47 | std: Tuple[float, float, float] = DEFAULT_IMAGE_PREPROCESSOR_STD 48 | ): 49 | super().__init__() 50 | 51 | self.register_buffer( 52 | "mean", 53 | torch.tensor(mean)[None, :, None, None] 54 | ) 55 | self.register_buffer( 56 | "std", 57 | torch.tensor(std)[None, :, None, None] 58 | ) 59 | 60 | def forward(self, image: torch.Tensor, inplace: bool = False): 61 | 62 | if inplace: 63 | image = image.sub_(self.mean).div_(self.std) 64 | else: 65 | image = (image - self.mean) / self.std 66 | 67 | return image 68 | 69 | @torch.no_grad() 70 | def preprocess_pil_image(self, image: PIL.Image.Image): 71 | image = torch.from_numpy(np.asarray(image)) 72 | image = image.permute(2, 0, 1)[None, ...] 73 | image = image.to(self.mean.device) 74 | image = image.type(self.mean.dtype) 75 | return self.forward(image, inplace=True) -------------------------------------------------------------------------------- /nanoowl/owl_drawing.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import PIL.Image 18 | import PIL.ImageDraw 19 | import cv2 20 | from .owl_predictor import OwlDecodeOutput 21 | import matplotlib.pyplot as plt 22 | import numpy as np 23 | from typing import List 24 | 25 | 26 | def get_colors(count: int): 27 | cmap = plt.cm.get_cmap("rainbow", count) 28 | colors = [] 29 | for i in range(count): 30 | color = cmap(i) 31 | color = [int(255 * value) for value in color] 32 | colors.append(tuple(color)) 33 | return colors 34 | 35 | 36 | def draw_owl_output(image, output: OwlDecodeOutput, text: List[str], draw_text=True): 37 | is_pil = not isinstance(image, np.ndarray) 38 | if is_pil: 39 | image = np.asarray(image) 40 | font = cv2.FONT_HERSHEY_SIMPLEX 41 | font_scale = 0.75 42 | colors = get_colors(len(text)) 43 | num_detections = len(output.labels) 44 | 45 | for i in range(num_detections): 46 | box = output.boxes[i] 47 | label_index = int(output.labels[i]) 48 | box = [int(x) for x in box] 49 | pt0 = (box[0], box[1]) 50 | pt1 = (box[2], box[3]) 51 | cv2.rectangle( 52 | image, 53 | pt0, 54 | pt1, 55 | colors[label_index], 56 | 4 57 | ) 58 | if draw_text: 59 | offset_y = 12 60 | offset_x = 0 61 | label_text = text[label_index] 62 | cv2.putText( 63 | image, 64 | label_text, 65 | (box[0] + offset_x, box[1] + offset_y), 66 | font, 67 | font_scale, 68 | colors[label_index], 69 | 2,# thickness 70 | cv2.LINE_AA 71 | ) 72 | if is_pil: 73 | image = PIL.Image.fromarray(image) 74 | return image -------------------------------------------------------------------------------- /nanoowl/owl_predictor.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import torch 18 | import numpy as np 19 | import PIL.Image 20 | import subprocess 21 | import tempfile 22 | import os 23 | from torchvision.ops import roi_align 24 | from transformers.models.owlvit.modeling_owlvit import OwlViTForObjectDetection 25 | from transformers.models.owlvit.processing_owlvit import OwlViTProcessor 26 | from dataclasses import dataclass 27 | from typing import List, Optional, Union, Tuple 28 | from .image_preprocessor import ImagePreprocessor 29 | 30 | __all__ = [ 31 | "OwlPredictor", 32 | "OwlEncodeTextOutput", 33 | "OwlEncodeImageOutput", 34 | "OwlDecodeOutput" 35 | ] 36 | 37 | 38 | def _owl_center_to_corners_format_torch(bboxes_center): 39 | center_x, center_y, width, height = bboxes_center.unbind(-1) 40 | bbox_corners = torch.stack( 41 | [ 42 | (center_x - 0.5 * width), 43 | (center_y - 0.5 * height), 44 | (center_x + 0.5 * width), 45 | (center_y + 0.5 * height) 46 | ], 47 | dim=-1, 48 | ) 49 | return bbox_corners 50 | 51 | 52 | def _owl_get_image_size(hf_name: str): 53 | 54 | image_sizes = { 55 | "google/owlvit-base-patch32": 768, 56 | "google/owlvit-base-patch16": 768, 57 | "google/owlvit-large-patch14": 840, 58 | } 59 | 60 | return image_sizes[hf_name] 61 | 62 | 63 | def _owl_get_patch_size(hf_name: str): 64 | 65 | patch_sizes = { 66 | "google/owlvit-base-patch32": 32, 67 | "google/owlvit-base-patch16": 16, 68 | "google/owlvit-large-patch14": 14, 69 | } 70 | 71 | return patch_sizes[hf_name] 72 | 73 | 74 | # This function is modified from https://github.com/huggingface/transformers/blob/e8fdd7875def7be59e2c9b823705fbf003163ea0/src/transformers/models/owlvit/modeling_owlvit.py#L1333 75 | # Copyright 2022 Google AI and The HuggingFace Team. All rights reserved. 76 | # SPDX-License-Identifier: Apache-2.0 77 | def _owl_normalize_grid_corner_coordinates(num_patches_per_side): 78 | box_coordinates = np.stack( 79 | np.meshgrid(np.arange(1, num_patches_per_side + 1), np.arange(1, num_patches_per_side + 1)), axis=-1 80 | ).astype(np.float32) 81 | box_coordinates /= np.array([num_patches_per_side, num_patches_per_side], np.float32) 82 | 83 | box_coordinates = box_coordinates.reshape( 84 | box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2] 85 | ) 86 | box_coordinates = torch.from_numpy(box_coordinates) 87 | 88 | return box_coordinates 89 | 90 | 91 | # This function is modified from https://github.com/huggingface/transformers/blob/e8fdd7875def7be59e2c9b823705fbf003163ea0/src/transformers/models/owlvit/modeling_owlvit.py#L1354 92 | # Copyright 2022 Google AI and The HuggingFace Team. All rights reserved. 93 | # SPDX-License-Identifier: Apache-2.0 94 | def _owl_compute_box_bias(num_patches_per_side): 95 | box_coordinates = _owl_normalize_grid_corner_coordinates(num_patches_per_side) 96 | box_coordinates = torch.clip(box_coordinates, 0.0, 1.0) 97 | 98 | box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4) 99 | 100 | box_size = torch.full_like(box_coord_bias, 1.0 / num_patches_per_side) 101 | box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4) 102 | 103 | box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1) 104 | 105 | return box_bias 106 | 107 | 108 | def _owl_box_roi_to_box_global(boxes, rois): 109 | x0y0 = rois[..., :2] 110 | x1y1 = rois[..., 2:] 111 | wh = (x1y1 - x0y0).repeat(1, 1, 2) 112 | x0y0 = x0y0.repeat(1, 1, 2) 113 | return (boxes * wh) + x0y0 114 | 115 | 116 | @dataclass 117 | class OwlEncodeTextOutput: 118 | text_embeds: torch.Tensor 119 | 120 | def slice(self, start_index, end_index): 121 | return OwlEncodeTextOutput( 122 | text_embeds=self.text_embeds[start_index:end_index] 123 | ) 124 | 125 | 126 | @dataclass 127 | class OwlEncodeImageOutput: 128 | image_embeds: torch.Tensor 129 | image_class_embeds: torch.Tensor 130 | logit_shift: torch.Tensor 131 | logit_scale: torch.Tensor 132 | pred_boxes: torch.Tensor 133 | 134 | 135 | @dataclass 136 | class OwlDecodeOutput: 137 | labels: torch.Tensor 138 | scores: torch.Tensor 139 | boxes: torch.Tensor 140 | input_indices: torch.Tensor 141 | 142 | 143 | class OwlPredictor(torch.nn.Module): 144 | 145 | def __init__(self, 146 | model_name: str = "google/owlvit-base-patch32", 147 | device: str = "cuda", 148 | image_encoder_engine: Optional[str] = None, 149 | image_encoder_engine_max_batch_size: int = 1, 150 | image_preprocessor: Optional[ImagePreprocessor] = None 151 | ): 152 | 153 | super().__init__() 154 | 155 | self.image_size = _owl_get_image_size(model_name) 156 | self.device = device 157 | self.model = OwlViTForObjectDetection.from_pretrained(model_name).to(self.device).eval() 158 | self.processor = OwlViTProcessor.from_pretrained(model_name) 159 | self.patch_size = _owl_get_patch_size(model_name) 160 | self.num_patches_per_side = self.image_size // self.patch_size 161 | self.box_bias = _owl_compute_box_bias(self.num_patches_per_side).to(self.device) 162 | self.num_patches = (self.num_patches_per_side)**2 163 | self.mesh_grid = torch.stack( 164 | torch.meshgrid( 165 | torch.linspace(0., 1., self.image_size), 166 | torch.linspace(0., 1., self.image_size) 167 | ) 168 | ).to(self.device).float() 169 | self.image_encoder_engine = None 170 | if image_encoder_engine is not None: 171 | image_encoder_engine = OwlPredictor.load_image_encoder_engine(image_encoder_engine, image_encoder_engine_max_batch_size) 172 | self.image_encoder_engine = image_encoder_engine 173 | self.image_preprocessor = image_preprocessor.to(self.device).eval() if image_preprocessor else ImagePreprocessor().to(self.device).eval() 174 | 175 | def get_num_patches(self): 176 | return self.num_patches 177 | 178 | def get_device(self): 179 | return self.device 180 | 181 | def get_image_size(self): 182 | return (self.image_size, self.image_size) 183 | 184 | def encode_text(self, text: List[str]) -> OwlEncodeTextOutput: 185 | text_input = self.processor(text=text, return_tensors="pt") 186 | input_ids = text_input['input_ids'].to(self.device) 187 | attention_mask = text_input['attention_mask'].to(self.device) 188 | text_outputs = self.model.owlvit.text_model(input_ids, attention_mask) 189 | text_embeds = text_outputs[1] 190 | text_embeds = self.model.owlvit.text_projection(text_embeds) 191 | return OwlEncodeTextOutput(text_embeds=text_embeds) 192 | 193 | def encode_image_torch(self, image: torch.Tensor) -> OwlEncodeImageOutput: 194 | 195 | vision_outputs = self.model.owlvit.vision_model(image) 196 | last_hidden_state = vision_outputs[0] 197 | image_embeds = self.model.owlvit.vision_model.post_layernorm(last_hidden_state) 198 | class_token_out = image_embeds[:, :1, :] 199 | image_embeds = image_embeds[:, 1:, :] * class_token_out 200 | image_embeds = self.model.layer_norm(image_embeds) # 768 dim 201 | 202 | # Box Head 203 | pred_boxes = self.model.box_head(image_embeds) 204 | pred_boxes += self.box_bias 205 | pred_boxes = torch.sigmoid(pred_boxes) 206 | pred_boxes = _owl_center_to_corners_format_torch(pred_boxes) 207 | 208 | # Class Head 209 | image_class_embeds = self.model.class_head.dense0(image_embeds) 210 | logit_shift = self.model.class_head.logit_shift(image_embeds) 211 | logit_scale = self.model.class_head.logit_scale(image_embeds) 212 | logit_scale = self.model.class_head.elu(logit_scale) + 1 213 | 214 | output = OwlEncodeImageOutput( 215 | image_embeds=image_embeds, 216 | image_class_embeds=image_class_embeds, 217 | logit_shift=logit_shift, 218 | logit_scale=logit_scale, 219 | pred_boxes=pred_boxes 220 | ) 221 | 222 | return output 223 | 224 | def encode_image_trt(self, image: torch.Tensor) -> OwlEncodeImageOutput: 225 | return self.image_encoder_engine(image) 226 | 227 | def encode_image(self, image: torch.Tensor) -> OwlEncodeImageOutput: 228 | if self.image_encoder_engine is not None: 229 | return self.encode_image_trt(image) 230 | else: 231 | return self.encode_image_torch(image) 232 | 233 | def extract_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool = True, padding_scale: float = 1.0): 234 | if len(rois) == 0: 235 | return torch.empty( 236 | (0, image.shape[1], self.image_size, self.image_size), 237 | dtype=image.dtype, 238 | device=image.device 239 | ) 240 | if pad_square: 241 | # pad square 242 | w = padding_scale * (rois[..., 2] - rois[..., 0]) / 2 243 | h = padding_scale * (rois[..., 3] - rois[..., 1]) / 2 244 | cx = (rois[..., 0] + rois[..., 2]) / 2 245 | cy = (rois[..., 1] + rois[..., 3]) / 2 246 | s = torch.max(w, h) 247 | rois = torch.stack([cx-s, cy-s, cx+s, cy+s], dim=-1) 248 | 249 | # compute mask 250 | pad_x = (s - w) / (2 * s) 251 | pad_y = (s - h) / (2 * s) 252 | mask_x = (self.mesh_grid[1][None, ...] > pad_x[..., None, None]) & (self.mesh_grid[1][None, ...] < (1. - pad_x[..., None, None])) 253 | mask_y = (self.mesh_grid[0][None, ...] > pad_y[..., None, None]) & (self.mesh_grid[0][None, ...] < (1. - pad_y[..., None, None])) 254 | mask = (mask_x & mask_y) 255 | 256 | # extract rois 257 | roi_images = roi_align(image, [rois], output_size=self.get_image_size()) 258 | 259 | # mask rois 260 | if pad_square: 261 | roi_images = (roi_images * mask[:, None, :, :]) 262 | 263 | return roi_images, rois 264 | 265 | def encode_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool = True, padding_scale: float=1.0): 266 | # with torch_timeit_sync("extract rois"): 267 | roi_images, rois = self.extract_rois(image, rois, pad_square, padding_scale) 268 | # with torch_timeit_sync("encode images"): 269 | output = self.encode_image(roi_images) 270 | pred_boxes = _owl_box_roi_to_box_global(output.pred_boxes, rois[:, None, :]) 271 | output.pred_boxes = pred_boxes 272 | return output 273 | 274 | def decode(self, 275 | image_output: OwlEncodeImageOutput, 276 | text_output: OwlEncodeTextOutput, 277 | threshold: Union[int, float, List[Union[int, float]]] = 0.1, 278 | ) -> OwlDecodeOutput: 279 | 280 | if isinstance(threshold, (int, float)): 281 | threshold = [threshold] * len(text_output.text_embeds) #apply single threshold to all labels 282 | 283 | num_input_images = image_output.image_class_embeds.shape[0] 284 | 285 | image_class_embeds = image_output.image_class_embeds 286 | image_class_embeds = image_class_embeds / (torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6) 287 | query_embeds = text_output.text_embeds 288 | query_embeds = query_embeds / (torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6) 289 | logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds) 290 | logits = (logits + image_output.logit_shift) * image_output.logit_scale 291 | 292 | scores_sigmoid = torch.sigmoid(logits) 293 | scores_max = scores_sigmoid.max(dim=-1) 294 | labels = scores_max.indices 295 | scores = scores_max.values 296 | masks = [] 297 | for i, thresh in enumerate(threshold): 298 | label_mask = labels == i 299 | score_mask = scores > thresh 300 | obj_mask = torch.logical_and(label_mask,score_mask) 301 | masks.append(obj_mask) 302 | 303 | mask = masks[0] 304 | for mask_t in masks[1:]: 305 | mask = torch.logical_or(mask, mask_t) 306 | 307 | input_indices = torch.arange(0, num_input_images, dtype=labels.dtype, device=labels.device) 308 | input_indices = input_indices[:, None].repeat(1, self.num_patches) 309 | 310 | return OwlDecodeOutput( 311 | labels=labels[mask], 312 | scores=scores[mask], 313 | boxes=image_output.pred_boxes[mask], 314 | input_indices=input_indices[mask] 315 | ) 316 | 317 | @staticmethod 318 | def get_image_encoder_input_names(): 319 | return ["image"] 320 | 321 | @staticmethod 322 | def get_image_encoder_output_names(): 323 | names = [ 324 | "image_embeds", 325 | "image_class_embeds", 326 | "logit_shift", 327 | "logit_scale", 328 | "pred_boxes" 329 | ] 330 | return names 331 | 332 | 333 | def export_image_encoder_onnx(self, 334 | output_path: str, 335 | use_dynamic_axes: bool = True, 336 | batch_size: int = 1, 337 | onnx_opset=17 338 | ): 339 | 340 | class TempModule(torch.nn.Module): 341 | def __init__(self, parent): 342 | super().__init__() 343 | self.parent = parent 344 | def forward(self, image): 345 | output = self.parent.encode_image_torch(image) 346 | return ( 347 | output.image_embeds, 348 | output.image_class_embeds, 349 | output.logit_shift, 350 | output.logit_scale, 351 | output.pred_boxes 352 | ) 353 | 354 | data = torch.randn(batch_size, 3, self.image_size, self.image_size).to(self.device) 355 | 356 | if use_dynamic_axes: 357 | dynamic_axes = { 358 | "image": {0: "batch"}, 359 | "image_embeds": {0: "batch"}, 360 | "image_class_embeds": {0: "batch"}, 361 | "logit_shift": {0: "batch"}, 362 | "logit_scale": {0: "batch"}, 363 | "pred_boxes": {0: "batch"} 364 | } 365 | else: 366 | dynamic_axes = {} 367 | 368 | model = TempModule(self) 369 | 370 | torch.onnx.export( 371 | model, 372 | data, 373 | output_path, 374 | input_names=self.get_image_encoder_input_names(), 375 | output_names=self.get_image_encoder_output_names(), 376 | dynamic_axes=dynamic_axes, 377 | opset_version=onnx_opset 378 | ) 379 | 380 | @staticmethod 381 | def load_image_encoder_engine(engine_path: str, max_batch_size: int = 1): 382 | import tensorrt as trt 383 | from torch2trt import TRTModule 384 | 385 | with trt.Logger() as logger, trt.Runtime(logger) as runtime: 386 | with open(engine_path, 'rb') as f: 387 | engine_bytes = f.read() 388 | engine = runtime.deserialize_cuda_engine(engine_bytes) 389 | 390 | base_module = TRTModule( 391 | engine, 392 | input_names=OwlPredictor.get_image_encoder_input_names(), 393 | output_names=OwlPredictor.get_image_encoder_output_names() 394 | ) 395 | 396 | class Wrapper(torch.nn.Module): 397 | def __init__(self, base_module: TRTModule, max_batch_size: int): 398 | super().__init__() 399 | self.base_module = base_module 400 | self.max_batch_size = max_batch_size 401 | 402 | @torch.no_grad() 403 | def forward(self, image): 404 | 405 | b = image.shape[0] 406 | 407 | results = [] 408 | 409 | for start_index in range(0, b, self.max_batch_size): 410 | end_index = min(b, start_index + self.max_batch_size) 411 | image_slice = image[start_index:end_index] 412 | # with torch_timeit_sync("run_engine"): 413 | output = self.base_module(image_slice) 414 | results.append( 415 | output 416 | ) 417 | 418 | return OwlEncodeImageOutput( 419 | image_embeds=torch.cat([r[0] for r in results], dim=0), 420 | image_class_embeds=torch.cat([r[1] for r in results], dim=0), 421 | logit_shift=torch.cat([r[2] for r in results], dim=0), 422 | logit_scale=torch.cat([r[3] for r in results], dim=0), 423 | pred_boxes=torch.cat([r[4] for r in results], dim=0) 424 | ) 425 | 426 | image_encoder = Wrapper(base_module, max_batch_size) 427 | 428 | return image_encoder 429 | 430 | def build_image_encoder_engine(self, 431 | engine_path: str, 432 | max_batch_size: int = 1, 433 | fp16_mode = True, 434 | onnx_path: Optional[str] = None, 435 | onnx_opset: int = 17 436 | ): 437 | 438 | if onnx_path is None: 439 | onnx_dir = tempfile.mkdtemp() 440 | onnx_path = os.path.join(onnx_dir, "image_encoder.onnx") 441 | self.export_image_encoder_onnx(onnx_path, onnx_opset=onnx_opset) 442 | 443 | args = ["/usr/src/tensorrt/bin/trtexec"] 444 | 445 | args.append(f"--onnx={onnx_path}") 446 | args.append(f"--saveEngine={engine_path}") 447 | 448 | if fp16_mode: 449 | args += ["--fp16"] 450 | 451 | args += [f"--shapes=image:1x3x{self.image_size}x{self.image_size}"] 452 | 453 | subprocess.call(args) 454 | 455 | return self.load_image_encoder_engine(engine_path, max_batch_size) 456 | 457 | def predict(self, 458 | image: PIL.Image, 459 | text: List[str], 460 | text_encodings: Optional[OwlEncodeTextOutput], 461 | threshold: Union[int, float, List[Union[int, float]]] = 0.1, 462 | pad_square: bool = True, 463 | 464 | ) -> OwlDecodeOutput: 465 | 466 | image_tensor = self.image_preprocessor.preprocess_pil_image(image) 467 | 468 | if text_encodings is None: 469 | text_encodings = self.encode_text(text) 470 | 471 | rois = torch.tensor([[0, 0, image.width, image.height]], dtype=image_tensor.dtype, device=image_tensor.device) 472 | 473 | image_encodings = self.encode_rois(image_tensor, rois, pad_square=pad_square) 474 | 475 | return self.decode(image_encodings, text_encodings, threshold) 476 | 477 | -------------------------------------------------------------------------------- /nanoowl/sync_timer.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import time 18 | import torch 19 | 20 | 21 | __all__ == [ 22 | "SyncTimer" 23 | ] 24 | 25 | 26 | class SyncTimer(): 27 | 28 | def __init__(self, name: str): 29 | self.name = name 30 | self.t0 = None 31 | 32 | def __enter__(self, *args, **kwargs): 33 | self.t0 = time.perf_counter_ns() 34 | 35 | def __exit__(self, *args, **kwargs): 36 | torch.cuda.current_stream().synchronize() 37 | t1 = time.perf_counter_ns() 38 | dt = (t1 - self.t0) / 1e9 39 | print(f"{self.name} FPS: {round(1./dt, 3)}") -------------------------------------------------------------------------------- /nanoowl/tree.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import json 18 | from enum import Enum 19 | from typing import List, Optional, Mapping 20 | from .clip_predictor import ClipEncodeTextOutput 21 | from .owl_predictor import OwlEncodeTextOutput 22 | 23 | 24 | __all__ = [ 25 | "TreeOp", 26 | "TreeNode", 27 | "Tree" 28 | ] 29 | 30 | 31 | class TreeOp(Enum): 32 | DETECT = "detect" 33 | CLASSIFY = "classify" 34 | 35 | def __str__(self) -> str: 36 | return str(self.value) 37 | 38 | 39 | class TreeNode: 40 | op: TreeOp 41 | input: int 42 | outputs: List[int] 43 | 44 | def __init__(self, op: TreeOp, input: int, outputs: Optional[List[int]] = None): 45 | self.op = op 46 | self.input = input 47 | self.outputs = [] if outputs is None else outputs 48 | 49 | def to_dict(self): 50 | return { 51 | "op": str(self.op), 52 | "input": self.input, 53 | "outputs": self.outputs 54 | } 55 | 56 | @staticmethod 57 | def from_dict(node_dict: dict): 58 | 59 | if "op" not in node_dict: 60 | raise RuntimeError("Missing 'op' field.") 61 | 62 | if "input" not in node_dict: 63 | raise RuntimeError("Missing 'input' field.") 64 | 65 | if "outputs" not in node_dict: 66 | raise RuntimeError("Missing 'input' field.") 67 | 68 | return TreeNode( 69 | op=node_dict["op"], 70 | input=node_dict["input"], 71 | outputs=node_dict["outputs"] 72 | ) 73 | 74 | 75 | class Tree: 76 | nodes: List[TreeNode] 77 | labels: List[str] 78 | 79 | def __init__(self, nodes, labels): 80 | self.nodes = nodes 81 | self.labels = labels 82 | self._label_index_to_node_map = self._build_label_index_to_node_map() 83 | 84 | def _build_label_index_to_node_map(self) -> Mapping[int, "TreeNode"]: 85 | label_to_node_map = {} 86 | for node in self.nodes: 87 | for label_index in node.outputs: 88 | if label_index in label_to_node_map: 89 | raise RuntimeError("Duplicate output label.") 90 | label_to_node_map[label_index] = node 91 | return label_to_node_map 92 | 93 | def to_dict(self): 94 | return { 95 | "nodes": [node.to_dict() for node in self.nodes], 96 | "labels": self.labels 97 | } 98 | 99 | @staticmethod 100 | def from_prompt(prompt: str): 101 | 102 | nodes = [] 103 | node_stack = [] 104 | label_index_stack = [0] 105 | labels = ["image"] 106 | label_index = 0 107 | 108 | for ch in prompt: 109 | 110 | if ch == "[": 111 | label_index += 1 112 | node = TreeNode(op=TreeOp.DETECT, input=label_index_stack[-1]) 113 | node.outputs.append(label_index) 114 | node_stack.append(node) 115 | label_index_stack.append(label_index) 116 | labels.append("") 117 | nodes.append(node) 118 | elif ch == "]": 119 | if len(node_stack) == 0: 120 | raise RuntimeError("Unexpected ']'.") 121 | node = node_stack.pop() 122 | if node.op != TreeOp.DETECT: 123 | raise RuntimeError("Unexpected ']'.") 124 | label_index_stack.pop() 125 | elif ch == "(": 126 | label_index = label_index + 1 127 | node = TreeNode(op=TreeOp.CLASSIFY, input=label_index_stack[-1]) 128 | node.outputs.append(label_index) 129 | node_stack.append(node) 130 | label_index_stack.append(label_index) 131 | labels.append("") 132 | nodes.append(node) 133 | elif ch == ")": 134 | if len(node_stack) == 0: 135 | raise RuntimeError("Unexpected ')'.") 136 | node = node_stack.pop() 137 | if node.op != TreeOp.CLASSIFY: 138 | raise RuntimeError("Unexpected ')'.") 139 | label_index_stack.pop() 140 | elif ch == ",": 141 | label_index_stack.pop() 142 | label_index = label_index + 1 143 | label_index_stack.append(label_index) 144 | node_stack[-1].outputs.append(label_index) 145 | labels.append("") 146 | else: 147 | labels[label_index_stack[-1]] += ch 148 | 149 | if len(node_stack) > 0: 150 | if node_stack[-1].op == TreeOp.DETECT: 151 | raise RuntimeError("Missing ']'.") 152 | if node_stack[-1].op == TreeOp.CLASSIFY: 153 | raise RuntimeError("Missing ')'.") 154 | 155 | labels = [label.strip() for label in labels] 156 | 157 | graph = Tree(nodes=nodes, labels=labels) 158 | 159 | return graph 160 | 161 | def to_json(self, indent: Optional[int] = None) -> str: 162 | return json.dumps(self.to_dict(), indent=indent) 163 | 164 | @staticmethod 165 | def from_dict(tree_dict: dict) -> "Tree": 166 | 167 | if "nodes" not in tree_dict: 168 | raise RuntimeError("Missing 'nodes' field.") 169 | 170 | if "labels" not in tree_dict: 171 | raise RuntimeError("Missing 'labels' field.") 172 | 173 | nodes = [TreeNode.from_dict(node_dict) for node_dict in tree_dict["nodes"]] 174 | labels = tree_dict["labels"] 175 | 176 | return Tree(nodes=nodes, labels=labels) 177 | 178 | @staticmethod 179 | def from_json(tree_json: str) -> "Tree": 180 | tree_dict = json.loads(tree_json) 181 | return Tree.from_dict(tree_dict) 182 | 183 | def get_op_for_label_index(self, label_index: int): 184 | if label_index not in self._label_index_to_node_map: 185 | return None 186 | return self._label_index_to_node_map[label_index].op 187 | 188 | def get_label_indices_with_op(self, op: TreeOp): 189 | return [ 190 | index for index in range(len(self.labels)) 191 | if self.get_op_for_label_index(index) == op 192 | ] 193 | 194 | def get_classify_label_indices(self): 195 | return self.get_label_indices_with_op(TreeOp.CLASSIFY) 196 | 197 | def get_detect_label_indices(self): 198 | return self.get_label_indices_with_op(TreeOp.DETECT) 199 | 200 | def find_nodes_with_input(self, input_index: int): 201 | return [n for n in self.nodes if n.input == input_index] 202 | 203 | def find_detect_nodes_with_input(self, input_index: int): 204 | return [n for n in self.find_nodes_with_input(input_index) if n.op == TreeOp.DETECT] 205 | 206 | def find_classify_nodes_with_input(self, input_index: int): 207 | return [n for n in self.find_nodes_with_input(input_index) if n.op == TreeOp.CLASSIFY] 208 | 209 | def get_label_depth(self, index): 210 | depth = 0 211 | while index in self._label_index_to_node_map: 212 | depth += 1 213 | node = self._label_index_to_node_map[index] 214 | index = node.input 215 | return depth 216 | 217 | def get_label_depth_map(self): 218 | depths = {} 219 | for i in range(len(self.labels)): 220 | depths[i] = self.get_label_depth(i) 221 | return depths 222 | 223 | def get_label_map(self): 224 | label_map = {} 225 | for i in range(len(self.labels)): 226 | label_map[i] = self.labels[i] 227 | return label_map -------------------------------------------------------------------------------- /nanoowl/tree_drawing.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import PIL.Image 18 | import PIL.ImageDraw 19 | import cv2 20 | from .tree import Tree 21 | from .tree_predictor import TreeOutput 22 | import matplotlib.pyplot as plt 23 | import numpy as np 24 | from typing import List 25 | 26 | 27 | def get_colors(count: int): 28 | cmap = plt.cm.get_cmap("rainbow", count) 29 | colors = [] 30 | for i in range(count): 31 | color = cmap(i) 32 | color = [int(255 * value) for value in color] 33 | colors.append(tuple(color)) 34 | return colors 35 | 36 | 37 | def draw_tree_output(image, output: TreeOutput, tree: Tree, draw_text=True, num_colors=8): 38 | detections = output.detections 39 | is_pil = not isinstance(image, np.ndarray) 40 | if is_pil: 41 | image = np.asarray(image) 42 | font = cv2.FONT_HERSHEY_SIMPLEX 43 | font_scale = 0.75 44 | colors = get_colors(num_colors) 45 | label_map = tree.get_label_map() 46 | label_depths = tree.get_label_depth_map() 47 | for detection in detections: 48 | box = [int(x) for x in detection.box] 49 | pt0 = (box[0], box[1]) 50 | pt1 = (box[2], box[3]) 51 | box_depth = min(label_depths[i] for i in detection.labels) 52 | cv2.rectangle( 53 | image, 54 | pt0, 55 | pt1, 56 | colors[box_depth % num_colors], 57 | 3 58 | ) 59 | if draw_text: 60 | offset_y = 30 61 | offset_x = 8 62 | for label in detection.labels: 63 | label_text = label_map[label] 64 | cv2.putText( 65 | image, 66 | label_text, 67 | (box[0] + offset_x, box[1] + offset_y), 68 | font, 69 | font_scale, 70 | colors[label % num_colors], 71 | 2,# thickness 72 | cv2.LINE_AA 73 | ) 74 | offset_y += 18 75 | if is_pil: 76 | image = PIL.Image.fromarray(image) 77 | return image -------------------------------------------------------------------------------- /nanoowl/tree_predictor.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | from .tree import Tree, TreeOp 18 | from .owl_predictor import OwlPredictor, OwlEncodeTextOutput, OwlEncodeImageOutput 19 | from .clip_predictor import ClipPredictor, ClipEncodeTextOutput, ClipEncodeImageOutput 20 | from .image_preprocessor import ImagePreprocessor 21 | 22 | import torch 23 | import PIL.Image 24 | from typing import Optional, Tuple, List, Mapping, Dict 25 | from dataclasses import dataclass 26 | 27 | 28 | @dataclass 29 | class TreeDetection: 30 | id: int 31 | parent_id: int 32 | box: Tuple[float, float, float, float] 33 | labels: List[int] 34 | scores: List[int] 35 | 36 | 37 | @dataclass 38 | class TreeOutput: 39 | detections: List[TreeDetection] 40 | 41 | 42 | class TreePredictor(torch.nn.Module): 43 | 44 | def __init__(self, 45 | owl_predictor: Optional[OwlPredictor] = None, 46 | clip_predictor: Optional[ClipPredictor] = None, 47 | image_preprocessor: Optional[ImagePreprocessor] = None, 48 | device: str = "cuda" 49 | ): 50 | super().__init__() 51 | self.owl_predictor = OwlPredictor() if owl_predictor is None else owl_predictor 52 | self.clip_predictor = ClipPredictor() if clip_predictor is None else clip_predictor 53 | self.image_preprocessor = ImagePreprocessor().to(device).eval() if image_preprocessor is None else image_preprocessor 54 | 55 | def encode_clip_text(self, tree: Tree) -> Dict[int, ClipEncodeTextOutput]: 56 | label_indices = tree.get_classify_label_indices() 57 | if len(label_indices) == 0: 58 | return {} 59 | labels = [tree.labels[index] for index in label_indices] 60 | text_encodings = self.clip_predictor.encode_text(labels) 61 | label_encodings = {} 62 | for i in range(len(labels)): 63 | label_encodings[label_indices[i]] = text_encodings.slice(i, i+1) 64 | return label_encodings 65 | 66 | def encode_owl_text(self, tree: Tree) -> Dict[int, OwlEncodeTextOutput]: 67 | label_indices = tree.get_detect_label_indices() 68 | if len(label_indices) == 0: 69 | return {} 70 | labels = [tree.labels[index] for index in label_indices] 71 | text_encodings = self.owl_predictor.encode_text(labels) 72 | label_encodings = {} 73 | for i in range(len(labels)): 74 | label_encodings[label_indices[i]] = text_encodings.slice(i, i+1) 75 | return label_encodings 76 | 77 | @torch.no_grad() 78 | def predict(self, 79 | image: PIL.Image.Image, 80 | tree: Tree, 81 | threshold: float = 0.1, 82 | clip_text_encodings: Optional[Dict[int, ClipEncodeTextOutput]] = None, 83 | owl_text_encodings: Optional[Dict[int, OwlEncodeTextOutput]] = None 84 | ): 85 | 86 | if clip_text_encodings is None: 87 | clip_text_encodings = self.encode_clip_text(tree) 88 | 89 | if owl_text_encodings is None: 90 | owl_text_encodings = self.encode_owl_text(tree) 91 | 92 | image_tensor = self.image_preprocessor.preprocess_pil_image(image) 93 | boxes = { 94 | 0: torch.tensor([[0, 0, image.width, image.height]], dtype=image_tensor.dtype, device=image_tensor.device) 95 | } 96 | scores = { 97 | 0: torch.tensor([1.], dtype=torch.float, device=image_tensor.device) 98 | } 99 | instance_ids = { 100 | 0: torch.tensor([0], dtype=torch.int64, device=image_tensor.device) 101 | } 102 | parent_instance_ids = { 103 | 0: torch.tensor([-1], dtype=torch.int64, device=image_tensor.device) 104 | } 105 | 106 | owl_image_encodings: Dict[int, OwlEncodeImageOutput] = {} 107 | clip_image_encodings: Dict[int, ClipEncodeImageOutput] = {} 108 | 109 | global_instance_id = 1 110 | 111 | queue = [0] 112 | 113 | while queue: 114 | label_index = queue.pop(0) 115 | 116 | detect_nodes = tree.find_detect_nodes_with_input(label_index) 117 | classify_nodes = tree.find_classify_nodes_with_input(label_index) 118 | 119 | # Run OWL image encode if required 120 | if len(detect_nodes) > 0 and label_index not in owl_image_encodings: 121 | owl_image_encodings[label_index] = self.owl_predictor.encode_rois(image_tensor, boxes[label_index]) 122 | 123 | 124 | # Run CLIP image encode if required 125 | if len(classify_nodes) > 0 and label_index not in clip_image_encodings: 126 | clip_image_encodings[label_index] = self.clip_predictor.encode_rois(image_tensor, boxes[label_index]) 127 | 128 | # Decode detect nodes 129 | for node in detect_nodes: 130 | 131 | if node.input not in owl_image_encodings: 132 | raise RuntimeError("Missing owl image encodings for node.") 133 | 134 | # gather encodings 135 | owl_text_encodings_for_node = OwlEncodeTextOutput( 136 | text_embeds=torch.cat([ 137 | owl_text_encodings[i].text_embeds for i in node.outputs 138 | ], dim=0) 139 | ) 140 | 141 | owl_node_output = self.owl_predictor.decode( 142 | owl_image_encodings[node.input], 143 | owl_text_encodings_for_node, 144 | threshold=threshold 145 | ) 146 | 147 | num_detections = len(owl_node_output.labels) 148 | instance_ids_for_node = torch.arange(global_instance_id, global_instance_id + num_detections, dtype=torch.int64, device=owl_node_output.labels.device) 149 | parent_instance_ids_for_node = instance_ids[node.input][owl_node_output.input_indices] 150 | global_instance_id += num_detections 151 | 152 | for i in range(len(node.outputs)): 153 | mask = owl_node_output.labels == i 154 | out_idx = node.outputs[i] 155 | boxes[out_idx] = owl_node_output.boxes[mask] 156 | scores[out_idx] = owl_node_output.scores[mask] 157 | instance_ids[out_idx] = instance_ids_for_node[mask] 158 | parent_instance_ids[out_idx] = parent_instance_ids_for_node[mask] 159 | 160 | for node in classify_nodes: 161 | 162 | if node.input not in clip_image_encodings: 163 | raise RuntimeError("Missing clip image encodings for node.") 164 | 165 | clip_text_encodings_for_node = ClipEncodeTextOutput( 166 | text_embeds=torch.cat([ 167 | clip_text_encodings[i].text_embeds for i in node.outputs 168 | ], dim=0) 169 | ) 170 | 171 | clip_node_output = self.clip_predictor.decode( 172 | clip_image_encodings[node.input], 173 | clip_text_encodings_for_node 174 | ) 175 | 176 | parent_instance_ids_for_node = instance_ids[node.input] 177 | 178 | for i in range(len(node.outputs)): 179 | mask = clip_node_output.labels == i 180 | output_buffer = node.outputs[i] 181 | scores[output_buffer] = clip_node_output.scores[mask].float() 182 | boxes[output_buffer] = boxes[label_index][mask].float() 183 | instance_ids[output_buffer] = instance_ids[node.input][mask] 184 | parent_instance_ids[output_buffer] = parent_instance_ids[node.input][mask] 185 | 186 | for node in detect_nodes: 187 | for buf in node.outputs: 188 | if buf in scores and len(scores[buf]) > 0: 189 | queue.append(buf) 190 | 191 | for node in classify_nodes: 192 | for buf in node.outputs: 193 | if buf in scores and len(scores[buf]) > 0: 194 | queue.append(buf) 195 | 196 | # Fill outputs 197 | detections: Dict[int, TreeDetection] = {} 198 | for i in boxes.keys(): 199 | for box, score, instance_id, parent_instance_id in zip(boxes[i], scores[i], instance_ids[i], parent_instance_ids[i]): 200 | instance_id = int(instance_id) 201 | score = float(score) 202 | box = box.tolist() 203 | parent_instance_id = int(parent_instance_id) 204 | if instance_id in detections: 205 | detections[instance_id].labels.append(i) 206 | detections[instance_id].scores.append(score) 207 | else: 208 | detections[instance_id] = TreeDetection( 209 | id=instance_id, 210 | parent_id=parent_instance_id, 211 | box=box, 212 | labels=[i], 213 | scores=[score] 214 | ) 215 | 216 | return TreeOutput(detections=detections.values()) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | 4 | setup( 5 | name="nanoowl", 6 | version="0.0.0", 7 | packages=find_packages() 8 | ) -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /test/test_clip_predictor.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import pytest 18 | import torch 19 | import PIL.Image 20 | from nanoowl.clip_predictor import ClipPredictor 21 | from nanoowl.image_preprocessor import ImagePreprocessor 22 | 23 | 24 | def test_get_image_size(): 25 | clip_predictor = ClipPredictor() 26 | assert clip_predictor.get_image_size() == (224, 224) 27 | 28 | 29 | def test_clip_encode_text(): 30 | 31 | clip_predictor = ClipPredictor() 32 | 33 | text_encode_output = clip_predictor.encode_text(["a frog", "a dog"]) 34 | 35 | assert text_encode_output.text_embeds.shape == (2, 512) 36 | 37 | 38 | def test_clip_encode_image(): 39 | 40 | clip_predictor = ClipPredictor() 41 | 42 | image = PIL.Image.open("assets/owl_glove_small.jpg") 43 | 44 | image = image.resize((224, 224)) 45 | 46 | image_preprocessor = ImagePreprocessor().to(clip_predictor.device).eval() 47 | 48 | image_tensor = image_preprocessor.preprocess_pil_image(image) 49 | 50 | image_encode_output = clip_predictor.encode_image(image_tensor) 51 | 52 | assert image_encode_output.image_embeds.shape == (1, 512) 53 | 54 | 55 | def test_clip_classify(): 56 | 57 | clip_predictor = ClipPredictor() 58 | 59 | image = PIL.Image.open("assets/frog.jpg") 60 | 61 | image = image.resize((224, 224)) 62 | image_preprocessor = ImagePreprocessor().to(clip_predictor.device).eval() 63 | 64 | image_tensor = image_preprocessor.preprocess_pil_image(image) 65 | 66 | text_output = clip_predictor.encode_text(["a frog", "an owl"]) 67 | image_output = clip_predictor.encode_image(image_tensor) 68 | 69 | classify_output = clip_predictor.decode(image_output, text_output) 70 | 71 | assert classify_output.labels[0] == 0 -------------------------------------------------------------------------------- /test/test_image_preprocessor.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import pytest 18 | import torch 19 | import PIL.Image 20 | from nanoowl.image_preprocessor import ImagePreprocessor 21 | 22 | 23 | def test_image_preprocessor_preprocess_pil_image(): 24 | 25 | image_preproc = ImagePreprocessor().to("cuda").eval() 26 | 27 | image = PIL.Image.open("assets/owl_glove_small.jpg") 28 | 29 | image_tensor = image_preproc.preprocess_pil_image(image) 30 | 31 | assert image_tensor.shape == (1, 3, 499, 756) 32 | assert torch.allclose(image_tensor.mean(), torch.zeros_like(image_tensor), atol=1, rtol=1) 33 | -------------------------------------------------------------------------------- /test/test_owl_predictor.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import pytest 18 | import torch 19 | import PIL.Image 20 | from nanoowl.owl_predictor import OwlPredictor 21 | from nanoowl.image_preprocessor import ImagePreprocessor 22 | 23 | 24 | def test_owl_predictor_get_image_size(): 25 | owl_predictor = OwlPredictor() 26 | assert owl_predictor.get_image_size() == (768, 768) 27 | 28 | 29 | def test_owl_predictor_encode_text(): 30 | 31 | owl_predictor = OwlPredictor() 32 | 33 | text_encode_output = owl_predictor.encode_text(["a frog", "a dog"]) 34 | 35 | assert text_encode_output.text_embeds.shape == (2, 512) 36 | 37 | 38 | def test_owl_predictor_encode_image(): 39 | 40 | owl_predictor = OwlPredictor() 41 | 42 | image = PIL.Image.open("assets/owl_glove_small.jpg") 43 | 44 | image = image.resize(owl_predictor.get_image_size()) 45 | 46 | image_preprocessor = ImagePreprocessor().to(owl_predictor.device).eval() 47 | 48 | image_tensor = image_preprocessor.preprocess_pil_image(image) 49 | 50 | image_encode_output = owl_predictor.encode_image(image_tensor) 51 | 52 | assert image_encode_output.image_class_embeds.shape == (1, owl_predictor.get_num_patches(), 512) 53 | 54 | 55 | def test_owl_predictor_decode(): 56 | 57 | owl_predictor = OwlPredictor() 58 | 59 | image = PIL.Image.open("assets/owl_glove_small.jpg") 60 | 61 | image = image.resize(owl_predictor.get_image_size()) 62 | image_preprocessor = ImagePreprocessor().to(owl_predictor.device).eval() 63 | 64 | image_tensor = image_preprocessor.preprocess_pil_image(image) 65 | 66 | text_output = owl_predictor.encode_text(["an owl"]) 67 | image_output = owl_predictor.encode_image(image_tensor) 68 | 69 | classify_output = owl_predictor.decode(image_output, text_output) 70 | 71 | assert len(classify_output.labels == 1) 72 | assert classify_output.labels[0] == 0 73 | assert classify_output.boxes.shape == (1, 4) 74 | assert classify_output.input_indices == 0 75 | 76 | 77 | def test_owl_predictor_decode_multiple_images(): 78 | 79 | owl_predictor = OwlPredictor() 80 | image_preprocessor = ImagePreprocessor().to(owl_predictor.device).eval() 81 | 82 | image_paths = [ 83 | "assets/owl_glove_small.jpg", 84 | "assets/frog.jpg" 85 | ] 86 | 87 | images = [] 88 | for image_path in image_paths: 89 | image = PIL.Image.open(image_path) 90 | image = image.resize(owl_predictor.get_image_size()) 91 | image = image_preprocessor.preprocess_pil_image(image) 92 | images.append(image) 93 | 94 | images = torch.cat(images, dim=0) 95 | 96 | text_output = owl_predictor.encode_text(["an owl", "a frog"]) 97 | image_output = owl_predictor.encode_image(images) 98 | 99 | decode_output = owl_predictor.decode(image_output, text_output) 100 | 101 | # check number of detections 102 | assert len(decode_output.labels == 2) 103 | 104 | # check owl detection 105 | assert decode_output.labels[0] == 0 106 | assert decode_output.input_indices[0] == 0 107 | 108 | # check frog detection 109 | assert decode_output.labels[1] == 1 110 | assert decode_output.input_indices[1] == 1 111 | -------------------------------------------------------------------------------- /test/test_tree.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import pytest 18 | from nanoowl.tree import ( 19 | Tree, 20 | TreeOp 21 | ) 22 | 23 | 24 | def test_tree_from_prompt(): 25 | 26 | graph = Tree.from_prompt("[a face]") 27 | 28 | assert len(graph.nodes) == 1 29 | assert len(graph.labels) == 2 30 | assert graph.labels[0] == "image" 31 | assert graph.labels[1] == "a face" 32 | assert graph.nodes[0].op == TreeOp.DETECT 33 | 34 | graph = Tree.from_prompt("[a face](a dog, a cat)") 35 | 36 | assert len(graph.nodes) == 2 37 | assert len(graph.labels) == 4 38 | assert graph.labels[0] == "image" 39 | assert graph.labels[1] == "a face" 40 | assert graph.labels[2] == "a dog" 41 | assert graph.labels[3] == "a cat" 42 | assert graph.nodes[0].op == TreeOp.DETECT 43 | assert graph.nodes[1].op == TreeOp.CLASSIFY 44 | 45 | with pytest.raises(RuntimeError): 46 | Tree.from_prompt("]a face]") 47 | 48 | with pytest.raises(RuntimeError): 49 | Tree.from_prompt("[a face") 50 | 51 | with pytest.raises(RuntimeError): 52 | Tree.from_prompt("[a face)") 53 | 54 | with pytest.raises(RuntimeError): 55 | Tree.from_prompt("[a face]]") 56 | 57 | 58 | def test_tree_to_dict(): 59 | 60 | tree = Tree.from_prompt("[a[b,c(d,e)]]") 61 | tree_dict = tree.to_dict() 62 | assert "nodes" in tree_dict 63 | assert "labels" in tree_dict 64 | assert len(tree_dict["nodes"]) == 3 65 | assert len(tree_dict["labels"]) == 6 66 | 67 | 68 | 69 | def test_tree_from_prompt(): 70 | 71 | tree = Tree.from_prompt("(office, home, outdoors, gym)") 72 | 73 | print(tree) -------------------------------------------------------------------------------- /test/test_tree_predictor.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import pytest 18 | import PIL.Image 19 | from nanoowl.tree_predictor import TreePredictor 20 | from nanoowl.tree import Tree 21 | 22 | 23 | def test_encode_clip_labels(): 24 | 25 | predictor = TreePredictor() 26 | tree = Tree.from_prompt("(sunny, rainy)") 27 | 28 | text_encodings = predictor.encode_clip_text(tree) 29 | 30 | assert len(text_encodings) == 2 31 | assert 1 in text_encodings 32 | assert 2 in text_encodings 33 | assert text_encodings[1].text_embeds.shape == (1, 512) 34 | 35 | 36 | def test_encode_owl_labels(): 37 | 38 | predictor = TreePredictor() 39 | tree = Tree.from_prompt("[a face [an eye, a nose]]") 40 | 41 | text_encodings = predictor.encode_owl_text(tree) 42 | 43 | assert len(text_encodings) == 3 44 | assert 1 in text_encodings 45 | assert 2 in text_encodings 46 | assert 3 in text_encodings 47 | assert text_encodings[1].text_embeds.shape == (1, 512) 48 | 49 | 50 | def test_encode_clip_owl_labels_mixed(): 51 | 52 | predictor = TreePredictor() 53 | tree = Tree.from_prompt("[a face [an eye, a nose](happy, sad)]") 54 | 55 | owl_text_encodings = predictor.encode_owl_text(tree) 56 | clip_text_encodings = predictor.encode_clip_text(tree) 57 | 58 | assert len(owl_text_encodings) == 3 59 | assert len(clip_text_encodings) == 2 60 | 61 | 62 | def test_tree_predictor_predict(): 63 | 64 | predictor = TreePredictor() 65 | tree = Tree.from_prompt("[an owl]") 66 | 67 | 68 | image = PIL.Image.open("assets/owl_glove.jpg") 69 | 70 | detections = predictor.predict(image, tree) 71 | 72 | 73 | def test_tree_predictor_predict(): 74 | 75 | predictor = TreePredictor() 76 | tree = Tree.from_prompt("(outdoors, indoors)") 77 | 78 | 79 | image = PIL.Image.open("assets/owl_glove.jpg") 80 | 81 | detections = predictor.predict(image, tree) 82 | 83 | print(detections) 84 | --------------------------------------------------------------------------------