├── 20210918132449.png ├── README.md ├── data_loader.py ├── data_manager.py ├── eval_metrics.py ├── image-20210909100353763.png ├── loss.py ├── memory_MGMRA.py ├── memory_SGMRA.py ├── memory_module_MGMRA.py ├── model_MGMRA.py ├── model_SGMRA.py ├── model_mine.py ├── pre_process_sysu.py ├── re_rank.py ├── resnet.py ├── test_mine_pcb.py ├── train_HCT.py ├── train_MGMRA.py ├── train_SGMRA.py └── utils.py /20210918132449.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chenfeng1271/MGMRA/d9c2986a4d692f76990888aca5d60fd11c389017/20210918132449.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **[Memory Regulation and Alignment toward Generalizer RGB-Infrared Person Re-identification](https://arxiv.org/abs/2109.08843)** 2 | 3 | 4 | ### Updates 5 | 6 | - I am confident that I achieved the reported result in the SYSU indoor setting, although I am currently unable to reproduce it. However, all other results align with what was reported. Given that the reimplemented result stands at 77.35 in SYSU indoor Rank1, which was considered state-of-the-art at the time, there's no reason for me to manipulate my findings. Had I intended to misrepresent my results, I would have opted not to release my code. I urge future studies to compare their findings with my reimplemented results to ensure fairness. That said, I will not amend the results in my paper as I stand by their authenticity. 7 | - I re-upload the MG-MRA code. I already try it on my computer. It should be ok. 8 | 9 | - You may notice that the input of each stage of MG-MRA is different from that of the paper. These two settings have a similar performance, i.e., one is better in RegDB, one is better in SYSU. I did not heavily tune my code with better hyper-parameters, so you may achieve little improvement than report. 10 | 11 | - It is just a transformer decoder with the difference in qkv setting. You can follow CRET [1] to have a better understanding at this moment. 12 | 13 | - We try this work in RGB ReID and Video ReID, It achieves similar performance as baseline model in RGB ReID and about 1% improvement in Video ReID without changing any hyper parameter. 14 | 15 | [1] Ji K, Liu J, Hong W, et al. Cret: Cross-modal retrieval transformer for efficient text-video retrieval[C]//Proceedings of the 45th International ACM SIGIR Conference on Research and Development in Information Retrieval. 2022: 949-959. 16 | ### Highlights 17 | - The learned coarse-to-fine prototypes can consistently provide domain-level semantic templets with various granularity, meeting the requirement for multi-level semantic alignment. 18 | - Our proposed MG-MRA boosts the performance of baseline and existing state of the arts, e.g., AGW and HCT by a large margin with limited consumptions. We achieve a new state-of-the-art on RegDB and SYSU-MM01 with 94.59%/88.18% and 72.50%/68.94% Rank1/mAP respectively. 19 | - This work has some potential interesting settings which have not been explored, such as uncompatiblity with maxpooling as introduced in the discussion section of our paper. Moreover, some specific phenomenon show this module does a regulation work if your try to use the MGMRA brach output to evaluate. Unluckily, I have to move to my next work and can not investigate it further. 20 | 21 | - 22 | 23 | 24 | 25 | ### Method 26 | ![image-20210909100353763](20210918132449.png) 27 | 28 | ### Results 29 | 30 | ![image-20210909100353763](image-20210909100353763.png) 31 | 32 | 33 | ### Usage 34 | Our code extends the pytorch implementation of Cross-Modal-Re-ID-baseline in [Github](https://github.com/mangye16/Cross-Modal-Re-ID-baseline). Please refer to the offical repo for details of data preparation. 35 | 36 | ### Training 37 | 38 | Train original HCT method for RegDB by 39 | 40 | ```bash 41 | python train_HCT.py --dataset regdb --lr 0.1 --gpu 0 --batch-size 8 --num_pos 4 42 | ``` 43 | 44 | Train a SG-MRA for RegDB by 45 | ```bash 46 | python train_SGMRA.py --dataset regdb --lr 0.1 --gpu 0 --batch-size 8 --num_pos 4 47 | ``` 48 | 49 | Train a MG-MRA for RegDB by 50 | 51 | ```bash 52 | python train_MGMRA.py --dataset regdb --lr 0.1 --gpu 0 --batch-size 8 --num_pos 4 53 | ``` 54 | 55 | Train a model for SYSU-MM01 by 56 | 57 | ```bash 58 | python train_MGMRA.py --dataset sysu --lr 0.01 --batch-size 6 --num_pos 8 --gpu 0 59 | ``` 60 | 61 | **Parameters**: More parameters can be found in the manuscript and code. 62 | 63 | ### Reproduction 64 | Our code should be easy to reproduce the results reported in paper. You can train it directly. We also provide weight file (our model on SYSU all search, on 10 trials of RegDB) for fast evaluation. You can download it from (link:https://pan.baidu.com/s/1Hs74Qsii0sK15ELt_wQH_Q 65 | code:1111)and verify the performance as following table. 66 | 67 | | RegDB trial visible2infared | Rank1 | mAP | 68 | | ---------------------------- | ----- | ----- | 69 | | 1 | 93.50 | 93.84 | 70 | | 2 | 95.24 | 95.07 | 71 | | 3 | 93.01 | 93.29 | 72 | | 4 | 93.30 | 93.73 | 73 | | 5 | 96.80 | 95.98 | 74 | | 6 | 94.85 | 93.88 | 75 | | 7 | 95.73 | 95.47 | 76 | | 8 | 95.49 | 95.34 | 77 | | 9 | 93.98 | 93.94 | 78 | | 10 | 95.05 | 94.57 | 79 | | mean | 94.7 | 94.5 | 80 | 81 | ### Reference 82 | ``` 83 | @article{chen2021memory, 84 | title={Memory Regulation and Alignment toward Generalizer RGB-Infrared Person}, 85 | author={Chen, Feng and Wu, Fei and Wu, Qi and Wan, Zhiguo}, 86 | journal={arXiv preprint arXiv:2109.08843}, 87 | year={2021} 88 | } 89 | 90 | @article{arxiv20reidsurvey, 91 | title={Deep Learning for Person Re-identification: A Survey and Outlook}, 92 | author={Ye, Mang and Shen, Jianbing and Lin, Gaojie and Xiang, Tao and Shao, Ling and Hoi, Steven C. H.}, 93 | journal={arXiv preprint arXiv:2001.04193}, 94 | year={2020}, 95 | } 96 | ``` 97 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch.utils.data as data 4 | 5 | 6 | class SYSUData(data.Dataset): 7 | def __init__(self, data_dir, transform=None, colorIndex = None, thermalIndex = None): 8 | 9 | #data_dir = 'E:\chenfeng\dataset\SYSU-MM01/' 10 | # Load training images (path) and labels 11 | train_color_image = np.load(data_dir + 'train_rgb_resized_img.npy') 12 | self.train_color_label = np.load(data_dir + 'train_rgb_resized_label.npy') 13 | 14 | train_thermal_image = np.load(data_dir + 'train_ir_resized_img.npy') 15 | self.train_thermal_label = np.load(data_dir + 'train_ir_resized_label.npy') 16 | 17 | # BGR to RGB 18 | self.train_color_image = train_color_image 19 | self.train_thermal_image = train_thermal_image 20 | self.transform = transform 21 | self.cIndex = colorIndex 22 | self.tIndex = thermalIndex 23 | 24 | def __getitem__(self, index): 25 | 26 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 27 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 28 | 29 | img1 = self.transform(img1) 30 | img2 = self.transform(img2) 31 | 32 | return img1, img2, target1, target2 33 | 34 | def __len__(self): 35 | return len(self.train_color_label) 36 | 37 | 38 | class RegDBData(data.Dataset): 39 | def __init__(self, data_dir, trial, transform=None, colorIndex = None, thermalIndex = None): 40 | # Load training images (path) and labels 41 | #data_dir = 'E:\chenfeng\dataset\RegDB/' 42 | train_color_list = data_dir + 'idx/train_visible_{}'.format(trial)+ '.txt' 43 | train_thermal_list = data_dir + 'idx/train_thermal_{}'.format(trial)+ '.txt' 44 | 45 | color_img_file, train_color_label = load_data(train_color_list) 46 | thermal_img_file, train_thermal_label = load_data(train_thermal_list) 47 | 48 | train_color_image = [] 49 | for i in range(len(color_img_file)): 50 | 51 | img = Image.open(data_dir+ color_img_file[i]) 52 | img = img.resize((144, 288), Image.ANTIALIAS) 53 | pix_array = np.array(img) 54 | train_color_image.append(pix_array) 55 | train_color_image = np.array(train_color_image) 56 | 57 | train_thermal_image = [] 58 | for i in range(len(thermal_img_file)): 59 | img = Image.open(data_dir+ thermal_img_file[i]) 60 | img = img.resize((144, 288), Image.ANTIALIAS) 61 | pix_array = np.array(img) 62 | train_thermal_image.append(pix_array) 63 | train_thermal_image = np.array(train_thermal_image) 64 | 65 | # BGR to RGB 66 | self.train_color_image = train_color_image 67 | self.train_color_label = train_color_label 68 | 69 | # BGR to RGB 70 | self.train_thermal_image = train_thermal_image 71 | self.train_thermal_label = train_thermal_label 72 | 73 | self.transform = transform 74 | self.cIndex = colorIndex 75 | self.tIndex = thermalIndex 76 | 77 | def __getitem__(self, index): 78 | 79 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 80 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 81 | 82 | img1 = self.transform(img1) 83 | img2 = self.transform(img2) 84 | 85 | return img1, img2, target1, target2 86 | 87 | def __len__(self): 88 | return len(self.train_color_label) 89 | 90 | class TestData(data.Dataset): 91 | def __init__(self, test_img_file, test_label, transform=None, img_size = (144,288)): 92 | 93 | test_image = [] 94 | for i in range(len(test_img_file)): 95 | img = Image.open(test_img_file[i]) 96 | img = img.resize((img_size[0], img_size[1]), Image.ANTIALIAS) 97 | pix_array = np.array(img) 98 | test_image.append(pix_array) 99 | test_image = np.array(test_image) 100 | self.test_image = test_image 101 | self.test_label = test_label 102 | self.transform = transform 103 | 104 | def __getitem__(self, index): 105 | img1, target1 = self.test_image[index], self.test_label[index] 106 | img1 = self.transform(img1) 107 | return img1, target1 108 | 109 | def __len__(self): 110 | return len(self.test_image) 111 | 112 | class TestDataOld(data.Dataset): 113 | def __init__(self, data_dir, test_img_file, test_label, transform=None, img_size = (144,288)): 114 | 115 | test_image = [] 116 | for i in range(len(test_img_file)): 117 | img = Image.open(data_dir + test_img_file[i]) 118 | img = img.resize((img_size[0], img_size[1]), Image.ANTIALIAS) 119 | pix_array = np.array(img) 120 | test_image.append(pix_array) 121 | test_image = np.array(test_image) 122 | self.test_image = test_image 123 | self.test_label = test_label 124 | self.transform = transform 125 | 126 | def __getitem__(self, index): 127 | img1, target1 = self.test_image[index], self.test_label[index] 128 | img1 = self.transform(img1) 129 | return img1, target1 130 | 131 | def __len__(self): 132 | return len(self.test_image) 133 | def load_data(input_data_path ): 134 | with open(input_data_path) as f: 135 | data_file_list = open(input_data_path, 'rt').read().splitlines() 136 | # Get full list of image and labels 137 | file_image = [s.split(' ')[0] for s in data_file_list] 138 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 139 | 140 | return file_image, file_label -------------------------------------------------------------------------------- /data_manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import numpy as np 4 | import random 5 | 6 | def process_query_sysu(data_path, mode = 'all', relabel=False): 7 | if mode== 'all': 8 | ir_cameras = ['cam3','cam6'] 9 | elif mode =='indoor': 10 | ir_cameras = ['cam3','cam6'] 11 | 12 | file_path = os.path.join(data_path,'exp/test_id.txt') 13 | files_rgb = [] 14 | files_ir = [] 15 | 16 | with open(file_path, 'r') as file: 17 | ids = file.read().splitlines() 18 | ids = [int(y) for y in ids[0].split(',')] 19 | ids = ["%04d" % x for x in ids] 20 | 21 | for id in sorted(ids): 22 | for cam in ir_cameras: 23 | img_dir = os.path.join(data_path,cam,id) 24 | if os.path.isdir(img_dir): 25 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 26 | files_ir.extend(new_files) 27 | query_img = [] 28 | query_id = [] 29 | query_cam = [] 30 | for img_path in files_ir: 31 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 32 | query_img.append(img_path) 33 | query_id.append(pid) 34 | query_cam.append(camid) 35 | return query_img, np.array(query_id), np.array(query_cam) 36 | 37 | def process_gallery_sysu(data_path, mode = 'all', trial = 0, relabel=False): 38 | 39 | random.seed(trial) 40 | 41 | if mode== 'all': 42 | rgb_cameras = ['cam1','cam2','cam4','cam5'] 43 | elif mode =='indoor': 44 | rgb_cameras = ['cam1','cam2'] 45 | 46 | file_path = os.path.join(data_path,'exp/test_id.txt') 47 | files_rgb = [] 48 | with open(file_path, 'r') as file: 49 | ids = file.read().splitlines() 50 | ids = [int(y) for y in ids[0].split(',')] 51 | ids = ["%04d" % x for x in ids] 52 | 53 | for id in sorted(ids): 54 | for cam in rgb_cameras: 55 | img_dir = os.path.join(data_path,cam,id) 56 | if os.path.isdir(img_dir): 57 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 58 | files_rgb.append(random.choice(new_files)) 59 | gall_img = [] 60 | gall_id = [] 61 | gall_cam = [] 62 | for img_path in files_rgb: 63 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 64 | gall_img.append(img_path) 65 | gall_id.append(pid) 66 | gall_cam.append(camid) 67 | return gall_img, np.array(gall_id), np.array(gall_cam) 68 | 69 | def process_test_regdb(img_dir, trial = 1, modal = 'visible'): 70 | if modal=='visible': 71 | input_data_path = img_dir + 'idx/test_visible_{}'.format(trial) + '.txt' 72 | elif modal=='thermal': 73 | input_data_path = img_dir + 'idx/test_thermal_{}'.format(trial) + '.txt' 74 | 75 | with open(input_data_path) as f: 76 | data_file_list = open(input_data_path, 'rt').read().splitlines() 77 | # Get full list of image and labels 78 | file_image = [img_dir + '/' + s.split(' ')[0] for s in data_file_list] 79 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 80 | 81 | return file_image, np.array(file_label) -------------------------------------------------------------------------------- /eval_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import numpy as np 3 | """Cross-Modality ReID""" 4 | import pdb 5 | 6 | def eval_sysu(distmat, q_pids, g_pids, q_camids, g_camids, max_rank = 20): 7 | """Evaluation with sysu metric 8 | Key: for each query identity, its gallery images from the same camera view are discarded. "Following the original setting in ite dataset" 9 | """ 10 | num_q, num_g = distmat.shape 11 | if num_g < max_rank: 12 | max_rank = num_g 13 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 14 | indices = np.argsort(distmat, axis=1) 15 | pred_label = g_pids[indices] 16 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 17 | 18 | # compute cmc curve for each query 19 | new_all_cmc = [] 20 | all_cmc = [] 21 | all_AP = [] 22 | all_INP = [] 23 | num_valid_q = 0. # number of valid query 24 | for q_idx in range(num_q): 25 | # get query pid and camid 26 | q_pid = q_pids[q_idx] 27 | q_camid = q_camids[q_idx] 28 | 29 | # remove gallery samples that have the same pid and camid with query 30 | order = indices[q_idx] 31 | remove = (q_camid == 3) & (g_camids[order] == 2) 32 | keep = np.invert(remove) 33 | 34 | # compute cmc curve 35 | # the cmc calculation is different from standard protocol 36 | # we follow the protocol of the author's released code 37 | new_cmc = pred_label[q_idx][keep] 38 | new_index = np.unique(new_cmc, return_index=True)[1] 39 | new_cmc = [new_cmc[index] for index in sorted(new_index)] 40 | 41 | new_match = (new_cmc == q_pid).astype(np.int32) 42 | new_cmc = new_match.cumsum() 43 | new_all_cmc.append(new_cmc[:max_rank]) 44 | 45 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 46 | if not np.any(orig_cmc): 47 | # this condition is true when query identity does not appear in gallery 48 | continue 49 | 50 | cmc = orig_cmc.cumsum() 51 | 52 | # compute mINP 53 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook 54 | pos_idx = np.where(orig_cmc == 1) 55 | pos_max_idx = np.max(pos_idx) 56 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0) 57 | all_INP.append(inp) 58 | 59 | cmc[cmc > 1] = 1 60 | 61 | all_cmc.append(cmc[:max_rank]) 62 | num_valid_q += 1. 63 | 64 | # compute average precision 65 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 66 | num_rel = orig_cmc.sum() 67 | tmp_cmc = orig_cmc.cumsum() 68 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 69 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 70 | AP = tmp_cmc.sum() / num_rel 71 | all_AP.append(AP) 72 | 73 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 74 | 75 | all_cmc = np.asarray(all_cmc).astype(np.float32) 76 | all_cmc = all_cmc.sum(0) / num_valid_q # standard CMC 77 | 78 | new_all_cmc = np.asarray(new_all_cmc).astype(np.float32) 79 | new_all_cmc = new_all_cmc.sum(0) / num_valid_q 80 | mAP = np.mean(all_AP) 81 | mINP = np.mean(all_INP) 82 | return new_all_cmc, mAP, mINP 83 | 84 | 85 | 86 | def eval_regdb(distmat, q_pids, g_pids, max_rank = 20): 87 | num_q, num_g = distmat.shape 88 | if num_g < max_rank: 89 | max_rank = num_g 90 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 91 | indices = np.argsort(distmat, axis=1) 92 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 93 | 94 | # compute cmc curve for each query 95 | all_cmc = [] 96 | all_AP = [] 97 | all_INP = [] 98 | num_valid_q = 0. # number of valid query 99 | 100 | # only two cameras 101 | q_camids = np.ones(num_q).astype(np.int32) 102 | g_camids = 2* np.ones(num_g).astype(np.int32) 103 | 104 | for q_idx in range(num_q): 105 | # get query pid and camid 106 | q_pid = q_pids[q_idx] 107 | q_camid = q_camids[q_idx] 108 | 109 | # remove gallery samples that have the same pid and camid with query 110 | order = indices[q_idx] 111 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 112 | keep = np.invert(remove) 113 | 114 | # compute cmc curve 115 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 116 | if not np.any(raw_cmc): 117 | # this condition is true when query identity does not appear in gallery 118 | continue 119 | 120 | cmc = raw_cmc.cumsum() 121 | 122 | # compute mINP 123 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook 124 | pos_idx = np.where(raw_cmc == 1) 125 | pos_max_idx = np.max(pos_idx) 126 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0) 127 | all_INP.append(inp) 128 | 129 | cmc[cmc > 1] = 1 130 | 131 | all_cmc.append(cmc[:max_rank]) 132 | num_valid_q += 1. 133 | 134 | # compute average precision 135 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 136 | num_rel = raw_cmc.sum() 137 | tmp_cmc = raw_cmc.cumsum() 138 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 139 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 140 | AP = tmp_cmc.sum() / num_rel 141 | all_AP.append(AP) 142 | 143 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 144 | 145 | all_cmc = np.asarray(all_cmc).astype(np.float32) 146 | all_cmc = all_cmc.sum(0) / num_valid_q 147 | mAP = np.mean(all_AP) 148 | mINP = np.mean(all_INP) 149 | return all_cmc, mAP, mINP -------------------------------------------------------------------------------- /image-20210909100353763.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Chenfeng1271/MGMRA/d9c2986a4d692f76990888aca5d60fd11c389017/image-20210909100353763.png -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd.function import Function 7 | from torch.autograd import Variable 8 | 9 | 10 | class CenterTripletLoss(nn.Module): 11 | """ Hetero-center-triplet-loss-for-VT-Re-ID 12 | "Parameters Sharing Exploration and Hetero-Center Triplet Loss for Visible-Thermal Person Re-Identification" 13 | [(arxiv)](https://arxiv.org/abs/2008.06223). 14 | 15 | Args: 16 | - margin (float): margin for triplet. 17 | """ 18 | 19 | def __init__(self, batch_size, margin=0.3): 20 | super(CenterTripletLoss, self).__init__() 21 | self.margin = margin 22 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 23 | 24 | def forward(self, feats, labels): 25 | """ 26 | Args: 27 | - inputs: feature matrix with shape (batch_size, feat_dim) 28 | - targets: ground truth labels with shape (num_classes) 29 | """ 30 | label_uni = labels.unique() 31 | targets = torch.cat([label_uni,label_uni]) 32 | label_num = len(label_uni) 33 | feat = feats.chunk(label_num*2, 0) 34 | center = [] 35 | for i in range(label_num*2): 36 | center.append(torch.mean(feat[i], dim=0, keepdim=True)) 37 | inputs = torch.cat(center) 38 | 39 | n = inputs.size(0) 40 | 41 | # Compute pairwise distance, replace by the official when merged 42 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 43 | dist = dist + dist.t() 44 | dist.addmm_(1, -2, inputs, inputs.t()) 45 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 46 | 47 | # For each anchor, find the hardest positive and negative 48 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 49 | dist_ap, dist_an = [], [] 50 | for i in range(n): 51 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 52 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 53 | dist_ap = torch.cat(dist_ap) 54 | dist_an = torch.cat(dist_an) 55 | 56 | # Compute ranking hinge loss 57 | y = torch.ones_like(dist_an) 58 | loss = self.ranking_loss(dist_an, dist_ap, y) 59 | 60 | # compute accuracy 61 | correct = torch.ge(dist_an, dist_ap).sum().item() 62 | return loss, correct 63 | 64 | 65 | 66 | 67 | 68 | class CrossEntropyLabelSmooth(nn.Module): 69 | """Cross entropy loss with label smoothing regularizer. 70 | Reference: 71 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 72 | Equation: y = (1 - epsilon) * y + epsilon / K. 73 | Args: 74 | num_classes (int): number of classes. 75 | epsilon (float): weight. 76 | """ 77 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 78 | super(CrossEntropyLabelSmooth, self).__init__() 79 | self.num_classes = num_classes 80 | self.epsilon = epsilon 81 | self.use_gpu = use_gpu 82 | self.logsoftmax = nn.LogSoftmax(dim=1) 83 | 84 | def forward(self, inputs, targets): 85 | """ 86 | Args: 87 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 88 | targets: ground truth labels with shape (num_classes) 89 | """ 90 | log_probs = self.logsoftmax(inputs) 91 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 92 | if self.use_gpu: targets = targets.cuda() 93 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 94 | loss = (- targets * log_probs).mean(0).sum() 95 | return loss 96 | 97 | 98 | class OriTripletLoss(nn.Module): 99 | """Triplet loss with hard positive/negative mining. 100 | 101 | Reference: 102 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 103 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 104 | 105 | Args: 106 | - margin (float): margin for triplet. 107 | """ 108 | 109 | def __init__(self, batch_size, margin=0.3): 110 | super(OriTripletLoss, self).__init__() 111 | self.margin = margin 112 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 113 | 114 | def forward(self, inputs, targets): 115 | """ 116 | Args: 117 | - inputs: feature matrix with shape (batch_size, feat_dim) 118 | - targets: ground truth labels with shape (num_classes) 119 | """ 120 | n = inputs.size(0) 121 | 122 | # Compute pairwise distance, replace by the official when merged 123 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 124 | dist = dist + dist.t() 125 | dist.addmm_(1, -2, inputs, inputs.t()) 126 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 127 | 128 | # For each anchor, find the hardest positive and negative 129 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 130 | dist_ap, dist_an = [], [] 131 | for i in range(n): 132 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 133 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 134 | dist_ap = torch.cat(dist_ap) 135 | dist_an = torch.cat(dist_an) 136 | 137 | # Compute ranking hinge loss 138 | y = torch.ones_like(dist_an) 139 | loss = self.ranking_loss(dist_an, dist_ap, y) 140 | 141 | # compute accuracy 142 | correct = torch.ge(dist_an, dist_ap).sum().item() 143 | return loss, correct 144 | 145 | 146 | 147 | 148 | # Adaptive weights 149 | def softmax_weights(dist, mask): 150 | max_v = torch.max(dist * mask, dim=1, keepdim=True)[0] 151 | diff = dist - max_v 152 | Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero 153 | W = torch.exp(diff) * mask / Z 154 | return W 155 | 156 | def normalize(x, axis=-1): 157 | """Normalizing to unit length along the specified dimension. 158 | Args: 159 | x: pytorch Variable 160 | Returns: 161 | x: pytorch Variable, same shape as input 162 | """ 163 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 164 | return x 165 | 166 | class TripletLoss_WRT(nn.Module): 167 | """Weighted Regularized Triplet'.""" 168 | 169 | def __init__(self): 170 | super(TripletLoss_WRT, self).__init__() 171 | self.ranking_loss = nn.SoftMarginLoss() 172 | 173 | def forward(self, inputs, targets, normalize_feature=False): 174 | if normalize_feature: 175 | inputs = normalize(inputs, axis=-1) 176 | dist_mat = pdist_torch(inputs, inputs) 177 | 178 | N = dist_mat.size(0) 179 | # shape [N, N] 180 | is_pos = targets.expand(N, N).eq(targets.expand(N, N).t()).float() 181 | is_neg = targets.expand(N, N).ne(targets.expand(N, N).t()).float() 182 | 183 | # `dist_ap` means distance(anchor, positive) 184 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 185 | dist_ap = dist_mat * is_pos 186 | dist_an = dist_mat * is_neg 187 | 188 | weights_ap = softmax_weights(dist_ap, is_pos) 189 | weights_an = softmax_weights(-dist_an, is_neg) 190 | furthest_positive = torch.sum(dist_ap * weights_ap, dim=1) 191 | closest_negative = torch.sum(dist_an * weights_an, dim=1) 192 | 193 | y = furthest_positive.new().resize_as_(furthest_positive).fill_(1) 194 | loss = self.ranking_loss(closest_negative - furthest_positive, y) 195 | 196 | 197 | # compute accuracy 198 | correct = torch.ge(closest_negative, furthest_positive).sum().item() 199 | return loss, correct 200 | 201 | def pdist_torch(emb1, emb2): 202 | ''' 203 | compute the eucilidean distance matrix between embeddings1 and embeddings2 204 | using gpu 205 | ''' 206 | m, n = emb1.shape[0], emb2.shape[0] 207 | emb1_pow = torch.pow(emb1, 2).sum(dim = 1, keepdim = True).expand(m, n) 208 | emb2_pow = torch.pow(emb2, 2).sum(dim = 1, keepdim = True).expand(n, m).t() 209 | dist_mtx = emb1_pow + emb2_pow 210 | dist_mtx = dist_mtx.addmm_(1, -2, emb1, emb2.t()) 211 | # dist_mtx = dist_mtx.clamp(min = 1e-12) 212 | dist_mtx = dist_mtx.clamp(min = 1e-12).sqrt() 213 | return dist_mtx 214 | 215 | 216 | def pdist_np(emb1, emb2): 217 | ''' 218 | compute the eucilidean distance matrix between embeddings1 and embeddings2 219 | using cpu 220 | ''' 221 | m, n = emb1.shape[0], emb2.shape[0] 222 | emb1_pow = np.square(emb1).sum(axis = 1)[..., np.newaxis] 223 | emb2_pow = np.square(emb2).sum(axis = 1)[np.newaxis, ...] 224 | dist_mtx = -2 * np.matmul(emb1, emb2.T) + emb1_pow + emb2_pow 225 | # dist_mtx = np.sqrt(dist_mtx.clip(min = 1e-12)) 226 | return dist_mtx 227 | 228 | 229 | class global_loss_idx(nn.Module): 230 | 231 | def __init__(self, batch_size, margin=0.3): 232 | super(global_loss_idx, self).__init__() 233 | self.margin = margin 234 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 235 | 236 | def forward(self, global_feat, labels): 237 | global_feat = normalize(global_feat, axis=-1) 238 | inputs = global_feat 239 | 240 | 241 | n = inputs.size(0) 242 | dist_mat = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 243 | dist_mat = dist_mat + dist_mat.t() 244 | dist_mat.addmm_(1, -2, inputs, inputs.t()) 245 | dist_mat = dist_mat.clamp(min=1e-12).sqrt() 246 | 247 | 248 | N = dist_mat.size(0) 249 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 250 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 251 | 252 | dist_ap, relative_p_inds = torch.max( 253 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 254 | dist_an, relative_n_inds = torch.min( 255 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 256 | 257 | dist_ap = dist_ap.squeeze(1) 258 | dist_an = dist_an.squeeze(1) 259 | 260 | ind = (labels.new().resize_as_(labels) 261 | .copy_(torch.arange(0, N).long()) 262 | .unsqueeze(0).expand(N, N)) 263 | 264 | p_inds = torch.gather( 265 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 266 | n_inds = torch.gather( 267 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 268 | # shape [N] 269 | p_inds = p_inds.squeeze(1) 270 | n_inds = n_inds.squeeze(1) 271 | 272 | 273 | label_uni = labels.unique() 274 | targets = torch.cat([label_uni, label_uni]) 275 | label_num = len(label_uni) 276 | global_feat = global_feat.chunk(label_num * 2, 0) 277 | center = [] 278 | for i in range(label_num * 2): 279 | center.append(torch.mean(global_feat[i], dim=0, keepdim=True)) 280 | 281 | inputs = torch.cat(center) 282 | 283 | n = inputs.size(0) 284 | # Compute pairwise distance, replace by the official when merged 285 | dist_c = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 286 | dist_c = dist_c + dist_c.t() 287 | dist_c.addmm_(1, -2, inputs, inputs.t()) 288 | dist_c = dist_c.clamp(min=1e-12).sqrt() # for numerical stability 289 | 290 | # For each anchor, find the hardest positive and negative 291 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 292 | dist_ap_c, dist_an_c = [], [] 293 | for i in range(n): 294 | dist_ap_c.append(dist_c[i][mask[i]].max().unsqueeze(0)) 295 | dist_an_c.append(dist_c[i][mask[i] == 0].min().unsqueeze(0)) 296 | dist_ap_c = torch.cat(dist_ap_c) 297 | dist_an_c = torch.cat(dist_an_c) 298 | 299 | # Compute ranking hinge loss 300 | y = torch.ones_like(dist_an_c) 301 | loss = self.ranking_loss(dist_an_c, dist_ap_c, y) 302 | 303 | return loss, p_inds, n_inds, dist_ap, dist_an 304 | 305 | 306 | def batch_local_dist(x, y): 307 | """ 308 | Args: 309 | x: pytorch Variable, with shape [N, m, d] 310 | y: pytorch Variable, with shape [N, n, d] 311 | Returns: 312 | dist: pytorch Variable, with shape [N] 313 | """ 314 | assert len(x.size()) == 3 315 | assert len(y.size()) == 3 316 | assert x.size(0) == y.size(0) 317 | assert x.size(-1) == y.size(-1) 318 | 319 | # shape [N, m, n] 320 | dist_mat = batch_euclidean_dist(x, y) 321 | dist_mat = (torch.exp(dist_mat) - 1.) / (torch.exp(dist_mat) + 1.) 322 | # shape [N] 323 | dist = shortest_dist(dist_mat.permute(1, 2, 0)) 324 | return dist 325 | 326 | def batch_euclidean_dist(x, y): 327 | """ 328 | Args: 329 | x: pytorch Variable, with shape [N, m, d] 330 | y: pytorch Variable, with shape [N, n, d] 331 | Returns: 332 | dist: pytorch Variable, with shape [N, m, n] 333 | """ 334 | assert len(x.size()) == 3 335 | assert len(y.size()) == 3 336 | assert x.size(0) == y.size(0) 337 | assert x.size(-1) == y.size(-1) 338 | 339 | N, m, d = x.size() 340 | N, n, d = y.size() 341 | 342 | # shape [N, m, n] 343 | xx = torch.pow(x, 2).sum(-1, keepdim=True).expand(N, m, n) 344 | yy = torch.pow(y, 2).sum(-1, keepdim=True).expand(N, n, m).permute(0, 2, 1) 345 | dist = xx + yy 346 | dist.baddbmm_(1, -2, x, y.permute(0, 2, 1)) 347 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 348 | return dist 349 | 350 | class local_loss_idx(nn.Module): 351 | 352 | def __init__(self, batch_size, margin=0.3): 353 | super(local_loss_idx, self).__init__() 354 | self.margin = margin 355 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 356 | 357 | def forward(self, local_feat, p_inds, n_inds, labels): 358 | local_feat = normalize(local_feat, axis=-1) 359 | 360 | dist_ap = batch_local_dist(local_feat, local_feat[p_inds.long()]) 361 | dist_an = batch_local_dist(local_feat, local_feat[n_inds.long()]) 362 | 363 | y = torch.ones_like(dist_an) 364 | loss = self.ranking_loss(dist_an, dist_ap, y) 365 | return loss, dist_ap, dist_an 366 | 367 | def local_dist(x, y): 368 | """ 369 | Args: 370 | x: pytorch Variable, with shape [M, m, d] 371 | y: pytorch Variable, with shape [N, n, d] 372 | Returns: 373 | dist: pytorch Variable, with shape [M, N] 374 | """ 375 | M, m, d = x.size() 376 | N, n, d = y.size() 377 | x = x.contiguous().view(M * m, d) 378 | y = y.contiguous().view(N * n, d) 379 | # shape [M * m, N * n] 380 | dist_mat = euclidean_dist(x, y) 381 | dist_mat = (torch.exp(dist_mat) - 1.) / (torch.exp(dist_mat) + 1.) 382 | # shape [M * m, N * n] -> [M, m, N, n] -> [m, n, M, N] 383 | dist_mat = dist_mat.contiguous().view(M, m, N, n).permute(1, 3, 0, 2) 384 | # shape [M, N] 385 | dist_mat = shortest_dist(dist_mat) 386 | return dist_mat 387 | 388 | def shortest_dist(dist_mat): 389 | """Parallel version. 390 | Args: 391 | dist_mat: pytorch Variable, available shape: 392 | 1) [m, n] 393 | 2) [m, n, N], N is batch size 394 | 3) [m, n, *], * can be arbitrary additional dimensions 395 | Returns: 396 | dist: three cases corresponding to `dist_mat`: 397 | 1) scalar 398 | 2) pytorch Variable, with shape [N] 399 | 3) pytorch Variable, with shape [*] 400 | """ 401 | m, n = dist_mat.size()[:2] 402 | # Just offering some reference for accessing intermediate distance. 403 | dist = [[0 for _ in range(n)] for _ in range(m)] 404 | for i in range(m): 405 | for j in range(n): 406 | if (i == 0) and (j == 0): 407 | dist[i][j] = dist_mat[i, j] 408 | elif (i == 0) and (j > 0): 409 | dist[i][j] = dist[i][j - 1] + dist_mat[i, j] 410 | elif (i > 0) and (j == 0): 411 | dist[i][j] = dist[i - 1][j] + dist_mat[i, j] 412 | else: 413 | dist[i][j] = torch.min(dist[i - 1][j], dist[i][j - 1]) + dist_mat[i, j] 414 | dist = dist[-1][-1] 415 | return dist 416 | 417 | def hard_example_mining(dist_mat, labels, return_inds=False): 418 | """For each anchor, find the hardest positive and negative sample. 419 | Args: 420 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 421 | labels: pytorch LongTensor, with shape [N] 422 | return_inds: whether to return the indices. Save time if `False`(?) 423 | Returns: 424 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 425 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 426 | p_inds: pytorch LongTensor, with shape [N]; 427 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 428 | n_inds: pytorch LongTensor, with shape [N]; 429 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 430 | NOTE: Only consider the case in which all labels have same num of samples, 431 | thus we can cope with all anchors in parallel. 432 | """ 433 | 434 | assert len(dist_mat.size()) == 2 435 | assert dist_mat.size(0) == dist_mat.size(1) 436 | N = dist_mat.size(0) 437 | 438 | # shape [N, N] 439 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 440 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 441 | 442 | # `dist_ap` means distance(anchor, positive) 443 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 444 | dist_ap, relative_p_inds = torch.max( 445 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 446 | # `dist_an` means distance(anchor, negative) 447 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 448 | dist_an, relative_n_inds = torch.min( 449 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 450 | # shape [N] 451 | dist_ap = dist_ap.squeeze(1) 452 | dist_an = dist_an.squeeze(1) 453 | 454 | if return_inds: 455 | # shape [N, N] 456 | ind = (labels.new().resize_as_(labels) 457 | .copy_(torch.arange(0, N).long()) 458 | .unsqueeze( 0).expand(N, N)) 459 | # shape [N, 1] 460 | p_inds = torch.gather( 461 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 462 | n_inds = torch.gather( 463 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 464 | # shape [N] 465 | p_inds = p_inds.squeeze(1) 466 | n_inds = n_inds.squeeze(1) 467 | return dist_ap, dist_an, p_inds, n_inds 468 | 469 | return dist_ap, dist_an 470 | 471 | class BarlowTwins_loss(nn.Module): 472 | """ https://github.com/facebookresearch/barlowtwins. 473 | Reference: 474 | Barlow Twins: Self-Supervised Learning via Redundancy Reduction. 475 | 476 | """ 477 | 478 | def __init__(self, batch_size, margin=0.3): 479 | super(BarlowTwins_loss, self).__init__() 480 | self.margin = margin 481 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 482 | 483 | # projector 484 | 485 | def forward(self, inputs, targets): 486 | import pdb 487 | pdb.set_trace() 488 | 489 | # normalization layer for the representations z1 and z2 490 | # z1 = nn.BatchNorm1d(input1) 491 | # z2 = nn.BatchNorm1d(input2) 492 | 493 | # inputs = torch.tensor([item.cpu().detach().numpy() for item in inputs]).cuda() 494 | 495 | feat_V, feat_T = torch.chunk(inputs, 2, dim=0) 496 | c_metrix = feat_V.T @ feat_T # empirical cross-correlation matrix 497 | 498 | n = inputs.size(0) 499 | 500 | c_metrix.div_(n) # sum the cross-correlation matrix between all gpus 501 | 502 | on_diag = torch.diagonal(c_metrix).add_(-1).pow_(2).sum() 503 | 504 | off_diag = off_diagonal(c_metrix).pow_(2).sum() 505 | # off_diag 比例从0.00051递增到0.051(10倍数递增)效果逐渐增加。 506 | loss = (on_diag + 0.051 * off_diag) / 2048 507 | return loss 508 | 509 | def off_diagonal(x): 510 | # return a flattened view of the off-diagonal elements of a square matrix 511 | n, m, d = x.shape 512 | assert n == m 513 | return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() 514 | 515 | class BarlowTwins_loss_mem(nn.Module): 516 | """ https://github.com/facebookresearch/barlowtwins. 517 | Reference: 518 | Barlow Twins: Self-Supervised Learning via Redundancy Reduction. 519 | 520 | """ 521 | 522 | def __init__(self, margin=0.3): 523 | super(BarlowTwins_loss_mem, self).__init__() 524 | self.margin = margin 525 | #self.ranking_loss = nn.MarginRankingLoss(margin=margin) 526 | 527 | # projector 528 | 529 | def forward(self, inputs): 530 | 531 | # normalization layer for the representations z1 and z2 532 | # z1 = nn.BatchNorm1d(input1) 533 | # z2 = nn.BatchNorm1d(input2) 534 | #feat_V, feat_T = torch.chunk(inputs, 2, dim=0) 535 | 536 | b = inputs.permute([0,2,1]) 537 | n = inputs.size(0) 538 | c = b @ inputs # empirical cross-correlation matrix 539 | c = c.permute([1,2,0]) 540 | 541 | 542 | 543 | 544 | c.div_(n) # sum the cross-correlation matrix between all gpus 545 | 546 | on_diag = torch.diagonal(c).add_(-1).pow_(2).sum() 547 | #off_diag = off_diagonal(c).pow_(2).sum() 548 | off_diag = c.pow_(2).sum() - torch.diagonal(c).pow_(2).sum() 549 | 550 | # off_diag 比例从0.00051递增到0.051(10倍数递增)效果逐渐增加。 551 | loss = (on_diag + 0.051 * off_diag) / 200 552 | return loss -------------------------------------------------------------------------------- /memory_MGMRA.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function 2 | import torch 3 | from torch import nn 4 | import math 5 | from torch.nn.parameter import Parameter 6 | from torch.nn import functional as F 7 | import numpy as np 8 | 9 | ##此版本ins output做sem input 10 | 11 | # 12 | class MemoryUnit(nn.Module): 13 | def __init__(self, ptt_num, num_cls, part_num,fea_dim, shrink_thres=0.0025): 14 | super(MemoryUnit, self).__init__() 15 | ''' 16 | the instance PTT is divided into cls_number x ptt_number per cls x part number per ptt 17 | ''' 18 | self.num_cls = num_cls 19 | self.ptt_num = ptt_num 20 | self.part_num = part_num 21 | 22 | self.mem_dim = ptt_num * num_cls * part_num # M 23 | self.fea_dim = fea_dim # C 24 | self.weight = Parameter(torch.Tensor(self.mem_dim, self.fea_dim)) # M x C 25 | #self.sem_weight = Parameter(torch.Tensor(self.num_cls, self.fea_dim)) # N x C 26 | self.bias = None 27 | self.shrink_thres= shrink_thres 28 | # self.hard_sparse_shrink_opt = nn.Hardshrink(lambd=shrink_thres) 29 | 30 | self.avgpool = nn.AdaptiveAvgPool1d(1) 31 | self.reweight_layer_part = nn.Conv1d(self.part_num,self.part_num,1) 32 | self.reweight_layer_ins = nn.Conv1d(self.ptt_num,self.ptt_num,1) 33 | 34 | self.reset_parameters() 35 | 36 | def reset_parameters(self): 37 | stdv = 1. / math.sqrt(self.weight.size(1)) 38 | self.weight.data.uniform_(-stdv, stdv) 39 | if self.bias is not None: 40 | self.bias.data.uniform_(-stdv, stdv) 41 | 42 | def get_update_query(self, mem, max_indices, score, query): 43 | m, d = mem.size() 44 | 45 | query_update = torch.zeros((m,d)).cuda() 46 | #random_update = torch.zeros((m,d)).cuda() 47 | for i in range(m): 48 | idx = torch.nonzero(max_indices.squeeze(1)==i) 49 | a, _ = idx.size() 50 | #ex = update_indices[0][i] 51 | if a != 0: 52 | #random_idx = torch.randperm(a)[0] 53 | #idx = idx[idx != ex] 54 | # query_update[i] = torch.sum(query[idx].squeeze(1), dim=0) 55 | query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0) 56 | #random_update[i] = query[random_idx] * (score[random_idx,i] / torch.max(score[:,i])) 57 | else: 58 | query_update[i] = 0 59 | #random_update[i] = 0 60 | 61 | 62 | return query_update 63 | 64 | def forward(self, input, residual=False): 65 | ''' 66 | this is a bottom-up hierarchical stastic and summaration module 67 | all steps in main flow follow part -> prototype -> cls 68 | input = NHW x C 69 | total PTT M = num_cls (L) x ptt_num (T) x part_num (P) 70 | dimension C = fea_dim 71 | ''' 72 | ### for global part-unware instance PTT, act as sub flow 73 | att_weight = F.linear(input, self.weight) # we doesn't split the part dimension, there it is part-unaware NHW x M 74 | import pdb 75 | #pdb.set_trace() 76 | att_weight = F.softmax(att_weight, dim=1) # NHW x M 77 | ### update ### 78 | #_, gather_indice = torch.topk(att_weight, 1, dim=1) 79 | #ins_mem_sample_driven = self.get_update_query(self.weight, gather_indice, att_weight,input) 80 | #self.weight.data = F.normalize(ins_mem_sample_driven+ self.weight, dim=1) 81 | 82 | if self.shrink_thres >0: 83 | att_weight = hard_shrink_relu(att_weight, lambd=self.shrink_thres) 84 | att_weight = F.normalize(att_weight, p=1, dim=1) 85 | 86 | mem_trans = self.weight.permute(1, 0) # Mem^T, MxC 87 | output = F.linear(att_weight, mem_trans) # AttWeight x Mem^T^T = AW x Mem, (TxM) x (MxC) = TxC 88 | 89 | ### for global part-aware instance PTT 90 | import pdb 91 | #pdb.set_trace() 92 | 93 | self.reweight_part = self.weight.view(self.num_cls*self.ptt_num, self.fea_dim, -1).permute(0,2,1) 94 | self.reweight_part = (torch.sigmoid(self.reweight_layer_part(self.reweight_part))*self.reweight_part).permute(0,2,1) 95 | #self.reweight_part = (self.reweight_part).permute(0,2,1) 96 | self.part_ins_att = self.avgpool(self.reweight_part).squeeze(-1) 97 | ins_att_weight = F.linear(input, self.part_ins_att) # this is for global part-aware instance ptt which is not used in ours [NHW, C] x[C, M] = [NHW, M] 98 | ins_att_weight = F.softmax(ins_att_weight, dim=1) # NHW x LT 99 | if self.shrink_thres >0: 100 | ins_att_weight = hard_shrink_relu(ins_att_weight, lambd=self.shrink_thres) 101 | ins_att_weight = F.normalize(ins_att_weight, p=1, dim=1) 102 | 103 | ins_mem_trans = self.part_ins_att.permute(1, 0) # Mem^T, MxC 104 | output_part = F.linear(ins_att_weight, ins_mem_trans) # AttWeight x Mem^T^T = AW x Mem, (TxM) x (MxC) = TxC 105 | 106 | ### for semantic PTT 107 | #pdb.set_trace() 108 | self.reweight_ins = self.part_ins_att.view(self.num_cls, self.ptt_num, self.fea_dim) 109 | self.reweight_ins = (torch.sigmoid(self.reweight_layer_ins(self.reweight_ins))*self.reweight_ins).permute(0,2,1) 110 | #self.reweight_ins = (self.reweight_ins).permute(0,2,1) 111 | self.sem_att = self.avgpool(self.reweight_ins).squeeze(-1) 112 | sem_att_weight = F.linear(input, self.sem_att) 113 | sem_att_weight = F.softmax(sem_att_weight, dim=1) 114 | 115 | if self.shrink_thres >0: 116 | sem_att_weight = hard_shrink_relu(sem_att_weight, lambd=self.shrink_thres) 117 | sem_att_weight = F.normalize(sem_att_weight, p=1, dim=1) 118 | 119 | sem_mem_trans = self.sem_att.permute(1,0) 120 | output_sem = F.linear(sem_att_weight, sem_mem_trans) 121 | 122 | if residual: 123 | output_sem +=output 124 | 125 | 126 | #return {'output': output, 'att': att_weight} # output, att_weight 127 | return {'output': output_sem, 'att': sem_att_weight,'sem_attn': self.sem_att, 'output_part': output} 128 | 129 | 130 | def extra_repr(self): 131 | return 'mem_dim={}, fea_dim={}'.format( 132 | self.mem_dim, self.fea_dim is not None 133 | ) 134 | 135 | 136 | # NxCxHxW -> (NxHxW)xC -> addressing Mem, (NxHxW)xC -> NxCxHxW 137 | class MemModule(nn.Module): 138 | def __init__(self, ptt_num, num_cls, part_num, fea_dim, shrink_thres=0.0025, device='cuda'): 139 | super(MemModule, self).__init__() 140 | self.ptt_num = ptt_num 141 | self.num_cls = num_cls 142 | self.part_num = part_num 143 | ins_mem= False 144 | if ins_mem: 145 | self.mem_dim = ptt_num * num_cls * part_num# part-level instance 146 | else: 147 | self.mem_dim = num_cls# global semantic 148 | self.fea_dim = fea_dim 149 | self.shrink_thres = shrink_thres 150 | self.memory = MemoryUnit(self.ptt_num, self.num_cls, self.part_num, self.fea_dim, self.shrink_thres) 151 | 152 | def forward(self, input): 153 | s = input.data.shape 154 | l = len(s) 155 | 156 | if l == 3: 157 | x = input.permute(0, 2, 1) 158 | elif l == 4: 159 | x = input.permute(0, 2, 3, 1) 160 | elif l == 5: 161 | x = input.permute(0, 2, 3, 4, 1) 162 | else: 163 | x = [] 164 | print('wrong feature map size') 165 | x = x.contiguous() 166 | x = x.view(-1, s[1]) 167 | # 168 | y_and = self.memory(x) 169 | # 170 | y = y_and['output'] 171 | att = y_and['att'] 172 | y_part = y_and['output_part'] 173 | 174 | if l == 3: 175 | y = y.view(s[0], s[2], s[1]) 176 | y = y.permute(0, 2, 1) 177 | att = att.view(s[0], s[2], self.mem_dim) 178 | att = att.permute(0, 2, 1) 179 | elif l == 4: 180 | y = y.view(s[0], s[2], s[3], s[1]) 181 | y = y.permute(0, 3, 1, 2) 182 | att = att.view(s[0], s[2], s[3], self.mem_dim) 183 | att = att.permute(0, 3, 1, 2) 184 | elif l == 5: 185 | y = y.view(s[0], s[2], s[3], s[4], s[1]) 186 | y = y.permute(0, 4, 1, 2, 3) 187 | att = att.view(s[0], s[2], s[3], s[4], self.mem_dim) 188 | att = att.permute(0, 4, 1, 2, 3) 189 | else: 190 | y = x 191 | att = att 192 | print('wrong feature map size') 193 | return y, y_and['sem_attn'], y_part 194 | 195 | # relu based hard shrinkage function, only works for positive values 196 | def hard_shrink_relu(input, lambd=0, epsilon=1e-12): 197 | output = (F.relu(input-lambd) * input) / (torch.abs(input - lambd) + epsilon) 198 | return output 199 | 200 | -------------------------------------------------------------------------------- /memory_SGMRA.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function 2 | import torch 3 | from torch import nn 4 | import math 5 | from torch.nn.parameter import Parameter 6 | from torch.nn import functional as F 7 | import numpy as np 8 | 9 | 10 | # 11 | class MemoryUnit(nn.Module): 12 | def __init__(self, ptt_num, num_cls, part_num,fea_dim, shrink_thres=0.0025): 13 | super(MemoryUnit, self).__init__() 14 | ''' 15 | the instance PTT is divided into cls_number x ptt_number per cls x part number per ptt 16 | ''' 17 | self.num_cls = num_cls 18 | self.ptt_num = ptt_num 19 | self.part_num = part_num 20 | 21 | self.mem_dim = ptt_num * num_cls * part_num # M 22 | self.fea_dim = fea_dim # C 23 | self.weight = Parameter(torch.Tensor(self.mem_dim, self.fea_dim)) # M x C 24 | #self.sem_weight = Parameter(torch.Tensor(self.num_cls, self.fea_dim)) # N x C 25 | self.bias = None 26 | self.shrink_thres= shrink_thres 27 | # self.hard_sparse_shrink_opt = nn.Hardshrink(lambd=shrink_thres) 28 | 29 | self.avgpool = nn.AdaptiveAvgPool1d(1) 30 | 31 | self.reset_parameters() 32 | 33 | def reset_parameters(self): 34 | stdv = 1. / math.sqrt(self.weight.size(1)) 35 | self.weight.data.uniform_(-stdv, stdv) 36 | if self.bias is not None: 37 | self.bias.data.uniform_(-stdv, stdv) 38 | 39 | def get_update_query(self, mem, max_indices, score, query): 40 | m, d = mem.size() 41 | 42 | query_update = torch.zeros((m,d)).cuda() 43 | #random_update = torch.zeros((m,d)).cuda() 44 | for i in range(m): 45 | idx = torch.nonzero(max_indices.squeeze(1)==i) 46 | a, _ = idx.size() 47 | #ex = update_indices[0][i] 48 | if a != 0: 49 | #random_idx = torch.randperm(a)[0] 50 | #idx = idx[idx != ex] 51 | # query_update[i] = torch.sum(query[idx].squeeze(1), dim=0) 52 | query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0) 53 | #random_update[i] = query[random_idx] * (score[random_idx,i] / torch.max(score[:,i])) 54 | else: 55 | query_update[i] = 0 56 | #random_update[i] = 0 57 | 58 | 59 | return query_update 60 | 61 | def forward(self, input, residual=False): 62 | ''' 63 | this is a bottom-up hierarchical stastic and summaration module 64 | all steps in main flow follow part -> prototype -> cls 65 | input = NHW x C 66 | total PTT M = num_cls (L) x ptt_num (T) x part_num (P) 67 | dimension C = fea_dim 68 | ''' 69 | ### for global part-unware instance PTT, act as sub flow 70 | att_weight = F.linear(input, self.weight) # we doesn't split the part dimension, there it is part-unaware NHW x M 71 | att_weight = F.softmax(att_weight, dim=1) # NHW x M 72 | ### update ### 73 | #_, gather_indice = torch.topk(att_weight, 1, dim=1) 74 | #ins_mem_sample_driven = self.get_update_query(self.weight, gather_indice, att_weight,input) 75 | #self.weight.data = F.normalize(ins_mem_sample_driven+ self.weight, dim=1) 76 | 77 | if self.shrink_thres >0: 78 | att_weight = hard_shrink_relu(att_weight, lambd=self.shrink_thres) 79 | att_weight = F.normalize(att_weight, p=1, dim=1) 80 | 81 | mem_trans = self.weight.permute(1, 0) # Mem^T, MxC 82 | output = F.linear(att_weight, mem_trans) # AttWeight x Mem^T^T = AW x Mem, (TxM) x (MxC) = TxC 83 | 84 | 85 | 86 | #return {'output': output, 'att': att_weight} # output, att_weight 87 | return {'output': output, 'att': None,'sem_attn': self.weight} 88 | 89 | 90 | def extra_repr(self): 91 | return 'mem_dim={}, fea_dim={}'.format( 92 | self.mem_dim, self.fea_dim is not None 93 | ) 94 | 95 | 96 | # NxCxHxW -> (NxHxW)xC -> addressing Mem, (NxHxW)xC -> NxCxHxW 97 | class MemModule(nn.Module): 98 | def __init__(self, ptt_num, num_cls, part_num, fea_dim, shrink_thres=0.0025, device='cuda'): 99 | super(MemModule, self).__init__() 100 | self.ptt_num = ptt_num 101 | self.num_cls = num_cls 102 | self.part_num = part_num 103 | ins_mem= False 104 | if ins_mem: 105 | self.mem_dim = ptt_num * num_cls * part_num# part-level instance 106 | else: 107 | self.mem_dim = num_cls# global semantic 108 | self.fea_dim = fea_dim 109 | self.shrink_thres = shrink_thres 110 | self.memory = MemoryUnit(self.ptt_num, self.num_cls, self.part_num, self.fea_dim, self.shrink_thres) 111 | 112 | def forward(self, input): 113 | s = input.data.shape 114 | l = len(s) 115 | 116 | if l == 3: 117 | x = input.permute(0, 2, 1) 118 | elif l == 4: 119 | x = input.permute(0, 2, 3, 1) 120 | elif l == 5: 121 | x = input.permute(0, 2, 3, 4, 1) 122 | else: 123 | x = [] 124 | print('wrong feature map size') 125 | x = x.contiguous() 126 | x = x.view(-1, s[1]) 127 | # 128 | y_and = self.memory(x) 129 | # 130 | y = y_and['output'] 131 | att = y_and['att'] 132 | 133 | if l == 3: 134 | y = y.view(s[0], s[2], s[1]) 135 | y = y.permute(0, 2, 1) 136 | att = att.view(s[0], s[2], self.mem_dim) 137 | att = att.permute(0, 2, 1) 138 | elif l == 4: 139 | y = y.view(s[0], s[2], s[3], s[1]) 140 | y = y.permute(0, 3, 1, 2) 141 | #att = att.view(s[0], s[2], s[3], self.mem_dim) 142 | #att = att.permute(0, 3, 1, 2) 143 | elif l == 5: 144 | y = y.view(s[0], s[2], s[3], s[4], s[1]) 145 | y = y.permute(0, 4, 1, 2, 3) 146 | att = att.view(s[0], s[2], s[3], s[4], self.mem_dim) 147 | att = att.permute(0, 4, 1, 2, 3) 148 | else: 149 | y = x 150 | att = att 151 | print('wrong feature map size') 152 | return y, y_and['sem_attn'] 153 | 154 | # relu based hard shrinkage function, only works for positive values 155 | def hard_shrink_relu(input, lambd=0, epsilon=1e-12): 156 | output = (F.relu(input-lambd) * input) / (torch.abs(input - lambd) + epsilon) 157 | return output 158 | 159 | -------------------------------------------------------------------------------- /memory_module_MGMRA.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, print_function 2 | import torch 3 | from torch import nn 4 | import math 5 | from torch.nn.parameter import Parameter 6 | from torch.nn import functional as F 7 | import numpy as np 8 | 9 | 10 | class MemoryUnit(nn.Module): 11 | def __init__(self, ptt_num, num_cls, part_num,fea_dim, shrink_thres=0.0025): 12 | super(MemoryUnit, self).__init__() 13 | ''' 14 | the instance PTT is divided into cls_number x ptt_number per cls x part number per ptt 15 | ''' 16 | self.num_cls = num_cls 17 | self.ptt_num = ptt_num 18 | self.part_num = part_num 19 | 20 | self.mem_dim = ptt_num * num_cls * part_num # M 21 | self.fea_dim = fea_dim # C 22 | self.weight = Parameter(torch.Tensor(self.mem_dim, self.fea_dim)) # M x C 23 | #self.sem_weight = Parameter(torch.Tensor(self.num_cls, self.fea_dim)) # N x C 24 | self.bias = None 25 | self.shrink_thres= shrink_thres 26 | # self.hard_sparse_shrink_opt = nn.Hardshrink(lambd=shrink_thres) 27 | 28 | self.avgpool = nn.AdaptiveAvgPool1d(1) 29 | self.reweight_layer_part = nn.Conv1d(self.part_num,self.part_num,1) 30 | self.reweight_layer_ins = nn.Conv1d(self.ptt_num,self.ptt_num,1) 31 | 32 | self.reset_parameters() 33 | 34 | def reset_parameters(self): 35 | stdv = 1. / math.sqrt(self.weight.size(1)) 36 | self.weight.data.uniform_(-stdv, stdv) 37 | if self.bias is not None: 38 | self.bias.data.uniform_(-stdv, stdv) 39 | 40 | def get_update_query(self, mem, max_indices, score, query): 41 | m, d = mem.size() 42 | 43 | query_update = torch.zeros((m,d)).cuda() 44 | #random_update = torch.zeros((m,d)).cuda() 45 | for i in range(m): 46 | idx = torch.nonzero(max_indices.squeeze(1)==i) 47 | a, _ = idx.size() 48 | #ex = update_indices[0][i] 49 | if a != 0: 50 | #random_idx = torch.randperm(a)[0] 51 | #idx = idx[idx != ex] 52 | # query_update[i] = torch.sum(query[idx].squeeze(1), dim=0) 53 | query_update[i] = torch.sum(((score[idx,i] / torch.max(score[:,i])) *query[idx].squeeze(1)), dim=0) 54 | #random_update[i] = query[random_idx] * (score[random_idx,i] / torch.max(score[:,i])) 55 | else: 56 | query_update[i] = 0 57 | #random_update[i] = 0 58 | 59 | 60 | return query_update 61 | 62 | def forward(self, input, residual=False): 63 | ''' 64 | this is a bottom-up hierarchical stastic and summaration module 65 | all steps in main flow follow part -> prototype -> cls 66 | input = NHW x C 67 | total PTT M = num_cls (L) x ptt_num (T) x part_num (P) 68 | dimension C = fea_dim 69 | ''' 70 | ### for global part-unware instance PTT, act as sub flow 71 | att_weight = F.linear(input, self.weight) # we doesn't split the part dimension, there it is part-unaware NHW x M 72 | import pdb 73 | #pdb.set_trace() 74 | att_weight = F.softmax(att_weight, dim=1) # NHW x M 75 | ### update ### 76 | #_, gather_indice = torch.topk(att_weight, 1, dim=1) 77 | #ins_mem_sample_driven = self.get_update_query(self.weight, gather_indice, att_weight,input) 78 | #self.weight.data = F.normalize(ins_mem_sample_driven+ self.weight, dim=1) 79 | 80 | if self.shrink_thres >0: 81 | att_weight = hard_shrink_relu(att_weight, lambd=self.shrink_thres) 82 | att_weight = F.normalize(att_weight, p=1, dim=1) 83 | 84 | mem_trans = self.weight.permute(1, 0) # Mem^T, MxC 85 | output_part = F.linear(att_weight, mem_trans) # AttWeight x Mem^T^T = AW x Mem, (TxM) x (MxC) = TxC 86 | 87 | ### for global part-aware instance PTT 88 | 89 | self.reweight_part = self.weight.view(self.num_cls*self.ptt_num, self.fea_dim, -1).permute(0,2,1) 90 | self.reweight_part = (torch.sigmoid(self.reweight_layer_part(self.reweight_part))*self.reweight_part).permute(0,2,1) 91 | self.part_ins_att = self.avgpool(self.reweight_part).squeeze(-1) 92 | ins_att_weight = F.linear(output_part, self.part_ins_att) # this is for global part-aware instance ptt which is not used in ours [NHW, C] x[C, M] = [NHW, M] 93 | ins_att_weight = F.softmax(ins_att_weight, dim=1) # NHW x LT 94 | if self.shrink_thres >0: 95 | ins_att_weight = hard_shrink_relu(ins_att_weight, lambd=self.shrink_thres) 96 | ins_att_weight = F.normalize(ins_att_weight, p=1, dim=1) 97 | 98 | ins_mem_trans = self.part_ins_att.permute(1, 0) # Mem^T, MxC 99 | output_ins = F.linear(ins_att_weight, ins_mem_trans) # AttWeight x Mem^T^T = AW x Mem, (TxM) x (MxC) = TxC 100 | 101 | ### for semantic PTT 102 | #pdb.set_trace() 103 | self.reweight_ins = self.part_ins_att.view(self.num_cls, self.ptt_num, self.fea_dim) 104 | self.reweight_ins = (torch.sigmoid(self.reweight_layer_ins(self.reweight_ins))*self.reweight_ins).permute(0,2,1) 105 | self.sem_att = self.avgpool(self.reweight_ins).squeeze(-1) 106 | sem_att_weight = F.linear(output_ins, self.sem_att) 107 | sem_att_weight = F.softmax(sem_att_weight, dim=1) 108 | 109 | if self.shrink_thres >0: 110 | sem_att_weight = hard_shrink_relu(sem_att_weight, lambd=self.shrink_thres) 111 | sem_att_weight = F.normalize(sem_att_weight, p=1, dim=1) 112 | 113 | sem_mem_trans = self.sem_att.permute(1,0) 114 | output_sem = F.linear(sem_att_weight, sem_mem_trans) 115 | 116 | if residual: 117 | output_sem +=output 118 | 119 | 120 | #return {'output': output, 'att': att_weight} # output, att_weight 121 | return {'output_sem': output_sem, 'output_part': output_part, 'output_ins':output_ins} 122 | 123 | 124 | def extra_repr(self): 125 | return 'mem_dim={}, fea_dim={}'.format( 126 | self.mem_dim, self.fea_dim is not None 127 | ) 128 | 129 | 130 | # NxCxHxW -> (NxHxW)xC -> addressing Mem, (NxHxW)xC -> NxCxHxW 131 | class MemModule(nn.Module): 132 | def __init__(self, ptt_num, num_cls, part_num, fea_dim, shrink_thres=0.0025, device='cuda'): 133 | super(MemModule, self).__init__() 134 | self.ptt_num = ptt_num 135 | self.num_cls = num_cls 136 | self.part_num = part_num 137 | ins_mem= False 138 | if ins_mem: 139 | self.mem_dim = ptt_num * num_cls * part_num# part-level instance 140 | else: 141 | self.mem_dim = num_cls# global semantic 142 | self.fea_dim = fea_dim 143 | self.shrink_thres = shrink_thres 144 | self.memory = MemoryUnit(self.ptt_num, self.num_cls, self.part_num, self.fea_dim, self.shrink_thres) 145 | 146 | def forward(self, input): 147 | s = input.data.shape 148 | l = len(s) 149 | 150 | if l == 3: 151 | x = input.permute(0, 2, 1) 152 | elif l == 4: 153 | x = input.permute(0, 2, 3, 1) 154 | elif l == 5: 155 | x = input.permute(0, 2, 3, 4, 1) 156 | else: 157 | x = [] 158 | print('wrong feature map size') 159 | x = x.contiguous() 160 | x = x.view(-1, s[1]) 161 | # 162 | y_and = self.memory(x) 163 | # 164 | y_sem = y_and['output_sem'] 165 | y_ins = y_and['output_ins'] 166 | y_part = y_and['output_part'] 167 | 168 | 169 | if l == 4: 170 | y_sem = y_sem.view(s[0], s[2], s[3], s[1]) 171 | y_sem = y_sem.permute(0, 3, 1, 2) 172 | y_ins = y_ins.view(s[0], s[2], s[3], s[1]) 173 | y_ins = y_ins.permute(0, 3, 1, 2) 174 | y_part = y_part.view(s[0], s[2], s[3], s[1]) 175 | y_part = y_part.permute(0, 3, 1, 2) 176 | 177 | else: 178 | print('wrong feature map size') 179 | return y_sem, y_ins, y_part 180 | 181 | # relu based hard shrinkage function, only works for positive values 182 | def hard_shrink_relu(input, lambd=0, epsilon=1e-12): 183 | output = (F.relu(input-lambd) * input) / (torch.abs(input - lambd) + epsilon) 184 | return output 185 | 186 | -------------------------------------------------------------------------------- /model_MGMRA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | from resnet import resnet50, resnet18 6 | from memory_MGMRA import MemModule 7 | #from memory_module_h import MemModule 8 | import random 9 | 10 | ##此版本为使用memory做part feature 11 | 12 | class Normalize(nn.Module): 13 | def __init__(self, power=2): 14 | super(Normalize, self).__init__() 15 | self.power = power 16 | 17 | def forward(self, x): 18 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 19 | out = x.div(norm) 20 | return out 21 | 22 | class Non_local(nn.Module): 23 | def __init__(self, in_channels, reduc_ratio=2): 24 | super(Non_local, self).__init__() 25 | 26 | self.in_channels = in_channels 27 | self.inter_channels = reduc_ratio//reduc_ratio 28 | 29 | self.g = nn.Sequential( 30 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, 31 | padding=0), 32 | ) 33 | 34 | self.W = nn.Sequential( 35 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, 36 | kernel_size=1, stride=1, padding=0), 37 | nn.BatchNorm2d(self.in_channels), 38 | ) 39 | nn.init.constant_(self.W[1].weight, 0.0) 40 | nn.init.constant_(self.W[1].bias, 0.0) 41 | 42 | 43 | 44 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 45 | kernel_size=1, stride=1, padding=0) 46 | 47 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 48 | kernel_size=1, stride=1, padding=0) 49 | 50 | def forward(self, x): 51 | ''' 52 | :param x: (b, c, t, h, w) 53 | :return: 54 | ''' 55 | 56 | batch_size = x.size(0) 57 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 58 | g_x = g_x.permute(0, 2, 1) 59 | 60 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 61 | theta_x = theta_x.permute(0, 2, 1) 62 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 63 | f = torch.matmul(theta_x, phi_x) 64 | N = f.size(-1) 65 | # f_div_C = torch.nn.functional.softmax(f, dim=-1) 66 | f_div_C = f / N 67 | 68 | y = torch.matmul(f_div_C, g_x) 69 | y = y.permute(0, 2, 1).contiguous() 70 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 71 | W_y = self.W(y) 72 | z = W_y + x 73 | 74 | return z 75 | 76 | 77 | # ##################################################################### 78 | def weights_init_kaiming(m): 79 | classname = m.__class__.__name__ 80 | # print(classname) 81 | if classname.find('Conv') != -1: 82 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 83 | elif classname.find('Linear') != -1: 84 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 85 | init.zeros_(m.bias.data) 86 | elif classname.find('BatchNorm1d') != -1: 87 | init.normal_(m.weight.data, 1.0, 0.01) 88 | init.zeros_(m.bias.data) 89 | 90 | def weights_init_classifier(m): 91 | classname = m.__class__.__name__ 92 | if classname.find('Linear') != -1: 93 | init.normal_(m.weight.data, 0, 0.001) 94 | if m.bias: 95 | init.zeros_(m.bias.data) 96 | 97 | 98 | 99 | class visible_module(nn.Module): 100 | def __init__(self, arch='resnet50', share_net=1): 101 | super(visible_module, self).__init__() 102 | 103 | model_v = resnet50(pretrained=True, 104 | last_conv_stride=1, last_conv_dilation=1) 105 | # avg pooling to global pooling 106 | self.share_net = share_net 107 | 108 | if self.share_net == 0: 109 | pass 110 | else: 111 | self.visible = nn.ModuleList() 112 | self.visible.conv1 = model_v.conv1 113 | self.visible.bn1 = model_v.bn1 114 | self.visible.relu = model_v.relu 115 | self.visible.maxpool = model_v.maxpool 116 | if self.share_net > 1: 117 | for i in range(1, self.share_net): 118 | setattr(self.visible,'layer'+str(i), getattr(model_v,'layer'+str(i))) 119 | 120 | def forward(self, x): 121 | if self.share_net == 0: 122 | return x 123 | else: 124 | x = self.visible.conv1(x) 125 | x = self.visible.bn1(x) 126 | x = self.visible.relu(x) 127 | x = self.visible.maxpool(x) 128 | 129 | if self.share_net > 1: 130 | for i in range(1, self.share_net): 131 | x = getattr(self.visible, 'layer'+str(i))(x) 132 | return x 133 | 134 | 135 | class thermal_module(nn.Module): 136 | def __init__(self, arch='resnet50', share_net=1): 137 | super(thermal_module, self).__init__() 138 | 139 | model_t = resnet50(pretrained=True, 140 | last_conv_stride=1, last_conv_dilation=1) 141 | # avg pooling to global pooling 142 | self.share_net = share_net 143 | 144 | if self.share_net == 0: 145 | pass 146 | else: 147 | self.thermal = nn.ModuleList() 148 | self.thermal.conv1 = model_t.conv1 149 | self.thermal.bn1 = model_t.bn1 150 | self.thermal.relu = model_t.relu 151 | self.thermal.maxpool = model_t.maxpool 152 | if self.share_net > 1: 153 | for i in range(1, self.share_net): 154 | setattr(self.thermal,'layer'+str(i), getattr(model_t,'layer'+str(i))) 155 | 156 | def forward(self, x): 157 | if self.share_net == 0: 158 | return x 159 | else: 160 | x = self.thermal.conv1(x) 161 | x = self.thermal.bn1(x) 162 | x = self.thermal.relu(x) 163 | x = self.thermal.maxpool(x) 164 | 165 | if self.share_net > 1: 166 | for i in range(1, self.share_net): 167 | x = getattr(self.thermal, 'layer'+str(i))(x) 168 | return x 169 | 170 | 171 | class base_resnet(nn.Module): 172 | def __init__(self, arch='resnet50', share_net=1): 173 | super(base_resnet, self).__init__() 174 | 175 | model_base = resnet50(pretrained=True, 176 | last_conv_stride=1, last_conv_dilation=1) 177 | # avg pooling to global pooling 178 | model_base.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 179 | self.share_net = share_net 180 | if self.share_net == 0: 181 | self.base = model_base 182 | else: 183 | self.base = nn.ModuleList() 184 | 185 | if self.share_net > 4: 186 | pass 187 | else: 188 | for i in range(self.share_net, 5): 189 | setattr(self.base,'layer'+str(i), getattr(model_base,'layer'+str(i))) 190 | 191 | def forward(self, x): 192 | if self.share_net == 0: 193 | x = self.base.conv1(x) 194 | x = self.base.bn1(x) 195 | x = self.base.relu(x) 196 | x = self.base.maxpool(x) 197 | 198 | x = self.base.layer1(x) 199 | x = self.base.layer2(x) 200 | x = self.base.layer3(x) 201 | x = self.base.layer4(x) 202 | return x 203 | elif self.share_net > 4: 204 | return x 205 | else: 206 | for i in range(self.share_net, 5): 207 | x = getattr(self.base, 'layer'+str(i))(x) 208 | return x 209 | 210 | 211 | 212 | class embed_net(nn.Module): 213 | def __init__(self, class_num, no_local= 'off', gm_pool = 'on', arch='resnet50', share_net=1, pcb='on',local_feat_dim=256, num_strips=6): 214 | super(embed_net, self).__init__() 215 | 216 | self.thermal_module = thermal_module(arch=arch, share_net=share_net) 217 | self.visible_module = visible_module(arch=arch, share_net=share_net) 218 | self.base_resnet = base_resnet(arch=arch, share_net=share_net) 219 | 220 | self.non_local = no_local 221 | self.pcb = pcb 222 | if self.non_local =='on': 223 | pass 224 | 225 | 226 | pool_dim = 2048 227 | self.l2norm = Normalize(2) 228 | self.gm_pool = gm_pool 229 | 230 | ##memory module 231 | self.mem_rep = MemModule(ptt_num=5, num_cls=206, part_num=6, fea_dim=pool_dim, shrink_thres =0.0025) 232 | self.pool_mem = nn.AdaptiveAvgPool2d((1,1)) 233 | self.bn = nn.BatchNorm2d(pool_dim) 234 | self.bottleneck = nn.BatchNorm1d(pool_dim) 235 | self.classifier = nn.Linear(pool_dim, class_num, bias=False) 236 | self.classifier.apply(weights_init_classifier) 237 | self.bottleneck.apply(weights_init_kaiming) 238 | 239 | if self.pcb == 'on': 240 | self.num_stripes=num_strips 241 | local_conv_out_channels=local_feat_dim 242 | 243 | self.local_conv_list = nn.ModuleList() 244 | for _ in range(self.num_stripes): 245 | conv = nn.Conv2d(pool_dim, local_conv_out_channels, 1) 246 | conv.apply(weights_init_kaiming) 247 | self.local_conv_list.append(nn.Sequential( 248 | conv, 249 | nn.BatchNorm2d(local_conv_out_channels), 250 | nn.ReLU(inplace=True) 251 | )) 252 | 253 | self.fc_list = nn.ModuleList() 254 | for _ in range(self.num_stripes): 255 | fc = nn.Linear(local_conv_out_channels, class_num) 256 | init.normal_(fc.weight, std=0.001) 257 | init.constant_(fc.bias, 0) 258 | self.fc_list.append(fc) 259 | 260 | 261 | else: 262 | self.bottleneck = nn.BatchNorm1d(pool_dim) 263 | self.bottleneck.bias.requires_grad_(False) # no shift 264 | 265 | self.classifier = nn.Linear(pool_dim, class_num, bias=False) 266 | 267 | self.bottleneck.apply(weights_init_kaiming) 268 | self.classifier.apply(weights_init_classifier) 269 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 270 | 271 | 272 | 273 | 274 | def forward(self, x1, x2, modal=0): 275 | if modal == 0: 276 | x1 = self.visible_module(x1) 277 | x2 = self.thermal_module(x2) 278 | x = torch.cat((x1, x2), 0) 279 | elif modal == 1: 280 | x = self.visible_module(x1) 281 | elif modal == 2: 282 | x = self.thermal_module(x2) 283 | 284 | # shared block 285 | if self.non_local == 'on': 286 | pass 287 | else: 288 | x = self.base_resnet(x) 289 | 290 | ## memory module 291 | #x_mem, att_mem = self.mem_rep(x) 292 | #x_mem += x 293 | #x_mem_pool = self.pool_mem(x_mem).view(x_mem.size(0), x_mem.size(1)) 294 | #x_mem_feat = self.bottleneck(x_mem_pool) 295 | 296 | 297 | if self.pcb == 'on': 298 | feat = x 299 | assert feat.size(2) % self.num_stripes == 0 300 | stripe_h = int(feat.size(2) / self.num_stripes) 301 | local_feat_list = [] 302 | logits_list = [] 303 | local_feat_mem_list = [] 304 | local_feat_mem_part_list = [] 305 | for i in range(self.num_stripes): 306 | # shape [N, C, 1, 1] 307 | 308 | # average pool 309 | #local_feat = F.avg_pool2d(feat[:, :, i * stripe_h: (i + 1) * stripe_h, :],(stripe_h, feat.size(-1))) 310 | if self.gm_pool == 'on': 311 | # gm pool 312 | local_feat = feat[:, :, i * stripe_h: (i + 1) * stripe_h, :] 313 | local_feat_mem, _, local_feat_mem_part = self.mem_rep(local_feat) 314 | 315 | local_feat_mem_part_list.append(local_feat_mem_part) 316 | local_feat_mem = local_feat + local_feat_mem 317 | b, c, h, w = local_feat.shape 318 | local_feat = local_feat.view(b,c,-1) 319 | p = 3.0 # regDB: 10.0 SYSU: 3.0 320 | local_feat = (torch.mean(local_feat**p, dim=-1) + 1e-12)**(1/p) 321 | else: 322 | # average pool 323 | #local_feat = F.avg_pool2d(feat[:, :, i * stripe_h: (i + 1) * stripe_h, :],(stripe_h, feat.size(-1))) 324 | local_feat = F.max_pool2d(feat[:, :, i * stripe_h: (i + 1) * stripe_h, :],(stripe_h, feat.size(-1))) 325 | 326 | 327 | # shape [N, c, 1, 1] 328 | local_feat = self.local_conv_list[i](local_feat.view(feat.size(0),feat.size(1),1,1)) 329 | 330 | 331 | # shape [N, c] 332 | local_feat = local_feat.view(local_feat.size(0), -1) 333 | local_feat_list.append(local_feat) 334 | local_feat_mem_list.append(local_feat_mem) 335 | 336 | 337 | if hasattr(self, 'fc_list'): 338 | logits_list.append(self.fc_list[i](local_feat)) 339 | 340 | 341 | 342 | feat_all = [lf for lf in local_feat_list] 343 | feat_all = torch.cat(feat_all, dim=1) 344 | 345 | feat_all_mem = [lf for lf in local_feat_mem_list] 346 | feat_all_mem = torch.cat(feat_all_mem, dim=2) 347 | 348 | lf_mem_pool = self.pool_mem(feat_all_mem).view(feat_all_mem.size(0), feat_all_mem.size(1)) 349 | lf_mem_feat = self.bottleneck(lf_mem_pool) 350 | 351 | ### this part is for part alignment, we then would change the discription here 352 | feat_all_part = [lf for lf in local_feat_mem_part_list] 353 | index = [i for i in range(len(feat_all_part))] 354 | random.shuffle(index) 355 | feat_all_part_shuffle = [feat_all_part[i] for i in index] 356 | feat_all_part_chunk = torch.cat(feat_all_part_shuffle, dim=1) 357 | p_1, p_2 = torch.chunk(feat_all_part_chunk,2,1) 358 | 359 | 360 | 361 | 362 | 363 | if self.training: 364 | #return local_feat_list, logits_list, feat_all , x_mem_pool+ lf_mem_pool, self.classifier(x_mem_feat+lf_mem_feat) 365 | return local_feat_list, logits_list, feat_all , lf_mem_pool, self.classifier(lf_mem_feat),[p_1,p_2] 366 | else: 367 | return self.l2norm(feat_all) 368 | else: 369 | if self.gm_pool == 'on': 370 | b, c, h, w = x.shape 371 | x = x.view(b, c, -1) 372 | p = 3.0 373 | x_pool = (torch.mean(x**p, dim=-1) + 1e-12)**(1/p) 374 | else: 375 | x_pool = self.avgpool(x) 376 | x_pool = x_pool.view(x_pool.size(0), x_pool.size(1)) 377 | 378 | feat = self.bottleneck(x_pool) 379 | 380 | if self.training: 381 | return x_pool, self.classifier(feat)#, scores 382 | else: 383 | return self.l2norm(x_pool), self.l2norm(feat) 384 | 385 | -------------------------------------------------------------------------------- /model_SGMRA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | from resnet import resnet50, resnet18 6 | from memory_SGMRA import MemModule 7 | 8 | 9 | class Normalize(nn.Module): 10 | def __init__(self, power=2): 11 | super(Normalize, self).__init__() 12 | self.power = power 13 | 14 | def forward(self, x): 15 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 16 | out = x.div(norm) 17 | return out 18 | 19 | class Non_local(nn.Module): 20 | def __init__(self, in_channels, reduc_ratio=2): 21 | super(Non_local, self).__init__() 22 | 23 | self.in_channels = in_channels 24 | self.inter_channels = reduc_ratio//reduc_ratio 25 | 26 | self.g = nn.Sequential( 27 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, 28 | padding=0), 29 | ) 30 | 31 | self.W = nn.Sequential( 32 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, 33 | kernel_size=1, stride=1, padding=0), 34 | nn.BatchNorm2d(self.in_channels), 35 | ) 36 | nn.init.constant_(self.W[1].weight, 0.0) 37 | nn.init.constant_(self.W[1].bias, 0.0) 38 | 39 | 40 | 41 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 42 | kernel_size=1, stride=1, padding=0) 43 | 44 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 45 | kernel_size=1, stride=1, padding=0) 46 | 47 | def forward(self, x): 48 | ''' 49 | :param x: (b, c, t, h, w) 50 | :return: 51 | ''' 52 | 53 | batch_size = x.size(0) 54 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 55 | g_x = g_x.permute(0, 2, 1) 56 | 57 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 58 | theta_x = theta_x.permute(0, 2, 1) 59 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 60 | f = torch.matmul(theta_x, phi_x) 61 | N = f.size(-1) 62 | # f_div_C = torch.nn.functional.softmax(f, dim=-1) 63 | f_div_C = f / N 64 | 65 | y = torch.matmul(f_div_C, g_x) 66 | y = y.permute(0, 2, 1).contiguous() 67 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 68 | W_y = self.W(y) 69 | z = W_y + x 70 | 71 | return z 72 | 73 | 74 | # ##################################################################### 75 | def weights_init_kaiming(m): 76 | classname = m.__class__.__name__ 77 | # print(classname) 78 | if classname.find('Conv') != -1: 79 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 80 | elif classname.find('Linear') != -1: 81 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 82 | init.zeros_(m.bias.data) 83 | elif classname.find('BatchNorm1d') != -1: 84 | init.normal_(m.weight.data, 1.0, 0.01) 85 | init.zeros_(m.bias.data) 86 | 87 | def weights_init_classifier(m): 88 | classname = m.__class__.__name__ 89 | if classname.find('Linear') != -1: 90 | init.normal_(m.weight.data, 0, 0.001) 91 | if m.bias: 92 | init.zeros_(m.bias.data) 93 | 94 | 95 | 96 | class visible_module(nn.Module): 97 | def __init__(self, arch='resnet50', share_net=1): 98 | super(visible_module, self).__init__() 99 | 100 | model_v = resnet50(pretrained=True, 101 | last_conv_stride=1, last_conv_dilation=1) 102 | # avg pooling to global pooling 103 | self.share_net = share_net 104 | 105 | if self.share_net == 0: 106 | pass 107 | else: 108 | self.visible = nn.ModuleList() 109 | self.visible.conv1 = model_v.conv1 110 | self.visible.bn1 = model_v.bn1 111 | self.visible.relu = model_v.relu 112 | self.visible.maxpool = model_v.maxpool 113 | if self.share_net > 1: 114 | for i in range(1, self.share_net): 115 | setattr(self.visible,'layer'+str(i), getattr(model_v,'layer'+str(i))) 116 | 117 | def forward(self, x): 118 | if self.share_net == 0: 119 | return x 120 | else: 121 | x = self.visible.conv1(x) 122 | x = self.visible.bn1(x) 123 | x = self.visible.relu(x) 124 | x = self.visible.maxpool(x) 125 | 126 | if self.share_net > 1: 127 | for i in range(1, self.share_net): 128 | x = getattr(self.visible, 'layer'+str(i))(x) 129 | return x 130 | 131 | 132 | class thermal_module(nn.Module): 133 | def __init__(self, arch='resnet50', share_net=1): 134 | super(thermal_module, self).__init__() 135 | 136 | model_t = resnet50(pretrained=True, 137 | last_conv_stride=1, last_conv_dilation=1) 138 | # avg pooling to global pooling 139 | self.share_net = share_net 140 | 141 | if self.share_net == 0: 142 | pass 143 | else: 144 | self.thermal = nn.ModuleList() 145 | self.thermal.conv1 = model_t.conv1 146 | self.thermal.bn1 = model_t.bn1 147 | self.thermal.relu = model_t.relu 148 | self.thermal.maxpool = model_t.maxpool 149 | if self.share_net > 1: 150 | for i in range(1, self.share_net): 151 | setattr(self.thermal,'layer'+str(i), getattr(model_t,'layer'+str(i))) 152 | 153 | def forward(self, x): 154 | if self.share_net == 0: 155 | return x 156 | else: 157 | x = self.thermal.conv1(x) 158 | x = self.thermal.bn1(x) 159 | x = self.thermal.relu(x) 160 | x = self.thermal.maxpool(x) 161 | 162 | if self.share_net > 1: 163 | for i in range(1, self.share_net): 164 | x = getattr(self.thermal, 'layer'+str(i))(x) 165 | return x 166 | 167 | 168 | class base_resnet(nn.Module): 169 | def __init__(self, arch='resnet50', share_net=1): 170 | super(base_resnet, self).__init__() 171 | 172 | model_base = resnet50(pretrained=True, 173 | last_conv_stride=1, last_conv_dilation=1) 174 | # avg pooling to global pooling 175 | model_base.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 176 | self.share_net = share_net 177 | if self.share_net == 0: 178 | self.base = model_base 179 | else: 180 | self.base = nn.ModuleList() 181 | 182 | if self.share_net > 4: 183 | pass 184 | else: 185 | for i in range(self.share_net, 5): 186 | setattr(self.base,'layer'+str(i), getattr(model_base,'layer'+str(i))) 187 | 188 | def forward(self, x): 189 | if self.share_net == 0: 190 | x = self.base.conv1(x) 191 | x = self.base.bn1(x) 192 | x = self.base.relu(x) 193 | x = self.base.maxpool(x) 194 | 195 | x = self.base.layer1(x) 196 | x = self.base.layer2(x) 197 | x = self.base.layer3(x) 198 | x = self.base.layer4(x) 199 | return x 200 | elif self.share_net > 4: 201 | return x 202 | else: 203 | for i in range(self.share_net, 5): 204 | x = getattr(self.base, 'layer'+str(i))(x) 205 | return x 206 | 207 | 208 | 209 | class embed_net(nn.Module): 210 | def __init__(self, class_num, no_local= 'off', gm_pool = 'on', arch='resnet50', share_net=1, pcb='on',local_feat_dim=256, num_strips=6): 211 | super(embed_net, self).__init__() 212 | 213 | self.thermal_module = thermal_module(arch=arch, share_net=share_net) 214 | self.visible_module = visible_module(arch=arch, share_net=share_net) 215 | self.base_resnet = base_resnet(arch=arch, share_net=share_net) 216 | 217 | self.non_local = no_local 218 | self.pcb = pcb 219 | if self.non_local =='on': 220 | pass 221 | 222 | 223 | pool_dim = 2048 224 | self.l2norm = Normalize(2) 225 | self.gm_pool = gm_pool 226 | 227 | ##memory module 228 | self.mem_rep = MemModule(ptt_num=5, num_cls=206, part_num=6, fea_dim=pool_dim, shrink_thres =0.0025) 229 | self.pool_mem = nn.AdaptiveAvgPool2d((1,1)) 230 | self.bn = nn.BatchNorm2d(pool_dim) 231 | self.bottleneck = nn.BatchNorm1d(pool_dim) 232 | self.classifier = nn.Linear(pool_dim, class_num, bias=False) 233 | self.classifier.apply(weights_init_classifier) 234 | self.bottleneck.apply(weights_init_kaiming) 235 | 236 | if self.pcb == 'on': 237 | self.num_stripes=num_strips 238 | local_conv_out_channels=local_feat_dim 239 | 240 | self.local_conv_list = nn.ModuleList() 241 | for _ in range(self.num_stripes): 242 | conv = nn.Conv2d(pool_dim, local_conv_out_channels, 1) 243 | conv.apply(weights_init_kaiming) 244 | self.local_conv_list.append(nn.Sequential( 245 | conv, 246 | nn.BatchNorm2d(local_conv_out_channels), 247 | nn.ReLU(inplace=True) 248 | )) 249 | 250 | self.fc_list = nn.ModuleList() 251 | for _ in range(self.num_stripes): 252 | fc = nn.Linear(local_conv_out_channels, class_num) 253 | init.normal_(fc.weight, std=0.001) 254 | init.constant_(fc.bias, 0) 255 | self.fc_list.append(fc) 256 | 257 | 258 | else: 259 | self.bottleneck = nn.BatchNorm1d(pool_dim) 260 | self.bottleneck.bias.requires_grad_(False) # no shift 261 | 262 | self.classifier = nn.Linear(pool_dim, class_num, bias=False) 263 | 264 | self.bottleneck.apply(weights_init_kaiming) 265 | self.classifier.apply(weights_init_classifier) 266 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 267 | 268 | 269 | 270 | 271 | def forward(self, x1, x2, modal=0): 272 | if modal == 0: 273 | x1 = self.visible_module(x1) 274 | x2 = self.thermal_module(x2) 275 | x = torch.cat((x1, x2), 0) 276 | elif modal == 1: 277 | x = self.visible_module(x1) 278 | elif modal == 2: 279 | x = self.thermal_module(x2) 280 | 281 | # shared block 282 | if self.non_local == 'on': 283 | pass 284 | else: 285 | x = self.base_resnet(x) 286 | 287 | ## memory module 288 | #x_mem, att_mem = self.mem_rep(x) 289 | #x_mem += x 290 | #x_mem_pool = self.pool_mem(x_mem).view(x_mem.size(0), x_mem.size(1)) 291 | #x_mem_feat = self.bottleneck(x_mem_pool) 292 | 293 | 294 | if self.pcb == 'on': 295 | feat = x 296 | assert feat.size(2) % self.num_stripes == 0 297 | stripe_h = int(feat.size(2) / self.num_stripes) 298 | local_feat_list = [] 299 | logits_list = [] 300 | local_feat_mem_list = [] 301 | for i in range(self.num_stripes): 302 | # shape [N, C, 1, 1] 303 | 304 | # average pool 305 | #local_feat = F.avg_pool2d(feat[:, :, i * stripe_h: (i + 1) * stripe_h, :],(stripe_h, feat.size(-1))) 306 | if self.gm_pool == 'on': 307 | # gm pool 308 | local_feat = feat[:, :, i * stripe_h: (i + 1) * stripe_h, :] 309 | local_feat_mem, _ = self.mem_rep(local_feat) 310 | local_feat_mem = local_feat + local_feat_mem 311 | b, c, h, w = local_feat.shape 312 | local_feat = local_feat.view(b,c,-1) 313 | p = 10.0 # regDB: 10.0 SYSU: 3.0 314 | local_feat = (torch.mean(local_feat**p, dim=-1) + 1e-12)**(1/p) 315 | else: 316 | # average pool 317 | #local_feat = F.avg_pool2d(feat[:, :, i * stripe_h: (i + 1) * stripe_h, :],(stripe_h, feat.size(-1))) 318 | local_feat = F.max_pool2d(feat[:, :, i * stripe_h: (i + 1) * stripe_h, :],(stripe_h, feat.size(-1))) 319 | 320 | 321 | # shape [N, c, 1, 1] 322 | local_feat = self.local_conv_list[i](local_feat.view(feat.size(0),feat.size(1),1,1)) 323 | 324 | 325 | # shape [N, c] 326 | local_feat = local_feat.view(local_feat.size(0), -1) 327 | local_feat_list.append(local_feat) 328 | local_feat_mem_list.append(local_feat_mem) 329 | 330 | 331 | if hasattr(self, 'fc_list'): 332 | logits_list.append(self.fc_list[i](local_feat)) 333 | 334 | 335 | 336 | feat_all = [lf for lf in local_feat_list] 337 | feat_all = torch.cat(feat_all, dim=1) 338 | 339 | feat_all_mem = [lf for lf in local_feat_mem_list] 340 | feat_all_mem = torch.cat(feat_all_mem, dim=2) 341 | 342 | lf_mem_pool = self.pool_mem(feat_all_mem).view(feat_all_mem.size(0), feat_all_mem.size(1)) 343 | lf_mem_feat = self.bottleneck(lf_mem_pool) 344 | 345 | 346 | if self.training: 347 | #return local_feat_list, logits_list, feat_all , x_mem_pool+ lf_mem_pool, self.classifier(x_mem_feat+lf_mem_feat) 348 | return local_feat_list, logits_list, feat_all , lf_mem_pool, self.classifier(lf_mem_feat) 349 | else: 350 | return self.l2norm(feat_all) 351 | else: 352 | if self.gm_pool == 'on': 353 | b, c, h, w = x.shape 354 | x = x.view(b, c, -1) 355 | p = 3.0 356 | x_pool = (torch.mean(x**p, dim=-1) + 1e-12)**(1/p) 357 | else: 358 | x_pool = self.avgpool(x) 359 | x_pool = x_pool.view(x_pool.size(0), x_pool.size(1)) 360 | 361 | feat = self.bottleneck(x_pool) 362 | 363 | if self.training: 364 | return x_pool, self.classifier(feat)#, scores 365 | else: 366 | return self.l2norm(x_pool), self.l2norm(feat) 367 | 368 | -------------------------------------------------------------------------------- /model_mine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | from resnet import resnet50, resnet18 6 | 7 | class Normalize(nn.Module): 8 | def __init__(self, power=2): 9 | super(Normalize, self).__init__() 10 | self.power = power 11 | 12 | def forward(self, x): 13 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 14 | out = x.div(norm) 15 | return out 16 | 17 | class Non_local(nn.Module): 18 | def __init__(self, in_channels, reduc_ratio=2): 19 | super(Non_local, self).__init__() 20 | 21 | self.in_channels = in_channels 22 | self.inter_channels = reduc_ratio//reduc_ratio 23 | 24 | self.g = nn.Sequential( 25 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, 26 | padding=0), 27 | ) 28 | 29 | self.W = nn.Sequential( 30 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, 31 | kernel_size=1, stride=1, padding=0), 32 | nn.BatchNorm2d(self.in_channels), 33 | ) 34 | nn.init.constant_(self.W[1].weight, 0.0) 35 | nn.init.constant_(self.W[1].bias, 0.0) 36 | 37 | 38 | 39 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 40 | kernel_size=1, stride=1, padding=0) 41 | 42 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 43 | kernel_size=1, stride=1, padding=0) 44 | 45 | def forward(self, x): 46 | ''' 47 | :param x: (b, c, t, h, w) 48 | :return: 49 | ''' 50 | 51 | batch_size = x.size(0) 52 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 53 | g_x = g_x.permute(0, 2, 1) 54 | 55 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 56 | theta_x = theta_x.permute(0, 2, 1) 57 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 58 | f = torch.matmul(theta_x, phi_x) 59 | N = f.size(-1) 60 | # f_div_C = torch.nn.functional.softmax(f, dim=-1) 61 | f_div_C = f / N 62 | 63 | y = torch.matmul(f_div_C, g_x) 64 | y = y.permute(0, 2, 1).contiguous() 65 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 66 | W_y = self.W(y) 67 | z = W_y + x 68 | 69 | return z 70 | 71 | 72 | # ##################################################################### 73 | def weights_init_kaiming(m): 74 | classname = m.__class__.__name__ 75 | # print(classname) 76 | if classname.find('Conv') != -1: 77 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 78 | elif classname.find('Linear') != -1: 79 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 80 | init.zeros_(m.bias.data) 81 | elif classname.find('BatchNorm1d') != -1: 82 | init.normal_(m.weight.data, 1.0, 0.01) 83 | init.zeros_(m.bias.data) 84 | 85 | def weights_init_classifier(m): 86 | classname = m.__class__.__name__ 87 | if classname.find('Linear') != -1: 88 | init.normal_(m.weight.data, 0, 0.001) 89 | if m.bias: 90 | init.zeros_(m.bias.data) 91 | 92 | 93 | 94 | class visible_module(nn.Module): 95 | def __init__(self, arch='resnet50', share_net=1): 96 | super(visible_module, self).__init__() 97 | 98 | model_v = resnet50(pretrained=True, 99 | last_conv_stride=1, last_conv_dilation=1) 100 | # avg pooling to global pooling 101 | self.share_net = share_net 102 | 103 | if self.share_net == 0: 104 | pass 105 | else: 106 | self.visible = nn.ModuleList() 107 | self.visible.conv1 = model_v.conv1 108 | self.visible.bn1 = model_v.bn1 109 | self.visible.relu = model_v.relu 110 | self.visible.maxpool = model_v.maxpool 111 | if self.share_net > 1: 112 | for i in range(1, self.share_net): 113 | setattr(self.visible,'layer'+str(i), getattr(model_v,'layer'+str(i))) 114 | 115 | def forward(self, x): 116 | if self.share_net == 0: 117 | return x 118 | else: 119 | x = self.visible.conv1(x) 120 | x = self.visible.bn1(x) 121 | x = self.visible.relu(x) 122 | x = self.visible.maxpool(x) 123 | 124 | if self.share_net > 1: 125 | for i in range(1, self.share_net): 126 | x = getattr(self.visible, 'layer'+str(i))(x) 127 | return x 128 | 129 | 130 | class thermal_module(nn.Module): 131 | def __init__(self, arch='resnet50', share_net=1): 132 | super(thermal_module, self).__init__() 133 | 134 | model_t = resnet50(pretrained=True, 135 | last_conv_stride=1, last_conv_dilation=1) 136 | # avg pooling to global pooling 137 | self.share_net = share_net 138 | 139 | if self.share_net == 0: 140 | pass 141 | else: 142 | self.thermal = nn.ModuleList() 143 | self.thermal.conv1 = model_t.conv1 144 | self.thermal.bn1 = model_t.bn1 145 | self.thermal.relu = model_t.relu 146 | self.thermal.maxpool = model_t.maxpool 147 | if self.share_net > 1: 148 | for i in range(1, self.share_net): 149 | setattr(self.thermal,'layer'+str(i), getattr(model_t,'layer'+str(i))) 150 | 151 | def forward(self, x): 152 | if self.share_net == 0: 153 | return x 154 | else: 155 | x = self.thermal.conv1(x) 156 | x = self.thermal.bn1(x) 157 | x = self.thermal.relu(x) 158 | x = self.thermal.maxpool(x) 159 | 160 | if self.share_net > 1: 161 | for i in range(1, self.share_net): 162 | x = getattr(self.thermal, 'layer'+str(i))(x) 163 | return x 164 | 165 | 166 | class base_resnet(nn.Module): 167 | def __init__(self, arch='resnet50', share_net=1): 168 | super(base_resnet, self).__init__() 169 | 170 | model_base = resnet50(pretrained=True, 171 | last_conv_stride=1, last_conv_dilation=1) 172 | # avg pooling to global pooling 173 | model_base.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 174 | self.share_net = share_net 175 | if self.share_net == 0: 176 | self.base = model_base 177 | else: 178 | self.base = nn.ModuleList() 179 | 180 | if self.share_net > 4: 181 | pass 182 | else: 183 | for i in range(self.share_net, 5): 184 | setattr(self.base,'layer'+str(i), getattr(model_base,'layer'+str(i))) 185 | 186 | def forward(self, x): 187 | if self.share_net == 0: 188 | x = self.base.conv1(x) 189 | x = self.base.bn1(x) 190 | x = self.base.relu(x) 191 | x = self.base.maxpool(x) 192 | 193 | x = self.base.layer1(x) 194 | x = self.base.layer2(x) 195 | x = self.base.layer3(x) 196 | x = self.base.layer4(x) 197 | return x 198 | elif self.share_net > 4: 199 | return x 200 | else: 201 | for i in range(self.share_net, 5): 202 | x = getattr(self.base, 'layer'+str(i))(x) 203 | return x 204 | 205 | 206 | 207 | class embed_net(nn.Module): 208 | def __init__(self, class_num, no_local= 'off', gm_pool = 'on', arch='resnet50', share_net=1, pcb='on',local_feat_dim=256, num_strips=6): 209 | super(embed_net, self).__init__() 210 | 211 | self.thermal_module = thermal_module(arch=arch, share_net=share_net) 212 | self.visible_module = visible_module(arch=arch, share_net=share_net) 213 | self.base_resnet = base_resnet(arch=arch, share_net=share_net) 214 | 215 | self.non_local = no_local 216 | self.pcb = pcb 217 | if self.non_local =='on': 218 | pass 219 | 220 | 221 | pool_dim = 2048 222 | self.l2norm = Normalize(2) 223 | self.gm_pool = gm_pool 224 | 225 | if self.pcb == 'on': 226 | self.num_stripes=num_strips 227 | local_conv_out_channels=local_feat_dim 228 | 229 | self.local_conv_list = nn.ModuleList() 230 | for _ in range(self.num_stripes): 231 | conv = nn.Conv2d(pool_dim, local_conv_out_channels, 1) 232 | conv.apply(weights_init_kaiming) 233 | self.local_conv_list.append(nn.Sequential( 234 | conv, 235 | nn.BatchNorm2d(local_conv_out_channels), 236 | nn.ReLU(inplace=True) 237 | )) 238 | 239 | self.fc_list = nn.ModuleList() 240 | for _ in range(self.num_stripes): 241 | fc = nn.Linear(local_conv_out_channels, class_num) 242 | init.normal_(fc.weight, std=0.001) 243 | init.constant_(fc.bias, 0) 244 | self.fc_list.append(fc) 245 | 246 | 247 | else: 248 | self.bottleneck = nn.BatchNorm1d(pool_dim) 249 | self.bottleneck.bias.requires_grad_(False) # no shift 250 | 251 | self.classifier = nn.Linear(pool_dim, class_num, bias=False) 252 | 253 | self.bottleneck.apply(weights_init_kaiming) 254 | self.classifier.apply(weights_init_classifier) 255 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 256 | 257 | 258 | 259 | 260 | def forward(self, x1, x2, modal=0): 261 | if modal == 0: 262 | x1 = self.visible_module(x1) 263 | x2 = self.thermal_module(x2) 264 | x = torch.cat((x1, x2), 0) 265 | elif modal == 1: 266 | x = self.visible_module(x1) 267 | elif modal == 2: 268 | x = self.thermal_module(x2) 269 | 270 | # shared block 271 | if self.non_local == 'on': 272 | pass 273 | else: 274 | x = self.base_resnet(x) 275 | 276 | if self.pcb == 'on': 277 | feat = x 278 | assert feat.size(2) % self.num_stripes == 0 279 | stripe_h = int(feat.size(2) / self.num_stripes) 280 | local_feat_list = [] 281 | logits_list = [] 282 | for i in range(self.num_stripes): 283 | # shape [N, C, 1, 1] 284 | 285 | # average pool 286 | #local_feat = F.avg_pool2d(feat[:, :, i * stripe_h: (i + 1) * stripe_h, :],(stripe_h, feat.size(-1))) 287 | if self.gm_pool == 'on': 288 | # gm pool 289 | local_feat = feat[:, :, i * stripe_h: (i + 1) * stripe_h, :] 290 | b, c, h, w = local_feat.shape 291 | local_feat = local_feat.view(b,c,-1) 292 | p = 10.0 # regDB: 10.0 SYSU: 3.0 293 | local_feat = (torch.mean(local_feat**p, dim=-1) + 1e-12)**(1/p) 294 | else: 295 | # average pool 296 | #local_feat = F.avg_pool2d(feat[:, :, i * stripe_h: (i + 1) * stripe_h, :],(stripe_h, feat.size(-1))) 297 | local_feat = F.max_pool2d(feat[:, :, i * stripe_h: (i + 1) * stripe_h, :],(stripe_h, feat.size(-1))) 298 | 299 | 300 | # shape [N, c, 1, 1] 301 | local_feat = self.local_conv_list[i](local_feat.view(feat.size(0),feat.size(1),1,1)) 302 | 303 | 304 | # shape [N, c] 305 | local_feat = local_feat.view(local_feat.size(0), -1) 306 | local_feat_list.append(local_feat) 307 | 308 | 309 | if hasattr(self, 'fc_list'): 310 | logits_list.append(self.fc_list[i](local_feat)) 311 | 312 | feat_all = [lf for lf in local_feat_list] 313 | feat_all = torch.cat(feat_all, dim=1) 314 | 315 | 316 | if self.training: 317 | return local_feat_list, logits_list, feat_all 318 | else: 319 | return self.l2norm(feat_all) 320 | else: 321 | if self.gm_pool == 'on': 322 | b, c, h, w = x.shape 323 | x = x.view(b, c, -1) 324 | p = 3.0 325 | x_pool = (torch.mean(x**p, dim=-1) + 1e-12)**(1/p) 326 | else: 327 | x_pool = self.avgpool(x) 328 | x_pool = x_pool.view(x_pool.size(0), x_pool.size(1)) 329 | 330 | feat = self.bottleneck(x_pool) 331 | 332 | if self.training: 333 | return x_pool, self.classifier(feat)#, scores 334 | else: 335 | return self.l2norm(x_pool), self.l2norm(feat) -------------------------------------------------------------------------------- /pre_process_sysu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import pdb 4 | import os 5 | 6 | data_path = '/media/hijune/datadisk/reid-data/SYSU RGB-IR Re-ID/SYSU-MM01' 7 | 8 | rgb_cameras = ['cam1','cam2','cam4','cam5'] 9 | ir_cameras = ['cam3','cam6'] 10 | 11 | # load id info 12 | file_path_train = os.path.join(data_path,'exp/train_id.txt') 13 | file_path_val = os.path.join(data_path,'exp/val_id.txt') 14 | with open(file_path_train, 'r') as file: 15 | ids = file.read().splitlines() 16 | ids = [int(y) for y in ids[0].split(',')] 17 | id_train = ["%04d" % x for x in ids] 18 | 19 | with open(file_path_val, 'r') as file: 20 | ids = file.read().splitlines() 21 | ids = [int(y) for y in ids[0].split(',')] 22 | id_val = ["%04d" % x for x in ids] 23 | 24 | # combine train and val split 25 | id_train.extend(id_val) 26 | 27 | files_rgb = [] 28 | files_ir = [] 29 | for id in sorted(id_train): 30 | for cam in rgb_cameras: 31 | img_dir = os.path.join(data_path,cam,id) 32 | if os.path.isdir(img_dir): 33 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 34 | files_rgb.extend(new_files) 35 | 36 | for cam in ir_cameras: 37 | img_dir = os.path.join(data_path,cam,id) 38 | if os.path.isdir(img_dir): 39 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 40 | files_ir.extend(new_files) 41 | 42 | # relabel 43 | pid_container = set() 44 | for img_path in files_ir: 45 | pid = int(img_path[-13:-9]) 46 | pid_container.add(pid) 47 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 48 | fix_image_width = 144 49 | fix_image_height = 288 50 | def read_imgs(train_image): 51 | train_img = [] 52 | train_label = [] 53 | for img_path in train_image: 54 | # img 55 | img = Image.open(img_path) 56 | img = img.resize((fix_image_width, fix_image_height), Image.ANTIALIAS) 57 | pix_array = np.array(img) 58 | 59 | train_img.append(pix_array) 60 | 61 | # label 62 | pid = int(img_path[-13:-9]) 63 | pid = pid2label[pid] 64 | train_label.append(pid) 65 | return np.array(train_img), np.array(train_label) 66 | 67 | # rgb imges 68 | train_img, train_label = read_imgs(files_rgb) 69 | np.save(data_path + 'train_rgb_resized_img.npy', train_img) 70 | np.save(data_path + 'train_rgb_resized_label.npy', train_label) 71 | 72 | # ir imges 73 | train_img, train_label = read_imgs(files_ir) 74 | np.save(data_path + 'train_ir_resized_img.npy', train_img) 75 | np.save(data_path + 'train_ir_resized_label.npy', train_label) 76 | -------------------------------------------------------------------------------- /re_rank.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from scipy.spatial.distance import cdist 6 | 7 | def k_reciprocal(probFea,galFea,k1=20,k2=6,lambda_value=0.3, MemorySave = False, Minibatch = 2000): 8 | 9 | query_num = probFea.shape[0] 10 | all_num = query_num + galFea.shape[0] 11 | feat = np.append(probFea,galFea,axis = 0) 12 | feat = feat.astype(np.float16) 13 | #print('computing original distance') 14 | if MemorySave: 15 | original_dist = np.zeros(shape = [all_num,all_num],dtype = np.float16) 16 | i = 0 17 | while True: 18 | it = i + Minibatch 19 | if it < np.shape(feat)[0]: 20 | original_dist[i:it,] = np.power(cdist(feat[i:it,],feat),2).astype(np.float16) 21 | else: 22 | original_dist[i:,:] = np.power(cdist(feat[i:,],feat),2).astype(np.float16) 23 | break 24 | i = it 25 | else: 26 | original_dist = cdist(feat,feat).astype(np.float16) 27 | original_dist = np.power(original_dist,2).astype(np.float16) 28 | del feat 29 | gallery_num = original_dist.shape[0] 30 | original_dist = np.transpose(original_dist/np.max(original_dist,axis = 0)) 31 | V = np.zeros_like(original_dist).astype(np.float16) 32 | initial_rank = np.argsort(original_dist).astype(np.int32) 33 | 34 | 35 | #print('starting re_ranking') 36 | for i in range(all_num): 37 | # k-reciprocal neighbors 38 | forward_k_neigh_index = initial_rank[i,:k1+1] 39 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 40 | fi = np.where(backward_k_neigh_index==i)[0] 41 | k_reciprocal_index = forward_k_neigh_index[fi] 42 | k_reciprocal_expansion_index = k_reciprocal_index 43 | for j in range(len(k_reciprocal_index)): 44 | candidate = k_reciprocal_index[j] 45 | candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2))+1] 46 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2))+1] 47 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 48 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 49 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2/3*len(candidate_k_reciprocal_index): 50 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 51 | 52 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 53 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index]) 54 | V[i,k_reciprocal_expansion_index] = weight/np.sum(weight) 55 | original_dist = original_dist[:query_num,] 56 | if k2 != 1: 57 | V_qe = np.zeros_like(V,dtype=np.float16) 58 | for i in range(all_num): 59 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0) 60 | V = V_qe 61 | del V_qe 62 | del initial_rank 63 | invIndex = [] 64 | for i in range(gallery_num): 65 | invIndex.append(np.where(V[:,i] != 0)[0]) 66 | 67 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float16) 68 | 69 | 70 | for i in range(query_num): 71 | temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float16) 72 | indNonZero = np.where(V[i,:] != 0)[0] 73 | indImages = [] 74 | indImages = [invIndex[ind] for ind in indNonZero] 75 | for j in range(len(indNonZero)): 76 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 77 | jaccard_dist[i] = 1-temp_min/(2-temp_min) 78 | 79 | final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value 80 | del original_dist 81 | del V 82 | del jaccard_dist 83 | final_dist = final_dist[:query_num,query_num:] 84 | return final_dist 85 | 86 | 87 | 88 | def random_walk(query_feat, gall_feat, alpha = 0.95): 89 | pg_sim = torch.from_numpy(np.matmul(query_feat, np.transpose(gall_feat))) 90 | gg_sim = torch.from_numpy(np.matmul(gall_feat, np.transpose(gall_feat))) 91 | 92 | one_diag = torch.eye(gg_sim.size(0), dtype=torch.double) 93 | # row normalization 94 | zeros_diag = gg_sim - gg_sim.diag().diag() 95 | A = F.softmax(zeros_diag, dim=1) 96 | 97 | A = (1-alpha) * torch.inverse(one_diag - alpha * A) 98 | pg_sim = torch.matmul(pg_sim, A.t()) 99 | 100 | return -pg_sim.numpy() -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 18 | """3x3 convolution with padding""" 19 | # original padding is 1; original dilation is 1 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=dilation, bias=False, dilation=dilation) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride, dilation) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | # original padding is 1; original dilation is 1 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * 4) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | 97 | def __init__(self, block, layers, last_conv_stride=2, last_conv_dilation=1): 98 | 99 | self.inplanes = 64 100 | super(ResNet, self).__init__() 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 102 | bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_conv_stride, dilation=last_conv_dilation) 110 | 111 | for m in self.modules(): 112 | if isinstance(m, nn.Conv2d): 113 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 114 | m.weight.data.normal_(0, math.sqrt(2. / n)) 115 | elif isinstance(m, nn.BatchNorm2d): 116 | m.weight.data.fill_(1) 117 | m.bias.data.zero_() 118 | 119 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 120 | downsample = None 121 | if stride != 1 or self.inplanes != planes * block.expansion: 122 | downsample = nn.Sequential( 123 | nn.Conv2d(self.inplanes, planes * block.expansion, 124 | kernel_size=1, stride=stride, bias=False), 125 | nn.BatchNorm2d(planes * block.expansion), 126 | ) 127 | 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, stride, downsample, dilation)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append(block(self.inplanes, planes)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.maxpool(x) 141 | 142 | x = self.layer1(x) 143 | x = self.layer2(x) 144 | x = self.layer3(x) 145 | x = self.layer4(x) 146 | 147 | return x 148 | 149 | 150 | def remove_fc(state_dict): 151 | """Remove the fc layer parameters from state_dict.""" 152 | # for key, value in state_dict.items(): 153 | for key, value in list(state_dict.items()): 154 | if key.startswith('fc.'): 155 | del state_dict[key] 156 | return state_dict 157 | 158 | 159 | def resnet18(pretrained=False, **kwargs): 160 | """Constructs a ResNet-18 model. 161 | Args: 162 | pretrained (bool): If True, returns a model pre-trained on ImageNet 163 | """ 164 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 165 | if pretrained: 166 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet18']))) 167 | return model 168 | 169 | 170 | def resnet34(pretrained=False, **kwargs): 171 | """Constructs a ResNet-34 model. 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 176 | if pretrained: 177 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet34']))) 178 | return model 179 | 180 | 181 | def resnet50(pretrained=False, **kwargs): 182 | """Constructs a ResNet-50 model. 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | """ 186 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 187 | if pretrained: 188 | # model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50']))) 189 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50']))) 190 | return model 191 | 192 | 193 | def resnet101(pretrained=False, **kwargs): 194 | """Constructs a ResNet-101 model. 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | """ 198 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 199 | if pretrained: 200 | model.load_state_dict( 201 | remove_fc(model_zoo.load_url(model_urls['resnet101']))) 202 | return model 203 | 204 | 205 | def resnet152(pretrained=False, **kwargs): 206 | """Constructs a ResNet-152 model. 207 | Args: 208 | pretrained (bool): If True, returns a model pre-trained on ImageNet 209 | """ 210 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 211 | if pretrained: 212 | model.load_state_dict( 213 | remove_fc(model_zoo.load_url(model_urls['resnet152']))) 214 | return model -------------------------------------------------------------------------------- /test_mine_pcb.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import time 4 | import torch.nn as nn 5 | import torch.backends.cudnn as cudnn 6 | from torch.autograd import Variable 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | from data_loader import SYSUData, RegDBData, TestData 10 | from data_manager import * 11 | from eval_metrics import eval_sysu, eval_regdb 12 | from model_mem_2 import embed_net 13 | from utils import * 14 | import pdb 15 | from re_rank import random_walk, k_reciprocal 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 18 | parser.add_argument('--dataset', default='sysu', help='dataset name: regdb or sysu]') 19 | parser.add_argument('--lr', default=0.1 , type=float, help='learning rate, 0.00035 for adam') 20 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer') 21 | parser.add_argument('--arch', default='resnet50', type=str, 22 | help='network baseline: resnet50') 23 | parser.add_argument('--resume', '-r', default='', type=str, 24 | help='resume from checkpoint') 25 | parser.add_argument('--test-only', action='store_true', help='test only') 26 | parser.add_argument('--model_path', default='save_model/', type=str, 27 | help='model save path') 28 | parser.add_argument('--save_epoch', default=20, type=int, 29 | metavar='s', help='save model every 10 epochs') 30 | parser.add_argument('--log_path', default='log/', type=str, 31 | help='log save path') 32 | parser.add_argument('--vis_log_path', default='log/vis_log/', type=str, 33 | help='log save path') 34 | parser.add_argument('--workers', default=4, type=int, metavar='N', 35 | help='number of data loading workers (default: 4)') 36 | parser.add_argument('--img_w', default=144, type=int, 37 | metavar='imgw', help='img width') 38 | parser.add_argument('--img_h', default=288, type=int, 39 | metavar='imgh', help='img height') 40 | parser.add_argument('--batch-size', default=8, type=int, 41 | metavar='B', help='training batch size') 42 | parser.add_argument('--test-batch', default=64, type=int, 43 | metavar='tb', help='testing batch size') 44 | parser.add_argument('--method', default='base', type=str, 45 | metavar='m', help='method type: base or awg') 46 | parser.add_argument('--margin', default=0.3, type=float, 47 | metavar='margin', help='triplet loss margin') 48 | parser.add_argument('--num_pos', default=4, type=int, 49 | help='num of pos per identity in each modality') 50 | parser.add_argument('--trial', default=7, type=int, 51 | metavar='t', help='trial (only for RegDB dataset)') 52 | parser.add_argument('--seed', default=0, type=int, 53 | metavar='t', help='random seed') 54 | parser.add_argument('--gpu', default='0', type=str, 55 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 56 | parser.add_argument('--mode', default='all', type=str, help='all or indoor for sysu') 57 | parser.add_argument('--tvsearch', action='store_true', help='whether thermal to visible search on RegDB') 58 | 59 | parser.add_argument('--share_net', default=2, type=int, 60 | metavar='share', help='[1,2,3,4]the start number of shared network in the two-stream networks') 61 | parser.add_argument('--re_rank', default='no', type=str, help='performing reranking. [random_walk | k_reciprocal | no]') 62 | parser.add_argument('--pcb', default='on', type=str, help='performing PCB, on or off') 63 | 64 | parser.add_argument('--w_center', default=1.0, type=float, help='the weight for center loss') 65 | 66 | parser.add_argument('--local_feat_dim', default=256, type=int, 67 | help='feature dimention of each local feature in PCB') 68 | parser.add_argument('--num_strips', default=6, type=int, 69 | help='num of local strips in PCB') 70 | 71 | parser.add_argument('--label_smooth', default='on', type=str, help='performing label smooth or not') 72 | 73 | args = parser.parse_args() 74 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 75 | 76 | dataset = args.dataset 77 | if dataset == 'sysu': 78 | data_path = 'E:\chenfeng\dataset\SYSU-MM01/' 79 | n_class = 395 80 | test_mode = [1, 2] 81 | elif dataset =='regdb': 82 | data_path = 'E:\chenfeng\dataset\RegDB/' 83 | n_class = 206 84 | test_mode = [2, 1] 85 | 86 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 87 | best_acc = 0 # best test accuracy 88 | start_epoch = 0 89 | if args.pcb == 'on': 90 | pool_dim = args.num_strips * args.local_feat_dim 91 | else: 92 | pool_dim = 2048 93 | print('==> Building model..') 94 | if args.method =='base': 95 | net = embed_net(n_class, no_local= 'off', gm_pool = 'on', arch=args.arch, share_net=args.share_net, pcb=args.pcb, local_feat_dim=args.local_feat_dim, num_strips=args.num_strips) 96 | else: 97 | net = embed_net(n_class, no_local= 'on', gm_pool = 'on', arch=args.arch, share_net=args.share_net, pcb=args.pcb) 98 | net.to(device) 99 | cudnn.benchmark = True 100 | 101 | checkpoint_path = args.model_path 102 | 103 | if args.method =='id': 104 | criterion = nn.CrossEntropyLoss() 105 | criterion.to(device) 106 | 107 | print('==> Loading data..') 108 | # Data loading code 109 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 110 | transform_train = transforms.Compose([ 111 | transforms.ToPILImage(), 112 | transforms.RandomCrop((args.img_h,args.img_w)), 113 | transforms.RandomHorizontalFlip(), 114 | transforms.ToTensor(), 115 | normalize, 116 | ]) 117 | 118 | transform_test = transforms.Compose([ 119 | transforms.ToPILImage(), 120 | transforms.Resize((args.img_h,args.img_w)), 121 | transforms.ToTensor(), 122 | normalize, 123 | ]) 124 | 125 | end = time.time() 126 | 127 | 128 | 129 | def extract_gall_feat(gall_loader): 130 | net.eval() 131 | print ('Extracting Gallery Feature...') 132 | start = time.time() 133 | ptr = 0 134 | gall_feat_pool = np.zeros((ngall, pool_dim)) 135 | gall_feat_fc = np.zeros((ngall, pool_dim)) 136 | with torch.no_grad(): 137 | for batch_idx, (input, label ) in enumerate(gall_loader): 138 | batch_num = input.size(0) 139 | input = Variable(input.cuda()) 140 | if args.pcb == 'on': 141 | feat_pool = net(input, input, test_mode[0]) 142 | gall_feat_pool[ptr:ptr+batch_num,: ] = feat_pool.detach().cpu().numpy() 143 | else: 144 | feat_pool, feat_fc = net(input, input, test_mode[0]) 145 | gall_feat_pool[ptr:ptr+batch_num,: ] = feat_pool.detach().cpu().numpy() 146 | gall_feat_fc[ptr:ptr+batch_num,: ] = feat_fc.detach().cpu().numpy() 147 | ptr = ptr + batch_num 148 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 149 | if args.pcb == 'on': 150 | return gall_feat_pool 151 | else: 152 | return gall_feat_pool, gall_feat_fc 153 | 154 | def extract_query_feat(query_loader): 155 | net.eval() 156 | print ('Extracting Query Feature...') 157 | start = time.time() 158 | ptr = 0 159 | query_feat_pool = np.zeros((nquery, pool_dim)) 160 | query_feat_fc = np.zeros((nquery, pool_dim)) 161 | with torch.no_grad(): 162 | for batch_idx, (input, label ) in enumerate(query_loader): 163 | batch_num = input.size(0) 164 | input = Variable(input.cuda()) 165 | if args.pcb == 'on': 166 | feat_pool = net(input, input, test_mode[1]) 167 | query_feat_pool[ptr:ptr+batch_num,: ] = feat_pool.detach().cpu().numpy() 168 | else: 169 | feat_pool, feat_fc = net(input, input, test_mode[1]) 170 | query_feat_pool[ptr:ptr+batch_num,: ] = feat_pool.detach().cpu().numpy() 171 | query_feat_fc[ptr:ptr+batch_num,: ] = feat_fc.detach().cpu().numpy() 172 | ptr = ptr + batch_num 173 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 174 | if args.pcb == 'on': 175 | return query_feat_pool 176 | else: 177 | return query_feat_pool, query_feat_fc 178 | 179 | 180 | if dataset == 'sysu': 181 | 182 | print('==> Resuming from checkpoint..') 183 | 184 | # model_path = checkpoint_path + args.resume 185 | model_path = checkpoint_path + 'sysu_c_tri_pcb_on_w_tri_1.0_s6_f256_share_net2_base_gm_k8_p6_lr_0.1_seed_0_best.t' 186 | if os.path.isfile(model_path): 187 | print('==> loading checkpoint {}'.format(args.resume)) 188 | checkpoint = torch.load(model_path) 189 | net.load_state_dict(checkpoint['net']) 190 | print('==> loaded checkpoint {} (epoch {})' 191 | .format(args.resume, checkpoint['epoch'])) 192 | else: 193 | print('==> no checkpoint found at {}'.format(args.resume)) 194 | 195 | # testing set 196 | query_img, query_label, query_cam = process_query_sysu(data_path, mode=args.mode) 197 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=0) 198 | 199 | nquery = len(query_label) 200 | ngall = len(gall_label) 201 | print("Dataset statistics:") 202 | print(" ------------------------------") 203 | print(" subset | # ids | # images") 204 | print(" ------------------------------") 205 | print(" query | {:5d} | {:8d}".format(len(np.unique(query_label)), nquery)) 206 | print(" gallery | {:5d} | {:8d}".format(len(np.unique(gall_label)), ngall)) 207 | print(" ------------------------------") 208 | 209 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 210 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 211 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 212 | 213 | if args.pcb == 'on': 214 | query_feat_pool = extract_query_feat(query_loader) 215 | else: 216 | query_feat_pool, query_feat_fc = extract_query_feat(query_loader) 217 | for trial in range(10): 218 | print('Test Trial: {}'.format(trial)) 219 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=trial) 220 | 221 | trial_gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 222 | trial_gall_loader = data.DataLoader(trial_gallset, batch_size=args.test_batch, shuffle=False, num_workers=4) 223 | 224 | if args.pcb == 'on': 225 | gall_feat_pool = extract_gall_feat(trial_gall_loader) 226 | else: 227 | gall_feat_pool, gall_feat_fc = extract_gall_feat(trial_gall_loader) 228 | 229 | if args.re_rank == 'random_walk': 230 | distmat_pool = random_walk(query_feat_pool, gall_feat_pool) 231 | if args.pcb == 'off': distmat = random_walk(query_feat_fc, gall_feat_fc) 232 | elif args.re_rank == 'k_reciprocal': 233 | distmat_pool = k_reciprocal(query_feat_pool, gall_feat_pool) 234 | if args.pcb == 'off': distmat = k_reciprocal(query_feat_fc, gall_feat_fc) 235 | elif args.re_rank == 'no': 236 | # compute the similarity 237 | distmat_pool = -np.matmul(query_feat_pool, np.transpose(gall_feat_pool)) 238 | if args.pcb == 'off': distmat = -np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 239 | # pool5 feature 240 | cmc_pool, mAP_pool, mINP_pool = eval_sysu(distmat_pool, query_label, gall_label, query_cam, gall_cam) 241 | 242 | if args.pcb == 'off': 243 | # fc feature 244 | cmc, mAP, mINP = eval_sysu(distmat, query_label, gall_label, query_cam, gall_cam) 245 | if trial == 0: 246 | if args.pcb == 'off': 247 | all_cmc = cmc 248 | all_mAP = mAP 249 | all_mINP = mINP 250 | all_cmc_pool = cmc_pool 251 | all_mAP_pool = mAP_pool 252 | all_mINP_pool = mINP_pool 253 | else: 254 | if args.pcb == 'off': 255 | all_cmc = all_cmc + cmc 256 | all_mAP = all_mAP + mAP 257 | all_mINP = all_mINP + mINP 258 | all_cmc_pool = all_cmc_pool + cmc_pool 259 | all_mAP_pool = all_mAP_pool + mAP_pool 260 | all_mINP_pool = all_mINP_pool + mINP_pool 261 | 262 | 263 | if args.pcb == 'off': 264 | print( 265 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 266 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 267 | print( 268 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 269 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 270 | 271 | 272 | elif dataset == 'regdb': 273 | 274 | for trial in range(10): 275 | test_trial = trial +1 276 | print('Test Trial: {}'.format(test_trial)) 277 | #model_path = checkpoint_path + 'regdbtest_share_net2_base_gm_p4_n8_lr_0.1_seed_0_trial_{}_best.t'.format(test_trial) 278 | model_path = checkpoint_path + 'regdb_c_tri_pcb_on_w_tri_2.0_s6_f256_share_net2_base_gm10_k4_p8_lr_0.1_seed_0_trial_{}_best.t'.format(test_trial) 279 | if os.path.isfile(model_path): 280 | print('==> loading checkpoint {}'.format(args.resume)) 281 | checkpoint = torch.load(model_path) 282 | net.load_state_dict(checkpoint['net']) 283 | 284 | # training set 285 | trainset = RegDBData(data_path, test_trial, transform=transform_train) 286 | # generate the idx of each person identity 287 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 288 | 289 | # testing set 290 | query_img, query_label = process_test_regdb(data_path, trial=test_trial, modal='visible') 291 | gall_img, gall_label = process_test_regdb(data_path, trial=test_trial, modal='thermal') 292 | 293 | gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 294 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 295 | 296 | nquery = len(query_label) 297 | ngall = len(gall_label) 298 | 299 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 300 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 301 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 302 | 303 | if args.pcb == 'on': 304 | query_feat_pool = extract_query_feat(query_loader) 305 | gall_feat_pool = extract_gall_feat(gall_loader) 306 | else: 307 | query_feat_pool, query_feat_fc = extract_query_feat(query_loader) 308 | gall_feat_pool, gall_feat_fc = extract_gall_feat(gall_loader) 309 | 310 | if args.tvsearch: 311 | if args.re_rank == 'random_walk': 312 | distmat_pool = random_walk(gall_feat_pool, query_feat_pool) 313 | if args.pcb == 'off': distmat = random_walk(gall_feat_fc, query_feat_fc) 314 | elif args.re_rank == 'k_reciprocal': 315 | distmat_pool = k_reciprocal(gall_feat_pool, query_feat_pool) 316 | if args.pcb == 'off': distmat = k_reciprocal(gall_feat_fc, query_feat_fc) 317 | elif args.re_rank == 'no': 318 | # compute the similarity 319 | distmat_pool = -np.matmul(gall_feat_pool, np.transpose(query_feat_pool)) 320 | if args.pcb == 'off': distmat = -np.matmul(gall_feat_fc, np.transpose(query_feat_fc)) 321 | # pool5 feature 322 | cmc_pool, mAP_pool, mINP_pool = eval_regdb(distmat_pool, gall_label, query_label) 323 | if args.pcb == 'off': 324 | # fc feature 325 | cmc, mAP, mINP = eval_regdb(distmat,gall_label, query_label ) 326 | else: 327 | if args.re_rank == 'random_walk': 328 | distmat_pool = random_walk(query_feat_pool, gall_feat_pool) 329 | if args.pcb == 'off': distmat = random_walk(query_feat_fc, gall_feat_fc) 330 | elif args.re_rank == 'k_reciprocal': 331 | distmat_pool = k_reciprocal(query_feat_pool, gall_feat_pool) 332 | if args.pcb == 'off': distmat = k_reciprocal(query_feat_fc, gall_feat_fc) 333 | elif args.re_rank == 'no': 334 | # compute the similarity 335 | distmat_pool = -np.matmul(query_feat_pool, np.transpose(gall_feat_pool)) 336 | if args.pcb == 'off': distmat = -np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 337 | # pool5 feature 338 | cmc_pool, mAP_pool, mINP_pool = eval_regdb(distmat_pool, query_label, gall_label) 339 | if args.pcb == 'off': 340 | # fc feature 341 | cmc, mAP, mINP = eval_regdb(distmat, query_label, gall_label) 342 | 343 | 344 | if trial == 0: 345 | if args.pcb == 'off': 346 | all_cmc = cmc 347 | all_mAP = mAP 348 | all_mINP = mINP 349 | all_cmc_pool = cmc_pool 350 | all_mAP_pool = mAP_pool 351 | all_mINP_pool = mINP_pool 352 | else: 353 | if args.pcb == 'off': 354 | all_cmc = all_cmc + cmc 355 | all_mAP = all_mAP + mAP 356 | all_mINP = all_mINP + mINP 357 | all_cmc_pool = all_cmc_pool + cmc_pool 358 | all_mAP_pool = all_mAP_pool + mAP_pool 359 | all_mINP_pool = all_mINP_pool + mINP_pool 360 | 361 | if args.pcb == 'off': 362 | print( 363 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 364 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 365 | print( 366 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 367 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 368 | if args.pcb == 'off': 369 | cmc = all_cmc / 10 370 | mAP = all_mAP / 10 371 | 372 | cmc_pool = all_cmc_pool / 10 373 | mAP_pool = all_mAP_pool / 10 374 | print('All Average:') 375 | 376 | if args.pcb == 'off': 377 | print( 378 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 379 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 380 | print('POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 381 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) -------------------------------------------------------------------------------- /train_HCT.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import sys 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.backends.cudnn as cudnn 9 | from torch.autograd import Variable 10 | import torch.utils.data as data 11 | import torchvision 12 | import torchvision.transforms as transforms 13 | from data_loader import SYSUData, RegDBData, TestData 14 | from data_manager import * 15 | from eval_metrics import eval_sysu, eval_regdb 16 | #from model_main import embed_net 17 | from model_mem import embed_net 18 | from utils import * 19 | from loss import OriTripletLoss, HcTripletLoss, CrossEntropyLabelSmooth, EntropyLossEncap, BarlowTwins_loss_mem, MemTriLoss 20 | from torch.optim import lr_scheduler 21 | from tensorboardX import SummaryWriter 22 | import torch.nn.functional as F 23 | import math 24 | 25 | 26 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 27 | parser.add_argument('--dataset', default='regdb', help='dataset name: regdb or sysu]') 28 | parser.add_argument('--lr', default=0.3 , type=float, help='learning rate, 0.00035 for adam') 29 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer') 30 | parser.add_argument('--arch', default='resnet50', type=str, 31 | help='network baseline:resnet50') 32 | parser.add_argument('--resume', '-r', default='', type=str, 33 | help='resume from checkpoint') 34 | parser.add_argument('--test-only', action='store_true', help='test only') 35 | parser.add_argument('--model_path', default='save_model/', type=str, 36 | help='model save path') 37 | parser.add_argument('--save_epoch', default=20, type=int, 38 | metavar='s', help='save model every 10 epochs') 39 | parser.add_argument('--log_path', default='log/', type=str, 40 | help='log save path') 41 | parser.add_argument('--vis_log_path', default='log/vis_log_ddag/', type=str, 42 | help='log save path') 43 | parser.add_argument('--workers', default=0, type=int, metavar='N', 44 | help='number of data loading workers (default: 4)') 45 | parser.add_argument('--img_w', default=144, type=int, 46 | metavar='imgw', help='img width') 47 | parser.add_argument('--img_h', default=288, type=int, 48 | metavar='imgh', help='img height') 49 | parser.add_argument('--batch-size', default=8, type=int, 50 | metavar='B', help='training batch size') 51 | parser.add_argument('--test-batch', default=64, type=int, 52 | metavar='tb', help='testing batch size') 53 | parser.add_argument('--part', default=3, type=int, 54 | metavar='tb', help=' part number') 55 | parser.add_argument('--drop', default=0.2, type=float, 56 | metavar='drop', help='dropout ratio') 57 | parser.add_argument('--margin', default=0.3, type=float, 58 | metavar='margin', help='triplet loss margin') 59 | parser.add_argument('--num_pos', default=6, type=int, 60 | help='num of pos per identity in each modality') 61 | parser.add_argument('--trial', default=10, type=int, 62 | metavar='t', help='trial (only for RegDB dataset)') 63 | parser.add_argument('--seed', default=0, type=int, 64 | metavar='t', help='random seed') 65 | parser.add_argument('--gpu', default='0', type=str, 66 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 67 | parser.add_argument('--mode', default='all', type=str, help='all or indoor') 68 | 69 | parser.add_argument('--cpool', default='no', type=str, help='The coarse branch pooling: no | wpa | avg | max | gem') 70 | parser.add_argument('--bpool', default='avg', type=str, help='The backbone (fine branch) pooling: avg | max | gem') 71 | parser.add_argument('--label_smooth', default='off', type=str, help='performing label smooth or not') 72 | parser.add_argument('--hcloss', default='HcTri', type=str, help='OriTri, HcTri') 73 | parser.add_argument('--margin_hc', default=0, type=float, 74 | metavar='margin', help='additional hc triplet loss margin') 75 | parser.add_argument('--fuse', default='sum', type=str, help='sum | cat') 76 | 77 | args = parser.parse_args() 78 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 79 | 80 | set_seed(args.seed) 81 | 82 | dataset = args.dataset 83 | if dataset == 'sysu': 84 | # TODO: define your data path 85 | data_path = 'E:\chenfeng\dataset\SYSU-MM01/' 86 | log_path = os.path.join(args.log_path, 'sysu_log_ddag/') 87 | test_mode = [1, 2] # infrared to visible 88 | elif dataset =='regdb': 89 | # TODO: define your data path for RegDB dataset 90 | data_path = 'E:\chenfeng\dataset\RegDB/' 91 | log_path = os.path.join(args.log_path, 'regdb_log_ddag/') 92 | test_mode = [2, 1] # visible to infrared 93 | 94 | checkpoint_path = args.model_path 95 | 96 | if not os.path.isdir(log_path): 97 | os.makedirs(log_path) 98 | if not os.path.isdir(checkpoint_path): 99 | os.makedirs(checkpoint_path) 100 | if not os.path.isdir(args.vis_log_path): 101 | os.makedirs(args.vis_log_path) 102 | 103 | # log file name 104 | suffix = dataset+'_bpool_{}_cpool_{}_hcloss_{}_fuse_{}'.format(args.bpool,args.cpool,args.hcloss,args.fuse) #c2f:coarse to fine sm: simple module 105 | 106 | suffix = suffix + '_hcmargin_{}'.format(args.margin_hc) + '_gm_ls_{}_s1'.format(args.label_smooth) # ls: label_smooth 107 | 108 | if args.cpool == 'wpa': 109 | suffix = suffix + '_P_{}'.format(args.part) 110 | suffix = suffix + '_drop_{}_{}_{}_lr_{}_seed_{}'.format(args.drop, args.num_pos, args.batch_size, args.lr, args.seed) 111 | if not args.optim == 'sgd': 112 | suffix = suffix + '_' + args.optim 113 | if dataset == 'regdb': 114 | suffix = suffix + '_trial_{}'.format(args.trial) 115 | 116 | sys.stdout = Logger(log_path + suffix + '_os.txt') 117 | 118 | vis_log_dir = args.vis_log_path + suffix + '/' 119 | 120 | if not os.path.isdir(vis_log_dir): 121 | os.makedirs(vis_log_dir) 122 | writer = SummaryWriter(vis_log_dir) 123 | print("==========\nArgs:{}\n==========".format(args)) 124 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 125 | best_acc = 0 # best test accuracy 126 | start_epoch = 0 127 | 128 | feature_dim = 2048 129 | feature_dim_att = 2048 if args.fuse == "sum" else 4096 130 | 131 | end = time.time() 132 | 133 | print('==> Loading data..') 134 | # Data loading code 135 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 136 | transform_train = transforms.Compose([ 137 | transforms.ToPILImage(), 138 | transforms.Pad(10), 139 | transforms.RandomCrop((args.img_h, args.img_w)), 140 | transforms.RandomHorizontalFlip(), 141 | transforms.ToTensor(), 142 | normalize, 143 | ]) 144 | transform_test = transforms.Compose([ 145 | transforms.ToPILImage(), 146 | transforms.Resize((args.img_h, args.img_w)), 147 | transforms.ToTensor(), 148 | normalize, 149 | ]) 150 | 151 | 152 | if dataset == 'sysu': 153 | # training set 154 | trainset = SYSUData(data_path, transform=transform_train) 155 | # generate the idx of each person identity 156 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 157 | 158 | # testing set 159 | query_img, query_label, query_cam = process_query_sysu(data_path, mode=args.mode) 160 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=0) 161 | 162 | elif dataset == 'regdb': 163 | # training set 164 | trainset = RegDBData(data_path, args.trial, transform=transform_train) 165 | # generate the idx of each person identity 166 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 167 | 168 | # testing set 169 | query_img, query_label = process_test_regdb(data_path, trial=args.trial, modal='visible') 170 | gall_img, gall_label = process_test_regdb(data_path, trial=args.trial, modal='thermal') 171 | 172 | gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 173 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 174 | 175 | # testing data loader 176 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 177 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 178 | 179 | 180 | n_class = len(np.unique(trainset.train_color_label)) 181 | 182 | nquery = len(query_label) 183 | ngall = len(gall_label) 184 | 185 | print('Dataset {} statistics:'.format(dataset)) 186 | print(' ------------------------------') 187 | print(' subset | # ids | # images') 188 | print(' ------------------------------') 189 | print(' visible | {:5d} | {:8d}'.format(n_class, len(trainset.train_color_label))) 190 | print(' thermal | {:5d} | {:8d}'.format(n_class, len(trainset.train_thermal_label))) 191 | print(' ------------------------------') 192 | print(' query | {:5d} | {:8d}'.format(len(np.unique(query_label)), nquery)) 193 | print(' gallery | {:5d} | {:8d}'.format(len(np.unique(gall_label)), ngall)) 194 | print(' ------------------------------') 195 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 196 | 197 | print('==> Building model..') 198 | net = embed_net(n_class, drop=args.drop, part=args.part, arch=args.arch, cpool=args.cpool,bpool=args.bpool,fuse=args.fuse) 199 | net.to(device) 200 | cudnn.benchmark = True 201 | 202 | if len(args.resume) > 0: 203 | model_path = checkpoint_path + args.resume 204 | if os.path.isfile(model_path): 205 | print('==> loading checkpoint {}'.format(args.resume)) 206 | checkpoint = torch.load(model_path) 207 | start_epoch = checkpoint['epoch'] 208 | net.load_state_dict(checkpoint['net']) 209 | print('==> loaded checkpoint {} (epoch {})' 210 | .format(args.resume, checkpoint['epoch'])) 211 | else: 212 | print('==> no checkpoint found at {}'.format(args.resume)) 213 | 214 | # define loss function 215 | if args.label_smooth == 'on': 216 | criterion1 = CrossEntropyLabelSmooth(n_class) 217 | else: 218 | criterion1 = nn.CrossEntropyLoss() 219 | loader_batch = args.batch_size * args.num_pos 220 | criterion2 = OriTripletLoss(batch_size=loader_batch, margin=args.margin) 221 | #criterion2 = HcTripletLoss(batch_size=loader_batch, margin=args.margin) 222 | if args.hcloss == 'OriTri': 223 | criterion_hc = OriTripletLoss(batch_size=loader_batch, margin=args.margin) 224 | if args.hcloss == 'HcTri': 225 | criterion_hc = HcTripletLoss(batch_size=loader_batch, margin=args.margin+args.margin_hc) 226 | if args.hcloss == 'no': 227 | pass 228 | criterion1.to(device) 229 | criterion2.to(device) 230 | if args.hcloss != 'no': 231 | criterion_hc.to(device) 232 | 233 | # memory att update 234 | tr_entropy_loss_func = BarlowTwins_loss_mem() 235 | tri_mem_loss_fuc = MemTriLoss() 236 | l1_mem_loss_func = nn.SmoothL1Loss() 237 | 238 | # optimizer 239 | if args.optim == 'sgd': 240 | if args.cpool != 'no': 241 | ignored_params = list(map(id, net.bottleneck.parameters())) \ 242 | + list(map(id, net.classifier.parameters())) \ 243 | + list(map(id, net.classifier_att.parameters())) \ 244 | + list(map(id, net.cpool_layer.parameters())) 245 | 246 | base_params = filter(lambda p: id(p) not in ignored_params, net.parameters()) 247 | 248 | optimizer_P = optim.SGD([ 249 | {'params': base_params, 'lr': 0.1 * args.lr}, 250 | {'params': net.bottleneck.parameters(), 'lr': args.lr}, 251 | {'params': net.classifier.parameters(), 'lr': args.lr}, 252 | {'params': net.classifier_att.parameters(), 'lr': args.lr}, 253 | {'params': net.cpool_layer.parameters(), 'lr': args.lr}, 254 | ], 255 | weight_decay=5e-4, momentum=0.9, nesterov=True) 256 | else: 257 | ignored_params = list(map(id, net.bottleneck.parameters())) \ 258 | + list(map(id, net.classifier.parameters())) 259 | 260 | base_params = filter(lambda p: id(p) not in ignored_params, net.parameters()) 261 | 262 | optimizer_P = optim.SGD([ 263 | {'params': base_params, 'lr': 0.1 * args.lr}, 264 | {'params': net.bottleneck.parameters(), 'lr': args.lr}, 265 | {'params': net.classifier.parameters(), 'lr': args.lr}, 266 | ], 267 | weight_decay=5e-4, momentum=0.9, nesterov=True) 268 | 269 | 270 | # exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) 271 | def adjust_learning_rate(optimizer_P, epoch): 272 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 273 | if epoch < 10: 274 | lr = args.lr * (epoch + 1) / 10 275 | elif 10 <= epoch < 20: 276 | lr = args.lr 277 | elif 20 <= epoch < 50: 278 | lr = args.lr * 0.1 279 | elif epoch >= 50: 280 | lr = args.lr * 0.01 281 | 282 | optimizer_P.param_groups[0]['lr'] = 0.1 * lr 283 | for i in range(len(optimizer_P.param_groups) - 1): 284 | optimizer_P.param_groups[i + 1]['lr'] = lr 285 | return lr 286 | 287 | 288 | def train(epoch): 289 | # adjust learning rate 290 | current_lr = adjust_learning_rate(optimizer_P, epoch) 291 | train_loss = AverageMeter() 292 | id_loss = AverageMeter() 293 | tri_loss = AverageMeter() 294 | data_time = AverageMeter() 295 | batch_time = AverageMeter() 296 | tri_mem_loss = AverageMeter() 297 | ce_mem_loss = AverageMeter() 298 | correct = 0 299 | total = 0 300 | 301 | # switch to train mode 302 | net.train() 303 | end = time.time() 304 | 305 | for batch_idx, (input1, input2, label1, label2) in enumerate(trainloader): 306 | 307 | labels = torch.cat((label1, label2), 0) 308 | 309 | input1 = Variable(input1.cuda()) 310 | input2 = Variable(input2.cuda()) 311 | 312 | labels = Variable(labels.cuda()) 313 | data_time.update(time.time() - end) 314 | 315 | if args.cpool != 'no': 316 | # Forward into the network 317 | feat, out0, feat_att, out_att, att_mem, feat_mem = net(input1, input2) 318 | # Part attention loss 319 | loss_p = criterion1(out_att, labels) 320 | if args.hcloss != 'no': 321 | loss_p_hc, _ = criterion_hc(feat_att, labels) 322 | else: 323 | # Forward into the network 324 | feat, out0, att_mem, feat_mem, x_mem_feat, out_mem = net(input1, input2) 325 | loss_mem_br_cls = criterion1(out_mem, labels.long()) 326 | loss_mem_br_tri,_ = criterion2(x_mem_feat, labels) 327 | 328 | 329 | # baseline loss: identity loss + triplet loss Eq. (1) 330 | loss_id = criterion1(out0, labels.long()) 331 | loss_tri, batch_acc = criterion2(feat, labels) 332 | # loss mem att 333 | loss_mem = tr_entropy_loss_func(att_mem) 334 | loss_mem_tri,_ = tri_mem_loss_fuc(feat_mem,labels,att_mem) 335 | #att_mem_c_1 , att_mem_c_2 = att_mem_c.chunk(2,dim=0) 336 | #loss_mem_c = l1_mem_loss_func(att_mem_c_1, att_mem_c_2) 337 | #loss_hc, _ = criterion_hc(feat, labels) 338 | correct += (batch_acc / 2) 339 | _, predicted = out0.max(1) 340 | correct += (predicted.eq(labels).sum().item() / 2) 341 | 342 | if args.cpool != 'no': 343 | # Instance-level part-aggregated feature learning Eq. (10) 344 | if args.hcloss != 'no': 345 | loss = loss_id + loss_tri + loss_p + loss_p_hc 346 | else: 347 | loss = loss_id + loss_tri + loss_p 348 | else: 349 | loss = loss_id + loss_tri #+ loss_hc 350 | 351 | loss = loss + loss_mem + loss_mem_tri + loss_mem_br_cls * 0.1 + loss_mem_br_tri #+ loss_mem_c 352 | #loss = loss + loss_mem_tri 353 | 354 | # optimization 355 | optimizer_P.zero_grad() 356 | loss.backward() 357 | optimizer_P.step() 358 | 359 | # log different loss components 360 | train_loss.update(loss.item(), 2 * input1.size(0)) 361 | id_loss.update(loss_id.item(), 2 * input1.size(0)) 362 | tri_loss.update(loss_tri.item(), 2 * input1.size(0)) 363 | tri_mem_loss.update(loss_mem_tri.item(), 2 * input1.size(0)) 364 | ce_mem_loss.update(loss_mem.item(),2 * input1.size(0)) 365 | #graph_loss.update(loss_G.item(), 2 * input1.size(0)) 366 | total += labels.size(0) 367 | 368 | # measure elapsed time 369 | batch_time.update(time.time() - end) 370 | end = time.time() 371 | if batch_idx % 50 == 0: 372 | print('Epoch: [{}][{}/{}] ' 373 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 374 | 'lr:{:.2f} ' 375 | 'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f}) ' 376 | 'iLoss: {id_loss.val:.4f} ({id_loss.avg:.4f}) ' 377 | 'TLoss: {tri_loss.val:.4f} ({tri_loss.avg:.4f}) ' 378 | 'TriMem: {trimem.val:.4f} ({trimem.avg:.4f}) ' 379 | 'CeMem: {cemem.val:.4f} ({cemem.avg:.4f}) ' 380 | 'Accu: {:.2f}'.format( 381 | epoch, batch_idx, len(trainloader), current_lr, 382 | 100. * correct / total, batch_time=batch_time, 383 | train_loss=train_loss, id_loss=id_loss, tri_loss=tri_loss, trimem = tri_mem_loss, cemem=ce_mem_loss)) 384 | 385 | writer.add_scalar('total_loss', train_loss.avg, epoch) 386 | writer.add_scalar('id_loss', id_loss.avg, epoch) 387 | writer.add_scalar('tri_loss', tri_loss.avg, epoch) 388 | #writer.add_scalar('graph_loss', graph_loss.avg, epoch) 389 | writer.add_scalar('lr', current_lr, epoch) 390 | # computer wG 391 | #return 1. / (1. + train_loss.avg) 392 | 393 | def test(epoch): 394 | # switch to evaluation mode 395 | net.eval() 396 | print('Extracting Gallery Feature...') 397 | start = time.time() 398 | ptr = 0 399 | gall_feat = np.zeros((ngall, feature_dim)) 400 | gall_feat_att = np.zeros((ngall, feature_dim_att)) 401 | with torch.no_grad(): 402 | for batch_idx, (input, label) in enumerate(gall_loader): 403 | batch_num = input.size(0) 404 | input = Variable(input.cuda()) 405 | if args.cpool != 'no': 406 | feat, feat_att = net(input, input, test_mode[0]) 407 | gall_feat_att[ptr:ptr + batch_num, :] = feat_att.detach().cpu().numpy() 408 | gall_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 409 | else: 410 | feat, x_mem_feat = net(input, input, test_mode[0]) 411 | gall_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 412 | gall_feat_att[ptr:ptr + batch_num, :] = x_mem_feat.detach().cpu().numpy() 413 | ptr = ptr + batch_num 414 | print('Extracting Time:\t {:.3f}'.format(time.time() - start)) 415 | 416 | # switch to evaluation 417 | net.eval() 418 | print('Extracting Query Feature...') 419 | start = time.time() 420 | ptr = 0 421 | query_feat = np.zeros((nquery, feature_dim)) 422 | query_feat_att = np.zeros((nquery, feature_dim_att)) 423 | with torch.no_grad(): 424 | for batch_idx, (input, label) in enumerate(query_loader): 425 | batch_num = input.size(0) 426 | input = Variable(input.cuda()) 427 | if args.cpool != 'no': 428 | feat, feat_att = net(input, input, test_mode[1]) 429 | query_feat_att[ptr:ptr + batch_num, :] = feat_att.detach().cpu().numpy() 430 | query_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 431 | else: 432 | feat, x_mem_feat = net(input, input, test_mode[1]) 433 | query_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 434 | query_feat_att[ptr:ptr + batch_num, :] = x_mem_feat.detach().cpu().numpy() 435 | ptr = ptr + batch_num 436 | print('Extracting Time:\t {:.3f}'.format(time.time() - start)) 437 | 438 | start = time.time() 439 | # compute the similarity 440 | distmat = np.matmul(query_feat, np.transpose(gall_feat)) 441 | if args.cpool != 'no': 442 | distmat_att = np.matmul(query_feat_att, np.transpose(gall_feat_att)) 443 | 444 | # evaluation 445 | if dataset == 'regdb': 446 | cmc, mAP, mINP = eval_regdb(-distmat, query_label, gall_label) 447 | if args.cpool != 'no': 448 | cmc_att, mAP_att, mINP_att = eval_regdb(-distmat_att, query_label, gall_label) 449 | elif dataset == 'sysu': 450 | cmc, mAP, mINP = eval_sysu(-distmat, query_label, gall_label, query_cam, gall_cam) 451 | if args.cpool != 'no': 452 | cmc_att, mAP_att, mINP_att = eval_sysu(-distmat_att, query_label, gall_label, query_cam, gall_cam) 453 | print('Evaluation Time:\t {:.3f}'.format(time.time() - start)) 454 | 455 | writer.add_scalar('rank1', cmc[0], epoch) 456 | writer.add_scalar('mAP', mAP, epoch) 457 | if args.cpool != 'no': 458 | writer.add_scalar('rank1_att', cmc_att[0], epoch) 459 | writer.add_scalar('mAP_att', mAP_att, epoch) 460 | writer.add_scalar('mAP_att', mAP_att, epoch) 461 | writer.add_scalar('mINP_att', mINP_att, epoch) 462 | return cmc, mAP, mINP, cmc_att, mAP_att, mINP_att 463 | else: 464 | return cmc, mAP, mINP 465 | 466 | 467 | # training 468 | print('==> Start Training...') 469 | for epoch in range(start_epoch, 61 if args.dataset == 'regdb' else 61 - start_epoch):# default regdb 31 470 | 471 | print('==> Preparing Data Loader...') 472 | # identity sampler: 473 | sampler = IdentitySampler(trainset.train_color_label, \ 474 | trainset.train_thermal_label, color_pos, thermal_pos, args.num_pos, args.batch_size, 475 | epoch) 476 | 477 | trainset.cIndex = sampler.index1 # color index 478 | trainset.tIndex = sampler.index2 # infrared index 479 | '''print(epoch) 480 | print(trainset.cIndex) 481 | print(trainset.tIndex)''' 482 | 483 | loader_batch = args.batch_size * args.num_pos 484 | 485 | trainloader = data.DataLoader(trainset, batch_size=loader_batch, \ 486 | sampler=sampler, num_workers=args.workers, drop_last=True) 487 | 488 | # training 489 | train(epoch) 490 | 491 | if epoch > 0 and epoch % 5 == 0: 492 | print('Test Epoch: {}'.format(epoch)) 493 | 494 | if args.cpool != 'no': 495 | # testing 496 | cmc, mAP, mINP, cmc_att, mAP_att, mINP_att = test(epoch) 497 | # log output FC: f_bn, the fine branch feature FC_att: f_bnf, the coarse branch feature 498 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 499 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 500 | 501 | print('FC_att: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 502 | cmc_att[0], cmc_att[4], cmc_att[9], cmc_att[19], mAP_att, mINP_att)) 503 | 504 | else: 505 | # testing 506 | cmc, mAP, mINP = test(epoch) 507 | # log output 508 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 509 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 510 | 511 | # save model 512 | if args.cpool != 'no': 513 | if cmc_att[0] >= best_acc: # not the real best for sysu-mm01 514 | best_acc = cmc_att[0] 515 | best_epoch = epoch 516 | best_mAP = mAP_att 517 | best_mINP = mINP_att 518 | state = { 519 | 'net': net.state_dict(), 520 | 'cmc': cmc_att, 521 | 'mAP': mAP_att, 522 | 'epoch': epoch, 523 | } 524 | torch.save(state, checkpoint_path + suffix + '_best.t') 525 | else: 526 | if cmc[0] >= best_acc: # not the real best for sysu-mm01 527 | best_acc = cmc[0] 528 | best_epoch = epoch 529 | best_mAP = mAP 530 | best_mINP = mINP 531 | state = { 532 | 'net': net.state_dict(), 533 | 'cmc': cmc, 534 | 'mAP': mAP, 535 | 'epoch': epoch, 536 | } 537 | torch.save(state, checkpoint_path + suffix + '_best.t') 538 | 539 | print('Best Epoch [{}], Rank-1: {:.2%} | mAP: {:.2%}| mINP: {:.2%}'.format(best_epoch, best_acc, best_mAP, best_mINP)) -------------------------------------------------------------------------------- /train_MGMRA.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import sys 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | import torch.backends.cudnn as cudnn 10 | from torch.autograd import Variable 11 | import torch.utils.data as data 12 | import torchvision 13 | import torchvision.transforms as transforms 14 | from data_loader import SYSUData, RegDBData, TestData 15 | from data_manager import * 16 | from eval_metrics import eval_sysu, eval_regdb 17 | from model_MGMRA import embed_net 18 | from utils import * 19 | from loss import OriTripletLoss, CenterTripletLoss, CrossEntropyLabelSmooth, TripletLoss_WRT 20 | from tensorboardX import SummaryWriter 21 | from re_rank import random_walk, k_reciprocal 22 | 23 | import numpy as np 24 | np.set_printoptions(threshold=np.inf) 25 | 26 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 27 | parser.add_argument('--dataset', default='regdb', help='dataset name: regdb or sysu]') 28 | parser.add_argument('--lr', default=0.01 , type=float, help='learning rate, 0.00035 for adam, 0.01 for sysu') 29 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer') 30 | parser.add_argument('--arch', default='resnet50', type=str, 31 | help='network baseline:resnet18 or resnet50') 32 | parser.add_argument('--resume', '-r', default='', type=str, 33 | help='resume from checkpoint') 34 | parser.add_argument('--test-only', action='store_true', help='test only') 35 | parser.add_argument('--model_path', default='save_model/', type=str, 36 | help='model save path') 37 | parser.add_argument('--save_epoch', default=100, type=int, 38 | metavar='s', help='save model every 10 epochs') 39 | parser.add_argument('--log_path', default='log/', type=str, 40 | help='log save path') 41 | parser.add_argument('--vis_log_path', default='log/vis_log/', type=str, 42 | help='log save path') 43 | parser.add_argument('--workers', default=4, type=int, metavar='N', 44 | help='number of data loading workers (default: 4)') 45 | parser.add_argument('--img_w', default=144, type=int, 46 | metavar='imgw', help='img width') 47 | parser.add_argument('--img_h', default=288, type=int, 48 | metavar='imgh', help='img height') 49 | parser.add_argument('--batch-size', default=4, type=int, 50 | metavar='B', help='training batch size') 51 | parser.add_argument('--test-batch', default=64, type=int, 52 | metavar='tb', help='testing batch size') 53 | parser.add_argument('--method', default='base', type=str, 54 | metavar='m', help='method type: base or agw') 55 | parser.add_argument('--margin', default=0.3, type=float, 56 | metavar='margin', help='triplet loss margin') 57 | parser.add_argument('--num_pos', default=4, type=int, 58 | help='num of pos per identity in each modality') 59 | parser.add_argument('--trial', default=10, type=int, 60 | metavar='t', help='trial (only for RegDB dataset)') 61 | parser.add_argument('--seed', default=0, type=int, 62 | metavar='t', help='random seed') 63 | parser.add_argument('--gpu', default='0', type=str, 64 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 65 | parser.add_argument('--mode', default='all', type=str, help='all or indoor') 66 | 67 | parser.add_argument('--share_net', default=2, type=int, 68 | metavar='share', help='[1,2,3,4,5]the start number of shared network in the two-stream networks') 69 | parser.add_argument('--re_rank', default='k_reciprocal', type=str, help='performing reranking. [random_walk | k_reciprocal | no]') 70 | parser.add_argument('--pcb', default='on', type=str, help='performing PCB, on or off') 71 | parser.add_argument('--w_center', default=2.0, type=float, help='the weight for center loss') 72 | 73 | parser.add_argument('--local_feat_dim', default=256, type=int, 74 | help='feature dimention of each local feature in PCB') 75 | parser.add_argument('--num_strips', default=6, type=int, 76 | help='num of local strips in PCB') 77 | 78 | parser.add_argument('--label_smooth', default='on', type=str, help='performing label smooth or not') 79 | 80 | args = parser.parse_args() 81 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 82 | 83 | set_seed(args.seed) 84 | 85 | dataset = args.dataset 86 | if dataset == 'sysu': 87 | # TODO: define your data path 88 | data_path = "/home/wmj/reid/data/SYSU/" 89 | log_path = os.path.join(args.log_path, 'sysu_log_ddag/') 90 | test_mode = [1, 2] # infrared to visible 91 | elif dataset =='regdb': 92 | # TODO: define your data path for RegDB dataset 93 | data_path = "/home/wmj/reid/data/RegDB/" 94 | log_path = os.path.join(args.log_path, 'regdb_log_ddag/') 95 | test_mode = [2, 1] # [2,1] visible to infrared 96 | 97 | checkpoint_path = args.model_path 98 | 99 | if not os.path.isdir(log_path): 100 | os.makedirs(log_path) 101 | if not os.path.isdir(checkpoint_path): 102 | os.makedirs(checkpoint_path) 103 | if not os.path.isdir(args.vis_log_path): 104 | os.makedirs(args.vis_log_path) 105 | 106 | suffix = dataset+'_c_tri_pcb_{}_w_tri_{}'.format(args.pcb,args.w_center) 107 | if args.pcb=='on': 108 | suffix = suffix + '_s{}_f{}'.format(args.num_strips, args.local_feat_dim) 109 | 110 | suffix = suffix + '_share_net{}'.format(args.share_net) 111 | if args.method=='agw': 112 | suffix = suffix + '_agw_k{}_p{}_lr_{}_seed_{}'.format(args.num_pos, args.batch_size, args.lr, args.seed) 113 | else: 114 | suffix = suffix + '_base_gm10_k{}_p{}_lr_{}_seed_{}'.format(args.num_pos, args.batch_size, args.lr, args.seed) 115 | 116 | 117 | if not args.optim == 'sgd': 118 | suffix = suffix + '_' + args.optim 119 | 120 | if dataset == 'regdb': 121 | suffix = suffix + '_trial_{}'.format(args.trial) 122 | 123 | sys.stdout = Logger(log_path + suffix + '_os.txt') 124 | 125 | vis_log_dir = args.vis_log_path + suffix + '/' 126 | 127 | if not os.path.isdir(vis_log_dir): 128 | os.makedirs(vis_log_dir) 129 | writer = SummaryWriter(vis_log_dir) 130 | print("==========\nArgs:{}\n==========".format(args)) 131 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 132 | best_acc = 0 # best test accuracy 133 | start_epoch = 0 134 | 135 | print('==> Loading data..') 136 | # Data loading code 137 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 138 | transform_train = transforms.Compose([ 139 | transforms.ToPILImage(), 140 | transforms.Pad(10), 141 | transforms.RandomCrop((args.img_h, args.img_w)), 142 | transforms.RandomHorizontalFlip(), 143 | transforms.ToTensor(), 144 | normalize, 145 | ]) 146 | transform_test = transforms.Compose([ 147 | transforms.ToPILImage(), 148 | transforms.Resize((args.img_h, args.img_w)), 149 | transforms.ToTensor(), 150 | normalize, 151 | ]) 152 | 153 | end = time.time() 154 | if dataset == 'sysu': 155 | # training set 156 | trainset = SYSUData(data_path, transform=transform_train) 157 | # generate the idx of each person identity 158 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 159 | 160 | # testing set 161 | query_img, query_label, query_cam = process_query_sysu(data_path, mode=args.mode) 162 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=0) 163 | 164 | elif dataset == 'regdb': 165 | # training set 166 | trainset = RegDBData(data_path, args.trial, transform=transform_train) 167 | # generate the idx of each person identity 168 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 169 | 170 | # testing set 171 | if test_mode[0] == 2: 172 | ### V -> I 173 | query_img, query_label = process_test_regdb(data_path, trial=args.trial, modal='visible') 174 | gall_img, gall_label = process_test_regdb(data_path, trial=args.trial, modal='thermal') 175 | 176 | if test_mode[0] == 1: 177 | #### I -> V 178 | query_img, query_label = process_test_regdb(data_path, trial=args.trial, modal='thermal') 179 | gall_img, gall_label = process_test_regdb(data_path, trial=args.trial, modal='visible') 180 | 181 | gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 182 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 183 | 184 | # testing data loader 185 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 186 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 187 | 188 | n_class = len(np.unique(trainset.train_color_label)) 189 | nquery = len(query_label) 190 | ngall = len(gall_label) 191 | 192 | print('Dataset {} statistics:'.format(dataset)) 193 | print(' ------------------------------') 194 | print(' subset | # ids | # images') 195 | print(' ------------------------------') 196 | print(' visible | {:5d} | {:8d}'.format(n_class, len(trainset.train_color_label))) 197 | print(' thermal | {:5d} | {:8d}'.format(n_class, len(trainset.train_thermal_label))) 198 | print(' ------------------------------') 199 | print(' query | {:5d} | {:8d}'.format(len(np.unique(query_label)), nquery)) 200 | print(' gallery | {:5d} | {:8d}'.format(len(np.unique(gall_label)), ngall)) 201 | print(' ------------------------------') 202 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 203 | 204 | print('==> Building model..') 205 | if args.method =='base': 206 | net = embed_net(n_class, no_local= 'off', gm_pool = 'on', arch=args.arch, share_net=args.share_net, pcb=args.pcb, local_feat_dim=args.local_feat_dim, num_strips=args.num_strips) 207 | else: 208 | net = embed_net(n_class, no_local= 'on', gm_pool = 'on', arch=args.arch, share_net=args.share_net, pcb=args.pcb) 209 | net.to(device) 210 | 211 | 212 | cudnn.benchmark = True 213 | 214 | if len(args.resume) > 0: 215 | model_path = checkpoint_path + args.resume 216 | if os.path.isfile(model_path): 217 | print('==> loading checkpoint {}'.format(args.resume)) 218 | checkpoint = torch.load(model_path) 219 | start_epoch = checkpoint['epoch'] 220 | net.load_state_dict(checkpoint['net']) 221 | print('==> loaded checkpoint {} (epoch {})' 222 | .format(args.resume, checkpoint['epoch'])) 223 | else: 224 | print('==> no checkpoint found at {}'.format(args.resume)) 225 | 226 | # define loss function 227 | if args.label_smooth == 'off': 228 | criterion_id = nn.CrossEntropyLoss() 229 | else: 230 | criterion_id = CrossEntropyLabelSmooth(n_class) 231 | 232 | if args.method == 'agw': 233 | criterion_tri = TripletLoss_WRT() 234 | else: 235 | loader_batch = args.batch_size * args.num_pos 236 | #criterion_tri= OriTripletLoss(batch_size=loader_batch, margin=args.margin) 237 | criterion_tri= CenterTripletLoss(batch_size=loader_batch, margin=args.margin) 238 | 239 | criterion_part = torch.nn.MSELoss() 240 | 241 | criterion_id.to(device) 242 | criterion_tri.to(device) 243 | 244 | 245 | 246 | if args.optim == 'sgd': 247 | if args.pcb == 'on': 248 | ignored_params = list(map(id, net.local_conv_list.parameters())) \ 249 | + list(map(id, net.fc_list.parameters())) 250 | 251 | base_params = filter(lambda p: id(p) not in ignored_params, net.parameters()) 252 | 253 | optimizer = optim.SGD([ 254 | {'params': base_params, 'lr': 0.1 * args.lr}, 255 | {'params': net.local_conv_list.parameters(), 'lr': args.lr}, 256 | {'params': net.fc_list.parameters(), 'lr': args.lr} 257 | ], 258 | weight_decay=5e-4, momentum=0.9, nesterov=True) 259 | else: 260 | ignored_params = list(map(id, net.bottleneck.parameters())) \ 261 | + list(map(id, net.classifier.parameters())) 262 | 263 | base_params = filter(lambda p: id(p) not in ignored_params, net.parameters()) 264 | 265 | optimizer = optim.SGD([ 266 | {'params': base_params, 'lr': 0.1 * args.lr}, 267 | {'params': net.bottleneck.parameters(), 'lr': args.lr}, 268 | {'params': net.classifier.parameters(), 'lr': args.lr}], 269 | weight_decay=5e-4, momentum=0.9, nesterov=True) 270 | 271 | def adjust_learning_rate(optimizer, epoch): 272 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 273 | if epoch < 10: 274 | lr = args.lr * (epoch + 1) / 10 275 | elif epoch >= 10 and epoch < 20: 276 | lr = args.lr 277 | elif epoch >= 20 and epoch < 50: 278 | lr = args.lr * 0.1 279 | elif epoch >= 50: 280 | lr = args.lr * 0.01 281 | 282 | optimizer.param_groups[0]['lr'] = 0.1 * lr 283 | for i in range(len(optimizer.param_groups) - 1): 284 | optimizer.param_groups[i + 1]['lr'] = lr 285 | 286 | return lr 287 | 288 | 289 | def train(epoch): 290 | 291 | current_lr = adjust_learning_rate(optimizer, epoch) 292 | train_loss = AverageMeter() 293 | id_loss = AverageMeter() 294 | tri_loss = AverageMeter() 295 | data_time = AverageMeter() 296 | batch_time = AverageMeter() 297 | correct = 0 298 | total = 0 299 | 300 | # switch to train mode 301 | net.train() 302 | end = time.time() 303 | 304 | for batch_idx, (input1, input2, label1, label2) in enumerate(trainloader): 305 | 306 | labels = torch.cat((label1, label2), 0) 307 | 308 | input1 = Variable(input1.cuda()) 309 | input2 = Variable(input2.cuda()) 310 | 311 | labels = Variable(labels.cuda()) 312 | data_time.update(time.time() - end) 313 | 314 | 315 | if args.pcb == 'on': 316 | feat, out0, feat_all, feat_mem, out_mem, out_part = net(input1, input2) 317 | loss_id = criterion_id(out0[0], labels.long()) 318 | loss_tri_l, batch_acc = criterion_tri(feat[0], labels) 319 | for i in range(len(feat)-1): 320 | loss_id += criterion_id(out0[i+1], labels.long()) 321 | loss_tri_l += criterion_tri(feat[i+1], labels)[0] 322 | loss_tri, batch_acc = criterion_tri(feat_all, labels) 323 | loss_tri += loss_tri_l * args.w_center # 324 | correct += batch_acc 325 | ### for mem branch 326 | loss_tri_mem, batch_acc = criterion_tri(feat_mem, labels.long()) 327 | loss_id_mem = criterion_id(out_mem, labels.long()) 328 | 329 | loss_part = criterion_part(out_part[0], out_part[1]) 330 | loss = loss_id + loss_tri + loss_id_mem*0.1 + loss_tri_mem + (loss_part-0.1) * 0.1 331 | else: 332 | feat, out0 = net(input1, input2) 333 | loss_id = criterion_id(out0, labels) 334 | 335 | loss_tri, batch_acc = criterion_tri(feat, labels) 336 | correct += (batch_acc / 2) 337 | _, predicted = out0.max(1) 338 | correct += (predicted.eq(labels).sum().item() / 2) 339 | loss = loss_id + loss_tri * args.w_center # 340 | 341 | 342 | optimizer.zero_grad() 343 | loss.backward() 344 | optimizer.step() 345 | 346 | # update P 347 | train_loss.update(loss.item(), 2 * input1.size(0)) 348 | id_loss.update(loss_id.item(), 2 * input1.size(0)) 349 | tri_loss.update(loss_tri, 2 * input1.size(0)) 350 | total += labels.size(0) 351 | 352 | # measure elapsed time 353 | batch_time.update(time.time() - end) 354 | end = time.time() 355 | if batch_idx % 50 == 0: 356 | print('Epoch: [{}][{}/{}] ' 357 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 358 | 'lr:{:.3f} ' 359 | 'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f}) ' 360 | 'iLoss: {id_loss.val:.4f} ({id_loss.avg:.4f}) ' 361 | 'TLoss: {tri_loss.val:.4f} ({tri_loss.avg:.4f}) ' 362 | 'Accu: {:.2f}'.format( 363 | epoch, batch_idx, len(trainloader), current_lr, 364 | 100. * correct / total, batch_time=batch_time, 365 | train_loss=train_loss, id_loss=id_loss,tri_loss=tri_loss)) 366 | 367 | writer.add_scalar('total_loss', train_loss.avg, epoch) 368 | writer.add_scalar('id_loss', id_loss.avg, epoch) 369 | writer.add_scalar('tri_loss', tri_loss.avg, epoch) 370 | writer.add_scalar('lr', current_lr, epoch) 371 | 372 | 373 | def test(epoch): 374 | # switch to evaluation mode 375 | net.eval() 376 | print('Extracting Gallery Feature...') 377 | start = time.time() 378 | ptr = 0 379 | if args.pcb == 'on': 380 | feat_dim = args.num_strips * args.local_feat_dim 381 | else: 382 | feat_dim = 2048 383 | gall_feat = np.zeros((ngall, feat_dim)) 384 | gall_feat_att = np.zeros((ngall, feat_dim)) 385 | with torch.no_grad(): 386 | for batch_idx, (input, label) in enumerate(gall_loader): 387 | batch_num = input.size(0) 388 | input = Variable(input.cuda()) 389 | if args.pcb == 'on': 390 | feat = net(input, input, test_mode[0]) 391 | gall_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 392 | else: 393 | feat, feat_att = net(input, input, test_mode[0]) 394 | gall_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 395 | gall_feat_att[ptr:ptr + batch_num, :] = feat_att.detach().cpu().numpy() 396 | ptr = ptr + batch_num 397 | print('Extracting Time:\t {:.3f}'.format(time.time() - start)) 398 | 399 | # switch to evaluation 400 | net.eval() 401 | print('Extracting Query Feature...') 402 | start = time.time() 403 | ptr = 0 404 | 405 | query_feat = np.zeros((nquery, feat_dim)) 406 | query_feat_att = np.zeros((nquery, feat_dim)) 407 | with torch.no_grad(): 408 | for batch_idx, (input, label) in enumerate(query_loader): 409 | batch_num = input.size(0) 410 | input = Variable(input.cuda()) 411 | if args.pcb == 'on': 412 | feat = net(input, input, test_mode[1]) 413 | query_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 414 | else: 415 | feat, feat_att = net(input, input, test_mode[1]) 416 | query_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 417 | query_feat_att[ptr:ptr + batch_num, :] = feat_att.detach().cpu().numpy() 418 | ptr = ptr + batch_num 419 | print('Extracting Time:\t {:.3f}'.format(time.time() - start)) 420 | 421 | start = time.time() 422 | 423 | 424 | if args.re_rank == 'random_walk': 425 | distmat = random_walk(query_feat, gall_feat) 426 | if args.pcb == 'off': distmat_att = random_walk(query_feat_att, gall_feat_att) 427 | elif args.re_rank == 'k_reciprocal': 428 | distmat = k_reciprocal(query_feat, gall_feat) 429 | if args.pcb == 'off': distmat_att = k_reciprocal(query_feat_att, gall_feat_att) 430 | elif args.re_rank == 'no': 431 | # compute the similarity 432 | distmat = -np.matmul(query_feat, np.transpose(gall_feat)) 433 | if args.pcb == 'off': distmat_att = -np.matmul(query_feat_att, np.transpose(gall_feat_att)) 434 | 435 | # evaluation 436 | if dataset == 'regdb': 437 | cmc, mAP, mINP = eval_regdb(distmat, query_label, gall_label) 438 | if args.pcb == 'off': cmc_att, mAP_att, mINP_att = eval_regdb(distmat_att, query_label, gall_label) 439 | elif dataset == 'sysu': 440 | cmc, mAP, mINP = eval_sysu(distmat, query_label, gall_label, query_cam, gall_cam) 441 | if args.pcb == 'off': cmc_att, mAP_att, mINP_att = eval_sysu(distmat_att, query_label, gall_label, query_cam, gall_cam) 442 | print('Evaluation Time:\t {:.3f}'.format(time.time() - start)) 443 | 444 | writer.add_scalar('rank1', cmc[0], epoch) 445 | writer.add_scalar('mAP', mAP, epoch) 446 | writer.add_scalar('mINP', mINP, epoch) 447 | if args.pcb == 'off': 448 | writer.add_scalar('rank1_att', cmc_att[0], epoch) 449 | writer.add_scalar('mAP_att', mAP_att, epoch) 450 | writer.add_scalar('mINP_att', mINP_att, epoch) 451 | 452 | return cmc, mAP, mINP, cmc_att, mAP_att, mINP_att 453 | else: 454 | return cmc, mAP, mINP 455 | 456 | 457 | 458 | # training 459 | print('==> Start Training...') 460 | for epoch in range(start_epoch, 61 - start_epoch): 461 | 462 | print('==> Preparing Data Loader...') 463 | # identity sampler 464 | sampler = IdentitySampler(trainset.train_color_label, \ 465 | trainset.train_thermal_label, color_pos, thermal_pos, args.num_pos, args.batch_size, 466 | epoch) 467 | 468 | trainset.cIndex = sampler.index1 # color index 469 | trainset.tIndex = sampler.index2 # thermal index 470 | print(epoch) 471 | 472 | loader_batch = args.batch_size * args.num_pos 473 | 474 | trainloader = data.DataLoader(trainset, batch_size=loader_batch, \ 475 | sampler=sampler, num_workers=args.workers, drop_last=True) 476 | 477 | # training 478 | train(epoch) 479 | 480 | if epoch > 9 and epoch % 2 == 0: 481 | print('Test Epoch: {}'.format(epoch)) 482 | 483 | # testing 484 | if args.pcb == 'off': 485 | cmc, mAP, mINP, cmc_fc, mAP_fc, mINP_fc = test(epoch) 486 | else: 487 | cmc_fc, mAP_fc, mINP_fc = test(epoch) 488 | # save model 489 | if cmc_fc[0] > best_acc: # not the real best for sysu-mm01 490 | best_acc = cmc_fc[0] 491 | best_epoch = epoch 492 | best_mAP = mAP_fc 493 | best_mINP = mINP_fc 494 | state = { 495 | 'net': net.state_dict(), 496 | 'cmc': cmc_fc, 497 | 'mAP': mAP_fc, 498 | 'mINP': mINP_fc, 499 | 'epoch': epoch, 500 | } 501 | torch.save(state, checkpoint_path + suffix + '_best.t') 502 | 503 | if args.pcb == 'off': 504 | print('POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 505 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 506 | 507 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 508 | cmc_fc[0], cmc_fc[4], cmc_fc[9], cmc_fc[19], mAP_fc, mINP_fc)) 509 | print('Best Epoch [{}], Rank-1: {:.2%} | mAP: {:.2%}| mINP: {:.2%}'.format(best_epoch, best_acc, best_mAP, best_mINP)) 510 | 511 | 512 | -------------------------------------------------------------------------------- /train_SGMRA.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import sys 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | import torch.backends.cudnn as cudnn 10 | from torch.autograd import Variable 11 | import torch.utils.data as data 12 | import torchvision 13 | import torchvision.transforms as transforms 14 | from data_loader import SYSUData, RegDBData, TestData 15 | from data_manager import * 16 | from eval_metrics import eval_sysu, eval_regdb 17 | from model_SGMRA import embed_net 18 | from utils import * 19 | from loss import OriTripletLoss, CenterTripletLoss, CrossEntropyLabelSmooth, TripletLoss_WRT 20 | from tensorboardX import SummaryWriter 21 | from re_rank import random_walk, k_reciprocal 22 | 23 | import numpy as np 24 | np.set_printoptions(threshold=np.inf) 25 | 26 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 27 | parser.add_argument('--dataset', default='regdb', help='dataset name: regdb or sysu]') 28 | parser.add_argument('--lr', default=0.1 , type=float, help='learning rate, 0.00035 for adam') 29 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer') 30 | parser.add_argument('--arch', default='resnet50', type=str, 31 | help='network baseline:resnet18 or resnet50') 32 | parser.add_argument('--resume', '-r', default='', type=str, 33 | help='resume from checkpoint') 34 | parser.add_argument('--test-only', action='store_true', help='test only') 35 | parser.add_argument('--model_path', default='save_model/', type=str, 36 | help='model save path') 37 | parser.add_argument('--save_epoch', default=100, type=int, 38 | metavar='s', help='save model every 10 epochs') 39 | parser.add_argument('--log_path', default='log/', type=str, 40 | help='log save path') 41 | parser.add_argument('--vis_log_path', default='log/vis_log/', type=str, 42 | help='log save path') 43 | parser.add_argument('--workers', default=4, type=int, metavar='N', 44 | help='number of data loading workers (default: 4)') 45 | parser.add_argument('--img_w', default=144, type=int, 46 | metavar='imgw', help='img width') 47 | parser.add_argument('--img_h', default=288, type=int, 48 | metavar='imgh', help='img height') 49 | parser.add_argument('--batch-size', default=8, type=int, 50 | metavar='B', help='training batch size') 51 | parser.add_argument('--test-batch', default=64, type=int, 52 | metavar='tb', help='testing batch size') 53 | parser.add_argument('--method', default='base', type=str, 54 | metavar='m', help='method type: base or agw') 55 | parser.add_argument('--margin', default=0.3, type=float, 56 | metavar='margin', help='triplet loss margin') 57 | parser.add_argument('--num_pos', default=4, type=int, 58 | help='num of pos per identity in each modality') 59 | parser.add_argument('--trial', default=10, type=int, 60 | metavar='t', help='trial (only for RegDB dataset)') 61 | parser.add_argument('--seed', default=0, type=int, 62 | metavar='t', help='random seed') 63 | parser.add_argument('--gpu', default='0', type=str, 64 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 65 | parser.add_argument('--mode', default='all', type=str, help='all or indoor') 66 | 67 | parser.add_argument('--share_net', default=2, type=int, 68 | metavar='share', help='[1,2,3,4,5]the start number of shared network in the two-stream networks') 69 | parser.add_argument('--re_rank', default='no', type=str, help='performing reranking. [random_walk | k_reciprocal | no]') 70 | parser.add_argument('--pcb', default='on', type=str, help='performing PCB, on or off') 71 | parser.add_argument('--w_center', default=2.0, type=float, help='the weight for center loss') 72 | 73 | parser.add_argument('--local_feat_dim', default=256, type=int, 74 | help='feature dimention of each local feature in PCB') 75 | parser.add_argument('--num_strips', default=6, type=int, 76 | help='num of local strips in PCB') 77 | 78 | parser.add_argument('--label_smooth', default='on', type=str, help='performing label smooth or not') 79 | 80 | args = parser.parse_args() 81 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 82 | 83 | set_seed(args.seed) 84 | 85 | dataset = args.dataset 86 | if dataset == 'sysu': 87 | # TODO: define your data path 88 | data_path = "/home/wmj/reid/data/SYSU/" 89 | log_path = os.path.join(args.log_path, 'sysu_log_ddag/') 90 | test_mode = [1, 2] # infrared to visible 91 | elif dataset =='regdb': 92 | # TODO: define your data path for RegDB dataset 93 | data_path = "/home/wmj/reid/data/RegDB/" 94 | log_path = os.path.join(args.log_path, 'regdb_log_ddag/') 95 | test_mode = [2, 1] # [2,1] visible to infrared 96 | 97 | checkpoint_path = args.model_path 98 | 99 | if not os.path.isdir(log_path): 100 | os.makedirs(log_path) 101 | if not os.path.isdir(checkpoint_path): 102 | os.makedirs(checkpoint_path) 103 | if not os.path.isdir(args.vis_log_path): 104 | os.makedirs(args.vis_log_path) 105 | 106 | suffix = dataset+'_c_tri_pcb_{}_w_tri_{}'.format(args.pcb,args.w_center) 107 | if args.pcb=='on': 108 | suffix = suffix + '_s{}_f{}'.format(args.num_strips, args.local_feat_dim) 109 | 110 | suffix = suffix + '_share_net{}'.format(args.share_net) 111 | if args.method=='agw': 112 | suffix = suffix + '_agw_k{}_p{}_lr_{}_seed_{}'.format(args.num_pos, args.batch_size, args.lr, args.seed) 113 | else: 114 | suffix = suffix + '_base_gm10_k{}_p{}_lr_{}_seed_{}'.format(args.num_pos, args.batch_size, args.lr, args.seed) 115 | 116 | 117 | if not args.optim == 'sgd': 118 | suffix = suffix + '_' + args.optim 119 | 120 | if dataset == 'regdb': 121 | suffix = suffix + '_trial_{}'.format(args.trial) 122 | 123 | sys.stdout = Logger(log_path + suffix + '_os.txt') 124 | 125 | vis_log_dir = args.vis_log_path + suffix + '/' 126 | 127 | if not os.path.isdir(vis_log_dir): 128 | os.makedirs(vis_log_dir) 129 | writer = SummaryWriter(vis_log_dir) 130 | print("==========\nArgs:{}\n==========".format(args)) 131 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 132 | best_acc = 0 # best test accuracy 133 | start_epoch = 0 134 | 135 | print('==> Loading data..') 136 | # Data loading code 137 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 138 | transform_train = transforms.Compose([ 139 | transforms.ToPILImage(), 140 | transforms.Pad(10), 141 | transforms.RandomCrop((args.img_h, args.img_w)), 142 | transforms.RandomHorizontalFlip(), 143 | transforms.ToTensor(), 144 | normalize, 145 | ]) 146 | transform_test = transforms.Compose([ 147 | transforms.ToPILImage(), 148 | transforms.Resize((args.img_h, args.img_w)), 149 | transforms.ToTensor(), 150 | normalize, 151 | ]) 152 | 153 | end = time.time() 154 | if dataset == 'sysu': 155 | # training set 156 | trainset = SYSUData(data_path, transform=transform_train) 157 | # generate the idx of each person identity 158 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 159 | 160 | # testing set 161 | query_img, query_label, query_cam = process_query_sysu(data_path, mode=args.mode) 162 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=0) 163 | 164 | elif dataset == 'regdb': 165 | # training set 166 | trainset = RegDBData(data_path, args.trial, transform=transform_train) 167 | # generate the idx of each person identity 168 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 169 | 170 | # testing set 171 | query_img, query_label = process_test_regdb(data_path, trial=args.trial, modal='visible') 172 | gall_img, gall_label = process_test_regdb(data_path, trial=args.trial, modal='thermal') 173 | 174 | gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 175 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 176 | 177 | # testing data loader 178 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 179 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 180 | 181 | n_class = len(np.unique(trainset.train_color_label)) 182 | nquery = len(query_label) 183 | ngall = len(gall_label) 184 | 185 | print('Dataset {} statistics:'.format(dataset)) 186 | print(' ------------------------------') 187 | print(' subset | # ids | # images') 188 | print(' ------------------------------') 189 | print(' visible | {:5d} | {:8d}'.format(n_class, len(trainset.train_color_label))) 190 | print(' thermal | {:5d} | {:8d}'.format(n_class, len(trainset.train_thermal_label))) 191 | print(' ------------------------------') 192 | print(' query | {:5d} | {:8d}'.format(len(np.unique(query_label)), nquery)) 193 | print(' gallery | {:5d} | {:8d}'.format(len(np.unique(gall_label)), ngall)) 194 | print(' ------------------------------') 195 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 196 | 197 | print('==> Building model..') 198 | if args.method =='base': 199 | net = embed_net(n_class, no_local= 'off', gm_pool = 'on', arch=args.arch, share_net=args.share_net, pcb=args.pcb, local_feat_dim=args.local_feat_dim, num_strips=args.num_strips) 200 | else: 201 | net = embed_net(n_class, no_local= 'on', gm_pool = 'on', arch=args.arch, share_net=args.share_net, pcb=args.pcb) 202 | net.to(device) 203 | 204 | 205 | cudnn.benchmark = True 206 | 207 | if len(args.resume) > 0: 208 | model_path = checkpoint_path + args.resume 209 | if os.path.isfile(model_path): 210 | print('==> loading checkpoint {}'.format(args.resume)) 211 | checkpoint = torch.load(model_path) 212 | start_epoch = checkpoint['epoch'] 213 | net.load_state_dict(checkpoint['net']) 214 | print('==> loaded checkpoint {} (epoch {})' 215 | .format(args.resume, checkpoint['epoch'])) 216 | else: 217 | print('==> no checkpoint found at {}'.format(args.resume)) 218 | 219 | # define loss function 220 | if args.label_smooth == 'off': 221 | criterion_id = nn.CrossEntropyLoss() 222 | else: 223 | criterion_id = CrossEntropyLabelSmooth(n_class) 224 | 225 | if args.method == 'agw': 226 | criterion_tri = TripletLoss_WRT() 227 | else: 228 | loader_batch = args.batch_size * args.num_pos 229 | #criterion_tri= OriTripletLoss(batch_size=loader_batch, margin=args.margin) 230 | criterion_tri= CenterTripletLoss(batch_size=loader_batch, margin=args.margin) 231 | 232 | criterion_id.to(device) 233 | criterion_tri.to(device) 234 | 235 | 236 | 237 | if args.optim == 'sgd': 238 | if args.pcb == 'on': 239 | ignored_params = list(map(id, net.local_conv_list.parameters())) \ 240 | + list(map(id, net.fc_list.parameters())) 241 | 242 | base_params = filter(lambda p: id(p) not in ignored_params, net.parameters()) 243 | 244 | optimizer = optim.SGD([ 245 | {'params': base_params, 'lr': 0.1 * args.lr}, 246 | {'params': net.local_conv_list.parameters(), 'lr': args.lr}, 247 | {'params': net.fc_list.parameters(), 'lr': args.lr} 248 | ], 249 | weight_decay=5e-4, momentum=0.9, nesterov=True) 250 | else: 251 | ignored_params = list(map(id, net.bottleneck.parameters())) \ 252 | + list(map(id, net.classifier.parameters())) 253 | 254 | base_params = filter(lambda p: id(p) not in ignored_params, net.parameters()) 255 | 256 | optimizer = optim.SGD([ 257 | {'params': base_params, 'lr': 0.1 * args.lr}, 258 | {'params': net.bottleneck.parameters(), 'lr': args.lr}, 259 | {'params': net.classifier.parameters(), 'lr': args.lr}], 260 | weight_decay=5e-4, momentum=0.9, nesterov=True) 261 | 262 | def adjust_learning_rate(optimizer, epoch): 263 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 264 | if epoch < 10: 265 | lr = args.lr * (epoch + 1) / 10 266 | elif epoch >= 10 and epoch < 20: 267 | lr = args.lr 268 | elif epoch >= 20 and epoch < 50: 269 | lr = args.lr * 0.1 270 | elif epoch >= 50: 271 | lr = args.lr * 0.01 272 | 273 | optimizer.param_groups[0]['lr'] = 0.1 * lr 274 | for i in range(len(optimizer.param_groups) - 1): 275 | optimizer.param_groups[i + 1]['lr'] = lr 276 | 277 | return lr 278 | 279 | 280 | def train(epoch): 281 | 282 | current_lr = adjust_learning_rate(optimizer, epoch) 283 | train_loss = AverageMeter() 284 | id_loss = AverageMeter() 285 | tri_loss = AverageMeter() 286 | data_time = AverageMeter() 287 | batch_time = AverageMeter() 288 | correct = 0 289 | total = 0 290 | 291 | # switch to train mode 292 | net.train() 293 | end = time.time() 294 | 295 | for batch_idx, (input1, input2, label1, label2) in enumerate(trainloader): 296 | 297 | labels = torch.cat((label1, label2), 0) 298 | 299 | input1 = Variable(input1.cuda()) 300 | input2 = Variable(input2.cuda()) 301 | 302 | labels = Variable(labels.cuda()) 303 | data_time.update(time.time() - end) 304 | 305 | 306 | if args.pcb == 'on': 307 | feat, out0, feat_all, feat_mem, out_mem = net(input1, input2) 308 | loss_id = criterion_id(out0[0], labels.long()) 309 | loss_tri_l, batch_acc = criterion_tri(feat[0], labels) 310 | for i in range(len(feat)-1): 311 | loss_id += criterion_id(out0[i+1], labels.long()) 312 | loss_tri_l += criterion_tri(feat[i+1], labels)[0] 313 | loss_tri, batch_acc = criterion_tri(feat_all, labels) 314 | loss_tri += loss_tri_l * args.w_center # 315 | correct += batch_acc 316 | ### for mem branch 317 | loss_tri_mem, batch_acc = criterion_tri(feat_mem, labels.long()) 318 | loss_id_mem = criterion_id(out_mem, labels.long()) 319 | loss = loss_id + loss_tri + loss_id_mem*0.1 + loss_tri_mem 320 | else: 321 | feat, out0 = net(input1, input2) 322 | loss_id = criterion_id(out0, labels) 323 | 324 | loss_tri, batch_acc = criterion_tri(feat, labels) 325 | correct += (batch_acc / 2) 326 | _, predicted = out0.max(1) 327 | correct += (predicted.eq(labels).sum().item() / 2) 328 | loss = loss_id + loss_tri * args.w_center # 329 | 330 | 331 | optimizer.zero_grad() 332 | loss.backward() 333 | optimizer.step() 334 | 335 | # update P 336 | train_loss.update(loss.item(), 2 * input1.size(0)) 337 | id_loss.update(loss_id.item(), 2 * input1.size(0)) 338 | tri_loss.update(loss_tri, 2 * input1.size(0)) 339 | total += labels.size(0) 340 | 341 | # measure elapsed time 342 | batch_time.update(time.time() - end) 343 | end = time.time() 344 | if batch_idx % 50 == 0: 345 | print('Epoch: [{}][{}/{}] ' 346 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 347 | 'lr:{:.3f} ' 348 | 'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f}) ' 349 | 'iLoss: {id_loss.val:.4f} ({id_loss.avg:.4f}) ' 350 | 'TLoss: {tri_loss.val:.4f} ({tri_loss.avg:.4f}) ' 351 | 'Accu: {:.2f}'.format( 352 | epoch, batch_idx, len(trainloader), current_lr, 353 | 100. * correct / total, batch_time=batch_time, 354 | train_loss=train_loss, id_loss=id_loss,tri_loss=tri_loss)) 355 | 356 | writer.add_scalar('total_loss', train_loss.avg, epoch) 357 | writer.add_scalar('id_loss', id_loss.avg, epoch) 358 | writer.add_scalar('tri_loss', tri_loss.avg, epoch) 359 | writer.add_scalar('lr', current_lr, epoch) 360 | 361 | 362 | def test(epoch): 363 | # switch to evaluation mode 364 | net.eval() 365 | print('Extracting Gallery Feature...') 366 | start = time.time() 367 | ptr = 0 368 | if args.pcb == 'on': 369 | feat_dim = args.num_strips * args.local_feat_dim 370 | else: 371 | feat_dim = 2048 372 | gall_feat = np.zeros((ngall, feat_dim)) 373 | gall_feat_att = np.zeros((ngall, feat_dim)) 374 | with torch.no_grad(): 375 | for batch_idx, (input, label) in enumerate(gall_loader): 376 | batch_num = input.size(0) 377 | input = Variable(input.cuda()) 378 | if args.pcb == 'on': 379 | feat = net(input, input, test_mode[0]) 380 | gall_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 381 | else: 382 | feat, feat_att = net(input, input, test_mode[0]) 383 | gall_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 384 | gall_feat_att[ptr:ptr + batch_num, :] = feat_att.detach().cpu().numpy() 385 | ptr = ptr + batch_num 386 | print('Extracting Time:\t {:.3f}'.format(time.time() - start)) 387 | 388 | # switch to evaluation 389 | net.eval() 390 | print('Extracting Query Feature...') 391 | start = time.time() 392 | ptr = 0 393 | 394 | query_feat = np.zeros((nquery, feat_dim)) 395 | query_feat_att = np.zeros((nquery, feat_dim)) 396 | with torch.no_grad(): 397 | for batch_idx, (input, label) in enumerate(query_loader): 398 | batch_num = input.size(0) 399 | input = Variable(input.cuda()) 400 | if args.pcb == 'on': 401 | feat = net(input, input, test_mode[1]) 402 | query_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 403 | else: 404 | feat, feat_att = net(input, input, test_mode[1]) 405 | query_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 406 | query_feat_att[ptr:ptr + batch_num, :] = feat_att.detach().cpu().numpy() 407 | ptr = ptr + batch_num 408 | print('Extracting Time:\t {:.3f}'.format(time.time() - start)) 409 | 410 | start = time.time() 411 | 412 | 413 | if args.re_rank == 'random_walk': 414 | distmat = random_walk(query_feat, gall_feat) 415 | if args.pcb == 'off': distmat_att = random_walk(query_feat_att, gall_feat_att) 416 | elif args.re_rank == 'k_reciprocal': 417 | distmat = k_reciprocal(query_feat, gall_feat) 418 | if args.pcb == 'off': distmat_att = k_reciprocal(query_feat_att, gall_feat_att) 419 | elif args.re_rank == 'no': 420 | # compute the similarity 421 | distmat = -np.matmul(query_feat, np.transpose(gall_feat)) 422 | if args.pcb == 'off': distmat_att = -np.matmul(query_feat_att, np.transpose(gall_feat_att)) 423 | 424 | # evaluation 425 | if dataset == 'regdb': 426 | cmc, mAP, mINP = eval_regdb(distmat, query_label, gall_label) 427 | if args.pcb == 'off': cmc_att, mAP_att, mINP_att = eval_regdb(distmat_att, query_label, gall_label) 428 | elif dataset == 'sysu': 429 | cmc, mAP, mINP = eval_sysu(distmat, query_label, gall_label, query_cam, gall_cam) 430 | if args.pcb == 'off': cmc_att, mAP_att, mINP_att = eval_sysu(distmat_att, query_label, gall_label, query_cam, gall_cam) 431 | print('Evaluation Time:\t {:.3f}'.format(time.time() - start)) 432 | 433 | writer.add_scalar('rank1', cmc[0], epoch) 434 | writer.add_scalar('mAP', mAP, epoch) 435 | writer.add_scalar('mINP', mINP, epoch) 436 | if args.pcb == 'off': 437 | writer.add_scalar('rank1_att', cmc_att[0], epoch) 438 | writer.add_scalar('mAP_att', mAP_att, epoch) 439 | writer.add_scalar('mINP_att', mINP_att, epoch) 440 | 441 | return cmc, mAP, mINP, cmc_att, mAP_att, mINP_att 442 | else: 443 | return cmc, mAP, mINP 444 | 445 | 446 | 447 | # training 448 | print('==> Start Training...') 449 | for epoch in range(start_epoch, 61 - start_epoch): 450 | 451 | print('==> Preparing Data Loader...') 452 | # identity sampler 453 | sampler = IdentitySampler(trainset.train_color_label, \ 454 | trainset.train_thermal_label, color_pos, thermal_pos, args.num_pos, args.batch_size, 455 | epoch) 456 | 457 | trainset.cIndex = sampler.index1 # color index 458 | trainset.tIndex = sampler.index2 # thermal index 459 | print(epoch) 460 | 461 | loader_batch = args.batch_size * args.num_pos 462 | 463 | trainloader = data.DataLoader(trainset, batch_size=loader_batch, \ 464 | sampler=sampler, num_workers=args.workers, drop_last=True) 465 | 466 | # training 467 | train(epoch) 468 | 469 | if epoch > 9 and epoch % 2 == 0: 470 | print('Test Epoch: {}'.format(epoch)) 471 | 472 | # testing 473 | if args.pcb == 'off': 474 | cmc, mAP, mINP, cmc_fc, mAP_fc, mINP_fc = test(epoch) 475 | else: 476 | cmc_fc, mAP_fc, mINP_fc = test(epoch) 477 | # save model 478 | if cmc_fc[0] > best_acc: # not the real best for sysu-mm01 479 | best_acc = cmc_fc[0] 480 | best_epoch = epoch 481 | best_mAP = mAP_fc 482 | best_mINP = mINP_fc 483 | state = { 484 | 'net': net.state_dict(), 485 | 'cmc': cmc_fc, 486 | 'mAP': mAP_fc, 487 | 'mINP': mINP_fc, 488 | 'epoch': epoch, 489 | } 490 | torch.save(state, checkpoint_path + suffix + '_best.t') 491 | 492 | if args.pcb == 'off': 493 | print('POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 494 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 495 | 496 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 497 | cmc_fc[0], cmc_fc[4], cmc_fc[9], cmc_fc[19], mAP_fc, mINP_fc)) 498 | print('Best Epoch [{}], Rank-1: {:.2%} | mAP: {:.2%}| mINP: {:.2%}'.format(best_epoch, best_acc, best_mAP, best_mINP)) 499 | 500 | 501 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data.sampler import Sampler 4 | import sys 5 | import os.path as osp 6 | import torch 7 | 8 | def load_data(input_data_path ): 9 | with open(input_data_path) as f: 10 | data_file_list = open(input_data_path, 'rt').read().splitlines() 11 | # Get full list of color image and labels 12 | file_image = [s.split(' ')[0] for s in data_file_list] 13 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 14 | 15 | return file_image, file_label 16 | 17 | 18 | def GenIdx( train_color_label, train_thermal_label): 19 | color_pos = [] 20 | unique_label_color = np.unique(train_color_label) 21 | for i in range(len(unique_label_color)): 22 | tmp_pos = [k for k,v in enumerate(train_color_label) if v==unique_label_color[i]] 23 | color_pos.append(tmp_pos) 24 | 25 | thermal_pos = [] 26 | unique_label_thermal = np.unique(train_thermal_label) 27 | for i in range(len(unique_label_thermal)): 28 | tmp_pos = [k for k,v in enumerate(train_thermal_label) if v==unique_label_thermal[i]] 29 | thermal_pos.append(tmp_pos) 30 | return color_pos, thermal_pos 31 | 32 | def GenCamIdx(gall_img, gall_label, mode): 33 | if mode =='indoor': 34 | camIdx = [1,2] 35 | else: 36 | camIdx = [1,2,4,5] 37 | gall_cam = [] 38 | for i in range(len(gall_img)): 39 | gall_cam.append(int(gall_img[i][-10])) 40 | 41 | sample_pos = [] 42 | unique_label = np.unique(gall_label) 43 | for i in range(len(unique_label)): 44 | for j in range(len(camIdx)): 45 | id_pos = [k for k,v in enumerate(gall_label) if v==unique_label[i] and gall_cam[k]==camIdx[j]] 46 | if id_pos: 47 | sample_pos.append(id_pos) 48 | return sample_pos 49 | 50 | def ExtractCam(gall_img): 51 | gall_cam = [] 52 | for i in range(len(gall_img)): 53 | cam_id = int(gall_img[i][-10]) 54 | # if cam_id ==3: 55 | # cam_id = 2 56 | gall_cam.append(cam_id) 57 | 58 | return np.array(gall_cam) 59 | 60 | class IdentitySampler(Sampler): 61 | """Sample person identities evenly in each batch. 62 | Args: 63 | train_color_label, train_thermal_label: labels of two modalities 64 | color_pos, thermal_pos: positions of each identity 65 | batchSize: batch size 66 | """ 67 | 68 | def __init__(self, train_color_label, train_thermal_label, color_pos, thermal_pos, num_pos, batchSize, epoch): 69 | uni_label = np.unique(train_color_label) 70 | self.n_classes = len(uni_label) 71 | 72 | 73 | N = np.maximum(len(train_color_label), len(train_thermal_label)) 74 | for j in range(int(N/(batchSize*num_pos))+1): 75 | batch_idx = np.random.choice(uni_label, batchSize, replace = False) 76 | for i in range(batchSize): 77 | sample_color = np.random.choice(color_pos[batch_idx[i]], num_pos) 78 | sample_thermal = np.random.choice(thermal_pos[batch_idx[i]], num_pos) 79 | 80 | if j ==0 and i==0: 81 | index1= sample_color 82 | index2= sample_thermal 83 | else: 84 | index1 = np.hstack((index1, sample_color)) 85 | index2 = np.hstack((index2, sample_thermal)) 86 | 87 | self.index1 = index1 88 | self.index2 = index2 89 | self.N = N 90 | 91 | def __iter__(self): 92 | return iter(np.arange(len(self.index1))) 93 | 94 | def __len__(self): 95 | return self.N 96 | 97 | class AverageMeter(object): 98 | """Computes and stores the average and current value""" 99 | def __init__(self): 100 | self.reset() 101 | 102 | def reset(self): 103 | self.val = 0 104 | self.avg = 0 105 | self.sum = 0 106 | self.count = 0 107 | 108 | def update(self, val, n=1): 109 | self.val = val 110 | self.sum += val * n 111 | self.count += n 112 | self.avg = self.sum / self.count 113 | 114 | def mkdir_if_missing(directory): 115 | if not osp.exists(directory): 116 | try: 117 | os.makedirs(directory) 118 | except OSError as e: 119 | if e.errno != errno.EEXIST: 120 | raise 121 | class Logger(object): 122 | """ 123 | Write console output to external text file. 124 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 125 | """ 126 | def __init__(self, fpath=None): 127 | self.console = sys.stdout 128 | self.file = None 129 | if fpath is not None: 130 | mkdir_if_missing(osp.dirname(fpath)) 131 | self.file = open(fpath, 'w') 132 | 133 | def __del__(self): 134 | self.close() 135 | 136 | def __enter__(self): 137 | pass 138 | 139 | def __exit__(self, *args): 140 | self.close() 141 | 142 | def write(self, msg): 143 | self.console.write(msg) 144 | if self.file is not None: 145 | self.file.write(msg) 146 | 147 | def flush(self): 148 | self.console.flush() 149 | if self.file is not None: 150 | self.file.flush() 151 | os.fsync(self.file.fileno()) 152 | 153 | def close(self): 154 | self.console.close() 155 | if self.file is not None: 156 | self.file.close() 157 | 158 | def set_seed(seed, cuda=True): 159 | np.random.seed(seed) 160 | torch.manual_seed(seed) 161 | if cuda: 162 | torch.cuda.manual_seed(seed) 163 | 164 | def set_requires_grad(nets, requires_grad=False): 165 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 166 | Parameters: 167 | nets (network list) -- a list of networks 168 | requires_grad (bool) -- whether the networks require gradients or not 169 | """ 170 | if not isinstance(nets, list): 171 | nets = [nets] 172 | for net in nets: 173 | if net is not None: 174 | for param in net.parameters(): 175 | param.requires_grad = requires_grad --------------------------------------------------------------------------------