├── main.png
├── test_img.jpg
├── test.bash
├── ReadMe.md
├── test_extract_id.py
├── opts.py
└── st_adapter_DyKnow.py
/main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianyuan168326/All-in-One-MedReID-Pytorch/HEAD/main.png
--------------------------------------------------------------------------------
/test_img.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tianyuan168326/All-in-One-MedReID-Pytorch/HEAD/test_img.jpg
--------------------------------------------------------------------------------
/test.bash:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node 1 --master_port=28349 test_extract_id.py --lr 5e-6 \
2 | --epochs 20 --optim adamw --save_every_n_epochs 3 --wd 5e-2 --backbone clip_Tadapt --clip_grad -1 --online_compute_id \
3 | --batch_size 6 --text_use_method identity_CL --mimic_all --lora_rank 16 --margin -1 --ft_all --relation_type difference --start_epoch 0 \
4 | --exp_dir ./ \
5 | --resume_main MAMI_pretrained.pth
6 |
--------------------------------------------------------------------------------
/ReadMe.md:
--------------------------------------------------------------------------------
1 | # All-in-One Medical Image Re-Identification (CVPR2025)
2 |
3 | MaMI is the fist unified re-identification model for medical images, capable of handling various imaging modalities such as X-ray, CT, fundus, and pathology images. By leveraging a Continuous Modality-based Parameter Adapter (ComPA) and integrating medical priors, MaMI supports both historical data-assisted diagnosis and privacy protection applications.
4 |
5 |
6 |

