├── Downstreams ├── CMPM-C │ ├── dataloader.py │ ├── dataset_split.py │ ├── loss.py │ ├── model.py │ ├── optimizer.py │ ├── test.py │ ├── train.py │ └── utils.py └── readme.txt ├── LICENSE ├── PLIPmodel.py ├── README.md ├── assets ├── SYNTH-PEDES.png ├── abstract.png └── examples.png ├── checkpoints └── readme.txt ├── data └── readme.txt ├── dataset_split.py ├── requirements.txt ├── test_dataloader.py ├── textual_model.py ├── utils.py ├── visual_model.py └── zs_infer.py /Downstreams/CMPM-C/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import os 4 | import json 5 | from transformers import BertTokenizer 6 | import random 7 | import copy 8 | from torch.utils.data import DataLoader 9 | from PIL import Image 10 | from prefetch_generator import BackgroundGenerator 11 | import numpy as np 12 | 13 | class DataLoaderX(DataLoader): 14 | def __iter__(self): 15 | return BackgroundGenerator(super().__iter__()) 16 | 17 | class Dataset(data.Dataset): 18 | def __init__(self, image_path, dataset_path, transform=None ,flip_transform=None): 19 | assert transform is not None, 'transform must not be None' 20 | self.impath = image_path 21 | self.datapath = dataset_path 22 | with open(dataset_path, 'r', encoding='utf8') as fp: 23 | self.dataset = json.load(fp) 24 | self.transform = transform 25 | self.flip_transform = flip_transform 26 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 27 | 28 | def caption_to_tokens(self, caption): 29 | result = self.tokenizer(caption, padding="max_length", max_length=64, truncation=True, return_tensors='pt') 30 | token, mask = result["input_ids"], result["attention_mask"] 31 | token, mask = token.squeeze(), mask.squeeze() 32 | return token,mask 33 | 34 | def __getitem__(self, index): 35 | caption = self.dataset[index]["captions"][0] 36 | label = self.dataset[index]["id"] 37 | file_path = self.dataset[index]["file_path"] 38 | image = Image.open(os.path.join(self.impath, file_path)).convert('RGB') 39 | image_gt = self.transform(image) 40 | tokens,masks = self.caption_to_tokens(caption) 41 | tokens = torch.tensor(tokens) 42 | label = torch.tensor(label) 43 | if self.flip_transform == None: 44 | return image_gt, tokens, masks, label 45 | else: 46 | return image_gt, self.flip_transform(image), tokens, masks, label 47 | 48 | def __len__(self): 49 | return len(self.dataset) 50 | 51 | class Dataset_test_image(data.Dataset): 52 | def __init__(self, image_path, dataset_path, transform=None): 53 | assert transform is not None, 'transform must not be None' 54 | self.impath = image_path 55 | self.datapath = dataset_path 56 | with open(dataset_path, 'r', encoding='utf8') as fp: 57 | self.dataset = json.load(fp) 58 | self.transform = transform 59 | print("Information about image gallery:{}".format(len(self))) 60 | 61 | def __getitem__(self, index): 62 | label = self.dataset[index]["id"] 63 | file_path = self.dataset[index]["file_path"] 64 | image = Image.open(os.path.join(self.impath, file_path)).convert('RGB') 65 | image_gt = self.transform(image) 66 | label = torch.tensor(label) 67 | return label,image_gt 68 | 69 | def __len__(self): 70 | return len(self.dataset) 71 | 72 | class Dataset_test_text(data.Dataset): 73 | def __init__(self, image_path, dataset_path): 74 | self.impath = image_path 75 | self.datapath = dataset_path 76 | with open(dataset_path, 'r', encoding='utf8') as fp: 77 | self.dataset = json.load(fp) 78 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 79 | self.initial_data = [] 80 | self.caption_depart_initial() 81 | print("Information about text query:{}".format(len(self))) 82 | 83 | def __len__(self): 84 | return len(self.initial_data) 85 | 86 | def caption_to_tokens(self, caption): 87 | result = self.tokenizer(caption, padding="max_length", max_length=64, truncation=True, return_tensors='pt') 88 | token, mask = result["input_ids"], result["attention_mask"] 89 | token, mask = token.squeeze(), mask.squeeze() 90 | return token, mask 91 | 92 | def caption_depart_initial(self): 93 | for i in range(len(self.dataset)): 94 | item = self.dataset[i] 95 | label = item["id"] 96 | captions_list = item["captions"] 97 | for j in range(len(captions_list)): 98 | caption = captions_list[j] 99 | self.initial_data.append([label,caption]) 100 | 101 | def __getitem__(self, index): 102 | caption = self.initial_data[index][1] 103 | label = self.initial_data[index][0] 104 | caption_tokens,masks = self.caption_to_tokens(caption) 105 | caption_tokens = torch.tensor(caption_tokens) 106 | label = torch.tensor(label) 107 | return label,caption_tokens,masks 108 | 109 | 110 | def get_loader(image_path, dataset_path,transform,flip_transform, batch_size, num_workers,distributed=False): 111 | dataset = Dataset(image_path=image_path,dataset_path=dataset_path,transform=transform,flip_transform=flip_transform) 112 | if distributed == False: 113 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) 114 | else: 115 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset,shuffle=True) 116 | dataloader = DataLoaderX(dataset,batch_size=batch_size,num_workers=num_workers,pin_memory=True,sampler=train_sampler,shuffle=False) 117 | return dataloader 118 | 119 | def get_loader_test(image_path, dataset_path,transform, batch_size, num_workers): 120 | image_dataset = Dataset_test_image(image_path=image_path,dataset_path=dataset_path,transform=transform) 121 | image_dataloader = DataLoader(image_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) 122 | text_dataset = Dataset_test_text(image_path=image_path,dataset_path=dataset_path) 123 | text_dataloader = DataLoader(text_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) 124 | return image_dataloader,text_dataloader -------------------------------------------------------------------------------- /Downstreams/CMPM-C/dataset_split.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def cuhk_train_depart(train_path,train_path_depart): 4 | with open(train_path) as f: 5 | dataset = json.load(f) 6 | output = [] 7 | for i in range(len(dataset)): 8 | data = dataset[i] 9 | captions = data["captions"] 10 | if len(captions)!=2: 11 | print("{}:{}".format(data["id"],captions)) 12 | for j in range(len(captions)): 13 | dict = {} 14 | dict["split"] = data["split"] 15 | dict["id"] = data["id"] 16 | dict["file_path"] = data["file_path"] 17 | dict["captions"] = [data["captions"][j]] 18 | output.append(dict) 19 | with open(train_path_depart,"w") as f: 20 | json.dump(output,f,indent=4) 21 | print("completed!") 22 | 23 | def TrainValidTest_split(path,train_path,test_path,valid_path): 24 | with open(path,"r") as f: 25 | dataset = json.load(f) 26 | train_output=[] 27 | test_output=[] 28 | valid_output=[] 29 | for i in range(len(dataset)): 30 | data = dataset[i] 31 | split = data["split"] 32 | if split == "train": 33 | train_output.append(data) 34 | elif split =="test": 35 | test_output.append(data) 36 | else: 37 | valid_output.append(data) 38 | if (i+1) % 100 == 0: 39 | print("{}/{} completed".format(i+1,len(dataset))) 40 | print("The train_set capacity:{}".format(len(train_output))) 41 | print("The test_set capacity:{}".format(len(test_output))) 42 | print("The valid_set capacity:{}".format(len(valid_output))) 43 | with open(train_path,"w") as f : 44 | json.dump(train_output,f,indent=4) 45 | with open(test_path,"w") as f : 46 | json.dump(test_output,f,indent=4) 47 | with open(valid_path,"w") as f : 48 | json.dump(valid_output,f,indent=4) 49 | 50 | if __name__ =="__main__": 51 | train_path = "data/CUHK-PEDES/CUHK-PEDES-train.json" 52 | train_path_depart = "data/CUHK-PEDES/CUHK-PEDES-train-depart.json" 53 | test_path = "data/CUHK-PEDES/CUHK-PEDES-test.json" 54 | valid_path = "data/CUHK-PEDES/CUHK-PEDES-valid.json" 55 | dataset_path = "data/CUHK-PEDES/reid_raw.json" 56 | TrainValidTest_split(dataset_path, train_path, test_path,valid_path) 57 | cuhk_train_depart(train_path,train_path_depart) 58 | 59 | train_path = "data/ICFG-PEDES/ICFG-PEDES-train.json" 60 | test_path = "data/ICFG-PEDES/ICFG-PEDES-test.json" 61 | valid_path = "data/ICFG-PEDES/ICFG-PEDES-valid.json" 62 | dataset_path = "data/ICFG-PEDES/ICFG_PEDES.json" 63 | TrainValidTest_split(dataset_path, train_path, test_path, valid_path) -------------------------------------------------------------------------------- /Downstreams/CMPM-C/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | 6 | class Loss_calc(nn.Module): 7 | def __init__(self, args): 8 | super(Loss_calc, self).__init__() 9 | self.epsilon = args.epsilon 10 | self.W =Parameter(torch.randn(args.feature_size, args.num_classes)) 11 | self.init_weight() 12 | def init_weight(self): 13 | nn.init.xavier_uniform_(self.W.data, gain=1) 14 | 15 | def compute_cmpm_loss(self, image_embeddings, text_embeddings, labels): 16 | """ 17 | Cross-Modal Projection Matching Loss(CMPM) 18 | :param image_embeddings: Tensor with dtype torch.float32 19 | :param text_embeddings: Tensor with dtype torch.float32 20 | :param labels: Tensor with dtype torch.int32 21 | :return: 22 | i2t_loss: cmpm loss for image projected to text 23 | t2i_loss: cmpm loss for text projected to image 24 | pos_avg_sim: average cosine-similarity for positive pairs 25 | neg_avg_sim: averate cosine-similarity for negative pairs 26 | """ 27 | 28 | batch_size = image_embeddings.shape[0] 29 | labels_reshape = torch.reshape(labels, (batch_size, 1)) 30 | labels_dist = labels_reshape - labels_reshape.t() 31 | labels_mask = (labels_dist == 0) 32 | 33 | image_norm = image_embeddings / image_embeddings.norm(dim=1, keepdim=True) 34 | text_norm = text_embeddings / text_embeddings.norm(dim=1, keepdim=True) 35 | image_proj_text = torch.matmul(image_embeddings, text_norm.t()) 36 | text_proj_image = torch.matmul(text_embeddings, image_norm.t()) 37 | 38 | # normalize the true matching distribution 39 | labels_mask_norm = labels_mask.float() / labels_mask.float().norm(dim=1) 40 | 41 | i2t_pred = F.softmax(image_proj_text, dim=1) 42 | i2t_loss = i2t_pred * (F.log_softmax(image_proj_text, dim=1) - torch.log(labels_mask_norm + self.epsilon)) 43 | t2i_pred = F.softmax(text_proj_image, dim=1) 44 | t2i_loss = t2i_pred * (F.log_softmax(text_proj_image, dim=1) - torch.log(labels_mask_norm + self.epsilon)) 45 | 46 | cmpm_loss = torch.mean(torch.sum(i2t_loss, dim=1)) + torch.mean(torch.sum(t2i_loss, dim=1)) 47 | 48 | return cmpm_loss 49 | 50 | def compute_cmpc_loss(self, image_embeddings, text_embeddings, labels): 51 | """ 52 | Cross-Modal Projection Classfication loss(CMPC) 53 | :param image_embeddings: Tensor with dtype torch.float32 54 | :param text_embeddings: Tensor with dtype torch.float32 55 | :param labels: Tensor with dtype torch.int32 56 | :return: 57 | """ 58 | criterion = nn.CrossEntropyLoss(reduction='mean') 59 | self.W_norm = self.W / self.W.norm(dim=0) 60 | image_norm = image_embeddings / image_embeddings.norm(dim=1, keepdim=True) 61 | text_norm = text_embeddings / text_embeddings.norm(dim=1, keepdim=True) 62 | image_proj_text = torch.sum(image_embeddings * text_norm, dim=1, keepdim=True) * text_norm 63 | text_proj_image = torch.sum(text_embeddings * image_norm, dim=1, keepdim=True) * image_norm 64 | image_logits = torch.matmul(image_proj_text, self.W_norm) 65 | text_logits = torch.matmul(text_proj_image, self.W_norm) 66 | cmpc_loss = criterion(image_logits, labels) + criterion(text_logits, labels) 67 | image_pred = torch.argmax(image_logits, dim=1) 68 | text_pred = torch.argmax(text_logits, dim=1) 69 | image_precision = torch.mean((image_pred == labels).float()) 70 | text_precision = torch.mean((text_pred == labels).float()) 71 | return cmpc_loss, image_pred, text_pred, image_precision, text_precision 72 | 73 | def forward(self, global_visual_embed, global_textual_embed,IDlabels): 74 | cmpm_loss = self.compute_cmpm_loss(global_visual_embed,global_textual_embed,IDlabels) 75 | cmpc_loss,image_pred,text_pred,image_precision, text_precision = self.compute_cmpc_loss(global_visual_embed, global_textual_embed,IDlabels) 76 | loss = cmpm_loss + cmpc_loss 77 | 78 | return cmpm_loss, cmpc_loss, loss, image_precision, text_precision -------------------------------------------------------------------------------- /Downstreams/CMPM-C/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | import torchvision.models as models 6 | from collections import OrderedDict 7 | from transformers import BertModel 8 | 9 | def weights_init_kaiming(m): 10 | classname = m.__class__.__name__ 11 | if classname.find('Conv2d') != -1: 12 | init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu') 13 | elif classname.find('Linear') != -1: 14 | init.kaiming_normal(m.weight.data, a=0, mode='fan_out') 15 | init.constant_(m.bias.data, 0.0) 16 | elif classname.find('BatchNorm1d') != -1: 17 | init.normal(m.weight.data, 1.0, 0.02) 18 | init.constant_(m.bias.data, 0.0) 19 | elif classname.find('BatchNorm2d') != -1: 20 | init.constant_(m.weight.data, 1) 21 | init.constant_(m.bias.data, 0) 22 | 23 | class conv(nn.Module): 24 | 25 | def __init__(self, input_dim, output_dim, relu=False, BN=False): 26 | super(conv, self).__init__() 27 | 28 | block = [] 29 | block += [nn.Conv2d(input_dim, output_dim, kernel_size=1, bias=False)] 30 | 31 | if BN: 32 | block += [nn.BatchNorm2d(output_dim)] 33 | if relu: 34 | block += [nn.LeakyReLU(0.25, inplace=True)] 35 | 36 | self.block = nn.Sequential(*block) 37 | self.block.apply(weights_init_kaiming) 38 | 39 | def forward(self, x): 40 | x = self.block(x) 41 | x = x.squeeze(3).squeeze(2) 42 | return x 43 | 44 | class Bottleneck(nn.Module): 45 | expansion = 4 46 | def __init__(self, inplanes, planes, stride=1): 47 | super().__init__() 48 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 49 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 50 | self.bn1 = nn.BatchNorm2d(planes) 51 | self.relu1 = nn.ReLU(inplace=True) 52 | 53 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 54 | self.bn2 = nn.BatchNorm2d(planes) 55 | self.relu2 = nn.ReLU(inplace=True) 56 | 57 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 58 | 59 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 60 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 61 | self.relu3 = nn.ReLU(inplace=True) 62 | 63 | self.downsample = None 64 | self.stride = stride 65 | 66 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 67 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 68 | self.downsample = nn.Sequential(OrderedDict([ 69 | ("-1", nn.AvgPool2d(stride)), 70 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 71 | ("1", nn.BatchNorm2d(planes * self.expansion)) 72 | ])) 73 | 74 | def forward(self, x: torch.Tensor): 75 | identity = x 76 | 77 | out = self.relu1(self.bn1(self.conv1(x))) 78 | out = self.relu2(self.bn2(self.conv2(out))) 79 | out = self.avgpool(out) 80 | out = self.bn3(self.conv3(out)) 81 | 82 | if self.downsample is not None: 83 | identity = self.downsample(x) 84 | 85 | out += identity 86 | out = self.relu3(out) 87 | return out 88 | 89 | class AttentionPool2d(nn.Module): 90 | def __init__(self, spacial_dim_x: int,spacial_dim_y: int, embed_dim: int, num_heads: int, output_dim: int = None): 91 | super().__init__() 92 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim_x * spacial_dim_y + 1, embed_dim) / embed_dim ** 0.5) 93 | self.k_proj = nn.Linear(embed_dim, embed_dim) 94 | self.q_proj = nn.Linear(embed_dim, embed_dim) 95 | self.v_proj = nn.Linear(embed_dim, embed_dim) 96 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 97 | self.num_heads = num_heads 98 | 99 | def forward(self, x): 100 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 101 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 102 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 103 | x, _ = F.multi_head_attention_forward( 104 | query=x[:1], key=x, value=x, 105 | embed_dim_to_check=x.shape[-1], 106 | num_heads=self.num_heads, 107 | q_proj_weight=self.q_proj.weight, 108 | k_proj_weight=self.k_proj.weight, 109 | v_proj_weight=self.v_proj.weight, 110 | in_proj_weight=None, 111 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 112 | bias_k=None, 113 | bias_v=None, 114 | add_zero_attn=False, 115 | dropout_p=0, 116 | out_proj_weight=self.c_proj.weight, 117 | out_proj_bias=self.c_proj.bias, 118 | use_separate_proj_weight=True, 119 | training=self.training, 120 | need_weights=False 121 | ) 122 | return x.squeeze(0) 123 | 124 | class Image_encoder_ModifiedResNet(nn.Module): 125 | """ 126 | A ResNet class that is similar to torchvision's but contains the following changes: 127 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 128 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 129 | - The final pooling layer is a QKV attention instead of an average pool 130 | """ 131 | def __init__(self, layers, output_dim, heads, input_resolution=[256,128], width=64): 132 | super().__init__() 133 | self.output_dim = output_dim 134 | self.input_resolution = input_resolution 135 | 136 | # the 3-layer stem 137 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 138 | self.bn1 = nn.BatchNorm2d(width // 2) 139 | self.relu1 = nn.ReLU(inplace=True) 140 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 141 | self.bn2 = nn.BatchNorm2d(width // 2) 142 | self.relu2 = nn.ReLU(inplace=True) 143 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 144 | self.bn3 = nn.BatchNorm2d(width) 145 | self.relu3 = nn.ReLU(inplace=True) 146 | self.avgpool = nn.AvgPool2d(2) 147 | 148 | # residual layers 149 | self._inplanes = width # this is a *mutable* variable used during construction 150 | self.layer1 = self._make_layer(width, layers[0]) 151 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 152 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 153 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 154 | 155 | embed_dim = width * 32 # the ResNet feature dimension 156 | self.attnpool = AttentionPool2d(input_resolution[0] // 32,input_resolution[1] // 32, embed_dim, heads, output_dim) 157 | self.initialize_parameters() 158 | def _make_layer(self, planes, blocks, stride=1): 159 | layers = [Bottleneck(self._inplanes, planes, stride)] 160 | 161 | self._inplanes = planes * Bottleneck.expansion 162 | for _ in range(1, blocks): 163 | layers.append(Bottleneck(self._inplanes, planes)) 164 | 165 | return nn.Sequential(*layers) 166 | 167 | def initialize_parameters(self): 168 | if self.attnpool is not None: 169 | std = self.attnpool.c_proj.in_features ** -0.5 170 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 171 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 172 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 173 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 174 | 175 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 176 | for name, param in resnet_block.named_parameters(): 177 | if name.endswith("bn3.weight"): 178 | nn.init.zeros_(param) 179 | 180 | def forward(self, x): 181 | def stem(x): 182 | x = self.relu1(self.bn1(self.conv1(x))) 183 | x = self.relu2(self.bn2(self.conv2(x))) 184 | x = self.relu3(self.bn3(self.conv3(x))) 185 | x = self.avgpool(x) 186 | return x 187 | 188 | x = x.type(self.conv1.weight.dtype) 189 | x = stem(x) 190 | x = self.layer1(x) 191 | x = self.layer2(x) 192 | x = self.layer3(x) 193 | x = self.layer4(x) 194 | feat = self.attnpool(x) 195 | return feat 196 | 197 | class Text_encoder(nn.Module): 198 | def __init__(self, encoder_type: str): 199 | super(Text_encoder, self).__init__() 200 | self.encoder = BertModel.from_pretrained(encoder_type) 201 | 202 | def forward(self, token, mask): 203 | x = self.encoder(input_ids=token, attention_mask=mask) 204 | pooler_output = x.pooler_output 205 | return pooler_output 206 | 207 | class Model(nn.Module): 208 | def __init__(self, image_encoder,text_encoder): 209 | super().__init__() 210 | self.image_encoder = image_encoder 211 | self.text_encoder = text_encoder 212 | 213 | def encode_image(self,image): 214 | return self.image_encoder(image) 215 | 216 | def encode_text(self,text,mask): 217 | return self.text_encoder(text, mask) 218 | 219 | def forward(self, image,text,masks): 220 | global_image_out = self.image_encoder(image) 221 | global_text_out = self.text_encoder(text,masks) 222 | return global_image_out,global_text_out 223 | 224 | def Create_model(args): 225 | image_encoder = Image_encoder_ModifiedResNet(args.layers,args.img_dim,args.heads,input_resolution=[args.width,args.height]) 226 | text_encoder = Text_encoder(encoder_type=args.txt_backbone) 227 | model = Model(image_encoder, text_encoder) 228 | return model 229 | 230 | 231 | -------------------------------------------------------------------------------- /Downstreams/CMPM-C/optimizer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import Tensor 4 | from torch.optim.optimizer import Optimizer 5 | 6 | from typing import List 7 | 8 | class Adan(Optimizer): 9 | """ 10 | Implements a pytorch variant of Adan 11 | Adan was proposed in 12 | Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J]. arXiv preprint arXiv:2208.06677, 2022. 13 | https://arxiv.org/abs/2208.06677 14 | Arguments: 15 | params (iterable): iterable of parameters to optimize or dicts defining parameter groups. 16 | lr (float, optional): learning rate. (default: 1e-3) 17 | betas (Tuple[float, float, flot], optional): coefficients used for computing 18 | running averages of gradient and its norm. (default: (0.98, 0.92, 0.99)) 19 | eps (float, optional): term added to the denominator to improve 20 | numerical stability. (default: 1e-8) 21 | weight_decay (float, optional): decoupled weight decay (L2 penalty) (default: 0) 22 | max_grad_norm (float, optional): value used to clip 23 | global grad norm (default: 0.0 no clip) 24 | no_prox (bool): how to perform the decoupled weight decay (default: False) 25 | foreach (bool): if True would use torch._foreach implementation. It's faster but uses 26 | slightly more memory. (default: True) 27 | """ 28 | 29 | def __init__(self, params, lr=1e-3, betas=(0.98, 0.92, 0.99), eps=1e-8, 30 | weight_decay=0.0, max_grad_norm=0.0, no_prox=False, foreach: bool = True): 31 | if not 0.0 <= max_grad_norm: 32 | raise ValueError("Invalid Max grad norm: {}".format(max_grad_norm)) 33 | if not 0.0 <= lr: 34 | raise ValueError("Invalid learning rate: {}".format(lr)) 35 | if not 0.0 <= eps: 36 | raise ValueError("Invalid epsilon value: {}".format(eps)) 37 | if not 0.0 <= betas[0] < 1.0: 38 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 39 | if not 0.0 <= betas[1] < 1.0: 40 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 41 | if not 0.0 <= betas[2] < 1.0: 42 | raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2])) 43 | defaults = dict(lr=lr, betas=betas, eps=eps, 44 | weight_decay=weight_decay, 45 | max_grad_norm=max_grad_norm, no_prox=no_prox, foreach=foreach) 46 | super().__init__(params, defaults) 47 | 48 | def __setstate__(self, state): 49 | super(Adan, self).__setstate__(state) 50 | for group in self.param_groups: 51 | group.setdefault('no_prox', False) 52 | 53 | @torch.no_grad() 54 | def restart_opt(self): 55 | for group in self.param_groups: 56 | group['step'] = 0 57 | for p in group['params']: 58 | if p.requires_grad: 59 | state = self.state[p] 60 | # State initialization 61 | 62 | # Exponential moving average of gradient values 63 | state['exp_avg'] = torch.zeros_like(p) 64 | # Exponential moving average of squared gradient values 65 | state['exp_avg_sq'] = torch.zeros_like(p) 66 | # Exponential moving average of gradient difference 67 | state['exp_avg_diff'] = torch.zeros_like(p) 68 | 69 | @torch.no_grad() 70 | def step(self, closure=None): 71 | """ 72 | Performs a single optimization step. 73 | """ 74 | 75 | loss = None 76 | if closure is not None: 77 | with torch.enable_grad(): 78 | loss = closure() 79 | 80 | if self.defaults['max_grad_norm'] > 0: 81 | device = self.param_groups[0]['params'][0].device 82 | global_grad_norm = torch.zeros(1, device=device) 83 | 84 | max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], device=device) 85 | for group in self.param_groups: 86 | 87 | for p in group['params']: 88 | if p.grad is not None: 89 | grad = p.grad 90 | global_grad_norm.add_(grad.pow(2).sum()) 91 | 92 | global_grad_norm = torch.sqrt(global_grad_norm) 93 | 94 | clip_global_grad_norm = torch.clamp(max_grad_norm / (global_grad_norm + group['eps']), max=1.0) 95 | else: 96 | clip_global_grad_norm = 1.0 97 | 98 | for group in self.param_groups: 99 | params_with_grad = [] 100 | grads = [] 101 | exp_avgs = [] 102 | exp_avg_sqs = [] 103 | exp_avg_diffs = [] 104 | pre_grads = [] 105 | 106 | beta1, beta2, beta3 = group['betas'] 107 | # assume same step across group now to simplify things 108 | # per parameter step can be easily support by making it tensor, or pass list into kernel 109 | if 'step' in group: 110 | group['step'] += 1 111 | else: 112 | group['step'] = 1 113 | 114 | bias_correction1 = 1.0 - beta1 ** group['step'] 115 | bias_correction2 = 1.0 - beta2 ** group['step'] 116 | bias_correction3 = 1.0 - beta3 ** group['step'] 117 | 118 | for p in group['params']: 119 | if p.grad is None: 120 | continue 121 | params_with_grad.append(p) 122 | grads.append(p.grad) 123 | 124 | state = self.state[p] 125 | if len(state) == 0: 126 | state['exp_avg'] = torch.zeros_like(p) 127 | state['exp_avg_sq'] = torch.zeros_like(p) 128 | state['exp_avg_diff'] = torch.zeros_like(p) 129 | 130 | if 'pre_grad' not in state or group['step'] == 1: 131 | # at first step grad wouldn't be clipped by `clip_global_grad_norm` 132 | # this is only to simplify implementation 133 | state['pre_grad'] = p.grad 134 | 135 | exp_avgs.append(state['exp_avg']) 136 | exp_avg_sqs.append(state['exp_avg_sq']) 137 | exp_avg_diffs.append(state['exp_avg_diff']) 138 | pre_grads.append(state['pre_grad']) 139 | 140 | kwargs = dict( 141 | params=params_with_grad, 142 | grads=grads, 143 | exp_avgs=exp_avgs, 144 | exp_avg_sqs=exp_avg_sqs, 145 | exp_avg_diffs=exp_avg_diffs, 146 | pre_grads=pre_grads, 147 | beta1=beta1, 148 | beta2=beta2, 149 | beta3=beta3, 150 | bias_correction1=bias_correction1, 151 | bias_correction2=bias_correction2, 152 | bias_correction3_sqrt=math.sqrt(bias_correction3), 153 | lr=group['lr'], 154 | weight_decay=group['weight_decay'], 155 | eps=group['eps'], 156 | no_prox=group['no_prox'], 157 | clip_global_grad_norm=clip_global_grad_norm, 158 | ) 159 | if group["foreach"]: 160 | copy_grads = _multi_tensor_adan(**kwargs) 161 | else: 162 | copy_grads = _single_tensor_adan(**kwargs) 163 | 164 | for p, copy_grad in zip(params_with_grad, copy_grads): 165 | self.state[p]['pre_grad'] = copy_grad 166 | 167 | return loss 168 | 169 | 170 | def _single_tensor_adan( 171 | params: List[Tensor], 172 | grads: List[Tensor], 173 | exp_avgs: List[Tensor], 174 | exp_avg_sqs: List[Tensor], 175 | exp_avg_diffs: List[Tensor], 176 | pre_grads: List[Tensor], 177 | *, 178 | beta1: float, 179 | beta2: float, 180 | beta3: float, 181 | bias_correction1: float, 182 | bias_correction2: float, 183 | bias_correction3_sqrt: float, 184 | lr: float, 185 | weight_decay: float, 186 | eps: float, 187 | no_prox: bool, 188 | clip_global_grad_norm: Tensor, 189 | ): 190 | copy_grads = [] 191 | for i, param in enumerate(params): 192 | grad = grads[i] 193 | exp_avg = exp_avgs[i] 194 | exp_avg_sq = exp_avg_sqs[i] 195 | exp_avg_diff = exp_avg_diffs[i] 196 | pre_grad = pre_grads[i] 197 | 198 | grad = grad.mul_(clip_global_grad_norm) 199 | copy_grads.append(grad.clone()) 200 | 201 | diff = grad - pre_grad 202 | update = grad + beta2 * diff 203 | 204 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) # m_t 205 | exp_avg_diff.mul_(beta2).add_(diff, alpha=1 - beta2) # diff_t 206 | exp_avg_sq.mul_(beta3).addcmul_(update, update, value=1 - beta3) # n_t 207 | 208 | denom = ((exp_avg_sq).sqrt() / bias_correction3_sqrt).add_(eps) 209 | update = ((exp_avg / bias_correction1 + beta2 * exp_avg_diff / bias_correction2)).div_(denom) 210 | 211 | if no_prox: 212 | param.mul_(1 - lr * weight_decay) 213 | param.add_(update, alpha=-lr) 214 | else: 215 | param.add_(update, alpha=-lr) 216 | param.div_(1 + lr * weight_decay) 217 | return copy_grads 218 | 219 | 220 | def _multi_tensor_adan( 221 | params: List[Tensor], 222 | grads: List[Tensor], 223 | exp_avgs: List[Tensor], 224 | exp_avg_sqs: List[Tensor], 225 | exp_avg_diffs: List[Tensor], 226 | pre_grads: List[Tensor], 227 | *, 228 | beta1: float, 229 | beta2: float, 230 | beta3: float, 231 | bias_correction1: float, 232 | bias_correction2: float, 233 | bias_correction3_sqrt: float, 234 | lr: float, 235 | weight_decay: float, 236 | eps: float, 237 | no_prox: bool, 238 | clip_global_grad_norm: Tensor, 239 | ): 240 | if clip_global_grad_norm < 1.0: 241 | torch._foreach_mul_(grads, clip_global_grad_norm.item()) 242 | copy_grads = [g.clone() for g in grads] 243 | 244 | diff = torch._foreach_sub(grads, pre_grads) 245 | # NOTE: line below while looking identical gives different result, due to float precision errors. 246 | # using mul+add produces identical results to single-tensor, using add+alpha doesn't 247 | # On cuda this difference doesn't matter due to its' own precision non-determinism 248 | # update = torch._foreach_add(grads, torch._foreach_mul(diff, beta2)) 249 | update = torch._foreach_add(grads, diff, alpha=beta2) 250 | 251 | torch._foreach_mul_(exp_avgs, beta1) 252 | torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1) # m_t 253 | 254 | torch._foreach_mul_(exp_avg_diffs, beta2) 255 | torch._foreach_add_(exp_avg_diffs, diff, alpha=1 - beta2) # diff_t 256 | 257 | torch._foreach_mul_(exp_avg_sqs, beta3) 258 | torch._foreach_addcmul_(exp_avg_sqs, update, update, value=1 - beta3) # n_t 259 | 260 | denom = torch._foreach_sqrt(exp_avg_sqs) 261 | torch._foreach_div_(denom, bias_correction3_sqrt) 262 | torch._foreach_add_(denom, eps) 263 | 264 | update = torch._foreach_div(exp_avgs, bias_correction1) 265 | # NOTE: same issue as above. beta2 * diff / bias_correction2 != diff * (beta2 / bias_correction2) 266 | # using faster version by default. 267 | # torch._foreach_add_(update, torch._foreach_div(torch._foreach_mul(exp_avg_diffs, beta2), bias_correction2)) 268 | torch._foreach_add_(update, torch._foreach_mul(exp_avg_diffs, beta2 / bias_correction2)) 269 | torch._foreach_div_(update, denom) 270 | 271 | if no_prox: 272 | torch._foreach_mul_(params, 1 - lr * weight_decay) 273 | torch._foreach_add_(params, update, alpha=-lr) 274 | else: 275 | torch._foreach_add_(params, update, alpha=-lr) 276 | torch._foreach_div_(params, 1 + lr * weight_decay) 277 | return copy_grads 278 | 279 | def create_optimizer(params,args): 280 | if args.optimizer =="adam": 281 | print("The optimizer is {}".format(args.optimizer)) 282 | optimizer = torch.optim.Adam(params,lr=args.lr, betas=(args.adam_alpha, args.adam_beta), eps=args.epsilon) 283 | return optimizer 284 | elif args.optimizer =="adan": 285 | print("The optimizer is {}".format(args.optimizer)) 286 | optimizer = Adan(params, lr=args.lr, weight_decay=args.adan_weight_decay, betas=args.adan_opt_betas, eps=args.adan_opt_eps, 287 | max_grad_norm=args.adan_max_grad_norm, no_prox=args.adan_no_prox) 288 | return optimizer -------------------------------------------------------------------------------- /Downstreams/CMPM-C/test.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | from utils import * 3 | import os 4 | import shutil 5 | from dataloader import get_loader_test 6 | import argparse 7 | from model import Create_model 8 | import warnings 9 | warnings.filterwarnings("ignore") 10 | 11 | def Test_parse_args(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--model_path', type=str, default='log/CUHK-CMPMC') 14 | parser.add_argument('--image_dir', type=str, default='data/CUHK-PEDES/') 15 | parser.add_argument('--caption_path', type=str, 16 | default='data/CUHK-PEDES/CUHK-PEDES-test.json', 17 | help='path for test annotation json file') 18 | parser.add_argument('--best_dir', type=str, default='log/CUHK-CMPMC') 19 | parser.add_argument('--flip_eval', type=bool, default=False) 20 | parser.add_argument('--mean', default=[0.357, 0.323, 0.328], type=list) 21 | parser.add_argument('--std', default=[0.252, 0.242, 0.239], type=list) 22 | # *********************************************************************************************************************** 23 | # 设置模型backbone的类型和参数 24 | parser.add_argument('--img_backbone', type=str, default='ModifiedResNet', 25 | help="ResNet:xxx, ModifiedResNet, ViT:xxx") 26 | parser.add_argument('--txt_backbone', type=str, default="bert-base-uncased") 27 | parser.add_argument('--img_dim', type=int, default=768, help='dimension of image embedding vectors') 28 | parser.add_argument('--text_dim', type=int, default=768, help='dimension of text embedding vectors') 29 | parser.add_argument('--patch_size', type=int, default=16, help='Just for ViT model') 30 | parser.add_argument('--layers', type=list, default=[3, 4, 6, 3], help='Just for ModifiedResNet model') 31 | parser.add_argument('--heads', type=int, default=8, help='Just for ModifiedResNet model') 32 | 33 | parser.add_argument('--height', type=int, default=256) 34 | parser.add_argument('--width', type=int, default=128) 35 | 36 | # 设置超参数 37 | parser.add_argument('--num_epoches', type=int, default=30) 38 | parser.add_argument('--batch_size', type=int, default=128) 39 | parser.add_argument('--num_workers', type=int, default=4) 40 | parser.add_argument('--device',type=str,default="cuda:0") 41 | parser.add_argument('--feature_size', type=int, default=768) 42 | args = parser.parse_args() 43 | return args 44 | 45 | def test(image_test_loader,text_test_loader, model): 46 | # switch to evaluate mode 47 | model = model.eval() 48 | device = next(model.parameters()).device 49 | 50 | qids, gids, qfeats, gfeats = [], [], [], [] 51 | # text 52 | for pid, caption,mask in text_test_loader: 53 | caption = caption.to(device) 54 | mask = mask.to(device) 55 | with torch.no_grad(): 56 | text_feat = model.encode_text(caption,mask) 57 | qids.append(pid.view(-1)) # flatten 58 | qfeats.append(text_feat) 59 | qids = torch.cat(qids, 0) 60 | qfeats = torch.cat(qfeats, 0) 61 | 62 | # image 63 | for pid, img in image_test_loader: 64 | img = img.to(device) 65 | with torch.no_grad(): 66 | img_feat = model.encode_image(img) 67 | gids.append(pid.view(-1)) # flatten 68 | gfeats.append(img_feat) 69 | gids = torch.cat(gids, 0) 70 | gfeats = torch.cat(gfeats, 0) 71 | ac_top1_t2i, ac_top5_t2i, ac_top10_t2i, mAP = test_map(qfeats, qids, gfeats, gids) 72 | return ac_top1_t2i, ac_top5_t2i, ac_top10_t2i, mAP 73 | 74 | 75 | def Test_main(args): 76 | device = args.device 77 | transform = transforms.Compose([ 78 | transforms.Resize((256,128), interpolation=3), 79 | transforms.ToTensor(), 80 | transforms.Normalize(args.mean, 81 | args.std) 82 | ]) 83 | 84 | image_test_loader,text_test_loader = get_loader_test(args.image_dir, args.caption_path, transform, args.batch_size,args.num_workers) 85 | 86 | ac_t2i_top1_best = 0.0 87 | ac_t2i_top5_best = 0.0 88 | ac_t2i_top10_best = 0.0 89 | mAP_best = 0.0 90 | best = 0 91 | dst_best = args.best_dir + "/model_best" + ".pth" 92 | model = Create_model(args).to(device) 93 | 94 | for i in range(0, args.num_epoches): 95 | if i%2!=0 and i!=args.num_epoches-1: 96 | continue 97 | model_file = os.path.join(args.model_path, str(i+1))+".pth.tar" 98 | print(model_file) 99 | if os.path.isdir(model_file): 100 | continue 101 | checkpoint = torch.load(model_file) 102 | model.image_encoder.load_state_dict(checkpoint["ImgEncoder_state_dict"]) 103 | model.text_encoder.load_state_dict(checkpoint["TxtEncoder_state_dict"]) 104 | ac_top1_t2i, ac_top5_t2i, ac_top10_t2i, mAP = test(image_test_loader,text_test_loader, model) 105 | if ac_top1_t2i > ac_t2i_top1_best: 106 | ac_t2i_top1_best = ac_top1_t2i 107 | ac_t2i_top5_best = ac_top5_t2i 108 | ac_t2i_top10_best = ac_top10_t2i 109 | mAP_best = mAP 110 | best = i 111 | shutil.copyfile(model_file, dst_best) 112 | 113 | print('Epo{}: {:.5f} {:.5f} {:.5f} {:.5f}'.format( 114 | best, ac_t2i_top1_best, ac_t2i_top5_best, ac_t2i_top10_best, mAP_best)) 115 | 116 | if __name__ == '__main__': 117 | args = Test_parse_args() 118 | Test_main(args) -------------------------------------------------------------------------------- /Downstreams/CMPM-C/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from dataloader import get_loader 3 | from model import Create_model 4 | from loss import Loss_calc 5 | from torchvision import transforms 6 | from utils import * 7 | import time 8 | from optimizer import create_optimizer 9 | import torch.distributed as dist 10 | from torch.nn.parallel import DistributedDataParallel as DDP 11 | 12 | 13 | def Train_parse_args(): 14 | parser = argparse.ArgumentParser(description='command for train by Resnet') 15 | parser.add_argument('--mean', default=[0.357, 0.323, 0.328], type=list) 16 | parser.add_argument('--std', default=[0.252, 0.242, 0.239], type=list) 17 | #*********************************************************************************************************************** 18 | #设置数据集路径和输出路径等 19 | parser.add_argument('--name', default='CUHK-CMPMC', type=str, help='output model name') 20 | parser.add_argument('--checkpoint_dir', type=str, 21 | default="log", 22 | help='directory to store checkpoint') 23 | parser.add_argument('--log_dir', type=str, 24 | default="log", 25 | help='directory to store log') 26 | parser.add_argument('--image_path', type=str, 27 | default=r'data/CUHK-PEDES', 28 | help='directory to store dataset') 29 | parser.add_argument('--dataset_path', type=str, 30 | default=r'data/CUHK-PEDES-train-depart.json', 31 | help='directory to annotations') 32 | parser.add_argument('--checkpoint_path', type=str,default=r'checkpoints/PLIP_RN50.pth.tar') 33 | #*********************************************************************************************************************** 34 | #设置模型backbone的类型和参数 35 | parser.add_argument('--img_backbone', type=str, default='ModifiedResNet',help="ResNet:xxx, ModifiedResNet, ViT:xxx") 36 | parser.add_argument('--txt_backbone', type=str, default="bert-base-uncased") 37 | parser.add_argument('--img_dim', type=int, default=768, help='dimension of image embedding vectors') 38 | parser.add_argument('--text_dim', type=int, default=768, help='dimension of text embedding vectors') 39 | parser.add_argument('--layers', type=list, default=[3,4,6,3], help='Just for ModifiedResNet model') 40 | parser.add_argument('--heads', type=int, default=8, help='Just for ModifiedResNet model') 41 | parser.add_argument('--feature_size', type=int, default=768) 42 | parser.add_argument('--num_classes', type=int, default=11003) # CUHK:11003 ICFG:3102 THE NUMBER OF IDENTITIES 43 | #*********************************************************************************************************************** 44 | #设置训练预处理超参数 45 | parser.add_argument('--height', type=int, default=256) 46 | parser.add_argument('--width', type=int, default=128) 47 | 48 | #*********************************************************************************************************************** 49 | #设置学习率等超参数 50 | parser.add_argument('--save_every', type=int, default=1, help='step size for saving trained models') 51 | parser.add_argument('--batch_size', type=int, default=32) 52 | parser.add_argument('--epochs', type=int, default=30) 53 | parser.add_argument('--warm_epoch', default=5, type=int, help='the first K epoch that needs warm up') 54 | parser.add_argument('--lr', type=float, default=0.0001, help='the learning rate of optimizer') 55 | parser.add_argument('--lr_decay_type', type=str, default='MultiStepLR', 56 | help='One of "MultiStepLR" or "StepLR" or "get_linear_schedule_with_warmup"') 57 | parser.add_argument('--lr_decay_ratio', type=float, default=0.1) 58 | parser.add_argument('--epoches_decay', type=str, default='20', help='#epoches when learning rate decays') 59 | parser.add_argument('--backbone_frozen', type=bool, default=False) 60 | #*********************************************************************************************************************** 61 | # 设置优化器超参数 62 | parser.add_argument('--optimizer', type=str, default="adan", help='The optimizer type:adam or adan') 63 | 64 | parser.add_argument('--adam_alpha', type=float, default=0.9) 65 | parser.add_argument('--adam_beta', type=float, default=0.999) 66 | parser.add_argument('--epsilon', type=float, default=1e-8) 67 | 68 | parser.add_argument('--adan_max-grad-norm', type=float, default=0.0, 69 | help='if the l2 norm is large than this hyper-parameter, then we clip the gradient (default: 0.0, no gradient clip)') 70 | parser.add_argument('--adan_weight-decay', type=float, default=0.02, 71 | help='weight decay, similar one used in AdamW (default: 0.02)') 72 | parser.add_argument('--adan_opt-eps', default=1e-8, type=float, metavar='EPSILON', 73 | help='optimizer epsilon to avoid the bad case where second-order moment is zero (default: None, use opt default 1e-8 in adan)') 74 | parser.add_argument('--adan_opt-betas', default=[0.98, 0.92, 0.99], type=float, nargs='+', metavar='BETA', 75 | help='optimizer betas in Adan (default: None, use opt default [0.98, 0.92, 0.99] in Adan)') 76 | parser.add_argument('--adan_no-prox', action='store_true', default=False, 77 | help='whether perform weight decay like AdamW (default=False)') 78 | 79 | #*********************************************************************************************************************** 80 | #其他设置 81 | parser.add_argument('--num_workers', type=int, default=4) 82 | parser.add_argument('--gpus', type=str, default="0,1,2,3") 83 | parser.add_argument('--local_rank', type=int) 84 | args = parser.parse_args() 85 | return args 86 | 87 | 88 | def train(args): 89 | rank = int(os.environ["RANK"]) 90 | local_rank = int(os.environ["LOCAL_RANK"]) 91 | torch.cuda.set_device(rank % torch.cuda.device_count()) 92 | dist.init_process_group(backend="nccl") 93 | device = torch.device("cuda", local_rank) 94 | print(f"[init] == local rank: {local_rank}, global rank: {rank} ==") 95 | 96 | train_information_setting(args) 97 | checkpoint_dir = os.path.join(args.checkpoint_dir, args.name) 98 | 99 | transform = transforms.Compose([ 100 | transforms.Resize((args.height, args.width),interpolation=3), 101 | transforms.RandomHorizontalFlip(p=0.5), 102 | transforms.ToTensor(), 103 | transforms.Normalize(args.mean,args.std) 104 | ]) 105 | 106 | epochs = args.epochs 107 | 108 | model = Create_model(args).to(device) 109 | model_file = args.checkpoint_path 110 | checkpoint = torch.load(model_file) 111 | model.image_encoder.load_state_dict(checkpoint["ImgEncoder_state_dict"]) 112 | model.text_encoder.load_state_dict(checkpoint["TxtEncoder_state_dict"]) 113 | Loss = Loss_calc(args).to(device) 114 | params = [{"params": model.image_encoder.parameters(), "lr": args.lr}, 115 | {"params": model.text_encoder.parameters(), "lr": args.lr}, 116 | {"params": Loss.parameters(), "lr": args.lr * 10}] 117 | 118 | 119 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) # 多卡bn层同步 120 | optimizer = create_optimizer(params, args) 121 | 122 | train_dataloader = get_loader(args.image_path, args.dataset_path, transform, None,args.batch_size,args.num_workers,distributed=True) 123 | scheduler = lr_scheduler(optimizer, args,len(train_dataloader)) 124 | PrintInformation(args, model) 125 | 126 | model = DDP(model, device_ids=[local_rank], output_device=local_rank, broadcast_buffers=True,find_unused_parameters=True) 127 | 128 | if rank == 0: 129 | print(" ======= Training ======= \n") 130 | 131 | model.train() 132 | 133 | for epoch in range(epochs): 134 | print("**********************************************************") 135 | print(f">>> Training epoch {epoch + 1}") 136 | total_loss = 0 137 | start = time.time() 138 | train_dataloader.sampler.set_epoch(epoch) 139 | if epoch < args.warm_epoch: 140 | print('learning rate warm_up') 141 | optimizer = gradual_warmup(epoch, optimizer, epochs=args.warm_epoch) 142 | 143 | optimizer.zero_grad() 144 | for idx, (images_gt, targets, masks, labels) in enumerate( 145 | train_dataloader): 146 | images_gt, targets, masks, labels = images_gt.to(device), targets.to(device), masks.to(device), labels.to(device) - 1 147 | global_visual_embed, global_textual_embed = model(images_gt, targets, masks) 148 | IDlabels = labels 149 | cmpm_loss, cmpc_loss, loss, image_precision, text_precision = Loss(global_visual_embed, global_textual_embed, IDlabels) 150 | #print(loss) 151 | loss.backward() 152 | optimizer.step() 153 | optimizer.zero_grad() 154 | total_loss += loss.item() 155 | if idx % 50 == 0 : 156 | print( 157 | "Train Epoch:[{}/{}] iteration:[{}/{}] cmpm_loss:{:.4f} cmpc_loss:{:.4f} " 158 | "image_pre:{:.4f} text_pre:{:.4f}" 159 | .format(epoch + 1, args.epochs, idx, len(train_dataloader), cmpm_loss.item(), 160 | cmpc_loss.item(), 161 | image_precision, text_precision)) 162 | scheduler.step() 163 | Epoch_time = time.time() - start 164 | print("Average loss is :{}".format(total_loss / len(train_dataloader))) 165 | print('Epoch_training complete in {:.0f}m {:.0f}s'.format( 166 | Epoch_time // 60, Epoch_time % 60)) 167 | 168 | if epoch % args.save_every == 0 or epoch == epochs - 1: 169 | state = {"epoch": epoch + 1, 170 | "ImgEncoder_state_dict": model.module.image_encoder.state_dict(), 171 | "TxtEncoder_state_dict": model.module.text_encoder.state_dict() 172 | } 173 | save_checkpoint(state, epoch + 1, checkpoint_dir) 174 | 175 | def main(): 176 | args = Train_parse_args() 177 | train(args) 178 | 179 | if __name__ == '__main__': 180 | main() 181 | -------------------------------------------------------------------------------- /Downstreams/CMPM-C/utils.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import sys 3 | import os.path as osp 4 | import torch.utils.data as data 5 | import os 6 | import torch 7 | import numpy as np 8 | import random 9 | import json 10 | import torch.nn as nn 11 | from einops.layers.torch import Rearrange 12 | from transformers import get_linear_schedule_with_warmup 13 | import yaml 14 | 15 | def setup_seed(seed): 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | np.random.seed(seed) # Numpy module. 20 | random.seed(seed) # Python random module. 21 | torch.backends.cudnn.benchmark = False 22 | torch.backends.cudnn.deterministic = True 23 | print("The seed is {}".format(seed)) 24 | 25 | def lr_scheduler(optimizer, args,loader_length): 26 | if args.lr_decay_type == "get_linear_schedule_with_warmup": 27 | scheduler = get_linear_schedule_with_warmup( 28 | optimizer, num_warmup_steps=int(args.epochs * loader_length * 0.1/args.accumulation_steps), 29 | num_training_steps=int(args.epochs * loader_length/args.accumulation_steps) 30 | ) 31 | print("lr_scheduler is get_linear_schedule_with_warmup") 32 | else: 33 | if '_' in args.epoches_decay: 34 | epoches_list = args.epoches_decay.split('_') 35 | epoches_list = [int(e) for e in epoches_list] 36 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, epoches_list, gamma=args.lr_decay_ratio) 37 | print("lr_scheduler is MultiStepLR") 38 | else: 39 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, int(args.epoches_decay), gamma=args.lr_decay_ratio) 40 | print("lr_scheduler is StepLR") 41 | return scheduler 42 | 43 | def save_checkpoint(state, epoch, dst): 44 | if not os.path.exists(dst): 45 | os.makedirs(dst) 46 | filename = os.path.join(dst, str(epoch)) + '.pth.tar' 47 | torch.save(state, filename) 48 | 49 | def gradual_warmup(epoch,optimizer,epochs): 50 | if epoch == 0: 51 | warmup_percent_done = (epoch + 1) / epochs 52 | else: 53 | warmup_percent_done = (epoch + 1) / epoch 54 | for param_group in optimizer.param_groups: 55 | param_group['lr'] = param_group['lr']*warmup_percent_done 56 | return optimizer 57 | 58 | def PrintInformation(args,model): 59 | # with open(args.annotation_path,"r") as f: 60 | # dataset = json.load(f) 61 | # num = len(dataset) 62 | print("The image model is: {}".format(args.img_backbone)) 63 | print("The language model is: {}".format(args.txt_backbone)) 64 | print('Number of model parameters: {}'.format( 65 | sum([p.data.nelement() for p in model.parameters()]))) 66 | print("Checkpoints are saved in: {}".format(args.name)) 67 | print("The original learning rate is {}".format(args.lr)) 68 | 69 | 70 | def train_information_setting(args): #配置好训练信息的输出文件 71 | name = args.name 72 | # set some paths 73 | log_dir = args.log_dir 74 | log_dir = os.path.join(log_dir, name) 75 | sys.stdout = Logger(os.path.join(log_dir, "train_log.txt")) 76 | opt_dir = os.path.join('log', name) 77 | if not os.path.exists(opt_dir): 78 | os.makedirs(opt_dir) 79 | with open('%s/opts_train.yaml' % opt_dir, 'w') as fp: 80 | yaml.dump(vars(args), fp, default_flow_style=False) 81 | 82 | def mkdir_if_missing(directory): 83 | if not osp.exists(directory): 84 | try: 85 | os.makedirs(directory) 86 | except OSError as e: 87 | if e.errno != errno.EEXIST: 88 | raise 89 | 90 | class Logger(object): 91 | """ 92 | Write console output to external text file. 93 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 94 | """ 95 | def __init__(self, fpath=None): 96 | self.console = sys.stdout 97 | self.file = None 98 | if fpath is not None: 99 | mkdir_if_missing(os.path.dirname(fpath)) 100 | self.file = open(fpath, 'w') 101 | 102 | def __del__(self): 103 | self.close() 104 | 105 | def __enter__(self): 106 | pass 107 | 108 | def __exit__(self, *args): 109 | self.close() 110 | 111 | def write(self, msg): 112 | self.console.write(msg) 113 | if self.file is not None: 114 | self.file.write(msg) 115 | 116 | def flush(self): 117 | self.console.flush() 118 | if self.file is not None: 119 | self.file.flush() 120 | os.fsync(self.file.fileno()) 121 | 122 | def close(self): 123 | self.console.close() 124 | if self.file is not None: 125 | self.file.close() 126 | 127 | def compute_topk(query, gallery, target_query, target_gallery, k=[1,10], reverse=False): 128 | result = [] 129 | query = query / (query.norm(dim=1,keepdim=True)+1e-12) 130 | gallery = gallery / (gallery.norm(dim=1,keepdim=True)+1e-12) 131 | sim_cosine = torch.matmul(query, gallery.t()) 132 | result.extend(topk(sim_cosine, target_gallery, target_query, k)) 133 | if reverse: 134 | result.extend(topk(sim_cosine, target_query, target_gallery, k, dim=0)) 135 | return result 136 | 137 | def topk(sim, target_gallery, target_query, k=[1,10], dim=1): 138 | result = [] 139 | maxk = max(k) 140 | size_total = len(target_gallery) 141 | _, pred_index = sim.topk(maxk, dim, True, True) 142 | pred_labels = target_gallery[pred_index] 143 | if dim == 1: 144 | pred_labels = pred_labels.t() 145 | correct = pred_labels.eq(target_query.view(1,-1).expand_as(pred_labels)) 146 | 147 | for topk in k: 148 | correct_k = torch.sum(correct[:topk], dim=0) 149 | correct_k = torch.sum(correct_k > 0).float() 150 | result.append(correct_k * 100 / size_total) 151 | return result 152 | 153 | def test_map(query_feature,query_label,gallery_feature, gallery_label): 154 | query_feature = query_feature / (query_feature.norm(dim=1, keepdim=True) + 1e-12) 155 | gallery_feature = gallery_feature / (gallery_feature.norm(dim=1, keepdim=True) + 1e-12) 156 | CMC = torch.IntTensor(len(gallery_label)).zero_() 157 | ap = 0.0 158 | for i in range(len(query_label)): 159 | ap_tmp, CMC_tmp = evaluate(query_feature[i], query_label[i], gallery_feature, gallery_label) 160 | 161 | if CMC_tmp[0] == -1: 162 | continue 163 | CMC = CMC + CMC_tmp 164 | ap += ap_tmp 165 | CMC = CMC.float() 166 | CMC = CMC / len(query_label) 167 | print('Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f' % (CMC[0], CMC[4], CMC[9], ap / len(query_label))) 168 | return CMC[0], CMC[4], CMC[9], ap / len(query_label) 169 | 170 | def evaluate(qf, ql, gf, gl): 171 | query = qf.view(-1, 1) 172 | score = torch.mm(gf, query) 173 | score = score.squeeze(1).cpu() 174 | score = score.numpy() 175 | index = np.argsort(score) 176 | index = index[::-1] 177 | gl=gl.cuda().data.cpu().numpy() 178 | ql=ql.cuda().data.cpu().numpy() 179 | query_index = np.argwhere(gl == ql) 180 | CMC_tmp = compute_mAP(index, query_index) 181 | return CMC_tmp 182 | 183 | def compute_mAP(index, good_index): 184 | ap = 0 185 | cmc = torch.IntTensor(len(index)).zero_() 186 | if good_index.size == 0: # if empty 187 | cmc[0] = -1 188 | return ap, cmc 189 | # find good_index index 190 | ngood = len(good_index) 191 | mask = np.in1d(index, good_index) 192 | rows_good = np.argwhere(mask == True) 193 | rows_good = rows_good.flatten() 194 | 195 | cmc[rows_good[0]:] = 1 196 | for i in range(ngood): 197 | d_recall = 1.0 / ngood 198 | precision = (i + 1) * 1.0 / (rows_good[i] + 1) 199 | if rows_good[i] != 0: 200 | old_precision = i * 1.0 / rows_good[i] 201 | else: 202 | old_precision = 1.0 203 | ap = ap + d_recall * (old_precision + precision) / 2 204 | 205 | return ap, cmc -------------------------------------------------------------------------------- /Downstreams/readme.txt: -------------------------------------------------------------------------------- 1 | These are some experiments for the downstream tasks. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 jlongzuo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PLIPmodel.py: -------------------------------------------------------------------------------- 1 | from visual_model import * 2 | from textual_model import Textual_encoder 3 | 4 | class PLIP_MResNet(nn.Module): 5 | def __init__(self, image_encoder,text_encoder): 6 | super().__init__() 7 | self.image_encoder = image_encoder 8 | self.text_encoder = text_encoder 9 | 10 | def get_text_global_embedding(self,caption,mask): 11 | global_text_out = self.text_encoder.get_global_embedding(caption,mask) 12 | return global_text_out 13 | 14 | def get_image_embeddings(self,image): 15 | global_image_out, _,_,_,_ = self.image_encoder(image) 16 | return global_image_out 17 | 18 | def forward(self, image,text,masks): 19 | global_image_out,x1,x2,x3,x4 = self.image_encoder(image) 20 | global_text_out, part_text_out = self.text_encoder(text,masks) 21 | return global_image_out,x1,x2,x3,x4,global_text_out,part_text_out 22 | 23 | def Create_PLIP_Model(args): 24 | if args.plip_model == "MResNet_BERT": 25 | image_encoder = Image_encoder_ModifiedResNet(args.layers,args.img_dim,args.heads,input_resolution=[args.width,args.height]) 26 | text_encoder = Textual_encoder(encoder_type=args.txt_backbone) 27 | model = PLIP_MResNet(image_encoder, text_encoder) 28 | return model 29 | else: 30 | raise RuntimeError(f"The image backbone you input does not meet the specification!") 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PLIP 2 | **PLIP** is a novel **L**anguage-**I**mage **P**re-training framework for generic **P**erson representation learning which benefits a range of downstream person-centric tasks. 3 | 4 | Also, we present a large-scale person dataset named **SYNTH-PEDES** to verify its effectiveness, where the Stylish Pedestrian Attributes-union Captioning method **(SPAC)** is proposed to synthesize diverse textual descriptions. 5 | 6 | Experiments show that our model not only significantly improves existing methods on downstream tasks, but also shows great ability in the few-shot and domain generalization settings. More details can be found at our paper [PLIP: Language-Image Pre-training for Person Representation Learning](http://export.arxiv.org/abs/2305.08386). 7 | 8 |