├── LICENSE ├── README.md ├── Transform.py ├── data_loader.py ├── data_manager.py ├── eval_metrics.py ├── heterogeneity_loss.py ├── model.py ├── pre_process_sysu.py ├── test.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 98zyx 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hetero-center-loss-for-cross-modality-person-re-id 2 | Code for paper "Hetero-center loss for cross-modality person re-identification" 3 | 4 | ## Update: 5 | 2020-06-03: 6 | Because pytorch split the data of a batch to each gpu when using multi-gpus, the loss may be incorrectly computed. The users may need to use single gpu to reproduce the experimental results in the paper. 7 | 8 | 2020-01-07: 9 | we fix up a bug in learning rate schedule, before that only the first three parameter group's learning rate will be correctly decay to 1/10. However, after fixing up the bug, the model's performance still stay the same. The updated model and code have been upload. 10 | 11 | ## Requirments: 12 | **pytorch: 0.4.1(the higher version may lead to performance fluctuation)** 13 | 14 | torchvision: 0.2.1 15 | 16 | numpy: 1.17.4 17 | 18 | python: 3.7 19 | 20 | 21 | ## Dataset: 22 | **SYSU-MM01** 23 | 24 | **Reg-DB** 25 | 26 | 27 | ## Run: 28 | ### SYSU-MM01: 29 | 1. prepare training set 30 | ``` 31 | python pre_process_sysu.py 32 | ``` 33 | 2. train model 34 | ``` 35 | python train.py --dataset sysu --lr 0.01 --drop 0.0 --trial 1 --gpu 1 --epochs 60 --w_hc 0.5 --per_img 8 36 | ``` 37 | * (Notice that you need to set the 88 line in train.py to your SYSU-MM01 dataset path) 38 | 39 | 3. evaluate model(single-shot all-search) 40 | ``` 41 | python test.py --dataset sysu --lr 0.01 --drop 0.0 --trial 1 --gpu 1 --low-dim 512 --resume 'Your model name' --w_hc 0.5 --mode all --gall-mode single --model_path 'Your model path' 42 | ``` 43 | 44 | ### Reg-DB: 45 | 1. train model 46 | ``` 47 | python train.py --dataset regdb --lr 0.01 --drop 0.0 --trial 1 --gpu 1 --epochs 60 --w_hc 0.5 --per_img 8 48 | ``` 49 | 50 | 2. evaluate model 51 | ``` 52 | python test.py --dataset regdb --lr 0.01 --drop 0.0 --trial 1 --gpu 1 --low-dim 512 --resume 'Your model name' --w_hc 0.5 --model_path 'Your model path' 53 | ``` 54 | 55 | ## Results: 56 | Dataset | Rank1 | mAP | model 57 | ---- | ----- | ------ | ----- 58 | SYSU-MM01 | ~56% | ~54% | [BaiduYun(code:y2em)](https://pan.baidu.com/s/1Ty1WCBVUZvzGk-cQLK432w) 59 | RegDB | ~83% | ~72% | [BaiduYun(code:y2em)](https://pan.baidu.com/s/1Ty1WCBVUZvzGk-cQLK432w) 60 | 61 | ## Tips: 62 | Because this is the first time I use Github to release my code, maybe this project is a little difficult to read and use. If you have any question, please don't hesitate to contact me (zhuyuanxin98@outlook.com). I will reply to you as soon as possible. 63 | 64 | Most of the code are borrowed from https://github.com/mangye16/Cross-Modal-Re-ID-baseline. I am very grateful to the author (@[mangye16](https://github.com/mangye16)) for his contribution and help. 65 | 66 | **If you think this project useful, please give me a star and cite following papers:** 67 | 68 | [1] Zhu Y, Yang Z, Wang L, et al. Hetero-Center Loss for Cross-Modality Person Re-Identification[J]. Neurocomputing, 2019. 69 | 70 | [2] Ye M, Lan X, Wang Z, et al. Bi-directional Center-Constrained Top-Ranking for Visible Thermal Person Re-Identification[J]. IEEE Transactions on Information Forensics and Security, 2019. 71 | 72 | -------------------------------------------------------------------------------- /Transform.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import random 6 | import math 7 | 8 | 9 | class RectScale(object): 10 | def __init__(self, height, width, interpolation=Image.BILINEAR): 11 | self.height = height 12 | self.width = width 13 | self.interpolation = interpolation 14 | 15 | def __call__(self, img): 16 | w, h = img.size 17 | if h == self.height and w == self.width: 18 | return img 19 | return img.resize((self.width, self.height), self.interpolation) -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image, ImageChops 3 | from torchvision import transforms 4 | import random 5 | import torch 6 | import torchvision.datasets as datasets 7 | import torch.utils.data as data 8 | 9 | 10 | class SYSUData(data.Dataset): 11 | def __init__(self, data_dir, transform=None, colorIndex = None, thermalIndex = None): 12 | 13 | # Load training images (path) and labels 14 | train_color_image = np.load(data_dir + 'train_rgb_resized_img.npy') 15 | self.train_color_label = np.load(data_dir + 'train_rgb_resized_label.npy') 16 | 17 | train_thermal_image = np.load(data_dir + 'train_ir_resized_img.npy') 18 | self.train_thermal_label = np.load(data_dir + 'train_ir_resized_label.npy') 19 | 20 | # RGB format 21 | self.train_color_image = train_color_image 22 | self.train_thermal_image = train_thermal_image 23 | self.transform = transform 24 | self.cIndex = colorIndex 25 | self.tIndex = thermalIndex 26 | 27 | def __getitem__(self, index): 28 | 29 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 30 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 31 | 32 | img1 = self.transform(img1) 33 | img2 = self.transform(img2) 34 | 35 | return img1, img2, target1, target2 36 | 37 | def __len__(self): 38 | return len(self.train_color_label) 39 | 40 | 41 | class RegDBData(data.Dataset): 42 | def __init__(self, data_dir, trial, transform=None, colorIndex = None, thermalIndex = None): 43 | # Load training images (path) and labels 44 | train_color_list = data_dir + 'idx/train_visible_{}'.format(trial)+ '.txt' 45 | train_thermal_list = data_dir + 'idx/train_thermal_{}'.format(trial)+ '.txt' 46 | 47 | color_img_file, train_color_label = load_data(train_color_list) 48 | thermal_img_file, train_thermal_label = load_data(train_thermal_list) 49 | 50 | train_color_image = [] 51 | for i in range(len(color_img_file)): 52 | img = Image.open(data_dir+ color_img_file[i]) 53 | img = img.resize((144, 288), Image.ANTIALIAS) 54 | pix_array = np.array(img) 55 | train_color_image.append(pix_array) 56 | train_color_image = np.array(train_color_image) 57 | 58 | train_thermal_image = [] 59 | for i in range(len(thermal_img_file)): 60 | img = Image.open(data_dir+ thermal_img_file[i]) 61 | img = img.resize((144, 288), Image.ANTIALIAS) 62 | pix_array = np.array(img) 63 | train_thermal_image.append(pix_array) 64 | train_thermal_image = np.array(train_thermal_image) 65 | 66 | # RGB format 67 | self.train_color_image = train_color_image 68 | self.train_color_label = train_color_label 69 | 70 | # RGB format 71 | self.train_thermal_image = train_thermal_image 72 | self.train_thermal_label = train_thermal_label 73 | 74 | self.transform = transform 75 | self.cIndex = colorIndex 76 | self.tIndex = thermalIndex 77 | 78 | def __getitem__(self, index): 79 | 80 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 81 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 82 | 83 | img1 = self.transform(img1) 84 | img2 = self.transform(img2) 85 | 86 | return img1, img2, target1, target2 87 | 88 | def __len__(self): 89 | return len(self.train_color_label) 90 | 91 | class TestData(data.Dataset): 92 | def __init__(self, test_img_file, test_label, transform=None, img_size = (224,224)): 93 | test_image = [] 94 | for i in range(len(test_img_file)): 95 | img = Image.open(test_img_file[i]) 96 | img = img.resize((img_size[0], img_size[1]), Image.ANTIALIAS) 97 | pix_array = np.array(img) 98 | test_image.append(pix_array) 99 | test_image = np.array(test_image) 100 | self.test_image = test_image 101 | self.test_label = test_label 102 | self.transform = transform 103 | 104 | def __getitem__(self, index): 105 | img1, target1 = self.test_image[index], self.test_label[index] 106 | img1 = self.transform(img1) 107 | return img1, target1 108 | 109 | def __len__(self): 110 | return len(self.test_image) 111 | 112 | def load_data(input_data_path ): 113 | with open(input_data_path) as f: 114 | data_file_list = open(input_data_path, 'rt').read().splitlines() 115 | # Get full list of image and labels 116 | file_image = [s.split(' ')[0] for s in data_file_list] 117 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 118 | 119 | return file_image, file_label 120 | -------------------------------------------------------------------------------- /data_manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import sys 4 | import numpy as np 5 | import random 6 | 7 | def process_query_sysu(data_path, mode = 'all', relabel=False): 8 | if mode== 'all': 9 | ir_cameras = ['cam3','cam6'] 10 | elif mode =='indoor': 11 | ir_cameras = ['cam3','cam6'] 12 | 13 | file_path = os.path.join(data_path,'exp/test_id.txt') 14 | files_rgb = [] 15 | files_ir = [] 16 | 17 | with open(file_path, 'r') as file: 18 | ids = file.read().splitlines() 19 | ids = [int(y) for y in ids[0].split(',')] 20 | ids = ["%04d" % x for x in ids] 21 | 22 | for id in sorted(ids): 23 | for cam in ir_cameras: 24 | img_dir = os.path.join(data_path,cam,id) 25 | if os.path.isdir(img_dir): 26 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 27 | files_ir.extend(new_files) 28 | query_img = [] 29 | query_id = [] 30 | query_cam = [] 31 | for img_path in files_ir: 32 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 33 | query_img.append(img_path) 34 | query_id.append(pid) 35 | query_cam.append(camid) 36 | return query_img, np.array(query_id), np.array(query_cam) 37 | 38 | def process_gallery_sysu(data_path, mode = 'all', trial = 0, relabel=False, gall_mode='single'): 39 | 40 | # random.seed(trial) 41 | #gall_mode = 'single' 42 | if mode== 'all': 43 | rgb_cameras = ['cam1','cam2','cam4','cam5'] 44 | elif mode =='indoor': 45 | rgb_cameras = ['cam1','cam2'] 46 | 47 | file_path = os.path.join(data_path,'exp/test_id.txt') 48 | files_rgb = [] 49 | with open(file_path, 'r') as file: 50 | ids = file.read().splitlines() 51 | ids = [int(y) for y in ids[0].split(',')] 52 | ids = ["%04d" % x for x in ids] 53 | #np.random.seed(1) 54 | for id in sorted(ids): 55 | for cam in rgb_cameras: 56 | img_dir = os.path.join(data_path,cam,id) 57 | if os.path.isdir(img_dir): 58 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 59 | if gall_mode == 'single': 60 | files_rgb.append(random.choice(new_files)) 61 | if gall_mode == 'multi': 62 | files_rgb.append(np.random.choice(new_files, 10, replace=False)) 63 | gall_img = [] 64 | gall_id = [] 65 | gall_cam = [] 66 | for img_path in files_rgb: 67 | if gall_mode == 'single': 68 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 69 | gall_img.append(img_path) 70 | gall_id.append(pid) 71 | gall_cam.append(camid) 72 | if gall_mode == 'multi': 73 | for i in img_path: 74 | camid, pid = int(i[-15]), int(i[-13:-9]) 75 | gall_img.append(i) 76 | gall_id.append(pid) 77 | gall_cam.append(camid) 78 | return gall_img, np.array(gall_id), np.array(gall_cam) 79 | 80 | def process_test_regdb(img_dir, trial = 1, modal = 'visible'): 81 | if modal=='visible': 82 | input_data_path = img_dir + 'idx/test_visible_{}'.format(trial) + '.txt' 83 | elif modal=='thermal': 84 | input_data_path = img_dir + 'idx/test_thermal_{}'.format(trial) + '.txt' 85 | 86 | with open(input_data_path) as f: 87 | data_file_list = open(input_data_path, 'rt').read().splitlines() 88 | # Get full list of image and labels 89 | file_image = [img_dir + '/' + s.split(' ')[0] for s in data_file_list] 90 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 91 | 92 | return file_image, np.array(file_label) -------------------------------------------------------------------------------- /eval_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import glob 4 | import re 5 | import sys 6 | import os.path as osp 7 | import numpy as np 8 | 9 | import random 10 | from time import time 11 | 12 | """Cross-Modality ReID""" 13 | 14 | def eval_sysu(distmat, q_pids, g_pids, q_camids, g_camids, max_rank = 20): 15 | """Evaluation with sysu metric 16 | Key: for each query identity, its gallery images from the same camera view are discarded. 17 | """ 18 | num_q, num_g = distmat.shape 19 | if num_g < max_rank: 20 | max_rank = num_g 21 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 22 | indices = np.argsort(distmat, axis=1) 23 | pred_label = g_pids[indices] 24 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 25 | 26 | # compute cmc curve for each query 27 | new_all_cmc = [] 28 | all_cmc = [] 29 | all_AP = [] 30 | num_valid_q = 0. # number of valid query 31 | for q_idx in range(num_q): 32 | # get query pid and camid 33 | q_pid = q_pids[q_idx] 34 | q_camid = q_camids[q_idx] 35 | 36 | # remove gallery samples that have the same pid and camid with query 37 | order = indices[q_idx] 38 | remove = (q_camid == 3) & (g_camids[order] == 2) 39 | keep = np.invert(remove) 40 | 41 | # compute cmc curve 42 | # the cmc calculation is different from standard protocol 43 | # we follow the protocol of the author's released code 44 | new_cmc = pred_label[q_idx][keep] 45 | new_index = np.unique(new_cmc, return_index=True)[1] 46 | new_cmc = [new_cmc[index] for index in sorted(new_index)] 47 | 48 | new_match = (new_cmc == q_pid).astype(np.int32) 49 | new_cmc = new_match.cumsum() 50 | new_all_cmc.append(new_cmc[:max_rank]) 51 | 52 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 53 | if not np.any(orig_cmc): 54 | # this condition is true when query identity does not appear in gallery 55 | continue 56 | 57 | cmc = orig_cmc.cumsum() 58 | cmc[cmc > 1] = 1 59 | 60 | all_cmc.append(cmc[:max_rank]) 61 | num_valid_q += 1. 62 | 63 | # compute average precision 64 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 65 | num_rel = orig_cmc.sum() 66 | tmp_cmc = orig_cmc.cumsum() 67 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 68 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 69 | AP = tmp_cmc.sum() / num_rel 70 | all_AP.append(AP) 71 | 72 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 73 | 74 | all_cmc = np.asarray(all_cmc).astype(np.float32) 75 | all_cmc = all_cmc.sum(0) / num_valid_q 76 | 77 | new_all_cmc = np.asarray(new_all_cmc).astype(np.float32) 78 | new_all_cmc = new_all_cmc.sum(0) / num_valid_q 79 | mAP = np.mean(all_AP) 80 | 81 | return new_all_cmc, mAP 82 | 83 | 84 | 85 | def eval_regdb(distmat, q_pids, g_pids, max_rank = 20): 86 | num_q, num_g = distmat.shape 87 | if num_g < max_rank: 88 | max_rank = num_g 89 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 90 | indices = np.argsort(distmat, axis=1) 91 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 92 | 93 | # compute cmc curve for each query 94 | all_cmc = [] 95 | all_AP = [] 96 | num_valid_q = 0. # number of valid query 97 | 98 | # only two cameras 99 | q_camids = np.ones(num_q).astype(np.int32) 100 | g_camids = 2* np.ones(num_g).astype(np.int32) 101 | 102 | for q_idx in range(num_q): 103 | # get query pid and camid 104 | q_pid = q_pids[q_idx] 105 | q_camid = q_camids[q_idx] 106 | 107 | # remove gallery samples that have the same pid and camid with query 108 | order = indices[q_idx] 109 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 110 | keep = np.invert(remove) 111 | 112 | # compute cmc curve 113 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 114 | if not np.any(raw_cmc): 115 | # this condition is true when query identity does not appear in gallery 116 | continue 117 | 118 | cmc = raw_cmc.cumsum() 119 | cmc[cmc > 1] = 1 120 | 121 | all_cmc.append(cmc[:max_rank]) 122 | num_valid_q += 1. 123 | 124 | # compute average precision 125 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 126 | num_rel = raw_cmc.sum() 127 | tmp_cmc = raw_cmc.cumsum() 128 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 129 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 130 | AP = tmp_cmc.sum() / num_rel 131 | all_AP.append(AP) 132 | 133 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 134 | 135 | all_cmc = np.asarray(all_cmc).astype(np.float32) 136 | all_cmc = all_cmc.sum(0) / num_valid_q 137 | mAP = np.mean(all_AP) 138 | 139 | return all_cmc, mAP -------------------------------------------------------------------------------- /heterogeneity_loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import nn, tensor 3 | import torch 4 | from torch.autograd import Variable 5 | 6 | class hetero_loss(nn.Module): 7 | def __init__(self, margin=0.1, dist_type = 'l2'): 8 | super(hetero_loss, self).__init__() 9 | self.margin = margin 10 | self.dist_type = dist_type 11 | if dist_type == 'l2': 12 | self.dist = nn.MSELoss(reduction='sum') 13 | if dist_type == 'cos': 14 | self.dist = nn.CosineSimilarity(dim=0) 15 | if dist_type == 'l1': 16 | self.dist = nn.L1Loss() 17 | 18 | def forward(self, feat1, feat2, label1, label2): 19 | feat_size = feat1.size()[1] 20 | feat_num = feat1.size()[0] 21 | label_num = len(label1.unique()) 22 | feat1 = feat1.chunk(label_num, 0) 23 | feat2 = feat2.chunk(label_num, 0) 24 | #loss = Variable(.cuda()) 25 | for i in range(label_num): 26 | center1 = torch.mean(feat1[i], dim=0) 27 | center2 = torch.mean(feat2[i], dim=0) 28 | if self.dist_type == 'l2' or self.dist_type == 'l1': 29 | if i == 0: 30 | dist = max(0, self.dist(center1, center2) - self.margin) 31 | else: 32 | dist += max(0, self.dist(center1, center2) - self.margin) 33 | elif self.dist_type == 'cos': 34 | if i == 0: 35 | dist = max(0, 1-self.dist(center1, center2) - self.margin) 36 | else: 37 | dist += max(0, 1-self.dist(center1, center2) - self.margin) 38 | 39 | return dist 40 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from torchvision import models 5 | from torch.autograd import Variable 6 | 7 | class Normalize(nn.Module): 8 | def __init__(self, power=2): 9 | super(Normalize, self).__init__() 10 | self.power = power 11 | 12 | def forward(self, x): 13 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1./self.power) 14 | out = x.div(norm) 15 | return out 16 | 17 | # ##################################################################### 18 | def weights_init_kaiming(m): 19 | classname = m.__class__.__name__ 20 | # print(classname) 21 | if classname.find('Conv') != -1: 22 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 23 | elif classname.find('Linear') != -1: 24 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 25 | # init.normal_(m.weight.data, 0, 0.001) 26 | init.zeros_(m.bias.data) 27 | elif classname.find('BatchNorm1d') != -1: 28 | init.normal_(m.weight.data, 1.0, 0.01) 29 | init.zeros_(m.bias.data) 30 | 31 | def weights_init_classifier(m): 32 | classname = m.__class__.__name__ 33 | if classname.find('Linear') != -1: 34 | init.normal_(m.weight.data, 0, 0.001) 35 | init.zeros_(m.bias.data) 36 | 37 | # Defines the new fc layer and classification layer 38 | # |--Linear--|--bn--|--relu--|--Linear--| 39 | class FeatureBlock(nn.Module): 40 | def __init__(self, input_dim, low_dim, dropout=0.5, relu=True): 41 | super(FeatureBlock, self).__init__() 42 | feat_block = [] 43 | feat_block += [nn.Linear(input_dim, low_dim)] 44 | feat_block += [nn.BatchNorm1d(low_dim)] 45 | 46 | feat_block = nn.Sequential(*feat_block) 47 | feat_block.apply(weights_init_kaiming) 48 | self.feat_block = feat_block 49 | def forward(self, x): 50 | x = self.feat_block(x) 51 | return x 52 | 53 | class ClassBlock(nn.Module): 54 | def __init__(self, input_dim, class_num, dropout=0.5, relu=True): 55 | super(ClassBlock, self).__init__() 56 | classifier = [] 57 | if relu: 58 | classifier += [nn.LeakyReLU(0.1)] 59 | if dropout: 60 | classifier += [nn.Dropout(p=dropout)] 61 | 62 | classifier += [nn.Linear(input_dim, class_num)] 63 | classifier = nn.Sequential(*classifier) 64 | classifier.apply(weights_init_classifier) 65 | 66 | self.classifier = classifier 67 | def forward(self, x): 68 | x = self.classifier(x) 69 | return x 70 | 71 | # Define the ResNet18-based Model 72 | class visible_net_resnet(nn.Module): 73 | def __init__(self, arch ='resnet18'): 74 | super(visible_net_resnet, self).__init__() 75 | if arch =='resnet18': 76 | model_ft = models.resnet18(pretrained=True) 77 | elif arch =='resnet50': 78 | model_ft = models.resnet50(pretrained=True) 79 | 80 | for mo in model_ft.layer4[0].modules(): 81 | if isinstance(mo, nn.Conv2d): 82 | mo.stride = (1, 1) 83 | 84 | # avg pooling to global pooling 85 | model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1)) 86 | self.visible = model_ft 87 | self.dropout = nn.Dropout(p=0.5) 88 | def forward(self, x): 89 | x = self.visible.conv1(x) 90 | x = self.visible.bn1(x) 91 | x = self.visible.relu(x) 92 | x = self.visible.maxpool(x) 93 | x = self.visible.layer1(x) 94 | x = self.visible.layer2(x) 95 | x = self.visible.layer3(x) 96 | x = self.visible.layer4(x) 97 | num_part = 6 98 | # pool size 99 | sx = x.size(2) / num_part 100 | sx = int(sx) 101 | kx = x.size(2) - sx * (num_part - 1) 102 | kx = int(kx) 103 | x = nn.functional.avg_pool2d(x, kernel_size=(kx, x.size(3)), stride=(sx, x.size(3))) 104 | #x = self.visible.avgpool(x) 105 | x = x.view(x.size(0), x.size(1), x.size(2)) 106 | # x = self.dropout(x) 107 | return x 108 | 109 | class thermal_net_resnet(nn.Module): 110 | def __init__(self, arch ='resnet18'): 111 | super(thermal_net_resnet, self).__init__() 112 | if arch =='resnet18': 113 | model_ft = models.resnet18(pretrained=True) 114 | elif arch =='resnet50': 115 | model_ft = models.resnet50(pretrained=True) 116 | 117 | for mo in model_ft.layer4[0].modules(): 118 | if isinstance(mo, nn.Conv2d): 119 | mo.stride = (1, 1) 120 | 121 | # avg pooling to global pooling 122 | model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1)) 123 | self.thermal = model_ft 124 | self.dropout = nn.Dropout(p=0.5) 125 | def forward(self, x): 126 | x = self.thermal.conv1(x) 127 | x = self.thermal.bn1(x) 128 | x = self.thermal.relu(x) 129 | x = self.thermal.maxpool(x) 130 | x = self.thermal.layer1(x) 131 | x = self.thermal.layer2(x) 132 | x = self.thermal.layer3(x) 133 | x = self.thermal.layer4(x) 134 | num_part = 6 # number of part 135 | # pool size 136 | sx = x.size(2) / num_part 137 | sx = int(sx) 138 | kx = x.size(2) - sx * (num_part-1) 139 | kx = int(kx) 140 | x = nn.functional.avg_pool2d(x, kernel_size=(kx, x.size(3)), stride=(sx, x.size(3))) 141 | #x = self.thermal.avgpool(x) 142 | x = x.view(x.size(0), x.size(1), x.size(2)) 143 | # x = self.dropout(x) 144 | return x 145 | 146 | class embed_net(nn.Module): 147 | def __init__(self, low_dim, class_num, drop = 0.5, arch ='resnet50'): 148 | super(embed_net, self).__init__() 149 | if arch =='resnet18': 150 | self.visible_net = visible_net_resnet(arch = arch) 151 | self.thermal_net = thermal_net_resnet(arch = arch) 152 | pool_dim = 512 153 | elif arch =='resnet50': 154 | self.visible_net = visible_net_resnet(arch = arch) 155 | self.thermal_net = thermal_net_resnet(arch = arch) 156 | pool_dim = 2048 157 | 158 | self.feature1 = FeatureBlock(pool_dim, low_dim, dropout = drop) 159 | self.feature2 = FeatureBlock(pool_dim, low_dim, dropout=drop) 160 | self.feature3 = FeatureBlock(pool_dim, low_dim, dropout=drop) 161 | self.feature4 = FeatureBlock(pool_dim, low_dim, dropout=drop) 162 | self.feature5 = FeatureBlock(pool_dim, low_dim, dropout=drop) 163 | self.feature6 = FeatureBlock(pool_dim, low_dim, dropout=drop) 164 | self.classifier1 = ClassBlock(low_dim, class_num, dropout = drop) 165 | self.classifier2 = ClassBlock(low_dim, class_num, dropout = drop) 166 | self.classifier3 = ClassBlock(low_dim, class_num, dropout=drop) 167 | self.classifier4 = ClassBlock(low_dim, class_num, dropout=drop) 168 | self.classifier5 = ClassBlock(low_dim, class_num, dropout=drop) 169 | self.classifier6 = ClassBlock(low_dim, class_num, dropout=drop) 170 | 171 | self.l2norm = Normalize(2) 172 | 173 | def forward(self, x1, x2, modal = 0 ): 174 | if modal==0: 175 | x1 = self.visible_net(x1) 176 | x1 = x1.chunk(6,2) 177 | x1_0 = x1[0].contiguous().view(x1[0].size(0),-1) 178 | x1_1 = x1[1].contiguous().view(x1[1].size(0), -1) 179 | x1_2 = x1[2].contiguous().view(x1[2].size(0), -1) 180 | x1_3 = x1[3].contiguous().view(x1[3].size(0), -1) 181 | x1_4 = x1[4].contiguous().view(x1[4].size(0), -1) 182 | x1_5 = x1[5].contiguous().view(x1[5].size(0), -1) 183 | x2 = self.thermal_net(x2) 184 | x2 = x2.chunk(6, 2) 185 | x2_0 = x2[0].contiguous().view(x2[0].size(0), -1) 186 | x2_1 = x2[1].contiguous().view(x2[1].size(0), -1) 187 | x2_2 = x2[2].contiguous().view(x2[2].size(0), -1) 188 | x2_3 = x2[3].contiguous().view(x2[3].size(0), -1) 189 | x2_4 = x2[4].contiguous().view(x2[4].size(0), -1) 190 | x2_5 = x2[5].contiguous().view(x2[5].size(0), -1) 191 | x_0 = torch.cat((x1_0, x2_0), 0) 192 | x_1 = torch.cat((x1_1, x2_1), 0) 193 | x_2 = torch.cat((x1_2, x2_2), 0) 194 | x_3 = torch.cat((x1_3, x2_3), 0) 195 | x_4 = torch.cat((x1_4, x2_4), 0) 196 | x_5 = torch.cat((x1_5, x2_5), 0) 197 | elif modal ==1: 198 | x = self.visible_net(x1) 199 | x = x.chunk(6,2) 200 | x_0 = x[0].contiguous().view(x[0].size(0),-1) 201 | x_1 = x[1].contiguous().view(x[1].size(0), -1) 202 | x_2 = x[2].contiguous().view(x[2].size(0), -1) 203 | x_3 = x[3].contiguous().view(x[3].size(0), -1) 204 | x_4 = x[4].contiguous().view(x[4].size(0), -1) 205 | x_5 = x[5].contiguous().view(x[5].size(0), -1) 206 | elif modal ==2: 207 | x = self.thermal_net(x2) 208 | x = x.chunk(6, 2) 209 | x_0 = x[0].contiguous().view(x[0].size(0), -1) 210 | x_1 = x[1].contiguous().view(x[1].size(0), -1) 211 | x_2 = x[2].contiguous().view(x[2].size(0), -1) 212 | x_3 = x[3].contiguous().view(x[3].size(0), -1) 213 | x_4 = x[4].contiguous().view(x[4].size(0), -1) 214 | x_5 = x[5].contiguous().view(x[5].size(0), -1) 215 | 216 | y_0 = self.feature1(x_0) 217 | y_1 = self.feature2(x_1) 218 | y_2 = self.feature3(x_2) 219 | y_3 = self.feature4(x_3) 220 | y_4 = self.feature5(x_4) 221 | y_5 = self.feature6(x_5) 222 | #y = self.feature(x) 223 | out_0 = self.classifier1(y_0) 224 | out_1 = self.classifier2(y_1) 225 | out_2 = self.classifier3(y_2) 226 | out_3 = self.classifier4(y_3) 227 | out_4 = self.classifier5(y_4) 228 | out_5 = self.classifier6(y_5) 229 | #out = self.classifier(y) 230 | if self.training: 231 | return (out_0, out_1, out_2, out_3, out_4, out_5), (self.l2norm(y_0), self.l2norm(y_1), self.l2norm(y_2), self.l2norm(y_3), self.l2norm(y_4), self.l2norm(y_5)) 232 | else: 233 | x_0 = self.l2norm(x_0) 234 | x_1 = self.l2norm(x_1) 235 | x_2 = self.l2norm(x_2) 236 | x_3 = self.l2norm(x_3) 237 | x_4 = self.l2norm(x_4) 238 | x_5 = self.l2norm(x_5) 239 | x = torch.cat((x_0, x_1, x_2, x_3, x_4, x_5), 1) 240 | y_0 = self.l2norm(y_0) 241 | y_1 = self.l2norm(y_1) 242 | y_2 = self.l2norm(y_2) 243 | y_3 = self.l2norm(y_3) 244 | y_4 = self.l2norm(y_4) 245 | y_5 = self.l2norm(y_5) 246 | y = torch.cat((y_0, y_1, y_2, y_3, y_4, y_5), 1) 247 | return x, y 248 | 249 | 250 | # debug model structure 251 | 252 | # net = embed_net(512, 319) 253 | # net.train() 254 | # input = Variable(torch.FloatTensor(8, 3, 224, 224)) 255 | # x, y = net(input, input) -------------------------------------------------------------------------------- /pre_process_sysu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import pdb 4 | import os 5 | 6 | data_path = '/home/omnisky/person_reID/HHL-master/stargan4reid/SYSU-MM01' 7 | 8 | rgb_cameras = ['cam1','cam2','cam4','cam5'] 9 | ir_cameras = ['cam3','cam6'] 10 | 11 | # load id info 12 | file_path_train = os.path.join(data_path,'exp/train_id.txt') 13 | file_path_val = os.path.join(data_path,'exp/val_id.txt') 14 | with open(file_path_train, 'r') as file: 15 | ids = file.read().splitlines() 16 | ids = [int(y) for y in ids[0].split(',')] 17 | id_train = ["%04d" % x for x in ids] 18 | 19 | with open(file_path_val, 'r') as file: 20 | ids = file.read().splitlines() 21 | ids = [int(y) for y in ids[0].split(',')] 22 | id_val = ["%04d" % x for x in ids] 23 | 24 | # combine train and val split 25 | id_train.extend(id_val) 26 | 27 | files_rgb = [] 28 | files_ir = [] 29 | for id in sorted(id_train): 30 | for cam in rgb_cameras: 31 | img_dir = os.path.join(data_path,cam,id) 32 | if os.path.isdir(img_dir): 33 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 34 | files_rgb.extend(new_files) 35 | 36 | for cam in ir_cameras: 37 | img_dir = os.path.join(data_path,cam,id) 38 | if os.path.isdir(img_dir): 39 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 40 | files_ir.extend(new_files) 41 | 42 | # relabel 43 | pid_container = set() 44 | for img_path in files_ir: 45 | pid = int(img_path[-13:-9]) 46 | pid_container.add(pid) 47 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 48 | fix_image_width = 144 49 | fix_image_height = 288 50 | def read_imgs(train_image): 51 | train_img = [] 52 | train_label = [] 53 | for img_path in train_image: 54 | # img 55 | img = Image.open(img_path) 56 | img = img.resize((fix_image_width, fix_image_height), Image.ANTIALIAS) 57 | pix_array = np.array(img) 58 | 59 | train_img.append(pix_array) 60 | 61 | # label 62 | pid = int(img_path[-13:-9]) 63 | pid = pid2label[pid] 64 | train_label.append(pid) 65 | return np.array(train_img), np.array(train_label) 66 | 67 | # rgb imges 68 | train_img, train_label = read_imgs(files_rgb) 69 | np.save(data_path + 'train_rgb_resized_img.npy', train_img) 70 | np.save(data_path + 'train_rgb_resized_label.npy', train_label) 71 | 72 | # ir imges 73 | train_img, train_label = read_imgs(files_ir) 74 | np.save(data_path + 'train_ir_resized_img.npy', train_img) 75 | np.save(data_path + 'train_ir_resized_label.npy', train_label) 76 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import sys 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | import torch.backends.cudnn as cudnn 9 | from torch.autograd import Variable 10 | import torch.utils.data as data 11 | #import torchvision 12 | #import torchvision.transforms as transforms 13 | from data_loader import SYSUData, RegDBData, TestData 14 | from data_manager import * 15 | from eval_metrics import eval_sysu, eval_regdb 16 | from model import embed_net 17 | from utils import * 18 | import time 19 | import scipy.io as scio 20 | import Transform as transforms 21 | 22 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 23 | parser.add_argument('--dataset', default='sysu', help='dataset name: regdb or sysu]') 24 | parser.add_argument('--lr', default=0.01, type=float, help='learning rate') 25 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer') 26 | parser.add_argument('--arch', default='resnet50', type=str, help='network baseline') 27 | parser.add_argument('--resume', '-r', default='', type=str, help='resume from checkpoint') 28 | parser.add_argument('--model_path', default='save_model/', type=str, help='model save path') 29 | parser.add_argument('--log_path', default='log2/', type=str, help='log save path') 30 | parser.add_argument('--workers', default=4, type=int, metavar='N', 31 | help='number of data loading workers (default: 4)') 32 | parser.add_argument('--low-dim', default=512, type=int, 33 | metavar='D', help='feature dimension') 34 | parser.add_argument('--img_w', default=144, type=int, 35 | metavar='imgw', help='img width') 36 | parser.add_argument('--img_h', default=288, type=int, 37 | metavar='imgh', help='img height') 38 | parser.add_argument('--batch-size', default=32, type=int, 39 | metavar='B', help='training batch size') 40 | parser.add_argument('--test-batch', default=64, type=int, 41 | metavar='tb', help='testing batch size') 42 | parser.add_argument('--method', default='id', type=str, 43 | metavar='m', help='Method type') 44 | parser.add_argument('--drop', default=0.0, type=float, 45 | metavar='drop', help='dropout ratio') 46 | parser.add_argument('--trial', default=1, type=int, 47 | metavar='t', help='trial') 48 | parser.add_argument('--gpu', default='0', type=str, 49 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 50 | parser.add_argument('--mode', default='all', type=str, help='all or indoor') 51 | parser.add_argument('--per_img', default=8, type=int, 52 | help='number of samples of an id in every batch') 53 | parser.add_argument('--w_hc', default=0.5, type=float, 54 | help='weight of Hetero-Center Loss') 55 | parser.add_argument('--thd', default=0, type=float, 56 | help='threshold of Hetero-Center Loss') 57 | parser.add_argument('--gall-mode', default='single', type=str, help='single or multi') 58 | 59 | args = parser.parse_args() 60 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 61 | 62 | torch.manual_seed(1) 63 | torch.cuda.manual_seed(1) 64 | torch.cuda.manual_seed_all(1) # 为所有GPU设置随机种子 65 | np.random.seed(1) 66 | random.seed(1) 67 | 68 | dataset = args.dataset 69 | if dataset == 'sysu': 70 | data_path = '/home/omnisky/person_reID/zyx/Cross-Modal-Re-ID-baseline-master/Cross-Modal-Re-ID-baseline-master/dataset/' 71 | log_path = args.log_path + 'sysu_log/' 72 | n_class = 395 73 | test_mode = [1, 2] 74 | elif dataset =='regdb': 75 | data_path = 'RegDB/' 76 | n_class = 206 77 | test_mode = [2, 1] 78 | 79 | if not os.path.isdir(log_path): 80 | os.makedirs(log_path) 81 | 82 | if args.method =='id': 83 | suffix = dataset + '_id_bn_relu' 84 | suffix = suffix + '_drop_{}'.format(args.drop) 85 | suffix = suffix + '_lr_{:1.1e}'.format(args.lr) 86 | suffix = suffix + '_dim_{}'.format(args.low_dim) 87 | suffix = suffix + '_whc_{}'.format(args.w_hc) 88 | suffix = suffix + '_thd_{}'.format(args.thd) 89 | suffix = suffix + '_pimg_{}'.format(args.per_img) 90 | suffix = suffix + '_gm_{}'.format(args.gall_mode) 91 | suffix = suffix + '_m_{}'.format(args.mode) 92 | test_log_file = open(log_path + suffix + '.txt', "w") 93 | sys.stdout = Logger(log_path + suffix + '_os.txt') 94 | 95 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 96 | best_acc = 0 # best test accuracy 97 | start_epoch = 0 98 | 99 | print('==> Building model..') 100 | net = embed_net(args.low_dim, n_class, drop = args.drop, arch=args.arch) 101 | net.to(device) 102 | cudnn.benchmark = True 103 | 104 | print('==> Resuming from checkpoint..') 105 | checkpoint_path = args.model_path 106 | if len(args.resume)>0: 107 | model_path = checkpoint_path + args.resume 108 | if os.path.isfile(model_path): 109 | print('==> loading checkpoint {}'.format(args.resume)) 110 | print('==> loading checkpoint {}'.format(args.resume), file=test_log_file) 111 | checkpoint = torch.load(model_path) 112 | start_epoch = checkpoint['epoch'] 113 | net.load_state_dict(checkpoint['net']) 114 | print('==> loaded checkpoint {} (epoch {})' 115 | .format(args.resume, checkpoint['epoch'])) 116 | print('==> loaded checkpoint {} (epoch {})' 117 | .format(args.resume, checkpoint['epoch']), file=test_log_file) 118 | else: 119 | print('==> no checkpoint found at {}'.format(args.resume)) 120 | print('==> no checkpoint found at {}'.format(args.resume), file=test_log_file) 121 | 122 | 123 | if args.method =='id': 124 | criterion = nn.CrossEntropyLoss() 125 | criterion.to(device) 126 | 127 | print('==> Loading data..') 128 | print('==> Loading data..', file=test_log_file) 129 | # Data loading code 130 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 131 | transform_train = transforms.Compose([ 132 | transforms.ToPILImage(), 133 | transforms.RandomCrop((args.img_h,args.img_w)), 134 | transforms.RandomHorizontalFlip(), 135 | transforms.ToTensor(), 136 | normalize, 137 | ]) 138 | 139 | transform_test = transforms.Compose([ 140 | transforms.ToPILImage(), 141 | #transforms.Resize((args.img_h,args.img_w)), 142 | transforms.RectScale(args.img_h, args.img_w), 143 | transforms.ToTensor(), 144 | normalize, 145 | ]) 146 | 147 | end = time.time() 148 | 149 | if dataset =='sysu': 150 | # testing set 151 | query_img, query_label, query_cam = process_query_sysu(data_path, mode = args.mode) 152 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode = args.mode, trial = 0, gall_mode=args.gall_mode) 153 | 154 | 155 | elif dataset =='regdb': 156 | # training set 157 | trainset = RegDBData(data_path, args.trial, transform=transform_train) 158 | # generate the idx of each person identity 159 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 160 | 161 | # testing set 162 | query_img, query_label = process_test_regdb(data_path, trial = args.trial, modal = 'visible') 163 | gall_img, gall_label = process_test_regdb(data_path, trial = args.trial, modal = 'thermal') 164 | 165 | gallset = TestData(gall_img, gall_label, transform = transform_test, img_size =(args.img_w,args.img_h)) 166 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 167 | 168 | nquery = len(query_label) 169 | ngall = len(gall_label) 170 | print("Dataset statistics:") 171 | print(" ------------------------------") 172 | print(" subset | # ids | # images") 173 | print(" ------------------------------") 174 | print(" query | {:5d} | {:8d}".format(len(np.unique(query_label)), nquery)) 175 | print(" gallery | {:5d} | {:8d}".format(len(np.unique(gall_label)), ngall)) 176 | print(" ------------------------------") 177 | 178 | print("Dataset statistics:", file=test_log_file) 179 | print(" ------------------------------", file=test_log_file) 180 | print(" subset | # ids | # images", file=test_log_file) 181 | print(" ------------------------------", file=test_log_file) 182 | print(" query | {:5d} | {:8d}".format(len(np.unique(query_label)), nquery), file=test_log_file) 183 | print(" gallery | {:5d} | {:8d}".format(len(np.unique(gall_label)), ngall), file=test_log_file) 184 | print(" ------------------------------", file=test_log_file) 185 | 186 | queryset = TestData(query_img, query_label, transform = transform_test, img_size =(args.img_w, args.img_h)) 187 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 188 | print('Data Loading Time:\t {:.3f}'.format(time.time()-end)) 189 | print('Data Loading Time:\t {:.3f}'.format(time.time()-end), file=test_log_file) 190 | 191 | feature_dim = args.low_dim 192 | 193 | if args.arch =='resnet50': 194 | pool_dim = 2048 195 | elif args.arch =='resnet18': 196 | pool_dim = 512 197 | 198 | def extract_gall_feat(gall_loader): 199 | net.eval() 200 | print('Extracting Gallery Feature...') 201 | print('Extracting Gallery Feature...', file=test_log_file) 202 | start = time.time() 203 | ptr = 0 204 | gall_feat = np.zeros((ngall, 6*feature_dim)) 205 | gall_feat_pool = np.zeros((ngall, 6*pool_dim)) 206 | with torch.no_grad(): 207 | for batch_idx, (input, label ) in enumerate(gall_loader): 208 | batch_num = input.size(0) 209 | input = Variable(input.cuda()) 210 | pool_feat, feat = net(input, input, test_mode[0]) 211 | gall_feat[ptr:ptr+batch_num,: ] = feat.detach().cpu().numpy() 212 | gall_feat_pool[ptr:ptr+batch_num,: ] = pool_feat.detach().cpu().numpy() 213 | ptr = ptr + batch_num 214 | print('Extracting Time:\t {:.3f}'.format(time.time() - start), file=test_log_file) 215 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 216 | return gall_feat, gall_feat_pool 217 | 218 | def extract_query_feat(query_loader): 219 | net.eval() 220 | print ('Extracting Query Feature...') 221 | print('Extracting Query Feature...', file=test_log_file) 222 | start = time.time() 223 | ptr = 0 224 | query_feat = np.zeros((nquery, 6*feature_dim)) 225 | query_feat_pool = np.zeros((nquery, 6*pool_dim)) 226 | with torch.no_grad(): 227 | for batch_idx, (input, label ) in enumerate(query_loader): 228 | batch_num = input.size(0) 229 | input = Variable(input.cuda()) 230 | pool_feat, feat = net(input, input, test_mode[1]) 231 | query_feat[ptr:ptr+batch_num,: ] = feat.detach().cpu().numpy() 232 | query_feat_pool[ptr:ptr+batch_num,: ] = pool_feat.detach().cpu().numpy() 233 | ptr = ptr + batch_num 234 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 235 | print('Extracting Time:\t {:.3f}'.format(time.time() - start), file=test_log_file) 236 | return query_feat, query_feat_pool 237 | 238 | query_feat, query_feat_pool = extract_query_feat(query_loader) 239 | 240 | all_cmc = 0 241 | all_mAP = 0 242 | all_cmc_pool = 0 243 | if dataset =='regdb': 244 | gall_feat, gall_feat_pool = extract_gall_feat(gall_loader) 245 | # fc feature 246 | distmat = np.matmul(query_feat, np.transpose(gall_feat)) 247 | cmc, mAP = eval_regdb(-distmat, query_label, gall_label) 248 | 249 | # pool5 feature 250 | distmat_pool = np.matmul(query_feat_pool, np.transpose(gall_feat_pool)) 251 | cmc_pool, mAP_pool = eval_regdb(-distmat_pool, query_label, gall_label) 252 | 253 | print ('Test Trial: {}'.format(args.trial)) 254 | print('FC: top-1: {:.2%} | top-5: {:.2%} | top-10: {:.2%}| top-20: {:.2%}'.format( 255 | cmc[0], cmc[4], cmc[9], cmc[19])) 256 | print('mAP: {:.2%}'.format(mAP)) 257 | print('POOL5: top-1: {:.2%} | top-5: {:.2%} | top-10: {:.2%}| top-20: {:.2%}'.format( 258 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19])) 259 | print('mAP: {:.2%}'.format(mAP_pool)) 260 | 261 | print('Test Trial: {}'.format(args.trial), file=test_log_file) 262 | print('FC: top-1: {:.2%} | top-5: {:.2%} | top-10: {:.2%}| top-20: {:.2%}'.format( 263 | cmc[0], cmc[4], cmc[9], cmc[19]), file=test_log_file) 264 | print('mAP: {:.2%}'.format(mAP), file=test_log_file) 265 | print('POOL5: top-1: {:.2%} | top-5: {:.2%} | top-10: {:.2%}| top-20: {:.2%}'.format( 266 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19]), file=test_log_file) 267 | print('mAP: {:.2%}'.format(mAP_pool), file=test_log_file) 268 | 269 | elif dataset =='sysu': 270 | for trial in range(10): 271 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode = args.mode, trial = trial, gall_mode=args.gall_mode) 272 | 273 | trial_gallset = TestData(gall_img, gall_label, transform = transform_test,img_size =(args.img_w,args.img_h)) 274 | trial_gall_loader = data.DataLoader(trial_gallset, batch_size=args.test_batch, shuffle=False, num_workers=4) 275 | 276 | gall_feat, gall_feat_pool = extract_gall_feat(trial_gall_loader) 277 | 278 | # fc feature 279 | distmat = np.matmul(query_feat, np.transpose(gall_feat)) 280 | cmc, mAP = eval_sysu(-distmat, query_label, gall_label,query_cam, gall_cam) 281 | 282 | # pool5 feature 283 | distmat_pool = np.matmul(query_feat_pool, np.transpose(gall_feat_pool)) 284 | cmc_pool, mAP_pool = eval_sysu(-distmat_pool, query_label, gall_label,query_cam, gall_cam) 285 | if trial ==0: 286 | all_cmc = cmc 287 | all_mAP = mAP 288 | all_cmc_pool = cmc_pool 289 | all_mAP_pool = mAP_pool 290 | else: 291 | all_cmc = all_cmc + cmc 292 | all_mAP = all_mAP + mAP 293 | all_cmc_pool = all_cmc_pool + cmc_pool 294 | all_mAP_pool = all_mAP_pool + mAP_pool 295 | 296 | print ('Test Trial: {}'.format(trial)) 297 | print('FC: top-1: {:.2%} | top-5: {:.2%} | top-10: {:.2%}| top-20: {:.2%}'.format( 298 | cmc[0], cmc[4], cmc[9], cmc[19])) 299 | print('mAP: {:.2%}'.format(mAP)) 300 | print('POOL5: top-1: {:.2%} | top-5: {:.2%} | top-10: {:.2%}| top-20: {:.2%}'.format( 301 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19])) 302 | print('mAP: {:.2%}'.format(mAP_pool)) 303 | 304 | print('Test Trial: {}'.format(trial), file=test_log_file) 305 | print('FC: top-1: {:.2%} | top-5: {:.2%} | top-10: {:.2%}| top-20: {:.2%}'.format( 306 | cmc[0], cmc[4], cmc[9], cmc[19]), file=test_log_file) 307 | print('mAP: {:.2%}'.format(mAP), file=test_log_file) 308 | print('POOL5: top-1: {:.2%} | top-5: {:.2%} | top-10: {:.2%}| top-20: {:.2%}'.format( 309 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19]), file=test_log_file) 310 | print('mAP: {:.2%}'.format(mAP_pool), file=test_log_file) 311 | 312 | cmc = all_cmc /10 313 | mAP = all_mAP /10 314 | 315 | cmc_pool = all_cmc_pool /10 316 | mAP_pool = all_mAP_pool /10 317 | print ('All Average:') 318 | print('FC: top-1: {:.2%} | top-5: {:.2%} | top-10: {:.2%}| top-20: {:.2%}'.format(cmc[0], cmc[4], cmc[9], cmc[19])) 319 | print('mAP: {:.2%}'.format(mAP)) 320 | print('POOL5: top-1: {:.2%} | top-5: {:.2%} | top-10: {:.2%}| top-20: {:.2%}'.format( 321 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19])) 322 | print('mAP: {:.2%}'.format(mAP_pool)) 323 | 324 | print('All Average:', file=test_log_file) 325 | print('FC: top-1: {:.2%} | top-5: {:.2%} | top-10: {:.2%}| top-20: {:.2%}'.format(cmc[0], cmc[4], cmc[9], cmc[19]), file=test_log_file) 326 | print('mAP: {:.2%}'.format(mAP), file=test_log_file) 327 | print('POOL5: top-1: {:.2%} | top-5: {:.2%} | top-10: {:.2%}| top-20: {:.2%}'.format( 328 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19]), file=test_log_file) 329 | print('mAP: {:.2%}'.format(mAP_pool), file=test_log_file) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import sys 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.backends.cudnn as cudnn 9 | from torch.autograd import Variable 10 | import torch.utils.data as data 11 | #import torchvision 12 | #import torchvision.transforms as transforms 13 | from data_loader import SYSUData, RegDBData, TestData 14 | from data_manager import * 15 | from eval_metrics import eval_sysu, eval_regdb 16 | from model import embed_net 17 | from utils import * 18 | import Transform as transforms 19 | from heterogeneity_loss import hetero_loss 20 | import xlwt,xlrd 21 | from torch.backends import cudnn 22 | 23 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 24 | parser.add_argument('--dataset', default='sysu', help='dataset name: regdb or sysu]') 25 | parser.add_argument('--lr', default=0.01, type=float, help='learning rate') 26 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer') 27 | parser.add_argument('--arch', default='resnet50', type=str, 28 | help='network baseline:resnet18 or resnet50') 29 | parser.add_argument('--resume', '-r', default='', type=str, 30 | help='resume from checkpoint') 31 | parser.add_argument('--test-only', action='store_true', help='test only') 32 | parser.add_argument('--model_path', default='save_model/', type=str, 33 | help='model save path') 34 | parser.add_argument('--save_epoch', default=20, type=int, 35 | metavar='s', help='save model every 10 epochs') 36 | parser.add_argument('--log_path', default='log/', type=str, 37 | help='log save path') 38 | parser.add_argument('--workers', default=4, type=int, metavar='N', 39 | help='number of data loading workers (default: 4)') 40 | parser.add_argument('--low-dim', default=512, type=int, 41 | metavar='D', help='feature dimension') 42 | parser.add_argument('--img_w', default=144, type=int, 43 | metavar='imgw', help='img width') 44 | parser.add_argument('--img_h', default=288, type=int, 45 | metavar='imgh', help='img height') 46 | parser.add_argument('--batch-size', default=32, type=int, 47 | metavar='B', help='training batch size') 48 | parser.add_argument('--test-batch', default=64, type=int, 49 | metavar='tb', help='testing batch size') 50 | parser.add_argument('--method', default='id', type=str, 51 | metavar='m', help='method type') 52 | parser.add_argument('--drop', default=0.0, type=float, 53 | metavar='drop', help='dropout ratio') 54 | parser.add_argument('--trial', default=1, type=int, 55 | metavar='t', help='trial (only for RegDB dataset)') 56 | parser.add_argument('--gpu', default='3', type=str, 57 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 58 | parser.add_argument('--mode', default='all', type=str, help='all or indoor') 59 | parser.add_argument('--per_img', default=8, type=int, 60 | help='number of samples of an id in every batch') 61 | parser.add_argument('--w_hc', default=0.5, type=float, 62 | help='weight of Hetero-Center Loss') 63 | parser.add_argument('--thd', default=0, type=float, 64 | help='threshold of Hetero-Center Loss') 65 | parser.add_argument('--epochs', default=60, type=int, 66 | help='weight of Hetero-Center Loss') 67 | parser.add_argument('--dist-type', default='l2', type=str, 68 | help='type of distance') 69 | 70 | 71 | torch.manual_seed(1) 72 | torch.cuda.manual_seed(1) 73 | torch.cuda.manual_seed_all(1) # 为所有GPU设置随机种子 74 | np.random.seed(1) 75 | random.seed(1) 76 | 77 | def worker_init_fn(worker_id): 78 | # After creating the workers, each worker has an independent seed that is initialized to the curent random seed + the id of the worker 79 | np.random.seed(0 + worker_id) 80 | 81 | args = parser.parse_args() 82 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 83 | #fix_random(0) 84 | 85 | 86 | dataset = args.dataset 87 | if dataset == 'sysu': 88 | data_path = '/home/omnisky/person_reID/zyx/Cross-Modal-Re-ID-baseline-master/Cross-Modal-Re-ID-baseline-master/dataset/SYSU-MM01/' 89 | log_path = args.log_path + 'sysu_log/' 90 | test_mode = [1, 2] # thermal to visible 91 | elif dataset =='regdb': 92 | data_path = 'RegDB/' 93 | log_path = args.log_path + 'regdb_log/' 94 | test_mode = [2, 1] # visible to thermal 95 | 96 | checkpoint_path = args.model_path 97 | 98 | if not os.path.isdir(log_path): 99 | os.makedirs(log_path) 100 | if not os.path.isdir(checkpoint_path): 101 | os.makedirs(checkpoint_path) 102 | 103 | if args.method =='id': 104 | suffix = dataset + '_id_bn_relu' 105 | suffix = suffix + '_drop_{}'.format(args.drop) 106 | suffix = suffix + '_lr_{:1.1e}'.format(args.lr) 107 | suffix = suffix + '_dim_{}'.format(args.low_dim) 108 | suffix = suffix + '_whc_{}'.format(args.w_hc) 109 | suffix = suffix + '_thd_{}'.format(args.thd) 110 | suffix = suffix + '_pimg_{}'.format(args.per_img) 111 | suffix = suffix + '_ds_{}'.format(args.dist_type) 112 | suffix = suffix + '_md_{}'.format(args.mode) 113 | if not args.optim == 'sgd': 114 | suffix = suffix + '_' + args.optim 115 | suffix = suffix + '_' + args.arch 116 | if dataset =='regdb': 117 | suffix = suffix + '_trial_{}'.format(args.trial) 118 | 119 | test_log_file = open(log_path + suffix + '.txt', "w") 120 | sys.stdout = Logger(log_path + suffix + '_os.txt') 121 | 122 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 123 | best_acc = 0 # best test accuracy 124 | start_epoch = 0 125 | feature_dim = args.low_dim 126 | 127 | print('==> Loading data..') 128 | # Data loading code 129 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 130 | transform_train = transforms.Compose([ 131 | transforms.ToPILImage(), 132 | #transforms.Pad(10), 133 | transforms.RectScale(args.img_h, args.img_w), 134 | transforms.RandomCrop((args.img_h,args.img_w)), 135 | transforms.RandomHorizontalFlip(), 136 | transforms.ToTensor(), 137 | normalize, 138 | ]) 139 | transform_test = transforms.Compose([ 140 | transforms.ToPILImage(), 141 | #transforms.Resize((args.img_h,args.img_w)), 142 | transforms.RectScale(args.img_h, args.img_w), 143 | transforms.ToTensor(), 144 | normalize, 145 | ]) 146 | 147 | end = time.time() 148 | if dataset =='sysu': 149 | # training set 150 | trainset = SYSUData(data_path, transform=transform_train) 151 | # generate the idx of each person identity 152 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 153 | 154 | # testing set 155 | query_img, query_label, query_cam = process_query_sysu(data_path, mode = args.mode) 156 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode = args.mode, trial = 0) 157 | 158 | elif dataset =='regdb': 159 | # training set 160 | trainset = RegDBData(data_path, args.trial, transform=transform_train) 161 | # generate the idx of each person identity 162 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 163 | 164 | # testing set 165 | query_img, query_label = process_test_regdb(data_path, trial = args.trial, modal = 'visible') 166 | gall_img, gall_label = process_test_regdb(data_path, trial = args.trial, modal = 'thermal') 167 | 168 | gallset = TestData(gall_img, gall_label, transform = transform_test, img_size =(args.img_w,args.img_h)) 169 | queryset = TestData(query_img, query_label, transform = transform_test, img_size =(args.img_w,args.img_h)) 170 | 171 | # testing data loader 172 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers, worker_init_fn=worker_init_fn) 173 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers, worker_init_fn=worker_init_fn) 174 | 175 | n_class = len(np.unique(trainset.train_color_label)) 176 | nquery = len(query_label) 177 | ngall = len(gall_label) 178 | 179 | print('Dataset {} statistics:'.format(dataset)) 180 | print(' ------------------------------') 181 | print(' subset | # ids | # images') 182 | print(' ------------------------------') 183 | print(' visible | {:5d} | {:8d}'.format(n_class, len(trainset.train_color_label))) 184 | print(' thermal | {:5d} | {:8d}'.format(n_class, len(trainset.train_thermal_label))) 185 | print(' ------------------------------') 186 | print(' query | {:5d} | {:8d}'.format(len(np.unique(query_label)), nquery)) 187 | print(' gallery | {:5d} | {:8d}'.format(len(np.unique(gall_label)), ngall)) 188 | print(' ------------------------------') 189 | print('Data Loading Time:\t {:.3f}'.format(time.time()-end)) 190 | 191 | 192 | print('==> Building model..') 193 | net = embed_net(args.low_dim, n_class, drop = args.drop, arch=args.arch) 194 | net.to(device) 195 | #cudnn.benchmark = True 196 | cudnn.benckmark = False 197 | cudnn.deterministic = True 198 | 199 | if len(args.resume)>0: 200 | model_path = checkpoint_path + args.resume 201 | if os.path.isfile(model_path): 202 | print('==> loading checkpoint {}'.format(args.resume)) 203 | checkpoint = torch.load(model_path) 204 | start_epoch = checkpoint['epoch'] 205 | net.load_state_dict(checkpoint['net']) 206 | print('==> loaded checkpoint {} (epoch {})' 207 | .format(args.resume, checkpoint['epoch'])) 208 | else: 209 | print('==> no checkpoint found at {}'.format(args.resume)) 210 | 211 | if args.method =='id': 212 | thd = args.thd 213 | criterion = nn.CrossEntropyLoss() 214 | criterion.to(device) 215 | criterion_het = hetero_loss(margin=thd, dist_type=args.dist_type) 216 | criterion_het.to(device) 217 | 218 | ignored_params = list(map(id, net.feature1.parameters())) \ 219 | + list(map(id, net.feature2.parameters())) \ 220 | + list(map(id, net.feature3.parameters())) \ 221 | + list(map(id, net.feature4.parameters())) \ 222 | + list(map(id, net.feature5.parameters())) \ 223 | + list(map(id, net.feature6.parameters())) \ 224 | + list(map(id, net.classifier1.parameters())) \ 225 | + list(map(id, net.classifier2.parameters())) \ 226 | + list(map(id, net.classifier3.parameters()))\ 227 | + list(map(id, net.classifier4.parameters()))\ 228 | + list(map(id, net.classifier5.parameters()))\ 229 | + list(map(id, net.classifier6.parameters())) 230 | base_params = filter(lambda p: id(p) not in ignored_params, net.parameters()) 231 | if args.optim == 'sgd': 232 | optimizer = optim.SGD([ 233 | {'params': base_params, 'lr': 0.1*args.lr}, 234 | {'params': net.feature1.parameters(), 'lr': args.lr}, 235 | {'params': net.feature2.parameters(), 'lr': args.lr}, 236 | {'params': net.feature3.parameters(), 'lr': args.lr}, 237 | {'params': net.feature4.parameters(), 'lr': args.lr}, 238 | {'params': net.feature5.parameters(), 'lr': args.lr}, 239 | {'params': net.feature6.parameters(), 'lr': args.lr}, 240 | {'params': net.classifier1.parameters(), 'lr': args.lr}, 241 | {'params': net.classifier2.parameters(), 'lr': args.lr}, 242 | {'params': net.classifier3.parameters(), 'lr': args.lr}, 243 | {'params': net.classifier4.parameters(), 'lr': args.lr}, 244 | {'params': net.classifier5.parameters(), 'lr': args.lr}, 245 | {'params': net.classifier6.parameters(), 'lr': args.lr}], 246 | weight_decay=5e-4, momentum=0.9, nesterov=True) 247 | elif args.optim == 'adam': 248 | optimizer = optim.Adam([ 249 | {'params': base_params, 'lr': 0.1*args.lr}, 250 | {'params': net.feature.parameters(), 'lr': args.lr}, 251 | {'params': net.classifier.parameters(), 'lr': args.lr}],weight_decay=5e-4) 252 | 253 | def adjust_learning_rate(optimizer, epoch): 254 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 255 | 256 | if epoch < 30: 257 | lr = args.lr 258 | elif epoch >= 30 and epoch < 60: 259 | lr = args.lr * 0.1 260 | #lr = args.lr 261 | else: 262 | lr = args.lr * 0.01 263 | #lr = args.lr 264 | 265 | optimizer.param_groups[0]['lr'] = 0.1*lr 266 | optimizer.param_groups[1]['lr'] = lr 267 | optimizer.param_groups[2]['lr'] = lr 268 | optimizer.param_groups[3]['lr'] = lr 269 | optimizer.param_groups[4]['lr'] = lr 270 | optimizer.param_groups[5]['lr'] = lr 271 | optimizer.param_groups[6]['lr'] = lr 272 | optimizer.param_groups[7]['lr'] = lr 273 | optimizer.param_groups[8]['lr'] = lr 274 | optimizer.param_groups[9]['lr'] = lr 275 | optimizer.param_groups[10]['lr'] = lr 276 | optimizer.param_groups[11]['lr'] = lr 277 | optimizer.param_groups[12]['lr'] = lr 278 | return lr 279 | 280 | def train(epoch, loss_log): 281 | current_lr = adjust_learning_rate(optimizer, epoch) 282 | train_loss = AverageMeter() 283 | data_time = AverageMeter() 284 | batch_time = AverageMeter() 285 | correct = 0 286 | total = 0 287 | 288 | # switch to train mode 289 | net.train() 290 | end = time.time() 291 | #num_batch = 0 292 | #epoch_loss = 0 293 | for batch_idx, (input1, input2, label1, label2) in enumerate(trainloader): 294 | input1 = Variable(input1.cuda()) 295 | input2 = Variable(input2.cuda()) 296 | 297 | labels = torch.cat((label1,label2),0) 298 | labels = Variable(labels.cuda().long()) 299 | label1 = Variable(label1.cuda().long()) 300 | label2 = Variable(label2.cuda().long()) 301 | data_time.update(time.time() - end) 302 | 303 | outputs, feat = net(input1, input2) 304 | if args.method =='id': 305 | loss0 = criterion(outputs[0], labels) 306 | loss1 = criterion(outputs[1], labels) 307 | loss2 = criterion(outputs[2], labels) 308 | loss3 = criterion(outputs[3], labels) 309 | loss4 = criterion(outputs[4], labels) 310 | loss5 = criterion(outputs[5], labels) 311 | het_feat0 = feat[0].chunk(2, 0) 312 | het_feat1 = feat[1].chunk(2, 0) 313 | het_feat2 = feat[2].chunk(2, 0) 314 | het_feat3 = feat[3].chunk(2, 0) 315 | het_feat4 = feat[4].chunk(2, 0) 316 | het_feat5 = feat[5].chunk(2, 0) 317 | loss_c0 = criterion_het(het_feat0[0], het_feat0[1], label1, label2) 318 | loss_c1 = criterion_het(het_feat1[0], het_feat1[1], label1, label2) 319 | loss_c2 = criterion_het(het_feat2[0], het_feat2[1], label1, label2) 320 | loss_c3 = criterion_het(het_feat3[0], het_feat3[1], label1, label2) 321 | loss_c4 = criterion_het(het_feat4[0], het_feat4[1], label1, label2) 322 | loss_c5 = criterion_het(het_feat5[0], het_feat5[1], label1, label2) 323 | loss0 = loss0 + w_hc * loss_c0 324 | loss1 = loss1 + w_hc * loss_c1 325 | loss2 = loss2 + w_hc * loss_c2 326 | loss3 = loss3 + w_hc * loss_c3 327 | loss4 = loss4 + w_hc * loss_c4 328 | loss5 = loss5 + w_hc * loss_c5 329 | 330 | 331 | _, predicted = outputs[0].max(1) 332 | correct += predicted.eq(labels).sum().item() 333 | 334 | optimizer.zero_grad() 335 | torch.autograd.backward([loss0, loss1, loss2, loss3, loss4, loss5], [torch.tensor(1.0).cuda(), torch.tensor(1.0).cuda(), torch.tensor(1.0).cuda(), torch.tensor(1.0).cuda(), torch.tensor(1.0).cuda(), torch.tensor(1.0).cuda()]) 336 | #loss.backward() 337 | optimizer.step() 338 | loss = (loss0 + loss1 + loss2 + loss3 + loss4 + loss5) / 6 339 | #epoch_loss = epoch_loss + loss.item() 340 | train_loss.update(loss.item(), 2*input1.size(0)) 341 | 342 | total += labels.size(0) 343 | 344 | # measure elapsed time 345 | batch_time.update(time.time() - end) 346 | end = time.time() 347 | if batch_idx%10 ==0: 348 | print('Epoch: [{}][{}/{}] ' 349 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 350 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f}) ' 351 | 'lr:{} ' 352 | 'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f}) ' 353 | 'Accu: {:.2f}' .format( 354 | epoch, batch_idx, len(trainloader),current_lr, 355 | 100.*correct/total, batch_time=batch_time, 356 | data_time=data_time, train_loss=train_loss)) 357 | #num_batch = num_batch + 1 358 | #epoch_loss = epoch_loss / num_batch 359 | #loss_log.append(epoch_loss) 360 | 361 | 362 | def test(epoch): 363 | # switch to evaluation mode 364 | net.eval() 365 | print ('Extracting Gallery Feature...') 366 | start = time.time() 367 | ptr = 0 368 | gall_feat = np.zeros((ngall, 6*args.low_dim)) 369 | with torch.no_grad(): 370 | for batch_idx, (input, label ) in enumerate(gall_loader): 371 | batch_num = input.size(0) 372 | input = Variable(input.cuda()) 373 | feat_pool, feat = net(input, input, test_mode[0]) 374 | gall_feat[ptr:ptr+batch_num,: ] = feat.detach().cpu().numpy() 375 | ptr = ptr + batch_num 376 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 377 | 378 | # switch to evaluation mode 379 | net.eval() 380 | print ('Extracting Query Feature...') 381 | start = time.time() 382 | ptr = 0 383 | query_feat = np.zeros((nquery, 6*args.low_dim)) 384 | with torch.no_grad(): 385 | for batch_idx, (input, label ) in enumerate(query_loader): 386 | batch_num = input.size(0) 387 | input = Variable(input.cuda()) 388 | feat_pool, feat = net(input, input, test_mode[1]) 389 | query_feat[ptr:ptr+batch_num,: ] = feat.detach().cpu().numpy() 390 | ptr = ptr + batch_num 391 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 392 | 393 | start = time.time() 394 | # compute the similarity 395 | distmat = np.matmul(query_feat, np.transpose(gall_feat)) 396 | 397 | # evaluation 398 | if dataset =='regdb': 399 | cmc, mAP = eval_regdb(-distmat, query_label, gall_label) 400 | elif dataset =='sysu': 401 | cmc, mAP = eval_sysu(-distmat, query_label, gall_label, query_cam, gall_cam) 402 | print('Evaluation Time:\t {:.3f}'.format(time.time()-start)) 403 | return cmc, mAP 404 | 405 | # training 406 | print('==> Start Training...') 407 | per_img = args.per_img 408 | per_id = args.batch_size / per_img 409 | w_hc = args.w_hc 410 | loss_log = [] 411 | for epoch in range(start_epoch, args.epochs+1-start_epoch): 412 | 413 | print('==> Preparing Data Loader...') 414 | # identity sampler 415 | sampler = IdentitySampler(trainset.train_color_label, \ 416 | trainset.train_thermal_label, color_pos, thermal_pos, args.batch_size, per_img) 417 | trainset.cIndex = sampler.index1 # color index 418 | trainset.tIndex = sampler.index2 # thermal index 419 | trainloader = data.DataLoader(trainset, batch_size=args.batch_size,\ 420 | sampler = sampler, num_workers=args.workers, drop_last =True) 421 | 422 | # training 423 | train(epoch, loss_log) 424 | 425 | if epoch > 0 and epoch%2 ==0: 426 | print ('Test Epoch: {}'.format(epoch)) 427 | print ('Test Epoch: {}'.format(epoch),file=test_log_file) 428 | # testing 429 | cmc, mAP = test(epoch) 430 | 431 | print('FC: Rank-1: {:.2%} | Rank-10: {:.2%} | Rank-20: {:.2%}| mAP: {:.2%}'.format( 432 | cmc[0], cmc[9], cmc[19], mAP)) 433 | print('FC: Rank-1: {:.2%} | Rank-10: {:.2%} | Rank-20: {:.2%}| mAP: {:.2%}'.format( 434 | cmc[0], cmc[9], cmc[19], mAP), file = test_log_file) 435 | test_log_file.flush() 436 | 437 | # save model 438 | if cmc[0] > best_acc: # not the real best for sysu-mm01 439 | best_acc = cmc[0] 440 | state = { 441 | 'net': net.state_dict(), 442 | 'cmc': cmc, 443 | 'mAP': mAP, 444 | 'epoch': epoch, 445 | } 446 | torch.save(state, checkpoint_path + suffix + '_best.t') 447 | 448 | # save model every 20 epochs 449 | if epoch > 10 and epoch%args.save_epoch ==0: 450 | state = { 451 | 'net': net.state_dict(), 452 | 'cmc': cmc, 453 | 'mAP': mAP, 454 | 'epoch': epoch, 455 | } 456 | torch.save(state, checkpoint_path + suffix + '_epoch_{}.t'.format(epoch)) 457 | 458 | ''' 459 | f = xlwt.Workbook() 460 | sheet1 = f.add_sheet(u'sheet1', cell_overwrite_ok=True) # 创建sheet 461 | # 将数据写入第 i 行,第 j 列 462 | i = 0 463 | for data in loss_log: 464 | sheet1.write(i, 0, data) 465 | i = i + 1 466 | f.save('log/sysu_log/'+'whc_{}'.format(args.w_hc)+'.xls'+'thd_{}'.format(args.thd)+'.xls' ) # 保存文件 467 | ''' 468 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | import numbers 4 | import numpy as np 5 | from torch.utils.data.sampler import Sampler 6 | import sys 7 | import os.path as osp 8 | import scipy.io as scio 9 | 10 | def GenIdx( train_color_label, train_thermal_label): 11 | color_pos = [] 12 | unique_label_color = np.unique(train_color_label) 13 | for i in range(len(unique_label_color)): 14 | tmp_pos = [k for k,v in enumerate(train_color_label) if v==unique_label_color[i]] 15 | color_pos.append(tmp_pos) 16 | 17 | thermal_pos = [] 18 | unique_label_thermal = np.unique(train_thermal_label) 19 | for i in range(len(unique_label_thermal)): 20 | tmp_pos = [k for k,v in enumerate(train_thermal_label) if v==unique_label_thermal[i]] 21 | thermal_pos.append(tmp_pos) 22 | return color_pos, thermal_pos 23 | 24 | 25 | class IdentitySampler(Sampler): 26 | """Sample person identities evenly in each batch. 27 | Args: 28 | train_color_label, train_thermal_label: labels of two modalities 29 | color_pos, thermal_pos: positions of each identity 30 | batchSize: batch size 31 | """ 32 | 33 | def __init__(self, train_color_label, train_thermal_label, color_pos, thermal_pos, batchSize, per_img): 34 | uni_label = np.unique(train_color_label) 35 | self.n_classes = len(uni_label) 36 | 37 | sample_color = np.arange(batchSize) 38 | sample_thermal = np.arange(batchSize) 39 | N = np.maximum(len(train_color_label), len(train_thermal_label)) 40 | 41 | #per_img = 4 42 | per_id = batchSize / per_img 43 | for j in range(N//batchSize+1): 44 | batch_idx = np.random.choice(uni_label, int(per_id), replace = False) 45 | 46 | for s, i in enumerate(range(0, batchSize, per_img)): 47 | sample_color[i:i+per_img] = np.random.choice(color_pos[batch_idx[s]], per_img, replace=False) 48 | sample_thermal[i:i+per_img] = np.random.choice(thermal_pos[batch_idx[s]], per_img, replace=False) 49 | 50 | if j ==0: 51 | index1= sample_color 52 | index2= sample_thermal 53 | else: 54 | index1 = np.hstack((index1, sample_color)) 55 | index2 = np.hstack((index2, sample_thermal)) 56 | 57 | self.index1 = index1 58 | self.index2 = index2 59 | self.N = N 60 | 61 | def __iter__(self): 62 | return iter(np.arange(len(self.index1))) 63 | 64 | def __len__(self): 65 | return self.N 66 | 67 | class AverageMeter(object): 68 | """Computes and stores the average and current value""" 69 | def __init__(self): 70 | self.reset() 71 | 72 | def reset(self): 73 | self.val = 0 74 | self.avg = 0 75 | self.sum = 0 76 | self.count = 0 77 | 78 | def update(self, val, n=1): 79 | self.val = val 80 | self.sum += val * n 81 | self.count += n 82 | self.avg = self.sum / self.count 83 | 84 | def mkdir_if_missing(directory): 85 | if not osp.exists(directory): 86 | try: 87 | os.makedirs(directory) 88 | except OSError as e: 89 | if e.errno != errno.EEXIST: 90 | raise 91 | class Logger(object): 92 | """ 93 | Write console output to external text file. 94 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 95 | """ 96 | def __init__(self, fpath=None): 97 | self.console = sys.stdout 98 | self.file = None 99 | if fpath is not None: 100 | mkdir_if_missing(osp.dirname(fpath)) 101 | self.file = open(fpath, 'w') 102 | 103 | def __del__(self): 104 | self.close() 105 | 106 | def __enter__(self): 107 | pass 108 | 109 | def __exit__(self, *args): 110 | self.close() 111 | 112 | def write(self, msg): 113 | self.console.write(msg) 114 | if self.file is not None: 115 | self.file.write(msg) 116 | 117 | def flush(self): 118 | self.console.flush() 119 | if self.file is not None: 120 | self.file.flush() 121 | os.fsync(self.file.fileno()) 122 | 123 | def close(self): 124 | self.console.close() 125 | if self.file is not None: 126 | self.file.close() 127 | 128 | --------------------------------------------------------------------------------