├── .gitignore ├── .pre-commit-config.yaml ├── README.md ├── __init__.py ├── assets └── icon.svg ├── classification.py ├── detection.py ├── fiftyone.yml ├── instance_segmentation.py └── semantic_segmentation.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/asottile/blacken-docs 3 | rev: v1.12.0 4 | hooks: 5 | - id: blacken-docs 6 | additional_dependencies: [black==21.12b0] 7 | args: ["-l 79"] 8 | exclude: index.umd.js 9 | - repo: https://github.com/ambv/black 10 | rev: 22.3.0 11 | hooks: 12 | - id: black 13 | language_version: python3 14 | args: ["-l 79"] 15 | exclude: index.umd.js 16 | - repo: local 17 | hooks: 18 | - id: pylint 19 | name: pylint 20 | language: system 21 | files: \.py$ 22 | entry: pylint 23 | args: ["--errors-only"] 24 | exclude: index.umd.js 25 | - repo: local 26 | hooks: 27 | - id: ipynb-strip 28 | name: ipynb-strip 29 | language: system 30 | files: \.ipynb$ 31 | entry: jupyter nbconvert --clear-output --ClearOutputPreprocessor.enabled=True 32 | args: ["--log-level=ERROR"] 33 | - repo: https://github.com/pre-commit/mirrors-prettier 34 | rev: v2.6.2 35 | hooks: 36 | - id: prettier 37 | exclude: index.umd.js 38 | language_version: system 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Zero Shot Prediction Plugin 2 | 3 | ![zero_shot_owlvit_example](https://github.com/jacobmarks/zero-shot-prediction-plugin/assets/12500356/6aca099a-17b3-4f85-955d-26c3951f0646) 4 | 5 | This plugin allows you to perform zero-shot prediction on your dataset for the following tasks: 6 | 7 | - Image Classification 8 | - Object Detection 9 | - Instance Segmentation 10 | - Semantic Segmentation 11 | 12 | Given a list of label classes, which you can input either manually, separated by commas, or by uploading a text file, the plugin will perform zero-shot prediction on your dataset for the specified task and add the results to the dataset under a new field, which you can specify. 13 | 14 | ### Updates 15 | - 🆕 **2024-12-03**: Added support for Apple AIMv2 Zero Shot Model (courtesy of [@harpreetsahota204](https://github.com/harpreetsahota204)) 16 | - 🆕 **2024-12-16**: Added MPS and GPU support for ALIGN, AltCLIP, Apple AIMv2 (courtesy of [@harpreetsahota204](https://github.com/harpreetsahota204)) 17 | - **2024-06-22**: Updated interface for Python operator execution 18 | - **2024-05-30**: Added 19 | - support for Grounding DINO for object detection and instance segmentation 20 | - confidence thresholding for object detection and instance segmentation 21 | - **2024-03-06**: Added support for YOLO-World for object detection and instance segmentation! 22 | - **2024-01-10**: Removing LAION CLIP models. 23 | - **2024-01-05**: Added support for EVA-CLIP, SigLIP, and DFN CLIP for image classification! 24 | - **2023-11-28**: Version 1.1.1 supports OpenCLIP for image classification! 25 | - **2023-11-13**: Version 1.1.0 supports [calling operators from the Python SDK](#python-sdk)! 26 | - **2023-10-27**: Added support for MetaCLIP for image classification 27 | - **2023-10-20**: Added support for AltCLIP and Align for image classification and GroupViT for semantic segmentation 28 | 29 | ### Requirements 30 | 31 | - To use YOLO-World models, you must have `"ultalytics>=8.1.42"`. 32 | 33 | ## Models 34 | 35 | ### Built-in Models 36 | 37 | As a starting point, this plugin comes with at least one zero-shot model per task. These are: 38 | 39 | #### Image Classification 40 | 41 | - [ALIGN](https://huggingface.co/docs/transformers/model_doc/align) 42 | - [AltCLIP](https://huggingface.co/docs/transformers/model_doc/altclip) 43 | - 🆕 [Apple AIMv2](https://huggingface.co/apple/aimv2-large-patch14-224-lit) 44 | - [CLIP](https://github.com/openai/CLIP): (OpenAI) 45 | - [CLIPA](https://github.com/UCSC-VLAA/CLIPA) 46 | - [DFN CLIP](https://huggingface.co/apple/DFN5B-CLIP-ViT-H-14-378): Data Filtering Networks 47 | - [EVA-CLIP](https://huggingface.co/QuanSun/EVA-CLIP) 48 | - [MetaCLIP](https://github.com/facebookresearch/metaclip) 49 | - [SigLIP](https://huggingface.co/timm/ViT-SO400M-14-SigLIP-384) 50 | 51 | #### Object Detection 52 | 53 | - [YOLO-World](https://docs.ultralytics.com/models/yolo-world/) 54 | - [Owl-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit) 55 | - [Grounding DINO](https://huggingface.co/docs/transformers/main/en/model_doc/grounding-dino) 56 | 57 | #### Instance Segmentation 58 | 59 | - [Owl-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit) + [Segment Anything (SAM)](https://github.com/facebookresearch/segment-anything) 60 | - [YOLO-World](https://docs.ultralytics.com/models/yolo-world/) + [Segment Anything (SAM)](https://github.com/facebookresearch/segment-anything) 61 | - [Grounding DINO](https://huggingface.co/docs/transformers/main/en/model_doc/grounding-dino) + [Segment Anything (SAM)](https://github.com/facebookresearch/segment-anything) 62 | 63 | #### Semantic Segmentation 64 | 65 | - [CLIPSeg](https://huggingface.co/blog/clipseg-zero-shot) 66 | - [GroupViT](https://huggingface.co/docs/transformers/model_doc/groupvit) 67 | 68 | Most of the models used are from the [HuggingFace Transformers](https://huggingface.co/transformers/) library, and CLIP and SAM models are from the [FiftyOne Model Zoo](https://docs.voxel51.com/user_guide/model_zoo/index.html) 69 | 70 | _Note_— For SAM you will need to have Facebook's `segment-anything` library installed. 71 | 72 | ### Adding Your Own Models 73 | 74 | You can see the implementations for all of these models in the following files: 75 | 76 | - `classification.py` 77 | - `detection.py` 78 | - `instance_segmentation.py` 79 | - `semantic_segmentation.py` 80 | 81 | These models are "registered" via dictionaries in each file. In `semantic_segmentation.py`, for example, the dictionary is: 82 | 83 | ```py 84 | SEMANTIC_SEGMENTATION_MODELS = { 85 | "CLIPSeg": { 86 | "activator": CLIPSeg_activator, 87 | "model": CLIPSegZeroShotModel, 88 | "name": "CLIPSeg", 89 | }, 90 | "GroupViT": { 91 | "activator": GroupViT_activator, 92 | "model": GroupViTZeroShotModel, 93 | "name": "GroupViT", 94 | }, 95 | } 96 | ``` 97 | 98 | The `activator` checks the environment to see if the model is available, and the `model` is a `fiftyone.core.models.Model` object that is instantiated with the model name and the task — or a function that instantiates such a model. The `name` is the name of the model that will be displayed in the dropdown menu in the plugin. 99 | 100 | If you want to add your own model, you can add it to the dictionary in the corresponding file. For example, if you want to add a new semantic segmentation model, you can add it to the `SEMANTIC_SEGMENTATION_MODELS` dictionary in `semantic_segmentation.py`: 101 | 102 | ```py 103 | CLASSIFICATION_MODELS = { 104 | "CLIPSeg": { 105 | "activator": CLIPSeg_activator, 106 | "model": CLIPSegZeroShotModel, 107 | "name": "CLIPSeg", 108 | }, 109 | "GroupViT": { 110 | "activator": GroupViT_activator, 111 | "model": GroupViTZeroShotModel, 112 | "name": "GroupViT", 113 | }, 114 | ..., # other models 115 | "My Model": { 116 | "activator": my_model_activator, 117 | "model": my_model, 118 | "name": "My Model", 119 | } 120 | } 121 | ``` 122 | 123 | 💡 You need to implement the `activator` and `model` functions for your model. The `activator` should check the environment to see if the model is available, and the `model` should be a `fiftyone.core.models.Model` object that is instantiated with the model name and the task. 124 | 125 | ## Watch On Youtube 126 | 127 | [![Video Thumbnail](https://img.youtube.com/vi/GlwyFHbTklw/0.jpg)](https://www.youtube.com/watch?v=GlwyFHbTklw&list=PLuREAXoPgT0RZrUaT0UpX_HzwKkoB-S9j&index=7) 128 | 129 | ## Installation 130 | 131 | ```shell 132 | fiftyone plugins download https://github.com/jacobmarks/zero-shot-prediction-plugin 133 | ``` 134 | 135 | If you want to use AltCLIP, Align, Owl-ViT, CLIPSeg, or GroupViT, you will also need to install the `transformers` library: 136 | 137 | ```shell 138 | pip install transformers 139 | ``` 140 | 141 | If you want to use SAM, you will also need to install the `segment-anything` library: 142 | 143 | ```shell 144 | pip install git+https://github.com/facebookresearch/segment-anything.git 145 | ``` 146 | 147 | If you want to use OpenCLIP, you will also need to install the `open_clip` library from PyPI: 148 | 149 | ```shell 150 | pip install open-clip-torch 151 | ``` 152 | 153 | Or from source: 154 | 155 | ```shell 156 | pip install git+https://github.com/mlfoundations/open_clip.git 157 | ``` 158 | 159 | If you want to use YOLO-World, you will also need to install the `ultralytics` library: 160 | 161 | ```shell 162 | pip install -U ultralytics 163 | ``` 164 | 165 | ## Usage 166 | 167 | All of the operators in this plugin can be run in _delegated_ execution mode. This means that instead of waiting for the operator to finish, you _schedule_ 168 | the operation to be performed separately. This is useful for long-running operations, such as performing inference on a large dataset. 169 | 170 | Once you have pressed the `Schedule` button for the operator, you will be able to see the job from the command line using FiftyOne's [command line interface](https://docs.voxel51.com/cli/index.html#fiftyone-delegated-operations): 171 | 172 | ```shell 173 | fiftyone delegated list 174 | ``` 175 | 176 | will show you the status of all delegated operations. 177 | 178 | To launch a service which runs the operation, as well as any other delegated operations that have been scheduled, run: 179 | 180 | ```shell 181 | fiftyone delegated launch 182 | ``` 183 | 184 | Once the operation has completed, you can view the results in the App (upon refresh). 185 | 186 | After the operation completes, you can also clean up your list of delegated operations by running: 187 | 188 | ```shell 189 | fiftyone delegated cleanup -s COMPLETED 190 | ``` 191 | 192 | ## Operators 193 | 194 | ### `zero_shot_predict` 195 | 196 | - Select the task you want to perform zero-shot prediction on (image classification, object detection, instance segmentation, or semantic segmentation), and the field you want to add the results to. 197 | 198 | ### `zero_shot_classify` 199 | 200 | - Perform zero-shot image classification on your dataset 201 | 202 | ### `zero_shot_detect` 203 | 204 | - Perform zero-shot object detection on your dataset 205 | 206 | ### `zero_shot_instance_segment` 207 | 208 | - Perform zero-shot instance segmentation on your dataset 209 | 210 | ### `zero_shot_semantic_segment` 211 | 212 | - Perform zero-shot semantic segmentation on your dataset 213 | 214 | ## Python SDK 215 | 216 | You can also use the compute operators from the Python SDK! 217 | 218 | ```python 219 | import fiftyone as fo 220 | import fiftyone.operators as foo 221 | import fiftyone.zoo as foz 222 | 223 | dataset = fo.load_dataset("quickstart") 224 | 225 | ## Access the operator via its URI (plugin name + operator name) 226 | zsc = foo.get_operator("@jacobmarks/zero_shot_prediction/zero_shot_classify") 227 | 228 | ## Run zero-shot classification on all images in the dataset, specifying the labels with the `labels` argument 229 | zsc(dataset, labels=["cat", "dog", "bird"]) 230 | 231 | ## Run zero-shot classification on all images in the dataset, specifying the labels with a text file 232 | zsc(dataset, labels_file="/path/to/labels.txt") 233 | 234 | ## Specify the model to use, and the field to add the results to 235 | zsc(dataset, labels=["cat", "dog", "bird"], model_name="CLIP", label_field="predictions") 236 | 237 | ## Run zero-shot detection on a view 238 | zsd = foo.get_operator("@jacobmarks/zero_shot_prediction/zero_shot_detect") 239 | view = dataset.take(10) 240 | await zsd( 241 | view, 242 | labels=["license plate"], 243 | model_name="OwlViT", 244 | label_field="owlvit_license_plate", 245 | ) 246 | ``` 247 | 248 | All four of the task-specific zero-shot prediction operators also expose a `list_models()` method, which returns a list of the available models for that task. 249 | 250 | ```python 251 | zsss = foo.get_operator( 252 | "@jacobmarks/zero_shot_prediction/zero_shot_semantic_segment" 253 | ) 254 | 255 | zsss.list_models() 256 | 257 | ## ['CLIPSeg', 'GroupViT'] 258 | ``` 259 | 260 | **Note**: The `zero_shot_predict` operator is not yet supported in the Python SDK. 261 | 262 | **Note**: With earlier versions of FiftyOne, you may have trouble running these 263 | operator executions within a Jupyter notebook. If so, try running them in a 264 | Python script, or upgrading to the latest version of FiftyOne! 265 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | """Zero Shot Prediction plugin. 2 | 3 | | Copyright 2017-2023, Voxel51, Inc. 4 | | `voxel51.com `_ 5 | | 6 | """ 7 | 8 | import os 9 | import base64 10 | 11 | from fiftyone.core.utils import add_sys_path 12 | import fiftyone.operators as foo 13 | from fiftyone.operators import types 14 | 15 | with add_sys_path(os.path.dirname(os.path.abspath(__file__))): 16 | # pylint: disable=no-name-in-module,import-error 17 | from classification import ( 18 | run_zero_shot_classification, 19 | CLASSIFICATION_MODELS, 20 | ) 21 | from detection import run_zero_shot_detection, DETECTION_MODELS 22 | from instance_segmentation import ( 23 | run_zero_shot_instance_segmentation, 24 | INSTANCE_SEGMENTATION_MODELS, 25 | ) 26 | from semantic_segmentation import ( 27 | run_zero_shot_semantic_segmentation, 28 | SEMANTIC_SEGMENTATION_MODELS, 29 | ) 30 | 31 | 32 | ZERO_SHOT_TASKS = ( 33 | "classification", 34 | "detection", 35 | "instance_segmentation", 36 | "semantic_segmentation", 37 | ) 38 | 39 | 40 | MODEL_LISTS = { 41 | "classification": CLASSIFICATION_MODELS, 42 | "detection": DETECTION_MODELS, 43 | "instance_segmentation": INSTANCE_SEGMENTATION_MODELS, 44 | "semantic_segmentation": SEMANTIC_SEGMENTATION_MODELS, 45 | } 46 | 47 | 48 | def _format_submodel_name(submodel): 49 | if type(submodel) == str: 50 | return submodel 51 | return f"{submodel[0]}|{submodel[1]}" 52 | 53 | 54 | def _format_submodel_label(submodel): 55 | if type(submodel) == str: 56 | return submodel.split(".")[0] 57 | pretrain = submodel[0].split("/")[-1] 58 | arch = submodel[1] 59 | 60 | arch_string = f"Architecture: {arch}" 61 | _pt = pretrain and pretrain != "openai" 62 | pretrain_string = f" | Pretrained: {pretrain}" if _pt else "" 63 | return f"{arch_string}{pretrain_string}" 64 | 65 | 66 | def _execution_mode(ctx, inputs): 67 | delegate = ctx.params.get("delegate", False) 68 | 69 | if delegate: 70 | description = "Uncheck this box to execute the operation immediately" 71 | else: 72 | description = "Check this box to delegate execution of this task" 73 | 74 | inputs.bool( 75 | "delegate", 76 | default=False, 77 | required=True, 78 | label="Delegate execution?", 79 | description=description, 80 | view=types.CheckboxView(), 81 | ) 82 | 83 | if delegate: 84 | inputs.view( 85 | "notice", 86 | types.Notice( 87 | label=( 88 | "You've chosen delegated execution. Note that you must " 89 | "have a delegated operation service running in order for " 90 | "this task to be processed. See " 91 | "https://docs.voxel51.com/plugins/index.html#operators " 92 | "for more information" 93 | ) 94 | ), 95 | ) 96 | 97 | 98 | def _get_active_models(task): 99 | ams = [] 100 | for element in MODEL_LISTS[task].values(): 101 | if element["activator"](): 102 | ams.append(element["name"]) 103 | return ams 104 | 105 | 106 | def _get_labels(ctx): 107 | if ctx.params.get("label_input_choices", False) == "direct": 108 | labels = ctx.params.get("labels", "") 109 | return [label.strip() for label in labels.split(",")] 110 | else: 111 | labels_file = ctx.params.get("labels_file", None).strip() 112 | if "," in labels_file: 113 | lf = labels_file.split(",")[1] 114 | else: 115 | lf = labels_file 116 | decoded_bytes = base64.b64decode(lf) 117 | labels = decoded_bytes.decode("utf-8") 118 | return [label.strip() for label in labels.split("\n")] 119 | 120 | 121 | TASK_TO_FUNCTION = { 122 | "classification": run_zero_shot_classification, 123 | "detection": run_zero_shot_detection, 124 | "instance_segmentation": run_zero_shot_instance_segmentation, 125 | "semantic_segmentation": run_zero_shot_semantic_segmentation, 126 | } 127 | 128 | 129 | def run_zero_shot_task( 130 | dataset, 131 | task, 132 | model_name, 133 | label_field, 134 | categories, 135 | architecture, 136 | pretrained, 137 | confidence=0.2, 138 | ): 139 | return TASK_TO_FUNCTION[task]( 140 | dataset, 141 | model_name, 142 | label_field, 143 | categories, 144 | architecture=architecture, 145 | pretrained=pretrained, 146 | confidence=confidence, 147 | ) 148 | 149 | 150 | def _model_name_to_field_name(model_name): 151 | fn = ( 152 | model_name.lower() 153 | .replace(" ", "_") 154 | .replace("_+", "") 155 | .replace("-", "") 156 | .split("(")[0] 157 | .strip() 158 | ) 159 | if fn[-1] == "_": 160 | fn = fn[:-1] 161 | return fn 162 | 163 | 164 | def _handle_model_choice_inputs(ctx, inputs, chosen_task): 165 | active_models = _get_active_models(chosen_task) 166 | 167 | if len(active_models) == 0: 168 | inputs.str( 169 | "no_models_warning", 170 | view=types.Warning( 171 | label=f"No Models Found", 172 | description="No models were found for the selected task. Please install the required libraries.", 173 | ), 174 | ) 175 | return types.Property(inputs) 176 | 177 | ct_label = ( 178 | "Segmentation" 179 | if "segment" in chosen_task 180 | else chosen_task.capitalize() 181 | ) 182 | model_dropdown_label = f"{ct_label} Model" 183 | model_dropdown = types.Dropdown(label=model_dropdown_label) 184 | for model in active_models: 185 | model_dropdown.add_choice(model, label=model) 186 | inputs.enum( 187 | f"model_choice_{chosen_task}", 188 | model_dropdown.values(), 189 | default=model_dropdown.choices[0].value, 190 | view=model_dropdown, 191 | ) 192 | 193 | model_choice = ctx.params.get( 194 | f"model_choice_{chosen_task}", model_dropdown.choices[0].value 195 | ) 196 | mc = model_choice.split("(")[0].strip().lower() 197 | 198 | submodels = MODEL_LISTS[chosen_task][model_choice].get("submodels", None) 199 | if submodels is not None: 200 | if len(submodels) == 1: 201 | ctx.params["pretrained"] = submodels[0][0] 202 | ctx.params["architecture"] = submodels[0][1] 203 | else: 204 | submodel_dropdown = types.Dropdown( 205 | label=f"{chosen_task.capitalize()} Submodel" 206 | ) 207 | for submodel in submodels: 208 | submodel_dropdown.add_choice( 209 | _format_submodel_name(submodel), 210 | label=_format_submodel_label(submodel), 211 | ) 212 | inputs.enum( 213 | f"submodel_choice_{chosen_task}_{mc}", 214 | submodel_dropdown.values(), 215 | default=submodel_dropdown.choices[0].value, 216 | view=submodel_dropdown, 217 | ) 218 | 219 | submodel_choice = ctx.params.get( 220 | f"submodel_choice_{chosen_task}_{model_choice}", 221 | submodel_dropdown.choices[0].value, 222 | ) 223 | 224 | if "|" in submodel_choice: 225 | if chosen_task == "instance_segmentation": 226 | ctx.params["pretrained"] += submodel_choice.split("|")[0] 227 | ctx.params["architecture"] += submodel_choice.split("|")[1] 228 | else: 229 | ctx.params["pretrained"] = submodel_choice.split("|")[0] 230 | ctx.params["architecture"] = submodel_choice.split("|")[1] 231 | else: 232 | if chosen_task == "instance_segmentation": 233 | ctx.params["pretrained"] += submodel_choice 234 | else: 235 | ctx.params["pretrained"] = submodel_choice 236 | ctx.params["architecture"] = None 237 | 238 | 239 | def handle_model_choice_inputs(ctx, inputs, chosen_task): 240 | if chosen_task == "instance_segmentation": 241 | _handle_model_choice_inputs(ctx, inputs, "detection") 242 | if ctx.params.get("pretrained", None) is not None: 243 | ctx.params["pretrained"] = ctx.params["pretrained"] + " + " 244 | else: 245 | ctx.params["pretrained"] = " + " 246 | if ctx.params.get("architecture", None) is not None: 247 | ctx.params["architecture"] = ctx.params["architecture"] + " + " 248 | else: 249 | ctx.params["architecture"] = " + " 250 | _handle_model_choice_inputs(ctx, inputs, chosen_task) 251 | 252 | 253 | class ZeroShotTasks(foo.Operator): 254 | @property 255 | def config(self): 256 | _config = foo.OperatorConfig( 257 | name="zero_shot_predict", 258 | label="Perform Zero Shot Prediction", 259 | dynamic=True, 260 | ) 261 | _config.icon = "/assets/icon.svg" 262 | return _config 263 | 264 | def resolve_delegation(self, ctx): 265 | return ctx.params.get("delegate", False) 266 | 267 | def resolve_input(self, ctx): 268 | inputs = types.Object() 269 | 270 | radio_choices = types.RadioGroup() 271 | radio_choices.add_choice("classification", label="Classification") 272 | radio_choices.add_choice("detection", label="Detection") 273 | radio_choices.add_choice( 274 | "instance_segmentation", label="Instance Segmentation" 275 | ) 276 | radio_choices.add_choice( 277 | "semantic_segmentation", label="Semantic Segmentation" 278 | ) 279 | inputs.enum( 280 | "task_choices", 281 | radio_choices.values(), 282 | default=radio_choices.choices[0].value, 283 | label="Zero Shot Task", 284 | view=radio_choices, 285 | ) 286 | 287 | chosen_task = ctx.params.get("task_choices", "classification") 288 | handle_model_choice_inputs(ctx, inputs, chosen_task) 289 | active_models = _get_active_models(chosen_task) 290 | 291 | if chosen_task in ["detection", "instance_segmentation"]: 292 | inputs.float( 293 | "confidence", 294 | label="Confidence Threshold", 295 | default=0.2, 296 | description="The minimum confidence required for a prediction to be included", 297 | ) 298 | 299 | label_input_choices = types.RadioGroup() 300 | label_input_choices.add_choice("direct", label="Input directly") 301 | label_input_choices.add_choice("file", label="Input from file") 302 | inputs.enum( 303 | "label_input_choices", 304 | label_input_choices.values(), 305 | default=label_input_choices.choices[0].value, 306 | label="Labels", 307 | view=label_input_choices, 308 | ) 309 | 310 | if ctx.params.get("label_input_choices", "direct") == "direct": 311 | inputs.str( 312 | "labels", 313 | label="Labels", 314 | description="Enter the names of the classes you wish to generate predictions for, separated by commas", 315 | required=True, 316 | ) 317 | else: 318 | labels_file = types.FileView(label="Labels File") 319 | inputs.str( 320 | "labels_file", 321 | label="Labels File", 322 | required=True, 323 | view=labels_file, 324 | ) 325 | 326 | model_name = ctx.params.get( 327 | f"model_choice_{chosen_task}", active_models[0] 328 | ) 329 | model_name = model_name.split("(")[0].strip().replace("-", "").lower() 330 | inputs.str( 331 | f"label_field_{chosen_task}_{model_name}", 332 | label="Label Field", 333 | default=_model_name_to_field_name(model_name), 334 | description="The field to store the predicted labels in", 335 | required=True, 336 | ) 337 | _execution_mode(ctx, inputs) 338 | inputs.view_target(ctx) 339 | return types.Property(inputs) 340 | 341 | def execute(self, ctx): 342 | view = ctx.target_view() 343 | task = ctx.params.get("task_choices", "classification") 344 | active_models = _get_active_models(task) 345 | model_name = ctx.params.get(f"model_choice_{task}", active_models[0]) 346 | mn = model_name.split("(")[0].strip().lower().replace("-", "") 347 | if task == "instance_segmentation": 348 | model_name = ( 349 | ctx.params[f"model_choice_detection"] + " + " + model_name 350 | ) 351 | categories = _get_labels(ctx) 352 | label_field = ctx.params.get(f"label_field_{task}_{mn}", mn) 353 | architecture = ctx.params.get("architecture", None) 354 | pretrained = ctx.params.get("pretrained", None) 355 | confidence = ctx.params.get("confidence", 0.2) 356 | 357 | run_zero_shot_task( 358 | view, 359 | task, 360 | model_name, 361 | label_field, 362 | categories, 363 | architecture, 364 | pretrained, 365 | confidence=confidence, 366 | ) 367 | ctx.ops.reload_dataset() 368 | 369 | 370 | ### Common input control flow for all tasks 371 | def _input_control_flow(ctx, task): 372 | inputs = types.Object() 373 | active_models = _get_active_models(task) 374 | if len(active_models) == 0: 375 | inputs.str( 376 | "no_models_warning", 377 | view=types.Warning( 378 | label=f"No Models Found", 379 | description="No models were found for the selected task. Please install the required libraries.", 380 | ), 381 | ) 382 | return types.Property(inputs) 383 | 384 | handle_model_choice_inputs(ctx, inputs, task) 385 | 386 | label_input_choices = types.RadioGroup() 387 | label_input_choices.add_choice("direct", label="Input directly") 388 | label_input_choices.add_choice("file", label="Input from file") 389 | inputs.enum( 390 | "label_input_choices", 391 | label_input_choices.values(), 392 | default=label_input_choices.choices[0].value, 393 | label="Labels", 394 | view=label_input_choices, 395 | ) 396 | 397 | if task in ["detection", "instance_segmentation"]: 398 | inputs.float( 399 | "confidence", 400 | label="Confidence Threshold", 401 | default=0.2, 402 | description="The minimum confidence required for a prediction to be included", 403 | ) 404 | 405 | if ctx.params.get("label_input_choices", False) == "direct": 406 | inputs.str( 407 | "labels", 408 | label="Labels", 409 | description="Enter the names of the classes you wish to generate predictions for, separated by commas", 410 | required=True, 411 | ) 412 | else: 413 | labels_file = types.FileView(label="Labels File") 414 | inputs.str( 415 | "labels_file", 416 | label="Labels File", 417 | required=True, 418 | view=labels_file, 419 | ) 420 | 421 | model_name = ctx.params.get(f"model_choice_{task}", active_models[0]) 422 | mn = model_name.split("(")[0].strip().lower().replace("-", "") 423 | inputs.str( 424 | f"label_field_{mn}", 425 | label="Label Field", 426 | default=_model_name_to_field_name(model_name), 427 | description="The field to store the predicted labels in", 428 | required=True, 429 | ) 430 | _execution_mode(ctx, inputs) 431 | inputs.view_target(ctx) 432 | return inputs 433 | 434 | 435 | def _execute_control_flow(ctx, task): 436 | view = ctx.target_view() 437 | model_name = ctx.params.get(f"model_choice_{task}", "CLIP") 438 | mn = _model_name_to_field_name(model_name).split("(")[0].strip().lower() 439 | label_field = ctx.params.get(f"label_field_{mn}", mn) 440 | if task == "instance_segmentation": 441 | model_name = ctx.params[f"model_choice_detection"] + " + " + model_name 442 | 443 | kwargs = {} 444 | if task in ["detection", "instance_segmentation"]: 445 | kwargs["confidence"] = ctx.params.get("confidence", 0.2) 446 | categories = _get_labels(ctx) 447 | 448 | architecture = ctx.params.get("architecture", None) 449 | pretrained = ctx.params.get("pretrained", None) 450 | 451 | run_zero_shot_task( 452 | view, 453 | task, 454 | model_name, 455 | label_field, 456 | categories, 457 | architecture, 458 | pretrained, 459 | **kwargs, 460 | ) 461 | ctx.ops.reload_dataset() 462 | 463 | 464 | NAME_TO_TASK = { 465 | "zero_shot_classify": "classification", 466 | "zero_shot_detect": "detection", 467 | "zero_shot_instance_segment": "instance_segmentation", 468 | "zero_shot_semantic_segment": "semantic_segmentation", 469 | } 470 | 471 | 472 | def _format_model_name(model_name): 473 | return ( 474 | model_name.lower().replace(" ", "").replace("_", "").replace("-", "") 475 | ) 476 | 477 | 478 | def _match_model_name(model_name, model_names): 479 | for name in model_names: 480 | if _format_model_name(name) == _format_model_name(model_name): 481 | return name 482 | raise ValueError( 483 | f"Model name {model_name} not found. Use one of {model_names}" 484 | ) 485 | 486 | 487 | def _resolve_model_name(task, model_name): 488 | if model_name is None: 489 | return list(MODEL_LISTS[task].keys())[0] 490 | elif model_name not in MODEL_LISTS[task]: 491 | return _match_model_name(model_name, list(MODEL_LISTS[task].keys())) 492 | return model_name 493 | 494 | 495 | def _resolve_labels(labels, labels_file): 496 | if labels is None and labels_file is None: 497 | raise ValueError("Must provide either labels or labels_file") 498 | 499 | if labels is not None and labels_file is not None: 500 | raise ValueError("Cannot provide both labels and labels_file") 501 | 502 | if labels is not None and type(labels) == list: 503 | labels = ", ".join(labels) 504 | else: 505 | with open(labels_file, "r") as f: 506 | labels = [label.strip() for label in f.readlines()] 507 | labels = ", ".join(labels) 508 | 509 | return labels 510 | 511 | 512 | def _resolve_label_field(model_name, label_field): 513 | if label_field is None: 514 | label_field = _model_name_to_field_name(model_name) 515 | label_field_name = f"label_field_{_model_name_to_field_name(model_name)}" 516 | 517 | return label_field, label_field_name 518 | 519 | 520 | def _handle_calling( 521 | uri, 522 | sample_collection, 523 | model_name=None, 524 | labels=None, 525 | labels_file=None, 526 | label_field=None, 527 | delegate=False, 528 | confidence=None, 529 | ): 530 | ctx = dict(view=sample_collection.view()) 531 | 532 | task = NAME_TO_TASK[uri.split("/")[-1]] 533 | 534 | model_name = _resolve_model_name(task, model_name) 535 | labels = _resolve_labels(labels, labels_file) 536 | 537 | label_field, label_field_name = _resolve_label_field( 538 | model_name, label_field 539 | ) 540 | 541 | params = dict( 542 | label_input_choices="direct", 543 | delegate=delegate, 544 | labels=labels, 545 | ) 546 | if confidence is not None: 547 | params["confidence"] = confidence 548 | params[label_field_name] = label_field 549 | params[f"model_choice_{task}"] = model_name 550 | 551 | return foo.execute_operator(uri, ctx, params=params) 552 | 553 | 554 | class ZeroShotClassify(foo.Operator): 555 | @property 556 | def config(self): 557 | _config = foo.OperatorConfig( 558 | name="zero_shot_classify", 559 | label="Perform Zero Shot Classification", 560 | dynamic=True, 561 | ) 562 | _config.icon = "/assets/icon.svg" 563 | return _config 564 | 565 | def resolve_delegation(self, ctx): 566 | return ctx.params.get("delegate", False) 567 | 568 | def resolve_input(self, ctx): 569 | inputs = _input_control_flow(ctx, "classification") 570 | return types.Property(inputs) 571 | 572 | def execute(self, ctx): 573 | _execute_control_flow(ctx, "classification") 574 | 575 | def __call__( 576 | self, 577 | sample_collection, 578 | model_name=None, 579 | labels=None, 580 | labels_file=None, 581 | label_field=None, 582 | delegate=False, 583 | ): 584 | return _handle_calling( 585 | self.uri, 586 | sample_collection, 587 | model_name=model_name, 588 | labels=labels, 589 | labels_file=labels_file, 590 | label_field=label_field, 591 | delegate=delegate, 592 | ) 593 | 594 | def list_models(self): 595 | return list(MODEL_LISTS["classification"].keys()) 596 | 597 | 598 | class ZeroShotDetect(foo.Operator): 599 | @property 600 | def config(self): 601 | _config = foo.OperatorConfig( 602 | name="zero_shot_detect", 603 | label="Perform Zero Shot Detection", 604 | dynamic=True, 605 | ) 606 | _config.icon = "/assets/icon.svg" 607 | return _config 608 | 609 | def resolve_delegation(self, ctx): 610 | return ctx.params.get("delegate", False) 611 | 612 | def resolve_input(self, ctx): 613 | inputs = _input_control_flow(ctx, "detection") 614 | inputs.float( 615 | "confidence", 616 | label="Confidence Threshold", 617 | default=0.2, 618 | description="The minimum confidence required for a prediction to be included", 619 | ) 620 | return types.Property(inputs) 621 | 622 | def execute(self, ctx): 623 | _execute_control_flow(ctx, "detection") 624 | 625 | def __call__( 626 | self, 627 | sample_collection, 628 | model_name=None, 629 | labels=None, 630 | labels_file=None, 631 | label_field=None, 632 | delegate=False, 633 | confidence=0.2, 634 | ): 635 | return _handle_calling( 636 | self.uri, 637 | sample_collection, 638 | model_name=model_name, 639 | labels=labels, 640 | labels_file=labels_file, 641 | label_field=label_field, 642 | delegate=delegate, 643 | confidence=confidence, 644 | ) 645 | 646 | def list_models(self): 647 | return list(MODEL_LISTS["detection"].keys()) 648 | 649 | 650 | class ZeroShotInstanceSegment(foo.Operator): 651 | @property 652 | def config(self): 653 | _config = foo.OperatorConfig( 654 | name="zero_shot_instance_segment", 655 | label="Perform Zero Shot Instance Segmentation", 656 | dynamic=True, 657 | ) 658 | _config.icon = "/assets/icon.svg" 659 | return _config 660 | 661 | def resolve_delegation(self, ctx): 662 | return ctx.params.get("delegate", False) 663 | 664 | def resolve_input(self, ctx): 665 | inputs = _input_control_flow(ctx, "instance_segmentation") 666 | return types.Property(inputs) 667 | 668 | def execute(self, ctx): 669 | _execute_control_flow(ctx, "instance_segmentation") 670 | 671 | def __call__( 672 | self, 673 | sample_collection, 674 | model_name=None, 675 | labels=None, 676 | labels_file=None, 677 | label_field=None, 678 | delegate=False, 679 | confidence=0.2, 680 | ): 681 | return _handle_calling( 682 | self.uri, 683 | sample_collection, 684 | model_name=model_name, 685 | labels=labels, 686 | labels_file=labels_file, 687 | label_field=label_field, 688 | delegate=delegate, 689 | confidence=confidence, 690 | ) 691 | 692 | def list_models(self): 693 | return list(MODEL_LISTS["instance_segmentation"].keys()) 694 | 695 | 696 | class ZeroShotSemanticSegment(foo.Operator): 697 | @property 698 | def config(self): 699 | _config = foo.OperatorConfig( 700 | name="zero_shot_semantic_segment", 701 | label="Perform Zero Shot Semantic Segmentation", 702 | dynamic=True, 703 | ) 704 | _config.icon = "/assets/icon.svg" 705 | return _config 706 | 707 | def resolve_delegation(self, ctx): 708 | return ctx.params.get("delegate", False) 709 | 710 | def resolve_input(self, ctx): 711 | inputs = _input_control_flow(ctx, "semantic_segmentation") 712 | return types.Property(inputs) 713 | 714 | def execute(self, ctx): 715 | _execute_control_flow(ctx, "semantic_segmentation") 716 | 717 | def __call__( 718 | self, 719 | sample_collection, 720 | model_name=None, 721 | labels=None, 722 | labels_file=None, 723 | label_field=None, 724 | delegate=False, 725 | ): 726 | return _handle_calling( 727 | self.uri, 728 | sample_collection, 729 | model_name=model_name, 730 | labels=labels, 731 | labels_file=labels_file, 732 | label_field=label_field, 733 | delegate=delegate, 734 | ) 735 | 736 | def list_models(self): 737 | return list(MODEL_LISTS["semantic_segmentation"].keys()) 738 | 739 | 740 | def register(plugin): 741 | plugin.register(ZeroShotTasks) 742 | plugin.register(ZeroShotClassify) 743 | plugin.register(ZeroShotDetect) 744 | plugin.register(ZeroShotInstanceSegment) 745 | plugin.register(ZeroShotSemanticSegment) 746 | -------------------------------------------------------------------------------- /assets/icon.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /classification.py: -------------------------------------------------------------------------------- 1 | """Zero Shot Classification. 2 | 3 | | Copyright 2017-2023, Voxel51, Inc. 4 | | `voxel51.com `_ 5 | | 6 | """ 7 | from importlib.util import find_spec 8 | import numpy as np 9 | from PIL import Image 10 | import torch 11 | 12 | import fiftyone as fo 13 | from fiftyone.core.models import Model 14 | import fiftyone.zoo as foz 15 | 16 | ### Make tuples in form ("pretrained", "clip_model") 17 | 18 | OPENAI_ARCHS = ["ViT-B-32", "ViT-B-16", "ViT-L-14"] 19 | OPENAI_CLIP_MODELS = [("openai", model) for model in OPENAI_ARCHS] 20 | 21 | DFN_CLIP_MODELS = [("dfn2b", "ViT-B-16")] 22 | 23 | META_ARCHS = ("ViT-B-16-quickgelu", "ViT-B-32-quickgelu", "ViT-L-14-quickgelu") 24 | META_PRETRAINS = ("metaclip_400m", "metaclip_fullcc") 25 | META_CLIP_MODELS = [ 26 | (pretrain, arch) for pretrain in META_PRETRAINS for arch in META_ARCHS 27 | ] 28 | META_CLIP_MODELS.append(("metaclip_fullcc", "ViT-H-14-quickgelu")) 29 | 30 | CLIPA_MODELS = [("", "hf-hub:UCSC-VLAA/ViT-L-14-CLIPA-datacomp1B")] 31 | 32 | SIGLIP_ARCHS = ( 33 | "ViT-B-16-SigLIP", 34 | "ViT-B-16-SigLIP-256", 35 | "ViT-B-16-SigLIP-384", 36 | "ViT-L-16-SigLIP-256", 37 | "ViT-L-16-SigLIP-384", 38 | "ViT-SO400M-14-SigLIP", 39 | "ViT-SO400M-14-SigLIP-384", 40 | ) 41 | SIGLIP_MODELS = [("", arch) for arch in SIGLIP_ARCHS] 42 | 43 | EVA_CLIP_MODELS = [ 44 | ("merged2b_s8b_b131k", "EVA02-B-16"), 45 | ("merged2b_s6b_b61k", "EVA02-L-14-336"), 46 | ("merged2b_s4b_b131k", "EVA02-L-14"), 47 | ] 48 | 49 | AIMV2_MODELS = ["aimv2-large-patch14-224-lit"] 50 | 51 | def get_device(): 52 | """Helper function to determine the best available device.""" 53 | if torch.cuda.is_available(): 54 | device = "cuda" 55 | print(f"Using CUDA device: {torch.cuda.get_device_name(0)}") 56 | elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): 57 | device = "mps" 58 | print("Using Apple Silicon (MPS) device") 59 | else: 60 | device = "cpu" 61 | print("Using CPU device") 62 | return device 63 | 64 | 65 | def CLIPZeroShotModel(config): 66 | """ 67 | This function loads a zero-shot classification model using the CLIP architecture. 68 | It utilizes the FiftyOne Zoo to load a pre-trained model based on the provided 69 | configuration. 70 | 71 | Args: 72 | config (dict): A dictionary containing configuration parameters for the model. 73 | - categories (list, optional): A list of categories for classification. 74 | - clip_model (str, optional): The architecture of the CLIP model to use. 75 | - pretrained (str, optional): The pre-trained weights to use. 76 | 77 | Returns: 78 | Model: A loaded CLIP zero-shot classification model ready for inference. 79 | """ 80 | cats = config.get("categories", None) 81 | clip_model = config.get("clip_model", "ViT-B-32") 82 | pretrained = config.get("pretrained", "openai") 83 | 84 | model = foz.load_zoo_model( 85 | "clip-vit-base32-torch", 86 | clip_model=clip_model, 87 | pretrained=pretrained, 88 | text_prompt="A photo of a", 89 | classes=cats, 90 | ) 91 | 92 | return model 93 | 94 | 95 | def CLIP_activator(): 96 | """ 97 | Determines if the CLIP model can be activated. 98 | 99 | This function checks for the availability of the necessary 100 | components to activate the CLIP model. It returns True if 101 | the model can be activated, otherwise False. 102 | 103 | Returns: 104 | bool: True if the CLIP model can be activated, False otherwise. 105 | """ 106 | return True 107 | 108 | class AltCLIPZeroShotModel(Model): 109 | """ 110 | This class implements a zero-shot classification model using the AltCLIP architecture. 111 | It leverages the AltCLIP model from the Hugging Face Transformers library to perform 112 | image classification without requiring task-specific training data. 113 | 114 | Args: 115 | config (dict): A dictionary containing configuration parameters for the model. 116 | - categories (list, optional): A list of categories for classification. 117 | 118 | Attributes: 119 | categories (list): The list of categories for classification. 120 | candidate_labels (list): A list of text prompts for each category. 121 | model (AltCLIPModel): The pre-trained AltCLIP model. 122 | processor (AltCLIPProcessor): The processor for preparing inputs for the model. 123 | 124 | Methods: 125 | media_type: Returns the type of media the model is designed to process. 126 | _predict(image): Performs prediction on a single image. 127 | predict(args): Converts input data to an image and performs prediction. 128 | _predict_all(images): Performs prediction on a list of images. 129 | """ 130 | def __init__(self, config): 131 | """ 132 | Initializes the AltCLIPZeroShotModel with the given configuration. 133 | 134 | This constructor sets up the model by initializing the categories 135 | and candidate labels for classification. It also loads the pre-trained 136 | AltCLIP model and processor from the Hugging Face Transformers library. 137 | 138 | Args: 139 | config (dict): A dictionary containing configuration parameters for the model. 140 | - categories (list, optional): A list of categories for classification. 141 | """ 142 | self.categories = config.get("categories", None) 143 | self.candidate_labels = [ 144 | f"a photo of a {cat}" for cat in self.categories 145 | ] 146 | 147 | from transformers import AltCLIPModel, AltCLIPProcessor 148 | 149 | self.model = AltCLIPModel.from_pretrained("BAAI/AltCLIP") 150 | self.processor = AltCLIPProcessor.from_pretrained("BAAI/AltCLIP") 151 | 152 | # Set up device 153 | self.device = get_device() 154 | 155 | # Move model to appropriate device and set to eval mode 156 | self.model = self.model.to(self.device) 157 | self.model.eval() 158 | 159 | @property 160 | def media_type(self): 161 | return "image" 162 | 163 | def _predict(self, image): 164 | """ 165 | Performs prediction on a single image. 166 | 167 | This method processes the input image using the AltCLIPProcessor 168 | and performs a forward pass through the AltCLIPModel to obtain 169 | classification probabilities for each category. It returns a 170 | FiftyOne Classification object containing the predicted label, 171 | logits, and confidence score. 172 | 173 | Args: 174 | image (PIL.Image): The input image to classify. 175 | 176 | Returns: 177 | fo.Classification: The classification result containing the 178 | predicted label, logits, and confidence score. 179 | """ 180 | inputs = self.processor( 181 | text=self.candidate_labels, 182 | images=image, 183 | return_tensors="pt", 184 | padding=True 185 | ) 186 | 187 | inputs = inputs.to(self.device) 188 | 189 | with torch.no_grad(): 190 | outputs = self.model(**inputs) 191 | 192 | logits_per_image = outputs.logits_per_image 193 | 194 | # Move to CPU only if necessary 195 | if logits_per_image.device.type != 'cpu': 196 | logits_per_image = logits_per_image.cpu() 197 | 198 | probs = logits_per_image.softmax(dim=1).numpy() 199 | 200 | return fo.Classification( 201 | label=self.categories[probs.argmax()], 202 | logits=logits_per_image.squeeze().numpy(), 203 | confidence=np.amax(probs[0]), 204 | ) 205 | 206 | def predict(self, args): 207 | """ 208 | Converts input data to an image and performs prediction. 209 | 210 | This method takes input data, converts it into a PIL image, 211 | and then uses the `_predict` method to perform classification. 212 | It returns the prediction results as a FiftyOne Classification object. 213 | 214 | Args: 215 | args (numpy.ndarray): The input data to be converted into an image. 216 | 217 | Returns: 218 | fo.Classification: The classification result containing the 219 | predicted label, logits, and confidence score. 220 | """ 221 | image = Image.fromarray(args) 222 | predictions = self._predict(image) 223 | return predictions 224 | 225 | def _predict_all(self, images): 226 | return [self._predict(image) for image in images] 227 | 228 | 229 | def AltCLIP_activator(): 230 | return find_spec("transformers") is not None 231 | 232 | 233 | class AlignZeroShotModel(Model): 234 | """ 235 | AlignZeroShotModel is a class for zero-shot image classification using the Align model. 236 | 237 | This class leverages the Align model from the `transformers` library to perform 238 | zero-shot classification on images. It initializes with a configuration that 239 | specifies the categories for classification. The model processes input images 240 | and predicts the most likely category from the provided list. 241 | 242 | Attributes: 243 | categories (list): A list of category labels for classification. 244 | candidate_labels (list): A list of formatted labels for the Align model. 245 | processor (AlignProcessor): The processor for preparing inputs for the model. 246 | model (AlignModel): The pre-trained Align model for classification. 247 | 248 | Methods: 249 | media_type: Returns the type of media the model works with, which is "image". 250 | _predict(image): Performs prediction on a single image. 251 | predict(args): Converts input data to an image and performs prediction. 252 | _predict_all(images): Performs prediction on a list of images. 253 | """ 254 | def __init__(self, config): 255 | self.categories = config.get("categories", None) 256 | self.candidate_labels = [ 257 | f"a photo of a {cat}" for cat in self.categories 258 | ] 259 | 260 | from transformers import AlignProcessor, AlignModel 261 | 262 | self.processor = AlignProcessor.from_pretrained("kakaobrain/align-base") 263 | self.model = AlignModel.from_pretrained("kakaobrain/align-base") 264 | 265 | # Set up device 266 | self.device = get_device() 267 | 268 | # Move model to appropriate device and set to eval mode 269 | self.model = self.model.to(self.device) 270 | self.model.eval() 271 | 272 | @property 273 | def media_type(self): 274 | return "image" 275 | 276 | def _predict(self, image): 277 | """ 278 | Performs prediction on a single image. 279 | 280 | This method takes an image as input and processes it using the Align model 281 | to predict the most likely category from the pre-defined list of categories. 282 | It uses the processor to prepare the input and the model to generate predictions. 283 | The method returns a `fiftyone.core.labels.Classification` object containing 284 | the predicted label, logits, and confidence score. 285 | 286 | Args: 287 | image (PIL.Image.Image): The input image for classification. 288 | 289 | Returns: 290 | fiftyone.core.labels.Classification: The classification result with the 291 | predicted label, logits, and confidence score. 292 | """ 293 | inputs = self.processor( 294 | text=self.candidate_labels, 295 | images=image, 296 | return_tensors="pt" 297 | ) 298 | 299 | inputs = inputs.to(self.device) 300 | 301 | with torch.no_grad(): 302 | outputs = self.model(**inputs) 303 | 304 | logits_per_image = outputs.logits_per_image 305 | 306 | # Move to CPU only if necessary 307 | if logits_per_image.device.type != 'cpu': 308 | logits_per_image = logits_per_image.cpu() 309 | 310 | probs = logits_per_image.softmax(dim=1).numpy() 311 | 312 | return fo.Classification( 313 | label=self.categories[probs.argmax()], 314 | logits=logits_per_image.squeeze().numpy(), 315 | confidence=np.amax(probs[0]), 316 | ) 317 | 318 | def predict(self, args): 319 | """ 320 | Predicts the category of the given image. 321 | 322 | This method takes an image in the form of a numpy array, converts it 323 | to a PIL Image, and then uses the `_predict` method to classify the 324 | image. The classification result is returned as a 325 | `fiftyone.core.labels.Classification` object. 326 | 327 | Args: 328 | args (np.ndarray): The input image as a numpy array. 329 | 330 | Returns: 331 | fiftyone.core.labels.Classification: The classification result with 332 | the predicted label, logits, and confidence score. 333 | """ 334 | image = Image.fromarray(args) 335 | predictions = self._predict(image) 336 | return predictions 337 | 338 | def _predict_all(self, images): 339 | return [self._predict(image) for image in images] 340 | 341 | 342 | def Align_activator(): 343 | return find_spec("transformers") is not None 344 | 345 | 346 | def OpenCLIPZeroShotModel(config): 347 | """ 348 | Initializes and returns an OpenCLIP zero-shot model based on the provided configuration. 349 | 350 | This function loads a pre-trained OpenCLIP model using the specified configuration 351 | parameters. The model is initialized with a text prompt and a set of categories 352 | for zero-shot classification tasks. 353 | 354 | Args: 355 | config (dict): A dictionary containing configuration parameters for the model. 356 | - "categories" (list, optional): A list of category names for classification. 357 | - "clip_model" (str, optional): The name of the CLIP model architecture to use. 358 | Defaults to "ViT-B-32". 359 | - "pretrained" (str, optional): The name of the pre-trained weights to load. 360 | Defaults to "openai". 361 | 362 | Returns: 363 | An instance of the OpenCLIP model configured for zero-shot classification. 364 | """ 365 | cats = config.get("categories", None) 366 | clip_model = config.get("clip_model", "ViT-B-32") 367 | pretrained = config.get("pretrained", "openai") 368 | 369 | model = foz.load_zoo_model( 370 | "open-clip-torch", 371 | clip_model=clip_model, 372 | pretrained=pretrained, 373 | text_prompt="A photo of a", 374 | classes=cats, 375 | ) 376 | 377 | return model 378 | 379 | def OpenCLIP_activator(): 380 | return find_spec("open_clip") is not None 381 | 382 | class AIMV2ZeroShotModel(Model): 383 | """Zero-shot image classification model using Apple's AIM-V2. 384 | 385 | AIM-V2 (Apple Image Models V2) are vision-language models from Apple that achieve 386 | state-of-the-art performance on various vision tasks. 387 | 388 | Available models: 389 | - apple/aimv2-large-patch14-224-lit: LiT-tuned variant 390 | 391 | Args: 392 | config (dict): Configuration dictionary containing: 393 | - categories (list): List of category labels for classification 394 | - model_name (str, optional): Full model name including organization. 395 | Defaults to "apple/aimv2-large-patch14-224-lit" 396 | 397 | Attributes: 398 | categories (list): Available classification categories 399 | candidate_labels (list): Text prompts generated from categories 400 | model (AutoModel): The underlying AIM-V2 model 401 | processor (AutoProcessor): Processor for preparing inputs 402 | """ 403 | 404 | def __init__(self, config): 405 | self.categories = config.get("categories", None) 406 | if self.categories is None: 407 | raise ValueError("Categories must be provided in config") 408 | 409 | self.candidate_labels = [ 410 | f"Picture of a {cat}." for cat in self.categories 411 | ] 412 | 413 | model_name = config.get( 414 | "model_name", 415 | "apple/aimv2-large-patch14-224-lit" 416 | ) 417 | 418 | from transformers import AutoProcessor, AutoModel 419 | 420 | # Set up device 421 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 422 | 423 | # Initialize model and processor 424 | self.processor = AutoProcessor.from_pretrained( 425 | model_name, 426 | trust_remote_code=True 427 | ) 428 | 429 | self.model = AutoModel.from_pretrained( 430 | model_name, 431 | trust_remote_code=True 432 | ) 433 | 434 | # Move model to appropriate device and set to eval mode 435 | self.model = self.model.to(self.device) 436 | self.model.eval() 437 | 438 | @property 439 | def media_type(self): 440 | """The type of media handled by this model. 441 | 442 | Returns: 443 | str: Always returns 'image' 444 | """ 445 | return "image" 446 | 447 | def _predict(self, image): 448 | """Internal prediction method for a single image. 449 | 450 | Args: 451 | image (PIL.Image): Input image to classify 452 | 453 | Returns: 454 | fo.Classification: Classification result containing: 455 | - label: Predicted category 456 | - logits: Raw model outputs 457 | - confidence: Prediction confidence score 458 | """ 459 | inputs = self.processor( 460 | text=self.candidate_labels, 461 | images=image, 462 | add_special_tokens=True, 463 | truncation=True, 464 | padding=True, 465 | return_tensors="pt" 466 | ) 467 | 468 | inputs = inputs.to(self.device) 469 | 470 | with torch.no_grad(): 471 | outputs = self.model(**inputs) 472 | 473 | logits_per_image = outputs.logits_per_image 474 | 475 | # Move to CPU only if necessary 476 | if logits_per_image.device.type != 'cpu': 477 | logits_per_image = logits_per_image.cpu() 478 | 479 | probs = logits_per_image.softmax(dim=-1).numpy() 480 | 481 | return fo.Classification( 482 | label=self.categories[probs.argmax()], 483 | logits=logits_per_image.squeeze().numpy(), 484 | confidence=np.amax(probs[0]), 485 | ) 486 | 487 | def predict(self, args): 488 | """Public prediction interface for numpy array input. 489 | 490 | Args: 491 | args (np.ndarray): Input image as numpy array 492 | 493 | Returns: 494 | fo.Classification: Classification result 495 | """ 496 | image = Image.fromarray(args) 497 | predictions = self._predict(image) 498 | return predictions 499 | 500 | def _predict_all(self, images): 501 | """Batch prediction for multiple images. 502 | 503 | Args: 504 | images (list): List of images to classify 505 | 506 | Returns: 507 | list: List of fo.Classification results 508 | """ 509 | return [self._predict(image) for image in images] 510 | 511 | 512 | def AIMV2_activator(): 513 | """Check if required dependencies for AIM-V2 are available.""" 514 | try: 515 | from transformers import AutoProcessor, AutoModel 516 | return True 517 | except ImportError: 518 | return False 519 | 520 | 521 | OPEN_CLIP_MODEL_TYPES = { 522 | "CLIPA": CLIPA_MODELS, 523 | "DFN CLIP": DFN_CLIP_MODELS, 524 | "EVA-CLIP": EVA_CLIP_MODELS, 525 | "MetaCLIP": META_CLIP_MODELS, 526 | "SigLIP": SIGLIP_MODELS, 527 | } 528 | 529 | 530 | def build_classification_models_dict(): 531 | """ 532 | Builds a dictionary of classification models available for use. 533 | 534 | This function constructs a dictionary where each key is a string representing 535 | the name of a classification model type, and the value is a dictionary containing 536 | the following keys: 537 | - "activator": A function that checks if the model's dependencies are available. 538 | - "model": A function that initializes and returns the model. 539 | - "submodels": Additional model configurations or variants, if any. 540 | - "name": The display name of the model. 541 | 542 | Returns: 543 | dict: A dictionary of available classification models. 544 | """ 545 | cms = {} 546 | 547 | # Add CLIP (OpenAI) if available 548 | if OpenCLIP_activator(): 549 | cms["CLIP (OpenAI)"] = { 550 | "activator": CLIP_activator, 551 | "model": CLIPZeroShotModel, 552 | "submodels": None, 553 | "name": "CLIP (OpenAI)", 554 | } 555 | 556 | # Add ALIGN if available 557 | if Align_activator(): 558 | cms["ALIGN"] = { 559 | "activator": Align_activator, 560 | "model": AlignZeroShotModel, 561 | "submodels": None, 562 | "name": "ALIGN", 563 | } 564 | 565 | # Add AltCLIP if available 566 | if AltCLIP_activator(): 567 | cms["AltCLIP"] = { 568 | "activator": AltCLIP_activator, 569 | "model": AltCLIPZeroShotModel, 570 | "submodels": None, 571 | "name": "AltCLIP", 572 | } 573 | 574 | # Add Apple AIMv2 if available 575 | if AIMV2_activator(): 576 | cms["AIMv2"] = { 577 | "activator": AIMV2_activator, 578 | "model": AIMV2ZeroShotModel, 579 | "submodels": None, 580 | "name": "AIMv2", 581 | } 582 | 583 | # Add OpenCLIP models if available 584 | for key, value in OPEN_CLIP_MODEL_TYPES.items(): 585 | cms[key] = { 586 | "activator": OpenCLIP_activator, 587 | "model": OpenCLIPZeroShotModel, 588 | "submodels": value, 589 | "name": key, 590 | } 591 | 592 | return cms 593 | 594 | CLASSIFICATION_MODELS = build_classification_models_dict() 595 | 596 | 597 | def _get_model(model_name, config): 598 | return CLASSIFICATION_MODELS[model_name]["model"](config) 599 | 600 | 601 | def run_zero_shot_classification( 602 | dataset, 603 | model_name, 604 | label_field, 605 | categories, 606 | architecture=None, 607 | pretrained=None, 608 | **kwargs, 609 | ): 610 | config = { 611 | "categories": categories, 612 | "clip_model": architecture, 613 | "pretrained": pretrained, 614 | } 615 | 616 | model = _get_model(model_name, config) 617 | 618 | dataset.apply_model(model, label_field=label_field) -------------------------------------------------------------------------------- /detection.py: -------------------------------------------------------------------------------- 1 | """Zero Shot Detection. 2 | 3 | | Copyright 2017-2023, Voxel51, Inc. 4 | | `voxel51.com `_ 5 | | 6 | """ 7 | 8 | from importlib.util import find_spec 9 | import pkg_resources 10 | 11 | from PIL import Image 12 | 13 | import fiftyone as fo 14 | import fiftyone.zoo as foz 15 | from fiftyone.core.models import Model 16 | 17 | YOLO_WORLD_PRETRAINS = ( 18 | "yolov8s-world", 19 | "yolov8s-worldv2", 20 | "yolov8m-world", 21 | "yolov8m-worldv2", 22 | "yolov8l-world", 23 | "yolov8l-worldv2", 24 | "yolov8x-world", 25 | "yolov8x-worldv2", 26 | ) 27 | 28 | 29 | class OwlViTZeroShotModel(Model): 30 | def __init__(self, config): 31 | self.checkpoint = "google/owlvit-base-patch32" 32 | 33 | self.candidate_labels = config.get("categories", None) 34 | 35 | from transformers import pipeline 36 | 37 | self.model = pipeline( 38 | model=self.checkpoint, task="zero-shot-object-detection" 39 | ) 40 | 41 | @property 42 | def media_type(self): 43 | return "image" 44 | 45 | def predict(self, args): 46 | image = Image.fromarray(args) 47 | predictions = self._predict(image) 48 | return predictions 49 | 50 | def _predict(self, image): 51 | raw_predictions = self.model( 52 | image, candidate_labels=self.candidate_labels 53 | ) 54 | 55 | size = image.size 56 | w, h = size[0], size[1] 57 | 58 | detections = [] 59 | for prediction in raw_predictions: 60 | score, box = prediction["score"], prediction["box"] 61 | bounding_box = [ 62 | box["xmin"] / w, 63 | box["ymin"] / h, 64 | box["xmax"] / w, 65 | box["ymax"] / h, 66 | ] 67 | ### constrain bounding box to [0, 1] 68 | bounding_box[0] = max(0, bounding_box[0]) 69 | bounding_box[1] = max(0, bounding_box[1]) 70 | bounding_box[2] = min(1, bounding_box[2]) 71 | bounding_box[3] = min(1, bounding_box[3]) 72 | 73 | ### convert to (x, y, w, h) 74 | bounding_box[2] = bounding_box[2] - bounding_box[0] 75 | bounding_box[3] = bounding_box[3] - bounding_box[1] 76 | 77 | label = prediction["label"] 78 | 79 | detection = fo.Detection( 80 | label=label, 81 | bounding_box=bounding_box, 82 | confidence=score, 83 | ) 84 | detections.append(detection) 85 | 86 | return fo.Detections(detections=detections) 87 | 88 | def predict_all(self, samples, args): 89 | pass 90 | 91 | 92 | def OwlViT_activator(): 93 | return find_spec("transformers") is not None 94 | 95 | 96 | def GroundingDINO(config): 97 | classes = config.get("categories", None) 98 | model = foz.load_zoo_model( 99 | "zero-shot-detection-transformer-torch", 100 | name_or_path="IDEA-Research/grounding-dino-tiny", 101 | classes=classes, 102 | ) 103 | return model 104 | 105 | 106 | def GroundingDINO_activator(): 107 | if find_spec("transformers") is None: 108 | return False 109 | required_version = "4.40.0" 110 | installed_version = pkg_resources.get_distribution("transformers").version 111 | if installed_version < required_version: 112 | return False 113 | 114 | required_fiftyone_version = "0.24.0" 115 | installed_fiftyone_version = pkg_resources.get_distribution( 116 | "fiftyone" 117 | ).version 118 | if installed_fiftyone_version < required_fiftyone_version: 119 | return False 120 | 121 | return True 122 | 123 | 124 | def YOLOWorldModel(config): 125 | classes = config.get("categories", None) 126 | pretrained = config.get("pretrained", "yolov8l-world") 127 | if "v2" in pretrained: 128 | from ultralytics import YOLO 129 | 130 | model = YOLO(pretrained + ".pt") 131 | model.set_classes(classes) 132 | import fiftyone.utils.ultralytics as fouu 133 | 134 | model = fouu.convert_ultralytics_model(model) 135 | else: 136 | model = foz.load_zoo_model(pretrained + "-torch", classes=classes) 137 | return model 138 | 139 | 140 | def YOLOWorld_activator(): 141 | if find_spec("ultralytics") is None: 142 | return False 143 | required_version = "8.1.42" 144 | installed_version = pkg_resources.get_distribution("ultralytics").version 145 | return installed_version >= required_version 146 | 147 | 148 | def build_detection_models_dict(): 149 | dms = {} 150 | 151 | if YOLOWorld_activator(): 152 | dms["YOLO-World"] = { 153 | "activator": YOLOWorld_activator, 154 | "model": YOLOWorldModel, 155 | "submodels": YOLO_WORLD_PRETRAINS, 156 | "name": "YOLO-World", 157 | } 158 | 159 | if OwlViT_activator(): 160 | dms["OwlViT"] = { 161 | "activator": OwlViT_activator, 162 | "model": OwlViTZeroShotModel, 163 | "submodels": None, 164 | "name": "OwlViT", 165 | } 166 | 167 | if GroundingDINO_activator(): 168 | dms["GroundingDINO"] = { 169 | "activator": GroundingDINO_activator, 170 | "model": GroundingDINO, 171 | "submodels": None, 172 | "name": "GroundingDINO", 173 | } 174 | 175 | return dms 176 | 177 | 178 | DETECTION_MODELS = build_detection_models_dict() 179 | 180 | 181 | def _get_model(model_name, config): 182 | return DETECTION_MODELS[model_name]["model"](config) 183 | 184 | 185 | def run_zero_shot_detection( 186 | dataset, model_name, label_field, categories, pretrained=None, **kwargs 187 | ): 188 | confidence = kwargs.get("confidence", 0.2) 189 | config = {"categories": categories, "pretrained": pretrained} 190 | model = _get_model(model_name, config) 191 | dataset.apply_model( 192 | model, label_field=label_field, confidence_thresh=confidence 193 | ) 194 | -------------------------------------------------------------------------------- /fiftyone.yml: -------------------------------------------------------------------------------- 1 | fiftyone: 2 | version: ">=0.24.0" 3 | name: "@jacobmarks/zero_shot_prediction" 4 | version: "1.3.3" 5 | description: "Run zero-shot (open vocabulary) prediction on your data!" 6 | url: "https://github.com/jacobmarks/zero-shot-predictions-plugin/" 7 | operators: 8 | - zero_shot_predict 9 | - zero_shot_classify 10 | - zero_shot_detect 11 | - zero_shot_instance_segment 12 | - zero_shot_semantic_segment 13 | -------------------------------------------------------------------------------- /instance_segmentation.py: -------------------------------------------------------------------------------- 1 | """Zero Shot Instance Segmentation. 2 | 3 | | Copyright 2017-2023, Voxel51, Inc. 4 | | `voxel51.com `_ 5 | | 6 | """ 7 | 8 | from importlib.util import find_spec 9 | import os 10 | 11 | 12 | from fiftyone.core.utils import add_sys_path 13 | import fiftyone.zoo as foz 14 | 15 | SAM_ARCHS = ("ViT-B", "ViT-H", "ViT-L") 16 | SAM_MODELS = [("", SA) for SA in SAM_ARCHS] 17 | 18 | 19 | def SAM_activator(): 20 | return True 21 | 22 | 23 | def build_instance_segmentation_models_dict(): 24 | sms = {} 25 | 26 | if SAM_activator(): 27 | sms["SAM"] = { 28 | "activator": SAM_activator, 29 | "model": "N/A", 30 | "name": "SAM", 31 | "submodels": SAM_MODELS, 32 | } 33 | 34 | return sms 35 | 36 | 37 | INSTANCE_SEGMENTATION_MODELS = build_instance_segmentation_models_dict() 38 | 39 | 40 | def _get_segmentation_model(architecture): 41 | zoo_model_name = ( 42 | "segment-anything-" + architecture.lower().replace("-", "") + "-torch" 43 | ) 44 | return foz.load_zoo_model(zoo_model_name) 45 | 46 | 47 | def run_zero_shot_instance_segmentation( 48 | dataset, 49 | model_name, 50 | label_field, 51 | categories, 52 | pretrained=None, 53 | architecture=None, 54 | **kwargs 55 | ): 56 | with add_sys_path(os.path.dirname(os.path.abspath(__file__))): 57 | # pylint: disable=no-name-in-module,import-error 58 | from detection import run_zero_shot_detection 59 | 60 | det_model_name, _ = model_name.split(" + ") 61 | det_pretrained, _ = pretrained.split(" + ") 62 | if det_pretrained == "": 63 | det_pretrained = None 64 | _, seg_architecture = architecture.split(" + ") 65 | 66 | run_zero_shot_detection( 67 | dataset, 68 | det_model_name, 69 | label_field, 70 | categories, 71 | pretrained=det_pretrained, 72 | **kwargs 73 | ) 74 | 75 | seg_model = _get_segmentation_model(seg_architecture) 76 | dataset.apply_model( 77 | seg_model, label_field=label_field, prompt_field=label_field 78 | ) 79 | -------------------------------------------------------------------------------- /semantic_segmentation.py: -------------------------------------------------------------------------------- 1 | """Zero Shot Semantic Segmentation. 2 | 3 | | Copyright 2017-2023, Voxel51, Inc. 4 | | `voxel51.com `_ 5 | | 6 | """ 7 | 8 | from importlib.util import find_spec 9 | from PIL import Image 10 | import torch 11 | 12 | import fiftyone as fo 13 | from fiftyone.core.models import Model 14 | 15 | 16 | class CLIPSegZeroShotModel(Model): 17 | def __init__(self, config): 18 | self.candidate_labels = config.get("categories", None) 19 | 20 | from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation 21 | 22 | self.processor = CLIPSegProcessor.from_pretrained( 23 | "CIDAS/clipseg-rd64-refined" 24 | ) 25 | self.model = CLIPSegForImageSegmentation.from_pretrained( 26 | "CIDAS/clipseg-rd64-refined" 27 | ) 28 | 29 | @property 30 | def media_type(self): 31 | return "image" 32 | 33 | def _predict(self, image): 34 | inputs = self.processor( 35 | text=self.candidate_labels, 36 | images=[image] * len(self.candidate_labels), 37 | padding="max_length", 38 | return_tensors="pt", 39 | ) 40 | with torch.no_grad(): 41 | outputs = self.model(**inputs) 42 | preds = outputs.logits.unsqueeze(1) 43 | # pylint: disable=no-member 44 | mask = torch.argmax(preds, dim=0).squeeze().numpy() 45 | return fo.Segmentation(mask=mask) 46 | 47 | def predict(self, args): 48 | image = Image.fromarray(args) 49 | predictions = self._predict(image) 50 | return predictions 51 | 52 | def predict_all(self, samples, args): 53 | pass 54 | 55 | 56 | def CLIPSeg_activator(): 57 | return find_spec("transformers") is not None 58 | 59 | 60 | class GroupViTZeroShotModel(Model): 61 | def __init__(self, config): 62 | cats = config.get("categories", None) 63 | self.candidate_labels = [f"a photo of a {cat}" for cat in cats] 64 | 65 | from transformers import AutoProcessor, GroupViTModel 66 | 67 | self.processor = AutoProcessor.from_pretrained( 68 | "nvidia/groupvit-gccyfcc" 69 | ) 70 | self.model = GroupViTModel.from_pretrained("nvidia/groupvit-gccyfcc") 71 | 72 | @property 73 | def media_type(self): 74 | return "image" 75 | 76 | def _predict(self, image): 77 | inputs = self.processor( 78 | text=self.candidate_labels, 79 | images=image, 80 | padding="max_length", 81 | return_tensors="pt", 82 | ) 83 | with torch.no_grad(): 84 | outputs = self.model(**inputs, output_segmentation=True) 85 | preds = outputs.segmentation_logits.squeeze() 86 | # pylint: disable=no-member 87 | mask = torch.argmax(preds, dim=0).numpy() 88 | return fo.Segmentation(mask=mask) 89 | 90 | def predict(self, args): 91 | image = Image.fromarray(args) 92 | image = image.resize((224, 224)) 93 | predictions = self._predict(image) 94 | return predictions 95 | 96 | def predict_all(self, samples, args): 97 | pass 98 | 99 | 100 | def GroupViT_activator(): 101 | return find_spec("transformers") is not None 102 | 103 | 104 | SEMANTIC_SEGMENTATION_MODELS = { 105 | "CLIPSeg": { 106 | "activator": CLIPSeg_activator, 107 | "model": CLIPSegZeroShotModel, 108 | "name": "CLIPSeg", 109 | }, 110 | "GroupViT": { 111 | "activator": GroupViT_activator, 112 | "model": GroupViTZeroShotModel, 113 | "name": "GroupViT", 114 | }, 115 | } 116 | 117 | 118 | def _get_model(model_name, config): 119 | return SEMANTIC_SEGMENTATION_MODELS[model_name]["model"](config) 120 | 121 | 122 | def run_zero_shot_semantic_segmentation( 123 | dataset, model_name, label_field, categories, **kwargs 124 | ): 125 | if "other" not in categories: 126 | categories.append("other") 127 | config = {"categories": categories} 128 | model = _get_model(model_name, config) 129 | dataset.apply_model(model, label_field=label_field) 130 | 131 | dataset.mask_targets[label_field] = { 132 | i: label for i, label in enumerate(categories) 133 | } 134 | dataset.save() 135 | --------------------------------------------------------------------------------