├── mgm ├── __init__.py ├── model │ ├── __init__.py │ ├── consolidate.py │ ├── multimodal_projector │ │ └── builder.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ ├── clip_encoder.py │ │ └── openclip_encoder.py │ ├── processor │ │ └── video_processor.py │ ├── language_model │ │ ├── mgm_mistral.py │ │ ├── mgm_mixtral.py │ │ └── mgm_gemma.py │ └── builder.py ├── constants.py ├── utils.py └── mm_utils.py ├── images ├── flmm_chat_vis.jpg ├── flmm_pipeline.jpg └── flmm_visual_cot.jpg ├── segment_anything ├── utils │ ├── __init__.py │ ├── transforms.py │ └── onnx.py ├── modeling │ ├── __init__.py │ ├── common.py │ ├── mask_decoder.py │ ├── sam.py │ └── transformer.py ├── __init__.py └── build_sam.py ├── scripts ├── deepspeed2torch_state_dict.py ├── demo │ ├── visualize_ranks.py │ ├── utils.py │ ├── grounded_conversation.py │ └── multiprocess_infer_png.py ├── visual_cot │ ├── gpt_eval_cot_score_single.py │ ├── gpt_eval_cot_score.py │ └── visual_cot_inference.py ├── multiprocess_eval_png.py └── multiprocess_eval_refcoco.py ├── flmm ├── utils.py ├── datasets │ ├── pad2square_processor.py │ ├── transforms.py │ └── llava_processors.py ├── models │ └── mask_head │ │ ├── mask_decoder.py │ │ └── mask_refiner.py └── runner.py ├── deepseek_vl ├── utils │ ├── __init__.py │ └── io.py ├── models │ ├── __init__.py │ ├── projector.py │ ├── modeling_vlm.py │ ├── image_processing_vlm.py │ └── clip_encoder.py └── __init__.py ├── LICENSE └── .gitignore /mgm/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import MGMLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /images/flmm_chat_vis.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wusize/F-LMM/HEAD/images/flmm_chat_vis.jpg -------------------------------------------------------------------------------- /images/flmm_pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wusize/F-LMM/HEAD/images/flmm_pipeline.jpg -------------------------------------------------------------------------------- /images/flmm_visual_cot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wusize/F-LMM/HEAD/images/flmm_visual_cot.jpg -------------------------------------------------------------------------------- /segment_anything/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /mgm/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.mgm_llama import MGMLlamaForCausalLM 2 | try: 3 | from .language_model.mgm_mistral import MGMMistralForCausalLM 4 | from .language_model.mgm_mixtral import MGMMixtralForCausalLM 5 | from .language_model.mgm_gemma import MGMGemmaForCausalLM 6 | except: 7 | ImportWarning("New model not imported. Try to update Transformers.") -------------------------------------------------------------------------------- /segment_anything/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .sam import Sam 8 | from .image_encoder import ImageEncoderViT 9 | from .mask_decoder import MaskDecoder 10 | from .prompt_encoder import PromptEncoder 11 | from .transformer import TwoWayTransformer 12 | -------------------------------------------------------------------------------- /segment_anything/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .build_sam import ( 8 | build_sam, 9 | build_sam_vit_h, 10 | build_sam_vit_l, 11 | build_sam_vit_b, 12 | sam_model_registry, 13 | ) 14 | from .predictor import SamPredictor 15 | from .automatic_mask_generator import SamAutomaticMaskGenerator 16 | -------------------------------------------------------------------------------- /scripts/deepspeed2torch_state_dict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from xtuner.model.utils import guess_load_checkpoint 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) 7 | parser.add_argument("--deepspeed_path", default='', type=str) 8 | parser.add_argument("--torch_path", default='', type=str) 9 | 10 | args = parser.parse_args() 11 | state_dict = guess_load_checkpoint(args.deepspeed_path) 12 | torch.save(state_dict, args.torch_path) 13 | -------------------------------------------------------------------------------- /mgm/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m mgm.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from mgm.model import * 10 | from mgm.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /mgm/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | PREDICT_TOKEN_INDEX = -300 10 | DEFAULT_IMAGE_TOKEN = "" 11 | DEFAULT_IMAGE_PATCH_TOKEN = "" 12 | DEFAULT_IM_START_TOKEN = "" 13 | DEFAULT_IM_END_TOKEN = "" 14 | IMAGE_PLACEHOLDER = "" 15 | DEFAULT_PREDICT_TOKEN = "" 16 | 17 | DESCRIPT_PROMPT = [ 18 | "Describe this image thoroughly.", 19 | "Provide a detailed description in this picture.", 20 | "Detail every aspect of what's in this picture.", 21 | "Explain this image with precision and detail.", 22 | "Give a comprehensive description of this visual.", 23 | "Elaborate on the specifics within this image.", 24 | "Offer a detailed account of this picture's contents.", 25 | "Describe in detail what this image portrays.", 26 | "Break down this image into detailed descriptions.", 27 | "Provide a thorough description of the elements in this image."] -------------------------------------------------------------------------------- /flmm/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import partial 3 | from six.moves import map, zip 4 | 5 | 6 | @torch.no_grad() 7 | def compute_mask_IoU(masks, target): 8 | temp = masks * target 9 | intersection = temp.sum(dim=-1) 10 | union = ((masks + target) - temp).sum(dim=-1) 11 | return intersection / (union + 1e-12) 12 | 13 | 14 | def multi_apply(func, *args, **kwargs): 15 | """Apply function to a list of arguments. 16 | 17 | Note: 18 | This function applies the ``func`` to multiple inputs and 19 | map the multiple outputs of the ``func`` into different 20 | list. Each list contains the same type of outputs corresponding 21 | to different inputs. 22 | 23 | Args: 24 | func (Function): A function that will be applied to a list of 25 | arguments 26 | 27 | Returns: 28 | tuple(list): A tuple containing multiple list, each list contains \ 29 | a kind of returned results by the function 30 | """ 31 | pfunc = partial(func, **kwargs) if kwargs else func 32 | map_results = map(pfunc, *args) 33 | return tuple(map(list, zip(*map_results))) 34 | -------------------------------------------------------------------------------- /deepseek_vl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | -------------------------------------------------------------------------------- /deepseek_vl/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | from .image_processing_vlm import VLMImageProcessor 21 | from .modeling_vlm import MultiModalityCausalLM 22 | from .processing_vlm import VLChatProcessor 23 | 24 | __all__ = [ 25 | "VLMImageProcessor", 26 | "VLChatProcessor", 27 | "MultiModalityCausalLM", 28 | ] 29 | -------------------------------------------------------------------------------- /deepseek_vl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | 21 | # check if python version is above 3.10 22 | import sys 23 | 24 | if sys.version_info >= (3, 10): 25 | print("Python version is above 3.10, patching the collections module.") 26 | # Monkey patch collections 27 | import collections 28 | import collections.abc 29 | 30 | for type_name in collections.abc.__all__: 31 | setattr(collections, type_name, getattr(collections.abc, type_name)) 32 | -------------------------------------------------------------------------------- /scripts/demo/visualize_ranks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | deepseek_vl_1_3b = torch.tensor([1307+225, 64.6, 34.8, 51.1, 75.0, 62.8, 68.2, 64.9, 63.4, 68.3]) 4 | mgm_2b = torch.tensor([1341+312, 59.8, 31.1, 65.9, 75.0, 63.7, 67.3, 65.6, 64.4, 68.4]) 5 | llava_1_5_7b = torch.tensor([1511+348, 64.3, 30.5, 69.0, 75.2, 63.7, 67.1, 64.8, 63.4, 68.2]) 6 | hpt_air_7b = torch.tensor([1010+258, 69.8, 31.3, 59.2, 74.3, 64.0, 67.5, 65.5, 64.0, 68.8]) 7 | hpt_air_1_5_8b = torch.tensor([1476+308, 75.2, 36.3, 62.1, 76.3, 64.5, 68.5, 65.4, 64.1, 68.5]) 8 | mgm_7b = torch.tensor([1523+316, 69.3, 40.8, 75.8, 75.7, 64.8, 68.3, 66.3, 65.3, 68.6]) 9 | deepseek_vl_7b = torch.tensor([1468+298, 73.2, 41.5, 77.8, 76.1, 66.4, 70.1, 65.7, 64.5, 68.5]) 10 | llava_1_6_7b = torch.tensor([1519/322, 68.1, 44.1, 72.3, 75.8, 65.8, 70.1, 66.3, 65.1, 69.0]) 11 | llava_1_6_m_7b = torch.tensor([1501+324, 69.5, 47.8, 71.7, 75.7, 66.5, 70.1, 66.5, 65.4, 69.1]) 12 | mgm_hd_7b = torch.tensor([1546+319, 65.8, 41.3, 74.0, 76.1, 65.2, 68.5, 66.7, 65.6, 69.1]) 13 | 14 | all_scores = torch.stack([deepseek_vl_1_3b, mgm_2b, 15 | llava_1_5_7b, hpt_air_7b, hpt_air_1_5_8b, mgm_7b, deepseek_vl_7b, 16 | llava_1_6_7b, llava_1_6_m_7b, mgm_hd_7b]) 17 | 18 | all_ranks = torch.sort(-all_scores, dim=0).indices 19 | 20 | 21 | qa_ranks = all_ranks[:, :4] 22 | seg_ranks = all_ranks[:, 4:] 23 | 24 | 25 | qa_ave_ranks = qa_ranks.float().mean(dim=-1) 26 | seg_ave_ranks = seg_ranks.float().mean(dim=-1) 27 | 28 | 29 | import matplotlib.pyplot as plt 30 | plt.scatter((10 - qa_ave_ranks).tolist(), (10 - seg_ave_ranks).tolist()) 31 | plt.savefig('ave_ranks.jpg') 32 | 33 | -------------------------------------------------------------------------------- /mgm/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | class IdentityMap(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | def forward(self, x, *args, **kwargs): 10 | return x 11 | 12 | @property 13 | def config(self): 14 | return {"mm_projector_type": 'identity'} 15 | 16 | 17 | class SimpleResBlock(nn.Module): 18 | def __init__(self, channels): 19 | super().__init__() 20 | self.pre_norm = nn.LayerNorm(channels) 21 | 22 | self.proj = nn.Sequential( 23 | nn.Linear(channels, channels), 24 | nn.GELU(), 25 | nn.Linear(channels, channels) 26 | ) 27 | def forward(self, x): 28 | x = self.pre_norm(x) 29 | return x + self.proj(x) 30 | 31 | 32 | def build_vision_projector(config, delay_load=False, **kwargs): 33 | projector_type = getattr(config, 'mm_projector_type', 'linear') 34 | 35 | if projector_type == 'linear': 36 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 37 | 38 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 39 | if mlp_gelu_match: 40 | mlp_depth = int(mlp_gelu_match.group(1)) 41 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 42 | for _ in range(1, mlp_depth): 43 | modules.append(nn.GELU()) 44 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 45 | return nn.Sequential(*modules) 46 | 47 | if projector_type == 'identity': 48 | return IdentityMap() 49 | 50 | raise ValueError(f'Unknown projector type: {projector_type}') -------------------------------------------------------------------------------- /segment_anything/modeling/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from typing import Type 11 | 12 | 13 | class MLPBlock(nn.Module): 14 | def __init__( 15 | self, 16 | embedding_dim: int, 17 | mlp_dim: int, 18 | act: Type[nn.Module] = nn.GELU, 19 | ) -> None: 20 | super().__init__() 21 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 22 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 23 | self.act = act() 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | x = self.lin1(x) 27 | x = self.act(x) 28 | x = self.lin2(x) 29 | return x 30 | # return self.lin2(self.act(self.lin1(x))) 31 | 32 | 33 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 34 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 35 | class LayerNorm2d(nn.Module): 36 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 37 | super().__init__() 38 | self.weight = nn.Parameter(torch.ones(num_channels)) 39 | self.bias = nn.Parameter(torch.zeros(num_channels)) 40 | self.eps = eps 41 | 42 | def forward(self, x: torch.Tensor) -> torch.Tensor: 43 | u = x.mean(1, keepdim=True) 44 | s = (x - u).pow(2).mean(1, keepdim=True) 45 | x = (x - u) / torch.sqrt(s + self.eps) 46 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 47 | return x 48 | -------------------------------------------------------------------------------- /mgm/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | from .eva_encoder import EVAVisionTower 4 | from .openclip_encoder import OpenCLIPVisionTower 5 | 6 | 7 | def build_vision_tower(vision_tower_cfg, **kwargs): 8 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 9 | image_processor = getattr(vision_tower_cfg, 'image_processor', getattr(vision_tower_cfg, 'image_processor', "../processor/clip-patch14-224")) 10 | 11 | if not os.path.exists(vision_tower): 12 | raise ValueError(f'Not find vision tower: {vision_tower}') 13 | 14 | if "openai" in vision_tower.lower() or "ShareGPT4V" in vision_tower: 15 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 16 | elif "lavis" in vision_tower.lower() or "eva" in vision_tower.lower(): 17 | return EVAVisionTower(vision_tower, image_processor, args=vision_tower_cfg, **kwargs) 18 | else: 19 | raise ValueError(f'Unknown vision tower: {vision_tower}') 20 | 21 | 22 | def build_vision_tower_aux(vision_tower_cfg, **kwargs): 23 | vision_tower_aux = getattr(vision_tower_cfg, 'mm_vision_tower_aux', getattr(vision_tower_cfg, 'vision_tower_aux', None)) 24 | 25 | if not os.path.exists(vision_tower_aux): 26 | raise ValueError(f'Not find vision tower: {vision_tower_aux}') 27 | 28 | if "openclip" in vision_tower_aux.lower() or 'convnext' in vision_tower_aux.lower(): 29 | return OpenCLIPVisionTower(vision_tower_aux, args=vision_tower_cfg, **kwargs) 30 | elif "openai" in vision_tower_aux.lower(): 31 | return CLIPVisionTower(vision_tower_aux, args=vision_tower_cfg, **kwargs) 32 | else: 33 | raise ValueError(f'Unknown vision tower: {vision_tower_aux}') -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | S-Lab License 1.0 2 | 3 | Copyright 2022 S-Lab 4 | 5 | Redistribution and use for non-commercial purpose in source and 6 | binary forms, with or without modification, are permitted provided 7 | that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright 10 | notice, this list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright 13 | notice, this list of conditions and the following disclaimer in 14 | the documentation and/or other materials provided with the 15 | distribution. 16 | 17 | 3. Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived 19 | from this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 22 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 23 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 24 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 25 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 26 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 27 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 28 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 29 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 30 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 31 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 32 | 33 | In the event that redistribution and/or use for commercial purpose in 34 | source or binary forms, with or without modification is required, 35 | please contact the contributor(s) of the work. 36 | -------------------------------------------------------------------------------- /flmm/datasets/pad2square_processor.py: -------------------------------------------------------------------------------- 1 | from transformers.image_processing_utils import BatchFeature 2 | import numpy as np 3 | from PIL import Image 4 | from mmengine.logging import print_log 5 | 6 | 7 | class Pad2Square: 8 | def __init__(self, image_mean=(0.48145466, 0.4578275, 0.40821073)): 9 | if not isinstance(image_mean[0], int): 10 | image_mean = tuple(int(x * 255) for x in image_mean) 11 | print_log(f"image_mean: {image_mean}") 12 | self.image_mean = image_mean 13 | 14 | def preprocess(self, image, return_tensors=None): 15 | image = image.convert('RGB') 16 | 17 | width, height = image.size 18 | if width == height: 19 | result = image 20 | before_height = after_height = before_width = after_width = 0 21 | elif width > height: 22 | result = Image.new(image.mode, (width, width), self.image_mean) 23 | result.paste(image, (0, (width - height) // 2)) 24 | before_height = (width - height) // 2 25 | after_height = (width - height) - before_height 26 | before_width = after_width = 0 27 | else: 28 | result = Image.new(image.mode, (height, height), self.image_mean) 29 | result.paste(image, ((height - width) // 2, 0)) 30 | # return result 31 | before_width = (height - width) // 2 32 | after_width = (height - width) - before_width 33 | before_height = after_height = 0 34 | 35 | meta = dict(padding=dict(before_height=before_height, after_height=after_height, 36 | before_width=before_width, after_width=after_width), 37 | image_shape=dict(height=height, width=width), 38 | padded_shape=dict(height=max(height, width), width=max(height, width))) 39 | 40 | data = {"pixel_values": [result], "image_sizes": [(height, width)], "meta_datas": [meta]} 41 | 42 | return BatchFeature(data=data, tensor_type=return_tensors) 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/*/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # custom 107 | data/ 108 | data 109 | .vscode 110 | .idea 111 | .DS_Store 112 | *.pkl 113 | *.pkl.json 114 | *.log.json 115 | work_dirs/ 116 | 117 | # Pytorch 118 | *.pth 119 | *.py~ 120 | *.sh~ 121 | 122 | # srun 123 | *.out 124 | batchscript-* 125 | .idea/*.xml 126 | .idea/ 127 | *.pdf 128 | checkpoints/ 129 | -------------------------------------------------------------------------------- /flmm/models/mask_head/mask_decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import types 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from mmseg.models import UNet 6 | from mmseg.models.utils.wrappers import Upsample, resize 7 | from mmengine.logging import print_log 8 | 9 | 10 | def upsample_forward_func(self, x): 11 | dtype = x.dtype 12 | x = x.float() 13 | if not self.size: 14 | size = [int(t * self.scale_factor) for t in x.shape[-2:]] 15 | else: 16 | size = self.size 17 | return resize(x, size, None, self.mode, self.align_corners).to(dtype) 18 | 19 | 20 | class UNetHead(UNet): 21 | def __init__(self, upsample_input=None, 22 | normalize_input=False, 23 | *args, **kwargs): 24 | super().__init__(*args, **kwargs) 25 | self.conv_seg = nn.Conv2d(self.base_channels, 1, kernel_size=1) 26 | 27 | for module in self.modules(): 28 | if isinstance(module, Upsample): 29 | print_log("Replace upsample forward function") 30 | module.forward = types.MethodType(upsample_forward_func, module) 31 | 32 | self.init_weights() 33 | self.upsample_input = upsample_input 34 | self.normalize_input = normalize_input 35 | 36 | @property 37 | def dtype(self): 38 | return self.conv_seg.weight.dtype 39 | 40 | def forward(self, x): 41 | h, w = x.shape[-2:] 42 | if self.normalize_input: 43 | assert x.min() >= 0.0 and x.max() <= 1.0 44 | x_sum = x.sum((-2, -1), keepdims=True).clamp(min=1e-12) 45 | x = x / x_sum 46 | 47 | if self.upsample_input is not None: 48 | scale_factor = max(1.0, self.upsample_input / max(h, w)) 49 | x = F.interpolate(x.float(), scale_factor=scale_factor, mode='bilinear').to(x) 50 | h, w = x.shape[-2:] # upsample the low-res input to get better results 51 | 52 | dividend = 2**(self.num_stages - 1) 53 | padded_h = math.ceil(h / dividend) * dividend 54 | padded_w = math.ceil(w / dividend) * dividend 55 | 56 | padded_x = x.new_zeros(*x.shape[:2], padded_h, padded_w) 57 | padded_x[..., :h, :w] = x 58 | x = super().forward(padded_x)[-1][..., :h, :w] 59 | return self.conv_seg(x) 60 | -------------------------------------------------------------------------------- /scripts/demo/utils.py: -------------------------------------------------------------------------------- 1 | colors = [(220, 20, 60), (119, 11, 32), (0, 0, 142), (0, 0, 230), (106, 0, 228), 2 | (0, 60, 100), (0, 80, 100), (0, 0, 70), (0, 0, 192), (250, 170, 30), 3 | (100, 170, 30), (220, 220, 0), (175, 116, 175), (250, 0, 30), 4 | (165, 42, 42), (255, 77, 255), (0, 226, 252), (182, 182, 255), 5 | (0, 82, 0), (120, 166, 157), (110, 76, 0), (174, 57, 255), 6 | (199, 100, 0), (72, 0, 118), (255, 179, 240), (0, 125, 92), 7 | (209, 0, 151), (188, 208, 182), (0, 220, 176), (255, 99, 164), 8 | (92, 0, 73), (133, 129, 255), (78, 180, 255), (0, 228, 0), 9 | (174, 255, 243), (45, 89, 255), (134, 134, 103), (145, 148, 174), 10 | (255, 208, 186), (197, 226, 255), (171, 134, 1), (109, 63, 54), 11 | (207, 138, 255), (151, 0, 95), (9, 80, 61), (84, 105, 51), 12 | (74, 65, 105), (166, 196, 102), (208, 195, 210), (255, 109, 65), 13 | (0, 143, 149), (179, 0, 194), (209, 99, 106), (5, 121, 0), 14 | (227, 255, 205), (147, 186, 208), (153, 69, 1), (3, 95, 161), 15 | (163, 255, 0), (119, 0, 170), (0, 182, 199), (0, 165, 120), 16 | (183, 130, 88), (95, 32, 0), (130, 114, 135), (110, 129, 133), 17 | (166, 74, 118), (219, 142, 185), (79, 210, 114), (178, 90, 62), 18 | (65, 70, 15), (127, 167, 115), (59, 105, 106), (142, 108, 45), 19 | (196, 172, 0), (95, 54, 80), (128, 76, 255), (201, 57, 1), 20 | (246, 0, 122), (191, 162, 208), (255, 255, 128), (147, 211, 203), 21 | (150, 100, 100), (168, 171, 172), (146, 112, 198), (210, 170, 100), 22 | (92, 136, 89), (218, 88, 184), (241, 129, 0), (217, 17, 255), 23 | (124, 74, 181), (70, 70, 70), (255, 228, 255), (154, 208, 0), 24 | (193, 0, 92), (76, 91, 113), (255, 180, 195), (106, 154, 176), 25 | (230, 150, 140), (60, 143, 255), (128, 64, 128), (92, 82, 55), 26 | (254, 212, 124), (73, 77, 174), (255, 160, 98), (255, 255, 255), 27 | (104, 84, 109), (169, 164, 131), (225, 199, 255), (137, 54, 74), 28 | (135, 158, 223), (7, 246, 231), (107, 255, 200), (58, 41, 149), 29 | (183, 121, 142), (255, 73, 97), (107, 142, 35), (190, 153, 153), 30 | (146, 139, 141), (70, 130, 180), (134, 199, 156), (209, 226, 140), 31 | (96, 36, 108), (96, 96, 96), (64, 170, 64), (152, 251, 152), 32 | (208, 229, 228), (206, 186, 171), (152, 161, 64), (116, 112, 0), 33 | (0, 114, 143), (102, 102, 156), (250, 141, 255)] 34 | -------------------------------------------------------------------------------- /mgm/model/processor/video_processor.py: -------------------------------------------------------------------------------- 1 | from transformers import CLIPImageProcessor 2 | from transformers.image_processing_utils import BatchFeature, get_size_dict 3 | from transformers.image_transforms import get_resize_output_image_size 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | import numpy as np 9 | 10 | 11 | class VideoFramesProcessor(CLIPImageProcessor): 12 | 13 | def __init__(self, **kwargs): 14 | super().__init__(**kwargs) 15 | 16 | def preprocess(self, images, **kwargs): 17 | if not isinstance(images, np.ndarray): 18 | return super().preprocess(images=images, **kwargs) 19 | 20 | do_resize = kwargs.get('do_resize', self.do_resize) 21 | size = kwargs.get('size', self.size) 22 | size = get_size_dict(size, param_name="size", default_to_square=False) 23 | do_center_crop = kwargs.get('do_center_crop', self.do_center_crop) 24 | crop_size = kwargs.get('crop_size', self.crop_size) 25 | crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) 26 | do_rescale = kwargs.get('do_rescale', self.do_rescale) 27 | rescale_factor = kwargs.get('rescale_factor', self.rescale_factor) 28 | do_normalize = kwargs.get('do_normalize', self.do_normalize) 29 | image_mean = kwargs.get('image_mean', self.image_mean) 30 | image_std = kwargs.get('image_std', self.image_std) 31 | return_tensors = kwargs.get('return_tensors', None) 32 | 33 | def resize(images, output_size): 34 | images = images.permute((0, 3, 1, 2)) 35 | images = F.interpolate(images, size=output_size, mode='bicubic') 36 | images = images.permute((0, 2, 3, 1)) 37 | return images 38 | 39 | def center_crop(images, crop_size): 40 | crop_width, crop_height = crop_size["width"], crop_size["height"] 41 | img_width, img_height = images.shape[1:3] 42 | x = (img_width - crop_width) // 2 43 | y = (img_height - crop_height) // 2 44 | images = images[:, x:x+crop_width, y:y+crop_height] 45 | return images 46 | 47 | def rescale(images, rescale_factor): 48 | images = images * rescale_factor 49 | return images 50 | 51 | def normalize(images, mean, std): 52 | mean = torch.tensor(mean) 53 | std = torch.tensor(std) 54 | images = (images - mean) / std 55 | return images 56 | 57 | images = torch.from_numpy(images).float() 58 | 59 | if do_resize: 60 | output_size = get_resize_output_image_size(images[0], size=size["shortest_edge"], default_to_square=False) 61 | images = resize(images, output_size) 62 | 63 | if do_center_crop: 64 | images = center_crop(images, crop_size) 65 | 66 | if do_rescale: 67 | images = rescale(images, rescale_factor) 68 | 69 | if do_normalize: 70 | images = normalize(images, image_mean, image_std) 71 | 72 | images = images.permute((0, 3, 1, 2)) 73 | data = {"pixel_values": images} 74 | return BatchFeature(data=data, tensor_type=return_tensors) 75 | -------------------------------------------------------------------------------- /segment_anything/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | from functools import partial 10 | 11 | from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 12 | 13 | 14 | def build_sam_vit_h(checkpoint=None): 15 | return _build_sam( 16 | encoder_embed_dim=1280, 17 | encoder_depth=32, 18 | encoder_num_heads=16, 19 | encoder_global_attn_indexes=[7, 15, 23, 31], 20 | checkpoint=checkpoint, 21 | ) 22 | 23 | 24 | build_sam = build_sam_vit_h 25 | 26 | 27 | def build_sam_vit_l(checkpoint=None): 28 | return _build_sam( 29 | encoder_embed_dim=1024, 30 | encoder_depth=24, 31 | encoder_num_heads=16, 32 | encoder_global_attn_indexes=[5, 11, 17, 23], 33 | checkpoint=checkpoint, 34 | ) 35 | 36 | 37 | def build_sam_vit_b(checkpoint=None): 38 | return _build_sam( 39 | encoder_embed_dim=768, 40 | encoder_depth=12, 41 | encoder_num_heads=12, 42 | encoder_global_attn_indexes=[2, 5, 8, 11], 43 | checkpoint=checkpoint, 44 | ) 45 | 46 | 47 | sam_model_registry = { 48 | "default": build_sam_vit_h, 49 | "vit_h": build_sam_vit_h, 50 | "vit_l": build_sam_vit_l, 51 | "vit_b": build_sam_vit_b, 52 | } 53 | 54 | 55 | def _build_sam( 56 | encoder_embed_dim, 57 | encoder_depth, 58 | encoder_num_heads, 59 | encoder_global_attn_indexes, 60 | checkpoint=None, 61 | ): 62 | prompt_embed_dim = 256 63 | image_size = 1024 64 | vit_patch_size = 16 65 | image_embedding_size = image_size // vit_patch_size 66 | sam = Sam( 67 | image_encoder=ImageEncoderViT( 68 | depth=encoder_depth, 69 | embed_dim=encoder_embed_dim, 70 | img_size=image_size, 71 | mlp_ratio=4, 72 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 73 | num_heads=encoder_num_heads, 74 | patch_size=vit_patch_size, 75 | qkv_bias=True, 76 | use_rel_pos=True, 77 | global_attn_indexes=encoder_global_attn_indexes, 78 | window_size=14, 79 | out_chans=prompt_embed_dim, 80 | ), 81 | prompt_encoder=PromptEncoder( 82 | embed_dim=prompt_embed_dim, 83 | image_embedding_size=(image_embedding_size, image_embedding_size), 84 | input_image_size=(image_size, image_size), 85 | mask_in_chans=16, 86 | ), 87 | mask_decoder=MaskDecoder( 88 | num_multimask_outputs=3, 89 | transformer=TwoWayTransformer( 90 | depth=2, 91 | embedding_dim=prompt_embed_dim, 92 | mlp_dim=2048, 93 | num_heads=8, 94 | ), 95 | transformer_dim=prompt_embed_dim, 96 | iou_head_depth=3, 97 | iou_head_hidden_dim=256, 98 | ), 99 | pixel_mean=[123.675, 116.28, 103.53], 100 | pixel_std=[58.395, 57.12, 57.375], 101 | ) 102 | sam.eval() 103 | if checkpoint is not None: 104 | with open(checkpoint, "rb") as f: 105 | state_dict = torch.load(f) 106 | sam.load_state_dict(state_dict) 107 | return sam 108 | -------------------------------------------------------------------------------- /mgm/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | from ..processor.video_processor import VideoFramesProcessor 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower, args, delay_load=False): 9 | super().__init__() 10 | 11 | self.is_loaded = False 12 | 13 | self.vision_tower_name = vision_tower 14 | self.select_layer = args.mm_vision_select_layer 15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 16 | self.is_optimize = getattr(args, 'optimize_vision_tower', False) 17 | 18 | if not delay_load: 19 | self.load_model() 20 | elif getattr(args, 'unfreeze_mm_vision_tower', False): 21 | self.load_model() 22 | else: 23 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 24 | 25 | def load_model(self): 26 | self.image_processor = VideoFramesProcessor.from_pretrained(self.vision_tower_name) 27 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) 28 | self.vision_tower.requires_grad_(False) 29 | 30 | self.is_loaded = True 31 | 32 | def feature_select(self, image_forward_outs): 33 | image_features = image_forward_outs.hidden_states[self.select_layer] 34 | if self.select_feature == 'patch': 35 | image_features = image_features[:, 1:] 36 | elif self.select_feature == 'cls_patch': 37 | image_features = image_features 38 | else: 39 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 40 | return image_features 41 | 42 | def image_forward(self, images): 43 | if type(images) is list: 44 | image_features = [] 45 | for image in images: 46 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 47 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 48 | image_features.append(image_feature) 49 | else: 50 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 51 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 52 | 53 | return image_features 54 | 55 | def forward(self, images): 56 | if not self.is_optimize: 57 | with torch.no_grad(): 58 | image_features = self.image_forward(images) 59 | else: 60 | image_features = self.image_forward(images) 61 | 62 | return image_features 63 | 64 | @property 65 | def dummy_feature(self): 66 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 67 | 68 | @property 69 | def dtype(self): 70 | return self.vision_tower.dtype 71 | 72 | @property 73 | def device(self): 74 | return self.vision_tower.device 75 | 76 | @property 77 | def config(self): 78 | if self.is_loaded: 79 | return self.vision_tower.config 80 | else: 81 | return self.cfg_only 82 | 83 | @property 84 | def hidden_size(self): 85 | return self.config.hidden_size 86 | 87 | @property 88 | def num_patches(self): 89 | return (self.config.image_size // self.config.patch_size) ** 2 -------------------------------------------------------------------------------- /deepseek_vl/utils/io.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | import json 21 | from typing import Dict, List 22 | 23 | import PIL.Image 24 | import torch 25 | import base64 26 | import io 27 | from transformers import AutoModelForCausalLM 28 | 29 | from deepseek_vl.models import MultiModalityCausalLM, VLChatProcessor 30 | 31 | 32 | def load_pretrained_model(model_path: str): 33 | vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) 34 | tokenizer = vl_chat_processor.tokenizer 35 | 36 | vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained( 37 | model_path, trust_remote_code=True 38 | ) 39 | vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval() 40 | 41 | return tokenizer, vl_chat_processor, vl_gpt 42 | 43 | 44 | def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]: 45 | """ 46 | 47 | Support file path or base64 images. 48 | 49 | Args: 50 | conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is : 51 | [ 52 | { 53 | "role": "User", 54 | "content": "\nExtract all information from this image and convert them into markdown format.", 55 | "images": ["./examples/table_datasets.png"] 56 | }, 57 | {"role": "Assistant", "content": ""}, 58 | ] 59 | 60 | Returns: 61 | pil_images (List[PIL.Image.Image]): the list of PIL images. 62 | 63 | """ 64 | 65 | pil_images = [] 66 | 67 | for message in conversations: 68 | if "images" not in message: 69 | continue 70 | 71 | for image_data in message["images"]: 72 | if image_data.startswith("data:image"): 73 | # Image data is in base64 format 74 | _, image_data = image_data.split(",", 1) 75 | image_bytes = base64.b64decode(image_data) 76 | pil_img = PIL.Image.open(io.BytesIO(image_bytes)) 77 | else: 78 | # Image data is a file path 79 | pil_img = PIL.Image.open(image_data) 80 | pil_img = pil_img.convert("RGB") 81 | pil_images.append(pil_img) 82 | 83 | return pil_images 84 | 85 | 86 | def load_json(filepath): 87 | with open(filepath, "r") as f: 88 | data = json.load(f) 89 | return data 90 | -------------------------------------------------------------------------------- /scripts/visual_cot/gpt_eval_cot_score_single.py: -------------------------------------------------------------------------------- 1 | import os 2 | import openai 3 | import time 4 | from tqdm import tqdm 5 | import json 6 | import argparse 7 | import re 8 | import requests 9 | 10 | 11 | BASE_PROMPT = """ 12 | You are responsible for proofreading the answers, you need to give a score to the model's answer by referring to the standard answer, based on the given question. The full score is 1 point and the minimum score is 0 points. Please output the score in the form "score: ". The evaluation criteria require that the closer the model's answer is to the standard answer, the higher the score. 13 | """ 14 | 15 | PROMPT = """ 16 | question: %s 17 | standard answer: %s 18 | model's answer: %s 19 | """ 20 | 21 | API_KEY = os.environ['OPENAI_API_KEY'] 22 | GPT_EVAL_MODEL_NAME = "gpt-3.5-turbo-1106" 23 | API_TYPE = os.getenv("API_TYPE", "openai") 24 | 25 | API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions") 26 | # API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY") 27 | headers = { 28 | "Authorization": f"Bearer {API_KEY}", 29 | "Content-Type": "application/json", 30 | } 31 | 32 | 33 | def get_eval(content: str, max_tokens=100, retries: int = 5): 34 | global headers 35 | messages = [ 36 | { 37 | "role": "system", 38 | "content": BASE_PROMPT, 39 | }, 40 | {"role": "user", "content": content}, 41 | ] 42 | 43 | payload = { 44 | "model": GPT_EVAL_MODEL_NAME, 45 | "messages": messages, 46 | "temperature": 0.2, 47 | "max_tokens": max_tokens, 48 | } 49 | 50 | for attempt in range(retries): 51 | try: 52 | response = requests.post(API_URL, headers=headers, json=payload, timeout=60) 53 | response.raise_for_status() 54 | response_data = response.json() 55 | 56 | content = response_data["choices"][0]["message"]["content"] 57 | if content != "": 58 | return content 59 | break # If successful, break out of the loop 60 | 61 | except Exception as e: 62 | print(f"Attempt {attempt + 1} failed with error: {e}") 63 | if attempt < retries: # If we have retries left, sleep and then continue to next attempt 64 | time.sleep(5) 65 | else: # If this was the last attempt, log and return empty 66 | print(f"All {retries} attempts failed. Last error message: {e}") 67 | return "" 68 | return "" 69 | 70 | 71 | def get_score(question_text, gt_answer_text, pred_answer_text): 72 | content = PROMPT % (question_text, gt_answer_text, pred_answer_text) 73 | ret = get_eval(content) 74 | ret = ret.lower() 75 | if 'score' not in ret: 76 | return 0.0 77 | res = re.findall(r'score: ([\d\.]+)', ret) 78 | if len(res) != 1: 79 | return 0.0 80 | res = float(res[0]) 81 | if res > 1.0: 82 | res = 1 83 | if res < 0.0: 84 | res = 0 85 | time.sleep(1) # sleep for 1 second after a successful request to avoid high frequency 86 | return res 87 | 88 | 89 | if __name__ == "__main__": 90 | trial = get_eval('Who are you?') 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument("--result_file", type=str) 93 | args = parser.parse_args() 94 | result_file = args.result_file 95 | print(f"Processing {result_file}", flush=True) 96 | scores = [] 97 | with open(result_file, 'r') as f: 98 | data = json.load(f) 99 | for data_sample in tqdm(data): 100 | score = get_score(data_sample['question'], data_sample['gt'], data_sample['answer']) 101 | scores.append(score) 102 | data_sample['score'] = score 103 | 104 | print(f'The avg score on {result_file} is: {sum(scores) / len(scores)}', flush=True) 105 | 106 | with open(result_file, 'w') as f: 107 | json.dump(data, f, indent=4) 108 | -------------------------------------------------------------------------------- /deepseek_vl/models/projector.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | from typing import Tuple, Union 21 | 22 | import torch 23 | import torch.nn as nn 24 | from attrdict import AttrDict 25 | 26 | 27 | class MlpProjector(nn.Module): 28 | def __init__(self, cfg): 29 | super().__init__() 30 | 31 | self.cfg = cfg 32 | 33 | if cfg.projector_type == "identity": 34 | modules = nn.Identity() 35 | 36 | elif cfg.projector_type == "linear": 37 | modules = nn.Linear(cfg.input_dim, cfg.n_embed) 38 | 39 | elif cfg.projector_type == "mlp_gelu": 40 | mlp_depth = cfg.get("depth", 1) 41 | modules = [nn.Linear(cfg.input_dim, cfg.n_embed)] 42 | for _ in range(1, mlp_depth): 43 | modules.append(nn.GELU()) 44 | modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) 45 | modules = nn.Sequential(*modules) 46 | 47 | elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu": 48 | mlp_depth = cfg.get("depth", 1) 49 | self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) 50 | self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) 51 | 52 | modules = [] 53 | for _ in range(1, mlp_depth): 54 | modules.append(nn.GELU()) 55 | modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) 56 | modules = nn.Sequential(*modules) 57 | 58 | else: 59 | raise ValueError(f"Unknown projector type: {cfg.projector_type}") 60 | 61 | self.layers = modules 62 | 63 | def forward( 64 | self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor] 65 | ): 66 | """ 67 | 68 | Args: 69 | x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor, 70 | then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x); 71 | otherwise it is the feature from the single vision encoder. 72 | 73 | Returns: 74 | x (torch.Tensor): [b, s, c] 75 | """ 76 | 77 | if isinstance(x_or_tuple, tuple): 78 | # self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu": 79 | high_x, low_x = x_or_tuple 80 | high_x = self.high_up_proj(high_x) 81 | low_x = self.low_up_proj(low_x) 82 | x = torch.concat([high_x, low_x], dim=-1) 83 | else: 84 | x = x_or_tuple 85 | 86 | return self.layers(x) 87 | 88 | 89 | if __name__ == "__main__": 90 | cfg = AttrDict( 91 | input_dim=1024, 92 | n_embed=2048, 93 | depth=2, 94 | projector_type="low_high_hybrid_split_mlp_gelu", 95 | ) 96 | inputs = (torch.rand(4, 576, 1024), torch.rand(4, 576, 1024)) 97 | 98 | m = MlpProjector(cfg) 99 | out = m(inputs) 100 | print(out.shape) 101 | -------------------------------------------------------------------------------- /scripts/visual_cot/gpt_eval_cot_score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import openai 3 | import time 4 | from tqdm import tqdm 5 | import json 6 | import argparse 7 | import re 8 | import requests 9 | from glob import glob 10 | 11 | 12 | BASE_PROMPT = """ 13 | You are responsible for proofreading the answers, you need to give a score to the model's answer by referring to the standard answer, based on the given question. The full score is 1 point and the minimum score is 0 points. Please output the score in the form "score: ". The evaluation criteria require that the closer the model's answer is to the standard answer, the higher the score. 14 | """ 15 | 16 | PROMPT = """ 17 | question: %s 18 | standard answer: %s 19 | model's answer: %s 20 | """ 21 | 22 | API_KEY = os.environ['OPENAI_API_KEY'] 23 | GPT_EVAL_MODEL_NAME = "gpt-3.5-turbo-1106" 24 | API_TYPE = os.getenv("API_TYPE", "openai") 25 | 26 | API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions") 27 | # API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY") 28 | headers = { 29 | "Authorization": f"Bearer {API_KEY}", 30 | "Content-Type": "application/json", 31 | } 32 | 33 | 34 | def get_eval(content: str, max_tokens=100, retries: int = 5): 35 | global headers 36 | messages = [ 37 | { 38 | "role": "system", 39 | "content": BASE_PROMPT, 40 | }, 41 | {"role": "user", "content": content}, 42 | ] 43 | 44 | payload = { 45 | "model": GPT_EVAL_MODEL_NAME, 46 | "messages": messages, 47 | "temperature": 0.2, 48 | "max_tokens": max_tokens, 49 | } 50 | 51 | for attempt in range(retries): 52 | try: 53 | response = requests.post(API_URL, headers=headers, json=payload, timeout=60) 54 | response.raise_for_status() 55 | response_data = response.json() 56 | 57 | content = response_data["choices"][0]["message"]["content"] 58 | if content != "": 59 | return content 60 | break # If successful, break out of the loop 61 | 62 | except Exception as e: 63 | print(f"Attempt {attempt + 1} failed with error: {e}") 64 | if attempt < retries: # If we have retries left, sleep and then continue to next attempt 65 | time.sleep(5) 66 | else: # If this was the last attempt, log and return empty 67 | print(f"All {retries} attempts failed. Last error message: {e}") 68 | return "" 69 | return "" 70 | 71 | 72 | def get_score(question_text, gt_answer_text, pred_answer_text): 73 | content = PROMPT % (question_text, gt_answer_text, pred_answer_text) 74 | ret = get_eval(content) 75 | ret = ret.lower() 76 | if 'score' not in ret: 77 | return 0.0 78 | res = re.findall(r'score: ([\d\.]+)', ret) 79 | if len(res) != 1: 80 | return 0.0 81 | res = float(res[0]) 82 | if res > 1.0: 83 | res = 1 84 | if res < 0.0: 85 | res = 0 86 | time.sleep(1) # sleep for 1 second after a successful request to avoid high frequency 87 | return res 88 | 89 | 90 | if __name__ == "__main__": 91 | trial = get_eval('Who are you?') 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument("--result_dir", type=str) 94 | args = parser.parse_args() 95 | result_files = glob(f'{args.result_dir}/*.json') 96 | 97 | for result_file in result_files: 98 | print(f"Processing {result_file}", flush=True) 99 | scores = [] 100 | with open(result_file, 'r') as f: 101 | data = json.load(f) 102 | for data_sample in tqdm(data): 103 | score = get_score(data_sample['question'], data_sample['gt'], data_sample['answer']) 104 | scores.append(score) 105 | data_sample['score'] = score 106 | 107 | print(f'The avg score on {result_file} is: {sum(scores) / len(scores)}', flush=True) 108 | 109 | with open(result_file, 'w') as f: 110 | json.dump(data, f, indent=4) 111 | -------------------------------------------------------------------------------- /segment_anything/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from torchvision.transforms.functional import resize, to_pil_image # type: ignore 11 | 12 | from copy import deepcopy 13 | from typing import Tuple 14 | 15 | 16 | class ResizeLongestSide: 17 | """ 18 | Resizes images to the longest side 'target_length', as well as provides 19 | methods for resizing coordinates and boxes. Provides methods for 20 | transforming both numpy array and batched torch tensors. 21 | """ 22 | 23 | def __init__(self, target_length: int) -> None: 24 | self.target_length = target_length 25 | 26 | def apply_image(self, image: np.ndarray) -> np.ndarray: 27 | """ 28 | Expects a numpy array with shape HxWxC in uint8 format. 29 | """ 30 | target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) 31 | return np.array(resize(to_pil_image(image), target_size)) 32 | 33 | def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 34 | """ 35 | Expects a numpy array of length 2 in the final dimension. Requires the 36 | original image size in (H, W) format. 37 | """ 38 | old_h, old_w = original_size 39 | new_h, new_w = self.get_preprocess_shape( 40 | original_size[0], original_size[1], self.target_length 41 | ) 42 | coords = deepcopy(coords).astype(float) 43 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 44 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 45 | return coords 46 | 47 | def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: 48 | """ 49 | Expects a numpy array shape Bx4. Requires the original image size 50 | in (H, W) format. 51 | """ 52 | boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) 53 | return boxes.reshape(-1, 4) 54 | 55 | def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: 56 | """ 57 | Expects batched images with shape BxCxHxW and float format. This 58 | transformation may not exactly match apply_image. apply_image is 59 | the transformation expected by the model. 60 | """ 61 | # Expects an image in BCHW format. May not exactly match apply_image. 62 | target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) 63 | return F.interpolate( 64 | image, target_size, mode="bilinear", align_corners=False, antialias=True 65 | ) 66 | 67 | def apply_coords_torch( 68 | self, coords: torch.Tensor, original_size: Tuple[int, ...] 69 | ) -> torch.Tensor: 70 | """ 71 | Expects a torch tensor with length 2 in the last dimension. Requires the 72 | original image size in (H, W) format. 73 | """ 74 | old_h, old_w = original_size 75 | new_h, new_w = self.get_preprocess_shape( 76 | original_size[0], original_size[1], self.target_length 77 | ) 78 | coords = deepcopy(coords).to(torch.float) 79 | coords[..., 0] = coords[..., 0] * (new_w / old_w) 80 | coords[..., 1] = coords[..., 1] * (new_h / old_h) 81 | return coords 82 | 83 | def apply_boxes_torch( 84 | self, boxes: torch.Tensor, original_size: Tuple[int, ...] 85 | ) -> torch.Tensor: 86 | """ 87 | Expects a torch tensor with shape Bx4. Requires the original image 88 | size in (H, W) format. 89 | """ 90 | boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) 91 | return boxes.reshape(-1, 4) 92 | 93 | @staticmethod 94 | def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: 95 | """ 96 | Compute the output size given input size and target long side length. 97 | """ 98 | scale = long_side_length * 1.0 / max(oldh, oldw) 99 | newh, neww = oldh * scale, oldw * scale 100 | neww = int(neww + 0.5) 101 | newh = int(newh + 0.5) 102 | return (newh, neww) 103 | -------------------------------------------------------------------------------- /mgm/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from mgm.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True, encoding='UTF-8') 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /mgm/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import base64 4 | 5 | import torch 6 | from transformers import StoppingCriteria 7 | from mgm.constants import IMAGE_TOKEN_INDEX 8 | 9 | 10 | def load_image_from_base64(image): 11 | return Image.open(BytesIO(base64.b64decode(image))) 12 | 13 | 14 | def expand2square(pil_img, background_color): 15 | width, height = pil_img.size 16 | if width == height: 17 | return pil_img 18 | elif width > height: 19 | result = Image.new(pil_img.mode, (width, width), background_color) 20 | result.paste(pil_img, (0, (width - height) // 2)) 21 | return result 22 | else: 23 | result = Image.new(pil_img.mode, (height, height), background_color) 24 | result.paste(pil_img, ((height - width) // 2, 0)) 25 | return result 26 | 27 | 28 | def process_images(images, image_processor, model_cfg): 29 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) 30 | new_images = [] 31 | if image_aspect_ratio == 'pad': 32 | for image in images: 33 | image = expand2square(image.convert('RGB'), tuple(int(x*255) for x in image_processor.image_mean)) 34 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 35 | new_images.append(image) 36 | else: 37 | return image_processor(images, return_tensors='pt')['pixel_values'] 38 | if all(x.shape == new_images[0].shape for x in new_images): 39 | new_images = torch.stack(new_images, dim=0) 40 | return new_images 41 | 42 | 43 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 44 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 45 | 46 | def insert_separator(X, sep): 47 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 48 | 49 | input_ids = [] 50 | offset = 0 51 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 52 | offset = 1 53 | input_ids.append(prompt_chunks[0][0]) 54 | 55 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 56 | input_ids.extend(x[offset:]) 57 | 58 | if return_tensors is not None: 59 | if return_tensors == 'pt': 60 | return torch.tensor(input_ids, dtype=torch.long) 61 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 62 | return input_ids 63 | 64 | 65 | def get_model_name_from_path(model_path): 66 | model_path = model_path.strip("/") 67 | model_paths = model_path.split("/") 68 | if model_paths[-1].startswith('checkpoint-'): 69 | return model_paths[-2] + "_" + model_paths[-1] 70 | else: 71 | return model_paths[-1] 72 | 73 | class KeywordsStoppingCriteria(StoppingCriteria): 74 | def __init__(self, keywords, tokenizer, input_ids): 75 | self.keywords = keywords 76 | self.keyword_ids = [] 77 | self.max_keyword_len = 0 78 | for keyword in keywords: 79 | cur_keyword_ids = tokenizer(keyword).input_ids 80 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 81 | cur_keyword_ids = cur_keyword_ids[1:] 82 | if len(cur_keyword_ids) > self.max_keyword_len: 83 | self.max_keyword_len = len(cur_keyword_ids) 84 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 85 | self.tokenizer = tokenizer 86 | self.start_len = input_ids.shape[1] 87 | 88 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 89 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 90 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 91 | for keyword_id in self.keyword_ids: 92 | truncated_output_ids = output_ids[0, -keyword_id.shape[0]:] 93 | if torch.equal(truncated_output_ids, keyword_id): 94 | return True 95 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 96 | for keyword in self.keywords: 97 | if keyword in outputs: 98 | return True 99 | return False 100 | 101 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 102 | outputs = [] 103 | for i in range(output_ids.shape[0]): 104 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) 105 | return all(outputs) -------------------------------------------------------------------------------- /scripts/demo/grounded_conversation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import numpy as np 4 | from mmengine.config import Config 5 | from xtuner.registry import BUILDER 6 | from PIL import Image 7 | from xtuner.model.utils import guess_load_checkpoint 8 | from scripts.demo.utils import colors 9 | 10 | import spacy 11 | nlp = spacy.load("en_core_web_sm") 12 | import random 13 | random.shuffle(colors) 14 | 15 | 16 | def process_noun_chunks(noun_chunks): 17 | new_noun_chunks = [] 18 | for i in range(len(noun_chunks)): 19 | noun_chunk = noun_chunks[i] 20 | if 'image' in noun_chunk.lower(): 21 | continue 22 | if noun_chunk.lower() in ['it', 'this', 'that', 'those', 'these', 'them', 23 | 'he', 'she', 'you', 'i', 'they', 'me', 'her', 24 | 'him', 'a', 'what', 'which', 'whose', 'who']: 25 | continue 26 | keep = True 27 | for j in range(len(noun_chunks)): # de-duplicate 28 | if i != j and noun_chunk in noun_chunks[j]: 29 | if len(noun_chunk) < len(noun_chunks[j]) or i > j: 30 | keep = False 31 | break 32 | if keep: 33 | new_noun_chunks.append(noun_chunk) 34 | 35 | return new_noun_chunks 36 | 37 | 38 | def extract_noun_phrases(output_text): 39 | doc = nlp(output_text) 40 | noun_chunks = list(set(chunk.text for chunk in doc.noun_chunks)) 41 | if len(noun_chunks) == 0: 42 | noun_chunks = [output_text] 43 | last_end = 0 44 | noun_chunks = process_noun_chunks(noun_chunks) 45 | noun_chunks = sorted(noun_chunks, key=lambda x: output_text.find(x)) 46 | 47 | noun_chunks = [noun_chunk for noun_chunk in noun_chunks 48 | if int(input(f'Ground {noun_chunk}?')) == 1] 49 | 50 | positive_ids = [] 51 | phrases = [] 52 | for noun_chunk in noun_chunks: 53 | obj_start = output_text.find(noun_chunk) 54 | if obj_start < last_end: 55 | continue 56 | obj_end = obj_start + len(noun_chunk) 57 | last_end = obj_end 58 | positive_ids.append((obj_start, obj_end)) 59 | phrases.append(noun_chunk) 60 | 61 | return positive_ids, phrases 62 | 63 | 64 | def find_interval(intervals, idx): 65 | for interval_id, (start_id, end_id) in enumerate(intervals): 66 | if (idx >= start_id) and (idx < end_id): 67 | return interval_id 68 | return len(intervals) 69 | 70 | 71 | if __name__ == '__main__': 72 | parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) 73 | parser.add_argument('config', help='config file path.') 74 | parser.add_argument('--image', 75 | default='data/coco/val2017/000000000632.jpg', type=str) 76 | parser.add_argument('--text', 77 | default='Where is the shampoo?', type=str) 78 | parser.add_argument('--checkpoint', 79 | default='checkpoints/frozen_deepseek_vl_1_3b_unet_sam_l_iter_95080.pth', type=str) 80 | parser.add_argument('--use_sam', action='store_true') 81 | 82 | args = parser.parse_args() 83 | 84 | cfg = Config.fromfile(args.config) 85 | prompt_template = cfg.prompt_template 86 | tokenizer = cfg.tokenizer 87 | image_processor = cfg.image_processor 88 | prompt = cfg.get('prompt', None) 89 | 90 | model = BUILDER.build(cfg.model) 91 | state_dict = guess_load_checkpoint(args.checkpoint) 92 | missing, unexpected = model.load_state_dict(state_dict, strict=False) 93 | model._prepare_for_generation(image_processor=image_processor, 94 | prompt_template=prompt_template, 95 | max_thought_tokens=16, 96 | max_new_tokens=512, 97 | lmm_name=cfg.lmm_name, 98 | additional_prompt='') 99 | model = model.cuda().eval() 100 | 101 | image = Image.open(args.image) 102 | output = model.answer(image, args.text) 103 | output_ids = output.pop('output_ids').cpu() 104 | output_text = output.pop('output_text') 105 | encoded = model.tokenizer(output_text, add_special_tokens=False, return_tensors='pt') 106 | assert (encoded.input_ids[0] == output_ids).all() 107 | offsets = encoded.encodings[0].offsets 108 | str_places, phrases = extract_noun_phrases(output_text) 109 | positive_ids = [] 110 | for start_id, end_id in str_places: 111 | start_token_place = find_interval(offsets, start_id) 112 | end_token_place = max(start_token_place+1, find_interval(offsets, end_id)) 113 | positive_ids.append((start_token_place, end_token_place)) 114 | with torch.no_grad(): 115 | pred_masks, sam_pred_masks = model.ground(image=image, positive_ids=positive_ids, **output) 116 | if args.use_sam: 117 | masks = sam_pred_masks.cpu().numpy() > 0 118 | else: 119 | masks = pred_masks.cpu().numpy() > 0 120 | 121 | image_np = np.array(image).astype(np.float32) 122 | for color_id, mask in enumerate(masks): 123 | image_np[mask] = image_np[mask] * 0.2 + np.array(colors[color_id]).reshape((1, 1, 3)) * 0.8 124 | 125 | image = Image.fromarray(image_np.astype(np.uint8)) 126 | print(output_text, flush=True) 127 | print(phrases, flush=True) 128 | image.save('example.jpg') 129 | -------------------------------------------------------------------------------- /flmm/models/mask_head/mask_refiner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from segment_anything import sam_model_registry 6 | from segment_anything.utils.transforms import ResizeLongestSide 7 | 8 | 9 | def mask2box(mask): 10 | ys, xs = np.where(mask > 0) 11 | y0, y1 = ys.min(), ys.max() 12 | x0, x1 = xs.min(), xs.max() 13 | 14 | return np.array([x0, y0, x1+1, y1+1]) # avoid x0==x1 15 | 16 | 17 | def compute_mask_IoU(masks, target): 18 | temp = masks * target 19 | intersection = temp.sum(dim=-1) 20 | union = ((masks + target) - temp).sum(dim=-1) 21 | return intersection, union, intersection / (union + 1e-12) 22 | 23 | 24 | class SAMWrapper(nn.Module): 25 | def __init__(self, model_name, checkpoint, 26 | use_text=True, use_mask=True, use_box=True, 27 | multimask_output=False): 28 | super(SAMWrapper, self).__init__() 29 | self.model = sam_model_registry[model_name](checkpoint=checkpoint) 30 | self.model.image_encoder.requires_grad_(False) 31 | self.transform = ResizeLongestSide(self.model.image_encoder.img_size) 32 | self.use_text = use_text 33 | self.use_mask = use_mask 34 | self.use_box = use_box 35 | self.multimask_output = multimask_output 36 | 37 | def train(self, mode=True): 38 | super().train(mode=mode) 39 | self.model.image_encoder.eval() 40 | self.training = mode 41 | return self 42 | 43 | @property 44 | def dtype(self): 45 | return self.model.dtype 46 | 47 | @torch.no_grad() 48 | def encode_image(self, image): 49 | image = np.array(image.convert(self.model.image_format)) 50 | input_image = self.transform.apply_image(image) 51 | input_image_torch = torch.as_tensor(input_image, device=self.model.device) 52 | transformed_image = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] 53 | 54 | original_image_size = image.shape[:2] 55 | input_size = transformed_image.shape[-2:] 56 | 57 | features = self.model.image_encoder(self.model.preprocess(transformed_image)) 58 | 59 | return features, original_image_size, input_size 60 | 61 | def generate_prompt_masks(self, masks, input_size): 62 | pad_value = min(-1.0, masks.min().item()) 63 | masks = F.interpolate(masks[:, None].float(), size=input_size, mode='bilinear').to(masks) 64 | h, w = masks.shape[-2:] 65 | masks = F.pad(masks, (0, self.model.image_encoder.img_size - w, 66 | 0, self.model.image_encoder.img_size - h), value=pad_value) 67 | prompt_masks = F.interpolate(masks.float(), size=(256, 256), mode='bilinear').to(masks) 68 | 69 | return prompt_masks 70 | 71 | def forward(self, image, pred_masks, text_embeds): 72 | # masks are in logits 73 | image_embedding, original_image_size, input_size = self.encode_image(image) 74 | if self.training: 75 | image_embedding.requires_grad = True 76 | prompt_masks = self.generate_prompt_masks(pred_masks, input_size) 77 | 78 | pred_masks = F.interpolate(pred_masks.detach()[None].float().sigmoid(), 79 | size=original_image_size, mode='bilinear')[0] 80 | pred_masks = (pred_masks > 0.5).to(pred_masks) 81 | 82 | sam_masks = [] 83 | for prompt_mask, pred_mask, text_embed in zip(prompt_masks, pred_masks, text_embeds): 84 | if self.use_box: 85 | if pred_mask.sum() > 0: 86 | box = mask2box(pred_mask.float().cpu().numpy()) 87 | else: 88 | h, w = original_image_size 89 | box = np.array([0.0, 0.0, w, h]) 90 | box = self.transform.apply_boxes(box, original_image_size) 91 | box_torch = torch.as_tensor(box, dtype=pred_mask.dtype, device=self.model.device) 92 | box_torch = box_torch[None, :] # 1, 1, 4 93 | else: 94 | box_torch = None 95 | sparse_embeddings, dense_embeddings = self.model.prompt_encoder( 96 | points=None, 97 | boxes=box_torch, 98 | masks=prompt_mask.view(1, 1, 256, 256) if self.use_mask else None, 99 | ) 100 | if self.use_text: 101 | sparse_embeddings = torch.cat([sparse_embeddings.to(dense_embeddings), 102 | text_embed[None].to(dense_embeddings)], dim=1) 103 | else: 104 | sparse_embeddings = sparse_embeddings.to(dense_embeddings) 105 | low_res_masks, iou_predictions = self.model.mask_decoder( 106 | image_embeddings=image_embedding, 107 | image_pe=self.model.prompt_encoder.get_dense_pe(), 108 | sparse_prompt_embeddings=sparse_embeddings, 109 | dense_prompt_embeddings=dense_embeddings, 110 | multimask_output=self.multimask_output, 111 | ) 112 | sam_mask = self.model.postprocess_masks(low_res_masks, input_size, original_image_size) 113 | 114 | if self.multimask_output: 115 | candidate_masks = (sam_mask[0] > 0.0).float() 116 | candidate_ious = compute_mask_IoU(candidate_masks.view(3, -1), 117 | pred_mask.float().view(1, -1))[-1] 118 | sam_mask = sam_mask[0, candidate_ious.argmax()] 119 | else: 120 | assert sam_mask.shape[1] == 1 121 | sam_mask = sam_mask[0, 0] 122 | sam_masks.append(sam_mask) 123 | 124 | return torch.stack(sam_masks) 125 | 126 | def state_dict(self, *args, **kwargs): 127 | state_dict = super().state_dict(*args, **kwargs) 128 | return {k: v for k, v in state_dict.items() if 'image_encoder' not in k} 129 | -------------------------------------------------------------------------------- /mgm/model/language_model/mgm_mistral.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ------------------------------------------------------------------------ 15 | # Modified from LLaVA (https://github.com/haotian-liu/LLaVA) 16 | # Copyright 2024 Yanwei Li 17 | # ------------------------------------------------------------------------ 18 | 19 | from typing import List, Optional, Tuple, Union 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | from transformers import AutoConfig, AutoModelForCausalLM, \ 25 | MistralConfig, MistralModel, MistralForCausalLM 26 | 27 | from transformers.modeling_outputs import CausalLMOutputWithPast 28 | from transformers.generation.utils import GenerateOutput 29 | from transformers.generation.utils import logging 30 | 31 | from ..mgm_arch import MGMMetaModel, MGMMetaForCausalLM 32 | 33 | logger = logging.get_logger(__name__) 34 | 35 | class MGMConfig(MistralConfig): 36 | model_type = "mgm_mistral" 37 | 38 | 39 | class MGMMistralModel(MGMMetaModel, MistralModel): 40 | config_class = MGMConfig 41 | 42 | def __init__(self, config: MistralConfig): 43 | super(MGMMistralModel, self).__init__(config) 44 | # self.max_pos_idx = 0 45 | 46 | class MGMMistralForCausalLM(MistralForCausalLM, MGMMetaForCausalLM): 47 | config_class = MGMConfig 48 | 49 | def __init__(self, config): 50 | super(MistralForCausalLM, self).__init__(config) 51 | self.model = MGMMistralModel(config) 52 | # self.pretraining_tp = config.pretraining_tp 53 | self.vocab_size = config.vocab_size 54 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 55 | 56 | # Initialize weights and apply final processing 57 | self.post_init() 58 | 59 | def get_model(self): 60 | return self.model 61 | 62 | def forward( 63 | self, 64 | input_ids: torch.LongTensor = None, 65 | attention_mask: Optional[torch.Tensor] = None, 66 | position_ids: Optional[torch.LongTensor] = None, 67 | past_key_values: Optional[List[torch.FloatTensor]] = None, 68 | inputs_embeds: Optional[torch.FloatTensor] = None, 69 | labels: Optional[torch.LongTensor] = None, 70 | use_cache: Optional[bool] = None, 71 | output_attentions: Optional[bool] = None, 72 | output_hidden_states: Optional[bool] = None, 73 | images: Optional[torch.FloatTensor] = None, 74 | images_aux: Optional[torch.FloatTensor] = None, 75 | return_dict: Optional[bool] = None, 76 | ) -> Union[Tuple, CausalLMOutputWithPast]: 77 | 78 | if inputs_embeds is None: 79 | ( 80 | input_ids, 81 | position_ids, 82 | attention_mask, 83 | past_key_values, 84 | inputs_embeds, 85 | labels 86 | ) = self.prepare_inputs_labels_for_multimodal( 87 | input_ids, 88 | position_ids, 89 | attention_mask, 90 | past_key_values, 91 | labels, 92 | images, 93 | images_aux 94 | ) 95 | 96 | return super().forward( 97 | input_ids=input_ids, 98 | attention_mask=attention_mask, 99 | position_ids=position_ids, 100 | past_key_values=past_key_values, 101 | inputs_embeds=inputs_embeds, 102 | labels=labels, 103 | use_cache=use_cache, 104 | output_attentions=output_attentions, 105 | output_hidden_states=output_hidden_states, 106 | return_dict=return_dict 107 | ) 108 | 109 | @torch.no_grad() 110 | def generate( 111 | self, 112 | inputs: Optional[torch.Tensor] = None, 113 | images: Optional[torch.Tensor] = None, 114 | images_aux: Optional[torch.FloatTensor] = None, 115 | **kwargs, 116 | ) -> Union[GenerateOutput, torch.LongTensor]: 117 | position_ids = kwargs.pop("position_ids", None) 118 | attention_mask = kwargs.pop("attention_mask", None) 119 | if "inputs_embeds" in kwargs: 120 | raise NotImplementedError("`inputs_embeds` is not supported") 121 | 122 | if images is not None: 123 | ( 124 | inputs, 125 | position_ids, 126 | attention_mask, 127 | _, 128 | inputs_embeds, 129 | _ 130 | ) = self.prepare_inputs_labels_for_multimodal( 131 | inputs, 132 | position_ids, 133 | attention_mask, 134 | None, 135 | None, 136 | images, 137 | images_aux 138 | ) 139 | else: 140 | inputs_embeds = self.get_model().embed_tokens(inputs) 141 | 142 | return super().generate( 143 | position_ids=position_ids, 144 | attention_mask=attention_mask, 145 | inputs_embeds=inputs_embeds, 146 | **kwargs 147 | ) 148 | 149 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 150 | images = kwargs.pop("images", None) 151 | images_aux = kwargs.pop("images_aux", None) 152 | _inputs = super().prepare_inputs_for_generation( 153 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 154 | ) 155 | if images is not None: 156 | _inputs['images'] = images 157 | if images_aux is not None: 158 | _inputs['images_aux'] = images_aux 159 | return _inputs 160 | 161 | AutoConfig.register("mgm_mistral", MGMConfig) 162 | AutoModelForCausalLM.register(MGMConfig, MGMMistralForCausalLM) -------------------------------------------------------------------------------- /mgm/model/language_model/mgm_mixtral.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ------------------------------------------------------------------------ 15 | # Modified from LLaVA (https://github.com/haotian-liu/LLaVA) 16 | # Copyright 2024 Yanwei Li 17 | # ------------------------------------------------------------------------ 18 | 19 | from typing import List, Optional, Tuple, Union 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | from transformers import AutoConfig, AutoModelForCausalLM, \ 25 | MixtralConfig, MixtralModel, MixtralForCausalLM 26 | 27 | from transformers.modeling_outputs import CausalLMOutputWithPast 28 | from transformers.generation.utils import GenerateOutput 29 | from transformers.generation.utils import logging 30 | 31 | from ..mgm_arch import MGMMetaModel, MGMMetaForCausalLM 32 | 33 | logger = logging.get_logger(__name__) 34 | 35 | class MGMConfig(MixtralConfig): 36 | model_type = "mgm_mixtral" 37 | 38 | 39 | class MGMMixtralModel(MGMMetaModel, MixtralModel): 40 | config_class = MGMConfig 41 | 42 | def __init__(self, config: MixtralConfig): 43 | super(MGMMixtralModel, self).__init__(config) 44 | # self.max_pos_idx = 0 45 | 46 | class MGMMixtralForCausalLM(MixtralForCausalLM, MGMMetaForCausalLM): 47 | config_class = MGMConfig 48 | 49 | def __init__(self, config): 50 | super(MixtralForCausalLM, self).__init__(config) 51 | self.model = MGMMixtralModel(config) 52 | # self.pretraining_tp = config.pretraining_tp 53 | self.vocab_size = config.vocab_size 54 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 55 | 56 | # Initialize weights and apply final processing 57 | self.post_init() 58 | 59 | def get_model(self): 60 | return self.model 61 | 62 | def forward( 63 | self, 64 | input_ids: torch.LongTensor = None, 65 | attention_mask: Optional[torch.Tensor] = None, 66 | position_ids: Optional[torch.LongTensor] = None, 67 | past_key_values: Optional[List[torch.FloatTensor]] = None, 68 | inputs_embeds: Optional[torch.FloatTensor] = None, 69 | labels: Optional[torch.LongTensor] = None, 70 | use_cache: Optional[bool] = None, 71 | output_attentions: Optional[bool] = None, 72 | output_hidden_states: Optional[bool] = None, 73 | images: Optional[torch.FloatTensor] = None, 74 | images_aux: Optional[torch.FloatTensor] = None, 75 | return_dict: Optional[bool] = None, 76 | ) -> Union[Tuple, CausalLMOutputWithPast]: 77 | 78 | if inputs_embeds is None: 79 | ( 80 | input_ids, 81 | position_ids, 82 | attention_mask, 83 | past_key_values, 84 | inputs_embeds, 85 | labels 86 | ) = self.prepare_inputs_labels_for_multimodal( 87 | input_ids, 88 | position_ids, 89 | attention_mask, 90 | past_key_values, 91 | labels, 92 | images, 93 | images_aux 94 | ) 95 | 96 | return super().forward( 97 | input_ids=input_ids, 98 | attention_mask=attention_mask, 99 | position_ids=position_ids, 100 | past_key_values=past_key_values, 101 | inputs_embeds=inputs_embeds, 102 | labels=labels, 103 | use_cache=use_cache, 104 | output_attentions=output_attentions, 105 | output_hidden_states=output_hidden_states, 106 | return_dict=return_dict 107 | ) 108 | 109 | @torch.no_grad() 110 | def generate( 111 | self, 112 | inputs: Optional[torch.Tensor] = None, 113 | images: Optional[torch.Tensor] = None, 114 | images_aux: Optional[torch.FloatTensor] = None, 115 | **kwargs, 116 | ) -> Union[GenerateOutput, torch.LongTensor]: 117 | position_ids = kwargs.pop("position_ids", None) 118 | attention_mask = kwargs.pop("attention_mask", None) 119 | if "inputs_embeds" in kwargs: 120 | raise NotImplementedError("`inputs_embeds` is not supported") 121 | 122 | if images is not None: 123 | ( 124 | inputs, 125 | position_ids, 126 | attention_mask, 127 | _, 128 | inputs_embeds, 129 | _ 130 | ) = self.prepare_inputs_labels_for_multimodal( 131 | inputs, 132 | position_ids, 133 | attention_mask, 134 | None, 135 | None, 136 | images, 137 | images_aux 138 | ) 139 | else: 140 | inputs_embeds = self.get_model().embed_tokens(inputs) 141 | 142 | return super().generate( 143 | position_ids=position_ids, 144 | attention_mask=attention_mask, 145 | inputs_embeds=inputs_embeds, 146 | **kwargs 147 | ) 148 | 149 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 150 | images = kwargs.pop("images", None) 151 | images_aux = kwargs.pop("images_aux", None) 152 | _inputs = super().prepare_inputs_for_generation( 153 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 154 | ) 155 | if images is not None: 156 | _inputs['images'] = images 157 | if images_aux is not None: 158 | _inputs['images_aux'] = images_aux 159 | return _inputs 160 | 161 | AutoConfig.register("mgm_mixtral", MGMConfig) 162 | AutoModelForCausalLM.register(MGMConfig, MGMMixtralForCausalLM) -------------------------------------------------------------------------------- /segment_anything/utils/onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Tuple 12 | 13 | from ..modeling import Sam 14 | from .amg import calculate_stability_score 15 | 16 | 17 | class SamOnnxModel(nn.Module): 18 | """ 19 | This model should not be called directly, but is used in ONNX export. 20 | It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, 21 | with some functions modified to enable model tracing. Also supports extra 22 | options controlling what information. See the ONNX export script for details. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | model: Sam, 28 | return_single_mask: bool, 29 | use_stability_score: bool = False, 30 | return_extra_metrics: bool = False, 31 | ) -> None: 32 | super().__init__() 33 | self.mask_decoder = model.mask_decoder 34 | self.model = model 35 | self.img_size = model.image_encoder.img_size 36 | self.return_single_mask = return_single_mask 37 | self.use_stability_score = use_stability_score 38 | self.stability_score_offset = 1.0 39 | self.return_extra_metrics = return_extra_metrics 40 | 41 | @staticmethod 42 | def resize_longest_image_size( 43 | input_image_size: torch.Tensor, longest_side: int 44 | ) -> torch.Tensor: 45 | input_image_size = input_image_size.to(torch.float32) 46 | scale = longest_side / torch.max(input_image_size) 47 | transformed_size = scale * input_image_size 48 | transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) 49 | return transformed_size 50 | 51 | def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: 52 | point_coords = point_coords + 0.5 53 | point_coords = point_coords / self.img_size 54 | point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) 55 | point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) 56 | 57 | point_embedding = point_embedding * (point_labels != -1) 58 | point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( 59 | point_labels == -1 60 | ) 61 | 62 | for i in range(self.model.prompt_encoder.num_point_embeddings): 63 | point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ 64 | i 65 | ].weight * (point_labels == i) 66 | 67 | return point_embedding 68 | 69 | def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: 70 | mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) 71 | mask_embedding = mask_embedding + ( 72 | 1 - has_mask_input 73 | ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) 74 | return mask_embedding 75 | 76 | def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: 77 | masks = F.interpolate( 78 | masks, 79 | size=(self.img_size, self.img_size), 80 | mode="bilinear", 81 | align_corners=False, 82 | ) 83 | 84 | prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) 85 | masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore 86 | 87 | orig_im_size = orig_im_size.to(torch.int64) 88 | h, w = orig_im_size[0], orig_im_size[1] 89 | masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) 90 | return masks 91 | 92 | def select_masks( 93 | self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int 94 | ) -> Tuple[torch.Tensor, torch.Tensor]: 95 | # Determine if we should return the multiclick mask or not from the number of points. 96 | # The reweighting is used to avoid control flow. 97 | score_reweight = torch.tensor( 98 | [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] 99 | ).to(iou_preds.device) 100 | score = iou_preds + (num_points - 2.5) * score_reweight 101 | best_idx = torch.argmax(score, dim=1) 102 | masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) 103 | iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) 104 | 105 | return masks, iou_preds 106 | 107 | @torch.no_grad() 108 | def forward( 109 | self, 110 | image_embeddings: torch.Tensor, 111 | point_coords: torch.Tensor, 112 | point_labels: torch.Tensor, 113 | mask_input: torch.Tensor, 114 | has_mask_input: torch.Tensor, 115 | orig_im_size: torch.Tensor, 116 | ): 117 | sparse_embedding = self._embed_points(point_coords, point_labels) 118 | dense_embedding = self._embed_masks(mask_input, has_mask_input) 119 | 120 | masks, scores = self.model.mask_decoder.predict_masks( 121 | image_embeddings=image_embeddings, 122 | image_pe=self.model.prompt_encoder.get_dense_pe(), 123 | sparse_prompt_embeddings=sparse_embedding, 124 | dense_prompt_embeddings=dense_embedding, 125 | ) 126 | 127 | if self.use_stability_score: 128 | scores = calculate_stability_score( 129 | masks, self.model.mask_threshold, self.stability_score_offset 130 | ) 131 | 132 | if self.return_single_mask: 133 | masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) 134 | 135 | upscaled_masks = self.mask_postprocessing(masks, orig_im_size) 136 | 137 | if self.return_extra_metrics: 138 | stability_scores = calculate_stability_score( 139 | upscaled_masks, self.model.mask_threshold, self.stability_score_offset 140 | ) 141 | areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) 142 | return upscaled_masks, scores, stability_scores, areas, masks 143 | 144 | return upscaled_masks, scores, masks 145 | -------------------------------------------------------------------------------- /deepseek_vl/models/modeling_vlm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | from attrdict import AttrDict 22 | from einops import rearrange 23 | from transformers import ( 24 | AutoConfig, 25 | AutoModelForCausalLM, 26 | LlamaConfig, 27 | LlamaForCausalLM, 28 | PreTrainedModel, 29 | ) 30 | from transformers.configuration_utils import PretrainedConfig 31 | 32 | from deepseek_vl.models.clip_encoder import CLIPVisionTower, HybridVisionTower 33 | from deepseek_vl.models.projector import MlpProjector 34 | 35 | 36 | def model_name_to_cls(cls_name): 37 | if "MlpProjector" in cls_name: 38 | cls = MlpProjector 39 | 40 | elif "CLIPVisionTower" in cls_name: 41 | cls = CLIPVisionTower 42 | 43 | elif "HybridVisionTower" in cls_name: 44 | cls = HybridVisionTower 45 | 46 | else: 47 | raise ValueError(f"class_name {cls_name} is invalid.") 48 | 49 | return cls 50 | 51 | 52 | class VisionConfig(PretrainedConfig): 53 | model_type = "vision" 54 | cls: str = "" 55 | params: AttrDict = {} 56 | 57 | def __init__(self, **kwargs): 58 | super().__init__(**kwargs) 59 | 60 | self.cls = kwargs.get("cls", "") 61 | if not isinstance(self.cls, str): 62 | self.cls = self.cls.__name__ 63 | 64 | self.params = AttrDict(kwargs.get("params", {})) 65 | 66 | 67 | class AlignerConfig(PretrainedConfig): 68 | model_type = "aligner" 69 | cls: str = "" 70 | params: AttrDict = {} 71 | 72 | def __init__(self, **kwargs): 73 | super().__init__(**kwargs) 74 | 75 | self.cls = kwargs.get("cls", "") 76 | if not isinstance(self.cls, str): 77 | self.cls = self.cls.__name__ 78 | 79 | self.params = AttrDict(kwargs.get("params", {})) 80 | 81 | 82 | class MultiModalityConfig(PretrainedConfig): 83 | model_type = "multi_modality" 84 | vision_config: VisionConfig 85 | aligner_config: AlignerConfig 86 | language_config: LlamaConfig 87 | 88 | def __init__(self, **kwargs): 89 | super().__init__(**kwargs) 90 | vision_config = kwargs.get("vision_config", {}) 91 | self.vision_config = VisionConfig(**vision_config) 92 | 93 | aligner_config = kwargs.get("aligner_config", {}) 94 | self.aligner_config = AlignerConfig(**aligner_config) 95 | 96 | language_config = kwargs.get("language_config", {}) 97 | if isinstance(language_config, LlamaConfig): 98 | self.language_config = language_config 99 | else: 100 | self.language_config = LlamaConfig(**language_config) 101 | 102 | 103 | class MultiModalityPreTrainedModel(PreTrainedModel): 104 | config_class = MultiModalityConfig 105 | base_model_prefix = "multi_modality" 106 | _no_split_modules = [] 107 | _skip_keys_device_placement = "past_key_values" 108 | 109 | 110 | class MultiModalityCausalLM(MultiModalityPreTrainedModel): 111 | def __init__(self, config: MultiModalityConfig): 112 | super().__init__(config) 113 | 114 | vision_config = config.vision_config 115 | vision_cls = model_name_to_cls(vision_config.cls) 116 | self.vision_model = vision_cls(**vision_config.params) 117 | 118 | aligner_config = config.aligner_config 119 | aligner_cls = model_name_to_cls(aligner_config.cls) 120 | self.aligner = aligner_cls(aligner_config.params) 121 | 122 | language_config = config.language_config 123 | self.language_model = LlamaForCausalLM(language_config) 124 | 125 | def prepare_inputs_embeds( 126 | self, 127 | input_ids: torch.LongTensor, 128 | pixel_values: torch.FloatTensor, 129 | images_seq_mask: torch.LongTensor, 130 | images_emb_mask: torch.LongTensor, 131 | **kwargs, 132 | ): 133 | """ 134 | 135 | Args: 136 | input_ids (torch.LongTensor): [b, T] 137 | pixel_values (torch.FloatTensor): [b, n_images, 3, h, w] 138 | images_seq_mask (torch.BoolTensor): [b, T] 139 | images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens] 140 | 141 | assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask) 142 | 143 | Returns: 144 | input_embeds (torch.Tensor): [b, T, D] 145 | """ 146 | 147 | bs, n = pixel_values.shape[0:2] 148 | images = rearrange(pixel_values, "b n c h w -> (b n) c h w") 149 | # [b x n, T2, D] 150 | images_embeds = self.aligner(self.vision_model(images)) 151 | 152 | # [b x n, T2, D] -> [b, n x T2, D] 153 | images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) 154 | # [b, n, T2] -> [b, n x T2] 155 | images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") 156 | 157 | # [b, T, D] 158 | input_ids[input_ids < 0] = 0 # ignore the image embeddings 159 | inputs_embeds = self.language_model.get_input_embeddings()(input_ids) 160 | 161 | # replace with the image embeddings 162 | inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask] 163 | 164 | return inputs_embeds 165 | 166 | 167 | AutoConfig.register("vision", VisionConfig) 168 | AutoConfig.register("aligner", AlignerConfig) 169 | AutoConfig.register("multi_modality", MultiModalityConfig) 170 | AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM) 171 | -------------------------------------------------------------------------------- /mgm/model/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ------------------------------------------------------------------------ 15 | # Modified from LLaVA (https://github.com/haotian-liu/LLaVA) 16 | # Copyright 2024 Yanwei Li 17 | # ------------------------------------------------------------------------ 18 | 19 | import os 20 | import warnings 21 | 22 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig 23 | import torch 24 | from mgm.model import * 25 | from mgm.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 26 | 27 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs): 28 | kwargs = {"device_map": device_map, **kwargs} 29 | 30 | if device != "cuda": 31 | kwargs['device_map'] = {"": device} 32 | 33 | if load_8bit: 34 | kwargs['load_in_8bit'] = True 35 | elif load_4bit: 36 | kwargs['load_in_4bit'] = True 37 | kwargs['quantization_config'] = BitsAndBytesConfig( 38 | load_in_4bit=True, 39 | bnb_4bit_compute_dtype=torch.float16, 40 | bnb_4bit_use_double_quant=True, 41 | bnb_4bit_quant_type='nf4' 42 | ) 43 | else: 44 | kwargs['torch_dtype'] = torch.float16 45 | 46 | if use_flash_attn: 47 | kwargs['attn_implementation'] = 'flash_attention_2' 48 | 49 | if 'mgm' in model_name.lower(): 50 | # Load MGM model 51 | if model_base is not None: 52 | # this may be mm projector only 53 | print('Loading MGM from base model...') 54 | 55 | if "8x7b" in model_name.lower(): 56 | tokenizer = AutoTokenizer.from_pretrained(model_base) 57 | model = MGMMixtralForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) 58 | elif "2b" in model_name.lower(): 59 | tokenizer = AutoTokenizer.from_pretrained(model_base) 60 | model = MGMGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) 61 | else: 62 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 63 | model = MGMLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) 64 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') 65 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} 66 | model.load_state_dict(mm_projector_weights, strict=False) 67 | else: 68 | if "8x7b" in model_name.lower(): 69 | tokenizer = AutoTokenizer.from_pretrained(model_path) 70 | model = MGMMixtralForCausalLM.from_pretrained(model_path, **kwargs) 71 | elif "2b" in model_name.lower(): 72 | tokenizer = AutoTokenizer.from_pretrained(model_path) 73 | model = MGMGemmaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 74 | else: 75 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 76 | model = MGMLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 77 | 78 | else: 79 | # Load language model 80 | if model_base is not None: 81 | # PEFT model 82 | from peft import PeftModel 83 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 84 | model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) 85 | print(f"Loading LoRA weights from {model_path}") 86 | model = PeftModel.from_pretrained(model, model_path) 87 | print(f"Merging weights") 88 | model = model.merge_and_unload() 89 | print('Convert to FP16...') 90 | model.to(torch.float16) 91 | else: 92 | if 'mpt' in model_name.lower(): 93 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 94 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) 95 | else: 96 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 97 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 98 | 99 | image_processor = None 100 | 101 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 102 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) 103 | if mm_use_im_patch_token: 104 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 105 | if mm_use_im_start_end: 106 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 107 | 108 | model.resize_token_embeddings(len(tokenizer)) 109 | 110 | vision_tower = model.get_vision_tower() 111 | if not vision_tower.is_loaded: 112 | vision_tower.load_model() 113 | vision_tower.to(device=device, dtype=torch.float16) 114 | image_processor = vision_tower.image_processor 115 | 116 | if 'mgm' in model_name.lower(): 117 | vision_tower_aux = model.get_vision_tower_aux() 118 | if not vision_tower_aux.is_loaded: 119 | vision_tower_aux.load_model() 120 | vision_tower_aux.to(device=device, dtype=torch.float16) 121 | 122 | # initialize attention modules 123 | model.config.model_path = model_path 124 | model.get_model().initialize_uni_modules(model.config, for_eval=True) 125 | 126 | if hasattr(model.config, "max_sequence_length"): 127 | context_len = model.config.max_sequence_length 128 | else: 129 | context_len = 2048 130 | 131 | return tokenizer, model, image_processor, context_len -------------------------------------------------------------------------------- /mgm/model/language_model/mgm_gemma.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ------------------------------------------------------------------------ 15 | # Modified from LLaVA (https://github.com/haotian-liu/LLaVA) 16 | # Copyright 2024 Yanwei Li 17 | # ------------------------------------------------------------------------ 18 | 19 | from typing import List, Optional, Tuple, Union 20 | 21 | import torch 22 | import torch.nn as nn 23 | 24 | try: 25 | from transformers import AutoConfig, AutoModelForCausalLM, \ 26 | GemmaConfig, GemmaModel, GemmaForCausalLM 27 | except: 28 | print("New model not imported. Try to update Transformers to 4.38.0 or later.") 29 | from transformers.modeling_outputs import CausalLMOutputWithPast 30 | from transformers.generation.utils import GenerateOutput 31 | from transformers.generation.utils import logging 32 | 33 | from ..mgm_arch import MGMMetaModel, MGMMetaForCausalLM 34 | 35 | logger = logging.get_logger(__name__) 36 | 37 | class MGMConfig(GemmaConfig): 38 | model_type = "mgm_gemma" 39 | 40 | 41 | class MGMGemmaModel(MGMMetaModel, GemmaModel): 42 | config_class = MGMConfig 43 | 44 | def __init__(self, config: GemmaConfig): 45 | super(MGMGemmaModel, self).__init__(config) 46 | 47 | class MGMGemmaForCausalLM(GemmaForCausalLM, MGMMetaForCausalLM): 48 | config_class = MGMConfig 49 | 50 | def __init__(self, config): 51 | super(GemmaForCausalLM, self).__init__(config) 52 | self.model = MGMGemmaModel(config) 53 | self.vocab_size = config.vocab_size 54 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 55 | 56 | # Initialize weights and apply final processing 57 | self.post_init() 58 | 59 | def get_model(self): 60 | return self.model 61 | 62 | def forward( 63 | self, 64 | input_ids: torch.LongTensor = None, 65 | attention_mask: Optional[torch.Tensor] = None, 66 | position_ids: Optional[torch.LongTensor] = None, 67 | past_key_values: Optional[List[torch.FloatTensor]] = None, 68 | inputs_embeds: Optional[torch.FloatTensor] = None, 69 | labels: Optional[torch.LongTensor] = None, 70 | use_cache: Optional[bool] = None, 71 | cache_position: Optional[torch.LongTensor] = None, 72 | output_attentions: Optional[bool] = None, 73 | output_hidden_states: Optional[bool] = None, 74 | images: Optional[torch.FloatTensor] = None, 75 | images_aux: Optional[torch.FloatTensor] = None, 76 | return_dict: Optional[bool] = None, 77 | mask_ids: Optional[torch.LongTensor] = None, 78 | ) -> Union[Tuple, CausalLMOutputWithPast]: 79 | image_places = None 80 | if inputs_embeds is None: 81 | ( 82 | input_ids, 83 | position_ids, 84 | attention_mask, 85 | past_key_values, 86 | inputs_embeds, 87 | labels, mask_ids, image_places 88 | ) = self.prepare_inputs_labels_for_multimodal( 89 | input_ids, 90 | position_ids, 91 | attention_mask, 92 | past_key_values, 93 | labels, 94 | images, 95 | images_aux, mask_ids=mask_ids 96 | ) 97 | assert return_dict 98 | 99 | output = super().forward( 100 | input_ids=input_ids, 101 | attention_mask=attention_mask, 102 | position_ids=position_ids, 103 | past_key_values=past_key_values, 104 | inputs_embeds=inputs_embeds, 105 | labels=labels, 106 | use_cache=use_cache, 107 | cache_position=cache_position, 108 | output_attentions=output_attentions, 109 | output_hidden_states=output_hidden_states, 110 | return_dict=return_dict 111 | ) 112 | output.mask_ids = mask_ids 113 | output.image_places = image_places 114 | 115 | return output 116 | 117 | @torch.no_grad() 118 | def generate( 119 | self, 120 | inputs: Optional[torch.Tensor] = None, 121 | images: Optional[torch.Tensor] = None, 122 | images_aux: Optional[torch.FloatTensor] = None, 123 | **kwargs, 124 | ) -> Union[GenerateOutput, torch.LongTensor]: 125 | position_ids = kwargs.pop("position_ids", None) 126 | attention_mask = kwargs.pop("attention_mask", None) 127 | if "inputs_embeds" in kwargs: 128 | raise NotImplementedError("`inputs_embeds` is not supported") 129 | 130 | if images is not None: 131 | ( 132 | inputs, 133 | position_ids, 134 | attention_mask, 135 | _, 136 | inputs_embeds, 137 | _, mask_ids, image_places 138 | ) = self.prepare_inputs_labels_for_multimodal( 139 | inputs, 140 | position_ids, 141 | attention_mask, 142 | None, 143 | None, 144 | images, 145 | images_aux 146 | ) 147 | else: 148 | inputs_embeds = self.get_model().embed_tokens(inputs) 149 | 150 | return super().generate( 151 | position_ids=position_ids, 152 | attention_mask=attention_mask, 153 | inputs_embeds=inputs_embeds, 154 | **kwargs 155 | ) 156 | 157 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 158 | images = kwargs.pop("images", None) 159 | images_aux = kwargs.pop("images_aux", None) 160 | _inputs = super().prepare_inputs_for_generation( 161 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 162 | ) 163 | if images is not None: 164 | _inputs['images'] = images 165 | if images_aux is not None: 166 | _inputs['images_aux'] = images_aux 167 | return _inputs 168 | 169 | AutoConfig.register("mgm_gemma", MGMConfig) 170 | AutoModelForCausalLM.register(MGMConfig, MGMGemmaForCausalLM) -------------------------------------------------------------------------------- /scripts/multiprocess_eval_png.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import argparse 5 | from flmm.datasets.png import PNGDataset 6 | from tqdm import tqdm 7 | from xtuner.registry import BUILDER 8 | from mmengine.config import Config 9 | from xtuner.model.utils import guess_load_checkpoint 10 | from accelerate import Accelerator 11 | from accelerate.utils import gather_object 12 | from xtuner.utils.constants import DEFAULT_IMAGE_TOKEN 13 | 14 | accelerator = Accelerator() 15 | 16 | 17 | def average_accuracy(ious): 18 | ious = ious.cpu().numpy() 19 | accuracy = [] 20 | average_acc = 0 21 | thresholds = np.arange(0, 1, 0.00001) 22 | for t in thresholds: 23 | predictions = (ious >= t).astype(int) 24 | TP = np.sum(predictions) 25 | a = TP / len(predictions) 26 | 27 | accuracy.append(a) 28 | for i, t in enumerate(zip(thresholds[:-1], thresholds[1:])): 29 | average_acc += (np.abs(t[1] - t[0])) * accuracy[i] 30 | 31 | return average_acc 32 | 33 | 34 | def compute_mask_IoU(masks, target): 35 | temp = masks * target 36 | intersection = temp.sum(dim=-1) 37 | union = ((masks + target) - temp).sum(dim=-1) 38 | return intersection, union, intersection / (union + 1e-12) 39 | 40 | 41 | def mask2box(mask): 42 | ys, xs = np.where(mask) 43 | y0, y1 = ys.min(), ys.max() 44 | x0, x1 = xs.min(), xs.max() 45 | 46 | return np.array([x0, y0, x1, y1]) 47 | 48 | 49 | def mask2point(mask, image_h, image_w): 50 | h, w = mask.shape 51 | ys, xs = np.where(mask) 52 | ys, xs = (image_h * (ys.astype(np.float32) + 0.5) / h, 53 | image_w * (xs.astype(np.float32) + 0.5) / w) 54 | return np.stack([xs, ys], axis=1) 55 | 56 | 57 | def mask2logits(mask, eps=1e-3): 58 | def inv_sigmoid(x): 59 | return np.log(x / (1 - x)) 60 | 61 | logits = np.zeros(mask.shape, dtype="float32") 62 | logits[mask > 0] = 1 - eps 63 | logits[mask < 1] = eps 64 | logits = inv_sigmoid(logits) 65 | 66 | return logits 67 | 68 | 69 | if __name__ == '__main__': 70 | parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) 71 | parser.add_argument('config', help='config file path.') 72 | parser.add_argument('--checkpoint', default=None, type=str) 73 | parser.add_argument('--debug', action='store_true') 74 | args = parser.parse_args() 75 | 76 | accelerator = Accelerator() 77 | # each GPU creates a string 78 | message = [f"Hello this is GPU {accelerator.process_index}"] 79 | # collect the messages from all GPUs 80 | messages = gather_object(message) 81 | # output the messages only on the main process with accelerator.print() 82 | accelerator.print(messages) 83 | 84 | cfg = Config.fromfile(args.config) 85 | prompt_template = cfg.prompt_template 86 | tokenizer = cfg.tokenizer 87 | image_processor = cfg.image_processor 88 | prompt = cfg.get('prompt', None) 89 | 90 | print(f'Device: {accelerator.device}', flush=True) 91 | model = BUILDER.build(cfg.model) 92 | if args.checkpoint is not None: 93 | state_dict = guess_load_checkpoint(args.checkpoint) 94 | missing, unexpected = model.load_state_dict(state_dict, strict=False) 95 | accelerator.print(f"Unexpected parameters: {unexpected}") 96 | model = model.to(device=accelerator.device) 97 | model.eval() 98 | 99 | dataset_params = dict(json_file='data/coco/annotations/png_coco_val2017.json', 100 | panoptic_json_file='data/coco/annotations/panoptic_val2017.json', 101 | panoptic_png_path='data/coco/annotations/panoptic_val2017', 102 | tokenizer=tokenizer, 103 | image_processor=image_processor, 104 | prompt_template=prompt_template, 105 | local_path='data/coco/val2017', 106 | ceph_path='openmmlab:s3://openmmlab/datasets/detection/coco/val2017', 107 | image2tensor=cfg.get('image2tensor', True), 108 | add_image_token=cfg.get('add_image_token', False), 109 | image_token=cfg.get('image_token', DEFAULT_IMAGE_TOKEN) 110 | ) 111 | if prompt is not None: 112 | dataset_params.update(prompt=prompt) 113 | png_dataset = PNGDataset(**dataset_params) 114 | 115 | mask_ious = [] 116 | isthing = [] 117 | plural = [] 118 | pixel_accs = [] 119 | 120 | # sync GPUs and start the timer 121 | accelerator.wait_for_everyone() 122 | 123 | data_ids = list(range(len(png_dataset))) 124 | if args.debug: 125 | data_ids = data_ids[:100] 126 | 127 | # divide the prompt list onto the available GPUs 128 | with accelerator.split_between_processes(data_ids) as sub_ids: 129 | 130 | for idx in tqdm(sub_ids, disable=not accelerator.is_main_process): 131 | data_sample = png_dataset[idx] 132 | with torch.no_grad(): 133 | pred_mask_logits = model.predict(data_sample) 134 | masks = data_sample['gt_masks'].to(pred_mask_logits.device) 135 | gt_masks = masks.float().cpu() 136 | 137 | pred_masks = F.interpolate(pred_mask_logits[None].float().sigmoid(), 138 | size=masks.shape[-2:], mode='bilinear')[0].cpu() 139 | pred_masks = (pred_masks > 0.5).float() 140 | 141 | assert pred_masks.shape == gt_masks.shape 142 | mask_cnt = pred_masks.shape[0] 143 | 144 | mask_infos = data_sample['mask_infos'] 145 | sub_mask_ious = [compute_mask_IoU(pred_masks.flatten(1, 2), gt_masks.flatten(1, 2))[-1]] 146 | sub_isthing = [torch.tensor([mask_info['isthing'] for mask_info in mask_infos])] 147 | sub_plural = [torch.tensor([mask_info['plural'] for mask_info in mask_infos])] 148 | pixel_acc = [torch.eq(pred_masks, gt_masks).float().flatten(1, 2).mean(-1)] 149 | 150 | mask_ious += sub_mask_ious 151 | isthing += sub_isthing 152 | plural += sub_plural 153 | pixel_accs += pixel_acc 154 | 155 | mask_ious = gather_object(mask_ious) 156 | isthing = gather_object(isthing) 157 | plural = gather_object(plural) 158 | pixel_accs = gather_object(pixel_accs) 159 | 160 | if accelerator.is_main_process: 161 | mask_ious = torch.cat(mask_ious) 162 | isthing = torch.cat(isthing) 163 | plural = torch.cat(plural) 164 | 165 | AA = average_accuracy(mask_ious) 166 | AA_singulars = average_accuracy(mask_ious[torch.logical_not(plural)]) 167 | AA_plurals = average_accuracy(mask_ious[plural]) 168 | AA_things = average_accuracy(mask_ious[isthing]) 169 | AA_stuff = average_accuracy(mask_ious[torch.logical_not(isthing)]) 170 | 171 | accuracy = (mask_ious > 0.5).float().mean() 172 | 173 | pixel_accs = torch.cat(pixel_accs).mean() 174 | 175 | print(f"aIoU: {AA}, aIoU_singulars: {AA_singulars}, aIoU_plurals: {AA_plurals}, " 176 | f"aIoU_things: {AA_things}, aIoU_stuff: {AA_stuff}, aAcc@0.5: {accuracy}, " 177 | f"pixel_accs: {pixel_accs}", flush=True) 178 | -------------------------------------------------------------------------------- /segment_anything/modeling/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import List, Tuple, Type 12 | 13 | from .common import LayerNorm2d 14 | 15 | 16 | class MaskDecoder(nn.Module): 17 | def __init__( 18 | self, 19 | *, 20 | transformer_dim: int, 21 | transformer: nn.Module, 22 | num_multimask_outputs: int = 3, 23 | activation: Type[nn.Module] = nn.GELU, 24 | iou_head_depth: int = 3, 25 | iou_head_hidden_dim: int = 256, 26 | ) -> None: 27 | """ 28 | Predicts masks given an image and prompt embeddings, using a 29 | transformer architecture. 30 | 31 | Arguments: 32 | transformer_dim (int): the channel dimension of the transformer 33 | transformer (nn.Module): the transformer used to predict masks 34 | num_multimask_outputs (int): the number of masks to predict 35 | when disambiguating masks 36 | activation (nn.Module): the type of activation to use when 37 | upscaling masks 38 | iou_head_depth (int): the depth of the MLP used to predict 39 | mask quality 40 | iou_head_hidden_dim (int): the hidden dimension of the MLP 41 | used to predict mask quality 42 | """ 43 | super().__init__() 44 | self.transformer_dim = transformer_dim 45 | self.transformer = transformer 46 | 47 | self.num_multimask_outputs = num_multimask_outputs 48 | 49 | self.iou_token = nn.Embedding(1, transformer_dim) 50 | self.num_mask_tokens = num_multimask_outputs + 1 51 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 52 | 53 | self.output_upscaling = nn.Sequential( 54 | nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), 55 | LayerNorm2d(transformer_dim // 4), 56 | activation(), 57 | nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), 58 | activation(), 59 | ) 60 | self.output_hypernetworks_mlps = nn.ModuleList( 61 | [ 62 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 63 | for i in range(self.num_mask_tokens) 64 | ] 65 | ) 66 | 67 | self.iou_prediction_head = MLP( 68 | transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth 69 | ) 70 | 71 | def forward( 72 | self, 73 | image_embeddings: torch.Tensor, 74 | image_pe: torch.Tensor, 75 | sparse_prompt_embeddings: torch.Tensor, 76 | dense_prompt_embeddings: torch.Tensor, 77 | multimask_output: bool, 78 | ) -> Tuple[torch.Tensor, torch.Tensor]: 79 | """ 80 | Predict masks given image and prompt embeddings. 81 | 82 | Arguments: 83 | image_embeddings (torch.Tensor): the embeddings from the image encoder 84 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 85 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 86 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 87 | multimask_output (bool): Whether to return multiple masks or a single 88 | mask. 89 | 90 | Returns: 91 | torch.Tensor: batched predicted masks 92 | torch.Tensor: batched predictions of mask quality 93 | """ 94 | masks, iou_pred = self.predict_masks( 95 | image_embeddings=image_embeddings, 96 | image_pe=image_pe, 97 | sparse_prompt_embeddings=sparse_prompt_embeddings, 98 | dense_prompt_embeddings=dense_prompt_embeddings, 99 | ) 100 | 101 | # Select the correct mask or masks for output 102 | if multimask_output: 103 | mask_slice = slice(1, None) 104 | else: 105 | mask_slice = slice(0, 1) 106 | masks = masks[:, mask_slice, :, :] 107 | iou_pred = iou_pred[:, mask_slice] 108 | 109 | # Prepare output 110 | return masks, iou_pred 111 | 112 | def predict_masks( 113 | self, 114 | image_embeddings: torch.Tensor, 115 | image_pe: torch.Tensor, 116 | sparse_prompt_embeddings: torch.Tensor, 117 | dense_prompt_embeddings: torch.Tensor, 118 | ) -> Tuple[torch.Tensor, torch.Tensor]: 119 | """Predicts masks. See 'forward' for more details.""" 120 | # Concatenate output tokens 121 | output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) 122 | output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) 123 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 124 | 125 | # Expand per-image data in batch direction to be per-mask 126 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 127 | src = src + dense_prompt_embeddings 128 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 129 | b, c, h, w = src.shape 130 | 131 | # Run the transformer 132 | hs, src = self.transformer(src, pos_src, tokens) 133 | iou_token_out = hs[:, 0, :] 134 | mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] 135 | 136 | # Upscale mask embeddings and predict masks using the mask tokens 137 | src = src.transpose(1, 2).view(b, c, h, w) 138 | upscaled_embedding = self.output_upscaling(src) 139 | hyper_in_list: List[torch.Tensor] = [] 140 | for i in range(self.num_mask_tokens): 141 | hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) 142 | hyper_in = torch.stack(hyper_in_list, dim=1) 143 | b, c, h, w = upscaled_embedding.shape 144 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 145 | 146 | # Generate mask quality predictions 147 | iou_pred = self.iou_prediction_head(iou_token_out) 148 | 149 | return masks, iou_pred 150 | 151 | 152 | # Lightly adapted from 153 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 154 | class MLP(nn.Module): 155 | def __init__( 156 | self, 157 | input_dim: int, 158 | hidden_dim: int, 159 | output_dim: int, 160 | num_layers: int, 161 | sigmoid_output: bool = False, 162 | ) -> None: 163 | super().__init__() 164 | self.num_layers = num_layers 165 | h = [hidden_dim] * (num_layers - 1) 166 | self.layers = nn.ModuleList( 167 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 168 | ) 169 | self.sigmoid_output = sigmoid_output 170 | 171 | def forward(self, x): 172 | for i, layer in enumerate(self.layers): 173 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 174 | if self.sigmoid_output: 175 | x = F.sigmoid(x) 176 | return x 177 | -------------------------------------------------------------------------------- /scripts/demo/multiprocess_infer_png.py: -------------------------------------------------------------------------------- 1 | import os 2 | from sklearn.cluster import KMeans 3 | import numpy as np 4 | from PIL import Image 5 | import torch 6 | import torch.nn.functional as F 7 | import argparse 8 | from src.datasets.png import PNGDataset 9 | from tqdm import tqdm 10 | from xtuner.registry import BUILDER 11 | from mmengine.config import Config 12 | from xtuner.model.utils import guess_load_checkpoint 13 | from accelerate import Accelerator 14 | from accelerate.utils import gather_object 15 | from xtuner.utils.constants import DEFAULT_IMAGE_TOKEN 16 | from scripts.demo.utils import colors 17 | 18 | def compute_mask_IoU(masks, target): 19 | temp = masks * target 20 | intersection = temp.sum(dim=-1) 21 | union = ((masks + target) - temp).sum(dim=-1) 22 | return intersection, union, intersection / (union + 1e-12) 23 | 24 | def do_kmeans(feature_map, gt_mask): 25 | c, h, w = feature_map.shape 26 | feature_map = feature_map.view(c, h*w).T.contiguous() 27 | feature_map = F.normalize(feature_map, dim=-1).cpu().numpy() 28 | cluster_method = KMeans(n_clusters=2, n_init=10) 29 | # fit model and predict clusters 30 | results = cluster_method.fit_predict(feature_map) 31 | 32 | mask1 = torch.from_numpy(results.reshape(h, w) == 0).float() 33 | mask2 = torch.from_numpy(results.reshape(h, w) == 1).float() 34 | 35 | masks = F.interpolate(torch.stack([mask1, mask2])[None], size=gt_mask.shape, mode='bilinear')[0] 36 | ious = compute_mask_IoU(masks.view(2, -1), torch.from_numpy(gt_mask).float().view(1, -1))[-1] 37 | 38 | return masks[ious.argmax()] > 0 39 | 40 | 41 | if __name__ == '__main__': 42 | parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) 43 | parser.add_argument('config', help='config file path.') 44 | parser.add_argument('--checkpoint', 45 | default='checkpoints/frozen_deepseek_vl_1_3b_unet_sam_l_iter_95080.pth', type=str) 46 | parser.add_argument('--save_dir', default='data/deepseek1_3b_png', type=str) 47 | args = parser.parse_args() 48 | 49 | os.makedirs(args.save_dir, exist_ok=True) 50 | for subset in ['gt', 'sam', 'conv', 'attn', 'attn_all']: 51 | os.makedirs(os.path.join(args.save_dir, subset), exist_ok=True) 52 | 53 | accelerator = Accelerator() 54 | # each GPU creates a string 55 | message = [f"Hello this is GPU {accelerator.process_index}"] 56 | # collect the messages from all GPUs 57 | messages = gather_object(message) 58 | # output the messages only on the main process with accelerator.print() 59 | accelerator.print(messages) 60 | 61 | cfg = Config.fromfile(args.config) 62 | prompt_template = cfg.prompt_template 63 | tokenizer = cfg.tokenizer 64 | image_processor = cfg.image_processor 65 | prompt = cfg.get('prompt', None) 66 | 67 | print(f'Device: {accelerator.device}', flush=True) 68 | model = BUILDER.build(cfg.model) 69 | if args.checkpoint is not None: 70 | state_dict = guess_load_checkpoint(args.checkpoint) 71 | missing, unexpected = model.load_state_dict(state_dict, strict=False) 72 | accelerator.print(f"Unexpected parameters: {unexpected}") 73 | model = model.to(device=accelerator.device) 74 | model.eval() 75 | 76 | dataset_params = dict(json_file='data/png_coco_val2017.json', 77 | panoptic_json_file='data/coco/annotations/panoptic_val2017.json', 78 | panoptic_png_path='data/coco/panoptic_val2017', 79 | tokenizer=tokenizer, 80 | image_processor=image_processor, 81 | prompt_template=prompt_template, 82 | local_path='data/coco/val2017', 83 | ceph_path='openmmlab:s3://openmmlab/datasets/detection/coco/val2017', 84 | image2tensor=cfg.get('image2tensor', True), 85 | add_image_token=cfg.get('add_image_token', False), 86 | image_token=cfg.get('image_token', DEFAULT_IMAGE_TOKEN) 87 | ) 88 | if prompt is not None: 89 | dataset_params.update(prompt=prompt) 90 | png_dataset = PNGDataset(**dataset_params) 91 | 92 | mask_ious = [] 93 | isthing = [] 94 | plural = [] 95 | pixel_accs = [] 96 | 97 | # sync GPUs and start the timer 98 | accelerator.wait_for_everyone() 99 | 100 | data_ids = list(range(len(png_dataset)))[:100] 101 | # divide the prompt list onto the available GPUs 102 | with accelerator.split_between_processes(data_ids) as sub_ids: 103 | for idx in tqdm(sub_ids, disable=not accelerator.is_main_process): 104 | data_sample = png_dataset[idx] 105 | with torch.no_grad(): 106 | output = model._forward(data_sample) 107 | gt_masks = data_sample['gt_masks'].cpu().numpy() > 0 108 | pred_masks = F.interpolate(output['pred_masks'][None].float().cpu(), 109 | size=gt_masks.shape[-2:], mode='bilinear')[0].numpy() > 0 110 | sam_pred_masks = F.interpolate(output['sam_pred_masks'][None].float().cpu(), 111 | size=gt_masks.shape[-2:], mode='bilinear')[0].numpy() > 0 112 | mask_attentions = output['mask_attentions'] 113 | attn_masks = torch.stack([do_kmeans(mask_attention, gt_mask) for mask_attention, gt_mask in 114 | zip(mask_attentions, gt_masks)]) 115 | # attn_masks = F.interpolate(attn_masks[None].float().cpu(), 116 | # size=gt_masks.shape[-2:], mode='bilinear')[0].numpy() > 0 117 | 118 | file_name = os.path.basename(data_sample['file_name']) 119 | 120 | image = np.array(data_sample['image']).astype(np.float32) 121 | sam_image = image.copy() 122 | gt_image = image.copy() 123 | conv_image = image.copy() 124 | attn_image = image.copy() 125 | 126 | for color_id, (gt_mask, sam_mask, cnn_mask, attn_mask) in enumerate( 127 | zip(gt_masks, sam_pred_masks, pred_masks, attn_masks)): 128 | sam_image[sam_mask] = sam_image[sam_mask] * 0.2 + np.array(colors[color_id]).reshape((1, 1, 3)) * 0.8 129 | gt_image[gt_mask] = gt_image[gt_mask] * 0.2 + np.array(colors[color_id]).reshape((1, 1, 3)) * 0.8 130 | conv_image[cnn_mask] = conv_image[cnn_mask] * 0.2 + np.array(colors[color_id]).reshape((1, 1, 3)) * 0.8 131 | attn_image[attn_mask] = attn_image[attn_mask] * 0.2 + np.array(colors[color_id]).reshape((1, 1, 3)) * 0.8 132 | 133 | 134 | all_in_one = np.concatenate([image, attn_image, conv_image, sam_image, gt_image], axis=1) 135 | 136 | sam_image = Image.fromarray(sam_image.astype(np.uint8)) 137 | gt_image = Image.fromarray(gt_image.astype(np.uint8)) 138 | conv_image = Image.fromarray(conv_image.astype(np.uint8)) 139 | attn_image = Image.fromarray(attn_image.astype(np.uint8)) 140 | 141 | all_in_one = Image.fromarray(all_in_one.astype(np.uint8)) 142 | 143 | 144 | sam_image.save(os.path.join(args.save_dir, f'sam/{file_name}')) 145 | gt_image.save(os.path.join(args.save_dir, f'gt/{file_name}')) 146 | conv_image.save(os.path.join(args.save_dir, f'conv/{file_name}')) 147 | attn_image.save(os.path.join(args.save_dir, f'attn/{file_name}')) 148 | all_in_one.save(os.path.join(args.save_dir, file_name)) 149 | 150 | np.save(os.path.join(args.save_dir, f'attn_all/{file_name[:-4]}.npy'), mask_attentions.cpu().numpy()) 151 | -------------------------------------------------------------------------------- /segment_anything/modeling/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | 11 | from typing import Any, Dict, List, Tuple 12 | 13 | from .image_encoder import ImageEncoderViT 14 | from .mask_decoder import MaskDecoder 15 | from .prompt_encoder import PromptEncoder 16 | 17 | 18 | class Sam(nn.Module): 19 | mask_threshold: float = 0.0 20 | image_format: str = "RGB" 21 | 22 | def __init__( 23 | self, 24 | image_encoder: ImageEncoderViT, 25 | prompt_encoder: PromptEncoder, 26 | mask_decoder: MaskDecoder, 27 | pixel_mean: List[float] = [123.675, 116.28, 103.53], 28 | pixel_std: List[float] = [58.395, 57.12, 57.375], 29 | ) -> None: 30 | """ 31 | SAM predicts object masks from an image and input prompts. 32 | 33 | Arguments: 34 | image_encoder (ImageEncoderViT): The backbone used to encode the 35 | image into image embeddings that allow for efficient mask prediction. 36 | prompt_encoder (PromptEncoder): Encodes various types of input prompts. 37 | mask_decoder (MaskDecoder): Predicts masks from the image embeddings 38 | and encoded prompts. 39 | pixel_mean (list(float)): Mean values for normalizing pixels in the input image. 40 | pixel_std (list(float)): Std values for normalizing pixels in the input image. 41 | """ 42 | super().__init__() 43 | self.image_encoder = image_encoder 44 | self.prompt_encoder = prompt_encoder 45 | self.mask_decoder = mask_decoder 46 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 47 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 48 | 49 | @property 50 | def device(self) -> Any: 51 | return self.pixel_mean.device 52 | 53 | @property 54 | def dtype(self) -> Any: 55 | return self.pixel_mean.dtype 56 | 57 | @torch.no_grad() 58 | def forward( 59 | self, 60 | batched_input: List[Dict[str, Any]], 61 | multimask_output: bool, 62 | ) -> List[Dict[str, torch.Tensor]]: 63 | """ 64 | Predicts masks end-to-end from provided images and prompts. 65 | If prompts are not known in advance, using SamPredictor is 66 | recommended over calling the model directly. 67 | 68 | Arguments: 69 | batched_input (list(dict)): A list over input images, each a 70 | dictionary with the following keys. A prompt key can be 71 | excluded if it is not present. 72 | 'image': The image as a torch tensor in 3xHxW format, 73 | already transformed for input to the model. 74 | 'original_size': (tuple(int, int)) The original size of 75 | the image before transformation, as (H, W). 76 | 'point_coords': (torch.Tensor) Batched point prompts for 77 | this image, with shape BxNx2. Already transformed to the 78 | input frame of the model. 79 | 'point_labels': (torch.Tensor) Batched labels for point prompts, 80 | with shape BxN. 81 | 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. 82 | Already transformed to the input frame of the model. 83 | 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, 84 | in the form Bx1xHxW. 85 | multimask_output (bool): Whether the model should predict multiple 86 | disambiguating masks, or return a single mask. 87 | 88 | Returns: 89 | (list(dict)): A list over input images, where each element is 90 | as dictionary with the following keys. 91 | 'masks': (torch.Tensor) Batched binary mask predictions, 92 | with shape BxCxHxW, where B is the number of input prompts, 93 | C is determined by multimask_output, and (H, W) is the 94 | original size of the image. 95 | 'iou_predictions': (torch.Tensor) The model's predictions 96 | of mask quality, in shape BxC. 97 | 'low_res_logits': (torch.Tensor) Low resolution logits with 98 | shape BxCxHxW, where H=W=256. Can be passed as mask input 99 | to subsequent iterations of prediction. 100 | """ 101 | input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) 102 | image_embeddings = self.image_encoder(input_images) 103 | 104 | outputs = [] 105 | for image_record, curr_embedding in zip(batched_input, image_embeddings): 106 | if "point_coords" in image_record: 107 | points = (image_record["point_coords"], image_record["point_labels"]) 108 | else: 109 | points = None 110 | sparse_embeddings, dense_embeddings = self.prompt_encoder( 111 | points=points, 112 | boxes=image_record.get("boxes", None), 113 | masks=image_record.get("mask_inputs", None), 114 | ) 115 | low_res_masks, iou_predictions = self.mask_decoder( 116 | image_embeddings=curr_embedding.unsqueeze(0), 117 | image_pe=self.prompt_encoder.get_dense_pe(), 118 | sparse_prompt_embeddings=sparse_embeddings, 119 | dense_prompt_embeddings=dense_embeddings, 120 | multimask_output=multimask_output, 121 | ) 122 | masks = self.postprocess_masks( 123 | low_res_masks, 124 | input_size=image_record["image"].shape[-2:], 125 | original_size=image_record["original_size"], 126 | ) 127 | masks = masks > self.mask_threshold 128 | outputs.append( 129 | { 130 | "masks": masks, 131 | "iou_predictions": iou_predictions, 132 | "low_res_logits": low_res_masks, 133 | } 134 | ) 135 | return outputs 136 | 137 | def postprocess_masks( 138 | self, 139 | masks: torch.Tensor, 140 | input_size: Tuple[int, ...], 141 | original_size: Tuple[int, ...], 142 | ) -> torch.Tensor: 143 | """ 144 | Remove padding and upscale masks to the original image size. 145 | 146 | Arguments: 147 | masks (torch.Tensor): Batched masks from the mask_decoder, 148 | in BxCxHxW format. 149 | input_size (tuple(int, int)): The size of the image input to the 150 | model, in (H, W) format. Used to remove padding. 151 | original_size (tuple(int, int)): The original size of the image 152 | before resizing for input to the model, in (H, W) format. 153 | 154 | Returns: 155 | (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) 156 | is given by original_size. 157 | """ 158 | masks = F.interpolate( 159 | masks.float(), 160 | (self.image_encoder.img_size, self.image_encoder.img_size), 161 | mode="bilinear", 162 | align_corners=False, 163 | ).to(masks) 164 | masks = masks[..., : input_size[0], : input_size[1]] 165 | masks = F.interpolate(masks.float(), original_size, mode="bilinear", align_corners=False).to(masks) 166 | return masks 167 | 168 | def preprocess(self, x: torch.Tensor) -> torch.Tensor: 169 | """Normalize pixel values and pad to a square input.""" 170 | # Normalize colors 171 | x = (x - self.pixel_mean) / self.pixel_std 172 | 173 | # Pad 174 | h, w = x.shape[-2:] 175 | padh = self.image_encoder.img_size - h 176 | padw = self.image_encoder.img_size - w 177 | x = F.pad(x, (0, padw, 0, padh)) 178 | return x 179 | -------------------------------------------------------------------------------- /scripts/multiprocess_eval_refcoco.py: -------------------------------------------------------------------------------- 1 | from mmdet.datasets import RefCocoDataset 2 | from flmm.datasets.transforms import PILLoadImageFromFile, RefCOCO2PNG 3 | from mmdet.datasets.transforms import LoadAnnotations 4 | from mmdet.evaluation import RefSegMetric 5 | import argparse 6 | from mmengine.config import Config 7 | from xtuner.model.utils import guess_load_checkpoint 8 | from xtuner.registry import BUILDER 9 | from xtuner.utils.constants import DEFAULT_IMAGE_TOKEN 10 | from accelerate import Accelerator 11 | from accelerate.utils import gather_object 12 | from mmdet.structures.mask import BitmapMasks 13 | 14 | from tqdm import tqdm 15 | import torch 16 | import torch.nn.functional as F 17 | from time import time 18 | 19 | 20 | if __name__ == '__main__': 21 | parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) 22 | parser.add_argument('config', help='config file path.') 23 | parser.add_argument('--checkpoint', default=None, type=str) 24 | parser.add_argument('--debug', action='store_true') 25 | parser.add_argument('--ceph', action='store_true') 26 | parser.add_argument('--concat', action='store_true') 27 | args = parser.parse_args() 28 | 29 | ### Initialize accelerator 30 | accelerator = Accelerator() 31 | # each GPU creates a string 32 | message = [f"Hello this is GPU {accelerator.process_index}"] 33 | # collect the messages from all GPUs 34 | messages = gather_object(message) 35 | # output the messages only on the main process with accelerator.print() 36 | accelerator.print(messages) 37 | 38 | cfg = Config.fromfile(args.config) 39 | prompt_template = cfg.prompt_template 40 | tokenizer = cfg.tokenizer 41 | image_processor = cfg.image_processor 42 | prompt = cfg.get('prompt', None) 43 | model = BUILDER.build(cfg.model) 44 | 45 | if args.checkpoint is not None: 46 | state_dict = guess_load_checkpoint(args.checkpoint) 47 | missing, unexpected = model.load_state_dict(state_dict, strict=False) 48 | accelerator.print(f"Unexpected parameters: {unexpected}") 49 | 50 | print(f"Start moving model to device: {accelerator.device}", flush=True) 51 | tik = time() 52 | model = model.to(device=accelerator.device) 53 | print(f"Finished moving model to device: {accelerator.device}, time used: {time() - tik}", flush=True) 54 | model.eval() 55 | 56 | if args.ceph: 57 | backend_args = dict( 58 | backend='petrel', 59 | path_mapping=dict({ 60 | 'data/coco/train2014/': 'openmmlab:s3://openmmlab/datasets/detection/coco/train2014/' 61 | })) 62 | else: 63 | backend_args = None 64 | 65 | refcoco2png_params = dict( 66 | type=RefCOCO2PNG, 67 | image_processor=image_processor, 68 | tokenizer=tokenizer, 69 | prompt_template=prompt_template, 70 | concat=args.concat, 71 | image2tensor=cfg.get('image2tensor', True), 72 | add_image_token=cfg.get('add_image_token', False), 73 | image_token=cfg.get('image_token', DEFAULT_IMAGE_TOKEN) 74 | ) 75 | accelerator.print(f"Do concatenation? {args.concat}") 76 | if prompt is not None: 77 | refcoco2png_params.update(prompt=prompt) 78 | 79 | # ref_coco data pipeline 80 | test_pipeline = [ 81 | dict(type=PILLoadImageFromFile, backend_args=backend_args), 82 | dict( 83 | type=LoadAnnotations, 84 | with_mask=True, 85 | with_bbox=False, 86 | with_seg=False, 87 | with_label=False), 88 | refcoco2png_params 89 | ] 90 | 91 | refcoco_subsets = dict() 92 | for split in ['val', 'testA', 'testB']: 93 | refcoco_subsets[f'refcoco_{split}'] = dict( 94 | ann_file='refcoco/instances.json', 95 | split_file='refcoco/refs(unc).p', 96 | split=split) 97 | 98 | for split in ['val', 'testA', 'testB']: 99 | refcoco_subsets[f'refcoco+_{split}'] = dict( 100 | ann_file='refcoco+/instances.json', 101 | split_file='refcoco+/refs(unc).p', 102 | split=split) 103 | 104 | for split in ['val', 'test']: 105 | refcoco_subsets[f'refcocog_{split}'] = dict( 106 | ann_file='refcocog/instances.json', 107 | split_file='refcocog/refs(umd).p', 108 | split=split) 109 | 110 | for name, subset in refcoco_subsets.items(): 111 | accelerator.print(f"Start evaluating {name}") 112 | dataset = RefCocoDataset( 113 | data_root='data/coco/', 114 | data_prefix=dict(img_path='train2014/'), 115 | text_mode='select_first', 116 | pipeline=test_pipeline, 117 | **subset 118 | ) 119 | # sync GPUs and start the timer 120 | accelerator.wait_for_everyone() 121 | 122 | data_ids = list(range(len(dataset))) 123 | if args.debug: 124 | data_ids = data_ids[:100] 125 | 126 | results = [] 127 | # divide the prompt list onto the available GPUs 128 | with accelerator.split_between_processes(data_ids) as sub_ids: 129 | for idx in tqdm(sub_ids, disable=not accelerator.is_main_process): 130 | data_sample = dataset[idx] 131 | if args.concat: 132 | with torch.no_grad(): 133 | pred_mask_logits = model.predict(data_sample) 134 | 135 | gt_masks = data_sample['gt_masks'].numpy() > 0 136 | pred_masks = F.interpolate(pred_mask_logits[None].float().sigmoid(), 137 | size=gt_masks.shape[-2:], mode='bilinear')[0].cpu() 138 | pred_masks = pred_masks > 0.5 139 | 140 | assert len(pred_masks) == len(gt_masks) 141 | mask_cnt = pred_masks.shape[0] 142 | 143 | # Formulate the output into the format that the evaluator accepts 144 | results.append(dict(pred_instances=dict(masks=pred_masks), 145 | gt_masks=BitmapMasks(masks=gt_masks, 146 | height=gt_masks.shape[1], 147 | width=gt_masks.shape[2])) 148 | ) 149 | else: 150 | for sub_data_sample in data_sample: 151 | with torch.no_grad(): 152 | pred_mask_logits = model.predict(sub_data_sample) 153 | 154 | gt_masks = sub_data_sample['gt_masks'].numpy() > 0 155 | pred_masks = F.interpolate(pred_mask_logits[None].float().sigmoid(), 156 | size=gt_masks.shape[-2:], mode='bilinear')[0].cpu() 157 | pred_masks = pred_masks > 0.5 158 | 159 | assert len(pred_masks) == len(gt_masks) 160 | mask_cnt = pred_masks.shape[0] 161 | assert mask_cnt == 1 162 | 163 | # Formulate the output into the format that the evaluator accepts 164 | results.append(dict(pred_instances=dict(masks=pred_masks), 165 | gt_masks=BitmapMasks(masks=gt_masks, 166 | height=gt_masks.shape[1], 167 | width=gt_masks.shape[2])) 168 | ) 169 | results = gather_object(results) 170 | if accelerator.is_main_process: 171 | accelerator.print(f"Collected {len(results)} result samples from all gpus") 172 | evaluator = RefSegMetric(metric=['cIoU', 'mIoU']) 173 | evaluator.process(data_batch=dict(), data_samples=results) 174 | metrics = evaluator.compute_metrics(evaluator.results) 175 | accelerator.print(f"Evaluation results on {name}: {metrics}") 176 | accelerator.print(f"Finished evaluating {name}") 177 | -------------------------------------------------------------------------------- /flmm/runner.py: -------------------------------------------------------------------------------- 1 | import time 2 | import warnings 3 | import mmengine 4 | from typing import Optional 5 | from collections import OrderedDict 6 | from mmengine.runner import Runner 7 | from mmengine.dist import master_only 8 | from mmengine.fileio import FileClient, join_path 9 | from mmengine.model import is_model_wrapper 10 | from mmengine.utils import apply_to, get_git_hash 11 | from mmengine.optim import OptimWrapper 12 | from mmengine.runner.checkpoint import (find_latest_checkpoint, 13 | save_checkpoint, weights_to_cpu) 14 | from xtuner.model.utils import guess_load_checkpoint 15 | 16 | 17 | class CustomRunner(Runner): 18 | def load_or_resume(self) -> None: 19 | """load or resume checkpoint.""" 20 | if self._has_loaded: 21 | return None 22 | 23 | # decide to load from checkpoint or resume from checkpoint 24 | resume_from = None 25 | if self._resume and self._load_from is None: 26 | # auto resume from the latest checkpoint 27 | resume_from = find_latest_checkpoint(self.work_dir) 28 | self.logger.info( 29 | f'Auto resumed from the latest checkpoint {resume_from}.') 30 | elif self._resume and self._load_from is not None: 31 | # resume from the specified checkpoint 32 | resume_from = self._load_from 33 | 34 | if resume_from is not None: 35 | self.resume(resume_from) 36 | self._has_loaded = True 37 | elif self._load_from is not None: 38 | # self.load_checkpoint(self._load_from) # todo: customize for deepspeed 39 | state_dict = guess_load_checkpoint(self._load_from) 40 | if is_model_wrapper(self.model): 41 | self.model.module.load_state_dict(state_dict, strict=False) 42 | else: 43 | self.model.load_state_dict(state_dict, strict=False) 44 | self.logger.info(f'Load checkpoint from {self._load_from}.') 45 | self._has_loaded = True 46 | 47 | @master_only 48 | def save_checkpoint( 49 | self, 50 | out_dir: str, 51 | filename: str, 52 | file_client_args: Optional[dict] = None, 53 | save_optimizer: bool = True, 54 | save_param_scheduler: bool = True, 55 | meta: Optional[dict] = None, 56 | by_epoch: bool = True, 57 | backend_args: Optional[dict] = None, 58 | ): 59 | """Save checkpoints. 60 | 61 | ``CheckpointHook`` invokes this method to save checkpoints 62 | periodically. 63 | 64 | Args: 65 | out_dir (str): The directory that checkpoints are saved. 66 | filename (str): The checkpoint filename. 67 | file_client_args (dict, optional): Arguments to instantiate a 68 | FileClient. See :class:`mmengine.fileio.FileClient` for 69 | details. Defaults to None. It will be deprecated in future. 70 | Please use `backend_args` instead. 71 | save_optimizer (bool): Whether to save the optimizer to 72 | the checkpoint. Defaults to True. 73 | save_param_scheduler (bool): Whether to save the param_scheduler 74 | to the checkpoint. Defaults to True. 75 | meta (dict, optional): The meta information to be saved in the 76 | checkpoint. Defaults to None. 77 | by_epoch (bool): Decide the number of epoch or iteration saved in 78 | checkpoint. Defaults to True. 79 | backend_args (dict, optional): Arguments to instantiate the 80 | prefix of uri corresponding backend. Defaults to None. 81 | New in v0.2.0. 82 | """ 83 | if meta is None: 84 | meta = {} 85 | elif not isinstance(meta, dict): 86 | raise TypeError( 87 | f'meta should be a dict or None, but got {type(meta)}') 88 | 89 | if by_epoch: 90 | # self.epoch increments 1 after 91 | # `self.call_hook('after_train_epoch)` but `save_checkpoint` is 92 | # called by `after_train_epoch`` method of `CheckpointHook` so 93 | # `epoch` should be `self.epoch + 1` 94 | meta.setdefault('epoch', self.epoch + 1) 95 | meta.setdefault('iter', self.iter) 96 | else: 97 | meta.setdefault('epoch', self.epoch) 98 | meta.setdefault('iter', self.iter + 1) 99 | 100 | if file_client_args is not None: 101 | warnings.warn( 102 | '"file_client_args" will be deprecated in future. ' 103 | 'Please use "backend_args" instead', DeprecationWarning) 104 | if backend_args is not None: 105 | raise ValueError( 106 | '"file_client_args" and "backend_args" cannot be set at ' 107 | 'the same time.') 108 | 109 | file_client = FileClient.infer_client(file_client_args, out_dir) 110 | filepath = file_client.join_path(out_dir, filename) 111 | else: 112 | filepath = join_path( # type: ignore 113 | out_dir, filename, backend_args=backend_args) 114 | 115 | meta.update( 116 | cfg=self.cfg.pretty_text, 117 | seed=self.seed, 118 | experiment_name=self.experiment_name, 119 | time=time.strftime('%Y%m%d_%H%M%S', time.localtime()), 120 | mmengine_version=mmengine.__version__ + get_git_hash()) 121 | 122 | if hasattr(self.train_dataloader.dataset, 'metainfo'): 123 | meta.update(dataset_meta=self.train_dataloader.dataset.metainfo) 124 | 125 | if is_model_wrapper(self.model): 126 | model = self.model.module 127 | else: 128 | model = self.model 129 | 130 | # model parameters 131 | model_parameters = {k: v.detach() for k, v in model.named_parameters() if v.requires_grad} 132 | 133 | checkpoint = { 134 | 'meta': 135 | meta, 136 | 'state_dict': 137 | weights_to_cpu(OrderedDict(model_parameters)), 138 | 'message_hub': 139 | apply_to(self.message_hub.state_dict(), 140 | lambda x: hasattr(x, 'cpu'), lambda x: x.cpu()), 141 | } 142 | # save optimizer state dict to checkpoint 143 | if save_optimizer: 144 | if isinstance(self.optim_wrapper, OptimWrapper): 145 | checkpoint['optimizer'] = apply_to( 146 | self.optim_wrapper.state_dict(), 147 | lambda x: hasattr(x, 'cpu'), lambda x: x.cpu()) 148 | else: 149 | raise TypeError( 150 | 'self.optim_wrapper should be an `OptimWrapper` ' 151 | 'or `OptimWrapperDict` instance, but got ' 152 | f'{self.optim_wrapper}') 153 | 154 | # save param scheduler state dict 155 | if save_param_scheduler and self.param_schedulers is None: 156 | self.logger.warning( 157 | '`save_param_scheduler` is True but `self.param_schedulers` ' 158 | 'is None, so skip saving parameter schedulers') 159 | save_param_scheduler = False 160 | if save_param_scheduler: 161 | if isinstance(self.param_schedulers, dict): 162 | checkpoint['param_schedulers'] = dict() 163 | for name, schedulers in self.param_schedulers.items(): 164 | checkpoint['param_schedulers'][name] = [] 165 | for scheduler in schedulers: 166 | state_dict = scheduler.state_dict() 167 | checkpoint['param_schedulers'][name].append(state_dict) 168 | else: 169 | checkpoint['param_schedulers'] = [] 170 | for scheduler in self.param_schedulers: # type: ignore 171 | state_dict = scheduler.state_dict() # type: ignore 172 | checkpoint['param_schedulers'].append(state_dict) 173 | 174 | self.call_hook('before_save_checkpoint', checkpoint=checkpoint) 175 | save_checkpoint( 176 | checkpoint, 177 | filepath, 178 | file_client_args=file_client_args, 179 | backend_args=backend_args) 180 | -------------------------------------------------------------------------------- /scripts/visual_cot/visual_cot_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import numpy as np 5 | from glob import glob 6 | from accelerate import Accelerator 7 | from tqdm import tqdm 8 | from accelerate.utils import gather_object 9 | from mmengine.config import Config 10 | from xtuner.registry import BUILDER 11 | from PIL import Image 12 | from xtuner.model.utils import guess_load_checkpoint 13 | import mmcv 14 | from torch.nn.functional import interpolate 15 | 16 | def get_iou(bb1, bb2): 17 | assert bb1[0] < bb1[2] 18 | assert bb1[1] < bb1[3] 19 | assert bb2[0] < bb2[2] 20 | assert bb2[1] < bb2[3] 21 | 22 | # determine the coordinates of the intersection rectangle 23 | x_left = max(bb1[0], bb2[0]) 24 | y_top = max(bb1[1], bb2[1]) 25 | x_right = min(bb1[2], bb2[2]) 26 | y_bottom = min(bb1[3], bb2[3]) 27 | 28 | if x_right < x_left or y_bottom < y_top: 29 | return 0.0 30 | 31 | # The intersection of two axis-aligned bounding boxes is always an 32 | # axis-aligned bounding box 33 | intersection_area = (x_right - x_left) * (y_bottom - y_top) 34 | 35 | # compute the area of both AABBs 36 | bb1_area = (bb1[2] - bb1[0]) * (bb1[3] - bb1[1]) 37 | bb2_area = (bb2[2] - bb2[0]) * (bb2[3] - bb2[1]) 38 | 39 | # compute the intersection over union by taking the intersection 40 | # area and dividing it by the sum of prediction + ground-truth 41 | # areas - the interesection area 42 | iou = intersection_area / float(bb1_area + bb2_area - intersection_area) 43 | assert iou >= 0.0 44 | assert iou <= 1.0 45 | return iou 46 | 47 | 48 | def draw_box(image, box): 49 | image = np.array(image.convert('RGB')) 50 | image = mmcv.imshow_bboxes(img=image, 51 | bboxes=np.array(box).reshape(1, 4), 52 | colors=(255, 0, 0), 53 | thickness=2, 54 | show=False) 55 | 56 | return Image.fromarray(image) 57 | 58 | 59 | def draw_mask(image, mask): 60 | image = np.array(image.convert('RGB')).astype(np.float32) 61 | image[mask] = image[mask] * 0.5 + np.array([255, 0, 0], dtype=np.float32).reshape(1, 1, 3) * 0.5 62 | image = image.astype(np.uint8) 63 | image = mmcv.imshow_bboxes(img=image, 64 | bboxes=np.array(box).reshape(1, 4), 65 | colors=(255, 0, 0), 66 | thickness=2, 67 | show=False) 68 | 69 | return Image.fromarray(image) 70 | 71 | 72 | if __name__ == '__main__': 73 | parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) 74 | parser.add_argument('config', help='config file path.') 75 | parser.add_argument('--checkpoint', 76 | default='checkpoints/frozen_deepseek_vl_1_3b_unet_sam_l_iter_95080.pth', type=str) 77 | parser.add_argument('--image_folder', default='data', type=str) 78 | parser.add_argument('--version', default='v1', type=str) 79 | parser.add_argument('--save_folder', default='visual_cot', type=str) 80 | parser.add_argument('--debug', action='store_true') 81 | parser.add_argument('--discard_sam', action='store_true') 82 | parser.add_argument('--box_scale', default=1.0, type=float) 83 | args = parser.parse_args() 84 | accelerator = Accelerator() 85 | model_name = os.path.basename(args.config)[:-3] 86 | os.makedirs(args.save_folder, exist_ok=True) 87 | args.save_folder = os.path.join(args.save_folder, 88 | f'{model_name}_visual_cot_{args.version}') 89 | if args.debug: 90 | args.save_folder += 'debug' 91 | os.makedirs(args.save_folder, exist_ok=True) 92 | 93 | message = [f"Hello this is GPU {accelerator.process_index}"] 94 | # collect the messages from all GPUs 95 | messages = gather_object(message) 96 | # output the messages only on the main process with accelerator.print() 97 | accelerator.print(messages) 98 | 99 | cfg = Config.fromfile(args.config) 100 | prompt_template = cfg.prompt_template 101 | tokenizer = cfg.tokenizer 102 | image_processor = cfg.image_processor 103 | prompt = cfg.get('prompt', None) 104 | 105 | print(f'Device: {accelerator.device}', flush=True) 106 | model = BUILDER.build(cfg.model) 107 | state_dict = guess_load_checkpoint(args.checkpoint) 108 | missing, unexpected = model.load_state_dict(state_dict, strict=False) 109 | accelerator.print(f"Unexpected parameters: {unexpected}") 110 | model._prepare_for_generation(image_processor=image_processor, 111 | prompt_template=prompt_template, 112 | max_thought_tokens=16, 113 | max_new_tokens=32, 114 | lmm_name=cfg.lmm_name, 115 | additional_prompt='\nAnswer the question using a single word or phrase.', 116 | box_scale=args.box_scale, 117 | use_sam=not args.discard_sam) 118 | model = model.to(device=accelerator.device) 119 | model.eval() 120 | 121 | json_files = glob("scripts/visual_cot/benchmark/*.json") 122 | for json_file in json_files: 123 | accelerator.print(f"Processing {json_file}") 124 | 125 | with open(json_file, 'r') as f: 126 | data = json.load(f) 127 | # sync GPUs and start the timer 128 | accelerator.wait_for_everyone() 129 | data_ids = list(range(len(data))) 130 | if args.debug: 131 | data_ids = data_ids[::50] 132 | 133 | results = [] 134 | # ious = [] 135 | os.makedirs(os.path.join(args.save_folder, f'{os.path.basename(json_file)[:-4]}'), exist_ok=True) 136 | with accelerator.split_between_processes(data_ids) as sub_ids: 137 | for idx in tqdm(sub_ids, disable=not accelerator.is_main_process): 138 | data_sample = data[idx] 139 | image = Image.open(os.path.join(args.image_folder, data_sample['image'][0])) 140 | question = data_sample['conversations'][0]['value'].replace( 141 | 'Please provide the bounding box coordinate ' 142 | 'of the region that can help you answer the question better.', 143 | '' 144 | ) 145 | question = question.replace('', '').strip() 146 | gt_bbox = data_sample['image'][1].split('###')[-1].replace('[', '').replace(']', '') 147 | gt_bbox = [int(x) for x in gt_bbox.split(',')] 148 | thought, box, answer, mask = getattr(model, f'visual_cot_{args.version}')(image, question, gt_bbox) 149 | # iou = get_iou(box, gt_bbox) 150 | # ious.append(iou) 151 | image = draw_box(image, box) 152 | if mask is not None: 153 | mask = interpolate(mask[None, None].float(), size=(image.height, image.width), mode='bilinear') 154 | mask = (mask[0, 0] > 0.0).cpu().numpy() 155 | image = draw_mask(image, mask) 156 | image.save(os.path.join(args.save_folder, 157 | f"{os.path.basename(json_file)[:-4]}/{os.path.basename(data_sample['image'][0])}")) 158 | results.append(dict(thought=thought, 159 | box=box, 160 | gt_bbox=gt_bbox, 161 | # iou=iou, 162 | answer=answer, 163 | question_id=data_sample['question_id'], 164 | question=question, 165 | image=data_sample['image'][0], 166 | gt=data_sample['conversations'][-1]['value'])) 167 | results = gather_object(results) 168 | # ious = gather_object(ious) 169 | if accelerator.is_main_process: 170 | accelerator.print(f"Collected {len(results)} result samples from all gpus") 171 | # accelerator.print(f"Average IoU on {json_file}: {sum(ious) / len(ious)}") 172 | with open(os.path.join(args.save_folder, os.path.basename(json_file)), 'w') as f: 173 | json.dump(results, f, indent=4) 174 | -------------------------------------------------------------------------------- /flmm/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Optional, Dict, Union, Tuple, List 3 | from PIL import Image 4 | import mmengine.fileio as fileio 5 | from mmengine.logging import print_log 6 | import io 7 | from mmcv.transforms import LoadImageFromFile, BaseTransform 8 | from xtuner.registry import BUILDER 9 | from xtuner.utils.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN 10 | import torch 11 | import torch.nn.functional as F 12 | import copy 13 | 14 | try: 15 | from petrel_client.client import Client 16 | except: 17 | Client = None 18 | 19 | 20 | class PILLoadImageFromFile(LoadImageFromFile): 21 | def __init__(self, **kwargs): 22 | backend_args = kwargs.pop('backend_args', None) 23 | if Client is None: 24 | backend_args = None 25 | super().__init__(backend_args=backend_args, **kwargs) 26 | 27 | def transform(self, results: dict) -> Optional[dict]: 28 | """Functions to load image. 29 | 30 | Args: 31 | results (dict): Result dict from 32 | :class:`mmengine.dataset.BaseDataset`. 33 | 34 | Returns: 35 | dict: The dict contains loaded image and meta information. 36 | """ 37 | 38 | filename = results['img_path'] 39 | try: 40 | if self.file_client_args is not None: 41 | file_client = fileio.FileClient.infer_client( 42 | self.file_client_args, filename) 43 | img_bytes = file_client.get(filename) 44 | else: 45 | img_bytes = fileio.get( 46 | filename, backend_args=self.backend_args) 47 | img = Image.open(io.BytesIO(img_bytes)) 48 | except Exception as e: 49 | if self.ignore_empty: 50 | return None 51 | else: 52 | raise e 53 | # in some cases, images are not read successfully, the img would be 54 | # `None`, refer to https://github.com/open-mmlab/mmpretrain/issues/1427 55 | assert img is not None, f'failed to load image: {filename}' 56 | results['img'] = img 57 | results['img_shape'] = (img.height, img.width) 58 | results['ori_shape'] = (img.height, img.width) 59 | return results 60 | 61 | 62 | class RefCOCO2PNG(BaseTransform): 63 | def __init__(self, 64 | image_processor=None, 65 | tokenizer=None, 66 | prompt_template=None, 67 | prompt='\nWhat is shown in this image?', 68 | concat=True, 69 | image2tensor=True, 70 | add_image_token=False, 71 | image_token=DEFAULT_IMAGE_TOKEN): 72 | self.tokenizer = BUILDER.build(tokenizer) 73 | self.image_processor = BUILDER.build(image_processor) 74 | self.concat = concat 75 | self.image2tensor = image2tensor 76 | self.image_token = image_token 77 | 78 | self.add_image_token = add_image_token 79 | if add_image_token: 80 | print_log(f"Manually add image token: {self.image_token}") 81 | special_tokens_dict = {'additional_special_tokens': [self.image_token, ]} 82 | num_added_toks = self.tokenizer.add_special_tokens(special_tokens_dict) 83 | assert num_added_toks == 1 84 | 85 | self.image_token_idx = self.tokenizer.encode(self.image_token, add_special_tokens=False)[-1] 86 | print_log(f"Image token: {self.tokenizer.decode(self.image_token_idx)}") 87 | 88 | self.prompt = self.tokenizer.encode( 89 | prompt_template['INSTRUCTION'].format(input=prompt), 90 | add_special_tokens=True) 91 | self.prompt_template = prompt_template 92 | 93 | def transform(self, results): 94 | if self.concat: 95 | return self.transform_concat(results) 96 | else: 97 | return self.transform_split(results) 98 | 99 | def transform_split(self, results): 100 | all_results = [] 101 | for inst_id, instant_text in enumerate(results['text']): 102 | new_results = copy.deepcopy(results) 103 | new_results['text'] = [instant_text] 104 | new_results['gt_masks'] = results['gt_masks'][inst_id:inst_id+1] 105 | all_results.append(self.transform_concat(new_results)) 106 | 107 | return all_results 108 | 109 | def transform_concat(self, results: dict): 110 | 111 | caption_input_ids = [] 112 | mask_ids = [-1] * len(self.prompt) 113 | split_token_id = self.tokenizer.encode('.', add_special_tokens=False)[-1] 114 | 115 | for inst_id, instant_text in enumerate(results['text']): 116 | segment_input_ids = self.tokenizer.encode(instant_text, add_special_tokens=False) 117 | caption_input_ids += segment_input_ids 118 | mask_ids += [inst_id] * len(segment_input_ids) 119 | 120 | caption_input_ids.append(split_token_id) 121 | mask_ids.append(-1) 122 | 123 | input_ids = self.prompt + caption_input_ids 124 | input_ids = torch.tensor(input_ids, dtype=torch.long) 125 | mask_ids = torch.tensor(mask_ids) 126 | 127 | image = results['img'] 128 | image_data = self.image_processor.preprocess(image) 129 | 130 | pixel_values = image_data['pixel_values'][0] 131 | if self.image2tensor: 132 | pixel_values = torch.from_numpy(pixel_values) 133 | meta_data = image_data['meta_datas'][0] 134 | 135 | assert len(results['gt_masks'].masks) == len(results['text']) 136 | mask_cnt = len(results['text']) 137 | 138 | masks = torch.from_numpy(results['gt_masks'].masks).float() 139 | 140 | h, w = meta_data['image_shape']['height'], meta_data['image_shape']['width'] 141 | gt_masks = masks.clone() 142 | masks = F.interpolate(masks[None], size=(h, w))[0] 143 | 144 | p_h, p_w = meta_data['padded_shape']['height'], meta_data['padded_shape']['width'] 145 | 146 | padded_masks = torch.zeros(mask_cnt, p_h, p_w, dtype=masks.dtype) 147 | padding = meta_data['padding'] 148 | 149 | padded_masks[:, padding['before_height']:p_h - padding['after_height'], 150 | padding['before_width']:p_w - padding['after_width']] = masks 151 | 152 | # todo: add labels 153 | prompt_len = len(self.prompt) 154 | labels = torch.ones_like(input_ids) * IGNORE_INDEX 155 | labels[prompt_len:] = input_ids[prompt_len:] 156 | 157 | if self.add_image_token: 158 | input_ids[input_ids == self.image_token_idx] = IMAGE_TOKEN_INDEX 159 | 160 | return dict(input_ids=input_ids, 161 | mask_ids=mask_ids, 162 | pixel_values=pixel_values, 163 | padded_masks=padded_masks, 164 | masks=masks, # shape is kept 165 | gt_masks=gt_masks, 166 | image_sizes=torch.tensor(image_data['image_sizes'][0]), 167 | image=image, 168 | meta_data=meta_data, 169 | labels=labels) 170 | 171 | 172 | if __name__ == '__main__': 173 | from mmdet.datasets import RefCocoDataset 174 | from mmengine.config import Config 175 | from mmdet.datasets.transforms import LoadAnnotations 176 | 177 | cfg = Config.fromfile('configs/fuyu/frozen_fuyu_8b_unet_sam_l_refcoco_png.py') 178 | prompt_template = cfg.prompt_template 179 | tokenizer = cfg.tokenizer 180 | image_processor = cfg.image_processor 181 | prompt = cfg.get('prompt', None) 182 | 183 | refcoco2png_params = dict( 184 | type=RefCOCO2PNG, 185 | image_processor=image_processor, 186 | tokenizer=tokenizer, 187 | prompt_template=prompt_template, 188 | 189 | ) 190 | if prompt is not None: 191 | refcoco2png_params.update(prompt=prompt) 192 | 193 | test_pipeline = [ 194 | dict(type=PILLoadImageFromFile, backend_args=None), 195 | dict( 196 | type=LoadAnnotations, 197 | with_mask=True, 198 | with_bbox=False, 199 | with_seg=False, 200 | with_label=False), 201 | refcoco2png_params 202 | ] 203 | 204 | dataset = RefCocoDataset( 205 | data_root='data/coco/', 206 | data_prefix=dict(img_path='train2014/'), 207 | text_mode='select_first', 208 | pipeline=test_pipeline, 209 | ann_file='refcoco/instances.json', 210 | split_file='refcoco/refs(unc).p', 211 | split='val' 212 | ) 213 | 214 | 215 | for data in dataset: 216 | print(data.keys()) 217 | -------------------------------------------------------------------------------- /deepseek_vl/models/image_processing_vlm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | from typing import List, Tuple, Union 21 | 22 | import numpy as np 23 | import torch 24 | import torchvision 25 | import torchvision.transforms.functional 26 | from PIL import Image 27 | from transformers import AutoImageProcessor, PretrainedConfig 28 | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature 29 | from transformers.image_utils import to_numpy_array 30 | from transformers.utils import logging 31 | from flmm.utils import multi_apply 32 | 33 | logger = logging.get_logger(__name__) 34 | 35 | ImageType = Union[np.ndarray, torch.Tensor, Image.Image] 36 | IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073) 37 | IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711) 38 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 39 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 40 | 41 | 42 | def expand2square(pil_img, background_color): 43 | pil_img = pil_img.convert('RGB') 44 | width, height = pil_img.size 45 | if width == height: 46 | result = pil_img 47 | before_height = after_height = before_width = after_width = 0 48 | elif width > height: 49 | result = Image.new(pil_img.mode, (width, width), background_color) 50 | result.paste(pil_img, (0, (width - height) // 2)) 51 | before_height = (width - height) // 2 52 | after_height = (width - height) - before_height 53 | before_width = after_width = 0 54 | else: 55 | result = Image.new(pil_img.mode, (height, height), background_color) 56 | result.paste(pil_img, ((height - width) // 2, 0)) 57 | before_width = (height - width) // 2 58 | after_width = (height - width) - before_width 59 | before_height = after_height = 0 60 | 61 | meta = dict(padding=dict(before_height=before_height, after_height=after_height, 62 | before_width=before_width, after_width=after_width), 63 | image_shape=dict(height=height, width=width), 64 | padded_shape=dict(height=max(height, width), width=max(height, width))) 65 | 66 | return result, meta 67 | 68 | 69 | class VLMImageProcessorConfig(PretrainedConfig): 70 | model_type = "deepseek_vlm" 71 | image_size: int 72 | min_size: int 73 | image_mean: Union[Tuple[float, float, float], List[float]] 74 | image_std: Union[Tuple[float, float, float], List[float]] 75 | rescale_factor: float 76 | do_normalize: bool 77 | 78 | def __init__( 79 | self, 80 | image_size: int, 81 | min_size: int = 14, 82 | image_mean: Union[Tuple[float, float, float], List[float]] = ( 83 | 0.48145466, 84 | 0.4578275, 85 | 0.40821073, 86 | ), 87 | image_std: Union[Tuple[float, float, float], List[float]] = ( 88 | 0.26862954, 89 | 0.26130258, 90 | 0.27577711, 91 | ), 92 | rescale_factor: float = 1.0 / 255.0, 93 | do_normalize: bool = True, 94 | **kwargs, 95 | ): 96 | self.image_size = image_size 97 | self.min_size = min_size 98 | self.image_mean = image_mean 99 | self.image_std = image_std 100 | self.rescale_factor = rescale_factor 101 | self.do_normalize = do_normalize 102 | 103 | super().__init__(**kwargs) 104 | 105 | 106 | class VLMImageProcessor(BaseImageProcessor): 107 | model_input_names = ["pixel_values"] 108 | 109 | def __init__( 110 | self, 111 | image_size: int, 112 | min_size: int = 14, 113 | image_mean: Union[Tuple[float, float, float], List[float]] = ( 114 | 0.48145466, 115 | 0.4578275, 116 | 0.40821073, 117 | ), 118 | image_std: Union[Tuple[float, float, float], List[float]] = ( 119 | 0.26862954, 120 | 0.26130258, 121 | 0.27577711, 122 | ), 123 | rescale_factor: float = 1.0 / 255.0, 124 | do_normalize: bool = True, 125 | **kwargs, 126 | ): 127 | super().__init__(**kwargs) 128 | 129 | self.image_size = image_size 130 | self.rescale_factor = rescale_factor 131 | self.image_mean = image_mean 132 | self.image_std = image_std 133 | self.min_size = min_size 134 | self.do_normalize = do_normalize 135 | 136 | if image_mean is None: 137 | self.background_color = (127, 127, 127) 138 | else: 139 | self.background_color = tuple([int(x * 255) for x in image_mean]) 140 | 141 | def resize(self, pil_img: Image): 142 | """ 143 | 144 | Args: 145 | pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB 146 | 147 | Returns: 148 | x (np.ndarray): [3, self.image_size, self.image_size] 149 | """ 150 | 151 | width, height = pil_img.size 152 | max_size = max(width, height) 153 | 154 | size = [ 155 | max(int(height / max_size * self.image_size), self.min_size), 156 | max(int(width / max_size * self.image_size), self.min_size), 157 | ] 158 | 159 | if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0: 160 | print(f"orig size = {pil_img.size}, new size = {size}") 161 | raise ValueError("Invalid size!") 162 | 163 | pil_img = torchvision.transforms.functional.resize( 164 | pil_img, 165 | size, 166 | interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC, 167 | antialias=True, 168 | ) 169 | 170 | pil_img, meta = expand2square(pil_img, self.background_color) 171 | x = to_numpy_array(pil_img) 172 | 173 | # [H, W, 3] -> [3, H, W] 174 | x = np.transpose(x, (2, 0, 1)) 175 | 176 | return x, meta 177 | 178 | def preprocess(self, images, return_tensors=None, **kwargs): 179 | # resize and pad to [self.image_size, self.image_size] 180 | # then convert from [H, W, 3] to [3, H, W] 181 | # images: List[np.ndarray] = [self.resize(image) for image in images] 182 | if not isinstance(images, (list, tuple)): 183 | images = [images] 184 | image_sizes = [(image.height, image.width) for image in images] 185 | images, meta_datas = multi_apply(self.resize, images) 186 | 187 | # resacle from [0, 255] -> [0, 1] 188 | images = [ 189 | self.rescale( 190 | image=image, 191 | scale=self.rescale_factor, 192 | input_data_format="channels_first", 193 | ) 194 | for image in images 195 | ] 196 | 197 | # normalize 198 | if self.do_normalize: 199 | images = [ 200 | self.normalize( 201 | image=image, 202 | mean=self.image_mean, 203 | std=self.image_std, 204 | input_data_format="channels_first", 205 | ) 206 | for image in images 207 | ] 208 | 209 | data = {"pixel_values": images} 210 | if not return_tensors: 211 | data.update({"image_sizes": image_sizes, "meta_datas": meta_datas}) 212 | output = BatchFeature(data=data, tensor_type=return_tensors) 213 | if return_tensors: 214 | output.image_sizes = image_sizes 215 | output.meta_datas = meta_datas 216 | 217 | return output 218 | 219 | @property 220 | def default_shape(self): 221 | return [3, self.image_size, self.image_size] 222 | 223 | 224 | AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor) 225 | 226 | 227 | if __name__ == "__main__": 228 | image_processor = VLMImageProcessor( 229 | image_size=1024, 230 | image_mean=IMAGENET_INCEPTION_MEAN, 231 | image_std=IMAGENET_INCEPTION_STD, 232 | do_normalize=True, 233 | ) 234 | -------------------------------------------------------------------------------- /deepseek_vl/models/clip_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | from typing import Dict, List, Literal, Optional, Tuple, Union 21 | 22 | import torch 23 | import torch.nn as nn 24 | import torchvision.transforms 25 | from einops import rearrange 26 | 27 | from deepseek_vl.models.sam import create_sam_vit 28 | from deepseek_vl.models.siglip_vit import create_siglip_vit 29 | 30 | 31 | class CLIPVisionTower(nn.Module): 32 | def __init__( 33 | self, 34 | model_name: str = "siglip_large_patch16_384", 35 | image_size: Union[Tuple[int, int], int] = 336, 36 | select_feature: str = "patch", 37 | select_layer: int = -2, 38 | select_layers: list = None, 39 | ckpt_path: str = "", 40 | pixel_mean: Optional[List[float]] = None, 41 | pixel_std: Optional[List[float]] = None, 42 | **kwargs, 43 | ): 44 | super().__init__() 45 | 46 | self.model_name = model_name 47 | self.select_feature = select_feature 48 | self.select_layer = select_layer 49 | self.select_layers = select_layers 50 | 51 | vision_tower_params = { 52 | "model_name": model_name, 53 | "image_size": image_size, 54 | "ckpt_path": ckpt_path, 55 | "select_layer": select_layer, 56 | } 57 | vision_tower_params.update(kwargs) 58 | self.vision_tower, self.forward_kwargs = self.build_vision_tower( 59 | vision_tower_params 60 | ) 61 | 62 | if pixel_mean is not None and pixel_std is not None: 63 | image_norm = torchvision.transforms.Normalize( 64 | mean=pixel_mean, std=pixel_std 65 | ) 66 | else: 67 | image_norm = None 68 | 69 | self.image_norm = image_norm 70 | 71 | def build_vision_tower(self, vision_tower_params): 72 | if self.model_name.startswith("siglip"): 73 | self.select_feature = "same" 74 | vision_tower = create_siglip_vit(**vision_tower_params) 75 | forward_kwargs = dict() 76 | 77 | elif self.model_name.startswith("sam"): 78 | vision_tower = create_sam_vit(**vision_tower_params) 79 | forward_kwargs = dict() 80 | 81 | else: # huggingface 82 | from transformers import CLIPVisionModel 83 | 84 | vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params) 85 | forward_kwargs = dict(output_hidden_states=True) 86 | 87 | return vision_tower, forward_kwargs 88 | 89 | def feature_select(self, image_forward_outs): 90 | if isinstance(image_forward_outs, torch.Tensor): 91 | # the output has been the self.select_layer"s features 92 | image_features = image_forward_outs 93 | else: 94 | image_features = image_forward_outs.hidden_states[self.select_layer] 95 | 96 | if self.select_feature == "patch": 97 | # if the output has cls_token 98 | image_features = image_features[:, 1:] 99 | elif self.select_feature == "cls_patch": 100 | image_features = image_features 101 | elif self.select_feature == "same": 102 | image_features = image_features 103 | 104 | else: 105 | raise ValueError(f"Unexpected select feature: {self.select_feature}") 106 | return image_features 107 | 108 | def forward(self, images): 109 | """ 110 | 111 | Args: 112 | images (torch.Tensor): [b, 3, H, W] 113 | 114 | Returns: 115 | image_features (torch.Tensor): [b, n_patch, d] 116 | """ 117 | 118 | if self.image_norm is not None: 119 | images = self.image_norm(images) 120 | 121 | image_forward_outs = self.vision_tower(images, **self.forward_kwargs) 122 | image_features = self.feature_select(image_forward_outs) 123 | return image_features 124 | 125 | 126 | class HybridVisionTower(nn.Module): 127 | def __init__( 128 | self, 129 | high_res_cfg: Dict, 130 | low_res_cfg: Dict, 131 | freeze_high: bool = False, 132 | freeze_low: bool = False, 133 | concat_type: Literal["feature", "sequence", "add", "tuple"] = "tuple", 134 | **ignore_kwargs, 135 | ): 136 | super().__init__() 137 | 138 | self.vision_tower_high = CLIPVisionTower(**high_res_cfg) 139 | self.vision_tower_low = CLIPVisionTower(**low_res_cfg) 140 | self.low_res_size = low_res_cfg["image_size"] 141 | self.concat_type = concat_type 142 | 143 | self.high_layer_norm = nn.LayerNorm(high_res_cfg.get("output_dim", 1024)) 144 | self.low_layer_norm = nn.LayerNorm(low_res_cfg.get("output_dim", 1024)) 145 | 146 | if freeze_high: 147 | for p_name, p in self.vision_tower_high.named_parameters(): 148 | p.requires_grad = False 149 | self.vision_tower_high = self.vision_tower_high.eval() 150 | else: 151 | # train donwsamples and neck 152 | for p_name, p in self.vision_tower_high.named_parameters(): 153 | if "downsamples" in p_name or "neck" in p_name: 154 | p.requires_grad = True 155 | else: 156 | p.requires_grad = False 157 | 158 | if freeze_low: 159 | for p in self.vision_tower_low.parameters(): 160 | p.requires_grad = False 161 | self.vision_tower_low = self.vision_tower_low.eval() 162 | 163 | self.resize = torchvision.transforms.Resize(self.low_res_size, antialias=True) 164 | 165 | def forward(self, images: torch.Tensor): 166 | """ 167 | 168 | Args: 169 | images (torch.Tensor): [bs, 3, H, W] 170 | 171 | Returns: 172 | res (torch.Tensor): [bs, t, c] 173 | """ 174 | 175 | # [bs, c, h, w] 176 | high_images = images 177 | 178 | # [bs, c, h_low, w_low] 179 | low_images = self.resize(images) 180 | 181 | # separately run two vision towers 182 | # run high_res vision tower 183 | high_res = self.vision_tower_high(high_images) 184 | # [bs, c, h, w] -> [bs, h*w, c] 185 | high_res = rearrange(high_res, "b c h w -> b (h w) c") 186 | # run low_res vision tower 187 | low_res = self.vision_tower_low(low_images) 188 | 189 | if self.concat_type == "feature": 190 | images_features = torch.cat([high_res, low_res], dim=-1) 191 | elif self.concat_type == "sequence": 192 | images_features = torch.cat([high_res, low_res], dim=1) 193 | elif self.concat_type == "add": 194 | images_features = high_res + low_res 195 | elif self.concat_type == "tuple": 196 | images_features = (high_res, low_res) 197 | 198 | else: 199 | raise ValueError( 200 | "Currently only support `feature`, `sequence`, `add` and `tuple` concat type." 201 | ) 202 | 203 | return images_features 204 | 205 | 206 | if __name__ == "__main__": 207 | image_size = 1024 208 | x = torch.zeros(2, 3, image_size, image_size).bfloat16().cuda() 209 | 210 | high_res_cfg = dict( 211 | model_name="sam_b_downsample", 212 | select_feature="same", 213 | image_size=image_size, 214 | pixel_mean=(0.48145466, 0.4578275, 0.40821073), 215 | pixel_std=(0.26862954, 0.26130258, 0.27577711), 216 | select_layer=-1, 217 | ckpt_path="", 218 | ) 219 | 220 | low_res_cfg = dict( 221 | model_name="siglip_large_patch16_384", 222 | select_feature="same", 223 | image_size=384, 224 | pixel_mean=(0.5, 0.5, 0.5), 225 | pixel_std=(0.5, 0.5, 0.5), 226 | select_layer=-1, 227 | ckpt_path="", 228 | ) 229 | 230 | net = ( 231 | HybridVisionTower( 232 | high_res_cfg=high_res_cfg, 233 | low_res_cfg=low_res_cfg, 234 | freeze_high=True, 235 | freeze_low=True, 236 | concat_type="tuple", 237 | ) 238 | .bfloat16() 239 | .cuda() 240 | ) 241 | high_x, low_x = net(x) 242 | print(x.shape, high_x.shape, low_x.shape) 243 | -------------------------------------------------------------------------------- /mgm/model/multimodal_encoder/openclip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import os 5 | import json 6 | import logging 7 | import deepspeed 8 | from pathlib import Path 9 | from open_clip.factory import load_state_dict, get_model_config 10 | from open_clip.model import CLIPVisionCfg, CLIPTextCfg, _build_vision_tower, convert_to_custom_text_state_dict, resize_pos_embed 11 | from typing import Dict, Optional 12 | from transformers.deepspeed import deepspeed_config, is_deepspeed_zero3_enabled 13 | 14 | 15 | class OpenCLIPVisionTower(nn.Module): 16 | def __init__(self, vision_tower, args, delay_load=False): 17 | super().__init__() 18 | 19 | self.is_loaded = False 20 | self.vision_tower_name = vision_tower 21 | self.vision_config = json.load(open(os.path.join(vision_tower,'open_clip_config.json'), 'r')) 22 | self.is_optimize = getattr(args, 'optimize_vision_tower_aux', False) 23 | self.is_droppath = getattr(args, 'drop_path', True) 24 | 25 | if not delay_load: 26 | self.load_model() 27 | 28 | def load_model(self): 29 | ckpt_path = os.path.join(self.vision_tower_name, 'open_clip_pytorch_model.bin') 30 | if 'convnext' in self.vision_tower_name: 31 | if 'large' in self.vision_tower_name and ('d-320' in self.vision_tower_name 32 | or 'd_320' in self.vision_tower_name): 33 | self.model_type = 'convnext_large_d_320' 34 | self.model_channel = [192, 384, 768, 1536] # stage 0-3 35 | elif 'base' in self.vision_tower_name and 'w-320' in self.vision_tower_name: 36 | self.model_type = 'convnext_base_w_320' 37 | self.model_channel = [128, 256, 512, 1024] 38 | elif 'xxlarge' in self.vision_tower_name: 39 | self.model_type = 'convnext_xxlarge' 40 | self.model_channel = [384, 768, 1536, 3072] 41 | 42 | clip_model = CLIP(**get_model_config(self.model_type), drop_path=self.is_droppath) 43 | clip_model.visual.trunk.norm_pre = None 44 | clip_model.visual.trunk.head = None 45 | clip_model.visual.head = None 46 | print(f'Loading pretrained weights ({self.model_type}).') 47 | load_checkpoint(clip_model, ckpt_path, strict=False) 48 | 49 | self.is_loaded = True 50 | # decompose stem and stages blocks in vision tower 51 | self.vision_stem = clip_model.visual.trunk.stem 52 | self.vision_stages = clip_model.visual.trunk.stages 53 | self.vision_stem.requires_grad_(False) 54 | self.vision_stages.requires_grad_(False) 55 | 56 | def forward(self, images): 57 | if type(images) is list: 58 | image_features = [] 59 | for image in images: 60 | image_feature = self.backbone(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)) 61 | image_features.append(image_feature) 62 | else: 63 | image_features = self.backbone(images.to(device=self.device, dtype=self.dtype)) 64 | 65 | return image_features 66 | 67 | def backbone(self, images): 68 | if not self.is_optimize: 69 | with torch.no_grad(): 70 | results = self.basic_forward(images) 71 | else: 72 | results = self.basic_forward(images) 73 | 74 | target_size = (results['stage_0'].shape[-2], results['stage_0'].shape[-1]) 75 | result_cat = [] 76 | for _stage in results: 77 | if _stage == 'stage_0': 78 | result_cat.append(results[_stage].contiguous()) 79 | else: 80 | result_cat.append(F.interpolate(results[_stage].float().contiguous() , 81 | size=target_size, 82 | mode='bilinear', 83 | align_corners=False).to(dtype=results[_stage].dtype)) 84 | result_cat = torch.cat(result_cat, dim=1) 85 | 86 | return result_cat.contiguous() 87 | 88 | def basic_forward(self, images): 89 | results = {} 90 | x = self.vision_stem(images) 91 | for _idx in range(len(self.vision_stages)): 92 | x = self.vision_stages[_idx](x) 93 | results[f'stage_{_idx}'] = x 94 | return results 95 | 96 | @property 97 | def dummy_feature(self): 98 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 99 | 100 | @property 101 | def dtype(self): 102 | return self.vision_stem[0].weight.dtype 103 | 104 | @property 105 | def device(self): 106 | return self.vision_stem[0].weight.device 107 | 108 | @property 109 | def config(self): 110 | return self.vision_config 111 | 112 | @property 113 | def hidden_size(self): 114 | return sum(self.model_channel) 115 | 116 | # modified function from open_clip to support zero3 stage 117 | def load_checkpoint(model, checkpoint_path, strict=True): 118 | if Path(checkpoint_path).suffix in ('.npz', '.npy'): 119 | from open_clip.big_vision import load_big_vision_weights 120 | load_big_vision_weights(model, checkpoint_path) 121 | return {} 122 | 123 | state_dict = load_state_dict(checkpoint_path) 124 | # detect old format and make compatible with new format 125 | if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): 126 | state_dict = convert_to_custom_text_state_dict(state_dict) 127 | # If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712 128 | # if 'logit_bias' not in state_dict and model.logit_bias is not None: 129 | # state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"]) 130 | # Certain text transformers no longer expect position_ids after transformers==4.31 131 | position_id_key = 'text.transformer.embeddings.position_ids' 132 | if position_id_key in state_dict and not hasattr(model, position_id_key): 133 | del state_dict[position_id_key] 134 | resize_pos_embed(state_dict, model) 135 | # resize_text_pos_embed(state_dict, model) 136 | #incompatible_keys = model.load_state_dict(state_dict, strict=strict) 137 | if is_deepspeed_zero3_enabled(): 138 | 139 | error_msgs = [] 140 | 141 | def load(module: nn.Module, state_dict, prefix=""): 142 | metadata = None 143 | 144 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 145 | args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) 146 | # Parameters of module and children will start with prefix. We can exit early if there are none in this 147 | # state_dict 148 | if len([key for key in state_dict if key.startswith(prefix)]) > 0: 149 | if is_deepspeed_zero3_enabled(): 150 | # In sharded models, each shard has only part of the full state_dict, so only gather 151 | # parameters that are in the current state_dict. 152 | named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False)) 153 | params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters] 154 | if len(params_to_gather) > 0: 155 | # because zero3 puts placeholders in model params, this context 156 | # manager gathers (unpartitions) the params of the current layer, then loads from 157 | # the state dict and then re-partitions them again 158 | with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): 159 | if torch.distributed.get_rank() == 0: 160 | module._load_from_state_dict(*args) 161 | else: 162 | module._load_from_state_dict(*args) 163 | 164 | for name, child in module._modules.items(): 165 | if child is not None: 166 | load(child, state_dict, prefix + name + ".") 167 | 168 | load(model, state_dict) 169 | incompatible_keys = [] 170 | else: 171 | incompatible_keys = model.load_state_dict(state_dict, strict=strict) 172 | logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}") 173 | return incompatible_keys 174 | 175 | class CLIP(nn.Module): 176 | output_dict: torch.jit.Final[bool] 177 | 178 | def __init__( 179 | self, 180 | embed_dim: int, 181 | vision_cfg: CLIPVisionCfg, 182 | text_cfg: CLIPTextCfg, 183 | quick_gelu: bool = False, 184 | cast_dtype: Optional[torch.dtype] = None, 185 | output_dict: bool = False, 186 | drop_path: bool = False, 187 | ): 188 | super().__init__() 189 | self.output_dict = output_dict 190 | 191 | # Fix drop path during training 192 | if not drop_path: 193 | print('Not using drop path during training.') 194 | vision_cfg['timm_drop_path'] = 0.0 195 | 196 | self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) 197 | -------------------------------------------------------------------------------- /segment_anything/modeling/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import Tensor, nn 9 | 10 | import math 11 | from typing import Tuple, Type 12 | 13 | from .common import MLPBlock 14 | 15 | 16 | class TwoWayTransformer(nn.Module): 17 | def __init__( 18 | self, 19 | depth: int, 20 | embedding_dim: int, 21 | num_heads: int, 22 | mlp_dim: int, 23 | activation: Type[nn.Module] = nn.ReLU, 24 | attention_downsample_rate: int = 2, 25 | ) -> None: 26 | """ 27 | A transformer decoder that attends to an input image using 28 | queries whose positional embedding is supplied. 29 | 30 | Args: 31 | depth (int): number of layers in the transformer 32 | embedding_dim (int): the channel dimension for the input embeddings 33 | num_heads (int): the number of heads for multihead attention. Must 34 | divide embedding_dim 35 | mlp_dim (int): the channel dimension internal to the MLP block 36 | activation (nn.Module): the activation to use in the MLP block 37 | """ 38 | super().__init__() 39 | self.depth = depth 40 | self.embedding_dim = embedding_dim 41 | self.num_heads = num_heads 42 | self.mlp_dim = mlp_dim 43 | self.layers = nn.ModuleList() 44 | 45 | for i in range(depth): 46 | self.layers.append( 47 | TwoWayAttentionBlock( 48 | embedding_dim=embedding_dim, 49 | num_heads=num_heads, 50 | mlp_dim=mlp_dim, 51 | activation=activation, 52 | attention_downsample_rate=attention_downsample_rate, 53 | skip_first_layer_pe=(i == 0), 54 | ) 55 | ) 56 | 57 | self.final_attn_token_to_image = Attention( 58 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 59 | ) 60 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 61 | 62 | def forward( 63 | self, 64 | image_embedding: Tensor, 65 | image_pe: Tensor, 66 | point_embedding: Tensor, 67 | ) -> Tuple[Tensor, Tensor]: 68 | """ 69 | Args: 70 | image_embedding (torch.Tensor): image to attend to. Should be shape 71 | B x embedding_dim x h x w for any h and w. 72 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 73 | have the same shape as image_embedding. 74 | point_embedding (torch.Tensor): the embedding to add to the query points. 75 | Must have shape B x N_points x embedding_dim for any N_points. 76 | 77 | Returns: 78 | torch.Tensor: the processed point_embedding 79 | torch.Tensor: the processed image_embedding 80 | """ 81 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 82 | bs, c, h, w = image_embedding.shape 83 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 84 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 85 | 86 | # Prepare queries 87 | queries = point_embedding 88 | keys = image_embedding 89 | 90 | # Apply transformer blocks and final layernorm 91 | for layer in self.layers: 92 | queries, keys = layer( 93 | queries=queries, 94 | keys=keys, 95 | query_pe=point_embedding, 96 | key_pe=image_pe, 97 | ) 98 | 99 | # Apply the final attention layer from the points to the image 100 | q = queries + point_embedding 101 | k = keys + image_pe 102 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 103 | queries = queries + attn_out 104 | queries = self.norm_final_attn(queries) 105 | 106 | return queries, keys 107 | 108 | 109 | class TwoWayAttentionBlock(nn.Module): 110 | def __init__( 111 | self, 112 | embedding_dim: int, 113 | num_heads: int, 114 | mlp_dim: int = 2048, 115 | activation: Type[nn.Module] = nn.ReLU, 116 | attention_downsample_rate: int = 2, 117 | skip_first_layer_pe: bool = False, 118 | ) -> None: 119 | """ 120 | A transformer block with four layers: (1) self-attention of sparse 121 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 122 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 123 | inputs. 124 | 125 | Arguments: 126 | embedding_dim (int): the channel dimension of the embeddings 127 | num_heads (int): the number of heads in the attention layers 128 | mlp_dim (int): the hidden dimension of the mlp block 129 | activation (nn.Module): the activation of the mlp block 130 | skip_first_layer_pe (bool): skip the PE on the first layer 131 | """ 132 | super().__init__() 133 | self.self_attn = Attention(embedding_dim, num_heads) 134 | self.norm1 = nn.LayerNorm(embedding_dim) 135 | 136 | self.cross_attn_token_to_image = Attention( 137 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 138 | ) 139 | self.norm2 = nn.LayerNorm(embedding_dim) 140 | 141 | self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) 142 | self.norm3 = nn.LayerNorm(embedding_dim) 143 | 144 | self.norm4 = nn.LayerNorm(embedding_dim) 145 | self.cross_attn_image_to_token = Attention( 146 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 147 | ) 148 | 149 | self.skip_first_layer_pe = skip_first_layer_pe 150 | 151 | def forward( 152 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 153 | ) -> Tuple[Tensor, Tensor]: 154 | # Self attention block 155 | if self.skip_first_layer_pe: 156 | queries = self.self_attn(q=queries, k=queries, v=queries) 157 | else: 158 | q = queries + query_pe 159 | attn_out = self.self_attn(q=q, k=q, v=queries) 160 | queries = queries + attn_out 161 | queries = self.norm1(queries) 162 | 163 | # Cross attention block, tokens attending to image embedding 164 | q = queries + query_pe 165 | k = keys + key_pe 166 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 167 | queries = queries + attn_out 168 | queries = self.norm2(queries) 169 | 170 | # MLP block 171 | mlp_out = self.mlp(queries) 172 | queries = queries + mlp_out 173 | queries = self.norm3(queries) 174 | 175 | # Cross attention block, image embedding attending to tokens 176 | q = queries + query_pe 177 | k = keys + key_pe 178 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 179 | keys = keys + attn_out 180 | keys = self.norm4(keys) 181 | 182 | return queries, keys 183 | 184 | 185 | class Attention(nn.Module): 186 | """ 187 | An attention layer that allows for downscaling the size of the embedding 188 | after projection to queries, keys, and values. 189 | """ 190 | 191 | def __init__( 192 | self, 193 | embedding_dim: int, 194 | num_heads: int, 195 | downsample_rate: int = 1, 196 | ) -> None: 197 | super().__init__() 198 | self.embedding_dim = embedding_dim 199 | self.internal_dim = embedding_dim // downsample_rate 200 | self.num_heads = num_heads 201 | assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." 202 | 203 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 204 | self.k_proj = nn.Linear(embedding_dim, self.internal_dim) 205 | self.v_proj = nn.Linear(embedding_dim, self.internal_dim) 206 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 207 | 208 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 209 | b, n, c = x.shape 210 | x = x.reshape(b, n, num_heads, c // num_heads) 211 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 212 | 213 | def _recombine_heads(self, x: Tensor) -> Tensor: 214 | b, n_heads, n_tokens, c_per_head = x.shape 215 | x = x.transpose(1, 2) 216 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 217 | 218 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 219 | # Input projections 220 | q = self.q_proj(q) 221 | k = self.k_proj(k) 222 | v = self.v_proj(v) 223 | 224 | # Separate into heads 225 | q = self._separate_heads(q, self.num_heads) 226 | k = self._separate_heads(k, self.num_heads) 227 | v = self._separate_heads(v, self.num_heads) 228 | 229 | # Attention 230 | _, _, _, c_per_head = q.shape 231 | attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens 232 | attn = attn / math.sqrt(c_per_head) 233 | attn = torch.softmax(attn, dim=-1) 234 | 235 | # Get output 236 | out = attn @ v 237 | out = self._recombine_heads(out) 238 | out = self.out_proj(out) 239 | 240 | return out 241 | -------------------------------------------------------------------------------- /flmm/datasets/llava_processors.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Union 2 | 3 | import numpy as np 4 | 5 | from transformers.image_processing_utils import BatchFeature, get_size_dict 6 | from transformers.image_transforms import ( 7 | convert_to_rgb, 8 | get_resize_output_image_size, 9 | resize, 10 | to_channel_dimension_format, 11 | ) 12 | from transformers.image_utils import ( 13 | ChannelDimension, 14 | ImageInput, 15 | PILImageResampling, 16 | get_image_size, 17 | infer_channel_dimension_format, 18 | is_scaled_image, 19 | make_list_of_images, 20 | to_numpy_array, 21 | valid_images, 22 | validate_kwargs, 23 | validate_preprocess_arguments, 24 | ) 25 | from transformers.utils import TensorType 26 | from transformers.models.clip.image_processing_clip import logger, CLIPImageProcessor 27 | from flmm.utils import multi_apply 28 | 29 | 30 | class CustomLlavaImageProcessor(CLIPImageProcessor): 31 | 32 | def resize( 33 | self, 34 | image: np.ndarray, 35 | size: Dict[str, int], 36 | resample: PILImageResampling = PILImageResampling.BICUBIC, 37 | data_format: Optional[Union[str, ChannelDimension]] = None, 38 | input_data_format: Optional[Union[str, ChannelDimension]] = None, 39 | **kwargs, 40 | ) -> np.ndarray: 41 | """ 42 | Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge 43 | resized to keep the input aspect ratio. 44 | 45 | Args: 46 | image (`np.ndarray`): 47 | Image to resize. 48 | size (`Dict[str, int]`): 49 | Size of the output image. 50 | resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): 51 | Resampling filter to use when resiizing the image. 52 | data_format (`str` or `ChannelDimension`, *optional*): 53 | The channel dimension format of the image. If not provided, it will be the same as the input image. 54 | input_data_format (`ChannelDimension` or `str`, *optional*): 55 | The channel dimension format of the input image. If not provided, it will be inferred. 56 | """ 57 | default_to_square = True 58 | if "shortest_edge" in size: 59 | size = size["shortest_edge"] 60 | default_to_square = False 61 | # customization: force the largest edge to size 62 | h, w = get_image_size(image, channel_dim=input_data_format) 63 | if h > w: 64 | size = (size, int(w * size / h)) 65 | else: 66 | size = (int(h * size / w), size) 67 | elif "height" in size and "width" in size: 68 | size = (size["height"], size["width"]) 69 | else: 70 | raise ValueError("Size must contain either 'shortest_edge' or 'height' and 'width'.") 71 | 72 | output_size = get_resize_output_image_size( 73 | image, 74 | size=size, 75 | default_to_square=default_to_square, 76 | input_data_format=input_data_format, 77 | ) 78 | return resize( 79 | image, 80 | size=output_size, 81 | resample=resample, 82 | data_format=data_format, 83 | input_data_format=input_data_format, 84 | **kwargs, 85 | ) 86 | 87 | def preprocess( 88 | self, 89 | images: ImageInput, 90 | do_resize: bool = None, 91 | size: Dict[str, int] = None, 92 | resample: PILImageResampling = None, 93 | do_center_crop: bool = None, 94 | crop_size: int = None, 95 | do_rescale: bool = None, 96 | rescale_factor: float = None, 97 | do_normalize: bool = None, 98 | image_mean: Optional[Union[float, List[float]]] = None, 99 | image_std: Optional[Union[float, List[float]]] = None, 100 | do_convert_rgb: bool = None, 101 | return_tensors: Optional[Union[str, TensorType]] = None, 102 | data_format: Optional[ChannelDimension] = ChannelDimension.FIRST, 103 | input_data_format: Optional[Union[str, ChannelDimension]] = None, 104 | **kwargs, 105 | ): 106 | do_resize = do_resize if do_resize is not None else self.do_resize 107 | size = size if size is not None else self.size 108 | size = get_size_dict(size, param_name="size", default_to_square=False) 109 | resample = resample if resample is not None else self.resample 110 | do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop 111 | crop_size = crop_size if crop_size is not None else self.crop_size 112 | crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) 113 | do_rescale = do_rescale if do_rescale is not None else self.do_rescale 114 | rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor 115 | do_normalize = do_normalize if do_normalize is not None else self.do_normalize 116 | image_mean = image_mean if image_mean is not None else self.image_mean 117 | image_std = image_std if image_std is not None else self.image_std 118 | do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb 119 | 120 | validate_kwargs(captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_processor_keys) 121 | 122 | images = make_list_of_images(images) 123 | 124 | if not valid_images(images): 125 | raise ValueError( 126 | "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " 127 | "torch.Tensor, tf.Tensor or jax.ndarray." 128 | ) 129 | validate_preprocess_arguments( 130 | do_rescale=do_rescale, 131 | rescale_factor=rescale_factor, 132 | do_normalize=do_normalize, 133 | image_mean=image_mean, 134 | image_std=image_std, 135 | do_center_crop=do_center_crop, 136 | crop_size=crop_size, 137 | do_resize=do_resize, 138 | size=size, 139 | resample=resample, 140 | ) 141 | 142 | if do_convert_rgb: 143 | images = [convert_to_rgb(image) for image in images] 144 | 145 | # All transformations expect numpy arrays. 146 | images = [to_numpy_array(image) for image in images] 147 | 148 | if is_scaled_image(images[0]) and do_rescale: 149 | logger.warning_once( 150 | "It looks like you are trying to rescale already rescaled images. If the input" 151 | " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." 152 | ) 153 | 154 | if input_data_format is None: 155 | # We assume that all images have the same channel dimension format. 156 | input_data_format = infer_channel_dimension_format(images[0]) 157 | 158 | image_sizes = [get_image_size(image, channel_dim=input_data_format) for image in images] 159 | 160 | if do_resize: 161 | images = [ 162 | self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format) 163 | for image in images 164 | ] 165 | 166 | # we do not apppy center crop 167 | # if do_center_crop: 168 | # images = [ 169 | # self.center_crop(image=image, size=crop_size, input_data_format=input_data_format) for image in images 170 | # ] 171 | 172 | images, meta_datas = multi_apply(self.pad, images) 173 | 174 | if do_rescale: 175 | images = [ 176 | self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) 177 | for image in images 178 | ] 179 | 180 | if do_normalize: 181 | images = [ 182 | self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) 183 | for image in images 184 | ] 185 | 186 | images = [ 187 | to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images 188 | ] 189 | 190 | data = {"pixel_values": images, "image_sizes": image_sizes, "meta_datas": meta_datas} 191 | 192 | return BatchFeature(data=data, tensor_type=return_tensors) 193 | 194 | 195 | def pad(self, image): 196 | pad_value = np.array(tuple(int(x * 255) for x in self.image_mean), dtype=image.dtype) 197 | assert isinstance(image, np.ndarray) 198 | h, w, _ = image.shape 199 | size = max(h, w) 200 | new_image = np.ones((size, size, 3), dtype=image.dtype) * pad_value 201 | 202 | pad_height, pad_width = size - h, size - w 203 | before_height, before_width = pad_height // 2, pad_width // 2 204 | after_height, after_width = pad_height - before_height, pad_width - before_width 205 | 206 | new_image[before_height:size-after_height, before_width:size-after_width] = image 207 | 208 | meta = dict(padding=dict(before_height=before_height, after_height=after_height, 209 | before_width=before_width, after_width=after_width), 210 | image_shape=dict(height=h, width=w), 211 | padded_shape=dict(height=size, width=size)) 212 | 213 | return new_image, meta 214 | --------------------------------------------------------------------------------