├── test.py
├── utils
├── __init__.py
├── modes.py
├── video.py
├── florence.py
└── sam.py
├── test
├── 01_source.png
├── 02_source.png
├── 01_output_mask.png
├── 02_output_mask.png
├── 01_bounding_boxes.jpeg
├── 02_bounding_boxes.jpeg
├── 01_output_masked_image.png
├── 02_output_masked_image.png
└── README.md
├── workflows
├── workflow.png
└── workflow.json
├── videos
├── clip-07-camera-1.mp4
├── clip-07-camera-2.mp4
└── clip-07-camera-3.mp4
├── models
└── sam2
│ ├── sam2_hiera_large.pt
│ ├── sam2_hiera_small.pt
│ ├── sam2_hiera_tiny.pt
│ └── sam2_hiera_base_plus.pt
├── .gitignore
├── configs
├── __init__.py
├── sam2_hiera_b+.yaml
├── sam2_hiera_s.yaml
├── sam2_hiera_l.yaml
└── sam2_hiera_t.yaml
├── requirements.txt
├── README.md
├── .gitattributes
├── __init__.py
├── LICENSE
└── app.py
/test.py:
--------------------------------------------------------------------------------
1 | __init__.py
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test/01_source.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rdancer/ComfyUI_Florence2SAM2/HEAD/test/01_source.png
--------------------------------------------------------------------------------
/test/02_source.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rdancer/ComfyUI_Florence2SAM2/HEAD/test/02_source.png
--------------------------------------------------------------------------------
/test/01_output_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rdancer/ComfyUI_Florence2SAM2/HEAD/test/01_output_mask.png
--------------------------------------------------------------------------------
/test/02_output_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rdancer/ComfyUI_Florence2SAM2/HEAD/test/02_output_mask.png
--------------------------------------------------------------------------------
/workflows/workflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rdancer/ComfyUI_Florence2SAM2/HEAD/workflows/workflow.png
--------------------------------------------------------------------------------
/test/01_bounding_boxes.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rdancer/ComfyUI_Florence2SAM2/HEAD/test/01_bounding_boxes.jpeg
--------------------------------------------------------------------------------
/test/02_bounding_boxes.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rdancer/ComfyUI_Florence2SAM2/HEAD/test/02_bounding_boxes.jpeg
--------------------------------------------------------------------------------
/test/01_output_masked_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rdancer/ComfyUI_Florence2SAM2/HEAD/test/01_output_masked_image.png
--------------------------------------------------------------------------------
/test/02_output_masked_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rdancer/ComfyUI_Florence2SAM2/HEAD/test/02_output_masked_image.png
--------------------------------------------------------------------------------
/videos/clip-07-camera-1.mp4:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:7944c1a5e9be241ebf11eb39f6302c3ce9d8482ca9f12e4268b252aeda6baee9
3 | size 5500081
4 |
--------------------------------------------------------------------------------
/videos/clip-07-camera-2.mp4:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:abbfef6d422c9aa3968d14de6b78aecaf544c85423d401387e3d5e75ffee3497
3 | size 5467189
4 |
--------------------------------------------------------------------------------
/videos/clip-07-camera-3.mp4:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:e518f2ee6761d559bc864be2fec70ddc41244fbf3fea404c3158129a434ce879
3 | size 5397505
4 |
--------------------------------------------------------------------------------
/models/sam2/sam2_hiera_large.pt:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:7442e4e9b732a508f80e141e7c2913437a3610ee0c77381a66658c3a445df87b
3 | size 897952466
4 |
--------------------------------------------------------------------------------
/models/sam2/sam2_hiera_small.pt:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:95949964d4e548409021d47b22712d5f1abf2564cc0c3c765ba599a24ac7dce3
3 | size 184309650
4 |
--------------------------------------------------------------------------------
/models/sam2/sam2_hiera_tiny.pt:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:65b50056e05bcb13694174f51bb6da89c894b57b75ccdf0ba6352c597c5d1125
3 | size 155906050
4 |
--------------------------------------------------------------------------------
/models/sam2/sam2_hiera_base_plus.pt:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:d0bb7f236400a49669ffdd1be617959a8b1d1065081789d7bbff88eded3a8071
3 | size 323493298
4 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /venv
2 | /.idea
3 | /tmp
4 |
5 | __pycache__/
6 |
7 | .*.sw?
8 | *~
9 |
10 | # something on my MacOS keeps saving these resource forks for every single file --rdancer
11 | .*
12 |
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | einops
3 | spaces
4 | timm
5 | transformers
6 | gradio
7 | supervision
8 | opencv-python
9 | pytest
10 |
11 | # Build dependency; getting errors so try putting it here
12 | flit_core>=3.2,<4
13 | #samv2
14 | # Rdancer's version of samv2 includes a fix that makes it not crash on Windows
15 | git+https://github.com/rdancer/samv2.git
16 |
--------------------------------------------------------------------------------
/utils/modes.py:
--------------------------------------------------------------------------------
1 | IMAGE_OPEN_VOCABULARY_DETECTION_MODE = "open vocabulary detection + image masks"
2 | IMAGE_CAPTION_GROUNDING_MASKS_MODE = "caption + grounding + image masks"
3 |
4 | IMAGE_INFERENCE_MODES = [
5 | IMAGE_OPEN_VOCABULARY_DETECTION_MODE,
6 | IMAGE_CAPTION_GROUNDING_MASKS_MODE
7 | ]
8 |
9 | VIDEO_OPEN_VOCABULARY_DETECTION_MODE = "open vocabulary detection + video masks"
10 |
11 | VIDEO_INFERENCE_MODES = [
12 | VIDEO_OPEN_VOCABULARY_DETECTION_MODE
13 | ]
14 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ComfyUI custom node implementing Florence 2 + Segment Anything Model 2, based on [SkalskiP's HuggingFace space](https://huggingface.co/spaces/SkalskiP/florence-sam)
2 |
3 | 
4 |
5 | *RdancerFlorence2SAM2GenerateMask* - the node is self-contained, and does not require separate model loaders. Models are lazy-loaded, and cached. Model unloading, if required, must be done manually.
6 |
7 | ## Testing
8 |
9 | Run `python test.py test/*_source.png "products"` 👉 the resulting images must pixel-match `test/*output*`.
10 |
--------------------------------------------------------------------------------
/utils/video.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import os
3 | import shutil
4 | import uuid
5 |
6 |
7 | def create_directory(directory_path: str) -> None:
8 | if not os.path.exists(directory_path):
9 | os.makedirs(directory_path)
10 |
11 |
12 | def delete_directory(directory_path: str) -> None:
13 | if not os.path.exists(directory_path):
14 | raise FileNotFoundError(f"Directory '{directory_path}' does not exist.")
15 |
16 | try:
17 | shutil.rmtree(directory_path)
18 | except PermissionError:
19 | raise PermissionError(
20 | f"Permission denied: Unable to delete '{directory_path}'.")
21 |
22 |
23 | def generate_unique_name():
24 | current_datetime = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
25 | unique_id = uuid.uuid4()
26 | return f"{current_datetime}_{unique_id}"
27 |
--------------------------------------------------------------------------------
/test/README.md:
--------------------------------------------------------------------------------
1 |
2 | - 1 image goes in and has a prompt "products" and then out comes a mask with all the segmented products in the image
3 | - But it could output a batch or list (I will probably just end up combining them all in the next node)
4 |
5 | Note: Instead of bounding boxes with labels we **just need the masks** that's it
6 |
7 |
8 |
9 | | source image | screenshot (left: source; right: result) |
10 | |
11 |
12 | 
13 |
14 | |
15 |
16 | 
17 |
18 | |
19 |
20 | |
21 |
22 | 
23 |
24 | |
25 |
26 | 
27 |
28 | |
29 |
30 |
31 |
32 | The test images have been confirmed to work with this version of the code, and the screenshots above were taken from the [HuggingFace space](https://huggingface.co/spaces/SkalskiP/florence-sam).
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.7z filter=lfs diff=lfs merge=lfs -text
2 | *.arrow filter=lfs diff=lfs merge=lfs -text
3 | *.bin filter=lfs diff=lfs merge=lfs -text
4 | *.bz2 filter=lfs diff=lfs merge=lfs -text
5 | *.ckpt filter=lfs diff=lfs merge=lfs -text
6 | *.ftz filter=lfs diff=lfs merge=lfs -text
7 | *.gz filter=lfs diff=lfs merge=lfs -text
8 | *.h5 filter=lfs diff=lfs merge=lfs -text
9 | *.joblib filter=lfs diff=lfs merge=lfs -text
10 | *.lfs.* filter=lfs diff=lfs merge=lfs -text
11 | *.mlmodel filter=lfs diff=lfs merge=lfs -text
12 | *.model filter=lfs diff=lfs merge=lfs -text
13 | *.msgpack filter=lfs diff=lfs merge=lfs -text
14 | *.npy filter=lfs diff=lfs merge=lfs -text
15 | *.npz filter=lfs diff=lfs merge=lfs -text
16 | *.onnx filter=lfs diff=lfs merge=lfs -text
17 | *.ot filter=lfs diff=lfs merge=lfs -text
18 | *.parquet filter=lfs diff=lfs merge=lfs -text
19 | *.pb filter=lfs diff=lfs merge=lfs -text
20 | *.pickle filter=lfs diff=lfs merge=lfs -text
21 | *.pkl filter=lfs diff=lfs merge=lfs -text
22 | *.pt filter=lfs diff=lfs merge=lfs -text
23 | *.pth filter=lfs diff=lfs merge=lfs -text
24 | *.rar filter=lfs diff=lfs merge=lfs -text
25 | *.safetensors filter=lfs diff=lfs merge=lfs -text
26 | saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27 | *.tar.* filter=lfs diff=lfs merge=lfs -text
28 | *.tar filter=lfs diff=lfs merge=lfs -text
29 | *.tflite filter=lfs diff=lfs merge=lfs -text
30 | *.tgz filter=lfs diff=lfs merge=lfs -text
31 | *.wasm filter=lfs diff=lfs merge=lfs -text
32 | *.xz filter=lfs diff=lfs merge=lfs -text
33 | *.zip filter=lfs diff=lfs merge=lfs -text
34 | *.zst filter=lfs diff=lfs merge=lfs -text
35 | *tfevents* filter=lfs diff=lfs merge=lfs -text
36 | *.mp4 filter=lfs diff=lfs merge=lfs -text
37 |
--------------------------------------------------------------------------------
/utils/florence.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Union, Any, Tuple, Dict
3 | from unittest.mock import patch
4 |
5 | import torch
6 | from PIL import Image
7 | from transformers import AutoModelForCausalLM, AutoProcessor
8 | from transformers.dynamic_module_utils import get_imports
9 |
10 | FLORENCE_CHECKPOINT = "microsoft/Florence-2-base"
11 | FLORENCE_OBJECT_DETECTION_TASK = ''
12 | FLORENCE_DETAILED_CAPTION_TASK = ''
13 | FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK = ''
14 | FLORENCE_OPEN_VOCABULARY_DETECTION_TASK = ''
15 | FLORENCE_DENSE_REGION_CAPTION_TASK = ''
16 |
17 |
18 | def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]:
19 | """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
20 | if not str(filename).endswith("/modeling_florence2.py"):
21 | return get_imports(filename)
22 | imports = get_imports(filename)
23 | try:
24 | imports.remove("flash_attn")
25 | except ValueError:
26 | pass
27 | return imports
28 |
29 |
30 | def load_florence_model(
31 | device: torch.device, checkpoint: str = FLORENCE_CHECKPOINT
32 | ) -> Tuple[Any, Any]:
33 | with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
34 | model = AutoModelForCausalLM.from_pretrained(
35 | checkpoint, trust_remote_code=True).to(device).eval()
36 | processor = AutoProcessor.from_pretrained(
37 | checkpoint, trust_remote_code=True)
38 | return model, processor
39 |
40 |
41 | def run_florence_inference(
42 | model: Any,
43 | processor: Any,
44 | device: torch.device,
45 | image: Image,
46 | task: str,
47 | text: str = ""
48 | ) -> Tuple[str, Dict]:
49 | prompt = task + text
50 | inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
51 | generated_ids = model.generate(
52 | input_ids=inputs["input_ids"],
53 | pixel_values=inputs["pixel_values"],
54 | max_new_tokens=1024,
55 | num_beams=3
56 | )
57 | generated_text = processor.batch_decode(
58 | generated_ids, skip_special_tokens=False)[0]
59 | response = processor.post_process_generation(
60 | generated_text, task=task, image_size=image.size)
61 | return generated_text, response
62 |
--------------------------------------------------------------------------------
/utils/sam.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | import os
3 |
4 | import folder_paths
5 | import numpy as np
6 | import supervision as sv
7 | import torch
8 | from PIL import Image
9 | from sam2.build_sam import build_sam2, build_sam2_video_predictor
10 | from sam2.sam2_image_predictor import SAM2ImagePredictor
11 |
12 | model_to_config_map = {
13 | # models: sam2_hiera_base_plus.pt sam2_hiera_large.pt sam2_hiera_small.pt sam2_hiera_tiny.pt
14 | # configs: sam2_hiera_b+.yaml sam2_hiera_l.yaml sam2_hiera_s.yaml sam2_hiera_t.yaml
15 | "sam2_hiera_base_plus.pt": "sam2_hiera_b+.yaml",
16 | "sam2_hiera_large.pt": "sam2_hiera_l.yaml",
17 | "sam2_hiera_small.pt": "sam2_hiera_s.yaml",
18 | "sam2_hiera_tiny.pt": "sam2_hiera_t.yaml",
19 | }
20 | SAM_CHECKPOINT = "sam2_hiera_small.pt"
21 | SAM_CONFIG = "sam2_hiera_s.yaml" # from /usr/local/lib/python3.10/dist-packages/sam2/configs, *not* from either the models directory, or this package's directory
22 |
23 | def load_sam_image_model(
24 | device: torch.device,
25 | checkpoint: str = SAM_CHECKPOINT,
26 | config: str = None
27 | ) -> SAM2ImagePredictor:
28 | if config is None:
29 | config = model_to_config_map[checkpoint]
30 | import os
31 |
32 | # 1. Print the current working directory with flush=True
33 | current_working_directory = os.getcwd()
34 | print(f"Current working directory: {current_working_directory}", flush=True)
35 |
36 | # 2. Check if the "models" and "models/sam2" directories exist
37 | models_dir = folder_paths.models_dir
38 | sam2_dir = os.path.join(models_dir, "sam2")
39 |
40 | if os.path.exists(models_dir):
41 | print(f"'models' directory exists: {models_dir}", flush=True)
42 | else:
43 | print(f"'models' directory does not exist: {models_dir}", flush=True)
44 |
45 | if os.path.exists(sam2_dir):
46 | print(f"'models/sam2' directory exists: {sam2_dir}", flush=True)
47 | else:
48 | print(f"'models/sam2' directory does not exist: {sam2_dir}", flush=True)
49 |
50 | model_path = os.path.join(sam2_dir, checkpoint)
51 | if os.path.exists(model_path):
52 | print(f"'models/sam2/{checkpoint}' file exists: {model_path}", flush=True)
53 | else:
54 | print(f"'models/sam2/{checkpoint}' file does not exist: {model_path}", flush=True)
55 |
56 | model = build_sam2(config, model_path, device=device)
57 | return SAM2ImagePredictor(sam_model=model)
58 |
59 |
60 | def load_sam_video_model(
61 | device: torch.device,
62 | config: str = SAM_CONFIG,
63 | checkpoint: str = SAM_CHECKPOINT
64 | ) -> Any:
65 | return build_sam2_video_predictor(config, checkpoint, device=device)
66 |
67 |
68 | def run_sam_inference(
69 | model: Any,
70 | image: Image,
71 | detections: sv.Detections
72 | ) -> sv.Detections:
73 | image = np.array(image.convert("RGB"))
74 | model.set_image(image)
75 | mask, score, _ = model.predict(box=detections.xyxy, multimask_output=False)
76 |
77 | # dirty fix; remove this later
78 | if len(mask.shape) == 4:
79 | mask = np.squeeze(mask)
80 |
81 | detections.mask = mask.astype(bool)
82 | return detections
83 |
--------------------------------------------------------------------------------
/configs/sam2_hiera_b+.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 112
12 | num_heads: 2
13 | neck:
14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
15 | position_encoding:
16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
17 | num_pos_feats: 256
18 | normalize: true
19 | scale: null
20 | temperature: 10000
21 | d_model: 256
22 | backbone_channel_list: [896, 448, 224, 112]
23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24 | fpn_interp_model: nearest
25 |
26 | memory_attention:
27 | _target_: sam2.modeling.memory_attention.MemoryAttention
28 | d_model: 256
29 | pos_enc_at_input: true
30 | layer:
31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
32 | activation: relu
33 | dim_feedforward: 2048
34 | dropout: 0.1
35 | pos_enc_at_attn: false
36 | self_attention:
37 | _target_: sam2.modeling.sam.transformer.RoPEAttention
38 | rope_theta: 10000.0
39 | feat_sizes: [32, 32]
40 | embedding_dim: 256
41 | num_heads: 1
42 | downsample_rate: 1
43 | dropout: 0.1
44 | d_model: 256
45 | pos_enc_at_cross_attn_keys: true
46 | pos_enc_at_cross_attn_queries: false
47 | cross_attention:
48 | _target_: sam2.modeling.sam.transformer.RoPEAttention
49 | rope_theta: 10000.0
50 | feat_sizes: [32, 32]
51 | rope_k_repeat: True
52 | embedding_dim: 256
53 | num_heads: 1
54 | downsample_rate: 1
55 | dropout: 0.1
56 | kv_in_dim: 64
57 | num_layers: 4
58 |
59 | memory_encoder:
60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
61 | out_dim: 64
62 | position_encoding:
63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
64 | num_pos_feats: 64
65 | normalize: true
66 | scale: null
67 | temperature: 10000
68 | mask_downsampler:
69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
70 | kernel_size: 3
71 | stride: 2
72 | padding: 1
73 | fuser:
74 | _target_: sam2.modeling.memory_encoder.Fuser
75 | layer:
76 | _target_: sam2.modeling.memory_encoder.CXBlock
77 | dim: 256
78 | kernel_size: 7
79 | padding: 3
80 | layer_scale_init_value: 1e-6
81 | use_dwconv: True # depth-wise convs
82 | num_layers: 2
83 |
84 | num_maskmem: 7
85 | image_size: 1024
86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87 | sigmoid_scale_for_mem_enc: 20.0
88 | sigmoid_bias_for_mem_enc: -10.0
89 | use_mask_input_as_output_without_sam: true
90 | # Memory
91 | directly_add_no_mem_embed: true
92 | # use high-resolution feature map in the SAM mask decoder
93 | use_high_res_features_in_sam: true
94 | # output 3 masks on the first click on initial conditioning frames
95 | multimask_output_in_sam: true
96 | # SAM heads
97 | iou_prediction_use_sigmoid: True
98 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
99 | use_obj_ptrs_in_encoder: true
100 | add_tpos_enc_to_obj_ptrs: false
101 | only_obj_ptrs_in_the_past_for_eval: true
102 | # object occlusion prediction
103 | pred_obj_scores: true
104 | pred_obj_scores_mlp: true
105 | fixed_no_obj_ptr: true
106 | # multimask tracking settings
107 | multimask_output_for_tracking: true
108 | use_multimask_token_for_obj_ptr: true
109 | multimask_min_pt_num: 0
110 | multimask_max_pt_num: 1
111 | use_mlp_for_obj_ptr_proj: true
112 | # Compilation flag
113 | compile_image_encoder: False
114 |
--------------------------------------------------------------------------------
/configs/sam2_hiera_s.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 96
12 | num_heads: 1
13 | stages: [1, 2, 11, 2]
14 | global_att_blocks: [7, 10, 13]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | neck:
17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18 | position_encoding:
19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20 | num_pos_feats: 256
21 | normalize: true
22 | scale: null
23 | temperature: 10000
24 | d_model: 256
25 | backbone_channel_list: [768, 384, 192, 96]
26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27 | fpn_interp_model: nearest
28 |
29 | memory_attention:
30 | _target_: sam2.modeling.memory_attention.MemoryAttention
31 | d_model: 256
32 | pos_enc_at_input: true
33 | layer:
34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35 | activation: relu
36 | dim_feedforward: 2048
37 | dropout: 0.1
38 | pos_enc_at_attn: false
39 | self_attention:
40 | _target_: sam2.modeling.sam.transformer.RoPEAttention
41 | rope_theta: 10000.0
42 | feat_sizes: [32, 32]
43 | embedding_dim: 256
44 | num_heads: 1
45 | downsample_rate: 1
46 | dropout: 0.1
47 | d_model: 256
48 | pos_enc_at_cross_attn_keys: true
49 | pos_enc_at_cross_attn_queries: false
50 | cross_attention:
51 | _target_: sam2.modeling.sam.transformer.RoPEAttention
52 | rope_theta: 10000.0
53 | feat_sizes: [32, 32]
54 | rope_k_repeat: True
55 | embedding_dim: 256
56 | num_heads: 1
57 | downsample_rate: 1
58 | dropout: 0.1
59 | kv_in_dim: 64
60 | num_layers: 4
61 |
62 | memory_encoder:
63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
64 | out_dim: 64
65 | position_encoding:
66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67 | num_pos_feats: 64
68 | normalize: true
69 | scale: null
70 | temperature: 10000
71 | mask_downsampler:
72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
73 | kernel_size: 3
74 | stride: 2
75 | padding: 1
76 | fuser:
77 | _target_: sam2.modeling.memory_encoder.Fuser
78 | layer:
79 | _target_: sam2.modeling.memory_encoder.CXBlock
80 | dim: 256
81 | kernel_size: 7
82 | padding: 3
83 | layer_scale_init_value: 1e-6
84 | use_dwconv: True # depth-wise convs
85 | num_layers: 2
86 |
87 | num_maskmem: 7
88 | image_size: 1024
89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90 | sigmoid_scale_for_mem_enc: 20.0
91 | sigmoid_bias_for_mem_enc: -10.0
92 | use_mask_input_as_output_without_sam: true
93 | # Memory
94 | directly_add_no_mem_embed: true
95 | # use high-resolution feature map in the SAM mask decoder
96 | use_high_res_features_in_sam: true
97 | # output 3 masks on the first click on initial conditioning frames
98 | multimask_output_in_sam: true
99 | # SAM heads
100 | iou_prediction_use_sigmoid: True
101 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
102 | use_obj_ptrs_in_encoder: true
103 | add_tpos_enc_to_obj_ptrs: false
104 | only_obj_ptrs_in_the_past_for_eval: true
105 | # object occlusion prediction
106 | pred_obj_scores: true
107 | pred_obj_scores_mlp: true
108 | fixed_no_obj_ptr: true
109 | # multimask tracking settings
110 | multimask_output_for_tracking: true
111 | use_multimask_token_for_obj_ptr: true
112 | multimask_min_pt_num: 0
113 | multimask_max_pt_num: 1
114 | use_mlp_for_obj_ptr_proj: true
115 | # Compilation flag
116 | compile_image_encoder: False
117 |
--------------------------------------------------------------------------------
/configs/sam2_hiera_l.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 144
12 | num_heads: 2
13 | stages: [2, 6, 36, 4]
14 | global_att_blocks: [23, 33, 43]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | window_spec: [8, 4, 16, 8]
17 | neck:
18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19 | position_encoding:
20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21 | num_pos_feats: 256
22 | normalize: true
23 | scale: null
24 | temperature: 10000
25 | d_model: 256
26 | backbone_channel_list: [1152, 576, 288, 144]
27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28 | fpn_interp_model: nearest
29 |
30 | memory_attention:
31 | _target_: sam2.modeling.memory_attention.MemoryAttention
32 | d_model: 256
33 | pos_enc_at_input: true
34 | layer:
35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36 | activation: relu
37 | dim_feedforward: 2048
38 | dropout: 0.1
39 | pos_enc_at_attn: false
40 | self_attention:
41 | _target_: sam2.modeling.sam.transformer.RoPEAttention
42 | rope_theta: 10000.0
43 | feat_sizes: [32, 32]
44 | embedding_dim: 256
45 | num_heads: 1
46 | downsample_rate: 1
47 | dropout: 0.1
48 | d_model: 256
49 | pos_enc_at_cross_attn_keys: true
50 | pos_enc_at_cross_attn_queries: false
51 | cross_attention:
52 | _target_: sam2.modeling.sam.transformer.RoPEAttention
53 | rope_theta: 10000.0
54 | feat_sizes: [32, 32]
55 | rope_k_repeat: True
56 | embedding_dim: 256
57 | num_heads: 1
58 | downsample_rate: 1
59 | dropout: 0.1
60 | kv_in_dim: 64
61 | num_layers: 4
62 |
63 | memory_encoder:
64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
65 | out_dim: 64
66 | position_encoding:
67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68 | num_pos_feats: 64
69 | normalize: true
70 | scale: null
71 | temperature: 10000
72 | mask_downsampler:
73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
74 | kernel_size: 3
75 | stride: 2
76 | padding: 1
77 | fuser:
78 | _target_: sam2.modeling.memory_encoder.Fuser
79 | layer:
80 | _target_: sam2.modeling.memory_encoder.CXBlock
81 | dim: 256
82 | kernel_size: 7
83 | padding: 3
84 | layer_scale_init_value: 1e-6
85 | use_dwconv: True # depth-wise convs
86 | num_layers: 2
87 |
88 | num_maskmem: 7
89 | image_size: 1024
90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91 | sigmoid_scale_for_mem_enc: 20.0
92 | sigmoid_bias_for_mem_enc: -10.0
93 | use_mask_input_as_output_without_sam: true
94 | # Memory
95 | directly_add_no_mem_embed: true
96 | # use high-resolution feature map in the SAM mask decoder
97 | use_high_res_features_in_sam: true
98 | # output 3 masks on the first click on initial conditioning frames
99 | multimask_output_in_sam: true
100 | # SAM heads
101 | iou_prediction_use_sigmoid: True
102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103 | use_obj_ptrs_in_encoder: true
104 | add_tpos_enc_to_obj_ptrs: false
105 | only_obj_ptrs_in_the_past_for_eval: true
106 | # object occlusion prediction
107 | pred_obj_scores: true
108 | pred_obj_scores_mlp: true
109 | fixed_no_obj_ptr: true
110 | # multimask tracking settings
111 | multimask_output_for_tracking: true
112 | use_multimask_token_for_obj_ptr: true
113 | multimask_min_pt_num: 0
114 | multimask_max_pt_num: 1
115 | use_mlp_for_obj_ptr_proj: true
116 | # Compilation flag
117 | compile_image_encoder: False
118 |
--------------------------------------------------------------------------------
/configs/sam2_hiera_t.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 96
12 | num_heads: 1
13 | stages: [1, 2, 7, 2]
14 | global_att_blocks: [5, 7, 9]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | neck:
17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18 | position_encoding:
19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20 | num_pos_feats: 256
21 | normalize: true
22 | scale: null
23 | temperature: 10000
24 | d_model: 256
25 | backbone_channel_list: [768, 384, 192, 96]
26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27 | fpn_interp_model: nearest
28 |
29 | memory_attention:
30 | _target_: sam2.modeling.memory_attention.MemoryAttention
31 | d_model: 256
32 | pos_enc_at_input: true
33 | layer:
34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35 | activation: relu
36 | dim_feedforward: 2048
37 | dropout: 0.1
38 | pos_enc_at_attn: false
39 | self_attention:
40 | _target_: sam2.modeling.sam.transformer.RoPEAttention
41 | rope_theta: 10000.0
42 | feat_sizes: [32, 32]
43 | embedding_dim: 256
44 | num_heads: 1
45 | downsample_rate: 1
46 | dropout: 0.1
47 | d_model: 256
48 | pos_enc_at_cross_attn_keys: true
49 | pos_enc_at_cross_attn_queries: false
50 | cross_attention:
51 | _target_: sam2.modeling.sam.transformer.RoPEAttention
52 | rope_theta: 10000.0
53 | feat_sizes: [32, 32]
54 | rope_k_repeat: True
55 | embedding_dim: 256
56 | num_heads: 1
57 | downsample_rate: 1
58 | dropout: 0.1
59 | kv_in_dim: 64
60 | num_layers: 4
61 |
62 | memory_encoder:
63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
64 | out_dim: 64
65 | position_encoding:
66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67 | num_pos_feats: 64
68 | normalize: true
69 | scale: null
70 | temperature: 10000
71 | mask_downsampler:
72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
73 | kernel_size: 3
74 | stride: 2
75 | padding: 1
76 | fuser:
77 | _target_: sam2.modeling.memory_encoder.Fuser
78 | layer:
79 | _target_: sam2.modeling.memory_encoder.CXBlock
80 | dim: 256
81 | kernel_size: 7
82 | padding: 3
83 | layer_scale_init_value: 1e-6
84 | use_dwconv: True # depth-wise convs
85 | num_layers: 2
86 |
87 | num_maskmem: 7
88 | image_size: 1024
89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90 | # SAM decoder
91 | sigmoid_scale_for_mem_enc: 20.0
92 | sigmoid_bias_for_mem_enc: -10.0
93 | use_mask_input_as_output_without_sam: true
94 | # Memory
95 | directly_add_no_mem_embed: true
96 | # use high-resolution feature map in the SAM mask decoder
97 | use_high_res_features_in_sam: true
98 | # output 3 masks on the first click on initial conditioning frames
99 | multimask_output_in_sam: true
100 | # SAM heads
101 | iou_prediction_use_sigmoid: True
102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103 | use_obj_ptrs_in_encoder: true
104 | add_tpos_enc_to_obj_ptrs: false
105 | only_obj_ptrs_in_the_past_for_eval: true
106 | # object occlusion prediction
107 | pred_obj_scores: true
108 | pred_obj_scores_mlp: true
109 | fixed_no_obj_ptr: true
110 | # multimask tracking settings
111 | multimask_output_for_tracking: true
112 | use_multimask_token_for_obj_ptr: true
113 | multimask_min_pt_num: 0
114 | multimask_max_pt_num: 1
115 | use_mlp_for_obj_ptr_proj: true
116 | # Compilation flag
117 | # HieraT does not currently support compilation, should always be set to False
118 | compile_image_encoder: False
119 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | import numpy as np
4 | import os
5 |
6 | try:
7 | from app import process_image
8 | except ImportError:
9 | # We're running as a module
10 | from .app import process_image
11 | from .utils.sam import model_to_config_map as sam_model_to_config_map
12 |
13 |
14 | # Format conversion helpers adapted from LayerStyle -- but LayerStyle has them
15 | # wrong: this is not the place to squeeze/unsqueeze.
16 | #
17 | # - [tensor2pil](https://github.com/chflame163/ComfyUI_LayerStyle/blob/28c1a4f3082d0af5067a7bc4b72951a8dd47b9b8/py/imagefunc.py#L131)
18 | # - [pil2tensor](https://github.com/chflame163/ComfyUI_LayerStyle/blob/28c1a4f3082d0af5067a7bc4b72951a8dd47b9b8/py/imagefunc.py#L111)
19 | #
20 | # LayerStyle wrongly, misguidedly, and confusingly, un/squeezes the batch
21 | # dimension in the helpers, and then in the main code, they have to reverse
22 | # that. The batch dimension is there for a reason, people, it's not something
23 | # to be abstracted away! So our version leaves that out.
24 | def tensor2pil(t_image: torch.Tensor) -> Image.Image:
25 | return Image.fromarray(np.clip(255.0 * t_image.cpu().numpy(), 0, 255).astype(np.uint8))
26 | def pil2tensor(image: Image.Image) -> torch.Tensor:
27 | return torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
28 |
29 |
30 |
31 |
32 | class F2S2GenerateMask:
33 | def __init__(self):
34 | if os.name == "nt":
35 | self._fix_problems()
36 |
37 | def _fix_problems(self):
38 | print(f"torch version: {torch.__version__}")
39 | print(f"torch CUDA available: {torch.cuda.is_available()}")
40 | print("disabling gradients and optimisations")
41 | torch.backends.cudnn.enabled = False
42 | # print("setting CUDA_LAUNCH_BLOCKING=1 TORCH_USE_RTLD_GLOBAL=1")
43 | # os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
44 | # os.environ['TORCH_USE_RTLD_GLOBAL'] = '1'
45 | # Check if the environment variable is already set
46 | if os.getenv('TORCH_CUDNN_SDPA_ENABLED') != '1':
47 | os.environ['TORCH_CUDNN_SDPA_ENABLED'] = '1'
48 | print("CuDNN SDPA enabled.")
49 | else:
50 | print("CuDNN SDPA was already enabled.")
51 |
52 | @classmethod
53 | def INPUT_TYPES(cls):
54 | model_list = list(sam_model_to_config_map.keys())
55 | model_list.sort()
56 | device_list = ["cuda", "cpu"]
57 | return {
58 | "required": {
59 | "sam2_model": (model_list, {"default": "sam2_hiera_small.pt"}),
60 | "device": (device_list,),
61 | "image": ("IMAGE",),
62 | "prompt": ("STRING", {"default": "subject"}),
63 | },
64 | "optional": {
65 | "keep_model_loaded": ("BOOLEAN", {"default": False}),
66 | }
67 | }
68 |
69 | RETURN_TYPES = ("IMAGE", "MASK", "IMAGE",)
70 | RETURN_NAMES = ("annotated_image", "mask", "masked_image",)
71 | FUNCTION = "_process_image"
72 | CATEGORY = "💃rDancer"
73 |
74 | def _process_image(self, sam2_model: str, device: str, image: torch.Tensor, prompt: str = None, keep_model_loaded: bool = False):
75 | torch_device = torch.device(device)
76 | prompt = prompt.strip() if prompt else ""
77 | annotated_images, masks, masked_images = [], [], []
78 | # Convert image from tensor to PIL
79 | # the image has an extra batch dimension, despite the variable name
80 | for i, img in enumerate(image):
81 | img = tensor2pil(img).convert("RGB")
82 | keep_model_loaded = keep_model_loaded if i == (image.size(0) - 1) else True
83 | annotated_image, mask, masked_image = process_image(torch_device, sam2_model, img, prompt, keep_model_loaded)
84 | annotated_images.append(pil2tensor(annotated_image))
85 | masks.append(pil2tensor(mask))
86 | masked_images.append(pil2tensor(masked_image))
87 | annotated_images = torch.stack(annotated_images)
88 | masks = torch.stack(masks)
89 | masked_images = torch.stack(masked_images)
90 | return (annotated_images, masks, masked_images, )
91 |
92 |
93 | NODE_CLASS_MAPPINGS = {
94 | "RdancerFlorence2SAM2GenerateMask": F2S2GenerateMask
95 | }
96 |
97 | __all__ = ["NODE_CLASS_MAPPINGS"]
98 |
99 | if __name__ == "__main__":
100 | # detect which parameters are filenames -- those are images
101 | # the rest are prompts
102 | # call process_image with the images and prompts
103 | # save the output images
104 | # return the output images' filenames
105 | import sys
106 | import os
107 | import argparse
108 | from app import process_image
109 |
110 | # import rdancer_debug # will listen for debugger to attach
111 |
112 | def my_process_image(image_path, prompt):
113 | from utils.sam import SAM_CHECKPOINT
114 | image = Image.open(image_path).convert("RGB")
115 | annotated_image, mask, masked_image = process_image(SAM_CHECKPOINT, image, prompt)
116 | output_image_path, output_mask_path, output_masked_image_path = f"output_image_{os.path.basename(image_path)}", f"output_mask_{os.path.basename(image_path)}", f"output_masked_image_{os.path.basename(image_path)}"
117 | annotated_image.save(output_image_path)
118 | mask.save(output_mask_path)
119 | masked_image.save(output_masked_image_path)
120 | return output_image_path, output_mask_path, output_masked_image_path
121 |
122 | if len(sys.argv) < 2:
123 | print(f"Usage: python {os.path.basename(sys.argv[0])} [ ...] []")
124 | sys.exit(1)
125 |
126 | # test which exist as filenames
127 | images = []
128 | prompts = []
129 |
130 | for arg in sys.argv[1:]:
131 | if not os.path.exists(arg):
132 | prompts.append(arg)
133 | else:
134 | images.append(arg)
135 |
136 | if len(prompts) > 1:
137 | raise ValueError("At most one prompt is required")
138 | if len(images) < 1:
139 | raise ValueError("At least one image is required")
140 |
141 | prompt = prompts[0].strip() if prompts else None
142 |
143 | print(f"Processing {len(images)} image{'' if len(images) == 1 else 's'} with prompt: {prompt}")
144 |
145 | from app import process_image
146 |
147 | for image in images:
148 | output_image_path, output_mask_path, output_masked_image_path = my_process_image(image, prompt)
149 | print(f"Saved output image to {output_image_path} and mask to {output_mask_path} and masked image to {output_masked_image_path}")
150 |
151 |
--------------------------------------------------------------------------------
/workflows/workflow.json:
--------------------------------------------------------------------------------
1 | {
2 | "last_node_id": 61,
3 | "last_link_id": 68,
4 | "nodes": [
5 | {
6 | "id": 30,
7 | "type": "LoadImage",
8 | "pos": {
9 | "0": 360,
10 | "1": 520
11 | },
12 | "size": {
13 | "0": 278.3598327636719,
14 | "1": 400.1376647949219
15 | },
16 | "flags": {},
17 | "order": 0,
18 | "mode": 0,
19 | "inputs": [],
20 | "outputs": [
21 | {
22 | "name": "IMAGE",
23 | "type": "IMAGE",
24 | "links": [
25 | 28
26 | ],
27 | "slot_index": 0,
28 | "shape": 3
29 | },
30 | {
31 | "name": "MASK",
32 | "type": "MASK",
33 | "links": null,
34 | "shape": 3
35 | }
36 | ],
37 | "properties": {
38 | "Node name for S&R": "LoadImage"
39 | },
40 | "widgets_values": [
41 | "Screenshot 2024-09-16 221105.png",
42 | "image"
43 | ]
44 | },
45 | {
46 | "id": 36,
47 | "type": "LoadImage",
48 | "pos": {
49 | "0": 650,
50 | "1": 520
51 | },
52 | "size": {
53 | "0": 278.3598327636719,
54 | "1": 400.1376647949219
55 | },
56 | "flags": {},
57 | "order": 1,
58 | "mode": 0,
59 | "inputs": [],
60 | "outputs": [
61 | {
62 | "name": "IMAGE",
63 | "type": "IMAGE",
64 | "links": [
65 | 29
66 | ],
67 | "slot_index": 0,
68 | "shape": 3
69 | },
70 | {
71 | "name": "MASK",
72 | "type": "MASK",
73 | "links": null,
74 | "shape": 3
75 | }
76 | ],
77 | "properties": {
78 | "Node name for S&R": "LoadImage"
79 | },
80 | "widgets_values": [
81 | "Screenshot 2024-09-16 221234.png",
82 | "image"
83 | ]
84 | },
85 | {
86 | "id": 31,
87 | "type": "PreviewImage",
88 | "pos": {
89 | "0": 360,
90 | "1": 970
91 | },
92 | "size": {
93 | "0": 471.4238586425781,
94 | "1": 380.51617431640625
95 | },
96 | "flags": {},
97 | "order": 4,
98 | "mode": 0,
99 | "inputs": [
100 | {
101 | "name": "images",
102 | "type": "IMAGE",
103 | "link": 66
104 | }
105 | ],
106 | "outputs": [],
107 | "properties": {
108 | "Node name for S&R": "PreviewImage"
109 | },
110 | "widgets_values": []
111 | },
112 | {
113 | "id": 51,
114 | "type": "PreviewImage",
115 | "pos": {
116 | "0": 1160,
117 | "1": 970
118 | },
119 | "size": {
120 | "0": 470.9900207519531,
121 | "1": 380.91436767578125
122 | },
123 | "flags": {},
124 | "order": 6,
125 | "mode": 0,
126 | "inputs": [
127 | {
128 | "name": "images",
129 | "type": "IMAGE",
130 | "link": 68
131 | }
132 | ],
133 | "outputs": [],
134 | "properties": {
135 | "Node name for S&R": "PreviewImage"
136 | },
137 | "widgets_values": []
138 | },
139 | {
140 | "id": 58,
141 | "type": "PreviewImage",
142 | "pos": {
143 | "0": 1160,
144 | "1": 525
145 | },
146 | "size": {
147 | "0": 470.9900207519531,
148 | "1": 380.91436767578125
149 | },
150 | "flags": {},
151 | "order": 7,
152 | "mode": 0,
153 | "inputs": [
154 | {
155 | "name": "images",
156 | "type": "IMAGE",
157 | "link": 60
158 | }
159 | ],
160 | "outputs": [],
161 | "properties": {
162 | "Node name for S&R": "PreviewImage"
163 | },
164 | "widgets_values": []
165 | },
166 | {
167 | "id": 37,
168 | "type": "ImageBatch",
169 | "pos": {
170 | "0": 880,
171 | "1": 1010
172 | },
173 | "size": {
174 | "0": 210,
175 | "1": 46
176 | },
177 | "flags": {},
178 | "order": 2,
179 | "mode": 0,
180 | "inputs": [
181 | {
182 | "name": "image1",
183 | "type": "IMAGE",
184 | "link": 28
185 | },
186 | {
187 | "name": "image2",
188 | "type": "IMAGE",
189 | "link": 29
190 | }
191 | ],
192 | "outputs": [
193 | {
194 | "name": "IMAGE",
195 | "type": "IMAGE",
196 | "links": [
197 | 65
198 | ],
199 | "slot_index": 0,
200 | "shape": 3
201 | }
202 | ],
203 | "properties": {
204 | "Node name for S&R": "ImageBatch"
205 | },
206 | "widgets_values": []
207 | },
208 | {
209 | "id": 57,
210 | "type": "MaskToImage",
211 | "pos": {
212 | "0": 856,
213 | "1": 1335
214 | },
215 | "size": {
216 | "0": 264.5999755859375,
217 | "1": 26
218 | },
219 | "flags": {},
220 | "order": 5,
221 | "mode": 0,
222 | "inputs": [
223 | {
224 | "name": "mask",
225 | "type": "MASK",
226 | "link": 67
227 | }
228 | ],
229 | "outputs": [
230 | {
231 | "name": "IMAGE",
232 | "type": "IMAGE",
233 | "links": [
234 | 60
235 | ],
236 | "slot_index": 0,
237 | "shape": 3
238 | }
239 | ],
240 | "properties": {
241 | "Node name for S&R": "MaskToImage"
242 | },
243 | "widgets_values": []
244 | },
245 | {
246 | "id": 60,
247 | "type": "RdancerFlorence2SAM2GenerateMask",
248 | "pos": {
249 | "0": 840,
250 | "1": 1120
251 | },
252 | "size": {
253 | "0": 303.60308837890625,
254 | "1": 170
255 | },
256 | "flags": {},
257 | "order": 3,
258 | "mode": 0,
259 | "inputs": [
260 | {
261 | "name": "image",
262 | "type": "IMAGE",
263 | "link": 65
264 | }
265 | ],
266 | "outputs": [
267 | {
268 | "name": "annotated_image",
269 | "type": "IMAGE",
270 | "links": [
271 | 66
272 | ],
273 | "shape": 3
274 | },
275 | {
276 | "name": "mask",
277 | "type": "MASK",
278 | "links": [
279 | 67
280 | ],
281 | "shape": 3
282 | },
283 | {
284 | "name": "masked_image",
285 | "type": "IMAGE",
286 | "links": [
287 | 68
288 | ],
289 | "shape": 3
290 | }
291 | ],
292 | "properties": {
293 | "Node name for S&R": "RdancerFlorence2SAM2GenerateMask"
294 | },
295 | "widgets_values": [
296 | "sam2_hiera_small.pt",
297 | "cuda",
298 | "products",
299 | false
300 | ]
301 | }
302 | ],
303 | "links": [
304 | [
305 | 28,
306 | 30,
307 | 0,
308 | 37,
309 | 0,
310 | "IMAGE"
311 | ],
312 | [
313 | 29,
314 | 36,
315 | 0,
316 | 37,
317 | 1,
318 | "IMAGE"
319 | ],
320 | [
321 | 60,
322 | 57,
323 | 0,
324 | 58,
325 | 0,
326 | "IMAGE"
327 | ],
328 | [
329 | 65,
330 | 37,
331 | 0,
332 | 60,
333 | 0,
334 | "IMAGE"
335 | ],
336 | [
337 | 66,
338 | 60,
339 | 0,
340 | 31,
341 | 0,
342 | "IMAGE"
343 | ],
344 | [
345 | 67,
346 | 60,
347 | 1,
348 | 57,
349 | 0,
350 | "MASK"
351 | ],
352 | [
353 | 68,
354 | 60,
355 | 2,
356 | 51,
357 | 0,
358 | "IMAGE"
359 | ]
360 | ],
361 | "groups": [],
362 | "config": {},
363 | "extra": {
364 | "ds": {
365 | "scale": 2.0677378438607503,
366 | "offset": [
367 | -675.6252784124076,
368 | -892.8859744309092
369 | ]
370 | }
371 | },
372 | "version": 0.4
373 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Tuple, Optional
3 |
4 | import cv2
5 | # import gradio as gr # Gradio phones home, we don't want that
6 | import numpy as np
7 | # import spaces
8 | import supervision as sv
9 | import torch
10 | from PIL import Image
11 | from tqdm import tqdm
12 | import gc
13 |
14 | import comfy.model_management as mm
15 |
16 | try:
17 | from utils.video import generate_unique_name, create_directory, delete_directory
18 |
19 | from utils.florence import load_florence_model, run_florence_inference, \
20 | FLORENCE_OPEN_VOCABULARY_DETECTION_TASK #,
21 | # FLORENCE_DETAILED_CAPTION_TASK, FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK
22 | from utils.modes import IMAGE_INFERENCE_MODES, IMAGE_OPEN_VOCABULARY_DETECTION_MODE #, \
23 | # IMAGE_CAPTION_GROUNDING_MASKS_MODE, VIDEO_INFERENCE_MODES
24 | from utils.sam import load_sam_image_model, run_sam_inference #, load_sam_video_model
25 | except ImportError:
26 | # We're running as a module
27 | from .utils.video import generate_unique_name, create_directory, delete_directory
28 |
29 | from .utils.florence import load_florence_model, run_florence_inference, \
30 | FLORENCE_OPEN_VOCABULARY_DETECTION_TASK #,
31 | # FLORENCE_DETAILED_CAPTION_TASK, FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK
32 | from .utils.modes import IMAGE_INFERENCE_MODES, IMAGE_OPEN_VOCABULARY_DETECTION_MODE #, \
33 | # IMAGE_CAPTION_GROUNDING_MASKS_MODE, VIDEO_INFERENCE_MODES
34 | from .utils.sam import load_sam_image_model, run_sam_inference #, load_sam_video_model
35 |
36 | # MARKDOWN = """
37 | # # Florence2 + SAM2 🔥
38 |
39 | #
53 |
54 | # This demo integrates Florence2 and SAM2 by creating a two-stage inference pipeline. In
55 | # the first stage, Florence2 performs tasks such as object detection, open-vocabulary
56 | # object detection, image captioning, or phrase grounding. In the second stage, SAM2
57 | # performs object segmentation on the image.
58 | # """
59 |
60 | # IMAGE_PROCESSING_EXAMPLES = [
61 | # [IMAGE_OPEN_VOCABULARY_DETECTION_MODE, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg", 'straw, white napkin, black napkin, hair'],
62 | # [IMAGE_OPEN_VOCABULARY_DETECTION_MODE, "https://media.roboflow.com/notebooks/examples/dog-3.jpeg", 'tail'],
63 | # [IMAGE_CAPTION_GROUNDING_MASKS_MODE, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg", None],
64 | # [IMAGE_CAPTION_GROUNDING_MASKS_MODE, "https://media.roboflow.com/notebooks/examples/dog-3.jpeg", None],
65 | # ]
66 | # VIDEO_PROCESSING_EXAMPLES = [
67 | # ["videos/clip-07-camera-1.mp4", "player in white outfit, player in black outfit, ball, rim"],
68 | # ["videos/clip-07-camera-2.mp4", "player in white outfit, player in black outfit, ball, rim"],
69 | # ["videos/clip-07-camera-3.mp4", "player in white outfit, player in black outfit, ball, rim"]
70 | # ]
71 |
72 | # VIDEO_SCALE_FACTOR = 0.5
73 | # VIDEO_TARGET_DIRECTORY = "tmp"
74 | # create_directory(directory_path=VIDEO_TARGET_DIRECTORY)
75 |
76 | DEVICE = None #torch.device("cuda")
77 | # DEVICE = torch.device("cpu")
78 |
79 | if torch.cuda.is_available():
80 | torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
81 | if torch.cuda.get_device_properties(0).major >= 8:
82 | torch.backends.cuda.matmul.allow_tf32 = True
83 | torch.backends.cudnn.allow_tf32 = True
84 | elif torch.backends.mps.is_available():
85 | DEVICE = torch.device("mps")
86 | else:
87 | DEVICE = torch.device("cpu")
88 |
89 | FLORENCE_MODEL, FLORENCE_PROCESSOR = None, None
90 | SAM_IMAGE_MODEL = None
91 | # SAM_VIDEO_MODEL = load_sam_video_model(device=DEVICE)
92 | COLORS = ['#FF1493', '#00BFFF', '#FF6347', '#FFD700', '#32CD32', '#8A2BE2']
93 | COLOR_PALETTE = sv.ColorPalette.from_hex(COLORS)
94 | BOX_ANNOTATOR = sv.BoxAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX)
95 | LABEL_ANNOTATOR = sv.LabelAnnotator(
96 | color=COLOR_PALETTE,
97 | color_lookup=sv.ColorLookup.INDEX,
98 | text_position=sv.Position.CENTER_OF_MASS,
99 | text_color=sv.Color.from_hex("#000000"),
100 | border_radius=5
101 | )
102 | MASK_ANNOTATOR = sv.MaskAnnotator(
103 | color=COLOR_PALETTE,
104 | color_lookup=sv.ColorLookup.INDEX
105 | )
106 |
107 |
108 | def annotate_image(image, detections):
109 | output_image = image.copy()
110 | output_image = MASK_ANNOTATOR.annotate(output_image, detections)
111 | output_image = BOX_ANNOTATOR.annotate(output_image, detections)
112 | output_image = LABEL_ANNOTATOR.annotate(output_image, detections)
113 | return output_image
114 |
115 |
116 | # def on_mode_dropdown_change(text):
117 | # return [
118 | # gr.Textbox(visible=text == IMAGE_OPEN_VOCABULARY_DETECTION_MODE),
119 | # gr.Textbox(visible=text == IMAGE_CAPTION_GROUNDING_MASKS_MODE),
120 | # ]
121 |
122 | def lazy_load_models(device: torch.device, sam_image_model: str):
123 | global SAM_IMAGE_MODEL
124 | global loaded_sam_image_model
125 | global FLORENCE_MODEL
126 | global FLORENCE_PROCESSOR
127 | global DEVICE
128 | if device != DEVICE:
129 | offload_models(delete=True)
130 | DEVICE = device
131 | if SAM_IMAGE_MODEL is None:
132 | SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE, checkpoint=sam_image_model)
133 | loaded_sam_image_model = sam_image_model
134 | elif loaded_sam_image_model != sam_image_model:
135 | print(f"DEBUG [ComfyUI_Florence2SAM2::lazy_load_models] Old model {loaded_sam_image_model} != new model {sam_image_model} => releasing memory")
136 | SAM_IMAGE_MODEL.model.cpu()
137 | del SAM_IMAGE_MODEL
138 | gc.collect()
139 | torch.cuda.empty_cache()
140 | SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE, checkpoint=sam_image_model)
141 | loaded_sam_image_model = sam_image_model
142 | if FLORENCE_MODEL is None or FLORENCE_PROCESSOR is None:
143 | assert FLORENCE_MODEL is None and FLORENCE_PROCESSOR is None
144 | FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
145 |
146 | # The models could have been offloaded to RAM by offload_models(); if they're already on `device`, this is a no-op
147 | SAM_IMAGE_MODEL.model.to(device)
148 | FLORENCE_MODEL.to(device)
149 | # FLORENCE_PROCESSOR.to(device) # note a model
150 |
151 | def offload_models(delete=False):
152 | global SAM_IMAGE_MODEL
153 | global FLORENCE_MODEL
154 | global FLORENCE_PROCESSOR
155 | offload_device = mm.unet_offload_device()
156 | do_gc = False
157 | if SAM_IMAGE_MODEL is not None:
158 | if delete:
159 | SAM_IMAGE_MODEL.model.cpu()
160 | del SAM_IMAGE_MODEL
161 | SAM_IMAGE_MODEL = None
162 | do_gc = True
163 | else:
164 | SAM_IMAGE_MODEL.model.to(offload_device)
165 | if FLORENCE_MODEL is not None:
166 | if delete:
167 | FLORENCE_MODEL.cpu()
168 | del FLORENCE_MODEL
169 | FLORENCE_MODEL = None
170 | do_gc = True
171 | else:
172 | FLORENCE_MODEL.to(offload_device)
173 | if FLORENCE_PROCESSOR is not None:
174 | if delete:
175 | # FLORENCE_PROCESSOR.cpu()
176 | del FLORENCE_PROCESSOR
177 | FLORENCE_PROCESSOR = None
178 | do_gc = True
179 | else:
180 | # FLORENCE_PROCESSOR.cpu()
181 | pass
182 | mm.soft_empty_cache()
183 | if do_gc:
184 | gc.collect()
185 |
186 | def process_image(device: torch.device, sam_image_model: str, image: Image.Image, promt: str, keep_model_loaded: bool) -> Tuple[Optional[Image.Image], Optional[Image.Image], Optional[Image.Image]]:
187 | lazy_load_models(device, sam_image_model)
188 | annotated_image, mask_list = _process_image(IMAGE_OPEN_VOCABULARY_DETECTION_MODE, image, promt)
189 | if mask_list is not None and len(mask_list) > 0:
190 | mask = np.any(mask_list, axis=0) # Merge masks into a single mask
191 | mask = (mask * 255).astype(np.uint8)
192 | else:
193 | print(f"Florence2SAM2: No objects of class {promt} found in the image.")
194 | mask = np.zeros((image.height, image.width), dtype=np.uint8)
195 | mask = Image.fromarray(mask).convert("L") # Convert to 8-bit grayscale
196 | masked_image = Image.new("RGB", image.size, (0, 0, 0))
197 | masked_image.paste(image, mask=mask)
198 | if not keep_model_loaded:
199 | offload_models()
200 | return annotated_image, mask, masked_image
201 |
202 | @torch.inference_mode()
203 | @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
204 | def _process_image(
205 | mode_dropdown=IMAGE_OPEN_VOCABULARY_DETECTION_MODE, image_input=None, text_input=None
206 | ) -> Tuple[Optional[Image.Image], Optional[np.ndarray]]:
207 | """
208 | Process an image with Florence2 and SAM2.
209 |
210 | Note that the models are lazy loaded, so they will not waste time loading during startup, or at all if this method is not called.
211 |
212 | @param mode_dropdown: The mode of the Florence2 model. Must be IMAGE_OPEN_VOCABULARY_DETECTION_MODE.
213 | @param image_input: The image to process.
214 | @param text_input: The text prompt to use for the Florence2 model.
215 |
216 | @return: Tuple[Image.Image, Image.Image]: The annotated image, merged mask (Boolean array) of the detected objects
217 | """
218 | global SAM_IMAGE_MODEL
219 | global FLORENCE_MODEL
220 | global FLORENCE_PROCESSOR
221 |
222 | if not image_input:
223 | # gr.Info("Please upload an image.")
224 | return None, None
225 |
226 | if mode_dropdown == IMAGE_OPEN_VOCABULARY_DETECTION_MODE:
227 | if not text_input:
228 | # gr.Info("Please enter a text prompt.")
229 | return None, None
230 |
231 | texts = [prompt.strip() for prompt in text_input.split(",")]
232 | detections_list = []
233 | for text in texts:
234 | _, result = run_florence_inference(
235 | model=FLORENCE_MODEL,
236 | processor=FLORENCE_PROCESSOR,
237 | device=DEVICE,
238 | image=image_input,
239 | task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
240 | text=text
241 | )
242 | detections = sv.Detections.from_lmm(
243 | lmm=sv.LMM.FLORENCE_2,
244 | result=result,
245 | resolution_wh=image_input.size
246 | )
247 | detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
248 | detections_list.append(detections)
249 |
250 | detections = sv.Detections.merge(detections_list)
251 | detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
252 | return annotate_image(image_input, detections), detections.mask
253 |
254 | # if mode_dropdown == IMAGE_CAPTION_GROUNDING_MASKS_MODE:
255 | # _, result = run_florence_inference(
256 | # model=FLORENCE_MODEL,
257 | # processor=FLORENCE_PROCESSOR,
258 | # device=DEVICE,
259 | # image=image_input,
260 | # task=FLORENCE_DETAILED_CAPTION_TASK
261 | # )
262 | # caption = result[FLORENCE_DETAILED_CAPTION_TASK]
263 | # _, result = run_florence_inference(
264 | # model=FLORENCE_MODEL,
265 | # processor=FLORENCE_PROCESSOR,
266 | # device=DEVICE,
267 | # image=image_input,
268 | # task=FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK,
269 | # text=caption
270 | # )
271 | # detections = sv.Detections.from_lmm(
272 | # lmm=sv.LMM.FLORENCE_2,
273 | # result=result,
274 | # resolution_wh=image_input.size
275 | # )
276 | # detections = run_sam_inference(SAM_IMAGE_MODEL, image_input, detections)
277 | # return annotate_image(image_input, detections), caption
278 |
279 |
280 | # @spaces.GPU(duration=300)
281 | # @torch.inference_mode()
282 | # @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
283 | # def process_video(
284 | # video_input, text_input, progress=gr.Progress(track_tqdm=True)
285 | # ) -> Optional[str]:
286 | # if not video_input:
287 | # gr.Info("Please upload a video.")
288 | # return None
289 |
290 | # if not text_input:
291 | # gr.Info("Please enter a text prompt.")
292 | # return None
293 |
294 | # frame_generator = sv.get_video_frames_generator(video_input)
295 | # frame = next(frame_generator)
296 | # frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
297 |
298 | # texts = [prompt.strip() for prompt in text_input.split(",")]
299 | # detections_list = []
300 | # for text in texts:
301 | # _, result = run_florence_inference(
302 | # model=FLORENCE_MODEL,
303 | # processor=FLORENCE_PROCESSOR,
304 | # device=DEVICE,
305 | # image=frame,
306 | # task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
307 | # text=text
308 | # )
309 | # detections = sv.Detections.from_lmm(
310 | # lmm=sv.LMM.FLORENCE_2,
311 | # result=result,
312 | # resolution_wh=frame.size
313 | # )
314 | # detections = run_sam_inference(SAM_IMAGE_MODEL, frame, detections)
315 | # detections_list.append(detections)
316 |
317 | # detections = sv.Detections.merge(detections_list)
318 | # detections = run_sam_inference(SAM_IMAGE_MODEL, frame, detections)
319 |
320 | # if len(detections.mask) == 0:
321 | # gr.Info(
322 | # "No objects of class {text_input} found in the first frame of the video. "
323 | # "Trim the video to make the object appear in the first frame or try a "
324 | # "different text prompt."
325 | # )
326 | # return None
327 |
328 | # name = generate_unique_name()
329 | # frame_directory_path = os.path.join(VIDEO_TARGET_DIRECTORY, name)
330 | # frames_sink = sv.ImageSink(
331 | # target_dir_path=frame_directory_path,
332 | # image_name_pattern="{:05d}.jpeg"
333 | # )
334 |
335 | # video_info = sv.VideoInfo.from_video_path(video_input)
336 | # video_info.width = int(video_info.width * VIDEO_SCALE_FACTOR)
337 | # video_info.height = int(video_info.height * VIDEO_SCALE_FACTOR)
338 |
339 | # frames_generator = sv.get_video_frames_generator(video_input)
340 | # with frames_sink:
341 | # for frame in tqdm(
342 | # frames_generator,
343 | # total=video_info.total_frames,
344 | # desc="splitting video into frames"
345 | # ):
346 | # frame = sv.scale_image(frame, VIDEO_SCALE_FACTOR)
347 | # frames_sink.save_image(frame)
348 |
349 | # inference_state = SAM_VIDEO_MODEL.init_state(
350 | # video_path=frame_directory_path,
351 | # device=DEVICE
352 | # )
353 |
354 | # for mask_index, mask in enumerate(detections.mask):
355 | # _, object_ids, mask_logits = SAM_VIDEO_MODEL.add_new_mask(
356 | # inference_state=inference_state,
357 | # frame_idx=0,
358 | # obj_id=mask_index,
359 | # mask=mask
360 | # )
361 |
362 | # video_path = os.path.join(VIDEO_TARGET_DIRECTORY, f"{name}.mp4")
363 | # frames_generator = sv.get_video_frames_generator(video_input)
364 | # masks_generator = SAM_VIDEO_MODEL.propagate_in_video(inference_state)
365 | # with sv.VideoSink(video_path, video_info=video_info) as sink:
366 | # for frame, (_, tracker_ids, mask_logits) in zip(frames_generator, masks_generator):
367 | # frame = sv.scale_image(frame, VIDEO_SCALE_FACTOR)
368 | # masks = (mask_logits > 0.0).cpu().numpy().astype(bool)
369 | # if len(masks.shape) == 4:
370 | # masks = np.squeeze(masks, axis=1)
371 |
372 | # detections = sv.Detections(
373 | # xyxy=sv.mask_to_xyxy(masks=masks),
374 | # mask=masks,
375 | # class_id=np.array(tracker_ids)
376 | # )
377 | # annotated_frame = frame.copy()
378 | # annotated_frame = MASK_ANNOTATOR.annotate(
379 | # scene=annotated_frame, detections=detections)
380 | # annotated_frame = BOX_ANNOTATOR.annotate(
381 | # scene=annotated_frame, detections=detections)
382 | # sink.write_frame(annotated_frame)
383 |
384 | # delete_directory(frame_directory_path)
385 | # return video_path
386 |
387 |
388 | # with gr.Blocks() as demo:
389 | # gr.Markdown(MARKDOWN)
390 | # with gr.Tab("Image"):
391 | # image_processing_mode_dropdown_component = gr.Dropdown(
392 | # choices=IMAGE_INFERENCE_MODES,
393 | # value=IMAGE_INFERENCE_MODES[0],
394 | # label="Mode",
395 | # info="Select a mode to use.",
396 | # interactive=True
397 | # )
398 | # with gr.Row():
399 | # with gr.Column():
400 | # image_processing_image_input_component = gr.Image(
401 | # type='pil', label='Upload image')
402 | # image_processing_text_input_component = gr.Textbox(
403 | # label='Text prompt',
404 | # placeholder='Enter comma separated text prompts')
405 | # image_processing_submit_button_component = gr.Button(
406 | # value='Submit', variant='primary')
407 | # with gr.Column():
408 | # image_processing_image_output_component = gr.Image(
409 | # type='pil', label='Image output')
410 | # image_processing_text_output_component = gr.Textbox(
411 | # label='Caption output', visible=False)
412 |
413 | # with gr.Row():
414 | # gr.Examples(
415 | # fn=process_image,
416 | # examples=IMAGE_PROCESSING_EXAMPLES,
417 | # inputs=[
418 | # image_processing_mode_dropdown_component,
419 | # image_processing_image_input_component,
420 | # image_processing_text_input_component
421 | # ],
422 | # outputs=[
423 | # image_processing_image_output_component,
424 | # image_processing_text_output_component
425 | # ],
426 | # run_on_click=True
427 | # )
428 | # with gr.Tab("Video"):
429 | # video_processing_mode_dropdown_component = gr.Dropdown(
430 | # choices=VIDEO_INFERENCE_MODES,
431 | # value=VIDEO_INFERENCE_MODES[0],
432 | # label="Mode",
433 | # info="Select a mode to use.",
434 | # interactive=True
435 | # )
436 | # with gr.Row():
437 | # with gr.Column():
438 | # video_processing_video_input_component = gr.Video(
439 | # label='Upload video')
440 | # video_processing_text_input_component = gr.Textbox(
441 | # label='Text prompt',
442 | # placeholder='Enter comma separated text prompts')
443 | # video_processing_submit_button_component = gr.Button(
444 | # value='Submit', variant='primary')
445 | # with gr.Column():
446 | # video_processing_video_output_component = gr.Video(
447 | # label='Video output')
448 | # with gr.Row():
449 | # gr.Examples(
450 | # fn=process_video,
451 | # examples=VIDEO_PROCESSING_EXAMPLES,
452 | # inputs=[
453 | # video_processing_video_input_component,
454 | # video_processing_text_input_component
455 | # ],
456 | # outputs=video_processing_video_output_component,
457 | # run_on_click=True
458 | # )
459 |
460 | # image_processing_submit_button_component.click(
461 | # fn=process_image,
462 | # inputs=[
463 | # image_processing_mode_dropdown_component,
464 | # image_processing_image_input_component,
465 | # image_processing_text_input_component
466 | # ],
467 | # outputs=[
468 | # image_processing_image_output_component,
469 | # image_processing_text_output_component
470 | # ]
471 | # )
472 | # image_processing_text_input_component.submit(
473 | # fn=process_image,
474 | # inputs=[
475 | # image_processing_mode_dropdown_component,
476 | # image_processing_image_input_component,
477 | # image_processing_text_input_component
478 | # ],
479 | # outputs=[
480 | # image_processing_image_output_component,
481 | # image_processing_text_output_component
482 | # ]
483 | # )
484 | # image_processing_mode_dropdown_component.change(
485 | # on_mode_dropdown_change,
486 | # inputs=[image_processing_mode_dropdown_component],
487 | # outputs=[
488 | # image_processing_text_input_component,
489 | # image_processing_text_output_component
490 | # ]
491 | # )
492 | # video_processing_submit_button_component.click(
493 | # fn=process_video,
494 | # inputs=[
495 | # video_processing_video_input_component,
496 | # video_processing_text_input_component
497 | # ],
498 | # outputs=video_processing_video_output_component
499 | # )
500 | # video_processing_text_input_component.submit(
501 | # fn=process_video,
502 | # inputs=[
503 | # video_processing_video_input_component,
504 | # video_processing_text_input_component
505 | # ],
506 | # outputs=video_processing_video_output_component
507 | # )
508 |
509 | # demo.launch(debug=False, show_error=True)
510 |
--------------------------------------------------------------------------------