├── .gitattributes ├── .gitignore ├── README.md ├── __init__.py ├── deepseek_vl ├── __init__.py ├── models │ ├── __init__.py │ ├── clip_encoder.py │ ├── image_processing_vlm.py │ ├── modeling_vlm.py │ ├── processing_vlm.py │ ├── projector.py │ ├── sam.py │ └── siglip_vit.py └── utils │ ├── __init__.py │ ├── conversation.py │ └── io.py ├── nodes.py └── requirements.txt /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | checkpoints/ 3 | *.py[cod] 4 | *$py.class 5 | *.egg-info 6 | .pytest_cache -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI nodes to use DeepSeek-VL 2 | 3 | https://huggingface.co/deepseek-ai 4 | 5 | ![image](https://github.com/kijai/ComfyUI-DeepSeek-VL/assets/40791699/6a1d7872-7960-48f3-a079-21903761eddb) 6 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | 3 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] -------------------------------------------------------------------------------- /deepseek_vl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | 21 | # check if python version is above 3.10 22 | import sys 23 | 24 | if sys.version_info >= (3, 10): 25 | print("Python version is above 3.10, patching the collections module.") 26 | # Monkey patch collections 27 | import collections 28 | import collections.abc 29 | 30 | for type_name in collections.abc.__all__: 31 | setattr(collections, type_name, getattr(collections.abc, type_name)) 32 | -------------------------------------------------------------------------------- /deepseek_vl/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | from .image_processing_vlm import VLMImageProcessor 21 | from .modeling_vlm import MultiModalityCausalLM 22 | from .processing_vlm import VLChatProcessor 23 | 24 | __all__ = [ 25 | "VLMImageProcessor", 26 | "VLChatProcessor", 27 | "MultiModalityCausalLM", 28 | ] 29 | -------------------------------------------------------------------------------- /deepseek_vl/models/clip_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | from typing import Dict, List, Literal, Optional, Tuple, Union 21 | 22 | import torch 23 | import torch.nn as nn 24 | import torchvision.transforms 25 | from einops import rearrange 26 | 27 | from .sam import create_sam_vit 28 | from .siglip_vit import create_siglip_vit 29 | 30 | 31 | class CLIPVisionTower(nn.Module): 32 | def __init__( 33 | self, 34 | model_name: str = "siglip_large_patch16_384", 35 | image_size: Union[Tuple[int, int], int] = 336, 36 | select_feature: str = "patch", 37 | select_layer: int = -2, 38 | select_layers: list = None, 39 | ckpt_path: str = "", 40 | pixel_mean: Optional[List[float]] = None, 41 | pixel_std: Optional[List[float]] = None, 42 | **kwargs, 43 | ): 44 | super().__init__() 45 | 46 | self.model_name = model_name 47 | self.select_feature = select_feature 48 | self.select_layer = select_layer 49 | self.select_layers = select_layers 50 | 51 | vision_tower_params = { 52 | "model_name": model_name, 53 | "image_size": image_size, 54 | "ckpt_path": ckpt_path, 55 | "select_layer": select_layer, 56 | } 57 | vision_tower_params.update(kwargs) 58 | self.vision_tower, self.forward_kwargs = self.build_vision_tower( 59 | vision_tower_params 60 | ) 61 | 62 | if pixel_mean is not None and pixel_std is not None: 63 | image_norm = torchvision.transforms.Normalize( 64 | mean=pixel_mean, std=pixel_std 65 | ) 66 | else: 67 | image_norm = None 68 | 69 | self.image_norm = image_norm 70 | 71 | def build_vision_tower(self, vision_tower_params): 72 | if self.model_name.startswith("siglip"): 73 | self.select_feature = "same" 74 | vision_tower = create_siglip_vit(**vision_tower_params) 75 | forward_kwargs = dict() 76 | 77 | elif self.model_name.startswith("sam"): 78 | vision_tower = create_sam_vit(**vision_tower_params) 79 | forward_kwargs = dict() 80 | 81 | else: # huggingface 82 | from transformers import CLIPVisionModel 83 | 84 | vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params) 85 | forward_kwargs = dict(output_hidden_states=True) 86 | 87 | return vision_tower, forward_kwargs 88 | 89 | def feature_select(self, image_forward_outs): 90 | if isinstance(image_forward_outs, torch.Tensor): 91 | # the output has been the self.select_layer"s features 92 | image_features = image_forward_outs 93 | else: 94 | image_features = image_forward_outs.hidden_states[self.select_layer] 95 | 96 | if self.select_feature == "patch": 97 | # if the output has cls_token 98 | image_features = image_features[:, 1:] 99 | elif self.select_feature == "cls_patch": 100 | image_features = image_features 101 | elif self.select_feature == "same": 102 | image_features = image_features 103 | 104 | else: 105 | raise ValueError(f"Unexpected select feature: {self.select_feature}") 106 | return image_features 107 | 108 | def forward(self, images): 109 | """ 110 | 111 | Args: 112 | images (torch.Tensor): [b, 3, H, W] 113 | 114 | Returns: 115 | image_features (torch.Tensor): [b, n_patch, d] 116 | """ 117 | 118 | if self.image_norm is not None: 119 | images = self.image_norm(images) 120 | 121 | image_forward_outs = self.vision_tower(images, **self.forward_kwargs) 122 | image_features = self.feature_select(image_forward_outs) 123 | return image_features 124 | 125 | 126 | class HybridVisionTower(nn.Module): 127 | def __init__( 128 | self, 129 | high_res_cfg: Dict, 130 | low_res_cfg: Dict, 131 | freeze_high: bool = False, 132 | freeze_low: bool = False, 133 | concat_type: Literal["feature", "sequence", "add", "tuple"] = "tuple", 134 | **ignore_kwargs, 135 | ): 136 | super().__init__() 137 | 138 | self.vision_tower_high = CLIPVisionTower(**high_res_cfg) 139 | self.vision_tower_low = CLIPVisionTower(**low_res_cfg) 140 | self.low_res_size = low_res_cfg["image_size"] 141 | self.concat_type = concat_type 142 | 143 | self.high_layer_norm = nn.LayerNorm(high_res_cfg.get("output_dim", 1024)) 144 | self.low_layer_norm = nn.LayerNorm(low_res_cfg.get("output_dim", 1024)) 145 | 146 | if freeze_high: 147 | for p_name, p in self.vision_tower_high.named_parameters(): 148 | p.requires_grad = False 149 | self.vision_tower_high = self.vision_tower_high.eval() 150 | else: 151 | # train donwsamples and neck 152 | for p_name, p in self.vision_tower_high.named_parameters(): 153 | if "downsamples" in p_name or "neck" in p_name: 154 | p.requires_grad = True 155 | else: 156 | p.requires_grad = False 157 | 158 | if freeze_low: 159 | for p in self.vision_tower_low.parameters(): 160 | p.requires_grad = False 161 | self.vision_tower_low = self.vision_tower_low.eval() 162 | 163 | self.resize = torchvision.transforms.Resize(self.low_res_size, antialias=True) 164 | 165 | def forward(self, images: torch.Tensor): 166 | """ 167 | 168 | Args: 169 | images (torch.Tensor): [bs, 3, H, W] 170 | 171 | Returns: 172 | res (torch.Tensor): [bs, t, c] 173 | """ 174 | 175 | # [bs, c, h, w] 176 | high_images = images 177 | 178 | # [bs, c, h_low, w_low] 179 | low_images = self.resize(images) 180 | 181 | # separately run two vision towers 182 | # run high_res vision tower 183 | high_res = self.vision_tower_high(high_images) 184 | # [bs, c, h, w] -> [bs, h*w, c] 185 | high_res = rearrange(high_res, "b c h w -> b (h w) c") 186 | # run low_res vision tower 187 | low_res = self.vision_tower_low(low_images) 188 | 189 | if self.concat_type == "feature": 190 | images_features = torch.cat([high_res, low_res], dim=-1) 191 | elif self.concat_type == "sequence": 192 | images_features = torch.cat([high_res, low_res], dim=1) 193 | elif self.concat_type == "add": 194 | images_features = high_res + low_res 195 | elif self.concat_type == "tuple": 196 | images_features = (high_res, low_res) 197 | 198 | else: 199 | raise ValueError( 200 | "Currently only support `feature`, `sequence`, `add` and `tuple` concat type." 201 | ) 202 | 203 | return images_features 204 | 205 | 206 | if __name__ == "__main__": 207 | image_size = 1024 208 | x = torch.zeros(2, 3, image_size, image_size).bfloat16().cuda() 209 | 210 | high_res_cfg = dict( 211 | model_name="sam_b_downsample", 212 | select_feature="same", 213 | image_size=image_size, 214 | pixel_mean=(0.48145466, 0.4578275, 0.40821073), 215 | pixel_std=(0.26862954, 0.26130258, 0.27577711), 216 | select_layer=-1, 217 | ckpt_path="", 218 | ) 219 | 220 | low_res_cfg = dict( 221 | model_name="siglip_large_patch16_384", 222 | select_feature="same", 223 | image_size=384, 224 | pixel_mean=(0.5, 0.5, 0.5), 225 | pixel_std=(0.5, 0.5, 0.5), 226 | select_layer=-1, 227 | ckpt_path="", 228 | ) 229 | 230 | net = ( 231 | HybridVisionTower( 232 | high_res_cfg=high_res_cfg, 233 | low_res_cfg=low_res_cfg, 234 | freeze_high=True, 235 | freeze_low=True, 236 | concat_type="tuple", 237 | ) 238 | .bfloat16() 239 | .cuda() 240 | ) 241 | high_x, low_x = net(x) 242 | print(x.shape, high_x.shape, low_x.shape) 243 | -------------------------------------------------------------------------------- /deepseek_vl/models/image_processing_vlm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | from typing import List, Tuple, Union 21 | 22 | import numpy as np 23 | import torch 24 | import torchvision 25 | import torchvision.transforms.functional 26 | from PIL import Image 27 | from transformers import AutoImageProcessor, PretrainedConfig 28 | from transformers.image_processing_utils import BaseImageProcessor, BatchFeature 29 | from transformers.image_utils import to_numpy_array 30 | from transformers.utils import logging 31 | 32 | logger = logging.get_logger(__name__) 33 | 34 | ImageType = Union[np.ndarray, torch.Tensor, Image.Image] 35 | IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073) 36 | IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711) 37 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 38 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 39 | 40 | 41 | def expand2square(pil_img, background_color): 42 | width, height = pil_img.size 43 | if width == height: 44 | return pil_img 45 | elif width > height: 46 | result = Image.new(pil_img.mode, (width, width), background_color) 47 | result.paste(pil_img, (0, (width - height) // 2)) 48 | return result 49 | else: 50 | result = Image.new(pil_img.mode, (height, height), background_color) 51 | result.paste(pil_img, ((height - width) // 2, 0)) 52 | return result 53 | 54 | 55 | class VLMImageProcessorConfig(PretrainedConfig): 56 | model_type = "deepseek_vlm" 57 | image_size: int 58 | min_size: int 59 | image_mean: Union[Tuple[float, float, float], List[float]] 60 | image_std: Union[Tuple[float, float, float], List[float]] 61 | rescale_factor: float 62 | do_normalize: bool 63 | 64 | def __init__( 65 | self, 66 | image_size: int, 67 | min_size: int = 14, 68 | image_mean: Union[Tuple[float, float, float], List[float]] = ( 69 | 0.48145466, 70 | 0.4578275, 71 | 0.40821073, 72 | ), 73 | image_std: Union[Tuple[float, float, float], List[float]] = ( 74 | 0.26862954, 75 | 0.26130258, 76 | 0.27577711, 77 | ), 78 | rescale_factor: float = 1.0 / 255.0, 79 | do_normalize: bool = True, 80 | **kwargs, 81 | ): 82 | self.image_size = image_size 83 | self.min_size = min_size 84 | self.image_mean = image_mean 85 | self.image_std = image_std 86 | self.rescale_factor = rescale_factor 87 | self.do_normalize = do_normalize 88 | 89 | super().__init__(**kwargs) 90 | 91 | 92 | class VLMImageProcessor(BaseImageProcessor): 93 | model_input_names = ["pixel_values"] 94 | 95 | def __init__( 96 | self, 97 | image_size: int, 98 | min_size: int = 14, 99 | image_mean: Union[Tuple[float, float, float], List[float]] = ( 100 | 0.48145466, 101 | 0.4578275, 102 | 0.40821073, 103 | ), 104 | image_std: Union[Tuple[float, float, float], List[float]] = ( 105 | 0.26862954, 106 | 0.26130258, 107 | 0.27577711, 108 | ), 109 | rescale_factor: float = 1.0 / 255.0, 110 | do_normalize: bool = True, 111 | **kwargs, 112 | ): 113 | super().__init__(**kwargs) 114 | 115 | self.image_size = image_size 116 | self.rescale_factor = rescale_factor 117 | self.image_mean = image_mean 118 | self.image_std = image_std 119 | self.min_size = min_size 120 | self.do_normalize = do_normalize 121 | 122 | if image_mean is None: 123 | self.background_color = (127, 127, 127) 124 | else: 125 | self.background_color = tuple([int(x * 255) for x in image_mean]) 126 | 127 | def resize(self, pil_img: Image) -> np.ndarray: 128 | """ 129 | 130 | Args: 131 | pil_img (PIL.Image): [H, W, 3] in PIL.Image in RGB 132 | 133 | Returns: 134 | x (np.ndarray): [3, self.image_size, self.image_size] 135 | """ 136 | 137 | width, height = pil_img.size 138 | max_size = max(width, height) 139 | 140 | size = [ 141 | max(int(height / max_size * self.image_size), self.min_size), 142 | max(int(width / max_size * self.image_size), self.min_size), 143 | ] 144 | 145 | if width <= 0 or height <= 0 or size[0] <= 0 or size[1] <= 0: 146 | print(f"orig size = {pil_img.size}, new size = {size}") 147 | raise ValueError("Invalid size!") 148 | 149 | pil_img = torchvision.transforms.functional.resize( 150 | pil_img, 151 | size, 152 | interpolation=torchvision.transforms.functional.InterpolationMode.BICUBIC, 153 | antialias=True, 154 | ) 155 | 156 | pil_img = expand2square(pil_img, self.background_color) 157 | x = to_numpy_array(pil_img) 158 | 159 | # [H, W, 3] -> [3, H, W] 160 | x = np.transpose(x, (2, 0, 1)) 161 | 162 | return x 163 | 164 | def preprocess(self, images, return_tensors: str = "pt", **kwargs) -> BatchFeature: 165 | # resize and pad to [self.image_size, self.image_size] 166 | # then convert from [H, W, 3] to [3, H, W] 167 | images: List[np.ndarray] = [self.resize(image) for image in images] 168 | 169 | # resacle from [0, 255] -> [0, 1] 170 | images = [ 171 | self.rescale( 172 | image=image, 173 | scale=self.rescale_factor, 174 | input_data_format="channels_first", 175 | ) 176 | for image in images 177 | ] 178 | 179 | # normalize 180 | if self.do_normalize: 181 | images = [ 182 | self.normalize( 183 | image=image, 184 | mean=self.image_mean, 185 | std=self.image_std, 186 | input_data_format="channels_first", 187 | ) 188 | for image in images 189 | ] 190 | 191 | data = {"pixel_values": images} 192 | return BatchFeature(data=data, tensor_type=return_tensors) 193 | 194 | @property 195 | def default_shape(self): 196 | return [3, self.image_size, self.image_size] 197 | 198 | 199 | AutoImageProcessor.register(VLMImageProcessorConfig, VLMImageProcessor) 200 | 201 | 202 | if __name__ == "__main__": 203 | image_processor = VLMImageProcessor( 204 | image_size=1024, 205 | image_mean=IMAGENET_INCEPTION_MEAN, 206 | image_std=IMAGENET_INCEPTION_STD, 207 | do_normalize=True, 208 | ) 209 | -------------------------------------------------------------------------------- /deepseek_vl/models/modeling_vlm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | from attrdict import AttrDict 22 | from einops import rearrange 23 | from transformers import ( 24 | AutoConfig, 25 | AutoModelForCausalLM, 26 | LlamaConfig, 27 | LlamaForCausalLM, 28 | PreTrainedModel, 29 | ) 30 | from transformers.configuration_utils import PretrainedConfig 31 | 32 | from .clip_encoder import CLIPVisionTower, HybridVisionTower 33 | from .projector import MlpProjector 34 | 35 | 36 | def model_name_to_cls(cls_name): 37 | if "MlpProjector" in cls_name: 38 | cls = MlpProjector 39 | 40 | elif "CLIPVisionTower" in cls_name: 41 | cls = CLIPVisionTower 42 | 43 | elif "HybridVisionTower" in cls_name: 44 | cls = HybridVisionTower 45 | 46 | else: 47 | raise ValueError(f"class_name {cls_name} is invalid.") 48 | 49 | return cls 50 | 51 | 52 | class VisionConfig(PretrainedConfig): 53 | model_type = "vision" 54 | cls: str = "" 55 | params: AttrDict = {} 56 | 57 | def __init__(self, **kwargs): 58 | super().__init__(**kwargs) 59 | 60 | self.cls = kwargs.get("cls", "") 61 | if not isinstance(self.cls, str): 62 | self.cls = self.cls.__name__ 63 | 64 | self.params = AttrDict(kwargs.get("params", {})) 65 | 66 | 67 | class AlignerConfig(PretrainedConfig): 68 | model_type = "aligner" 69 | cls: str = "" 70 | params: AttrDict = {} 71 | 72 | def __init__(self, **kwargs): 73 | super().__init__(**kwargs) 74 | 75 | self.cls = kwargs.get("cls", "") 76 | if not isinstance(self.cls, str): 77 | self.cls = self.cls.__name__ 78 | 79 | self.params = AttrDict(kwargs.get("params", {})) 80 | 81 | 82 | class MultiModalityConfig(PretrainedConfig): 83 | model_type = "multi_modality" 84 | vision_config: VisionConfig 85 | aligner_config: AlignerConfig 86 | language_config: LlamaConfig 87 | 88 | def __init__(self, **kwargs): 89 | super().__init__(**kwargs) 90 | vision_config = kwargs.get("vision_config", {}) 91 | self.vision_config = VisionConfig(**vision_config) 92 | 93 | aligner_config = kwargs.get("aligner_config", {}) 94 | self.aligner_config = AlignerConfig(**aligner_config) 95 | 96 | language_config = kwargs.get("language_config", {}) 97 | if isinstance(language_config, LlamaConfig): 98 | self.language_config = language_config 99 | else: 100 | self.language_config = LlamaConfig(**language_config) 101 | 102 | 103 | class MultiModalityPreTrainedModel(PreTrainedModel): 104 | config_class = MultiModalityConfig 105 | base_model_prefix = "multi_modality" 106 | _no_split_modules = [] 107 | _skip_keys_device_placement = "past_key_values" 108 | 109 | 110 | class MultiModalityCausalLM(MultiModalityPreTrainedModel): 111 | def __init__(self, config: MultiModalityConfig): 112 | super().__init__(config) 113 | 114 | vision_config = config.vision_config 115 | vision_cls = model_name_to_cls(vision_config.cls) 116 | self.vision_model = vision_cls(**vision_config.params) 117 | 118 | aligner_config = config.aligner_config 119 | aligner_cls = model_name_to_cls(aligner_config.cls) 120 | self.aligner = aligner_cls(aligner_config.params) 121 | 122 | language_config = config.language_config 123 | self.language_model = LlamaForCausalLM(language_config) 124 | 125 | def prepare_inputs_embeds( 126 | self, 127 | input_ids: torch.LongTensor, 128 | pixel_values: torch.FloatTensor, 129 | images_seq_mask: torch.LongTensor, 130 | images_emb_mask: torch.LongTensor, 131 | **kwargs, 132 | ): 133 | """ 134 | 135 | Args: 136 | input_ids (torch.LongTensor): [b, T] 137 | pixel_values (torch.FloatTensor): [b, n_images, 3, h, w] 138 | images_seq_mask (torch.BoolTensor): [b, T] 139 | images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens] 140 | 141 | assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask) 142 | 143 | Returns: 144 | input_embeds (torch.Tensor): [b, T, D] 145 | """ 146 | 147 | bs, n = pixel_values.shape[0:2] 148 | images = rearrange(pixel_values, "b n c h w -> (b n) c h w") 149 | # [b x n, T2, D] 150 | images_embeds = self.aligner(self.vision_model(images)) 151 | 152 | # [b x n, T2, D] -> [b, n x T2, D] 153 | images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n) 154 | # [b, n, T2] -> [b, n x T2] 155 | images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)") 156 | 157 | # [b, T, D] 158 | input_ids[input_ids < 0] = 0 # ignore the image embeddings 159 | inputs_embeds = self.language_model.get_input_embeddings()(input_ids) 160 | 161 | # replace with the image embeddings 162 | inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask] 163 | 164 | return inputs_embeds 165 | 166 | 167 | AutoConfig.register("vision", VisionConfig) 168 | AutoConfig.register("aligner", AlignerConfig) 169 | AutoConfig.register("multi_modality", MultiModalityConfig) 170 | AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM) 171 | -------------------------------------------------------------------------------- /deepseek_vl/models/processing_vlm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | from dataclasses import dataclass 21 | from typing import Dict, List 22 | 23 | import torch 24 | from PIL.Image import Image 25 | from transformers import LlamaTokenizerFast 26 | from transformers.processing_utils import ProcessorMixin 27 | 28 | from .image_processing_vlm import VLMImageProcessor 29 | from ..utils.conversation import get_conv_template 30 | 31 | 32 | class DictOutput(object): 33 | def keys(self): 34 | return self.__dict__.keys() 35 | 36 | def __getitem__(self, item): 37 | return self.__dict__[item] 38 | 39 | def __setitem__(self, key, value): 40 | self.__dict__[key] = value 41 | 42 | 43 | @dataclass 44 | class VLChatProcessorOutput(DictOutput): 45 | sft_format: str 46 | input_ids: torch.Tensor 47 | pixel_values: torch.Tensor 48 | num_image_tokens: torch.IntTensor 49 | 50 | def __len__(self): 51 | return len(self.input_ids) 52 | 53 | 54 | @dataclass 55 | class BatchedVLChatProcessorOutput(DictOutput): 56 | sft_format: List[str] 57 | input_ids: torch.Tensor 58 | pixel_values: torch.Tensor 59 | attention_mask: torch.Tensor 60 | images_seq_mask: torch.BoolTensor 61 | images_emb_mask: torch.BoolTensor 62 | 63 | def to(self, device, dtype=torch.bfloat16): 64 | self.input_ids = self.input_ids.to(device) 65 | self.attention_mask = self.attention_mask.to(device) 66 | self.images_seq_mask = self.images_seq_mask.to(device) 67 | self.images_emb_mask = self.images_emb_mask.to(device) 68 | self.pixel_values = self.pixel_values.to(device=device, dtype=dtype) 69 | return self 70 | 71 | 72 | class VLChatProcessor(ProcessorMixin): 73 | image_processor_class = "AutoImageProcessor" 74 | tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast") 75 | 76 | attributes = ["image_processor", "tokenizer"] 77 | 78 | system_prompt = ( 79 | "You are a helpful language and vision assistant. " 80 | "You are able to understand the visual content that the user provides, " 81 | "and assist the user with a variety of tasks using natural language." 82 | ) 83 | 84 | def __init__( 85 | self, 86 | image_processor: VLMImageProcessor, 87 | tokenizer: LlamaTokenizerFast, 88 | image_tag: str = "", 89 | num_image_tokens: int = 576, 90 | add_special_token: bool = False, 91 | sft_format: str = "deepseek", 92 | mask_prompt: bool = True, 93 | ignore_id: int = -100, 94 | **kwargs, 95 | ): 96 | self.image_processor = image_processor 97 | self.tokenizer = tokenizer 98 | 99 | image_id = self.tokenizer.vocab.get(image_tag) 100 | if image_id is None: 101 | special_tokens = [image_tag] 102 | special_tokens_dict = {"additional_special_tokens": special_tokens} 103 | self.tokenizer.add_special_tokens(special_tokens_dict) 104 | print(f"Add image tag = {image_tag} to the tokenizer") 105 | 106 | self.image_tag = image_tag 107 | self.num_image_tokens = num_image_tokens 108 | self.add_special_token = add_special_token 109 | self.sft_format = sft_format 110 | self.mask_prompt = mask_prompt 111 | self.ignore_id = ignore_id 112 | 113 | super().__init__( 114 | image_processor, 115 | tokenizer, 116 | image_tag, 117 | num_image_tokens, 118 | add_special_token, 119 | sft_format, 120 | mask_prompt, 121 | ignore_id, 122 | **kwargs, 123 | ) 124 | 125 | def new_chat_template(self): 126 | conv = get_conv_template(self.sft_format) 127 | conv.set_system_message(self.system_prompt) 128 | return conv 129 | 130 | def apply_sft_template_for_multi_turn_prompts( 131 | self, 132 | conversations: List[Dict[str, str]], 133 | sft_format: str = "deepseek", 134 | system_prompt: str = "", 135 | ): 136 | """ 137 | Applies the SFT template to conversation. 138 | 139 | An example of conversation: 140 | conversation = [ 141 | { 142 | "role": "User", 143 | "content": " is Figure 1.\n is Figure 2.\nWhich image is brighter?", 144 | "images": [ 145 | "./multi-images/attribute_comparison_1.png", 146 | "./multi-images/attribute_comparison_2.png" 147 | ] 148 | }, 149 | { 150 | "role": "Assistant", 151 | "content": "" 152 | } 153 | ] 154 | 155 | Args: 156 | conversations (List[Dict]): A conversation with a List of Dict[str, str] text. 157 | sft_format (str, optional): The format of the SFT template to use. Defaults to "deepseek". 158 | system_prompt (str, optional): The system prompt to use in the SFT template. Defaults to "". 159 | 160 | Returns: 161 | sft_prompt (str): The formatted text. 162 | """ 163 | 164 | conv = get_conv_template(sft_format) 165 | conv.set_system_message(system_prompt) 166 | for message in conversations: 167 | conv.append_message(message["role"], message["content"].strip()) 168 | sft_prompt = conv.get_prompt().strip() 169 | 170 | return sft_prompt 171 | 172 | @property 173 | def image_token(self): 174 | return self.image_tag 175 | 176 | @property 177 | def image_id(self): 178 | image_id = self.tokenizer.vocab.get(self.image_tag) 179 | return image_id 180 | 181 | @property 182 | def pad_id(self): 183 | pad_id = self.tokenizer.pad_token_id 184 | if pad_id is None: 185 | pad_id = self.tokenizer.eos_token_id 186 | 187 | return pad_id 188 | 189 | def add_image_token( 190 | self, 191 | image_indices: List[int], 192 | input_ids: torch.LongTensor, 193 | ): 194 | """ 195 | 196 | Args: 197 | image_indices (List[int]): [index_0, index_1, ..., index_j] 198 | input_ids (torch.LongTensor): [N] 199 | 200 | Returns: 201 | input_ids (torch.LongTensor): [N + image tokens] 202 | num_image_tokens (torch.IntTensor): [n_images] 203 | """ 204 | 205 | input_slices = [] 206 | 207 | start = 0 208 | for index in image_indices: 209 | if self.add_special_token: 210 | end = index + 1 211 | else: 212 | end = index 213 | 214 | # original text tokens 215 | input_slices.append(input_ids[start:end]) 216 | 217 | # add image tokens, and set the mask as False 218 | input_slices.append( 219 | self.image_id * torch.ones((self.num_image_tokens,), dtype=torch.long) 220 | ) 221 | start = index + 1 222 | 223 | # the left part 224 | input_slices.append(input_ids[start:]) 225 | 226 | # concat all slices 227 | input_ids = torch.cat(input_slices, dim=0) 228 | num_image_tokens = torch.IntTensor([self.num_image_tokens] * len(image_indices)) 229 | 230 | return input_ids, num_image_tokens 231 | 232 | def process_one( 233 | self, 234 | prompt: str = None, 235 | conversations: List[Dict[str, str]] = None, 236 | images: List[Image] = None, 237 | **kwargs, 238 | ): 239 | """ 240 | 241 | Args: 242 | prompt (str): the formatted prompt; 243 | conversations (List[Dict]): conversations with a list of messages; 244 | images (List[ImageType]): the list of images; 245 | **kwargs: 246 | 247 | Returns: 248 | outputs (BaseProcessorOutput): the output of the processor, 249 | - input_ids (torch.LongTensor): [N + image tokens] 250 | - target_ids (torch.LongTensor): [N + image tokens] 251 | - images (torch.FloatTensor): [n_images, 3, H, W] 252 | - image_id (int): the id of the image token 253 | - num_image_tokens (List[int]): the number of image tokens 254 | """ 255 | 256 | assert ( 257 | prompt is None or conversations is None 258 | ), "prompt and conversations cannot be used at the same time." 259 | 260 | if prompt is None: 261 | # apply sft format 262 | sft_format = self.apply_sft_template_for_multi_turn_prompts( 263 | conversations=conversations, 264 | sft_format=self.sft_format, 265 | system_prompt=self.system_prompt, 266 | ) 267 | else: 268 | sft_format = prompt 269 | 270 | # tokenize 271 | input_ids = self.tokenizer.encode(sft_format) 272 | input_ids = torch.LongTensor(input_ids) 273 | 274 | # add image tokens to the input_ids 275 | image_token_mask: torch.BoolTensor = input_ids == self.image_id 276 | image_indices = image_token_mask.nonzero() 277 | input_ids, num_image_tokens = self.add_image_token( 278 | image_indices=image_indices, 279 | input_ids=input_ids, 280 | ) 281 | 282 | # load images 283 | images_outputs = self.image_processor(images, return_tensors="pt") 284 | 285 | prepare = VLChatProcessorOutput( 286 | sft_format=sft_format, 287 | input_ids=input_ids, 288 | pixel_values=images_outputs.pixel_values, 289 | num_image_tokens=num_image_tokens, 290 | ) 291 | 292 | return prepare 293 | 294 | def __call__( 295 | self, 296 | *, 297 | prompt: str = None, 298 | conversations: List[Dict[str, str]] = None, 299 | images: List[Image] = None, 300 | force_batchify: bool = True, 301 | **kwargs, 302 | ): 303 | """ 304 | 305 | Args: 306 | prompt (str): the formatted prompt; 307 | conversations (List[Dict]): conversations with a list of messages; 308 | images (List[ImageType]): the list of images; 309 | force_batchify (bool): force batchify the inputs; 310 | **kwargs: 311 | 312 | Returns: 313 | outputs (BaseProcessorOutput): the output of the processor, 314 | - input_ids (torch.LongTensor): [N + image tokens] 315 | - images (torch.FloatTensor): [n_images, 3, H, W] 316 | - image_id (int): the id of the image token 317 | - num_image_tokens (List[int]): the number of image tokens 318 | """ 319 | 320 | prepare = self.process_one( 321 | prompt=prompt, conversations=conversations, images=images 322 | ) 323 | 324 | if force_batchify: 325 | prepare = self.batchify([prepare]) 326 | 327 | return prepare 328 | 329 | def batchify( 330 | self, prepare_list: List[VLChatProcessorOutput] 331 | ) -> BatchedVLChatProcessorOutput: 332 | """ 333 | Preprocesses the inputs for multimodal inference. 334 | 335 | Args: 336 | prepare_list (List[VLChatProcessorOutput]): A list of VLChatProcessorOutput. 337 | 338 | Returns: 339 | BatchedVLChatProcessorOutput: A dictionary of the inputs to use for multimodal inference. 340 | """ 341 | 342 | batch_size = len(prepare_list) 343 | sft_format = [] 344 | n_images = [] 345 | seq_lens = [] 346 | for prepare in prepare_list: 347 | n_images.append(len(prepare.num_image_tokens)) 348 | seq_lens.append(len(prepare)) 349 | 350 | input_token_max_len = max(seq_lens) 351 | max_n_images = max(1, max(n_images)) 352 | 353 | batched_input_ids = torch.full( 354 | (batch_size, input_token_max_len), self.pad_id 355 | ).long() # FIXME 356 | batched_attention_mask = torch.zeros((batch_size, input_token_max_len)).long() 357 | batched_pixel_values = torch.zeros( 358 | (batch_size, max_n_images, *self.image_processor.default_shape) 359 | ).float() 360 | batched_images_seq_mask = torch.zeros((batch_size, input_token_max_len)).bool() 361 | batched_images_emb_mask = torch.zeros( 362 | (batch_size, max_n_images, self.num_image_tokens) 363 | ).bool() 364 | 365 | for i, prepare in enumerate(prepare_list): 366 | input_ids = prepare.input_ids 367 | seq_len = len(prepare) 368 | n_image = len(prepare.num_image_tokens) 369 | # left-padding 370 | batched_attention_mask[i, -seq_len:] = 1 371 | batched_input_ids[i, -seq_len:] = torch.LongTensor(input_ids) 372 | batched_images_seq_mask[i, -seq_len:] = input_ids == self.image_id 373 | 374 | if n_image > 0: 375 | batched_pixel_values[i, :n_image] = prepare.pixel_values 376 | for j, n_image_tokens in enumerate(prepare.num_image_tokens): 377 | batched_images_emb_mask[i, j, :n_image_tokens] = True 378 | 379 | sft_format.append(prepare.sft_format) 380 | 381 | batched_prepares = BatchedVLChatProcessorOutput( 382 | input_ids=batched_input_ids, 383 | attention_mask=batched_attention_mask, 384 | pixel_values=batched_pixel_values, 385 | images_seq_mask=batched_images_seq_mask, 386 | images_emb_mask=batched_images_emb_mask, 387 | sft_format=sft_format, 388 | ) 389 | 390 | return batched_prepares 391 | -------------------------------------------------------------------------------- /deepseek_vl/models/projector.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | from typing import Tuple, Union 21 | 22 | import torch 23 | import torch.nn as nn 24 | from attrdict import AttrDict 25 | 26 | 27 | class MlpProjector(nn.Module): 28 | def __init__(self, cfg): 29 | super().__init__() 30 | 31 | self.cfg = cfg 32 | 33 | if cfg.projector_type == "identity": 34 | modules = nn.Identity() 35 | 36 | elif cfg.projector_type == "linear": 37 | modules = nn.Linear(cfg.input_dim, cfg.n_embed) 38 | 39 | elif cfg.projector_type == "mlp_gelu": 40 | mlp_depth = cfg.get("depth", 1) 41 | modules = [nn.Linear(cfg.input_dim, cfg.n_embed)] 42 | for _ in range(1, mlp_depth): 43 | modules.append(nn.GELU()) 44 | modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) 45 | modules = nn.Sequential(*modules) 46 | 47 | elif cfg.projector_type == "low_high_hybrid_split_mlp_gelu": 48 | mlp_depth = cfg.get("depth", 1) 49 | self.high_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) 50 | self.low_up_proj = nn.Linear(cfg.input_dim, cfg.n_embed // 2) 51 | 52 | modules = [] 53 | for _ in range(1, mlp_depth): 54 | modules.append(nn.GELU()) 55 | modules.append(nn.Linear(cfg.n_embed, cfg.n_embed)) 56 | modules = nn.Sequential(*modules) 57 | 58 | else: 59 | raise ValueError(f"Unknown projector type: {cfg.projector_type}") 60 | 61 | self.layers = modules 62 | 63 | def forward( 64 | self, x_or_tuple: Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor] 65 | ): 66 | """ 67 | 68 | Args: 69 | x_or_tuple (Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: if it is a tuple of torch.Tensor, 70 | then it comes from the hybrid vision encoder, and x = high_res_x, low_res_x); 71 | otherwise it is the feature from the single vision encoder. 72 | 73 | Returns: 74 | x (torch.Tensor): [b, s, c] 75 | """ 76 | 77 | if isinstance(x_or_tuple, tuple): 78 | # self.cfg.projector_type == "low_high_hybrid_split_mlp_gelu": 79 | high_x, low_x = x_or_tuple 80 | high_x = self.high_up_proj(high_x) 81 | low_x = self.low_up_proj(low_x) 82 | x = torch.concat([high_x, low_x], dim=-1) 83 | else: 84 | x = x_or_tuple 85 | 86 | return self.layers(x) 87 | 88 | 89 | if __name__ == "__main__": 90 | cfg = AttrDict( 91 | input_dim=1024, 92 | n_embed=2048, 93 | depth=2, 94 | projector_type="low_high_hybrid_split_mlp_gelu", 95 | ) 96 | inputs = (torch.rand(4, 576, 1024), torch.rand(4, 576, 1024)) 97 | 98 | m = MlpProjector(cfg) 99 | out = m(inputs) 100 | print(out.shape) 101 | -------------------------------------------------------------------------------- /deepseek_vl/models/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import copy 8 | from dataclasses import dataclass 9 | from functools import partial 10 | from typing import List, Optional, Tuple, Type, Union 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | 17 | class MLPBlock(nn.Module): 18 | def __init__( 19 | self, 20 | embedding_dim: int, 21 | mlp_dim: int, 22 | act: Type[nn.Module] = nn.GELU, 23 | ) -> None: 24 | super().__init__() 25 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 26 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 27 | self.act = act() 28 | 29 | def forward(self, x: torch.Tensor) -> torch.Tensor: 30 | return self.lin2(self.act(self.lin1(x))) 31 | 32 | 33 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 34 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 35 | class LayerNorm2d(nn.Module): 36 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 37 | super().__init__() 38 | self.weight = nn.Parameter(torch.ones(num_channels)) 39 | self.bias = nn.Parameter(torch.zeros(num_channels)) 40 | self.eps = eps 41 | 42 | def forward(self, x: torch.Tensor) -> torch.Tensor: 43 | u = x.mean(1, keepdim=True) 44 | s = (x - u).pow(2).mean(1, keepdim=True) 45 | x = (x - u) / torch.sqrt(s + self.eps) 46 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 47 | return x 48 | 49 | 50 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa 51 | class ImageEncoderViT(nn.Module): 52 | def __init__( 53 | self, 54 | img_size: int = 1024, 55 | patch_size: int = 16, 56 | in_chans: int = 3, 57 | embed_dim: int = 768, 58 | depth: int = 12, 59 | num_heads: int = 12, 60 | mlp_ratio: float = 4.0, 61 | out_chans: int = 256, 62 | qkv_bias: bool = True, 63 | norm_layer: Type[nn.Module] = nn.LayerNorm, 64 | act_layer: Type[nn.Module] = nn.GELU, 65 | use_abs_pos: bool = True, 66 | use_rel_pos: bool = False, 67 | rel_pos_zero_init: bool = True, 68 | window_size: int = 0, 69 | global_attn_indexes: Tuple[int, ...] = (), 70 | downsample_channels: Tuple[int, ...] = (512, 1024), 71 | ) -> None: 72 | """ 73 | Args: 74 | img_size (int): Input image size. 75 | patch_size (int): Patch size. 76 | in_chans (int): Number of input image channels. 77 | embed_dim (int): Patch embedding dimension. 78 | depth (int): Depth of ViT. 79 | num_heads (int): Number of attention heads in each ViT block. 80 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 81 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 82 | norm_layer (nn.Module): Normalization layer. 83 | act_layer (nn.Module): Activation layer. 84 | use_abs_pos (bool): If True, use absolute positional embeddings. 85 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 86 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 87 | window_size (int): Window size for window attention blocks. 88 | global_attn_indexes (list): Indexes for blocks using global attention. 89 | downsample_channels (list): Channels for downsampling layers. 90 | """ 91 | super().__init__() 92 | self.img_size = img_size 93 | 94 | self.patch_embed = PatchEmbed( 95 | kernel_size=(patch_size, patch_size), 96 | stride=(patch_size, patch_size), 97 | in_chans=in_chans, 98 | embed_dim=embed_dim, 99 | ) 100 | 101 | self.pos_embed: Optional[nn.Parameter] = None 102 | if use_abs_pos: 103 | # Initialize absolute positional embedding with pretrain image size. 104 | self.pos_embed = nn.Parameter( 105 | torch.zeros( 106 | 1, img_size // patch_size, img_size // patch_size, embed_dim 107 | ) 108 | ) 109 | 110 | self.blocks = nn.ModuleList() 111 | for i in range(depth): 112 | block = Block( 113 | dim=embed_dim, 114 | num_heads=num_heads, 115 | mlp_ratio=mlp_ratio, 116 | qkv_bias=qkv_bias, 117 | norm_layer=norm_layer, 118 | act_layer=act_layer, 119 | use_rel_pos=use_rel_pos, 120 | rel_pos_zero_init=rel_pos_zero_init, 121 | window_size=window_size if i not in global_attn_indexes else 0, 122 | input_size=(img_size // patch_size, img_size // patch_size), 123 | ) 124 | self.blocks.append(block) 125 | 126 | self.neck = nn.Sequential( 127 | nn.Conv2d( 128 | embed_dim, 129 | out_chans, 130 | kernel_size=1, 131 | bias=False, 132 | ), 133 | LayerNorm2d(out_chans), 134 | nn.Conv2d( 135 | out_chans, 136 | out_chans, 137 | kernel_size=3, 138 | padding=1, 139 | bias=False, 140 | ), 141 | LayerNorm2d(out_chans), 142 | ) 143 | 144 | in_channels = out_chans 145 | downsamples = [] 146 | for i in range(len(downsample_channels)): 147 | out_channels = downsample_channels[i] 148 | downsamples.append( 149 | nn.Conv2d( 150 | in_channels, 151 | out_channels, 152 | kernel_size=3, 153 | stride=2, 154 | padding=1, 155 | bias=False, 156 | ) 157 | ) 158 | in_channels = out_channels 159 | self.downsamples = nn.Sequential(*downsamples) 160 | 161 | self.sam_hd = True 162 | if self.sam_hd: 163 | self.hd_alpha_downsamples = nn.Parameter(torch.zeros(1)) 164 | # self.neck_hd = nn.Linear(embed_dim, embed_dim) 165 | self.neck_hd = copy.deepcopy(self.neck) 166 | # self.downsamples_hd = copy.deepcopy(self.downsamples) 167 | 168 | def forward(self, x: torch.Tensor) -> torch.Tensor: 169 | x = self.patch_embed(x) 170 | if self.pos_embed is not None: 171 | x = x + self.pos_embed 172 | 173 | global_features = [] 174 | for i, blk in enumerate(self.blocks): 175 | x = blk(x) 176 | if self.sam_hd and blk.window_size == 0: 177 | global_features.append(x) 178 | 179 | x = self.neck(x.permute(0, 3, 1, 2)) 180 | x_dtype = x.dtype 181 | x = F.interpolate( 182 | x.float(), size=(96, 96), mode="bilinear", align_corners=False 183 | ).to(x_dtype) 184 | x = self.downsamples(x) 185 | 186 | if self.sam_hd: 187 | first_global_feature = self.neck_hd(global_features[0].permute(0, 3, 1, 2)) 188 | x_dtype = first_global_feature.dtype 189 | first_global_feature = F.interpolate( 190 | first_global_feature.float(), 191 | size=(96, 96), 192 | mode="bilinear", 193 | align_corners=False, 194 | ) 195 | first_global_feature = self.downsamples(first_global_feature.to(x_dtype)) 196 | x = x + first_global_feature * self.hd_alpha_downsamples 197 | 198 | return x 199 | 200 | 201 | class Block(nn.Module): 202 | """Transformer blocks with support of window attention and residual propagation blocks""" 203 | 204 | def __init__( 205 | self, 206 | dim: int, 207 | num_heads: int, 208 | mlp_ratio: float = 4.0, 209 | qkv_bias: bool = True, 210 | norm_layer: Type[nn.Module] = nn.LayerNorm, 211 | act_layer: Type[nn.Module] = nn.GELU, 212 | use_rel_pos: bool = False, 213 | rel_pos_zero_init: bool = True, 214 | window_size: int = 0, 215 | input_size: Optional[Tuple[int, int]] = None, 216 | ) -> None: 217 | """ 218 | Args: 219 | dim (int): Number of input channels. 220 | num_heads (int): Number of attention heads in each ViT block. 221 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 222 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 223 | norm_layer (nn.Module): Normalization layer. 224 | act_layer (nn.Module): Activation layer. 225 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 226 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 227 | window_size (int): Window size for window attention blocks. If it equals 0, then 228 | use global attention. 229 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 230 | positional parameter size. 231 | """ 232 | super().__init__() 233 | self.norm1 = norm_layer(dim) 234 | self.attn = Attention( 235 | dim, 236 | num_heads=num_heads, 237 | qkv_bias=qkv_bias, 238 | use_rel_pos=use_rel_pos, 239 | rel_pos_zero_init=rel_pos_zero_init, 240 | input_size=input_size if window_size == 0 else (window_size, window_size), 241 | ) 242 | 243 | self.norm2 = norm_layer(dim) 244 | self.mlp = MLPBlock( 245 | embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer 246 | ) 247 | 248 | self.window_size = window_size 249 | 250 | def forward(self, x: torch.Tensor) -> torch.Tensor: 251 | shortcut = x 252 | x = self.norm1(x) 253 | # Window partition 254 | if self.window_size > 0: 255 | H, W = x.shape[1], x.shape[2] 256 | x, pad_hw = window_partition(x, self.window_size) 257 | 258 | x = self.attn(x) 259 | # Reverse window partition 260 | if self.window_size > 0: 261 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 262 | 263 | x = shortcut + x 264 | x = x + self.mlp(self.norm2(x)) 265 | 266 | return x 267 | 268 | 269 | class Attention(nn.Module): 270 | """Multi-head Attention block with relative position embeddings.""" 271 | 272 | def __init__( 273 | self, 274 | dim: int, 275 | num_heads: int = 8, 276 | qkv_bias: bool = True, 277 | use_rel_pos: bool = False, 278 | rel_pos_zero_init: bool = True, 279 | input_size: Optional[Tuple[int, int]] = None, 280 | ) -> None: 281 | """ 282 | Args: 283 | dim (int): Number of input channels. 284 | num_heads (int): Number of attention heads. 285 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 286 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 287 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 288 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 289 | positional parameter size. 290 | """ 291 | super().__init__() 292 | self.num_heads = num_heads 293 | head_dim = dim // num_heads 294 | self.scale = head_dim**-0.5 295 | 296 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 297 | self.proj = nn.Linear(dim, dim) 298 | 299 | self.use_rel_pos = use_rel_pos 300 | if self.use_rel_pos: 301 | assert ( 302 | input_size is not None 303 | ), "Input size must be provided if using relative positional encoding." 304 | # initialize relative positional embeddings 305 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 306 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 307 | 308 | def forward(self, x: torch.Tensor) -> torch.Tensor: 309 | B, H, W, _ = x.shape 310 | # qkv with shape (3, B, nHead, H * W, C) 311 | qkv = ( 312 | self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 313 | ) 314 | # q, k, v with shape (B * nHead, H * W, C) 315 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 316 | 317 | def do_attention(q, k, v): 318 | attn = (q * self.scale) @ k.transpose(-2, -1) 319 | if self.use_rel_pos: 320 | attn = add_decomposed_rel_pos( 321 | attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W) 322 | ) 323 | 324 | attn = attn.softmax(dim=-1) 325 | x = ( 326 | (attn @ v) 327 | .view(B, self.num_heads, H, W, -1) 328 | .permute(0, 2, 3, 1, 4) 329 | .reshape(B, H, W, -1) 330 | ) 331 | 332 | return x 333 | 334 | # from haiscale.utils import on_demand_checkpoint 335 | # x = on_demand_checkpoint(do_attention, q, k, v) 336 | x = do_attention(q, k, v) 337 | x = self.proj(x) 338 | 339 | return x 340 | 341 | 342 | def window_partition( 343 | x: torch.Tensor, window_size: int 344 | ) -> Tuple[torch.Tensor, Tuple[int, int]]: 345 | """ 346 | Partition into non-overlapping windows with padding if needed. 347 | Args: 348 | x (tensor): input tokens with [B, H, W, C]. 349 | window_size (int): window size. 350 | 351 | Returns: 352 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 353 | (Hp, Wp): padded height and width before partition 354 | """ 355 | B, H, W, C = x.shape 356 | 357 | pad_h = (window_size - H % window_size) % window_size 358 | pad_w = (window_size - W % window_size) % window_size 359 | if pad_h > 0 or pad_w > 0: 360 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 361 | Hp, Wp = H + pad_h, W + pad_w 362 | 363 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 364 | windows = ( 365 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 366 | ) 367 | return windows, (Hp, Wp) 368 | 369 | 370 | def window_unpartition( 371 | windows: torch.Tensor, 372 | window_size: int, 373 | pad_hw: Tuple[int, int], 374 | hw: Tuple[int, int], 375 | ) -> torch.Tensor: 376 | """ 377 | Window unpartition into original sequences and removing padding. 378 | Args: 379 | windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 380 | window_size (int): window size. 381 | pad_hw (Tuple): padded height and width (Hp, Wp). 382 | hw (Tuple): original height and width (H, W) before padding. 383 | 384 | Returns: 385 | x: unpartitioned sequences with [B, H, W, C]. 386 | """ 387 | Hp, Wp = pad_hw 388 | H, W = hw 389 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 390 | x = windows.view( 391 | B, Hp // window_size, Wp // window_size, window_size, window_size, -1 392 | ) 393 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 394 | 395 | if Hp > H or Wp > W: 396 | x = x[:, :H, :W, :].contiguous() 397 | return x 398 | 399 | 400 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 401 | """ 402 | Get relative positional embeddings according to the relative positions of 403 | query and key sizes. 404 | Args: 405 | q_size (int): size of query q. 406 | k_size (int): size of key k. 407 | rel_pos (Tensor): relative position embeddings (L, C). 408 | 409 | Returns: 410 | Extracted positional embeddings according to relative positions. 411 | """ 412 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 413 | # Interpolate rel pos if needed. 414 | if rel_pos.shape[0] != max_rel_dist: 415 | # Interpolate rel pos. 416 | rel_pos_resized = F.interpolate( 417 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 418 | size=max_rel_dist, 419 | mode="linear", 420 | ) 421 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 422 | else: 423 | rel_pos_resized = rel_pos 424 | 425 | # Scale the coords with short length if shapes for q and k are different. 426 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 427 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 428 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 429 | 430 | return rel_pos_resized[relative_coords.long()] 431 | 432 | 433 | def add_decomposed_rel_pos( 434 | attn: torch.Tensor, 435 | q: torch.Tensor, 436 | rel_pos_h: torch.Tensor, 437 | rel_pos_w: torch.Tensor, 438 | q_size: Tuple[int, int], 439 | k_size: Tuple[int, int], 440 | ) -> torch.Tensor: 441 | """ 442 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 443 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 444 | Args: 445 | attn (Tensor): attention map. 446 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 447 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 448 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 449 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 450 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 451 | 452 | Returns: 453 | attn (Tensor): attention map with added relative positional embeddings. 454 | """ 455 | q_h, q_w = q_size 456 | k_h, k_w = k_size 457 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 458 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 459 | 460 | B, _, dim = q.shape 461 | r_q = q.reshape(B, q_h, q_w, dim) 462 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 463 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 464 | 465 | attn = ( 466 | attn.view(B, q_h, q_w, k_h, k_w) 467 | + rel_h[:, :, :, :, None] 468 | + rel_w[:, :, :, None, :] 469 | ).view(B, q_h * q_w, k_h * k_w) 470 | 471 | return attn 472 | 473 | 474 | class PatchEmbed(nn.Module): 475 | """ 476 | Image to Patch Embedding. 477 | """ 478 | 479 | def __init__( 480 | self, 481 | kernel_size: Tuple[int, int] = (16, 16), 482 | stride: Tuple[int, int] = (16, 16), 483 | padding: Tuple[int, int] = (0, 0), 484 | in_chans: int = 3, 485 | embed_dim: int = 768, 486 | ) -> None: 487 | """ 488 | Args: 489 | kernel_size (Tuple): kernel size of the projection layer. 490 | stride (Tuple): stride of the projection layer. 491 | padding (Tuple): padding size of the projection layer. 492 | in_chans (int): Number of input image channels. 493 | embed_dim (int): Patch embedding dimension. 494 | """ 495 | super().__init__() 496 | 497 | self.proj = nn.Conv2d( 498 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 499 | ) 500 | 501 | def forward(self, x: torch.Tensor) -> torch.Tensor: 502 | x = self.proj(x) 503 | # B C H W -> B H W C 504 | x = x.permute(0, 2, 3, 1) 505 | return x 506 | 507 | 508 | @dataclass 509 | class SAMViTCfg: 510 | image_size: Union[Tuple[int, int], int] = 1024 511 | width: int = 1024 512 | layers: int = 23 513 | heads: int = 16 514 | patch_size: int = 16 515 | window_size: int = 14 516 | prompt_embed_dim: int = 256 517 | global_attn_indexes: Union[List[int], Tuple[int]] = (5, 11, 17, 23) 518 | downsample_channels: Union[List[int], Tuple[int]] = (512, 1024) 519 | 520 | 521 | SAM_MODEL_CONFIG = { 522 | "sam_vit_b": { 523 | "width": 768, 524 | "layers": 12, 525 | "heads": 12, 526 | "global_attn_indexes": [2, 5, 8, 11], 527 | "downsample_channels": (), 528 | }, 529 | "sam_b_downsample": { 530 | "width": 768, 531 | "layers": 12, 532 | "heads": 12, 533 | "global_attn_indexes": [2, 5, 8, 11], 534 | "downsample_channels": (512, 1024), 535 | }, 536 | "sam_vit_l": { 537 | "width": 1024, 538 | "layers": 24, 539 | "heads": 16, 540 | "global_attn_indexes": [5, 11, 17, 23], 541 | "downsample_channels": (), 542 | }, 543 | "sam_vit_h": { 544 | "width": 1280, 545 | "layers": 32, 546 | "heads": 16, 547 | "global_attn_indexes": [7, 15, 23, 31], 548 | "downsample_channels": (), 549 | }, 550 | } 551 | 552 | 553 | def create_sam_vit( 554 | model_name: str = "sam_b_downsample", 555 | image_size: int = 1024, 556 | ckpt_path: str = "", 557 | **kwargs, 558 | ): 559 | assert ( 560 | model_name in SAM_MODEL_CONFIG.keys() 561 | ), f"model name: {model_name} should be in {SAM_MODEL_CONFIG.keys()}" 562 | 563 | sam_cfg = SAMViTCfg(**SAM_MODEL_CONFIG[model_name]) 564 | image_encoder = ImageEncoderViT( 565 | depth=sam_cfg.layers, 566 | embed_dim=sam_cfg.width, 567 | img_size=image_size, 568 | mlp_ratio=4, 569 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 570 | num_heads=sam_cfg.heads, 571 | patch_size=sam_cfg.patch_size, 572 | qkv_bias=True, 573 | use_rel_pos=True, 574 | global_attn_indexes=sam_cfg.global_attn_indexes, 575 | window_size=14, 576 | out_chans=sam_cfg.prompt_embed_dim, 577 | downsample_channels=sam_cfg.downsample_channels, 578 | ) 579 | 580 | if ckpt_path: 581 | state_dict = torch.load(ckpt_path) 582 | image_encoder.load_state_dict(state_dict, strict=False) 583 | print(f"SAM-ViT restores from {ckpt_path}") 584 | 585 | return image_encoder 586 | 587 | 588 | if __name__ == "__main__": 589 | x = torch.zeros(2, 3, 1024, 1024).bfloat16() 590 | # x.permute(0, 3, 1, 2) 591 | net = create_sam_vit().bfloat16() 592 | out = net(x) 593 | print(x.shape, out.shape) 594 | -------------------------------------------------------------------------------- /deepseek_vl/models/siglip_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | # https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py 21 | import math 22 | import warnings 23 | from dataclasses import dataclass 24 | from functools import partial 25 | from typing import ( 26 | Callable, 27 | Dict, 28 | Final, 29 | List, 30 | Literal, 31 | Optional, 32 | Sequence, 33 | Set, 34 | Tuple, 35 | Type, 36 | Union, 37 | ) 38 | 39 | import torch 40 | import torch.nn as nn 41 | import torch.nn.functional as F 42 | from timm.layers import ( 43 | AttentionPoolLatent, 44 | DropPath, 45 | LayerType, 46 | Mlp, 47 | PatchDropout, 48 | PatchEmbed, 49 | resample_abs_pos_embed, 50 | ) 51 | from timm.models._manipulate import checkpoint_seq, named_apply 52 | 53 | 54 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 55 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 56 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 57 | def norm_cdf(x): 58 | # Computes standard normal cumulative distribution function 59 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 60 | 61 | if (mean < a - 2 * std) or (mean > b + 2 * std): 62 | warnings.warn( 63 | "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 64 | "The distribution of values may be incorrect.", 65 | stacklevel=2, 66 | ) 67 | 68 | with torch.no_grad(): 69 | # Values are generated by using a truncated uniform distribution and 70 | # then using the inverse CDF for the normal distribution. 71 | # Get upper and lower cdf values 72 | l = norm_cdf((a - mean) / std) # noqa: E741 73 | u = norm_cdf((b - mean) / std) 74 | 75 | # Uniformly fill tensor with values from [l, u], then translate to 76 | # [2l-1, 2u-1]. 77 | tensor.uniform_(2 * l - 1, 2 * u - 1) 78 | 79 | # Use inverse cdf transform for normal distribution to get truncated 80 | # standard normal 81 | tensor.erfinv_() 82 | 83 | # Transform to proper mean, std 84 | tensor.mul_(std * math.sqrt(2.0)) 85 | tensor.add_(mean) 86 | 87 | # Clamp to ensure it's in the proper range 88 | tensor.clamp_(min=a, max=b) 89 | return tensor 90 | 91 | 92 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): 93 | # type: (torch.Tensor, float, float, float, float) -> torch.Tensor 94 | r"""The original timm.models.layers.weight_init.trunc_normal_ can not handle bfloat16 yet, here we first 95 | convert the tensor to float32, apply the trunc_normal_() in float32, and then convert it back to its orignal dtype. 96 | Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn 97 | from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 98 | with values outside :math:`[a, b]` redrawn until they are within 99 | the bounds. The method used for generating the random values works 100 | best when :math:`a \leq \text{mean} \leq b`. 101 | Args: 102 | tensor: an n-dimensional `torch.Tensor` 103 | mean: the mean of the normal distribution 104 | std: the standard deviation of the normal distribution 105 | a: the minimum cutoff value 106 | b: the maximum cutoff value 107 | Examples: 108 | >>> w = torch.empty(3, 5) 109 | >>> nn.init.trunc_normal_(w) 110 | """ 111 | 112 | with torch.no_grad(): 113 | dtype = tensor.dtype 114 | tensor_fp32 = tensor.float() 115 | tensor_fp32 = _no_grad_trunc_normal_(tensor_fp32, mean, std, a, b) 116 | tensor_dtype = tensor_fp32.to(dtype=dtype) 117 | tensor.copy_(tensor_dtype) 118 | 119 | 120 | def init_weights(self): 121 | if self.pos_embed is not None: 122 | trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) 123 | trunc_normal_(self.latent, std=self.latent_dim**-0.5) 124 | 125 | 126 | def init_weights_vit_timm(module: nn.Module, name: str = "") -> None: 127 | """ViT weight initialization, original timm impl (for reproducibility)""" 128 | if isinstance(module, nn.Linear): 129 | trunc_normal_(module.weight, std=0.02) 130 | if module.bias is not None: 131 | nn.init.zeros_(module.bias) 132 | elif hasattr(module, "init_weights"): 133 | module.init_weights() 134 | 135 | 136 | class Attention(nn.Module): 137 | fused_attn: Final[bool] 138 | 139 | def __init__( 140 | self, 141 | dim: int, 142 | num_heads: int = 8, 143 | qkv_bias: bool = False, 144 | qk_norm: bool = False, 145 | attn_drop: float = 0.0, 146 | proj_drop: float = 0.0, 147 | norm_layer: nn.Module = nn.LayerNorm, 148 | ) -> None: 149 | super().__init__() 150 | assert dim % num_heads == 0, "dim should be divisible by num_heads" 151 | self.num_heads = num_heads 152 | self.head_dim = dim // num_heads 153 | self.scale = self.head_dim**-0.5 154 | # self.fused_attn = use_fused_attn() 155 | self.fused_attn = True 156 | 157 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 158 | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 159 | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() 160 | self.attn_drop = nn.Dropout(attn_drop) 161 | self.proj = nn.Linear(dim, dim) 162 | self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity() 163 | 164 | def forward(self, x: torch.Tensor) -> torch.Tensor: 165 | B, N, C = x.shape 166 | qkv = ( 167 | self.qkv(x) 168 | .reshape(B, N, 3, self.num_heads, self.head_dim) 169 | .permute(2, 0, 3, 1, 4) 170 | ) 171 | q, k, v = qkv.unbind(0) 172 | q, k = self.q_norm(q), self.k_norm(k) 173 | 174 | if self.fused_attn: 175 | x = F.scaled_dot_product_attention( 176 | q, 177 | k, 178 | v, 179 | dropout_p=self.attn_drop.p if self.training else 0.0, 180 | ) 181 | else: 182 | q = q * self.scale 183 | attn = q @ k.transpose(-2, -1) 184 | attn = attn.softmax(dim=-1) 185 | attn = self.attn_drop(attn) 186 | x = attn @ v 187 | 188 | x = x.transpose(1, 2).reshape(B, N, C) 189 | x = self.proj(x) 190 | x = self.proj_drop(x) 191 | return x 192 | 193 | 194 | class LayerScale(nn.Module): 195 | def __init__( 196 | self, 197 | dim: int, 198 | init_values: float = 1e-5, 199 | inplace: bool = False, 200 | ) -> None: 201 | super().__init__() 202 | self.inplace = inplace 203 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 204 | 205 | def forward(self, x: torch.Tensor) -> torch.Tensor: 206 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 207 | 208 | 209 | class Block(nn.Module): 210 | def __init__( 211 | self, 212 | dim: int, 213 | num_heads: int, 214 | mlp_ratio: float = 4.0, 215 | qkv_bias: bool = False, 216 | qk_norm: bool = False, 217 | proj_drop: float = 0.0, 218 | attn_drop: float = 0.0, 219 | init_values: Optional[float] = None, 220 | drop_path: float = 0.0, 221 | act_layer: nn.Module = nn.GELU, 222 | norm_layer: nn.Module = nn.LayerNorm, 223 | mlp_layer: nn.Module = Mlp, 224 | ) -> None: 225 | super().__init__() 226 | self.norm1 = norm_layer(dim) 227 | self.attn = Attention( 228 | dim, 229 | num_heads=num_heads, 230 | qkv_bias=qkv_bias, 231 | qk_norm=qk_norm, 232 | attn_drop=attn_drop, 233 | proj_drop=proj_drop, 234 | norm_layer=norm_layer, 235 | ) 236 | self.ls1 = ( 237 | LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 238 | ) 239 | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 240 | 241 | self.norm2 = norm_layer(dim) 242 | self.mlp = mlp_layer( 243 | in_features=dim, 244 | hidden_features=int(dim * mlp_ratio), 245 | act_layer=act_layer, 246 | drop=proj_drop, 247 | ) 248 | self.ls2 = ( 249 | LayerScale(dim, init_values=init_values) if init_values else nn.Identity() 250 | ) 251 | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 252 | 253 | def forward(self, x: torch.Tensor) -> torch.Tensor: 254 | x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) 255 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) 256 | return x 257 | 258 | 259 | class VisionTransformer(nn.Module): 260 | """Vision Transformer 261 | 262 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 263 | - https://arxiv.org/abs/2010.11929 264 | """ 265 | 266 | dynamic_img_size: Final[bool] 267 | 268 | def __init__( 269 | self, 270 | img_size: Union[int, Tuple[int, int]] = 224, 271 | patch_size: Union[int, Tuple[int, int]] = 16, 272 | in_chans: int = 3, 273 | num_classes: int = 1000, 274 | global_pool: Literal["", "avg", "token", "map"] = "token", 275 | embed_dim: int = 768, 276 | depth: int = 12, 277 | num_heads: int = 12, 278 | mlp_ratio: float = 4.0, 279 | qkv_bias: bool = True, 280 | qk_norm: bool = False, 281 | init_values: Optional[float] = None, 282 | class_token: bool = True, 283 | no_embed_class: bool = False, 284 | reg_tokens: int = 0, 285 | pre_norm: bool = False, 286 | fc_norm: Optional[bool] = None, 287 | dynamic_img_size: bool = False, 288 | dynamic_img_pad: bool = False, 289 | drop_rate: float = 0.0, 290 | pos_drop_rate: float = 0.0, 291 | patch_drop_rate: float = 0.0, 292 | proj_drop_rate: float = 0.0, 293 | attn_drop_rate: float = 0.0, 294 | drop_path_rate: float = 0.0, 295 | weight_init: Literal["skip", "jax", "jax_nlhb", "moco", ""] = "", 296 | embed_layer: Callable = PatchEmbed, 297 | norm_layer: Optional[LayerType] = None, 298 | act_layer: Optional[LayerType] = None, 299 | block_fn: Type[nn.Module] = Block, 300 | mlp_layer: Type[nn.Module] = Mlp, 301 | ignore_head: bool = False, 302 | ) -> None: 303 | """ 304 | Args: 305 | img_size: Input image size. 306 | patch_size: Patch size. 307 | in_chans: Number of image input channels. 308 | num_classes: Mumber of classes for classification head. 309 | global_pool: Type of global pooling for final sequence (default: 'token'). 310 | embed_dim: Transformer embedding dimension. 311 | depth: Depth of transformer. 312 | num_heads: Number of attention heads. 313 | mlp_ratio: Ratio of mlp hidden dim to embedding dim. 314 | qkv_bias: Enable bias for qkv projections if True. 315 | init_values: Layer-scale init values (layer-scale enabled if not None). 316 | class_token: Use class token. 317 | no_embed_class: Don't include position embeddings for class (or reg) tokens. 318 | reg_tokens: Number of register tokens. 319 | fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. 320 | drop_rate: Head dropout rate. 321 | pos_drop_rate: Position embedding dropout rate. 322 | attn_drop_rate: Attention dropout rate. 323 | drop_path_rate: Stochastic depth rate. 324 | weight_init: Weight initialization scheme. 325 | embed_layer: Patch embedding layer. 326 | norm_layer: Normalization layer. 327 | act_layer: MLP activation layer. 328 | block_fn: Transformer block layer. 329 | """ 330 | super().__init__() 331 | assert global_pool in ("", "avg", "token", "map") 332 | assert class_token or global_pool != "token" 333 | use_fc_norm = global_pool == "avg" if fc_norm is None else fc_norm 334 | # norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6) 335 | # act_layer = get_act_layer(act_layer) or nn.GELU 336 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 337 | act_layer = nn.GELU 338 | 339 | self.num_classes = num_classes 340 | self.global_pool = global_pool 341 | self.num_features = self.embed_dim = ( 342 | embed_dim # num_features for consistency with other models 343 | ) 344 | self.num_prefix_tokens = 1 if class_token else 0 345 | self.num_prefix_tokens += reg_tokens 346 | self.num_reg_tokens = reg_tokens 347 | self.has_class_token = class_token 348 | self.no_embed_class = ( 349 | no_embed_class # don't embed prefix positions (includes reg) 350 | ) 351 | self.dynamic_img_size = dynamic_img_size 352 | self.grad_checkpointing = False 353 | self.ignore_head = ignore_head 354 | 355 | embed_args = {} 356 | if dynamic_img_size: 357 | # flatten deferred until after pos embed 358 | embed_args.update(dict(strict_img_size=False, output_fmt="NHWC")) 359 | self.patch_embed = embed_layer( 360 | img_size=img_size, 361 | patch_size=patch_size, 362 | in_chans=in_chans, 363 | embed_dim=embed_dim, 364 | bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) 365 | dynamic_img_pad=dynamic_img_pad, 366 | **embed_args, 367 | ) 368 | num_patches = self.patch_embed.num_patches 369 | 370 | self.cls_token = ( 371 | nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None 372 | ) 373 | self.reg_token = ( 374 | nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None 375 | ) 376 | embed_len = ( 377 | num_patches if no_embed_class else num_patches + self.num_prefix_tokens 378 | ) 379 | self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * 0.02) 380 | self.pos_drop = nn.Dropout(p=pos_drop_rate) 381 | if patch_drop_rate > 0: 382 | self.patch_drop = PatchDropout( 383 | patch_drop_rate, 384 | num_prefix_tokens=self.num_prefix_tokens, 385 | ) 386 | else: 387 | self.patch_drop = nn.Identity() 388 | self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() 389 | 390 | dpr = [ 391 | x.item() for x in torch.linspace(0, drop_path_rate, depth) 392 | ] # stochastic depth decay rule 393 | self.blocks = nn.Sequential( 394 | *[ 395 | block_fn( 396 | dim=embed_dim, 397 | num_heads=num_heads, 398 | mlp_ratio=mlp_ratio, 399 | qkv_bias=qkv_bias, 400 | qk_norm=qk_norm, 401 | init_values=init_values, 402 | proj_drop=proj_drop_rate, 403 | attn_drop=attn_drop_rate, 404 | drop_path=dpr[i], 405 | norm_layer=norm_layer, 406 | act_layer=act_layer, 407 | mlp_layer=mlp_layer, 408 | ) 409 | for i in range(depth) 410 | ] 411 | ) 412 | self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() 413 | 414 | # Classifier Head 415 | if global_pool == "map": 416 | AttentionPoolLatent.init_weights = init_weights 417 | self.attn_pool = AttentionPoolLatent( 418 | self.embed_dim, 419 | num_heads=num_heads, 420 | mlp_ratio=mlp_ratio, 421 | norm_layer=norm_layer, 422 | ) 423 | else: 424 | self.attn_pool = None 425 | self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() 426 | self.head_drop = nn.Dropout(drop_rate) 427 | self.head = ( 428 | nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 429 | ) 430 | 431 | if weight_init != "skip": 432 | self.init_weights(weight_init) 433 | 434 | def init_weights(self, mode: Literal["jax", "jax_nlhb", "moco", ""] = "") -> None: 435 | assert mode in ("jax", "jax_nlhb", "moco", "") 436 | # head_bias = -math.log(self.num_classes) if "nlhb" in mode else 0.0 437 | trunc_normal_(self.pos_embed, std=0.02) 438 | if self.cls_token is not None: 439 | nn.init.normal_(self.cls_token, std=1e-6) 440 | named_apply(init_weights_vit_timm, self) 441 | 442 | @torch.jit.ignore 443 | def no_weight_decay(self) -> Set: 444 | return {"pos_embed", "cls_token", "dist_token"} 445 | 446 | @torch.jit.ignore 447 | def group_matcher(self, coarse: bool = False) -> Dict: 448 | return dict( 449 | stem=r"^cls_token|pos_embed|patch_embed", # stem and embed 450 | blocks=[(r"^blocks\.(\d+)", None), (r"^norm", (99999,))], 451 | ) 452 | 453 | @torch.jit.ignore 454 | def set_grad_checkpointing(self, enable: bool = True) -> None: 455 | self.grad_checkpointing = enable 456 | 457 | @torch.jit.ignore 458 | def get_classifier(self) -> nn.Module: 459 | return self.head 460 | 461 | def reset_classifier(self, num_classes: int, global_pool=None) -> None: 462 | self.num_classes = num_classes 463 | if global_pool is not None: 464 | assert global_pool in ("", "avg", "token", "map") 465 | if global_pool == "map" and self.attn_pool is None: 466 | assert ( 467 | False 468 | ), "Cannot currently add attention pooling in reset_classifier()." 469 | elif global_pool != "map " and self.attn_pool is not None: 470 | self.attn_pool = None # remove attention pooling 471 | self.global_pool = global_pool 472 | self.head = ( 473 | nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 474 | ) 475 | 476 | def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: 477 | if self.dynamic_img_size: 478 | B, H, W, C = x.shape 479 | pos_embed = resample_abs_pos_embed( 480 | self.pos_embed, 481 | (H, W), 482 | num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens, 483 | ) 484 | x = x.view(B, -1, C) 485 | else: 486 | pos_embed = self.pos_embed 487 | 488 | to_cat = [] 489 | if self.cls_token is not None: 490 | to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) 491 | if self.reg_token is not None: 492 | to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) 493 | 494 | if self.no_embed_class: 495 | # deit-3, updated JAX (big vision) 496 | # position embedding does not overlap with class token, add then concat 497 | x = x + pos_embed 498 | if to_cat: 499 | x = torch.cat(to_cat + [x], dim=1) 500 | else: 501 | # original timm, JAX, and deit vit impl 502 | # pos_embed has entry for class token, concat then add 503 | if to_cat: 504 | x = torch.cat(to_cat + [x], dim=1) 505 | x = x + pos_embed 506 | 507 | return self.pos_drop(x) 508 | 509 | def _intermediate_layers( 510 | self, 511 | x: torch.Tensor, 512 | n: Union[int, Sequence] = 1, 513 | ) -> List[torch.Tensor]: 514 | outputs, num_blocks = [], len(self.blocks) 515 | take_indices = set( 516 | range(num_blocks - n, num_blocks) if isinstance(n, int) else n 517 | ) 518 | 519 | # forward pass 520 | x = self.patch_embed(x) 521 | x = self._pos_embed(x) 522 | x = self.patch_drop(x) 523 | x = self.norm_pre(x) 524 | for i, blk in enumerate(self.blocks): 525 | x = blk(x) 526 | if i in take_indices: 527 | outputs.append(x) 528 | 529 | return outputs 530 | 531 | def get_intermediate_layers( 532 | self, 533 | x: torch.Tensor, 534 | n: Union[int, Sequence] = 1, 535 | reshape: bool = False, 536 | return_prefix_tokens: bool = False, 537 | norm: bool = False, 538 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: 539 | """Intermediate layer accessor (NOTE: This is a WIP experiment). 540 | Inspired by DINO / DINOv2 interface 541 | """ 542 | # take last n blocks if n is an int, if in is a sequence, select by matching indices 543 | outputs = self._intermediate_layers(x, n) 544 | if norm: 545 | outputs = [self.norm(out) for out in outputs] 546 | prefix_tokens = [out[:, 0 : self.num_prefix_tokens] for out in outputs] 547 | outputs = [out[:, self.num_prefix_tokens :] for out in outputs] 548 | 549 | if reshape: 550 | grid_size = self.patch_embed.grid_size 551 | outputs = [ 552 | out.reshape(x.shape[0], grid_size[0], grid_size[1], -1) 553 | .permute(0, 3, 1, 2) 554 | .contiguous() 555 | for out in outputs 556 | ] 557 | 558 | if return_prefix_tokens: 559 | return tuple(zip(outputs, prefix_tokens)) 560 | return tuple(outputs) 561 | 562 | def forward_features(self, x: torch.Tensor) -> torch.Tensor: 563 | x = self.patch_embed(x) 564 | x = self._pos_embed(x) 565 | x = self.patch_drop(x) 566 | x = self.norm_pre(x) 567 | if self.grad_checkpointing and not torch.jit.is_scripting(): 568 | x = checkpoint_seq(self.blocks, x) 569 | else: 570 | x = self.blocks(x) 571 | x = self.norm(x) 572 | return x 573 | 574 | def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: 575 | if self.attn_pool is not None: 576 | x = self.attn_pool(x) 577 | elif self.global_pool == "avg": 578 | x = x[:, self.num_prefix_tokens :].mean(dim=1) 579 | elif self.global_pool: 580 | x = x[:, 0] # class token 581 | x = self.fc_norm(x) 582 | x = self.head_drop(x) 583 | return x if pre_logits else self.head(x) 584 | 585 | def forward(self, x: torch.Tensor) -> torch.Tensor: 586 | x = self.forward_features(x) 587 | if not self.ignore_head: 588 | x = self.forward_head(x) 589 | return x 590 | 591 | 592 | @dataclass 593 | class SigLIPVisionCfg: 594 | width: int = 1152 595 | layers: Union[Tuple[int, int, int, int], int] = 27 596 | heads: int = 16 597 | patch_size: int = 14 598 | image_size: Union[Tuple[int, int], int] = 336 599 | global_pool: str = "map" 600 | mlp_ratio: float = 3.7362 601 | class_token: bool = False 602 | num_classes: int = 0 603 | use_checkpoint: bool = False 604 | 605 | 606 | SigLIP_MODEL_CONFIG = { 607 | "siglip_so400m_patch14_384": { 608 | "image_size": 336, 609 | "patch_size": 14, 610 | "width": 1152, 611 | "layers": 27, 612 | "heads": 16, 613 | "mlp_ratio": 3.7362, 614 | "global_pool": "map", 615 | "use_checkpoint": False, 616 | }, 617 | "siglip_so400m_patch14_224": { 618 | "image_size": 224, 619 | "patch_size": 14, 620 | "width": 1152, 621 | "layers": 27, 622 | "heads": 16, 623 | "mlp_ratio": 3.7362, 624 | "global_pool": "map", 625 | "use_checkpoint": False, 626 | }, 627 | "siglip_large_patch16_384": { 628 | "image_size": 384, 629 | "patch_size": 16, 630 | "width": 1024, 631 | "layers": 24, 632 | "heads": 16, 633 | "mlp_ratio": 4, 634 | "global_pool": "map", 635 | "use_checkpoint": False, 636 | }, 637 | } 638 | 639 | 640 | def create_siglip_vit( 641 | model_name: str = "siglip_so400m_patch14_384", 642 | image_size: int = 384, 643 | select_layer: int = -1, 644 | ckpt_path: str = "", 645 | **kwargs, 646 | ): 647 | assert ( 648 | model_name in SigLIP_MODEL_CONFIG.keys() 649 | ), f"model name should be in {SigLIP_MODEL_CONFIG.keys()}" 650 | 651 | vision_cfg = SigLIPVisionCfg(**SigLIP_MODEL_CONFIG[model_name]) 652 | 653 | if select_layer <= 0: 654 | layers = min(vision_cfg.layers, vision_cfg.layers + select_layer + 1) 655 | else: 656 | layers = min(vision_cfg.layers, select_layer) 657 | 658 | model = VisionTransformer( 659 | img_size=image_size, 660 | patch_size=vision_cfg.patch_size, 661 | embed_dim=vision_cfg.width, 662 | depth=layers, 663 | num_heads=vision_cfg.heads, 664 | mlp_ratio=vision_cfg.mlp_ratio, 665 | class_token=vision_cfg.class_token, 666 | global_pool=vision_cfg.global_pool, 667 | ignore_head=kwargs.get("ignore_head", True), 668 | weight_init=kwargs.get("weight_init", "skip"), 669 | num_classes=0, 670 | ) 671 | 672 | if ckpt_path: 673 | state_dict = torch.load(ckpt_path, map_location="cpu") 674 | 675 | incompatible_keys = model.load_state_dict(state_dict, strict=False) 676 | print( 677 | f"SigLIP-ViT restores from {ckpt_path},\n" 678 | f"\tincompatible_keys:', {incompatible_keys}." 679 | ) 680 | 681 | return model 682 | -------------------------------------------------------------------------------- /deepseek_vl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | -------------------------------------------------------------------------------- /deepseek_vl/utils/conversation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | """ 21 | From https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py 22 | """ 23 | 24 | import dataclasses 25 | from enum import IntEnum, auto 26 | from typing import Dict, List 27 | 28 | 29 | class SeparatorStyle(IntEnum): 30 | """Separator styles.""" 31 | 32 | ADD_COLON_SINGLE = auto() 33 | ADD_COLON_TWO = auto() 34 | ADD_COLON_SPACE_SINGLE = auto() 35 | NO_COLON_SINGLE = auto() 36 | NO_COLON_TWO = auto() 37 | ADD_NEW_LINE_SINGLE = auto() 38 | LLAMA2 = auto() 39 | CHATGLM = auto() 40 | CHATML = auto() 41 | CHATINTERN = auto() 42 | DOLLY = auto() 43 | RWKV = auto() 44 | PHOENIX = auto() 45 | ROBIN = auto() 46 | DeepSeek = auto() 47 | PLAIN = auto() 48 | ALIGNMENT = auto() 49 | 50 | 51 | @dataclasses.dataclass 52 | class Conversation: 53 | """A class that manages prompt templates and keeps all conversation history.""" 54 | 55 | # The name of this template 56 | name: str 57 | # The template of the system prompt 58 | system_template: str = "{system_message}" 59 | # The system message 60 | system_message: str = "" 61 | # The names of two roles 62 | roles: List[str] = (("USER", "ASSISTANT"),) 63 | # All messages. Each item is (role, message). 64 | messages: List[List[str]] = () 65 | # The number of few shot examples 66 | offset: int = 0 67 | # The separator style and configurations 68 | sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE 69 | sep: str = "\n" 70 | sep2: str = None 71 | # Stop criteria (the default one is EOS token) 72 | stop_str: str = None 73 | # Stops generation if meeting any token in this list 74 | stop_token_ids: List[int] = None 75 | 76 | def get_prompt(self) -> str: 77 | """Get the prompt for generation.""" 78 | system_prompt = self.system_template.format(system_message=self.system_message) 79 | 80 | if self.sep_style == SeparatorStyle.DeepSeek: 81 | seps = [self.sep, self.sep2] 82 | if system_prompt == "" or system_prompt is None: 83 | ret = "" 84 | else: 85 | ret = system_prompt + seps[0] 86 | for i, (role, message) in enumerate(self.messages): 87 | if message: 88 | ret += role + ": " + message + seps[i % 2] 89 | else: 90 | ret += role + ":" 91 | return ret 92 | elif self.sep_style == SeparatorStyle.LLAMA2: 93 | seps = [self.sep, self.sep2] 94 | if self.system_message: 95 | ret = system_prompt 96 | else: 97 | ret = "[INST] " 98 | for i, (role, message) in enumerate(self.messages): 99 | tag = self.roles[i % 2] 100 | if message: 101 | if type(message) is tuple: # multimodal message 102 | message, _ = message 103 | if i == 0: 104 | ret += message + " " 105 | else: 106 | ret += tag + " " + message + seps[i % 2] 107 | else: 108 | ret += tag 109 | return ret 110 | elif self.sep_style == SeparatorStyle.PLAIN: 111 | seps = [self.sep, self.sep2] 112 | ret = "" 113 | for i, (role, message) in enumerate(self.messages): 114 | if message: 115 | if type(message) is tuple: 116 | message, _, _ = message 117 | if i % 2 == 0: 118 | ret += message + seps[i % 2] 119 | else: 120 | ret += message + seps[i % 2] 121 | else: 122 | ret += "" 123 | return ret 124 | elif self.sep_style == SeparatorStyle.ALIGNMENT: 125 | seps = [self.sep, self.sep2] 126 | ret = "" 127 | for i, (role, message) in enumerate(self.messages): 128 | if message: 129 | if type(message) is tuple: 130 | message, _, _ = message 131 | if i % 2 == 0: 132 | ret += "\n" + seps[i % 2] 133 | else: 134 | ret += message + seps[i % 2] 135 | else: 136 | ret += "" 137 | return ret 138 | else: 139 | raise ValueError(f"Invalid style: {self.sep_style}") 140 | 141 | def get_prompt_for_current_round(self, content=None): 142 | """Get current round formatted question prompt during sft training""" 143 | if self.sep_style == SeparatorStyle.PLAIN: 144 | formatted_question = "\n" 145 | elif self.sep_style == SeparatorStyle.DeepSeek: 146 | formatted_question = ( 147 | f"{self.roles[0]}: " + content.strip() + self.sep + f"{self.roles[1]}:" 148 | ) 149 | else: 150 | raise ValueError(f"Unsupported sep_style: {self.sep_style}") 151 | return formatted_question 152 | 153 | def set_system_message(self, system_message: str): 154 | """Set the system message.""" 155 | self.system_message = system_message 156 | 157 | def append_message(self, role: str, message: str): 158 | """Append a new message.""" 159 | self.messages.append([role, message]) 160 | 161 | def reset_message(self): 162 | """Reset a new message.""" 163 | self.messages = [] 164 | 165 | def update_last_message(self, message: str): 166 | """Update the last output. 167 | 168 | The last message is typically set to be None when constructing the prompt, 169 | so we need to update it in-place after getting the response from a model. 170 | """ 171 | self.messages[-1][1] = message 172 | 173 | def to_gradio_chatbot(self): 174 | """Convert the conversation to gradio chatbot format.""" 175 | ret = [] 176 | for i, (role, msg) in enumerate(self.messages[self.offset :]): 177 | if i % 2 == 0: 178 | ret.append([msg, None]) 179 | else: 180 | ret[-1][-1] = msg 181 | return ret 182 | 183 | def to_openai_api_messages(self): 184 | """Convert the conversation to OpenAI chat completion format.""" 185 | system_prompt = self.system_template.format(system_message=self.system_message) 186 | ret = [{"role": "system", "content": system_prompt}] 187 | 188 | for i, (_, msg) in enumerate(self.messages[self.offset :]): 189 | if i % 2 == 0: 190 | ret.append({"role": "user", "content": msg}) 191 | else: 192 | if msg is not None: 193 | ret.append({"role": "assistant", "content": msg}) 194 | return ret 195 | 196 | def copy(self): 197 | return Conversation( 198 | name=self.name, 199 | system_template=self.system_template, 200 | system_message=self.system_message, 201 | roles=self.roles, 202 | messages=[[x, y] for x, y in self.messages], 203 | offset=self.offset, 204 | sep_style=self.sep_style, 205 | sep=self.sep, 206 | sep2=self.sep2, 207 | stop_str=self.stop_str, 208 | stop_token_ids=self.stop_token_ids, 209 | ) 210 | 211 | def dict(self): 212 | return { 213 | "template_name": self.name, 214 | "system_message": self.system_message, 215 | "roles": self.roles, 216 | "messages": self.messages, 217 | "offset": self.offset, 218 | } 219 | 220 | 221 | # A global registry for all conversation templates 222 | conv_templates: Dict[str, Conversation] = {} 223 | 224 | 225 | def register_conv_template(template: Conversation, override: bool = False): 226 | """Register a new conversation template.""" 227 | if not override: 228 | assert ( 229 | template.name not in conv_templates 230 | ), f"{template.name} has been registered." 231 | 232 | conv_templates[template.name] = template 233 | 234 | 235 | def get_conv_template(name: str) -> Conversation: 236 | """Get a conversation template.""" 237 | return conv_templates[name].copy() 238 | 239 | 240 | # llava_llama2 template 241 | register_conv_template( 242 | Conversation( 243 | name="llava_llama2", 244 | system_message="You are a helpful language and vision assistant. " 245 | "You are able to understand the visual content that the user provides, " 246 | "and assist the user with a variety of tasks using natural language.", 247 | system_template="[INST] <>\n{system_message}\n<>\n\n", 248 | roles=("[INST]", "[/INST]"), 249 | messages=(), 250 | offset=0, 251 | sep_style=SeparatorStyle.LLAMA2, 252 | sep=" ", 253 | sep2=" ", 254 | stop_token_ids=[2], 255 | ) 256 | ) 257 | 258 | # llama2 template 259 | # reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212 260 | register_conv_template( 261 | Conversation( 262 | name="llama-2", 263 | system_template="[INST] <>\n{system_message}\n<>\n\n", 264 | roles=("[INST]", "[/INST]"), 265 | messages=(), 266 | offset=0, 267 | sep_style=SeparatorStyle.LLAMA2, 268 | sep=" ", 269 | sep2=" ", 270 | stop_token_ids=[2], 271 | ) 272 | ) 273 | 274 | 275 | # deepseek template 276 | register_conv_template( 277 | Conversation( 278 | name="deepseek", 279 | system_template="{system_message}", 280 | # system_message="You are a helpful assistant. Please answer truthfully and write out your " 281 | # "thinking step by step to be sure you get the right answer.", 282 | system_message="", 283 | roles=("User", "Assistant"), 284 | messages=(), 285 | offset=0, 286 | sep_style=SeparatorStyle.DeepSeek, 287 | sep="\n\n", 288 | sep2="<|end▁of▁sentence|>", 289 | stop_token_ids=[100001], 290 | stop_str=["User:", "<|end▁of▁sentence|>"], 291 | ) 292 | ) 293 | 294 | register_conv_template( 295 | Conversation( 296 | name="plain", 297 | system_template="", 298 | system_message="", 299 | roles=("", ""), 300 | messages=(), 301 | offset=0, 302 | sep_style=SeparatorStyle.PLAIN, 303 | sep="", 304 | sep2="", 305 | stop_token_ids=[2], 306 | stop_str=[""], 307 | ) 308 | ) 309 | 310 | 311 | register_conv_template( 312 | Conversation( 313 | name="alignment", 314 | system_template="", 315 | system_message="", 316 | roles=("", ""), 317 | messages=(), 318 | offset=0, 319 | sep_style=SeparatorStyle.ALIGNMENT, 320 | sep="", 321 | sep2="", 322 | stop_token_ids=[2], 323 | stop_str=[""], 324 | ) 325 | ) 326 | 327 | 328 | if __name__ == "__main__": 329 | # print("Llama-2 template:") 330 | # conv = get_conv_template("llama-2") 331 | # conv.set_system_message("You are a helpful, respectful and honest assistant.") 332 | # conv.append_message(conv.roles[0], "Hello!") 333 | # conv.append_message(conv.roles[1], "Hi!") 334 | # conv.append_message(conv.roles[0], "How are you?") 335 | # conv.append_message(conv.roles[1], None) 336 | # print(conv.get_prompt()) 337 | 338 | # print("\n") 339 | 340 | print("deepseek template:") 341 | conv = get_conv_template("deepseek") 342 | conv.append_message(conv.roles[0], "Hello!") 343 | conv.append_message(conv.roles[1], "Hi! This is Tony.") 344 | conv.append_message(conv.roles[0], "Who are you?") 345 | conv.append_message(conv.roles[1], "I am a helpful assistant.") 346 | conv.append_message(conv.roles[0], "How are you?") 347 | conv.append_message(conv.roles[1], None) 348 | print(conv.get_prompt()) 349 | -------------------------------------------------------------------------------- /deepseek_vl/utils/io.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023-2024 DeepSeek. 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | # the Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 15 | # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 16 | # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 17 | # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 18 | # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 19 | 20 | import json 21 | from typing import Dict, List 22 | 23 | import PIL.Image 24 | import torch 25 | from transformers import AutoModelForCausalLM 26 | 27 | from deepseek_vl.models import MultiModalityCausalLM, VLChatProcessor 28 | 29 | 30 | def load_pretrained_model(model_path: str): 31 | vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) 32 | tokenizer = vl_chat_processor.tokenizer 33 | 34 | vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained( 35 | model_path, trust_remote_code=True 36 | ) 37 | vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval() 38 | 39 | return tokenizer, vl_chat_processor, vl_gpt 40 | 41 | 42 | def load_pil_images(conversations: List[Dict[str, str]]) -> List[PIL.Image.Image]: 43 | """ 44 | 45 | Args: 46 | conversations (List[Dict[str, str]]): the conversations with a list of messages. An example is : 47 | [ 48 | { 49 | "role": "User", 50 | "content": "\nExtract all information from this image and convert them into markdown format.", 51 | "images": ["./examples/table_datasets.png"] 52 | }, 53 | {"role": "Assistant", "content": ""}, 54 | ] 55 | 56 | Returns: 57 | pil_images (List[PIL.Image.Image]): the list of PIL images. 58 | 59 | """ 60 | 61 | pil_images = [] 62 | 63 | for message in conversations: 64 | if "images" not in message: 65 | continue 66 | 67 | for image_path in message["images"]: 68 | pil_img = PIL.Image.open(image_path) 69 | pil_img = pil_img.convert("RGB") 70 | pil_images.append(pil_img) 71 | 72 | return pil_images 73 | 74 | 75 | def load_json(filepath): 76 | with open(filepath, "r") as f: 77 | data = json.load(f) 78 | return data 79 | -------------------------------------------------------------------------------- /nodes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.transforms import ToPILImage 4 | 5 | from transformers import AutoModelForCausalLM 6 | from .deepseek_vl.models import VLChatProcessor, MultiModalityCausalLM 7 | 8 | import comfy.model_management as mm 9 | from comfy.utils import ProgressBar 10 | import folder_paths 11 | 12 | script_directory = os.path.dirname(os.path.abspath(__file__)) 13 | 14 | class deepseek_vl_model_loader: 15 | 16 | @classmethod 17 | def INPUT_TYPES(s): 18 | return {"required": { 19 | "model": ( 20 | [ 21 | "deepseek-vl-1.3b-chat", 22 | "deepseek-vl-1.3b-base", 23 | "deepseek-vl-7b-chat", 24 | "deepseek-vl-7b-base", 25 | ], 26 | { 27 | "default": "deepseek-vl-7b-chat" 28 | }), 29 | }, 30 | } 31 | 32 | RETURN_TYPES = ("DEEPSEEKVLMODEL",) 33 | RETURN_NAMES = ("deepseek_model",) 34 | FUNCTION = "loadmodel" 35 | CATEGORY = "deepseek-vl" 36 | 37 | def loadmodel(self, model): 38 | mm.soft_empty_cache() 39 | dtype = mm.vae_dtype() 40 | device = mm.get_torch_device() 41 | custom_config = { 42 | "model": model, 43 | } 44 | if not hasattr(self, "model") or custom_config != self.current_config: 45 | self.current_config = custom_config 46 | model_dir = (os.path.join(folder_paths.models_dir, "LLM", "deepseek-vl")) 47 | checkpoint_path = os.path.join(model_dir, model) 48 | 49 | if not os.path.exists(checkpoint_path): 50 | print(f"Downloading {model}") 51 | from huggingface_hub import snapshot_download 52 | 53 | snapshot_download(repo_id=f"deepseek-ai/{model}", 54 | local_dir=checkpoint_path, 55 | local_dir_use_symlinks=False 56 | ) 57 | model_path = checkpoint_path 58 | else: 59 | model_path = os.path.join(folder_paths.models_dir, "LLM", "deepseek-vl", model) 60 | print(f"Loading model from {model_path}") 61 | 62 | vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) 63 | tokenizer = vl_chat_processor.tokenizer 64 | 65 | vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) 66 | vl_gpt = vl_gpt.to(dtype).to(device).eval() 67 | 68 | deepseek_vl_model = { 69 | "chat_processor": vl_chat_processor, 70 | "model": vl_gpt, 71 | "tokenizer": tokenizer 72 | } 73 | 74 | return (deepseek_vl_model,) 75 | 76 | class deepseek_vl_inference: 77 | @classmethod 78 | def INPUT_TYPES(s): 79 | return {"required": { 80 | "images": ("IMAGE",), 81 | "deepseek_vl_model": ("DEEPSEEKVLMODEL",), 82 | #"parameters": ("LLAMACPPARAMS", ), 83 | "prompt": ("STRING", {"multiline": True, "default": "Describe the image in detail.",}), 84 | 85 | }, 86 | } 87 | 88 | RETURN_TYPES = ("STRING",) 89 | RETURN_NAMES = ("output",) 90 | FUNCTION = "process" 91 | CATEGORY = "Llama-cpp" 92 | 93 | def process(self, images, deepseek_vl_model, prompt): 94 | 95 | mm.soft_empty_cache() 96 | device = mm.get_torch_device() 97 | offload_device = mm.unet_offload_device() 98 | images = images.permute(0, 3, 1, 2) 99 | to_pil = ToPILImage() 100 | 101 | 102 | vl_chat_processor = deepseek_vl_model["chat_processor"] 103 | model = deepseek_vl_model["model"] 104 | tokenizer = deepseek_vl_model["tokenizer"] 105 | ## single image conversation example 106 | 107 | conversation = [ 108 | { 109 | "role": "User", 110 | "content": f"{prompt}", 111 | }, 112 | {"role": "Assistant", "content": ""}, 113 | ] 114 | 115 | ## multiple images (or in-context learning) conversation example 116 | # conversation = [ 117 | # { 118 | # "role": "User", 119 | # "content": "A dog wearing nothing in the foreground, " 120 | # "a dog wearing a santa hat, " 121 | # "a dog wearing a wizard outfit, and " 122 | # "what"s the dog wearing?", 123 | # "images": [ 124 | # "images/dog_a.png", 125 | # "images/dog_b.png", 126 | # "images/dog_c.png", 127 | # "images/dog_d.png", 128 | # ], 129 | # }, 130 | # {"role": "Assistant", "content": ""} 131 | # ] 132 | pbar = ProgressBar(len(images)) 133 | answer_list = [] 134 | model.to(device) 135 | for img in images: 136 | pil_image = to_pil(img) 137 | prepare_inputs = vl_chat_processor( 138 | conversations=conversation, 139 | images=[pil_image], 140 | force_batchify=True 141 | ).to(device) 142 | 143 | # run image encoder to get the image embeddings 144 | inputs_embeds = model.prepare_inputs_embeds(**prepare_inputs) 145 | 146 | # run the model to get the response 147 | outputs = model.language_model.generate( 148 | inputs_embeds=inputs_embeds, 149 | attention_mask=prepare_inputs.attention_mask, 150 | pad_token_id=tokenizer.eos_token_id, 151 | bos_token_id=tokenizer.bos_token_id, 152 | eos_token_id=tokenizer.eos_token_id, 153 | max_new_tokens=512, 154 | do_sample=False, 155 | use_cache=True 156 | ) 157 | answer = tokenizer.decode(outputs[0].cpu(), skip_special_tokens=True) 158 | answer = answer.lstrip(" [User]\n\n") 159 | answer_list.append(answer) 160 | pbar.update(1) 161 | 162 | model.to(offload_device) 163 | #print(f"{prepare_inputs['sft_format'][0]}", answer) 164 | if (len(images)) > 1: 165 | return (answer_list,) 166 | else: 167 | return (answer_list[0],) 168 | 169 | class parameters: 170 | @classmethod 171 | def INPUT_TYPES(s): 172 | return {"required": { 173 | "max_tokens": ("INT", {"default": 32, "min": 0, "max": 4096, "step": 1}), 174 | "top_k": ("INT", {"default": 40, "min": 0, "max": 1000, "step": 1}), 175 | "top_p": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 1.0, "step": 0.01}), 176 | "min_p": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 1.0, "step": 0.01}), 177 | "typical_p": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), 178 | "temperature": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.01}), 179 | "repeat_penalty": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.01}), 180 | "frequency_penalty": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), 181 | "presence_penalty": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), 182 | "tfs_z": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), 183 | "mirostat_mode": ("INT", {"default": 0, "min": 0, "max": 1, "step": 1}), 184 | "mirostat_eta": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 1.0, "step": 0.01}), 185 | "mirostat_tau": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 10.0, "step": 0.01}), 186 | } 187 | } 188 | 189 | RETURN_TYPES = ("LLAMACPPARAMS",) 190 | RETURN_NAMES = ("parameters",) 191 | FUNCTION = "process" 192 | CATEGORY = "Llama-cpp" 193 | 194 | def process(self, max_tokens, top_k, top_p, min_p, typical_p, temperature, repeat_penalty, 195 | frequency_penalty, presence_penalty, tfs_z, mirostat_mode, mirostat_eta, mirostat_tau, 196 | ): 197 | 198 | parameters_dict = { 199 | "max_tokens": max_tokens, 200 | "top_k": top_k, 201 | "top_p": top_p, 202 | "min_p": min_p, 203 | "typical_p": typical_p, 204 | "temperature": temperature, 205 | "repeat_penalty": repeat_penalty, 206 | "frequency_penalty": frequency_penalty, 207 | "presence_penalty": presence_penalty, 208 | "tfs_z": tfs_z, 209 | "mirostat_mode": mirostat_mode, 210 | "mirostat_eta": mirostat_eta, 211 | "mirostat_tau": mirostat_tau, 212 | } 213 | return (parameters_dict,) 214 | 215 | NODE_CLASS_MAPPINGS = { 216 | "deepseek_vl_model_loader": deepseek_vl_model_loader, 217 | "deepseek_vl_inference": deepseek_vl_inference, 218 | } 219 | 220 | NODE_DISPLAY_NAME_MAPPINGS = { 221 | "deepseek_vl_model_loader": "DeepSeek-VL Model Loader", 222 | "deepseek_vl_inference": "DeepSeek-VL Inference", 223 | } 224 | 225 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | attrdict 2 | huggingface_hub 3 | transformers>=4.38.2 4 | timm>=0.9.16 5 | sentencepiece --------------------------------------------------------------------------------