├── .gitignore ├── Florence.py ├── LICENSE ├── README.md ├── __init__.py ├── workflow.png ├── workflow_bbox.png └── workflow_seg_crop.png /.gitignore: -------------------------------------------------------------------------------- 1 | models/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 112 | .pdm.toml 113 | .pdm-python 114 | .pdm-build/ 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | -------------------------------------------------------------------------------- /Florence.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import copy 4 | import gc 5 | from unittest.mock import patch 6 | import random 7 | 8 | from PIL import Image, ImageDraw, ImageFont 9 | import matplotlib.pyplot as plt 10 | import matplotlib.patches as patches 11 | 12 | import numpy as np 13 | import torch 14 | from transformers import AutoProcessor, AutoModelForCausalLM 15 | from transformers.dynamic_module_utils import get_imports 16 | 17 | # Comfy Utils 18 | import folder_paths 19 | import comfy.model_management 20 | 21 | colormap = ['blue','orange','green','purple','brown','pink','gray','olive','cyan','red', 22 | 'lime','indigo','violet','aqua','magenta','coral','gold','tan','skyblue'] 23 | 24 | def fixed_get_imports(filename: str | os.PathLike) -> list[str]: 25 | """Workaround for FlashAttention""" 26 | if os.path.basename(filename) != "modeling_florence2.py": 27 | return get_imports(filename) 28 | imports = get_imports(filename) 29 | imports.remove("flash_attn") 30 | return imports 31 | 32 | def load_model(version, device): 33 | comfy_model_dir = os.path.join(folder_paths.models_dir, "LLM") 34 | if not os.path.exists(comfy_model_dir): 35 | os.mkdir(comfy_model_dir) 36 | 37 | identifier = "microsoft/Florence-2-" + version 38 | 39 | with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): 40 | model = AutoModelForCausalLM.from_pretrained(identifier, cache_dir=comfy_model_dir, trust_remote_code=True) 41 | processor = AutoProcessor.from_pretrained(identifier, cache_dir=comfy_model_dir, trust_remote_code=True) 42 | 43 | model = model.to(device) 44 | return (model, processor) 45 | 46 | def fig_to_pil(fig): 47 | buf = io.BytesIO() 48 | fig.savefig(buf, format='png', dpi=100, bbox_inches='tight', pad_inches=0) 49 | buf.seek(0) 50 | pil = Image.open(buf) 51 | plt.close() 52 | return pil 53 | 54 | def plot_bbox(image, data): 55 | fig, ax = plt.subplots() 56 | fig.set_size_inches(image.width / 100, image.height / 100) 57 | ax.imshow(image) 58 | for i, (bbox, label) in enumerate(zip(data['bboxes'], data['labels'])): 59 | x1, y1, x2, y2 = bbox 60 | rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none') 61 | ax.add_patch(rect) 62 | enum_label = f"{i}: {label}" 63 | plt.text(x1 + 7, y1 + 17, enum_label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5)) 64 | ax.axis('off') 65 | return fig 66 | 67 | def draw_polygons(image, prediction, fill_mask=False): 68 | output_image = copy.deepcopy(image) 69 | draw = ImageDraw.Draw(output_image) 70 | scale = 1 71 | for polygons, label in zip(prediction['polygons'], prediction['labels']): 72 | color = random.choice(colormap) 73 | fill_color = color if fill_mask else None 74 | for _polygon in polygons: 75 | _polygon = np.array(_polygon).reshape(-1, 2) 76 | if len(_polygon) < 3: 77 | print('Invalid polygon:', _polygon) 78 | continue 79 | _polygon = (_polygon * scale).reshape(-1).tolist() 80 | if fill_mask: 81 | draw.polygon(_polygon, outline=color, fill=fill_color) 82 | else: 83 | draw.polygon(_polygon, outline=color) 84 | draw.text((_polygon[0] + 8, _polygon[1] + 2), label, fill=color) 85 | return output_image 86 | 87 | def convert_to_od_format(data): 88 | od_results = { 89 | 'bboxes': data.get('bboxes', []), 90 | 'labels': data.get('bboxes_labels', []) 91 | } 92 | return od_results 93 | 94 | def draw_ocr_bboxes(image, prediction): 95 | scale = 1 96 | output_image = copy.deepcopy(image) 97 | draw = ImageDraw.Draw(output_image) 98 | bboxes, labels = prediction['quad_boxes'], prediction['labels'] 99 | for box, label in zip(bboxes, labels): 100 | color = random.choice(colormap) 101 | new_box = (np.array(box) * scale).tolist() 102 | draw.polygon(new_box, width=3, outline=color) 103 | draw.text((new_box[0]+8, new_box[1]+2), 104 | "{}".format(label), 105 | align="right", 106 | fill=color) 107 | return output_image 108 | 109 | TASK_OPTIONS = [ 110 | "caption", 111 | "detailed caption", 112 | "more detailed caption", 113 | "object detection", 114 | "dense region caption", 115 | "region proposal", 116 | "caption to phrase grounding", 117 | "referring expression segmentation", 118 | "region to segmentation", 119 | "open vocabulary detection", 120 | "region to category", 121 | "region to description", 122 | "OCR", 123 | "OCR with region" 124 | ] 125 | 126 | class LoadFlorence2Model: 127 | def __init__(self): 128 | self.model = None 129 | self.processor = None 130 | self.version = None 131 | self.device = comfy.model_management.get_torch_device() 132 | 133 | @classmethod 134 | def INPUT_TYPES(s): 135 | return { 136 | "required": { 137 | "version": (["base", "base-ft", "large", "large-ft"],), 138 | }, 139 | } 140 | 141 | RETURN_TYPES = ("FLORENCE2", ) 142 | FUNCTION = "load" 143 | CATEGORY = "Florence2" 144 | 145 | def load(self, version): 146 | 147 | if self.version != version: 148 | self.model, self.processor = load_model(version, self.device) 149 | self.version = version 150 | 151 | return ({'model': self.model, 'processor': self.processor, 'version': self.version, 'device': self.device}, ) 152 | 153 | class Florence2: 154 | def __init__(self): 155 | self.model = None 156 | self.processor = None 157 | self.version = None 158 | self.device = None 159 | 160 | @classmethod 161 | def INPUT_TYPES(s): 162 | return { 163 | "required": { 164 | "FLORENCE2": ("FLORENCE2",), 165 | "image": ("IMAGE",), 166 | "task": (TASK_OPTIONS, {"default": TASK_OPTIONS[0]}), 167 | "text_input": ("STRING", {}), 168 | "max_new_tokens": ("INT", {"default": 1024, "step": 1 }), 169 | "num_beams": ("INT", {"default": 3, "min": 1, "step": 1 }), 170 | "do_sample": ('BOOLEAN', {"default":False}), 171 | "fill_mask": ('BOOLEAN', {"default":False}), 172 | }, 173 | } 174 | 175 | RETURN_TYPES = ("IMAGE", "STRING", "F_BBOXES",) 176 | RETURN_NAMES = ("preview", "string", "F_BBOXES",) 177 | FUNCTION = "apply" 178 | CATEGORY = "Florence2" 179 | 180 | def apply(self, FLORENCE2, image, task, text_input, max_new_tokens, num_beams, do_sample, fill_mask): 181 | img = 255. * image[0].cpu().numpy() 182 | img = Image.fromarray(np.clip(img, 0, 255).astype(np.uint8)) 183 | 184 | self.model = FLORENCE2['model'] 185 | self.processor = FLORENCE2['processor'] 186 | self.version = FLORENCE2['version'] 187 | self.device = FLORENCE2['device'] 188 | 189 | results, output_image = self.process_image(img, task, max_new_tokens, num_beams, do_sample, fill_mask, text_input) 190 | # bboxes, labels OR polygons, labels 191 | if isinstance(results, dict): 192 | results["width"] = img.width 193 | results["height"] = img.height 194 | 195 | if output_image == None: 196 | output_image = image[0].detach().clone().unsqueeze(0) 197 | else: 198 | output_image = np.asarray(output_image).astype(np.float32) / 255 199 | output_image = torch.from_numpy(output_image).unsqueeze(0) 200 | 201 | return (output_image, str(results), results) 202 | 203 | def run_example(self, task_prompt, image, max_new_tokens, num_beams, do_sample, text_input=None): 204 | if text_input is None: 205 | prompt = task_prompt 206 | else: 207 | prompt = task_prompt + text_input 208 | inputs = self.processor(text=prompt, images=image, return_tensors="pt").to(self.device) 209 | generated_ids = self.model.generate( 210 | input_ids=inputs["input_ids"], 211 | pixel_values=inputs["pixel_values"], 212 | max_new_tokens=max_new_tokens, 213 | early_stopping=False, 214 | do_sample=do_sample, 215 | num_beams=num_beams, 216 | ) 217 | generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=False)[0] 218 | parsed_answer = self.processor.post_process_generation( 219 | generated_text, 220 | task=task_prompt, 221 | image_size=(image.width, image.height) 222 | ) 223 | return parsed_answer 224 | 225 | def process_image(self, image, task_prompt, max_new_tokens, num_beams, do_sample, fill_mask, text_input=None): 226 | if task_prompt == 'caption': 227 | task_prompt = '