├── .gitignore
├── .vscode
└── settings.json
├── LICENSE.md
├── README.md
├── assets
├── basic_usage_out.jpg
├── camera.jpg
├── class.jpg
├── frog.jpg
├── jetson_person_2x.gif
├── nanoowl_jetbot.gif
├── owl_glove.jpg
├── owl_glove_out.jpg
├── owl_glove_small.jpg
├── owl_gradio_demo.jpg
├── owl_predict_out.jpg
└── tree_predict_out.jpg
├── docker
└── 23-01
│ ├── Dockerfile
│ ├── build.sh
│ └── run.sh
├── examples
├── owl_predict.py
├── tree_demo
│ ├── index.html
│ └── tree_demo.py
└── tree_predict.py
├── nanoowl
├── __init__.py
├── build_image_encoder_engine.py
├── clip_predictor.py
├── image_preprocessor.py
├── owl_drawing.py
├── owl_predictor.py
├── sync_timer.py
├── tree.py
├── tree_drawing.py
└── tree_predictor.py
├── setup.py
└── test
├── __init__.py
├── test_clip_predictor.py
├── test_image_preprocessor.py
├── test_owl_predictor.py
├── test_tree.py
└── test_tree_predictor.py
/.gitignore:
--------------------------------------------------------------------------------
1 | data
2 | *.egg-info
3 | *.pyc
4 | charmap.json.gz
5 | sandbox
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.testing.pytestArgs": [
3 | "test"
4 | ],
5 | "python.testing.unittestEnabled": false,
6 | "python.testing.pytestEnabled": true
7 | }
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
4 |
5 | NanoOWL is a project that optimizes [OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit) to run 🔥 ***real-time*** 🔥 on [NVIDIA Jetson Orin Platforms](https://store.nvidia.com/en-us/jetson/store) with [NVIDIA TensorRT](https://developer.nvidia.com/tensorrt). NanoOWL also introduces a new "tree detection" pipeline that combines OWL-ViT and CLIP to enable nested detection and classification of anything, at any level, simply by providing text.
6 |
7 |
8 |
9 |
10 | > Interested in detecting object masks as well? Try combining NanoOWL with
11 | > [NanoSAM](https://github.com/NVIDIA-AI-IOT/nanosam) for zero-shot open-vocabulary
12 | > instance segmentation.
13 |
14 |
15 | ## 👍 Usage
16 |
17 | You can use NanoOWL in Python like this
18 |
19 | ```python3
20 | from nanoowl.owl_predictor import OwlPredictor
21 |
22 | predictor = OwlPredictor(
23 | "google/owlvit-base-patch32",
24 | image_encoder_engine="data/owlvit-base-patch32-image-encoder.engine"
25 | )
26 |
27 | image = PIL.Image.open("assets/owl_glove_small.jpg")
28 |
29 | output = predictor.predict(image=image, text=["an owl", "a glove"], threshold=0.1)
30 |
31 | print(output)
32 | ```
33 |
34 | Or better yet, to use OWL-ViT in conjunction with CLIP to detect and classify anything,
35 | at any level, check out the tree predictor example below!
36 |
37 | > See [Setup](#setup) for instructions on how to build the image encoder engine.
38 |
39 |
40 | ## ⏱️ Performance
41 |
42 | NanoOWL runs real-time on Jetson Orin Nano.
43 |
44 |
45 |
46 |
47 |
Model †
48 |
Image Size
49 |
Patch Size
50 |
⏱️ Jetson Orin Nano (FPS)
51 |
⏱️ Jetson AGX Orin (FPS)
52 |
🎯 Accuracy (mAP)
53 |
54 |
55 |
56 |
57 |
OWL-ViT (ViT-B/32)
58 |
768
59 |
32
60 |
TBD
61 |
95
62 |
28
63 |
64 |
65 |
OWL-ViT (ViT-B/16)
66 |
768
67 |
16
68 |
TBD
69 |
25
70 |
31.7
71 |
72 |
73 |
74 |
75 |
76 | ## 🛠️ Setup
77 |
78 | 1. Install the dependencies
79 |
80 | 1. Install PyTorch
81 |
82 | 2. Install [torch2trt](https://github.com/NVIDIA-AI-IOT/torch2trt)
83 | 3. Install NVIDIA TensorRT
84 | 4. Install the Transformers library
85 |
86 | ```bash
87 | python3 -m pip install transformers
88 | ```
89 | 5. (optional) Install NanoSAM (for the instance segmentation example)
90 |
91 | 2. Install the NanoOWL package.
92 |
93 | ```bash
94 | git clone https://github.com/NVIDIA-AI-IOT/nanoowl
95 | cd nanoowl
96 | python3 setup.py develop --user
97 | ```
98 |
99 | 3. Build the TensorRT engine for the OWL-ViT vision encoder
100 |
101 | ```bash
102 | mkdir -p data
103 | python3 -m nanoowl.build_image_encoder_engine \
104 | data/owl_image_encoder_patch32.engine
105 | ```
106 |
107 |
108 | 4. Run an example prediction to ensure everything is working
109 |
110 | ```bash
111 | cd examples
112 | python3 owl_predict.py \
113 | --prompt="[an owl, a glove]" \
114 | --threshold=0.1 \
115 | --image_encoder_engine=../data/owl_image_encoder_patch32.engine
116 | ```
117 |
118 | That's it! If everything is working properly, you should see a visualization saved to ``data/owl_predict_out.jpg``.
119 |
120 |
121 | ## 🤸 Examples
122 |
123 | ### Example 1 - Basic prediction
124 |
125 |
126 |
127 | This example demonstrates how to use the TensorRT optimized OWL-ViT model to
128 | detect objects by providing text descriptions of the object labels.
129 |
130 | To run the example, first navigate to the examples folder
131 |
132 | ```bash
133 | cd examples
134 | ```
135 |
136 | Then run the example
137 |
138 | ```bash
139 | python3 owl_predict.py \
140 | --prompt="[an owl, a glove]" \
141 | --threshold=0.1 \
142 | --image_encoder_engine=../data/owl_image_encoder_patch32.engine
143 | ```
144 |
145 | By default the output will be saved to ``data/owl_predict_out.jpg``.
146 |
147 | You can also use this example to profile inference. Simply set the flag ``--profile``.
148 |
149 | ### Example 2 - Tree prediction
150 |
151 |
152 |
153 | This example demonstrates how to use the tree predictor class to detect and
154 | classify objects at any level.
155 |
156 | To run the example, first navigate to the examples folder
157 |
158 | ```bash
159 | cd examples
160 | ```
161 |
162 | To detect all owls, and the detect all wings and eyes in each detect owl region
163 | of interest, type
164 |
165 | ```bash
166 | python3 tree_predict.py \
167 | --prompt="[an owl [a wing, an eye]]" \
168 | --threshold=0.15 \
169 | --image_encoder_engine=../data/owl_image_encoder_patch32.engine
170 | ```
171 |
172 | By default the output will be saved to ``data/tree_predict_out.jpg``.
173 |
174 | To classify the image as indoors or outdoors, type
175 |
176 | ```bash
177 | python3 tree_predict.py \
178 | --prompt="(indoors, outdoors)" \
179 | --threshold=0.15 \
180 | --image_encoder_engine=../data/owl_image_encoder_patch32.engine
181 | ```
182 |
183 | To classify the image as indoors or outdoors, and if it's outdoors then detect
184 | all owls, type
185 |
186 | ```bash
187 | python3 tree_predict.py \
188 | --prompt="(indoors, outdoors [an owl])" \
189 | --threshold=0.15 \
190 | --image_encoder_engine=../data/owl_image_encoder_patch32.engine
191 | ```
192 |
193 |
194 | ### Example 3 - Tree prediction (Live Camera)
195 |
196 |
197 |
198 | This example demonstrates the tree predictor running on a live camera feed with
199 | live-edited text prompts. To run the example
200 |
201 | 1. Ensure you have a camera device connected
202 |
203 | 2. Launch the demo
204 | ```bash
205 | cd examples/tree_demo
206 | python3 tree_demo.py ../../data/owl_image_encoder_patch32.engine
207 | ```
208 | 3. Second, open your browser to ``http://:7860``
209 | 4. Type whatever prompt you like to see what works! Here are some examples
210 | - Example: [a face [a nose, an eye, a mouth]]
211 | - Example: [a face (interested, yawning / bored)]
212 | - Example: (indoors, outdoors)
213 |
214 |
215 |
216 |
217 | ## 👏 Acknowledgement
218 |
219 | Thanks to the authors of [OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit) for the great open-vocabluary detection work.
220 |
221 |
222 | ## 🔗 See also
223 |
224 | - [NanoSAM](https://github.com/NVIDIA-AI-IOT/nanosam) - A real-time Segment Anything (SAM) model variant for NVIDIA Jetson Orin platforms.
225 | - [Jetson Introduction to Knowledge Distillation Tutorial](https://github.com/NVIDIA-AI-IOT/jetson-intro-to-distillation) - For an introduction to knowledge distillation as a model optimization technique.
226 | - [Jetson Generative AI Playground](https://nvidia-ai-iot.github.io/jetson-generative-ai-playground/) - For instructions and tips for using a variety of LLMs and transformers on Jetson.
227 | - [Jetson Containers](https://github.com/dusty-nv/jetson-containers) - For a variety of easily deployable and modular Jetson Containers
228 |
--------------------------------------------------------------------------------
/assets/basic_usage_out.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/basic_usage_out.jpg
--------------------------------------------------------------------------------
/assets/camera.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/camera.jpg
--------------------------------------------------------------------------------
/assets/class.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/class.jpg
--------------------------------------------------------------------------------
/assets/frog.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/frog.jpg
--------------------------------------------------------------------------------
/assets/jetson_person_2x.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/jetson_person_2x.gif
--------------------------------------------------------------------------------
/assets/nanoowl_jetbot.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/nanoowl_jetbot.gif
--------------------------------------------------------------------------------
/assets/owl_glove.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/owl_glove.jpg
--------------------------------------------------------------------------------
/assets/owl_glove_out.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/owl_glove_out.jpg
--------------------------------------------------------------------------------
/assets/owl_glove_small.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/owl_glove_small.jpg
--------------------------------------------------------------------------------
/assets/owl_gradio_demo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/owl_gradio_demo.jpg
--------------------------------------------------------------------------------
/assets/owl_predict_out.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/owl_predict_out.jpg
--------------------------------------------------------------------------------
/assets/tree_predict_out.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVIDIA-AI-IOT/nanoowl/fb553dee4c58f8c53ec2bf01f355a54648e67dbd/assets/tree_predict_out.jpg
--------------------------------------------------------------------------------
/docker/23-01/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvcr.io/nvidia/pytorch:23.01-py3
2 |
3 | # upgrade pillow to fix "UnidentifiedImageError"
4 | RUN pip install pillow --upgrade
5 |
6 | RUN pip install git+https://github.com/NVIDIA-AI-IOT/torch2trt.git
7 | RUN pip install transformers timm accelerate
8 | RUN pip install git+https://github.com/openai/CLIP.git
--------------------------------------------------------------------------------
/docker/23-01/build.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | docker build -t nanoowl:23-01 -f $(pwd)/docker/23-01/Dockerfile $(pwd)/docker/23-01
--------------------------------------------------------------------------------
/docker/23-01/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | docker run \
5 | -it \
6 | -d \
7 | --rm \
8 | --ipc host \
9 | --gpus all \
10 | --shm-size 14G \
11 | --device /dev/video0:/dev/video0 \
12 | -v /tmp/.X11-unix:/tmp/.X11-unix \
13 | -e DISPLAY=$DISPLAY \
14 | -p 7860:7860 \
15 | -v $(pwd):/nanoowl \
16 | nanoowl:23-01
--------------------------------------------------------------------------------
/examples/owl_predict.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import argparse
18 | import PIL.Image
19 | import time
20 | import torch
21 | from nanoowl.owl_predictor import (
22 | OwlPredictor
23 | )
24 | from nanoowl.owl_drawing import (
25 | draw_owl_output
26 | )
27 |
28 |
29 | if __name__ == "__main__":
30 |
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument("--image", type=str, default="../assets/owl_glove_small.jpg")
33 | parser.add_argument("--prompt", type=str, default="[an owl, a glove]")
34 | parser.add_argument("--threshold", type=str, default="0.1,0.1")
35 | parser.add_argument("--output", type=str, default="../data/owl_predict_out.jpg")
36 | parser.add_argument("--model", type=str, default="google/owlvit-base-patch32")
37 | parser.add_argument("--image_encoder_engine", type=str, default="../data/owl_image_encoder_patch32.engine")
38 | parser.add_argument("--profile", action="store_true")
39 | parser.add_argument("--num_profiling_runs", type=int, default=30)
40 | args = parser.parse_args()
41 |
42 | prompt = args.prompt.strip("][()")
43 | text = prompt.split(',')
44 | print(text)
45 |
46 | thresholds = args.threshold.strip("][()")
47 | thresholds = thresholds.split(',')
48 | if len(thresholds) == 1:
49 | thresholds = float(thresholds[0])
50 | else:
51 | thresholds = [float(x) for x in thresholds]
52 | print(thresholds)
53 |
54 |
55 | predictor = OwlPredictor(
56 | args.model,
57 | image_encoder_engine=args.image_encoder_engine
58 | )
59 |
60 | image = PIL.Image.open(args.image)
61 |
62 | text_encodings = predictor.encode_text(text)
63 |
64 | output = predictor.predict(
65 | image=image,
66 | text=text,
67 | text_encodings=text_encodings,
68 | threshold=thresholds,
69 | pad_square=False
70 | )
71 |
72 | if args.profile:
73 | torch.cuda.current_stream().synchronize()
74 | t0 = time.perf_counter_ns()
75 | for i in range(args.num_profiling_runs):
76 | output = predictor.predict(
77 | image=image,
78 | text=text,
79 | text_encodings=text_encodings,
80 | threshold=thresholds,
81 | pad_square=False
82 | )
83 | torch.cuda.current_stream().synchronize()
84 | t1 = time.perf_counter_ns()
85 | dt = (t1 - t0) / 1e9
86 | print(f"PROFILING FPS: {args.num_profiling_runs/dt}")
87 |
88 | image = draw_owl_output(image, output, text=text, draw_text=True)
89 |
90 | image.save(args.output)
--------------------------------------------------------------------------------
/examples/tree_demo/index.html:
--------------------------------------------------------------------------------
1 |
17 |
18 |
19 |
20 |
21 |
54 |
55 |
103 |
104 |
105 |
106 |
NanoOWL
107 |
108 |
109 |
110 |
111 |
112 |
--------------------------------------------------------------------------------
/examples/tree_demo/tree_demo.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import asyncio
18 | import argparse
19 | from aiohttp import web, WSCloseCode
20 | import logging
21 | import weakref
22 | import cv2
23 | import time
24 | import PIL.Image
25 | import matplotlib.pyplot as plt
26 | from typing import List
27 | from nanoowl.tree import Tree
28 | from nanoowl.tree_predictor import (
29 | TreePredictor
30 | )
31 | from nanoowl.tree_drawing import draw_tree_output
32 | from nanoowl.owl_predictor import OwlPredictor
33 |
34 |
35 | if __name__ == "__main__":
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument("image_encode_engine", type=str)
38 | parser.add_argument("--image_quality", type=int, default=50)
39 | parser.add_argument("--port", type=int, default=7860)
40 | parser.add_argument("--host", type=str, default="0.0.0.0")
41 | parser.add_argument("--camera", type=int, default=0)
42 | parser.add_argument("--resolution", type=str, default="640x480", help="Camera resolution as WIDTHxHEIGHT")
43 | args = parser.parse_args()
44 | width, height = map(int, args.resolution.split("x"))
45 |
46 | CAMERA_DEVICE = args.camera
47 | IMAGE_QUALITY = args.image_quality
48 |
49 | predictor = TreePredictor(
50 | owl_predictor=OwlPredictor(
51 | image_encoder_engine=args.image_encode_engine
52 | )
53 | )
54 |
55 | prompt_data = None
56 |
57 | def get_colors(count: int):
58 | cmap = plt.cm.get_cmap("rainbow", count)
59 | colors = []
60 | for i in range(count):
61 | color = cmap(i)
62 | color = [int(255 * value) for value in color]
63 | colors.append(tuple(color))
64 | return colors
65 |
66 |
67 | def cv2_to_pil(image):
68 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
69 | return PIL.Image.fromarray(image)
70 |
71 |
72 | async def handle_index_get(request: web.Request):
73 | logging.info("handle_index_get")
74 | return web.FileResponse("./index.html")
75 |
76 |
77 | async def websocket_handler(request):
78 |
79 | global prompt_data
80 |
81 | ws = web.WebSocketResponse()
82 |
83 | await ws.prepare(request)
84 |
85 | logging.info("Websocket connected.")
86 |
87 | request.app['websockets'].add(ws)
88 |
89 | try:
90 | async for msg in ws:
91 | logging.info(f"Received message from websocket.")
92 | if "prompt" in msg.data:
93 | header, prompt = msg.data.split(":")
94 | logging.info("Received prompt: " + prompt)
95 | try:
96 | tree = Tree.from_prompt(prompt)
97 | clip_encodings = predictor.encode_clip_text(tree)
98 | owl_encodings = predictor.encode_owl_text(tree)
99 | prompt_data = {
100 | "tree": tree,
101 | "clip_encodings": clip_encodings,
102 | "owl_encodings": owl_encodings
103 | }
104 | logging.info("Set prompt: " + prompt)
105 | except Exception as e:
106 | print(e)
107 | finally:
108 | request.app['websockets'].discard(ws)
109 |
110 | return ws
111 |
112 |
113 | async def on_shutdown(app: web.Application):
114 | for ws in set(app['websockets']):
115 | await ws.close(code=WSCloseCode.GOING_AWAY,
116 | message='Server shutdown')
117 |
118 |
119 | async def detection_loop(app: web.Application):
120 |
121 | loop = asyncio.get_running_loop()
122 |
123 | logging.info("Opening camera.")
124 |
125 | camera = cv2.VideoCapture(CAMERA_DEVICE)
126 | camera.set(cv2.CAP_PROP_FRAME_WIDTH, width)
127 | camera.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
128 |
129 | logging.info("Loading predictor.")
130 |
131 | def _read_and_encode_image():
132 |
133 | re, image = camera.read()
134 |
135 | if not re:
136 | return re, None
137 |
138 | image_pil = cv2_to_pil(image)
139 |
140 | if prompt_data is not None:
141 | prompt_data_local = prompt_data
142 | t0 = time.perf_counter_ns()
143 | detections = predictor.predict(
144 | image_pil,
145 | tree=prompt_data_local['tree'],
146 | clip_text_encodings=prompt_data_local['clip_encodings'],
147 | owl_text_encodings=prompt_data_local['owl_encodings']
148 | )
149 | t1 = time.perf_counter_ns()
150 | dt = (t1 - t0) / 1e9
151 | tree = prompt_data_local['tree']
152 | image = draw_tree_output(image, detections, prompt_data_local['tree'])
153 |
154 | image_jpeg = bytes(
155 | cv2.imencode(".jpg", image, [cv2.IMWRITE_JPEG_QUALITY, IMAGE_QUALITY])[1]
156 | )
157 |
158 | return re, image_jpeg
159 |
160 | while True:
161 |
162 | re, image = await loop.run_in_executor(None, _read_and_encode_image)
163 |
164 | if not re:
165 | break
166 |
167 | for ws in app["websockets"]:
168 | await ws.send_bytes(image)
169 |
170 | camera.release()
171 |
172 |
173 | async def run_detection_loop(app):
174 | try:
175 | task = asyncio.create_task(detection_loop(app))
176 | yield
177 | task.cancel()
178 | except asyncio.CancelledError:
179 | pass
180 | finally:
181 | await task
182 |
183 |
184 | logging.basicConfig(level=logging.INFO)
185 | app = web.Application()
186 | app['websockets'] = weakref.WeakSet()
187 | app.router.add_get("/", handle_index_get)
188 | app.router.add_route("GET", "/ws", websocket_handler)
189 | app.on_shutdown.append(on_shutdown)
190 | app.cleanup_ctx.append(run_detection_loop)
191 | web.run_app(app, host=args.host, port=args.port)
192 |
--------------------------------------------------------------------------------
/examples/tree_predict.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import argparse
18 | import PIL.Image
19 | from nanoowl.owl_predictor import OwlPredictor
20 | from nanoowl.tree_predictor import (
21 | TreePredictor, Tree
22 | )
23 | from nanoowl.tree_drawing import (
24 | draw_tree_output
25 | )
26 |
27 |
28 | if __name__ == "__main__":
29 |
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument("--image", type=str, default="../assets/owl_glove_small.jpg")
32 | parser.add_argument("--prompt", type=str, default="")
33 | parser.add_argument("--threshold", type=float, default=0.1)
34 | parser.add_argument("--output", type=str, default="../data/tree_predict_out.jpg")
35 | parser.add_argument("--model", type=str, default="google/owlvit-base-patch32")
36 | parser.add_argument("--image_encoder_engine", type=str, default="../data/owl_image_encoder_patch32.engine")
37 | args = parser.parse_args()
38 |
39 | predictor = TreePredictor(
40 | owl_predictor=OwlPredictor(
41 | args.model,
42 | image_encoder_engine=args.image_encoder_engine
43 | )
44 | )
45 |
46 | image = PIL.Image.open(args.image)
47 | tree = Tree.from_prompt(args.prompt)
48 | clip_text_encodings = predictor.encode_clip_text(tree)
49 | owl_text_encodings = predictor.encode_owl_text(tree)
50 |
51 | output = predictor.predict(
52 | image=image,
53 | tree=tree,
54 | clip_text_encodings=clip_text_encodings,
55 | owl_text_encodings=owl_text_encodings,
56 | threshold=args.threshold
57 | )
58 |
59 | image = draw_tree_output(image, output, tree=tree, draw_text=True)
60 |
61 | image.save(args.output)
--------------------------------------------------------------------------------
/nanoowl/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/nanoowl/build_image_encoder_engine.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import argparse
18 | from .owl_predictor import OwlPredictor
19 |
20 |
21 | if __name__ == "__main__":
22 |
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("output_path", type=str)
25 | parser.add_argument("--model_name", type=str, default="google/owlvit-base-patch32")
26 | parser.add_argument("--fp16_mode", type=bool, default=True)
27 | parser.add_argument("--onnx_opset", type=int, default=16)
28 | args = parser.parse_args()
29 |
30 | predictor = OwlPredictor(
31 | model_name=args.model_name
32 | )
33 |
34 | predictor.build_image_encoder_engine(
35 | args.output_path,
36 | fp16_mode=args.fp16_mode,
37 | onnx_opset=args.onnx_opset
38 | )
--------------------------------------------------------------------------------
/nanoowl/clip_predictor.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import torch
18 | import clip
19 | import PIL.Image
20 | from torchvision.ops import roi_align
21 | from typing import List, Tuple, Optional
22 | from dataclasses import dataclass
23 | from .image_preprocessor import ImagePreprocessor
24 |
25 |
26 | __all__ = [
27 | "ClipPredictor",
28 | "ClipEncodeTextOutput",
29 | "ClipEncodeImageOutput",
30 | "ClipDecodeOutput"
31 | ]
32 |
33 |
34 | @dataclass
35 | class ClipEncodeTextOutput:
36 | text_embeds: torch.Tensor
37 |
38 | def slice(self, start_index, end_index):
39 | return ClipEncodeTextOutput(
40 | text_embeds=self.text_embeds[start_index:end_index]
41 | )
42 |
43 |
44 | @dataclass
45 | class ClipEncodeImageOutput:
46 | image_embeds: torch.Tensor
47 |
48 |
49 | @dataclass
50 | class ClipDecodeOutput:
51 | labels: torch.Tensor
52 | scores: torch.Tensor
53 |
54 |
55 | class ClipPredictor(torch.nn.Module):
56 |
57 | def __init__(self,
58 | model_name: str = "ViT-B/32",
59 | image_size: Tuple[int, int] = (224, 224),
60 | device: str = "cuda",
61 | image_preprocessor: Optional[ImagePreprocessor] = None
62 | ):
63 | super().__init__()
64 | self.device = device
65 | self.clip_model, _ = clip.load(model_name, device)
66 | self.image_size = image_size
67 | self.mesh_grid = torch.stack(
68 | torch.meshgrid(
69 | torch.linspace(0., 1., self.image_size[1]),
70 | torch.linspace(0., 1., self.image_size[0])
71 | )
72 | ).to(self.device).float()
73 | self.image_preprocessor = image_preprocessor.to(self.device).eval() if image_preprocessor else ImagePreprocessor().to(self.device).eval()
74 |
75 | def get_device(self):
76 | return self.device
77 |
78 | def get_image_size(self):
79 | return self.image_size
80 |
81 | def encode_text(self, text: List[str]) -> ClipEncodeTextOutput:
82 | text_tokens = clip.tokenize(text).to(self.device)
83 | text_embeds = self.clip_model.encode_text(text_tokens)
84 | return ClipEncodeTextOutput(text_embeds=text_embeds)
85 |
86 | def encode_image(self, image: torch.Tensor) -> ClipEncodeImageOutput:
87 | image_embeds = self.clip_model.encode_image(image)
88 | return ClipEncodeImageOutput(image_embeds=image_embeds)
89 |
90 | def extract_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool = True, padding_scale: float=1.0):
91 | if len(rois) == 0:
92 | return torch.empty(
93 | (0, image.shape[1], self.image_size[0], self.image_size[1]),
94 | dtype=image.dtype,
95 | device=image.device
96 | )
97 |
98 | if pad_square:
99 | # pad square
100 | w = padding_scale * (rois[..., 2] - rois[..., 0]) / 2
101 | h = padding_scale * (rois[..., 3] - rois[..., 1]) / 2
102 | cx = (rois[..., 0] + rois[..., 2]) / 2
103 | cy = (rois[..., 1] + rois[..., 3]) / 2
104 | s = torch.max(w, h)
105 | rois = torch.stack([cx-s, cy-s, cx+s, cy+s], dim=-1)
106 |
107 | # compute mask
108 | pad_x = (s - w) / (2 * s)
109 | pad_y = (s - h) / (2 * s)
110 | mask_x = (self.mesh_grid[1][None, ...] > pad_x[..., None, None]) & (self.mesh_grid[1][None, ...] < (1. - pad_x[..., None, None]))
111 | mask_y = (self.mesh_grid[0][None, ...] > pad_y[..., None, None]) & (self.mesh_grid[0][None, ...] < (1. - pad_y[..., None, None]))
112 | mask = (mask_x & mask_y)
113 |
114 | roi_images = roi_align(image, [rois], output_size=self.get_image_size())
115 |
116 | if pad_square:
117 | roi_images = roi_images * mask[:, None, :, :]
118 |
119 | return roi_images, rois
120 |
121 | def encode_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool = True, padding_scale: float = 1.0):
122 | roi_images, rois = self.extract_rois(image, rois, pad_square, padding_scale)
123 | return self.encode_image(roi_images)
124 |
125 | def decode(self,
126 | image_output: ClipEncodeImageOutput,
127 | text_output: ClipEncodeTextOutput
128 | ) -> ClipDecodeOutput:
129 |
130 | image_embeds = image_output.image_embeds
131 | text_embeds = text_output.text_embeds
132 |
133 | image_embeds = image_embeds / image_embeds.norm(dim=1, keepdim=True)
134 | text_embeds = text_embeds / text_embeds.norm(dim=1, keepdim=True)
135 | logit_scale = self.clip_model.logit_scale.exp()
136 | logits_per_image = logit_scale * image_embeds @ text_embeds.t()
137 | probs = torch.softmax(logits_per_image, dim=-1)
138 | prob_max = probs.max(dim=-1)
139 |
140 | return ClipDecodeOutput(
141 | labels=prob_max.indices,
142 | scores=prob_max.values
143 | )
144 |
145 | def predict(self,
146 | image: PIL.Image,
147 | text: List[str],
148 | text_encodings: Optional[ClipEncodeTextOutput],
149 | pad_square: bool = True,
150 | threshold: float = 0.1
151 | ) -> ClipDecodeOutput:
152 |
153 | image_tensor = self.image_preprocessor.preprocess_pil_image(image)
154 |
155 | if text_encodings is None:
156 | text_encodings = self.encode_text(text)
157 |
158 | rois = torch.tensor([[0, 0, image.height, image.width]], dtype=image_tensor.dtype, device=image_tensor.device)
159 |
160 | image_encodings = self.encode_rois(image_tensor, rois, pad_square=pad_square)
161 |
162 | return self.decode(image_encodings, text_encodings)
--------------------------------------------------------------------------------
/nanoowl/image_preprocessor.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import torch
18 | import PIL.Image
19 | import numpy as np
20 | from typing import Tuple
21 |
22 |
23 | __all__ = [
24 | "ImagePreprocessor",
25 | "DEFAULT_IMAGE_PREPROCESSOR_MEAN",
26 | "DEFAULT_IMAGE_PREPROCESSOR_STD"
27 | ]
28 |
29 |
30 | DEFAULT_IMAGE_PREPROCESSOR_MEAN = [
31 | 0.48145466 * 255.,
32 | 0.4578275 * 255.,
33 | 0.40821073 * 255.
34 | ]
35 |
36 |
37 | DEFAULT_IMAGE_PREPROCESSOR_STD = [
38 | 0.26862954 * 255.,
39 | 0.26130258 * 255.,
40 | 0.27577711 * 255.
41 | ]
42 |
43 |
44 | class ImagePreprocessor(torch.nn.Module):
45 | def __init__(self,
46 | mean: Tuple[float, float, float] = DEFAULT_IMAGE_PREPROCESSOR_MEAN,
47 | std: Tuple[float, float, float] = DEFAULT_IMAGE_PREPROCESSOR_STD
48 | ):
49 | super().__init__()
50 |
51 | self.register_buffer(
52 | "mean",
53 | torch.tensor(mean)[None, :, None, None]
54 | )
55 | self.register_buffer(
56 | "std",
57 | torch.tensor(std)[None, :, None, None]
58 | )
59 |
60 | def forward(self, image: torch.Tensor, inplace: bool = False):
61 |
62 | if inplace:
63 | image = image.sub_(self.mean).div_(self.std)
64 | else:
65 | image = (image - self.mean) / self.std
66 |
67 | return image
68 |
69 | @torch.no_grad()
70 | def preprocess_pil_image(self, image: PIL.Image.Image):
71 | image = torch.from_numpy(np.asarray(image))
72 | image = image.permute(2, 0, 1)[None, ...]
73 | image = image.to(self.mean.device)
74 | image = image.type(self.mean.dtype)
75 | return self.forward(image, inplace=True)
--------------------------------------------------------------------------------
/nanoowl/owl_drawing.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import PIL.Image
18 | import PIL.ImageDraw
19 | import cv2
20 | from .owl_predictor import OwlDecodeOutput
21 | import matplotlib.pyplot as plt
22 | import numpy as np
23 | from typing import List
24 |
25 |
26 | def get_colors(count: int):
27 | cmap = plt.cm.get_cmap("rainbow", count)
28 | colors = []
29 | for i in range(count):
30 | color = cmap(i)
31 | color = [int(255 * value) for value in color]
32 | colors.append(tuple(color))
33 | return colors
34 |
35 |
36 | def draw_owl_output(image, output: OwlDecodeOutput, text: List[str], draw_text=True):
37 | is_pil = not isinstance(image, np.ndarray)
38 | if is_pil:
39 | image = np.asarray(image)
40 | font = cv2.FONT_HERSHEY_SIMPLEX
41 | font_scale = 0.75
42 | colors = get_colors(len(text))
43 | num_detections = len(output.labels)
44 |
45 | for i in range(num_detections):
46 | box = output.boxes[i]
47 | label_index = int(output.labels[i])
48 | box = [int(x) for x in box]
49 | pt0 = (box[0], box[1])
50 | pt1 = (box[2], box[3])
51 | cv2.rectangle(
52 | image,
53 | pt0,
54 | pt1,
55 | colors[label_index],
56 | 4
57 | )
58 | if draw_text:
59 | offset_y = 12
60 | offset_x = 0
61 | label_text = text[label_index]
62 | cv2.putText(
63 | image,
64 | label_text,
65 | (box[0] + offset_x, box[1] + offset_y),
66 | font,
67 | font_scale,
68 | colors[label_index],
69 | 2,# thickness
70 | cv2.LINE_AA
71 | )
72 | if is_pil:
73 | image = PIL.Image.fromarray(image)
74 | return image
--------------------------------------------------------------------------------
/nanoowl/owl_predictor.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import torch
18 | import numpy as np
19 | import PIL.Image
20 | import subprocess
21 | import tempfile
22 | import os
23 | from torchvision.ops import roi_align
24 | from transformers.models.owlvit.modeling_owlvit import OwlViTForObjectDetection
25 | from transformers.models.owlvit.processing_owlvit import OwlViTProcessor
26 | from dataclasses import dataclass
27 | from typing import List, Optional, Union, Tuple
28 | from .image_preprocessor import ImagePreprocessor
29 |
30 | __all__ = [
31 | "OwlPredictor",
32 | "OwlEncodeTextOutput",
33 | "OwlEncodeImageOutput",
34 | "OwlDecodeOutput"
35 | ]
36 |
37 |
38 | def _owl_center_to_corners_format_torch(bboxes_center):
39 | center_x, center_y, width, height = bboxes_center.unbind(-1)
40 | bbox_corners = torch.stack(
41 | [
42 | (center_x - 0.5 * width),
43 | (center_y - 0.5 * height),
44 | (center_x + 0.5 * width),
45 | (center_y + 0.5 * height)
46 | ],
47 | dim=-1,
48 | )
49 | return bbox_corners
50 |
51 |
52 | def _owl_get_image_size(hf_name: str):
53 |
54 | image_sizes = {
55 | "google/owlvit-base-patch32": 768,
56 | "google/owlvit-base-patch16": 768,
57 | "google/owlvit-large-patch14": 840,
58 | }
59 |
60 | return image_sizes[hf_name]
61 |
62 |
63 | def _owl_get_patch_size(hf_name: str):
64 |
65 | patch_sizes = {
66 | "google/owlvit-base-patch32": 32,
67 | "google/owlvit-base-patch16": 16,
68 | "google/owlvit-large-patch14": 14,
69 | }
70 |
71 | return patch_sizes[hf_name]
72 |
73 |
74 | # This function is modified from https://github.com/huggingface/transformers/blob/e8fdd7875def7be59e2c9b823705fbf003163ea0/src/transformers/models/owlvit/modeling_owlvit.py#L1333
75 | # Copyright 2022 Google AI and The HuggingFace Team. All rights reserved.
76 | # SPDX-License-Identifier: Apache-2.0
77 | def _owl_normalize_grid_corner_coordinates(num_patches_per_side):
78 | box_coordinates = np.stack(
79 | np.meshgrid(np.arange(1, num_patches_per_side + 1), np.arange(1, num_patches_per_side + 1)), axis=-1
80 | ).astype(np.float32)
81 | box_coordinates /= np.array([num_patches_per_side, num_patches_per_side], np.float32)
82 |
83 | box_coordinates = box_coordinates.reshape(
84 | box_coordinates.shape[0] * box_coordinates.shape[1], box_coordinates.shape[2]
85 | )
86 | box_coordinates = torch.from_numpy(box_coordinates)
87 |
88 | return box_coordinates
89 |
90 |
91 | # This function is modified from https://github.com/huggingface/transformers/blob/e8fdd7875def7be59e2c9b823705fbf003163ea0/src/transformers/models/owlvit/modeling_owlvit.py#L1354
92 | # Copyright 2022 Google AI and The HuggingFace Team. All rights reserved.
93 | # SPDX-License-Identifier: Apache-2.0
94 | def _owl_compute_box_bias(num_patches_per_side):
95 | box_coordinates = _owl_normalize_grid_corner_coordinates(num_patches_per_side)
96 | box_coordinates = torch.clip(box_coordinates, 0.0, 1.0)
97 |
98 | box_coord_bias = torch.log(box_coordinates + 1e-4) - torch.log1p(-box_coordinates + 1e-4)
99 |
100 | box_size = torch.full_like(box_coord_bias, 1.0 / num_patches_per_side)
101 | box_size_bias = torch.log(box_size + 1e-4) - torch.log1p(-box_size + 1e-4)
102 |
103 | box_bias = torch.cat([box_coord_bias, box_size_bias], dim=-1)
104 |
105 | return box_bias
106 |
107 |
108 | def _owl_box_roi_to_box_global(boxes, rois):
109 | x0y0 = rois[..., :2]
110 | x1y1 = rois[..., 2:]
111 | wh = (x1y1 - x0y0).repeat(1, 1, 2)
112 | x0y0 = x0y0.repeat(1, 1, 2)
113 | return (boxes * wh) + x0y0
114 |
115 |
116 | @dataclass
117 | class OwlEncodeTextOutput:
118 | text_embeds: torch.Tensor
119 |
120 | def slice(self, start_index, end_index):
121 | return OwlEncodeTextOutput(
122 | text_embeds=self.text_embeds[start_index:end_index]
123 | )
124 |
125 |
126 | @dataclass
127 | class OwlEncodeImageOutput:
128 | image_embeds: torch.Tensor
129 | image_class_embeds: torch.Tensor
130 | logit_shift: torch.Tensor
131 | logit_scale: torch.Tensor
132 | pred_boxes: torch.Tensor
133 |
134 |
135 | @dataclass
136 | class OwlDecodeOutput:
137 | labels: torch.Tensor
138 | scores: torch.Tensor
139 | boxes: torch.Tensor
140 | input_indices: torch.Tensor
141 |
142 |
143 | class OwlPredictor(torch.nn.Module):
144 |
145 | def __init__(self,
146 | model_name: str = "google/owlvit-base-patch32",
147 | device: str = "cuda",
148 | image_encoder_engine: Optional[str] = None,
149 | image_encoder_engine_max_batch_size: int = 1,
150 | image_preprocessor: Optional[ImagePreprocessor] = None
151 | ):
152 |
153 | super().__init__()
154 |
155 | self.image_size = _owl_get_image_size(model_name)
156 | self.device = device
157 | self.model = OwlViTForObjectDetection.from_pretrained(model_name).to(self.device).eval()
158 | self.processor = OwlViTProcessor.from_pretrained(model_name)
159 | self.patch_size = _owl_get_patch_size(model_name)
160 | self.num_patches_per_side = self.image_size // self.patch_size
161 | self.box_bias = _owl_compute_box_bias(self.num_patches_per_side).to(self.device)
162 | self.num_patches = (self.num_patches_per_side)**2
163 | self.mesh_grid = torch.stack(
164 | torch.meshgrid(
165 | torch.linspace(0., 1., self.image_size),
166 | torch.linspace(0., 1., self.image_size)
167 | )
168 | ).to(self.device).float()
169 | self.image_encoder_engine = None
170 | if image_encoder_engine is not None:
171 | image_encoder_engine = OwlPredictor.load_image_encoder_engine(image_encoder_engine, image_encoder_engine_max_batch_size)
172 | self.image_encoder_engine = image_encoder_engine
173 | self.image_preprocessor = image_preprocessor.to(self.device).eval() if image_preprocessor else ImagePreprocessor().to(self.device).eval()
174 |
175 | def get_num_patches(self):
176 | return self.num_patches
177 |
178 | def get_device(self):
179 | return self.device
180 |
181 | def get_image_size(self):
182 | return (self.image_size, self.image_size)
183 |
184 | def encode_text(self, text: List[str]) -> OwlEncodeTextOutput:
185 | text_input = self.processor(text=text, return_tensors="pt")
186 | input_ids = text_input['input_ids'].to(self.device)
187 | attention_mask = text_input['attention_mask'].to(self.device)
188 | text_outputs = self.model.owlvit.text_model(input_ids, attention_mask)
189 | text_embeds = text_outputs[1]
190 | text_embeds = self.model.owlvit.text_projection(text_embeds)
191 | return OwlEncodeTextOutput(text_embeds=text_embeds)
192 |
193 | def encode_image_torch(self, image: torch.Tensor) -> OwlEncodeImageOutput:
194 |
195 | vision_outputs = self.model.owlvit.vision_model(image)
196 | last_hidden_state = vision_outputs[0]
197 | image_embeds = self.model.owlvit.vision_model.post_layernorm(last_hidden_state)
198 | class_token_out = image_embeds[:, :1, :]
199 | image_embeds = image_embeds[:, 1:, :] * class_token_out
200 | image_embeds = self.model.layer_norm(image_embeds) # 768 dim
201 |
202 | # Box Head
203 | pred_boxes = self.model.box_head(image_embeds)
204 | pred_boxes += self.box_bias
205 | pred_boxes = torch.sigmoid(pred_boxes)
206 | pred_boxes = _owl_center_to_corners_format_torch(pred_boxes)
207 |
208 | # Class Head
209 | image_class_embeds = self.model.class_head.dense0(image_embeds)
210 | logit_shift = self.model.class_head.logit_shift(image_embeds)
211 | logit_scale = self.model.class_head.logit_scale(image_embeds)
212 | logit_scale = self.model.class_head.elu(logit_scale) + 1
213 |
214 | output = OwlEncodeImageOutput(
215 | image_embeds=image_embeds,
216 | image_class_embeds=image_class_embeds,
217 | logit_shift=logit_shift,
218 | logit_scale=logit_scale,
219 | pred_boxes=pred_boxes
220 | )
221 |
222 | return output
223 |
224 | def encode_image_trt(self, image: torch.Tensor) -> OwlEncodeImageOutput:
225 | return self.image_encoder_engine(image)
226 |
227 | def encode_image(self, image: torch.Tensor) -> OwlEncodeImageOutput:
228 | if self.image_encoder_engine is not None:
229 | return self.encode_image_trt(image)
230 | else:
231 | return self.encode_image_torch(image)
232 |
233 | def extract_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool = True, padding_scale: float = 1.0):
234 | if len(rois) == 0:
235 | return torch.empty(
236 | (0, image.shape[1], self.image_size, self.image_size),
237 | dtype=image.dtype,
238 | device=image.device
239 | )
240 | if pad_square:
241 | # pad square
242 | w = padding_scale * (rois[..., 2] - rois[..., 0]) / 2
243 | h = padding_scale * (rois[..., 3] - rois[..., 1]) / 2
244 | cx = (rois[..., 0] + rois[..., 2]) / 2
245 | cy = (rois[..., 1] + rois[..., 3]) / 2
246 | s = torch.max(w, h)
247 | rois = torch.stack([cx-s, cy-s, cx+s, cy+s], dim=-1)
248 |
249 | # compute mask
250 | pad_x = (s - w) / (2 * s)
251 | pad_y = (s - h) / (2 * s)
252 | mask_x = (self.mesh_grid[1][None, ...] > pad_x[..., None, None]) & (self.mesh_grid[1][None, ...] < (1. - pad_x[..., None, None]))
253 | mask_y = (self.mesh_grid[0][None, ...] > pad_y[..., None, None]) & (self.mesh_grid[0][None, ...] < (1. - pad_y[..., None, None]))
254 | mask = (mask_x & mask_y)
255 |
256 | # extract rois
257 | roi_images = roi_align(image, [rois], output_size=self.get_image_size())
258 |
259 | # mask rois
260 | if pad_square:
261 | roi_images = (roi_images * mask[:, None, :, :])
262 |
263 | return roi_images, rois
264 |
265 | def encode_rois(self, image: torch.Tensor, rois: torch.Tensor, pad_square: bool = True, padding_scale: float=1.0):
266 | # with torch_timeit_sync("extract rois"):
267 | roi_images, rois = self.extract_rois(image, rois, pad_square, padding_scale)
268 | # with torch_timeit_sync("encode images"):
269 | output = self.encode_image(roi_images)
270 | pred_boxes = _owl_box_roi_to_box_global(output.pred_boxes, rois[:, None, :])
271 | output.pred_boxes = pred_boxes
272 | return output
273 |
274 | def decode(self,
275 | image_output: OwlEncodeImageOutput,
276 | text_output: OwlEncodeTextOutput,
277 | threshold: Union[int, float, List[Union[int, float]]] = 0.1,
278 | ) -> OwlDecodeOutput:
279 |
280 | if isinstance(threshold, (int, float)):
281 | threshold = [threshold] * len(text_output.text_embeds) #apply single threshold to all labels
282 |
283 | num_input_images = image_output.image_class_embeds.shape[0]
284 |
285 | image_class_embeds = image_output.image_class_embeds
286 | image_class_embeds = image_class_embeds / (torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6)
287 | query_embeds = text_output.text_embeds
288 | query_embeds = query_embeds / (torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6)
289 | logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds)
290 | logits = (logits + image_output.logit_shift) * image_output.logit_scale
291 |
292 | scores_sigmoid = torch.sigmoid(logits)
293 | scores_max = scores_sigmoid.max(dim=-1)
294 | labels = scores_max.indices
295 | scores = scores_max.values
296 | masks = []
297 | for i, thresh in enumerate(threshold):
298 | label_mask = labels == i
299 | score_mask = scores > thresh
300 | obj_mask = torch.logical_and(label_mask,score_mask)
301 | masks.append(obj_mask)
302 |
303 | mask = masks[0]
304 | for mask_t in masks[1:]:
305 | mask = torch.logical_or(mask, mask_t)
306 |
307 | input_indices = torch.arange(0, num_input_images, dtype=labels.dtype, device=labels.device)
308 | input_indices = input_indices[:, None].repeat(1, self.num_patches)
309 |
310 | return OwlDecodeOutput(
311 | labels=labels[mask],
312 | scores=scores[mask],
313 | boxes=image_output.pred_boxes[mask],
314 | input_indices=input_indices[mask]
315 | )
316 |
317 | @staticmethod
318 | def get_image_encoder_input_names():
319 | return ["image"]
320 |
321 | @staticmethod
322 | def get_image_encoder_output_names():
323 | names = [
324 | "image_embeds",
325 | "image_class_embeds",
326 | "logit_shift",
327 | "logit_scale",
328 | "pred_boxes"
329 | ]
330 | return names
331 |
332 |
333 | def export_image_encoder_onnx(self,
334 | output_path: str,
335 | use_dynamic_axes: bool = True,
336 | batch_size: int = 1,
337 | onnx_opset=17
338 | ):
339 |
340 | class TempModule(torch.nn.Module):
341 | def __init__(self, parent):
342 | super().__init__()
343 | self.parent = parent
344 | def forward(self, image):
345 | output = self.parent.encode_image_torch(image)
346 | return (
347 | output.image_embeds,
348 | output.image_class_embeds,
349 | output.logit_shift,
350 | output.logit_scale,
351 | output.pred_boxes
352 | )
353 |
354 | data = torch.randn(batch_size, 3, self.image_size, self.image_size).to(self.device)
355 |
356 | if use_dynamic_axes:
357 | dynamic_axes = {
358 | "image": {0: "batch"},
359 | "image_embeds": {0: "batch"},
360 | "image_class_embeds": {0: "batch"},
361 | "logit_shift": {0: "batch"},
362 | "logit_scale": {0: "batch"},
363 | "pred_boxes": {0: "batch"}
364 | }
365 | else:
366 | dynamic_axes = {}
367 |
368 | model = TempModule(self)
369 |
370 | torch.onnx.export(
371 | model,
372 | data,
373 | output_path,
374 | input_names=self.get_image_encoder_input_names(),
375 | output_names=self.get_image_encoder_output_names(),
376 | dynamic_axes=dynamic_axes,
377 | opset_version=onnx_opset
378 | )
379 |
380 | @staticmethod
381 | def load_image_encoder_engine(engine_path: str, max_batch_size: int = 1):
382 | import tensorrt as trt
383 | from torch2trt import TRTModule
384 |
385 | with trt.Logger() as logger, trt.Runtime(logger) as runtime:
386 | with open(engine_path, 'rb') as f:
387 | engine_bytes = f.read()
388 | engine = runtime.deserialize_cuda_engine(engine_bytes)
389 |
390 | base_module = TRTModule(
391 | engine,
392 | input_names=OwlPredictor.get_image_encoder_input_names(),
393 | output_names=OwlPredictor.get_image_encoder_output_names()
394 | )
395 |
396 | class Wrapper(torch.nn.Module):
397 | def __init__(self, base_module: TRTModule, max_batch_size: int):
398 | super().__init__()
399 | self.base_module = base_module
400 | self.max_batch_size = max_batch_size
401 |
402 | @torch.no_grad()
403 | def forward(self, image):
404 |
405 | b = image.shape[0]
406 |
407 | results = []
408 |
409 | for start_index in range(0, b, self.max_batch_size):
410 | end_index = min(b, start_index + self.max_batch_size)
411 | image_slice = image[start_index:end_index]
412 | # with torch_timeit_sync("run_engine"):
413 | output = self.base_module(image_slice)
414 | results.append(
415 | output
416 | )
417 |
418 | return OwlEncodeImageOutput(
419 | image_embeds=torch.cat([r[0] for r in results], dim=0),
420 | image_class_embeds=torch.cat([r[1] for r in results], dim=0),
421 | logit_shift=torch.cat([r[2] for r in results], dim=0),
422 | logit_scale=torch.cat([r[3] for r in results], dim=0),
423 | pred_boxes=torch.cat([r[4] for r in results], dim=0)
424 | )
425 |
426 | image_encoder = Wrapper(base_module, max_batch_size)
427 |
428 | return image_encoder
429 |
430 | def build_image_encoder_engine(self,
431 | engine_path: str,
432 | max_batch_size: int = 1,
433 | fp16_mode = True,
434 | onnx_path: Optional[str] = None,
435 | onnx_opset: int = 17
436 | ):
437 |
438 | if onnx_path is None:
439 | onnx_dir = tempfile.mkdtemp()
440 | onnx_path = os.path.join(onnx_dir, "image_encoder.onnx")
441 | self.export_image_encoder_onnx(onnx_path, onnx_opset=onnx_opset)
442 |
443 | args = ["/usr/src/tensorrt/bin/trtexec"]
444 |
445 | args.append(f"--onnx={onnx_path}")
446 | args.append(f"--saveEngine={engine_path}")
447 |
448 | if fp16_mode:
449 | args += ["--fp16"]
450 |
451 | args += [f"--shapes=image:1x3x{self.image_size}x{self.image_size}"]
452 |
453 | subprocess.call(args)
454 |
455 | return self.load_image_encoder_engine(engine_path, max_batch_size)
456 |
457 | def predict(self,
458 | image: PIL.Image,
459 | text: List[str],
460 | text_encodings: Optional[OwlEncodeTextOutput],
461 | threshold: Union[int, float, List[Union[int, float]]] = 0.1,
462 | pad_square: bool = True,
463 |
464 | ) -> OwlDecodeOutput:
465 |
466 | image_tensor = self.image_preprocessor.preprocess_pil_image(image)
467 |
468 | if text_encodings is None:
469 | text_encodings = self.encode_text(text)
470 |
471 | rois = torch.tensor([[0, 0, image.width, image.height]], dtype=image_tensor.dtype, device=image_tensor.device)
472 |
473 | image_encodings = self.encode_rois(image_tensor, rois, pad_square=pad_square)
474 |
475 | return self.decode(image_encodings, text_encodings, threshold)
476 |
477 |
--------------------------------------------------------------------------------
/nanoowl/sync_timer.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import time
18 | import torch
19 |
20 |
21 | __all__ == [
22 | "SyncTimer"
23 | ]
24 |
25 |
26 | class SyncTimer():
27 |
28 | def __init__(self, name: str):
29 | self.name = name
30 | self.t0 = None
31 |
32 | def __enter__(self, *args, **kwargs):
33 | self.t0 = time.perf_counter_ns()
34 |
35 | def __exit__(self, *args, **kwargs):
36 | torch.cuda.current_stream().synchronize()
37 | t1 = time.perf_counter_ns()
38 | dt = (t1 - self.t0) / 1e9
39 | print(f"{self.name} FPS: {round(1./dt, 3)}")
--------------------------------------------------------------------------------
/nanoowl/tree.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import json
18 | from enum import Enum
19 | from typing import List, Optional, Mapping
20 | from .clip_predictor import ClipEncodeTextOutput
21 | from .owl_predictor import OwlEncodeTextOutput
22 |
23 |
24 | __all__ = [
25 | "TreeOp",
26 | "TreeNode",
27 | "Tree"
28 | ]
29 |
30 |
31 | class TreeOp(Enum):
32 | DETECT = "detect"
33 | CLASSIFY = "classify"
34 |
35 | def __str__(self) -> str:
36 | return str(self.value)
37 |
38 |
39 | class TreeNode:
40 | op: TreeOp
41 | input: int
42 | outputs: List[int]
43 |
44 | def __init__(self, op: TreeOp, input: int, outputs: Optional[List[int]] = None):
45 | self.op = op
46 | self.input = input
47 | self.outputs = [] if outputs is None else outputs
48 |
49 | def to_dict(self):
50 | return {
51 | "op": str(self.op),
52 | "input": self.input,
53 | "outputs": self.outputs
54 | }
55 |
56 | @staticmethod
57 | def from_dict(node_dict: dict):
58 |
59 | if "op" not in node_dict:
60 | raise RuntimeError("Missing 'op' field.")
61 |
62 | if "input" not in node_dict:
63 | raise RuntimeError("Missing 'input' field.")
64 |
65 | if "outputs" not in node_dict:
66 | raise RuntimeError("Missing 'input' field.")
67 |
68 | return TreeNode(
69 | op=node_dict["op"],
70 | input=node_dict["input"],
71 | outputs=node_dict["outputs"]
72 | )
73 |
74 |
75 | class Tree:
76 | nodes: List[TreeNode]
77 | labels: List[str]
78 |
79 | def __init__(self, nodes, labels):
80 | self.nodes = nodes
81 | self.labels = labels
82 | self._label_index_to_node_map = self._build_label_index_to_node_map()
83 |
84 | def _build_label_index_to_node_map(self) -> Mapping[int, "TreeNode"]:
85 | label_to_node_map = {}
86 | for node in self.nodes:
87 | for label_index in node.outputs:
88 | if label_index in label_to_node_map:
89 | raise RuntimeError("Duplicate output label.")
90 | label_to_node_map[label_index] = node
91 | return label_to_node_map
92 |
93 | def to_dict(self):
94 | return {
95 | "nodes": [node.to_dict() for node in self.nodes],
96 | "labels": self.labels
97 | }
98 |
99 | @staticmethod
100 | def from_prompt(prompt: str):
101 |
102 | nodes = []
103 | node_stack = []
104 | label_index_stack = [0]
105 | labels = ["image"]
106 | label_index = 0
107 |
108 | for ch in prompt:
109 |
110 | if ch == "[":
111 | label_index += 1
112 | node = TreeNode(op=TreeOp.DETECT, input=label_index_stack[-1])
113 | node.outputs.append(label_index)
114 | node_stack.append(node)
115 | label_index_stack.append(label_index)
116 | labels.append("")
117 | nodes.append(node)
118 | elif ch == "]":
119 | if len(node_stack) == 0:
120 | raise RuntimeError("Unexpected ']'.")
121 | node = node_stack.pop()
122 | if node.op != TreeOp.DETECT:
123 | raise RuntimeError("Unexpected ']'.")
124 | label_index_stack.pop()
125 | elif ch == "(":
126 | label_index = label_index + 1
127 | node = TreeNode(op=TreeOp.CLASSIFY, input=label_index_stack[-1])
128 | node.outputs.append(label_index)
129 | node_stack.append(node)
130 | label_index_stack.append(label_index)
131 | labels.append("")
132 | nodes.append(node)
133 | elif ch == ")":
134 | if len(node_stack) == 0:
135 | raise RuntimeError("Unexpected ')'.")
136 | node = node_stack.pop()
137 | if node.op != TreeOp.CLASSIFY:
138 | raise RuntimeError("Unexpected ')'.")
139 | label_index_stack.pop()
140 | elif ch == ",":
141 | label_index_stack.pop()
142 | label_index = label_index + 1
143 | label_index_stack.append(label_index)
144 | node_stack[-1].outputs.append(label_index)
145 | labels.append("")
146 | else:
147 | labels[label_index_stack[-1]] += ch
148 |
149 | if len(node_stack) > 0:
150 | if node_stack[-1].op == TreeOp.DETECT:
151 | raise RuntimeError("Missing ']'.")
152 | if node_stack[-1].op == TreeOp.CLASSIFY:
153 | raise RuntimeError("Missing ')'.")
154 |
155 | labels = [label.strip() for label in labels]
156 |
157 | graph = Tree(nodes=nodes, labels=labels)
158 |
159 | return graph
160 |
161 | def to_json(self, indent: Optional[int] = None) -> str:
162 | return json.dumps(self.to_dict(), indent=indent)
163 |
164 | @staticmethod
165 | def from_dict(tree_dict: dict) -> "Tree":
166 |
167 | if "nodes" not in tree_dict:
168 | raise RuntimeError("Missing 'nodes' field.")
169 |
170 | if "labels" not in tree_dict:
171 | raise RuntimeError("Missing 'labels' field.")
172 |
173 | nodes = [TreeNode.from_dict(node_dict) for node_dict in tree_dict["nodes"]]
174 | labels = tree_dict["labels"]
175 |
176 | return Tree(nodes=nodes, labels=labels)
177 |
178 | @staticmethod
179 | def from_json(tree_json: str) -> "Tree":
180 | tree_dict = json.loads(tree_json)
181 | return Tree.from_dict(tree_dict)
182 |
183 | def get_op_for_label_index(self, label_index: int):
184 | if label_index not in self._label_index_to_node_map:
185 | return None
186 | return self._label_index_to_node_map[label_index].op
187 |
188 | def get_label_indices_with_op(self, op: TreeOp):
189 | return [
190 | index for index in range(len(self.labels))
191 | if self.get_op_for_label_index(index) == op
192 | ]
193 |
194 | def get_classify_label_indices(self):
195 | return self.get_label_indices_with_op(TreeOp.CLASSIFY)
196 |
197 | def get_detect_label_indices(self):
198 | return self.get_label_indices_with_op(TreeOp.DETECT)
199 |
200 | def find_nodes_with_input(self, input_index: int):
201 | return [n for n in self.nodes if n.input == input_index]
202 |
203 | def find_detect_nodes_with_input(self, input_index: int):
204 | return [n for n in self.find_nodes_with_input(input_index) if n.op == TreeOp.DETECT]
205 |
206 | def find_classify_nodes_with_input(self, input_index: int):
207 | return [n for n in self.find_nodes_with_input(input_index) if n.op == TreeOp.CLASSIFY]
208 |
209 | def get_label_depth(self, index):
210 | depth = 0
211 | while index in self._label_index_to_node_map:
212 | depth += 1
213 | node = self._label_index_to_node_map[index]
214 | index = node.input
215 | return depth
216 |
217 | def get_label_depth_map(self):
218 | depths = {}
219 | for i in range(len(self.labels)):
220 | depths[i] = self.get_label_depth(i)
221 | return depths
222 |
223 | def get_label_map(self):
224 | label_map = {}
225 | for i in range(len(self.labels)):
226 | label_map[i] = self.labels[i]
227 | return label_map
--------------------------------------------------------------------------------
/nanoowl/tree_drawing.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import PIL.Image
18 | import PIL.ImageDraw
19 | import cv2
20 | from .tree import Tree
21 | from .tree_predictor import TreeOutput
22 | import matplotlib.pyplot as plt
23 | import numpy as np
24 | from typing import List
25 |
26 |
27 | def get_colors(count: int):
28 | cmap = plt.cm.get_cmap("rainbow", count)
29 | colors = []
30 | for i in range(count):
31 | color = cmap(i)
32 | color = [int(255 * value) for value in color]
33 | colors.append(tuple(color))
34 | return colors
35 |
36 |
37 | def draw_tree_output(image, output: TreeOutput, tree: Tree, draw_text=True, num_colors=8):
38 | detections = output.detections
39 | is_pil = not isinstance(image, np.ndarray)
40 | if is_pil:
41 | image = np.asarray(image)
42 | font = cv2.FONT_HERSHEY_SIMPLEX
43 | font_scale = 0.75
44 | colors = get_colors(num_colors)
45 | label_map = tree.get_label_map()
46 | label_depths = tree.get_label_depth_map()
47 | for detection in detections:
48 | box = [int(x) for x in detection.box]
49 | pt0 = (box[0], box[1])
50 | pt1 = (box[2], box[3])
51 | box_depth = min(label_depths[i] for i in detection.labels)
52 | cv2.rectangle(
53 | image,
54 | pt0,
55 | pt1,
56 | colors[box_depth % num_colors],
57 | 3
58 | )
59 | if draw_text:
60 | offset_y = 30
61 | offset_x = 8
62 | for label in detection.labels:
63 | label_text = label_map[label]
64 | cv2.putText(
65 | image,
66 | label_text,
67 | (box[0] + offset_x, box[1] + offset_y),
68 | font,
69 | font_scale,
70 | colors[label % num_colors],
71 | 2,# thickness
72 | cv2.LINE_AA
73 | )
74 | offset_y += 18
75 | if is_pil:
76 | image = PIL.Image.fromarray(image)
77 | return image
--------------------------------------------------------------------------------
/nanoowl/tree_predictor.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | from .tree import Tree, TreeOp
18 | from .owl_predictor import OwlPredictor, OwlEncodeTextOutput, OwlEncodeImageOutput
19 | from .clip_predictor import ClipPredictor, ClipEncodeTextOutput, ClipEncodeImageOutput
20 | from .image_preprocessor import ImagePreprocessor
21 |
22 | import torch
23 | import PIL.Image
24 | from typing import Optional, Tuple, List, Mapping, Dict
25 | from dataclasses import dataclass
26 |
27 |
28 | @dataclass
29 | class TreeDetection:
30 | id: int
31 | parent_id: int
32 | box: Tuple[float, float, float, float]
33 | labels: List[int]
34 | scores: List[int]
35 |
36 |
37 | @dataclass
38 | class TreeOutput:
39 | detections: List[TreeDetection]
40 |
41 |
42 | class TreePredictor(torch.nn.Module):
43 |
44 | def __init__(self,
45 | owl_predictor: Optional[OwlPredictor] = None,
46 | clip_predictor: Optional[ClipPredictor] = None,
47 | image_preprocessor: Optional[ImagePreprocessor] = None,
48 | device: str = "cuda"
49 | ):
50 | super().__init__()
51 | self.owl_predictor = OwlPredictor() if owl_predictor is None else owl_predictor
52 | self.clip_predictor = ClipPredictor() if clip_predictor is None else clip_predictor
53 | self.image_preprocessor = ImagePreprocessor().to(device).eval() if image_preprocessor is None else image_preprocessor
54 |
55 | def encode_clip_text(self, tree: Tree) -> Dict[int, ClipEncodeTextOutput]:
56 | label_indices = tree.get_classify_label_indices()
57 | if len(label_indices) == 0:
58 | return {}
59 | labels = [tree.labels[index] for index in label_indices]
60 | text_encodings = self.clip_predictor.encode_text(labels)
61 | label_encodings = {}
62 | for i in range(len(labels)):
63 | label_encodings[label_indices[i]] = text_encodings.slice(i, i+1)
64 | return label_encodings
65 |
66 | def encode_owl_text(self, tree: Tree) -> Dict[int, OwlEncodeTextOutput]:
67 | label_indices = tree.get_detect_label_indices()
68 | if len(label_indices) == 0:
69 | return {}
70 | labels = [tree.labels[index] for index in label_indices]
71 | text_encodings = self.owl_predictor.encode_text(labels)
72 | label_encodings = {}
73 | for i in range(len(labels)):
74 | label_encodings[label_indices[i]] = text_encodings.slice(i, i+1)
75 | return label_encodings
76 |
77 | @torch.no_grad()
78 | def predict(self,
79 | image: PIL.Image.Image,
80 | tree: Tree,
81 | threshold: float = 0.1,
82 | clip_text_encodings: Optional[Dict[int, ClipEncodeTextOutput]] = None,
83 | owl_text_encodings: Optional[Dict[int, OwlEncodeTextOutput]] = None
84 | ):
85 |
86 | if clip_text_encodings is None:
87 | clip_text_encodings = self.encode_clip_text(tree)
88 |
89 | if owl_text_encodings is None:
90 | owl_text_encodings = self.encode_owl_text(tree)
91 |
92 | image_tensor = self.image_preprocessor.preprocess_pil_image(image)
93 | boxes = {
94 | 0: torch.tensor([[0, 0, image.width, image.height]], dtype=image_tensor.dtype, device=image_tensor.device)
95 | }
96 | scores = {
97 | 0: torch.tensor([1.], dtype=torch.float, device=image_tensor.device)
98 | }
99 | instance_ids = {
100 | 0: torch.tensor([0], dtype=torch.int64, device=image_tensor.device)
101 | }
102 | parent_instance_ids = {
103 | 0: torch.tensor([-1], dtype=torch.int64, device=image_tensor.device)
104 | }
105 |
106 | owl_image_encodings: Dict[int, OwlEncodeImageOutput] = {}
107 | clip_image_encodings: Dict[int, ClipEncodeImageOutput] = {}
108 |
109 | global_instance_id = 1
110 |
111 | queue = [0]
112 |
113 | while queue:
114 | label_index = queue.pop(0)
115 |
116 | detect_nodes = tree.find_detect_nodes_with_input(label_index)
117 | classify_nodes = tree.find_classify_nodes_with_input(label_index)
118 |
119 | # Run OWL image encode if required
120 | if len(detect_nodes) > 0 and label_index not in owl_image_encodings:
121 | owl_image_encodings[label_index] = self.owl_predictor.encode_rois(image_tensor, boxes[label_index])
122 |
123 |
124 | # Run CLIP image encode if required
125 | if len(classify_nodes) > 0 and label_index not in clip_image_encodings:
126 | clip_image_encodings[label_index] = self.clip_predictor.encode_rois(image_tensor, boxes[label_index])
127 |
128 | # Decode detect nodes
129 | for node in detect_nodes:
130 |
131 | if node.input not in owl_image_encodings:
132 | raise RuntimeError("Missing owl image encodings for node.")
133 |
134 | # gather encodings
135 | owl_text_encodings_for_node = OwlEncodeTextOutput(
136 | text_embeds=torch.cat([
137 | owl_text_encodings[i].text_embeds for i in node.outputs
138 | ], dim=0)
139 | )
140 |
141 | owl_node_output = self.owl_predictor.decode(
142 | owl_image_encodings[node.input],
143 | owl_text_encodings_for_node,
144 | threshold=threshold
145 | )
146 |
147 | num_detections = len(owl_node_output.labels)
148 | instance_ids_for_node = torch.arange(global_instance_id, global_instance_id + num_detections, dtype=torch.int64, device=owl_node_output.labels.device)
149 | parent_instance_ids_for_node = instance_ids[node.input][owl_node_output.input_indices]
150 | global_instance_id += num_detections
151 |
152 | for i in range(len(node.outputs)):
153 | mask = owl_node_output.labels == i
154 | out_idx = node.outputs[i]
155 | boxes[out_idx] = owl_node_output.boxes[mask]
156 | scores[out_idx] = owl_node_output.scores[mask]
157 | instance_ids[out_idx] = instance_ids_for_node[mask]
158 | parent_instance_ids[out_idx] = parent_instance_ids_for_node[mask]
159 |
160 | for node in classify_nodes:
161 |
162 | if node.input not in clip_image_encodings:
163 | raise RuntimeError("Missing clip image encodings for node.")
164 |
165 | clip_text_encodings_for_node = ClipEncodeTextOutput(
166 | text_embeds=torch.cat([
167 | clip_text_encodings[i].text_embeds for i in node.outputs
168 | ], dim=0)
169 | )
170 |
171 | clip_node_output = self.clip_predictor.decode(
172 | clip_image_encodings[node.input],
173 | clip_text_encodings_for_node
174 | )
175 |
176 | parent_instance_ids_for_node = instance_ids[node.input]
177 |
178 | for i in range(len(node.outputs)):
179 | mask = clip_node_output.labels == i
180 | output_buffer = node.outputs[i]
181 | scores[output_buffer] = clip_node_output.scores[mask].float()
182 | boxes[output_buffer] = boxes[label_index][mask].float()
183 | instance_ids[output_buffer] = instance_ids[node.input][mask]
184 | parent_instance_ids[output_buffer] = parent_instance_ids[node.input][mask]
185 |
186 | for node in detect_nodes:
187 | for buf in node.outputs:
188 | if buf in scores and len(scores[buf]) > 0:
189 | queue.append(buf)
190 |
191 | for node in classify_nodes:
192 | for buf in node.outputs:
193 | if buf in scores and len(scores[buf]) > 0:
194 | queue.append(buf)
195 |
196 | # Fill outputs
197 | detections: Dict[int, TreeDetection] = {}
198 | for i in boxes.keys():
199 | for box, score, instance_id, parent_instance_id in zip(boxes[i], scores[i], instance_ids[i], parent_instance_ids[i]):
200 | instance_id = int(instance_id)
201 | score = float(score)
202 | box = box.tolist()
203 | parent_instance_id = int(parent_instance_id)
204 | if instance_id in detections:
205 | detections[instance_id].labels.append(i)
206 | detections[instance_id].scores.append(score)
207 | else:
208 | detections[instance_id] = TreeDetection(
209 | id=instance_id,
210 | parent_id=parent_instance_id,
211 | box=box,
212 | labels=[i],
213 | scores=[score]
214 | )
215 |
216 | return TreeOutput(detections=detections.values())
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 |
4 | setup(
5 | name="nanoowl",
6 | version="0.0.0",
7 | packages=find_packages()
8 | )
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
--------------------------------------------------------------------------------
/test/test_clip_predictor.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import pytest
18 | import torch
19 | import PIL.Image
20 | from nanoowl.clip_predictor import ClipPredictor
21 | from nanoowl.image_preprocessor import ImagePreprocessor
22 |
23 |
24 | def test_get_image_size():
25 | clip_predictor = ClipPredictor()
26 | assert clip_predictor.get_image_size() == (224, 224)
27 |
28 |
29 | def test_clip_encode_text():
30 |
31 | clip_predictor = ClipPredictor()
32 |
33 | text_encode_output = clip_predictor.encode_text(["a frog", "a dog"])
34 |
35 | assert text_encode_output.text_embeds.shape == (2, 512)
36 |
37 |
38 | def test_clip_encode_image():
39 |
40 | clip_predictor = ClipPredictor()
41 |
42 | image = PIL.Image.open("assets/owl_glove_small.jpg")
43 |
44 | image = image.resize((224, 224))
45 |
46 | image_preprocessor = ImagePreprocessor().to(clip_predictor.device).eval()
47 |
48 | image_tensor = image_preprocessor.preprocess_pil_image(image)
49 |
50 | image_encode_output = clip_predictor.encode_image(image_tensor)
51 |
52 | assert image_encode_output.image_embeds.shape == (1, 512)
53 |
54 |
55 | def test_clip_classify():
56 |
57 | clip_predictor = ClipPredictor()
58 |
59 | image = PIL.Image.open("assets/frog.jpg")
60 |
61 | image = image.resize((224, 224))
62 | image_preprocessor = ImagePreprocessor().to(clip_predictor.device).eval()
63 |
64 | image_tensor = image_preprocessor.preprocess_pil_image(image)
65 |
66 | text_output = clip_predictor.encode_text(["a frog", "an owl"])
67 | image_output = clip_predictor.encode_image(image_tensor)
68 |
69 | classify_output = clip_predictor.decode(image_output, text_output)
70 |
71 | assert classify_output.labels[0] == 0
--------------------------------------------------------------------------------
/test/test_image_preprocessor.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import pytest
18 | import torch
19 | import PIL.Image
20 | from nanoowl.image_preprocessor import ImagePreprocessor
21 |
22 |
23 | def test_image_preprocessor_preprocess_pil_image():
24 |
25 | image_preproc = ImagePreprocessor().to("cuda").eval()
26 |
27 | image = PIL.Image.open("assets/owl_glove_small.jpg")
28 |
29 | image_tensor = image_preproc.preprocess_pil_image(image)
30 |
31 | assert image_tensor.shape == (1, 3, 499, 756)
32 | assert torch.allclose(image_tensor.mean(), torch.zeros_like(image_tensor), atol=1, rtol=1)
33 |
--------------------------------------------------------------------------------
/test/test_owl_predictor.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import pytest
18 | import torch
19 | import PIL.Image
20 | from nanoowl.owl_predictor import OwlPredictor
21 | from nanoowl.image_preprocessor import ImagePreprocessor
22 |
23 |
24 | def test_owl_predictor_get_image_size():
25 | owl_predictor = OwlPredictor()
26 | assert owl_predictor.get_image_size() == (768, 768)
27 |
28 |
29 | def test_owl_predictor_encode_text():
30 |
31 | owl_predictor = OwlPredictor()
32 |
33 | text_encode_output = owl_predictor.encode_text(["a frog", "a dog"])
34 |
35 | assert text_encode_output.text_embeds.shape == (2, 512)
36 |
37 |
38 | def test_owl_predictor_encode_image():
39 |
40 | owl_predictor = OwlPredictor()
41 |
42 | image = PIL.Image.open("assets/owl_glove_small.jpg")
43 |
44 | image = image.resize(owl_predictor.get_image_size())
45 |
46 | image_preprocessor = ImagePreprocessor().to(owl_predictor.device).eval()
47 |
48 | image_tensor = image_preprocessor.preprocess_pil_image(image)
49 |
50 | image_encode_output = owl_predictor.encode_image(image_tensor)
51 |
52 | assert image_encode_output.image_class_embeds.shape == (1, owl_predictor.get_num_patches(), 512)
53 |
54 |
55 | def test_owl_predictor_decode():
56 |
57 | owl_predictor = OwlPredictor()
58 |
59 | image = PIL.Image.open("assets/owl_glove_small.jpg")
60 |
61 | image = image.resize(owl_predictor.get_image_size())
62 | image_preprocessor = ImagePreprocessor().to(owl_predictor.device).eval()
63 |
64 | image_tensor = image_preprocessor.preprocess_pil_image(image)
65 |
66 | text_output = owl_predictor.encode_text(["an owl"])
67 | image_output = owl_predictor.encode_image(image_tensor)
68 |
69 | classify_output = owl_predictor.decode(image_output, text_output)
70 |
71 | assert len(classify_output.labels == 1)
72 | assert classify_output.labels[0] == 0
73 | assert classify_output.boxes.shape == (1, 4)
74 | assert classify_output.input_indices == 0
75 |
76 |
77 | def test_owl_predictor_decode_multiple_images():
78 |
79 | owl_predictor = OwlPredictor()
80 | image_preprocessor = ImagePreprocessor().to(owl_predictor.device).eval()
81 |
82 | image_paths = [
83 | "assets/owl_glove_small.jpg",
84 | "assets/frog.jpg"
85 | ]
86 |
87 | images = []
88 | for image_path in image_paths:
89 | image = PIL.Image.open(image_path)
90 | image = image.resize(owl_predictor.get_image_size())
91 | image = image_preprocessor.preprocess_pil_image(image)
92 | images.append(image)
93 |
94 | images = torch.cat(images, dim=0)
95 |
96 | text_output = owl_predictor.encode_text(["an owl", "a frog"])
97 | image_output = owl_predictor.encode_image(images)
98 |
99 | decode_output = owl_predictor.decode(image_output, text_output)
100 |
101 | # check number of detections
102 | assert len(decode_output.labels == 2)
103 |
104 | # check owl detection
105 | assert decode_output.labels[0] == 0
106 | assert decode_output.input_indices[0] == 0
107 |
108 | # check frog detection
109 | assert decode_output.labels[1] == 1
110 | assert decode_output.input_indices[1] == 1
111 |
--------------------------------------------------------------------------------
/test/test_tree.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import pytest
18 | from nanoowl.tree import (
19 | Tree,
20 | TreeOp
21 | )
22 |
23 |
24 | def test_tree_from_prompt():
25 |
26 | graph = Tree.from_prompt("[a face]")
27 |
28 | assert len(graph.nodes) == 1
29 | assert len(graph.labels) == 2
30 | assert graph.labels[0] == "image"
31 | assert graph.labels[1] == "a face"
32 | assert graph.nodes[0].op == TreeOp.DETECT
33 |
34 | graph = Tree.from_prompt("[a face](a dog, a cat)")
35 |
36 | assert len(graph.nodes) == 2
37 | assert len(graph.labels) == 4
38 | assert graph.labels[0] == "image"
39 | assert graph.labels[1] == "a face"
40 | assert graph.labels[2] == "a dog"
41 | assert graph.labels[3] == "a cat"
42 | assert graph.nodes[0].op == TreeOp.DETECT
43 | assert graph.nodes[1].op == TreeOp.CLASSIFY
44 |
45 | with pytest.raises(RuntimeError):
46 | Tree.from_prompt("]a face]")
47 |
48 | with pytest.raises(RuntimeError):
49 | Tree.from_prompt("[a face")
50 |
51 | with pytest.raises(RuntimeError):
52 | Tree.from_prompt("[a face)")
53 |
54 | with pytest.raises(RuntimeError):
55 | Tree.from_prompt("[a face]]")
56 |
57 |
58 | def test_tree_to_dict():
59 |
60 | tree = Tree.from_prompt("[a[b,c(d,e)]]")
61 | tree_dict = tree.to_dict()
62 | assert "nodes" in tree_dict
63 | assert "labels" in tree_dict
64 | assert len(tree_dict["nodes"]) == 3
65 | assert len(tree_dict["labels"]) == 6
66 |
67 |
68 |
69 | def test_tree_from_prompt():
70 |
71 | tree = Tree.from_prompt("(office, home, outdoors, gym)")
72 |
73 | print(tree)
--------------------------------------------------------------------------------
/test/test_tree_predictor.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2 | # SPDX-License-Identifier: Apache-2.0
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import pytest
18 | import PIL.Image
19 | from nanoowl.tree_predictor import TreePredictor
20 | from nanoowl.tree import Tree
21 |
22 |
23 | def test_encode_clip_labels():
24 |
25 | predictor = TreePredictor()
26 | tree = Tree.from_prompt("(sunny, rainy)")
27 |
28 | text_encodings = predictor.encode_clip_text(tree)
29 |
30 | assert len(text_encodings) == 2
31 | assert 1 in text_encodings
32 | assert 2 in text_encodings
33 | assert text_encodings[1].text_embeds.shape == (1, 512)
34 |
35 |
36 | def test_encode_owl_labels():
37 |
38 | predictor = TreePredictor()
39 | tree = Tree.from_prompt("[a face [an eye, a nose]]")
40 |
41 | text_encodings = predictor.encode_owl_text(tree)
42 |
43 | assert len(text_encodings) == 3
44 | assert 1 in text_encodings
45 | assert 2 in text_encodings
46 | assert 3 in text_encodings
47 | assert text_encodings[1].text_embeds.shape == (1, 512)
48 |
49 |
50 | def test_encode_clip_owl_labels_mixed():
51 |
52 | predictor = TreePredictor()
53 | tree = Tree.from_prompt("[a face [an eye, a nose](happy, sad)]")
54 |
55 | owl_text_encodings = predictor.encode_owl_text(tree)
56 | clip_text_encodings = predictor.encode_clip_text(tree)
57 |
58 | assert len(owl_text_encodings) == 3
59 | assert len(clip_text_encodings) == 2
60 |
61 |
62 | def test_tree_predictor_predict():
63 |
64 | predictor = TreePredictor()
65 | tree = Tree.from_prompt("[an owl]")
66 |
67 |
68 | image = PIL.Image.open("assets/owl_glove.jpg")
69 |
70 | detections = predictor.predict(image, tree)
71 |
72 |
73 | def test_tree_predictor_predict():
74 |
75 | predictor = TreePredictor()
76 | tree = Tree.from_prompt("(outdoors, indoors)")
77 |
78 |
79 | image = PIL.Image.open("assets/owl_glove.jpg")
80 |
81 | detections = predictor.predict(image, tree)
82 |
83 | print(detections)
84 |
--------------------------------------------------------------------------------