├── README.md
├── backbone.py
├── clean.sh
├── clean_step.sh
├── criterions.py
├── datasets.py
├── desktop.ini
├── eval.py
├── main.py
├── models.py
├── settings.py
├── solver.py
├── tensorboard.sh
└── transforms.py
/README.md:
--------------------------------------------------------------------------------
1 | # Leaning Compact and Representative Features for Cross-Modality Person Re-Identification
2 | Pytorch code for "Leaning Compact and Representative Features for Cross-Modality Person Re-Identification"(World Wide Web,CCF-B).
3 |
4 | ## [Highlights]
5 |
1.We devise an efficient Enumerate Angular Triplet (EAT) loss, which can better help to obtain an angularly separable common feature space via explicitly restraining the
internal angles between different embedding features, contributing to the improvement of the performance.
6 |
2.Motivated by the knowledge distillation, a novel Cross-Modality Knowledge Distillation (CMKD) loss is proposed to reduce the modality discrepancy in the modality-
specific feature extraction stage, contributing to the effectiveness of the cross-modality person Re-ID task.
7 |
3.Our network achieves prominent results on both SYSU-MM01 and RegDB datasets without any other data augment skills. It achieves a Mean Average Precision (mAP) of
43.09% and 79.92% on SYSU-MM01 and RegDB datasets, respectively.
8 |
9 | ## [Prerequisite]
10 |
Python>=3.6
11 |
Pytorch>=1.0.0
12 |
Opencv>=3.1.0
13 |
tensorboard-pytorch
14 | ## [Experiments]
15 | Training:
16 |
python main.py -a train
17 |
Testing:
18 |
python main.py -a test -m checkpoint_name -s test_setting
19 |
The test settings of SYSU-MM01 include: "all_multi" (all search mode, multi-shot), "all_single" (all search mode, single-shot), "indoor_multi" (indoor search mode, multi-shot), "indoor_single" (indoor search mode, single-shot).
20 |
21 | ## [Cite]
22 | If you find our paper/codes useful, please kindly consider citing the paper:
23 |
@article{gao2022leaning,
24 |
title={Leaning compact and representative features for cross-modality person re-identification},
25 |
author={Gao, Guangwei and Shao, Hao and Wu, Fei and Yang, Meng and Yu, Yi},
26 |
journal={World Wide Web},
27 |
pages={1--18},
28 |
year={2022},
29 |
publisher={Springer}
30 |
}
31 |
--------------------------------------------------------------------------------
/backbone.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 |
5 |
6 | class Bottleneck(nn.Module):
7 | expansion = 4
8 |
9 | def __init__(self, inplanes, planes, stride=1, downsample=None):
10 | super(Bottleneck, self).__init__()
11 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
12 | self.bn1 = nn.BatchNorm2d(planes)
13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
14 | padding=1, bias=False)
15 | self.bn2 = nn.BatchNorm2d(planes)
16 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
17 | self.bn3 = nn.BatchNorm2d(planes * 4)
18 | self.relu = nn.ReLU(inplace=True)
19 | self.downsample = downsample
20 | self.stride = stride
21 |
22 | def forward(self, x):
23 | residual = x
24 |
25 | out = self.conv1(x)
26 | out = self.bn1(out)
27 | out = self.relu(out)
28 |
29 | out = self.conv2(out)
30 | out = self.bn2(out)
31 | out = self.relu(out)
32 |
33 | out = self.conv3(out)
34 | out = self.bn3(out)
35 |
36 | if self.downsample is not None:
37 | residual = self.downsample(x)
38 |
39 | out += residual
40 | out = self.relu(out)
41 |
42 | return out
43 |
44 |
45 | class ResNet(nn.Module):
46 | def __init__(self, last_stride=2, block=Bottleneck, layers=[3, 4, 6, 3]):
47 | self.inplanes = 64
48 | super().__init__()
49 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
50 | bias=False)
51 | self.bn1 = nn.BatchNorm2d(64)
52 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
53 | self.layer1 = self._make_layer(block, 64, layers[0])
54 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
55 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
56 | self.layer4 = self._make_layer(
57 | block, 512, layers[3], stride=last_stride)
58 | self.relu = nn.ReLU(inplace=True)
59 |
60 | def _make_layer(self, block, planes, blocks, stride=1):
61 | downsample = None
62 | if stride != 1 or self.inplanes != planes * block.expansion:
63 | downsample = nn.Sequential(
64 | nn.Conv2d(self.inplanes, planes * block.expansion,
65 | kernel_size=1, stride=stride, bias=False),
66 | nn.BatchNorm2d(planes * block.expansion),
67 | )
68 |
69 | layers = []
70 | layers.append(block(self.inplanes, planes, stride, downsample))
71 | self.inplanes = planes * block.expansion
72 | for i in range(1, blocks):
73 | layers.append(block(self.inplanes, planes))
74 |
75 | return nn.Sequential(*layers)
76 |
77 | def forward(self, x):
78 | x = self.conv1(x)
79 | x = self.bn1(x)
80 | x = self.relu(x)
81 | x = self.maxpool(x)
82 |
83 | x = self.layer1(x)
84 | x = self.layer2(x)
85 | x = self.layer3(x)
86 | x = self.layer4(x)
87 |
88 | return x
89 |
90 | def load_param(self, model_path):
91 | param_dict = torch.load(model_path)
92 | for i in param_dict:
93 | if 'fc' in i:
94 | continue
95 | self.state_dict()[i].copy_(param_dict[i])
96 |
97 | def random_init(self):
98 | for m in self.modules():
99 | if isinstance(m, nn.Conv2d):
100 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
101 | m.weight.data.normal_(0, math.sqrt(2. / n))
102 | elif isinstance(m, nn.BatchNorm2d):
103 | m.weight.data.fill_(1)
104 | m.bias.data.zero_()
105 |
--------------------------------------------------------------------------------
/clean.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | rm -rf __pycache__
3 | rm \.*\.swp
4 | rm -R ../logdir/*
5 | rm -R ../showdir/*
6 | rm ../models/ckp_step*
7 |
--------------------------------------------------------------------------------
/clean_step.sh:
--------------------------------------------------------------------------------
1 | rm ../models/step*
2 |
--------------------------------------------------------------------------------
/criterions.py:
--------------------------------------------------------------------------------
1 | """
2 | Angular Triplet Loss
3 | YE, Hanrong et al, Bi-directional Exponential Angular Triplet Loss for RGB-Infrared Person Re-Identification
4 | """
5 |
6 | import torch.nn.functional as F
7 | import torch
8 | from torch import nn
9 | import settings
10 |
11 | class expATLoss():
12 | def __init__(self):
13 | self.marginloss = torch.nn.MarginRankingLoss(margin = settings.at_margin)
14 |
15 | def forward(self, anc, pos, neg):
16 | cos_pos = F.cosine_similarity(anc, pos)
17 | cos_neg = F.relu(F.cosine_similarity(anc, neg))
18 | y_true = anc.new().resize_as_(anc).fill_(1)[:,0:1]
19 | return torch.exp(self.marginloss(cos_pos, cos_neg.float(), y_true)) # max(0, -1*(cos_pos - cos_neg))
20 |
21 |
22 | class CrossEntropyLabelSmoothLoss(nn.Module):
23 | """Cross entropy loss with label smoothing regularizer.
24 | Reference:
25 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
26 | Equation: y = (1 - epsilon) * y + epsilon / K.
27 | Args:
28 | num_classes (int): number of classes.
29 | epsilon (float): weight.
30 | """
31 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True):
32 | super(CrossEntropyLabelSmoothLoss, self).__init__()
33 | self.num_classes = num_classes
34 | self.epsilon = epsilon
35 | self.use_gpu = use_gpu
36 | self.logsoftmax = nn.LogSoftmax(dim=1)
37 |
38 | def forward(self, inputs, targets):
39 | """
40 | Args:
41 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
42 | targets: ground truth labels with shape (num_classes)
43 | """
44 | log_probs = self.logsoftmax(inputs)
45 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
46 | if self.use_gpu: targets = targets.cuda()
47 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
48 | loss = (- targets * log_probs).mean(0).sum()
49 | return loss
50 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | """
2 | Angular Triplet Loss
3 | YE, Hanrong et al, Bi-directional Exponential Angular Triplet Loss for RGB-Infrared Person Re-Identification
4 | """
5 |
6 | import glob
7 | import random
8 | import os
9 | import re
10 | import sys
11 | import urllib
12 | import tarfile
13 | import zipfile
14 | import os.path as osp
15 | from scipy.io import loadmat
16 | import numpy as np
17 | import h5py
18 | from scipy.misc import imsave
19 | import random
20 | import time
21 | import settings
22 | import torch
23 | import numpy as np
24 | from torch.utils.data import Dataset
25 | from PIL import Image
26 | import torchvision.transforms as transforms
27 |
28 | class SYSU_triplet_dataset(Dataset):
29 |
30 | def __init__(self, data_folder = 'SYSU-MM01', transforms_list=None, mode='train', search_mode='all'):
31 |
32 | if mode == 'train':
33 | self.id_file = 'train_id.txt'
34 | elif mode == 'val':
35 | self.id_file = 'val_id.txt'
36 | else:
37 | self.id_file = 'test_id.txt'
38 |
39 | if search_mode == 'all':
40 | self.rgb_cameras = ['cam1','cam2','cam4','cam5']
41 | self.ir_cameras = ['cam3','cam6']
42 | elif search_mode == 'indoor':
43 | self.rgb_cameras = ['cam1','cam2']
44 | self.ir_cameras = ['cam3','cam6']
45 |
46 | file_path = os.path.join(data_folder,'exp',self.id_file)
47 |
48 | with open(file_path, 'r') as file:
49 | self.ids = file.read().splitlines()
50 |
51 | #print(self.ids)
52 | self.ids = [int(y) for y in self.ids[0].split(',')]
53 | self.ids.sort()
54 |
55 | self.id_dict = {}
56 |
57 | for index, id in enumerate(self.ids):
58 | #print(index,id)
59 | self.id_dict[id] = index
60 |
61 | self.ids = ["%04d" % x for x in self.ids]
62 |
63 | self.transform = transforms_list
64 |
65 | self.files_rgb = {}
66 | self.files_ir = {}
67 |
68 | for id in sorted(self.ids):
69 |
70 | self.files_rgb[id] = []
71 | self.files_ir[id] = []
72 |
73 | for cam in self.rgb_cameras:
74 | img_dir = os.path.join(data_folder,cam,id)
75 | if os.path.isdir(img_dir):
76 | self.files_rgb[id].extend(sorted([img_dir+'/'+i for i in os.listdir(img_dir)]))
77 | for cam in self.ir_cameras:
78 | img_dir = os.path.join(data_folder,cam,id)
79 | if os.path.isdir(img_dir):
80 | self.files_ir[id].extend(sorted([img_dir+'/'+i for i in os.listdir(img_dir)]))
81 |
82 | self.all_files = []
83 |
84 | for id in sorted(self.ids):
85 | self.all_files.extend(self.files_rgb[id])
86 |
87 | def __getitem__(self, index):
88 |
89 | anchor_file = self.all_files[index]
90 | anchor_id = anchor_file.split('/')[-2]
91 |
92 | anchor_rgb = np.random.choice(self.files_rgb[anchor_id])
93 | positive_rgb = np.random.choice([x for x in self.files_rgb[anchor_id] if x != anchor_rgb])
94 | negative_id = np.random.choice([id for id in self.ids if id != anchor_id])
95 | negative_rgb = np.random.choice(self.files_rgb[negative_id])
96 |
97 | anchor_ir = np.random.choice(self.files_ir[anchor_id])
98 | positive_ir = np.random.choice([x for x in self.files_ir[anchor_id] if x != anchor_ir])
99 | negative_id = np.random.choice([id for id in self.ids if id != anchor_id])
100 | negative_ir = np.random.choice(self.files_ir[negative_id])
101 |
102 | anchor_label = np.array(self.id_dict[int(anchor_id)])
103 |
104 | #print(anchor_file, positive_file, negative_file, anchor_id)
105 |
106 | anchor_rgb = Image.open(anchor_rgb)
107 | positive_rgb = Image.open(positive_rgb)
108 | negative_rgb = Image.open(negative_rgb)
109 |
110 | anchor_ir = Image.open(anchor_ir)
111 | positive_ir = Image.open(positive_ir)
112 | negative_ir = Image.open(negative_ir)
113 |
114 | if self.transform is not None:
115 | anchor_rgb = self.transform(anchor_rgb)
116 | positive_rgb = self.transform(positive_rgb)
117 | negative_rgb = self.transform(negative_rgb)
118 |
119 | anchor_ir = self.transform(anchor_ir)
120 | positive_ir = self.transform(positive_ir)
121 | negative_ir = self.transform(negative_ir)
122 |
123 | modality_rgb = torch.tensor([1,0]).float()
124 | modality_ir = torch.tensor([0,1]).float()
125 |
126 | return anchor_rgb, positive_rgb, negative_rgb, anchor_ir, positive_ir, negative_ir, anchor_label, modality_rgb, modality_ir
127 |
128 | def __len__(self):
129 | return len(self.all_files)
130 |
131 |
132 |
133 | class SYSU_eval_datasets(object):
134 | def __init__(self, data_folder = 'SYSU-MM01', search_mode='all', search_setting='single' , data_split='val', use_random=False, **kwargs):
135 |
136 | self.data_folder = data_folder
137 | self.train_id_file = 'train_id.txt'
138 | self.val_id_file = 'val_id.txt'
139 | self.test_id_file = 'test_id.txt'
140 |
141 | if search_mode == 'all':
142 | self.rgb_cameras = ['cam1','cam2','cam4','cam5']
143 | self.ir_cameras = ['cam3','cam6']
144 | elif search_mode == 'indoor':
145 | self.rgb_cameras = ['cam1','cam2']
146 | self.ir_cameras = ['cam3','cam6']
147 |
148 | if data_split == 'train':
149 | self.id_file = self.train_id_file
150 | elif data_split == 'val':
151 | self.id_file = self.val_id_file
152 | elif data_split == 'test':
153 | self.id_file = self.test_id_file
154 |
155 | self.search_setting = search_setting
156 | self.search_mode = search_mode
157 | self.use_random = use_random
158 |
159 |
160 | query, num_query_pids, num_query_imgs = self._process_query_images(id_file = self.id_file, relabel=False)
161 | gallery, num_gallery_pids, num_gallery_imgs = self._process_gallery_images(id_file = self.id_file, relabel=False)
162 |
163 | num_total_pids = num_query_pids
164 | num_total_imgs = num_query_imgs + num_gallery_imgs
165 |
166 | print("Dataset statistics:")
167 | print(" ------------------------------")
168 | print(" subset | # ids | # images")
169 | print(" ------------------------------")
170 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs))
171 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs))
172 | print(" ------------------------------")
173 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs))
174 | print(" ------------------------------")
175 |
176 | self.query = query
177 | self.gallery = gallery
178 |
179 | self.num_query_pids = num_query_pids
180 | self.num_gallery_pids = num_gallery_pids
181 |
182 | def _process_query_images(self, id_file, relabel=False):
183 |
184 | file_path = os.path.join(self.data_folder,'exp',id_file)
185 |
186 | files_ir = []
187 |
188 | with open(file_path, 'r') as file:
189 | ids = file.read().splitlines()
190 | ids = [int(y) for y in ids[0].split(',')]
191 | ids = ["%04d" % x for x in ids]
192 |
193 | for id in sorted(ids):
194 | for cam in self.ir_cameras:
195 | img_dir = os.path.join(self.data_folder,cam,id)
196 | if os.path.isdir(img_dir):
197 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)])
198 | files_ir.extend(new_files) #files_ir.append(random.choice(new_files))
199 | pid_container = set()
200 |
201 | for img_path in files_ir:
202 | camid, pid = int(img_path.split('/')[-3].split('cam')[1]), int(img_path.split('/')[-2])
203 | if pid == -1: continue # junk images are just ignored
204 | pid_container.add(pid)
205 | pid2label = {pid:label for label, pid in enumerate(pid_container)}
206 |
207 | dataset = []
208 | for img_path in files_ir:
209 | camid, pid = int(img_path.split('/')[-3].split('cam')[1]), int(img_path.split('/')[-2])
210 | if pid == -1: continue # junk images are just ignored
211 | if relabel: pid = pid2label[pid]
212 | dataset.append((img_path, pid, camid))
213 |
214 | num_pids = len(pid_container)
215 | num_imgs = len(dataset)
216 | return dataset, num_pids, num_imgs
217 |
218 | def _process_gallery_images(self, id_file, relabel=False):
219 | if self.use_random:
220 | random.seed(time.time())
221 | else:
222 | random.seed(1)
223 |
224 | file_path = os.path.join(self.data_folder,'exp',id_file)
225 | files_rgb = []
226 |
227 | with open(file_path, 'r') as file:
228 | ids = file.read().splitlines()
229 | ids = [int(y) for y in ids[0].split(',')]
230 | ids = ["%04d" % x for x in ids]
231 |
232 | for id in sorted(ids):
233 | for cam in self.rgb_cameras:
234 | img_dir = os.path.join(self.data_folder,cam,id)
235 | if os.path.isdir(img_dir):
236 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)])
237 | if self.search_setting == 'single':
238 | files_rgb.append(random.choice(new_files))
239 | elif self.search_setting == 'multi':
240 | files_rgb.extend(random.sample(new_files, k=10)) # multi-shot, 10 for each ca
241 |
242 | pid_container = set()
243 | for img_path in files_rgb:
244 | camid, pid = int(img_path.split('/')[-3].split('cam')[1]), int(img_path.split('/')[-2])
245 | if pid == -1: continue # junk images are just ignored
246 | pid_container.add(pid)
247 | pid2label = {pid:label for label, pid in enumerate(pid_container)}
248 |
249 | dataset = []
250 | for img_path in files_rgb:
251 | camid, pid = int(img_path.split('/')[-3].split('cam')[1]), int(img_path.split('/')[-2])
252 | if pid == -1: continue # junk images are just ignored
253 | if relabel: pid = pid2label[pid]
254 | dataset.append((img_path, pid, camid))
255 |
256 | num_pids = len(pid_container)
257 | num_imgs = len(dataset)
258 | return dataset, num_pids, num_imgs
259 |
260 |
261 |
262 |
263 |
264 | class Image_dataset(Dataset):
265 | """Image Person ReID Dataset"""
266 | def __init__(self, dataset, transform=None):
267 | self.dataset = dataset
268 | self.transform = transform
269 |
270 | def __len__(self):
271 | return len(self.dataset)
272 |
273 | def __getitem__(self, index):
274 | img_path, pid, camid = self.dataset[index]
275 | img = Image.open(img_path)
276 | if self.transform is not None:
277 | img = self.transform(img)
278 | return img, pid, camid
279 |
280 | class RegDB_triplet_dataset(Dataset):
281 |
282 | def __init__(self, data_dir, transforms_list=None, mode='train', trial=1):
283 |
284 | if mode == 'train':
285 | self.visible_files = 'train_visible_' + str(trial) + '.txt'
286 | self.thermal_files = 'train_thermal_' + str(trial) + '.txt'
287 | elif mode == 'val':
288 | self.visible_files = 'test_visible_' + str(trial) + '.txt'
289 | self.thermal_files = 'test_thermal_' + str(trial) + '.txt'
290 | else:
291 | self.visible_files = 'test_visible_' + str(trial) + '.txt'
292 | self.thermal_files = 'test_thermal_' + str(trial) + '.txt'
293 |
294 |
295 | color_list = os.path.join(data_dir, 'idx', self.visible_files)
296 | thermal_list = os.path.join(data_dir, 'idx', self.thermal_files)
297 |
298 | color_img_file, color_label = self.load_data(color_list)
299 | thermal_img_file, thermal_label = self.load_data(thermal_list)
300 |
301 | color_image = []
302 | color_image_path = []
303 | for i in range(len(color_img_file)):
304 | img_path = os.path.join(data_dir, color_img_file[i])
305 | color_image_path.append(img_path)
306 | img = Image.open(img_path)
307 | img = img.resize(settings.inp_size[::-1]) #img.resize((144, 288), Image.ANTIALIAS) # (width, height)
308 | color_image.append(img)
309 | thermal_image = []
310 | thermal_image_path = []
311 | for i in range(len(thermal_img_file)):
312 | img_path = os.path.join(data_dir, thermal_img_file[i])
313 | thermal_image_path.append(img_path)
314 | img = Image.open(img_path)
315 | img = img.resize(settings.inp_size[::-1], Image.ANTIALIAS)
316 | thermal_image.append(img)
317 |
318 | # make dict
319 | color_img_dict = {}
320 | for i in range(len(color_label)):
321 | label = color_label[i]
322 | if label not in color_img_dict.keys():
323 | color_img_dict[label] = []
324 |
325 | color_img_dict[label].append(i)
326 |
327 | thermal_img_dict = {}
328 | for i in range(len(thermal_label)):
329 | label = thermal_label[i]
330 | if label not in thermal_img_dict.keys():
331 | thermal_img_dict[label] = []
332 |
333 | thermal_img_dict[label].append(i)
334 |
335 | self.color_image = color_image
336 | self.color_label = color_label
337 | self.thermal_image = thermal_image
338 | self.thermal_label = thermal_label
339 | self.color_img_dict = color_img_dict
340 | self.thermal_img_dict = thermal_img_dict
341 | self.ids = list(self.color_img_dict.keys())
342 | self.transform = transforms_list
343 |
344 | def load_data(self, input_data_path):
345 | with open(input_data_path) as f:
346 | data_file_list = open(input_data_path, 'rt').read().splitlines()
347 | # Get full list of image and labels
348 | file_image = [s.split(' ')[0] for s in data_file_list]
349 | file_label = [int(s.split(' ')[1]) for s in data_file_list]
350 |
351 | return file_image, file_label
352 |
353 | def __getitem__(self, index):
354 |
355 | anchor_file = self.color_image[index]
356 | anchor_id = self.color_label[index]
357 |
358 | anchor_rgb = anchor_file
359 | positive_rgb = self.color_image[np.random.choice([x for x in self.color_img_dict[anchor_id] if x != anchor_rgb])]
360 | negative_id = np.random.choice([id for id in self.ids if id != anchor_id])
361 | negative_rgb = self.color_image[np.random.choice(self.color_img_dict[negative_id])]
362 |
363 | anchor_ir = self.thermal_image[np.random.choice(self.thermal_img_dict[anchor_id])]
364 | positive_ir = self.thermal_image[np.random.choice([x for x in self.thermal_img_dict[anchor_id] if x != anchor_ir])]
365 | negative_id = np.random.choice([id for id in self.ids if id != anchor_id])
366 | negative_ir = self.thermal_image[np.random.choice(self.thermal_img_dict[negative_id])]
367 |
368 | anchor_label = np.array(anchor_id)
369 |
370 | if self.transform is not None:
371 | anchor_rgb = self.transform(anchor_rgb)
372 | positive_rgb = self.transform(positive_rgb)
373 | negative_rgb = self.transform(negative_rgb)
374 |
375 | anchor_ir = self.transform(anchor_ir)
376 | positive_ir = self.transform(positive_ir)
377 | negative_ir = self.transform(negative_ir)
378 |
379 | modality_rgb = torch.tensor([1,0]).float()
380 | modality_ir = torch.tensor([0,1]).float()
381 |
382 | return anchor_rgb, positive_rgb, negative_rgb, anchor_ir, positive_ir, negative_ir, anchor_label, modality_rgb, modality_ir
383 |
384 | def __len__(self):
385 | return len(self.color_label)
386 |
387 |
388 | class RegDB_eval_datasets(object):
389 | def __init__(self, data_dir, transforms_list=None, mode='train', trial=1):
390 |
391 | if mode == 'train':
392 | self.visible_files = 'train_visible_' + str(trial) + '.txt'
393 | self.thermal_files = 'train_thermal_' + str(trial) + '.txt'
394 | elif mode == 'val':
395 | self.visible_files = 'test_visible_' + str(trial) + '.txt'
396 | self.thermal_files = 'test_thermal_' + str(trial) + '.txt'
397 | else:
398 | self.visible_files = 'test_visible_' + str(trial) + '.txt'
399 | self.thermal_files = 'test_thermal_' + str(trial) + '.txt'
400 |
401 |
402 | color_list = os.path.join(data_dir, 'idx', self.visible_files)
403 | thermal_list = os.path.join(data_dir, 'idx', self.thermal_files)
404 |
405 | color_img_file, color_label = self.load_data(color_list)
406 | thermal_img_file, thermal_label = self.load_data(thermal_list)
407 |
408 | color_image = []
409 | color_image_path = []
410 | for i in range(len(color_img_file)):
411 | img_path = os.path.join(data_dir, color_img_file[i])
412 | color_image_path.append(img_path)
413 | img = Image.open(img_path)
414 | img = img.resize(settings.inp_size[::-1])
415 | color_image.append((img, color_label[i], img_path))
416 |
417 |
418 | thermal_image = []
419 | thermal_image_path = []
420 | for i in range(len(thermal_img_file)):
421 | img_path = os.path.join(data_dir, thermal_img_file[i])
422 | thermal_image_path.append(img_path)
423 | img = Image.open(img_path)
424 | img = img.resize(settings.inp_size[::-1], Image.ANTIALIAS)
425 | thermal_image.append((img, thermal_label[i], img_path))
426 |
427 | # make dict
428 | color_img_dict = {}
429 | for i in range(len(color_label)):
430 | label = color_label[i]
431 | if label not in color_img_dict.keys():
432 | color_img_dict[label] = []
433 |
434 | thermal_img_dict = {}
435 | for i in range(len(thermal_label)):
436 | label = thermal_label[i]
437 | if label not in thermal_img_dict.keys():
438 | thermal_img_dict[label] = []
439 |
440 | color_ids = list(color_img_dict.keys())
441 | thermal_ids = list(thermal_img_dict.keys())
442 |
443 | query = thermal_image
444 | num_query_imgs = len(query)
445 | num_query_pids = len(thermal_ids)
446 |
447 | gallery = color_image
448 | num_gallery_pids = len(color_ids)
449 | num_gallery_imgs = len(gallery)
450 |
451 | num_total_pids = num_query_pids
452 | num_total_imgs = num_query_imgs + num_gallery_imgs
453 |
454 | print("Dataset statistics:")
455 | print(" ------------------------------")
456 | print(" subset | # ids | # images")
457 | print(" ------------------------------")
458 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs))
459 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs))
460 | print(" ------------------------------")
461 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs))
462 | print(" ------------------------------")
463 |
464 | self.query = query
465 | self.gallery = gallery
466 |
467 | self.num_query_pids = num_query_pids
468 | self.num_gallery_pids = num_gallery_pids
469 |
470 | def load_data(self, input_data_path):
471 | with open(input_data_path) as f:
472 | data_file_list = open(input_data_path, 'rt').read().splitlines()
473 | # Get full list of image and labels
474 | file_image = [s.split(' ')[0] for s in data_file_list]
475 | file_label = [int(s.split(' ')[1]) for s in data_file_list]
476 |
477 | return file_image, file_label
478 |
479 | class RegDB_wrapper(Dataset):
480 | """For evaluation"""
481 | def __init__(self, dataset, transform=None):
482 | self.dataset = dataset
483 | self.transform = transform
484 |
485 | def __len__(self):
486 | return len(self.dataset)
487 |
488 | def __getitem__(self, index):
489 | img, pid, img_path = self.dataset[index]
490 |
491 | if self.transform is not None:
492 | img = self.transform(img)
493 | return img, pid, img_path
494 |
495 | if __name__ == '__main__':
496 | dataset = RegDB_triplet_dataset(settings.regdb_dir, settings.transforms_list, trial=2)
497 | print(len(dataset))
498 | data = RegDB_eval_datasets(settings.regdb_dir, settings.test_transforms_list, trial=10)
499 | gallery_set = RegDB_wrapper(data.gallery)
500 | query_set = RegDB_wrapper(data.query)
501 | print(len(gallery_set))
502 |
503 |
504 |
505 |
--------------------------------------------------------------------------------
/desktop.ini:
--------------------------------------------------------------------------------
1 | [LocalizedFileNames]
2 | main.py=@main,0
3 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | """
2 | Angular Triplet Loss
3 | YE, Hanrong et al, Bi-directional Exponential Angular Triplet Loss for RGB-Infrared Person Re-Identification
4 | """
5 | from __future__ import print_function, absolute_import
6 | import numpy as np
7 | import copy
8 | from collections import defaultdict
9 | import sys
10 | import torch
11 | import matplotlib.pyplot as plt
12 | import pickle
13 |
14 |
15 | from IPython import embed
16 |
17 |
18 |
19 | def test(feature_generators, queryloader, galleryloader, use_gpu = True, ranks=[1, 5, 10, 20]):
20 | if type(feature_generators) is list:
21 | feature_generator_rgb = feature_generators[0]
22 | feature_generator_ir = feature_generators[1]
23 |
24 | else:
25 | feature_generator_rgb = feature_generators
26 | feature_generator_ir = feature_generators
27 | feature_generator_rgb.eval()
28 | feature_generator_ir.eval()
29 |
30 | with torch.no_grad():
31 | qf, q_pids, q_camids = [], [], []
32 | for batch_idx, (imgs, pids, camids) in enumerate(queryloader):
33 | if use_gpu: imgs = imgs.cuda()
34 | features = feature_generator_ir(imgs) # query features # use fi
35 | features = features.data#.cpu()
36 |
37 | qf.append(features)
38 | q_pids.extend(pids)
39 | q_camids.extend(camids)
40 |
41 | qf = torch.cat(qf, 0)
42 | q_pids = np.asarray(q_pids)
43 | q_camids = np.asarray(q_camids)
44 |
45 | print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1)))
46 |
47 | gf, g_pids, g_camids = [], [], []
48 | #end = time.time()
49 | for batch_idx, (imgs, pids, camids) in enumerate(galleryloader):
50 | if use_gpu: imgs = imgs.cuda()
51 |
52 | features = feature_generator_rgb(imgs)
53 | features = features.data#.cpu()
54 |
55 | gf.append(features)
56 | g_pids.extend(pids)
57 | g_camids.extend(camids)
58 | gf = torch.cat(gf, 0)
59 | g_pids = np.asarray(g_pids)
60 | g_camids = np.asarray(g_camids)
61 |
62 | print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1)))
63 |
64 |
65 | qf = qf.view(qf.size(0),-1)
66 | gf = gf.view(gf.size(0),-1)
67 |
68 | # see norm
69 | q_norms = qf.norm(dim=1)
70 | print('q_norms:')
71 | print(q_norms)
72 |
73 | g_norms = gf.norm(dim=1)
74 | print('g_norms:')
75 | print(g_norms)
76 | m, n = qf.size(0), gf.size(0)
77 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
78 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
79 | distmat.addmm_(1, -2, qf, gf.t())
80 | distmat = distmat.cpu().numpy()
81 |
82 | print("Computing CMC and mAP")
83 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids) # use_metric_cuhk03=args.use_metric_cuhk03)
84 |
85 | print("Results ----------")
86 | print("mAP: {:.1%}".format(mAP))
87 | print("CMC curve")
88 | for r in ranks:
89 | print("Rank-{:<3}: {:.1%}".format(r, cmc[r-1]))
90 | print("------------------")
91 |
92 | return distmat,cmc, mAP
93 |
94 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50):
95 | """Evaluation with SYSU metric
96 | Key: for each query identity in camera 3, its gallery images from camera 2 view are discarded.
97 | """
98 |
99 | num_q, num_g = distmat.shape
100 | if num_g < max_rank:
101 | max_rank = num_g
102 | print("Note: number of gallery samples is quite small, got {}".format(num_g))
103 | indices = np.argsort(distmat, axis=1)
104 |
105 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
106 |
107 | # compute cmc curve for each query
108 | all_cmc = []
109 | all_AP = []
110 |
111 | num_valid_q = 0. # number of valid query
112 | for q_idx in range(num_q):
113 | # get query pid and camid
114 | q_pid = q_pids[q_idx]
115 | q_camid = q_camids[q_idx]
116 | # remove gallery samples that have the same pid and camid with query
117 | order = indices[q_idx]
118 | remove = (q_camid == 3) & (g_camids[order] == 2)
119 | keep = np.invert(remove)
120 |
121 |
122 | if(not q_idx):
123 | print('Query ID',q_pid)
124 | for g_idx in range(20):
125 | print('Gallery ID Rank #', g_idx ,' : ', g_pids[order[g_idx]], 'distance : ', distmat[q_idx][order[g_idx]])
126 |
127 | # compute cmc curve
128 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
129 | if not np.any(orig_cmc):
130 | # this condition is true when query identity does not appear in gallery
131 | continue
132 |
133 | cmc = orig_cmc.cumsum()
134 | cmc[cmc > 1] = 1
135 |
136 | all_cmc.append(cmc[:max_rank])
137 | num_valid_q += 1.
138 |
139 | # compute average precision
140 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
141 | num_rel = orig_cmc.sum()
142 | tmp_cmc = orig_cmc.cumsum()
143 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)]
144 |
145 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc
146 | AP = tmp_cmc.sum() / num_rel
147 | all_AP.append(AP)
148 |
149 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery"
150 |
151 | all_cmc = np.asarray(all_cmc).astype(np.float32)
152 | all_cmc = all_cmc.sum(0) / num_valid_q
153 | mAP = np.mean(all_AP)
154 | return all_cmc, mAP
155 |
156 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Angular Triplet Loss
3 | YE, Hanrong et al, Bi-directional Exponential Angular Triplet Loss for RGB-Infrared Person Re-Identification
4 | """
5 |
6 | import settings
7 | import os
8 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
9 | os.environ["CUDA_VISIBLE_DEVICES"]=settings.device_id
10 | import sys
11 | import argparse
12 | import csv
13 | import numpy as np
14 | import time
15 | import torch
16 | from torch import nn
17 | import torch.nn.functional as F
18 | from torch.optim import Adam
19 | from torch.optim.lr_scheduler import MultiStepLR
20 | from torch.utils.data import DataLoader
21 | from torch.utils.tensorboard import SummaryWriter
22 | from criterions import expATLoss, CrossEntropyLabelSmoothLoss
23 | import torchvision.transforms as transforms
24 |
25 | logger = settings.logger
26 | torch.cuda.manual_seed_all(66)
27 | torch.manual_seed(66)
28 |
29 | from datasets import RegDB_triplet_dataset, RegDB_eval_datasets, Image_dataset,RegDB_wrapper
30 | import itertools
31 | import solver
32 | from models import IdClassifier, FeatureEmbedder, Base_rgb,Base_ir
33 | from eval import test, evaluate
34 |
35 |
36 |
37 | from IPython import embed
38 |
39 | best_rank1 = 0
40 |
41 |
42 |
43 | def ensure_dir(dir_path):
44 | if not os.path.isdir(dir_path):
45 | os.makedirs(dir_path)
46 |
47 | class Session:
48 | def __init__(self):
49 | self.log_dir = settings.log_dir
50 | self.model_dir = settings.model_dir
51 | ensure_dir(settings.log_dir)
52 | ensure_dir(settings.model_dir)
53 | logger.info('set log dir as %s' % settings.log_dir)
54 | logger.info('set model dir as %s' % settings.model_dir)
55 |
56 | ##################################### Import models ###########################
57 | self.feature_rgb_generator = Base_rgb(last_stride=1,model_path=settings.pretrained_model_path)
58 | self.feature_ir_generator = Base_ir(last_stride=1,model_path=settings.pretrained_model_path)
59 | self.feature_embedder = FeatureEmbedder(last_stride=1,model_path=settings.pretrained_model_path)
60 | self.id_classifier = IdClassifier()
61 |
62 | if torch.cuda.is_available():
63 | self.feature_rgb_generator.cuda()
64 | self.feature_ir_generator.cuda()
65 | self.feature_embedder.cuda()
66 | self.id_classifier.cuda()
67 |
68 | self.feature_rgb_generator = nn.DataParallel(self.feature_rgb_generator, device_ids=range(settings.num_gpu))
69 |
70 | self.feature_ir_generator = nn.DataParallel(self.feature_ir_generator, device_ids=range(settings.num_gpu))
71 | self.feature_embedder = nn.DataParallel(self.feature_embedder, device_ids=range(settings.num_gpu))
72 | self.id_classifier = nn.DataParallel(self.id_classifier, device_ids=range(settings.num_gpu))
73 |
74 | ############################# Get Losses & Optimizers #########################
75 | self.criterion_at = expATLoss()
76 | self.loss1 = torch.nn.MSELoss()
77 | self.criterion_identity = CrossEntropyLabelSmoothLoss(settings.num_classes, epsilon=0.1) #torch.nn.CrossEntropyLoss()
78 |
79 | opt_models = [self.feature_rgb_generator,
80 | self.feature_ir_generator,
81 | self.feature_embedder,
82 | self.id_classifier]
83 |
84 | def make_optimizer(opt_models):
85 | train_params = []
86 |
87 | for opt_model in opt_models:
88 | for key, value in opt_model.named_parameters():
89 | if not value.requires_grad:
90 | continue
91 | lr = settings.BASE_LR
92 | weight_decay = settings.WEIGHT_DECAY
93 | if "bias" in key:
94 | lr = settings.BASE_LR * settings.BIAS_LR_FACTOR
95 | weight_decay = settings.WEIGHT_DECAY_BIAS
96 | train_params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
97 |
98 |
99 | optimizer = torch.optim.Adam(train_params)
100 | return optimizer
101 |
102 | self.optimizer_G = make_optimizer(opt_models)
103 |
104 | self.epoch_count = 0
105 | self.step = 0
106 | self.save_steps = settings.save_steps
107 | self.num_workers = settings.num_workers
108 | self.writers = {}
109 | self.dataloaders = {}
110 |
111 | self.sche_G = solver.WarmupMultiStepLR(self.optimizer_G, milestones=settings.iter_sche, gamma=0.1) # default setting s
112 |
113 | def tensorboard(self, name):
114 | self.writers[name] = SummaryWriter(os.path.join(self.log_dir, name + '.events'))
115 | return self.writers[name]
116 |
117 |
118 | def write(self, name, out):
119 | for k, v in out.items():
120 | self.writers[name].add_scalar(name + '/' + k, v, self.step)
121 |
122 |
123 | out['G_lr'] = self.optimizer_G.param_groups[0]['lr']
124 | out['step'] = self.step
125 | out['eooch_count'] = self.epoch_count
126 | outputs = [
127 | "{}:{:.4g}".format(k, v)
128 | for k, v in out.items()
129 | ]
130 | logger.info(name + '--' + ' '.join(outputs))
131 |
132 | def save_checkpoints(self, name):
133 | ckp_path = os.path.join(self.model_dir, name)
134 | obj = {
135 | 'feature_rgb_generator': self.feature_rgb_generator.state_dict(),
136 | 'feature_ir_generator': self.feature_ir_generator.state_dict(),
137 | 'feature_embedder': self.feature_embedder.state_dict(),
138 | 'id_classifier': self.id_classifier.state_dict(),
139 | 'clock': self.step,
140 | 'epoch_count': self.epoch_count,
141 | 'opt_G': self.optimizer_G.state_dict(),
142 | }
143 | torch.save(obj, ckp_path)
144 |
145 | def load_checkpoints(self, name):
146 | ckp_path = os.path.join(self.model_dir, name)
147 | try:
148 | obj = torch.load(ckp_path)
149 | print('load checkpoint: %s' %ckp_path)
150 | except FileNotFoundError:
151 | return
152 | self.feature_rgb_generator.load_state_dict(obj['feature_rgb_generator'])
153 | self.feature_ir_generator.load_state_dict(obj['feature_ir_generator'])
154 | self.feature_embedder.load_state_dict(obj['feature_embedder'])
155 | self.id_classifier.load_state_dict(obj['id_classifier'])
156 | self.optimizer_G.load_state_dict(obj['opt_G'])
157 | self.step = obj['clock']
158 | self.epoch_count = obj['epoch_count']
159 | self.sche_G.last_epoch = self.step
160 |
161 |
162 | def load_checkpoints_delf_init(self, name):
163 | ckp_path = os.path.join(self.model_dir, name)
164 | obj = torch.load(ckp_path)
165 | self.backbone.load_state_dict(obj['backbone'])
166 |
167 | def cal_fea(self, x, domain_mode):
168 | if domain_mode == 'rgb':
169 | feat = self.feature_rgb_generator(x)
170 | return feat,self.feature_embedder(feat)
171 | elif domain_mode == 'ir':
172 | feat = self.feature_ir_generator(x)
173 | return feat,self.feature_embedder(feat)
174 |
175 |
176 | def inf_batch(self, batch):
177 | alpha = settings.alpha
178 | beta = settings.beta
179 |
180 | anchor_rgb, positive_rgb, negative_rgb, anchor_ir, positive_ir, \
181 | negative_ir, anchor_label, modality_rgb, modality_ir = batch
182 |
183 | if torch.cuda.is_available():
184 | anchor_rgb = anchor_rgb.cuda()
185 | positive_rgb = positive_rgb.cuda()
186 | negative_rgb = negative_rgb.cuda()
187 | anchor_ir = anchor_ir.cuda()
188 | positive_ir = positive_ir.cuda()
189 | negative_ir = negative_ir.cuda()
190 | anchor_label = anchor_label.cuda()
191 | anchor_rgb_features1, anchor_rgb_features2 = self.cal_fea(anchor_rgb, 'rgb')
192 | positive_rgb_features1, positive_rgb_features2 = self.cal_fea(positive_rgb, 'rgb')
193 | negative_rgb_features1, negative_rgb_features2 = self.cal_fea(negative_rgb, 'rgb')
194 |
195 | anchor_ir_features1, anchor_ir_features2 = self.cal_fea(anchor_ir, 'ir')
196 | positive_ir_features1, positive_ir_features2 = self.cal_fea(positive_ir, 'ir')
197 | negative_ir_features1, negative_ir_features2 = self.cal_fea(negative_ir, 'ir')
198 |
199 | lossx = self.loss1(anchor_rgb_features1, positive_ir_features1) + self.loss1(anchor_ir_features1,
200 | positive_rgb_features1)
201 | at_loss_rgb = self.criterion_at.forward(anchor_rgb_features2,
202 | positive_ir_features2, negative_rgb_features2)
203 |
204 | at_loss_ir = self.criterion_at.forward(anchor_ir_features2,
205 | positive_rgb_features2, negative_ir_features2)
206 |
207 | at_loss = at_loss_rgb + at_loss_ir + lossx
208 |
209 | predicted_id_rgb = self.id_classifier(anchor_rgb_features2)
210 | predicted_id_ir = self.id_classifier(anchor_ir_features2)
211 |
212 | identity_loss = self.criterion_identity(predicted_id_rgb, anchor_label) + \
213 | self.criterion_identity(predicted_id_ir, anchor_label)
214 |
215 | loss_G = alpha * at_loss + beta * identity_loss
216 |
217 | self.optimizer_G.zero_grad()
218 | loss_G.backward()
219 | self.optimizer_G.step()
220 |
221 | self.write('train_stats', {'loss_G': loss_G,
222 | 'at_loss': at_loss,
223 | 'identity_loss': identity_loss
224 | })
225 |
226 | def run_train_val(ckp_name='ckp_latest'):
227 | sess = Session()
228 | sess.load_checkpoints(ckp_name)
229 |
230 | sess.tensorboard('train_stats')
231 | sess.tensorboard('val_stats')
232 |
233 | ######################## Get Datasets & Dataloaders ###########################
234 |
235 | train_dataset = RegDB_triplet_dataset(settings.data_folder, settings.transforms_list, trial=2)
236 |
237 | def get_train_dataloader():
238 | return iter(DataLoader(RegDB_triplet_dataset(data_dir=settings.data_folder, transforms_list=settings.transforms_list), batch_size=settings.train_batch_size, shuffle=True,num_workers=settings.num_workers, drop_last = True))
239 |
240 | train_dataloader = get_train_dataloader()
241 |
242 | eval_val = RegDB_eval_datasets(settings.data_folder, settings.test_transforms_list, mode = 'val',trial=2)
243 |
244 | transform_test = settings.test_transforms_list
245 |
246 | val_queryloader = DataLoader(
247 | RegDB_wrapper(eval_val.query, transform=transform_test),
248 | batch_size=settings.val_batch_size, shuffle=False, num_workers=0,
249 | drop_last=False,
250 | )
251 |
252 | val_galleryloader = DataLoader(
253 | RegDB_wrapper(eval_val.gallery, transform=transform_test),
254 | batch_size=settings.val_batch_size, shuffle=False, num_workers=0,
255 | drop_last=False,
256 | )
257 |
258 | while sess.step < settings.iter_sche[-1]:
259 | sess.sche_G.step()
260 | sess.feature_rgb_generator.train()
261 | sess.feature_ir_generator.train()
262 | sess.feature_embedder.train()
263 |
264 | sess.id_classifier.train()
265 |
266 | try:
267 | batch_t = next(train_dataloader)
268 | except StopIteration:
269 | train_dataloader = get_train_dataloader()
270 | batch_t = next(train_dataloader)
271 | sess.epoch_count += 1
272 |
273 | sess.inf_batch(batch_t)
274 |
275 |
276 |
277 | if sess.step % settings.val_step ==0:
278 | sess.feature_rgb_generator.eval()
279 | sess.feature_ir_generator.eval()
280 | sess.feature_embedder.eval()
281 | sess.id_classifier.eval()
282 | test_ranks, test_mAP = test([nn.Sequential(sess.feature_rgb_generator, sess.feature_embedder), nn.Sequential(sess.feature_ir_generator, sess.feature_embedder)], val_queryloader, val_galleryloader)
283 | global best_rank1
284 | if best_rank1 < test_ranks[0] * 100.0:
285 | best_rank1 = test_ranks[0] * 100.0
286 | sess.save_checkpoints('ckp_latest')
287 | sess.save_checkpoints('ckp_latest_backup')
288 | sess.write('val_stats', {'test_mAP_percentage': test_mAP*100.0, \
289 | 'test_rank-1_accuracy_percentage':test_ranks[0]*100.0,\
290 | 'test_rank-5_accuracy_percentage':test_ranks[4]*100.0,\
291 | 'test_rank-10_accuracy_percentage':test_ranks[9]*100.0,\
292 | 'test_rank-20_accuracy_percentage':test_ranks[19]*100.0
293 | })
294 |
295 | if sess.step % sess.save_steps == 0:
296 | sess.save_checkpoints('ckp_step_%d' % sess.step)
297 | logger.info('save model as ckp_step_%d' % sess.step)
298 | sess.step += 1
299 |
300 |
301 | def run_test(ckp, setting):
302 | if ckp == 'all':
303 | models = sorted(os.listdir('../models/'))
304 | csvfile = open('all_test_results.csv', 'w')
305 | writer = csv.writer(csvfile)
306 |
307 | writer.writerow(['ckp_name', 'mAP', 'R1', 'R5', 'R10', 'R20'])
308 |
309 | for mm in models:
310 | result = test_ckp(mm, setting)
311 | writer.writerow(result)
312 |
313 | csvfile.close()
314 |
315 | else:
316 | test_ckp(ckp, setting)
317 |
318 |
319 | def test_ckp(ckp_name, setting):
320 | sess = Session()
321 | sess.load_checkpoints(ckp_name)
322 |
323 | search_mode = setting.split('_')[0] # 'all' or 'indoor'
324 | search_setting = setting.split('_')[1] # 'single' or 'multi'
325 |
326 | transform_test = settings.test_transforms_list
327 |
328 | results_ranks = np.zeros(50)
329 | results_map = np.zeros(1)
330 |
331 | for i in range(settings.test_times):
332 | eval_test = RegDB_eval_datasets(settings.data_folder, settings.test_transforms_list, trial=10)
333 |
334 | test_queryloader = DataLoader(
335 | RegDB_wrapper(eval_test.query, transform=transform_test),
336 | batch_size=settings.val_batch_size, shuffle=False, num_workers=0,
337 | drop_last=False,
338 | )
339 |
340 | test_galleryloader = DataLoader(
341 | RegDB_wrapper(eval_test.gallery, transform=transform_test),
342 | batch_size=settings.val_batch_size, shuffle=False, num_workers=0,
343 | drop_last=False,
344 | )
345 |
346 | distmat,test_ranks, test_mAP = test([nn.Sequential(sess.feature_rgb_generator, sess.feature_embedder), nn.Sequential(sess.feature_ir_generator, sess.feature_embedder)], test_queryloader, test_galleryloader)
347 | embed()
348 | results_ranks += test_ranks
349 | results_map += test_mAP
350 |
351 | logger.info('Test no.{} for model {} in setting {}, Test mAP: {}, R1: {}, R5: {}, R10: {}, R20: {}'.format(i,
352 | ckp_name,
353 | setting,
354 | test_mAP*100.0,
355 | test_ranks[0]*100.0,
356 | test_ranks[4]*100.0,
357 | test_ranks[9]*100.0,
358 | test_ranks[19]*100.0))
359 |
360 |
361 | test_mAP = results_map / settings.test_times
362 | test_ranks = results_ranks / settings.test_times
363 | logger.info('For model {} in setting {}, AVG test mAP: {}, R1: {}, R5: {}, R10: {}, R20: {}'.format(ckp_name,
364 | setting,
365 | test_mAP*100.0,
366 | test_ranks[0]*100.0,
367 | test_ranks[4]*100.0,
368 | test_ranks[9]*100.0,
369 | test_ranks[19]*100.0))
370 |
371 | return [ckp_name, test_mAP*100.0, test_ranks[0]*100.0, test_ranks[4]*100.0, test_ranks[9]*100.0, test_ranks[19]*100.0]
372 |
373 |
374 | if __name__ == '__main__':
375 | parser = argparse.ArgumentParser()
376 | parser.add_argument('-a', '--action', default='train')
377 | parser.add_argument('-m', '--model', default='ckp_latest')
378 | parser.add_argument('-s', '--setting', default='all_single')
379 | args = parser.parse_args(sys.argv[1:])
380 |
381 | if args.action == 'train':
382 | run_train_val(args.model)
383 | elif args.action == 'test':
384 | run_test(args.model, args.setting)
385 |
386 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | """
2 | Angular Triplet Loss
3 | YE, Hanrong et al, Bi-directional Exponential Angular Triplet Loss for RGB-Infrared Person Re-Identification
4 | """
5 |
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torchvision import models
9 | from backbone import ResNet
10 | import settings
11 | import torch
12 | import math
13 |
14 | class Normalize(nn.Module):
15 | def __init__(self, power=2):
16 | super(Normalize, self).__init__()
17 | self.power = power
18 |
19 | def forward(self, x):
20 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
21 | out = x.div(norm)
22 | return out
23 |
24 |
25 | class Non_local(nn.Module):
26 | def __init__(self, in_channels, reduc_ratio=2):
27 | super(Non_local, self).__init__()
28 |
29 | self.in_channels = in_channels
30 | self.inter_channels = reduc_ratio//reduc_ratio
31 |
32 | self.g = nn.Sequential(
33 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1,
34 | padding=0),
35 | )
36 |
37 | self.W = nn.Sequential(
38 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
39 | kernel_size=1, stride=1, padding=0),
40 | nn.BatchNorm2d(self.in_channels),
41 | )
42 | nn.init.constant_(self.W[1].weight, 0.0)
43 | nn.init.constant_(self.W[1].bias, 0.0)
44 |
45 |
46 |
47 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
48 | kernel_size=1, stride=1, padding=0)
49 |
50 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
51 | kernel_size=1, stride=1, padding=0)
52 |
53 | def forward(self, x):
54 | '''
55 | :param x: (b, c, t, h, w)
56 | :return:
57 | '''
58 |
59 | batch_size = x.size(0)
60 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
61 | g_x = g_x.permute(0, 2, 1)
62 |
63 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
64 | theta_x = theta_x.permute(0, 2, 1)
65 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
66 | f = torch.matmul(theta_x, phi_x)
67 | N = f.size(-1)
68 | # f_div_C = torch.nn.functional.softmax(f, dim=-1)
69 | f_div_C = f / N
70 |
71 | y = torch.matmul(f_div_C, g_x)
72 | y = y.permute(0, 2, 1).contiguous()
73 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
74 | W_y = self.W(y)
75 | z = W_y + x
76 |
77 | return z
78 |
79 | class FeatureEmbedder(nn.Module):
80 | def __init__(self,last_stride,model_path,part=3):
81 | super(FeatureEmbedder, self).__init__()
82 | #self.gap = nn.AdaptiveAvgPool2d(1)
83 | self.bottleneck = nn.BatchNorm1d(2048)
84 | self.bottleneck.bias.requires_grad_(False) # no shift
85 | self.bottleneck.apply(weights_init_kaiming)
86 | self.base = ResNet(last_stride)
87 | self.base.load_param(model_path)
88 | layers = [3, 4, 6, 3]
89 | non_layers = [0, 2, 3, 0]
90 | self.part = part
91 | self.NL_2 = nn.ModuleList(
92 | [Non_local(512) for i in range(non_layers[1])])
93 | self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])])
94 | self.NL_3 = nn.ModuleList(
95 | [Non_local(1024) for i in range(non_layers[2])])
96 | self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])])
97 | self.NL_4 = nn.ModuleList(
98 | [Non_local(2048) for i in range(non_layers[3])])
99 | self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])])
100 | def forward(self, x):
101 | NL2_counter = 0
102 | if len(self.NL_2_idx) == 0: self.NL_2_idx = [-1]
103 | for i in range(len(self.base.layer2)):
104 | x = self.base.layer2[i](x)
105 | if i == self.NL_2_idx[NL2_counter]:
106 | _, C, H, W = x.shape
107 | x = self.NL_2[NL2_counter](x)
108 | NL2_counter += 1
109 | # Layer 3
110 | NL3_counter = 0
111 | if len(self.NL_3_idx) == 0: self.NL_3_idx = [-1]
112 | for i in range(len(self.base.layer3)):
113 | x = self.base.layer3[i](x)
114 | if i == self.NL_3_idx[NL3_counter]:
115 | _, C, H, W = x.shape
116 | x = self.NL_3[NL3_counter](x)
117 | NL3_counter += 1
118 | # Layer 4
119 | NL4_counter = 0
120 | if len(self.NL_4_idx) == 0: self.NL_4_idx = [-1]
121 | for i in range(len(self.base.layer4)):
122 | x = self.base.layer4[i](x)
123 | if i == self.NL_4_idx[NL4_counter]:
124 | _, C, H, W = x.shape
125 | x = self.NL_4[NL4_counter](x)
126 | NL4_counter += 1
127 | b, c, h, w = x.shape
128 | y = x.view(b, c, -1)
129 | p = 3.0
130 | global_feat = (torch.mean(y ** p, dim=-1) + 1e-12) ** (1 / p)
131 | feat = global_feat.view(global_feat.shape[0], -1) # flatten to (bs, 2048)
132 | bnfeat = self.bottleneck(feat) # normalize for angular softmax
133 | return bnfeat
134 |
135 | class IdClassifier(nn.Module):
136 | def __init__(self, in_planes = 2048, num_classes = settings.num_classes): # train 296, val 99
137 | super(IdClassifier, self).__init__()
138 | self.classifier = nn.Linear(in_planes, num_classes, bias=False)
139 | self.classifier.apply(weights_init_classifier)
140 | self.dropout = 0.5
141 | self.l2norm = Normalize(2)
142 | def forward(self, x):
143 | x = x.view(x.size(0), -1)
144 | if self.training:
145 | x = F.dropout(x,self.dropout,training = self.training)
146 | x = F.elu(x)
147 | else :
148 | x = self.l2norm(x)
149 | out = self.classifier(x)
150 | return out
151 |
152 | def weights_init_kaiming(m):
153 | classname = m.__class__.__name__
154 | if classname.find('Linear') != -1:
155 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
156 | nn.init.constant_(m.bias, 0.0)
157 | elif classname.find('Conv') != -1:
158 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
159 | if m.bias is not None:
160 | nn.init.constant_(m.bias, 0.0)
161 | elif classname.find('BatchNorm') != -1:
162 | if m.affine:
163 | nn.init.constant_(m.weight, 1.0)
164 | nn.init.constant_(m.bias, 0.0)
165 |
166 |
167 | def weights_init_classifier(m):
168 | classname = m.__class__.__name__
169 | if classname.find('Linear') != -1:
170 | nn.init.normal_(m.weight, std=0.001)
171 | if m.bias:
172 | nn.init.constant_(m.bias, 0.0)
173 |
174 |
175 | class Baseline(nn.Module):
176 | def __init__(self, last_stride, model_path):
177 | super(Baseline, self).__init__()
178 | self.base = ResNet(last_stride)
179 | self.base.load_param(model_path)
180 |
181 | def forward(self, x):
182 | return self.base(x) # (b, 2048, 1, 1)
183 |
184 | class Base_rgb(nn.Module):
185 | def __init__(self, last_stride, model_path):
186 | super(Base_rgb, self).__init__()
187 | self.base = ResNet(last_stride)
188 | self.base.load_param(model_path)
189 |
190 | layers = [3, 4, 6, 3]
191 | non_layers = [0, 2, 3, 0]
192 | self.NL_1 = nn.ModuleList(
193 | [Non_local(256) for i in range(non_layers[0])])
194 | self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
195 | def forward(self, x):
196 |
197 | x = self.base.conv1(x)
198 | x = self.base.bn1(x)
199 | x = self.base.relu(x)
200 | x = self.base.maxpool(x)
201 | NL1_counter = 0
202 | if len(self.NL_1_idx) == 0: self.NL_1_idx = [-1]
203 | for i in range(len(self.base.layer1)):
204 | x = self.base.layer1[i](x)
205 | if i == self.NL_1_idx[NL1_counter]:
206 | _, C, H, W = x.shape
207 | x = self.NL_1[NL1_counter](x)
208 | NL1_counter += 1
209 | return x
210 |
211 | class Base_ir(nn.Module):
212 | def __init__(self, last_stride, model_path):
213 | super(Base_ir, self).__init__()
214 | self.base = ResNet(last_stride)
215 | self.base.load_param(model_path)
216 |
217 | layers = [3, 4, 6, 3]
218 | non_layers = [0, 2, 3, 0]
219 | self.NL_1 = nn.ModuleList(
220 | [Non_local(256) for i in range(non_layers[0])])
221 | self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])])
222 | def forward(self, x):
223 |
224 | x = self.base.conv1(x)
225 | x = self.base.bn1(x)
226 | x = self.base.relu(x)
227 | x = self.base.maxpool(x)
228 | NL1_counter = 0
229 | if len(self.NL_1_idx) == 0: self.NL_1_idx = [-1]
230 | for i in range(len(self.base.layer1)):
231 | x = self.base.layer1[i](x)
232 | if i == self.NL_1_idx[NL1_counter]:
233 | _, C, H, W = x.shape
234 | x = self.NL_1[NL1_counter](x)
235 | NL1_counter += 1
236 | return x
--------------------------------------------------------------------------------
/settings.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import transforms
4 |
5 | G_lr = 3e-4
6 | BASE_LR = 3e-4
7 | BIAS_LR_FACTOR = 2
8 | WEIGHT_DECAY = 0.0005
9 | WEIGHT_DECAY_BIAS = 0.
10 | D_lr = 1e-4
11 | iter_sche = [10000, 20000, 30000]
12 |
13 | train_batch_size = 8
14 | val_batch_size = 16
15 |
16 | log_dir = '../logdir'
17 | show_dir = '../showdir'
18 | model_dir = '../models'
19 | data_folder = '/home/ggw/HaoShao/dataset/RegDB'
20 | pretrained_model_path = '/home/ggw/.cache/torch/checkpoints/resnet50-19c8e357.pth'
21 |
22 | model_path = os.path.join(model_dir, 'latest')
23 | save_steps = 5000
24 | latest_steps = 100
25 | val_step = 200
26 |
27 | num_workers = 4
28 | num_gpu = 1
29 | device_id = '1'
30 | num_classes = 296
31 | test_times = 10 # official setting
32 |
33 | # for showing logger
34 | logger = logging.getLogger('train')
35 | logger.setLevel(logging.INFO)
36 |
37 | ch = logging.StreamHandler()
38 | ch.setLevel(logging.INFO)
39 |
40 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
41 | ch.setFormatter(formatter)
42 | logger.addHandler(ch)
43 |
44 |
45 | ############################# Hyper-parameters ################################
46 | alpha = 1.0
47 | beta = 1.0
48 | at_margin = 1
49 |
50 | pixel_mean = [0.485, 0.456, 0.406]
51 | pixel_std = [0.229, 0.224, 0.225]
52 | inp_size = [384, 128]
53 |
54 | # transforms
55 |
56 | transforms_list = transforms.Compose([transforms.RectScale(*inp_size),
57 | transforms.RandomHorizontalFlip(),
58 | transforms.Pad(10),
59 | transforms.RandomCrop(inp_size),
60 | transforms.ToTensor(),
61 | transforms.Normalize(mean=pixel_mean,
62 | std=pixel_std),
63 | transforms.RandomErasing(probability=0.5, mean=pixel_mean)])
64 |
65 | test_transforms_list = transforms.Compose([
66 | transforms.RectScale(*inp_size),
67 | transforms.ToTensor(),
68 | transforms.Normalize(mean=pixel_mean,
69 | std=pixel_std)])
70 |
71 |
--------------------------------------------------------------------------------
/solver.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: liaoxingyu
4 | @contact: sherlockliao01@gmail.com
5 | """
6 | from bisect import bisect_right
7 | import torch
8 |
9 |
10 | # FIXME ideally this would be achieved with a CombinedLRScheduler,
11 | # separating MultiStepLR with WarmupLR
12 | # but the current LRScheduler design doesn't allow it
13 |
14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
15 | def __init__(
16 | self,
17 | optimizer,
18 | milestones,
19 | gamma=0.01,
20 | warmup_factor=1.0 / 3,
21 | warmup_iters=500,
22 | warmup_method="linear",
23 | last_epoch=-1,
24 | ):
25 | if not list(milestones) == sorted(milestones):
26 | raise ValueError(
27 | "Milestones should be a list of" " increasing integers. Got {}",
28 | milestones,
29 | )
30 |
31 | if warmup_method not in ("constant", "linear"):
32 | raise ValueError(
33 | "Only 'constant' or 'linear' warmup_method accepted"
34 | "got {}".format(warmup_method)
35 | )
36 | self.milestones = milestones
37 | self.gamma = gamma
38 | self.warmup_factor = warmup_factor
39 | self.warmup_iters = warmup_iters
40 | self.warmup_method = warmup_method
41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
42 |
43 | def get_lr(self):
44 | warmup_factor = 1
45 | if self.last_epoch < self.warmup_iters:
46 | if self.warmup_method == "constant":
47 | warmup_factor = self.warmup_factor
48 | elif self.warmup_method == "linear":
49 | alpha = self.last_epoch / self.warmup_iters
50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
51 | return [
52 | base_lr
53 | * warmup_factor
54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch)
55 | for base_lr in self.base_lrs
56 | ]
57 |
--------------------------------------------------------------------------------
/tensorboard.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | function rand() {
3 | min=$1
4 | max=$(($2-$min+1))
5 | num=$(($RANDOM+1000000000000))
6 | echo $(($num%$max+$min))
7 | }
8 |
9 | #rnd=$(rand 3000 12000)
10 | tensorboard --logdir ../logdir --host 0.0.0.0 --port 17650 --reload_interval 3
11 |
--------------------------------------------------------------------------------
/transforms.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from torchvision.transforms import *
3 | from PIL import Image
4 | import random
5 | import math
6 |
7 |
8 | class RectScale(object):
9 | def __init__(self, height, width, interpolation=Image.BILINEAR):
10 | self.height = height
11 | self.width = width
12 | self.interpolation = interpolation
13 |
14 | def __call__(self, img):
15 | w, h = img.size
16 | if h == self.height and w == self.width:
17 | return img
18 | return img.resize((self.width, self.height), self.interpolation)
19 |
20 |
21 |
22 | class RandomSizedRectCrop(object):
23 | def __init__(self, height, width, interpolation=Image.BILINEAR):
24 | self.height = height
25 | self.width = width
26 | self.interpolation = interpolation
27 |
28 | def __call__(self, img):
29 | for attempt in range(10):
30 | area = img.size[0] * img.size[1]
31 | target_area = random.uniform(0.64, 1.0) * area
32 | aspect_ratio = random.uniform(2, 3)
33 |
34 | h = int(round(math.sqrt(target_area * aspect_ratio)))
35 | w = int(round(math.sqrt(target_area / aspect_ratio)))
36 |
37 | if w <= img.size[0] and h <= img.size[1]:
38 | x1 = random.randint(0, img.size[0] - w)
39 | y1 = random.randint(0, img.size[1] - h)
40 |
41 | img = img.crop((x1, y1, x1 + w, y1 + h))
42 | assert(img.size == (w, h))
43 |
44 | return img.resize((self.width, self.height), self.interpolation)
45 |
46 | # Fallback
47 | scale = RectScale(self.height, self.width,
48 | interpolation=self.interpolation)
49 | return scale(img)
50 |
51 |
52 | class RandomErasing(object):
53 | """ Randomly selects a rectangle region in an image and erases its pixels.
54 | 'Random Erasing Data Augmentation' by Zhong et al.
55 | See https://arxiv.org/pdf/1708.04896.pdf
56 | Args:
57 | probability: The probability that the Random Erasing operation will be performed.
58 | sl: Minimum proportion of erased area against input image.
59 | sh: Maximum proportion of erased area against input image.
60 | r1: Minimum aspect ratio of erased area.
61 | mean: Erasing value.
62 | """
63 |
64 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)):
65 | self.probability = probability
66 | self.mean = mean
67 | self.sl = sl
68 | self.sh = sh
69 | self.r1 = r1
70 |
71 | def __call__(self, img):
72 |
73 | if random.uniform(0, 1) > self.probability:
74 | return img
75 |
76 | for attempt in range(100):
77 | area = img.size()[1] * img.size()[2]
78 |
79 | target_area = random.uniform(self.sl, self.sh) * area
80 | aspect_ratio = random.uniform(self.r1, 1 / self.r1)
81 |
82 | h = int(round(math.sqrt(target_area * aspect_ratio)))
83 | w = int(round(math.sqrt(target_area / aspect_ratio)))
84 |
85 | if w < img.size()[2] and h < img.size()[1]:
86 | x1 = random.randint(0, img.size()[1] - h)
87 | y1 = random.randint(0, img.size()[2] - w)
88 | if img.size()[0] == 3:
89 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
90 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
91 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
92 | else:
93 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
94 | return img
95 |
96 | return img
97 |
--------------------------------------------------------------------------------