├── Images ├── GIF.gif ├── result_rationale.jpg ├── result_rationale.pdf ├── result_report.jpg ├── result_report.pdf ├── result_report_generation.jpg ├── result_report_generation.pdf ├── result_vqa.jpg └── result_vqa.pdf ├── Quick_demo ├── Language_files │ ├── config.json │ ├── special_tokens_map.json │ ├── tokenizer.model │ └── tokenizer_config.json ├── MedKEBERT │ ├── config.json │ ├── special_tokens_map.json │ ├── tokenizer.json │ ├── tokenizer_config.json │ └── vocab.txt ├── Model │ └── RadFM │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── blocks.cpython-39.pyc │ │ ├── helpers.cpython-39.pyc │ │ ├── multimodality_model.cpython-39.pyc │ │ ├── my_embedding_layer.cpython-39.pyc │ │ ├── position_encoding.cpython-39.pyc │ │ ├── transformer_decoder.cpython-39.pyc │ │ ├── utils.cpython-39.pyc │ │ └── vit_3d.cpython-39.pyc │ │ ├── blocks.py │ │ ├── helpers.py │ │ ├── multimodality_model.py │ │ ├── my_embedding_layer.py │ │ ├── position_encoding.py │ │ ├── transformer_decoder.py │ │ ├── utils.py │ │ └── vit_3d.py ├── test.py └── view1_frontal.jpg ├── README.md ├── requirements.txt └── src ├── Dataset ├── dataset │ ├── MedPix_dataset.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── MedPix_dataset.cpython-39.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── binary.cpython-39.pyc │ │ ├── case_report.cpython-39.pyc │ │ ├── chestxray.cpython-310.pyc │ │ ├── chestxray.cpython-39.pyc │ │ ├── paper_inline.cpython-39.pyc │ │ ├── pmcoa.cpython-39.pyc │ │ ├── pmcvqa.cpython-39.pyc │ │ ├── radiopaedia.cpython-39.pyc │ │ └── radiovqa.cpython-39.pyc │ ├── binary.py │ ├── caption_prompt.json │ ├── case_report.py │ ├── chestxray.py │ ├── cls_prompt.json │ ├── data_csv │ │ └── README.md │ ├── dicom_to_png_for_VinDR_sampled_using_mammo.py │ ├── jpg2nii_data_convert.py │ ├── mammo_prompt.json │ ├── modality_prompt.json │ ├── nii2npy_for_radiopaedio.py │ ├── paper_inline.py │ ├── pmcoa.py │ ├── radiology_feature_prompt.json │ ├── radiopaedia.py │ ├── report_prompt.json │ ├── spinexr_prompt.json │ ├── vqa.py │ └── yes_no_prompt.json ├── multi_dataset.py ├── multi_dataset_test.py └── multi_dataset_test_for_close.py ├── Model └── RadFM │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-39.pyc │ ├── blocks.cpython-39.pyc │ ├── helpers.cpython-39.pyc │ ├── multimodality_model.cpython-39.pyc │ ├── my_embedding_layer.cpython-39.pyc │ ├── position_encoding.cpython-39.pyc │ ├── transformer_decoder.cpython-39.pyc │ ├── utils.cpython-39.pyc │ └── vit_3d.cpython-39.pyc │ ├── blocks.py │ ├── helpers.py │ ├── multimodality_model.py │ ├── my_embedding_layer.py │ ├── position_encoding.py │ ├── transformer_decoder.py │ ├── utils.py │ └── vit_3d.py ├── My_Trainer ├── __pycache__ │ └── trainer.cpython-39.pyc └── trainer.py ├── datasampler.py ├── output_csv_example └── caption_example.csv ├── test.py └── train.py /Images/GIF.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Images/GIF.gif -------------------------------------------------------------------------------- /Images/result_rationale.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Images/result_rationale.jpg -------------------------------------------------------------------------------- /Images/result_rationale.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Images/result_rationale.pdf -------------------------------------------------------------------------------- /Images/result_report.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Images/result_report.jpg -------------------------------------------------------------------------------- /Images/result_report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Images/result_report.pdf -------------------------------------------------------------------------------- /Images/result_report_generation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Images/result_report_generation.jpg -------------------------------------------------------------------------------- /Images/result_report_generation.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Images/result_report_generation.pdf -------------------------------------------------------------------------------- /Images/result_vqa.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Images/result_vqa.jpg -------------------------------------------------------------------------------- /Images/result_vqa.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Images/result_vqa.pdf -------------------------------------------------------------------------------- /Quick_demo/Language_files/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "/home/cs/leijiayu/wuchaoyi/Finetune_LLAMA/LLAMA_Model/llama-13b-hf", 3 | "architectures": [ 4 | "LlamaForCausalLM" 5 | ], 6 | "bos_token_id": 0, 7 | "eos_token_id": 1, 8 | "hidden_act": "silu", 9 | "hidden_size": 5120, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 13824, 12 | "max_sequence_length": 2048, 13 | "model_type": "llama", 14 | "num_attention_heads": 40, 15 | "num_hidden_layers": 40, 16 | "pad_token_id": -1, 17 | "rms_norm_eps": 1e-06, 18 | "tie_word_embeddings": false, 19 | "torch_dtype": "float32", 20 | "transformers_version": "4.28.0.dev0", 21 | "use_cache": true, 22 | "vocab_size": 32000 23 | } 24 | -------------------------------------------------------------------------------- /Quick_demo/Language_files/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /Quick_demo/Language_files/tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/Language_files/tokenizer.model -------------------------------------------------------------------------------- /Quick_demo/Language_files/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | {"bos_token": "", "eos_token": "", "model_max_length": 1000000000000000019884624838656, "tokenizer_class": "LlamaTokenizer", "unk_token": ""} -------------------------------------------------------------------------------- /Quick_demo/MedKEBERT/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "xmcmic/Med-KEBERT", 3 | "architectures": [ 4 | "BertModel" 5 | ], 6 | "attention_probs_dropout_prob": 0.1, 7 | "classifier_dropout": null, 8 | "gradient_checkpointing": false, 9 | "hidden_act": "gelu", 10 | "hidden_dropout_prob": 0.1, 11 | "hidden_size": 768, 12 | "initializer_range": 0.02, 13 | "intermediate_size": 3072, 14 | "layer_norm_eps": 1e-12, 15 | "max_position_embeddings": 512, 16 | "model_type": "bert", 17 | "num_attention_heads": 12, 18 | "num_hidden_layers": 12, 19 | "output_hidden_states": true, 20 | "pad_token_id": 0, 21 | "position_embedding_type": "absolute", 22 | "torch_dtype": "float32", 23 | "transformers_version": "4.24.0", 24 | "type_vocab_size": 2, 25 | "use_cache": true, 26 | "vocab_size": 30522 27 | } 28 | -------------------------------------------------------------------------------- /Quick_demo/MedKEBERT/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "cls_token": "[CLS]", 3 | "mask_token": "[MASK]", 4 | "pad_token": "[PAD]", 5 | "sep_token": "[SEP]", 6 | "unk_token": "[UNK]" 7 | } 8 | -------------------------------------------------------------------------------- /Quick_demo/MedKEBERT/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "cls_token": "[CLS]", 3 | "do_basic_tokenize": true, 4 | "do_lower_case": true, 5 | "mask_token": "[MASK]", 6 | "name_or_path": "xmcmic/Med-KEBERT", 7 | "never_split": null, 8 | "pad_token": "[PAD]", 9 | "sep_token": "[SEP]", 10 | "special_tokens_map_file": null, 11 | "strip_accents": null, 12 | "tokenize_chinese_chars": true, 13 | "tokenizer_class": "BertTokenizer", 14 | "unk_token": "[UNK]" 15 | } 16 | -------------------------------------------------------------------------------- /Quick_demo/Model/RadFM/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/Model/RadFM/__init__.py -------------------------------------------------------------------------------- /Quick_demo/Model/RadFM/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/Model/RadFM/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /Quick_demo/Model/RadFM/__pycache__/blocks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/Model/RadFM/__pycache__/blocks.cpython-39.pyc -------------------------------------------------------------------------------- /Quick_demo/Model/RadFM/__pycache__/helpers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/Model/RadFM/__pycache__/helpers.cpython-39.pyc -------------------------------------------------------------------------------- /Quick_demo/Model/RadFM/__pycache__/multimodality_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/Model/RadFM/__pycache__/multimodality_model.cpython-39.pyc -------------------------------------------------------------------------------- /Quick_demo/Model/RadFM/__pycache__/my_embedding_layer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/Model/RadFM/__pycache__/my_embedding_layer.cpython-39.pyc -------------------------------------------------------------------------------- /Quick_demo/Model/RadFM/__pycache__/position_encoding.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/Model/RadFM/__pycache__/position_encoding.cpython-39.pyc -------------------------------------------------------------------------------- /Quick_demo/Model/RadFM/__pycache__/transformer_decoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/Model/RadFM/__pycache__/transformer_decoder.cpython-39.pyc -------------------------------------------------------------------------------- /Quick_demo/Model/RadFM/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/Model/RadFM/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /Quick_demo/Model/RadFM/__pycache__/vit_3d.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/Model/RadFM/__pycache__/vit_3d.cpython-39.pyc -------------------------------------------------------------------------------- /Quick_demo/Model/RadFM/blocks.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union, Callable, Optional 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from torch.utils.checkpoint import checkpoint 8 | 9 | class PMC_CLIP_cfg: 10 | backbone: str = 'ModifiedRN50' # ['RN50', 'ModifiedRN50', 'MAE'] 11 | layers: Union[Tuple[int, int, int, int], int] = [3,4,6,3] 12 | width: int = 64 13 | head_width: int = 64 14 | mlp_ratio: float = 4.0 15 | patch_size: int = 16 16 | image_size: Union[Tuple[int, int], int] = 224 17 | timm_model_name: str = None # a valid model name overrides layers, width, patch_size 18 | timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model 19 | timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') 20 | timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') 21 | patch_dropout: float = 0.0 # patch dropout rate, no dropout by default 22 | drop_attention_rate: float = 0. # Transformer Dropout 23 | patch_size: None 24 | 25 | class Bottleneck(nn.Module): 26 | expansion = 4 27 | 28 | def __init__(self, inplanes, planes, stride=1): 29 | super().__init__() 30 | 31 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 32 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 33 | self.bn1 = nn.BatchNorm2d(planes) 34 | self.relu1 = nn.ReLU(inplace=True) 35 | 36 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 37 | self.bn2 = nn.BatchNorm2d(planes) 38 | self.relu2 = nn.ReLU(inplace=True) 39 | 40 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 41 | 42 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 43 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 44 | self.relu3 = nn.ReLU(inplace=True) 45 | 46 | self.downsample = None 47 | self.stride = stride 48 | 49 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 50 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 51 | self.downsample = nn.Sequential(OrderedDict([ 52 | ("-1", nn.AvgPool2d(stride)), 53 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 54 | ("1", nn.BatchNorm2d(planes * self.expansion)) 55 | ])) 56 | 57 | def forward(self, x: torch.Tensor): 58 | identity = x 59 | 60 | out = self.relu1(self.bn1(self.conv1(x))) 61 | out = self.relu2(self.bn2(self.conv2(out))) 62 | out = self.avgpool(out) 63 | out = self.bn3(self.conv3(out)) 64 | 65 | if self.downsample is not None: 66 | identity = self.downsample(x) 67 | 68 | out += identity 69 | out = self.relu3(out) 70 | return out 71 | 72 | 73 | class AttentionPool2d(nn.Module): 74 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 75 | super().__init__() 76 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 77 | self.k_proj = nn.Linear(embed_dim, embed_dim) 78 | self.q_proj = nn.Linear(embed_dim, embed_dim) 79 | self.v_proj = nn.Linear(embed_dim, embed_dim) 80 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 81 | self.num_heads = num_heads 82 | 83 | def forward(self, x): 84 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 85 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 86 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 87 | x, _ = F.multi_head_attention_forward( 88 | query=x, key=x, value=x, 89 | embed_dim_to_check=x.shape[-1], 90 | num_heads=self.num_heads, 91 | q_proj_weight=self.q_proj.weight, 92 | k_proj_weight=self.k_proj.weight, 93 | v_proj_weight=self.v_proj.weight, 94 | in_proj_weight=None, 95 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 96 | bias_k=None, 97 | bias_v=None, 98 | add_zero_attn=False, 99 | dropout_p=0, 100 | out_proj_weight=self.c_proj.weight, 101 | out_proj_bias=self.c_proj.bias, 102 | use_separate_proj_weight=True, 103 | training=self.training, 104 | need_weights=False 105 | ) 106 | 107 | return x[0] 108 | 109 | 110 | class ResNet(nn.Module): 111 | """ 112 | RN50 113 | """ 114 | 115 | def __init__( 116 | self, layers, output_dim, heads, image_size=224, width=64, 117 | block=Bottleneck, 118 | ): 119 | super().__init__() 120 | self.output_dim = output_dim 121 | self.image_size = image_size 122 | 123 | # the 1-layer stem 124 | self.conv1 = nn.Conv2d(3, width, kernel_size=3, stride=2, padding=1, bias=False) 125 | self.bn1 = nn.BatchNorm2d(width) 126 | self.relu1 = nn.ReLU(inplace=True) 127 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 128 | 129 | # residual layers 130 | self._inplanes = width # this is a *mutable* variable used during construction 131 | self.layer1 = self._make_layer(width, layers[0]) 132 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 133 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 134 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 135 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 136 | # self.head = nn.Linear(512 * 6, output_dim) 137 | self.head = nn.Linear(512 * block.expansion, output_dim) 138 | 139 | # embed_dim = width * 32 # the ResNet feature dimension 140 | # self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 141 | 142 | self.init_parameters() 143 | 144 | def _make_layer( 145 | self, 146 | planes, blocks, stride=1, 147 | block=Bottleneck, 148 | ): 149 | layers = [block(self._inplanes, planes, stride)] 150 | 151 | self._inplanes = planes * block.expansion 152 | for _ in range(1, blocks): 153 | layers.append(block(self._inplanes, planes)) 154 | 155 | return nn.Sequential(*layers) 156 | 157 | def init_parameters(self): 158 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 159 | for name, param in resnet_block.named_parameters(): 160 | if name.endswith("bn3.weight"): 161 | nn.init.zeros_(param) 162 | 163 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 164 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 165 | for param in self.parameters(): 166 | param.requires_grad = False 167 | if freeze_bn_stats: 168 | freeze_batch_norm_2d(self) 169 | 170 | @torch.jit.ignore 171 | def set_grad_checkpointing(self, enable=True): 172 | # FIXME support for non-transformer 173 | pass 174 | 175 | def stem(self, x): 176 | x = self.relu1(self.bn1(self.conv1(x))) 177 | x = self.maxpool(x) 178 | return x 179 | 180 | def forward(self, x): 181 | # x[0]: [batch_size, 3, 224, 224] 182 | # x[1]: [batch_size, 1] 183 | x = self.stem(x) # [batch_size, 64, 56, 56] 184 | x = self.layer1(x) 185 | x = self.layer2(x) 186 | x = self.layer3(x) 187 | x = self.layer4(x) # [batch_size, 2048, 7, 7] 188 | x = self.avgpool(x) # [batch_size, 2048, 1, 1] 189 | x = torch.flatten(x, 1) # [batch_size, 2048*1*1] 190 | x = self.head(x) # [batch_size, 1024] 191 | 192 | visual_output = dict.fromkeys(["image_features", "mim_loss"], None) 193 | visual_output.update({ 194 | 'image_features': x, 195 | }) 196 | 197 | return visual_output 198 | 199 | 200 | class ModifiedResNet(nn.Module): 201 | """ 202 | A ResNet class that is similar to torchvision's but contains the following changes: 203 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 204 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 205 | - The final pooling layer is a QKV attention instead of an average pool 206 | """ 207 | 208 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 209 | super().__init__() 210 | self.output_dim = output_dim 211 | self.image_size = image_size 212 | 213 | # the 3-layer stem 214 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 215 | self.bn1 = nn.BatchNorm2d(width // 2) 216 | self.relu1 = nn.ReLU(inplace=True) 217 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 218 | self.bn2 = nn.BatchNorm2d(width // 2) 219 | self.relu2 = nn.ReLU(inplace=True) 220 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 221 | self.bn3 = nn.BatchNorm2d(width) 222 | self.relu3 = nn.ReLU(inplace=True) 223 | self.avgpool = nn.AvgPool2d(2) 224 | 225 | # residual layers 226 | self._inplanes = width # this is a *mutable* variable used during construction 227 | self.layer1 = self._make_layer(width, layers[0]) 228 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 229 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 230 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 231 | 232 | embed_dim = width * 32 # the ResNet feature dimension 233 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 234 | 235 | self.init_parameters() 236 | 237 | def _make_layer(self, planes, blocks, stride=1): 238 | layers = [Bottleneck(self._inplanes, planes, stride)] 239 | 240 | self._inplanes = planes * Bottleneck.expansion 241 | for _ in range(1, blocks): 242 | layers.append(Bottleneck(self._inplanes, planes)) 243 | 244 | return nn.Sequential(*layers) 245 | 246 | def init_parameters(self): 247 | if self.attnpool is not None: 248 | std = self.attnpool.c_proj.in_features ** -0.5 249 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 250 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 251 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 252 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 253 | 254 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 255 | for name, param in resnet_block.named_parameters(): 256 | if name.endswith("bn3.weight"): 257 | nn.init.zeros_(param) 258 | 259 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 260 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 261 | for param in self.parameters(): 262 | param.requires_grad = False 263 | if freeze_bn_stats: 264 | freeze_batch_norm_2d(self) 265 | 266 | @torch.jit.ignore 267 | def set_grad_checkpointing(self, enable=True): 268 | # FIXME support for non-transformer 269 | pass 270 | 271 | def stem(self, x): 272 | x = self.relu1(self.bn1(self.conv1(x))) 273 | x = self.relu2(self.bn2(self.conv2(x))) 274 | x = self.relu3(self.bn3(self.conv3(x))) 275 | x = self.avgpool(x) 276 | return x 277 | 278 | def forward(self, x): 279 | x = self.stem(x) 280 | x = self.layer1(x) 281 | x = self.layer2(x) 282 | x = self.layer3(x) 283 | x = self.layer4(x) 284 | x = self.attnpool(x) 285 | 286 | visual_output = dict.fromkeys(["image_features", "mim_loss"], None) 287 | visual_output.update({ 288 | 'image_features': x, 289 | }) 290 | 291 | return visual_output 292 | 293 | 294 | class LayerNorm(nn.LayerNorm): 295 | """Subclass torch's LayerNorm to handle fp16.""" 296 | 297 | def forward(self, x: torch.Tensor): 298 | orig_type = x.dtype 299 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 300 | return x.to(orig_type) 301 | 302 | 303 | class QuickGELU(nn.Module): 304 | # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory 305 | def forward(self, x: torch.Tensor): 306 | return x * torch.sigmoid(1.702 * x) 307 | 308 | 309 | class ResidualAttentionBlock(nn.Module): 310 | def __init__( 311 | self, d_model: int, n_head: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU, 312 | drop_attention_rate: float = 0., 313 | ): 314 | super().__init__() 315 | 316 | self.attn = nn.MultiheadAttention( 317 | embed_dim=d_model, 318 | num_heads=n_head, 319 | dropout=drop_attention_rate, 320 | ) 321 | self.ln_1 = LayerNorm(d_model) 322 | mlp_width = int(d_model * mlp_ratio) 323 | self.mlp = nn.Sequential(OrderedDict([ 324 | ("c_fc", nn.Linear(d_model, mlp_width)), 325 | ("gelu", act_layer()), 326 | ("c_proj", nn.Linear(mlp_width, d_model)) 327 | ])) 328 | self.ln_2 = LayerNorm(d_model) 329 | 330 | def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 331 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] 332 | 333 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 334 | x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) 335 | x = x + self.mlp(self.ln_2(x)) 336 | return x 337 | 338 | 339 | class PatchDropout(nn.Module): 340 | """ 341 | https://arxiv.org/abs/2212.00794 342 | """ 343 | 344 | def __init__(self, prob, exclude_first_token=True): 345 | super().__init__() 346 | assert 0 <= prob < 1. 347 | self.prob = prob 348 | self.exclude_first_token = exclude_first_token # exclude CLS token 349 | 350 | def forward(self, x): 351 | if not self.training or self.prob == 0.: 352 | return x 353 | 354 | if self.exclude_first_token: 355 | cls_tokens, x = x[:, :1], x[:, 1:] 356 | else: 357 | cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) 358 | 359 | batch = x.size()[0] 360 | num_tokens = x.size()[1] 361 | 362 | batch_indices = torch.arange(batch) 363 | batch_indices = batch_indices[..., None] 364 | 365 | keep_prob = 1 - self.prob 366 | num_patches_keep = max(1, int(num_tokens * keep_prob)) 367 | 368 | rand = torch.randn(batch, num_tokens) 369 | patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices 370 | 371 | x = x[batch_indices, patch_indices_keep] 372 | 373 | if self.exclude_first_token: 374 | x = torch.cat((cls_tokens, x), dim=1) 375 | 376 | return x 377 | 378 | 379 | class Transformer(nn.Module): 380 | def __init__( 381 | self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU, 382 | drop_attention_rate: float = 0., 383 | ): 384 | super().__init__() 385 | self.width = width 386 | self.layers = layers 387 | self.grad_checkpointing = False 388 | 389 | self.resblocks = nn.ModuleList([ 390 | ResidualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer, drop_attention_rate=drop_attention_rate) 391 | for _ in range(layers) 392 | ]) 393 | 394 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 395 | for r in self.resblocks: 396 | if self.grad_checkpointing and not torch.jit.is_scripting(): 397 | x = checkpoint(r, x, attn_mask) 398 | else: 399 | x = r(x, attn_mask=attn_mask) 400 | return x -------------------------------------------------------------------------------- /Quick_demo/Model/RadFM/helpers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from https://github.com/lucidrains/flamingo-pytorch 3 | """ 4 | 5 | import torch 6 | from einops import rearrange, repeat 7 | from einops_exts import rearrange_many 8 | from torch import einsum, nn 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def FeedForward(dim, mult=4): 16 | inner_dim = int(dim * mult) 17 | return nn.Sequential( 18 | nn.LayerNorm(dim), 19 | nn.Linear(dim, inner_dim, bias=False), 20 | nn.GELU(), 21 | nn.Linear(inner_dim, dim, bias=False), 22 | ) 23 | 24 | 25 | class PerceiverAttention(nn.Module): 26 | def __init__(self, *, dim, dim_head=64, heads=8): 27 | super().__init__() 28 | self.scale = dim_head**-0.5 29 | self.heads = heads 30 | inner_dim = dim_head * heads 31 | 32 | self.norm_media = nn.LayerNorm(dim) 33 | self.norm_latents = nn.LayerNorm(dim) 34 | 35 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 36 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 37 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 38 | 39 | def forward(self, x, latents): 40 | """ 41 | Args: 42 | x (torch.Tensor): image features 43 | shape (b, T, n1, D) 44 | latent (torch.Tensor): latent features 45 | shape (b, T, n2, D) 46 | """ 47 | x = self.norm_media(x) 48 | latents = self.norm_latents(latents) 49 | 50 | h = self.heads 51 | 52 | q = self.to_q(latents) 53 | kv_input = torch.cat((x, latents), dim=-2) 54 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 55 | q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) 56 | q = q * self.scale 57 | 58 | # attention 59 | sim = einsum("... i d, ... j d -> ... i j", q, k) 60 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 61 | attn = sim.softmax(dim=-1) 62 | 63 | out = einsum("... i j, ... j d -> ... i d", attn, v) 64 | out = rearrange(out, "b h t n d -> b t n (h d)", h=h) 65 | return self.to_out(out) 66 | 67 | 68 | class PerceiverResampler(nn.Module): 69 | def __init__( 70 | self, 71 | *, 72 | dim, 73 | depth=6, 74 | dim_head=64, 75 | heads=8, 76 | num_latents=64, 77 | max_num_media=None, 78 | max_num_frames=None, 79 | ff_mult=4, 80 | ): 81 | super().__init__() 82 | self.latents = nn.Parameter(torch.randn(num_latents, dim)) 83 | self.frame_embs = ( 84 | nn.Parameter(torch.randn(max_num_frames, dim)) 85 | if exists(max_num_frames) 86 | else None 87 | ) 88 | self.media_time_embs = ( 89 | nn.Parameter(torch.randn(max_num_media, 1, dim)) 90 | if exists(max_num_media) 91 | else None 92 | ) 93 | 94 | self.layers = nn.ModuleList([]) 95 | for _ in range(depth): 96 | self.layers.append( 97 | nn.ModuleList( 98 | [ 99 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 100 | FeedForward(dim=dim, mult=ff_mult), 101 | ] 102 | ) 103 | ) 104 | 105 | self.norm = nn.LayerNorm(dim) 106 | 107 | def forward(self, x): 108 | """ 109 | Args: 110 | x (torch.Tensor): image features 111 | shape (b, T, F, v, D) 112 | Returns: 113 | shape (b, T, n, D) where n is self.num_latents 114 | """ 115 | b, T, F, v = x.shape[:4] 116 | 117 | # frame and media time embeddings 118 | if exists(self.frame_embs): 119 | frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) 120 | x = x + frame_embs 121 | x = rearrange( 122 | x, "b T F v d -> b T (F v) d" 123 | ) # flatten the frame and spatial dimensions 124 | if exists(self.media_time_embs): 125 | x = x + self.media_time_embs[:T] 126 | 127 | # blocks 128 | latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) 129 | for attn, ff in self.layers: 130 | latents = attn(x, latents) + latents 131 | latents = ff(latents) + latents 132 | return self.norm(latents) 133 | 134 | 135 | # gated cross attention 136 | 137 | 138 | class MaskedCrossAttention(nn.Module): 139 | def __init__( 140 | self, 141 | *, 142 | dim, 143 | dim_visual, 144 | dim_head=64, 145 | heads=8, 146 | only_attend_immediate_media=True, 147 | ): 148 | super().__init__() 149 | self.scale = dim_head**-0.5 150 | self.heads = heads 151 | inner_dim = dim_head * heads 152 | 153 | self.norm = nn.LayerNorm(dim) 154 | 155 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 156 | self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False) 157 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 158 | 159 | # whether for text to only attend to immediate preceding image, or all previous images 160 | self.only_attend_immediate_media = only_attend_immediate_media 161 | 162 | def forward(self, x, media, media_locations=None, attend_previous=True): 163 | """ 164 | Args: 165 | x (torch.Tensor): text features 166 | shape (B, T_txt, D_txt) 167 | media (torch.Tensor): image features 168 | shape (B, T_img, n, D_img) where n is the dim of the latents 169 | media_locations: boolean mask identifying the media tokens in x 170 | shape (B, T_txt) 171 | attend_previous: bool 172 | If false, ignores immediately preceding image and starts attending when following image 173 | """ 174 | _, T_img, n = media.shape[:3] 175 | h = self.heads 176 | 177 | x = self.norm(x) 178 | 179 | q = self.to_q(x) 180 | media = rearrange(media, "b t n d -> b (t n) d") 181 | 182 | k, v = self.to_kv(media).chunk(2, dim=-1) 183 | q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h) 184 | 185 | q = q * self.scale 186 | 187 | sim = einsum("... i d, ... j d -> ... i j", q, k) 188 | 189 | if exists(media_locations): 190 | # at each boolean of True, increment the time counter (relative to media time) 191 | text_time = media_locations.cumsum(dim=-1) 192 | media_time = torch.arange(T_img, device=x.device) + 1 193 | 194 | if not attend_previous: 195 | text_time[~media_locations] += 1 196 | # make sure max is still the number of images in the sequence 197 | text_time[ 198 | text_time 199 | > repeat( 200 | torch.count_nonzero(media_locations, dim=1), 201 | "b -> b i", 202 | i=text_time.shape[1], 203 | ) 204 | ] = 0 205 | 206 | # text time must equal media time if only attending to most immediate image 207 | # otherwise, as long as text time is greater than media time (if attending to all previous images / media) 208 | mask_op = torch.eq if self.only_attend_immediate_media else torch.ge 209 | 210 | text_to_media_mask = mask_op( 211 | rearrange(text_time, "b i -> b 1 i 1"), 212 | repeat(media_time, "j -> 1 1 1 (j n)", n=n), 213 | ) 214 | sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max) 215 | 216 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 217 | attn = sim.softmax(dim=-1) 218 | 219 | if exists(media_locations) and self.only_attend_immediate_media: 220 | # any text without a preceding media needs to have attention zeroed out 221 | text_without_media_mask = text_time == 0 222 | text_without_media_mask = rearrange( 223 | text_without_media_mask, "b i -> b 1 i 1" 224 | ) 225 | attn = attn.masked_fill(text_without_media_mask, 0.0) 226 | 227 | out = einsum("... i j, ... j d -> ... i d", attn, v) 228 | out = rearrange(out, "b h n d -> b n (h d)") 229 | return self.to_out(out) 230 | 231 | 232 | class GatedCrossAttentionBlock(nn.Module): 233 | def __init__( 234 | self, 235 | *, 236 | dim, 237 | dim_visual, 238 | dim_head=64, 239 | heads=8, 240 | ff_mult=4, 241 | only_attend_immediate_media=True, 242 | ): 243 | super().__init__() 244 | self.attn = MaskedCrossAttention( 245 | dim=dim, 246 | dim_visual=dim_visual, 247 | dim_head=dim_head, 248 | heads=heads, 249 | only_attend_immediate_media=only_attend_immediate_media, 250 | ) 251 | self.attn_gate = nn.Parameter(torch.tensor([0.0])) 252 | 253 | self.ff = FeedForward(dim, mult=ff_mult) 254 | self.ff_gate = nn.Parameter(torch.tensor([0.0])) 255 | 256 | def forward( 257 | self, 258 | x, 259 | media, 260 | media_locations=None, 261 | attend_previous=True, 262 | ): 263 | x = ( 264 | self.attn( 265 | x, 266 | media, 267 | media_locations=media_locations, 268 | attend_previous=attend_previous, 269 | ) 270 | * self.attn_gate.tanh() 271 | + x 272 | ) 273 | x = self.ff(x) * self.ff_gate.tanh() + x 274 | 275 | return x 276 | -------------------------------------------------------------------------------- /Quick_demo/Model/RadFM/multimodality_model.py: -------------------------------------------------------------------------------- 1 | # Import necessary libraries 2 | from torch import nn 3 | from transformers.models.llama import LlamaForCausalLM 4 | from .my_embedding_layer import MyEmbedding 5 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 6 | import tqdm.auto as tqdm 7 | import torch.nn as nn 8 | import torch 9 | from torch.utils.checkpoint import checkpoint 10 | from torch.autograd import Variable 11 | import numpy as np 12 | 13 | class MultiLLaMAForCausalLM(nn.Module): 14 | """ 15 | A multimodal LLaMA model that combines language and vision inputs 16 | for causal language modeling tasks. 17 | """ 18 | def __init__(self, lang_model_path): 19 | """ 20 | Initialize the multimodal model. 21 | 22 | Args: 23 | lang_model_path (str): Path to the pretrained language model 24 | """ 25 | super(MultiLLaMAForCausalLM, self).__init__() 26 | 27 | # Load pretrained LLaMA model 28 | self.lang_model = LlamaForCausalLM.from_pretrained( 29 | lang_model_path, 30 | ) 31 | 32 | # Enable gradient checkpointing for memory efficiency 33 | self.lang_model.gradient_checkpointing_enable() 34 | self.lang_model.enable_input_require_grads() 35 | 36 | # Initialize custom embedding layer and share weights with language model 37 | self.embedding_layer = MyEmbedding() 38 | self.embedding_layer.weight = self.lang_model.get_input_embeddings().weight 39 | 40 | # Set model dimensions 41 | self.hidden_dim = 5120 42 | self.voc_size = 32000 43 | 44 | def forward(self, lang_x, vision_x, attention_mask, labels, loss_reweight, key_words_query): 45 | """ 46 | Forward pass for the multimodal model. 47 | 48 | Args: 49 | lang_x: Language input tokens 50 | vision_x: Vision input features 51 | attention_mask: Attention mask for language inputs 52 | labels: Target labels for language modeling 53 | loss_reweight: Weights for calculating loss (to prioritize certain tokens) 54 | key_words_query: Query for highlighting important words 55 | 56 | Returns: 57 | Dictionary containing model outputs including loss and logits 58 | """ 59 | if labels.shape == lang_x.shape: 60 | # Set embedding mode to handle text inputs 61 | self.embedding_layer.flag = 'Text' 62 | 63 | # Get embeddings and matching loss from embedding layer 64 | input_embedding, loss_match = self.embedding_layer(lang_x, vision_x, key_words_query) 65 | 66 | # Forward pass through the language model 67 | output = self.lang_model(inputs_embeds=input_embedding, attention_mask=attention_mask, labels=labels) 68 | logits = output['logits'] 69 | 70 | # Initialize regularization loss 71 | loss_reg = None 72 | if labels is not None: 73 | # Shift logits and labels for next-token prediction 74 | shift_logits = logits[..., :-1, :].contiguous() 75 | shift_labels = labels[..., 1:].contiguous() 76 | shift_loss_reweight = loss_reweight[..., 1:].contiguous() 77 | 78 | # Prepare for loss calculation 79 | loss_fct = CrossEntropyLoss(reduction='none') 80 | shift_logits = shift_logits.view(-1, self.voc_size) 81 | shift_labels = shift_labels.view(-1) 82 | shift_loss_reweight = shift_loss_reweight.view(-1) 83 | 84 | # Ensure tensors are on the same device 85 | shift_labels = shift_labels.to(shift_logits.device) 86 | shift_loss_reweight = shift_loss_reweight.to(shift_logits.device) 87 | 88 | # Calculate weighted cross-entropy loss 89 | loss_reg = loss_fct(shift_logits, shift_labels) 90 | loss_reg = torch.sum(shift_loss_reweight * loss_reg) / torch.sum(shift_loss_reweight) 91 | 92 | # Combine losses 93 | loss = loss_reg 94 | if loss_match is not None: 95 | loss = 0.8 * loss + 0.2 * loss_match 96 | 97 | # Calculate accuracy metrics 98 | logits = output['logits'][..., :-1, :].contiguous().detach() 99 | total = len(labels) 100 | predictions = torch.argmax(logits, dim=-1) 101 | labels = labels[..., 1:].contiguous() 102 | 103 | # Count correct predictions (ignoring padding tokens with -100) 104 | Acc = torch.sum(torch.all(torch.logical_or(predictions == labels, labels == -100), dim=-1)) 105 | Accuracy = Acc / total 106 | 107 | return dict( 108 | # loss_reg = loss_reg, 109 | # loss_matching = loss_matching, 110 | logits=Accuracy, 111 | loss=output['loss'], 112 | ) 113 | 114 | ### useless for now ignore the folowing codes ### 115 | # if labels.shape == vision_x.shape: 116 | # self.embedding_layer.flag = 'Seg' 117 | # input_embedding = self.embedding_layer(lang_x, vision_x) 118 | 119 | def generate(self, lang_x, vision_x): 120 | """ 121 | Generate text based on language and vision inputs. 122 | 123 | Args: 124 | lang_x: Language input tokens 125 | vision_x: Vision input features 126 | 127 | Returns: 128 | Generated token sequence 129 | """ 130 | # Set embedding mode to text generation 131 | self.embedding_layer.flag = 'Text' 132 | 133 | with torch.no_grad(): 134 | # Get embeddings from the embedding layer 135 | input_embedding, _ = self.embedding_layer(lang_x, vision_x) 136 | 137 | # Generate text using language model 138 | generation = self.lang_model.generate( 139 | inputs_embeds=input_embedding, 140 | max_new_tokens=200, 141 | top_k=50 142 | ) 143 | 144 | return generation -------------------------------------------------------------------------------- /Quick_demo/Model/RadFM/my_embedding_layer.py: -------------------------------------------------------------------------------- 1 | # Import necessary libraries 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch 5 | from .helpers import PerceiverResampler 6 | from .utils import get_visual_encoder 7 | from einops import rearrange, repeat 8 | from einops_exts import rearrange_many 9 | import torchvision 10 | from .vit_3d import ViT 11 | from einops.layers.torch import Rearrange 12 | from .transformer_decoder import TransformerDecoder, TransformerDecoderLayer 13 | from torch.utils.checkpoint import checkpoint 14 | from torch.autograd import Variable 15 | import random 16 | from transformers import AutoTokenizer, AutoModel 17 | 18 | class MyEmbedding(nn.Module): 19 | """ 20 | Custom embedding layer for multimodal inputs that combines text and vision features. 21 | """ 22 | def __init__(self, num_embeddings=32000, embedding_dim=5120, perceiver_num=32, vis_dim=768, 23 | patch_size=32, frame_patch_size=4, seg_channel=256): 24 | """ 25 | Initialize the multimodal embedding layer. 26 | 27 | Args: 28 | num_embeddings (int): Size of vocabulary for text embeddings 29 | embedding_dim (int): Dimension of output embeddings 30 | perceiver_num (int): Number of latent vectors in perceiver 31 | vis_dim (int): Dimension of vision features 32 | patch_size (int): Size of image patches 33 | frame_patch_size (int): Size of 3D frame patches 34 | seg_channel (int): Number of segmentation channels 35 | """ 36 | super().__init__() 37 | self.num_embeddings = num_embeddings 38 | self.embedding_dim = embedding_dim 39 | # Standard embedding weight matrix for text tokens 40 | self.weight = nn.Parameter(torch.torch.randn((num_embeddings, embedding_dim))) 41 | # Special token weights for figures/images 42 | self.figure_token_weight = nn.Parameter(torch.randn((2, embedding_dim))) 43 | self.flag = 'Text' # Mode flag: 'Text' or 'Seg' 44 | self.patch_size = patch_size 45 | self.frame_patch_size = frame_patch_size 46 | self.seg_channel = seg_channel 47 | 48 | ## the MedKEBERT can be downloaded from https://huggingface.co/xmcmic/Med-KEBERT/tree/main ## 49 | # Initialize medical domain BERT model for keyword understanding 50 | self.bert_tokenizer = AutoTokenizer.from_pretrained("xmcmic/Med-KEBERT") 51 | self.bert_model = AutoModel.from_pretrained("xmcmic/Med-KEBERT") 52 | # Project BERT outputs to vision feature space 53 | self.bert_projection_fc = nn.Linear(768, vis_dim) 54 | 55 | # 3D Vision Transformer for processing volumetric medical images 56 | self.vision_encoder = ViT( 57 | image_size=512, # image size 58 | frames=512, # max number of frames 59 | image_patch_size=patch_size, # image patch size 60 | frame_patch_size=frame_patch_size, # frame patch size 61 | dim=vis_dim, 62 | depth=12, 63 | heads=8, 64 | mlp_dim=2048, 65 | dropout=0.1, 66 | emb_dropout=0.1 67 | ) 68 | 69 | # Upscaling layers for vision features (used in segmentation mode) 70 | self.output_upscaling = nn.Sequential( 71 | nn.ConvTranspose3d(vis_dim, vis_dim // 4, kernel_size=2, stride=2), 72 | nn.BatchNorm3d(vis_dim // 4), 73 | nn.GELU(), 74 | nn.ConvTranspose3d(vis_dim // 4, vis_dim // 8, kernel_size=2, stride=2), 75 | nn.GELU(), 76 | ) 77 | 78 | # Transformer decoder for cross-attention between text and vision 79 | decoder_layer = TransformerDecoderLayer(d_model=vis_dim, nhead=8, normalize_before=True) 80 | decoder_norm = nn.LayerNorm(vis_dim) 81 | self.transformer_decoder = TransformerDecoder(decoder_layer=decoder_layer, num_layers=4, norm=decoder_norm) 82 | 83 | # MLP for processing transformer decoder outputs 84 | self.transformer_decoder_mlp = nn.Sequential( 85 | nn.Linear(vis_dim, vis_dim // 4), 86 | nn.GELU(), 87 | nn.Linear(vis_dim // 4, vis_dim // 8), 88 | nn.GELU(), 89 | ) 90 | self.vis_dim = vis_dim 91 | 92 | # Perceiver resampler to reduce sequence length of vision features 93 | self.perceiver = PerceiverResampler(dim=self.vis_dim, num_latents=perceiver_num) 94 | # Final projection to embedding dimension 95 | self.fc = nn.Linear(self.vis_dim, self.embedding_dim) 96 | # Classification head for matching keywords 97 | self.cls_head = nn.Linear(self.vis_dim // 8, 1) 98 | 99 | 100 | def forward(self, text_input, vision_x, key_words_query=None): 101 | """ 102 | Forward pass for the embedding layer. 103 | 104 | Args: 105 | text_input: Text token indices [B, L] 106 | vision_x: Visual input features [B, S, C, H, W, D] 107 | key_words_query: Optional list of key words for contrastive learning 108 | 109 | Returns: 110 | tuple: (output_embeddings, loss_matching) 111 | - output_embeddings: Combined embeddings for text and vision 112 | - loss_matching: Contrastive loss for keyword matching (or None) 113 | """ 114 | if self.flag == 'Text': 115 | # Process in text mode 116 | B, S, C, H, W, D = vision_x.shape 117 | # Reshape for batch processing 118 | vision_x = rearrange(vision_x, "b S c h w d-> (b S) c h w d") 119 | 120 | # Process through vision encoder 121 | vision_x, pos_embedding = self.vision_encoder(vision_x) 122 | 123 | # Reshape back to batch format 124 | vision_x = rearrange(vision_x, "(b s F) v d -> b s F v d", b=B, s=S, F=1) 125 | 126 | loss_matching = None 127 | 128 | if key_words_query is not None: 129 | ## we do not use the following parts in final version. 130 | ## You can quota the following codes and if so the bert models will be useless. 131 | # key_words_query list[list[str]] B, words, each word matches corresponding vision_x embedding 132 | 133 | # Extract and deduplicate keywords 134 | query_words = [item for sublist in key_words_query for item in sublist] 135 | query_words = list(set(query_words)) 136 | 137 | # Limit number of keywords to process 138 | if len(query_words) > 16: 139 | random.shuffle(query_words) 140 | query_words = query_words[0:16] 141 | 142 | if query_words != []: 143 | # Create binary labels for contrastive learning 144 | contrastive_labels = torch.zeros(B, len(query_words)) # B Q 145 | for i, sublist in enumerate(key_words_query): 146 | for j, item in enumerate(query_words): 147 | if item in sublist: 148 | contrastive_labels[i, j] = 1 149 | contrastive_labels = contrastive_labels.to(vision_x.dtype).to(vision_x.device) 150 | 151 | # Get BERT embeddings for keywords 152 | with torch.no_grad(): 153 | query_words_embedding = self.bert_tokenizer( 154 | query_words, 155 | padding='max_length', 156 | truncation=True, 157 | max_length=256, 158 | return_tensors="pt" 159 | ) 160 | query_words_embedding = self.bert_model( 161 | input_ids=query_words_embedding['input_ids'].to(vision_x.device), 162 | attention_mask=query_words_embedding['attention_mask'].to(vision_x.device) 163 | )['last_hidden_state'][:, 0, :].to(vision_x.dtype).to(vision_x.device) # Q,D 164 | 165 | # Project BERT embeddings to vision space 166 | query_words_embedding = self.bert_projection_fc(query_words_embedding) 167 | query_words_embedding = query_words_embedding.unsqueeze(0).repeat(B, 1, 1) # B,Q,D 168 | _, N, _ = query_words_embedding.shape 169 | 170 | # Pool vision features 171 | image_embedding = vision_x.mean(dim=1) # B V D average pooling to remove multimodality 172 | image_embedding = rearrange(image_embedding, "b F v d -> b (F v) d") 173 | pos_embedding = rearrange(pos_embedding, "(b s) v d -> b s v d", b=B, s=S)[:, 0, :, :] 174 | 175 | # Prepare inputs for transformer decoder 176 | image_embedding = image_embedding.transpose(0, 1) # (H/P W/P D/P) B D 177 | pos_embedding = pos_embedding.transpose(0, 1) # (H/P W/P D/P) B D 178 | query_words_embedding = query_words_embedding.transpose(0, 1) # N B D 179 | 180 | # Cross-attention between keywords and image features 181 | oo_embedding, _ = self.transformer_decoder( 182 | query_words_embedding, image_embedding, pos=pos_embedding 183 | ) 184 | oo_embedding = oo_embedding.transpose(0, 1) # B Q D 185 | oo_embedding = rearrange(oo_embedding, 'b n d -> (b n) d') 186 | oo_embedding = self.transformer_decoder_mlp(oo_embedding) 187 | oo_embedding = self.cls_head(oo_embedding).mean(dim=-1) 188 | oo_embedding = rearrange(oo_embedding, '(b n) -> b n', b=B, n=N) # B Q 189 | 190 | # Calculate contrastive loss 191 | loss_matching = F.binary_cross_entropy_with_logits(oo_embedding, contrastive_labels) 192 | 193 | # Process vision features through perceiver resampler 194 | vision_x = self.perceiver(vision_x) # reshapes to (b, S, n, d) 195 | 196 | n = vision_x.shape[2] 197 | 198 | # Project vision features to embedding dimension 199 | vision_x = rearrange(vision_x, "b s n d -> (b s n) d") 200 | vision_x = self.fc(vision_x) 201 | vision_x = rearrange(vision_x, "(b T) d -> b T d", b=B, T=n*S) 202 | 203 | # Combine text and vision embeddings 204 | embedding_weight = torch.cat([self.weight, self.figure_token_weight], dim=0) 205 | embedding_weight = embedding_weight.unsqueeze(0).repeat(B, 1, 1) 206 | embedding_weight = torch.cat([embedding_weight, vision_x], dim=1) 207 | 208 | # Convert text indices to one-hot and compute final embeddings 209 | text_input = F.one_hot(text_input, embedding_weight.shape[1]).to(vision_x.dtype).to(vision_x.device) 210 | out_put = torch.matmul(text_input, embedding_weight) 211 | 212 | ## useless for now. ignore the folowing code## 213 | # if self.flag == 'Seg': 214 | # B,C,H,W,D = vision_x.shape 215 | # _,N,_ = text_input.shape 216 | # latent_embedding, pos_embedding = self.vision_encoder(vision_x) # B (H/P W/P D/P) D 217 | 218 | # image_embedding = latent_embedding.transpose(0,1) # (H/P W/P D/P) B D 219 | # pos_embedding = pos_embedding.transpose(0,1) # (H/P W/P D/P) B D 220 | # text_input = text_input.transpose(0,1) # N B D 221 | 222 | # mask_embedding,_ = self.transformer_decoder(text_input, image_embedding, pos = pos_embedding) 223 | # mask_embedding = mask_embedding.transpose(0,1) # B N D 224 | # mask_embedding = rearrange(mask_embedding, 'b n d -> (b n) d') 225 | # mask_embedding = self.transformer_decoder_mlp(mask_embedding) 226 | # mask_embedding = rearrange(mask_embedding, '(b n) d -> b n d', b=B, n=N,d = self.vis_dim // 8) 227 | 228 | # vision_x = rearrange(latent_embedding,'b (h w d) c -> b c h w d', h = (H // self.patch_size), w = (W // self.patch_size), d = (D // self.frame_patch_size), c=self.vis_dim) 229 | # vision_x = self.output_upscaling(vision_x) #B C H/4 W/4 D/4 230 | # out_put = torch.einsum('bchwd,bnc->bnhwd', vision_x, mask_embedding) 231 | 232 | return out_put, loss_matching 233 | 234 | # model = MyEmbedding(vision_encoder_path = '') 235 | # text_input = torch.randint(low=0, high=3210, size=(4,2048)) 236 | # image_input = torch.randn((4,3,3,512,512,4)) 237 | # key_words_query = [[],[],[],['consoliation']] 238 | # print(model(text_input, image_input, key_words_query)) -------------------------------------------------------------------------------- /Quick_demo/Model/RadFM/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | from einops.layers.torch import Rearrange 9 | from einops import rearrange, repeat 10 | 11 | class PositionEmbeddingSine(nn.Module): 12 | """ 13 | This is a more standard version of the position embedding, very similar to the one 14 | used by the Attention is all you need paper, generalized to work on images. 15 | """ 16 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 17 | super().__init__() 18 | self.num_pos_feats = num_pos_feats 19 | self.temperature = temperature 20 | self.normalize = normalize 21 | if scale is not None and normalize is False: 22 | raise ValueError("normalize should be True if scale is passed") 23 | if scale is None: 24 | scale = 2 * math.pi 25 | self.scale = scale 26 | 27 | def forward(self, tensor_list): 28 | x = tensor_list.tensors 29 | mask = tensor_list.mask 30 | assert mask is not None 31 | not_mask = ~mask 32 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 33 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 34 | if self.normalize: 35 | eps = 1e-6 36 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 37 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 38 | 39 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 40 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 41 | 42 | pos_x = x_embed[:, :, :, None] / dim_t 43 | pos_y = y_embed[:, :, :, None] / dim_t 44 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 45 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 46 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 47 | return pos 48 | 49 | 50 | class PositionEmbeddingLearned(nn.Module): 51 | """ 52 | Absolute pos embedding, learned. 53 | """ 54 | def __init__(self, num_pos_feats=256): 55 | super().__init__() 56 | self.row_embed = nn.Embedding(50, num_pos_feats) 57 | self.col_embed = nn.Embedding(50, num_pos_feats) 58 | self.reset_parameters() 59 | 60 | def reset_parameters(self): 61 | nn.init.uniform_(self.row_embed.weight) 62 | nn.init.uniform_(self.col_embed.weight) 63 | 64 | def forward(self, tensor_list): 65 | x = tensor_list.tensors 66 | h, w = x.shape[-2:] 67 | i = torch.arange(w, device=x.device) 68 | j = torch.arange(h, device=x.device) 69 | x_emb = self.col_embed(i) 70 | y_emb = self.row_embed(j) 71 | pos = torch.cat([ 72 | x_emb.unsqueeze(0).repeat(h, 1, 1), 73 | y_emb.unsqueeze(1).repeat(1, w, 1), 74 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 75 | return pos 76 | 77 | class PositionEmbeddingLearned3d(nn.Module): 78 | """ 79 | Absolute pos embedding, learned. 80 | """ 81 | def __init__(self, num_pos_feats=256,h_patch_num = 16, w_patch_num = 16,d_patch_num = 64): 82 | super().__init__() 83 | self.h_patch_num = h_patch_num 84 | self.w_patch_num = w_patch_num 85 | self.d_patch_num = d_patch_num 86 | self.row_embed = nn.Embedding(h_patch_num, num_pos_feats) 87 | self.col_embed = nn.Embedding(w_patch_num, num_pos_feats) 88 | self.dep_embed = nn.Embedding(d_patch_num, num_pos_feats) 89 | self.reset_parameters() 90 | 91 | def reset_parameters(self): 92 | nn.init.uniform_(self.row_embed.weight) 93 | nn.init.uniform_(self.col_embed.weight) 94 | nn.init.uniform_(self.dep_embed.weight) 95 | 96 | def forward(self, B, h, w, d,x): 97 | i = (torch.arange(h, device=x.device) + 1)* (self.h_patch_num // h) -1 98 | j = (torch.arange(w, device=x.device) + 1)* (self.w_patch_num // w) -1 99 | k = (torch.arange(d, device=x.device) + 1)* (self.d_patch_num // d) -1 100 | x_emb = self.row_embed(i).unsqueeze(1).unsqueeze(2).repeat(1,w,d,1) 101 | y_emb = self.col_embed(j).unsqueeze(0).unsqueeze(2).repeat(h,1,d,1) 102 | z_emb = self.dep_embed(k).unsqueeze(0).unsqueeze(1).repeat(h,w,1,1) 103 | pos = torch.cat([x_emb,y_emb,z_emb,], dim=-1).unsqueeze(0).repeat(B, 1, 1, 1, 1) 104 | pos = rearrange(pos,'b h w d c -> b (h w d) c') 105 | return pos 106 | 107 | def build_position_encoding(args): 108 | N_steps = args.hidden_dim // 2 109 | if args.position_embedding in ('v2', 'sine'): 110 | # TODO find a better way of exposing other arguments 111 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 112 | elif args.position_embedding in ('v3', 'learned'): 113 | position_embedding = PositionEmbeddingLearned(N_steps) 114 | else: 115 | raise ValueError(f"not supported {args.position_embedding}") 116 | 117 | return position_embedding 118 | 119 | # Pos = PositionEmbeddingLearned3d() 120 | # x = torch.randn((8,3,32,32,1)) 121 | # print(Pos(8,16,16,1,x)) -------------------------------------------------------------------------------- /Quick_demo/Model/RadFM/transformer_decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code modified from DETR tranformer: 3 | https://github.com/facebookresearch/detr 4 | Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 5 | """ 6 | 7 | import copy 8 | from typing import Optional, List 9 | import pickle as cp 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import nn, Tensor 14 | 15 | 16 | class TransformerDecoder(nn.Module): 17 | 18 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 19 | super().__init__() 20 | self.layers = _get_clones(decoder_layer, num_layers) 21 | self.num_layers = num_layers 22 | self.norm = norm 23 | self.return_intermediate = return_intermediate 24 | 25 | def forward(self, tgt, memory, 26 | tgt_mask: Optional[Tensor] = None, 27 | memory_mask: Optional[Tensor] = None, 28 | tgt_key_padding_mask: Optional[Tensor] = None, 29 | memory_key_padding_mask: Optional[Tensor] = None, 30 | pos: Optional[Tensor] = None, 31 | query_pos: Optional[Tensor] = None): 32 | output = tgt 33 | T,B,C = memory.shape 34 | intermediate = [] 35 | atten_layers = [] 36 | for n,layer in enumerate(self.layers): 37 | 38 | residual=True 39 | output,ws = layer(output, memory, tgt_mask=tgt_mask, 40 | memory_mask=memory_mask, 41 | tgt_key_padding_mask=tgt_key_padding_mask, 42 | memory_key_padding_mask=memory_key_padding_mask, 43 | pos=pos, query_pos=query_pos,residual=residual) 44 | atten_layers.append(ws) 45 | if self.return_intermediate: 46 | intermediate.append(self.norm(output)) 47 | if self.norm is not None: 48 | output = self.norm(output) 49 | if self.return_intermediate: 50 | intermediate.pop() 51 | intermediate.append(output) 52 | 53 | if self.return_intermediate: 54 | return torch.stack(intermediate) 55 | return output,atten_layers 56 | 57 | 58 | 59 | class TransformerDecoderLayer(nn.Module): 60 | 61 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 62 | activation="relu", normalize_before=False): 63 | super().__init__() 64 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 65 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 66 | # Implementation of Feedforward model 67 | self.linear1 = nn.Linear(d_model, dim_feedforward) 68 | self.dropout = nn.Dropout(dropout) 69 | self.linear2 = nn.Linear(dim_feedforward, d_model) 70 | 71 | self.norm1 = nn.LayerNorm(d_model) 72 | self.norm2 = nn.LayerNorm(d_model) 73 | self.norm3 = nn.LayerNorm(d_model) 74 | self.dropout1 = nn.Dropout(dropout) 75 | self.dropout2 = nn.Dropout(dropout) 76 | self.dropout3 = nn.Dropout(dropout) 77 | 78 | self.activation = _get_activation_fn(activation) 79 | self.normalize_before = normalize_before 80 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 81 | return tensor if pos is None else tensor + pos 82 | 83 | def forward_post(self, tgt, memory, 84 | tgt_mask: Optional[Tensor] = None, 85 | memory_mask: Optional[Tensor] = None, 86 | tgt_key_padding_mask: Optional[Tensor] = None, 87 | memory_key_padding_mask: Optional[Tensor] = None, 88 | pos: Optional[Tensor] = None, 89 | query_pos: Optional[Tensor] = None, 90 | residual=True): 91 | q = k = self.with_pos_embed(tgt, query_pos) 92 | tgt2,ws = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 93 | key_padding_mask=tgt_key_padding_mask) 94 | tgt = self.norm1(tgt) 95 | tgt2,ws = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 96 | key=self.with_pos_embed(memory, pos), 97 | value=memory, attn_mask=memory_mask, 98 | key_padding_mask=memory_key_padding_mask) 99 | 100 | 101 | # attn_weights [B,NUM_Q,T] 102 | tgt = tgt + self.dropout2(tgt2) 103 | tgt = self.norm2(tgt) 104 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 105 | tgt = tgt + self.dropout3(tgt2) 106 | tgt = self.norm3(tgt) 107 | return tgt,ws 108 | 109 | def forward_pre(self, tgt, memory, 110 | tgt_mask: Optional[Tensor] = None, 111 | memory_mask: Optional[Tensor] = None, 112 | tgt_key_padding_mask: Optional[Tensor] = None, 113 | memory_key_padding_mask: Optional[Tensor] = None, 114 | pos: Optional[Tensor] = None, 115 | query_pos: Optional[Tensor] = None): 116 | tgt2 = self.norm1(tgt) 117 | q = k = self.with_pos_embed(tgt2, query_pos) 118 | tgt2,ws = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 119 | key_padding_mask=tgt_key_padding_mask) 120 | tgt = tgt + self.dropout1(tgt2) 121 | tgt2 = self.norm2(tgt) 122 | tgt2,attn_weights = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 123 | key=self.with_pos_embed(memory, pos), 124 | value=memory, attn_mask=memory_mask, 125 | key_padding_mask=memory_key_padding_mask) 126 | tgt = tgt + self.dropout2(tgt2) 127 | tgt2 = self.norm3(tgt) 128 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 129 | tgt = tgt + self.dropout3(tgt2) 130 | return tgt,attn_weights 131 | 132 | def forward(self, tgt, memory, 133 | tgt_mask: Optional[Tensor] = None, 134 | memory_mask: Optional[Tensor] = None, 135 | tgt_key_padding_mask: Optional[Tensor] = None, 136 | memory_key_padding_mask: Optional[Tensor] = None, 137 | pos: Optional[Tensor] = None, 138 | query_pos: Optional[Tensor] = None, 139 | residual=True): 140 | if self.normalize_before: 141 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 142 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 143 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 144 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos,residual) 145 | 146 | 147 | def _get_clones(module, N): 148 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 149 | 150 | 151 | 152 | def _get_activation_fn(activation): 153 | """Return an activation function given a string""" 154 | if activation == "relu": 155 | return F.relu 156 | if activation == "gelu": 157 | return F.gelu 158 | if activation == "glu": 159 | return F.glu 160 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 161 | -------------------------------------------------------------------------------- /Quick_demo/Model/RadFM/utils.py: -------------------------------------------------------------------------------- 1 | from .blocks import ModifiedResNet,PMC_CLIP_cfg 2 | import torch 3 | from torchvision import transforms 4 | from PIL import Image 5 | import torch.nn as nn 6 | def extend_instance(obj, mixin): 7 | """Apply mixins to a class instance after creation""" 8 | base_cls = obj.__class__ 9 | base_cls_name = obj.__class__.__name__ 10 | obj.__class__ = type( 11 | base_cls_name, (mixin, base_cls), {} 12 | ) # mixin needs to go first for our forward() logic to work 13 | 14 | 15 | def getattr_recursive(obj, att): 16 | """ 17 | Return nested attribute of obj 18 | Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c 19 | """ 20 | if att == "": 21 | return obj 22 | i = att.find(".") 23 | if i < 0: 24 | return getattr(obj, att) 25 | else: 26 | return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) 27 | 28 | 29 | def setattr_recursive(obj, att, val): 30 | """ 31 | Set nested attribute of obj 32 | Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val 33 | """ 34 | if "." in att: 35 | obj = getattr_recursive(obj, ".".join(att.split(".")[:-1])) 36 | setattr(obj, att.split(".")[-1], val) 37 | 38 | 39 | 40 | def get_visual_encoder(model_str): 41 | """ 42 | Args: 43 | str (_type_): str_to_model_path 44 | Return: 45 | vision_model, visual_dim, img_preprocessor 46 | """ 47 | normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 48 | img_preprocessor = transforms.Compose([ 49 | transforms.Resize((512,512), interpolation=Image.BICUBIC), 50 | transforms.ToTensor(), 51 | normalize, 52 | ]) 53 | if 'PMC-CLIP' in model_str: 54 | #vision_cfg = json.load(open(model_args.visual_model_config,'r'))['vision_cfg'] 55 | vision_cfg = PMC_CLIP_cfg() 56 | vision_heads = vision_cfg.width * 32 // vision_cfg.head_width 57 | vision_model = ModifiedResNet( 58 | layers=vision_cfg.layers, 59 | heads=vision_heads, 60 | output_dim = 768, 61 | image_size=vision_cfg.image_size, 62 | width=vision_cfg.width 63 | ) 64 | vision_model = vision_load_pretrain(vision_model,model_str) 65 | vision_model = nn.Sequential(*list(vision_model.children())[:-2]) 66 | visual_dim = 1024 67 | return vision_model,visual_dim,img_preprocessor 68 | 69 | def vision_load_pretrain(resnet,model_path): 70 | checkpoint = torch.load(model_path, map_location='cpu') 71 | state_dict = checkpoint['state_dict'] 72 | state_dict = {k.replace('module.visual.',''): v for k, v in state_dict.items() if '.visual' in k} 73 | resnet.load_state_dict(state_dict) 74 | return resnet 75 | -------------------------------------------------------------------------------- /Quick_demo/Model/RadFM/vit_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from einops import rearrange, repeat 5 | from einops.layers.torch import Rearrange 6 | from .position_encoding import PositionEmbeddingLearned3d 7 | 8 | # helpers 9 | 10 | def pair(t): 11 | return t if isinstance(t, tuple) else (t, t) 12 | 13 | # classes 14 | 15 | class PreNorm(nn.Module): 16 | def __init__(self, dim, fn): 17 | super().__init__() 18 | self.norm = nn.LayerNorm(dim) 19 | self.fn = fn 20 | def forward(self, x, **kwargs): 21 | return self.fn(self.norm(x), **kwargs) 22 | 23 | class FeedForward(nn.Module): 24 | def __init__(self, dim, hidden_dim, dropout = 0.): 25 | super().__init__() 26 | self.net = nn.Sequential( 27 | nn.Linear(dim, hidden_dim), 28 | nn.GELU(), 29 | nn.Dropout(dropout), 30 | nn.Linear(hidden_dim, dim), 31 | nn.Dropout(dropout) 32 | ) 33 | def forward(self, x): 34 | return self.net(x) 35 | 36 | class Attention(nn.Module): 37 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 38 | super().__init__() 39 | inner_dim = dim_head * heads 40 | project_out = not (heads == 1 and dim_head == dim) 41 | 42 | self.heads = heads 43 | self.scale = dim_head ** -0.5 44 | 45 | self.attend = nn.Softmax(dim = -1) 46 | self.dropout = nn.Dropout(dropout) 47 | 48 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 49 | 50 | self.to_out = nn.Sequential( 51 | nn.Linear(inner_dim, dim), 52 | nn.Dropout(dropout) 53 | ) if project_out else nn.Identity() 54 | 55 | def forward(self, x): 56 | qkv = self.to_qkv(x).chunk(3, dim = -1) 57 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 58 | 59 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 60 | 61 | attn = self.attend(dots) 62 | attn = self.dropout(attn) 63 | 64 | out = torch.matmul(attn, v) 65 | out = rearrange(out, 'b h n d -> b n (h d)') 66 | return self.to_out(out) 67 | 68 | class Transformer(nn.Module): 69 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 70 | super().__init__() 71 | self.layers = nn.ModuleList([]) 72 | for _ in range(depth): 73 | self.layers.append(nn.ModuleList([ 74 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 75 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 76 | ])) 77 | def forward(self, x): 78 | for attn, ff in self.layers: 79 | x = attn(x) + x 80 | x = ff(x) + x 81 | return x 82 | 83 | class ViT(nn.Module): 84 | def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): 85 | super().__init__() 86 | image_height, image_width = pair(image_size) 87 | patch_height, patch_width = pair(image_patch_size) 88 | 89 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 90 | assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size' 91 | 92 | self.patch_height = patch_height 93 | self.patch_width = patch_width 94 | self.frame_patch_size = frame_patch_size 95 | 96 | num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size) 97 | patch_dim = channels * patch_height * patch_width * frame_patch_size 98 | 99 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 100 | 101 | self.to_patch_embedding = nn.Sequential( 102 | Rearrange('b c (h p1) (w p2) (f pf) -> b (h w f) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size), 103 | nn.LayerNorm(patch_dim), 104 | nn.Linear(patch_dim, dim), 105 | nn.LayerNorm(dim), 106 | ) 107 | 108 | self.pos_embedding = PositionEmbeddingLearned3d(dim // 3,(image_height // patch_height), (image_width // patch_width), (frames // frame_patch_size)) 109 | self.dropout = nn.Dropout(emb_dropout) 110 | 111 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 112 | 113 | def forward(self, video): 114 | B, C, H, W, D = video.shape 115 | x = self.to_patch_embedding(video) 116 | b, n, _ = x.shape 117 | 118 | pos = self.pos_embedding(B, H // self.patch_height, W // self.patch_width, D // self.frame_patch_size,x) 119 | x += pos 120 | x = self.dropout(x) 121 | 122 | x = self.transformer(x) 123 | return x,pos 124 | -------------------------------------------------------------------------------- /Quick_demo/test.py: -------------------------------------------------------------------------------- 1 | # Import necessary libraries for data processing, model loading, and inference 2 | import tqdm.auto as tqdm 3 | import torch.nn.functional as F 4 | from typing import Optional, Dict, Sequence 5 | from typing import List, Optional, Tuple, Union 6 | import transformers 7 | from dataclasses import dataclass, field 8 | from Model.RadFM.multimodality_model import MultiLLaMAForCausalLM 9 | import torch 10 | from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer 11 | from torchvision import transforms 12 | from PIL import Image 13 | 14 | def get_tokenizer(tokenizer_path, max_img_size=100, image_num=32): 15 | ''' 16 | Initialize the tokenizer with special tokens for image handling 17 | 18 | Args: 19 | tokenizer_path: Path to the base tokenizer 20 | max_img_size: Maximum number of images supported in a prompt 21 | image_num: Number of token embeddings per image 22 | 23 | Returns: 24 | Tuple of (tokenizer, image_padding_tokens) 25 | ''' 26 | if isinstance(tokenizer_path, str): 27 | image_padding_tokens = [] 28 | # Load the base tokenizer from the provided path 29 | text_tokenizer = LlamaTokenizer.from_pretrained( 30 | tokenizer_path, 31 | ) 32 | # Define initial special tokens for image markup 33 | special_token = {"additional_special_tokens": ["", ""]} 34 | 35 | # Generate unique tokens for each image position and patch 36 | for i in range(max_img_size): 37 | image_padding_token = "" 38 | 39 | for j in range(image_num): 40 | image_token = "" 41 | image_padding_token = image_padding_token + image_token 42 | special_token["additional_special_tokens"].append("") 43 | 44 | # Store the concatenated tokens for each image 45 | image_padding_tokens.append(image_padding_token) 46 | 47 | # Add all special tokens to the tokenizer 48 | text_tokenizer.add_special_tokens( 49 | special_token 50 | ) 51 | 52 | # Configure standard special tokens for LLaMA models 53 | text_tokenizer.pad_token_id = 0 54 | text_tokenizer.bos_token_id = 1 55 | text_tokenizer.eos_token_id = 2 56 | 57 | return text_tokenizer, image_padding_tokens 58 | 59 | def combine_and_preprocess(question, image_list, image_padding_tokens): 60 | ''' 61 | Combine text and images into a multimodal input format 62 | 63 | Args: 64 | question: Text input or question to process 65 | image_list: List of images with their metadata 66 | image_padding_tokens: Special tokens for image placeholders 67 | 68 | Returns: 69 | Tuple of (processed_text, processed_images_tensor) 70 | ''' 71 | # Define image transformation pipeline 72 | transform = transforms.Compose([ 73 | transforms.RandomResizedCrop([512, 512], scale=(0.8, 1.0), interpolation=transforms.InterpolationMode.BICUBIC), 74 | transforms.ToTensor(), 75 | ]) 76 | 77 | images = [] 78 | new_qestions = [_ for _ in question] # Convert question string to list of characters 79 | padding_index = 0 80 | 81 | # Process each image in the list 82 | for img in image_list: 83 | img_path = img['img_path'] 84 | position = img['position'] # Where to insert the image in the text 85 | 86 | # Load and transform the image 87 | image = Image.open(img_path).convert('RGB') 88 | image = transform(image) 89 | image = image.unsqueeze(0).unsqueeze(-1) # Add batch and depth dimensions (c,w,h,d) 90 | 91 | # Resize the image to target dimensions 92 | target_H = 512 93 | target_W = 512 94 | target_D = 4 95 | # This can be different for 3D and 2D images. For demonstration we here set this as the default sizes for 2D images. 96 | images.append(torch.nn.functional.interpolate(image, size=(target_H, target_W, target_D))) 97 | 98 | # Insert image placeholder token at the specified position in text 99 | new_qestions[position] = "" + image_padding_tokens[padding_index] + "" + new_qestions[position] 100 | padding_index += 1 101 | 102 | # Stack all images into a batch and add batch dimension 103 | vision_x = torch.cat(images, dim=1).unsqueeze(0) # Cat tensors and expand the batch_size dim 104 | 105 | # Join the character list back into a string 106 | text = ''.join(new_qestions) 107 | return text, vision_x 108 | 109 | 110 | def main(): 111 | ''' 112 | Main function to demonstrate the RadFM model inference 113 | ''' 114 | print("Setup tokenizer") 115 | # Initialize tokenizer with special image tokens 116 | text_tokenizer, image_padding_tokens = get_tokenizer('./Language_files') 117 | print("Finish loading tokenizer") 118 | 119 | ### Initialize a simple case for demo ### 120 | print("Setup demo case") 121 | # Define a medical question about a chest X-ray 122 | question = "Can you identify any visible signs of Cardiomegaly in the image?" 123 | 124 | # Specify the image path and where to insert it in the question 125 | image = [ 126 | { 127 | 'img_path': './view1_frontal.jpg', 128 | 'position': 0, # Insert at the beginning of the question 129 | }, # Can add arbitrary number of images 130 | ] 131 | 132 | # Combine text and images into model-ready format 133 | text, vision_x = combine_and_preprocess(question, image, image_padding_tokens) 134 | print("Finish loading demo case") 135 | 136 | print("Setup Model") 137 | # Initialize the multimodal model 138 | model = MultiLLaMAForCausalLM( 139 | lang_model_path='./Language_files', # Build up model based on LLaMa-13B config 140 | ) 141 | 142 | # Load pretrained model weights 143 | ckpt = torch.load('./pytorch_model.bin', map_location='cpu') # Please download our checkpoint from huggingface and decompress the original zip file first 144 | model.load_state_dict(ckpt) 145 | print("Finish loading model") 146 | 147 | # Move model to GPU and set to evaluation mode 148 | model = model.to('cuda') 149 | model.eval() 150 | 151 | # Run inference without gradient computation 152 | with torch.no_grad(): 153 | # Tokenize the combined text with image placeholders 154 | lang_x = text_tokenizer( 155 | text, max_length=2048, truncation=True, return_tensors="pt" 156 | )['input_ids'].to('cuda') 157 | 158 | # Move image tensor to GPU 159 | vision_x = vision_x.to('cuda') 160 | 161 | # Generate text response 162 | generation = model.generate(lang_x, vision_x) 163 | 164 | # Decode the generated token IDs to text 165 | generated_texts = text_tokenizer.batch_decode(generation, skip_special_tokens=True) 166 | 167 | # Print results 168 | print('---------------------------------------------------') 169 | print('Input: ', question) 170 | print('Output: ', generated_texts[0]) 171 | 172 | 173 | if __name__ == "__main__": 174 | main() -------------------------------------------------------------------------------- /Quick_demo/view1_frontal.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/view1_frontal.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.6.1 2 | einops-exts==0.0.4 3 | huggingface-hub==0.16.4 4 | nibabel==5.1.0 5 | nmslib==2.1.1 6 | opencv-python==4.8.0.76 7 | pandas==2.0.3 8 | Pillow==9.4.0 9 | pytz==2023.3 10 | PyYAML==6.0.1 11 | scikit-learn==1.3.0 12 | scipy==1.11.2 13 | scispacy 14 | sentencepiece==0.1.99 15 | SimpleITK==2.2.1 16 | spacy==3.6.1 17 | spacy-alignments==0.9.0 18 | spacy-legacy==3.0.12 19 | spacy-loggers==1.0.4 20 | spacy-transformers==1.2.5 21 | tokenizers==0.13.3 22 | torch==2.0.1 23 | torchaudio==2.0.2 24 | torchvision==0.15.2 25 | tqdm==4.66.1 26 | transformers==4.28.1 27 | -------------------------------------------------------------------------------- /src/Dataset/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .radiopaedia import RadioVQA_Dataset,Radio_Modality_Dataset,Radiofeatures_Dataset,RadioCaption_Dataset 2 | from .binary import Binary_Dataset 3 | from .chestxray import ChestXray_Dataset 4 | from .vqa import VQA_Dataset 5 | from .pmcoa import PMCOA_Dataset 6 | from .paper_inline import Paper_Inline_dataset 7 | from .case_report import CaseReport_dataset 8 | from .MedPix_dataset import MedPix_Multi_Dataset,MedPix_Single_Dataset,MedPix_QA_Dataset 9 | -------------------------------------------------------------------------------- /src/Dataset/dataset/__pycache__/MedPix_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/MedPix_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /src/Dataset/dataset/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/Dataset/dataset/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/Dataset/dataset/__pycache__/binary.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/binary.cpython-39.pyc -------------------------------------------------------------------------------- /src/Dataset/dataset/__pycache__/case_report.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/case_report.cpython-39.pyc -------------------------------------------------------------------------------- /src/Dataset/dataset/__pycache__/chestxray.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/chestxray.cpython-310.pyc -------------------------------------------------------------------------------- /src/Dataset/dataset/__pycache__/chestxray.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/chestxray.cpython-39.pyc -------------------------------------------------------------------------------- /src/Dataset/dataset/__pycache__/paper_inline.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/paper_inline.cpython-39.pyc -------------------------------------------------------------------------------- /src/Dataset/dataset/__pycache__/pmcoa.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/pmcoa.cpython-39.pyc -------------------------------------------------------------------------------- /src/Dataset/dataset/__pycache__/pmcvqa.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/pmcvqa.cpython-39.pyc -------------------------------------------------------------------------------- /src/Dataset/dataset/__pycache__/radiopaedia.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/radiopaedia.cpython-39.pyc -------------------------------------------------------------------------------- /src/Dataset/dataset/__pycache__/radiovqa.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/radiovqa.cpython-39.pyc -------------------------------------------------------------------------------- /src/Dataset/dataset/binary.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import logging 4 | import os 5 | import re 6 | import difflib 7 | import sys 8 | import torch 9 | import random 10 | from abc import abstractmethod 11 | from itertools import islice 12 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union 13 | from collections.abc import Mapping 14 | from torch.utils.data import DataLoader 15 | import PIL 16 | from torch.utils.data import Dataset 17 | import numpy as np 18 | import pandas as pd 19 | from tqdm import tqdm 20 | from torchvision import transforms 21 | from collections import defaultdict 22 | from PIL import Image 23 | 24 | class Binary_Dataset(Dataset): 25 | """_summary_ 26 | Args: 27 | Dataset (_type_): caption task formulated as vqa task for Chestxray classification dataset 28 | csv_path (_type_): path to csv file 29 | prompt_json_file (_type_): path to json file containing binary cls prompts, the answer is yes/no 30 | Output: 31 | Dict: { 32 | "image_dict": {"image": image, "position": {"question": 0}}, # image is a tensor of shape [c,w,h,d] [3,512,512,1], position is a dict, random choice of 0 or len(question) 33 | "question": question, # random choice of caption prompts 34 | "answer":answer, # caption 35 | } 36 | """ 37 | def __init__(self,csv_path,prompt_json_file): 38 | data_info = pd.read_csv(csv_path) 39 | self.img_path_list = np.asarray(data_info['image_path']) 40 | self.disease_list = np.asarray(data_info['disease']) 41 | self.answer_list = np.asarray(data_info['label']) 42 | self.transform = transforms.Compose([ 43 | transforms.RandomResizedCrop([512,512],scale=(0.8, 1.0), interpolation=transforms.InterpolationMode.BICUBIC), 44 | transforms.ToTensor(), 45 | ]) 46 | with open(prompt_json_file, 'r') as f: 47 | self.caption_prompts = json.load(f)['caption_prompt'] 48 | self.map_answer = {0:'no',1:'yes'} 49 | 50 | def __len__(self): 51 | return len(self.img_path_list) 52 | 53 | def __getitem__(self, index): 54 | img_path = self.img_path_list[index] 55 | image = Image.open(img_path).convert('RGB') 56 | image = self.transform(image) 57 | image = image.unsqueeze(-1) # c,w,h,d 58 | answer = self.map_answer[self.answer_list[index]] 59 | question = random.choice(self.caption_prompts).replace('disease',self.disease_list[index]) 60 | image_dict = [{ 61 | "image": image, 62 | "position": { 63 | "question": len(question) 64 | } 65 | }] 66 | return { 67 | "image_dict": image_dict, 68 | "question": question, 69 | "answer":answer, 70 | } 71 | -------------------------------------------------------------------------------- /src/Dataset/dataset/caption_prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "caption_prompt": [ 3 | "Can you provide a caption consists of finding and impression for this medical image?", 4 | "Describe the finding and impression of the medical image you see.", 5 | "Please caption this medical scan with finding and impression.", 6 | "What is the finding and impression of this image?", 7 | "Describe this medical scan with finding and impression.", 8 | "Please write a caption consists of finding and impression for this image.", 9 | "Can you summarize with finding and impression the images presented?", 10 | "Please caption this scan with finding and impression.", 11 | "Please provide a caption consists of finding and impression for this medical image.", 12 | "Can you provide a summary consists of finding and impression of this radiograph?", 13 | "What are the findings and impression presented in this medical scan?", 14 | "Please write a caption consists of finding and impression for this scan.", 15 | "Can you provide a description consists of finding and impression of this medical scan?", 16 | "Please caption this medical scan with finding and impression.", 17 | "Can you provide a caption consists of finding and impression for this medical scan?" 18 | ] 19 | } -------------------------------------------------------------------------------- /src/Dataset/dataset/case_report.py: -------------------------------------------------------------------------------- 1 | # Import necessary libraries for data processing, image handling, and model integration 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import transformers 5 | import pandas as pd 6 | import copy 7 | import random 8 | import os 9 | import numpy as np 10 | import tqdm 11 | import torch 12 | import json 13 | from PIL import Image 14 | import torchvision 15 | from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer 16 | from torchvision import transforms 17 | from ast import literal_eval 18 | 19 | class CaseReport_dataset(Dataset): 20 | """ 21 | Dataset class for medical case reports with associated images. 22 | 23 | This dataset processes medical case reports containing text and referenced images, 24 | formatting them for multimodal medical AI training or inference. 25 | """ 26 | def __init__(self, csv_path, img_path): 27 | """ 28 | Initialize the dataset. 29 | 30 | Args: 31 | csv_path: Path to CSV file containing case reports data 32 | img_path: Base path to the directory containing images 33 | """ 34 | self.img_path = img_path # Root directory for images 35 | self.question_list = pd.read_csv(csv_path) # Load dataset from CSV 36 | 37 | # Define image transformation pipeline 38 | # normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 39 | self.transform = transforms.Compose([ 40 | # Crop and resize images to 512x512, maintaining 80-100% of original content 41 | transforms.RandomResizedCrop([512, 512], scale=(0.8, 1.0), interpolation=transforms.InterpolationMode.BICUBIC), 42 | # Convert to tensor with values in [0, 1] 43 | transforms.ToTensor(), 44 | # normalize, # Commented out normalization 45 | ]) 46 | 47 | 48 | def __len__(self): 49 | """Return the total number of samples in the dataset""" 50 | return len(self.question_list) 51 | 52 | def __getitem__(self, idx): 53 | """ 54 | Get a single sample from the dataset 55 | 56 | Args: 57 | idx: Index of the sample to retrieve 58 | 59 | Returns: 60 | Dictionary containing the processed sample with image, question, and answer 61 | """ 62 | # Get the row from dataframe 63 | sample = self.question_list.iloc[idx] 64 | 65 | # Extract metadata and content 66 | PMC_id = sample['PMC_id'] # PubMed Central ID 67 | img_ref = literal_eval(sample['img_ref']) # List of image references 68 | context = str(sample['context']) # Case context 69 | 70 | # Truncate long contexts to focus on beginning and end 71 | sentences = context.split('.') 72 | if len(sentences) > 5: 73 | first_sentence = sentences[0] # Keep the first sentence 74 | last_sentences = ". ".join(context.split('.')[-4:]) # Keep the last 4 sentences 75 | context = first_sentence + '. ' + last_sentences 76 | 77 | # Format question by combining context and actual question 78 | question = str(context) + '\n' + str(sample['question']).replace('Q:', '') 79 | 80 | # Clean up answer formatting 81 | answer = str(sample['answer']).replace('A:', '') 82 | 83 | # Process each referenced image 84 | images = [] 85 | for img_id in img_ref: 86 | # Construct the full image path 87 | img_path = self.img_path + '/' + PMC_id + '_' + img_id + '.jpg' 88 | 89 | try: 90 | # Load and transform the image 91 | image = Image.open(img_path).convert('RGB') 92 | image = self.transform(image) 93 | 94 | # Randomly decide where to place the image in the text 95 | # Either at the end of question or at the end of context 96 | if random.random() > 0.5: 97 | images.append({'image': image, "position": {"question": len(question)}}) 98 | else: 99 | images.append({'image': image, "position": {"question": len(context)}}) 100 | except: 101 | # Skip images that can't be loaded 102 | continue 103 | 104 | # Return formatted sample 105 | return { 106 | "image_dict": images, # List of images with position information 107 | "question": question, # Formatted question text 108 | "answer": answer, # Answer text 109 | } 110 | 111 | # Example usage (commented out) 112 | # csv_path = '/gpfs/home/cs/leijiayu/wuchaoyi/multi_modal/Data/GPT_realdata/casa_report_train.csv' 113 | # img_path = '/home/cs/leijiayu/data/all_images/figures/' 114 | # dataset = CaseReport_dataset(csv_path, img_path) 115 | # print(dataset[0]) -------------------------------------------------------------------------------- /src/Dataset/dataset/chestxray.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import logging 4 | import os 5 | import re 6 | import difflib 7 | import sys 8 | import torch 9 | import random 10 | from abc import abstractmethod 11 | from itertools import islice 12 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union 13 | from collections.abc import Mapping 14 | from torch.utils.data import DataLoader 15 | import PIL 16 | from torch.utils.data import Dataset 17 | import numpy as np 18 | import pandas as pd 19 | from tqdm import tqdm 20 | from torchvision import transforms 21 | from collections import defaultdict 22 | from PIL import Image 23 | 24 | class ChestXray_Dataset(Dataset): 25 | """_summary_ 26 | Args: 27 | Dataset (_type_): caption task formulated as vqa task for Chestxray classification dataset 28 | csv_path (_type_): path to csv file 29 | img_root_dir (_type_): path to image root directory 30 | prompt_json_file (_type_): path to json file containing caption prompts 31 | Output: 32 | Dict: { 33 | "image_dict": {"image": image, "position": {"question": 0}}, # image is a tensor of shape [c,w,h,d] [3,512,512,1], position is a dict, random choice of 0 or len(question) 34 | "question": question, # random choice of caption prompts 35 | "answer":answer, # caption 36 | } 37 | """ 38 | def __init__(self,csv_path,prompt_json_file): 39 | data_info = pd.read_csv(csv_path) 40 | self.img_path_list = np.asarray(data_info['image_path']) 41 | self.answer_list = np.asarray(data_info['label']) 42 | self.transform = transforms.Compose([ 43 | transforms.RandomResizedCrop([512,512],scale=(0.8, 1.0), interpolation=transforms.InterpolationMode.BICUBIC), 44 | transforms.ToTensor(), 45 | ]) 46 | with open(prompt_json_file, 'r') as f: 47 | self.caption_prompts = json.load(f)['caption_prompt'] 48 | 49 | def __len__(self): 50 | return len(self.img_path_list) 51 | 52 | def __getitem__(self, index): 53 | img_path = self.img_path_list[index] 54 | try: 55 | image = Image.open(img_path).convert('RGB') 56 | image = self.transform(image) 57 | image = image.unsqueeze(-1) # c,w,h,d 58 | except: 59 | image = np.random.randn(3,512,512,4) 60 | 61 | answer = self.answer_list[index] 62 | question = random.choice(self.caption_prompts) 63 | image_dict = [{ 64 | "image": image, 65 | "position": { 66 | "question": len(question) 67 | } 68 | }] 69 | return { 70 | "image_dict": image_dict, 71 | "question": question, 72 | "answer":answer, 73 | } 74 | 75 | 76 | if __name__ == "__main__": 77 | test_dataset = ChestXray_Dataset(csv_path = '../data_csv/chestxray.csv', 78 | prompt_json_file = './cls_prompt.json') 79 | for i in range(10): 80 | test_data = test_dataset[i] 81 | print(test_data['image_dict'][0]['image'].shape) # [3,512,512,1] 82 | #需要确保所有的chestxray img_path都有图像 83 | 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /src/Dataset/dataset/cls_prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "caption_prompt": [ 3 | "What is the diagnosis for this chest X-ray?", 4 | "Based on this X-ray, what type of lung disease is suspected?", 5 | "Can you identify any abnormality in this chest X-ray?", 6 | "What are the findings in this chest X-ray?", 7 | "What pathology is indicated by this chest X-ray?", 8 | "What lung disease is likely present in this chest X-ray?", 9 | "What are the potential causes of the findings in this chest X-ray?", 10 | "What are your conclusions from this chest X-ray?", 11 | "What is your interpretation of this chest X-ray?", 12 | "What abnormalities are present in this chest X-ray?", 13 | "What is the differential diagnosis for the findings in this chest X-ray?" 14 | ] 15 | } -------------------------------------------------------------------------------- /src/Dataset/dataset/data_csv/README.md: -------------------------------------------------------------------------------- 1 | Please check the [data_csv](https://huggingface.co/datasets/chaoyi-wu/RadFM_data_csv) to download the used train/test split csv files and ensure the image path are related to your local path. -------------------------------------------------------------------------------- /src/Dataset/dataset/dicom_to_png_for_VinDR_sampled_using_mammo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import csv 4 | import json 5 | import imageio 6 | 7 | import pandas as pd 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | import matplotlib.pyplot as plt 12 | from pydicom import dcmread 13 | 14 | def dcm_to_png(dcm_path,save_png_path): 15 | ds = dcmread(dcm_path) 16 | arr = ds.pixel_array 17 | img_array = arr.copy() 18 | cv2.normalize(arr, img_array, 0, 255, cv2.NORM_MINMAX) 19 | img_array = np.array(img_array,dtype='uint8') 20 | # img_array = cv2.resize(img_array, (512,512), interpolation = cv2.INTER_LINEAR) 21 | imageio.imwrite(save_png_path,img_array) 22 | 23 | def preprocess_csv(csv_path,data_dir,save_data_dir): 24 | data_info = pd.read_csv(csv_path) 25 | patient_file_list = data_info.iloc[:,0] 26 | img_file_list = data_info.iloc[:,2] 27 | for idx in tqdm(range(len(img_file_list))): 28 | patient_file = patient_file_list[idx] 29 | img_file = img_file_list[idx] 30 | img_path = os.path.join(data_dir,str(patient_file),str(img_file)+'.dicom') 31 | os.makedirs(os.path.join(save_data_dir,str(patient_file)), exist_ok=True) 32 | save_img_path = os.path.join(save_data_dir,str(patient_file),str(img_file)+'.png') 33 | dcm_to_png(img_path,save_img_path) 34 | 35 | 36 | csv_path = './DATA/VinDr/VinDr-Mammo/1.0.0/breast-level_annotations.csv' 37 | data_dir = './DATA/VinDr/VinDr-Mammo/1.0.0/images' 38 | save_data_dir = './DATA/VinDr/VinDr-Mammo/process/images' 39 | os.makedirs(save_data_dir, exist_ok=True) 40 | preprocess_csv(csv_path,data_dir,save_data_dir) 41 | -------------------------------------------------------------------------------- /src/Dataset/dataset/jpg2nii_data_convert.py: -------------------------------------------------------------------------------- 1 | #processed cases accoring to case_id_list, and save a csv file, with image path and image caption 2 | import os 3 | import cv2 4 | import csv 5 | import json 6 | import subprocess 7 | import pandas as pd 8 | import numpy as np 9 | import SimpleITK as sitk 10 | from tqdm import tqdm 11 | from collections import defaultdict 12 | 13 | def get_image(single_image_dir,single_image_filenames): 14 | # single_image_filenames 15 | single_image_filenames.sort(key=lambda x: int(x.split('.')[0])) 16 | image_list = [] 17 | for image_filename in single_image_filenames: 18 | image_file = os.path.join(single_image_dir, image_filename) 19 | #read jpeg to 2D array 20 | image_array = cv2.imread(image_file,0) 21 | if image_array is not None: 22 | image_size = image_array.shape 23 | image_array = cv2.resize(image_array,(512,512),interpolation = cv2.INTER_LINEAR) 24 | image_list.append(image_array) 25 | else: 26 | pass 27 | image_array = np.array(image_list) #c,w,h 28 | if len(image_array.shape) == 3: 29 | if image_array.shape[0] < image_array.shape[1]: 30 | image_array = image_array.transpose((1, 2, 0)) 31 | # image_array = np.transpose(image_array, (2,0,1)) # w,h,c 32 | return image_array 33 | 34 | gray_list = ['CT','MRI','X-ray','Ultrasound','Mammography'] 35 | 36 | def convert_case(case_id,image_root_dir,json_root_dir,save_case_dict,save_root_dir=None): 37 | # save_image_dir 38 | case_images_dir = os.path.join(image_root_dir, case_id) 39 | case_json_path = os.path.join(json_root_dir, case_id+'.json') 40 | with open(case_json_path, 'r') as f: 41 | data = json.load(f) 42 | image_nums = (len(data.keys())-1)//2 43 | for image_num in range(1,image_nums+1): 44 | case_dict = defaultdict(list) 45 | image_dir = os.path.join(case_images_dir, str(image_num)) #./images/1/1 46 | image_caption = data[str(image_num) + '详情'] 47 | image_modality = data[str(image_num)][0]['modality'] 48 | 49 | single_image_names = os.listdir(image_dir) 50 | single_image_names.sort(key=lambda x: int(x.split('_')[1])) 51 | save_image_series = [] 52 | 53 | for single_image_name in single_image_names: 54 | single_image_dir = os.path.join(image_dir, single_image_name) 55 | 56 | save_npy_dir = os.path.join(save_root_dir,str(case_id),str(image_num)) 57 | 58 | 59 | single_image_filenames = os.listdir(single_image_dir) 60 | if len(os.listdir(single_image_dir)) == 1: 61 | # 2D image 62 | image_file = os.path.join(single_image_dir, single_image_filenames[0]) 63 | save_image_array = cv2.imread(image_file) # w,h,c 64 | else: 65 | save_image_array = get_image(single_image_dir,single_image_filenames) 66 | if not os.path.exists(save_npy_dir): 67 | os.makedirs(save_npy_dir) 68 | # print(save_image_array.shape) 69 | if save_image_array is not None: 70 | if len(save_image_array.shape) <= 5 and len(save_image_array.shape) >=2: 71 | save_nii_path = os.path.join(save_npy_dir,single_image_name+'.nii.gz') 72 | out = sitk.GetImageFromArray(save_image_array) 73 | sitk.WriteImage(out, save_nii_path) 74 | save_image_series.append(save_nii_path) 75 | else: 76 | save_npy_path = os.path.join(save_npy_dir,single_image_name+'.npy') 77 | np.save(save_npy_path,save_image_array) 78 | save_image_series.append(save_npy_path) 79 | case_dict['image'] = save_image_series 80 | case_dict['image_caption'] = image_caption 81 | case_dict['image_modality'] = image_modality 82 | save_case_dict.append(case_dict) 83 | 84 | if __name__ == "__main__": 85 | # case_id,image_root_dir,json_root_dir 86 | import argparse 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument('--index', default=0, type=int) 89 | parser.add_argument('--add_index', default=0, type=int) 90 | parser.add_argument('--start_index', default=1, type=int) 91 | parser.add_argument('--end_index', default=1000, type=int) 92 | args = parser.parse_args() 93 | 94 | image_root_dir = '/mnt/petrelfs/share_data/zhangxiaoman/DATA/Radio_VQA/processed_file/images' 95 | json_root_dir = '/mnt/petrelfs/share_data/zhangxiaoman/DATA/Radio_VQA/processed_file/jsons' 96 | save_root_dir = '/mnt/petrelfs/share_data/zhangxiaoman/DATA/Radio_VQA/processed_file/npys' 97 | save_case_dict = [] 98 | 99 | args.start_index = args.index*1000+1 + args.add_index 100 | args.end_index = (args.index+1)*1000+1 101 | 102 | for case_id in tqdm(range(args.start_index,args.end_index)): 103 | case_id = str(case_id) 104 | convert_case(case_id,image_root_dir,json_root_dir,save_case_dict,save_root_dir) 105 | # CT_0 (200, 630, 630, 3) 106 | 107 | # save to csv 108 | save_json_file = '/mnt/petrelfs/share_data/zhangxiaoman/DATA/Radio_VQA/processed_file/processed_jsons/processed_json_'+str(args.index)+'.json' 109 | with open(save_json_file, 'w', encoding='utf-8') as f: 110 | json.dump(save_case_dict, f, ensure_ascii=False,indent=4) 111 | # B, S, T, W, H, Z 112 | # srun --partition=medai --mpi=pmi2 --quotatype=auto --gres=gpu:0 -n1 --ntasks-per-node=1 python data_convert.py --index 2 --add_index 24 113 | # cd /mnt/petrelfs/share_data/zhangxiaoman/DATA/Radio_VQA/jpeg2npy -------------------------------------------------------------------------------- /src/Dataset/dataset/mammo_prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "caption_prompt": [ 3 | "What is the diagnosis for this mammogram?", 4 | "Based on this X-ray, what type of breast disease is suspected?", 5 | "Can you identify any abnormality in this mammogram?", 6 | "What are the findings in this mammogram?", 7 | "What pathology is indicated by this mammogram?", 8 | "What lung disease is likely present in this mammogram?", 9 | "What are the potential causes of the findings in this mammogram?", 10 | "What are your conclusions from this mammogram?", 11 | "What is your interpretation of this mammogram?", 12 | "What abnormalities are present in this mammogram?", 13 | "What is the differential diagnosis for the findings in this mammogram?", 14 | "What is the diagnosis for this breast X-ray?", 15 | "Can you identify any abnormality in this breast X-ray?", 16 | "What are the findings in this breast X-ray?", 17 | "What pathology is indicated by this breast X-ray?", 18 | "What lung disease is likely present in this breast X-ray?", 19 | "What are the potential causes of the findings in this breast X-ray?", 20 | "What are your conclusions from this breast X-ray?", 21 | "What is your interpretation of this breast X-ray?", 22 | "What abnormalities are present in this breast X-ray?", 23 | "What is the differential diagnosis for the findings in this breast X-ray?" 24 | ] 25 | } -------------------------------------------------------------------------------- /src/Dataset/dataset/modality_prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "caption_prompt": [ 3 | "What modality is used to take this image?", 4 | "What type of imaging modality is used to acquire the above image?", 5 | "What imaging modality is used?", 6 | "What imaging modality was used to take this image?", 7 | "What is the imaging modality?" 8 | ], 9 | "modality_prompt": [ 10 | "Is this image a modality scan?", 11 | "Is the given image a modality scan?", 12 | "Is the given image a modality?" 13 | ] 14 | } 15 | -------------------------------------------------------------------------------- /src/Dataset/dataset/nii2npy_for_radiopaedio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import csv 4 | import json 5 | import subprocess 6 | import pandas as pd 7 | import numpy as np 8 | import SimpleITK as sitk 9 | from tqdm import tqdm 10 | from scipy import ndimage 11 | from collections import defaultdict 12 | 13 | def resize_array(array_list, shape_list): 14 | if len(array_list) == 0: 15 | return None 16 | # Get the median value of the c dimension 17 | c_values = [shape[3] for shape in shape_list] 18 | z = np.median(c_values) 19 | 20 | # Resize each array to the same size 21 | resized_arrays = [] 22 | for array in array_list: 23 | resized_array = ndimage.zoom(array, (3/array.shape[0],512/array.shape[1], 512/array.shape[2], z/array.shape[3]), order=0) 24 | # print(resized_array.shape) 25 | if resized_array.shape[3] == z: 26 | resized_arrays.append(resized_array) 27 | else: 28 | if resized_array.shape[3] > z: 29 | resized_arrays.append(resized_array[:,:,:,:int(z)]) 30 | else: 31 | resized_arrays.append(np.pad(resized_array, ((0,0),(0,0),(0,0),(0,int(z-resized_array.shape[3]))), 'constant', constant_values=0)) 32 | # Convert the list of arrays to a numpy array 33 | resized_array = np.array(resized_arrays) 34 | 35 | return resized_array 36 | 37 | def process_image_list(image_path_list): 38 | image_shape_list = [] 39 | image_array_list = [] 40 | for image_path in image_path_list: 41 | if os.path.exists(image_path) == False: 42 | continue 43 | elif image_path.split('.')[-1] == 'npy': 44 | image_array = np.load(image_path) #c,w,h,d 45 | try: 46 | image_array = cv2.resize(image_array,(512,512)) 47 | if len(image_array.shape) == 2: 48 | image_array = image_array[np.newaxis,:,:,np.newaxis] 49 | # 1wh1 to 3wh1 50 | image_array = np.concatenate([image_array,image_array,image_array],axis=0) 51 | elif len(image_array.shape) == 3: 52 | #whc to cwh 53 | image_array = image_array.transpose(2,0,1)[:,:,:,np.newaxis] 54 | 55 | image_shape_list.append(image_array.shape) 56 | image_array_list.append(image_array) 57 | except: 58 | pass 59 | else: 60 | itk_image = sitk.ReadImage(image_path) 61 | image_array = sitk.GetArrayFromImage(itk_image) #c,w,h,d 62 | if image_array.shape[0] != 512: 63 | image_array = cv2.resize(image_array,(512,512)) 64 | if len(image_array.shape) == 2: 65 | image_array = image_array[np.newaxis,:,:,np.newaxis] 66 | image_array = np.concatenate([image_array,image_array,image_array],axis=0) 67 | elif len(image_array.shape) == 3: 68 | image_array = image_array[np.newaxis,:,:,:] 69 | image_array = np.concatenate([image_array,image_array,image_array],axis=0) 70 | image_shape_list.append(image_array.shape) 71 | image_array_list.append(image_array) 72 | save_image_array = resize_array(image_array_list, image_shape_list) 73 | return save_image_array 74 | 75 | 76 | def process_json_file(json_file,save_json_file,save_root_dir): 77 | if not os.path.exists(save_root_dir): 78 | os.makedirs(save_root_dir) 79 | with open(json_file, 'r') as f: 80 | data = json.load(f) 81 | data_len = len(data) 82 | for i in tqdm(range(data_len)): 83 | samples = data[i]['samples'] 84 | for sample_i in tqdm(range(len(samples))): 85 | if samples[sample_i]['image'] == []: 86 | samples.pop(sample_i) 87 | else: 88 | image_path_list = samples[sample_i]['image'] 89 | case_id = image_path_list[0].split('/')[-3] 90 | save_image_array = process_image_list(image_path_list) 91 | if save_image_array is not None: 92 | save_image_path = os.path.join(save_root_dir, str(case_id)+'_'+str(sample_i)+'.npy') 93 | np.save(save_image_path,save_image_array) 94 | # 如果边处理边传到aws的话可以参考这一段 95 | # save_aws_image_path = save_image_path.replace('/mnt/petrelfs/share_data/zhangxiaoman/DATA/','s3://zhangxiaoman_hdd_new_share/') 96 | # os.system(f'aws s3 cp {save_image_path} {save_aws_image_path} --endpoint-url=http://10.140.27.254') 97 | # os.remove(save_image_path) 98 | # data[i]['npy_path'] = save_aws_image_path 99 | data[i]['samples']['npy_path'] = save_image_path 100 | data[i]['samples']['image_size'] = save_image_array.shape 101 | else: 102 | print(i,image_path_list) 103 | if len(samples) == 0: 104 | data.pop(i) 105 | 106 | with open(save_json_file, 'w') as f: 107 | json.dump(data, f,ensure_ascii=False,indent=4) 108 | 109 | 110 | if __name__ == "__main__": 111 | json_file = '../processed_file/processed_jsons/processed_json_2023-11-18.json' 112 | save_json_file = '../processed_file/processed_jsons/processed_json_2023-11-18-npy.json' 113 | save_root_dir = '../processed_file/processed_images' 114 | 115 | process_json_file(json_file,save_json_file,save_root_dir) 116 | -------------------------------------------------------------------------------- /src/Dataset/dataset/paper_inline.py: -------------------------------------------------------------------------------- 1 | # Import necessary libraries for data processing, image handling, and model integration 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import transformers 5 | import pandas as pd 6 | import copy 7 | import random 8 | import os 9 | import numpy as np 10 | import tqdm 11 | import torch 12 | import json 13 | from PIL import Image 14 | import torchvision 15 | from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer 16 | from torchvision import transforms 17 | 18 | class Paper_Inline_dataset(Dataset): 19 | """ 20 | Dataset class for processing scientific papers with inline images. 21 | 22 | This dataset extracts text and associated images from scientific papers, 23 | preparing them for multimodal model training. 24 | """ 25 | def __init__(self, csv_path, img_path, sample_sentence_length=50, max_img_size=3): 26 | """ 27 | Initialize the dataset. 28 | 29 | Args: 30 | csv_path: Path to CSV file containing paper metadata 31 | img_path: Root directory for paper figures 32 | sample_sentence_length: Maximum number of sentences to include in a sample 33 | max_img_size: Maximum number of images to include in a sample 34 | """ 35 | self.max_img_size = max_img_size 36 | self.sample_sentence_length = sample_sentence_length 37 | self.img_path = img_path 38 | # Load paper paths from CSV 39 | self.paper_path = np.array(pd.read_csv(csv_path)['PMC_path']) 40 | 41 | # Define image transformation pipeline 42 | # normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 43 | self.transform = transforms.Compose([ 44 | # Crop and resize images to 512x512, maintaining 80-100% of original content 45 | transforms.RandomResizedCrop([512, 512], scale=(0.8, 1.0), interpolation=transforms.InterpolationMode.BICUBIC), 46 | # Convert to tensor with values in [0, 1] 47 | transforms.ToTensor(), 48 | # normalize, # Commented out normalization 49 | ]) 50 | 51 | 52 | def __len__(self): 53 | """Return the total number of papers in the dataset""" 54 | return self.paper_path.shape[0] 55 | 56 | def __getitem__(self, idx): 57 | """ 58 | Get a single sample from the dataset 59 | 60 | Args: 61 | idx: Index of the paper to retrieve 62 | 63 | Returns: 64 | Dictionary containing the processed sample with images, question, and answer 65 | """ 66 | # Load the paper JSON file 67 | paper_json = self.paper_path[idx] 68 | # Extract PMC ID from the file path 69 | PMC_name = paper_json.rsplit('/', 2)[-1].split('.')[0] 70 | # Load the list of sentences with image references 71 | sentences_list = json.load(open(paper_json, 'r')) 72 | # Process the paper to extract text and images 73 | image_dict, question, answer = self.random_sample_sentence(sentences_list, PMC_name) 74 | 75 | # Return formatted sample 76 | # Note: question is empty since this is for pretraining with full paper text 77 | return { 78 | "image_dict": image_dict, # List of images with position information 79 | "question": question, # Empty string for this dataset 80 | "answer": answer, # Full text content 81 | } 82 | 83 | def random_sample_sentence(self, sentences_list, PMC_name): 84 | """ 85 | Sample a segment of sentences from a paper and process inline images 86 | 87 | Args: 88 | sentences_list: List of sentences with image references 89 | PMC_name: PubMed Central ID for the paper 90 | 91 | Returns: 92 | Tuple of (processed_images, question_text, answer_text) 93 | """ 94 | sentences_length = len(sentences_list) 95 | 96 | # Select a segment of the paper - either randomly or around image references 97 | p = random.random() 98 | if p >= 0.5: 99 | # Random segment selection 100 | if len(sentences_list) > self.sample_sentence_length: 101 | start = random.randint(0, sentences_length - self.sample_sentence_length) 102 | sentences_list = sentences_list[start:(start + self.sample_sentence_length)] 103 | else: 104 | # Try to select a segment containing images 105 | if len(sentences_list) > self.sample_sentence_length: 106 | sample_start = [] 107 | # Find sentences with image references 108 | for sentence_id in range(len(sentences_list)): 109 | if sentences_list[sentence_id]['img_ref'] != []: 110 | # Start 10 sentences before the image if possible 111 | if sentence_id - 10 < 0: 112 | sample_start.append(0) 113 | else: 114 | if sentence_id - 10 > sentences_length - self.sample_sentence_length: 115 | sample_start.append(sentences_length - self.sample_sentence_length) 116 | else: 117 | sample_start.append(sentence_id - 10) 118 | 119 | # If no images found, select random segment 120 | if sample_start == []: 121 | start = random.randint(0, sentences_length - self.sample_sentence_length) 122 | sentences_list = sentences_list[start:(start + self.sample_sentence_length)] 123 | else: 124 | # Select a random segment that contains images 125 | start = sample_start[random.randint(0, len(sample_start) - 1)] 126 | sentences_list = sentences_list[start:(start + self.sample_sentence_length)] 127 | 128 | # Process the selected segment 129 | text = '' 130 | images = [] 131 | for ix in sentences_list: 132 | sentence = ix 133 | if sentence["img_ref"] == []: 134 | # Add plain text without images 135 | text = text + sentence['text'] 136 | else: 137 | # Stop if we've reached the maximum number of images 138 | if len(images) + len(sentence["img_ref"]) > self.max_img_size: 139 | break 140 | 141 | # Process each image referenced in the sentence 142 | for img_id in sentence["img_ref"]: 143 | img_path = self.img_path + '/' + PMC_name + '_' + img_id + '.jpg' 144 | if os.path.exists(img_path): 145 | try: 146 | # Load and transform the image 147 | image = Image.open(img_path).convert('RGB') 148 | image = self.transform(image) 149 | # Add image with position information 150 | images.append({'image': image, "position": {"answer": len(text)}}) 151 | except: 152 | # Skip images that can't be loaded 153 | continue 154 | # Add the text after processing images 155 | text = text + sentence['text'] 156 | 157 | # For this dataset, we don't use a question-answer format 158 | # Instead, all text is in the "answer" field 159 | question = '' 160 | answer = text 161 | 162 | return images, question, answer 163 | 164 | # Example usage (commented out) 165 | # csv_path = '/home/cs/leijiayu/wuchaoyi/multi_modal/Data/train_paper.csv' 166 | # img_path = '/home/cs/leijiayu/data/all_images/figures/' 167 | # dataset = multi_paper_dataset(csv_path, img_path) 168 | # print(dataset[0]) -------------------------------------------------------------------------------- /src/Dataset/dataset/pmcoa.py: -------------------------------------------------------------------------------- 1 | # Import necessary libraries for data processing, image handling, and model interaction 2 | import csv 3 | import json 4 | import logging 5 | import os 6 | import re 7 | import difflib 8 | import sys 9 | import torch 10 | import random 11 | from abc import abstractmethod 12 | from itertools import islice 13 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union 14 | from collections.abc import Mapping 15 | from torch.utils.data import DataLoader 16 | import PIL 17 | from torch.utils.data import Dataset 18 | import numpy as np 19 | import pandas as pd 20 | from tqdm import tqdm 21 | from torchvision import transforms 22 | from collections import defaultdict 23 | from PIL import Image 24 | 25 | class PMCOA_Dataset(Dataset): 26 | """ 27 | Dataset for processing scientific figures and captions from PubMed Central Open Access (PMC-OA). 28 | 29 | This dataset formulates image captioning as a visual question answering task, 30 | where the model is prompted with a question about an image and should respond 31 | with an appropriate caption. 32 | 33 | Args: 34 | csv_path: Path to CSV file with columns [PMC_ID, Figure_path, Caption] 35 | img_root_dir: Path to image root directory containing figure images 36 | prompt_json_file: Path to JSON file containing caption prompts 37 | 38 | Output: 39 | Dict: { 40 | "image_dict": [{"image": image, "position": {"question": position}}], 41 | # image is a tensor of shape [c,w,h,d] [3,512,512,1] 42 | # position is where to insert the image - either at start (0) or end of question 43 | "question": question, # randomly selected caption prompt 44 | "answer": answer, # original caption from the paper 45 | } 46 | """ 47 | def __init__(self, csv_path, img_root_dir, prompt_json_file): 48 | """ 49 | Initialize the dataset. 50 | 51 | Args: 52 | csv_path: Path to CSV file with figure metadata 53 | img_root_dir: Root directory containing figure images 54 | prompt_json_file: JSON file with caption prompts 55 | """ 56 | self.img_root_dir = img_root_dir 57 | 58 | # Load metadata from CSV file 59 | data_info = pd.read_csv(csv_path) 60 | self.img_path_list = np.asarray(data_info['Figure_path']) 61 | self.caption_list = np.asarray(data_info['Caption']) 62 | 63 | # Define image transformation pipeline 64 | # normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 65 | self.transform = transforms.Compose([ 66 | # Crop and resize images to 512x512, maintaining 80-100% of original content 67 | transforms.RandomResizedCrop([512, 512], scale=(0.8, 1.0), interpolation=transforms.InterpolationMode.BICUBIC), 68 | # Convert to tensor with values in [0, 1] 69 | transforms.ToTensor(), 70 | # normalize, # Commented out normalization 71 | ]) 72 | 73 | # Load caption prompts from JSON file 74 | with open(prompt_json_file, 'r') as f: 75 | self.caption_prompts = json.load(f)['caption_prompt'] 76 | 77 | 78 | def __len__(self): 79 | """Return the total number of samples in the dataset""" 80 | return len(self.img_path_list) 81 | 82 | def __getitem__(self, index): 83 | """ 84 | Get a single sample from the dataset 85 | 86 | Args: 87 | index: Index of the sample to retrieve 88 | 89 | Returns: 90 | Dictionary containing processed sample with image, question prompt, and caption answer 91 | """ 92 | # Get the image filename and construct full path 93 | file_name = self.img_path_list[index] 94 | img_path = os.path.join(self.img_root_dir, file_name) 95 | 96 | # Load and preprocess the image 97 | image = Image.open(img_path).convert('RGB') 98 | image = self.transform(image) # normalize to [0,1] 99 | image = image.unsqueeze(-1) # add depth dimension [C, H, W, 1] 100 | 101 | # Get the caption and a random prompt 102 | answer = self.caption_list[index] 103 | question = random.choice(self.caption_prompts) 104 | 105 | # Randomly decide whether to place the image before or after the question 106 | if random.random() < 0.5: 107 | # Place image before the question 108 | image_dict = { 109 | "image": image, 110 | "position": { 111 | "question": 0 # At the beginning of question 112 | } 113 | } 114 | else: 115 | # Place image after the question 116 | image_dict = { 117 | "image": image, 118 | "position": { 119 | "question": len(question) # At the end of question 120 | } 121 | } 122 | 123 | # Return formatted sample 124 | return { 125 | "image_dict": [image_dict], # List containing one image with position info 126 | "question": question, # Caption prompt 127 | "answer": answer, # Ground truth caption 128 | } 129 | 130 | if __name__ == "__main__": 131 | # Example usage for testing the dataset 132 | test_dataset = PMCOA_Dataset( 133 | csv_path='../data_csv/pmcoa_image_caption_train.csv', 134 | img_root_dir='/home/cs/leijiayu/data/PMCVQA/caption_T060_filtered_top4_sep_v0_subfigures', 135 | prompt_json_file='./caption_prompt.json' 136 | ) 137 | 138 | # Test the first 10 samples 139 | for i in range(10): 140 | test_data = test_dataset[i] 141 | print(test_data['image_dict'][0]['image'].shape) # Should print [3,512,512,1] -------------------------------------------------------------------------------- /src/Dataset/dataset/radiology_feature_prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "caption_prompt": [ 3 | "What disease can be diagnosed from these radiological images and what specific features are typically observed on the images?", 4 | "Identify the disease that is typically associated with these radiological images and describe the classic radiological presentation.", 5 | "Based on the provided images, which disease is most likely to be diagnosed and how does it manifest on radiological examinations?", 6 | "Determine the disease that corresponds to the given radiographic images and describe the characteristic radiological features.", 7 | "With these radiological images, which disease would you suspect and what specific radiographic patterns are typically seen?", 8 | "Analyze the provided images and identify the disease that is commonly associated with such radiological findings. Discuss the characteristic radiographic manifestations.", 9 | "From these radiological images, diagnose the disease and explain the typical radiological presentation observed.", 10 | "Assess the radiographic images and determine the disease that is commonly linked to these findings. Describe the typical radiological features associated with this disease.", 11 | "Examine the provided radiological images and identify the disease that would most likely be diagnosed based on the characteristic radiologic appearance.", 12 | "Based on the presented radiographic findings, indicate the disease that is commonly associated with these images and describe the typical radiological patterns observed." 13 | ] 14 | } 15 | -------------------------------------------------------------------------------- /src/Dataset/dataset/report_prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "caption_prompt": [ 3 | "Can you provide a radiology report for this medical image?", 4 | "Describe the medical image you see.", 5 | "What is depicted in this picture?", 6 | "Please report this medical scan.", 7 | "What is the medical significance of this image?", 8 | "What can you infer from this picture?", 9 | "Can you provide a quick summary of this image?", 10 | "Describe this medical scan.", 11 | "Please write a radiology report for this image.", 12 | "Can you summarize the images presented?", 13 | "Please generate a radiology report for this scan.", 14 | "Describe the regions of interest in this scan.", 15 | "Please provide a caption for this medical image.", 16 | "Can you provide a brief summary of this radiograph?", 17 | "Describe the structures involved in this medical image.", 18 | "What are the findings presented in this medical scan?", 19 | "Please write a radiology report for this scan.", 20 | "Can you provide a description of this medical scan?", 21 | "Please caption this medical scan.", 22 | "Can you provide a report summary for this medical scan?" 23 | ] 24 | } -------------------------------------------------------------------------------- /src/Dataset/dataset/spinexr_prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "caption_prompt": [ 3 | "What is the diagnosis for this spine X-ray?", 4 | "Based on this X-ray, what type of spine disease is suspected?", 5 | "Can you identify any abnormality in this spine X-ray?", 6 | "What are the findings in this spine X-ray?", 7 | "What pathology is indicated by this spine X-ray?", 8 | "What lung disease is likely present in this spine X-ray?", 9 | "What are the potential causes of the findings in this spine X-ray?", 10 | "What are your conclusions from this spine X-ray?", 11 | "What is your interpretation of this spine X-ray?", 12 | "What abnormalities are present in this spine X-ray?", 13 | "What is the differential diagnosis for the findings in this spine X-ray?" 14 | ] 15 | } -------------------------------------------------------------------------------- /src/Dataset/dataset/vqa.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import logging 4 | import os 5 | import re 6 | import difflib 7 | import sys 8 | import torch 9 | import random 10 | from abc import abstractmethod 11 | from itertools import islice 12 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union 13 | from collections.abc import Mapping 14 | from torch.utils.data import DataLoader 15 | import PIL 16 | from torch.utils.data import Dataset 17 | import numpy as np 18 | import pandas as pd 19 | from tqdm import tqdm 20 | from torchvision import transforms 21 | from collections import defaultdict 22 | from PIL import Image 23 | 24 | class VQA_Dataset(Dataset): 25 | """_summary_ 26 | Args: 27 | Dataset (_type_): 28 | csv_path (_type_): path to csv file 29 | Output: 30 | Dict: { 31 | "image_dict": {"image": image, "position": {"question": 0}}, # image is a tensor of shape [c,w,h,d] [3,512,512,1], position is a dict, random choice of 0 or len(question) 32 | "question": question, # random choice of caption prompts 33 | "answer":answer, # caption 34 | } 35 | """ 36 | def __init__(self,csv_path): 37 | data_info = pd.read_csv(csv_path) 38 | self.img_root_dir_list = np.asarray(data_info['img_root_dir']) 39 | self.img_path_list = np.asarray(data_info['Figure_path']) 40 | self.question_list = np.asarray(data_info['Question']) 41 | self.answer_list = np.asarray(data_info['Answer']) 42 | # PMC_ID,Figure_path,Question,Answer 43 | self.transform = transforms.Compose([ 44 | transforms.RandomResizedCrop([512,512],scale=(0.8, 1.0), interpolation=transforms.InterpolationMode.BICUBIC), 45 | transforms.ToTensor(), 46 | ]) 47 | 48 | 49 | def __len__(self): 50 | return len(self.img_path_list) 51 | 52 | def __getitem__(self, index): 53 | file_name = self.img_path_list[index] 54 | img_root_dir = self.img_root_dir_list[index] 55 | img_path = os.path.join(img_root_dir,file_name) 56 | image = Image.open(img_path).convert('RGB') 57 | image = self.transform(image) 58 | image = image.unsqueeze(-1) 59 | answer = self.answer_list[index] 60 | question = str(self.question_list[index]) 61 | if random.random() < 0.5: 62 | image_dict = { 63 | "image": image, 64 | "position": { 65 | "question": 0 66 | } 67 | } 68 | else: 69 | image_dict = { 70 | "image": image, 71 | "position": { 72 | "question": len(question) 73 | } 74 | } 75 | return { 76 | "image_dict": [image_dict], 77 | "question": question, 78 | "answer":answer, 79 | } 80 | 81 | if __name__ == "__main__": 82 | test_dataset = PMCVQA_Dataset(csv_path = '../data_csv/pmcvqa_train.csv') 83 | for i in range(10): 84 | test_data = test_dataset[i] 85 | print(test_data['image_dict'][0]['image'].shape) # [3,512,512,1] 86 | 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /src/Dataset/dataset/yes_no_prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "caption_prompt": [ 3 | "Is the disease visible in the image?", 4 | "Does the image show signs of disease?", 5 | "Does the image show any disease?", 6 | "Is there any disease in the affected area?", 7 | "Does the image depict any visible disease?", 8 | "Is there an presence of disease in the image?", 9 | "Are there any visible signs of disease in the image?", 10 | "Does the image exhibit any disease?", 11 | "Are there disease visible in the image?", 12 | "Does the image show any signs of disease?", 13 | "Can you identify any visible signs of disease in the image?", 14 | "Is there any indication of disease in the image?", 15 | "Does the image show signs of disease?", 16 | "Does the image show any visible signs of disease?" 17 | ] 18 | } 19 | 20 | 21 | -------------------------------------------------------------------------------- /src/Model/RadFM/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Model/RadFM/__init__.py -------------------------------------------------------------------------------- /src/Model/RadFM/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Model/RadFM/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /src/Model/RadFM/__pycache__/blocks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Model/RadFM/__pycache__/blocks.cpython-39.pyc -------------------------------------------------------------------------------- /src/Model/RadFM/__pycache__/helpers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Model/RadFM/__pycache__/helpers.cpython-39.pyc -------------------------------------------------------------------------------- /src/Model/RadFM/__pycache__/multimodality_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Model/RadFM/__pycache__/multimodality_model.cpython-39.pyc -------------------------------------------------------------------------------- /src/Model/RadFM/__pycache__/my_embedding_layer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Model/RadFM/__pycache__/my_embedding_layer.cpython-39.pyc -------------------------------------------------------------------------------- /src/Model/RadFM/__pycache__/position_encoding.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Model/RadFM/__pycache__/position_encoding.cpython-39.pyc -------------------------------------------------------------------------------- /src/Model/RadFM/__pycache__/transformer_decoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Model/RadFM/__pycache__/transformer_decoder.cpython-39.pyc -------------------------------------------------------------------------------- /src/Model/RadFM/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Model/RadFM/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /src/Model/RadFM/__pycache__/vit_3d.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Model/RadFM/__pycache__/vit_3d.cpython-39.pyc -------------------------------------------------------------------------------- /src/Model/RadFM/blocks.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union, Callable, Optional 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from torch.utils.checkpoint import checkpoint 8 | 9 | class PMC_CLIP_cfg: 10 | backbone: str = 'ModifiedRN50' # ['RN50', 'ModifiedRN50', 'MAE'] 11 | layers: Union[Tuple[int, int, int, int], int] = [3,4,6,3] 12 | width: int = 64 13 | head_width: int = 64 14 | mlp_ratio: float = 4.0 15 | patch_size: int = 16 16 | image_size: Union[Tuple[int, int], int] = 224 17 | timm_model_name: str = None # a valid model name overrides layers, width, patch_size 18 | timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model 19 | timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') 20 | timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') 21 | patch_dropout: float = 0.0 # patch dropout rate, no dropout by default 22 | drop_attention_rate: float = 0. # Transformer Dropout 23 | patch_size: None 24 | 25 | class Bottleneck(nn.Module): 26 | expansion = 4 27 | 28 | def __init__(self, inplanes, planes, stride=1): 29 | super().__init__() 30 | 31 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 32 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 33 | self.bn1 = nn.BatchNorm2d(planes) 34 | self.relu1 = nn.ReLU(inplace=True) 35 | 36 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 37 | self.bn2 = nn.BatchNorm2d(planes) 38 | self.relu2 = nn.ReLU(inplace=True) 39 | 40 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 41 | 42 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 43 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 44 | self.relu3 = nn.ReLU(inplace=True) 45 | 46 | self.downsample = None 47 | self.stride = stride 48 | 49 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 50 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 51 | self.downsample = nn.Sequential(OrderedDict([ 52 | ("-1", nn.AvgPool2d(stride)), 53 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 54 | ("1", nn.BatchNorm2d(planes * self.expansion)) 55 | ])) 56 | 57 | def forward(self, x: torch.Tensor): 58 | identity = x 59 | 60 | out = self.relu1(self.bn1(self.conv1(x))) 61 | out = self.relu2(self.bn2(self.conv2(out))) 62 | out = self.avgpool(out) 63 | out = self.bn3(self.conv3(out)) 64 | 65 | if self.downsample is not None: 66 | identity = self.downsample(x) 67 | 68 | out += identity 69 | out = self.relu3(out) 70 | return out 71 | 72 | 73 | class AttentionPool2d(nn.Module): 74 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 75 | super().__init__() 76 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 77 | self.k_proj = nn.Linear(embed_dim, embed_dim) 78 | self.q_proj = nn.Linear(embed_dim, embed_dim) 79 | self.v_proj = nn.Linear(embed_dim, embed_dim) 80 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 81 | self.num_heads = num_heads 82 | 83 | def forward(self, x): 84 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 85 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 86 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 87 | x, _ = F.multi_head_attention_forward( 88 | query=x, key=x, value=x, 89 | embed_dim_to_check=x.shape[-1], 90 | num_heads=self.num_heads, 91 | q_proj_weight=self.q_proj.weight, 92 | k_proj_weight=self.k_proj.weight, 93 | v_proj_weight=self.v_proj.weight, 94 | in_proj_weight=None, 95 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 96 | bias_k=None, 97 | bias_v=None, 98 | add_zero_attn=False, 99 | dropout_p=0, 100 | out_proj_weight=self.c_proj.weight, 101 | out_proj_bias=self.c_proj.bias, 102 | use_separate_proj_weight=True, 103 | training=self.training, 104 | need_weights=False 105 | ) 106 | 107 | return x[0] 108 | 109 | 110 | class ResNet(nn.Module): 111 | """ 112 | RN50 113 | """ 114 | 115 | def __init__( 116 | self, layers, output_dim, heads, image_size=224, width=64, 117 | block=Bottleneck, 118 | ): 119 | super().__init__() 120 | self.output_dim = output_dim 121 | self.image_size = image_size 122 | 123 | # the 1-layer stem 124 | self.conv1 = nn.Conv2d(3, width, kernel_size=3, stride=2, padding=1, bias=False) 125 | self.bn1 = nn.BatchNorm2d(width) 126 | self.relu1 = nn.ReLU(inplace=True) 127 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 128 | 129 | # residual layers 130 | self._inplanes = width # this is a *mutable* variable used during construction 131 | self.layer1 = self._make_layer(width, layers[0]) 132 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 133 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 134 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 135 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 136 | # self.head = nn.Linear(512 * 6, output_dim) 137 | self.head = nn.Linear(512 * block.expansion, output_dim) 138 | 139 | # embed_dim = width * 32 # the ResNet feature dimension 140 | # self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 141 | 142 | self.init_parameters() 143 | 144 | def _make_layer( 145 | self, 146 | planes, blocks, stride=1, 147 | block=Bottleneck, 148 | ): 149 | layers = [block(self._inplanes, planes, stride)] 150 | 151 | self._inplanes = planes * block.expansion 152 | for _ in range(1, blocks): 153 | layers.append(block(self._inplanes, planes)) 154 | 155 | return nn.Sequential(*layers) 156 | 157 | def init_parameters(self): 158 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 159 | for name, param in resnet_block.named_parameters(): 160 | if name.endswith("bn3.weight"): 161 | nn.init.zeros_(param) 162 | 163 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 164 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 165 | for param in self.parameters(): 166 | param.requires_grad = False 167 | if freeze_bn_stats: 168 | freeze_batch_norm_2d(self) 169 | 170 | @torch.jit.ignore 171 | def set_grad_checkpointing(self, enable=True): 172 | # FIXME support for non-transformer 173 | pass 174 | 175 | def stem(self, x): 176 | x = self.relu1(self.bn1(self.conv1(x))) 177 | x = self.maxpool(x) 178 | return x 179 | 180 | def forward(self, x): 181 | # x[0]: [batch_size, 3, 224, 224] 182 | # x[1]: [batch_size, 1] 183 | x = self.stem(x) # [batch_size, 64, 56, 56] 184 | x = self.layer1(x) 185 | x = self.layer2(x) 186 | x = self.layer3(x) 187 | x = self.layer4(x) # [batch_size, 2048, 7, 7] 188 | x = self.avgpool(x) # [batch_size, 2048, 1, 1] 189 | x = torch.flatten(x, 1) # [batch_size, 2048*1*1] 190 | x = self.head(x) # [batch_size, 1024] 191 | 192 | visual_output = dict.fromkeys(["image_features", "mim_loss"], None) 193 | visual_output.update({ 194 | 'image_features': x, 195 | }) 196 | 197 | return visual_output 198 | 199 | 200 | class ModifiedResNet(nn.Module): 201 | """ 202 | A ResNet class that is similar to torchvision's but contains the following changes: 203 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 204 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 205 | - The final pooling layer is a QKV attention instead of an average pool 206 | """ 207 | 208 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 209 | super().__init__() 210 | self.output_dim = output_dim 211 | self.image_size = image_size 212 | 213 | # the 3-layer stem 214 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 215 | self.bn1 = nn.BatchNorm2d(width // 2) 216 | self.relu1 = nn.ReLU(inplace=True) 217 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 218 | self.bn2 = nn.BatchNorm2d(width // 2) 219 | self.relu2 = nn.ReLU(inplace=True) 220 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 221 | self.bn3 = nn.BatchNorm2d(width) 222 | self.relu3 = nn.ReLU(inplace=True) 223 | self.avgpool = nn.AvgPool2d(2) 224 | 225 | # residual layers 226 | self._inplanes = width # this is a *mutable* variable used during construction 227 | self.layer1 = self._make_layer(width, layers[0]) 228 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 229 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 230 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 231 | 232 | embed_dim = width * 32 # the ResNet feature dimension 233 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 234 | 235 | self.init_parameters() 236 | 237 | def _make_layer(self, planes, blocks, stride=1): 238 | layers = [Bottleneck(self._inplanes, planes, stride)] 239 | 240 | self._inplanes = planes * Bottleneck.expansion 241 | for _ in range(1, blocks): 242 | layers.append(Bottleneck(self._inplanes, planes)) 243 | 244 | return nn.Sequential(*layers) 245 | 246 | def init_parameters(self): 247 | if self.attnpool is not None: 248 | std = self.attnpool.c_proj.in_features ** -0.5 249 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 250 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 251 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 252 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 253 | 254 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 255 | for name, param in resnet_block.named_parameters(): 256 | if name.endswith("bn3.weight"): 257 | nn.init.zeros_(param) 258 | 259 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 260 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 261 | for param in self.parameters(): 262 | param.requires_grad = False 263 | if freeze_bn_stats: 264 | freeze_batch_norm_2d(self) 265 | 266 | @torch.jit.ignore 267 | def set_grad_checkpointing(self, enable=True): 268 | # FIXME support for non-transformer 269 | pass 270 | 271 | def stem(self, x): 272 | x = self.relu1(self.bn1(self.conv1(x))) 273 | x = self.relu2(self.bn2(self.conv2(x))) 274 | x = self.relu3(self.bn3(self.conv3(x))) 275 | x = self.avgpool(x) 276 | return x 277 | 278 | def forward(self, x): 279 | x = self.stem(x) 280 | x = self.layer1(x) 281 | x = self.layer2(x) 282 | x = self.layer3(x) 283 | x = self.layer4(x) 284 | x = self.attnpool(x) 285 | 286 | visual_output = dict.fromkeys(["image_features", "mim_loss"], None) 287 | visual_output.update({ 288 | 'image_features': x, 289 | }) 290 | 291 | return visual_output 292 | 293 | 294 | class LayerNorm(nn.LayerNorm): 295 | """Subclass torch's LayerNorm to handle fp16.""" 296 | 297 | def forward(self, x: torch.Tensor): 298 | orig_type = x.dtype 299 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 300 | return x.to(orig_type) 301 | 302 | 303 | class QuickGELU(nn.Module): 304 | # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory 305 | def forward(self, x: torch.Tensor): 306 | return x * torch.sigmoid(1.702 * x) 307 | 308 | 309 | class ResidualAttentionBlock(nn.Module): 310 | def __init__( 311 | self, d_model: int, n_head: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU, 312 | drop_attention_rate: float = 0., 313 | ): 314 | super().__init__() 315 | 316 | self.attn = nn.MultiheadAttention( 317 | embed_dim=d_model, 318 | num_heads=n_head, 319 | dropout=drop_attention_rate, 320 | ) 321 | self.ln_1 = LayerNorm(d_model) 322 | mlp_width = int(d_model * mlp_ratio) 323 | self.mlp = nn.Sequential(OrderedDict([ 324 | ("c_fc", nn.Linear(d_model, mlp_width)), 325 | ("gelu", act_layer()), 326 | ("c_proj", nn.Linear(mlp_width, d_model)) 327 | ])) 328 | self.ln_2 = LayerNorm(d_model) 329 | 330 | def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 331 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] 332 | 333 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 334 | x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) 335 | x = x + self.mlp(self.ln_2(x)) 336 | return x 337 | 338 | 339 | class PatchDropout(nn.Module): 340 | """ 341 | https://arxiv.org/abs/2212.00794 342 | """ 343 | 344 | def __init__(self, prob, exclude_first_token=True): 345 | super().__init__() 346 | assert 0 <= prob < 1. 347 | self.prob = prob 348 | self.exclude_first_token = exclude_first_token # exclude CLS token 349 | 350 | def forward(self, x): 351 | if not self.training or self.prob == 0.: 352 | return x 353 | 354 | if self.exclude_first_token: 355 | cls_tokens, x = x[:, :1], x[:, 1:] 356 | else: 357 | cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) 358 | 359 | batch = x.size()[0] 360 | num_tokens = x.size()[1] 361 | 362 | batch_indices = torch.arange(batch) 363 | batch_indices = batch_indices[..., None] 364 | 365 | keep_prob = 1 - self.prob 366 | num_patches_keep = max(1, int(num_tokens * keep_prob)) 367 | 368 | rand = torch.randn(batch, num_tokens) 369 | patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices 370 | 371 | x = x[batch_indices, patch_indices_keep] 372 | 373 | if self.exclude_first_token: 374 | x = torch.cat((cls_tokens, x), dim=1) 375 | 376 | return x 377 | 378 | 379 | class Transformer(nn.Module): 380 | def __init__( 381 | self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU, 382 | drop_attention_rate: float = 0., 383 | ): 384 | super().__init__() 385 | self.width = width 386 | self.layers = layers 387 | self.grad_checkpointing = False 388 | 389 | self.resblocks = nn.ModuleList([ 390 | ResidualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer, drop_attention_rate=drop_attention_rate) 391 | for _ in range(layers) 392 | ]) 393 | 394 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 395 | for r in self.resblocks: 396 | if self.grad_checkpointing and not torch.jit.is_scripting(): 397 | x = checkpoint(r, x, attn_mask) 398 | else: 399 | x = r(x, attn_mask=attn_mask) 400 | return x -------------------------------------------------------------------------------- /src/Model/RadFM/helpers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from https://github.com/lucidrains/flamingo-pytorch 3 | """ 4 | 5 | import torch 6 | from einops import rearrange, repeat 7 | from einops_exts import rearrange_many 8 | from torch import einsum, nn 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def FeedForward(dim, mult=4): 16 | inner_dim = int(dim * mult) 17 | return nn.Sequential( 18 | nn.LayerNorm(dim), 19 | nn.Linear(dim, inner_dim, bias=False), 20 | nn.GELU(), 21 | nn.Linear(inner_dim, dim, bias=False), 22 | ) 23 | 24 | 25 | class PerceiverAttention(nn.Module): 26 | def __init__(self, *, dim, dim_head=64, heads=8): 27 | super().__init__() 28 | self.scale = dim_head**-0.5 29 | self.heads = heads 30 | inner_dim = dim_head * heads 31 | 32 | self.norm_media = nn.LayerNorm(dim) 33 | self.norm_latents = nn.LayerNorm(dim) 34 | 35 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 36 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 37 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 38 | 39 | def forward(self, x, latents): 40 | """ 41 | Args: 42 | x (torch.Tensor): image features 43 | shape (b, T, n1, D) 44 | latent (torch.Tensor): latent features 45 | shape (b, T, n2, D) 46 | """ 47 | x = self.norm_media(x) 48 | latents = self.norm_latents(latents) 49 | 50 | h = self.heads 51 | 52 | q = self.to_q(latents) 53 | kv_input = torch.cat((x, latents), dim=-2) 54 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 55 | q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) 56 | q = q * self.scale 57 | 58 | # attention 59 | sim = einsum("... i d, ... j d -> ... i j", q, k) 60 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 61 | attn = sim.softmax(dim=-1) 62 | 63 | out = einsum("... i j, ... j d -> ... i d", attn, v) 64 | out = rearrange(out, "b h t n d -> b t n (h d)", h=h) 65 | return self.to_out(out) 66 | 67 | 68 | class PerceiverResampler(nn.Module): 69 | def __init__( 70 | self, 71 | *, 72 | dim, 73 | depth=6, 74 | dim_head=64, 75 | heads=8, 76 | num_latents=64, 77 | max_num_media=None, 78 | max_num_frames=None, 79 | ff_mult=4, 80 | ): 81 | super().__init__() 82 | self.latents = nn.Parameter(torch.randn(num_latents, dim)) 83 | self.frame_embs = ( 84 | nn.Parameter(torch.randn(max_num_frames, dim)) 85 | if exists(max_num_frames) 86 | else None 87 | ) 88 | self.media_time_embs = ( 89 | nn.Parameter(torch.randn(max_num_media, 1, dim)) 90 | if exists(max_num_media) 91 | else None 92 | ) 93 | 94 | self.layers = nn.ModuleList([]) 95 | for _ in range(depth): 96 | self.layers.append( 97 | nn.ModuleList( 98 | [ 99 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 100 | FeedForward(dim=dim, mult=ff_mult), 101 | ] 102 | ) 103 | ) 104 | 105 | self.norm = nn.LayerNorm(dim) 106 | 107 | def forward(self, x): 108 | """ 109 | Args: 110 | x (torch.Tensor): image features 111 | shape (b, T, F, v, D) 112 | Returns: 113 | shape (b, T, n, D) where n is self.num_latents 114 | """ 115 | b, T, F, v = x.shape[:4] 116 | 117 | # frame and media time embeddings 118 | if exists(self.frame_embs): 119 | frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) 120 | x = x + frame_embs 121 | x = rearrange( 122 | x, "b T F v d -> b T (F v) d" 123 | ) # flatten the frame and spatial dimensions 124 | if exists(self.media_time_embs): 125 | x = x + self.media_time_embs[:T] 126 | 127 | # blocks 128 | latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) 129 | for attn, ff in self.layers: 130 | latents = attn(x, latents) + latents 131 | latents = ff(latents) + latents 132 | return self.norm(latents) 133 | 134 | 135 | # gated cross attention 136 | 137 | 138 | class MaskedCrossAttention(nn.Module): 139 | def __init__( 140 | self, 141 | *, 142 | dim, 143 | dim_visual, 144 | dim_head=64, 145 | heads=8, 146 | only_attend_immediate_media=True, 147 | ): 148 | super().__init__() 149 | self.scale = dim_head**-0.5 150 | self.heads = heads 151 | inner_dim = dim_head * heads 152 | 153 | self.norm = nn.LayerNorm(dim) 154 | 155 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 156 | self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False) 157 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 158 | 159 | # whether for text to only attend to immediate preceding image, or all previous images 160 | self.only_attend_immediate_media = only_attend_immediate_media 161 | 162 | def forward(self, x, media, media_locations=None, attend_previous=True): 163 | """ 164 | Args: 165 | x (torch.Tensor): text features 166 | shape (B, T_txt, D_txt) 167 | media (torch.Tensor): image features 168 | shape (B, T_img, n, D_img) where n is the dim of the latents 169 | media_locations: boolean mask identifying the media tokens in x 170 | shape (B, T_txt) 171 | attend_previous: bool 172 | If false, ignores immediately preceding image and starts attending when following image 173 | """ 174 | _, T_img, n = media.shape[:3] 175 | h = self.heads 176 | 177 | x = self.norm(x) 178 | 179 | q = self.to_q(x) 180 | media = rearrange(media, "b t n d -> b (t n) d") 181 | 182 | k, v = self.to_kv(media).chunk(2, dim=-1) 183 | q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h) 184 | 185 | q = q * self.scale 186 | 187 | sim = einsum("... i d, ... j d -> ... i j", q, k) 188 | 189 | if exists(media_locations): 190 | # at each boolean of True, increment the time counter (relative to media time) 191 | text_time = media_locations.cumsum(dim=-1) 192 | media_time = torch.arange(T_img, device=x.device) + 1 193 | 194 | if not attend_previous: 195 | text_time[~media_locations] += 1 196 | # make sure max is still the number of images in the sequence 197 | text_time[ 198 | text_time 199 | > repeat( 200 | torch.count_nonzero(media_locations, dim=1), 201 | "b -> b i", 202 | i=text_time.shape[1], 203 | ) 204 | ] = 0 205 | 206 | # text time must equal media time if only attending to most immediate image 207 | # otherwise, as long as text time is greater than media time (if attending to all previous images / media) 208 | mask_op = torch.eq if self.only_attend_immediate_media else torch.ge 209 | 210 | text_to_media_mask = mask_op( 211 | rearrange(text_time, "b i -> b 1 i 1"), 212 | repeat(media_time, "j -> 1 1 1 (j n)", n=n), 213 | ) 214 | sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max) 215 | 216 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 217 | attn = sim.softmax(dim=-1) 218 | 219 | if exists(media_locations) and self.only_attend_immediate_media: 220 | # any text without a preceding media needs to have attention zeroed out 221 | text_without_media_mask = text_time == 0 222 | text_without_media_mask = rearrange( 223 | text_without_media_mask, "b i -> b 1 i 1" 224 | ) 225 | attn = attn.masked_fill(text_without_media_mask, 0.0) 226 | 227 | out = einsum("... i j, ... j d -> ... i d", attn, v) 228 | out = rearrange(out, "b h n d -> b n (h d)") 229 | return self.to_out(out) 230 | 231 | 232 | class GatedCrossAttentionBlock(nn.Module): 233 | def __init__( 234 | self, 235 | *, 236 | dim, 237 | dim_visual, 238 | dim_head=64, 239 | heads=8, 240 | ff_mult=4, 241 | only_attend_immediate_media=True, 242 | ): 243 | super().__init__() 244 | self.attn = MaskedCrossAttention( 245 | dim=dim, 246 | dim_visual=dim_visual, 247 | dim_head=dim_head, 248 | heads=heads, 249 | only_attend_immediate_media=only_attend_immediate_media, 250 | ) 251 | self.attn_gate = nn.Parameter(torch.tensor([0.0])) 252 | 253 | self.ff = FeedForward(dim, mult=ff_mult) 254 | self.ff_gate = nn.Parameter(torch.tensor([0.0])) 255 | 256 | def forward( 257 | self, 258 | x, 259 | media, 260 | media_locations=None, 261 | attend_previous=True, 262 | ): 263 | x = ( 264 | self.attn( 265 | x, 266 | media, 267 | media_locations=media_locations, 268 | attend_previous=attend_previous, 269 | ) 270 | * self.attn_gate.tanh() 271 | + x 272 | ) 273 | x = self.ff(x) * self.ff_gate.tanh() + x 274 | 275 | return x 276 | -------------------------------------------------------------------------------- /src/Model/RadFM/multimodality_model.py: -------------------------------------------------------------------------------- 1 | # Import necessary libraries 2 | from torch import nn 3 | from transformers.models.llama import LlamaForCausalLM 4 | from .my_embedding_layer import MyEmbedding 5 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 6 | import tqdm.auto as tqdm 7 | import torch.nn as nn 8 | import torch 9 | from torch.utils.checkpoint import checkpoint 10 | from torch.autograd import Variable 11 | import numpy as np 12 | 13 | class MultiLLaMAForCausalLM(nn.Module): 14 | """ 15 | A multimodal LLaMA model that combines language and vision inputs 16 | for causal language modeling tasks. 17 | """ 18 | def __init__(self, lang_model_path): 19 | """ 20 | Initialize the multimodal model. 21 | 22 | Args: 23 | lang_model_path (str): Path to the pretrained language model 24 | """ 25 | super(MultiLLaMAForCausalLM, self).__init__() 26 | 27 | # Load pretrained LLaMA model 28 | self.lang_model = LlamaForCausalLM.from_pretrained( 29 | lang_model_path, 30 | ) 31 | 32 | # Enable gradient checkpointing for memory efficiency 33 | self.lang_model.gradient_checkpointing_enable() 34 | self.lang_model.enable_input_require_grads() 35 | 36 | # Initialize custom embedding layer and share weights with language model 37 | self.embedding_layer = MyEmbedding() 38 | self.embedding_layer.weight = self.lang_model.get_input_embeddings().weight 39 | 40 | # Set model dimensions 41 | self.hidden_dim = 5120 42 | self.voc_size = 32000 43 | 44 | def forward(self, lang_x, vision_x, attention_mask, labels, loss_reweight, key_words_query): 45 | """ 46 | Forward pass for the multimodal model. 47 | 48 | Args: 49 | lang_x: Language input tokens 50 | vision_x: Vision input features 51 | attention_mask: Attention mask for language inputs 52 | labels: Target labels for language modeling 53 | loss_reweight: Weights for calculating loss (to prioritize certain tokens) 54 | key_words_query: Query for highlighting important words 55 | 56 | Returns: 57 | Dictionary containing model outputs including loss and logits 58 | """ 59 | if labels.shape == lang_x.shape: 60 | # Set embedding mode to handle text inputs 61 | self.embedding_layer.flag = 'Text' 62 | 63 | # Get embeddings and matching loss from embedding layer 64 | input_embedding, loss_match = self.embedding_layer(lang_x, vision_x, key_words_query) 65 | 66 | # Forward pass through the language model 67 | output = self.lang_model(inputs_embeds=input_embedding, attention_mask=attention_mask, labels=labels) 68 | logits = output['logits'] 69 | 70 | # Initialize regularization loss 71 | loss_reg = None 72 | if labels is not None: 73 | # Shift logits and labels for next-token prediction 74 | shift_logits = logits[..., :-1, :].contiguous() 75 | shift_labels = labels[..., 1:].contiguous() 76 | shift_loss_reweight = loss_reweight[..., 1:].contiguous() 77 | 78 | # Prepare for loss calculation 79 | loss_fct = CrossEntropyLoss(reduction='none') 80 | shift_logits = shift_logits.view(-1, self.voc_size) 81 | shift_labels = shift_labels.view(-1) 82 | shift_loss_reweight = shift_loss_reweight.view(-1) 83 | 84 | # Ensure tensors are on the same device 85 | shift_labels = shift_labels.to(shift_logits.device) 86 | shift_loss_reweight = shift_loss_reweight.to(shift_logits.device) 87 | 88 | # Calculate weighted cross-entropy loss 89 | loss_reg = loss_fct(shift_logits, shift_labels) 90 | loss_reg = torch.sum(shift_loss_reweight * loss_reg) / torch.sum(shift_loss_reweight) 91 | 92 | # Combine losses 93 | loss = loss_reg 94 | if loss_match is not None: 95 | loss = 0.8 * loss + 0.2 * loss_match 96 | 97 | # Calculate accuracy metrics 98 | logits = output['logits'][..., :-1, :].contiguous().detach() 99 | total = len(labels) 100 | predictions = torch.argmax(logits, dim=-1) 101 | labels = labels[..., 1:].contiguous() 102 | 103 | # Count correct predictions (ignoring padding tokens with -100) 104 | Acc = torch.sum(torch.all(torch.logical_or(predictions == labels, labels == -100), dim=-1)) 105 | Accuracy = Acc / total 106 | 107 | return dict( 108 | # loss_reg = loss_reg, 109 | # loss_matching = loss_matching, 110 | logits=Accuracy, 111 | loss=output['loss'], 112 | ) 113 | 114 | ### useless for now ignore the folowing codes ### 115 | # if labels.shape == vision_x.shape: 116 | # self.embedding_layer.flag = 'Seg' 117 | # input_embedding = self.embedding_layer(lang_x, vision_x) 118 | 119 | def generate(self, lang_x, vision_x): 120 | """ 121 | Generate text based on language and vision inputs. 122 | 123 | Args: 124 | lang_x: Language input tokens 125 | vision_x: Vision input features 126 | 127 | Returns: 128 | Generated token sequence 129 | """ 130 | # Set embedding mode to text generation 131 | self.embedding_layer.flag = 'Text' 132 | 133 | with torch.no_grad(): 134 | # Get embeddings from the embedding layer 135 | input_embedding, _ = self.embedding_layer(lang_x, vision_x) 136 | 137 | # Generate text using language model 138 | generation = self.lang_model.generate( 139 | inputs_embeds=input_embedding, 140 | max_new_tokens=200, 141 | top_k=50 142 | ) 143 | 144 | return generation -------------------------------------------------------------------------------- /src/Model/RadFM/my_embedding_layer.py: -------------------------------------------------------------------------------- 1 | # Import necessary libraries 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch 5 | from .helpers import PerceiverResampler 6 | from .utils import get_visual_encoder 7 | from einops import rearrange, repeat 8 | from einops_exts import rearrange_many 9 | import torchvision 10 | from .vit_3d import ViT 11 | from einops.layers.torch import Rearrange 12 | from .transformer_decoder import TransformerDecoder, TransformerDecoderLayer 13 | from torch.utils.checkpoint import checkpoint 14 | from torch.autograd import Variable 15 | import random 16 | from transformers import AutoTokenizer, AutoModel 17 | 18 | class MyEmbedding(nn.Module): 19 | """ 20 | Custom embedding layer for multimodal inputs that combines text and vision features. 21 | """ 22 | def __init__(self, num_embeddings=32000, embedding_dim=5120, perceiver_num=32, vis_dim=768, 23 | patch_size=32, frame_patch_size=4, seg_channel=256): 24 | """ 25 | Initialize the multimodal embedding layer. 26 | 27 | Args: 28 | num_embeddings (int): Size of vocabulary for text embeddings 29 | embedding_dim (int): Dimension of output embeddings 30 | perceiver_num (int): Number of latent vectors in perceiver 31 | vis_dim (int): Dimension of vision features 32 | patch_size (int): Size of image patches 33 | frame_patch_size (int): Size of 3D frame patches 34 | seg_channel (int): Number of segmentation channels 35 | """ 36 | super().__init__() 37 | self.num_embeddings = num_embeddings 38 | self.embedding_dim = embedding_dim 39 | # Standard embedding weight matrix for text tokens 40 | self.weight = nn.Parameter(torch.torch.randn((num_embeddings, embedding_dim))) 41 | # Special token weights for figures/images 42 | self.figure_token_weight = nn.Parameter(torch.randn((2, embedding_dim))) 43 | self.flag = 'Text' # Mode flag: 'Text' or 'Seg' 44 | self.patch_size = patch_size 45 | self.frame_patch_size = frame_patch_size 46 | self.seg_channel = seg_channel 47 | 48 | ## the MedKEBERT can be downloaded from https://huggingface.co/xmcmic/Med-KEBERT/tree/main ## 49 | # Initialize medical domain BERT model for keyword understanding 50 | self.bert_tokenizer = AutoTokenizer.from_pretrained("xmcmic/Med-KEBERT") 51 | self.bert_model = AutoModel.from_pretrained("xmcmic/Med-KEBERT") 52 | # Project BERT outputs to vision feature space 53 | self.bert_projection_fc = nn.Linear(768, vis_dim) 54 | 55 | # 3D Vision Transformer for processing volumetric medical images 56 | self.vision_encoder = ViT( 57 | image_size=512, # image size 58 | frames=512, # max number of frames 59 | image_patch_size=patch_size, # image patch size 60 | frame_patch_size=frame_patch_size, # frame patch size 61 | dim=vis_dim, 62 | depth=12, 63 | heads=8, 64 | mlp_dim=2048, 65 | dropout=0.1, 66 | emb_dropout=0.1 67 | ) 68 | 69 | # Upscaling layers for vision features (used in segmentation mode) 70 | self.output_upscaling = nn.Sequential( 71 | nn.ConvTranspose3d(vis_dim, vis_dim // 4, kernel_size=2, stride=2), 72 | nn.BatchNorm3d(vis_dim // 4), 73 | nn.GELU(), 74 | nn.ConvTranspose3d(vis_dim // 4, vis_dim // 8, kernel_size=2, stride=2), 75 | nn.GELU(), 76 | ) 77 | 78 | # Transformer decoder for cross-attention between text and vision 79 | decoder_layer = TransformerDecoderLayer(d_model=vis_dim, nhead=8, normalize_before=True) 80 | decoder_norm = nn.LayerNorm(vis_dim) 81 | self.transformer_decoder = TransformerDecoder(decoder_layer=decoder_layer, num_layers=4, norm=decoder_norm) 82 | 83 | # MLP for processing transformer decoder outputs 84 | self.transformer_decoder_mlp = nn.Sequential( 85 | nn.Linear(vis_dim, vis_dim // 4), 86 | nn.GELU(), 87 | nn.Linear(vis_dim // 4, vis_dim // 8), 88 | nn.GELU(), 89 | ) 90 | self.vis_dim = vis_dim 91 | 92 | # Perceiver resampler to reduce sequence length of vision features 93 | self.perceiver = PerceiverResampler(dim=self.vis_dim, num_latents=perceiver_num) 94 | # Final projection to embedding dimension 95 | self.fc = nn.Linear(self.vis_dim, self.embedding_dim) 96 | # Classification head for matching keywords 97 | self.cls_head = nn.Linear(self.vis_dim // 8, 1) 98 | 99 | 100 | def forward(self, text_input, vision_x, key_words_query=None): 101 | """ 102 | Forward pass for the embedding layer. 103 | 104 | Args: 105 | text_input: Text token indices [B, L] 106 | vision_x: Visual input features [B, S, C, H, W, D] 107 | key_words_query: Optional list of key words for contrastive learning 108 | 109 | Returns: 110 | tuple: (output_embeddings, loss_matching) 111 | - output_embeddings: Combined embeddings for text and vision 112 | - loss_matching: Contrastive loss for keyword matching (or None) 113 | """ 114 | if self.flag == 'Text': 115 | # Process in text mode 116 | B, S, C, H, W, D = vision_x.shape 117 | # Reshape for batch processing 118 | vision_x = rearrange(vision_x, "b S c h w d-> (b S) c h w d") 119 | 120 | # Process through vision encoder 121 | vision_x, pos_embedding = self.vision_encoder(vision_x) 122 | 123 | # Reshape back to batch format 124 | vision_x = rearrange(vision_x, "(b s F) v d -> b s F v d", b=B, s=S, F=1) 125 | 126 | loss_matching = None 127 | 128 | if key_words_query is not None: 129 | ## we do not use the following parts in final version. 130 | ## You can quota the following codes and if so the bert models will be useless. 131 | # key_words_query list[list[str]] B, words, each word matches corresponding vision_x embedding 132 | 133 | # Extract and deduplicate keywords 134 | query_words = [item for sublist in key_words_query for item in sublist] 135 | query_words = list(set(query_words)) 136 | 137 | # Limit number of keywords to process 138 | if len(query_words) > 16: 139 | random.shuffle(query_words) 140 | query_words = query_words[0:16] 141 | 142 | if query_words != []: 143 | # Create binary labels for contrastive learning 144 | contrastive_labels = torch.zeros(B, len(query_words)) # B Q 145 | for i, sublist in enumerate(key_words_query): 146 | for j, item in enumerate(query_words): 147 | if item in sublist: 148 | contrastive_labels[i, j] = 1 149 | contrastive_labels = contrastive_labels.to(vision_x.dtype).to(vision_x.device) 150 | 151 | # Get BERT embeddings for keywords 152 | with torch.no_grad(): 153 | query_words_embedding = self.bert_tokenizer( 154 | query_words, 155 | padding='max_length', 156 | truncation=True, 157 | max_length=256, 158 | return_tensors="pt" 159 | ) 160 | query_words_embedding = self.bert_model( 161 | input_ids=query_words_embedding['input_ids'].to(vision_x.device), 162 | attention_mask=query_words_embedding['attention_mask'].to(vision_x.device) 163 | )['last_hidden_state'][:, 0, :].to(vision_x.dtype).to(vision_x.device) # Q,D 164 | 165 | # Project BERT embeddings to vision space 166 | query_words_embedding = self.bert_projection_fc(query_words_embedding) 167 | query_words_embedding = query_words_embedding.unsqueeze(0).repeat(B, 1, 1) # B,Q,D 168 | _, N, _ = query_words_embedding.shape 169 | 170 | # Pool vision features 171 | image_embedding = vision_x.mean(dim=1) # B V D average pooling to remove multimodality 172 | image_embedding = rearrange(image_embedding, "b F v d -> b (F v) d") 173 | pos_embedding = rearrange(pos_embedding, "(b s) v d -> b s v d", b=B, s=S)[:, 0, :, :] 174 | 175 | # Prepare inputs for transformer decoder 176 | image_embedding = image_embedding.transpose(0, 1) # (H/P W/P D/P) B D 177 | pos_embedding = pos_embedding.transpose(0, 1) # (H/P W/P D/P) B D 178 | query_words_embedding = query_words_embedding.transpose(0, 1) # N B D 179 | 180 | # Cross-attention between keywords and image features 181 | oo_embedding, _ = self.transformer_decoder( 182 | query_words_embedding, image_embedding, pos=pos_embedding 183 | ) 184 | oo_embedding = oo_embedding.transpose(0, 1) # B Q D 185 | oo_embedding = rearrange(oo_embedding, 'b n d -> (b n) d') 186 | oo_embedding = self.transformer_decoder_mlp(oo_embedding) 187 | oo_embedding = self.cls_head(oo_embedding).mean(dim=-1) 188 | oo_embedding = rearrange(oo_embedding, '(b n) -> b n', b=B, n=N) # B Q 189 | 190 | # Calculate contrastive loss 191 | loss_matching = F.binary_cross_entropy_with_logits(oo_embedding, contrastive_labels) 192 | 193 | # Process vision features through perceiver resampler 194 | vision_x = self.perceiver(vision_x) # reshapes to (b, S, n, d) 195 | 196 | n = vision_x.shape[2] 197 | 198 | # Project vision features to embedding dimension 199 | vision_x = rearrange(vision_x, "b s n d -> (b s n) d") 200 | vision_x = self.fc(vision_x) 201 | vision_x = rearrange(vision_x, "(b T) d -> b T d", b=B, T=n*S) 202 | 203 | # Combine text and vision embeddings 204 | embedding_weight = torch.cat([self.weight, self.figure_token_weight], dim=0) 205 | embedding_weight = embedding_weight.unsqueeze(0).repeat(B, 1, 1) 206 | embedding_weight = torch.cat([embedding_weight, vision_x], dim=1) 207 | 208 | # Convert text indices to one-hot and compute final embeddings 209 | text_input = F.one_hot(text_input, embedding_weight.shape[1]).to(vision_x.dtype).to(vision_x.device) 210 | out_put = torch.matmul(text_input, embedding_weight) 211 | 212 | ## useless for now. ignore the folowing code## 213 | # if self.flag == 'Seg': 214 | # B,C,H,W,D = vision_x.shape 215 | # _,N,_ = text_input.shape 216 | # latent_embedding, pos_embedding = self.vision_encoder(vision_x) # B (H/P W/P D/P) D 217 | 218 | # image_embedding = latent_embedding.transpose(0,1) # (H/P W/P D/P) B D 219 | # pos_embedding = pos_embedding.transpose(0,1) # (H/P W/P D/P) B D 220 | # text_input = text_input.transpose(0,1) # N B D 221 | 222 | # mask_embedding,_ = self.transformer_decoder(text_input, image_embedding, pos = pos_embedding) 223 | # mask_embedding = mask_embedding.transpose(0,1) # B N D 224 | # mask_embedding = rearrange(mask_embedding, 'b n d -> (b n) d') 225 | # mask_embedding = self.transformer_decoder_mlp(mask_embedding) 226 | # mask_embedding = rearrange(mask_embedding, '(b n) d -> b n d', b=B, n=N,d = self.vis_dim // 8) 227 | 228 | # vision_x = rearrange(latent_embedding,'b (h w d) c -> b c h w d', h = (H // self.patch_size), w = (W // self.patch_size), d = (D // self.frame_patch_size), c=self.vis_dim) 229 | # vision_x = self.output_upscaling(vision_x) #B C H/4 W/4 D/4 230 | # out_put = torch.einsum('bchwd,bnc->bnhwd', vision_x, mask_embedding) 231 | 232 | return out_put, loss_matching 233 | 234 | # model = MyEmbedding(vision_encoder_path = '') 235 | # text_input = torch.randint(low=0, high=3210, size=(4,2048)) 236 | # image_input = torch.randn((4,3,3,512,512,4)) 237 | # key_words_query = [[],[],[],['consoliation']] 238 | # print(model(text_input, image_input, key_words_query)) -------------------------------------------------------------------------------- /src/Model/RadFM/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | from einops.layers.torch import Rearrange 9 | from einops import rearrange, repeat 10 | 11 | class PositionEmbeddingSine(nn.Module): 12 | """ 13 | This is a more standard version of the position embedding, very similar to the one 14 | used by the Attention is all you need paper, generalized to work on images. 15 | """ 16 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 17 | super().__init__() 18 | self.num_pos_feats = num_pos_feats 19 | self.temperature = temperature 20 | self.normalize = normalize 21 | if scale is not None and normalize is False: 22 | raise ValueError("normalize should be True if scale is passed") 23 | if scale is None: 24 | scale = 2 * math.pi 25 | self.scale = scale 26 | 27 | def forward(self, tensor_list): 28 | x = tensor_list.tensors 29 | mask = tensor_list.mask 30 | assert mask is not None 31 | not_mask = ~mask 32 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 33 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 34 | if self.normalize: 35 | eps = 1e-6 36 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 37 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 38 | 39 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 40 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 41 | 42 | pos_x = x_embed[:, :, :, None] / dim_t 43 | pos_y = y_embed[:, :, :, None] / dim_t 44 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 45 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 46 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 47 | return pos 48 | 49 | 50 | class PositionEmbeddingLearned(nn.Module): 51 | """ 52 | Absolute pos embedding, learned. 53 | """ 54 | def __init__(self, num_pos_feats=256): 55 | super().__init__() 56 | self.row_embed = nn.Embedding(50, num_pos_feats) 57 | self.col_embed = nn.Embedding(50, num_pos_feats) 58 | self.reset_parameters() 59 | 60 | def reset_parameters(self): 61 | nn.init.uniform_(self.row_embed.weight) 62 | nn.init.uniform_(self.col_embed.weight) 63 | 64 | def forward(self, tensor_list): 65 | x = tensor_list.tensors 66 | h, w = x.shape[-2:] 67 | i = torch.arange(w, device=x.device) 68 | j = torch.arange(h, device=x.device) 69 | x_emb = self.col_embed(i) 70 | y_emb = self.row_embed(j) 71 | pos = torch.cat([ 72 | x_emb.unsqueeze(0).repeat(h, 1, 1), 73 | y_emb.unsqueeze(1).repeat(1, w, 1), 74 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 75 | return pos 76 | 77 | class PositionEmbeddingLearned3d(nn.Module): 78 | """ 79 | Absolute pos embedding, learned. 80 | """ 81 | def __init__(self, num_pos_feats=256,h_patch_num = 16, w_patch_num = 16,d_patch_num = 64): 82 | super().__init__() 83 | self.h_patch_num = h_patch_num 84 | self.w_patch_num = w_patch_num 85 | self.d_patch_num = d_patch_num 86 | self.row_embed = nn.Embedding(h_patch_num, num_pos_feats) 87 | self.col_embed = nn.Embedding(w_patch_num, num_pos_feats) 88 | self.dep_embed = nn.Embedding(d_patch_num, num_pos_feats) 89 | self.reset_parameters() 90 | 91 | def reset_parameters(self): 92 | nn.init.uniform_(self.row_embed.weight) 93 | nn.init.uniform_(self.col_embed.weight) 94 | nn.init.uniform_(self.dep_embed.weight) 95 | 96 | def forward(self, B, h, w, d,x): 97 | i = (torch.arange(h, device=x.device) + 1)* (self.h_patch_num // h) -1 98 | j = (torch.arange(w, device=x.device) + 1)* (self.w_patch_num // w) -1 99 | k = (torch.arange(d, device=x.device) + 1)* (self.d_patch_num // d) -1 100 | x_emb = self.row_embed(i).unsqueeze(1).unsqueeze(2).repeat(1,w,d,1) 101 | y_emb = self.col_embed(j).unsqueeze(0).unsqueeze(2).repeat(h,1,d,1) 102 | z_emb = self.dep_embed(k).unsqueeze(0).unsqueeze(1).repeat(h,w,1,1) 103 | pos = torch.cat([x_emb,y_emb,z_emb,], dim=-1).unsqueeze(0).repeat(B, 1, 1, 1, 1) 104 | pos = rearrange(pos,'b h w d c -> b (h w d) c') 105 | return pos 106 | 107 | def build_position_encoding(args): 108 | N_steps = args.hidden_dim // 2 109 | if args.position_embedding in ('v2', 'sine'): 110 | # TODO find a better way of exposing other arguments 111 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 112 | elif args.position_embedding in ('v3', 'learned'): 113 | position_embedding = PositionEmbeddingLearned(N_steps) 114 | else: 115 | raise ValueError(f"not supported {args.position_embedding}") 116 | 117 | return position_embedding 118 | 119 | # Pos = PositionEmbeddingLearned3d() 120 | # x = torch.randn((8,3,32,32,1)) 121 | # print(Pos(8,16,16,1,x)) -------------------------------------------------------------------------------- /src/Model/RadFM/transformer_decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code modified from DETR tranformer: 3 | https://github.com/facebookresearch/detr 4 | Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 5 | """ 6 | 7 | import copy 8 | from typing import Optional, List 9 | import pickle as cp 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import nn, Tensor 14 | 15 | 16 | class TransformerDecoder(nn.Module): 17 | 18 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 19 | super().__init__() 20 | self.layers = _get_clones(decoder_layer, num_layers) 21 | self.num_layers = num_layers 22 | self.norm = norm 23 | self.return_intermediate = return_intermediate 24 | 25 | def forward(self, tgt, memory, 26 | tgt_mask: Optional[Tensor] = None, 27 | memory_mask: Optional[Tensor] = None, 28 | tgt_key_padding_mask: Optional[Tensor] = None, 29 | memory_key_padding_mask: Optional[Tensor] = None, 30 | pos: Optional[Tensor] = None, 31 | query_pos: Optional[Tensor] = None): 32 | output = tgt 33 | T,B,C = memory.shape 34 | intermediate = [] 35 | atten_layers = [] 36 | for n,layer in enumerate(self.layers): 37 | 38 | residual=True 39 | output,ws = layer(output, memory, tgt_mask=tgt_mask, 40 | memory_mask=memory_mask, 41 | tgt_key_padding_mask=tgt_key_padding_mask, 42 | memory_key_padding_mask=memory_key_padding_mask, 43 | pos=pos, query_pos=query_pos,residual=residual) 44 | atten_layers.append(ws) 45 | if self.return_intermediate: 46 | intermediate.append(self.norm(output)) 47 | if self.norm is not None: 48 | output = self.norm(output) 49 | if self.return_intermediate: 50 | intermediate.pop() 51 | intermediate.append(output) 52 | 53 | if self.return_intermediate: 54 | return torch.stack(intermediate) 55 | return output,atten_layers 56 | 57 | 58 | 59 | class TransformerDecoderLayer(nn.Module): 60 | 61 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 62 | activation="relu", normalize_before=False): 63 | super().__init__() 64 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 65 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 66 | # Implementation of Feedforward model 67 | self.linear1 = nn.Linear(d_model, dim_feedforward) 68 | self.dropout = nn.Dropout(dropout) 69 | self.linear2 = nn.Linear(dim_feedforward, d_model) 70 | 71 | self.norm1 = nn.LayerNorm(d_model) 72 | self.norm2 = nn.LayerNorm(d_model) 73 | self.norm3 = nn.LayerNorm(d_model) 74 | self.dropout1 = nn.Dropout(dropout) 75 | self.dropout2 = nn.Dropout(dropout) 76 | self.dropout3 = nn.Dropout(dropout) 77 | 78 | self.activation = _get_activation_fn(activation) 79 | self.normalize_before = normalize_before 80 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 81 | return tensor if pos is None else tensor + pos 82 | 83 | def forward_post(self, tgt, memory, 84 | tgt_mask: Optional[Tensor] = None, 85 | memory_mask: Optional[Tensor] = None, 86 | tgt_key_padding_mask: Optional[Tensor] = None, 87 | memory_key_padding_mask: Optional[Tensor] = None, 88 | pos: Optional[Tensor] = None, 89 | query_pos: Optional[Tensor] = None, 90 | residual=True): 91 | q = k = self.with_pos_embed(tgt, query_pos) 92 | tgt2,ws = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 93 | key_padding_mask=tgt_key_padding_mask) 94 | tgt = self.norm1(tgt) 95 | tgt2,ws = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 96 | key=self.with_pos_embed(memory, pos), 97 | value=memory, attn_mask=memory_mask, 98 | key_padding_mask=memory_key_padding_mask) 99 | 100 | 101 | # attn_weights [B,NUM_Q,T] 102 | tgt = tgt + self.dropout2(tgt2) 103 | tgt = self.norm2(tgt) 104 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 105 | tgt = tgt + self.dropout3(tgt2) 106 | tgt = self.norm3(tgt) 107 | return tgt,ws 108 | 109 | def forward_pre(self, tgt, memory, 110 | tgt_mask: Optional[Tensor] = None, 111 | memory_mask: Optional[Tensor] = None, 112 | tgt_key_padding_mask: Optional[Tensor] = None, 113 | memory_key_padding_mask: Optional[Tensor] = None, 114 | pos: Optional[Tensor] = None, 115 | query_pos: Optional[Tensor] = None): 116 | tgt2 = self.norm1(tgt) 117 | q = k = self.with_pos_embed(tgt2, query_pos) 118 | tgt2,ws = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 119 | key_padding_mask=tgt_key_padding_mask) 120 | tgt = tgt + self.dropout1(tgt2) 121 | tgt2 = self.norm2(tgt) 122 | tgt2,attn_weights = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 123 | key=self.with_pos_embed(memory, pos), 124 | value=memory, attn_mask=memory_mask, 125 | key_padding_mask=memory_key_padding_mask) 126 | tgt = tgt + self.dropout2(tgt2) 127 | tgt2 = self.norm3(tgt) 128 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 129 | tgt = tgt + self.dropout3(tgt2) 130 | return tgt,attn_weights 131 | 132 | def forward(self, tgt, memory, 133 | tgt_mask: Optional[Tensor] = None, 134 | memory_mask: Optional[Tensor] = None, 135 | tgt_key_padding_mask: Optional[Tensor] = None, 136 | memory_key_padding_mask: Optional[Tensor] = None, 137 | pos: Optional[Tensor] = None, 138 | query_pos: Optional[Tensor] = None, 139 | residual=True): 140 | if self.normalize_before: 141 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 142 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 143 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 144 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos,residual) 145 | 146 | 147 | def _get_clones(module, N): 148 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 149 | 150 | 151 | 152 | def _get_activation_fn(activation): 153 | """Return an activation function given a string""" 154 | if activation == "relu": 155 | return F.relu 156 | if activation == "gelu": 157 | return F.gelu 158 | if activation == "glu": 159 | return F.glu 160 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 161 | -------------------------------------------------------------------------------- /src/Model/RadFM/utils.py: -------------------------------------------------------------------------------- 1 | from .blocks import ModifiedResNet,PMC_CLIP_cfg 2 | import torch 3 | from torchvision import transforms 4 | from PIL import Image 5 | import torch.nn as nn 6 | def extend_instance(obj, mixin): 7 | """Apply mixins to a class instance after creation""" 8 | base_cls = obj.__class__ 9 | base_cls_name = obj.__class__.__name__ 10 | obj.__class__ = type( 11 | base_cls_name, (mixin, base_cls), {} 12 | ) # mixin needs to go first for our forward() logic to work 13 | 14 | 15 | def getattr_recursive(obj, att): 16 | """ 17 | Return nested attribute of obj 18 | Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c 19 | """ 20 | if att == "": 21 | return obj 22 | i = att.find(".") 23 | if i < 0: 24 | return getattr(obj, att) 25 | else: 26 | return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) 27 | 28 | 29 | def setattr_recursive(obj, att, val): 30 | """ 31 | Set nested attribute of obj 32 | Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val 33 | """ 34 | if "." in att: 35 | obj = getattr_recursive(obj, ".".join(att.split(".")[:-1])) 36 | setattr(obj, att.split(".")[-1], val) 37 | 38 | 39 | 40 | def get_visual_encoder(model_str): 41 | """ 42 | Args: 43 | str (_type_): str_to_model_path 44 | Return: 45 | vision_model, visual_dim, img_preprocessor 46 | """ 47 | normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 48 | img_preprocessor = transforms.Compose([ 49 | transforms.Resize((512,512), interpolation=Image.BICUBIC), 50 | transforms.ToTensor(), 51 | normalize, 52 | ]) 53 | if 'PMC-CLIP' in model_str: 54 | #vision_cfg = json.load(open(model_args.visual_model_config,'r'))['vision_cfg'] 55 | vision_cfg = PMC_CLIP_cfg() 56 | vision_heads = vision_cfg.width * 32 // vision_cfg.head_width 57 | vision_model = ModifiedResNet( 58 | layers=vision_cfg.layers, 59 | heads=vision_heads, 60 | output_dim = 768, 61 | image_size=vision_cfg.image_size, 62 | width=vision_cfg.width 63 | ) 64 | vision_model = vision_load_pretrain(vision_model,model_str) 65 | vision_model = nn.Sequential(*list(vision_model.children())[:-2]) 66 | visual_dim = 1024 67 | return vision_model,visual_dim,img_preprocessor 68 | 69 | def vision_load_pretrain(resnet,model_path): 70 | checkpoint = torch.load(model_path, map_location='cpu') 71 | state_dict = checkpoint['state_dict'] 72 | state_dict = {k.replace('module.visual.',''): v for k, v in state_dict.items() if '.visual' in k} 73 | resnet.load_state_dict(state_dict) 74 | return resnet 75 | -------------------------------------------------------------------------------- /src/Model/RadFM/vit_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from einops import rearrange, repeat 5 | from einops.layers.torch import Rearrange 6 | from .position_encoding import PositionEmbeddingLearned3d 7 | 8 | # helpers 9 | 10 | def pair(t): 11 | return t if isinstance(t, tuple) else (t, t) 12 | 13 | # classes 14 | 15 | class PreNorm(nn.Module): 16 | def __init__(self, dim, fn): 17 | super().__init__() 18 | self.norm = nn.LayerNorm(dim) 19 | self.fn = fn 20 | def forward(self, x, **kwargs): 21 | return self.fn(self.norm(x), **kwargs) 22 | 23 | class FeedForward(nn.Module): 24 | def __init__(self, dim, hidden_dim, dropout = 0.): 25 | super().__init__() 26 | self.net = nn.Sequential( 27 | nn.Linear(dim, hidden_dim), 28 | nn.GELU(), 29 | nn.Dropout(dropout), 30 | nn.Linear(hidden_dim, dim), 31 | nn.Dropout(dropout) 32 | ) 33 | def forward(self, x): 34 | return self.net(x) 35 | 36 | class Attention(nn.Module): 37 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 38 | super().__init__() 39 | inner_dim = dim_head * heads 40 | project_out = not (heads == 1 and dim_head == dim) 41 | 42 | self.heads = heads 43 | self.scale = dim_head ** -0.5 44 | 45 | self.attend = nn.Softmax(dim = -1) 46 | self.dropout = nn.Dropout(dropout) 47 | 48 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 49 | 50 | self.to_out = nn.Sequential( 51 | nn.Linear(inner_dim, dim), 52 | nn.Dropout(dropout) 53 | ) if project_out else nn.Identity() 54 | 55 | def forward(self, x): 56 | qkv = self.to_qkv(x).chunk(3, dim = -1) 57 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 58 | 59 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 60 | 61 | attn = self.attend(dots) 62 | attn = self.dropout(attn) 63 | 64 | out = torch.matmul(attn, v) 65 | out = rearrange(out, 'b h n d -> b n (h d)') 66 | return self.to_out(out) 67 | 68 | class Transformer(nn.Module): 69 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 70 | super().__init__() 71 | self.layers = nn.ModuleList([]) 72 | for _ in range(depth): 73 | self.layers.append(nn.ModuleList([ 74 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 75 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 76 | ])) 77 | def forward(self, x): 78 | for attn, ff in self.layers: 79 | x = attn(x) + x 80 | x = ff(x) + x 81 | return x 82 | 83 | class ViT(nn.Module): 84 | def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): 85 | super().__init__() 86 | image_height, image_width = pair(image_size) 87 | patch_height, patch_width = pair(image_patch_size) 88 | 89 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 90 | assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size' 91 | 92 | self.patch_height = patch_height 93 | self.patch_width = patch_width 94 | self.frame_patch_size = frame_patch_size 95 | 96 | num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size) 97 | patch_dim = channels * patch_height * patch_width * frame_patch_size 98 | 99 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 100 | 101 | self.to_patch_embedding = nn.Sequential( 102 | Rearrange('b c (h p1) (w p2) (f pf) -> b (h w f) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size), 103 | nn.LayerNorm(patch_dim), 104 | nn.Linear(patch_dim, dim), 105 | nn.LayerNorm(dim), 106 | ) 107 | 108 | self.pos_embedding = PositionEmbeddingLearned3d(dim // 3,(image_height // patch_height), (image_width // patch_width), (frames // frame_patch_size)) 109 | self.dropout = nn.Dropout(emb_dropout) 110 | 111 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 112 | 113 | def forward(self, video): 114 | B, C, H, W, D = video.shape 115 | x = self.to_patch_embedding(video) 116 | b, n, _ = x.shape 117 | 118 | pos = self.pos_embedding(B, H // self.patch_height, W // self.patch_width, D // self.frame_patch_size,x) 119 | x += pos 120 | x = self.dropout(x) 121 | 122 | x = self.transformer(x) 123 | return x,pos 124 | -------------------------------------------------------------------------------- /src/My_Trainer/__pycache__/trainer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/My_Trainer/__pycache__/trainer.cpython-39.pyc -------------------------------------------------------------------------------- /src/datasampler.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | import math 3 | from torch.utils.data.sampler import Sampler 4 | from torch.utils.data.sampler import Sampler 5 | from torch.utils.data import DataLoader, DistributedSampler 6 | import random 7 | import torch 8 | from Dataset.multi_dataset import multi_dataset 9 | 10 | def make_batch(index_list, batch_size, drop_last): 11 | if drop_last: 12 | batches = [] 13 | whole_batch_num = len(index_list)//batch_size 14 | for _ in range(whole_batch_num): 15 | batches.append(index_list[batch_size*_:(batch_size*(_+1))]) 16 | else: 17 | batches = [] 18 | whole_batch_num = math.ceil(len(index_list)/batch_size) 19 | for _ in range(whole_batch_num): 20 | batches.append(index_list[batch_size*_:(batch_size*(_+1))]) 21 | return batches 22 | 23 | def batch_generation(dataset,batch_size_2D, batch_size_3D,drop_last=False,shuffle = True, seed = 0): 24 | 25 | len_2D = len(dataset.data_whole_2D) 26 | len_3D = len(dataset.data_whole_3D) 27 | index_2D = list(range(len_2D)) 28 | index_3D = list(range(len_2D,(len_2D+len_3D))) 29 | assert len(index_2D) + len(index_3D) == len(dataset.data_whole) 30 | 31 | if shuffle: 32 | # deterministically shuffle based on epoch and seed 33 | g = torch.Generator() 34 | g.manual_seed(seed) 35 | random.shuffle(index_2D) 36 | random.shuffle(index_3D) 37 | 38 | batch_2D = make_batch(index_2D, batch_size_2D, drop_last) 39 | batch_3D = make_batch(index_3D, batch_size_3D, drop_last) 40 | 41 | batch_chunk = batch_2D + batch_3D 42 | return batch_chunk 43 | 44 | class My_DistributedBatchSampler(Sampler): 45 | """ Iterable wrapper that distributes data across multiple workers. 46 | 47 | Args: 48 | iterable (iterable) 49 | num_replicas (int, optional): Number of processes participating in distributed training. 50 | rank (int, optional): Rank of the current process within ``num_replicas``. 51 | 52 | Example: 53 | >>> list(DistributedSampler(range(10), num_replicas=2, rank=0)) 54 | [0, 2, 4, 6, 8] 55 | >>> list(DistributedSampler(range(10), num_replicas=2, rank=1)) 56 | [1, 3, 5, 7, 9] 57 | """ 58 | 59 | def __init__(self, dataset, num_replicas=None, rank=None, batch_size_2D = 4, batch_size_3D = 1, drop_last = False, shuffle = True, seed: int = 0): 60 | self.num_replicas = num_replicas 61 | self.rank = rank 62 | self.drop_last = drop_last 63 | self.shuffle = shuffle 64 | self.dataset = dataset 65 | self.batch_size_2D = batch_size_2D 66 | self.batch_size_3D = batch_size_3D 67 | self.seed = seed 68 | self.epoch = 0 69 | 70 | if num_replicas is None or rank is None: # pragma: no cover 71 | if not torch.distributed.is_initialized(): 72 | raise RuntimeError('Requires `torch.distributed` to be initialized.') 73 | 74 | self.num_replicas = ( 75 | torch.distributed.get_world_size() if num_replicas is None else num_replicas) 76 | self.rank = torch.distributed.get_rank() if rank is None else rank 77 | 78 | indices = batch_generation(self.dataset,self.batch_size_2D,self.batch_size_3D,self.drop_last,self.shuffle) 79 | if self.rank >= self.num_replicas: 80 | raise IndexError('`rank` must be smaller than the `num_replicas`.') 81 | 82 | if self.drop_last and len(indices) % self.num_replicas != 0: # type: ignore[arg-type] 83 | # Split to nearest available length that is evenly divisible. 84 | # This is to ensure each rank receives the same amount of data when 85 | # using this Sampler. 86 | self.num_samples = math.ceil( 87 | (len(indices) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] 88 | ) 89 | else: 90 | self.num_samples = math.ceil(len(indices) / self.num_replicas) # type: ignore[arg-type] 91 | self.total_size = self.num_samples * self.num_replicas 92 | 93 | def __iter__(self): 94 | indices = batch_generation(self.dataset,self.batch_size_2D,self.batch_size_3D,self.drop_last,self.shuffle,self.seed + self.epoch) 95 | # print(indices) 96 | if self.shuffle: 97 | # deterministically shuffle based on epoch and seed 98 | g = torch.Generator() 99 | g.manual_seed(self.seed + self.epoch) 100 | random.shuffle(indices) 101 | 102 | if not self.drop_last: 103 | # add extra samples to make it evenly divisible 104 | padding_size = self.total_size - len(indices) 105 | if padding_size <= len(indices): 106 | indices += indices[:padding_size] 107 | else: 108 | indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] 109 | else: 110 | # remove tail of data to make it evenly divisible. 111 | indices = indices[:self.total_size] 112 | assert len(indices) == self.total_size 113 | 114 | # subsample 115 | indices = indices[self.rank:self.total_size:self.num_replicas] 116 | assert len(indices) == self.num_samples 117 | 118 | return iter(indices) 119 | 120 | def __len__(self): 121 | return self.num_samples 122 | 123 | def set_epoch(self, epoch: int) -> None: 124 | r""" 125 | Set the epoch for this sampler. 126 | 127 | When :attr:`shuffle=True`, this ensures all replicas 128 | use a different random ordering for each epoch. Otherwise, the next iteration of this 129 | sampler will yield the same ordering. 130 | 131 | Args: 132 | epoch (int): Epoch number. 133 | """ 134 | self.epoch = epoch 135 | 136 | 137 | # print(My_DistributedBatchSampler) 138 | # Train_dataset = multi_dataset(text_tokenizer = '/mnt/petrelfs/share_data/zhangxiaoman/CODE/RadFM/src/Language_models/tokenizer') 139 | 140 | # DDP_sample_0 = list(My_DistributedBatchSampler(dataset= Train_dataset , num_replicas = 32, rank = 0,)) 141 | # DDP_sample_1 = list(My_DistributedBatchSampler(dataset= Train_dataset , num_replicas = 32, rank = 1,)) 142 | 143 | # for ii in DDP_sample_0: 144 | # print(ii) 145 | 146 | # for ii in DDP_sample_1: 147 | # print(ii) 148 | 149 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | # Import necessary libraries for data processing, modeling, and utilities 2 | import tqdm.auto as tqdm 3 | import torch.nn.functional as F 4 | from typing import Optional, Dict, Sequence 5 | from typing import List, Optional, Tuple, Union 6 | import transformers 7 | from My_Trainer.trainer import Trainer 8 | from dataclasses import dataclass, field 9 | from Dataset.multi_dataset_test import multi_dataset 10 | from Model.RadFM.multimodality_model import MultiLLaMAForCausalLM 11 | from datasampler import My_DistributedBatchSampler 12 | import torch 13 | from torch.utils.data import DataLoader 14 | import csv 15 | import random 16 | import numpy as np 17 | 18 | def setup_seed(seed): 19 | """ 20 | Set random seeds for reproducibility across different libraries 21 | 22 | Args: 23 | seed: Integer seed value to use 24 | """ 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed_all(seed) 27 | np.random.seed(seed) 28 | random.seed(seed) 29 | torch.backends.cudnn.deterministic = True 30 | 31 | # Set seed for reproducibility 32 | setup_seed(20) 33 | 34 | 35 | @dataclass 36 | class ModelArguments: 37 | """ 38 | Arguments related to model paths and configuration 39 | """ 40 | lang_encoder_path: Optional[str] = field(default="/home/cs/leijiayu/wuchaoyi/book_pretrain/Results/Book_mix_2048_13B_full/checkpoint-45800") 41 | tokenizer_path: str = field(default='/home/cs/leijiayu/wuchaoyi/Finetune_LLAMA/LLAMA_Model/tokenizer', metadata={"help": "Path to the tokenizer data."}) 42 | #vision_encoder_path: str = field(default='/home/cs/leijiayu/wuchaoyi/multi_modal/src/PMC-CLIP/checkpoint.pt', metadata={"help": "Path to the vision_encoder."}) 43 | 44 | 45 | @dataclass 46 | class DataArguments: 47 | """ 48 | Arguments related to dataset configuration and testing modes 49 | """ 50 | Mode: Optional[str] = field(default="Train") 51 | test_split: Optional[str] = field(default="open") 52 | 53 | @dataclass 54 | class TrainingArguments(transformers.TrainingArguments): 55 | """ 56 | Custom training arguments extending HuggingFace's TrainingArguments 57 | with additional parameters for multimodal training 58 | """ 59 | remove_unused_columns: bool = field(default = False) 60 | batch_size_2D: int = field(default = 4) # Batch size for 2D data 61 | batch_size_3D: int = field(default = 1) # Batch size for 3D data 62 | output_dir: Optional[str] = field(default="/home/cs/leijiayu/wuchaoyi/multi_modal/src/Results/BLIP_overfit/") 63 | cache_dir: Optional[str] = field(default=None) 64 | optim: str = field(default="adamw_torch") 65 | 66 | 67 | @dataclass 68 | class DataCollator(object): 69 | """ 70 | Data collator for preparing batches of multimodal inputs for the model 71 | """ 72 | 73 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 74 | # Extract different components from the input instances 75 | vision_xs, lang_xs, attention_masks, labels = tuple( 76 | [instance[key] for instance in instances] 77 | for key in ('vision_x','lang_x', 'attention_mask', 'labels') 78 | ) 79 | 80 | # Stack language tensors along batch dimension 81 | lang_xs = torch.cat([_.unsqueeze(0) for _ in lang_xs], dim=0) 82 | attention_masks = torch.cat([_.unsqueeze(0) for _ in attention_masks], dim=0) 83 | labels = torch.cat([_.unsqueeze(0) for _ in labels], dim=0) 84 | 85 | # Set target dimensions for resizing vision inputs 86 | target_H = 512 87 | target_W = 512 88 | target_D = 4 89 | MAX_D = 0 90 | 91 | # Reduce resolution for single samples to save memory 92 | if len(vision_xs) == 1: 93 | target_H = 256 94 | target_W = 256 95 | 96 | # Define possible depth values for 3D data 97 | D_list = list(range(4,65,4)) 98 | # Adjust depth values for large inputs 99 | if len(vision_xs) == 1: 100 | if vision_xs[0].shape[0] > 6: 101 | D_list = list(range(4,33,4)) 102 | 103 | # Find maximum depth in current batch 104 | for ii in vision_xs: 105 | try: 106 | D = ii.shape[-1] 107 | if D > MAX_D: 108 | MAX_D = D 109 | except: 110 | continue 111 | 112 | # Select closest target depth from available options 113 | for temp_D in D_list: 114 | if abs(temp_D - MAX_D) < abs(target_D - MAX_D): 115 | target_D = temp_D 116 | 117 | # Resize all vision inputs to target dimensions 118 | vision_xs = [torch.nn.functional.interpolate(s, size=(target_H, target_W, target_D)) for s in vision_xs] 119 | 120 | # Pad sequence for variable-length vision inputs 121 | vision_xs = torch.nn.utils.rnn.pad_sequence( 122 | vision_xs, batch_first=True, padding_value=0 123 | ) 124 | print(vision_xs.shape) 125 | 126 | # Return collated batch 127 | return dict( 128 | lang_x=lang_xs, 129 | vision_x=vision_xs, 130 | attention_mask=attention_masks, 131 | labels=labels, 132 | ) 133 | 134 | def main(): 135 | """ 136 | Main function to set up and run the inference process 137 | """ 138 | # Parse command-line arguments 139 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 140 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 141 | 142 | # Set custom data sampler 143 | training_args.data_sampler = My_DistributedBatchSampler 144 | 145 | print("Setup Data") 146 | # Initialize test dataset with specified split 147 | Test_dataset = multi_dataset(text_tokenizer=model_args.tokenizer_path, test_split=data_args.test_split) 148 | 149 | # Configure DataLoader for test dataset 150 | Test_dataloader = DataLoader( 151 | Test_dataset, 152 | batch_size=1, 153 | num_workers=1, 154 | pin_memory=True, 155 | sampler=None, 156 | shuffle=True, 157 | collate_fn=None, 158 | drop_last=False, 159 | ) 160 | 161 | print("Setup Model") 162 | # Initialize the multimodal model 163 | model = MultiLLaMAForCausalLM( 164 | lang_model_path=model_args.lang_encoder_path, 165 | ) 166 | 167 | # Load pre-trained model checkpoint 168 | ckpt = torch.load('/gpfs/home/cs/leijiayu/wuchaoyi/wangyingjie/src/Results/backup/checkpoint-17600/pytorch_model.bin', map_location='cpu') 169 | # ckpt.pop('embedding_layer.figure_token_weight') 170 | model.load_state_dict(ckpt, strict=False) 171 | model = model.to('cuda') 172 | model.eval() # Set model to evaluation mode 173 | 174 | # Create output CSV file for results 175 | with open('output_whole_2_epoch' + data_args.test_split + '.csv', mode='w') as outfile: 176 | writer = csv.writer(outfile) 177 | writer.writerow(["Question", "Ground Truth", "Pred", 'belong_to']) 178 | cc = 0 179 | 180 | # Process each sample in the test dataset 181 | for sample in tqdm.tqdm(Test_dataloader): 182 | question = sample["question"] 183 | belong_to = sample['belong_to'] 184 | # img_pp = sample['img_path'] 185 | 186 | # Tokenize the question text 187 | lang_x = Test_dataset.text_tokenizer( 188 | question, max_length=2048, truncation=True, return_tensors="pt" 189 | )['input_ids'].to('cuda') 190 | 191 | # Get vision input 192 | vision_x = sample["vision_x"].to('cuda') 193 | answer = sample['answer'] 194 | 195 | try: 196 | # Generate text based on text and vision inputs 197 | generation = model.generate(lang_x, vision_x) 198 | generated_texts = Test_dataset.text_tokenizer.batch_decode(generation, skip_special_tokens=True) 199 | 200 | # Write results to CSV 201 | writer.writerow([question, answer, generated_texts, belong_to]) 202 | cc = cc + 1 203 | # if cc >= 10000: 204 | # break 205 | except: 206 | continue 207 | 208 | if __name__ == "__main__": 209 | main() -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | # Import necessary libraries 2 | import tqdm.auto as tqdm 3 | import torch.nn.functional as F 4 | from typing import Optional, Dict, Sequence 5 | from typing import List, Optional, Tuple, Union 6 | import transformers 7 | from My_Trainer.trainer import Trainer 8 | from dataclasses import dataclass, field 9 | from Dataset.multi_dataset import multi_dataset 10 | from Model.RadFM.multimodality_model import MultiLLaMAForCausalLM 11 | from datasampler import My_DistributedBatchSampler 12 | from datasets import load_metric 13 | from Dataset.multi_dataset_test_for_close import multi_dataset_close 14 | import numpy as np 15 | import torch 16 | 17 | 18 | def compute_metrics(eval_preds): 19 | """ 20 | Compute evaluation metrics from prediction outputs. 21 | Returns the mean accuracy across all predictions. 22 | 23 | Args: 24 | eval_preds: Prediction outputs from the model 25 | 26 | Returns: 27 | Dictionary containing accuracy metric 28 | """ 29 | # metric = load_metric("glue", "mrpc") 30 | ACCs = eval_preds.predictions 31 | # print(ACCs) 32 | return {"accuracy": np.mean(ACCs, axis=-1)} 33 | 34 | @dataclass 35 | class ModelArguments: 36 | """ 37 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune. 38 | """ 39 | lang_encoder_path: Optional[str] = field(default="/home/cs/leijiayu/wuchaoyi/book_pretrain/Results/Book_mix_2048_13B_full/checkpoint-45800") 40 | tokenizer_path: str = field(default='/home/cs/leijiayu/wuchaoyi/Finetune_LLAMA/LLAMA_Model/tokenizer', 41 | metadata={"help": "Path to the tokenizer data."}) 42 | 43 | 44 | 45 | @dataclass 46 | class DataArguments: 47 | """ 48 | Arguments pertaining to data processing mode. 49 | """ 50 | Mode: Optional[str] = field(default="Train") 51 | 52 | @dataclass 53 | class TrainingArguments(transformers.TrainingArguments): 54 | """ 55 | Custom training arguments extending the HuggingFace TrainingArguments class. 56 | Includes additional parameters specific to this multimodal training setup. 57 | """ 58 | remove_unused_columns: bool = field(default=False) 59 | batch_size_2D: int = field(default=4) # Batch size for 2D data 60 | batch_size_3D: int = field(default=1) # Batch size for 3D data 61 | output_dir: Optional[str] = field(default="/home/cs/leijiayu/wuchaoyi/multi_modal/src/Results/BLIP_overfit/") 62 | cache_dir: Optional[str] = field(default=None) 63 | optim: str = field(default="adamw_torch") 64 | 65 | 66 | @dataclass 67 | class DataCollator(object): 68 | """ 69 | Data collator that handles batching of multimodal inputs. 70 | Processes vision and language inputs, handles padding, and resizes vision inputs. 71 | """ 72 | 73 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 74 | # Extract different data components from instances 75 | vision_xs, lang_xs, attention_masks, labels, loss_reweight, key_words_query = tuple( 76 | [instance[key] for instance in instances] 77 | for key in ('vision_x', 'lang_x', 'attention_mask', 'labels', 'loss_reweight', 'key_words_query') 78 | ) 79 | 80 | # Stack language tensors along batch dimension 81 | lang_xs = torch.cat([_.unsqueeze(0) for _ in lang_xs], dim=0) 82 | attention_masks = torch.cat([_.unsqueeze(0) for _ in attention_masks], dim=0) 83 | labels = torch.cat([_.unsqueeze(0) for _ in labels], dim=0) 84 | loss_reweight = torch.cat([_.unsqueeze(0) for _ in loss_reweight], dim=0) 85 | 86 | # Set target dimensions for vision input resizing 87 | target_H = 512 88 | target_W = 512 89 | target_D = 4 90 | MAX_D = 0 91 | 92 | # Define possible depth values for 3D data 93 | D_list = list(range(4, 65, 4)) 94 | # Adjust depth range for larger inputs 95 | if len(vision_xs) == 1: 96 | if vision_xs[0].shape[0] > 6: 97 | D_list = list(range(4, 33, 4)) 98 | 99 | # Find maximum depth in current batch 100 | for ii in vision_xs: 101 | try: 102 | D = ii.shape[-1] 103 | if D > MAX_D: 104 | MAX_D = D 105 | except: 106 | continue 107 | 108 | # Select closest target depth from available options 109 | for temp_D in D_list: 110 | if abs(temp_D - MAX_D) < abs(target_D - MAX_D): 111 | target_D = temp_D 112 | 113 | # Reduce image dimensions for larger depth inputs with small batch size 114 | if len(vision_xs) == 1 and target_D > 4: 115 | target_H = 256 116 | target_W = 256 117 | 118 | # Resize all vision inputs to target dimensions 119 | vision_xs = [torch.nn.functional.interpolate(s, size=(target_H, target_W, target_D)) for s in vision_xs] 120 | 121 | # Pad sequence for variable-length vision inputs 122 | vision_xs = torch.nn.utils.rnn.pad_sequence( 123 | vision_xs, batch_first=True, padding_value=0 124 | ) 125 | print(vision_xs.shape, vision_xs.dtype) 126 | 127 | # Return collated batch 128 | return dict( 129 | lang_x=lang_xs, 130 | vision_x=vision_xs, 131 | attention_mask=attention_masks, 132 | labels=labels, 133 | loss_reweight=loss_reweight, 134 | key_words_query=key_words_query 135 | ) 136 | 137 | def main(): 138 | """ 139 | Main function to set up and run the training process. 140 | Parses arguments, initializes datasets, model, and trainer. 141 | """ 142 | # Parse command-line arguments 143 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 144 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 145 | 146 | # Set custom data sampler 147 | training_args.data_sampler = My_DistributedBatchSampler 148 | 149 | print("Setup Data") 150 | # Initialize training and evaluation datasets 151 | Train_dataset = multi_dataset(text_tokenizer=model_args.tokenizer_path) 152 | Eval_dataset = multi_dataset_close(text_tokenizer=model_args.tokenizer_path) 153 | 154 | print("Setup Model") 155 | # Initialize the multimodal model 156 | model = MultiLLaMAForCausalLM( 157 | lang_model_path=model_args.lang_encoder_path, 158 | ) 159 | 160 | # Setup trainer with model, datasets, and configurations 161 | trainer = Trainer( 162 | model=model, 163 | train_dataset=Train_dataset, 164 | eval_dataset=Eval_dataset, 165 | args=training_args, 166 | data_collator=DataCollator(), 167 | compute_metrics=compute_metrics 168 | ) 169 | 170 | # Start training 171 | trainer.train() 172 | # Save training state 173 | trainer.save_state() 174 | 175 | if __name__ == "__main__": 176 | main() --------------------------------------------------------------------------------