├── README.md ├── reid ├── __init__.py ├── datasets │ ├── __init__.py │ ├── bases.py │ ├── cuhk03.py │ ├── duke_mr_tkl.py │ ├── duke_reid.py │ ├── duke_si_tkl.py │ ├── duke_vidreid.py │ ├── ilids_vid.py │ ├── market1501.py │ ├── mars.py │ ├── msmt17.py │ └── prid2011.py ├── dist_metric.py ├── evaluation_metrics │ ├── __init__.py │ ├── classification.py │ └── ranking.py ├── evaluators_image.py ├── evaluators_video.py ├── feature_extraction │ ├── __init__.py │ ├── cnn.py │ └── database.py ├── loss │ ├── __init__.py │ └── poset_G2G.py ├── metric_learning │ ├── __init__.py │ ├── euclidean.py │ └── kissme.py ├── models │ ├── __init__.py │ └── resnet_mt.py ├── trainers.py └── utils │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── dataset.py │ ├── preprocessor_image.py │ ├── preprocessor_video.py │ ├── sampler.py │ ├── sampler_mt.py │ └── transforms.py │ ├── iotools.py │ ├── logging.py │ ├── meters.py │ ├── osutils.py │ └── serialization.py ├── scripts ├── cuhk03.sh ├── duke_mr_tkl.sh ├── duke_reid.sh ├── duke_si_tkl.sh ├── ilids-vid.sh ├── market1501.sh ├── mars.sh ├── msmt17.sh └── prid2011.sh ├── taudl_image.py └── taudl_video.py /README.md: -------------------------------------------------------------------------------- 1 | # Tracklet Association Unsupervised Deep Learning (TAUDL) 2 | Pytorch implementation for our paper [Link](http://openaccess.thecvf.com/content_ECCV_2018/papers/Minxian_Li_Unsupervised_Person_Re-identification_ECCV_2018_paper.pdf). This code is based on the [Open-ReID](https://github.com/Cysu/open-reid) library. 3 | 4 | ## Citation 5 | Please cite the following paper in your publications if it helps your research: 6 | 7 | ``` 8 | @inproceedings{li2018unsupervised, 9 | title={Unsupervised person re-identification by deep learning tracklet association}, 10 | author={Li, Minxian and Zhu, Xiatian and Gong, Shaogang}, 11 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, 12 | pages={737--753}, 13 | year={2018} 14 | } 15 | -------------------------------------------------------------------------------- /reid/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation_metrics 5 | from . import feature_extraction 6 | from . import loss 7 | from . import metric_learning 8 | from . import models 9 | from . import utils 10 | from . import dist_metric 11 | # from . import evaluators 12 | from . import evaluators_video 13 | # from . import trainers_mt 14 | # from . import trainers_us 15 | 16 | __version__ = '0.2.0' 17 | -------------------------------------------------------------------------------- /reid/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import warnings 3 | 4 | from .cuhk03 import CUHK03 5 | from .market1501 import Market1501 6 | from .duke_reid import DukeMTMC_reID 7 | from .msmt17 import MSMT17 8 | 9 | from .prid2011 import PRID2011 10 | from .ilids_vid import iLIDS_VID 11 | from .mars import Mars 12 | from .duke_vidreid import DukeMTMC_VidReID 13 | from .duke_si_tkl import DukeMTMC_SITKL 14 | from .duke_mr_tkl import DukeMTMC_MRTKL 15 | 16 | 17 | __factory = { 18 | # image-based 19 | 'CUHK03': CUHK03, 20 | 'Market1501': Market1501, 21 | 'DukeMTMC-reID': DukeMTMC_reID, 22 | 'MSMT17': MSMT17, 23 | 24 | # video-based 25 | 'PRID2011': PRID2011, 26 | 'iLIDS-VID': iLIDS_VID, 27 | 'Mars': Mars, 28 | 'DukeMTMC-VideoReID': DukeMTMC_VidReID, 29 | 'DukeMTMC-SI-Tracklet': DukeMTMC_SITKL, 30 | 'DukeMTMC-MR-Tracklet': DukeMTMC_MRTKL 31 | } 32 | 33 | 34 | def names(): 35 | return sorted(__factory.keys()) 36 | 37 | 38 | def create(name, root, *args, **kwargs): 39 | """ 40 | Create a dataset instance. 41 | 42 | Parameters 43 | ---------- 44 | name : str 45 | The dataset name. Can be one of 'viper', 'cuhk01', 'cuhk03', 46 | 'market1501', and 'dukemtmc'. 47 | root : str 48 | The path to the dataset directory. 49 | split_id : int, optional 50 | The index of data split. Default: 0 51 | num_val : int or float, optional 52 | When int, it means the number of validation identities. When float, 53 | it means the proportion of validation to all the trainval. Default: 100 54 | download : bool, optional 55 | If True, will download the dataset. Default: False 56 | """ 57 | if name not in __factory: 58 | raise KeyError("Unknown dataset:", name) 59 | return __factory[name](root, *args, **kwargs) 60 | 61 | 62 | def get_dataset(name, root, *args, **kwargs): 63 | warnings.warn("get_dataset is deprecated. Use create instead.") 64 | return create(name, root, *args, **kwargs) 65 | -------------------------------------------------------------------------------- /reid/datasets/bases.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import numpy as np 5 | import ipdb 6 | 7 | class BaseDataset(object): 8 | """ 9 | Base class of reid dataset 10 | """ 11 | 12 | def get_imagedata_info(self, data): 13 | pids, cams = [], [] 14 | for _, _, pid, tidpc, _, camid in data: 15 | pids += [pid] 16 | cams += [camid] 17 | pids = set(pids) 18 | cams = set(cams) 19 | num_pids = len(pids) 20 | num_cams = len(cams) 21 | num_imgs = len(data) 22 | 23 | num_tidspc = [] 24 | for cam in cams: 25 | tid_index_list = [index for index, (_, _, _, _, _, camid) in enumerate(data) if camid == cam] 26 | num_tidspc.append(len(tid_index_list)) 27 | return num_pids, num_imgs, num_cams, num_tidspc 28 | 29 | def get_videodata_info(self, data, return_tracklet_stats=False): 30 | pids, cams, tracklet_stats = [], [], [] 31 | for img_paths, tid, pid, tid_sub, pid_sub, camid in data: 32 | pids += [pid] 33 | cams += [camid] 34 | tracklet_stats += [len(img_paths)] 35 | pids = set(pids) 36 | cams = set(cams) 37 | num_pids = len(pids) 38 | num_cams = len(cams) 39 | num_tkls = len(data) 40 | 41 | num_tkls_sub, num_pids_sub = [], [] 42 | tkl_count_per_pid_sub = [] 43 | for cam in range(num_cams): 44 | indexes = [index for index, (_, _, _, _, _, camid) in enumerate(data) if camid == cam] 45 | tids_sub = [data[index][3] for index in indexes] 46 | pids_sub = [data[index][2] for index in indexes] 47 | pid_sub_list = list(set(pids_sub)) 48 | num_tkls_sub.append(len(tids_sub)) 49 | num_pids_sub.append(len(pid_sub_list)) 50 | 51 | # count tkls number per pid_sub 52 | for pid_sub in pid_sub_list: 53 | if not pid_sub == 702: 54 | tkl_count_per_pid_sub.append(pids_sub.count(pid_sub)) 55 | 56 | # with open('duketkl_tkls_count.txt', 'w') as f: 57 | # for item in tkl_count_per_pid_sub: 58 | # f.write("%s\n" % item) 59 | 60 | if return_tracklet_stats: 61 | return num_pids, num_tkls, num_cams, tracklet_stats 62 | return num_pids, num_tkls, num_cams, num_pids_sub, num_tkls_sub 63 | 64 | def print_dataset_statistics(self): 65 | raise NotImplementedError 66 | 67 | 68 | class BaseImageDataset(BaseDataset): 69 | """ 70 | Base class of image reid dataset 71 | """ 72 | 73 | def print_dataset_statistics(self, train, query, gallery): 74 | num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \ 75 | self.get_videodata_info(train, return_tracklet_stats=True) 76 | 77 | num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \ 78 | self.get_videodata_info(query, return_tracklet_stats=True) 79 | 80 | num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \ 81 | self.get_videodata_info(gallery, return_tracklet_stats=True) 82 | 83 | tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats 84 | min_num = np.min(tracklet_stats) 85 | max_num = np.max(tracklet_stats) 86 | avg_num = np.mean(tracklet_stats) 87 | 88 | print("Dataset statistics:") 89 | print(" -------------------------------------------") 90 | print(" subset | # ids | # tracklets | # cameras") 91 | print(" -------------------------------------------") 92 | print(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_cams)) 93 | print(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_cams)) 94 | print(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams)) 95 | print(" -------------------------------------------") 96 | print(" number of images per tracklet: {} ~ {}, average {:.2f}".format(min_num, max_num, avg_num)) 97 | print(" -------------------------------------------") 98 | 99 | 100 | class BaseVideoDataset(BaseDataset): 101 | """ 102 | Base class of video reid dataset 103 | """ 104 | 105 | def print_dataset_statistics(self, train, query, gallery): 106 | num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \ 107 | self.get_videodata_info(train, return_tracklet_stats=True) 108 | 109 | num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \ 110 | self.get_videodata_info(query, return_tracklet_stats=True) 111 | 112 | num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \ 113 | self.get_videodata_info(gallery, return_tracklet_stats=True) 114 | 115 | tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats 116 | min_num = np.min(tracklet_stats) 117 | max_num = np.max(tracklet_stats) 118 | avg_num = np.mean(tracklet_stats) 119 | 120 | print("Dataset statistics:") 121 | print(" -------------------------------------------") 122 | print(" subset | # ids | # tracklets | # cameras") 123 | print(" -------------------------------------------") 124 | print(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_cams)) 125 | print(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_cams)) 126 | print(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams)) 127 | print(" -------------------------------------------") 128 | print(" number of images per tracklet: {} ~ {}, average {:.2f}".format(min_num, max_num, avg_num)) 129 | print(" -------------------------------------------") -------------------------------------------------------------------------------- /reid/datasets/cuhk03.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import glob 7 | import re 8 | import sys 9 | import urllib 10 | import tarfile 11 | import zipfile 12 | import os.path as osp 13 | from scipy.io import loadmat 14 | import numpy as np 15 | import h5py 16 | from scipy.misc import imsave 17 | from ..utils.iotools import mkdir_if_missing, write_json, read_json 18 | 19 | import ipdb 20 | 21 | 22 | class CUHK03(): 23 | """ 24 | CUHK03 25 | Reference: 26 | Li et al. DeepReID: Deep Filter Pairing Neural Network for Person Re-identification. CVPR 2014. 27 | URL: http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html#! 28 | 29 | Dataset statistics: 30 | # identities: 1360 31 | # images: 13164 32 | # cameras: 6 33 | # splits: 20 (classic) 34 | Args: 35 | split_id (int): split index (default: 0) 36 | cuhk03_labeled (bool): whether to load labeled images; if false, detected images are loaded (default: False) 37 | """ 38 | 39 | def __init__(self, root, split_id=0, cuhk03_labeled=False, cuhk03_classic_split=True, min_seq_len=0): 40 | self.dataset_dir = osp.join(root, '') 41 | self.data_dir = osp.join(self.dataset_dir, 'cuhk03_release') 42 | self.raw_mat_path = osp.join(self.data_dir, 'cuhk-03.mat') 43 | 44 | self.imgs_detected_dir = osp.join(self.dataset_dir, 'images_detected') 45 | self.imgs_labeled_dir = osp.join(self.dataset_dir, 'images_labeled') 46 | 47 | self.split_classic_det_json_path = osp.join(self.dataset_dir, 'splits_classic_detected.json') 48 | self.split_classic_lab_json_path = osp.join(self.dataset_dir, 'splits_classic_labeled.json') 49 | 50 | self.split_new_det_json_path = osp.join(self.dataset_dir, 'splits_new_detected.json') 51 | self.split_new_lab_json_path = osp.join(self.dataset_dir, 'splits_new_labeled.json') 52 | 53 | self.split_new_det_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_detected.mat') 54 | self.split_new_lab_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_labeled.mat') 55 | 56 | self._check_before_run() 57 | self._preprocess() 58 | 59 | if cuhk03_labeled: 60 | image_type = 'labeled' 61 | split_path = self.split_classic_lab_json_path if cuhk03_classic_split else self.split_new_lab_json_path 62 | else: 63 | image_type = 'detected' 64 | split_path = self.split_classic_det_json_path if cuhk03_classic_split else self.split_new_det_json_path 65 | 66 | splits = read_json(split_path) 67 | assert split_id < len(splits), "Condition split_id ({}) < len(splits) ({}) is false".format(split_id, len(splits)) 68 | split = splits[split_id] 69 | print("Split index = {}".format(split_id)) 70 | 71 | # the raw format (based on imgs) of element is [img_path, pid, cam] 72 | train = split['train'] 73 | query = split['query'] 74 | gallery = split['gallery'] 75 | 76 | # get train set 77 | tid_start = 0 78 | train_set, num_tkls_train, num_persons_train, \ 79 | num_tkls_pc_train, num_persons_pc_train, \ 80 | trainval_len_pertkl, trainval_len_pertkl_percam = \ 81 | self.Build_Set(train, relabel=True, min_seq_len=min_seq_len, tid_start=tid_start) 82 | # get query set 83 | tid_start = 0 84 | query_set, num_tkls_query, num_persons_query, \ 85 | num_tkls_pc_query, num_pids_pc_query, \ 86 | query_len_pertkl, query_len_pertkl_percam = \ 87 | self.Build_Set(query, relabel=False, min_seq_len=min_seq_len, tid_start=tid_start) 88 | # get gallery set 89 | tid_start = num_tkls_query 90 | gallery_set, num_tkls_gallery, num_persons_gallery, \ 91 | num_tkls_pc_gallery, num_pids_pc_gallery, \ 92 | gallery_len_pertkl, gallery_len_pertkl_percam = \ 93 | self.Build_Set(gallery, relabel=False, min_seq_len=min_seq_len, tid_start=tid_start) 94 | 95 | num_total_pids = num_persons_train + num_persons_query 96 | num_train_imgs = sum(trainval_len_pertkl) 97 | num_query_imgs = sum(query_len_pertkl) 98 | num_gallery_imgs = sum(gallery_len_pertkl) 99 | num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs 100 | 101 | self.train = train_set 102 | self.num_train_pids = num_persons_train 103 | self.num_train_pids_sub = num_persons_pc_train 104 | self.num_train_tids_sub = num_tkls_pc_train 105 | self.query = query_set 106 | self.gallery = gallery_set 107 | 108 | def _check_before_run(self): 109 | """Check if all files are available before going deeper""" 110 | if not osp.exists(self.dataset_dir): 111 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 112 | if not osp.exists(self.data_dir): 113 | raise RuntimeError("'{}' is not available".format(self.data_dir)) 114 | if not osp.exists(self.raw_mat_path): 115 | raise RuntimeError("'{}' is not available".format(self.raw_mat_path)) 116 | if not osp.exists(self.split_new_det_mat_path): 117 | raise RuntimeError("'{}' is not available".format(self.split_new_det_mat_path)) 118 | if not osp.exists(self.split_new_lab_mat_path): 119 | raise RuntimeError("'{}' is not available".format(self.split_new_lab_mat_path)) 120 | 121 | def _preprocess(self): 122 | """ 123 | This function is a bit complex and ugly, what it does is 124 | 1. Extract data from cuhk-03.mat and save as png images. 125 | 2. Create 20 classic splits. (Li et al. CVPR'14) 126 | 3. Create new split. (Zhong et al. CVPR'17) 127 | """ 128 | print( 129 | "Note: if root path is changed, the previously generated json files need to be re-generated (delete them first)") 130 | if osp.exists(self.imgs_labeled_dir) and \ 131 | osp.exists(self.imgs_detected_dir) and \ 132 | osp.exists(self.split_classic_det_json_path) and \ 133 | osp.exists(self.split_classic_lab_json_path) and \ 134 | osp.exists(self.split_new_det_json_path) and \ 135 | osp.exists(self.split_new_lab_json_path): 136 | return 137 | 138 | mkdir_if_missing(self.imgs_detected_dir) 139 | mkdir_if_missing(self.imgs_labeled_dir) 140 | 141 | print("Extract image data from {} and save as png".format(self.raw_mat_path)) 142 | mat = h5py.File(self.raw_mat_path, 'r') 143 | 144 | def _deref(ref): 145 | return mat[ref][:].T 146 | 147 | def _process_images(img_refs, campid, pid, save_dir): 148 | img_paths = [] # Note: some persons only have images for one view 149 | for imgid, img_ref in enumerate(img_refs): 150 | img = _deref(img_ref) 151 | # skip empty cell 152 | if img.size == 0 or img.ndim < 3: continue 153 | # images are saved with the following format, index-1 (ensure uniqueness) 154 | # campid: index of camera pair (1-5) 155 | # pid: index of person in 'campid'-th camera pair 156 | # viewid: index of view, {1, 2} 157 | # imgid: index of image, (1-10) 158 | viewid = 1 if imgid < 5 else 2 159 | img_name = '{:01d}_{:03d}_{:01d}_{:02d}.png'.format(campid + 1, pid + 1, viewid, imgid + 1) 160 | img_path = osp.join(save_dir, img_name) 161 | imsave(img_path, img) 162 | img_paths.append(img_path) 163 | return img_paths 164 | 165 | def _extract_img(name): 166 | print("Processing {} images (extract and save) ...".format(name)) 167 | meta_data = [] 168 | imgs_dir = self.imgs_detected_dir if name == 'detected' else self.imgs_labeled_dir 169 | for campid, camp_ref in enumerate(mat[name][0]): 170 | camp = _deref(camp_ref) 171 | num_pids = camp.shape[0] 172 | for pid in range(num_pids): 173 | img_paths = _process_images(camp[pid, :], campid, pid, imgs_dir) 174 | assert len(img_paths) > 0, "campid{}-pid{} has no images".format(campid, pid) 175 | meta_data.append((campid + 1, pid + 1, img_paths)) 176 | print("done camera pair {} with {} identities".format(campid + 1, num_pids)) 177 | return meta_data 178 | 179 | meta_detected = _extract_img('detected') 180 | meta_labeled = _extract_img('labeled') 181 | 182 | def _extract_classic_split(meta_data, test_split): 183 | train, test = [], [] 184 | num_train_pids, num_test_pids = 0, 0 185 | num_train_imgs, num_test_imgs = 0, 0 186 | for i, (campid, pid, img_paths) in enumerate(meta_data): 187 | 188 | if [campid, pid] in test_split: 189 | for img_path in img_paths: 190 | camid = int(osp.basename(img_path).split('_')[2]) 191 | test.append((img_path, num_test_pids, camid)) 192 | num_test_pids += 1 193 | num_test_imgs += len(img_paths) 194 | else: 195 | for img_path in img_paths: 196 | camid = int(osp.basename(img_path).split('_')[2]) 197 | train.append((img_path, num_train_pids, camid)) 198 | num_train_pids += 1 199 | num_train_imgs += len(img_paths) 200 | return train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs 201 | 202 | print("Creating classic splits (# = 20) ...") 203 | splits_classic_det, splits_classic_lab = [], [] 204 | for split_ref in mat['testsets'][0]: 205 | test_split = _deref(split_ref).tolist() 206 | 207 | # create split for detected images 208 | train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \ 209 | _extract_classic_split(meta_detected, test_split) 210 | splits_classic_det.append({ 211 | 'train': train, 'query': test, 'gallery': test, 212 | 'num_train_pids': num_train_pids, 'num_train_imgs': num_train_imgs, 213 | 'num_query_pids': num_test_pids, 'num_query_imgs': num_test_imgs, 214 | 'num_gallery_pids': num_test_pids, 'num_gallery_imgs': num_test_imgs, 215 | }) 216 | 217 | # create split for labeled images 218 | train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \ 219 | _extract_classic_split(meta_labeled, test_split) 220 | splits_classic_lab.append({ 221 | 'train': train, 'query': test, 'gallery': test, 222 | 'num_train_pids': num_train_pids, 'num_train_imgs': num_train_imgs, 223 | 'num_query_pids': num_test_pids, 'num_query_imgs': num_test_imgs, 224 | 'num_gallery_pids': num_test_pids, 'num_gallery_imgs': num_test_imgs, 225 | }) 226 | 227 | write_json(splits_classic_det, self.split_classic_det_json_path) 228 | write_json(splits_classic_lab, self.split_classic_lab_json_path) 229 | 230 | def _extract_set(filelist, pids, pid2label, idxs, img_dir, relabel): 231 | tmp_set = [] 232 | unique_pids = set() 233 | for idx in idxs: 234 | img_name = filelist[idx][0] 235 | camid = int(img_name.split('_')[2]) 236 | pid = pids[idx] 237 | if relabel: pid = pid2label[pid] 238 | img_path = osp.join(img_dir, img_name) 239 | tmp_set.append((img_path, int(pid), camid)) 240 | unique_pids.add(pid) 241 | return tmp_set, len(unique_pids), len(idxs) 242 | 243 | def _extract_new_split(split_dict, img_dir): 244 | train_idxs = split_dict['train_idx'].flatten() - 1 # index-0 245 | pids = split_dict['labels'].flatten() 246 | train_pids = set(pids[train_idxs]) 247 | pid2label = {pid: label for label, pid in enumerate(train_pids)} 248 | query_idxs = split_dict['query_idx'].flatten() - 1 249 | gallery_idxs = split_dict['gallery_idx'].flatten() - 1 250 | filelist = split_dict['filelist'].flatten() 251 | train_info = _extract_set(filelist, pids, pid2label, train_idxs, img_dir, relabel=True) 252 | query_info = _extract_set(filelist, pids, pid2label, query_idxs, img_dir, relabel=False) 253 | gallery_info = _extract_set(filelist, pids, pid2label, gallery_idxs, img_dir, relabel=False) 254 | return train_info, query_info, gallery_info 255 | 256 | print("Creating new splits for detected images (767/700) ...") 257 | train_info, query_info, gallery_info = _extract_new_split( 258 | loadmat(self.split_new_det_mat_path), 259 | self.imgs_detected_dir, 260 | ) 261 | splits = [{ 262 | 'train': train_info[0], 'query': query_info[0], 'gallery': gallery_info[0], 263 | 'num_train_pids': train_info[1], 'num_train_imgs': train_info[2], 264 | 'num_query_pids': query_info[1], 'num_query_imgs': query_info[2], 265 | 'num_gallery_pids': gallery_info[1], 'num_gallery_imgs': gallery_info[2], 266 | }] 267 | write_json(splits, self.split_new_det_json_path) 268 | 269 | print("Creating new splits for labeled images (767/700) ...") 270 | train_info, query_info, gallery_info = _extract_new_split( 271 | loadmat(self.split_new_lab_mat_path), 272 | self.imgs_labeled_dir, 273 | ) 274 | splits = [{ 275 | 'train': train_info[0], 'query': query_info[0], 'gallery': gallery_info[0], 276 | 'num_train_pids': train_info[1], 'num_train_imgs': train_info[2], 277 | 'num_query_pids': query_info[1], 'num_query_imgs': query_info[2], 278 | 'num_gallery_pids': gallery_info[1], 'num_gallery_imgs': gallery_info[2], 279 | }] 280 | write_json(splits, self.split_new_lab_json_path) 281 | 282 | def Build_Set(self, dataset_raw, relabel=False, min_seq_len=0, tid_start=0): 283 | # the format of dataset_raw is [img_path, pid, cam] 284 | 285 | # Step_1: compute the tracklet number of dataset_raw 286 | pid_container = set() 287 | cam_container = set() 288 | for i, (_, pid, cam) in enumerate(dataset_raw): 289 | pid_container.add(pid) 290 | cam_container.add(cam) 291 | num_pids = len(pid_container) 292 | num_cams = len(cam_container) 293 | num_imgs = len(dataset_raw) 294 | 295 | # Step_2: get the dataset(based on tkl) 296 | dataset = [] 297 | len_pertkl = [] 298 | tid = tid_start 299 | for pid_idx, pid in enumerate(pid_container): 300 | for cam_idx, cam in enumerate(cam_container): 301 | img_names = [] 302 | for i in range(num_imgs): 303 | if dataset_raw[i][1] == pid and dataset_raw[i][2] == cam: 304 | img_names.append(dataset_raw[i][0]) 305 | if len(img_names) > min_seq_len: 306 | img_names = tuple(img_names) 307 | tid_pc = -1 308 | pid_pc = -1 309 | cam = cam - 1 # cam start from 0 310 | dataset.append((img_names, tid, pid, tid_pc, pid_pc, cam)) 311 | tid += 1 312 | len_pertkl.append(len(img_names)) 313 | num_tkls = len(dataset) 314 | 315 | # ----------------------------- Next: get the tid_pc and pid_pc -----------------------------# 316 | num_tkls_pc = [] 317 | num_pids_pc = [] 318 | len_pertkl_pc = [] 319 | for c in range(num_cams): 320 | # count tid per camera 321 | tkl_index_pc = [index for index, (_, _, _, _, _, camid) in enumerate(dataset) if camid == c] 322 | num_tkls_pc.append(len(tkl_index_pc)) 323 | 324 | # count pid per camera 325 | pid_list_pc = [dataset[i][2] for i in tkl_index_pc] 326 | unique_pid_list_pc = list(set(pid_list_pc)) 327 | num_pids_pc.append(len(unique_pid_list_pc)) 328 | 329 | # count image number per tracklet 330 | len_pertkl_pc.append([len_pertkl[i] for i in tkl_index_pc]) 331 | 332 | pid_percam2label = {pid: label for label, pid in enumerate(unique_pid_list_pc)} 333 | for i, tkl_index in enumerate(tkl_index_pc): 334 | tid = dataset[tkl_index][1] 335 | pid = dataset[tkl_index][2] 336 | tid_pc = i 337 | pid_pc = pid_percam2label[pid] 338 | cam = dataset[tkl_index][5] 339 | dataset[tkl_index] = (dataset[tkl_index][0], tid, pid, tid_pc, pid_pc, cam) 340 | assert num_tkls == sum(num_tkls_pc) 341 | # # check if pid starts from 0 and increments with 1 342 | # for idx, pid in enumerate(pid_container): 343 | # assert idx == pid, "See code comment for explanation" 344 | return dataset, num_tkls, num_pids, num_tkls_pc, num_pids_pc, len_pertkl, len_pertkl_pc -------------------------------------------------------------------------------- /reid/datasets/duke_mr_tkl.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import glob 7 | import re 8 | import sys 9 | import urllib 10 | import tarfile 11 | import zipfile 12 | import os.path as osp 13 | from scipy.io import loadmat 14 | import numpy as np 15 | import h5py 16 | from scipy.misc import imsave 17 | 18 | from reid.utils.iotools import mkdir_if_missing, write_json, read_json 19 | from .bases import BaseVideoDataset 20 | 21 | import ipdb 22 | 23 | 24 | class DukeMTMC_MRTKL(BaseVideoDataset): 25 | """ 26 | DukeMTMCVidReID 27 | 28 | Reference: 29 | Wu et al. Exploit the Unknown Gradually: One-Shot Video-Based Person 30 | Re-Identification by Stepwise Learning. CVPR 2018. 31 | 32 | URL: https://github.com/Yu-Wu/DukeMTMC-VideoReID 33 | 34 | Dataset statistics: 35 | # identities: 702 (train) + 1110 (test, 702 query) 36 | # tracklets: 2196 (train) + 2636 (test) 37 | """ 38 | dataset_dir = '' 39 | 40 | def __init__(self, root='data', min_seq_len=0, verbose=True, **kwargs): 41 | self.dataset_dir = osp.join(root, self.dataset_dir) 42 | # self.dataset_url = '' 43 | self.train_dir = osp.join(self.dataset_dir, 'train') 44 | self.query_dir = osp.join(self.dataset_dir, 'query') 45 | self.gallery_dir = osp.join(self.dataset_dir, 'gallery') 46 | self.split_train_json_path = osp.join(self.dataset_dir, 'info/train_ext_filter.json') 47 | self.split_query_json_path = osp.join(self.dataset_dir, 'info/query_ext.json') 48 | self.split_gallery_json_path = osp.join(self.dataset_dir, 'info/gallery_ext.json') 49 | 50 | self.min_seq_len = min_seq_len 51 | # self._download_data() 52 | self._check_before_run() 53 | print("Note: if root path is changed, the previously generated json files need to be re-generated (so delete them first)") 54 | 55 | train = self._process_train_dir(self.train_dir, self.split_train_json_path, relabel=True, filter=False) 56 | query = self._process_test_dir(self.query_dir, self.split_query_json_path, relabel=False) 57 | gallery = self._process_test_dir(self.gallery_dir, self.split_gallery_json_path, relabel=False) 58 | 59 | if verbose: 60 | print("=> DukeMTMC-VideoReID loaded") 61 | self.print_dataset_statistics(train, query, gallery) 62 | 63 | self.train = train 64 | self.query = query 65 | self.gallery = gallery 66 | 67 | self.num_train_pids, self.num_train_tkls, self.num_train_cams, self.num_train_pids_sub, self.num_train_tids_sub \ 68 | = self.get_videodata_info(self.train) 69 | self.num_query_pids, self.num_query_tkls, self.num_query_cams, _, _ = self.get_videodata_info(self.query) 70 | self.num_gallery_pids, self.num_gallery_tkls, self.num_gallery_cams, _, _ = self.get_videodata_info(self.gallery) 71 | 72 | # def _download_data(self): 73 | # if osp.exists(self.dataset_dir): 74 | # print("This dataset has been downloaded.") 75 | # return 76 | # 77 | # print("Creating directory {}".format(self.dataset_dir)) 78 | # mkdir_if_missing(self.dataset_dir) 79 | # fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 80 | # 81 | # print("Downloading DukeMTMC-VideoReID dataset") 82 | # urllib.urlretrieve(self.dataset_url, fpath) 83 | # 84 | # print("Extracting files") 85 | # zip_ref = zipfile.ZipFile(fpath, 'r') 86 | # zip_ref.extractall(self.dataset_dir) 87 | # zip_ref.close() 88 | 89 | def _check_before_run(self): 90 | """Check if all files are available before going deeper""" 91 | if not osp.exists(self.dataset_dir): 92 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 93 | if not osp.exists(self.train_dir): 94 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 95 | if not osp.exists(self.query_dir): 96 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 97 | if not osp.exists(self.gallery_dir): 98 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 99 | 100 | def _process_train_dir(self, dir_path, json_path, relabel, tid_start=0, filter=True): 101 | if osp.exists(json_path): 102 | print("=> {} generated before, awesome!".format(json_path)) 103 | split = read_json(json_path) 104 | return split['tracklets'] 105 | 106 | print("=> Automatically generating split (might take a while for the first time, have a coffe)") 107 | pdirs = sorted(glob.glob(osp.join(dir_path, '*'))) # avoid .DS_Store 108 | print("Processing '{}' with {} person identities".format(dir_path, len(pdirs))) 109 | 110 | pid_container = set() 111 | for pdir in pdirs: 112 | pid = int(osp.basename(pdir)) 113 | pid_container.add(pid) 114 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 115 | 116 | cam_list = [] 117 | tracklets = [] 118 | tid = tid_start # tid is the key in video Re-ID (as the fname in image Re-ID) 119 | for pdir in pdirs: 120 | pid = int(osp.basename(pdir)) 121 | if filter: 122 | if pid == -1: 123 | continue 124 | if relabel: pid = pid2label[pid] 125 | tdirs = sorted(glob.glob(osp.join(pdir, '*'))) 126 | 127 | for tdir in tdirs: 128 | raw_img_paths = sorted(glob.glob(osp.join(tdir, '*.jpg'))) 129 | num_imgs = len(raw_img_paths) 130 | if num_imgs < self.min_seq_len: 131 | continue 132 | 133 | img_paths = sorted(glob.glob(osp.join(tdir, '*.jpg'))) 134 | img_name = osp.basename(img_paths[0]) 135 | # new naming format: 0001_C6_030823.jpg 136 | camid = int(img_name[6]) - 1 137 | 138 | img_paths = tuple(img_paths) 139 | tid_sub = -1 140 | pid_sub = -1 141 | tracklets.append((img_paths, tid, pid, tid_sub, pid_sub, camid)) 142 | tid += 1 143 | cam_list.append(camid) 144 | 145 | num_cams = len(list(set(cam_list))) 146 | start_tid_uic = 0 147 | for cam_index in range(num_cams): 148 | # count tid per camera 149 | tkl_index_list = [index for index, (_, _, _, _, _, camid) in enumerate(tracklets) if camid == cam_index] 150 | # count pid per camera 151 | pid_list_sub = [tracklets[j][2] for j in tkl_index_list] 152 | unique_pid_list_percam = list(set(pid_list_sub)) 153 | start_tid_uic += len(unique_pid_list_percam) 154 | 155 | pid_percam2label = {pid: label for label, pid in enumerate(unique_pid_list_percam)} 156 | 157 | for index, tkl_index in enumerate(tkl_index_list): 158 | img_paths = tracklets[tkl_index][0] 159 | tid = tracklets[tkl_index][1] 160 | pid = tracklets[tkl_index][2] 161 | tid_sub = index 162 | pid_sub = pid_percam2label[pid] 163 | camid = tracklets[tkl_index][5] 164 | tracklets[tkl_index] = (img_paths, tid, pid, tid_sub, pid_sub, camid) 165 | 166 | print("Saving split to {}".format(json_path)) 167 | split_dict = { 168 | 'tracklets': tracklets, 169 | } 170 | write_json(split_dict, json_path) 171 | return tracklets 172 | 173 | def _process_test_dir(self, dir_path, json_path, relabel): 174 | if osp.exists(json_path): 175 | print("=> {} generated before, awesome!".format(json_path)) 176 | split = read_json(json_path) 177 | return split['tracklets'] 178 | 179 | print("=> Automatically generating split (might take a while for the first time, have a coffe)") 180 | pdirs = sorted(glob.glob(osp.join(dir_path, '*'))) # avoid .DS_Store 181 | print("Processing '{}' with {} person identities".format(dir_path, len(pdirs))) 182 | 183 | pid_container = set() 184 | for pdir in pdirs: 185 | pid = int(osp.basename(pdir)) 186 | pid_container.add(pid) 187 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 188 | 189 | cam_list = [] 190 | tracklets = [] 191 | for pdir in pdirs: 192 | pid = int(osp.basename(pdir)) 193 | if relabel: pid = pid2label[pid] 194 | tdirs = sorted(glob.glob(osp.join(pdir, '*'))) 195 | 196 | for tdir in tdirs: 197 | tid = int(osp.basename(tdir)[3:]) 198 | raw_img_paths = sorted(glob.glob(osp.join(tdir, '*.jpg'))) 199 | num_imgs = len(raw_img_paths) 200 | if num_imgs < self.min_seq_len: 201 | continue 202 | 203 | img_paths = sorted(glob.glob(osp.join(tdir, '*.jpg'))) 204 | img_name = osp.basename(img_paths[0]) 205 | # new naming format: 0001_C6_030823.jpg 206 | camid = int(img_name[6]) - 1 207 | img_paths = tuple(img_paths) 208 | tid_sub = -1 209 | pid_sub = -1 210 | tracklets.append((img_paths, tid, pid, tid_sub, pid_sub, camid)) 211 | cam_list.append(camid) 212 | 213 | num_cams = len(list(set(cam_list))) 214 | start_tid_uic = 0 215 | for cam_index in range(num_cams): 216 | # count tid per camera 217 | tkl_index_list = [index for index, (_, _, _, _, _, camid) in enumerate(tracklets) if camid == cam_index] 218 | # count pid per camera 219 | pid_list_sub = [tracklets[j][2] for j in tkl_index_list] 220 | unique_pid_list_percam = list(set(pid_list_sub)) 221 | start_tid_uic += len(unique_pid_list_percam) 222 | 223 | pid_percam2label = {pid: label for label, pid in enumerate(unique_pid_list_percam)} 224 | 225 | for index, tkl_index in enumerate(tkl_index_list): 226 | img_paths = tracklets[tkl_index][0] 227 | tid = tracklets[tkl_index][1] 228 | pid = tracklets[tkl_index][2] 229 | tid_sub = index 230 | pid_sub = pid_percam2label[pid] 231 | camid = tracklets[tkl_index][5] 232 | tracklets[tkl_index] = (img_paths, tid, pid, tid_sub, pid_sub, camid) 233 | 234 | print("Saving split to {}".format(json_path)) 235 | split_dict = { 236 | 'tracklets': tracklets, 237 | } 238 | write_json(split_dict, json_path) 239 | 240 | return tracklets -------------------------------------------------------------------------------- /reid/datasets/duke_reid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import glob 6 | import re 7 | import os.path as osp 8 | from .bases import BaseImageDataset 9 | 10 | 11 | class DukeMTMC_reID(BaseImageDataset): 12 | """ 13 | DukeMTMC-reID 14 | Reference: 15 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 16 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 17 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 18 | 19 | Dataset statistics: 20 | # identities: 1404 (train + query) 21 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 22 | # cameras: 8 23 | """ 24 | 25 | def __init__(self, root='data', verbose=True, **kwargs): 26 | super(DukeMTMC_reID, self).__init__() 27 | self.dataset_dir = root 28 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 29 | self.query_dir = osp.join(self.dataset_dir, 'query') 30 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 31 | 32 | self._check_before_run() 33 | 34 | train = self._struct_data(self.train_dir, relabel=True) 35 | query = self._struct_data(self.query_dir, relabel=False) 36 | gallery = self._struct_data(self.gallery_dir, relabel=False) 37 | 38 | if verbose: 39 | print("=> Market1501 loaded") 40 | self.print_dataset_statistics(train, query, gallery) 41 | 42 | self.train = train 43 | self.query = query 44 | self.gallery = gallery 45 | 46 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_tids_sub = self.get_imagedata_info(self.train) 47 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_tids_sub = self.get_imagedata_info(self.query) 48 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_tids_sub = self.get_imagedata_info(self.gallery) 49 | 50 | def _check_before_run(self): 51 | """Check if all files are available before going deeper""" 52 | if not osp.exists(self.dataset_dir): 53 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 54 | if not osp.exists(self.train_dir): 55 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 56 | if not osp.exists(self.query_dir): 57 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 58 | if not osp.exists(self.gallery_dir): 59 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 60 | 61 | def _struct_data(self, dir_path, relabel=False): 62 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 63 | pattern = re.compile(r'([-\d]+)_c(\d)') 64 | 65 | pid_container = set() 66 | for img_path in img_paths: 67 | pid, _ = map(int, pattern.search(img_path).groups()) 68 | if pid == -1: continue # junk images are just ignored 69 | pid_container.add(pid) 70 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 71 | 72 | img_dataset = [] 73 | tid = -1 74 | tid_pc = -1 75 | pid_pc = -1 76 | cam_list = [] 77 | for img_path in img_paths: 78 | pid, camid = map(int, pattern.search(img_path).groups()) 79 | assert 1 <= camid <= 8 80 | camid -= 1 # index starts from 0 81 | if relabel: pid = pid2label[pid] 82 | cam_list.append(camid) 83 | img_dataset.append((img_path, tid, pid, tid_pc, pid_pc, camid)) 84 | 85 | tkl_dataset = [] 86 | tid = -1 87 | cam_list = list(set(cam_list)) 88 | for cam in cam_list: 89 | tid_index_list = [index for index, (_, _, _, _, _, camid) in enumerate(img_dataset) if camid == cam] 90 | pid_list_pc = [img_dataset[i][2] for i in tid_index_list] 91 | unique_pid_list_pc = list(set(pid_list_pc)) 92 | pid_percam2label = {pid: label for label, pid in enumerate(unique_pid_list_pc)} 93 | 94 | for insert_pid in unique_pid_list_pc: 95 | img_index_list = [index for index, (_, _, pid, _, _, camid) in enumerate(img_dataset) if camid == cam and pid == insert_pid] 96 | img_names = tuple([osp.join(img_dataset[index][0]) for index in img_index_list]) 97 | pid_pc = pid_percam2label[insert_pid] 98 | tid_pc = pid_pc 99 | tid += 1 100 | pid = insert_pid 101 | tkl_dataset.append((img_names, tid, pid, tid_pc, pid_pc, cam)) 102 | return tkl_dataset -------------------------------------------------------------------------------- /reid/datasets/duke_si_tkl.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | import glob 7 | import os.path as osp 8 | from reid.utils.iotools import mkdir_if_missing, write_json, read_json 9 | from .bases import BaseVideoDataset 10 | 11 | import ipdb 12 | 13 | 14 | class DukeMTMC_SITKL(BaseVideoDataset): 15 | """ 16 | DukeMTMC-SI-Tracklet 17 | 18 | Reference: 19 | Minxian Li, Xiatian Zhu, Shaogang Gong. Unsupervised Tracklet Person Re-Identification. TPAMI 2019. 20 | 21 | URL: https://github.com/liminxian/DukeMTMC-SI-Tracklet 22 | 23 | Dataset statistics: 24 | # identities: 702 (train) + 701 (query) + 1086 (gallery) 25 | # tracklets: 5803 (train) + 701 (query) + 6143 (gallery) 26 | """ 27 | dataset_dir = '' 28 | 29 | def __init__(self, root='data', min_seq_len=0, verbose=True, **kwargs): 30 | self.dataset_dir = osp.join(root, self.dataset_dir) 31 | # self.dataset_url = '' 32 | self.train_dir = osp.join(self.dataset_dir, 'train') 33 | self.query_dir = osp.join(self.dataset_dir, 'query') 34 | self.gallery_dir = osp.join(self.dataset_dir, 'gallery') 35 | self.split_train_json_path = osp.join(self.dataset_dir, 'info/train_ext.json') 36 | self.split_query_json_path = osp.join(self.dataset_dir, 'info/query_ext.json') 37 | self.split_gallery_json_path = osp.join(self.dataset_dir, 'info/gallery_ext.json') 38 | 39 | self.min_seq_len = min_seq_len 40 | # self._download_data() 41 | self._check_before_run() 42 | print("Note: if root path is changed, the previously generated json files need to be re-generated (so delete them first)") 43 | 44 | train = self._process_train_dir(self.train_dir, self.split_train_json_path, relabel=True) 45 | query = self._process_test_dir(self.query_dir, self.split_query_json_path, relabel=False) 46 | gallery = self._process_test_dir(self.gallery_dir, self.split_gallery_json_path, relabel=False) 47 | 48 | if verbose: 49 | print("=> DukeMTMC-VideoReID loaded") 50 | self.print_dataset_statistics(train, query, gallery) 51 | 52 | self.train = train 53 | self.query = query 54 | self.gallery = gallery 55 | 56 | self.num_train_pids, self.num_train_tkls, self.num_train_cams, self.num_train_pids_sub, self.num_train_tids_sub \ 57 | = self.get_videodata_info(self.train) 58 | self.num_query_pids, self.num_query_tkls, self.num_query_cams, _, _ = self.get_videodata_info(self.query) 59 | self.num_gallery_pids, self.num_gallery_tkls, self.num_gallery_cams, _, _ = self.get_videodata_info(self.gallery) 60 | 61 | 62 | # def _download_data(self): 63 | # if osp.exists(self.dataset_dir): 64 | # print("This dataset has been downloaded.") 65 | # return 66 | # 67 | # print("Creating directory {}".format(self.dataset_dir)) 68 | # mkdir_if_missing(self.dataset_dir) 69 | # fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 70 | # 71 | # print("Downloading DukeMTMC-VideoReID dataset") 72 | # urllib.urlretrieve(self.dataset_url, fpath) 73 | # 74 | # print("Extracting files") 75 | # zip_ref = zipfile.ZipFile(fpath, 'r') 76 | # zip_ref.extractall(self.dataset_dir) 77 | # zip_ref.close() 78 | 79 | def _check_before_run(self): 80 | """Check if all files are available before going deeper""" 81 | if not osp.exists(self.dataset_dir): 82 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 83 | if not osp.exists(self.train_dir): 84 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 85 | if not osp.exists(self.query_dir): 86 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 87 | if not osp.exists(self.gallery_dir): 88 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 89 | 90 | def _process_train_dir(self, dir_path, json_path, relabel, tid_start=0, filter=True): 91 | if osp.exists(json_path): 92 | print("=> {} generated before, awesome!".format(json_path)) 93 | split = read_json(json_path) 94 | return split['tracklets'] 95 | 96 | print("=> Automatically generating split (might take a while for the first time, have a coffe)") 97 | pdirs = sorted(glob.glob(osp.join(dir_path, '*'))) # avoid .DS_Store 98 | print("Processing '{}' with {} person identities".format(dir_path, len(pdirs))) 99 | 100 | pid_container = set() 101 | for pdir in pdirs: 102 | pid = int(osp.basename(pdir)) 103 | pid_container.add(pid) 104 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 105 | 106 | cam_list = [] 107 | tracklets = [] 108 | tid = tid_start # tid is the key in video Re-ID (as the fname in image Re-ID) 109 | for pdir in pdirs: 110 | pid = int(osp.basename(pdir)) 111 | if relabel: pid = pid2label[pid] 112 | tdirs = sorted(glob.glob(osp.join(pdir, '*'))) 113 | 114 | for tdir in tdirs: 115 | raw_img_paths = sorted(glob.glob(osp.join(tdir, '*.jpg'))) 116 | num_imgs = len(raw_img_paths) 117 | if num_imgs < self.min_seq_len: 118 | continue 119 | 120 | img_paths = sorted(glob.glob(osp.join(tdir, '*.jpg'))) 121 | img_name = osp.basename(img_paths[0]) 122 | # naming format: c6_030823.jpg 123 | camid = int(img_name[1]) - 1 124 | 125 | img_paths = tuple(img_paths) 126 | tid_sub = -1 127 | pid_sub = -1 128 | tracklets.append((img_paths, tid, pid, tid_sub, pid_sub, camid)) 129 | tid += 1 130 | cam_list.append(camid) 131 | 132 | num_cams = len(list(set(cam_list))) 133 | start_tid_uic = 0 134 | for cam_index in range(num_cams): 135 | # count tid per camera 136 | tkl_index_list = [index for index, (_, _, _, _, _, camid) in enumerate(tracklets) if camid == cam_index] 137 | # count pid per camera 138 | pid_list_sub = [tracklets[j][2] for j in tkl_index_list] 139 | unique_pid_list_percam = list(set(pid_list_sub)) 140 | start_tid_uic += len(unique_pid_list_percam) 141 | 142 | pid_percam2label = {pid: label for label, pid in enumerate(unique_pid_list_percam)} 143 | 144 | for index, tkl_index in enumerate(tkl_index_list): 145 | img_paths = tracklets[tkl_index][0] 146 | tid = tracklets[tkl_index][1] 147 | pid = tracklets[tkl_index][2] 148 | tid_sub = index 149 | pid_sub = pid_percam2label[pid] 150 | camid = tracklets[tkl_index][5] 151 | tracklets[tkl_index] = (img_paths, tid, pid, tid_sub, pid_sub, camid) 152 | 153 | print("Saving split to {}".format(json_path)) 154 | split_dict = { 155 | 'tracklets': tracklets, 156 | } 157 | write_json(split_dict, json_path) 158 | return tracklets 159 | 160 | def _process_test_dir(self, dir_path, json_path, relabel): 161 | if osp.exists(json_path): 162 | print("=> {} generated before, awesome!".format(json_path)) 163 | split = read_json(json_path) 164 | return split['tracklets'] 165 | 166 | print("=> Automatically generating split (might take a while for the first time, have a coffe)") 167 | pdirs = sorted(glob.glob(osp.join(dir_path, '*'))) # avoid .DS_Store 168 | print("Processing '{}' with {} person identities".format(dir_path, len(pdirs))) 169 | 170 | pid_container = set() 171 | for pdir in pdirs: 172 | pid = int(osp.basename(pdir)) 173 | pid_container.add(pid) 174 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 175 | 176 | cam_list = [] 177 | tracklets = [] 178 | for pdir in pdirs: 179 | pid = int(osp.basename(pdir)) 180 | if relabel: pid = pid2label[pid] 181 | tdirs = sorted(glob.glob(osp.join(pdir, '*'))) 182 | 183 | for tdir in tdirs: 184 | tid = int(osp.basename(tdir)[3:]) 185 | raw_img_paths = sorted(glob.glob(osp.join(tdir, '*.jpg'))) 186 | num_imgs = len(raw_img_paths) 187 | if num_imgs < self.min_seq_len: 188 | continue 189 | 190 | img_paths = sorted(glob.glob(osp.join(tdir, '*.jpg'))) 191 | img_name = osp.basename(img_paths[0]) 192 | # naming format: c6_030823.jpg 193 | camid = int(img_name[1]) - 1 194 | img_paths = tuple(img_paths) 195 | tid_sub = -1 196 | pid_sub = -1 197 | tracklets.append((img_paths, tid, pid, tid_sub, pid_sub, camid)) 198 | cam_list.append(camid) 199 | 200 | num_cams = len(list(set(cam_list))) 201 | start_tid_uic = 0 202 | for cam_index in range(num_cams): 203 | # count tid per camera 204 | tkl_index_list = [index for index, (_, _, _, _, _, camid) in enumerate(tracklets) if camid == cam_index] 205 | # count pid per camera 206 | pid_list_sub = [tracklets[j][2] for j in tkl_index_list] 207 | unique_pid_list_percam = list(set(pid_list_sub)) 208 | start_tid_uic += len(unique_pid_list_percam) 209 | 210 | pid_percam2label = {pid: label for label, pid in enumerate(unique_pid_list_percam)} 211 | 212 | for index, tkl_index in enumerate(tkl_index_list): 213 | img_paths = tracklets[tkl_index][0] 214 | tid = tracklets[tkl_index][1] 215 | pid = tracklets[tkl_index][2] 216 | tid_sub = index 217 | pid_sub = pid_percam2label[pid] 218 | camid = tracklets[tkl_index][5] 219 | tracklets[tkl_index] = (img_paths, tid, pid, tid_sub, pid_sub, camid) 220 | 221 | print("Saving split to {}".format(json_path)) 222 | split_dict = { 223 | 'tracklets': tracklets, 224 | } 225 | write_json(split_dict, json_path) 226 | 227 | return tracklets -------------------------------------------------------------------------------- /reid/datasets/duke_vidreid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import glob 6 | import zipfile 7 | import os.path as osp 8 | from reid.utils.iotools import mkdir_if_missing, write_json, read_json 9 | from .bases import BaseVideoDataset 10 | import ipdb 11 | 12 | 13 | class DukeMTMC_VidReID(BaseVideoDataset): 14 | """ 15 | DukeMTMCVidReID 16 | 17 | Reference: 18 | Wu et al. Exploit the Unknown Gradually: One-Shot Video-Based Person 19 | Re-Identification by Stepwise Learning. CVPR 2018. 20 | 21 | URL: https://github.com/Yu-Wu/DukeMTMC-VideoReID 22 | 23 | Dataset statistics: 24 | # identities: 702 (train) + 702 (test) 25 | # tracklets: 2196 (train) + 2636 (test) 26 | """ 27 | dataset_dir = '' 28 | 29 | def __init__(self, root='data', min_seq_len=0, verbose=True, **kwargs): 30 | self.dataset_dir = osp.join(root, self.dataset_dir) 31 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-VideoReID.zip' 32 | self.train_dir = osp.join(self.dataset_dir, 'train') 33 | self.query_dir = osp.join(self.dataset_dir, 'query') 34 | self.gallery_dir = osp.join(self.dataset_dir, 'gallery') 35 | self.split_train_json_path = osp.join(self.dataset_dir, 'info/split_train_ext.json') 36 | self.split_query_json_path = osp.join(self.dataset_dir, 'info/split_query_ext.json') 37 | self.split_gallery_json_path = osp.join(self.dataset_dir, 'info/split_gallery_ext.json') 38 | 39 | self.min_seq_len = min_seq_len 40 | self._download_data() 41 | self._check_before_run() 42 | print("Note: if root path is changed, the previously generated json files need to be re-generated (so delete them first)") 43 | 44 | train = self._process_dir(self.train_dir, self.split_train_json_path, relabel=True) 45 | query = self._process_dir(self.query_dir, self.split_query_json_path, relabel=False) 46 | gallery = self._process_dir(self.gallery_dir, self.split_gallery_json_path, relabel=False) 47 | 48 | if verbose: 49 | print("=> DukeMTMC-VideoReID loaded") 50 | self.print_dataset_statistics(train, query, gallery) 51 | 52 | self.train = train 53 | self.query = query 54 | self.gallery = gallery 55 | 56 | self.num_train_pids, self.num_train_tkls, self.num_train_cams, self.num_train_pids_sub, self.num_train_tids_sub \ 57 | = self.get_videodata_info(self.train) 58 | self.num_query_pids, self.num_query_tkls, self.num_query_cams, _, _ = self.get_videodata_info(self.query) 59 | self.num_gallery_pids, self.num_gallery_tkls, self.num_gallery_cams, _, _ = self.get_videodata_info(self.gallery) 60 | 61 | def _download_data(self): 62 | if osp.exists(self.dataset_dir): 63 | print("This dataset has been downloaded.") 64 | return 65 | 66 | print("Creating directory {}".format(self.dataset_dir)) 67 | mkdir_if_missing(self.dataset_dir) 68 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 69 | 70 | # print("Downloading DukeMTMC-VideoReID dataset") 71 | # urllib.urlretrieve(self.dataset_url, fpath) 72 | 73 | print("Extracting files") 74 | zip_ref = zipfile.ZipFile(fpath, 'r') 75 | zip_ref.extractall(self.dataset_dir) 76 | zip_ref.close() 77 | 78 | def _check_before_run(self): 79 | """Check if all files are available before going deeper""" 80 | if not osp.exists(self.dataset_dir): 81 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 82 | if not osp.exists(self.train_dir): 83 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 84 | if not osp.exists(self.query_dir): 85 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 86 | if not osp.exists(self.gallery_dir): 87 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 88 | 89 | def _process_dir(self, dir_path, json_path, relabel): 90 | if osp.exists(json_path): 91 | print("=> {} generated before, awesome!".format(json_path)) 92 | split = read_json(json_path) 93 | return split['tracklets'] 94 | 95 | print("=> Automatically generating split (might take a while for the first time, have a coffe)") 96 | pdirs = sorted(glob.glob(osp.join(dir_path, '*'))) # avoid .DS_Store 97 | print("Processing '{}' with {} person identities".format(dir_path, len(pdirs))) 98 | 99 | pid_container = set() 100 | for pdir in pdirs: 101 | pid = int(osp.basename(pdir)) 102 | pid_container.add(pid) 103 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 104 | 105 | cam_list = [] 106 | tracklets = [] 107 | for pdir in pdirs: 108 | pid = int(osp.basename(pdir)) 109 | if relabel: pid = pid2label[pid] 110 | tdirs = sorted(glob.glob(osp.join(pdir, '*'))) 111 | 112 | for tdir in tdirs: 113 | tid = int(osp.basename(tdir)) - 1 114 | raw_img_paths = sorted(glob.glob(osp.join(tdir, '*.jpg'))) 115 | num_imgs = len(raw_img_paths) 116 | 117 | if num_imgs < self.min_seq_len: 118 | continue 119 | 120 | img_paths = [] 121 | for img_idx in range(num_imgs): 122 | # some tracklet starts from 0002 instead of 0001 123 | img_idx_name = 'F' + str(img_idx+1).zfill(4) 124 | res = sorted(glob.glob(osp.join(tdir, '*' + img_idx_name + '*.jpg'))) 125 | if len(res) == 0: 126 | print("Warn: index name {} in {} is missing, jump to next".format(img_idx_name, tdir)) 127 | continue 128 | img_paths.append(res[0]) 129 | img_name = osp.basename(img_paths[0]) 130 | if img_name.find('_') == -1: 131 | # old naming format: 0001C6F0099X30823.jpg 132 | camid = int(img_name[5]) - 1 133 | else: 134 | # new naming format: 0001_C6_F0099_X30823.jpg 135 | camid = int(img_name[6]) - 1 136 | img_paths = tuple(img_paths) 137 | tid_sub = -1 138 | pid_sub = -1 139 | tracklets.append((img_paths, tid, pid, tid_sub, pid_sub, camid)) 140 | cam_list.append(camid) 141 | 142 | num_cams = len(list(set(cam_list))) 143 | start_tid_uic = 0 144 | for cam_index in range(num_cams): 145 | # count tid per camera 146 | tkl_index_list = [index for index, (_, _, _, _, _, camid) in enumerate(tracklets) if camid == cam_index] 147 | # count pid per camera 148 | pid_list_sub = [tracklets[j][2] for j in tkl_index_list] 149 | unique_pid_list_percam = list(set(pid_list_sub)) 150 | start_tid_uic += len(unique_pid_list_percam) 151 | 152 | pid_percam2label = {pid: label for label, pid in enumerate(unique_pid_list_percam)} 153 | 154 | for index, tkl_index in enumerate(tkl_index_list): 155 | img_paths = tracklets[tkl_index][0] 156 | tid = tracklets[tkl_index][1] 157 | pid = tracklets[tkl_index][2] 158 | tid_sub = index 159 | pid_sub = pid_percam2label[pid] 160 | camid = tracklets[tkl_index][5] 161 | tracklets[tkl_index] = (img_paths, tid, pid, tid_sub, pid_sub, camid) 162 | 163 | print("Saving split to {}".format(json_path)) 164 | split_dict = { 165 | 'tracklets': tracklets, 166 | } 167 | write_json(split_dict, json_path) 168 | return tracklets -------------------------------------------------------------------------------- /reid/datasets/ilids_vid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import glob 6 | import urllib 7 | import tarfile 8 | import os.path as osp 9 | from scipy.io import loadmat 10 | import numpy as np 11 | from ..utils.iotools import mkdir_if_missing, write_json, read_json 12 | 13 | import ipdb 14 | 15 | class iLIDS_VID(object): 16 | """ 17 | iLIDS-VID 18 | Reference: 19 | Wang et al. Person Re-Identification by Video Ranking. ECCV 2014. 20 | URL: http://www.eecs.qmul.ac.uk/~xiatian/downloads_qmul_iLIDS-VID_ReID_dataset.html 21 | 22 | Dataset statistics: 23 | # identities: 300 24 | # tracklets: 600 25 | # cameras: 2 26 | """ 27 | dataset_dir = '' 28 | def __init__(self, root, split_id=10, min_seq_len=0): 29 | super(iLIDS_VID, self).__init__() 30 | 31 | self.dataset_dir = osp.join(root, self.dataset_dir) 32 | self.dataset_url = 'http://www.eecs.qmul.ac.uk/~xiatian/iLIDS-VID/iLIDS-VID.tar' 33 | self.data_dir = osp.join(self.dataset_dir, '') 34 | self.split_dir = osp.join(self.data_dir, 'info/train-test people splits') 35 | self.split_mat_path = osp.join(self.split_dir, 'train_test_splits_ilidsvid.mat') 36 | self.split_path = osp.join(self.data_dir, 'info/splits.json') 37 | self.cam_1_path = osp.join(self.data_dir, 'sequences/cam1') 38 | self.cam_2_path = osp.join(self.data_dir, 'sequences/cam2') 39 | 40 | self._download_data() 41 | self._check_before_run() 42 | 43 | self._prepare_split() 44 | splits = read_json(self.split_path) 45 | print("{}/{}".format(split_id, len(splits))) 46 | if split_id >= len(splits): 47 | raise ValueError( 48 | "split_id exceeds range, received {}, but expected between 0 and {}".format(split_id, len(splits) - 1)) 49 | split = splits[split_id] 50 | train_dirs, test_dirs = split['train'], split['test'] 51 | print("# train identites: {}, # test identites {}".format(len(train_dirs), len(test_dirs))) 52 | 53 | train, num_train_tracklets, num_train_pids, num_imgs_train = \ 54 | self._process_data(train_dirs, cam1=True, cam2=True) 55 | query, num_query_tracklets, num_query_pids, num_imgs_query = \ 56 | self._process_data(test_dirs, cam1=True, cam2=False) 57 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery = \ 58 | self._process_data(test_dirs, cam1=False, cam2=True) 59 | 60 | # get train set 61 | tid_start = 0 62 | train_set, num_tkls_train, num_persons_train, \ 63 | num_tkls_pc_train, num_persons_pc_train, \ 64 | trainval_len_pertkl, trainval_len_pertkl_percam = \ 65 | self.Build_Set(train, relabel=True, min_seq_len=min_seq_len, tid_start=tid_start) 66 | # get test set 67 | tid_start = 0 68 | query_set, query_num_tracklets, query_num_pids, \ 69 | query_num_tracklets_percam, query_num_pids_percam, \ 70 | query_num_len_pertkl, query_len_pertkl_percam = \ 71 | self.Build_Set(query, relabel=False, min_seq_len=min_seq_len, tid_start=tid_start) 72 | tid_start = query_num_tracklets 73 | gallery_set, gallery_num_tracklets, gallery_num_pids, \ 74 | gallery_num_tracklets_percam, gallery_num_pids_percam, \ 75 | gallery_len_pertkl, gallery_len_pertkl_percam = \ 76 | self.Build_Set(gallery, relabel=False, min_seq_len=min_seq_len, tid_start=tid_start) 77 | 78 | num_imgs_per_tracklet = trainval_len_pertkl + query_num_len_pertkl + gallery_len_pertkl 79 | min_num = np.min(num_imgs_per_tracklet) 80 | max_num = np.max(num_imgs_per_tracklet) 81 | avg_num = np.mean(num_imgs_per_tracklet) 82 | print(" number of images per tracklet: {} ~ {}, average {}".format(min_num, max_num, avg_num)) 83 | 84 | self.train = train_set 85 | self.num_train_pids = num_persons_train 86 | self.num_train_pids_sub = num_persons_pc_train 87 | self.num_train_tids_sub = num_tkls_pc_train 88 | self.query = query_set 89 | self.gallery = gallery_set 90 | return 91 | 92 | def _download_data(self): 93 | if osp.exists(self.dataset_dir): 94 | print("This dataset has been downloaded.") 95 | return 96 | 97 | mkdir_if_missing(self.dataset_dir) 98 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 99 | 100 | print("Downloading iLIDS-VID dataset") 101 | urllib.request.urlretrieve(self.dataset_url, fpath) 102 | 103 | print("Extracting files") 104 | tar = tarfile.open(fpath) 105 | tar.extractall(path=self.dataset_dir) 106 | tar.close() 107 | 108 | def _check_before_run(self): 109 | """Check if all files are available before going deeper""" 110 | if not osp.exists(self.dataset_dir): 111 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 112 | if not osp.exists(self.data_dir): 113 | raise RuntimeError("'{}' is not available".format(self.data_dir)) 114 | if not osp.exists(self.split_dir): 115 | raise RuntimeError("'{}' is not available".format(self.split_dir)) 116 | 117 | def _prepare_split(self): 118 | if not osp.exists(self.split_path): 119 | print("Creating splits ...") 120 | mat_split_data = loadmat(self.split_mat_path)['ls_set'] 121 | 122 | num_splits = mat_split_data.shape[0] 123 | num_total_ids = mat_split_data.shape[1] 124 | assert num_splits == 10 125 | assert num_total_ids == 300 126 | num_ids_each = num_total_ids // 2 127 | 128 | # pids in mat_split_data are indices, so we need to transform them 129 | # to real pids 130 | person_cam1_dirs = sorted(glob.glob(osp.join(self.cam_1_path, '*'))) 131 | person_cam2_dirs = sorted(glob.glob(osp.join(self.cam_2_path, '*'))) 132 | 133 | person_cam1_dirs = [osp.basename(item) for item in person_cam1_dirs] 134 | person_cam2_dirs = [osp.basename(item) for item in person_cam2_dirs] 135 | 136 | # make sure persons in one camera view can be found in the other camera view 137 | assert set(person_cam1_dirs) == set(person_cam2_dirs) 138 | 139 | splits = [] 140 | for i_split in range(num_splits): 141 | # first 50% for testing and the remaining for training, following Wang et al. ECCV'14. 142 | train_idxs = sorted(list(mat_split_data[i_split, num_ids_each:])) 143 | test_idxs = sorted(list(mat_split_data[i_split, :num_ids_each])) 144 | 145 | train_idxs = [int(i) - 1 for i in train_idxs] 146 | test_idxs = [int(i) - 1 for i in test_idxs] 147 | 148 | # transform pids to person dir names 149 | train_dirs = [person_cam1_dirs[i] for i in train_idxs] 150 | test_dirs = [person_cam1_dirs[i] for i in test_idxs] 151 | 152 | split = {'train': train_dirs, 'test': test_dirs} 153 | splits.append(split) 154 | 155 | print("Totally {} splits are created, following Wang et al. ECCV'14".format(len(splits))) 156 | print("Split file is saved to {}".format(self.split_path)) 157 | write_json(splits, self.split_path) 158 | 159 | print("Splits created") 160 | 161 | def _process_data(self, dirnames, cam1=True, cam2=True): 162 | tracklets = [] 163 | num_imgs_per_tracklet = [] 164 | dirname2pid = {dirname: i for i, dirname in enumerate(dirnames)} 165 | 166 | for dirname in dirnames: 167 | if cam1: 168 | person_dir = osp.join(self.cam_1_path, dirname) 169 | img_names = glob.glob(osp.join(person_dir, '*.png')) 170 | assert len(img_names) > 0 171 | img_names = tuple(img_names) 172 | pid = dirname2pid[dirname] 173 | tracklets.append((img_names, pid, 0)) 174 | num_imgs_per_tracklet.append(len(img_names)) 175 | 176 | if cam2: 177 | person_dir = osp.join(self.cam_2_path, dirname) 178 | img_names = glob.glob(osp.join(person_dir, '*.png')) 179 | assert len(img_names) > 0 180 | img_names = tuple(img_names) 181 | pid = dirname2pid[dirname] 182 | tracklets.append((img_names, pid, 1)) 183 | num_imgs_per_tracklet.append(len(img_names)) 184 | 185 | num_tracklets = len(tracklets) 186 | num_pids = len(dirnames) 187 | 188 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 189 | 190 | 191 | def Build_Set(self, dataset_raw, relabel=False, min_seq_len=0, tid_start=0): 192 | pid_container = set() 193 | cam_container = set() 194 | for i, (_, pid, cam) in enumerate(dataset_raw): 195 | pid_container.add(pid) 196 | cam_container.add(cam) 197 | num_pids = len(pid_container) 198 | num_cams = len(cam_container) 199 | num_tkls = len(dataset_raw) 200 | 201 | dataset = [] 202 | len_pertkl = [] 203 | tid = tid_start # tid is the key in video Re-ID (as the fname in image Re-ID) 204 | for i in range(num_tkls): 205 | img_names = dataset_raw[i][0] 206 | pid = dataset_raw[i][1] 207 | cam = dataset_raw[i][2] 208 | assert 0 <= cam <= 1 209 | tid_percam = -1 210 | pid_percam = -1 211 | if len(img_names) >= min_seq_len: 212 | dataset.append((img_names, tid, pid, tid_percam, pid_percam, cam)) 213 | tid += 1 214 | len_pertkl.append(len(img_names)) 215 | 216 | # ----------------------------- Next: get the tid_pc and pid_pc -----------------------------# 217 | num_tkls_pc = [] 218 | num_pids_pc = [] 219 | len_pertkl_pc = [] 220 | for i, c in enumerate(cam_container): 221 | # count tid per camera 222 | tkl_index_pc = [index for index, (_, _, _, _, _, camid) in enumerate(dataset) if camid == c] 223 | num_tkls_pc.append(len(tkl_index_pc)) 224 | 225 | # count pid per camera 226 | pid_list_pc = [dataset[i][2] for i in tkl_index_pc] 227 | unique_pid_list_pc = list(set(pid_list_pc)) 228 | num_pids_pc.append(len(unique_pid_list_pc)) 229 | 230 | # count image number per tracklet 231 | len_pertkl_pc.append([len_pertkl[i] for i in tkl_index_pc]) 232 | 233 | pid_percam2label = {pid: label for label, pid in enumerate(unique_pid_list_pc)} 234 | for i, tkl_index in enumerate(tkl_index_pc): 235 | tid = dataset[tkl_index][1] 236 | pid = dataset[tkl_index][2] 237 | tid_pc = i 238 | pid_pc = pid_percam2label[pid] 239 | cam = dataset[tkl_index][5] 240 | dataset[tkl_index] = (dataset[tkl_index][0], tid, pid, tid_pc, pid_pc, cam) 241 | assert num_tkls == sum(num_tkls_pc) 242 | # # check if pid starts from 0 and increments with 1 243 | # for idx, pid in enumerate(pid_container): 244 | # assert idx == pid, "See code comment for explanation" 245 | return dataset, num_tkls, num_pids, num_tkls_pc, num_pids_pc, len_pertkl, len_pertkl_pc 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | -------------------------------------------------------------------------------- /reid/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import glob 6 | import re 7 | import os.path as osp 8 | from .bases import BaseImageDataset 9 | 10 | 11 | class Market1501(BaseImageDataset): 12 | """ 13 | Market1501 14 | 15 | Reference: 16 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 17 | 18 | URL: http://www.liangzheng.org/Project/project_reid.html 19 | 20 | Dataset statistics: 21 | # identities: 1501 (+1 for background) 22 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 23 | """ 24 | 25 | def __init__(self, root='data', verbose=True, **kwargs): 26 | super(Market1501, self).__init__() 27 | self.dataset_dir = root 28 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 29 | self.query_dir = osp.join(self.dataset_dir, 'query') 30 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 31 | 32 | self._check_before_run() 33 | 34 | train = self._struct_data(self.train_dir, relabel=True) 35 | query = self._struct_data(self.query_dir, relabel=False) 36 | gallery = self._struct_data(self.gallery_dir, relabel=False) 37 | 38 | if verbose: 39 | print("=> Market1501 loaded") 40 | self.print_dataset_statistics(train, query, gallery) 41 | 42 | self.train = train 43 | self.query = query 44 | self.gallery = gallery 45 | 46 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_tids_sub = self.get_imagedata_info(self.train) 47 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_tids_sub = self.get_imagedata_info(self.query) 48 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_tids_sub = self.get_imagedata_info(self.gallery) 49 | 50 | def _check_before_run(self): 51 | """Check if all files are available before going deeper""" 52 | if not osp.exists(self.dataset_dir): 53 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 54 | if not osp.exists(self.train_dir): 55 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 56 | if not osp.exists(self.query_dir): 57 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 58 | if not osp.exists(self.gallery_dir): 59 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 60 | 61 | def _struct_data(self, dir_path, relabel=False): 62 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 63 | pattern = re.compile(r'([-\d]+)_c(\d)') 64 | 65 | pid_container = set() 66 | for img_path in img_paths: 67 | pid, _ = map(int, pattern.search(img_path).groups()) 68 | if pid == -1: continue # junk images are just ignored 69 | pid_container.add(pid) 70 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 71 | 72 | img_dataset = [] 73 | tid = -1 74 | tid_pc = -1 75 | pid_pc = -1 76 | cam_list = [] 77 | for img_path in img_paths: 78 | pid, camid = map(int, pattern.search(img_path).groups()) 79 | if pid == -1: continue # junk images are just ignored 80 | assert 0 <= pid <= 1501 # pid == 0 means background 81 | assert 1 <= camid <= 6 82 | camid -= 1 # index starts from 0 83 | if relabel: pid = pid2label[pid] 84 | cam_list.append(camid) 85 | img_dataset.append((img_path, tid, pid, tid_pc, pid_pc, camid)) 86 | 87 | tkl_dataset = [] 88 | tid = -1 89 | cam_list = list(set(cam_list)) 90 | for cam in cam_list: 91 | tid_index_list = [index for index, (_, _, _, _, _, camid) in enumerate(img_dataset) if camid == cam] 92 | pid_list_pc = [img_dataset[i][2] for i in tid_index_list] 93 | unique_pid_list_pc = list(set(pid_list_pc)) 94 | pid_percam2label = {pid: label for label, pid in enumerate(unique_pid_list_pc)} 95 | 96 | for insert_pid in unique_pid_list_pc: 97 | img_index_list = [index for index, (_, _, pid, _, _, camid) in enumerate(img_dataset) if camid == cam and pid == insert_pid] 98 | img_names = tuple([osp.join(img_dataset[index][0]) for index in img_index_list]) 99 | pid_pc = pid_percam2label[insert_pid] 100 | tid_pc = pid_pc 101 | tid += 1 102 | pid = insert_pid 103 | tkl_dataset.append((img_names, tid, pid, tid_pc, pid_pc, cam)) 104 | return tkl_dataset -------------------------------------------------------------------------------- /reid/datasets/mars.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import numpy as np 4 | import torch 5 | import random 6 | from scipy.io import loadmat 7 | from collections import defaultdict 8 | 9 | def nested_dict(n, type): 10 | if n == 1: 11 | return defaultdict(type) 12 | else: 13 | return defaultdict(lambda: nested_dict(n - 1, type)) 14 | 15 | 16 | def flatten_dataset(dataset): 17 | new_dataset = [] 18 | for tracklet in dataset: 19 | img_names, tid, pid, camid = tracklet 20 | for img_name in img_names: 21 | new_dataset.append((img_name, tid, pid, camid)) 22 | return new_dataset 23 | 24 | 25 | class Mars(object): 26 | def __init__(self, root, split_id=0, min_seq_len=0): 27 | super(Mars, self).__init__() 28 | 29 | # configure path 30 | self.dataset_dir = osp.join(root, '') 31 | train_name_path = osp.join(self.dataset_dir, 'info/train_name.txt') 32 | test_name_path = osp.join(self.dataset_dir, 'info/test_name.txt') 33 | track_train_path = osp.join(self.dataset_dir, 'info/tracks_train_info.mat') 34 | track_test_path = osp.join(self.dataset_dir, 'info/tracks_test_info.mat') 35 | query_IDX_path = osp.join(self.dataset_dir, 'info/query_IDX.mat') 36 | 37 | # build train & test set according to the configure files (unit: tracklet) 38 | track_train = loadmat(track_train_path)['track_train_info'] # numpy.ndarray (8298, 4) 39 | track_test = loadmat(track_test_path)['track_test_info'] # numpy.ndarray (12180, 4) 40 | query_IDX = loadmat(query_IDX_path)['query_IDX'].squeeze() 41 | query_IDX -= 1 # index from 0 42 | track_query = track_test[query_IDX, :] # numpy.ndarray (1980,4) 43 | gallery_IDX = [i for i in range(track_test.shape[0]) if i not in query_IDX] 44 | track_gallery = track_test[gallery_IDX, :] # numpy.ndarray (10200,4) 45 | train_names = self._get_names(train_name_path) 46 | test_names = self._get_names(test_name_path) 47 | 48 | # get train set 49 | tid_start = 0 50 | # trainval_set, trainval_num_tracklets, trainval_num_pids, \ 51 | # trainval_num_tracklets_percam, trainval_num_pids_percam, \ 52 | # trainval_len_pertkl, trainval_len_pertkl_percam = \ 53 | # self.Get_Set(train_names, track_train, home_dir=osp.join(self.dataset_dir,'bbox_train'), 54 | # relabel=True, min_seq_len=min_seq_len, tid_start=tid_start) 55 | trainval_set, trainval_num_tracklets, trainval_num_pids, \ 56 | trainval_num_tracklets_percam, trainval_num_pids_percam, \ 57 | trainval_len_pertkl = \ 58 | self.Get_Set_1T4P(train_names, track_train, home_dir=osp.join(self.dataset_dir, 'bbox_train'), 59 | multitask=True, relabel=True, min_seq_len=min_seq_len, tid_start=tid_start) 60 | 61 | # get test set 62 | tid_start = 0 63 | query_set, query_num_tracklets, query_num_pids, \ 64 | query_num_tracklets_percam, query_num_pids_percam, \ 65 | query_num_len_pertkl, query_len_pertkl_percam = \ 66 | self.Get_Set(test_names, track_query, home_dir=osp.join(self.dataset_dir,'bbox_test'), 67 | relabel=False, min_seq_len=min_seq_len, tid_start=tid_start) 68 | 69 | tid_start = query_num_tracklets 70 | gallery_set, gallery_num_tracklets, gallery_num_pids, \ 71 | gallery_num_tracklets_percam, gallery_num_pids_percam, \ 72 | gallery_len_pertkl, gallery_len_pertkl_percam = \ 73 | self.Get_Set(test_names, track_gallery, home_dir=osp.join(self.dataset_dir,'bbox_test'), 74 | relabel=False, min_seq_len=min_seq_len, tid_start=tid_start) 75 | 76 | # display info 77 | print("=========================================================================================") 78 | print("", self.__class__.__name__, "dataset loaded") 79 | print(" subset | # ids | # tracklets | Cam 1 | Cam 2 | Cam 3 | Cam 4 | Cam 5 | Cam 6 |") 80 | print(" ---------------------------------------------------------------------------------") 81 | print(" trainval | {:5d} | {:8d} | {:5d} | {:5d} | {:5d} | {:5d} | {:5d} | {:5d} |" 82 | .format(trainval_num_pids, trainval_num_tracklets, 83 | trainval_num_tracklets_percam[0],trainval_num_tracklets_percam[1], 84 | trainval_num_tracklets_percam[2], trainval_num_tracklets_percam[3], 85 | trainval_num_tracklets_percam[4], trainval_num_tracklets_percam[5])) 86 | print(" query | {:5d} | {:8d} | {:5d} | {:5d} | {:5d} | {:5d} | {:5d} | {:5d} |" 87 | .format(query_num_pids, query_num_tracklets, 88 | query_num_tracklets_percam[0], query_num_tracklets_percam[1], 89 | query_num_tracklets_percam[2], query_num_tracklets_percam[3], 90 | query_num_tracklets_percam[4], query_num_tracklets_percam[5])) 91 | print(" gallery | {:5d} | {:8d} | {:5d} | {:5d} | {:5d} | {:5d} | {:5d} | {:5d} |" 92 | .format(gallery_num_pids, gallery_num_tracklets, 93 | gallery_num_tracklets_percam[0], gallery_num_tracklets_percam[1], 94 | gallery_num_tracklets_percam[2], gallery_num_tracklets_percam[3], 95 | gallery_num_tracklets_percam[4], gallery_num_tracklets_percam[5])) 96 | 97 | num_imgs_per_tracklet = trainval_len_pertkl + query_num_len_pertkl + gallery_len_pertkl 98 | min_num = np.min(num_imgs_per_tracklet) 99 | max_num = np.max(num_imgs_per_tracklet) 100 | avg_num = np.mean(num_imgs_per_tracklet) 101 | print(" number of images per tracklet: {} ~ {}, average {}".format(min_num, max_num, avg_num)) 102 | 103 | self.train = trainval_set 104 | self.train_num_pids = trainval_num_pids 105 | self.num_train_pids_sub = trainval_num_pids_percam 106 | self.num_train_tids_sub = trainval_num_tracklets_percam 107 | self.query = query_set 108 | self.gallery = gallery_set 109 | return 110 | 111 | def _get_names(self, fpath): 112 | names = [] 113 | with open(fpath, 'r') as f: 114 | for line in f: 115 | new_line = line.rstrip() 116 | names.append(new_line) 117 | return names 118 | 119 | def Get_Set(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0, tid_start = 0): 120 | # assert home_dir in ['bbox_train', 'bbox_test'] 121 | num_tracklets = meta_data.shape[0] 122 | raw_pid_list = list(set(meta_data[:, 2])) 123 | num_pids = len(raw_pid_list) 124 | num_cams = len(set(meta_data[:, 3])) 125 | if relabel: 126 | pid2label = {pid: label for label, pid in enumerate(raw_pid_list)} 127 | else: 128 | pid2label = {pid: pid for label, pid in enumerate(raw_pid_list)} 129 | 130 | dataset = [] 131 | len_pertkl = [] 132 | tid = tid_start # tid is the key in video Re-ID (as the fname in image Re-ID) 133 | for i in range(num_tracklets): 134 | data = meta_data[i,...] 135 | start_index, end_index, raw_pid, camid = data 136 | assert 1 <= camid <= 6 137 | camid -= 1 138 | if raw_pid == -1: continue # junk images are just ignored 139 | img_names = names[start_index-1:end_index] 140 | 141 | # make sure image names correspond to the same person 142 | pnames = [img_name[:4] for img_name in img_names] 143 | assert len(set(pnames)) == 1, "error: a single tracklet contains different person images!" 144 | # make sure all images are captured under the same camera 145 | camnames = [img_name[5] for img_name in img_names] 146 | assert len(set(camnames)) == 1, "error: images are captured under different cameras!" 147 | 148 | # append image names with directory information 149 | img_names = [osp.join(home_dir, img_name[:4], img_name) for img_name in img_names] 150 | img_names = tuple(img_names) 151 | pid = pid2label[raw_pid] 152 | tid_percam = -1 153 | pid_percam = -1 154 | if len(img_names) >= min_seq_len: 155 | dataset.append((img_names, tid, pid, tid_percam, pid_percam, camid)) 156 | tid += 1 157 | len_pertkl.append(len(img_names)) 158 | num_tracklets = len(dataset) 159 | 160 | num_tracklets_percam = [] 161 | num_pids_percam = [] 162 | len_pertkl_percam = [] 163 | for i in range(num_cams): 164 | # count tid per camera 165 | tkl_index_list = [index for index, (_, _, _, _, _, camid) in enumerate(dataset) if camid == i] 166 | num_tracklets_percam.append(len(tkl_index_list)) 167 | # count pid per camera 168 | pid_list_percam = [dataset[j][2] for j in tkl_index_list] 169 | unique_pid_list_percam = list(set(pid_list_percam)) 170 | num_pids_percam.append(len(unique_pid_list_percam)) 171 | # count image number per tracklet 172 | len_pertkl_percam.append([len_pertkl[j] for j in tkl_index_list]) 173 | 174 | pid_percam2label = {pid: label for label, pid in enumerate(unique_pid_list_percam)} 175 | 176 | for j, tkl_index in enumerate(tkl_index_list): 177 | tid = dataset[tkl_index][1] 178 | pid = dataset[tkl_index][2] 179 | tid_percam = j 180 | pid_percam = pid_percam2label[pid] 181 | camid = dataset[tkl_index][5] 182 | dataset[tkl_index] = (dataset[tkl_index][0], tid, pid, tid_percam, pid_percam, camid) 183 | assert num_tracklets == sum(num_tracklets_percam) 184 | return dataset, num_tracklets, num_pids, num_tracklets_percam, num_pids_percam, len_pertkl, len_pertkl_percam 185 | 186 | 187 | def Get_Set_1T4P(self, names, meta_data, home_dir=None, multitask=False, relabel=False, min_seq_len=0, tid_start=0): 188 | # assert home_dir in ['bbox_train', 'bbox_test'] 189 | num_tracklets = meta_data.shape[0] 190 | raw_pid_list = list(set(meta_data[:, 2])) 191 | num_pids = len(raw_pid_list) 192 | num_cams = len(set(meta_data[:, 3])) 193 | if relabel: 194 | pid2label = {pid: label for label, pid in enumerate(raw_pid_list)} 195 | else: 196 | pid2label = {pid: pid for label, pid in enumerate(raw_pid_list)} 197 | 198 | dataset = [] 199 | len_pertkl = [] 200 | tid = tid_start # tid is the key in video Re-ID (as the fname in image Re-ID) 201 | for i in range(num_tracklets): 202 | data = meta_data[i, ...] 203 | start_index, end_index, raw_pid, camid = data 204 | assert 1 <= camid <= 6 205 | camid -= 1 206 | img_names = names[start_index - 1:end_index] 207 | 208 | # make sure image names correspond to the same person 209 | pnames = [img_name[:4] for img_name in img_names] 210 | assert len(set(pnames)) == 1, "error: a single tracklet contains different person images!" 211 | 212 | # make sure all images are captured under the same camera 213 | camnames = [img_name[5] for img_name in img_names] 214 | assert len(set(camnames)) == 1, "error: images are captured under different cameras!" 215 | 216 | # append image names with directory information 217 | img_names = [osp.join(home_dir, img_name[:4], img_name) for img_name in img_names] 218 | img_names = tuple(img_names) 219 | pid = pid2label[raw_pid] 220 | tid_percam = -1 221 | pid_percam = -1 222 | if len(img_names) >= min_seq_len: 223 | dataset.append((img_names, tid, pid, tid_percam, pid_percam, camid)) 224 | tid += 1 225 | len_pertkl.append(len(img_names)) 226 | num_tracklets = len(dataset) 227 | 228 | if multitask == False: 229 | return dataset, num_tracklets, num_pids, len_pertkl 230 | else: 231 | mt_num_tracklets = [] 232 | num_pids_pc = [] 233 | mt_num_imgs_per_tracklet = [] 234 | for i in range(num_cams): 235 | tkl_index_list = [index for index, (_, _, _, _, _, camid) in enumerate(dataset) if camid == i] 236 | mt_num_tracklets.append(len(tkl_index_list)) 237 | mt_num_imgs_per_tracklet.append([len_pertkl[j] for j in tkl_index_list]) 238 | 239 | id_cami_list = [dataset[j][2] for j in tkl_index_list] 240 | id_unique_cami_list = list(set(id_cami_list)) 241 | num_pids_pc.append(len(id_unique_cami_list)) 242 | 243 | # Scheme 3: choose 1 tracklet per pid, rate of pid have more than 1 tracklet (test duplicate rate) 244 | tindex = 0 245 | new_dataset = [] 246 | num_tids_pc_new = [] 247 | for t in range(num_cams): 248 | tkl_index_list = [index for index, (_, _, _, _, _, camid) in enumerate(dataset) if camid == t] 249 | assert len(tkl_index_list) == mt_num_tracklets[t] 250 | 251 | num_sampling = num_pids_pc[t] 252 | tkl_per_pid = 1 253 | id_cami_list = [dataset[j][2] for j in tkl_index_list] 254 | id_unique_cami_list = list(set(id_cami_list)) 255 | id_indices = torch.randperm(len(id_unique_cami_list)) 256 | tkl_sample_no = 0 257 | for pid_sample_no in range(num_pids_pc[t]): 258 | # Step_1: after shuffling the pid order, pick up the pid one by one 259 | pid_anchor = id_unique_cami_list[id_indices[pid_sample_no]] 260 | # Step_2: get all tracklets belong to pid_anchor 261 | tid_list = [] 262 | for tid in tkl_index_list: 263 | if dataset[tid][2] == pid_anchor: 264 | tid_list.append(tid) 265 | # Step_3: pick up (tkl_per_pid) tracklets randomly 266 | if len(tid_list) >= tkl_per_pid: 267 | random.shuffle(tid_list) # shuffle the tid order 268 | for j in range(tkl_per_pid): 269 | tkl_index = tid_list[j] 270 | fnames = dataset[tkl_index][0] 271 | tid = tindex 272 | pid = dataset[tkl_index][2] 273 | tid_percam = tkl_sample_no 274 | pid_percam = pid_sample_no 275 | camid = dataset[tkl_index][5] 276 | 277 | assert camid == t 278 | new_dataset.append((fnames, tid, pid, tid_percam, pid_percam, camid)) 279 | tindex += 1 280 | tkl_sample_no += 1 281 | else: 282 | for j in range(len(tid_list)): 283 | tkl_index = tid_list[j] 284 | fnames = dataset[tkl_index][0] 285 | tid = tindex 286 | pid = dataset[tkl_index][2] 287 | tid_percam = tkl_sample_no 288 | pid_percam = pid_sample_no 289 | camid = dataset[tkl_index][5] 290 | 291 | assert camid == t 292 | new_dataset.append((fnames, tid, pid, tid_percam, pid_percam, camid)) 293 | tindex += 1 294 | tkl_sample_no += 1 295 | for j in range(tkl_per_pid - len(tid_list)): 296 | tkl_index = tid_list[j] 297 | fnames = dataset[tkl_index][0] 298 | tid = tindex 299 | pid = dataset[tkl_index][2] 300 | tid_percam = tkl_sample_no 301 | pid_percam = pid_sample_no 302 | camid = dataset[tkl_index][5] 303 | 304 | assert camid == t 305 | new_dataset.append((fnames, tid, pid, tid_percam, pid_percam, camid)) 306 | tindex += 1 307 | tkl_sample_no += 1 308 | if tkl_sample_no == tkl_per_pid * num_sampling: 309 | break 310 | assert tkl_sample_no == tkl_per_pid * num_sampling 311 | 312 | # # Increment sampling 313 | # tkl_sample_no = tkl_sample_no # keep tkl_sample_no 314 | # rate = 0.2 315 | # Inc_num_pids = int (mt_num_pids[t]*rate) 316 | # tkl_per_pid = 1 317 | # # select pid 318 | # id_cami_list = [dataset[j][2] for j in index_list] 319 | # id_unique_cami_list = list(set(id_cami_list)) 320 | # # id_indices = torch.randperm(len(id_unique_cami_list)) 321 | # for pid_sample_no in range(Inc_num_pids): # Increment pids 322 | # pid_index = id_unique_cami_list[id_indices[pid_sample_no]] 323 | # 324 | # # get all tracklets of pid_index 325 | # tkl_list = [] 326 | # for index in index_list: 327 | # if dataset[index][2] == pid_index: 328 | # tkl_list.append(index) 329 | # if len(tkl_list) >= tkl_per_pid: 330 | # # random.shuffle(tkl_list) # randomly select (tkl_per_pid) tracklets 331 | # for j in range(tkl_per_pid): 332 | # fnames = dataset[tkl_list[j]][0] 333 | # camid = dataset[tkl_list[j]][3] 334 | # assert camid == t 335 | # new_dataset.append((fnames, tindex, tkl_sample_no, camid)) 336 | # tindex += 1 337 | # tkl_sample_no += 1 338 | # else: 339 | # for j in range(len(tkl_list)): 340 | # fnames = dataset[tkl_list[j]][0] 341 | # camid = dataset[tkl_list[j]][3] 342 | # assert camid == t 343 | # new_dataset.append((fnames, tindex, tkl_sample_no, camid)) 344 | # tindex += 1 345 | # tkl_sample_no += 1 346 | # for j in range(tkl_per_pid - len(tkl_list)): 347 | # fnames = dataset[tkl_list[0]][0] 348 | # camid = dataset[tkl_list[0]][3] 349 | # assert camid == t 350 | # new_dataset.append((fnames, tindex, tkl_sample_no, camid)) 351 | # tindex += 1 352 | # tkl_sample_no += 1 353 | # assert tkl_sample_no == mt_num_pids[t] + Inc_num_pids 354 | num_tids_pc_new.append(tkl_sample_no) 355 | num_tracklets_percam = num_tids_pc_new 356 | num_pids_percam = num_pids_pc 357 | num_tracklets = sum(num_tracklets_percam) 358 | 359 | return new_dataset, num_tracklets, num_pids, num_tracklets_percam, num_pids_percam, len_pertkl 360 | 361 | 362 | 363 | 364 | 365 | 366 | -------------------------------------------------------------------------------- /reid/datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | 4 | 5 | class MSMT17(object): 6 | """ 7 | MSMT17 8 | Reference: 9 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018. 10 | URL: http://www.pkuvmc.com/publications/msmt17.html 11 | 12 | Dataset statistics: 13 | # identities: 4101 14 | # images: 32621 (train) + 11659 (query) + 82161 (gallery) 15 | # cameras: 15 16 | """ 17 | 18 | def __init__(self, root='data', split_id=0, min_seq_len=0, **kwargs): 19 | self.dataset_dir = osp.join(root, '') 20 | self.train_dir = osp.join(self.dataset_dir, 'mask_train_v2') 21 | self.test_dir = osp.join(self.dataset_dir, 'mask_test_v2') 22 | self.list_train_path = osp.join(self.dataset_dir, 'list_train.txt') 23 | self.list_val_path = osp.join(self.dataset_dir, 'list_val.txt') 24 | self.list_query_path = osp.join(self.dataset_dir, 'list_query.txt') 25 | self.list_gallery_path = osp.join(self.dataset_dir, 'list_gallery.txt') 26 | 27 | self._check_before_run() 28 | # get train set 29 | tid_start = 0 30 | train_set, num_tkls_train, num_persons_train, \ 31 | num_tkls_pc_train, num_persons_pc_train, \ 32 | trainval_len_pertkl, trainval_len_pertkl_percam = \ 33 | self.Build_Set(self.train_dir, self.list_train_path, min_seq_len=min_seq_len, tid_start=tid_start) 34 | # get query set 35 | tid_start = 0 36 | query_set, num_tkls_train_query, num_persons_query, \ 37 | num_tkls_pc_query, num_pids_pc_query, \ 38 | query_len_pertkl, query_len_pertkl_percam = \ 39 | self.Build_Set(self.test_dir, self.list_query_path, min_seq_len=min_seq_len, tid_start=tid_start) 40 | # get gallery set 41 | tid_start = num_tkls_train_query 42 | gallery_set, num_tkls_train_gallery, num_persons_gallery, \ 43 | num_tkls_pc_gallery, num_pids_pc_gallery, \ 44 | gallery_len_pertkl, gallery_len_pertkl_percam = \ 45 | self.Build_Set(self.test_dir, self.list_gallery_path, min_seq_len=min_seq_len, tid_start=tid_start) 46 | 47 | num_total_pids = num_persons_train + num_persons_query 48 | num_train_imgs = sum(trainval_len_pertkl) 49 | num_query_imgs = sum(query_len_pertkl) 50 | num_gallery_imgs = sum(gallery_len_pertkl) 51 | num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs 52 | 53 | print("=> MSMT17 loaded") 54 | print("Dataset statistics:") 55 | print(" ------------------------------") 56 | print(" subset | # ids | # images") 57 | print(" ------------------------------") 58 | print(" train | {:5d} | {:8d}".format(num_persons_train, num_train_imgs)) 59 | print(" query | {:5d} | {:8d}".format(num_persons_query, num_query_imgs)) 60 | print(" gallery | {:5d} | {:8d}".format(num_persons_gallery, num_gallery_imgs)) 61 | print(" ------------------------------") 62 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_imgs)) 63 | print(" ------------------------------") 64 | 65 | self.train = train_set 66 | self.num_train_pids = num_persons_train 67 | self.num_train_pids_sub = num_persons_pc_train 68 | self.num_train_tids_sub = num_tkls_pc_train 69 | self.query = query_set 70 | self.gallery = gallery_set 71 | 72 | def _check_before_run(self): 73 | """Check if all files are available before going deeper""" 74 | if not osp.exists(self.dataset_dir): 75 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 76 | if not osp.exists(self.train_dir): 77 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 78 | if not osp.exists(self.test_dir): 79 | raise RuntimeError("'{}' is not available".format(self.test_dir)) 80 | 81 | def Build_Set(self, dir_path, list_path, min_seq_len=0, tid_start=0): 82 | # Get the dataset_img 83 | with open(list_path, 'r') as txt: 84 | lines = txt.readlines() 85 | dataset_img = [] 86 | pid_container = set() 87 | cam_container = set() 88 | for img_idx, img_info in enumerate(lines): 89 | img_path, pid = img_info.split(' ') 90 | pid = int(pid) # no need to relabel 91 | cam = int(img_path.split('_')[2]) 92 | img_name = osp.join(dir_path, img_path) 93 | 94 | # dataset.append((img_path, pid, camid)) 95 | dataset_img.append((img_name, pid, cam)) 96 | pid_container.add(pid) 97 | cam_container.add(cam) 98 | num_imgs = len(dataset_img) 99 | num_pids = len(pid_container) 100 | num_cams = len(cam_container) 101 | 102 | # Get the dataset(tkl) 103 | dataset = [] 104 | len_pertkl = [] 105 | tid = tid_start # tid is the key in video Re-ID (as the fname in image Re-ID) 106 | for pid_idx, pid in enumerate(pid_container): 107 | for cam_idx, cam in enumerate(cam_container): 108 | img_names = [] 109 | for i in range(num_imgs): 110 | if dataset_img[i][1] == pid and dataset_img[i][2] == cam: 111 | img_names.append(dataset_img[i][0]) 112 | if len(img_names) > min_seq_len: 113 | img_names = tuple(img_names) 114 | tid_pc = -1 115 | pid_pc = -1 116 | cam = cam - 1 # cam start from 0 117 | dataset.append((img_names, tid, pid, tid_pc, pid_pc, cam)) 118 | tid += 1 119 | len_pertkl.append(len(img_names)) 120 | num_tkls = len(dataset) 121 | 122 | #----------------------------- Next: get the tid_pc and pid_pc -----------------------------# 123 | num_tkls_pc = [] 124 | num_pids_pc = [] 125 | len_pertkl_pc = [] 126 | for c in range(num_cams): 127 | # count tid per camera 128 | tkl_index_pc = [index for index, (_, _, _, _, _, camid) in enumerate(dataset) if camid == c] 129 | num_tkls_pc.append(len(tkl_index_pc)) 130 | 131 | # count pid per camera 132 | pid_list_pc = [dataset[i][2] for i in tkl_index_pc] 133 | unique_pid_list_pc = list(set(pid_list_pc)) 134 | num_pids_pc.append(len(unique_pid_list_pc)) 135 | 136 | # count image number per tracklet 137 | len_pertkl_pc.append([len_pertkl[i] for i in tkl_index_pc]) 138 | 139 | pid_percam2label = {pid:label for label, pid in enumerate(unique_pid_list_pc)} 140 | for i, tkl_index in enumerate(tkl_index_pc): 141 | tid = dataset[tkl_index][1] 142 | pid = dataset[tkl_index][2] 143 | tid_pc = i 144 | pid_pc = pid_percam2label[pid] 145 | cam = dataset[tkl_index][5] 146 | dataset[tkl_index] = (dataset[tkl_index][0], tid, pid, tid_pc, pid_pc, cam) 147 | assert num_tkls == sum(num_tkls_pc) 148 | # # check if pid starts from 0 and increments with 1 149 | # for idx, pid in enumerate(pid_container): 150 | # assert idx == pid, "See code comment for explanation" 151 | return dataset, num_tkls, num_pids, num_tkls_pc, num_pids_pc, len_pertkl, len_pertkl_pc 152 | 153 | # def _process_dir(self, dir_path, list_path): 154 | # with open(list_path, 'r') as txt: 155 | # lines = txt.readlines() 156 | # dataset = [] 157 | # pid_container = set() 158 | # for img_idx, img_info in enumerate(lines): 159 | # img_path, pid = img_info.split(' ') 160 | # pid = int(pid) # no need to relabel 161 | # camid = int(img_path.split('_')[2]) 162 | # img_path = osp.join(dir_path, img_path) 163 | # dataset.append((img_path, pid, camid)) 164 | # pid_container.add(pid) 165 | # num_imgs = len(dataset) 166 | # num_pids = len(pid_container) 167 | # # check if pid starts from 0 and increments with 1 168 | # for idx, pid in enumerate(pid_container): 169 | # assert idx == pid, "See code comment for explanation" 170 | # return dataset, num_pids, num_imgs -------------------------------------------------------------------------------- /reid/datasets/prid2011.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import glob 6 | import os.path as osp 7 | import numpy as np 8 | from ..utils.iotools import mkdir_if_missing, write_json, read_json 9 | import ipdb 10 | 11 | 12 | class PRID2011(object): 13 | """ 14 | PRID2011 15 | Reference: 16 | Hirzer et al. Person Re-Identification by Descriptive and Discriminative Classification. SCIA 2011. 17 | URL: https://www.tugraz.at/institute/icg/research/team-bischof/lrs/downloads/PRID11/ 18 | 19 | Dataset statistics: 20 | # identities: 200 21 | # tracklets: 400 22 | # cameras: 2 23 | """ 24 | dataset_dir = '' 25 | def __init__(self, root, split_id=10, min_seq_len=0): 26 | self.dataset_dir = osp.join(root, self.dataset_dir) 27 | self.split_path = osp.join(self.dataset_dir, 'info/splits_prid2011.json') 28 | self.cam_a_path = osp.join(self.dataset_dir, 'multi_shot', 'cam_a') 29 | self.cam_b_path = osp.join(self.dataset_dir, 'multi_shot', 'cam_b') 30 | 31 | self._check_before_run() 32 | splits = read_json(self.split_path) 33 | print("{}/{}".format(split_id, len(splits))) 34 | if split_id >= len(splits): 35 | raise ValueError( 36 | "split_id exceeds range, received {}, but expected between 0 and {}".format(split_id, len(splits) - 1)) 37 | split = splits[split_id] 38 | train_dirs, test_dirs = split['train'], split['test'] 39 | print("# train identites: {}, # test identites {}".format(len(train_dirs), len(test_dirs))) 40 | 41 | train, num_train_tracklets, num_train_pids, num_imgs_train = \ 42 | self._process_data(train_dirs, cam1=True, cam2=True) 43 | query, num_query_tracklets, num_query_pids, num_imgs_query = \ 44 | self._process_data(test_dirs, cam1=True, cam2=False) 45 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery = \ 46 | self._process_data(test_dirs, cam1=False, cam2=True) 47 | 48 | # get train set 49 | tid_start = 0 50 | train_set, num_tkls_train, num_persons_train, \ 51 | num_tkls_pc_train, num_persons_pc_train, \ 52 | trainval_len_pertkl, trainval_len_pertkl_percam = \ 53 | self.Build_Set(train, relabel=True, min_seq_len=min_seq_len, tid_start=tid_start) 54 | # get test set 55 | tid_start = 0 56 | query_set, query_num_tracklets, query_num_pids, \ 57 | query_num_tracklets_percam, query_num_pids_percam, \ 58 | query_num_len_pertkl, query_len_pertkl_percam = \ 59 | self.Build_Set(query, relabel=False, min_seq_len=min_seq_len, tid_start=tid_start) 60 | tid_start = query_num_tracklets 61 | gallery_set, gallery_num_tracklets, gallery_num_pids, \ 62 | gallery_num_tracklets_percam, gallery_num_pids_percam, \ 63 | gallery_len_pertkl, gallery_len_pertkl_percam = \ 64 | self.Build_Set(gallery, relabel=False, min_seq_len=min_seq_len, tid_start=tid_start) 65 | 66 | num_imgs_per_tracklet = trainval_len_pertkl + query_num_len_pertkl + gallery_len_pertkl 67 | min_num = np.min(num_imgs_per_tracklet) 68 | max_num = np.max(num_imgs_per_tracklet) 69 | avg_num = np.mean(num_imgs_per_tracklet) 70 | print(" number of images per tracklet: {} ~ {}, average {}".format(min_num, max_num, avg_num)) 71 | 72 | self.train = train_set 73 | self.num_train_pids = num_persons_train 74 | self.num_train_pids_sub = num_persons_pc_train 75 | self.num_train_tids_sub = num_tkls_pc_train 76 | self.query = query_set 77 | self.gallery = gallery_set 78 | return 79 | 80 | def _check_before_run(self): 81 | """Check if all files are available before going deeper""" 82 | if not osp.exists(self.dataset_dir): 83 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 84 | 85 | def _process_data(self, dirnames, cam1=True, cam2=True): 86 | tracklets = [] 87 | num_imgs_per_tracklet = [] 88 | dirname2pid = {dirname: i for i, dirname in enumerate(dirnames)} 89 | 90 | for dirname in dirnames: 91 | if cam1: 92 | person_dir = osp.join(self.cam_a_path, dirname) 93 | img_names = glob.glob(osp.join(person_dir, '*.png')) 94 | assert len(img_names) > 0 95 | img_names = tuple(img_names) 96 | pid = dirname2pid[dirname] 97 | tracklets.append((img_names, pid, 0)) 98 | num_imgs_per_tracklet.append(len(img_names)) 99 | 100 | if cam2: 101 | person_dir = osp.join(self.cam_b_path, dirname) 102 | img_names = glob.glob(osp.join(person_dir, '*.png')) 103 | assert len(img_names) > 0 104 | img_names = tuple(img_names) 105 | pid = dirname2pid[dirname] 106 | tracklets.append((img_names, pid, 1)) 107 | num_imgs_per_tracklet.append(len(img_names)) 108 | 109 | num_tracklets = len(tracklets) 110 | num_pids = len(dirnames) 111 | 112 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 113 | 114 | 115 | def Build_Set(self, dataset_raw, relabel=False, min_seq_len=0, tid_start=0): 116 | pid_container = set() 117 | cam_container = set() 118 | for i, (_, pid, cam) in enumerate(dataset_raw): 119 | pid_container.add(pid) 120 | cam_container.add(cam) 121 | num_pids = len(pid_container) 122 | num_cams = len(cam_container) 123 | num_tkls = len(dataset_raw) 124 | 125 | dataset = [] 126 | len_pertkl = [] 127 | tid = tid_start # tid is the key in video Re-ID (as the fname in image Re-ID) 128 | for i in range(num_tkls): 129 | img_names = dataset_raw[i][0] 130 | pid = dataset_raw[i][1] 131 | cam = dataset_raw[i][2] 132 | assert 0 <= cam <= 1 133 | tid_percam = -1 134 | pid_percam = -1 135 | if len(img_names) >= min_seq_len: 136 | dataset.append((img_names, tid, pid, tid_percam, pid_percam, cam)) 137 | tid += 1 138 | len_pertkl.append(len(img_names)) 139 | 140 | # ----------------------------- Next: get the tid_pc and pid_pc -----------------------------# 141 | num_tkls_pc = [] 142 | num_pids_pc = [] 143 | len_pertkl_pc = [] 144 | for i, c in enumerate(cam_container): 145 | # count tid per camera 146 | tkl_index_pc = [index for index, (_, _, _, _, _, camid) in enumerate(dataset) if camid == c] 147 | num_tkls_pc.append(len(tkl_index_pc)) 148 | 149 | # count pid per camera 150 | pid_list_pc = [dataset[i][2] for i in tkl_index_pc] 151 | unique_pid_list_pc = list(set(pid_list_pc)) 152 | num_pids_pc.append(len(unique_pid_list_pc)) 153 | 154 | # count image number per tracklet 155 | len_pertkl_pc.append([len_pertkl[i] for i in tkl_index_pc]) 156 | 157 | pid_percam2label = {pid: label for label, pid in enumerate(unique_pid_list_pc)} 158 | for i, tkl_index in enumerate(tkl_index_pc): 159 | tid = dataset[tkl_index][1] 160 | pid = dataset[tkl_index][2] 161 | tid_pc = i 162 | pid_pc = pid_percam2label[pid] 163 | cam = dataset[tkl_index][5] 164 | dataset[tkl_index] = (dataset[tkl_index][0], tid, pid, tid_pc, pid_pc, cam) 165 | assert num_tkls == sum(num_tkls_pc) 166 | # # check if pid starts from 0 and increments with 1 167 | # for idx, pid in enumerate(pid_container): 168 | # assert idx == pid, "See code comment for explanation" 169 | return dataset, num_tkls, num_pids, num_tkls_pc, num_pids_pc, len_pertkl, len_pertkl_pc 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | -------------------------------------------------------------------------------- /reid/dist_metric.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | from .evaluators_image import extract_features 6 | from .metric_learning import get_metric 7 | 8 | 9 | class DistanceMetric(object): 10 | def __init__(self, algorithm='euclidean', *args, **kwargs): 11 | super(DistanceMetric, self).__init__() 12 | self.algorithm = algorithm 13 | self.metric = get_metric(algorithm, *args, **kwargs) 14 | 15 | def train(self, model, data_loader): 16 | if self.algorithm == 'euclidean': return 17 | features, labels = extract_features(model, data_loader) 18 | features = torch.stack(features.values()).numpy() 19 | labels = torch.Tensor(list(labels.values())).numpy() 20 | self.metric.fit(features, labels) 21 | 22 | def transform(self, X): 23 | if torch.is_tensor(X): 24 | X = X.numpy() 25 | X = self.metric.transform(X) 26 | X = torch.from_numpy(X) 27 | else: 28 | X = self.metric.transform(X) 29 | return X 30 | 31 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classification import accuracy 4 | from .ranking import cmc, mean_ap 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap', 10 | ] 11 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from ..utils import to_torch 4 | 5 | 6 | def accuracy(output, target, topk=(1,)): 7 | output, target = to_torch(output), to_torch(target) 8 | maxk = max(topk) 9 | batch_size = target.size(0) 10 | 11 | _, pred = output.topk(maxk, 1, True, True) 12 | pred = pred.t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | 15 | ret = [] 16 | for k in topk: 17 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 18 | ret.append(correct_k.mul_(1. / batch_size)) 19 | return ret 20 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from sklearn.metrics import average_precision_score 6 | 7 | from ..utils import to_numpy 8 | 9 | 10 | def _unique_sample(ids_dict, num): 11 | mask = np.zeros(num, dtype=np.bool) 12 | for _, indices in ids_dict.items(): 13 | i = np.random.choice(indices) 14 | mask[i] = True 15 | return mask 16 | 17 | 18 | def cmc(distmat, query_ids=None, gallery_ids=None, 19 | query_cams=None, gallery_cams=None, topk=100, 20 | separate_camera_set=False, 21 | single_gallery_shot=False, 22 | first_match_break=False): 23 | distmat = to_numpy(distmat) 24 | m, n = distmat.shape 25 | # Fill up default values 26 | if query_ids is None: 27 | query_ids = np.arange(m) 28 | if gallery_ids is None: 29 | gallery_ids = np.arange(n) 30 | if query_cams is None: 31 | query_cams = np.zeros(m).astype(np.int32) 32 | if gallery_cams is None: 33 | gallery_cams = np.ones(n).astype(np.int32) 34 | # Ensure numpy array 35 | query_ids = np.asarray(query_ids) 36 | gallery_ids = np.asarray(gallery_ids) 37 | query_cams = np.asarray(query_cams) 38 | gallery_cams = np.asarray(gallery_cams) 39 | # Sort and find correct matches 40 | indices = np.argsort(distmat, axis=1) 41 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 42 | # Compute CMC for each query 43 | ret = np.zeros(topk) 44 | num_valid_queries = 0 45 | for i in range(m): 46 | # Filter out the same id and same camera 47 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 48 | (gallery_cams[indices[i]] != query_cams[i])) 49 | if separate_camera_set: 50 | # Filter out samples from same camera 51 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 52 | if not np.any(matches[i, valid]): continue 53 | if single_gallery_shot: 54 | repeat = 10 55 | gids = gallery_ids[indices[i][valid]] 56 | inds = np.where(valid)[0] 57 | ids_dict = defaultdict(list) 58 | for j, x in zip(inds, gids): 59 | ids_dict[x].append(j) 60 | else: 61 | repeat = 1 62 | for _ in range(repeat): 63 | if single_gallery_shot: 64 | # Randomly choose one instance for each id 65 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 66 | index = np.nonzero(matches[i, sampled])[0] 67 | else: 68 | index = np.nonzero(matches[i, valid])[0] 69 | delta = 1. / (len(index) * repeat) 70 | for j, k in enumerate(index): 71 | if k - j >= topk: break 72 | if first_match_break: 73 | ret[k - j] += 1 74 | break 75 | ret[k - j] += delta 76 | num_valid_queries += 1 77 | if num_valid_queries == 0: 78 | raise RuntimeError("No valid query") 79 | return ret.cumsum() / num_valid_queries 80 | 81 | 82 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 83 | query_cams=None, gallery_cams=None): 84 | distmat = to_numpy(distmat) 85 | m, n = distmat.shape 86 | # Fill up default values 87 | if query_ids is None: 88 | query_ids = np.arange(m) 89 | if gallery_ids is None: 90 | gallery_ids = np.arange(n) 91 | if query_cams is None: 92 | query_cams = np.zeros(m).astype(np.int32) 93 | if gallery_cams is None: 94 | gallery_cams = np.ones(n).astype(np.int32) 95 | # Ensure numpy array 96 | query_ids = np.asarray(query_ids) 97 | gallery_ids = np.asarray(gallery_ids) 98 | query_cams = np.asarray(query_cams) 99 | gallery_cams = np.asarray(gallery_cams) 100 | # Sort and find correct matches 101 | indices = np.argsort(distmat, axis=1) 102 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 103 | # Compute AP for each query 104 | aps = [] 105 | for i in range(m): 106 | # Filter out the same id and same camera 107 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 108 | (gallery_cams[indices[i]] != query_cams[i])) 109 | y_true = matches[i, valid] 110 | y_score = -distmat[i][indices[i]][valid] 111 | if not np.any(y_true): continue 112 | aps.append(average_precision_score(y_true, y_score)) 113 | if len(aps) == 0: 114 | raise RuntimeError("No valid query") 115 | return np.mean(aps) 116 | -------------------------------------------------------------------------------- /reid/evaluators_image.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | import torch 4 | from collections import OrderedDict 5 | from .evaluation_metrics import cmc, mean_ap 6 | from .feature_extraction import extract_cnn_feature 7 | from .utils.meters import AverageMeter 8 | import ipdb 9 | 10 | def extract_features(model, data_loader, print_freq=1, metric=None): 11 | model.eval() 12 | batch_time = AverageMeter() 13 | data_time = AverageMeter() 14 | 15 | features = OrderedDict() 16 | labels = OrderedDict() 17 | 18 | end = time.time() 19 | for i, (imgs, fnames, tids, pids, tids_percam, pids_percam, cams) in enumerate(data_loader): 20 | data_time.update(time.time() - end) 21 | 22 | outputs = extract_cnn_feature(model, imgs, cams) 23 | print("size of outputs: ", outputs.__sizeof__()) 24 | for fname, output, pid in zip(fnames, outputs, pids): 25 | features[fname] = output 26 | labels[fname] = pid 27 | 28 | batch_time.update(time.time() - end) 29 | end = time.time() 30 | 31 | if (i + 1) % print_freq == 0: 32 | print('Extract Features: [{}/{}]\t' 33 | 'Time {:.3f} ({:.3f})\t' 34 | 'Data {:.3f} ({:.3f})\t' 35 | .format(i + 1, len(data_loader), 36 | batch_time.val, batch_time.avg, 37 | data_time.val, data_time.avg)) 38 | 39 | return features, labels 40 | 41 | 42 | def pairwise_distance(features, query=None, gallery=None, metric=None): 43 | if query is None and gallery is None: 44 | n = len(features) 45 | x = torch.cat(list(features.values())) 46 | x = x.view(n, -1) 47 | if metric is not None: 48 | x = metric.transform(x) 49 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2 50 | dist = dist.expand(n, n) - 2 * torch.mm(x, x.t()) 51 | return dist 52 | 53 | x = torch.cat([features[f].unsqueeze(0) for f, _, _, _, _, _ in query], 0) 54 | y = torch.cat([features[f].unsqueeze(0) for f, _, _, _, _, _ in gallery], 0) 55 | m, n = x.size(0), y.size(0) 56 | x = x.view(m, -1) 57 | y = y.view(n, -1) 58 | if metric is not None: 59 | x = metric.transform(x) 60 | y = metric.transform(y) 61 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 62 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 63 | dist.addmm_(1, -2, x, y.t()) 64 | return dist 65 | 66 | 67 | def evaluate_all(distmat, query=None, gallery=None, 68 | query_ids=None, gallery_ids=None, 69 | query_cams=None, gallery_cams=None, 70 | cmc_topk=(1, 5, 10, 20)): 71 | if query is not None and gallery is not None: 72 | query_ids = [pid for _, _, pid, _, _, _ in query] 73 | gallery_ids = [pid for _, _, pid, _, _, _ in gallery] 74 | query_cams = [cam for _, _, _, _, _, cam in query] 75 | gallery_cams = [cam for _, _, _, _, _, cam in gallery] 76 | else: 77 | assert (query_ids is not None and gallery_ids is not None 78 | and query_cams is not None and gallery_cams is not None) 79 | 80 | # Compute mean AP 81 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 82 | print('Mean AP: {:4.1%}'.format(mAP)) 83 | 84 | # Compute all kinds of CMC scores 85 | cmc_configs = { 86 | 'allshots': dict(separate_camera_set=False, 87 | single_gallery_shot=False, 88 | first_match_break=False), 89 | 'cuhk03': dict(separate_camera_set=True, # In cuhk03, query and gallery sets are from different camera views. 90 | single_gallery_shot=True, # The gallery just includes a camera view 91 | first_match_break=False), 92 | 'market1501': dict(separate_camera_set=False, # In Market-1501, query and gallery sets could have same camera views. 93 | single_gallery_shot=False, # The gallery includes multiple camera views 94 | first_match_break=True)} 95 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 96 | query_cams, gallery_cams, **params) 97 | for name, params in cmc_configs.items()} 98 | 99 | print('CMC Scores{:>12}{:>12}{:>12}' 100 | .format('allshots', 'cuhk03', 'market1501')) 101 | for k in cmc_topk: 102 | print(' top-{:<4}{:12.1%}{:12.1%}{:12.1%}' 103 | .format(k, cmc_scores['allshots'][k - 1], 104 | cmc_scores['cuhk03'][k - 1], 105 | cmc_scores['market1501'][k - 1])) 106 | 107 | # Use the allshots cmc top-1 score for validation criterion 108 | return cmc_scores['allshots'][0] 109 | 110 | 111 | class Evaluator(object): 112 | def __init__(self, model): 113 | super(Evaluator, self).__init__() 114 | self.model = model 115 | 116 | def evaluate(self, data_loader, query, gallery, metric=None): 117 | features, _ = extract_features(self.model, data_loader) 118 | distmat = pairwise_distance(features, query, gallery, metric=metric) 119 | return evaluate_all(distmat, query=query, gallery=gallery) 120 | -------------------------------------------------------------------------------- /reid/evaluators_video.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | from collections import OrderedDict 4 | 5 | import torch 6 | 7 | from .evaluation_metrics import cmc, mean_ap 8 | from .feature_extraction import extract_cnn_feature 9 | from .utils.meters import AverageMeter 10 | 11 | import ipdb 12 | 13 | def pooling(inputs, method='average'): 14 | _pooling_methods = ['average', 'max'] 15 | assert method in _pooling_methods, "method must be within {}".format(_pooling_methods) 16 | 17 | # num_frames = inputs.shape[0] 18 | dim = inputs.shape[1] 19 | 20 | if method == 'average': 21 | feature = torch.mean(inputs, 0) 22 | assert feature.shape[0] == dim 23 | return feature 24 | elif method == 'max': 25 | feature, _ = torch.max(inputs, 0) 26 | assert feature.shape[0] == dim 27 | return feature 28 | 29 | 30 | def extract_features(model, data_loader, print_freq=1, metric=None): 31 | model.eval() 32 | batch_time = AverageMeter() 33 | data_time = AverageMeter() 34 | 35 | features_tkl = OrderedDict() 36 | labels = OrderedDict() 37 | cams = OrderedDict() 38 | end = time.time() 39 | 40 | # VERY IMPORTANT: tid is the key in video Re-ID (as the fname in image Re-ID) 41 | for i, (imgs, fnames, tid, pid, _, _, cam) in enumerate(data_loader): # traverse all data in data_loader(test_data) 42 | data_time.update(time.time() - end) 43 | 44 | tkl_batch_size = imgs.size(0) 45 | num_instances = imgs.size(1) 46 | img_batch_size = tkl_batch_size*num_instances 47 | 48 | imgs = imgs.view(img_batch_size, imgs.size(2), imgs.size(3), imgs.size(4)) 49 | features = extract_cnn_feature(model, imgs, cam) 50 | 51 | for j in range(tkl_batch_size): 52 | index = tid[j] 53 | features_tkl[index] = features[j*num_instances:(j+1)*num_instances] 54 | labels[index] = pid[j] 55 | cams[index] = cam[j] 56 | 57 | batch_time.update(time.time() - end) 58 | end = time.time() 59 | 60 | if (i + 1) % print_freq == 0: 61 | print('Extract Features: [{}/{}]\t' 62 | 'Time {:.3f} ({:.3f})\t' 63 | 'Data {:.3f} ({:.3f})\t' 64 | .format(i + 1, len(data_loader), 65 | batch_time.val, batch_time.avg, 66 | data_time.val, data_time.avg)) 67 | 68 | return features_tkl, labels, cams 69 | 70 | def extract_features_pooling(model, data_loader, print_freq=1, metric=None): 71 | model.eval() 72 | batch_time = AverageMeter() 73 | data_time = AverageMeter() 74 | 75 | features_tkl = OrderedDict() 76 | labels = OrderedDict() 77 | cams = OrderedDict() 78 | end = time.time() 79 | 80 | # VERY IMPORTANT: tid is the key in video Re-ID (as the fname in image Re-ID) 81 | for i, (imgs, fnames, tid, pid, _, _, cam) in enumerate(data_loader): # traverse all data in data_loader(test_data) 82 | data_time.update(time.time() - end) 83 | 84 | tkl_batch_size = imgs.size(0) 85 | num_instances = imgs.size(1) 86 | img_batch_size = tkl_batch_size * num_instances 87 | imgs = imgs.view(img_batch_size, imgs.size(2), imgs.size(3), imgs.size(4)) 88 | 89 | features = extract_cnn_feature(model, imgs, cam) 90 | for j in range(tkl_batch_size): 91 | index = tid[j].item() 92 | features_tkl[index] = pooling(features[j*num_instances:(j+1)*num_instances], 'average') 93 | labels[index] = pid[j] 94 | cams[index] = cam[j] 95 | 96 | batch_time.update(time.time() - end) 97 | end = time.time() 98 | 99 | if (i + 1) % print_freq == 0: 100 | print('Extract Features: [{}/{}]\t' 101 | 'Time {:.3f} ({:.3f})\t' 102 | 'Data {:.3f} ({:.3f})\t' 103 | .format(i + 1, len(data_loader), 104 | batch_time.val, batch_time.avg, 105 | data_time.val, data_time.avg)) 106 | 107 | return features_tkl, labels, cams 108 | 109 | 110 | def pairwise_distance(features, query=None, gallery=None, metric=None): 111 | x = torch.cat([features[tid].unsqueeze(0) for _, tid, _, _, _, _ in query], 0) 112 | y = torch.cat([features[tid].unsqueeze(0) for _, tid, _, _, _, _ in gallery], 0) 113 | 114 | m, n = x.size(0), y.size(0) 115 | x = x.view(m, -1) 116 | y = y.view(n, -1) 117 | 118 | if metric is not None: 119 | x = metric.transform(x) 120 | y = metric.transform(y) 121 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 122 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 123 | dist.addmm_(1, -2, x, y.t()) 124 | return dist 125 | 126 | 127 | def set_min_distance(features, query=None, gallery=None, metric=None): 128 | x = torch.cat([features[tid] for _, tid, _, _, _, _ in query], 0) 129 | query_set_num = len(query) 130 | y = torch.cat([features[tid] for _, tid, _, _, _, _ in gallery], 0) 131 | gallery_set_num = len(gallery) 132 | 133 | m, n = x.size(0), y.size(0) 134 | x = x.view(m, -1) 135 | y = y.view(n, -1) 136 | 137 | if metric is not None: 138 | x = metric.transform(x) 139 | y = metric.transform(y) 140 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 141 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 142 | dist.addmm_(1, -2, x, y.t()) 143 | 144 | num_instances = m / query_set_num 145 | set_dist = torch.FloatTensor(query_set_num, gallery_set_num).zero_().cuda() 146 | start = time.time() 147 | for i in range(query_set_num): 148 | for j in range(gallery_set_num): 149 | set_dist[i, j] = torch.min(dist[i*num_instances:(i+1)*num_instances, j*num_instances:(j+1)*num_instances]) 150 | end = time.time() 151 | print('Get set_dist time:{:.1f}s'.format(end-start)) 152 | 153 | return set_dist 154 | 155 | 156 | def evaluate_all(distmat, query=None, gallery=None, 157 | query_ids=None, gallery_ids=None, 158 | query_cams=None, gallery_cams=None, 159 | cmc_topk=(1, 5, 10, 20)): 160 | if query is not None and gallery is not None: 161 | query_ids = [pid for _, _, pid, _, _, _ in query] 162 | gallery_ids = [pid for _, _, pid, _, _, _ in gallery] 163 | query_cams = [cam for _, _, _, _, _, cam in query] 164 | gallery_cams = [cam for _, _, _, _, _, cam in gallery] 165 | else: 166 | assert (query_ids is not None and gallery_ids is not None 167 | and query_cams is not None and gallery_cams is not None) 168 | 169 | # Compute mean AP 170 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 171 | print('Mean AP: {:4.1%}'.format(mAP)) 172 | 173 | # Compute all kinds of CMC scores 174 | cmc_configs = { 175 | 'allshots': dict(separate_camera_set=False, 176 | single_gallery_shot=False, 177 | first_match_break=False), 178 | 'cuhk03': dict(separate_camera_set=True, # In cuhk03, query and gallery sets are from different camera views. 179 | single_gallery_shot=True, # The gallery just includes a camera view 180 | first_match_break=False), 181 | 'market1501': dict(separate_camera_set=False, # In Market-1501, query and gallery sets could have same camera views. 182 | single_gallery_shot=False, # The gallery includes multiple camera views 183 | first_match_break=True)} 184 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 185 | query_cams, gallery_cams, **params) 186 | for name, params in cmc_configs.items()} 187 | 188 | print('CMC Scores{:>12}{:>12}{:>12}' 189 | .format('allshots', 'cuhk03', 'market1501')) 190 | for k in cmc_topk: 191 | print(' top-{:<4}{:12.1%}{:12.1%}{:12.1%}' 192 | .format(k, cmc_scores['allshots'][k - 1], 193 | cmc_scores['cuhk03'][k - 1], 194 | cmc_scores['market1501'][k - 1])) 195 | 196 | # Use the allshots cmc top-1 score for validation criterion 197 | return cmc_scores['allshots'][0] 198 | 199 | 200 | class Evaluator(object): 201 | def __init__(self, model): 202 | super(Evaluator, self).__init__() 203 | self.model = model 204 | 205 | def evaluate(self, data_loader, query, gallery, metric=None): 206 | features_tkl, labels, cams = extract_features_pooling(self.model, data_loader) 207 | distmat = pairwise_distance(features_tkl, query, gallery, metric=metric) 208 | return evaluate_all(distmat, query=query, gallery=gallery) 209 | 210 | # def evaluate(self, data_loader, query, gallery, metric=None): 211 | # features_tkl, labels, cams = extract_features(self.model, data_loader) 212 | # distmat = set_min_distance(features_tkl, query, gallery, metric=metric) 213 | # return evaluate_all(distmat, query=query, gallery=gallery) 214 | -------------------------------------------------------------------------------- /reid/feature_extraction/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .cnn import extract_cnn_feature 4 | from .database import FeatureDatabase 5 | 6 | __all__ = [ 7 | 'extract_cnn_feature', 8 | 'FeatureDatabase', 9 | ] 10 | -------------------------------------------------------------------------------- /reid/feature_extraction/cnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import OrderedDict 3 | 4 | from torch.autograd import Variable 5 | 6 | from ..utils import to_torch 7 | 8 | import ipdb 9 | 10 | def extract_cnn_feature(model, inputs, camid, modules=None): 11 | model.eval() 12 | inputs = to_torch(inputs) 13 | inputs = Variable(inputs, volatile=True) 14 | if modules is None: 15 | # ipdb.set_trace() 16 | # outputs = model(inputs) 17 | outputs = model(inputs, 0) 18 | outputs = outputs.data.cpu() 19 | return outputs 20 | # Register forward hook for each module 21 | outputs = OrderedDict() 22 | handles = [] 23 | for m in modules: 24 | outputs[id(m)] = None 25 | def func(m, i, o): outputs[id(m)] = o.data.cpu() 26 | handles.append(m.register_forward_hook(func)) 27 | model(inputs) 28 | for h in handles: 29 | h.remove() 30 | return list(outputs.values()) -------------------------------------------------------------------------------- /reid/feature_extraction/database.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import h5py 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class FeatureDatabase(Dataset): 9 | def __init__(self, *args, **kwargs): 10 | super(FeatureDatabase, self).__init__() 11 | self.fid = h5py.File(*args, **kwargs) 12 | 13 | def __enter__(self): 14 | return self 15 | 16 | def __exit__(self, exc_type, exc_val, exc_tb): 17 | self.close() 18 | 19 | def __getitem__(self, keys): 20 | if isinstance(keys, (tuple, list)): 21 | return [self._get_single_item(k) for k in keys] 22 | return self._get_single_item(keys) 23 | 24 | def _get_single_item(self, key): 25 | return np.asarray(self.fid[key]) 26 | 27 | def __setitem__(self, key, value): 28 | if key in self.fid: 29 | if self.fid[key].shape == value.shape and \ 30 | self.fid[key].dtype == value.dtype: 31 | self.fid[key][...] = value 32 | else: 33 | del self.fid[key] 34 | self.fid.create_dataset(key, data=value) 35 | else: 36 | self.fid.create_dataset(key, data=value) 37 | 38 | def __delitem__(self, key): 39 | del self.fid[key] 40 | 41 | def __len__(self): 42 | return len(self.fid) 43 | 44 | def __iter__(self): 45 | return iter(self.fid) 46 | 47 | def flush(self): 48 | self.fid.flush() 49 | 50 | def close(self): 51 | self.fid.close() 52 | -------------------------------------------------------------------------------- /reid/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .poset_G2G import PosetLoss_G2G 4 | 5 | __all__ = [ 6 | 'PosetLoss_G2G', 7 | ] 8 | -------------------------------------------------------------------------------- /reid/loss/poset_G2G.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import math 5 | from torch import nn 6 | from torch.autograd import Variable 7 | from ..utils.meters import AverageMeter 8 | import time 9 | import ipdb 10 | 11 | 12 | class PosetLoss_G2G(nn.Module): 13 | def __init__(self, margin=0): 14 | super(PosetLoss_G2G, self).__init__() 15 | self.margin = margin 16 | # self.ranking_loss = nn.MarginRankingLoss(margin=margin) 17 | 18 | def unique_index(self, L, e): 19 | return [i for i, v in enumerate(L) if v == e] 20 | 21 | def forward(self, batch_features, labels, camid): 22 | sample_num = batch_features.size(0) # sample_num is batchsize 23 | dim = batch_features.size(1) 24 | task_num = len(set(camid)) 25 | 26 | # batch_features to group_features 27 | num_instances = 0 28 | for i_sample in range(sample_num): 29 | if camid[i_sample] == camid[0]: 30 | num_instances += 1 31 | else: 32 | break 33 | group_num = int(sample_num / num_instances) # m is the num of classes in minibatch 34 | 35 | # Compute the mask via camid 36 | mask = torch.ByteTensor(group_num, sample_num).zero_().cuda() 37 | for i_group in range(group_num): 38 | for j_instance in range(num_instances): 39 | mask[i_group][i_group * num_instances + j_instance] = 1 40 | group_features = [] 41 | group_labels = [] 42 | group_camid = [] 43 | for i in range(group_num): 44 | feature = batch_features[mask[i].nonzero().squeeze(), :] 45 | feature_mean = torch.mean(feature, 0).unsqueeze(0) 46 | 47 | label_mean = labels[mask[i].nonzero().squeeze()][0] 48 | camid_mean = camid[mask[i].nonzero().squeeze()][0] 49 | 50 | group_features.append(feature_mean) 51 | group_labels.append(label_mean) 52 | group_camid.append(camid_mean) 53 | group_features = torch.cat(group_features) 54 | group_labels = torch.LongTensor(group_labels).cuda() 55 | group_camid = torch.LongTensor(group_camid).cuda() 56 | 57 | # Group to Group of the batch 58 | x = group_features 59 | y = group_features 60 | n = group_num 61 | m = group_num 62 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(n, m) + \ 63 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(m, n).t() 64 | dist.addmm_(1, -2, x, y.t()) 65 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 66 | sigma = 4 67 | exp_dist = torch.exp(-dist / (2 * sigma)) 68 | 69 | mask = torch.ByteTensor(n, m).zero_().cuda() 70 | dist_near = Variable(torch.FloatTensor(n).zero_().cuda()) 71 | K = int(task_num / 2) 72 | for i in range(n): 73 | taskID = group_camid[i] 74 | mask[i] = (group_camid != taskID) 75 | # len of dist_neib_cross less than len of exp_dist 76 | dist_neib_cross = torch.masked_select(exp_dist[i].data, mask[i]) 77 | value_cross, index_cross = torch.sort(dist_neib_cross, 0, descending=True) # True: big2small False: small2big 78 | dist_near[i] = torch.sum(value_cross[:K]) 79 | dist_all = torch.sum(exp_dist, 1) 80 | 81 | # compute poset_loss 82 | quotient = dist_near / dist_all 83 | # entropy_loss = -torch.sum(quotient * torch.log(quotient))/sample_num 84 | poset_loss = torch.sum(-torch.log(quotient)) / n 85 | assert poset_loss.item() >= 0 86 | return poset_loss 87 | -------------------------------------------------------------------------------- /reid/metric_learning/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from metric_learn import (ITML_Supervised, LMNN, LSML_Supervised, 4 | SDML_Supervised, NCA, LFDA, RCA_Supervised) 5 | 6 | from .euclidean import Euclidean 7 | from .kissme import KISSME 8 | 9 | __factory = { 10 | 'euclidean': Euclidean, 11 | 'kissme': KISSME, 12 | 'itml': ITML_Supervised, 13 | 'lmnn': LMNN, 14 | 'lsml': LSML_Supervised, 15 | 'sdml': SDML_Supervised, 16 | 'nca': NCA, 17 | 'lfda': LFDA, 18 | 'rca': RCA_Supervised, 19 | } 20 | 21 | 22 | def get_metric(algorithm, *args, **kwargs): 23 | if algorithm not in __factory: 24 | raise KeyError("Unknown metric:", algorithm) 25 | return __factory[algorithm](*args, **kwargs) 26 | -------------------------------------------------------------------------------- /reid/metric_learning/euclidean.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | from metric_learn.base_metric import BaseMetricLearner 5 | 6 | 7 | class Euclidean(BaseMetricLearner): 8 | def __init__(self): 9 | self.M_ = None 10 | 11 | def metric(self): 12 | return self.M_ 13 | 14 | def fit(self, X): 15 | self.M_ = np.eye(X.shape[1]) 16 | self.X_ = X 17 | 18 | def transform(self, X=None): 19 | if X is None: 20 | return self.X_ 21 | return X 22 | -------------------------------------------------------------------------------- /reid/metric_learning/kissme.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | from metric_learn.base_metric import BaseMetricLearner 5 | 6 | 7 | def validate_cov_matrix(M): 8 | M = (M + M.T) * 0.5 9 | k = 0 10 | I = np.eye(M.shape[0]) 11 | while True: 12 | try: 13 | _ = np.linalg.cholesky(M) 14 | break 15 | except np.linalg.LinAlgError: 16 | # Find the nearest positive definite matrix for M. Modified from 17 | # http://www.mathworks.com/matlabcentral/fileexchange/42885-nearestspd 18 | # Might take several minutes 19 | k += 1 20 | w, v = np.linalg.eig(M) 21 | min_eig = v.min() 22 | M += (-min_eig * k * k + np.spacing(min_eig)) * I 23 | return M 24 | 25 | 26 | class KISSME(BaseMetricLearner): 27 | def __init__(self): 28 | self.M_ = None 29 | 30 | def metric(self): 31 | return self.M_ 32 | 33 | def fit(self, X, y=None): 34 | n = X.shape[0] 35 | if y is None: 36 | y = np.arange(n) 37 | X1, X2 = np.meshgrid(np.arange(n), np.arange(n)) 38 | X1, X2 = X1[X1 < X2], X2[X1 < X2] 39 | matches = (y[X1] == y[X2]) 40 | num_matches = matches.sum() 41 | num_non_matches = len(matches) - num_matches 42 | idxa = X1[matches] 43 | idxb = X2[matches] 44 | S = X[idxa] - X[idxb] 45 | C1 = S.transpose().dot(S) / num_matches 46 | p = np.random.choice(num_non_matches, num_matches, replace=False) 47 | idxa = X1[~matches] 48 | idxb = X2[~matches] 49 | idxa = idxa[p] 50 | idxb = idxb[p] 51 | S = X[idxa] - X[idxb] 52 | C0 = S.transpose().dot(S) / num_matches 53 | self.M_ = np.linalg.inv(C1) - np.linalg.inv(C0) 54 | self.M_ = validate_cov_matrix(self.M_) 55 | self.X_ = X 56 | -------------------------------------------------------------------------------- /reid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .resnet_mt import * 3 | 4 | 5 | __factory = { 6 | 'resnet18': resnet18, 7 | 'resnet34': resnet34, 8 | 'resnet50': resnet50, 9 | 'resnet101': resnet101, 10 | 'resnet152': resnet152, 11 | } 12 | 13 | 14 | def names(): 15 | return sorted(__factory.keys()) 16 | 17 | 18 | def create(name, *args, **kwargs): 19 | """ 20 | Create a model instance. 21 | 22 | Parameters 23 | ---------- 24 | name : str 25 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 26 | 'resnet50', 'resnet101', and 'resnet152'. 27 | pretrained : bool, optional 28 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 29 | model. Default: True 30 | cut_at_pooling : bool, optional 31 | If True, will cut the model before the last global pooling layer and 32 | ignore the remaining kwargs. Default: False 33 | num_features : int, optional 34 | If positive, will append a Linear layer after the global pooling layer, 35 | with this number of output units, followed by a BatchNorm layer. 36 | Otherwise these layers will not be appended. Default: 256 for 37 | 'inception', 0 for 'resnet*' 38 | norm : bool, optional 39 | If True, will normalize the feature to be unit L2-norm for each sample. 40 | Otherwise will append a ReLU layer after the above Linear layer if 41 | num_features > 0. Default: False 42 | dropout : float, optional 43 | If positive, will append a Dropout layer with this dropout rate. 44 | Default: 0 45 | num_classes : int, optional 46 | If positive, will append a Linear layer at the end as the classifier 47 | with this number of output units. Default: 0 48 | """ 49 | if name not in __factory: 50 | raise KeyError("Unknown model:", name) 51 | return __factory[name](*args, **kwargs) 52 | -------------------------------------------------------------------------------- /reid/models/resnet_mt.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | 8 | import ipdb 9 | 10 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 11 | 'resnet152'] 12 | 13 | class ResNet(nn.Module): 14 | __factory = { 15 | 18: torchvision.models.resnet18, 16 | 34: torchvision.models.resnet34, 17 | 50: torchvision.models.resnet50, 18 | 101: torchvision.models.resnet101, 19 | 152: torchvision.models.resnet152, 20 | } 21 | 22 | def __init__(self, depth, pretrained=True, cut_at_pooling=False, 23 | num_features=0, norm=False, dropout=0, num_classes=0, double_loss=False): 24 | super(ResNet, self).__init__() 25 | 26 | self.depth = depth 27 | self.pretrained = pretrained 28 | self.cut_at_pooling = cut_at_pooling 29 | self.double_loss = double_loss 30 | 31 | # Construct base (pretrained) resnet 32 | if depth not in ResNet.__factory: 33 | raise KeyError("Unsupported depth:", depth) 34 | self.base = ResNet.__factory[depth](pretrained=pretrained) 35 | 36 | if not self.cut_at_pooling: 37 | self.num_features = num_features 38 | self.norm = norm 39 | self.dropout = dropout 40 | self.has_embedding = num_features > 0 41 | self.num_classes = num_classes 42 | self.num_camera = len(num_classes) 43 | 44 | out_planes = self.base.fc.in_features 45 | 46 | # Append new layers 47 | if self.has_embedding: 48 | self.feat = nn.Linear(out_planes, self.num_features) 49 | self.feat_bn = nn.BatchNorm1d(self.num_features) 50 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 51 | init.constant_(self.feat.bias, 0) 52 | init.constant_(self.feat_bn.weight, 1) 53 | init.constant_(self.feat_bn.bias, 0) 54 | else: 55 | # Change the num_features to CNN output channels 56 | self.num_features = out_planes 57 | if self.dropout > 0: 58 | self.drop = nn.Dropout(self.dropout) 59 | 60 | self.multitask_classifier = nn.ModuleList() 61 | for task_i in range(0, self.num_camera): 62 | classifier = nn.Linear(self.num_features, self.num_classes[task_i]) 63 | init.normal_(classifier.weight, std=0.001) 64 | init.constant_(classifier.bias, 0) 65 | self.multitask_classifier += [classifier] 66 | # module = [nn.Linear(self.num_features, self.num_classes[task_i])] 67 | # self.multitask_classifier += module 68 | 69 | if not self.pretrained: 70 | self.reset_params() 71 | 72 | def forward(self, x, cam_i): 73 | for name, module in self.base._modules.items(): 74 | if name == 'avgpool': 75 | break 76 | x = module(x) 77 | 78 | if self.cut_at_pooling: 79 | return x 80 | 81 | x = F.avg_pool2d(x, x.size()[2:]) 82 | x = x.view(x.size(0), -1) 83 | 84 | if self.has_embedding: 85 | x = self.feat(x) 86 | x = self.feat_bn(x) 87 | if self.norm: 88 | x = F.normalize(x) 89 | elif self.has_embedding: 90 | x = F.relu(x) 91 | if self.dropout > 0: 92 | x = self.drop(x) 93 | 94 | if not self.training: 95 | return x # x is self.num_features(2048D) 96 | elif self.num_classes[cam_i] > 0: 97 | prelogits = self.multitask_classifier[cam_i](x) 98 | 99 | if self.double_loss: 100 | return prelogits, x 101 | else: 102 | return prelogits 103 | 104 | def reset_params(self): 105 | for m in self.modules(): 106 | if isinstance(m, nn.Conv2d): 107 | init.kaiming_normal(m.weight, mode='fan_out') 108 | if m.bias is not None: 109 | init.constant(m.bias, 0) 110 | elif isinstance(m, nn.BatchNorm2d): 111 | init.constant(m.weight, 1) 112 | init.constant(m.bias, 0) 113 | elif isinstance(m, nn.Linear): 114 | init.normal(m.weight, std=0.001) 115 | if m.bias is not None: 116 | init.constant(m.bias, 0) 117 | 118 | 119 | def resnet18(**kwargs): 120 | return ResNet(18, **kwargs) 121 | 122 | 123 | def resnet34(**kwargs): 124 | return ResNet(34, **kwargs) 125 | 126 | 127 | def resnet50(**kwargs): 128 | return ResNet(50, **kwargs) 129 | 130 | 131 | def resnet101(**kwargs): 132 | return ResNet(101, **kwargs) 133 | 134 | 135 | def resnet152(**kwargs): 136 | return ResNet(152, **kwargs) 137 | -------------------------------------------------------------------------------- /reid/trainers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | from .evaluation_metrics import accuracy 7 | from .loss import PosetLoss_G2G 8 | from .utils.meters import AverageMeter 9 | import ipdb 10 | 11 | class BaseTrainer(object): 12 | def __init__(self, model, criterion_1, criterion_2, num_task): 13 | super(BaseTrainer, self).__init__() 14 | self.model = model 15 | self.criterion_1 = criterion_1 16 | self.criterion_2 = criterion_2 17 | self.num_task = num_task 18 | 19 | def train(self, epoch, data_loader, optimizer, print_freq=1): 20 | self.model.train() 21 | 22 | batch_time = AverageMeter() 23 | data_time = AverageMeter() 24 | losses = AverageMeter() 25 | precisions = AverageMeter() 26 | 27 | end = time.time() 28 | for i, inputs in enumerate(data_loader): 29 | data_time.update(time.time() - end) 30 | 31 | # inputs is the image data, targets is the pid label, camid is the cam 32 | inputs, targets, camid = self._parse_data(inputs) 33 | loss, loss_1, loss_2, prec1, loss_1_time, loss_2_time = self._forward(inputs, targets, camid) 34 | 35 | losses.update(loss.item(), targets.size(0)) 36 | precisions.update(prec1, targets.size(0)) 37 | 38 | optimizer.zero_grad() 39 | loss.backward() 40 | optimizer.step() 41 | 42 | batch_time.update(time.time() - end) 43 | end = time.time() 44 | 45 | if (i + 1) % print_freq == 0: 46 | print('Epoch: [{}][{}/{}]\t' 47 | 'Loss_1 {:.3f} \t' 48 | 'Loss_2 {:.3f} \t' 49 | 'Loss {:.3f} ({:.3f})\t' 50 | 'Prec {:.2%} ({:.2%})\t' 51 | 'Loss_1_Time {:.3f} ({:.3f})\t' 52 | 'Loss_2_Time {:.3f} ({:.3f})\t' 53 | .format(epoch, i + 1, len(data_loader), 54 | loss_1.item(), 55 | loss_2.item(), 56 | # 0, 57 | losses.val, losses.avg, 58 | precisions.val, precisions.avg, 59 | loss_1_time.val, loss_1_time.avg, 60 | loss_2_time.val, loss_2_time.avg)) 61 | 62 | def _parse_data(self, inputs): 63 | raise NotImplementedError 64 | 65 | def _forward(self, inputs, targets, camid): 66 | raise NotImplementedError 67 | 68 | def unique_index(L, e): 69 | return [i for i, v in enumerate(L) if v == e] 70 | 71 | class Trainer(BaseTrainer): 72 | def _parse_data(self, inputs): 73 | imgs, fname, _, _, tids_pc, _, camids = inputs # image_data, image_names, tids, pids, tids_pc, pids_pc, cams 74 | 75 | inputs = imgs.cuda() 76 | targets = tids_pc.cuda() 77 | camids = camids.cuda() 78 | return inputs, targets, camids 79 | 80 | def _forward(self, inputs, labels, camid): 81 | if isinstance(self.criterion_1, torch.nn.CrossEntropyLoss) and isinstance(self.criterion_2, PosetLoss_G2G): 82 | loss_1_time = AverageMeter() 83 | loss_2_time = AverageMeter() 84 | 85 | end = time.time() 86 | # compute loss_1 87 | loss_1 = Variable(torch.FloatTensor(1).zero_().cuda()) 88 | prec = 0 89 | batch_features = [] 90 | for t in range(self.num_task): 91 | sample_index = torch.LongTensor(unique_index(camid, t)).cuda() 92 | if len(sample_index) > 0: 93 | labels_t = torch.index_select(labels, 0, sample_index) # labels in task t 94 | inputs_t = Variable(torch.index_select(inputs, 0, sample_index)) 95 | prelogits_cam_i, features_cam_i = self.model(inputs_t, t) 96 | 97 | # loss_1 98 | loss_1 += self.criterion_1(prelogits_cam_i, Variable(labels_t)) 99 | prec_1, = accuracy(prelogits_cam_i.data, labels_t) 100 | prec += prec_1[0] 101 | 102 | batch_features.append(features_cam_i) # concentrate the features for computing loss_2 103 | batch_features = torch.cat(batch_features) 104 | prec = prec / self.num_task 105 | loss_1_time.update(time.time() - end) 106 | 107 | # compute loss_2 108 | end = time.time() 109 | loss_2 = self.criterion_2(batch_features, labels, camid) 110 | 111 | # sum 112 | lamda = 0.7 113 | loss = (1 - lamda) * loss_1 + lamda * loss_2 114 | else: 115 | raise ValueError("Unsupported loss:", self.criterion_1) 116 | return loss, (1 - lamda) * loss_1, lamda * loss_2, prec, loss_1_time, loss_2_time 117 | -------------------------------------------------------------------------------- /reid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | -------------------------------------------------------------------------------- /reid/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .dataset import Dataset 4 | from .preprocessor_image import Preprocessor_Image 5 | from .preprocessor_video import Preprocessor_Video 6 | -------------------------------------------------------------------------------- /reid/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os.path as osp 3 | import numpy as np 4 | import ipdb 5 | 6 | from ..serialization import read_json 7 | 8 | def _pluck(identities, indices, relabel=False): 9 | ret = [] 10 | for index, pid in enumerate(indices): 11 | pid_images = identities[pid] 12 | for camid, cam_images in enumerate(pid_images): 13 | for fname in cam_images: 14 | name = osp.splitext(fname)[0] 15 | x, y, _ = map(int, name.split('_')) 16 | assert pid == x and camid == y 17 | if relabel: 18 | ret.append((fname, index, camid)) 19 | else: 20 | ret.append((fname, pid, camid)) 21 | return ret 22 | 23 | def _mt_pluck(identities, indices, num_cameras, relabel=False): 24 | temp = [] 25 | for index, pid in enumerate(indices): 26 | pid_images = identities[pid] 27 | for camid, cam_images in enumerate(pid_images): 28 | for fname in cam_images: 29 | name = osp.splitext(fname)[0] 30 | x, y, _ = map(int, name.split('_')) 31 | assert pid == x and camid == y 32 | temp.append((fname, pid, camid)) 33 | 34 | ret = [] 35 | id_num_cam = [] 36 | training_img_num = len(temp) 37 | for camid in range(0,num_cameras): 38 | pid_cam = -1 39 | for index in range(0,training_img_num): 40 | if temp[index][2] == camid: 41 | if pid_cam == -1: 42 | fname = temp[index][0] 43 | pid_cam = pid_cam + 1 44 | ret.append((fname, pid_cam, camid)) 45 | cur_pid_total = temp[index][1] 46 | else: 47 | fname = temp[index][0] 48 | if temp[index][1] == cur_pid_total: 49 | ret.append((fname, pid_cam, camid)) 50 | else: 51 | pid_cam = pid_cam + 1 52 | ret.append((fname, pid_cam, camid)) 53 | cur_pid_total = temp[index][1] 54 | id_num_cam.append(pid_cam+1) 55 | 56 | return ret, id_num_cam 57 | 58 | class Dataset(object): 59 | def __init__(self, root, split_id=0): 60 | self.root = root 61 | self.split_id = split_id 62 | self.meta = None 63 | self.split = None 64 | self.train, self.val, self.trainval = [], [], [] 65 | self.query, self.gallery = [], [] 66 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 67 | 68 | @property 69 | def images_dir(self): 70 | return osp.join(self.root, 'images') 71 | 72 | def load(self, num_val=0.3, verbose=True): 73 | splits = read_json(osp.join(self.root, 'splits.json')) 74 | if self.split_id >= len(splits): 75 | raise ValueError("split_id exceeds total splits {}" 76 | .format(len(splits))) 77 | self.split = splits[self.split_id] 78 | 79 | # Randomly split train / val 80 | trainval_pids = np.asarray(self.split['trainval']) 81 | 82 | np.random.shuffle(trainval_pids) 83 | num = len(trainval_pids) 84 | if isinstance(num_val, float): 85 | num_val = int(round(num * num_val)) 86 | if num_val >= num or num_val < 0: 87 | raise ValueError("num_val exceeds total identities {}" 88 | .format(num)) 89 | train_pids = sorted(trainval_pids[:-num_val]) 90 | val_pids = sorted(trainval_pids[-num_val:]) 91 | 92 | self.meta = read_json(osp.join(self.root, 'meta.json')) 93 | identities = self.meta['identities'] 94 | num_cameras = self.meta['num_cameras'] 95 | 96 | self.train = _pluck(identities, train_pids, relabel=True) # relabel is necessary, id number of training set is the num_class 97 | self.val = _pluck(identities, val_pids, relabel=True) 98 | self.trainval = _pluck(identities, trainval_pids, relabel=True) 99 | self.query = _pluck(identities, self.split['query']) 100 | self.gallery = _pluck(identities, self.split['gallery']) 101 | self.num_train_ids = len(train_pids) 102 | self.num_val_ids = len(val_pids) 103 | self.num_trainval_ids = len(trainval_pids) 104 | 105 | self.mt_train, self.num_train_ids_cam = _mt_pluck(identities, train_pids, num_cameras, relabel=True) 106 | self.mt_trainval, self.num_trainval_ids_cam = _mt_pluck(identities, trainval_pids, num_cameras, relabel=True)# usually use this 107 | 108 | if verbose: 109 | print(self.__class__.__name__, "dataset loaded") 110 | print(" subset | # ids | # images") 111 | print(" ---------------------------") 112 | print(" train | {:5d} | {:8d}" 113 | .format(self.num_train_ids, len(self.train))) 114 | print(" val | {:5d} | {:8d}" 115 | .format(self.num_val_ids, len(self.val))) 116 | print(" trainval | {:5d} | {:8d}" 117 | .format(self.num_trainval_ids, len(self.trainval))) 118 | print(" query | {:5d} | {:8d}" 119 | .format(len(self.split['query']), len(self.query))) 120 | print(" gallery | {:5d} | {:8d}" 121 | .format(len(self.split['gallery']), len(self.gallery))) 122 | 123 | def _check_integrity(self): 124 | return osp.isdir(osp.join(self.root, 'images')) and \ 125 | osp.isfile(osp.join(self.root, 'meta.json')) and \ 126 | osp.isfile(osp.join(self.root, 'splits.json')) 127 | -------------------------------------------------------------------------------- /reid/utils/data/preprocessor_image.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os.path as osp 3 | 4 | from PIL import Image 5 | 6 | class Preprocessor_Image(object): 7 | def __init__(self, dataset, root=None, transform=None): 8 | super(Preprocessor_Image, self).__init__() 9 | self.dataset = dataset 10 | self.root = root 11 | self.transform = transform 12 | 13 | def __len__(self): 14 | return len(self.dataset) 15 | 16 | def __getitem__(self, indices): 17 | if isinstance(indices, (tuple, list)): 18 | return [self._get_single_item(index) for index in indices] 19 | return self._get_single_item(indices) 20 | 21 | def _get_single_item(self, index): 22 | fname, tid, pid, tidpc, pidpc, camid = self.dataset[index] 23 | fpath = fname 24 | # if self.root is not None: 25 | # fpath = osp.join(self.root, fname) 26 | img = Image.open(fpath).convert('RGB') 27 | if self.transform is not None: 28 | img = self.transform(img) 29 | return img, fname, tid, pid, tidpc, pidpc, camid 30 | -------------------------------------------------------------------------------- /reid/utils/data/preprocessor_video.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | 6 | import ipdb 7 | 8 | class Preprocessor_Video(object): 9 | """ 10 | This class deals with video-reid where each tracklet has a number 11 | of images and only a fixed number of images is selected. 12 | """ 13 | _sample_methods = ['evenly', 'random', 'all'] 14 | 15 | def __init__(self, dataset, root=None, transform=None, seq_len=1, sample='evenly'): 16 | super(Preprocessor_Video, self).__init__() 17 | assert sample in self._sample_methods, "sample must be within {}".format(self._sample_methods) 18 | assert transform is not None 19 | self.dataset = dataset 20 | self.root = root 21 | self.transform = transform 22 | self.seq_len = seq_len 23 | self.sample = sample 24 | 25 | def __len__(self): 26 | return len(self.dataset) 27 | 28 | def __getitem__(self, indices): 29 | if isinstance(indices, (tuple, list)): 30 | return [self._get_single_item(index) for index in indices] 31 | return self._get_single_item(indices) 32 | 33 | def _get_single_item(self, index): 34 | fpaths, tid, pid, tidpc, pidpc, camid = self.dataset[index] 35 | num = len(fpaths) # the frame num of a tracklet 36 | 37 | if self.sample == 'random': 38 | """ 39 | Randomly sample seq_len items from num items, if num is smaller than 40 | seq_len, replicating items is adopted 41 | """ 42 | indices = np.arange(num) 43 | if num >= self.seq_len: 44 | indices = np.random.choice(indices, size=self.seq_len, replace=False) 45 | else: 46 | indices = np.random.choice(indices, size=self.seq_len, replace=True) 47 | # TODO: disable the sorting to achieve order-agnostic 48 | #indices = np.sort(indices) 49 | elif self.sample == 'evenly': 50 | """ 51 | Evenly sample seq_len items from num items 52 | """ 53 | assert num >= self.seq_len, "condition failed: num ({}) >= self.seq_len ({})".format(num, self.seq_len) 54 | num -= num % self.seq_len 55 | indices = np.arange(0, num, num/self.seq_len) 56 | assert len(indices) == self.seq_len 57 | elif self.sample == 'all': 58 | """ 59 | Sample all items, seq_len is useless now and batch_size needs 60 | to be set to 1 otherwise error will occur 61 | """ 62 | indices = np.arange(num) 63 | else: 64 | raise KeyError("unknown sample method: {}".format(self.sample)) 65 | 66 | imgs = [] 67 | fnames = [] 68 | for idx in indices: 69 | fpath = fpaths[idx] 70 | img = Image.open(fpath).convert('RGB') 71 | if self.transform is not None: img = self.transform(img) 72 | img = img.unsqueeze(0) 73 | imgs.append(img) 74 | fnames.append(fpath) 75 | imgs = torch.cat(imgs, dim=0) 76 | 77 | # return imgs, fpaths[0], tid, pid, camid 78 | return imgs, fnames, tid, pid, tidpc, pidpc, camid 79 | -------------------------------------------------------------------------------- /reid/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data.sampler import ( 7 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 8 | WeightedRandomSampler) 9 | 10 | import ipdb 11 | 12 | 13 | class RandomIdentitySampler(Sampler): 14 | def __init__(self, data_source, num_instances=1): 15 | self.data_source = data_source 16 | self.num_instances = num_instances 17 | self.index_dic = defaultdict(list) 18 | for index, (_, _, pid, _) in enumerate(data_source): 19 | self.index_dic[pid].append(index) 20 | self.pids = list(self.index_dic.keys()) 21 | self.num_samples = len(self.pids) 22 | 23 | def __len__(self): 24 | return self.num_samples * self.num_instances 25 | 26 | def __iter__(self): 27 | indices = torch.randperm(self.num_samples) 28 | ret = [] 29 | for i in indices: 30 | pid = self.pids[i] 31 | t = self.index_dic[pid] 32 | if len(t) >= self.num_instances: 33 | t = np.random.choice(t, size=self.num_instances, replace=False) 34 | else: 35 | t = np.random.choice(t, size=self.num_instances, replace=True) 36 | ret.extend(t) 37 | return iter(ret) 38 | -------------------------------------------------------------------------------- /reid/utils/data/sampler_mt.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data.sampler import ( 7 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 8 | WeightedRandomSampler) 9 | 10 | import ipdb 11 | 12 | def nested_dict(n, type): 13 | if n == 1: 14 | return defaultdict(type) 15 | else: 16 | return defaultdict(lambda: nested_dict(n-1, type)) 17 | 18 | class RandomIdentitySampler(Sampler): 19 | def __init__(self, data_source, num_instances=1, num_task=1): 20 | self.data_source = data_source 21 | self.num_instances = num_instances 22 | self.num_task = num_task 23 | self.index_dic_mt = nested_dict(2, list) 24 | 25 | for index, (_, _, _, tid_pc, _, camid) in enumerate(data_source): 26 | self.index_dic_mt[camid][tid_pc].append(index) # camid equals taskid 27 | 28 | self.pids = [0]*num_task 29 | self.num_samples = [0]*num_task 30 | for cam_index in range(num_task): 31 | self.pids[cam_index] = list(self.index_dic_mt[cam_index].keys()) 32 | self.num_samples[cam_index] = len(self.pids[cam_index]) 33 | 34 | def __len__(self): 35 | num_samples = 0 36 | for t in range(self.num_task): 37 | num_samples += self.num_samples[t] 38 | return num_samples * self.num_instances 39 | 40 | def __iter__(self): 41 | ret = [] 42 | indices = nested_dict(self.num_task, list) 43 | for t in range(self.num_task): 44 | indices[t] = torch.randperm(self.num_samples[t]) 45 | # train_num = max(self.num_samples) 46 | train_num = min(self.num_samples) 47 | 48 | for tkl_index in range(train_num): 49 | for cam_index in range(self.num_task): 50 | loop_tkl_index = tkl_index % len(self.pids[cam_index]) 51 | i = indices[cam_index][loop_tkl_index] # pid_index has been shuffled 52 | pid = self.pids[cam_index][i] 53 | fid = self.index_dic_mt[cam_index][pid] 54 | if len(fid) >= self.num_instances: 55 | fid = np.random.choice(fid, size=self.num_instances, replace=False) 56 | else: 57 | fid = np.random.choice(fid, size=self.num_instances, replace=True) 58 | ret.extend(fid) 59 | return iter(ret) 60 | -------------------------------------------------------------------------------- /reid/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import math 6 | import random 7 | 8 | 9 | class RectScale(object): 10 | def __init__(self, height, width, interpolation=Image.BILINEAR): 11 | self.height = height 12 | self.width = width 13 | self.interpolation = interpolation 14 | 15 | def __call__(self, img): 16 | w, h = img.size 17 | if h == self.height and w == self.width: 18 | return img 19 | return img.resize((self.width, self.height), self.interpolation) 20 | 21 | 22 | class RandomSizedRectCrop(object): 23 | def __init__(self, height, width, interpolation=Image.BILINEAR): 24 | self.height = height 25 | self.width = width 26 | self.interpolation = interpolation 27 | 28 | def __call__(self, img): 29 | for attempt in range(10): 30 | area = img.size[0] * img.size[1] 31 | target_area = random.uniform(0.64, 1.0) * area 32 | aspect_ratio = random.uniform(2, 3) 33 | 34 | h = int(round(math.sqrt(target_area * aspect_ratio))) 35 | w = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | if w <= img.size[0] and h <= img.size[1]: 38 | x1 = random.randint(0, img.size[0] - w) 39 | y1 = random.randint(0, img.size[1] - h) 40 | 41 | img = img.crop((x1, y1, x1 + w, y1 + h)) 42 | assert(img.size == (w, h)) 43 | 44 | return img.resize((self.width, self.height), self.interpolation) 45 | 46 | # Fallback 47 | scale = RectScale(self.height, self.width, 48 | interpolation=self.interpolation) 49 | return scale(img) 50 | -------------------------------------------------------------------------------- /reid/utils/iotools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import os.path as osp 5 | import errno 6 | import json 7 | import shutil 8 | 9 | import torch 10 | 11 | 12 | def mkdir_if_missing(directory): 13 | if not osp.exists(directory): 14 | try: 15 | os.makedirs(directory) 16 | except OSError as e: 17 | if e.errno != errno.EEXIST: 18 | raise 19 | 20 | 21 | def read_json(fpath): 22 | with open(fpath, 'r') as f: 23 | obj = json.load(f) 24 | return obj 25 | 26 | 27 | def write_json(obj, fpath): 28 | mkdir_if_missing(osp.dirname(fpath)) 29 | with open(fpath, 'w') as f: 30 | json.dump(obj, f, indent=4, separators=(',', ': ')) 31 | 32 | 33 | def save_checkpoint(state, is_best=False, fpath='checkpoint.pth.tar'): 34 | if len(osp.dirname(fpath)) != 0: 35 | mkdir_if_missing(osp.dirname(fpath)) 36 | torch.save(state, fpath) 37 | if is_best: 38 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar')) -------------------------------------------------------------------------------- /reid/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /reid/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /reid/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /reid/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | if is_best: 28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 29 | 30 | 31 | def load_checkpoint(fpath): 32 | if osp.isfile(fpath): 33 | checkpoint = torch.load(fpath) 34 | print("=> Loaded checkpoint '{}'".format(fpath)) 35 | return checkpoint 36 | else: 37 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 38 | 39 | 40 | def copy_state_dict(state_dict, model, strip=None): 41 | tgt_state = model.state_dict() 42 | copied_names = set() 43 | for name, param in state_dict.items(): 44 | if strip is not None and name.startswith(strip): 45 | name = name[len(strip):] 46 | if name not in tgt_state: 47 | continue 48 | if isinstance(param, Parameter): 49 | param = param.data 50 | if param.size() != tgt_state[name].size(): 51 | print('mismatch:', name, param.size(), tgt_state[name].size()) 52 | continue 53 | tgt_state[name].copy_(param) 54 | copied_names.add(name) 55 | 56 | missing = set(tgt_state.keys()) - copied_names 57 | if len(missing) > 0: 58 | print("missing keys in state_dict:", missing) 59 | 60 | return model 61 | -------------------------------------------------------------------------------- /scripts/cuhk03.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | data_dir=/home/liva7/Data 3 | data_set=CUHK03 4 | 5 | CUDA_VISIBLE_DEVICES=3 python3 taudl_image.py \ 6 | --data-dir ${data_dir} \ 7 | -d ${data_set} \ 8 | -b 128 \ 9 | -a resnet50 \ 10 | --features 2048 \ 11 | --epochs 200 \ 12 | --num-instances 4 \ 13 | --start_save 100 -------------------------------------------------------------------------------- /scripts/duke_mr_tkl.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | data_dir=/home/liva7/Data 3 | data_set=DukeMTMC-MR-Tracklet 4 | 5 | CUDA_VISIBLE_DEVICES=1,2,3 python3 taudl_video.py \ 6 | --data-dir ${data_dir} \ 7 | -d ${data_set} \ 8 | -b 384 \ 9 | -a resnet50 \ 10 | --features 2048 \ 11 | --epochs 200 \ 12 | --num-instances 4 \ 13 | --start_save 100 -------------------------------------------------------------------------------- /scripts/duke_reid.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | data_dir=/home/liva7/Data 3 | data_set=DukeMTMC-reID 4 | 5 | CUDA_VISIBLE_DEVICES=1,2,3 python3 taudl_image.py \ 6 | --data-dir ${data_dir} \ 7 | -d ${data_set} \ 8 | -b 384 \ 9 | -a resnet50 \ 10 | --features 2048 \ 11 | --epochs 200 \ 12 | --num-instances 4 \ 13 | --start_save 100 -------------------------------------------------------------------------------- /scripts/duke_si_tkl.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | data_dir=/home/liva7/Data 3 | data_set=DukeMTMC-SI-Tracklet 4 | 5 | CUDA_VISIBLE_DEVICES=1,2,3 python3 taudl_video.py \ 6 | --data-dir ${data_dir} \ 7 | -d ${data_set} \ 8 | -b 384 \ 9 | -a resnet50 \ 10 | --features 2048 \ 11 | --epochs 200 \ 12 | --num-instances 4 \ 13 | --start_save 100 -------------------------------------------------------------------------------- /scripts/ilids-vid.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | data_dir=/home/liva7/Data 3 | data_set=iLIDS-VID 4 | 5 | CUDA_VISIBLE_DEVICES=1,2,3 python3 taudl_video.py \ 6 | --data-dir ${data_dir} \ 7 | -d ${data_set} \ 8 | -b 384 \ 9 | -a resnet50 \ 10 | --features 2048 \ 11 | --epochs 200 \ 12 | --num-instances 16 \ 13 | --start_save 100 -------------------------------------------------------------------------------- /scripts/market1501.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | data_dir=/home/liva7/Data 3 | data_set=Market1501 4 | 5 | CUDA_VISIBLE_DEVICES=1,2,3 python3 taudl_image.py \ 6 | --data-dir ${data_dir} \ 7 | -d ${data_set} \ 8 | -b 384 \ 9 | -a resnet50 \ 10 | --features 2048 \ 11 | --epochs 200 \ 12 | --num-instances 4 \ 13 | --start_save 100 -------------------------------------------------------------------------------- /scripts/mars.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | data_dir=/home/liva7/Data 3 | data_set=Mars 4 | 5 | CUDA_VISIBLE_DEVICES=1,2,3 python3 taudl_video.py \ 6 | --data-dir ${data_dir} \ 7 | -d ${data_set} \ 8 | -b 384 \ 9 | -a resnet50 \ 10 | --features 2048 \ 11 | --epochs 200 \ 12 | --num-instances 4 \ 13 | --start_save 100 -------------------------------------------------------------------------------- /scripts/msmt17.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | data_dir=/home/liva7/Data 3 | data_set=MSMT17 4 | 5 | CUDA_VISIBLE_DEVICES=1,2,3 python3 taudl_image.py \ 6 | --data-dir ${data_dir} \ 7 | -d ${data_set} \ 8 | -b 384 \ 9 | -a resnet50 \ 10 | --features 2048 \ 11 | --epochs 200 \ 12 | --num-instances 4 \ 13 | --start_save 100 -------------------------------------------------------------------------------- /scripts/prid2011.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | data_dir=/home/liva7/Data 3 | data_set=PRID2011 4 | 5 | CUDA_VISIBLE_DEVICES=1,2,3 python3 taudl_video.py \ 6 | --data-dir ${data_dir} \ 7 | -d ${data_set} \ 8 | -b 384 \ 9 | -a resnet50 \ 10 | --features 2048 \ 11 | --epochs 200 \ 12 | --num-instances 16 \ 13 | --start_save 100 -------------------------------------------------------------------------------- /taudl_image.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os.path as osp 4 | 5 | import numpy as np 6 | import sys 7 | import torch 8 | import time 9 | import datetime 10 | from torch import nn 11 | from torch.backends import cudnn 12 | from torch.utils.data import DataLoader 13 | from reid import datasets 14 | from reid import models 15 | from reid.dist_metric import DistanceMetric 16 | from reid.loss import PosetLoss_G2G 17 | from reid.trainers import Trainer 18 | from reid.evaluators_image import Evaluator 19 | from reid.utils.data import transforms as T 20 | from reid.utils.data.preprocessor_image import Preprocessor_Image 21 | from reid.utils.data.sampler_mt import RandomIdentitySampler 22 | from reid.utils.logging import Logger 23 | from reid.utils.serialization import load_checkpoint, save_checkpoint 24 | import ipdb 25 | 26 | 27 | def flatten_dataset(dataset): 28 | new_dataset = [] 29 | for tracklet in dataset: 30 | img_names, tid, pid, tid_pc, pid_pc, camid = tracklet 31 | for img_name in img_names: 32 | new_dataset.append((img_name, tid, pid, tid_pc, pid_pc, camid)) 33 | return new_dataset 34 | 35 | 36 | def get_data(name, split_id, data_dir, height, width, batch_size, num_instances, workers, combine_trainval): 37 | root = osp.join(data_dir, name) 38 | dataset = datasets.create(name, root, split_id=split_id) 39 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 40 | 41 | # load train_set, query_set, gallery_set 42 | mt_train_set = dataset.train 43 | mt_num_classes = dataset.num_train_tids_sub 44 | query_set = dataset.query 45 | gallery_set = dataset.gallery 46 | 47 | train_transformer = T.Compose([ 48 | T.RandomSizedRectCrop(height, width), 49 | T.RandomHorizontalFlip(), 50 | T.ToTensor(), 51 | normalizer 52 | ]) 53 | 54 | test_transformer = T.Compose([ 55 | T.RectScale(height, width), 56 | T.ToTensor(), 57 | normalizer, 58 | ]) 59 | 60 | # Random ID 61 | mt_train_set = flatten_dataset(mt_train_set) 62 | num_task = len(mt_num_classes) # num_task equals camera number, each camera is a task 63 | mt_train_loader = DataLoader( 64 | Preprocessor_Image(mt_train_set, root=dataset.dataset_dir, transform=train_transformer), 65 | batch_size=batch_size, num_workers=workers, 66 | sampler=RandomIdentitySampler(mt_train_set, num_instances, num_task), # Here is different between softmax_loss 67 | pin_memory=True, drop_last=True) 68 | 69 | query_set = flatten_dataset(query_set) 70 | gallery_set = flatten_dataset(gallery_set) 71 | test_set = list(set(query_set) | set(gallery_set)) 72 | test_loader = DataLoader( 73 | Preprocessor_Image(test_set, root=dataset.dataset_dir, transform=test_transformer), 74 | batch_size=batch_size, num_workers=workers, 75 | shuffle=False, pin_memory=True) 76 | 77 | return mt_train_loader, mt_num_classes, test_loader, query_set, gallery_set 78 | 79 | 80 | def main(args): 81 | np.random.seed(args.seed) 82 | torch.manual_seed(args.seed) 83 | cudnn.benchmark = True 84 | 85 | start = time.time() 86 | 87 | # Redirect print to both console and log file 88 | if not args.evaluate: 89 | dt = datetime.datetime.now() 90 | sys.stdout = Logger(osp.join(args.logs_dir, 'log_' 91 | + str(dt.month).zfill(2) 92 | + str(dt.day).zfill(2) 93 | + str(dt.hour).zfill(2) 94 | + str(dt.minute).zfill(2) + '.txt')) 95 | 96 | # Create data loaders 97 | assert args.num_instances > 1, "num_instances should be greater than 1" 98 | assert args.batch_size % args.num_instances == 0, \ 99 | 'num_instances should divide batch_size' 100 | if args.height is None or args.width is None: 101 | args.height, args.width = (144, 56) if args.arch == 'inception' else \ 102 | (256, 128) 103 | mt_train_loader, mt_num_classes, test_loader, query_set, gallery_set = \ 104 | get_data(args.dataset, args.split, args.data_dir, args.height, 105 | args.width, args.batch_size, args.num_instances, args.workers, 106 | args.combine_trainval) 107 | 108 | # Create model 109 | # Hacking here to let the classifier be the last feature embedding layer 110 | # Net structure: avgpool -> FC(1024) -> FC(args.features) 111 | model = models.create(args.arch, num_features=args.features, 112 | dropout=args.dropout, num_classes=mt_num_classes, double_loss=True) 113 | model = nn.DataParallel(model).cuda() 114 | 115 | # Distance metric 116 | metric = DistanceMetric(algorithm=args.dist_metric) 117 | 118 | # Evaluator 119 | evaluator = Evaluator(model) 120 | 121 | # Criterion 122 | criterion_1 = nn.CrossEntropyLoss().cuda() 123 | criterion_2 = PosetLoss_G2G(margin=args.margin).cuda() 124 | 125 | # Optimizer 126 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, 127 | weight_decay=args.weight_decay) 128 | 129 | # Trainer 130 | num_task = len(mt_num_classes) # num_task equals camera number, each camera is a task 131 | trainer = Trainer(model, criterion_1, criterion_2, num_task) 132 | 133 | # Schedule learning rate 134 | def adjust_lr(epoch): 135 | lr = args.lr if epoch <= 100 else \ 136 | args.lr * (0.001 ** ((epoch - 100) / 50.0)) 137 | for g in optimizer.param_groups: 138 | g['lr'] = lr * g.get('lr_mult', 1) 139 | 140 | # Start training 141 | start_epoch = best_top1 = 0 142 | for epoch in range(start_epoch, args.epochs): 143 | adjust_lr(epoch) 144 | trainer.train(epoch, mt_train_loader, optimizer) 145 | if (epoch % args.start_save == (args.start_save - 1)): 146 | save_checkpoint({ 147 | 'state_dict': model.module.state_dict(), 148 | 'epoch': epoch + 1, 149 | 'best_top1': best_top1, 150 | }, 0, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) 151 | 152 | # Final test 153 | print('Test with the model after epoch {:d}:'.format(epoch + 1)) 154 | checkpoint = load_checkpoint(osp.join(args.logs_dir, 'checkpoint.pth.tar')) 155 | model.module.load_state_dict(checkpoint['state_dict']) 156 | metric.train(model, mt_train_loader) 157 | evaluator.evaluate(test_loader, query_set, gallery_set, metric) 158 | end = time.time() 159 | print('Total time: {:.1f}s'.format(end-start)) 160 | 161 | 162 | 163 | if __name__ == '__main__': 164 | parser = argparse.ArgumentParser(description="Triplet loss classification") 165 | # data 166 | parser.add_argument('-d', '--dataset', type=str, default='cuhk03', 167 | choices=datasets.names()) 168 | parser.add_argument('-b', '--batch-size', type=int, default=256) 169 | parser.add_argument('-j', '--workers', type=int, default=4) 170 | parser.add_argument('--split', type=int, default=0) 171 | parser.add_argument('--height', type=int, 172 | help="input height, default: 256 for resnet*, " 173 | "144 for inception") 174 | parser.add_argument('--width', type=int, 175 | help="input width, default: 128 for resnet*, " 176 | "56 for inception") 177 | parser.add_argument('--combine-trainval', action='store_true', 178 | help="train and val sets together for training, " 179 | "val set alone for validation", default=True) 180 | parser.add_argument('--num-instances', type=int, default=4, 181 | help="each minibatch consist of " 182 | "(batch_size // num_instances) identities, and " 183 | "each identity has num_instances instances, " 184 | "default: 4") 185 | # model 186 | parser.add_argument('-a', '--arch', type=str, default='resnet50', 187 | choices=models.names()) 188 | parser.add_argument('--features', type=int, default=128) 189 | parser.add_argument('--dropout', type=float, default=0) 190 | # loss 191 | parser.add_argument('--margin', type=float, default=0.1, 192 | help="margin of the triplet loss, default: 0.5") 193 | # optimizer 194 | parser.add_argument('--lr', type=float, default=0.00035, 195 | help="learning rate of all parameters") 196 | parser.add_argument('--weight-decay', type=float, default=5e-4) #0.1? 197 | # training configs 198 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 199 | parser.add_argument('--evaluate', action='store_true', 200 | help="evaluation only") 201 | parser.add_argument('--epochs', type=int, default=150) 202 | parser.add_argument('--start_save', type=int, default=0, 203 | help="start saving checkpoints after specific epoch") 204 | parser.add_argument('--seed', type=int, default=1) 205 | parser.add_argument('--print-freq', type=int, default=1) 206 | # metric learning 207 | parser.add_argument('--dist-metric', type=str, default='euclidean', 208 | choices=['euclidean', 'kissme']) 209 | # misc 210 | working_dir = osp.dirname(osp.abspath(__file__)) 211 | parser.add_argument('--data-dir', type=str, metavar='PATH', 212 | default=osp.join(working_dir, 'data')) 213 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 214 | default=osp.join(working_dir, 'logs')) 215 | 216 | main(parser.parse_args()) 217 | -------------------------------------------------------------------------------- /taudl_video.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os.path as osp 4 | 5 | import numpy as np 6 | import sys 7 | import torch 8 | import time 9 | import datetime 10 | from torch import nn 11 | from torch.backends import cudnn 12 | from torch.utils.data import DataLoader 13 | from reid import datasets 14 | from reid import models 15 | from reid.dist_metric import DistanceMetric 16 | from reid.loss import PosetLoss_G2G 17 | from reid.trainers import Trainer 18 | from reid.evaluators_video import Evaluator 19 | from reid.utils.data import transforms as T 20 | from reid.utils.data.preprocessor_image import Preprocessor_Image 21 | from reid.utils.data.preprocessor_video import Preprocessor_Video 22 | from reid.utils.data.sampler_mt import RandomIdentitySampler 23 | from reid.utils.logging import Logger 24 | from reid.utils.serialization import load_checkpoint, save_checkpoint 25 | import ipdb 26 | 27 | 28 | def flatten_dataset(dataset): 29 | new_dataset = [] 30 | for tracklet in dataset: 31 | img_names, tid, pid, tid_pc, pid_pc, camid = tracklet 32 | for img_name in img_names: 33 | new_dataset.append((img_name, tid, pid, tid_pc, pid_pc, camid)) 34 | return new_dataset 35 | 36 | 37 | def get_data(name, split_id, data_dir, height, width, batch_size, num_instances, workers, combine_trainval): 38 | root = osp.join(data_dir, name) 39 | dataset = datasets.create(name, root, split_id=split_id) 40 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 41 | 42 | # load train_set, query_set, gallery_set 43 | mt_train_set = dataset.train 44 | mt_num_classes = dataset.num_train_tids_sub 45 | query_set = dataset.query 46 | gallery_set = dataset.gallery 47 | 48 | train_transformer = T.Compose([ 49 | T.RandomSizedRectCrop(height, width), 50 | T.RandomHorizontalFlip(), 51 | T.ToTensor(), 52 | normalizer 53 | ]) 54 | 55 | test_transformer = T.Compose([ 56 | T.RectScale(height, width), 57 | T.ToTensor(), 58 | normalizer, 59 | ]) 60 | 61 | # Random ID 62 | mt_train_set = flatten_dataset(mt_train_set) 63 | num_task = len(mt_num_classes) # num_task equals camera number, each camera is a task 64 | mt_train_loader = DataLoader( 65 | Preprocessor_Image(mt_train_set, root=dataset.dataset_dir, transform=train_transformer), 66 | batch_size=batch_size, num_workers=workers, 67 | sampler=RandomIdentitySampler(mt_train_set, num_instances, num_task), # Here is different between softmax_loss 68 | pin_memory=True, drop_last=True) 69 | 70 | # correct format conflict 71 | query_set_new = [] 72 | for index in range(len(query_set)): 73 | img_paths = tuple(query_set[index][0]) 74 | tid = query_set[index][1] 75 | pid = query_set[index][2] 76 | tid_sub = query_set[index][3] 77 | pid_sub = query_set[index][4] 78 | camid = query_set[index][5] 79 | query_set_new.append((img_paths, tid, pid, tid_sub, pid_sub, camid)) 80 | gallery_set_new = [] 81 | for index in range(len(gallery_set)): 82 | img_paths = tuple(gallery_set[index][0]) 83 | tid = gallery_set[index][1] 84 | pid = gallery_set[index][2] 85 | tid_sub = gallery_set[index][3] 86 | pid_sub = gallery_set[index][4] 87 | camid = gallery_set[index][5] 88 | query_set_new.append((img_paths, tid, pid, tid_sub, pid_sub, camid)) 89 | test_set = list(set(query_set_new) | set(gallery_set_new)) 90 | seq_len = 16 91 | test_loader = DataLoader( 92 | Preprocessor_Video(test_set, transform=test_transformer, seq_len=seq_len, sample='random'), 93 | batch_size=int(batch_size / seq_len), num_workers=workers, 94 | shuffle=False, pin_memory=True) 95 | 96 | return mt_train_loader, mt_num_classes, test_loader, query_set, gallery_set 97 | 98 | 99 | def main(args): 100 | np.random.seed(args.seed) 101 | torch.manual_seed(args.seed) 102 | cudnn.benchmark = True 103 | 104 | start = time.time() 105 | 106 | # Redirect print to both console and log file 107 | if not args.evaluate: 108 | dt = datetime.datetime.now() 109 | sys.stdout = Logger(osp.join(args.logs_dir, 'log_' 110 | + str(dt.month).zfill(2) 111 | + str(dt.day).zfill(2) 112 | + str(dt.hour).zfill(2) 113 | + str(dt.minute).zfill(2) + '.txt')) 114 | 115 | # Create data loaders 116 | assert args.num_instances > 1, "num_instances should be greater than 1" 117 | assert args.batch_size % args.num_instances == 0, \ 118 | 'num_instances should divide batch_size' 119 | if args.height is None or args.width is None: 120 | args.height, args.width = (144, 56) if args.arch == 'inception' else \ 121 | (256, 128) 122 | mt_train_loader, mt_num_classes, test_loader, query_set, gallery_set = \ 123 | get_data(args.dataset, args.split, args.data_dir, args.height, 124 | args.width, args.batch_size, args.num_instances, args.workers, 125 | args.combine_trainval) 126 | 127 | # Create model 128 | # Hacking here to let the classifier be the last feature embedding layer 129 | # Net structure: avgpool -> FC(1024) -> FC(args.features) 130 | model = models.create(args.arch, num_features=args.features, 131 | dropout=args.dropout, num_classes=mt_num_classes, double_loss=True) 132 | model = nn.DataParallel(model).cuda() 133 | 134 | # Distance metric 135 | metric = DistanceMetric(algorithm=args.dist_metric) 136 | 137 | # Evaluator 138 | evaluator = Evaluator(model) 139 | 140 | # Criterion 141 | criterion_1 = nn.CrossEntropyLoss().cuda() 142 | criterion_2 = PosetLoss_G2G(margin=args.margin).cuda() 143 | 144 | # Optimizer 145 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, 146 | weight_decay=args.weight_decay) 147 | 148 | # Trainer 149 | num_task = len(mt_num_classes) # num_task equals camera number, each camera is a task 150 | trainer = Trainer(model, criterion_1, criterion_2, num_task) 151 | 152 | # Schedule learning rate 153 | def adjust_lr(epoch): 154 | lr = args.lr if epoch <= 100 else \ 155 | args.lr * (0.001 ** ((epoch - 100) / 50.0)) 156 | for g in optimizer.param_groups: 157 | g['lr'] = lr * g.get('lr_mult', 1) 158 | 159 | # Start training 160 | start_epoch = best_top1 = 0 161 | for epoch in range(start_epoch, args.epochs): 162 | adjust_lr(epoch) 163 | trainer.train(epoch, mt_train_loader, optimizer) 164 | if (epoch % args.start_save == (args.start_save - 1)): 165 | save_checkpoint({ 166 | 'state_dict': model.module.state_dict(), 167 | 'epoch': epoch + 1, 168 | 'best_top1': best_top1, 169 | }, 0, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) 170 | 171 | # Final test 172 | print('Test with the model after epoch {:d}:'.format(epoch + 1)) 173 | checkpoint = load_checkpoint(osp.join(args.logs_dir, 'checkpoint.pth.tar')) 174 | model.module.load_state_dict(checkpoint['state_dict']) 175 | metric.train(model, mt_train_loader) 176 | evaluator.evaluate(test_loader, query_set, gallery_set, metric) 177 | end = time.time() 178 | print('Total time: {:.1f}s'.format(end-start)) 179 | 180 | 181 | 182 | if __name__ == '__main__': 183 | parser = argparse.ArgumentParser(description="Triplet loss classification") 184 | # data 185 | parser.add_argument('-d', '--dataset', type=str, default='cuhk03', 186 | choices=datasets.names()) 187 | parser.add_argument('-b', '--batch-size', type=int, default=256) 188 | parser.add_argument('-j', '--workers', type=int, default=4) 189 | parser.add_argument('--split', type=int, default=0) 190 | parser.add_argument('--height', type=int, 191 | help="input height, default: 256 for resnet*, " 192 | "144 for inception") 193 | parser.add_argument('--width', type=int, 194 | help="input width, default: 128 for resnet*, " 195 | "56 for inception") 196 | parser.add_argument('--combine-trainval', action='store_true', 197 | help="train and val sets together for training, " 198 | "val set alone for validation", default=True) 199 | parser.add_argument('--num-instances', type=int, default=4, 200 | help="each minibatch consist of " 201 | "(batch_size // num_instances) identities, and " 202 | "each identity has num_instances instances, " 203 | "default: 4") 204 | # model 205 | parser.add_argument('-a', '--arch', type=str, default='resnet50', 206 | choices=models.names()) 207 | parser.add_argument('--features', type=int, default=128) 208 | parser.add_argument('--dropout', type=float, default=0) 209 | # loss 210 | parser.add_argument('--margin', type=float, default=0.1, 211 | help="margin of the triplet loss, default: 0.5") 212 | # optimizer 213 | parser.add_argument('--lr', type=float, default=0.00035, 214 | help="learning rate of all parameters") 215 | parser.add_argument('--weight-decay', type=float, default=5e-4) #0.1? 216 | # training configs 217 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 218 | parser.add_argument('--evaluate', action='store_true', 219 | help="evaluation only") 220 | parser.add_argument('--epochs', type=int, default=150) 221 | parser.add_argument('--start_save', type=int, default=0, 222 | help="start saving checkpoints after specific epoch") 223 | parser.add_argument('--seed', type=int, default=1) 224 | parser.add_argument('--print-freq', type=int, default=1) 225 | # metric learning 226 | parser.add_argument('--dist-metric', type=str, default='euclidean', 227 | choices=['euclidean', 'kissme']) 228 | # misc 229 | working_dir = osp.dirname(osp.abspath(__file__)) 230 | parser.add_argument('--data-dir', type=str, metavar='PATH', 231 | default=osp.join(working_dir, 'data')) 232 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 233 | default=osp.join(working_dir, 'logs')) 234 | 235 | main(parser.parse_args()) 236 | --------------------------------------------------------------------------------