├── .gitignore
├── .pre-commit-config.yaml
├── README.md
├── __init__.py
├── assets
└── icon.svg
├── classification.py
├── detection.py
├── fiftyone.yml
├── instance_segmentation.py
└── semantic_segmentation.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/asottile/blacken-docs
3 | rev: v1.12.0
4 | hooks:
5 | - id: blacken-docs
6 | additional_dependencies: [black==21.12b0]
7 | args: ["-l 79"]
8 | exclude: index.umd.js
9 | - repo: https://github.com/ambv/black
10 | rev: 22.3.0
11 | hooks:
12 | - id: black
13 | language_version: python3
14 | args: ["-l 79"]
15 | exclude: index.umd.js
16 | - repo: local
17 | hooks:
18 | - id: pylint
19 | name: pylint
20 | language: system
21 | files: \.py$
22 | entry: pylint
23 | args: ["--errors-only"]
24 | exclude: index.umd.js
25 | - repo: local
26 | hooks:
27 | - id: ipynb-strip
28 | name: ipynb-strip
29 | language: system
30 | files: \.ipynb$
31 | entry: jupyter nbconvert --clear-output --ClearOutputPreprocessor.enabled=True
32 | args: ["--log-level=ERROR"]
33 | - repo: https://github.com/pre-commit/mirrors-prettier
34 | rev: v2.6.2
35 | hooks:
36 | - id: prettier
37 | exclude: index.umd.js
38 | language_version: system
39 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Zero Shot Prediction Plugin
2 |
3 | 
4 |
5 | This plugin allows you to perform zero-shot prediction on your dataset for the following tasks:
6 |
7 | - Image Classification
8 | - Object Detection
9 | - Instance Segmentation
10 | - Semantic Segmentation
11 |
12 | Given a list of label classes, which you can input either manually, separated by commas, or by uploading a text file, the plugin will perform zero-shot prediction on your dataset for the specified task and add the results to the dataset under a new field, which you can specify.
13 |
14 | ### Updates
15 | - 🆕 **2024-12-03**: Added support for Apple AIMv2 Zero Shot Model (courtesy of [@harpreetsahota204](https://github.com/harpreetsahota204))
16 | - 🆕 **2024-12-16**: Added MPS and GPU support for ALIGN, AltCLIP, Apple AIMv2 (courtesy of [@harpreetsahota204](https://github.com/harpreetsahota204))
17 | - **2024-06-22**: Updated interface for Python operator execution
18 | - **2024-05-30**: Added
19 | - support for Grounding DINO for object detection and instance segmentation
20 | - confidence thresholding for object detection and instance segmentation
21 | - **2024-03-06**: Added support for YOLO-World for object detection and instance segmentation!
22 | - **2024-01-10**: Removing LAION CLIP models.
23 | - **2024-01-05**: Added support for EVA-CLIP, SigLIP, and DFN CLIP for image classification!
24 | - **2023-11-28**: Version 1.1.1 supports OpenCLIP for image classification!
25 | - **2023-11-13**: Version 1.1.0 supports [calling operators from the Python SDK](#python-sdk)!
26 | - **2023-10-27**: Added support for MetaCLIP for image classification
27 | - **2023-10-20**: Added support for AltCLIP and Align for image classification and GroupViT for semantic segmentation
28 |
29 | ### Requirements
30 |
31 | - To use YOLO-World models, you must have `"ultalytics>=8.1.42"`.
32 |
33 | ## Models
34 |
35 | ### Built-in Models
36 |
37 | As a starting point, this plugin comes with at least one zero-shot model per task. These are:
38 |
39 | #### Image Classification
40 |
41 | - [ALIGN](https://huggingface.co/docs/transformers/model_doc/align)
42 | - [AltCLIP](https://huggingface.co/docs/transformers/model_doc/altclip)
43 | - 🆕 [Apple AIMv2](https://huggingface.co/apple/aimv2-large-patch14-224-lit)
44 | - [CLIP](https://github.com/openai/CLIP): (OpenAI)
45 | - [CLIPA](https://github.com/UCSC-VLAA/CLIPA)
46 | - [DFN CLIP](https://huggingface.co/apple/DFN5B-CLIP-ViT-H-14-378): Data Filtering Networks
47 | - [EVA-CLIP](https://huggingface.co/QuanSun/EVA-CLIP)
48 | - [MetaCLIP](https://github.com/facebookresearch/metaclip)
49 | - [SigLIP](https://huggingface.co/timm/ViT-SO400M-14-SigLIP-384)
50 |
51 | #### Object Detection
52 |
53 | - [YOLO-World](https://docs.ultralytics.com/models/yolo-world/)
54 | - [Owl-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)
55 | - [Grounding DINO](https://huggingface.co/docs/transformers/main/en/model_doc/grounding-dino)
56 |
57 | #### Instance Segmentation
58 |
59 | - [Owl-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit) + [Segment Anything (SAM)](https://github.com/facebookresearch/segment-anything)
60 | - [YOLO-World](https://docs.ultralytics.com/models/yolo-world/) + [Segment Anything (SAM)](https://github.com/facebookresearch/segment-anything)
61 | - [Grounding DINO](https://huggingface.co/docs/transformers/main/en/model_doc/grounding-dino) + [Segment Anything (SAM)](https://github.com/facebookresearch/segment-anything)
62 |
63 | #### Semantic Segmentation
64 |
65 | - [CLIPSeg](https://huggingface.co/blog/clipseg-zero-shot)
66 | - [GroupViT](https://huggingface.co/docs/transformers/model_doc/groupvit)
67 |
68 | Most of the models used are from the [HuggingFace Transformers](https://huggingface.co/transformers/) library, and CLIP and SAM models are from the [FiftyOne Model Zoo](https://docs.voxel51.com/user_guide/model_zoo/index.html)
69 |
70 | _Note_— For SAM you will need to have Facebook's `segment-anything` library installed.
71 |
72 | ### Adding Your Own Models
73 |
74 | You can see the implementations for all of these models in the following files:
75 |
76 | - `classification.py`
77 | - `detection.py`
78 | - `instance_segmentation.py`
79 | - `semantic_segmentation.py`
80 |
81 | These models are "registered" via dictionaries in each file. In `semantic_segmentation.py`, for example, the dictionary is:
82 |
83 | ```py
84 | SEMANTIC_SEGMENTATION_MODELS = {
85 | "CLIPSeg": {
86 | "activator": CLIPSeg_activator,
87 | "model": CLIPSegZeroShotModel,
88 | "name": "CLIPSeg",
89 | },
90 | "GroupViT": {
91 | "activator": GroupViT_activator,
92 | "model": GroupViTZeroShotModel,
93 | "name": "GroupViT",
94 | },
95 | }
96 | ```
97 |
98 | The `activator` checks the environment to see if the model is available, and the `model` is a `fiftyone.core.models.Model` object that is instantiated with the model name and the task — or a function that instantiates such a model. The `name` is the name of the model that will be displayed in the dropdown menu in the plugin.
99 |
100 | If you want to add your own model, you can add it to the dictionary in the corresponding file. For example, if you want to add a new semantic segmentation model, you can add it to the `SEMANTIC_SEGMENTATION_MODELS` dictionary in `semantic_segmentation.py`:
101 |
102 | ```py
103 | CLASSIFICATION_MODELS = {
104 | "CLIPSeg": {
105 | "activator": CLIPSeg_activator,
106 | "model": CLIPSegZeroShotModel,
107 | "name": "CLIPSeg",
108 | },
109 | "GroupViT": {
110 | "activator": GroupViT_activator,
111 | "model": GroupViTZeroShotModel,
112 | "name": "GroupViT",
113 | },
114 | ..., # other models
115 | "My Model": {
116 | "activator": my_model_activator,
117 | "model": my_model,
118 | "name": "My Model",
119 | }
120 | }
121 | ```
122 |
123 | 💡 You need to implement the `activator` and `model` functions for your model. The `activator` should check the environment to see if the model is available, and the `model` should be a `fiftyone.core.models.Model` object that is instantiated with the model name and the task.
124 |
125 | ## Watch On Youtube
126 |
127 | [](https://www.youtube.com/watch?v=GlwyFHbTklw&list=PLuREAXoPgT0RZrUaT0UpX_HzwKkoB-S9j&index=7)
128 |
129 | ## Installation
130 |
131 | ```shell
132 | fiftyone plugins download https://github.com/jacobmarks/zero-shot-prediction-plugin
133 | ```
134 |
135 | If you want to use AltCLIP, Align, Owl-ViT, CLIPSeg, or GroupViT, you will also need to install the `transformers` library:
136 |
137 | ```shell
138 | pip install transformers
139 | ```
140 |
141 | If you want to use SAM, you will also need to install the `segment-anything` library:
142 |
143 | ```shell
144 | pip install git+https://github.com/facebookresearch/segment-anything.git
145 | ```
146 |
147 | If you want to use OpenCLIP, you will also need to install the `open_clip` library from PyPI:
148 |
149 | ```shell
150 | pip install open-clip-torch
151 | ```
152 |
153 | Or from source:
154 |
155 | ```shell
156 | pip install git+https://github.com/mlfoundations/open_clip.git
157 | ```
158 |
159 | If you want to use YOLO-World, you will also need to install the `ultralytics` library:
160 |
161 | ```shell
162 | pip install -U ultralytics
163 | ```
164 |
165 | ## Usage
166 |
167 | All of the operators in this plugin can be run in _delegated_ execution mode. This means that instead of waiting for the operator to finish, you _schedule_
168 | the operation to be performed separately. This is useful for long-running operations, such as performing inference on a large dataset.
169 |
170 | Once you have pressed the `Schedule` button for the operator, you will be able to see the job from the command line using FiftyOne's [command line interface](https://docs.voxel51.com/cli/index.html#fiftyone-delegated-operations):
171 |
172 | ```shell
173 | fiftyone delegated list
174 | ```
175 |
176 | will show you the status of all delegated operations.
177 |
178 | To launch a service which runs the operation, as well as any other delegated operations that have been scheduled, run:
179 |
180 | ```shell
181 | fiftyone delegated launch
182 | ```
183 |
184 | Once the operation has completed, you can view the results in the App (upon refresh).
185 |
186 | After the operation completes, you can also clean up your list of delegated operations by running:
187 |
188 | ```shell
189 | fiftyone delegated cleanup -s COMPLETED
190 | ```
191 |
192 | ## Operators
193 |
194 | ### `zero_shot_predict`
195 |
196 | - Select the task you want to perform zero-shot prediction on (image classification, object detection, instance segmentation, or semantic segmentation), and the field you want to add the results to.
197 |
198 | ### `zero_shot_classify`
199 |
200 | - Perform zero-shot image classification on your dataset
201 |
202 | ### `zero_shot_detect`
203 |
204 | - Perform zero-shot object detection on your dataset
205 |
206 | ### `zero_shot_instance_segment`
207 |
208 | - Perform zero-shot instance segmentation on your dataset
209 |
210 | ### `zero_shot_semantic_segment`
211 |
212 | - Perform zero-shot semantic segmentation on your dataset
213 |
214 | ## Python SDK
215 |
216 | You can also use the compute operators from the Python SDK!
217 |
218 | ```python
219 | import fiftyone as fo
220 | import fiftyone.operators as foo
221 | import fiftyone.zoo as foz
222 |
223 | dataset = fo.load_dataset("quickstart")
224 |
225 | ## Access the operator via its URI (plugin name + operator name)
226 | zsc = foo.get_operator("@jacobmarks/zero_shot_prediction/zero_shot_classify")
227 |
228 | ## Run zero-shot classification on all images in the dataset, specifying the labels with the `labels` argument
229 | zsc(dataset, labels=["cat", "dog", "bird"])
230 |
231 | ## Run zero-shot classification on all images in the dataset, specifying the labels with a text file
232 | zsc(dataset, labels_file="/path/to/labels.txt")
233 |
234 | ## Specify the model to use, and the field to add the results to
235 | zsc(dataset, labels=["cat", "dog", "bird"], model_name="CLIP", label_field="predictions")
236 |
237 | ## Run zero-shot detection on a view
238 | zsd = foo.get_operator("@jacobmarks/zero_shot_prediction/zero_shot_detect")
239 | view = dataset.take(10)
240 | await zsd(
241 | view,
242 | labels=["license plate"],
243 | model_name="OwlViT",
244 | label_field="owlvit_license_plate",
245 | )
246 | ```
247 |
248 | All four of the task-specific zero-shot prediction operators also expose a `list_models()` method, which returns a list of the available models for that task.
249 |
250 | ```python
251 | zsss = foo.get_operator(
252 | "@jacobmarks/zero_shot_prediction/zero_shot_semantic_segment"
253 | )
254 |
255 | zsss.list_models()
256 |
257 | ## ['CLIPSeg', 'GroupViT']
258 | ```
259 |
260 | **Note**: The `zero_shot_predict` operator is not yet supported in the Python SDK.
261 |
262 | **Note**: With earlier versions of FiftyOne, you may have trouble running these
263 | operator executions within a Jupyter notebook. If so, try running them in a
264 | Python script, or upgrading to the latest version of FiftyOne!
265 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | """Zero Shot Prediction plugin.
2 |
3 | | Copyright 2017-2023, Voxel51, Inc.
4 | | `voxel51.com `_
5 | |
6 | """
7 |
8 | import os
9 | import base64
10 |
11 | from fiftyone.core.utils import add_sys_path
12 | import fiftyone.operators as foo
13 | from fiftyone.operators import types
14 |
15 | with add_sys_path(os.path.dirname(os.path.abspath(__file__))):
16 | # pylint: disable=no-name-in-module,import-error
17 | from classification import (
18 | run_zero_shot_classification,
19 | CLASSIFICATION_MODELS,
20 | )
21 | from detection import run_zero_shot_detection, DETECTION_MODELS
22 | from instance_segmentation import (
23 | run_zero_shot_instance_segmentation,
24 | INSTANCE_SEGMENTATION_MODELS,
25 | )
26 | from semantic_segmentation import (
27 | run_zero_shot_semantic_segmentation,
28 | SEMANTIC_SEGMENTATION_MODELS,
29 | )
30 |
31 |
32 | ZERO_SHOT_TASKS = (
33 | "classification",
34 | "detection",
35 | "instance_segmentation",
36 | "semantic_segmentation",
37 | )
38 |
39 |
40 | MODEL_LISTS = {
41 | "classification": CLASSIFICATION_MODELS,
42 | "detection": DETECTION_MODELS,
43 | "instance_segmentation": INSTANCE_SEGMENTATION_MODELS,
44 | "semantic_segmentation": SEMANTIC_SEGMENTATION_MODELS,
45 | }
46 |
47 |
48 | def _format_submodel_name(submodel):
49 | if type(submodel) == str:
50 | return submodel
51 | return f"{submodel[0]}|{submodel[1]}"
52 |
53 |
54 | def _format_submodel_label(submodel):
55 | if type(submodel) == str:
56 | return submodel.split(".")[0]
57 | pretrain = submodel[0].split("/")[-1]
58 | arch = submodel[1]
59 |
60 | arch_string = f"Architecture: {arch}"
61 | _pt = pretrain and pretrain != "openai"
62 | pretrain_string = f" | Pretrained: {pretrain}" if _pt else ""
63 | return f"{arch_string}{pretrain_string}"
64 |
65 |
66 | def _execution_mode(ctx, inputs):
67 | delegate = ctx.params.get("delegate", False)
68 |
69 | if delegate:
70 | description = "Uncheck this box to execute the operation immediately"
71 | else:
72 | description = "Check this box to delegate execution of this task"
73 |
74 | inputs.bool(
75 | "delegate",
76 | default=False,
77 | required=True,
78 | label="Delegate execution?",
79 | description=description,
80 | view=types.CheckboxView(),
81 | )
82 |
83 | if delegate:
84 | inputs.view(
85 | "notice",
86 | types.Notice(
87 | label=(
88 | "You've chosen delegated execution. Note that you must "
89 | "have a delegated operation service running in order for "
90 | "this task to be processed. See "
91 | "https://docs.voxel51.com/plugins/index.html#operators "
92 | "for more information"
93 | )
94 | ),
95 | )
96 |
97 |
98 | def _get_active_models(task):
99 | ams = []
100 | for element in MODEL_LISTS[task].values():
101 | if element["activator"]():
102 | ams.append(element["name"])
103 | return ams
104 |
105 |
106 | def _get_labels(ctx):
107 | if ctx.params.get("label_input_choices", False) == "direct":
108 | labels = ctx.params.get("labels", "")
109 | return [label.strip() for label in labels.split(",")]
110 | else:
111 | labels_file = ctx.params.get("labels_file", None).strip()
112 | if "," in labels_file:
113 | lf = labels_file.split(",")[1]
114 | else:
115 | lf = labels_file
116 | decoded_bytes = base64.b64decode(lf)
117 | labels = decoded_bytes.decode("utf-8")
118 | return [label.strip() for label in labels.split("\n")]
119 |
120 |
121 | TASK_TO_FUNCTION = {
122 | "classification": run_zero_shot_classification,
123 | "detection": run_zero_shot_detection,
124 | "instance_segmentation": run_zero_shot_instance_segmentation,
125 | "semantic_segmentation": run_zero_shot_semantic_segmentation,
126 | }
127 |
128 |
129 | def run_zero_shot_task(
130 | dataset,
131 | task,
132 | model_name,
133 | label_field,
134 | categories,
135 | architecture,
136 | pretrained,
137 | confidence=0.2,
138 | ):
139 | return TASK_TO_FUNCTION[task](
140 | dataset,
141 | model_name,
142 | label_field,
143 | categories,
144 | architecture=architecture,
145 | pretrained=pretrained,
146 | confidence=confidence,
147 | )
148 |
149 |
150 | def _model_name_to_field_name(model_name):
151 | fn = (
152 | model_name.lower()
153 | .replace(" ", "_")
154 | .replace("_+", "")
155 | .replace("-", "")
156 | .split("(")[0]
157 | .strip()
158 | )
159 | if fn[-1] == "_":
160 | fn = fn[:-1]
161 | return fn
162 |
163 |
164 | def _handle_model_choice_inputs(ctx, inputs, chosen_task):
165 | active_models = _get_active_models(chosen_task)
166 |
167 | if len(active_models) == 0:
168 | inputs.str(
169 | "no_models_warning",
170 | view=types.Warning(
171 | label=f"No Models Found",
172 | description="No models were found for the selected task. Please install the required libraries.",
173 | ),
174 | )
175 | return types.Property(inputs)
176 |
177 | ct_label = (
178 | "Segmentation"
179 | if "segment" in chosen_task
180 | else chosen_task.capitalize()
181 | )
182 | model_dropdown_label = f"{ct_label} Model"
183 | model_dropdown = types.Dropdown(label=model_dropdown_label)
184 | for model in active_models:
185 | model_dropdown.add_choice(model, label=model)
186 | inputs.enum(
187 | f"model_choice_{chosen_task}",
188 | model_dropdown.values(),
189 | default=model_dropdown.choices[0].value,
190 | view=model_dropdown,
191 | )
192 |
193 | model_choice = ctx.params.get(
194 | f"model_choice_{chosen_task}", model_dropdown.choices[0].value
195 | )
196 | mc = model_choice.split("(")[0].strip().lower()
197 |
198 | submodels = MODEL_LISTS[chosen_task][model_choice].get("submodels", None)
199 | if submodels is not None:
200 | if len(submodels) == 1:
201 | ctx.params["pretrained"] = submodels[0][0]
202 | ctx.params["architecture"] = submodels[0][1]
203 | else:
204 | submodel_dropdown = types.Dropdown(
205 | label=f"{chosen_task.capitalize()} Submodel"
206 | )
207 | for submodel in submodels:
208 | submodel_dropdown.add_choice(
209 | _format_submodel_name(submodel),
210 | label=_format_submodel_label(submodel),
211 | )
212 | inputs.enum(
213 | f"submodel_choice_{chosen_task}_{mc}",
214 | submodel_dropdown.values(),
215 | default=submodel_dropdown.choices[0].value,
216 | view=submodel_dropdown,
217 | )
218 |
219 | submodel_choice = ctx.params.get(
220 | f"submodel_choice_{chosen_task}_{model_choice}",
221 | submodel_dropdown.choices[0].value,
222 | )
223 |
224 | if "|" in submodel_choice:
225 | if chosen_task == "instance_segmentation":
226 | ctx.params["pretrained"] += submodel_choice.split("|")[0]
227 | ctx.params["architecture"] += submodel_choice.split("|")[1]
228 | else:
229 | ctx.params["pretrained"] = submodel_choice.split("|")[0]
230 | ctx.params["architecture"] = submodel_choice.split("|")[1]
231 | else:
232 | if chosen_task == "instance_segmentation":
233 | ctx.params["pretrained"] += submodel_choice
234 | else:
235 | ctx.params["pretrained"] = submodel_choice
236 | ctx.params["architecture"] = None
237 |
238 |
239 | def handle_model_choice_inputs(ctx, inputs, chosen_task):
240 | if chosen_task == "instance_segmentation":
241 | _handle_model_choice_inputs(ctx, inputs, "detection")
242 | if ctx.params.get("pretrained", None) is not None:
243 | ctx.params["pretrained"] = ctx.params["pretrained"] + " + "
244 | else:
245 | ctx.params["pretrained"] = " + "
246 | if ctx.params.get("architecture", None) is not None:
247 | ctx.params["architecture"] = ctx.params["architecture"] + " + "
248 | else:
249 | ctx.params["architecture"] = " + "
250 | _handle_model_choice_inputs(ctx, inputs, chosen_task)
251 |
252 |
253 | class ZeroShotTasks(foo.Operator):
254 | @property
255 | def config(self):
256 | _config = foo.OperatorConfig(
257 | name="zero_shot_predict",
258 | label="Perform Zero Shot Prediction",
259 | dynamic=True,
260 | )
261 | _config.icon = "/assets/icon.svg"
262 | return _config
263 |
264 | def resolve_delegation(self, ctx):
265 | return ctx.params.get("delegate", False)
266 |
267 | def resolve_input(self, ctx):
268 | inputs = types.Object()
269 |
270 | radio_choices = types.RadioGroup()
271 | radio_choices.add_choice("classification", label="Classification")
272 | radio_choices.add_choice("detection", label="Detection")
273 | radio_choices.add_choice(
274 | "instance_segmentation", label="Instance Segmentation"
275 | )
276 | radio_choices.add_choice(
277 | "semantic_segmentation", label="Semantic Segmentation"
278 | )
279 | inputs.enum(
280 | "task_choices",
281 | radio_choices.values(),
282 | default=radio_choices.choices[0].value,
283 | label="Zero Shot Task",
284 | view=radio_choices,
285 | )
286 |
287 | chosen_task = ctx.params.get("task_choices", "classification")
288 | handle_model_choice_inputs(ctx, inputs, chosen_task)
289 | active_models = _get_active_models(chosen_task)
290 |
291 | if chosen_task in ["detection", "instance_segmentation"]:
292 | inputs.float(
293 | "confidence",
294 | label="Confidence Threshold",
295 | default=0.2,
296 | description="The minimum confidence required for a prediction to be included",
297 | )
298 |
299 | label_input_choices = types.RadioGroup()
300 | label_input_choices.add_choice("direct", label="Input directly")
301 | label_input_choices.add_choice("file", label="Input from file")
302 | inputs.enum(
303 | "label_input_choices",
304 | label_input_choices.values(),
305 | default=label_input_choices.choices[0].value,
306 | label="Labels",
307 | view=label_input_choices,
308 | )
309 |
310 | if ctx.params.get("label_input_choices", "direct") == "direct":
311 | inputs.str(
312 | "labels",
313 | label="Labels",
314 | description="Enter the names of the classes you wish to generate predictions for, separated by commas",
315 | required=True,
316 | )
317 | else:
318 | labels_file = types.FileView(label="Labels File")
319 | inputs.str(
320 | "labels_file",
321 | label="Labels File",
322 | required=True,
323 | view=labels_file,
324 | )
325 |
326 | model_name = ctx.params.get(
327 | f"model_choice_{chosen_task}", active_models[0]
328 | )
329 | model_name = model_name.split("(")[0].strip().replace("-", "").lower()
330 | inputs.str(
331 | f"label_field_{chosen_task}_{model_name}",
332 | label="Label Field",
333 | default=_model_name_to_field_name(model_name),
334 | description="The field to store the predicted labels in",
335 | required=True,
336 | )
337 | _execution_mode(ctx, inputs)
338 | inputs.view_target(ctx)
339 | return types.Property(inputs)
340 |
341 | def execute(self, ctx):
342 | view = ctx.target_view()
343 | task = ctx.params.get("task_choices", "classification")
344 | active_models = _get_active_models(task)
345 | model_name = ctx.params.get(f"model_choice_{task}", active_models[0])
346 | mn = model_name.split("(")[0].strip().lower().replace("-", "")
347 | if task == "instance_segmentation":
348 | model_name = (
349 | ctx.params[f"model_choice_detection"] + " + " + model_name
350 | )
351 | categories = _get_labels(ctx)
352 | label_field = ctx.params.get(f"label_field_{task}_{mn}", mn)
353 | architecture = ctx.params.get("architecture", None)
354 | pretrained = ctx.params.get("pretrained", None)
355 | confidence = ctx.params.get("confidence", 0.2)
356 |
357 | run_zero_shot_task(
358 | view,
359 | task,
360 | model_name,
361 | label_field,
362 | categories,
363 | architecture,
364 | pretrained,
365 | confidence=confidence,
366 | )
367 | ctx.ops.reload_dataset()
368 |
369 |
370 | ### Common input control flow for all tasks
371 | def _input_control_flow(ctx, task):
372 | inputs = types.Object()
373 | active_models = _get_active_models(task)
374 | if len(active_models) == 0:
375 | inputs.str(
376 | "no_models_warning",
377 | view=types.Warning(
378 | label=f"No Models Found",
379 | description="No models were found for the selected task. Please install the required libraries.",
380 | ),
381 | )
382 | return types.Property(inputs)
383 |
384 | handle_model_choice_inputs(ctx, inputs, task)
385 |
386 | label_input_choices = types.RadioGroup()
387 | label_input_choices.add_choice("direct", label="Input directly")
388 | label_input_choices.add_choice("file", label="Input from file")
389 | inputs.enum(
390 | "label_input_choices",
391 | label_input_choices.values(),
392 | default=label_input_choices.choices[0].value,
393 | label="Labels",
394 | view=label_input_choices,
395 | )
396 |
397 | if task in ["detection", "instance_segmentation"]:
398 | inputs.float(
399 | "confidence",
400 | label="Confidence Threshold",
401 | default=0.2,
402 | description="The minimum confidence required for a prediction to be included",
403 | )
404 |
405 | if ctx.params.get("label_input_choices", False) == "direct":
406 | inputs.str(
407 | "labels",
408 | label="Labels",
409 | description="Enter the names of the classes you wish to generate predictions for, separated by commas",
410 | required=True,
411 | )
412 | else:
413 | labels_file = types.FileView(label="Labels File")
414 | inputs.str(
415 | "labels_file",
416 | label="Labels File",
417 | required=True,
418 | view=labels_file,
419 | )
420 |
421 | model_name = ctx.params.get(f"model_choice_{task}", active_models[0])
422 | mn = model_name.split("(")[0].strip().lower().replace("-", "")
423 | inputs.str(
424 | f"label_field_{mn}",
425 | label="Label Field",
426 | default=_model_name_to_field_name(model_name),
427 | description="The field to store the predicted labels in",
428 | required=True,
429 | )
430 | _execution_mode(ctx, inputs)
431 | inputs.view_target(ctx)
432 | return inputs
433 |
434 |
435 | def _execute_control_flow(ctx, task):
436 | view = ctx.target_view()
437 | model_name = ctx.params.get(f"model_choice_{task}", "CLIP")
438 | mn = _model_name_to_field_name(model_name).split("(")[0].strip().lower()
439 | label_field = ctx.params.get(f"label_field_{mn}", mn)
440 | if task == "instance_segmentation":
441 | model_name = ctx.params[f"model_choice_detection"] + " + " + model_name
442 |
443 | kwargs = {}
444 | if task in ["detection", "instance_segmentation"]:
445 | kwargs["confidence"] = ctx.params.get("confidence", 0.2)
446 | categories = _get_labels(ctx)
447 |
448 | architecture = ctx.params.get("architecture", None)
449 | pretrained = ctx.params.get("pretrained", None)
450 |
451 | run_zero_shot_task(
452 | view,
453 | task,
454 | model_name,
455 | label_field,
456 | categories,
457 | architecture,
458 | pretrained,
459 | **kwargs,
460 | )
461 | ctx.ops.reload_dataset()
462 |
463 |
464 | NAME_TO_TASK = {
465 | "zero_shot_classify": "classification",
466 | "zero_shot_detect": "detection",
467 | "zero_shot_instance_segment": "instance_segmentation",
468 | "zero_shot_semantic_segment": "semantic_segmentation",
469 | }
470 |
471 |
472 | def _format_model_name(model_name):
473 | return (
474 | model_name.lower().replace(" ", "").replace("_", "").replace("-", "")
475 | )
476 |
477 |
478 | def _match_model_name(model_name, model_names):
479 | for name in model_names:
480 | if _format_model_name(name) == _format_model_name(model_name):
481 | return name
482 | raise ValueError(
483 | f"Model name {model_name} not found. Use one of {model_names}"
484 | )
485 |
486 |
487 | def _resolve_model_name(task, model_name):
488 | if model_name is None:
489 | return list(MODEL_LISTS[task].keys())[0]
490 | elif model_name not in MODEL_LISTS[task]:
491 | return _match_model_name(model_name, list(MODEL_LISTS[task].keys()))
492 | return model_name
493 |
494 |
495 | def _resolve_labels(labels, labels_file):
496 | if labels is None and labels_file is None:
497 | raise ValueError("Must provide either labels or labels_file")
498 |
499 | if labels is not None and labels_file is not None:
500 | raise ValueError("Cannot provide both labels and labels_file")
501 |
502 | if labels is not None and type(labels) == list:
503 | labels = ", ".join(labels)
504 | else:
505 | with open(labels_file, "r") as f:
506 | labels = [label.strip() for label in f.readlines()]
507 | labels = ", ".join(labels)
508 |
509 | return labels
510 |
511 |
512 | def _resolve_label_field(model_name, label_field):
513 | if label_field is None:
514 | label_field = _model_name_to_field_name(model_name)
515 | label_field_name = f"label_field_{_model_name_to_field_name(model_name)}"
516 |
517 | return label_field, label_field_name
518 |
519 |
520 | def _handle_calling(
521 | uri,
522 | sample_collection,
523 | model_name=None,
524 | labels=None,
525 | labels_file=None,
526 | label_field=None,
527 | delegate=False,
528 | confidence=None,
529 | ):
530 | ctx = dict(view=sample_collection.view())
531 |
532 | task = NAME_TO_TASK[uri.split("/")[-1]]
533 |
534 | model_name = _resolve_model_name(task, model_name)
535 | labels = _resolve_labels(labels, labels_file)
536 |
537 | label_field, label_field_name = _resolve_label_field(
538 | model_name, label_field
539 | )
540 |
541 | params = dict(
542 | label_input_choices="direct",
543 | delegate=delegate,
544 | labels=labels,
545 | )
546 | if confidence is not None:
547 | params["confidence"] = confidence
548 | params[label_field_name] = label_field
549 | params[f"model_choice_{task}"] = model_name
550 |
551 | return foo.execute_operator(uri, ctx, params=params)
552 |
553 |
554 | class ZeroShotClassify(foo.Operator):
555 | @property
556 | def config(self):
557 | _config = foo.OperatorConfig(
558 | name="zero_shot_classify",
559 | label="Perform Zero Shot Classification",
560 | dynamic=True,
561 | )
562 | _config.icon = "/assets/icon.svg"
563 | return _config
564 |
565 | def resolve_delegation(self, ctx):
566 | return ctx.params.get("delegate", False)
567 |
568 | def resolve_input(self, ctx):
569 | inputs = _input_control_flow(ctx, "classification")
570 | return types.Property(inputs)
571 |
572 | def execute(self, ctx):
573 | _execute_control_flow(ctx, "classification")
574 |
575 | def __call__(
576 | self,
577 | sample_collection,
578 | model_name=None,
579 | labels=None,
580 | labels_file=None,
581 | label_field=None,
582 | delegate=False,
583 | ):
584 | return _handle_calling(
585 | self.uri,
586 | sample_collection,
587 | model_name=model_name,
588 | labels=labels,
589 | labels_file=labels_file,
590 | label_field=label_field,
591 | delegate=delegate,
592 | )
593 |
594 | def list_models(self):
595 | return list(MODEL_LISTS["classification"].keys())
596 |
597 |
598 | class ZeroShotDetect(foo.Operator):
599 | @property
600 | def config(self):
601 | _config = foo.OperatorConfig(
602 | name="zero_shot_detect",
603 | label="Perform Zero Shot Detection",
604 | dynamic=True,
605 | )
606 | _config.icon = "/assets/icon.svg"
607 | return _config
608 |
609 | def resolve_delegation(self, ctx):
610 | return ctx.params.get("delegate", False)
611 |
612 | def resolve_input(self, ctx):
613 | inputs = _input_control_flow(ctx, "detection")
614 | inputs.float(
615 | "confidence",
616 | label="Confidence Threshold",
617 | default=0.2,
618 | description="The minimum confidence required for a prediction to be included",
619 | )
620 | return types.Property(inputs)
621 |
622 | def execute(self, ctx):
623 | _execute_control_flow(ctx, "detection")
624 |
625 | def __call__(
626 | self,
627 | sample_collection,
628 | model_name=None,
629 | labels=None,
630 | labels_file=None,
631 | label_field=None,
632 | delegate=False,
633 | confidence=0.2,
634 | ):
635 | return _handle_calling(
636 | self.uri,
637 | sample_collection,
638 | model_name=model_name,
639 | labels=labels,
640 | labels_file=labels_file,
641 | label_field=label_field,
642 | delegate=delegate,
643 | confidence=confidence,
644 | )
645 |
646 | def list_models(self):
647 | return list(MODEL_LISTS["detection"].keys())
648 |
649 |
650 | class ZeroShotInstanceSegment(foo.Operator):
651 | @property
652 | def config(self):
653 | _config = foo.OperatorConfig(
654 | name="zero_shot_instance_segment",
655 | label="Perform Zero Shot Instance Segmentation",
656 | dynamic=True,
657 | )
658 | _config.icon = "/assets/icon.svg"
659 | return _config
660 |
661 | def resolve_delegation(self, ctx):
662 | return ctx.params.get("delegate", False)
663 |
664 | def resolve_input(self, ctx):
665 | inputs = _input_control_flow(ctx, "instance_segmentation")
666 | return types.Property(inputs)
667 |
668 | def execute(self, ctx):
669 | _execute_control_flow(ctx, "instance_segmentation")
670 |
671 | def __call__(
672 | self,
673 | sample_collection,
674 | model_name=None,
675 | labels=None,
676 | labels_file=None,
677 | label_field=None,
678 | delegate=False,
679 | confidence=0.2,
680 | ):
681 | return _handle_calling(
682 | self.uri,
683 | sample_collection,
684 | model_name=model_name,
685 | labels=labels,
686 | labels_file=labels_file,
687 | label_field=label_field,
688 | delegate=delegate,
689 | confidence=confidence,
690 | )
691 |
692 | def list_models(self):
693 | return list(MODEL_LISTS["instance_segmentation"].keys())
694 |
695 |
696 | class ZeroShotSemanticSegment(foo.Operator):
697 | @property
698 | def config(self):
699 | _config = foo.OperatorConfig(
700 | name="zero_shot_semantic_segment",
701 | label="Perform Zero Shot Semantic Segmentation",
702 | dynamic=True,
703 | )
704 | _config.icon = "/assets/icon.svg"
705 | return _config
706 |
707 | def resolve_delegation(self, ctx):
708 | return ctx.params.get("delegate", False)
709 |
710 | def resolve_input(self, ctx):
711 | inputs = _input_control_flow(ctx, "semantic_segmentation")
712 | return types.Property(inputs)
713 |
714 | def execute(self, ctx):
715 | _execute_control_flow(ctx, "semantic_segmentation")
716 |
717 | def __call__(
718 | self,
719 | sample_collection,
720 | model_name=None,
721 | labels=None,
722 | labels_file=None,
723 | label_field=None,
724 | delegate=False,
725 | ):
726 | return _handle_calling(
727 | self.uri,
728 | sample_collection,
729 | model_name=model_name,
730 | labels=labels,
731 | labels_file=labels_file,
732 | label_field=label_field,
733 | delegate=delegate,
734 | )
735 |
736 | def list_models(self):
737 | return list(MODEL_LISTS["semantic_segmentation"].keys())
738 |
739 |
740 | def register(plugin):
741 | plugin.register(ZeroShotTasks)
742 | plugin.register(ZeroShotClassify)
743 | plugin.register(ZeroShotDetect)
744 | plugin.register(ZeroShotInstanceSegment)
745 | plugin.register(ZeroShotSemanticSegment)
746 |
--------------------------------------------------------------------------------
/assets/icon.svg:
--------------------------------------------------------------------------------
1 |
6 |
--------------------------------------------------------------------------------
/classification.py:
--------------------------------------------------------------------------------
1 | """Zero Shot Classification.
2 |
3 | | Copyright 2017-2023, Voxel51, Inc.
4 | | `voxel51.com `_
5 | |
6 | """
7 | from importlib.util import find_spec
8 | import numpy as np
9 | from PIL import Image
10 | import torch
11 |
12 | import fiftyone as fo
13 | from fiftyone.core.models import Model
14 | import fiftyone.zoo as foz
15 |
16 | ### Make tuples in form ("pretrained", "clip_model")
17 |
18 | OPENAI_ARCHS = ["ViT-B-32", "ViT-B-16", "ViT-L-14"]
19 | OPENAI_CLIP_MODELS = [("openai", model) for model in OPENAI_ARCHS]
20 |
21 | DFN_CLIP_MODELS = [("dfn2b", "ViT-B-16")]
22 |
23 | META_ARCHS = ("ViT-B-16-quickgelu", "ViT-B-32-quickgelu", "ViT-L-14-quickgelu")
24 | META_PRETRAINS = ("metaclip_400m", "metaclip_fullcc")
25 | META_CLIP_MODELS = [
26 | (pretrain, arch) for pretrain in META_PRETRAINS for arch in META_ARCHS
27 | ]
28 | META_CLIP_MODELS.append(("metaclip_fullcc", "ViT-H-14-quickgelu"))
29 |
30 | CLIPA_MODELS = [("", "hf-hub:UCSC-VLAA/ViT-L-14-CLIPA-datacomp1B")]
31 |
32 | SIGLIP_ARCHS = (
33 | "ViT-B-16-SigLIP",
34 | "ViT-B-16-SigLIP-256",
35 | "ViT-B-16-SigLIP-384",
36 | "ViT-L-16-SigLIP-256",
37 | "ViT-L-16-SigLIP-384",
38 | "ViT-SO400M-14-SigLIP",
39 | "ViT-SO400M-14-SigLIP-384",
40 | )
41 | SIGLIP_MODELS = [("", arch) for arch in SIGLIP_ARCHS]
42 |
43 | EVA_CLIP_MODELS = [
44 | ("merged2b_s8b_b131k", "EVA02-B-16"),
45 | ("merged2b_s6b_b61k", "EVA02-L-14-336"),
46 | ("merged2b_s4b_b131k", "EVA02-L-14"),
47 | ]
48 |
49 | AIMV2_MODELS = ["aimv2-large-patch14-224-lit"]
50 |
51 | def get_device():
52 | """Helper function to determine the best available device."""
53 | if torch.cuda.is_available():
54 | device = "cuda"
55 | print(f"Using CUDA device: {torch.cuda.get_device_name(0)}")
56 | elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
57 | device = "mps"
58 | print("Using Apple Silicon (MPS) device")
59 | else:
60 | device = "cpu"
61 | print("Using CPU device")
62 | return device
63 |
64 |
65 | def CLIPZeroShotModel(config):
66 | """
67 | This function loads a zero-shot classification model using the CLIP architecture.
68 | It utilizes the FiftyOne Zoo to load a pre-trained model based on the provided
69 | configuration.
70 |
71 | Args:
72 | config (dict): A dictionary containing configuration parameters for the model.
73 | - categories (list, optional): A list of categories for classification.
74 | - clip_model (str, optional): The architecture of the CLIP model to use.
75 | - pretrained (str, optional): The pre-trained weights to use.
76 |
77 | Returns:
78 | Model: A loaded CLIP zero-shot classification model ready for inference.
79 | """
80 | cats = config.get("categories", None)
81 | clip_model = config.get("clip_model", "ViT-B-32")
82 | pretrained = config.get("pretrained", "openai")
83 |
84 | model = foz.load_zoo_model(
85 | "clip-vit-base32-torch",
86 | clip_model=clip_model,
87 | pretrained=pretrained,
88 | text_prompt="A photo of a",
89 | classes=cats,
90 | )
91 |
92 | return model
93 |
94 |
95 | def CLIP_activator():
96 | """
97 | Determines if the CLIP model can be activated.
98 |
99 | This function checks for the availability of the necessary
100 | components to activate the CLIP model. It returns True if
101 | the model can be activated, otherwise False.
102 |
103 | Returns:
104 | bool: True if the CLIP model can be activated, False otherwise.
105 | """
106 | return True
107 |
108 | class AltCLIPZeroShotModel(Model):
109 | """
110 | This class implements a zero-shot classification model using the AltCLIP architecture.
111 | It leverages the AltCLIP model from the Hugging Face Transformers library to perform
112 | image classification without requiring task-specific training data.
113 |
114 | Args:
115 | config (dict): A dictionary containing configuration parameters for the model.
116 | - categories (list, optional): A list of categories for classification.
117 |
118 | Attributes:
119 | categories (list): The list of categories for classification.
120 | candidate_labels (list): A list of text prompts for each category.
121 | model (AltCLIPModel): The pre-trained AltCLIP model.
122 | processor (AltCLIPProcessor): The processor for preparing inputs for the model.
123 |
124 | Methods:
125 | media_type: Returns the type of media the model is designed to process.
126 | _predict(image): Performs prediction on a single image.
127 | predict(args): Converts input data to an image and performs prediction.
128 | _predict_all(images): Performs prediction on a list of images.
129 | """
130 | def __init__(self, config):
131 | """
132 | Initializes the AltCLIPZeroShotModel with the given configuration.
133 |
134 | This constructor sets up the model by initializing the categories
135 | and candidate labels for classification. It also loads the pre-trained
136 | AltCLIP model and processor from the Hugging Face Transformers library.
137 |
138 | Args:
139 | config (dict): A dictionary containing configuration parameters for the model.
140 | - categories (list, optional): A list of categories for classification.
141 | """
142 | self.categories = config.get("categories", None)
143 | self.candidate_labels = [
144 | f"a photo of a {cat}" for cat in self.categories
145 | ]
146 |
147 | from transformers import AltCLIPModel, AltCLIPProcessor
148 |
149 | self.model = AltCLIPModel.from_pretrained("BAAI/AltCLIP")
150 | self.processor = AltCLIPProcessor.from_pretrained("BAAI/AltCLIP")
151 |
152 | # Set up device
153 | self.device = get_device()
154 |
155 | # Move model to appropriate device and set to eval mode
156 | self.model = self.model.to(self.device)
157 | self.model.eval()
158 |
159 | @property
160 | def media_type(self):
161 | return "image"
162 |
163 | def _predict(self, image):
164 | """
165 | Performs prediction on a single image.
166 |
167 | This method processes the input image using the AltCLIPProcessor
168 | and performs a forward pass through the AltCLIPModel to obtain
169 | classification probabilities for each category. It returns a
170 | FiftyOne Classification object containing the predicted label,
171 | logits, and confidence score.
172 |
173 | Args:
174 | image (PIL.Image): The input image to classify.
175 |
176 | Returns:
177 | fo.Classification: The classification result containing the
178 | predicted label, logits, and confidence score.
179 | """
180 | inputs = self.processor(
181 | text=self.candidate_labels,
182 | images=image,
183 | return_tensors="pt",
184 | padding=True
185 | )
186 |
187 | inputs = inputs.to(self.device)
188 |
189 | with torch.no_grad():
190 | outputs = self.model(**inputs)
191 |
192 | logits_per_image = outputs.logits_per_image
193 |
194 | # Move to CPU only if necessary
195 | if logits_per_image.device.type != 'cpu':
196 | logits_per_image = logits_per_image.cpu()
197 |
198 | probs = logits_per_image.softmax(dim=1).numpy()
199 |
200 | return fo.Classification(
201 | label=self.categories[probs.argmax()],
202 | logits=logits_per_image.squeeze().numpy(),
203 | confidence=np.amax(probs[0]),
204 | )
205 |
206 | def predict(self, args):
207 | """
208 | Converts input data to an image and performs prediction.
209 |
210 | This method takes input data, converts it into a PIL image,
211 | and then uses the `_predict` method to perform classification.
212 | It returns the prediction results as a FiftyOne Classification object.
213 |
214 | Args:
215 | args (numpy.ndarray): The input data to be converted into an image.
216 |
217 | Returns:
218 | fo.Classification: The classification result containing the
219 | predicted label, logits, and confidence score.
220 | """
221 | image = Image.fromarray(args)
222 | predictions = self._predict(image)
223 | return predictions
224 |
225 | def _predict_all(self, images):
226 | return [self._predict(image) for image in images]
227 |
228 |
229 | def AltCLIP_activator():
230 | return find_spec("transformers") is not None
231 |
232 |
233 | class AlignZeroShotModel(Model):
234 | """
235 | AlignZeroShotModel is a class for zero-shot image classification using the Align model.
236 |
237 | This class leverages the Align model from the `transformers` library to perform
238 | zero-shot classification on images. It initializes with a configuration that
239 | specifies the categories for classification. The model processes input images
240 | and predicts the most likely category from the provided list.
241 |
242 | Attributes:
243 | categories (list): A list of category labels for classification.
244 | candidate_labels (list): A list of formatted labels for the Align model.
245 | processor (AlignProcessor): The processor for preparing inputs for the model.
246 | model (AlignModel): The pre-trained Align model for classification.
247 |
248 | Methods:
249 | media_type: Returns the type of media the model works with, which is "image".
250 | _predict(image): Performs prediction on a single image.
251 | predict(args): Converts input data to an image and performs prediction.
252 | _predict_all(images): Performs prediction on a list of images.
253 | """
254 | def __init__(self, config):
255 | self.categories = config.get("categories", None)
256 | self.candidate_labels = [
257 | f"a photo of a {cat}" for cat in self.categories
258 | ]
259 |
260 | from transformers import AlignProcessor, AlignModel
261 |
262 | self.processor = AlignProcessor.from_pretrained("kakaobrain/align-base")
263 | self.model = AlignModel.from_pretrained("kakaobrain/align-base")
264 |
265 | # Set up device
266 | self.device = get_device()
267 |
268 | # Move model to appropriate device and set to eval mode
269 | self.model = self.model.to(self.device)
270 | self.model.eval()
271 |
272 | @property
273 | def media_type(self):
274 | return "image"
275 |
276 | def _predict(self, image):
277 | """
278 | Performs prediction on a single image.
279 |
280 | This method takes an image as input and processes it using the Align model
281 | to predict the most likely category from the pre-defined list of categories.
282 | It uses the processor to prepare the input and the model to generate predictions.
283 | The method returns a `fiftyone.core.labels.Classification` object containing
284 | the predicted label, logits, and confidence score.
285 |
286 | Args:
287 | image (PIL.Image.Image): The input image for classification.
288 |
289 | Returns:
290 | fiftyone.core.labels.Classification: The classification result with the
291 | predicted label, logits, and confidence score.
292 | """
293 | inputs = self.processor(
294 | text=self.candidate_labels,
295 | images=image,
296 | return_tensors="pt"
297 | )
298 |
299 | inputs = inputs.to(self.device)
300 |
301 | with torch.no_grad():
302 | outputs = self.model(**inputs)
303 |
304 | logits_per_image = outputs.logits_per_image
305 |
306 | # Move to CPU only if necessary
307 | if logits_per_image.device.type != 'cpu':
308 | logits_per_image = logits_per_image.cpu()
309 |
310 | probs = logits_per_image.softmax(dim=1).numpy()
311 |
312 | return fo.Classification(
313 | label=self.categories[probs.argmax()],
314 | logits=logits_per_image.squeeze().numpy(),
315 | confidence=np.amax(probs[0]),
316 | )
317 |
318 | def predict(self, args):
319 | """
320 | Predicts the category of the given image.
321 |
322 | This method takes an image in the form of a numpy array, converts it
323 | to a PIL Image, and then uses the `_predict` method to classify the
324 | image. The classification result is returned as a
325 | `fiftyone.core.labels.Classification` object.
326 |
327 | Args:
328 | args (np.ndarray): The input image as a numpy array.
329 |
330 | Returns:
331 | fiftyone.core.labels.Classification: The classification result with
332 | the predicted label, logits, and confidence score.
333 | """
334 | image = Image.fromarray(args)
335 | predictions = self._predict(image)
336 | return predictions
337 |
338 | def _predict_all(self, images):
339 | return [self._predict(image) for image in images]
340 |
341 |
342 | def Align_activator():
343 | return find_spec("transformers") is not None
344 |
345 |
346 | def OpenCLIPZeroShotModel(config):
347 | """
348 | Initializes and returns an OpenCLIP zero-shot model based on the provided configuration.
349 |
350 | This function loads a pre-trained OpenCLIP model using the specified configuration
351 | parameters. The model is initialized with a text prompt and a set of categories
352 | for zero-shot classification tasks.
353 |
354 | Args:
355 | config (dict): A dictionary containing configuration parameters for the model.
356 | - "categories" (list, optional): A list of category names for classification.
357 | - "clip_model" (str, optional): The name of the CLIP model architecture to use.
358 | Defaults to "ViT-B-32".
359 | - "pretrained" (str, optional): The name of the pre-trained weights to load.
360 | Defaults to "openai".
361 |
362 | Returns:
363 | An instance of the OpenCLIP model configured for zero-shot classification.
364 | """
365 | cats = config.get("categories", None)
366 | clip_model = config.get("clip_model", "ViT-B-32")
367 | pretrained = config.get("pretrained", "openai")
368 |
369 | model = foz.load_zoo_model(
370 | "open-clip-torch",
371 | clip_model=clip_model,
372 | pretrained=pretrained,
373 | text_prompt="A photo of a",
374 | classes=cats,
375 | )
376 |
377 | return model
378 |
379 | def OpenCLIP_activator():
380 | return find_spec("open_clip") is not None
381 |
382 | class AIMV2ZeroShotModel(Model):
383 | """Zero-shot image classification model using Apple's AIM-V2.
384 |
385 | AIM-V2 (Apple Image Models V2) are vision-language models from Apple that achieve
386 | state-of-the-art performance on various vision tasks.
387 |
388 | Available models:
389 | - apple/aimv2-large-patch14-224-lit: LiT-tuned variant
390 |
391 | Args:
392 | config (dict): Configuration dictionary containing:
393 | - categories (list): List of category labels for classification
394 | - model_name (str, optional): Full model name including organization.
395 | Defaults to "apple/aimv2-large-patch14-224-lit"
396 |
397 | Attributes:
398 | categories (list): Available classification categories
399 | candidate_labels (list): Text prompts generated from categories
400 | model (AutoModel): The underlying AIM-V2 model
401 | processor (AutoProcessor): Processor for preparing inputs
402 | """
403 |
404 | def __init__(self, config):
405 | self.categories = config.get("categories", None)
406 | if self.categories is None:
407 | raise ValueError("Categories must be provided in config")
408 |
409 | self.candidate_labels = [
410 | f"Picture of a {cat}." for cat in self.categories
411 | ]
412 |
413 | model_name = config.get(
414 | "model_name",
415 | "apple/aimv2-large-patch14-224-lit"
416 | )
417 |
418 | from transformers import AutoProcessor, AutoModel
419 |
420 | # Set up device
421 | self.device = "cuda" if torch.cuda.is_available() else "cpu"
422 |
423 | # Initialize model and processor
424 | self.processor = AutoProcessor.from_pretrained(
425 | model_name,
426 | trust_remote_code=True
427 | )
428 |
429 | self.model = AutoModel.from_pretrained(
430 | model_name,
431 | trust_remote_code=True
432 | )
433 |
434 | # Move model to appropriate device and set to eval mode
435 | self.model = self.model.to(self.device)
436 | self.model.eval()
437 |
438 | @property
439 | def media_type(self):
440 | """The type of media handled by this model.
441 |
442 | Returns:
443 | str: Always returns 'image'
444 | """
445 | return "image"
446 |
447 | def _predict(self, image):
448 | """Internal prediction method for a single image.
449 |
450 | Args:
451 | image (PIL.Image): Input image to classify
452 |
453 | Returns:
454 | fo.Classification: Classification result containing:
455 | - label: Predicted category
456 | - logits: Raw model outputs
457 | - confidence: Prediction confidence score
458 | """
459 | inputs = self.processor(
460 | text=self.candidate_labels,
461 | images=image,
462 | add_special_tokens=True,
463 | truncation=True,
464 | padding=True,
465 | return_tensors="pt"
466 | )
467 |
468 | inputs = inputs.to(self.device)
469 |
470 | with torch.no_grad():
471 | outputs = self.model(**inputs)
472 |
473 | logits_per_image = outputs.logits_per_image
474 |
475 | # Move to CPU only if necessary
476 | if logits_per_image.device.type != 'cpu':
477 | logits_per_image = logits_per_image.cpu()
478 |
479 | probs = logits_per_image.softmax(dim=-1).numpy()
480 |
481 | return fo.Classification(
482 | label=self.categories[probs.argmax()],
483 | logits=logits_per_image.squeeze().numpy(),
484 | confidence=np.amax(probs[0]),
485 | )
486 |
487 | def predict(self, args):
488 | """Public prediction interface for numpy array input.
489 |
490 | Args:
491 | args (np.ndarray): Input image as numpy array
492 |
493 | Returns:
494 | fo.Classification: Classification result
495 | """
496 | image = Image.fromarray(args)
497 | predictions = self._predict(image)
498 | return predictions
499 |
500 | def _predict_all(self, images):
501 | """Batch prediction for multiple images.
502 |
503 | Args:
504 | images (list): List of images to classify
505 |
506 | Returns:
507 | list: List of fo.Classification results
508 | """
509 | return [self._predict(image) for image in images]
510 |
511 |
512 | def AIMV2_activator():
513 | """Check if required dependencies for AIM-V2 are available."""
514 | try:
515 | from transformers import AutoProcessor, AutoModel
516 | return True
517 | except ImportError:
518 | return False
519 |
520 |
521 | OPEN_CLIP_MODEL_TYPES = {
522 | "CLIPA": CLIPA_MODELS,
523 | "DFN CLIP": DFN_CLIP_MODELS,
524 | "EVA-CLIP": EVA_CLIP_MODELS,
525 | "MetaCLIP": META_CLIP_MODELS,
526 | "SigLIP": SIGLIP_MODELS,
527 | }
528 |
529 |
530 | def build_classification_models_dict():
531 | """
532 | Builds a dictionary of classification models available for use.
533 |
534 | This function constructs a dictionary where each key is a string representing
535 | the name of a classification model type, and the value is a dictionary containing
536 | the following keys:
537 | - "activator": A function that checks if the model's dependencies are available.
538 | - "model": A function that initializes and returns the model.
539 | - "submodels": Additional model configurations or variants, if any.
540 | - "name": The display name of the model.
541 |
542 | Returns:
543 | dict: A dictionary of available classification models.
544 | """
545 | cms = {}
546 |
547 | # Add CLIP (OpenAI) if available
548 | if OpenCLIP_activator():
549 | cms["CLIP (OpenAI)"] = {
550 | "activator": CLIP_activator,
551 | "model": CLIPZeroShotModel,
552 | "submodels": None,
553 | "name": "CLIP (OpenAI)",
554 | }
555 |
556 | # Add ALIGN if available
557 | if Align_activator():
558 | cms["ALIGN"] = {
559 | "activator": Align_activator,
560 | "model": AlignZeroShotModel,
561 | "submodels": None,
562 | "name": "ALIGN",
563 | }
564 |
565 | # Add AltCLIP if available
566 | if AltCLIP_activator():
567 | cms["AltCLIP"] = {
568 | "activator": AltCLIP_activator,
569 | "model": AltCLIPZeroShotModel,
570 | "submodels": None,
571 | "name": "AltCLIP",
572 | }
573 |
574 | # Add Apple AIMv2 if available
575 | if AIMV2_activator():
576 | cms["AIMv2"] = {
577 | "activator": AIMV2_activator,
578 | "model": AIMV2ZeroShotModel,
579 | "submodels": None,
580 | "name": "AIMv2",
581 | }
582 |
583 | # Add OpenCLIP models if available
584 | for key, value in OPEN_CLIP_MODEL_TYPES.items():
585 | cms[key] = {
586 | "activator": OpenCLIP_activator,
587 | "model": OpenCLIPZeroShotModel,
588 | "submodels": value,
589 | "name": key,
590 | }
591 |
592 | return cms
593 |
594 | CLASSIFICATION_MODELS = build_classification_models_dict()
595 |
596 |
597 | def _get_model(model_name, config):
598 | return CLASSIFICATION_MODELS[model_name]["model"](config)
599 |
600 |
601 | def run_zero_shot_classification(
602 | dataset,
603 | model_name,
604 | label_field,
605 | categories,
606 | architecture=None,
607 | pretrained=None,
608 | **kwargs,
609 | ):
610 | config = {
611 | "categories": categories,
612 | "clip_model": architecture,
613 | "pretrained": pretrained,
614 | }
615 |
616 | model = _get_model(model_name, config)
617 |
618 | dataset.apply_model(model, label_field=label_field)
--------------------------------------------------------------------------------
/detection.py:
--------------------------------------------------------------------------------
1 | """Zero Shot Detection.
2 |
3 | | Copyright 2017-2023, Voxel51, Inc.
4 | | `voxel51.com `_
5 | |
6 | """
7 |
8 | from importlib.util import find_spec
9 | import pkg_resources
10 |
11 | from PIL import Image
12 |
13 | import fiftyone as fo
14 | import fiftyone.zoo as foz
15 | from fiftyone.core.models import Model
16 |
17 | YOLO_WORLD_PRETRAINS = (
18 | "yolov8s-world",
19 | "yolov8s-worldv2",
20 | "yolov8m-world",
21 | "yolov8m-worldv2",
22 | "yolov8l-world",
23 | "yolov8l-worldv2",
24 | "yolov8x-world",
25 | "yolov8x-worldv2",
26 | )
27 |
28 |
29 | class OwlViTZeroShotModel(Model):
30 | def __init__(self, config):
31 | self.checkpoint = "google/owlvit-base-patch32"
32 |
33 | self.candidate_labels = config.get("categories", None)
34 |
35 | from transformers import pipeline
36 |
37 | self.model = pipeline(
38 | model=self.checkpoint, task="zero-shot-object-detection"
39 | )
40 |
41 | @property
42 | def media_type(self):
43 | return "image"
44 |
45 | def predict(self, args):
46 | image = Image.fromarray(args)
47 | predictions = self._predict(image)
48 | return predictions
49 |
50 | def _predict(self, image):
51 | raw_predictions = self.model(
52 | image, candidate_labels=self.candidate_labels
53 | )
54 |
55 | size = image.size
56 | w, h = size[0], size[1]
57 |
58 | detections = []
59 | for prediction in raw_predictions:
60 | score, box = prediction["score"], prediction["box"]
61 | bounding_box = [
62 | box["xmin"] / w,
63 | box["ymin"] / h,
64 | box["xmax"] / w,
65 | box["ymax"] / h,
66 | ]
67 | ### constrain bounding box to [0, 1]
68 | bounding_box[0] = max(0, bounding_box[0])
69 | bounding_box[1] = max(0, bounding_box[1])
70 | bounding_box[2] = min(1, bounding_box[2])
71 | bounding_box[3] = min(1, bounding_box[3])
72 |
73 | ### convert to (x, y, w, h)
74 | bounding_box[2] = bounding_box[2] - bounding_box[0]
75 | bounding_box[3] = bounding_box[3] - bounding_box[1]
76 |
77 | label = prediction["label"]
78 |
79 | detection = fo.Detection(
80 | label=label,
81 | bounding_box=bounding_box,
82 | confidence=score,
83 | )
84 | detections.append(detection)
85 |
86 | return fo.Detections(detections=detections)
87 |
88 | def predict_all(self, samples, args):
89 | pass
90 |
91 |
92 | def OwlViT_activator():
93 | return find_spec("transformers") is not None
94 |
95 |
96 | def GroundingDINO(config):
97 | classes = config.get("categories", None)
98 | model = foz.load_zoo_model(
99 | "zero-shot-detection-transformer-torch",
100 | name_or_path="IDEA-Research/grounding-dino-tiny",
101 | classes=classes,
102 | )
103 | return model
104 |
105 |
106 | def GroundingDINO_activator():
107 | if find_spec("transformers") is None:
108 | return False
109 | required_version = "4.40.0"
110 | installed_version = pkg_resources.get_distribution("transformers").version
111 | if installed_version < required_version:
112 | return False
113 |
114 | required_fiftyone_version = "0.24.0"
115 | installed_fiftyone_version = pkg_resources.get_distribution(
116 | "fiftyone"
117 | ).version
118 | if installed_fiftyone_version < required_fiftyone_version:
119 | return False
120 |
121 | return True
122 |
123 |
124 | def YOLOWorldModel(config):
125 | classes = config.get("categories", None)
126 | pretrained = config.get("pretrained", "yolov8l-world")
127 | if "v2" in pretrained:
128 | from ultralytics import YOLO
129 |
130 | model = YOLO(pretrained + ".pt")
131 | model.set_classes(classes)
132 | import fiftyone.utils.ultralytics as fouu
133 |
134 | model = fouu.convert_ultralytics_model(model)
135 | else:
136 | model = foz.load_zoo_model(pretrained + "-torch", classes=classes)
137 | return model
138 |
139 |
140 | def YOLOWorld_activator():
141 | if find_spec("ultralytics") is None:
142 | return False
143 | required_version = "8.1.42"
144 | installed_version = pkg_resources.get_distribution("ultralytics").version
145 | return installed_version >= required_version
146 |
147 |
148 | def build_detection_models_dict():
149 | dms = {}
150 |
151 | if YOLOWorld_activator():
152 | dms["YOLO-World"] = {
153 | "activator": YOLOWorld_activator,
154 | "model": YOLOWorldModel,
155 | "submodels": YOLO_WORLD_PRETRAINS,
156 | "name": "YOLO-World",
157 | }
158 |
159 | if OwlViT_activator():
160 | dms["OwlViT"] = {
161 | "activator": OwlViT_activator,
162 | "model": OwlViTZeroShotModel,
163 | "submodels": None,
164 | "name": "OwlViT",
165 | }
166 |
167 | if GroundingDINO_activator():
168 | dms["GroundingDINO"] = {
169 | "activator": GroundingDINO_activator,
170 | "model": GroundingDINO,
171 | "submodels": None,
172 | "name": "GroundingDINO",
173 | }
174 |
175 | return dms
176 |
177 |
178 | DETECTION_MODELS = build_detection_models_dict()
179 |
180 |
181 | def _get_model(model_name, config):
182 | return DETECTION_MODELS[model_name]["model"](config)
183 |
184 |
185 | def run_zero_shot_detection(
186 | dataset, model_name, label_field, categories, pretrained=None, **kwargs
187 | ):
188 | confidence = kwargs.get("confidence", 0.2)
189 | config = {"categories": categories, "pretrained": pretrained}
190 | model = _get_model(model_name, config)
191 | dataset.apply_model(
192 | model, label_field=label_field, confidence_thresh=confidence
193 | )
194 |
--------------------------------------------------------------------------------
/fiftyone.yml:
--------------------------------------------------------------------------------
1 | fiftyone:
2 | version: ">=0.24.0"
3 | name: "@jacobmarks/zero_shot_prediction"
4 | version: "1.3.3"
5 | description: "Run zero-shot (open vocabulary) prediction on your data!"
6 | url: "https://github.com/jacobmarks/zero-shot-predictions-plugin/"
7 | operators:
8 | - zero_shot_predict
9 | - zero_shot_classify
10 | - zero_shot_detect
11 | - zero_shot_instance_segment
12 | - zero_shot_semantic_segment
13 |
--------------------------------------------------------------------------------
/instance_segmentation.py:
--------------------------------------------------------------------------------
1 | """Zero Shot Instance Segmentation.
2 |
3 | | Copyright 2017-2023, Voxel51, Inc.
4 | | `voxel51.com `_
5 | |
6 | """
7 |
8 | from importlib.util import find_spec
9 | import os
10 |
11 |
12 | from fiftyone.core.utils import add_sys_path
13 | import fiftyone.zoo as foz
14 |
15 | SAM_ARCHS = ("ViT-B", "ViT-H", "ViT-L")
16 | SAM_MODELS = [("", SA) for SA in SAM_ARCHS]
17 |
18 |
19 | def SAM_activator():
20 | return True
21 |
22 |
23 | def build_instance_segmentation_models_dict():
24 | sms = {}
25 |
26 | if SAM_activator():
27 | sms["SAM"] = {
28 | "activator": SAM_activator,
29 | "model": "N/A",
30 | "name": "SAM",
31 | "submodels": SAM_MODELS,
32 | }
33 |
34 | return sms
35 |
36 |
37 | INSTANCE_SEGMENTATION_MODELS = build_instance_segmentation_models_dict()
38 |
39 |
40 | def _get_segmentation_model(architecture):
41 | zoo_model_name = (
42 | "segment-anything-" + architecture.lower().replace("-", "") + "-torch"
43 | )
44 | return foz.load_zoo_model(zoo_model_name)
45 |
46 |
47 | def run_zero_shot_instance_segmentation(
48 | dataset,
49 | model_name,
50 | label_field,
51 | categories,
52 | pretrained=None,
53 | architecture=None,
54 | **kwargs
55 | ):
56 | with add_sys_path(os.path.dirname(os.path.abspath(__file__))):
57 | # pylint: disable=no-name-in-module,import-error
58 | from detection import run_zero_shot_detection
59 |
60 | det_model_name, _ = model_name.split(" + ")
61 | det_pretrained, _ = pretrained.split(" + ")
62 | if det_pretrained == "":
63 | det_pretrained = None
64 | _, seg_architecture = architecture.split(" + ")
65 |
66 | run_zero_shot_detection(
67 | dataset,
68 | det_model_name,
69 | label_field,
70 | categories,
71 | pretrained=det_pretrained,
72 | **kwargs
73 | )
74 |
75 | seg_model = _get_segmentation_model(seg_architecture)
76 | dataset.apply_model(
77 | seg_model, label_field=label_field, prompt_field=label_field
78 | )
79 |
--------------------------------------------------------------------------------
/semantic_segmentation.py:
--------------------------------------------------------------------------------
1 | """Zero Shot Semantic Segmentation.
2 |
3 | | Copyright 2017-2023, Voxel51, Inc.
4 | | `voxel51.com `_
5 | |
6 | """
7 |
8 | from importlib.util import find_spec
9 | from PIL import Image
10 | import torch
11 |
12 | import fiftyone as fo
13 | from fiftyone.core.models import Model
14 |
15 |
16 | class CLIPSegZeroShotModel(Model):
17 | def __init__(self, config):
18 | self.candidate_labels = config.get("categories", None)
19 |
20 | from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
21 |
22 | self.processor = CLIPSegProcessor.from_pretrained(
23 | "CIDAS/clipseg-rd64-refined"
24 | )
25 | self.model = CLIPSegForImageSegmentation.from_pretrained(
26 | "CIDAS/clipseg-rd64-refined"
27 | )
28 |
29 | @property
30 | def media_type(self):
31 | return "image"
32 |
33 | def _predict(self, image):
34 | inputs = self.processor(
35 | text=self.candidate_labels,
36 | images=[image] * len(self.candidate_labels),
37 | padding="max_length",
38 | return_tensors="pt",
39 | )
40 | with torch.no_grad():
41 | outputs = self.model(**inputs)
42 | preds = outputs.logits.unsqueeze(1)
43 | # pylint: disable=no-member
44 | mask = torch.argmax(preds, dim=0).squeeze().numpy()
45 | return fo.Segmentation(mask=mask)
46 |
47 | def predict(self, args):
48 | image = Image.fromarray(args)
49 | predictions = self._predict(image)
50 | return predictions
51 |
52 | def predict_all(self, samples, args):
53 | pass
54 |
55 |
56 | def CLIPSeg_activator():
57 | return find_spec("transformers") is not None
58 |
59 |
60 | class GroupViTZeroShotModel(Model):
61 | def __init__(self, config):
62 | cats = config.get("categories", None)
63 | self.candidate_labels = [f"a photo of a {cat}" for cat in cats]
64 |
65 | from transformers import AutoProcessor, GroupViTModel
66 |
67 | self.processor = AutoProcessor.from_pretrained(
68 | "nvidia/groupvit-gccyfcc"
69 | )
70 | self.model = GroupViTModel.from_pretrained("nvidia/groupvit-gccyfcc")
71 |
72 | @property
73 | def media_type(self):
74 | return "image"
75 |
76 | def _predict(self, image):
77 | inputs = self.processor(
78 | text=self.candidate_labels,
79 | images=image,
80 | padding="max_length",
81 | return_tensors="pt",
82 | )
83 | with torch.no_grad():
84 | outputs = self.model(**inputs, output_segmentation=True)
85 | preds = outputs.segmentation_logits.squeeze()
86 | # pylint: disable=no-member
87 | mask = torch.argmax(preds, dim=0).numpy()
88 | return fo.Segmentation(mask=mask)
89 |
90 | def predict(self, args):
91 | image = Image.fromarray(args)
92 | image = image.resize((224, 224))
93 | predictions = self._predict(image)
94 | return predictions
95 |
96 | def predict_all(self, samples, args):
97 | pass
98 |
99 |
100 | def GroupViT_activator():
101 | return find_spec("transformers") is not None
102 |
103 |
104 | SEMANTIC_SEGMENTATION_MODELS = {
105 | "CLIPSeg": {
106 | "activator": CLIPSeg_activator,
107 | "model": CLIPSegZeroShotModel,
108 | "name": "CLIPSeg",
109 | },
110 | "GroupViT": {
111 | "activator": GroupViT_activator,
112 | "model": GroupViTZeroShotModel,
113 | "name": "GroupViT",
114 | },
115 | }
116 |
117 |
118 | def _get_model(model_name, config):
119 | return SEMANTIC_SEGMENTATION_MODELS[model_name]["model"](config)
120 |
121 |
122 | def run_zero_shot_semantic_segmentation(
123 | dataset, model_name, label_field, categories, **kwargs
124 | ):
125 | if "other" not in categories:
126 | categories.append("other")
127 | config = {"categories": categories}
128 | model = _get_model(model_name, config)
129 | dataset.apply_model(model, label_field=label_field)
130 |
131 | dataset.mask_targets[label_field] = {
132 | i: label for i, label in enumerate(categories)
133 | }
134 | dataset.save()
135 |
--------------------------------------------------------------------------------