├── README.md ├── core ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── base.cpython-35.pyc │ ├── base.cpython-36.pyc │ ├── base.cpython-37.pyc │ ├── base.cpython-38.pyc │ ├── test.cpython-38.pyc │ ├── train.cpython-36.pyc │ ├── train.cpython-37.pyc │ ├── train.cpython-38.pyc │ ├── train_heatmap_no_grad_reasoning.cpython-35.pyc │ ├── train_heatmapmask_id_reasoning.cpython-35.pyc │ ├── train_heatmapmask_no_grad_reasoning.cpython-35.pyc │ ├── train_heatmapmask_reasoning.cpython-35.pyc │ └── train_with_cam.cpython-35.pyc ├── base.py ├── test.py └── train.py ├── data_loader ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── dataset.cpython-38.pyc │ ├── loader.cpython-38.pyc │ ├── processing.cpython-38.pyc │ └── sampler.cpython-38.pyc ├── dataset.py ├── loader.py ├── processing.py └── sampler.py ├── main.py ├── network ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── gem_pool.cpython-38.pyc │ ├── lr.cpython-38.pyc │ ├── model.cpython-36.pyc │ ├── model.cpython-37.pyc │ └── model.cpython-38.pyc ├── clip │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── clip.cpython-38.pyc │ │ ├── model.cpython-38.pyc │ │ └── simple_tokenizer.cpython-38.pyc │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── clip.py │ ├── model.py │ └── simple_tokenizer.py ├── gem_pool.py ├── lr.py ├── model.py └── processing.py ├── run.sh └── tools ├── MSEL.py ├── __init__.py ├── __pycache__ ├── __init__.cpython-35.pyc ├── __init__.cpython-36.pyc ├── __init__.cpython-37.pyc ├── __init__.cpython-38.pyc ├── eval_metrics.cpython-38.pyc ├── logger.cpython-35.pyc ├── logger.cpython-36.pyc ├── logger.cpython-37.pyc ├── logger.cpython-38.pyc ├── loss.cpython-35.pyc ├── loss.cpython-36.pyc ├── loss.cpython-37.pyc ├── loss.cpython-38.pyc ├── meter.cpython-35.pyc ├── meter.cpython-36.pyc ├── meter.cpython-37.pyc ├── meter.cpython-38.pyc ├── utils.cpython-35.pyc ├── utils.cpython-36.pyc ├── utils.cpython-37.pyc └── utils.cpython-38.pyc ├── eval_metrics.py ├── logger.py ├── loss.py ├── meter.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | 权重已放出。链接:https://pan.baidu.com/s/1kiMj34iGZMvxFrXmZ4d9EA 2 | 提取码:b7l1 3 | -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .train import * 3 | from .test import * 4 | from .test import * -------------------------------------------------------------------------------- /core/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/core/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /core/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/core/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /core/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/core/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/core/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/base.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/core/__pycache__/base.cpython-35.pyc -------------------------------------------------------------------------------- /core/__pycache__/base.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/core/__pycache__/base.cpython-36.pyc -------------------------------------------------------------------------------- /core/__pycache__/base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/core/__pycache__/base.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/core/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/test.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/core/__pycache__/test.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/train.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/core/__pycache__/train.cpython-36.pyc -------------------------------------------------------------------------------- /core/__pycache__/train.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/core/__pycache__/train.cpython-37.pyc -------------------------------------------------------------------------------- /core/__pycache__/train.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/core/__pycache__/train.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/train_heatmap_no_grad_reasoning.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/core/__pycache__/train_heatmap_no_grad_reasoning.cpython-35.pyc -------------------------------------------------------------------------------- /core/__pycache__/train_heatmapmask_id_reasoning.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/core/__pycache__/train_heatmapmask_id_reasoning.cpython-35.pyc -------------------------------------------------------------------------------- /core/__pycache__/train_heatmapmask_no_grad_reasoning.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/core/__pycache__/train_heatmapmask_no_grad_reasoning.cpython-35.pyc -------------------------------------------------------------------------------- /core/__pycache__/train_heatmapmask_reasoning.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/core/__pycache__/train_heatmapmask_reasoning.cpython-35.pyc -------------------------------------------------------------------------------- /core/__pycache__/train_with_cam.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/core/__pycache__/train_with_cam.cpython-35.pyc -------------------------------------------------------------------------------- /core/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | from bisect import bisect_right 6 | from network import Model 7 | from network.lr import CosineLRScheduler 8 | from tools import os_walk, CrossEntropyLabelSmooth, SupConLoss, TripletLoss_WRT, MSEL, MSEL_Feat, MSEL_Cos 9 | 10 | def create_scheduler(optimizer, num_epochs, lr_min, warmup_lr_init, warmup_t, noise_range = None): 11 | 12 | lr_scheduler = CosineLRScheduler( 13 | optimizer, 14 | t_initial=num_epochs, 15 | lr_min=lr_min, 16 | t_mul= 1., 17 | decay_rate=0.1, 18 | warmup_lr_init=warmup_lr_init, 19 | warmup_t=warmup_t, 20 | cycle_limit=1, 21 | t_in_epochs=True, 22 | noise_range_t=noise_range, 23 | noise_pct= 0.67, 24 | noise_std= 1., 25 | noise_seed=42, 26 | ) 27 | 28 | return lr_scheduler 29 | 30 | class Base: 31 | def __init__(self, config): 32 | self.config = config 33 | 34 | self.pid_num = config.pid_num 35 | 36 | self.max_save_model_num = config.max_save_model_num 37 | self.output_path = config.output_path 38 | self.save_model_path = os.path.join(self.output_path, 'models/') 39 | self.save_logs_path = os.path.join(self.output_path, 'logs/') 40 | 41 | self.learning_rate = config.learning_rate 42 | self.weight_decay = config.weight_decay 43 | self.milestones = config.milestones 44 | 45 | self.img_h = config.img_h 46 | self.img_w = config.img_w 47 | 48 | self.stage1_learning_rate = config.stage1_learning_rate 49 | self.stage2_learning_rate = config.stage2_learning_rate 50 | self.stage1_weight_decay = config.stage1_weight_decay 51 | self.stage1_train_epochs = config.stage1_train_epochs 52 | self.stage1_lr_min = config.stage1_lr_min 53 | self.stage1_warmup_lr_init = config.stage1_warmup_lr_init 54 | self.stage1_warmup_epochs = config.stage1_warmup_epochs 55 | 56 | self._init_device() 57 | self._init_model() 58 | self._init_creiteron() 59 | 60 | def _init_device(self): 61 | self.device = torch.device('cuda') 62 | 63 | def _init_model(self): 64 | 65 | self.model = Model(self.pid_num, self.img_h, self.img_w) 66 | self.model = nn.DataParallel(self.model).to(self.device) 67 | 68 | def _init_creiteron(self): 69 | self.con_creiteron = SupConLoss(self.device) 70 | self.pid_creiteron = nn.CrossEntropyLoss() 71 | self.soft_pid_creiteron = CrossEntropyLabelSmooth() 72 | self.tri_creiteron = TripletLoss_WRT() 73 | self.msel_creiteron = MSEL(4) 74 | self.mselcos_creiteron = MSEL_Cos(4) 75 | self.mselfeat_creiteron = MSEL_Feat(4) 76 | 77 | def _init_optimizer_stage1(self): 78 | params = [] 79 | keys = [] 80 | for key, value in self.model.named_parameters(): 81 | if 'prompt_learner1' in key: 82 | lr = self.stage1_learning_rate 83 | weight_decay = self.stage1_weight_decay 84 | params += [{'params': [value], 'lr': lr, 'weight_decay': weight_decay}] 85 | keys += [[key]] 86 | if 'prompt_learner2' in key: 87 | lr = self.stage1_learning_rate 88 | weight_decay = self.stage1_weight_decay 89 | params += [{'params': [value], 'lr': lr, 'weight_decay': weight_decay}] 90 | keys += [[key]] 91 | 92 | self.model_optimizer_stage1 = getattr(torch.optim, 'Adam')(params) 93 | self.model_lr_scheduler_stage1 = create_scheduler(self.model_optimizer_stage1, 94 | num_epochs=self.stage1_train_epochs, lr_min=self.stage1_lr_min, 95 | warmup_lr_init=self.stage1_warmup_lr_init, 96 | warmup_t=self.stage1_warmup_epochs, noise_range=None) 97 | 98 | def _init_optimizer_stage2(self): 99 | params = [] 100 | keys = [] 101 | for key, value in self.model.named_parameters(): 102 | if 'attention_fusion' in key: 103 | lr = self.stage2_learning_rate 104 | weight_decay = self.stage1_weight_decay 105 | params += [{'params': [value], 'lr': lr, 'weight_decay': weight_decay}] 106 | keys += [[key]] 107 | 108 | self.model_optimizer_stage2 = getattr(torch.optim, 'Adam')(params) 109 | self.model_lr_scheduler_stage2 = create_scheduler(self.model_optimizer_stage2, 110 | num_epochs=self.stage1_train_epochs, lr_min=self.stage1_lr_min, 111 | warmup_lr_init=self.stage1_warmup_lr_init, 112 | warmup_t=self.stage1_warmup_epochs, noise_range=None) 113 | 114 | def _init_optimizer_stage3(self): 115 | params = [] 116 | keys = [] 117 | for key, value in self.model.named_parameters(): 118 | if 'prompt_learner1' in key: 119 | value.requires_grad_(False) 120 | continue 121 | if 'prompt_learner2' in key: 122 | value.requires_grad_(False) 123 | continue 124 | if 'attention_fusion' in key: 125 | value.requires_grad_(False) 126 | continue 127 | if 'text_encoder' in key: 128 | value.requires_grad_(False) 129 | continue 130 | lr = self.learning_rate 131 | if 'classifier' in key: 132 | lr = self.learning_rate * 2 133 | params += [{'params': [value], 'lr': lr, 'weight_decay': self.weight_decay}] 134 | keys += [[key]] 135 | 136 | self.model_optimizer_stage3 = getattr(torch.optim, 'Adam')(params) 137 | self.model_lr_scheduler_stage3 = WarmupMultiStepLR(self.model_optimizer_stage3, self.milestones, 138 | gamma=0.1, warmup_factor=0.01, warmup_iters=10) 139 | 140 | def save_model(self, save_epoch, is_best): 141 | if is_best: 142 | model_file_path = os.path.join(self.save_model_path, 'model_{}.pth'.format(save_epoch)) 143 | torch.save(self.model.state_dict(), model_file_path) 144 | 145 | if self.max_save_model_num > 0: 146 | root, _, files = os_walk(self.save_model_path) 147 | for file in files: 148 | if '.pth' not in file: 149 | files.remove(file) 150 | if len(files) > 1 * self.max_save_model_num: 151 | file_iters = sorted([int(file.replace('.pth', '').split('_')[1]) for file in files], reverse=False) 152 | 153 | model_file_path = os.path.join(root, 'model_{}.pth'.format(file_iters[0])) 154 | os.remove(model_file_path) 155 | 156 | def resume_last_model(self): 157 | root, _, files = os_walk(self.save_model_path) 158 | for file in files: 159 | if '.pth' not in file: 160 | files.remove(file) 161 | if len(files) > 0: 162 | indexes = [] 163 | for file in files: 164 | indexes.append(int(file.replace('.pth', '').split('_')[-1])) 165 | indexes = sorted(list(set(indexes)), reverse=False) 166 | self.resume_model(indexes[-1]) 167 | start_train_epoch = indexes[-1] 168 | return start_train_epoch 169 | else: 170 | return 0 171 | 172 | def resume_model(self, resume_epoch): 173 | model_path = os.path.join(self.save_model_path, 'model_{}.pth'.format(resume_epoch)) 174 | self.model.load_state_dict(torch.load(model_path), strict=False) 175 | print('Successfully resume model from {}'.format(model_path)) 176 | 177 | def set_train(self): 178 | self.model = self.model.train() 179 | 180 | self.training = True 181 | 182 | def set_eval(self): 183 | self.model = self.model.eval() 184 | 185 | self.training = False 186 | 187 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 188 | def __init__(self, optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3, warmup_iters=500, 189 | warmup_method='linear', last_epoch=-1): 190 | if not list(milestones) == sorted(milestones): 191 | raise ValueError( 192 | "Milestones should be a list of " " increasing integers. Got {}", milestones) 193 | 194 | if warmup_method not in ("constant", "linear"): 195 | raise ValueError( 196 | "Only 'constant' or 'linear' warmup method accepted got {}".format(warmup_method)) 197 | self.milestones = milestones 198 | self.gamma = gamma 199 | self.warmup_factor = warmup_factor 200 | self.warmup_iters = warmup_iters 201 | self.warmup_method = warmup_method 202 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 203 | 204 | def get_lr(self): 205 | warmup_factor = 1 206 | if self.last_epoch < self.warmup_iters: 207 | if self.warmup_method == "constant": 208 | warmup_factor = self.warmup_factor 209 | elif self.warmup_method == "linear": 210 | alpha = float(self.last_epoch) / float(self.warmup_iters) 211 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 212 | 213 | return [ 214 | base_lr 215 | * warmup_factor 216 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 217 | for base_lr in self.base_lrs 218 | ] 219 | -------------------------------------------------------------------------------- /core/test.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | from torch.autograd import Variable 5 | from tools import eval_regdb, eval_sysu 6 | import os 7 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 8 | 9 | def test(base, loader, config): 10 | base.set_eval() 11 | print('Extracting Query Feature...') 12 | ptr = 0 13 | query_feat = np.zeros((loader.n_query, 6144)) 14 | with torch.no_grad(): 15 | for batch_idx, (input, label) in enumerate(loader.query_loader): 16 | batch_num = input.size(0) 17 | input = Variable(input.cuda()) 18 | flip_input = torch.flip(input, [3]).cuda() 19 | feat = base.model(x2=input) 20 | flip_feat = base.model(x2=flip_input) 21 | feat = torch.cat([feat, flip_feat], dim=1) 22 | query_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 23 | ptr = ptr + batch_num 24 | 25 | print('Extracting Gallery Feature...') 26 | 27 | if loader.dataset == 'sysu': 28 | all_cmc = 0 29 | all_mAP = 0 30 | all_mINP = 0 31 | for i in range(10): 32 | ptr = 0 33 | gall_loader = loader.gallery_loaders[i] 34 | gall_feat = np.zeros((loader.n_gallery, 6144)) 35 | with torch.no_grad(): 36 | for batch_idx, (input, label) in enumerate(gall_loader): 37 | batch_num = input.size(0) 38 | input = Variable(input.cuda()) 39 | flip_input = torch.flip(input, [3]).cuda() 40 | feat = base.model(x1=input) 41 | flip_feat = base.model(x1=flip_input) 42 | feat = torch.cat([feat, flip_feat], dim=1) 43 | gall_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 44 | ptr = ptr + batch_num 45 | distmat = np.matmul(query_feat, np.transpose(gall_feat)) 46 | cmc, mAP, mINP = eval_sysu(-distmat, loader.query_label, loader.gall_label, loader.query_cam, 47 | loader.gall_cam) 48 | all_cmc += cmc 49 | all_mAP += mAP 50 | all_mINP += mINP 51 | all_cmc /= 10.0 52 | all_mAP /= 10.0 53 | all_mINP /= 10.0 54 | 55 | elif loader.dataset == 'regdb': 56 | gall_loader = loader.gallery_loaders 57 | gall_feat = np.zeros((loader.n_gallery, 6144)) 58 | ptr = 0 59 | with torch.no_grad(): 60 | for batch_idx, (input, label) in enumerate(gall_loader): 61 | batch_num = input.size(0) 62 | input = Variable(input.cuda()) 63 | flip_input = torch.flip(input, [3]).cuda() 64 | feat = base.model(x1=input) 65 | flip_feat = base.model(x1=flip_input) 66 | feat = torch.cat([feat, flip_feat], dim=1) 67 | gall_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 68 | 69 | ptr = ptr + batch_num 70 | if config.regdb_test_mode == 't-v': 71 | distmat = np.matmul(query_feat, np.transpose(gall_feat)) 72 | cmc, mAP, mINP = eval_regdb(-distmat, loader.query_label, loader.gall_label) 73 | else: 74 | distmat = np.matmul(gall_feat, np.transpose(query_feat)) 75 | cmc, mAP, mINP = eval_regdb(-distmat, loader.gall_label, loader.query_label) 76 | 77 | all_cmc, all_mAP, all_mINP = cmc, mAP, mINP 78 | 79 | 80 | return all_cmc, all_mAP, all_mINP -------------------------------------------------------------------------------- /core/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tools import MultiItemAverageMeter 3 | from network.processing import FeatureShuffling 4 | 5 | def train_stage1(base, num_image, i_ter, batch, visible_labels_list, visible_image_features_list, 6 | infrared_labels_list, infrared_image_features_list): 7 | base.set_train() 8 | meter = MultiItemAverageMeter() 9 | iter_list = torch.randperm(num_image).to(base.device) 10 | for i in range(i_ter): 11 | b_list = iter_list[i*batch: (i+1)*batch] 12 | rgb_target = visible_labels_list[b_list].long() 13 | ir_target = infrared_labels_list[b_list].long() 14 | rgb_image_features = visible_image_features_list[b_list] 15 | ir_image_features = infrared_image_features_list[b_list] 16 | rgb_text_features = base.model(label1=rgb_target, get_text=True) 17 | ir_text_features = base.model(label2=ir_target, get_text=True) 18 | image_features = torch.cat([rgb_image_features, ir_image_features], dim=0) 19 | text_features = torch.cat([rgb_text_features, ir_text_features], dim=0) 20 | target = torch.cat([rgb_target, ir_target], dim=0) 21 | loss_i2t = base.con_creiteron(image_features, text_features, target, target) 22 | loss_t2i = base.con_creiteron(text_features, image_features, target, target) 23 | 24 | loss = loss_i2t + loss_t2i 25 | base.model_optimizer_stage1.zero_grad() 26 | loss.backward() 27 | base.model_optimizer_stage1.step() 28 | 29 | meter.update({'loss_i2t': loss_i2t.data, 30 | 'loss_t2i': loss_t2i.data,}) 31 | 32 | return meter.get_val(), meter.get_str() 33 | 34 | def train_stage2(base, num_image, i_ter, batch, labels_list, image_features_list): 35 | base.set_train() 36 | meter = MultiItemAverageMeter() 37 | iter_list = torch.randperm(num_image).to(base.device) 38 | for i in range(i_ter): 39 | b_list = iter_list[i*batch: (i+1)*batch] 40 | target = labels_list[b_list].long() 41 | image_features = image_features_list[b_list] 42 | text_features = base.model(label=target, get_fusion_text=True) 43 | loss_i2t = base.con_creiteron(image_features, text_features, target, target) 44 | loss_t2i = base.con_creiteron(text_features, image_features, target, target) 45 | 46 | loss = loss_i2t + loss_t2i 47 | base.model_optimizer_stage2.zero_grad() 48 | loss.backward() 49 | base.model_optimizer_stage2.step() 50 | 51 | meter.update({'loss_i2t': loss_i2t.data, 52 | 'loss_t2i': loss_t2i.data,}) 53 | 54 | return meter.get_val(), meter.get_str() 55 | 56 | def train(base, loaders, text_features, config, current_epoch): 57 | 58 | base.set_train() 59 | meter = MultiItemAverageMeter() 60 | loader = loaders.get_train_loader() 61 | for i, (input1_0, input1_1, input2, label1, label2) in enumerate(loader): 62 | rgb_imgs1, rgb_imgs2, rgb_pids = input1_0, input1_1, label1 63 | ir_imgs, ir_pids = input2, label2 64 | rgb_imgs1, rgb_imgs2, rgb_pids = rgb_imgs1.to(base.device), rgb_imgs2.to(base.device), \ 65 | rgb_pids.to(base.device).long() 66 | ir_imgs, ir_pids = ir_imgs.to(base.device), ir_pids.to(base.device).long() 67 | 68 | rgb_imgs = torch.cat([rgb_imgs1, rgb_imgs2], dim=0) 69 | pids = torch.cat([rgb_pids, rgb_pids, ir_pids], dim=0) 70 | 71 | features, cls_score = base.model(x1=rgb_imgs, x2=ir_imgs) 72 | 73 | n = features[1].shape[0] // 3 74 | rgb_features = features[0].squeeze().narrow(0, 0, n)#32.2048 75 | ir_features = features[0].squeeze().narrow(0, 2 * n, n)#32,2048 76 | rgb_attn_features = features[1].narrow(0, 0, n) 77 | ir_attn_features = features[1].narrow(0, 2 * n, n) 78 | 79 | rgb_logits = rgb_attn_features @ text_features.t() 80 | ir_logits = ir_attn_features @ text_features.t() 81 | 82 | ide_loss = base.pid_creiteron(cls_score[0], pids) 83 | ide_loss_proj = base.pid_creiteron(cls_score[1], pids) 84 | triplet_loss = base.tri_creiteron(features[0].squeeze(), pids) 85 | triplet_loss_proj = base.tri_creiteron(features[1].squeeze(), pids) 86 | msel_loss = base.msel_creiteron(torch.cat([rgb_features, ir_features], dim=0), torch.cat([rgb_pids, ir_pids], dim=0)) 87 | msel_loss_proj = base.msel_creiteron(torch.cat([rgb_attn_features, ir_attn_features], dim=0), 88 | torch.cat([rgb_pids, ir_pids], dim=0)) 89 | 90 | rgb_i2t_ide_loss = base.pid_creiteron(rgb_logits, rgb_pids) 91 | ir_i2t_ide_loss = base.pid_creiteron(ir_logits, ir_pids) 92 | 93 | loss = ide_loss + ide_loss_proj + (msel_loss + msel_loss_proj) + \ 94 | config.lambda1 * (triplet_loss + triplet_loss_proj) + \ 95 | config.lambda2 * rgb_i2t_ide_loss + config.lambda3 * ir_i2t_ide_loss 96 | 97 | base.model_optimizer_stage3.zero_grad() 98 | loss.backward() 99 | base.model_optimizer_stage3.step() 100 | meter.update({'pid_loss': ide_loss.data, 101 | 'pid_loss_proj': ide_loss_proj.data, 102 | 'triplet_loss': triplet_loss.data, 103 | 'triplet_loss_proj': triplet_loss_proj.data, 104 | 'rgb_i2t_pid_loss': rgb_i2t_ide_loss.data, 105 | 'ir_i2t_pid_loss': ir_i2t_ide_loss.data, 106 | 'msel_loss': msel_loss.data, 107 | 'msel_loss_proj': msel_loss_proj.data, 108 | }) 109 | return meter.get_val(), meter.get_str() 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /data_loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/data_loader/__init__.py -------------------------------------------------------------------------------- /data_loader/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/data_loader/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/data_loader/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/data_loader/__pycache__/loader.cpython-38.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/processing.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/data_loader/__pycache__/processing.cpython-38.pyc -------------------------------------------------------------------------------- /data_loader/__pycache__/sampler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/data_loader/__pycache__/sampler.cpython-38.pyc -------------------------------------------------------------------------------- /data_loader/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import random 4 | import numpy as np 5 | import torch.utils.data as data 6 | from PIL import Image 7 | 8 | class SYSUData(data.Dataset): 9 | def __init__(self, data_dir, transform1=None, transform2=None, transform3=None, colorIndex=None, thermalIndex=None): 10 | train_color_image = np.load(data_dir + 'train_rgb_resized_img.npy') 11 | self.train_color_label = np.load(data_dir + 'train_rgb_resized_label.npy') 12 | 13 | train_thermal_image = np.load(data_dir + 'train_ir_resized_img.npy') 14 | self.train_thermal_label = np.load(data_dir + 'train_ir_resized_label.npy') 15 | 16 | # RGB format 17 | self.train_color_image = train_color_image 18 | self.train_thermal_image = train_thermal_image 19 | self.transform1 = transform1 20 | self.transform2 = transform2 21 | self.transform3 = transform3 22 | self.cIndex = colorIndex 23 | self.tIndex = thermalIndex 24 | 25 | def __getitem__(self, index): 26 | 27 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 28 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 29 | 30 | img1_0 = self.transform1(img1) 31 | img1_1 = self.transform2(img1) 32 | img2 = self.transform3(img2) 33 | 34 | return img1_0, img1_1, img2, target1, target2 35 | 36 | def __len__(self): 37 | return len(self.train_color_label) 38 | 39 | class SYSUDataNormalSamples(data.Dataset): 40 | def __init__(self, data_dir, transform1=None, transform2=None, colorIndex=None, thermalIndex=None): 41 | train_color_image = np.load(data_dir + 'train_rgb_resized_img.npy') 42 | self.train_color_label = np.load(data_dir + 'train_rgb_resized_label.npy') 43 | 44 | train_thermal_image = np.load(data_dir + 'train_ir_resized_img.npy') 45 | self.train_thermal_label = np.load(data_dir + 'train_ir_resized_label.npy') 46 | 47 | # RGB format 48 | self.train_color_image = train_color_image 49 | self.train_thermal_image = train_thermal_image 50 | self.transform1 = transform1 51 | self.transform2 = transform2 52 | self.cIndex = colorIndex 53 | self.tIndex = thermalIndex 54 | 55 | def __getitem__(self, index): 56 | 57 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 58 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 59 | 60 | img1 = self.transform1(img1) 61 | img2 = self.transform2(img2) 62 | 63 | return img1, img2, target1, target2 64 | 65 | def __len__(self): 66 | return len(self.train_color_label) 67 | 68 | class SYSUDataRGBNormalSamples: 69 | def __init__(self, data_dir): 70 | train_color_image = np.load(data_dir + 'train_rgb_resized_img.npy') 71 | self.train_color_label = np.load(data_dir + 'train_rgb_resized_label.npy') 72 | 73 | # RGB format 74 | self.train_color_image = train_color_image 75 | 76 | samples = self._load_samples() 77 | self.samples = samples 78 | 79 | def _load_samples(self): 80 | samples = [] 81 | for i in range(self.train_color_label.shape[0]): 82 | samples.append([self.train_color_image[i], self.train_color_label[i]]) 83 | 84 | return samples 85 | 86 | class SYSUDataIRNormalSamples: 87 | def __init__(self, data_dir): 88 | train_thermal_image = np.load(data_dir + 'train_ir_resized_img.npy') 89 | self.train_thermal_label = np.load(data_dir + 'train_ir_resized_label.npy') 90 | 91 | # RGB format 92 | self.train_thermal_image = train_thermal_image 93 | 94 | samples = self._load_samples() 95 | self.samples = samples 96 | 97 | def _load_samples(self): 98 | samples = [] 99 | for i in range(self.train_thermal_label.shape[0]): 100 | samples.append([self.train_thermal_image[i], self.train_thermal_label[i]]) 101 | 102 | return samples 103 | 104 | 105 | class RegDBData(data.Dataset): 106 | def __init__(self, data_dir, trial, transform1=None, transform2=None, transform3=None, 107 | colorIndex=None, thermalIndex=None): 108 | train_color_list = data_dir + 'idx/train_visible_{}'.format(trial) + '.txt' 109 | train_thermal_list = data_dir + 'idx/train_thermal_{}'.format(trial) + '.txt' 110 | 111 | color_img_file, train_color_label = load_data(train_color_list) 112 | thermal_img_file, train_thermal_label = load_data(train_thermal_list) 113 | 114 | train_color_image = [] 115 | for i in range(len(color_img_file)): 116 | img = Image.open(data_dir + color_img_file[i]) 117 | img = img.resize((144, 288), Image.ANTIALIAS) 118 | pix_array = np.array(img) 119 | train_color_image.append(pix_array) 120 | train_color_image = np.array(train_color_image) 121 | 122 | train_thermal_image = [] 123 | for i in range(len(thermal_img_file)): 124 | img = Image.open(data_dir + thermal_img_file[i]) 125 | img = img.resize((144, 288), Image.ANTIALIAS) 126 | pix_array = np.array(img) 127 | train_thermal_image.append(pix_array) 128 | train_thermal_image = np.array(train_thermal_image) 129 | 130 | # RGB format 131 | self.train_color_image = train_color_image 132 | self.train_color_label = train_color_label 133 | 134 | # RGB format 135 | self.train_thermal_image = train_thermal_image 136 | self.train_thermal_label = train_thermal_label 137 | 138 | self.transform1 = transform1 139 | self.transform2 = transform2 140 | self.transform3 = transform3 141 | self.cIndex = colorIndex 142 | self.tIndex = thermalIndex 143 | 144 | def __getitem__(self, index): 145 | 146 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 147 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 148 | 149 | img1_0 = self.transform1(img1) 150 | img1_1 = self.transform2(img1) 151 | img2 = self.transform3(img2) 152 | 153 | return img1_0, img1_1, img2, target1, target2 154 | 155 | def __len__(self): 156 | return len(self.train_color_label) 157 | 158 | class RegDBDataNormalSamples(data.Dataset): 159 | def __init__(self, data_dir, trial, transform1=None, transform2=None, colorIndex=None, thermalIndex=None): 160 | train_color_list = data_dir + 'idx/train_visible_{}'.format(trial) + '.txt' 161 | train_thermal_list = data_dir + 'idx/train_thermal_{}'.format(trial) + '.txt' 162 | 163 | color_img_file, train_color_label = load_data(train_color_list) 164 | thermal_img_file, train_thermal_label = load_data(train_thermal_list) 165 | 166 | train_color_image = [] 167 | for i in range(len(color_img_file)): 168 | img = Image.open(data_dir + color_img_file[i]) 169 | img = img.resize((144, 288), Image.ANTIALIAS) 170 | pix_array = np.array(img) 171 | train_color_image.append(pix_array) 172 | train_color_image = np.array(train_color_image) 173 | 174 | train_thermal_image = [] 175 | for i in range(len(thermal_img_file)): 176 | img = Image.open(data_dir + thermal_img_file[i]) 177 | img = img.resize((144, 288), Image.ANTIALIAS) 178 | pix_array = np.array(img) 179 | train_thermal_image.append(pix_array) 180 | train_thermal_image = np.array(train_thermal_image) 181 | 182 | # RGB format 183 | self.train_color_image = train_color_image 184 | self.train_color_label = train_color_label 185 | 186 | # RGB format 187 | self.train_thermal_image = train_thermal_image 188 | self.train_thermal_label = train_thermal_label 189 | 190 | self.transform1 = transform1 191 | self.transform2 = transform2 192 | self.cIndex = colorIndex 193 | self.tIndex = thermalIndex 194 | 195 | def __getitem__(self, index): 196 | 197 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 198 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 199 | 200 | img1 = self.transform1(img1) 201 | img2 = self.transform2(img2) 202 | 203 | return img1, img2, target1, target2 204 | 205 | def __len__(self): 206 | return len(self.train_color_label) 207 | 208 | class RegDBDataRGBSamples(data.Dataset): 209 | def __init__(self, data_dir, trial): 210 | train_color_list = data_dir + 'idx/train_visible_{}'.format(trial) + '.txt' 211 | 212 | color_img_file, train_color_label = load_data(train_color_list) 213 | 214 | train_color_image = [] 215 | for i in range(len(color_img_file)): 216 | img = Image.open(data_dir + color_img_file[i]) 217 | img = img.resize((144, 288), Image.ANTIALIAS) 218 | pix_array = np.array(img) 219 | train_color_image.append(pix_array) 220 | train_color_image = np.array(train_color_image) 221 | 222 | # RGB format 223 | self.train_color_image = train_color_image 224 | self.train_color_label = train_color_label 225 | 226 | samples = self._load_samples() 227 | self.samples = samples 228 | 229 | def _load_samples(self): 230 | samples = [] 231 | for i in range(len(self.train_color_label)): 232 | samples.append([self.train_color_image[i], self.train_color_label[i]]) 233 | 234 | return samples 235 | 236 | class RegDBDataIRSamples(data.Dataset): 237 | def __init__(self, data_dir, trial): 238 | train_thermal_list = data_dir + 'idx/train_thermal_{}'.format(trial) + '.txt' 239 | 240 | thermal_img_file, train_thermal_label = load_data(train_thermal_list) 241 | 242 | train_thermal_image = [] 243 | for i in range(len(thermal_img_file)): 244 | img = Image.open(data_dir + thermal_img_file[i]) 245 | img = img.resize((144, 288), Image.ANTIALIAS) 246 | pix_array = np.array(img) 247 | train_thermal_image.append(pix_array) 248 | train_thermal_image = np.array(train_thermal_image) 249 | 250 | # RGB format 251 | self.train_thermal_image = train_thermal_image 252 | self.train_thermal_label = train_thermal_label 253 | 254 | samples = self._load_samples() 255 | self.samples = samples 256 | 257 | def _load_samples(self): 258 | samples = [] 259 | for i in range(len(self.train_thermal_label)): 260 | samples.append([self.train_thermal_image[i], self.train_thermal_label[i]]) 261 | 262 | return samples 263 | 264 | class TestData(data.Dataset): 265 | def __init__(self, test_img_file, test_label, transform=None, img_size=(224, 224)): 266 | test_image = [] 267 | for i in range(len(test_img_file)): 268 | img = Image.open(test_img_file[i]) 269 | img = img.resize((img_size[0], img_size[1]), Image.ANTIALIAS) 270 | pix_array = np.array(img) 271 | test_image.append(pix_array) 272 | test_image = np.array(test_image) 273 | self.test_image = test_image 274 | self.test_label = test_label 275 | self.transform = transform 276 | 277 | def __getitem__(self, index): 278 | img1, target1 = self.test_image[index], self.test_label[index] 279 | img1 = self.transform(img1) 280 | return img1, target1 281 | 282 | def __len__(self): 283 | return len(self.test_image) 284 | 285 | def load_data(input_data_path): 286 | with open(input_data_path) as f: 287 | data_file_list = open(input_data_path, 'rt').read().splitlines() 288 | # Get full list of image and labels 289 | file_image = [s.split(' ')[0] for s in data_file_list] 290 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 291 | 292 | return file_image, file_label 293 | 294 | def process_query_sysu(data_path, mode='all', relabel=False): 295 | 296 | if mode == 'all': 297 | ir_cameras = ['cam3', 'cam6'] 298 | elif mode =='indoor': 299 | ir_cameras = ['cam3', 'cam6'] 300 | 301 | file_path = os.path.join(data_path, 'exp/test_id.txt') 302 | files_ir = [] 303 | 304 | with open(file_path, 'r') as file: 305 | ids = file.read().splitlines() 306 | ids = [int(y) for y in ids[0].split(',')] 307 | ids = ["%04d" % x for x in ids] 308 | 309 | for id in sorted(ids): 310 | for cam in ir_cameras: 311 | img_dir = os.path.join(data_path, cam, id) 312 | if os.path.isdir(img_dir): 313 | new_files = sorted([img_dir + '/' + i for i in os.listdir(img_dir)]) 314 | files_ir.extend(new_files) 315 | query_img = [] 316 | query_id = [] 317 | query_cam = [] 318 | for img_path in files_ir: 319 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 320 | query_img.append(img_path) 321 | query_id.append(pid) 322 | query_cam.append(camid) 323 | 324 | return query_img, np.array(query_id), np.array(query_cam) 325 | 326 | def process_gallery_sysu(data_path, mode='all', trial=0, relabel=False, gall_mode='single'): 327 | 328 | random.seed(trial) 329 | 330 | if mode == 'all': 331 | rgb_cameras = ['cam1', 'cam2', 'cam4', 'cam5'] 332 | elif mode == 'indoor': 333 | rgb_cameras = ['cam1', 'cam2'] 334 | 335 | file_path = os.path.join(data_path, 'exp/test_id.txt') 336 | files_rgb = [] 337 | with open(file_path, 'r') as file: 338 | ids = file.read().splitlines() 339 | ids = [int(y) for y in ids[0].split(',')] 340 | ids = ["%04d" % x for x in ids] 341 | 342 | for id in sorted(ids): 343 | for cam in rgb_cameras: 344 | img_dir = os.path.join(data_path, cam, id) 345 | if os.path.isdir(img_dir): 346 | new_files = sorted([img_dir + '/' + i for i in os.listdir(img_dir)]) 347 | if gall_mode == 'single': 348 | files_rgb.append(random.choice(new_files)) 349 | if gall_mode == 'multi': 350 | files_rgb.append(np.random.choice(new_files, 10, replace=False)) 351 | gall_img = [] 352 | gall_id = [] 353 | gall_cam = [] 354 | 355 | for img_path in files_rgb: 356 | if gall_mode == 'single': 357 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 358 | gall_img.append(img_path) 359 | gall_id.append(pid) 360 | gall_cam.append(camid) 361 | 362 | if gall_mode == 'multi': 363 | for i in img_path: 364 | camid, pid = int(i[-15]), int(i[-13:-9]) 365 | gall_img.append(i) 366 | gall_id.append(pid) 367 | gall_cam.append(camid) 368 | 369 | return gall_img, np.array(gall_id), np.array(gall_cam) 370 | 371 | 372 | def process_test_regdb(img_dir, trial=1, modal='visible'): 373 | if modal == 'visible': 374 | input_data_path = img_dir + 'idx/test_visible_{}'.format(trial) + '.txt' 375 | elif modal == 'thermal': 376 | input_data_path = img_dir + 'idx/test_thermal_{}'.format(trial) + '.txt' 377 | 378 | with open(input_data_path) as f: 379 | data_file_list = open(input_data_path, 'rt').read().splitlines() 380 | # Get full list of image and labels 381 | file_image = [img_dir + '/' + s.split(' ')[0] for s in data_file_list] 382 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 383 | 384 | return file_image, np.array(file_label) 385 | 386 | class Dataset: 387 | 388 | def __init__(self, samples, transform): 389 | self.samples = samples 390 | self.transform = transform 391 | 392 | def __getitem__(self, index): 393 | this_sample = copy.deepcopy(self.samples[index]) 394 | if self.transform is not None: 395 | this_sample[0] = self.transform(this_sample[0]) 396 | this_sample[1] = np.array(this_sample[1]) 397 | return this_sample 398 | 399 | def __len__(self): 400 | return len(self.samples) 401 | 402 | -------------------------------------------------------------------------------- /data_loader/loader.py: -------------------------------------------------------------------------------- 1 | 2 | import torchvision.transforms as transforms 3 | from data_loader.dataset import SYSUData, RegDBData, TestData, process_query_sysu, process_gallery_sysu, \ 4 | process_test_regdb, SYSUDataNormalSamples, Dataset, SYSUDataRGBNormalSamples, SYSUDataIRNormalSamples, \ 5 | RegDBDataNormalSamples, RegDBDataRGBSamples, RegDBDataIRSamples 6 | from data_loader.processing import ChannelRandomErasing, ChannelAdapGray, ChannelExchange 7 | from data_loader.sampler import GenIdx, IdentitySampler 8 | 9 | import torch.utils.data as data 10 | 11 | class Loader: 12 | 13 | def __init__(self, config): 14 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 15 | 16 | self.transform_color1 = transforms.Compose([ 17 | transforms.ToPILImage(), 18 | transforms.Pad(10), 19 | transforms.RandomCrop((288, 144)), 20 | transforms.RandomHorizontalFlip(), 21 | transforms.ToTensor(), 22 | normalize, 23 | ChannelRandomErasing(probability=0.5)]) 24 | 25 | self.transform_color2 = transforms.Compose([ 26 | transforms.ToPILImage(), 27 | transforms.Pad(10), 28 | transforms.RandomCrop((288, 144)), 29 | transforms.RandomHorizontalFlip(), 30 | transforms.ToTensor(), 31 | normalize, 32 | ChannelRandomErasing(probability=0.5), 33 | ChannelExchange(gray=2)]) 34 | 35 | self.transform_thermal = transforms.Compose([ 36 | transforms.ToPILImage(), 37 | transforms.Pad(10), 38 | transforms.RandomCrop((288, 144)), 39 | transforms.RandomHorizontalFlip(), 40 | transforms.ToTensor(), 41 | normalize, 42 | ChannelRandomErasing(probability=0.5), 43 | ChannelAdapGray(probability=0.5)]) 44 | 45 | self.transform_test = transforms.Compose([ 46 | transforms.ToPILImage(), 47 | transforms.Resize((config.img_h, config.img_w)), 48 | transforms.ToTensor(), 49 | normalize]) 50 | 51 | self.dataset = config.dataset 52 | self.sysu_data_path = config.sysu_data_path 53 | self.regdb_data_path = config.regdb_data_path 54 | 55 | self.trial = config.trial 56 | 57 | self.img_w = config.img_w 58 | self.img_h = config.img_h 59 | 60 | self.num_pos = config.num_pos 61 | self.stage1_batch_size = config.stage1_batch_size 62 | self.batch_size = config.batch_size 63 | 64 | 65 | self.test_mode = config.test_mode 66 | 67 | self.gall_mode = config.gall_mode 68 | 69 | self.num_workers = config.num_workers 70 | 71 | self._loader() 72 | 73 | def _loader(self): 74 | if self.dataset == 'sysu': 75 | samples = SYSUData(self.sysu_data_path, transform1=self.transform_color1, transform2=self.transform_color2, 76 | transform3=self.transform_thermal) 77 | self.color_pos, self.thermal_pos = GenIdx(samples.train_color_label, samples.train_thermal_label) 78 | self.samples = samples 79 | 80 | rgb_samples = SYSUDataRGBNormalSamples(self.sysu_data_path) 81 | ir_samples = SYSUDataIRNormalSamples(self.sysu_data_path) 82 | 83 | self.stage1_rgb_loader = self.get_stage1_rgb_loader(rgb_samples) 84 | self.stage1_ir_loader = self.get_stage1_ir_loader(ir_samples) 85 | 86 | normal_samples = SYSUDataNormalSamples(self.sysu_data_path, transform1=self.transform_test, 87 | transform2=self.transform_test) 88 | self.normal_color_pos, self.normal_thermal_pos = GenIdx(normal_samples.train_color_label, 89 | normal_samples.train_thermal_label) 90 | self.normal_samples = normal_samples 91 | 92 | query_samples, gallery_samples_list = self._get_test_samples(self.dataset) 93 | query_loader = data.DataLoader(query_samples, batch_size=128, shuffle=False, drop_last=False, 94 | num_workers=self.num_workers) 95 | gallery_loaders = [] 96 | for i in range(10): 97 | gallery_loader = data.DataLoader(gallery_samples_list[i], batch_size=128, shuffle=False, 98 | drop_last=False, num_workers=self.num_workers) 99 | gallery_loaders.append(gallery_loader) 100 | self.query_loader = query_loader 101 | self.gallery_loaders = gallery_loaders 102 | 103 | elif self.dataset == 'regdb': 104 | samples = RegDBData(self.regdb_data_path, self.trial, transform1=self.transform_color1, 105 | transform2=self.transform_color2, transform3=self.transform_thermal) 106 | self.color_pos, self.thermal_pos = GenIdx(samples.train_color_label, samples.train_thermal_label) 107 | self.samples = samples 108 | 109 | rgb_samples = RegDBDataRGBSamples(self.regdb_data_path, self.trial) 110 | ir_samples = RegDBDataIRSamples(self.regdb_data_path, self.trial) 111 | 112 | self.stage1_rgb_loader = self.get_stage1_rgb_loader(rgb_samples) 113 | self.stage1_ir_loader = self.get_stage1_ir_loader(ir_samples) 114 | 115 | normal_samples = RegDBDataNormalSamples(self.regdb_data_path, self.trial, transform1=self.transform_test, 116 | transform2=self.transform_test) 117 | self.normal_color_pos, self.normal_thermal_pos = GenIdx(normal_samples.train_color_label, 118 | normal_samples.train_thermal_label) 119 | self.normal_samples = normal_samples 120 | 121 | query_samples, gallery_samples = self._get_test_samples(self.dataset) 122 | self.query_loader = data.DataLoader(query_samples, batch_size=128, shuffle=False, drop_last=False, 123 | num_workers=self.num_workers) 124 | gallery_loader = data.DataLoader(gallery_samples, batch_size=128, shuffle=False, drop_last=False, 125 | num_workers=self.num_workers) 126 | self.gallery_loaders = gallery_loader 127 | 128 | def _get_test_samples(self, dataset): 129 | if dataset == 'sysu': 130 | query_img, query_label, query_cam = process_query_sysu(self.sysu_data_path, mode=self.test_mode) 131 | query_samples = TestData(query_img, query_label, transform=self.transform_test, 132 | img_size=(self.img_w, self.img_h)) 133 | self.query_label = query_label 134 | self.query_cam = query_cam 135 | 136 | self.n_query = len(query_label) 137 | 138 | gallery_samples_list = [] 139 | for i in range(10): 140 | gall_img, gall_label, gall_cam = process_gallery_sysu(self.sysu_data_path, mode=self.test_mode, trial=i, 141 | gall_mode=self.gall_mode) 142 | self.gall_cam = gall_cam 143 | self.gall_label = gall_label 144 | self.n_gallery = len(gall_label) 145 | 146 | gallery_samples = TestData(gall_img, gall_label, transform=self.transform_test, 147 | img_size=(self.img_w, self.img_h)) 148 | gallery_samples_list.append(gallery_samples) 149 | return query_samples, gallery_samples_list 150 | elif self.dataset == 'regdb': 151 | query_img, query_label = process_test_regdb(self.regdb_data_path, trial=self.trial, modal='thermal') 152 | query_samples = TestData(query_img, query_label, transform=self.transform_test, 153 | img_size=(self.img_w, self.img_h)) 154 | self.query_label = query_label 155 | 156 | self.n_query = len(query_label) 157 | gall_img, gall_label = process_test_regdb(self.regdb_data_path, trial=self.trial, modal='visible') 158 | gallery_samples = TestData(gall_img, gall_label, transform=self.transform_test, 159 | img_size=(self.img_w, self.img_h)) 160 | self.gall_label = gall_label 161 | self.n_gallery = len(gall_label) 162 | return query_samples, gallery_samples 163 | 164 | def get_train_loader(self): 165 | sampler = IdentitySampler(self.samples.train_color_label, self.samples.train_thermal_label, self.color_pos, 166 | self.thermal_pos, self.num_pos, int(self.batch_size / self.num_pos)) 167 | self.samples.cIndex = sampler.index1 168 | self.samples.tIndex = sampler.index2 169 | train_loader = data.DataLoader(self.samples, batch_size=self.batch_size, 170 | sampler=sampler, num_workers=self.num_workers, drop_last=True) 171 | return train_loader 172 | 173 | def get_train_normal_loader(self): 174 | normal_sampler = IdentitySampler(self.normal_samples.train_color_label, self.normal_samples.train_thermal_label, 175 | self.normal_color_pos, self.normal_thermal_pos, self.num_pos, 176 | int(self.batch_size / self.num_pos)) 177 | self.normal_samples.cIndex = normal_sampler.index1 178 | self.normal_samples.tIndex = normal_sampler.index2 179 | normal_train_loader = data.DataLoader(self.normal_samples, batch_size=self.stage1_batch_size, 180 | sampler=normal_sampler, num_workers=self.num_workers, drop_last=True) 181 | return normal_train_loader 182 | 183 | def get_stage1_rgb_loader(self, rgb_samples): 184 | datset = Dataset(rgb_samples.samples, transform=self.transform_test) 185 | train_loader = data.DataLoader(datset, batch_size=self.stage1_batch_size, num_workers=self.num_workers, 186 | shuffle=True, drop_last=True) 187 | return train_loader 188 | 189 | def get_stage1_ir_loader(self, ir_samples): 190 | datset = Dataset(ir_samples.samples, transform=self.transform_test) 191 | train_loader = data.DataLoader(datset, batch_size=self.stage1_batch_size, num_workers=self.num_workers, 192 | shuffle=True, drop_last=True) 193 | return train_loader -------------------------------------------------------------------------------- /data_loader/processing.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import random 4 | import math 5 | 6 | class ChannelAdap(object): 7 | 8 | def __init__(self, probability=0.5): 9 | self.probability = probability 10 | 11 | def __call__(self, img): 12 | 13 | idx = random.randint(0, 3) 14 | 15 | if idx == 0: 16 | img[1, :, :] = img[0, :, :] 17 | img[2, :, :] = img[0, :, :] 18 | elif idx == 1: 19 | img[0, :, :] = img[1, :, :] 20 | img[2, :, :] = img[1, :, :] 21 | elif idx == 2: 22 | img[0, :, :] = img[2, :, :] 23 | img[1, :, :] = img[2, :, :] 24 | else: 25 | img = img 26 | 27 | return img 28 | 29 | 30 | class ChannelAdapGray(object): 31 | 32 | def __init__(self, probability=0.5): 33 | self.probability = probability 34 | 35 | def __call__(self, img): 36 | 37 | idx = random.randint(0, 3) 38 | 39 | if idx == 0: 40 | img[1, :, :] = img[0, :, :] 41 | img[2, :, :] = img[0, :, :] 42 | elif idx == 1: 43 | img[0, :, :] = img[1, :, :] 44 | img[2, :, :] = img[1, :, :] 45 | elif idx == 2: 46 | img[0, :, :] = img[2, :, :] 47 | img[1, :, :] = img[2, :, :] 48 | else: 49 | if random.uniform(0, 1) > self.probability: 50 | img = img 51 | else: 52 | tmp_img = 0.2989 * img[0, :, :] + 0.5870 * img[1, :, :] + 0.1140 * img[2, :, :] 53 | img[0, :, :] = tmp_img 54 | img[1, :, :] = tmp_img 55 | img[2, :, :] = tmp_img 56 | return img 57 | 58 | 59 | class ChannelRandomErasing(object): 60 | 61 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.4914, 0.4822, 0.4465]): 62 | 63 | self.probability = probability 64 | self.mean = mean 65 | self.sl = sl 66 | self.sh = sh 67 | self.r1 = r1 68 | 69 | def __call__(self, img): 70 | 71 | if random.uniform(0, 1) > self.probability: 72 | return img 73 | 74 | for attempt in range(100): 75 | area = img.size()[1] * img.size()[2] 76 | 77 | target_area = random.uniform(self.sl, self.sh) * area 78 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 79 | 80 | h = int(round(math.sqrt(target_area * aspect_ratio))) 81 | w = int(round(math.sqrt(target_area / aspect_ratio))) 82 | 83 | if w < img.size()[2] and h < img.size()[1]: 84 | x1 = random.randint(0, img.size()[1] - h) 85 | y1 = random.randint(0, img.size()[2] - w) 86 | if img.size()[0] == 3: 87 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 88 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 89 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 90 | else: 91 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 92 | return img 93 | 94 | return img 95 | 96 | class ChannelExchange(object): 97 | 98 | def __init__(self, gray=2): 99 | self.gray = gray 100 | 101 | def __call__(self, img): 102 | 103 | idx = random.randint(0, self.gray) 104 | 105 | if idx == 0: 106 | img[1, :, :] = img[0, :, :] 107 | img[2, :, :] = img[0, :, :] 108 | elif idx == 1: 109 | img[0, :, :] = img[1, :, :] 110 | img[2, :, :] = img[1, :, :] 111 | elif idx == 2: 112 | img[0, :, :] = img[2, :, :] 113 | img[1, :, :] = img[2, :, :] 114 | else: 115 | tmp_img = 0.2989 * img[0, :, :] + 0.5870 * img[1, :, :] + 0.1140 * img[2, :, :] 116 | img[0, :, :] = tmp_img 117 | img[1, :, :] = tmp_img 118 | img[2, :, :] = tmp_img 119 | return img -------------------------------------------------------------------------------- /data_loader/sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data.sampler import Sampler 3 | 4 | def GenIdx(train_color_label, train_thermal_label): 5 | color_pos = [] 6 | unique_label_color = np.unique(train_color_label) 7 | for i in range(len(unique_label_color)): 8 | tmp_pos = [k for k, v in enumerate(train_color_label) if v == unique_label_color[i]] 9 | color_pos.append(tmp_pos) 10 | 11 | thermal_pos = [] 12 | unique_label_thermal = np.unique(train_thermal_label) 13 | for i in range(len(unique_label_thermal)): 14 | tmp_pos = [k for k, v in enumerate(train_thermal_label) if v == unique_label_thermal[i]] 15 | thermal_pos.append(tmp_pos) 16 | 17 | return color_pos, thermal_pos 18 | 19 | class IdentitySampler(Sampler): 20 | """Sample person identities evenly in each batch. 21 | Args: 22 | train_color_label, train_thermal_label: labels of two modalities 23 | color_pos, thermal_pos: positions of each identity 24 | batchSize: batch size 25 | """ 26 | 27 | def __init__(self, train_color_label, train_thermal_label, color_pos, thermal_pos, num_pos, batchSize): 28 | uni_label = np.unique(train_color_label) 29 | self.n_classes = len(uni_label) 30 | 31 | N = np.maximum(len(train_color_label), len(train_thermal_label)) 32 | for j in range(int(N / (batchSize * num_pos)) + 1): 33 | batch_idx = np.random.choice(uni_label, batchSize, replace=False) 34 | for i in range(batchSize): 35 | sample_color = np.random.choice(color_pos[batch_idx[i]], num_pos) 36 | sample_thermal = np.random.choice(thermal_pos[batch_idx[i]], num_pos) 37 | 38 | if j == 0 and i == 0: 39 | index1 = sample_color 40 | index2 = sample_thermal 41 | else: 42 | index1 = np.hstack((index1, sample_color)) 43 | index2 = np.hstack((index2, sample_thermal)) 44 | 45 | self.index1 = index1 46 | self.index2 = index2 47 | self.N = N 48 | 49 | def __iter__(self): 50 | return iter(np.arange(len(self.index1))) 51 | 52 | def __len__(self): 53 | return self.N 54 | 55 | ''' 56 | 57 | class IdentitySampler(Sampler): 58 | 59 | def __init__(self, train_color_label, train_thermal_label, color_pos, thermal_pos, batchSize, per_img): 60 | uni_label = np.unique(train_color_label) 61 | self.n_classes = len(uni_label) 62 | 63 | sample_color = np.arange(batchSize) 64 | sample_thermal = np.arange(batchSize) 65 | N = np.maximum(len(train_color_label), len(train_thermal_label)) 66 | 67 | # per_img = 4 68 | per_id = batchSize / per_img 69 | for j in range(N // batchSize + 1): 70 | batch_idx = np.random.choice(uni_label, int(per_id), replace=False) 71 | 72 | for s, i in enumerate(range(0, batchSize, per_img)): 73 | sample_color[i:i + per_img] = np.random.choice(color_pos[batch_idx[s]], per_img, replace=False) 74 | sample_thermal[i:i + per_img] = np.random.choice(thermal_pos[batch_idx[s]], per_img, replace=False) 75 | 76 | if j == 0: 77 | index1 = sample_color 78 | index2 = sample_thermal 79 | else: 80 | index1 = np.hstack((index1, sample_color)) 81 | index2 = np.hstack((index2, sample_thermal)) 82 | 83 | self.index1 = index1 84 | self.index2 = index2 85 | self.N = N 86 | 87 | def __iter__(self): 88 | return iter(np.arange(len(self.index1))) 89 | 90 | def __len__(self): 91 | return self.N 92 | ''' -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import ast 4 | import torch 5 | import random 6 | import argparse 7 | import numpy as np 8 | 9 | 10 | from data_loader.loader import Loader 11 | from core import Base, train, train_stage1, train_stage2, test 12 | from tools import make_dirs, Logger, os_walk, time_now 13 | import warnings 14 | warnings.filterwarnings("ignore") 15 | 16 | best_mAP = 0 17 | best_rank1 = 0 18 | def seed_torch(seed): 19 | seed = int(seed) 20 | random.seed(seed) 21 | os.environ['PYTHONASHSEED'] = str(seed) 22 | np.random.seed(seed) 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed(seed) 25 | torch.cuda.manual_seed_all(seed) 26 | torch.backends.cudnn.deterministic = True 27 | torch.backends.cudnn.benchmark = False 28 | 29 | 30 | def main(config): 31 | global best_mAP 32 | global best_rank1 33 | 34 | loaders = Loader(config) 35 | model = Base(config) 36 | 37 | make_dirs(model.output_path) 38 | make_dirs(model.save_model_path) 39 | make_dirs(model.save_logs_path) 40 | 41 | logger = Logger(os.path.join(os.path.join(config.output_path, 'logs/'), 'log.txt')) 42 | logger('\n' * 3) 43 | logger(config) 44 | 45 | if config.mode == 'train': 46 | if config.resume_train_epoch >= 0: 47 | model.resume_model(config.resume_train_epoch) 48 | start_train_epoch = config.resume_train_epoch 49 | else: 50 | start_train_epoch = 0 51 | 52 | if config.auto_resume_training_from_lastest_step: 53 | root, _, files = os_walk(model.save_model_path) 54 | if len(files) > 0: 55 | indexes = [] 56 | for file in files: 57 | indexes.append(int(file.replace('.pth', '').split('_')[-1])) 58 | indexes = sorted(list(set(indexes)), reverse=False) 59 | model.resume_model(indexes[-1]) 60 | start_train_epoch = indexes[-1] 61 | logger('Time: {}, automatically resume training from the latest step (model {})'.format(time_now(), 62 | indexes[-1])) 63 | 64 | print('Start the 1st Stage of Training') 65 | print('Extracting Image Features') 66 | 67 | visible_image_features = [] 68 | visible_labels = [] 69 | infrared_image_features = [] 70 | infrared_labels = [] 71 | 72 | with torch.no_grad(): 73 | for i, data in enumerate(loaders.get_train_normal_loader()): 74 | rgb_imgs, rgb_pids = data[0].to(model.device), data[2].to(model.device) 75 | ir_imgs, ir_pids = data[1].to(model.device), data[3].to(model.device) 76 | rgb_image_features_proj = model.model(x1=rgb_imgs, get_image=True) 77 | ir_image_features_proj = model.model(x2=ir_imgs, get_image=True) 78 | for i, j, img_feat1, img_feat2 in zip(rgb_pids, ir_pids, rgb_image_features_proj, ir_image_features_proj): 79 | visible_labels.append(i) 80 | visible_image_features.append(img_feat1.cpu()) 81 | infrared_labels.append(j) 82 | infrared_image_features.append(img_feat2.cpu()) 83 | visible_labels_list = torch.stack(visible_labels, dim=0).cuda() 84 | infrared_labels_list = torch.stack(infrared_labels, dim=0).cuda() 85 | visible_image_features_list = torch.stack(visible_image_features, dim=0).cuda() 86 | infrared_image_features_list = torch.stack(infrared_image_features, dim=0).cuda() 87 | batch = config.stage1_batch_size 88 | num_image = infrared_labels_list.shape[0] 89 | i_ter = num_image // batch 90 | del visible_labels, visible_image_features, infrared_labels, infrared_image_features 91 | print('Visible Image Features Extracted, Start Training') 92 | 93 | model._init_optimizer_stage1() 94 | 95 | for current_epoch in range(start_train_epoch, config.stage1_train_epochs): 96 | model.model_lr_scheduler_stage1.step(current_epoch) 97 | _, result = train_stage1(model, num_image, i_ter, batch, visible_labels_list, 98 | visible_image_features_list, infrared_labels_list, infrared_image_features_list) 99 | logger('Time: {}; Epoch: {}; LR: {}; {}'.format(time_now(), current_epoch, 100 | model.model_lr_scheduler_stage1._get_lr 101 | (current_epoch)[0], result)) 102 | 103 | print('The 1st Stage of Trained') 104 | 105 | print('Start the 2st Stage Training') 106 | print('Extracting Image Features') 107 | 108 | image_features = [] 109 | labels = [] 110 | 111 | with torch.no_grad(): 112 | for i, data in enumerate(loaders.get_train_normal_loader()): 113 | rgb_imgs, rgb_pids = data[0].to(model.device), data[2].to(model.device) 114 | ir_imgs, ir_pids = data[1].to(model.device), data[3].to(model.device) 115 | rgb_image_features_proj = model.model(x1=rgb_imgs, get_image=True) 116 | ir_image_features_proj = model.model(x2=ir_imgs, get_image=True) 117 | pids = torch.cat([rgb_pids, ir_pids], dim=0) 118 | image_features_proj = torch.cat([rgb_image_features_proj, ir_image_features_proj], dim=0) 119 | for i, img_feat in zip(pids, image_features_proj): 120 | labels.append(i) 121 | image_features.append(img_feat.cpu()) 122 | labels_list = torch.stack(labels, dim=0).cuda() 123 | image_features_list = torch.stack(image_features, dim=0).cuda() 124 | batch = config.batch_size * 2 125 | num_image = labels_list.shape[0] 126 | i_ter = num_image // batch 127 | del labels, image_features 128 | print('Image Features Extracted, Start Training') 129 | 130 | model._init_optimizer_stage2() 131 | 132 | for current_epoch in range(start_train_epoch, config.stage1_train_epochs): 133 | model.model_lr_scheduler_stage2.step(current_epoch) 134 | _, result = train_stage2(model, num_image, i_ter, batch, labels_list, 135 | image_features_list, ) 136 | logger('Time: {}; Epoch: {}; LR: {}; {}'.format(time_now(), current_epoch, 137 | model.model_lr_scheduler_stage2._get_lr 138 | (current_epoch)[0], result)) 139 | 140 | print('The 2st Stage Trained') 141 | 142 | print('Start the 3st Stage Training') 143 | print('Extracting Text Features') 144 | 145 | num_classes = model.model.module.num_classes 146 | batch = config.batch_size 147 | i_ter = num_classes // batch 148 | left = num_classes - batch * (num_classes // batch) 149 | if left != 0: 150 | i_ter = i_ter + 1 151 | text_features = [] 152 | with torch.no_grad(): 153 | for i in range(i_ter): 154 | if i + 1 != i_ter: 155 | l_list = torch.arange(i * batch, (i + 1) * batch) 156 | else: 157 | l_list = torch.arange(i * batch, num_classes) 158 | text_feature = model.model(label=l_list, get_fusion_text=True) 159 | text_features.append(text_feature.cpu()) 160 | text_features = torch.cat(text_features, 0).cuda() 161 | print('Text Features Extracted, Start Training') 162 | 163 | model._init_optimizer_stage3() 164 | 165 | for current_epoch in range(start_train_epoch, config.total_train_epoch): 166 | model.model_lr_scheduler_stage3.step(current_epoch) 167 | 168 | _, result = train(model, loaders, text_features, config, current_epoch) 169 | logger('Time: {}; Epoch: {}; LR, {}; {}'.format(time_now(), current_epoch, 170 | model.model_lr_scheduler_stage3.get_lr()[0], result)) 171 | 172 | if current_epoch + 1 >= 1 and (current_epoch + 1) % config.eval_epoch == 0: 173 | cmc, mAP, mINP = test(model, loaders, config) 174 | is_best_rank = (cmc[0] >= best_rank1) 175 | best_rank1 = max(cmc[0], best_rank1) 176 | model.save_model(current_epoch, is_best_rank) 177 | logger('Time: {}; Test on Dataset: {}, \nmINP: {} \nmAP: {} \n Rank: {}'.format(time_now(), 178 | config.dataset, 179 | mINP, mAP, cmc)) 180 | 181 | elif config.mode == 'test': 182 | model.resume_model(config.resume_test_model) 183 | cmc, mAP, mINP = test(model, loaders, config) 184 | logger('Time: {}; Test on Dataset: {}, \nmINP: {} \nmAP: {} \n Rank: {}'.format(time_now(), 185 | config.dataset, 186 | mINP, mAP, cmc)) 187 | 188 | if __name__ == '__main__': 189 | 190 | parser = argparse.ArgumentParser() 191 | parser.add_argument('--cuda', type=str, default='cuda') 192 | parser.add_argument('--mode', type=str, default='train', help='train, test') 193 | parser.add_argument('--test_mode', default='all', type=str, help='all or indoor') 194 | parser.add_argument('--gall_mode', default='single', type=str, help='single or multi') 195 | parser.add_argument('--regdb_test_mode', default='v-t', type=str, help='') 196 | parser.add_argument('--dataset', default='sysu', help='dataset name: regdb or sysu]') 197 | parser.add_argument('--sysu_data_path', type=str, default='/ssd/s01015/data/SYSU-MM01/') 198 | parser.add_argument('--regdb_data_path', type=str, default='/ssd/s01015/data/RegDB/') 199 | parser.add_argument('--trial', default=1, type=int, help='trial (only for RegDB dataset)') 200 | parser.add_argument('--batch-size', default=32, type=int, metavar='B', help='training batch size') 201 | parser.add_argument('--img_w', default=144, type=int, metavar='imgw', help='img width') 202 | parser.add_argument('--img_h', default=288, type=int, metavar='imgh', help='img height') 203 | parser.add_argument('--seed', type=int, default=1) 204 | parser.add_argument('--pid_num', type=int, default=395) 205 | parser.add_argument('--learning_rate', type=float, default=0.0003) 206 | parser.add_argument('--weight_decay', type=float, default=0.0005) 207 | parser.add_argument('--milestones', nargs='+', type=int, default=[40, 70], 208 | help='milestones for the learning rate decay') 209 | 210 | parser.add_argument('--stage1_batch-size', default=32, type=int, metavar='B', help='training batch size') 211 | parser.add_argument('--stage1_learning_rate', type=float, default=0.0003) 212 | parser.add_argument('--stage2_learning_rate', type=float, default=0.0003) 213 | parser.add_argument('--stage1_weight_decay', type=float, default=1e-4) 214 | parser.add_argument('--stage1_lr_min', type=float, default=1e-6) 215 | parser.add_argument('--stage1_warmup_lr_init', type=float, default=0.00001) 216 | parser.add_argument('--stage1_warmup_epochs', type=int, default=5) 217 | parser.add_argument('--stage1_train_epochs', type=int, default=60) 218 | 219 | parser.add_argument('--lambda1', type=float, default=0.15) 220 | parser.add_argument('--lambda2', type=float, default=0.05) 221 | parser.add_argument('--lambda3', type=float, default=0.1) 222 | 223 | parser.add_argument('--loss', default=1, type=int, 224 | help='num of pos per identity in each modality') 225 | 226 | parser.add_argument('--num_pos', default=4, type=int, 227 | help='num of pos per identity in each modality') 228 | parser.add_argument('--num_workers', default=8, type=int, 229 | help='num of pos per identity in each modality') 230 | parser.add_argument('--output_path', type=str, default='models/base/', 231 | help='path to save related informations') 232 | parser.add_argument('--max_save_model_num', type=int, default=1, help='0 for max num is infinit') 233 | parser.add_argument('--resume_train_epoch', type=int, default=-1, help='-1 for no resuming') 234 | parser.add_argument('--auto_resume_training_from_lastest_step', type=ast.literal_eval, default=True) 235 | parser.add_argument('--total_train_epoch', type=int, default=120) 236 | parser.add_argument('--eval_epoch', type=int, default=1) 237 | parser.add_argument('--resume_test_model', type=int, default=119, help='-1 for no resuming') 238 | 239 | config = parser.parse_args() 240 | seed_torch(config.seed) 241 | main(config) 242 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import * -------------------------------------------------------------------------------- /network/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/network/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /network/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/network/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /network/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/network/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /network/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/network/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/gem_pool.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/network/__pycache__/gem_pool.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/lr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/network/__pycache__/lr.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/network/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /network/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/network/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /network/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/network/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /network/clip/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .clip import * -------------------------------------------------------------------------------- /network/clip/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/network/clip/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /network/clip/__pycache__/clip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/network/clip/__pycache__/clip.cpython-38.pyc -------------------------------------------------------------------------------- /network/clip/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/network/clip/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /network/clip/__pycache__/simple_tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/network/clip/__pycache__/simple_tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /network/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/network/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /network/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | 22 | if torch.__version__.split(".") < ["1", "7", "1"]: 23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 24 | 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "ViT-B-32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 35 | "ViT-B-16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 36 | } 37 | 38 | 39 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 40 | os.makedirs(root, exist_ok=True) 41 | filename = os.path.basename(url) 42 | 43 | expected_sha256 = url.split("/")[-2] 44 | download_target = os.path.join(root, filename) 45 | 46 | if os.path.exists(download_target) and not os.path.isfile(download_target): 47 | raise RuntimeError(f"{download_target} exists and is not a regular file") 48 | 49 | if os.path.isfile(download_target): 50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 51 | return download_target 52 | else: 53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 54 | 55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 57 | while True: 58 | buffer = source.read(8192) 59 | if not buffer: 60 | break 61 | 62 | output.write(buffer) 63 | loop.update(len(buffer)) 64 | 65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 67 | 68 | return download_target 69 | 70 | 71 | def _transform(n_px): 72 | return Compose([ 73 | Resize(n_px, interpolation=BICUBIC), 74 | CenterCrop(n_px), 75 | lambda image: image.convert("RGB"), 76 | ToTensor(), 77 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 78 | ]) 79 | 80 | 81 | def available_models() -> List[str]: 82 | return list(_MODELS.keys()) 83 | 84 | 85 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): 86 | 87 | if name in _MODELS: 88 | model_path = _download(_MODELS[name]) 89 | elif os.path.isfile(name): 90 | model_path = name 91 | else: 92 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 93 | 94 | try: 95 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 96 | state_dict = None 97 | except RuntimeError: 98 | if jit: 99 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 100 | jit = False 101 | state_dict = torch.load(model_path, map_location="cpu") 102 | 103 | if not jit: 104 | model = build_model(state_dict or model.state_dict()).to(device) 105 | if str(device) == "cpu": 106 | model.float() 107 | return model, _transform(model.visual.input_resolution) 108 | 109 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 110 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 111 | 112 | def patch_device(module): 113 | try: 114 | graphs = [module.graph] if hasattr(module, "graph") else [] 115 | except RuntimeError: 116 | graphs = [] 117 | 118 | if hasattr(module, "forward1"): 119 | graphs.append(module.forward1.graph) 120 | 121 | for graph in graphs: 122 | for node in graph.findAllNodes("prim::Constant"): 123 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 124 | node.copyAttributes(device_node) 125 | 126 | model.apply(patch_device) 127 | patch_device(model.encode_image) 128 | patch_device(model.encode_text) 129 | 130 | # patch dtype to float32 on CPU 131 | if str(device) == "cpu": 132 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 133 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 134 | float_node = float_input.node() 135 | 136 | def patch_float(module): 137 | try: 138 | graphs = [module.graph] if hasattr(module, "graph") else [] 139 | except RuntimeError: 140 | graphs = [] 141 | 142 | if hasattr(module, "forward1"): 143 | graphs.append(module.forward1.graph) 144 | 145 | for graph in graphs: 146 | for node in graph.findAllNodes("aten::to"): 147 | inputs = list(node.inputs()) 148 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 149 | if inputs[i].node()["value"] == 5: 150 | inputs[i].node().copyAttributes(float_node) 151 | 152 | model.apply(patch_float) 153 | patch_float(model.encode_image) 154 | patch_float(model.encode_text) 155 | 156 | model.float() 157 | 158 | return model, _transform(model.input_resolution.item()) 159 | 160 | 161 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 162 | # import pdb 163 | # pdb.set_trace() 164 | if isinstance(texts, str): 165 | texts = [texts] #['a photo of a face.'] 166 | 167 | sot_token = _tokenizer.encoder["<|startoftext|>"] #49406 168 | eot_token = _tokenizer.encoder["<|endoftext|>"] #49407 169 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 170 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) #1,77 171 | 172 | for i, tokens in enumerate(all_tokens): 173 | if len(tokens) > context_length: #context_length 77 174 | if truncate: 175 | tokens = tokens[:context_length] 176 | tokens[-1] = eot_token 177 | else: 178 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 179 | result[i, :len(tokens)] = torch.tensor(tokens) 180 | 181 | return result 182 | -------------------------------------------------------------------------------- /network/clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | 19 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | 22 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 23 | 24 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 25 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 26 | 27 | self.relu = nn.ReLU(inplace=True) 28 | self.downsample = None 29 | self.stride = stride 30 | 31 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 32 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 33 | self.downsample = nn.Sequential(OrderedDict([ 34 | ("-1", nn.AvgPool2d(stride)), 35 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 36 | ("1", nn.BatchNorm2d(planes * self.expansion)) 37 | ])) 38 | 39 | def forward(self, x: torch.Tensor): 40 | identity = x 41 | 42 | out = self.relu(self.bn1(self.conv1(x))) 43 | out = self.relu(self.bn2(self.conv2(out))) 44 | out = self.avgpool(out) 45 | out = self.bn3(self.conv3(out)) 46 | 47 | if self.downsample is not None: 48 | identity = self.downsample(x) 49 | 50 | out += identity 51 | out = self.relu(out) 52 | return out 53 | 54 | 55 | class AttentionPool2d(nn.Module): 56 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 57 | super().__init__() 58 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5) 59 | self.k_proj = nn.Linear(embed_dim, embed_dim) 60 | self.q_proj = nn.Linear(embed_dim, embed_dim) 61 | self.v_proj = nn.Linear(embed_dim, embed_dim) 62 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 63 | self.num_heads = num_heads 64 | 65 | def forward(self, x): 66 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 67 | 1) # NCHW -> (HW)NC #32,2048,7,7 ->49, 32, 2048 68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 50,32,2048 69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 70 | x, _ = F.multi_head_attention_forward( 71 | query=x, key=x, value=x, 72 | embed_dim_to_check=x.shape[-1], 73 | num_heads=self.num_heads, 74 | q_proj_weight=self.q_proj.weight, 75 | k_proj_weight=self.k_proj.weight, 76 | v_proj_weight=self.v_proj.weight, 77 | in_proj_weight=None, 78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 79 | bias_k=None, 80 | bias_v=None, 81 | add_zero_attn=False, 82 | dropout_p=0, 83 | out_proj_weight=self.c_proj.weight, 84 | out_proj_bias=self.c_proj.bias, 85 | use_separate_proj_weight=True, 86 | training=self.training, 87 | need_weights=False 88 | ) 89 | 90 | return x 91 | 92 | 93 | class ModifiedResNet(nn.Module): 94 | 95 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 96 | super().__init__() 97 | self.output_dim = output_dim 98 | self.input_resolution = input_resolution 99 | 100 | # the 3-layer stem 101 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 102 | self.bn1 = nn.BatchNorm2d(width // 2) 103 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 104 | self.bn2 = nn.BatchNorm2d(width // 2) 105 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 106 | self.bn3 = nn.BatchNorm2d(width) 107 | self.avgpool = nn.AvgPool2d(2) 108 | self.relu = nn.ReLU(inplace=True) 109 | 110 | self._inplanes = width # this is a *mutable* variable used during construction 111 | self.layer1 = self._make_layer(width, layers[0]) 112 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 113 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 114 | self.layer4 = self._make_layer(width * 8, layers[3], stride=1) 115 | embed_dim = width * 32 # the ResNet feature dimension 116 | self.attnpool = AttentionPool2d(input_resolution, embed_dim, heads, output_dim) 117 | 118 | def _make_layer(self, planes, blocks, stride=1): 119 | layers = [Bottleneck(self._inplanes, planes, stride)] 120 | 121 | self._inplanes = planes * Bottleneck.expansion 122 | for _ in range(1, blocks): 123 | layers.append(Bottleneck(self._inplanes, planes)) 124 | 125 | return nn.Sequential(*layers) 126 | 127 | def forward(self, x): 128 | def stem(x): 129 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 130 | x = self.relu(bn(conv(x))) 131 | x = self.avgpool(x) 132 | return x 133 | 134 | x = x.type(self.conv1.weight.dtype) 135 | x = stem(x) 136 | x = self.layer1(x) 137 | x = self.layer2(x) 138 | x3 = self.layer3(x) 139 | x4 = self.layer4(x3) 140 | xproj = self.attnpool(x4) 141 | 142 | return x4 143 | 144 | 145 | class LayerNorm(nn.LayerNorm): 146 | """Subclass torch's LayerNorm to handle fp16.""" 147 | 148 | def forward(self, x: torch.Tensor): 149 | orig_type = x.dtype 150 | ret = super().forward(x.type(torch.float32)) 151 | return ret.type(orig_type) 152 | 153 | 154 | class QuickGELU(nn.Module): 155 | def forward(self, x: torch.Tensor): 156 | return x * torch.sigmoid(1.702 * x) 157 | 158 | 159 | class ResidualAttentionBlock(nn.Module): 160 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 161 | super().__init__() 162 | 163 | self.attn = nn.MultiheadAttention(d_model, n_head) 164 | self.ln_1 = LayerNorm(d_model) 165 | self.mlp = nn.Sequential(OrderedDict([ 166 | ("c_fc", nn.Linear(d_model, d_model * 4)), 167 | ("gelu", QuickGELU()), 168 | ("c_proj", nn.Linear(d_model * 4, d_model)) 169 | ])) 170 | self.ln_2 = LayerNorm(d_model) 171 | self.attn_mask = attn_mask 172 | 173 | def attention(self, x: torch.Tensor): 174 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 175 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 176 | 177 | def forward(self, x: torch.Tensor): 178 | x = x + self.attention(self.ln_1(x)) 179 | x = x + self.mlp(self.ln_2(x)) 180 | return x 181 | 182 | 183 | class Transformer(nn.Module): 184 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 185 | super().__init__() 186 | self.width = width 187 | self.layers = layers 188 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 189 | 190 | def forward(self, x: torch.Tensor): 191 | return self.resblocks(x) 192 | 193 | 194 | class VisionTransformer(nn.Module): 195 | def __init__(self, h_resolution: int, w_resolution: int, patch_size: int, stride_size: int, width: int, layers: int, 196 | heads: int, output_dim: int): 197 | super().__init__() 198 | self.h_resolution = h_resolution 199 | self.w_resolution = w_resolution 200 | self.output_dim = output_dim 201 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=stride_size, 202 | bias=False) 203 | 204 | scale = width ** -0.5 205 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 206 | self.positional_embedding = nn.Parameter(scale * torch.randn(h_resolution * w_resolution + 1, width)) 207 | self.ln_pre = LayerNorm(width) 208 | 209 | self.transformer = Transformer(width, layers, heads) 210 | 211 | self.ln_post = LayerNorm(width) 212 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 213 | 214 | def forward(self, x: torch.Tensor, cv_emb=None): 215 | x = self.conv1(x) # shape = [*, width, grid, grid] 216 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 217 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 218 | x = torch.cat( 219 | [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 220 | x], dim=1) # shape = [*, grid ** 2 + 1, width] 221 | if cv_emb != None: 222 | x[:, 0] = x[:, 0] + cv_emb 223 | x = x + self.positional_embedding.to(x.dtype) 224 | x = self.ln_pre(x) 225 | 226 | x = x.permute(1, 0, 2) # NLD -> LND 227 | 228 | x11 = self.transformer.resblocks[:11](x) 229 | x12 = self.transformer.resblocks[11](x11) 230 | x11 = x11.permute(1, 0, 2) # LND -> NLD 231 | x12 = x12.permute(1, 0, 2) # LND -> NLD 232 | 233 | x12 = self.ln_post(x12) 234 | 235 | if self.proj is not None: 236 | xproj = x12 @ self.proj 237 | 238 | return x11, x12, xproj 239 | 240 | 241 | class CLIP(nn.Module): 242 | def __init__(self, 243 | embed_dim: int, 244 | # vision 245 | image_resolution: int, 246 | vision_layers: Union[Tuple[int, int, int, int], int], 247 | vision_width: int, 248 | vision_patch_size: int, 249 | vision_stride_size: int, 250 | # text 251 | context_length: int, 252 | vocab_size: int, 253 | transformer_width: int, 254 | transformer_heads: int, 255 | transformer_layers: int, 256 | h_resolution: int, 257 | w_resolution: int 258 | ): 259 | super().__init__() 260 | 261 | self.context_length = context_length 262 | 263 | if isinstance(vision_layers, (tuple, list)): 264 | vision_heads = vision_width * 32 // 64 265 | self.visual = ModifiedResNet( 266 | layers=vision_layers, 267 | output_dim=embed_dim, 268 | heads=vision_heads, 269 | input_resolution=h_resolution * w_resolution, 270 | width=vision_width 271 | ) 272 | else: 273 | vision_heads = vision_width // 64 274 | self.visual = VisionTransformer( 275 | h_resolution=h_resolution, 276 | w_resolution=w_resolution, 277 | patch_size=vision_patch_size, 278 | stride_size=vision_stride_size, 279 | width=vision_width, 280 | layers=vision_layers, 281 | heads=vision_heads, 282 | output_dim=embed_dim 283 | ) 284 | 285 | self.transformer = Transformer( 286 | width=transformer_width, 287 | layers=transformer_layers, 288 | heads=transformer_heads, 289 | attn_mask=self.build_attention_mask() 290 | ) 291 | 292 | self.vocab_size = vocab_size 293 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 294 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 295 | self.ln_final = LayerNorm(transformer_width) 296 | 297 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 298 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 299 | 300 | self.initialize_parameters() 301 | 302 | def initialize_parameters(self): 303 | nn.init.normal_(self.token_embedding.weight, std=0.02) 304 | nn.init.normal_(self.positional_embedding, std=0.01) 305 | 306 | if isinstance(self.visual, ModifiedResNet): 307 | if self.visual.attnpool is not None: 308 | std = self.visual.attnpool.c_proj.in_features ** -0.5 309 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 310 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 311 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 312 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 313 | 314 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 315 | for name, param in resnet_block.named_parameters(): 316 | if name.endswith("bn3.weight"): 317 | nn.init.zeros_(param) 318 | 319 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 320 | attn_std = self.transformer.width ** -0.5 321 | fc_std = (2 * self.transformer.width) ** -0.5 322 | for block in self.transformer.resblocks: 323 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 324 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 325 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 326 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 327 | 328 | if self.text_projection is not None: 329 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 330 | 331 | def build_attention_mask(self): 332 | # lazily create causal attention mask, with full attention between the vision tokens 333 | # pytorch uses additive attention mask; fill with -inf 334 | mask = torch.empty(self.context_length, self.context_length) 335 | mask.fill_(float("-inf")) 336 | mask.triu_(1) # zero out the lower diagonal 337 | return mask 338 | 339 | @property 340 | def dtype(self): 341 | return self.visual.conv1.weight.dtype 342 | 343 | def encode_image(self, image): 344 | return self.visual(image.type(self.dtype)) 345 | 346 | def encode_text(self, text): 347 | x = self.token_embedding(text).type(self.dtype) 348 | 349 | x = x + self.positional_embedding.type(self.dtype) 350 | x = x.permute(1, 0, 2) 351 | x = self.transformer(x) 352 | x = x.permute(1, 0, 2) 353 | x = self.ln_final(x).type(self.dtype) 354 | 355 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 356 | 357 | return x 358 | 359 | def forward(self, image, text): 360 | image_features = self.encode_image(image) 361 | text_features = self.encode_text(text) 362 | 363 | # normalized features 364 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 365 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 366 | 367 | # cosine similarity as logits 368 | logit_scale = self.logit_scale.exp() 369 | logits_per_image = logit_scale * image_features @ text_features.t() 370 | logits_per_text = logit_scale * text_features @ image_features.t() 371 | 372 | # shape = [global_batch_size, global_batch_size] 373 | return logits_per_image, logits_per_text 374 | 375 | 376 | def convert_weights(model: nn.Module): 377 | """Convert applicable model parameters to fp16""" 378 | 379 | def _convert_weights_to_fp16(l): 380 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 381 | l.weight.data = l.weight.data.float() 382 | if l.bias is not None: 383 | l.bias.data = l.bias.data.float() 384 | 385 | if isinstance(l, nn.MultiheadAttention): 386 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 387 | tensor = getattr(l, attr) 388 | if tensor is not None: 389 | tensor.data = tensor.data.float() 390 | 391 | for name in ["text_projection", "proj"]: 392 | if hasattr(l, name): 393 | attr = getattr(l, name) 394 | if attr is not None: 395 | attr.data = attr.data.float() 396 | 397 | model.apply(_convert_weights_to_fp16) 398 | 399 | 400 | def build_model(state_dict: dict, h_resolution: int, w_resolution: int, vision_stride_size: int): 401 | vit = "visual.proj" in state_dict 402 | 403 | if vit: 404 | vision_width = state_dict["visual.conv1.weight"].shape[0] 405 | vision_layers = len( 406 | [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 407 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 408 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 409 | image_resolution = vision_patch_size * grid_size 410 | else: # RN50 411 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in 412 | [1, 2, 3, 4]] 413 | vision_layers = tuple(counts) 414 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 415 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 416 | vision_patch_size = None 417 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 418 | image_resolution = output_width * 32 419 | 420 | embed_dim = state_dict["text_projection"].shape[1] 421 | context_length = state_dict["positional_embedding"].shape[0] # 77 (77,512) 422 | vocab_size = state_dict["token_embedding.weight"].shape[0] 423 | transformer_width = state_dict["ln_final.weight"].shape[0] 424 | transformer_heads = transformer_width // 64 425 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) 426 | 427 | model = CLIP( 428 | embed_dim, 429 | image_resolution, vision_layers, vision_width, vision_patch_size, vision_stride_size, 430 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, 431 | h_resolution, w_resolution 432 | ) 433 | if vit: 434 | state_dict["visual.positional_embedding"] = resize_pos_embed(state_dict["visual.positional_embedding"], 435 | model.visual.positional_embedding, h_resolution, 436 | w_resolution) 437 | else: # RN50 438 | state_dict["visual.attnpool.positional_embedding"] = resize_pos_embed( 439 | state_dict["visual.attnpool.positional_embedding"], model.visual.attnpool.positional_embedding, 440 | h_resolution, w_resolution) 441 | 442 | for key in ["input_resolution", "context_length", "vocab_size"]: 443 | if key in state_dict: 444 | del state_dict[key] 445 | 446 | convert_weights(model) 447 | 448 | model.load_state_dict(state_dict) 449 | return model.eval() 450 | 451 | 452 | import math 453 | 454 | 455 | def resize_pos_embed(posemb, posemb_new, hight, width): 456 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 457 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 458 | print('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) 459 | 460 | ntok_new = posemb_new.shape[0] # 129,2048 461 | 462 | posemb_token, posemb_grid = posemb[:1], posemb[1:] 463 | ntok_new -= 1 464 | 465 | gs_old = int(math.sqrt(len(posemb_grid))) # 14 466 | print('Position embedding resize to height:{} width: {}'.format(hight, width)) 467 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 468 | posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear') 469 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1) 470 | posemb = torch.cat([posemb_token, posemb_grid.squeeze()], dim=0) 471 | return posemb -------------------------------------------------------------------------------- /network/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | 18 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 19 | cs = bs[:] 20 | n = 0 21 | for b in range(2**8): 22 | if b not in bs: 23 | bs.append(b) 24 | cs.append(2**8+n) 25 | n += 1 26 | cs = [chr(n) for n in cs] 27 | return dict(zip(bs, cs)) 28 | 29 | 30 | def get_pairs(word): 31 | 32 | pairs = set() 33 | prev_char = word[0] 34 | for char in word[1:]: 35 | pairs.add((prev_char, char)) 36 | prev_char = char 37 | return pairs 38 | 39 | 40 | def basic_clean(text): 41 | text = ftfy.fix_text(text) 42 | text = html.unescape(html.unescape(text)) 43 | return text.strip() 44 | 45 | 46 | def whitespace_clean(text): 47 | text = re.sub(r'\s+', ' ', text) 48 | text = text.strip() 49 | return text 50 | 51 | 52 | class SimpleTokenizer(object): 53 | def __init__(self, bpe_path: str = default_bpe()): 54 | self.byte_encoder = bytes_to_unicode() 55 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 56 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 57 | merges = merges[1:49152-256-2+1] 58 | merges = [tuple(merge.split()) for merge in merges] 59 | vocab = list(bytes_to_unicode().values()) 60 | vocab = vocab + [v+'' for v in vocab] 61 | for merge in merges: 62 | vocab.append(''.join(merge)) 63 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 64 | self.encoder = dict(zip(vocab, range(len(vocab)))) 65 | self.decoder = {v: k for k, v in self.encoder.items()} 66 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 67 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 68 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 69 | 70 | def bpe(self, token): 71 | if token in self.cache: 72 | return self.cache[token] 73 | word = tuple(token[:-1]) + ( token[-1] + '',) 74 | pairs = get_pairs(word) 75 | 76 | if not pairs: 77 | return token+'' 78 | 79 | while True: 80 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 81 | if bigram not in self.bpe_ranks: 82 | break 83 | first, second = bigram 84 | new_word = [] 85 | i = 0 86 | while i < len(word): 87 | try: 88 | j = word.index(first, i) 89 | new_word.extend(word[i:j]) 90 | i = j 91 | except: 92 | new_word.extend(word[i:]) 93 | break 94 | 95 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 96 | new_word.append(first+second) 97 | i += 2 98 | else: 99 | new_word.append(word[i]) 100 | i += 1 101 | new_word = tuple(new_word) 102 | word = new_word 103 | if len(word) == 1: 104 | break 105 | else: 106 | pairs = get_pairs(word) 107 | word = ' '.join(word) 108 | self.cache[token] = word 109 | return word 110 | 111 | def encode(self, text): 112 | bpe_tokens = [] 113 | text = whitespace_clean(basic_clean(text)).lower() 114 | for token in re.findall(self.pat, text): 115 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 116 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 117 | return bpe_tokens 118 | 119 | def decode(self, tokens): 120 | text = ''.join([self.decoder[token] for token in tokens]) 121 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 122 | return text 123 | -------------------------------------------------------------------------------- /network/gem_pool.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class GeneralizedMeanPooling(nn.Module): 8 | 9 | def __init__(self, norm, output_size=1, eps=1e-6): 10 | super(GeneralizedMeanPooling, self).__init__() 11 | assert norm > 0 12 | self.p = float(norm) 13 | self.output_size = output_size 14 | self.eps = eps 15 | 16 | def forward(self, x): 17 | x = x.clamp(min=self.eps).pow(self.p) 18 | return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p) 19 | 20 | def __repr__(self): 21 | return self.__class__.__name__ + '(' \ 22 | + str(self.p) + ', ' \ 23 | + 'output_size=' + str(self.output_size) + ')' 24 | 25 | 26 | class GeneralizedMeanPoolingP(GeneralizedMeanPooling): 27 | """ Same, but norm is trainable 28 | """ 29 | def __init__(self, norm=3, output_size=1, eps=1e-6): 30 | super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps) 31 | self.p = nn.Parameter(torch.ones(1) * norm) -------------------------------------------------------------------------------- /network/lr.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from bisect import bisect_right 4 | import torch 5 | 6 | from typing import Dict, Any 7 | 8 | 9 | class Scheduler: 10 | 11 | def __init__(self, 12 | optimizer: torch.optim.Optimizer, 13 | param_group_field: str, 14 | noise_range_t=None, 15 | noise_type='normal', 16 | noise_pct=0.67, 17 | noise_std=1.0, 18 | noise_seed=None, 19 | initialize: bool = True) -> None: 20 | self.optimizer = optimizer 21 | self.param_group_field = param_group_field 22 | self._initial_param_group_field = f"initial_{param_group_field}" 23 | if initialize: 24 | for i, group in enumerate(self.optimizer.param_groups): 25 | if param_group_field not in group: 26 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]") 27 | group.setdefault(self._initial_param_group_field, group[param_group_field]) 28 | else: 29 | for i, group in enumerate(self.optimizer.param_groups): 30 | if self._initial_param_group_field not in group: 31 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") 32 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] 33 | self.metric = None # any point to having this for all? 34 | self.noise_range_t = noise_range_t 35 | self.noise_pct = noise_pct 36 | self.noise_type = noise_type 37 | self.noise_std = noise_std 38 | self.noise_seed = noise_seed if noise_seed is not None else 42 39 | self.update_groups(self.base_values) 40 | 41 | def state_dict(self) -> Dict[str, Any]: 42 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 43 | 44 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 45 | self.__dict__.update(state_dict) 46 | 47 | def get_epoch_values(self, epoch: int): 48 | return None 49 | 50 | def get_update_values(self, num_updates: int): 51 | return None 52 | 53 | def step(self, epoch: int, metric: float = None) -> None: 54 | self.metric = metric 55 | values = self.get_epoch_values(epoch) 56 | if values is not None: 57 | values = self._add_noise(values, epoch) 58 | self.update_groups(values) 59 | 60 | def step_update(self, num_updates: int, metric: float = None): 61 | self.metric = metric 62 | values = self.get_update_values(num_updates) 63 | if values is not None: 64 | values = self._add_noise(values, num_updates) 65 | self.update_groups(values) 66 | 67 | def update_groups(self, values): 68 | if not isinstance(values, (list, tuple)): 69 | values = [values] * len(self.optimizer.param_groups) 70 | for param_group, value in zip(self.optimizer.param_groups, values): 71 | param_group[self.param_group_field] = value 72 | 73 | def _add_noise(self, lrs, t): 74 | if self.noise_range_t is not None: 75 | if isinstance(self.noise_range_t, (list, tuple)): 76 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] 77 | else: 78 | apply_noise = t >= self.noise_range_t 79 | if apply_noise: 80 | g = torch.Generator() 81 | g.manual_seed(self.noise_seed + t) 82 | if self.noise_type == 'normal': 83 | while True: 84 | # resample if noise out of percent limit, brute force but shouldn't spin much 85 | noise = torch.randn(1, generator=g).item() 86 | if abs(noise) < self.noise_pct: 87 | break 88 | else: 89 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 90 | lrs = [v + v * noise for v in lrs] 91 | return lrs 92 | 93 | 94 | _logger = logging.getLogger(__name__) 95 | 96 | 97 | class CosineLRScheduler(Scheduler): 98 | 99 | def __init__(self, 100 | optimizer: torch.optim.Optimizer, 101 | t_initial: int, 102 | t_mul: float = 1., 103 | lr_min: float = 0., 104 | decay_rate: float = 1., 105 | warmup_t=0, 106 | warmup_lr_init=0, 107 | warmup_prefix=False, 108 | cycle_limit=0, 109 | t_in_epochs=True, 110 | noise_range_t=None, 111 | noise_pct=0.67, 112 | noise_std=1.0, 113 | noise_seed=42, 114 | initialize=True) -> None: 115 | super().__init__( 116 | optimizer, param_group_field="lr", 117 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 118 | initialize=initialize) 119 | 120 | assert t_initial > 0 121 | assert lr_min >= 0 122 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 123 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 124 | "rate since t_initial = t_mul = eta_mul = 1.") 125 | self.t_initial = t_initial 126 | self.t_mul = t_mul 127 | self.lr_min = lr_min 128 | self.decay_rate = decay_rate 129 | self.cycle_limit = cycle_limit 130 | self.warmup_t = warmup_t 131 | self.warmup_lr_init = warmup_lr_init 132 | self.warmup_prefix = warmup_prefix 133 | self.t_in_epochs = t_in_epochs 134 | if self.warmup_t: 135 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 136 | super().update_groups(self.warmup_lr_init) 137 | else: 138 | self.warmup_steps = [1 for _ in self.base_values] 139 | 140 | def _get_lr(self, t): 141 | if t < self.warmup_t: 142 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 143 | else: 144 | if self.warmup_prefix: 145 | t = t - self.warmup_t 146 | 147 | if self.t_mul != 1: 148 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 149 | t_i = self.t_mul ** i * self.t_initial 150 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 151 | else: 152 | i = t // self.t_initial 153 | t_i = self.t_initial 154 | t_curr = t - (self.t_initial * i) 155 | 156 | gamma = self.decay_rate ** i 157 | lr_min = self.lr_min * gamma 158 | lr_max_values = [v * gamma for v in self.base_values] 159 | 160 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 161 | lrs = [ 162 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 163 | ] 164 | else: 165 | lrs = [self.lr_min for _ in self.base_values] 166 | 167 | return lrs 168 | 169 | def get_epoch_values(self, epoch: int): 170 | if self.t_in_epochs: 171 | return self._get_lr(epoch) 172 | else: 173 | return None 174 | 175 | def get_update_values(self, num_updates: int): 176 | if not self.t_in_epochs: 177 | return self._get_lr(num_updates) 178 | else: 179 | return None 180 | 181 | def get_cycle_length(self, cycles=0): 182 | if not cycles: 183 | cycles = self.cycle_limit 184 | cycles = max(1, cycles) 185 | if self.t_mul == 1.0: 186 | return self.t_initial * cycles 187 | else: 188 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 189 | -------------------------------------------------------------------------------- /network/model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torchvision 4 | import torch.nn as nn 5 | from .gem_pool import GeneralizedMeanPoolingP 6 | 7 | class Normalize(nn.Module): 8 | def __init__(self, power=2): 9 | super(Normalize, self).__init__() 10 | self.power = power 11 | 12 | def forward(self, x): 13 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 14 | out = x.div(norm) 15 | return out 16 | 17 | def weights_init_kaiming(m): 18 | classname = m.__class__.__name__ 19 | if classname.find('Linear') != -1: 20 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 21 | nn.init.constant_(m.bias, 0.0) 22 | elif classname.find('Conv') != -1: 23 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 24 | if m.bias is not None: 25 | nn.init.constant_(m.bias, 0.0) 26 | elif classname.find('BatchNorm') != -1: 27 | if m.affine: 28 | nn.init.constant_(m.weight, 1.0) 29 | nn.init.constant_(m.bias, 0.0) 30 | elif classname.find('InstanceNorm') != -1: 31 | if m.affine: 32 | nn.init.constant_(m.weight, 1.0) 33 | nn.init.constant_(m.bias, 0.0) 34 | 35 | def weights_init_classifier(m): 36 | classname = m.__class__.__name__ 37 | if classname.find('Linear') != -1: 38 | nn.init.normal_(m.weight, std=0.001) 39 | if m.bias: 40 | nn.init.constant_(m.bias, 0.0) 41 | 42 | class Classifier(nn.Module): 43 | def __init__(self, pid_num): 44 | super(Classifier, self, ).__init__() 45 | self.pid_num = pid_num 46 | self.GEM = GeneralizedMeanPoolingP() 47 | self.BN = nn.BatchNorm1d(2048) 48 | self.BN.apply(weights_init_kaiming) 49 | 50 | self.classifier = nn.Linear(2048, self.pid_num, bias=False) 51 | self.classifier.apply(weights_init_classifier) 52 | 53 | self.l2_norm = Normalize(2) 54 | 55 | def forward(self, features_map): 56 | features = self.GEM(features_map) 57 | bn_features = self.BN(features.squeeze()) 58 | cls_score = self.classifier(bn_features) 59 | return features, cls_score, self.l2_norm(bn_features) 60 | 61 | class Classifier2(nn.Module): 62 | def __init__(self, pid_num): 63 | super(Classifier2, self, ).__init__() 64 | self.pid_num = pid_num 65 | self.BN = nn.BatchNorm1d(1024) 66 | self.BN.apply(weights_init_kaiming) 67 | 68 | self.classifier = nn.Linear(1024, self.pid_num, bias=False) 69 | self.classifier.apply(weights_init_classifier) 70 | 71 | self.l2_norm = Normalize(2) 72 | 73 | def forward(self, features): 74 | bn_features = self.BN(features.squeeze()) 75 | cls_score = self.classifier(bn_features) 76 | return cls_score, self.l2_norm(features) 77 | 78 | class PromptLearner1(nn.Module): 79 | def __init__(self, num_class, dtype, token_embedding): 80 | super().__init__() 81 | ctx_init = "A photo of a X X X X person." 82 | ctx_dim = 512 83 | ctx_init = ctx_init.replace("_", " ") 84 | n_ctx = 4 85 | 86 | tokenized_prompts = clip.tokenize(ctx_init).cuda() 87 | with torch.no_grad(): 88 | embedding = token_embedding(tokenized_prompts).type(dtype) 89 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 90 | 91 | n_cls_ctx = 4 92 | cls_vectors = torch.empty(num_class, n_cls_ctx, ctx_dim, dtype=dtype) 93 | nn.init.normal_(cls_vectors, std=0.02) 94 | self.cls_ctx = nn.Parameter(cls_vectors) 95 | 96 | self.register_buffer("token_prefix", embedding[:, :n_ctx + 1, :]) 97 | self.register_buffer("token_suffix", embedding[:, n_ctx + 1 + n_cls_ctx:, :]) 98 | self.num_class = num_class 99 | self.n_cls_ctx = n_cls_ctx 100 | 101 | def forward(self, label): 102 | cls_ctx = self.cls_ctx[label] 103 | b = label.shape[0] 104 | prefix = self.token_prefix.expand(b, -1, -1) 105 | suffix = self.token_suffix.expand(b, -1, -1) 106 | 107 | prompts = torch.cat( 108 | [ 109 | prefix, # (n_cls, 1, dim) 110 | cls_ctx, # (n_cls, n_ctx, dim) 111 | suffix, # (n_cls, *, dim) 112 | ], 113 | dim=1, 114 | ) 115 | return prompts 116 | 117 | class PromptLearner2(nn.Module): 118 | def __init__(self, num_class, dtype, token_embedding): 119 | super().__init__() 120 | ctx_init = "A photo of a X X X X person." 121 | ctx_dim = 512 122 | ctx_init = ctx_init.replace("_", " ") 123 | n_ctx = 4 124 | 125 | tokenized_prompts = clip.tokenize(ctx_init).cuda() 126 | with torch.no_grad(): 127 | embedding = token_embedding(tokenized_prompts).type(dtype) 128 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 129 | 130 | n_cls_ctx = 4 131 | cls_vectors = torch.empty(num_class, n_cls_ctx, ctx_dim, dtype=dtype) 132 | nn.init.normal_(cls_vectors, std=0.02) 133 | self.cls_ctx = nn.Parameter(cls_vectors) 134 | 135 | self.register_buffer("token_prefix", embedding[:, :n_ctx + 1, :]) 136 | self.register_buffer("token_suffix", embedding[:, n_ctx + 1 + n_cls_ctx:, :]) 137 | self.num_class = num_class 138 | self.n_cls_ctx = n_cls_ctx 139 | 140 | def forward(self, label): 141 | cls_ctx = self.cls_ctx[label] 142 | b = label.shape[0] 143 | prefix = self.token_prefix.expand(b, -1, -1) 144 | suffix = self.token_suffix.expand(b, -1, -1) 145 | 146 | prompts = torch.cat( 147 | [ 148 | prefix, # (n_cls, 1, dim) 149 | cls_ctx, # (n_cls, n_ctx, dim) 150 | suffix, # (n_cls, *, dim) 151 | ], 152 | dim=1, 153 | ) 154 | return prompts 155 | 156 | class TextEncoder(nn.Module): 157 | def __init__(self, clip_model): 158 | super().__init__() 159 | self.transformer = clip_model.transformer 160 | self.positional_embedding = clip_model.positional_embedding 161 | self.ln_final = clip_model.ln_final 162 | self.text_projection = clip_model.text_projection 163 | self.dtype = clip_model.dtype 164 | 165 | def forward(self, prompts, tokenized_prompts): 166 | x = prompts + self.positional_embedding.type(self.dtype) 167 | x = x.permute(1, 0, 2) # NLD -> LND 168 | x = self.transformer(x) 169 | x = x.permute(1, 0, 2) # LND -> NLD 170 | x = self.ln_final(x).type(self.dtype) 171 | 172 | # x.shape = [batch_size, n_ctx, transformer.width] 173 | # take features from the eot embedding (eot_token is the highest number in each sequence) 174 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection 175 | return x 176 | 177 | class AttentionFusion(nn.Module): 178 | def __init__(self, embed_dim): 179 | super(AttentionFusion, self).__init__() 180 | self.dropout_rate = 0.1 181 | self.embed_dim = embed_dim 182 | self.embed_dim_qkv = embed_dim 183 | 184 | self.embedding_q = nn.Sequential(nn.Linear(self.embed_dim, self.embed_dim_qkv), 185 | nn.Tanh(), nn.Dropout(self.dropout_rate)) 186 | self.embedding_k = nn.Sequential(nn.Linear(self.embed_dim, self.embed_dim_qkv), 187 | nn.Tanh(), nn.Dropout(self.dropout_rate)) 188 | self.embedding_v = nn.Sequential(nn.Linear(self.embed_dim, self.embed_dim_qkv), 189 | nn.Tanh(), nn.Dropout(self.dropout_rate)) 190 | self.embedding_common = nn.Sequential(nn.Linear(self.embed_dim_qkv, self.embed_dim)) 191 | self.softmax = nn.Softmax(dim=1) 192 | 193 | def q_k_v_product_attention(self, q_emb, k_emb, v_emb): 194 | weights = torch.bmm(q_emb, k_emb.permute(0, 2, 1)) 195 | weights = torch.div(weights, (self.embed_dim_qkv ** 0.5)) 196 | weights = self.softmax(weights) 197 | new_v_emb = weights.bmm(v_emb) 198 | return new_v_emb 199 | 200 | def forward(self, text_features1, text_features2): 201 | batch_size = text_features1.size(0) 202 | q_emb = self.embedding_q(text_features1.unsqueeze(1)) 203 | k_emb = self.embedding_k(text_features2.unsqueeze(1)) 204 | v_emb = self.embedding_v(text_features2.unsqueeze(1)) 205 | new_v_emb = self.q_k_v_product_attention(q_emb, k_emb, v_emb) 206 | new_text_features = self.embedding_common(new_v_emb) 207 | new_text_features = new_text_features.view(batch_size, self.embed_dim) + text_features1 208 | return new_text_features 209 | 210 | class Model(nn.Module): 211 | def __init__(self, num_classes, img_h, img_w): 212 | super(Model, self).__init__() 213 | self.in_planes = 2048 214 | self.num_classes = num_classes 215 | 216 | self.h_resolution = int((img_h - 16) // 16 + 1) 217 | self.w_resolution = int((img_w - 16) // 16 + 1) 218 | self.vision_stride_size = 16 219 | clip_model = load_clip_to_cpu('RN50', self.h_resolution, self.w_resolution, self.vision_stride_size) 220 | clip_model.to("cuda") 221 | 222 | self.image_encoder1 = nn.Sequential(clip_model.visual.conv1, clip_model.visual.bn1, clip_model.visual.conv2, 223 | clip_model.visual.bn2, clip_model.visual.conv3, clip_model.visual.bn3, 224 | clip_model.visual.relu, clip_model.visual.avgpool) 225 | self.image_encoder2 = copy.deepcopy(self.image_encoder1) 226 | 227 | self.image_encoder = nn.Sequential(clip_model.visual.layer1, clip_model.visual.layer2, clip_model.visual.layer3, 228 | clip_model.visual.layer4) 229 | self.attnpool = clip_model.visual.attnpool 230 | self.classifier = Classifier(self.num_classes) 231 | self.classifier2 = Classifier2(self.num_classes) 232 | 233 | self.prompt_learner1 = PromptLearner1(num_classes, clip_model.dtype, clip_model.token_embedding) 234 | self.prompt_learner2 = PromptLearner2(num_classes, clip_model.dtype, clip_model.token_embedding) 235 | self.text_encoder = TextEncoder(clip_model) 236 | self.attention_fusion = AttentionFusion(1024) 237 | 238 | def forward(self, x1=None, x2=None, label1=None, label2=None, label=None, get_image=False, get_text=False, 239 | get_fusion_text=False): 240 | if get_image == True: 241 | if x1 is not None and x2 is None: 242 | image_features_map1 = self.image_encoder1(x1) 243 | image_features_map1 = self.image_encoder(image_features_map1) 244 | image_features1_proj = self.attnpool(image_features_map1)[0] 245 | return image_features1_proj 246 | elif x1 is None and x2 is not None: 247 | image_features_map2 = self.image_encoder2(x2) 248 | image_features_map2 = self.image_encoder(image_features_map2) 249 | image_features2_proj = self.attnpool(image_features_map2)[0] 250 | return image_features2_proj 251 | 252 | if get_text == True: 253 | if label1 is not None and label2 is None: 254 | prompts1 = self.prompt_learner1(label1) 255 | text_features1 = self.text_encoder(prompts1, self.prompt_learner1.tokenized_prompts) 256 | return text_features1 257 | if label2 is not None and label1 is None: 258 | prompts2 = self.prompt_learner2(label2) 259 | text_features2 = self.text_encoder(prompts2, self.prompt_learner2.tokenized_prompts) 260 | return text_features2 261 | 262 | if get_fusion_text == True: 263 | prompts1 = self.prompt_learner1(label) 264 | text_features1 = self.text_encoder(prompts1, self.prompt_learner1.tokenized_prompts) 265 | prompts2 = self.prompt_learner2(label) 266 | text_features2 = self.text_encoder(prompts2, self.prompt_learner2.tokenized_prompts) 267 | text_features = self.attention_fusion(text_features1, text_features2) 268 | return text_features 269 | 270 | if x1 is not None and x2 is not None: 271 | 272 | image_features_map1 = self.image_encoder1(x1) 273 | image_features_map2 = self.image_encoder2(x2) 274 | image_features_maps = torch.cat([image_features_map1, image_features_map2], dim=0) 275 | image_features_maps = self.image_encoder(image_features_maps) 276 | image_features_proj = self.attnpool(image_features_maps)[0] 277 | features, cls_scores, _ = self.classifier(image_features_maps) 278 | cls_scores_proj, _ = self.classifier2(image_features_proj) 279 | 280 | return [features, image_features_proj], [cls_scores, cls_scores_proj] 281 | 282 | elif x1 is not None and x2 is None: 283 | 284 | image_features_map1 = self.image_encoder1(x1) 285 | image_features_map1 = self.image_encoder(image_features_map1) 286 | image_features1_proj = self.attnpool(image_features_map1)[0] 287 | _, _, test_features1 = self.classifier(image_features_map1) 288 | _, test_features1_proj = self.classifier2(image_features1_proj) 289 | 290 | return torch.cat([test_features1, test_features1_proj], dim=1) 291 | 292 | elif x1 is None and x2 is not None: 293 | 294 | image_features_map2 = self.image_encoder2(x2) 295 | image_features_map2 = self.image_encoder(image_features_map2) 296 | image_features2_proj = self.attnpool(image_features_map2)[0] 297 | _, _, test_features2 = self.classifier(image_features_map2) 298 | _, test_features2_proj = self.classifier2(image_features2_proj) 299 | 300 | return torch.cat([test_features2, test_features2_proj], dim=1) 301 | 302 | from .clip import clip 303 | def load_clip_to_cpu(backbone_name, h_resolution, w_resolution, vision_stride_size): 304 | url = clip._MODELS[backbone_name] 305 | model_path = clip._download(url) 306 | 307 | try: 308 | model = torch.jit.load(model_path, map_location="cpu").eval() 309 | state_dict = None 310 | 311 | except RuntimeError: 312 | state_dict = torch.load(model_path, map_location="cpu") 313 | 314 | model = clip.build_model(state_dict or model.state_dict(), h_resolution, w_resolution, vision_stride_size) 315 | 316 | return model 317 | 318 | 319 | -------------------------------------------------------------------------------- /network/processing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | class FeatureShuffling(): 5 | def __init__(self): 6 | super(FeatureShuffling, self).__init__() 7 | 8 | def __call__(self, features1, features2): 9 | size, channel = features1.size(0), features1.size(1) 10 | num = 4 11 | 12 | a = list(range(size)) 13 | b = [a[i: i + num] for i in range(0, size, num)] 14 | c = [] 15 | for b1 in b: 16 | random.shuffle(b1) 17 | c.extend(b1) 18 | 19 | shuffling_features1 = torch.zeros([size, channel]).cuda() 20 | shuffling_features2 = torch.zeros([size, channel]).cuda() 21 | 22 | for i in range(size): 23 | shuffling_features1[i] = features1[c[i]] 24 | shuffling_features2[i] = features2[c[i]] 25 | return shuffling_features1, shuffling_features2 26 | 27 | 28 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 main.py --output_path sysu/base -------------------------------------------------------------------------------- /tools/MSEL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def pdist_torch(emb1, emb2): 7 | ''' 8 | compute the eucilidean distance matrix between embeddings1 and embeddings2 9 | using gpu 10 | ''' 11 | m, n = emb1.shape[0], emb2.shape[0] 12 | emb1_pow = torch.pow(emb1, 2).sum(dim=1, keepdim=True).expand(m, n) 13 | emb2_pow = torch.pow(emb2, 2).sum(dim=1, keepdim=True).expand(n, m).t() 14 | dist_mtx = emb1_pow + emb2_pow 15 | dist_mtx = dist_mtx.addmm_(1, -2, emb1, emb2.t()) 16 | dist_mtx = dist_mtx.clamp(min=1e-12).sqrt() 17 | return dist_mtx 18 | 19 | class MSEL(nn.Module): 20 | def __init__(self,num_pos,feat_norm = 'no'): 21 | super(MSEL, self).__init__() 22 | self.num_pos = num_pos 23 | self.feat_norm = feat_norm 24 | 25 | def forward(self, inputs, targets): 26 | if self.feat_norm == 'yes': 27 | inputs = F.normalize(inputs, p=2, dim=-1) 28 | 29 | target, _ = targets.chunk(2,0) 30 | N = target.size(0) 31 | 32 | dist_mat = pdist_torch(inputs, inputs) 33 | 34 | 35 | dist_intra_rgb = dist_mat[0 : N, 0 : N] 36 | dist_cross_rgb = dist_mat[0 : N, N : 2*N] 37 | dist_intra_ir = dist_mat[N : 2*N, N : 2*N] 38 | dist_cross_ir = dist_mat[N : 2*N, 0 : N] 39 | 40 | # shape [N, N] 41 | 42 | is_pos = target.expand(N, N).eq(target.expand(N, N).t()) 43 | 44 | 45 | dist_intra_rgb = is_pos * dist_intra_rgb 46 | intra_rgb, _ = dist_intra_rgb.topk(self.num_pos - 1, dim=1 ,largest = True, sorted = False) # remove itself 47 | intra_mean_rgb = torch.mean(intra_rgb, dim=1) 48 | 49 | dist_intra_ir = is_pos * dist_intra_ir 50 | intra_ir, _ = dist_intra_ir.topk(self.num_pos - 1, dim=1, largest=True, sorted=False) 51 | intra_mean_ir = torch.mean(intra_ir, dim=1) 52 | 53 | dist_cross_rgb = dist_cross_rgb[is_pos].contiguous().view(N, -1) # [N, num_pos] 54 | cross_mean_rgb = torch.mean(dist_cross_rgb, dim =1) 55 | 56 | dist_cross_ir = dist_cross_ir[is_pos].contiguous().view(N, -1) # [N, num_pos] 57 | cross_mean_ir = torch.mean(dist_cross_ir, dim=1) 58 | 59 | loss = (torch.mean(torch.pow(cross_mean_rgb - intra_mean_rgb, 2)) + 60 | torch.mean(torch.pow(cross_mean_ir - intra_mean_ir, 2))) / 2 61 | 62 | return loss 63 | 64 | 65 | class MSEL_Cos(nn.Module): # for features after bn 66 | def __init__(self,num_pos): 67 | super(MSEL_Cos, self).__init__() 68 | self.num_pos = num_pos 69 | 70 | def forward(self, inputs, targets): 71 | 72 | inputs = nn.functional.normalize(inputs, p=2, dim=1) 73 | 74 | target, _ = targets.chunk(2,0) 75 | N = target.size(0) 76 | 77 | dist_mat = 1 - torch.matmul(inputs, torch.t(inputs)) 78 | 79 | dist_intra_rgb = dist_mat[0: N, 0: N] 80 | dist_cross_rgb = dist_mat[0: N, N: 2*N] 81 | dist_intra_ir = dist_mat[N: 2*N, N: 2*N] 82 | dist_cross_ir = dist_mat[N: 2*N, 0: N] 83 | 84 | # shape [N, N] 85 | is_pos = target.expand(N, N).eq(target.expand(N, N).t()) 86 | 87 | dist_intra_rgb = is_pos * dist_intra_rgb 88 | intra_rgb, _ = dist_intra_rgb.topk(self.num_pos - 1, dim=1, largest=True, sorted=False) # remove itself 89 | intra_mean_rgb = torch.mean(intra_rgb, dim=1) 90 | 91 | dist_intra_ir = is_pos * dist_intra_ir 92 | intra_ir, _ = dist_intra_ir.topk(self.num_pos - 1, dim=1, largest=True, sorted=False) 93 | intra_mean_ir = torch.mean(intra_ir, dim=1) 94 | 95 | dist_cross_rgb = dist_cross_rgb[is_pos].contiguous().view(N, -1) # [N, num_pos] 96 | cross_mean_rgb = torch.mean(dist_cross_rgb, dim=1) 97 | 98 | dist_cross_ir = dist_cross_ir[is_pos].contiguous().view(N, -1) # [N, num_pos] 99 | cross_mean_ir = torch.mean(dist_cross_ir, dim=1) 100 | 101 | loss = (torch.mean(torch.pow(cross_mean_rgb - intra_mean_rgb, 2)) + 102 | torch.mean(torch.pow(cross_mean_ir - intra_mean_ir, 2))) / 2 103 | 104 | return loss 105 | 106 | 107 | class MSEL_Feat(nn.Module): # compute MSEL loss by the distance between sample and center 108 | def __init__(self, num_pos): 109 | super(MSEL_Feat, self).__init__() 110 | self.num_pos = num_pos 111 | 112 | def forward(self, input1, input2): 113 | N = input1.size(0) 114 | id_num = N // self.num_pos 115 | 116 | feats_rgb = input1.chunk(id_num, 0) 117 | feats_ir = input2.chunk(id_num, 0) 118 | 119 | loss_list = [] 120 | for i in range(id_num): 121 | cross_center_rgb = torch.mean(feats_rgb[i], dim=0) # cross center 122 | cross_center_ir = torch.mean(feats_ir[i], dim=0) 123 | 124 | for j in range(self.num_pos): 125 | 126 | feat_rgb = feats_rgb[i][j] 127 | feat_ir = feats_ir[i][j] 128 | 129 | intra_feats_rgb = torch.cat((feats_rgb[i][0:j], feats_rgb[i][j+1:]), dim=0) # intra center 130 | intra_feats_ir = torch.cat((feats_rgb[i][0:j], feats_rgb[i][j+1:]), dim=0) 131 | 132 | intra_center_rgb = torch.mean(intra_feats_rgb, dim=0) 133 | intra_center_ir = torch.mean(intra_feats_ir, dim=0) 134 | 135 | dist_intra_rgb = pdist_torch(feat_rgb.view(1, -1), intra_center_rgb.view(1, -1)) 136 | dist_intra_ir = pdist_torch(feat_ir.view(1, -1), intra_center_ir.view(1, -1)) 137 | 138 | dist_cross_rgb = pdist_torch(feat_rgb.view(1, -1), cross_center_ir.view(1, -1)) 139 | dist_cross_ir = pdist_torch(feat_ir.view(1, -1), cross_center_rgb.view(1, -1)) 140 | 141 | loss_list.append(torch.pow(dist_cross_rgb - dist_intra_rgb, 2) + torch.pow(dist_cross_ir - dist_intra_ir, 2)) 142 | 143 | loss = sum(loss_list) / N / 2 144 | 145 | return loss 146 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss import * 2 | from .utils import * 3 | from .meter import * 4 | from .logger import * 5 | from .eval_metrics import * 6 | from .MSEL import * -------------------------------------------------------------------------------- /tools/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/tools/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /tools/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/tools/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /tools/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/tools/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/tools/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/eval_metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/tools/__pycache__/eval_metrics.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/logger.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/tools/__pycache__/logger.cpython-35.pyc -------------------------------------------------------------------------------- /tools/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/tools/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /tools/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/tools/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/tools/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/loss.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/tools/__pycache__/loss.cpython-35.pyc -------------------------------------------------------------------------------- /tools/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/tools/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /tools/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/tools/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/tools/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/meter.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/tools/__pycache__/meter.cpython-35.pyc -------------------------------------------------------------------------------- /tools/__pycache__/meter.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/tools/__pycache__/meter.cpython-36.pyc -------------------------------------------------------------------------------- /tools/__pycache__/meter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/tools/__pycache__/meter.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/meter.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/tools/__pycache__/meter.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/tools/__pycache__/utils.cpython-35.pyc -------------------------------------------------------------------------------- /tools/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/tools/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /tools/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/tools/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /tools/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nengdong96/CSDN/2224959c5df9af6a58cd34a6f9ef6f6ab8177966/tools/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /tools/eval_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import numpy as np 3 | 4 | def eval_sysu(distmat, q_pids, g_pids, q_camids, g_camids, max_rank = 20): 5 | 6 | num_q, num_g = distmat.shape 7 | if num_g < max_rank: 8 | max_rank = num_g 9 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 10 | indices = np.argsort(distmat, axis=1) 11 | pred_label = g_pids[indices] 12 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 13 | 14 | new_all_cmc = [] 15 | all_cmc = [] 16 | all_AP = [] 17 | all_INP = [] 18 | num_valid_q = 0. 19 | for q_idx in range(num_q): 20 | q_pid = q_pids[q_idx] 21 | q_camid = q_camids[q_idx] 22 | 23 | order = indices[q_idx] 24 | remove = (q_camid == 3) & (g_camids[order] == 2) 25 | keep = np.invert(remove) 26 | new_cmc = pred_label[q_idx][keep] 27 | new_index = np.unique(new_cmc, return_index=True)[1] 28 | new_cmc = [new_cmc[index] for index in sorted(new_index)] 29 | 30 | new_match = (new_cmc == q_pid).astype(np.int32) 31 | new_cmc = new_match.cumsum() 32 | new_all_cmc.append(new_cmc[:max_rank]) 33 | 34 | orig_cmc = matches[q_idx][keep] 35 | if not np.any(orig_cmc): 36 | continue 37 | 38 | cmc = orig_cmc.cumsum() 39 | 40 | pos_idx = np.where(orig_cmc == 1) 41 | pos_max_idx = np.max(pos_idx) 42 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0) 43 | all_INP.append(inp) 44 | 45 | cmc[cmc > 1] = 1 46 | 47 | all_cmc.append(cmc[:max_rank]) 48 | num_valid_q += 1. 49 | 50 | num_rel = orig_cmc.sum() 51 | tmp_cmc = orig_cmc.cumsum() 52 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 53 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 54 | AP = tmp_cmc.sum() / num_rel 55 | all_AP.append(AP) 56 | 57 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 58 | 59 | all_cmc = np.asarray(all_cmc).astype(np.float32) 60 | all_cmc = all_cmc.sum(0) / num_valid_q # standard CMC 61 | 62 | new_all_cmc = np.asarray(new_all_cmc).astype(np.float32) 63 | new_all_cmc = new_all_cmc.sum(0) / num_valid_q 64 | mAP = np.mean(all_AP) 65 | mINP = np.mean(all_INP) 66 | return new_all_cmc, mAP, mINP 67 | 68 | def eval_regdb(distmat, q_pids, g_pids, max_rank = 20): 69 | num_q, num_g = distmat.shape 70 | if num_g < max_rank: 71 | max_rank = num_g 72 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 73 | indices = np.argsort(distmat, axis=1) 74 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 75 | 76 | all_cmc = [] 77 | all_AP = [] 78 | all_INP = [] 79 | num_valid_q = 0. 80 | 81 | q_camids = np.ones(num_q).astype(np.int32) 82 | g_camids = 2 * np.ones(num_g).astype(np.int32) 83 | 84 | for q_idx in range(num_q): 85 | q_pid = q_pids[q_idx] 86 | q_camid = q_camids[q_idx] 87 | 88 | order = indices[q_idx] 89 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 90 | keep = np.invert(remove) 91 | 92 | raw_cmc = matches[q_idx][keep] 93 | if not np.any(raw_cmc): 94 | 95 | continue 96 | 97 | cmc = raw_cmc.cumsum() 98 | 99 | pos_idx = np.where(raw_cmc == 1) 100 | pos_max_idx = np.max(pos_idx) 101 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0) 102 | all_INP.append(inp) 103 | 104 | cmc[cmc > 1] = 1 105 | 106 | all_cmc.append(cmc[:max_rank]) 107 | num_valid_q += 1. 108 | 109 | num_rel = raw_cmc.sum() 110 | tmp_cmc = raw_cmc.cumsum() 111 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 112 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 113 | AP = tmp_cmc.sum() / num_rel 114 | all_AP.append(AP) 115 | 116 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 117 | 118 | all_cmc = np.asarray(all_cmc).astype(np.float32) 119 | all_cmc = all_cmc.sum(0) / num_valid_q 120 | mAP = np.mean(all_AP) 121 | mINP = np.mean(all_INP) 122 | return all_cmc, mAP, mINP -------------------------------------------------------------------------------- /tools/logger.py: -------------------------------------------------------------------------------- 1 | class Logger: 2 | 3 | def __init__(self, log_file): 4 | self.log_file = log_file 5 | 6 | def __call__(self, input): 7 | input = str(input) 8 | with open(self.log_file, 'a') as f: 9 | f.writelines(input+'\n') 10 | print(input) -------------------------------------------------------------------------------- /tools/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class CrossEntropyLabelSmooth(nn.Module): 6 | def __init__(self, epsilon=0.1, use_gpu=True): 7 | super(CrossEntropyLabelSmooth, self).__init__() 8 | self.epsilon = epsilon 9 | self.use_gpu = use_gpu 10 | self.logsoftmax = nn.LogSoftmax(dim=1) 11 | 12 | def forward(self, inputs, targets): 13 | log_probs = self.logsoftmax(inputs) 14 | targets = targets.long() 15 | size = log_probs.size() 16 | targets = torch.zeros((size[0], size[1])).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 17 | 18 | if self.use_gpu: 19 | targets = targets.to(torch.device('cuda')) 20 | targets = (1 - self.epsilon) * targets + self.epsilon / size[1] 21 | loss = (-targets * log_probs).mean(0).sum() 22 | return loss 23 | 24 | class SupConLoss(nn.Module): 25 | def __init__(self, device): 26 | super(SupConLoss, self).__init__() 27 | self.device = device 28 | self.temperature = 1.0 29 | def forward(self, text_features, image_features, t_label, i_targets): 30 | batch_size = text_features.shape[0] 31 | batch_size_N = image_features.shape[0] 32 | mask = torch.eq(t_label.unsqueeze(1).expand(batch_size, batch_size_N), \ 33 | i_targets.unsqueeze(0).expand(batch_size,batch_size_N)).float().to(self.device) 34 | 35 | logits = torch.div(torch.matmul(text_features, image_features.T),self.temperature) 36 | # for numerical stability 37 | logits_max, _ = torch.max(logits, dim=1, keepdim=True) 38 | logits = logits - logits_max.detach() 39 | exp_logits = torch.exp(logits) 40 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 41 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 42 | loss = - mean_log_prob_pos.mean() 43 | 44 | return loss 45 | 46 | def normalize(x, axis=-1): 47 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 48 | return x 49 | 50 | def pdist_torch(emb1, emb2): 51 | m, n = emb1.shape[0], emb2.shape[0] 52 | emb1_pow = torch.pow(emb1, 2).sum(dim=1, keepdim=True).expand(m, n) 53 | emb2_pow = torch.pow(emb2, 2).sum(dim=1, keepdim=True).expand(n, m).t() 54 | dist_mtx = emb1_pow + emb2_pow 55 | dist_mtx = dist_mtx.addmm_(1, -2, emb1, emb2.t()) 56 | dist_mtx = dist_mtx.clamp(min=1e-12).sqrt() 57 | return dist_mtx 58 | 59 | def softmax_weights(dist, mask): 60 | max_v = torch.max(dist * mask, dim=1, keepdim=True)[0] 61 | diff = dist - max_v 62 | Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero 63 | W = torch.exp(diff) * mask / Z 64 | return W 65 | 66 | class TripletLoss_WRT(nn.Module): 67 | 68 | def __init__(self): 69 | super(TripletLoss_WRT, self).__init__() 70 | self.ranking_loss = nn.SoftMarginLoss() 71 | 72 | def forward(self, inputs, targets, normalize_feature=False): 73 | if normalize_feature: 74 | inputs = normalize(inputs, axis=-1) 75 | dist_mat = pdist_torch(inputs, inputs) 76 | 77 | N = dist_mat.size(0) 78 | is_pos = targets.expand(N, N).eq(targets.expand(N, N).t()).float() 79 | is_neg = targets.expand(N, N).ne(targets.expand(N, N).t()).float() 80 | 81 | dist_ap = dist_mat * is_pos 82 | dist_an = dist_mat * is_neg 83 | 84 | weights_ap = softmax_weights(dist_ap, is_pos) 85 | weights_an = softmax_weights(-dist_an, is_neg) 86 | furthest_positive = torch.sum(dist_ap * weights_ap, dim=1) 87 | closest_negative = torch.sum(dist_an * weights_an, dim=1) 88 | 89 | y = furthest_positive.new().resize_as_(furthest_positive).fill_(1) 90 | loss = self.ranking_loss(closest_negative - furthest_positive, y) 91 | 92 | return loss 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /tools/meter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class CatMeter: 4 | 5 | def __init__(self): 6 | self.reset() 7 | 8 | def reset(self): 9 | self.val = None 10 | 11 | def update(self, val): 12 | if self.val is None: 13 | self.val = val 14 | else: 15 | self.val = torch.cat([self.val, val], dim=0) 16 | 17 | def get_val(self): 18 | return self.val 19 | 20 | def get_val_numpy(self): 21 | return self.val.data.cpu().numpy() 22 | 23 | class MultiItemAverageMeter: 24 | 25 | def __init__(self): 26 | self.content = {} 27 | 28 | def update(self, val): 29 | 30 | for key in list(val.keys()): 31 | value = val[key] 32 | if key not in list(self.content.keys()): 33 | self.content[key] = {'avg': value, 'sum': value, 'count': 1.0} 34 | else: 35 | self.content[key]['sum'] += value 36 | self.content[key]['count'] += 1.0 37 | self.content[key]['avg'] = self.content[key]['sum'] / self.content[key]['count'] 38 | 39 | def get_val(self): 40 | keys = list(self.content.keys()) 41 | values = [] 42 | for key in keys: 43 | try: 44 | values.append(self.content[key]['avg'].data.cpu().numpy()) 45 | except: 46 | values.append(self.content[key]['avg']) 47 | return keys, values 48 | 49 | def get_str(self): 50 | 51 | result = '' 52 | keys, values = self.get_val() 53 | 54 | for key, value in zip(keys, values): 55 | result += key 56 | result += ': ' 57 | result += str(value) 58 | result += '; ' 59 | 60 | return result 61 | -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | def os_walk(folder_dir): 5 | for root, dirs, files in os.walk(folder_dir): 6 | files = sorted(files, reverse=True) 7 | dirs = sorted(dirs, reverse=True) 8 | return root, dirs, files 9 | 10 | def time_now(): 11 | return time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()) 12 | 13 | def make_dirs(dir): 14 | if not os.path.exists(dir): 15 | os.makedirs(dir) 16 | print('Successfully make dirs: {}'.format(dir)) 17 | else: 18 | print('Existed dirs: {}'.format(dir)) 19 | 20 | --------------------------------------------------------------------------------