├── CXR_LLAVA_HF ├── __init__.py ├── VisualTransformer.py └── CXR_LLAVA_HF.py ├── IMG ├── demo.png └── img.jpg ├── requirements.txt ├── main.py └── README.md /CXR_LLAVA_HF/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /IMG/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECOFRI/CXR_LLaVA/HEAD/IMG/demo.png -------------------------------------------------------------------------------- /IMG/img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECOFRI/CXR_LLaVA/HEAD/IMG/img.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ECOFRI/CXR_LLaVA/HEAD/requirements.txt -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from CXR_LLAVA.CXR_LLAVA import CXR_LLAVA_Loader 3 | from PIL import Image 4 | if __name__ == '__main__': 5 | model_path = "MODEL FOLDER PATH" 6 | img = Image.open(os.path.join(os.path.dirname(__file__), "IMG", "img.jpg")) 7 | loader = CXR_LLAVA_Loader(model_path=model_path, temperature=0.4, top_p=0.8) 8 | 9 | chat = [ 10 | {"role": "system", "content": "You are a helpful radiologist. Try to interpret chest x ray image and answer to the question that user provides."}, 11 | {"role": "user", "content": "\nWrite a radiologic report on the given chest radiograph, including information about atelectasis, cardiomegaly, consolidation, pulmonary edema, pleural effusion, and pneumothorax.\n"} 12 | ] 13 | 14 | response = loader.generate(chat,pil_image=img) 15 | print("QUESTION : %s"%chat[-1]['content']) 16 | print("RESPONSE : %s"%response) 17 | 18 | chat.append({"role":"assistant","content":response}) 19 | chat.append({"role":"user","content":"What is possible diagnosis?"}) 20 | 21 | response = loader.generate(chat,pil_image=img) 22 | 23 | print("QUESTION : %s"%chat[-1]['content']) 24 | print("RESPONSE : %s"%response) 25 | 26 | chat.append({"role": "assistant", "content": response}) 27 | chat.append({"role": "user", "content": "Should additional radiologic study needed?"}) 28 | 29 | response = loader.generate(chat,pil_image=img) 30 | 31 | print("QUESTION : %s"%chat[-1]['content']) 32 | print("RESPONSE : %s"%response) 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # CXR-LLaVA Model Card 3 | ### Multimodal Large Language Model Fine-Tuned for Chest X-ray Images 4 | 5 | CXR-LLaVA is an open-source, multimodal large language model specifically designed for generating radiologic reports from chest X-ray images. 6 | 7 | - **Arxiv Preprint Paper**: Explore the detailed scientific background of CXR LLaVA on [Arxiv](https://arxiv.org/abs/2310.18341). 8 | - **Demo Website**: Experience the model in action at [Radiologist App](https://radiologist.app/cxr-llava/viewer.php). 9 | 10 | 11 | |Version| Input CXR resolution | Channels | Vision Encoder | Base LLM | Weight 12 | |--|--|--|--|--|--| 13 | | v1.0 | 512x512 | RGB|RN50|LLAMA2-13B-CHAT|Deprecated 14 | |v2.0.1 (Latest)|512x512|Grayscale|ViT-L/16|LLAMA2-7B-CHAT| Link 15 | 16 | You can interpret CXR with just 6 lines of code. 17 | 18 | (NVIDIA GPU VRAM>14GB needed) 19 | ```python 20 | from transformers import AutoModel 21 | from PIL import Image 22 | model = AutoModel.from_pretrained("ECOFRI/CXR-LLAVA-v2", trust_remote_code=True) 23 | model = model.to("cuda") 24 | cxr_image = Image.open("img.jpg") 25 | response = model.write_radiologic_report(cxr_image) 26 | ``` 27 | > The radiologic report reveals a large consolidation in the right upper lobe of the lungs. There is no evidence of pleural effusion or pneumothorax. The cardiac and mediastinal contours are normal. 28 | 29 | 30 | ## Usage Guide 31 | ### Install Dependencies 32 | Before you begin, make sure you have PyTorch installed. After confirming that PyTorch is installed, you can install the additional required dependencies. Run the following command in your terminal or command prompt: 33 | ```python 34 | pip install transformers sentencepiece protobuf pillow 35 | ``` 36 | 37 | ### Importing Packages 38 | ```python 39 | from transformers import AutoModel 40 | from PIL import Image 41 | ``` 42 | ### Prepare CXR 43 | 44 |
45 | 46 | Ensure you have an CXR image file ready, such as 'img.jpg'. 47 | 48 | Use the following code to load the image 49 | ```python 50 | cxr_image = Image.open("img.jpg") 51 | ``` 52 | ### Load model 53 | Loading the CXR-LLAVA model is straightforward and can be done in one line of code. 54 | 55 | ```python 56 | model = AutoModel.from_pretrained("ECOFRI/CXR-LLAVA-v2", trust_remote_code=True) 57 | model = model.to("cuda") 58 | ``` 59 | 60 | ### Generating Radiologic Reports 61 | 62 | To write a radiologic report of a chest radiograph: 63 | 64 | 65 | ```python 66 | response = model.write_radiologic_report(cxr_image) 67 | ``` 68 | 69 | > The radiologic report reveals a large consolidation in the right upper lobe of the lungs. There is no evidence of pleural effusion or pneumothorax. The cardiac and mediastinal contours are normal. 70 | 71 | 72 | ### Differential Diagnosis 73 | For differential diagnosis: 74 | 75 | ```python 76 | response = model.write_differential_diagnosis(cxr_image) 77 | ``` 78 | > Possible differential diagnoses for this patient include pneumonia,tuberculosis, lung abscess, or a neoplastic process such as lung cancer. 79 | 80 | ### Question Answering 81 | To ask a question: 82 | ```python 83 | question = "What is true meaning of consolidation?" 84 | response = model.ask_question(question=question, image=cxr_image) 85 | ``` 86 | > Consolidation refers to the filling of the airspaces in the lungs with fluid, pus, blood, cells or other substances, resulting in a region of lung tissue that has become dense and solid rather than containing air. 87 | 88 | ### Custom Prompt 89 | For custom interactions: 90 | ```python 91 | img = Image.open("img.jpg") 92 | chat = [ 93 | {"role": "system", 94 | "content": "You are a helpful radiologist. Try to interpret chest x ray image and answer to the question that user provides."}, 95 | {"role": "user", 96 | "content": "\nWrite a radiologic report on the given chest radiograph, including information about atelectasis, cardiomegaly, consolidation, pulmonary edema, pleural effusion, and pneumothorax.\n"} 97 | ] 98 | response = model.generate_cxr_repsonse(chat=chat,pil_image=img, temperature=0, top_p=1) 99 | ``` 100 | 101 | ## Intended Use 102 | ### Intended Use Cases 103 | CXR-LLaVA is designed for generating radiologic reports from chest X-ray images and is intended for research purposes. It can assist researchers in exploring the potential of multimodal large language models in interpreting chest X-rays. The model is suitable for assistant-like chat interactions related to chest X-ray interpretation. 104 | 105 | ### Out-of-Scope Use 106 | * Use for interpreting non-CXR images or medical imaging modalities not covered in the training data, such as photographs or other types of radiological images, which will result in meaningless outputs. 107 | * Clinical decision-making or direct patient care. 108 | 109 | ## Training Data 110 | The CXR-LLaVA model was trained on multiple open CXR datasets, including BrixIA, CheXpert, MIMIC, NIH, PadChest, RSNA COVID-19 AI Detection Challenge, and VinDR datasets. 111 | 112 | Refer to our research article on [Arxiv](https://arxiv.org/abs/2310.18341) for more details. 113 | 114 | ## Model Performance 115 | Refer to our research article on [Arxiv](https://arxiv.org/abs/2310.18341) for more details. 116 | 117 | ## Model Release 118 | * Model (v2.0.1) Release Date: January 14, 2024. 119 | * Status: This is a static model trained on an offline dataset. 120 | 121 | 122 | ## Ethical Considerations 123 | **Research Use Only:** The CXR-LLaVA model is intended solely for research purposes. Users must ensure ethical and responsible use within a research setting. It should not be used for clinical diagnosis or treatment without thorough validation and regulatory approval. 124 | 125 | **Informed Usage:** Users must be knowledgeable about the model's capabilities and limitations. They should interpret results within the context of their expertise and be aware of the potential implications of using the model. 126 | 127 | **Data Privacy:** When using the model with patient data, researchers must adhere to all relevant data protection and privacy regulations. Anonymization of patient data is essential to maintain confidentiality and privacy. 128 | 129 | ## Limitations 130 | **Domain-Specific Training:** The model was trained exclusively on chest X-ray (CXR) images. Inputting non-CXR images, such as photographs or other types of medical imaging, will result in meaningless outputs. 131 | 132 | **Numerical Data Handling:** The model may struggle with accurately processing numerical data, including specific measurements or quantitative details often found in radiologic reports, such as the exact location or size of abnormalities. 133 | 134 | **Image Quality:** The model processes 512x512 resolution grayscale images. Differences in image resolution or grayscale levels from those used during training could affect the model's performance. Higher resolution images or those with more grayscale levels might provide details that the model cannot accurately interpret. 135 | 136 | **Bias and Generalizability:** The model was trained on specific datasets, which may not fully represent the diversity of clinical cases in different medical settings. This could lead to biases in the model's outputs. Users should interpret results cautiously and consider potential biases. 137 | 138 | **Unpredictable Outputs:** As with all LLMs, the CXR-LLaVA model may produce unpredictable outputs. Safety testing and tuning tailored to specific applications are necessary before deploying any applications involving this model. 139 | Regulatory Approval: The model has not undergone regulatory approval processes, such as FDA clearance. It must not be used for clinical decision-making or direct patient care without such approval. 140 | 141 | ## Important Note 142 | CXR-LLaVA may generate incorrect interpretations of chest X-rays, omit crucial information, or provide inaccurate responses during interactions. Therefore, it should never be used for patient treatment. The model is intended solely for research purposes and should not be relied upon for clinical decision-making or direct patient care. 143 | 144 | 145 | 146 | ## License Information 147 | CXR LLaVA is available under a Creative Commons NonCommercial License. 148 | 149 | Users must obtain the LLAMA-2 license prior to use. More details can be found [here](https://ai.meta.com/resources/models-and-libraries/llama-downloads/). 150 | 151 | 152 | Lastly, we extend our heartfelt thanks to all the contributors of the [LLaVA project](https://llava-vl.github.io/). 153 | -------------------------------------------------------------------------------- /CXR_LLAVA_HF/VisualTransformer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Source code from OPEN_CLIP project. 3 | https://github.com/mlfoundations/open_clip/blob/main/LICENSE 4 | ''' 5 | 6 | from collections import OrderedDict 7 | import math 8 | from typing import Callable, Optional, Sequence, Tuple 9 | from functools import partial 10 | 11 | import torch 12 | from torch import nn 13 | from torch.nn import functional as F 14 | from torch.utils.checkpoint import checkpoint 15 | 16 | from itertools import repeat 17 | import collections.abc 18 | 19 | # From PyTorch internals 20 | def _ntuple(n): 21 | def parse(x): 22 | if isinstance(x, collections.abc.Iterable): 23 | return x 24 | return tuple(repeat(x, n)) 25 | return parse 26 | 27 | to_1tuple = _ntuple(1) 28 | to_2tuple = _ntuple(2) 29 | to_3tuple = _ntuple(3) 30 | to_4tuple = _ntuple(4) 31 | to_ntuple = lambda n, x: _ntuple(n)(x) 32 | 33 | class LayerNormFp32(nn.LayerNorm): 34 | """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" 35 | 36 | def forward(self, x: torch.Tensor): 37 | orig_type = x.dtype 38 | x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) 39 | 40 | #x = F.layer_norm(x.to(torch.bfloat16), self.normalized_shape, self.weight, self.bias, self.eps) 41 | return x.to(orig_type) 42 | 43 | 44 | class LayerNorm(nn.LayerNorm): 45 | """Subclass torch's LayerNorm (with cast back to input dtype).""" 46 | 47 | def forward(self, x: torch.Tensor): 48 | orig_type = x.dtype 49 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 50 | return x.to(orig_type) 51 | 52 | 53 | class QuickGELU(nn.Module): 54 | # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory 55 | def forward(self, x: torch.Tensor): 56 | return x * torch.sigmoid(1.702 * x) 57 | 58 | 59 | class LayerScale(nn.Module): 60 | def __init__(self, dim, init_values=1e-5, inplace=False): 61 | super().__init__() 62 | self.inplace = inplace 63 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 64 | 65 | def forward(self, x): 66 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 67 | 68 | 69 | class PatchDropout(nn.Module): 70 | """ 71 | https://arxiv.org/abs/2212.00794 72 | """ 73 | 74 | def __init__(self, prob, exclude_first_token=True): 75 | super().__init__() 76 | assert 0 <= prob < 1. 77 | self.prob = prob 78 | self.exclude_first_token = exclude_first_token # exclude CLS token 79 | 80 | def forward(self, x): 81 | if not self.training or self.prob == 0.: 82 | return x 83 | 84 | if self.exclude_first_token: 85 | cls_tokens, x = x[:, :1], x[:, 1:] 86 | else: 87 | cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) 88 | 89 | batch = x.size()[0] 90 | num_tokens = x.size()[1] 91 | 92 | batch_indices = torch.arange(batch) 93 | batch_indices = batch_indices[..., None] 94 | 95 | keep_prob = 1 - self.prob 96 | num_patches_keep = max(1, int(num_tokens * keep_prob)) 97 | 98 | rand = torch.randn(batch, num_tokens) 99 | patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices 100 | 101 | x = x[batch_indices, patch_indices_keep] 102 | 103 | if self.exclude_first_token: 104 | x = torch.cat((cls_tokens, x), dim=1) 105 | 106 | return x 107 | 108 | 109 | class Attention(nn.Module): 110 | def __init__( 111 | self, 112 | dim, 113 | num_heads=8, 114 | qkv_bias=True, 115 | scaled_cosine=False, 116 | scale_heads=False, 117 | logit_scale_max=math.log(1. / 0.01), 118 | attn_drop=0., 119 | proj_drop=0. 120 | ): 121 | super().__init__() 122 | self.scaled_cosine = scaled_cosine 123 | self.scale_heads = scale_heads 124 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 125 | self.num_heads = num_heads 126 | self.head_dim = dim // num_heads 127 | self.scale = self.head_dim ** -0.5 128 | self.logit_scale_max = logit_scale_max 129 | 130 | # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original 131 | self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) 132 | if qkv_bias: 133 | self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) 134 | else: 135 | self.in_proj_bias = None 136 | 137 | if self.scaled_cosine: 138 | self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) 139 | else: 140 | self.logit_scale = None 141 | self.attn_drop = nn.Dropout(attn_drop) 142 | if self.scale_heads: 143 | self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) 144 | else: 145 | self.head_scale = None 146 | self.out_proj = nn.Linear(dim, dim) 147 | self.out_drop = nn.Dropout(proj_drop) 148 | 149 | def forward(self, x, attn_mask: Optional[torch.Tensor] = None): 150 | L, N, C = x.shape 151 | q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) 152 | q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) 153 | k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) 154 | v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) 155 | 156 | if self.logit_scale is not None: 157 | attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) 158 | logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() 159 | attn = attn.view(N, self.num_heads, L, L) * logit_scale 160 | attn = attn.view(-1, L, L) 161 | else: 162 | q = q * self.scale 163 | attn = torch.bmm(q, k.transpose(-1, -2)) 164 | 165 | if attn_mask is not None: 166 | if attn_mask.dtype == torch.bool: 167 | new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) 168 | new_attn_mask.masked_fill_(attn_mask, float("-inf")) 169 | attn_mask = new_attn_mask 170 | attn += attn_mask 171 | 172 | attn = attn.softmax(dim=-1) 173 | attn = self.attn_drop(attn) 174 | 175 | x = torch.bmm(attn, v) 176 | if self.head_scale is not None: 177 | x = x.view(N, self.num_heads, L, C) * self.head_scale 178 | x = x.view(-1, L, C) 179 | x = x.transpose(0, 1).reshape(L, N, C) 180 | x = self.out_proj(x) 181 | x = self.out_drop(x) 182 | return x 183 | 184 | 185 | class AttentionalPooler(nn.Module): 186 | def __init__( 187 | self, 188 | d_model: int, 189 | context_dim: int, 190 | n_head: int = 8, 191 | n_queries: int = 256, 192 | norm_layer: Callable = LayerNorm 193 | ): 194 | super().__init__() 195 | self.query = nn.Parameter(torch.randn(n_queries, d_model)) 196 | self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim) 197 | self.ln_q = norm_layer(d_model) 198 | self.ln_k = norm_layer(context_dim) 199 | 200 | def forward(self, x: torch.Tensor): 201 | x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND 202 | N = x.shape[1] 203 | q = self.ln_q(self.query) 204 | out = self.attn(q.unsqueeze(1).expand(-1, N, -1), x, x, need_weights=False)[0] 205 | return out.permute(1, 0, 2) # LND -> NLD 206 | 207 | 208 | class ResidualAttentionBlock(nn.Module): 209 | def __init__( 210 | self, 211 | d_model: int, 212 | n_head: int, 213 | mlp_ratio: float = 4.0, 214 | ls_init_value: float = None, 215 | act_layer: Callable = nn.GELU, 216 | norm_layer: Callable = LayerNorm, 217 | is_cross_attention: bool = False, 218 | ): 219 | super().__init__() 220 | 221 | self.ln_1 = norm_layer(d_model) 222 | self.attn = nn.MultiheadAttention(d_model, n_head) 223 | self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() 224 | if is_cross_attention: 225 | self.ln_1_kv = norm_layer(d_model) 226 | 227 | self.ln_2 = norm_layer(d_model) 228 | mlp_width = int(d_model * mlp_ratio) 229 | self.mlp = nn.Sequential(OrderedDict([ 230 | ("c_fc", nn.Linear(d_model, mlp_width)), 231 | ("gelu", act_layer()), 232 | ("c_proj", nn.Linear(mlp_width, d_model)) 233 | ])) 234 | self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() 235 | 236 | def attention( 237 | self, 238 | q_x: torch.Tensor, 239 | k_x: Optional[torch.Tensor] = None, 240 | v_x: Optional[torch.Tensor] = None, 241 | attn_mask: Optional[torch.Tensor] = None, 242 | ): 243 | k_x = k_x if k_x is not None else q_x 244 | v_x = v_x if v_x is not None else q_x 245 | 246 | attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None 247 | return self.attn( 248 | q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask 249 | )[0] 250 | 251 | def forward( 252 | self, 253 | q_x: torch.Tensor, 254 | k_x: Optional[torch.Tensor] = None, 255 | v_x: Optional[torch.Tensor] = None, 256 | attn_mask: Optional[torch.Tensor] = None, 257 | ): 258 | k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None 259 | v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None 260 | 261 | x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) 262 | x = x + self.ls_2(self.mlp(self.ln_2(x))) 263 | return x 264 | 265 | 266 | class CustomResidualAttentionBlock(nn.Module): 267 | def __init__( 268 | self, 269 | d_model: int, 270 | n_head: int, 271 | mlp_ratio: float = 4.0, 272 | ls_init_value: float = None, 273 | act_layer: Callable = nn.GELU, 274 | norm_layer: Callable = LayerNorm, 275 | scale_cosine_attn: bool = False, 276 | scale_heads: bool = False, 277 | scale_attn: bool = False, 278 | scale_fc: bool = False, 279 | ): 280 | super().__init__() 281 | 282 | self.ln_1 = norm_layer(d_model) 283 | self.attn = Attention( 284 | d_model, n_head, 285 | scaled_cosine=scale_cosine_attn, 286 | scale_heads=scale_heads, 287 | ) 288 | self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() 289 | self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() 290 | 291 | self.ln_2 = norm_layer(d_model) 292 | mlp_width = int(d_model * mlp_ratio) 293 | self.mlp = nn.Sequential(OrderedDict([ 294 | ("c_fc", nn.Linear(d_model, mlp_width)), 295 | ("gelu", act_layer()), 296 | ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), 297 | ("c_proj", nn.Linear(mlp_width, d_model)) 298 | ])) 299 | self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() 300 | 301 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 302 | x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) 303 | x = x + self.ls_2(self.mlp(self.ln_2(x))) 304 | return x 305 | 306 | 307 | def _expand_token(token, batch_size: int): 308 | return token.view(1, 1, -1).expand(batch_size, -1, -1) 309 | 310 | 311 | class Transformer(nn.Module): 312 | def __init__( 313 | self, 314 | width: int, 315 | layers: int, 316 | heads: int, 317 | mlp_ratio: float = 4.0, 318 | ls_init_value: float = None, 319 | act_layer: Callable = nn.GELU, 320 | norm_layer: Callable = LayerNorm, 321 | ): 322 | super().__init__() 323 | self.width = width 324 | self.layers = layers 325 | self.grad_checkpointing = False 326 | 327 | self.resblocks = nn.ModuleList([ 328 | ResidualAttentionBlock( 329 | width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) 330 | for _ in range(layers) 331 | ]) 332 | 333 | def get_cast_dtype(self) -> torch.dtype: 334 | if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'): 335 | return self.resblocks[0].mlp.c_fc.int8_original_dtype 336 | return self.resblocks[0].mlp.c_fc.weight.dtype 337 | 338 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 339 | for r in self.resblocks: 340 | if self.grad_checkpointing and not torch.jit.is_scripting(): 341 | # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 342 | x = checkpoint(r, x, None, None, attn_mask) 343 | else: 344 | x = r(x, attn_mask=attn_mask) 345 | return x 346 | 347 | 348 | class VisionTransformer(nn.Module): 349 | output_tokens: torch.jit.Final[bool] 350 | 351 | def __init__( 352 | self, 353 | in_channels:int, 354 | image_size: int, 355 | patch_size: int, 356 | width: int, 357 | layers: int, 358 | heads: int, 359 | mlp_ratio: float, 360 | ls_init_value: float = None, 361 | attentional_pool: bool = False, 362 | attn_pooler_queries: int = 256, 363 | attn_pooler_heads: int = 8, 364 | output_dim: int = 512, 365 | patch_dropout: float = 0., 366 | no_ln_pre: bool = False, 367 | pos_embed_type: str = 'learnable', 368 | pool_type: str = 'tok', 369 | final_ln_after_pool: bool = False, 370 | act_layer: Callable = nn.GELU, 371 | norm_layer: Callable = LayerNorm, 372 | output_tokens: bool = False, 373 | ): 374 | super().__init__() 375 | assert pool_type in ('tok', 'avg', 'none') 376 | self.output_tokens = output_tokens 377 | image_height, image_width = self.image_size = to_2tuple(image_size) 378 | patch_height, patch_width = self.patch_size = to_2tuple(patch_size) 379 | self.grid_size = (image_height // patch_height, image_width // patch_width) 380 | self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled 381 | self.output_dim = output_dim 382 | 383 | self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 384 | 385 | # class embeddings and positional embeddings 386 | scale = width ** -0.5 387 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 388 | if pos_embed_type == 'learnable': 389 | self.positional_embedding = nn.Parameter( 390 | scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) 391 | elif pos_embed_type == 'sin_cos_2d': 392 | # fixed sin-cos embedding 393 | assert self.grid_size[0] == self.grid_size[1], \ 394 | 'currently sin cos 2d pos embedding only supports square input' 395 | self.positional_embedding = nn.Parameter( 396 | torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False) 397 | pos_embed_type = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True) 398 | self.positional_embedding.data.copy_(torch.from_numpy(pos_embed_type).float()) 399 | else: 400 | raise ValueError 401 | 402 | # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn 403 | self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() 404 | 405 | self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width) 406 | self.transformer = Transformer( 407 | width, 408 | layers, 409 | heads, 410 | mlp_ratio, 411 | ls_init_value=ls_init_value, 412 | act_layer=act_layer, 413 | norm_layer=norm_layer, 414 | ) 415 | 416 | if attentional_pool: 417 | if isinstance(attentional_pool, str): 418 | self.attn_pool_type = attentional_pool 419 | self.pool_type = 'none' 420 | if attentional_pool in ('parallel', 'cascade'): 421 | self.attn_pool = AttentionalPooler( 422 | output_dim, 423 | width, 424 | n_head=attn_pooler_heads, 425 | n_queries=attn_pooler_queries, 426 | ) 427 | self.attn_pool_contrastive = AttentionalPooler( 428 | output_dim, 429 | width, 430 | n_head=attn_pooler_heads, 431 | n_queries=1, 432 | ) 433 | else: 434 | assert False 435 | else: 436 | self.attn_pool_type = '' 437 | self.pool_type = pool_type 438 | self.attn_pool = AttentionalPooler( 439 | output_dim, 440 | width, 441 | n_head=attn_pooler_heads, 442 | n_queries=attn_pooler_queries, 443 | ) 444 | self.attn_pool_contrastive = None 445 | pool_dim = output_dim 446 | else: 447 | self.attn_pool = None 448 | pool_dim = width 449 | self.pool_type = pool_type 450 | 451 | self.ln_post = norm_layer(pool_dim) 452 | self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim)) 453 | 454 | self.init_parameters() 455 | 456 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 457 | for param in self.parameters(): 458 | param.requires_grad = False 459 | 460 | if unlocked_groups != 0: 461 | groups = [ 462 | [ 463 | self.conv1, 464 | self.class_embedding, 465 | self.positional_embedding, 466 | self.ln_pre, 467 | ], 468 | *self.transformer.resblocks[:-1], 469 | [ 470 | self.transformer.resblocks[-1], 471 | self.ln_post, 472 | ], 473 | self.proj, 474 | ] 475 | 476 | def _unlock(x): 477 | if isinstance(x, Sequence): 478 | for g in x: 479 | _unlock(g) 480 | else: 481 | if isinstance(x, torch.nn.Parameter): 482 | x.requires_grad = True 483 | else: 484 | for p in x.parameters(): 485 | p.requires_grad = True 486 | 487 | _unlock(groups[-unlocked_groups:]) 488 | 489 | def init_parameters(self): 490 | # FIXME OpenAI CLIP did not define an init for the VisualTransformer 491 | # TODO experiment if default PyTorch init, below, or alternate init is best. 492 | 493 | # nn.init.normal_(self.class_embedding, std=self.scale) 494 | # nn.init.normal_(self.positional_embedding, std=self.scale) 495 | # 496 | # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 497 | # attn_std = self.transformer.width ** -0.5 498 | # fc_std = (2 * self.transformer.width) ** -0.5 499 | # for block in self.transformer.resblocks: 500 | # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 501 | # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 502 | # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 503 | # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 504 | # 505 | # if self.text_projection is not None: 506 | # nn.init.normal_(self.text_projection, std=self.scale) 507 | pass 508 | 509 | @torch.jit.ignore 510 | def set_grad_checkpointing(self, enable=True): 511 | self.transformer.grad_checkpointing = enable 512 | 513 | def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 514 | if self.pool_type == 'avg': 515 | pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:] 516 | elif self.pool_type == 'tok': 517 | pooled, tokens = x[:, 0], x[:, 1:] 518 | else: 519 | pooled = tokens = x 520 | 521 | return pooled, tokens 522 | 523 | def forward(self, x: torch.Tensor): 524 | x = self.conv1(x) # shape = [*, width, grid, grid] 525 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 526 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 527 | 528 | # class embeddings and positional embeddings 529 | x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) 530 | # shape = [*, grid ** 2 + 1, width] 531 | x = x + self.positional_embedding.to(x.dtype) 532 | 533 | x = self.patch_dropout(x) 534 | x = self.ln_pre(x) 535 | 536 | x = x.permute(1, 0, 2) # NLD -> LND 537 | x = self.transformer(x) 538 | x = x.permute(1, 0, 2) # LND -> NLD 539 | 540 | if self.attn_pool is not None: 541 | if self.attn_pool_contrastive is not None: 542 | # This is untested, WIP pooling that should match paper 543 | x = self.ln_post(x) # TBD LN first or separate one after each pool? 544 | tokens = self.attn_pool(x) 545 | if self.attn_pool_type == 'parallel': 546 | pooled = self.attn_pool_contrastive(x) 547 | else: 548 | assert self.attn_pool_type == 'cascade' 549 | pooled = self.attn_pool_contrastive(tokens) 550 | else: 551 | # this is the original OpenCLIP CoCa setup, does not match paper 552 | x = self.attn_pool(x) 553 | x = self.ln_post(x) 554 | pooled, tokens = self._global_pool(x) 555 | elif self.final_ln_after_pool: 556 | pooled, tokens = self._global_pool(x) 557 | pooled = self.ln_post(pooled) 558 | else: 559 | x = self.ln_post(x) 560 | pooled, tokens = self._global_pool(x) 561 | 562 | if self.proj is not None: 563 | pooled = pooled @ self.proj 564 | 565 | if self.output_tokens: 566 | return pooled, tokens 567 | 568 | return pooled 569 | 570 | 571 | def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax'): 572 | if pool_type == 'first': 573 | pooled, tokens = x[:, 0], x[:, 1:] 574 | elif pool_type == 'last': 575 | pooled, tokens = x[:, -1], x[:, :-1] 576 | elif pool_type == 'argmax': 577 | # take features from the eot embedding (eot_token is the highest number in each sequence) 578 | assert text is not None 579 | pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x 580 | else: 581 | pooled = tokens = x 582 | 583 | return pooled, tokens 584 | 585 | 586 | class TextTransformer(nn.Module): 587 | output_tokens: torch.jit.Final[bool] 588 | 589 | def __init__( 590 | self, 591 | context_length: int = 77, 592 | vocab_size: int = 49408, 593 | width: int = 512, 594 | heads: int = 8, 595 | layers: int = 12, 596 | mlp_ratio: float = 4.0, 597 | ls_init_value: float = None, 598 | output_dim: int = 512, 599 | embed_cls: bool = False, 600 | no_causal_mask: bool = False, 601 | pad_id: int = 0, 602 | pool_type: str = 'argmax', 603 | proj_bias: bool = False, 604 | act_layer: Callable = nn.GELU, 605 | norm_layer: Callable = LayerNorm, 606 | output_tokens: bool = False, 607 | ): 608 | super().__init__() 609 | assert pool_type in ('first', 'last', 'argmax', 'none') 610 | self.output_tokens = output_tokens 611 | self.num_pos = self.context_length = context_length 612 | self.vocab_size = vocab_size 613 | self.width = width 614 | self.output_dim = output_dim 615 | self.heads = heads 616 | self.pad_id = pad_id 617 | self.pool_type = pool_type 618 | 619 | self.token_embedding = nn.Embedding(vocab_size, width) 620 | if embed_cls: 621 | self.cls_emb = nn.Parameter(torch.empty(width)) 622 | self.num_pos += 1 623 | else: 624 | self.cls_emb = None 625 | self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) 626 | self.transformer = Transformer( 627 | width=width, 628 | layers=layers, 629 | heads=heads, 630 | mlp_ratio=mlp_ratio, 631 | ls_init_value=ls_init_value, 632 | act_layer=act_layer, 633 | norm_layer=norm_layer, 634 | ) 635 | self.ln_final = norm_layer(width) 636 | 637 | if no_causal_mask: 638 | self.attn_mask = None 639 | else: 640 | self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False) 641 | 642 | if proj_bias: 643 | self.text_projection = nn.Linear(width, output_dim) 644 | else: 645 | self.text_projection = nn.Parameter(torch.empty(width, output_dim)) 646 | 647 | self.init_parameters() 648 | 649 | def init_parameters(self): 650 | nn.init.normal_(self.token_embedding.weight, std=0.02) 651 | nn.init.normal_(self.positional_embedding, std=0.01) 652 | if self.cls_emb is not None: 653 | nn.init.normal_(self.cls_emb, std=0.01) 654 | 655 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 656 | attn_std = self.transformer.width ** -0.5 657 | fc_std = (2 * self.transformer.width) ** -0.5 658 | for block in self.transformer.resblocks: 659 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 660 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 661 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 662 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 663 | 664 | if self.text_projection is not None: 665 | if isinstance(self.text_projection, nn.Linear): 666 | nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5) 667 | if self.text_projection.bias is not None: 668 | nn.init.zeros_(self.text_projection.bias) 669 | else: 670 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 671 | 672 | @torch.jit.ignore 673 | def set_grad_checkpointing(self, enable=True): 674 | self.transformer.grad_checkpointing = enable 675 | 676 | def build_causal_mask(self): 677 | # lazily create causal attention mask, with full attention between the tokens 678 | # pytorch uses additive attention mask; fill with -inf 679 | mask = torch.empty(self.num_pos, self.num_pos) 680 | mask.fill_(float("-inf")) 681 | mask.triu_(1) # zero out the lower diagonal 682 | return mask 683 | 684 | def build_cls_mask(self, text, cast_dtype: torch.dtype): 685 | cls_mask = (text != self.pad_id).unsqueeze(1) 686 | cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) 687 | additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) 688 | additive_mask.fill_(0) 689 | additive_mask.masked_fill_(~cls_mask, float("-inf")) 690 | additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) 691 | return additive_mask 692 | 693 | def forward(self, text): 694 | cast_dtype = self.transformer.get_cast_dtype() 695 | seq_len = text.shape[1] 696 | 697 | x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] 698 | attn_mask = self.attn_mask 699 | if self.cls_emb is not None: 700 | seq_len += 1 701 | x = torch.cat([x, _expand_token(self.cls_emb, x.shape[0])], dim=1) 702 | cls_mask = self.build_cls_mask(text, cast_dtype) 703 | if attn_mask is not None: 704 | attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] 705 | 706 | x = x + self.positional_embedding[:seq_len].to(cast_dtype) 707 | x = x.permute(1, 0, 2) # NLD -> LND 708 | x = self.transformer(x, attn_mask=attn_mask) 709 | x = x.permute(1, 0, 2) # LND -> NLD 710 | 711 | # x.shape = [batch_size, n_ctx, transformer.width] 712 | if self.cls_emb is not None: 713 | # presence of appended cls embed (CoCa) overrides pool_type, always take last token 714 | pooled, tokens = text_global_pool(x, pool_type='last') 715 | pooled = self.ln_final(pooled) # final LN applied after pooling in this case 716 | else: 717 | x = self.ln_final(x) 718 | pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type) 719 | 720 | if self.text_projection is not None: 721 | if isinstance(self.text_projection, nn.Linear): 722 | pooled = self.text_projection(pooled) 723 | else: 724 | pooled = pooled @ self.text_projection 725 | 726 | if self.output_tokens: 727 | return pooled, tokens 728 | 729 | return pooled 730 | 731 | 732 | class MultimodalTransformer(Transformer): 733 | def __init__( 734 | self, 735 | width: int, 736 | layers: int, 737 | heads: int, 738 | context_length: int = 77, 739 | mlp_ratio: float = 4.0, 740 | ls_init_value: float = None, 741 | act_layer: Callable = nn.GELU, 742 | norm_layer: Callable = LayerNorm, 743 | output_dim: int = 512, 744 | ): 745 | 746 | super().__init__( 747 | width=width, 748 | layers=layers, 749 | heads=heads, 750 | mlp_ratio=mlp_ratio, 751 | ls_init_value=ls_init_value, 752 | act_layer=act_layer, 753 | norm_layer=norm_layer, 754 | ) 755 | self.context_length = context_length 756 | self.cross_attn = nn.ModuleList([ 757 | ResidualAttentionBlock( 758 | width, 759 | heads, 760 | mlp_ratio, 761 | ls_init_value=ls_init_value, 762 | act_layer=act_layer, 763 | norm_layer=norm_layer, 764 | is_cross_attention=True, 765 | ) 766 | for _ in range(layers) 767 | ]) 768 | 769 | self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) 770 | 771 | self.ln_final = norm_layer(width) 772 | self.text_projection = nn.Parameter(torch.empty(width, output_dim)) 773 | 774 | def init_parameters(self): 775 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 776 | attn_std = self.transformer.width ** -0.5 777 | fc_std = (2 * self.transformer.width) ** -0.5 778 | for block in self.transformer.resblocks: 779 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 780 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 781 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 782 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 783 | for block in self.transformer.cross_attn: 784 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 785 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 786 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 787 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 788 | 789 | if self.text_projection is not None: 790 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 791 | 792 | def build_attention_mask(self): 793 | # lazily create causal attention mask, with full attention between the tokens 794 | # pytorch uses additive attention mask; fill with -inf 795 | mask = torch.empty(self.context_length, self.context_length) 796 | mask.fill_(float("-inf")) 797 | mask.triu_(1) # zero out the lower diagonal 798 | return mask 799 | 800 | def forward(self, image_embs, text_embs): 801 | text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq 802 | image_embs = image_embs.permute(1, 0, 2) # NLD -> LND 803 | seq_len = text_embs.shape[0] 804 | 805 | for resblock, cross_attn in zip(self.resblocks, self.cross_attn): 806 | if self.grad_checkpointing and not torch.jit.is_scripting(): 807 | # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 808 | text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len]) 809 | text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None) 810 | else: 811 | text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) 812 | text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) 813 | 814 | x = text_embs.permute(1, 0, 2) # LND -> NLD 815 | x = self.ln_final(x) 816 | 817 | if self.text_projection is not None: 818 | x = x @ self.text_projection 819 | 820 | return x 821 | 822 | @torch.jit.ignore 823 | def set_grad_checkpointing(self, enable=True): 824 | self.grad_checkpointing = enable 825 | 826 | 827 | # Copyright (c) Meta Platforms, Inc. and affiliates. 828 | # All rights reserved. 829 | 830 | # This source code is licensed under the license found in the 831 | # LICENSE file in the root directory of this source tree. 832 | # -------------------------------------------------------- 833 | # Position embedding utils 834 | # -------------------------------------------------------- 835 | 836 | import numpy as np 837 | 838 | import torch 839 | 840 | # -------------------------------------------------------- 841 | # 2D sine-cosine position embedding 842 | # References: 843 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 844 | # MoCo v3: https://github.com/facebookresearch/moco-v3 845 | # -------------------------------------------------------- 846 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 847 | """ 848 | grid_size: int of the grid height and width 849 | return: 850 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 851 | """ 852 | grid_h = np.arange(grid_size, dtype=np.float32) 853 | grid_w = np.arange(grid_size, dtype=np.float32) 854 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 855 | grid = np.stack(grid, axis=0) 856 | 857 | grid = grid.reshape([2, 1, grid_size, grid_size]) 858 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 859 | if cls_token: 860 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 861 | return pos_embed 862 | 863 | 864 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 865 | assert embed_dim % 2 == 0 866 | 867 | # use half of dimensions to encode grid_h 868 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 869 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 870 | 871 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 872 | return emb 873 | 874 | 875 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 876 | """ 877 | embed_dim: output dimension for each position 878 | pos: a list of positions to be encoded: size (M,) 879 | out: (M, D) 880 | """ 881 | assert embed_dim % 2 == 0 882 | omega = np.arange(embed_dim // 2, dtype=float) 883 | omega /= embed_dim / 2. 884 | omega = 1. / 10000**omega # (D/2,) 885 | 886 | pos = pos.reshape(-1) # (M,) 887 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 888 | 889 | emb_sin = np.sin(out) # (M, D/2) 890 | emb_cos = np.cos(out) # (M, D/2) 891 | 892 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 893 | return emb 894 | 895 | 896 | # -------------------------------------------------------- 897 | # Interpolate position embeddings for high-resolution 898 | # References: 899 | # DeiT: https://github.com/facebookresearch/deit 900 | # -------------------------------------------------------- 901 | def interpolate_pos_embed(model, checkpoint_model): 902 | if 'pos_embed' in checkpoint_model: 903 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 904 | embedding_size = pos_embed_checkpoint.shape[-1] 905 | num_patches = model.patch_embed.num_patches 906 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 907 | # height (== width) for the checkpoint position embedding 908 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 909 | # height (== width) for the new position embedding 910 | new_size = int(num_patches ** 0.5) 911 | # class_token and dist_token are kept unchanged 912 | if orig_size != new_size: 913 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 914 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 915 | # only the position tokens are interpolated 916 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 917 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 918 | pos_tokens = torch.nn.functional.interpolate( 919 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 920 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 921 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 922 | checkpoint_model['pos_embed'] = new_pos_embed -------------------------------------------------------------------------------- /CXR_LLAVA_HF/CXR_LLAVA_HF.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig, PreTrainedModel 2 | import torch, transformers 3 | from typing import List, Optional, Tuple, Union 4 | from transformers.modeling_outputs import CausalLMOutputWithPast 5 | from .VisualTransformer import VisionTransformer, LayerNorm 6 | from functools import partial 7 | from transformers import TextIteratorStreamer 8 | from transformers import StoppingCriteria, GenerationConfig 9 | from threading import Thread 10 | from dataclasses import dataclass 11 | import numpy as np 12 | from PIL import Image 13 | # Model Constants 14 | IGNORE_INDEX = -100 15 | IMAGE_TOKEN_INDEX = -200 16 | DEFAULT_IMAGE_TOKEN = "" 17 | DEFAULT_IMAGE_PATCH_TOKEN = "" 18 | DEFAULT_IM_START_TOKEN = "" 19 | DEFAULT_IM_END_TOKEN = "" 20 | class AttrDict(dict): 21 | def __init__(self, *args, **kwargs): 22 | super(AttrDict, self).__init__(*args, **kwargs) 23 | self.__dict__ = self 24 | def __getattr__(self, key): 25 | if key in self: 26 | return self[key] 27 | raise AttributeError(f"'AttrDict' object has no attribute '{key}'") 28 | 29 | 30 | class CXRLLAVAConfig(PretrainedConfig): 31 | model_type = "CXR-LLAVA" 32 | 33 | def __init__(self, **kwargs,): 34 | 35 | if 'llama' in kwargs: 36 | self.llama = AttrDict(kwargs['llama']) 37 | del kwargs['llama'] 38 | 39 | self.__dict__.update(kwargs) 40 | super().__init__(**kwargs) 41 | 42 | 43 | class CXRLLAVAModel(PreTrainedModel): 44 | config_class = CXRLLAVAConfig 45 | 46 | def __init__(self, config): 47 | super().__init__(config) 48 | 49 | self.tokenizer = transformers.LlamaTokenizer.from_pretrained(config._name_or_path, add_special_tokens=False) 50 | self.tokenizer.pad_token = self.tokenizer.unk_token 51 | self.tokenizer.sep_token = self.tokenizer.unk_token 52 | self.tokenizer.cls_token = self.tokenizer.unk_token 53 | self.tokenizer.mask_token = self.tokenizer.unk_token 54 | 55 | vision_cfg = CLIPVisionCfg(**config.clip_vision_cfg) 56 | 57 | self.generation_config = GenerationConfig.from_pretrained(config._name_or_path) 58 | 59 | vision_heads = vision_cfg.width // vision_cfg.head_width 60 | norm_layer = LayerNorm 61 | act_layer = torch.nn.GELU 62 | if vision_cfg.norm_kwargs: 63 | norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs) 64 | if vision_cfg.act_kwargs is not None: 65 | act_layer = partial(act_layer, **vision_cfg.act_kwargs) 66 | 67 | self.vision_tower = VisionTransformer( 68 | in_channels=1, 69 | image_size=vision_cfg.image_size, 70 | patch_size=vision_cfg.patch_size, 71 | width=vision_cfg.width, 72 | layers=vision_cfg.layers, 73 | heads=vision_heads, 74 | mlp_ratio=vision_cfg.mlp_ratio, 75 | ls_init_value=vision_cfg.ls_init_value, 76 | patch_dropout=vision_cfg.patch_dropout, 77 | attentional_pool=vision_cfg.attentional_pool, 78 | attn_pooler_queries=vision_cfg.attn_pooler_queries, 79 | attn_pooler_heads=vision_cfg.attn_pooler_heads, 80 | pos_embed_type=vision_cfg.pos_embed_type, 81 | no_ln_pre=vision_cfg.no_ln_pre, 82 | final_ln_after_pool=vision_cfg.final_ln_after_pool, 83 | pool_type=vision_cfg.pool_type, 84 | output_tokens=vision_cfg.output_tokens, 85 | output_dim=config.clip_embed_dim, 86 | act_layer=act_layer, 87 | norm_layer=norm_layer, 88 | ) 89 | 90 | self.vision_tower.image_processor = transformers.CLIPImageProcessor( 91 | do_resize=True, 92 | size={'shortest_edge': config.clip_vision_cfg['image_size']}, 93 | resample=True, 94 | do_center_crop=True, 95 | crop_size=config.clip_vision_cfg['image_size'], 96 | do_rescale=True, 97 | rescale_factor=1 / 255, 98 | do_normalize=True, 99 | image_mean=config.image_preprocess_cfg['mean'], 100 | image_std=config.image_preprocess_cfg['std'], 101 | do_convert_rgb=False 102 | ) 103 | 104 | def convert_dtype(dtype): 105 | if dtype == 'fp32': 106 | dtype = torch.float32 107 | elif dtype == 'fp16': 108 | dtype = torch.float16 109 | elif dtype == 'bf16': 110 | dtype = torch.bfloat16 111 | else: 112 | raise Exception("Unsupported dtype") 113 | return dtype 114 | 115 | self.clip_cast_dtype = convert_dtype(config.clip_vision_tower_dtype) 116 | self.mm_projector = torch.nn.Linear(config.mm_projector_dim, config.llama['hidden_size']) 117 | self.lm_head = torch.nn.Linear(config.llama.hidden_size, config.llama.vocab_size, bias=False) 118 | self.llama = transformers.LlamaModel(transformers.LlamaConfig(**config.llama)) 119 | 120 | self.llama = self.llama.to(torch.bfloat16) 121 | self.lm_head = self.lm_head.to(torch.bfloat16) 122 | self.vision_tower = self.vision_tower.to(torch.bfloat16) 123 | self.mm_projector = self.mm_projector.to(torch.bfloat16) 124 | 125 | def get_input_embeddings(self): 126 | return self.llama.get_input_embeddings() 127 | 128 | def get_vision_tower(self): 129 | return self.vision_tower 130 | 131 | def gradient_checkpointing_enable(self): 132 | return self.llama.gradient_checkpointing_enable() 133 | 134 | def encode_images(self, images): 135 | images = images.to(torch.bfloat16) 136 | 137 | def _expand_token(token, batch_size: int): 138 | return token.view(1, 1, -1).expand(batch_size, -1, -1) 139 | 140 | # open_clip ViT 141 | # https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py 142 | x = images 143 | x = self.vision_tower.conv1(x) # shape = [*, width, grid, grid] 144 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 145 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 146 | 147 | # class embeddings and positional embeddings 148 | x = torch.cat([_expand_token(self.vision_tower.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) 149 | # shape = [*, grid ** 2 + 1, width] 150 | x = x + self.vision_tower.positional_embedding.to(x.dtype) 151 | 152 | x = self.vision_tower.patch_dropout(x) 153 | x = self.vision_tower.ln_pre(x) 154 | 155 | x = x.permute(1, 0, 2) # NLD -> LND 156 | x = self.vision_tower.transformer(x) 157 | x = x.permute(1, 0, 2) # LND -> NLD 158 | 159 | if self.vision_tower.attn_pool is not None: 160 | if self.vision_tower.attn_pool_contrastive is not None: 161 | # This is untested, WIP pooling that should match paper 162 | x = self.vision_tower.ln_post(x) # TBD LN first or separate one after each pool? 163 | tokens = self.vision_tower.attn_pool(x) 164 | if self.vision_tower.attn_pool_type == 'parallel': 165 | pooled = self.vision_tower.attn_pool_contrastive(x) 166 | else: 167 | assert self.vision_tower.attn_pool_type == 'cascade' 168 | pooled = self.vision_tower.attn_pool_contrastive(tokens) 169 | else: 170 | # this is the original OpenCLIP CoCa setup, does not match paper 171 | x = self.vision_tower.attn_pool(x) 172 | x = self.vision_tower.ln_post(x) 173 | pooled, tokens = self.vision_tower._global_pool(x) 174 | elif self.vision_tower.final_ln_after_pool: 175 | pooled, tokens = self.vision_tower._global_pool(x) 176 | pooled = self.vision_tower.ln_post(pooled) 177 | else: 178 | x = self.vision_tower.ln_post(x) 179 | pooled, tokens = self.vision_tower._global_pool(x) 180 | 181 | if self.vision_tower.proj is not None: 182 | pooled = pooled @ self.vision_tower.proj 183 | 184 | image_features = tokens 185 | image_features = image_features.to(torch.bfloat16) 186 | image_features = self.mm_projector(image_features) 187 | 188 | image_features = image_features.to(torch.bfloat16) 189 | return image_features 190 | 191 | def forward( 192 | self, 193 | input_ids: torch.LongTensor = None, 194 | attention_mask: Optional[torch.Tensor] = None, 195 | past_key_values: Optional[List[torch.FloatTensor]] = None, 196 | inputs_embeds: Optional[torch.FloatTensor] = None, 197 | labels: Optional[torch.LongTensor] = None, # (1,4317) 198 | use_cache: Optional[bool] = None, 199 | output_attentions: Optional[bool] = None, 200 | output_hidden_states: Optional[bool] = None, 201 | images: Optional[torch.FloatTensor] = None, 202 | return_dict: Optional[bool] = None, 203 | ) -> Union[Tuple, CausalLMOutputWithPast]: 204 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 205 | output_hidden_states = ( 206 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 207 | ) 208 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 209 | 210 | 211 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal( 212 | input_ids, attention_mask, past_key_values, labels, images) 213 | 214 | outputs = self.llama( 215 | input_ids=input_ids, 216 | attention_mask=attention_mask, 217 | past_key_values=past_key_values, 218 | inputs_embeds=inputs_embeds, 219 | use_cache=use_cache, 220 | output_attentions=output_attentions, 221 | output_hidden_states=output_hidden_states, 222 | return_dict=return_dict 223 | ) 224 | 225 | hidden_states = outputs[0] 226 | logits = self.lm_head(hidden_states) 227 | 228 | loss = None 229 | 230 | return CausalLMOutputWithPast( 231 | loss=loss, 232 | logits=logits, 233 | past_key_values=outputs.past_key_values, 234 | hidden_states=outputs.hidden_states, 235 | attentions=outputs.attentions, 236 | ) 237 | 238 | # original multimodal code 239 | def prepare_inputs_labels_for_multimodal( 240 | self, input_ids, attention_mask, past_key_values, labels, images 241 | ): 242 | vision_tower = self.vision_tower 243 | if vision_tower is None or images is None or input_ids.shape[1] == 1: 244 | if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[ 245 | 1] == 1: 246 | attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), 247 | dtype=attention_mask.dtype, device=attention_mask.device) 248 | return input_ids, attention_mask, past_key_values, None, labels 249 | 250 | if type(images) is list or images.ndim == 5: 251 | concat_images = torch.cat([image for image in images], dim=0) 252 | image_features = self.encode_images(concat_images) 253 | split_sizes = [image.shape[0] for image in images] 254 | image_features = torch.split(image_features, split_sizes, dim=0) 255 | image_features = [x.flatten(0, 1) for x in image_features] 256 | else: 257 | image_features = self.encode_images(images) 258 | 259 | new_input_embeds = [] 260 | new_labels = [] if labels is not None else None 261 | cur_image_idx = 0 262 | for batch_idx, cur_input_ids in enumerate(input_ids): 263 | if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: 264 | # multimodal LLM, but the current sample is not multimodal 265 | cur_input_embeds = self.llama.embed_tokens(cur_input_ids) 266 | cur_input_embeds = cur_input_embeds + (0. * self.mm_projector(vision_tower.dummy_feature)).sum() 267 | new_input_embeds.append(cur_input_embeds) 268 | if labels is not None: 269 | new_labels.append(labels[batch_idx]) 270 | cur_image_idx += 1 271 | continue 272 | image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] 273 | cur_new_input_embeds = [] 274 | if labels is not None: 275 | cur_labels = labels[batch_idx] 276 | cur_new_labels = [] 277 | assert cur_labels.shape == cur_input_ids.shape 278 | while image_token_indices.numel() > 0: 279 | cur_image_features = image_features[cur_image_idx] 280 | image_token_start = image_token_indices[0] 281 | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', 282 | False): 283 | cur_new_input_embeds.append(self.llama.embed_tokens(cur_input_ids[:image_token_start - 1]).detach()) 284 | cur_new_input_embeds.append( 285 | self.llama.embed_tokens(cur_input_ids[image_token_start - 1:image_token_start])) 286 | cur_new_input_embeds.append(cur_image_features) 287 | cur_new_input_embeds.append( 288 | self.llama.embed_tokens(cur_input_ids[image_token_start + 1:image_token_start + 2])) 289 | if labels is not None: 290 | cur_new_labels.append(cur_labels[:image_token_start]) 291 | cur_new_labels.append( 292 | torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, 293 | dtype=labels.dtype)) 294 | cur_new_labels.append(cur_labels[image_token_start:image_token_start + 1]) 295 | cur_labels = cur_labels[image_token_start + 2:] 296 | else: 297 | cur_new_input_embeds.append(self.llama.embed_tokens(cur_input_ids[:image_token_start])) 298 | cur_new_input_embeds.append(cur_image_features) 299 | if labels is not None: 300 | cur_new_labels.append(cur_labels[:image_token_start]) 301 | cur_new_labels.append( 302 | torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, 303 | dtype=labels.dtype)) 304 | cur_labels = cur_labels[image_token_start + 1:] 305 | cur_image_idx += 1 306 | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', 307 | False): 308 | cur_input_ids = cur_input_ids[image_token_start + 2:] 309 | else: 310 | cur_input_ids = cur_input_ids[image_token_start + 1:] 311 | image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] 312 | if cur_input_ids.numel() > 0: 313 | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', 314 | False): 315 | cur_new_input_embeds.append(self.llama.embed_tokens(cur_input_ids).detach()) 316 | else: 317 | cur_new_input_embeds.append(self.llama.embed_tokens(cur_input_ids)) 318 | if labels is not None: 319 | cur_new_labels.append(cur_labels) 320 | cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds] 321 | 322 | cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) 323 | new_input_embeds.append(cur_new_input_embeds) 324 | if labels is not None: 325 | cur_new_labels = torch.cat(cur_new_labels, dim=0) 326 | new_labels.append(cur_new_labels) 327 | 328 | if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): 329 | max_len = max(x.shape[0] for x in new_input_embeds) 330 | 331 | new_input_embeds_align = [] 332 | for cur_new_embed in new_input_embeds: 333 | cur_new_embed = torch.cat((cur_new_embed, 334 | torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), 335 | dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) 336 | new_input_embeds_align.append(cur_new_embed) 337 | new_input_embeds = torch.stack(new_input_embeds_align, dim=0) 338 | 339 | if labels is not None: 340 | new_labels_align = [] 341 | _new_labels = new_labels 342 | for cur_new_label in new_labels: 343 | cur_new_label = torch.cat((cur_new_label, 344 | torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, 345 | dtype=cur_new_label.dtype, device=cur_new_label.device)), 346 | dim=0) 347 | new_labels_align.append(cur_new_label) 348 | new_labels = torch.stack(new_labels_align, dim=0) 349 | 350 | if attention_mask is not None: 351 | new_attention_mask = [] 352 | for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, 353 | new_labels): 354 | new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, 355 | dtype=attention_mask.dtype, device=attention_mask.device) 356 | new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), 357 | False, dtype=attention_mask.dtype, 358 | device=attention_mask.device) 359 | cur_new_attention_mask = torch.cat( 360 | (new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0) 361 | new_attention_mask.append(cur_new_attention_mask) 362 | attention_mask = torch.stack(new_attention_mask, dim=0) 363 | assert attention_mask.shape == new_labels.shape 364 | else: 365 | new_input_embeds = torch.stack(new_input_embeds, dim=0) 366 | if labels is not None: 367 | new_labels = torch.stack(new_labels, dim=0) 368 | 369 | if attention_mask is not None: 370 | new_attn_mask_pad_left = torch.full( 371 | (attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, 372 | dtype=attention_mask.dtype, device=attention_mask.device) 373 | attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1) 374 | assert attention_mask.shape == new_input_embeds.shape[:2] 375 | 376 | return None, attention_mask, past_key_values, new_input_embeds, new_labels 377 | 378 | # sw-modified code 379 | 380 | def prepare_inputs_labels_for_multimodal_use_final_vector( 381 | self, input_ids, attention_mask, past_key_values, labels, images 382 | ): 383 | vision_tower = self.vision_tower 384 | if vision_tower is None or images is None or input_ids.shape[1] == 1: 385 | if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[ 386 | 1] == 1: 387 | attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), 388 | dtype=attention_mask.dtype, device=attention_mask.device) 389 | return input_ids, attention_mask, past_key_values, None, labels 390 | 391 | if type(images) is list or images.ndim == 5: 392 | concat_images = torch.cat([image for image in images], dim=0) 393 | image_features = self.encode_images(concat_images) 394 | split_sizes = [image.shape[0] for image in images] 395 | image_features = torch.split(image_features, split_sizes, dim=0) 396 | image_features = [x.flatten(0, 1) for x in image_features] 397 | else: 398 | image_features = self.encode_images(images) 399 | 400 | new_input_embeds = [] 401 | new_labels = [] if labels is not None else None 402 | cur_image_idx = 0 403 | for batch_idx, cur_input_ids in enumerate(input_ids): 404 | if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: 405 | # multimodal LLM, but the current sample is not multimodal 406 | cur_input_embeds = self.llama.embed_tokens(cur_input_ids) 407 | cur_input_embeds = cur_input_embeds + (0. * self.mm_projector(vision_tower.dummy_feature)).sum() 408 | new_input_embeds.append(cur_input_embeds) 409 | if labels is not None: 410 | new_labels.append(labels[batch_idx]) 411 | cur_image_idx += 1 412 | continue 413 | image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] 414 | cur_new_input_embeds = [] 415 | if labels is not None: 416 | cur_labels = labels[batch_idx] 417 | cur_new_labels = [] 418 | assert cur_labels.shape == cur_input_ids.shape 419 | while image_token_indices.numel() > 0: 420 | cur_image_features = image_features[cur_image_idx] 421 | image_token_start = image_token_indices[0] 422 | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', 423 | False): 424 | cur_new_input_embeds.append(self.llama.embed_tokens(cur_input_ids[:image_token_start - 1]).detach()) 425 | cur_new_input_embeds.append( 426 | self.llama.embed_tokens(cur_input_ids[image_token_start - 1:image_token_start])) 427 | cur_new_input_embeds.append(cur_image_features) 428 | cur_new_input_embeds.append( 429 | self.llama.embed_tokens(cur_input_ids[image_token_start + 1:image_token_start + 2])) 430 | if labels is not None: 431 | cur_new_labels.append(cur_labels[:image_token_start]) 432 | cur_new_labels.append( 433 | torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, 434 | dtype=labels.dtype)) 435 | cur_new_labels.append(cur_labels[image_token_start:image_token_start + 1]) 436 | cur_labels = cur_labels[image_token_start + 2:] 437 | else: 438 | cur_new_input_embeds.append( 439 | self.llama.embed_tokens(cur_input_ids[:image_token_start].to(self.device))) 440 | cur_new_input_embeds.append(cur_image_features) 441 | if labels is not None: 442 | cur_new_labels.append(cur_labels[:image_token_start]) 443 | cur_new_labels.append( 444 | torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, 445 | dtype=labels.dtype)) 446 | cur_labels = cur_labels[image_token_start + 1:] 447 | cur_image_idx += 1 448 | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', 449 | False): 450 | cur_input_ids = cur_input_ids[image_token_start + 2:] 451 | else: 452 | cur_input_ids = cur_input_ids[image_token_start + 1:] 453 | image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] 454 | if cur_input_ids.numel() > 0: 455 | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', 456 | False): 457 | cur_new_input_embeds.append(self.llama.embed_tokens(cur_input_ids).detach()) 458 | else: 459 | cur_new_input_embeds.append(self.llama.embed_tokens(cur_input_ids.to(self.device))) 460 | if labels is not None: 461 | # seowoo-edit 462 | cur_labels = labels[batch_idx] 463 | cur_new_labels.append(cur_labels) 464 | # [5120] -> [1, 5120] 465 | cur_new_input_embeds[1] = torch.unsqueeze(cur_new_input_embeds[1], dim=0) 466 | cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds] 467 | cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) 468 | new_input_embeds.append(cur_new_input_embeds) 469 | if labels is not None: 470 | cur_new_labels = torch.cat(cur_new_labels, dim=0) 471 | new_labels.append(cur_new_labels) 472 | 473 | if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): 474 | # print("if 204") 475 | max_len = max(x.shape[0] for x in new_input_embeds) 476 | 477 | new_input_embeds_align = [] 478 | for cur_new_embed in new_input_embeds: 479 | cur_new_embed = torch.cat((cur_new_embed, 480 | torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), 481 | dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) 482 | new_input_embeds_align.append(cur_new_embed) 483 | new_input_embeds = torch.stack(new_input_embeds_align, dim=0) 484 | 485 | if labels is not None: 486 | new_labels_align = [] 487 | _new_labels = new_labels 488 | for cur_new_label in new_labels: 489 | cur_new_label = torch.cat((cur_new_label, 490 | torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, 491 | dtype=cur_new_label.dtype, device=cur_new_label.device)), 492 | dim=0) 493 | new_labels_align.append(cur_new_label) 494 | new_labels = torch.stack(new_labels_align, dim=0) 495 | 496 | if attention_mask is not None: 497 | new_attention_mask = [] 498 | for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, 499 | new_labels): 500 | new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, 501 | dtype=attention_mask.dtype, device=attention_mask.device) 502 | new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), 503 | False, dtype=attention_mask.dtype, 504 | device=attention_mask.device) 505 | cur_new_attention_mask = torch.cat( 506 | (new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0) 507 | new_attention_mask.append(cur_new_attention_mask) 508 | attention_mask = torch.stack(new_attention_mask, dim=0) 509 | assert attention_mask.shape == new_labels.shape 510 | else: 511 | new_input_embeds = torch.stack(new_input_embeds, dim=0) 512 | if labels is not None: 513 | new_labels = torch.stack(new_labels, dim=0) 514 | 515 | if attention_mask is not None: 516 | new_attn_mask_pad_left = torch.full( 517 | (attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, 518 | dtype=attention_mask.dtype, device=attention_mask.device) 519 | attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1) 520 | assert attention_mask.shape == new_input_embeds.shape[:2] 521 | 522 | return None, attention_mask, past_key_values, new_input_embeds, labels 523 | 524 | def prepare_inputs_for_generation( 525 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 526 | ): 527 | if past_key_values: 528 | input_ids = input_ids[:, -1:] 529 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 530 | if inputs_embeds is not None and past_key_values is None: 531 | model_inputs = {"inputs_embeds": inputs_embeds} 532 | else: 533 | model_inputs = {"input_ids": input_ids} 534 | model_inputs.update( 535 | { 536 | "past_key_values": past_key_values, 537 | "use_cache": kwargs.get("use_cache"), 538 | "attention_mask": attention_mask, 539 | "images": kwargs.get("images", None), 540 | } 541 | ) 542 | return model_inputs 543 | 544 | def apply_chat_template(self, chat): 545 | return self.tokenizer.apply_chat_template(chat, tokenize=False) 546 | 547 | def tokenizer_image_token(self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 548 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 549 | 550 | def insert_separator(X, sep): 551 | return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] 552 | 553 | input_ids = [] 554 | offset = 0 555 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 556 | offset = 1 557 | input_ids.append(prompt_chunks[0][0]) 558 | 559 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 560 | input_ids.extend(x[offset:]) 561 | 562 | if return_tensors is not None: 563 | if return_tensors == 'pt': 564 | return torch.tensor(input_ids, dtype=torch.long) 565 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 566 | return input_ids 567 | 568 | def write_radiologic_report(self, image, temperature=0.2, top_p=0.8): 569 | chat = [ 570 | {"role": "system", 571 | "content": "You are a helpful radiologist. Try to interpret chest x ray image and answer to the question that user provides."}, 572 | {"role": "user", 573 | "content": "\nWrite a radiologic report on the given chest radiograph, including information about atelectasis, cardiomegaly, consolidation, pulmonary edema, pleural effusion, and pneumothorax.\n"} 574 | ] 575 | response = self.generate_cxr_repsonse(chat=chat,image=image, temperature=temperature, top_p=top_p) 576 | return response 577 | 578 | def write_differential_diagnosis(self, image, temperature=0.2, top_p=0.8): 579 | chat = [ 580 | {"role": "system", 581 | "content": "You are a helpful radiologist. Try to interpret chest x ray image and answer to the question that user provides."}, 582 | {"role": "user", 583 | "content": "\nWhat are the possible differential diagnoses for this patient?\n"} 584 | ] 585 | response = self.generate_cxr_repsonse(chat=chat, image=image, temperature=temperature, top_p=top_p) 586 | return response 587 | def ask_question(self, question, image, temperature=0.2, top_p=0.8): 588 | chat = [ 589 | {"role": "system", 590 | "content": "You are a helpful radiologist. Try to interpret chest x ray image and answer to the question that user provides."}, 591 | {"role": "user", 592 | "content": "\n"+question} 593 | ] 594 | response = self.generate_cxr_repsonse(chat=chat, image=image, temperature=temperature, top_p=top_p) 595 | return response 596 | 597 | def generate_cxr_repsonse(self, chat, image, temperature=0.2, top_p=0.8): 598 | with torch.no_grad(): 599 | streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) 600 | 601 | if np.array(image).max()>255: 602 | raise Exception("WARNING. 16-bit image is not supported.") 603 | 604 | image = image.convert('L') # convert to grayscale 605 | image = np.array(image) 606 | 607 | if len(image.shape) == 2: 608 | image = np.expand_dims(image,axis=-1) # (width, height) --> (width, height, 1) 609 | 610 | prompt = self.apply_chat_template(chat) 611 | images = self.vision_tower.image_processor(image, return_tensors='pt')['pixel_values'] 612 | images = images.to(self.device) 613 | input_ids = self.tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 614 | stopping_criteria = KeywordsStoppingCriteria([""], self.tokenizer, input_ids) 615 | 616 | image_args = {"images": images} 617 | do_sample = True if temperature > 0.001 else False 618 | num_image_tokens = 1 619 | max_context_length = getattr(self.config, 'max_position_embeddings', 2048) 620 | 621 | max_new_tokens = min(512, max_context_length - input_ids.shape[-1] - num_image_tokens) 622 | thread = Thread(target=self.generate, kwargs=dict( 623 | inputs=input_ids, 624 | do_sample=do_sample, 625 | temperature=temperature, 626 | top_p=top_p, 627 | max_new_tokens=max_new_tokens, 628 | streamer=streamer, 629 | stopping_criteria=[stopping_criteria], 630 | use_cache=True, 631 | generation_config=self.generation_config, 632 | **image_args 633 | )) 634 | thread.start() 635 | generated_text = "" 636 | for new_text in streamer: 637 | generated_text += new_text 638 | 639 | return generated_text 640 | 641 | def tokenizer_image_token(self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 642 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 643 | 644 | def insert_separator(X, sep): 645 | return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] 646 | 647 | input_ids = [] 648 | offset = 0 649 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 650 | offset = 1 651 | input_ids.append(prompt_chunks[0][0]) 652 | 653 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 654 | input_ids.extend(x[offset:]) 655 | 656 | if return_tensors is not None: 657 | if return_tensors == 'pt': 658 | return torch.tensor(input_ids, dtype=torch.long) 659 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 660 | return input_ids 661 | class KeywordsStoppingCriteria(StoppingCriteria): 662 | def __init__(self, keywords, tokenizer, input_ids): 663 | self.keywords = keywords 664 | self.keyword_ids = [] 665 | for keyword in keywords: 666 | cur_keyword_ids = tokenizer(keyword).input_ids 667 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 668 | cur_keyword_ids = cur_keyword_ids[1:] 669 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 670 | self.tokenizer = tokenizer 671 | self.start_len = input_ids.shape[1] 672 | 673 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 674 | assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO 675 | offset = min(output_ids.shape[1] - self.start_len, 3) 676 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 677 | for keyword_id in self.keyword_ids: 678 | if output_ids[0, -keyword_id.shape[0]:] == keyword_id: 679 | return True 680 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 681 | for keyword in self.keywords: 682 | if keyword in outputs: 683 | return True 684 | return False 685 | @dataclass 686 | class CLIPVisionCfg: 687 | layers: Union[Tuple[int, int, int, int], int] = 12 688 | width: int = 768 689 | head_width: int = 64 690 | mlp_ratio: float = 4.0 691 | patch_size: int = 16 692 | image_size: Union[Tuple[int, int], int] = 224 693 | 694 | ls_init_value: Optional[float] = None # layer scale initial value 695 | patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results 696 | attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type) 697 | attn_pooler_queries: int = 256 # n_queries for attentional pooler 698 | attn_pooler_heads: int = 8 # n heads for attentional_pooling 699 | no_ln_pre: bool = False # disable pre transformer LayerNorm 700 | pos_embed_type: str = 'learnable' 701 | final_ln_after_pool: bool = False # apply final LayerNorm after pooling 702 | pool_type: str = 'tok' 703 | output_tokens: bool = False 704 | act_kwargs: Optional[dict] = None 705 | norm_kwargs: Optional[dict] = None 706 | 707 | timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size 708 | timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model 709 | timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') 710 | timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') 711 | timm_proj_bias: bool = False # enable bias final projection 712 | timm_drop: float = 0. # head dropout 713 | timm_drop_path: Optional[float] = None # backbone stochastic depth 714 | --------------------------------------------------------------------------------