├── .gitattributes ├── .gitignore ├── requirements.txt ├── __init__.py ├── pyproject.toml ├── .github └── workflows │ └── publish.yaml ├── LICENSE ├── README.md ├── configuration_florence2.py ├── nodes.py └── modeling_florence2.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *pyc 3 | .vscode 4 | __pycache__ 5 | *.egg-info 6 | *.bak 7 | checkpoints 8 | results 9 | backup -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers>=4.39.0,!=4.50.* 2 | matplotlib 3 | timm 4 | pillow>=10.2.0 5 | peft 6 | accelerate>=0.26.0 7 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-florence2" 3 | description = "Nodes to use Florence2 VLM for image vision tasks: object detection, captioning, segmentation and ocr" 4 | version = "1.0.7" 5 | license = "MIT" 6 | dependencies = ["transformers>=4.39.0,!=4.50.*"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/kijai/ComfyUI-Florence2" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "kijai" 14 | DisplayName = "ComfyUI-Florence2" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | paths: 8 | - "pyproject.toml" 9 | 10 | permissions: 11 | issues: write 12 | 13 | jobs: 14 | publish-node: 15 | name: Publish Custom Node to registry 16 | runs-on: ubuntu-latest 17 | if: ${{ github.repository_owner == 'kijai' }} 18 | steps: 19 | - name: Check out code 20 | uses: actions/checkout@v4 21 | - name: Publish Custom Node 22 | uses: Comfy-Org/publish-node-action@v1 23 | with: 24 | ## Add your own personal access token to your Github Repository secrets and reference it here. 25 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Jukka Seppänen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Florence2 in ComfyUI 2 | 3 | > Florence-2 is an advanced vision foundation model that uses a prompt-based approach to handle a wide range of vision and vision-language tasks. 4 | Florence-2 can interpret simple text prompts to perform tasks like captioning, object detection, and segmentation. 5 | It leverages our FLD-5B dataset, containing 5.4 billion annotations across 126 million images, to master multi-task learning. 6 | The model's sequence-to-sequence architecture enables it to excel in both zero-shot and fine-tuned settings, proving to be a competitive vision foundation model. 7 | 8 | ## New Feature: Document Visual Question Answering (DocVQA) 9 | 10 | This fork includes support for Document Visual Question Answering (DocVQA) using the Florence2 model. DocVQA allows you to ask questions about the content of document images, and the model will provide answers based on the visual and textual information in the document. This feature is particularly useful for extracting information from scanned documents, forms, receipts, and other text-heavy images. 11 | 12 | ## Installation: 13 | 14 | Clone this repository to 'ComfyUI/custom_nodes` folder. 15 | 16 | Install the dependencies in requirements.txt, transformers version 4.38.0 minimum is required: 17 | 18 | `pip install -r requirements.txt` 19 | 20 | or if you use portable (run this in ComfyUI_windows_portable -folder): 21 | 22 | `python_embeded\python.exe -m pip install -r ComfyUI\custom_nodes\ComfyUI-Florence2\requirements.txt` 23 | 24 | ![image](https://github.com/kijai/ComfyUI-Florence2/assets/40791699/4d537ac7-5490-470f-92f5-3007da7b9cc7) 25 | ![image](https://github.com/kijai/ComfyUI-Florence2/assets/40791699/512357b7-39ee-43ee-bb63-7347b0a8d07d) 26 | 27 | Supports most Florence2 models, which can be automatically downloaded with the `DownloadAndLoadFlorence2Model` to `ComfyUI/models/LLM`: 28 | 29 | Official: 30 | 31 | https://huggingface.co/microsoft/Florence-2-base 32 | 33 | https://huggingface.co/microsoft/Florence-2-base-ft 34 | 35 | https://huggingface.co/microsoft/Florence-2-large 36 | 37 | https://huggingface.co/microsoft/Florence-2-large-ft 38 | 39 | https://huggingface.co/HuggingFaceM4/Florence-2-DocVQA 40 | 41 | Tested finetunes: 42 | 43 | https://huggingface.co/MiaoshouAI/Florence-2-base-PromptGen-v1.5 44 | 45 | https://huggingface.co/MiaoshouAI/Florence-2-large-PromptGen-v1.5 46 | 47 | https://huggingface.co/thwri/CogFlorence-2.2-Large 48 | 49 | https://huggingface.co/HuggingFaceM4/Florence-2-DocVQA 50 | 51 | https://huggingface.co/gokaygokay/Florence-2-SD3-Captioner 52 | 53 | https://huggingface.co/gokaygokay/Florence-2-Flux-Large 54 | 55 | https://huggingface.co/NikshepShetty/Florence-2-pixelpros 56 | 57 | ## Using DocVQA 58 | 59 | To use the DocVQA feature: 60 | 1. Load a document image into ComfyUI. 61 | 2. Connect the image to the Florence2 DocVQA node. 62 | 3. Input your question about the document. 63 | 4. The node will output the answer based on the document's content. 64 | 65 | Example questions: 66 | - "What is the total amount on this receipt?" 67 | - "What is the date mentioned in this form?" 68 | - "Who is the sender of this letter?" 69 | 70 | Note: The accuracy of answers depends on the quality of the input image and the complexity of the question. 71 | -------------------------------------------------------------------------------- /configuration_florence2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import warnings 15 | """ Florence-2 configuration""" 16 | 17 | from typing import Optional 18 | 19 | from transformers import AutoConfig 20 | from transformers.configuration_utils import PretrainedConfig 21 | from transformers.utils import logging 22 | 23 | logger = logging.get_logger(__name__) 24 | 25 | class Florence2VisionConfig(PretrainedConfig): 26 | r""" 27 | This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel 28 | according to the specified arguments, defining the model architecture. Instantiating a configuration with the 29 | defaults will yield a similar configuration to that of the Florence2VisionModel architecture. 30 | 31 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 32 | documentation from [`PretrainedConfig`] for more information. 33 | 34 | Args: 35 | drop_path_rate (`float`, *optional*, defaults to 0.1): 36 | The dropout rate of the drop path layer. 37 | patch_size (`List[int]`, *optional*, defaults to [7, 3, 3, 3]): 38 | The patch size of the image. 39 | patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2, 2]): 40 | The patch stride of the image. 41 | patch_padding (`List[int]`, *optional*, defaults to [3, 1, 1, 1]): 42 | The patch padding of the image. 43 | patch_prenorm (`List[bool]`, *optional*, defaults to [false, true, true, true]): 44 | Whether to apply layer normalization before the patch embedding layer. 45 | enable_checkpoint (`bool`, *optional*, defaults to False): 46 | Whether to enable checkpointing. 47 | dim_embed (`List[int]`, *optional*, defaults to [256, 512, 1024, 2048]): 48 | The dimension of the embedding layer. 49 | num_heads (`List[int]`, *optional*, defaults to [8, 16, 32, 64]): 50 | The number of attention heads. 51 | num_groups (`List[int]`, *optional*, defaults to [8, 16, 32, 64]): 52 | The number of groups. 53 | depths (`List[int]`, *optional*, defaults to [1, 1, 9, 1]): 54 | The depth of the model. 55 | window_size (`int`, *optional*, defaults to 12): 56 | The window size of the model. 57 | projection_dim (`int`, *optional*, defaults to 1024): 58 | The dimension of the projection layer. 59 | visual_temporal_embedding (`dict`, *optional*): 60 | The configuration of the visual temporal embedding. 61 | image_pos_embed (`dict`, *optional*): 62 | The configuration of the image position embedding. 63 | image_feature_source (`List[str]`, *optional*, defaults to ["spatial_avg_pool", "temporal_avg_pool"]): 64 | The source of the image feature. 65 | Example: 66 | 67 | ```python 68 | >>> from transformers import Florence2VisionConfig, Florence2VisionModel 69 | 70 | >>> # Initializing a Florence2 Vision style configuration 71 | >>> configuration = Florence2VisionConfig() 72 | 73 | >>> # Initializing a model (with random weights) 74 | >>> model = Florence2VisionModel(configuration) 75 | 76 | >>> # Accessing the model configuration 77 | >>> configuration = model.config 78 | ```""" 79 | 80 | model_type = "florence2_vision" 81 | keys_to_ignore_at_inference = ["past_key_values"] 82 | 83 | def __init__( 84 | self, 85 | drop_path_rate=0.1, 86 | patch_size=[7, 3, 3, 3], 87 | patch_stride=[4, 2, 2, 2], 88 | patch_padding=[3, 1, 1, 1], 89 | patch_prenorm=[False, True, True, True], 90 | enable_checkpoint=False, 91 | dim_embed=[256, 512, 1024, 2048], 92 | num_heads=[8, 16, 32, 64], 93 | num_groups=[8, 16, 32, 64], 94 | depths=[1, 1, 9, 1], 95 | window_size=12, 96 | projection_dim=1024, 97 | visual_temporal_embedding=None, 98 | image_pos_embed=None, 99 | image_feature_source=["spatial_avg_pool", "temporal_avg_pool"], 100 | **kwargs, 101 | ): 102 | self.drop_path_rate = drop_path_rate 103 | self.patch_size = patch_size 104 | self.patch_stride = patch_stride 105 | self.patch_padding = patch_padding 106 | self.patch_prenorm = patch_prenorm 107 | self.enable_checkpoint = enable_checkpoint 108 | self.dim_embed = dim_embed 109 | self.num_heads = num_heads 110 | self.num_groups = num_groups 111 | self.depths = depths 112 | self.window_size = window_size 113 | self.projection_dim = projection_dim 114 | self.visual_temporal_embedding = visual_temporal_embedding 115 | self.image_pos_embed = image_pos_embed 116 | self.image_feature_source = image_feature_source 117 | 118 | super().__init__(**kwargs) 119 | 120 | 121 | 122 | class Florence2LanguageConfig(PretrainedConfig): 123 | r""" 124 | This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART 125 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 126 | defaults will yield a similar configuration to that of the BART 127 | [facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture. 128 | 129 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 130 | documentation from [`PretrainedConfig`] for more information. 131 | 132 | 133 | Args: 134 | vocab_size (`int`, *optional*, defaults to 51289): 135 | Vocabulary size of the Florence2Language model. Defines the number of different tokens that can be represented by the 136 | `inputs_ids` passed when calling [`Florence2LanguageModel`]. 137 | d_model (`int`, *optional*, defaults to 1024): 138 | Dimensionality of the layers and the pooler layer. 139 | encoder_layers (`int`, *optional*, defaults to 12): 140 | Number of encoder layers. 141 | decoder_layers (`int`, *optional*, defaults to 12): 142 | Number of decoder layers. 143 | encoder_attention_heads (`int`, *optional*, defaults to 16): 144 | Number of attention heads for each attention layer in the Transformer encoder. 145 | decoder_attention_heads (`int`, *optional*, defaults to 16): 146 | Number of attention heads for each attention layer in the Transformer decoder. 147 | decoder_ffn_dim (`int`, *optional*, defaults to 4096): 148 | Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. 149 | encoder_ffn_dim (`int`, *optional*, defaults to 4096): 150 | Dimensionality of the "intermediate" (often named feed-forward) layer in decoder. 151 | activation_function (`str` or `function`, *optional*, defaults to `"gelu"`): 152 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, 153 | `"relu"`, `"silu"` and `"gelu_new"` are supported. 154 | dropout (`float`, *optional*, defaults to 0.1): 155 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 156 | attention_dropout (`float`, *optional*, defaults to 0.0): 157 | The dropout ratio for the attention probabilities. 158 | activation_dropout (`float`, *optional*, defaults to 0.0): 159 | The dropout ratio for activations inside the fully connected layer. 160 | classifier_dropout (`float`, *optional*, defaults to 0.0): 161 | The dropout ratio for classifier. 162 | max_position_embeddings (`int`, *optional*, defaults to 1024): 163 | The maximum sequence length that this model might ever be used with. Typically set this to something large 164 | just in case (e.g., 512 or 1024 or 2048). 165 | init_std (`float`, *optional*, defaults to 0.02): 166 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 167 | encoder_layerdrop (`float`, *optional*, defaults to 0.0): 168 | The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) 169 | for more details. 170 | decoder_layerdrop (`float`, *optional*, defaults to 0.0): 171 | The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) 172 | for more details. 173 | scale_embedding (`bool`, *optional*, defaults to `False`): 174 | Scale embeddings by diving by sqrt(d_model). 175 | use_cache (`bool`, *optional*, defaults to `True`): 176 | Whether or not the model should return the last key/values attentions (not used by all models). 177 | num_labels (`int`, *optional*, defaults to 3): 178 | The number of labels to use in [`Florence2LanguageForSequenceClassification`]. 179 | forced_eos_token_id (`int`, *optional*, defaults to 2): 180 | The id of the token to force as the last generated token when `max_length` is reached. Usually set to 181 | `eos_token_id`. 182 | 183 | Example: 184 | 185 | ```python 186 | >>> from transformers import Florence2LanguageConfig, Florence2LanguageModel 187 | 188 | >>> # Initializing a Florence2 Language style configuration 189 | >>> configuration = Florence2LanguageConfig() 190 | 191 | >>> # Initializing a model (with random weights) 192 | >>> model = Florence2LangaugeModel(configuration) 193 | 194 | >>> # Accessing the model configuration 195 | >>> configuration = model.config 196 | ```""" 197 | 198 | model_type = "florence2_language" 199 | keys_to_ignore_at_inference = ["past_key_values"] 200 | attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} 201 | 202 | def __init__( 203 | self, 204 | vocab_size=51289, 205 | max_position_embeddings=1024, 206 | encoder_layers=12, 207 | encoder_ffn_dim=4096, 208 | encoder_attention_heads=16, 209 | decoder_layers=12, 210 | decoder_ffn_dim=4096, 211 | decoder_attention_heads=16, 212 | encoder_layerdrop=0.0, 213 | decoder_layerdrop=0.0, 214 | activation_function="gelu", 215 | d_model=1024, 216 | dropout=0.1, 217 | attention_dropout=0.0, 218 | activation_dropout=0.0, 219 | init_std=0.02, 220 | classifier_dropout=0.0, 221 | scale_embedding=False, 222 | use_cache=True, 223 | num_labels=3, 224 | pad_token_id=1, 225 | bos_token_id=0, 226 | eos_token_id=2, 227 | is_encoder_decoder=True, 228 | decoder_start_token_id=2, 229 | forced_eos_token_id=2, 230 | **kwargs, 231 | ): 232 | self.vocab_size = vocab_size 233 | self.max_position_embeddings = max_position_embeddings 234 | self.d_model = d_model 235 | self.encoder_ffn_dim = encoder_ffn_dim 236 | self.encoder_layers = encoder_layers 237 | self.encoder_attention_heads = encoder_attention_heads 238 | self.decoder_ffn_dim = decoder_ffn_dim 239 | self.decoder_layers = decoder_layers 240 | self.decoder_attention_heads = decoder_attention_heads 241 | self.dropout = dropout 242 | self.attention_dropout = attention_dropout 243 | self.activation_dropout = activation_dropout 244 | self.activation_function = activation_function 245 | self.init_std = init_std 246 | self.encoder_layerdrop = encoder_layerdrop 247 | self.decoder_layerdrop = decoder_layerdrop 248 | self.classifier_dropout = classifier_dropout 249 | self.use_cache = use_cache 250 | self.num_hidden_layers = encoder_layers 251 | self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True 252 | 253 | super().__init__( 254 | num_labels=num_labels, 255 | pad_token_id=pad_token_id, 256 | bos_token_id=bos_token_id, 257 | eos_token_id=eos_token_id, 258 | is_encoder_decoder=is_encoder_decoder, 259 | decoder_start_token_id=decoder_start_token_id, 260 | forced_eos_token_id=forced_eos_token_id, 261 | **kwargs, 262 | ) 263 | 264 | # ensure backward compatibility for BART CNN models 265 | if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False): 266 | self.forced_bos_token_id = self.bos_token_id 267 | warnings.warn( 268 | f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. " 269 | "The config can simply be saved and uploaded again to be fixed." 270 | ) 271 | 272 | class Florence2Config(PretrainedConfig): 273 | r""" 274 | This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an 275 | Florence-2 model according to the specified arguments, defining the model architecture. 276 | 277 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 278 | documentation from [`PretrainedConfig`] for more information. 279 | 280 | Args: 281 | vision_config (`Florence2VisionConfig`, *optional*): 282 | Custom vision config or dict 283 | text_config (`Union[AutoConfig, dict]`, *optional*): 284 | The config object of the text backbone. 285 | ignore_index (`int`, *optional*, defaults to -100): 286 | The ignore index for the loss function. 287 | vocab_size (`int`, *optional*, defaults to 51289): 288 | Vocabulary size of the Florence2model. Defines the number of different tokens that can be represented by the 289 | `inputs_ids` passed when calling [`~Florence2ForConditionalGeneration`] 290 | projection_dim (`int`, *optional*, defaults to 1024): 291 | Dimension of the multimodal projection space. 292 | 293 | Example: 294 | 295 | ```python 296 | >>> from transformers import Florence2ForConditionalGeneration, Florence2Config, CLIPVisionConfig, BartConfig 297 | 298 | >>> # Initializing a clip-like vision config 299 | >>> vision_config = CLIPVisionConfig() 300 | 301 | >>> # Initializing a Bart config 302 | >>> text_config = BartConfig() 303 | 304 | >>> # Initializing a Florence-2 configuration 305 | >>> configuration = Florence2Config(vision_config, text_config) 306 | 307 | >>> # Initializing a model from the florence-2 configuration 308 | >>> model = Florence2ForConditionalGeneration(configuration) 309 | 310 | >>> # Accessing the model configuration 311 | >>> configuration = model.config 312 | ```""" 313 | 314 | model_type = "florence2" 315 | is_composition = False 316 | 317 | def __init__( 318 | self, 319 | vision_config=None, 320 | text_config=None, 321 | ignore_index=-100, 322 | vocab_size=51289, 323 | projection_dim=1024, 324 | **kwargs, 325 | ): 326 | self.ignore_index = ignore_index 327 | self.vocab_size = vocab_size 328 | self.projection_dim = projection_dim 329 | if vision_config is not None: 330 | vision_config = PretrainedConfig(**vision_config) 331 | self.vision_config = vision_config 332 | self.vocab_size = self.vocab_size 333 | 334 | self.text_config = text_config 335 | if text_config is not None: 336 | self.text_config = Florence2LanguageConfig(**text_config) 337 | 338 | 339 | super().__init__(**kwargs) 340 | 341 | -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | import torch 3 | import torchvision.transforms.functional as F 4 | import io 5 | import os 6 | import matplotlib 7 | matplotlib.use('Agg') 8 | import matplotlib.pyplot as plt 9 | import matplotlib.patches as patches 10 | from PIL import Image, ImageDraw, ImageColor, ImageFont 11 | import random 12 | import numpy as np 13 | import re 14 | from pathlib import Path 15 | 16 | #workaround for unnecessary flash_attn requirement 17 | from unittest.mock import patch 18 | from transformers.dynamic_module_utils import get_imports 19 | 20 | import transformers 21 | 22 | from safetensors.torch import save_file 23 | 24 | def fixed_get_imports(filename: str | os.PathLike) -> list[str]: 25 | try: 26 | if not str(filename).endswith("modeling_florence2.py"): 27 | return get_imports(filename) 28 | imports = get_imports(filename) 29 | imports.remove("flash_attn") 30 | except: 31 | print(f"No flash_attn import to remove") 32 | pass 33 | return imports 34 | 35 | 36 | def create_path_dict(paths: list[str], predicate: Callable[[Path], bool] = lambda _: True) -> dict[str, str]: 37 | """ 38 | Creates a flat dictionary of the contents of all given paths: ``{name: absolute_path}``. 39 | 40 | Non-recursive. Optionally takes a predicate to filter items. Duplicate names overwrite (the last one wins). 41 | 42 | Args: 43 | paths (list[str]): 44 | The paths to search for items. 45 | predicate (Callable[[Path], bool]): 46 | (Optional) If provided, each path is tested against this filter. 47 | Returns ``True`` to include a path. 48 | 49 | Default: Include everything 50 | """ 51 | 52 | flattened_paths = [item for path in paths if Path(path).exists() for item in Path(path).iterdir() if predicate(item)] 53 | 54 | return {item.name: str(item.absolute()) for item in flattened_paths} 55 | 56 | 57 | import comfy.model_management as mm 58 | from comfy.utils import ProgressBar 59 | import folder_paths 60 | 61 | script_directory = os.path.dirname(os.path.abspath(__file__)) 62 | model_directory = os.path.join(folder_paths.models_dir, "LLM") 63 | os.makedirs(model_directory, exist_ok=True) 64 | 65 | # Ensure ComfyUI knows about the LLM model path 66 | folder_paths.add_model_folder_path("LLM", model_directory) 67 | 68 | from transformers import AutoModelForCausalLM, AutoProcessor, set_seed 69 | 70 | model_list = [ 71 | 'microsoft/Florence-2-base', 72 | 'microsoft/Florence-2-base-ft', 73 | 'microsoft/Florence-2-large', 74 | 'microsoft/Florence-2-large-ft', 75 | 'HuggingFaceM4/Florence-2-DocVQA', 76 | 'thwri/CogFlorence-2.1-Large', 77 | 'thwri/CogFlorence-2.2-Large', 78 | 'gokaygokay/Florence-2-SD3-Captioner', 79 | 'gokaygokay/Florence-2-Flux-Large', 80 | 'MiaoshouAI/Florence-2-base-PromptGen-v1.5', 81 | 'MiaoshouAI/Florence-2-large-PromptGen-v1.5', 82 | 'MiaoshouAI/Florence-2-base-PromptGen-v2.0', 83 | 'MiaoshouAI/Florence-2-large-PromptGen-v2.0', 84 | 'PJMixers-Images/Florence-2-base-Castollux-v0.5' 85 | ] 86 | 87 | class DownloadAndLoadFlorence2Model: 88 | @classmethod 89 | def INPUT_TYPES(s): 90 | return {"required": { 91 | "model": (model_list, {"default": 'microsoft/Florence-2-base'}), 92 | "precision": ([ 'fp16','bf16','fp32'], 93 | { 94 | "default": 'fp16' 95 | }), 96 | "attention": ( 97 | [ 'flash_attention_2', 'sdpa', 'eager'], 98 | { 99 | "default": 'sdpa' 100 | }), 101 | }, 102 | "optional": { 103 | "lora": ("PEFTLORA",), 104 | "convert_to_safetensors": ("BOOLEAN", {"default": False, "tooltip": "Some of the older model weights are not saved in .safetensors format, which seem to cause longer loading times, this option converts the .bin weights to .safetensors"}), 105 | } 106 | } 107 | 108 | RETURN_TYPES = ("FL2MODEL",) 109 | RETURN_NAMES = ("florence2_model",) 110 | FUNCTION = "loadmodel" 111 | CATEGORY = "Florence2" 112 | 113 | def loadmodel(self, model, precision, attention, lora=None, convert_to_safetensors=False): 114 | if model not in model_list: 115 | raise ValueError(f"Model {model} is not in the supported model list.") 116 | device = mm.get_torch_device() 117 | offload_device = mm.unet_offload_device() 118 | dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] 119 | 120 | model_name = model.rsplit('/', 1)[-1] 121 | model_path = os.path.join(model_directory, model_name) 122 | 123 | if not os.path.exists(model_path): 124 | print(f"Downloading Florence2 model to: {model_path}") 125 | from huggingface_hub import snapshot_download 126 | snapshot_download(repo_id=model, 127 | local_dir=model_path, 128 | local_dir_use_symlinks=False) 129 | 130 | print(f"Florence2 using {attention} for attention") 131 | 132 | if convert_to_safetensors: 133 | model_weight_path = os.path.join(model_path, 'pytorch_model.bin') 134 | if os.path.exists(model_weight_path): 135 | safetensors_weight_path = os.path.join(model_path, 'model.safetensors') 136 | print(f"Converting {model_weight_path} to {safetensors_weight_path}") 137 | if not os.path.exists(safetensors_weight_path): 138 | sd = torch.load(model_weight_path, map_location=offload_device) 139 | sd_new = {} 140 | for k, v in sd.items(): 141 | sd_new[k] = v.clone() 142 | save_file(sd_new, safetensors_weight_path) 143 | if os.path.exists(safetensors_weight_path): 144 | print(f"Conversion successful. Deleting original file: {model_weight_path}") 145 | os.remove(model_weight_path) 146 | print(f"Original {model_weight_path} file deleted.") 147 | 148 | if transformers.__version__ < '4.51.0': 149 | with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement 150 | model = AutoModelForCausalLM.from_pretrained(model_path, attn_implementation=attention, torch_dtype=dtype,trust_remote_code=True).to(offload_device) 151 | else: 152 | from .modeling_florence2 import Florence2ForConditionalGeneration 153 | model = Florence2ForConditionalGeneration.from_pretrained(model_path, attn_implementation=attention, torch_dtype=dtype).to(offload_device) 154 | 155 | processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) 156 | 157 | if lora is not None: 158 | from peft import PeftModel 159 | adapter_name = lora 160 | model = PeftModel.from_pretrained(model, adapter_name, trust_remote_code=True) 161 | 162 | florence2_model = { 163 | 'model': model, 164 | 'processor': processor, 165 | 'dtype': dtype 166 | } 167 | 168 | return (florence2_model,) 169 | 170 | class DownloadAndLoadFlorence2Lora: 171 | @classmethod 172 | def INPUT_TYPES(s): 173 | return {"required": { 174 | "model": ( 175 | [ 176 | 'NikshepShetty/Florence-2-pixelprose', 177 | ], 178 | ), 179 | }, 180 | 181 | } 182 | 183 | RETURN_TYPES = ("PEFTLORA",) 184 | RETURN_NAMES = ("lora",) 185 | FUNCTION = "loadmodel" 186 | CATEGORY = "Florence2" 187 | 188 | def loadmodel(self, model): 189 | if model not in ['NikshepShetty/Florence-2-pixelprose']: 190 | raise ValueError(f"Lora Model {model} is not in the supported lora model list.") 191 | model_name = model.rsplit('/', 1)[-1] 192 | model_path = os.path.join(model_directory, model_name) 193 | 194 | if not os.path.exists(model_path): 195 | print(f"Downloading Florence2 lora model to: {model_path}") 196 | from huggingface_hub import snapshot_download 197 | snapshot_download(repo_id=model, 198 | local_dir=model_path, 199 | local_dir_use_symlinks=False) 200 | return (model_path,) 201 | 202 | class Florence2ModelLoader: 203 | 204 | @classmethod 205 | def INPUT_TYPES(s): 206 | all_llm_paths = folder_paths.get_folder_paths("LLM") 207 | s.model_paths = create_path_dict(all_llm_paths, lambda x: x.is_dir()) 208 | 209 | return {"required": { 210 | "model": ([*s.model_paths], {"tooltip": "models are expected to be in Comfyui/models/LLM folder"}), 211 | "precision": (['fp16','bf16','fp32'],), 212 | "attention": ( 213 | [ 'flash_attention_2', 'sdpa', 'eager'], 214 | { 215 | "default": 'sdpa' 216 | }), 217 | }, 218 | "optional": { 219 | "lora": ("PEFTLORA",), 220 | "convert_to_safetensors": ("BOOLEAN", {"default": False, "tooltip": "Some of the older model weights are not saved in .safetensors format, which seem to cause longer loading times, this option converts the .bin weights to .safetensors"}), 221 | } 222 | } 223 | 224 | RETURN_TYPES = ("FL2MODEL",) 225 | RETURN_NAMES = ("florence2_model",) 226 | FUNCTION = "loadmodel" 227 | CATEGORY = "Florence2" 228 | 229 | def loadmodel(self, model, precision, attention, lora=None, convert_to_safetensors=False): 230 | device = mm.get_torch_device() 231 | offload_device = mm.unet_offload_device() 232 | dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] 233 | model_path = Florence2ModelLoader.model_paths.get(model) 234 | print(f"Loading model from {model_path}") 235 | print(f"Florence2 using {attention} for attention") 236 | if convert_to_safetensors: 237 | model_weight_path = os.path.join(model_path, 'pytorch_model.bin') 238 | if os.path.exists(model_weight_path): 239 | safetensors_weight_path = os.path.join(model_path, 'model.safetensors') 240 | print(f"Converting {model_weight_path} to {safetensors_weight_path}") 241 | if not os.path.exists(safetensors_weight_path): 242 | sd = torch.load(model_weight_path, map_location=offload_device) 243 | sd_new = {} 244 | for k, v in sd.items(): 245 | sd_new[k] = v.clone() 246 | save_file(sd_new, safetensors_weight_path) 247 | if os.path.exists(safetensors_weight_path): 248 | print(f"Conversion successful. Deleting original file: {model_weight_path}") 249 | os.remove(model_weight_path) 250 | print(f"Original {model_weight_path} file deleted.") 251 | 252 | if transformers.__version__ < '4.51.0': 253 | with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement 254 | model = AutoModelForCausalLM.from_pretrained(model_path, attn_implementation=attention, torch_dtype=dtype,trust_remote_code=True).to(offload_device) 255 | else: 256 | from .modeling_florence2 import Florence2ForConditionalGeneration 257 | model = Florence2ForConditionalGeneration.from_pretrained(model_path, attn_implementation=attention, torch_dtype=dtype).to(offload_device) 258 | processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) 259 | 260 | if lora is not None: 261 | from peft import PeftModel 262 | adapter_name = lora 263 | model = PeftModel.from_pretrained(model, adapter_name, trust_remote_code=True) 264 | 265 | florence2_model = { 266 | 'model': model, 267 | 'processor': processor, 268 | 'dtype': dtype 269 | } 270 | 271 | return (florence2_model,) 272 | 273 | class Florence2Run: 274 | @classmethod 275 | def INPUT_TYPES(s): 276 | return { 277 | "required": { 278 | "image": ("IMAGE", ), 279 | "florence2_model": ("FL2MODEL", ), 280 | "text_input": ("STRING", {"default": "", "multiline": True}), 281 | "task": ( 282 | [ 283 | 'region_caption', 284 | 'dense_region_caption', 285 | 'region_proposal', 286 | 'caption', 287 | 'detailed_caption', 288 | 'more_detailed_caption', 289 | 'caption_to_phrase_grounding', 290 | 'referring_expression_segmentation', 291 | 'ocr', 292 | 'ocr_with_region', 293 | 'docvqa', 294 | 'prompt_gen_tags', 295 | 'prompt_gen_mixed_caption', 296 | 'prompt_gen_analyze', 297 | 'prompt_gen_mixed_caption_plus', 298 | ], 299 | ), 300 | "fill_mask": ("BOOLEAN", {"default": True}), 301 | }, 302 | "optional": { 303 | "keep_model_loaded": ("BOOLEAN", {"default": False}), 304 | "max_new_tokens": ("INT", {"default": 1024, "min": 1, "max": 4096}), 305 | "num_beams": ("INT", {"default": 3, "min": 1, "max": 64}), 306 | "do_sample": ("BOOLEAN", {"default": True}), 307 | "output_mask_select": ("STRING", {"default": ""}), 308 | "seed": ("INT", {"default": 1, "min": 1, "max": 0xffffffffffffffff}), 309 | } 310 | } 311 | 312 | RETURN_TYPES = ("IMAGE", "MASK", "STRING", "JSON") 313 | RETURN_NAMES =("image", "mask", "caption", "data") 314 | FUNCTION = "encode" 315 | CATEGORY = "Florence2" 316 | 317 | def hash_seed(self, seed): 318 | import hashlib 319 | # Convert the seed to a string and then to bytes 320 | seed_bytes = str(seed).encode('utf-8') 321 | # Create a SHA-256 hash of the seed bytes 322 | hash_object = hashlib.sha256(seed_bytes) 323 | # Convert the hash to an integer 324 | hashed_seed = int(hash_object.hexdigest(), 16) 325 | # Ensure the hashed seed is within the acceptable range for set_seed 326 | return hashed_seed % (2**32) 327 | 328 | def encode(self, image, text_input, florence2_model, task, fill_mask, keep_model_loaded=False, 329 | num_beams=3, max_new_tokens=1024, do_sample=True, output_mask_select="", seed=None): 330 | device = mm.get_torch_device() 331 | _, height, width, _ = image.shape 332 | offload_device = mm.unet_offload_device() 333 | annotated_image_tensor = None 334 | mask_tensor = None 335 | processor = florence2_model['processor'] 336 | model = florence2_model['model'] 337 | dtype = florence2_model['dtype'] 338 | model.to(device) 339 | 340 | if seed: 341 | set_seed(self.hash_seed(seed)) 342 | 343 | colormap = ['blue','orange','green','purple','brown','pink','olive','cyan','red', 344 | 'lime','indigo','violet','aqua','magenta','gold','tan','skyblue'] 345 | 346 | prompts = { 347 | 'region_caption': '', 348 | 'dense_region_caption': '', 349 | 'region_proposal': '', 350 | 'caption': '', 351 | 'detailed_caption': '', 352 | 'more_detailed_caption': '', 353 | 'caption_to_phrase_grounding': '', 354 | 'referring_expression_segmentation': '', 355 | 'ocr': '', 356 | 'ocr_with_region': '', 357 | 'docvqa': '', 358 | 'prompt_gen_tags': '', 359 | 'prompt_gen_mixed_caption': '', 360 | 'prompt_gen_analyze': '', 361 | 'prompt_gen_mixed_caption_plus': '', 362 | } 363 | task_prompt = prompts.get(task, '') 364 | 365 | if (task not in ['referring_expression_segmentation', 'caption_to_phrase_grounding', 'docvqa']) and text_input: 366 | raise ValueError("Text input (prompt) is only supported for 'referring_expression_segmentation', 'caption_to_phrase_grounding', and 'docvqa'") 367 | 368 | if text_input != "": 369 | prompt = task_prompt + " " + text_input 370 | else: 371 | prompt = task_prompt 372 | 373 | image = image.permute(0, 3, 1, 2) 374 | 375 | out = [] 376 | out_masks = [] 377 | out_results = [] 378 | out_data = [] 379 | pbar = ProgressBar(len(image)) 380 | for img in image: 381 | image_pil = F.to_pil_image(img) 382 | inputs = processor(text=prompt, images=image_pil, return_tensors="pt", do_rescale=False).to(dtype).to(device) 383 | 384 | generated_ids = model.generate( 385 | input_ids=inputs["input_ids"], 386 | pixel_values=inputs["pixel_values"], 387 | max_new_tokens=max_new_tokens, 388 | do_sample=do_sample, 389 | num_beams=num_beams, 390 | use_cache=False, 391 | ) 392 | 393 | results = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] 394 | print(results) 395 | # cleanup the special tokens from the final list 396 | if task == 'ocr_with_region': 397 | clean_results = str(results) 398 | cleaned_string = re.sub(r'|<[^>]*>', '\n', clean_results) 399 | clean_results = re.sub(r'\n+', '\n', cleaned_string) 400 | else: 401 | clean_results = str(results) 402 | clean_results = clean_results.replace('', '') 403 | clean_results = clean_results.replace('', '') 404 | 405 | #return single string if only one image for compatibility with nodes that can't handle string lists 406 | if len(image) == 1: 407 | out_results = clean_results 408 | else: 409 | out_results.append(clean_results) 410 | 411 | W, H = image_pil.size 412 | 413 | parsed_answer = processor.post_process_generation(results, task=task_prompt, image_size=(W, H)) 414 | 415 | if task == 'region_caption' or task == 'dense_region_caption' or task == 'caption_to_phrase_grounding' or task == 'region_proposal': 416 | fig, ax = plt.subplots(figsize=(W / 100, H / 100), dpi=100) 417 | fig.subplots_adjust(left=0, right=1, top=1, bottom=0) 418 | ax.imshow(image_pil) 419 | bboxes = parsed_answer[task_prompt]['bboxes'] 420 | labels = parsed_answer[task_prompt]['labels'] 421 | 422 | mask_indexes = [] 423 | # Determine mask indexes outside the loop 424 | if output_mask_select != "": 425 | mask_indexes = [n for n in output_mask_select.split(",")] 426 | print(mask_indexes) 427 | else: 428 | mask_indexes = [str(i) for i in range(len(bboxes))] 429 | 430 | # Initialize mask_layer only if needed 431 | if fill_mask: 432 | mask_layer = Image.new('RGB', image_pil.size, (0, 0, 0)) 433 | mask_draw = ImageDraw.Draw(mask_layer) 434 | 435 | for index, (bbox, label) in enumerate(zip(bboxes, labels)): 436 | # Modify the label to include the index 437 | indexed_label = f"{index}.{label}" 438 | 439 | if fill_mask: 440 | # Ensure y1 is greater than or equal to y0 for mask drawing 441 | x0, y0, x1, y1 = bbox[0], bbox[1], bbox[2], bbox[3] 442 | if y1 < y0: 443 | y0, y1 = y1, y0 444 | if x1 < x0: 445 | x0, x1 = x1, x0 446 | 447 | if str(index) in mask_indexes: 448 | print("match index:", str(index), "in mask_indexes:", mask_indexes) 449 | mask_draw.rectangle([x0, y0, x1, y1], fill=(255, 255, 255)) 450 | if label in mask_indexes: 451 | print("match label") 452 | mask_draw.rectangle([x0, y0, x1, y1], fill=(255, 255, 255)) 453 | 454 | # Create a Rectangle patch 455 | # Ensure y1 is greater than or equal to y0 456 | y0, y1 = bbox[1], bbox[3] 457 | if y1 < y0: 458 | y0, y1 = y1, y0 459 | 460 | rect = patches.Rectangle( 461 | (bbox[0], y0), # (x,y) - lower left corner 462 | bbox[2] - bbox[0], # Width 463 | y1 - y0, # Height 464 | linewidth=1, 465 | edgecolor='r', 466 | facecolor='none', 467 | label=indexed_label 468 | ) 469 | # Calculate text width with a rough estimation 470 | text_width = len(label) * 6 # Adjust multiplier based on your font size 471 | text_height = 12 # Adjust based on your font size 472 | 473 | # Get corrected coordinates 474 | x0, y0, x1, y1 = bbox[0], bbox[1], bbox[2], bbox[3] 475 | if y1 < y0: 476 | y0, y1 = y1, y0 477 | if x1 < x0: 478 | x0, x1 = x1, x0 479 | 480 | # Initial text position 481 | text_x = x0 482 | text_y = y0 - text_height # Position text above the top-left of the bbox 483 | 484 | # Adjust text_x if text is going off the left or right edge 485 | if text_x < 0: 486 | text_x = 0 487 | elif text_x + text_width > W: 488 | text_x = W - text_width 489 | 490 | # Adjust text_y if text is going off the top edge 491 | if text_y < 0: 492 | text_y = y1 # Move text below the bottom-left of the bbox if it doesn't overlap with bbox 493 | 494 | # Add the rectangle to the plot 495 | ax.add_patch(rect) 496 | facecolor = random.choice(colormap) if len(image) == 1 else 'red' 497 | # Add the label 498 | plt.text( 499 | text_x, 500 | text_y, 501 | indexed_label, 502 | color='white', 503 | fontsize=12, 504 | bbox=dict(facecolor=facecolor, alpha=0.5) 505 | ) 506 | if fill_mask: 507 | mask_tensor = F.to_tensor(mask_layer) 508 | mask_tensor = mask_tensor.unsqueeze(0).permute(0, 2, 3, 1).cpu().float() 509 | mask_tensor = mask_tensor.mean(dim=0, keepdim=True) 510 | mask_tensor = mask_tensor.repeat(1, 1, 1, 3) 511 | mask_tensor = mask_tensor[:, :, :, 0] 512 | out_masks.append(mask_tensor) 513 | 514 | # Remove axis and padding around the image 515 | ax.axis('off') 516 | ax.margins(0,0) 517 | ax.get_xaxis().set_major_locator(plt.NullLocator()) 518 | ax.get_yaxis().set_major_locator(plt.NullLocator()) 519 | fig.canvas.draw() 520 | buf = io.BytesIO() 521 | plt.savefig(buf, format='png', pad_inches=0) 522 | buf.seek(0) 523 | annotated_image_pil = Image.open(buf) 524 | 525 | annotated_image_tensor = F.to_tensor(annotated_image_pil) 526 | out_tensor = annotated_image_tensor[:3, :, :].unsqueeze(0).permute(0, 2, 3, 1).cpu().float() 527 | out.append(out_tensor) 528 | 529 | if task == 'caption_to_phrase_grounding': 530 | out_data.append(parsed_answer[task_prompt]) 531 | else: 532 | out_data.append(bboxes) 533 | 534 | 535 | pbar.update(1) 536 | 537 | plt.close(fig) 538 | 539 | elif task == 'referring_expression_segmentation': 540 | # Create a new black image 541 | mask_image = Image.new('RGB', (W, H), 'black') 542 | mask_draw = ImageDraw.Draw(mask_image) 543 | 544 | predictions = parsed_answer[task_prompt] 545 | 546 | # Iterate over polygons and labels 547 | for polygons, label in zip(predictions['polygons'], predictions['labels']): 548 | color = random.choice(colormap) 549 | for _polygon in polygons: 550 | _polygon = np.array(_polygon).reshape(-1, 2) 551 | # Clamp polygon points to image boundaries 552 | _polygon = np.clip(_polygon, [0, 0], [W - 1, H - 1]) 553 | if len(_polygon) < 3: 554 | print('Invalid polygon:', _polygon) 555 | continue 556 | 557 | _polygon = _polygon.reshape(-1).tolist() 558 | 559 | # Draw the polygon 560 | if fill_mask: 561 | overlay = Image.new('RGBA', image_pil.size, (255, 255, 255, 0)) 562 | image_pil = image_pil.convert('RGBA') 563 | draw = ImageDraw.Draw(overlay) 564 | color_with_opacity = ImageColor.getrgb(color) + (180,) 565 | draw.polygon(_polygon, outline=color, fill=color_with_opacity, width=3) 566 | image_pil = Image.alpha_composite(image_pil, overlay) 567 | else: 568 | draw = ImageDraw.Draw(image_pil) 569 | draw.polygon(_polygon, outline=color, width=3) 570 | 571 | #draw mask 572 | mask_draw.polygon(_polygon, outline="white", fill="white") 573 | 574 | image_tensor = F.to_tensor(image_pil) 575 | image_tensor = image_tensor[:3, :, :].unsqueeze(0).permute(0, 2, 3, 1).cpu().float() 576 | out.append(image_tensor) 577 | 578 | mask_tensor = F.to_tensor(mask_image) 579 | mask_tensor = mask_tensor.unsqueeze(0).permute(0, 2, 3, 1).cpu().float() 580 | mask_tensor = mask_tensor.mean(dim=0, keepdim=True) 581 | mask_tensor = mask_tensor.repeat(1, 1, 1, 3) 582 | mask_tensor = mask_tensor[:, :, :, 0] 583 | out_masks.append(mask_tensor) 584 | pbar.update(1) 585 | 586 | elif task == 'ocr_with_region': 587 | try: 588 | font = ImageFont.load_default().font_variant(size=24) 589 | except: 590 | font = ImageFont.load_default() 591 | predictions = parsed_answer[task_prompt] 592 | scale = 1 593 | image_pil = image_pil.convert('RGBA') 594 | overlay = Image.new('RGBA', image_pil.size, (255, 255, 255, 0)) 595 | draw = ImageDraw.Draw(overlay) 596 | bboxes, labels = predictions['quad_boxes'], predictions['labels'] 597 | 598 | # Create a new black image for the mask 599 | mask_image = Image.new('RGB', (W, H), 'black') 600 | mask_draw = ImageDraw.Draw(mask_image) 601 | 602 | for box, label in zip(bboxes, labels): 603 | scaled_box = [v / (width if idx % 2 == 0 else height) for idx, v in enumerate(box)] 604 | out_data.append({"label": label, "box": scaled_box}) 605 | 606 | color = random.choice(colormap) 607 | new_box = (np.array(box) * scale).tolist() 608 | 609 | # Ensure polygon coordinates are valid 610 | # For polygons, we need to make sure the points form a valid shape 611 | # This is a simple check to ensure the polygon has at least 3 points 612 | if len(new_box) >= 6: # At least 3 points (x,y pairs) 613 | if fill_mask: 614 | color_with_opacity = ImageColor.getrgb(color) + (180,) 615 | draw.polygon(new_box, outline=color, fill=color_with_opacity, width=3) 616 | else: 617 | draw.polygon(new_box, outline=color, width=3) 618 | 619 | # Get the first point for text positioning 620 | text_x, text_y = new_box[0]+8, new_box[1]+2 621 | 622 | draw.text((text_x, text_y), 623 | "{}".format(label), 624 | align="right", 625 | font=font, 626 | fill=color) 627 | 628 | # Draw the mask 629 | mask_draw.polygon(new_box, outline="white", fill="white") 630 | 631 | image_pil = Image.alpha_composite(image_pil, overlay) 632 | image_pil = image_pil.convert('RGB') 633 | 634 | image_tensor = F.to_tensor(image_pil) 635 | image_tensor = image_tensor[:3, :, :].unsqueeze(0).permute(0, 2, 3, 1).cpu().float() 636 | out.append(image_tensor) 637 | 638 | # Process the mask 639 | mask_tensor = F.to_tensor(mask_image) 640 | mask_tensor = mask_tensor.unsqueeze(0).permute(0, 2, 3, 1).cpu().float() 641 | mask_tensor = mask_tensor.mean(dim=0, keepdim=True) 642 | mask_tensor = mask_tensor.repeat(1, 1, 1, 3) 643 | mask_tensor = mask_tensor[:, :, :, 0] 644 | out_masks.append(mask_tensor) 645 | 646 | pbar.update(1) 647 | 648 | elif task == 'docvqa': 649 | if text_input == "": 650 | raise ValueError("Text input (prompt) is required for 'docvqa'") 651 | prompt = " " + text_input 652 | 653 | inputs = processor(text=prompt, images=image_pil, return_tensors="pt", do_rescale=False).to(dtype).to(device) 654 | generated_ids = model.generate( 655 | input_ids=inputs["input_ids"], 656 | pixel_values=inputs["pixel_values"], 657 | max_new_tokens=max_new_tokens, 658 | do_sample=do_sample, 659 | num_beams=num_beams, 660 | use_cache=False, 661 | ) 662 | 663 | results = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] 664 | clean_results = results.replace('', '').replace('', '') 665 | 666 | if len(image) == 1: 667 | out_results = clean_results 668 | else: 669 | out_results.append(clean_results) 670 | 671 | out.append(F.to_tensor(image_pil).unsqueeze(0).permute(0, 2, 3, 1).cpu().float()) 672 | 673 | pbar.update(1) 674 | 675 | if len(out) > 0: 676 | out_tensor = torch.cat(out, dim=0) 677 | else: 678 | out_tensor = torch.zeros((1, 64,64, 3), dtype=torch.float32, device="cpu") 679 | if len(out_masks) > 0: 680 | out_mask_tensor = torch.cat(out_masks, dim=0) 681 | else: 682 | out_mask_tensor = torch.zeros((1,64,64), dtype=torch.float32, device="cpu") 683 | 684 | if not keep_model_loaded: 685 | print("Offloading model...") 686 | model.to(offload_device) 687 | mm.soft_empty_cache() 688 | 689 | return (out_tensor, out_mask_tensor, out_results, out_data) 690 | 691 | NODE_CLASS_MAPPINGS = { 692 | "DownloadAndLoadFlorence2Model": DownloadAndLoadFlorence2Model, 693 | "DownloadAndLoadFlorence2Lora": DownloadAndLoadFlorence2Lora, 694 | "Florence2ModelLoader": Florence2ModelLoader, 695 | "Florence2Run": Florence2Run, 696 | } 697 | NODE_DISPLAY_NAME_MAPPINGS = { 698 | "DownloadAndLoadFlorence2Model": "DownloadAndLoadFlorence2Model", 699 | "DownloadAndLoadFlorence2Lora": "DownloadAndLoadFlorence2Lora", 700 | "Florence2ModelLoader": "Florence2ModelLoader", 701 | "Florence2Run": "Florence2Run", 702 | } 703 | -------------------------------------------------------------------------------- /modeling_florence2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ PyTorch Florence-2 model.""" 17 | from dataclasses import dataclass 18 | from typing import List, Optional, Tuple, Union 19 | 20 | import math 21 | import torch 22 | import torch.utils.checkpoint 23 | from torch import nn 24 | import torch.nn.functional as F 25 | import torch.utils.checkpoint as checkpoint 26 | from torch.nn import CrossEntropyLoss 27 | from collections import OrderedDict 28 | from einops import rearrange 29 | try: 30 | from timm.models.layers import DropPath, trunc_normal_ 31 | except: 32 | from timm.layers import DropPath, trunc_normal_ 33 | 34 | from transformers.modeling_utils import PreTrainedModel 35 | from transformers.generation.utils import GenerationMixin 36 | from transformers.utils import ( 37 | ModelOutput, 38 | add_start_docstrings, 39 | add_start_docstrings_to_model_forward, 40 | is_flash_attn_2_available, 41 | logging, 42 | replace_return_docstrings, 43 | is_flash_attn_2_available, 44 | is_flash_attn_greater_or_equal_2_10, 45 | ) 46 | from .configuration_florence2 import Florence2Config 47 | from .configuration_florence2 import Florence2LanguageConfig 48 | from .configuration_florence2 import Florence2VisionConfig 49 | 50 | 51 | from transformers.activations import ACT2FN 52 | from transformers.modeling_attn_mask_utils import ( 53 | _prepare_4d_attention_mask, 54 | _prepare_4d_attention_mask_for_sdpa, 55 | _prepare_4d_causal_attention_mask, 56 | _prepare_4d_causal_attention_mask_for_sdpa, 57 | ) 58 | from transformers.modeling_outputs import ( 59 | BaseModelOutput, 60 | BaseModelOutputWithPastAndCrossAttentions, 61 | Seq2SeqLMOutput, 62 | Seq2SeqModelOutput, 63 | ) 64 | 65 | 66 | if is_flash_attn_2_available(): 67 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 68 | 69 | logger = logging.get_logger(__name__) 70 | 71 | _CONFIG_FOR_DOC = "Florence2Config" 72 | 73 | class LearnedAbsolutePositionEmbedding2D(nn.Module): 74 | """ 75 | This module learns positional embeddings up to a fixed maximum size. 76 | """ 77 | 78 | def __init__(self, embedding_dim=256, num_pos=50): 79 | super().__init__() 80 | self.row_embeddings = nn.Embedding(num_pos, embedding_dim // 2) 81 | self.column_embeddings = nn.Embedding(num_pos, embedding_dim - (embedding_dim // 2)) 82 | 83 | def forward(self, pixel_values): 84 | """ 85 | pixel_values: (batch_size, height, width, num_channels) 86 | returns: (batch_size, height, width, embedding_dim * 2) 87 | """ 88 | if len(pixel_values.shape) != 4: 89 | raise ValueError('pixel_values must be a 4D tensor') 90 | height, width = pixel_values.shape[1:3] 91 | width_values = torch.arange(width, device=pixel_values.device) 92 | height_values = torch.arange(height, device=pixel_values.device) 93 | x_emb = self.column_embeddings(width_values) 94 | y_emb = self.row_embeddings(height_values) 95 | # (height, width, embedding_dim * 2) 96 | pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1) 97 | # (embedding_dim * 2, height, width) 98 | pos = pos.permute(2, 0, 1) 99 | pos = pos.unsqueeze(0) 100 | # (batch_size, embedding_dim * 2, height, width) 101 | pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) 102 | # (batch_size, height, width, embedding_dim * 2) 103 | pos = pos.permute(0, 2, 3, 1) 104 | return pos 105 | 106 | class PositionalEmbeddingCosine1D(nn.Module): 107 | """ 108 | This class implements a very simple positional encoding. It follows closely 109 | the encoder from the link below: 110 | https://pytorch.org/tutorials/beginner/translation_transformer.html 111 | 112 | Args: 113 | embed_dim: The dimension of the embeddings. 114 | dropout_prob: The dropout probability. 115 | max_seq_len: The maximum length to precompute the positional encodings. 116 | """ 117 | def __init__( 118 | self, 119 | embed_dim: int = 512, 120 | max_seq_len: int = 1024) -> None: 121 | super(PositionalEmbeddingCosine1D, self).__init__() 122 | self.embed_dim = embed_dim 123 | self.max_seq_len = max_seq_len 124 | # Generate the sinusoidal arrays. 125 | factor = math.log(10000) 126 | denominator = torch.exp( 127 | -factor * torch.arange(0, self.embed_dim, 2) / self.embed_dim) 128 | # Matrix where rows correspond to a positional embedding as a function 129 | # of the position index (i.e., the row index). 130 | frequencies = \ 131 | torch.arange(0, self.max_seq_len) \ 132 | .reshape(self.max_seq_len, 1) * denominator 133 | pos_idx_to_embed = torch.zeros((self.max_seq_len, self.embed_dim)) 134 | # Populate uneven entries. 135 | pos_idx_to_embed[:, 0::2] = torch.sin(frequencies) 136 | pos_idx_to_embed[:, 1::2] = torch.cos(frequencies) 137 | # Save the positional embeddings in a constant buffer. 138 | self.register_buffer("pos_idx_to_embed", pos_idx_to_embed) 139 | 140 | def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: 141 | """ 142 | Args: 143 | seq_embeds: The sequence embeddings in order. Allowed size: 144 | 1. [T, D], where T is the length of the sequence, and D is the 145 | frame embedding dimension. 146 | 2. [B, T, D], where B is the batch size and T and D are the 147 | same as above. 148 | 149 | Returns a tensor of with the same dimensions as the input: i.e., 150 | [1, T, D] or [T, D]. 151 | """ 152 | shape_len = len(seq_embeds.shape) 153 | assert 2 <= shape_len <= 3 154 | len_seq = seq_embeds.size(-2) 155 | assert len_seq <= self.max_seq_len 156 | pos_embeds = self.pos_idx_to_embed[0:seq_embeds.size(-2), :] 157 | # Adapt pre-computed positional embeddings to the input. 158 | if shape_len == 3: 159 | pos_embeds = pos_embeds.view( 160 | (1, pos_embeds.size(0), pos_embeds.size(1))) 161 | return pos_embeds 162 | 163 | 164 | class LearnedAbsolutePositionEmbedding1D(nn.Module): 165 | """ 166 | Learnable absolute positional embeddings for 1D sequences. 167 | 168 | Args: 169 | embed_dim: The dimension of the embeddings. 170 | max_seq_len: The maximum length to precompute the positional encodings. 171 | """ 172 | def __init__( 173 | self, 174 | embedding_dim: int = 512, 175 | num_pos: int = 1024) -> None: 176 | super(LearnedAbsolutePositionEmbedding1D, self).__init__() 177 | self.embeddings = nn.Embedding(num_pos, embedding_dim) 178 | self.num_pos = num_pos 179 | 180 | def forward(self, seq_embeds: torch.Tensor) -> torch.Tensor: 181 | """ 182 | Args: 183 | seq_embeds: The sequence embeddings in order. Allowed size: 184 | 1. [T, D], where T is the length of the sequence, and D is the 185 | frame embedding dimension. 186 | 2. [B, T, D], where B is the batch size and T and D are the 187 | same as above. 188 | 189 | Returns a tensor of with the same dimensions as the input: i.e., 190 | [1, T, D] or [T, D]. 191 | """ 192 | shape_len = len(seq_embeds.shape) 193 | assert 2 <= shape_len <= 3 194 | len_seq = seq_embeds.size(-2) 195 | assert len_seq <= self.num_pos 196 | # [T, D] 197 | pos_embeds = self.embeddings(torch.arange(len_seq).to(seq_embeds.device)) 198 | # Adapt pre-computed positional embeddings to the input. 199 | if shape_len == 3: 200 | pos_embeds = pos_embeds.view( 201 | (1, pos_embeds.size(0), pos_embeds.size(1))) 202 | return pos_embeds 203 | 204 | 205 | 206 | class MySequential(nn.Sequential): 207 | def forward(self, *inputs): 208 | for module in self._modules.values(): 209 | if type(inputs) == tuple: 210 | inputs = module(*inputs) 211 | else: 212 | inputs = module(inputs) 213 | return inputs 214 | 215 | 216 | class PreNorm(nn.Module): 217 | def __init__(self, norm, fn, drop_path=None): 218 | super().__init__() 219 | self.norm = norm 220 | self.fn = fn 221 | self.drop_path = drop_path 222 | 223 | def forward(self, x, *args, **kwargs): 224 | shortcut = x 225 | if self.norm != None: 226 | x, size = self.fn(self.norm(x), *args, **kwargs) 227 | else: 228 | x, size = self.fn(x, *args, **kwargs) 229 | 230 | if self.drop_path: 231 | x = self.drop_path(x) 232 | 233 | x = shortcut + x 234 | 235 | return x, size 236 | 237 | 238 | class Mlp(nn.Module): 239 | def __init__( 240 | self, 241 | in_features, 242 | hidden_features=None, 243 | out_features=None, 244 | act_layer=nn.GELU, 245 | ): 246 | super().__init__() 247 | out_features = out_features or in_features 248 | hidden_features = hidden_features or in_features 249 | self.net = nn.Sequential(OrderedDict([ 250 | ("fc1", nn.Linear(in_features, hidden_features)), 251 | ("act", act_layer()), 252 | ("fc2", nn.Linear(hidden_features, out_features)) 253 | ])) 254 | 255 | def forward(self, x, size): 256 | return self.net(x), size 257 | 258 | 259 | class DepthWiseConv2d(nn.Module): 260 | def __init__( 261 | self, 262 | dim_in, 263 | kernel_size, 264 | padding, 265 | stride, 266 | bias=True, 267 | ): 268 | super().__init__() 269 | self.dw = nn.Conv2d( 270 | dim_in, dim_in, 271 | kernel_size=kernel_size, 272 | padding=padding, 273 | groups=dim_in, 274 | stride=stride, 275 | bias=bias 276 | ) 277 | 278 | def forward(self, x, size): 279 | B, N, C = x.shape 280 | H, W = size 281 | assert N == H * W 282 | 283 | x = self.dw(x.transpose(1, 2).view(B, C, H, W)) 284 | size = (x.size(-2), x.size(-1)) 285 | x = x.flatten(2).transpose(1, 2) 286 | return x, size 287 | 288 | 289 | class ConvEmbed(nn.Module): 290 | """ Image to Patch Embedding 291 | """ 292 | 293 | def __init__( 294 | self, 295 | patch_size=7, 296 | in_chans=3, 297 | embed_dim=64, 298 | stride=4, 299 | padding=2, 300 | norm_layer=None, 301 | pre_norm=True 302 | ): 303 | super().__init__() 304 | self.patch_size = patch_size 305 | 306 | self.proj = nn.Conv2d( 307 | in_chans, embed_dim, 308 | kernel_size=patch_size, 309 | stride=stride, 310 | padding=padding 311 | ) 312 | 313 | dim_norm = in_chans if pre_norm else embed_dim 314 | self.norm = norm_layer(dim_norm) if norm_layer else None 315 | 316 | self.pre_norm = pre_norm 317 | 318 | def forward(self, x, size): 319 | H, W = size 320 | if len(x.size()) == 3: 321 | if self.norm and self.pre_norm: 322 | x = self.norm(x) 323 | x = rearrange( 324 | x, 'b (h w) c -> b c h w', 325 | h=H, w=W 326 | ) 327 | 328 | x = self.proj(x) 329 | 330 | _, _, H, W = x.shape 331 | x = rearrange(x, 'b c h w -> b (h w) c') 332 | if self.norm and not self.pre_norm: 333 | x = self.norm(x) 334 | 335 | return x, (H, W) 336 | 337 | 338 | class ChannelAttention(nn.Module): 339 | 340 | def __init__(self, dim, groups=8, qkv_bias=True): 341 | super().__init__() 342 | 343 | self.groups = groups 344 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 345 | self.proj = nn.Linear(dim, dim) 346 | 347 | def forward(self, x, size): 348 | B, N, C = x.shape 349 | 350 | qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4) 351 | q, k, v = qkv[0], qkv[1], qkv[2] 352 | 353 | q = q * (float(N) ** -0.5) 354 | attention = q.transpose(-1, -2) @ k 355 | attention = attention.softmax(dim=-1) 356 | x = (attention @ v.transpose(-1, -2)).transpose(-1, -2) 357 | x = x.transpose(1, 2).reshape(B, N, C) 358 | x = self.proj(x) 359 | return x, size 360 | 361 | 362 | class ChannelBlock(nn.Module): 363 | 364 | def __init__(self, dim, groups, mlp_ratio=4., qkv_bias=True, 365 | drop_path_rate=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, 366 | conv_at_attn=True, conv_at_ffn=True): 367 | super().__init__() 368 | 369 | drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() 370 | 371 | self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None 372 | self.channel_attn = PreNorm( 373 | norm_layer(dim), 374 | ChannelAttention(dim, groups=groups, qkv_bias=qkv_bias), 375 | drop_path 376 | ) 377 | self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None 378 | self.ffn = PreNorm( 379 | norm_layer(dim), 380 | Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer), 381 | drop_path 382 | ) 383 | 384 | def forward(self, x, size): 385 | if self.conv1: 386 | x, size = self.conv1(x, size) 387 | x, size = self.channel_attn(x, size) 388 | 389 | if self.conv2: 390 | x, size = self.conv2(x, size) 391 | x, size = self.ffn(x, size) 392 | 393 | return x, size 394 | 395 | 396 | def window_partition(x, window_size: int): 397 | B, H, W, C = x.shape 398 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 399 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 400 | return windows 401 | 402 | 403 | def window_reverse(windows, batch_size: int, window_size: int, H: int, W: int): 404 | B = batch_size 405 | # this will cause onnx conversion failed for dynamic axis, because treated as constant 406 | # int(windows.shape[0] / (H * W / window_size / window_size)) 407 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 408 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 409 | return x 410 | 411 | 412 | class WindowAttention(nn.Module): 413 | def __init__(self, dim, num_heads, window_size, qkv_bias=True): 414 | 415 | super().__init__() 416 | self.dim = dim 417 | self.window_size = window_size 418 | self.num_heads = num_heads 419 | head_dim = dim // num_heads 420 | self.scale = float(head_dim) ** -0.5 421 | 422 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 423 | self.proj = nn.Linear(dim, dim) 424 | 425 | self.softmax = nn.Softmax(dim=-1) 426 | 427 | def forward(self, x, size): 428 | 429 | H, W = size 430 | B, L, C = x.shape 431 | assert L == H * W, "input feature has wrong size" 432 | 433 | x = x.view(B, H, W, C) 434 | 435 | pad_l = pad_t = 0 436 | pad_r = (self.window_size - W % self.window_size) % self.window_size 437 | pad_b = (self.window_size - H % self.window_size) % self.window_size 438 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) 439 | _, Hp, Wp, _ = x.shape 440 | 441 | x = window_partition(x, self.window_size) 442 | x = x.view(-1, self.window_size * self.window_size, C) 443 | 444 | # W-MSA/SW-MSA 445 | # attn_windows = self.attn(x_windows) 446 | 447 | B_, N, C = x.shape 448 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 449 | q, k, v = qkv[0], qkv[1], qkv[2] 450 | 451 | q = q * self.scale 452 | attn = (q @ k.transpose(-2, -1)) 453 | attn = self.softmax(attn) 454 | 455 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 456 | x = self.proj(x) 457 | 458 | # merge windows 459 | x = x.view( 460 | -1, self.window_size, self.window_size, C 461 | ) 462 | x = window_reverse(x, B, self.window_size, Hp, Wp) 463 | 464 | if pad_r > 0 or pad_b > 0: 465 | x = x[:, :H, :W, :].contiguous() 466 | 467 | x = x.view(B, H * W, C) 468 | 469 | return x, size 470 | 471 | 472 | class SpatialBlock(nn.Module): 473 | 474 | def __init__(self, dim, num_heads, window_size, 475 | mlp_ratio=4., qkv_bias=True, drop_path_rate=0., act_layer=nn.GELU, 476 | norm_layer=nn.LayerNorm, conv_at_attn=True, conv_at_ffn=True): 477 | super().__init__() 478 | 479 | drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() 480 | 481 | self.conv1 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_attn else None 482 | self.window_attn = PreNorm( 483 | norm_layer(dim), 484 | WindowAttention(dim, num_heads, window_size, qkv_bias=qkv_bias), 485 | drop_path 486 | ) 487 | self.conv2 = PreNorm(None, DepthWiseConv2d(dim, 3, 1, 1)) if conv_at_ffn else None 488 | self.ffn = PreNorm( 489 | norm_layer(dim), 490 | Mlp(in_features=dim, hidden_features=int(dim*mlp_ratio), act_layer=act_layer), 491 | drop_path 492 | ) 493 | 494 | def forward(self, x, size): 495 | if self.conv1: 496 | x, size = self.conv1(x, size) 497 | x, size = self.window_attn(x, size) 498 | 499 | if self.conv2: 500 | x, size = self.conv2(x, size) 501 | x, size = self.ffn(x, size) 502 | return x, size 503 | 504 | 505 | class DaViT(nn.Module): 506 | """ DaViT: Dual-Attention Transformer 507 | 508 | Args: 509 | in_chans (int): Number of input image channels. Default: 3. 510 | num_classes (int): Number of classes for classification head. Default: 1000. 511 | patch_size (tuple(int)): Patch size of convolution in different stages. Default: (7, 2, 2, 2). 512 | patch_stride (tuple(int)): Patch stride of convolution in different stages. Default: (4, 2, 2, 2). 513 | patch_padding (tuple(int)): Patch padding of convolution in different stages. Default: (3, 0, 0, 0). 514 | patch_prenorm (tuple(bool)): If True, perform norm before convlution layer. Default: (True, False, False, False). 515 | embed_dims (tuple(int)): Patch embedding dimension in different stages. Default: (64, 128, 192, 256). 516 | num_heads (tuple(int)): Number of spatial attention heads in different stages. Default: (4, 8, 12, 16). 517 | num_groups (tuple(int)): Number of channel groups in different stages. Default: (4, 8, 12, 16). 518 | window_size (int): Window size. Default: 7. 519 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. 520 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True. 521 | drop_path_rate (float): Stochastic depth rate. Default: 0.1. 522 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 523 | enable_checkpoint (bool): If True, enable checkpointing. Default: False. 524 | conv_at_attn (bool): If True, performe depthwise convolution before attention layer. Default: True. 525 | conv_at_ffn (bool): If True, performe depthwise convolution before ffn layer. Default: True. 526 | """ 527 | 528 | def __init__( 529 | self, 530 | in_chans=3, 531 | num_classes=1000, 532 | depths=(1, 1, 3, 1), 533 | patch_size=(7, 2, 2, 2), 534 | patch_stride=(4, 2, 2, 2), 535 | patch_padding=(3, 0, 0, 0), 536 | patch_prenorm=(False, False, False, False), 537 | embed_dims=(64, 128, 192, 256), 538 | num_heads=(3, 6, 12, 24), 539 | num_groups=(3, 6, 12, 24), 540 | window_size=7, 541 | mlp_ratio=4., 542 | qkv_bias=True, 543 | drop_path_rate=0.1, 544 | norm_layer=nn.LayerNorm, 545 | enable_checkpoint=False, 546 | conv_at_attn=True, 547 | conv_at_ffn=True, 548 | ): 549 | super().__init__() 550 | 551 | self.num_classes = num_classes 552 | self.embed_dims = embed_dims 553 | self.num_heads = num_heads 554 | self.num_groups = num_groups 555 | self.num_stages = len(self.embed_dims) 556 | self.enable_checkpoint = enable_checkpoint 557 | assert self.num_stages == len(self.num_heads) == len(self.num_groups) 558 | 559 | num_stages = len(embed_dims) 560 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)*2)] 561 | 562 | depth_offset = 0 563 | convs = [] 564 | blocks = [] 565 | for i in range(num_stages): 566 | conv_embed = ConvEmbed( 567 | patch_size=patch_size[i], 568 | stride=patch_stride[i], 569 | padding=patch_padding[i], 570 | in_chans=in_chans if i == 0 else self.embed_dims[i - 1], 571 | embed_dim=self.embed_dims[i], 572 | norm_layer=norm_layer, 573 | pre_norm=patch_prenorm[i] 574 | ) 575 | convs.append(conv_embed) 576 | 577 | block = MySequential( 578 | *[ 579 | MySequential(OrderedDict([ 580 | ( 581 | 'spatial_block', SpatialBlock( 582 | embed_dims[i], 583 | num_heads[i], 584 | window_size, 585 | drop_path_rate=dpr[depth_offset+j*2], 586 | qkv_bias=qkv_bias, 587 | mlp_ratio=mlp_ratio, 588 | conv_at_attn=conv_at_attn, 589 | conv_at_ffn=conv_at_ffn, 590 | ) 591 | ), 592 | ( 593 | 'channel_block', ChannelBlock( 594 | embed_dims[i], 595 | num_groups[i], 596 | drop_path_rate=dpr[depth_offset+j*2+1], 597 | qkv_bias=qkv_bias, 598 | mlp_ratio=mlp_ratio, 599 | conv_at_attn=conv_at_attn, 600 | conv_at_ffn=conv_at_ffn, 601 | ) 602 | ) 603 | ])) for j in range(depths[i]) 604 | ] 605 | ) 606 | blocks.append(block) 607 | depth_offset += depths[i]*2 608 | 609 | self.convs = nn.ModuleList(convs) 610 | self.blocks = nn.ModuleList(blocks) 611 | 612 | self.norms = norm_layer(self.embed_dims[-1]) 613 | self.avgpool = nn.AdaptiveAvgPool1d(1) 614 | self.head = nn.Linear(self.embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() 615 | 616 | @property 617 | def dim_out(self): 618 | return self.embed_dims[-1] 619 | 620 | def forward_features_unpool(self, x): 621 | """ 622 | forward until avg pooling 623 | Args: 624 | x (_type_): input image tensor 625 | """ 626 | input_size = (x.size(2), x.size(3)) 627 | for conv, block in zip(self.convs, self.blocks): 628 | x, input_size = conv(x, input_size) 629 | if self.enable_checkpoint: 630 | x, input_size = checkpoint.checkpoint(block, x, input_size) 631 | else: 632 | x, input_size = block(x, input_size) 633 | return x 634 | 635 | def forward_features(self, x): 636 | x = self.forward_features_unpool(x) 637 | 638 | # (batch_size, num_tokens, token_dim) 639 | x = self.avgpool(x.transpose(1, 2)) 640 | # (batch_size, 1, num_tokens) 641 | x = torch.flatten(x, 1) 642 | x = self.norms(x) 643 | 644 | return x 645 | 646 | def forward(self, x): 647 | x = self.forward_features(x) 648 | x = self.head(x) 649 | return x 650 | 651 | @classmethod 652 | def from_config(cls, config): 653 | return cls( 654 | depths=config.depths, 655 | embed_dims=config.dim_embed, 656 | num_heads=config.num_heads, 657 | num_groups=config.num_groups, 658 | patch_size=config.patch_size, 659 | patch_stride=config.patch_stride, 660 | patch_padding=config.patch_padding, 661 | patch_prenorm=config.patch_prenorm, 662 | drop_path_rate=config.drop_path_rate, 663 | window_size=config.window_size, 664 | ) 665 | 666 | 667 | 668 | 669 | if is_flash_attn_2_available(): 670 | from flash_attn import flash_attn_func, flash_attn_varlen_func 671 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 672 | 673 | # Copied from transformers.models.llama.modeling_llama._get_unpad_data 674 | def _get_unpad_data(attention_mask): 675 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 676 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 677 | max_seqlen_in_batch = seqlens_in_batch.max().item() 678 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) 679 | return ( 680 | indices, 681 | cu_seqlens, 682 | max_seqlen_in_batch, 683 | ) 684 | 685 | 686 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): 687 | """ 688 | Shift input ids one token to the right. 689 | """ 690 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 691 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() 692 | shifted_input_ids[:, 0] = decoder_start_token_id 693 | 694 | if pad_token_id is None: 695 | raise ValueError("self.model.config.pad_token_id has to be defined.") 696 | # replace possible -100 values in labels by `pad_token_id` 697 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 698 | 699 | return shifted_input_ids 700 | 701 | 702 | class Florence2LearnedPositionalEmbedding(nn.Embedding): 703 | """ 704 | This module learns positional embeddings up to a fixed maximum size. 705 | """ 706 | 707 | def __init__(self, num_embeddings: int, embedding_dim: int): 708 | # Florence2 is set up so that if padding_idx is specified then offset the embedding ids by 2 709 | # and adjust num_embeddings appropriately. Other models don't have this hack 710 | self.offset = 2 711 | super().__init__(num_embeddings + self.offset, embedding_dim) 712 | 713 | def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): 714 | """`input_ids' shape is expected to be [bsz x seqlen].""" 715 | 716 | bsz, seq_len = input_ids.shape[:2] 717 | positions = torch.arange( 718 | past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device 719 | ).expand(bsz, -1) 720 | 721 | return super().forward(positions + self.offset) 722 | 723 | 724 | class Florence2ScaledWordEmbedding(nn.Embedding): 725 | """ 726 | This module overrides nn.Embeddings' forward by multiplying with embeddings scale. 727 | """ 728 | 729 | def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): 730 | super().__init__(num_embeddings, embedding_dim, padding_idx) 731 | self.embed_scale = embed_scale 732 | 733 | def forward(self, input_ids: torch.Tensor): 734 | return super().forward(input_ids) * self.embed_scale 735 | 736 | 737 | class Florence2Attention(nn.Module): 738 | """Multi-headed attention from 'Attention Is All You Need' paper""" 739 | 740 | def __init__( 741 | self, 742 | embed_dim: int, 743 | num_heads: int, 744 | dropout: float = 0.0, 745 | is_decoder: bool = False, 746 | bias: bool = True, 747 | is_causal: bool = False, 748 | config: Optional[Florence2LanguageConfig] = None, 749 | ): 750 | super().__init__() 751 | self.embed_dim = embed_dim 752 | self.num_heads = num_heads 753 | self.dropout = dropout 754 | self.head_dim = embed_dim // num_heads 755 | self.config = config 756 | 757 | if (self.head_dim * num_heads) != self.embed_dim: 758 | raise ValueError( 759 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" 760 | f" and `num_heads`: {num_heads})." 761 | ) 762 | self.scaling = self.head_dim**-0.5 763 | self.is_decoder = is_decoder 764 | self.is_causal = is_causal 765 | 766 | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 767 | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 768 | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 769 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 770 | 771 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 772 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 773 | 774 | def forward( 775 | self, 776 | hidden_states: torch.Tensor, 777 | key_value_states: Optional[torch.Tensor] = None, 778 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 779 | attention_mask: Optional[torch.Tensor] = None, 780 | layer_head_mask: Optional[torch.Tensor] = None, 781 | output_attentions: bool = False, 782 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 783 | """Input shape: Batch x Time x Channel""" 784 | 785 | # if key_value_states are provided this layer is used as a cross-attention layer 786 | # for the decoder 787 | is_cross_attention = key_value_states is not None 788 | 789 | bsz, tgt_len, _ = hidden_states.size() 790 | 791 | # get query proj 792 | query_states = self.q_proj(hidden_states) * self.scaling 793 | # get key, value proj 794 | # `past_key_value[0].shape[2] == key_value_states.shape[1]` 795 | # is checking that the `sequence_length` of the `past_key_value` is the same as 796 | # the provided `key_value_states` to support prefix tuning 797 | if ( 798 | is_cross_attention 799 | and past_key_value is not None 800 | and past_key_value[0].shape[2] == key_value_states.shape[1] 801 | ): 802 | # reuse k,v, cross_attentions 803 | key_states = past_key_value[0] 804 | value_states = past_key_value[1] 805 | elif is_cross_attention: 806 | # cross_attentions 807 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz) 808 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz) 809 | elif past_key_value is not None: 810 | # reuse k, v, self_attention 811 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 812 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 813 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 814 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 815 | else: 816 | # self_attention 817 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 818 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 819 | 820 | if self.is_decoder: 821 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 822 | # Further calls to cross_attention layer can then reuse all cross-attention 823 | # key/value_states (first "if" case) 824 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 825 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 826 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 827 | # if encoder bi-directional self-attention `past_key_value` is always `None` 828 | past_key_value = (key_states, value_states) 829 | 830 | proj_shape = (bsz * self.num_heads, -1, self.head_dim) 831 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) 832 | key_states = key_states.reshape(*proj_shape) 833 | value_states = value_states.reshape(*proj_shape) 834 | 835 | src_len = key_states.size(1) 836 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) 837 | 838 | if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): 839 | raise ValueError( 840 | f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" 841 | f" {attn_weights.size()}" 842 | ) 843 | 844 | if attention_mask is not None: 845 | if attention_mask.size() != (bsz, 1, tgt_len, src_len): 846 | raise ValueError( 847 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" 848 | ) 849 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask 850 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 851 | 852 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 853 | 854 | if layer_head_mask is not None: 855 | if layer_head_mask.size() != (self.num_heads,): 856 | raise ValueError( 857 | f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" 858 | f" {layer_head_mask.size()}" 859 | ) 860 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 861 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 862 | 863 | if output_attentions: 864 | # this operation is a bit awkward, but it's required to 865 | # make sure that attn_weights keeps its gradient. 866 | # In order to do so, attn_weights have to be reshaped 867 | # twice and have to be reused in the following 868 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 869 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) 870 | else: 871 | attn_weights_reshaped = None 872 | 873 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) 874 | 875 | attn_output = torch.bmm(attn_probs, value_states) 876 | 877 | if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): 878 | raise ValueError( 879 | f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" 880 | f" {attn_output.size()}" 881 | ) 882 | 883 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) 884 | attn_output = attn_output.transpose(1, 2) 885 | 886 | # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be 887 | # partitioned across GPUs when using tensor-parallelism. 888 | attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) 889 | 890 | attn_output = self.out_proj(attn_output) 891 | 892 | return attn_output, attn_weights_reshaped, past_key_value 893 | 894 | 895 | class Florence2FlashAttention2(Florence2Attention): 896 | """ 897 | Florence2 flash attention module. This module inherits from `Florence2Attention` as the weights of the module stays 898 | untouched. The only required change would be on the forward pass where it needs to correctly call the public API of 899 | flash attention and deal with padding tokens in case the input contains any of them. 900 | """ 901 | 902 | # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ 903 | def __init__(self, *args, **kwargs): 904 | super().__init__(*args, **kwargs) 905 | 906 | # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. 907 | # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. 908 | # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). 909 | self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() 910 | 911 | def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 912 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) 913 | 914 | def forward( 915 | self, 916 | hidden_states: torch.Tensor, 917 | key_value_states: Optional[torch.Tensor] = None, 918 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 919 | attention_mask: Optional[torch.Tensor] = None, 920 | layer_head_mask: Optional[torch.Tensor] = None, 921 | output_attentions: bool = False, 922 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 923 | # Florence2FlashAttention2 attention does not support output_attentions 924 | if output_attentions: 925 | raise ValueError("Florence2FlashAttention2 attention does not support output_attentions") 926 | 927 | # if key_value_states are provided this layer is used as a cross-attention layer 928 | # for the decoder 929 | is_cross_attention = key_value_states is not None 930 | 931 | bsz, q_len, _ = hidden_states.size() 932 | 933 | # get query proj 934 | query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) 935 | # get key, value proj 936 | # `past_key_value[0].shape[2] == key_value_states.shape[1]` 937 | # is checking that the `sequence_length` of the `past_key_value` is the same as 938 | # the provided `key_value_states` to support prefix tuning 939 | if ( 940 | is_cross_attention 941 | and past_key_value is not None 942 | and past_key_value[0].shape[2] == key_value_states.shape[1] 943 | ): 944 | # reuse k,v, cross_attentions 945 | key_states = past_key_value[0].transpose(1, 2) 946 | value_states = past_key_value[1].transpose(1, 2) 947 | elif is_cross_attention: 948 | # cross_attentions 949 | key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) 950 | value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) 951 | elif past_key_value is not None: 952 | # reuse k, v, self_attention 953 | key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) 954 | value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) 955 | key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) 956 | value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) 957 | else: 958 | # self_attention 959 | key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) 960 | value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) 961 | 962 | if self.is_decoder: 963 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 964 | # Further calls to cross_attention layer can then reuse all cross-attention 965 | # key/value_states (first "if" case) 966 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 967 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 968 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 969 | # if encoder bi-directional self-attention `past_key_value` is always `None` 970 | past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) 971 | 972 | kv_seq_len = key_states.shape[-2] 973 | if past_key_value is not None: 974 | kv_seq_len += past_key_value[0].shape[-2] 975 | 976 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 977 | # therefore the input hidden states gets silently casted in float32. Hence, we need 978 | # cast them back in the correct dtype just to be sure everything works as expected. 979 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms 980 | # in fp32. (LlamaRMSNorm handles it correctly) 981 | 982 | input_dtype = query_states.dtype 983 | if input_dtype == torch.float32: 984 | if torch.is_autocast_enabled(): 985 | target_dtype = torch.get_autocast_gpu_dtype() 986 | # Handle the case where the model is quantized 987 | elif hasattr(self.config, "_pre_quantization_dtype"): 988 | target_dtype = self.config._pre_quantization_dtype 989 | else: 990 | target_dtype = self.q_proj.weight.dtype 991 | 992 | logger.warning_once( 993 | f"The input hidden states seems to be silently casted in float32, this might be related to" 994 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 995 | f" {target_dtype}." 996 | ) 997 | 998 | query_states = query_states.to(target_dtype) 999 | key_states = key_states.to(target_dtype) 1000 | value_states = value_states.to(target_dtype) 1001 | 1002 | attn_output = self._flash_attention_forward( 1003 | query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout 1004 | ) 1005 | 1006 | attn_output = attn_output.reshape(bsz, q_len, -1) 1007 | attn_output = self.out_proj(attn_output) 1008 | 1009 | if not output_attentions: 1010 | attn_weights = None 1011 | 1012 | return attn_output, attn_weights, past_key_value 1013 | 1014 | # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward 1015 | def _flash_attention_forward( 1016 | self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None 1017 | ): 1018 | """ 1019 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 1020 | first unpad the input, then computes the attention scores and pad the final attention scores. 1021 | 1022 | Args: 1023 | query_states (`torch.Tensor`): 1024 | Input query states to be passed to Flash Attention API 1025 | key_states (`torch.Tensor`): 1026 | Input key states to be passed to Flash Attention API 1027 | value_states (`torch.Tensor`): 1028 | Input value states to be passed to Flash Attention API 1029 | attention_mask (`torch.Tensor`): 1030 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 1031 | position of padding tokens and 1 for the position of non-padding tokens. 1032 | dropout (`float`): 1033 | Attention dropout 1034 | softmax_scale (`float`, *optional*): 1035 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 1036 | """ 1037 | if not self._flash_attn_uses_top_left_mask: 1038 | causal = self.is_causal 1039 | else: 1040 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. 1041 | causal = self.is_causal and query_length != 1 1042 | 1043 | # Contains at least one padding token in the sequence 1044 | if attention_mask is not None: 1045 | batch_size = query_states.shape[0] 1046 | query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( 1047 | query_states, key_states, value_states, attention_mask, query_length 1048 | ) 1049 | 1050 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 1051 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 1052 | 1053 | attn_output_unpad = flash_attn_varlen_func( 1054 | query_states, 1055 | key_states, 1056 | value_states, 1057 | cu_seqlens_q=cu_seqlens_q, 1058 | cu_seqlens_k=cu_seqlens_k, 1059 | max_seqlen_q=max_seqlen_in_batch_q, 1060 | max_seqlen_k=max_seqlen_in_batch_k, 1061 | dropout_p=dropout, 1062 | softmax_scale=softmax_scale, 1063 | causal=causal, 1064 | ) 1065 | 1066 | attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) 1067 | else: 1068 | attn_output = flash_attn_func( 1069 | query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal 1070 | ) 1071 | 1072 | return attn_output 1073 | 1074 | # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input 1075 | def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): 1076 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) 1077 | batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape 1078 | 1079 | key_layer = index_first_axis( 1080 | key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k 1081 | ) 1082 | value_layer = index_first_axis( 1083 | value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k 1084 | ) 1085 | if query_length == kv_seq_len: 1086 | query_layer = index_first_axis( 1087 | query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k 1088 | ) 1089 | cu_seqlens_q = cu_seqlens_k 1090 | max_seqlen_in_batch_q = max_seqlen_in_batch_k 1091 | indices_q = indices_k 1092 | elif query_length == 1: 1093 | max_seqlen_in_batch_q = 1 1094 | cu_seqlens_q = torch.arange( 1095 | batch_size + 1, dtype=torch.int32, device=query_layer.device 1096 | ) # There is a memcpy here, that is very bad. 1097 | indices_q = cu_seqlens_q[:-1] 1098 | query_layer = query_layer.squeeze(1) 1099 | else: 1100 | # The -q_len: slice assumes left padding. 1101 | attention_mask = attention_mask[:, -query_length:] 1102 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) 1103 | 1104 | return ( 1105 | query_layer, 1106 | key_layer, 1107 | value_layer, 1108 | indices_q, 1109 | (cu_seqlens_q, cu_seqlens_k), 1110 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k), 1111 | ) 1112 | 1113 | 1114 | class Florence2SdpaAttention(Florence2Attention): 1115 | def forward( 1116 | self, 1117 | hidden_states: torch.Tensor, 1118 | key_value_states: Optional[torch.Tensor] = None, 1119 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 1120 | attention_mask: Optional[torch.Tensor] = None, 1121 | layer_head_mask: Optional[torch.Tensor] = None, 1122 | output_attentions: bool = False, 1123 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 1124 | """Input shape: Batch x Time x Channel""" 1125 | if output_attentions or layer_head_mask is not None: 1126 | # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. 1127 | logger.warning_once( 1128 | "Florence2Model is using Florence2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" 1129 | ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' 1130 | ) 1131 | return super().forward( 1132 | hidden_states, 1133 | key_value_states=key_value_states, 1134 | past_key_value=past_key_value, 1135 | attention_mask=attention_mask, 1136 | layer_head_mask=layer_head_mask, 1137 | output_attentions=output_attentions, 1138 | ) 1139 | 1140 | # if key_value_states are provided this layer is used as a cross-attention layer 1141 | # for the decoder 1142 | is_cross_attention = key_value_states is not None 1143 | 1144 | bsz, tgt_len, _ = hidden_states.size() 1145 | 1146 | # get query proj 1147 | query_states = self.q_proj(hidden_states) 1148 | # get key, value proj 1149 | # `past_key_value[0].shape[2] == key_value_states.shape[1]` 1150 | # is checking that the `sequence_length` of the `past_key_value` is the same as 1151 | # the provided `key_value_states` to support prefix tuning 1152 | if ( 1153 | is_cross_attention 1154 | and past_key_value is not None 1155 | and past_key_value[0].shape[2] == key_value_states.shape[1] 1156 | ): 1157 | # reuse k,v, cross_attentions 1158 | key_states = past_key_value[0] 1159 | value_states = past_key_value[1] 1160 | elif is_cross_attention: 1161 | # cross_attentions 1162 | key_states = self._shape(self.k_proj(key_value_states), -1, bsz) 1163 | value_states = self._shape(self.v_proj(key_value_states), -1, bsz) 1164 | elif past_key_value is not None: 1165 | # reuse k, v, self_attention 1166 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 1167 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 1168 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 1169 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 1170 | else: 1171 | # self_attention 1172 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 1173 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 1174 | 1175 | if self.is_decoder: 1176 | # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. 1177 | # Further calls to cross_attention layer can then reuse all cross-attention 1178 | # key/value_states (first "if" case) 1179 | # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of 1180 | # all previous decoder key/value_states. Further calls to uni-directional self-attention 1181 | # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) 1182 | # if encoder bi-directional self-attention `past_key_value` is always `None` 1183 | past_key_value = (key_states, value_states) 1184 | 1185 | query_states = self._shape(query_states, tgt_len, bsz) 1186 | 1187 | # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment 1188 | # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. 1189 | # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. 1190 | is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False 1191 | 1192 | # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, 1193 | # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 1194 | attn_output = torch.nn.functional.scaled_dot_product_attention( 1195 | query_states, 1196 | key_states, 1197 | value_states, 1198 | attn_mask=attention_mask, 1199 | dropout_p=self.dropout if self.training else 0.0, 1200 | is_causal=is_causal, 1201 | ) 1202 | 1203 | if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): 1204 | raise ValueError( 1205 | f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" 1206 | f" {attn_output.size()}" 1207 | ) 1208 | 1209 | attn_output = attn_output.transpose(1, 2) 1210 | 1211 | # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be 1212 | # partitioned across GPUs when using tensor-parallelism. 1213 | attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) 1214 | 1215 | attn_output = self.out_proj(attn_output) 1216 | 1217 | return attn_output, None, past_key_value 1218 | 1219 | 1220 | FLORENCE2_ATTENTION_CLASSES = { 1221 | "eager": Florence2Attention, 1222 | "sdpa": Florence2SdpaAttention, 1223 | "flash_attention_2": Florence2FlashAttention2, 1224 | } 1225 | 1226 | 1227 | class Florence2EncoderLayer(nn.Module): 1228 | def __init__(self, config: Florence2LanguageConfig): 1229 | super().__init__() 1230 | self.embed_dim = config.d_model 1231 | 1232 | self.self_attn = FLORENCE2_ATTENTION_CLASSES[config._attn_implementation]( 1233 | embed_dim=self.embed_dim, 1234 | num_heads=config.encoder_attention_heads, 1235 | dropout=config.attention_dropout, 1236 | config=config, 1237 | ) 1238 | self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) 1239 | self.dropout = config.dropout 1240 | self.activation_fn = ACT2FN[config.activation_function] 1241 | self.activation_dropout = config.activation_dropout 1242 | self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) 1243 | self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) 1244 | self.final_layer_norm = nn.LayerNorm(self.embed_dim) 1245 | 1246 | def forward( 1247 | self, 1248 | hidden_states: torch.FloatTensor, 1249 | attention_mask: torch.FloatTensor, 1250 | layer_head_mask: torch.FloatTensor, 1251 | output_attentions: Optional[bool] = False, 1252 | ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]: 1253 | """ 1254 | Args: 1255 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 1256 | attention_mask (`torch.FloatTensor`): attention mask of size 1257 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 1258 | layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size 1259 | `(encoder_attention_heads,)`. 1260 | output_attentions (`bool`, *optional*): 1261 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 1262 | returned tensors for more detail. 1263 | """ 1264 | residual = hidden_states 1265 | hidden_states, attn_weights, _ = self.self_attn( 1266 | hidden_states=hidden_states, 1267 | attention_mask=attention_mask, 1268 | layer_head_mask=layer_head_mask, 1269 | output_attentions=output_attentions, 1270 | ) 1271 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 1272 | hidden_states = residual + hidden_states 1273 | hidden_states = self.self_attn_layer_norm(hidden_states) 1274 | 1275 | residual = hidden_states 1276 | hidden_states = self.activation_fn(self.fc1(hidden_states)) 1277 | hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) 1278 | hidden_states = self.fc2(hidden_states) 1279 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 1280 | hidden_states = residual + hidden_states 1281 | hidden_states = self.final_layer_norm(hidden_states) 1282 | 1283 | if hidden_states.dtype == torch.float16 and ( 1284 | torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() 1285 | ): 1286 | clamp_value = torch.finfo(hidden_states.dtype).max - 1000 1287 | hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) 1288 | 1289 | outputs = (hidden_states,) 1290 | 1291 | if output_attentions: 1292 | outputs += (attn_weights,) 1293 | 1294 | return outputs 1295 | 1296 | 1297 | class Florence2DecoderLayer(nn.Module): 1298 | def __init__(self, config: Florence2LanguageConfig): 1299 | super().__init__() 1300 | self.embed_dim = config.d_model 1301 | 1302 | self.self_attn = FLORENCE2_ATTENTION_CLASSES[config._attn_implementation]( 1303 | embed_dim=self.embed_dim, 1304 | num_heads=config.decoder_attention_heads, 1305 | dropout=config.attention_dropout, 1306 | is_decoder=True, 1307 | is_causal=True, 1308 | config=config, 1309 | ) 1310 | self.dropout = config.dropout 1311 | self.activation_fn = ACT2FN[config.activation_function] 1312 | self.activation_dropout = config.activation_dropout 1313 | 1314 | self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) 1315 | self.encoder_attn = FLORENCE2_ATTENTION_CLASSES[config._attn_implementation]( 1316 | self.embed_dim, 1317 | config.decoder_attention_heads, 1318 | dropout=config.attention_dropout, 1319 | is_decoder=True, 1320 | config=config, 1321 | ) 1322 | self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) 1323 | self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) 1324 | self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) 1325 | self.final_layer_norm = nn.LayerNorm(self.embed_dim) 1326 | 1327 | def forward( 1328 | self, 1329 | hidden_states: torch.Tensor, 1330 | attention_mask: Optional[torch.Tensor] = None, 1331 | encoder_hidden_states: Optional[torch.Tensor] = None, 1332 | encoder_attention_mask: Optional[torch.Tensor] = None, 1333 | layer_head_mask: Optional[torch.Tensor] = None, 1334 | cross_attn_layer_head_mask: Optional[torch.Tensor] = None, 1335 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 1336 | output_attentions: Optional[bool] = False, 1337 | use_cache: Optional[bool] = True, 1338 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 1339 | """ 1340 | Args: 1341 | hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` 1342 | attention_mask (`torch.FloatTensor`): attention mask of size 1343 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 1344 | encoder_hidden_states (`torch.FloatTensor`): 1345 | cross attention input to the layer of shape `(batch, seq_len, embed_dim)` 1346 | encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size 1347 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 1348 | layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size 1349 | `(encoder_attention_heads,)`. 1350 | cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of 1351 | size `(decoder_attention_heads,)`. 1352 | past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states 1353 | output_attentions (`bool`, *optional*): 1354 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 1355 | returned tensors for more detail. 1356 | """ 1357 | residual = hidden_states 1358 | 1359 | # Self Attention 1360 | # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 1361 | self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None 1362 | # add present self-attn cache to positions 1,2 of present_key_value tuple 1363 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 1364 | hidden_states=hidden_states, 1365 | past_key_value=self_attn_past_key_value, 1366 | attention_mask=attention_mask, 1367 | layer_head_mask=layer_head_mask, 1368 | output_attentions=output_attentions, 1369 | ) 1370 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 1371 | hidden_states = residual + hidden_states 1372 | hidden_states = self.self_attn_layer_norm(hidden_states) 1373 | 1374 | # Cross-Attention Block 1375 | cross_attn_present_key_value = None 1376 | cross_attn_weights = None 1377 | if encoder_hidden_states is not None: 1378 | residual = hidden_states 1379 | 1380 | # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple 1381 | cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None 1382 | hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( 1383 | hidden_states=hidden_states, 1384 | key_value_states=encoder_hidden_states, 1385 | attention_mask=encoder_attention_mask, 1386 | layer_head_mask=cross_attn_layer_head_mask, 1387 | past_key_value=cross_attn_past_key_value, 1388 | output_attentions=output_attentions, 1389 | ) 1390 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 1391 | hidden_states = residual + hidden_states 1392 | hidden_states = self.encoder_attn_layer_norm(hidden_states) 1393 | 1394 | # add cross-attn to positions 3,4 of present_key_value tuple 1395 | present_key_value = present_key_value + cross_attn_present_key_value 1396 | 1397 | # Fully Connected 1398 | residual = hidden_states 1399 | hidden_states = self.activation_fn(self.fc1(hidden_states)) 1400 | hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) 1401 | hidden_states = self.fc2(hidden_states) 1402 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 1403 | hidden_states = residual + hidden_states 1404 | hidden_states = self.final_layer_norm(hidden_states) 1405 | 1406 | outputs = (hidden_states,) 1407 | 1408 | if output_attentions: 1409 | outputs += (self_attn_weights, cross_attn_weights) 1410 | 1411 | if use_cache: 1412 | outputs += (present_key_value,) 1413 | 1414 | return outputs 1415 | 1416 | 1417 | 1418 | class Florence2LanguagePreTrainedModel(PreTrainedModel): 1419 | config_class = Florence2LanguageConfig 1420 | base_model_prefix = "model" 1421 | supports_gradient_checkpointing = True 1422 | _keys_to_ignore_on_load_unexpected = ["encoder.version", "decoder.version"] 1423 | _no_split_modules = [r"Florence2EncoderLayer", r"Florence2DecoderLayer"] 1424 | _skip_keys_device_placement = "past_key_values" 1425 | _supports_flash_attn_2 = True 1426 | _supports_sdpa = True 1427 | 1428 | def _init_weights(self, module): 1429 | std = self.config.init_std 1430 | if isinstance(module, nn.Linear): 1431 | module.weight.data.normal_(mean=0.0, std=std) 1432 | if module.bias is not None: 1433 | module.bias.data.zero_() 1434 | elif isinstance(module, nn.Embedding): 1435 | module.weight.data.normal_(mean=0.0, std=std) 1436 | if module.padding_idx is not None: 1437 | module.weight.data[module.padding_idx].zero_() 1438 | elif isinstance(module, nn.Conv2d): 1439 | nn.init.normal_(module.weight, std=0.02) 1440 | for name, _ in module.named_parameters(): 1441 | if name == "bias": 1442 | nn.init.constant_(module.bias, 0) 1443 | elif isinstance(module, nn.LayerNorm): 1444 | nn.init.constant_(module.weight, 1.0) 1445 | nn.init.constant_(module.bias, 0) 1446 | elif isinstance(module, nn.BatchNorm2d): 1447 | nn.init.constant_(module.weight, 1.0) 1448 | nn.init.constant_(module.bias, 0) 1449 | 1450 | @property 1451 | def dummy_inputs(self): 1452 | pad_token = self.config.pad_token_id 1453 | input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device) 1454 | dummy_inputs = { 1455 | "attention_mask": input_ids.ne(pad_token), 1456 | "input_ids": input_ids, 1457 | } 1458 | return dummy_inputs 1459 | 1460 | 1461 | class Florence2Encoder(Florence2LanguagePreTrainedModel): 1462 | """ 1463 | Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a 1464 | [`Florence2EncoderLayer`]. 1465 | 1466 | Args: 1467 | config: Florence2LanguageConfig 1468 | embed_tokens (nn.Embedding): output embedding 1469 | """ 1470 | 1471 | def __init__(self, config: Florence2LanguageConfig, embed_tokens: Optional[nn.Embedding] = None): 1472 | super().__init__(config) 1473 | 1474 | self.dropout = config.dropout 1475 | self.layerdrop = config.encoder_layerdrop 1476 | 1477 | embed_dim = config.d_model 1478 | self.padding_idx = config.pad_token_id 1479 | self.max_source_positions = config.max_position_embeddings 1480 | embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 1481 | 1482 | self.embed_tokens = Florence2ScaledWordEmbedding( 1483 | config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale 1484 | ) 1485 | 1486 | if embed_tokens is not None: 1487 | self.embed_tokens.weight = embed_tokens.weight 1488 | 1489 | self.embed_positions = Florence2LearnedPositionalEmbedding( 1490 | config.max_position_embeddings, 1491 | embed_dim, 1492 | ) 1493 | self.layers = nn.ModuleList([Florence2EncoderLayer(config) for _ in range(config.encoder_layers)]) 1494 | self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" 1495 | self._use_sdpa = config._attn_implementation == "sdpa" 1496 | self.layernorm_embedding = nn.LayerNorm(embed_dim) 1497 | 1498 | self.gradient_checkpointing = False 1499 | # Initialize weights and apply final processing 1500 | self.post_init() 1501 | 1502 | def get_input_embeddings(self): 1503 | return self.embed_tokens 1504 | 1505 | def set_input_embeddings(self, value): 1506 | self.embed_tokens = value 1507 | 1508 | def forward( 1509 | self, 1510 | input_ids: torch.LongTensor = None, 1511 | attention_mask: Optional[torch.Tensor] = None, 1512 | head_mask: Optional[torch.Tensor] = None, 1513 | inputs_embeds: Optional[torch.FloatTensor] = None, 1514 | output_attentions: Optional[bool] = None, 1515 | output_hidden_states: Optional[bool] = None, 1516 | return_dict: Optional[bool] = None, 1517 | ) -> Union[Tuple, BaseModelOutput]: 1518 | r""" 1519 | Args: 1520 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 1521 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 1522 | provide it. 1523 | 1524 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 1525 | [`PreTrainedTokenizer.__call__`] for details. 1526 | 1527 | [What are input IDs?](../glossary#input-ids) 1528 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 1529 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 1530 | 1531 | - 1 for tokens that are **not masked**, 1532 | - 0 for tokens that are **masked**. 1533 | 1534 | [What are attention masks?](../glossary#attention-mask) 1535 | head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): 1536 | Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: 1537 | 1538 | - 1 indicates the head is **not masked**, 1539 | - 0 indicates the head is **masked**. 1540 | 1541 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 1542 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. 1543 | This is useful if you want more control over how to convert `input_ids` indices into associated vectors 1544 | than the model's internal embedding lookup matrix. 1545 | output_attentions (`bool`, *optional*): 1546 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 1547 | returned tensors for more detail. 1548 | output_hidden_states (`bool`, *optional*): 1549 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 1550 | for more detail. 1551 | return_dict (`bool`, *optional*): 1552 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 1553 | """ 1554 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1555 | output_hidden_states = ( 1556 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1557 | ) 1558 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1559 | 1560 | # retrieve input_ids and inputs_embeds 1561 | if input_ids is not None and inputs_embeds is not None: 1562 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 1563 | elif input_ids is not None: 1564 | input = input_ids 1565 | input_ids = input_ids.view(-1, input_ids.shape[-1]) 1566 | elif inputs_embeds is not None: 1567 | input = inputs_embeds[:, :, -1] 1568 | else: 1569 | raise ValueError("You have to specify either input_ids or inputs_embeds") 1570 | 1571 | if inputs_embeds is None: 1572 | inputs_embeds = self.embed_tokens(input_ids) 1573 | 1574 | embed_pos = self.embed_positions(input) 1575 | embed_pos = embed_pos.to(inputs_embeds.device) 1576 | 1577 | hidden_states = inputs_embeds + embed_pos 1578 | hidden_states = self.layernorm_embedding(hidden_states) 1579 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 1580 | 1581 | # expand attention_mask 1582 | if attention_mask is not None: 1583 | if self._use_flash_attention_2: 1584 | attention_mask = attention_mask if 0 in attention_mask else None 1585 | elif self._use_sdpa and head_mask is None and not output_attentions: 1586 | # output_attentions=True & head_mask can not be supported when using SDPA, fall back to 1587 | # the manual implementation that requires a 4D causal mask in all cases. 1588 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 1589 | attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) 1590 | else: 1591 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 1592 | attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) 1593 | 1594 | encoder_states = () if output_hidden_states else None 1595 | all_attentions = () if output_attentions else None 1596 | 1597 | # check if head_mask has a correct number of layers specified if desired 1598 | if head_mask is not None: 1599 | if head_mask.size()[0] != (len(self.layers)): 1600 | raise ValueError( 1601 | f"The head_mask should be specified for {len(self.layers)} layers, but it is for" 1602 | f" {head_mask.size()[0]}." 1603 | ) 1604 | 1605 | for idx, encoder_layer in enumerate(self.layers): 1606 | if output_hidden_states: 1607 | encoder_states = encoder_states + (hidden_states,) 1608 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 1609 | to_drop = False 1610 | if self.training: 1611 | dropout_probability = torch.rand([]) 1612 | if dropout_probability < self.layerdrop: # skip the layer 1613 | to_drop = True 1614 | 1615 | if to_drop: 1616 | layer_outputs = (None, None) 1617 | else: 1618 | if self.gradient_checkpointing and self.training: 1619 | layer_outputs = self._gradient_checkpointing_func( 1620 | encoder_layer.__call__, 1621 | hidden_states, 1622 | attention_mask, 1623 | (head_mask[idx] if head_mask is not None else None), 1624 | output_attentions, 1625 | ) 1626 | else: 1627 | layer_outputs = encoder_layer( 1628 | hidden_states, 1629 | attention_mask, 1630 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 1631 | output_attentions=output_attentions, 1632 | ) 1633 | 1634 | hidden_states = layer_outputs[0] 1635 | 1636 | if output_attentions: 1637 | all_attentions = all_attentions + (layer_outputs[1],) 1638 | 1639 | if output_hidden_states: 1640 | encoder_states = encoder_states + (hidden_states,) 1641 | 1642 | if not return_dict: 1643 | return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) 1644 | return BaseModelOutput( 1645 | last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions 1646 | ) 1647 | 1648 | 1649 | class Florence2Decoder(Florence2LanguagePreTrainedModel): 1650 | """ 1651 | Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`Florence2DecoderLayer`] 1652 | 1653 | Args: 1654 | config: Florence2LanguageConfig 1655 | embed_tokens (nn.Embedding): output embedding 1656 | """ 1657 | 1658 | def __init__(self, config: Florence2LanguageConfig, embed_tokens: Optional[nn.Embedding] = None): 1659 | super().__init__(config) 1660 | self.dropout = config.dropout 1661 | self.layerdrop = config.decoder_layerdrop 1662 | self.padding_idx = config.pad_token_id 1663 | self.max_target_positions = config.max_position_embeddings 1664 | embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 1665 | 1666 | self.embed_tokens = Florence2ScaledWordEmbedding( 1667 | config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale 1668 | ) 1669 | 1670 | if embed_tokens is not None: 1671 | self.embed_tokens.weight = embed_tokens.weight 1672 | 1673 | self.embed_positions = Florence2LearnedPositionalEmbedding( 1674 | config.max_position_embeddings, 1675 | config.d_model, 1676 | ) 1677 | self.layers = nn.ModuleList([Florence2DecoderLayer(config) for _ in range(config.decoder_layers)]) 1678 | self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" 1679 | self._use_sdpa = config._attn_implementation == "sdpa" 1680 | 1681 | self.layernorm_embedding = nn.LayerNorm(config.d_model) 1682 | 1683 | self.gradient_checkpointing = False 1684 | # Initialize weights and apply final processing 1685 | self.post_init() 1686 | 1687 | def get_input_embeddings(self): 1688 | return self.embed_tokens 1689 | 1690 | def set_input_embeddings(self, value): 1691 | self.embed_tokens = value 1692 | 1693 | def forward( 1694 | self, 1695 | input_ids: torch.LongTensor = None, 1696 | attention_mask: Optional[torch.Tensor] = None, 1697 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 1698 | encoder_attention_mask: Optional[torch.LongTensor] = None, 1699 | head_mask: Optional[torch.Tensor] = None, 1700 | cross_attn_head_mask: Optional[torch.Tensor] = None, 1701 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1702 | inputs_embeds: Optional[torch.FloatTensor] = None, 1703 | use_cache: Optional[bool] = None, 1704 | output_attentions: Optional[bool] = None, 1705 | output_hidden_states: Optional[bool] = None, 1706 | return_dict: Optional[bool] = None, 1707 | ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: 1708 | r""" 1709 | Args: 1710 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 1711 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you 1712 | provide it. 1713 | 1714 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 1715 | [`PreTrainedTokenizer.__call__`] for details. 1716 | 1717 | [What are input IDs?](../glossary#input-ids) 1718 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 1719 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 1720 | 1721 | - 1 for tokens that are **not masked**, 1722 | - 0 for tokens that are **masked**. 1723 | 1724 | [What are attention masks?](../glossary#attention-mask) 1725 | encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): 1726 | Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention 1727 | of the decoder. 1728 | encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): 1729 | Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values 1730 | selected in `[0, 1]`: 1731 | 1732 | - 1 for tokens that are **not masked**, 1733 | - 0 for tokens that are **masked**. 1734 | 1735 | [What are attention masks?](../glossary#attention-mask) 1736 | head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): 1737 | Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: 1738 | 1739 | - 1 indicates the head is **not masked**, 1740 | - 0 indicates the head is **masked**. 1741 | 1742 | cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): 1743 | Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing 1744 | cross-attention on hidden heads. Mask values selected in `[0, 1]`: 1745 | 1746 | - 1 indicates the head is **not masked**, 1747 | - 0 indicates the head is **masked**. 1748 | 1749 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 1750 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of 1751 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of 1752 | shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 1753 | 1754 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the 1755 | cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 1756 | 1757 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those 1758 | that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of 1759 | all `decoder_input_ids` of shape `(batch_size, sequence_length)`. 1760 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 1761 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. 1762 | This is useful if you want more control over how to convert `input_ids` indices into associated vectors 1763 | than the model's internal embedding lookup matrix. 1764 | output_attentions (`bool`, *optional*): 1765 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 1766 | returned tensors for more detail. 1767 | output_hidden_states (`bool`, *optional*): 1768 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 1769 | for more detail. 1770 | return_dict (`bool`, *optional*): 1771 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 1772 | """ 1773 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1774 | output_hidden_states = ( 1775 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1776 | ) 1777 | use_cache = use_cache if use_cache is not None else self.config.use_cache 1778 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1779 | 1780 | # retrieve input_ids and inputs_embeds 1781 | if input_ids is not None and inputs_embeds is not None: 1782 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 1783 | elif input_ids is not None: 1784 | input = input_ids 1785 | input_shape = input.shape 1786 | input_ids = input_ids.view(-1, input_shape[-1]) 1787 | elif inputs_embeds is not None: 1788 | input_shape = inputs_embeds.size()[:-1] 1789 | input = inputs_embeds[:, :, -1] 1790 | else: 1791 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 1792 | 1793 | # past_key_values_length 1794 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 1795 | 1796 | if inputs_embeds is None: 1797 | inputs_embeds = self.embed_tokens(input) 1798 | 1799 | if self._use_flash_attention_2: 1800 | # 2d mask is passed through the layers 1801 | attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None 1802 | elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: 1803 | # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on 1804 | # the manual implementation that requires a 4D causal mask in all cases. 1805 | attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( 1806 | attention_mask, 1807 | input_shape, 1808 | inputs_embeds, 1809 | past_key_values_length, 1810 | ) 1811 | else: 1812 | # 4d mask is passed through the layers 1813 | attention_mask = _prepare_4d_causal_attention_mask( 1814 | attention_mask, input_shape, inputs_embeds, past_key_values_length 1815 | ) 1816 | 1817 | # expand encoder attention mask 1818 | if encoder_hidden_states is not None and encoder_attention_mask is not None: 1819 | if self._use_flash_attention_2: 1820 | encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None 1821 | elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: 1822 | # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on 1823 | # the manual implementation that requires a 4D causal mask in all cases. 1824 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 1825 | encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( 1826 | encoder_attention_mask, 1827 | inputs_embeds.dtype, 1828 | tgt_len=input_shape[-1], 1829 | ) 1830 | else: 1831 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 1832 | encoder_attention_mask = _prepare_4d_attention_mask( 1833 | encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] 1834 | ) 1835 | 1836 | # embed positions 1837 | positions = self.embed_positions(input, past_key_values_length) 1838 | positions = positions.to(inputs_embeds.device) 1839 | 1840 | hidden_states = inputs_embeds + positions 1841 | hidden_states = self.layernorm_embedding(hidden_states) 1842 | 1843 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 1844 | 1845 | if self.gradient_checkpointing and self.training: 1846 | if use_cache: 1847 | logger.warning_once( 1848 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 1849 | ) 1850 | use_cache = False 1851 | 1852 | # decoder layers 1853 | all_hidden_states = () if output_hidden_states else None 1854 | all_self_attns = () if output_attentions else None 1855 | all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None 1856 | next_decoder_cache = () if use_cache else None 1857 | 1858 | # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired 1859 | for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): 1860 | if attn_mask is not None: 1861 | if attn_mask.size()[0] != (len(self.layers)): 1862 | raise ValueError( 1863 | f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" 1864 | f" {head_mask.size()[0]}." 1865 | ) 1866 | 1867 | for idx, decoder_layer in enumerate(self.layers): 1868 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 1869 | if output_hidden_states: 1870 | all_hidden_states += (hidden_states,) 1871 | if self.training: 1872 | dropout_probability = torch.rand([]) 1873 | if dropout_probability < self.layerdrop: 1874 | continue 1875 | 1876 | past_key_value = past_key_values[idx] if past_key_values is not None else None 1877 | 1878 | if self.gradient_checkpointing and self.training: 1879 | layer_outputs = self._gradient_checkpointing_func( 1880 | decoder_layer.__call__, 1881 | hidden_states, 1882 | attention_mask, 1883 | encoder_hidden_states, 1884 | encoder_attention_mask, 1885 | head_mask[idx] if head_mask is not None else None, 1886 | cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, 1887 | None, 1888 | output_attentions, 1889 | use_cache, 1890 | ) 1891 | else: 1892 | layer_outputs = decoder_layer( 1893 | hidden_states, 1894 | attention_mask=attention_mask, 1895 | encoder_hidden_states=encoder_hidden_states, 1896 | encoder_attention_mask=encoder_attention_mask, 1897 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 1898 | cross_attn_layer_head_mask=( 1899 | cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None 1900 | ), 1901 | past_key_value=past_key_value, 1902 | output_attentions=output_attentions, 1903 | use_cache=use_cache, 1904 | ) 1905 | hidden_states = layer_outputs[0] 1906 | 1907 | if use_cache: 1908 | next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) 1909 | 1910 | if output_attentions: 1911 | all_self_attns += (layer_outputs[1],) 1912 | 1913 | if encoder_hidden_states is not None: 1914 | all_cross_attentions += (layer_outputs[2],) 1915 | 1916 | # add hidden states from the last decoder layer 1917 | if output_hidden_states: 1918 | all_hidden_states += (hidden_states,) 1919 | 1920 | next_cache = next_decoder_cache if use_cache else None 1921 | if not return_dict: 1922 | return tuple( 1923 | v 1924 | for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] 1925 | if v is not None 1926 | ) 1927 | return BaseModelOutputWithPastAndCrossAttentions( 1928 | last_hidden_state=hidden_states, 1929 | past_key_values=next_cache, 1930 | hidden_states=all_hidden_states, 1931 | attentions=all_self_attns, 1932 | cross_attentions=all_cross_attentions, 1933 | ) 1934 | 1935 | 1936 | class Florence2LanguageModel(Florence2LanguagePreTrainedModel): 1937 | _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] 1938 | 1939 | def __init__(self, config: Florence2LanguageConfig): 1940 | super().__init__(config) 1941 | 1942 | padding_idx, vocab_size = config.pad_token_id, config.vocab_size 1943 | self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) 1944 | 1945 | self.encoder = Florence2Encoder(config, self.shared) 1946 | self.decoder = Florence2Decoder(config, self.shared) 1947 | 1948 | # Initialize weights and apply final processing 1949 | self.post_init() 1950 | 1951 | def _tie_weights(self): 1952 | if self.config.tie_word_embeddings: 1953 | self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) 1954 | self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) 1955 | 1956 | def get_input_embeddings(self): 1957 | return self.shared 1958 | 1959 | def set_input_embeddings(self, value): 1960 | self.shared = value 1961 | self.encoder.embed_tokens = self.shared 1962 | self.decoder.embed_tokens = self.shared 1963 | 1964 | def get_encoder(self): 1965 | return self.encoder 1966 | 1967 | def get_decoder(self): 1968 | return self.decoder 1969 | 1970 | def forward( 1971 | self, 1972 | input_ids: torch.LongTensor = None, 1973 | attention_mask: Optional[torch.Tensor] = None, 1974 | decoder_input_ids: Optional[torch.LongTensor] = None, 1975 | decoder_attention_mask: Optional[torch.LongTensor] = None, 1976 | head_mask: Optional[torch.Tensor] = None, 1977 | decoder_head_mask: Optional[torch.Tensor] = None, 1978 | cross_attn_head_mask: Optional[torch.Tensor] = None, 1979 | encoder_outputs: Optional[List[torch.FloatTensor]] = None, 1980 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1981 | inputs_embeds: Optional[torch.FloatTensor] = None, 1982 | decoder_inputs_embeds: Optional[torch.FloatTensor] = None, 1983 | use_cache: Optional[bool] = None, 1984 | output_attentions: Optional[bool] = None, 1985 | output_hidden_states: Optional[bool] = None, 1986 | return_dict: Optional[bool] = None, 1987 | ) -> Union[Tuple, Seq2SeqModelOutput]: 1988 | # different to other models, Florence2 automatically creates decoder_input_ids from 1989 | # input_ids if no decoder_input_ids are provided 1990 | if decoder_input_ids is None and decoder_inputs_embeds is None: 1991 | if input_ids is None: 1992 | raise ValueError( 1993 | "If no `decoder_input_ids` or `decoder_inputs_embeds` are " 1994 | "passed, `input_ids` cannot be `None`. Please pass either " 1995 | "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`." 1996 | ) 1997 | 1998 | decoder_input_ids = shift_tokens_right( 1999 | input_ids, self.config.pad_token_id, self.config.decoder_start_token_id 2000 | ) 2001 | 2002 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 2003 | output_hidden_states = ( 2004 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 2005 | ) 2006 | use_cache = use_cache if use_cache is not None else self.config.use_cache 2007 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 2008 | 2009 | if encoder_outputs is None: 2010 | encoder_outputs = self.encoder( 2011 | input_ids=input_ids, 2012 | attention_mask=attention_mask, 2013 | head_mask=head_mask, 2014 | inputs_embeds=inputs_embeds, 2015 | output_attentions=output_attentions, 2016 | output_hidden_states=output_hidden_states, 2017 | return_dict=return_dict, 2018 | ) 2019 | # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True 2020 | elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): 2021 | encoder_outputs = BaseModelOutput( 2022 | last_hidden_state=encoder_outputs[0], 2023 | hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, 2024 | attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, 2025 | ) 2026 | 2027 | # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) 2028 | decoder_outputs = self.decoder( 2029 | input_ids=decoder_input_ids, 2030 | attention_mask=decoder_attention_mask, 2031 | encoder_hidden_states=encoder_outputs[0], 2032 | encoder_attention_mask=attention_mask, 2033 | head_mask=decoder_head_mask, 2034 | cross_attn_head_mask=cross_attn_head_mask, 2035 | past_key_values=past_key_values, 2036 | inputs_embeds=decoder_inputs_embeds, 2037 | use_cache=use_cache, 2038 | output_attentions=output_attentions, 2039 | output_hidden_states=output_hidden_states, 2040 | return_dict=return_dict, 2041 | ) 2042 | 2043 | if not return_dict: 2044 | return decoder_outputs + encoder_outputs 2045 | 2046 | return Seq2SeqModelOutput( 2047 | last_hidden_state=decoder_outputs.last_hidden_state, 2048 | past_key_values=decoder_outputs.past_key_values, 2049 | decoder_hidden_states=decoder_outputs.hidden_states, 2050 | decoder_attentions=decoder_outputs.attentions, 2051 | cross_attentions=decoder_outputs.cross_attentions, 2052 | encoder_last_hidden_state=encoder_outputs.last_hidden_state, 2053 | encoder_hidden_states=encoder_outputs.hidden_states, 2054 | encoder_attentions=encoder_outputs.attentions, 2055 | ) 2056 | 2057 | 2058 | class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin): 2059 | base_model_prefix = "model" 2060 | _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] 2061 | _keys_to_ignore_on_load_missing = ["final_logits_bias"] 2062 | 2063 | def __init__(self, config: Florence2LanguageConfig): 2064 | super().__init__(config) 2065 | self.model = Florence2LanguageModel(config) 2066 | self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) 2067 | self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) 2068 | 2069 | # Initialize weights and apply final processing 2070 | self.post_init() 2071 | 2072 | def _tie_weights(self): 2073 | if self.config.tie_word_embeddings: 2074 | self._tie_or_clone_weights(self.model.encoder.embed_tokens, self.model.shared) 2075 | self._tie_or_clone_weights(self.model.decoder.embed_tokens, self.model.shared) 2076 | self._tie_or_clone_weights(self.lm_head, self.model.shared) 2077 | 2078 | def get_encoder(self): 2079 | return self.model.get_encoder() 2080 | 2081 | def get_decoder(self): 2082 | return self.model.get_decoder() 2083 | 2084 | def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding: 2085 | new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of) 2086 | self._resize_final_logits_bias(new_embeddings.weight.shape[0]) 2087 | return new_embeddings 2088 | 2089 | def _resize_final_logits_bias(self, new_num_tokens: int) -> None: 2090 | old_num_tokens = self.final_logits_bias.shape[-1] 2091 | if new_num_tokens <= old_num_tokens: 2092 | new_bias = self.final_logits_bias[:, :new_num_tokens] 2093 | else: 2094 | extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) 2095 | new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) 2096 | self.register_buffer("final_logits_bias", new_bias) 2097 | 2098 | def get_output_embeddings(self): 2099 | return self.lm_head 2100 | 2101 | def set_output_embeddings(self, new_embeddings): 2102 | self.lm_head = new_embeddings 2103 | 2104 | def forward( 2105 | self, 2106 | input_ids: torch.LongTensor = None, 2107 | attention_mask: Optional[torch.Tensor] = None, 2108 | decoder_input_ids: Optional[torch.LongTensor] = None, 2109 | decoder_attention_mask: Optional[torch.LongTensor] = None, 2110 | head_mask: Optional[torch.Tensor] = None, 2111 | decoder_head_mask: Optional[torch.Tensor] = None, 2112 | cross_attn_head_mask: Optional[torch.Tensor] = None, 2113 | encoder_outputs: Optional[List[torch.FloatTensor]] = None, 2114 | past_key_values: Optional[List[torch.FloatTensor]] = None, 2115 | inputs_embeds: Optional[torch.FloatTensor] = None, 2116 | decoder_inputs_embeds: Optional[torch.FloatTensor] = None, 2117 | labels: Optional[torch.LongTensor] = None, 2118 | use_cache: Optional[bool] = None, 2119 | output_attentions: Optional[bool] = None, 2120 | output_hidden_states: Optional[bool] = None, 2121 | return_dict: Optional[bool] = None, 2122 | ) -> Union[Tuple, Seq2SeqLMOutput]: 2123 | r""" 2124 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 2125 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 2126 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 2127 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 2128 | 2129 | Returns: 2130 | """ 2131 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 2132 | 2133 | if labels is not None: 2134 | if use_cache: 2135 | logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") 2136 | use_cache = False 2137 | if decoder_input_ids is None and decoder_inputs_embeds is None: 2138 | decoder_input_ids = shift_tokens_right( 2139 | labels, self.config.pad_token_id, self.config.decoder_start_token_id 2140 | ) 2141 | 2142 | outputs = self.model( 2143 | input_ids, 2144 | attention_mask=attention_mask, 2145 | decoder_input_ids=decoder_input_ids, 2146 | encoder_outputs=encoder_outputs, 2147 | decoder_attention_mask=decoder_attention_mask, 2148 | head_mask=head_mask, 2149 | decoder_head_mask=decoder_head_mask, 2150 | cross_attn_head_mask=cross_attn_head_mask, 2151 | past_key_values=past_key_values, 2152 | inputs_embeds=inputs_embeds, 2153 | decoder_inputs_embeds=decoder_inputs_embeds, 2154 | use_cache=use_cache, 2155 | output_attentions=output_attentions, 2156 | output_hidden_states=output_hidden_states, 2157 | return_dict=return_dict, 2158 | ) 2159 | 2160 | lm_logits = self.lm_head(outputs[0]) 2161 | lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) 2162 | 2163 | masked_lm_loss = None 2164 | if labels is not None: 2165 | labels = labels.to(lm_logits.device) 2166 | loss_fct = CrossEntropyLoss() 2167 | masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) 2168 | 2169 | if not return_dict: 2170 | output = (lm_logits,) + outputs[1:] 2171 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 2172 | 2173 | return Seq2SeqLMOutput( 2174 | loss=masked_lm_loss, 2175 | logits=lm_logits, 2176 | past_key_values=outputs.past_key_values, 2177 | decoder_hidden_states=outputs.decoder_hidden_states, 2178 | decoder_attentions=outputs.decoder_attentions, 2179 | cross_attentions=outputs.cross_attentions, 2180 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 2181 | encoder_hidden_states=outputs.encoder_hidden_states, 2182 | encoder_attentions=outputs.encoder_attentions, 2183 | ) 2184 | 2185 | def prepare_inputs_for_generation( 2186 | self, 2187 | decoder_input_ids, 2188 | past_key_values=None, 2189 | attention_mask=None, 2190 | decoder_attention_mask=None, 2191 | head_mask=None, 2192 | decoder_head_mask=None, 2193 | cross_attn_head_mask=None, 2194 | use_cache=None, 2195 | encoder_outputs=None, 2196 | **kwargs, 2197 | ): 2198 | # cut decoder_input_ids if past_key_values is used 2199 | if past_key_values is not None: 2200 | past_length = past_key_values[0][0].shape[2] 2201 | 2202 | # Some generation methods already pass only the last input ID 2203 | if decoder_input_ids.shape[1] > past_length: 2204 | remove_prefix_length = past_length 2205 | else: 2206 | # Default to old behavior: keep only final ID 2207 | remove_prefix_length = decoder_input_ids.shape[1] - 1 2208 | 2209 | decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] 2210 | 2211 | return { 2212 | "input_ids": None, # encoder_outputs is defined. input_ids not needed 2213 | "encoder_outputs": encoder_outputs, 2214 | "past_key_values": past_key_values, 2215 | "decoder_input_ids": decoder_input_ids, 2216 | "attention_mask": attention_mask, 2217 | "decoder_attention_mask": decoder_attention_mask, 2218 | "head_mask": head_mask, 2219 | "decoder_head_mask": decoder_head_mask, 2220 | "cross_attn_head_mask": cross_attn_head_mask, 2221 | "use_cache": use_cache, # change this to avoid caching (presumably for debugging) 2222 | } 2223 | 2224 | def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): 2225 | return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) 2226 | 2227 | @staticmethod 2228 | def _reorder_cache(past_key_values, beam_idx): 2229 | reordered_past = () 2230 | for layer_past in past_key_values: 2231 | # cached cross_attention states don't have to be reordered -> they are always the same 2232 | reordered_past += ( 2233 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) 2234 | + layer_past[2:], 2235 | ) 2236 | return reordered_past 2237 | 2238 | @dataclass 2239 | class Florence2Seq2SeqLMOutput(ModelOutput): 2240 | """ 2241 | Base class for Florence-2 model's outputs that also contains : pre-computed hidden states that can speed up sequential 2242 | decoding. 2243 | 2244 | Args: 2245 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): 2246 | Language modeling loss. 2247 | logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): 2248 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 2249 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 2250 | Sequence of hidden-states at the output of the last layer of the decoder of the model. 2251 | 2252 | If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1, 2253 | hidden_size)` is output. 2254 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 2255 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 2256 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape 2257 | `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 2258 | 2259 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 2260 | blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 2261 | decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 2262 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 2263 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 2264 | 2265 | Hidden-states of the decoder at the output of each layer plus the optional initial embedding outputs. 2266 | decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 2267 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 2268 | sequence_length)`. 2269 | 2270 | Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the 2271 | self-attention heads. 2272 | cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 2273 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 2274 | sequence_length)`. 2275 | 2276 | Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the 2277 | weighted average in the cross-attention heads. 2278 | encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 2279 | Sequence of hidden-states at the output of the last layer of the encoder of the model. 2280 | encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 2281 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 2282 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 2283 | 2284 | Hidden-states of the encoder at the output of each layer plus the optional initial embedding outputs. 2285 | encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 2286 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 2287 | sequence_length)`. 2288 | 2289 | Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the 2290 | self-attention heads. 2291 | image_hidden_states (`tuple(torch.FloatTensor)`, *optional*): 2292 | Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, 2293 | num_image_tokens, hidden_size)`. 2294 | 2295 | image_hidden_states of the model produced by the vision encoder 2296 | """ 2297 | loss: Optional[torch.FloatTensor] = None 2298 | logits: torch.FloatTensor = None 2299 | last_hidden_state: torch.FloatTensor = None 2300 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 2301 | decoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None 2302 | decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None 2303 | cross_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None 2304 | encoder_last_hidden_state: Optional[torch.FloatTensor] = None 2305 | encoder_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None 2306 | encoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None 2307 | image_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None 2308 | 2309 | 2310 | FLORENCE2_START_DOCSTRING = r""" 2311 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 2312 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 2313 | etc.) 2314 | 2315 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 2316 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 2317 | and behavior. 2318 | 2319 | Parameters: 2320 | config ([`Florence2Config`] or [`Florence2VisionConfig`]): 2321 | Model configuration class with all the parameters of the model. Initializing with a config file does not 2322 | load the weights associated with the model, only the configuration. Check out the 2323 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 2324 | """ 2325 | 2326 | 2327 | @add_start_docstrings( 2328 | "The bare Florence-2 Model outputting raw hidden-states without any specific head on top.", 2329 | FLORENCE2_START_DOCSTRING, 2330 | ) 2331 | class Florence2PreTrainedModel(PreTrainedModel): 2332 | config_class = Florence2Config 2333 | base_model_prefix = "model" 2334 | supports_gradient_checkpointing = True 2335 | _skip_keys_device_placement = "past_key_values" 2336 | 2337 | @property 2338 | def _supports_flash_attn_2(self): 2339 | """ 2340 | Retrieve language_model's attribute to check whether the model supports 2341 | Flash Attention 2 or not. 2342 | """ 2343 | if hasattr(self, 'language_model') and self.language_model is not None: 2344 | return self.language_model._supports_flash_attn_2 2345 | return True # Default to True during initialization 2346 | 2347 | @property 2348 | def _supports_sdpa(self): 2349 | """ 2350 | Retrieve language_model's attribute to check whether the model supports 2351 | SDPA or not. 2352 | """ 2353 | if hasattr(self, 'language_model') and self.language_model is not None: 2354 | return self.language_model._supports_sdpa 2355 | return True # Default to True during initialization 2356 | 2357 | 2358 | FLORENCE2_INPUTS_DOCSTRING = r""" 2359 | Args: 2360 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 2361 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 2362 | it. 2363 | 2364 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 2365 | [`PreTrainedTokenizer.__call__`] for details. 2366 | 2367 | [What are input IDs?](../glossary#input-ids) 2368 | pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): 2369 | The tensors corresponding to the input images. Pixel values can be obtained using 2370 | [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`Florence2Processor`] uses 2371 | [`CLIPImageProcessor`] for processing images). 2372 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 2373 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 2374 | 2375 | - 1 for tokens that are **not masked**, 2376 | - 0 for tokens that are **masked**. 2377 | 2378 | [What are attention masks?](../glossary#attention-mask) 2379 | 2380 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 2381 | [`PreTrainedTokenizer.__call__`] for details. 2382 | 2383 | If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see 2384 | `past_key_values`). 2385 | 2386 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] 2387 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more 2388 | information on the default strategy. 2389 | 2390 | - 1 indicates the head is **not masked**, 2391 | - 0 indicates the head is **masked**. 2392 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 2393 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 2394 | config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) 2395 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 2396 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 2397 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape 2398 | `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. 2399 | 2400 | Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 2401 | blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. 2402 | 2403 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that 2404 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all 2405 | `decoder_input_ids` of shape `(batch_size, sequence_length)`. 2406 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 2407 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 2408 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 2409 | model's internal embedding lookup matrix. 2410 | use_cache (`bool`, *optional*): 2411 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 2412 | `past_key_values`). 2413 | output_attentions (`bool`, *optional*): 2414 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 2415 | tensors for more detail. 2416 | output_hidden_states (`bool`, *optional*): 2417 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 2418 | more detail. 2419 | return_dict (`bool`, *optional*): 2420 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 2421 | """ 2422 | 2423 | @add_start_docstrings( 2424 | """The FLORENCE2 vision model without any head""", 2425 | FLORENCE2_START_DOCSTRING, 2426 | ) 2427 | class Florence2VisionModel(Florence2PreTrainedModel): 2428 | def __init__(self, config: Florence2VisionConfig): 2429 | super().__init__(config) 2430 | assert config.model_type == 'davit', 'only DaViT is supported for now' 2431 | self.vision_tower = DaViT.from_config(config=config) 2432 | 2433 | self.post_init() 2434 | 2435 | def forward(self, pixel_values): 2436 | if len(pixel_values.shape) == 4: 2437 | x = self.vision_tower.forward_features_unpool(pixel_values) 2438 | else: 2439 | raise ValueError(f'invalid image shape {pixel_values.shape}') 2440 | return x 2441 | 2442 | 2443 | @add_start_docstrings( 2444 | """The FLORENCE2 vision model with projection layer""", 2445 | FLORENCE2_START_DOCSTRING, 2446 | ) 2447 | class Florence2VisionModelWithProjection(Florence2PreTrainedModel): 2448 | def __init__(self, config: Florence2VisionConfig): 2449 | super().__init__(config) 2450 | assert config.model_type == 'davit', 'only DaViT is supported for now' 2451 | self.vision_tower = DaViT.from_config(config=config) 2452 | 2453 | self._build_image_projection_layers(config) 2454 | 2455 | self.post_init() 2456 | 2457 | def _build_image_projection_layers(self, config): 2458 | image_dim_out = config.dim_embed[-1] 2459 | dim_projection = config.projection_dim 2460 | self.image_projection = nn.Parameter( 2461 | torch.empty(image_dim_out, dim_projection) 2462 | ) 2463 | self.image_proj_norm = nn.LayerNorm(dim_projection) 2464 | image_pos_embed_config = config.image_pos_embed 2465 | if image_pos_embed_config['type'] == 'learned_abs_2d': 2466 | self.image_pos_embed = LearnedAbsolutePositionEmbedding2D( 2467 | embedding_dim=image_dim_out, 2468 | num_pos=image_pos_embed_config['max_pos_embeddings'] 2469 | ) 2470 | else: 2471 | raise NotImplementedError('Not implemented yet') 2472 | 2473 | self.image_feature_source = config.image_feature_source 2474 | 2475 | # temporal embedding 2476 | visual_temporal_embedding_config = config.visual_temporal_embedding 2477 | if visual_temporal_embedding_config['type'] == 'COSINE': 2478 | self.visual_temporal_embed = PositionalEmbeddingCosine1D( 2479 | embed_dim=image_dim_out, 2480 | max_seq_len=visual_temporal_embedding_config['max_temporal_embeddings'] 2481 | ) 2482 | else: 2483 | raise NotImplementedError('Not implemented yet') 2484 | 2485 | def forward(self, pixel_values): 2486 | if len(pixel_values.shape) == 4: 2487 | batch_size, C, H, W = pixel_values.shape 2488 | T = 1 2489 | x = self.vision_tower.forward_features_unpool(pixel_values) 2490 | else: 2491 | raise ValueError(f'invalid image shape {pixel_values.shape}') 2492 | 2493 | if self.image_pos_embed is not None: 2494 | x = x.view(batch_size * T, -1, x.shape[-1]) 2495 | num_tokens = x.shape[-2] 2496 | h, w = int(num_tokens ** 0.5), int(num_tokens ** 0.5) 2497 | assert h * w == num_tokens, 'only support square feature maps for now' 2498 | x = x.view(batch_size * T, h, w, x.shape[-1]) 2499 | pos_embed = self.image_pos_embed(x) 2500 | x = x + pos_embed 2501 | x = x.view(batch_size, T * h*w, x.shape[-1]) 2502 | 2503 | if self.visual_temporal_embed is not None: 2504 | visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]) 2505 | x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1]) 2506 | 2507 | x_feat_dict = {} 2508 | 2509 | spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) 2510 | x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x 2511 | 2512 | temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1) 2513 | x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x 2514 | 2515 | x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] 2516 | x_feat_dict['last_frame'] = x 2517 | 2518 | new_x = [] 2519 | for _image_feature_source in self.image_feature_source: 2520 | if _image_feature_source not in x_feat_dict: 2521 | raise ValueError('invalid image feature source: {}'.format(_image_feature_source)) 2522 | new_x.append(x_feat_dict[_image_feature_source]) 2523 | 2524 | x = torch.cat(new_x, dim=1) 2525 | 2526 | x = x @ self.image_projection 2527 | x = self.image_proj_norm(x) 2528 | 2529 | 2530 | return x 2531 | 2532 | 2533 | 2534 | @add_start_docstrings( 2535 | """The FLORENCE2 model which consists of a vision backbone and a language model.""", 2536 | FLORENCE2_START_DOCSTRING, 2537 | ) 2538 | class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixin): 2539 | _tied_weights_keys = ["language_model.encoder.embed_tokens.weight", "language_model.decoder.embed_tokens.weight", "language_model.lm_head.weight"] 2540 | def __init__(self, config: Florence2Config): 2541 | super().__init__(config) 2542 | assert config.vision_config.model_type == 'davit', 'only DaViT is supported for now' 2543 | self.vision_tower = DaViT.from_config(config=config.vision_config) 2544 | # remove unused layers 2545 | del self.vision_tower.head 2546 | del self.vision_tower.norms 2547 | 2548 | self.vocab_size = config.vocab_size 2549 | self._attn_implementation = config._attn_implementation 2550 | self._build_image_projection_layers(config) 2551 | 2552 | language_model = Florence2LanguageForConditionalGeneration(config=config.text_config) 2553 | 2554 | self.language_model = language_model 2555 | 2556 | self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 2557 | self.post_init() 2558 | 2559 | def _build_image_projection_layers(self, config): 2560 | image_dim_out = config.vision_config.dim_embed[-1] 2561 | dim_projection = config.vision_config.projection_dim 2562 | self.image_projection = nn.Parameter( 2563 | torch.empty(image_dim_out, dim_projection) 2564 | ) 2565 | self.image_proj_norm = nn.LayerNorm(dim_projection) 2566 | image_pos_embed_config = config.vision_config.image_pos_embed 2567 | if image_pos_embed_config['type'] == 'learned_abs_2d': 2568 | self.image_pos_embed = LearnedAbsolutePositionEmbedding2D( 2569 | embedding_dim=image_dim_out, 2570 | num_pos=image_pos_embed_config['max_pos_embeddings'] 2571 | ) 2572 | else: 2573 | raise NotImplementedError('Not implemented yet') 2574 | 2575 | self.image_feature_source = config.vision_config.image_feature_source 2576 | 2577 | # temporal embedding 2578 | visual_temporal_embedding_config = config.vision_config.visual_temporal_embedding 2579 | if visual_temporal_embedding_config['type'] == 'COSINE': 2580 | self.visual_temporal_embed = PositionalEmbeddingCosine1D( 2581 | embed_dim=image_dim_out, 2582 | max_seq_len=visual_temporal_embedding_config['max_temporal_embeddings'] 2583 | ) 2584 | else: 2585 | raise NotImplementedError('Not implemented yet') 2586 | 2587 | def get_encoder(self): 2588 | return self.language_model.get_encoder() 2589 | 2590 | def get_decoder(self): 2591 | return self.language_model.get_decoder() 2592 | 2593 | def get_input_embeddings(self): 2594 | return self.language_model.get_input_embeddings() 2595 | 2596 | def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: 2597 | model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) 2598 | # update vocab size 2599 | self.config.text_config.vocab_size = model_embeds.num_embeddings 2600 | self.config.vocab_size = model_embeds.num_embeddings 2601 | self.vocab_size = model_embeds.num_embeddings 2602 | return model_embeds 2603 | 2604 | def _encode_image(self, pixel_values): 2605 | if len(pixel_values.shape) == 4: 2606 | batch_size, C, H, W = pixel_values.shape 2607 | T = 1 2608 | x = self.vision_tower.forward_features_unpool(pixel_values) 2609 | else: 2610 | raise ValueError(f'invalid image shape {pixel_values.shape}') 2611 | 2612 | if self.image_pos_embed is not None: 2613 | x = x.view(batch_size * T, -1, x.shape[-1]) 2614 | num_tokens = x.shape[-2] 2615 | h, w = int(num_tokens ** 0.5), int(num_tokens ** 0.5) 2616 | assert h * w == num_tokens, 'only support square feature maps for now' 2617 | x = x.view(batch_size * T, h, w, x.shape[-1]) 2618 | pos_embed = self.image_pos_embed(x) 2619 | x = x + pos_embed 2620 | x = x.view(batch_size, T * h*w, x.shape[-1]) 2621 | 2622 | if self.visual_temporal_embed is not None: 2623 | visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]) 2624 | x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1]) 2625 | 2626 | x_feat_dict = {} 2627 | 2628 | spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) 2629 | x_feat_dict['spatial_avg_pool'] = spatial_avg_pool_x 2630 | 2631 | temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1) 2632 | x_feat_dict['temporal_avg_pool'] = temporal_avg_pool_x 2633 | 2634 | x = x.view(batch_size, T, -1, x.shape[-1])[:, -1] 2635 | x_feat_dict['last_frame'] = x 2636 | 2637 | new_x = [] 2638 | for _image_feature_source in self.image_feature_source: 2639 | if _image_feature_source not in x_feat_dict: 2640 | raise ValueError('invalid image feature source: {}'.format(_image_feature_source)) 2641 | new_x.append(x_feat_dict[_image_feature_source]) 2642 | 2643 | x = torch.cat(new_x, dim=1) 2644 | 2645 | x = x @ self.image_projection 2646 | x = self.image_proj_norm(x) 2647 | 2648 | return x 2649 | 2650 | def _merge_input_ids_with_image_features( 2651 | self, image_features, inputs_embeds 2652 | ): 2653 | batch_size, image_token_length = image_features.size()[:-1] 2654 | device = image_features.device 2655 | image_attention_mask = torch.ones(batch_size, image_token_length, device=device) 2656 | 2657 | # task_prefix_embeds: [batch_size, padded_context_length, hidden_size] 2658 | # task_prefix_attention_mask: [batch_size, context_length] 2659 | if inputs_embeds is None: 2660 | return image_features, image_attention_mask 2661 | 2662 | task_prefix_embeds = inputs_embeds 2663 | task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device) 2664 | 2665 | if len(task_prefix_attention_mask.shape) == 3: 2666 | task_prefix_attention_mask = task_prefix_attention_mask[:, 0] 2667 | 2668 | # concat [image embeds, task prefix embeds] 2669 | inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1) 2670 | attention_mask = torch.cat([image_attention_mask, task_prefix_attention_mask], dim=1) 2671 | 2672 | return inputs_embeds, attention_mask 2673 | 2674 | 2675 | @add_start_docstrings_to_model_forward(FLORENCE2_INPUTS_DOCSTRING) 2676 | @replace_return_docstrings(output_type=Florence2Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) 2677 | def forward( 2678 | self, 2679 | input_ids: torch.LongTensor = None, 2680 | pixel_values: torch.FloatTensor = None, 2681 | attention_mask: Optional[torch.Tensor] = None, 2682 | decoder_input_ids: Optional[torch.LongTensor] = None, 2683 | decoder_attention_mask: Optional[torch.LongTensor] = None, 2684 | head_mask: Optional[torch.Tensor] = None, 2685 | decoder_head_mask: Optional[torch.Tensor] = None, 2686 | cross_attn_head_mask: Optional[torch.Tensor] = None, 2687 | encoder_outputs: Optional[List[torch.FloatTensor]] = None, 2688 | past_key_values: Optional[List[torch.FloatTensor]] = None, 2689 | inputs_embeds: Optional[torch.FloatTensor] = None, 2690 | decoder_inputs_embeds: Optional[torch.FloatTensor] = None, 2691 | labels: Optional[torch.LongTensor] = None, 2692 | use_cache: Optional[bool] = None, 2693 | output_attentions: Optional[bool] = None, 2694 | output_hidden_states: Optional[bool] = None, 2695 | return_dict: Optional[bool] = None, 2696 | ) -> Union[Tuple, Florence2Seq2SeqLMOutput]: 2697 | r""" 2698 | Args: 2699 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 2700 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 2701 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 2702 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 2703 | 2704 | Returns: 2705 | 2706 | Example: 2707 | 2708 | ```python 2709 | >>> from PIL import Image 2710 | >>> import requests 2711 | >>> from transformers import AutoProcessor, Florence2ForConditionalGeneration 2712 | 2713 | >>> model = Florence2ForConditionalGeneration.from_pretrained("microsoft/Florence-2-large") 2714 | >>> processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large") 2715 | 2716 | >>> prompt = "" 2717 | >>> url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg" 2718 | >>> image = Image.open(requests.get(url, stream=True).raw) 2719 | 2720 | >>> inputs = processor(text=prompt, images=image, return_tensors="pt") 2721 | 2722 | >>> # Generate 2723 | >>> generate_ids = model.generate(**inputs, max_length=100) 2724 | >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 2725 | "A green car parked in front of a yellow building." 2726 | ```""" 2727 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 2728 | output_hidden_states = ( 2729 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 2730 | ) 2731 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 2732 | 2733 | image_features = None 2734 | if inputs_embeds is None: 2735 | # 1. Extra the input embeddings 2736 | if input_ids is not None: 2737 | inputs_embeds = self.get_input_embeddings()(input_ids) 2738 | # 2. Merge text and images 2739 | if pixel_values is not None: 2740 | # (batch_size, num_image_tokens, hidden_size) 2741 | image_features = self._encode_image(pixel_values) 2742 | inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds) 2743 | 2744 | if inputs_embeds is not None: 2745 | attention_mask = attention_mask.to(inputs_embeds.dtype) 2746 | outputs = self.language_model( 2747 | attention_mask=attention_mask, 2748 | labels=labels, 2749 | inputs_embeds=inputs_embeds, 2750 | decoder_input_ids=decoder_input_ids, 2751 | encoder_outputs=encoder_outputs, 2752 | decoder_attention_mask=decoder_attention_mask, 2753 | head_mask=head_mask, 2754 | decoder_head_mask=decoder_head_mask, 2755 | cross_attn_head_mask=cross_attn_head_mask, 2756 | past_key_values=past_key_values, 2757 | decoder_inputs_embeds=decoder_inputs_embeds, 2758 | use_cache=use_cache, 2759 | output_attentions=output_attentions, 2760 | output_hidden_states=output_hidden_states, 2761 | return_dict=return_dict, 2762 | ) 2763 | 2764 | logits = outputs.logits 2765 | logits = logits.float() 2766 | loss = outputs.loss 2767 | if not return_dict: 2768 | output = (logits,) + outputs[1:] 2769 | return (loss,) + output if loss is not None else output 2770 | 2771 | return Florence2Seq2SeqLMOutput( 2772 | loss=loss, 2773 | logits=logits, 2774 | past_key_values=outputs.past_key_values, 2775 | decoder_hidden_states=outputs.decoder_hidden_states, 2776 | decoder_attentions=outputs.decoder_attentions, 2777 | cross_attentions=outputs.cross_attentions, 2778 | encoder_last_hidden_state=outputs.encoder_last_hidden_state, 2779 | encoder_hidden_states=outputs.encoder_hidden_states, 2780 | encoder_attentions=outputs.encoder_attentions, 2781 | image_hidden_states=image_features 2782 | ) 2783 | 2784 | def generate( 2785 | self, 2786 | input_ids, 2787 | inputs_embeds=None, 2788 | pixel_values=None, 2789 | **kwargs 2790 | ): 2791 | 2792 | if inputs_embeds is None: 2793 | # 1. Extra the input embeddings 2794 | if input_ids is not None: 2795 | inputs_embeds = self.get_input_embeddings()(input_ids) 2796 | # 2. Merge text and images 2797 | if pixel_values is not None: 2798 | image_features = self._encode_image(pixel_values) 2799 | inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds) 2800 | 2801 | return self.language_model.generate( 2802 | input_ids=None, 2803 | inputs_embeds=inputs_embeds, 2804 | **kwargs 2805 | ) 2806 | 2807 | def prepare_inputs_for_generation( 2808 | self, 2809 | decoder_input_ids, 2810 | past_key_values=None, 2811 | attention_mask=None, 2812 | pixel_values=None, 2813 | decoder_attention_mask=None, 2814 | head_mask=None, 2815 | decoder_head_mask=None, 2816 | cross_attn_head_mask=None, 2817 | use_cache=None, 2818 | encoder_outputs=None, 2819 | **kwargs, 2820 | ): 2821 | # cut decoder_input_ids if past_key_values is used 2822 | if past_key_values is not None: 2823 | past_length = past_key_values[0][0].shape[2] 2824 | 2825 | # Some generation methods already pass only the last input ID 2826 | if decoder_input_ids.shape[1] > past_length: 2827 | remove_prefix_length = past_length 2828 | else: 2829 | # Default to old behavior: keep only final ID 2830 | remove_prefix_length = decoder_input_ids.shape[1] - 1 2831 | 2832 | decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] 2833 | 2834 | return { 2835 | "input_ids": None, # encoder_outputs is defined. input_ids not needed 2836 | "encoder_outputs": encoder_outputs, 2837 | "past_key_values": past_key_values, 2838 | "decoder_input_ids": decoder_input_ids, 2839 | "attention_mask": attention_mask, 2840 | "pixel_values": pixel_values, 2841 | "decoder_attention_mask": decoder_attention_mask, 2842 | "head_mask": head_mask, 2843 | "decoder_head_mask": decoder_head_mask, 2844 | "cross_attn_head_mask": cross_attn_head_mask, 2845 | "use_cache": use_cache, # change this to avoid caching (presumably for debugging) 2846 | } 2847 | 2848 | def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): 2849 | return self.language_model.shift_tokens_right(labels) 2850 | 2851 | def _reorder_cache(self, *args, **kwargs): 2852 | return self.language_model._reorder_cache(*args, **kwargs) --------------------------------------------------------------------------------