7 |
8 |
9 |
10 |
11 | ## Features
12 |
13 | - **Unified Multi-Modality Model**
14 | A single model supports multiple medical image modalities without having to train separate models for each modality.
15 |
16 | - **Continuous Modality Adaptive Parameterization**
17 | The ComPA module generates continuous modality representations to dynamically adapt model parameters based on the input image.
18 |
19 | - **Medical Priors Integration**
20 | Incorporates pre-trained Medical Foundation Models (MFMs) to enhance feature discrimination, capturing subtle identity-related cues for more robust re-identification.
21 |
22 | ## Requirements
23 |
24 | - Python 3.7+
25 | - PyTorch 1.9+
26 | - CUDA-enabled GPU (recommended: NVIDIA RTX 4090)
27 |
28 | ## Pre-trained Model
29 | MAMI_pretrained.pth
30 |
31 | Google Driver Link: https://drive.google.com/file/d/159KsDSgzgFCdSLN-I1iO0I2Z8Dm7dpiK/view?usp=sharing
32 |
33 | BaiduYun (百度网盘) Link: https://pan.baidu.com/s/1GD-TsafqYhbXM6VwQVUwIA?pwd=fbge
34 |
35 | ## Usage
36 |
37 | 1. You can download the pre-trained model from the above link.
38 |
39 | 2. copy the file the file "MAMI_pretrained.pth" to this directory.
40 |
41 | 3. run "test.bash"
42 |
43 | ## ToDo
44 | - [ ] **Open Source Train/Validation Split**
45 | Publish the scripts/configuration used to generate the train/val split for the dataset.
46 |
47 | - [ ] **Open Source The Benchmark Methods**
48 |
49 | - [ ] **Open Source Training Code**
50 | Release the complete training pipeline including all scripts and configuration files.
51 |
52 |
53 | ## Citation
54 | If you find our work useful, please cite:
55 |
56 | @article{tian2025towards,
57 | title={Towards All-in-One Medical Image Re-Identification},
58 | author={Tian, Yuan and Ji, Kaiyuan and Zhang, Rongzhao and Jiang, Yankai and Li, Chunyi and Wang, Xiaosong and Zhai, Guangtao},
59 | journal={arXiv preprint arXiv:2503.08173},
60 | year={2025}
61 | }
62 |
--------------------------------------------------------------------------------
/test_extract_id.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
4 | os.environ['TORCH_HOME']='/root/autodl-tmp/torch_home'
5 | os.environ['HF_HOME']='/root/autodl-tmp/huggingface_cache'
6 | os.environ['HF_ENDPOINT']='https://hf-mirror.com'
7 |
8 | import torch
9 | from opts import parser
10 | from torchvision.transforms import Compose, InterpolationMode, Normalize, ToTensor
11 | import torch.distributed as dist
12 | from PIL import Image
13 | import torchvision.transforms as T
14 | dist.init_process_group(backend='nccl') # nccl是GPU设备上最快、最推荐的后端
15 | args = parser.parse_args()
16 | local_rank = args.local_rank
17 | torch.cuda.set_device(local_rank)
18 |
19 | from torch.utils.tensorboard import SummaryWriter
20 | if local_rank == 0:
21 | tb_logger = SummaryWriter(log_dir=os.path.join(args.exp_dir,'board'),flush_secs=10)
22 | checkpoint_dir = os.path.join(args.exp_dir,'checkpoints')
23 | if not os.path.exists(checkpoint_dir):
24 | os.mkdir(checkpoint_dir)
25 | log_training = open(os.path.join(args.exp_dir,'log.csv'), 'w')
26 | log_training.write(str(args))
27 |
28 | device = "cuda"
29 | from st_adapter_DyKnow import clip_vit_base_patch16_adapter24x384,clip_vit_base_patch16
30 |
31 | # from ema import EMA
32 | import torchvision.transforms as T
33 |
34 | model =clip_vit_base_patch16_adapter24x384(num_classes=1,args = args,lora_rank =args.lora_rank ).to(device).train()
35 | val_transform =T.Compose(
36 | [
37 | T.Resize(size=256, antialias=True),
38 | ToTensor(),
39 | Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
40 | T.CenterCrop((224,224))
41 | ]
42 | )
43 |
44 |
45 | if os.path.isfile(args.resume_main):
46 | print(("=> loading checkpoint '{}'".format(args.resume_main)))
47 | checkpoint = torch.load(args.resume_main,map_location='cpu')
48 | checkpoint_model = checkpoint["model"]
49 | ckeck_copy = {}
50 | ks =checkpoint_model.keys()
51 | for k in ks:
52 | new_k = k
53 | if k.startswith("module"):
54 | if "router_MLP.2" in k or "knowledge_pool" in k:
55 | if args.lora_rank<0:
56 | continue
57 | new_k = k[7:]
58 | ckeck_copy[new_k] = checkpoint_model[k]
59 | model.load_state_dict(ckeck_copy,strict=False)
60 |
61 |
62 |
63 | contain_trainable = False
64 | for name, param in model.named_parameters():
65 | if param.requires_grad:
66 | contain_trainable = True
67 | break
68 | if contain_trainable:
69 | model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[torch.cuda.current_device()],\
70 | broadcast_buffers=False , find_unused_parameters=True)
71 | else:
72 | model = model.cuda()
73 |
74 |
75 | ## we accept both single-image (such as Xray, fundus) and sequence (such as CT) medical images.
76 | ## The input dimension is (B,C,H,W) or (B,L,H,W), where L is the sequence length
77 |
78 | ##
79 | img_path = "test_img.jpg"
80 | img = Image.open(img_path).convert('RGB')
81 | img_tensor = val_transform(img)
82 | image_tensor_example = img_tensor.unsqueeze(0)
83 | input = image_tensor_example
84 | if len(input.size()) ==4:
85 | input = input.unsqueeze(2)
86 | ft_x,_ = model(input.to(device))
87 |
88 | print("ID tensor", ft_x)
--------------------------------------------------------------------------------
/opts.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | parser = argparse.ArgumentParser(description="PyTorch implementation of action recognition models")
3 | parser.add_argument('--exp_dir', type=str, default="experiments/exp_test",help='Path to experiment output directory')
4 | parser.add_argument('--train_csv', type=str, default="",help='Path to experiment output directory')
5 | parser.add_argument('--optim', type=str, default="sgd",help='')
6 | parser.add_argument('--verification_csv', type=str, default="",help='Path to experiment output directory')
7 | parser.add_argument('--epochs', default=20, type=int, metavar='N',
8 | help='number of total epochs to run')
9 | parser.add_argument('--save_every_n_epochs', default=5, type=int, metavar='N',
10 | help='number of total epochs to run')
11 | parser.add_argument('--test_every_n_epochs', default=1, type=int, metavar='N',
12 | help='number of total epochs to run')
13 | parser.add_argument('--lora_rank', default=8, type=int, metavar='N',
14 | help='number of total epochs to run')
15 | parser.add_argument('--code_number', default=16, type=int, metavar='N',
16 | help='number of total epochs to run')
17 | parser.add_argument('--resume_main', default='', type=str, metavar='PATH',
18 | help='path to latest resume_main checkpoint (default: none)')
19 | parser.add_argument('--data_modality', default='image', type=str,
20 | help='path to latest resume_main checkpoint (default: none)')
21 | parser.add_argument('--backbone', default='clip_stadapt_imagenet', type=str,
22 | help='path to latest resume_main checkpoint (default: none)')
23 | parser.add_argument('--lr', '--learning-rate', default=0.0001, type=float,
24 | metavar='LR', help='initial learning rate')
25 | parser.add_argument('--wd', default=0.0001, type=float,
26 | metavar='LR', help='initial learning rate')
27 | parser.add_argument('--margin', default=0.4, type=float,
28 | metavar='LR', help='initial learning rate')
29 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
30 | help='evaluate model on validation set')
31 | parser.add_argument( '--only_eval_privacy', dest='only_eval_privacy', action='store_true',
32 | help='evaluate model on validation set')
33 | parser.add_argument('--batch_size', default=8, type=int, help='Batch size for training')
34 | parser.add_argument('--start_epoch', default=0, type=int, help='Batch size for training')
35 | parser.add_argument('--text_use_method', type=str, default="none",help='Path to experiment output directory')
36 | parser.add_argument('--relation_type', type=str, default="mlp",help='Path to experiment output directory')
37 | parser.add_argument('--mimic_all', action='store_true', help='Enable verbose mode')
38 | parser.add_argument('--no_codebook', action='store_true', help='Enable verbose mode')
39 | parser.add_argument('--ft_all', action='store_true', help='Enable verbose mode')
40 | parser.add_argument('--ft_backbone_lr_multi', default=1, type=float,
41 | metavar='LR', help='initial learning rate')
42 | parser.add_argument('--train_modality', type=str, default="all",help='')
43 |
44 | parser.add_argument('--online_compute_id', action='store_true', help='Enable verbose mode')
45 | parser.add_argument('--aug_color', action='store_true', help='Enable verbose mode')
46 | parser.add_argument('--aug_resize', action='store_true', help='Enable verbose mode')
47 |
48 | parser.add_argument('--no_randE', action='store_true', help='Enable verbose mode')
49 | parser.add_argument('--no_randCrop', action='store_true', help='Enable verbose mode')
50 | parser.add_argument('--clip_grad', default=0, type=float,
51 | metavar='LR', help='initial learning rate')
52 |
53 | parser.add_argument('--local-rank', type=int, default="0")
54 |
55 |
--------------------------------------------------------------------------------
/st_adapter_DyKnow.py:
--------------------------------------------------------------------------------
1 | # modified from: https://github.com/openai/CLIP/blob/a9b1bf5920416aaeaec965c25dd9e8f98c864f16/clip/model.py
2 |
3 | from typing import Tuple
4 | from collections import OrderedDict
5 | import math
6 | import functools
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | CLIP_VIT_B16_PATH = "/root/autodl-tmp/patient_triple/code/open-metric-learning/my_code/pretrained_models/ViT-B-16.pt"
13 | DWCONV3D_DISABLE_CUDNN = True
14 | class Adapter(nn.Module):
15 |
16 | def __init__(self, in_channels, adapter_channels, kernel_size):
17 | super().__init__()
18 | self.fc1 = nn.Linear(in_channels, adapter_channels)
19 | self.conv = nn.Conv3d(
20 | adapter_channels, adapter_channels,
21 | kernel_size=kernel_size,
22 | stride=(1, 1, 1),
23 | padding=tuple(x // 2 for x in kernel_size),
24 | groups=adapter_channels,
25 | )
26 | self.fc2 = nn.Linear(adapter_channels, in_channels)
27 | nn.init.constant_(self.conv.weight, 0.)
28 | nn.init.constant_(self.conv.bias, 0.)
29 | nn.init.constant_(self.fc1.bias, 0.)
30 | nn.init.constant_(self.fc2.bias, 0.)
31 |
32 | def forward(self, x, T):
33 | BT, L, C = x.size()
34 | B = BT // T
35 | Ca = self.conv.in_channels
36 | H = W = round(math.sqrt(L - 1))
37 | assert L - 1 == H * W
38 | x_id = x
39 | x = x[:, 1:, :]
40 | x = self.fc1(x)
41 | x = x.view(B, T, H, W, Ca).permute(0, 4, 1, 2, 3).contiguous()
42 |
43 | cudnn_enabled = torch.backends.cudnn.enabled
44 | torch.backends.cudnn.enabled = cudnn_enabled and DWCONV3D_DISABLE_CUDNN
45 | x = self.conv(x)
46 | torch.backends.cudnn.enabled = cudnn_enabled
47 |
48 | x = x.permute(0, 2, 3, 4, 1).contiguous().view(BT, L - 1, Ca)
49 | x = self.fc2(x)
50 | x_id[:, 1:, :] += x
51 | return x_id
52 |
53 |
54 | class LayerNorm(nn.LayerNorm):
55 | """Subclass torch's LayerNorm to handle fp16."""
56 |
57 | def forward(self, x: torch.Tensor):
58 | orig_type = x.dtype
59 | ret = super().forward(x.type(torch.float32))
60 | return ret.type(orig_type)
61 |
62 |
63 | class QuickGELU(nn.Module):
64 | def forward(self, x: torch.Tensor):
65 | return x * torch.sigmoid(1.702 * x)
66 | class GroupedLinear(nn.Module):
67 | def __init__(self, in_features, out_features, groups=1, bias=True):
68 | super(GroupedLinear, self).__init__()
69 | self.in_features = in_features
70 | self.out_features = out_features
71 | self.groups = groups
72 | self.bias = bias
73 |
74 | assert in_features % groups == 0, "in_features must be divisible by groups"
75 | assert out_features % groups == 0, "out_features must be divisible by groups"
76 |
77 | self.weight = nn.Parameter(torch.Tensor(groups, in_features // groups, out_features // groups))
78 | if bias:
79 | self.bias = nn.Parameter(torch.Tensor(groups, out_features // groups))
80 | else:
81 | self.register_parameter('bias', None)
82 |
83 | self.reset_parameters()
84 |
85 | def reset_parameters(self):
86 | nn.init.constant_(self.weight,0)
87 | nn.init.constant_(self.bias,0)
88 | # nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
89 | # if self.bias is not None:
90 | # fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
91 | # bound = 1 / math.sqrt(fan_in)
92 | # nn.init.uniform_(self.bias, -bound, bound)
93 |
94 | def forward(self, x):
95 | B, C = x.size()
96 | assert C == self.in_features, "Input feature dimension does not match"
97 |
98 | x = x.view(B, self.groups, C // self.groups)
99 | x = torch.einsum('bgi,gio->bgo', x, self.weight)
100 |
101 | if self.bias is not None:
102 | x = x + self.bias
103 |
104 | x = x.reshape(B, self.out_features)
105 | return x
106 |
107 | class GroupedMLP(nn.Module):
108 | def __init__(self, in_features, out_features, groups=1, bias=True):
109 | super(GroupedMLP, self).__init__()
110 | self.l1 = GroupedLinear(in_features, in_features//4,1)
111 | self.l2 = GroupedLinear(in_features//4, out_features,groups)
112 |
113 |
114 | def forward(self, x):
115 | return self.l2 (F.leaky_relu(self.l1(x),0.1))
116 |
117 | class GroupedMLP(nn.Module):
118 | def __init__(self, in_features, out_features, groups=1, bias=True):
119 | super(GroupedMLP, self).__init__()
120 | self.l11 = GroupedLinear(in_features, in_features//2,1)
121 | self.l21 = GroupedLinear(in_features//2, in_features//4,1)
122 | self.l31 = GroupedLinear(in_features//4, out_features,groups)
123 |
124 |
125 | def forward(self, x):
126 | res = x
127 | x = self.l21 (F.leaky_relu(self.l11(x),0.1,True))
128 | x = F.leaky_relu(x,0.1,True)
129 | return self.l31(x)
130 |
131 |
132 | class ResidualAttentionBlock(nn.Module):
133 | def __init__(self,
134 | d_model: int,
135 | n_head: int,
136 | adapter_width: int,
137 | adapter_kernel_size: Tuple[int, int, int],
138 | adapter_pre_attn: bool,
139 | adapter_pre_mlp: bool,
140 | lora_rank ,
141 | ft_all
142 | ) -> None:
143 | super().__init__()
144 |
145 | self.attn = nn.MultiheadAttention(d_model, n_head)
146 | self.ln_1 = LayerNorm(d_model)
147 | self.mlp = nn.Sequential(OrderedDict([
148 | ("c_fc", nn.Linear(d_model, d_model * 4)),
149 | ("gelu", QuickGELU()),
150 | ("c_proj", nn.Linear(d_model * 4, d_model))
151 | ]))
152 | self.ln_2 = LayerNorm(d_model)
153 | if not ft_all:
154 | for _ in self.attn.parameters():
155 | _.requires_grad = False
156 | for _ in self.ln_1.parameters():
157 | _.requires_grad = False
158 | for _ in self.mlp.parameters():
159 | _.requires_grad = False
160 | for _ in self.ln_2.parameters():
161 | _.requires_grad = False
162 |
163 | adapter_class = functools.partial(
164 | Adapter,
165 | in_channels=d_model,
166 | adapter_channels=adapter_width,
167 | kernel_size=adapter_kernel_size,
168 | )
169 | self.adapter_pre_attn = \
170 | adapter_class() if adapter_pre_attn else None
171 | self.adapter_pre_mlp = \
172 | adapter_class() if adapter_pre_mlp else None
173 | self.d_model = d_model
174 | self.lora_rank = lora_rank
175 | if lora_rank>0:
176 | self.lora_rank_dynamic = lora_rank_dynamic = lora_rank
177 | G = 16
178 | self.lora_mlp_q_a = GroupedMLP(d_model, lora_rank_dynamic * d_model,groups=G)
179 | self.lora_mlp_q_b = GroupedMLP(d_model, lora_rank_dynamic * d_model,groups=G)
180 | self.lora_mlp_k_a = GroupedMLP(d_model, lora_rank_dynamic * d_model,groups=G)
181 | self.lora_mlp_k_b = GroupedMLP(d_model, lora_rank_dynamic * d_model,groups=G)
182 | self.lora_mlp_v_a = GroupedMLP(d_model, lora_rank_dynamic * d_model,groups=G)
183 | self.lora_mlp_v_b = GroupedMLP(d_model, lora_rank_dynamic * d_model,groups=G)
184 |
185 | self.linear_a_q_static = nn.Parameter(torch.zeros(1, self.lora_rank, self.d_model))
186 | self.linear_b_q_static = nn.Parameter(torch.zeros(1, self.lora_rank, self.d_model))
187 | self.linear_a_k_static = nn.Parameter(torch.zeros(1, self.lora_rank, self.d_model))
188 | self.linear_b_k_static = nn.Parameter(torch.zeros(1, self.lora_rank, self.d_model))
189 | self.linear_a_v_static = nn.Parameter(torch.zeros(1, self.lora_rank, self.d_model))
190 | self.linear_b_v_static = nn.Parameter(torch.zeros(1, self.lora_rank, self.d_model))
191 | # if ft_all:
192 | # self.linear_a_q_static.requires_grad = False
193 | # self.linear_b_q_static.requires_grad = False
194 | # self.linear_a_k_static.requires_grad = False
195 | # self.linear_b_k_static.requires_grad = False
196 | # self.linear_a_v_static.requires_grad = False
197 | # self.linear_b_v_static.requires_grad = False
198 | # LORA parameter generation MLPs -for FFN
199 | G = 32
200 | self.lora_mlp_c_fc_a = GroupedMLP(d_model, lora_rank_dynamic * (d_model * 4),groups=G)
201 | self.lora_mlp_c_fc_b = GroupedMLP(d_model, lora_rank_dynamic * d_model,groups=G)
202 | self.lora_mlp_c_proj_a = GroupedMLP(d_model, lora_rank_dynamic * d_model,groups=G)
203 | self.lora_mlp_c_proj_b = GroupedMLP(d_model, lora_rank_dynamic * (d_model * 4),groups=G)
204 |
205 | self.linear_a_c_fc_static = nn.Parameter(torch.zeros(1, self.lora_rank, self.d_model * 4))
206 | self.linear_b_c_fc_static = nn.Parameter(torch.zeros(1, self.lora_rank, self.d_model))
207 | self.linear_a_c_proj_static = nn.Parameter(torch.zeros(1, self.lora_rank, self.d_model))
208 | self.linear_b_c_proj_static = nn.Parameter(torch.zeros(1, self.lora_rank, self.d_model * 4))
209 | # if ft_all:
210 | # self.linear_a_c_fc_static.requires_grad = False
211 | # self.linear_b_c_fc_static.requires_grad = False
212 | # self.linear_a_c_proj_static.requires_grad = False
213 | # self.linear_b_c_proj_static.requires_grad = False
214 |
215 |
216 | def attention(self, x: torch.Tensor,x_route_context) -> torch.Tensor:
217 |
218 | B, L, C = x.size()
219 | H = self.attn.num_heads
220 |
221 | if self.lora_rank>0:
222 | # B C x_route_context
223 | # Generate LORA parameters
224 | linear_a_q = self.lora_mlp_q_a(x_route_context).view(B, self.lora_rank_dynamic, C)+self.linear_a_q_static
225 | linear_b_q = self.lora_mlp_q_b(x_route_context).view(B, self.lora_rank_dynamic, C)+self.linear_b_q_static
226 | linear_a_k = self.lora_mlp_k_a(x_route_context).view(B, self.lora_rank_dynamic, C)+self.linear_a_k_static
227 | linear_b_k = self.lora_mlp_k_b(x_route_context).view(B, self.lora_rank_dynamic, C)+self.linear_b_k_static
228 | linear_a_v = self.lora_mlp_v_a(x_route_context).view(B, self.lora_rank_dynamic, C)+self.linear_a_v_static
229 | linear_b_v = self.lora_mlp_v_b(x_route_context).view(B, self.lora_rank_dynamic, C)+self.linear_b_v_static
230 |
231 |
232 | new_q = torch.einsum('bkc,blk->blc', linear_b_q, torch.einsum('bri,bti->btr', linear_a_q, x))
233 | ## B K C * B L K
234 | new_k = torch.einsum('bkc,blk->blc', linear_b_k, torch.einsum('bri,bti->btr', linear_a_k, x))
235 | new_v = torch.einsum('bkc,blk->blc', linear_b_v, torch.einsum('bri,bti->btr', linear_a_v, x))
236 | ##
237 |
238 | qkv = F.linear(x, weight=self.attn.in_proj_weight, bias=self.attn.in_proj_bias)
239 | if self.lora_rank>0:
240 | qkv[:, :, :self.d_model] += new_q
241 | qkv[:, :, self.d_model:-self.d_model] += new_k
242 | qkv[:, :, -self.d_model:] += new_v
243 | qkv = qkv.view(B, L, H * 3, -1).permute(0, 2, 1, 3)
244 | q, k, v = qkv.split([H, H, H], dim=1)
245 | out = F.scaled_dot_product_attention(q, k, v)
246 | out = out.permute(0, 2, 1, 3).flatten(-2)
247 | out = self.attn.out_proj(out)
248 |
249 | return out
250 |
251 | def mlp_wrapper(self, x: torch.Tensor,x_route_context) -> torch.Tensor:
252 | if self.lora_rank>0:
253 | B, L, C = x.size()
254 | linear_a_c_fc = self.lora_mlp_c_fc_a(x_route_context).view(B, self.lora_rank, C * 4)+self.linear_a_c_fc_static
255 | linear_b_c_fc = self.lora_mlp_c_fc_b(x_route_context).view(B, self.lora_rank, C)+self.linear_b_c_fc_static
256 | # linear_a_c_proj = self.lora_mlp_c_proj_a(x_route_context).view(B, self.lora_rank, C)+self.linear_a_c_proj_static
257 | # linear_b_c_proj = self.lora_mlp_c_proj_b(x_route_context).view(B, self.lora_rank, C * 4)+self.linear_b_c_proj_static
258 |
259 | c_fc_out = F.linear(x, weight=self.mlp.c_fc.weight, bias=self.mlp.c_fc.bias)
260 |
261 | if self.lora_rank>0:
262 | new_c_fc = torch.einsum('bkc,blk->blc', linear_a_c_fc, torch.einsum('bri,bti->btr', linear_b_c_fc, x))
263 | ### B K C * B K L
264 | c_fc_out += new_c_fc
265 |
266 | c_fc_out = self.mlp.gelu(c_fc_out)
267 |
268 | # Compute c_proj with LORA
269 | c_proj_out = F.linear(c_fc_out, weight=self.mlp.c_proj.weight, bias=self.mlp.c_proj.bias)
270 | # if self.lora_rank>0:
271 | # new_c_proj = torch.einsum('bkc,blk->blc', linear_a_c_proj, torch.einsum('bri,bti->btr', linear_b_c_proj, c_fc_out))
272 | # c_proj_out += new_c_proj
273 |
274 | return c_proj_out
275 |
276 |
277 | def forward(self,
278 | x: torch.Tensor,
279 | num_frames: int,
280 | x_route_context
281 | ) -> torch.Tensor:
282 | if self.adapter_pre_attn is not None:
283 | x = self.adapter_pre_attn(x, num_frames)
284 | x = x + self.attention(self.ln_1(x),x_route_context)
285 | if self.adapter_pre_mlp is not None:
286 | x = self.adapter_pre_mlp(x, num_frames)
287 | x = x + self.mlp_wrapper(self.ln_2(x),x_route_context)
288 | return x
289 |
290 |
291 | class Transformer(nn.Module):
292 | def __init__(self,
293 | width: int,
294 | layers: int,
295 | heads: int,
296 | adapter_width: int,
297 | adapter_layers: int,
298 | adapter_kernel_size: Tuple[int, int, int],
299 | adapter_pre_attn: bool,
300 | adapter_pre_mlp: bool,
301 | lora_rank,
302 | ft_all
303 | ):
304 | super().__init__()
305 | self.width = width
306 | self.layers = layers
307 | self.resblocks = nn.ModuleList([
308 | ResidualAttentionBlock(
309 | d_model=width,
310 | n_head=heads,
311 | adapter_width=adapter_width,
312 | adapter_kernel_size=adapter_kernel_size,
313 | adapter_pre_attn=adapter_pre_attn and i >= layers - adapter_layers,
314 | adapter_pre_mlp=adapter_pre_mlp and i >= layers - adapter_layers,
315 | lora_rank=lora_rank,
316 | ft_all = ft_all
317 | )
318 | for i in range(layers)
319 | ])
320 |
321 | def forward(self, x: torch.Tensor, num_frames: int,x_route_context) -> torch.Tensor:
322 | for block in self.resblocks:
323 | x = block(x, num_frames,x_route_context)
324 | return x
325 |
326 |
327 | class VisionTransformer(nn.Module):
328 | def __init__(self,
329 | input_resolution: int,
330 | patch_size: int,
331 | width: int,
332 | layers: int,
333 | heads: int,
334 | num_classes: int,
335 | adapter_width: int,
336 | adapter_layers: int,
337 | adapter_kernel_size: Tuple[int, int, int],
338 | adapter_pre_attn: bool,
339 | adapter_pre_mlp: bool,
340 | lora_rank,
341 | args
342 | ):
343 | super().__init__()
344 | self.input_resolution = input_resolution
345 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width,
346 | kernel_size=patch_size, stride=patch_size, bias=False)
347 | self.router_spatial_transform = nn.Sequential(
348 | nn.Conv2d(in_channels=width, out_channels=width,
349 | kernel_size=1, stride=1,padding=0, bias=True),
350 | nn.LeakyReLU(0.1,True),
351 | nn.Conv2d(in_channels=width, out_channels=width,
352 | kernel_size=1, stride=1,padding=0, bias=True),
353 | )
354 | self.router_group = args.code_number
355 | self.knowledge_group = 1
356 | self.router_MLP = nn.Sequential(
357 | nn.Linear(width,width),
358 | nn.LeakyReLU(0.1,True),
359 | nn.Linear(width,self.knowledge_group*self.router_group ),
360 | )
361 | self.ft_comp_MLP = nn.Sequential(
362 | nn.LeakyReLU(0.1,False),
363 | nn.Linear(width,width),
364 | nn.LeakyReLU(0.1,False),
365 | nn.Linear(width,width),
366 | )
367 | nn.init.constant_(self.ft_comp_MLP[3].weight,0)
368 | nn.init.constant_(self.ft_comp_MLP[3].bias,0)
369 | self.knowledge_pool = nn.Parameter(torch.randn(self.knowledge_group,self.router_group, 768//self.router_group))
370 |
371 | scale = width ** -0.5
372 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
373 | self.positional_embedding = nn.Parameter(
374 | scale * torch.randn(
375 | (input_resolution // patch_size) ** 2 + 1, width
376 | )
377 | )
378 | self.ln_pre = LayerNorm(width)
379 |
380 | self.transformer = Transformer(width, layers, heads,
381 | adapter_width, adapter_layers, adapter_kernel_size,
382 | adapter_pre_attn, adapter_pre_mlp,lora_rank =lora_rank,ft_all =args.ft_all)
383 |
384 | self.ln_post = LayerNorm(width)
385 | self.lora_rank = lora_rank
386 | # for n, p in self.named_parameters():
387 | # if 'adapter' not in n:
388 | # p.requires_grad_(False)
389 |
390 | self.dropout = nn.Dropout(0.5)
391 | self.fc = nn.Linear(width, num_classes)
392 | nn.init.normal_(self.fc.weight, std=0.02)
393 | nn.init.constant_(self.fc.bias, 0.)
394 | self.last_mlp = nn.Sequential(
395 | nn.Linear(768,768),
396 | nn.LeakyReLU(0.1),
397 | nn.Linear(768,768),
398 | )
399 | self.args = args
400 |
401 | def forward(self, x: torch.Tensor,return_mode = "default"):
402 | B, T = x.size(0), x.size(2)
403 | x = x.permute(0, 2, 1, 3, 4).flatten(0, 1)
404 | x = self.conv1(x) # shape = [*, width, grid, grid]
405 | x_for_route = self.router_spatial_transform(x)
406 | x_for_route_root =x_for_route.mean(-1).mean(-1)
407 | x_for_route = self.router_MLP(x_for_route_root) ##B K
408 | x_for_ft_comp = self.ft_comp_MLP(x_for_route_root)
409 | x_for_route = x_for_route.reshape(B*T,self.knowledge_group,self.router_group ) ## B, KG, RG
410 | x_for_route = F.softmax(x_for_route,dim=-1) ## B K
411 |
412 | #### 1 K C
413 | x_route_context = self.knowledge_pool.unsqueeze(0) * x_for_route.unsqueeze(-1)## B, KG, RG, C//RG
414 | x_route_context = x_route_context.sum(1) ## B RG, C//RG
415 | x_route_context = x_route_context.reshape(B*T, 768)
416 | if self.args.no_codebook:
417 | x_route_context = x_for_route_root
418 | if return_mode == "modality_token":
419 | return x_route_context.reshape(B,T,768).mean(1)
420 | spatial_size = tuple(x.size()[2:])
421 | x = x.flatten(-2).permute(0, 2, 1)##B N C
422 | if self.lora_rank>0:
423 | x = x + x_for_ft_comp.unsqueeze(1)
424 | x = torch.cat([
425 | self.class_embedding.view(1, 1, -1).expand(x.shape[0], -1, -1), x
426 | ], dim=1) # [*, grid ** 2 + 1, width]
427 | x = x + self.positional_embedding.to(x.dtype)
428 | x = self.ln_pre(x)
429 |
430 | x = x.view(B, T, x.size(1), x.size(2)).flatten(0, 1) # BT, L, D
431 |
432 | x = self.transformer(x, T,x_route_context)
433 |
434 | x = x.contiguous().view(B, T, spatial_size[0] * spatial_size[1] + 1, x.size(-1))
435 | # x_global = x[:, :, 0, :].mean(dim=1)
436 | x_local = x[:, :, 1:, :].mean(dim=1)
437 | x_global = x_local.mean(1)
438 | H,W = 14,14
439 | B,S,C = x_local.size()
440 | assert S == H*W
441 | x_local =x_local.reshape(B,H,W,C).permute(0,3,1,2)
442 | # x_global = self.ln_post(x_global)
443 | return self.last_mlp(x_global), x_local
444 |
445 |
446 |
447 | def clip_vit_base_patch16_adapter24x384(**kwargs):
448 | model = VisionTransformer(
449 | input_resolution=224,
450 | patch_size=16,
451 | width=768,
452 | layers=12,
453 | heads=12,
454 | adapter_width=384,
455 | adapter_layers=12,
456 | adapter_kernel_size=(3, 1, 1),
457 | adapter_pre_attn=True,
458 | adapter_pre_mlp=False,
459 | **kwargs,
460 | )
461 | assert CLIP_VIT_B16_PATH is not None, \
462 | 'Please set CLIP_VIT_B16_PATH in configs.py.'
463 | checkpoint = torch.jit.load(CLIP_VIT_B16_PATH, map_location='cpu')
464 | model.load_state_dict(checkpoint.visual.state_dict(), strict=False)
465 | return model
466 |
467 | def clip_vit_base_patch16(**kwargs):
468 | model = VisionTransformer(
469 | input_resolution=224,
470 | patch_size=16,
471 | width=768,
472 | layers=12,
473 | heads=12,
474 | adapter_width=384,
475 | adapter_layers=12,
476 | adapter_kernel_size=(3, 1, 1),
477 | adapter_pre_attn=False,
478 | adapter_pre_mlp=False,
479 | **kwargs,
480 | )
481 | assert CLIP_VIT_B16_PATH is not None, \
482 | 'Please set CLIP_VIT_B16_PATH in configs.py.'
483 | checkpoint = torch.jit.load(CLIP_VIT_B16_PATH, map_location='cpu')
484 | model.load_state_dict(checkpoint.visual.state_dict(), strict=False)
485 | print(model)
486 | return model
487 |
488 | def clip_vit_base_patch16_adapter12x384(**kwargs):
489 | model = VisionTransformer(
490 | input_resolution=224,
491 | patch_size=16,
492 | width=768,
493 | layers=12,
494 | heads=12,
495 | adapter_width=384,
496 | adapter_layers=12,
497 | adapter_kernel_size=(3, 1, 1),
498 | adapter_pre_attn=False,
499 | adapter_pre_mlp=False,
500 | **kwargs,
501 | )
502 | assert CLIP_VIT_B16_PATH is not None, \
503 | 'Please set CLIP_VIT_B16_PATH in configs.py'
504 | checkpoint = torch.jit.load(CLIP_VIT_B16_PATH, map_location='cpu')
505 | print(model.load_state_dict(checkpoint.visual.state_dict(), strict=False))
506 | return model
--------------------------------------------------------------------------------