├── test.py ├── utils ├── __init__.py ├── modes.py ├── video.py ├── florence.py └── sam.py ├── test ├── 01_source.png ├── 02_source.png ├── 01_output_mask.png ├── 02_output_mask.png ├── 01_bounding_boxes.jpeg ├── 02_bounding_boxes.jpeg ├── 01_output_masked_image.png ├── 02_output_masked_image.png └── README.md ├── workflows ├── workflow.png └── workflow.json ├── videos ├── clip-07-camera-1.mp4 ├── clip-07-camera-2.mp4 └── clip-07-camera-3.mp4 ├── models └── sam2 │ ├── sam2_hiera_large.pt │ ├── sam2_hiera_small.pt │ ├── sam2_hiera_tiny.pt │ └── sam2_hiera_base_plus.pt ├── .gitignore ├── configs ├── __init__.py ├── sam2_hiera_b+.yaml ├── sam2_hiera_s.yaml ├── sam2_hiera_l.yaml └── sam2_hiera_t.yaml ├── requirements.txt ├── README.md ├── .gitattributes ├── __init__.py ├── LICENSE └── app.py /test.py: -------------------------------------------------------------------------------- 1 | __init__.py -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test/01_source.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdancer/ComfyUI_Florence2SAM2/HEAD/test/01_source.png -------------------------------------------------------------------------------- /test/02_source.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdancer/ComfyUI_Florence2SAM2/HEAD/test/02_source.png -------------------------------------------------------------------------------- /test/01_output_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdancer/ComfyUI_Florence2SAM2/HEAD/test/01_output_mask.png -------------------------------------------------------------------------------- /test/02_output_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdancer/ComfyUI_Florence2SAM2/HEAD/test/02_output_mask.png -------------------------------------------------------------------------------- /workflows/workflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdancer/ComfyUI_Florence2SAM2/HEAD/workflows/workflow.png -------------------------------------------------------------------------------- /test/01_bounding_boxes.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdancer/ComfyUI_Florence2SAM2/HEAD/test/01_bounding_boxes.jpeg -------------------------------------------------------------------------------- /test/02_bounding_boxes.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdancer/ComfyUI_Florence2SAM2/HEAD/test/02_bounding_boxes.jpeg -------------------------------------------------------------------------------- /test/01_output_masked_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdancer/ComfyUI_Florence2SAM2/HEAD/test/01_output_masked_image.png -------------------------------------------------------------------------------- /test/02_output_masked_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rdancer/ComfyUI_Florence2SAM2/HEAD/test/02_output_masked_image.png -------------------------------------------------------------------------------- /videos/clip-07-camera-1.mp4: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7944c1a5e9be241ebf11eb39f6302c3ce9d8482ca9f12e4268b252aeda6baee9 3 | size 5500081 4 | -------------------------------------------------------------------------------- /videos/clip-07-camera-2.mp4: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:abbfef6d422c9aa3968d14de6b78aecaf544c85423d401387e3d5e75ffee3497 3 | size 5467189 4 | -------------------------------------------------------------------------------- /videos/clip-07-camera-3.mp4: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e518f2ee6761d559bc864be2fec70ddc41244fbf3fea404c3158129a434ce879 3 | size 5397505 4 | -------------------------------------------------------------------------------- /models/sam2/sam2_hiera_large.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7442e4e9b732a508f80e141e7c2913437a3610ee0c77381a66658c3a445df87b 3 | size 897952466 4 | -------------------------------------------------------------------------------- /models/sam2/sam2_hiera_small.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:95949964d4e548409021d47b22712d5f1abf2564cc0c3c765ba599a24ac7dce3 3 | size 184309650 4 | -------------------------------------------------------------------------------- /models/sam2/sam2_hiera_tiny.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:65b50056e05bcb13694174f51bb6da89c894b57b75ccdf0ba6352c597c5d1125 3 | size 155906050 4 | -------------------------------------------------------------------------------- /models/sam2/sam2_hiera_base_plus.pt: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d0bb7f236400a49669ffdd1be617959a8b1d1065081789d7bbff88eded3a8071 3 | size 323493298 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /venv 2 | /.idea 3 | /tmp 4 | 5 | __pycache__/ 6 | 7 | .*.sw? 8 | *~ 9 | 10 | # something on my MacOS keeps saving these resource forks for every single file --rdancer 11 | .* 12 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | einops 3 | spaces 4 | timm 5 | transformers 6 | gradio 7 | supervision 8 | opencv-python 9 | pytest 10 | 11 | # Build dependency; getting errors so try putting it here 12 | flit_core>=3.2,<4 13 | #samv2 14 | # Rdancer's version of samv2 includes a fix that makes it not crash on Windows 15 | git+https://github.com/rdancer/samv2.git 16 | -------------------------------------------------------------------------------- /utils/modes.py: -------------------------------------------------------------------------------- 1 | IMAGE_OPEN_VOCABULARY_DETECTION_MODE = "open vocabulary detection + image masks" 2 | IMAGE_CAPTION_GROUNDING_MASKS_MODE = "caption + grounding + image masks" 3 | 4 | IMAGE_INFERENCE_MODES = [ 5 | IMAGE_OPEN_VOCABULARY_DETECTION_MODE, 6 | IMAGE_CAPTION_GROUNDING_MASKS_MODE 7 | ] 8 | 9 | VIDEO_OPEN_VOCABULARY_DETECTION_MODE = "open vocabulary detection + video masks" 10 | 11 | VIDEO_INFERENCE_MODES = [ 12 | VIDEO_OPEN_VOCABULARY_DETECTION_MODE 13 | ] 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ComfyUI custom node implementing Florence 2 + Segment Anything Model 2, based on [SkalskiP's HuggingFace space](https://huggingface.co/spaces/SkalskiP/florence-sam) 2 | 3 | ![sample workflow](workflows/workflow.png) 4 | 5 | *RdancerFlorence2SAM2GenerateMask* - the node is self-contained, and does not require separate model loaders. Models are lazy-loaded, and cached. Model unloading, if required, must be done manually. 6 | 7 | ## Testing 8 | 9 | Run `python test.py test/*_source.png "products"` 👉 the resulting images must pixel-match `test/*output*`. 10 | -------------------------------------------------------------------------------- /utils/video.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import shutil 4 | import uuid 5 | 6 | 7 | def create_directory(directory_path: str) -> None: 8 | if not os.path.exists(directory_path): 9 | os.makedirs(directory_path) 10 | 11 | 12 | def delete_directory(directory_path: str) -> None: 13 | if not os.path.exists(directory_path): 14 | raise FileNotFoundError(f"Directory '{directory_path}' does not exist.") 15 | 16 | try: 17 | shutil.rmtree(directory_path) 18 | except PermissionError: 19 | raise PermissionError( 20 | f"Permission denied: Unable to delete '{directory_path}'.") 21 | 22 | 23 | def generate_unique_name(): 24 | current_datetime = datetime.datetime.now().strftime("%Y%m%d%H%M%S") 25 | unique_id = uuid.uuid4() 26 | return f"{current_datetime}_{unique_id}" 27 | -------------------------------------------------------------------------------- /test/README.md: -------------------------------------------------------------------------------- 1 | 2 | - 1 image goes in and has a prompt "products" and then out comes a mask with all the segmented products in the image 3 | - But it could output a batch or list (I will probably just end up combining them all in the next node) 4 | 5 | Note: Instead of bounding boxes with labels we **just need the masks** that's it 6 | 7 | 8 | 9 | 10 | 19 | 20 | 29 | 30 |
source imagescreenshot (left: source; right: result)
11 | 12 | ![01](01_source.png) 13 | 14 | 15 | 16 | ![01](01_bounding_boxes.jpeg) 17 | 18 |
21 | 22 | ![02](02_source.png) 23 | 24 | 25 | 26 | ![02](02_bounding_boxes.jpeg) 27 | 28 |
31 | 32 | The test images have been confirmed to work with this version of the code, and the screenshots above were taken from the [HuggingFace space](https://huggingface.co/spaces/SkalskiP/florence-sam). -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.7z filter=lfs diff=lfs merge=lfs -text 2 | *.arrow filter=lfs diff=lfs merge=lfs -text 3 | *.bin filter=lfs diff=lfs merge=lfs -text 4 | *.bz2 filter=lfs diff=lfs merge=lfs -text 5 | *.ckpt filter=lfs diff=lfs merge=lfs -text 6 | *.ftz filter=lfs diff=lfs merge=lfs -text 7 | *.gz filter=lfs diff=lfs merge=lfs -text 8 | *.h5 filter=lfs diff=lfs merge=lfs -text 9 | *.joblib filter=lfs diff=lfs merge=lfs -text 10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text 11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text 12 | *.model filter=lfs diff=lfs merge=lfs -text 13 | *.msgpack filter=lfs diff=lfs merge=lfs -text 14 | *.npy filter=lfs diff=lfs merge=lfs -text 15 | *.npz filter=lfs diff=lfs merge=lfs -text 16 | *.onnx filter=lfs diff=lfs merge=lfs -text 17 | *.ot filter=lfs diff=lfs merge=lfs -text 18 | *.parquet filter=lfs diff=lfs merge=lfs -text 19 | *.pb filter=lfs diff=lfs merge=lfs -text 20 | *.pickle filter=lfs diff=lfs merge=lfs -text 21 | *.pkl filter=lfs diff=lfs merge=lfs -text 22 | *.pt filter=lfs diff=lfs merge=lfs -text 23 | *.pth filter=lfs diff=lfs merge=lfs -text 24 | *.rar filter=lfs diff=lfs merge=lfs -text 25 | *.safetensors filter=lfs diff=lfs merge=lfs -text 26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text 27 | *.tar.* filter=lfs diff=lfs merge=lfs -text 28 | *.tar filter=lfs diff=lfs merge=lfs -text 29 | *.tflite filter=lfs diff=lfs merge=lfs -text 30 | *.tgz filter=lfs diff=lfs merge=lfs -text 31 | *.wasm filter=lfs diff=lfs merge=lfs -text 32 | *.xz filter=lfs diff=lfs merge=lfs -text 33 | *.zip filter=lfs diff=lfs merge=lfs -text 34 | *.zst filter=lfs diff=lfs merge=lfs -text 35 | *tfevents* filter=lfs diff=lfs merge=lfs -text 36 | *.mp4 filter=lfs diff=lfs merge=lfs -text 37 | -------------------------------------------------------------------------------- /utils/florence.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union, Any, Tuple, Dict 3 | from unittest.mock import patch 4 | 5 | import torch 6 | from PIL import Image 7 | from transformers import AutoModelForCausalLM, AutoProcessor 8 | from transformers.dynamic_module_utils import get_imports 9 | 10 | FLORENCE_CHECKPOINT = "microsoft/Florence-2-base" 11 | FLORENCE_OBJECT_DETECTION_TASK = '' 12 | FLORENCE_DETAILED_CAPTION_TASK = '' 13 | FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK = '' 14 | FLORENCE_OPEN_VOCABULARY_DETECTION_TASK = '' 15 | FLORENCE_DENSE_REGION_CAPTION_TASK = '' 16 | 17 | 18 | def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]: 19 | """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72.""" 20 | if not str(filename).endswith("/modeling_florence2.py"): 21 | return get_imports(filename) 22 | imports = get_imports(filename) 23 | try: 24 | imports.remove("flash_attn") 25 | except ValueError: 26 | pass 27 | return imports 28 | 29 | 30 | def load_florence_model( 31 | device: torch.device, checkpoint: str = FLORENCE_CHECKPOINT 32 | ) -> Tuple[Any, Any]: 33 | with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): 34 | model = AutoModelForCausalLM.from_pretrained( 35 | checkpoint, trust_remote_code=True).to(device).eval() 36 | processor = AutoProcessor.from_pretrained( 37 | checkpoint, trust_remote_code=True) 38 | return model, processor 39 | 40 | 41 | def run_florence_inference( 42 | model: Any, 43 | processor: Any, 44 | device: torch.device, 45 | image: Image, 46 | task: str, 47 | text: str = "" 48 | ) -> Tuple[str, Dict]: 49 | prompt = task + text 50 | inputs = processor(text=prompt, images=image, return_tensors="pt").to(device) 51 | generated_ids = model.generate( 52 | input_ids=inputs["input_ids"], 53 | pixel_values=inputs["pixel_values"], 54 | max_new_tokens=1024, 55 | num_beams=3 56 | ) 57 | generated_text = processor.batch_decode( 58 | generated_ids, skip_special_tokens=False)[0] 59 | response = processor.post_process_generation( 60 | generated_text, task=task, image_size=image.size) 61 | return generated_text, response 62 | -------------------------------------------------------------------------------- /utils/sam.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import os 3 | 4 | import folder_paths 5 | import numpy as np 6 | import supervision as sv 7 | import torch 8 | from PIL import Image 9 | from sam2.build_sam import build_sam2, build_sam2_video_predictor 10 | from sam2.sam2_image_predictor import SAM2ImagePredictor 11 | 12 | model_to_config_map = { 13 | # models: sam2_hiera_base_plus.pt sam2_hiera_large.pt sam2_hiera_small.pt sam2_hiera_tiny.pt 14 | # configs: sam2_hiera_b+.yaml sam2_hiera_l.yaml sam2_hiera_s.yaml sam2_hiera_t.yaml 15 | "sam2_hiera_base_plus.pt": "sam2_hiera_b+.yaml", 16 | "sam2_hiera_large.pt": "sam2_hiera_l.yaml", 17 | "sam2_hiera_small.pt": "sam2_hiera_s.yaml", 18 | "sam2_hiera_tiny.pt": "sam2_hiera_t.yaml", 19 | } 20 | SAM_CHECKPOINT = "sam2_hiera_small.pt" 21 | SAM_CONFIG = "sam2_hiera_s.yaml" # from /usr/local/lib/python3.10/dist-packages/sam2/configs, *not* from either the models directory, or this package's directory 22 | 23 | def load_sam_image_model( 24 | device: torch.device, 25 | checkpoint: str = SAM_CHECKPOINT, 26 | config: str = None 27 | ) -> SAM2ImagePredictor: 28 | if config is None: 29 | config = model_to_config_map[checkpoint] 30 | import os 31 | 32 | # 1. Print the current working directory with flush=True 33 | current_working_directory = os.getcwd() 34 | print(f"Current working directory: {current_working_directory}", flush=True) 35 | 36 | # 2. Check if the "models" and "models/sam2" directories exist 37 | models_dir = folder_paths.models_dir 38 | sam2_dir = os.path.join(models_dir, "sam2") 39 | 40 | if os.path.exists(models_dir): 41 | print(f"'models' directory exists: {models_dir}", flush=True) 42 | else: 43 | print(f"'models' directory does not exist: {models_dir}", flush=True) 44 | 45 | if os.path.exists(sam2_dir): 46 | print(f"'models/sam2' directory exists: {sam2_dir}", flush=True) 47 | else: 48 | print(f"'models/sam2' directory does not exist: {sam2_dir}", flush=True) 49 | 50 | model_path = os.path.join(sam2_dir, checkpoint) 51 | if os.path.exists(model_path): 52 | print(f"'models/sam2/{checkpoint}' file exists: {model_path}", flush=True) 53 | else: 54 | print(f"'models/sam2/{checkpoint}' file does not exist: {model_path}", flush=True) 55 | 56 | model = build_sam2(config, model_path, device=device) 57 | return SAM2ImagePredictor(sam_model=model) 58 | 59 | 60 | def load_sam_video_model( 61 | device: torch.device, 62 | config: str = SAM_CONFIG, 63 | checkpoint: str = SAM_CHECKPOINT 64 | ) -> Any: 65 | return build_sam2_video_predictor(config, checkpoint, device=device) 66 | 67 | 68 | def run_sam_inference( 69 | model: Any, 70 | image: Image, 71 | detections: sv.Detections 72 | ) -> sv.Detections: 73 | image = np.array(image.convert("RGB")) 74 | model.set_image(image) 75 | mask, score, _ = model.predict(box=detections.xyxy, multimask_output=False) 76 | 77 | # dirty fix; remove this later 78 | if len(mask.shape) == 4: 79 | mask = np.squeeze(mask) 80 | 81 | detections.mask = mask.astype(bool) 82 | return detections 83 | -------------------------------------------------------------------------------- /configs/sam2_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 112 12 | num_heads: 2 13 | neck: 14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 15 | position_encoding: 16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 17 | num_pos_feats: 256 18 | normalize: true 19 | scale: null 20 | temperature: 10000 21 | d_model: 256 22 | backbone_channel_list: [896, 448, 224, 112] 23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 24 | fpn_interp_model: nearest 25 | 26 | memory_attention: 27 | _target_: sam2.modeling.memory_attention.MemoryAttention 28 | d_model: 256 29 | pos_enc_at_input: true 30 | layer: 31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 32 | activation: relu 33 | dim_feedforward: 2048 34 | dropout: 0.1 35 | pos_enc_at_attn: false 36 | self_attention: 37 | _target_: sam2.modeling.sam.transformer.RoPEAttention 38 | rope_theta: 10000.0 39 | feat_sizes: [32, 32] 40 | embedding_dim: 256 41 | num_heads: 1 42 | downsample_rate: 1 43 | dropout: 0.1 44 | d_model: 256 45 | pos_enc_at_cross_attn_keys: true 46 | pos_enc_at_cross_attn_queries: false 47 | cross_attention: 48 | _target_: sam2.modeling.sam.transformer.RoPEAttention 49 | rope_theta: 10000.0 50 | feat_sizes: [32, 32] 51 | rope_k_repeat: True 52 | embedding_dim: 256 53 | num_heads: 1 54 | downsample_rate: 1 55 | dropout: 0.1 56 | kv_in_dim: 64 57 | num_layers: 4 58 | 59 | memory_encoder: 60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 61 | out_dim: 64 62 | position_encoding: 63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 64 | num_pos_feats: 64 65 | normalize: true 66 | scale: null 67 | temperature: 10000 68 | mask_downsampler: 69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 70 | kernel_size: 3 71 | stride: 2 72 | padding: 1 73 | fuser: 74 | _target_: sam2.modeling.memory_encoder.Fuser 75 | layer: 76 | _target_: sam2.modeling.memory_encoder.CXBlock 77 | dim: 256 78 | kernel_size: 7 79 | padding: 3 80 | layer_scale_init_value: 1e-6 81 | use_dwconv: True # depth-wise convs 82 | num_layers: 2 83 | 84 | num_maskmem: 7 85 | image_size: 1024 86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 87 | sigmoid_scale_for_mem_enc: 20.0 88 | sigmoid_bias_for_mem_enc: -10.0 89 | use_mask_input_as_output_without_sam: true 90 | # Memory 91 | directly_add_no_mem_embed: true 92 | # use high-resolution feature map in the SAM mask decoder 93 | use_high_res_features_in_sam: true 94 | # output 3 masks on the first click on initial conditioning frames 95 | multimask_output_in_sam: true 96 | # SAM heads 97 | iou_prediction_use_sigmoid: True 98 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 99 | use_obj_ptrs_in_encoder: true 100 | add_tpos_enc_to_obj_ptrs: false 101 | only_obj_ptrs_in_the_past_for_eval: true 102 | # object occlusion prediction 103 | pred_obj_scores: true 104 | pred_obj_scores_mlp: true 105 | fixed_no_obj_ptr: true 106 | # multimask tracking settings 107 | multimask_output_for_tracking: true 108 | use_multimask_token_for_obj_ptr: true 109 | multimask_min_pt_num: 0 110 | multimask_max_pt_num: 1 111 | use_mlp_for_obj_ptr_proj: true 112 | # Compilation flag 113 | compile_image_encoder: False 114 | -------------------------------------------------------------------------------- /configs/sam2_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 11, 2] 14 | global_att_blocks: [7, 10, 13] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | sigmoid_scale_for_mem_enc: 20.0 91 | sigmoid_bias_for_mem_enc: -10.0 92 | use_mask_input_as_output_without_sam: true 93 | # Memory 94 | directly_add_no_mem_embed: true 95 | # use high-resolution feature map in the SAM mask decoder 96 | use_high_res_features_in_sam: true 97 | # output 3 masks on the first click on initial conditioning frames 98 | multimask_output_in_sam: true 99 | # SAM heads 100 | iou_prediction_use_sigmoid: True 101 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 102 | use_obj_ptrs_in_encoder: true 103 | add_tpos_enc_to_obj_ptrs: false 104 | only_obj_ptrs_in_the_past_for_eval: true 105 | # object occlusion prediction 106 | pred_obj_scores: true 107 | pred_obj_scores_mlp: true 108 | fixed_no_obj_ptr: true 109 | # multimask tracking settings 110 | multimask_output_for_tracking: true 111 | use_multimask_token_for_obj_ptr: true 112 | multimask_min_pt_num: 0 113 | multimask_max_pt_num: 1 114 | use_mlp_for_obj_ptr_proj: true 115 | # Compilation flag 116 | compile_image_encoder: False 117 | -------------------------------------------------------------------------------- /configs/sam2_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 144 12 | num_heads: 2 13 | stages: [2, 6, 36, 4] 14 | global_att_blocks: [23, 33, 43] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | window_spec: [8, 4, 16, 8] 17 | neck: 18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 19 | position_encoding: 20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 21 | num_pos_feats: 256 22 | normalize: true 23 | scale: null 24 | temperature: 10000 25 | d_model: 256 26 | backbone_channel_list: [1152, 576, 288, 144] 27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 28 | fpn_interp_model: nearest 29 | 30 | memory_attention: 31 | _target_: sam2.modeling.memory_attention.MemoryAttention 32 | d_model: 256 33 | pos_enc_at_input: true 34 | layer: 35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 36 | activation: relu 37 | dim_feedforward: 2048 38 | dropout: 0.1 39 | pos_enc_at_attn: false 40 | self_attention: 41 | _target_: sam2.modeling.sam.transformer.RoPEAttention 42 | rope_theta: 10000.0 43 | feat_sizes: [32, 32] 44 | embedding_dim: 256 45 | num_heads: 1 46 | downsample_rate: 1 47 | dropout: 0.1 48 | d_model: 256 49 | pos_enc_at_cross_attn_keys: true 50 | pos_enc_at_cross_attn_queries: false 51 | cross_attention: 52 | _target_: sam2.modeling.sam.transformer.RoPEAttention 53 | rope_theta: 10000.0 54 | feat_sizes: [32, 32] 55 | rope_k_repeat: True 56 | embedding_dim: 256 57 | num_heads: 1 58 | downsample_rate: 1 59 | dropout: 0.1 60 | kv_in_dim: 64 61 | num_layers: 4 62 | 63 | memory_encoder: 64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 65 | out_dim: 64 66 | position_encoding: 67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 68 | num_pos_feats: 64 69 | normalize: true 70 | scale: null 71 | temperature: 10000 72 | mask_downsampler: 73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 74 | kernel_size: 3 75 | stride: 2 76 | padding: 1 77 | fuser: 78 | _target_: sam2.modeling.memory_encoder.Fuser 79 | layer: 80 | _target_: sam2.modeling.memory_encoder.CXBlock 81 | dim: 256 82 | kernel_size: 7 83 | padding: 3 84 | layer_scale_init_value: 1e-6 85 | use_dwconv: True # depth-wise convs 86 | num_layers: 2 87 | 88 | num_maskmem: 7 89 | image_size: 1024 90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | compile_image_encoder: False 118 | -------------------------------------------------------------------------------- /configs/sam2_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 7, 2] 14 | global_att_blocks: [5, 7, 9] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | # SAM decoder 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | # HieraT does not currently support compilation, should always be set to False 118 | compile_image_encoder: False 119 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import numpy as np 4 | import os 5 | 6 | try: 7 | from app import process_image 8 | except ImportError: 9 | # We're running as a module 10 | from .app import process_image 11 | from .utils.sam import model_to_config_map as sam_model_to_config_map 12 | 13 | 14 | # Format conversion helpers adapted from LayerStyle -- but LayerStyle has them 15 | # wrong: this is not the place to squeeze/unsqueeze. 16 | # 17 | # - [tensor2pil](https://github.com/chflame163/ComfyUI_LayerStyle/blob/28c1a4f3082d0af5067a7bc4b72951a8dd47b9b8/py/imagefunc.py#L131) 18 | # - [pil2tensor](https://github.com/chflame163/ComfyUI_LayerStyle/blob/28c1a4f3082d0af5067a7bc4b72951a8dd47b9b8/py/imagefunc.py#L111) 19 | # 20 | # LayerStyle wrongly, misguidedly, and confusingly, un/squeezes the batch 21 | # dimension in the helpers, and then in the main code, they have to reverse 22 | # that. The batch dimension is there for a reason, people, it's not something 23 | # to be abstracted away! So our version leaves that out. 24 | def tensor2pil(t_image: torch.Tensor) -> Image.Image: 25 | return Image.fromarray(np.clip(255.0 * t_image.cpu().numpy(), 0, 255).astype(np.uint8)) 26 | def pil2tensor(image: Image.Image) -> torch.Tensor: 27 | return torch.from_numpy(np.array(image).astype(np.float32) / 255.0) 28 | 29 | 30 | 31 | 32 | class F2S2GenerateMask: 33 | def __init__(self): 34 | if os.name == "nt": 35 | self._fix_problems() 36 | 37 | def _fix_problems(self): 38 | print(f"torch version: {torch.__version__}") 39 | print(f"torch CUDA available: {torch.cuda.is_available()}") 40 | print("disabling gradients and optimisations") 41 | torch.backends.cudnn.enabled = False 42 | # print("setting CUDA_LAUNCH_BLOCKING=1 TORCH_USE_RTLD_GLOBAL=1") 43 | # os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 44 | # os.environ['TORCH_USE_RTLD_GLOBAL'] = '1' 45 | # Check if the environment variable is already set 46 | if os.getenv('TORCH_CUDNN_SDPA_ENABLED') != '1': 47 | os.environ['TORCH_CUDNN_SDPA_ENABLED'] = '1' 48 | print("CuDNN SDPA enabled.") 49 | else: 50 | print("CuDNN SDPA was already enabled.") 51 | 52 | @classmethod 53 | def INPUT_TYPES(cls): 54 | model_list = list(sam_model_to_config_map.keys()) 55 | model_list.sort() 56 | device_list = ["cuda", "cpu"] 57 | return { 58 | "required": { 59 | "sam2_model": (model_list, {"default": "sam2_hiera_small.pt"}), 60 | "device": (device_list,), 61 | "image": ("IMAGE",), 62 | "prompt": ("STRING", {"default": "subject"}), 63 | }, 64 | "optional": { 65 | "keep_model_loaded": ("BOOLEAN", {"default": False}), 66 | } 67 | } 68 | 69 | RETURN_TYPES = ("IMAGE", "MASK", "IMAGE",) 70 | RETURN_NAMES = ("annotated_image", "mask", "masked_image",) 71 | FUNCTION = "_process_image" 72 | CATEGORY = "💃rDancer" 73 | 74 | def _process_image(self, sam2_model: str, device: str, image: torch.Tensor, prompt: str = None, keep_model_loaded: bool = False): 75 | torch_device = torch.device(device) 76 | prompt = prompt.strip() if prompt else "" 77 | annotated_images, masks, masked_images = [], [], [] 78 | # Convert image from tensor to PIL 79 | # the image has an extra batch dimension, despite the variable name 80 | for i, img in enumerate(image): 81 | img = tensor2pil(img).convert("RGB") 82 | keep_model_loaded = keep_model_loaded if i == (image.size(0) - 1) else True 83 | annotated_image, mask, masked_image = process_image(torch_device, sam2_model, img, prompt, keep_model_loaded) 84 | annotated_images.append(pil2tensor(annotated_image)) 85 | masks.append(pil2tensor(mask)) 86 | masked_images.append(pil2tensor(masked_image)) 87 | annotated_images = torch.stack(annotated_images) 88 | masks = torch.stack(masks) 89 | masked_images = torch.stack(masked_images) 90 | return (annotated_images, masks, masked_images, ) 91 | 92 | 93 | NODE_CLASS_MAPPINGS = { 94 | "RdancerFlorence2SAM2GenerateMask": F2S2GenerateMask 95 | } 96 | 97 | __all__ = ["NODE_CLASS_MAPPINGS"] 98 | 99 | if __name__ == "__main__": 100 | # detect which parameters are filenames -- those are images 101 | # the rest are prompts 102 | # call process_image with the images and prompts 103 | # save the output images 104 | # return the output images' filenames 105 | import sys 106 | import os 107 | import argparse 108 | from app import process_image 109 | 110 | # import rdancer_debug # will listen for debugger to attach 111 | 112 | def my_process_image(image_path, prompt): 113 | from utils.sam import SAM_CHECKPOINT 114 | image = Image.open(image_path).convert("RGB") 115 | annotated_image, mask, masked_image = process_image(SAM_CHECKPOINT, image, prompt) 116 | output_image_path, output_mask_path, output_masked_image_path = f"output_image_{os.path.basename(image_path)}", f"output_mask_{os.path.basename(image_path)}", f"output_masked_image_{os.path.basename(image_path)}" 117 | annotated_image.save(output_image_path) 118 | mask.save(output_mask_path) 119 | masked_image.save(output_masked_image_path) 120 | return output_image_path, output_mask_path, output_masked_image_path 121 | 122 | if len(sys.argv) < 2: 123 | print(f"Usage: python {os.path.basename(sys.argv[0])} [ ...] []") 124 | sys.exit(1) 125 | 126 | # test which exist as filenames 127 | images = [] 128 | prompts = [] 129 | 130 | for arg in sys.argv[1:]: 131 | if not os.path.exists(arg): 132 | prompts.append(arg) 133 | else: 134 | images.append(arg) 135 | 136 | if len(prompts) > 1: 137 | raise ValueError("At most one prompt is required") 138 | if len(images) < 1: 139 | raise ValueError("At least one image is required") 140 | 141 | prompt = prompts[0].strip() if prompts else None 142 | 143 | print(f"Processing {len(images)} image{'' if len(images) == 1 else 's'} with prompt: {prompt}") 144 | 145 | from app import process_image 146 | 147 | for image in images: 148 | output_image_path, output_mask_path, output_masked_image_path = my_process_image(image, prompt) 149 | print(f"Saved output image to {output_image_path} and mask to {output_mask_path} and masked image to {output_masked_image_path}") 150 | 151 | -------------------------------------------------------------------------------- /workflows/workflow.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 61, 3 | "last_link_id": 68, 4 | "nodes": [ 5 | { 6 | "id": 30, 7 | "type": "LoadImage", 8 | "pos": { 9 | "0": 360, 10 | "1": 520 11 | }, 12 | "size": { 13 | "0": 278.3598327636719, 14 | "1": 400.1376647949219 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "inputs": [], 20 | "outputs": [ 21 | { 22 | "name": "IMAGE", 23 | "type": "IMAGE", 24 | "links": [ 25 | 28 26 | ], 27 | "slot_index": 0, 28 | "shape": 3 29 | }, 30 | { 31 | "name": "MASK", 32 | "type": "MASK", 33 | "links": null, 34 | "shape": 3 35 | } 36 | ], 37 | "properties": { 38 | "Node name for S&R": "LoadImage" 39 | }, 40 | "widgets_values": [ 41 | "Screenshot 2024-09-16 221105.png", 42 | "image" 43 | ] 44 | }, 45 | { 46 | "id": 36, 47 | "type": "LoadImage", 48 | "pos": { 49 | "0": 650, 50 | "1": 520 51 | }, 52 | "size": { 53 | "0": 278.3598327636719, 54 | "1": 400.1376647949219 55 | }, 56 | "flags": {}, 57 | "order": 1, 58 | "mode": 0, 59 | "inputs": [], 60 | "outputs": [ 61 | { 62 | "name": "IMAGE", 63 | "type": "IMAGE", 64 | "links": [ 65 | 29 66 | ], 67 | "slot_index": 0, 68 | "shape": 3 69 | }, 70 | { 71 | "name": "MASK", 72 | "type": "MASK", 73 | "links": null, 74 | "shape": 3 75 | } 76 | ], 77 | "properties": { 78 | "Node name for S&R": "LoadImage" 79 | }, 80 | "widgets_values": [ 81 | "Screenshot 2024-09-16 221234.png", 82 | "image" 83 | ] 84 | }, 85 | { 86 | "id": 31, 87 | "type": "PreviewImage", 88 | "pos": { 89 | "0": 360, 90 | "1": 970 91 | }, 92 | "size": { 93 | "0": 471.4238586425781, 94 | "1": 380.51617431640625 95 | }, 96 | "flags": {}, 97 | "order": 4, 98 | "mode": 0, 99 | "inputs": [ 100 | { 101 | "name": "images", 102 | "type": "IMAGE", 103 | "link": 66 104 | } 105 | ], 106 | "outputs": [], 107 | "properties": { 108 | "Node name for S&R": "PreviewImage" 109 | }, 110 | "widgets_values": [] 111 | }, 112 | { 113 | "id": 51, 114 | "type": "PreviewImage", 115 | "pos": { 116 | "0": 1160, 117 | "1": 970 118 | }, 119 | "size": { 120 | "0": 470.9900207519531, 121 | "1": 380.91436767578125 122 | }, 123 | "flags": {}, 124 | "order": 6, 125 | "mode": 0, 126 | "inputs": [ 127 | { 128 | "name": "images", 129 | "type": "IMAGE", 130 | "link": 68 131 | } 132 | ], 133 | "outputs": [], 134 | "properties": { 135 | "Node name for S&R": "PreviewImage" 136 | }, 137 | "widgets_values": [] 138 | }, 139 | { 140 | "id": 58, 141 | "type": "PreviewImage", 142 | "pos": { 143 | "0": 1160, 144 | "1": 525 145 | }, 146 | "size": { 147 | "0": 470.9900207519531, 148 | "1": 380.91436767578125 149 | }, 150 | "flags": {}, 151 | "order": 7, 152 | "mode": 0, 153 | "inputs": [ 154 | { 155 | "name": "images", 156 | "type": "IMAGE", 157 | "link": 60 158 | } 159 | ], 160 | "outputs": [], 161 | "properties": { 162 | "Node name for S&R": "PreviewImage" 163 | }, 164 | "widgets_values": [] 165 | }, 166 | { 167 | "id": 37, 168 | "type": "ImageBatch", 169 | "pos": { 170 | "0": 880, 171 | "1": 1010 172 | }, 173 | "size": { 174 | "0": 210, 175 | "1": 46 176 | }, 177 | "flags": {}, 178 | "order": 2, 179 | "mode": 0, 180 | "inputs": [ 181 | { 182 | "name": "image1", 183 | "type": "IMAGE", 184 | "link": 28 185 | }, 186 | { 187 | "name": "image2", 188 | "type": "IMAGE", 189 | "link": 29 190 | } 191 | ], 192 | "outputs": [ 193 | { 194 | "name": "IMAGE", 195 | "type": "IMAGE", 196 | "links": [ 197 | 65 198 | ], 199 | "slot_index": 0, 200 | "shape": 3 201 | } 202 | ], 203 | "properties": { 204 | "Node name for S&R": "ImageBatch" 205 | }, 206 | "widgets_values": [] 207 | }, 208 | { 209 | "id": 57, 210 | "type": "MaskToImage", 211 | "pos": { 212 | "0": 856, 213 | "1": 1335 214 | }, 215 | "size": { 216 | "0": 264.5999755859375, 217 | "1": 26 218 | }, 219 | "flags": {}, 220 | "order": 5, 221 | "mode": 0, 222 | "inputs": [ 223 | { 224 | "name": "mask", 225 | "type": "MASK", 226 | "link": 67 227 | } 228 | ], 229 | "outputs": [ 230 | { 231 | "name": "IMAGE", 232 | "type": "IMAGE", 233 | "links": [ 234 | 60 235 | ], 236 | "slot_index": 0, 237 | "shape": 3 238 | } 239 | ], 240 | "properties": { 241 | "Node name for S&R": "MaskToImage" 242 | }, 243 | "widgets_values": [] 244 | }, 245 | { 246 | "id": 60, 247 | "type": "RdancerFlorence2SAM2GenerateMask", 248 | "pos": { 249 | "0": 840, 250 | "1": 1120 251 | }, 252 | "size": { 253 | "0": 303.60308837890625, 254 | "1": 170 255 | }, 256 | "flags": {}, 257 | "order": 3, 258 | "mode": 0, 259 | "inputs": [ 260 | { 261 | "name": "image", 262 | "type": "IMAGE", 263 | "link": 65 264 | } 265 | ], 266 | "outputs": [ 267 | { 268 | "name": "annotated_image", 269 | "type": "IMAGE", 270 | "links": [ 271 | 66 272 | ], 273 | "shape": 3 274 | }, 275 | { 276 | "name": "mask", 277 | "type": "MASK", 278 | "links": [ 279 | 67 280 | ], 281 | "shape": 3 282 | }, 283 | { 284 | "name": "masked_image", 285 | "type": "IMAGE", 286 | "links": [ 287 | 68 288 | ], 289 | "shape": 3 290 | } 291 | ], 292 | "properties": { 293 | "Node name for S&R": "RdancerFlorence2SAM2GenerateMask" 294 | }, 295 | "widgets_values": [ 296 | "sam2_hiera_small.pt", 297 | "cuda", 298 | "products", 299 | false 300 | ] 301 | } 302 | ], 303 | "links": [ 304 | [ 305 | 28, 306 | 30, 307 | 0, 308 | 37, 309 | 0, 310 | "IMAGE" 311 | ], 312 | [ 313 | 29, 314 | 36, 315 | 0, 316 | 37, 317 | 1, 318 | "IMAGE" 319 | ], 320 | [ 321 | 60, 322 | 57, 323 | 0, 324 | 58, 325 | 0, 326 | "IMAGE" 327 | ], 328 | [ 329 | 65, 330 | 37, 331 | 0, 332 | 60, 333 | 0, 334 | "IMAGE" 335 | ], 336 | [ 337 | 66, 338 | 60, 339 | 0, 340 | 31, 341 | 0, 342 | "IMAGE" 343 | ], 344 | [ 345 | 67, 346 | 60, 347 | 1, 348 | 57, 349 | 0, 350 | "MASK" 351 | ], 352 | [ 353 | 68, 354 | 60, 355 | 2, 356 | 51, 357 | 0, 358 | "IMAGE" 359 | ] 360 | ], 361 | "groups": [], 362 | "config": {}, 363 | "extra": { 364 | "ds": { 365 | "scale": 2.0677378438607503, 366 | "offset": [ 367 | -675.6252784124076, 368 | -892.8859744309092 369 | ] 370 | } 371 | }, 372 | "version": 0.4 373 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Tuple, Optional 3 | 4 | import cv2 5 | # import gradio as gr # Gradio phones home, we don't want that 6 | import numpy as np 7 | # import spaces 8 | import supervision as sv 9 | import torch 10 | from PIL import Image 11 | from tqdm import tqdm 12 | import gc 13 | 14 | import comfy.model_management as mm 15 | 16 | try: 17 | from utils.video import generate_unique_name, create_directory, delete_directory 18 | 19 | from utils.florence import load_florence_model, run_florence_inference, \ 20 | FLORENCE_OPEN_VOCABULARY_DETECTION_TASK #, 21 | # FLORENCE_DETAILED_CAPTION_TASK, FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK 22 | from utils.modes import IMAGE_INFERENCE_MODES, IMAGE_OPEN_VOCABULARY_DETECTION_MODE #, \ 23 | # IMAGE_CAPTION_GROUNDING_MASKS_MODE, VIDEO_INFERENCE_MODES 24 | from utils.sam import load_sam_image_model, run_sam_inference #, load_sam_video_model 25 | except ImportError: 26 | # We're running as a module 27 | from .utils.video import generate_unique_name, create_directory, delete_directory 28 | 29 | from .utils.florence import load_florence_model, run_florence_inference, \ 30 | FLORENCE_OPEN_VOCABULARY_DETECTION_TASK #, 31 | # FLORENCE_DETAILED_CAPTION_TASK, FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK 32 | from .utils.modes import IMAGE_INFERENCE_MODES, IMAGE_OPEN_VOCABULARY_DETECTION_MODE #, \ 33 | # IMAGE_CAPTION_GROUNDING_MASKS_MODE, VIDEO_INFERENCE_MODES 34 | from .utils.sam import load_sam_image_model, run_sam_inference #, load_sam_video_model 35 | 36 | # MARKDOWN = """ 37 | # # Florence2 + SAM2 🔥 38 | 39 | # 53 | 54 | # This demo integrates Florence2 and SAM2 by creating a two-stage inference pipeline. In 55 | # the first stage, Florence2 performs tasks such as object detection, open-vocabulary 56 | # object detection, image captioning, or phrase grounding. In the second stage, SAM2 57 | # performs object segmentation on the image. 58 | # """ 59 | 60 | # IMAGE_PROCESSING_EXAMPLES = [ 61 | # [IMAGE_OPEN_VOCABULARY_DETECTION_MODE, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg", 'straw, white napkin, black napkin, hair'], 62 | # [IMAGE_OPEN_VOCABULARY_DETECTION_MODE, "https://media.roboflow.com/notebooks/examples/dog-3.jpeg", 'tail'], 63 | # [IMAGE_CAPTION_GROUNDING_MASKS_MODE, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg", None], 64 | # [IMAGE_CAPTION_GROUNDING_MASKS_MODE, "https://media.roboflow.com/notebooks/examples/dog-3.jpeg", None], 65 | # ] 66 | # VIDEO_PROCESSING_EXAMPLES = [ 67 | # ["videos/clip-07-camera-1.mp4", "player in white outfit, player in black outfit, ball, rim"], 68 | # ["videos/clip-07-camera-2.mp4", "player in white outfit, player in black outfit, ball, rim"], 69 | # ["videos/clip-07-camera-3.mp4", "player in white outfit, player in black outfit, ball, rim"] 70 | # ] 71 | 72 | # VIDEO_SCALE_FACTOR = 0.5 73 | # VIDEO_TARGET_DIRECTORY = "tmp" 74 | # create_directory(directory_path=VIDEO_TARGET_DIRECTORY) 75 | 76 | DEVICE = None #torch.device("cuda") 77 | # DEVICE = torch.device("cpu") 78 | 79 | if torch.cuda.is_available(): 80 | torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() 81 | if torch.cuda.get_device_properties(0).major >= 8: 82 | torch.backends.cuda.matmul.allow_tf32 = True 83 | torch.backends.cudnn.allow_tf32 = True 84 | elif torch.backends.mps.is_available(): 85 | DEVICE = torch.device("mps") 86 | else: 87 | DEVICE = torch.device("cpu") 88 | 89 | FLORENCE_MODEL, FLORENCE_PROCESSOR = None, None 90 | SAM_IMAGE_MODEL = None 91 | # SAM_VIDEO_MODEL = load_sam_video_model(device=DEVICE) 92 | COLORS = ['#FF1493', '#00BFFF', '#FF6347', '#FFD700', '#32CD32', '#8A2BE2'] 93 | COLOR_PALETTE = sv.ColorPalette.from_hex(COLORS) 94 | BOX_ANNOTATOR = sv.BoxAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX) 95 | LABEL_ANNOTATOR = sv.LabelAnnotator( 96 | color=COLOR_PALETTE, 97 | color_lookup=sv.ColorLookup.INDEX, 98 | text_position=sv.Position.CENTER_OF_MASS, 99 | text_color=sv.Color.from_hex("#000000"), 100 | border_radius=5 101 | ) 102 | MASK_ANNOTATOR = sv.MaskAnnotator( 103 | color=COLOR_PALETTE, 104 | color_lookup=sv.ColorLookup.INDEX 105 | ) 106 | 107 | 108 | def annotate_image(image, detections): 109 | output_image = image.copy() 110 | output_image = MASK_ANNOTATOR.annotate(output_image, detections) 111 | output_image = BOX_ANNOTATOR.annotate(output_image, detections) 112 | output_image = LABEL_ANNOTATOR.annotate(output_image, detections) 113 | return output_image 114 | 115 | 116 | # def on_mode_dropdown_change(text): 117 | # return [ 118 | # gr.Textbox(visible=text == IMAGE_OPEN_VOCABULARY_DETECTION_MODE), 119 | # gr.Textbox(visible=text == IMAGE_CAPTION_GROUNDING_MASKS_MODE), 120 | # ] 121 | 122 | def lazy_load_models(device: torch.device, sam_image_model: str): 123 | global SAM_IMAGE_MODEL 124 | global loaded_sam_image_model 125 | global FLORENCE_MODEL 126 | global FLORENCE_PROCESSOR 127 | global DEVICE 128 | if device != DEVICE: 129 | offload_models(delete=True) 130 | DEVICE = device 131 | if SAM_IMAGE_MODEL is None: 132 | SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE, checkpoint=sam_image_model) 133 | loaded_sam_image_model = sam_image_model 134 | elif loaded_sam_image_model != sam_image_model: 135 | print(f"DEBUG [ComfyUI_Florence2SAM2::lazy_load_models] Old model {loaded_sam_image_model} != new model {sam_image_model} => releasing memory") 136 | SAM_IMAGE_MODEL.model.cpu() 137 | del SAM_IMAGE_MODEL 138 | gc.collect() 139 | torch.cuda.empty_cache() 140 | SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE, checkpoint=sam_image_model) 141 | loaded_sam_image_model = sam_image_model 142 | if FLORENCE_MODEL is None or FLORENCE_PROCESSOR is None: 143 | assert FLORENCE_MODEL is None and FLORENCE_PROCESSOR is None 144 | FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE) 145 | 146 | # The models could have been offloaded to RAM by offload_models(); if they're already on `device`, this is a no-op 147 | SAM_IMAGE_MODEL.model.to(device) 148 | FLORENCE_MODEL.to(device) 149 | # FLORENCE_PROCESSOR.to(device) # note a model 150 | 151 | def offload_models(delete=False): 152 | global SAM_IMAGE_MODEL 153 | global FLORENCE_MODEL 154 | global FLORENCE_PROCESSOR 155 | offload_device = mm.unet_offload_device() 156 | do_gc = False 157 | if SAM_IMAGE_MODEL is not None: 158 | if delete: 159 | SAM_IMAGE_MODEL.model.cpu() 160 | del SAM_IMAGE_MODEL 161 | SAM_IMAGE_MODEL = None 162 | do_gc = True 163 | else: 164 | SAM_IMAGE_MODEL.model.to(offload_device) 165 | if FLORENCE_MODEL is not None: 166 | if delete: 167 | FLORENCE_MODEL.cpu() 168 | del FLORENCE_MODEL 169 | FLORENCE_MODEL = None 170 | do_gc = True 171 | else: 172 | FLORENCE_MODEL.to(offload_device) 173 | if FLORENCE_PROCESSOR is not None: 174 | if delete: 175 | # FLORENCE_PROCESSOR.cpu() 176 | del FLORENCE_PROCESSOR 177 | FLORENCE_PROCESSOR = None 178 | do_gc = True 179 | else: 180 | # FLORENCE_PROCESSOR.cpu() 181 | pass 182 | mm.soft_empty_cache() 183 | if do_gc: 184 | gc.collect() 185 | 186 | def process_image(device: torch.device, sam_image_model: str, image: Image.Image, promt: str, keep_model_loaded: bool) -> Tuple[Optional[Image.Image], Optional[Image.Image], Optional[Image.Image]]: 187 | lazy_load_models(device, sam_image_model) 188 | annotated_image, mask_list = _process_image(IMAGE_OPEN_VOCABULARY_DETECTION_MODE, image, promt) 189 | if mask_list is not None and len(mask_list) > 0: 190 | mask = np.any(mask_list, axis=0) # Merge masks into a single mask 191 | mask = (mask * 255).astype(np.uint8) 192 | else: 193 | print(f"Florence2SAM2: No objects of class {promt} found in the image.") 194 | mask = np.zeros((image.height, image.width), dtype=np.uint8) 195 | mask = Image.fromarray(mask).convert("L") # Convert to 8-bit grayscale 196 | masked_image = Image.new("RGB", image.size, (0, 0, 0)) 197 | masked_image.paste(image, mask=mask) 198 | if not keep_model_loaded: 199 | offload_models() 200 | return annotated_image, mask, masked_image 201 | 202 | @torch.inference_mode() 203 | @torch.autocast(device_type="cuda", dtype=torch.bfloat16) 204 | def _process_image( 205 | mode_dropdown=IMAGE_OPEN_VOCABULARY_DETECTION_MODE, image_input=None, text_input=None 206 | ) -> Tuple[Optional[Image.Image], Optional[np.ndarray]]: 207 | """ 208 | Process an image with Florence2 and SAM2. 209 | 210 | Note that the models are lazy loaded, so they will not waste time loading during startup, or at all if this method is not called. 211 | 212 | @param mode_dropdown: The mode of the Florence2 model. Must be IMAGE_OPEN_VOCABULARY_DETECTION_MODE. 213 | @param image_input: The image to process. 214 | @param text_input: The text prompt to use for the Florence2 model. 215 | 216 | @return: Tuple[Image.Image, Image.Image]: The annotated image, merged mask (Boolean array) of the detected objects 217 | """ 218 | global SAM_IMAGE_MODEL 219 | global FLORENCE_MODEL 220 | global FLORENCE_PROCESSOR 221 | 222 | if not image_input: 223 | # gr.Info("Please upload an image.") 224 | return None, None 225 | 226 | if mode_dropdown == IMAGE_OPEN_VOCABULARY_DETECTION_MODE: 227 | if not text_input: 228 | # gr.Info("Please enter a text prompt.") 229 | return None, None 230 | 231 | texts = [prompt.strip() for prompt in text_input.split(",")] 232 | detections_list = [] 233 | for text in texts: 234 | _, result = run_florence_inference( 235 | model=FLORENCE_MODEL, 236 | processor=FLORENCE_PROCESSOR, 237 | device=DEVICE, 238 | image=image_input, 239 | task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, 240 | text=text 241 | ) 242 | detections = sv.Detections.from_lmm( 243 | lmm=sv.LMM.FLORENCE_2, 244 | result=result, 245 | resolution_wh=image_input.size 246 | ) 247 | detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections) 248 | detections_list.append(detections) 249 | 250 | detections = sv.Detections.merge(detections_list) 251 | detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections) 252 | return annotate_image(image_input, detections), detections.mask 253 | 254 | # if mode_dropdown == IMAGE_CAPTION_GROUNDING_MASKS_MODE: 255 | # _, result = run_florence_inference( 256 | # model=FLORENCE_MODEL, 257 | # processor=FLORENCE_PROCESSOR, 258 | # device=DEVICE, 259 | # image=image_input, 260 | # task=FLORENCE_DETAILED_CAPTION_TASK 261 | # ) 262 | # caption = result[FLORENCE_DETAILED_CAPTION_TASK] 263 | # _, result = run_florence_inference( 264 | # model=FLORENCE_MODEL, 265 | # processor=FLORENCE_PROCESSOR, 266 | # device=DEVICE, 267 | # image=image_input, 268 | # task=FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK, 269 | # text=caption 270 | # ) 271 | # detections = sv.Detections.from_lmm( 272 | # lmm=sv.LMM.FLORENCE_2, 273 | # result=result, 274 | # resolution_wh=image_input.size 275 | # ) 276 | # detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections) 277 | # return annotate_image(image_input, detections), caption 278 | 279 | 280 | # @spaces.GPU(duration=300) 281 | # @torch.inference_mode() 282 | # @torch.autocast(device_type="cuda", dtype=torch.bfloat16) 283 | # def process_video( 284 | # video_input, text_input, progress=gr.Progress(track_tqdm=True) 285 | # ) -> Optional[str]: 286 | # if not video_input: 287 | # gr.Info("Please upload a video.") 288 | # return None 289 | 290 | # if not text_input: 291 | # gr.Info("Please enter a text prompt.") 292 | # return None 293 | 294 | # frame_generator = sv.get_video_frames_generator(video_input) 295 | # frame = next(frame_generator) 296 | # frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 297 | 298 | # texts = [prompt.strip() for prompt in text_input.split(",")] 299 | # detections_list = [] 300 | # for text in texts: 301 | # _, result = run_florence_inference( 302 | # model=FLORENCE_MODEL, 303 | # processor=FLORENCE_PROCESSOR, 304 | # device=DEVICE, 305 | # image=frame, 306 | # task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, 307 | # text=text 308 | # ) 309 | # detections = sv.Detections.from_lmm( 310 | # lmm=sv.LMM.FLORENCE_2, 311 | # result=result, 312 | # resolution_wh=frame.size 313 | # ) 314 | # detections = run_sam_inference(SAM_IMAGE_MODEL, frame, detections) 315 | # detections_list.append(detections) 316 | 317 | # detections = sv.Detections.merge(detections_list) 318 | # detections = run_sam_inference(SAM_IMAGE_MODEL, frame, detections) 319 | 320 | # if len(detections.mask) == 0: 321 | # gr.Info( 322 | # "No objects of class {text_input} found in the first frame of the video. " 323 | # "Trim the video to make the object appear in the first frame or try a " 324 | # "different text prompt." 325 | # ) 326 | # return None 327 | 328 | # name = generate_unique_name() 329 | # frame_directory_path = os.path.join(VIDEO_TARGET_DIRECTORY, name) 330 | # frames_sink = sv.ImageSink( 331 | # target_dir_path=frame_directory_path, 332 | # image_name_pattern="{:05d}.jpeg" 333 | # ) 334 | 335 | # video_info = sv.VideoInfo.from_video_path(video_input) 336 | # video_info.width = int(video_info.width * VIDEO_SCALE_FACTOR) 337 | # video_info.height = int(video_info.height * VIDEO_SCALE_FACTOR) 338 | 339 | # frames_generator = sv.get_video_frames_generator(video_input) 340 | # with frames_sink: 341 | # for frame in tqdm( 342 | # frames_generator, 343 | # total=video_info.total_frames, 344 | # desc="splitting video into frames" 345 | # ): 346 | # frame = sv.scale_image(frame, VIDEO_SCALE_FACTOR) 347 | # frames_sink.save_image(frame) 348 | 349 | # inference_state = SAM_VIDEO_MODEL.init_state( 350 | # video_path=frame_directory_path, 351 | # device=DEVICE 352 | # ) 353 | 354 | # for mask_index, mask in enumerate(detections.mask): 355 | # _, object_ids, mask_logits = SAM_VIDEO_MODEL.add_new_mask( 356 | # inference_state=inference_state, 357 | # frame_idx=0, 358 | # obj_id=mask_index, 359 | # mask=mask 360 | # ) 361 | 362 | # video_path = os.path.join(VIDEO_TARGET_DIRECTORY, f"{name}.mp4") 363 | # frames_generator = sv.get_video_frames_generator(video_input) 364 | # masks_generator = SAM_VIDEO_MODEL.propagate_in_video(inference_state) 365 | # with sv.VideoSink(video_path, video_info=video_info) as sink: 366 | # for frame, (_, tracker_ids, mask_logits) in zip(frames_generator, masks_generator): 367 | # frame = sv.scale_image(frame, VIDEO_SCALE_FACTOR) 368 | # masks = (mask_logits > 0.0).cpu().numpy().astype(bool) 369 | # if len(masks.shape) == 4: 370 | # masks = np.squeeze(masks, axis=1) 371 | 372 | # detections = sv.Detections( 373 | # xyxy=sv.mask_to_xyxy(masks=masks), 374 | # mask=masks, 375 | # class_id=np.array(tracker_ids) 376 | # ) 377 | # annotated_frame = frame.copy() 378 | # annotated_frame = MASK_ANNOTATOR.annotate( 379 | # scene=annotated_frame, detections=detections) 380 | # annotated_frame = BOX_ANNOTATOR.annotate( 381 | # scene=annotated_frame, detections=detections) 382 | # sink.write_frame(annotated_frame) 383 | 384 | # delete_directory(frame_directory_path) 385 | # return video_path 386 | 387 | 388 | # with gr.Blocks() as demo: 389 | # gr.Markdown(MARKDOWN) 390 | # with gr.Tab("Image"): 391 | # image_processing_mode_dropdown_component = gr.Dropdown( 392 | # choices=IMAGE_INFERENCE_MODES, 393 | # value=IMAGE_INFERENCE_MODES[0], 394 | # label="Mode", 395 | # info="Select a mode to use.", 396 | # interactive=True 397 | # ) 398 | # with gr.Row(): 399 | # with gr.Column(): 400 | # image_processing_image_input_component = gr.Image( 401 | # type='pil', label='Upload image') 402 | # image_processing_text_input_component = gr.Textbox( 403 | # label='Text prompt', 404 | # placeholder='Enter comma separated text prompts') 405 | # image_processing_submit_button_component = gr.Button( 406 | # value='Submit', variant='primary') 407 | # with gr.Column(): 408 | # image_processing_image_output_component = gr.Image( 409 | # type='pil', label='Image output') 410 | # image_processing_text_output_component = gr.Textbox( 411 | # label='Caption output', visible=False) 412 | 413 | # with gr.Row(): 414 | # gr.Examples( 415 | # fn=process_image, 416 | # examples=IMAGE_PROCESSING_EXAMPLES, 417 | # inputs=[ 418 | # image_processing_mode_dropdown_component, 419 | # image_processing_image_input_component, 420 | # image_processing_text_input_component 421 | # ], 422 | # outputs=[ 423 | # image_processing_image_output_component, 424 | # image_processing_text_output_component 425 | # ], 426 | # run_on_click=True 427 | # ) 428 | # with gr.Tab("Video"): 429 | # video_processing_mode_dropdown_component = gr.Dropdown( 430 | # choices=VIDEO_INFERENCE_MODES, 431 | # value=VIDEO_INFERENCE_MODES[0], 432 | # label="Mode", 433 | # info="Select a mode to use.", 434 | # interactive=True 435 | # ) 436 | # with gr.Row(): 437 | # with gr.Column(): 438 | # video_processing_video_input_component = gr.Video( 439 | # label='Upload video') 440 | # video_processing_text_input_component = gr.Textbox( 441 | # label='Text prompt', 442 | # placeholder='Enter comma separated text prompts') 443 | # video_processing_submit_button_component = gr.Button( 444 | # value='Submit', variant='primary') 445 | # with gr.Column(): 446 | # video_processing_video_output_component = gr.Video( 447 | # label='Video output') 448 | # with gr.Row(): 449 | # gr.Examples( 450 | # fn=process_video, 451 | # examples=VIDEO_PROCESSING_EXAMPLES, 452 | # inputs=[ 453 | # video_processing_video_input_component, 454 | # video_processing_text_input_component 455 | # ], 456 | # outputs=video_processing_video_output_component, 457 | # run_on_click=True 458 | # ) 459 | 460 | # image_processing_submit_button_component.click( 461 | # fn=process_image, 462 | # inputs=[ 463 | # image_processing_mode_dropdown_component, 464 | # image_processing_image_input_component, 465 | # image_processing_text_input_component 466 | # ], 467 | # outputs=[ 468 | # image_processing_image_output_component, 469 | # image_processing_text_output_component 470 | # ] 471 | # ) 472 | # image_processing_text_input_component.submit( 473 | # fn=process_image, 474 | # inputs=[ 475 | # image_processing_mode_dropdown_component, 476 | # image_processing_image_input_component, 477 | # image_processing_text_input_component 478 | # ], 479 | # outputs=[ 480 | # image_processing_image_output_component, 481 | # image_processing_text_output_component 482 | # ] 483 | # ) 484 | # image_processing_mode_dropdown_component.change( 485 | # on_mode_dropdown_change, 486 | # inputs=[image_processing_mode_dropdown_component], 487 | # outputs=[ 488 | # image_processing_text_input_component, 489 | # image_processing_text_output_component 490 | # ] 491 | # ) 492 | # video_processing_submit_button_component.click( 493 | # fn=process_video, 494 | # inputs=[ 495 | # video_processing_video_input_component, 496 | # video_processing_text_input_component 497 | # ], 498 | # outputs=video_processing_video_output_component 499 | # ) 500 | # video_processing_text_input_component.submit( 501 | # fn=process_video, 502 | # inputs=[ 503 | # video_processing_video_input_component, 504 | # video_processing_text_input_component 505 | # ], 506 | # outputs=video_processing_video_output_component 507 | # ) 508 | 509 | # demo.launch(debug=False, show_error=True) 510 | --------------------------------------------------------------------------------