├── README.md ├── backbone.py ├── clean.sh ├── clean_step.sh ├── criterions.py ├── datasets.py ├── desktop.ini ├── eval.py ├── main.py ├── models.py ├── settings.py ├── solver.py ├── tensorboard.sh └── transforms.py /README.md: -------------------------------------------------------------------------------- 1 | # Leaning Compact and Representative Features for Cross-Modality Person Re-Identification 2 | Pytorch code for "Leaning Compact and Representative Features for Cross-Modality Person Re-Identification"(World Wide Web,CCF-B). 3 | 4 | ## [Highlights] 5 |
1.We devise an efficient Enumerate Angular Triplet (EAT) loss, which can better help to obtain an angularly separable common feature space via explicitly restraining the
internal angles between different embedding features, contributing to the improvement of the performance. 6 |
2.Motivated by the knowledge distillation, a novel Cross-Modality Knowledge Distillation (CMKD) loss is proposed to reduce the modality discrepancy in the modality-
specific feature extraction stage, contributing to the effectiveness of the cross-modality person Re-ID task. 7 |
3.Our network achieves prominent results on both SYSU-MM01 and RegDB datasets without any other data augment skills. It achieves a Mean Average Precision (mAP) of
43.09% and 79.92% on SYSU-MM01 and RegDB datasets, respectively. 8 | 9 | ## [Prerequisite] 10 |
Python>=3.6 11 |
Pytorch>=1.0.0 12 |
Opencv>=3.1.0 13 |
tensorboard-pytorch 14 | ## [Experiments] 15 | Training: 16 |
python main.py -a train 17 |
Testing: 18 |
python main.py -a test -m checkpoint_name -s test_setting 19 |
The test settings of SYSU-MM01 include: "all_multi" (all search mode, multi-shot), "all_single" (all search mode, single-shot), "indoor_multi" (indoor search mode, multi-shot), "indoor_single" (indoor search mode, single-shot). 20 | 21 | ## [Cite] 22 | If you find our paper/codes useful, please kindly consider citing the paper: 23 |
@article{gao2022leaning, 24 |
title={Leaning compact and representative features for cross-modality person re-identification}, 25 |
author={Gao, Guangwei and Shao, Hao and Wu, Fei and Yang, Meng and Yu, Yi}, 26 |
journal={World Wide Web}, 27 |
pages={1--18}, 28 |
year={2022}, 29 |
publisher={Springer} 30 |
} 31 | -------------------------------------------------------------------------------- /backbone.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class Bottleneck(nn.Module): 7 | expansion = 4 8 | 9 | def __init__(self, inplanes, planes, stride=1, downsample=None): 10 | super(Bottleneck, self).__init__() 11 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 12 | self.bn1 = nn.BatchNorm2d(planes) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 14 | padding=1, bias=False) 15 | self.bn2 = nn.BatchNorm2d(planes) 16 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 17 | self.bn3 = nn.BatchNorm2d(planes * 4) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.downsample = downsample 20 | self.stride = stride 21 | 22 | def forward(self, x): 23 | residual = x 24 | 25 | out = self.conv1(x) 26 | out = self.bn1(out) 27 | out = self.relu(out) 28 | 29 | out = self.conv2(out) 30 | out = self.bn2(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv3(out) 34 | out = self.bn3(out) 35 | 36 | if self.downsample is not None: 37 | residual = self.downsample(x) 38 | 39 | out += residual 40 | out = self.relu(out) 41 | 42 | return out 43 | 44 | 45 | class ResNet(nn.Module): 46 | def __init__(self, last_stride=2, block=Bottleneck, layers=[3, 4, 6, 3]): 47 | self.inplanes = 64 48 | super().__init__() 49 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 50 | bias=False) 51 | self.bn1 = nn.BatchNorm2d(64) 52 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 53 | self.layer1 = self._make_layer(block, 64, layers[0]) 54 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 55 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 56 | self.layer4 = self._make_layer( 57 | block, 512, layers[3], stride=last_stride) 58 | self.relu = nn.ReLU(inplace=True) 59 | 60 | def _make_layer(self, block, planes, blocks, stride=1): 61 | downsample = None 62 | if stride != 1 or self.inplanes != planes * block.expansion: 63 | downsample = nn.Sequential( 64 | nn.Conv2d(self.inplanes, planes * block.expansion, 65 | kernel_size=1, stride=stride, bias=False), 66 | nn.BatchNorm2d(planes * block.expansion), 67 | ) 68 | 69 | layers = [] 70 | layers.append(block(self.inplanes, planes, stride, downsample)) 71 | self.inplanes = planes * block.expansion 72 | for i in range(1, blocks): 73 | layers.append(block(self.inplanes, planes)) 74 | 75 | return nn.Sequential(*layers) 76 | 77 | def forward(self, x): 78 | x = self.conv1(x) 79 | x = self.bn1(x) 80 | x = self.relu(x) 81 | x = self.maxpool(x) 82 | 83 | x = self.layer1(x) 84 | x = self.layer2(x) 85 | x = self.layer3(x) 86 | x = self.layer4(x) 87 | 88 | return x 89 | 90 | def load_param(self, model_path): 91 | param_dict = torch.load(model_path) 92 | for i in param_dict: 93 | if 'fc' in i: 94 | continue 95 | self.state_dict()[i].copy_(param_dict[i]) 96 | 97 | def random_init(self): 98 | for m in self.modules(): 99 | if isinstance(m, nn.Conv2d): 100 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 101 | m.weight.data.normal_(0, math.sqrt(2. / n)) 102 | elif isinstance(m, nn.BatchNorm2d): 103 | m.weight.data.fill_(1) 104 | m.bias.data.zero_() 105 | -------------------------------------------------------------------------------- /clean.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | rm -rf __pycache__ 3 | rm \.*\.swp 4 | rm -R ../logdir/* 5 | rm -R ../showdir/* 6 | rm ../models/ckp_step* 7 | -------------------------------------------------------------------------------- /clean_step.sh: -------------------------------------------------------------------------------- 1 | rm ../models/step* 2 | -------------------------------------------------------------------------------- /criterions.py: -------------------------------------------------------------------------------- 1 | """ 2 | Angular Triplet Loss 3 | YE, Hanrong et al, Bi-directional Exponential Angular Triplet Loss for RGB-Infrared Person Re-Identification 4 | """ 5 | 6 | import torch.nn.functional as F 7 | import torch 8 | from torch import nn 9 | import settings 10 | 11 | class expATLoss(): 12 | def __init__(self): 13 | self.marginloss = torch.nn.MarginRankingLoss(margin = settings.at_margin) 14 | 15 | def forward(self, anc, pos, neg): 16 | cos_pos = F.cosine_similarity(anc, pos) 17 | cos_neg = F.relu(F.cosine_similarity(anc, neg)) 18 | y_true = anc.new().resize_as_(anc).fill_(1)[:,0:1] 19 | return torch.exp(self.marginloss(cos_pos, cos_neg.float(), y_true)) # max(0, -1*(cos_pos - cos_neg)) 20 | 21 | 22 | class CrossEntropyLabelSmoothLoss(nn.Module): 23 | """Cross entropy loss with label smoothing regularizer. 24 | Reference: 25 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 26 | Equation: y = (1 - epsilon) * y + epsilon / K. 27 | Args: 28 | num_classes (int): number of classes. 29 | epsilon (float): weight. 30 | """ 31 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 32 | super(CrossEntropyLabelSmoothLoss, self).__init__() 33 | self.num_classes = num_classes 34 | self.epsilon = epsilon 35 | self.use_gpu = use_gpu 36 | self.logsoftmax = nn.LogSoftmax(dim=1) 37 | 38 | def forward(self, inputs, targets): 39 | """ 40 | Args: 41 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 42 | targets: ground truth labels with shape (num_classes) 43 | """ 44 | log_probs = self.logsoftmax(inputs) 45 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 46 | if self.use_gpu: targets = targets.cuda() 47 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 48 | loss = (- targets * log_probs).mean(0).sum() 49 | return loss 50 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Angular Triplet Loss 3 | YE, Hanrong et al, Bi-directional Exponential Angular Triplet Loss for RGB-Infrared Person Re-Identification 4 | """ 5 | 6 | import glob 7 | import random 8 | import os 9 | import re 10 | import sys 11 | import urllib 12 | import tarfile 13 | import zipfile 14 | import os.path as osp 15 | from scipy.io import loadmat 16 | import numpy as np 17 | import h5py 18 | from scipy.misc import imsave 19 | import random 20 | import time 21 | import settings 22 | import torch 23 | import numpy as np 24 | from torch.utils.data import Dataset 25 | from PIL import Image 26 | import torchvision.transforms as transforms 27 | 28 | class SYSU_triplet_dataset(Dataset): 29 | 30 | def __init__(self, data_folder = 'SYSU-MM01', transforms_list=None, mode='train', search_mode='all'): 31 | 32 | if mode == 'train': 33 | self.id_file = 'train_id.txt' 34 | elif mode == 'val': 35 | self.id_file = 'val_id.txt' 36 | else: 37 | self.id_file = 'test_id.txt' 38 | 39 | if search_mode == 'all': 40 | self.rgb_cameras = ['cam1','cam2','cam4','cam5'] 41 | self.ir_cameras = ['cam3','cam6'] 42 | elif search_mode == 'indoor': 43 | self.rgb_cameras = ['cam1','cam2'] 44 | self.ir_cameras = ['cam3','cam6'] 45 | 46 | file_path = os.path.join(data_folder,'exp',self.id_file) 47 | 48 | with open(file_path, 'r') as file: 49 | self.ids = file.read().splitlines() 50 | 51 | #print(self.ids) 52 | self.ids = [int(y) for y in self.ids[0].split(',')] 53 | self.ids.sort() 54 | 55 | self.id_dict = {} 56 | 57 | for index, id in enumerate(self.ids): 58 | #print(index,id) 59 | self.id_dict[id] = index 60 | 61 | self.ids = ["%04d" % x for x in self.ids] 62 | 63 | self.transform = transforms_list 64 | 65 | self.files_rgb = {} 66 | self.files_ir = {} 67 | 68 | for id in sorted(self.ids): 69 | 70 | self.files_rgb[id] = [] 71 | self.files_ir[id] = [] 72 | 73 | for cam in self.rgb_cameras: 74 | img_dir = os.path.join(data_folder,cam,id) 75 | if os.path.isdir(img_dir): 76 | self.files_rgb[id].extend(sorted([img_dir+'/'+i for i in os.listdir(img_dir)])) 77 | for cam in self.ir_cameras: 78 | img_dir = os.path.join(data_folder,cam,id) 79 | if os.path.isdir(img_dir): 80 | self.files_ir[id].extend(sorted([img_dir+'/'+i for i in os.listdir(img_dir)])) 81 | 82 | self.all_files = [] 83 | 84 | for id in sorted(self.ids): 85 | self.all_files.extend(self.files_rgb[id]) 86 | 87 | def __getitem__(self, index): 88 | 89 | anchor_file = self.all_files[index] 90 | anchor_id = anchor_file.split('/')[-2] 91 | 92 | anchor_rgb = np.random.choice(self.files_rgb[anchor_id]) 93 | positive_rgb = np.random.choice([x for x in self.files_rgb[anchor_id] if x != anchor_rgb]) 94 | negative_id = np.random.choice([id for id in self.ids if id != anchor_id]) 95 | negative_rgb = np.random.choice(self.files_rgb[negative_id]) 96 | 97 | anchor_ir = np.random.choice(self.files_ir[anchor_id]) 98 | positive_ir = np.random.choice([x for x in self.files_ir[anchor_id] if x != anchor_ir]) 99 | negative_id = np.random.choice([id for id in self.ids if id != anchor_id]) 100 | negative_ir = np.random.choice(self.files_ir[negative_id]) 101 | 102 | anchor_label = np.array(self.id_dict[int(anchor_id)]) 103 | 104 | #print(anchor_file, positive_file, negative_file, anchor_id) 105 | 106 | anchor_rgb = Image.open(anchor_rgb) 107 | positive_rgb = Image.open(positive_rgb) 108 | negative_rgb = Image.open(negative_rgb) 109 | 110 | anchor_ir = Image.open(anchor_ir) 111 | positive_ir = Image.open(positive_ir) 112 | negative_ir = Image.open(negative_ir) 113 | 114 | if self.transform is not None: 115 | anchor_rgb = self.transform(anchor_rgb) 116 | positive_rgb = self.transform(positive_rgb) 117 | negative_rgb = self.transform(negative_rgb) 118 | 119 | anchor_ir = self.transform(anchor_ir) 120 | positive_ir = self.transform(positive_ir) 121 | negative_ir = self.transform(negative_ir) 122 | 123 | modality_rgb = torch.tensor([1,0]).float() 124 | modality_ir = torch.tensor([0,1]).float() 125 | 126 | return anchor_rgb, positive_rgb, negative_rgb, anchor_ir, positive_ir, negative_ir, anchor_label, modality_rgb, modality_ir 127 | 128 | def __len__(self): 129 | return len(self.all_files) 130 | 131 | 132 | 133 | class SYSU_eval_datasets(object): 134 | def __init__(self, data_folder = 'SYSU-MM01', search_mode='all', search_setting='single' , data_split='val', use_random=False, **kwargs): 135 | 136 | self.data_folder = data_folder 137 | self.train_id_file = 'train_id.txt' 138 | self.val_id_file = 'val_id.txt' 139 | self.test_id_file = 'test_id.txt' 140 | 141 | if search_mode == 'all': 142 | self.rgb_cameras = ['cam1','cam2','cam4','cam5'] 143 | self.ir_cameras = ['cam3','cam6'] 144 | elif search_mode == 'indoor': 145 | self.rgb_cameras = ['cam1','cam2'] 146 | self.ir_cameras = ['cam3','cam6'] 147 | 148 | if data_split == 'train': 149 | self.id_file = self.train_id_file 150 | elif data_split == 'val': 151 | self.id_file = self.val_id_file 152 | elif data_split == 'test': 153 | self.id_file = self.test_id_file 154 | 155 | self.search_setting = search_setting 156 | self.search_mode = search_mode 157 | self.use_random = use_random 158 | 159 | 160 | query, num_query_pids, num_query_imgs = self._process_query_images(id_file = self.id_file, relabel=False) 161 | gallery, num_gallery_pids, num_gallery_imgs = self._process_gallery_images(id_file = self.id_file, relabel=False) 162 | 163 | num_total_pids = num_query_pids 164 | num_total_imgs = num_query_imgs + num_gallery_imgs 165 | 166 | print("Dataset statistics:") 167 | print(" ------------------------------") 168 | print(" subset | # ids | # images") 169 | print(" ------------------------------") 170 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs)) 171 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs)) 172 | print(" ------------------------------") 173 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs)) 174 | print(" ------------------------------") 175 | 176 | self.query = query 177 | self.gallery = gallery 178 | 179 | self.num_query_pids = num_query_pids 180 | self.num_gallery_pids = num_gallery_pids 181 | 182 | def _process_query_images(self, id_file, relabel=False): 183 | 184 | file_path = os.path.join(self.data_folder,'exp',id_file) 185 | 186 | files_ir = [] 187 | 188 | with open(file_path, 'r') as file: 189 | ids = file.read().splitlines() 190 | ids = [int(y) for y in ids[0].split(',')] 191 | ids = ["%04d" % x for x in ids] 192 | 193 | for id in sorted(ids): 194 | for cam in self.ir_cameras: 195 | img_dir = os.path.join(self.data_folder,cam,id) 196 | if os.path.isdir(img_dir): 197 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 198 | files_ir.extend(new_files) #files_ir.append(random.choice(new_files)) 199 | pid_container = set() 200 | 201 | for img_path in files_ir: 202 | camid, pid = int(img_path.split('/')[-3].split('cam')[1]), int(img_path.split('/')[-2]) 203 | if pid == -1: continue # junk images are just ignored 204 | pid_container.add(pid) 205 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 206 | 207 | dataset = [] 208 | for img_path in files_ir: 209 | camid, pid = int(img_path.split('/')[-3].split('cam')[1]), int(img_path.split('/')[-2]) 210 | if pid == -1: continue # junk images are just ignored 211 | if relabel: pid = pid2label[pid] 212 | dataset.append((img_path, pid, camid)) 213 | 214 | num_pids = len(pid_container) 215 | num_imgs = len(dataset) 216 | return dataset, num_pids, num_imgs 217 | 218 | def _process_gallery_images(self, id_file, relabel=False): 219 | if self.use_random: 220 | random.seed(time.time()) 221 | else: 222 | random.seed(1) 223 | 224 | file_path = os.path.join(self.data_folder,'exp',id_file) 225 | files_rgb = [] 226 | 227 | with open(file_path, 'r') as file: 228 | ids = file.read().splitlines() 229 | ids = [int(y) for y in ids[0].split(',')] 230 | ids = ["%04d" % x for x in ids] 231 | 232 | for id in sorted(ids): 233 | for cam in self.rgb_cameras: 234 | img_dir = os.path.join(self.data_folder,cam,id) 235 | if os.path.isdir(img_dir): 236 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 237 | if self.search_setting == 'single': 238 | files_rgb.append(random.choice(new_files)) 239 | elif self.search_setting == 'multi': 240 | files_rgb.extend(random.sample(new_files, k=10)) # multi-shot, 10 for each ca 241 | 242 | pid_container = set() 243 | for img_path in files_rgb: 244 | camid, pid = int(img_path.split('/')[-3].split('cam')[1]), int(img_path.split('/')[-2]) 245 | if pid == -1: continue # junk images are just ignored 246 | pid_container.add(pid) 247 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 248 | 249 | dataset = [] 250 | for img_path in files_rgb: 251 | camid, pid = int(img_path.split('/')[-3].split('cam')[1]), int(img_path.split('/')[-2]) 252 | if pid == -1: continue # junk images are just ignored 253 | if relabel: pid = pid2label[pid] 254 | dataset.append((img_path, pid, camid)) 255 | 256 | num_pids = len(pid_container) 257 | num_imgs = len(dataset) 258 | return dataset, num_pids, num_imgs 259 | 260 | 261 | 262 | 263 | 264 | class Image_dataset(Dataset): 265 | """Image Person ReID Dataset""" 266 | def __init__(self, dataset, transform=None): 267 | self.dataset = dataset 268 | self.transform = transform 269 | 270 | def __len__(self): 271 | return len(self.dataset) 272 | 273 | def __getitem__(self, index): 274 | img_path, pid, camid = self.dataset[index] 275 | img = Image.open(img_path) 276 | if self.transform is not None: 277 | img = self.transform(img) 278 | return img, pid, camid 279 | 280 | class RegDB_triplet_dataset(Dataset): 281 | 282 | def __init__(self, data_dir, transforms_list=None, mode='train', trial=1): 283 | 284 | if mode == 'train': 285 | self.visible_files = 'train_visible_' + str(trial) + '.txt' 286 | self.thermal_files = 'train_thermal_' + str(trial) + '.txt' 287 | elif mode == 'val': 288 | self.visible_files = 'test_visible_' + str(trial) + '.txt' 289 | self.thermal_files = 'test_thermal_' + str(trial) + '.txt' 290 | else: 291 | self.visible_files = 'test_visible_' + str(trial) + '.txt' 292 | self.thermal_files = 'test_thermal_' + str(trial) + '.txt' 293 | 294 | 295 | color_list = os.path.join(data_dir, 'idx', self.visible_files) 296 | thermal_list = os.path.join(data_dir, 'idx', self.thermal_files) 297 | 298 | color_img_file, color_label = self.load_data(color_list) 299 | thermal_img_file, thermal_label = self.load_data(thermal_list) 300 | 301 | color_image = [] 302 | color_image_path = [] 303 | for i in range(len(color_img_file)): 304 | img_path = os.path.join(data_dir, color_img_file[i]) 305 | color_image_path.append(img_path) 306 | img = Image.open(img_path) 307 | img = img.resize(settings.inp_size[::-1]) #img.resize((144, 288), Image.ANTIALIAS) # (width, height) 308 | color_image.append(img) 309 | thermal_image = [] 310 | thermal_image_path = [] 311 | for i in range(len(thermal_img_file)): 312 | img_path = os.path.join(data_dir, thermal_img_file[i]) 313 | thermal_image_path.append(img_path) 314 | img = Image.open(img_path) 315 | img = img.resize(settings.inp_size[::-1], Image.ANTIALIAS) 316 | thermal_image.append(img) 317 | 318 | # make dict 319 | color_img_dict = {} 320 | for i in range(len(color_label)): 321 | label = color_label[i] 322 | if label not in color_img_dict.keys(): 323 | color_img_dict[label] = [] 324 | 325 | color_img_dict[label].append(i) 326 | 327 | thermal_img_dict = {} 328 | for i in range(len(thermal_label)): 329 | label = thermal_label[i] 330 | if label not in thermal_img_dict.keys(): 331 | thermal_img_dict[label] = [] 332 | 333 | thermal_img_dict[label].append(i) 334 | 335 | self.color_image = color_image 336 | self.color_label = color_label 337 | self.thermal_image = thermal_image 338 | self.thermal_label = thermal_label 339 | self.color_img_dict = color_img_dict 340 | self.thermal_img_dict = thermal_img_dict 341 | self.ids = list(self.color_img_dict.keys()) 342 | self.transform = transforms_list 343 | 344 | def load_data(self, input_data_path): 345 | with open(input_data_path) as f: 346 | data_file_list = open(input_data_path, 'rt').read().splitlines() 347 | # Get full list of image and labels 348 | file_image = [s.split(' ')[0] for s in data_file_list] 349 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 350 | 351 | return file_image, file_label 352 | 353 | def __getitem__(self, index): 354 | 355 | anchor_file = self.color_image[index] 356 | anchor_id = self.color_label[index] 357 | 358 | anchor_rgb = anchor_file 359 | positive_rgb = self.color_image[np.random.choice([x for x in self.color_img_dict[anchor_id] if x != anchor_rgb])] 360 | negative_id = np.random.choice([id for id in self.ids if id != anchor_id]) 361 | negative_rgb = self.color_image[np.random.choice(self.color_img_dict[negative_id])] 362 | 363 | anchor_ir = self.thermal_image[np.random.choice(self.thermal_img_dict[anchor_id])] 364 | positive_ir = self.thermal_image[np.random.choice([x for x in self.thermal_img_dict[anchor_id] if x != anchor_ir])] 365 | negative_id = np.random.choice([id for id in self.ids if id != anchor_id]) 366 | negative_ir = self.thermal_image[np.random.choice(self.thermal_img_dict[negative_id])] 367 | 368 | anchor_label = np.array(anchor_id) 369 | 370 | if self.transform is not None: 371 | anchor_rgb = self.transform(anchor_rgb) 372 | positive_rgb = self.transform(positive_rgb) 373 | negative_rgb = self.transform(negative_rgb) 374 | 375 | anchor_ir = self.transform(anchor_ir) 376 | positive_ir = self.transform(positive_ir) 377 | negative_ir = self.transform(negative_ir) 378 | 379 | modality_rgb = torch.tensor([1,0]).float() 380 | modality_ir = torch.tensor([0,1]).float() 381 | 382 | return anchor_rgb, positive_rgb, negative_rgb, anchor_ir, positive_ir, negative_ir, anchor_label, modality_rgb, modality_ir 383 | 384 | def __len__(self): 385 | return len(self.color_label) 386 | 387 | 388 | class RegDB_eval_datasets(object): 389 | def __init__(self, data_dir, transforms_list=None, mode='train', trial=1): 390 | 391 | if mode == 'train': 392 | self.visible_files = 'train_visible_' + str(trial) + '.txt' 393 | self.thermal_files = 'train_thermal_' + str(trial) + '.txt' 394 | elif mode == 'val': 395 | self.visible_files = 'test_visible_' + str(trial) + '.txt' 396 | self.thermal_files = 'test_thermal_' + str(trial) + '.txt' 397 | else: 398 | self.visible_files = 'test_visible_' + str(trial) + '.txt' 399 | self.thermal_files = 'test_thermal_' + str(trial) + '.txt' 400 | 401 | 402 | color_list = os.path.join(data_dir, 'idx', self.visible_files) 403 | thermal_list = os.path.join(data_dir, 'idx', self.thermal_files) 404 | 405 | color_img_file, color_label = self.load_data(color_list) 406 | thermal_img_file, thermal_label = self.load_data(thermal_list) 407 | 408 | color_image = [] 409 | color_image_path = [] 410 | for i in range(len(color_img_file)): 411 | img_path = os.path.join(data_dir, color_img_file[i]) 412 | color_image_path.append(img_path) 413 | img = Image.open(img_path) 414 | img = img.resize(settings.inp_size[::-1]) 415 | color_image.append((img, color_label[i], img_path)) 416 | 417 | 418 | thermal_image = [] 419 | thermal_image_path = [] 420 | for i in range(len(thermal_img_file)): 421 | img_path = os.path.join(data_dir, thermal_img_file[i]) 422 | thermal_image_path.append(img_path) 423 | img = Image.open(img_path) 424 | img = img.resize(settings.inp_size[::-1], Image.ANTIALIAS) 425 | thermal_image.append((img, thermal_label[i], img_path)) 426 | 427 | # make dict 428 | color_img_dict = {} 429 | for i in range(len(color_label)): 430 | label = color_label[i] 431 | if label not in color_img_dict.keys(): 432 | color_img_dict[label] = [] 433 | 434 | thermal_img_dict = {} 435 | for i in range(len(thermal_label)): 436 | label = thermal_label[i] 437 | if label not in thermal_img_dict.keys(): 438 | thermal_img_dict[label] = [] 439 | 440 | color_ids = list(color_img_dict.keys()) 441 | thermal_ids = list(thermal_img_dict.keys()) 442 | 443 | query = thermal_image 444 | num_query_imgs = len(query) 445 | num_query_pids = len(thermal_ids) 446 | 447 | gallery = color_image 448 | num_gallery_pids = len(color_ids) 449 | num_gallery_imgs = len(gallery) 450 | 451 | num_total_pids = num_query_pids 452 | num_total_imgs = num_query_imgs + num_gallery_imgs 453 | 454 | print("Dataset statistics:") 455 | print(" ------------------------------") 456 | print(" subset | # ids | # images") 457 | print(" ------------------------------") 458 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs)) 459 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs)) 460 | print(" ------------------------------") 461 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs)) 462 | print(" ------------------------------") 463 | 464 | self.query = query 465 | self.gallery = gallery 466 | 467 | self.num_query_pids = num_query_pids 468 | self.num_gallery_pids = num_gallery_pids 469 | 470 | def load_data(self, input_data_path): 471 | with open(input_data_path) as f: 472 | data_file_list = open(input_data_path, 'rt').read().splitlines() 473 | # Get full list of image and labels 474 | file_image = [s.split(' ')[0] for s in data_file_list] 475 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 476 | 477 | return file_image, file_label 478 | 479 | class RegDB_wrapper(Dataset): 480 | """For evaluation""" 481 | def __init__(self, dataset, transform=None): 482 | self.dataset = dataset 483 | self.transform = transform 484 | 485 | def __len__(self): 486 | return len(self.dataset) 487 | 488 | def __getitem__(self, index): 489 | img, pid, img_path = self.dataset[index] 490 | 491 | if self.transform is not None: 492 | img = self.transform(img) 493 | return img, pid, img_path 494 | 495 | if __name__ == '__main__': 496 | dataset = RegDB_triplet_dataset(settings.regdb_dir, settings.transforms_list, trial=2) 497 | print(len(dataset)) 498 | data = RegDB_eval_datasets(settings.regdb_dir, settings.test_transforms_list, trial=10) 499 | gallery_set = RegDB_wrapper(data.gallery) 500 | query_set = RegDB_wrapper(data.query) 501 | print(len(gallery_set)) 502 | 503 | 504 | 505 | -------------------------------------------------------------------------------- /desktop.ini: -------------------------------------------------------------------------------- 1 | [LocalizedFileNames] 2 | main.py=@main,0 3 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Angular Triplet Loss 3 | YE, Hanrong et al, Bi-directional Exponential Angular Triplet Loss for RGB-Infrared Person Re-Identification 4 | """ 5 | from __future__ import print_function, absolute_import 6 | import numpy as np 7 | import copy 8 | from collections import defaultdict 9 | import sys 10 | import torch 11 | import matplotlib.pyplot as plt 12 | import pickle 13 | 14 | 15 | from IPython import embed 16 | 17 | 18 | 19 | def test(feature_generators, queryloader, galleryloader, use_gpu = True, ranks=[1, 5, 10, 20]): 20 | if type(feature_generators) is list: 21 | feature_generator_rgb = feature_generators[0] 22 | feature_generator_ir = feature_generators[1] 23 | 24 | else: 25 | feature_generator_rgb = feature_generators 26 | feature_generator_ir = feature_generators 27 | feature_generator_rgb.eval() 28 | feature_generator_ir.eval() 29 | 30 | with torch.no_grad(): 31 | qf, q_pids, q_camids = [], [], [] 32 | for batch_idx, (imgs, pids, camids) in enumerate(queryloader): 33 | if use_gpu: imgs = imgs.cuda() 34 | features = feature_generator_ir(imgs) # query features # use fi 35 | features = features.data#.cpu() 36 | 37 | qf.append(features) 38 | q_pids.extend(pids) 39 | q_camids.extend(camids) 40 | 41 | qf = torch.cat(qf, 0) 42 | q_pids = np.asarray(q_pids) 43 | q_camids = np.asarray(q_camids) 44 | 45 | print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1))) 46 | 47 | gf, g_pids, g_camids = [], [], [] 48 | #end = time.time() 49 | for batch_idx, (imgs, pids, camids) in enumerate(galleryloader): 50 | if use_gpu: imgs = imgs.cuda() 51 | 52 | features = feature_generator_rgb(imgs) 53 | features = features.data#.cpu() 54 | 55 | gf.append(features) 56 | g_pids.extend(pids) 57 | g_camids.extend(camids) 58 | gf = torch.cat(gf, 0) 59 | g_pids = np.asarray(g_pids) 60 | g_camids = np.asarray(g_camids) 61 | 62 | print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1))) 63 | 64 | 65 | qf = qf.view(qf.size(0),-1) 66 | gf = gf.view(gf.size(0),-1) 67 | 68 | # see norm 69 | q_norms = qf.norm(dim=1) 70 | print('q_norms:') 71 | print(q_norms) 72 | 73 | g_norms = gf.norm(dim=1) 74 | print('g_norms:') 75 | print(g_norms) 76 | m, n = qf.size(0), gf.size(0) 77 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 78 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 79 | distmat.addmm_(1, -2, qf, gf.t()) 80 | distmat = distmat.cpu().numpy() 81 | 82 | print("Computing CMC and mAP") 83 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids) # use_metric_cuhk03=args.use_metric_cuhk03) 84 | 85 | print("Results ----------") 86 | print("mAP: {:.1%}".format(mAP)) 87 | print("CMC curve") 88 | for r in ranks: 89 | print("Rank-{:<3}: {:.1%}".format(r, cmc[r-1])) 90 | print("------------------") 91 | 92 | return distmat,cmc, mAP 93 | 94 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): 95 | """Evaluation with SYSU metric 96 | Key: for each query identity in camera 3, its gallery images from camera 2 view are discarded. 97 | """ 98 | 99 | num_q, num_g = distmat.shape 100 | if num_g < max_rank: 101 | max_rank = num_g 102 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 103 | indices = np.argsort(distmat, axis=1) 104 | 105 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 106 | 107 | # compute cmc curve for each query 108 | all_cmc = [] 109 | all_AP = [] 110 | 111 | num_valid_q = 0. # number of valid query 112 | for q_idx in range(num_q): 113 | # get query pid and camid 114 | q_pid = q_pids[q_idx] 115 | q_camid = q_camids[q_idx] 116 | # remove gallery samples that have the same pid and camid with query 117 | order = indices[q_idx] 118 | remove = (q_camid == 3) & (g_camids[order] == 2) 119 | keep = np.invert(remove) 120 | 121 | 122 | if(not q_idx): 123 | print('Query ID',q_pid) 124 | for g_idx in range(20): 125 | print('Gallery ID Rank #', g_idx ,' : ', g_pids[order[g_idx]], 'distance : ', distmat[q_idx][order[g_idx]]) 126 | 127 | # compute cmc curve 128 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 129 | if not np.any(orig_cmc): 130 | # this condition is true when query identity does not appear in gallery 131 | continue 132 | 133 | cmc = orig_cmc.cumsum() 134 | cmc[cmc > 1] = 1 135 | 136 | all_cmc.append(cmc[:max_rank]) 137 | num_valid_q += 1. 138 | 139 | # compute average precision 140 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 141 | num_rel = orig_cmc.sum() 142 | tmp_cmc = orig_cmc.cumsum() 143 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 144 | 145 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 146 | AP = tmp_cmc.sum() / num_rel 147 | all_AP.append(AP) 148 | 149 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 150 | 151 | all_cmc = np.asarray(all_cmc).astype(np.float32) 152 | all_cmc = all_cmc.sum(0) / num_valid_q 153 | mAP = np.mean(all_AP) 154 | return all_cmc, mAP 155 | 156 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Angular Triplet Loss 3 | YE, Hanrong et al, Bi-directional Exponential Angular Triplet Loss for RGB-Infrared Person Re-Identification 4 | """ 5 | 6 | import settings 7 | import os 8 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 9 | os.environ["CUDA_VISIBLE_DEVICES"]=settings.device_id 10 | import sys 11 | import argparse 12 | import csv 13 | import numpy as np 14 | import time 15 | import torch 16 | from torch import nn 17 | import torch.nn.functional as F 18 | from torch.optim import Adam 19 | from torch.optim.lr_scheduler import MultiStepLR 20 | from torch.utils.data import DataLoader 21 | from torch.utils.tensorboard import SummaryWriter 22 | from criterions import expATLoss, CrossEntropyLabelSmoothLoss 23 | import torchvision.transforms as transforms 24 | 25 | logger = settings.logger 26 | torch.cuda.manual_seed_all(66) 27 | torch.manual_seed(66) 28 | 29 | from datasets import RegDB_triplet_dataset, RegDB_eval_datasets, Image_dataset,RegDB_wrapper 30 | import itertools 31 | import solver 32 | from models import IdClassifier, FeatureEmbedder, Base_rgb,Base_ir 33 | from eval import test, evaluate 34 | 35 | 36 | 37 | from IPython import embed 38 | 39 | best_rank1 = 0 40 | 41 | 42 | 43 | def ensure_dir(dir_path): 44 | if not os.path.isdir(dir_path): 45 | os.makedirs(dir_path) 46 | 47 | class Session: 48 | def __init__(self): 49 | self.log_dir = settings.log_dir 50 | self.model_dir = settings.model_dir 51 | ensure_dir(settings.log_dir) 52 | ensure_dir(settings.model_dir) 53 | logger.info('set log dir as %s' % settings.log_dir) 54 | logger.info('set model dir as %s' % settings.model_dir) 55 | 56 | ##################################### Import models ########################### 57 | self.feature_rgb_generator = Base_rgb(last_stride=1,model_path=settings.pretrained_model_path) 58 | self.feature_ir_generator = Base_ir(last_stride=1,model_path=settings.pretrained_model_path) 59 | self.feature_embedder = FeatureEmbedder(last_stride=1,model_path=settings.pretrained_model_path) 60 | self.id_classifier = IdClassifier() 61 | 62 | if torch.cuda.is_available(): 63 | self.feature_rgb_generator.cuda() 64 | self.feature_ir_generator.cuda() 65 | self.feature_embedder.cuda() 66 | self.id_classifier.cuda() 67 | 68 | self.feature_rgb_generator = nn.DataParallel(self.feature_rgb_generator, device_ids=range(settings.num_gpu)) 69 | 70 | self.feature_ir_generator = nn.DataParallel(self.feature_ir_generator, device_ids=range(settings.num_gpu)) 71 | self.feature_embedder = nn.DataParallel(self.feature_embedder, device_ids=range(settings.num_gpu)) 72 | self.id_classifier = nn.DataParallel(self.id_classifier, device_ids=range(settings.num_gpu)) 73 | 74 | ############################# Get Losses & Optimizers ######################### 75 | self.criterion_at = expATLoss() 76 | self.loss1 = torch.nn.MSELoss() 77 | self.criterion_identity = CrossEntropyLabelSmoothLoss(settings.num_classes, epsilon=0.1) #torch.nn.CrossEntropyLoss() 78 | 79 | opt_models = [self.feature_rgb_generator, 80 | self.feature_ir_generator, 81 | self.feature_embedder, 82 | self.id_classifier] 83 | 84 | def make_optimizer(opt_models): 85 | train_params = [] 86 | 87 | for opt_model in opt_models: 88 | for key, value in opt_model.named_parameters(): 89 | if not value.requires_grad: 90 | continue 91 | lr = settings.BASE_LR 92 | weight_decay = settings.WEIGHT_DECAY 93 | if "bias" in key: 94 | lr = settings.BASE_LR * settings.BIAS_LR_FACTOR 95 | weight_decay = settings.WEIGHT_DECAY_BIAS 96 | train_params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 97 | 98 | 99 | optimizer = torch.optim.Adam(train_params) 100 | return optimizer 101 | 102 | self.optimizer_G = make_optimizer(opt_models) 103 | 104 | self.epoch_count = 0 105 | self.step = 0 106 | self.save_steps = settings.save_steps 107 | self.num_workers = settings.num_workers 108 | self.writers = {} 109 | self.dataloaders = {} 110 | 111 | self.sche_G = solver.WarmupMultiStepLR(self.optimizer_G, milestones=settings.iter_sche, gamma=0.1) # default setting s 112 | 113 | def tensorboard(self, name): 114 | self.writers[name] = SummaryWriter(os.path.join(self.log_dir, name + '.events')) 115 | return self.writers[name] 116 | 117 | 118 | def write(self, name, out): 119 | for k, v in out.items(): 120 | self.writers[name].add_scalar(name + '/' + k, v, self.step) 121 | 122 | 123 | out['G_lr'] = self.optimizer_G.param_groups[0]['lr'] 124 | out['step'] = self.step 125 | out['eooch_count'] = self.epoch_count 126 | outputs = [ 127 | "{}:{:.4g}".format(k, v) 128 | for k, v in out.items() 129 | ] 130 | logger.info(name + '--' + ' '.join(outputs)) 131 | 132 | def save_checkpoints(self, name): 133 | ckp_path = os.path.join(self.model_dir, name) 134 | obj = { 135 | 'feature_rgb_generator': self.feature_rgb_generator.state_dict(), 136 | 'feature_ir_generator': self.feature_ir_generator.state_dict(), 137 | 'feature_embedder': self.feature_embedder.state_dict(), 138 | 'id_classifier': self.id_classifier.state_dict(), 139 | 'clock': self.step, 140 | 'epoch_count': self.epoch_count, 141 | 'opt_G': self.optimizer_G.state_dict(), 142 | } 143 | torch.save(obj, ckp_path) 144 | 145 | def load_checkpoints(self, name): 146 | ckp_path = os.path.join(self.model_dir, name) 147 | try: 148 | obj = torch.load(ckp_path) 149 | print('load checkpoint: %s' %ckp_path) 150 | except FileNotFoundError: 151 | return 152 | self.feature_rgb_generator.load_state_dict(obj['feature_rgb_generator']) 153 | self.feature_ir_generator.load_state_dict(obj['feature_ir_generator']) 154 | self.feature_embedder.load_state_dict(obj['feature_embedder']) 155 | self.id_classifier.load_state_dict(obj['id_classifier']) 156 | self.optimizer_G.load_state_dict(obj['opt_G']) 157 | self.step = obj['clock'] 158 | self.epoch_count = obj['epoch_count'] 159 | self.sche_G.last_epoch = self.step 160 | 161 | 162 | def load_checkpoints_delf_init(self, name): 163 | ckp_path = os.path.join(self.model_dir, name) 164 | obj = torch.load(ckp_path) 165 | self.backbone.load_state_dict(obj['backbone']) 166 | 167 | def cal_fea(self, x, domain_mode): 168 | if domain_mode == 'rgb': 169 | feat = self.feature_rgb_generator(x) 170 | return feat,self.feature_embedder(feat) 171 | elif domain_mode == 'ir': 172 | feat = self.feature_ir_generator(x) 173 | return feat,self.feature_embedder(feat) 174 | 175 | 176 | def inf_batch(self, batch): 177 | alpha = settings.alpha 178 | beta = settings.beta 179 | 180 | anchor_rgb, positive_rgb, negative_rgb, anchor_ir, positive_ir, \ 181 | negative_ir, anchor_label, modality_rgb, modality_ir = batch 182 | 183 | if torch.cuda.is_available(): 184 | anchor_rgb = anchor_rgb.cuda() 185 | positive_rgb = positive_rgb.cuda() 186 | negative_rgb = negative_rgb.cuda() 187 | anchor_ir = anchor_ir.cuda() 188 | positive_ir = positive_ir.cuda() 189 | negative_ir = negative_ir.cuda() 190 | anchor_label = anchor_label.cuda() 191 | anchor_rgb_features1, anchor_rgb_features2 = self.cal_fea(anchor_rgb, 'rgb') 192 | positive_rgb_features1, positive_rgb_features2 = self.cal_fea(positive_rgb, 'rgb') 193 | negative_rgb_features1, negative_rgb_features2 = self.cal_fea(negative_rgb, 'rgb') 194 | 195 | anchor_ir_features1, anchor_ir_features2 = self.cal_fea(anchor_ir, 'ir') 196 | positive_ir_features1, positive_ir_features2 = self.cal_fea(positive_ir, 'ir') 197 | negative_ir_features1, negative_ir_features2 = self.cal_fea(negative_ir, 'ir') 198 | 199 | lossx = self.loss1(anchor_rgb_features1, positive_ir_features1) + self.loss1(anchor_ir_features1, 200 | positive_rgb_features1) 201 | at_loss_rgb = self.criterion_at.forward(anchor_rgb_features2, 202 | positive_ir_features2, negative_rgb_features2) 203 | 204 | at_loss_ir = self.criterion_at.forward(anchor_ir_features2, 205 | positive_rgb_features2, negative_ir_features2) 206 | 207 | at_loss = at_loss_rgb + at_loss_ir + lossx 208 | 209 | predicted_id_rgb = self.id_classifier(anchor_rgb_features2) 210 | predicted_id_ir = self.id_classifier(anchor_ir_features2) 211 | 212 | identity_loss = self.criterion_identity(predicted_id_rgb, anchor_label) + \ 213 | self.criterion_identity(predicted_id_ir, anchor_label) 214 | 215 | loss_G = alpha * at_loss + beta * identity_loss 216 | 217 | self.optimizer_G.zero_grad() 218 | loss_G.backward() 219 | self.optimizer_G.step() 220 | 221 | self.write('train_stats', {'loss_G': loss_G, 222 | 'at_loss': at_loss, 223 | 'identity_loss': identity_loss 224 | }) 225 | 226 | def run_train_val(ckp_name='ckp_latest'): 227 | sess = Session() 228 | sess.load_checkpoints(ckp_name) 229 | 230 | sess.tensorboard('train_stats') 231 | sess.tensorboard('val_stats') 232 | 233 | ######################## Get Datasets & Dataloaders ########################### 234 | 235 | train_dataset = RegDB_triplet_dataset(settings.data_folder, settings.transforms_list, trial=2) 236 | 237 | def get_train_dataloader(): 238 | return iter(DataLoader(RegDB_triplet_dataset(data_dir=settings.data_folder, transforms_list=settings.transforms_list), batch_size=settings.train_batch_size, shuffle=True,num_workers=settings.num_workers, drop_last = True)) 239 | 240 | train_dataloader = get_train_dataloader() 241 | 242 | eval_val = RegDB_eval_datasets(settings.data_folder, settings.test_transforms_list, mode = 'val',trial=2) 243 | 244 | transform_test = settings.test_transforms_list 245 | 246 | val_queryloader = DataLoader( 247 | RegDB_wrapper(eval_val.query, transform=transform_test), 248 | batch_size=settings.val_batch_size, shuffle=False, num_workers=0, 249 | drop_last=False, 250 | ) 251 | 252 | val_galleryloader = DataLoader( 253 | RegDB_wrapper(eval_val.gallery, transform=transform_test), 254 | batch_size=settings.val_batch_size, shuffle=False, num_workers=0, 255 | drop_last=False, 256 | ) 257 | 258 | while sess.step < settings.iter_sche[-1]: 259 | sess.sche_G.step() 260 | sess.feature_rgb_generator.train() 261 | sess.feature_ir_generator.train() 262 | sess.feature_embedder.train() 263 | 264 | sess.id_classifier.train() 265 | 266 | try: 267 | batch_t = next(train_dataloader) 268 | except StopIteration: 269 | train_dataloader = get_train_dataloader() 270 | batch_t = next(train_dataloader) 271 | sess.epoch_count += 1 272 | 273 | sess.inf_batch(batch_t) 274 | 275 | 276 | 277 | if sess.step % settings.val_step ==0: 278 | sess.feature_rgb_generator.eval() 279 | sess.feature_ir_generator.eval() 280 | sess.feature_embedder.eval() 281 | sess.id_classifier.eval() 282 | test_ranks, test_mAP = test([nn.Sequential(sess.feature_rgb_generator, sess.feature_embedder), nn.Sequential(sess.feature_ir_generator, sess.feature_embedder)], val_queryloader, val_galleryloader) 283 | global best_rank1 284 | if best_rank1 < test_ranks[0] * 100.0: 285 | best_rank1 = test_ranks[0] * 100.0 286 | sess.save_checkpoints('ckp_latest') 287 | sess.save_checkpoints('ckp_latest_backup') 288 | sess.write('val_stats', {'test_mAP_percentage': test_mAP*100.0, \ 289 | 'test_rank-1_accuracy_percentage':test_ranks[0]*100.0,\ 290 | 'test_rank-5_accuracy_percentage':test_ranks[4]*100.0,\ 291 | 'test_rank-10_accuracy_percentage':test_ranks[9]*100.0,\ 292 | 'test_rank-20_accuracy_percentage':test_ranks[19]*100.0 293 | }) 294 | 295 | if sess.step % sess.save_steps == 0: 296 | sess.save_checkpoints('ckp_step_%d' % sess.step) 297 | logger.info('save model as ckp_step_%d' % sess.step) 298 | sess.step += 1 299 | 300 | 301 | def run_test(ckp, setting): 302 | if ckp == 'all': 303 | models = sorted(os.listdir('../models/')) 304 | csvfile = open('all_test_results.csv', 'w') 305 | writer = csv.writer(csvfile) 306 | 307 | writer.writerow(['ckp_name', 'mAP', 'R1', 'R5', 'R10', 'R20']) 308 | 309 | for mm in models: 310 | result = test_ckp(mm, setting) 311 | writer.writerow(result) 312 | 313 | csvfile.close() 314 | 315 | else: 316 | test_ckp(ckp, setting) 317 | 318 | 319 | def test_ckp(ckp_name, setting): 320 | sess = Session() 321 | sess.load_checkpoints(ckp_name) 322 | 323 | search_mode = setting.split('_')[0] # 'all' or 'indoor' 324 | search_setting = setting.split('_')[1] # 'single' or 'multi' 325 | 326 | transform_test = settings.test_transforms_list 327 | 328 | results_ranks = np.zeros(50) 329 | results_map = np.zeros(1) 330 | 331 | for i in range(settings.test_times): 332 | eval_test = RegDB_eval_datasets(settings.data_folder, settings.test_transforms_list, trial=10) 333 | 334 | test_queryloader = DataLoader( 335 | RegDB_wrapper(eval_test.query, transform=transform_test), 336 | batch_size=settings.val_batch_size, shuffle=False, num_workers=0, 337 | drop_last=False, 338 | ) 339 | 340 | test_galleryloader = DataLoader( 341 | RegDB_wrapper(eval_test.gallery, transform=transform_test), 342 | batch_size=settings.val_batch_size, shuffle=False, num_workers=0, 343 | drop_last=False, 344 | ) 345 | 346 | distmat,test_ranks, test_mAP = test([nn.Sequential(sess.feature_rgb_generator, sess.feature_embedder), nn.Sequential(sess.feature_ir_generator, sess.feature_embedder)], test_queryloader, test_galleryloader) 347 | embed() 348 | results_ranks += test_ranks 349 | results_map += test_mAP 350 | 351 | logger.info('Test no.{} for model {} in setting {}, Test mAP: {}, R1: {}, R5: {}, R10: {}, R20: {}'.format(i, 352 | ckp_name, 353 | setting, 354 | test_mAP*100.0, 355 | test_ranks[0]*100.0, 356 | test_ranks[4]*100.0, 357 | test_ranks[9]*100.0, 358 | test_ranks[19]*100.0)) 359 | 360 | 361 | test_mAP = results_map / settings.test_times 362 | test_ranks = results_ranks / settings.test_times 363 | logger.info('For model {} in setting {}, AVG test mAP: {}, R1: {}, R5: {}, R10: {}, R20: {}'.format(ckp_name, 364 | setting, 365 | test_mAP*100.0, 366 | test_ranks[0]*100.0, 367 | test_ranks[4]*100.0, 368 | test_ranks[9]*100.0, 369 | test_ranks[19]*100.0)) 370 | 371 | return [ckp_name, test_mAP*100.0, test_ranks[0]*100.0, test_ranks[4]*100.0, test_ranks[9]*100.0, test_ranks[19]*100.0] 372 | 373 | 374 | if __name__ == '__main__': 375 | parser = argparse.ArgumentParser() 376 | parser.add_argument('-a', '--action', default='train') 377 | parser.add_argument('-m', '--model', default='ckp_latest') 378 | parser.add_argument('-s', '--setting', default='all_single') 379 | args = parser.parse_args(sys.argv[1:]) 380 | 381 | if args.action == 'train': 382 | run_train_val(args.model) 383 | elif args.action == 'test': 384 | run_test(args.model, args.setting) 385 | 386 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Angular Triplet Loss 3 | YE, Hanrong et al, Bi-directional Exponential Angular Triplet Loss for RGB-Infrared Person Re-Identification 4 | """ 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torchvision import models 9 | from backbone import ResNet 10 | import settings 11 | import torch 12 | import math 13 | 14 | class Normalize(nn.Module): 15 | def __init__(self, power=2): 16 | super(Normalize, self).__init__() 17 | self.power = power 18 | 19 | def forward(self, x): 20 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 21 | out = x.div(norm) 22 | return out 23 | 24 | 25 | class Non_local(nn.Module): 26 | def __init__(self, in_channels, reduc_ratio=2): 27 | super(Non_local, self).__init__() 28 | 29 | self.in_channels = in_channels 30 | self.inter_channels = reduc_ratio//reduc_ratio 31 | 32 | self.g = nn.Sequential( 33 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, 34 | padding=0), 35 | ) 36 | 37 | self.W = nn.Sequential( 38 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, 39 | kernel_size=1, stride=1, padding=0), 40 | nn.BatchNorm2d(self.in_channels), 41 | ) 42 | nn.init.constant_(self.W[1].weight, 0.0) 43 | nn.init.constant_(self.W[1].bias, 0.0) 44 | 45 | 46 | 47 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 48 | kernel_size=1, stride=1, padding=0) 49 | 50 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 51 | kernel_size=1, stride=1, padding=0) 52 | 53 | def forward(self, x): 54 | ''' 55 | :param x: (b, c, t, h, w) 56 | :return: 57 | ''' 58 | 59 | batch_size = x.size(0) 60 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 61 | g_x = g_x.permute(0, 2, 1) 62 | 63 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 64 | theta_x = theta_x.permute(0, 2, 1) 65 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 66 | f = torch.matmul(theta_x, phi_x) 67 | N = f.size(-1) 68 | # f_div_C = torch.nn.functional.softmax(f, dim=-1) 69 | f_div_C = f / N 70 | 71 | y = torch.matmul(f_div_C, g_x) 72 | y = y.permute(0, 2, 1).contiguous() 73 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 74 | W_y = self.W(y) 75 | z = W_y + x 76 | 77 | return z 78 | 79 | class FeatureEmbedder(nn.Module): 80 | def __init__(self,last_stride,model_path,part=3): 81 | super(FeatureEmbedder, self).__init__() 82 | #self.gap = nn.AdaptiveAvgPool2d(1) 83 | self.bottleneck = nn.BatchNorm1d(2048) 84 | self.bottleneck.bias.requires_grad_(False) # no shift 85 | self.bottleneck.apply(weights_init_kaiming) 86 | self.base = ResNet(last_stride) 87 | self.base.load_param(model_path) 88 | layers = [3, 4, 6, 3] 89 | non_layers = [0, 2, 3, 0] 90 | self.part = part 91 | self.NL_2 = nn.ModuleList( 92 | [Non_local(512) for i in range(non_layers[1])]) 93 | self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])]) 94 | self.NL_3 = nn.ModuleList( 95 | [Non_local(1024) for i in range(non_layers[2])]) 96 | self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])]) 97 | self.NL_4 = nn.ModuleList( 98 | [Non_local(2048) for i in range(non_layers[3])]) 99 | self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])]) 100 | def forward(self, x): 101 | NL2_counter = 0 102 | if len(self.NL_2_idx) == 0: self.NL_2_idx = [-1] 103 | for i in range(len(self.base.layer2)): 104 | x = self.base.layer2[i](x) 105 | if i == self.NL_2_idx[NL2_counter]: 106 | _, C, H, W = x.shape 107 | x = self.NL_2[NL2_counter](x) 108 | NL2_counter += 1 109 | # Layer 3 110 | NL3_counter = 0 111 | if len(self.NL_3_idx) == 0: self.NL_3_idx = [-1] 112 | for i in range(len(self.base.layer3)): 113 | x = self.base.layer3[i](x) 114 | if i == self.NL_3_idx[NL3_counter]: 115 | _, C, H, W = x.shape 116 | x = self.NL_3[NL3_counter](x) 117 | NL3_counter += 1 118 | # Layer 4 119 | NL4_counter = 0 120 | if len(self.NL_4_idx) == 0: self.NL_4_idx = [-1] 121 | for i in range(len(self.base.layer4)): 122 | x = self.base.layer4[i](x) 123 | if i == self.NL_4_idx[NL4_counter]: 124 | _, C, H, W = x.shape 125 | x = self.NL_4[NL4_counter](x) 126 | NL4_counter += 1 127 | b, c, h, w = x.shape 128 | y = x.view(b, c, -1) 129 | p = 3.0 130 | global_feat = (torch.mean(y ** p, dim=-1) + 1e-12) ** (1 / p) 131 | feat = global_feat.view(global_feat.shape[0], -1) # flatten to (bs, 2048) 132 | bnfeat = self.bottleneck(feat) # normalize for angular softmax 133 | return bnfeat 134 | 135 | class IdClassifier(nn.Module): 136 | def __init__(self, in_planes = 2048, num_classes = settings.num_classes): # train 296, val 99 137 | super(IdClassifier, self).__init__() 138 | self.classifier = nn.Linear(in_planes, num_classes, bias=False) 139 | self.classifier.apply(weights_init_classifier) 140 | self.dropout = 0.5 141 | self.l2norm = Normalize(2) 142 | def forward(self, x): 143 | x = x.view(x.size(0), -1) 144 | if self.training: 145 | x = F.dropout(x,self.dropout,training = self.training) 146 | x = F.elu(x) 147 | else : 148 | x = self.l2norm(x) 149 | out = self.classifier(x) 150 | return out 151 | 152 | def weights_init_kaiming(m): 153 | classname = m.__class__.__name__ 154 | if classname.find('Linear') != -1: 155 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 156 | nn.init.constant_(m.bias, 0.0) 157 | elif classname.find('Conv') != -1: 158 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 159 | if m.bias is not None: 160 | nn.init.constant_(m.bias, 0.0) 161 | elif classname.find('BatchNorm') != -1: 162 | if m.affine: 163 | nn.init.constant_(m.weight, 1.0) 164 | nn.init.constant_(m.bias, 0.0) 165 | 166 | 167 | def weights_init_classifier(m): 168 | classname = m.__class__.__name__ 169 | if classname.find('Linear') != -1: 170 | nn.init.normal_(m.weight, std=0.001) 171 | if m.bias: 172 | nn.init.constant_(m.bias, 0.0) 173 | 174 | 175 | class Baseline(nn.Module): 176 | def __init__(self, last_stride, model_path): 177 | super(Baseline, self).__init__() 178 | self.base = ResNet(last_stride) 179 | self.base.load_param(model_path) 180 | 181 | def forward(self, x): 182 | return self.base(x) # (b, 2048, 1, 1) 183 | 184 | class Base_rgb(nn.Module): 185 | def __init__(self, last_stride, model_path): 186 | super(Base_rgb, self).__init__() 187 | self.base = ResNet(last_stride) 188 | self.base.load_param(model_path) 189 | 190 | layers = [3, 4, 6, 3] 191 | non_layers = [0, 2, 3, 0] 192 | self.NL_1 = nn.ModuleList( 193 | [Non_local(256) for i in range(non_layers[0])]) 194 | self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])]) 195 | def forward(self, x): 196 | 197 | x = self.base.conv1(x) 198 | x = self.base.bn1(x) 199 | x = self.base.relu(x) 200 | x = self.base.maxpool(x) 201 | NL1_counter = 0 202 | if len(self.NL_1_idx) == 0: self.NL_1_idx = [-1] 203 | for i in range(len(self.base.layer1)): 204 | x = self.base.layer1[i](x) 205 | if i == self.NL_1_idx[NL1_counter]: 206 | _, C, H, W = x.shape 207 | x = self.NL_1[NL1_counter](x) 208 | NL1_counter += 1 209 | return x 210 | 211 | class Base_ir(nn.Module): 212 | def __init__(self, last_stride, model_path): 213 | super(Base_ir, self).__init__() 214 | self.base = ResNet(last_stride) 215 | self.base.load_param(model_path) 216 | 217 | layers = [3, 4, 6, 3] 218 | non_layers = [0, 2, 3, 0] 219 | self.NL_1 = nn.ModuleList( 220 | [Non_local(256) for i in range(non_layers[0])]) 221 | self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])]) 222 | def forward(self, x): 223 | 224 | x = self.base.conv1(x) 225 | x = self.base.bn1(x) 226 | x = self.base.relu(x) 227 | x = self.base.maxpool(x) 228 | NL1_counter = 0 229 | if len(self.NL_1_idx) == 0: self.NL_1_idx = [-1] 230 | for i in range(len(self.base.layer1)): 231 | x = self.base.layer1[i](x) 232 | if i == self.NL_1_idx[NL1_counter]: 233 | _, C, H, W = x.shape 234 | x = self.NL_1[NL1_counter](x) 235 | NL1_counter += 1 236 | return x -------------------------------------------------------------------------------- /settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import transforms 4 | 5 | G_lr = 3e-4 6 | BASE_LR = 3e-4 7 | BIAS_LR_FACTOR = 2 8 | WEIGHT_DECAY = 0.0005 9 | WEIGHT_DECAY_BIAS = 0. 10 | D_lr = 1e-4 11 | iter_sche = [10000, 20000, 30000] 12 | 13 | train_batch_size = 8 14 | val_batch_size = 16 15 | 16 | log_dir = '../logdir' 17 | show_dir = '../showdir' 18 | model_dir = '../models' 19 | data_folder = '/home/ggw/HaoShao/dataset/RegDB' 20 | pretrained_model_path = '/home/ggw/.cache/torch/checkpoints/resnet50-19c8e357.pth' 21 | 22 | model_path = os.path.join(model_dir, 'latest') 23 | save_steps = 5000 24 | latest_steps = 100 25 | val_step = 200 26 | 27 | num_workers = 4 28 | num_gpu = 1 29 | device_id = '1' 30 | num_classes = 296 31 | test_times = 10 # official setting 32 | 33 | # for showing logger 34 | logger = logging.getLogger('train') 35 | logger.setLevel(logging.INFO) 36 | 37 | ch = logging.StreamHandler() 38 | ch.setLevel(logging.INFO) 39 | 40 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 41 | ch.setFormatter(formatter) 42 | logger.addHandler(ch) 43 | 44 | 45 | ############################# Hyper-parameters ################################ 46 | alpha = 1.0 47 | beta = 1.0 48 | at_margin = 1 49 | 50 | pixel_mean = [0.485, 0.456, 0.406] 51 | pixel_std = [0.229, 0.224, 0.225] 52 | inp_size = [384, 128] 53 | 54 | # transforms 55 | 56 | transforms_list = transforms.Compose([transforms.RectScale(*inp_size), 57 | transforms.RandomHorizontalFlip(), 58 | transforms.Pad(10), 59 | transforms.RandomCrop(inp_size), 60 | transforms.ToTensor(), 61 | transforms.Normalize(mean=pixel_mean, 62 | std=pixel_std), 63 | transforms.RandomErasing(probability=0.5, mean=pixel_mean)]) 64 | 65 | test_transforms_list = transforms.Compose([ 66 | transforms.RectScale(*inp_size), 67 | transforms.ToTensor(), 68 | transforms.Normalize(mean=pixel_mean, 69 | std=pixel_std)]) 70 | 71 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from bisect import bisect_right 7 | import torch 8 | 9 | 10 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 11 | # separating MultiStepLR with WarmupLR 12 | # but the current LRScheduler design doesn't allow it 13 | 14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 15 | def __init__( 16 | self, 17 | optimizer, 18 | milestones, 19 | gamma=0.01, 20 | warmup_factor=1.0 / 3, 21 | warmup_iters=500, 22 | warmup_method="linear", 23 | last_epoch=-1, 24 | ): 25 | if not list(milestones) == sorted(milestones): 26 | raise ValueError( 27 | "Milestones should be a list of" " increasing integers. Got {}", 28 | milestones, 29 | ) 30 | 31 | if warmup_method not in ("constant", "linear"): 32 | raise ValueError( 33 | "Only 'constant' or 'linear' warmup_method accepted" 34 | "got {}".format(warmup_method) 35 | ) 36 | self.milestones = milestones 37 | self.gamma = gamma 38 | self.warmup_factor = warmup_factor 39 | self.warmup_iters = warmup_iters 40 | self.warmup_method = warmup_method 41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 42 | 43 | def get_lr(self): 44 | warmup_factor = 1 45 | if self.last_epoch < self.warmup_iters: 46 | if self.warmup_method == "constant": 47 | warmup_factor = self.warmup_factor 48 | elif self.warmup_method == "linear": 49 | alpha = self.last_epoch / self.warmup_iters 50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 51 | return [ 52 | base_lr 53 | * warmup_factor 54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 55 | for base_lr in self.base_lrs 56 | ] 57 | -------------------------------------------------------------------------------- /tensorboard.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | function rand() { 3 | min=$1 4 | max=$(($2-$min+1)) 5 | num=$(($RANDOM+1000000000000)) 6 | echo $(($num%$max+$min)) 7 | } 8 | 9 | #rnd=$(rand 3000 12000) 10 | tensorboard --logdir ../logdir --host 0.0.0.0 --port 17650 --reload_interval 3 11 | -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from torchvision.transforms import * 3 | from PIL import Image 4 | import random 5 | import math 6 | 7 | 8 | class RectScale(object): 9 | def __init__(self, height, width, interpolation=Image.BILINEAR): 10 | self.height = height 11 | self.width = width 12 | self.interpolation = interpolation 13 | 14 | def __call__(self, img): 15 | w, h = img.size 16 | if h == self.height and w == self.width: 17 | return img 18 | return img.resize((self.width, self.height), self.interpolation) 19 | 20 | 21 | 22 | class RandomSizedRectCrop(object): 23 | def __init__(self, height, width, interpolation=Image.BILINEAR): 24 | self.height = height 25 | self.width = width 26 | self.interpolation = interpolation 27 | 28 | def __call__(self, img): 29 | for attempt in range(10): 30 | area = img.size[0] * img.size[1] 31 | target_area = random.uniform(0.64, 1.0) * area 32 | aspect_ratio = random.uniform(2, 3) 33 | 34 | h = int(round(math.sqrt(target_area * aspect_ratio))) 35 | w = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | if w <= img.size[0] and h <= img.size[1]: 38 | x1 = random.randint(0, img.size[0] - w) 39 | y1 = random.randint(0, img.size[1] - h) 40 | 41 | img = img.crop((x1, y1, x1 + w, y1 + h)) 42 | assert(img.size == (w, h)) 43 | 44 | return img.resize((self.width, self.height), self.interpolation) 45 | 46 | # Fallback 47 | scale = RectScale(self.height, self.width, 48 | interpolation=self.interpolation) 49 | return scale(img) 50 | 51 | 52 | class RandomErasing(object): 53 | """ Randomly selects a rectangle region in an image and erases its pixels. 54 | 'Random Erasing Data Augmentation' by Zhong et al. 55 | See https://arxiv.org/pdf/1708.04896.pdf 56 | Args: 57 | probability: The probability that the Random Erasing operation will be performed. 58 | sl: Minimum proportion of erased area against input image. 59 | sh: Maximum proportion of erased area against input image. 60 | r1: Minimum aspect ratio of erased area. 61 | mean: Erasing value. 62 | """ 63 | 64 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 65 | self.probability = probability 66 | self.mean = mean 67 | self.sl = sl 68 | self.sh = sh 69 | self.r1 = r1 70 | 71 | def __call__(self, img): 72 | 73 | if random.uniform(0, 1) > self.probability: 74 | return img 75 | 76 | for attempt in range(100): 77 | area = img.size()[1] * img.size()[2] 78 | 79 | target_area = random.uniform(self.sl, self.sh) * area 80 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 81 | 82 | h = int(round(math.sqrt(target_area * aspect_ratio))) 83 | w = int(round(math.sqrt(target_area / aspect_ratio))) 84 | 85 | if w < img.size()[2] and h < img.size()[1]: 86 | x1 = random.randint(0, img.size()[1] - h) 87 | y1 = random.randint(0, img.size()[2] - w) 88 | if img.size()[0] == 3: 89 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 90 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 91 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 92 | else: 93 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 94 | return img 95 | 96 | return img 97 | --------------------------------------------------------------------------------