├── Downstreams ├── CMPM-C │ ├── dataloader.py │ ├── dataset_split.py │ ├── loss.py │ ├── model.py │ ├── optimizer.py │ ├── test.py │ ├── train.py │ └── utils.py └── readme.txt ├── LICENSE ├── PLIPmodel.py ├── README.md ├── assets ├── SYNTH-PEDES.png ├── abstract.png └── examples.png ├── checkpoints └── readme.txt ├── data └── readme.txt ├── dataset_split.py ├── requirements.txt ├── test_dataloader.py ├── textual_model.py ├── utils.py ├── visual_model.py └── zs_infer.py /Downstreams/CMPM-C/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import os 4 | import json 5 | from transformers import BertTokenizer 6 | import random 7 | import copy 8 | from torch.utils.data import DataLoader 9 | from PIL import Image 10 | from prefetch_generator import BackgroundGenerator 11 | import numpy as np 12 | 13 | class DataLoaderX(DataLoader): 14 | def __iter__(self): 15 | return BackgroundGenerator(super().__iter__()) 16 | 17 | class Dataset(data.Dataset): 18 | def __init__(self, image_path, dataset_path, transform=None ,flip_transform=None): 19 | assert transform is not None, 'transform must not be None' 20 | self.impath = image_path 21 | self.datapath = dataset_path 22 | with open(dataset_path, 'r', encoding='utf8') as fp: 23 | self.dataset = json.load(fp) 24 | self.transform = transform 25 | self.flip_transform = flip_transform 26 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 27 | 28 | def caption_to_tokens(self, caption): 29 | result = self.tokenizer(caption, padding="max_length", max_length=64, truncation=True, return_tensors='pt') 30 | token, mask = result["input_ids"], result["attention_mask"] 31 | token, mask = token.squeeze(), mask.squeeze() 32 | return token,mask 33 | 34 | def __getitem__(self, index): 35 | caption = self.dataset[index]["captions"][0] 36 | label = self.dataset[index]["id"] 37 | file_path = self.dataset[index]["file_path"] 38 | image = Image.open(os.path.join(self.impath, file_path)).convert('RGB') 39 | image_gt = self.transform(image) 40 | tokens,masks = self.caption_to_tokens(caption) 41 | tokens = torch.tensor(tokens) 42 | label = torch.tensor(label) 43 | if self.flip_transform == None: 44 | return image_gt, tokens, masks, label 45 | else: 46 | return image_gt, self.flip_transform(image), tokens, masks, label 47 | 48 | def __len__(self): 49 | return len(self.dataset) 50 | 51 | class Dataset_test_image(data.Dataset): 52 | def __init__(self, image_path, dataset_path, transform=None): 53 | assert transform is not None, 'transform must not be None' 54 | self.impath = image_path 55 | self.datapath = dataset_path 56 | with open(dataset_path, 'r', encoding='utf8') as fp: 57 | self.dataset = json.load(fp) 58 | self.transform = transform 59 | print("Information about image gallery:{}".format(len(self))) 60 | 61 | def __getitem__(self, index): 62 | label = self.dataset[index]["id"] 63 | file_path = self.dataset[index]["file_path"] 64 | image = Image.open(os.path.join(self.impath, file_path)).convert('RGB') 65 | image_gt = self.transform(image) 66 | label = torch.tensor(label) 67 | return label,image_gt 68 | 69 | def __len__(self): 70 | return len(self.dataset) 71 | 72 | class Dataset_test_text(data.Dataset): 73 | def __init__(self, image_path, dataset_path): 74 | self.impath = image_path 75 | self.datapath = dataset_path 76 | with open(dataset_path, 'r', encoding='utf8') as fp: 77 | self.dataset = json.load(fp) 78 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 79 | self.initial_data = [] 80 | self.caption_depart_initial() 81 | print("Information about text query:{}".format(len(self))) 82 | 83 | def __len__(self): 84 | return len(self.initial_data) 85 | 86 | def caption_to_tokens(self, caption): 87 | result = self.tokenizer(caption, padding="max_length", max_length=64, truncation=True, return_tensors='pt') 88 | token, mask = result["input_ids"], result["attention_mask"] 89 | token, mask = token.squeeze(), mask.squeeze() 90 | return token, mask 91 | 92 | def caption_depart_initial(self): 93 | for i in range(len(self.dataset)): 94 | item = self.dataset[i] 95 | label = item["id"] 96 | captions_list = item["captions"] 97 | for j in range(len(captions_list)): 98 | caption = captions_list[j] 99 | self.initial_data.append([label,caption]) 100 | 101 | def __getitem__(self, index): 102 | caption = self.initial_data[index][1] 103 | label = self.initial_data[index][0] 104 | caption_tokens,masks = self.caption_to_tokens(caption) 105 | caption_tokens = torch.tensor(caption_tokens) 106 | label = torch.tensor(label) 107 | return label,caption_tokens,masks 108 | 109 | 110 | def get_loader(image_path, dataset_path,transform,flip_transform, batch_size, num_workers,distributed=False): 111 | dataset = Dataset(image_path=image_path,dataset_path=dataset_path,transform=transform,flip_transform=flip_transform) 112 | if distributed == False: 113 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) 114 | else: 115 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset,shuffle=True) 116 | dataloader = DataLoaderX(dataset,batch_size=batch_size,num_workers=num_workers,pin_memory=True,sampler=train_sampler,shuffle=False) 117 | return dataloader 118 | 119 | def get_loader_test(image_path, dataset_path,transform, batch_size, num_workers): 120 | image_dataset = Dataset_test_image(image_path=image_path,dataset_path=dataset_path,transform=transform) 121 | image_dataloader = DataLoader(image_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) 122 | text_dataset = Dataset_test_text(image_path=image_path,dataset_path=dataset_path) 123 | text_dataloader = DataLoader(text_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) 124 | return image_dataloader,text_dataloader -------------------------------------------------------------------------------- /Downstreams/CMPM-C/dataset_split.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def cuhk_train_depart(train_path,train_path_depart): 4 | with open(train_path) as f: 5 | dataset = json.load(f) 6 | output = [] 7 | for i in range(len(dataset)): 8 | data = dataset[i] 9 | captions = data["captions"] 10 | if len(captions)!=2: 11 | print("{}:{}".format(data["id"],captions)) 12 | for j in range(len(captions)): 13 | dict = {} 14 | dict["split"] = data["split"] 15 | dict["id"] = data["id"] 16 | dict["file_path"] = data["file_path"] 17 | dict["captions"] = [data["captions"][j]] 18 | output.append(dict) 19 | with open(train_path_depart,"w") as f: 20 | json.dump(output,f,indent=4) 21 | print("completed!") 22 | 23 | def TrainValidTest_split(path,train_path,test_path,valid_path): 24 | with open(path,"r") as f: 25 | dataset = json.load(f) 26 | train_output=[] 27 | test_output=[] 28 | valid_output=[] 29 | for i in range(len(dataset)): 30 | data = dataset[i] 31 | split = data["split"] 32 | if split == "train": 33 | train_output.append(data) 34 | elif split =="test": 35 | test_output.append(data) 36 | else: 37 | valid_output.append(data) 38 | if (i+1) % 100 == 0: 39 | print("{}/{} completed".format(i+1,len(dataset))) 40 | print("The train_set capacity:{}".format(len(train_output))) 41 | print("The test_set capacity:{}".format(len(test_output))) 42 | print("The valid_set capacity:{}".format(len(valid_output))) 43 | with open(train_path,"w") as f : 44 | json.dump(train_output,f,indent=4) 45 | with open(test_path,"w") as f : 46 | json.dump(test_output,f,indent=4) 47 | with open(valid_path,"w") as f : 48 | json.dump(valid_output,f,indent=4) 49 | 50 | if __name__ =="__main__": 51 | train_path = "data/CUHK-PEDES/CUHK-PEDES-train.json" 52 | train_path_depart = "data/CUHK-PEDES/CUHK-PEDES-train-depart.json" 53 | test_path = "data/CUHK-PEDES/CUHK-PEDES-test.json" 54 | valid_path = "data/CUHK-PEDES/CUHK-PEDES-valid.json" 55 | dataset_path = "data/CUHK-PEDES/reid_raw.json" 56 | TrainValidTest_split(dataset_path, train_path, test_path,valid_path) 57 | cuhk_train_depart(train_path,train_path_depart) 58 | 59 | train_path = "data/ICFG-PEDES/ICFG-PEDES-train.json" 60 | test_path = "data/ICFG-PEDES/ICFG-PEDES-test.json" 61 | valid_path = "data/ICFG-PEDES/ICFG-PEDES-valid.json" 62 | dataset_path = "data/ICFG-PEDES/ICFG_PEDES.json" 63 | TrainValidTest_split(dataset_path, train_path, test_path, valid_path) -------------------------------------------------------------------------------- /Downstreams/CMPM-C/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | 6 | class Loss_calc(nn.Module): 7 | def __init__(self, args): 8 | super(Loss_calc, self).__init__() 9 | self.epsilon = args.epsilon 10 | self.W =Parameter(torch.randn(args.feature_size, args.num_classes)) 11 | self.init_weight() 12 | def init_weight(self): 13 | nn.init.xavier_uniform_(self.W.data, gain=1) 14 | 15 | def compute_cmpm_loss(self, image_embeddings, text_embeddings, labels): 16 | """ 17 | Cross-Modal Projection Matching Loss(CMPM) 18 | :param image_embeddings: Tensor with dtype torch.float32 19 | :param text_embeddings: Tensor with dtype torch.float32 20 | :param labels: Tensor with dtype torch.int32 21 | :return: 22 | i2t_loss: cmpm loss for image projected to text 23 | t2i_loss: cmpm loss for text projected to image 24 | pos_avg_sim: average cosine-similarity for positive pairs 25 | neg_avg_sim: averate cosine-similarity for negative pairs 26 | """ 27 | 28 | batch_size = image_embeddings.shape[0] 29 | labels_reshape = torch.reshape(labels, (batch_size, 1)) 30 | labels_dist = labels_reshape - labels_reshape.t() 31 | labels_mask = (labels_dist == 0) 32 | 33 | image_norm = image_embeddings / image_embeddings.norm(dim=1, keepdim=True) 34 | text_norm = text_embeddings / text_embeddings.norm(dim=1, keepdim=True) 35 | image_proj_text = torch.matmul(image_embeddings, text_norm.t()) 36 | text_proj_image = torch.matmul(text_embeddings, image_norm.t()) 37 | 38 | # normalize the true matching distribution 39 | labels_mask_norm = labels_mask.float() / labels_mask.float().norm(dim=1) 40 | 41 | i2t_pred = F.softmax(image_proj_text, dim=1) 42 | i2t_loss = i2t_pred * (F.log_softmax(image_proj_text, dim=1) - torch.log(labels_mask_norm + self.epsilon)) 43 | t2i_pred = F.softmax(text_proj_image, dim=1) 44 | t2i_loss = t2i_pred * (F.log_softmax(text_proj_image, dim=1) - torch.log(labels_mask_norm + self.epsilon)) 45 | 46 | cmpm_loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean(torch.sum(t2i_loss, dim=1)) 47 | 48 | return cmpm_loss 49 | 50 | def compute_cmpc_loss(self, image_embeddings, text_embeddings, labels): 51 | """ 52 | Cross-Modal Projection Classfication loss(CMPC) 53 | :param image_embeddings: Tensor with dtype torch.float32 54 | :param text_embeddings: Tensor with dtype torch.float32 55 | :param labels: Tensor with dtype torch.int32 56 | :return: 57 | """ 58 | criterion = nn.CrossEntropyLoss(reduction='mean') 59 | self.W_norm = self.W / self.W.norm(dim=0) 60 | image_norm = image_embeddings / image_embeddings.norm(dim=1, keepdim=True) 61 | text_norm = text_embeddings / text_embeddings.norm(dim=1, keepdim=True) 62 | image_proj_text = torch.sum(image_embeddings * text_norm, dim=1, keepdim=True) * text_norm 63 | text_proj_image = torch.sum(text_embeddings * image_norm, dim=1, keepdim=True) * image_norm 64 | image_logits = torch.matmul(image_proj_text, self.W_norm) 65 | text_logits = torch.matmul(text_proj_image, self.W_norm) 66 | cmpc_loss = criterion(image_logits, labels) + criterion(text_logits, labels) 67 | image_pred = torch.argmax(image_logits, dim=1) 68 | text_pred = torch.argmax(text_logits, dim=1) 69 | image_precision = torch.mean((image_pred == labels).float()) 70 | text_precision = torch.mean((text_pred == labels).float()) 71 | return cmpc_loss, image_pred, text_pred, image_precision, text_precision 72 | 73 | def forward(self, global_visual_embed, global_textual_embed,IDlabels): 74 | cmpm_loss = self.compute_cmpm_loss(global_visual_embed,global_textual_embed,IDlabels) 75 | cmpc_loss,image_pred,text_pred,image_precision, text_precision = self.compute_cmpc_loss(global_visual_embed, global_textual_embed,IDlabels) 76 | loss = cmpm_loss + cmpc_loss 77 | 78 | return cmpm_loss, cmpc_loss, loss, image_precision, text_precision -------------------------------------------------------------------------------- /Downstreams/CMPM-C/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | import torchvision.models as models 6 | from collections import OrderedDict 7 | from transformers import BertModel 8 | 9 | def weights_init_kaiming(m): 10 | classname = m.__class__.__name__ 11 | if classname.find('Conv2d') != -1: 12 | init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu') 13 | elif classname.find('Linear') != -1: 14 | init.kaiming_normal(m.weight.data, a=0, mode='fan_out') 15 | init.constant_(m.bias.data, 0.0) 16 | elif classname.find('BatchNorm1d') != -1: 17 | init.normal(m.weight.data, 1.0, 0.02) 18 | init.constant_(m.bias.data, 0.0) 19 | elif classname.find('BatchNorm2d') != -1: 20 | init.constant_(m.weight.data, 1) 21 | init.constant_(m.bias.data, 0) 22 | 23 | class conv(nn.Module): 24 | 25 | def __init__(self, input_dim, output_dim, relu=False, BN=False): 26 | super(conv, self).__init__() 27 | 28 | block = [] 29 | block += [nn.Conv2d(input_dim, output_dim, kernel_size=1, bias=False)] 30 | 31 | if BN: 32 | block += [nn.BatchNorm2d(output_dim)] 33 | if relu: 34 | block += [nn.LeakyReLU(0.25, inplace=True)] 35 | 36 | self.block = nn.Sequential(*block) 37 | self.block.apply(weights_init_kaiming) 38 | 39 | def forward(self, x): 40 | x = self.block(x) 41 | x = x.squeeze(3).squeeze(2) 42 | return x 43 | 44 | class Bottleneck(nn.Module): 45 | expansion = 4 46 | def __init__(self, inplanes, planes, stride=1): 47 | super().__init__() 48 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 49 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 50 | self.bn1 = nn.BatchNorm2d(planes) 51 | self.relu1 = nn.ReLU(inplace=True) 52 | 53 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 54 | self.bn2 = nn.BatchNorm2d(planes) 55 | self.relu2 = nn.ReLU(inplace=True) 56 | 57 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 58 | 59 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 60 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 61 | self.relu3 = nn.ReLU(inplace=True) 62 | 63 | self.downsample = None 64 | self.stride = stride 65 | 66 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 67 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 68 | self.downsample = nn.Sequential(OrderedDict([ 69 | ("-1", nn.AvgPool2d(stride)), 70 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 71 | ("1", nn.BatchNorm2d(planes * self.expansion)) 72 | ])) 73 | 74 | def forward(self, x: torch.Tensor): 75 | identity = x 76 | 77 | out = self.relu1(self.bn1(self.conv1(x))) 78 | out = self.relu2(self.bn2(self.conv2(out))) 79 | out = self.avgpool(out) 80 | out = self.bn3(self.conv3(out)) 81 | 82 | if self.downsample is not None: 83 | identity = self.downsample(x) 84 | 85 | out += identity 86 | out = self.relu3(out) 87 | return out 88 | 89 | class AttentionPool2d(nn.Module): 90 | def __init__(self, spacial_dim_x: int,spacial_dim_y: int, embed_dim: int, num_heads: int, output_dim: int = None): 91 | super().__init__() 92 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim_x * spacial_dim_y + 1, embed_dim) / embed_dim ** 0.5) 93 | self.k_proj = nn.Linear(embed_dim, embed_dim) 94 | self.q_proj = nn.Linear(embed_dim, embed_dim) 95 | self.v_proj = nn.Linear(embed_dim, embed_dim) 96 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 97 | self.num_heads = num_heads 98 | 99 | def forward(self, x): 100 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 101 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 102 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 103 | x, _ = F.multi_head_attention_forward( 104 | query=x[:1], key=x, value=x, 105 | embed_dim_to_check=x.shape[-1], 106 | num_heads=self.num_heads, 107 | q_proj_weight=self.q_proj.weight, 108 | k_proj_weight=self.k_proj.weight, 109 | v_proj_weight=self.v_proj.weight, 110 | in_proj_weight=None, 111 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 112 | bias_k=None, 113 | bias_v=None, 114 | add_zero_attn=False, 115 | dropout_p=0, 116 | out_proj_weight=self.c_proj.weight, 117 | out_proj_bias=self.c_proj.bias, 118 | use_separate_proj_weight=True, 119 | training=self.training, 120 | need_weights=False 121 | ) 122 | return x.squeeze(0) 123 | 124 | class Image_encoder_ModifiedResNet(nn.Module): 125 | """ 126 | A ResNet class that is similar to torchvision's but contains the following changes: 127 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 128 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 129 | - The final pooling layer is a QKV attention instead of an average pool 130 | """ 131 | def __init__(self, layers, output_dim, heads, input_resolution=[256,128], width=64): 132 | super().__init__() 133 | self.output_dim = output_dim 134 | self.input_resolution = input_resolution 135 | 136 | # the 3-layer stem 137 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 138 | self.bn1 = nn.BatchNorm2d(width // 2) 139 | self.relu1 = nn.ReLU(inplace=True) 140 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 141 | self.bn2 = nn.BatchNorm2d(width // 2) 142 | self.relu2 = nn.ReLU(inplace=True) 143 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 144 | self.bn3 = nn.BatchNorm2d(width) 145 | self.relu3 = nn.ReLU(inplace=True) 146 | self.avgpool = nn.AvgPool2d(2) 147 | 148 | # residual layers 149 | self._inplanes = width # this is a *mutable* variable used during construction 150 | self.layer1 = self._make_layer(width, layers[0]) 151 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 152 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 153 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 154 | 155 | embed_dim = width * 32 # the ResNet feature dimension 156 | self.attnpool = AttentionPool2d(input_resolution[0] // 32,input_resolution[1] // 32, embed_dim, heads, output_dim) 157 | self.initialize_parameters() 158 | def _make_layer(self, planes, blocks, stride=1): 159 | layers = [Bottleneck(self._inplanes, planes, stride)] 160 | 161 | self._inplanes = planes * Bottleneck.expansion 162 | for _ in range(1, blocks): 163 | layers.append(Bottleneck(self._inplanes, planes)) 164 | 165 | return nn.Sequential(*layers) 166 | 167 | def initialize_parameters(self): 168 | if self.attnpool is not None: 169 | std = self.attnpool.c_proj.in_features ** -0.5 170 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 171 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 172 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 173 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 174 | 175 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 176 | for name, param in resnet_block.named_parameters(): 177 | if name.endswith("bn3.weight"): 178 | nn.init.zeros_(param) 179 | 180 | def forward(self, x): 181 | def stem(x): 182 | x = self.relu1(self.bn1(self.conv1(x))) 183 | x = self.relu2(self.bn2(self.conv2(x))) 184 | x = self.relu3(self.bn3(self.conv3(x))) 185 | x = self.avgpool(x) 186 | return x 187 | 188 | x = x.type(self.conv1.weight.dtype) 189 | x = stem(x) 190 | x = self.layer1(x) 191 | x = self.layer2(x) 192 | x = self.layer3(x) 193 | x = self.layer4(x) 194 | feat = self.attnpool(x) 195 | return feat 196 | 197 | class Text_encoder(nn.Module): 198 | def __init__(self, encoder_type: str): 199 | super(Text_encoder, self).__init__() 200 | self.encoder = BertModel.from_pretrained(encoder_type) 201 | 202 | def forward(self, token, mask): 203 | x = self.encoder(input_ids=token, attention_mask=mask) 204 | pooler_output = x.pooler_output 205 | return pooler_output 206 | 207 | class Model(nn.Module): 208 | def __init__(self, image_encoder,text_encoder): 209 | super().__init__() 210 | self.image_encoder = image_encoder 211 | self.text_encoder = text_encoder 212 | 213 | def encode_image(self,image): 214 | return self.image_encoder(image) 215 | 216 | def encode_text(self,text,mask): 217 | return self.text_encoder(text, mask) 218 | 219 | def forward(self, image,text,masks): 220 | global_image_out = self.image_encoder(image) 221 | global_text_out = self.text_encoder(text,masks) 222 | return global_image_out,global_text_out 223 | 224 | def Create_model(args): 225 | image_encoder = Image_encoder_ModifiedResNet(args.layers,args.img_dim,args.heads,input_resolution=[args.width,args.height]) 226 | text_encoder = Text_encoder(encoder_type=args.txt_backbone) 227 | model = Model(image_encoder, text_encoder) 228 | return model 229 | 230 | 231 | -------------------------------------------------------------------------------- /Downstreams/CMPM-C/optimizer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import Tensor 4 | from torch.optim.optimizer import Optimizer 5 | 6 | from typing import List 7 | 8 | class Adan(Optimizer): 9 | """ 10 | Implements a pytorch variant of Adan 11 | Adan was proposed in 12 | Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J]. arXiv preprint arXiv:2208.06677, 2022. 13 | https://arxiv.org/abs/2208.06677 14 | Arguments: 15 | params (iterable): iterable of parameters to optimize or dicts defining parameter groups. 16 | lr (float, optional): learning rate. (default: 1e-3) 17 | betas (Tuple[float, float, flot], optional): coefficients used for computing 18 | running averages of gradient and its norm. (default: (0.98, 0.92, 0.99)) 19 | eps (float, optional): term added to the denominator to improve 20 | numerical stability. (default: 1e-8) 21 | weight_decay (float, optional): decoupled weight decay (L2 penalty) (default: 0) 22 | max_grad_norm (float, optional): value used to clip 23 | global grad norm (default: 0.0 no clip) 24 | no_prox (bool): how to perform the decoupled weight decay (default: False) 25 | foreach (bool): if True would use torch._foreach implementation. It's faster but uses 26 | slightly more memory. (default: True) 27 | """ 28 | 29 | def __init__(self, params, lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-8, 30 | weight_decay=0.0, max_grad_norm=0.0, no_prox=False, foreach: bool = True): 31 | if not 0.0 <= max_grad_norm: 32 | raise ValueError("Invalid Max grad norm: {}".format(max_grad_norm)) 33 | if not 0.0 <= lr: 34 | raise ValueError("Invalid learning rate: {}".format(lr)) 35 | if not 0.0 <= eps: 36 | raise ValueError("Invalid epsilon value: {}".format(eps)) 37 | if not 0.0 <= betas[0] < 1.0: 38 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 39 | if not 0.0 <= betas[1] < 1.0: 40 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 41 | if not 0.0 <= betas[2] < 1.0: 42 | raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2])) 43 | defaults = dict(lr=lr, betas=betas, eps=eps, 44 | weight_decay=weight_decay, 45 | max_grad_norm=max_grad_norm, no_prox=no_prox, foreach=foreach) 46 | super().__init__(params, defaults) 47 | 48 | def __setstate__(self, state): 49 | super(Adan, self).__setstate__(state) 50 | for group in self.param_groups: 51 | group.setdefault('no_prox', False) 52 | 53 | @torch.no_grad() 54 | def restart_opt(self): 55 | for group in self.param_groups: 56 | group['step'] = 0 57 | for p in group['params']: 58 | if p.requires_grad: 59 | state = self.state[p] 60 | # State initialization 61 | 62 | # Exponential moving average of gradient values 63 | state['exp_avg'] = torch.zeros_like(p) 64 | # Exponential moving average of squared gradient values 65 | state['exp_avg_sq'] = torch.zeros_like(p) 66 | # Exponential moving average of gradient difference 67 | state['exp_avg_diff'] = torch.zeros_like(p) 68 | 69 | @torch.no_grad() 70 | def step(self, closure=None): 71 | """ 72 | Performs a single optimization step. 73 | """ 74 | 75 | loss = None 76 | if closure is not None: 77 | with torch.enable_grad(): 78 | loss = closure() 79 | 80 | if self.defaults['max_grad_norm'] > 0: 81 | device = self.param_groups[0]['params'][0].device 82 | global_grad_norm = torch.zeros(1, device=device) 83 | 84 | max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device) 85 | for group in self.param_groups: 86 | 87 | for p in group['params']: 88 | if p.grad is not None: 89 | grad = p.grad 90 | global_grad_norm.add_(grad.pow(2).sum()) 91 | 92 | global_grad_norm = torch.sqrt(global_grad_norm) 93 | 94 | clip_global_grad_norm = torch.clamp(max_grad_norm / (global_grad_norm + group['eps']), max=1.0) 95 | else: 96 | clip_global_grad_norm = 1.0 97 | 98 | for group in self.param_groups: 99 | params_with_grad = [] 100 | grads = [] 101 | exp_avgs = [] 102 | exp_avg_sqs = [] 103 | exp_avg_diffs = [] 104 | pre_grads = [] 105 | 106 | beta1, beta2, beta3 = group['betas'] 107 | # assume same step across group now to simplify things 108 | # per parameter step can be easily support by making it tensor, or pass list into kernel 109 | if 'step' in group: 110 | group['step'] += 1 111 | else: 112 | group['step'] = 1 113 | 114 | bias_correction1 = 1.0 - beta1 ** group['step'] 115 | bias_correction2 = 1.0 - beta2 ** group['step'] 116 | bias_correction3 = 1.0 - beta3 ** group['step'] 117 | 118 | for p in group['params']: 119 | if p.grad is None: 120 | continue 121 | params_with_grad.append(p) 122 | grads.append(p.grad) 123 | 124 | state = self.state[p] 125 | if len(state) == 0: 126 | state['exp_avg'] = torch.zeros_like(p) 127 | state['exp_avg_sq'] = torch.zeros_like(p) 128 | state['exp_avg_diff'] = torch.zeros_like(p) 129 | 130 | if 'pre_grad' not in state or group['step'] == 1: 131 | # at first step grad wouldn't be clipped by `clip_global_grad_norm` 132 | # this is only to simplify implementation 133 | state['pre_grad'] = p.grad 134 | 135 | exp_avgs.append(state['exp_avg']) 136 | exp_avg_sqs.append(state['exp_avg_sq']) 137 | exp_avg_diffs.append(state['exp_avg_diff']) 138 | pre_grads.append(state['pre_grad']) 139 | 140 | kwargs = dict( 141 | params=params_with_grad, 142 | grads=grads, 143 | exp_avgs=exp_avgs, 144 | exp_avg_sqs=exp_avg_sqs, 145 | exp_avg_diffs=exp_avg_diffs, 146 | pre_grads=pre_grads, 147 | beta1=beta1, 148 | beta2=beta2, 149 | beta3=beta3, 150 | bias_correction1=bias_correction1, 151 | bias_correction2=bias_correction2, 152 | bias_correction3_sqrt=math.sqrt(bias_correction3), 153 | lr=group['lr'], 154 | weight_decay=group['weight_decay'], 155 | eps=group['eps'], 156 | no_prox=group['no_prox'], 157 | clip_global_grad_norm=clip_global_grad_norm, 158 | ) 159 | if group["foreach"]: 160 | copy_grads = _multi_tensor_adan(**kwargs) 161 | else: 162 | copy_grads = _single_tensor_adan(**kwargs) 163 | 164 | for p, copy_grad in zip(params_with_grad, copy_grads): 165 | self.state[p]['pre_grad'] = copy_grad 166 | 167 | return loss 168 | 169 | 170 | def _single_tensor_adan( 171 | params: List[Tensor], 172 | grads: List[Tensor], 173 | exp_avgs: List[Tensor], 174 | exp_avg_sqs: List[Tensor], 175 | exp_avg_diffs: List[Tensor], 176 | pre_grads: List[Tensor], 177 | *, 178 | beta1: float, 179 | beta2: float, 180 | beta3: float, 181 | bias_correction1: float, 182 | bias_correction2: float, 183 | bias_correction3_sqrt: float, 184 | lr: float, 185 | weight_decay: float, 186 | eps: float, 187 | no_prox: bool, 188 | clip_global_grad_norm: Tensor, 189 | ): 190 | copy_grads = [] 191 | for i, param in enumerate(params): 192 | grad = grads[i] 193 | exp_avg = exp_avgs[i] 194 | exp_avg_sq = exp_avg_sqs[i] 195 | exp_avg_diff = exp_avg_diffs[i] 196 | pre_grad = pre_grads[i] 197 | 198 | grad = grad.mul_(clip_global_grad_norm) 199 | copy_grads.append(grad.clone()) 200 | 201 | diff = grad - pre_grad 202 | update = grad + beta2 * diff 203 | 204 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t 205 | exp_avg_diff.mul_(beta2).add_(diff, alpha=1 - beta2) # diff_t 206 | exp_avg_sq.mul_(beta3).addcmul_(update, update, value=1 - beta3) # n_t 207 | 208 | denom = ((exp_avg_sq).sqrt() / bias_correction3_sqrt).add_(eps) 209 | update = ((exp_avg / bias_correction1 + beta2 * exp_avg_diff / bias_correction2)).div_(denom) 210 | 211 | if no_prox: 212 | param.mul_(1 - lr * weight_decay) 213 | param.add_(update, alpha=-lr) 214 | else: 215 | param.add_(update, alpha=-lr) 216 | param.div_(1 + lr * weight_decay) 217 | return copy_grads 218 | 219 | 220 | def _multi_tensor_adan( 221 | params: List[Tensor], 222 | grads: List[Tensor], 223 | exp_avgs: List[Tensor], 224 | exp_avg_sqs: List[Tensor], 225 | exp_avg_diffs: List[Tensor], 226 | pre_grads: List[Tensor], 227 | *, 228 | beta1: float, 229 | beta2: float, 230 | beta3: float, 231 | bias_correction1: float, 232 | bias_correction2: float, 233 | bias_correction3_sqrt: float, 234 | lr: float, 235 | weight_decay: float, 236 | eps: float, 237 | no_prox: bool, 238 | clip_global_grad_norm: Tensor, 239 | ): 240 | if clip_global_grad_norm < 1.0: 241 | torch._foreach_mul_(grads, clip_global_grad_norm.item()) 242 | copy_grads = [g.clone() for g in grads] 243 | 244 | diff = torch._foreach_sub(grads, pre_grads) 245 | # NOTE: line below while looking identical gives different result, due to float precision errors. 246 | # using mul+add produces identical results to single-tensor, using add+alpha doesn't 247 | # On cuda this difference doesn't matter due to its' own precision non-determinism 248 | # update = torch._foreach_add(grads, torch._foreach_mul(diff, beta2)) 249 | update = torch._foreach_add(grads, diff, alpha=beta2) 250 | 251 | torch._foreach_mul_(exp_avgs, beta1) 252 | torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) # m_t 253 | 254 | torch._foreach_mul_(exp_avg_diffs, beta2) 255 | torch._foreach_add_(exp_avg_diffs, diff, alpha=1 - beta2) # diff_t 256 | 257 | torch._foreach_mul_(exp_avg_sqs, beta3) 258 | torch._foreach_addcmul_(exp_avg_sqs, update, update, value=1 - beta3) # n_t 259 | 260 | denom = torch._foreach_sqrt(exp_avg_sqs) 261 | torch._foreach_div_(denom, bias_correction3_sqrt) 262 | torch._foreach_add_(denom, eps) 263 | 264 | update = torch._foreach_div(exp_avgs, bias_correction1) 265 | # NOTE: same issue as above. beta2 * diff / bias_correction2 != diff * (beta2 / bias_correction2) 266 | # using faster version by default. 267 | # torch._foreach_add_(update, torch._foreach_div(torch._foreach_mul(exp_avg_diffs, beta2), bias_correction2)) 268 | torch._foreach_add_(update, torch._foreach_mul(exp_avg_diffs, beta2 / bias_correction2)) 269 | torch._foreach_div_(update, denom) 270 | 271 | if no_prox: 272 | torch._foreach_mul_(params, 1 - lr * weight_decay) 273 | torch._foreach_add_(params, update, alpha=-lr) 274 | else: 275 | torch._foreach_add_(params, update, alpha=-lr) 276 | torch._foreach_div_(params, 1 + lr * weight_decay) 277 | return copy_grads 278 | 279 | def create_optimizer(params,args): 280 | if args.optimizer =="adam": 281 | print("The optimizer is {}".format(args.optimizer)) 282 | optimizer = torch.optim.Adam(params,lr=args.lr, betas=(args.adam_alpha, args.adam_beta), eps=args.epsilon) 283 | return optimizer 284 | elif args.optimizer =="adan": 285 | print("The optimizer is {}".format(args.optimizer)) 286 | optimizer = Adan(params, lr=args.lr, weight_decay=args.adan_weight_decay, betas=args.adan_opt_betas, eps=args.adan_opt_eps, 287 | max_grad_norm=args.adan_max_grad_norm, no_prox=args.adan_no_prox) 288 | return optimizer -------------------------------------------------------------------------------- /Downstreams/CMPM-C/test.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | from utils import * 3 | import os 4 | import shutil 5 | from dataloader import get_loader_test 6 | import argparse 7 | from model import Create_model 8 | import warnings 9 | warnings.filterwarnings("ignore") 10 | 11 | def Test_parse_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--model_path', type=str, default='log/CUHK-CMPMC') 14 | parser.add_argument('--image_dir', type=str, default='data/CUHK-PEDES/') 15 | parser.add_argument('--caption_path', type=str, 16 | default='data/CUHK-PEDES/CUHK-PEDES-test.json', 17 | help='path for test annotation json file') 18 | parser.add_argument('--best_dir', type=str, default='log/CUHK-CMPMC') 19 | parser.add_argument('--flip_eval', type=bool, default=False) 20 | parser.add_argument('--mean', default=[0.357, 0.323, 0.328], type=list) 21 | parser.add_argument('--std', default=[0.252, 0.242, 0.239], type=list) 22 | # *********************************************************************************************************************** 23 | # 设置模型backbone的类型和参数 24 | parser.add_argument('--img_backbone', type=str, default='ModifiedResNet', 25 | help="ResNet:xxx, ModifiedResNet, ViT:xxx") 26 | parser.add_argument('--txt_backbone', type=str, default="bert-base-uncased") 27 | parser.add_argument('--img_dim', type=int, default=768, help='dimension of image embedding vectors') 28 | parser.add_argument('--text_dim', type=int, default=768, help='dimension of text embedding vectors') 29 | parser.add_argument('--patch_size', type=int, default=16, help='Just for ViT model') 30 | parser.add_argument('--layers', type=list, default=[3, 4, 6, 3], help='Just for ModifiedResNet model') 31 | parser.add_argument('--heads', type=int, default=8, help='Just for ModifiedResNet model') 32 | 33 | parser.add_argument('--height', type=int, default=256) 34 | parser.add_argument('--width', type=int, default=128) 35 | 36 | # 设置超参数 37 | parser.add_argument('--num_epoches', type=int, default=30) 38 | parser.add_argument('--batch_size', type=int, default=128) 39 | parser.add_argument('--num_workers', type=int, default=4) 40 | parser.add_argument('--device',type=str,default="cuda:0") 41 | parser.add_argument('--feature_size', type=int, default=768) 42 | args = parser.parse_args() 43 | return args 44 | 45 | def test(image_test_loader,text_test_loader, model): 46 | # switch to evaluate mode 47 | model = model.eval() 48 | device = next(model.parameters()).device 49 | 50 | qids, gids, qfeats, gfeats = [], [], [], [] 51 | # text 52 | for pid, caption,mask in text_test_loader: 53 | caption = caption.to(device) 54 | mask = mask.to(device) 55 | with torch.no_grad(): 56 | text_feat = model.encode_text(caption,mask) 57 | qids.append(pid.view(-1)) # flatten 58 | qfeats.append(text_feat) 59 | qids = torch.cat(qids, 0) 60 | qfeats = torch.cat(qfeats, 0) 61 | 62 | # image 63 | for pid, img in image_test_loader: 64 | img = img.to(device) 65 | with torch.no_grad(): 66 | img_feat = model.encode_image(img) 67 | gids.append(pid.view(-1)) # flatten 68 | gfeats.append(img_feat) 69 | gids = torch.cat(gids, 0) 70 | gfeats = torch.cat(gfeats, 0) 71 | ac_top1_t2i, ac_top5_t2i, ac_top10_t2i, mAP = test_map(qfeats, qids, gfeats, gids) 72 | return ac_top1_t2i, ac_top5_t2i, ac_top10_t2i, mAP 73 | 74 | 75 | def Test_main(args): 76 | device = args.device 77 | transform = transforms.Compose([ 78 | transforms.Resize((256,128), interpolation=3), 79 | transforms.ToTensor(), 80 | transforms.Normalize(args.mean, 81 | args.std) 82 | ]) 83 | 84 | image_test_loader,text_test_loader = get_loader_test(args.image_dir, args.caption_path, transform, args.batch_size,args.num_workers) 85 | 86 | ac_t2i_top1_best = 0.0 87 | ac_t2i_top5_best = 0.0 88 | ac_t2i_top10_best = 0.0 89 | mAP_best = 0.0 90 | best = 0 91 | dst_best = args.best_dir + "/model_best" + ".pth" 92 | model = Create_model(args).to(device) 93 | 94 | for i in range(0, args.num_epoches): 95 | if i%2!=0 and i!=args.num_epoches-1: 96 | continue 97 | model_file = os.path.join(args.model_path, str(i+1))+".pth.tar" 98 | print(model_file) 99 | if os.path.isdir(model_file): 100 | continue 101 | checkpoint = torch.load(model_file) 102 | model.image_encoder.load_state_dict(checkpoint["ImgEncoder_state_dict"]) 103 | model.text_encoder.load_state_dict(checkpoint["TxtEncoder_state_dict"]) 104 | ac_top1_t2i, ac_top5_t2i, ac_top10_t2i, mAP = test(image_test_loader,text_test_loader, model) 105 | if ac_top1_t2i > ac_t2i_top1_best: 106 | ac_t2i_top1_best = ac_top1_t2i 107 | ac_t2i_top5_best = ac_top5_t2i 108 | ac_t2i_top10_best = ac_top10_t2i 109 | mAP_best = mAP 110 | best = i 111 | shutil.copyfile(model_file, dst_best) 112 | 113 | print('Epo{}: {:.5f} {:.5f} {:.5f} {:.5f}'.format( 114 | best, ac_t2i_top1_best, ac_t2i_top5_best, ac_t2i_top10_best, mAP_best)) 115 | 116 | if __name__ == '__main__': 117 | args = Test_parse_args() 118 | Test_main(args) -------------------------------------------------------------------------------- /Downstreams/CMPM-C/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from dataloader import get_loader 3 | from model import Create_model 4 | from loss import Loss_calc 5 | from torchvision import transforms 6 | from utils import * 7 | import time 8 | from optimizer import create_optimizer 9 | import torch.distributed as dist 10 | from torch.nn.parallel import DistributedDataParallel as DDP 11 | 12 | 13 | def Train_parse_args(): 14 | parser = argparse.ArgumentParser(description='command for train by Resnet') 15 | parser.add_argument('--mean', default=[0.357, 0.323, 0.328], type=list) 16 | parser.add_argument('--std', default=[0.252, 0.242, 0.239], type=list) 17 | #*********************************************************************************************************************** 18 | #设置数据集路径和输出路径等 19 | parser.add_argument('--name', default='CUHK-CMPMC', type=str, help='output model name') 20 | parser.add_argument('--checkpoint_dir', type=str, 21 | default="log", 22 | help='directory to store checkpoint') 23 | parser.add_argument('--log_dir', type=str, 24 | default="log", 25 | help='directory to store log') 26 | parser.add_argument('--image_path', type=str, 27 | default=r'data/CUHK-PEDES', 28 | help='directory to store dataset') 29 | parser.add_argument('--dataset_path', type=str, 30 | default=r'data/CUHK-PEDES-train-depart.json', 31 | help='directory to annotations') 32 | parser.add_argument('--checkpoint_path', type=str,default=r'checkpoints/PLIP_RN50.pth.tar') 33 | #*********************************************************************************************************************** 34 | #设置模型backbone的类型和参数 35 | parser.add_argument('--img_backbone', type=str, default='ModifiedResNet',help="ResNet:xxx, ModifiedResNet, ViT:xxx") 36 | parser.add_argument('--txt_backbone', type=str, default="bert-base-uncased") 37 | parser.add_argument('--img_dim', type=int, default=768, help='dimension of image embedding vectors') 38 | parser.add_argument('--text_dim', type=int, default=768, help='dimension of text embedding vectors') 39 | parser.add_argument('--layers', type=list, default=[3,4,6,3], help='Just for ModifiedResNet model') 40 | parser.add_argument('--heads', type=int, default=8, help='Just for ModifiedResNet model') 41 | parser.add_argument('--feature_size', type=int, default=768) 42 | parser.add_argument('--num_classes', type=int, default=11003) # CUHK:11003 ICFG:3102 THE NUMBER OF IDENTITIES 43 | #*********************************************************************************************************************** 44 | #设置训练预处理超参数 45 | parser.add_argument('--height', type=int, default=256) 46 | parser.add_argument('--width', type=int, default=128) 47 | 48 | #*********************************************************************************************************************** 49 | #设置学习率等超参数 50 | parser.add_argument('--save_every', type=int, default=1, help='step size for saving trained models') 51 | parser.add_argument('--batch_size', type=int, default=32) 52 | parser.add_argument('--epochs', type=int, default=30) 53 | parser.add_argument('--warm_epoch', default=5, type=int, help='the first K epoch that needs warm up') 54 | parser.add_argument('--lr', type=float, default=0.0001, help='the learning rate of optimizer') 55 | parser.add_argument('--lr_decay_type', type=str, default='MultiStepLR', 56 | help='One of "MultiStepLR" or "StepLR" or "get_linear_schedule_with_warmup"') 57 | parser.add_argument('--lr_decay_ratio', type=float, default=0.1) 58 | parser.add_argument('--epoches_decay', type=str, default='20', help='#epoches when learning rate decays') 59 | parser.add_argument('--backbone_frozen', type=bool, default=False) 60 | #*********************************************************************************************************************** 61 | # 设置优化器超参数 62 | parser.add_argument('--optimizer', type=str, default="adan", help='The optimizer type:adam or adan') 63 | 64 | parser.add_argument('--adam_alpha', type=float, default=0.9) 65 | parser.add_argument('--adam_beta', type=float, default=0.999) 66 | parser.add_argument('--epsilon', type=float, default=1e-8) 67 | 68 | parser.add_argument('--adan_max-grad-norm', type=float, default=0.0, 69 | help='if the l2 norm is large than this hyper-parameter, then we clip the gradient (default: 0.0, no gradient clip)') 70 | parser.add_argument('--adan_weight-decay', type=float, default=0.02, 71 | help='weight decay, similar one used in AdamW (default: 0.02)') 72 | parser.add_argument('--adan_opt-eps', default=1e-8, type=float, metavar='EPSILON', 73 | help='optimizer epsilon to avoid the bad case where second-order moment is zero (default: None, use opt default 1e-8 in adan)') 74 | parser.add_argument('--adan_opt-betas', default=[0.98, 0.92, 0.99], type=float, nargs='+', metavar='BETA', 75 | help='optimizer betas in Adan (default: None, use opt default [0.98, 0.92, 0.99] in Adan)') 76 | parser.add_argument('--adan_no-prox', action='store_true', default=False, 77 | help='whether perform weight decay like AdamW (default=False)') 78 | 79 | #*********************************************************************************************************************** 80 | #其他设置 81 | parser.add_argument('--num_workers', type=int, default=4) 82 | parser.add_argument('--gpus', type=str, default="0,1,2,3") 83 | parser.add_argument('--local_rank', type=int) 84 | args = parser.parse_args() 85 | return args 86 | 87 | 88 | def train(args): 89 | rank = int(os.environ["RANK"]) 90 | local_rank = int(os.environ["LOCAL_RANK"]) 91 | torch.cuda.set_device(rank % torch.cuda.device_count()) 92 | dist.init_process_group(backend="nccl") 93 | device = torch.device("cuda", local_rank) 94 | print(f"[init] == local rank: {local_rank}, global rank: {rank} ==") 95 | 96 | train_information_setting(args) 97 | checkpoint_dir = os.path.join(args.checkpoint_dir, args.name) 98 | 99 | transform = transforms.Compose([ 100 | transforms.Resize((args.height, args.width),interpolation=3), 101 | transforms.RandomHorizontalFlip(p=0.5), 102 | transforms.ToTensor(), 103 | transforms.Normalize(args.mean,args.std) 104 | ]) 105 | 106 | epochs = args.epochs 107 | 108 | model = Create_model(args).to(device) 109 | model_file = args.checkpoint_path 110 | checkpoint = torch.load(model_file) 111 | model.image_encoder.load_state_dict(checkpoint["ImgEncoder_state_dict"]) 112 | model.text_encoder.load_state_dict(checkpoint["TxtEncoder_state_dict"]) 113 | Loss = Loss_calc(args).to(device) 114 | params = [{"params": model.image_encoder.parameters(), "lr": args.lr}, 115 | {"params": model.text_encoder.parameters(), "lr": args.lr}, 116 | {"params": Loss.parameters(), "lr": args.lr * 10}] 117 | 118 | 119 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) # 多卡bn层同步 120 | optimizer = create_optimizer(params, args) 121 | 122 | train_dataloader = get_loader(args.image_path, args.dataset_path, transform, None,args.batch_size,args.num_workers,distributed=True) 123 | scheduler = lr_scheduler(optimizer, args,len(train_dataloader)) 124 | PrintInformation(args, model) 125 | 126 | model = DDP(model, device_ids=[local_rank], output_device=local_rank, broadcast_buffers=True,find_unused_parameters=True) 127 | 128 | if rank == 0: 129 | print(" ======= Training ======= \n") 130 | 131 | model.train() 132 | 133 | for epoch in range(epochs): 134 | print("**********************************************************") 135 | print(f">>> Training epoch {epoch + 1}") 136 | total_loss = 0 137 | start = time.time() 138 | train_dataloader.sampler.set_epoch(epoch) 139 | if epoch < args.warm_epoch: 140 | print('learning rate warm_up') 141 | optimizer = gradual_warmup(epoch, optimizer, epochs=args.warm_epoch) 142 | 143 | optimizer.zero_grad() 144 | for idx, (images_gt, targets, masks, labels) in enumerate( 145 | train_dataloader): 146 | images_gt, targets, masks, labels = images_gt.to(device), targets.to(device), masks.to(device), labels.to(device) - 1 147 | global_visual_embed, global_textual_embed = model(images_gt, targets, masks) 148 | IDlabels = labels 149 | cmpm_loss, cmpc_loss, loss, image_precision, text_precision = Loss(global_visual_embed, global_textual_embed, IDlabels) 150 | #print(loss) 151 | loss.backward() 152 | optimizer.step() 153 | optimizer.zero_grad() 154 | total_loss += loss.item() 155 | if idx % 50 == 0 : 156 | print( 157 | "Train Epoch:[{}/{}] iteration:[{}/{}] cmpm_loss:{:.4f} cmpc_loss:{:.4f} " 158 | "image_pre:{:.4f} text_pre:{:.4f}" 159 | .format(epoch + 1, args.epochs, idx, len(train_dataloader), cmpm_loss.item(), 160 | cmpc_loss.item(), 161 | image_precision, text_precision)) 162 | scheduler.step() 163 | Epoch_time = time.time() - start 164 | print("Average loss is :{}".format(total_loss / len(train_dataloader))) 165 | print('Epoch_training complete in {:.0f}m {:.0f}s'.format( 166 | Epoch_time // 60, Epoch_time % 60)) 167 | 168 | if epoch % args.save_every == 0 or epoch == epochs - 1: 169 | state = {"epoch": epoch + 1, 170 | "ImgEncoder_state_dict": model.module.image_encoder.state_dict(), 171 | "TxtEncoder_state_dict": model.module.text_encoder.state_dict() 172 | } 173 | save_checkpoint(state, epoch + 1, checkpoint_dir) 174 | 175 | def main(): 176 | args = Train_parse_args() 177 | train(args) 178 | 179 | if __name__ == '__main__': 180 | main() 181 | -------------------------------------------------------------------------------- /Downstreams/CMPM-C/utils.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import sys 3 | import os.path as osp 4 | import torch.utils.data as data 5 | import os 6 | import torch 7 | import numpy as np 8 | import random 9 | import json 10 | import torch.nn as nn 11 | from einops.layers.torch import Rearrange 12 | from transformers import get_linear_schedule_with_warmup 13 | import yaml 14 | 15 | def setup_seed(seed): 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | np.random.seed(seed) # Numpy module. 20 | random.seed(seed) # Python random module. 21 | torch.backends.cudnn.benchmark = False 22 | torch.backends.cudnn.deterministic = True 23 | print("The seed is {}".format(seed)) 24 | 25 | def lr_scheduler(optimizer, args,loader_length): 26 | if args.lr_decay_type == "get_linear_schedule_with_warmup": 27 | scheduler = get_linear_schedule_with_warmup( 28 | optimizer, num_warmup_steps=int(args.epochs * loader_length * 0.1/args.accumulation_steps), 29 | num_training_steps=int(args.epochs * loader_length/args.accumulation_steps) 30 | ) 31 | print("lr_scheduler is get_linear_schedule_with_warmup") 32 | else: 33 | if '_' in args.epoches_decay: 34 | epoches_list = args.epoches_decay.split('_') 35 | epoches_list = [int(e) for e in epoches_list] 36 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, epoches_list, gamma=args.lr_decay_ratio) 37 | print("lr_scheduler is MultiStepLR") 38 | else: 39 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, int(args.epoches_decay), gamma=args.lr_decay_ratio) 40 | print("lr_scheduler is StepLR") 41 | return scheduler 42 | 43 | def save_checkpoint(state, epoch, dst): 44 | if not os.path.exists(dst): 45 | os.makedirs(dst) 46 | filename = os.path.join(dst, str(epoch)) + '.pth.tar' 47 | torch.save(state, filename) 48 | 49 | def gradual_warmup(epoch,optimizer,epochs): 50 | if epoch == 0: 51 | warmup_percent_done = (epoch + 1) / epochs 52 | else: 53 | warmup_percent_done = (epoch + 1) / epoch 54 | for param_group in optimizer.param_groups: 55 | param_group['lr'] = param_group['lr']*warmup_percent_done 56 | return optimizer 57 | 58 | def PrintInformation(args,model): 59 | # with open(args.annotation_path,"r") as f: 60 | # dataset = json.load(f) 61 | # num = len(dataset) 62 | print("The image model is: {}".format(args.img_backbone)) 63 | print("The language model is: {}".format(args.txt_backbone)) 64 | print('Number of model parameters: {}'.format( 65 | sum([p.data.nelement() for p in model.parameters()]))) 66 | print("Checkpoints are saved in: {}".format(args.name)) 67 | print("The original learning rate is {}".format(args.lr)) 68 | 69 | 70 | def train_information_setting(args): #配置好训练信息的输出文件 71 | name = args.name 72 | # set some paths 73 | log_dir = args.log_dir 74 | log_dir = os.path.join(log_dir, name) 75 | sys.stdout = Logger(os.path.join(log_dir, "train_log.txt")) 76 | opt_dir = os.path.join('log', name) 77 | if not os.path.exists(opt_dir): 78 | os.makedirs(opt_dir) 79 | with open('%s/opts_train.yaml' % opt_dir, 'w') as fp: 80 | yaml.dump(vars(args), fp, default_flow_style=False) 81 | 82 | def mkdir_if_missing(directory): 83 | if not osp.exists(directory): 84 | try: 85 | os.makedirs(directory) 86 | except OSError as e: 87 | if e.errno != errno.EEXIST: 88 | raise 89 | 90 | class Logger(object): 91 | """ 92 | Write console output to external text file. 93 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 94 | """ 95 | def __init__(self, fpath=None): 96 | self.console = sys.stdout 97 | self.file = None 98 | if fpath is not None: 99 | mkdir_if_missing(os.path.dirname(fpath)) 100 | self.file = open(fpath, 'w') 101 | 102 | def __del__(self): 103 | self.close() 104 | 105 | def __enter__(self): 106 | pass 107 | 108 | def __exit__(self, *args): 109 | self.close() 110 | 111 | def write(self, msg): 112 | self.console.write(msg) 113 | if self.file is not None: 114 | self.file.write(msg) 115 | 116 | def flush(self): 117 | self.console.flush() 118 | if self.file is not None: 119 | self.file.flush() 120 | os.fsync(self.file.fileno()) 121 | 122 | def close(self): 123 | self.console.close() 124 | if self.file is not None: 125 | self.file.close() 126 | 127 | def compute_topk(query, gallery, target_query, target_gallery, k=[1,10], reverse=False): 128 | result = [] 129 | query = query / (query.norm(dim=1,keepdim=True)+1e-12) 130 | gallery = gallery / (gallery.norm(dim=1,keepdim=True)+1e-12) 131 | sim_cosine = torch.matmul(query, gallery.t()) 132 | result.extend(topk(sim_cosine, target_gallery, target_query, k)) 133 | if reverse: 134 | result.extend(topk(sim_cosine, target_query, target_gallery, k, dim=0)) 135 | return result 136 | 137 | def topk(sim, target_gallery, target_query, k=[1,10], dim=1): 138 | result = [] 139 | maxk = max(k) 140 | size_total = len(target_gallery) 141 | _, pred_index = sim.topk(maxk, dim, True, True) 142 | pred_labels = target_gallery[pred_index] 143 | if dim == 1: 144 | pred_labels = pred_labels.t() 145 | correct = pred_labels.eq(target_query.view(1,-1).expand_as(pred_labels)) 146 | 147 | for topk in k: 148 | correct_k = torch.sum(correct[:topk], dim=0) 149 | correct_k = torch.sum(correct_k > 0).float() 150 | result.append(correct_k * 100 / size_total) 151 | return result 152 | 153 | def test_map(query_feature,query_label,gallery_feature, gallery_label): 154 | query_feature = query_feature / (query_feature.norm(dim=1, keepdim=True) + 1e-12) 155 | gallery_feature = gallery_feature / (gallery_feature.norm(dim=1, keepdim=True) + 1e-12) 156 | CMC = torch.IntTensor(len(gallery_label)).zero_() 157 | ap = 0.0 158 | for i in range(len(query_label)): 159 | ap_tmp, CMC_tmp = evaluate(query_feature[i], query_label[i], gallery_feature, gallery_label) 160 | 161 | if CMC_tmp[0] == -1: 162 | continue 163 | CMC = CMC + CMC_tmp 164 | ap += ap_tmp 165 | CMC = CMC.float() 166 | CMC = CMC / len(query_label) 167 | print('Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f' % (CMC[0], CMC[4], CMC[9], ap / len(query_label))) 168 | return CMC[0], CMC[4], CMC[9], ap / len(query_label) 169 | 170 | def evaluate(qf, ql, gf, gl): 171 | query = qf.view(-1, 1) 172 | score = torch.mm(gf, query) 173 | score = score.squeeze(1).cpu() 174 | score = score.numpy() 175 | index = np.argsort(score) 176 | index = index[::-1] 177 | gl=gl.cuda().data.cpu().numpy() 178 | ql=ql.cuda().data.cpu().numpy() 179 | query_index = np.argwhere(gl == ql) 180 | CMC_tmp = compute_mAP(index, query_index) 181 | return CMC_tmp 182 | 183 | def compute_mAP(index, good_index): 184 | ap = 0 185 | cmc = torch.IntTensor(len(index)).zero_() 186 | if good_index.size == 0: # if empty 187 | cmc[0] = -1 188 | return ap, cmc 189 | # find good_index index 190 | ngood = len(good_index) 191 | mask = np.in1d(index, good_index) 192 | rows_good = np.argwhere(mask == True) 193 | rows_good = rows_good.flatten() 194 | 195 | cmc[rows_good[0]:] = 1 196 | for i in range(ngood): 197 | d_recall = 1.0 / ngood 198 | precision = (i + 1) * 1.0 / (rows_good[i] + 1) 199 | if rows_good[i] != 0: 200 | old_precision = i * 1.0 / rows_good[i] 201 | else: 202 | old_precision = 1.0 203 | ap = ap + d_recall * (old_precision + precision) / 2 204 | 205 | return ap, cmc -------------------------------------------------------------------------------- /Downstreams/readme.txt: -------------------------------------------------------------------------------- 1 | These are some experiments for the downstream tasks. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 jlongzuo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PLIPmodel.py: -------------------------------------------------------------------------------- 1 | from visual_model import * 2 | from textual_model import Textual_encoder 3 | 4 | class PLIP_MResNet(nn.Module): 5 | def __init__(self, image_encoder,text_encoder): 6 | super().__init__() 7 | self.image_encoder = image_encoder 8 | self.text_encoder = text_encoder 9 | 10 | def get_text_global_embedding(self,caption,mask): 11 | global_text_out = self.text_encoder.get_global_embedding(caption,mask) 12 | return global_text_out 13 | 14 | def get_image_embeddings(self,image): 15 | global_image_out, _,_,_,_ = self.image_encoder(image) 16 | return global_image_out 17 | 18 | def forward(self, image,text,masks): 19 | global_image_out,x1,x2,x3,x4 = self.image_encoder(image) 20 | global_text_out, part_text_out = self.text_encoder(text,masks) 21 | return global_image_out,x1,x2,x3,x4,global_text_out,part_text_out 22 | 23 | def Create_PLIP_Model(args): 24 | if args.plip_model == "MResNet_BERT": 25 | image_encoder = Image_encoder_ModifiedResNet(args.layers,args.img_dim,args.heads,input_resolution=[args.width,args.height]) 26 | text_encoder = Textual_encoder(encoder_type=args.txt_backbone) 27 | model = PLIP_MResNet(image_encoder, text_encoder) 28 | return model 29 | else: 30 | raise RuntimeError(f"The image backbone you input does not meet the specification!") 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PLIP 2 | **PLIP** is a novel **L**anguage-**I**mage **P**re-training framework for generic **P**erson representation learning which benefits a range of downstream person-centric tasks. 3 | 4 | Also, we present a large-scale person dataset named **SYNTH-PEDES** to verify its effectiveness, where the Stylish Pedestrian Attributes-union Captioning method **(SPAC)** is proposed to synthesize diverse textual descriptions. 5 | 6 | Experiments show that our model not only significantly improves existing methods on downstream tasks, but also shows great ability in the few-shot and domain generalization settings. More details can be found at our paper [PLIP: Language-Image Pre-training for Person Representation Learning](http://export.arxiv.org/abs/2305.08386). 7 | 8 |
9 | 10 | ## News 11 | * 🔥[06.1] The **SYNTH-PEDES** is released. Welcome to download and use! 12 | * 🔥[06.1] The code for **CMPM/C fine-tuning** is released! It leads to SOTA performance without bells and whistles! 13 | * 🔥[05.31] The pre-trained model and **zero-shot inference** code are released ! 14 | 15 | ## SYNTH-PEDES 16 | SYNTH-PEDES is by far the largest person dataset with textual descriptions without any human annotation effort. Every person image has 2 or 3 different texutal descriptions and 6 attribute annotations. The dataset is released at [Baidu Yun](https://pan.baidu.com/s/11jQ3gvkn77b3jjVx-quQxQ?pwd=1037). 17 | 18 | **Note that SYNTH-PEDES can only be used for research, any commercial usage is forbidden.** 19 | 20 | This is the comparison of SYNTH-PEDES with other popular datasets. 21 |
22 | 23 | These are some examples of our SYNTH-PEDES dataset. 24 |
25 | 26 | Annotation format: 27 | ``` 28 | { 29 | "id": 7, 30 | "file_path": "Part1/7/1.jpg", 31 | "attributes": [ 32 | "man,black hair,black shirt,pink shorts,black shoes,unknown" 33 | ], 34 | "captions": [ 35 | "A man in his mid-twenties with short black hair is wearing a black t-shirt over light pink trousers. He is also wearing black shoes.", 36 | "The man with short black hair is wearing a black shirt and salmon pink shorts. He is also wearing black shoes." 37 | ], 38 | "prompt_caption": [ 39 | "A man with black hair is wearing a black shirt with pink shorts and a pair of black shoes." 40 | ] 41 | } 42 | ``` 43 | 44 | ## Models 45 | We utilize ResNet50 and Bert as our encoders. After pre-training, we fine-tune and evaluate the performance on three downstream tasks. The checkpoints have been released at [Baidu Yun](https://pan.baidu.com/s/1LjT-x6kjGwpO2EP4Ni7bCA?pwd=1037) and [Google Drive](https://drive.google.com/file/d/1Cpid6AGHXF_is5ULB3UJKMGvl6kf2Tmg/view?usp=sharing). 46 | 47 | ### CUHK-PEDES dataset (Text Re-ID R@1/R@10) 48 | | Pre-train | CMPM/C | SSAN | LGUR | 49 | | :---: |:---: |:---: | :---: 50 | | IN sup | 54.81/83.22 | 61.37/86.73 | 64.21/87.93 51 | | IN unsup |55.34/83.76| 61.97/86.63| 65.33/88.47 52 | | CLIP |55.67/83.82| 62.09/86.89| 64.70/88.76 53 | | LUP |57.21/84.68| 63.91/88.36| 65.42/89.36 54 | | LUP-NL |57.35/84.77| 63.71/87.46| 64.68/88.69 55 | | **PLIP(ours)** |**69.23/91.16**| **64.91/88.39**| **67.22/89.49** 56 | 57 | ### ICFG-PEDES dataset (Text Re-ID R@1/R@10) 58 | | Pre-train | CMPM/C | SSAN | LGUR | 59 | | :---: |:---: |:---: | :---: 60 | | IN sup | 47.61/75.48| 54.23/79.53| 57.42/81.45 61 | | IN unsup |48.34/75.66| 55.27/79.64| 59.90/82.94 62 | | CLIP |48.12/75.51| 53.58/78.96| 58.35/82.02 63 | | LUP |50.12/76.23| 56.51/80.41| 60.33/83.06 64 | | LUP-NL |49.64/76.15| 55.59/79.78| 60.25/82.84 65 | | **PLIP(ours)** |**64.25/86.32**| **60.12/82.84**| **62.27/83.96** 66 | 67 | ### Market1501 & DukeMTMC (Image Re-ID mAP/cmc1) 68 | | Methods | Market1501 | DukeMTMC | 69 | | :---: |:---: |:---: 70 | | BOT | 85.9/94.5 |76.4/86.4 71 | | BDB |86.7/95.3| 76.0/89.0 72 | | MGN |87.5/95.1 |79.4/89.0 73 | | ABDNet |88.3/95.6| 78.6/89.0 74 | | **PLIP+BOT** | 88.0/95.1| 77.0/86.5 75 | | **PLIP+BDB** |88.4/95.7| 78.2/89.8 76 | | **PLIP+MGN** |90.6/96.3| **81.7**/90.3 77 | | **PLIP+ABDNet**|**91.2**/**96.7** |81.6/**90.9** 78 | 79 | ### Evaluate on PETA & PA-100K & RAP (PAR mA/F1) 80 | | Methods | PETA | PA-100K | RAP 81 | | :---: |:---: |:---: |:---: 82 | | DeepMAR | 80.14/83.56| 78.28/84.32| 76.81/78.94 83 | | Rethink |83.96/86.35 |80.21/87.40 |79.27/79.95 84 | | VTB |84.12/86.63| 81.02/87.31| 81.43/80.63 85 | | Label2Label |84.08/86.57 |82.24/87.08| 81.82/80.93 86 | | **PLIP+DeepMAR** | 82.46/85.87 |80.33/87.24 |78.96/80.12 87 | | **PLIP+Rethink**|85.56/87.63| 82.09/88.12| 81.87/81.53 88 | | **PLIP+VTB** |86.03/**88.14**| 83.24/88.57 |83.64/**81.78** 89 | | **PLIP+Label2Label** |**86.12**/88.08 |**84.36**/**88.63**| **83.77**/81.49 90 | 91 | 92 | ## Usage 93 | ### Install Requirements 94 | we use 4 RTX3090 24G GPU for training and evaluation. 95 | 96 | Create conda environment. 97 | ``` 98 | conda create --name PLIP --file requirements.txt 99 | conda activate PLIP 100 | ``` 101 | 102 | ### Datasets Prepare 103 | Download the CUHK-PEDES dataset from [here](https://github.com/ShuangLI59/Person-Search-with-Natural-Language-Description) and ICFG-PEDES dataset from [here](https://github.com/zifyloo/SSAN). 104 | 105 | Organize them in `data` folder as follows: 106 | ``` 107 | |-- data/ 108 | | |-- / 109 | | |-- imgs 110 | | |-- cam_a 111 | | |-- cam_b 112 | | |-- ... 113 | | |-- reid_raw.json 114 | | 115 | | |-- / 116 | | |-- imgs 117 | | |-- test 118 | | |-- train 119 | | |-- ICFG_PEDES.json 120 | | 121 | | |-- / 122 | | |-- Part1 123 | | |-- ... 124 | | |-- Part11 125 | | |-- synthpedes_dataset.json 126 | ``` 127 | 128 | ### Zero-shot Inference 129 | Our pre-trained model can directly be transfered to downstream tasks, especially text-based Re-ID. 130 | 131 | 1. Run the python file and generate train/test/valid json files respectively. 132 | ``` 133 | python dataset_split.py 134 | ``` 135 | 136 | 2. Then you can evaluate by running: 137 | ``` 138 | python zs_inference.py 139 | ``` 140 | 141 | ### Fine-tuning Inference 142 | Almost all existing downstream person-centric methods can be improved through replacing the backbone with our pre-trained model. Taking CMPM/C as example: 143 | 144 | 1. Go to the CMPM/C root: 145 | ``` 146 | cd Downstreams/CMPM-C 147 | ``` 148 | 149 | 2. Run the following to train and test. Note that you can modify the code yourself for single GPU training: 150 | ``` 151 | python dataset_split.py 152 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 train.py 153 | python test.py 154 | ``` 155 | 156 | 157 | 158 | ### Evaluate on Other Methods and Tasks. 159 | By simply replacing the visual backbone with our pre-trained model, almost all existing methods on downstream tasks make significant improvements. For example, you can try by the following repositories: 160 | 161 | **Text-based Re-ID:** 162 | [SSAN](https://github.com/zifyloo/SSAN), [LGUR](https://github.com/ZhiyinShao-H/LGUR) 163 | 164 | **Image-based Re-ID:** 165 | [BOT](https://github.com/michuanhaohao/reid-strong-baseline), [MGN](https://github.com/seathiefwang/MGN-pytorch), [ABD-Net](https://github.com/VITA-Group/ABD-Net) 166 | 167 | **Person Attribute Recognition:** 168 | [Rethink](https://github.com/valencebond/Rethinking_of_PAR), [Label2label](https://github.com/Li-Wanhua/Label2Label/tree/main/Pedestrian_Attribute), [VTB](https://github.com/cxh0519/VTB) 169 | 170 | ## Reference 171 | If you use PLIP in your research, please cite it by the following BibTeX entry: 172 | ``` 173 | @inproceedings{ 174 | zuo2024plip, 175 | title={PLIP: Language-Image Pre-training for Person Representation Learning}, 176 | author={Jialong Zuo and Jiahao Hong and Feng Zhang and Changqian Yu and Hanyu Zhou and Changxin Gao and Nong Sang and Jingdong Wang}, 177 | booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}, 178 | year={2024} 179 | } 180 | ``` 181 | -------------------------------------------------------------------------------- /assets/SYNTH-PEDES.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zplusdragon/PLIP/09890e33c52dfbf25068d3ed8a07d79179df98ee/assets/SYNTH-PEDES.png -------------------------------------------------------------------------------- /assets/abstract.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zplusdragon/PLIP/09890e33c52dfbf25068d3ed8a07d79179df98ee/assets/abstract.png -------------------------------------------------------------------------------- /assets/examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zplusdragon/PLIP/09890e33c52dfbf25068d3ed8a07d79179df98ee/assets/examples.png -------------------------------------------------------------------------------- /checkpoints/readme.txt: -------------------------------------------------------------------------------- 1 | You should put the checkpoints file here. -------------------------------------------------------------------------------- /data/readme.txt: -------------------------------------------------------------------------------- 1 | The datasets for training or testing should be put here. 2 | -------------------------------------------------------------------------------- /dataset_split.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def TrainValidTest_split(path,train_path,test_path,valid_path): 4 | with open(path,"r") as f: 5 | dataset = json.load(f) 6 | train_output=[] 7 | test_output=[] 8 | valid_output=[] 9 | for i in range(len(dataset)): 10 | data = dataset[i] 11 | split = data["split"] 12 | if split == "train": 13 | train_output.append(data) 14 | elif split =="test": 15 | test_output.append(data) 16 | else: 17 | valid_output.append(data) 18 | if (i+1) % 100 == 0: 19 | print("{}/{} completed".format(i+1,len(dataset))) 20 | print("The train_set capacity:{}".format(len(train_output))) 21 | print("The test_set capacity:{}".format(len(test_output))) 22 | print("The valid_set capacity:{}".format(len(valid_output))) 23 | with open(train_path,"w") as f : 24 | json.dump(train_output,f,indent=4) 25 | with open(test_path,"w") as f : 26 | json.dump(test_output,f,indent=4) 27 | with open(valid_path,"w") as f : 28 | json.dump(valid_output,f,indent=4) 29 | 30 | if __name__ =="__main__": 31 | train_path = "data/CUHK-PEDES/CUHK-PEDES-train.json" 32 | test_path = "data/CUHK-PEDES/CUHK-PEDES-test.json" 33 | valid_path = "data/CUHK-PEDES/CUHK-PEDES-valid.json" 34 | dataset_path = "data/CUHK-PEDES/reid_raw.json" 35 | TrainValidTest_split(dataset_path, train_path, test_path,valid_path) 36 | 37 | train_path = "data/ICFG-PEDES/ICFG-PEDES-train.json" 38 | test_path = "data/ICFG-PEDES/ICFG-PEDES-test.json" 39 | valid_path = "data/ICFG-PEDES/ICFG-PEDES-valid.json" 40 | dataset_path = "data/ICFG-PEDES/ICFG_PEDES.json" 41 | TrainValidTest_split(dataset_path, train_path, test_path, valid_path) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main.conda 5 | _openmp_mutex=5.1=1_gnu.conda 6 | blas=1.0=mkl.conda 7 | brotli=1.0.9=he6710b0_2.conda 8 | brotlipy=0.7.0=py310h7f8727e_1002.conda 9 | bzip2=1.0.8=h7b6447c_0.conda 10 | ca-certificates=2023.01.10=h06a4308_0.conda 11 | certifi=2022.12.7=py310h06a4308_0.conda 12 | cffi=1.15.0=py310hd667e15_1.conda 13 | charset-normalizer=2.0.4=pyhd3eb1b0_0.conda 14 | cryptography=37.0.1=py310h9ce1e76_0.conda 15 | cudatoolkit=11.3.1=h2bc3f7f_2.conda 16 | cycler=0.11.0=pyhd3eb1b0_0.conda 17 | dbus=1.13.18=hb2f20db_0.conda 18 | expat=2.4.4=h295c915_0.conda 19 | ffmpeg=4.3=hf484d3e_0 20 | flit-core=3.8.0=py310h06a4308_0.conda 21 | fontconfig=2.13.1=h6c09931_0.conda 22 | fonttools=4.25.0=pyhd3eb1b0_0.conda 23 | freetype=2.11.0=h70c0345_0.conda 24 | giflib=5.2.1=h7b6447c_0.conda 25 | glib=2.69.1=h4ff587b_1.conda 26 | gmp=6.2.1=h295c915_3.conda 27 | gnutls=3.6.15=he1e5248_0.conda 28 | gst-plugins-base=1.14.0=h8213a91_2.conda 29 | gstreamer=1.14.0=h28cd5cc_2.conda 30 | icu=58.2=he6710b0_3.conda 31 | idna=3.3=pyhd3eb1b0_0.conda 32 | imageio=2.9.0=pyhd3eb1b0_0.conda 33 | intel-openmp=2021.4.0=h06a4308_3561.conda 34 | jpeg=9e=h7f8727e_0.conda 35 | kiwisolver=1.4.2=py310h295c915_0.conda 36 | lame=3.100=h7b6447c_0.conda 37 | lcms2=2.12=h3be6417_0.conda 38 | ld_impl_linux-64=2.38=h1181459_1.conda 39 | libffi=3.3=he6710b0_2.conda 40 | libgcc-ng=11.2.0=h1234567_1.conda 41 | libgomp=11.2.0=h1234567_1.conda 42 | libiconv=1.16=h7f8727e_2.conda 43 | libidn2=2.3.2=h7f8727e_0.conda 44 | libpng=1.6.37=hbc83047_0.conda 45 | libstdcxx-ng=11.2.0=h1234567_1.conda 46 | libtasn1=4.16.0=h27cfd23_0.conda 47 | libtiff=4.2.0=h2818925_1.conda 48 | libunistring=0.9.10=h27cfd23_0.conda 49 | libuuid=1.0.3=h7f8727e_2.conda 50 | libwebp=1.2.2=h55f646e_0.conda 51 | libwebp-base=1.2.2=h7f8727e_0.conda 52 | libxcb=1.15=h7f8727e_0.conda 53 | libxml2=2.9.14=h74e7548_0.conda 54 | lz4-c=1.9.3=h295c915_1.conda 55 | matplotlib=3.5.1=py310h06a4308_1.conda 56 | matplotlib-base=3.5.1=py310ha18d171_1.conda 57 | mkl=2021.4.0=h06a4308_640.conda 58 | mkl-service=2.4.0=py310h7f8727e_0.conda 59 | mkl_fft=1.3.1=py310hd6ae3a3_0.conda 60 | mkl_random=1.2.2=py310h00e6091_0.conda 61 | munkres=1.1.4=py_0.conda 62 | ncurses=6.3=h5eee18b_3.conda 63 | nettle=3.7.3=hbbd107a_1.conda 64 | numpy=1.22.3=py310hfa59a62_0.conda 65 | numpy-base=1.22.3=py310h9585f30_0.conda 66 | openh264=2.1.1=h4ff587b_0.conda 67 | openssl=1.1.1t=h7f8727e_0.conda 68 | packaging=21.3=pyhd3eb1b0_0.conda 69 | pcre=8.45=h295c915_0.conda 70 | pillow=9.2.0=py310hace64e9_1.conda 71 | prettytable=3.5.0=py310h06a4308_0.conda 72 | pycparser=2.21=pyhd3eb1b0_0.conda 73 | pyopenssl=22.0.0=pyhd3eb1b0_0.conda 74 | pyparsing=3.0.4=pyhd3eb1b0_0.conda 75 | pyqt=5.9.2=py310h295c915_6.conda 76 | pysocks=1.7.1=py310h06a4308_0.conda 77 | python=3.10.4=h12debd9_0 78 | python-dateutil=2.8.2=pyhd3eb1b0_0.conda 79 | pytorch=1.12.0=py3.10_cuda11.3_cudnn8.3.2_0 80 | pytorch-mutex=1.0=cuda 81 | qt=5.9.7=h5867ecd_1.conda 82 | readline=8.1.2=h7f8727e_1.conda 83 | requests=2.28.1=py310h06a4308_0.conda 84 | setuptools=61.2.0=py310h06a4308_0.conda 85 | sip=4.19.13=py310h295c915_0.conda 86 | six=1.16.0=pyhd3eb1b0_1.conda 87 | sqlite=3.38.5=hc218d9a_0.conda 88 | tk=8.6.12=h1ccaba5_0.conda 89 | torchaudio=0.12.0=py310_cu113 90 | torchvision=0.13.0=py310_cu113 91 | tornado=6.1=py310h7f8727e_0.conda 92 | tqdm=4.64.0=py310h06a4308_0.conda 93 | typing_extensions=4.4.0=py310h06a4308_0.conda 94 | tzdata=2022a=hda174b7_0.conda 95 | urllib3=1.26.9=py310h06a4308_0.conda 96 | wcwidth=0.2.5=pyhd3eb1b0_0.conda 97 | wheel=0.37.1=pyhd3eb1b0_0.conda 98 | xz=5.2.5=h7f8727e_1.conda 99 | yaml=0.2.5=h7b6447c_0.conda 100 | zlib=1.2.12=h7f8727e_2.conda 101 | zstd=1.5.2=ha4553b6_0.conda 102 | -------------------------------------------------------------------------------- /test_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import os 4 | import json 5 | from transformers import BertTokenizer 6 | import random 7 | import copy 8 | from torch.utils.data import DataLoader 9 | from PIL import Image 10 | from prefetch_generator import BackgroundGenerator 11 | import numpy as np 12 | 13 | class Dataset_test_image(data.Dataset): 14 | def __init__(self, image_path, dataset_path, transform=None): 15 | assert transform is not None, 'transform must not be None' 16 | self.impath = image_path 17 | self.datapath = dataset_path 18 | with open(dataset_path, 'r', encoding='utf8') as fp: 19 | self.dataset = json.load(fp) 20 | self.transform = transform 21 | print("Information about image gallery:{}".format(len(self))) 22 | 23 | def __getitem__(self, index): 24 | label = self.dataset[index]["id"] 25 | file_path = self.dataset[index]["file_path"] 26 | image = Image.open(os.path.join(self.impath, file_path)).convert('RGB') 27 | image_gt = self.transform(image) 28 | label = torch.tensor(label) 29 | return label,image_gt 30 | 31 | def __len__(self): 32 | return len(self.dataset) 33 | 34 | class Dataset_test_text(data.Dataset): 35 | def __init__(self, image_path, dataset_path): 36 | self.impath = image_path 37 | self.datapath = dataset_path 38 | with open(dataset_path, 'r', encoding='utf8') as fp: 39 | self.dataset = json.load(fp) 40 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 41 | self.initial_data = [] 42 | self.caption_depart_initial() 43 | print("Information about text query:{}".format(len(self))) 44 | 45 | def caption_to_tokens(self, caption): 46 | result = self.tokenizer(caption, padding="max_length", max_length=64, truncation=True, return_tensors='pt') 47 | token, mask = result["input_ids"], result["attention_mask"] 48 | token, mask = token.squeeze(), mask.squeeze() 49 | return token, mask 50 | 51 | def caption_depart_initial(self): 52 | for i in range(len(self.dataset)): 53 | item = self.dataset[i] 54 | label = item["id"] 55 | captions_list = item["captions"] 56 | for j in range(len(captions_list)): 57 | caption = captions_list[j] 58 | self.initial_data.append([label,caption]) 59 | 60 | def __getitem__(self, index): 61 | caption = self.initial_data[index][1] 62 | label = self.initial_data[index][0] 63 | caption_tokens, masks = self.caption_to_tokens(caption) 64 | caption_tokens = torch.tensor(caption_tokens) 65 | label = torch.tensor(label) 66 | return label, caption_tokens, masks 67 | 68 | def __len__(self): 69 | return len(self.initial_data) 70 | 71 | def get_loader_test(args, transform, batch_size, num_workers): 72 | image_path = args.image_path 73 | test_path = args.test_path 74 | dataset_image = Dataset_test_image(image_path, test_path,transform=transform) 75 | dataset_text = Dataset_test_text(image_path,test_path) 76 | image_dataloader = DataLoader(dataset_image, batch_size=batch_size, shuffle=False, num_workers=num_workers) 77 | text_dataloader = DataLoader(dataset_text, batch_size=batch_size, shuffle=False, num_workers=num_workers) 78 | return image_dataloader, text_dataloader -------------------------------------------------------------------------------- /textual_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from transformers import BertModel 5 | from torch.nn import init 6 | from einops import rearrange 7 | 8 | class Textual_encoder(nn.Module): 9 | def __init__(self, encoder_type: str): 10 | super(Textual_encoder, self).__init__() 11 | self.encoder = BertModel.from_pretrained(encoder_type) 12 | unfreeze_layers = ['layer.8','layer.9','layer.10', 'layer.11', 'pooler'] 13 | for name, param in self.encoder.named_parameters(): 14 | param.requires_grad = False 15 | for ele in unfreeze_layers: 16 | if ele in name: 17 | param.requires_grad = True 18 | break 19 | 20 | def get_global_embedding(self,token,mask): 21 | x = self.encoder(input_ids=token, attention_mask=mask) 22 | pooler_output = x.pooler_output 23 | return pooler_output 24 | 25 | def get_local_embedding(self,token,mask): 26 | x = self.encoder(input_ids=token, attention_mask=mask) 27 | hidden_states = x.last_hidden_state 28 | return hidden_states 29 | 30 | def forward(self, token, mask): 31 | x = self.encoder(input_ids=token, attention_mask=mask) 32 | hidden_states = x.last_hidden_state 33 | pooler_output = x.pooler_output 34 | return pooler_output, hidden_states 35 | 36 | 37 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import sys 3 | import os.path as osp 4 | import torch.utils.data as data 5 | import os 6 | import torch 7 | import numpy as np 8 | import random 9 | import json 10 | import torch.nn as nn 11 | from einops.layers.torch import Rearrange 12 | from transformers import get_linear_schedule_with_warmup 13 | import yaml 14 | 15 | 16 | def mask_ratio_scheduler(current_epoch, total_epochs,lower_ratio,upper_ratio,ratio_type="linear"): 17 | if ratio_type == "linear": 18 | return (current_epoch/total_epochs)*(upper_ratio-lower_ratio)+lower_ratio 19 | elif ratio_type == "square": 20 | return ((current_epoch/total_epochs)**2)*(upper_ratio-lower_ratio)+lower_ratio 21 | elif ratio_type == "squareroot": 22 | return ((current_epoch / total_epochs) ** (1/2)) * (upper_ratio - lower_ratio) + lower_ratio 23 | 24 | def setup_seed(seed): 25 | print('The seed is {}'.format(seed)) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | np.random.seed(seed) # Numpy module. 30 | random.seed(seed) # Python random module. 31 | torch.backends.cudnn.benchmark = False 32 | torch.backends.cudnn.deterministic = True 33 | 34 | def image_to_patch(image,patch_size): 35 | patch_height, patch_width = (patch_size,patch_size) 36 | patch_embedding = nn.Sequential( 37 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width)) 38 | x = patch_embedding(image) 39 | return x 40 | 41 | def lr_scheduler(optimizer, args,loader_length): 42 | if args.lr_decay_type == "get_linear_schedule_with_warmup": 43 | scheduler = get_linear_schedule_with_warmup( 44 | optimizer, num_warmup_steps=int(args.epochs * loader_length * 0.1/args.accumulation_steps), 45 | num_training_steps=int(args.epochs * loader_length/args.accumulation_steps) 46 | ) 47 | print("lr_scheduler is get_linear_schedule_with_warmup") 48 | else: 49 | if '_' in args.epoches_decay: 50 | epoches_list = args.epoches_decay.split('_') 51 | epoches_list = [int(e) for e in epoches_list] 52 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, epoches_list, gamma=args.lr_decay_ratio) 53 | print("lr_scheduler is MultiStepLR") 54 | else: 55 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, int(args.epoches_decay), gamma=args.lr_decay_ratio) 56 | print("lr_scheduler is StepLR") 57 | return scheduler 58 | 59 | def load_checkpoint(model,resume): 60 | start_epoch=0 61 | if os.path.isfile(resume): 62 | checkpoint = torch.load(resume) 63 | # checkpoint= torch.load(resume, map_location='cuda:0') 64 | start_epoch = checkpoint['epoch'] 65 | model.load_state_dict(checkpoint['state_dict']) 66 | print('Load checkpoint at epoch %d.' % (start_epoch)) 67 | return start_epoch,model 68 | 69 | def save_checkpoint(state, epoch, dst): 70 | if not os.path.exists(dst): 71 | os.makedirs(dst) 72 | filename = os.path.join(dst, str(epoch)) + '.pth.tar' 73 | torch.save(state, filename) 74 | 75 | def gradual_warmup(epoch,optimizer,epochs): 76 | if epoch == 0: 77 | warmup_percent_done = (epoch + 1) / epochs 78 | else: 79 | warmup_percent_done = (epoch + 1) / epoch 80 | for param_group in optimizer.param_groups: 81 | param_group['lr'] = param_group['lr']*warmup_percent_done 82 | return optimizer 83 | 84 | def compute_topk(query, gallery, target_query, target_gallery, k=[1,10], reverse=False): 85 | result = [] 86 | query = query / (query.norm(dim=1,keepdim=True)+1e-12) 87 | gallery = gallery / (gallery.norm(dim=1,keepdim=True)+1e-12) 88 | sim_cosine = torch.matmul(query, gallery.t()) 89 | result.extend(topk(sim_cosine, target_gallery, target_query, k)) 90 | if reverse: 91 | result.extend(topk(sim_cosine, target_query, target_gallery, k, dim=0)) 92 | return result 93 | 94 | def topk(sim, target_gallery, target_query, k=[1,10], dim=1): 95 | result = [] 96 | maxk = max(k) 97 | size_total = len(target_gallery) 98 | _, pred_index = sim.topk(maxk, dim, True, True) 99 | pred_labels = target_gallery[pred_index] 100 | if dim == 1: 101 | pred_labels = pred_labels.t() 102 | correct = pred_labels.eq(target_query.view(1,-1).expand_as(pred_labels)) 103 | 104 | for topk in k: 105 | correct_k = torch.sum(correct[:topk], dim=0) 106 | correct_k = torch.sum(correct_k > 0).float() 107 | result.append(correct_k * 100 / size_total) 108 | return result 109 | 110 | def test_map(query_feature,query_label,gallery_feature, gallery_label): 111 | query_feature = query_feature / (query_feature.norm(dim=1, keepdim=True) + 1e-12) 112 | gallery_feature = gallery_feature / (gallery_feature.norm(dim=1, keepdim=True) + 1e-12) 113 | CMC = torch.IntTensor(len(gallery_label)).zero_() 114 | ap = 0.0 115 | for i in range(len(query_label)): 116 | ap_tmp, CMC_tmp = evaluate(query_feature[i], query_label[i], gallery_feature, gallery_label) 117 | 118 | if CMC_tmp[0] == -1: 119 | continue 120 | CMC = CMC + CMC_tmp 121 | ap += ap_tmp 122 | CMC = CMC.float() 123 | CMC = CMC / len(query_label) 124 | print('Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f' % (CMC[0], CMC[4], CMC[9], ap / len(query_label))) 125 | return CMC[0], CMC[4], CMC[9], ap / len(query_label) 126 | 127 | def evaluate(qf, ql, gf, gl): 128 | query = qf.view(-1, 1) 129 | score = torch.mm(gf, query) 130 | score = score.squeeze(1).cpu() 131 | score = score.numpy() 132 | index = np.argsort(score) 133 | index = index[::-1] 134 | gl=gl.cuda().data.cpu().numpy() 135 | ql=ql.cuda().data.cpu().numpy() 136 | query_index = np.argwhere(gl == ql) 137 | CMC_tmp = compute_mAP(index, query_index) 138 | return CMC_tmp 139 | 140 | def compute_mAP(index, good_index): 141 | ap = 0 142 | cmc = torch.IntTensor(len(index)).zero_() 143 | if good_index.size == 0: # if empty 144 | cmc[0] = -1 145 | return ap, cmc 146 | # find good_index index 147 | ngood = len(good_index) 148 | mask = np.in1d(index, good_index) 149 | rows_good = np.argwhere(mask == True) 150 | rows_good = rows_good.flatten() 151 | 152 | cmc[rows_good[0]:] = 1 153 | for i in range(ngood): 154 | d_recall = 1.0 / ngood 155 | precision = (i + 1) * 1.0 / (rows_good[i] + 1) 156 | if rows_good[i] != 0: 157 | old_precision = i * 1.0 / rows_good[i] 158 | else: 159 | old_precision = 1.0 160 | ap = ap + d_recall * (old_precision + precision) / 2 161 | 162 | return ap, cmc 163 | 164 | def mkdir_if_missing(directory): 165 | if not osp.exists(directory): 166 | try: 167 | os.makedirs(directory) 168 | except OSError as e: 169 | if e.errno != errno.EEXIST: 170 | raise 171 | 172 | def PrintInformation(args,model): 173 | # with open(args.annotation_path,"r") as f: 174 | # dataset = json.load(f) 175 | # num = len(dataset) 176 | print("The image model is: {}".format(args.img_backbone)) 177 | print("The language model is: {}".format(args.txt_backbone)) 178 | print('Number of model parameters: {}'.format( 179 | sum([p.data.nelement() for p in model.parameters()]))) 180 | print("The VAP task is {}".format(args.vap_type)) 181 | print("Checkpoints are saved in: {}".format(args.name)) 182 | #print("The number of training samples is {}".format(12498736)) 183 | print("The original learning rate is {}".format(args.lr)) 184 | print("The pretrain setting is {}".format(args.pretrain)) 185 | if args.content_mask_ratio == args.content_mask_ratio_upper: 186 | print("The mask ratio schedule: Constant") 187 | else: 188 | print("The mask ratio schedule:{}".format(args.mask_ratio_schedule)) 189 | print("The VLM loss :{}".format(args.VLM_loss)) 190 | print("The SIC loss :{}".format(args.SIC_loss)) 191 | 192 | def train_information_setting(args): #配置好训练信息的输出文件 193 | name = args.name 194 | # set some paths 195 | log_dir = args.log_dir 196 | log_dir = os.path.join(log_dir, name) 197 | if args.CONTINUE == False: 198 | sys.stdout = Logger(os.path.join(log_dir, "train_log.txt")) 199 | opt_dir = os.path.join('log', name) 200 | if not os.path.exists(opt_dir): 201 | os.makedirs(opt_dir) 202 | with open('%s/opts_train.yaml' % opt_dir, 'w') as fp: 203 | yaml.dump(vars(args), fp, default_flow_style=False) 204 | else: 205 | sys.stdout = Logger(os.path.join(log_dir, "train_log_CONTINUE.txt")) 206 | opt_dir = os.path.join('log', name) 207 | if not os.path.exists(opt_dir): 208 | os.makedirs(opt_dir) 209 | with open('%s/opts_train_CONTINUE.yaml' % opt_dir, 'w') as fp: 210 | yaml.dump(vars(args), fp, default_flow_style=False) 211 | 212 | class Logger(object): 213 | """ 214 | Write console output to external text file. 215 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 216 | """ 217 | def __init__(self, fpath=None): 218 | self.console = sys.stdout 219 | self.file = None 220 | if fpath is not None: 221 | mkdir_if_missing(os.path.dirname(fpath)) 222 | self.file = open(fpath, 'w') 223 | 224 | def __del__(self): 225 | self.close() 226 | 227 | def __enter__(self): 228 | pass 229 | 230 | def __exit__(self, *args): 231 | self.close() 232 | 233 | def write(self, msg): 234 | self.console.write(msg) 235 | if self.file is not None: 236 | self.file.write(msg) 237 | 238 | def flush(self): 239 | self.console.flush() 240 | if self.file is not None: 241 | self.file.flush() 242 | os.fsync(self.file.fileno()) 243 | 244 | def close(self): 245 | self.console.close() 246 | if self.file is not None: 247 | self.file.close() 248 | 249 | class Vocabulary(object): 250 | """Simple vocabulary wrapper.""" 251 | def __init__(self): 252 | self.word2idx = {} 253 | self.idx2word = {} 254 | self.idx = 0 255 | 256 | def add_word(self, word): 257 | if not word in self.word2idx: 258 | self.word2idx[word] = self.idx 259 | self.idx2word[self.idx] = word 260 | self.idx += 1 261 | 262 | def __call__(self, word): 263 | if not word in self.word2idx: 264 | return self.word2idx['[UNK]'] 265 | return self.word2idx[word] 266 | 267 | def __len__(self): 268 | return len(self.word2idx) 269 | 270 | def model_resume_setting(model,Loss,args): 271 | checkpoint_dir = os.path.join(args.checkpoint_dir, args.name) 272 | model_files_list = os.listdir(checkpoint_dir) 273 | model_files_list = [int(x[:-8]) for x in model_files_list if x[-3:] == 'tar'] 274 | current_epoch = max(model_files_list) 275 | print("Continue the training at Epoch{}".format(current_epoch)) 276 | assert current_epoch>=10,'Current epoch must be greater than the warm_up epoch!' 277 | if current_epoch >= int(args.epoches_decay): 278 | args.lr = args.lr*args.lr_decay_ratio 279 | args.epochs = args.epochs - current_epoch 280 | args.warm_epoch = 0 281 | args.epoches_decay = '999' 282 | if current_epoch < int(args.epoches_decay): 283 | args.lr = args.lr 284 | args.epochs = args.epochs - current_epoch 285 | args.warm_epoch = 0 286 | args.epoches_decay = str(int(args.epoches_decay)-current_epoch) 287 | model_file = os.path.join(checkpoint_dir, str(current_epoch) + ".pth.tar") 288 | checkpoint = torch.load(model_file, map_location="cpu") 289 | model.image_encoder.load_state_dict(checkpoint["ImgEncoder_state_dict"]) 290 | model.text_encoder.load_state_dict(checkpoint["TxtEncoder_state_dict"]) 291 | Loss.load_state_dict(checkpoint["Decoder_state_dict"]) 292 | return current_epoch -------------------------------------------------------------------------------- /visual_model.py: -------------------------------------------------------------------------------- 1 | from torch import nn, einsum 2 | import torch.nn.functional as F 3 | from einops import rearrange, repeat 4 | from einops.layers.torch import Rearrange 5 | import os, sys 6 | from torch.nn import init 7 | import torchvision.models as models 8 | from collections import OrderedDict 9 | from typing import Tuple, Union 10 | import numpy as np 11 | import torch 12 | 13 | #***************************************************************************************** 14 | #MResNet,visual encoder. 15 | #***************************************************************************************** 16 | class Bottleneck(nn.Module): 17 | expansion = 4 18 | def __init__(self, inplanes, planes, stride=1): 19 | super().__init__() 20 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 21 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 22 | self.bn1 = nn.BatchNorm2d(planes) 23 | self.relu1 = nn.ReLU(inplace=True) 24 | 25 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 26 | self.bn2 = nn.BatchNorm2d(planes) 27 | self.relu2 = nn.ReLU(inplace=True) 28 | 29 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 30 | 31 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 32 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 33 | self.relu3 = nn.ReLU(inplace=True) 34 | 35 | self.downsample = None 36 | self.stride = stride 37 | 38 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 39 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 40 | self.downsample = nn.Sequential(OrderedDict([ 41 | ("-1", nn.AvgPool2d(stride)), 42 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 43 | ("1", nn.BatchNorm2d(planes * self.expansion)) 44 | ])) 45 | 46 | def forward(self, x: torch.Tensor): 47 | identity = x 48 | 49 | out = self.relu1(self.bn1(self.conv1(x))) 50 | out = self.relu2(self.bn2(self.conv2(out))) 51 | out = self.avgpool(out) 52 | out = self.bn3(self.conv3(out)) 53 | 54 | if self.downsample is not None: 55 | identity = self.downsample(x) 56 | 57 | out = out+identity 58 | out = self.relu3(out) 59 | return out 60 | 61 | class AttentionPool2d(nn.Module): 62 | def __init__(self, spacial_dim_x: int,spacial_dim_y: int, embed_dim: int, num_heads: int, output_dim: int = None): 63 | super().__init__() 64 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim_x * spacial_dim_y + 1, embed_dim) / embed_dim ** 0.5) 65 | self.k_proj = nn.Linear(embed_dim, embed_dim) 66 | self.q_proj = nn.Linear(embed_dim, embed_dim) 67 | self.v_proj = nn.Linear(embed_dim, embed_dim) 68 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 69 | self.num_heads = num_heads 70 | 71 | def forward(self, x): 72 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 73 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 74 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 75 | x, _ = F.multi_head_attention_forward( 76 | query=x[:1], key=x, value=x, 77 | embed_dim_to_check=x.shape[-1], 78 | num_heads=self.num_heads, 79 | q_proj_weight=self.q_proj.weight, 80 | k_proj_weight=self.k_proj.weight, 81 | v_proj_weight=self.v_proj.weight, 82 | in_proj_weight=None, 83 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 84 | bias_k=None, 85 | bias_v=None, 86 | add_zero_attn=False, 87 | dropout_p=0, 88 | out_proj_weight=self.c_proj.weight, 89 | out_proj_bias=self.c_proj.bias, 90 | use_separate_proj_weight=True, 91 | training=self.training, 92 | need_weights=True 93 | ) 94 | return x.squeeze(0) 95 | 96 | class Image_encoder_ModifiedResNet(nn.Module): 97 | """ 98 | A ResNet class that is similar to torchvision's but contains the following changes: 99 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 100 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 101 | - The final pooling layer is a QKV attention instead of an average pool 102 | """ 103 | def __init__(self, layers, output_dim, heads, input_resolution=[256,128], width=64): 104 | super().__init__() 105 | self.output_dim = output_dim 106 | self.input_resolution = input_resolution 107 | 108 | # the 3-layer stem 109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 110 | self.bn1 = nn.BatchNorm2d(width // 2) 111 | self.relu1 = nn.ReLU(inplace=True) 112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 113 | self.bn2 = nn.BatchNorm2d(width // 2) 114 | self.relu2 = nn.ReLU(inplace=True) 115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 116 | self.bn3 = nn.BatchNorm2d(width) 117 | self.relu3 = nn.ReLU(inplace=True) 118 | self.avgpool = nn.AvgPool2d(2) 119 | 120 | # residual layers 121 | self._inplanes = width # this is a *mutable* variable used during construction 122 | self.layer1 = self._make_layer(width, layers[0]) 123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 126 | 127 | embed_dim = width * 32 # the ResNet feature dimension 128 | self.attnpool = AttentionPool2d(input_resolution[0] // 32,input_resolution[1] // 32, embed_dim, heads, output_dim) 129 | self.initialize_parameters() 130 | 131 | def _make_layer(self, planes, blocks, stride=1): 132 | layers = [Bottleneck(self._inplanes, planes, stride)] 133 | 134 | self._inplanes = planes * Bottleneck.expansion 135 | for _ in range(1, blocks): 136 | layers.append(Bottleneck(self._inplanes, planes)) 137 | 138 | return nn.Sequential(*layers) 139 | 140 | def initialize_parameters(self): 141 | if self.attnpool is not None: 142 | std = self.attnpool.c_proj.in_features ** -0.5 143 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 144 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 145 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 146 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 147 | 148 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 149 | for name, param in resnet_block.named_parameters(): 150 | if name.endswith("bn3.weight"): 151 | nn.init.zeros_(param) 152 | 153 | def forward(self, x): 154 | def stem(x): 155 | x = self.relu1(self.bn1(self.conv1(x))) 156 | x = self.relu2(self.bn2(self.conv2(x))) 157 | x = self.relu3(self.bn3(self.conv3(x))) 158 | x = self.avgpool(x) 159 | return x 160 | 161 | x = x.type(self.conv1.weight.dtype) 162 | x = stem(x) 163 | x1 = self.layer1(x) 164 | x2 = self.layer2(x1) 165 | x3 = self.layer3(x2) 166 | x4 = self.layer4(x3) 167 | feat = self.attnpool(x4) 168 | return feat,x1,x2,x3,x4 169 | 170 | 171 | 172 | -------------------------------------------------------------------------------- /zs_infer.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | import torch 3 | import yaml 4 | from utils import * 5 | import time 6 | import os 7 | import shutil 8 | import torch.backends.cudnn as cudnn 9 | from test_dataloader import get_loader_test 10 | import pickle 11 | import argparse 12 | from PLIPmodel import Create_PLIP_Model 13 | 14 | def Test_parse_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--model_path', type=str, default='checkpoints/PLIP_MRN50.pth.tar') 17 | parser.add_argument('--image_path', type=str, default='data/CUHK-PEDES/imgs') 18 | parser.add_argument('--test_path', type=str, 19 | default='data/CUHK-PEDES/CUHK-PEDES-test.json', 20 | help='path for test annotation json file') 21 | # *********************************************************************************************************************** 22 | # 设置模型backbone的类型和参数 23 | parser.add_argument('--plip_model', type=str, default='MResNet_BERT') 24 | parser.add_argument('--img_backbone', type=str, default='ModifiedResNet', 25 | help="ResNet:xxx, ModifiedResNet, ViT:xxx") 26 | parser.add_argument('--txt_backbone', type=str, default="bert-base-uncased") 27 | parser.add_argument('--img_dim', type=int, default=768, help='dimension of image embedding vectors') 28 | parser.add_argument('--text_dim', type=int, default=768, help='dimension of text embedding vectors') 29 | parser.add_argument('--layers', type=list, default=[3, 4, 6, 3], help='Just for ModifiedResNet model') 30 | parser.add_argument('--heads', type=int, default=8, help='Just for ModifiedResNet model') 31 | parser.add_argument('--height', type=int, default=256) 32 | parser.add_argument('--width', type=int, default=128) 33 | 34 | # 设置超参数 35 | parser.add_argument('--batch_size', type=int, default=128) 36 | parser.add_argument('--num_workers', type=int, default=4) 37 | parser.add_argument('--device',type=str,default="cuda:0") 38 | parser.add_argument('--feature_size', type=int, default=768) 39 | args = parser.parse_args() 40 | return args 41 | 42 | def test(image_test_loader,text_test_loader, model): 43 | # switch to evaluate mode 44 | model = model.eval() 45 | device = next(model.parameters()).device 46 | 47 | qids, gids, qfeats, gfeats = [], [], [], [] 48 | # text 49 | for pid, caption,mask in text_test_loader: 50 | caption = caption.to(device) 51 | mask = mask.to(device) 52 | with torch.no_grad(): 53 | text_feat = model.get_text_global_embedding(caption,mask) 54 | qids.append(pid.view(-1)) # flatten 55 | qfeats.append(text_feat) 56 | qids = torch.cat(qids, 0) 57 | qfeats = torch.cat(qfeats, 0) 58 | 59 | # image 60 | for pid, img in image_test_loader: 61 | img = img.to(device) 62 | with torch.no_grad(): 63 | img_feat = model.get_image_embeddings(img) 64 | gids.append(pid.view(-1)) # flatten 65 | gfeats.append(img_feat) 66 | gids = torch.cat(gids, 0) 67 | gfeats = torch.cat(gfeats, 0) 68 | ac_top1_t2i, ac_top5_t2i, ac_top10_t2i, mAP = test_map(qfeats, qids, gfeats, gids) 69 | return ac_top1_t2i, ac_top5_t2i, ac_top10_t2i, mAP 70 | 71 | 72 | def Test_main(args): 73 | device = args.device 74 | transform = transforms.Compose([ 75 | transforms.Resize((256,128), interpolation=3), 76 | transforms.ToTensor(), 77 | transforms.Normalize((0.357, 0.323, 0.328), 78 | (0.252, 0.242, 0.239)) 79 | ]) 80 | 81 | image_test_loader,text_test_loader = get_loader_test(args, transform, args.batch_size,args.num_workers) 82 | model = Create_PLIP_Model(args).to(device) 83 | 84 | model_file = args.model_path 85 | print(model_file) 86 | if os.path.isdir(model_file): 87 | continue 88 | checkpoint = torch.load(model_file,map_location='cpu') 89 | model.image_encoder.load_state_dict(checkpoint["ImgEncoder_state_dict"]) 90 | model.text_encoder.load_state_dict(checkpoint["TxtEncoder_state_dict"]) 91 | ac_top1_t2i, ac_top5_t2i, ac_top10_t2i, mAP = test(image_test_loader,text_test_loader, model) 92 | 93 | print('R@1:{:.5f} R@5:{:.5f} R@10:{:.5f} mAP:{:.5f}'.format(ac_t2i_top1, ac_t2i_top5, ac_t2i_top10, mAP)) 94 | 95 | import warnings 96 | warnings.filterwarnings("ignore") 97 | if __name__ == '__main__': 98 | args = Test_parse_args() 99 | Test_main(args) 100 | --------------------------------------------------------------------------------