├── Images
├── GIF.gif
├── result_rationale.jpg
├── result_rationale.pdf
├── result_report.jpg
├── result_report.pdf
├── result_report_generation.jpg
├── result_report_generation.pdf
├── result_vqa.jpg
└── result_vqa.pdf
├── Quick_demo
├── Language_files
│ ├── config.json
│ ├── special_tokens_map.json
│ ├── tokenizer.model
│ └── tokenizer_config.json
├── MedKEBERT
│ ├── config.json
│ ├── special_tokens_map.json
│ ├── tokenizer.json
│ ├── tokenizer_config.json
│ └── vocab.txt
├── Model
│ └── RadFM
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-39.pyc
│ │ ├── blocks.cpython-39.pyc
│ │ ├── helpers.cpython-39.pyc
│ │ ├── multimodality_model.cpython-39.pyc
│ │ ├── my_embedding_layer.cpython-39.pyc
│ │ ├── position_encoding.cpython-39.pyc
│ │ ├── transformer_decoder.cpython-39.pyc
│ │ ├── utils.cpython-39.pyc
│ │ └── vit_3d.cpython-39.pyc
│ │ ├── blocks.py
│ │ ├── helpers.py
│ │ ├── multimodality_model.py
│ │ ├── my_embedding_layer.py
│ │ ├── position_encoding.py
│ │ ├── transformer_decoder.py
│ │ ├── utils.py
│ │ └── vit_3d.py
├── test.py
└── view1_frontal.jpg
├── README.md
├── requirements.txt
└── src
├── Dataset
├── dataset
│ ├── MedPix_dataset.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── MedPix_dataset.cpython-39.pyc
│ │ ├── __init__.cpython-310.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── binary.cpython-39.pyc
│ │ ├── case_report.cpython-39.pyc
│ │ ├── chestxray.cpython-310.pyc
│ │ ├── chestxray.cpython-39.pyc
│ │ ├── paper_inline.cpython-39.pyc
│ │ ├── pmcoa.cpython-39.pyc
│ │ ├── pmcvqa.cpython-39.pyc
│ │ ├── radiopaedia.cpython-39.pyc
│ │ └── radiovqa.cpython-39.pyc
│ ├── binary.py
│ ├── caption_prompt.json
│ ├── case_report.py
│ ├── chestxray.py
│ ├── cls_prompt.json
│ ├── data_csv
│ │ └── README.md
│ ├── dicom_to_png_for_VinDR_sampled_using_mammo.py
│ ├── jpg2nii_data_convert.py
│ ├── mammo_prompt.json
│ ├── modality_prompt.json
│ ├── nii2npy_for_radiopaedio.py
│ ├── paper_inline.py
│ ├── pmcoa.py
│ ├── radiology_feature_prompt.json
│ ├── radiopaedia.py
│ ├── report_prompt.json
│ ├── spinexr_prompt.json
│ ├── vqa.py
│ └── yes_no_prompt.json
├── multi_dataset.py
├── multi_dataset_test.py
└── multi_dataset_test_for_close.py
├── Model
└── RadFM
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-39.pyc
│ ├── blocks.cpython-39.pyc
│ ├── helpers.cpython-39.pyc
│ ├── multimodality_model.cpython-39.pyc
│ ├── my_embedding_layer.cpython-39.pyc
│ ├── position_encoding.cpython-39.pyc
│ ├── transformer_decoder.cpython-39.pyc
│ ├── utils.cpython-39.pyc
│ └── vit_3d.cpython-39.pyc
│ ├── blocks.py
│ ├── helpers.py
│ ├── multimodality_model.py
│ ├── my_embedding_layer.py
│ ├── position_encoding.py
│ ├── transformer_decoder.py
│ ├── utils.py
│ └── vit_3d.py
├── My_Trainer
├── __pycache__
│ └── trainer.cpython-39.pyc
└── trainer.py
├── datasampler.py
├── output_csv_example
└── caption_example.csv
├── test.py
└── train.py
/Images/GIF.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Images/GIF.gif
--------------------------------------------------------------------------------
/Images/result_rationale.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Images/result_rationale.jpg
--------------------------------------------------------------------------------
/Images/result_rationale.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Images/result_rationale.pdf
--------------------------------------------------------------------------------
/Images/result_report.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Images/result_report.jpg
--------------------------------------------------------------------------------
/Images/result_report.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Images/result_report.pdf
--------------------------------------------------------------------------------
/Images/result_report_generation.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Images/result_report_generation.jpg
--------------------------------------------------------------------------------
/Images/result_report_generation.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Images/result_report_generation.pdf
--------------------------------------------------------------------------------
/Images/result_vqa.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Images/result_vqa.jpg
--------------------------------------------------------------------------------
/Images/result_vqa.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Images/result_vqa.pdf
--------------------------------------------------------------------------------
/Quick_demo/Language_files/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "/home/cs/leijiayu/wuchaoyi/Finetune_LLAMA/LLAMA_Model/llama-13b-hf",
3 | "architectures": [
4 | "LlamaForCausalLM"
5 | ],
6 | "bos_token_id": 0,
7 | "eos_token_id": 1,
8 | "hidden_act": "silu",
9 | "hidden_size": 5120,
10 | "initializer_range": 0.02,
11 | "intermediate_size": 13824,
12 | "max_sequence_length": 2048,
13 | "model_type": "llama",
14 | "num_attention_heads": 40,
15 | "num_hidden_layers": 40,
16 | "pad_token_id": -1,
17 | "rms_norm_eps": 1e-06,
18 | "tie_word_embeddings": false,
19 | "torch_dtype": "float32",
20 | "transformers_version": "4.28.0.dev0",
21 | "use_cache": true,
22 | "vocab_size": 32000
23 | }
24 |
--------------------------------------------------------------------------------
/Quick_demo/Language_files/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {}
--------------------------------------------------------------------------------
/Quick_demo/Language_files/tokenizer.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/Language_files/tokenizer.model
--------------------------------------------------------------------------------
/Quick_demo/Language_files/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {"bos_token": "", "eos_token": "", "model_max_length": 1000000000000000019884624838656, "tokenizer_class": "LlamaTokenizer", "unk_token": ""}
--------------------------------------------------------------------------------
/Quick_demo/MedKEBERT/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "xmcmic/Med-KEBERT",
3 | "architectures": [
4 | "BertModel"
5 | ],
6 | "attention_probs_dropout_prob": 0.1,
7 | "classifier_dropout": null,
8 | "gradient_checkpointing": false,
9 | "hidden_act": "gelu",
10 | "hidden_dropout_prob": 0.1,
11 | "hidden_size": 768,
12 | "initializer_range": 0.02,
13 | "intermediate_size": 3072,
14 | "layer_norm_eps": 1e-12,
15 | "max_position_embeddings": 512,
16 | "model_type": "bert",
17 | "num_attention_heads": 12,
18 | "num_hidden_layers": 12,
19 | "output_hidden_states": true,
20 | "pad_token_id": 0,
21 | "position_embedding_type": "absolute",
22 | "torch_dtype": "float32",
23 | "transformers_version": "4.24.0",
24 | "type_vocab_size": 2,
25 | "use_cache": true,
26 | "vocab_size": 30522
27 | }
28 |
--------------------------------------------------------------------------------
/Quick_demo/MedKEBERT/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "cls_token": "[CLS]",
3 | "mask_token": "[MASK]",
4 | "pad_token": "[PAD]",
5 | "sep_token": "[SEP]",
6 | "unk_token": "[UNK]"
7 | }
8 |
--------------------------------------------------------------------------------
/Quick_demo/MedKEBERT/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "cls_token": "[CLS]",
3 | "do_basic_tokenize": true,
4 | "do_lower_case": true,
5 | "mask_token": "[MASK]",
6 | "name_or_path": "xmcmic/Med-KEBERT",
7 | "never_split": null,
8 | "pad_token": "[PAD]",
9 | "sep_token": "[SEP]",
10 | "special_tokens_map_file": null,
11 | "strip_accents": null,
12 | "tokenize_chinese_chars": true,
13 | "tokenizer_class": "BertTokenizer",
14 | "unk_token": "[UNK]"
15 | }
16 |
--------------------------------------------------------------------------------
/Quick_demo/Model/RadFM/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/Model/RadFM/__init__.py
--------------------------------------------------------------------------------
/Quick_demo/Model/RadFM/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/Model/RadFM/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/Quick_demo/Model/RadFM/__pycache__/blocks.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/Model/RadFM/__pycache__/blocks.cpython-39.pyc
--------------------------------------------------------------------------------
/Quick_demo/Model/RadFM/__pycache__/helpers.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/Model/RadFM/__pycache__/helpers.cpython-39.pyc
--------------------------------------------------------------------------------
/Quick_demo/Model/RadFM/__pycache__/multimodality_model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/Model/RadFM/__pycache__/multimodality_model.cpython-39.pyc
--------------------------------------------------------------------------------
/Quick_demo/Model/RadFM/__pycache__/my_embedding_layer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/Model/RadFM/__pycache__/my_embedding_layer.cpython-39.pyc
--------------------------------------------------------------------------------
/Quick_demo/Model/RadFM/__pycache__/position_encoding.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/Model/RadFM/__pycache__/position_encoding.cpython-39.pyc
--------------------------------------------------------------------------------
/Quick_demo/Model/RadFM/__pycache__/transformer_decoder.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/Model/RadFM/__pycache__/transformer_decoder.cpython-39.pyc
--------------------------------------------------------------------------------
/Quick_demo/Model/RadFM/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/Model/RadFM/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/Quick_demo/Model/RadFM/__pycache__/vit_3d.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/Model/RadFM/__pycache__/vit_3d.cpython-39.pyc
--------------------------------------------------------------------------------
/Quick_demo/Model/RadFM/blocks.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union, Callable, Optional
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn
7 | from torch.utils.checkpoint import checkpoint
8 |
9 | class PMC_CLIP_cfg:
10 | backbone: str = 'ModifiedRN50' # ['RN50', 'ModifiedRN50', 'MAE']
11 | layers: Union[Tuple[int, int, int, int], int] = [3,4,6,3]
12 | width: int = 64
13 | head_width: int = 64
14 | mlp_ratio: float = 4.0
15 | patch_size: int = 16
16 | image_size: Union[Tuple[int, int], int] = 224
17 | timm_model_name: str = None # a valid model name overrides layers, width, patch_size
18 | timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
19 | timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
20 | timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
21 | patch_dropout: float = 0.0 # patch dropout rate, no dropout by default
22 | drop_attention_rate: float = 0. # Transformer Dropout
23 | patch_size: None
24 |
25 | class Bottleneck(nn.Module):
26 | expansion = 4
27 |
28 | def __init__(self, inplanes, planes, stride=1):
29 | super().__init__()
30 |
31 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
32 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
33 | self.bn1 = nn.BatchNorm2d(planes)
34 | self.relu1 = nn.ReLU(inplace=True)
35 |
36 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
37 | self.bn2 = nn.BatchNorm2d(planes)
38 | self.relu2 = nn.ReLU(inplace=True)
39 |
40 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
41 |
42 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
43 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
44 | self.relu3 = nn.ReLU(inplace=True)
45 |
46 | self.downsample = None
47 | self.stride = stride
48 |
49 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
50 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
51 | self.downsample = nn.Sequential(OrderedDict([
52 | ("-1", nn.AvgPool2d(stride)),
53 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
54 | ("1", nn.BatchNorm2d(planes * self.expansion))
55 | ]))
56 |
57 | def forward(self, x: torch.Tensor):
58 | identity = x
59 |
60 | out = self.relu1(self.bn1(self.conv1(x)))
61 | out = self.relu2(self.bn2(self.conv2(out)))
62 | out = self.avgpool(out)
63 | out = self.bn3(self.conv3(out))
64 |
65 | if self.downsample is not None:
66 | identity = self.downsample(x)
67 |
68 | out += identity
69 | out = self.relu3(out)
70 | return out
71 |
72 |
73 | class AttentionPool2d(nn.Module):
74 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
75 | super().__init__()
76 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
77 | self.k_proj = nn.Linear(embed_dim, embed_dim)
78 | self.q_proj = nn.Linear(embed_dim, embed_dim)
79 | self.v_proj = nn.Linear(embed_dim, embed_dim)
80 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
81 | self.num_heads = num_heads
82 |
83 | def forward(self, x):
84 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
85 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
86 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
87 | x, _ = F.multi_head_attention_forward(
88 | query=x, key=x, value=x,
89 | embed_dim_to_check=x.shape[-1],
90 | num_heads=self.num_heads,
91 | q_proj_weight=self.q_proj.weight,
92 | k_proj_weight=self.k_proj.weight,
93 | v_proj_weight=self.v_proj.weight,
94 | in_proj_weight=None,
95 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
96 | bias_k=None,
97 | bias_v=None,
98 | add_zero_attn=False,
99 | dropout_p=0,
100 | out_proj_weight=self.c_proj.weight,
101 | out_proj_bias=self.c_proj.bias,
102 | use_separate_proj_weight=True,
103 | training=self.training,
104 | need_weights=False
105 | )
106 |
107 | return x[0]
108 |
109 |
110 | class ResNet(nn.Module):
111 | """
112 | RN50
113 | """
114 |
115 | def __init__(
116 | self, layers, output_dim, heads, image_size=224, width=64,
117 | block=Bottleneck,
118 | ):
119 | super().__init__()
120 | self.output_dim = output_dim
121 | self.image_size = image_size
122 |
123 | # the 1-layer stem
124 | self.conv1 = nn.Conv2d(3, width, kernel_size=3, stride=2, padding=1, bias=False)
125 | self.bn1 = nn.BatchNorm2d(width)
126 | self.relu1 = nn.ReLU(inplace=True)
127 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
128 |
129 | # residual layers
130 | self._inplanes = width # this is a *mutable* variable used during construction
131 | self.layer1 = self._make_layer(width, layers[0])
132 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
133 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
134 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
135 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
136 | # self.head = nn.Linear(512 * 6, output_dim)
137 | self.head = nn.Linear(512 * block.expansion, output_dim)
138 |
139 | # embed_dim = width * 32 # the ResNet feature dimension
140 | # self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
141 |
142 | self.init_parameters()
143 |
144 | def _make_layer(
145 | self,
146 | planes, blocks, stride=1,
147 | block=Bottleneck,
148 | ):
149 | layers = [block(self._inplanes, planes, stride)]
150 |
151 | self._inplanes = planes * block.expansion
152 | for _ in range(1, blocks):
153 | layers.append(block(self._inplanes, planes))
154 |
155 | return nn.Sequential(*layers)
156 |
157 | def init_parameters(self):
158 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
159 | for name, param in resnet_block.named_parameters():
160 | if name.endswith("bn3.weight"):
161 | nn.init.zeros_(param)
162 |
163 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
164 | assert unlocked_groups == 0, 'partial locking not currently supported for this model'
165 | for param in self.parameters():
166 | param.requires_grad = False
167 | if freeze_bn_stats:
168 | freeze_batch_norm_2d(self)
169 |
170 | @torch.jit.ignore
171 | def set_grad_checkpointing(self, enable=True):
172 | # FIXME support for non-transformer
173 | pass
174 |
175 | def stem(self, x):
176 | x = self.relu1(self.bn1(self.conv1(x)))
177 | x = self.maxpool(x)
178 | return x
179 |
180 | def forward(self, x):
181 | # x[0]: [batch_size, 3, 224, 224]
182 | # x[1]: [batch_size, 1]
183 | x = self.stem(x) # [batch_size, 64, 56, 56]
184 | x = self.layer1(x)
185 | x = self.layer2(x)
186 | x = self.layer3(x)
187 | x = self.layer4(x) # [batch_size, 2048, 7, 7]
188 | x = self.avgpool(x) # [batch_size, 2048, 1, 1]
189 | x = torch.flatten(x, 1) # [batch_size, 2048*1*1]
190 | x = self.head(x) # [batch_size, 1024]
191 |
192 | visual_output = dict.fromkeys(["image_features", "mim_loss"], None)
193 | visual_output.update({
194 | 'image_features': x,
195 | })
196 |
197 | return visual_output
198 |
199 |
200 | class ModifiedResNet(nn.Module):
201 | """
202 | A ResNet class that is similar to torchvision's but contains the following changes:
203 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
204 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
205 | - The final pooling layer is a QKV attention instead of an average pool
206 | """
207 |
208 | def __init__(self, layers, output_dim, heads, image_size=224, width=64):
209 | super().__init__()
210 | self.output_dim = output_dim
211 | self.image_size = image_size
212 |
213 | # the 3-layer stem
214 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
215 | self.bn1 = nn.BatchNorm2d(width // 2)
216 | self.relu1 = nn.ReLU(inplace=True)
217 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
218 | self.bn2 = nn.BatchNorm2d(width // 2)
219 | self.relu2 = nn.ReLU(inplace=True)
220 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
221 | self.bn3 = nn.BatchNorm2d(width)
222 | self.relu3 = nn.ReLU(inplace=True)
223 | self.avgpool = nn.AvgPool2d(2)
224 |
225 | # residual layers
226 | self._inplanes = width # this is a *mutable* variable used during construction
227 | self.layer1 = self._make_layer(width, layers[0])
228 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
229 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
230 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
231 |
232 | embed_dim = width * 32 # the ResNet feature dimension
233 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
234 |
235 | self.init_parameters()
236 |
237 | def _make_layer(self, planes, blocks, stride=1):
238 | layers = [Bottleneck(self._inplanes, planes, stride)]
239 |
240 | self._inplanes = planes * Bottleneck.expansion
241 | for _ in range(1, blocks):
242 | layers.append(Bottleneck(self._inplanes, planes))
243 |
244 | return nn.Sequential(*layers)
245 |
246 | def init_parameters(self):
247 | if self.attnpool is not None:
248 | std = self.attnpool.c_proj.in_features ** -0.5
249 | nn.init.normal_(self.attnpool.q_proj.weight, std=std)
250 | nn.init.normal_(self.attnpool.k_proj.weight, std=std)
251 | nn.init.normal_(self.attnpool.v_proj.weight, std=std)
252 | nn.init.normal_(self.attnpool.c_proj.weight, std=std)
253 |
254 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
255 | for name, param in resnet_block.named_parameters():
256 | if name.endswith("bn3.weight"):
257 | nn.init.zeros_(param)
258 |
259 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
260 | assert unlocked_groups == 0, 'partial locking not currently supported for this model'
261 | for param in self.parameters():
262 | param.requires_grad = False
263 | if freeze_bn_stats:
264 | freeze_batch_norm_2d(self)
265 |
266 | @torch.jit.ignore
267 | def set_grad_checkpointing(self, enable=True):
268 | # FIXME support for non-transformer
269 | pass
270 |
271 | def stem(self, x):
272 | x = self.relu1(self.bn1(self.conv1(x)))
273 | x = self.relu2(self.bn2(self.conv2(x)))
274 | x = self.relu3(self.bn3(self.conv3(x)))
275 | x = self.avgpool(x)
276 | return x
277 |
278 | def forward(self, x):
279 | x = self.stem(x)
280 | x = self.layer1(x)
281 | x = self.layer2(x)
282 | x = self.layer3(x)
283 | x = self.layer4(x)
284 | x = self.attnpool(x)
285 |
286 | visual_output = dict.fromkeys(["image_features", "mim_loss"], None)
287 | visual_output.update({
288 | 'image_features': x,
289 | })
290 |
291 | return visual_output
292 |
293 |
294 | class LayerNorm(nn.LayerNorm):
295 | """Subclass torch's LayerNorm to handle fp16."""
296 |
297 | def forward(self, x: torch.Tensor):
298 | orig_type = x.dtype
299 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
300 | return x.to(orig_type)
301 |
302 |
303 | class QuickGELU(nn.Module):
304 | # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
305 | def forward(self, x: torch.Tensor):
306 | return x * torch.sigmoid(1.702 * x)
307 |
308 |
309 | class ResidualAttentionBlock(nn.Module):
310 | def __init__(
311 | self, d_model: int, n_head: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU,
312 | drop_attention_rate: float = 0.,
313 | ):
314 | super().__init__()
315 |
316 | self.attn = nn.MultiheadAttention(
317 | embed_dim=d_model,
318 | num_heads=n_head,
319 | dropout=drop_attention_rate,
320 | )
321 | self.ln_1 = LayerNorm(d_model)
322 | mlp_width = int(d_model * mlp_ratio)
323 | self.mlp = nn.Sequential(OrderedDict([
324 | ("c_fc", nn.Linear(d_model, mlp_width)),
325 | ("gelu", act_layer()),
326 | ("c_proj", nn.Linear(mlp_width, d_model))
327 | ]))
328 | self.ln_2 = LayerNorm(d_model)
329 |
330 | def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
331 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
332 |
333 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
334 | x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
335 | x = x + self.mlp(self.ln_2(x))
336 | return x
337 |
338 |
339 | class PatchDropout(nn.Module):
340 | """
341 | https://arxiv.org/abs/2212.00794
342 | """
343 |
344 | def __init__(self, prob, exclude_first_token=True):
345 | super().__init__()
346 | assert 0 <= prob < 1.
347 | self.prob = prob
348 | self.exclude_first_token = exclude_first_token # exclude CLS token
349 |
350 | def forward(self, x):
351 | if not self.training or self.prob == 0.:
352 | return x
353 |
354 | if self.exclude_first_token:
355 | cls_tokens, x = x[:, :1], x[:, 1:]
356 | else:
357 | cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
358 |
359 | batch = x.size()[0]
360 | num_tokens = x.size()[1]
361 |
362 | batch_indices = torch.arange(batch)
363 | batch_indices = batch_indices[..., None]
364 |
365 | keep_prob = 1 - self.prob
366 | num_patches_keep = max(1, int(num_tokens * keep_prob))
367 |
368 | rand = torch.randn(batch, num_tokens)
369 | patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
370 |
371 | x = x[batch_indices, patch_indices_keep]
372 |
373 | if self.exclude_first_token:
374 | x = torch.cat((cls_tokens, x), dim=1)
375 |
376 | return x
377 |
378 |
379 | class Transformer(nn.Module):
380 | def __init__(
381 | self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU,
382 | drop_attention_rate: float = 0.,
383 | ):
384 | super().__init__()
385 | self.width = width
386 | self.layers = layers
387 | self.grad_checkpointing = False
388 |
389 | self.resblocks = nn.ModuleList([
390 | ResidualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer, drop_attention_rate=drop_attention_rate)
391 | for _ in range(layers)
392 | ])
393 |
394 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
395 | for r in self.resblocks:
396 | if self.grad_checkpointing and not torch.jit.is_scripting():
397 | x = checkpoint(r, x, attn_mask)
398 | else:
399 | x = r(x, attn_mask=attn_mask)
400 | return x
--------------------------------------------------------------------------------
/Quick_demo/Model/RadFM/helpers.py:
--------------------------------------------------------------------------------
1 | """
2 | Taken from https://github.com/lucidrains/flamingo-pytorch
3 | """
4 |
5 | import torch
6 | from einops import rearrange, repeat
7 | from einops_exts import rearrange_many
8 | from torch import einsum, nn
9 |
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 |
15 | def FeedForward(dim, mult=4):
16 | inner_dim = int(dim * mult)
17 | return nn.Sequential(
18 | nn.LayerNorm(dim),
19 | nn.Linear(dim, inner_dim, bias=False),
20 | nn.GELU(),
21 | nn.Linear(inner_dim, dim, bias=False),
22 | )
23 |
24 |
25 | class PerceiverAttention(nn.Module):
26 | def __init__(self, *, dim, dim_head=64, heads=8):
27 | super().__init__()
28 | self.scale = dim_head**-0.5
29 | self.heads = heads
30 | inner_dim = dim_head * heads
31 |
32 | self.norm_media = nn.LayerNorm(dim)
33 | self.norm_latents = nn.LayerNorm(dim)
34 |
35 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
36 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
37 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
38 |
39 | def forward(self, x, latents):
40 | """
41 | Args:
42 | x (torch.Tensor): image features
43 | shape (b, T, n1, D)
44 | latent (torch.Tensor): latent features
45 | shape (b, T, n2, D)
46 | """
47 | x = self.norm_media(x)
48 | latents = self.norm_latents(latents)
49 |
50 | h = self.heads
51 |
52 | q = self.to_q(latents)
53 | kv_input = torch.cat((x, latents), dim=-2)
54 | k, v = self.to_kv(kv_input).chunk(2, dim=-1)
55 | q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
56 | q = q * self.scale
57 |
58 | # attention
59 | sim = einsum("... i d, ... j d -> ... i j", q, k)
60 | sim = sim - sim.amax(dim=-1, keepdim=True).detach()
61 | attn = sim.softmax(dim=-1)
62 |
63 | out = einsum("... i j, ... j d -> ... i d", attn, v)
64 | out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
65 | return self.to_out(out)
66 |
67 |
68 | class PerceiverResampler(nn.Module):
69 | def __init__(
70 | self,
71 | *,
72 | dim,
73 | depth=6,
74 | dim_head=64,
75 | heads=8,
76 | num_latents=64,
77 | max_num_media=None,
78 | max_num_frames=None,
79 | ff_mult=4,
80 | ):
81 | super().__init__()
82 | self.latents = nn.Parameter(torch.randn(num_latents, dim))
83 | self.frame_embs = (
84 | nn.Parameter(torch.randn(max_num_frames, dim))
85 | if exists(max_num_frames)
86 | else None
87 | )
88 | self.media_time_embs = (
89 | nn.Parameter(torch.randn(max_num_media, 1, dim))
90 | if exists(max_num_media)
91 | else None
92 | )
93 |
94 | self.layers = nn.ModuleList([])
95 | for _ in range(depth):
96 | self.layers.append(
97 | nn.ModuleList(
98 | [
99 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
100 | FeedForward(dim=dim, mult=ff_mult),
101 | ]
102 | )
103 | )
104 |
105 | self.norm = nn.LayerNorm(dim)
106 |
107 | def forward(self, x):
108 | """
109 | Args:
110 | x (torch.Tensor): image features
111 | shape (b, T, F, v, D)
112 | Returns:
113 | shape (b, T, n, D) where n is self.num_latents
114 | """
115 | b, T, F, v = x.shape[:4]
116 |
117 | # frame and media time embeddings
118 | if exists(self.frame_embs):
119 | frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
120 | x = x + frame_embs
121 | x = rearrange(
122 | x, "b T F v d -> b T (F v) d"
123 | ) # flatten the frame and spatial dimensions
124 | if exists(self.media_time_embs):
125 | x = x + self.media_time_embs[:T]
126 |
127 | # blocks
128 | latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
129 | for attn, ff in self.layers:
130 | latents = attn(x, latents) + latents
131 | latents = ff(latents) + latents
132 | return self.norm(latents)
133 |
134 |
135 | # gated cross attention
136 |
137 |
138 | class MaskedCrossAttention(nn.Module):
139 | def __init__(
140 | self,
141 | *,
142 | dim,
143 | dim_visual,
144 | dim_head=64,
145 | heads=8,
146 | only_attend_immediate_media=True,
147 | ):
148 | super().__init__()
149 | self.scale = dim_head**-0.5
150 | self.heads = heads
151 | inner_dim = dim_head * heads
152 |
153 | self.norm = nn.LayerNorm(dim)
154 |
155 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
156 | self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
157 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
158 |
159 | # whether for text to only attend to immediate preceding image, or all previous images
160 | self.only_attend_immediate_media = only_attend_immediate_media
161 |
162 | def forward(self, x, media, media_locations=None, attend_previous=True):
163 | """
164 | Args:
165 | x (torch.Tensor): text features
166 | shape (B, T_txt, D_txt)
167 | media (torch.Tensor): image features
168 | shape (B, T_img, n, D_img) where n is the dim of the latents
169 | media_locations: boolean mask identifying the media tokens in x
170 | shape (B, T_txt)
171 | attend_previous: bool
172 | If false, ignores immediately preceding image and starts attending when following image
173 | """
174 | _, T_img, n = media.shape[:3]
175 | h = self.heads
176 |
177 | x = self.norm(x)
178 |
179 | q = self.to_q(x)
180 | media = rearrange(media, "b t n d -> b (t n) d")
181 |
182 | k, v = self.to_kv(media).chunk(2, dim=-1)
183 | q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
184 |
185 | q = q * self.scale
186 |
187 | sim = einsum("... i d, ... j d -> ... i j", q, k)
188 |
189 | if exists(media_locations):
190 | # at each boolean of True, increment the time counter (relative to media time)
191 | text_time = media_locations.cumsum(dim=-1)
192 | media_time = torch.arange(T_img, device=x.device) + 1
193 |
194 | if not attend_previous:
195 | text_time[~media_locations] += 1
196 | # make sure max is still the number of images in the sequence
197 | text_time[
198 | text_time
199 | > repeat(
200 | torch.count_nonzero(media_locations, dim=1),
201 | "b -> b i",
202 | i=text_time.shape[1],
203 | )
204 | ] = 0
205 |
206 | # text time must equal media time if only attending to most immediate image
207 | # otherwise, as long as text time is greater than media time (if attending to all previous images / media)
208 | mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
209 |
210 | text_to_media_mask = mask_op(
211 | rearrange(text_time, "b i -> b 1 i 1"),
212 | repeat(media_time, "j -> 1 1 1 (j n)", n=n),
213 | )
214 | sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
215 |
216 | sim = sim - sim.amax(dim=-1, keepdim=True).detach()
217 | attn = sim.softmax(dim=-1)
218 |
219 | if exists(media_locations) and self.only_attend_immediate_media:
220 | # any text without a preceding media needs to have attention zeroed out
221 | text_without_media_mask = text_time == 0
222 | text_without_media_mask = rearrange(
223 | text_without_media_mask, "b i -> b 1 i 1"
224 | )
225 | attn = attn.masked_fill(text_without_media_mask, 0.0)
226 |
227 | out = einsum("... i j, ... j d -> ... i d", attn, v)
228 | out = rearrange(out, "b h n d -> b n (h d)")
229 | return self.to_out(out)
230 |
231 |
232 | class GatedCrossAttentionBlock(nn.Module):
233 | def __init__(
234 | self,
235 | *,
236 | dim,
237 | dim_visual,
238 | dim_head=64,
239 | heads=8,
240 | ff_mult=4,
241 | only_attend_immediate_media=True,
242 | ):
243 | super().__init__()
244 | self.attn = MaskedCrossAttention(
245 | dim=dim,
246 | dim_visual=dim_visual,
247 | dim_head=dim_head,
248 | heads=heads,
249 | only_attend_immediate_media=only_attend_immediate_media,
250 | )
251 | self.attn_gate = nn.Parameter(torch.tensor([0.0]))
252 |
253 | self.ff = FeedForward(dim, mult=ff_mult)
254 | self.ff_gate = nn.Parameter(torch.tensor([0.0]))
255 |
256 | def forward(
257 | self,
258 | x,
259 | media,
260 | media_locations=None,
261 | attend_previous=True,
262 | ):
263 | x = (
264 | self.attn(
265 | x,
266 | media,
267 | media_locations=media_locations,
268 | attend_previous=attend_previous,
269 | )
270 | * self.attn_gate.tanh()
271 | + x
272 | )
273 | x = self.ff(x) * self.ff_gate.tanh() + x
274 |
275 | return x
276 |
--------------------------------------------------------------------------------
/Quick_demo/Model/RadFM/multimodality_model.py:
--------------------------------------------------------------------------------
1 | # Import necessary libraries
2 | from torch import nn
3 | from transformers.models.llama import LlamaForCausalLM
4 | from .my_embedding_layer import MyEmbedding
5 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
6 | import tqdm.auto as tqdm
7 | import torch.nn as nn
8 | import torch
9 | from torch.utils.checkpoint import checkpoint
10 | from torch.autograd import Variable
11 | import numpy as np
12 |
13 | class MultiLLaMAForCausalLM(nn.Module):
14 | """
15 | A multimodal LLaMA model that combines language and vision inputs
16 | for causal language modeling tasks.
17 | """
18 | def __init__(self, lang_model_path):
19 | """
20 | Initialize the multimodal model.
21 |
22 | Args:
23 | lang_model_path (str): Path to the pretrained language model
24 | """
25 | super(MultiLLaMAForCausalLM, self).__init__()
26 |
27 | # Load pretrained LLaMA model
28 | self.lang_model = LlamaForCausalLM.from_pretrained(
29 | lang_model_path,
30 | )
31 |
32 | # Enable gradient checkpointing for memory efficiency
33 | self.lang_model.gradient_checkpointing_enable()
34 | self.lang_model.enable_input_require_grads()
35 |
36 | # Initialize custom embedding layer and share weights with language model
37 | self.embedding_layer = MyEmbedding()
38 | self.embedding_layer.weight = self.lang_model.get_input_embeddings().weight
39 |
40 | # Set model dimensions
41 | self.hidden_dim = 5120
42 | self.voc_size = 32000
43 |
44 | def forward(self, lang_x, vision_x, attention_mask, labels, loss_reweight, key_words_query):
45 | """
46 | Forward pass for the multimodal model.
47 |
48 | Args:
49 | lang_x: Language input tokens
50 | vision_x: Vision input features
51 | attention_mask: Attention mask for language inputs
52 | labels: Target labels for language modeling
53 | loss_reweight: Weights for calculating loss (to prioritize certain tokens)
54 | key_words_query: Query for highlighting important words
55 |
56 | Returns:
57 | Dictionary containing model outputs including loss and logits
58 | """
59 | if labels.shape == lang_x.shape:
60 | # Set embedding mode to handle text inputs
61 | self.embedding_layer.flag = 'Text'
62 |
63 | # Get embeddings and matching loss from embedding layer
64 | input_embedding, loss_match = self.embedding_layer(lang_x, vision_x, key_words_query)
65 |
66 | # Forward pass through the language model
67 | output = self.lang_model(inputs_embeds=input_embedding, attention_mask=attention_mask, labels=labels)
68 | logits = output['logits']
69 |
70 | # Initialize regularization loss
71 | loss_reg = None
72 | if labels is not None:
73 | # Shift logits and labels for next-token prediction
74 | shift_logits = logits[..., :-1, :].contiguous()
75 | shift_labels = labels[..., 1:].contiguous()
76 | shift_loss_reweight = loss_reweight[..., 1:].contiguous()
77 |
78 | # Prepare for loss calculation
79 | loss_fct = CrossEntropyLoss(reduction='none')
80 | shift_logits = shift_logits.view(-1, self.voc_size)
81 | shift_labels = shift_labels.view(-1)
82 | shift_loss_reweight = shift_loss_reweight.view(-1)
83 |
84 | # Ensure tensors are on the same device
85 | shift_labels = shift_labels.to(shift_logits.device)
86 | shift_loss_reweight = shift_loss_reweight.to(shift_logits.device)
87 |
88 | # Calculate weighted cross-entropy loss
89 | loss_reg = loss_fct(shift_logits, shift_labels)
90 | loss_reg = torch.sum(shift_loss_reweight * loss_reg) / torch.sum(shift_loss_reweight)
91 |
92 | # Combine losses
93 | loss = loss_reg
94 | if loss_match is not None:
95 | loss = 0.8 * loss + 0.2 * loss_match
96 |
97 | # Calculate accuracy metrics
98 | logits = output['logits'][..., :-1, :].contiguous().detach()
99 | total = len(labels)
100 | predictions = torch.argmax(logits, dim=-1)
101 | labels = labels[..., 1:].contiguous()
102 |
103 | # Count correct predictions (ignoring padding tokens with -100)
104 | Acc = torch.sum(torch.all(torch.logical_or(predictions == labels, labels == -100), dim=-1))
105 | Accuracy = Acc / total
106 |
107 | return dict(
108 | # loss_reg = loss_reg,
109 | # loss_matching = loss_matching,
110 | logits=Accuracy,
111 | loss=output['loss'],
112 | )
113 |
114 | ### useless for now ignore the folowing codes ###
115 | # if labels.shape == vision_x.shape:
116 | # self.embedding_layer.flag = 'Seg'
117 | # input_embedding = self.embedding_layer(lang_x, vision_x)
118 |
119 | def generate(self, lang_x, vision_x):
120 | """
121 | Generate text based on language and vision inputs.
122 |
123 | Args:
124 | lang_x: Language input tokens
125 | vision_x: Vision input features
126 |
127 | Returns:
128 | Generated token sequence
129 | """
130 | # Set embedding mode to text generation
131 | self.embedding_layer.flag = 'Text'
132 |
133 | with torch.no_grad():
134 | # Get embeddings from the embedding layer
135 | input_embedding, _ = self.embedding_layer(lang_x, vision_x)
136 |
137 | # Generate text using language model
138 | generation = self.lang_model.generate(
139 | inputs_embeds=input_embedding,
140 | max_new_tokens=200,
141 | top_k=50
142 | )
143 |
144 | return generation
--------------------------------------------------------------------------------
/Quick_demo/Model/RadFM/my_embedding_layer.py:
--------------------------------------------------------------------------------
1 | # Import necessary libraries
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch
5 | from .helpers import PerceiverResampler
6 | from .utils import get_visual_encoder
7 | from einops import rearrange, repeat
8 | from einops_exts import rearrange_many
9 | import torchvision
10 | from .vit_3d import ViT
11 | from einops.layers.torch import Rearrange
12 | from .transformer_decoder import TransformerDecoder, TransformerDecoderLayer
13 | from torch.utils.checkpoint import checkpoint
14 | from torch.autograd import Variable
15 | import random
16 | from transformers import AutoTokenizer, AutoModel
17 |
18 | class MyEmbedding(nn.Module):
19 | """
20 | Custom embedding layer for multimodal inputs that combines text and vision features.
21 | """
22 | def __init__(self, num_embeddings=32000, embedding_dim=5120, perceiver_num=32, vis_dim=768,
23 | patch_size=32, frame_patch_size=4, seg_channel=256):
24 | """
25 | Initialize the multimodal embedding layer.
26 |
27 | Args:
28 | num_embeddings (int): Size of vocabulary for text embeddings
29 | embedding_dim (int): Dimension of output embeddings
30 | perceiver_num (int): Number of latent vectors in perceiver
31 | vis_dim (int): Dimension of vision features
32 | patch_size (int): Size of image patches
33 | frame_patch_size (int): Size of 3D frame patches
34 | seg_channel (int): Number of segmentation channels
35 | """
36 | super().__init__()
37 | self.num_embeddings = num_embeddings
38 | self.embedding_dim = embedding_dim
39 | # Standard embedding weight matrix for text tokens
40 | self.weight = nn.Parameter(torch.torch.randn((num_embeddings, embedding_dim)))
41 | # Special token weights for figures/images
42 | self.figure_token_weight = nn.Parameter(torch.randn((2, embedding_dim)))
43 | self.flag = 'Text' # Mode flag: 'Text' or 'Seg'
44 | self.patch_size = patch_size
45 | self.frame_patch_size = frame_patch_size
46 | self.seg_channel = seg_channel
47 |
48 | ## the MedKEBERT can be downloaded from https://huggingface.co/xmcmic/Med-KEBERT/tree/main ##
49 | # Initialize medical domain BERT model for keyword understanding
50 | self.bert_tokenizer = AutoTokenizer.from_pretrained("xmcmic/Med-KEBERT")
51 | self.bert_model = AutoModel.from_pretrained("xmcmic/Med-KEBERT")
52 | # Project BERT outputs to vision feature space
53 | self.bert_projection_fc = nn.Linear(768, vis_dim)
54 |
55 | # 3D Vision Transformer for processing volumetric medical images
56 | self.vision_encoder = ViT(
57 | image_size=512, # image size
58 | frames=512, # max number of frames
59 | image_patch_size=patch_size, # image patch size
60 | frame_patch_size=frame_patch_size, # frame patch size
61 | dim=vis_dim,
62 | depth=12,
63 | heads=8,
64 | mlp_dim=2048,
65 | dropout=0.1,
66 | emb_dropout=0.1
67 | )
68 |
69 | # Upscaling layers for vision features (used in segmentation mode)
70 | self.output_upscaling = nn.Sequential(
71 | nn.ConvTranspose3d(vis_dim, vis_dim // 4, kernel_size=2, stride=2),
72 | nn.BatchNorm3d(vis_dim // 4),
73 | nn.GELU(),
74 | nn.ConvTranspose3d(vis_dim // 4, vis_dim // 8, kernel_size=2, stride=2),
75 | nn.GELU(),
76 | )
77 |
78 | # Transformer decoder for cross-attention between text and vision
79 | decoder_layer = TransformerDecoderLayer(d_model=vis_dim, nhead=8, normalize_before=True)
80 | decoder_norm = nn.LayerNorm(vis_dim)
81 | self.transformer_decoder = TransformerDecoder(decoder_layer=decoder_layer, num_layers=4, norm=decoder_norm)
82 |
83 | # MLP for processing transformer decoder outputs
84 | self.transformer_decoder_mlp = nn.Sequential(
85 | nn.Linear(vis_dim, vis_dim // 4),
86 | nn.GELU(),
87 | nn.Linear(vis_dim // 4, vis_dim // 8),
88 | nn.GELU(),
89 | )
90 | self.vis_dim = vis_dim
91 |
92 | # Perceiver resampler to reduce sequence length of vision features
93 | self.perceiver = PerceiverResampler(dim=self.vis_dim, num_latents=perceiver_num)
94 | # Final projection to embedding dimension
95 | self.fc = nn.Linear(self.vis_dim, self.embedding_dim)
96 | # Classification head for matching keywords
97 | self.cls_head = nn.Linear(self.vis_dim // 8, 1)
98 |
99 |
100 | def forward(self, text_input, vision_x, key_words_query=None):
101 | """
102 | Forward pass for the embedding layer.
103 |
104 | Args:
105 | text_input: Text token indices [B, L]
106 | vision_x: Visual input features [B, S, C, H, W, D]
107 | key_words_query: Optional list of key words for contrastive learning
108 |
109 | Returns:
110 | tuple: (output_embeddings, loss_matching)
111 | - output_embeddings: Combined embeddings for text and vision
112 | - loss_matching: Contrastive loss for keyword matching (or None)
113 | """
114 | if self.flag == 'Text':
115 | # Process in text mode
116 | B, S, C, H, W, D = vision_x.shape
117 | # Reshape for batch processing
118 | vision_x = rearrange(vision_x, "b S c h w d-> (b S) c h w d")
119 |
120 | # Process through vision encoder
121 | vision_x, pos_embedding = self.vision_encoder(vision_x)
122 |
123 | # Reshape back to batch format
124 | vision_x = rearrange(vision_x, "(b s F) v d -> b s F v d", b=B, s=S, F=1)
125 |
126 | loss_matching = None
127 |
128 | if key_words_query is not None:
129 | ## we do not use the following parts in final version.
130 | ## You can quota the following codes and if so the bert models will be useless.
131 | # key_words_query list[list[str]] B, words, each word matches corresponding vision_x embedding
132 |
133 | # Extract and deduplicate keywords
134 | query_words = [item for sublist in key_words_query for item in sublist]
135 | query_words = list(set(query_words))
136 |
137 | # Limit number of keywords to process
138 | if len(query_words) > 16:
139 | random.shuffle(query_words)
140 | query_words = query_words[0:16]
141 |
142 | if query_words != []:
143 | # Create binary labels for contrastive learning
144 | contrastive_labels = torch.zeros(B, len(query_words)) # B Q
145 | for i, sublist in enumerate(key_words_query):
146 | for j, item in enumerate(query_words):
147 | if item in sublist:
148 | contrastive_labels[i, j] = 1
149 | contrastive_labels = contrastive_labels.to(vision_x.dtype).to(vision_x.device)
150 |
151 | # Get BERT embeddings for keywords
152 | with torch.no_grad():
153 | query_words_embedding = self.bert_tokenizer(
154 | query_words,
155 | padding='max_length',
156 | truncation=True,
157 | max_length=256,
158 | return_tensors="pt"
159 | )
160 | query_words_embedding = self.bert_model(
161 | input_ids=query_words_embedding['input_ids'].to(vision_x.device),
162 | attention_mask=query_words_embedding['attention_mask'].to(vision_x.device)
163 | )['last_hidden_state'][:, 0, :].to(vision_x.dtype).to(vision_x.device) # Q,D
164 |
165 | # Project BERT embeddings to vision space
166 | query_words_embedding = self.bert_projection_fc(query_words_embedding)
167 | query_words_embedding = query_words_embedding.unsqueeze(0).repeat(B, 1, 1) # B,Q,D
168 | _, N, _ = query_words_embedding.shape
169 |
170 | # Pool vision features
171 | image_embedding = vision_x.mean(dim=1) # B V D average pooling to remove multimodality
172 | image_embedding = rearrange(image_embedding, "b F v d -> b (F v) d")
173 | pos_embedding = rearrange(pos_embedding, "(b s) v d -> b s v d", b=B, s=S)[:, 0, :, :]
174 |
175 | # Prepare inputs for transformer decoder
176 | image_embedding = image_embedding.transpose(0, 1) # (H/P W/P D/P) B D
177 | pos_embedding = pos_embedding.transpose(0, 1) # (H/P W/P D/P) B D
178 | query_words_embedding = query_words_embedding.transpose(0, 1) # N B D
179 |
180 | # Cross-attention between keywords and image features
181 | oo_embedding, _ = self.transformer_decoder(
182 | query_words_embedding, image_embedding, pos=pos_embedding
183 | )
184 | oo_embedding = oo_embedding.transpose(0, 1) # B Q D
185 | oo_embedding = rearrange(oo_embedding, 'b n d -> (b n) d')
186 | oo_embedding = self.transformer_decoder_mlp(oo_embedding)
187 | oo_embedding = self.cls_head(oo_embedding).mean(dim=-1)
188 | oo_embedding = rearrange(oo_embedding, '(b n) -> b n', b=B, n=N) # B Q
189 |
190 | # Calculate contrastive loss
191 | loss_matching = F.binary_cross_entropy_with_logits(oo_embedding, contrastive_labels)
192 |
193 | # Process vision features through perceiver resampler
194 | vision_x = self.perceiver(vision_x) # reshapes to (b, S, n, d)
195 |
196 | n = vision_x.shape[2]
197 |
198 | # Project vision features to embedding dimension
199 | vision_x = rearrange(vision_x, "b s n d -> (b s n) d")
200 | vision_x = self.fc(vision_x)
201 | vision_x = rearrange(vision_x, "(b T) d -> b T d", b=B, T=n*S)
202 |
203 | # Combine text and vision embeddings
204 | embedding_weight = torch.cat([self.weight, self.figure_token_weight], dim=0)
205 | embedding_weight = embedding_weight.unsqueeze(0).repeat(B, 1, 1)
206 | embedding_weight = torch.cat([embedding_weight, vision_x], dim=1)
207 |
208 | # Convert text indices to one-hot and compute final embeddings
209 | text_input = F.one_hot(text_input, embedding_weight.shape[1]).to(vision_x.dtype).to(vision_x.device)
210 | out_put = torch.matmul(text_input, embedding_weight)
211 |
212 | ## useless for now. ignore the folowing code##
213 | # if self.flag == 'Seg':
214 | # B,C,H,W,D = vision_x.shape
215 | # _,N,_ = text_input.shape
216 | # latent_embedding, pos_embedding = self.vision_encoder(vision_x) # B (H/P W/P D/P) D
217 |
218 | # image_embedding = latent_embedding.transpose(0,1) # (H/P W/P D/P) B D
219 | # pos_embedding = pos_embedding.transpose(0,1) # (H/P W/P D/P) B D
220 | # text_input = text_input.transpose(0,1) # N B D
221 |
222 | # mask_embedding,_ = self.transformer_decoder(text_input, image_embedding, pos = pos_embedding)
223 | # mask_embedding = mask_embedding.transpose(0,1) # B N D
224 | # mask_embedding = rearrange(mask_embedding, 'b n d -> (b n) d')
225 | # mask_embedding = self.transformer_decoder_mlp(mask_embedding)
226 | # mask_embedding = rearrange(mask_embedding, '(b n) d -> b n d', b=B, n=N,d = self.vis_dim // 8)
227 |
228 | # vision_x = rearrange(latent_embedding,'b (h w d) c -> b c h w d', h = (H // self.patch_size), w = (W // self.patch_size), d = (D // self.frame_patch_size), c=self.vis_dim)
229 | # vision_x = self.output_upscaling(vision_x) #B C H/4 W/4 D/4
230 | # out_put = torch.einsum('bchwd,bnc->bnhwd', vision_x, mask_embedding)
231 |
232 | return out_put, loss_matching
233 |
234 | # model = MyEmbedding(vision_encoder_path = '')
235 | # text_input = torch.randint(low=0, high=3210, size=(4,2048))
236 | # image_input = torch.randn((4,3,3,512,512,4))
237 | # key_words_query = [[],[],[],['consoliation']]
238 | # print(model(text_input, image_input, key_words_query))
--------------------------------------------------------------------------------
/Quick_demo/Model/RadFM/position_encoding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Various positional encodings for the transformer.
4 | """
5 | import math
6 | import torch
7 | from torch import nn
8 | from einops.layers.torch import Rearrange
9 | from einops import rearrange, repeat
10 |
11 | class PositionEmbeddingSine(nn.Module):
12 | """
13 | This is a more standard version of the position embedding, very similar to the one
14 | used by the Attention is all you need paper, generalized to work on images.
15 | """
16 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
17 | super().__init__()
18 | self.num_pos_feats = num_pos_feats
19 | self.temperature = temperature
20 | self.normalize = normalize
21 | if scale is not None and normalize is False:
22 | raise ValueError("normalize should be True if scale is passed")
23 | if scale is None:
24 | scale = 2 * math.pi
25 | self.scale = scale
26 |
27 | def forward(self, tensor_list):
28 | x = tensor_list.tensors
29 | mask = tensor_list.mask
30 | assert mask is not None
31 | not_mask = ~mask
32 | y_embed = not_mask.cumsum(1, dtype=torch.float32)
33 | x_embed = not_mask.cumsum(2, dtype=torch.float32)
34 | if self.normalize:
35 | eps = 1e-6
36 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
37 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
38 |
39 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
40 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
41 |
42 | pos_x = x_embed[:, :, :, None] / dim_t
43 | pos_y = y_embed[:, :, :, None] / dim_t
44 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
45 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
46 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
47 | return pos
48 |
49 |
50 | class PositionEmbeddingLearned(nn.Module):
51 | """
52 | Absolute pos embedding, learned.
53 | """
54 | def __init__(self, num_pos_feats=256):
55 | super().__init__()
56 | self.row_embed = nn.Embedding(50, num_pos_feats)
57 | self.col_embed = nn.Embedding(50, num_pos_feats)
58 | self.reset_parameters()
59 |
60 | def reset_parameters(self):
61 | nn.init.uniform_(self.row_embed.weight)
62 | nn.init.uniform_(self.col_embed.weight)
63 |
64 | def forward(self, tensor_list):
65 | x = tensor_list.tensors
66 | h, w = x.shape[-2:]
67 | i = torch.arange(w, device=x.device)
68 | j = torch.arange(h, device=x.device)
69 | x_emb = self.col_embed(i)
70 | y_emb = self.row_embed(j)
71 | pos = torch.cat([
72 | x_emb.unsqueeze(0).repeat(h, 1, 1),
73 | y_emb.unsqueeze(1).repeat(1, w, 1),
74 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
75 | return pos
76 |
77 | class PositionEmbeddingLearned3d(nn.Module):
78 | """
79 | Absolute pos embedding, learned.
80 | """
81 | def __init__(self, num_pos_feats=256,h_patch_num = 16, w_patch_num = 16,d_patch_num = 64):
82 | super().__init__()
83 | self.h_patch_num = h_patch_num
84 | self.w_patch_num = w_patch_num
85 | self.d_patch_num = d_patch_num
86 | self.row_embed = nn.Embedding(h_patch_num, num_pos_feats)
87 | self.col_embed = nn.Embedding(w_patch_num, num_pos_feats)
88 | self.dep_embed = nn.Embedding(d_patch_num, num_pos_feats)
89 | self.reset_parameters()
90 |
91 | def reset_parameters(self):
92 | nn.init.uniform_(self.row_embed.weight)
93 | nn.init.uniform_(self.col_embed.weight)
94 | nn.init.uniform_(self.dep_embed.weight)
95 |
96 | def forward(self, B, h, w, d,x):
97 | i = (torch.arange(h, device=x.device) + 1)* (self.h_patch_num // h) -1
98 | j = (torch.arange(w, device=x.device) + 1)* (self.w_patch_num // w) -1
99 | k = (torch.arange(d, device=x.device) + 1)* (self.d_patch_num // d) -1
100 | x_emb = self.row_embed(i).unsqueeze(1).unsqueeze(2).repeat(1,w,d,1)
101 | y_emb = self.col_embed(j).unsqueeze(0).unsqueeze(2).repeat(h,1,d,1)
102 | z_emb = self.dep_embed(k).unsqueeze(0).unsqueeze(1).repeat(h,w,1,1)
103 | pos = torch.cat([x_emb,y_emb,z_emb,], dim=-1).unsqueeze(0).repeat(B, 1, 1, 1, 1)
104 | pos = rearrange(pos,'b h w d c -> b (h w d) c')
105 | return pos
106 |
107 | def build_position_encoding(args):
108 | N_steps = args.hidden_dim // 2
109 | if args.position_embedding in ('v2', 'sine'):
110 | # TODO find a better way of exposing other arguments
111 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
112 | elif args.position_embedding in ('v3', 'learned'):
113 | position_embedding = PositionEmbeddingLearned(N_steps)
114 | else:
115 | raise ValueError(f"not supported {args.position_embedding}")
116 |
117 | return position_embedding
118 |
119 | # Pos = PositionEmbeddingLearned3d()
120 | # x = torch.randn((8,3,32,32,1))
121 | # print(Pos(8,16,16,1,x))
--------------------------------------------------------------------------------
/Quick_demo/Model/RadFM/transformer_decoder.py:
--------------------------------------------------------------------------------
1 | """
2 | Code modified from DETR tranformer:
3 | https://github.com/facebookresearch/detr
4 | Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
5 | """
6 |
7 | import copy
8 | from typing import Optional, List
9 | import pickle as cp
10 |
11 | import torch
12 | import torch.nn.functional as F
13 | from torch import nn, Tensor
14 |
15 |
16 | class TransformerDecoder(nn.Module):
17 |
18 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
19 | super().__init__()
20 | self.layers = _get_clones(decoder_layer, num_layers)
21 | self.num_layers = num_layers
22 | self.norm = norm
23 | self.return_intermediate = return_intermediate
24 |
25 | def forward(self, tgt, memory,
26 | tgt_mask: Optional[Tensor] = None,
27 | memory_mask: Optional[Tensor] = None,
28 | tgt_key_padding_mask: Optional[Tensor] = None,
29 | memory_key_padding_mask: Optional[Tensor] = None,
30 | pos: Optional[Tensor] = None,
31 | query_pos: Optional[Tensor] = None):
32 | output = tgt
33 | T,B,C = memory.shape
34 | intermediate = []
35 | atten_layers = []
36 | for n,layer in enumerate(self.layers):
37 |
38 | residual=True
39 | output,ws = layer(output, memory, tgt_mask=tgt_mask,
40 | memory_mask=memory_mask,
41 | tgt_key_padding_mask=tgt_key_padding_mask,
42 | memory_key_padding_mask=memory_key_padding_mask,
43 | pos=pos, query_pos=query_pos,residual=residual)
44 | atten_layers.append(ws)
45 | if self.return_intermediate:
46 | intermediate.append(self.norm(output))
47 | if self.norm is not None:
48 | output = self.norm(output)
49 | if self.return_intermediate:
50 | intermediate.pop()
51 | intermediate.append(output)
52 |
53 | if self.return_intermediate:
54 | return torch.stack(intermediate)
55 | return output,atten_layers
56 |
57 |
58 |
59 | class TransformerDecoderLayer(nn.Module):
60 |
61 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
62 | activation="relu", normalize_before=False):
63 | super().__init__()
64 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
65 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
66 | # Implementation of Feedforward model
67 | self.linear1 = nn.Linear(d_model, dim_feedforward)
68 | self.dropout = nn.Dropout(dropout)
69 | self.linear2 = nn.Linear(dim_feedforward, d_model)
70 |
71 | self.norm1 = nn.LayerNorm(d_model)
72 | self.norm2 = nn.LayerNorm(d_model)
73 | self.norm3 = nn.LayerNorm(d_model)
74 | self.dropout1 = nn.Dropout(dropout)
75 | self.dropout2 = nn.Dropout(dropout)
76 | self.dropout3 = nn.Dropout(dropout)
77 |
78 | self.activation = _get_activation_fn(activation)
79 | self.normalize_before = normalize_before
80 | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
81 | return tensor if pos is None else tensor + pos
82 |
83 | def forward_post(self, tgt, memory,
84 | tgt_mask: Optional[Tensor] = None,
85 | memory_mask: Optional[Tensor] = None,
86 | tgt_key_padding_mask: Optional[Tensor] = None,
87 | memory_key_padding_mask: Optional[Tensor] = None,
88 | pos: Optional[Tensor] = None,
89 | query_pos: Optional[Tensor] = None,
90 | residual=True):
91 | q = k = self.with_pos_embed(tgt, query_pos)
92 | tgt2,ws = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
93 | key_padding_mask=tgt_key_padding_mask)
94 | tgt = self.norm1(tgt)
95 | tgt2,ws = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
96 | key=self.with_pos_embed(memory, pos),
97 | value=memory, attn_mask=memory_mask,
98 | key_padding_mask=memory_key_padding_mask)
99 |
100 |
101 | # attn_weights [B,NUM_Q,T]
102 | tgt = tgt + self.dropout2(tgt2)
103 | tgt = self.norm2(tgt)
104 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
105 | tgt = tgt + self.dropout3(tgt2)
106 | tgt = self.norm3(tgt)
107 | return tgt,ws
108 |
109 | def forward_pre(self, tgt, memory,
110 | tgt_mask: Optional[Tensor] = None,
111 | memory_mask: Optional[Tensor] = None,
112 | tgt_key_padding_mask: Optional[Tensor] = None,
113 | memory_key_padding_mask: Optional[Tensor] = None,
114 | pos: Optional[Tensor] = None,
115 | query_pos: Optional[Tensor] = None):
116 | tgt2 = self.norm1(tgt)
117 | q = k = self.with_pos_embed(tgt2, query_pos)
118 | tgt2,ws = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
119 | key_padding_mask=tgt_key_padding_mask)
120 | tgt = tgt + self.dropout1(tgt2)
121 | tgt2 = self.norm2(tgt)
122 | tgt2,attn_weights = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
123 | key=self.with_pos_embed(memory, pos),
124 | value=memory, attn_mask=memory_mask,
125 | key_padding_mask=memory_key_padding_mask)
126 | tgt = tgt + self.dropout2(tgt2)
127 | tgt2 = self.norm3(tgt)
128 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
129 | tgt = tgt + self.dropout3(tgt2)
130 | return tgt,attn_weights
131 |
132 | def forward(self, tgt, memory,
133 | tgt_mask: Optional[Tensor] = None,
134 | memory_mask: Optional[Tensor] = None,
135 | tgt_key_padding_mask: Optional[Tensor] = None,
136 | memory_key_padding_mask: Optional[Tensor] = None,
137 | pos: Optional[Tensor] = None,
138 | query_pos: Optional[Tensor] = None,
139 | residual=True):
140 | if self.normalize_before:
141 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
142 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
143 | return self.forward_post(tgt, memory, tgt_mask, memory_mask,
144 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos,residual)
145 |
146 |
147 | def _get_clones(module, N):
148 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
149 |
150 |
151 |
152 | def _get_activation_fn(activation):
153 | """Return an activation function given a string"""
154 | if activation == "relu":
155 | return F.relu
156 | if activation == "gelu":
157 | return F.gelu
158 | if activation == "glu":
159 | return F.glu
160 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
161 |
--------------------------------------------------------------------------------
/Quick_demo/Model/RadFM/utils.py:
--------------------------------------------------------------------------------
1 | from .blocks import ModifiedResNet,PMC_CLIP_cfg
2 | import torch
3 | from torchvision import transforms
4 | from PIL import Image
5 | import torch.nn as nn
6 | def extend_instance(obj, mixin):
7 | """Apply mixins to a class instance after creation"""
8 | base_cls = obj.__class__
9 | base_cls_name = obj.__class__.__name__
10 | obj.__class__ = type(
11 | base_cls_name, (mixin, base_cls), {}
12 | ) # mixin needs to go first for our forward() logic to work
13 |
14 |
15 | def getattr_recursive(obj, att):
16 | """
17 | Return nested attribute of obj
18 | Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
19 | """
20 | if att == "":
21 | return obj
22 | i = att.find(".")
23 | if i < 0:
24 | return getattr(obj, att)
25 | else:
26 | return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
27 |
28 |
29 | def setattr_recursive(obj, att, val):
30 | """
31 | Set nested attribute of obj
32 | Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
33 | """
34 | if "." in att:
35 | obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
36 | setattr(obj, att.split(".")[-1], val)
37 |
38 |
39 |
40 | def get_visual_encoder(model_str):
41 | """
42 | Args:
43 | str (_type_): str_to_model_path
44 | Return:
45 | vision_model, visual_dim, img_preprocessor
46 | """
47 | normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
48 | img_preprocessor = transforms.Compose([
49 | transforms.Resize((512,512), interpolation=Image.BICUBIC),
50 | transforms.ToTensor(),
51 | normalize,
52 | ])
53 | if 'PMC-CLIP' in model_str:
54 | #vision_cfg = json.load(open(model_args.visual_model_config,'r'))['vision_cfg']
55 | vision_cfg = PMC_CLIP_cfg()
56 | vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
57 | vision_model = ModifiedResNet(
58 | layers=vision_cfg.layers,
59 | heads=vision_heads,
60 | output_dim = 768,
61 | image_size=vision_cfg.image_size,
62 | width=vision_cfg.width
63 | )
64 | vision_model = vision_load_pretrain(vision_model,model_str)
65 | vision_model = nn.Sequential(*list(vision_model.children())[:-2])
66 | visual_dim = 1024
67 | return vision_model,visual_dim,img_preprocessor
68 |
69 | def vision_load_pretrain(resnet,model_path):
70 | checkpoint = torch.load(model_path, map_location='cpu')
71 | state_dict = checkpoint['state_dict']
72 | state_dict = {k.replace('module.visual.',''): v for k, v in state_dict.items() if '.visual' in k}
73 | resnet.load_state_dict(state_dict)
74 | return resnet
75 |
--------------------------------------------------------------------------------
/Quick_demo/Model/RadFM/vit_3d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from einops import rearrange, repeat
5 | from einops.layers.torch import Rearrange
6 | from .position_encoding import PositionEmbeddingLearned3d
7 |
8 | # helpers
9 |
10 | def pair(t):
11 | return t if isinstance(t, tuple) else (t, t)
12 |
13 | # classes
14 |
15 | class PreNorm(nn.Module):
16 | def __init__(self, dim, fn):
17 | super().__init__()
18 | self.norm = nn.LayerNorm(dim)
19 | self.fn = fn
20 | def forward(self, x, **kwargs):
21 | return self.fn(self.norm(x), **kwargs)
22 |
23 | class FeedForward(nn.Module):
24 | def __init__(self, dim, hidden_dim, dropout = 0.):
25 | super().__init__()
26 | self.net = nn.Sequential(
27 | nn.Linear(dim, hidden_dim),
28 | nn.GELU(),
29 | nn.Dropout(dropout),
30 | nn.Linear(hidden_dim, dim),
31 | nn.Dropout(dropout)
32 | )
33 | def forward(self, x):
34 | return self.net(x)
35 |
36 | class Attention(nn.Module):
37 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
38 | super().__init__()
39 | inner_dim = dim_head * heads
40 | project_out = not (heads == 1 and dim_head == dim)
41 |
42 | self.heads = heads
43 | self.scale = dim_head ** -0.5
44 |
45 | self.attend = nn.Softmax(dim = -1)
46 | self.dropout = nn.Dropout(dropout)
47 |
48 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
49 |
50 | self.to_out = nn.Sequential(
51 | nn.Linear(inner_dim, dim),
52 | nn.Dropout(dropout)
53 | ) if project_out else nn.Identity()
54 |
55 | def forward(self, x):
56 | qkv = self.to_qkv(x).chunk(3, dim = -1)
57 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
58 |
59 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
60 |
61 | attn = self.attend(dots)
62 | attn = self.dropout(attn)
63 |
64 | out = torch.matmul(attn, v)
65 | out = rearrange(out, 'b h n d -> b n (h d)')
66 | return self.to_out(out)
67 |
68 | class Transformer(nn.Module):
69 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
70 | super().__init__()
71 | self.layers = nn.ModuleList([])
72 | for _ in range(depth):
73 | self.layers.append(nn.ModuleList([
74 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
75 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
76 | ]))
77 | def forward(self, x):
78 | for attn, ff in self.layers:
79 | x = attn(x) + x
80 | x = ff(x) + x
81 | return x
82 |
83 | class ViT(nn.Module):
84 | def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
85 | super().__init__()
86 | image_height, image_width = pair(image_size)
87 | patch_height, patch_width = pair(image_patch_size)
88 |
89 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
90 | assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
91 |
92 | self.patch_height = patch_height
93 | self.patch_width = patch_width
94 | self.frame_patch_size = frame_patch_size
95 |
96 | num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
97 | patch_dim = channels * patch_height * patch_width * frame_patch_size
98 |
99 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
100 |
101 | self.to_patch_embedding = nn.Sequential(
102 | Rearrange('b c (h p1) (w p2) (f pf) -> b (h w f) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
103 | nn.LayerNorm(patch_dim),
104 | nn.Linear(patch_dim, dim),
105 | nn.LayerNorm(dim),
106 | )
107 |
108 | self.pos_embedding = PositionEmbeddingLearned3d(dim // 3,(image_height // patch_height), (image_width // patch_width), (frames // frame_patch_size))
109 | self.dropout = nn.Dropout(emb_dropout)
110 |
111 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
112 |
113 | def forward(self, video):
114 | B, C, H, W, D = video.shape
115 | x = self.to_patch_embedding(video)
116 | b, n, _ = x.shape
117 |
118 | pos = self.pos_embedding(B, H // self.patch_height, W // self.patch_width, D // self.frame_patch_size,x)
119 | x += pos
120 | x = self.dropout(x)
121 |
122 | x = self.transformer(x)
123 | return x,pos
124 |
--------------------------------------------------------------------------------
/Quick_demo/test.py:
--------------------------------------------------------------------------------
1 | # Import necessary libraries for data processing, model loading, and inference
2 | import tqdm.auto as tqdm
3 | import torch.nn.functional as F
4 | from typing import Optional, Dict, Sequence
5 | from typing import List, Optional, Tuple, Union
6 | import transformers
7 | from dataclasses import dataclass, field
8 | from Model.RadFM.multimodality_model import MultiLLaMAForCausalLM
9 | import torch
10 | from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
11 | from torchvision import transforms
12 | from PIL import Image
13 |
14 | def get_tokenizer(tokenizer_path, max_img_size=100, image_num=32):
15 | '''
16 | Initialize the tokenizer with special tokens for image handling
17 |
18 | Args:
19 | tokenizer_path: Path to the base tokenizer
20 | max_img_size: Maximum number of images supported in a prompt
21 | image_num: Number of token embeddings per image
22 |
23 | Returns:
24 | Tuple of (tokenizer, image_padding_tokens)
25 | '''
26 | if isinstance(tokenizer_path, str):
27 | image_padding_tokens = []
28 | # Load the base tokenizer from the provided path
29 | text_tokenizer = LlamaTokenizer.from_pretrained(
30 | tokenizer_path,
31 | )
32 | # Define initial special tokens for image markup
33 | special_token = {"additional_special_tokens": ["", ""]}
34 |
35 | # Generate unique tokens for each image position and patch
36 | for i in range(max_img_size):
37 | image_padding_token = ""
38 |
39 | for j in range(image_num):
40 | image_token = ""
41 | image_padding_token = image_padding_token + image_token
42 | special_token["additional_special_tokens"].append("")
43 |
44 | # Store the concatenated tokens for each image
45 | image_padding_tokens.append(image_padding_token)
46 |
47 | # Add all special tokens to the tokenizer
48 | text_tokenizer.add_special_tokens(
49 | special_token
50 | )
51 |
52 | # Configure standard special tokens for LLaMA models
53 | text_tokenizer.pad_token_id = 0
54 | text_tokenizer.bos_token_id = 1
55 | text_tokenizer.eos_token_id = 2
56 |
57 | return text_tokenizer, image_padding_tokens
58 |
59 | def combine_and_preprocess(question, image_list, image_padding_tokens):
60 | '''
61 | Combine text and images into a multimodal input format
62 |
63 | Args:
64 | question: Text input or question to process
65 | image_list: List of images with their metadata
66 | image_padding_tokens: Special tokens for image placeholders
67 |
68 | Returns:
69 | Tuple of (processed_text, processed_images_tensor)
70 | '''
71 | # Define image transformation pipeline
72 | transform = transforms.Compose([
73 | transforms.RandomResizedCrop([512, 512], scale=(0.8, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
74 | transforms.ToTensor(),
75 | ])
76 |
77 | images = []
78 | new_qestions = [_ for _ in question] # Convert question string to list of characters
79 | padding_index = 0
80 |
81 | # Process each image in the list
82 | for img in image_list:
83 | img_path = img['img_path']
84 | position = img['position'] # Where to insert the image in the text
85 |
86 | # Load and transform the image
87 | image = Image.open(img_path).convert('RGB')
88 | image = transform(image)
89 | image = image.unsqueeze(0).unsqueeze(-1) # Add batch and depth dimensions (c,w,h,d)
90 |
91 | # Resize the image to target dimensions
92 | target_H = 512
93 | target_W = 512
94 | target_D = 4
95 | # This can be different for 3D and 2D images. For demonstration we here set this as the default sizes for 2D images.
96 | images.append(torch.nn.functional.interpolate(image, size=(target_H, target_W, target_D)))
97 |
98 | # Insert image placeholder token at the specified position in text
99 | new_qestions[position] = "" + image_padding_tokens[padding_index] + "" + new_qestions[position]
100 | padding_index += 1
101 |
102 | # Stack all images into a batch and add batch dimension
103 | vision_x = torch.cat(images, dim=1).unsqueeze(0) # Cat tensors and expand the batch_size dim
104 |
105 | # Join the character list back into a string
106 | text = ''.join(new_qestions)
107 | return text, vision_x
108 |
109 |
110 | def main():
111 | '''
112 | Main function to demonstrate the RadFM model inference
113 | '''
114 | print("Setup tokenizer")
115 | # Initialize tokenizer with special image tokens
116 | text_tokenizer, image_padding_tokens = get_tokenizer('./Language_files')
117 | print("Finish loading tokenizer")
118 |
119 | ### Initialize a simple case for demo ###
120 | print("Setup demo case")
121 | # Define a medical question about a chest X-ray
122 | question = "Can you identify any visible signs of Cardiomegaly in the image?"
123 |
124 | # Specify the image path and where to insert it in the question
125 | image = [
126 | {
127 | 'img_path': './view1_frontal.jpg',
128 | 'position': 0, # Insert at the beginning of the question
129 | }, # Can add arbitrary number of images
130 | ]
131 |
132 | # Combine text and images into model-ready format
133 | text, vision_x = combine_and_preprocess(question, image, image_padding_tokens)
134 | print("Finish loading demo case")
135 |
136 | print("Setup Model")
137 | # Initialize the multimodal model
138 | model = MultiLLaMAForCausalLM(
139 | lang_model_path='./Language_files', # Build up model based on LLaMa-13B config
140 | )
141 |
142 | # Load pretrained model weights
143 | ckpt = torch.load('./pytorch_model.bin', map_location='cpu') # Please download our checkpoint from huggingface and decompress the original zip file first
144 | model.load_state_dict(ckpt)
145 | print("Finish loading model")
146 |
147 | # Move model to GPU and set to evaluation mode
148 | model = model.to('cuda')
149 | model.eval()
150 |
151 | # Run inference without gradient computation
152 | with torch.no_grad():
153 | # Tokenize the combined text with image placeholders
154 | lang_x = text_tokenizer(
155 | text, max_length=2048, truncation=True, return_tensors="pt"
156 | )['input_ids'].to('cuda')
157 |
158 | # Move image tensor to GPU
159 | vision_x = vision_x.to('cuda')
160 |
161 | # Generate text response
162 | generation = model.generate(lang_x, vision_x)
163 |
164 | # Decode the generated token IDs to text
165 | generated_texts = text_tokenizer.batch_decode(generation, skip_special_tokens=True)
166 |
167 | # Print results
168 | print('---------------------------------------------------')
169 | print('Input: ', question)
170 | print('Output: ', generated_texts[0])
171 |
172 |
173 | if __name__ == "__main__":
174 | main()
--------------------------------------------------------------------------------
/Quick_demo/view1_frontal.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/Quick_demo/view1_frontal.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops==0.6.1
2 | einops-exts==0.0.4
3 | huggingface-hub==0.16.4
4 | nibabel==5.1.0
5 | nmslib==2.1.1
6 | opencv-python==4.8.0.76
7 | pandas==2.0.3
8 | Pillow==9.4.0
9 | pytz==2023.3
10 | PyYAML==6.0.1
11 | scikit-learn==1.3.0
12 | scipy==1.11.2
13 | scispacy
14 | sentencepiece==0.1.99
15 | SimpleITK==2.2.1
16 | spacy==3.6.1
17 | spacy-alignments==0.9.0
18 | spacy-legacy==3.0.12
19 | spacy-loggers==1.0.4
20 | spacy-transformers==1.2.5
21 | tokenizers==0.13.3
22 | torch==2.0.1
23 | torchaudio==2.0.2
24 | torchvision==0.15.2
25 | tqdm==4.66.1
26 | transformers==4.28.1
27 |
--------------------------------------------------------------------------------
/src/Dataset/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .radiopaedia import RadioVQA_Dataset,Radio_Modality_Dataset,Radiofeatures_Dataset,RadioCaption_Dataset
2 | from .binary import Binary_Dataset
3 | from .chestxray import ChestXray_Dataset
4 | from .vqa import VQA_Dataset
5 | from .pmcoa import PMCOA_Dataset
6 | from .paper_inline import Paper_Inline_dataset
7 | from .case_report import CaseReport_dataset
8 | from .MedPix_dataset import MedPix_Multi_Dataset,MedPix_Single_Dataset,MedPix_QA_Dataset
9 |
--------------------------------------------------------------------------------
/src/Dataset/dataset/__pycache__/MedPix_dataset.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/MedPix_dataset.cpython-39.pyc
--------------------------------------------------------------------------------
/src/Dataset/dataset/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/src/Dataset/dataset/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/src/Dataset/dataset/__pycache__/binary.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/binary.cpython-39.pyc
--------------------------------------------------------------------------------
/src/Dataset/dataset/__pycache__/case_report.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/case_report.cpython-39.pyc
--------------------------------------------------------------------------------
/src/Dataset/dataset/__pycache__/chestxray.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/chestxray.cpython-310.pyc
--------------------------------------------------------------------------------
/src/Dataset/dataset/__pycache__/chestxray.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/chestxray.cpython-39.pyc
--------------------------------------------------------------------------------
/src/Dataset/dataset/__pycache__/paper_inline.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/paper_inline.cpython-39.pyc
--------------------------------------------------------------------------------
/src/Dataset/dataset/__pycache__/pmcoa.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/pmcoa.cpython-39.pyc
--------------------------------------------------------------------------------
/src/Dataset/dataset/__pycache__/pmcvqa.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/pmcvqa.cpython-39.pyc
--------------------------------------------------------------------------------
/src/Dataset/dataset/__pycache__/radiopaedia.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/radiopaedia.cpython-39.pyc
--------------------------------------------------------------------------------
/src/Dataset/dataset/__pycache__/radiovqa.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Dataset/dataset/__pycache__/radiovqa.cpython-39.pyc
--------------------------------------------------------------------------------
/src/Dataset/dataset/binary.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import json
3 | import logging
4 | import os
5 | import re
6 | import difflib
7 | import sys
8 | import torch
9 | import random
10 | from abc import abstractmethod
11 | from itertools import islice
12 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
13 | from collections.abc import Mapping
14 | from torch.utils.data import DataLoader
15 | import PIL
16 | from torch.utils.data import Dataset
17 | import numpy as np
18 | import pandas as pd
19 | from tqdm import tqdm
20 | from torchvision import transforms
21 | from collections import defaultdict
22 | from PIL import Image
23 |
24 | class Binary_Dataset(Dataset):
25 | """_summary_
26 | Args:
27 | Dataset (_type_): caption task formulated as vqa task for Chestxray classification dataset
28 | csv_path (_type_): path to csv file
29 | prompt_json_file (_type_): path to json file containing binary cls prompts, the answer is yes/no
30 | Output:
31 | Dict: {
32 | "image_dict": {"image": image, "position": {"question": 0}}, # image is a tensor of shape [c,w,h,d] [3,512,512,1], position is a dict, random choice of 0 or len(question)
33 | "question": question, # random choice of caption prompts
34 | "answer":answer, # caption
35 | }
36 | """
37 | def __init__(self,csv_path,prompt_json_file):
38 | data_info = pd.read_csv(csv_path)
39 | self.img_path_list = np.asarray(data_info['image_path'])
40 | self.disease_list = np.asarray(data_info['disease'])
41 | self.answer_list = np.asarray(data_info['label'])
42 | self.transform = transforms.Compose([
43 | transforms.RandomResizedCrop([512,512],scale=(0.8, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
44 | transforms.ToTensor(),
45 | ])
46 | with open(prompt_json_file, 'r') as f:
47 | self.caption_prompts = json.load(f)['caption_prompt']
48 | self.map_answer = {0:'no',1:'yes'}
49 |
50 | def __len__(self):
51 | return len(self.img_path_list)
52 |
53 | def __getitem__(self, index):
54 | img_path = self.img_path_list[index]
55 | image = Image.open(img_path).convert('RGB')
56 | image = self.transform(image)
57 | image = image.unsqueeze(-1) # c,w,h,d
58 | answer = self.map_answer[self.answer_list[index]]
59 | question = random.choice(self.caption_prompts).replace('disease',self.disease_list[index])
60 | image_dict = [{
61 | "image": image,
62 | "position": {
63 | "question": len(question)
64 | }
65 | }]
66 | return {
67 | "image_dict": image_dict,
68 | "question": question,
69 | "answer":answer,
70 | }
71 |
--------------------------------------------------------------------------------
/src/Dataset/dataset/caption_prompt.json:
--------------------------------------------------------------------------------
1 | {
2 | "caption_prompt": [
3 | "Can you provide a caption consists of finding and impression for this medical image?",
4 | "Describe the finding and impression of the medical image you see.",
5 | "Please caption this medical scan with finding and impression.",
6 | "What is the finding and impression of this image?",
7 | "Describe this medical scan with finding and impression.",
8 | "Please write a caption consists of finding and impression for this image.",
9 | "Can you summarize with finding and impression the images presented?",
10 | "Please caption this scan with finding and impression.",
11 | "Please provide a caption consists of finding and impression for this medical image.",
12 | "Can you provide a summary consists of finding and impression of this radiograph?",
13 | "What are the findings and impression presented in this medical scan?",
14 | "Please write a caption consists of finding and impression for this scan.",
15 | "Can you provide a description consists of finding and impression of this medical scan?",
16 | "Please caption this medical scan with finding and impression.",
17 | "Can you provide a caption consists of finding and impression for this medical scan?"
18 | ]
19 | }
--------------------------------------------------------------------------------
/src/Dataset/dataset/case_report.py:
--------------------------------------------------------------------------------
1 | # Import necessary libraries for data processing, image handling, and model integration
2 | from torch.utils.data import Dataset
3 | import numpy as np
4 | import transformers
5 | import pandas as pd
6 | import copy
7 | import random
8 | import os
9 | import numpy as np
10 | import tqdm
11 | import torch
12 | import json
13 | from PIL import Image
14 | import torchvision
15 | from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
16 | from torchvision import transforms
17 | from ast import literal_eval
18 |
19 | class CaseReport_dataset(Dataset):
20 | """
21 | Dataset class for medical case reports with associated images.
22 |
23 | This dataset processes medical case reports containing text and referenced images,
24 | formatting them for multimodal medical AI training or inference.
25 | """
26 | def __init__(self, csv_path, img_path):
27 | """
28 | Initialize the dataset.
29 |
30 | Args:
31 | csv_path: Path to CSV file containing case reports data
32 | img_path: Base path to the directory containing images
33 | """
34 | self.img_path = img_path # Root directory for images
35 | self.question_list = pd.read_csv(csv_path) # Load dataset from CSV
36 |
37 | # Define image transformation pipeline
38 | # normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
39 | self.transform = transforms.Compose([
40 | # Crop and resize images to 512x512, maintaining 80-100% of original content
41 | transforms.RandomResizedCrop([512, 512], scale=(0.8, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
42 | # Convert to tensor with values in [0, 1]
43 | transforms.ToTensor(),
44 | # normalize, # Commented out normalization
45 | ])
46 |
47 |
48 | def __len__(self):
49 | """Return the total number of samples in the dataset"""
50 | return len(self.question_list)
51 |
52 | def __getitem__(self, idx):
53 | """
54 | Get a single sample from the dataset
55 |
56 | Args:
57 | idx: Index of the sample to retrieve
58 |
59 | Returns:
60 | Dictionary containing the processed sample with image, question, and answer
61 | """
62 | # Get the row from dataframe
63 | sample = self.question_list.iloc[idx]
64 |
65 | # Extract metadata and content
66 | PMC_id = sample['PMC_id'] # PubMed Central ID
67 | img_ref = literal_eval(sample['img_ref']) # List of image references
68 | context = str(sample['context']) # Case context
69 |
70 | # Truncate long contexts to focus on beginning and end
71 | sentences = context.split('.')
72 | if len(sentences) > 5:
73 | first_sentence = sentences[0] # Keep the first sentence
74 | last_sentences = ". ".join(context.split('.')[-4:]) # Keep the last 4 sentences
75 | context = first_sentence + '. ' + last_sentences
76 |
77 | # Format question by combining context and actual question
78 | question = str(context) + '\n' + str(sample['question']).replace('Q:', '')
79 |
80 | # Clean up answer formatting
81 | answer = str(sample['answer']).replace('A:', '')
82 |
83 | # Process each referenced image
84 | images = []
85 | for img_id in img_ref:
86 | # Construct the full image path
87 | img_path = self.img_path + '/' + PMC_id + '_' + img_id + '.jpg'
88 |
89 | try:
90 | # Load and transform the image
91 | image = Image.open(img_path).convert('RGB')
92 | image = self.transform(image)
93 |
94 | # Randomly decide where to place the image in the text
95 | # Either at the end of question or at the end of context
96 | if random.random() > 0.5:
97 | images.append({'image': image, "position": {"question": len(question)}})
98 | else:
99 | images.append({'image': image, "position": {"question": len(context)}})
100 | except:
101 | # Skip images that can't be loaded
102 | continue
103 |
104 | # Return formatted sample
105 | return {
106 | "image_dict": images, # List of images with position information
107 | "question": question, # Formatted question text
108 | "answer": answer, # Answer text
109 | }
110 |
111 | # Example usage (commented out)
112 | # csv_path = '/gpfs/home/cs/leijiayu/wuchaoyi/multi_modal/Data/GPT_realdata/casa_report_train.csv'
113 | # img_path = '/home/cs/leijiayu/data/all_images/figures/'
114 | # dataset = CaseReport_dataset(csv_path, img_path)
115 | # print(dataset[0])
--------------------------------------------------------------------------------
/src/Dataset/dataset/chestxray.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import json
3 | import logging
4 | import os
5 | import re
6 | import difflib
7 | import sys
8 | import torch
9 | import random
10 | from abc import abstractmethod
11 | from itertools import islice
12 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
13 | from collections.abc import Mapping
14 | from torch.utils.data import DataLoader
15 | import PIL
16 | from torch.utils.data import Dataset
17 | import numpy as np
18 | import pandas as pd
19 | from tqdm import tqdm
20 | from torchvision import transforms
21 | from collections import defaultdict
22 | from PIL import Image
23 |
24 | class ChestXray_Dataset(Dataset):
25 | """_summary_
26 | Args:
27 | Dataset (_type_): caption task formulated as vqa task for Chestxray classification dataset
28 | csv_path (_type_): path to csv file
29 | img_root_dir (_type_): path to image root directory
30 | prompt_json_file (_type_): path to json file containing caption prompts
31 | Output:
32 | Dict: {
33 | "image_dict": {"image": image, "position": {"question": 0}}, # image is a tensor of shape [c,w,h,d] [3,512,512,1], position is a dict, random choice of 0 or len(question)
34 | "question": question, # random choice of caption prompts
35 | "answer":answer, # caption
36 | }
37 | """
38 | def __init__(self,csv_path,prompt_json_file):
39 | data_info = pd.read_csv(csv_path)
40 | self.img_path_list = np.asarray(data_info['image_path'])
41 | self.answer_list = np.asarray(data_info['label'])
42 | self.transform = transforms.Compose([
43 | transforms.RandomResizedCrop([512,512],scale=(0.8, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
44 | transforms.ToTensor(),
45 | ])
46 | with open(prompt_json_file, 'r') as f:
47 | self.caption_prompts = json.load(f)['caption_prompt']
48 |
49 | def __len__(self):
50 | return len(self.img_path_list)
51 |
52 | def __getitem__(self, index):
53 | img_path = self.img_path_list[index]
54 | try:
55 | image = Image.open(img_path).convert('RGB')
56 | image = self.transform(image)
57 | image = image.unsqueeze(-1) # c,w,h,d
58 | except:
59 | image = np.random.randn(3,512,512,4)
60 |
61 | answer = self.answer_list[index]
62 | question = random.choice(self.caption_prompts)
63 | image_dict = [{
64 | "image": image,
65 | "position": {
66 | "question": len(question)
67 | }
68 | }]
69 | return {
70 | "image_dict": image_dict,
71 | "question": question,
72 | "answer":answer,
73 | }
74 |
75 |
76 | if __name__ == "__main__":
77 | test_dataset = ChestXray_Dataset(csv_path = '../data_csv/chestxray.csv',
78 | prompt_json_file = './cls_prompt.json')
79 | for i in range(10):
80 | test_data = test_dataset[i]
81 | print(test_data['image_dict'][0]['image'].shape) # [3,512,512,1]
82 | #需要确保所有的chestxray img_path都有图像
83 |
84 |
85 |
86 |
87 |
--------------------------------------------------------------------------------
/src/Dataset/dataset/cls_prompt.json:
--------------------------------------------------------------------------------
1 | {
2 | "caption_prompt": [
3 | "What is the diagnosis for this chest X-ray?",
4 | "Based on this X-ray, what type of lung disease is suspected?",
5 | "Can you identify any abnormality in this chest X-ray?",
6 | "What are the findings in this chest X-ray?",
7 | "What pathology is indicated by this chest X-ray?",
8 | "What lung disease is likely present in this chest X-ray?",
9 | "What are the potential causes of the findings in this chest X-ray?",
10 | "What are your conclusions from this chest X-ray?",
11 | "What is your interpretation of this chest X-ray?",
12 | "What abnormalities are present in this chest X-ray?",
13 | "What is the differential diagnosis for the findings in this chest X-ray?"
14 | ]
15 | }
--------------------------------------------------------------------------------
/src/Dataset/dataset/data_csv/README.md:
--------------------------------------------------------------------------------
1 | Please check the [data_csv](https://huggingface.co/datasets/chaoyi-wu/RadFM_data_csv) to download the used train/test split csv files and ensure the image path are related to your local path.
--------------------------------------------------------------------------------
/src/Dataset/dataset/dicom_to_png_for_VinDR_sampled_using_mammo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import csv
4 | import json
5 | import imageio
6 |
7 | import pandas as pd
8 | import numpy as np
9 | from tqdm import tqdm
10 |
11 | import matplotlib.pyplot as plt
12 | from pydicom import dcmread
13 |
14 | def dcm_to_png(dcm_path,save_png_path):
15 | ds = dcmread(dcm_path)
16 | arr = ds.pixel_array
17 | img_array = arr.copy()
18 | cv2.normalize(arr, img_array, 0, 255, cv2.NORM_MINMAX)
19 | img_array = np.array(img_array,dtype='uint8')
20 | # img_array = cv2.resize(img_array, (512,512), interpolation = cv2.INTER_LINEAR)
21 | imageio.imwrite(save_png_path,img_array)
22 |
23 | def preprocess_csv(csv_path,data_dir,save_data_dir):
24 | data_info = pd.read_csv(csv_path)
25 | patient_file_list = data_info.iloc[:,0]
26 | img_file_list = data_info.iloc[:,2]
27 | for idx in tqdm(range(len(img_file_list))):
28 | patient_file = patient_file_list[idx]
29 | img_file = img_file_list[idx]
30 | img_path = os.path.join(data_dir,str(patient_file),str(img_file)+'.dicom')
31 | os.makedirs(os.path.join(save_data_dir,str(patient_file)), exist_ok=True)
32 | save_img_path = os.path.join(save_data_dir,str(patient_file),str(img_file)+'.png')
33 | dcm_to_png(img_path,save_img_path)
34 |
35 |
36 | csv_path = './DATA/VinDr/VinDr-Mammo/1.0.0/breast-level_annotations.csv'
37 | data_dir = './DATA/VinDr/VinDr-Mammo/1.0.0/images'
38 | save_data_dir = './DATA/VinDr/VinDr-Mammo/process/images'
39 | os.makedirs(save_data_dir, exist_ok=True)
40 | preprocess_csv(csv_path,data_dir,save_data_dir)
41 |
--------------------------------------------------------------------------------
/src/Dataset/dataset/jpg2nii_data_convert.py:
--------------------------------------------------------------------------------
1 | #processed cases accoring to case_id_list, and save a csv file, with image path and image caption
2 | import os
3 | import cv2
4 | import csv
5 | import json
6 | import subprocess
7 | import pandas as pd
8 | import numpy as np
9 | import SimpleITK as sitk
10 | from tqdm import tqdm
11 | from collections import defaultdict
12 |
13 | def get_image(single_image_dir,single_image_filenames):
14 | # single_image_filenames
15 | single_image_filenames.sort(key=lambda x: int(x.split('.')[0]))
16 | image_list = []
17 | for image_filename in single_image_filenames:
18 | image_file = os.path.join(single_image_dir, image_filename)
19 | #read jpeg to 2D array
20 | image_array = cv2.imread(image_file,0)
21 | if image_array is not None:
22 | image_size = image_array.shape
23 | image_array = cv2.resize(image_array,(512,512),interpolation = cv2.INTER_LINEAR)
24 | image_list.append(image_array)
25 | else:
26 | pass
27 | image_array = np.array(image_list) #c,w,h
28 | if len(image_array.shape) == 3:
29 | if image_array.shape[0] < image_array.shape[1]:
30 | image_array = image_array.transpose((1, 2, 0))
31 | # image_array = np.transpose(image_array, (2,0,1)) # w,h,c
32 | return image_array
33 |
34 | gray_list = ['CT','MRI','X-ray','Ultrasound','Mammography']
35 |
36 | def convert_case(case_id,image_root_dir,json_root_dir,save_case_dict,save_root_dir=None):
37 | # save_image_dir
38 | case_images_dir = os.path.join(image_root_dir, case_id)
39 | case_json_path = os.path.join(json_root_dir, case_id+'.json')
40 | with open(case_json_path, 'r') as f:
41 | data = json.load(f)
42 | image_nums = (len(data.keys())-1)//2
43 | for image_num in range(1,image_nums+1):
44 | case_dict = defaultdict(list)
45 | image_dir = os.path.join(case_images_dir, str(image_num)) #./images/1/1
46 | image_caption = data[str(image_num) + '详情']
47 | image_modality = data[str(image_num)][0]['modality']
48 |
49 | single_image_names = os.listdir(image_dir)
50 | single_image_names.sort(key=lambda x: int(x.split('_')[1]))
51 | save_image_series = []
52 |
53 | for single_image_name in single_image_names:
54 | single_image_dir = os.path.join(image_dir, single_image_name)
55 |
56 | save_npy_dir = os.path.join(save_root_dir,str(case_id),str(image_num))
57 |
58 |
59 | single_image_filenames = os.listdir(single_image_dir)
60 | if len(os.listdir(single_image_dir)) == 1:
61 | # 2D image
62 | image_file = os.path.join(single_image_dir, single_image_filenames[0])
63 | save_image_array = cv2.imread(image_file) # w,h,c
64 | else:
65 | save_image_array = get_image(single_image_dir,single_image_filenames)
66 | if not os.path.exists(save_npy_dir):
67 | os.makedirs(save_npy_dir)
68 | # print(save_image_array.shape)
69 | if save_image_array is not None:
70 | if len(save_image_array.shape) <= 5 and len(save_image_array.shape) >=2:
71 | save_nii_path = os.path.join(save_npy_dir,single_image_name+'.nii.gz')
72 | out = sitk.GetImageFromArray(save_image_array)
73 | sitk.WriteImage(out, save_nii_path)
74 | save_image_series.append(save_nii_path)
75 | else:
76 | save_npy_path = os.path.join(save_npy_dir,single_image_name+'.npy')
77 | np.save(save_npy_path,save_image_array)
78 | save_image_series.append(save_npy_path)
79 | case_dict['image'] = save_image_series
80 | case_dict['image_caption'] = image_caption
81 | case_dict['image_modality'] = image_modality
82 | save_case_dict.append(case_dict)
83 |
84 | if __name__ == "__main__":
85 | # case_id,image_root_dir,json_root_dir
86 | import argparse
87 | parser = argparse.ArgumentParser()
88 | parser.add_argument('--index', default=0, type=int)
89 | parser.add_argument('--add_index', default=0, type=int)
90 | parser.add_argument('--start_index', default=1, type=int)
91 | parser.add_argument('--end_index', default=1000, type=int)
92 | args = parser.parse_args()
93 |
94 | image_root_dir = '/mnt/petrelfs/share_data/zhangxiaoman/DATA/Radio_VQA/processed_file/images'
95 | json_root_dir = '/mnt/petrelfs/share_data/zhangxiaoman/DATA/Radio_VQA/processed_file/jsons'
96 | save_root_dir = '/mnt/petrelfs/share_data/zhangxiaoman/DATA/Radio_VQA/processed_file/npys'
97 | save_case_dict = []
98 |
99 | args.start_index = args.index*1000+1 + args.add_index
100 | args.end_index = (args.index+1)*1000+1
101 |
102 | for case_id in tqdm(range(args.start_index,args.end_index)):
103 | case_id = str(case_id)
104 | convert_case(case_id,image_root_dir,json_root_dir,save_case_dict,save_root_dir)
105 | # CT_0 (200, 630, 630, 3)
106 |
107 | # save to csv
108 | save_json_file = '/mnt/petrelfs/share_data/zhangxiaoman/DATA/Radio_VQA/processed_file/processed_jsons/processed_json_'+str(args.index)+'.json'
109 | with open(save_json_file, 'w', encoding='utf-8') as f:
110 | json.dump(save_case_dict, f, ensure_ascii=False,indent=4)
111 | # B, S, T, W, H, Z
112 | # srun --partition=medai --mpi=pmi2 --quotatype=auto --gres=gpu:0 -n1 --ntasks-per-node=1 python data_convert.py --index 2 --add_index 24
113 | # cd /mnt/petrelfs/share_data/zhangxiaoman/DATA/Radio_VQA/jpeg2npy
--------------------------------------------------------------------------------
/src/Dataset/dataset/mammo_prompt.json:
--------------------------------------------------------------------------------
1 | {
2 | "caption_prompt": [
3 | "What is the diagnosis for this mammogram?",
4 | "Based on this X-ray, what type of breast disease is suspected?",
5 | "Can you identify any abnormality in this mammogram?",
6 | "What are the findings in this mammogram?",
7 | "What pathology is indicated by this mammogram?",
8 | "What lung disease is likely present in this mammogram?",
9 | "What are the potential causes of the findings in this mammogram?",
10 | "What are your conclusions from this mammogram?",
11 | "What is your interpretation of this mammogram?",
12 | "What abnormalities are present in this mammogram?",
13 | "What is the differential diagnosis for the findings in this mammogram?",
14 | "What is the diagnosis for this breast X-ray?",
15 | "Can you identify any abnormality in this breast X-ray?",
16 | "What are the findings in this breast X-ray?",
17 | "What pathology is indicated by this breast X-ray?",
18 | "What lung disease is likely present in this breast X-ray?",
19 | "What are the potential causes of the findings in this breast X-ray?",
20 | "What are your conclusions from this breast X-ray?",
21 | "What is your interpretation of this breast X-ray?",
22 | "What abnormalities are present in this breast X-ray?",
23 | "What is the differential diagnosis for the findings in this breast X-ray?"
24 | ]
25 | }
--------------------------------------------------------------------------------
/src/Dataset/dataset/modality_prompt.json:
--------------------------------------------------------------------------------
1 | {
2 | "caption_prompt": [
3 | "What modality is used to take this image?",
4 | "What type of imaging modality is used to acquire the above image?",
5 | "What imaging modality is used?",
6 | "What imaging modality was used to take this image?",
7 | "What is the imaging modality?"
8 | ],
9 | "modality_prompt": [
10 | "Is this image a modality scan?",
11 | "Is the given image a modality scan?",
12 | "Is the given image a modality?"
13 | ]
14 | }
15 |
--------------------------------------------------------------------------------
/src/Dataset/dataset/nii2npy_for_radiopaedio.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import csv
4 | import json
5 | import subprocess
6 | import pandas as pd
7 | import numpy as np
8 | import SimpleITK as sitk
9 | from tqdm import tqdm
10 | from scipy import ndimage
11 | from collections import defaultdict
12 |
13 | def resize_array(array_list, shape_list):
14 | if len(array_list) == 0:
15 | return None
16 | # Get the median value of the c dimension
17 | c_values = [shape[3] for shape in shape_list]
18 | z = np.median(c_values)
19 |
20 | # Resize each array to the same size
21 | resized_arrays = []
22 | for array in array_list:
23 | resized_array = ndimage.zoom(array, (3/array.shape[0],512/array.shape[1], 512/array.shape[2], z/array.shape[3]), order=0)
24 | # print(resized_array.shape)
25 | if resized_array.shape[3] == z:
26 | resized_arrays.append(resized_array)
27 | else:
28 | if resized_array.shape[3] > z:
29 | resized_arrays.append(resized_array[:,:,:,:int(z)])
30 | else:
31 | resized_arrays.append(np.pad(resized_array, ((0,0),(0,0),(0,0),(0,int(z-resized_array.shape[3]))), 'constant', constant_values=0))
32 | # Convert the list of arrays to a numpy array
33 | resized_array = np.array(resized_arrays)
34 |
35 | return resized_array
36 |
37 | def process_image_list(image_path_list):
38 | image_shape_list = []
39 | image_array_list = []
40 | for image_path in image_path_list:
41 | if os.path.exists(image_path) == False:
42 | continue
43 | elif image_path.split('.')[-1] == 'npy':
44 | image_array = np.load(image_path) #c,w,h,d
45 | try:
46 | image_array = cv2.resize(image_array,(512,512))
47 | if len(image_array.shape) == 2:
48 | image_array = image_array[np.newaxis,:,:,np.newaxis]
49 | # 1wh1 to 3wh1
50 | image_array = np.concatenate([image_array,image_array,image_array],axis=0)
51 | elif len(image_array.shape) == 3:
52 | #whc to cwh
53 | image_array = image_array.transpose(2,0,1)[:,:,:,np.newaxis]
54 |
55 | image_shape_list.append(image_array.shape)
56 | image_array_list.append(image_array)
57 | except:
58 | pass
59 | else:
60 | itk_image = sitk.ReadImage(image_path)
61 | image_array = sitk.GetArrayFromImage(itk_image) #c,w,h,d
62 | if image_array.shape[0] != 512:
63 | image_array = cv2.resize(image_array,(512,512))
64 | if len(image_array.shape) == 2:
65 | image_array = image_array[np.newaxis,:,:,np.newaxis]
66 | image_array = np.concatenate([image_array,image_array,image_array],axis=0)
67 | elif len(image_array.shape) == 3:
68 | image_array = image_array[np.newaxis,:,:,:]
69 | image_array = np.concatenate([image_array,image_array,image_array],axis=0)
70 | image_shape_list.append(image_array.shape)
71 | image_array_list.append(image_array)
72 | save_image_array = resize_array(image_array_list, image_shape_list)
73 | return save_image_array
74 |
75 |
76 | def process_json_file(json_file,save_json_file,save_root_dir):
77 | if not os.path.exists(save_root_dir):
78 | os.makedirs(save_root_dir)
79 | with open(json_file, 'r') as f:
80 | data = json.load(f)
81 | data_len = len(data)
82 | for i in tqdm(range(data_len)):
83 | samples = data[i]['samples']
84 | for sample_i in tqdm(range(len(samples))):
85 | if samples[sample_i]['image'] == []:
86 | samples.pop(sample_i)
87 | else:
88 | image_path_list = samples[sample_i]['image']
89 | case_id = image_path_list[0].split('/')[-3]
90 | save_image_array = process_image_list(image_path_list)
91 | if save_image_array is not None:
92 | save_image_path = os.path.join(save_root_dir, str(case_id)+'_'+str(sample_i)+'.npy')
93 | np.save(save_image_path,save_image_array)
94 | # 如果边处理边传到aws的话可以参考这一段
95 | # save_aws_image_path = save_image_path.replace('/mnt/petrelfs/share_data/zhangxiaoman/DATA/','s3://zhangxiaoman_hdd_new_share/')
96 | # os.system(f'aws s3 cp {save_image_path} {save_aws_image_path} --endpoint-url=http://10.140.27.254')
97 | # os.remove(save_image_path)
98 | # data[i]['npy_path'] = save_aws_image_path
99 | data[i]['samples']['npy_path'] = save_image_path
100 | data[i]['samples']['image_size'] = save_image_array.shape
101 | else:
102 | print(i,image_path_list)
103 | if len(samples) == 0:
104 | data.pop(i)
105 |
106 | with open(save_json_file, 'w') as f:
107 | json.dump(data, f,ensure_ascii=False,indent=4)
108 |
109 |
110 | if __name__ == "__main__":
111 | json_file = '../processed_file/processed_jsons/processed_json_2023-11-18.json'
112 | save_json_file = '../processed_file/processed_jsons/processed_json_2023-11-18-npy.json'
113 | save_root_dir = '../processed_file/processed_images'
114 |
115 | process_json_file(json_file,save_json_file,save_root_dir)
116 |
--------------------------------------------------------------------------------
/src/Dataset/dataset/paper_inline.py:
--------------------------------------------------------------------------------
1 | # Import necessary libraries for data processing, image handling, and model integration
2 | from torch.utils.data import Dataset
3 | import numpy as np
4 | import transformers
5 | import pandas as pd
6 | import copy
7 | import random
8 | import os
9 | import numpy as np
10 | import tqdm
11 | import torch
12 | import json
13 | from PIL import Image
14 | import torchvision
15 | from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaTokenizer
16 | from torchvision import transforms
17 |
18 | class Paper_Inline_dataset(Dataset):
19 | """
20 | Dataset class for processing scientific papers with inline images.
21 |
22 | This dataset extracts text and associated images from scientific papers,
23 | preparing them for multimodal model training.
24 | """
25 | def __init__(self, csv_path, img_path, sample_sentence_length=50, max_img_size=3):
26 | """
27 | Initialize the dataset.
28 |
29 | Args:
30 | csv_path: Path to CSV file containing paper metadata
31 | img_path: Root directory for paper figures
32 | sample_sentence_length: Maximum number of sentences to include in a sample
33 | max_img_size: Maximum number of images to include in a sample
34 | """
35 | self.max_img_size = max_img_size
36 | self.sample_sentence_length = sample_sentence_length
37 | self.img_path = img_path
38 | # Load paper paths from CSV
39 | self.paper_path = np.array(pd.read_csv(csv_path)['PMC_path'])
40 |
41 | # Define image transformation pipeline
42 | # normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
43 | self.transform = transforms.Compose([
44 | # Crop and resize images to 512x512, maintaining 80-100% of original content
45 | transforms.RandomResizedCrop([512, 512], scale=(0.8, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
46 | # Convert to tensor with values in [0, 1]
47 | transforms.ToTensor(),
48 | # normalize, # Commented out normalization
49 | ])
50 |
51 |
52 | def __len__(self):
53 | """Return the total number of papers in the dataset"""
54 | return self.paper_path.shape[0]
55 |
56 | def __getitem__(self, idx):
57 | """
58 | Get a single sample from the dataset
59 |
60 | Args:
61 | idx: Index of the paper to retrieve
62 |
63 | Returns:
64 | Dictionary containing the processed sample with images, question, and answer
65 | """
66 | # Load the paper JSON file
67 | paper_json = self.paper_path[idx]
68 | # Extract PMC ID from the file path
69 | PMC_name = paper_json.rsplit('/', 2)[-1].split('.')[0]
70 | # Load the list of sentences with image references
71 | sentences_list = json.load(open(paper_json, 'r'))
72 | # Process the paper to extract text and images
73 | image_dict, question, answer = self.random_sample_sentence(sentences_list, PMC_name)
74 |
75 | # Return formatted sample
76 | # Note: question is empty since this is for pretraining with full paper text
77 | return {
78 | "image_dict": image_dict, # List of images with position information
79 | "question": question, # Empty string for this dataset
80 | "answer": answer, # Full text content
81 | }
82 |
83 | def random_sample_sentence(self, sentences_list, PMC_name):
84 | """
85 | Sample a segment of sentences from a paper and process inline images
86 |
87 | Args:
88 | sentences_list: List of sentences with image references
89 | PMC_name: PubMed Central ID for the paper
90 |
91 | Returns:
92 | Tuple of (processed_images, question_text, answer_text)
93 | """
94 | sentences_length = len(sentences_list)
95 |
96 | # Select a segment of the paper - either randomly or around image references
97 | p = random.random()
98 | if p >= 0.5:
99 | # Random segment selection
100 | if len(sentences_list) > self.sample_sentence_length:
101 | start = random.randint(0, sentences_length - self.sample_sentence_length)
102 | sentences_list = sentences_list[start:(start + self.sample_sentence_length)]
103 | else:
104 | # Try to select a segment containing images
105 | if len(sentences_list) > self.sample_sentence_length:
106 | sample_start = []
107 | # Find sentences with image references
108 | for sentence_id in range(len(sentences_list)):
109 | if sentences_list[sentence_id]['img_ref'] != []:
110 | # Start 10 sentences before the image if possible
111 | if sentence_id - 10 < 0:
112 | sample_start.append(0)
113 | else:
114 | if sentence_id - 10 > sentences_length - self.sample_sentence_length:
115 | sample_start.append(sentences_length - self.sample_sentence_length)
116 | else:
117 | sample_start.append(sentence_id - 10)
118 |
119 | # If no images found, select random segment
120 | if sample_start == []:
121 | start = random.randint(0, sentences_length - self.sample_sentence_length)
122 | sentences_list = sentences_list[start:(start + self.sample_sentence_length)]
123 | else:
124 | # Select a random segment that contains images
125 | start = sample_start[random.randint(0, len(sample_start) - 1)]
126 | sentences_list = sentences_list[start:(start + self.sample_sentence_length)]
127 |
128 | # Process the selected segment
129 | text = ''
130 | images = []
131 | for ix in sentences_list:
132 | sentence = ix
133 | if sentence["img_ref"] == []:
134 | # Add plain text without images
135 | text = text + sentence['text']
136 | else:
137 | # Stop if we've reached the maximum number of images
138 | if len(images) + len(sentence["img_ref"]) > self.max_img_size:
139 | break
140 |
141 | # Process each image referenced in the sentence
142 | for img_id in sentence["img_ref"]:
143 | img_path = self.img_path + '/' + PMC_name + '_' + img_id + '.jpg'
144 | if os.path.exists(img_path):
145 | try:
146 | # Load and transform the image
147 | image = Image.open(img_path).convert('RGB')
148 | image = self.transform(image)
149 | # Add image with position information
150 | images.append({'image': image, "position": {"answer": len(text)}})
151 | except:
152 | # Skip images that can't be loaded
153 | continue
154 | # Add the text after processing images
155 | text = text + sentence['text']
156 |
157 | # For this dataset, we don't use a question-answer format
158 | # Instead, all text is in the "answer" field
159 | question = ''
160 | answer = text
161 |
162 | return images, question, answer
163 |
164 | # Example usage (commented out)
165 | # csv_path = '/home/cs/leijiayu/wuchaoyi/multi_modal/Data/train_paper.csv'
166 | # img_path = '/home/cs/leijiayu/data/all_images/figures/'
167 | # dataset = multi_paper_dataset(csv_path, img_path)
168 | # print(dataset[0])
--------------------------------------------------------------------------------
/src/Dataset/dataset/pmcoa.py:
--------------------------------------------------------------------------------
1 | # Import necessary libraries for data processing, image handling, and model interaction
2 | import csv
3 | import json
4 | import logging
5 | import os
6 | import re
7 | import difflib
8 | import sys
9 | import torch
10 | import random
11 | from abc import abstractmethod
12 | from itertools import islice
13 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
14 | from collections.abc import Mapping
15 | from torch.utils.data import DataLoader
16 | import PIL
17 | from torch.utils.data import Dataset
18 | import numpy as np
19 | import pandas as pd
20 | from tqdm import tqdm
21 | from torchvision import transforms
22 | from collections import defaultdict
23 | from PIL import Image
24 |
25 | class PMCOA_Dataset(Dataset):
26 | """
27 | Dataset for processing scientific figures and captions from PubMed Central Open Access (PMC-OA).
28 |
29 | This dataset formulates image captioning as a visual question answering task,
30 | where the model is prompted with a question about an image and should respond
31 | with an appropriate caption.
32 |
33 | Args:
34 | csv_path: Path to CSV file with columns [PMC_ID, Figure_path, Caption]
35 | img_root_dir: Path to image root directory containing figure images
36 | prompt_json_file: Path to JSON file containing caption prompts
37 |
38 | Output:
39 | Dict: {
40 | "image_dict": [{"image": image, "position": {"question": position}}],
41 | # image is a tensor of shape [c,w,h,d] [3,512,512,1]
42 | # position is where to insert the image - either at start (0) or end of question
43 | "question": question, # randomly selected caption prompt
44 | "answer": answer, # original caption from the paper
45 | }
46 | """
47 | def __init__(self, csv_path, img_root_dir, prompt_json_file):
48 | """
49 | Initialize the dataset.
50 |
51 | Args:
52 | csv_path: Path to CSV file with figure metadata
53 | img_root_dir: Root directory containing figure images
54 | prompt_json_file: JSON file with caption prompts
55 | """
56 | self.img_root_dir = img_root_dir
57 |
58 | # Load metadata from CSV file
59 | data_info = pd.read_csv(csv_path)
60 | self.img_path_list = np.asarray(data_info['Figure_path'])
61 | self.caption_list = np.asarray(data_info['Caption'])
62 |
63 | # Define image transformation pipeline
64 | # normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
65 | self.transform = transforms.Compose([
66 | # Crop and resize images to 512x512, maintaining 80-100% of original content
67 | transforms.RandomResizedCrop([512, 512], scale=(0.8, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
68 | # Convert to tensor with values in [0, 1]
69 | transforms.ToTensor(),
70 | # normalize, # Commented out normalization
71 | ])
72 |
73 | # Load caption prompts from JSON file
74 | with open(prompt_json_file, 'r') as f:
75 | self.caption_prompts = json.load(f)['caption_prompt']
76 |
77 |
78 | def __len__(self):
79 | """Return the total number of samples in the dataset"""
80 | return len(self.img_path_list)
81 |
82 | def __getitem__(self, index):
83 | """
84 | Get a single sample from the dataset
85 |
86 | Args:
87 | index: Index of the sample to retrieve
88 |
89 | Returns:
90 | Dictionary containing processed sample with image, question prompt, and caption answer
91 | """
92 | # Get the image filename and construct full path
93 | file_name = self.img_path_list[index]
94 | img_path = os.path.join(self.img_root_dir, file_name)
95 |
96 | # Load and preprocess the image
97 | image = Image.open(img_path).convert('RGB')
98 | image = self.transform(image) # normalize to [0,1]
99 | image = image.unsqueeze(-1) # add depth dimension [C, H, W, 1]
100 |
101 | # Get the caption and a random prompt
102 | answer = self.caption_list[index]
103 | question = random.choice(self.caption_prompts)
104 |
105 | # Randomly decide whether to place the image before or after the question
106 | if random.random() < 0.5:
107 | # Place image before the question
108 | image_dict = {
109 | "image": image,
110 | "position": {
111 | "question": 0 # At the beginning of question
112 | }
113 | }
114 | else:
115 | # Place image after the question
116 | image_dict = {
117 | "image": image,
118 | "position": {
119 | "question": len(question) # At the end of question
120 | }
121 | }
122 |
123 | # Return formatted sample
124 | return {
125 | "image_dict": [image_dict], # List containing one image with position info
126 | "question": question, # Caption prompt
127 | "answer": answer, # Ground truth caption
128 | }
129 |
130 | if __name__ == "__main__":
131 | # Example usage for testing the dataset
132 | test_dataset = PMCOA_Dataset(
133 | csv_path='../data_csv/pmcoa_image_caption_train.csv',
134 | img_root_dir='/home/cs/leijiayu/data/PMCVQA/caption_T060_filtered_top4_sep_v0_subfigures',
135 | prompt_json_file='./caption_prompt.json'
136 | )
137 |
138 | # Test the first 10 samples
139 | for i in range(10):
140 | test_data = test_dataset[i]
141 | print(test_data['image_dict'][0]['image'].shape) # Should print [3,512,512,1]
--------------------------------------------------------------------------------
/src/Dataset/dataset/radiology_feature_prompt.json:
--------------------------------------------------------------------------------
1 | {
2 | "caption_prompt": [
3 | "What disease can be diagnosed from these radiological images and what specific features are typically observed on the images?",
4 | "Identify the disease that is typically associated with these radiological images and describe the classic radiological presentation.",
5 | "Based on the provided images, which disease is most likely to be diagnosed and how does it manifest on radiological examinations?",
6 | "Determine the disease that corresponds to the given radiographic images and describe the characteristic radiological features.",
7 | "With these radiological images, which disease would you suspect and what specific radiographic patterns are typically seen?",
8 | "Analyze the provided images and identify the disease that is commonly associated with such radiological findings. Discuss the characteristic radiographic manifestations.",
9 | "From these radiological images, diagnose the disease and explain the typical radiological presentation observed.",
10 | "Assess the radiographic images and determine the disease that is commonly linked to these findings. Describe the typical radiological features associated with this disease.",
11 | "Examine the provided radiological images and identify the disease that would most likely be diagnosed based on the characteristic radiologic appearance.",
12 | "Based on the presented radiographic findings, indicate the disease that is commonly associated with these images and describe the typical radiological patterns observed."
13 | ]
14 | }
15 |
--------------------------------------------------------------------------------
/src/Dataset/dataset/report_prompt.json:
--------------------------------------------------------------------------------
1 | {
2 | "caption_prompt": [
3 | "Can you provide a radiology report for this medical image?",
4 | "Describe the medical image you see.",
5 | "What is depicted in this picture?",
6 | "Please report this medical scan.",
7 | "What is the medical significance of this image?",
8 | "What can you infer from this picture?",
9 | "Can you provide a quick summary of this image?",
10 | "Describe this medical scan.",
11 | "Please write a radiology report for this image.",
12 | "Can you summarize the images presented?",
13 | "Please generate a radiology report for this scan.",
14 | "Describe the regions of interest in this scan.",
15 | "Please provide a caption for this medical image.",
16 | "Can you provide a brief summary of this radiograph?",
17 | "Describe the structures involved in this medical image.",
18 | "What are the findings presented in this medical scan?",
19 | "Please write a radiology report for this scan.",
20 | "Can you provide a description of this medical scan?",
21 | "Please caption this medical scan.",
22 | "Can you provide a report summary for this medical scan?"
23 | ]
24 | }
--------------------------------------------------------------------------------
/src/Dataset/dataset/spinexr_prompt.json:
--------------------------------------------------------------------------------
1 | {
2 | "caption_prompt": [
3 | "What is the diagnosis for this spine X-ray?",
4 | "Based on this X-ray, what type of spine disease is suspected?",
5 | "Can you identify any abnormality in this spine X-ray?",
6 | "What are the findings in this spine X-ray?",
7 | "What pathology is indicated by this spine X-ray?",
8 | "What lung disease is likely present in this spine X-ray?",
9 | "What are the potential causes of the findings in this spine X-ray?",
10 | "What are your conclusions from this spine X-ray?",
11 | "What is your interpretation of this spine X-ray?",
12 | "What abnormalities are present in this spine X-ray?",
13 | "What is the differential diagnosis for the findings in this spine X-ray?"
14 | ]
15 | }
--------------------------------------------------------------------------------
/src/Dataset/dataset/vqa.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import json
3 | import logging
4 | import os
5 | import re
6 | import difflib
7 | import sys
8 | import torch
9 | import random
10 | from abc import abstractmethod
11 | from itertools import islice
12 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
13 | from collections.abc import Mapping
14 | from torch.utils.data import DataLoader
15 | import PIL
16 | from torch.utils.data import Dataset
17 | import numpy as np
18 | import pandas as pd
19 | from tqdm import tqdm
20 | from torchvision import transforms
21 | from collections import defaultdict
22 | from PIL import Image
23 |
24 | class VQA_Dataset(Dataset):
25 | """_summary_
26 | Args:
27 | Dataset (_type_):
28 | csv_path (_type_): path to csv file
29 | Output:
30 | Dict: {
31 | "image_dict": {"image": image, "position": {"question": 0}}, # image is a tensor of shape [c,w,h,d] [3,512,512,1], position is a dict, random choice of 0 or len(question)
32 | "question": question, # random choice of caption prompts
33 | "answer":answer, # caption
34 | }
35 | """
36 | def __init__(self,csv_path):
37 | data_info = pd.read_csv(csv_path)
38 | self.img_root_dir_list = np.asarray(data_info['img_root_dir'])
39 | self.img_path_list = np.asarray(data_info['Figure_path'])
40 | self.question_list = np.asarray(data_info['Question'])
41 | self.answer_list = np.asarray(data_info['Answer'])
42 | # PMC_ID,Figure_path,Question,Answer
43 | self.transform = transforms.Compose([
44 | transforms.RandomResizedCrop([512,512],scale=(0.8, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
45 | transforms.ToTensor(),
46 | ])
47 |
48 |
49 | def __len__(self):
50 | return len(self.img_path_list)
51 |
52 | def __getitem__(self, index):
53 | file_name = self.img_path_list[index]
54 | img_root_dir = self.img_root_dir_list[index]
55 | img_path = os.path.join(img_root_dir,file_name)
56 | image = Image.open(img_path).convert('RGB')
57 | image = self.transform(image)
58 | image = image.unsqueeze(-1)
59 | answer = self.answer_list[index]
60 | question = str(self.question_list[index])
61 | if random.random() < 0.5:
62 | image_dict = {
63 | "image": image,
64 | "position": {
65 | "question": 0
66 | }
67 | }
68 | else:
69 | image_dict = {
70 | "image": image,
71 | "position": {
72 | "question": len(question)
73 | }
74 | }
75 | return {
76 | "image_dict": [image_dict],
77 | "question": question,
78 | "answer":answer,
79 | }
80 |
81 | if __name__ == "__main__":
82 | test_dataset = PMCVQA_Dataset(csv_path = '../data_csv/pmcvqa_train.csv')
83 | for i in range(10):
84 | test_data = test_dataset[i]
85 | print(test_data['image_dict'][0]['image'].shape) # [3,512,512,1]
86 |
87 |
88 |
89 |
90 |
91 |
--------------------------------------------------------------------------------
/src/Dataset/dataset/yes_no_prompt.json:
--------------------------------------------------------------------------------
1 | {
2 | "caption_prompt": [
3 | "Is the disease visible in the image?",
4 | "Does the image show signs of disease?",
5 | "Does the image show any disease?",
6 | "Is there any disease in the affected area?",
7 | "Does the image depict any visible disease?",
8 | "Is there an presence of disease in the image?",
9 | "Are there any visible signs of disease in the image?",
10 | "Does the image exhibit any disease?",
11 | "Are there disease visible in the image?",
12 | "Does the image show any signs of disease?",
13 | "Can you identify any visible signs of disease in the image?",
14 | "Is there any indication of disease in the image?",
15 | "Does the image show signs of disease?",
16 | "Does the image show any visible signs of disease?"
17 | ]
18 | }
19 |
20 |
21 |
--------------------------------------------------------------------------------
/src/Model/RadFM/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Model/RadFM/__init__.py
--------------------------------------------------------------------------------
/src/Model/RadFM/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Model/RadFM/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/src/Model/RadFM/__pycache__/blocks.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Model/RadFM/__pycache__/blocks.cpython-39.pyc
--------------------------------------------------------------------------------
/src/Model/RadFM/__pycache__/helpers.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Model/RadFM/__pycache__/helpers.cpython-39.pyc
--------------------------------------------------------------------------------
/src/Model/RadFM/__pycache__/multimodality_model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Model/RadFM/__pycache__/multimodality_model.cpython-39.pyc
--------------------------------------------------------------------------------
/src/Model/RadFM/__pycache__/my_embedding_layer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Model/RadFM/__pycache__/my_embedding_layer.cpython-39.pyc
--------------------------------------------------------------------------------
/src/Model/RadFM/__pycache__/position_encoding.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Model/RadFM/__pycache__/position_encoding.cpython-39.pyc
--------------------------------------------------------------------------------
/src/Model/RadFM/__pycache__/transformer_decoder.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Model/RadFM/__pycache__/transformer_decoder.cpython-39.pyc
--------------------------------------------------------------------------------
/src/Model/RadFM/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Model/RadFM/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/src/Model/RadFM/__pycache__/vit_3d.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/Model/RadFM/__pycache__/vit_3d.cpython-39.pyc
--------------------------------------------------------------------------------
/src/Model/RadFM/blocks.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union, Callable, Optional
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn
7 | from torch.utils.checkpoint import checkpoint
8 |
9 | class PMC_CLIP_cfg:
10 | backbone: str = 'ModifiedRN50' # ['RN50', 'ModifiedRN50', 'MAE']
11 | layers: Union[Tuple[int, int, int, int], int] = [3,4,6,3]
12 | width: int = 64
13 | head_width: int = 64
14 | mlp_ratio: float = 4.0
15 | patch_size: int = 16
16 | image_size: Union[Tuple[int, int], int] = 224
17 | timm_model_name: str = None # a valid model name overrides layers, width, patch_size
18 | timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
19 | timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
20 | timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
21 | patch_dropout: float = 0.0 # patch dropout rate, no dropout by default
22 | drop_attention_rate: float = 0. # Transformer Dropout
23 | patch_size: None
24 |
25 | class Bottleneck(nn.Module):
26 | expansion = 4
27 |
28 | def __init__(self, inplanes, planes, stride=1):
29 | super().__init__()
30 |
31 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
32 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
33 | self.bn1 = nn.BatchNorm2d(planes)
34 | self.relu1 = nn.ReLU(inplace=True)
35 |
36 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
37 | self.bn2 = nn.BatchNorm2d(planes)
38 | self.relu2 = nn.ReLU(inplace=True)
39 |
40 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
41 |
42 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
43 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
44 | self.relu3 = nn.ReLU(inplace=True)
45 |
46 | self.downsample = None
47 | self.stride = stride
48 |
49 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
50 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
51 | self.downsample = nn.Sequential(OrderedDict([
52 | ("-1", nn.AvgPool2d(stride)),
53 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
54 | ("1", nn.BatchNorm2d(planes * self.expansion))
55 | ]))
56 |
57 | def forward(self, x: torch.Tensor):
58 | identity = x
59 |
60 | out = self.relu1(self.bn1(self.conv1(x)))
61 | out = self.relu2(self.bn2(self.conv2(out)))
62 | out = self.avgpool(out)
63 | out = self.bn3(self.conv3(out))
64 |
65 | if self.downsample is not None:
66 | identity = self.downsample(x)
67 |
68 | out += identity
69 | out = self.relu3(out)
70 | return out
71 |
72 |
73 | class AttentionPool2d(nn.Module):
74 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
75 | super().__init__()
76 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
77 | self.k_proj = nn.Linear(embed_dim, embed_dim)
78 | self.q_proj = nn.Linear(embed_dim, embed_dim)
79 | self.v_proj = nn.Linear(embed_dim, embed_dim)
80 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
81 | self.num_heads = num_heads
82 |
83 | def forward(self, x):
84 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
85 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
86 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
87 | x, _ = F.multi_head_attention_forward(
88 | query=x, key=x, value=x,
89 | embed_dim_to_check=x.shape[-1],
90 | num_heads=self.num_heads,
91 | q_proj_weight=self.q_proj.weight,
92 | k_proj_weight=self.k_proj.weight,
93 | v_proj_weight=self.v_proj.weight,
94 | in_proj_weight=None,
95 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
96 | bias_k=None,
97 | bias_v=None,
98 | add_zero_attn=False,
99 | dropout_p=0,
100 | out_proj_weight=self.c_proj.weight,
101 | out_proj_bias=self.c_proj.bias,
102 | use_separate_proj_weight=True,
103 | training=self.training,
104 | need_weights=False
105 | )
106 |
107 | return x[0]
108 |
109 |
110 | class ResNet(nn.Module):
111 | """
112 | RN50
113 | """
114 |
115 | def __init__(
116 | self, layers, output_dim, heads, image_size=224, width=64,
117 | block=Bottleneck,
118 | ):
119 | super().__init__()
120 | self.output_dim = output_dim
121 | self.image_size = image_size
122 |
123 | # the 1-layer stem
124 | self.conv1 = nn.Conv2d(3, width, kernel_size=3, stride=2, padding=1, bias=False)
125 | self.bn1 = nn.BatchNorm2d(width)
126 | self.relu1 = nn.ReLU(inplace=True)
127 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
128 |
129 | # residual layers
130 | self._inplanes = width # this is a *mutable* variable used during construction
131 | self.layer1 = self._make_layer(width, layers[0])
132 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
133 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
134 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
135 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
136 | # self.head = nn.Linear(512 * 6, output_dim)
137 | self.head = nn.Linear(512 * block.expansion, output_dim)
138 |
139 | # embed_dim = width * 32 # the ResNet feature dimension
140 | # self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
141 |
142 | self.init_parameters()
143 |
144 | def _make_layer(
145 | self,
146 | planes, blocks, stride=1,
147 | block=Bottleneck,
148 | ):
149 | layers = [block(self._inplanes, planes, stride)]
150 |
151 | self._inplanes = planes * block.expansion
152 | for _ in range(1, blocks):
153 | layers.append(block(self._inplanes, planes))
154 |
155 | return nn.Sequential(*layers)
156 |
157 | def init_parameters(self):
158 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
159 | for name, param in resnet_block.named_parameters():
160 | if name.endswith("bn3.weight"):
161 | nn.init.zeros_(param)
162 |
163 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
164 | assert unlocked_groups == 0, 'partial locking not currently supported for this model'
165 | for param in self.parameters():
166 | param.requires_grad = False
167 | if freeze_bn_stats:
168 | freeze_batch_norm_2d(self)
169 |
170 | @torch.jit.ignore
171 | def set_grad_checkpointing(self, enable=True):
172 | # FIXME support for non-transformer
173 | pass
174 |
175 | def stem(self, x):
176 | x = self.relu1(self.bn1(self.conv1(x)))
177 | x = self.maxpool(x)
178 | return x
179 |
180 | def forward(self, x):
181 | # x[0]: [batch_size, 3, 224, 224]
182 | # x[1]: [batch_size, 1]
183 | x = self.stem(x) # [batch_size, 64, 56, 56]
184 | x = self.layer1(x)
185 | x = self.layer2(x)
186 | x = self.layer3(x)
187 | x = self.layer4(x) # [batch_size, 2048, 7, 7]
188 | x = self.avgpool(x) # [batch_size, 2048, 1, 1]
189 | x = torch.flatten(x, 1) # [batch_size, 2048*1*1]
190 | x = self.head(x) # [batch_size, 1024]
191 |
192 | visual_output = dict.fromkeys(["image_features", "mim_loss"], None)
193 | visual_output.update({
194 | 'image_features': x,
195 | })
196 |
197 | return visual_output
198 |
199 |
200 | class ModifiedResNet(nn.Module):
201 | """
202 | A ResNet class that is similar to torchvision's but contains the following changes:
203 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
204 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
205 | - The final pooling layer is a QKV attention instead of an average pool
206 | """
207 |
208 | def __init__(self, layers, output_dim, heads, image_size=224, width=64):
209 | super().__init__()
210 | self.output_dim = output_dim
211 | self.image_size = image_size
212 |
213 | # the 3-layer stem
214 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
215 | self.bn1 = nn.BatchNorm2d(width // 2)
216 | self.relu1 = nn.ReLU(inplace=True)
217 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
218 | self.bn2 = nn.BatchNorm2d(width // 2)
219 | self.relu2 = nn.ReLU(inplace=True)
220 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
221 | self.bn3 = nn.BatchNorm2d(width)
222 | self.relu3 = nn.ReLU(inplace=True)
223 | self.avgpool = nn.AvgPool2d(2)
224 |
225 | # residual layers
226 | self._inplanes = width # this is a *mutable* variable used during construction
227 | self.layer1 = self._make_layer(width, layers[0])
228 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
229 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
230 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
231 |
232 | embed_dim = width * 32 # the ResNet feature dimension
233 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
234 |
235 | self.init_parameters()
236 |
237 | def _make_layer(self, planes, blocks, stride=1):
238 | layers = [Bottleneck(self._inplanes, planes, stride)]
239 |
240 | self._inplanes = planes * Bottleneck.expansion
241 | for _ in range(1, blocks):
242 | layers.append(Bottleneck(self._inplanes, planes))
243 |
244 | return nn.Sequential(*layers)
245 |
246 | def init_parameters(self):
247 | if self.attnpool is not None:
248 | std = self.attnpool.c_proj.in_features ** -0.5
249 | nn.init.normal_(self.attnpool.q_proj.weight, std=std)
250 | nn.init.normal_(self.attnpool.k_proj.weight, std=std)
251 | nn.init.normal_(self.attnpool.v_proj.weight, std=std)
252 | nn.init.normal_(self.attnpool.c_proj.weight, std=std)
253 |
254 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
255 | for name, param in resnet_block.named_parameters():
256 | if name.endswith("bn3.weight"):
257 | nn.init.zeros_(param)
258 |
259 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
260 | assert unlocked_groups == 0, 'partial locking not currently supported for this model'
261 | for param in self.parameters():
262 | param.requires_grad = False
263 | if freeze_bn_stats:
264 | freeze_batch_norm_2d(self)
265 |
266 | @torch.jit.ignore
267 | def set_grad_checkpointing(self, enable=True):
268 | # FIXME support for non-transformer
269 | pass
270 |
271 | def stem(self, x):
272 | x = self.relu1(self.bn1(self.conv1(x)))
273 | x = self.relu2(self.bn2(self.conv2(x)))
274 | x = self.relu3(self.bn3(self.conv3(x)))
275 | x = self.avgpool(x)
276 | return x
277 |
278 | def forward(self, x):
279 | x = self.stem(x)
280 | x = self.layer1(x)
281 | x = self.layer2(x)
282 | x = self.layer3(x)
283 | x = self.layer4(x)
284 | x = self.attnpool(x)
285 |
286 | visual_output = dict.fromkeys(["image_features", "mim_loss"], None)
287 | visual_output.update({
288 | 'image_features': x,
289 | })
290 |
291 | return visual_output
292 |
293 |
294 | class LayerNorm(nn.LayerNorm):
295 | """Subclass torch's LayerNorm to handle fp16."""
296 |
297 | def forward(self, x: torch.Tensor):
298 | orig_type = x.dtype
299 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
300 | return x.to(orig_type)
301 |
302 |
303 | class QuickGELU(nn.Module):
304 | # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
305 | def forward(self, x: torch.Tensor):
306 | return x * torch.sigmoid(1.702 * x)
307 |
308 |
309 | class ResidualAttentionBlock(nn.Module):
310 | def __init__(
311 | self, d_model: int, n_head: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU,
312 | drop_attention_rate: float = 0.,
313 | ):
314 | super().__init__()
315 |
316 | self.attn = nn.MultiheadAttention(
317 | embed_dim=d_model,
318 | num_heads=n_head,
319 | dropout=drop_attention_rate,
320 | )
321 | self.ln_1 = LayerNorm(d_model)
322 | mlp_width = int(d_model * mlp_ratio)
323 | self.mlp = nn.Sequential(OrderedDict([
324 | ("c_fc", nn.Linear(d_model, mlp_width)),
325 | ("gelu", act_layer()),
326 | ("c_proj", nn.Linear(mlp_width, d_model))
327 | ]))
328 | self.ln_2 = LayerNorm(d_model)
329 |
330 | def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
331 | return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
332 |
333 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
334 | x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
335 | x = x + self.mlp(self.ln_2(x))
336 | return x
337 |
338 |
339 | class PatchDropout(nn.Module):
340 | """
341 | https://arxiv.org/abs/2212.00794
342 | """
343 |
344 | def __init__(self, prob, exclude_first_token=True):
345 | super().__init__()
346 | assert 0 <= prob < 1.
347 | self.prob = prob
348 | self.exclude_first_token = exclude_first_token # exclude CLS token
349 |
350 | def forward(self, x):
351 | if not self.training or self.prob == 0.:
352 | return x
353 |
354 | if self.exclude_first_token:
355 | cls_tokens, x = x[:, :1], x[:, 1:]
356 | else:
357 | cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
358 |
359 | batch = x.size()[0]
360 | num_tokens = x.size()[1]
361 |
362 | batch_indices = torch.arange(batch)
363 | batch_indices = batch_indices[..., None]
364 |
365 | keep_prob = 1 - self.prob
366 | num_patches_keep = max(1, int(num_tokens * keep_prob))
367 |
368 | rand = torch.randn(batch, num_tokens)
369 | patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
370 |
371 | x = x[batch_indices, patch_indices_keep]
372 |
373 | if self.exclude_first_token:
374 | x = torch.cat((cls_tokens, x), dim=1)
375 |
376 | return x
377 |
378 |
379 | class Transformer(nn.Module):
380 | def __init__(
381 | self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, act_layer: Callable = nn.GELU,
382 | drop_attention_rate: float = 0.,
383 | ):
384 | super().__init__()
385 | self.width = width
386 | self.layers = layers
387 | self.grad_checkpointing = False
388 |
389 | self.resblocks = nn.ModuleList([
390 | ResidualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer, drop_attention_rate=drop_attention_rate)
391 | for _ in range(layers)
392 | ])
393 |
394 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
395 | for r in self.resblocks:
396 | if self.grad_checkpointing and not torch.jit.is_scripting():
397 | x = checkpoint(r, x, attn_mask)
398 | else:
399 | x = r(x, attn_mask=attn_mask)
400 | return x
--------------------------------------------------------------------------------
/src/Model/RadFM/helpers.py:
--------------------------------------------------------------------------------
1 | """
2 | Taken from https://github.com/lucidrains/flamingo-pytorch
3 | """
4 |
5 | import torch
6 | from einops import rearrange, repeat
7 | from einops_exts import rearrange_many
8 | from torch import einsum, nn
9 |
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 |
15 | def FeedForward(dim, mult=4):
16 | inner_dim = int(dim * mult)
17 | return nn.Sequential(
18 | nn.LayerNorm(dim),
19 | nn.Linear(dim, inner_dim, bias=False),
20 | nn.GELU(),
21 | nn.Linear(inner_dim, dim, bias=False),
22 | )
23 |
24 |
25 | class PerceiverAttention(nn.Module):
26 | def __init__(self, *, dim, dim_head=64, heads=8):
27 | super().__init__()
28 | self.scale = dim_head**-0.5
29 | self.heads = heads
30 | inner_dim = dim_head * heads
31 |
32 | self.norm_media = nn.LayerNorm(dim)
33 | self.norm_latents = nn.LayerNorm(dim)
34 |
35 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
36 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
37 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
38 |
39 | def forward(self, x, latents):
40 | """
41 | Args:
42 | x (torch.Tensor): image features
43 | shape (b, T, n1, D)
44 | latent (torch.Tensor): latent features
45 | shape (b, T, n2, D)
46 | """
47 | x = self.norm_media(x)
48 | latents = self.norm_latents(latents)
49 |
50 | h = self.heads
51 |
52 | q = self.to_q(latents)
53 | kv_input = torch.cat((x, latents), dim=-2)
54 | k, v = self.to_kv(kv_input).chunk(2, dim=-1)
55 | q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
56 | q = q * self.scale
57 |
58 | # attention
59 | sim = einsum("... i d, ... j d -> ... i j", q, k)
60 | sim = sim - sim.amax(dim=-1, keepdim=True).detach()
61 | attn = sim.softmax(dim=-1)
62 |
63 | out = einsum("... i j, ... j d -> ... i d", attn, v)
64 | out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
65 | return self.to_out(out)
66 |
67 |
68 | class PerceiverResampler(nn.Module):
69 | def __init__(
70 | self,
71 | *,
72 | dim,
73 | depth=6,
74 | dim_head=64,
75 | heads=8,
76 | num_latents=64,
77 | max_num_media=None,
78 | max_num_frames=None,
79 | ff_mult=4,
80 | ):
81 | super().__init__()
82 | self.latents = nn.Parameter(torch.randn(num_latents, dim))
83 | self.frame_embs = (
84 | nn.Parameter(torch.randn(max_num_frames, dim))
85 | if exists(max_num_frames)
86 | else None
87 | )
88 | self.media_time_embs = (
89 | nn.Parameter(torch.randn(max_num_media, 1, dim))
90 | if exists(max_num_media)
91 | else None
92 | )
93 |
94 | self.layers = nn.ModuleList([])
95 | for _ in range(depth):
96 | self.layers.append(
97 | nn.ModuleList(
98 | [
99 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
100 | FeedForward(dim=dim, mult=ff_mult),
101 | ]
102 | )
103 | )
104 |
105 | self.norm = nn.LayerNorm(dim)
106 |
107 | def forward(self, x):
108 | """
109 | Args:
110 | x (torch.Tensor): image features
111 | shape (b, T, F, v, D)
112 | Returns:
113 | shape (b, T, n, D) where n is self.num_latents
114 | """
115 | b, T, F, v = x.shape[:4]
116 |
117 | # frame and media time embeddings
118 | if exists(self.frame_embs):
119 | frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
120 | x = x + frame_embs
121 | x = rearrange(
122 | x, "b T F v d -> b T (F v) d"
123 | ) # flatten the frame and spatial dimensions
124 | if exists(self.media_time_embs):
125 | x = x + self.media_time_embs[:T]
126 |
127 | # blocks
128 | latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
129 | for attn, ff in self.layers:
130 | latents = attn(x, latents) + latents
131 | latents = ff(latents) + latents
132 | return self.norm(latents)
133 |
134 |
135 | # gated cross attention
136 |
137 |
138 | class MaskedCrossAttention(nn.Module):
139 | def __init__(
140 | self,
141 | *,
142 | dim,
143 | dim_visual,
144 | dim_head=64,
145 | heads=8,
146 | only_attend_immediate_media=True,
147 | ):
148 | super().__init__()
149 | self.scale = dim_head**-0.5
150 | self.heads = heads
151 | inner_dim = dim_head * heads
152 |
153 | self.norm = nn.LayerNorm(dim)
154 |
155 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
156 | self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
157 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
158 |
159 | # whether for text to only attend to immediate preceding image, or all previous images
160 | self.only_attend_immediate_media = only_attend_immediate_media
161 |
162 | def forward(self, x, media, media_locations=None, attend_previous=True):
163 | """
164 | Args:
165 | x (torch.Tensor): text features
166 | shape (B, T_txt, D_txt)
167 | media (torch.Tensor): image features
168 | shape (B, T_img, n, D_img) where n is the dim of the latents
169 | media_locations: boolean mask identifying the media tokens in x
170 | shape (B, T_txt)
171 | attend_previous: bool
172 | If false, ignores immediately preceding image and starts attending when following image
173 | """
174 | _, T_img, n = media.shape[:3]
175 | h = self.heads
176 |
177 | x = self.norm(x)
178 |
179 | q = self.to_q(x)
180 | media = rearrange(media, "b t n d -> b (t n) d")
181 |
182 | k, v = self.to_kv(media).chunk(2, dim=-1)
183 | q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h)
184 |
185 | q = q * self.scale
186 |
187 | sim = einsum("... i d, ... j d -> ... i j", q, k)
188 |
189 | if exists(media_locations):
190 | # at each boolean of True, increment the time counter (relative to media time)
191 | text_time = media_locations.cumsum(dim=-1)
192 | media_time = torch.arange(T_img, device=x.device) + 1
193 |
194 | if not attend_previous:
195 | text_time[~media_locations] += 1
196 | # make sure max is still the number of images in the sequence
197 | text_time[
198 | text_time
199 | > repeat(
200 | torch.count_nonzero(media_locations, dim=1),
201 | "b -> b i",
202 | i=text_time.shape[1],
203 | )
204 | ] = 0
205 |
206 | # text time must equal media time if only attending to most immediate image
207 | # otherwise, as long as text time is greater than media time (if attending to all previous images / media)
208 | mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
209 |
210 | text_to_media_mask = mask_op(
211 | rearrange(text_time, "b i -> b 1 i 1"),
212 | repeat(media_time, "j -> 1 1 1 (j n)", n=n),
213 | )
214 | sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
215 |
216 | sim = sim - sim.amax(dim=-1, keepdim=True).detach()
217 | attn = sim.softmax(dim=-1)
218 |
219 | if exists(media_locations) and self.only_attend_immediate_media:
220 | # any text without a preceding media needs to have attention zeroed out
221 | text_without_media_mask = text_time == 0
222 | text_without_media_mask = rearrange(
223 | text_without_media_mask, "b i -> b 1 i 1"
224 | )
225 | attn = attn.masked_fill(text_without_media_mask, 0.0)
226 |
227 | out = einsum("... i j, ... j d -> ... i d", attn, v)
228 | out = rearrange(out, "b h n d -> b n (h d)")
229 | return self.to_out(out)
230 |
231 |
232 | class GatedCrossAttentionBlock(nn.Module):
233 | def __init__(
234 | self,
235 | *,
236 | dim,
237 | dim_visual,
238 | dim_head=64,
239 | heads=8,
240 | ff_mult=4,
241 | only_attend_immediate_media=True,
242 | ):
243 | super().__init__()
244 | self.attn = MaskedCrossAttention(
245 | dim=dim,
246 | dim_visual=dim_visual,
247 | dim_head=dim_head,
248 | heads=heads,
249 | only_attend_immediate_media=only_attend_immediate_media,
250 | )
251 | self.attn_gate = nn.Parameter(torch.tensor([0.0]))
252 |
253 | self.ff = FeedForward(dim, mult=ff_mult)
254 | self.ff_gate = nn.Parameter(torch.tensor([0.0]))
255 |
256 | def forward(
257 | self,
258 | x,
259 | media,
260 | media_locations=None,
261 | attend_previous=True,
262 | ):
263 | x = (
264 | self.attn(
265 | x,
266 | media,
267 | media_locations=media_locations,
268 | attend_previous=attend_previous,
269 | )
270 | * self.attn_gate.tanh()
271 | + x
272 | )
273 | x = self.ff(x) * self.ff_gate.tanh() + x
274 |
275 | return x
276 |
--------------------------------------------------------------------------------
/src/Model/RadFM/multimodality_model.py:
--------------------------------------------------------------------------------
1 | # Import necessary libraries
2 | from torch import nn
3 | from transformers.models.llama import LlamaForCausalLM
4 | from .my_embedding_layer import MyEmbedding
5 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
6 | import tqdm.auto as tqdm
7 | import torch.nn as nn
8 | import torch
9 | from torch.utils.checkpoint import checkpoint
10 | from torch.autograd import Variable
11 | import numpy as np
12 |
13 | class MultiLLaMAForCausalLM(nn.Module):
14 | """
15 | A multimodal LLaMA model that combines language and vision inputs
16 | for causal language modeling tasks.
17 | """
18 | def __init__(self, lang_model_path):
19 | """
20 | Initialize the multimodal model.
21 |
22 | Args:
23 | lang_model_path (str): Path to the pretrained language model
24 | """
25 | super(MultiLLaMAForCausalLM, self).__init__()
26 |
27 | # Load pretrained LLaMA model
28 | self.lang_model = LlamaForCausalLM.from_pretrained(
29 | lang_model_path,
30 | )
31 |
32 | # Enable gradient checkpointing for memory efficiency
33 | self.lang_model.gradient_checkpointing_enable()
34 | self.lang_model.enable_input_require_grads()
35 |
36 | # Initialize custom embedding layer and share weights with language model
37 | self.embedding_layer = MyEmbedding()
38 | self.embedding_layer.weight = self.lang_model.get_input_embeddings().weight
39 |
40 | # Set model dimensions
41 | self.hidden_dim = 5120
42 | self.voc_size = 32000
43 |
44 | def forward(self, lang_x, vision_x, attention_mask, labels, loss_reweight, key_words_query):
45 | """
46 | Forward pass for the multimodal model.
47 |
48 | Args:
49 | lang_x: Language input tokens
50 | vision_x: Vision input features
51 | attention_mask: Attention mask for language inputs
52 | labels: Target labels for language modeling
53 | loss_reweight: Weights for calculating loss (to prioritize certain tokens)
54 | key_words_query: Query for highlighting important words
55 |
56 | Returns:
57 | Dictionary containing model outputs including loss and logits
58 | """
59 | if labels.shape == lang_x.shape:
60 | # Set embedding mode to handle text inputs
61 | self.embedding_layer.flag = 'Text'
62 |
63 | # Get embeddings and matching loss from embedding layer
64 | input_embedding, loss_match = self.embedding_layer(lang_x, vision_x, key_words_query)
65 |
66 | # Forward pass through the language model
67 | output = self.lang_model(inputs_embeds=input_embedding, attention_mask=attention_mask, labels=labels)
68 | logits = output['logits']
69 |
70 | # Initialize regularization loss
71 | loss_reg = None
72 | if labels is not None:
73 | # Shift logits and labels for next-token prediction
74 | shift_logits = logits[..., :-1, :].contiguous()
75 | shift_labels = labels[..., 1:].contiguous()
76 | shift_loss_reweight = loss_reweight[..., 1:].contiguous()
77 |
78 | # Prepare for loss calculation
79 | loss_fct = CrossEntropyLoss(reduction='none')
80 | shift_logits = shift_logits.view(-1, self.voc_size)
81 | shift_labels = shift_labels.view(-1)
82 | shift_loss_reweight = shift_loss_reweight.view(-1)
83 |
84 | # Ensure tensors are on the same device
85 | shift_labels = shift_labels.to(shift_logits.device)
86 | shift_loss_reweight = shift_loss_reweight.to(shift_logits.device)
87 |
88 | # Calculate weighted cross-entropy loss
89 | loss_reg = loss_fct(shift_logits, shift_labels)
90 | loss_reg = torch.sum(shift_loss_reweight * loss_reg) / torch.sum(shift_loss_reweight)
91 |
92 | # Combine losses
93 | loss = loss_reg
94 | if loss_match is not None:
95 | loss = 0.8 * loss + 0.2 * loss_match
96 |
97 | # Calculate accuracy metrics
98 | logits = output['logits'][..., :-1, :].contiguous().detach()
99 | total = len(labels)
100 | predictions = torch.argmax(logits, dim=-1)
101 | labels = labels[..., 1:].contiguous()
102 |
103 | # Count correct predictions (ignoring padding tokens with -100)
104 | Acc = torch.sum(torch.all(torch.logical_or(predictions == labels, labels == -100), dim=-1))
105 | Accuracy = Acc / total
106 |
107 | return dict(
108 | # loss_reg = loss_reg,
109 | # loss_matching = loss_matching,
110 | logits=Accuracy,
111 | loss=output['loss'],
112 | )
113 |
114 | ### useless for now ignore the folowing codes ###
115 | # if labels.shape == vision_x.shape:
116 | # self.embedding_layer.flag = 'Seg'
117 | # input_embedding = self.embedding_layer(lang_x, vision_x)
118 |
119 | def generate(self, lang_x, vision_x):
120 | """
121 | Generate text based on language and vision inputs.
122 |
123 | Args:
124 | lang_x: Language input tokens
125 | vision_x: Vision input features
126 |
127 | Returns:
128 | Generated token sequence
129 | """
130 | # Set embedding mode to text generation
131 | self.embedding_layer.flag = 'Text'
132 |
133 | with torch.no_grad():
134 | # Get embeddings from the embedding layer
135 | input_embedding, _ = self.embedding_layer(lang_x, vision_x)
136 |
137 | # Generate text using language model
138 | generation = self.lang_model.generate(
139 | inputs_embeds=input_embedding,
140 | max_new_tokens=200,
141 | top_k=50
142 | )
143 |
144 | return generation
--------------------------------------------------------------------------------
/src/Model/RadFM/my_embedding_layer.py:
--------------------------------------------------------------------------------
1 | # Import necessary libraries
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch
5 | from .helpers import PerceiverResampler
6 | from .utils import get_visual_encoder
7 | from einops import rearrange, repeat
8 | from einops_exts import rearrange_many
9 | import torchvision
10 | from .vit_3d import ViT
11 | from einops.layers.torch import Rearrange
12 | from .transformer_decoder import TransformerDecoder, TransformerDecoderLayer
13 | from torch.utils.checkpoint import checkpoint
14 | from torch.autograd import Variable
15 | import random
16 | from transformers import AutoTokenizer, AutoModel
17 |
18 | class MyEmbedding(nn.Module):
19 | """
20 | Custom embedding layer for multimodal inputs that combines text and vision features.
21 | """
22 | def __init__(self, num_embeddings=32000, embedding_dim=5120, perceiver_num=32, vis_dim=768,
23 | patch_size=32, frame_patch_size=4, seg_channel=256):
24 | """
25 | Initialize the multimodal embedding layer.
26 |
27 | Args:
28 | num_embeddings (int): Size of vocabulary for text embeddings
29 | embedding_dim (int): Dimension of output embeddings
30 | perceiver_num (int): Number of latent vectors in perceiver
31 | vis_dim (int): Dimension of vision features
32 | patch_size (int): Size of image patches
33 | frame_patch_size (int): Size of 3D frame patches
34 | seg_channel (int): Number of segmentation channels
35 | """
36 | super().__init__()
37 | self.num_embeddings = num_embeddings
38 | self.embedding_dim = embedding_dim
39 | # Standard embedding weight matrix for text tokens
40 | self.weight = nn.Parameter(torch.torch.randn((num_embeddings, embedding_dim)))
41 | # Special token weights for figures/images
42 | self.figure_token_weight = nn.Parameter(torch.randn((2, embedding_dim)))
43 | self.flag = 'Text' # Mode flag: 'Text' or 'Seg'
44 | self.patch_size = patch_size
45 | self.frame_patch_size = frame_patch_size
46 | self.seg_channel = seg_channel
47 |
48 | ## the MedKEBERT can be downloaded from https://huggingface.co/xmcmic/Med-KEBERT/tree/main ##
49 | # Initialize medical domain BERT model for keyword understanding
50 | self.bert_tokenizer = AutoTokenizer.from_pretrained("xmcmic/Med-KEBERT")
51 | self.bert_model = AutoModel.from_pretrained("xmcmic/Med-KEBERT")
52 | # Project BERT outputs to vision feature space
53 | self.bert_projection_fc = nn.Linear(768, vis_dim)
54 |
55 | # 3D Vision Transformer for processing volumetric medical images
56 | self.vision_encoder = ViT(
57 | image_size=512, # image size
58 | frames=512, # max number of frames
59 | image_patch_size=patch_size, # image patch size
60 | frame_patch_size=frame_patch_size, # frame patch size
61 | dim=vis_dim,
62 | depth=12,
63 | heads=8,
64 | mlp_dim=2048,
65 | dropout=0.1,
66 | emb_dropout=0.1
67 | )
68 |
69 | # Upscaling layers for vision features (used in segmentation mode)
70 | self.output_upscaling = nn.Sequential(
71 | nn.ConvTranspose3d(vis_dim, vis_dim // 4, kernel_size=2, stride=2),
72 | nn.BatchNorm3d(vis_dim // 4),
73 | nn.GELU(),
74 | nn.ConvTranspose3d(vis_dim // 4, vis_dim // 8, kernel_size=2, stride=2),
75 | nn.GELU(),
76 | )
77 |
78 | # Transformer decoder for cross-attention between text and vision
79 | decoder_layer = TransformerDecoderLayer(d_model=vis_dim, nhead=8, normalize_before=True)
80 | decoder_norm = nn.LayerNorm(vis_dim)
81 | self.transformer_decoder = TransformerDecoder(decoder_layer=decoder_layer, num_layers=4, norm=decoder_norm)
82 |
83 | # MLP for processing transformer decoder outputs
84 | self.transformer_decoder_mlp = nn.Sequential(
85 | nn.Linear(vis_dim, vis_dim // 4),
86 | nn.GELU(),
87 | nn.Linear(vis_dim // 4, vis_dim // 8),
88 | nn.GELU(),
89 | )
90 | self.vis_dim = vis_dim
91 |
92 | # Perceiver resampler to reduce sequence length of vision features
93 | self.perceiver = PerceiverResampler(dim=self.vis_dim, num_latents=perceiver_num)
94 | # Final projection to embedding dimension
95 | self.fc = nn.Linear(self.vis_dim, self.embedding_dim)
96 | # Classification head for matching keywords
97 | self.cls_head = nn.Linear(self.vis_dim // 8, 1)
98 |
99 |
100 | def forward(self, text_input, vision_x, key_words_query=None):
101 | """
102 | Forward pass for the embedding layer.
103 |
104 | Args:
105 | text_input: Text token indices [B, L]
106 | vision_x: Visual input features [B, S, C, H, W, D]
107 | key_words_query: Optional list of key words for contrastive learning
108 |
109 | Returns:
110 | tuple: (output_embeddings, loss_matching)
111 | - output_embeddings: Combined embeddings for text and vision
112 | - loss_matching: Contrastive loss for keyword matching (or None)
113 | """
114 | if self.flag == 'Text':
115 | # Process in text mode
116 | B, S, C, H, W, D = vision_x.shape
117 | # Reshape for batch processing
118 | vision_x = rearrange(vision_x, "b S c h w d-> (b S) c h w d")
119 |
120 | # Process through vision encoder
121 | vision_x, pos_embedding = self.vision_encoder(vision_x)
122 |
123 | # Reshape back to batch format
124 | vision_x = rearrange(vision_x, "(b s F) v d -> b s F v d", b=B, s=S, F=1)
125 |
126 | loss_matching = None
127 |
128 | if key_words_query is not None:
129 | ## we do not use the following parts in final version.
130 | ## You can quota the following codes and if so the bert models will be useless.
131 | # key_words_query list[list[str]] B, words, each word matches corresponding vision_x embedding
132 |
133 | # Extract and deduplicate keywords
134 | query_words = [item for sublist in key_words_query for item in sublist]
135 | query_words = list(set(query_words))
136 |
137 | # Limit number of keywords to process
138 | if len(query_words) > 16:
139 | random.shuffle(query_words)
140 | query_words = query_words[0:16]
141 |
142 | if query_words != []:
143 | # Create binary labels for contrastive learning
144 | contrastive_labels = torch.zeros(B, len(query_words)) # B Q
145 | for i, sublist in enumerate(key_words_query):
146 | for j, item in enumerate(query_words):
147 | if item in sublist:
148 | contrastive_labels[i, j] = 1
149 | contrastive_labels = contrastive_labels.to(vision_x.dtype).to(vision_x.device)
150 |
151 | # Get BERT embeddings for keywords
152 | with torch.no_grad():
153 | query_words_embedding = self.bert_tokenizer(
154 | query_words,
155 | padding='max_length',
156 | truncation=True,
157 | max_length=256,
158 | return_tensors="pt"
159 | )
160 | query_words_embedding = self.bert_model(
161 | input_ids=query_words_embedding['input_ids'].to(vision_x.device),
162 | attention_mask=query_words_embedding['attention_mask'].to(vision_x.device)
163 | )['last_hidden_state'][:, 0, :].to(vision_x.dtype).to(vision_x.device) # Q,D
164 |
165 | # Project BERT embeddings to vision space
166 | query_words_embedding = self.bert_projection_fc(query_words_embedding)
167 | query_words_embedding = query_words_embedding.unsqueeze(0).repeat(B, 1, 1) # B,Q,D
168 | _, N, _ = query_words_embedding.shape
169 |
170 | # Pool vision features
171 | image_embedding = vision_x.mean(dim=1) # B V D average pooling to remove multimodality
172 | image_embedding = rearrange(image_embedding, "b F v d -> b (F v) d")
173 | pos_embedding = rearrange(pos_embedding, "(b s) v d -> b s v d", b=B, s=S)[:, 0, :, :]
174 |
175 | # Prepare inputs for transformer decoder
176 | image_embedding = image_embedding.transpose(0, 1) # (H/P W/P D/P) B D
177 | pos_embedding = pos_embedding.transpose(0, 1) # (H/P W/P D/P) B D
178 | query_words_embedding = query_words_embedding.transpose(0, 1) # N B D
179 |
180 | # Cross-attention between keywords and image features
181 | oo_embedding, _ = self.transformer_decoder(
182 | query_words_embedding, image_embedding, pos=pos_embedding
183 | )
184 | oo_embedding = oo_embedding.transpose(0, 1) # B Q D
185 | oo_embedding = rearrange(oo_embedding, 'b n d -> (b n) d')
186 | oo_embedding = self.transformer_decoder_mlp(oo_embedding)
187 | oo_embedding = self.cls_head(oo_embedding).mean(dim=-1)
188 | oo_embedding = rearrange(oo_embedding, '(b n) -> b n', b=B, n=N) # B Q
189 |
190 | # Calculate contrastive loss
191 | loss_matching = F.binary_cross_entropy_with_logits(oo_embedding, contrastive_labels)
192 |
193 | # Process vision features through perceiver resampler
194 | vision_x = self.perceiver(vision_x) # reshapes to (b, S, n, d)
195 |
196 | n = vision_x.shape[2]
197 |
198 | # Project vision features to embedding dimension
199 | vision_x = rearrange(vision_x, "b s n d -> (b s n) d")
200 | vision_x = self.fc(vision_x)
201 | vision_x = rearrange(vision_x, "(b T) d -> b T d", b=B, T=n*S)
202 |
203 | # Combine text and vision embeddings
204 | embedding_weight = torch.cat([self.weight, self.figure_token_weight], dim=0)
205 | embedding_weight = embedding_weight.unsqueeze(0).repeat(B, 1, 1)
206 | embedding_weight = torch.cat([embedding_weight, vision_x], dim=1)
207 |
208 | # Convert text indices to one-hot and compute final embeddings
209 | text_input = F.one_hot(text_input, embedding_weight.shape[1]).to(vision_x.dtype).to(vision_x.device)
210 | out_put = torch.matmul(text_input, embedding_weight)
211 |
212 | ## useless for now. ignore the folowing code##
213 | # if self.flag == 'Seg':
214 | # B,C,H,W,D = vision_x.shape
215 | # _,N,_ = text_input.shape
216 | # latent_embedding, pos_embedding = self.vision_encoder(vision_x) # B (H/P W/P D/P) D
217 |
218 | # image_embedding = latent_embedding.transpose(0,1) # (H/P W/P D/P) B D
219 | # pos_embedding = pos_embedding.transpose(0,1) # (H/P W/P D/P) B D
220 | # text_input = text_input.transpose(0,1) # N B D
221 |
222 | # mask_embedding,_ = self.transformer_decoder(text_input, image_embedding, pos = pos_embedding)
223 | # mask_embedding = mask_embedding.transpose(0,1) # B N D
224 | # mask_embedding = rearrange(mask_embedding, 'b n d -> (b n) d')
225 | # mask_embedding = self.transformer_decoder_mlp(mask_embedding)
226 | # mask_embedding = rearrange(mask_embedding, '(b n) d -> b n d', b=B, n=N,d = self.vis_dim // 8)
227 |
228 | # vision_x = rearrange(latent_embedding,'b (h w d) c -> b c h w d', h = (H // self.patch_size), w = (W // self.patch_size), d = (D // self.frame_patch_size), c=self.vis_dim)
229 | # vision_x = self.output_upscaling(vision_x) #B C H/4 W/4 D/4
230 | # out_put = torch.einsum('bchwd,bnc->bnhwd', vision_x, mask_embedding)
231 |
232 | return out_put, loss_matching
233 |
234 | # model = MyEmbedding(vision_encoder_path = '')
235 | # text_input = torch.randint(low=0, high=3210, size=(4,2048))
236 | # image_input = torch.randn((4,3,3,512,512,4))
237 | # key_words_query = [[],[],[],['consoliation']]
238 | # print(model(text_input, image_input, key_words_query))
--------------------------------------------------------------------------------
/src/Model/RadFM/position_encoding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Various positional encodings for the transformer.
4 | """
5 | import math
6 | import torch
7 | from torch import nn
8 | from einops.layers.torch import Rearrange
9 | from einops import rearrange, repeat
10 |
11 | class PositionEmbeddingSine(nn.Module):
12 | """
13 | This is a more standard version of the position embedding, very similar to the one
14 | used by the Attention is all you need paper, generalized to work on images.
15 | """
16 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
17 | super().__init__()
18 | self.num_pos_feats = num_pos_feats
19 | self.temperature = temperature
20 | self.normalize = normalize
21 | if scale is not None and normalize is False:
22 | raise ValueError("normalize should be True if scale is passed")
23 | if scale is None:
24 | scale = 2 * math.pi
25 | self.scale = scale
26 |
27 | def forward(self, tensor_list):
28 | x = tensor_list.tensors
29 | mask = tensor_list.mask
30 | assert mask is not None
31 | not_mask = ~mask
32 | y_embed = not_mask.cumsum(1, dtype=torch.float32)
33 | x_embed = not_mask.cumsum(2, dtype=torch.float32)
34 | if self.normalize:
35 | eps = 1e-6
36 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
37 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
38 |
39 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
40 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
41 |
42 | pos_x = x_embed[:, :, :, None] / dim_t
43 | pos_y = y_embed[:, :, :, None] / dim_t
44 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
45 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
46 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
47 | return pos
48 |
49 |
50 | class PositionEmbeddingLearned(nn.Module):
51 | """
52 | Absolute pos embedding, learned.
53 | """
54 | def __init__(self, num_pos_feats=256):
55 | super().__init__()
56 | self.row_embed = nn.Embedding(50, num_pos_feats)
57 | self.col_embed = nn.Embedding(50, num_pos_feats)
58 | self.reset_parameters()
59 |
60 | def reset_parameters(self):
61 | nn.init.uniform_(self.row_embed.weight)
62 | nn.init.uniform_(self.col_embed.weight)
63 |
64 | def forward(self, tensor_list):
65 | x = tensor_list.tensors
66 | h, w = x.shape[-2:]
67 | i = torch.arange(w, device=x.device)
68 | j = torch.arange(h, device=x.device)
69 | x_emb = self.col_embed(i)
70 | y_emb = self.row_embed(j)
71 | pos = torch.cat([
72 | x_emb.unsqueeze(0).repeat(h, 1, 1),
73 | y_emb.unsqueeze(1).repeat(1, w, 1),
74 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
75 | return pos
76 |
77 | class PositionEmbeddingLearned3d(nn.Module):
78 | """
79 | Absolute pos embedding, learned.
80 | """
81 | def __init__(self, num_pos_feats=256,h_patch_num = 16, w_patch_num = 16,d_patch_num = 64):
82 | super().__init__()
83 | self.h_patch_num = h_patch_num
84 | self.w_patch_num = w_patch_num
85 | self.d_patch_num = d_patch_num
86 | self.row_embed = nn.Embedding(h_patch_num, num_pos_feats)
87 | self.col_embed = nn.Embedding(w_patch_num, num_pos_feats)
88 | self.dep_embed = nn.Embedding(d_patch_num, num_pos_feats)
89 | self.reset_parameters()
90 |
91 | def reset_parameters(self):
92 | nn.init.uniform_(self.row_embed.weight)
93 | nn.init.uniform_(self.col_embed.weight)
94 | nn.init.uniform_(self.dep_embed.weight)
95 |
96 | def forward(self, B, h, w, d,x):
97 | i = (torch.arange(h, device=x.device) + 1)* (self.h_patch_num // h) -1
98 | j = (torch.arange(w, device=x.device) + 1)* (self.w_patch_num // w) -1
99 | k = (torch.arange(d, device=x.device) + 1)* (self.d_patch_num // d) -1
100 | x_emb = self.row_embed(i).unsqueeze(1).unsqueeze(2).repeat(1,w,d,1)
101 | y_emb = self.col_embed(j).unsqueeze(0).unsqueeze(2).repeat(h,1,d,1)
102 | z_emb = self.dep_embed(k).unsqueeze(0).unsqueeze(1).repeat(h,w,1,1)
103 | pos = torch.cat([x_emb,y_emb,z_emb,], dim=-1).unsqueeze(0).repeat(B, 1, 1, 1, 1)
104 | pos = rearrange(pos,'b h w d c -> b (h w d) c')
105 | return pos
106 |
107 | def build_position_encoding(args):
108 | N_steps = args.hidden_dim // 2
109 | if args.position_embedding in ('v2', 'sine'):
110 | # TODO find a better way of exposing other arguments
111 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
112 | elif args.position_embedding in ('v3', 'learned'):
113 | position_embedding = PositionEmbeddingLearned(N_steps)
114 | else:
115 | raise ValueError(f"not supported {args.position_embedding}")
116 |
117 | return position_embedding
118 |
119 | # Pos = PositionEmbeddingLearned3d()
120 | # x = torch.randn((8,3,32,32,1))
121 | # print(Pos(8,16,16,1,x))
--------------------------------------------------------------------------------
/src/Model/RadFM/transformer_decoder.py:
--------------------------------------------------------------------------------
1 | """
2 | Code modified from DETR tranformer:
3 | https://github.com/facebookresearch/detr
4 | Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
5 | """
6 |
7 | import copy
8 | from typing import Optional, List
9 | import pickle as cp
10 |
11 | import torch
12 | import torch.nn.functional as F
13 | from torch import nn, Tensor
14 |
15 |
16 | class TransformerDecoder(nn.Module):
17 |
18 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
19 | super().__init__()
20 | self.layers = _get_clones(decoder_layer, num_layers)
21 | self.num_layers = num_layers
22 | self.norm = norm
23 | self.return_intermediate = return_intermediate
24 |
25 | def forward(self, tgt, memory,
26 | tgt_mask: Optional[Tensor] = None,
27 | memory_mask: Optional[Tensor] = None,
28 | tgt_key_padding_mask: Optional[Tensor] = None,
29 | memory_key_padding_mask: Optional[Tensor] = None,
30 | pos: Optional[Tensor] = None,
31 | query_pos: Optional[Tensor] = None):
32 | output = tgt
33 | T,B,C = memory.shape
34 | intermediate = []
35 | atten_layers = []
36 | for n,layer in enumerate(self.layers):
37 |
38 | residual=True
39 | output,ws = layer(output, memory, tgt_mask=tgt_mask,
40 | memory_mask=memory_mask,
41 | tgt_key_padding_mask=tgt_key_padding_mask,
42 | memory_key_padding_mask=memory_key_padding_mask,
43 | pos=pos, query_pos=query_pos,residual=residual)
44 | atten_layers.append(ws)
45 | if self.return_intermediate:
46 | intermediate.append(self.norm(output))
47 | if self.norm is not None:
48 | output = self.norm(output)
49 | if self.return_intermediate:
50 | intermediate.pop()
51 | intermediate.append(output)
52 |
53 | if self.return_intermediate:
54 | return torch.stack(intermediate)
55 | return output,atten_layers
56 |
57 |
58 |
59 | class TransformerDecoderLayer(nn.Module):
60 |
61 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
62 | activation="relu", normalize_before=False):
63 | super().__init__()
64 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
65 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
66 | # Implementation of Feedforward model
67 | self.linear1 = nn.Linear(d_model, dim_feedforward)
68 | self.dropout = nn.Dropout(dropout)
69 | self.linear2 = nn.Linear(dim_feedforward, d_model)
70 |
71 | self.norm1 = nn.LayerNorm(d_model)
72 | self.norm2 = nn.LayerNorm(d_model)
73 | self.norm3 = nn.LayerNorm(d_model)
74 | self.dropout1 = nn.Dropout(dropout)
75 | self.dropout2 = nn.Dropout(dropout)
76 | self.dropout3 = nn.Dropout(dropout)
77 |
78 | self.activation = _get_activation_fn(activation)
79 | self.normalize_before = normalize_before
80 | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
81 | return tensor if pos is None else tensor + pos
82 |
83 | def forward_post(self, tgt, memory,
84 | tgt_mask: Optional[Tensor] = None,
85 | memory_mask: Optional[Tensor] = None,
86 | tgt_key_padding_mask: Optional[Tensor] = None,
87 | memory_key_padding_mask: Optional[Tensor] = None,
88 | pos: Optional[Tensor] = None,
89 | query_pos: Optional[Tensor] = None,
90 | residual=True):
91 | q = k = self.with_pos_embed(tgt, query_pos)
92 | tgt2,ws = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
93 | key_padding_mask=tgt_key_padding_mask)
94 | tgt = self.norm1(tgt)
95 | tgt2,ws = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
96 | key=self.with_pos_embed(memory, pos),
97 | value=memory, attn_mask=memory_mask,
98 | key_padding_mask=memory_key_padding_mask)
99 |
100 |
101 | # attn_weights [B,NUM_Q,T]
102 | tgt = tgt + self.dropout2(tgt2)
103 | tgt = self.norm2(tgt)
104 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
105 | tgt = tgt + self.dropout3(tgt2)
106 | tgt = self.norm3(tgt)
107 | return tgt,ws
108 |
109 | def forward_pre(self, tgt, memory,
110 | tgt_mask: Optional[Tensor] = None,
111 | memory_mask: Optional[Tensor] = None,
112 | tgt_key_padding_mask: Optional[Tensor] = None,
113 | memory_key_padding_mask: Optional[Tensor] = None,
114 | pos: Optional[Tensor] = None,
115 | query_pos: Optional[Tensor] = None):
116 | tgt2 = self.norm1(tgt)
117 | q = k = self.with_pos_embed(tgt2, query_pos)
118 | tgt2,ws = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
119 | key_padding_mask=tgt_key_padding_mask)
120 | tgt = tgt + self.dropout1(tgt2)
121 | tgt2 = self.norm2(tgt)
122 | tgt2,attn_weights = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
123 | key=self.with_pos_embed(memory, pos),
124 | value=memory, attn_mask=memory_mask,
125 | key_padding_mask=memory_key_padding_mask)
126 | tgt = tgt + self.dropout2(tgt2)
127 | tgt2 = self.norm3(tgt)
128 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
129 | tgt = tgt + self.dropout3(tgt2)
130 | return tgt,attn_weights
131 |
132 | def forward(self, tgt, memory,
133 | tgt_mask: Optional[Tensor] = None,
134 | memory_mask: Optional[Tensor] = None,
135 | tgt_key_padding_mask: Optional[Tensor] = None,
136 | memory_key_padding_mask: Optional[Tensor] = None,
137 | pos: Optional[Tensor] = None,
138 | query_pos: Optional[Tensor] = None,
139 | residual=True):
140 | if self.normalize_before:
141 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
142 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
143 | return self.forward_post(tgt, memory, tgt_mask, memory_mask,
144 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos,residual)
145 |
146 |
147 | def _get_clones(module, N):
148 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
149 |
150 |
151 |
152 | def _get_activation_fn(activation):
153 | """Return an activation function given a string"""
154 | if activation == "relu":
155 | return F.relu
156 | if activation == "gelu":
157 | return F.gelu
158 | if activation == "glu":
159 | return F.glu
160 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
161 |
--------------------------------------------------------------------------------
/src/Model/RadFM/utils.py:
--------------------------------------------------------------------------------
1 | from .blocks import ModifiedResNet,PMC_CLIP_cfg
2 | import torch
3 | from torchvision import transforms
4 | from PIL import Image
5 | import torch.nn as nn
6 | def extend_instance(obj, mixin):
7 | """Apply mixins to a class instance after creation"""
8 | base_cls = obj.__class__
9 | base_cls_name = obj.__class__.__name__
10 | obj.__class__ = type(
11 | base_cls_name, (mixin, base_cls), {}
12 | ) # mixin needs to go first for our forward() logic to work
13 |
14 |
15 | def getattr_recursive(obj, att):
16 | """
17 | Return nested attribute of obj
18 | Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
19 | """
20 | if att == "":
21 | return obj
22 | i = att.find(".")
23 | if i < 0:
24 | return getattr(obj, att)
25 | else:
26 | return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
27 |
28 |
29 | def setattr_recursive(obj, att, val):
30 | """
31 | Set nested attribute of obj
32 | Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
33 | """
34 | if "." in att:
35 | obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
36 | setattr(obj, att.split(".")[-1], val)
37 |
38 |
39 |
40 | def get_visual_encoder(model_str):
41 | """
42 | Args:
43 | str (_type_): str_to_model_path
44 | Return:
45 | vision_model, visual_dim, img_preprocessor
46 | """
47 | normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
48 | img_preprocessor = transforms.Compose([
49 | transforms.Resize((512,512), interpolation=Image.BICUBIC),
50 | transforms.ToTensor(),
51 | normalize,
52 | ])
53 | if 'PMC-CLIP' in model_str:
54 | #vision_cfg = json.load(open(model_args.visual_model_config,'r'))['vision_cfg']
55 | vision_cfg = PMC_CLIP_cfg()
56 | vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
57 | vision_model = ModifiedResNet(
58 | layers=vision_cfg.layers,
59 | heads=vision_heads,
60 | output_dim = 768,
61 | image_size=vision_cfg.image_size,
62 | width=vision_cfg.width
63 | )
64 | vision_model = vision_load_pretrain(vision_model,model_str)
65 | vision_model = nn.Sequential(*list(vision_model.children())[:-2])
66 | visual_dim = 1024
67 | return vision_model,visual_dim,img_preprocessor
68 |
69 | def vision_load_pretrain(resnet,model_path):
70 | checkpoint = torch.load(model_path, map_location='cpu')
71 | state_dict = checkpoint['state_dict']
72 | state_dict = {k.replace('module.visual.',''): v for k, v in state_dict.items() if '.visual' in k}
73 | resnet.load_state_dict(state_dict)
74 | return resnet
75 |
--------------------------------------------------------------------------------
/src/Model/RadFM/vit_3d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from einops import rearrange, repeat
5 | from einops.layers.torch import Rearrange
6 | from .position_encoding import PositionEmbeddingLearned3d
7 |
8 | # helpers
9 |
10 | def pair(t):
11 | return t if isinstance(t, tuple) else (t, t)
12 |
13 | # classes
14 |
15 | class PreNorm(nn.Module):
16 | def __init__(self, dim, fn):
17 | super().__init__()
18 | self.norm = nn.LayerNorm(dim)
19 | self.fn = fn
20 | def forward(self, x, **kwargs):
21 | return self.fn(self.norm(x), **kwargs)
22 |
23 | class FeedForward(nn.Module):
24 | def __init__(self, dim, hidden_dim, dropout = 0.):
25 | super().__init__()
26 | self.net = nn.Sequential(
27 | nn.Linear(dim, hidden_dim),
28 | nn.GELU(),
29 | nn.Dropout(dropout),
30 | nn.Linear(hidden_dim, dim),
31 | nn.Dropout(dropout)
32 | )
33 | def forward(self, x):
34 | return self.net(x)
35 |
36 | class Attention(nn.Module):
37 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
38 | super().__init__()
39 | inner_dim = dim_head * heads
40 | project_out = not (heads == 1 and dim_head == dim)
41 |
42 | self.heads = heads
43 | self.scale = dim_head ** -0.5
44 |
45 | self.attend = nn.Softmax(dim = -1)
46 | self.dropout = nn.Dropout(dropout)
47 |
48 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
49 |
50 | self.to_out = nn.Sequential(
51 | nn.Linear(inner_dim, dim),
52 | nn.Dropout(dropout)
53 | ) if project_out else nn.Identity()
54 |
55 | def forward(self, x):
56 | qkv = self.to_qkv(x).chunk(3, dim = -1)
57 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
58 |
59 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
60 |
61 | attn = self.attend(dots)
62 | attn = self.dropout(attn)
63 |
64 | out = torch.matmul(attn, v)
65 | out = rearrange(out, 'b h n d -> b n (h d)')
66 | return self.to_out(out)
67 |
68 | class Transformer(nn.Module):
69 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
70 | super().__init__()
71 | self.layers = nn.ModuleList([])
72 | for _ in range(depth):
73 | self.layers.append(nn.ModuleList([
74 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
75 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
76 | ]))
77 | def forward(self, x):
78 | for attn, ff in self.layers:
79 | x = attn(x) + x
80 | x = ff(x) + x
81 | return x
82 |
83 | class ViT(nn.Module):
84 | def __init__(self, *, image_size, image_patch_size, frames, frame_patch_size, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
85 | super().__init__()
86 | image_height, image_width = pair(image_size)
87 | patch_height, patch_width = pair(image_patch_size)
88 |
89 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
90 | assert frames % frame_patch_size == 0, 'Frames must be divisible by frame patch size'
91 |
92 | self.patch_height = patch_height
93 | self.patch_width = patch_width
94 | self.frame_patch_size = frame_patch_size
95 |
96 | num_patches = (image_height // patch_height) * (image_width // patch_width) * (frames // frame_patch_size)
97 | patch_dim = channels * patch_height * patch_width * frame_patch_size
98 |
99 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
100 |
101 | self.to_patch_embedding = nn.Sequential(
102 | Rearrange('b c (h p1) (w p2) (f pf) -> b (h w f) (p1 p2 pf c)', p1 = patch_height, p2 = patch_width, pf = frame_patch_size),
103 | nn.LayerNorm(patch_dim),
104 | nn.Linear(patch_dim, dim),
105 | nn.LayerNorm(dim),
106 | )
107 |
108 | self.pos_embedding = PositionEmbeddingLearned3d(dim // 3,(image_height // patch_height), (image_width // patch_width), (frames // frame_patch_size))
109 | self.dropout = nn.Dropout(emb_dropout)
110 |
111 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
112 |
113 | def forward(self, video):
114 | B, C, H, W, D = video.shape
115 | x = self.to_patch_embedding(video)
116 | b, n, _ = x.shape
117 |
118 | pos = self.pos_embedding(B, H // self.patch_height, W // self.patch_width, D // self.frame_patch_size,x)
119 | x += pos
120 | x = self.dropout(x)
121 |
122 | x = self.transformer(x)
123 | return x,pos
124 |
--------------------------------------------------------------------------------
/src/My_Trainer/__pycache__/trainer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chaoyi-wu/RadFM/9efba84cced9ffdcbce00d6005f255414ffa8c36/src/My_Trainer/__pycache__/trainer.cpython-39.pyc
--------------------------------------------------------------------------------
/src/datasampler.py:
--------------------------------------------------------------------------------
1 | import torch.distributed as dist
2 | import math
3 | from torch.utils.data.sampler import Sampler
4 | from torch.utils.data.sampler import Sampler
5 | from torch.utils.data import DataLoader, DistributedSampler
6 | import random
7 | import torch
8 | from Dataset.multi_dataset import multi_dataset
9 |
10 | def make_batch(index_list, batch_size, drop_last):
11 | if drop_last:
12 | batches = []
13 | whole_batch_num = len(index_list)//batch_size
14 | for _ in range(whole_batch_num):
15 | batches.append(index_list[batch_size*_:(batch_size*(_+1))])
16 | else:
17 | batches = []
18 | whole_batch_num = math.ceil(len(index_list)/batch_size)
19 | for _ in range(whole_batch_num):
20 | batches.append(index_list[batch_size*_:(batch_size*(_+1))])
21 | return batches
22 |
23 | def batch_generation(dataset,batch_size_2D, batch_size_3D,drop_last=False,shuffle = True, seed = 0):
24 |
25 | len_2D = len(dataset.data_whole_2D)
26 | len_3D = len(dataset.data_whole_3D)
27 | index_2D = list(range(len_2D))
28 | index_3D = list(range(len_2D,(len_2D+len_3D)))
29 | assert len(index_2D) + len(index_3D) == len(dataset.data_whole)
30 |
31 | if shuffle:
32 | # deterministically shuffle based on epoch and seed
33 | g = torch.Generator()
34 | g.manual_seed(seed)
35 | random.shuffle(index_2D)
36 | random.shuffle(index_3D)
37 |
38 | batch_2D = make_batch(index_2D, batch_size_2D, drop_last)
39 | batch_3D = make_batch(index_3D, batch_size_3D, drop_last)
40 |
41 | batch_chunk = batch_2D + batch_3D
42 | return batch_chunk
43 |
44 | class My_DistributedBatchSampler(Sampler):
45 | """ Iterable wrapper that distributes data across multiple workers.
46 |
47 | Args:
48 | iterable (iterable)
49 | num_replicas (int, optional): Number of processes participating in distributed training.
50 | rank (int, optional): Rank of the current process within ``num_replicas``.
51 |
52 | Example:
53 | >>> list(DistributedSampler(range(10), num_replicas=2, rank=0))
54 | [0, 2, 4, 6, 8]
55 | >>> list(DistributedSampler(range(10), num_replicas=2, rank=1))
56 | [1, 3, 5, 7, 9]
57 | """
58 |
59 | def __init__(self, dataset, num_replicas=None, rank=None, batch_size_2D = 4, batch_size_3D = 1, drop_last = False, shuffle = True, seed: int = 0):
60 | self.num_replicas = num_replicas
61 | self.rank = rank
62 | self.drop_last = drop_last
63 | self.shuffle = shuffle
64 | self.dataset = dataset
65 | self.batch_size_2D = batch_size_2D
66 | self.batch_size_3D = batch_size_3D
67 | self.seed = seed
68 | self.epoch = 0
69 |
70 | if num_replicas is None or rank is None: # pragma: no cover
71 | if not torch.distributed.is_initialized():
72 | raise RuntimeError('Requires `torch.distributed` to be initialized.')
73 |
74 | self.num_replicas = (
75 | torch.distributed.get_world_size() if num_replicas is None else num_replicas)
76 | self.rank = torch.distributed.get_rank() if rank is None else rank
77 |
78 | indices = batch_generation(self.dataset,self.batch_size_2D,self.batch_size_3D,self.drop_last,self.shuffle)
79 | if self.rank >= self.num_replicas:
80 | raise IndexError('`rank` must be smaller than the `num_replicas`.')
81 |
82 | if self.drop_last and len(indices) % self.num_replicas != 0: # type: ignore[arg-type]
83 | # Split to nearest available length that is evenly divisible.
84 | # This is to ensure each rank receives the same amount of data when
85 | # using this Sampler.
86 | self.num_samples = math.ceil(
87 | (len(indices) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
88 | )
89 | else:
90 | self.num_samples = math.ceil(len(indices) / self.num_replicas) # type: ignore[arg-type]
91 | self.total_size = self.num_samples * self.num_replicas
92 |
93 | def __iter__(self):
94 | indices = batch_generation(self.dataset,self.batch_size_2D,self.batch_size_3D,self.drop_last,self.shuffle,self.seed + self.epoch)
95 | # print(indices)
96 | if self.shuffle:
97 | # deterministically shuffle based on epoch and seed
98 | g = torch.Generator()
99 | g.manual_seed(self.seed + self.epoch)
100 | random.shuffle(indices)
101 |
102 | if not self.drop_last:
103 | # add extra samples to make it evenly divisible
104 | padding_size = self.total_size - len(indices)
105 | if padding_size <= len(indices):
106 | indices += indices[:padding_size]
107 | else:
108 | indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
109 | else:
110 | # remove tail of data to make it evenly divisible.
111 | indices = indices[:self.total_size]
112 | assert len(indices) == self.total_size
113 |
114 | # subsample
115 | indices = indices[self.rank:self.total_size:self.num_replicas]
116 | assert len(indices) == self.num_samples
117 |
118 | return iter(indices)
119 |
120 | def __len__(self):
121 | return self.num_samples
122 |
123 | def set_epoch(self, epoch: int) -> None:
124 | r"""
125 | Set the epoch for this sampler.
126 |
127 | When :attr:`shuffle=True`, this ensures all replicas
128 | use a different random ordering for each epoch. Otherwise, the next iteration of this
129 | sampler will yield the same ordering.
130 |
131 | Args:
132 | epoch (int): Epoch number.
133 | """
134 | self.epoch = epoch
135 |
136 |
137 | # print(My_DistributedBatchSampler)
138 | # Train_dataset = multi_dataset(text_tokenizer = '/mnt/petrelfs/share_data/zhangxiaoman/CODE/RadFM/src/Language_models/tokenizer')
139 |
140 | # DDP_sample_0 = list(My_DistributedBatchSampler(dataset= Train_dataset , num_replicas = 32, rank = 0,))
141 | # DDP_sample_1 = list(My_DistributedBatchSampler(dataset= Train_dataset , num_replicas = 32, rank = 1,))
142 |
143 | # for ii in DDP_sample_0:
144 | # print(ii)
145 |
146 | # for ii in DDP_sample_1:
147 | # print(ii)
148 |
149 |
--------------------------------------------------------------------------------
/src/test.py:
--------------------------------------------------------------------------------
1 | # Import necessary libraries for data processing, modeling, and utilities
2 | import tqdm.auto as tqdm
3 | import torch.nn.functional as F
4 | from typing import Optional, Dict, Sequence
5 | from typing import List, Optional, Tuple, Union
6 | import transformers
7 | from My_Trainer.trainer import Trainer
8 | from dataclasses import dataclass, field
9 | from Dataset.multi_dataset_test import multi_dataset
10 | from Model.RadFM.multimodality_model import MultiLLaMAForCausalLM
11 | from datasampler import My_DistributedBatchSampler
12 | import torch
13 | from torch.utils.data import DataLoader
14 | import csv
15 | import random
16 | import numpy as np
17 |
18 | def setup_seed(seed):
19 | """
20 | Set random seeds for reproducibility across different libraries
21 |
22 | Args:
23 | seed: Integer seed value to use
24 | """
25 | torch.manual_seed(seed)
26 | torch.cuda.manual_seed_all(seed)
27 | np.random.seed(seed)
28 | random.seed(seed)
29 | torch.backends.cudnn.deterministic = True
30 |
31 | # Set seed for reproducibility
32 | setup_seed(20)
33 |
34 |
35 | @dataclass
36 | class ModelArguments:
37 | """
38 | Arguments related to model paths and configuration
39 | """
40 | lang_encoder_path: Optional[str] = field(default="/home/cs/leijiayu/wuchaoyi/book_pretrain/Results/Book_mix_2048_13B_full/checkpoint-45800")
41 | tokenizer_path: str = field(default='/home/cs/leijiayu/wuchaoyi/Finetune_LLAMA/LLAMA_Model/tokenizer', metadata={"help": "Path to the tokenizer data."})
42 | #vision_encoder_path: str = field(default='/home/cs/leijiayu/wuchaoyi/multi_modal/src/PMC-CLIP/checkpoint.pt', metadata={"help": "Path to the vision_encoder."})
43 |
44 |
45 | @dataclass
46 | class DataArguments:
47 | """
48 | Arguments related to dataset configuration and testing modes
49 | """
50 | Mode: Optional[str] = field(default="Train")
51 | test_split: Optional[str] = field(default="open")
52 |
53 | @dataclass
54 | class TrainingArguments(transformers.TrainingArguments):
55 | """
56 | Custom training arguments extending HuggingFace's TrainingArguments
57 | with additional parameters for multimodal training
58 | """
59 | remove_unused_columns: bool = field(default = False)
60 | batch_size_2D: int = field(default = 4) # Batch size for 2D data
61 | batch_size_3D: int = field(default = 1) # Batch size for 3D data
62 | output_dir: Optional[str] = field(default="/home/cs/leijiayu/wuchaoyi/multi_modal/src/Results/BLIP_overfit/")
63 | cache_dir: Optional[str] = field(default=None)
64 | optim: str = field(default="adamw_torch")
65 |
66 |
67 | @dataclass
68 | class DataCollator(object):
69 | """
70 | Data collator for preparing batches of multimodal inputs for the model
71 | """
72 |
73 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
74 | # Extract different components from the input instances
75 | vision_xs, lang_xs, attention_masks, labels = tuple(
76 | [instance[key] for instance in instances]
77 | for key in ('vision_x','lang_x', 'attention_mask', 'labels')
78 | )
79 |
80 | # Stack language tensors along batch dimension
81 | lang_xs = torch.cat([_.unsqueeze(0) for _ in lang_xs], dim=0)
82 | attention_masks = torch.cat([_.unsqueeze(0) for _ in attention_masks], dim=0)
83 | labels = torch.cat([_.unsqueeze(0) for _ in labels], dim=0)
84 |
85 | # Set target dimensions for resizing vision inputs
86 | target_H = 512
87 | target_W = 512
88 | target_D = 4
89 | MAX_D = 0
90 |
91 | # Reduce resolution for single samples to save memory
92 | if len(vision_xs) == 1:
93 | target_H = 256
94 | target_W = 256
95 |
96 | # Define possible depth values for 3D data
97 | D_list = list(range(4,65,4))
98 | # Adjust depth values for large inputs
99 | if len(vision_xs) == 1:
100 | if vision_xs[0].shape[0] > 6:
101 | D_list = list(range(4,33,4))
102 |
103 | # Find maximum depth in current batch
104 | for ii in vision_xs:
105 | try:
106 | D = ii.shape[-1]
107 | if D > MAX_D:
108 | MAX_D = D
109 | except:
110 | continue
111 |
112 | # Select closest target depth from available options
113 | for temp_D in D_list:
114 | if abs(temp_D - MAX_D) < abs(target_D - MAX_D):
115 | target_D = temp_D
116 |
117 | # Resize all vision inputs to target dimensions
118 | vision_xs = [torch.nn.functional.interpolate(s, size=(target_H, target_W, target_D)) for s in vision_xs]
119 |
120 | # Pad sequence for variable-length vision inputs
121 | vision_xs = torch.nn.utils.rnn.pad_sequence(
122 | vision_xs, batch_first=True, padding_value=0
123 | )
124 | print(vision_xs.shape)
125 |
126 | # Return collated batch
127 | return dict(
128 | lang_x=lang_xs,
129 | vision_x=vision_xs,
130 | attention_mask=attention_masks,
131 | labels=labels,
132 | )
133 |
134 | def main():
135 | """
136 | Main function to set up and run the inference process
137 | """
138 | # Parse command-line arguments
139 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
140 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
141 |
142 | # Set custom data sampler
143 | training_args.data_sampler = My_DistributedBatchSampler
144 |
145 | print("Setup Data")
146 | # Initialize test dataset with specified split
147 | Test_dataset = multi_dataset(text_tokenizer=model_args.tokenizer_path, test_split=data_args.test_split)
148 |
149 | # Configure DataLoader for test dataset
150 | Test_dataloader = DataLoader(
151 | Test_dataset,
152 | batch_size=1,
153 | num_workers=1,
154 | pin_memory=True,
155 | sampler=None,
156 | shuffle=True,
157 | collate_fn=None,
158 | drop_last=False,
159 | )
160 |
161 | print("Setup Model")
162 | # Initialize the multimodal model
163 | model = MultiLLaMAForCausalLM(
164 | lang_model_path=model_args.lang_encoder_path,
165 | )
166 |
167 | # Load pre-trained model checkpoint
168 | ckpt = torch.load('/gpfs/home/cs/leijiayu/wuchaoyi/wangyingjie/src/Results/backup/checkpoint-17600/pytorch_model.bin', map_location='cpu')
169 | # ckpt.pop('embedding_layer.figure_token_weight')
170 | model.load_state_dict(ckpt, strict=False)
171 | model = model.to('cuda')
172 | model.eval() # Set model to evaluation mode
173 |
174 | # Create output CSV file for results
175 | with open('output_whole_2_epoch' + data_args.test_split + '.csv', mode='w') as outfile:
176 | writer = csv.writer(outfile)
177 | writer.writerow(["Question", "Ground Truth", "Pred", 'belong_to'])
178 | cc = 0
179 |
180 | # Process each sample in the test dataset
181 | for sample in tqdm.tqdm(Test_dataloader):
182 | question = sample["question"]
183 | belong_to = sample['belong_to']
184 | # img_pp = sample['img_path']
185 |
186 | # Tokenize the question text
187 | lang_x = Test_dataset.text_tokenizer(
188 | question, max_length=2048, truncation=True, return_tensors="pt"
189 | )['input_ids'].to('cuda')
190 |
191 | # Get vision input
192 | vision_x = sample["vision_x"].to('cuda')
193 | answer = sample['answer']
194 |
195 | try:
196 | # Generate text based on text and vision inputs
197 | generation = model.generate(lang_x, vision_x)
198 | generated_texts = Test_dataset.text_tokenizer.batch_decode(generation, skip_special_tokens=True)
199 |
200 | # Write results to CSV
201 | writer.writerow([question, answer, generated_texts, belong_to])
202 | cc = cc + 1
203 | # if cc >= 10000:
204 | # break
205 | except:
206 | continue
207 |
208 | if __name__ == "__main__":
209 | main()
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | # Import necessary libraries
2 | import tqdm.auto as tqdm
3 | import torch.nn.functional as F
4 | from typing import Optional, Dict, Sequence
5 | from typing import List, Optional, Tuple, Union
6 | import transformers
7 | from My_Trainer.trainer import Trainer
8 | from dataclasses import dataclass, field
9 | from Dataset.multi_dataset import multi_dataset
10 | from Model.RadFM.multimodality_model import MultiLLaMAForCausalLM
11 | from datasampler import My_DistributedBatchSampler
12 | from datasets import load_metric
13 | from Dataset.multi_dataset_test_for_close import multi_dataset_close
14 | import numpy as np
15 | import torch
16 |
17 |
18 | def compute_metrics(eval_preds):
19 | """
20 | Compute evaluation metrics from prediction outputs.
21 | Returns the mean accuracy across all predictions.
22 |
23 | Args:
24 | eval_preds: Prediction outputs from the model
25 |
26 | Returns:
27 | Dictionary containing accuracy metric
28 | """
29 | # metric = load_metric("glue", "mrpc")
30 | ACCs = eval_preds.predictions
31 | # print(ACCs)
32 | return {"accuracy": np.mean(ACCs, axis=-1)}
33 |
34 | @dataclass
35 | class ModelArguments:
36 | """
37 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
38 | """
39 | lang_encoder_path: Optional[str] = field(default="/home/cs/leijiayu/wuchaoyi/book_pretrain/Results/Book_mix_2048_13B_full/checkpoint-45800")
40 | tokenizer_path: str = field(default='/home/cs/leijiayu/wuchaoyi/Finetune_LLAMA/LLAMA_Model/tokenizer',
41 | metadata={"help": "Path to the tokenizer data."})
42 |
43 |
44 |
45 | @dataclass
46 | class DataArguments:
47 | """
48 | Arguments pertaining to data processing mode.
49 | """
50 | Mode: Optional[str] = field(default="Train")
51 |
52 | @dataclass
53 | class TrainingArguments(transformers.TrainingArguments):
54 | """
55 | Custom training arguments extending the HuggingFace TrainingArguments class.
56 | Includes additional parameters specific to this multimodal training setup.
57 | """
58 | remove_unused_columns: bool = field(default=False)
59 | batch_size_2D: int = field(default=4) # Batch size for 2D data
60 | batch_size_3D: int = field(default=1) # Batch size for 3D data
61 | output_dir: Optional[str] = field(default="/home/cs/leijiayu/wuchaoyi/multi_modal/src/Results/BLIP_overfit/")
62 | cache_dir: Optional[str] = field(default=None)
63 | optim: str = field(default="adamw_torch")
64 |
65 |
66 | @dataclass
67 | class DataCollator(object):
68 | """
69 | Data collator that handles batching of multimodal inputs.
70 | Processes vision and language inputs, handles padding, and resizes vision inputs.
71 | """
72 |
73 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
74 | # Extract different data components from instances
75 | vision_xs, lang_xs, attention_masks, labels, loss_reweight, key_words_query = tuple(
76 | [instance[key] for instance in instances]
77 | for key in ('vision_x', 'lang_x', 'attention_mask', 'labels', 'loss_reweight', 'key_words_query')
78 | )
79 |
80 | # Stack language tensors along batch dimension
81 | lang_xs = torch.cat([_.unsqueeze(0) for _ in lang_xs], dim=0)
82 | attention_masks = torch.cat([_.unsqueeze(0) for _ in attention_masks], dim=0)
83 | labels = torch.cat([_.unsqueeze(0) for _ in labels], dim=0)
84 | loss_reweight = torch.cat([_.unsqueeze(0) for _ in loss_reweight], dim=0)
85 |
86 | # Set target dimensions for vision input resizing
87 | target_H = 512
88 | target_W = 512
89 | target_D = 4
90 | MAX_D = 0
91 |
92 | # Define possible depth values for 3D data
93 | D_list = list(range(4, 65, 4))
94 | # Adjust depth range for larger inputs
95 | if len(vision_xs) == 1:
96 | if vision_xs[0].shape[0] > 6:
97 | D_list = list(range(4, 33, 4))
98 |
99 | # Find maximum depth in current batch
100 | for ii in vision_xs:
101 | try:
102 | D = ii.shape[-1]
103 | if D > MAX_D:
104 | MAX_D = D
105 | except:
106 | continue
107 |
108 | # Select closest target depth from available options
109 | for temp_D in D_list:
110 | if abs(temp_D - MAX_D) < abs(target_D - MAX_D):
111 | target_D = temp_D
112 |
113 | # Reduce image dimensions for larger depth inputs with small batch size
114 | if len(vision_xs) == 1 and target_D > 4:
115 | target_H = 256
116 | target_W = 256
117 |
118 | # Resize all vision inputs to target dimensions
119 | vision_xs = [torch.nn.functional.interpolate(s, size=(target_H, target_W, target_D)) for s in vision_xs]
120 |
121 | # Pad sequence for variable-length vision inputs
122 | vision_xs = torch.nn.utils.rnn.pad_sequence(
123 | vision_xs, batch_first=True, padding_value=0
124 | )
125 | print(vision_xs.shape, vision_xs.dtype)
126 |
127 | # Return collated batch
128 | return dict(
129 | lang_x=lang_xs,
130 | vision_x=vision_xs,
131 | attention_mask=attention_masks,
132 | labels=labels,
133 | loss_reweight=loss_reweight,
134 | key_words_query=key_words_query
135 | )
136 |
137 | def main():
138 | """
139 | Main function to set up and run the training process.
140 | Parses arguments, initializes datasets, model, and trainer.
141 | """
142 | # Parse command-line arguments
143 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
144 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
145 |
146 | # Set custom data sampler
147 | training_args.data_sampler = My_DistributedBatchSampler
148 |
149 | print("Setup Data")
150 | # Initialize training and evaluation datasets
151 | Train_dataset = multi_dataset(text_tokenizer=model_args.tokenizer_path)
152 | Eval_dataset = multi_dataset_close(text_tokenizer=model_args.tokenizer_path)
153 |
154 | print("Setup Model")
155 | # Initialize the multimodal model
156 | model = MultiLLaMAForCausalLM(
157 | lang_model_path=model_args.lang_encoder_path,
158 | )
159 |
160 | # Setup trainer with model, datasets, and configurations
161 | trainer = Trainer(
162 | model=model,
163 | train_dataset=Train_dataset,
164 | eval_dataset=Eval_dataset,
165 | args=training_args,
166 | data_collator=DataCollator(),
167 | compute_metrics=compute_metrics
168 | )
169 |
170 | # Start training
171 | trainer.train()
172 | # Save training state
173 | trainer.save_state()
174 |
175 | if __name__ == "__main__":
176 | main()
--------------------------------------------------------------------------------