├── Pytorch-PMT-VI-ReID ├── config │ ├── RegDB.yml │ ├── SYSU.yml │ └── config.py ├── dataloader.py ├── datamanager.py ├── eval_metrics.py ├── loss │ ├── DCL.py │ ├── MSEL.py │ └── Triplet.py ├── main.sh ├── model │ ├── make_model.py │ └── vision_transformer.py ├── optimizer.py ├── process_sysu.py ├── scheduler.py ├── test.py ├── train.py ├── transforms.py └── utils.py └── README.md /Pytorch-PMT-VI-ReID/config/RegDB.yml: -------------------------------------------------------------------------------- 1 | DATASET: 'regdb' 2 | 3 | START_EPOCH: 1 4 | MAX_EPOCH: 36 5 | BATCH_SIZE: 32 6 | NUM_POS: 4 7 | 8 | # PMT 9 | METHOD: 'PMT' # 'PMT' or 'base' 10 | PL_EPOCH: 6 11 | MSEL: 0.5 12 | DCL: 0.5 13 | MARGIN: 0.1 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /Pytorch-PMT-VI-ReID/config/SYSU.yml: -------------------------------------------------------------------------------- 1 | DATASET: 'sysu' 2 | 3 | START_EPOCH: 1 4 | MAX_EPOCH: 24 5 | BATCH_SIZE: 32 6 | NUM_POS: 4 7 | 8 | 9 | METHOD: 'PMT' # 'PMT' or 'base' 10 | PL_EPOCH: 6 11 | MSEL: 0.5 12 | DCL: 0.5 13 | MARGIN: 0.1 14 | 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /Pytorch-PMT-VI-ReID/config/config.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | cfg = CN() 3 | 4 | cfg.SEED = 0 5 | 6 | # dataset 7 | cfg.DATASET = 'sysu' # sysu or regdb 8 | cfg.DATA_PATH_SYSU = ' ../SYSU-MM01/' 9 | cfg.DATA_PATH_RegDB = ' ../RegDB/' 10 | cfg.PRETRAIN_PATH = '../jx_vit_base_p16_224-80ecf9dd.pth ' 11 | 12 | cfg.START_EPOCH = 1 13 | cfg.MAX_EPOCH = 24 14 | 15 | cfg.H = 256 16 | cfg.W = 128 17 | cfg.BATCH_SIZE = 32 # num of images for each modality in a mini batch 18 | cfg.NUM_POS = 4 19 | 20 | # PMT 21 | cfg.METHOD ='PMT' 22 | cfg.PL_EPOCH = 6 # for PL strategy 23 | cfg.MSEL = 0.5 # weight for MSEL 24 | cfg.DCL = 0.5 # weight for DCL 25 | cfg.MARGIN = 0.1 # margin for triplet 26 | 27 | 28 | # model 29 | cfg.STRIDE_SIZE = [12,12] 30 | cfg.DROP_OUT = 0.03 31 | cfg.ATT_DROP_RATE = 0.0 32 | cfg.DROP_PATH = 0.1 33 | 34 | # optimizer 35 | cfg.OPTIMIZER_NAME = 'AdamW' # AdamW or SGD 36 | cfg.MOMENTUM = 0.9 # for SGD 37 | 38 | cfg.BASE_LR = 3e-4 39 | cfg.WEIGHT_DECAY = 1e-4 40 | cfg.WEIGHT_DECAY_BIAS = 1e-4 41 | cfg.BIAS_LR_FACTOR = 1 42 | 43 | cfg.LR_PRETRAIN = 0.5 44 | cfg.LR_MIN = 0.01 45 | cfg.LR_INIT = 0.01 46 | cfg.WARMUP_EPOCHS = 3 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /Pytorch-PMT-VI-ReID/dataloader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch.utils.data as data 4 | from torch.utils.data.sampler import Sampler 5 | 6 | class SYSUData(data.Dataset): 7 | def __init__(self, data_dir, transform1=None,transform2 = None, colorIndex=None, thermalIndex=None): 8 | # Load training images (path) and labels 9 | train_color_image = np.load(data_dir + 'train_rgb_resized_img.npy') 10 | self.train_color_label = np.load(data_dir + 'train_rgb_resized_label.npy') 11 | 12 | train_thermal_image = np.load(data_dir + 'train_ir_resized_img.npy') 13 | self.train_thermal_label = np.load(data_dir + 'train_ir_resized_label.npy') 14 | 15 | # RGB format 16 | self.train_color_image = train_color_image 17 | self.train_thermal_image = train_thermal_image 18 | self.transform1 = transform1 19 | self.transform2 = transform2 20 | self.cIndex = colorIndex 21 | self.tIndex = thermalIndex 22 | 23 | def __getitem__(self, index): 24 | 25 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 26 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 27 | 28 | img1 = self.transform1(img1) 29 | img2 = self.transform2(img2) 30 | 31 | return img1, img2, target1, target2 32 | 33 | def __len__(self): 34 | return len(self.train_color_label) 35 | 36 | 37 | class RegDBData(data.Dataset): 38 | def __init__(self, data_dir, trial, transform1=None,transform2 = None, colorIndex=None, thermalIndex=None): 39 | # Load training images (path) and labels 40 | train_color_list = data_dir + 'idx/train_visible_{}'.format(trial) + '.txt' 41 | train_thermal_list = data_dir + 'idx/train_thermal_{}'.format(trial) + '.txt' 42 | 43 | color_img_file, train_color_label = load_data(train_color_list) 44 | thermal_img_file, train_thermal_label = load_data(train_thermal_list) 45 | 46 | train_color_image = [] 47 | for i in range(len(color_img_file)): 48 | img = Image.open(data_dir + color_img_file[i]) 49 | img = img.resize((144, 288), Image.ANTIALIAS) 50 | pix_array = np.array(img) 51 | train_color_image.append(pix_array) 52 | train_color_image = np.array(train_color_image) 53 | 54 | train_thermal_image = [] 55 | for i in range(len(thermal_img_file)): 56 | img = Image.open(data_dir + thermal_img_file[i]) 57 | img = img.resize((144, 288), Image.ANTIALIAS) 58 | pix_array = np.array(img) 59 | train_thermal_image.append(pix_array) 60 | train_thermal_image = np.array(train_thermal_image) 61 | 62 | # RGB format 63 | self.train_color_image = train_color_image 64 | self.train_color_label = train_color_label 65 | 66 | # RGB format 67 | self.train_thermal_image = train_thermal_image 68 | self.train_thermal_label = train_thermal_label 69 | 70 | self.transform1 = transform1 71 | self.transform2 = transform2 72 | self.cIndex = colorIndex 73 | self.tIndex = thermalIndex 74 | 75 | def __getitem__(self, index): 76 | 77 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 78 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 79 | 80 | img1 = self.transform1(img1) 81 | img2 = self.transform2(img2) 82 | 83 | return img1, img2, target1, target2 84 | 85 | def __len__(self): 86 | return len(self.train_color_label) 87 | 88 | 89 | class TestData_RegDB(data.Dataset): 90 | def __init__(self, test_img_file, test_label, transform=None, img_size=(224, 224)): 91 | test_image = [] 92 | for i in range(len(test_img_file)): 93 | img = Image.open(test_img_file[i]) 94 | img = img.resize((img_size[0], img_size[1]), Image.ANTIALIAS) 95 | pix_array = np.array(img) 96 | test_image.append(pix_array) 97 | test_image = np.array(test_image) 98 | self.test_image = test_image 99 | self.test_label = test_label 100 | self.transform = transform 101 | 102 | def __getitem__(self, index): 103 | img1, target1 = self.test_image[index], self.test_label[index] 104 | img1 = self.transform(img1) 105 | 106 | return img1, target1 107 | 108 | def __len__(self): 109 | return len(self.test_image) 110 | 111 | class TestData(data.Dataset): 112 | def __init__(self, test_img_file, test_label, transform=None, img_size=(224, 224)): 113 | test_image = [] 114 | for i in range(len(test_img_file)): 115 | img = Image.open(test_img_file[i]) 116 | img = img.resize((img_size[0], img_size[1]), Image.ANTIALIAS) 117 | pix_array = np.array(img) 118 | test_image.append(pix_array) 119 | test_image = np.array(test_image) 120 | self.test_image = test_image 121 | self.test_label = test_label 122 | self.transform = transform 123 | 124 | def __getitem__(self, index): 125 | img1, target1 = self.test_image[index], self.test_label[index] 126 | img1 = self.transform(img1) 127 | return img1, target1 128 | 129 | def __len__(self): 130 | return len(self.test_image) 131 | 132 | 133 | def load_data(input_data_path): 134 | with open(input_data_path) as f: 135 | data_file_list = open(input_data_path, 'rt').read().splitlines() 136 | # Get full list of image and labels 137 | file_image = [s.split(' ')[0] for s in data_file_list] 138 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 139 | 140 | return file_image, file_label 141 | 142 | 143 | def GenIdx(train_color_label, train_thermal_label): 144 | color_pos = [] 145 | unique_label_color = np.unique(train_color_label) 146 | for i in range(len(unique_label_color)): 147 | tmp_pos = [k for k, v in enumerate(train_color_label) if v == unique_label_color[i]] 148 | color_pos.append(tmp_pos) 149 | 150 | thermal_pos = [] 151 | unique_label_thermal = np.unique(train_thermal_label) 152 | for i in range(len(unique_label_thermal)): 153 | tmp_pos = [k for k, v in enumerate(train_thermal_label) if v == unique_label_thermal[i]] 154 | thermal_pos.append(tmp_pos) 155 | 156 | return color_pos, thermal_pos 157 | 158 | 159 | class IdentitySampler(Sampler): 160 | """Sample person identities evenly in each batch. 161 | Args: 162 | train_color_label, train_thermal_label: labels of two modalities 163 | color_pos, thermal_pos: positions of each identity 164 | batchSize: batch size 165 | """ 166 | 167 | def __init__(self, train_color_label, train_thermal_label, color_pos, thermal_pos, batchSize, per_img): 168 | uni_label = np.unique(train_color_label) 169 | self.n_classes = len(uni_label) 170 | 171 | sample_color = np.arange(batchSize) 172 | sample_thermal = np.arange(batchSize) 173 | N = np.maximum(len(train_color_label), len(train_thermal_label)) 174 | 175 | # per_img = 4 176 | per_id = batchSize / per_img 177 | for j in range(N // batchSize + 1): 178 | batch_idx = np.random.choice(uni_label, int(per_id), replace=False) 179 | 180 | for s, i in enumerate(range(0, batchSize, per_img)): 181 | sample_color[i:i + per_img] = np.random.choice(color_pos[batch_idx[s]], per_img, replace=False) 182 | sample_thermal[i:i + per_img] = np.random.choice(thermal_pos[batch_idx[s]], per_img, replace=False) 183 | 184 | if j == 0: 185 | index1 = sample_color 186 | index2 = sample_thermal 187 | else: 188 | index1 = np.hstack((index1, sample_color)) 189 | index2 = np.hstack((index2, sample_thermal)) 190 | 191 | self.index1 = index1 192 | self.index2 = index2 193 | self.N = N 194 | 195 | def __iter__(self): 196 | return iter(np.arange(len(self.index1))) 197 | 198 | def __len__(self): 199 | return self.N 200 | -------------------------------------------------------------------------------- /Pytorch-PMT-VI-ReID/datamanager.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import numpy as np 4 | import random 5 | 6 | def process_query_sysu(data_path, mode='all', relabel=False): 7 | 8 | if mode== 'all': 9 | ir_cameras = ['cam3','cam6'] 10 | elif mode =='indoor': 11 | ir_cameras = ['cam3','cam6'] 12 | 13 | file_path = os.path.join(data_path, 'exp/test_id.txt') 14 | files_rgb = [] 15 | files_ir = [] 16 | 17 | with open(file_path, 'r') as file: 18 | ids = file.read().splitlines() 19 | ids = [int(y) for y in ids[0].split(',')] 20 | ids = ["%04d" % x for x in ids] 21 | 22 | for id in sorted(ids): 23 | for cam in ir_cameras: 24 | img_dir = os.path.join(data_path, cam, id) 25 | if os.path.isdir(img_dir): 26 | new_files = sorted([img_dir + '/' + i for i in os.listdir(img_dir)]) 27 | files_ir.extend(new_files) 28 | query_img = [] 29 | query_id = [] 30 | query_cam = [] 31 | for img_path in files_ir: 32 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 33 | query_img.append(img_path) 34 | query_id.append(pid) 35 | query_cam.append(camid) 36 | 37 | return query_img, np.array(query_id), np.array(query_cam) 38 | 39 | 40 | def process_gallery_sysu(data_path, mode='all', trial=0, relabel=False, gall_mode='single'): 41 | 42 | random.seed(trial) 43 | 44 | if mode == 'all': 45 | rgb_cameras = ['cam1', 'cam2', 'cam4', 'cam5'] 46 | elif mode == 'indoor': 47 | rgb_cameras = ['cam1', 'cam2'] 48 | 49 | file_path = os.path.join(data_path, 'exp/test_id.txt') 50 | files_rgb = [] 51 | with open(file_path, 'r') as file: 52 | ids = file.read().splitlines() 53 | ids = [int(y) for y in ids[0].split(',')] 54 | ids = ["%04d" % x for x in ids] 55 | 56 | for id in sorted(ids): 57 | for cam in rgb_cameras: 58 | img_dir = os.path.join(data_path, cam, id) 59 | if os.path.isdir(img_dir): 60 | new_files = sorted([img_dir + '/' + i for i in os.listdir(img_dir)]) 61 | if gall_mode == 'single': 62 | files_rgb.append(random.choice(new_files)) 63 | if gall_mode == 'multi': 64 | files_rgb.append(np.random.choice(new_files, 10, replace=False)) 65 | gall_img = [] 66 | gall_id = [] 67 | gall_cam = [] 68 | 69 | for img_path in files_rgb: 70 | if gall_mode == 'single': 71 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 72 | gall_img.append(img_path) 73 | gall_id.append(pid) 74 | gall_cam.append(camid) 75 | 76 | if gall_mode == 'multi': 77 | for i in img_path: 78 | camid, pid = int(i[-15]), int(i[-13:-9]) 79 | gall_img.append(i) 80 | gall_id.append(pid) 81 | gall_cam.append(camid) 82 | 83 | return gall_img, np.array(gall_id), np.array(gall_cam) 84 | 85 | 86 | def process_test_regdb(img_dir, trial=1, modal='visible'): 87 | if modal == 'visible': 88 | input_data_path = img_dir + 'idx/test_visible_{}'.format(trial) + '.txt' 89 | elif modal == 'thermal': 90 | input_data_path = img_dir + 'idx/test_thermal_{}'.format(trial) + '.txt' 91 | 92 | with open(input_data_path) as f: 93 | data_file_list = open(input_data_path, 'rt').read().splitlines() 94 | # Get full list of image and labels 95 | file_image = [img_dir + '/' + s.split(' ')[0] for s in data_file_list] 96 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 97 | 98 | return file_image, np.array(file_label) -------------------------------------------------------------------------------- /Pytorch-PMT-VI-ReID/eval_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import numpy as np 3 | 4 | def eval_sysu(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=20): 5 | """Evaluation with sysu metric 6 | Key: for each query identity, its gallery images from the same camera view are discarded. "Following the original setting in ite dataset" 7 | """ 8 | num_q, num_g = distmat.shape 9 | 10 | if num_g < max_rank: 11 | max_rank = num_g 12 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 13 | 14 | indices = np.argsort(distmat, axis=1) 15 | pred_label = g_pids[indices] 16 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 17 | # compute cmc curve for each query 18 | new_all_cmc = [] 19 | all_cmc = [] 20 | all_AP = [] 21 | all_INP = [] 22 | num_valid_q = 0. # number of valid query 23 | for q_idx in range(num_q): 24 | # get query pid and camid 25 | q_pid = q_pids[q_idx] 26 | q_camid = q_camids[q_idx] 27 | 28 | # remove gallery samples that have the same pid and camid with query 29 | order = indices[q_idx] 30 | remove = (q_camid == 3) & (g_camids[order] == 2) 31 | keep = np.invert(remove) 32 | 33 | # compute cmc curve 34 | # the cmc calculation is different from standard protocol 35 | # we follow the protocol of the author's released code 36 | new_cmc = pred_label[q_idx][keep] 37 | new_index = np.unique(new_cmc, return_index=True)[1] 38 | new_cmc = [new_cmc[index] for index in sorted(new_index)] 39 | 40 | new_match = (new_cmc == q_pid).astype(np.int32) 41 | new_cmc = new_match.cumsum() 42 | new_all_cmc.append(new_cmc[:max_rank]) 43 | 44 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 45 | if not np.any(orig_cmc): 46 | # this condition is true when query identity does not appear in gallery 47 | continue 48 | 49 | cmc = orig_cmc.cumsum() 50 | 51 | # compute mINP 52 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook 53 | pos_idx = np.where(orig_cmc == 1) 54 | pos_max_idx = np.max(pos_idx) 55 | inp = cmc[pos_max_idx] / (pos_max_idx + 1.0) 56 | all_INP.append(inp) 57 | 58 | cmc[cmc > 1] = 1 59 | 60 | all_cmc.append(cmc[:max_rank]) 61 | num_valid_q += 1. 62 | 63 | # compute average precision 64 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 65 | num_rel = orig_cmc.sum() 66 | tmp_cmc = orig_cmc.cumsum() 67 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 68 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 69 | AP = tmp_cmc.sum() / num_rel 70 | all_AP.append(AP) 71 | 72 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 73 | 74 | all_cmc = np.asarray(all_cmc).astype(np.float32) 75 | all_cmc = all_cmc.sum(0) / num_valid_q # standard CMC 76 | 77 | new_all_cmc = np.asarray(new_all_cmc).astype(np.float32) 78 | new_all_cmc = new_all_cmc.sum(0) / num_valid_q 79 | mAP = np.mean(all_AP) 80 | mINP = np.mean(all_INP) 81 | return new_all_cmc, mAP, mINP 82 | 83 | 84 | def eval_regdb(distmat, q_pids, g_pids, max_rank=20): 85 | num_q, num_g = distmat.shape 86 | if num_g < max_rank: 87 | max_rank = num_g 88 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 89 | indices = np.argsort(distmat, axis=1) 90 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 91 | 92 | # compute cmc curve for each query 93 | all_cmc = [] 94 | all_AP = [] 95 | all_INP = [] 96 | num_valid_q = 0. # number of valid query 97 | 98 | # only two cameras 99 | q_camids = np.ones(num_q).astype(np.int32) 100 | g_camids = 2 * np.ones(num_g).astype(np.int32) 101 | 102 | for q_idx in range(num_q): 103 | # get query pid and camid 104 | q_pid = q_pids[q_idx] 105 | q_camid = q_camids[q_idx] 106 | 107 | # remove gallery samples that have the same pid and camid with query 108 | order = indices[q_idx] 109 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 110 | keep = np.invert(remove) 111 | 112 | # compute cmc curve 113 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 114 | if not np.any(raw_cmc): 115 | # this condition is true when query identity does not appear in gallery 116 | continue 117 | 118 | cmc = raw_cmc.cumsum() 119 | 120 | # compute mINP 121 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook 122 | pos_idx = np.where(raw_cmc == 1) 123 | pos_max_idx = np.max(pos_idx) 124 | inp = cmc[pos_max_idx] / (pos_max_idx + 1.0) 125 | all_INP.append(inp) 126 | 127 | cmc[cmc > 1] = 1 128 | 129 | all_cmc.append(cmc[:max_rank]) 130 | num_valid_q += 1. 131 | 132 | # compute average precision 133 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 134 | num_rel = raw_cmc.sum() 135 | tmp_cmc = raw_cmc.cumsum() 136 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 137 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 138 | AP = tmp_cmc.sum() / num_rel 139 | all_AP.append(AP) 140 | 141 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 142 | 143 | all_cmc = np.asarray(all_cmc).astype(np.float32) 144 | all_cmc = all_cmc.sum(0) / num_valid_q 145 | mAP = np.mean(all_AP) 146 | mINP = np.mean(all_INP) 147 | return all_cmc, mAP, mINP -------------------------------------------------------------------------------- /Pytorch-PMT-VI-ReID/loss/DCL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | def normalize(x, axis=-1): 6 | """Normalizing to unit length along the specified dimension. 7 | Args: 8 | x: pytorch Variable 9 | Returns: 10 | x: pytorch Variable, same shape as input 11 | """ 12 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 13 | return x 14 | 15 | def pdist_torch(emb1, emb2): 16 | ''' 17 | compute the eucilidean distance matrix between embeddings1 and embeddings2 18 | using gpu 19 | ''' 20 | m, n = emb1.shape[0], emb2.shape[0] 21 | emb1_pow = torch.pow(emb1, 2).sum(dim=1, keepdim=True).expand(m, n) 22 | emb2_pow = torch.pow(emb2, 2).sum(dim=1, keepdim=True).expand(n, m).t() 23 | dist_mtx = emb1_pow + emb2_pow 24 | dist_mtx = dist_mtx.addmm_(1, -2, emb1, emb2.t()) 25 | dist_mtx = dist_mtx.clamp(min=1e-12).sqrt() 26 | return dist_mtx 27 | 28 | 29 | class DCL(nn.Module): 30 | def __init__(self, num_pos=4, feat_norm='no'): 31 | super(DCL, self).__init__() 32 | self.num_pos = num_pos 33 | self.feat_norm = feat_norm 34 | 35 | def forward(self,inputs, targets): 36 | if self.feat_norm == 'yes': 37 | inputs = F.normalize(inputs, p=2, dim=-1) 38 | 39 | N = inputs.size(0) 40 | id_num = N // 2 // self.num_pos 41 | 42 | is_neg = targets.expand(N, N).ne(targets.expand(N, N).t()) 43 | is_neg_c2i = is_neg[::self.num_pos, :].chunk(2, 0)[0] # mask [id_num, N] 44 | 45 | centers = [] 46 | for i in range(id_num): 47 | centers.append(inputs[targets == targets[i * self.num_pos]].mean(0)) 48 | centers = torch.stack(centers) 49 | 50 | dist_mat = pdist_torch(centers, inputs) # c-i 51 | 52 | an = dist_mat * is_neg_c2i 53 | an = an[an > 1e-6].view(id_num, -1) 54 | 55 | d_neg = torch.mean(an, dim=1, keepdim=True) 56 | mask_an = (an - d_neg).expand(id_num, N - 2 * self.num_pos).lt(0) # mask 57 | an = an * mask_an 58 | 59 | list_an = [] 60 | for i in range (id_num): 61 | list_an.append(torch.mean(an[i][an[i]>1e-6])) 62 | an_mean = sum(list_an) / len(list_an) 63 | 64 | ap = dist_mat * ~is_neg_c2i 65 | ap_mean = torch.mean(ap[ap>1e-6]) 66 | 67 | loss = ap_mean / an_mean 68 | 69 | return loss -------------------------------------------------------------------------------- /Pytorch-PMT-VI-ReID/loss/MSEL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def pdist_torch(emb1, emb2): 7 | ''' 8 | compute the eucilidean distance matrix between embeddings1 and embeddings2 9 | using gpu 10 | ''' 11 | m, n = emb1.shape[0], emb2.shape[0] 12 | emb1_pow = torch.pow(emb1, 2).sum(dim=1, keepdim=True).expand(m, n) 13 | emb2_pow = torch.pow(emb2, 2).sum(dim=1, keepdim=True).expand(n, m).t() 14 | dist_mtx = emb1_pow + emb2_pow 15 | dist_mtx = dist_mtx.addmm_(1, -2, emb1, emb2.t()) 16 | dist_mtx = dist_mtx.clamp(min=1e-12).sqrt() 17 | return dist_mtx 18 | 19 | class MSEL(nn.Module): 20 | def __init__(self,num_pos,feat_norm = 'no'): 21 | super(MSEL, self).__init__() 22 | self.num_pos = num_pos 23 | self.feat_norm = feat_norm 24 | 25 | def forward(self, inputs, targets): 26 | if self.feat_norm == 'yes': 27 | inputs = F.normalize(inputs, p=2, dim=-1) 28 | 29 | target, _ = targets.chunk(2,0) 30 | N = target.size(0) 31 | 32 | dist_mat = pdist_torch(inputs, inputs) 33 | 34 | dist_intra_rgb = dist_mat[0 : N, 0 : N] 35 | dist_cross_rgb = dist_mat[0 : N, N : 2*N] 36 | dist_intra_ir = dist_mat[N : 2*N, N : 2*N] 37 | dist_cross_ir = dist_mat[N : 2*N, 0 : N] 38 | 39 | # shape [N, N] 40 | is_pos = target.expand(N, N).eq(target.expand(N, N).t()) 41 | 42 | dist_intra_rgb = is_pos * dist_intra_rgb 43 | intra_rgb, _ = dist_intra_rgb.topk(self.num_pos - 1, dim=1 ,largest = True, sorted = False) # remove itself 44 | intra_mean_rgb = torch.mean(intra_rgb, dim=1) 45 | 46 | dist_intra_ir = is_pos * dist_intra_ir 47 | intra_ir, _ = dist_intra_ir.topk(self.num_pos - 1, dim=1, largest=True, sorted=False) 48 | intra_mean_ir = torch.mean(intra_ir, dim=1) 49 | 50 | dist_cross_rgb = dist_cross_rgb[is_pos].contiguous().view(N, -1) # [N, num_pos] 51 | cross_mean_rgb = torch.mean(dist_cross_rgb, dim =1) 52 | 53 | dist_cross_ir = dist_cross_ir[is_pos].contiguous().view(N, -1) # [N, num_pos] 54 | cross_mean_ir = torch.mean(dist_cross_ir, dim=1) 55 | 56 | loss = (torch.mean(torch.pow(cross_mean_rgb - intra_mean_rgb, 2)) + 57 | torch.mean(torch.pow(cross_mean_ir - intra_mean_ir, 2))) / 2 58 | 59 | return loss 60 | 61 | 62 | class MSEL_Cos(nn.Module): # for features after bn 63 | def __init__(self,num_pos): 64 | super(MSEL_Cos, self).__init__() 65 | self.num_pos = num_pos 66 | 67 | def forward(self, inputs, targets): 68 | 69 | inputs = nn.functional.normalize(inputs, p=2, dim=1) 70 | 71 | target, _ = targets.chunk(2,0) 72 | N = target.size(0) 73 | 74 | dist_mat = 1 - torch.matmul(inputs, torch.t(inputs)) 75 | 76 | dist_intra_rgb = dist_mat[0: N, 0: N] 77 | dist_cross_rgb = dist_mat[0: N, N: 2*N] 78 | dist_intra_ir = dist_mat[N: 2*N, N: 2*N] 79 | dist_cross_ir = dist_mat[N: 2*N, 0: N] 80 | 81 | # shape [N, N] 82 | is_pos = target.expand(N, N).eq(target.expand(N, N).t()) 83 | 84 | dist_intra_rgb = is_pos * dist_intra_rgb 85 | intra_rgb, _ = dist_intra_rgb.topk(self.num_pos - 1, dim=1, largest=True, sorted=False) # remove itself 86 | intra_mean_rgb = torch.mean(intra_rgb, dim=1) 87 | 88 | dist_intra_ir = is_pos * dist_intra_ir 89 | intra_ir, _ = dist_intra_ir.topk(self.num_pos - 1, dim=1, largest=True, sorted=False) 90 | intra_mean_ir = torch.mean(intra_ir, dim=1) 91 | 92 | dist_cross_rgb = dist_cross_rgb[is_pos].contiguous().view(N, -1) # [N, num_pos] 93 | cross_mean_rgb = torch.mean(dist_cross_rgb, dim=1) 94 | 95 | dist_cross_ir = dist_cross_ir[is_pos].contiguous().view(N, -1) # [N, num_pos] 96 | cross_mean_ir = torch.mean(dist_cross_ir, dim=1) 97 | 98 | loss = (torch.mean(torch.pow(cross_mean_rgb - intra_mean_rgb, 2)) + 99 | torch.mean(torch.pow(cross_mean_ir - intra_mean_ir, 2))) / 2 100 | 101 | return loss 102 | 103 | 104 | class MSEL_Feat(nn.Module): # compute MSEL loss by the distance between sample and center 105 | def __init__(self, num_pos): 106 | super(MSEL_Feat, self).__init__() 107 | self.num_pos = num_pos 108 | 109 | def forward(self, input1, input2): 110 | N = input1.size(0) 111 | id_num = N // self.num_pos 112 | 113 | feats_rgb = input1.chunk(id_num, 0) 114 | feats_ir = input2.chunk(id_num, 0) 115 | 116 | loss_list = [] 117 | for i in range(id_num): 118 | cross_center_rgb = torch.mean(feats_rgb[i], dim=0) # cross center 119 | cross_center_ir = torch.mean(feats_ir[i], dim=0) 120 | 121 | for j in range(self.num_pos): 122 | 123 | feat_rgb = feats_rgb[i][j] 124 | feat_ir = feats_ir[i][j] 125 | 126 | intra_feats_rgb = torch.cat((feats_rgb[i][0:j], feats_rgb[i][j+1:]), dim=0) # intra center 127 | intra_feats_ir = torch.cat((feats_rgb[i][0:j], feats_rgb[i][j+1:]), dim=0) 128 | 129 | intra_center_rgb = torch.mean(intra_feats_rgb, dim=0) 130 | intra_center_ir = torch.mean(intra_feats_ir, dim=0) 131 | 132 | dist_intra_rgb = pdist_torch(feat_rgb.view(1, -1), intra_center_rgb.view(1, -1)) 133 | dist_intra_ir = pdist_torch(feat_ir.view(1, -1), intra_center_ir.view(1, -1)) 134 | 135 | dist_cross_rgb = pdist_torch(feat_rgb.view(1, -1), cross_center_ir.view(1, -1)) 136 | dist_cross_ir = pdist_torch(feat_ir.view(1, -1), cross_center_rgb.view(1, -1)) 137 | 138 | loss_list.append(torch.pow(dist_cross_rgb - dist_intra_rgb, 2) + torch.pow(dist_cross_ir - dist_intra_ir, 2)) 139 | 140 | loss = sum(loss_list) / N / 2 141 | 142 | return loss 143 | -------------------------------------------------------------------------------- /Pytorch-PMT-VI-ReID/loss/Triplet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def euclidean_dist(x, y, eps=1e-12): 6 | """ 7 | Args: 8 | x: pytorch Tensor, with shape [m, d] 9 | y: pytorch Tensor, with shape [n, d] 10 | Returns: 11 | dist: pytorch Tensor, with shape [m, n] 12 | """ 13 | m, n = x.size(0), y.size(0) 14 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 15 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 16 | dist = xx + yy 17 | dist.addmm_(x, y.t(), beta=1, alpha=-2) #dist.addmm_(1, -2, x, y.t()) 18 | dist = dist.clamp(min=eps).sqrt() 19 | 20 | return dist 21 | 22 | def hard_example_mining(dist_mat, target): 23 | """For each anchor, find the hardest positive and negative sample. 24 | Args: 25 | dist_mat: pytorch Tensor, pair wise distance between samples, shape [N, N] 26 | target: pytorch LongTensor, with shape [N] 27 | return_inds: whether to return the indices. Save time if `False`(?) 28 | Returns: 29 | dist_ap: pytorch Tensor, distance(anchor, positive); shape [N] 30 | dist_an: pytorch Tensor, distance(anchor, negative); shape [N] 31 | p_inds: pytorch LongTensor, with shape [N]; 32 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 33 | n_inds: pytorch LongTensor, with shape [N]; 34 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 35 | NOTE: Only consider the case in which all target have same num of samples, 36 | thus we can cope with all anchors in parallel. 37 | """ 38 | assert len(dist_mat.size()) == 2 39 | assert dist_mat.size(0) == dist_mat.size(1) 40 | N = dist_mat.size(0) 41 | 42 | # shape [N, N] 43 | is_pos = target.expand(N, N).eq(target.expand(N, N).t()) 44 | is_neg = target.expand(N, N).ne(target.expand(N, N).t()) 45 | 46 | dist_ap, relative_p_inds = torch.max( 47 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 48 | 49 | dist_an, relative_n_inds = torch.min( 50 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 51 | 52 | dist_ap = dist_ap.squeeze(1) 53 | dist_an = dist_an.squeeze(1) 54 | 55 | return dist_ap, dist_an 56 | 57 | 58 | class TripletLoss(nn.Module): 59 | """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). 60 | Related Triplet Loss theory can be found in paper 'In Defense of the Triplet 61 | Loss for Person Re-Identification'.""" 62 | def __init__(self, margin, feat_norm='yes'): 63 | super(TripletLoss, self).__init__() 64 | self.margin = margin 65 | self.feat_norm = feat_norm 66 | if margin >= 0: 67 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 68 | else: 69 | self.ranking_loss = nn.SoftMarginLoss() 70 | 71 | def forward(self, global_feat1, global_feat2, target): 72 | if self.feat_norm == 'yes': 73 | global_feat1 = F.normalize(global_feat1, p=2, dim=-1) 74 | global_feat2 = F.normalize(global_feat2, p=2, dim=-1) 75 | 76 | dist_mat = euclidean_dist(global_feat1, global_feat2) 77 | dist_ap, dist_an = hard_example_mining(dist_mat, target) 78 | 79 | y = dist_an.new().resize_as_(dist_an).fill_(1) 80 | if self.margin >= 0: 81 | loss = self.ranking_loss(dist_an, dist_ap, y) 82 | else: 83 | loss = self.ranking_loss(dist_an - dist_ap, y) 84 | 85 | return loss 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /Pytorch-PMT-VI-ReID/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for trial in 1 2 3 4 5 6 7 8 9 10 4 | do 5 | python train.py --trial $trial --config_file config/RegDB.yml 6 | done 7 | echo 'Done!' -------------------------------------------------------------------------------- /Pytorch-PMT-VI-ReID/model/make_model.py: -------------------------------------------------------------------------------- 1 | from model.vision_transformer import ViT 2 | import torch 3 | import torch.nn as nn 4 | 5 | # L2 norm 6 | class Normalize(nn.Module): 7 | def __init__(self, power=2): 8 | super(Normalize, self).__init__() 9 | self.power = power 10 | 11 | def forward(self, x): 12 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 13 | out = x.div(norm) 14 | return out 15 | 16 | def weights_init_kaiming(m): 17 | classname = m.__class__.__name__ 18 | if classname.find('Linear') != -1: 19 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 20 | nn.init.constant_(m.bias, 0.0) 21 | 22 | elif classname.find('Conv') != -1: 23 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 24 | if m.bias is not None: 25 | nn.init.constant_(m.bias, 0.0) 26 | elif classname.find('BatchNorm') != -1: 27 | if m.affine: 28 | nn.init.constant_(m.weight, 1.0) 29 | nn.init.constant_(m.bias, 0.0) 30 | 31 | 32 | def weights_init_classifier(m): 33 | classname = m.__class__.__name__ 34 | if classname.find('Linear') != -1: 35 | nn.init.normal_(m.weight, std=0.001) 36 | if m.bias: 37 | nn.init.constant_(m.bias, 0.0) 38 | 39 | class build_vision_transformer(nn.Module): 40 | def __init__(self, num_classes, cfg): 41 | super(build_vision_transformer, self).__init__() 42 | self.in_planes = 768 43 | 44 | self.base = ViT(img_size=[cfg.H,cfg.W], 45 | stride_size=cfg.STRIDE_SIZE, 46 | drop_path_rate=cfg.DROP_PATH, 47 | drop_rate=cfg.DROP_OUT, 48 | attn_drop_rate=cfg.ATT_DROP_RATE) 49 | 50 | self.base.load_param(cfg.PRETRAIN_PATH) 51 | 52 | print('Loading pretrained ImageNet model......from {}'.format(cfg.PRETRAIN_PATH)) 53 | 54 | self.num_classes = num_classes 55 | 56 | self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) 57 | self.classifier.apply(weights_init_classifier) 58 | 59 | self.bottleneck = nn.BatchNorm1d(self.in_planes) 60 | self.bottleneck.bias.requires_grad_(False) 61 | self.bottleneck.apply(weights_init_kaiming) 62 | 63 | self.l2norm = Normalize(2) 64 | 65 | 66 | def forward(self, x): 67 | features = self.base(x) 68 | feat = self.bottleneck(features) 69 | 70 | if self.training: 71 | cls_score = self.classifier(feat) 72 | 73 | return cls_score, features 74 | 75 | else: 76 | return self.l2norm(feat) 77 | 78 | 79 | def load_param(self, trained_path): 80 | param_dict = torch.load(trained_path) 81 | for i in param_dict: 82 | self.state_dict()[i.replace('module.', '')].copy_(param_dict[i]) 83 | print('Loading pretrained model from {}'.format(trained_path)) 84 | 85 | def load_param_finetune(self, model_path): 86 | param_dict = torch.load(model_path) 87 | for i in param_dict: 88 | self.state_dict()[i].copy_(param_dict[i]) 89 | print('Loading pretrained model for finetuning from {}'.format(model_path)) -------------------------------------------------------------------------------- /Pytorch-PMT-VI-ReID/model/vision_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.nn.functional as F 5 | from itertools import repeat 6 | import collections.abc as container_abcs 7 | 8 | def _ntuple(n): 9 | def parse(x): 10 | if isinstance(x, container_abcs.Iterable): 11 | return x 12 | return tuple(repeat(x, n)) 13 | return parse 14 | 15 | to_2tuple = _ntuple(2) 16 | 17 | 18 | def weights_init_classifier(m): 19 | classname = m.__class__.__name__ 20 | if classname.find('Linear') != -1: 21 | nn.init.normal_(m.weight, std=0.001) 22 | if m.bias: 23 | nn.init.constant_(m.bias, 0.0) 24 | 25 | def drop_path(x, drop_prob: float = 0., training: bool = False): 26 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 27 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 28 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 29 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 30 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 31 | 'survival rate' as the argument. 32 | """ 33 | if drop_prob == 0. or not training: 34 | return x 35 | keep_prob = 1 - drop_prob 36 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 37 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 38 | random_tensor.floor_() # binarize 39 | output = x.div(keep_prob) * random_tensor 40 | return output 41 | 42 | class DropPath(nn.Module): 43 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 44 | """ 45 | def __init__(self, drop_prob=None): 46 | super(DropPath, self).__init__() 47 | self.drop_prob = drop_prob 48 | 49 | def forward(self, x): 50 | return drop_path(x, self.drop_prob, self.training) 51 | 52 | def resize_pos_embed(posemb, posemb_new, hight, width): 53 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 54 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 55 | ntok_new = posemb_new.shape[1] 56 | 57 | posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:] 58 | ntok_new -= 1 59 | 60 | gs_old = int(math.sqrt(len(posemb_grid))) 61 | print('Resized position embedding from size:{} to size: {} with height:{} width: {}'.format(posemb.shape, posemb_new.shape, hight, width)) 62 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 63 | posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear') 64 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1) 65 | posemb = torch.cat([posemb_token, posemb_grid], dim=1) 66 | return posemb 67 | 68 | class Mlp(nn.Module): 69 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 70 | super().__init__() 71 | out_features = out_features or in_features 72 | hidden_features = hidden_features or in_features 73 | self.fc1 = nn.Linear(in_features, hidden_features) 74 | self.act = act_layer() 75 | self.fc2 = nn.Linear(hidden_features, out_features) 76 | self.drop = nn.Dropout(drop) 77 | 78 | def forward(self, x): 79 | x = self.fc1(x) 80 | x = self.act(x) 81 | x = self.drop(x) 82 | x = self.fc2(x) 83 | x = self.drop(x) 84 | return x 85 | 86 | 87 | class Attention(nn.Module): 88 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 89 | super().__init__() 90 | self.num_heads = num_heads 91 | head_dim = dim // num_heads 92 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 93 | self.scale = qk_scale or head_dim ** -0.5 94 | 95 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 96 | self.attn_drop = nn.Dropout(attn_drop) 97 | self.proj = nn.Linear(dim, dim) 98 | self.proj_drop = nn.Dropout(proj_drop) 99 | 100 | def forward(self, x): 101 | B, N, C = x.shape 102 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 103 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 104 | 105 | attn = (q @ k.transpose(-2, -1)) * self.scale 106 | attn = attn.softmax(dim=-1) 107 | attn = self.attn_drop(attn) 108 | 109 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) #[128, 211, 768] 110 | x = self.proj(x) 111 | x = self.proj_drop(x) 112 | return x 113 | 114 | 115 | class Block(nn.Module): 116 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 117 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 118 | super().__init__() 119 | self.norm1 = norm_layer(dim) 120 | self.attn = Attention( 121 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 122 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 123 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 124 | self.norm2 = norm_layer(dim) 125 | mlp_hidden_dim = int(dim * mlp_ratio) 126 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 127 | 128 | def forward(self, x): 129 | x = x + self.drop_path(self.attn(self.norm1(x))) 130 | x = x + self.drop_path(self.mlp(self.norm2(x))) 131 | return x 132 | 133 | 134 | class PatchEmbed_overlap(nn.Module): 135 | """ Image to Patch Embedding with overlapping patches""" 136 | def __init__(self, img_size=224, patch_size=16, stride_size=20, in_chans=3, embed_dim=768): 137 | super().__init__() 138 | img_size = to_2tuple(img_size) 139 | patch_size = to_2tuple(patch_size) 140 | stride_size_tuple = to_2tuple(stride_size) 141 | self.num_x = (img_size[1] - patch_size[1]) // stride_size_tuple[1] + 1 142 | self.num_y = (img_size[0] - patch_size[0]) // stride_size_tuple[0] + 1 143 | print('using stride: {}, and patch number is num_y{} * num_x{}'.format(stride_size, self.num_y, self.num_x)) 144 | num_patches = self.num_x * self.num_y 145 | self.img_size = img_size 146 | self.patch_size = patch_size 147 | self.num_patches = num_patches 148 | 149 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride_size) 150 | 151 | for m in self.modules(): 152 | if isinstance(m, nn.Conv2d): 153 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 154 | m.weight.data.normal_(0, math.sqrt(2. / n)) 155 | elif isinstance(m, nn.BatchNorm2d): 156 | m.weight.data.fill_(1) 157 | m.bias.data.zero_() 158 | elif isinstance(m, nn.InstanceNorm2d): 159 | m.weight.data.fill_(1) 160 | m.bias.data.zero_() 161 | 162 | def forward(self, x): 163 | B, C, H, W = x.shape #batch_size , channels , height ,width 164 | # FIXME look at relaxing size constraints 165 | assert H == self.img_size[0] and W == self.img_size[1], \ 166 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 167 | 168 | x = self.proj(x) 169 | x = x.flatten(2).transpose(1, 2) 170 | 171 | return x 172 | 173 | 174 | class ViT(nn.Module): 175 | def __init__(self, img_size=224, patch_size=16, stride_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 176 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., 177 | drop_path_rate=0., norm_layer = nn.LayerNorm): 178 | super(ViT, self).__init__() 179 | self.num_classes = num_classes 180 | 181 | self.patch_embed = PatchEmbed_overlap(img_size=img_size, patch_size=patch_size, stride_size=stride_size, in_chans=in_chans,embed_dim=embed_dim) 182 | 183 | num_patches = self.patch_embed.num_patches 184 | 185 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 186 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 187 | 188 | self.pos_drop = nn.Dropout(p=drop_rate) 189 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 190 | 191 | self.blocks = nn.ModuleList([ 192 | Block( 193 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 194 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 195 | for i in range(depth)]) 196 | 197 | self.norm = norm_layer(embed_dim) 198 | 199 | trunc_normal_(self.cls_token, std=.02) 200 | trunc_normal_(self.pos_embed, std=.02) 201 | self.apply(self._init_weights) 202 | 203 | 204 | def _init_weights(self, m): 205 | if isinstance(m, nn.Linear): 206 | trunc_normal_(m.weight, std=.02) 207 | if isinstance(m, nn.Linear) and m.bias is not None: 208 | nn.init.constant_(m.bias, 0) 209 | 210 | elif isinstance(m, nn.LayerNorm): 211 | nn.init.constant_(m.bias, 0) 212 | nn.init.constant_(m.weight, 1.0) 213 | 214 | def forward_features(self, x): 215 | B = x.shape[0] 216 | x = self.patch_embed(x) 217 | 218 | cls_tokens = self.cls_token.expand(B, -1, -1) 219 | x = torch.cat((cls_tokens, x), dim=1) 220 | 221 | x = x + self.pos_embed 222 | 223 | x = self.pos_drop(x) 224 | 225 | for blk in self.blocks: 226 | x = blk(x) 227 | 228 | x = self.norm(x) 229 | 230 | return x[:, 0] 231 | 232 | def forward(self,x): 233 | x = self.forward_features(x) 234 | return x 235 | 236 | def load_param(self, model_path): 237 | param_dict = torch.load(model_path, map_location='cpu') 238 | if 'model' in param_dict: 239 | param_dict = param_dict['model'] 240 | if 'state_dict' in param_dict: 241 | param_dict = param_dict['state_dict'] 242 | for k, v in param_dict.items(): 243 | if 'head' in k or 'dist' in k: 244 | continue 245 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4: 246 | # For old models that I trained prior to conv based patchification 247 | O, I, H, W = self.patch_embed.proj.weight.shape 248 | v = v.reshape(O, -1, H, W) 249 | elif k == 'pos_embed' and v.shape != self.pos_embed.shape: 250 | # To resize pos embedding when using model at different size from pretrained weights 251 | if 'distilled' in model_path: 252 | print('distill need to choose right cls token in the pth') 253 | v = torch.cat([v[:, 0:1], v[:, 2:]], dim=1) 254 | v = resize_pos_embed(v, self.pos_embed, self.patch_embed.num_y, self.patch_embed.num_x) 255 | try: 256 | self.state_dict()[k].copy_(v) 257 | except: 258 | print('===========================ERROR=========================') 259 | print('shape do not match in k :{}: param_dict{} vs self.state_dict(){}'.format(k, v.shape, self.state_dict()[k].shape)) 260 | 261 | 262 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 263 | # type: (Tensor, float, float, float, float) -> Tensor 264 | r""" 265 | Examples: 266 | >>> w = torch.empty(3, 5) 267 | >>> nn.init.trunc_normal_(w) 268 | """ 269 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 270 | 271 | 272 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): #标准化 273 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 274 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 275 | def norm_cdf(x): 276 | # Computes standard normal cumulative distribution function 277 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 278 | 279 | if (mean < a - 2 * std) or (mean > b + 2 * std): 280 | print("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 281 | "The distribution of values may be incorrect.",) 282 | 283 | with torch.no_grad(): 284 | # Values are generated by using a truncated uniform distribution and 285 | # then using the inverse CDF for the normal distribution. 286 | # Get upper and lower cdf values 287 | l = norm_cdf((a - mean) / std) 288 | u = norm_cdf((b - mean) / std) 289 | 290 | # Uniformly fill tensor with values from [l, u], then translate to 291 | # [2l-1, 2u-1]. 292 | tensor.uniform_(2 * l - 1, 2 * u - 1) 293 | 294 | # Use inverse cdf transform for normal distribution to get truncated 295 | # standard normal 296 | tensor.erfinv_() 297 | 298 | # Transform to proper mean, std 299 | tensor.mul_(std * math.sqrt(2.)) 300 | tensor.add_(mean) 301 | 302 | # Clamp to ensure it's in the proper range 303 | tensor.clamp_(min=a, max=b) 304 | return tensor -------------------------------------------------------------------------------- /Pytorch-PMT-VI-ReID/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def make_optimizer(cfg, model): 4 | params = [] 5 | for key, value in model.named_parameters(): 6 | if not value.requires_grad: 7 | continue 8 | 9 | lr = cfg.BASE_LR 10 | weight_decay = cfg.WEIGHT_DECAY 11 | 12 | if "bias" in key: 13 | lr = cfg.BASE_LR * cfg.BIAS_LR_FACTOR 14 | weight_decay = cfg.WEIGHT_DECAY_BIAS 15 | 16 | if "base.patch_embed.proj" in key: 17 | params +=[{"params": [value], "lr": lr * cfg.LR_PRETRAIN, "weight_decay": weight_decay}] 18 | continue 19 | 20 | if "base.blocks." in key: 21 | params +=[{"params": [value], "lr": lr * cfg.LR_PRETRAIN, "weight_decay": weight_decay}] 22 | continue 23 | 24 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 25 | 26 | if cfg.OPTIMIZER_NAME == 'SGD': 27 | optimizer = getattr(torch.optim, cfg.OPTIMIZER_NAME)(params, momentum=cfg.MOMENTUM) 28 | 29 | elif cfg.OPTIMIZER_NAME == 'AdamW': 30 | optimizer = torch.optim.AdamW(params, lr=cfg.BASE_LR, weight_decay=cfg.WEIGHT_DECAY) 31 | 32 | else: 33 | optimizer = getattr(torch.optim, cfg.OPTIMIZER_NAME)(params) 34 | 35 | return optimizer -------------------------------------------------------------------------------- /Pytorch-PMT-VI-ReID/process_sysu.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | from config.config import cfg 5 | 6 | root = cfg.DATA_PATH_SYSU 7 | 8 | rgb_cameras = ['cam1','cam2','cam4','cam5'] 9 | ir_cameras = ['cam3','cam6'] 10 | 11 | # load id info 12 | file_path_train = os.path.join(root, 'exp/train_id.txt') 13 | with open(file_path_train, 'r') as file: 14 | ids = file.read().splitlines() 15 | ids = [int(y) for y in ids[0].split(',')] 16 | id_train = ["%04d" % x for x in ids] 17 | 18 | file_path_val = os.path.join(root, 'exp/val_id.txt') 19 | with open(file_path_val, 'r') as file: 20 | ids = file.read().splitlines() 21 | ids = [int(y) for y in ids[0].split(',')] 22 | id_val = ["%04d" % x for x in ids] 23 | 24 | # combine train and val split 25 | id_train.extend(id_val) 26 | 27 | img_paths_rgb = [] 28 | img_paths_ir = [] 29 | for pid in sorted(id_train): 30 | for cam in rgb_cameras: 31 | img_dir = os.path.join(root, cam, pid) 32 | if os.path.isdir(img_dir): 33 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 34 | img_paths_rgb.extend(new_files) 35 | 36 | for cam in ir_cameras: 37 | img_dir = os.path.join(root, cam, pid) 38 | if os.path.isdir(img_dir): 39 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 40 | img_paths_ir.extend(new_files) 41 | 42 | # relabel 43 | pid_container = set() 44 | for img_path in img_paths_ir: 45 | pid = int(img_path[-13:-9]) 46 | pid_container.add(pid) 47 | 48 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 49 | 50 | def read_imgs(img_paths, img_w, img_h): 51 | train_img = [] 52 | train_label = [] 53 | for img_path in img_paths: 54 | # img 55 | img = Image.open(img_path) 56 | img = img.resize((img_w, img_h), Image.ANTIALIAS) 57 | pix_array = np.array(img) 58 | train_img.append(pix_array) 59 | 60 | # label 61 | pid = int(img_path[-13:-9]) 62 | label = pid2label[pid] 63 | train_label.append(label) 64 | 65 | return np.array(train_img), np.array(train_label) 66 | 67 | 68 | train_img, train_label = read_imgs(img_paths_rgb, img_w=cfg.W, img_h=cfg.H) 69 | np.save(os.path.join(root, 'train_rgb_resized_img.npy'), train_img) 70 | np.save(os.path.join(root, 'train_rgb_resized_label.npy'), train_label) 71 | 72 | train_img, train_label = read_imgs(img_paths_ir, img_w=cfg.W, img_h=cfg.H) 73 | np.save(os.path.join(root, 'train_ir_resized_img.npy'), train_img) 74 | np.save(os.path.join(root, 'train_ir_resized_label.npy'), train_label) 75 | -------------------------------------------------------------------------------- /Pytorch-PMT-VI-ReID/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | import logging 3 | import math 4 | import torch 5 | 6 | _logger = logging.getLogger(__name__) 7 | 8 | class Scheduler: 9 | """ Parameter Scheduler Base Class 10 | A scheduler base class that can be used to schedule any optimizer parameter groups. 11 | 12 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called 13 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value 14 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value 15 | 16 | The schedulers built on this should try to remain as stateless as possible (for simplicity). 17 | 18 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' 19 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training 20 | code and explicitly passed in to the schedulers on the corresponding step or step_update call. 21 | 22 | Based on ideas from: 23 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler 24 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers 25 | """ 26 | 27 | def __init__(self, 28 | optimizer: torch.optim.Optimizer, 29 | param_group_field: str, 30 | noise_range_t=None, 31 | noise_type='normal', 32 | noise_pct=0.67, 33 | noise_std=1.0, 34 | noise_seed=None, 35 | initialize: bool = True) -> None: 36 | self.optimizer = optimizer 37 | self.param_group_field = param_group_field 38 | self._initial_param_group_field = f"initial_{param_group_field}" 39 | if initialize: 40 | for i, group in enumerate(self.optimizer.param_groups): 41 | if param_group_field not in group: 42 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]") 43 | group.setdefault(self._initial_param_group_field, group[param_group_field]) 44 | else: 45 | for i, group in enumerate(self.optimizer.param_groups): 46 | if self._initial_param_group_field not in group: 47 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") 48 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] 49 | self.metric = None # any point to having this for all? 50 | self.noise_range_t = noise_range_t 51 | self.noise_pct = noise_pct 52 | self.noise_type = noise_type 53 | self.noise_std = noise_std 54 | self.noise_seed = noise_seed if noise_seed is not None else 42 55 | self.update_groups(self.base_values) 56 | 57 | def state_dict(self) -> Dict[str, Any]: 58 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 59 | 60 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 61 | self.__dict__.update(state_dict) 62 | 63 | def get_epoch_values(self, epoch: int): 64 | return None 65 | 66 | def get_update_values(self, num_updates: int): 67 | return None 68 | 69 | def step(self, epoch: int, metric: float = None) -> None: 70 | self.metric = metric 71 | values = self.get_epoch_values(epoch) 72 | if values is not None: 73 | values = self._add_noise(values, epoch) 74 | self.update_groups(values) 75 | 76 | def step_update(self, num_updates: int, metric: float = None): 77 | self.metric = metric 78 | values = self.get_update_values(num_updates) 79 | if values is not None: 80 | values = self._add_noise(values, num_updates) 81 | self.update_groups(values) 82 | 83 | def update_groups(self, values): 84 | if not isinstance(values, (list, tuple)): 85 | values = [values] * len(self.optimizer.param_groups) 86 | for param_group, value in zip(self.optimizer.param_groups, values): 87 | param_group[self.param_group_field] = value 88 | 89 | def _add_noise(self, lrs, t): 90 | if self.noise_range_t is not None: 91 | if isinstance(self.noise_range_t, (list, tuple)): 92 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] 93 | else: 94 | apply_noise = t >= self.noise_range_t 95 | if apply_noise: 96 | g = torch.Generator() 97 | g.manual_seed(self.noise_seed + t) 98 | if self.noise_type == 'normal': 99 | while True: 100 | # resample if noise out of percent limit, brute force but shouldn't spin much 101 | noise = torch.randn(1, generator=g).item() 102 | if abs(noise) < self.noise_pct: 103 | break 104 | else: 105 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 106 | lrs = [v + v * noise for v in lrs] 107 | return lrs 108 | 109 | 110 | 111 | class CosineLRScheduler(Scheduler): 112 | """ 113 | Cosine decay with restarts. 114 | This is described in the paper https://arxiv.org/abs/1608.03983. 115 | Inspiration from 116 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 117 | """ 118 | 119 | def __init__(self, 120 | optimizer: torch.optim.Optimizer, 121 | t_initial: int, 122 | t_mul: float = 1., 123 | lr_min: float = 0., 124 | decay_rate: float = 1., 125 | warmup_t=0, 126 | warmup_lr_init=0, 127 | warmup_prefix=False, 128 | cycle_limit=0, 129 | t_in_epochs=True, 130 | noise_range_t=None, 131 | noise_pct=0.67, 132 | noise_std=1.0, 133 | noise_seed=42, 134 | initialize=True) -> None: 135 | super().__init__( 136 | optimizer, param_group_field="lr", 137 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 138 | initialize=initialize) 139 | 140 | assert t_initial > 0 141 | assert lr_min >= 0 142 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 143 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 144 | "rate since t_initial = t_mul = eta_mul = 1.") 145 | self.t_initial = t_initial 146 | self.t_mul = t_mul 147 | self.lr_min = lr_min 148 | self.decay_rate = decay_rate 149 | self.cycle_limit = cycle_limit 150 | self.warmup_t = warmup_t 151 | self.warmup_lr_init = warmup_lr_init 152 | self.warmup_prefix = warmup_prefix 153 | self.t_in_epochs = t_in_epochs 154 | if self.warmup_t: 155 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 156 | super().update_groups(self.warmup_lr_init) 157 | else: 158 | self.warmup_steps = [1 for _ in self.base_values] 159 | 160 | def _get_lr(self, t): 161 | if t < self.warmup_t: 162 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 163 | else: 164 | if self.warmup_prefix: 165 | t = t - self.warmup_t 166 | 167 | if self.t_mul != 1: 168 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 169 | t_i = self.t_mul ** i * self.t_initial 170 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 171 | else: 172 | i = t // self.t_initial 173 | t_i = self.t_initial 174 | t_curr = t - (self.t_initial * i) 175 | 176 | gamma = self.decay_rate ** i 177 | lr_min = self.lr_min * gamma 178 | lr_max_values = [v * gamma for v in self.base_values] 179 | 180 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 181 | lrs = [ 182 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 183 | ] 184 | else: 185 | lrs = [self.lr_min for _ in self.base_values] 186 | 187 | return lrs 188 | 189 | def get_epoch_values(self, epoch: int): 190 | if self.t_in_epochs: 191 | return self._get_lr(epoch) 192 | else: 193 | return None 194 | 195 | def get_update_values(self, num_updates: int): 196 | if not self.t_in_epochs: 197 | return self._get_lr(num_updates) 198 | else: 199 | return None 200 | 201 | def get_cycle_length(self, cycles=0): 202 | if not cycles: 203 | cycles = self.cycle_limit 204 | cycles = max(1, cycles) 205 | if self.t_mul == 1.0: 206 | return self.t_initial * cycles 207 | else: 208 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 209 | 210 | 211 | 212 | def create_scheduler(cfg, optimizer): 213 | 214 | num_epochs = cfg.MAX_EPOCH 215 | lr_min = cfg.LR_MIN * cfg.BASE_LR 216 | warmup_lr_init = cfg.LR_INIT * cfg.BASE_LR 217 | warmup_t = cfg.WARMUP_EPOCHS 218 | 219 | lr_scheduler = CosineLRScheduler( 220 | optimizer, 221 | t_initial=num_epochs, 222 | lr_min=lr_min, 223 | t_mul= 1., 224 | decay_rate=1.0, 225 | warmup_lr_init=warmup_lr_init, 226 | warmup_t=warmup_t, 227 | cycle_limit=1, 228 | t_in_epochs=True, 229 | noise_range_t=None, 230 | noise_pct= 0.67, 231 | noise_std= 1., 232 | noise_seed=42, 233 | ) 234 | 235 | return lr_scheduler -------------------------------------------------------------------------------- /Pytorch-PMT-VI-ReID/test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import torch.backends.cudnn as cudnn 4 | from torch.autograd import Variable 5 | import torch.utils.data as data 6 | from eval_metrics import eval_sysu,eval_regdb 7 | from dataloader import TestData,TestData_RegDB 8 | from datamanager import * 9 | from model.make_model import build_vision_transformer 10 | from config.config import cfg 11 | from transforms import transform_test 12 | from tqdm import tqdm 13 | import argparse 14 | 15 | parser = argparse.ArgumentParser(description='PMT Training') 16 | parser.add_argument('--dataset', default='sysu', help= 'dataset name: regdb or sysu') 17 | parser.add_argument('--resume', '-r', default='', type=str, 18 | help='resume from checkpoint') 19 | parser.add_argument('--model_path', default='save_model/', type=str, 20 | help='model save path') 21 | parser.add_argument('--workers', default=0, type=int, metavar='N', 22 | help='number of data loading workers') 23 | parser.add_argument('--test-batch', default=128, type=int, 24 | help='testing batch size') 25 | parser.add_argument('--gpu', default='0', type=str, 26 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 27 | parser.add_argument('--mode', default='all', type=str, 28 | help='all or indoor for sysu') 29 | parser.add_argument('--gall_mode', default='single', type=str, 30 | help='single or multi for sysu') 31 | parser.add_argument('--trial', default=1, type=int, help='trial for regdb') 32 | parser.add_argument('--tvsearch', default=False, type=bool, 33 | help='whether thermal to visible search on regdb') 34 | args = parser.parse_args() 35 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 36 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 37 | 38 | if args.dataset == 'sysu': 39 | data_path = cfg.DATA_PATH_SYSU 40 | n_class = 395 41 | 42 | elif args.dataset == 'regdb': 43 | data_path = cfg.DATA_PATH_RegDB 44 | n_class = 206 45 | 46 | 47 | print('==> Building model..') 48 | model = build_vision_transformer(num_classes = n_class, cfg = cfg) 49 | model.to(device) 50 | cudnn.benchmark = True 51 | model.eval() 52 | 53 | def extract_gall_feat(gall_loader): 54 | model.eval() 55 | ptr = 0 56 | gall_feat = np.zeros((ngall, 768)) 57 | 58 | with torch.no_grad(): 59 | for batch_idx, (input, label) in enumerate(gall_loader): 60 | batch_num = input.size(0) 61 | input = Variable(input.cuda()) 62 | feat = model(input) 63 | gall_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 64 | ptr = ptr + batch_num 65 | 66 | return gall_feat 67 | 68 | def extract_query_feat(query_loader): 69 | model.eval() 70 | ptr = 0 71 | query_feat = np.zeros((nquery, 768)) 72 | 73 | with torch.no_grad(): 74 | for batch_idx, (input, label) in enumerate(query_loader): 75 | batch_num = input.size(0) 76 | input = Variable(input.cuda()) 77 | feat = model(input) 78 | query_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 79 | ptr = ptr + batch_num 80 | 81 | return query_feat 82 | 83 | all_cmc = 0 84 | all_mAP = 0 85 | all_mINP = 0 86 | 87 | if args.dataset == 'sysu': 88 | # load checkpoint 89 | print('==> Resuming from checkpoint..') 90 | if len(args.resume) > 0: 91 | model_path = args.model_path + args.resume 92 | if os.path.isfile(model_path): 93 | print('==> loading checkpoint {}'.format(args.resume)) 94 | model.load_param(model_path) 95 | print('==> loaded checkpoint {}'.format(args.resume)) 96 | else: 97 | print('==> no checkpoint found at {}'.format(args.resume)) 98 | 99 | # Test set 100 | query_img, query_label, query_cam = process_query_sysu(data_path, mode=args.mode) 101 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=0, gall_mode=args.gall_mode) # indoor 102 | 103 | nquery = len(query_label) 104 | ngall = len(gall_label) 105 | print("Dataset statistics:") 106 | print(" ------------------------------") 107 | print(" subset | # ids | # images") 108 | print(" ------------------------------") 109 | print(" query | {:5d} | {:8d}".format(len(np.unique(query_label)), nquery)) 110 | print(" gallery | {:5d} | {:8d}".format(len(np.unique(gall_label)), ngall)) 111 | print(" ------------------------------") 112 | 113 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(cfg.W, cfg.H)) 114 | query_loader = data.DataLoader(queryset, batch_size=128, shuffle=False, num_workers=args.workers) 115 | 116 | query_feat = extract_query_feat(query_loader) 117 | 118 | for i in tqdm(range(10)): 119 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=i,gall_mode=args.gall_mode) #all 120 | 121 | trial_gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(cfg.W, cfg.H)) 122 | trial_gall_loader = data.DataLoader(trial_gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 123 | 124 | gall_feat = extract_gall_feat(trial_gall_loader) 125 | 126 | distmat = -np.matmul(query_feat, np.transpose(gall_feat)) 127 | 128 | cmc, mAP, mInp = eval_sysu(distmat, query_label, gall_label, query_cam, gall_cam) 129 | print('\n mAP: {:.2%} | mInp:{:.2%} | top-1: {:.2%} | top-5: {:.2%} | top-10: {:.2%}| top-20: {:.2%} '.format(mAP,mInp,cmc[0],cmc[4],cmc[9],cmc[19])) 130 | 131 | all_cmc += cmc 132 | all_mAP += mAP 133 | all_mINP += mInp 134 | 135 | all_cmc /= 10.0 136 | all_mAP /= 10.0 137 | all_mINP /= 10.0 138 | print('\n Average:') 139 | print('mAP: {:.2%} | mInp:{:.2%} | top-1: {:.2%} | top-5: {:.2%} | top-10: {:.2%}| top-20: {:.2%}'.format(all_mAP,all_mINP,all_cmc[0],all_cmc[4],all_cmc[9],all_cmc[19])) 140 | 141 | # single test for regdb 142 | elif args.dataset == 'regdb': 143 | 144 | # load checkpoint 145 | print('==> Resuming from checkpoint..') 146 | if len(args.resume) > 0: 147 | model_path = args.model_path + args.resume 148 | if os.path.isfile(model_path): 149 | print('==> loading checkpoint {}'.format(args.resume)) 150 | model.load_param(model_path) 151 | print('==> loaded checkpoint {}'.format(args.resume)) 152 | else: 153 | print('==> no checkpoint found at {}'.format(args.resume)) 154 | 155 | 156 | query_img, query_label = process_test_regdb(data_path, trial=args.trial, modal='visible') 157 | gall_img, gall_label = process_test_regdb(data_path, trial=args.trial, modal='thermal') 158 | 159 | galleryset = TestData_RegDB(gall_img, gall_label, transform=transform_test, img_size=(cfg.W, cfg.H)) 160 | queryset = TestData_RegDB(query_img, query_label, transform=transform_test, img_size=(cfg.W, cfg.H)) 161 | 162 | nquery = len(query_label) 163 | ngall = len(gall_label) 164 | print("Dataset statistics:") 165 | print(" ------------------------------") 166 | print(" subset | # ids | # images") 167 | print(" ------------------------------") 168 | print(" query | {:5d} | {:8d}".format(len(np.unique(query_label)), nquery)) 169 | print(" gallery | {:5d} | {:8d}".format(len(np.unique(gall_label)), ngall)) 170 | print(" ------------------------------") 171 | 172 | gall_loader = data.DataLoader(galleryset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 173 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 174 | 175 | query_feat = extract_query_feat(query_loader) 176 | gall_feat = extract_gall_feat(gall_loader) 177 | 178 | if args.tvsearch: # T -> V 179 | distmat = -np.matmul(gall_feat, np.transpose(query_feat)) 180 | cm, mA, mInp = eval_regdb(distmat, gall_label, query_label) 181 | 182 | else: # V -> T 183 | distmat = -np.matmul(query_feat, np.transpose(gall_feat)) 184 | cm, mA, mInp = eval_regdb(distmat, query_label, gall_label) 185 | 186 | print('mAP: {:.2%} | mInp:{:.2%} | R-1: {:.2%} | R-5: {:.2%} | R-10: {:.2%}| R-20: {:.2%}'.format(mA,mInp,cm[0],cm[4],cm[9],cm[19])) 187 | 188 | 189 | # # 10 independent tests for regdb 190 | # elif args.dataset == 'regdb': 191 | # all_cmc = 0 192 | # all_mAP = 0 193 | # all_mINP = 0 194 | # 195 | # for trial in range(10): 196 | # 197 | # test_trial = trial + 1 198 | # model_path = args.model_path + 'RegDB_best_trial_{}.pth'.format(test_trial) 199 | # 200 | # if os.path.isfile(model_path): 201 | # print('==> loading checkpoint {}'.format(model_path)) 202 | # model.load_param(model_path) 203 | # else: 204 | # print('==> no checkpoint found at {}'.format(model_path)) 205 | # 206 | # # for single test 207 | # query_img, query_label = process_test_regdb(data_path, trial=test_trial, modal='visible') 208 | # gall_img, gall_label = process_test_regdb(data_path, trial=test_trial, modal='thermal') 209 | # 210 | # galleryset = TestData_RegDB(gall_img, gall_label, transform=transform_test, img_size=(cfg.W, cfg.H)) 211 | # queryset = TestData_RegDB(query_img, query_label, transform=transform_test, img_size=(cfg.W, cfg.H)) 212 | # 213 | # nquery = len(query_label) 214 | # ngall = len(gall_label) 215 | # 216 | # gall_loader = data.DataLoader(galleryset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 217 | # query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 218 | # 219 | # query_feat = extract_query_feat(query_loader) 220 | # gall_feat = extract_gall_feat(gall_loader) 221 | # 222 | # if args.tvsearch: # T -> V 223 | # distmat = -np.matmul(gall_feat, np.transpose(query_feat)) 224 | # cmc, mAP, mINP = eval_regdb(distmat, gall_label, query_label) 225 | # 226 | # else: # V -> T 227 | # distmat = -np.matmul(query_feat, np.transpose(gall_feat)) 228 | # cmc, mAP, mINP = eval_regdb(distmat, query_label, gall_label) 229 | # 230 | # all_cmc += cmc 231 | # all_mAP += mAP 232 | # all_mINP += mINP 233 | # print('mAP: {:.2%} | mInp:{:.2%} | R-1: {:.2%} | R-5: {:.2%} | R-10: {:.2%}| R-20: {:.2%}'.format(mAP, mINP, cmc[0], cmc[4], cmc[9],cmc[19])) 234 | # 235 | # all_cmc /= 10.0 236 | # all_mAP /= 10.0 237 | # all_mINP /= 10.0 238 | # print('\n Average:') 239 | # print('mAP: {:.2%} | mInp:{:.2%} | top-1: {:.2%} | top-5: {:.2%} | top-10: {:.2%}| top-20: {:.2%}'.format(all_mAP, all_mINP, 240 | # all_cmc[0],all_cmc[4], 241 | # all_cmc[9],all_cmc[19])) 242 | -------------------------------------------------------------------------------- /Pytorch-PMT-VI-ReID/train.py: -------------------------------------------------------------------------------- 1 | from dataloader import SYSUData, RegDBData,TestData, GenIdx, IdentitySampler 2 | from datamanager import process_gallery_sysu, process_query_sysu, process_test_regdb 3 | import numpy as np 4 | import torch.utils.data as data 5 | from torch.autograd import Variable 6 | import torch 7 | from torch.cuda import amp 8 | import torch.nn as nn 9 | import os.path as osp 10 | import os 11 | from model.make_model import build_vision_transformer 12 | import time 13 | import optimizer 14 | from scheduler import create_scheduler 15 | from loss.Triplet import TripletLoss 16 | from loss.MSEL import MSEL 17 | from loss.DCL import DCL 18 | from utils import AverageMeter, set_seed 19 | from transforms import transform_rgb, transform_rgb2gray, transform_thermal, transform_test 20 | from optimizer import make_optimizer 21 | from config.config import cfg 22 | from eval_metrics import eval_sysu, eval_regdb 23 | import argparse 24 | 25 | parser = argparse.ArgumentParser(description="PMT Training") 26 | parser.add_argument('--config_file', default='config/SYSU.yml', 27 | help='path to config file', type=str) 28 | parser.add_argument('--trial', default=1, 29 | help='only for RegDB', type=int) 30 | parser.add_argument('--resume', '-r', default='', 31 | help='resume from checkpoint', type=str) 32 | parser.add_argument('--model_path', default='save_model/', 33 | help='model save path', type=str) 34 | parser.add_argument('--num_workers', default=0, 35 | help='number of data loading workers', type=int) 36 | parser.add_argument('--start_test', default=0, 37 | help='start to test in training', type=int) 38 | parser.add_argument('--test_batch', default=128, 39 | help='batch size for test', type=int) 40 | parser.add_argument('--test_epoch', default=2, 41 | help='test model every 2 epochs', type=int) 42 | parser.add_argument('--save_epoch', default=2, 43 | help='save model every 2 epochs', type=int) 44 | parser.add_argument('--gpu', default='0', 45 | help='gpu device ids for CUDA_VISIBLE_DEVICES', type=str) 46 | parser.add_argument("opts", help="Modify config options using the command-line", 47 | default=None,nargs=argparse.REMAINDER) 48 | args = parser.parse_args() 49 | 50 | if args.config_file != '': 51 | cfg.merge_from_file(args.config_file) 52 | cfg.merge_from_list(args.opts) 53 | cfg.freeze() 54 | 55 | set_seed(cfg.SEED) 56 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 57 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 58 | torch.cuda.empty_cache() 59 | 60 | 61 | if cfg.DATASET == 'sysu': 62 | data_path = cfg.DATA_PATH_SYSU 63 | trainset_gray = SYSUData(data_path, transform1=transform_rgb2gray, transform2=transform_thermal) 64 | color_pos_gray, thermal_pos_gray = GenIdx(trainset_gray.train_color_label, trainset_gray.train_thermal_label) 65 | 66 | trainset_rgb = SYSUData(data_path, transform1=transform_rgb, transform2=transform_thermal) 67 | color_pos_rgb, thermal_pos_rgb = GenIdx(trainset_rgb.train_color_label, trainset_rgb.train_thermal_label) 68 | 69 | elif cfg.DATASET == 'regdb': 70 | data_path = cfg.DATA_PATH_RegDB 71 | trainset_gray = RegDBData(data_path, args.trial, transform1=transform_rgb2gray,transform2=transform_thermal) 72 | color_pos_gray, thermal_pos_gray = GenIdx(trainset_gray.train_color_label, trainset_gray.train_thermal_label) 73 | 74 | trainset_rgb = RegDBData(data_path, args.trial, transform1=transform_rgb, transform2=transform_thermal) 75 | color_pos_rgb, thermal_pos_rgb = GenIdx(trainset_rgb.train_color_label, trainset_rgb.train_thermal_label) 76 | print('Current trial :', args.trial) 77 | 78 | 79 | num_classes = len(np.unique(trainset_rgb.train_color_label)) 80 | model = build_vision_transformer(num_classes=num_classes,cfg = cfg) 81 | model.to(device) 82 | 83 | # load checkpoint 84 | if len(args.resume) > 0: 85 | model_path = args.model_path + args.resume 86 | if os.path.isfile(model_path): 87 | print('==> loading checkpoint {}'.format(args.resume)) 88 | model.load_param(model_path) 89 | print('==> loaded checkpoint {}'.format(args.resume)) 90 | else: 91 | print('==> no checkpoint found at {}'.format(args.resume)) 92 | 93 | # Loss 94 | criterion_ID = nn.CrossEntropyLoss() 95 | criterion_Tri = TripletLoss(margin=cfg.MARGIN, feat_norm='no') 96 | criterion_DCL = DCL(num_pos=cfg.NUM_POS, feat_norm='no') 97 | criterion_MSEL = MSEL(num_pos=cfg.NUM_POS, feat_norm='no') 98 | 99 | optimizer = make_optimizer(cfg, model) 100 | scheduler = create_scheduler(cfg, optimizer) 101 | 102 | scaler = amp.GradScaler() 103 | 104 | 105 | if cfg.DATASET == 'sysu': # for test 106 | query_img, query_label, query_cam = process_query_sysu(data_path, mode='all') 107 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(cfg.W, cfg.H)) 108 | 109 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode='all', trial=0, gall_mode='single') 110 | gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(cfg.W, cfg.H)) 111 | 112 | elif cfg.DATASET == 'regdb': 113 | query_img, query_label = process_test_regdb(data_path, trial=args.trial, modal='visible') 114 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(cfg.W, cfg.H)) 115 | 116 | gall_img, gall_label = process_test_regdb(data_path, trial=args.trial, modal='thermal') 117 | gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(cfg.W, cfg.H)) 118 | 119 | # Test loader 120 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=args.num_workers) 121 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.num_workers) 122 | 123 | 124 | loss_meter = AverageMeter() 125 | loss_ce_meter = AverageMeter() 126 | loss_tri_meter = AverageMeter() 127 | acc_rgb_meter = AverageMeter() 128 | acc_ir_meter = AverageMeter() 129 | 130 | 131 | def train(epoch): 132 | start_time = time.time() 133 | 134 | loss_meter.reset() 135 | loss_ce_meter.reset() 136 | loss_tri_meter.reset() 137 | acc_rgb_meter.reset() 138 | acc_ir_meter.reset() 139 | 140 | scheduler.step(epoch) 141 | model.train() 142 | 143 | for idx, (input1, input2, label1, label2) in enumerate(trainloader): 144 | 145 | optimizer.zero_grad() 146 | input1 = input1.to(device) 147 | input2 = input2.to(device) 148 | label1 = label1.to(device) 149 | label2 = label2.to(device) 150 | labels = torch.cat((label1,label2),0) 151 | 152 | with amp.autocast(enabled=True): 153 | scores, feats = model(torch.cat([input1,input2])) 154 | 155 | score1, score2 = scores.chunk(2,0) 156 | feat1, feat2 = feats.chunk(2,0) 157 | loss_id = criterion_ID(score1, label1.long()) + criterion_ID(score2, label2.long()) 158 | 159 | if cfg.METHOD == 'PMT': 160 | if epoch <= cfg.PL_EPOCH : 161 | loss_tri = criterion_Tri(feat1, feat1, label1) + criterion_Tri(feat2, feat2, label2) # intra 162 | loss = loss_id + loss_tri 163 | 164 | else: 165 | loss_dcl = criterion_DCL(feats, labels) 166 | loss_msel = criterion_MSEL(feats, labels) 167 | 168 | loss_tri = criterion_Tri(feats, feats, labels) 169 | 170 | loss = loss_id + loss_tri + cfg.DCL * loss_dcl + cfg.MSEL * loss_msel 171 | 172 | else: 173 | loss_tri = criterion_Tri(feats, feats, labels) 174 | loss = loss_id + loss_tri 175 | 176 | scaler.scale(loss).backward() 177 | scaler.step(optimizer) 178 | scaler.update() 179 | 180 | acc_rgb = (score1.max(1)[1] == label1).float().mean() 181 | acc_ir = (score2.max(1)[1] == label2).float().mean() 182 | 183 | loss_tri_meter.update(loss_tri.item()) 184 | loss_ce_meter.update(loss_id.item()) 185 | loss_meter.update(loss.item()) 186 | 187 | acc_rgb_meter.update(acc_rgb, 1) 188 | acc_ir_meter.update(acc_ir, 1) 189 | 190 | torch.cuda.synchronize() 191 | 192 | if (idx + 1) % 32 == 0 : 193 | print('Epoch[{}] Iteration[{}/{}]' 194 | ' Loss: {:.3f}, Tri:{:.3f} CE:{:.3f}, ' 195 | 'Acc_RGB: {:.3f}, Acc_IR: {:.3f}, ' 196 | 'Base Lr: {:.2e} '.format(epoch, (idx+1), 197 | len(trainloader), loss_meter.avg, loss_tri_meter.avg, 198 | loss_ce_meter.avg, acc_rgb_meter.avg, acc_ir_meter.avg, 199 | optimizer.state_dict()['param_groups'][0]['lr'])) 200 | 201 | end_time = time.time() 202 | time_per_batch = end_time - start_time 203 | print(' Epoch {} done. Time per batch: {:.1f}[min] '.format(epoch, time_per_batch/60)) 204 | 205 | 206 | def test(query_loader, gall_loader, dataset = 'sysu'): 207 | model.eval() 208 | nquery = len(query_label) 209 | ngall = len(gall_label) 210 | print('Testing...') 211 | ptr = 0 212 | gall_feat = np.zeros((ngall, 768)) 213 | 214 | with torch.no_grad(): 215 | for batch_idx, (input, label) in enumerate(gall_loader): 216 | batch_num = input.size(0) 217 | input = Variable(input.cuda()) 218 | feat = model(input) 219 | gall_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 220 | ptr = ptr + batch_num 221 | 222 | ptr = 0 223 | query_feat = np.zeros((nquery, 768)) 224 | with torch.no_grad(): 225 | for batch_idx, (input, label) in enumerate(query_loader): 226 | batch_num = input.size(0) 227 | input = Variable(input.cuda()) 228 | feat = model(input) 229 | query_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 230 | ptr = ptr + batch_num 231 | 232 | distmat = -np.matmul(query_feat, np.transpose(gall_feat)) 233 | if dataset == 'sysu': 234 | cmc, mAP, mInp = eval_sysu(distmat, query_label, gall_label, query_cam, gall_cam) 235 | else: 236 | cmc, mAP, mInp = eval_regdb(distmat, query_label, gall_label) 237 | 238 | return cmc, mAP, mInp 239 | 240 | 241 | # Training 242 | best_mAP = 0 243 | print('==> Start Training...') 244 | for epoch in range(cfg.START_EPOCH, cfg.MAX_EPOCH + 1): 245 | 246 | print('==> Preparing Data Loader...') 247 | 248 | sampler_rgb = IdentitySampler(trainset_rgb.train_color_label, trainset_rgb.train_thermal_label, 249 | color_pos_rgb,thermal_pos_rgb, cfg.BATCH_SIZE, per_img=cfg.NUM_POS) 250 | 251 | # RGB-IR 252 | trainset_rgb.cIndex = sampler_rgb.index1 # color index 253 | trainset_rgb.tIndex = sampler_rgb.index2 254 | 255 | if cfg.METHOD == 'PMT': 256 | if epoch <= cfg.PL_EPOCH: 257 | sampler_gray = IdentitySampler(trainset_gray.train_color_label, trainset_gray.train_thermal_label, 258 | color_pos_gray, thermal_pos_gray, cfg.BATCH_SIZE, per_img=cfg.NUM_POS) # Gray 259 | # Gray-IR 260 | trainset_gray.cIndex = sampler_gray.index1 261 | trainset_gray.tIndex = sampler_gray.index2 262 | 263 | trainloader = data.DataLoader(trainset_gray, batch_size=cfg.BATCH_SIZE, sampler=sampler_gray, 264 | num_workers=args.num_workers,drop_last=True, pin_memory=True) 265 | 266 | else: 267 | trainloader = data.DataLoader(trainset_rgb, batch_size=cfg.BATCH_SIZE, sampler=sampler_rgb, 268 | num_workers=args.num_workers, drop_last=True,pin_memory=True) 269 | 270 | else: 271 | trainloader = data.DataLoader(trainset_rgb, batch_size=cfg.BATCH_SIZE, sampler=sampler_rgb, 272 | num_workers=args.num_workers, drop_last=True, pin_memory=True) 273 | 274 | train(epoch) 275 | 276 | if epoch > args.start_test and epoch % args.test_epoch == 0: 277 | cmc, mAP, mINP = test(query_loader, gall_loader, cfg.DATASET) 278 | print(' mAP: {:.2%} | mInp:{:.2%} | top-1: {:.2%} | top-5: {:.2%} | top-10: {:.2%}| top-20: {:.2%}'.format(mAP,mINP,cmc[0],cmc[4],cmc[9],cmc[19])) 279 | 280 | if mAP > best_mAP: 281 | best_mAP = mAP 282 | if cfg.DATASET == 'sysu': 283 | torch.save(model.state_dict(), osp.join('./save_model', os.path.basename(args.config_file)[:-4] + '_best.pth')) # maybe not the best 284 | else: 285 | torch.save(model.state_dict(), osp.join('./save_model', os.path.basename(args.config_file)[:-4] + '_best_trial_{}.pth'.format(args.trial))) 286 | 287 | if epoch > 20 and epoch % args.save_epoch == 0: 288 | 289 | torch.save(model.state_dict(), osp.join('./save_model', os.path.basename(args.config_file)[:-4] + '_epoch{}.pth'.format(epoch))) 290 | 291 | 292 | 293 | 294 | 295 | 296 | -------------------------------------------------------------------------------- /Pytorch-PMT-VI-ReID/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | from torchvision.transforms import * 3 | from PIL import Image 4 | import math 5 | from config.config import cfg 6 | 7 | class RandomErasing(object): 8 | """ Randomly selects a rectangle region in an image and erases its pixels. 9 | 'Random Erasing Data Augmentation' by Zhong et al. 10 | See https://arxiv.org/pdf/1708.04896.pdf 11 | Args: 12 | p: The prob that the Random Erasing operation will be performed. 13 | sl: Minimum proportion of erased area against input image. 14 | sh: Maximum proportion of erased area against input image. 15 | r1: Minimum aspect ratio of erased area. 16 | mean: Erasing value. 17 | """ 18 | def __init__(self, p=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.485, 0.456, 0.406)): 19 | self.p = p 20 | self.mean = mean 21 | self.sl = sl 22 | self.sh = sh 23 | self.r1 = r1 24 | 25 | def __call__(self, img): 26 | if random.uniform(0, 1) >= self.p: 27 | return img 28 | 29 | for attempt in range(100): 30 | area = img.size()[1] * img.size()[2] 31 | 32 | target_area = random.uniform(self.sl, self.sh) * area 33 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 34 | 35 | h = int(round(math.sqrt(target_area * aspect_ratio))) 36 | w = int(round(math.sqrt(target_area / aspect_ratio))) 37 | 38 | if w < img.size()[2] and h < img.size()[1]: 39 | x1 = random.randint(0, img.size()[1] - h) 40 | y1 = random.randint(0, img.size()[2] - w) 41 | if img.size()[0] == 3: 42 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 43 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 44 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 45 | else: 46 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 47 | return img 48 | 49 | return img 50 | 51 | 52 | class RectScale(object): 53 | def __init__(self, height, width, interpolation=Image.BILINEAR): 54 | self.height = height 55 | self.width = width 56 | self.interpolation = interpolation 57 | 58 | def __call__(self, img): 59 | w, h = img.size 60 | if h == self.height and w == self.width: 61 | return img 62 | return img.resize((self.width, self.height), self.interpolation) 63 | 64 | 65 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 66 | 67 | 68 | transform_mix_aug = [transforms.ColorJitter(brightness=0.3,contrast=0.3), 69 | transforms.GaussianBlur(21, sigma=(0.1, 3))] 70 | 71 | transform_rgb2gray = transforms.Compose([ 72 | transforms.ToPILImage(), 73 | RectScale(cfg.H, cfg.W), 74 | transforms.RandomHorizontalFlip(), 75 | transforms.Grayscale(num_output_channels=3), 76 | transforms.ToTensor(), 77 | normalize, 78 | RandomErasing(p=0.5) 79 | ]) 80 | 81 | transform_thermal = transforms.Compose([ 82 | transforms.ToPILImage(), 83 | RectScale(cfg.H, cfg.W), 84 | transforms.RandomHorizontalFlip(), 85 | transforms.RandomChoice(transform_mix_aug), 86 | transforms.ToTensor(), 87 | normalize, 88 | RandomErasing(p=0.5) 89 | ]) 90 | 91 | 92 | transform_rgb = transforms.Compose([ 93 | transforms.ToPILImage(), 94 | RectScale(cfg.H, cfg.W), 95 | transforms.RandomHorizontalFlip(), 96 | transforms.ToTensor(), 97 | normalize, 98 | RandomErasing(p=0.5) 99 | ]) 100 | 101 | 102 | transform_test = transforms.Compose([ 103 | transforms.ToPILImage(), 104 | RectScale(cfg.H, cfg.W), 105 | transforms.ToTensor(), 106 | normalize 107 | ]) -------------------------------------------------------------------------------- /Pytorch-PMT-VI-ReID/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import random 4 | 5 | def set_seed(seed): 6 | torch.manual_seed(seed) 7 | torch.cuda.manual_seed_all(seed) 8 | np.random.seed(seed) 9 | random.seed(seed) 10 | torch.backends.cudnn.benckmark = False 11 | torch.backends.cudnn.deterministic = True 12 | 13 | class AverageMeter(object): 14 | """Computes and stores the average and current value""" 15 | def __init__(self): 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def reset(self): 22 | self.val = 0 23 | self.avg = 0 24 | self.sum = 0 25 | self.count = 0 26 | 27 | def update(self, val, n=1): 28 | self.val = val 29 | self.sum += val * n 30 | self.count += n 31 | self.avg = self.sum / self.count 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Progressive Modality-shared Transformer (PMT) 2 | Pytorch code for paper "**Learning Progressive Modality-shared Transformers for Effective Visible-Infrared** 3 | 4 | **Person Re-identifification**". 5 | 6 | ### 1. Results 7 | We adopt the Transformer-based ViT-Base/16 and CNN-based AGW [3] as backbone respectively. 8 | 9 | |Datasets | Backbone | Rank@1 | mAP | mINP | Model| 10 | | -------- | ----- | ----- | ----- | ----- |:----:| 11 | | #SYSU-MM01 | ViT | ~ 67.53% | ~ 64.98% | ~51.86% |[GoogleDrive](https://drive.google.com/file/d/1S7Upn_8dWHNN5R3woazpocFU6J8hvCIe/view?usp=share_link)| 12 | |#SYSU-MM01 | AGW* | ~ 67.09% | ~ 64.25% | ~50.89% | [GoogleDrive](https://drive.google.com/file/d/1FOvspAdWEtqebAoqt48-bFxq5ebKnUrG/view?usp=share_link)| 13 | 14 | **\*Both of these two models may have some fluctuation due to random spliting. AGW\* means AGW uses random erasing. The results might be better by finetuning the hyper-parameters.** 15 | 16 | ### 2. Datasets 17 | 18 | - (1) RegDB [1]: The RegDB dataset can be downloaded from this [website](http://dm.dongguk.edu/link.html). 19 | 20 | - (2) SYSU-MM01 [2]: The SYSU-MM01 dataset can be downloaded from this [website](http://isee.sysu.edu.cn/project/RGBIRReID.htm). 21 | 22 | ### 3. Requirements 23 | 24 | #### **Prepare Pre-trained Model** 25 | 26 | - You may need to download the ImageNet pretrained transformer model [ViT-Base](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth). 27 | 28 | #### Prepare Training Data 29 | - You need to define the data path and pre-trained model path in `config.py`. 30 | - You need to run `python process_sysu.py` to pepare the dataset, the training data will be stored in ".npy" format. 31 | 32 | 33 | ### 4. Training 34 | 35 | **Train PMT by** 36 | 37 | ``` 38 | python train.py --config_file config/SYSU.yml 39 | ``` 40 | - `--config_file`: the config file path. 41 | 42 | ### 5. Testing 43 | 44 | **Test a model on SYSU-MM01 dataset by** 45 | 46 | ``` 47 | python test.py --dataset 'sysu' --mode 'all' --resume 'model_path' --gall_mode 'single' --gpu 0 48 | ``` 49 | - `--dataset`: which dataset "sysu" or "regdb". 50 | - `--mode`: "all" or "indoor" (only for sysu dataset). 51 | - `--gall_mode`: "single" or "multi" (only for sysu dataset). 52 | - `--resume`: the saved model path. 53 | - `--gpu`: which gpu to use. 54 | 55 | 56 | 57 | **Test a model on RegDB dataset by** 58 | 59 | ``` 60 | python test.py --dataset 'regdb' --resume 'model_path' --trial 1 --tvsearch True --gpu 0 61 | ``` 62 | 63 | - `--trial`: testing trial should match the trained model (only for regdb dataset). 64 | 65 | - `--tvsearch`: whether thermal to visible search True or False (only for regdb dataset). 66 | 67 | 68 | 69 | ### 6. Citation 70 | 71 | Most of the code of our backbone are borrowed from [TransReID](https://github.com/damo-cv/TransReID) [4] and [AGW](https://github.com/mangye16/Cross-Modal-Re-ID-baseline) [3]. 72 | 73 | Thanks a lot for the author's contribution. 74 | 75 | Please cite the following paper in your publications if it is helpful: 76 | 77 | ``` 78 | @article{lu2022learning, 79 | title={Learning Progressive Modality-shared Transformers for Effective Visible-Infrared Person Re-identification}, 80 | author={Lu, Hu and Zou, Xuezhang and Zhang, Pingping}, 81 | journal={arXiv preprint arXiv:2212.00226}, 82 | year={2022} 83 | } 84 | 85 | @inproceedings{he2021transreid, 86 | title={Transreid: Transformer-based object re-identification}, 87 | author={He, Shuting and Luo, Hao and Wang, Pichao and Wang, Fan and Li, Hao and Jiang, Wei}, 88 | booktitle={Proceedings of the IEEE/CVF international conference on computer vision}, 89 | pages={15013--15022}, 90 | year={2021} 91 | } 92 | 93 | @article{ye2021deep, 94 | title={Deep learning for person re-identification: A survey and outlook}, 95 | author={Ye, Mang and Shen, Jianbing and Lin, Gaojie and Xiang, Tao and Shao, Ling and Hoi, Steven CH}, 96 | journal={IEEE transactions on pattern analysis and machine intelligence}, 97 | volume={44}, 98 | number={6}, 99 | pages={2872--2893}, 100 | year={2021}, 101 | publisher={IEEE} 102 | } 103 | ``` 104 | 105 | ### 7. References. 106 | 107 | [1] D. T. Nguyen, H. G. Hong, K. W. Kim, and K. R. Park. Person recognition system based on a combination of body images from visible light and thermal cameras. Sensors, 17(3):605, 2017. 108 | 109 | [2] A. Wu, W.-s. Zheng, H.-X. Yu, S. Gong, and J. Lai. Rgb-infrared crossmodality person re-identification. In IEEE International Conference on Computer Vision (ICCV), pages 5380–5389, 2017. 110 | 111 | [3] Ye M, Shen J, Lin G, et al. Deep learning for person re-identification: A survey and outlook[J]. IEEE transactions on pattern analysis and machine intelligence, 2021, 44(6): 2872-2893. 112 | 113 | [4] He S, Luo H, Wang P, et al. Transreid: Transformer-based object re-identification[C]//Proceedings of the IEEE/CVF international conference on computer vision. 2021: 15013-15022. 114 | --------------------------------------------------------------------------------