├── AGW ├── README.md ├── data_loader.py ├── data_manager.py ├── eval_metrics.py ├── loss.py ├── model.py ├── pre_process_sysu.py ├── resnet.py ├── test.py ├── train.py └── utils.py ├── DDAG ├── LICENSE ├── README.md ├── attention.py ├── data_loader.py ├── data_manager.py ├── eval_metrics.py ├── loss.py ├── model_main.py ├── pre_process_sysu.py ├── resnet.py ├── test_ddag.py ├── train_ddag.py └── utils.py ├── README.md └── pipeline.png /AGW/README.md: -------------------------------------------------------------------------------- 1 | # Cross-Modal-Re-ID-baseline (AGW) 2 | Pytorch Code for Cross-Modality Person Re-Identification (Visible Thermal Re-ID) on RegDB dataset [1] and SYSU-MM01 dataset [2]. 3 | 4 | We adopt the two-stream network structure introduced in [3]. ResNet50 is adopted as the backbone. The softmax loss is adopted as the baseline. 5 | 6 | |Datasets | Pretrained| Rank@1 | mAP | mINP | Model| 7 | | -------- | ----- | ----- | ----- | ----- |------| 8 | |#RegDB | ImageNet | ~ 70.05% | ~ 66.37%| ~50.19% |----- | 9 | |#SYSU-MM01 | ImageNet | ~ 47.50% | ~ 47.65% | ~35.30% | [GoogleDrive](https://drive.google.com/open?id=181K9PQGnej0K5xNX9DRBDPAf3K9JosYk)| 10 | 11 | *Both of these two datasets may have some fluctuation due to random spliting. The results might be better by finetuning the hyper-parameters. 12 | 13 | ### 1. Prepare the datasets. 14 | 15 | - (1) RegDB Dataset [1]: The RegDB dataset can be downloaded from this [website](http://dm.dongguk.edu/link.html) by submitting a copyright form. 16 | 17 | - (Named: "Dongguk Body-based Person Recognition Database (DBPerson-Recog-DB1)" on their website). 18 | 19 | - A private download link can be requested via sending me an email (mangye16@gmail.com). 20 | 21 | - (2) SYSU-MM01 Dataset [2]: The SYSU-MM01 dataset can be downloaded from this [website](http://isee.sysu.edu.cn/project/RGBIRReID.htm). 22 | 23 | - run `python pre_process_sysu.py` to pepare the dataset, the training data will be stored in ".npy" format. 24 | 25 | ### 2. Training. 26 | Train a model by 27 | ```bash 28 | python train.py --dataset sysu --lr 0.1 --method agw --gpu 1 29 | ``` 30 | 31 | - `--dataset`: which dataset "sysu" or "regdb". 32 | 33 | - `--lr`: initial learning rate. 34 | 35 | - `--method`: method to run or baseline. 36 | 37 | - `--gpu`: which gpu to run. 38 | 39 | You may need mannully define the data path first. 40 | 41 | **Parameters**: More parameters can be found in the script. 42 | 43 | **Sampling Strategy**: N (= bacth size) person identities are randomly sampled at each step, then randomly select four visible and four thermal image. Details can be found in Line 302-307 in `train.py`. 44 | 45 | **Training Log**: The training log will be saved in `log/" dataset_name"+ log`. Model will be saved in `save_model/`. 46 | 47 | ### 3. Testing. 48 | 49 | Test a model on SYSU-MM01 or RegDB dataset by 50 | ```bash 51 | python test.py --mode all --resume 'model_path' --gpu 1 --dataset sysu 52 | ``` 53 | - `--dataset`: which dataset "sysu" or "regdb". 54 | 55 | - `--mode`: "all" or "indoor" all search or indoor search (only for sysu dataset). 56 | 57 | - `--trial`: testing trial (only for RegDB dataset). 58 | 59 | - `--resume`: the saved model path. 60 | 61 | - `--gpu`: which gpu to run. 62 | 63 | ### 4. Citation 64 | 65 | Please kindly cite this paper in your publications if it helps your research: 66 | ``` 67 | @article{arxiv20reidsurvey, 68 | title={Deep Learning for Person Re-identification: A Survey and Outlook}, 69 | author={Ye, Mang and Shen, Jianbing and Lin, Gaojie and Xiang, Tao and Shao, Ling and Hoi, Steven C. H.}, 70 | journal={arXiv preprint arXiv:2001.04193}, 71 | year={2020}, 72 | } 73 | ``` 74 | 75 | ### 5. References. 76 | [1] D. T. Nguyen, H. G. Hong, K. W. Kim, and K. R. Park. Person recognition system based on a combination of body images from visible 77 | light and thermal cameras. Sensors, 17(3):605, 2017. 78 | 79 | [2] A. Wu, W.-s. Zheng, H.-X. Yu, S. Gong, and J. Lai. Rgb-infrared crossmodality person re-identification. In IEEE International Conference on Computer Vision (ICCV), pages 5380–5389, 2017. 80 | 81 | [3] M. Ye, Z. Wang, X. Lan, and P. C. Yuen. Visible thermal person reidentification via dual-constrained top-ranking. In International Joint Conference on Artificial Intelligence (IJCAI), pages 1092–1099, 2018. 82 | 83 | Contact: mangye16@gmail.com 84 | -------------------------------------------------------------------------------- /AGW/data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch.utils.data as data 4 | 5 | 6 | class SYSUData(data.Dataset): 7 | def __init__(self, data_dir, transform=None, colorIndex = None, thermalIndex = None): 8 | 9 | data_dir = '../SYSU_MM01/' 10 | # Load training images (path) and labels 11 | train_color_image = np.load(data_dir + 'train_rgb_resized_img.npy') 12 | self.train_color_label = np.load(data_dir + 'train_rgb_resized_label.npy') 13 | 14 | train_thermal_image = np.load(data_dir + 'train_ir_resized_img.npy') 15 | self.train_thermal_label = np.load(data_dir + 'train_ir_resized_label.npy') 16 | 17 | # BGR to RGB 18 | self.train_color_image = train_color_image 19 | self.train_thermal_image = train_thermal_image 20 | self.transform = transform 21 | self.cIndex = colorIndex 22 | self.tIndex = thermalIndex 23 | 24 | def __getitem__(self, index): 25 | 26 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 27 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 28 | 29 | img1 = self.transform(img1) 30 | img2 = self.transform(img2) 31 | 32 | return img1, img2, target1, target2 33 | 34 | def __len__(self): 35 | return len(self.train_color_label) 36 | 37 | 38 | class RegDBData(data.Dataset): 39 | def __init__(self, data_dir, trial, transform=None, colorIndex = None, thermalIndex = None): 40 | # Load training images (path) and labels 41 | data_dir = 'RegDB/' 42 | train_color_list = data_dir + 'idx/train_visible_{}'.format(trial)+ '.txt' 43 | train_thermal_list = data_dir + 'idx/train_thermal_{}'.format(trial)+ '.txt' 44 | 45 | color_img_file, train_color_label = load_data(train_color_list) 46 | thermal_img_file, train_thermal_label = load_data(train_thermal_list) 47 | 48 | train_color_image = [] 49 | for i in range(len(color_img_file)): 50 | 51 | img = Image.open(data_dir+ color_img_file[i]) 52 | img = img.resize((144, 288), Image.ANTIALIAS) 53 | pix_array = np.array(img) 54 | train_color_image.append(pix_array) 55 | train_color_image = np.array(train_color_image) 56 | 57 | train_thermal_image = [] 58 | for i in range(len(thermal_img_file)): 59 | img = Image.open(data_dir+ thermal_img_file[i]) 60 | img = img.resize((144, 288), Image.ANTIALIAS) 61 | pix_array = np.array(img) 62 | train_thermal_image.append(pix_array) 63 | train_thermal_image = np.array(train_thermal_image) 64 | 65 | # BGR to RGB 66 | self.train_color_image = train_color_image 67 | self.train_color_label = train_color_label 68 | 69 | # BGR to RGB 70 | self.train_thermal_image = train_thermal_image 71 | self.train_thermal_label = train_thermal_label 72 | 73 | self.transform = transform 74 | self.cIndex = colorIndex 75 | self.tIndex = thermalIndex 76 | 77 | def __getitem__(self, index): 78 | 79 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 80 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 81 | 82 | img1 = self.transform(img1) 83 | img2 = self.transform(img2) 84 | 85 | return img1, img2, target1, target2 86 | 87 | def __len__(self): 88 | return len(self.train_color_label) 89 | 90 | class TestData(data.Dataset): 91 | def __init__(self, test_img_file, test_label, transform=None, img_size = (144,288)): 92 | 93 | test_image = [] 94 | for i in range(len(test_img_file)): 95 | img = Image.open(test_img_file[i]) 96 | img = img.resize((img_size[0], img_size[1]), Image.ANTIALIAS) 97 | pix_array = np.array(img) 98 | test_image.append(pix_array) 99 | test_image = np.array(test_image) 100 | self.test_image = test_image 101 | self.test_label = test_label 102 | self.transform = transform 103 | 104 | def __getitem__(self, index): 105 | img1, target1 = self.test_image[index], self.test_label[index] 106 | img1 = self.transform(img1) 107 | return img1, target1 108 | 109 | def __len__(self): 110 | return len(self.test_image) 111 | 112 | class TestDataOld(data.Dataset): 113 | def __init__(self, data_dir, test_img_file, test_label, transform=None, img_size = (144,288)): 114 | 115 | test_image = [] 116 | for i in range(len(test_img_file)): 117 | img = Image.open(data_dir + test_img_file[i]) 118 | img = img.resize((img_size[0], img_size[1]), Image.ANTIALIAS) 119 | pix_array = np.array(img) 120 | test_image.append(pix_array) 121 | test_image = np.array(test_image) 122 | self.test_image = test_image 123 | self.test_label = test_label 124 | self.transform = transform 125 | 126 | def __getitem__(self, index): 127 | img1, target1 = self.test_image[index], self.test_label[index] 128 | img1 = self.transform(img1) 129 | return img1, target1 130 | 131 | def __len__(self): 132 | return len(self.test_image) 133 | def load_data(input_data_path ): 134 | with open(input_data_path) as f: 135 | data_file_list = open(input_data_path, 'rt').read().splitlines() 136 | # Get full list of image and labels 137 | file_image = [s.split(' ')[0] for s in data_file_list] 138 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 139 | 140 | return file_image, file_label 141 | -------------------------------------------------------------------------------- /AGW/data_manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import numpy as np 4 | import random 5 | 6 | def process_query_sysu(data_path, mode = 'all', relabel=False): 7 | if mode== 'all': 8 | ir_cameras = ['cam3','cam6'] 9 | elif mode =='indoor': 10 | ir_cameras = ['cam3','cam6'] 11 | 12 | file_path = os.path.join(data_path,'exp/test_id.txt') 13 | files_rgb = [] 14 | files_ir = [] 15 | 16 | with open(file_path, 'r') as file: 17 | ids = file.read().splitlines() 18 | ids = [int(y) for y in ids[0].split(',')] 19 | ids = ["%04d" % x for x in ids] 20 | 21 | for id in sorted(ids): 22 | for cam in ir_cameras: 23 | img_dir = os.path.join(data_path,cam,id) 24 | if os.path.isdir(img_dir): 25 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 26 | files_ir.extend(new_files) 27 | query_img = [] 28 | query_id = [] 29 | query_cam = [] 30 | for img_path in files_ir: 31 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 32 | query_img.append(img_path) 33 | query_id.append(pid) 34 | query_cam.append(camid) 35 | return query_img, np.array(query_id), np.array(query_cam) 36 | 37 | def process_gallery_sysu(data_path, mode = 'all', trial = 0, relabel=False): 38 | 39 | random.seed(trial) 40 | 41 | if mode== 'all': 42 | rgb_cameras = ['cam1','cam2','cam4','cam5'] 43 | elif mode =='indoor': 44 | rgb_cameras = ['cam1','cam2'] 45 | 46 | file_path = os.path.join(data_path,'exp/test_id.txt') 47 | files_rgb = [] 48 | with open(file_path, 'r') as file: 49 | ids = file.read().splitlines() 50 | ids = [int(y) for y in ids[0].split(',')] 51 | ids = ["%04d" % x for x in ids] 52 | 53 | for id in sorted(ids): 54 | for cam in rgb_cameras: 55 | img_dir = os.path.join(data_path,cam,id) 56 | if os.path.isdir(img_dir): 57 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 58 | files_rgb.append(random.choice(new_files)) 59 | gall_img = [] 60 | gall_id = [] 61 | gall_cam = [] 62 | for img_path in files_rgb: 63 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 64 | gall_img.append(img_path) 65 | gall_id.append(pid) 66 | gall_cam.append(camid) 67 | return gall_img, np.array(gall_id), np.array(gall_cam) 68 | 69 | 70 | def process_gallery_sysu_multishot(data_path, mode='all', trial=0, relabel=False): 71 | random.seed(trial) 72 | 73 | if mode == 'all': 74 | rgb_cameras = ['cam1', 'cam2', 'cam4', 'cam5'] 75 | elif mode == 'indoor': 76 | rgb_cameras = ['cam1', 'cam2'] 77 | 78 | file_path = os.path.join(data_path, 'exp/test_id.txt') 79 | files_rgb = [] 80 | with open(file_path, 'r') as file: 81 | ids = file.read().splitlines() 82 | ids = [int(y) for y in ids[0].split(',')] 83 | ids = ["%04d" % x for x in ids] 84 | 85 | for id in sorted(ids): 86 | for cam in rgb_cameras: 87 | img_dir = os.path.join(data_path, cam, id) 88 | if os.path.isdir(img_dir): 89 | new_files = sorted([img_dir + '/' + i for i in os.listdir(img_dir)]) 90 | files_rgb = files_rgb + random.sample(new_files, 10) 91 | gall_img = [] 92 | gall_id = [] 93 | gall_cam = [] 94 | for img_path in files_rgb: 95 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 96 | gall_img.append(img_path) 97 | gall_id.append(pid) 98 | gall_cam.append(camid) 99 | return gall_img, np.array(gall_id), np.array(gall_cam) 100 | 101 | def process_test_regdb(img_dir, trial = 1, modal = 'visible'): 102 | if modal=='visible': 103 | input_data_path = img_dir + 'idx/test_visible_{}'.format(trial) + '.txt' 104 | elif modal=='thermal': 105 | input_data_path = img_dir + 'idx/test_thermal_{}'.format(trial) + '.txt' 106 | 107 | with open(input_data_path) as f: 108 | data_file_list = open(input_data_path, 'rt').read().splitlines() 109 | # Get full list of image and labels 110 | file_image = [img_dir + '/' + s.split(' ')[0] for s in data_file_list] 111 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 112 | 113 | return file_image, np.array(file_label) -------------------------------------------------------------------------------- /AGW/eval_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import numpy as np 3 | """Cross-Modality ReID""" 4 | import pdb 5 | 6 | def eval_sysu(distmat, q_pids, g_pids, q_camids, g_camids, max_rank = 20): 7 | """Evaluation with sysu metric 8 | Key: for each query identity, its gallery images from the same camera view are discarded. "Following the original setting in ite dataset" 9 | """ 10 | num_q, num_g = distmat.shape 11 | if num_g < max_rank: 12 | max_rank = num_g 13 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 14 | indices = np.argsort(distmat, axis=1) 15 | pred_label = g_pids[indices] 16 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 17 | 18 | # compute cmc curve for each query 19 | new_all_cmc = [] 20 | all_cmc = [] 21 | all_AP = [] 22 | all_INP = [] 23 | num_valid_q = 0. # number of valid query 24 | for q_idx in range(num_q): 25 | # get query pid and camid 26 | q_pid = q_pids[q_idx] 27 | q_camid = q_camids[q_idx] 28 | 29 | # remove gallery samples that have the same pid and camid with query 30 | order = indices[q_idx] 31 | remove = (q_camid == 3) & (g_camids[order] == 2) 32 | keep = np.invert(remove) 33 | 34 | # compute cmc curve 35 | # the cmc calculation is different from standard protocol 36 | # we follow the protocol of the author's released code 37 | new_cmc = pred_label[q_idx][keep] 38 | new_index = np.unique(new_cmc, return_index=True)[1] 39 | new_cmc = [new_cmc[index] for index in sorted(new_index)] 40 | 41 | new_match = (new_cmc == q_pid).astype(np.int32) 42 | new_cmc = new_match.cumsum() 43 | new_all_cmc.append(new_cmc[:max_rank]) 44 | 45 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 46 | if not np.any(orig_cmc): 47 | # this condition is true when query identity does not appear in gallery 48 | continue 49 | 50 | cmc = orig_cmc.cumsum() 51 | 52 | # compute mINP 53 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook 54 | pos_idx = np.where(orig_cmc == 1) 55 | pos_max_idx = np.max(pos_idx) 56 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0) 57 | all_INP.append(inp) 58 | 59 | cmc[cmc > 1] = 1 60 | 61 | all_cmc.append(cmc[:max_rank]) 62 | num_valid_q += 1. 63 | 64 | # compute average precision 65 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 66 | num_rel = orig_cmc.sum() 67 | tmp_cmc = orig_cmc.cumsum() 68 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 69 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 70 | AP = tmp_cmc.sum() / num_rel 71 | all_AP.append(AP) 72 | 73 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 74 | 75 | all_cmc = np.asarray(all_cmc).astype(np.float32) 76 | all_cmc = all_cmc.sum(0) / num_valid_q # standard CMC 77 | 78 | new_all_cmc = np.asarray(new_all_cmc).astype(np.float32) 79 | new_all_cmc = new_all_cmc.sum(0) / num_valid_q 80 | mAP = np.mean(all_AP) 81 | mINP = np.mean(all_INP) 82 | return new_all_cmc, mAP, mINP 83 | 84 | 85 | 86 | def eval_regdb(distmat, q_pids, g_pids, max_rank = 20): 87 | num_q, num_g = distmat.shape 88 | if num_g < max_rank: 89 | max_rank = num_g 90 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 91 | indices = np.argsort(distmat, axis=1) 92 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 93 | 94 | # compute cmc curve for each query 95 | all_cmc = [] 96 | all_AP = [] 97 | all_INP = [] 98 | num_valid_q = 0. # number of valid query 99 | 100 | # only two cameras 101 | q_camids = np.ones(num_q).astype(np.int32) 102 | g_camids = 2* np.ones(num_g).astype(np.int32) 103 | 104 | for q_idx in range(num_q): 105 | # get query pid and camid 106 | q_pid = q_pids[q_idx] 107 | q_camid = q_camids[q_idx] 108 | 109 | # remove gallery samples that have the same pid and camid with query 110 | order = indices[q_idx] 111 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 112 | keep = np.invert(remove) 113 | 114 | # compute cmc curve 115 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 116 | if not np.any(raw_cmc): 117 | # this condition is true when query identity does not appear in gallery 118 | continue 119 | 120 | cmc = raw_cmc.cumsum() 121 | 122 | # compute mINP 123 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook 124 | pos_idx = np.where(raw_cmc == 1) 125 | pos_max_idx = np.max(pos_idx) 126 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0) 127 | all_INP.append(inp) 128 | 129 | cmc[cmc > 1] = 1 130 | 131 | all_cmc.append(cmc[:max_rank]) 132 | num_valid_q += 1. 133 | 134 | # compute average precision 135 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 136 | num_rel = raw_cmc.sum() 137 | tmp_cmc = raw_cmc.cumsum() 138 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 139 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 140 | AP = tmp_cmc.sum() / num_rel 141 | all_AP.append(AP) 142 | 143 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 144 | 145 | all_cmc = np.asarray(all_cmc).astype(np.float32) 146 | all_cmc = all_cmc.sum(0) / num_valid_q 147 | mAP = np.mean(all_AP) 148 | mINP = np.mean(all_INP) 149 | return all_cmc, mAP, mINP -------------------------------------------------------------------------------- /AGW/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd.function import Function 6 | from torch.autograd import Variable 7 | 8 | 9 | class OriTripletLoss(nn.Module): 10 | """Triplet loss with hard positive/negative mining. 11 | 12 | Reference: 13 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 14 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 15 | 16 | Args: 17 | - margin (float): margin for triplet. 18 | """ 19 | 20 | def __init__(self, batch_size, margin=0.3): 21 | super(OriTripletLoss, self).__init__() 22 | self.margin = margin 23 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 24 | 25 | def forward(self, inputs, targets): 26 | """ 27 | Args: 28 | - inputs: feature matrix with shape (batch_size, feat_dim) 29 | - targets: ground truth labels with shape (num_classes) 30 | """ 31 | n = inputs.size(0) 32 | 33 | # Compute pairwise distance, replace by the official when merged 34 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 35 | dist = dist + dist.t() 36 | dist.addmm_(1, -2, inputs, inputs.t()) 37 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 38 | 39 | # For each anchor, find the hardest positive and negative 40 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 41 | dist_ap, dist_an = [], [] 42 | for i in range(n): 43 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 44 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 45 | dist_ap = torch.cat(dist_ap) 46 | dist_an = torch.cat(dist_an) 47 | 48 | # Compute ranking hinge loss 49 | y = torch.ones_like(dist_an) 50 | loss = self.ranking_loss(dist_an, dist_ap, y) 51 | 52 | # compute accuracy 53 | correct = torch.ge(dist_an, dist_ap).sum().item() 54 | return loss, correct 55 | 56 | 57 | 58 | 59 | 60 | # Adaptive weights 61 | def softmax_weights(dist, mask): 62 | max_v = torch.max(dist * mask, dim=1, keepdim=True)[0] 63 | diff = dist - max_v 64 | Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero 65 | W = torch.exp(diff) * mask / Z 66 | return W 67 | 68 | def normalize(x, axis=-1): 69 | """Normalizing to unit length along the specified dimension. 70 | Args: 71 | x: pytorch Variable 72 | Returns: 73 | x: pytorch Variable, same shape as input 74 | """ 75 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 76 | return x 77 | 78 | class TripletLoss_WRT(nn.Module): 79 | """Weighted Regularized Triplet'.""" 80 | 81 | def __init__(self): 82 | super(TripletLoss_WRT, self).__init__() 83 | self.ranking_loss = nn.SoftMarginLoss() 84 | 85 | def forward(self, inputs, targets, normalize_feature=False): 86 | if normalize_feature: 87 | inputs = normalize(inputs, axis=-1) 88 | dist_mat = pdist_torch(inputs, inputs) 89 | 90 | N = dist_mat.size(0) 91 | # shape [N, N] 92 | is_pos = targets.expand(N, N).eq(targets.expand(N, N).t()).float() 93 | is_neg = targets.expand(N, N).ne(targets.expand(N, N).t()).float() 94 | 95 | # `dist_ap` means distance(anchor, positive) 96 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 97 | dist_ap = dist_mat * is_pos 98 | dist_an = dist_mat * is_neg 99 | 100 | weights_ap = softmax_weights(dist_ap, is_pos) 101 | weights_an = softmax_weights(-dist_an, is_neg) 102 | furthest_positive = torch.sum(dist_ap * weights_ap, dim=1) 103 | closest_negative = torch.sum(dist_an * weights_an, dim=1) 104 | 105 | y = furthest_positive.new().resize_as_(furthest_positive).fill_(1) 106 | loss = self.ranking_loss(closest_negative - furthest_positive, y) 107 | 108 | 109 | # compute accuracy 110 | correct = torch.ge(closest_negative, furthest_positive).sum().item() 111 | return loss, correct 112 | 113 | def pdist_torch(emb1, emb2): 114 | ''' 115 | compute the eucilidean distance matrix between embeddings1 and embeddings2 116 | using gpu 117 | ''' 118 | m, n = emb1.shape[0], emb2.shape[0] 119 | emb1_pow = torch.pow(emb1, 2).sum(dim = 1, keepdim = True).expand(m, n) 120 | emb2_pow = torch.pow(emb2, 2).sum(dim = 1, keepdim = True).expand(n, m).t() 121 | dist_mtx = emb1_pow + emb2_pow 122 | dist_mtx = dist_mtx.addmm_(1, -2, emb1, emb2.t()) 123 | # dist_mtx = dist_mtx.clamp(min = 1e-12) 124 | dist_mtx = dist_mtx.clamp(min = 1e-12).sqrt() 125 | return dist_mtx 126 | 127 | 128 | def pdist_np(emb1, emb2): 129 | ''' 130 | compute the eucilidean distance matrix between embeddings1 and embeddings2 131 | using cpu 132 | ''' 133 | m, n = emb1.shape[0], emb2.shape[0] 134 | emb1_pow = np.square(emb1).sum(axis = 1)[..., np.newaxis] 135 | emb2_pow = np.square(emb2).sum(axis = 1)[np.newaxis, ...] 136 | dist_mtx = -2 * np.matmul(emb1, emb2.T) + emb1_pow + emb2_pow 137 | # dist_mtx = np.sqrt(dist_mtx.clip(min = 1e-12)) 138 | return dist_mtx -------------------------------------------------------------------------------- /AGW/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from resnet import resnet50, resnet18 5 | #from vgg import vgg16_bn 6 | 7 | class Normalize(nn.Module): 8 | def __init__(self, power=2): 9 | super(Normalize, self).__init__() 10 | self.power = power 11 | 12 | def forward(self, x): 13 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 14 | out = x.div(norm) 15 | return out 16 | 17 | class Non_local(nn.Module): 18 | def __init__(self, in_channels, reduc_ratio=2): 19 | super(Non_local, self).__init__() 20 | 21 | self.in_channels = in_channels 22 | self.inter_channels = reduc_ratio//reduc_ratio 23 | 24 | self.g = nn.Sequential( 25 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, 26 | padding=0), 27 | ) 28 | 29 | self.W = nn.Sequential( 30 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, 31 | kernel_size=1, stride=1, padding=0), 32 | nn.BatchNorm2d(self.in_channels), 33 | ) 34 | nn.init.constant_(self.W[1].weight, 0.0) 35 | nn.init.constant_(self.W[1].bias, 0.0) 36 | 37 | 38 | 39 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 40 | kernel_size=1, stride=1, padding=0) 41 | 42 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 43 | kernel_size=1, stride=1, padding=0) 44 | 45 | def forward(self, x): 46 | ''' 47 | :param x: (b, c, t, h, w) 48 | :return: 49 | ''' 50 | 51 | batch_size = x.size(0) 52 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 53 | g_x = g_x.permute(0, 2, 1) 54 | 55 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 56 | theta_x = theta_x.permute(0, 2, 1) 57 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 58 | f = torch.matmul(theta_x, phi_x) 59 | N = f.size(-1) 60 | # f_div_C = torch.nn.functional.softmax(f, dim=-1) 61 | f_div_C = f / N 62 | 63 | y = torch.matmul(f_div_C, g_x) 64 | y = y.permute(0, 2, 1).contiguous() 65 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 66 | W_y = self.W(y) 67 | z = W_y + x 68 | 69 | return z 70 | 71 | 72 | # ##################################################################### 73 | def weights_init_kaiming(m): 74 | classname = m.__class__.__name__ 75 | # print(classname) 76 | if classname.find('Conv') != -1: 77 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 78 | elif classname.find('Linear') != -1: 79 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 80 | init.zeros_(m.bias.data) 81 | elif classname.find('BatchNorm1d') != -1: 82 | init.normal_(m.weight.data, 1.0, 0.01) 83 | init.zeros_(m.bias.data) 84 | 85 | def weights_init_classifier(m): 86 | classname = m.__class__.__name__ 87 | if classname.find('Linear') != -1: 88 | init.normal_(m.weight.data, 0, 0.001) 89 | if m.bias: 90 | init.zeros_(m.bias.data) 91 | 92 | 93 | 94 | class visible_module(nn.Module): 95 | def __init__(self): 96 | super(visible_module, self).__init__() 97 | 98 | model_v = resnet50(pretrained=False, 99 | last_conv_stride=1, last_conv_dilation=1) 100 | # avg pooling to global pooling 101 | self.visible = model_v 102 | 103 | def forward(self, x): 104 | x = self.visible.conv1(x) 105 | x = self.visible.bn1(x) 106 | x = self.visible.relu(x) 107 | x = self.visible.maxpool(x) 108 | return x 109 | 110 | 111 | class thermal_module(nn.Module): 112 | def __init__(self): 113 | super(thermal_module, self).__init__() 114 | 115 | model_t = resnet50(pretrained=False, 116 | last_conv_stride=1, last_conv_dilation=1) 117 | # avg pooling to global pooling 118 | self.thermal = model_t 119 | 120 | def forward(self, x): 121 | x = self.thermal.conv1(x) 122 | x = self.thermal.bn1(x) 123 | x = self.thermal.relu(x) 124 | x = self.thermal.maxpool(x) 125 | return x 126 | 127 | 128 | class base_resnet(nn.Module): 129 | def __init__(self): 130 | super(base_resnet, self).__init__() 131 | 132 | model_base = resnet50(pretrained=False, 133 | last_conv_stride=1, last_conv_dilation=1) 134 | # avg pooling to global pooling 135 | model_base.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 136 | self.base = model_base 137 | 138 | def forward(self, x): 139 | x = self.base.layer1(x) 140 | x = self.base.layer2(x) 141 | x = self.base.layer3(x) 142 | x = self.base.layer4(x) 143 | return x 144 | 145 | 146 | class embed_net(nn.Module): 147 | def __init__(self, class_num, no_local= 'on', gm_pool = 'on', arch='resnet50'): 148 | super(embed_net, self).__init__() 149 | 150 | self.thermal_module = thermal_module() 151 | self.visible_module = visible_module() 152 | self.base_resnet = base_resnet() 153 | self.non_local = no_local 154 | if self.non_local =='on': 155 | layers=[3, 4, 6, 3] 156 | non_layers=[0,2,3,0] 157 | self.NL_1 = nn.ModuleList( 158 | [Non_local(256) for i in range(non_layers[0])]) 159 | self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])]) 160 | self.NL_2 = nn.ModuleList( 161 | [Non_local(512) for i in range(non_layers[1])]) 162 | self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])]) 163 | self.NL_3 = nn.ModuleList( 164 | [Non_local(1024) for i in range(non_layers[2])]) 165 | self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])]) 166 | self.NL_4 = nn.ModuleList( 167 | [Non_local(2048) for i in range(non_layers[3])]) 168 | self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])]) 169 | 170 | 171 | pool_dim = 2048 172 | self.l2norm = Normalize(2) 173 | self.bottleneck = nn.BatchNorm1d(pool_dim) 174 | self.bottleneck.bias.requires_grad_(False) # no shift 175 | 176 | self.classifier = nn.Linear(pool_dim, class_num, bias=False) 177 | 178 | self.bottleneck.apply(weights_init_kaiming) 179 | self.classifier.apply(weights_init_classifier) 180 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 181 | self.gm_pool = gm_pool 182 | 183 | def forward(self, x1, x2, modal=0): 184 | if modal == 0: 185 | x1 = self.visible_module(x1) 186 | x2 = self.thermal_module(x2) 187 | x = torch.cat((x1, x2), 0) 188 | elif modal == 1: 189 | x = self.visible_module(x1) 190 | elif modal == 2: 191 | x = self.thermal_module(x2) 192 | 193 | # shared block 194 | if self.non_local == 'on': 195 | NL1_counter = 0 196 | if len(self.NL_1_idx) == 0: self.NL_1_idx = [-1] 197 | for i in range(len(self.base_resnet.base.layer1)): 198 | x = self.base_resnet.base.layer1[i](x) 199 | if i == self.NL_1_idx[NL1_counter]: 200 | _, C, H, W = x.shape 201 | x = self.NL_1[NL1_counter](x) 202 | NL1_counter += 1 203 | # Layer 2 204 | NL2_counter = 0 205 | if len(self.NL_2_idx) == 0: self.NL_2_idx = [-1] 206 | for i in range(len(self.base_resnet.base.layer2)): 207 | x = self.base_resnet.base.layer2[i](x) 208 | if i == self.NL_2_idx[NL2_counter]: 209 | _, C, H, W = x.shape 210 | x = self.NL_2[NL2_counter](x) 211 | NL2_counter += 1 212 | # Layer 3 213 | NL3_counter = 0 214 | if len(self.NL_3_idx) == 0: self.NL_3_idx = [-1] 215 | for i in range(len(self.base_resnet.base.layer3)): 216 | x = self.base_resnet.base.layer3[i](x) 217 | if i == self.NL_3_idx[NL3_counter]: 218 | _, C, H, W = x.shape 219 | x = self.NL_3[NL3_counter](x) 220 | NL3_counter += 1 221 | # Layer 4 222 | NL4_counter = 0 223 | if len(self.NL_4_idx) == 0: self.NL_4_idx = [-1] 224 | for i in range(len(self.base_resnet.base.layer4)): 225 | x = self.base_resnet.base.layer4[i](x) 226 | if i == self.NL_4_idx[NL4_counter]: 227 | _, C, H, W = x.shape 228 | x = self.NL_4[NL4_counter](x) 229 | NL4_counter += 1 230 | else: 231 | x = self.base_resnet(x) 232 | if self.gm_pool == 'on': 233 | b, c, h, w = x.shape 234 | x = x.view(b, c, -1) 235 | p = 3.0 236 | x_pool = (torch.mean(x**p, dim=-1) + 1e-12)**(1/p) 237 | else: 238 | x_pool = self.avgpool(x) 239 | x_pool = x_pool.view(x_pool.size(0), x_pool.size(1)) 240 | 241 | feat = self.bottleneck(x_pool) 242 | 243 | if self.training: 244 | return x_pool, self.classifier(feat) 245 | else: 246 | return self.l2norm(x_pool), self.l2norm(feat) 247 | -------------------------------------------------------------------------------- /AGW/pre_process_sysu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import pdb 4 | import os 5 | 6 | data_path = 'SYSU-MM01' 7 | 8 | rgb_cameras = ['cam1','cam2','cam4','cam5'] 9 | ir_cameras = ['cam3','cam6'] 10 | 11 | # load id info 12 | file_path_train = os.path.join(data_path,'exp/train_id.txt') 13 | file_path_val = os.path.join(data_path,'exp/val_id.txt') 14 | with open(file_path_train, 'r') as file: 15 | ids = file.read().splitlines() 16 | ids = [int(y) for y in ids[0].split(',')] 17 | id_train = ["%04d" % x for x in ids] 18 | 19 | with open(file_path_val, 'r') as file: 20 | ids = file.read().splitlines() 21 | ids = [int(y) for y in ids[0].split(',')] 22 | id_val = ["%04d" % x for x in ids] 23 | 24 | # combine train and val split 25 | id_train.extend(id_val) 26 | 27 | files_rgb = [] 28 | files_ir = [] 29 | for id in sorted(id_train): 30 | for cam in rgb_cameras: 31 | img_dir = os.path.join(data_path,cam,id) 32 | if os.path.isdir(img_dir): 33 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 34 | files_rgb.extend(new_files) 35 | 36 | for cam in ir_cameras: 37 | img_dir = os.path.join(data_path,cam,id) 38 | if os.path.isdir(img_dir): 39 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 40 | files_ir.extend(new_files) 41 | 42 | # relabel 43 | pid_container = set() 44 | for img_path in files_ir: 45 | pid = int(img_path[-13:-9]) 46 | pid_container.add(pid) 47 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 48 | fix_image_width = 144 49 | fix_image_height = 288 50 | def read_imgs(train_image): 51 | train_img = [] 52 | train_label = [] 53 | train_cam = [] 54 | for img_path in train_image: 55 | # img 56 | img = Image.open(img_path) 57 | img = img.resize((fix_image_width, fix_image_height), Image.ANTIALIAS) 58 | pix_array = np.array(img) 59 | 60 | train_img.append(pix_array) 61 | 62 | # label 63 | pid = int(img_path[-13:-9]) 64 | cid = int(img_path[-15]) 65 | pid = pid2label[pid] 66 | train_label.append(pid) 67 | train_cam.append(cid) 68 | return np.array(train_img), np.array(train_label), np.array(train_cam) 69 | 70 | # rgb imges 71 | train_img, train_label, train_cam = read_imgs(files_rgb) 72 | np.save(data_path + 'train_rgb_resized_img.npy', train_img) 73 | np.save(data_path + 'train_rgb_resized_label.npy', train_label) 74 | np.save(data_path + 'train_rgb_resized_cam.npy', train_cam) 75 | 76 | # ir imges 77 | train_img, train_label, train_cam = read_imgs(files_ir) 78 | np.save(data_path + 'train_ir_resized_img.npy', train_img) 79 | np.save(data_path + 'train_ir_resized_label.npy', train_label) 80 | np.save(data_path + 'train_ir_resized_cam.npy', train_cam) 81 | -------------------------------------------------------------------------------- /AGW/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 18 | """3x3 convolution with padding""" 19 | # original padding is 1; original dilation is 1 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=dilation, bias=False, dilation=dilation) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride, dilation) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | # original padding is 1; original dilation is 1 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * 4) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | 97 | def __init__(self, block, layers, last_conv_stride=2, last_conv_dilation=1): 98 | 99 | self.inplanes = 64 100 | super(ResNet, self).__init__() 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 102 | bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_conv_stride, dilation=last_conv_dilation) 110 | 111 | for m in self.modules(): 112 | if isinstance(m, nn.Conv2d): 113 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 114 | #n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 115 | #m.weight.data.normal_(0, math.sqrt(2. / n)) 116 | elif isinstance(m, nn.BatchNorm2d): 117 | m.weight.data.fill_(1) 118 | m.bias.data.zero_() 119 | 120 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 121 | downsample = None 122 | if stride != 1 or self.inplanes != planes * block.expansion: 123 | downsample = nn.Sequential( 124 | nn.Conv2d(self.inplanes, planes * block.expansion, 125 | kernel_size=1, stride=stride, bias=False), 126 | nn.BatchNorm2d(planes * block.expansion), 127 | ) 128 | 129 | layers = [] 130 | layers.append(block(self.inplanes, planes, stride, downsample, dilation)) 131 | self.inplanes = planes * block.expansion 132 | for i in range(1, blocks): 133 | layers.append(block(self.inplanes, planes)) 134 | 135 | return nn.Sequential(*layers) 136 | 137 | def forward(self, x): 138 | x = self.conv1(x) 139 | x = self.bn1(x) 140 | x = self.relu(x) 141 | x = self.maxpool(x) 142 | 143 | x = self.layer1(x) 144 | x = self.layer2(x) 145 | x = self.layer3(x) 146 | x = self.layer4(x) 147 | 148 | return x 149 | 150 | 151 | def remove_fc(state_dict): 152 | """Remove the fc layer parameters from state_dict.""" 153 | # for key, value in state_dict.items(): 154 | for key, value in list(state_dict.items()): 155 | if key.startswith('fc.'): 156 | del state_dict[key] 157 | return state_dict 158 | 159 | 160 | def resnet18(pretrained=False, **kwargs): 161 | """Constructs a ResNet-18 model. 162 | Args: 163 | pretrained (bool): If True, returns a model pre-trained on ImageNet 164 | """ 165 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 166 | if pretrained: 167 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet18']))) 168 | return model 169 | 170 | 171 | def resnet34(pretrained=False, **kwargs): 172 | """Constructs a ResNet-34 model. 173 | Args: 174 | pretrained (bool): If True, returns a model pre-trained on ImageNet 175 | """ 176 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 177 | if pretrained: 178 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet34']))) 179 | return model 180 | 181 | 182 | def resnet50(pretrained=False, **kwargs): 183 | """Constructs a ResNet-50 model. 184 | Args: 185 | pretrained (bool): If True, returns a model pre-trained on ImageNet 186 | """ 187 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 188 | if pretrained: 189 | # model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50']))) 190 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50']))) 191 | return model 192 | 193 | 194 | def resnet101(pretrained=False, **kwargs): 195 | """Constructs a ResNet-101 model. 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on ImageNet 198 | """ 199 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 200 | if pretrained: 201 | model.load_state_dict( 202 | remove_fc(model_zoo.load_url(model_urls['resnet101']))) 203 | return model 204 | 205 | 206 | def resnet152(pretrained=False, **kwargs): 207 | """Constructs a ResNet-152 model. 208 | Args: 209 | pretrained (bool): If True, returns a model pre-trained on ImageNet 210 | """ 211 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 212 | if pretrained: 213 | model.load_state_dict( 214 | remove_fc(model_zoo.load_url(model_urls['resnet152']))) 215 | return model 216 | -------------------------------------------------------------------------------- /AGW/test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import time 4 | import torch.nn as nn 5 | import torch.backends.cudnn as cudnn 6 | from torch.autograd import Variable 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | from data_loader import SYSUData, RegDBData, TestData 10 | from data_manager import * 11 | from eval_metrics import eval_sysu, eval_regdb 12 | from model import embed_net 13 | from utils import * 14 | import pdb 15 | 16 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 17 | parser.add_argument('--dataset', default='sysu', help='dataset name: regdb or sysu]') 18 | parser.add_argument('--lr', default=0.1 , type=float, help='learning rate, 0.00035 for adam') 19 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer') 20 | parser.add_argument('--arch', default='resnet50', type=str, 21 | help='network baseline: resnet50') 22 | parser.add_argument('--resume', '-r', default='', type=str, 23 | help='resume from checkpoint') 24 | parser.add_argument('--test-only', action='store_true', help='test only') 25 | parser.add_argument('--model_path', default='save_model/', type=str, 26 | help='model save path') 27 | parser.add_argument('--save_epoch', default=20, type=int, 28 | metavar='s', help='save model every 10 epochs') 29 | parser.add_argument('--log_path', default='log/', type=str, 30 | help='log save path') 31 | parser.add_argument('--vis_log_path', default='log/vis_log/', type=str, 32 | help='log save path') 33 | parser.add_argument('--workers', default=4, type=int, metavar='N', 34 | help='number of data loading workers (default: 4)') 35 | parser.add_argument('--img_w', default=144, type=int, 36 | metavar='imgw', help='img width') 37 | parser.add_argument('--img_h', default=288, type=int, 38 | metavar='imgh', help='img height') 39 | parser.add_argument('--batch-size', default=8, type=int, 40 | metavar='B', help='training batch size') 41 | parser.add_argument('--test-batch', default=64, type=int, 42 | metavar='tb', help='testing batch size') 43 | parser.add_argument('--method', default='awg', type=str, 44 | metavar='m', help='method type: base or awg') 45 | parser.add_argument('--margin', default=0.3, type=float, 46 | metavar='margin', help='triplet loss margin') 47 | parser.add_argument('--num_pos', default=4, type=int, 48 | help='num of pos per identity in each modality') 49 | parser.add_argument('--trial', default=1, type=int, 50 | metavar='t', help='trial (only for RegDB dataset)') 51 | parser.add_argument('--seed', default=0, type=int, 52 | metavar='t', help='random seed') 53 | parser.add_argument('--gpu', default='0', type=str, 54 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 55 | parser.add_argument('--shot', default='single', type=str, help='single or multiple shot') 56 | parser.add_argument('--mode', default='all', type=str, help='all or indoor for sysu') 57 | parser.add_argument('--tvsearch', action='store_true', help='whether thermal to visible search on RegDB') 58 | #parser.add_argument('--mask_layer', nargs='+', default=[], help='mask layer') 59 | #parser.add_argument('--conv_mode', nargs='+', default=[], help='conv_mode') 60 | args = parser.parse_args() 61 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 62 | 63 | dataset = args.dataset 64 | shot = args.shot 65 | if dataset == 'sysu': 66 | data_path = 'SYSU-MM01/' 67 | n_class = 395 68 | test_mode = [1, 2] 69 | elif dataset =='regdb': 70 | data_path = 'RegDB/' 71 | n_class = 206 72 | test_mode = [2, 1] 73 | 74 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 75 | best_acc = 0 # best test accuracy 76 | start_epoch = 0 77 | pool_dim = 2048 78 | print('==> Building model..') 79 | if args.method =='base': 80 | net = embed_net(n_class, no_local= 'off', gm_pool = 'off', arch=args.arch) 81 | else: 82 | net = embed_net(n_class, no_local= 'on', gm_pool = 'on', arch=args.arch) 83 | net.to(device) 84 | cudnn.benchmark = True 85 | 86 | checkpoint_path = args.model_path 87 | 88 | if args.method =='id': 89 | criterion = nn.CrossEntropyLoss() 90 | criterion.to(device) 91 | 92 | print('==> Loading data..') 93 | # Data loading code 94 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 95 | transform_train = transforms.Compose([ 96 | transforms.ToPILImage(), 97 | transforms.RandomCrop((args.img_h,args.img_w)), 98 | transforms.RandomHorizontalFlip(), 99 | transforms.ToTensor(), 100 | normalize, 101 | ]) 102 | 103 | transform_test = transforms.Compose([ 104 | transforms.ToPILImage(), 105 | transforms.Resize((args.img_h,args.img_w)), 106 | transforms.ToTensor(), 107 | normalize, 108 | ]) 109 | 110 | end = time.time() 111 | 112 | 113 | 114 | def extract_gall_feat(gall_loader): 115 | net.eval() 116 | print ('Extracting Gallery Feature...') 117 | start = time.time() 118 | ptr = 0 119 | gall_feat_pool = np.zeros((ngall, pool_dim)) 120 | gall_feat_fc = np.zeros((ngall, pool_dim)) 121 | with torch.no_grad(): 122 | for batch_idx, (input, label ) in enumerate(gall_loader): 123 | batch_num = input.size(0) 124 | input = Variable(input.cuda()) 125 | feat_pool, feat_fc = net(input, input, test_mode[0]) 126 | gall_feat_pool[ptr:ptr+batch_num,: ] = feat_pool.detach().cpu().numpy() 127 | gall_feat_fc[ptr:ptr+batch_num,: ] = feat_fc.detach().cpu().numpy() 128 | ptr = ptr + batch_num 129 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 130 | return gall_feat_pool, gall_feat_fc 131 | 132 | def extract_query_feat(query_loader): 133 | net.eval() 134 | print ('Extracting Query Feature...') 135 | start = time.time() 136 | ptr = 0 137 | query_feat_pool = np.zeros((nquery, pool_dim)) 138 | query_feat_fc = np.zeros((nquery, pool_dim)) 139 | with torch.no_grad(): 140 | for batch_idx, (input, label ) in enumerate(query_loader): 141 | batch_num = input.size(0) 142 | input = Variable(input.cuda()) 143 | feat_pool, feat_fc = net(input, input, test_mode[1]) 144 | query_feat_pool[ptr:ptr+batch_num,: ] = feat_pool.detach().cpu().numpy() 145 | query_feat_fc[ptr:ptr+batch_num,: ] = feat_fc.detach().cpu().numpy() 146 | ptr = ptr + batch_num 147 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 148 | return query_feat_pool, query_feat_fc 149 | 150 | 151 | if dataset == 'sysu': 152 | 153 | print('==> Resuming from checkpoint..') 154 | 155 | model_path = checkpoint_path + args.resume 156 | 157 | if os.path.isfile(model_path): 158 | print('==> loading checkpoint {}'.format(args.resume)) 159 | checkpoint = torch.load(model_path) 160 | net.load_state_dict(checkpoint['net'], strict=False) 161 | print('==> loaded checkpoint {} (epoch {})' 162 | .format(args.resume, checkpoint['epoch'])) 163 | else: 164 | print('==> no checkpoint found at {}'.format(args.resume)) 165 | 166 | # testing set 167 | if shot == 'single': 168 | query_img, query_label, query_cam = process_query_sysu(data_path, mode=args.mode) 169 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=0) 170 | nquery = len(query_label) 171 | ngall = len(gall_label) 172 | print("Dataset statistics:") 173 | print(" ------------------------------") 174 | print(" subset | # ids | # images") 175 | print(" ------------------------------") 176 | print(" query | {:5d} | {:8d}".format(len(np.unique(query_label)), nquery)) 177 | print(" gallery | {:5d} | {:8d}".format(len(np.unique(gall_label)), ngall)) 178 | print(" ------------------------------") 179 | 180 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 181 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 182 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 183 | 184 | query_feat_pool, query_feat_fc = extract_query_feat(query_loader) 185 | for trial in range(10): 186 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=trial) 187 | 188 | trial_gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 189 | trial_gall_loader = data.DataLoader(trial_gallset, batch_size=args.test_batch, shuffle=False, num_workers=4) 190 | 191 | gall_feat_pool, gall_feat_fc = extract_gall_feat(trial_gall_loader) 192 | 193 | # pool5 feature 194 | distmat_pool = np.matmul(query_feat_pool, np.transpose(gall_feat_pool)) 195 | cmc_pool, mAP_pool, mINP_pool = eval_sysu(-distmat_pool, query_label, gall_label, query_cam, gall_cam) 196 | 197 | # fc feature 198 | distmat = np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 199 | cmc, mAP, mINP = eval_sysu(-distmat, query_label, gall_label, query_cam, gall_cam) 200 | if trial == 0: 201 | all_cmc = cmc 202 | all_mAP = mAP 203 | all_mINP = mINP 204 | all_cmc_pool = cmc_pool 205 | all_mAP_pool = mAP_pool 206 | all_mINP_pool = mINP_pool 207 | else: 208 | all_cmc = all_cmc + cmc 209 | all_mAP = all_mAP + mAP 210 | all_mINP = all_mINP + mINP 211 | all_cmc_pool = all_cmc_pool + cmc_pool 212 | all_mAP_pool = all_mAP_pool + mAP_pool 213 | all_mINP_pool = all_mINP_pool + mINP_pool 214 | 215 | print('Test Trial: {}'.format(trial)) 216 | print( 217 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 218 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 219 | print( 220 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 221 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 222 | cmc = all_cmc / 10 223 | mAP = all_mAP / 10 224 | mINP = all_mINP / 10 225 | 226 | cmc_pool = all_cmc_pool / 10 227 | mAP_pool = all_mAP_pool / 10 228 | mINP_pool = all_mINP_pool / 10 229 | print('All Average:') 230 | print( 231 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 232 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 233 | print( 234 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 235 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 236 | else: 237 | query_img, query_label, query_cam = process_query_sysu(data_path, mode=args.mode) 238 | gall_img, gall_label, gall_cam = process_gallery_sysu_multishot(data_path, mode=args.mode, trial=0) 239 | nquery = len(query_label) 240 | ngall = len(gall_label) 241 | print("Dataset statistics:") 242 | print(" ------------------------------") 243 | print(" subset | # ids | # images") 244 | print(" ------------------------------") 245 | print(" query | {:5d} | {:8d}".format(len(np.unique(query_label)), nquery)) 246 | print(" gallery | {:5d} | {:8d}".format(len(np.unique(gall_label)), ngall)) 247 | print(" ------------------------------") 248 | 249 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 250 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 251 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 252 | 253 | query_feat_pool, query_feat_fc = extract_query_feat(query_loader) 254 | for trial in range(10): 255 | gall_img, gall_label, gall_cam = process_gallery_sysu_multishot(data_path, mode=args.mode, trial=trial) 256 | 257 | trial_gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 258 | trial_gall_loader = data.DataLoader(trial_gallset, batch_size=args.test_batch, shuffle=False, num_workers=4) 259 | 260 | gall_feat_pool, gall_feat_fc = extract_gall_feat(trial_gall_loader) 261 | 262 | # pool5 feature 263 | distmat_pool = np.matmul(query_feat_pool, np.transpose(gall_feat_pool)) 264 | cmc_pool, mAP_pool, mINP_pool = eval_sysu(-distmat_pool, query_label, gall_label, query_cam, gall_cam) 265 | 266 | # fc feature 267 | distmat = np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 268 | cmc, mAP, mINP = eval_sysu(-distmat, query_label, gall_label, query_cam, gall_cam) 269 | if trial == 0: 270 | all_cmc = cmc 271 | all_mAP = mAP 272 | all_mINP = mINP 273 | all_cmc_pool = cmc_pool 274 | all_mAP_pool = mAP_pool 275 | all_mINP_pool = mINP_pool 276 | else: 277 | all_cmc = all_cmc + cmc 278 | all_mAP = all_mAP + mAP 279 | all_mINP = all_mINP + mINP 280 | all_cmc_pool = all_cmc_pool + cmc_pool 281 | all_mAP_pool = all_mAP_pool + mAP_pool 282 | all_mINP_pool = all_mINP_pool + mINP_pool 283 | 284 | print('Test Trial: {}'.format(trial)) 285 | print( 286 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 287 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 288 | print( 289 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 290 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 291 | 292 | cmc = all_cmc / 10 293 | mAP = all_mAP / 10 294 | mINP = all_mINP / 10 295 | 296 | cmc_pool = all_cmc_pool / 10 297 | mAP_pool = all_mAP_pool / 10 298 | mINP_pool = all_mINP_pool / 10 299 | print('All Average:') 300 | print( 301 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 302 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 303 | print( 304 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 305 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 306 | 307 | 308 | elif dataset == 'regdb': 309 | flag = 1 310 | for trial in range(10): 311 | test_trial = trial +1 312 | #model_path = checkpoint_path + args.resume 313 | model_path = checkpoint_path + 'regdb_'+ args.method +'_p4_n8_lr_0.1_seed_0_trial_{}_best.t'.format(test_trial) 314 | if os.path.isfile(model_path): 315 | print('==> loading checkpoint {}'.format(model_path)) 316 | checkpoint = torch.load(model_path) 317 | net.load_state_dict(checkpoint['net']) 318 | # training set 319 | trainset = RegDBData(data_path, test_trial, transform=transform_train) 320 | # generate the idx of each person identity 321 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 322 | 323 | # testing set 324 | query_img, query_label = process_test_regdb(data_path, trial=test_trial, modal='visible') 325 | gall_img, gall_label = process_test_regdb(data_path, trial=test_trial, modal='thermal') 326 | 327 | gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 328 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 329 | 330 | nquery = len(query_label) 331 | ngall = len(gall_label) 332 | 333 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 334 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 335 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 336 | 337 | query_feat_pool, query_feat_fc = extract_query_feat(query_loader) 338 | gall_feat_pool, gall_feat_fc = extract_gall_feat(gall_loader) 339 | 340 | if args.tvsearch: 341 | # pool5 feature 342 | distmat_pool = np.matmul(gall_feat_pool, np.transpose(query_feat_pool)) 343 | cmc_pool, mAP_pool, mINP_pool = eval_regdb(-distmat_pool, gall_label, query_label) 344 | 345 | # fc feature 346 | distmat = np.matmul(gall_feat_fc, np.transpose(query_feat_fc)) 347 | cmc, mAP, mINP = eval_regdb(-distmat, gall_label, query_label) 348 | else: 349 | # pool5 feature 350 | distmat_pool = np.matmul(query_feat_pool, np.transpose(gall_feat_pool)) 351 | cmc_pool, mAP_pool, mINP_pool = eval_regdb(-distmat_pool, query_label, gall_label) 352 | 353 | # fc feature 354 | distmat = np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 355 | cmc, mAP, mINP = eval_regdb(-distmat, query_label, gall_label) 356 | 357 | if trial == 0: 358 | all_cmc = cmc 359 | all_mAP = mAP 360 | all_mINP = mINP 361 | all_cmc_pool = cmc_pool 362 | all_mAP_pool = mAP_pool 363 | all_mINP_pool = mINP_pool 364 | else: 365 | all_cmc = all_cmc + cmc 366 | all_mAP = all_mAP + mAP 367 | all_mINP = all_mINP + mINP 368 | all_cmc_pool = all_cmc_pool + cmc_pool 369 | all_mAP_pool = all_mAP_pool + mAP_pool 370 | all_mINP_pool = all_mINP_pool + mINP_pool 371 | 372 | print('Test Trial: {}'.format(test_trial)) 373 | print( 374 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 375 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 376 | print( 377 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 378 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 379 | else: 380 | flag = 0 381 | if flag == 1: 382 | cmc = all_cmc / 10 383 | mAP = all_mAP / 10 384 | mINP = all_mINP / 10 385 | 386 | cmc_pool = all_cmc_pool / 10 387 | mAP_pool = all_mAP_pool / 10 388 | mINP_pool = all_mINP_pool / 10 389 | print('All Average:') 390 | print( 391 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 392 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 393 | print( 394 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 395 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 396 | 397 | else: 398 | print('Not enough checkpoints! The results are unreliable!') 399 | 400 | -------------------------------------------------------------------------------- /AGW/train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import sys 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.backends.cudnn as cudnn 9 | from torch.autograd import Variable 10 | import torch.utils.data as data 11 | import torchvision 12 | import torchvision.transforms as transforms 13 | from data_loader import SYSUData, RegDBData, TestData 14 | from data_manager import * 15 | from eval_metrics import eval_sysu, eval_regdb 16 | from model import embed_net 17 | from utils import * 18 | from loss import OriTripletLoss, TripletLoss_WRT 19 | from tensorboardX import SummaryWriter 20 | from collections.abc import Iterable 21 | 22 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 23 | parser.add_argument('--dataset', default='sysu', help='dataset name: regdb or sysu]') 24 | parser.add_argument('--lr', default=0.1 , type=float, help='learning rate, 0.00035 for adam') 25 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer') 26 | parser.add_argument('--arch', default='resnet50', type=str, 27 | help='network baseline:resnet18 or resnet50') 28 | parser.add_argument('--resume', '-r', default='', type=str, 29 | help='resume from checkpoint') 30 | parser.add_argument('--test-only', action='store_true', help='test only') 31 | parser.add_argument('--model_path', default='save_model/', type=str, 32 | help='model save path') 33 | parser.add_argument('--save_epoch', default=20, type=int, 34 | metavar='s', help='save model every 10 epochs') 35 | parser.add_argument('--log_path', default='log/', type=str, 36 | help='log save path') 37 | parser.add_argument('--vis_log_path', default='log/vis_log/', type=str, 38 | help='log save path') 39 | parser.add_argument('--workers', default=4, type=int, metavar='N', 40 | help='number of data loading workers (default: 4)') 41 | parser.add_argument('--img_w', default=144, type=int, 42 | metavar='imgw', help='img width') 43 | parser.add_argument('--img_h', default=288, type=int, 44 | metavar='imgh', help='img height') 45 | parser.add_argument('--batch-size', default=8, type=int, 46 | metavar='B', help='training batch size') 47 | parser.add_argument('--test-batch', default=64, type=int, 48 | metavar='tb', help='testing batch size') 49 | parser.add_argument('--method', default='agw', type=str, 50 | metavar='m', help='method type: base or agw') 51 | parser.add_argument('--margin', default=0.3, type=float, 52 | metavar='margin', help='triplet loss margin') 53 | parser.add_argument('--num_pos', default=4, type=int, 54 | help='num of pos per identity in each modality') 55 | parser.add_argument('--trial', default=1, type=int, 56 | metavar='t', help='trial (only for RegDB dataset)') 57 | parser.add_argument('--seed', default=0, type=int, 58 | metavar='t', help='random seed') 59 | parser.add_argument('--gpu', default='2', type=str, 60 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 61 | parser.add_argument('--mode', default='all', type=str, help='all or indoor') 62 | 63 | args = parser.parse_args() 64 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 65 | 66 | set_seed(args.seed) 67 | 68 | dataset = args.dataset 69 | if dataset == 'sysu': 70 | data_path = '../SYSU_MM01/' 71 | log_path = args.log_path + 'sysu_log/' 72 | test_mode = [1, 2] # thermal to visible 73 | elif dataset == 'regdb': 74 | data_path = 'RegDB/' 75 | log_path = args.log_path + 'regdb_log/' 76 | test_mode = [2, 1] # visible to thermal 77 | 78 | checkpoint_path = args.model_path 79 | 80 | if not os.path.isdir(log_path): 81 | os.makedirs(log_path) 82 | if not os.path.isdir(checkpoint_path): 83 | os.makedirs(checkpoint_path) 84 | if not os.path.isdir(args.vis_log_path): 85 | os.makedirs(args.vis_log_path) 86 | 87 | suffix = dataset 88 | if args.method=='agw': 89 | suffix = suffix + '_agw_p{}_n{}_lr_{}_seed_{}'.format(args.num_pos, args.batch_size, args.lr, args.seed) 90 | else: 91 | suffix = suffix + '_base_p{}_n{}_lr_{}_seed_{}'.format(args.num_pos, args.batch_size, args.lr, args.seed) 92 | 93 | 94 | if not args.optim == 'sgd': 95 | suffix = suffix + '_' + args.optim 96 | 97 | if dataset == 'regdb': 98 | suffix = suffix + '_trial_{}'.format(args.trial) 99 | 100 | sys.stdout = Logger(log_path + suffix + '_os.txt') 101 | 102 | vis_log_dir = args.vis_log_path + suffix + '/' 103 | 104 | if not os.path.isdir(vis_log_dir): 105 | os.makedirs(vis_log_dir) 106 | writer = SummaryWriter(vis_log_dir) 107 | print("==========\nArgs:{}\n==========".format(args)) 108 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 109 | best_acc = 0 # best test accuracy 110 | start_epoch = 0 111 | 112 | print('==> Loading data..') 113 | # Data loading code 114 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 115 | transform_train = transforms.Compose([ 116 | transforms.ToPILImage(), 117 | transforms.Pad(10), 118 | transforms.RandomCrop((args.img_h, args.img_w)), 119 | transforms.RandomHorizontalFlip(), 120 | transforms.ToTensor(), 121 | normalize, 122 | ]) 123 | transform_test = transforms.Compose([ 124 | transforms.ToPILImage(), 125 | transforms.Resize((args.img_h, args.img_w)), 126 | transforms.ToTensor(), 127 | normalize, 128 | ]) 129 | 130 | end = time.time() 131 | if dataset == 'sysu': 132 | # training set 133 | trainset = SYSUData(data_path, transform=transform_train) 134 | # generate the idx of each person identity 135 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 136 | 137 | # testing set 138 | query_img, query_label, query_cam = process_query_sysu(data_path, mode=args.mode) 139 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=0) 140 | 141 | elif dataset == 'regdb': 142 | # training set 143 | trainset = RegDBData(data_path, args.trial, transform=transform_train) 144 | # generate the idx of each person identity 145 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 146 | 147 | # testing set 148 | query_img, query_label = process_test_regdb(data_path, trial=args.trial, modal='visible') 149 | gall_img, gall_label = process_test_regdb(data_path, trial=args.trial, modal='thermal') 150 | 151 | gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 152 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 153 | 154 | # testing data loader 155 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 156 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 157 | 158 | n_class = len(np.unique(trainset.train_color_label)) 159 | nquery = len(query_label) 160 | ngall = len(gall_label) 161 | 162 | print('Dataset {} statistics:'.format(dataset)) 163 | print(' ------------------------------') 164 | print(' subset | # ids | # images') 165 | print(' ------------------------------') 166 | print(' visible | {:5d} | {:8d}'.format(n_class, len(trainset.train_color_label))) 167 | print(' thermal | {:5d} | {:8d}'.format(n_class, len(trainset.train_thermal_label))) 168 | print(' ------------------------------') 169 | print(' query | {:5d} | {:8d}'.format(len(np.unique(query_label)), nquery)) 170 | print(' gallery | {:5d} | {:8d}'.format(len(np.unique(gall_label)), ngall)) 171 | print(' ------------------------------') 172 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 173 | 174 | print('==> Building model..') 175 | if args.method =='base': 176 | net = embed_net(n_class, no_local= 'off', gm_pool = 'off', arch=args.arch) 177 | else: 178 | net = embed_net(n_class, no_local= 'on', gm_pool = 'on', arch=args.arch) 179 | net.to(device) 180 | cudnn.benchmark = True 181 | 182 | def remove_fc(state_dict): 183 | """Remove the fc layer parameters from state_dict.""" 184 | # for key, value in state_dict.items(): 185 | for key, value in list(state_dict.items()): 186 | if key.startswith('fc1.') or key.startswith('fc2.') or key.startswith('local_conv_list'): 187 | del state_dict[key] 188 | return state_dict 189 | 190 | if len(args.resume) > 0: 191 | model_path = checkpoint_path + args.resume 192 | if os.path.isfile(model_path): 193 | print('==> loading checkpoint {}'.format(args.resume)) 194 | net.load_state_dict(remove_fc(torch.load(model_path, map_location=torch.device('cpu'))), strict=False) 195 | print('==> loaded checkpoint {}' 196 | .format(args.resume)) 197 | else: 198 | print('==> no checkpoint found at {}'.format(args.resume)) 199 | 200 | # define loss function 201 | criterion_id = nn.CrossEntropyLoss() 202 | if args.method == 'agw': 203 | criterion_tri = TripletLoss_WRT() 204 | else: 205 | loader_batch = args.batch_size * args.num_pos 206 | criterion_tri= OriTripletLoss(batch_size=loader_batch, margin=args.margin) 207 | 208 | criterion_id.to(device) 209 | criterion_tri.to(device) 210 | 211 | 212 | if args.optim == 'sgd': 213 | ignored_params = list(map(id, net.bottleneck.parameters())) \ 214 | + list(map(id, net.classifier.parameters())) 215 | 216 | base_params = filter(lambda p: id(p) not in ignored_params, net.parameters()) 217 | 218 | optimizer = optim.SGD([ 219 | {'params': base_params, 'lr': 0.1 * args.lr}, 220 | {'params': net.bottleneck.parameters(), 'lr': args.lr}, 221 | {'params': net.classifier.parameters(), 'lr': args.lr}], 222 | weight_decay=5e-4, momentum=0.9, nesterov=True) 223 | 224 | # exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) 225 | def adjust_learning_rate(optimizer, epoch): 226 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 227 | if epoch < 10: 228 | lr = args.lr * (epoch + 1) / 10 229 | elif epoch >= 10 and epoch < 20: 230 | lr = args.lr 231 | elif epoch >= 20 and epoch < 50: 232 | lr = args.lr * 0.1 233 | elif epoch >= 50: 234 | lr = args.lr * 0.01 235 | 236 | optimizer.param_groups[0]['lr'] = 0.1 * lr 237 | for i in range(len(optimizer.param_groups) - 1): 238 | optimizer.param_groups[i + 1]['lr'] = lr 239 | 240 | return lr 241 | 242 | 243 | def train(epoch): 244 | 245 | current_lr = adjust_learning_rate(optimizer, epoch) 246 | train_loss = AverageMeter() 247 | id_loss = AverageMeter() 248 | tri_loss = AverageMeter() 249 | data_time = AverageMeter() 250 | batch_time = AverageMeter() 251 | correct = 0 252 | total = 0 253 | 254 | # switch to train mode 255 | net.train() 256 | end = time.time() 257 | 258 | for batch_idx, (input1, input2, label1, label2) in enumerate(trainloader): 259 | 260 | labels = torch.cat((label1, label2), 0) 261 | 262 | input1 = Variable(input1.cuda()) 263 | input2 = Variable(input2.cuda()) 264 | 265 | labels = Variable(labels.cuda()) 266 | data_time.update(time.time() - end) 267 | 268 | 269 | feat, out0, = net(input1, input2) 270 | 271 | loss_id = criterion_id(out0, labels) 272 | loss_tri, batch_acc = criterion_tri(feat, labels) 273 | correct += (batch_acc / 2) 274 | _, predicted = out0.max(1) 275 | correct += (predicted.eq(labels).sum().item() / 2) 276 | 277 | loss = loss_id + loss_tri 278 | optimizer.zero_grad() 279 | loss.backward() 280 | optimizer.step() 281 | 282 | # update P 283 | train_loss.update(loss.item(), 2 * input1.size(0)) 284 | id_loss.update(loss_id.item(), 2 * input1.size(0)) 285 | tri_loss.update(loss_tri.item(), 2 * input1.size(0)) 286 | total += labels.size(0) 287 | 288 | # measure elapsed time 289 | batch_time.update(time.time() - end) 290 | end = time.time() 291 | if batch_idx % 50 == 0: 292 | print('Epoch: [{}][{}/{}] ' 293 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 294 | 'lr:{:.3f} ' 295 | 'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f}) ' 296 | 'iLoss: {id_loss.val:.4f} ({id_loss.avg:.4f}) ' 297 | 'TLoss: {tri_loss.val:.4f} ({tri_loss.avg:.4f}) ' 298 | 'Accu: {:.2f}'.format( 299 | epoch, batch_idx, len(trainloader), current_lr, 300 | 100. * correct / total, batch_time=batch_time, 301 | train_loss=train_loss, id_loss=id_loss, tri_loss=tri_loss)) 302 | 303 | writer.add_scalar('total_loss', train_loss.avg, epoch) 304 | writer.add_scalar('id_loss', id_loss.avg, epoch) 305 | writer.add_scalar('tri_loss', tri_loss.avg, epoch) 306 | writer.add_scalar('lr', current_lr, epoch) 307 | 308 | 309 | def test(epoch): 310 | # switch to evaluation mode 311 | net.eval() 312 | print('Extracting Gallery Feature...') 313 | start = time.time() 314 | ptr = 0 315 | gall_feat = np.zeros((ngall, 2048)) 316 | gall_feat_att = np.zeros((ngall, 2048)) 317 | with torch.no_grad(): 318 | for batch_idx, (input, label) in enumerate(gall_loader): 319 | batch_num = input.size(0) 320 | input = Variable(input.cuda()) 321 | feat, feat_att = net(input, input, test_mode[0]) 322 | gall_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 323 | gall_feat_att[ptr:ptr + batch_num, :] = feat_att.detach().cpu().numpy() 324 | ptr = ptr + batch_num 325 | print('Extracting Time:\t {:.3f}'.format(time.time() - start)) 326 | 327 | # switch to evaluation 328 | net.eval() 329 | print('Extracting Query Feature...') 330 | start = time.time() 331 | ptr = 0 332 | query_feat = np.zeros((nquery, 2048)) 333 | query_feat_att = np.zeros((nquery, 2048)) 334 | with torch.no_grad(): 335 | for batch_idx, (input, label) in enumerate(query_loader): 336 | batch_num = input.size(0) 337 | input = Variable(input.cuda()) 338 | feat, feat_att = net(input, input, test_mode[1]) 339 | query_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 340 | query_feat_att[ptr:ptr + batch_num, :] = feat_att.detach().cpu().numpy() 341 | ptr = ptr + batch_num 342 | print('Extracting Time:\t {:.3f}'.format(time.time() - start)) 343 | 344 | start = time.time() 345 | # compute the similarity 346 | distmat = np.matmul(query_feat, np.transpose(gall_feat)) 347 | distmat_att = np.matmul(query_feat_att, np.transpose(gall_feat_att)) 348 | 349 | # evaluation 350 | if dataset == 'regdb': 351 | cmc, mAP, mINP = eval_regdb(-distmat, query_label, gall_label) 352 | cmc_att, mAP_att, mINP_att = eval_regdb(-distmat_att, query_label, gall_label) 353 | elif dataset == 'sysu': 354 | cmc, mAP, mINP = eval_sysu(-distmat, query_label, gall_label, query_cam, gall_cam) 355 | cmc_att, mAP_att, mINP_att = eval_sysu(-distmat_att, query_label, gall_label, query_cam, gall_cam) 356 | print('Evaluation Time:\t {:.3f}'.format(time.time() - start)) 357 | 358 | writer.add_scalar('rank1', cmc[0], epoch) 359 | writer.add_scalar('mAP', mAP, epoch) 360 | writer.add_scalar('mINP', mINP, epoch) 361 | writer.add_scalar('rank1_att', cmc_att[0], epoch) 362 | writer.add_scalar('mAP_att', mAP_att, epoch) 363 | writer.add_scalar('mINP_att', mINP_att, epoch) 364 | return cmc, mAP, mINP, cmc_att, mAP_att, mINP_att 365 | 366 | 367 | # training 368 | print('==> Start Training...') 369 | for epoch in range(start_epoch, 81 - start_epoch): 370 | 371 | print('==> Preparing Data Loader...') 372 | # identity sampler 373 | sampler = IdentitySampler(trainset.train_color_label, \ 374 | trainset.train_thermal_label, color_pos, thermal_pos, args.num_pos, args.batch_size, 375 | epoch) 376 | 377 | trainset.cIndex = sampler.index1 # color index 378 | trainset.tIndex = sampler.index2 # thermal index 379 | print(epoch) 380 | print(trainset.cIndex) 381 | print(trainset.tIndex) 382 | 383 | loader_batch = args.batch_size * args.num_pos 384 | 385 | trainloader = data.DataLoader(trainset, batch_size=loader_batch, \ 386 | sampler=sampler, num_workers=args.workers, drop_last=True) 387 | 388 | # training 389 | train(epoch) 390 | 391 | if epoch > 0 and epoch % 2 == 0: 392 | print('Test Epoch: {}'.format(epoch)) 393 | 394 | # testing 395 | cmc, mAP, mINP, cmc_att, mAP_att, mINP_att = test(epoch) 396 | # save model 397 | if cmc_att[0] > best_acc: # not the real best for sysu-mm01 398 | best_acc = cmc_att[0] 399 | best_epoch = epoch 400 | state = { 401 | 'net': net.state_dict(), 402 | 'cmc': cmc_att, 403 | 'mAP': mAP_att, 404 | 'mINP': mINP_att, 405 | 'epoch': epoch, 406 | } 407 | torch.save(state, checkpoint_path + suffix + '_best.t') 408 | 409 | # save model 410 | if epoch > 10 and epoch % args.save_epoch == 0: 411 | state = { 412 | 'net': net.state_dict(), 413 | 'cmc': cmc, 414 | 'mAP': mAP, 415 | 'epoch': epoch, 416 | } 417 | torch.save(state, checkpoint_path + suffix + '_epoch_{}.t'.format(epoch)) 418 | 419 | print('POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 420 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 421 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 422 | cmc_att[0], cmc_att[4], cmc_att[9], cmc_att[19], mAP_att, mINP_att)) 423 | print('Best Epoch [{}]'.format(best_epoch)) 424 | -------------------------------------------------------------------------------- /AGW/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data.sampler import Sampler 4 | import sys 5 | import os.path as osp 6 | import torch 7 | 8 | def load_data(input_data_path ): 9 | with open(input_data_path) as f: 10 | data_file_list = open(input_data_path, 'rt').read().splitlines() 11 | # Get full list of color image and labels 12 | file_image = [s.split(' ')[0] for s in data_file_list] 13 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 14 | 15 | return file_image, file_label 16 | 17 | 18 | def GenIdx( train_color_label, train_thermal_label): 19 | color_pos = [] 20 | unique_label_color = np.unique(train_color_label) 21 | for i in range(len(unique_label_color)): 22 | tmp_pos = [k for k,v in enumerate(train_color_label) if v==unique_label_color[i]] 23 | color_pos.append(tmp_pos) 24 | 25 | thermal_pos = [] 26 | unique_label_thermal = np.unique(train_thermal_label) 27 | for i in range(len(unique_label_thermal)): 28 | tmp_pos = [k for k,v in enumerate(train_thermal_label) if v==unique_label_thermal[i]] 29 | thermal_pos.append(tmp_pos) 30 | return color_pos, thermal_pos 31 | 32 | def GenCamIdx(gall_img, gall_label, mode): 33 | if mode =='indoor': 34 | camIdx = [1,2] 35 | else: 36 | camIdx = [1,2,4,5] 37 | gall_cam = [] 38 | for i in range(len(gall_img)): 39 | gall_cam.append(int(gall_img[i][-10])) 40 | 41 | sample_pos = [] 42 | unique_label = np.unique(gall_label) 43 | for i in range(len(unique_label)): 44 | for j in range(len(camIdx)): 45 | id_pos = [k for k,v in enumerate(gall_label) if v==unique_label[i] and gall_cam[k]==camIdx[j]] 46 | if id_pos: 47 | sample_pos.append(id_pos) 48 | return sample_pos 49 | 50 | def ExtractCam(gall_img): 51 | gall_cam = [] 52 | for i in range(len(gall_img)): 53 | cam_id = int(gall_img[i][-10]) 54 | # if cam_id ==3: 55 | # cam_id = 2 56 | gall_cam.append(cam_id) 57 | 58 | return np.array(gall_cam) 59 | 60 | class IdentitySampler(Sampler): 61 | """Sample person identities evenly in each batch. 62 | Args: 63 | train_color_label, train_thermal_label: labels of two modalities 64 | color_pos, thermal_pos: positions of each identity 65 | batchSize: batch size 66 | """ 67 | 68 | def __init__(self, train_color_label, train_thermal_label, color_pos, thermal_pos, num_pos, batchSize, epoch): 69 | uni_label = np.unique(train_color_label) 70 | self.n_classes = len(uni_label) 71 | 72 | 73 | N = np.maximum(len(train_color_label), len(train_thermal_label)) 74 | for j in range(int(N/(batchSize*num_pos))+1): 75 | batch_idx = np.random.choice(uni_label, batchSize, replace = False) 76 | for i in range(batchSize): 77 | sample_color = np.random.choice(color_pos[batch_idx[i]], num_pos) 78 | sample_thermal = np.random.choice(thermal_pos[batch_idx[i]], num_pos) 79 | 80 | if j ==0 and i==0: 81 | index1= sample_color 82 | index2= sample_thermal 83 | else: 84 | index1 = np.hstack((index1, sample_color)) 85 | index2 = np.hstack((index2, sample_thermal)) 86 | 87 | self.index1 = index1 88 | self.index2 = index2 89 | self.N = N 90 | 91 | def __iter__(self): 92 | return iter(np.arange(len(self.index1))) 93 | 94 | def __len__(self): 95 | return self.N 96 | 97 | class AverageMeter(object): 98 | """Computes and stores the average and current value""" 99 | def __init__(self): 100 | self.reset() 101 | 102 | def reset(self): 103 | self.val = 0 104 | self.avg = 0 105 | self.sum = 0 106 | self.count = 0 107 | 108 | def update(self, val, n=1): 109 | self.val = val 110 | self.sum += val * n 111 | self.count += n 112 | self.avg = self.sum / self.count 113 | 114 | def mkdir_if_missing(directory): 115 | if not osp.exists(directory): 116 | try: 117 | os.makedirs(directory) 118 | except OSError as e: 119 | if e.errno != errno.EEXIST: 120 | raise 121 | class Logger(object): 122 | """ 123 | Write console output to external text file. 124 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 125 | """ 126 | def __init__(self, fpath=None): 127 | self.console = sys.stdout 128 | self.file = None 129 | if fpath is not None: 130 | mkdir_if_missing(osp.dirname(fpath)) 131 | self.file = open(fpath, 'w') 132 | 133 | def __del__(self): 134 | self.close() 135 | 136 | def __enter__(self): 137 | pass 138 | 139 | def __exit__(self, *args): 140 | self.close() 141 | 142 | def write(self, msg): 143 | self.console.write(msg) 144 | if self.file is not None: 145 | self.file.write(msg) 146 | 147 | def flush(self): 148 | self.console.flush() 149 | if self.file is not None: 150 | self.file.flush() 151 | os.fsync(self.file.fileno()) 152 | 153 | def close(self): 154 | self.console.close() 155 | if self.file is not None: 156 | self.file.close() 157 | 158 | def set_seed(seed, cuda=True): 159 | np.random.seed(seed) 160 | torch.manual_seed(seed) 161 | if cuda: 162 | torch.cuda.manual_seed(seed) 163 | 164 | def set_requires_grad(nets, requires_grad=False): 165 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 166 | Parameters: 167 | nets (network list) -- a list of networks 168 | requires_grad (bool) -- whether the networks require gradients or not 169 | """ 170 | if not isinstance(nets, list): 171 | nets = [nets] 172 | for net in nets: 173 | if net is not None: 174 | for param in net.parameters(): 175 | param.requires_grad = requires_grad 176 | -------------------------------------------------------------------------------- /DDAG/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 mangye16 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /DDAG/README.md: -------------------------------------------------------------------------------- 1 | # DDAG 2 | Pytorch Code of DDAG for Visible-Infrared Person Re-Identification in ECCV 2020. [PDF](https://arxiv.org/pdf/2007.09314.pdf) 3 | 4 | ## Highlight 5 | 6 | The goal of this work is to learn a robust and discriminative cross-modality representation for visible-infrarerd person re-identification. 7 | 8 | - Intra-modality Weighted-Part Aggregation (IWPA): It learns discriminative part-aggregated features by mining the contextual part relation. 9 | 10 | - Cross-modality Graph Structured Attention (CGSA): It enhances the feature by incorporating the neighborhood information across two modalities. 11 | 12 | ### Results on the SYSU-MM01 Dataset 13 | Method |Datasets | Rank@1 | mAP | mINP | 14 | |------| -------- | ----- | ----- | ----- | 15 | | AGW [[1](https://github.com/mangye16/Cross-Modal-Re-ID-baseline)] |#SYSU-MM01 (All-Search) | ~ 47.50% | ~ 47.65% | ~ 35.30% | 16 | | DDAG|#SYSU-MM01 (All-Search) | ~ 54.75% | ~ 53.02% | ~39.62% | 17 | | AGW [[1](https://github.com/mangye16/Cross-Modal-Re-ID-baseline)] |#SYSU-MM01 (Indoor-Search) | ~ 54.17% | ~ 62.97% | ~ 59.23%| 18 | | DDAG|#SYSU-MM01 (Indoor-Search) | ~ 61.02% | ~ 67.98% | ~ 62.61%| 19 | 20 | *The code has been tested in Python 3.7, PyTorch=1.0. Both of these two datasets may have some fluctuation due to random spliting 21 | 22 | ### 1. Prepare the datasets. 23 | 24 | - (1) RegDB Dataset [1]: The RegDB dataset can be downloaded from this [website](http://dm.dongguk.edu/link.html) by submitting a copyright form. 25 | 26 | - (Named: "Dongguk Body-based Person Recognition Database (DBPerson-Recog-DB1)" on their website). 27 | 28 | - A private download link can be requested via sending me an email (mangye16@gmail.com). 29 | 30 | - (2) SYSU-MM01 Dataset [2]: The SYSU-MM01 dataset can be downloaded from this [website](http://isee.sysu.edu.cn/project/RGBIRReID.htm). 31 | 32 | - run `python pre_process_sysu.py` [link](https://github.com/mangye16/Cross-Modal-Re-ID-baseline/blob/master/pre_process_sysu.py) in to pepare the dataset, the training data will be stored in ".npy" format. 33 | 34 | ### 2. Training. 35 | Train a model by 36 | ```bash 37 | python train_ddag.py --dataset sysu --lr 0.1 --graph --wpa --part 3 --gpu 0 38 | ``` 39 | 40 | - `--dataset`: which dataset "sysu" or "regdb". 41 | 42 | - `--lr`: initial learning rate. 43 | 44 | - `--graph`: using graph attention. 45 | 46 | - `--wpa`: using weighted part attention 47 | 48 | - `--part`: part number 49 | 50 | - `--gpu`: which gpu to run. 51 | 52 | You may need manually define the data path first. 53 | 54 | 55 | ### 3. Testing. 56 | 57 | Test a model on SYSU-MM01 or RegDB dataset by 58 | ```bash 59 | python test_ddag.py --dataset sysu --mode all --wpa --graph --gpu 1 --resume 'model_path' 60 | ``` 61 | - `--dataset`: which dataset "sysu" or "regdb". 62 | 63 | - `--mode`: "all" or "indoor" all search or indoor search (only for sysu dataset). 64 | 65 | - `--trial`: testing trial (only for RegDB dataset). 66 | 67 | - `--resume`: the saved model path. ** Important ** 68 | 69 | - `--gpu`: which gpu to run. 70 | 71 | ### 4. Citation 72 | 73 | Please kindly cite the references in your publications if it helps your research: 74 | ``` 75 | @inproceedings{eccv20ddag, 76 | title={Dynamic Dual-Attentive Aggregation Learning for Visible-Infrared Person Re-Identification}, 77 | author={Ye, Mang and Shen, Jianbing and Crandall, David J. and Shao, Ling and Luo, Jiebo}, 78 | booktitle={European Conference on Computer Vision (ECCV)}, 79 | year={2020}, 80 | } 81 | ``` 82 | 83 | ``` 84 | @article{arxiv20reidsurvey, 85 | title={Deep Learning for Person Re-identification: A Survey and Outlook}, 86 | author={Ye, Mang and Shen, Jianbing and Lin, Gaojie and Xiang, Tao and Shao, Ling and Hoi, Steven C. H.}, 87 | journal={arXiv preprint arXiv:2001.04193}, 88 | year={2020}, 89 | } 90 | ``` 91 | 92 | Contact: mangye16@gmail.com 93 | -------------------------------------------------------------------------------- /DDAG/attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | """ 6 | PART of the code is from the following link 7 | https://github.com/Diego999/pyGAT/blob/master/layers.py 8 | """ 9 | 10 | 11 | class Normalize(nn.Module): 12 | def __init__(self, power=2): 13 | super(Normalize, self).__init__() 14 | self.power = power 15 | 16 | def forward(self, x): 17 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 18 | out = x.div(norm) 19 | return out 20 | 21 | class GraphAttentionLayer(nn.Module): 22 | """ 23 | Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 24 | """ 25 | 26 | def __init__(self, in_features, out_features, dropout, alpha=0.2, concat=True): 27 | super(GraphAttentionLayer, self).__init__() 28 | self.dropout = dropout 29 | self.in_features = in_features 30 | self.out_features = out_features 31 | self.alpha = alpha 32 | self.concat = concat 33 | 34 | self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) 35 | nn.init.xavier_uniform_(self.W.data, gain=1.414) 36 | self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1))) 37 | nn.init.xavier_uniform_(self.a.data, gain=1.414) 38 | 39 | self.leakyrelu = nn.LeakyReLU(self.alpha) 40 | 41 | def forward(self, input, adj): 42 | h = torch.mm(input, self.W) 43 | N = h.size()[0] 44 | 45 | a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features) 46 | e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2)) 47 | 48 | zero_vec = -9e15 * torch.ones_like(e) 49 | attention = torch.where(adj > 0, e, zero_vec) 50 | attention = F.softmax(attention, dim=1) 51 | attention = F.dropout(attention, self.dropout, training=self.training) 52 | h_prime = torch.matmul(attention, h) 53 | 54 | if self.concat: 55 | return F.elu(h_prime) 56 | else: 57 | return h_prime 58 | 59 | def __repr__(self): 60 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 61 | 62 | 63 | class SpecialSpmmFunction(torch.autograd.Function): 64 | """Special function for only sparse region backpropataion layer.""" 65 | 66 | @staticmethod 67 | def forward(ctx, indices, values, shape, b): 68 | assert indices.requires_grad == False 69 | a = torch.sparse_coo_tensor(indices, values, shape) 70 | ctx.save_for_backward(a, b) 71 | ctx.N = shape[0] 72 | return torch.matmul(a, b) 73 | 74 | @staticmethod 75 | def backward(ctx, grad_output): 76 | a, b = ctx.saved_tensors 77 | grad_values = grad_b = None 78 | if ctx.needs_input_grad[1]: 79 | grad_a_dense = grad_output.matmul(b.t()) 80 | edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :] 81 | grad_values = grad_a_dense.view(-1)[edge_idx] 82 | if ctx.needs_input_grad[3]: 83 | grad_b = a.t().matmul(grad_output) 84 | return None, grad_values, None, grad_b 85 | 86 | 87 | class SpecialSpmm(nn.Module): 88 | def forward(self, indices, values, shape, b): 89 | return SpecialSpmmFunction.apply(indices, values, shape, b) 90 | 91 | 92 | class SpGraphAttentionLayer(nn.Module): 93 | """ 94 | Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903 95 | """ 96 | 97 | def __init__(self, in_features, out_features, dropout, alpha, concat=True): 98 | super(SpGraphAttentionLayer, self).__init__() 99 | self.in_features = in_features 100 | self.out_features = out_features 101 | self.alpha = alpha 102 | self.concat = concat 103 | 104 | self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) 105 | nn.init.xavier_normal_(self.W.data, gain=1.414) 106 | 107 | self.a = nn.Parameter(torch.zeros(size=(1, 2 * out_features))) 108 | nn.init.xavier_normal_(self.a.data, gain=1.414) 109 | 110 | self.dropout = nn.Dropout(dropout) 111 | self.leakyrelu = nn.LeakyReLU(self.alpha) 112 | self.special_spmm = SpecialSpmm() 113 | 114 | def forward(self, input, adj): 115 | dv = 'cuda' if input.is_cuda else 'cpu' 116 | 117 | N = input.size()[0] 118 | edge = adj.nonzero().t() 119 | 120 | h = torch.mm(input, self.W) 121 | # h: N x out 122 | assert not torch.isnan(h).any() 123 | 124 | # Self-attention on the nodes - Shared attention mechanism 125 | edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t() 126 | # edge: 2*D x E 127 | 128 | edge_e = torch.exp(-self.leakyrelu(self.a.mm(edge_h).squeeze())) 129 | assert not torch.isnan(edge_e).any() 130 | # edge_e: E 131 | 132 | e_rowsum = self.special_spmm(edge, edge_e, torch.Size([N, N]), torch.ones(size=(N, 1), device=dv)) 133 | # e_rowsum: N x 1 134 | 135 | edge_e = self.dropout(edge_e) 136 | # edge_e: E 137 | 138 | h_prime = self.special_spmm(edge, edge_e, torch.Size([N, N]), h) 139 | assert not torch.isnan(h_prime).any() 140 | # h_prime: N x out 141 | 142 | h_prime = h_prime.div(e_rowsum) 143 | # h_prime: N x out 144 | assert not torch.isnan(h_prime).any() 145 | 146 | if self.concat: 147 | # if this layer is not last layer, 148 | return F.elu(h_prime) 149 | else: 150 | # if this layer is last layer, 151 | return h_prime 152 | 153 | def __repr__(self): 154 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 155 | 156 | 157 | class IWPA(nn.Module): 158 | """ 159 | Part attention layer, "Dynamic Dual-Attentive Aggregation Learning for Visible-Infrared Person Re-Identification" 160 | """ 161 | def __init__(self, in_channels, part = 3, inter_channels=None, out_channels=None): 162 | super(IWPA, self).__init__() 163 | 164 | self.in_channels = in_channels 165 | self.inter_channels = inter_channels 166 | self.out_channels = out_channels 167 | self.l2norm = Normalize(2) 168 | 169 | if self.inter_channels is None: 170 | self.inter_channels = in_channels 171 | 172 | if self.out_channels is None: 173 | self.out_channels = in_channels 174 | 175 | conv_nd = nn.Conv2d 176 | 177 | self.fc1 = nn.Sequential( 178 | conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, 179 | padding=0), 180 | ) 181 | 182 | self.fc2 = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 183 | kernel_size=1, stride=1, padding=0) 184 | 185 | self.fc3 = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 186 | kernel_size=1, stride=1, padding=0) 187 | 188 | self.W = nn.Sequential( 189 | conv_nd(in_channels=self.inter_channels, out_channels=self.out_channels, 190 | kernel_size=1, stride=1, padding=0), 191 | nn.BatchNorm2d(self.out_channels), 192 | ) 193 | nn.init.constant_(self.W[1].weight, 0.0) 194 | nn.init.constant_(self.W[1].bias, 0.0) 195 | 196 | 197 | self.bottleneck = nn.BatchNorm1d(in_channels) 198 | self.bottleneck.bias.requires_grad_(False) # no shift 199 | 200 | nn.init.normal_(self.bottleneck.weight.data, 1.0, 0.01) 201 | nn.init.zeros_(self.bottleneck.bias.data) 202 | 203 | # weighting vector of the part features 204 | self.gate = nn.Parameter(torch.FloatTensor(part)) 205 | nn.init.constant_(self.gate, 1/part) 206 | def forward(self, x, feat, t=None, part=0): 207 | bt, c, h, w = x.shape 208 | b = bt // t 209 | 210 | # get part features 211 | part_feat = F.adaptive_avg_pool2d(x, (part, 1)) 212 | part_feat = part_feat.view(b, t, c, part) 213 | part_feat = part_feat.permute(0, 2, 1, 3) # B, C, T, Part 214 | 215 | part_feat1 = self.fc1(part_feat).view(b, self.inter_channels, -1) # B, C//r, T*Part 216 | part_feat1 = part_feat1.permute(0, 2, 1) # B, T*Part, C//r 217 | 218 | part_feat2 = self.fc2(part_feat).view(b, self.inter_channels, -1) # B, C//r, T*Part 219 | 220 | part_feat3 = self.fc3(part_feat).view(b, self.inter_channels, -1) # B, C//r, T*Part 221 | part_feat3 = part_feat3.permute(0, 2, 1) # B, T*Part, C//r 222 | 223 | # get cross-part attention 224 | cpa_att = torch.matmul(part_feat1, part_feat2) # B, T*Part, T*Part 225 | cpa_att = F.softmax(cpa_att, dim=-1) 226 | 227 | # collect contextual information 228 | refined_part_feat = torch.matmul(cpa_att, part_feat3) # B, T*Part, C//r 229 | refined_part_feat = refined_part_feat.permute(0, 2, 1).contiguous() # B, C//r, T*Part 230 | refined_part_feat = refined_part_feat.view(b, self.inter_channels, part) # B, C//r, T, Part 231 | 232 | gate = F.softmax(self.gate, dim=-1) 233 | weight_part_feat = torch.matmul(refined_part_feat, gate) 234 | x = F.adaptive_avg_pool2d(x, (1, 1)) 235 | # weight_part_feat = weight_part_feat + x.view(x.size(0), x.size(1)) 236 | 237 | weight_part_feat = weight_part_feat + feat 238 | feat = self.bottleneck(weight_part_feat) 239 | 240 | return feat -------------------------------------------------------------------------------- /DDAG/data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image, ImageChops 3 | from torchvision import transforms 4 | import random 5 | import pdb 6 | import torch 7 | import torchvision.datasets as datasets 8 | import torch.utils.data as data 9 | 10 | 11 | class SYSUData(data.Dataset): 12 | def __init__(self, data_dir, transform=None, colorIndex = None, thermalIndex = None): 13 | 14 | # data_dir = '../Datasets/SYSU-MM01/' 15 | 16 | # Load training images (path) and labels 17 | train_color_image = np.load(data_dir + 'train_rgb_resized_img.npy') 18 | self.train_color_label = np.load(data_dir + 'train_rgb_resized_label.npy') 19 | 20 | train_thermal_image = np.load(data_dir + 'train_ir_resized_img.npy') 21 | self.train_thermal_label = np.load(data_dir + 'train_ir_resized_label.npy') 22 | 23 | # BGR to RGB 24 | self.train_color_image = train_color_image 25 | self.train_thermal_image = train_thermal_image 26 | self.transform = transform 27 | self.cIndex = colorIndex 28 | self.tIndex = thermalIndex 29 | 30 | def __getitem__(self, index): 31 | 32 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 33 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 34 | 35 | img1 = self.transform(img1) 36 | img2 = self.transform(img2) 37 | 38 | return img1, img2, target1, target2 39 | 40 | def __len__(self): 41 | return len(self.train_color_label) 42 | 43 | 44 | class RegDBData(data.Dataset): 45 | def __init__(self, data_dir, trial, transform=None, colorIndex = None, thermalIndex = None): 46 | # Load training images (path) and labels 47 | # data_dir = '../Datasets/RegDB/' 48 | train_color_list = data_dir + 'idx/train_visible_{}'.format(trial)+ '.txt' 49 | train_thermal_list = data_dir + 'idx/train_thermal_{}'.format(trial)+ '.txt' 50 | 51 | color_img_file, train_color_label = load_data(train_color_list) 52 | thermal_img_file, train_thermal_label = load_data(train_thermal_list) 53 | 54 | train_color_image = [] 55 | for i in range(len(color_img_file)): 56 | 57 | img = Image.open(data_dir+ color_img_file[i]) 58 | img = img.resize((144, 288), Image.ANTIALIAS) 59 | pix_array = np.array(img) 60 | train_color_image.append(pix_array) 61 | train_color_image = np.array(train_color_image) 62 | 63 | train_thermal_image = [] 64 | for i in range(len(thermal_img_file)): 65 | img = Image.open(data_dir+ thermal_img_file[i]) 66 | img = img.resize((144, 288), Image.ANTIALIAS) 67 | pix_array = np.array(img) 68 | train_thermal_image.append(pix_array) 69 | train_thermal_image = np.array(train_thermal_image) 70 | 71 | # BGR to RGB 72 | self.train_color_image = train_color_image 73 | self.train_color_label = train_color_label 74 | 75 | # BGR to RGB 76 | self.train_thermal_image = train_thermal_image 77 | self.train_thermal_label = train_thermal_label 78 | 79 | self.transform = transform 80 | self.cIndex = colorIndex 81 | self.tIndex = thermalIndex 82 | 83 | def __getitem__(self, index): 84 | 85 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 86 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 87 | 88 | img1 = self.transform(img1) 89 | img2 = self.transform(img2) 90 | 91 | return img1, img2, target1, target2 92 | 93 | def __len__(self): 94 | return len(self.train_color_label) 95 | 96 | class TestData(data.Dataset): 97 | def __init__(self, test_img_file, test_label, transform=None, img_size = (144,288)): 98 | 99 | test_image = [] 100 | for i in range(len(test_img_file)): 101 | img = Image.open(test_img_file[i]) 102 | img = img.resize((img_size[0], img_size[1]), Image.ANTIALIAS) 103 | pix_array = np.array(img) 104 | test_image.append(pix_array) 105 | test_image = np.array(test_image) 106 | self.test_image = test_image 107 | self.test_label = test_label 108 | self.transform = transform 109 | 110 | def __getitem__(self, index): 111 | img1, target1 = self.test_image[index], self.test_label[index] 112 | img1 = self.transform(img1) 113 | return img1, target1 114 | 115 | def __len__(self): 116 | return len(self.test_image) 117 | 118 | class TestDataOld(data.Dataset): 119 | def __init__(self, data_dir, test_img_file, test_label, transform=None, img_size = (144,288)): 120 | 121 | test_image = [] 122 | for i in range(len(test_img_file)): 123 | img = Image.open(data_dir + test_img_file[i]) 124 | img = img.resize((img_size[0], img_size[1]), Image.ANTIALIAS) 125 | pix_array = np.array(img) 126 | test_image.append(pix_array) 127 | test_image = np.array(test_image) 128 | self.test_image = test_image 129 | self.test_label = test_label 130 | self.transform = transform 131 | 132 | def __getitem__(self, index): 133 | img1, target1 = self.test_image[index], self.test_label[index] 134 | img1 = self.transform(img1) 135 | return img1, target1 136 | 137 | def __len__(self): 138 | return len(self.test_image) 139 | def load_data(input_data_path ): 140 | with open(input_data_path) as f: 141 | data_file_list = open(input_data_path, 'rt').read().splitlines() 142 | # Get full list of image and labels 143 | file_image = [s.split(' ')[0] for s in data_file_list] 144 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 145 | 146 | return file_image, file_label -------------------------------------------------------------------------------- /DDAG/data_manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import sys 4 | import numpy as np 5 | import random 6 | 7 | def process_query_sysu(data_path, mode = 'all', relabel=False): 8 | if mode== 'all': 9 | ir_cameras = ['cam3','cam6'] 10 | elif mode =='indoor': 11 | ir_cameras = ['cam3','cam6'] 12 | 13 | file_path = os.path.join(data_path,'exp/test_id.txt') 14 | files_rgb = [] 15 | files_ir = [] 16 | 17 | with open(file_path, 'r') as file: 18 | ids = file.read().splitlines() 19 | ids = [int(y) for y in ids[0].split(',')] 20 | ids = ["%04d" % x for x in ids] 21 | 22 | for id in sorted(ids): 23 | for cam in ir_cameras: 24 | img_dir = os.path.join(data_path,cam,id) 25 | if os.path.isdir(img_dir): 26 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 27 | files_ir.extend(new_files) 28 | query_img = [] 29 | query_id = [] 30 | query_cam = [] 31 | for img_path in files_ir: 32 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 33 | query_img.append(img_path) 34 | query_id.append(pid) 35 | query_cam.append(camid) 36 | return query_img, np.array(query_id), np.array(query_cam) 37 | 38 | def process_gallery_sysu(data_path, mode = 'all', trial = 0, relabel=False): 39 | 40 | random.seed(trial) 41 | 42 | if mode== 'all': 43 | rgb_cameras = ['cam1','cam2','cam4','cam5'] 44 | elif mode =='indoor': 45 | rgb_cameras = ['cam1','cam2'] 46 | 47 | file_path = os.path.join(data_path,'exp/test_id.txt') 48 | files_rgb = [] 49 | with open(file_path, 'r') as file: 50 | ids = file.read().splitlines() 51 | ids = [int(y) for y in ids[0].split(',')] 52 | ids = ["%04d" % x for x in ids] 53 | 54 | for id in sorted(ids): 55 | for cam in rgb_cameras: 56 | img_dir = os.path.join(data_path,cam,id) 57 | if os.path.isdir(img_dir): 58 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 59 | files_rgb.append(random.choice(new_files)) 60 | gall_img = [] 61 | gall_id = [] 62 | gall_cam = [] 63 | for img_path in files_rgb: 64 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 65 | gall_img.append(img_path) 66 | gall_id.append(pid) 67 | gall_cam.append(camid) 68 | return gall_img, np.array(gall_id), np.array(gall_cam) 69 | 70 | def process_gallery_sysu_multishot(data_path, mode='all', trial=0, relabel=False): 71 | random.seed(trial) 72 | 73 | if mode == 'all': 74 | rgb_cameras = ['cam1', 'cam2', 'cam4', 'cam5'] 75 | elif mode == 'indoor': 76 | rgb_cameras = ['cam1', 'cam2'] 77 | 78 | file_path = os.path.join(data_path, 'exp/test_id.txt') 79 | files_rgb = [] 80 | with open(file_path, 'r') as file: 81 | ids = file.read().splitlines() 82 | ids = [int(y) for y in ids[0].split(',')] 83 | ids = ["%04d" % x for x in ids] 84 | 85 | for id in sorted(ids): 86 | for cam in rgb_cameras: 87 | img_dir = os.path.join(data_path, cam, id) 88 | if os.path.isdir(img_dir): 89 | new_files = sorted([img_dir + '/' + i for i in os.listdir(img_dir)]) 90 | files_rgb = files_rgb + random.sample(new_files, 10) 91 | gall_img = [] 92 | gall_id = [] 93 | gall_cam = [] 94 | for img_path in files_rgb: 95 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 96 | gall_img.append(img_path) 97 | gall_id.append(pid) 98 | gall_cam.append(camid) 99 | return gall_img, np.array(gall_id), np.array(gall_cam) 100 | 101 | def process_test_regdb(img_dir, trial = 1, modal = 'visible'): 102 | if modal=='visible': 103 | input_data_path = img_dir + 'idx/test_visible_{}'.format(trial) + '.txt' 104 | elif modal=='thermal': 105 | input_data_path = img_dir + 'idx/test_thermal_{}'.format(trial) + '.txt' 106 | 107 | with open(input_data_path) as f: 108 | data_file_list = open(input_data_path, 'rt').read().splitlines() 109 | # Get full list of image and labels 110 | file_image = [img_dir + '/' + s.split(' ')[0] for s in data_file_list] 111 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 112 | 113 | return file_image, np.array(file_label) -------------------------------------------------------------------------------- /DDAG/eval_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import numpy as np 3 | """Cross-Modality ReID""" 4 | 5 | def eval_sysu(distmat, q_pids, g_pids, q_camids, g_camids, max_rank = 20): 6 | """Evaluation with sysu metric 7 | Key: for each query identity, its gallery images from the same camera view are discarded. "Following the original setting in ite dataset" 8 | """ 9 | num_q, num_g = distmat.shape 10 | if num_g < max_rank: 11 | max_rank = num_g 12 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 13 | indices = np.argsort(distmat, axis=1) 14 | pred_label = g_pids[indices] 15 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 16 | 17 | # compute cmc curve for each query 18 | new_all_cmc = [] 19 | all_cmc = [] 20 | all_AP = [] 21 | all_INP = [] 22 | num_valid_q = 0. # number of valid query 23 | for q_idx in range(num_q): 24 | # get query pid and camid 25 | q_pid = q_pids[q_idx] 26 | q_camid = q_camids[q_idx] 27 | 28 | # remove gallery samples that have the same pid and camid with query 29 | order = indices[q_idx] 30 | remove = (q_camid == 3) & (g_camids[order] == 2) 31 | keep = np.invert(remove) 32 | 33 | # compute cmc curve 34 | # the cmc calculation is different from standard protocol 35 | # we follow the protocol of the author's released code 36 | new_cmc = pred_label[q_idx][keep] 37 | new_index = np.unique(new_cmc, return_index=True)[1] 38 | new_cmc = [new_cmc[index] for index in sorted(new_index)] 39 | 40 | new_match = (new_cmc == q_pid).astype(np.int32) 41 | new_cmc = new_match.cumsum() 42 | new_all_cmc.append(new_cmc[:max_rank]) 43 | 44 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 45 | if not np.any(orig_cmc): 46 | # this condition is true when query identity does not appear in gallery 47 | continue 48 | 49 | cmc = orig_cmc.cumsum() 50 | 51 | # compute mINP 52 | # refernece: Deep Learning for Person Re-identification: A Survey and Outlook 53 | pos_idx = np.where(orig_cmc == 1) 54 | pos_max_idx = np.max(pos_idx) 55 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0) 56 | all_INP.append(inp) 57 | 58 | cmc[cmc > 1] = 1 59 | 60 | all_cmc.append(cmc[:max_rank]) 61 | num_valid_q += 1. 62 | 63 | # compute average precision 64 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 65 | num_rel = orig_cmc.sum() 66 | tmp_cmc = orig_cmc.cumsum() 67 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 68 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 69 | AP = tmp_cmc.sum() / num_rel 70 | all_AP.append(AP) 71 | 72 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 73 | 74 | all_cmc = np.asarray(all_cmc).astype(np.float32) 75 | all_cmc = all_cmc.sum(0) / num_valid_q # standard CMC 76 | 77 | new_all_cmc = np.asarray(new_all_cmc).astype(np.float32) 78 | new_all_cmc = new_all_cmc.sum(0) / num_valid_q 79 | mAP = np.mean(all_AP) 80 | mINP = np.mean(all_INP) 81 | return new_all_cmc, mAP, mINP 82 | 83 | 84 | 85 | def eval_regdb(distmat, q_pids, g_pids, max_rank = 20): 86 | num_q, num_g = distmat.shape 87 | if num_g < max_rank: 88 | max_rank = num_g 89 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 90 | indices = np.argsort(distmat, axis=1) 91 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 92 | 93 | # compute cmc curve for each query 94 | all_cmc = [] 95 | all_AP = [] 96 | all_INP = [] 97 | num_valid_q = 0. # number of valid query 98 | 99 | # only two cameras 100 | q_camids = np.ones(num_q).astype(np.int32) 101 | g_camids = 2* np.ones(num_g).astype(np.int32) 102 | 103 | for q_idx in range(num_q): 104 | # get query pid and camid 105 | q_pid = q_pids[q_idx] 106 | q_camid = q_camids[q_idx] 107 | 108 | # remove gallery samples that have the same pid and camid with query 109 | order = indices[q_idx] 110 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 111 | keep = np.invert(remove) 112 | 113 | # compute cmc curve 114 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 115 | if not np.any(raw_cmc): 116 | # this condition is true when query identity does not appear in gallery 117 | continue 118 | 119 | cmc = raw_cmc.cumsum() 120 | 121 | # compute mINP 122 | # refernece: Deep Learning for Person Re-identification: A Survey and Outlook 123 | pos_idx = np.where(raw_cmc == 1) 124 | pos_max_idx = np.max(pos_idx) 125 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0) 126 | all_INP.append(inp) 127 | 128 | cmc[cmc > 1] = 1 129 | 130 | all_cmc.append(cmc[:max_rank]) 131 | num_valid_q += 1. 132 | 133 | # compute average precision 134 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 135 | num_rel = raw_cmc.sum() 136 | tmp_cmc = raw_cmc.cumsum() 137 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 138 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 139 | AP = tmp_cmc.sum() / num_rel 140 | all_AP.append(AP) 141 | 142 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 143 | 144 | all_cmc = np.asarray(all_cmc).astype(np.float32) 145 | all_cmc = all_cmc.sum(0) / num_valid_q 146 | mAP = np.mean(all_AP) 147 | mINP = np.mean(all_INP) 148 | return all_cmc, mAP, mINP -------------------------------------------------------------------------------- /DDAG/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd.function import Function 6 | from torch.autograd import Variable 7 | 8 | class KLLoss(nn.Module): 9 | def __init__(self): 10 | super(KLLoss, self).__init__() 11 | def forward(self, pred, label): 12 | # pred: 2D matrix (batch_size, num_classes) 13 | # label: 1D vector indicating class number 14 | T=3 15 | 16 | predict = F.log_softmax(pred/T,dim=1) 17 | target_data = F.softmax(label/T,dim=1) 18 | target_data =target_data+10**(-7) 19 | target = Variable(target_data.data.cuda(),requires_grad=False) 20 | loss=T*T*((target*(target.log()-predict)).sum(1).sum()/target.size()[0]) 21 | return loss 22 | 23 | class OriTripletLoss(nn.Module): 24 | """Triplet loss with hard positive/negative mining. 25 | 26 | Reference: 27 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 28 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 29 | 30 | Args: 31 | - margin (float): margin for triplet. 32 | """ 33 | 34 | def __init__(self, batch_size, margin=0.3): 35 | super(OriTripletLoss, self).__init__() 36 | self.margin = margin 37 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 38 | 39 | def forward(self, inputs, targets): 40 | """ 41 | Args: 42 | - inputs: feature matrix with shape (batch_size, feat_dim) 43 | - targets: ground truth labels with shape (num_classes) 44 | """ 45 | n = inputs.size(0) 46 | 47 | # Compute pairwise distance, replace by the official when merged 48 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 49 | dist = dist + dist.t() 50 | dist.addmm_(1, -2, inputs, inputs.t()) 51 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 52 | 53 | # For each anchor, find the hardest positive and negative 54 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 55 | dist_ap, dist_an = [], [] 56 | for i in range(n): 57 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 58 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 59 | dist_ap = torch.cat(dist_ap) 60 | dist_an = torch.cat(dist_an) 61 | 62 | # Compute ranking hinge loss 63 | y = torch.ones_like(dist_an) 64 | loss = self.ranking_loss(dist_an, dist_ap, y) 65 | 66 | # compute accuracy 67 | correct = torch.ge(dist_an, dist_ap).sum().item() 68 | return loss, correct 69 | 70 | 71 | class TripletLoss(nn.Module): 72 | """Triplet loss with hard positive/negative mining. 73 | 74 | Reference: 75 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 76 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 77 | 78 | Args: 79 | - margin (float): margin for triplet. 80 | """ 81 | def __init__(self, batch_size, margin=0.5): 82 | super(TripletLoss, self).__init__() 83 | self.margin = margin 84 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 85 | self.batch_size = batch_size 86 | self.mask = torch.eye(batch_size) 87 | def forward(self, input, target): 88 | """ 89 | Args: 90 | - input: feature matrix with shape (batch_size, feat_dim) 91 | - target: ground truth labels with shape (num_classes) 92 | """ 93 | n = self.batch_size 94 | input1 = input.narrow(0,0,n) 95 | input2 = input.narrow(0,n,n) 96 | 97 | # Compute pairwise distance, replace by the official when merged 98 | dist = pdist_torch(input1, input2) 99 | 100 | # For each anchor, find the hardest positive and negative 101 | # mask = target1.expand(n, n).eq(target1.expand(n, n).t()) 102 | dist_ap, dist_an = [], [] 103 | for i in range(n): 104 | dist_ap.append(dist[i,i].unsqueeze(0)) 105 | dist_an.append(dist[i][self.mask[i] == 0].min().unsqueeze(0)) 106 | dist_ap = torch.cat(dist_ap) 107 | dist_an = torch.cat(dist_an) 108 | 109 | # Compute ranking hinge loss 110 | y = torch.ones_like(dist_an) 111 | loss = self.ranking_loss(dist_an, dist_ap, y) 112 | 113 | # compute accuracy 114 | correct = torch.ge(dist_an, dist_ap).sum().item() 115 | return loss, correct*2 116 | 117 | class BiTripletLoss(nn.Module): 118 | """Triplet loss with hard positive/negative mining. 119 | 120 | Reference: 121 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 122 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 123 | 124 | Args: 125 | - margin (float): margin for triplet.suffix 126 | """ 127 | def __init__(self, batch_size, margin=0.5): 128 | super(BiTripletLoss, self).__init__() 129 | self.margin = margin 130 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 131 | self.batch_size = batch_size 132 | self.mask = torch.eye(batch_size) 133 | def forward(self, input, target): 134 | """ 135 | Args: 136 | - input: feature matrix with shape (batch_size, feat_dim) 137 | - target: ground truth labels with shape (num_classes) 138 | """ 139 | n = self.batch_size 140 | input1 = input.narrow(0,0,n) 141 | input2 = input.narrow(0,n,n) 142 | 143 | # Compute pairwise distance, replace by the official when merged 144 | dist = pdist_torch(input1, input2) 145 | 146 | # For each anchor, find the hardest positive and negative 147 | # mask = target1.expand(n, n).eq(target1.expand(n, n).t()) 148 | dist_ap, dist_an = [], [] 149 | for i in range(n): 150 | dist_ap.append(dist[i,i].unsqueeze(0)) 151 | dist_an.append(dist[i][self.mask[i] == 0].min().unsqueeze(0)) 152 | dist_ap = torch.cat(dist_ap) 153 | dist_an = torch.cat(dist_an) 154 | 155 | # Compute ranking hinge loss 156 | y = torch.ones_like(dist_an) 157 | loss1 = self.ranking_loss(dist_an, dist_ap, y) 158 | 159 | # compute accuracy 160 | correct1 = torch.ge(dist_an, dist_ap).sum().item() 161 | 162 | # Compute pairwise distance, replace by the official when merged 163 | dist2 = pdist_torch(input2, input1) 164 | 165 | # For each anchor, find the hardest positive and negative 166 | dist_ap2, dist_an2 = [], [] 167 | for i in range(n): 168 | dist_ap2.append(dist2[i,i].unsqueeze(0)) 169 | dist_an2.append(dist2[i][self.mask[i] == 0].min().unsqueeze(0)) 170 | dist_ap2 = torch.cat(dist_ap2) 171 | dist_an2 = torch.cat(dist_an2) 172 | 173 | # Compute ranking hinge loss 174 | y2 = torch.ones_like(dist_an2) 175 | # loss2 = self.ranking_loss(dist_an2, dist_ap2, y2) 176 | 177 | loss2 = torch.sum(torch.nn.functional.relu(dist_ap2 + self.margin - dist_an2)) 178 | 179 | # compute accuracy 180 | correct2 = torch.ge(dist_an2, dist_ap2).sum().item() 181 | 182 | loss = torch.add(loss1, loss2) 183 | return loss, correct1 + correct2 184 | 185 | 186 | class BDTRLoss(nn.Module): 187 | """Triplet loss with hard positive/negative mining. 188 | 189 | Reference: 190 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 191 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 192 | 193 | Args: 194 | - margin (float): margin for triplet.suffix 195 | """ 196 | def __init__(self, batch_size, margin=0.5): 197 | super(BDTRLoss, self).__init__() 198 | self.margin = margin 199 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 200 | self.batch_size = batch_size 201 | self.mask = torch.eye(batch_size) 202 | def forward(self, inputs, targets): 203 | """ 204 | Args: 205 | - input: feature matrix with shape (batch_size, feat_dim) 206 | - target: ground truth labels with shape (num_classes) 207 | """ 208 | n = inputs.size(0) 209 | 210 | # Compute pairwise distance, replace by the official when merged 211 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 212 | dist = dist + dist.t() 213 | dist.addmm_(1, -2, inputs, inputs.t()) 214 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 215 | 216 | # For each anchor, find the hardest positive and negative 217 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 218 | dist_ap, dist_an = [], [] 219 | for i in range(n): 220 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 221 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 222 | dist_ap = torch.cat(dist_ap) 223 | dist_an = torch.cat(dist_an) 224 | 225 | # Compute ranking hinge loss 226 | y = torch.ones_like(dist_an) 227 | loss = self.ranking_loss(dist_an, dist_ap, y) 228 | correct = torch.ge(dist_an, dist_ap).sum().item() 229 | return loss, correct 230 | 231 | def pdist_torch(emb1, emb2): 232 | ''' 233 | compute the eucilidean distance matrix between embeddings1 and embeddings2 234 | using gpu 235 | ''' 236 | m, n = emb1.shape[0], emb2.shape[0] 237 | emb1_pow = torch.pow(emb1, 2).sum(dim = 1, keepdim = True).expand(m, n) 238 | emb2_pow = torch.pow(emb2, 2).sum(dim = 1, keepdim = True).expand(n, m).t() 239 | dist_mtx = emb1_pow + emb2_pow 240 | dist_mtx = dist_mtx.addmm_(1, -2, emb1, emb2.t()) 241 | # dist_mtx = dist_mtx.clamp(min = 1e-12) 242 | dist_mtx = dist_mtx.clamp(min = 1e-12).sqrt() 243 | return dist_mtx 244 | 245 | 246 | def pdist_np(emb1, emb2): 247 | ''' 248 | compute the eucilidean distance matrix between embeddings1 and embeddings2 249 | using cpu 250 | ''' 251 | m, n = emb1.shape[0], emb2.shape[0] 252 | emb1_pow = np.square(emb1).sum(axis = 1)[..., np.newaxis] 253 | emb2_pow = np.square(emb2).sum(axis = 1)[np.newaxis, ...] 254 | dist_mtx = -2 * np.matmul(emb1, emb2.T) + emb1_pow + emb2_pow 255 | # dist_mtx = np.sqrt(dist_mtx.clip(min = 1e-12)) 256 | return dist_mtx -------------------------------------------------------------------------------- /DDAG/model_main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from torchvision import models 5 | from torch.autograd import Variable 6 | from resnet import resnet50, resnet18 7 | import torch.nn.functional as F 8 | import math 9 | from attention import GraphAttentionLayer, IWPA 10 | 11 | class Normalize(nn.Module): 12 | def __init__(self, power=2): 13 | super(Normalize, self).__init__() 14 | self.power = power 15 | 16 | def forward(self, x): 17 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 18 | out = x.div(norm) 19 | return out 20 | 21 | # ##################################################################### 22 | def weights_init_kaiming(m): 23 | classname = m.__class__.__name__ 24 | # print(classname) 25 | if classname.find('Conv') != -1: 26 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 27 | elif classname.find('Linear') != -1: 28 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 29 | init.zeros_(m.bias.data) 30 | elif classname.find('BatchNorm1d') != -1: 31 | init.normal_(m.weight.data, 1.0, 0.01) 32 | init.zeros_(m.bias.data) 33 | 34 | 35 | def weights_init_classifier(m): 36 | classname = m.__class__.__name__ 37 | if classname.find('Linear') != -1: 38 | init.normal_(m.weight.data, 0, 0.001) 39 | if m.bias: 40 | init.zeros_(m.bias.data) 41 | 42 | # Defines the new fc layer and classification layer 43 | # |--Linear--|--bn--|--relu--|--Linear--| 44 | class FeatureBlock(nn.Module): 45 | def __init__(self, input_dim, low_dim, dropout=0.5, relu=True): 46 | super(FeatureBlock, self).__init__() 47 | feat_block = [] 48 | feat_block += [nn.Linear(input_dim, low_dim)] 49 | feat_block += [nn.BatchNorm1d(low_dim)] 50 | 51 | feat_block = nn.Sequential(*feat_block) 52 | feat_block.apply(weights_init_kaiming) 53 | self.feat_block = feat_block 54 | 55 | def forward(self, x): 56 | x = self.feat_block(x) 57 | return x 58 | 59 | 60 | class ClassBlock(nn.Module): 61 | def __init__(self, input_dim, class_num, dropout=0.5, relu=True): 62 | super(ClassBlock, self).__init__() 63 | classifier = [] 64 | if relu: 65 | classifier += [nn.LeakyReLU(0.1)] 66 | if dropout: 67 | classifier += [nn.Dropout(p=dropout)] 68 | 69 | classifier += [nn.Linear(input_dim, class_num)] 70 | classifier = nn.Sequential(*classifier) 71 | classifier.apply(weights_init_classifier) 72 | 73 | self.classifier = classifier 74 | 75 | def forward(self, x): 76 | x = self.classifier(x) 77 | return x 78 | 79 | class visible_module(nn.Module): 80 | def __init__(self, arch='resnet50'): 81 | super(visible_module, self).__init__() 82 | 83 | model_v = resnet50(pretrained=False, 84 | last_conv_stride=1, last_conv_dilation=1) 85 | # avg pooling to global pooling 86 | self.visible = model_v 87 | 88 | def forward(self, x): 89 | x = self.visible.conv1(x) 90 | x = self.visible.bn1(x) 91 | x = self.visible.relu(x) 92 | x = self.visible.maxpool(x) 93 | return x 94 | 95 | 96 | class thermal_module(nn.Module): 97 | def __init__(self, arch='resnet50'): 98 | super(thermal_module, self).__init__() 99 | 100 | model_t = resnet50(pretrained=False, 101 | last_conv_stride=1, last_conv_dilation=1) 102 | # avg pooling to global pooling 103 | self.thermal = model_t 104 | 105 | def forward(self, x): 106 | x = self.thermal.conv1(x) 107 | x = self.thermal.bn1(x) 108 | x = self.thermal.relu(x) 109 | x = self.thermal.maxpool(x) 110 | return x 111 | 112 | 113 | class base_resnet(nn.Module): 114 | def __init__(self, arch='resnet50'): 115 | super(base_resnet, self).__init__() 116 | 117 | model_base = resnet50(pretrained=False, 118 | last_conv_stride=1, last_conv_dilation=1) 119 | # avg pooling to global pooling 120 | model_base.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 121 | self.base = model_base 122 | 123 | def forward(self, x): 124 | x = self.base.layer1(x) 125 | x = self.base.layer2(x) 126 | x = self.base.layer3(x) 127 | x = self.base.layer4(x) 128 | return x 129 | 130 | 131 | class embed_net(nn.Module): 132 | def __init__(self, low_dim, class_num, drop=0.2, part = 3, alpha=0.2, nheads=4, arch='resnet50', wpa = False): 133 | super(embed_net, self).__init__() 134 | 135 | self.thermal_module = thermal_module(arch=arch) 136 | self.visible_module = visible_module(arch=arch) 137 | self.base_resnet = base_resnet(arch=arch) 138 | pool_dim = 2048 139 | self.dropout = drop 140 | self.part = part 141 | self.lpa = wpa 142 | 143 | self.l2norm = Normalize(2) 144 | #self.bb = nn.BatchNorm2d(pool_dim) 145 | self.bottleneck = nn.BatchNorm1d(pool_dim) 146 | self.bottleneck.bias.requires_grad_(False) # no shift 147 | 148 | self.classifier = nn.Linear(pool_dim, class_num, bias=False) 149 | 150 | self.classifier1 = nn.Linear(pool_dim, class_num, bias=False) 151 | self.classifier2 = nn.Linear(pool_dim, class_num, bias=False) 152 | 153 | self.bottleneck.apply(weights_init_kaiming) 154 | #self.bb.apply(weights_init_kaiming) 155 | self.classifier.apply(weights_init_classifier) 156 | self.classifier1.apply(weights_init_classifier) 157 | self.classifier2.apply(weights_init_classifier) 158 | 159 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 160 | self.wpa = IWPA(pool_dim, part) 161 | 162 | self.attentions = [GraphAttentionLayer(pool_dim, low_dim, dropout=drop, alpha=alpha, concat=True) for _ in range(nheads)] 163 | for i, attention in enumerate(self.attentions): 164 | self.add_module('attention_{}'.format(i), attention) 165 | 166 | self.out_att = GraphAttentionLayer(low_dim * nheads, class_num, dropout=drop, alpha=alpha, concat=False) 167 | 168 | def forward(self, x1, x2, adj, modal=0, cpa = False): 169 | # domain specific block 170 | if modal == 0: 171 | x1 = self.visible_module(x1) 172 | x2 = self.thermal_module(x2) 173 | x = torch.cat((x1, x2), 0) 174 | elif modal == 1: 175 | x = self.visible_module(x1) 176 | elif modal == 2: 177 | x = self.thermal_module(x2) 178 | 179 | # shared four blocks 180 | x = self.base_resnet(x) 181 | x_pool = self.avgpool(x) 182 | x_pool = x_pool.view(x_pool.size(0), x_pool.size(1)) 183 | feat = self.bottleneck(x_pool) 184 | 185 | if self.lpa: 186 | # intra-modality weighted part attention 187 | feat_att = self.wpa(x, feat, 1, self.part) 188 | 189 | if self.training: 190 | # cross-modality graph attention 191 | x_g = F.dropout(x_pool, self.dropout, training=self.training) 192 | x_g = torch.cat([att(x_g, adj) for att in self.attentions], dim=1) 193 | x_g = F.dropout(x_g, self.dropout, training=self.training) 194 | x_g = F.elu(self.out_att(x_g, adj)) 195 | return x_pool, self.classifier(feat), self.classifier(feat_att), F.log_softmax(x_g, dim=1) 196 | else: 197 | return self.l2norm(feat), self.l2norm(feat_att) 198 | -------------------------------------------------------------------------------- /DDAG/pre_process_sysu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import pdb 4 | import os 5 | 6 | data_path = 'SYSU-MM01' 7 | 8 | rgb_cameras = ['cam1','cam2','cam4','cam5'] 9 | ir_cameras = ['cam3','cam6'] 10 | 11 | # load id info 12 | file_path_train = os.path.join(data_path,'exp/train_id.txt') 13 | file_path_val = os.path.join(data_path,'exp/val_id.txt') 14 | with open(file_path_train, 'r') as file: 15 | ids = file.read().splitlines() 16 | ids = [int(y) for y in ids[0].split(',')] 17 | id_train = ["%04d" % x for x in ids] 18 | 19 | with open(file_path_val, 'r') as file: 20 | ids = file.read().splitlines() 21 | ids = [int(y) for y in ids[0].split(',')] 22 | id_val = ["%04d" % x for x in ids] 23 | 24 | # combine train and val split 25 | id_train.extend(id_val) 26 | 27 | files_rgb = [] 28 | files_ir = [] 29 | for id in sorted(id_train): 30 | for cam in rgb_cameras: 31 | img_dir = os.path.join(data_path,cam,id) 32 | if os.path.isdir(img_dir): 33 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 34 | files_rgb.extend(new_files) 35 | 36 | for cam in ir_cameras: 37 | img_dir = os.path.join(data_path,cam,id) 38 | if os.path.isdir(img_dir): 39 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 40 | files_ir.extend(new_files) 41 | 42 | # relabel 43 | pid_container = set() 44 | for img_path in files_ir: 45 | pid = int(img_path[-13:-9]) 46 | pid_container.add(pid) 47 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 48 | fix_image_width = 144 49 | fix_image_height = 288 50 | def read_imgs(train_image): 51 | train_img = [] 52 | train_label = [] 53 | train_cam = [] 54 | for img_path in train_image: 55 | # img 56 | img = Image.open(img_path) 57 | img = img.resize((fix_image_width, fix_image_height), Image.ANTIALIAS) 58 | pix_array = np.array(img) 59 | 60 | train_img.append(pix_array) 61 | 62 | # label 63 | pid = int(img_path[-13:-9]) 64 | cid = int(img_path[-15]) 65 | pid = pid2label[pid] 66 | train_label.append(pid) 67 | train_cam.append(cid) 68 | return np.array(train_img), np.array(train_label), np.array(train_cam) 69 | 70 | # rgb imges 71 | train_img, train_label, train_cam = read_imgs(files_rgb) 72 | np.save(data_path + 'train_rgb_resized_img.npy', train_img) 73 | np.save(data_path + 'train_rgb_resized_label.npy', train_label) 74 | np.save(data_path + 'train_rgb_resized_cam.npy', train_cam) 75 | 76 | # ir imges 77 | train_img, train_label, train_cam = read_imgs(files_ir) 78 | np.save(data_path + 'train_ir_resized_img.npy', train_img) 79 | np.save(data_path + 'train_ir_resized_label.npy', train_label) 80 | np.save(data_path + 'train_ir_resized_cam.npy', train_cam) 81 | -------------------------------------------------------------------------------- /DDAG/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 18 | """3x3 convolution with padding""" 19 | # original padding is 1; original dilation is 1 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=dilation, bias=False, dilation=dilation) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride, dilation) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | # original padding is 1; original dilation is 1 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * 4) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | 97 | def __init__(self, block, layers, last_conv_stride=2, last_conv_dilation=1): 98 | 99 | self.inplanes = 64 100 | super(ResNet, self).__init__() 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 102 | bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_conv_stride, dilation=last_conv_dilation) 110 | 111 | for m in self.modules(): 112 | if isinstance(m, nn.Conv2d): 113 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 114 | m.weight.data.normal_(0, math.sqrt(2. / n)) 115 | elif isinstance(m, nn.BatchNorm2d): 116 | m.weight.data.fill_(1) 117 | m.bias.data.zero_() 118 | 119 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 120 | downsample = None 121 | if stride != 1 or self.inplanes != planes * block.expansion: 122 | downsample = nn.Sequential( 123 | nn.Conv2d(self.inplanes, planes * block.expansion, 124 | kernel_size=1, stride=stride, bias=False), 125 | nn.BatchNorm2d(planes * block.expansion), 126 | ) 127 | 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, stride, downsample, dilation)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append(block(self.inplanes, planes)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.maxpool(x) 141 | 142 | x = self.layer1(x) 143 | x = self.layer2(x) 144 | x = self.layer3(x) 145 | x = self.layer4(x) 146 | 147 | return x 148 | 149 | 150 | def remove_fc(state_dict): 151 | """Remove the fc layer parameters from state_dict.""" 152 | # for key, value in state_dict.items(): 153 | for key, value in list(state_dict.items()): 154 | if key.startswith('fc.'): 155 | del state_dict[key] 156 | return state_dict 157 | 158 | 159 | def resnet18(pretrained=False, **kwargs): 160 | """Constructs a ResNet-18 model. 161 | Args: 162 | pretrained (bool): If True, returns a model pre-trained on ImageNet 163 | """ 164 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 165 | if pretrained: 166 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet18']))) 167 | return model 168 | 169 | 170 | def resnet34(pretrained=False, **kwargs): 171 | """Constructs a ResNet-34 model. 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 176 | if pretrained: 177 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet34']))) 178 | return model 179 | 180 | 181 | def resnet50(pretrained=False, **kwargs): 182 | """Constructs a ResNet-50 model. 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | """ 186 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 187 | if pretrained: 188 | # model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50']))) 189 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50']))) 190 | return model 191 | 192 | 193 | def resnet101(pretrained=False, **kwargs): 194 | """Constructs a ResNet-101 model. 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | """ 198 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 199 | if pretrained: 200 | model.load_state_dict( 201 | remove_fc(model_zoo.load_url(model_urls['resnet101']))) 202 | return model 203 | 204 | 205 | def resnet152(pretrained=False, **kwargs): 206 | """Constructs a ResNet-152 model. 207 | Args: 208 | pretrained (bool): If True, returns a model pre-trained on ImageNet 209 | """ 210 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 211 | if pretrained: 212 | model.load_state_dict( 213 | remove_fc(model_zoo.load_url(model_urls['resnet152']))) 214 | return model -------------------------------------------------------------------------------- /DDAG/test_ddag.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import sys 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | import torch.backends.cudnn as cudnn 9 | from torch.autograd import Variable 10 | import torch.utils.data as data 11 | import torchvision 12 | import torchvision.transforms as transforms 13 | from data_loader import SYSUData, RegDBData, TestData 14 | from data_manager import * 15 | from eval_metrics import eval_sysu, eval_regdb 16 | from model_main import embed_net 17 | from utils import * 18 | 19 | import time 20 | import scipy.io as scio 21 | 22 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 23 | parser.add_argument('--dataset', default='sysu', help='dataset name: regdb or sysu]') 24 | parser.add_argument('--lr', default=0.01, type=float, help='learning rate') 25 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer') 26 | parser.add_argument('--arch', default='resnet50', type=str, help='network baseline') 27 | parser.add_argument('--resume', '-r', default='', type=str, help='resume from checkpoint') 28 | parser.add_argument('--model_path', default='save_model/', type=str, help='model save path') 29 | parser.add_argument('--log_path', default='log/', type=str, help='log save path') 30 | parser.add_argument('--workers', default=4, type=int, metavar='N', 31 | help='number of data loading workers (default: 4)') 32 | parser.add_argument('--low-dim', default=512, type=int, 33 | metavar='D', help='feature dimension') 34 | parser.add_argument('--img_w', default=144, type=int, 35 | metavar='imgw', help='img width') 36 | parser.add_argument('--img_h', default=288, type=int, 37 | metavar='imgh', help='img height') 38 | parser.add_argument('--batch-size', default=32, type=int, 39 | metavar='B', help='training batch size') 40 | parser.add_argument('--part', default=3, type=int, 41 | metavar='tb', help=' part number') 42 | parser.add_argument('--test-batch', default=64, type=int, 43 | metavar='tb', help='testing batch size') 44 | parser.add_argument('--method', default='id', type=str, 45 | metavar='m', help='Method type') 46 | parser.add_argument('--drop', default=0.0, type=float, 47 | metavar='drop', help='dropout ratio') 48 | parser.add_argument('--trial', default=1, type=int, 49 | metavar='t', help='trial') 50 | parser.add_argument('--gpu', default='0', type=str, 51 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 52 | parser.add_argument('--mode', default='all', type=str, help='all or indoor') 53 | parser.add_argument('--shot', default='single', type=str, help='single or multiple shot') 54 | parser.add_argument('--graph', action='store_true', help='either add graph learning') 55 | parser.add_argument('--wpa', action='store_true', help='either add weighted part attention') 56 | args = parser.parse_args() 57 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 58 | np.random.seed(1) 59 | dataset = args.dataset 60 | if dataset == 'sysu': 61 | # TODO: define your data path for RegDB dataset 62 | data_path = '../SYSU_MM01/' 63 | n_class = 395 64 | test_mode = [1, 2] 65 | elif dataset =='regdb': 66 | # TODO: define your data path for RegDB dataset 67 | data_path = 'YOUR DATA PATH' 68 | n_class = 206 69 | test_mode = [2, 1] 70 | 71 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 72 | best_acc = 0 # best test accuracy 73 | start_epoch = 0 74 | 75 | print('==> Building model..') 76 | net = embed_net(args.low_dim, n_class, drop=args.drop, part=args.part, arch=args.arch, wpa=args.wpa) 77 | net.to(device) 78 | cudnn.benchmark = True 79 | 80 | print('==> Resuming from checkpoint..') 81 | checkpoint_path = args.model_path 82 | if len(args.resume)>0: 83 | model_path = checkpoint_path + args.resume 84 | # model_path = checkpoint_path + 'test_best.t' 85 | if os.path.isfile(model_path): 86 | print('==> loading checkpoint {}'.format(args.resume)) 87 | checkpoint = torch.load(model_path) 88 | start_epoch = checkpoint['epoch'] 89 | # pdb.set_trace() 90 | net.load_state_dict(checkpoint['net']) 91 | print('==> loaded checkpoint {} (epoch {})' 92 | .format(args.resume, checkpoint['epoch'])) 93 | else: 94 | print('==> no checkpoint found at {}!!!!!!!!!!'.format(args.resume)) 95 | 96 | 97 | if args.method =='id': 98 | criterion = nn.CrossEntropyLoss() 99 | criterion.to(device) 100 | 101 | print('==> Loading data..') 102 | # Data loading code 103 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 104 | transform_train = transforms.Compose([ 105 | transforms.ToPILImage(), 106 | # transforms.Resize((280,150), interpolation=2), 107 | transforms.RandomCrop((args.img_h,args.img_w)), 108 | transforms.RandomHorizontalFlip(), 109 | transforms.ToTensor(), 110 | normalize, 111 | ]) 112 | 113 | transform_test = transforms.Compose([ 114 | transforms.ToPILImage(), 115 | transforms.Resize((args.img_h,args.img_w)), 116 | transforms.ToTensor(), 117 | normalize, 118 | ]) 119 | 120 | end = time.time() 121 | 122 | if dataset =='sysu': 123 | # testing set 124 | if args.shot == 'single': 125 | query_img, query_label, query_cam = process_query_sysu(data_path, mode = args.mode) 126 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode = args.mode, trial = 0) 127 | 128 | nquery = len(query_label) 129 | ngall = len(gall_label) 130 | print("Dataset statistics:") 131 | print(" ------------------------------") 132 | print(" subset | # ids | # images") 133 | print(" ------------------------------") 134 | print(" query | {:5d} | {:8d}".format(len(np.unique(query_label)), nquery)) 135 | print(" gallery | {:5d} | {:8d}".format(len(np.unique(gall_label)), ngall)) 136 | print(" ------------------------------") 137 | 138 | 139 | queryset = TestData(query_img, query_label, transform = transform_test, img_size =(args.img_w, args.img_h)) 140 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 141 | else: 142 | query_img, query_label, query_cam = process_query_sysu(data_path, mode = args.mode) 143 | gall_img, gall_label, gall_cam = process_gallery_sysu_multishot(data_path, mode = args.mode, trial = 0) 144 | 145 | nquery = len(query_label) 146 | ngall = len(gall_label) 147 | print("Dataset statistics:") 148 | print(" ------------------------------") 149 | print(" subset | # ids | # images") 150 | print(" ------------------------------") 151 | print(" query | {:5d} | {:8d}".format(len(np.unique(query_label)), nquery)) 152 | print(" gallery | {:5d} | {:8d}".format(len(np.unique(gall_label)), ngall)) 153 | print(" ------------------------------") 154 | 155 | 156 | queryset = TestData(query_img, query_label, transform = transform_test, img_size =(args.img_w, args.img_h)) 157 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 158 | 159 | elif dataset =='regdb': 160 | # training set 161 | trainset = RegDBData(data_path, args.trial, transform=transform_train) 162 | # generate the idx of each person identity 163 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 164 | 165 | # testing set 166 | query_img, query_label = process_test_regdb(data_path, trial = args.trial, modal = 'visible') 167 | gall_img, gall_label = process_test_regdb(data_path, trial = args.trial, modal = 'thermal') 168 | 169 | gallset = TestData(gall_img, gall_label, transform = transform_test, img_size =(args.img_w,args.img_h)) 170 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 171 | 172 | print('Data Loading Time:\t {:.3f}'.format(time.time()-end)) 173 | 174 | feature_dim = 2048 175 | if args.arch =='resnet50': 176 | pool_dim = 2048 177 | elif args.arch =='resnet18': 178 | pool_dim = 512 179 | 180 | def extract_gall_feat(gall_loader): 181 | net.eval() 182 | print ('Extracting Gallery Feature...') 183 | start = time.time() 184 | ptr = 0 185 | gall_feat = np.zeros((ngall, feature_dim)) 186 | gall_feat_att = np.zeros((ngall, pool_dim)) 187 | with torch.no_grad(): 188 | for batch_idx, (input, label ) in enumerate(gall_loader): 189 | batch_num = input.size(0) 190 | input = Variable(input.cuda()) 191 | feat, feat_att = net(input, input, 0, test_mode[0]) 192 | gall_feat[ptr:ptr+batch_num,: ] = feat.detach().cpu().numpy() 193 | gall_feat_att[ptr:ptr+batch_num,: ] = feat_att.detach().cpu().numpy() 194 | ptr = ptr + batch_num 195 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 196 | return gall_feat, gall_feat_att 197 | 198 | def extract_query_feat(query_loader): 199 | net.eval() 200 | print ('Extracting Query Feature...') 201 | start = time.time() 202 | ptr = 0 203 | query_feat = np.zeros((nquery, feature_dim)) 204 | query_feat_att = np.zeros((nquery, pool_dim)) 205 | with torch.no_grad(): 206 | for batch_idx, (input, label ) in enumerate(query_loader): 207 | batch_num = input.size(0) 208 | input = Variable(input.cuda()) 209 | feat, feat_att = net(input, input, 0, test_mode[1]) 210 | query_feat[ptr:ptr+batch_num,: ] = feat.detach().cpu().numpy() 211 | query_feat_att[ptr:ptr+batch_num,: ] = feat_att.detach().cpu().numpy() 212 | ptr = ptr + batch_num 213 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 214 | return query_feat, query_feat_att 215 | 216 | query_feat, query_feat_att = extract_query_feat(query_loader) 217 | 218 | all_cmc = 0 219 | all_mAP = 0 220 | all_cmc_pool = 0 221 | 222 | if args.shot == 'single': 223 | for trial in range(10): 224 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode = args.mode, trial = trial) 225 | 226 | trial_gallset = TestData(gall_img, gall_label, transform = transform_test,img_size =(args.img_w,args.img_h)) 227 | trial_gall_loader = data.DataLoader(trial_gallset, batch_size=args.test_batch, shuffle=False, num_workers=4) 228 | 229 | gall_feat, gall_feat_att = extract_gall_feat(trial_gall_loader) 230 | 231 | # fc feature 232 | distmat = np.matmul(query_feat, np.transpose(gall_feat)) 233 | cmc, mAP, mINP = eval_sysu(-distmat, query_label, gall_label,query_cam, gall_cam) 234 | 235 | # attention feature 236 | distmat_att = np.matmul(query_feat_att, np.transpose(gall_feat_att)) 237 | cmc_att, mAP_att, mINP_att = eval_sysu(-distmat_att, query_label, gall_label,query_cam, gall_cam) 238 | if trial ==0: 239 | all_cmc = cmc 240 | all_mAP = mAP 241 | all_mINP = mINP 242 | all_cmc_att = cmc_att 243 | all_mAP_att = mAP_att 244 | all_mINP_att = mINP_att 245 | else: 246 | all_cmc = all_cmc + cmc 247 | all_mAP = all_mAP + mAP 248 | all_mINP = all_mINP + mINP 249 | all_cmc_att = all_cmc_att + cmc_att 250 | all_mAP_att = all_mAP_att + mAP_att 251 | all_mINP_att = all_mINP_att + mINP_att 252 | 253 | print('Test Trial: {}'.format(trial)) 254 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 255 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 256 | print('FC_att: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 257 | cmc_att[0], cmc_att[4], cmc_att[9], cmc_att[19], mAP_att, mINP_att)) 258 | 259 | cmc = all_cmc /10 260 | mAP = all_mAP /10 261 | mINP = all_mINP /10 262 | 263 | cmc_att = all_cmc_att /10 264 | mAP_att = all_mAP_att /10 265 | mINP_att = all_mINP_att /10 266 | print ('All Average:') 267 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 268 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 269 | print('FC_att: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 270 | cmc_att[0], cmc_att[4], cmc_att[9], cmc_att[19], mAP_att, mINP_att)) 271 | else: 272 | for trial in range(10): 273 | gall_img, gall_label, gall_cam = process_gallery_sysu_multishot(data_path, mode = args.mode, trial = trial) 274 | 275 | trial_gallset = TestData(gall_img, gall_label, transform = transform_test,img_size =(args.img_w,args.img_h)) 276 | trial_gall_loader = data.DataLoader(trial_gallset, batch_size=args.test_batch, shuffle=False, num_workers=4) 277 | 278 | gall_feat, gall_feat_att = extract_gall_feat(trial_gall_loader) 279 | 280 | # fc feature 281 | distmat = np.matmul(query_feat, np.transpose(gall_feat)) 282 | cmc, mAP, mINP = eval_sysu(-distmat, query_label, gall_label,query_cam, gall_cam) 283 | 284 | # attention feature 285 | distmat_att = np.matmul(query_feat_att, np.transpose(gall_feat_att)) 286 | cmc_att, mAP_att, mINP_att = eval_sysu(-distmat_att, query_label, gall_label,query_cam, gall_cam) 287 | if trial ==0: 288 | all_cmc = cmc 289 | all_mAP = mAP 290 | all_mINP = mINP 291 | all_cmc_att = cmc_att 292 | all_mAP_att = mAP_att 293 | all_mINP_att = mINP_att 294 | else: 295 | all_cmc = all_cmc + cmc 296 | all_mAP = all_mAP + mAP 297 | all_mINP = all_mINP + mINP 298 | all_cmc_att = all_cmc_att + cmc_att 299 | all_mAP_att = all_mAP_att + mAP_att 300 | all_mINP_att = all_mINP_att + mINP_att 301 | 302 | print('Test Trial: {}'.format(trial)) 303 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 304 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 305 | print('FC_att: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 306 | cmc_att[0], cmc_att[4], cmc_att[9], cmc_att[19], mAP_att, mINP_att)) 307 | 308 | cmc = all_cmc /10 309 | mAP = all_mAP /10 310 | mINP = all_mINP /10 311 | 312 | cmc_att = all_cmc_att /10 313 | mAP_att = all_mAP_att /10 314 | mINP_att = all_mINP_att /10 315 | print ('All Average:') 316 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 317 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 318 | print('FC_att: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 319 | cmc_att[0], cmc_att[4], cmc_att[9], cmc_att[19], mAP_att, mINP_att)) 320 | 321 | -------------------------------------------------------------------------------- /DDAG/train_ddag.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import sys 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.backends.cudnn as cudnn 9 | from torch.autograd import Variable 10 | import torch.utils.data as data 11 | import torchvision 12 | import torchvision.transforms as transforms 13 | from data_loader import SYSUData, RegDBData, TestData 14 | from data_manager import * 15 | from eval_metrics import eval_sysu, eval_regdb 16 | from model_main import embed_net 17 | from utils import * 18 | from loss import OriTripletLoss 19 | from torch.optim import lr_scheduler 20 | from tensorboardX import SummaryWriter 21 | import torch.nn.functional as F 22 | import math 23 | 24 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 25 | parser.add_argument('--dataset', default='sysu', help='dataset name: regdb or sysu]') 26 | parser.add_argument('--lr', default=0.1 , type=float, help='learning rate, 0.00035 for adam') 27 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer') 28 | parser.add_argument('--arch', default='resnet50', type=str, 29 | help='network baseline:resnet50') 30 | parser.add_argument('--resume', '-r', default='', type=str, 31 | help='resume from checkpoint') 32 | parser.add_argument('--test-only', action='store_true', help='test only') 33 | parser.add_argument('--model_path', default='save_model/', type=str, 34 | help='model save path') 35 | parser.add_argument('--save_epoch', default=20, type=int, 36 | metavar='s', help='save model every 10 epochs') 37 | parser.add_argument('--log_path', default='log/', type=str, 38 | help='log save path') 39 | parser.add_argument('--vis_log_path', default='log/vis_log_ddag/', type=str, 40 | help='log save path') 41 | parser.add_argument('--workers', default=4, type=int, metavar='N', 42 | help='number of data loading workers (default: 4)') 43 | parser.add_argument('--low-dim', default=512, type=int, 44 | metavar='D', help='feature dimension') 45 | parser.add_argument('--img_w', default=144, type=int, 46 | metavar='imgw', help='img width') 47 | parser.add_argument('--img_h', default=288, type=int, 48 | metavar='imgh', help='img height') 49 | parser.add_argument('--batch-size', default=8, type=int, 50 | metavar='B', help='training batch size') 51 | parser.add_argument('--test-batch', default=64, type=int, 52 | metavar='tb', help='testing batch size') 53 | parser.add_argument('--part', default=3, type=int, 54 | metavar='tb', help=' part number') 55 | parser.add_argument('--method', default='id+tri', type=str, 56 | metavar='m', help='method type') 57 | parser.add_argument('--drop', default=0.2, type=float, 58 | metavar='drop', help='dropout ratio') 59 | parser.add_argument('--margin', default=0.3, type=float, 60 | metavar='margin', help='triplet loss margin') 61 | parser.add_argument('--num_pos', default=4, type=int, 62 | help='num of pos per identity in each modality') 63 | parser.add_argument('--trial', default=1, type=int, 64 | metavar='t', help='trial (only for RegDB dataset)') 65 | parser.add_argument('--seed', default=0, type=int, 66 | metavar='t', help='random seed') 67 | parser.add_argument('--gpu', default='0', type=str, 68 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 69 | parser.add_argument('--mode', default='all', type=str, help='all or indoor') 70 | parser.add_argument('--lambda0', default=1.0, type=float, 71 | metavar='lambda0', help='graph attention weights') 72 | parser.add_argument('--graph', action='store_true', help='either add graph attention or not') 73 | parser.add_argument('--wpa', action='store_true', help='either add weighted part attention') 74 | 75 | args = parser.parse_args() 76 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 77 | 78 | set_seed(args.seed) 79 | 80 | dataset = args.dataset 81 | if dataset == 'sysu': 82 | # TODO: define your data path 83 | data_path = '../SYSU_MM01/' 84 | log_path = args.log_path + 'sysu_log_ddag/' 85 | test_mode = [1, 2] # infrared to visible 86 | elif dataset =='regdb': 87 | # TODO: define your data path for RegDB dataset 88 | data_path = '../RegDB/' 89 | log_path = args.log_path + 'regdb_log_ddag/' 90 | test_mode = [2, 1] # visible to infrared 91 | 92 | checkpoint_path = args.model_path 93 | 94 | if not os.path.isdir(log_path): 95 | os.makedirs(log_path) 96 | if not os.path.isdir(checkpoint_path): 97 | os.makedirs(checkpoint_path) 98 | if not os.path.isdir(args.vis_log_path): 99 | os.makedirs(args.vis_log_path) 100 | 101 | # log file name 102 | suffix = dataset 103 | if args.graph: 104 | suffix = suffix + '_G' 105 | if args.wpa: 106 | suffix = suffix + '_P_{}'.format(args.part) 107 | suffix = suffix + '_drop_{}_{}_{}_lr_{}_seed_{}'.format(args.drop, args.num_pos, args.batch_size, args.lr, args.seed) 108 | if not args.optim == 'sgd': 109 | suffix = suffix + '_' + args.optim 110 | if dataset == 'regdb': 111 | suffix = suffix + '_trial_{}'.format(args.trial) 112 | 113 | test_log_file = open(log_path + suffix + '.txt', "w") 114 | sys.stdout = Logger(log_path + suffix + '_os.txt') 115 | 116 | vis_log_dir = args.vis_log_path + suffix + '/' 117 | 118 | if not os.path.isdir(vis_log_dir): 119 | os.makedirs(vis_log_dir) 120 | writer = SummaryWriter(vis_log_dir) 121 | print("==========\nArgs:{}\n==========".format(args)) 122 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 123 | best_acc = 0 # best test accuracy 124 | start_epoch = 0 125 | feature_dim = args.low_dim 126 | wG = 0 127 | end = time.time() 128 | 129 | print('==> Loading data..') 130 | # Data loading code 131 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 132 | transform_train = transforms.Compose([ 133 | transforms.ToPILImage(), 134 | transforms.Pad(10), 135 | transforms.RandomCrop((args.img_h, args.img_w)), 136 | transforms.RandomHorizontalFlip(), 137 | transforms.ToTensor(), 138 | normalize, 139 | ]) 140 | transform_test = transforms.Compose([ 141 | transforms.ToPILImage(), 142 | transforms.Resize((args.img_h, args.img_w)), 143 | transforms.ToTensor(), 144 | normalize, 145 | ]) 146 | 147 | 148 | if dataset == 'sysu': 149 | # training set 150 | trainset = SYSUData(data_path, transform=transform_train) 151 | # generate the idx of each person identity 152 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 153 | 154 | # testing set 155 | query_img, query_label, query_cam = process_query_sysu(data_path, mode=args.mode) 156 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=0) 157 | 158 | elif dataset == 'regdb': 159 | # training set 160 | trainset = RegDBData(data_path, args.trial, transform=transform_train) 161 | # generate the idx of each person identity 162 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 163 | 164 | # testing set 165 | query_img, query_label = process_test_regdb(data_path, trial=args.trial, modal='visible') 166 | gall_img, gall_label = process_test_regdb(data_path, trial=args.trial, modal='thermal') 167 | 168 | gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 169 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 170 | 171 | # testing data loader 172 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 173 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 174 | 175 | n_class = len(np.unique(trainset.train_color_label)) 176 | nquery = len(query_label) 177 | ngall = len(gall_label) 178 | 179 | print('Dataset {} statistics:'.format(dataset)) 180 | print(' ------------------------------') 181 | print(' subset | # ids | # images') 182 | print(' ------------------------------') 183 | print(' visible | {:5d} | {:8d}'.format(n_class, len(trainset.train_color_label))) 184 | print(' thermal | {:5d} | {:8d}'.format(n_class, len(trainset.train_thermal_label))) 185 | print(' ------------------------------') 186 | print(' query | {:5d} | {:8d}'.format(len(np.unique(query_label)), nquery)) 187 | print(' gallery | {:5d} | {:8d}'.format(len(np.unique(gall_label)), ngall)) 188 | print(' ------------------------------') 189 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 190 | 191 | print('==> Building model..') 192 | net = embed_net(args.low_dim, n_class, drop=args.drop, part=args.part, arch=args.arch, wpa=args.wpa) 193 | net.to(device) 194 | cudnn.benchmark = True 195 | 196 | def remove_fc(state_dict): 197 | """Remove the fc layer parameters from state_dict.""" 198 | # for key, value in state_dict.items(): 199 | for key, value in list(state_dict.items()): 200 | if key.startswith('fc1.') or key.startswith('fc2.') or key.startswith('local_conv_list'): 201 | del state_dict[key] 202 | return state_dict 203 | 204 | if len(args.resume) > 0: 205 | model_path = checkpoint_path + args.resume 206 | if os.path.isfile(model_path): 207 | print('==> loading checkpoint {}'.format(args.resume)) 208 | net.load_state_dict(remove_fc(torch.load(model_path, map_location=torch.device('cpu'))), strict=False) 209 | print('==> loaded checkpoint {}' 210 | .format(args.resume)) 211 | else: 212 | print('==> no checkpoint found at {}'.format(args.resume)) 213 | 214 | # define loss function 215 | criterion1 = nn.CrossEntropyLoss() 216 | loader_batch = args.batch_size * args.num_pos 217 | criterion2 = OriTripletLoss(batch_size=loader_batch, margin=args.margin) 218 | criterion1.to(device) 219 | criterion2.to(device) 220 | 221 | # optimizer 222 | if args.optim == 'sgd': 223 | ignored_params = list(map(id, net.bottleneck.parameters())) \ 224 | + list(map(id, net.classifier.parameters())) \ 225 | + list(map(id, net.wpa.parameters())) \ 226 | + list(map(id, net.attention_0.parameters())) \ 227 | + list(map(id, net.attention_1.parameters())) \ 228 | + list(map(id, net.attention_2.parameters())) \ 229 | + list(map(id, net.attention_3.parameters())) \ 230 | + list(map(id, net.out_att.parameters())) 231 | 232 | base_params = filter(lambda p: id(p) not in ignored_params, net.parameters()) 233 | 234 | optimizer_P = optim.SGD([ 235 | {'params': base_params, 'lr': 0.1 * args.lr}, 236 | {'params': net.bottleneck.parameters(), 'lr': args.lr}, 237 | {'params': net.classifier.parameters(), 'lr': args.lr}, 238 | {'params': net.wpa.parameters(), 'lr': args.lr}, 239 | {'params': net.attention_0.parameters(), 'lr': args.lr}, 240 | {'params': net.attention_1.parameters(), 'lr': args.lr}, 241 | {'params': net.attention_2.parameters(), 'lr': args.lr}, 242 | {'params': net.attention_3.parameters(), 'lr': args.lr}, 243 | {'params': net.out_att.parameters(), 'lr': args.lr} ,], 244 | weight_decay=5e-4, momentum=0.9, nesterov=True) 245 | 246 | optimizer_G = optim.SGD([ 247 | {'params': net.attention_0.parameters(), 'lr': args.lr}, 248 | {'params': net.attention_1.parameters(), 'lr': args.lr}, 249 | {'params': net.attention_2.parameters(), 'lr': args.lr}, 250 | {'params': net.attention_3.parameters(), 'lr': args.lr}, 251 | {'params': net.out_att.parameters(), 'lr': args.lr}, ], 252 | weight_decay=5e-4, momentum=0.9, nesterov=True) 253 | 254 | # exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) 255 | def adjust_learning_rate(optimizer_P, optimizer_G, epoch): 256 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 257 | if epoch < 10: 258 | lr = args.lr * (epoch + 1) / 10 259 | elif epoch >= 10 and epoch < 30: 260 | lr = args.lr 261 | elif epoch >= 30 and epoch < 50: 262 | lr = args.lr * 0.1 263 | elif epoch >= 50: 264 | lr = args.lr * 0.01 265 | 266 | optimizer_P.param_groups[0]['lr'] = 0.1 * lr 267 | for i in range(len(optimizer_P.param_groups) - 1): 268 | optimizer_P.param_groups[i + 1]['lr'] = lr 269 | return lr 270 | 271 | 272 | def train(epoch, wG): 273 | # adjust learning rate 274 | current_lr = adjust_learning_rate(optimizer_P, optimizer_G, epoch) 275 | train_loss = AverageMeter() 276 | id_loss = AverageMeter() 277 | tri_loss = AverageMeter() 278 | graph_loss = AverageMeter() 279 | data_time = AverageMeter() 280 | batch_time = AverageMeter() 281 | correct = 0 282 | total = 0 283 | 284 | # switch to train mode 285 | net.train() 286 | end = time.time() 287 | 288 | for batch_idx, (input1, input2, label1, label2) in enumerate(trainloader): 289 | 290 | labels = torch.cat((label1, label2), 0) 291 | 292 | 293 | # Graph construction 294 | # one_hot = F.one_hot(labels, num_classes=n_class) # for version > 1.2 295 | one_hot = torch.index_select(torch.eye(n_class), dim = 0, index = labels) 296 | # Compute A in Eq. (6) 297 | adj = torch.mm(one_hot, torch.transpose(one_hot, 0, 1)).float() + torch.eye(labels.size()[0]).float() 298 | w_norm = adj.pow(2).sum(1, keepdim=True).pow(1. / 2) 299 | adj_norm = adj.div(w_norm) # normalized adjacency matrix 300 | 301 | input1 = Variable(input1.cuda()) 302 | input2 = Variable(input2.cuda()) 303 | 304 | labels = Variable(labels.cuda()) 305 | adj_norm = Variable(adj_norm.cuda()) 306 | data_time.update(time.time() - end) 307 | 308 | # Forward into the network 309 | feat, out0, out_att, output = net(input1, input2, adj_norm) 310 | 311 | # baseline loss: identity loss + triplet loss Eq. (1) 312 | loss_id = criterion1(out0, labels) 313 | loss_tri, batch_acc = criterion2(feat, labels) 314 | correct += (batch_acc / 2) 315 | _, predicted = out0.max(1) 316 | correct += (predicted.eq(labels).sum().item() / 2) 317 | 318 | # Part attention loss 319 | loss_p = criterion1(out_att, labels) 320 | 321 | # Graph attention loss Eq. (9) 322 | loss_G = F.nll_loss(output, labels) 323 | 324 | # Instance-level part-aggregated feature learning Eq. (10) 325 | loss = loss_id + loss_tri + loss_p 326 | # Overall loss Eq. (11) 327 | loss_total = loss + wG * loss_G 328 | 329 | # optimization 330 | optimizer_P.zero_grad() 331 | loss_total.backward() 332 | optimizer_P.step() 333 | 334 | # log different loss components 335 | train_loss.update(loss.item(), 2 * input1.size(0)) 336 | id_loss.update(loss_id.item(), 2 * input1.size(0)) 337 | tri_loss.update(loss_tri.item(), 2 * input1.size(0)) 338 | graph_loss.update(loss_G.item(), 2 * input1.size(0)) 339 | total += labels.size(0) 340 | 341 | # measure elapsed time 342 | batch_time.update(time.time() - end) 343 | end = time.time() 344 | if batch_idx % 50 == 0: 345 | print('Epoch: [{}][{}/{}] ' 346 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 347 | 'lr:{} ' 348 | 'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f}) ' 349 | 'iLoss: {id_loss.val:.4f} ({id_loss.avg:.4f}) ' 350 | 'TLoss: {tri_loss.val:.4f} ({tri_loss.avg:.4f}) ' 351 | 'GLoss: {graph_loss.val:.4f} ({graph_loss.avg:.4f}) ' 352 | 'Accu: {:.2f}'.format( 353 | epoch, batch_idx, len(trainloader), current_lr, 354 | 100. * correct / total, batch_time=batch_time, 355 | train_loss=train_loss, id_loss=id_loss, tri_loss=tri_loss, graph_loss=graph_loss)) 356 | 357 | writer.add_scalar('total_loss', train_loss.avg, epoch) 358 | writer.add_scalar('id_loss', id_loss.avg, epoch) 359 | writer.add_scalar('tri_loss', tri_loss.avg, epoch) 360 | writer.add_scalar('graph_loss', graph_loss.avg, epoch) 361 | writer.add_scalar('lr', current_lr, epoch) 362 | # computer wG 363 | return 1. / (1. + train_loss.avg) 364 | 365 | def test(epoch): 366 | # switch to evaluation mode 367 | net.eval() 368 | print('Extracting Gallery Feature...') 369 | start = time.time() 370 | ptr = 0 371 | gall_feat = np.zeros((ngall, 2048)) 372 | gall_feat_att = np.zeros((ngall, 2048)) 373 | with torch.no_grad(): 374 | for batch_idx, (input, label) in enumerate(gall_loader): 375 | batch_num = input.size(0) 376 | input = Variable(input.cuda()) 377 | feat, feat_att = net(input, input, 0, test_mode[0]) 378 | gall_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 379 | gall_feat_att[ptr:ptr + batch_num, :] = feat_att.detach().cpu().numpy() 380 | ptr = ptr + batch_num 381 | print('Extracting Time:\t {:.3f}'.format(time.time() - start)) 382 | 383 | # switch to evaluation 384 | net.eval() 385 | print('Extracting Query Feature...') 386 | start = time.time() 387 | ptr = 0 388 | query_feat = np.zeros((nquery, 2048)) 389 | query_feat_att = np.zeros((nquery, 2048)) 390 | with torch.no_grad(): 391 | for batch_idx, (input, label) in enumerate(query_loader): 392 | batch_num = input.size(0) 393 | input = Variable(input.cuda()) 394 | feat, feat_att = net(input, input, 0, test_mode[1]) 395 | query_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 396 | query_feat_att[ptr:ptr + batch_num, :] = feat_att.detach().cpu().numpy() 397 | ptr = ptr + batch_num 398 | print('Extracting Time:\t {:.3f}'.format(time.time() - start)) 399 | 400 | start = time.time() 401 | # compute the similarity 402 | distmat = np.matmul(query_feat, np.transpose(gall_feat)) 403 | distmat_att = np.matmul(query_feat_att, np.transpose(gall_feat_att)) 404 | 405 | # evaluation 406 | if dataset == 'regdb': 407 | cmc, mAP, mINP = eval_regdb(-distmat, query_label, gall_label) 408 | cmc_att, mAP_att, mINP_att = eval_regdb(-distmat_att, query_label, gall_label) 409 | elif dataset == 'sysu': 410 | cmc, mAP, mINP = eval_sysu(-distmat, query_label, gall_label, query_cam, gall_cam) 411 | cmc_att, mAP_att, mINP_att = eval_sysu(-distmat_att, query_label, gall_label, query_cam, gall_cam) 412 | print('Evaluation Time:\t {:.3f}'.format(time.time() - start)) 413 | 414 | writer.add_scalar('rank1', cmc[0], epoch) 415 | writer.add_scalar('mAP', mAP, epoch) 416 | writer.add_scalar('rank1_att', cmc_att[0], epoch) 417 | writer.add_scalar('mAP_att', mAP_att, epoch) 418 | writer.add_scalar('mAP_att', mAP_att, epoch) 419 | writer.add_scalar('mINP_att', mINP_att, epoch) 420 | return cmc, mAP, mINP, cmc_att, mAP_att, mINP_att 421 | 422 | 423 | # training 424 | print('==> Start Training...') 425 | for epoch in range(start_epoch, 81 - start_epoch): 426 | 427 | print('==> Preparing Data Loader...') 428 | # identity sampler: 429 | sampler = IdentitySampler(trainset.train_color_label, \ 430 | trainset.train_thermal_label, color_pos, thermal_pos, args.num_pos, args.batch_size, 431 | epoch) 432 | 433 | trainset.cIndex = sampler.index1 # color index 434 | trainset.tIndex = sampler.index2 # infrared index 435 | print(epoch) 436 | print(trainset.cIndex) 437 | print(trainset.tIndex) 438 | 439 | loader_batch = args.batch_size * args.num_pos 440 | 441 | trainloader = data.DataLoader(trainset, batch_size=loader_batch, \ 442 | sampler=sampler, num_workers=args.workers, drop_last=True) 443 | 444 | # training 445 | wG = train(epoch, wG) 446 | 447 | if epoch > 0 and epoch % 2 == 0: 448 | print('Test Epoch: {}'.format(epoch)) 449 | print('Test Epoch: {}'.format(epoch), file=test_log_file) 450 | 451 | # testing 452 | cmc, mAP, mINP, cmc_att, mAP_att, mINP_att = test(epoch) 453 | # log output 454 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 455 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 456 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 457 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP), file=test_log_file) 458 | 459 | print('FC_att: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 460 | cmc_att[0], cmc_att[4], cmc_att[9], cmc_att[19], mAP_att, mINP_att)) 461 | print('FC_att: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 462 | cmc_att[0], cmc_att[4], cmc_att[9], cmc_att[19], mAP_att, mINP_att), file=test_log_file) 463 | test_log_file.flush() 464 | 465 | # save model 466 | if cmc_att[0] > best_acc: # not the real best for sysu-mm01 467 | best_acc = cmc_att[0] 468 | best_epoch = epoch 469 | state = { 470 | 'net': net.state_dict(), 471 | 'cmc': cmc_att, 472 | 'mAP': mAP_att, 473 | 'epoch': epoch, 474 | } 475 | torch.save(state, checkpoint_path + suffix + '_best.t') 476 | print('Best Epoch [{}]'.format(best_epoch)) -------------------------------------------------------------------------------- /DDAG/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | import numbers 4 | import numpy as np 5 | from torch.utils.data.sampler import Sampler 6 | import sys 7 | import os.path as osp 8 | import scipy.io as scio 9 | import torch 10 | 11 | def load_data(input_data_path ): 12 | with open(input_data_path) as f: 13 | data_file_list = open(input_data_path, 'rt').read().splitlines() 14 | # Get full list of color image and labels 15 | file_image = [s.split(' ')[0] for s in data_file_list] 16 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 17 | 18 | return file_image, file_label 19 | 20 | 21 | def GenIdx( train_color_label, train_thermal_label): 22 | color_pos = [] 23 | unique_label_color = np.unique(train_color_label) 24 | for i in range(len(unique_label_color)): 25 | tmp_pos = [k for k,v in enumerate(train_color_label) if v==unique_label_color[i]] 26 | color_pos.append(tmp_pos) 27 | 28 | thermal_pos = [] 29 | unique_label_thermal = np.unique(train_thermal_label) 30 | for i in range(len(unique_label_thermal)): 31 | tmp_pos = [k for k,v in enumerate(train_thermal_label) if v==unique_label_thermal[i]] 32 | thermal_pos.append(tmp_pos) 33 | return color_pos, thermal_pos 34 | 35 | def GenCamIdx(gall_img, gall_label, mode): 36 | if mode =='indoor': 37 | camIdx = [1,2] 38 | else: 39 | camIdx = [1,2,4,5] 40 | gall_cam = [] 41 | for i in range(len(gall_img)): 42 | gall_cam.append(int(gall_img[i][-10])) 43 | 44 | sample_pos = [] 45 | unique_label = np.unique(gall_label) 46 | for i in range(len(unique_label)): 47 | for j in range(len(camIdx)): 48 | id_pos = [k for k,v in enumerate(gall_label) if v==unique_label[i] and gall_cam[k]==camIdx[j]] 49 | if id_pos: 50 | sample_pos.append(id_pos) 51 | return sample_pos 52 | 53 | def ExtractCam(gall_img): 54 | gall_cam = [] 55 | for i in range(len(gall_img)): 56 | cam_id = int(gall_img[i][-10]) 57 | # if cam_id ==3: 58 | # cam_id = 2 59 | gall_cam.append(cam_id) 60 | 61 | return np.array(gall_cam) 62 | 63 | class IdentitySampler(Sampler): 64 | """Sample person identities evenly in each batch. 65 | Args: 66 | train_color_label, train_thermal_label: labels of two modalities 67 | color_pos, thermal_pos: positions of each identity 68 | batchSize: batch size 69 | """ 70 | 71 | def __init__(self, train_color_label, train_thermal_label, color_pos, thermal_pos, num_pos, batchSize, epoch): 72 | uni_label = np.unique(train_color_label) 73 | self.n_classes = len(uni_label) 74 | 75 | 76 | N = np.maximum(len(train_color_label), len(train_thermal_label)) 77 | for j in range(int(N/(batchSize*num_pos))+1): 78 | batch_idx = np.random.choice(uni_label, batchSize, replace = False) 79 | for i in range(batchSize): 80 | sample_color = np.random.choice(color_pos[batch_idx[i]], num_pos) 81 | sample_thermal = np.random.choice(thermal_pos[batch_idx[i]], num_pos) 82 | 83 | if j ==0 and i==0: 84 | index1= sample_color 85 | index2= sample_thermal 86 | else: 87 | index1 = np.hstack((index1, sample_color)) 88 | index2 = np.hstack((index2, sample_thermal)) 89 | 90 | self.index1 = index1 91 | self.index2 = index2 92 | self.N = N 93 | 94 | def __iter__(self): 95 | return iter(np.arange(len(self.index1))) 96 | 97 | def __len__(self): 98 | return self.N 99 | 100 | class AverageMeter(object): 101 | """Computes and stores the average and current value""" 102 | def __init__(self): 103 | self.reset() 104 | 105 | def reset(self): 106 | self.val = 0 107 | self.avg = 0 108 | self.sum = 0 109 | self.count = 0 110 | 111 | def update(self, val, n=1): 112 | self.val = val 113 | self.sum += val * n 114 | self.count += n 115 | self.avg = self.sum / self.count 116 | 117 | def mkdir_if_missing(directory): 118 | if not osp.exists(directory): 119 | try: 120 | os.makedirs(directory) 121 | except OSError as e: 122 | if e.errno != errno.EEXIST: 123 | raise 124 | class Logger(object): 125 | """ 126 | Write console output to external text file. 127 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 128 | """ 129 | def __init__(self, fpath=None): 130 | self.console = sys.stdout 131 | self.file = None 132 | if fpath is not None: 133 | mkdir_if_missing(osp.dirname(fpath)) 134 | self.file = open(fpath, 'w') 135 | 136 | def __del__(self): 137 | self.close() 138 | 139 | def __enter__(self): 140 | pass 141 | 142 | def __exit__(self, *args): 143 | self.close() 144 | 145 | def write(self, msg): 146 | self.console.write(msg) 147 | if self.file is not None: 148 | self.file.write(msg) 149 | 150 | def flush(self): 151 | self.console.flush() 152 | if self.file is not None: 153 | self.file.flush() 154 | os.fsync(self.file.fileno()) 155 | 156 | def close(self): 157 | self.console.close() 158 | if self.file is not None: 159 | self.file.close() 160 | 161 | def set_seed(seed, cuda=True): 162 | np.random.seed(seed) 163 | torch.manual_seed(seed) 164 | if cuda: 165 | torch.cuda.manual_seed(seed) 166 | 167 | def set_requires_grad(nets, requires_grad=False): 168 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 169 | Parameters: 170 | nets (network list) -- a list of networks 171 | requires_grad (bool) -- whether the networks require gradients or not 172 | """ 173 | if not isinstance(nets, list): 174 | nets = [nets] 175 | for net in nets: 176 | if net is not None: 177 | for param in net.parameters(): 178 | param.requires_grad = requires_grad -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Self-Supervised Modality-Aware Multiple Granularity Pre-Training for RGB-Infrared Person Re-Identification 2 | 3 | ![](pipeline.png) 4 | 5 | This is the offical PyTorch implementation of the paper 'Self-Supervised Modality-Aware Multiple Granularity Pre-Training for RGB-Infrared Person Re-Identification'. 6 | 7 | **Authors**: *Lin Wan, Qianyan Jing, Zongyuan Sun, Chuang Zhang, Zhihang Li, and Yehansen Chen* 8 | 9 | # Abstract 10 | 11 | RGB-Infrared person re-identification (RGB-IR ReID) aims to associate people across disjoint RGB and IR camera views. Currently, state-of-the-art performance of RGB-IR ReID is not as impressive as that of conventional ReID. Much of that is due to the notorious modality bias training issue brought by the single-modality ImageNet pre-training, which might yield RGB-biased representations that severely hinder the cross-modality image retrieval. This paper makes first attempt to tackle the task from a pre-training perspective. We propose a self-supervised pre-training solution, named Modality-Aware Multiple Granularity Learning (MMGL), which directly trains models from scratch only on multi-modal ReID datasets, but achieving competitive results against ImageNet pre-training, without using any external data or sophisticated tuning tricks. First, we develop a simple-but-effective 'permutation recovery' pretext task that globally maps shuffled RGB-IR images into a shared latent permutation space, providing modality-invariant global representations for downstream ReID tasks. Second, we present a part-aware cycle-contrastive (PCC) learning strategy that utilizes cross-modality cycle-consistency to maximize agreement between semantically similar RGB-IR image patches. This enables contrastive learning for the unpaired multi-modal scenarios, further improving the discriminability of local features without laborious instance augmentation. Based on these designs, MMGL effectively alleviates the modality bias training problem. Extensive experiments demonstrate that it learns better representations (+8.03% Rank-1 accuracy) with faster training speed (converge only in few hours) and higher data efficiency (<5% data size) than ImageNet pre-training. The results also suggest it generalizes well to various existing models, losses and has promising transferability across datasets. 12 | 13 | # To Do List 14 | 15 | - [x] Release the fine-tuned models and training logs by MMGL 16 | - [ ] Release the source code (**The code is coming soon!**) 17 | - [ ] Release the pre-trained models and logs 18 | - [ ] Stay tuned 19 | 20 | # How to Use 21 | 22 | ## Environment 23 | 24 | **Packages** 25 | 26 | - Python 3.6.13 27 | - PyTorch 1.10.2 28 | - Numpy 1.19.2 29 | - Scipy 1.5.2 30 | - TensorboardX 2.2 31 | 32 | **Hardware** 33 | 34 | - A single Nvidia 2080Ti (original paper) / 3080Ti (what we use now) 35 | - GPU Memory: 12G 36 | - Nvidia Driver Version: 510.54 37 | - CUDA Version: 11.6 38 | 39 | ## Datasets 40 | 41 | - (1) RegDB: The RegDB dataset can be downloaded from this [website](http://dm.dongguk.edu/link.html) by submitting a copyright form. 42 | 43 | - (Named: "Dongguk Body-based Person Recognition Database (DBPerson-Recog-DB1)" on their website). 44 | 45 | 46 | - (2) SYSU-MM01: The SYSU-MM01 dataset can be downloaded from this [website](http://isee.sysu.edu.cn/project/RGBIRReID.htm). 47 | 48 | - run `python pre_process_sysu.py` to preprocess the dataset, the training data will be stored in ".npy" format. 49 | 50 | - A private download link of both datasets can be provided via sending me an email (chenyehansen@gmail.com). 51 | 52 | ## Self-supervised Pre-Training with MMGL 53 | 54 | Only **single-gpu** training is supported now. 55 | 56 | To do MMGL pre-training on a *two-stream* ResNet-50 backbone, run: 57 | ``` 58 | python train.py --dataset sysu --stream two --lr 0.1 --pcc --part --gpu 0 59 | ``` 60 | 61 | **Optional Hyper-Parameters:** 62 | 63 | `--num_stripe` : The number of partition stripes 64 | 65 | `--cl_weight`: The weight of PCC loss 66 | 67 | `--cl_temp`: The temperature of PCC loss 68 | 69 | 70 | To do MMGL pre-training on a *one-stream* ResNet-50 backbone, run: 71 | ``` 72 | python train.py --dataset sysu --stream one --lr 0.1 --pcc --part --gpu 0 73 | ``` 74 | 75 | **Pre-trained Models:** 76 | 77 | Backbone | Training Time | Permutation Accuracy | Model 78 | ---|:---:|:---:|:---: 79 | Two-Stream | 6h | 98.6% | avaliable soon 80 | One-Stream | 6h | 97.5% | avaliable soon 81 | 82 | 83 | ## Supervised RGB-Infrared Person Re-Identification 84 | 85 | Once the pre-training is finished, please move it to the corresponding ```save_model/``` dictionary of different methods. 86 | 87 | To perform supervised RGB-IR ReID with Base / AGW, run: 88 | ``` 89 | cd AGW 90 | 91 | python train.py --dataset sysu (or regdb) --mode all --lr 0.1 --method agw (or base) --gpu 0 --resume 'write your checkpoint file name here' 92 | ``` 93 | 94 | To test a model on SYSU-MM01 dataset by 95 | ``` 96 | python test.py --mode all --resume 'model_path' --gpu 0 --dataset sysu 97 | ``` 98 | - `--dataset`: "sysu". 99 | 100 | - `--mode`: "all" or "indoor" all search or indoor search. 101 | 102 | - `--resume`: the saved model path. 103 | 104 | - `--gpu`: which gpu to run. 105 | 106 | To perform supervised RGB-IR ReID with DDAG, run: 107 | ``` 108 | cd DDAG 109 | 110 | python train_ddag.py --dataset sysu(regdb) --lr 0.1 --wpa --graph --gpu 0 --resume 'write your checkpoint file name here' 111 | ``` 112 | 113 | To test a model on SYSU-MM01 dataset by 114 | 115 | ``` 116 | python test_ddag.py --dataset sysu --mode all --wpa --graph --gpu 1 --resume 'model_path' 117 | ``` 118 | - `--dataset`: "sysu". 119 | 120 | - `--mode`: "all" or "indoor" all search or indoor search. 121 | 122 | - `--resume`: the saved model path. ** Important ** 123 | 124 | - `--gpu`: which gpu to run. 125 | 126 | # Results 127 | 128 | **MMGL Pre-Training Fine-Tuned Results (SYSU-MM01, Single-Shot & All-Search):** 129 | |Methods | Pretrained| Rank@1 | mAP | Model| 130 | | -------- | ----- | ----- | ----- |------| 131 | |AGW | MMGL | 56.97% | 54.61% | [Checkpoint](https://drive.google.com/file/d/1y_GmFSWiVtsu0_Zf5tENLU0BTf6j9qfB/view?usp=sharing) \| [Training Log](https://drive.google.com/file/d/1xSdwuZ6AP3J-8Qi-dOBFw4J723I7m6eS/view?usp=sharing)| 132 | |DDAG | MMGL | 56.75% | 53.96% |[Checkpoint](https://drive.google.com/file/d/1hXYVXwfwNdL5JS9BPWvGwGD5ZB3FPzCy/view?usp=sharing) \| [Training Log](https://drive.google.com/file/d/1rpwVqG0q_O-Jg7Yz9itx0VZj4Euxy6GK/view?usp=sharing)| 133 | 134 | \* Both of these two methods may have some fluctuation due to random spliting. The results might be better by finetuning the hyper-parameters. 135 | 136 | **ImageNet Supervised Pre-Training Fine-Tuned Results (Provided by Mang Ye):** 137 | 138 | |Methods | Pretrained| Rank@1 | mAP | Model| 139 | | -------- | ----- | ----- | ----- |------| 140 | |AGW | ImageNet | ~ 47.50% | ~ 47.65% | [Checkpoint](https://drive.google.com/open?id=181K9PQGnej0K5xNX9DRBDPAf3K9JosYk)| 141 | |DDAG | ImageNet | ~ 54.75% | ~53.02% |----- | 142 | 143 | # Citation 144 | 145 | Please cite this paper in your publications if it helps your research: 146 | ``` 147 | @article{wan2021self, 148 | title={Self-Supervised Modality-Aware Multiple Granularity Pre-Training for RGB-Infrared Person Re-Identification}, 149 | author={Wan, Lin and Jing, Qianyan and Sun, Zongyuan and Zhang, Chuang and Li, Zhihang and Chen, Yehansen}, 150 | journal={arXiv preprint arXiv:2112.06147}, 151 | year={2021} 152 | } 153 | ``` 154 | -------------------------------------------------------------------------------- /pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hansonchen1996/MMGL/3a6d05ff1b9fcb73777e75636eef02193bfeadd4/pipeline.png --------------------------------------------------------------------------------