├── LICENSE ├── README.md ├── data_loader.py ├── data_manager.py ├── eval_metrics.py ├── loss.py ├── model_mine.py ├── pre_process_sysu.py ├── random_aug.py ├── re_rank.py ├── resnet.py ├── run.sh ├── test.py ├── test.sh ├── test_mine_pcb.py ├── train_mine.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Visual Computing Lab -- IISc 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MMD-ReID 2 | Pytorch implementation for MMD-ReID: A Simple but Effective solution for Visible-Thermal Person ReID. Accepted at BMVC 2021 (Oral) 3 | 4 | **Paper link**: https://arxiv.org/abs/2111.05059 5 | 6 | **Github Code**: https://github.com/vcl-iisc/MMD-ReID 7 | 8 | **Presentation Slides**: https://drive.google.com/file/d/1S0sfA7PMyzqGPnG5izGBeZ7uClsJ1uA3/view?usp=sharing 9 | 10 | **Project webpage**: https://vcl-iisc.github.io/mmd-reid-web/ 11 | 12 | **Recorded Talk**: https://recorder-v3.slideslive.com/?share=55344&s=d3b53e98-4362-410a-825d-77706f8b71c4 13 | 14 | 15 | ### Dependencies: 16 | - Python 3.7 17 | - GPU memory ~ 10G 18 | - NumPy 1.19 19 | - PyTorch 1.8 20 | 21 | ### How to use this code: 22 | Our code extends the pytorch implementation of [Parameter Sharing Exploration and Hetero center triplet loss for VT Re-ID](https://github.com/hijune6/Hetero-center-triplet-loss-for-VT-Re-ID) in Github. Please refer to the offical repo for details of data preparation. 23 | 24 | ### Training: 25 | ``` 26 | python train_mine.py --dataset sysu --gpu 1 --pcb off --share_net 3 --batch-size 4 --num_pos 4 --dist_disc 'margin_mmd' --margin_mmd 1.40 --run_name 'margin_mmd1.40' 27 | ``` 28 | 29 | ### Testing: 30 | ``` 31 | python test.py --dataset sysu --gpu 0 --pcb off --share_net 3 --batch-size 4 --num_pos 4 --run_name 'margin_mmd1.40' 32 | ``` 33 | 34 | ### Results: 35 | 36 | | | Rank@1 | Rank@10 | Rank@20 | mAP | 37 | |---|--------------|----------------|----------|-----------| 38 | | SYSU-MM01 (All search Single shot) | 66.75% | 94.16% | 97.38% | 62.25% | 39 | | RegDB (Visible to Thermal) | 95.06% | 98.67% | 99.31% | 88.95% | 40 | 41 | ### Citation 42 | If you use this code, please cite our work as: 43 | ```bibtex 44 | @inproceedings{jambigi2021mmd, 45 | title={MMD-ReID: A Simple but Effective solution for Visible-Thermal Person ReID}, 46 | author={Jambigi, Chaitra and Rawal, Ruchit and Chakraborty, Anirban}, 47 | booktitle={British Machine Vision Conference}, 48 | year={2021} 49 | } 50 | ``` 51 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch.utils.data as data 4 | 5 | 6 | class SYSUData(data.Dataset): 7 | def __init__(self, data_dir, transform=None, colorIndex = None, thermalIndex = None): 8 | 9 | data_dir = './SYSU-MM01/' 10 | # Load training images (path) and labels 11 | train_color_image = np.load(data_dir + 'train_rgb_resized_img.npy') 12 | self.train_color_label = np.load(data_dir + 'train_rgb_resized_label.npy') 13 | 14 | train_thermal_image = np.load(data_dir + 'train_ir_resized_img.npy') 15 | self.train_thermal_label = np.load(data_dir + 'train_ir_resized_label.npy') 16 | 17 | # BGR to RGB 18 | self.train_color_image = train_color_image 19 | self.train_thermal_image = train_thermal_image 20 | self.transform = transform 21 | self.cIndex = colorIndex 22 | self.tIndex = thermalIndex 23 | 24 | def __getitem__(self, index): 25 | 26 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 27 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 28 | 29 | img1 = self.transform(img1) 30 | img2 = self.transform(img2) 31 | 32 | return img1, img2, target1, target2 33 | 34 | def __len__(self): 35 | return len(self.train_color_label) 36 | 37 | 38 | class RegDBData(data.Dataset): 39 | def __init__(self, data_dir, trial, transform=None, colorIndex = None, thermalIndex = None): 40 | # Load training images (path) and labels 41 | data_dir = './RegDB/' 42 | train_color_list = data_dir + 'idx/train_visible_{}'.format(trial)+ '.txt' 43 | train_thermal_list = data_dir + 'idx/train_thermal_{}'.format(trial)+ '.txt' 44 | 45 | color_img_file, train_color_label = load_data(train_color_list) 46 | thermal_img_file, train_thermal_label = load_data(train_thermal_list) 47 | 48 | train_color_image = [] 49 | for i in range(len(color_img_file)): 50 | 51 | img = Image.open(data_dir+ color_img_file[i]) 52 | img = img.resize((144, 288), Image.ANTIALIAS) 53 | pix_array = np.array(img) 54 | train_color_image.append(pix_array) 55 | train_color_image = np.array(train_color_image) 56 | 57 | train_thermal_image = [] 58 | for i in range(len(thermal_img_file)): 59 | img = Image.open(data_dir+ thermal_img_file[i]) 60 | img = img.resize((144, 288), Image.ANTIALIAS) 61 | pix_array = np.array(img) 62 | train_thermal_image.append(pix_array) 63 | train_thermal_image = np.array(train_thermal_image) 64 | 65 | # BGR to RGB 66 | self.train_color_image = train_color_image 67 | self.train_color_label = train_color_label 68 | 69 | # BGR to RGB 70 | self.train_thermal_image = train_thermal_image 71 | self.train_thermal_label = train_thermal_label 72 | 73 | self.transform = transform 74 | self.cIndex = colorIndex 75 | self.tIndex = thermalIndex 76 | 77 | def __getitem__(self, index): 78 | 79 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 80 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 81 | 82 | img1 = self.transform(img1) 83 | img2 = self.transform(img2) 84 | 85 | return img1, img2, target1, target2 86 | 87 | def __len__(self): 88 | return len(self.train_color_label) 89 | 90 | class TestData(data.Dataset): 91 | def __init__(self, test_img_file, test_label, transform=None, img_size = (144,288)): 92 | 93 | test_image = [] 94 | for i in range(len(test_img_file)): 95 | img = Image.open(test_img_file[i]) 96 | img = img.resize((img_size[0], img_size[1]), Image.ANTIALIAS) 97 | pix_array = np.array(img) 98 | test_image.append(pix_array) 99 | test_image = np.array(test_image) 100 | self.test_image = test_image 101 | self.test_label = test_label 102 | self.transform = transform 103 | 104 | def __getitem__(self, index): 105 | img1, target1 = self.test_image[index], self.test_label[index] 106 | img1 = self.transform(img1) 107 | return img1, target1 108 | 109 | def __len__(self): 110 | return len(self.test_image) 111 | 112 | class TestDataOld(data.Dataset): 113 | def __init__(self, data_dir, test_img_file, test_label, transform=None, img_size = (144,288)): 114 | 115 | test_image = [] 116 | for i in range(len(test_img_file)): 117 | img = Image.open(data_dir + test_img_file[i]) 118 | img = img.resize((img_size[0], img_size[1]), Image.ANTIALIAS) 119 | pix_array = np.array(img) 120 | test_image.append(pix_array) 121 | test_image = np.array(test_image) 122 | self.test_image = test_image 123 | self.test_label = test_label 124 | self.transform = transform 125 | 126 | def __getitem__(self, index): 127 | img1, target1 = self.test_image[index], self.test_label[index] 128 | img1 = self.transform(img1) 129 | return img1, target1 130 | 131 | def __len__(self): 132 | return len(self.test_image) 133 | def load_data(input_data_path ): 134 | with open(input_data_path) as f: 135 | data_file_list = open(input_data_path, 'rt').read().splitlines() 136 | # Get full list of image and labels 137 | file_image = [s.split(' ')[0] for s in data_file_list] 138 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 139 | 140 | return file_image, file_label -------------------------------------------------------------------------------- /data_manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import numpy as np 4 | import random 5 | 6 | def process_query_sysu(data_path, mode = 'all', relabel=False): 7 | if mode== 'all': 8 | ir_cameras = ['cam3','cam6'] 9 | elif mode =='indoor': 10 | ir_cameras = ['cam3','cam6'] 11 | 12 | file_path = os.path.join(data_path,'exp/test_id.txt') 13 | files_rgb = [] 14 | files_ir = [] 15 | 16 | with open(file_path, 'r') as file: 17 | ids = file.read().splitlines() 18 | ids = [int(y) for y in ids[0].split(',')] 19 | ids = ["%04d" % x for x in ids] 20 | 21 | for id in sorted(ids): 22 | for cam in ir_cameras: 23 | img_dir = os.path.join(data_path,cam,id) 24 | if os.path.isdir(img_dir): 25 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 26 | files_ir.extend(new_files) 27 | query_img = [] 28 | query_id = [] 29 | query_cam = [] 30 | for img_path in files_ir: 31 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 32 | query_img.append(img_path) 33 | query_id.append(pid) 34 | query_cam.append(camid) 35 | return query_img, np.array(query_id), np.array(query_cam) 36 | 37 | def process_gallery_sysu(data_path, mode = 'all', trial = 0, relabel=False): 38 | 39 | random.seed(trial) 40 | 41 | if mode== 'all': 42 | rgb_cameras = ['cam1','cam2','cam4','cam5'] 43 | elif mode =='indoor': 44 | rgb_cameras = ['cam1','cam2'] 45 | 46 | file_path = os.path.join(data_path,'exp/test_id.txt') 47 | files_rgb = [] 48 | with open(file_path, 'r') as file: 49 | ids = file.read().splitlines() 50 | ids = [int(y) for y in ids[0].split(',')] 51 | ids = ["%04d" % x for x in ids] 52 | 53 | for id in sorted(ids): 54 | for cam in rgb_cameras: 55 | img_dir = os.path.join(data_path,cam,id) 56 | if os.path.isdir(img_dir): 57 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 58 | files_rgb.append(random.choice(new_files)) 59 | gall_img = [] 60 | gall_id = [] 61 | gall_cam = [] 62 | for img_path in files_rgb: 63 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 64 | gall_img.append(img_path) 65 | gall_id.append(pid) 66 | gall_cam.append(camid) 67 | return gall_img, np.array(gall_id), np.array(gall_cam) 68 | 69 | def process_test_regdb(img_dir, trial = 1, modal = 'visible'): 70 | if modal=='visible': 71 | input_data_path = img_dir + 'idx/test_visible_{}'.format(trial) + '.txt' 72 | elif modal=='thermal': 73 | input_data_path = img_dir + 'idx/test_thermal_{}'.format(trial) + '.txt' 74 | 75 | with open(input_data_path) as f: 76 | data_file_list = open(input_data_path, 'rt').read().splitlines() 77 | # Get full list of image and labels 78 | file_image = [img_dir + '/' + s.split(' ')[0] for s in data_file_list] 79 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 80 | 81 | return file_image, np.array(file_label) -------------------------------------------------------------------------------- /eval_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import numpy as np 3 | """Cross-Modality ReID""" 4 | import pdb 5 | 6 | def eval_sysu(distmat, q_pids, g_pids, q_camids, g_camids, max_rank = 20): 7 | """Evaluation with sysu metric 8 | Key: for each query identity, its gallery images from the same camera view are discarded. "Following the original setting in ite dataset" 9 | """ 10 | num_q, num_g = distmat.shape 11 | if num_g < max_rank: 12 | max_rank = num_g 13 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 14 | indices = np.argsort(distmat, axis=1) 15 | pred_label = g_pids[indices] 16 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 17 | 18 | # compute cmc curve for each query 19 | new_all_cmc = [] 20 | all_cmc = [] 21 | all_AP = [] 22 | all_INP = [] 23 | num_valid_q = 0. # number of valid query 24 | for q_idx in range(num_q): 25 | # get query pid and camid 26 | q_pid = q_pids[q_idx] 27 | q_camid = q_camids[q_idx] 28 | 29 | # remove gallery samples that have the same pid and camid with query 30 | order = indices[q_idx] 31 | remove = (q_camid == 3) & (g_camids[order] == 2) 32 | keep = np.invert(remove) 33 | 34 | # compute cmc curve 35 | # the cmc calculation is different from standard protocol 36 | # we follow the protocol of the author's released code 37 | new_cmc = pred_label[q_idx][keep] 38 | new_index = np.unique(new_cmc, return_index=True)[1] 39 | new_cmc = [new_cmc[index] for index in sorted(new_index)] 40 | 41 | new_match = (new_cmc == q_pid).astype(np.int32) 42 | new_cmc = new_match.cumsum() 43 | new_all_cmc.append(new_cmc[:max_rank]) 44 | 45 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 46 | if not np.any(orig_cmc): 47 | # this condition is true when query identity does not appear in gallery 48 | continue 49 | 50 | cmc = orig_cmc.cumsum() 51 | 52 | # compute mINP 53 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook 54 | pos_idx = np.where(orig_cmc == 1) 55 | pos_max_idx = np.max(pos_idx) 56 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0) 57 | all_INP.append(inp) 58 | 59 | cmc[cmc > 1] = 1 60 | 61 | all_cmc.append(cmc[:max_rank]) 62 | num_valid_q += 1. 63 | 64 | # compute average precision 65 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 66 | num_rel = orig_cmc.sum() 67 | tmp_cmc = orig_cmc.cumsum() 68 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 69 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 70 | AP = tmp_cmc.sum() / num_rel 71 | all_AP.append(AP) 72 | 73 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 74 | 75 | all_cmc = np.asarray(all_cmc).astype(np.float32) 76 | all_cmc = all_cmc.sum(0) / num_valid_q # standard CMC 77 | 78 | new_all_cmc = np.asarray(new_all_cmc).astype(np.float32) 79 | new_all_cmc = new_all_cmc.sum(0) / num_valid_q 80 | mAP = np.mean(all_AP) 81 | mINP = np.mean(all_INP) 82 | return new_all_cmc, mAP, mINP 83 | 84 | 85 | 86 | def eval_regdb(distmat, q_pids, g_pids, max_rank = 20): 87 | num_q, num_g = distmat.shape 88 | if num_g < max_rank: 89 | max_rank = num_g 90 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 91 | indices = np.argsort(distmat, axis=1) 92 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 93 | 94 | # compute cmc curve for each query 95 | all_cmc = [] 96 | all_AP = [] 97 | all_INP = [] 98 | num_valid_q = 0. # number of valid query 99 | 100 | # only two cameras 101 | q_camids = np.ones(num_q).astype(np.int32) 102 | g_camids = 2* np.ones(num_g).astype(np.int32) 103 | 104 | for q_idx in range(num_q): 105 | # get query pid and camid 106 | q_pid = q_pids[q_idx] 107 | q_camid = q_camids[q_idx] 108 | 109 | # remove gallery samples that have the same pid and camid with query 110 | order = indices[q_idx] 111 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 112 | keep = np.invert(remove) 113 | 114 | # compute cmc curve 115 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 116 | if not np.any(raw_cmc): 117 | # this condition is true when query identity does not appear in gallery 118 | continue 119 | 120 | cmc = raw_cmc.cumsum() 121 | 122 | # compute mINP 123 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook 124 | pos_idx = np.where(raw_cmc == 1) 125 | pos_max_idx = np.max(pos_idx) 126 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0) 127 | all_INP.append(inp) 128 | 129 | cmc[cmc > 1] = 1 130 | 131 | all_cmc.append(cmc[:max_rank]) 132 | num_valid_q += 1. 133 | 134 | # compute average precision 135 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 136 | num_rel = raw_cmc.sum() 137 | tmp_cmc = raw_cmc.cumsum() 138 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 139 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 140 | AP = tmp_cmc.sum() / num_rel 141 | all_AP.append(AP) 142 | 143 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 144 | 145 | all_cmc = np.asarray(all_cmc).astype(np.float32) 146 | all_cmc = all_cmc.sum(0) / num_valid_q 147 | mAP = np.mean(all_AP) 148 | mINP = np.mean(all_INP) 149 | return all_cmc, mAP, mINP -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd.function import Function 6 | from torch.autograd import Variable 7 | 8 | 9 | class CenterTripletLoss(nn.Module): 10 | """ Hetero-center-triplet-loss-for-VT-Re-ID 11 | "Parameters Sharing Exploration and Hetero-Center Triplet Loss for Visible-Thermal Person Re-Identification" 12 | [(arxiv)](https://arxiv.org/abs/2008.06223). 13 | 14 | Args: 15 | - margin (float): margin for triplet. 16 | """ 17 | 18 | def __init__(self, batch_size, margin=0.3): 19 | super(CenterTripletLoss, self).__init__() 20 | self.margin = margin 21 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 22 | 23 | def forward(self, feats, labels): 24 | """ 25 | Args: 26 | - inputs: feature matrix with shape (batch_size, feat_dim) 27 | - targets: ground truth labels with shape (num_classes) 28 | """ 29 | label_uni = labels.unique() 30 | targets = torch.cat([label_uni,label_uni]) 31 | label_num = len(label_uni) 32 | feat = feats.chunk(label_num*2, 0) 33 | center = [] 34 | for i in range(label_num*2): 35 | center.append(torch.mean(feat[i], dim=0, keepdim=True)) 36 | inputs = torch.cat(center) 37 | 38 | n = inputs.size(0) 39 | 40 | # Compute pairwise distance, replace by the official when merged 41 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 42 | dist = dist + dist.t() 43 | dist.addmm_(1, -2, inputs, inputs.t()) 44 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 45 | 46 | # For each anchor, find the hardest positive and negative 47 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 48 | dist_ap, dist_an = [], [] 49 | for i in range(n): 50 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 51 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 52 | dist_ap = torch.cat(dist_ap) 53 | dist_an = torch.cat(dist_an) 54 | 55 | # Compute ranking hinge loss 56 | y = torch.ones_like(dist_an) 57 | loss = self.ranking_loss(dist_an, dist_ap, y) 58 | 59 | # compute accuracy 60 | correct = torch.ge(dist_an, dist_ap).sum().item() 61 | return loss, correct 62 | 63 | 64 | 65 | 66 | 67 | class CrossEntropyLabelSmooth(nn.Module): 68 | """Cross entropy loss with label smoothing regularizer. 69 | Reference: 70 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 71 | Equation: y = (1 - epsilon) * y + epsilon / K. 72 | Args: 73 | num_classes (int): number of classes. 74 | epsilon (float): weight. 75 | """ 76 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 77 | super(CrossEntropyLabelSmooth, self).__init__() 78 | self.num_classes = num_classes 79 | self.epsilon = epsilon 80 | self.use_gpu = use_gpu 81 | self.logsoftmax = nn.LogSoftmax(dim=1) 82 | 83 | def forward(self, inputs, targets): 84 | """ 85 | Args: 86 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 87 | targets: ground truth labels with shape (num_classes) 88 | """ 89 | log_probs = self.logsoftmax(inputs) 90 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 91 | if self.use_gpu: targets = targets.cuda() 92 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 93 | loss = (- targets * log_probs).mean(0).sum() 94 | return loss 95 | 96 | 97 | class OriTripletLoss(nn.Module): 98 | """Triplet loss with hard positive/negative mining. 99 | 100 | Reference: 101 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 102 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 103 | 104 | Args: 105 | - margin (float): margin for triplet. 106 | """ 107 | 108 | def __init__(self, batch_size, margin=0.3): 109 | super(OriTripletLoss, self).__init__() 110 | self.margin = margin 111 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 112 | 113 | def forward(self, inputs, targets): 114 | """ 115 | Args: 116 | - inputs: feature matrix with shape (batch_size, feat_dim) 117 | - targets: ground truth labels with shape (num_classes) 118 | """ 119 | n = inputs.size(0) 120 | 121 | # Compute pairwise distance, replace by the official when merged 122 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 123 | dist = dist + dist.t() 124 | dist.addmm_(1, -2, inputs, inputs.t()) 125 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 126 | 127 | # For each anchor, find the hardest positive and negative 128 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 129 | dist_ap, dist_an = [], [] 130 | for i in range(n): 131 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 132 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 133 | dist_ap = torch.cat(dist_ap) 134 | dist_an = torch.cat(dist_an) 135 | 136 | # Compute ranking hinge loss 137 | y = torch.ones_like(dist_an) 138 | loss = self.ranking_loss(dist_an, dist_ap, y) 139 | 140 | # compute accuracy 141 | correct = torch.ge(dist_an, dist_ap).sum().item() 142 | return loss, correct 143 | 144 | 145 | 146 | 147 | # Adaptive weights 148 | def softmax_weights(dist, mask): 149 | max_v = torch.max(dist * mask, dim=1, keepdim=True)[0] 150 | diff = dist - max_v 151 | Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero 152 | W = torch.exp(diff) * mask / Z 153 | return W 154 | 155 | def normalize(x, axis=-1): 156 | """Normalizing to unit length along the specified dimension. 157 | Args: 158 | x: pytorch Variable 159 | Returns: 160 | x: pytorch Variable, same shape as input 161 | """ 162 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 163 | return x 164 | 165 | class TripletLoss_WRT(nn.Module): 166 | """Weighted Regularized Triplet'.""" 167 | 168 | def __init__(self): 169 | super(TripletLoss_WRT, self).__init__() 170 | self.ranking_loss = nn.SoftMarginLoss() 171 | 172 | def forward(self, inputs, targets, normalize_feature=False): 173 | if normalize_feature: 174 | inputs = normalize(inputs, axis=-1) 175 | dist_mat = pdist_torch(inputs, inputs) 176 | 177 | N = dist_mat.size(0) 178 | # shape [N, N] 179 | is_pos = targets.expand(N, N).eq(targets.expand(N, N).t()).float() 180 | is_neg = targets.expand(N, N).ne(targets.expand(N, N).t()).float() 181 | 182 | # `dist_ap` means distance(anchor, positive) 183 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 184 | dist_ap = dist_mat * is_pos 185 | dist_an = dist_mat * is_neg 186 | 187 | weights_ap = softmax_weights(dist_ap, is_pos) 188 | weights_an = softmax_weights(-dist_an, is_neg) 189 | furthest_positive = torch.sum(dist_ap * weights_ap, dim=1) 190 | closest_negative = torch.sum(dist_an * weights_an, dim=1) 191 | 192 | y = furthest_positive.new().resize_as_(furthest_positive).fill_(1) 193 | loss = self.ranking_loss(closest_negative - furthest_positive, y) 194 | 195 | 196 | # compute accuracy 197 | correct = torch.ge(closest_negative, furthest_positive).sum().item() 198 | return loss, correct 199 | 200 | def pdist_torch(emb1, emb2): 201 | ''' 202 | compute the eucilidean distance matrix between embeddings1 and embeddings2 203 | using gpu 204 | ''' 205 | m, n = emb1.shape[0], emb2.shape[0] 206 | emb1_pow = torch.pow(emb1, 2).sum(dim = 1, keepdim = True).expand(m, n) 207 | emb2_pow = torch.pow(emb2, 2).sum(dim = 1, keepdim = True).expand(n, m).t() 208 | dist_mtx = emb1_pow + emb2_pow 209 | dist_mtx = dist_mtx.addmm_(1, -2, emb1, emb2.t()) 210 | # dist_mtx = dist_mtx.clamp(min = 1e-12) 211 | dist_mtx = dist_mtx.clamp(min = 1e-12).sqrt() 212 | return dist_mtx 213 | 214 | 215 | def pdist_np(emb1, emb2): 216 | ''' 217 | compute the eucilidean distance matrix between embeddings1 and embeddings2 218 | using cpu 219 | ''' 220 | m, n = emb1.shape[0], emb2.shape[0] 221 | emb1_pow = np.square(emb1).sum(axis = 1)[..., np.newaxis] 222 | emb2_pow = np.square(emb2).sum(axis = 1)[np.newaxis, ...] 223 | dist_mtx = -2 * np.matmul(emb1, emb2.T) + emb1_pow + emb2_pow 224 | # dist_mtx = np.sqrt(dist_mtx.clip(min = 1e-12)) 225 | return dist_mtx 226 | 227 | class MMD_Loss(nn.Module): 228 | def __init__(self, kernel_mul = 2.0, kernel_num = 5): 229 | super(MMD_Loss, self).__init__() 230 | self.kernel_num = kernel_num 231 | self.kernel_mul = kernel_mul 232 | self.fix_sigma = None 233 | return 234 | def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 235 | n_samples = int(source.size()[0])+int(target.size()[0]) 236 | total = torch.cat([source, target], dim=0) 237 | 238 | total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) 239 | total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) 240 | L2_distance = ((total0-total1)**2).sum(2) 241 | if fix_sigma: 242 | bandwidth = fix_sigma 243 | else: 244 | bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples) 245 | bandwidth /= kernel_mul ** (kernel_num // 2) 246 | bandwidth_list = [bandwidth * (kernel_mul**i) + 1e-9 for i in range(kernel_num)] 247 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list] 248 | return sum(kernel_val), L2_distance 249 | 250 | def forward(self, source, target): 251 | 252 | xx_batch, yy_batch, xy_batch, yx_batch = 0,0,0,0 253 | 254 | batch_size = int(source.size()[0]) 255 | kernels, L2dist = self.guassian_kernel(source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma) 256 | XX = kernels[:batch_size, :batch_size] 257 | YY = kernels[batch_size:, batch_size:] 258 | XY = kernels[:batch_size, batch_size:] 259 | YX = kernels[batch_size:, :batch_size] 260 | 261 | xx_batch = torch.mean(XX) 262 | yy_batch = torch.mean(YY) 263 | xy_batch = torch.mean(XY) 264 | yx_batch = torch.mean(YX) 265 | 266 | loss = torch.mean(XX + YY - XY -YX) 267 | return loss, torch.max(L2dist), [xx_batch, yy_batch, xy_batch, yx_batch] 268 | 269 | 270 | class MarginMMD_Loss(nn.Module): 271 | def __init__(self, kernel_mul = 2.0, kernel_num = 5, P=4, K=4, margin=None): 272 | super(MarginMMD_Loss, self).__init__() 273 | self.kernel_num = kernel_num 274 | self.kernel_mul = kernel_mul 275 | self.fix_sigma = None 276 | self.P = P 277 | self.K = K 278 | self.margin = margin 279 | if self.margin: 280 | print(f'Using Margin : {self.margin}') 281 | return 282 | 283 | def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 284 | 285 | n_samples = int(source.size()[0])+int(target.size()[0]) 286 | total = torch.cat([source, target], dim=0) 287 | 288 | total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) 289 | total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) 290 | L2_distance = ((total0-total1)**2).sum(2) 291 | if fix_sigma: 292 | bandwidth = fix_sigma 293 | else: 294 | bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples) 295 | bandwidth /= kernel_mul ** (kernel_num // 2) 296 | bandwidth_list = [bandwidth * (kernel_mul**i) + 1e-9 for i in range(kernel_num)] 297 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list] 298 | if torch.sum(torch.isnan(sum(kernel_val))): 299 | ## We encountered a Nan in Kernel 300 | print(f'Bandwidth List : {bandwidth_list}') 301 | print(f'L2 Distance : {L2_distance}') 302 | ## Check for Nan in L2 distance 303 | print(f'L2 Nan : {torch.sum(torch.isnan(L2_distance))}') 304 | for bandwidth_temp in bandwidth_list: 305 | print(f'Temp: {bandwidth_temp}') 306 | print(f'BW Nan : {torch.sum(torch.isnan(L2_distance / bandwidth_temp))}') 307 | return sum(kernel_val), L2_distance 308 | 309 | def forward(self, source, target, labels1=None, labels2=None): 310 | ## Source - [P*K, 2048], Target - [P*K, 2048] 311 | ## Devide them in "P" groups of "K" images 312 | rgb_features_list, ir_features_list = list(torch.split(source,[self.K]*self.P,dim=0)), list(torch.split(target,[self.K]*self.P,dim=0)) 313 | total_loss = torch.tensor([0.], requires_grad=True).to(torch.device('cuda')) 314 | if labels1 is not None and labels2 is not None: 315 | rgb_labels, ir_labels = torch.split(labels1, [self.K]*self.P, dim=0), torch.split(labels2, [self.K]*self.P, dim=0) 316 | print(f'RGB Labels : {rgb_labels}') 317 | print(f'IR Labels : {ir_labels}') 318 | 319 | xx_batch, yy_batch, xy_batch, yx_batch = 0,0,0,0 320 | 321 | for rgb_feat, ir_feat in zip(rgb_features_list, ir_features_list): 322 | source, target = rgb_feat, ir_feat ## 4, 2048 ; 4*2048 -> 4*2048 323 | ## (rgb, ir, mid) -> rgb - mid + ir- mid -> 324 | batch_size = int(source.size()[0]) 325 | kernels, l2dist = self.guassian_kernel(source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma) 326 | XX = kernels[:batch_size, :batch_size] 327 | YY = kernels[batch_size:, batch_size:] 328 | XY = kernels[:batch_size, batch_size:] 329 | YX = kernels[batch_size:, :batch_size] 330 | 331 | xx_batch += torch.mean(XX) 332 | yy_batch += torch.mean(YY) 333 | xy_batch += torch.mean(XY) 334 | yx_batch += torch.mean(YX) 335 | 336 | if self.margin: 337 | loss = torch.mean(XX + YY - XY -YX) 338 | if loss-self.margin > 0: 339 | total_loss += loss 340 | else: 341 | total_loss += torch.clamp(loss - self.margin, min=0) 342 | 343 | else: 344 | total_loss += torch.mean(XX + YY - XY -YX) 345 | 346 | total_loss /= self.P 347 | return total_loss, torch.max(l2dist), [xx_batch / self.P, yy_batch / self.P, xy_batch / self.P, yx_batch / self.P] -------------------------------------------------------------------------------- /model_mine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | from resnet import resnet50, resnet18 6 | 7 | class Normalize(nn.Module): 8 | def __init__(self, power=2): 9 | super(Normalize, self).__init__() 10 | self.power = power 11 | 12 | def forward(self, x): 13 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 14 | out = x.div(norm) 15 | return out 16 | 17 | class Non_local(nn.Module): 18 | def __init__(self, in_channels, reduc_ratio=2): 19 | super(Non_local, self).__init__() 20 | 21 | self.in_channels = in_channels 22 | self.inter_channels = reduc_ratio//reduc_ratio 23 | 24 | self.g = nn.Sequential( 25 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, 26 | padding=0), 27 | ) 28 | 29 | self.W = nn.Sequential( 30 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, 31 | kernel_size=1, stride=1, padding=0), 32 | nn.BatchNorm2d(self.in_channels), 33 | ) 34 | nn.init.constant_(self.W[1].weight, 0.0) 35 | nn.init.constant_(self.W[1].bias, 0.0) 36 | 37 | 38 | 39 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 40 | kernel_size=1, stride=1, padding=0) 41 | 42 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 43 | kernel_size=1, stride=1, padding=0) 44 | 45 | def forward(self, x): 46 | ''' 47 | :param x: (b, c, t, h, w) 48 | :return: 49 | ''' 50 | 51 | batch_size = x.size(0) 52 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 53 | g_x = g_x.permute(0, 2, 1) 54 | 55 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 56 | theta_x = theta_x.permute(0, 2, 1) 57 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 58 | f = torch.matmul(theta_x, phi_x) 59 | N = f.size(-1) 60 | # f_div_C = torch.nn.functional.softmax(f, dim=-1) 61 | f_div_C = f / N 62 | 63 | y = torch.matmul(f_div_C, g_x) 64 | y = y.permute(0, 2, 1).contiguous() 65 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 66 | W_y = self.W(y) 67 | z = W_y + x 68 | 69 | return z 70 | 71 | 72 | # ##################################################################### 73 | def weights_init_kaiming(m): 74 | classname = m.__class__.__name__ 75 | # print(classname) 76 | if classname.find('Conv') != -1: 77 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 78 | elif classname.find('Linear') != -1: 79 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 80 | init.zeros_(m.bias.data) 81 | elif classname.find('BatchNorm1d') != -1: 82 | init.normal_(m.weight.data, 1.0, 0.01) 83 | init.zeros_(m.bias.data) 84 | 85 | def weights_init_classifier(m): 86 | classname = m.__class__.__name__ 87 | if classname.find('Linear') != -1: 88 | init.normal_(m.weight.data, 0, 0.001) 89 | if m.bias: 90 | init.zeros_(m.bias.data) 91 | 92 | 93 | 94 | class visible_module(nn.Module): 95 | def __init__(self, arch='resnet50', share_net=1): 96 | super(visible_module, self).__init__() 97 | 98 | model_v = resnet50(pretrained=True, 99 | last_conv_stride=1, last_conv_dilation=1) 100 | # avg pooling to global pooling 101 | self.share_net = share_net 102 | 103 | if self.share_net == 0: 104 | pass 105 | else: 106 | self.visible = nn.ModuleList() 107 | self.visible.conv1 = model_v.conv1 108 | self.visible.bn1 = model_v.bn1 109 | self.visible.relu = model_v.relu 110 | self.visible.maxpool = model_v.maxpool 111 | if self.share_net > 1: 112 | for i in range(1, self.share_net): 113 | setattr(self.visible,'layer'+str(i), getattr(model_v,'layer'+str(i))) 114 | 115 | def forward(self, x): 116 | if self.share_net == 0: 117 | return x 118 | else: 119 | x = self.visible.conv1(x) 120 | x = self.visible.bn1(x) 121 | x = self.visible.relu(x) 122 | x = self.visible.maxpool(x) 123 | 124 | if self.share_net > 1: 125 | for i in range(1, self.share_net): 126 | x = getattr(self.visible, 'layer'+str(i))(x) 127 | return x 128 | 129 | 130 | class thermal_module(nn.Module): 131 | def __init__(self, arch='resnet50', share_net=1): 132 | super(thermal_module, self).__init__() 133 | 134 | model_t = resnet50(pretrained=True, 135 | last_conv_stride=1, last_conv_dilation=1) 136 | # avg pooling to global pooling 137 | self.share_net = share_net 138 | 139 | if self.share_net == 0: 140 | pass 141 | else: 142 | self.thermal = nn.ModuleList() 143 | self.thermal.conv1 = model_t.conv1 144 | self.thermal.bn1 = model_t.bn1 145 | self.thermal.relu = model_t.relu 146 | self.thermal.maxpool = model_t.maxpool 147 | if self.share_net > 1: 148 | for i in range(1, self.share_net): 149 | setattr(self.thermal,'layer'+str(i), getattr(model_t,'layer'+str(i))) 150 | 151 | def forward(self, x): 152 | if self.share_net == 0: 153 | return x 154 | else: 155 | x = self.thermal.conv1(x) 156 | x = self.thermal.bn1(x) 157 | x = self.thermal.relu(x) 158 | x = self.thermal.maxpool(x) 159 | 160 | if self.share_net > 1: 161 | for i in range(1, self.share_net): 162 | x = getattr(self.thermal, 'layer'+str(i))(x) 163 | return x 164 | 165 | 166 | class base_resnet(nn.Module): 167 | def __init__(self, arch='resnet50', share_net=1): 168 | super(base_resnet, self).__init__() 169 | 170 | model_base = resnet50(pretrained=True, 171 | last_conv_stride=1, last_conv_dilation=1) 172 | # avg pooling to global pooling 173 | model_base.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 174 | self.share_net = share_net 175 | if self.share_net == 0: 176 | self.base = model_base 177 | else: 178 | self.base = nn.ModuleList() 179 | 180 | if self.share_net > 4: 181 | pass 182 | else: 183 | for i in range(self.share_net, 5): 184 | setattr(self.base,'layer'+str(i), getattr(model_base,'layer'+str(i))) 185 | 186 | def forward(self, x): 187 | if self.share_net == 0: 188 | x = self.base.conv1(x) 189 | x = self.base.bn1(x) 190 | x = self.base.relu(x) 191 | x = self.base.maxpool(x) 192 | 193 | x = self.base.layer1(x) 194 | x = self.base.layer2(x) 195 | x = self.base.layer3(x) 196 | x = self.base.layer4(x) 197 | return x 198 | elif self.share_net > 4: 199 | return x 200 | else: 201 | for i in range(self.share_net, 5): 202 | x = getattr(self.base, 'layer'+str(i))(x) 203 | return x 204 | 205 | 206 | 207 | class embed_net(nn.Module): 208 | def __init__(self, class_num, no_local= 'off', gm_pool = 'on', arch='resnet50', share_net=1, pcb='on',local_feat_dim=256, num_strips=6): 209 | super(embed_net, self).__init__() 210 | 211 | self.thermal_module = thermal_module(arch=arch, share_net=share_net) 212 | self.visible_module = visible_module(arch=arch, share_net=share_net) 213 | self.base_resnet = base_resnet(arch=arch, share_net=share_net) 214 | 215 | self.non_local = no_local 216 | self.pcb = pcb 217 | if self.non_local =='on': 218 | pass 219 | 220 | 221 | pool_dim = 2048 222 | self.l2norm = Normalize(2) 223 | self.gm_pool = gm_pool 224 | 225 | if self.pcb == 'on': 226 | self.num_stripes=num_strips 227 | local_conv_out_channels=local_feat_dim 228 | 229 | self.local_conv_list = nn.ModuleList() 230 | for _ in range(self.num_stripes): 231 | conv = nn.Conv2d(pool_dim, local_conv_out_channels, 1) 232 | conv.apply(weights_init_kaiming) 233 | self.local_conv_list.append(nn.Sequential( 234 | conv, 235 | nn.BatchNorm2d(local_conv_out_channels), 236 | nn.ReLU(inplace=True) 237 | )) 238 | 239 | self.fc_list = nn.ModuleList() 240 | for _ in range(self.num_stripes): 241 | fc = nn.Linear(local_conv_out_channels, class_num) 242 | init.normal_(fc.weight, std=0.001) 243 | init.constant_(fc.bias, 0) 244 | self.fc_list.append(fc) 245 | 246 | 247 | else: 248 | self.bottleneck = nn.BatchNorm1d(pool_dim) 249 | self.bottleneck.bias.requires_grad_(False) # no shift 250 | 251 | self.classifier = nn.Linear(pool_dim, class_num, bias=False) 252 | 253 | self.bottleneck.apply(weights_init_kaiming) 254 | self.classifier.apply(weights_init_classifier) 255 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 256 | 257 | 258 | 259 | 260 | def forward(self, x1, x2, modal=0): 261 | if modal == 0: 262 | x1 = self.visible_module(x1) 263 | x2 = self.thermal_module(x2) 264 | x = torch.cat((x1, x2), 0) 265 | elif modal == 1: 266 | x = self.visible_module(x1) 267 | elif modal == 2: 268 | x = self.thermal_module(x2) 269 | 270 | # shared block 271 | if self.non_local == 'on': 272 | pass 273 | else: 274 | x = self.base_resnet(x) 275 | 276 | if self.pcb == 'on': 277 | feat = x 278 | assert feat.size(2) % self.num_stripes == 0 279 | stripe_h = int(feat.size(2) / self.num_stripes) 280 | local_feat_list = [] 281 | logits_list = [] 282 | for i in range(self.num_stripes): 283 | # shape [N, C, 1, 1] 284 | 285 | # average pool 286 | #local_feat = F.avg_pool2d(feat[:, :, i * stripe_h: (i + 1) * stripe_h, :],(stripe_h, feat.size(-1))) 287 | if self.gm_pool == 'on': 288 | # gm pool 289 | local_feat = feat[:, :, i * stripe_h: (i + 1) * stripe_h, :] 290 | b, c, h, w = local_feat.shape 291 | local_feat = local_feat.view(b,c,-1) 292 | p = 10.0 # regDB: 10.0 SYSU: 3.0 293 | local_feat = (torch.mean(local_feat**p, dim=-1) + 1e-12)**(1/p) 294 | else: 295 | # average pool 296 | #local_feat = F.avg_pool2d(feat[:, :, i * stripe_h: (i + 1) * stripe_h, :],(stripe_h, feat.size(-1))) 297 | local_feat = F.max_pool2d(feat[:, :, i * stripe_h: (i + 1) * stripe_h, :],(stripe_h, feat.size(-1))) 298 | 299 | 300 | # shape [N, c, 1, 1] 301 | local_feat = self.local_conv_list[i](local_feat.view(feat.size(0),feat.size(1),1,1)) 302 | 303 | 304 | # shape [N, c] 305 | local_feat = local_feat.view(local_feat.size(0), -1) 306 | local_feat_list.append(local_feat) 307 | 308 | 309 | if hasattr(self, 'fc_list'): 310 | logits_list.append(self.fc_list[i](local_feat)) 311 | 312 | feat_all = [lf for lf in local_feat_list] 313 | feat_all = torch.cat(feat_all, dim=1) 314 | 315 | 316 | if self.training: 317 | return local_feat_list, logits_list, feat_all 318 | else: 319 | return self.l2norm(feat_all) 320 | else: 321 | if self.gm_pool == 'on': 322 | b, c, h, w = x.shape 323 | x = x.view(b, c, -1) 324 | p = 3.0 325 | x_pool = (torch.mean(x**p, dim=-1) + 1e-12)**(1/p) 326 | else: 327 | x_pool = self.avgpool(x) 328 | x_pool = x_pool.view(x_pool.size(0), x_pool.size(1)) 329 | 330 | feat = self.bottleneck(x_pool) 331 | 332 | if self.training: 333 | return x_pool, self.classifier(feat)#, scores 334 | else: 335 | return self.l2norm(x_pool), self.l2norm(feat) -------------------------------------------------------------------------------- /pre_process_sysu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import pdb 4 | import os 5 | 6 | data_path = '/media/hijune/datadisk/reid-data/SYSU RGB-IR Re-ID/SYSU-MM01' 7 | 8 | rgb_cameras = ['cam1','cam2','cam4','cam5'] 9 | ir_cameras = ['cam3','cam6'] 10 | 11 | # load id info 12 | file_path_train = os.path.join(data_path,'exp/train_id.txt') 13 | file_path_val = os.path.join(data_path,'exp/val_id.txt') 14 | with open(file_path_train, 'r') as file: 15 | ids = file.read().splitlines() 16 | ids = [int(y) for y in ids[0].split(',')] 17 | id_train = ["%04d" % x for x in ids] 18 | 19 | with open(file_path_val, 'r') as file: 20 | ids = file.read().splitlines() 21 | ids = [int(y) for y in ids[0].split(',')] 22 | id_val = ["%04d" % x for x in ids] 23 | 24 | # combine train and val split 25 | id_train.extend(id_val) 26 | 27 | files_rgb = [] 28 | files_ir = [] 29 | for id in sorted(id_train): 30 | for cam in rgb_cameras: 31 | img_dir = os.path.join(data_path,cam,id) 32 | if os.path.isdir(img_dir): 33 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 34 | files_rgb.extend(new_files) 35 | 36 | for cam in ir_cameras: 37 | img_dir = os.path.join(data_path,cam,id) 38 | if os.path.isdir(img_dir): 39 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 40 | files_ir.extend(new_files) 41 | 42 | # relabel 43 | pid_container = set() 44 | for img_path in files_ir: 45 | pid = int(img_path[-13:-9]) 46 | pid_container.add(pid) 47 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 48 | fix_image_width = 144 49 | fix_image_height = 288 50 | def read_imgs(train_image): 51 | train_img = [] 52 | train_label = [] 53 | for img_path in train_image: 54 | # img 55 | img = Image.open(img_path) 56 | img = img.resize((fix_image_width, fix_image_height), Image.ANTIALIAS) 57 | pix_array = np.array(img) 58 | 59 | train_img.append(pix_array) 60 | 61 | # label 62 | pid = int(img_path[-13:-9]) 63 | pid = pid2label[pid] 64 | train_label.append(pid) 65 | return np.array(train_img), np.array(train_label) 66 | 67 | # rgb imges 68 | train_img, train_label = read_imgs(files_rgb) 69 | np.save(data_path + 'train_rgb_resized_img.npy', train_img) 70 | np.save(data_path + 'train_rgb_resized_label.npy', train_label) 71 | 72 | # ir imges 73 | train_img, train_label = read_imgs(files_ir) 74 | np.save(data_path + 'train_ir_resized_img.npy', train_img) 75 | np.save(data_path + 'train_ir_resized_label.npy', train_label) 76 | -------------------------------------------------------------------------------- /random_aug.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import random 4 | 5 | 6 | class RandomErasing(object): 7 | """ Randomly selects a rectangle region in an image and erases its pixels. 8 | 'Random Erasing Data Augmentation' by Zhong et al. 9 | See https://arxiv.org/pdf/1708.04896.pdf 10 | Args: 11 | probability: The probability that the Random Erasing operation will be performed. 12 | sl: Minimum proportion of erased area against input image. 13 | sh: Maximum proportion of erased area against input image. 14 | r1: Minimum aspect ratio of erased area. 15 | mean: Erasing value. 16 | """ 17 | 18 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 19 | self.probability = probability 20 | self.mean = mean 21 | self.sl = sl 22 | self.sh = sh 23 | self.r1 = r1 24 | 25 | def __call__(self, img): 26 | 27 | if random.uniform(0, 1) >= self.probability: 28 | return img 29 | 30 | for attempt in range(100): 31 | area = img.size()[1] * img.size()[2] 32 | 33 | target_area = random.uniform(self.sl, self.sh) * area 34 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 35 | 36 | h = int(round(math.sqrt(target_area * aspect_ratio))) 37 | w = int(round(math.sqrt(target_area / aspect_ratio))) 38 | 39 | if w < img.size()[2] and h < img.size()[1]: 40 | x1 = random.randint(0, img.size()[1] - h) 41 | y1 = random.randint(0, img.size()[2] - w) 42 | if img.size()[0] == 3: 43 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 44 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 45 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 46 | else: 47 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 48 | return img 49 | 50 | return img -------------------------------------------------------------------------------- /re_rank.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from scipy.spatial.distance import cdist 6 | 7 | def k_reciprocal(probFea,galFea,k1=20,k2=6,lambda_value=0.3, MemorySave = False, Minibatch = 2000): 8 | 9 | query_num = probFea.shape[0] 10 | all_num = query_num + galFea.shape[0] 11 | feat = np.append(probFea,galFea,axis = 0) 12 | feat = feat.astype(np.float16) 13 | #print('computing original distance') 14 | if MemorySave: 15 | original_dist = np.zeros(shape = [all_num,all_num],dtype = np.float16) 16 | i = 0 17 | while True: 18 | it = i + Minibatch 19 | if it < np.shape(feat)[0]: 20 | original_dist[i:it,] = np.power(cdist(feat[i:it,],feat),2).astype(np.float16) 21 | else: 22 | original_dist[i:,:] = np.power(cdist(feat[i:,],feat),2).astype(np.float16) 23 | break 24 | i = it 25 | else: 26 | original_dist = cdist(feat,feat).astype(np.float16) 27 | original_dist = np.power(original_dist,2).astype(np.float16) 28 | del feat 29 | gallery_num = original_dist.shape[0] 30 | original_dist = np.transpose(original_dist/np.max(original_dist,axis = 0)) 31 | V = np.zeros_like(original_dist).astype(np.float16) 32 | initial_rank = np.argsort(original_dist).astype(np.int32) 33 | 34 | 35 | #print('starting re_ranking') 36 | for i in range(all_num): 37 | # k-reciprocal neighbors 38 | forward_k_neigh_index = initial_rank[i,:k1+1] 39 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 40 | fi = np.where(backward_k_neigh_index==i)[0] 41 | k_reciprocal_index = forward_k_neigh_index[fi] 42 | k_reciprocal_expansion_index = k_reciprocal_index 43 | for j in range(len(k_reciprocal_index)): 44 | candidate = k_reciprocal_index[j] 45 | candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2))+1] 46 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2))+1] 47 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 48 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 49 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2/3*len(candidate_k_reciprocal_index): 50 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 51 | 52 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 53 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index]) 54 | V[i,k_reciprocal_expansion_index] = weight/np.sum(weight) 55 | original_dist = original_dist[:query_num,] 56 | if k2 != 1: 57 | V_qe = np.zeros_like(V,dtype=np.float16) 58 | for i in range(all_num): 59 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0) 60 | V = V_qe 61 | del V_qe 62 | del initial_rank 63 | invIndex = [] 64 | for i in range(gallery_num): 65 | invIndex.append(np.where(V[:,i] != 0)[0]) 66 | 67 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float16) 68 | 69 | 70 | for i in range(query_num): 71 | temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float16) 72 | indNonZero = np.where(V[i,:] != 0)[0] 73 | indImages = [] 74 | indImages = [invIndex[ind] for ind in indNonZero] 75 | for j in range(len(indNonZero)): 76 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 77 | jaccard_dist[i] = 1-temp_min/(2-temp_min) 78 | 79 | final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value 80 | del original_dist 81 | del V 82 | del jaccard_dist 83 | final_dist = final_dist[:query_num,query_num:] 84 | return final_dist 85 | 86 | 87 | 88 | def random_walk(query_feat, gall_feat, alpha = 0.95): 89 | pg_sim = torch.from_numpy(np.matmul(query_feat, np.transpose(gall_feat))) 90 | gg_sim = torch.from_numpy(np.matmul(gall_feat, np.transpose(gall_feat))) 91 | 92 | one_diag = torch.eye(gg_sim.size(0), dtype=torch.double) 93 | # row normalization 94 | zeros_diag = gg_sim - gg_sim.diag().diag() 95 | A = F.softmax(zeros_diag, dim=1) 96 | 97 | A = (1-alpha) * torch.inverse(one_diag - alpha * A) 98 | pg_sim = torch.matmul(pg_sim, A.t()) 99 | 100 | return -pg_sim.numpy() -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 18 | """3x3 convolution with padding""" 19 | # original padding is 1; original dilation is 1 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=dilation, bias=False, dilation=dilation) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride, dilation) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | # original padding is 1; original dilation is 1 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * 4) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | 97 | def __init__(self, block, layers, last_conv_stride=2, last_conv_dilation=1): 98 | 99 | self.inplanes = 64 100 | super(ResNet, self).__init__() 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 102 | bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_conv_stride, dilation=last_conv_dilation) 110 | 111 | for m in self.modules(): 112 | if isinstance(m, nn.Conv2d): 113 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 114 | m.weight.data.normal_(0, math.sqrt(2. / n)) 115 | elif isinstance(m, nn.BatchNorm2d): 116 | m.weight.data.fill_(1) 117 | m.bias.data.zero_() 118 | 119 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 120 | downsample = None 121 | if stride != 1 or self.inplanes != planes * block.expansion: 122 | downsample = nn.Sequential( 123 | nn.Conv2d(self.inplanes, planes * block.expansion, 124 | kernel_size=1, stride=stride, bias=False), 125 | nn.BatchNorm2d(planes * block.expansion), 126 | ) 127 | 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, stride, downsample, dilation)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append(block(self.inplanes, planes)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.maxpool(x) 141 | 142 | x = self.layer1(x) 143 | x = self.layer2(x) 144 | x = self.layer3(x) 145 | x = self.layer4(x) 146 | 147 | return x 148 | 149 | 150 | def remove_fc(state_dict): 151 | """Remove the fc layer parameters from state_dict.""" 152 | # for key, value in state_dict.items(): 153 | for key, value in list(state_dict.items()): 154 | if key.startswith('fc.'): 155 | del state_dict[key] 156 | return state_dict 157 | 158 | 159 | def resnet18(pretrained=False, **kwargs): 160 | """Constructs a ResNet-18 model. 161 | Args: 162 | pretrained (bool): If True, returns a model pre-trained on ImageNet 163 | """ 164 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 165 | if pretrained: 166 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet18']))) 167 | return model 168 | 169 | 170 | def resnet34(pretrained=False, **kwargs): 171 | """Constructs a ResNet-34 model. 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 176 | if pretrained: 177 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet34']))) 178 | return model 179 | 180 | 181 | def resnet50(pretrained=False, **kwargs): 182 | """Constructs a ResNet-50 model. 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | """ 186 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 187 | if pretrained: 188 | # model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50']))) 189 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50']))) 190 | return model 191 | 192 | 193 | def resnet101(pretrained=False, **kwargs): 194 | """Constructs a ResNet-101 model. 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | """ 198 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 199 | if pretrained: 200 | model.load_state_dict( 201 | remove_fc(model_zoo.load_url(model_urls['resnet101']))) 202 | return model 203 | 204 | 205 | def resnet152(pretrained=False, **kwargs): 206 | """Constructs a ResNet-152 model. 207 | Args: 208 | pretrained (bool): If True, returns a model pre-trained on ImageNet 209 | """ 210 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 211 | if pretrained: 212 | model.load_state_dict( 213 | remove_fc(model_zoo.load_url(model_urls['resnet152']))) 214 | return model -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | # python train_mine.py --dataset sysu --gpu 0 --pcb off --share_net 3 --batch-size 4 --num_pos 4 --run_name 'baseline' 2 | # python train_mine.py --dataset sysu --gpu 1 --pcb off --share_net 3 --batch-size 4 --num_pos 4 --dist_disc 'mmd' --run_name 'mmd' 3 | python train_mine.py --dataset sysu --gpu 1 --pcb off --share_net 3 --batch-size 4 --num_pos 4 --dist_disc 'margin_mmd' --margin_mmd 1.40 --run_name 'margin_mmd1.40' -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import time 4 | import torch.nn as nn 5 | import torch.backends.cudnn as cudnn 6 | from torch.autograd import Variable 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | from data_loader import SYSUData, RegDBData, TestData 10 | from data_manager import * 11 | from eval_metrics import eval_sysu, eval_regdb 12 | from model_mine import embed_net 13 | from utils import * 14 | import pdb 15 | from re_rank import random_walk, k_reciprocal 16 | import os 17 | import numpy as np 18 | 19 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 20 | parser.add_argument('--dataset', default='sysu', help='dataset name: regdb or sysu]') 21 | parser.add_argument('--lr', default=0.1 , type=float, help='learning rate, 0.00035 for adam') 22 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer') 23 | parser.add_argument('--arch', default='resnet50', type=str, 24 | help='network baseline:resnet18 or resnet50') 25 | parser.add_argument('--resume', '-r', default='', type=str, 26 | help='resume from checkpoint') 27 | parser.add_argument('--test-only', action='store_true', help='test only') 28 | parser.add_argument('--model_path', default='save_model/', type=str, 29 | help='model save path') 30 | parser.add_argument('--save_epoch', default=100, type=int, 31 | metavar='s', help='save model every 10 epochs') 32 | parser.add_argument('--log_path', default='log/', type=str, 33 | help='log save path') 34 | parser.add_argument('--vis_log_path', default='log/vis_log/', type=str, 35 | help='log save path') 36 | parser.add_argument('--workers', default=4, type=int, metavar='N', 37 | help='number of data loading workers (default: 4)') 38 | parser.add_argument('--img_w', default=144, type=int, 39 | metavar='imgw', help='img width') 40 | parser.add_argument('--img_h', default=288, type=int, 41 | metavar='imgh', help='img height') 42 | parser.add_argument('--batch-size', default=8, type=int, 43 | metavar='B', help='training batch size') 44 | parser.add_argument('--test-batch', default=64, type=int, 45 | metavar='tb', help='testing batch size') 46 | parser.add_argument('--method', default='base', type=str, 47 | metavar='m', help='method type: base or agw') 48 | parser.add_argument('--margin', default=0.3, type=float, 49 | metavar='margin', help='triplet loss margin') 50 | parser.add_argument('--num_pos', default=4, type=int, 51 | help='num of pos per identity in each modality') 52 | parser.add_argument('--trial', default=1, type=int, 53 | metavar='t', help='trial (only for RegDB dataset)') 54 | parser.add_argument('--seed', default=0, type=int, 55 | metavar='t', help='random seed') 56 | parser.add_argument('--gpu', default='0', type=str, 57 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 58 | parser.add_argument('--mode', default='all', type=str, help='all or indoor') 59 | 60 | parser.add_argument('--share_net', default=2, type=int, 61 | metavar='share', help='[1,2,3,4,5]the start number of shared network in the two-stream networks') 62 | parser.add_argument('--re_rank', default='no', type=str, help='performing reranking. [random_walk | k_reciprocal | no]') 63 | parser.add_argument('--pcb', default='on', type=str, help='performing PCB, on or off') 64 | parser.add_argument('--w_center', default=2.0, type=float, help='the weight for center loss') 65 | 66 | parser.add_argument('--local_feat_dim', default=256, type=int, 67 | help='feature dimention of each local feature in PCB') 68 | parser.add_argument('--num_strips', default=6, type=int, 69 | help='num of local strips in PCB') 70 | 71 | parser.add_argument('--label_smooth', default='on', type=str, help='performing label smooth or not') 72 | parser.add_argument('--run_name', type=str, 73 | help='Run Name for following experiment', default='test_run') 74 | 75 | parser.add_argument('--m', default=0.4, type=float, help='Value of Additive Margin') 76 | parser.add_argument('--num_trials', type=int, default=10 ,help='Number of Trials for Averaging') 77 | parser.add_argument('--tvsearch', action='store_true', help='Thermal to Visible search for RegDB') 78 | 79 | 80 | 81 | args = parser.parse_args() 82 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 83 | 84 | dataset = args.dataset 85 | if dataset == 'sysu': 86 | data_path = 'SYSU-MM01' 87 | n_class = 395 88 | test_mode = [1, 2] 89 | elif dataset =='regdb': 90 | data_path = 'RegDB/' 91 | n_class = 206 92 | test_mode = [2, 1] 93 | 94 | print(args.num_trials) 95 | num_trials = args.num_trials 96 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 97 | best_acc = 0 # best test accuracy 98 | start_epoch = 0 99 | if args.pcb == 'on': 100 | pool_dim = args.num_strips * args.local_feat_dim 101 | else: 102 | pool_dim = 2048 103 | print('==> Building model..') 104 | if args.method =='base': 105 | net = embed_net(n_class, no_local= 'off', gm_pool = 'on', arch=args.arch, share_net=args.share_net, pcb=args.pcb, local_feat_dim=args.local_feat_dim, num_strips=args.num_strips) 106 | else: 107 | net = embed_net(n_class, no_local= 'on', gm_pool = 'on', arch=args.arch, share_net=args.share_net, pcb=args.pcb) 108 | net.to(device) 109 | cudnn.benchmark = True 110 | 111 | 112 | checkpoint_path = args.model_path 113 | 114 | if args.method =='id': 115 | criterion = nn.CrossEntropyLoss() 116 | criterion.to(device) 117 | 118 | print('==> Loading data..') 119 | # Data loading code 120 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 121 | transform_train = transforms.Compose([ 122 | transforms.ToPILImage(), 123 | transforms.RandomCrop((args.img_h,args.img_w)), 124 | transforms.RandomHorizontalFlip(), 125 | transforms.ToTensor(), 126 | normalize, 127 | ]) 128 | 129 | transform_test = transforms.Compose([ 130 | transforms.ToPILImage(), 131 | transforms.Resize((args.img_h,args.img_w)), 132 | transforms.ToTensor(), 133 | normalize, 134 | ]) 135 | 136 | end = time.time() 137 | 138 | 139 | 140 | def extract_gall_feat(gall_loader): 141 | net.eval() 142 | print ('Extracting Gallery Feature...') 143 | start = time.time() 144 | ptr = 0 145 | gall_feat_pool = np.zeros((ngall, pool_dim)) 146 | gall_feat_fc = np.zeros((ngall, pool_dim)) 147 | with torch.no_grad(): 148 | for batch_idx, (input, label ) in enumerate(gall_loader): 149 | batch_num = input.size(0) 150 | input = Variable(input.cuda()) 151 | if args.pcb == 'on': 152 | feat_pool = net(input, input, test_mode[0]) 153 | gall_feat_pool[ptr:ptr+batch_num,: ] = feat_pool.detach().cpu().numpy() 154 | else: 155 | feat_pool, feat_fc = net(input, input, test_mode[0]) 156 | gall_feat_pool[ptr:ptr+batch_num,: ] = feat_pool.detach().cpu().numpy() 157 | gall_feat_fc[ptr:ptr+batch_num,: ] = feat_fc.detach().cpu().numpy() 158 | ptr = ptr + batch_num 159 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 160 | if args.pcb == 'on': 161 | return gall_feat_pool 162 | else: 163 | return gall_feat_pool, gall_feat_fc 164 | 165 | def extract_query_feat(query_loader): 166 | net.eval() 167 | print ('Extracting Query Feature...') 168 | start = time.time() 169 | ptr = 0 170 | query_feat_pool = np.zeros((nquery, pool_dim)) 171 | query_feat_fc = np.zeros((nquery, pool_dim)) 172 | with torch.no_grad(): 173 | for batch_idx, (input, label ) in enumerate(query_loader): 174 | batch_num = input.size(0) 175 | input = Variable(input.cuda()) 176 | if args.pcb == 'on': 177 | feat_pool = net(input, input, test_mode[1]) 178 | query_feat_pool[ptr:ptr+batch_num,: ] = feat_pool.detach().cpu().numpy() 179 | else: 180 | feat_pool, feat_fc = net(input, input, test_mode[1]) 181 | query_feat_pool[ptr:ptr+batch_num,: ] = feat_pool.detach().cpu().numpy() 182 | query_feat_fc[ptr:ptr+batch_num,: ] = feat_fc.detach().cpu().numpy() 183 | ptr = ptr + batch_num 184 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 185 | if args.pcb == 'on': 186 | return query_feat_pool 187 | else: 188 | return query_feat_pool, query_feat_fc 189 | 190 | suffix = args.run_name + '_' + dataset+'_c_tri_pcb_{}_w_tri_{}'.format(args.pcb,args.w_center) 191 | if args.pcb=='on': 192 | suffix = suffix + '_s{}_f{}'.format(args.num_strips, args.local_feat_dim) 193 | 194 | suffix = suffix + '_share_net{}'.format(args.share_net) 195 | if args.method=='agw': 196 | suffix = suffix + '_agw_k{}_p{}_lr_{}_seed_{}'.format(args.num_pos, args.batch_size, args.lr, args.seed) 197 | else: 198 | suffix = suffix + '_base_gm10_k{}_p{}_lr_{}_seed_{}'.format(args.num_pos, args.batch_size, args.lr, args.seed) 199 | 200 | 201 | if not args.optim == 'sgd': 202 | suffix = suffix + '_' + args.optim 203 | 204 | if dataset == 'regdb': 205 | suffix = suffix + '_trial_{}'.format(args.trial) 206 | 207 | model_path = checkpoint_path + suffix + '_best.t' 208 | 209 | print('model_path =', model_path) 210 | if os.path.isfile(model_path): 211 | print('==> loading checkpoint {}'.format(args.resume)) 212 | checkpoint = torch.load(model_path) 213 | net.load_state_dict(checkpoint['net']) 214 | print('==> loaded checkpoint {} (epoch {})' 215 | .format(args.resume, checkpoint['epoch'])) 216 | else: 217 | print('==> no checkpoint found at {}'.format(args.resume)) 218 | 219 | 220 | if dataset == 'sysu': 221 | 222 | metrics = {'Rank-1':[], 'mAP': [], 'mINP': [], 'Rank-5':[], 'Rank-10':[], 'Rank-20':[]} 223 | 224 | print('==> Resuming from checkpoint..') 225 | 226 | 227 | # testing set 228 | query_img, query_label, query_cam = process_query_sysu(data_path, mode=args.mode) 229 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=0) 230 | 231 | nquery = len(query_label) 232 | ngall = len(gall_label) 233 | print("Dataset statistics:") 234 | print(" ------------------------------") 235 | print(" subset | # ids | # images") 236 | print(" ------------------------------") 237 | print(" query | {:5d} | {:8d}".format(len(np.unique(query_label)), nquery)) 238 | print(" gallery | {:5d} | {:8d}".format(len(np.unique(gall_label)), ngall)) 239 | print(" ------------------------------") 240 | 241 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 242 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 243 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 244 | 245 | if args.pcb == 'on': 246 | query_feat_pool = extract_query_feat(query_loader) 247 | else: 248 | query_feat_pool, query_feat_fc = extract_query_feat(query_loader) 249 | for trial in range(0,num_trials): 250 | print('Test Trial: {}'.format(trial)) 251 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=trial) 252 | 253 | trial_gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 254 | trial_gall_loader = data.DataLoader(trial_gallset, batch_size=args.test_batch, shuffle=False, num_workers=4) 255 | 256 | if args.pcb == 'on': 257 | gall_feat_pool = extract_gall_feat(trial_gall_loader) 258 | else: 259 | gall_feat_pool, gall_feat_fc = extract_gall_feat(trial_gall_loader) 260 | 261 | if args.re_rank == 'random_walk': 262 | distmat_pool = random_walk(query_feat_pool, gall_feat_pool) 263 | if args.pcb == 'off': distmat = random_walk(query_feat_fc, gall_feat_fc) 264 | elif args.re_rank == 'k_reciprocal': 265 | distmat_pool = k_reciprocal(query_feat_pool, gall_feat_pool) 266 | if args.pcb == 'off': distmat = k_reciprocal(query_feat_fc, gall_feat_fc) 267 | elif args.re_rank == 'no': 268 | # compute the similarity 269 | distmat_pool = -np.matmul(query_feat_pool, np.transpose(gall_feat_pool)) 270 | if args.pcb == 'off': 271 | 272 | ## ADDING CHANAGES FOR RE-RANKNG ## 273 | distmat = -np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 274 | 275 | 276 | # print(distmat) 277 | # exit(0) 278 | ################################### 279 | 280 | # pool5 feature 281 | cmc_pool, mAP_pool, mINP_pool = eval_sysu(distmat_pool, query_label, gall_label, query_cam, gall_cam) 282 | 283 | if args.pcb == 'off': 284 | # fc feature 285 | cmc, mAP, mINP = eval_sysu(distmat, query_label, gall_label, query_cam, gall_cam) 286 | if trial == 0: 287 | if args.pcb == 'off': 288 | all_cmc = cmc 289 | all_mAP = mAP 290 | all_mINP = mINP 291 | all_cmc_pool = cmc_pool 292 | all_mAP_pool = mAP_pool 293 | all_mINP_pool = mINP_pool 294 | else: 295 | if args.pcb == 'off': 296 | all_cmc = all_cmc + cmc 297 | all_mAP = all_mAP + mAP 298 | all_mINP = all_mINP + mINP 299 | all_cmc_pool = all_cmc_pool + cmc_pool 300 | all_mAP_pool = all_mAP_pool + mAP_pool 301 | all_mINP_pool = all_mINP_pool + mINP_pool 302 | 303 | metrics['Rank-1'].append(cmc[0]) 304 | metrics['Rank-5'].append(cmc[4]) 305 | metrics['Rank-10'].append(cmc[9]) 306 | metrics['Rank-20'].append(cmc[19]) 307 | metrics['mAP'].append(mAP) 308 | metrics['mINP'].append(mINP) 309 | 310 | if args.pcb == 'off': 311 | print( 312 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 313 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 314 | print( 315 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 316 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 317 | 318 | 319 | elif dataset == 'regdb': 320 | 321 | metrics = {'Rank-1':[], 'mAP': [], 'mINP': [], 'Rank-5':[], 'Rank-10':[], 'Rank-20':[]} 322 | 323 | for trial in range(num_trials): 324 | test_trial = trial +1 325 | print('Test Trial: {}'.format(test_trial)) 326 | #model_path = checkpoint_path + 'regdbtest_share_net2_base_gm_p4_n8_lr_0.1_seed_0_trial_{}_best.t'.format(test_trial) 327 | if os.path.isfile(model_path): 328 | print('==> loading checkpoint {}'.format(args.resume)) 329 | checkpoint = torch.load(model_path) 330 | net.load_state_dict(checkpoint['net']) 331 | 332 | # training set 333 | trainset = RegDBData(data_path, test_trial, transform=transform_train) 334 | # generate the idx of each person identity 335 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 336 | 337 | # testing set 338 | query_img, query_label = process_test_regdb(data_path, trial=test_trial, modal='visible') 339 | gall_img, gall_label = process_test_regdb(data_path, trial=test_trial, modal='thermal') 340 | 341 | gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 342 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 343 | 344 | nquery = len(query_label) 345 | ngall = len(gall_label) 346 | 347 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 348 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 349 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 350 | 351 | if args.pcb == 'on': 352 | query_feat_pool = extract_query_feat(query_loader) 353 | gall_feat_pool = extract_gall_feat(gall_loader) 354 | else: 355 | query_feat_pool, query_feat_fc = extract_query_feat(query_loader) 356 | gall_feat_pool, gall_feat_fc = extract_gall_feat(gall_loader) 357 | 358 | if args.tvsearch: 359 | if args.re_rank == 'random_walk': 360 | distmat_pool = random_walk(gall_feat_pool, query_feat_pool) 361 | if args.pcb == 'off': distmat = random_walk(gall_feat_fc, query_feat_fc) 362 | elif args.re_rank == 'k_reciprocal': 363 | distmat_pool = k_reciprocal(gall_feat_pool, query_feat_pool) 364 | if args.pcb == 'off': distmat = k_reciprocal(gall_feat_fc, query_feat_fc) 365 | elif args.re_rank == 'no': 366 | # compute the similarity 367 | distmat_pool = -np.matmul(gall_feat_pool, np.transpose(query_feat_pool)) 368 | if args.pcb == 'off': distmat = -np.matmul(gall_feat_fc, np.transpose(query_feat_fc)) 369 | # pool5 feature 370 | cmc_pool, mAP_pool, mINP_pool = eval_regdb(distmat_pool, gall_label, query_label) 371 | if args.pcb == 'off': 372 | # fc feature 373 | cmc, mAP, mINP = eval_regdb(distmat,gall_label, query_label ) 374 | else: 375 | if args.re_rank == 'random_walk': 376 | distmat_pool = random_walk(query_feat_pool, gall_feat_pool) 377 | if args.pcb == 'off': distmat = random_walk(query_feat_fc, gall_feat_fc) 378 | elif args.re_rank == 'k_reciprocal': 379 | distmat_pool = k_reciprocal(query_feat_pool, gall_feat_pool) 380 | if args.pcb == 'off': distmat = k_reciprocal(query_feat_fc, gall_feat_fc) 381 | elif args.re_rank == 'no': 382 | # compute the similarity 383 | distmat_pool = -np.matmul(query_feat_pool, np.transpose(gall_feat_pool)) 384 | if args.pcb == 'off': distmat = -np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 385 | # pool5 feature 386 | cmc_pool, mAP_pool, mINP_pool = eval_regdb(distmat_pool, query_label, gall_label) 387 | if args.pcb == 'off': 388 | # fc feature 389 | cmc, mAP, mINP = eval_regdb(distmat, query_label, gall_label) 390 | 391 | 392 | if trial == 0: 393 | if args.pcb == 'off': 394 | all_cmc = cmc 395 | all_mAP = mAP 396 | all_mINP = mINP 397 | all_cmc_pool = cmc_pool 398 | all_mAP_pool = mAP_pool 399 | all_mINP_pool = mINP_pool 400 | else: 401 | if args.pcb == 'off': 402 | all_cmc = all_cmc + cmc 403 | all_mAP = all_mAP + mAP 404 | all_mINP = all_mINP + mINP 405 | all_cmc_pool = all_cmc_pool + cmc_pool 406 | all_mAP_pool = all_mAP_pool + mAP_pool 407 | all_mINP_pool = all_mINP_pool + mINP_pool 408 | 409 | metrics['Rank-1'].append(cmc[0]) 410 | metrics['Rank-5'].append(cmc[4]) 411 | metrics['Rank-10'].append(cmc[9]) 412 | metrics['Rank-20'].append(cmc[19]) 413 | metrics['mAP'].append(mAP) 414 | metrics['mINP'].append(mINP) 415 | 416 | if args.pcb == 'off': 417 | print( 418 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 419 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 420 | print( 421 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 422 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 423 | if args.pcb == 'off': 424 | cmc = all_cmc / num_trials 425 | mAP = all_mAP / num_trials 426 | mINP = all_mINP / num_trials 427 | 428 | cmc_pool = all_cmc_pool / num_trials 429 | mAP_pool = all_mAP_pool / num_trials 430 | mINP_pool = all_mINP_pool / num_trials 431 | print('All Average:') 432 | 433 | if args.pcb == 'off': 434 | print( 435 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 436 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 437 | print('POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 438 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 439 | 440 | print('=*'*50) 441 | 442 | print('Sanity Check') 443 | print('Average') 444 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 445 | np.mean(metrics['Rank-1']), np.mean(metrics['Rank-5']), np.mean(metrics['Rank-10']), np.mean(metrics['Rank-20']), np.mean(metrics['mAP']), np.mean(metrics['mINP']))) 446 | 447 | print('STD') 448 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 449 | np.std(metrics['Rank-1']), np.std(metrics['Rank-5']), np.std(metrics['Rank-10']), np.std(metrics['Rank-20']), np.std(metrics['mAP']), np.std(metrics['mINP']))) -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | python test.py --dataset sysu --gpu 0 --pcb off --share_net 3 --batch-size 4 --num_pos 4 --run_name 'baseline' -------------------------------------------------------------------------------- /test_mine_pcb.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import time 4 | import torch.nn as nn 5 | import torch.backends.cudnn as cudnn 6 | from torch.autograd import Variable 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | from data_loader import SYSUData, RegDBData, TestData 10 | from data_manager import * 11 | from eval_metrics import eval_sysu, eval_regdb 12 | from model_mine import embed_net 13 | from utils import * 14 | import pdb 15 | from re_rank import random_walk, k_reciprocal 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 18 | parser.add_argument('--dataset', default='sysu', help='dataset name: regdb or sysu]') 19 | parser.add_argument('--lr', default=0.1 , type=float, help='learning rate, 0.00035 for adam') 20 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer') 21 | parser.add_argument('--arch', default='resnet50', type=str, 22 | help='network baseline: resnet50') 23 | parser.add_argument('--resume', '-r', default='', type=str, 24 | help='resume from checkpoint') 25 | parser.add_argument('--test-only', action='store_true', help='test only') 26 | parser.add_argument('--model_path', default='save_model/', type=str, 27 | help='model save path') 28 | parser.add_argument('--save_epoch', default=20, type=int, 29 | metavar='s', help='save model every 10 epochs') 30 | parser.add_argument('--log_path', default='log/', type=str, 31 | help='log save path') 32 | parser.add_argument('--vis_log_path', default='log/vis_log/', type=str, 33 | help='log save path') 34 | parser.add_argument('--workers', default=4, type=int, metavar='N', 35 | help='number of data loading workers (default: 4)') 36 | parser.add_argument('--img_w', default=144, type=int, 37 | metavar='imgw', help='img width') 38 | parser.add_argument('--img_h', default=288, type=int, 39 | metavar='imgh', help='img height') 40 | parser.add_argument('--batch-size', default=8, type=int, 41 | metavar='B', help='training batch size') 42 | parser.add_argument('--test-batch', default=64, type=int, 43 | metavar='tb', help='testing batch size') 44 | parser.add_argument('--method', default='base', type=str, 45 | metavar='m', help='method type: base or awg') 46 | parser.add_argument('--margin', default=0.3, type=float, 47 | metavar='margin', help='triplet loss margin') 48 | parser.add_argument('--num_pos', default=4, type=int, 49 | help='num of pos per identity in each modality') 50 | parser.add_argument('--trial', default=1, type=int, 51 | metavar='t', help='trial (only for RegDB dataset)') 52 | parser.add_argument('--seed', default=0, type=int, 53 | metavar='t', help='random seed') 54 | parser.add_argument('--gpu', default='0', type=str, 55 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 56 | parser.add_argument('--mode', default='all', type=str, help='all or indoor for sysu') 57 | parser.add_argument('--tvsearch', action='store_true', help='whether thermal to visible search on RegDB') 58 | 59 | parser.add_argument('--share_net', default=2, type=int, 60 | metavar='share', help='[1,2,3,4]the start number of shared network in the two-stream networks') 61 | parser.add_argument('--re_rank', default='no', type=str, help='performing reranking. [random_walk | k_reciprocal | no]') 62 | parser.add_argument('--pcb', default='off', type=str, help='performing PCB, on or off') 63 | 64 | parser.add_argument('--w_center', default=2.0, type=float, help='the weight for center loss') 65 | 66 | parser.add_argument('--local_feat_dim', default=256, type=int, 67 | help='feature dimention of each local feature in PCB') 68 | parser.add_argument('--num_strips', default=6, type=int, 69 | help='num of local strips in PCB') 70 | 71 | parser.add_argument('--label_smooth', default='on', type=str, help='performing label smooth or not') 72 | 73 | args = parser.parse_args() 74 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 75 | 76 | dataset = args.dataset 77 | if dataset == 'sysu': 78 | data_path = './SYSU-MM01/' 79 | n_class = 395 80 | test_mode = [1, 2] 81 | elif dataset =='regdb': 82 | data_path = './RegDB/' 83 | n_class = 206 84 | test_mode = [2, 1] 85 | 86 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 87 | best_acc = 0 # best test accuracy 88 | start_epoch = 0 89 | if args.pcb == 'on': 90 | pool_dim = args.num_strips * args.local_feat_dim 91 | else: 92 | pool_dim = 2048 93 | print('==> Building model..') 94 | if args.method =='base': 95 | net = embed_net(n_class, no_local= 'off', gm_pool = 'on', arch=args.arch, share_net=args.share_net, pcb=args.pcb, local_feat_dim=args.local_feat_dim, num_strips=args.num_strips) 96 | else: 97 | net = embed_net(n_class, no_local= 'on', gm_pool = 'on', arch=args.arch, share_net=args.share_net, pcb=args.pcb) 98 | net.to(device) 99 | cudnn.benchmark = True 100 | 101 | checkpoint_path = args.model_path 102 | 103 | if args.method =='id': 104 | criterion = nn.CrossEntropyLoss() 105 | criterion.to(device) 106 | 107 | print('==> Loading data..') 108 | # Data loading code 109 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 110 | transform_train = transforms.Compose([ 111 | transforms.ToPILImage(), 112 | transforms.RandomCrop((args.img_h,args.img_w)), 113 | transforms.RandomHorizontalFlip(), 114 | transforms.ToTensor(), 115 | normalize, 116 | ]) 117 | 118 | transform_test = transforms.Compose([ 119 | transforms.ToPILImage(), 120 | transforms.Resize((args.img_h,args.img_w)), 121 | transforms.ToTensor(), 122 | normalize, 123 | ]) 124 | 125 | end = time.time() 126 | 127 | 128 | 129 | def extract_gall_feat(gall_loader): 130 | net.eval() 131 | print ('Extracting Gallery Feature...') 132 | start = time.time() 133 | ptr = 0 134 | gall_feat_pool = np.zeros((ngall, pool_dim)) 135 | gall_feat_fc = np.zeros((ngall, pool_dim)) 136 | with torch.no_grad(): 137 | for batch_idx, (input, label ) in enumerate(gall_loader): 138 | batch_num = input.size(0) 139 | input = Variable(input.cuda()) 140 | if args.pcb == 'on': 141 | feat_pool = net(input, input, test_mode[0]) 142 | gall_feat_pool[ptr:ptr+batch_num,: ] = feat_pool.detach().cpu().numpy() 143 | else: 144 | feat_pool, feat_fc = net(input, input, test_mode[0]) 145 | gall_feat_pool[ptr:ptr+batch_num,: ] = feat_pool.detach().cpu().numpy() 146 | gall_feat_fc[ptr:ptr+batch_num,: ] = feat_fc.detach().cpu().numpy() 147 | ptr = ptr + batch_num 148 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 149 | if args.pcb == 'on': 150 | return gall_feat_pool 151 | else: 152 | return gall_feat_pool, gall_feat_fc 153 | 154 | def extract_query_feat(query_loader): 155 | net.eval() 156 | print ('Extracting Query Feature...') 157 | start = time.time() 158 | ptr = 0 159 | query_feat_pool = np.zeros((nquery, pool_dim)) 160 | query_feat_fc = np.zeros((nquery, pool_dim)) 161 | with torch.no_grad(): 162 | for batch_idx, (input, label ) in enumerate(query_loader): 163 | batch_num = input.size(0) 164 | input = Variable(input.cuda()) 165 | if args.pcb == 'on': 166 | feat_pool = net(input, input, test_mode[1]) 167 | query_feat_pool[ptr:ptr+batch_num,: ] = feat_pool.detach().cpu().numpy() 168 | else: 169 | feat_pool, feat_fc = net(input, input, test_mode[1]) 170 | query_feat_pool[ptr:ptr+batch_num,: ] = feat_pool.detach().cpu().numpy() 171 | query_feat_fc[ptr:ptr+batch_num,: ] = feat_fc.detach().cpu().numpy() 172 | ptr = ptr + batch_num 173 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 174 | if args.pcb == 'on': 175 | return query_feat_pool 176 | else: 177 | return query_feat_pool, query_feat_fc 178 | 179 | 180 | if dataset == 'sysu': 181 | 182 | print('==> Resuming from checkpoint..') 183 | 184 | # model_path = checkpoint_path + args.resume 185 | model_path = checkpoint_path + 'sysu_c_tri_pcb_off_w_tri_2.0_share_net3_base_gm10_k4_p8_lr_0.1_seed_0_best.t' 186 | if os.path.isfile(model_path): 187 | print('==> loading checkpoint {}'.format(args.resume)) 188 | checkpoint = torch.load(model_path) 189 | net.load_state_dict(checkpoint['net']) 190 | print('==> loaded checkpoint {} (epoch {})' 191 | .format(args.resume, checkpoint['epoch'])) 192 | else: 193 | print('==> no checkpoint found at {}'.format(args.resume)) 194 | 195 | # testing set 196 | query_img, query_label, query_cam = process_query_sysu(data_path, mode=args.mode) 197 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=0) 198 | 199 | nquery = len(query_label) 200 | ngall = len(gall_label) 201 | print("Dataset statistics:") 202 | print(" ------------------------------") 203 | print(" subset | # ids | # images") 204 | print(" ------------------------------") 205 | print(" query | {:5d} | {:8d}".format(len(np.unique(query_label)), nquery)) 206 | print(" gallery | {:5d} | {:8d}".format(len(np.unique(gall_label)), ngall)) 207 | print(" ------------------------------") 208 | 209 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 210 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 211 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 212 | 213 | if args.pcb == 'on': 214 | query_feat_pool = extract_query_feat(query_loader) 215 | else: 216 | query_feat_pool, query_feat_fc = extract_query_feat(query_loader) 217 | for trial in range(10): 218 | print('Test Trial: {}'.format(trial)) 219 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=trial) 220 | 221 | trial_gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 222 | trial_gall_loader = data.DataLoader(trial_gallset, batch_size=args.test_batch, shuffle=False, num_workers=4) 223 | 224 | if args.pcb == 'on': 225 | gall_feat_pool = extract_gall_feat(trial_gall_loader) 226 | else: 227 | gall_feat_pool, gall_feat_fc = extract_gall_feat(trial_gall_loader) 228 | 229 | if args.re_rank == 'random_walk': 230 | distmat_pool = random_walk(query_feat_pool, gall_feat_pool) 231 | if args.pcb == 'off': distmat = random_walk(query_feat_fc, gall_feat_fc) 232 | elif args.re_rank == 'k_reciprocal': 233 | distmat_pool = k_reciprocal(query_feat_pool, gall_feat_pool) 234 | if args.pcb == 'off': distmat = k_reciprocal(query_feat_fc, gall_feat_fc) 235 | elif args.re_rank == 'no': 236 | # compute the similarity 237 | distmat_pool = -np.matmul(query_feat_pool, np.transpose(gall_feat_pool)) 238 | if args.pcb == 'off': distmat = -np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 239 | # pool5 feature 240 | cmc_pool, mAP_pool, mINP_pool = eval_sysu(distmat_pool, query_label, gall_label, query_cam, gall_cam) 241 | 242 | if args.pcb == 'off': 243 | # fc feature 244 | cmc, mAP, mINP = eval_sysu(distmat, query_label, gall_label, query_cam, gall_cam) 245 | if trial == 0: 246 | if args.pcb == 'off': 247 | all_cmc = cmc 248 | all_mAP = mAP 249 | all_mINP = mINP 250 | all_cmc_pool = cmc_pool 251 | all_mAP_pool = mAP_pool 252 | all_mINP_pool = mINP_pool 253 | else: 254 | if args.pcb == 'off': 255 | all_cmc = all_cmc + cmc 256 | all_mAP = all_mAP + mAP 257 | all_mINP = all_mINP + mINP 258 | all_cmc_pool = all_cmc_pool + cmc_pool 259 | all_mAP_pool = all_mAP_pool + mAP_pool 260 | all_mINP_pool = all_mINP_pool + mINP_pool 261 | 262 | 263 | if args.pcb == 'off': 264 | print( 265 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 266 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 267 | print( 268 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 269 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 270 | 271 | 272 | elif dataset == 'regdb': 273 | 274 | for trial in range(10): 275 | test_trial = trial +1 276 | print('Test Trial: {}'.format(test_trial)) 277 | #model_path = checkpoint_path + 'regdbtest_share_net2_base_gm_p4_n8_lr_0.1_seed_0_trial_{}_best.t'.format(test_trial) 278 | model_path = checkpoint_path + 'regdb_c_tri_pcb_on_w_tri_2.0_s6_f256_share_net2_base_gm10_k4_p8_lr_0.1_seed_0_trial_{}_best.t'.format(test_trial) 279 | if os.path.isfile(model_path): 280 | print('==> loading checkpoint {}'.format(args.resume)) 281 | checkpoint = torch.load(model_path) 282 | net.load_state_dict(checkpoint['net']) 283 | 284 | # training set 285 | trainset = RegDBData(data_path, test_trial, transform=transform_train) 286 | # generate the idx of each person identity 287 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 288 | 289 | # testing set 290 | query_img, query_label = process_test_regdb(data_path, trial=test_trial, modal='visible') 291 | gall_img, gall_label = process_test_regdb(data_path, trial=test_trial, modal='thermal') 292 | 293 | gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 294 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 295 | 296 | nquery = len(query_label) 297 | ngall = len(gall_label) 298 | 299 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 300 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 301 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 302 | 303 | if args.pcb == 'on': 304 | query_feat_pool = extract_query_feat(query_loader) 305 | gall_feat_pool = extract_gall_feat(gall_loader) 306 | else: 307 | query_feat_pool, query_feat_fc = extract_query_feat(query_loader) 308 | gall_feat_pool, gall_feat_fc = extract_gall_feat(gall_loader) 309 | 310 | if args.tvsearch: 311 | if args.re_rank == 'random_walk': 312 | distmat_pool = random_walk(gall_feat_pool, query_feat_pool) 313 | if args.pcb == 'off': distmat = random_walk(gall_feat_fc, query_feat_fc) 314 | elif args.re_rank == 'k_reciprocal': 315 | distmat_pool = k_reciprocal(gall_feat_pool, query_feat_pool) 316 | if args.pcb == 'off': distmat = k_reciprocal(gall_feat_fc, query_feat_fc) 317 | elif args.re_rank == 'no': 318 | # compute the similarity 319 | distmat_pool = -np.matmul(gall_feat_pool, np.transpose(query_feat_pool)) 320 | if args.pcb == 'off': distmat = -np.matmul(gall_feat_fc, np.transpose(query_feat_fc)) 321 | # pool5 feature 322 | cmc_pool, mAP_pool, mINP_pool = eval_regdb(distmat_pool, gall_label, query_label) 323 | if args.pcb == 'off': 324 | # fc feature 325 | cmc, mAP, mINP = eval_regdb(distmat,gall_label, query_label ) 326 | else: 327 | if args.re_rank == 'random_walk': 328 | distmat_pool = random_walk(query_feat_pool, gall_feat_pool) 329 | if args.pcb == 'off': distmat = random_walk(query_feat_fc, gall_feat_fc) 330 | elif args.re_rank == 'k_reciprocal': 331 | distmat_pool = k_reciprocal(query_feat_pool, gall_feat_pool) 332 | if args.pcb == 'off': distmat = k_reciprocal(query_feat_fc, gall_feat_fc) 333 | elif args.re_rank == 'no': 334 | # compute the similarity 335 | distmat_pool = -np.matmul(query_feat_pool, np.transpose(gall_feat_pool)) 336 | if args.pcb == 'off': distmat = -np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 337 | # pool5 feature 338 | cmc_pool, mAP_pool, mINP_pool = eval_regdb(distmat_pool, query_label, gall_label) 339 | if args.pcb == 'off': 340 | # fc feature 341 | cmc, mAP, mINP = eval_regdb(distmat, query_label, gall_label) 342 | 343 | 344 | if trial == 0: 345 | if args.pcb == 'off': 346 | all_cmc = cmc 347 | all_mAP = mAP 348 | all_mINP = mINP 349 | all_cmc_pool = cmc_pool 350 | all_mAP_pool = mAP_pool 351 | all_mINP_pool = mINP_pool 352 | else: 353 | if args.pcb == 'off': 354 | all_cmc = all_cmc + cmc 355 | all_mAP = all_mAP + mAP 356 | all_mINP = all_mINP + mINP 357 | all_cmc_pool = all_cmc_pool + cmc_pool 358 | all_mAP_pool = all_mAP_pool + mAP_pool 359 | all_mINP_pool = all_mINP_pool + mINP_pool 360 | 361 | if args.pcb == 'off': 362 | print( 363 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 364 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 365 | print( 366 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 367 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 368 | if args.pcb == 'off': 369 | cmc = all_cmc / 10 370 | mAP = all_mAP / 10 371 | 372 | cmc_pool = all_cmc_pool / 10 373 | mAP_pool = all_mAP_pool / 10 374 | print('All Average:') 375 | 376 | if args.pcb == 'off': 377 | print( 378 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 379 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 380 | print('POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 381 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) -------------------------------------------------------------------------------- /train_mine.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import sys 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | import torch.backends.cudnn as cudnn 10 | from torch.autograd import Variable 11 | import torch.utils.data as data 12 | import torchvision 13 | import torchvision.transforms as transforms 14 | from data_loader import SYSUData, RegDBData, TestData 15 | from data_manager import * 16 | from eval_metrics import eval_sysu, eval_regdb 17 | from model_mine import embed_net 18 | from utils import * 19 | from loss import OriTripletLoss, CenterTripletLoss, CrossEntropyLabelSmooth, TripletLoss_WRT, MMD_Loss, MarginMMD_Loss 20 | from tensorboardX import SummaryWriter 21 | from re_rank import random_walk, k_reciprocal 22 | 23 | from random_aug import RandomErasing 24 | 25 | import numpy as np 26 | np.set_printoptions(threshold=np.inf) 27 | 28 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 29 | parser.add_argument('--dataset', default='sysu', help='dataset name: regdb or sysu]') 30 | parser.add_argument('--lr', default=0.1 , type=float, help='learning rate, 0.00035 for adam') 31 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer') 32 | parser.add_argument('--arch', default='resnet50', type=str, 33 | help='network baseline:resnet18 or resnet50') 34 | parser.add_argument('--resume', '-r', default='', type=str, 35 | help='resume from checkpoint') 36 | parser.add_argument('--test-only', action='store_true', help='test only') 37 | parser.add_argument('--model_path', default='save_model/', type=str, 38 | help='model save path') 39 | parser.add_argument('--save_epoch', default=100, type=int, 40 | metavar='s', help='save model every 10 epochs') 41 | parser.add_argument('--log_path', default='log/', type=str, 42 | help='log save path') 43 | parser.add_argument('--vis_log_path', default='log/vis_log/', type=str, 44 | help='log save path') 45 | parser.add_argument('--workers', default=4, type=int, metavar='N', 46 | help='number of data loading workers (default: 4)') 47 | parser.add_argument('--img_w', default=144, type=int, 48 | metavar='imgw', help='img width') 49 | parser.add_argument('--img_h', default=288, type=int, 50 | metavar='imgh', help='img height') 51 | parser.add_argument('--batch-size', default=4, type=int, 52 | metavar='B', help='training batch size') 53 | parser.add_argument('--test-batch', default=64, type=int, 54 | metavar='tb', help='testing batch size') 55 | parser.add_argument('--method', default='base', type=str, 56 | metavar='m', help='method type: base or agw') 57 | parser.add_argument('--margin', default=0.3, type=float, 58 | metavar='margin', help='triplet loss margin') 59 | parser.add_argument('--num_pos', default=4, type=int, 60 | help='num of pos per identity in each modality') 61 | parser.add_argument('--trial', default=1, type=int, 62 | metavar='t', help='trial (only for RegDB dataset)') 63 | parser.add_argument('--seed', default=0, type=int, 64 | metavar='t', help='random seed') 65 | parser.add_argument('--gpu', default='0', type=str, 66 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 67 | parser.add_argument('--mode', default='all', type=str, help='all or indoor') 68 | 69 | parser.add_argument('--share_net', default=3, type=int, 70 | metavar='share', help='[1,2,3,4,5]the start number of shared network in the two-stream networks') 71 | parser.add_argument('--re_rank', default='no', type=str, help='performing reranking. [random_walk | k_reciprocal | no]') 72 | parser.add_argument('--pcb', default='off', type=str, help='performing PCB, on or off') 73 | parser.add_argument('--w_center', default=2.0, type=float, help='the weight for center loss') 74 | 75 | parser.add_argument('--local_feat_dim', default=256, type=int, 76 | help='feature dimention of each local feature in PCB') 77 | parser.add_argument('--num_strips', default=6, type=int, 78 | help='num of local strips in PCB') 79 | 80 | parser.add_argument('--aug', action='store_true', help='Use Random Erasing Augmentation') 81 | parser.add_argument('--label_smooth', default='off', type=str, help='performing label smooth or not') 82 | parser.add_argument('--dist_disc', type=str, help='Include Distribution Discripeancy Loss', default=None) 83 | parser.add_argument('--margin_mmd', default=0, type=float, help='Value of Margin For MMD Loss') 84 | parser.add_argument('--dist_w', default=0.25, type=float, help='Weight of Distribution Discrepancy Loss') 85 | 86 | parser.add_argument('--run_name', type=str, 87 | help='Run Name for following experiment', default='test_run') 88 | 89 | args = parser.parse_args() 90 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 91 | 92 | set_seed(args.seed) 93 | 94 | dataset = args.dataset 95 | if dataset == 'sysu': 96 | data_path = './SYSU-MM01' 97 | log_path = args.log_path + 'sysu_log/' 98 | test_mode = [1, 2] # thermal to visible 99 | elif dataset == 'regdb': 100 | data_path = './RegDB/' 101 | log_path = args.log_path + 'regdb_log/' 102 | test_mode = [2, 1] # visible to thermal 103 | 104 | checkpoint_path = args.model_path 105 | 106 | if not os.path.isdir(log_path): 107 | os.makedirs(log_path) 108 | if not os.path.isdir(checkpoint_path): 109 | os.makedirs(checkpoint_path) 110 | if not os.path.isdir(args.vis_log_path): 111 | os.makedirs(args.vis_log_path) 112 | 113 | suffix = args.run_name + '_' + dataset+'_c_tri_pcb_{}_w_tri_{}'.format(args.pcb,args.w_center) 114 | if args.pcb=='on': 115 | suffix = suffix + '_s{}_f{}'.format(args.num_strips, args.local_feat_dim) 116 | 117 | suffix = suffix + '_share_net{}'.format(args.share_net) 118 | if args.method=='agw': 119 | suffix = suffix + '_agw_k{}_p{}_lr_{}_seed_{}'.format(args.num_pos, args.batch_size, args.lr, args.seed) 120 | else: 121 | suffix = suffix + '_base_gm10_k{}_p{}_lr_{}_seed_{}'.format(args.num_pos, args.batch_size, args.lr, args.seed) 122 | 123 | 124 | if not args.optim == 'sgd': 125 | suffix = suffix + '_' + args.optim 126 | 127 | if dataset == 'regdb': 128 | suffix = suffix + '_trial_{}'.format(args.trial) 129 | 130 | sys.stdout = Logger(log_path + suffix + '_os.txt') 131 | 132 | vis_log_dir = args.vis_log_path + suffix + '/' 133 | 134 | if not os.path.isdir(vis_log_dir): 135 | os.makedirs(vis_log_dir) 136 | writer = SummaryWriter(vis_log_dir) 137 | print("==========\nArgs:{}\n==========".format(args)) 138 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 139 | best_acc = 0 # best test accuracy 140 | start_epoch = 0 141 | 142 | print('==> Loading data..') 143 | # Data loading code 144 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 145 | 146 | if args.aug: 147 | transform_train = transforms.Compose([ 148 | transforms.ToPILImage(), 149 | transforms.Pad(10), 150 | transforms.RandomCrop((args.img_h, args.img_w)), 151 | transforms.RandomHorizontalFlip(), 152 | transforms.ToTensor(), 153 | normalize, 154 | RandomErasing(probability=0.5, mean=[0.485, 0.456, 0.406]) 155 | ]) 156 | 157 | else: 158 | transform_train = transforms.Compose([ 159 | transforms.ToPILImage(), 160 | transforms.Pad(10), 161 | transforms.RandomCrop((args.img_h, args.img_w)), 162 | transforms.RandomHorizontalFlip(), 163 | transforms.ToTensor(), 164 | normalize, 165 | ]) 166 | 167 | transform_test = transforms.Compose([ 168 | transforms.ToPILImage(), 169 | transforms.Resize((args.img_h, args.img_w)), 170 | transforms.ToTensor(), 171 | normalize, 172 | ]) 173 | 174 | end = time.time() 175 | if dataset == 'sysu': 176 | # training set 177 | trainset = SYSUData(data_path, transform=transform_train) 178 | # generate the idx of each person identity 179 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 180 | 181 | # testing set 182 | query_img, query_label, query_cam = process_query_sysu(data_path, mode=args.mode) 183 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=0) 184 | 185 | elif dataset == 'regdb': 186 | # training set 187 | trainset = RegDBData(data_path, args.trial, transform=transform_train) 188 | # generate the idx of each person identity 189 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 190 | 191 | # testing set 192 | query_img, query_label = process_test_regdb(data_path, trial=args.trial, modal='visible') 193 | gall_img, gall_label = process_test_regdb(data_path, trial=args.trial, modal='thermal') 194 | 195 | gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 196 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 197 | 198 | # testing data loader 199 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 200 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 201 | 202 | n_class = len(np.unique(trainset.train_color_label)) 203 | nquery = len(query_label) 204 | ngall = len(gall_label) 205 | 206 | print('Dataset {} statistics:'.format(dataset)) 207 | print(' ------------------------------') 208 | print(' subset | # ids | # images') 209 | print(' ------------------------------') 210 | print(' visible | {:5d} | {:8d}'.format(n_class, len(trainset.train_color_label))) 211 | print(' thermal | {:5d} | {:8d}'.format(n_class, len(trainset.train_thermal_label))) 212 | print(' ------------------------------') 213 | print(' query | {:5d} | {:8d}'.format(len(np.unique(query_label)), nquery)) 214 | print(' gallery | {:5d} | {:8d}'.format(len(np.unique(gall_label)), ngall)) 215 | print(' ------------------------------') 216 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 217 | 218 | print('==> Building model..') 219 | if args.method =='base': 220 | net = embed_net(n_class, no_local= 'off', gm_pool = 'on', arch=args.arch, share_net=args.share_net, pcb=args.pcb, local_feat_dim=args.local_feat_dim, num_strips=args.num_strips) 221 | else: 222 | net = embed_net(n_class, no_local= 'on', gm_pool = 'on', arch=args.arch, share_net=args.share_net, pcb=args.pcb) 223 | net.to(device) 224 | 225 | 226 | cudnn.benchmark = True 227 | 228 | if len(args.resume) > 0: 229 | model_path = checkpoint_path + args.resume 230 | if os.path.isfile(model_path): 231 | print('==> loading checkpoint {}'.format(args.resume)) 232 | checkpoint = torch.load(model_path) 233 | start_epoch = checkpoint['epoch'] 234 | net.load_state_dict(checkpoint['net']) 235 | print('==> loaded checkpoint {} (epoch {})' 236 | .format(args.resume, checkpoint['epoch'])) 237 | else: 238 | print('==> no checkpoint found at {}'.format(args.resume)) 239 | 240 | # define loss function 241 | if args.label_smooth == 'off': 242 | criterion_id = nn.CrossEntropyLoss() 243 | else: 244 | criterion_id = CrossEntropyLabelSmooth(n_class) 245 | 246 | if args.method == 'agw': 247 | criterion_tri = TripletLoss_WRT() 248 | else: 249 | loader_batch = args.batch_size * args.num_pos 250 | #criterion_tri= OriTripletLoss(batch_size=loader_batch, margin=args.margin) 251 | criterion_tri= CenterTripletLoss(batch_size=loader_batch, margin=args.margin) 252 | 253 | criterion_id.to(device) 254 | criterion_tri.to(device) 255 | 256 | criterion_mmd = MMD_Loss().to(device) 257 | criterion_margin_mmd = MarginMMD_Loss(margin=args.margin_mmd).to(device) 258 | 259 | 260 | if args.optim == 'sgd': 261 | if args.pcb == 'on': 262 | ignored_params = list(map(id, net.local_conv_list.parameters())) \ 263 | + list(map(id, net.fc_list.parameters())) 264 | 265 | base_params = filter(lambda p: id(p) not in ignored_params, net.parameters()) 266 | 267 | optimizer = optim.SGD([ 268 | {'params': base_params, 'lr': 0.1 * args.lr}, 269 | {'params': net.local_conv_list.parameters(), 'lr': args.lr}, 270 | {'params': net.fc_list.parameters(), 'lr': args.lr} 271 | ], 272 | weight_decay=5e-4, momentum=0.9, nesterov=True) 273 | else: 274 | ignored_params = list(map(id, net.bottleneck.parameters())) \ 275 | + list(map(id, net.classifier.parameters())) 276 | 277 | base_params = filter(lambda p: id(p) not in ignored_params, net.parameters()) 278 | 279 | optimizer = optim.SGD([ 280 | {'params': base_params, 'lr': 0.1 * args.lr}, 281 | {'params': net.bottleneck.parameters(), 'lr': args.lr}, 282 | {'params': net.classifier.parameters(), 'lr': args.lr}], 283 | weight_decay=5e-4, momentum=0.9, nesterov=True) 284 | 285 | def adjust_learning_rate(optimizer, epoch): 286 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 287 | if epoch < 10: 288 | lr = args.lr * (epoch + 1) / 10 289 | elif epoch >= 10 and epoch < 20: 290 | lr = args.lr 291 | elif epoch >= 20 and epoch < 50: 292 | lr = args.lr * 0.1 293 | elif epoch >= 50: 294 | lr = args.lr * 0.01 295 | 296 | optimizer.param_groups[0]['lr'] = 0.1 * lr 297 | for i in range(len(optimizer.param_groups) - 1): 298 | optimizer.param_groups[i + 1]['lr'] = lr 299 | 300 | return lr 301 | 302 | 303 | def train(epoch): 304 | 305 | current_lr = adjust_learning_rate(optimizer, epoch) 306 | train_loss = AverageMeter() 307 | id_loss = AverageMeter() 308 | tri_loss = AverageMeter() 309 | data_time = AverageMeter() 310 | batch_time = AverageMeter() 311 | correct = 0 312 | total = 0 313 | 314 | # switch to train mode 315 | net.train() 316 | end = time.time() 317 | 318 | for batch_idx, (input1, input2, label1, label2) in enumerate(trainloader): 319 | 320 | labels = torch.cat((label1, label2), 0) 321 | 322 | input1 = Variable(input1.cuda()) 323 | input2 = Variable(input2.cuda()) 324 | 325 | labels = Variable(labels.cuda()) 326 | data_time.update(time.time() - end) 327 | 328 | 329 | if args.pcb == 'on': 330 | feat, out0, feat_all = net(input1, input2) 331 | loss_id = criterion_id(out0[0], labels) 332 | loss_tri_l, batch_acc = criterion_tri(feat[0], labels) 333 | for i in range(len(feat)-1): 334 | loss_id += criterion_id(out0[i+1], labels) 335 | loss_tri_l += criterion_tri(feat[i+1], labels)[0] 336 | loss_tri, batch_acc = criterion_tri(feat_all, labels) 337 | loss_tri += loss_tri_l * args.w_center # 338 | correct += batch_acc 339 | loss = loss_id + loss_tri 340 | else: 341 | feat, out0 = net(input1, input2) 342 | loss_id = criterion_id(out0, labels) 343 | 344 | loss_tri, batch_acc = criterion_tri(feat, labels) 345 | correct += (batch_acc / 2) 346 | _, predicted = out0.max(1) 347 | correct += (predicted.eq(labels).sum().item() / 2) 348 | loss = loss_id + loss_tri * args.w_center # 349 | 350 | if args.dist_disc == 'mmd': 351 | ## Apply Global MMD Loss on Pooling Layer 352 | feat_rgb, feat_ir = torch.split(feat, [label1.size(0),label2.size(0)], dim=0) 353 | loss_dist, l2max, expec = criterion_mmd(feat_rgb, feat_ir) ## Use Global MMD 354 | 355 | elif args.dist_disc == 'margin_mmd': 356 | ## Apply Margin MMD-ID Loss on Pooling Layer 357 | feat_rgb, feat_ir = torch.split(feat, [label1.size(0),label2.size(0)], dim=0) 358 | loss_dist, l2max, expec = criterion_margin_mmd(feat_rgb, feat_ir) ## Use MMD-ID 359 | 360 | 361 | if args.dist_disc is not None: 362 | loss = loss + loss_dist * args.dist_w ## Add Discrepancy Loss 363 | 364 | 365 | optimizer.zero_grad() 366 | loss.backward() 367 | optimizer.step() 368 | 369 | # update P 370 | train_loss.update(loss.item(), 2 * input1.size(0)) 371 | id_loss.update(loss_id.item(), 2 * input1.size(0)) 372 | tri_loss.update(loss_tri, 2 * input1.size(0)) 373 | total += labels.size(0) 374 | 375 | # measure elapsed time 376 | batch_time.update(time.time() - end) 377 | end = time.time() 378 | if batch_idx % 50 == 0: 379 | print('Epoch: [{}][{}/{}] ' 380 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 381 | 'lr:{:.3f} ' 382 | 'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f}) ' 383 | 'iLoss: {id_loss.val:.4f} ({id_loss.avg:.4f}) ' 384 | 'TLoss: {tri_loss.val:.4f} ({tri_loss.avg:.4f}) ' 385 | 'Accu: {:.2f}'.format( 386 | epoch, batch_idx, len(trainloader), current_lr, 387 | 100. * correct / total, batch_time=batch_time, 388 | train_loss=train_loss, id_loss=id_loss,tri_loss=tri_loss)) 389 | 390 | writer.add_scalar('total_loss', train_loss.avg, epoch) 391 | writer.add_scalar('id_loss', id_loss.avg, epoch) 392 | writer.add_scalar('tri_loss', tri_loss.avg, epoch) 393 | writer.add_scalar('lr', current_lr, epoch) 394 | 395 | 396 | def test(epoch): 397 | # switch to evaluation mode 398 | net.eval() 399 | print('Extracting Gallery Feature...') 400 | start = time.time() 401 | ptr = 0 402 | if args.pcb == 'on': 403 | feat_dim = args.num_strips * args.local_feat_dim 404 | else: 405 | feat_dim = 2048 406 | gall_feat = np.zeros((ngall, feat_dim)) 407 | gall_feat_att = np.zeros((ngall, feat_dim)) 408 | with torch.no_grad(): 409 | for batch_idx, (input, label) in enumerate(gall_loader): 410 | batch_num = input.size(0) 411 | input = Variable(input.cuda()) 412 | if args.pcb == 'on': 413 | feat = net(input, input, test_mode[0]) 414 | gall_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 415 | else: 416 | feat, feat_att = net(input, input, test_mode[0]) 417 | gall_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 418 | gall_feat_att[ptr:ptr + batch_num, :] = feat_att.detach().cpu().numpy() 419 | ptr = ptr + batch_num 420 | print('Extracting Time:\t {:.3f}'.format(time.time() - start)) 421 | 422 | # switch to evaluation 423 | net.eval() 424 | print('Extracting Query Feature...') 425 | start = time.time() 426 | ptr = 0 427 | 428 | query_feat = np.zeros((nquery, feat_dim)) 429 | query_feat_att = np.zeros((nquery, feat_dim)) 430 | with torch.no_grad(): 431 | for batch_idx, (input, label) in enumerate(query_loader): 432 | batch_num = input.size(0) 433 | input = Variable(input.cuda()) 434 | if args.pcb == 'on': 435 | feat = net(input, input, test_mode[1]) 436 | query_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 437 | else: 438 | feat, feat_att = net(input, input, test_mode[1]) 439 | query_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 440 | query_feat_att[ptr:ptr + batch_num, :] = feat_att.detach().cpu().numpy() 441 | ptr = ptr + batch_num 442 | print('Extracting Time:\t {:.3f}'.format(time.time() - start)) 443 | 444 | start = time.time() 445 | 446 | 447 | if args.re_rank == 'random_walk': 448 | distmat = random_walk(query_feat, gall_feat) 449 | if args.pcb == 'off': distmat_att = random_walk(query_feat_att, gall_feat_att) 450 | elif args.re_rank == 'k_reciprocal': 451 | distmat = k_reciprocal(query_feat, gall_feat) 452 | if args.pcb == 'off': distmat_att = k_reciprocal(query_feat_att, gall_feat_att) 453 | elif args.re_rank == 'no': 454 | # compute the similarity 455 | distmat = -np.matmul(query_feat, np.transpose(gall_feat)) 456 | if args.pcb == 'off': distmat_att = -np.matmul(query_feat_att, np.transpose(gall_feat_att)) 457 | 458 | # evaluation 459 | if dataset == 'regdb': 460 | cmc, mAP, mINP = eval_regdb(distmat, query_label, gall_label) 461 | if args.pcb == 'off': cmc_att, mAP_att, mINP_att = eval_regdb(distmat_att, query_label, gall_label) 462 | elif dataset == 'sysu': 463 | cmc, mAP, mINP = eval_sysu(distmat, query_label, gall_label, query_cam, gall_cam) 464 | if args.pcb == 'off': cmc_att, mAP_att, mINP_att = eval_sysu(distmat_att, query_label, gall_label, query_cam, gall_cam) 465 | print('Evaluation Time:\t {:.3f}'.format(time.time() - start)) 466 | 467 | writer.add_scalar('rank1', cmc[0], epoch) 468 | writer.add_scalar('mAP', mAP, epoch) 469 | writer.add_scalar('mINP', mINP, epoch) 470 | if args.pcb == 'off': 471 | writer.add_scalar('rank1_att', cmc_att[0], epoch) 472 | writer.add_scalar('mAP_att', mAP_att, epoch) 473 | writer.add_scalar('mINP_att', mINP_att, epoch) 474 | 475 | return cmc, mAP, mINP, cmc_att, mAP_att, mINP_att 476 | else: 477 | return cmc, mAP, mINP 478 | 479 | 480 | 481 | # training 482 | print('==> Start Training...') 483 | for epoch in range(start_epoch, 61 - start_epoch): 484 | 485 | print('==> Preparing Data Loader...') 486 | # identity sampler 487 | sampler = IdentitySampler(trainset.train_color_label, \ 488 | trainset.train_thermal_label, color_pos, thermal_pos, args.num_pos, args.batch_size, 489 | epoch) 490 | 491 | trainset.cIndex = sampler.index1 # color index 492 | trainset.tIndex = sampler.index2 # thermal index 493 | print(epoch) 494 | 495 | loader_batch = args.batch_size * args.num_pos 496 | 497 | trainloader = data.DataLoader(trainset, batch_size=loader_batch, \ 498 | sampler=sampler, num_workers=args.workers, drop_last=True) 499 | 500 | # training 501 | train(epoch) 502 | 503 | if epoch > 9 and epoch % 2 == 0: 504 | print('Test Epoch: {}'.format(epoch)) 505 | 506 | # testing 507 | if args.pcb == 'off': 508 | cmc, mAP, mINP, cmc_fc, mAP_fc, mINP_fc = test(epoch) 509 | else: 510 | cmc_fc, mAP_fc, mINP_fc = test(epoch) 511 | # save model 512 | if cmc_fc[0] > best_acc: # not the real best for sysu-mm01 513 | best_acc = cmc_fc[0] 514 | best_epoch = epoch 515 | best_mAP = mAP_fc 516 | best_mINP = mINP_fc 517 | state = { 518 | 'net': net.state_dict(), 519 | 'cmc': cmc_fc, 520 | 'mAP': mAP_fc, 521 | 'mINP': mINP_fc, 522 | 'epoch': epoch, 523 | } 524 | torch.save(state, checkpoint_path + suffix + '_best.t') 525 | 526 | if args.pcb == 'off': 527 | print('POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 528 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 529 | 530 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 531 | cmc_fc[0], cmc_fc[4], cmc_fc[9], cmc_fc[19], mAP_fc, mINP_fc)) 532 | print('Best Epoch [{}], Rank-1: {:.2%} | mAP: {:.2%}| mINP: {:.2%}'.format(best_epoch, best_acc, best_mAP, best_mINP)) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data.sampler import Sampler 4 | import sys 5 | import os.path as osp 6 | import torch 7 | 8 | def load_data(input_data_path ): 9 | with open(input_data_path) as f: 10 | data_file_list = open(input_data_path, 'rt').read().splitlines() 11 | # Get full list of color image and labels 12 | file_image = [s.split(' ')[0] for s in data_file_list] 13 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 14 | 15 | return file_image, file_label 16 | 17 | 18 | def GenIdx( train_color_label, train_thermal_label): 19 | color_pos = [] 20 | unique_label_color = np.unique(train_color_label) 21 | for i in range(len(unique_label_color)): 22 | tmp_pos = [k for k,v in enumerate(train_color_label) if v==unique_label_color[i]] 23 | color_pos.append(tmp_pos) 24 | 25 | thermal_pos = [] 26 | unique_label_thermal = np.unique(train_thermal_label) 27 | for i in range(len(unique_label_thermal)): 28 | tmp_pos = [k for k,v in enumerate(train_thermal_label) if v==unique_label_thermal[i]] 29 | thermal_pos.append(tmp_pos) 30 | return color_pos, thermal_pos 31 | 32 | def GenCamIdx(gall_img, gall_label, mode): 33 | if mode =='indoor': 34 | camIdx = [1,2] 35 | else: 36 | camIdx = [1,2,4,5] 37 | gall_cam = [] 38 | for i in range(len(gall_img)): 39 | gall_cam.append(int(gall_img[i][-10])) 40 | 41 | sample_pos = [] 42 | unique_label = np.unique(gall_label) 43 | for i in range(len(unique_label)): 44 | for j in range(len(camIdx)): 45 | id_pos = [k for k,v in enumerate(gall_label) if v==unique_label[i] and gall_cam[k]==camIdx[j]] 46 | if id_pos: 47 | sample_pos.append(id_pos) 48 | return sample_pos 49 | 50 | def ExtractCam(gall_img): 51 | gall_cam = [] 52 | for i in range(len(gall_img)): 53 | cam_id = int(gall_img[i][-10]) 54 | # if cam_id ==3: 55 | # cam_id = 2 56 | gall_cam.append(cam_id) 57 | 58 | return np.array(gall_cam) 59 | 60 | class IdentitySampler(Sampler): 61 | """Sample person identities evenly in each batch. 62 | Args: 63 | train_color_label, train_thermal_label: labels of two modalities 64 | color_pos, thermal_pos: positions of each identity 65 | batchSize: batch size 66 | """ 67 | 68 | def __init__(self, train_color_label, train_thermal_label, color_pos, thermal_pos, num_pos, batchSize, epoch): 69 | uni_label = np.unique(train_color_label) 70 | self.n_classes = len(uni_label) 71 | 72 | 73 | N = np.maximum(len(train_color_label), len(train_thermal_label)) 74 | for j in range(int(N/(batchSize*num_pos))+1): 75 | batch_idx = np.random.choice(uni_label, batchSize, replace = False) 76 | for i in range(batchSize): 77 | sample_color = np.random.choice(color_pos[batch_idx[i]], num_pos) 78 | sample_thermal = np.random.choice(thermal_pos[batch_idx[i]], num_pos) 79 | 80 | if j ==0 and i==0: 81 | index1= sample_color 82 | index2= sample_thermal 83 | else: 84 | index1 = np.hstack((index1, sample_color)) 85 | index2 = np.hstack((index2, sample_thermal)) 86 | 87 | self.index1 = index1 88 | self.index2 = index2 89 | self.N = N 90 | 91 | def __iter__(self): 92 | return iter(np.arange(len(self.index1))) 93 | 94 | def __len__(self): 95 | return self.N 96 | 97 | class AverageMeter(object): 98 | """Computes and stores the average and current value""" 99 | def __init__(self): 100 | self.reset() 101 | 102 | def reset(self): 103 | self.val = 0 104 | self.avg = 0 105 | self.sum = 0 106 | self.count = 0 107 | 108 | def update(self, val, n=1): 109 | self.val = val 110 | self.sum += val * n 111 | self.count += n 112 | self.avg = self.sum / self.count 113 | 114 | def mkdir_if_missing(directory): 115 | if not osp.exists(directory): 116 | try: 117 | os.makedirs(directory) 118 | except OSError as e: 119 | if e.errno != errno.EEXIST: 120 | raise 121 | class Logger(object): 122 | """ 123 | Write console output to external text file. 124 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 125 | """ 126 | def __init__(self, fpath=None): 127 | self.console = sys.stdout 128 | self.file = None 129 | if fpath is not None: 130 | mkdir_if_missing(osp.dirname(fpath)) 131 | self.file = open(fpath, 'w') 132 | 133 | def __del__(self): 134 | self.close() 135 | 136 | def __enter__(self): 137 | pass 138 | 139 | def __exit__(self, *args): 140 | self.close() 141 | 142 | def write(self, msg): 143 | self.console.write(msg) 144 | if self.file is not None: 145 | self.file.write(msg) 146 | 147 | def flush(self): 148 | self.console.flush() 149 | if self.file is not None: 150 | self.file.flush() 151 | os.fsync(self.file.fileno()) 152 | 153 | def close(self): 154 | self.console.close() 155 | if self.file is not None: 156 | self.file.close() 157 | 158 | def set_seed(seed, cuda=True): 159 | np.random.seed(seed) 160 | torch.manual_seed(seed) 161 | if cuda: 162 | torch.cuda.manual_seed(seed) 163 | 164 | def set_requires_grad(nets, requires_grad=False): 165 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 166 | Parameters: 167 | nets (network list) -- a list of networks 168 | requires_grad (bool) -- whether the networks require gradients or not 169 | """ 170 | if not isinstance(nets, list): 171 | nets = [nets] 172 | for net in nets: 173 | if net is not None: 174 | for param in net.parameters(): 175 | param.requires_grad = requires_grad --------------------------------------------------------------------------------