├── README.md ├── data_loader.py ├── data_manager.py ├── eval_metrics.py ├── loss.py ├── model.py ├── pre_process_sysu.py ├── random_erasing.py ├── resnet.py ├── test.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Towards-a-Unified-Middle-Modality-Learning-for-Visible-Infrared-Person-Re-Identification 2 | 3 | [Paper](https://dl.acm.org/doi/10.1145/3474085.3475250) 4 | 5 | This repository is Pytorch code for our proposed MMN method for Cross-Modality Person Re-Identification. 6 | 7 | 8 | ### Training. 9 | Train a model by 10 | ```bash 11 | python train.py --dataset sysu 12 | ``` 13 | 14 | - `--dataset`: which dataset "sysu" or "regdb". 15 | 16 | ### Result. 17 | 18 | The results may have some fluctuation, and might be better by finetuning the hyper-parameters. 19 | 20 | 21 | |Datasets | Rank@1 | mAP | 22 | | -------- | ----- | ----- | 23 | |#RegDB[1] | ~ 91.6% | ~ 84.1% | 24 | |#SYSU-MM01[2] | ~ 70.6% | ~ 66.9% | 25 | 26 | 27 | ### Citation 28 | 29 | Please kindly cite this paper in your publications if it helps your research: 30 | ``` 31 | @inproceedings{zhang2021towards, 32 | title={Towards a Unified Middle Modality Learning for Visible-Infrared Person Re-Identification}, 33 | author={Zhang, Yukang and Yan, Yan and Lu, Yang and Wang, Hanzi}, 34 | booktitle={Proceedings of the 29th ACM International Conference on Multimedia}, 35 | pages={788--796}, 36 | year={2021} 37 | } 38 | ``` 39 | 40 | Our code is based on [mangye16](https://github.com/mangye16/Cross-Modal-Re-ID-baseline) [3, 4]. 41 | 42 | ### References. 43 | 44 | 45 | [1] D. T. Nguyen, H. G. Hong, K. W. Kim, and K. R. Park. Person recognition system based on a combination of body images from visible light and thermal cameras. Sensors, 17(3):605, 2017. 46 | 47 | [2] A. Wu, W.-s. Zheng, H.-X. Yu, S. Gong, and J. Lai. Rgb-infrared crossmodality person re-identification. In IEEE International Conference on Computer Vision (ICCV), pages 5380–5389, 2017. 48 | 49 | [3] M. Ye, J. Shen, G. Lin, T. Xiang, L. Shao, and S. C., Hoi. Deep learning for person re-identification: A survey and outlook. IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI), 2020. 50 | 51 | [4] M. Ye, X. Lan, Z. Wang, and P. C. Yuen. Bi-directional Center-Constrained Top-Ranking for Visible Thermal Person Re-Identification. IEEE Transactions on Information Forensics and Security (TIFS), 2019. 52 | 53 | 54 | If you have any question, please feel free to contact us. zhangyk@stu.xmu.edu.cn. 55 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch.utils.data as data 4 | 5 | 6 | class SYSUData(data.Dataset): 7 | def __init__(self, data_dir, transform=None, colorIndex = None, thermalIndex = None): 8 | 9 | # Load training images (path) and labels 10 | train_color_image = np.load(data_dir + 'train_rgb_resized_img.npy') 11 | self.train_color_label = np.load(data_dir + 'train_rgb_resized_label.npy') 12 | 13 | train_thermal_image = np.load(data_dir + 'train_ir_resized_img.npy') 14 | self.train_thermal_label = np.load(data_dir + 'train_ir_resized_label.npy') 15 | 16 | # BGR to RGB 17 | self.train_color_image = train_color_image 18 | self.train_thermal_image = train_thermal_image 19 | self.transform = transform 20 | self.cIndex = colorIndex 21 | self.tIndex = thermalIndex 22 | 23 | def __getitem__(self, index): 24 | 25 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 26 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 27 | 28 | img1 = self.transform(img1) 29 | img2 = self.transform(img2) 30 | 31 | return img1, img2, target1, target2 32 | 33 | def __len__(self): 34 | return len(self.train_color_label) 35 | 36 | 37 | class RegDBData(data.Dataset): 38 | def __init__(self, data_dir, trial, transform=None, colorIndex = None, thermalIndex = None): 39 | # Load training images (path) and labels 40 | train_color_list = data_dir + 'idx/train_visible_{}'.format(trial)+ '.txt' 41 | train_thermal_list = data_dir + 'idx/train_thermal_{}'.format(trial)+ '.txt' 42 | 43 | color_img_file, train_color_label = load_data(train_color_list) 44 | thermal_img_file, train_thermal_label = load_data(train_thermal_list) 45 | 46 | train_color_image = [] 47 | for i in range(len(color_img_file)): 48 | 49 | img = Image.open(data_dir+ color_img_file[i]) 50 | img = img.resize((192, 384), Image.ANTIALIAS) 51 | pix_array = np.array(img) 52 | train_color_image.append(pix_array) 53 | train_color_image = np.array(train_color_image) 54 | 55 | train_thermal_image = [] 56 | for i in range(len(thermal_img_file)): 57 | img = Image.open(data_dir+ thermal_img_file[i]) 58 | img = img.resize((192, 384), Image.ANTIALIAS) 59 | pix_array = np.array(img) 60 | train_thermal_image.append(pix_array) 61 | train_thermal_image = np.array(train_thermal_image) 62 | 63 | # BGR to RGB 64 | self.train_color_image = train_color_image 65 | self.train_color_label = train_color_label 66 | 67 | # BGR to RGB 68 | self.train_thermal_image = train_thermal_image 69 | self.train_thermal_label = train_thermal_label 70 | 71 | self.transform = transform 72 | self.cIndex = colorIndex 73 | self.tIndex = thermalIndex 74 | 75 | def __getitem__(self, index): 76 | 77 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 78 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 79 | 80 | img1 = self.transform(img1) 81 | img2 = self.transform(img2) 82 | 83 | return img1, img2, target1, target2 84 | 85 | def __len__(self): 86 | return len(self.train_color_label) 87 | 88 | class TestData(data.Dataset): 89 | def __init__(self, test_img_file, test_label, transform=None, img_size = (192,384)): 90 | 91 | test_image = [] 92 | for i in range(len(test_img_file)): 93 | img = Image.open(test_img_file[i]) 94 | img = img.resize((img_size[0], img_size[1]), Image.ANTIALIAS) 95 | pix_array = np.array(img) 96 | test_image.append(pix_array) 97 | test_image = np.array(test_image) 98 | self.test_image = test_image 99 | self.test_label = test_label 100 | self.transform = transform 101 | 102 | def __getitem__(self, index): 103 | img1, target1 = self.test_image[index], self.test_label[index] 104 | img1 = self.transform(img1) 105 | return img1, target1 106 | 107 | def __len__(self): 108 | return len(self.test_image) 109 | 110 | class TestDataOld(data.Dataset): 111 | def __init__(self, data_dir, test_img_file, test_label, transform=None, img_size = (192,384)): 112 | 113 | test_image = [] 114 | for i in range(len(test_img_file)): 115 | img = Image.open(data_dir + test_img_file[i]) 116 | img = img.resize((img_size[0], img_size[1]), Image.ANTIALIAS) 117 | pix_array = np.array(img) 118 | test_image.append(pix_array) 119 | test_image = np.array(test_image) 120 | self.test_image = test_image 121 | self.test_label = test_label 122 | self.transform = transform 123 | 124 | def __getitem__(self, index): 125 | img1, target1 = self.test_image[index], self.test_label[index] 126 | img1 = self.transform(img1) 127 | return img1, target1 128 | 129 | def __len__(self): 130 | return len(self.test_image) 131 | def load_data(input_data_path ): 132 | with open(input_data_path) as f: 133 | data_file_list = open(input_data_path, 'rt').read().splitlines() 134 | # Get full list of image and labels 135 | file_image = [s.split(' ')[0] for s in data_file_list] 136 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 137 | 138 | return file_image, file_label 139 | -------------------------------------------------------------------------------- /data_manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import numpy as np 4 | import random 5 | 6 | def process_query_sysu(data_path, mode = 'all', relabel=False): 7 | if mode== 'all': 8 | ir_cameras = ['cam3','cam6'] 9 | elif mode =='indoor': 10 | ir_cameras = ['cam3','cam6'] 11 | 12 | file_path = os.path.join(data_path,'exp/test_id.txt') 13 | files_rgb = [] 14 | files_ir = [] 15 | 16 | with open(file_path, 'r') as file: 17 | ids = file.read().splitlines() 18 | ids = [int(y) for y in ids[0].split(',')] 19 | ids = ["%04d" % x for x in ids] 20 | 21 | for id in sorted(ids): 22 | for cam in ir_cameras: 23 | img_dir = os.path.join(data_path,cam,id) 24 | if os.path.isdir(img_dir): 25 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 26 | files_ir.extend(new_files) 27 | query_img = [] 28 | query_id = [] 29 | query_cam = [] 30 | for img_path in files_ir: 31 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 32 | query_img.append(img_path) 33 | query_id.append(pid) 34 | query_cam.append(camid) 35 | return query_img, np.array(query_id), np.array(query_cam) 36 | 37 | def process_gallery_sysu(data_path, mode = 'all', trial = 0, relabel=False): 38 | 39 | random.seed(trial) 40 | 41 | if mode== 'all': 42 | rgb_cameras = ['cam1','cam2','cam4','cam5'] 43 | elif mode =='indoor': 44 | rgb_cameras = ['cam1','cam2'] 45 | 46 | file_path = os.path.join(data_path,'exp/test_id.txt') 47 | files_rgb = [] 48 | with open(file_path, 'r') as file: 49 | ids = file.read().splitlines() 50 | ids = [int(y) for y in ids[0].split(',')] 51 | ids = ["%04d" % x for x in ids] 52 | 53 | for id in sorted(ids): 54 | for cam in rgb_cameras: 55 | img_dir = os.path.join(data_path,cam,id) 56 | if os.path.isdir(img_dir): 57 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 58 | files_rgb.append(random.choice(new_files)) 59 | gall_img = [] 60 | gall_id = [] 61 | gall_cam = [] 62 | for img_path in files_rgb: 63 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 64 | gall_img.append(img_path) 65 | gall_id.append(pid) 66 | gall_cam.append(camid) 67 | return gall_img, np.array(gall_id), np.array(gall_cam) 68 | 69 | def process_test_regdb(img_dir, trial = 1, modal = 'visible'): 70 | if modal=='visible': 71 | input_data_path = img_dir + 'idx/test_visible_{}'.format(trial) + '.txt' 72 | elif modal=='thermal': 73 | input_data_path = img_dir + 'idx/test_thermal_{}'.format(trial) + '.txt' 74 | 75 | with open(input_data_path) as f: 76 | data_file_list = open(input_data_path, 'rt').read().splitlines() 77 | # Get full list of image and labels 78 | file_image = [img_dir + '/' + s.split(' ')[0] for s in data_file_list] 79 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 80 | 81 | return file_image, np.array(file_label) -------------------------------------------------------------------------------- /eval_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import numpy as np 3 | """Cross-Modality ReID""" 4 | import pdb 5 | 6 | def eval_sysu(distmat, q_pids, g_pids, q_camids, g_camids, max_rank = 20): 7 | """Evaluation with sysu metric 8 | Key: for each query identity, its gallery images from the same camera view are discarded. "Following the original setting in ite dataset" 9 | """ 10 | num_q, num_g = distmat.shape 11 | if num_g < max_rank: 12 | max_rank = num_g 13 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 14 | indices = np.argsort(distmat, axis=1) 15 | pred_label = g_pids[indices] 16 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 17 | 18 | # compute cmc curve for each query 19 | new_all_cmc = [] 20 | all_cmc = [] 21 | all_AP = [] 22 | all_INP = [] 23 | num_valid_q = 0. # number of valid query 24 | for q_idx in range(num_q): 25 | # get query pid and camid 26 | q_pid = q_pids[q_idx] 27 | q_camid = q_camids[q_idx] 28 | 29 | # remove gallery samples that have the same pid and camid with query 30 | order = indices[q_idx] 31 | remove = (q_camid == 3) & (g_camids[order] == 2) 32 | keep = np.invert(remove) 33 | 34 | # compute cmc curve 35 | # the cmc calculation is different from standard protocol 36 | # we follow the protocol of the author's released code 37 | new_cmc = pred_label[q_idx][keep] 38 | new_index = np.unique(new_cmc, return_index=True)[1] 39 | new_cmc = [new_cmc[index] for index in sorted(new_index)] 40 | 41 | new_match = (new_cmc == q_pid).astype(np.int32) 42 | new_cmc = new_match.cumsum() 43 | new_all_cmc.append(new_cmc[:max_rank]) 44 | 45 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 46 | if not np.any(orig_cmc): 47 | # this condition is true when query identity does not appear in gallery 48 | continue 49 | 50 | cmc = orig_cmc.cumsum() 51 | 52 | # compute mINP 53 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook 54 | pos_idx = np.where(orig_cmc == 1) 55 | pos_max_idx = np.max(pos_idx) 56 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0) 57 | all_INP.append(inp) 58 | 59 | cmc[cmc > 1] = 1 60 | 61 | all_cmc.append(cmc[:max_rank]) 62 | num_valid_q += 1. 63 | 64 | # compute average precision 65 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 66 | num_rel = orig_cmc.sum() 67 | tmp_cmc = orig_cmc.cumsum() 68 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 69 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 70 | AP = tmp_cmc.sum() / num_rel 71 | all_AP.append(AP) 72 | 73 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 74 | 75 | all_cmc = np.asarray(all_cmc).astype(np.float32) 76 | all_cmc = all_cmc.sum(0) / num_valid_q # standard CMC 77 | 78 | new_all_cmc = np.asarray(new_all_cmc).astype(np.float32) 79 | new_all_cmc = new_all_cmc.sum(0) / num_valid_q 80 | mAP = np.mean(all_AP) 81 | mINP = np.mean(all_INP) 82 | return new_all_cmc, mAP, mINP 83 | 84 | 85 | 86 | def eval_regdb(distmat, q_pids, g_pids, max_rank = 20): 87 | num_q, num_g = distmat.shape 88 | if num_g < max_rank: 89 | max_rank = num_g 90 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 91 | indices = np.argsort(distmat, axis=1) 92 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 93 | 94 | # compute cmc curve for each query 95 | all_cmc = [] 96 | all_AP = [] 97 | all_INP = [] 98 | num_valid_q = 0. # number of valid query 99 | 100 | # only two cameras 101 | q_camids = np.ones(num_q).astype(np.int32) 102 | g_camids = 2* np.ones(num_g).astype(np.int32) 103 | 104 | for q_idx in range(num_q): 105 | # get query pid and camid 106 | q_pid = q_pids[q_idx] 107 | q_camid = q_camids[q_idx] 108 | 109 | # remove gallery samples that have the same pid and camid with query 110 | order = indices[q_idx] 111 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 112 | keep = np.invert(remove) 113 | 114 | # compute cmc curve 115 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 116 | if not np.any(raw_cmc): 117 | # this condition is true when query identity does not appear in gallery 118 | continue 119 | 120 | cmc = raw_cmc.cumsum() 121 | 122 | # compute mINP 123 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook 124 | pos_idx = np.where(raw_cmc == 1) 125 | pos_max_idx = np.max(pos_idx) 126 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0) 127 | all_INP.append(inp) 128 | 129 | cmc[cmc > 1] = 1 130 | 131 | all_cmc.append(cmc[:max_rank]) 132 | num_valid_q += 1. 133 | 134 | # compute average precision 135 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 136 | num_rel = raw_cmc.sum() 137 | tmp_cmc = raw_cmc.cumsum() 138 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 139 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 140 | AP = tmp_cmc.sum() / num_rel 141 | all_AP.append(AP) 142 | 143 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 144 | 145 | all_cmc = np.asarray(all_cmc).astype(np.float32) 146 | all_cmc = all_cmc.sum(0) / num_valid_q 147 | mAP = np.mean(all_AP) 148 | mINP = np.mean(all_INP) 149 | return all_cmc, mAP, mINP -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd.function import Function 6 | from torch.autograd import Variable 7 | 8 | 9 | class DCLoss(nn.Module): 10 | def __init__(self, num=2): 11 | super(DCLoss, self).__init__() 12 | self.num = num 13 | self.fc1 = nn.Sequential(nn.Linear(2048, 256, bias=False)) 14 | self.fc2 = nn.Sequential(nn.BatchNorm1d(256), nn.ReLU(), nn.Linear(256, 256), nn.BatchNorm1d(256)) 15 | 16 | def forward(self, x): 17 | x = F.normalize(x) 18 | x = self.fc1(x) 19 | x = self.fc2(x) 20 | x = F.normalize(x) 21 | loss = 0 22 | num = int(x.size(0) / self.num) 23 | for i in range(self.num): 24 | for j in range(self.num): 25 | if i dist_ap3: 150 | 151 | loss1 = torch.abs(dist_ap2 - dist_ap3.detach())# + dist_an2.detach() - dist_an3 152 | else: 153 | loss1 = torch.abs(dist_ap2.detach() - dist_ap3)# + dist_an2.detach() - dist_an3 154 | 155 | return loss1# + loss2 156 | 157 | 158 | 159 | 160 | 161 | # Adaptive weights 162 | def softmax_weights(dist, mask): 163 | max_v = torch.max(dist * mask, dim=1, keepdim=True)[0] 164 | diff = dist - max_v 165 | Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero 166 | W = torch.exp(diff) * mask / Z 167 | return W 168 | 169 | def normalize(x, axis=-1): 170 | """Normalizing to unit length along the specified dimension. 171 | Args: 172 | x: pytorch Variable 173 | Returns: 174 | x: pytorch Variable, same shape as input 175 | """ 176 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 177 | return x 178 | 179 | class TripletLoss_WRT(nn.Module): 180 | """Weighted Regularized Triplet'.""" 181 | 182 | def __init__(self): 183 | super(TripletLoss_WRT, self).__init__() 184 | self.ranking_loss = nn.SoftMarginLoss() 185 | 186 | def forward(self, inputs, targets, normalize_feature=False): 187 | if normalize_feature: 188 | inputs = normalize(inputs, axis=-1) 189 | dist_mat = pdist_torch(inputs, inputs) 190 | 191 | N = dist_mat.size(0) 192 | # shape [N, N] 193 | is_pos = targets.expand(N, N).eq(targets.expand(N, N).t()).float() 194 | is_neg = targets.expand(N, N).ne(targets.expand(N, N).t()).float() 195 | 196 | # `dist_ap` means distance(anchor, positive) 197 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 198 | dist_ap = dist_mat * is_pos 199 | dist_an = dist_mat * is_neg 200 | 201 | weights_ap = softmax_weights(dist_ap, is_pos) 202 | weights_an = softmax_weights(-dist_an, is_neg) 203 | furthest_positive = torch.sum(dist_ap * weights_ap, dim=1) 204 | closest_negative = torch.sum(dist_an * weights_an, dim=1) 205 | 206 | y = furthest_positive.new().resize_as_(furthest_positive).fill_(1) 207 | loss = self.ranking_loss(closest_negative - furthest_positive, y) 208 | 209 | 210 | # compute accuracy 211 | correct = torch.ge(closest_negative, furthest_positive).sum().item() 212 | return loss, correct 213 | 214 | def pdist_torch(emb1, emb2): 215 | ''' 216 | compute the eucilidean distance matrix between embeddings1 and embeddings2 217 | using gpu 218 | ''' 219 | m, n = emb1.shape[0], emb2.shape[0] 220 | emb1_pow = torch.pow(emb1, 2).sum(dim = 1, keepdim = True).expand(m, n) 221 | emb2_pow = torch.pow(emb2, 2).sum(dim = 1, keepdim = True).expand(n, m).t() 222 | dist_mtx = emb1_pow + emb2_pow 223 | dist_mtx = dist_mtx.addmm_(1, -2, emb1, emb2.t()) 224 | # dist_mtx = dist_mtx.clamp(min = 1e-12) 225 | dist_mtx = dist_mtx.clamp(min = 1e-12).sqrt() 226 | return dist_mtx 227 | 228 | 229 | def pdist_np(emb1, emb2): 230 | ''' 231 | compute the eucilidean distance matrix between embeddings1 and embeddings2 232 | using cpu 233 | ''' 234 | m, n = emb1.shape[0], emb2.shape[0] 235 | emb1_pow = np.square(emb1).sum(axis = 1)[..., np.newaxis] 236 | emb2_pow = np.square(emb2).sum(axis = 1)[np.newaxis, ...] 237 | dist_mtx = -2 * np.matmul(emb1, emb2.T) + emb1_pow + emb2_pow 238 | # dist_mtx = np.sqrt(dist_mtx.clip(min = 1e-12)) 239 | return dist_mtx -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from resnet import resnet50, resnet18 5 | import torch.nn.functional as F 6 | 7 | class Normalize(nn.Module): 8 | def __init__(self, power=2): 9 | super(Normalize, self).__init__() 10 | self.power = power 11 | 12 | def forward(self, x): 13 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 14 | out = x.div(norm) 15 | return out 16 | 17 | 18 | # ##################################################################### 19 | def weights_init_kaiming(m): 20 | classname = m.__class__.__name__ 21 | # print(classname) 22 | if classname.find('Conv') != -1: 23 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 24 | elif classname.find('Linear') != -1: 25 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 26 | init.zeros_(m.bias.data) 27 | elif classname.find('BatchNorm1d') != -1: 28 | init.normal_(m.weight.data, 1.0, 0.01) 29 | init.zeros_(m.bias.data) 30 | 31 | def weights_init_classifier(m): 32 | classname = m.__class__.__name__ 33 | if classname.find('Linear') != -1: 34 | init.normal_(m.weight.data, 0, 0.001) 35 | if m.bias: 36 | init.zeros_(m.bias.data) 37 | 38 | 39 | def my_weights_init(m): 40 | if isinstance(m, nn.Linear): 41 | nn.init.constant_(m.weight, 0.333) 42 | nn.init.constant_(m.bias, 0.0) 43 | if isinstance(m, nn.Conv2d): 44 | nn.init.constant_(m.weight, 0.333) 45 | nn.init.constant_(m.bias, 0.0) 46 | 47 | 48 | class visible_module(nn.Module): 49 | def __init__(self, arch='resnet50'): 50 | super(visible_module, self).__init__() 51 | 52 | model_v = resnet50(pretrained=True, 53 | last_conv_stride=1, last_conv_dilation=1) 54 | # avg pooling to global pooling 55 | self.visible = model_v 56 | 57 | def forward(self, x): 58 | x = self.visible.conv1(x) 59 | x = self.visible.bn1(x) 60 | x = self.visible.relu(x) 61 | x = self.visible.maxpool(x) 62 | return x 63 | 64 | 65 | class thermal_module(nn.Module): 66 | def __init__(self, arch='resnet50'): 67 | super(thermal_module, self).__init__() 68 | 69 | model_t = resnet50(pretrained=True, 70 | last_conv_stride=1, last_conv_dilation=1) 71 | # avg pooling to global pooling 72 | self.thermal = model_t 73 | 74 | def forward(self, x): 75 | x = self.thermal.conv1(x) 76 | x = self.thermal.bn1(x) 77 | x = self.thermal.relu(x) 78 | x = self.thermal.maxpool(x) 79 | return x 80 | 81 | 82 | class base_resnet(nn.Module): 83 | def __init__(self, arch='resnet50'): 84 | super(base_resnet, self).__init__() 85 | 86 | model_base = resnet50(pretrained=True, 87 | last_conv_stride=1, last_conv_dilation=1) 88 | # avg pooling to global pooling 89 | model_base.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 90 | self.base = model_base 91 | 92 | def forward(self, x): 93 | x = self.base.layer1(x) 94 | x = self.base.layer2(x) 95 | x = self.base.layer3(x) 96 | x = self.base.layer4(x) 97 | return x 98 | 99 | class embed_net(nn.Module): 100 | def __init__(self, class_num, no_local= 'on', gm_pool = 'on', arch='resnet50'): 101 | super(embed_net, self).__init__() 102 | 103 | self.thermal_module = thermal_module(arch=arch) 104 | self.visible_module = visible_module(arch=arch) 105 | self.base_resnet = base_resnet(arch=arch) 106 | 107 | pool_dim = 2048 108 | self.l2norm = Normalize(2) 109 | self.bottleneck1 = nn.BatchNorm1d(pool_dim) 110 | self.bottleneck1.bias.requires_grad_(False) # no shift 111 | self.bottleneck1.apply(weights_init_kaiming) 112 | self.classifier1 = nn.Linear(pool_dim, class_num, bias=False) 113 | self.classifier1.apply(weights_init_classifier) 114 | 115 | self.bottleneck2 = nn.BatchNorm1d(pool_dim) 116 | self.bottleneck2.bias.requires_grad_(False) # no shift 117 | self.bottleneck2.apply(weights_init_kaiming) 118 | self.classifier2 = nn.Linear(pool_dim, class_num, bias=False) 119 | self.classifier2.apply(weights_init_classifier) 120 | 121 | self.bottleneck3 = nn.BatchNorm1d(pool_dim) 122 | self.bottleneck3.bias.requires_grad_(False) # no shift 123 | self.bottleneck3.apply(weights_init_kaiming) 124 | self.classifier3 = nn.Linear(pool_dim, class_num, bias=False) 125 | self.classifier3.apply(weights_init_classifier) 126 | 127 | self.bottleneck4 = nn.BatchNorm1d(pool_dim) 128 | self.bottleneck4.bias.requires_grad_(False) # no shift 129 | self.bottleneck4.apply(weights_init_kaiming) 130 | self.classifier4 = nn.Linear(pool_dim, class_num, bias=False) 131 | self.classifier4.apply(weights_init_classifier) 132 | 133 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 134 | self.encode1 = nn.Conv2d(3, 1, 1) 135 | self.encode1.apply(my_weights_init) 136 | self.fc1 = nn.Conv2d(1, 1, 1) 137 | self.fc1.apply(my_weights_init) 138 | self.bn1 = nn.BatchNorm2d(1) 139 | self.bn1.apply(weights_init_kaiming) 140 | 141 | 142 | self.encode2 = nn.Conv2d(3, 1, 1) 143 | self.encode2.apply(my_weights_init) 144 | self.fc2 = nn.Conv2d(1, 1, 1) 145 | self.fc2.apply(my_weights_init) 146 | self.bn2 = nn.BatchNorm2d(1) 147 | self.bn2.apply(weights_init_kaiming) 148 | 149 | 150 | self.decode = nn.Conv2d(1, 3, 1) 151 | self.decode.apply(my_weights_init) 152 | 153 | def forward(self, x1, x2, modal=0): 154 | if modal == 0: 155 | gray1 = F.relu(self.encode1(x1)) 156 | gray1 = self.bn1(F.relu(self.fc1(gray1))) 157 | 158 | gray2 = F.relu(self.encode2(x2)) 159 | gray2 = self.bn2(F.relu(self.fc2(gray2))) 160 | 161 | gray = F.relu(self.decode(torch.cat((gray1, gray2),0))) 162 | 163 | gray1, gray2 = torch.chunk(gray, 2, 0) 164 | xo = torch.cat((x1, x2), 0) 165 | 166 | x1 = self.visible_module(torch.cat((x1, gray1),0)) 167 | x2 = self.thermal_module(torch.cat((x2, gray2),0)) 168 | 169 | x = torch.cat((x1, x2), 0) 170 | elif modal == 1: 171 | gray1 = F.relu(self.encode1(x1)) 172 | gray1 = self.bn1(F.relu(self.fc1(gray1))) 173 | gray1 = F.relu(self.decode(gray1)) 174 | 175 | x = self.visible_module(torch.cat((x1, gray1),0)) 176 | elif modal == 2: 177 | gray2 = F.relu(self.encode2(x2)) 178 | gray2 = self.bn2(F.relu(self.fc2(gray2))) 179 | gray2 = F.relu(self.decode(gray2)) 180 | 181 | x = self.thermal_module(torch.cat((x2, gray2),0)) 182 | 183 | 184 | # shared block 185 | x = self.base_resnet.base.layer1(x) 186 | x = self.base_resnet.base.layer2(x) 187 | x = self.base_resnet.base.layer3(x) 188 | x = self.base_resnet.base.layer4(x) 189 | x41, x42, x43, x44 = torch.chunk(x, 4, 2) 190 | 191 | x41 = self.avgpool(x41) 192 | x42 = self.avgpool(x42) 193 | x43 = self.avgpool(x43) 194 | x44 = self.avgpool(x44) 195 | x41 = x41.view(x41.size(0), x41.size(1)) 196 | x42 = x42.view(x42.size(0), x42.size(1)) 197 | x43 = x43.view(x43.size(0), x43.size(1)) 198 | x44 = x44.view(x44.size(0), x44.size(1)) 199 | 200 | feat41 = self.bottleneck1(x41) 201 | feat42 = self.bottleneck2(x42) 202 | feat43 = self.bottleneck3(x43) 203 | feat44 = self.bottleneck4(x44) 204 | 205 | if self.training: 206 | return x41, x42, x43, x44, self.classifier1(feat41), self.classifier2(feat42), self.classifier3(feat43), self.classifier4(feat44), [xo, gray] 207 | else: 208 | return self.l2norm(torch.cat((x41, x42, x43, x44),1)), self.l2norm(torch.cat((feat41, feat42, feat43, feat44),1)) 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | -------------------------------------------------------------------------------- /pre_process_sysu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import pdb 4 | import os 5 | 6 | data_path = './Datasets/SYSU-MM01/' 7 | 8 | rgb_cameras = ['cam1','cam2','cam4','cam5'] 9 | ir_cameras = ['cam3','cam6'] 10 | 11 | # load id info 12 | file_path_train = os.path.join(data_path,'exp/train_id.txt') 13 | file_path_val = os.path.join(data_path,'exp/val_id.txt') 14 | with open(file_path_train, 'r') as file: 15 | ids = file.read().splitlines() 16 | ids = [int(y) for y in ids[0].split(',')] 17 | id_train = ["%04d" % x for x in ids] 18 | 19 | with open(file_path_val, 'r') as file: 20 | ids = file.read().splitlines() 21 | ids = [int(y) for y in ids[0].split(',')] 22 | id_val = ["%04d" % x for x in ids] 23 | 24 | # combine train and val split 25 | id_train.extend(id_val) 26 | 27 | files_rgb = [] 28 | files_ir = [] 29 | for id in sorted(id_train): 30 | for cam in rgb_cameras: 31 | img_dir = os.path.join(data_path,cam,id) 32 | if os.path.isdir(img_dir): 33 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 34 | files_rgb.extend(new_files) 35 | 36 | for cam in ir_cameras: 37 | img_dir = os.path.join(data_path,cam,id) 38 | if os.path.isdir(img_dir): 39 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 40 | files_ir.extend(new_files) 41 | 42 | # relabel 43 | pid_container = set() 44 | for img_path in files_ir: 45 | pid = int(img_path[-13:-9]) 46 | pid_container.add(pid) 47 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 48 | fix_image_width = 192 49 | fix_image_height = 384 50 | def read_imgs(train_image): 51 | train_img = [] 52 | train_label = [] 53 | for img_path in train_image: 54 | # img 55 | img = Image.open(img_path) 56 | img = img.resize((fix_image_width, fix_image_height), Image.ANTIALIAS) 57 | pix_array = np.array(img) 58 | 59 | train_img.append(pix_array) 60 | 61 | # label 62 | pid = int(img_path[-13:-9]) 63 | pid = pid2label[pid] 64 | train_label.append(pid) 65 | return np.array(train_img), np.array(train_label) 66 | 67 | # rgb imges 68 | train_img, train_label = read_imgs(files_rgb) 69 | np.save(data_path + 'train_rgb_resized_img.npy', train_img) 70 | np.save(data_path + 'train_rgb_resized_label.npy', train_label) 71 | 72 | # ir imges 73 | train_img, train_label = read_imgs(files_ir) 74 | np.save(data_path + 'train_ir_resized_img.npy', train_img) 75 | np.save(data_path + 'train_ir_resized_label.npy', train_label) 76 | -------------------------------------------------------------------------------- /random_erasing.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | 5 | from PIL import Image 6 | import random 7 | import math 8 | import numpy as np 9 | from torch import nn 10 | 11 | 12 | class ColorJitter(object): 13 | def __init__(self): 14 | self.color_jitter = transforms.RandomChoice([ 15 | transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5), 16 | ]) 17 | 18 | def __call__(self, image): 19 | h, w = 288, 144 20 | mask = np.ones((h, w, 3), dtype=np.uint8) 21 | for i in range(h//36): 22 | for j in range(h//36): 23 | if (i + j)%2==1: 24 | mask[i*36:i*36+36,j*36:j*36+36, :] = 0 25 | else: 26 | mask[i*36:i*36+36,j*36:j*36+36, :] = 1 27 | 28 | img = self.color_jitter(image) * mask + image * (1 - mask) 29 | #img = image * mask + image * (1 - mask) 30 | return img 31 | 32 | 33 | class RandomErasing(object): 34 | """ Randomly selects a rectangle region in an image and erases its pixels. 35 | 'Random Erasing Data Augmentation' by Zhong et al. 36 | See https://arxiv.org/pdf/1708.04896.pdf 37 | Args: 38 | probability: The probability that the Random Erasing operation will be performed. 39 | sl: Minimum proportion of erased area against input image. 40 | sh: Maximum proportion of erased area against input image. 41 | r1: Minimum aspect ratio of erased area. 42 | mean: Erasing value. 43 | """ 44 | 45 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 46 | self.probability = probability 47 | self.mean = mean 48 | self.sl = sl 49 | self.sh = sh 50 | self.r1 = r1 51 | 52 | def __call__(self, img): 53 | 54 | if random.uniform(0, 1) > self.probability: 55 | return img 56 | 57 | for attempt in range(100): 58 | area = img.size()[1] * img.size()[2] 59 | 60 | target_area = random.uniform(self.sl, self.sh) * area 61 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 62 | 63 | h = int(round(math.sqrt(target_area * aspect_ratio))) 64 | w = int(round(math.sqrt(target_area / aspect_ratio))) 65 | 66 | if w < img.size()[2] and h < img.size()[1]: 67 | x1 = random.randint(0, img.size()[1] - h) 68 | y1 = random.randint(0, img.size()[2] - w) 69 | if img.size()[0] == 3: 70 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 71 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 72 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 73 | else: 74 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 75 | return img 76 | 77 | return img 78 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 18 | """3x3 convolution with padding""" 19 | # original padding is 1; original dilation is 1 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=dilation, bias=False, dilation=dilation) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride, dilation) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | # original padding is 1; original dilation is 1 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * 4) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | 97 | def __init__(self, block, layers, last_conv_stride=2, last_conv_dilation=1): 98 | 99 | self.inplanes = 64 100 | super(ResNet, self).__init__() 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 102 | bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_conv_stride, dilation=last_conv_dilation) 110 | 111 | for m in self.modules(): 112 | if isinstance(m, nn.Conv2d): 113 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 114 | m.weight.data.normal_(0, math.sqrt(2. / n)) 115 | elif isinstance(m, nn.BatchNorm2d): 116 | m.weight.data.fill_(1) 117 | m.bias.data.zero_() 118 | 119 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 120 | downsample = None 121 | if stride != 1 or self.inplanes != planes * block.expansion: 122 | downsample = nn.Sequential( 123 | nn.Conv2d(self.inplanes, planes * block.expansion, 124 | kernel_size=1, stride=stride, bias=False), 125 | nn.BatchNorm2d(planes * block.expansion), 126 | ) 127 | 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, stride, downsample, dilation)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append(block(self.inplanes, planes)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.maxpool(x) 141 | 142 | x = self.layer1(x) 143 | x = self.layer2(x) 144 | x = self.layer3(x) 145 | x = self.layer4(x) 146 | 147 | return x 148 | 149 | 150 | def remove_fc(state_dict): 151 | """Remove the fc layer parameters from state_dict.""" 152 | # for key, value in state_dict.items(): 153 | for key, value in list(state_dict.items()): 154 | if key.startswith('fc.'): 155 | del state_dict[key] 156 | return state_dict 157 | 158 | 159 | def resnet18(pretrained=False, **kwargs): 160 | """Constructs a ResNet-18 model. 161 | Args: 162 | pretrained (bool): If True, returns a model pre-trained on ImageNet 163 | """ 164 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 165 | if pretrained: 166 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet18']))) 167 | return model 168 | 169 | 170 | def resnet34(pretrained=False, **kwargs): 171 | """Constructs a ResNet-34 model. 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 176 | if pretrained: 177 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet34']))) 178 | return model 179 | 180 | 181 | def resnet50(pretrained=False, **kwargs): 182 | """Constructs a ResNet-50 model. 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | """ 186 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 187 | if pretrained: 188 | # model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50']))) 189 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50']))) 190 | return model 191 | 192 | 193 | def resnet101(pretrained=False, **kwargs): 194 | """Constructs a ResNet-101 model. 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | """ 198 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 199 | if pretrained: 200 | model.load_state_dict( 201 | remove_fc(model_zoo.load_url(model_urls['resnet101']))) 202 | return model 203 | 204 | 205 | def resnet152(pretrained=False, **kwargs): 206 | """Constructs a ResNet-152 model. 207 | Args: 208 | pretrained (bool): If True, returns a model pre-trained on ImageNet 209 | """ 210 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 211 | if pretrained: 212 | model.load_state_dict( 213 | remove_fc(model_zoo.load_url(model_urls['resnet152']))) 214 | return model -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import time 4 | import torch.nn as nn 5 | import torch.backends.cudnn as cudnn 6 | from torch.autograd import Variable 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | from data_loader import SYSUData, RegDBData, TestData 10 | from data_manager import * 11 | from eval_metrics import eval_sysu, eval_regdb 12 | from model import embed_net 13 | from utils import * 14 | import pdb 15 | import scipy.io 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 18 | parser.add_argument('--dataset', default='sysu', help='dataset name: regdb or sysu]') 19 | parser.add_argument('--lr', default=0.1 , type=float, help='learning rate, 0.00035 for adam') 20 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer') 21 | parser.add_argument('--arch', default='resnet50', type=str, help='network baseline: resnet50') 22 | parser.add_argument('--resume', '-r', default='sysu_agw_p4_n4_lr_0.1_seed_0_best.t', type=str, help='resume from checkpoint') 23 | parser.add_argument('--test-only', action='store_true', help='test only') 24 | parser.add_argument('--model_path', default='save_model/', type=str, help='model save path') 25 | parser.add_argument('--save_epoch', default=20, type=int, metavar='s', help='save model every 10 epochs') 26 | parser.add_argument('--log_path', default='log/', type=str, help='log save path') 27 | parser.add_argument('--vis_log_path', default='log/vis_log/', type=str, help='log save path') 28 | parser.add_argument('--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') 29 | parser.add_argument('--img_w', default=192, type=int, metavar='imgw', help='img width') 30 | parser.add_argument('--img_h', default=384, type=int, metavar='imgh', help='img height') 31 | parser.add_argument('--batch-size', default=8, type=int, metavar='B', help='training batch size') 32 | parser.add_argument('--test-batch', default=64, type=int, metavar='tb', help='testing batch size') 33 | parser.add_argument('--method', default='awg', type=str, metavar='m', help='method type: base or awg') 34 | parser.add_argument('--margin', default=0.3, type=float, metavar='margin', help='triplet loss margin') 35 | parser.add_argument('--num_pos', default=5, type=int, help='num of pos per identity in each modality') 36 | parser.add_argument('--trial', default=1, type=int, metavar='t', help='trial (only for RegDB dataset)') 37 | parser.add_argument('--seed', default=0, type=int, metavar='t', help='random seed') 38 | parser.add_argument('--gpu', default='4', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES') 39 | parser.add_argument('--mode', default='all', type=str, help='all or indoor for sysu') 40 | parser.add_argument('--tvsearch', action='store_true', default = True, help='whether thermal to visible search on RegDB') 41 | args = parser.parse_args() 42 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 43 | 44 | dataset = args.dataset 45 | if dataset == 'sysu': 46 | data_path = './Datasets/SYSU-MM01/' 47 | n_class = 395 48 | test_mode = [1, 2] 49 | elif dataset =='regdb': 50 | data_path = './Datasets/RegDB/' 51 | n_class = 206 52 | test_mode = [2, 1] 53 | 54 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 55 | best_acc = 0 # best test accuracy 56 | start_epoch = 0 57 | pool_dim = 2048 * 4 58 | print('==> Building model..') 59 | if args.method =='base': 60 | net = embed_net(n_class, no_local= 'off', gm_pool = 'off', arch=args.arch) 61 | else: 62 | net = embed_net(n_class, no_local= 'on', gm_pool = 'on', arch=args.arch) 63 | net.to(device) 64 | cudnn.benchmark = True 65 | 66 | checkpoint_path = args.model_path 67 | 68 | if args.method =='id': 69 | criterion = nn.CrossEntropyLoss() 70 | criterion.to(device) 71 | 72 | print('==> Loading data..') 73 | # Data loading code 74 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 75 | transform_train = transforms.Compose([ 76 | transforms.ToPILImage(), 77 | transforms.RandomCrop((args.img_h,args.img_w)), 78 | transforms.RandomHorizontalFlip(), 79 | transforms.ToTensor(), 80 | normalize, 81 | ]) 82 | 83 | transform_test = transforms.Compose([ 84 | transforms.ToPILImage(), 85 | transforms.Resize((args.img_h,args.img_w)), 86 | transforms.ToTensor(), 87 | normalize, 88 | ]) 89 | 90 | end = time.time() 91 | 92 | 93 | 94 | def extract_gall_feat(gall_loader): 95 | net.eval() 96 | print ('Extracting Gallery Feature...') 97 | start = time.time() 98 | ptr = 0 99 | gall_feat_pool = np.zeros((ngall, pool_dim)) 100 | gall_feat_fc = np.zeros((ngall, pool_dim)) 101 | 102 | Xgall_feat_pool = np.zeros((ngall, pool_dim)) 103 | Xgall_feat_fc = np.zeros((ngall, pool_dim)) 104 | with torch.no_grad(): 105 | for batch_idx, (input, label ) in enumerate(gall_loader): 106 | batch_num = input.size(0) 107 | input = Variable(input.cuda()) 108 | feat_pool, feat_fc = net(input, input, test_mode[0]) 109 | gall_feat_pool[ptr:ptr+batch_num,: ] = feat_pool[:batch_num].detach().cpu().numpy() 110 | gall_feat_fc[ptr:ptr+batch_num,: ] = feat_fc[:batch_num].detach().cpu().numpy() 111 | 112 | Xgall_feat_pool[ptr:ptr+batch_num,: ] = feat_pool[batch_num:].detach().cpu().numpy() 113 | Xgall_feat_fc[ptr:ptr+batch_num,: ] = feat_fc[batch_num:].detach().cpu().numpy() 114 | ptr = ptr + batch_num 115 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 116 | return gall_feat_pool, gall_feat_fc, Xgall_feat_pool, Xgall_feat_fc 117 | 118 | def extract_query_feat(query_loader): 119 | net.eval() 120 | print ('Extracting Query Feature...') 121 | start = time.time() 122 | ptr = 0 123 | query_feat_pool = np.zeros((nquery, pool_dim)) 124 | query_feat_fc = np.zeros((nquery, pool_dim)) 125 | 126 | Xquery_feat_pool = np.zeros((nquery, pool_dim)) 127 | Xquery_feat_fc = np.zeros((nquery, pool_dim)) 128 | with torch.no_grad(): 129 | for batch_idx, (input, label ) in enumerate(query_loader): 130 | batch_num = input.size(0) 131 | input = Variable(input.cuda()) 132 | feat_pool, feat_fc = net(input, input, test_mode[1]) 133 | query_feat_pool[ptr:ptr+batch_num,: ] = feat_pool[:batch_num].detach().cpu().numpy() 134 | query_feat_fc[ptr:ptr+batch_num,: ] = feat_fc[:batch_num].detach().cpu().numpy() 135 | 136 | Xquery_feat_pool[ptr:ptr+batch_num,: ] = feat_pool[batch_num:].detach().cpu().numpy() 137 | Xquery_feat_fc[ptr:ptr+batch_num,: ] = feat_fc[batch_num:].detach().cpu().numpy() 138 | ptr = ptr + batch_num 139 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 140 | return query_feat_pool, query_feat_fc, Xquery_feat_pool, Xquery_feat_fc 141 | 142 | 143 | if dataset == 'sysu': 144 | 145 | print('==> Resuming from checkpoint..') 146 | if len(args.resume) > 0: 147 | model_path = checkpoint_path + args.resume 148 | # model_path = checkpoint_path + 'sysu_awg_p4_n8_lr_0.1_seed_0_best.t' 149 | if os.path.isfile(model_path): 150 | print('==> loading checkpoint {}'.format(args.resume)) 151 | checkpoint = torch.load(model_path) 152 | net.load_state_dict(checkpoint['net']) 153 | print('==> loaded checkpoint {} (epoch {})' 154 | .format(args.resume, checkpoint['epoch'])) 155 | else: 156 | print('==> no checkpoint found at {}'.format(args.resume)) 157 | 158 | # testing set 159 | query_img, query_label, query_cam = process_query_sysu(data_path, mode=args.mode) 160 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=0) 161 | 162 | nquery = len(query_label) 163 | ngall = len(gall_label) 164 | print("Dataset statistics:") 165 | print(" ------------------------------") 166 | print(" subset | # ids | # images") 167 | print(" ------------------------------") 168 | print(" query | {:5d} | {:8d}".format(len(np.unique(query_label)), nquery)) 169 | print(" gallery | {:5d} | {:8d}".format(len(np.unique(gall_label)), ngall)) 170 | print(" ------------------------------") 171 | 172 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 173 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 174 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 175 | 176 | query_feat_pool, query_feat_fc, Xquery_feat_pool, Xquery_feat_fc = extract_query_feat(query_loader) 177 | for trial in range(10): 178 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=trial) 179 | 180 | trial_gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 181 | trial_gall_loader = data.DataLoader(trial_gallset, batch_size=args.test_batch, shuffle=False, num_workers=4) 182 | 183 | gall_feat_pool, gall_feat_fc, Xgall_feat_pool, Xgall_feat_fc = extract_gall_feat(trial_gall_loader) 184 | 185 | # fc feature 186 | distmat = np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 187 | cmc, mAP, mINP = eval_sysu(-distmat, query_label, gall_label, query_cam, gall_cam) 188 | 189 | distmat1 = np.matmul(query_feat_fc, np.transpose(Xgall_feat_fc)) 190 | cmc1, mAP1, mINP1 = eval_sysu(-distmat1, query_label, gall_label, query_cam, gall_cam) 191 | 192 | distmat2 = np.matmul(Xquery_feat_fc, np.transpose(gall_feat_fc)) 193 | cmc2, mAP2, mINP2 = eval_sysu(-distmat2, query_label, gall_label, query_cam, gall_cam) 194 | 195 | distmat3 = np.matmul(Xquery_feat_fc, np.transpose(Xgall_feat_fc)) 196 | cmc3, mAP3, mINP3 = eval_sysu(-distmat3, query_label, gall_label, query_cam, gall_cam) 197 | 198 | result = {'gallery_f':gall_feat_fc,'gallery_label':gall_label,'gallery_cam':gall_cam,'query_f':query_feat_fc,'query_label':query_label,'query_cam':query_cam} 199 | scipy.io.savemat('pytorch_result.mat',result) 200 | 201 | distmat4 = distmat + distmat1 + distmat2 + distmat3 202 | cmc4, mAP4, mINP4 = eval_sysu(-distmat4, query_label, gall_label, query_cam, gall_cam) 203 | 204 | distmat5 = np.minimum(np.minimum(distmat, distmat1), np.minimum(distmat2, distmat3)) 205 | cmc5, mAP5, mINP5 = eval_sysu(-distmat5, query_label, gall_label, query_cam, gall_cam) 206 | 207 | distmat6 = np.maximum(np.maximum(distmat, distmat1), np.maximum(distmat2, distmat3)) 208 | cmc6, mAP6, mINP6 = eval_sysu(-distmat6, query_label, gall_label, query_cam, gall_cam) 209 | 210 | 211 | if trial == 0: 212 | all_cmc = cmc 213 | all_mAP = mAP 214 | all_mINP = mINP 215 | 216 | all_cmc1 = cmc1 217 | all_mAP1 = mAP1 218 | all_mINP1 = mINP1 219 | 220 | all_cmc2 = cmc2 221 | all_mAP2 = mAP2 222 | all_mINP2 = mINP2 223 | 224 | all_cmc3 = cmc3 225 | all_mAP3 = mAP3 226 | all_mINP3 = mINP3 227 | 228 | all_cmc4 = cmc4 229 | all_mAP4 = mAP4 230 | all_mINP4 = mINP4 231 | 232 | all_cmc5 = cmc5 233 | all_mAP5 = mAP5 234 | all_mINP5 = mINP5 235 | 236 | all_cmc6 = cmc6 237 | all_mAP6 = mAP6 238 | all_mINP6 = mINP6 239 | else: 240 | all_cmc = all_cmc + cmc 241 | all_mAP = all_mAP + mAP 242 | all_mINP = all_mINP + mINP 243 | 244 | all_cmc1 = all_cmc1 + cmc1 245 | all_mAP1 = all_mAP1 + mAP1 246 | all_mINP1 = all_mINP1 + mINP1 247 | 248 | all_cmc2 = all_cmc2 + cmc2 249 | all_mAP2 = all_mAP2 + mAP2 250 | all_mINP2 = all_mINP2 + mINP2 251 | 252 | all_cmc3 = all_cmc3 + cmc3 253 | all_mAP3 = all_mAP3 + mAP3 254 | all_mINP3 = all_mINP3 + mINP3 255 | 256 | all_cmc4 = all_cmc4 + cmc4 257 | all_mAP4 = all_mAP4 + mAP4 258 | all_mINP4 = all_mINP4 + mINP4 259 | 260 | all_cmc5 = all_cmc5 + cmc5 261 | all_mAP5 = all_mAP5 + mAP5 262 | all_mINP5 = all_mINP5 + mINP5 263 | 264 | all_cmc6 = all_cmc6 + cmc6 265 | all_mAP6 = all_mAP6 + mAP6 266 | all_mINP6 = all_mINP6 + mINP6 267 | 268 | print('Test Trial: {}'.format(trial)) 269 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 270 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 271 | 272 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 273 | cmc1[0], cmc1[4], cmc1[9], cmc1[19], mAP1, mINP1)) 274 | 275 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 276 | cmc2[0], cmc2[4], cmc2[9], cmc2[19], mAP2, mINP2)) 277 | 278 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 279 | cmc3[0], cmc3[4], cmc3[9], cmc3[19], mAP3, mINP3)) 280 | 281 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 282 | cmc4[0], cmc4[4], cmc4[9], cmc4[19], mAP4, mINP4)) 283 | 284 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 285 | cmc5[0], cmc5[4], cmc5[9], cmc5[19], mAP5, mINP5)) 286 | 287 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 288 | cmc6[0], cmc6[6], cmc6[9], cmc6[19], mAP6, mINP6)) 289 | 290 | 291 | elif dataset == 'regdb': 292 | 293 | for trial in range(10): 294 | test_trial = trial +1 295 | #model_path = checkpoint_path + args.resume regdb_agw_p4_n4_lr_0.1_seed_0_trial_9_best.t 296 | model_path = '/media/data3/zyk_data/GPAA3/save_model/regdb_agw_p4_n4_lr_0.1_seed_0_trial_{}_best.t'.format(test_trial) 297 | print(model_path) 298 | print(os.path.isfile(model_path)) 299 | if os.path.isfile(model_path): 300 | print('==> loading checkpoint {}'.format(args.resume)) 301 | checkpoint = torch.load(model_path) 302 | net.load_state_dict(checkpoint['net']) 303 | else: 304 | print('==> no checkpoint found at {}'.format(model_path)) 305 | 306 | 307 | # training set 308 | trainset = RegDBData(data_path, test_trial, transform=transform_train) 309 | # generate the idx of each person identity 310 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 311 | 312 | # testing set 313 | query_img, query_label = process_test_regdb(data_path, trial=test_trial, modal='visible') 314 | gall_img, gall_label = process_test_regdb(data_path, trial=test_trial, modal='thermal') 315 | 316 | gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 317 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 318 | 319 | nquery = len(query_label) 320 | ngall = len(gall_label) 321 | 322 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 323 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 324 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 325 | 326 | 327 | query_feat_pool, query_feat_fc, Xquery_feat_pool, Xquery_feat_fc = extract_query_feat(query_loader) 328 | gall_feat_pool, gall_feat_fc, Xgall_feat_pool, Xgall_feat_fc = extract_gall_feat(gall_loader) 329 | 330 | if args.tvsearch: 331 | # pool5 feature eval_regdb 332 | distmat = np.matmul(gall_feat_fc, np.transpose(query_feat_fc)) 333 | cmc, mAP, mINP = eval_regdb(-distmat, gall_label, query_label) 334 | 335 | distmat1 = np.matmul(Xgall_feat_fc, np.transpose(query_feat_fc)) 336 | cmc1, mAP1, mINP1 = eval_regdb(-distmat1, gall_label, query_label) 337 | 338 | distmat2 = np.matmul(gall_feat_fc, np.transpose(Xquery_feat_fc)) 339 | cmc2, mAP2, mINP2 = eval_regdb(-distmat2, gall_label, query_label) 340 | 341 | distmat3 = np.matmul(Xgall_feat_fc, np.transpose(Xquery_feat_fc)) 342 | cmc3, mAP3, mINP3 = eval_regdb(-distmat3, gall_label, query_label) 343 | 344 | #result = {'gallery_f':gall_feat_fc,'gallery_label':gall_label,'gallery_cam':gall_cam,'query_f':query_feat_fc,'query_label':query_label,'query_cam':query_cam} 345 | #scipy.io.savemat('pytorch_result.mat',result) 346 | 347 | distmat4 = distmat + distmat1 + distmat2 + distmat3 348 | cmc4, mAP4, mINP4 = eval_regdb(-distmat4, gall_label, query_label) 349 | 350 | distmat5 = np.minimum(np.minimum(distmat, distmat1), np.minimum(distmat2, distmat3)) 351 | cmc5, mAP5, mINP5 = eval_regdb(-distmat5, gall_label, query_label) 352 | 353 | distmat6 = np.maximum(np.maximum(distmat, distmat1), np.maximum(distmat2, distmat3)) 354 | cmc6, mAP6, mINP6 = eval_regdb(-distmat6, gall_label, query_label) 355 | else: 356 | # pool5 feature 357 | distmat = np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 358 | cmc, mAP, mINP = eval_regdb(-distmat, query_label, gall_label) 359 | 360 | distmat1 = np.matmul(query_feat_fc, np.transpose(Xgall_feat_fc)) 361 | cmc1, mAP1, mINP1 = eval_regdb(-distmat1, query_label, gall_label) 362 | 363 | distmat2 = np.matmul(Xquery_feat_fc, np.transpose(gall_feat_fc)) 364 | cmc2, mAP2, mINP2 = eval_regdb(-distmat2, query_label, gall_label) 365 | 366 | distmat3 = np.matmul(Xquery_feat_fc, np.transpose(Xgall_feat_fc)) 367 | cmc3, mAP3, mINP3 = eval_regdb(-distmat3, query_label, gall_label) 368 | 369 | #result = {'gallery_f':gall_feat_fc,'gallery_label':gall_label,'gallery_cam':gall_cam,'query_f':query_feat_fc,'query_label':query_label,'query_cam':query_cam} 370 | #scipy.io.savemat('pytorch_result.mat',result) 371 | 372 | distmat4 = distmat + distmat1 + distmat2 + distmat3 373 | cmc4, mAP4, mINP4 = eval_regdb(-distmat4, query_label, gall_label) 374 | 375 | distmat5 = np.minimum(np.minimum(distmat, distmat1), np.minimum(distmat2, distmat3)) 376 | cmc5, mAP5, mINP5 = eval_regdb(-distmat5, query_label, gall_label) 377 | 378 | distmat6 = np.maximum(np.maximum(distmat, distmat1), np.maximum(distmat2, distmat3)) 379 | cmc6, mAP6, mINP6 = eval_regdb(-distmat6, query_label, gall_label) 380 | 381 | 382 | if trial == 0: 383 | all_cmc = cmc 384 | all_mAP = mAP 385 | all_mINP = mINP 386 | 387 | all_cmc1 = cmc1 388 | all_mAP1 = mAP1 389 | all_mINP1 = mINP1 390 | 391 | all_cmc2 = cmc2 392 | all_mAP2 = mAP2 393 | all_mINP2 = mINP2 394 | 395 | all_cmc3 = cmc3 396 | all_mAP3 = mAP3 397 | all_mINP3 = mINP3 398 | 399 | all_cmc4 = cmc4 400 | all_mAP4 = mAP4 401 | all_mINP4 = mINP4 402 | 403 | all_cmc5 = cmc5 404 | all_mAP5 = mAP5 405 | all_mINP5 = mINP5 406 | 407 | all_cmc6 = cmc6 408 | all_mAP6 = mAP6 409 | all_mINP6 = mINP6 410 | else: 411 | all_cmc = all_cmc + cmc 412 | all_mAP = all_mAP + mAP 413 | all_mINP = all_mINP + mINP 414 | 415 | all_cmc1 = all_cmc1 + cmc1 416 | all_mAP1 = all_mAP1 + mAP1 417 | all_mINP1 = all_mINP1 + mINP1 418 | 419 | all_cmc2 = all_cmc2 + cmc2 420 | all_mAP2 = all_mAP2 + mAP2 421 | all_mINP2 = all_mINP2 + mINP2 422 | 423 | all_cmc3 = all_cmc3 + cmc3 424 | all_mAP3 = all_mAP3 + mAP3 425 | all_mINP3 = all_mINP3 + mINP3 426 | 427 | all_cmc4 = all_cmc4 + cmc4 428 | all_mAP4 = all_mAP4 + mAP4 429 | all_mINP4 = all_mINP4 + mINP4 430 | 431 | all_cmc5 = all_cmc5 + cmc5 432 | all_mAP5 = all_mAP5 + mAP5 433 | all_mINP5 = all_mINP5 + mINP5 434 | 435 | all_cmc6 = all_cmc6 + cmc6 436 | all_mAP6 = all_mAP6 + mAP6 437 | all_mINP6 = all_mINP6 + mINP6 438 | 439 | print('Test Trial: {}'.format(trial)) 440 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 441 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 442 | 443 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 444 | cmc1[0], cmc1[4], cmc1[9], cmc1[19], mAP1, mINP1)) 445 | 446 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 447 | cmc2[0], cmc2[4], cmc2[9], cmc2[19], mAP2, mINP2)) 448 | 449 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 450 | cmc3[0], cmc3[4], cmc3[9], cmc3[19], mAP3, mINP3)) 451 | 452 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 453 | cmc4[0], cmc4[4], cmc4[9], cmc4[19], mAP4, mINP4)) 454 | 455 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 456 | cmc5[0], cmc5[4], cmc5[9], cmc5[19], mAP5, mINP5)) 457 | 458 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 459 | cmc6[0], cmc6[6], cmc6[9], cmc6[19], mAP6, mINP6)) 460 | cmc = all_cmc / 10 461 | mAP = all_mAP / 10 462 | mINP = all_mINP / 10 463 | 464 | cmc1 = all_cmc1 / 10 465 | mAP1 = all_mAP1 / 10 466 | mINP1 = all_mINP1 / 10 467 | 468 | cmc2 = all_cmc2 / 10 469 | mAP2 = all_mAP2 / 10 470 | mINP2 = all_mINP2 / 10 471 | 472 | cmc3 = all_cmc3 / 10 473 | mAP3 = all_mAP3 / 10 474 | mINP3 = all_mINP3 / 10 475 | 476 | cmc4 = all_cmc4 / 10 477 | mAP4 = all_mAP4 / 10 478 | mINP4 = all_mINP4 / 10 479 | 480 | cmc5 = all_cmc5 / 10 481 | mAP5 = all_mAP5 / 10 482 | mINP5 = all_mINP5 / 10 483 | 484 | cmc6 = all_cmc6 / 10 485 | mAP6 = all_mAP6 / 10 486 | mINP6 = all_mINP6 / 10 487 | 488 | print('All Average:') 489 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 490 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 491 | 492 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 493 | cmc1[0], cmc1[4], cmc1[9], cmc1[19], mAP1, mINP1)) 494 | 495 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 496 | cmc2[0], cmc2[4], cmc2[9], cmc2[19], mAP2, mINP2)) 497 | 498 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 499 | cmc3[0], cmc3[4], cmc3[9], cmc3[19], mAP3, mINP3)) 500 | 501 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 502 | cmc4[0], cmc4[4], cmc4[9], cmc4[19], mAP4, mINP4)) 503 | 504 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 505 | cmc5[0], cmc5[4], cmc5[9], cmc5[19], mAP5, mINP5)) 506 | 507 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 508 | cmc6[0], cmc6[6], cmc6[9], cmc6[19], mAP6, mINP6)) 509 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import sys 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.backends.cudnn as cudnn 9 | from torch.autograd import Variable 10 | import torch.utils.data as data 11 | import torchvision 12 | import torchvision.transforms as transforms 13 | from data_loader import SYSUData, RegDBData, TestData 14 | from data_manager import * 15 | from eval_metrics import eval_sysu, eval_regdb 16 | from model import embed_net 17 | from utils import * 18 | from loss import OriTripletLoss, TriLoss, DCLoss 19 | from tensorboardX import SummaryWriter 20 | from random_erasing import RandomErasing 21 | 22 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 23 | parser.add_argument('--dataset', default='sysu', help='dataset name: regdb or sysu]') 24 | parser.add_argument('--lr', default=0.1 , type=float, help='learning rate, 0.00035 for adam') 25 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer') 26 | parser.add_argument('--arch', default='resnet50', type=str, help='network baseline:resnet18 or resnet50') 27 | parser.add_argument('--resume', '-r', default='', type=str, help='resume from checkpoint') 28 | parser.add_argument('--test-only', action='store_true', help='test only') 29 | parser.add_argument('--model_path', default='save_model/', type=str, help='model save path') 30 | parser.add_argument('--save_epoch', default=20, type=int, metavar='s', help='save model every 10 epochs') 31 | parser.add_argument('--log_path', default='log/', type=str, help='log save path') 32 | parser.add_argument('--vis_log_path', default='log/vis_log/', type=str, help='log save path') 33 | parser.add_argument('--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') 34 | parser.add_argument('--img_w', default=192, type=int, metavar='imgw', help='img width') 35 | parser.add_argument('--img_h', default=384, type=int, metavar='imgh', help='img height') 36 | parser.add_argument('--batch-size', default=4, type=int, metavar='B', help='training batch size') 37 | parser.add_argument('--test-batch', default=64, type=int, metavar='tb', help='testing batch size') 38 | parser.add_argument('--method', default='agw', type=str, metavar='m', help='method type: base or agw') 39 | parser.add_argument('--margin', default=0.3, type=float, metavar='margin', help='triplet loss margin') 40 | parser.add_argument('--num_pos', default=4, type=int, help='num of pos per identity in each modality') 41 | parser.add_argument('--trial', default=1, type=int, metavar='t', help='trial (only for RegDB dataset)') 42 | parser.add_argument('--seed', default=0, type=int, metavar='t', help='random seed') 43 | parser.add_argument('--gpu', default='0', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES') 44 | parser.add_argument('--mode', default='all', type=str, help='all or indoor') 45 | parser.add_argument('--delta', default=0.2, type=float, metavar='delta', help='dcl weights, 0.2 for PCB, 0.5 for resnet50') 46 | 47 | args = parser.parse_args() 48 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 49 | 50 | set_seed(args.seed) 51 | 52 | dataset = args.dataset 53 | if dataset == 'sysu': 54 | data_path = './Datasets/SYSU-MM01/' 55 | log_path = args.log_path + 'sysu_log/' 56 | test_mode = [1, 2] # thermal to visible 57 | elif dataset == 'regdb': 58 | data_path = './Datasets/RegDB/' 59 | log_path = args.log_path + 'regdb_log/' 60 | test_mode = [2, 1] # visible to thermal 61 | 62 | checkpoint_path = args.model_path 63 | 64 | if not os.path.isdir(log_path): 65 | os.makedirs(log_path) 66 | if not os.path.isdir(checkpoint_path): 67 | os.makedirs(checkpoint_path) 68 | if not os.path.isdir(args.vis_log_path): 69 | os.makedirs(args.vis_log_path) 70 | 71 | suffix = dataset 72 | if args.method=='agw': 73 | suffix = suffix + '_agw_p{}_n{}_lr_{}_seed_{}'.format(args.num_pos, args.batch_size, args.lr, args.seed) 74 | else: 75 | suffix = suffix + '_base_p{}_n{}_lr_{}_seed_{}'.format(args.num_pos, args.batch_size, args.lr, args.seed) 76 | 77 | 78 | if not args.optim == 'sgd': 79 | suffix = suffix + '_' + args.optim 80 | 81 | if dataset == 'regdb': 82 | suffix = suffix + '_trial_{}'.format(args.trial) 83 | 84 | sys.stdout = Logger(log_path + suffix + '_os.txt') 85 | 86 | vis_log_dir = args.vis_log_path + suffix + '/' 87 | 88 | if not os.path.isdir(vis_log_dir): 89 | os.makedirs(vis_log_dir) 90 | writer = SummaryWriter(vis_log_dir) 91 | print("==========\nArgs:{}\n==========".format(args)) 92 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 93 | best_acc = 0 # best test accuracy 94 | start_epoch = 0 95 | 96 | print('==> Loading data..') 97 | # Data loading code 98 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 99 | transform_train = transforms.Compose([ 100 | transforms.ToPILImage(), 101 | transforms.Pad(10), 102 | transforms.RandomCrop((args.img_h, args.img_w)), 103 | transforms.RandomHorizontalFlip(), 104 | transforms.ToTensor(), 105 | normalize, 106 | RandomErasing(probability = 0.5, mean=[0.0, 0.0, 0.0]), 107 | ]) 108 | transform_test = transforms.Compose([ 109 | transforms.ToPILImage(), 110 | transforms.Resize((args.img_h, args.img_w)), 111 | transforms.ToTensor(), 112 | normalize, 113 | ]) 114 | 115 | end = time.time() 116 | if dataset == 'sysu': 117 | # training set 118 | trainset = SYSUData(data_path, transform=transform_train) 119 | # generate the idx of each person identity 120 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 121 | 122 | # testing set 123 | query_img, query_label, query_cam = process_query_sysu(data_path, mode=args.mode) 124 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=0) 125 | 126 | elif dataset == 'regdb': 127 | # training set 128 | trainset = RegDBData(data_path, args.trial, transform=transform_train) 129 | # generate the idx of each person identity 130 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 131 | 132 | # testing set 133 | query_img, query_label = process_test_regdb(data_path, trial=args.trial, modal='visible') 134 | gall_img, gall_label = process_test_regdb(data_path, trial=args.trial, modal='thermal') 135 | 136 | gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 137 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 138 | 139 | # testing data loader 140 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 141 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 142 | 143 | n_class = len(np.unique(trainset.train_color_label)) 144 | nquery = len(query_label) 145 | ngall = len(gall_label) 146 | 147 | print('Dataset {} statistics:'.format(dataset)) 148 | print(' ------------------------------') 149 | print(' subset | # ids | # images') 150 | print(' ------------------------------') 151 | print(' visible | {:5d} | {:8d}'.format(n_class, len(trainset.train_color_label))) 152 | print(' thermal | {:5d} | {:8d}'.format(n_class, len(trainset.train_thermal_label))) 153 | print(' ------------------------------') 154 | print(' query | {:5d} | {:8d}'.format(len(np.unique(query_label)), nquery)) 155 | print(' gallery | {:5d} | {:8d}'.format(len(np.unique(gall_label)), ngall)) 156 | print(' ------------------------------') 157 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 158 | 159 | print('==> Building model..') 160 | if args.method =='base': 161 | net = embed_net(n_class, no_local= 'off', gm_pool = 'off', arch=args.arch) 162 | else: 163 | net = embed_net(n_class, no_local= 'on', gm_pool = 'on', arch=args.arch) 164 | net.to(device) 165 | cudnn.benchmark = True 166 | 167 | if len(args.resume) > 0: 168 | model_path = checkpoint_path + args.resume 169 | if os.path.isfile(model_path): 170 | print('==> loading checkpoint {}'.format(args.resume)) 171 | checkpoint = torch.load(model_path) 172 | start_epoch = checkpoint['epoch'] 173 | net.load_state_dict(checkpoint['net']) 174 | print('==> loaded checkpoint {} (epoch {})' 175 | .format(args.resume, checkpoint['epoch'])) 176 | else: 177 | print('==> no checkpoint found at {}'.format(args.resume)) 178 | 179 | # define loss function 180 | criterion_id = nn.CrossEntropyLoss() 181 | 182 | loader_batch = args.batch_size * args.num_pos 183 | criterion_tri= OriTripletLoss(batch_size=loader_batch, margin=args.margin) 184 | self_critial= TriLoss(batch_size=loader_batch, margin=args.margin) 185 | criterion_div = DCLoss(num=2) 186 | 187 | criterion_id.to(device) 188 | criterion_tri.to(device) 189 | criterion_div.to(device) 190 | 191 | 192 | if args.optim == 'sgd': 193 | ignored_params = list(map(id, net.bottleneck1.parameters())) \ 194 | + list(map(id, net.bottleneck2.parameters())) \ 195 | + list(map(id, net.bottleneck3.parameters())) \ 196 | + list(map(id, net.bottleneck4.parameters())) \ 197 | + list(map(id, net.classifier1.parameters())) \ 198 | + list(map(id, net.classifier2.parameters())) \ 199 | + list(map(id, net.classifier3.parameters())) \ 200 | + list(map(id, net.classifier4.parameters())) 201 | 202 | base_params = filter(lambda p: id(p) not in ignored_params, net.parameters()) 203 | 204 | optimizer = optim.SGD([ 205 | {'params': base_params, 'lr': 0.1 * args.lr}, 206 | {'params': net.bottleneck1.parameters(), 'lr': args.lr}, 207 | {'params': net.bottleneck2.parameters(), 'lr': args.lr}, 208 | {'params': net.bottleneck3.parameters(), 'lr': args.lr}, 209 | {'params': net.bottleneck4.parameters(), 'lr': args.lr}, 210 | {'params': net.classifier1.parameters(), 'lr': args.lr}, 211 | {'params': net.classifier2.parameters(), 'lr': args.lr}, 212 | {'params': net.classifier3.parameters(), 'lr': args.lr}, 213 | {'params': net.classifier4.parameters(), 'lr': args.lr}], 214 | weight_decay=5e-4, momentum=0.9, nesterov=True) 215 | 216 | # exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) 217 | def adjust_learning_rate(optimizer, epoch): 218 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 219 | if epoch < 10: 220 | lr = args.lr * (epoch + 1) / 10 221 | elif epoch >= 10 and epoch < 20: 222 | lr = args.lr 223 | elif epoch >= 20 and epoch < 50: 224 | lr = args.lr * 0.1 225 | elif epoch >= 50: 226 | lr = args.lr * 0.01 227 | 228 | optimizer.param_groups[0]['lr'] = 0.1 * lr 229 | for i in range(len(optimizer.param_groups) - 1): 230 | optimizer.param_groups[i + 1]['lr'] = lr 231 | 232 | return lr 233 | 234 | def train(epoch): 235 | 236 | current_lr = adjust_learning_rate(optimizer, epoch) 237 | train_loss = AverageMeter() 238 | id_loss = AverageMeter() 239 | tri_loss = AverageMeter() 240 | data_time = AverageMeter() 241 | batch_time = AverageMeter() 242 | correct = 0 243 | total = 0 244 | 245 | # switch to train mode 246 | net.train() 247 | end = time.time() 248 | 249 | for batch_idx, (input1, input2, label1, label2) in enumerate(trainloader): 250 | 251 | labels = torch.cat((label1, label1, label2, label2), 0) 252 | 253 | input1 = Variable(input1.cuda()) 254 | input2 = Variable(input2.cuda()) 255 | 256 | labels = Variable(labels.cuda()) 257 | data_time.update(time.time() - end) 258 | 259 | 260 | feat1, feat2, feat3, feat4, out1, out2, out3, out4, g1 = net(input1, input2) 261 | 262 | loss_id = (criterion_id(out1, labels) + criterion_id(out2, labels) + criterion_id(out3, labels) + criterion_id(out4, labels))*0.25 263 | 264 | lbs = torch.cat((label1, label2), 0) 265 | 266 | ft11, ft12, ft13, ft14 = torch.chunk(feat1, 4, 0) 267 | ft21, ft22, ft23, ft24 = torch.chunk(feat2, 4, 0) 268 | ft31, ft32, ft33, ft34 = torch.chunk(feat3, 4, 0) 269 | ft41, ft42, ft43, ft44 = torch.chunk(feat4, 4, 0) 270 | ba= criterion_tri(ft11, label1)[1] 271 | 272 | loss_tri1 = (criterion_tri(torch.cat((ft11, ft13),0), lbs)[0] + criterion_tri(torch.cat((ft11, ft14),0), lbs)[0] + criterion_tri(torch.cat((ft12, ft13),0), lbs)[0] + criterion_tri(torch.cat((ft12, ft14),0), lbs)[0])/4 273 | loss_tri2 = (criterion_tri(torch.cat((ft21, ft23),0), lbs)[0] + criterion_tri(torch.cat((ft21, ft24),0), lbs)[0] + criterion_tri(torch.cat((ft22, ft23),0), lbs)[0] + criterion_tri(torch.cat((ft22, ft24),0), lbs)[0])/4 274 | loss_tri3 = (criterion_tri(torch.cat((ft31, ft33),0), lbs)[0] + criterion_tri(torch.cat((ft31, ft34),0), lbs)[0] + criterion_tri(torch.cat((ft32, ft33),0), lbs)[0] + criterion_tri(torch.cat((ft32, ft34),0), lbs)[0])/4 275 | loss_tri4 = (criterion_tri(torch.cat((ft41, ft43),0), lbs)[0] + criterion_tri(torch.cat((ft41, ft44),0), lbs)[0] + criterion_tri(torch.cat((ft42, ft43),0), lbs)[0] + criterion_tri(torch.cat((ft42, ft44),0), lbs)[0])/4 276 | 277 | loss_tri = (loss_tri1 + loss_tri2 + loss_tri3 + loss_tri4)/4 278 | 279 | loss_dcl = (criterion_div(torch.cat((ft12, ft14),0)) + criterion_div(torch.cat((ft22, ft24),0)) + criterion_div(torch.cat((ft32, ft34),0)) + criterion_div(torch.cat((ft42, ft44),0)))*0.25*args.delta 280 | 281 | 282 | correct += (ba / 2) 283 | _, predicted = out1.max(1) 284 | correct += (predicted.eq(labels).sum().item() / 2) 285 | 286 | loss = loss_id + loss_tri + loss_dcl 287 | optimizer.zero_grad() 288 | loss.backward() 289 | optimizer.step() 290 | 291 | # update P 292 | train_loss.update(loss.item(), 2 * input1.size(0)) 293 | id_loss.update(loss_id.item(), 2 * input1.size(0)) 294 | tri_loss.update(loss_tri.item(), 2 * input1.size(0)) 295 | total += labels.size(0) 296 | 297 | # measure elapsed time 298 | batch_time.update(time.time() - end) 299 | end = time.time() 300 | if batch_idx % 50 == 0: 301 | print('Epoch: [{}][{}/{}] ' 302 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 303 | 'lr:{:.3f} ' 304 | 'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f}) ' 305 | 'iLoss: {id_loss.val:.4f} ({id_loss.avg:.4f}) ' 306 | 'TLoss: {tri_loss.val:.4f} ({tri_loss.avg:.4f}) ' 307 | 'Accu: {:.2f}'.format( 308 | epoch, batch_idx, len(trainloader), current_lr, 309 | 100. * correct / total, batch_time=batch_time, 310 | train_loss=train_loss, id_loss=id_loss, tri_loss=tri_loss)) 311 | 312 | writer.add_scalar('total_loss', train_loss.avg, epoch) 313 | writer.add_scalar('id_loss', id_loss.avg, epoch) 314 | writer.add_scalar('tri_loss', tri_loss.avg, epoch) 315 | writer.add_scalar('lr', current_lr, epoch) 316 | 317 | 318 | def test(epoch): 319 | # switch to evaluation mode 320 | net.eval() 321 | print('Extracting Gallery Feature...') 322 | start = time.time() 323 | ptr = 0 324 | gall_feat = np.zeros((ngall, 2048*4)) 325 | gall_feat_att = np.zeros((ngall, 2048*4)) 326 | Xgall_feat = np.zeros((ngall, 2048*4)) 327 | Xgall_feat_att = np.zeros((ngall, 2048*4)) 328 | with torch.no_grad(): 329 | for batch_idx, (input, label) in enumerate(gall_loader): 330 | batch_num = input.size(0) 331 | input = Variable(input.cuda()) 332 | feat, feat_att = net(input, input, test_mode[0]) 333 | gall_feat[ptr:ptr + batch_num, :] = feat[:batch_num].detach().cpu().numpy() 334 | gall_feat_att[ptr:ptr + batch_num, :] = feat_att[:batch_num].detach().cpu().numpy() 335 | Xgall_feat[ptr:ptr + batch_num, :] = feat[batch_num:].detach().cpu().numpy() 336 | Xgall_feat_att[ptr:ptr + batch_num, :] = feat_att[batch_num:].detach().cpu().numpy() 337 | ptr = ptr + batch_num 338 | print('Extracting Time:\t {:.3f}'.format(time.time() - start)) 339 | 340 | # switch to evaluation 341 | net.eval() 342 | print('Extracting Query Feature...') 343 | start = time.time() 344 | ptr = 0 345 | query_feat = np.zeros((nquery, 2048*4)) 346 | query_feat_att = np.zeros((nquery, 2048*4)) 347 | Xquery_feat = np.zeros((nquery, 2048*4)) 348 | Xquery_feat_att = np.zeros((nquery, 2048*4)) 349 | with torch.no_grad(): 350 | for batch_idx, (input, label) in enumerate(query_loader): 351 | batch_num = input.size(0) 352 | input = Variable(input.cuda()) 353 | feat, feat_att = net(input, input, test_mode[1]) 354 | query_feat[ptr:ptr + batch_num, :] = feat[:batch_num].detach().cpu().numpy() 355 | query_feat_att[ptr:ptr + batch_num, :] = feat_att[:batch_num].detach().cpu().numpy() 356 | Xquery_feat[ptr:ptr + batch_num, :] = feat[batch_num:].detach().cpu().numpy() 357 | Xquery_feat_att[ptr:ptr + batch_num, :] = feat_att[batch_num:].detach().cpu().numpy() 358 | ptr = ptr + batch_num 359 | print('Extracting Time:\t {:.3f}'.format(time.time() - start)) 360 | 361 | start = time.time() 362 | # compute the similarity 363 | distmat = np.matmul(query_feat, np.transpose(gall_feat)) 364 | distmat_att = np.matmul(query_feat_att, np.transpose(gall_feat_att)) 365 | 366 | Xdistmat = np.matmul(query_feat, np.transpose(Xgall_feat)) 367 | Xdistmat_att = np.matmul(query_feat_att, np.transpose(Xgall_feat_att)) 368 | 369 | distmatX = np.matmul(Xquery_feat, np.transpose(gall_feat)) 370 | distmat_attX = np.matmul(Xquery_feat_att, np.transpose(gall_feat_att)) 371 | 372 | XXdistmat = np.matmul(Xquery_feat, np.transpose(Xgall_feat)) 373 | XXdistmat_att = np.matmul(Xquery_feat_att, np.transpose(Xgall_feat_att)) 374 | # evaluation 375 | 376 | cmc, mAP, mINP = eval_regdb(-distmat, query_label, gall_label) 377 | cmc_att, mAP_att, mINP_att = eval_regdb(-distmat_att, query_label, gall_label) 378 | 379 | Xcmc, XmAP, XmINP = eval_regdb(-Xdistmat, query_label, gall_label) 380 | Xcmc_att, XmAP_att, XmINP_att = eval_regdb(-Xdistmat_att, query_label, gall_label) 381 | 382 | cmcX, mAPX, mINPX = eval_regdb(-distmatX, query_label, gall_label) 383 | cmc_attX, mAP_attX, mINP_attX = eval_regdb(-distmat_attX, query_label, gall_label) 384 | 385 | XXcmc, XXmAP, XXmINP = eval_regdb(-XXdistmat, query_label, gall_label) 386 | XXcmc_att, XXmAP_att, XXmINP_att = eval_regdb(-XXdistmat_att, query_label, gall_label) 387 | print('Evaluation Time:\t {:.3f}'.format(time.time() - start)) 388 | 389 | return cmc, mAP, mINP, cmc_att, mAP_att, mINP_att, \ 390 | Xcmc, XmAP, XmINP, Xcmc_att, XmAP_att, XmINP_att, \ 391 | cmcX, mAPX, mINPX, cmc_attX, mAP_attX, mINP_attX, \ 392 | XXcmc, XXmAP, XXmINP, XXcmc_att, XXmAP_att, XXmINP_att 393 | 394 | # training 395 | print('==> Start Training...') 396 | start_epoch = 0 397 | for epoch in range(start_epoch, 81 - start_epoch): 398 | 399 | print('==> Preparing Data Loader...') 400 | # identity sampler 401 | sampler = IdentitySampler(trainset.train_color_label, \ 402 | trainset.train_thermal_label, color_pos, thermal_pos, args.num_pos, args.batch_size, 403 | epoch) 404 | 405 | trainset.cIndex = sampler.index1 # color index 406 | trainset.tIndex = sampler.index2 # thermal index 407 | print(epoch) 408 | print(trainset.cIndex) 409 | print(trainset.tIndex) 410 | 411 | loader_batch = args.batch_size * args.num_pos 412 | 413 | trainloader = data.DataLoader(trainset, batch_size=loader_batch, \ 414 | sampler=sampler, num_workers=args.workers, drop_last=True) 415 | 416 | # training 417 | train(epoch) 418 | 419 | if epoch > 0 and epoch % 2 == 0: 420 | print('Test Epoch: {}'.format(epoch)) 421 | 422 | # testing 423 | cmc, mAP, mINP, cmc_att, mAP_att, mINP_att, \ 424 | Xcmc, XmAP, XmINP, Xcmc_att, XmAP_att, XmINP_att, \ 425 | cmcX, mAPX, mINPX, cmc_attX, mAP_attX, mINP_attX, \ 426 | XXcmc, XXmAP, XXmINP, XXcmc_att, XXmAP_att, XXmINP_att = test(epoch) 427 | # save model 428 | if cmc_att[0] > best_acc: # not the real best for sysu-mm01 429 | best_acc = cmc_att[0] 430 | best_epoch = epoch 431 | state = { 432 | 'net': net.state_dict(), 433 | 'cmc': cmc_att, 434 | 'mAP': mAP_att, 435 | 'mINP': mINP_att, 436 | 'epoch': epoch, 437 | } 438 | torch.save(state, checkpoint_path + suffix + '_best.t') 439 | 440 | # save model 441 | if epoch > 10 and epoch % args.save_epoch == 0: 442 | state = { 443 | 'net': net.state_dict(), 444 | 'cmc': cmc, 445 | 'mAP': mAP, 446 | 'epoch': epoch, 447 | } 448 | torch.save(state, checkpoint_path + suffix + '_epoch_{}.t'.format(epoch)) 449 | 450 | print('POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 451 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 452 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 453 | cmc_att[0], cmc_att[4], cmc_att[9], cmc_att[19], mAP_att, mINP_att)) 454 | print('POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 455 | Xcmc[0], Xcmc[4], Xcmc[9],Xcmc[19], XmAP, XmINP)) 456 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 457 | Xcmc_att[0], Xcmc_att[4], Xcmc_att[9], Xcmc_att[19], XmAP_att, XmINP_att)) 458 | print('POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 459 | cmcX[0], cmcX[4], cmcX[9], cmcX[19], mAPX, mINPX)) 460 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 461 | cmc_attX[0], cmc_attX[4], cmc_attX[9], cmc_attX[19], mAP_attX, mINP_attX)) 462 | print('POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 463 | XXcmc[0], XXcmc[4], XXcmc[9], XXcmc[19], XXmAP, XXmINP)) 464 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 465 | XXcmc_att[0], XXcmc_att[4], XXcmc_att[9], XXcmc_att[19], XXmAP_att, XXmINP_att)) 466 | print('Best Epoch [{}]'.format(best_epoch)) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data.sampler import Sampler 4 | import sys 5 | import os.path as osp 6 | import torch 7 | 8 | def load_data(input_data_path ): 9 | with open(input_data_path) as f: 10 | data_file_list = open(input_data_path, 'rt').read().splitlines() 11 | # Get full list of color image and labels 12 | file_image = [s.split(' ')[0] for s in data_file_list] 13 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 14 | 15 | return file_image, file_label 16 | 17 | 18 | def GenIdx( train_color_label, train_thermal_label): 19 | color_pos = [] 20 | unique_label_color = np.unique(train_color_label) 21 | for i in range(len(unique_label_color)): 22 | tmp_pos = [k for k,v in enumerate(train_color_label) if v==unique_label_color[i]] 23 | color_pos.append(tmp_pos) 24 | 25 | thermal_pos = [] 26 | unique_label_thermal = np.unique(train_thermal_label) 27 | for i in range(len(unique_label_thermal)): 28 | tmp_pos = [k for k,v in enumerate(train_thermal_label) if v==unique_label_thermal[i]] 29 | thermal_pos.append(tmp_pos) 30 | return color_pos, thermal_pos 31 | 32 | def GenCamIdx(gall_img, gall_label, mode): 33 | if mode =='indoor': 34 | camIdx = [1,2] 35 | else: 36 | camIdx = [1,2,4,5] 37 | gall_cam = [] 38 | for i in range(len(gall_img)): 39 | gall_cam.append(int(gall_img[i][-10])) 40 | 41 | sample_pos = [] 42 | unique_label = np.unique(gall_label) 43 | for i in range(len(unique_label)): 44 | for j in range(len(camIdx)): 45 | id_pos = [k for k,v in enumerate(gall_label) if v==unique_label[i] and gall_cam[k]==camIdx[j]] 46 | if id_pos: 47 | sample_pos.append(id_pos) 48 | return sample_pos 49 | 50 | def ExtractCam(gall_img): 51 | gall_cam = [] 52 | for i in range(len(gall_img)): 53 | cam_id = int(gall_img[i][-10]) 54 | # if cam_id ==3: 55 | # cam_id = 2 56 | gall_cam.append(cam_id) 57 | 58 | return np.array(gall_cam) 59 | 60 | class IdentitySampler(Sampler): 61 | """Sample person identities evenly in each batch. 62 | Args: 63 | train_color_label, train_thermal_label: labels of two modalities 64 | color_pos, thermal_pos: positions of each identity 65 | batchSize: batch size 66 | """ 67 | 68 | def __init__(self, train_color_label, train_thermal_label, color_pos, thermal_pos, num_pos, batchSize, epoch): 69 | uni_label = np.unique(train_color_label) 70 | self.n_classes = len(uni_label) 71 | 72 | 73 | N = np.maximum(len(train_color_label), len(train_thermal_label)) 74 | for j in range(int(N/(batchSize*num_pos))+1): 75 | batch_idx = np.random.choice(uni_label, batchSize, replace = False) 76 | for i in range(batchSize): 77 | sample_color = np.random.choice(color_pos[batch_idx[i]], num_pos) 78 | sample_thermal = np.random.choice(thermal_pos[batch_idx[i]], num_pos) 79 | 80 | if j ==0 and i==0: 81 | index1= sample_color 82 | index2= sample_thermal 83 | else: 84 | index1 = np.hstack((index1, sample_color)) 85 | index2 = np.hstack((index2, sample_thermal)) 86 | 87 | self.index1 = index1 88 | self.index2 = index2 89 | self.N = N 90 | 91 | def __iter__(self): 92 | return iter(np.arange(len(self.index1))) 93 | 94 | def __len__(self): 95 | return self.N 96 | 97 | class AverageMeter(object): 98 | """Computes and stores the average and current value""" 99 | def __init__(self): 100 | self.reset() 101 | 102 | def reset(self): 103 | self.val = 0 104 | self.avg = 0 105 | self.sum = 0 106 | self.count = 0 107 | 108 | def update(self, val, n=1): 109 | self.val = val 110 | self.sum += val * n 111 | self.count += n 112 | self.avg = self.sum / self.count 113 | 114 | def mkdir_if_missing(directory): 115 | if not osp.exists(directory): 116 | try: 117 | os.makedirs(directory) 118 | except OSError as e: 119 | if e.errno != errno.EEXIST: 120 | raise 121 | class Logger(object): 122 | """ 123 | Write console output to external text file. 124 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 125 | """ 126 | def __init__(self, fpath=None): 127 | self.console = sys.stdout 128 | self.file = None 129 | if fpath is not None: 130 | mkdir_if_missing(osp.dirname(fpath)) 131 | self.file = open(fpath, 'w') 132 | 133 | def __del__(self): 134 | self.close() 135 | 136 | def __enter__(self): 137 | pass 138 | 139 | def __exit__(self, *args): 140 | self.close() 141 | 142 | def write(self, msg): 143 | self.console.write(msg) 144 | if self.file is not None: 145 | self.file.write(msg) 146 | 147 | def flush(self): 148 | self.console.flush() 149 | if self.file is not None: 150 | self.file.flush() 151 | os.fsync(self.file.fileno()) 152 | 153 | def close(self): 154 | self.console.close() 155 | if self.file is not None: 156 | self.file.close() 157 | 158 | def set_seed(seed, cuda=True): 159 | np.random.seed(seed) 160 | torch.manual_seed(seed) 161 | if cuda: 162 | torch.cuda.manual_seed(seed) 163 | 164 | def set_requires_grad(nets, requires_grad=False): 165 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 166 | Parameters: 167 | nets (network list) -- a list of networks 168 | requires_grad (bool) -- whether the networks require gradients or not 169 | """ 170 | if not isinstance(nets, list): 171 | nets = [nets] 172 | for net in nets: 173 | if net is not None: 174 | for param in net.parameters(): 175 | param.requires_grad = requires_grad --------------------------------------------------------------------------------