├── main.png ├── test_img.jpg ├── test.bash ├── ReadMe.md ├── test_extract_id.py ├── opts.py └── st_adapter_DyKnow.py /main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyuan168326/All-in-One-MedReID-Pytorch/HEAD/main.png -------------------------------------------------------------------------------- /test_img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianyuan168326/All-in-One-MedReID-Pytorch/HEAD/test_img.jpg -------------------------------------------------------------------------------- /test.bash: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node 1 --master_port=28349 test_extract_id.py --lr 5e-6 \ 2 | --epochs 20 --optim adamw --save_every_n_epochs 3 --wd 5e-2 --backbone clip_Tadapt --clip_grad -1 --online_compute_id \ 3 | --batch_size 6 --text_use_method identity_CL --mimic_all --lora_rank 16 --margin -1 --ft_all --relation_type difference --start_epoch 0 \ 4 | --exp_dir ./ \ 5 | --resume_main MAMI_pretrained.pth 6 | -------------------------------------------------------------------------------- /ReadMe.md: -------------------------------------------------------------------------------- 1 | # All-in-One Medical Image Re-Identification (CVPR2025) 2 | 3 | MaMI is the fist unified re-identification model for medical images, capable of handling various imaging modalities such as X-ray, CT, fundus, and pathology images. By leveraging a Continuous Modality-based Parameter Adapter (ComPA) and integrating medical priors, MaMI supports both historical data-assisted diagnosis and privacy protection applications. 4 | 5 |
6 | MAMI 7 |
8 | 9 | 10 | 11 | ## Features 12 | 13 | - **Unified Multi-Modality Model** 14 | A single model supports multiple medical image modalities without having to train separate models for each modality. 15 | 16 | - **Continuous Modality Adaptive Parameterization** 17 | The ComPA module generates continuous modality representations to dynamically adapt model parameters based on the input image. 18 | 19 | - **Medical Priors Integration** 20 | Incorporates pre-trained Medical Foundation Models (MFMs) to enhance feature discrimination, capturing subtle identity-related cues for more robust re-identification. 21 | 22 | ## Requirements 23 | 24 | - Python 3.7+ 25 | - PyTorch 1.9+ 26 | - CUDA-enabled GPU (recommended: NVIDIA RTX 4090) 27 | 28 | ## Pre-trained Model 29 | MAMI_pretrained.pth 30 | 31 | Google Driver Link: https://drive.google.com/file/d/159KsDSgzgFCdSLN-I1iO0I2Z8Dm7dpiK/view?usp=sharing 32 | 33 | BaiduYun (百度网盘) Link: https://pan.baidu.com/s/1GD-TsafqYhbXM6VwQVUwIA?pwd=fbge 34 | 35 | ## Usage 36 | 37 | 1. You can download the pre-trained model from the above link. 38 | 39 | 2. copy the file the file "MAMI_pretrained.pth" to this directory. 40 | 41 | 3. run "test.bash" 42 | 43 | ## ToDo 44 | - [ ] **Open Source Train/Validation Split** 45 | Publish the scripts/configuration used to generate the train/val split for the dataset. 46 | 47 | - [ ] **Open Source The Benchmark Methods** 48 | 49 | - [ ] **Open Source Training Code** 50 | Release the complete training pipeline including all scripts and configuration files. 51 | 52 | 53 | ## Citation 54 | If you find our work useful, please cite: 55 | 56 | @article{tian2025towards, 57 | title={Towards All-in-One Medical Image Re-Identification}, 58 | author={Tian, Yuan and Ji, Kaiyuan and Zhang, Rongzhao and Jiang, Yankai and Li, Chunyi and Wang, Xiaosong and Zhai, Guangtao}, 59 | journal={arXiv preprint arXiv:2503.08173}, 60 | year={2025} 61 | } 62 | -------------------------------------------------------------------------------- /test_extract_id.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 4 | os.environ['TORCH_HOME']='/root/autodl-tmp/torch_home' 5 | os.environ['HF_HOME']='/root/autodl-tmp/huggingface_cache' 6 | os.environ['HF_ENDPOINT']='https://hf-mirror.com' 7 | 8 | import torch 9 | from opts import parser 10 | from torchvision.transforms import Compose, InterpolationMode, Normalize, ToTensor 11 | import torch.distributed as dist 12 | from PIL import Image 13 | import torchvision.transforms as T 14 | dist.init_process_group(backend='nccl') # nccl是GPU设备上最快、最推荐的后端 15 | args = parser.parse_args() 16 | local_rank = args.local_rank 17 | torch.cuda.set_device(local_rank) 18 | 19 | from torch.utils.tensorboard import SummaryWriter 20 | if local_rank == 0: 21 | tb_logger = SummaryWriter(log_dir=os.path.join(args.exp_dir,'board'),flush_secs=10) 22 | checkpoint_dir = os.path.join(args.exp_dir,'checkpoints') 23 | if not os.path.exists(checkpoint_dir): 24 | os.mkdir(checkpoint_dir) 25 | log_training = open(os.path.join(args.exp_dir,'log.csv'), 'w') 26 | log_training.write(str(args)) 27 | 28 | device = "cuda" 29 | from st_adapter_DyKnow import clip_vit_base_patch16_adapter24x384,clip_vit_base_patch16 30 | 31 | # from ema import EMA 32 | import torchvision.transforms as T 33 | 34 | model =clip_vit_base_patch16_adapter24x384(num_classes=1,args = args,lora_rank =args.lora_rank ).to(device).train() 35 | val_transform =T.Compose( 36 | [ 37 | T.Resize(size=256, antialias=True), 38 | ToTensor(), 39 | Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 40 | T.CenterCrop((224,224)) 41 | ] 42 | ) 43 | 44 | 45 | if os.path.isfile(args.resume_main): 46 | print(("=> loading checkpoint '{}'".format(args.resume_main))) 47 | checkpoint = torch.load(args.resume_main,map_location='cpu') 48 | checkpoint_model = checkpoint["model"] 49 | ckeck_copy = {} 50 | ks =checkpoint_model.keys() 51 | for k in ks: 52 | new_k = k 53 | if k.startswith("module"): 54 | if "router_MLP.2" in k or "knowledge_pool" in k: 55 | if args.lora_rank<0: 56 | continue 57 | new_k = k[7:] 58 | ckeck_copy[new_k] = checkpoint_model[k] 59 | model.load_state_dict(ckeck_copy,strict=False) 60 | 61 | 62 | 63 | contain_trainable = False 64 | for name, param in model.named_parameters(): 65 | if param.requires_grad: 66 | contain_trainable = True 67 | break 68 | if contain_trainable: 69 | model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[torch.cuda.current_device()],\ 70 | broadcast_buffers=False , find_unused_parameters=True) 71 | else: 72 | model = model.cuda() 73 | 74 | 75 | ## we accept both single-image (such as Xray, fundus) and sequence (such as CT) medical images. 76 | ## The input dimension is (B,C,H,W) or (B,L,H,W), where L is the sequence length 77 | 78 | ## 79 | img_path = "test_img.jpg" 80 | img = Image.open(img_path).convert('RGB') 81 | img_tensor = val_transform(img) 82 | image_tensor_example = img_tensor.unsqueeze(0) 83 | input = image_tensor_example 84 | if len(input.size()) ==4: 85 | input = input.unsqueeze(2) 86 | ft_x,_ = model(input.to(device)) 87 | 88 | print("ID tensor", ft_x) -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser(description="PyTorch implementation of action recognition models") 3 | parser.add_argument('--exp_dir', type=str, default="experiments/exp_test",help='Path to experiment output directory') 4 | parser.add_argument('--train_csv', type=str, default="",help='Path to experiment output directory') 5 | parser.add_argument('--optim', type=str, default="sgd",help='') 6 | parser.add_argument('--verification_csv', type=str, default="",help='Path to experiment output directory') 7 | parser.add_argument('--epochs', default=20, type=int, metavar='N', 8 | help='number of total epochs to run') 9 | parser.add_argument('--save_every_n_epochs', default=5, type=int, metavar='N', 10 | help='number of total epochs to run') 11 | parser.add_argument('--test_every_n_epochs', default=1, type=int, metavar='N', 12 | help='number of total epochs to run') 13 | parser.add_argument('--lora_rank', default=8, type=int, metavar='N', 14 | help='number of total epochs to run') 15 | parser.add_argument('--code_number', default=16, type=int, metavar='N', 16 | help='number of total epochs to run') 17 | parser.add_argument('--resume_main', default='', type=str, metavar='PATH', 18 | help='path to latest resume_main checkpoint (default: none)') 19 | parser.add_argument('--data_modality', default='image', type=str, 20 | help='path to latest resume_main checkpoint (default: none)') 21 | parser.add_argument('--backbone', default='clip_stadapt_imagenet', type=str, 22 | help='path to latest resume_main checkpoint (default: none)') 23 | parser.add_argument('--lr', '--learning-rate', default=0.0001, type=float, 24 | metavar='LR', help='initial learning rate') 25 | parser.add_argument('--wd', default=0.0001, type=float, 26 | metavar='LR', help='initial learning rate') 27 | parser.add_argument('--margin', default=0.4, type=float, 28 | metavar='LR', help='initial learning rate') 29 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 30 | help='evaluate model on validation set') 31 | parser.add_argument( '--only_eval_privacy', dest='only_eval_privacy', action='store_true', 32 | help='evaluate model on validation set') 33 | parser.add_argument('--batch_size', default=8, type=int, help='Batch size for training') 34 | parser.add_argument('--start_epoch', default=0, type=int, help='Batch size for training') 35 | parser.add_argument('--text_use_method', type=str, default="none",help='Path to experiment output directory') 36 | parser.add_argument('--relation_type', type=str, default="mlp",help='Path to experiment output directory') 37 | parser.add_argument('--mimic_all', action='store_true', help='Enable verbose mode') 38 | parser.add_argument('--no_codebook', action='store_true', help='Enable verbose mode') 39 | parser.add_argument('--ft_all', action='store_true', help='Enable verbose mode') 40 | parser.add_argument('--ft_backbone_lr_multi', default=1, type=float, 41 | metavar='LR', help='initial learning rate') 42 | parser.add_argument('--train_modality', type=str, default="all",help='') 43 | 44 | parser.add_argument('--online_compute_id', action='store_true', help='Enable verbose mode') 45 | parser.add_argument('--aug_color', action='store_true', help='Enable verbose mode') 46 | parser.add_argument('--aug_resize', action='store_true', help='Enable verbose mode') 47 | 48 | parser.add_argument('--no_randE', action='store_true', help='Enable verbose mode') 49 | parser.add_argument('--no_randCrop', action='store_true', help='Enable verbose mode') 50 | parser.add_argument('--clip_grad', default=0, type=float, 51 | metavar='LR', help='initial learning rate') 52 | 53 | parser.add_argument('--local-rank', type=int, default="0") 54 | 55 | -------------------------------------------------------------------------------- /st_adapter_DyKnow.py: -------------------------------------------------------------------------------- 1 | # modified from: https://github.com/openai/CLIP/blob/a9b1bf5920416aaeaec965c25dd9e8f98c864f16/clip/model.py 2 | 3 | from typing import Tuple 4 | from collections import OrderedDict 5 | import math 6 | import functools 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | CLIP_VIT_B16_PATH = "/root/autodl-tmp/patient_triple/code/open-metric-learning/my_code/pretrained_models/ViT-B-16.pt" 13 | DWCONV3D_DISABLE_CUDNN = True 14 | class Adapter(nn.Module): 15 | 16 | def __init__(self, in_channels, adapter_channels, kernel_size): 17 | super().__init__() 18 | self.fc1 = nn.Linear(in_channels, adapter_channels) 19 | self.conv = nn.Conv3d( 20 | adapter_channels, adapter_channels, 21 | kernel_size=kernel_size, 22 | stride=(1, 1, 1), 23 | padding=tuple(x // 2 for x in kernel_size), 24 | groups=adapter_channels, 25 | ) 26 | self.fc2 = nn.Linear(adapter_channels, in_channels) 27 | nn.init.constant_(self.conv.weight, 0.) 28 | nn.init.constant_(self.conv.bias, 0.) 29 | nn.init.constant_(self.fc1.bias, 0.) 30 | nn.init.constant_(self.fc2.bias, 0.) 31 | 32 | def forward(self, x, T): 33 | BT, L, C = x.size() 34 | B = BT // T 35 | Ca = self.conv.in_channels 36 | H = W = round(math.sqrt(L - 1)) 37 | assert L - 1 == H * W 38 | x_id = x 39 | x = x[:, 1:, :] 40 | x = self.fc1(x) 41 | x = x.view(B, T, H, W, Ca).permute(0, 4, 1, 2, 3).contiguous() 42 | 43 | cudnn_enabled = torch.backends.cudnn.enabled 44 | torch.backends.cudnn.enabled = cudnn_enabled and DWCONV3D_DISABLE_CUDNN 45 | x = self.conv(x) 46 | torch.backends.cudnn.enabled = cudnn_enabled 47 | 48 | x = x.permute(0, 2, 3, 4, 1).contiguous().view(BT, L - 1, Ca) 49 | x = self.fc2(x) 50 | x_id[:, 1:, :] += x 51 | return x_id 52 | 53 | 54 | class LayerNorm(nn.LayerNorm): 55 | """Subclass torch's LayerNorm to handle fp16.""" 56 | 57 | def forward(self, x: torch.Tensor): 58 | orig_type = x.dtype 59 | ret = super().forward(x.type(torch.float32)) 60 | return ret.type(orig_type) 61 | 62 | 63 | class QuickGELU(nn.Module): 64 | def forward(self, x: torch.Tensor): 65 | return x * torch.sigmoid(1.702 * x) 66 | class GroupedLinear(nn.Module): 67 | def __init__(self, in_features, out_features, groups=1, bias=True): 68 | super(GroupedLinear, self).__init__() 69 | self.in_features = in_features 70 | self.out_features = out_features 71 | self.groups = groups 72 | self.bias = bias 73 | 74 | assert in_features % groups == 0, "in_features must be divisible by groups" 75 | assert out_features % groups == 0, "out_features must be divisible by groups" 76 | 77 | self.weight = nn.Parameter(torch.Tensor(groups, in_features // groups, out_features // groups)) 78 | if bias: 79 | self.bias = nn.Parameter(torch.Tensor(groups, out_features // groups)) 80 | else: 81 | self.register_parameter('bias', None) 82 | 83 | self.reset_parameters() 84 | 85 | def reset_parameters(self): 86 | nn.init.constant_(self.weight,0) 87 | nn.init.constant_(self.bias,0) 88 | # nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 89 | # if self.bias is not None: 90 | # fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 91 | # bound = 1 / math.sqrt(fan_in) 92 | # nn.init.uniform_(self.bias, -bound, bound) 93 | 94 | def forward(self, x): 95 | B, C = x.size() 96 | assert C == self.in_features, "Input feature dimension does not match" 97 | 98 | x = x.view(B, self.groups, C // self.groups) 99 | x = torch.einsum('bgi,gio->bgo', x, self.weight) 100 | 101 | if self.bias is not None: 102 | x = x + self.bias 103 | 104 | x = x.reshape(B, self.out_features) 105 | return x 106 | 107 | class GroupedMLP(nn.Module): 108 | def __init__(self, in_features, out_features, groups=1, bias=True): 109 | super(GroupedMLP, self).__init__() 110 | self.l1 = GroupedLinear(in_features, in_features//4,1) 111 | self.l2 = GroupedLinear(in_features//4, out_features,groups) 112 | 113 | 114 | def forward(self, x): 115 | return self.l2 (F.leaky_relu(self.l1(x),0.1)) 116 | 117 | class GroupedMLP(nn.Module): 118 | def __init__(self, in_features, out_features, groups=1, bias=True): 119 | super(GroupedMLP, self).__init__() 120 | self.l11 = GroupedLinear(in_features, in_features//2,1) 121 | self.l21 = GroupedLinear(in_features//2, in_features//4,1) 122 | self.l31 = GroupedLinear(in_features//4, out_features,groups) 123 | 124 | 125 | def forward(self, x): 126 | res = x 127 | x = self.l21 (F.leaky_relu(self.l11(x),0.1,True)) 128 | x = F.leaky_relu(x,0.1,True) 129 | return self.l31(x) 130 | 131 | 132 | class ResidualAttentionBlock(nn.Module): 133 | def __init__(self, 134 | d_model: int, 135 | n_head: int, 136 | adapter_width: int, 137 | adapter_kernel_size: Tuple[int, int, int], 138 | adapter_pre_attn: bool, 139 | adapter_pre_mlp: bool, 140 | lora_rank , 141 | ft_all 142 | ) -> None: 143 | super().__init__() 144 | 145 | self.attn = nn.MultiheadAttention(d_model, n_head) 146 | self.ln_1 = LayerNorm(d_model) 147 | self.mlp = nn.Sequential(OrderedDict([ 148 | ("c_fc", nn.Linear(d_model, d_model * 4)), 149 | ("gelu", QuickGELU()), 150 | ("c_proj", nn.Linear(d_model * 4, d_model)) 151 | ])) 152 | self.ln_2 = LayerNorm(d_model) 153 | if not ft_all: 154 | for _ in self.attn.parameters(): 155 | _.requires_grad = False 156 | for _ in self.ln_1.parameters(): 157 | _.requires_grad = False 158 | for _ in self.mlp.parameters(): 159 | _.requires_grad = False 160 | for _ in self.ln_2.parameters(): 161 | _.requires_grad = False 162 | 163 | adapter_class = functools.partial( 164 | Adapter, 165 | in_channels=d_model, 166 | adapter_channels=adapter_width, 167 | kernel_size=adapter_kernel_size, 168 | ) 169 | self.adapter_pre_attn = \ 170 | adapter_class() if adapter_pre_attn else None 171 | self.adapter_pre_mlp = \ 172 | adapter_class() if adapter_pre_mlp else None 173 | self.d_model = d_model 174 | self.lora_rank = lora_rank 175 | if lora_rank>0: 176 | self.lora_rank_dynamic = lora_rank_dynamic = lora_rank 177 | G = 16 178 | self.lora_mlp_q_a = GroupedMLP(d_model, lora_rank_dynamic * d_model,groups=G) 179 | self.lora_mlp_q_b = GroupedMLP(d_model, lora_rank_dynamic * d_model,groups=G) 180 | self.lora_mlp_k_a = GroupedMLP(d_model, lora_rank_dynamic * d_model,groups=G) 181 | self.lora_mlp_k_b = GroupedMLP(d_model, lora_rank_dynamic * d_model,groups=G) 182 | self.lora_mlp_v_a = GroupedMLP(d_model, lora_rank_dynamic * d_model,groups=G) 183 | self.lora_mlp_v_b = GroupedMLP(d_model, lora_rank_dynamic * d_model,groups=G) 184 | 185 | self.linear_a_q_static = nn.Parameter(torch.zeros(1, self.lora_rank, self.d_model)) 186 | self.linear_b_q_static = nn.Parameter(torch.zeros(1, self.lora_rank, self.d_model)) 187 | self.linear_a_k_static = nn.Parameter(torch.zeros(1, self.lora_rank, self.d_model)) 188 | self.linear_b_k_static = nn.Parameter(torch.zeros(1, self.lora_rank, self.d_model)) 189 | self.linear_a_v_static = nn.Parameter(torch.zeros(1, self.lora_rank, self.d_model)) 190 | self.linear_b_v_static = nn.Parameter(torch.zeros(1, self.lora_rank, self.d_model)) 191 | # if ft_all: 192 | # self.linear_a_q_static.requires_grad = False 193 | # self.linear_b_q_static.requires_grad = False 194 | # self.linear_a_k_static.requires_grad = False 195 | # self.linear_b_k_static.requires_grad = False 196 | # self.linear_a_v_static.requires_grad = False 197 | # self.linear_b_v_static.requires_grad = False 198 | # LORA parameter generation MLPs -for FFN 199 | G = 32 200 | self.lora_mlp_c_fc_a = GroupedMLP(d_model, lora_rank_dynamic * (d_model * 4),groups=G) 201 | self.lora_mlp_c_fc_b = GroupedMLP(d_model, lora_rank_dynamic * d_model,groups=G) 202 | self.lora_mlp_c_proj_a = GroupedMLP(d_model, lora_rank_dynamic * d_model,groups=G) 203 | self.lora_mlp_c_proj_b = GroupedMLP(d_model, lora_rank_dynamic * (d_model * 4),groups=G) 204 | 205 | self.linear_a_c_fc_static = nn.Parameter(torch.zeros(1, self.lora_rank, self.d_model * 4)) 206 | self.linear_b_c_fc_static = nn.Parameter(torch.zeros(1, self.lora_rank, self.d_model)) 207 | self.linear_a_c_proj_static = nn.Parameter(torch.zeros(1, self.lora_rank, self.d_model)) 208 | self.linear_b_c_proj_static = nn.Parameter(torch.zeros(1, self.lora_rank, self.d_model * 4)) 209 | # if ft_all: 210 | # self.linear_a_c_fc_static.requires_grad = False 211 | # self.linear_b_c_fc_static.requires_grad = False 212 | # self.linear_a_c_proj_static.requires_grad = False 213 | # self.linear_b_c_proj_static.requires_grad = False 214 | 215 | 216 | def attention(self, x: torch.Tensor,x_route_context) -> torch.Tensor: 217 | 218 | B, L, C = x.size() 219 | H = self.attn.num_heads 220 | 221 | if self.lora_rank>0: 222 | # B C x_route_context 223 | # Generate LORA parameters 224 | linear_a_q = self.lora_mlp_q_a(x_route_context).view(B, self.lora_rank_dynamic, C)+self.linear_a_q_static 225 | linear_b_q = self.lora_mlp_q_b(x_route_context).view(B, self.lora_rank_dynamic, C)+self.linear_b_q_static 226 | linear_a_k = self.lora_mlp_k_a(x_route_context).view(B, self.lora_rank_dynamic, C)+self.linear_a_k_static 227 | linear_b_k = self.lora_mlp_k_b(x_route_context).view(B, self.lora_rank_dynamic, C)+self.linear_b_k_static 228 | linear_a_v = self.lora_mlp_v_a(x_route_context).view(B, self.lora_rank_dynamic, C)+self.linear_a_v_static 229 | linear_b_v = self.lora_mlp_v_b(x_route_context).view(B, self.lora_rank_dynamic, C)+self.linear_b_v_static 230 | 231 | 232 | new_q = torch.einsum('bkc,blk->blc', linear_b_q, torch.einsum('bri,bti->btr', linear_a_q, x)) 233 | ## B K C * B L K 234 | new_k = torch.einsum('bkc,blk->blc', linear_b_k, torch.einsum('bri,bti->btr', linear_a_k, x)) 235 | new_v = torch.einsum('bkc,blk->blc', linear_b_v, torch.einsum('bri,bti->btr', linear_a_v, x)) 236 | ## 237 | 238 | qkv = F.linear(x, weight=self.attn.in_proj_weight, bias=self.attn.in_proj_bias) 239 | if self.lora_rank>0: 240 | qkv[:, :, :self.d_model] += new_q 241 | qkv[:, :, self.d_model:-self.d_model] += new_k 242 | qkv[:, :, -self.d_model:] += new_v 243 | qkv = qkv.view(B, L, H * 3, -1).permute(0, 2, 1, 3) 244 | q, k, v = qkv.split([H, H, H], dim=1) 245 | out = F.scaled_dot_product_attention(q, k, v) 246 | out = out.permute(0, 2, 1, 3).flatten(-2) 247 | out = self.attn.out_proj(out) 248 | 249 | return out 250 | 251 | def mlp_wrapper(self, x: torch.Tensor,x_route_context) -> torch.Tensor: 252 | if self.lora_rank>0: 253 | B, L, C = x.size() 254 | linear_a_c_fc = self.lora_mlp_c_fc_a(x_route_context).view(B, self.lora_rank, C * 4)+self.linear_a_c_fc_static 255 | linear_b_c_fc = self.lora_mlp_c_fc_b(x_route_context).view(B, self.lora_rank, C)+self.linear_b_c_fc_static 256 | # linear_a_c_proj = self.lora_mlp_c_proj_a(x_route_context).view(B, self.lora_rank, C)+self.linear_a_c_proj_static 257 | # linear_b_c_proj = self.lora_mlp_c_proj_b(x_route_context).view(B, self.lora_rank, C * 4)+self.linear_b_c_proj_static 258 | 259 | c_fc_out = F.linear(x, weight=self.mlp.c_fc.weight, bias=self.mlp.c_fc.bias) 260 | 261 | if self.lora_rank>0: 262 | new_c_fc = torch.einsum('bkc,blk->blc', linear_a_c_fc, torch.einsum('bri,bti->btr', linear_b_c_fc, x)) 263 | ### B K C * B K L 264 | c_fc_out += new_c_fc 265 | 266 | c_fc_out = self.mlp.gelu(c_fc_out) 267 | 268 | # Compute c_proj with LORA 269 | c_proj_out = F.linear(c_fc_out, weight=self.mlp.c_proj.weight, bias=self.mlp.c_proj.bias) 270 | # if self.lora_rank>0: 271 | # new_c_proj = torch.einsum('bkc,blk->blc', linear_a_c_proj, torch.einsum('bri,bti->btr', linear_b_c_proj, c_fc_out)) 272 | # c_proj_out += new_c_proj 273 | 274 | return c_proj_out 275 | 276 | 277 | def forward(self, 278 | x: torch.Tensor, 279 | num_frames: int, 280 | x_route_context 281 | ) -> torch.Tensor: 282 | if self.adapter_pre_attn is not None: 283 | x = self.adapter_pre_attn(x, num_frames) 284 | x = x + self.attention(self.ln_1(x),x_route_context) 285 | if self.adapter_pre_mlp is not None: 286 | x = self.adapter_pre_mlp(x, num_frames) 287 | x = x + self.mlp_wrapper(self.ln_2(x),x_route_context) 288 | return x 289 | 290 | 291 | class Transformer(nn.Module): 292 | def __init__(self, 293 | width: int, 294 | layers: int, 295 | heads: int, 296 | adapter_width: int, 297 | adapter_layers: int, 298 | adapter_kernel_size: Tuple[int, int, int], 299 | adapter_pre_attn: bool, 300 | adapter_pre_mlp: bool, 301 | lora_rank, 302 | ft_all 303 | ): 304 | super().__init__() 305 | self.width = width 306 | self.layers = layers 307 | self.resblocks = nn.ModuleList([ 308 | ResidualAttentionBlock( 309 | d_model=width, 310 | n_head=heads, 311 | adapter_width=adapter_width, 312 | adapter_kernel_size=adapter_kernel_size, 313 | adapter_pre_attn=adapter_pre_attn and i >= layers - adapter_layers, 314 | adapter_pre_mlp=adapter_pre_mlp and i >= layers - adapter_layers, 315 | lora_rank=lora_rank, 316 | ft_all = ft_all 317 | ) 318 | for i in range(layers) 319 | ]) 320 | 321 | def forward(self, x: torch.Tensor, num_frames: int,x_route_context) -> torch.Tensor: 322 | for block in self.resblocks: 323 | x = block(x, num_frames,x_route_context) 324 | return x 325 | 326 | 327 | class VisionTransformer(nn.Module): 328 | def __init__(self, 329 | input_resolution: int, 330 | patch_size: int, 331 | width: int, 332 | layers: int, 333 | heads: int, 334 | num_classes: int, 335 | adapter_width: int, 336 | adapter_layers: int, 337 | adapter_kernel_size: Tuple[int, int, int], 338 | adapter_pre_attn: bool, 339 | adapter_pre_mlp: bool, 340 | lora_rank, 341 | args 342 | ): 343 | super().__init__() 344 | self.input_resolution = input_resolution 345 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, 346 | kernel_size=patch_size, stride=patch_size, bias=False) 347 | self.router_spatial_transform = nn.Sequential( 348 | nn.Conv2d(in_channels=width, out_channels=width, 349 | kernel_size=1, stride=1,padding=0, bias=True), 350 | nn.LeakyReLU(0.1,True), 351 | nn.Conv2d(in_channels=width, out_channels=width, 352 | kernel_size=1, stride=1,padding=0, bias=True), 353 | ) 354 | self.router_group = args.code_number 355 | self.knowledge_group = 1 356 | self.router_MLP = nn.Sequential( 357 | nn.Linear(width,width), 358 | nn.LeakyReLU(0.1,True), 359 | nn.Linear(width,self.knowledge_group*self.router_group ), 360 | ) 361 | self.ft_comp_MLP = nn.Sequential( 362 | nn.LeakyReLU(0.1,False), 363 | nn.Linear(width,width), 364 | nn.LeakyReLU(0.1,False), 365 | nn.Linear(width,width), 366 | ) 367 | nn.init.constant_(self.ft_comp_MLP[3].weight,0) 368 | nn.init.constant_(self.ft_comp_MLP[3].bias,0) 369 | self.knowledge_pool = nn.Parameter(torch.randn(self.knowledge_group,self.router_group, 768//self.router_group)) 370 | 371 | scale = width ** -0.5 372 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 373 | self.positional_embedding = nn.Parameter( 374 | scale * torch.randn( 375 | (input_resolution // patch_size) ** 2 + 1, width 376 | ) 377 | ) 378 | self.ln_pre = LayerNorm(width) 379 | 380 | self.transformer = Transformer(width, layers, heads, 381 | adapter_width, adapter_layers, adapter_kernel_size, 382 | adapter_pre_attn, adapter_pre_mlp,lora_rank =lora_rank,ft_all =args.ft_all) 383 | 384 | self.ln_post = LayerNorm(width) 385 | self.lora_rank = lora_rank 386 | # for n, p in self.named_parameters(): 387 | # if 'adapter' not in n: 388 | # p.requires_grad_(False) 389 | 390 | self.dropout = nn.Dropout(0.5) 391 | self.fc = nn.Linear(width, num_classes) 392 | nn.init.normal_(self.fc.weight, std=0.02) 393 | nn.init.constant_(self.fc.bias, 0.) 394 | self.last_mlp = nn.Sequential( 395 | nn.Linear(768,768), 396 | nn.LeakyReLU(0.1), 397 | nn.Linear(768,768), 398 | ) 399 | self.args = args 400 | 401 | def forward(self, x: torch.Tensor,return_mode = "default"): 402 | B, T = x.size(0), x.size(2) 403 | x = x.permute(0, 2, 1, 3, 4).flatten(0, 1) 404 | x = self.conv1(x) # shape = [*, width, grid, grid] 405 | x_for_route = self.router_spatial_transform(x) 406 | x_for_route_root =x_for_route.mean(-1).mean(-1) 407 | x_for_route = self.router_MLP(x_for_route_root) ##B K 408 | x_for_ft_comp = self.ft_comp_MLP(x_for_route_root) 409 | x_for_route = x_for_route.reshape(B*T,self.knowledge_group,self.router_group ) ## B, KG, RG 410 | x_for_route = F.softmax(x_for_route,dim=-1) ## B K 411 | 412 | #### 1 K C 413 | x_route_context = self.knowledge_pool.unsqueeze(0) * x_for_route.unsqueeze(-1)## B, KG, RG, C//RG 414 | x_route_context = x_route_context.sum(1) ## B RG, C//RG 415 | x_route_context = x_route_context.reshape(B*T, 768) 416 | if self.args.no_codebook: 417 | x_route_context = x_for_route_root 418 | if return_mode == "modality_token": 419 | return x_route_context.reshape(B,T,768).mean(1) 420 | spatial_size = tuple(x.size()[2:]) 421 | x = x.flatten(-2).permute(0, 2, 1)##B N C 422 | if self.lora_rank>0: 423 | x = x + x_for_ft_comp.unsqueeze(1) 424 | x = torch.cat([ 425 | self.class_embedding.view(1, 1, -1).expand(x.shape[0], -1, -1), x 426 | ], dim=1) # [*, grid ** 2 + 1, width] 427 | x = x + self.positional_embedding.to(x.dtype) 428 | x = self.ln_pre(x) 429 | 430 | x = x.view(B, T, x.size(1), x.size(2)).flatten(0, 1) # BT, L, D 431 | 432 | x = self.transformer(x, T,x_route_context) 433 | 434 | x = x.contiguous().view(B, T, spatial_size[0] * spatial_size[1] + 1, x.size(-1)) 435 | # x_global = x[:, :, 0, :].mean(dim=1) 436 | x_local = x[:, :, 1:, :].mean(dim=1) 437 | x_global = x_local.mean(1) 438 | H,W = 14,14 439 | B,S,C = x_local.size() 440 | assert S == H*W 441 | x_local =x_local.reshape(B,H,W,C).permute(0,3,1,2) 442 | # x_global = self.ln_post(x_global) 443 | return self.last_mlp(x_global), x_local 444 | 445 | 446 | 447 | def clip_vit_base_patch16_adapter24x384(**kwargs): 448 | model = VisionTransformer( 449 | input_resolution=224, 450 | patch_size=16, 451 | width=768, 452 | layers=12, 453 | heads=12, 454 | adapter_width=384, 455 | adapter_layers=12, 456 | adapter_kernel_size=(3, 1, 1), 457 | adapter_pre_attn=True, 458 | adapter_pre_mlp=False, 459 | **kwargs, 460 | ) 461 | assert CLIP_VIT_B16_PATH is not None, \ 462 | 'Please set CLIP_VIT_B16_PATH in configs.py.' 463 | checkpoint = torch.jit.load(CLIP_VIT_B16_PATH, map_location='cpu') 464 | model.load_state_dict(checkpoint.visual.state_dict(), strict=False) 465 | return model 466 | 467 | def clip_vit_base_patch16(**kwargs): 468 | model = VisionTransformer( 469 | input_resolution=224, 470 | patch_size=16, 471 | width=768, 472 | layers=12, 473 | heads=12, 474 | adapter_width=384, 475 | adapter_layers=12, 476 | adapter_kernel_size=(3, 1, 1), 477 | adapter_pre_attn=False, 478 | adapter_pre_mlp=False, 479 | **kwargs, 480 | ) 481 | assert CLIP_VIT_B16_PATH is not None, \ 482 | 'Please set CLIP_VIT_B16_PATH in configs.py.' 483 | checkpoint = torch.jit.load(CLIP_VIT_B16_PATH, map_location='cpu') 484 | model.load_state_dict(checkpoint.visual.state_dict(), strict=False) 485 | print(model) 486 | return model 487 | 488 | def clip_vit_base_patch16_adapter12x384(**kwargs): 489 | model = VisionTransformer( 490 | input_resolution=224, 491 | patch_size=16, 492 | width=768, 493 | layers=12, 494 | heads=12, 495 | adapter_width=384, 496 | adapter_layers=12, 497 | adapter_kernel_size=(3, 1, 1), 498 | adapter_pre_attn=False, 499 | adapter_pre_mlp=False, 500 | **kwargs, 501 | ) 502 | assert CLIP_VIT_B16_PATH is not None, \ 503 | 'Please set CLIP_VIT_B16_PATH in configs.py' 504 | checkpoint = torch.jit.load(CLIP_VIT_B16_PATH, map_location='cpu') 505 | print(model.load_state_dict(checkpoint.visual.state_dict(), strict=False)) 506 | return model --------------------------------------------------------------------------------