├── GaitReID_CVPR2022-main (1).zip ├── README.md ├── args.py ├── requirements.txt ├── torchreid ├── __init__.py ├── data_manager.py ├── dataset_loader.py ├── datasets │ ├── __init__.py │ ├── bases.py │ ├── casiab_video.py │ ├── casiab_video_sub.py │ ├── ltcc.py │ ├── prcc.py │ ├── real28.py │ └── vc_clothe.py ├── eval_cylib │ ├── Makefile │ ├── eval_metrics_cy.pyx │ ├── setup.py │ └── test_cython.py ├── eval_metrics.py ├── losses │ ├── __init__.py │ ├── center_loss.py │ ├── cross_entropy_loss.py │ └── hard_mine_triplet_loss.py ├── models │ ├── __init__.py │ └── resnet.py ├── optimizers.py ├── samplers.py ├── transforms.py └── utils │ ├── avgmeter.py │ ├── iotools.py │ ├── loggers.py │ ├── reidtools.py │ └── torchtools.py └── train_Baseline.py /GaitReID_CVPR2022-main (1).zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinx-USTC/GI-ReID/a5179eaf6cb2b6cc575e2a45aa0c81a6330943ee/GaitReID_CVPR2022-main (1).zip -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GI-ReID 2 | Code release for "Cloth-Changing Person Re-identification from A Single Image with Gait Prediction and Regularization". More files are coming soon. 3 | 4 | 5 | Sorry for delay. Due to an unexpected "Reject" received from reviewer, here we only provide the metarials for Baseline now (including dataloader files for multiple colth-changing reid datasets). The rest parts about gait prediction and regularization will be released upon acceptance. 6 | 7 | # Training and eval setting of Baseline w.r.t the different cloth-changing datasets 8 | 9 | ## LTCC standard eval setting (all training setting): 10 | nohup python train_imgreid_xent_htri.py --gpu-devices 1 --root /data1/Datasets_jinx/Cloth-ReID/reid_datasets/ -s ltcc -t ltcc -a resnet50_fc512 --height 256 --width 128 --max-epoch 60 --stepsize 20 40 --lr 0.0003 --train-batch-size 32 --test-batch-size 100 --save-dir ../debug_dir/test_ltcc3 --train-with-all-cloth --use-standard-metric --eval-freq 10 > ./debug3.log 2>&1 & 11 | 12 | ## LTCC cloth-changing eval setting (all training setting): 13 | ### Note that, if use --use-cloth-changing-metric, must add (g_pids[order] == q_pid) & ((g_cloids[order] == q_cloid) | (g_camids[order] == q_camid)): 14 | nohup python train_imgreid_xent_htri.py --gpu-devices 1 --root /data1/Datasets_jinx/Cloth-ReID/reid_datasets/ -s ltcc -t ltcc -a resnet50_fc512 --height 256 --width 128 --max-epoch 60 --stepsize 20 40 --lr 0.0003 --train-batch-size 32 --test-batch-size 100 --save-dir ../debug_dir/test_ltcc3 --train-with-all-cloth --use-cloth-changing-metric --eval-freq 10 > ./debug5.log 2>&1 & 15 | 16 | ## PRCC same-cloth eval setting: 17 | ### must using cuhk03 eval protocal, and did not remove anything: 18 | nohup python train_imgreid_xent_htri.py --gpu-devices 1 --root /data1/Datasets_jinx/Cloth-ReID/reid_datasets/ -s prcc -t prcc -a resnet50_fc512 --height 256 --width 128 --max-epoch 60 --stepsize 20 40 --lr 0.0003 --train-batch-size 32 --test-batch-size 100 --save-dir ../debug_dir/test_prcc --same-clothes --just-for-prcc-test --use-metric-cuhk03 --eval-freq 10 > ./debug8_prcc.log 2>&1 & 19 | 20 | ## PRCC cloth-changing eval setting: 21 | nohup python train_imgreid_xent_htri.py --gpu-devices 1 --root /data1/Datasets_jinx/Cloth-ReID/reid_datasets/ -s prcc -t prcc -a resnet50_fc512 --height 256 --width 128 --max-epoch 60 --stepsize 20 40 --lr 0.0003 --train-batch-size 32 --test-batch-size 100 --save-dir ../debug_dir/test_prcc --cross-clothes --just-for-prcc-test --use-metric-cuhk03 --eval-freq 10 > ./debug9_prcc.log 2>&1 & 22 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def argument_parser(): 5 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 6 | 7 | # ************************************************************ 8 | # Datasets (general) 9 | # ************************************************************ 10 | parser.add_argument('--root', type=str, default='data', 11 | help="root path to data directory") 12 | parser.add_argument('-s', '--source-names', type=str, required=True, nargs='+', 13 | help="source datasets (delimited by space)") 14 | parser.add_argument('-t', '--target-names', type=str, required=True, nargs='+', 15 | help="target datasets (delimited by space)") 16 | parser.add_argument('-j', '--workers', default=4, type=int, 17 | help="number of data loading workers (tips: 4 or 8 times number of gpus)") 18 | parser.add_argument('--height', type=int, default=256, 19 | help="height of an image") 20 | parser.add_argument('--width', type=int, default=128, 21 | help="width of an image") 22 | parser.add_argument('--split-id', type=int, default=0, 23 | help="split index (note: 0-based)") 24 | parser.add_argument('--train-sampler', type=str, default='', 25 | help="sampler for trainloader") 26 | 27 | # ************************************************************ 28 | # Video datasets 29 | # ************************************************************ 30 | parser.add_argument('--seq-len', type=int, default=8, 31 | help="number of images to sample in a tracklet") 32 | parser.add_argument('--sample-method', type=str, default='evenly', 33 | help="how to sample images from a tracklet") 34 | parser.add_argument('--pool-tracklet-features', type=str, default='avg', choices=['avg', 'max'], 35 | help="how to pool features over a tracklet (for video reid)") 36 | # ************************************************************ 37 | # Gait related 38 | # ************************************************************ 39 | parser.add_argument('--middle-idx', type=int, default=3, 40 | help="which image to use in a gait cycle") 41 | parser.add_argument('--loss-GaitSet', type=float, default=1.0, 42 | help="for loss balance") 43 | parser.add_argument('--loss-GCP', type=float, default=1.0, 44 | help="for loss balance") 45 | parser.add_argument('--lr_GCP', default=0.0005, type=float, 46 | help="initial learning rate") 47 | parser.add_argument('--lr_GaitSet', default=0.0005, type=float, 48 | help="initial learning rate") 49 | parser.add_argument('--stepsize_GCP', default=[40], nargs='+', type=int, 50 | help="stepsize to decay learning rate") 51 | parser.add_argument('--stepsize_GaitSet', default=[120], nargs='+', type=int, 52 | help="stepsize to decay learning rate") 53 | parser.add_argument('--epochs-only-for-GCP', type=int, default=80, 54 | help="total epochs just for training GCP") 55 | parser.add_argument('--load-weights-GCP', type=str, default='', 56 | help="load GCP pretrained weights but ignore layers that don't match in size") 57 | parser.add_argument('--load-weights-GaitSet', type=str, default='', 58 | help="load GaitSet pretrained weights but ignore layers that don't match in size") 59 | parser.add_argument('--reid-dim', type=int, default=256, 60 | help="dims of final ReID vector") 61 | parser.add_argument('--gait-dim', type=int, default=256, 62 | help="dims of final final Gait vector") 63 | # total loss for GaitReID: 64 | parser.add_argument('--loss-ReID-cla-local', type=float, default=1.0, 65 | help="for loss balance") 66 | parser.add_argument('--loss-ReID-tri-local', type=float, default=1.0, 67 | help="for loss balance") 68 | parser.add_argument('--loss-ReID-cla-global', type=float, default=1.0, 69 | help="for loss balance") 70 | parser.add_argument('--loss-ReID-tri-global', type=float, default=1.0, 71 | help="for loss balance") 72 | parser.add_argument('--loss-CFL', type=float, default=1.0, 73 | help="for loss balance") 74 | parser.add_argument('--save-vis-gait-dir', type=str, default='log', 75 | help="path to save gait prediction results") 76 | parser.add_argument('--mask-height', type=int, default=64, 77 | help="height of used mask") 78 | parser.add_argument('--mask-width', type=int, default=64, 79 | help="width of used mask") 80 | 81 | # ************************************************************ 82 | # For Cloth-reid datasets visualization (gait predictions and feature maps) 83 | # ************************************************************ 84 | 85 | parser.add_argument('--save-visualization-results', type=str, default='log', 86 | help="path to save gait prediction results, in the paper") 87 | parser.add_argument('--vis-feat', action='store_true', 88 | help="for feature visualization") 89 | 90 | # ************************************************************ 91 | # Cloth-reid datasets related 92 | # ************************************************************ 93 | 94 | # LTCC datasets related, must choose one of two 95 | parser.add_argument('--train-with-all-cloth', action='store_true', 96 | help="use both of cloth-changing and cloth-consistent person images for training") 97 | parser.add_argument('--train-with-only-cloth-changing', action='store_true', 98 | help="use only cloth-changing person images for training") 99 | parser.add_argument('--use-standard-metric', action='store_true', 100 | help="calculate mAP/Rank-1, didn't discard the images with the same cloth") 101 | parser.add_argument('--use-cloth-changing-metric', action='store_true', 102 | help="calculate mAP/Rank-1, discard the images with the same cloth") 103 | 104 | # LTCC , concat mask for ablation: 105 | parser.add_argument('--concat-mask', action='store_true', 106 | help="concat mask for ablation in supp") 107 | 108 | 109 | # PRCC datasets related, must choose one of two 110 | parser.add_argument('--cross-clothes', action='store_true', 111 | help="the person matching between Camera views A and C was cross-clothes matching") 112 | parser.add_argument('--same-clothes', action='store_true', 113 | help="the person matching between Camera views A and B was performed without clothing changes") 114 | parser.add_argument('--just-for-prcc-test', action='store_true', 115 | help="didn't discard the images with the same cloth") 116 | 117 | 118 | # ************************************************************ 119 | # CUHK03-specific setting 120 | # ************************************************************ 121 | parser.add_argument('--cuhk03-labeled', action='store_true', 122 | help="use labeled images, if false, use detected images") 123 | parser.add_argument('--cuhk03-classic-split', action='store_true', 124 | help="use classic split by Li et al. CVPR'14") 125 | parser.add_argument('--use-metric-cuhk03', action='store_true', 126 | help="use cuhk03's metric for evaluation") 127 | 128 | # ************************************************************ 129 | # Optimization options 130 | # ************************************************************ 131 | parser.add_argument('--optim', type=str, default='adam', 132 | help="optimization algorithm (see optimizers.py)") 133 | parser.add_argument('--lr', default=0.0003, type=float, 134 | help="initial learning rate") 135 | parser.add_argument('--weight-decay', default=5e-04, type=float, 136 | help="weight decay") 137 | # sgd 138 | parser.add_argument('--momentum', default=0.9, type=float, 139 | help="momentum factor for sgd and rmsprop") 140 | parser.add_argument('--sgd-dampening', default=0, type=float, 141 | help="sgd's dampening for momentum") 142 | parser.add_argument('--sgd-nesterov', action='store_true', 143 | help="whether to enable sgd's Nesterov momentum") 144 | # rmsprop 145 | parser.add_argument('--rmsprop-alpha', default=0.99, type=float, 146 | help="rmsprop's smoothing constant") 147 | # adam/amsgrad 148 | parser.add_argument('--adam-beta1', default=0.9, type=float, 149 | help="exponential decay rate for adam's first moment") 150 | parser.add_argument('--adam-beta2', default=0.999, type=float, 151 | help="exponential decay rate for adam's second moment") 152 | 153 | # ************************************************************ 154 | # Training hyperparameters 155 | # ************************************************************ 156 | parser.add_argument('--max-epoch', default=60, type=int, 157 | help="maximum epochs to run") 158 | parser.add_argument('--start-epoch', default=0, type=int, 159 | help="manual epoch number (useful when restart)") 160 | parser.add_argument('--stepsize', default=[20, 40], nargs='+', type=int, 161 | help="stepsize to decay learning rate") 162 | parser.add_argument('--gamma', default=0.1, type=float, 163 | help="learning rate decay") 164 | 165 | parser.add_argument('--train-batch-size', default=8*4*8, type=int, 166 | help="training batch size") 167 | parser.add_argument('--test-batch-size', default=100, type=int, 168 | help="test batch size") 169 | 170 | parser.add_argument('--fixbase', action='store_true', 171 | help="always fix base network") 172 | parser.add_argument('--fixbase-epoch', type=int, default=0, 173 | help="how many epochs to fix base network (only train randomly initialized classifier)") 174 | parser.add_argument('--open-layers', type=str, nargs='+', default=['classifier'], 175 | help="open specified layers for training while keeping others frozen") 176 | 177 | # ************************************************************ 178 | # Cross entropy loss-specific setting 179 | # ************************************************************ 180 | parser.add_argument('--label-smooth', action='store_true', 181 | help="use label smoothing regularizer in cross entropy loss") 182 | 183 | # ************************************************************ 184 | # Hard triplet loss-specific setting 185 | # ************************************************************ 186 | parser.add_argument('--margin', type=float, default=0.3, 187 | help="margin for triplet loss") 188 | parser.add_argument('--num-instances', type=int, default=4, 189 | help="number of instances per identity") 190 | parser.add_argument('--htri-only', action='store_true', 191 | help="only use hard triplet loss") 192 | parser.add_argument('--lambda-xent', type=float, default=1, 193 | help="weight to balance cross entropy loss") 194 | parser.add_argument('--lambda-htri', type=float, default=1, 195 | help="weight to balance hard triplet loss") 196 | 197 | # ************************************************************ 198 | # Architecture 199 | # ************************************************************ 200 | parser.add_argument('-a', '--arch', type=str, default='resnet50') 201 | 202 | # ************************************************************ 203 | # Test settings 204 | # ************************************************************ 205 | parser.add_argument('--load-weights', type=str, default='', 206 | help="load pretrained weights but ignore layers that don't match in size") 207 | parser.add_argument('--evaluate', action='store_true', 208 | help="evaluate only") 209 | parser.add_argument('--eval-freq', type=int, default=-1, 210 | help="evaluation frequency (set to -1 to test only in the end)") 211 | parser.add_argument('--start-eval', type=int, default=0, 212 | help="start to evaluate after a specific epoch") 213 | 214 | # ************************************************************ 215 | # Miscs 216 | # ************************************************************ 217 | parser.add_argument('--print-freq', type=int, default=2, 218 | help="print frequency") 219 | parser.add_argument('--seed', type=int, default=1, 220 | help="manual seed") 221 | parser.add_argument('--resume', type=str, default='', metavar='PATH', 222 | help="resume from a checkpoint") 223 | parser.add_argument('--save-dir', type=str, default='log', 224 | help="path to save log and model weights") 225 | parser.add_argument('--use-cpu', action='store_true', 226 | help="use cpu") 227 | parser.add_argument('--gpu-devices', default='0', type=str, 228 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 229 | parser.add_argument('--use-avai-gpus', action='store_true', 230 | help="use available gpus instead of specified devices (useful when using managed clusters)") 231 | parser.add_argument('--visualize-ranks', action='store_true', 232 | help="visualize ranked results, only available in evaluation mode") 233 | 234 | return parser 235 | 236 | 237 | def image_dataset_kwargs(parsed_args): 238 | """ 239 | Build kwargs for ImageDataManager in data_manager.py from 240 | the parsed command-line arguments. 241 | """ 242 | return { 243 | 'source_names': parsed_args.source_names, 244 | 'target_names': parsed_args.target_names, 245 | 'root': parsed_args.root, 246 | 'split_id': parsed_args.split_id, 247 | 'height': parsed_args.height, 248 | 'width': parsed_args.width, 249 | 'train_batch_size': parsed_args.train_batch_size, 250 | 'test_batch_size': parsed_args.test_batch_size, 251 | 'num_instances': parsed_args.num_instances, 252 | 'workers': parsed_args.workers, 253 | 'train_sampler': parsed_args.train_sampler, 254 | 'cuhk03_labeled': parsed_args.cuhk03_labeled, 255 | 'cuhk03_classic_split': parsed_args.cuhk03_classic_split 256 | } 257 | 258 | 259 | def video_dataset_kwargs(parsed_args): 260 | """ 261 | Build kwargs for VideoDataManager in data_manager.py from 262 | the parsed command-line arguments. 263 | """ 264 | return { 265 | 'source_names': parsed_args.source_names, 266 | 'target_names': parsed_args.target_names, 267 | 'root': parsed_args.root, 268 | 'split_id': parsed_args.split_id, 269 | 'height': parsed_args.height, 270 | 'width': parsed_args.width, 271 | 'train_batch_size': parsed_args.train_batch_size, 272 | 'test_batch_size': parsed_args.test_batch_size, 273 | 'num_instances': parsed_args.num_instances, 274 | 'workers': parsed_args.workers, 275 | 'seq_len': parsed_args.seq_len, 276 | 'sample_method': parsed_args.sample_method 277 | } 278 | 279 | 280 | def optimizer_kwargs(parsed_args): 281 | """ 282 | Build kwargs for optimizer in optimizer.py from 283 | the parsed command-line arguments. 284 | """ 285 | return { 286 | 'optim': parsed_args.optim, 287 | 'lr': parsed_args.lr, 288 | 'weight_decay': parsed_args.weight_decay, 289 | 'momentum': parsed_args.momentum, 290 | 'sgd_dampening': parsed_args.sgd_dampening, 291 | 'sgd_nesterov': parsed_args.sgd_nesterov, 292 | 'rmsprop_alpha': parsed_args.rmsprop_alpha, 293 | 'adam_beta1': parsed_args.adam_beta1, 294 | 'adam_beta2': parsed_args.adam_beta2 295 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Cython 2 | h5py 3 | numpy 4 | Pillow 5 | scipy>=1.0.0 6 | torch>=1.0.1 7 | torchvision>=0.4.1 -------------------------------------------------------------------------------- /torchreid/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | deep-person-reid 3 | == 4 | 5 | Description: PyTorch implementation of deep person re-identification models. 6 | 7 | Github page: https://github.com/KaiyangZhou/deep-person-reid 8 | """ 9 | 10 | __author__ = 'Kaiyang Zhou' 11 | __email__ = 'k.zhou@qmul.ac.uk' -------------------------------------------------------------------------------- /torchreid/data_manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from torch.utils.data import DataLoader 5 | 6 | from .dataset_loader import ImageDataset, VideoDataset 7 | from .datasets import init_imgreid_dataset, init_vidreid_dataset 8 | from .transforms import build_transforms, build_transforms_RGB 9 | from .samplers import RandomIdentitySampler 10 | 11 | 12 | class BaseDataManager(object): 13 | 14 | @property 15 | def num_train_pids(self): 16 | return self._num_train_pids 17 | 18 | @property 19 | def num_train_cams(self): 20 | return self._num_train_cams 21 | 22 | def return_dataloaders(self): 23 | """ 24 | Return trainloader and testloader dictionary 25 | """ 26 | return self.trainloader, self.testloader_dict 27 | 28 | def return_testdataset_by_name(self, name): 29 | """ 30 | Return query and gallery, each containing a list of (img_path, pid, camid). 31 | """ 32 | return self.testdataset_dict[name]['query'], self.testdataset_dict[name]['gallery'] 33 | 34 | 35 | class ImageDataManager(BaseDataManager): 36 | """ 37 | Image-ReID data manager 38 | """ 39 | 40 | def __init__(self, 41 | use_gpu, 42 | source_names, 43 | target_names, 44 | root, 45 | split_id=0, 46 | height=256, 47 | width=128, 48 | train_batch_size=32, 49 | test_batch_size=100, 50 | workers=4, 51 | train_sampler='', 52 | num_instances=4, # number of instances per identity (for RandomIdentitySampler) 53 | cuhk03_labeled=False, # use cuhk03's labeled or detected images 54 | cuhk03_classic_split=False # use cuhk03's classic split or 767/700 split 55 | ): 56 | super(ImageDataManager, self).__init__() 57 | self.use_gpu = use_gpu 58 | self.source_names = source_names 59 | self.target_names = target_names 60 | self.root = root 61 | self.split_id = split_id 62 | self.height = height 63 | self.width = width 64 | self.train_batch_size = train_batch_size 65 | self.test_batch_size = test_batch_size 66 | self.workers = workers 67 | self.train_sampler = train_sampler 68 | self.num_instances = num_instances 69 | self.cuhk03_labeled = cuhk03_labeled 70 | self.cuhk03_classic_split = cuhk03_classic_split 71 | self.pin_memory = True if self.use_gpu else False 72 | 73 | # Build train and test transform functions 74 | transform_train = build_transforms_RGB(self.height, self.width, is_train=True) 75 | transform_test = build_transforms_RGB(self.height, self.width, is_train=False) 76 | 77 | print("=> Initializing TRAIN (source) datasets") 78 | self.train = [] 79 | self._num_train_pids = 0 80 | self._num_train_cloth_ids = 0 81 | self._num_train_cams = 0 82 | 83 | for name in self.source_names: 84 | dataset = init_imgreid_dataset( 85 | root=self.root, name=name, split_id=self.split_id, cuhk03_labeled=self.cuhk03_labeled, 86 | cuhk03_classic_split=self.cuhk03_classic_split 87 | ) 88 | 89 | for img_path, pid, cloth_id, camid in dataset.train: 90 | pid += self._num_train_pids 91 | cloth_id += self._num_train_cloth_ids 92 | camid += self._num_train_cams 93 | self.train.append((img_path, pid, cloth_id, camid)) 94 | 95 | self._num_train_pids += dataset.num_train_pids 96 | self._num_train_cloth_ids += dataset.num_train_cloth_ids 97 | self._num_train_cams += dataset.num_train_cams 98 | 99 | if self.train_sampler == 'RandomIdentitySampler': 100 | self.trainloader = DataLoader( 101 | ImageDataset(self.train, transform=transform_train, training=True), 102 | sampler=RandomIdentitySampler(self.train, self.train_batch_size, self.num_instances), 103 | batch_size=self.train_batch_size, shuffle=False, num_workers=self.workers, 104 | pin_memory=self.pin_memory, drop_last=True 105 | ) 106 | 107 | else: 108 | self.trainloader = DataLoader( 109 | ImageDataset(self.train, transform=transform_train, training=True), 110 | batch_size=self.train_batch_size, shuffle=True, num_workers=self.workers, 111 | pin_memory=self.pin_memory, drop_last=True 112 | ) 113 | 114 | print("=> Initializing TEST (target) datasets") 115 | self.testloader_dict = {name: {'query': None, 'gallery': None} for name in self.target_names} 116 | self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in self.target_names} 117 | 118 | for name in self.target_names: 119 | dataset = init_imgreid_dataset( 120 | root=self.root, name=name, split_id=self.split_id, cuhk03_labeled=self.cuhk03_labeled, 121 | cuhk03_classic_split=self.cuhk03_classic_split 122 | ) 123 | 124 | self.testloader_dict[name]['query'] = DataLoader( 125 | ImageDataset(dataset.query, transform=transform_test, training=False), 126 | batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers, 127 | pin_memory=self.pin_memory, drop_last=False 128 | ) 129 | 130 | self.testloader_dict[name]['gallery'] = DataLoader( 131 | ImageDataset(dataset.gallery, transform=transform_test, training=False), 132 | batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers, 133 | pin_memory=self.pin_memory, drop_last=False 134 | ) 135 | 136 | self.testdataset_dict[name]['query'] = dataset.query 137 | self.testdataset_dict[name]['gallery'] = dataset.gallery 138 | 139 | print("\n") 140 | print(" **************** Summary ****************") 141 | print(" train names : {}".format(self.source_names)) 142 | print(" # train datasets : {}".format(len(self.source_names))) 143 | print(" # train ids : {}".format(self._num_train_pids)) 144 | print(" # train cloth_ids: {}".format(self._num_train_cloth_ids)) 145 | print(" # train images : {}".format(len(self.train))) 146 | print(" # train cameras : {}".format(self._num_train_cams)) 147 | print(" test names : {}".format(self.target_names)) 148 | print(" *****************************************") 149 | print("\n") 150 | 151 | 152 | class VideoDataManager(BaseDataManager): 153 | """ 154 | Video-ReID data manager 155 | """ 156 | 157 | def __init__(self, 158 | use_gpu, 159 | source_names, 160 | target_names, 161 | root, 162 | split_id=0, 163 | height=256, 164 | width=128, 165 | train_batch_size=32, 166 | test_batch_size=100, 167 | num_instances=4, 168 | workers=4, 169 | seq_len=15, 170 | sample_method='evenly', 171 | image_training=False # train the video-reid model with images rather than tracklets 172 | ): 173 | super(VideoDataManager, self).__init__() 174 | self.use_gpu = use_gpu 175 | self.source_names = source_names 176 | self.target_names = target_names 177 | self.root = root 178 | self.split_id = split_id 179 | self.height = height 180 | self.width = width 181 | self.train_batch_size = train_batch_size 182 | self.test_batch_size = test_batch_size 183 | self.workers = workers 184 | self.seq_len = seq_len 185 | self.sample_method = sample_method 186 | self.num_instances = num_instances 187 | self.image_training = image_training 188 | self.pin_memory = True if self.use_gpu else False 189 | 190 | # Build train and test transform functions 191 | transform_train = build_transforms(self.height, self.width, is_train=False) # Xin added 192 | transform_test = build_transforms(self.height, self.width, is_train=False) 193 | 194 | print("=> Initializing TRAIN (source) datasets") 195 | self.train = [] 196 | self._num_train_pids = 0 197 | self._num_train_cams = 0 198 | 199 | for name in self.source_names: 200 | dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id) 201 | 202 | for img_paths, pid, camid in dataset.train: 203 | pid += self._num_train_pids 204 | camid += self._num_train_cams 205 | if self.image_training: 206 | # decompose tracklets into images 207 | for img_path in img_paths: 208 | self.train.append((img_path, pid, camid)) 209 | else: 210 | self.train.append((img_paths, pid, camid)) 211 | 212 | self._num_train_pids += dataset.num_train_pids 213 | self._num_train_cams += dataset.num_train_cams 214 | 215 | if image_training: 216 | # each batch has image data of shape (batch, channel, height, width) 217 | self.trainloader = DataLoader( 218 | ImageDataset(self.train, transform=transform_train), 219 | batch_size=self.train_batch_size, shuffle=True, num_workers=self.workers, 220 | pin_memory=self.pin_memory, drop_last=True 221 | ) 222 | else: 223 | # each batch has image data of shape (batch=N*K, seq_len, channel, height, width) 224 | self.trainloader = DataLoader( 225 | VideoDataset(self.train, seq_len=self.seq_len, sample_method=self.sample_method, transform=transform_test), 226 | sampler=RandomIdentitySampler(self.train, self.train_batch_size, self.num_instances), 227 | batch_size=self.train_batch_size, shuffle=False, num_workers=self.workers, 228 | pin_memory=self.pin_memory, drop_last=True 229 | ) 230 | 231 | print("=> Initializing TEST (target) datasets") 232 | self.testloader_dict = {name: {'query': None, 'gallery': None} for name in self.target_names} 233 | self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in self.target_names} 234 | 235 | for name in self.target_names: 236 | dataset = init_vidreid_dataset(root=self.root, name=name, split_id=self.split_id) 237 | 238 | self.testloader_dict[name]['query'] = DataLoader( 239 | VideoDataset(dataset.query, seq_len=self.seq_len, sample_method=self.sample_method, transform=transform_test), 240 | batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers, 241 | pin_memory=self.pin_memory, drop_last=False, 242 | ) 243 | 244 | self.testloader_dict[name]['gallery'] = DataLoader( 245 | VideoDataset(dataset.gallery, seq_len=self.seq_len, sample_method=self.sample_method, transform=transform_test), 246 | batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers, 247 | pin_memory=self.pin_memory, drop_last=False, 248 | ) 249 | 250 | self.testdataset_dict[name]['query'] = dataset.query 251 | self.testdataset_dict[name]['gallery'] = dataset.gallery 252 | 253 | print("\n") 254 | print(" **************** Summary ****************") 255 | print(" train names : {}".format(self.source_names)) 256 | print(" # train datasets : {}".format(len(self.source_names))) 257 | print(" # train ids : {}".format(self._num_train_pids)) 258 | if self.image_training: 259 | print(" # train images : {}".format(len(self.train))) 260 | else: 261 | print(" # train tracklets: {}".format(len(self.train))) 262 | print(" # train cameras : {}".format(self._num_train_cams)) 263 | print(" test names : {}".format(self.target_names)) 264 | print(" *****************************************") 265 | print("\n") -------------------------------------------------------------------------------- /torchreid/dataset_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import os 6 | from PIL import Image 7 | import numpy as np 8 | import os.path as osp 9 | import io 10 | 11 | import torch 12 | from torch.utils.data import Dataset 13 | from .transforms import build_transforms_grey, build_transforms_RGB 14 | 15 | from args import argument_parser 16 | # global variables 17 | parser = argument_parser() 18 | args = parser.parse_args() 19 | 20 | def read_image(img_path): 21 | """Keep reading image until succeed. 22 | This can avoid IOError incurred by heavy IO process.""" 23 | got_img = False 24 | if not osp.exists(img_path): 25 | raise IOError("{} does not exist".format(img_path)) 26 | while not got_img: 27 | try: 28 | img = Image.open(img_path) #.convert('RGB') Gait only have 1 channel 29 | got_img = True 30 | except IOError: 31 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 32 | pass 33 | return img 34 | 35 | def read_image_RGB(img_path): 36 | """Keep reading image until succeed. 37 | This can avoid IOError incurred by heavy IO process.""" 38 | got_img = False 39 | if not osp.exists(img_path): 40 | raise IOError("{} does not exist".format(img_path)) 41 | while not got_img: 42 | try: 43 | img = Image.open(img_path).convert('RGB') 44 | got_img = True 45 | except IOError: 46 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 47 | pass 48 | return img 49 | 50 | 51 | class ImageDataset(Dataset): 52 | """Image Person ReID Dataset""" 53 | def __init__(self, dataset, transform=None, training=True): 54 | self.dataset = dataset 55 | self.transform = transform 56 | self.training = training 57 | self.transform_grey = build_transforms_grey(args.mask_height, args.mask_width, is_train=True) 58 | self.transform_RGB = build_transforms_RGB(args.height, args.width, is_train=True) 59 | 60 | def __len__(self): 61 | return len(self.dataset) 62 | 63 | def __getitem__(self, index): 64 | img_path, pid, cloth_id, camid = self.dataset[index] 65 | img = read_image_RGB(img_path) 66 | 67 | if self.training: 68 | try: # our extracted masks 69 | if args.source_names[0] == 'prcc': 70 | mask = read_image(img_path.split('train')[0]+'train_mask'+img_path.split('train')[-1].split('.jpg')[0]+'.png') 71 | elif args.source_names[0] == 'ltcc': 72 | mask = read_image(img_path.split('train')[0]+'train_mask'+img_path.split('train')[-1]) 73 | except: # using PRCC provided sketch 74 | mask = read_image(img_path.split('/rgb/')[0]+'/sketch/'+img_path.split('/rgb/')[-1]) 75 | 76 | mask = self.transform_grey(mask) 77 | 78 | if self.transform is not None: 79 | img = self.transform_RGB(img) 80 | 81 | if self.training: 82 | return img, pid, cloth_id, camid, img_path, mask 83 | else: 84 | if args.concat_mask and args.source_names[0] == 'ltcc': 85 | try: 86 | mask = read_image(img_path.split('test')[0] + 'test_mask' + img_path.split('test')[-1]) 87 | except: 88 | mask = read_image(img_path.split('query')[0] + 'query_mask' + img_path.split('query')[-1]) 89 | 90 | mask = self.transform_grey(mask) 91 | return img, pid, cloth_id, camid, mask 92 | 93 | return img, pid, cloth_id, camid, img_path 94 | 95 | 96 | class VideoDataset(Dataset): 97 | """Video Person ReID Dataset. 98 | Note batch data has shape (batch, seq_len, channel, height, width). 99 | """ 100 | _sample_methods = ['evenly', 'random', 'all'] 101 | 102 | def __init__(self, dataset, seq_len=15, sample_method='evenly', transform=None): 103 | self.dataset = dataset 104 | self.seq_len = seq_len 105 | self.sample_method = sample_method 106 | self.transform = transform 107 | self.cut_padding = 10 108 | 109 | def __len__(self): 110 | return len(self.dataset) 111 | 112 | def __getitem__(self, index): 113 | img_paths, pid, camid = self.dataset[index] 114 | num = len(img_paths) 115 | 116 | if self.sample_method == 'random': 117 | """ 118 | Randomly sample seq_len items from num items, 119 | if num is smaller than seq_len, then replicate items 120 | """ 121 | indices = np.arange(num) 122 | replace = False if num >= self.seq_len else True 123 | indices = np.random.choice(indices, size=self.seq_len, replace=replace) 124 | # sort indices to keep temporal order (comment it to be order-agnostic) 125 | indices = np.sort(indices) 126 | 127 | elif self.sample_method == 'evenly': 128 | """ 129 | Evenly sample seq_len items from num items. 130 | """ 131 | if num >= self.seq_len: 132 | num -= num % self.seq_len 133 | indices = np.arange(0, num, num/self.seq_len) 134 | else: 135 | # if num is smaller than seq_len, simply replicate the last image 136 | # until the seq_len requirement is satisfied 137 | indices = np.arange(0, num) 138 | num_pads = self.seq_len - num 139 | indices = np.concatenate([indices, np.ones(num_pads).astype(np.int32)*(num-1)]) 140 | assert len(indices) == self.seq_len 141 | 142 | elif self.sample_method == 'all': 143 | """ 144 | Sample all items, seq_len is useless now and batch_size needs 145 | to be set to 1. 146 | """ 147 | indices = np.arange(num) 148 | 149 | else: 150 | raise ValueError("Unknown sample method: {}. Expected one of {}".format(self.sample_method, self._sample_methods)) 151 | 152 | imgs = [] 153 | img_name = [] 154 | for index in indices: 155 | img_path = img_paths[int(index)] 156 | img = read_image(img_path) 157 | img_name.append(img_path) 158 | if self.transform is not None: 159 | img = self.transform(img) 160 | img = img.unsqueeze(0) 161 | imgs.append(img) 162 | imgs = torch.cat(imgs, dim=0) 163 | 164 | return imgs, pid, camid, img_name 165 | 166 | # return imgs[:, :, :, self.cut_padding:-self.cut_padding], pid, camid, img_paths 167 | -------------------------------------------------------------------------------- /torchreid/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .casiab_video import Casiab 6 | from .casiab_video_sub import Casiab_sub 7 | 8 | from .ltcc import Ltcc 9 | from .prcc import Prcc 10 | 11 | 12 | __imgreid_factory = { 13 | 'ltcc': Ltcc, 14 | 'prcc': Prcc, 15 | } 16 | 17 | # gait dataset, for training Gait-Stream 18 | __vidreid_factory = { 19 | 'casiab': Casiab, 20 | 'casiab_processed': Casiab_sub, 21 | } 22 | 23 | 24 | def init_imgreid_dataset(name, **kwargs): 25 | if name not in list(__imgreid_factory.keys()): 26 | raise KeyError("Invalid dataset, got '{}', but expected to be one of {}".format(name, list(__imgreid_factory.keys()))) 27 | return __imgreid_factory[name](**kwargs) 28 | 29 | 30 | def init_vidreid_dataset(name, **kwargs): 31 | if name not in list(__vidreid_factory.keys()): 32 | raise KeyError("Invalid dataset, got '{}', but expected to be one of {}".format(name, list(__vidreid_factory.keys()))) 33 | return __vidreid_factory[name](**kwargs) -------------------------------------------------------------------------------- /torchreid/datasets/bases.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import numpy as np 5 | 6 | 7 | class BaseDataset(object): 8 | """ 9 | Base class of reid dataset 10 | """ 11 | 12 | def get_imagedata_info(self, data): 13 | pids, cloth_ids, cams = [], [], [] 14 | for _, pid, cloth_id, camid in data: 15 | pids += [pid] 16 | cloth_ids += [cloth_id] 17 | cams += [camid] 18 | pids = set(pids) 19 | cloth_ids = set(cloth_ids) 20 | cams = set(cams) 21 | num_pids = len(pids) 22 | num_cloth_ids = len(cloth_ids) 23 | num_cams = len(cams) 24 | num_imgs = len(data) 25 | return num_pids, num_cloth_ids, num_imgs, num_cams 26 | 27 | def get_videodata_info(self, data, return_tracklet_info=False): 28 | pids, cams, tracklet_info = [], [], [] 29 | for img_paths, pid, camid in data: 30 | pids += [pid] 31 | cams += [camid] 32 | tracklet_info += [len(img_paths)] 33 | pids = set(pids) 34 | cams = set(cams) 35 | num_pids = len(pids) 36 | num_cams = len(cams) 37 | num_tracklets = len(data) 38 | if return_tracklet_info: 39 | return num_pids, num_tracklets, num_cams, tracklet_info 40 | return num_pids, num_tracklets, num_cams 41 | 42 | def print_dataset_statistics(self): 43 | raise NotImplementedError 44 | 45 | 46 | class BaseImageDataset(BaseDataset): 47 | """ 48 | Base class of image reid dataset 49 | """ 50 | 51 | def print_dataset_statistics(self, train, query, gallery): 52 | num_train_pids, num_train_cloth_ids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 53 | num_query_pids, num_query_cloth_ids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 54 | num_gallery_pids, num_gallery_cloth_ids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 55 | 56 | print("Dataset statistics:") 57 | print(" ----------------------------------------") 58 | print(" subset | # ids | # cloth_ids | # images | # cameras") 59 | print(" ----------------------------------------") 60 | print(" train | {:5d} | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_cloth_ids, num_train_imgs, num_train_cams)) 61 | print(" query | {:5d} | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_cloth_ids, num_query_imgs, num_query_cams)) 62 | print(" gallery | {:5d} | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_cloth_ids, num_gallery_imgs, num_gallery_cams)) 63 | print(" ----------------------------------------") 64 | 65 | 66 | class BaseVideoDataset(BaseDataset): 67 | """ 68 | Base class of video reid dataset 69 | """ 70 | 71 | def print_dataset_statistics(self, train, query, gallery): 72 | num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_info = \ 73 | self.get_videodata_info(train, return_tracklet_info=True) 74 | 75 | num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_info = \ 76 | self.get_videodata_info(query, return_tracklet_info=True) 77 | 78 | num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_info = \ 79 | self.get_videodata_info(gallery, return_tracklet_info=True) 80 | 81 | tracklet_info = train_tracklet_info + query_tracklet_info + gallery_tracklet_info 82 | min_num = np.min(tracklet_info) 83 | max_num = np.max(tracklet_info) 84 | avg_num = np.mean(tracklet_info) 85 | 86 | print("Dataset statistics:") 87 | print(" -------------------------------------------") 88 | print(" subset | # ids | # tracklets | # cameras") 89 | print(" -------------------------------------------") 90 | print(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_cams)) 91 | print(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_cams)) 92 | print(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams)) 93 | print(" -------------------------------------------") 94 | print(" number of images per tracklet: {} ~ {}, average {:.2f}".format(min_num, max_num, avg_num)) 95 | print(" -------------------------------------------") -------------------------------------------------------------------------------- /torchreid/datasets/casiab_video.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import glob 7 | import re 8 | import sys 9 | import urllib 10 | import tarfile 11 | import zipfile 12 | import os.path as osp 13 | from scipy.io import loadmat 14 | import numpy as np 15 | import h5py 16 | # from scipy.misc import imsave 17 | 18 | from .bases import BaseVideoDataset 19 | 20 | 21 | class Casiab(BaseVideoDataset): 22 | """ 23 | MARS 24 | 25 | Reference: 26 | Zheng et al. MARS: A Video Benchmark for Large-Scale Person Re-identification. ECCV 2016. 27 | 28 | URL: http://www.liangzheng.com.cn/Project/project_mars.html 29 | 30 | Dataset statistics: 31 | # identities: 1261 32 | # tracklets: 8298 (train) + 1980 (query) + 9330 (gallery) 33 | # cameras: 6 34 | """ 35 | dataset_dir = 'casiab' 36 | 37 | def __init__(self, root='data', min_seq_len=0, verbose=True, **kwargs): 38 | self.dataset_dir = osp.join(root, self.dataset_dir) 39 | self.train_dir = osp.join(self.dataset_dir, 'train_candi') 40 | self.gallery_dir = osp.join(self.dataset_dir, 'gallery_candi') 41 | # self.track_train_info_path = osp.join(self.dataset_dir, 'info/tracks_train_info.mat') 42 | # self.track_test_info_path = osp.join(self.dataset_dir, 'info/tracks_test_info.mat') 43 | self.query_dir = osp.join(self.dataset_dir, 'query_candi') 44 | 45 | self._check_before_run() 46 | 47 | train = self._process_data(self.train_dir, relabel=True, min_seq_len=min_seq_len) 48 | query = self._process_data(self.query_dir, relabel=False, min_seq_len=min_seq_len) 49 | gallery = self._process_data(self.gallery_dir, relabel=False, min_seq_len=min_seq_len) 50 | 51 | if verbose: 52 | print("=> Casiab loaded") 53 | self.print_dataset_statistics(train, query, gallery) 54 | 55 | self.train = train 56 | self.query = query 57 | self.gallery = gallery 58 | 59 | self.num_train_pids, _, self.num_train_cams = self.get_videodata_info(self.train) 60 | self.num_query_pids, _, self.num_query_cams = self.get_videodata_info(self.query) 61 | self.num_gallery_pids, _, self.num_gallery_cams = self.get_videodata_info(self.gallery) 62 | 63 | def _check_before_run(self): 64 | """Check if all files are available before going deeper""" 65 | if not osp.exists(self.dataset_dir): 66 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 67 | 68 | def _get_names(self, fpath): 69 | names = [] 70 | with open(fpath, 'r') as f: 71 | for line in f: 72 | new_line = line.rstrip() 73 | names.append(new_line) 74 | return names 75 | 76 | def _process_data(self, dir_path, relabel=False, min_seq_len=0): 77 | pid_paths = glob.glob(osp.join(dir_path, '*')) 78 | 79 | tracklets = [] 80 | for pid_path in pid_paths: 81 | pid = int(osp.basename(pid_path)) 82 | 83 | camid_paths = glob.glob(osp.join(pid_path, 'nm*')) 84 | for camid_path in camid_paths: 85 | camid = int(osp.basename(camid_path)[4]) 86 | 87 | target_view_dir = osp.join(camid_path, '090') 88 | image_paths = glob.glob(osp.join(target_view_dir, '*')) 89 | 90 | if len(image_paths) >= min_seq_len: 91 | img_paths = tuple(image_paths) 92 | tracklets.append((image_paths, pid, camid)) 93 | return tracklets 94 | 95 | -------------------------------------------------------------------------------- /torchreid/datasets/casiab_video_sub.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import glob 7 | import re 8 | import sys 9 | import urllib 10 | import tarfile 11 | import zipfile 12 | import os.path as osp 13 | from scipy.io import loadmat 14 | import numpy as np 15 | import h5py 16 | 17 | from .bases import BaseVideoDataset 18 | 19 | 20 | class Casiab_sub(BaseVideoDataset): 21 | """ 22 | MARS 23 | 24 | Reference: 25 | Zheng et al. MARS: A Video Benchmark for Large-Scale Person Re-identification. ECCV 2016. 26 | 27 | URL: http://www.liangzheng.com.cn/Project/project_mars.html 28 | 29 | Dataset statistics: 30 | # identities: 1261 31 | # tracklets: 8298 (train) + 1980 (query) + 9330 (gallery) 32 | # cameras: 6 33 | """ 34 | dataset_dir = 'CASIA_pro' 35 | 36 | def __init__(self, root='data', min_seq_len=8, verbose=True, **kwargs): 37 | self.dataset_dir = osp.join(root, self.dataset_dir) 38 | self.train_dir = osp.join(self.dataset_dir, 'train_candi') 39 | self.gallery_dir = osp.join(self.dataset_dir, 'gallery_candi') 40 | # self.track_train_info_path = osp.join(self.dataset_dir, 'info/tracks_train_info.mat') 41 | # self.track_test_info_path = osp.join(self.dataset_dir, 'info/tracks_test_info.mat') 42 | self.query_dir = osp.join(self.dataset_dir, 'query_candi') 43 | 44 | self._check_before_run() 45 | 46 | train = self._process_dir(self.train_dir, relabel=True, min_seq_len=min_seq_len) 47 | query = self._process_dir(self.query_dir, relabel=False, min_seq_len=min_seq_len) 48 | gallery = self._process_dir(self.gallery_dir, relabel=False, min_seq_len=min_seq_len) 49 | 50 | if verbose: 51 | print("=> Casiab loaded") 52 | self.print_dataset_statistics(train, query, gallery) 53 | 54 | self.train = train 55 | self.query = query 56 | self.gallery = gallery 57 | 58 | self.num_train_pids, _, self.num_train_cams = self.get_videodata_info(self.train) 59 | self.num_query_pids, _, self.num_query_cams = self.get_videodata_info(self.query) 60 | self.num_gallery_pids, _, self.num_gallery_cams = self.get_videodata_info(self.gallery) 61 | 62 | def _check_before_run(self): 63 | """Check if all files are available before going deeper""" 64 | if not osp.exists(self.dataset_dir): 65 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 66 | 67 | def _get_names(self, fpath): 68 | names = [] 69 | with open(fpath, 'r') as f: 70 | for line in f: 71 | new_line = line.rstrip() 72 | names.append(new_line) 73 | return names 74 | 75 | def _process_dir(self, dir_path, relabel=False, min_seq_len=8): 76 | pid_paths = glob.glob(osp.join(dir_path, '*')) 77 | 78 | tracklets = [] 79 | for pid_path in pid_paths: 80 | pid = int(osp.basename(pid_path)) 81 | 82 | camid_paths = glob.glob(osp.join(pid_path, 'nm*')) 83 | for camid_path in camid_paths: 84 | camid = int(osp.basename(camid_path)[4]) 85 | 86 | target_view_dir = osp.join(camid_path, '090') 87 | sub_tracklet_dirs = glob.glob(osp.join(target_view_dir, 'sub*')) 88 | 89 | for sub_sub_tracklet_dir in sub_tracklet_dirs: 90 | image_paths = glob.glob(osp.join(sub_sub_tracklet_dir, '*')) 91 | 92 | if len(image_paths) >= min_seq_len: 93 | img_paths = tuple(image_paths) 94 | tracklets.append((img_paths, pid, camid)) 95 | return tracklets 96 | 97 | -------------------------------------------------------------------------------- /torchreid/datasets/ltcc.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import glob 8 | import re 9 | 10 | import os.path as osp 11 | 12 | from .bases import BaseImageDataset 13 | 14 | from args import argument_parser 15 | # global variables 16 | parser = argument_parser() 17 | args = parser.parse_args() 18 | 19 | 20 | class Ltcc(BaseImageDataset): 21 | """ 22 | LTCC_ReID dataset 23 | Reference: 24 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 25 | URL: http://www.liangzheng.org/Project/project_reid.html 26 | 27 | Dataset statistics: 28 | # identities: 152 (+1 for background) 29 | # images: (train) + (query) + (gallery) 30 | """ 31 | dataset_dir = 'LTCC_ReID' 32 | 33 | def __init__(self, root='/home/jinx/data', verbose=True, **kwargs): 34 | super(Ltcc, self).__init__() 35 | self.dataset_dir = osp.join(root, self.dataset_dir) 36 | self.train_dir = osp.join(self.dataset_dir, 'train') 37 | self.query_dir = osp.join(self.dataset_dir, 'query') 38 | self.gallery_dir = osp.join(self.dataset_dir, 'test') 39 | 40 | self._check_before_run() 41 | 42 | self.cloth_unchange_id_train_list = [] 43 | self.cloth_change_id_train_list = [] 44 | self.cloth_unchange_id_test_list = [] 45 | self.cloth_change_id_test_list = [] 46 | file = open(osp.join(self.dataset_dir, 'info/cloth-unchange_id_train.txt')) 47 | cloth_unchange_id_train_list = file.readlines() 48 | for i in cloth_unchange_id_train_list: 49 | self.cloth_unchange_id_train_list.append(int(i.strip('\n'))) 50 | file = open(osp.join(self.dataset_dir, 'info/cloth-change_id_train.txt')) 51 | cloth_change_id_train_list = file.readlines() 52 | for i in cloth_change_id_train_list: 53 | self.cloth_change_id_train_list.append(int(i.strip('\n'))) 54 | file = open(osp.join(self.dataset_dir, 'info/cloth-unchange_id_test.txt')) 55 | cloth_unchange_id_test_list = file.readlines() 56 | for i in cloth_unchange_id_test_list: 57 | self.cloth_unchange_id_test_list.append(int(i.strip('\n'))) 58 | file = open(osp.join(self.dataset_dir, 'info/cloth-change_id_test.txt')) 59 | cloth_change_id_test_list = file.readlines() 60 | for i in cloth_change_id_test_list: 61 | self.cloth_change_id_test_list.append(int(i.strip('\n'))) 62 | 63 | train = self._process_dir_train(self.train_dir, self.cloth_unchange_id_train_list, self.cloth_change_id_train_list, relabel=True) 64 | query = self._process_dir(self.query_dir, relabel=False) 65 | gallery = self._process_dir(self.gallery_dir, relabel=False) 66 | 67 | if verbose: 68 | print("=> LTCC_ReID loaded") 69 | self.print_dataset_statistics(train, query, gallery) 70 | 71 | self.train = train 72 | self.query = query 73 | self.gallery = gallery 74 | 75 | self.num_train_pids, self.num_train_cloth_ids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 76 | self.num_query_pids, self.num_query_cloth_ids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 77 | self.num_gallery_pids, self.num_gallery_cloth_ids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 78 | 79 | def _check_before_run(self): 80 | """Check if all files are available before going deeper""" 81 | if not osp.exists(self.dataset_dir): 82 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 83 | if not osp.exists(self.train_dir): 84 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 85 | if not osp.exists(self.query_dir): 86 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 87 | if not osp.exists(self.gallery_dir): 88 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 89 | 90 | def _process_dir_train(self, dir_path, cloth_unchange_id_train_list, cloth_change_id_train_list, relabel=False): 91 | img_paths = glob.glob(osp.join(dir_path, '*.png')) 92 | pattern = re.compile(r'([\d]+)_([\d]+)+_c(\d)') # cloth ID 也要,所以返回三个!!re.compile(r'([\d]+)_[\d]+_c(\d)') 93 | 94 | pid_container = set() 95 | for img_path in img_paths: 96 | img_name = osp.basename(img_path) 97 | pid, _, _ = map(int, pattern.search(img_name).groups()) 98 | pid_container.add(pid) 99 | 100 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 101 | 102 | dataset = [] 103 | for img_path in img_paths: 104 | img_name = osp.basename(img_path) 105 | pid, cloth_id, camid = map(int, pattern.search(img_name).groups()) 106 | assert 0 <= pid <= 151 # pid == 0 means background 107 | assert 1 <= cloth_id <= 14 108 | assert 1 <= camid <= 12 109 | camid -= 1 # index starts from 0 110 | if relabel: pid = pid2label[pid] 111 | 112 | if args.train_with_all_cloth: 113 | dataset.append((img_path, pid, cloth_id, camid)) 114 | elif args.train_with_only_cloth_changing: 115 | if pid in cloth_change_id_train_list: 116 | dataset.append((img_path, pid, cloth_id, camid)) 117 | 118 | return dataset 119 | 120 | def _process_dir(self, dir_path, relabel=False): 121 | img_paths = glob.glob(osp.join(dir_path, '*.png')) 122 | pattern = re.compile(r'([\d]+)_([\d]+)+_c(\d)') # cloth ID 也要,所以返回三个!!re.compile(r'([\d]+)_[\d]+_c(\d)') 123 | 124 | pid_container = set() 125 | for img_path in img_paths: 126 | img_name = osp.basename(img_path) 127 | pid, _, _ = map(int, pattern.search(img_name).groups()) 128 | pid_container.add(pid) 129 | 130 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 131 | 132 | dataset = [] 133 | for img_path in img_paths: 134 | img_name = osp.basename(img_path) 135 | pid, cloth_id, camid = map(int, pattern.search(img_name).groups()) 136 | assert 0 <= pid <= 151 # pid == 0 means background 137 | assert 1 <= cloth_id <= 14 138 | assert 1 <= camid <= 12 139 | camid -= 1 # index starts from 0 140 | if relabel: pid = pid2label[pid] 141 | dataset.append((img_path, pid, cloth_id, camid)) 142 | 143 | return dataset -------------------------------------------------------------------------------- /torchreid/datasets/prcc.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import glob 8 | import re 9 | 10 | import os.path as osp 11 | 12 | from .bases import BaseImageDataset 13 | 14 | from args import argument_parser 15 | # global variables 16 | parser = argument_parser() 17 | args = parser.parse_args() 18 | 19 | class Prcc(BaseImageDataset): 20 | """ 21 | PRCC 22 | Reference: 23 | Person Re-identification by Contour Sketch under Moderate Clothing Change. 24 | TPAMI-2019 25 | 26 | Dataset statistics: 27 | # identities: 221, with 3 camera views. 28 | # images: 150IDs (train) + 71IDs (test) 29 | 30 | Dataset statistics: (A--->C, cross-clothes settings) 31 | ---------------------------------------- 32 | subset | # ids | # cloth_ids | # images | # cameras 33 | ---------------------------------------- 34 | train | 150 | 3 | 17896 | 3 35 | query | 71 | 1 | 3384 | 1 36 | gallery | 71 | 1 | 3543 | 1 37 | ---------------------------------------- 38 | 39 | Two test settings: 40 | parser.add_argument('--cross-clothes', action='store_true', 41 | help="the person matching between Camera views A and C was cross-clothes matching") 42 | parser.add_argument('--same-clothes', action='store_true', 43 | help="the person matching between Camera views A and B was performed without clothing changes") 44 | """ 45 | dataset_dir = 'prcc/rgb' # could change to sketch/contour folder 46 | 47 | def __init__(self, root='/home/jinx/data', verbose=True, **kwargs): 48 | super(Prcc, self).__init__() 49 | self.dataset_dir = osp.join(root, self.dataset_dir) 50 | self.train_dir = osp.join(self.dataset_dir, 'train') 51 | self.validation_dir = osp.join(self.dataset_dir, 'val') 52 | self.probe_gallery_dir = osp.join(self.dataset_dir, 'test') 53 | 54 | self._check_before_run() 55 | 56 | train = self._process_dir(self.train_dir, relabel=True) 57 | query, gallery = self._process_test_dir(self.probe_gallery_dir, relabel=False) 58 | 59 | if verbose: 60 | print("=> PRCC dataset loaded") 61 | self.print_dataset_statistics(train, query, gallery) 62 | 63 | self.train = train 64 | self.query = query 65 | self.gallery = gallery 66 | 67 | self.num_train_pids, self.num_train_cloth_ids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 68 | self.num_query_pids, self.num_query_cloth_ids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 69 | self.num_gallery_pids, self.num_gallery_cloth_ids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 70 | 71 | def _check_before_run(self): 72 | """Check if all files are available before going deeper""" 73 | if not osp.exists(self.dataset_dir): 74 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 75 | if not osp.exists(self.train_dir): 76 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 77 | if not osp.exists(self.probe_gallery_dir): 78 | raise RuntimeError("'{}' is not available".format(self.probe_gallery_dir)) 79 | 80 | 81 | def _process_dir(self, dir_path, relabel=False): 82 | 83 | # Load from train 84 | pid_dirs_path = glob.glob(osp.join(dir_path, '*')) 85 | 86 | dataset = [] 87 | pid_container = set() 88 | camid_mapper = {'A': 1, 'B': 2, 'C': 3} 89 | for pid_dir_path in pid_dirs_path: 90 | img_paths = glob.glob(osp.join(pid_dir_path, '*.jp*')) 91 | for img_path in img_paths: 92 | pid = int(osp.basename(pid_dir_path)) 93 | pid_container.add(pid) 94 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 95 | 96 | for pid_dir_path in pid_dirs_path: 97 | img_paths = glob.glob(osp.join(pid_dir_path, '*.jp*')) 98 | for img_path in img_paths: 99 | pid = int(osp.basename(pid_dir_path)) 100 | camid = camid_mapper[osp.basename(img_path)[0]] 101 | cloth_id = camid 102 | camid -= 1 # index starts from 0 103 | if relabel: pid = pid2label[pid] 104 | dataset.append((img_path, pid, cloth_id, camid)) 105 | 106 | return dataset 107 | 108 | def _process_test_dir(self, dir_path, relabel=False): 109 | 110 | camid_dirs_path = glob.glob(osp.join(dir_path, '*')) 111 | 112 | query = [] 113 | gallery = [] 114 | pid_container = set() 115 | camid_mapper = {'A': 1, 'B': 2, 'C': 3} 116 | 117 | for camid_dir_path in camid_dirs_path: 118 | pid_dir_paths = glob.glob(osp.join(camid_dir_path, '*')) 119 | for pid_dir_path in pid_dir_paths: 120 | pid = int(osp.basename(pid_dir_path)) 121 | img_paths = glob.glob(osp.join(pid_dir_path, '*')) 122 | for img_path in img_paths: 123 | camid = camid_mapper[osp.basename(camid_dir_path)] 124 | camid -= 1 # index starts from 0 125 | if camid == 0: 126 | cloth_id = camid 127 | query.append((img_path, pid, cloth_id, camid)) 128 | else: 129 | if args.cross_clothes and camid == 2: 130 | cloth_id = camid 131 | gallery.append((img_path, pid, cloth_id, camid)) 132 | elif args.same_clothes and camid == 1: 133 | cloth_id = camid 134 | gallery.append((img_path, pid, cloth_id, camid)) 135 | 136 | return query, gallery 137 | -------------------------------------------------------------------------------- /torchreid/datasets/real28.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: Zhiheng Yin 4 | @contact: yzhiheng@umich.edu 5 | """ 6 | 7 | import glob 8 | import re 9 | 10 | import os.path as osp 11 | 12 | from .bases import BaseImageDataset 13 | 14 | 15 | class Real28(BaseImageDataset): 16 | """ 17 | Real28 18 | Reference: 19 | # Copyright (C) 2019 Fangbin Wan, Yang Wu, Xuelin Qian, Yanwei Fu, Fudan University 20 | # 21 | # Licensed to the Apache Software Foundation (ASF) under one or more 22 | # contributor license agreements. 23 | # The ASF licenses this file to You under the Apache License, Version 2.0 24 | # (the "License"); you may not use this file except in compliance with 25 | # the License. You may obtain a copy of the License at 26 | # 27 | # http://www.apache.org/licenses/LICENSE-2.0 28 | 29 | Dataset statistics: 30 | # identities: 28 31 | # images: 0 (train) + 4323 (query) + (gallery) 32 | """ 33 | dataset_dir = 'Rea128/Real28' 34 | 35 | def __init__(self, root='/home/haoluo/data', verbose=True, **kwargs): 36 | super(Real28, self).__init__() 37 | self.dataset_dir = osp.join(root, self.dataset_dir) 38 | self.train_dir = osp.join(self.dataset_dir, 'train') 39 | self.query_dir = osp.join(self.dataset_dir, 'query') 40 | self.gallery_dir = osp.join(self.dataset_dir, 'gallery') 41 | 42 | self._check_before_run() 43 | 44 | train = self._process_dir(self.train_dir, relabel=True) 45 | query = self._process_dir(self.query_dir, relabel=False) 46 | gallery = self._process_dir(self.gallery_dir, relabel=False) 47 | 48 | if verbose: 49 | print("=> Real28 loaded") 50 | self.print_dataset_statistics(train, query, gallery) 51 | 52 | self.train = train 53 | self.query = query 54 | self.gallery = gallery 55 | 56 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 57 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 58 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 59 | 60 | def _check_before_run(self): 61 | """Check if all files are available before going deeper""" 62 | if not osp.exists(self.dataset_dir): 63 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 64 | if not osp.exists(self.train_dir): 65 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 66 | if not osp.exists(self.query_dir): 67 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 68 | if not osp.exists(self.gallery_dir): 69 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 70 | 71 | def _process_dir(self, dir_path, relabel=False): 72 | img_paths = glob.glob(osp.join(dir_path, '*.jp*')) 73 | pattern = re.compile(r'([\d]+)_([\d]+)') 74 | 75 | pid_container = set() 76 | for img_path in img_paths: 77 | pid, _ = map(int, pattern.search(img_path).groups()) 78 | if pid == -1: continue # junk images are just ignored 79 | pid_container.add(pid) 80 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 81 | 82 | dataset = [] 83 | for img_path in img_paths: 84 | pid, camid = map(int, pattern.search(img_path).groups()) 85 | if pid == -1: continue # junk images are just ignored 86 | assert 1 <= pid <= 28 # pid == 0 means background 87 | assert 1 <= camid <= 4 88 | camid -= 1 # index starts from 0 89 | if relabel: pid = pid2label[pid] 90 | dataset.append((img_path, pid, camid)) 91 | 92 | return dataset 93 | -------------------------------------------------------------------------------- /torchreid/datasets/vc_clothe.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import re 3 | 4 | import os.path as osp 5 | 6 | from .bases import BaseImageDataset 7 | 8 | 9 | class VC_clothe(BaseImageDataset): 10 | """ 11 | VC_clothes dataset 12 | Reference: 13 | Liu, Xinchen, et al. "Large-scale vehicle re-identification in urban surveillance videos." ICME 2016. 14 | 15 | URL:https://vehiclereid.github.io/VeRi/ 16 | 17 | Dataset statistics: 18 | # identities: 776 19 | # images: 37778 (train) + 1678 (query) + 11579 (gallery) 20 | # cameras: 20 21 | """ 22 | 23 | dataset_dir = 'VC-Clothes/VC-Clothes' 24 | 25 | def __init__(self, root='../', verbose=True, **kwargs): 26 | super(VC_clothe, self).__init__() 27 | self.dataset_dir = osp.join(root, self.dataset_dir) 28 | self.train_dir = osp.join(self.dataset_dir, 'train') 29 | self.query_dir = osp.join(self.dataset_dir, 'query') 30 | self.gallery_dir = osp.join(self.dataset_dir, 'gallery') 31 | 32 | self._check_before_run() 33 | 34 | train = self._process_dir(self.train_dir, relabel=True) 35 | query = self._process_dir(self.query_dir, relabel=False) 36 | gallery = self._process_dir(self.gallery_dir, relabel=False) 37 | 38 | if verbose: 39 | print("=> VC_clothes dataset is loaded") 40 | self.print_dataset_statistics(train, query, gallery) 41 | 42 | self.train = train 43 | self.query = query 44 | self.gallery = gallery 45 | 46 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 47 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 48 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 49 | 50 | def _check_before_run(self): 51 | """Check if all files are available before going deeper""" 52 | if not osp.exists(self.dataset_dir): 53 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 54 | if not osp.exists(self.train_dir): 55 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 56 | if not osp.exists(self.query_dir): 57 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 58 | if not osp.exists(self.gallery_dir): 59 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 60 | 61 | def _process_dir(self, dir_path, relabel=False): 62 | img_paths = glob.glob(osp.join(dir_path, '*.jp*')) 63 | pattern = re.compile(r'([\d]+)-([\d]+)') 64 | 65 | pid_container = set() 66 | for img_path in img_paths: 67 | pid, _ = map(int, pattern.search(img_path).groups()) 68 | if pid == -1: continue # junk images are just ignored 69 | pid_container.add(pid) 70 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 71 | 72 | dataset = [] 73 | for img_path in img_paths: 74 | pid, camid = map(int, pattern.search(img_path).groups()) 75 | if pid == -1: continue # junk images are just ignored 76 | assert 0 <= pid <= 512 # pid == 0 means background 77 | assert 1 <= camid <= 4 78 | camid -= 1 # index starts from 0 79 | if relabel: pid = pid2label[pid] 80 | dataset.append((img_path, pid, camid)) 81 | 82 | return dataset 83 | 84 | -------------------------------------------------------------------------------- /torchreid/eval_cylib/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | python setup.py build_ext --inplace 3 | rm -rf build 4 | 5 | clean: 6 | rm -rf build 7 | rm -f eval_metrics_cy.c *.so -------------------------------------------------------------------------------- /torchreid/eval_cylib/eval_metrics_cy.pyx: -------------------------------------------------------------------------------- 1 | # cython: boundscheck=False, wraparound=False, nonecheck=False, cdivision=True 2 | 3 | from __future__ import print_function 4 | 5 | import cython 6 | import numpy as np 7 | cimport numpy as np 8 | from collections import defaultdict 9 | import random 10 | 11 | 12 | """ 13 | Compiler directives: 14 | https://github.com/cython/cython/wiki/enhancements-compilerdirectives 15 | 16 | Cython tutorial: 17 | https://cython.readthedocs.io/en/latest/src/userguide/numpy_tutorial.html 18 | """ 19 | 20 | # Main interface 21 | cpdef evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=False): 22 | distmat = np.asarray(distmat, dtype=np.float32) 23 | q_pids = np.asarray(q_pids, dtype=np.int64) 24 | g_pids = np.asarray(g_pids, dtype=np.int64) 25 | q_camids = np.asarray(q_camids, dtype=np.int64) 26 | g_camids = np.asarray(g_camids, dtype=np.int64) 27 | if use_metric_cuhk03: 28 | return eval_cuhk03_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 29 | return eval_market1501_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 30 | 31 | 32 | cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids, 33 | long[:]q_camids, long[:]g_camids, long max_rank): 34 | 35 | cdef long num_q = distmat.shape[0] 36 | cdef long num_g = distmat.shape[1] 37 | 38 | if num_g < max_rank: 39 | max_rank = num_g 40 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 41 | 42 | cdef: 43 | long num_repeats = 10 44 | long[:,:] indices = np.argsort(distmat, axis=1) 45 | long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64) 46 | 47 | float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32) 48 | float[:] all_AP = np.zeros(num_q, dtype=np.float32) 49 | float num_valid_q = 0. # number of valid query 50 | 51 | long q_idx, q_pid, q_camid, g_idx 52 | long[:] order = np.zeros(num_g, dtype=np.int64) 53 | long keep 54 | 55 | float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches 56 | float[:] masked_raw_cmc = np.zeros(num_g, dtype=np.float32) 57 | float[:] cmc, masked_cmc 58 | long num_g_real, num_g_real_masked, rank_idx, rnd_idx 59 | unsigned long meet_condition 60 | float AP 61 | long[:] kept_g_pids, mask 62 | 63 | float num_rel 64 | float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32) 65 | float tmp_cmc_sum 66 | 67 | for q_idx in range(num_q): 68 | # get query pid and camid 69 | q_pid = q_pids[q_idx] 70 | q_camid = q_camids[q_idx] 71 | 72 | # remove gallery samples that have the same pid and camid with query 73 | for g_idx in range(num_g): 74 | order[g_idx] = indices[q_idx, g_idx] 75 | num_g_real = 0 76 | meet_condition = 0 77 | kept_g_pids = np.zeros(num_g, dtype=np.int64) 78 | 79 | for g_idx in range(num_g): 80 | if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid): 81 | raw_cmc[num_g_real] = matches[q_idx][g_idx] 82 | kept_g_pids[num_g_real] = g_pids[order[g_idx]] 83 | num_g_real += 1 84 | if matches[q_idx][g_idx] > 1e-31: 85 | meet_condition = 1 86 | 87 | if not meet_condition: 88 | # this condition is true when query identity does not appear in gallery 89 | continue 90 | 91 | # cuhk03-specific setting 92 | g_pids_dict = defaultdict(list) # overhead! 93 | for g_idx in range(num_g_real): 94 | g_pids_dict[kept_g_pids[g_idx]].append(g_idx) 95 | 96 | cmc = np.zeros(max_rank, dtype=np.float32) 97 | AP = 0. 98 | for _ in range(num_repeats): 99 | mask = np.zeros(num_g_real, dtype=np.int64) 100 | 101 | for _, idxs in g_pids_dict.items(): 102 | # randomly sample one image for each gallery person 103 | rnd_idx = np.random.choice(idxs) 104 | #rnd_idx = idxs[0] # use deterministic for debugging 105 | mask[rnd_idx] = 1 106 | 107 | num_g_real_masked = 0 108 | for g_idx in range(num_g_real): 109 | if mask[g_idx] == 1: 110 | masked_raw_cmc[num_g_real_masked] = raw_cmc[g_idx] 111 | num_g_real_masked += 1 112 | 113 | masked_cmc = np.zeros(num_g, dtype=np.float32) 114 | function_cumsum(masked_raw_cmc, masked_cmc, num_g_real_masked) 115 | for g_idx in range(num_g_real_masked): 116 | if masked_cmc[g_idx] > 1: 117 | masked_cmc[g_idx] = 1 118 | 119 | for rank_idx in range(max_rank): 120 | cmc[rank_idx] += masked_cmc[rank_idx] / num_repeats 121 | 122 | # compute AP 123 | function_cumsum(masked_raw_cmc, tmp_cmc, num_g_real_masked) 124 | num_rel = 0 125 | tmp_cmc_sum = 0 126 | for g_idx in range(num_g_real_masked): 127 | tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * masked_raw_cmc[g_idx] 128 | num_rel += masked_raw_cmc[g_idx] 129 | AP += tmp_cmc_sum / num_rel 130 | 131 | all_AP[q_idx] = AP / num_repeats 132 | for rank_idx in range(max_rank): 133 | all_cmc[q_idx, rank_idx] = cmc[rank_idx] 134 | num_valid_q += 1. 135 | 136 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 137 | 138 | # compute averaged cmc 139 | cdef float[:] avg_cmc = np.zeros(max_rank, dtype=np.float32) 140 | for rank_idx in range(max_rank): 141 | for q_idx in range(num_q): 142 | avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx] 143 | avg_cmc[rank_idx] /= num_valid_q 144 | 145 | cdef float mAP = 0 146 | for q_idx in range(num_q): 147 | mAP += all_AP[q_idx] 148 | mAP /= num_valid_q 149 | 150 | return np.asarray(avg_cmc).astype(np.float32), mAP 151 | 152 | 153 | cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids, 154 | long[:]q_camids, long[:]g_camids, long max_rank): 155 | 156 | cdef long num_q = distmat.shape[0] 157 | cdef long num_g = distmat.shape[1] 158 | 159 | if num_g < max_rank: 160 | max_rank = num_g 161 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 162 | 163 | cdef: 164 | long[:,:] indices = np.argsort(distmat, axis=1) 165 | long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64) 166 | 167 | float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32) 168 | float[:] all_AP = np.zeros(num_q, dtype=np.float32) 169 | float num_valid_q = 0. # number of valid query 170 | 171 | long q_idx, q_pid, q_camid, g_idx 172 | long[:] order = np.zeros(num_g, dtype=np.int64) 173 | long keep 174 | 175 | float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches 176 | float[:] cmc = np.zeros(num_g, dtype=np.float32) 177 | long num_g_real, rank_idx 178 | unsigned long meet_condition 179 | 180 | float num_rel 181 | float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32) 182 | float tmp_cmc_sum 183 | 184 | for q_idx in range(num_q): 185 | # get query pid and camid 186 | q_pid = q_pids[q_idx] 187 | q_camid = q_camids[q_idx] 188 | 189 | # remove gallery samples that have the same pid and camid with query 190 | for g_idx in range(num_g): 191 | order[g_idx] = indices[q_idx, g_idx] 192 | num_g_real = 0 193 | meet_condition = 0 194 | 195 | for g_idx in range(num_g): 196 | if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid): 197 | raw_cmc[num_g_real] = matches[q_idx][g_idx] 198 | num_g_real += 1 199 | if matches[q_idx][g_idx] > 1e-31: 200 | meet_condition = 1 201 | 202 | if not meet_condition: 203 | # this condition is true when query identity does not appear in gallery 204 | continue 205 | 206 | # compute cmc 207 | function_cumsum(raw_cmc, cmc, num_g_real) 208 | for g_idx in range(num_g_real): 209 | if cmc[g_idx] > 1: 210 | cmc[g_idx] = 1 211 | 212 | for rank_idx in range(max_rank): 213 | all_cmc[q_idx, rank_idx] = cmc[rank_idx] 214 | num_valid_q += 1. 215 | 216 | # compute average precision 217 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 218 | function_cumsum(raw_cmc, tmp_cmc, num_g_real) 219 | num_rel = 0 220 | tmp_cmc_sum = 0 221 | for g_idx in range(num_g_real): 222 | tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx] 223 | num_rel += raw_cmc[g_idx] 224 | all_AP[q_idx] = tmp_cmc_sum / num_rel 225 | 226 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 227 | 228 | # compute averaged cmc 229 | cdef float[:] avg_cmc = np.zeros(max_rank, dtype=np.float32) 230 | for rank_idx in range(max_rank): 231 | for q_idx in range(num_q): 232 | avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx] 233 | avg_cmc[rank_idx] /= num_valid_q 234 | 235 | cdef float mAP = 0 236 | for q_idx in range(num_q): 237 | mAP += all_AP[q_idx] 238 | mAP /= num_valid_q 239 | 240 | return np.asarray(avg_cmc).astype(np.float32), mAP 241 | 242 | 243 | # Compute the cumulative sum 244 | cdef void function_cumsum(cython.numeric[:] src, cython.numeric[:] dst, long n): 245 | cdef long i 246 | dst[0] = src[0] 247 | for i in range(1, n): 248 | dst[i] = src[i] + dst[i - 1] -------------------------------------------------------------------------------- /torchreid/eval_cylib/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from distutils.extension import Extension 3 | from Cython.Build import cythonize 4 | import numpy as np 5 | 6 | 7 | def numpy_include(): 8 | try: 9 | numpy_include = np.get_include() 10 | except AttributeError: 11 | numpy_include = np.get_numpy_include() 12 | return numpy_include 13 | 14 | ext_modules = [ 15 | Extension('eval_metrics_cy', 16 | ['eval_metrics_cy.pyx'], 17 | include_dirs=[numpy_include()], 18 | ) 19 | ] 20 | 21 | setup( 22 | name='Cython-based reid evaluation code', 23 | ext_modules=cythonize(ext_modules) 24 | ) -------------------------------------------------------------------------------- /torchreid/eval_cylib/test_cython.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | import os.path as osp 5 | import timeit 6 | import numpy as np 7 | 8 | sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../..') 9 | 10 | from torchreid.eval_metrics import evaluate 11 | 12 | """ 13 | Test the speed of cython-based evaluation code. The speed improvements 14 | can be much bigger when using the real reid data, which contains a larger 15 | amount of query and gallery images. 16 | 17 | Note: you might encounter the following error: 18 | 'AssertionError: Error: all query identities do not appear in gallery'. 19 | This is normal because the inputs are random numbers. Just try again. 20 | """ 21 | 22 | print("*** Compare running time ***") 23 | 24 | setup = ''' 25 | import sys 26 | import os.path as osp 27 | import numpy as np 28 | sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../..') 29 | from torchreid.eval_metrics import evaluate 30 | num_q = 30 31 | num_g = 300 32 | max_rank = 5 33 | distmat = np.random.rand(num_q, num_g) * 20 34 | q_pids = np.random.randint(0, num_q, size=num_q) 35 | g_pids = np.random.randint(0, num_g, size=num_g) 36 | q_camids = np.random.randint(0, 5, size=num_q) 37 | g_camids = np.random.randint(0, 5, size=num_g) 38 | ''' 39 | 40 | print("=> Using market1501's metric") 41 | pytime = timeit.timeit('evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)', setup=setup, number=20) 42 | cytime = timeit.timeit('evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)', setup=setup, number=20) 43 | print("Python time: {} s".format(pytime)) 44 | print("Cython time: {} s".format(cytime)) 45 | print("Cython is {} times faster than python\n".format(pytime / cytime)) 46 | 47 | print("=> Using cuhk03's metric") 48 | pytime = timeit.timeit('evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=False)', setup=setup, number=20) 49 | cytime = timeit.timeit('evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=True)', setup=setup, number=20) 50 | print("Python time: {} s".format(pytime)) 51 | print("Cython time: {} s".format(cytime)) 52 | print("Cython is {} times faster than python\n".format(pytime / cytime)) 53 | 54 | """ 55 | print("=> Check precision") 56 | 57 | num_q = 30 58 | num_g = 300 59 | max_rank = 5 60 | distmat = np.random.rand(num_q, num_g) * 20 61 | q_pids = np.random.randint(0, num_q, size=num_q) 62 | g_pids = np.random.randint(0, num_g, size=num_g) 63 | q_camids = np.random.randint(0, 5, size=num_q) 64 | g_camids = np.random.randint(0, 5, size=num_g) 65 | 66 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False) 67 | print("Python:\nmAP = {} \ncmc = {}\n".format(mAP, cmc)) 68 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True) 69 | print("Cython:\nmAP = {} \ncmc = {}\n".format(mAP, cmc)) 70 | """ -------------------------------------------------------------------------------- /torchreid/eval_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import copy 7 | from collections import defaultdict 8 | import sys 9 | import warnings 10 | 11 | from args import argument_parser 12 | 13 | # global variables 14 | parser = argument_parser() 15 | args = parser.parse_args() 16 | 17 | try: 18 | from torchreid.eval_cylib.eval_metrics_cy import evaluate_cy 19 | 20 | IS_CYTHON_AVAI = True 21 | print("Using Cython evaluation code as the backend") 22 | except ImportError: 23 | IS_CYTHON_AVAI = False 24 | warnings.warn("Cython evaluation is UNAVAILABLE, which is highly recommended") 25 | 26 | 27 | 28 | def eval_cuhk03(distmat, q_pids, g_pids, q_cloth_ids, g_cloth_ids, q_camids, g_camids, max_rank): 29 | """Evaluation with cuhk03 metric 30 | Key: one image for each gallery identity is randomly sampled for each query identity. 31 | Random sampling is performed num_repeats times. 32 | """ 33 | num_repeats = 10 34 | num_q, num_g = distmat.shape 35 | 36 | if num_g < max_rank: 37 | max_rank = num_g 38 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 39 | 40 | indices = np.argsort(distmat, axis=1) 41 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 42 | 43 | # compute cmc curve for each query 44 | all_cmc = [] 45 | all_AP = [] 46 | num_valid_q = 0. # number of valid query 47 | 48 | for q_idx in range(num_q): 49 | # get query pid and camid 50 | q_pid = q_pids[q_idx] 51 | q_cloth_id = q_cloth_ids[q_idx] 52 | q_camid = q_camids[q_idx] 53 | 54 | #if args.use_standard_metric: 55 | # remove gallery samples that have the same pid and camid with query 56 | order = indices[q_idx] 57 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 58 | keep = np.invert(remove) 59 | # elif args.use_cloth_changing_metric: 60 | # # remove gallery samples that have the same pid, cloth_id and camid with query 61 | # order = indices[q_idx] 62 | # remove = (g_pids[order] == q_pid) & ((g_camids[order] == q_camid) | (g_cloth_ids[order] == q_cloth_id)) 63 | # keep = np.invert(remove) 64 | 65 | if args.just_for_prcc_test: # didn't remove anything 66 | # remove gallery samples that have the same pid and camid with query 67 | order = indices[q_idx] 68 | remove = (g_pids[order] == (q_pid * 100000)) 69 | keep = np.invert(remove) 70 | 71 | # compute cmc curve 72 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 73 | if not np.any(raw_cmc): 74 | # this condition is true when query identity does not appear in gallery 75 | continue 76 | 77 | kept_g_pids = g_pids[order][keep] 78 | g_pids_dict = defaultdict(list) 79 | for idx, pid in enumerate(kept_g_pids): 80 | g_pids_dict[pid].append(idx) 81 | 82 | cmc, AP = 0., 0. 83 | for repeat_idx in range(num_repeats): 84 | mask = np.zeros(len(raw_cmc), dtype=np.bool) 85 | for _, idxs in g_pids_dict.items(): 86 | # randomly sample one image for each gallery person 87 | rnd_idx = np.random.choice(idxs) 88 | mask[rnd_idx] = True 89 | masked_raw_cmc = raw_cmc[mask] 90 | _cmc = masked_raw_cmc.cumsum() 91 | _cmc[_cmc > 1] = 1 92 | cmc += _cmc[:max_rank].astype(np.float32) 93 | # compute AP 94 | num_rel = masked_raw_cmc.sum() 95 | tmp_cmc = masked_raw_cmc.cumsum() 96 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 97 | tmp_cmc = np.asarray(tmp_cmc) * masked_raw_cmc 98 | AP += tmp_cmc.sum() / num_rel 99 | 100 | cmc /= num_repeats 101 | AP /= num_repeats 102 | all_cmc.append(cmc) 103 | all_AP.append(AP) 104 | num_valid_q += 1. 105 | 106 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 107 | 108 | all_cmc = np.asarray(all_cmc).astype(np.float32) 109 | all_cmc = all_cmc.sum(0) / num_valid_q 110 | mAP = np.mean(all_AP) 111 | 112 | return all_cmc, mAP 113 | 114 | 115 | def eval_market1501(distmat, q_pids, g_pids, q_cloth_ids, g_cloth_ids, q_camids, g_camids, max_rank): 116 | """Evaluation with market1501 metric 117 | Key: for each query identity, its gallery images from the same camera view are discarded. 118 | """ 119 | num_q, num_g = distmat.shape 120 | 121 | if num_g < max_rank: 122 | max_rank = num_g 123 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 124 | 125 | indices = np.argsort(distmat, axis=1) 126 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 127 | 128 | # compute cmc curve for each query 129 | all_cmc = [] 130 | all_AP = [] 131 | num_valid_q = 0. # number of valid query 132 | 133 | for q_idx in range(num_q): 134 | # get query pid and camid 135 | q_pid = q_pids[q_idx] 136 | q_cloth_id = q_cloth_ids[q_idx] 137 | q_camid = q_camids[q_idx] 138 | 139 | if args.use_standard_metric: 140 | # remove gallery samples that have the same pid and camid with query 141 | order = indices[q_idx] 142 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 143 | keep = np.invert(remove) 144 | elif args.use_cloth_changing_metric: 145 | # remove gallery samples that have the same pid, cloth_id and camid with query 146 | order = indices[q_idx] 147 | remove = (g_pids[order] == q_pid) & ((g_camids[order] == q_camid) | (g_cloth_ids[order] == q_cloth_id)) 148 | keep = np.invert(remove) 149 | else: # 只要没特殊指定,就按正常的来 150 | # remove gallery samples that have the same pid and camid with query 151 | order = indices[q_idx] 152 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 153 | keep = np.invert(remove) 154 | 155 | 156 | # compute cmc curve 157 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 158 | if not np.any(raw_cmc): 159 | # this condition is true when query identity does not appear in gallery 160 | continue 161 | 162 | cmc = raw_cmc.cumsum() 163 | cmc[cmc > 1] = 1 164 | 165 | all_cmc.append(cmc[:max_rank]) 166 | num_valid_q += 1. 167 | 168 | # compute average precision 169 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 170 | num_rel = raw_cmc.sum() 171 | tmp_cmc = raw_cmc.cumsum() 172 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 173 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 174 | AP = tmp_cmc.sum() / num_rel 175 | all_AP.append(AP) 176 | 177 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 178 | 179 | all_cmc = np.asarray(all_cmc).astype(np.float32) 180 | all_cmc = all_cmc.sum(0) / num_valid_q 181 | mAP = np.mean(all_AP) 182 | 183 | return all_cmc, mAP 184 | 185 | 186 | def evaluate_py(distmat, q_pids, g_pids, q_cloth_ids, g_cloth_ids, q_camids, g_camids, max_rank, use_metric_cuhk03): 187 | if use_metric_cuhk03: 188 | return eval_cuhk03(distmat, q_pids, g_pids, q_cloth_ids, g_cloth_ids, q_camids, g_camids, max_rank) 189 | else: 190 | return eval_market1501(distmat, q_pids, g_pids, q_cloth_ids, g_cloth_ids, q_camids, g_camids, max_rank) 191 | 192 | 193 | def evaluate(distmat, q_pids, g_pids, q_cloth_ids, g_cloth_ids, q_camids, g_camids, max_rank=50, 194 | use_metric_cuhk03=False, use_cython=True): 195 | if use_cython and IS_CYTHON_AVAI: 196 | # 还没改,没加cloth_id,先别用,也没装Cython 197 | return evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03) 198 | else: 199 | return evaluate_py(distmat, q_pids, g_pids, q_cloth_ids, g_cloth_ids, q_camids, g_camids, max_rank, 200 | use_metric_cuhk03) -------------------------------------------------------------------------------- /torchreid/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .cross_entropy_loss import CrossEntropyLoss 6 | from .hard_mine_triplet_loss import TripletLoss 7 | from .center_loss import CenterLoss 8 | 9 | 10 | def DeepSupervision(criterion, xs, y): 11 | """ 12 | Args: 13 | - criterion: loss function 14 | - xs: tuple of inputs 15 | - y: ground truth 16 | """ 17 | loss = 0. 18 | for x in xs: 19 | loss += criterion(x, y) 20 | loss /= len(xs) 21 | return loss -------------------------------------------------------------------------------- /torchreid/losses/center_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import warnings 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class CenterLoss(nn.Module): 11 | """Center loss. 12 | 13 | Reference: 14 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 15 | 16 | Args: 17 | - num_classes (int): number of classes. 18 | - feat_dim (int): feature dimension. 19 | """ 20 | def __init__(self, num_classes=10, feat_dim=2, use_gpu=True): 21 | super(CenterLoss, self).__init__() 22 | warnings.warn("This method is deprecated") 23 | self.num_classes = num_classes 24 | self.feat_dim = feat_dim 25 | self.use_gpu = use_gpu 26 | 27 | if self.use_gpu: 28 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 29 | else: 30 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 31 | 32 | def forward(self, x, labels): 33 | """ 34 | Args: 35 | - x: feature matrix with shape (batch_size, feat_dim). 36 | - labels: ground truth labels with shape (num_classes). 37 | """ 38 | batch_size = x.size(0) 39 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 40 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 41 | distmat.addmm_(1, -2, x, self.centers.t()) 42 | 43 | classes = torch.arange(self.num_classes).long() 44 | if self.use_gpu: classes = classes.cuda() 45 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 46 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 47 | 48 | dist = [] 49 | for i in range(batch_size): 50 | value = distmat[i][mask[i]] 51 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 52 | dist.append(value) 53 | dist = torch.cat(dist) 54 | loss = dist.mean() 55 | 56 | return loss -------------------------------------------------------------------------------- /torchreid/losses/cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class CrossEntropyLoss(nn.Module): 9 | """Cross entropy loss with label smoothing regularizer. 10 | 11 | Reference: 12 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 13 | 14 | Equation: y = (1 - epsilon) * y + epsilon / K. 15 | 16 | Args: 17 | - num_classes (int): number of classes 18 | - epsilon (float): weight 19 | - use_gpu (bool): whether to use gpu devices 20 | - label_smooth (bool): whether to apply label smoothing, if False, epsilon = 0 21 | """ 22 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, label_smooth=True): 23 | super(CrossEntropyLoss, self).__init__() 24 | self.num_classes = num_classes 25 | self.epsilon = epsilon if label_smooth else 0 26 | self.use_gpu = use_gpu 27 | self.logsoftmax = nn.LogSoftmax(dim=1) 28 | 29 | def forward(self, inputs, targets): 30 | """ 31 | Args: 32 | - inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 33 | - targets: ground truth labels with shape (num_classes) 34 | """ 35 | log_probs = self.logsoftmax(inputs) 36 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 37 | if self.use_gpu: targets = targets.cuda() 38 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 39 | loss = (- targets * log_probs).mean(0).sum() 40 | return loss -------------------------------------------------------------------------------- /torchreid/losses/hard_mine_triplet_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class TripletLoss(nn.Module): 9 | """Triplet loss with hard positive/negative mining. 10 | 11 | Reference: 12 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 13 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 14 | 15 | Args: 16 | - margin (float): margin for triplet. 17 | """ 18 | def __init__(self, margin=0.3): 19 | super(TripletLoss, self).__init__() 20 | self.margin = margin 21 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 22 | 23 | def forward(self, inputs, targets): 24 | """ 25 | Args: 26 | - inputs: feature matrix with shape (batch_size, feat_dim) 27 | - targets: ground truth labels with shape (num_classes) 28 | """ 29 | n = inputs.size(0) 30 | 31 | # Compute pairwise distance, replace by the official when merged 32 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 33 | dist = dist + dist.t() 34 | dist.addmm_(1, -2, inputs, inputs.t()) 35 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 36 | 37 | # For each anchor, find the hardest positive and negative 38 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 39 | dist_ap, dist_an = [], [] 40 | for i in range(n): 41 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 42 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 43 | dist_ap = torch.cat(dist_ap) 44 | dist_an = torch.cat(dist_an) 45 | 46 | # Compute ranking hinge loss 47 | y = torch.ones_like(dist_an) 48 | loss = self.ranking_loss(dist_an, dist_ap, y) 49 | return loss -------------------------------------------------------------------------------- /torchreid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .resnet import * 4 | 5 | 6 | __model_factory = { 7 | # image classification models 8 | 'resnet50': resnet50, 9 | 'resnet50_fc512': resnet50_fc512 10 | } 11 | 12 | 13 | def get_names(): 14 | return list(__model_factory.keys()) 15 | 16 | 17 | def init_model(name, *args, **kwargs): 18 | if name not in list(__model_factory.keys()): 19 | raise KeyError("Unknown model: {}".format(name)) 20 | return __model_factory[name](*args, **kwargs) -------------------------------------------------------------------------------- /torchreid/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | import torchvision 8 | import torch.utils.model_zoo as model_zoo 9 | 10 | 11 | __all__ = ['resnet50', 'resnet50_fc512'] 12 | 13 | 14 | model_urls = { 15 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 16 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 17 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 18 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 19 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 20 | } 21 | 22 | 23 | def conv3x3(in_planes, out_planes, stride=1): 24 | """3x3 convolution with padding""" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 26 | padding=1, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None): 33 | super(BasicBlock, self).__init__() 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None): 65 | super(Bottleneck, self).__init__() 66 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 67 | self.bn1 = nn.BatchNorm2d(planes) 68 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 69 | padding=1, bias=False) 70 | self.bn2 = nn.BatchNorm2d(planes) 71 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 72 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 73 | self.relu = nn.ReLU(inplace=True) 74 | self.downsample = downsample 75 | self.stride = stride 76 | 77 | def forward(self, x): 78 | residual = x 79 | 80 | out = self.conv1(x) 81 | out = self.bn1(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv2(out) 85 | out = self.bn2(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | 91 | if self.downsample is not None: 92 | residual = self.downsample(x) 93 | 94 | out += residual 95 | out = self.relu(out) 96 | 97 | return out 98 | 99 | def weights_init_kaiming(m): 100 | classname = m.__class__.__name__ 101 | if classname.find('Linear') != -1: 102 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 103 | nn.init.constant_(m.bias, 0.0) 104 | elif classname.find('Conv') != -1: 105 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 106 | if m.bias is not None: 107 | nn.init.constant_(m.bias, 0.0) 108 | elif classname.find('BatchNorm') != -1: 109 | if m.affine: 110 | nn.init.constant_(m.weight, 1.0) 111 | nn.init.constant_(m.bias, 0.0) 112 | 113 | def weights_init_classifier(m): 114 | classname = m.__class__.__name__ 115 | if classname.find('Linear') != -1: 116 | nn.init.normal_(m.weight, std=0.001) 117 | if m.bias: 118 | nn.init.constant_(m.bias, 0.0) 119 | 120 | class ResNet(nn.Module): 121 | """ 122 | Residual network 123 | 124 | Reference: 125 | He et al. Deep Residual Learning for Image Recognition. CVPR 2016. 126 | """ 127 | def __init__(self, num_classes, loss, block, layers, 128 | last_stride=2, 129 | fc_dims=None, 130 | dropout_p=None, 131 | **kwargs): 132 | self.inplanes = 64 133 | super(ResNet, self).__init__() 134 | self.loss = loss 135 | self.feature_dim = 512 * block.expansion 136 | 137 | # backbone network 138 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 139 | self.bn1 = nn.BatchNorm2d(64) 140 | self.relu = nn.ReLU(inplace=True) 141 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 142 | self.layer1 = self._make_layer(block, 64, layers[0]) 143 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 144 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 145 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride) 146 | 147 | self.global_avgpool = nn.AdaptiveAvgPool2d(1) 148 | self.fc = self._construct_fc_layer(fc_dims, 512 * block.expansion, dropout_p) 149 | self.classifier = nn.Linear(self.feature_dim, num_classes) 150 | 151 | self._init_params() 152 | 153 | 154 | def _make_layer(self, block, planes, blocks, stride=1): 155 | downsample = None 156 | if stride != 1 or self.inplanes != planes * block.expansion: 157 | downsample = nn.Sequential( 158 | nn.Conv2d(self.inplanes, planes * block.expansion, 159 | kernel_size=1, stride=stride, bias=False), 160 | nn.BatchNorm2d(planes * block.expansion), 161 | ) 162 | 163 | layers = [] 164 | layers.append(block(self.inplanes, planes, stride, downsample)) 165 | self.inplanes = planes * block.expansion 166 | for i in range(1, blocks): 167 | layers.append(block(self.inplanes, planes)) 168 | 169 | return nn.Sequential(*layers) 170 | 171 | def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): 172 | """ 173 | Construct fully connected layer 174 | 175 | - fc_dims (list or tuple): dimensions of fc layers, if None, 176 | no fc layers are constructed 177 | - input_dim (int): input dimension 178 | - dropout_p (float): dropout probability, if None, dropout is unused 179 | """ 180 | if fc_dims is None: 181 | self.feature_dim = input_dim 182 | return None 183 | 184 | assert isinstance(fc_dims, (list, tuple)), "fc_dims must be either list or tuple, but got {}".format(type(fc_dims)) 185 | 186 | layers = [] 187 | for dim in fc_dims: 188 | layers.append(nn.Linear(input_dim, dim)) 189 | layers.append(nn.BatchNorm1d(dim)) 190 | layers.append(nn.ReLU(inplace=True)) 191 | if dropout_p is not None: 192 | layers.append(nn.Dropout(p=dropout_p)) 193 | input_dim = dim 194 | 195 | self.feature_dim = fc_dims[-1] 196 | 197 | return nn.Sequential(*layers) 198 | 199 | def _init_params(self): 200 | for m in self.modules(): 201 | if isinstance(m, nn.Conv2d): 202 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 203 | if m.bias is not None: 204 | nn.init.constant_(m.bias, 0) 205 | elif isinstance(m, nn.BatchNorm2d): 206 | nn.init.constant_(m.weight, 1) 207 | nn.init.constant_(m.bias, 0) 208 | elif isinstance(m, nn.BatchNorm1d): 209 | nn.init.constant_(m.weight, 1) 210 | nn.init.constant_(m.bias, 0) 211 | elif isinstance(m, nn.Linear): 212 | nn.init.normal_(m.weight, 0, 0.01) 213 | if m.bias is not None: 214 | nn.init.constant_(m.bias, 0) 215 | 216 | def featuremaps(self, x): 217 | x = self.conv1(x) 218 | x = self.bn1(x) 219 | x = self.relu(x) 220 | x = self.maxpool(x) 221 | x = self.layer1(x) 222 | x = self.layer2(x) 223 | x = self.layer3(x) 224 | x = self.layer4(x) 225 | return x 226 | 227 | def forward(self, x): 228 | f = self.featuremaps(x) 229 | v = self.global_avgpool(f) 230 | v = v.view(v.size(0), -1) 231 | 232 | if self.fc is not None: 233 | v = self.fc(v) 234 | 235 | # v = self.BN(v) 236 | 237 | if not self.training: 238 | return v 239 | 240 | y = self.classifier(v) 241 | 242 | if self.loss == {'xent'}: 243 | return y 244 | elif self.loss == {'xent', 'htri'}: 245 | return y, v 246 | else: 247 | raise KeyError("Unsupported loss: {}".format(self.loss)) 248 | 249 | 250 | def init_pretrained_weights(model, model_url): 251 | """ 252 | Initialize model with pretrained weights. 253 | Layers that don't match with pretrained layers in name or size are kept unchanged. 254 | """ 255 | pretrain_dict = model_zoo.load_url(model_url) 256 | model_dict = model.state_dict() 257 | pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} 258 | model_dict.update(pretrain_dict) 259 | model.load_state_dict(model_dict) 260 | print("Initialized model with pretrained weights from {}".format(model_url)) 261 | 262 | 263 | """ 264 | Residual network configurations: 265 | -- 266 | resnet18: block=BasicBlock, layers=[2, 2, 2, 2] 267 | resnet34: block=BasicBlock, layers=[3, 4, 6, 3] 268 | resnet50: block=Bottleneck, layers=[3, 4, 6, 3] 269 | resnet101: block=Bottleneck, layers=[3, 4, 23, 3] 270 | resnet152: block=Bottleneck, layers=[3, 8, 36, 3] 271 | """ 272 | 273 | 274 | def resnet50(num_classes, loss, pretrained='imagenet', **kwargs): 275 | model = ResNet( 276 | num_classes=num_classes, 277 | loss=loss, 278 | block=Bottleneck, 279 | layers=[3, 4, 6, 3], 280 | last_stride=2, 281 | fc_dims=None, 282 | dropout_p=None, 283 | **kwargs 284 | ) 285 | if pretrained == 'imagenet': 286 | init_pretrained_weights(model, model_urls['resnet50']) 287 | return model 288 | 289 | 290 | def resnet50_fc512(num_classes, loss, pretrained='imagenet', **kwargs): 291 | model = ResNet( 292 | num_classes=num_classes, 293 | loss=loss, 294 | block=Bottleneck, 295 | layers=[3, 4, 6, 3], 296 | last_stride=1, 297 | fc_dims=[512], 298 | dropout_p=None, 299 | **kwargs 300 | ) 301 | if pretrained == 'imagenet': 302 | init_pretrained_weights(model, model_urls['resnet50']) 303 | return model -------------------------------------------------------------------------------- /torchreid/optimizers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def init_optimizer(params, 7 | optim='adam', 8 | lr=0.003, 9 | weight_decay=5e-4, 10 | momentum=0.9, # momentum factor for sgd and rmsprop 11 | sgd_dampening=0, # sgd's dampening for momentum 12 | sgd_nesterov=False, # whether to enable sgd's Nesterov momentum 13 | rmsprop_alpha=0.99, # rmsprop's smoothing constant 14 | adam_beta1=0.9, # exponential decay rate for adam's first moment 15 | adam_beta2=0.999 # # exponential decay rate for adam's second moment 16 | ): 17 | if optim == 'adam': 18 | return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay, 19 | betas=(adam_beta1, adam_beta2)) 20 | 21 | elif optim == 'amsgrad': 22 | return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay, 23 | betas=(adam_beta1, adam_beta2), amsgrad=True) 24 | 25 | elif optim == 'sgd': 26 | return torch.optim.SGD(params, lr=lr, momentum=momentum, weight_decay=weight_decay, 27 | dampening=sgd_dampening, nesterov=sgd_nesterov) 28 | 29 | elif optim == 'rmsprop': 30 | return torch.optim.RMSprop(params, lr=lr, momentum=momentum, weight_decay=weight_decay, 31 | alpha=rmsprop_alpha) 32 | 33 | else: 34 | raise ValueError("Unsupported optimizer: {}".format(optim)) -------------------------------------------------------------------------------- /torchreid/samplers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | from collections import defaultdict 5 | import numpy as np 6 | import copy 7 | import random 8 | 9 | import torch 10 | from torch.utils.data.sampler import Sampler 11 | 12 | 13 | class RandomIdentitySampler(Sampler): 14 | """ 15 | Randomly sample N identities, then for each identity, 16 | randomly sample K instances, therefore batch size is N*K. 17 | 18 | Args: 19 | - data_source (list): list of (img_path, pid, camid). 20 | - num_instances (int): number of instances per identity in a batch. 21 | - batch_size (int): number of examples in a batch. 22 | """ 23 | def __init__(self, data_source, batch_size, num_instances): 24 | self.data_source = data_source 25 | self.batch_size = batch_size 26 | self.num_instances = num_instances 27 | self.num_pids_per_batch = self.batch_size // self.num_instances 28 | self.index_dic = defaultdict(list) 29 | for index, (_, pid, _) in enumerate(self.data_source): 30 | self.index_dic[pid].append(index) 31 | self.pids = list(self.index_dic.keys()) 32 | 33 | # estimate number of examples in an epoch 34 | self.length = 0 35 | for pid in self.pids: 36 | idxs = self.index_dic[pid] 37 | num = len(idxs) 38 | if num < self.num_instances: 39 | num = self.num_instances 40 | self.length += num - num % self.num_instances 41 | 42 | def __iter__(self): 43 | batch_idxs_dict = defaultdict(list) 44 | 45 | for pid in self.pids: 46 | idxs = copy.deepcopy(self.index_dic[pid]) 47 | if len(idxs) < self.num_instances: 48 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 49 | random.shuffle(idxs) 50 | batch_idxs = [] 51 | for idx in idxs: 52 | batch_idxs.append(idx) 53 | if len(batch_idxs) == self.num_instances: 54 | batch_idxs_dict[pid].append(batch_idxs) 55 | batch_idxs = [] 56 | 57 | avai_pids = copy.deepcopy(self.pids) 58 | final_idxs = [] 59 | 60 | while len(avai_pids) >= self.num_pids_per_batch: 61 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 62 | for pid in selected_pids: 63 | batch_idxs = batch_idxs_dict[pid].pop(0) 64 | final_idxs.extend(batch_idxs) 65 | if len(batch_idxs_dict[pid]) == 0: 66 | avai_pids.remove(pid) 67 | 68 | return iter(final_idxs) 69 | 70 | def __len__(self): 71 | return self.length -------------------------------------------------------------------------------- /torchreid/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | from torchvision.transforms import * 5 | import torch 6 | 7 | from PIL import Image 8 | import random 9 | import numpy as np 10 | 11 | 12 | class Random2DTranslation(object): 13 | """ 14 | With a probability, first increase image size to (1 + 1/8), and then perform random crop. 15 | 16 | Args: 17 | - height (int): target image height. 18 | - width (int): target image width. 19 | - p (float): probability of performing this transformation. Default: 0.5. 20 | """ 21 | def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR): 22 | self.height = height 23 | self.width = width 24 | self.p = p 25 | self.interpolation = interpolation 26 | 27 | def __call__(self, img): 28 | """ 29 | Args: 30 | - img (PIL Image): Image to be cropped. 31 | """ 32 | if random.uniform(0, 1) > self.p: 33 | return img.resize((self.width, self.height), self.interpolation) 34 | 35 | new_width, new_height = int(round(self.width * 1.125)), int(round(self.height * 1.125)) 36 | resized_img = img.resize((new_width, new_height), self.interpolation) 37 | x_maxrange = new_width - self.width 38 | y_maxrange = new_height - self.height 39 | x1 = int(round(random.uniform(0, x_maxrange))) 40 | y1 = int(round(random.uniform(0, y_maxrange))) 41 | croped_img = resized_img.crop((x1, y1, x1 + self.width, y1 + self.height)) 42 | return croped_img 43 | 44 | 45 | def build_transforms(height, width, is_train, **kwargs): 46 | """Build transforms 47 | 48 | Args: 49 | - height (int): target image height. 50 | - width (int): target image width. 51 | - is_train (bool): train or test phase. 52 | """ 53 | 54 | # use imagenet mean and std as default 55 | imagenet_mean = [0.485, 0.456, 0.406] 56 | imagenet_std = [0.229, 0.224, 0.225] 57 | normalize = Normalize(mean=imagenet_mean, std=imagenet_std) 58 | 59 | transforms = [] 60 | 61 | if is_train: 62 | transforms += [Random2DTranslation(height, width)] 63 | transforms += [RandomHorizontalFlip()] 64 | else: 65 | transforms += [Resize((height, width))] 66 | 67 | transforms += [ToTensor()] 68 | transforms += [normalize] 69 | 70 | transforms = Compose(transforms) 71 | 72 | return transforms -------------------------------------------------------------------------------- /torchreid/utils/avgmeter.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value. 7 | 8 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 9 | """ 10 | def __init__(self): 11 | self.reset() 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /torchreid/utils/iotools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import os.path as osp 5 | import errno 6 | import json 7 | import shutil 8 | 9 | import torch 10 | 11 | 12 | def mkdir_if_missing(directory): 13 | if not osp.exists(directory): 14 | try: 15 | os.makedirs(directory) 16 | except OSError as e: 17 | if e.errno != errno.EEXIST: 18 | raise 19 | 20 | 21 | def check_isfile(path): 22 | isfile = osp.isfile(path) 23 | if not isfile: 24 | print("=> Warning: no file found at '{}' (ignored)".format(path)) 25 | return isfile 26 | 27 | 28 | def read_json(fpath): 29 | with open(fpath, 'r') as f: 30 | obj = json.load(f) 31 | return obj 32 | 33 | 34 | def write_json(obj, fpath): 35 | mkdir_if_missing(osp.dirname(fpath)) 36 | with open(fpath, 'w') as f: 37 | json.dump(obj, f, indent=4, separators=(',', ': ')) 38 | 39 | 40 | def save_checkpoint(state, is_best=False, fpath='checkpoint.pth.tar'): 41 | if len(osp.dirname(fpath)) != 0: 42 | mkdir_if_missing(osp.dirname(fpath)) 43 | torch.save(state, fpath) 44 | if is_best: 45 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar')) -------------------------------------------------------------------------------- /torchreid/utils/loggers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import sys 4 | import os 5 | import os.path as osp 6 | 7 | from .iotools import mkdir_if_missing 8 | 9 | 10 | class Logger(object): 11 | """ 12 | Write console output to external text file. 13 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 14 | """ 15 | def __init__(self, fpath=None): 16 | self.console = sys.stdout 17 | self.file = None 18 | if fpath is not None: 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | self.file = open(fpath, 'w') 21 | 22 | def __del__(self): 23 | self.close() 24 | 25 | def __enter__(self): 26 | pass 27 | 28 | def __exit__(self, *args): 29 | self.close() 30 | 31 | def write(self, msg): 32 | self.console.write(msg) 33 | if self.file is not None: 34 | self.file.write(msg) 35 | 36 | def flush(self): 37 | self.console.flush() 38 | if self.file is not None: 39 | self.file.flush() 40 | os.fsync(self.file.fileno()) 41 | 42 | def close(self): 43 | self.console.close() 44 | if self.file is not None: 45 | self.file.close() 46 | 47 | 48 | class RankLogger(object): 49 | """ 50 | RankLogger records the rank1 matching accuracy obtained for each 51 | test dataset at specified evaluation steps and provides a function 52 | to show the summarized results, which are convenient for analysis. 53 | 54 | Args: 55 | - source_names (list): list of strings (names) of source datasets. 56 | - target_names (list): list of strings (names) of target datasets. 57 | """ 58 | def __init__(self, source_names, target_names): 59 | self.source_names = source_names 60 | self.target_names = target_names 61 | self.logger = {name: {'epoch': [], 'rank1': []} for name in self.target_names} 62 | 63 | def write(self, name, epoch, rank1): 64 | self.logger[name]['epoch'].append(epoch) 65 | self.logger[name]['rank1'].append(rank1) 66 | 67 | def show_summary(self): 68 | print("=> Show summary") 69 | for name in self.target_names: 70 | from_where = 'source' if name in self.source_names else 'target' 71 | print("{} ({})".format(name, from_where)) 72 | for epoch, rank1 in zip(self.logger[name]['epoch'], self.logger[name]['rank1']): 73 | print("- epoch {}\t rank1 {:.1%}".format(epoch, rank1)) -------------------------------------------------------------------------------- /torchreid/utils/reidtools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import numpy as np 5 | import os 6 | import os.path as osp 7 | import shutil 8 | 9 | from .iotools import mkdir_if_missing 10 | 11 | 12 | def visualize_ranked_results(distmat, dataset, save_dir='log/ranked_results', topk=20): 13 | """ 14 | Visualize ranked results 15 | 16 | Support both imgreid and vidreid 17 | 18 | Args: 19 | - distmat: distance matrix of shape (num_query, num_gallery). 20 | - dataset: a 2-tuple containing (query, gallery), each contains a list of (img_path, pid, camid); 21 | for imgreid, img_path is a string, while for vidreid, img_path is a tuple containing 22 | a sequence of strings. 23 | - save_dir: directory to save output images. 24 | - topk: int, denoting top-k images in the rank list to be visualized. 25 | """ 26 | num_q, num_g = distmat.shape 27 | 28 | print("Visualizing top-{} ranks".format(topk)) 29 | print("# query: {}\n# gallery {}".format(num_q, num_g)) 30 | print("Saving images to '{}'".format(save_dir)) 31 | 32 | query, gallery = dataset 33 | assert num_q == len(query) 34 | assert num_g == len(gallery) 35 | 36 | indices = np.argsort(distmat, axis=1) 37 | mkdir_if_missing(save_dir) 38 | 39 | def _cp_img_to(src, dst, rank, prefix): 40 | """ 41 | - src: image path or tuple (for vidreid) 42 | - dst: target directory 43 | - rank: int, denoting ranked position, starting from 1 44 | - prefix: string 45 | """ 46 | if isinstance(src, tuple) or isinstance(src, list): 47 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3)) 48 | mkdir_if_missing(dst) 49 | for img_path in src: 50 | shutil.copy(img_path, dst) 51 | else: 52 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3) + '_name_' + osp.basename(src)) 53 | shutil.copy(src, dst) 54 | 55 | for q_idx in range(num_q): 56 | qimg_path, qpid, qcamid = query[q_idx] 57 | qdir = osp.join(save_dir, osp.basename(qimg_path)) 58 | mkdir_if_missing(qdir) 59 | _cp_img_to(qimg_path, qdir, rank=0, prefix='query') 60 | 61 | rank_idx = 1 62 | for g_idx in indices[q_idx,:]: 63 | gimg_path, gpid, gcamid = gallery[g_idx] 64 | invalid = (qpid == gpid) & (qcamid == gcamid) 65 | if not invalid: 66 | _cp_img_to(gimg_path, qdir, rank=rank_idx, prefix='gallery') 67 | rank_idx += 1 68 | if rank_idx > topk: 69 | break 70 | 71 | print("Done") 72 | -------------------------------------------------------------------------------- /torchreid/utils/torchtools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | def adjust_learning_rate(optimizer, base_lr, epoch, stepsize=20, gamma=0.1, 10 | linear_decay=False, final_lr=0, max_epoch=100): 11 | if linear_decay: 12 | # linearly decay learning rate from base_lr to final_lr 13 | frac_done = epoch / max_epoch 14 | lr = frac_done * final_lr + (1. - frac_done) * base_lr 15 | else: 16 | # decay learning rate by gamma for every stepsize 17 | lr = base_lr * (gamma ** (epoch // stepsize)) 18 | 19 | for param_group in optimizer.param_groups: 20 | param_group['lr'] = lr 21 | 22 | 23 | def set_bn_to_eval(m): 24 | # 1. no update for running mean and var 25 | # 2. scale and shift parameters are still trainable 26 | classname = m.__class__.__name__ 27 | if classname.find('BatchNorm') != -1: 28 | m.eval() 29 | 30 | 31 | def open_all_layers(model): 32 | """ 33 | Open all layers in model for training. 34 | 35 | Args: 36 | - model (nn.Module): neural net model. 37 | """ 38 | model.train() 39 | for p in model.parameters(): 40 | p.requires_grad = True 41 | 42 | 43 | def open_specified_layers(model, open_layers): 44 | """ 45 | Open specified layers in model for training while keeping 46 | other layers frozen. 47 | 48 | Args: 49 | - model (nn.Module): neural net model. 50 | - open_layers (list): list of layer names. 51 | """ 52 | if isinstance(model, nn.DataParallel): 53 | model = model.module 54 | 55 | for layer in open_layers: 56 | assert hasattr(model, layer), "'{}' is not an attribute of the model, please provide the correct name".format(layer) 57 | 58 | for name, module in model.named_children(): 59 | if name in open_layers: 60 | module.train() 61 | for p in module.parameters(): 62 | p.requires_grad = True 63 | else: 64 | module.eval() 65 | for p in module.parameters(): 66 | p.requires_grad = False 67 | 68 | 69 | def count_num_param(model): 70 | num_param = sum(p.numel() for p in model.parameters()) / 1e+06 71 | 72 | if isinstance(model, nn.DataParallel): 73 | model = model.module 74 | 75 | if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Module): 76 | # we ignore the classifier because it is unused at test time 77 | num_param -= sum(p.numel() for p in model.classifier.parameters()) / 1e+06 78 | return num_param -------------------------------------------------------------------------------- /train_Baseline.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | 4 | import os 5 | import sys 6 | import time 7 | import datetime 8 | import os.path as osp 9 | import numpy as np 10 | import cv2 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.backends.cudnn as cudnn 15 | from torch.optim import lr_scheduler 16 | import torch.optim as optim 17 | 18 | from args import argument_parser, image_dataset_kwargs, optimizer_kwargs 19 | from torchreid.data_manager import ImageDataManager 20 | from torchreid import models 21 | from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision, FullTripletLoss, CFLLoss 22 | from torchreid.utils.iotools import save_checkpoint, check_isfile 23 | from torchreid.utils.avgmeter import AverageMeter 24 | from torchreid.utils.loggers import Logger, RankLogger 25 | from torchreid.utils.torchtools import count_num_param, open_all_layers, open_specified_layers 26 | from torchreid.utils.reidtools import visualize_ranked_results 27 | from torchreid.eval_metrics import evaluate 28 | from torchreid.samplers import RandomIdentitySampler 29 | from torchreid.optimizers import init_optimizer 30 | 31 | # global variables 32 | parser = argument_parser() 33 | args = parser.parse_args() 34 | 35 | 36 | def main(): 37 | global args 38 | 39 | torch.manual_seed(args.seed) 40 | if not args.use_avai_gpus: os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices 41 | use_gpu = torch.cuda.is_available() 42 | if args.use_cpu: use_gpu = False 43 | log_name = 'log_test.txt' if args.evaluate else 'log_train.txt' 44 | sys.stdout = Logger(osp.join(args.save_dir, log_name)) 45 | print("==========\nArgs:{}\n==========".format(args)) 46 | 47 | if use_gpu: 48 | print("Currently using GPU {}".format(args.gpu_devices)) 49 | cudnn.benchmark = True 50 | torch.cuda.manual_seed_all(args.seed) 51 | else: 52 | print("Currently using CPU, however, GPU is highly recommended") 53 | 54 | print("Initializing image data manager") 55 | dm = ImageDataManager(use_gpu, **image_dataset_kwargs(args)) 56 | trainloader, testloader_dict = dm.return_dataloaders() 57 | 58 | # ReID-Stream: 59 | print("Initializing ReID-Stream: {}".format(args.arch)) 60 | model = models.init_model(name=args.arch, num_classes=dm.num_train_pids, reid_dim=args.reid_dim, loss={'xent', 'htri'}) 61 | print("ReID Model size: {:.3f} M".format(count_num_param(model))) 62 | 63 | criterion_xent = CrossEntropyLoss(num_classes=dm.num_train_pids, use_gpu=use_gpu, label_smooth=args.label_smooth) 64 | criterion_htri = TripletLoss(margin=args.margin) 65 | 66 | # 2. Optimizer 67 | # Main ReID-Stream: 68 | optimizer = init_optimizer(model.parameters(), **optimizer_kwargs(args)) 69 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.stepsize, gamma=args.gamma) 70 | 71 | if use_gpu: 72 | model = nn.DataParallel(model).cuda() 73 | 74 | if args.evaluate: 75 | print("Evaluate only") 76 | 77 | for name in args.target_names: 78 | print("Evaluating {} ...".format(name)) 79 | queryloader = testloader_dict[name]['query'] 80 | galleryloader = testloader_dict[name]['gallery'] 81 | distmat = test(model, queryloader, galleryloader, use_gpu, return_distmat=True) 82 | 83 | if args.visualize_ranks: 84 | visualize_ranked_results( 85 | distmat, dm.return_testdataset_by_name(name), 86 | save_dir=osp.join(args.save_dir, 'ranked_results', name), 87 | topk=20 88 | ) 89 | return 90 | 91 | start_time = time.time() 92 | ranklogger = RankLogger(args.source_names, args.target_names) 93 | train_time = 0 94 | print("==> Start training") 95 | 96 | for epoch in range(args.start_epoch, args.max_epoch): 97 | start_train_time = time.time() 98 | train(epoch, model, criterion_xent, criterion_htri, \ 99 | optimizer, trainloader, use_gpu) 100 | train_time += round(time.time() - start_train_time) 101 | 102 | scheduler.step() 103 | 104 | if (epoch + 1) > args.start_eval and args.eval_freq > 0 and (epoch + 1) % args.eval_freq == 0 or ( 105 | epoch + 1) == args.max_epoch: 106 | print("==> Test") 107 | 108 | for name in args.target_names: 109 | print("Evaluating {} ...".format(name)) 110 | queryloader = testloader_dict[name]['query'] 111 | galleryloader = testloader_dict[name]['gallery'] 112 | rank1 = test(model, queryloader, galleryloader, use_gpu) 113 | ranklogger.write(name, epoch + 1, rank1) 114 | 115 | if use_gpu: 116 | state_dict = model.module.state_dict() 117 | else: 118 | state_dict = model.state_dict() 119 | 120 | save_checkpoint({ 121 | 'state_dict': state_dict, 122 | 'rank1': rank1, 123 | 'epoch': epoch, 124 | }, False, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar')) 125 | 126 | elapsed = round(time.time() - start_time) 127 | elapsed = str(datetime.timedelta(seconds=elapsed)) 128 | train_time = str(datetime.timedelta(seconds=train_time)) 129 | print("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time)) 130 | ranklogger.show_summary() 131 | 132 | 133 | def train(epoch, model, criterion_xent, criterion_htri, \ 134 | optimizer, trainloader, use_gpu): 135 | losses = AverageMeter() 136 | batch_time = AverageMeter() 137 | data_time = AverageMeter() 138 | 139 | model.train() 140 | 141 | end = time.time() 142 | for batch_idx, (imgs, pids, cloth_ids, _, img_paths, masks) in enumerate(trainloader): 143 | data_time.update(time.time() - end) 144 | 145 | if use_gpu: 146 | imgs, pids, cloth_ids, masks = imgs.cuda(), pids.cuda(), cloth_ids.cuda(), masks.cuda() 147 | 148 | # gait img 64*64 only have person in the middle, but reid person 64*64 have person in the entire content, so need this pre-processing: 149 | padding_length = (args.mask_height - args.mask_width) // 2 150 | left_right_padding = nn.ZeroPad2d((padding_length, padding_length, 0, 0)) 151 | masks = left_right_padding(masks) 152 | 153 | # Main ReID-Stream: 154 | features, outputs = model(imgs) 155 | 156 | # ReID loss local: 157 | xent_loss = criterion_xent(outputs, pids) 158 | htri_loss = criterion_htri(features, pids) 159 | 160 | loss_total = args.loss_ReID_cla_local * xent_loss + args.loss_ReID_tri_local * htri_loss 161 | 162 | optimizer.zero_grad() 163 | loss_total.backward() 164 | optimizer.step() 165 | 166 | batch_time.update(time.time() - end) 167 | 168 | losses.update(loss_total.item(), pids.size(0)) 169 | 170 | if (batch_idx + 1) % args.print_freq == 0: 171 | print('Epoch: [{0}][{1}/{2}]\t' 172 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 173 | 'Data {data_time.val:.4f} ({data_time.avg:.4f})\t' 174 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( 175 | epoch + 1, batch_idx + 1, len(trainloader), batch_time=batch_time, 176 | data_time=data_time, loss=losses)) 177 | 178 | end = time.time() 179 | 180 | 181 | def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20], return_distmat=False): 182 | batch_time = AverageMeter() 183 | 184 | model.eval() 185 | 186 | with torch.no_grad(): 187 | qf, q_pids, q_cloth_ids, q_camids = [], [], [], [] 188 | for batch_idx, (imgs, pids, cloth_ids, camids, _) in enumerate(queryloader): 189 | if use_gpu: 190 | imgs = imgs.cuda() 191 | 192 | end = time.time() 193 | features = model(imgs) # 其实第二个应该是gait,但是无所谓,用不到 194 | batch_time.update(time.time() - end) 195 | 196 | features = features.data.cpu() 197 | qf.append(features) 198 | q_pids.extend(pids) 199 | q_cloth_ids.extend(cloth_ids) 200 | q_camids.extend(camids) 201 | qf = torch.cat(qf, 0) 202 | q_pids = np.asarray(q_pids) 203 | q_cloth_ids = np.asarray(q_cloth_ids) 204 | q_camids = np.asarray(q_camids) 205 | 206 | print("Extracted features for query set, obtained {}-by-{} matrix".format(qf.size(0), qf.size(1))) 207 | 208 | gf, g_pids, g_cloth_ids, g_camids = [], [], [], [] 209 | for batch_idx, (imgs, pids, cloth_ids, camids, _) in enumerate(galleryloader): 210 | if use_gpu: 211 | imgs = imgs.cuda() 212 | 213 | end = time.time() 214 | features = model(imgs) 215 | batch_time.update(time.time() - end) 216 | 217 | features = features.data.cpu() 218 | gf.append(features) 219 | g_pids.extend(pids) 220 | g_cloth_ids.extend(cloth_ids) 221 | g_camids.extend(camids) 222 | gf = torch.cat(gf, 0) 223 | g_pids = np.asarray(g_pids) 224 | g_cloth_ids = np.asarray(g_cloth_ids) 225 | g_camids = np.asarray(g_camids) 226 | 227 | print("Extracted features for gallery set, obtained {}-by-{} matrix".format(gf.size(0), gf.size(1))) 228 | 229 | print("==> BatchTime(s)/BatchSize(img): {:.3f}/{}".format(batch_time.avg, args.test_batch_size)) 230 | 231 | m, n = qf.size(0), gf.size(0) 232 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 233 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 234 | distmat.addmm_(1, -2, qf, gf.t()) 235 | distmat = distmat.numpy() 236 | 237 | print("Computing CMC and mAP") 238 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_cloth_ids, g_cloth_ids, q_camids, g_camids, use_metric_cuhk03=args.use_metric_cuhk03) 239 | 240 | print("Results ----------") 241 | print("mAP: {:.1%}".format(mAP)) 242 | print("CMC curve") 243 | for r in ranks: 244 | print("Rank-{:<3}: {:.1%}".format(r, cmc[r - 1])) 245 | print("------------------") 246 | 247 | if return_distmat: 248 | return distmat 249 | return cmc[0] 250 | 251 | 252 | if __name__ == '__main__': 253 | main() --------------------------------------------------------------------------------