├── LICENSE ├── README.md ├── main.py ├── models ├── __init__.py ├── inflate.py ├── integration_distribution_module.py ├── resnet.py ├── salient_to_broad_module.py └── statistic_attention_block.py ├── transforms ├── spatial_transforms.py └── temporal_transforms.py └── utils ├── data_manager.py ├── eval_metrics.py ├── losses.py ├── samplers.py ├── utils.py └── video_loader.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 baist 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### A codebase for video-based person re-identification 2 | 3 | Salient-to-Broad Transition for Video Person Re-identification (CVPR 2022) 4 | 5 | SANet: Statistic Attention Network for Video-Based Person Re-Identification (TCSVT 2021) 6 | 7 | ### Get started 8 | 9 | ```Shell 10 | # Train 11 | python main.py \ 12 | --arch ${sinet, sbnet, idnet, sanet} \ 13 | --dataset ${mars, lsvid, ...} \ 14 | --root ${path of dataset} \ 15 | --gpu_devices 0,1 \ 16 | --save_dir ${path for saving modles and logs} \ 17 | 18 | # Test with all frames 19 | python main.py \ 20 | --arch ${sinet, sbnet, idnet, sanet} \ 21 | --dataset mars \ 22 | --root ${path of dataset} \ 23 | --gpu_devices 0,1 \ 24 | --save_dir ${path for saving logs} \ 25 | --evaluate --all_frames --resume ${path of pretrained model} 26 | ``` 27 | 28 | 29 | ### Pretrained models 30 | 31 | #### MARS 32 | | Methods | Paper | Reproduce | Download | 33 | |----- | -----| ----- | -----| 34 | | SBNet (ResNet50 + SBM) | 85.7/90.2 | 85.6/90.7 | [model](https://drive.google.com/file/d/1l0VeAzZ1-Z7Gbrp6jLuF_C6MIAiamrpQ/view?usp=sharing) | 35 | | IDNet (Resnet50 + IDM) | 85.9/90.5 | 85.9/90.4 | [model](https://drive.google.com/file/d/1XxJUxaUXDDB1cq6d6W5aGfwj54OvqW5N/view?usp=sharing) | 36 | | **SINet** (ResNet50 + SBM + IDM) | 86.2/91.0 | 86.3/90.9 | [model](https://drive.google.com/file/d/18YKaBdexzc49A-zhmT8vJY_xug_eLiYF/view?usp=sharing) | 37 | | **SANet** (ResNet50 + SA Block) | 86.0/91.2 | 86.7/91.2 | [model](https://drive.google.com/file/d/1yhX4trD02-ryJ7jRObstmHk9IZb9Smc3/view?usp=sharing) | 38 | 39 | 40 | #### LS-VID 41 | 42 | | Methods | Paper | Reproduce | Download | 43 | |----- | -----| ----- | -----| 44 | | SBNet (ResNet50 + SBM) | 77.1/85.1 | 77.2/85.3 | [model](https://drive.google.com/file/d/1bAxPRKoFoLluP3dVpzpsEJCY_kJhaL2v/view?usp=sharing) | 45 | | IDNet (Resnet50 + IDM) | 78.0/86.2 | 78.2/86.0 | [model](https://drive.google.com/file/d/1l-vH5huoodRjiNBbfWAZIjLdXekAi70X/view?usp=sharing) | 46 | | **SINet** (ResNet50 + SBM + IDM) | 79.6/87.4 | 79.9/87.2 | [model](https://drive.google.com/file/d/1Xdd_XUPyhbrrB06wDq_qUdUzMKdjD9FK/view?usp=sharing) | 47 | 48 | ### Citation 49 | 50 | If you use our code in your research or wish to refer to the baseline results, please use the following BibTeX entry. 51 | 52 | @inproceedings{bai2022SINet, 53 | title={Salient-to-Broad Transition for Video Person Re-identification}, 54 | author={Bai, Shutao and Ma, Bingpeng and Chang, Hong and Huang, Rui and Chen, Xilin}, 55 | booktitle={CVPR}, 56 | year={2022}, 57 | } 58 | 59 | @ARTICLE{9570321, 60 | author={Bai, Shutao and Ma, Bingpeng and Chang, Hong and Huang, Rui and Shan, Shiguang and Chen, Xilin}, 61 | journal={IEEE Transactions on Circuits and Systems for Video Technology}, 62 | title={SANet: Statistic Attention Network for Video-Based Person Re-Identification}, 63 | year={2021}, 64 | volume={}, 65 | number={}, 66 | pages={1-1}, 67 | doi={10.1109/TCSVT.2021.3119983} 68 | } 69 | 70 | ### Acknowledgments 71 | 72 | This code is based on the implementations of [**AP3D**](https://github.com/guxinqian/AP3D). 73 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import sys 4 | import time 5 | import random 6 | import datetime 7 | import argparse 8 | import numpy as np 9 | import os.path as osp 10 | from tqdm import tqdm 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.backends.cudnn as cudnn 16 | from torch.optim import lr_scheduler 17 | from torch.utils.data import DataLoader 18 | 19 | import transforms.spatial_transforms as ST 20 | import transforms.temporal_transforms as TT 21 | from models import init_model 22 | from utils.losses import TripletLoss, InfoNce 23 | from utils.utils import AverageMeter, Logger, save_checkpoint, print_time 24 | from utils.eval_metrics import evaluate 25 | from utils.samplers import RandomIdentitySampler 26 | from utils import data_manager 27 | from utils.video_loader import VideoDataset, VideoDatasetInfer 28 | 29 | 30 | parser = argparse.ArgumentParser(description='Train video model') 31 | # Datasets 32 | parser.add_argument('--root', type=str, default='/home/guxinqian/data/') 33 | parser.add_argument('-d', '--dataset', type=str, default='lsvid', 34 | choices=data_manager.get_names()) 35 | parser.add_argument('-j', '--workers', default=4, type=int, 36 | help="number of data loading workers (default: 4)") 37 | parser.add_argument('--height', type=int, default=256, 38 | help="height of an image (default: 256)") 39 | parser.add_argument('--width', type=int, default=128, 40 | help="width of an image (default: 128)") 41 | # Augment 42 | parser.add_argument('--sample_stride', type=int, default=8, help="stride of images to sample in a tracklet") 43 | # Optimization options 44 | parser.add_argument('--max_epoch', default=160, type=int, 45 | help="maximum epochs to run") 46 | parser.add_argument('--start_epoch', default=0, type=int, 47 | help="manual epoch number (useful on restarts)") 48 | parser.add_argument('--train_batch', default=32, type=int, 49 | help="train batch size") 50 | parser.add_argument('--test_batch', default=32, type=int, help="has to be 1") 51 | parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float, 52 | help="initial learning rate, use 0.0001 for rnn, use 0.0003 for pooling and attention") 53 | parser.add_argument('--stepsize', default=[40, 80, 120], nargs='+', type=int, 54 | help="stepsize to decay learning rate") 55 | parser.add_argument('--gamma', default=0.1, type=float, 56 | help="learning rate decay") 57 | parser.add_argument('--weight_decay', default=5e-04, type=float, 58 | help="weight decay (default: 5e-04)") 59 | parser.add_argument('--margin', type=float, default=0.3, help="margin for triplet loss") 60 | parser.add_argument('--distance', type=str, default='cosine', help="euclidean or consine") 61 | parser.add_argument('--num_instances', type=int, default=4, help="number of instances per identity") 62 | parser.add_argument('--losses', default=['xent', 'htri'], nargs='+', type=str, help="losses") 63 | # Architecture 64 | parser.add_argument('-a', '--arch', type=str, default='c2resnet50', help="c2resnet50, nonlocalresnet50") 65 | parser.add_argument('--pretrain', action='store_true', help="load params form pretrain model on kinetics") 66 | parser.add_argument('--pretrain_model_path', type=str, default='', metavar='PATH') 67 | # Miscs 68 | parser.add_argument('--seed', type=int, default=1, help="manual seed") 69 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 70 | parser.add_argument('--evaluate', action='store_true', help="evaluation only") 71 | parser.add_argument('--eval_step', type=int, default=10, 72 | help="run evaluation for every N epochs (set to -1 to test after training)") 73 | parser.add_argument('--start_eval', type=int, default=0, help="start to evaluate after specific epoch") 74 | parser.add_argument('--save_dir', '--sd', type=str, default='') 75 | parser.add_argument('--use_cpu', action='store_true', help="use cpu") 76 | parser.add_argument('--gpu_devices', default='2,3', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES') 77 | 78 | parser.add_argument('--all_frames', action='store_true', help="evaluate with all frames ?") 79 | parser.add_argument('--seq_len', type=int, default=4, 80 | help="number of images to sample in a tracklet") 81 | parser.add_argument('--note', type=str, default='', help='additional description of this command') 82 | args = parser.parse_args() 83 | 84 | def specific_params(args): 85 | if args.arch in ['sinet', 'sbnet']: 86 | args.losses = ['xent', 'htri', 'infonce'] 87 | 88 | def main(): 89 | # fix the seed in random operation 90 | torch.backends.cudnn.benchmark = False 91 | torch.backends.cudnn.deterministic = True 92 | 93 | random.seed(args.seed) 94 | np.random.seed(args.seed) 95 | os.environ['PYTHONHASHSEED'] = str(args.seed) 96 | torch.manual_seed(args.seed) 97 | torch.cuda.manual_seed(args.seed) 98 | torch.cuda.manual_seed_all(args.seed) 99 | 100 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices 101 | use_gpu = torch.cuda.is_available() 102 | if args.use_cpu: use_gpu = False 103 | 104 | if not args.evaluate: 105 | sys.stdout = Logger(osp.join(args.save_dir, 'log_train.log')) 106 | elif args.all_frames: 107 | sys.stdout = Logger(osp.join(args.save_dir, 'log_eval_all_frames.log')) 108 | else: 109 | sys.stdout = Logger(osp.join(args.save_dir, 'log_eval_sampled_frames.log')) 110 | 111 | print_time("============ Initialized logger ============") 112 | print("\n".join("\t\t%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) 113 | 114 | print_time("============ Description ============") 115 | print_time("\t\t %s\n" % args.note) 116 | 117 | print_time("The experiment will be stored in %s\n" % args.save_dir) 118 | 119 | if use_gpu: 120 | print("Currently using GPU {}".format(args.gpu_devices)) 121 | cudnn.benchmark = True 122 | torch.manual_seed(args.seed) 123 | torch.cuda.manual_seed(args.seed) 124 | torch.cuda.manual_seed_all(args.seed) 125 | torch.backends.cudnn.benchmark = False 126 | torch.backends.cudnn.deterministic = True 127 | else: 128 | print("Currently using CPU (GPU is highly recommended)") 129 | 130 | 131 | print_time("Initializing dataset {}".format(args.dataset)) 132 | dataset = data_manager.init_dataset(name=args.dataset, root=args.root) 133 | 134 | # Data augmentation 135 | spatial_transform_train = ST.Compose([ 136 | ST.Scale((args.height, args.width), interpolation=3), 137 | ST.RandomHorizontalFlip(), 138 | ST.ToTensor(), 139 | ST.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 140 | ST.RandomErasing()]) 141 | 142 | spatial_transform_test = ST.Compose([ 143 | ST.Scale((args.height, args.width), interpolation=3), 144 | ST.ToTensor(), 145 | ST.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 146 | 147 | temporal_transform_train = TT.TemporalRestrictedCrop(size=args.seq_len) 148 | temporal_transform_test = TT.TemporalRestrictedBeginCrop(size=args.seq_len) 149 | 150 | dataset_train = dataset.train 151 | dataset_query = dataset.query 152 | dataset_gallery = dataset.gallery 153 | 154 | pin_memory = True if use_gpu else False 155 | 156 | trainloader = DataLoader( 157 | VideoDataset( 158 | dataset_train, 159 | spatial_transform=spatial_transform_train, 160 | temporal_transform=temporal_transform_train), 161 | sampler=RandomIdentitySampler(dataset.train, num_instances=args.num_instances), 162 | batch_size=args.train_batch, num_workers=args.workers, 163 | pin_memory=pin_memory, drop_last=True,) 164 | 165 | queryloader_sampled_frames = DataLoader( 166 | VideoDataset(dataset_query, spatial_transform=spatial_transform_test, temporal_transform=temporal_transform_test), 167 | batch_size=args.test_batch, shuffle=False, num_workers=args.workers, 168 | pin_memory=pin_memory, drop_last=False) 169 | 170 | galleryloader_sampled_frames = DataLoader( 171 | VideoDataset(dataset_gallery, spatial_transform=spatial_transform_test, temporal_transform=temporal_transform_test), 172 | batch_size=args.test_batch, shuffle=False, num_workers=args.workers, 173 | pin_memory=pin_memory, drop_last=False) 174 | 175 | queryloader_all_frames = DataLoader( 176 | VideoDatasetInfer( 177 | dataset_query, spatial_transform=spatial_transform_test, seq_len=args.seq_len), 178 | batch_size=1, shuffle=False, num_workers=args.workers, 179 | pin_memory=pin_memory, drop_last=False) 180 | 181 | galleryloader_all_frames = DataLoader( 182 | VideoDatasetInfer(dataset_gallery, spatial_transform=spatial_transform_test, seq_len=args.seq_len), 183 | batch_size=1, shuffle=False, num_workers=args.workers, 184 | pin_memory=pin_memory, drop_last=False) 185 | 186 | print_time("Initializing model: {}".format(args.arch)) 187 | model = init_model( 188 | name=args.arch, 189 | num_classes = dataset.num_train_pids, 190 | losses=args.losses, 191 | seq_len=args.seq_len) 192 | 193 | print_time("Model Size w/o Classifier: {:.5f}M".format( 194 | sum(p.numel() for name, p in model.named_parameters() if 'classifier' not in name and 'projection' not in name)/1000000.0)) 195 | 196 | criterions = { 197 | 'xent': nn.CrossEntropyLoss(), 198 | 'htri': TripletLoss(margin=args.margin, distance=args.distance), 199 | 'infonce': InfoNce(num_instance=args.num_instances)} 200 | 201 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 202 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=args.stepsize, gamma=args.gamma) 203 | start_epoch = args.start_epoch 204 | 205 | if args.pretrain: 206 | print("Loading pre-trained params from '{}'".format(args.pretrain_model_path)) 207 | pretrain_dict = torch.load(args.pretrain_model_path) 208 | model_dict = model.state_dict() 209 | model_dict.update(pretrain_dict) 210 | model.load_state_dict(model_dict) 211 | 212 | if args.resume: 213 | print_time("Loading checkpoint from '{}'".format(args.resume)) 214 | checkpoint = torch.load(args.resume) 215 | model.load_state_dict(checkpoint['state_dict']) 216 | start_epoch = checkpoint['epoch'] 217 | 218 | if use_gpu: 219 | model = nn.DataParallel(model).cuda() 220 | 221 | if args.evaluate: 222 | with torch.no_grad(): 223 | if args.all_frames: 224 | print_time('==> Evaluate with [all] frames!') 225 | test(model, queryloader_all_frames, galleryloader_all_frames, use_gpu) 226 | else: 227 | print_time('==> Evaluate with sampled [{}] frames per video!'.format(args.seq_len)) 228 | test(model, queryloader_sampled_frames, galleryloader_sampled_frames, use_gpu) 229 | return 230 | 231 | start_time = time.time() 232 | train_time = 0 233 | best_rank1 = -np.inf 234 | best_epoch = 0 235 | print_time("==> Start training") 236 | 237 | for epoch in range(start_epoch, args.max_epoch): 238 | start_train_time = time.time() 239 | train(epoch, model, criterions, optimizer, trainloader, use_gpu) 240 | train_time += round(time.time() - start_train_time) 241 | scheduler.step() 242 | 243 | if (epoch+1) >= args.start_eval and args.eval_step > 0 and (epoch+1) % args.eval_step == 0 or (epoch+1) == args.max_epoch: 244 | print_time("==> Test") 245 | with torch.no_grad(): 246 | rank1 = test(model, queryloader_sampled_frames, galleryloader_sampled_frames, use_gpu) 247 | 248 | is_best = rank1 > best_rank1 249 | if is_best: 250 | best_rank1 = rank1 251 | best_epoch = epoch + 1 252 | 253 | if use_gpu: state_dict = model.module.state_dict() 254 | else: state_dict = model.state_dict() 255 | save_checkpoint({ 256 | 'state_dict': state_dict, 257 | 'rank1': rank1, 258 | 'epoch': epoch, 259 | }, is_best, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch+1) + '.pth.tar')) 260 | 261 | print_time("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(best_rank1, best_epoch)) 262 | 263 | elapsed = round(time.time() - start_time) 264 | elapsed = str(datetime.timedelta(seconds=elapsed)) 265 | train_time = str(datetime.timedelta(seconds=train_time)) 266 | print_time('=='*50) 267 | print_time("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time)) 268 | 269 | # using all frames to evaluate the final performance after training 270 | args.all_frames = True 271 | 272 | infer_epochs = [150] 273 | if best_epoch !=150: infer_epochs.append(best_epoch) 274 | 275 | for epoch in infer_epochs: 276 | best_checkpoint_path = osp.join(args.save_dir, 'checkpoint_ep' + str(epoch) + '.pth.tar') 277 | checkpoint = torch.load(best_checkpoint_path) 278 | model.module.load_state_dict(checkpoint['state_dict']) 279 | 280 | print_time('==> Evaluate with all frames!') 281 | print_time("Loading checkpoint from '{}'".format(best_checkpoint_path)) 282 | with torch.no_grad(): 283 | test(model, queryloader_all_frames, galleryloader_all_frames, use_gpu) 284 | return 285 | 286 | def train(epoch, model, criterions, optimizer, trainloader, use_gpu): 287 | batch_xent_loss = AverageMeter() 288 | batch_htri_loss = AverageMeter() 289 | batch_info_loss = AverageMeter() 290 | batch_loss = AverageMeter() 291 | batch_corrects = AverageMeter() 292 | batch_time = AverageMeter() 293 | data_time = AverageMeter() 294 | 295 | model.train() 296 | 297 | end = time.time() 298 | pd = tqdm(total=len(trainloader), ncols=120, leave=False) 299 | for batch_idx, (vids, pids, camid) in enumerate(trainloader): 300 | pd.set_postfix({'Acc': '{:>7.2%}'.format(batch_corrects.avg), }) 301 | pd.update(1) 302 | if (pids-pids[0]).sum() == 0: 303 | continue 304 | 305 | if use_gpu: 306 | vids = vids.cuda() 307 | pids = pids.cuda() 308 | 309 | # measure data loading time 310 | data_time.update(time.time() - end) 311 | 312 | # zero the parameter gradients 313 | optimizer.zero_grad() 314 | 315 | if 'infonce' in args.losses: 316 | y, f, x = model(vids) 317 | # combine hard triplet loss with cross entropy loss 318 | xent_loss = criterions['xent'](y, pids) 319 | htri_loss = criterions['htri'](f, pids) 320 | info_loss = criterions['infonce'](x) 321 | loss = xent_loss + htri_loss + 0.001 * info_loss 322 | else: 323 | y, f = model(vids) 324 | # combine hard triplet loss with cross entropy loss 325 | xent_loss = criterions['xent'](y, pids) 326 | htri_loss = criterions['htri'](f, pids) 327 | loss = xent_loss + htri_loss 328 | info_loss = htri_loss * 0 329 | 330 | # backward + optimize 331 | loss.backward() 332 | optimizer.step() 333 | 334 | # statistics 335 | _, preds = torch.max(y.data, 1) 336 | batch_corrects.update(torch.sum(preds == pids.data).float()/pids.size(0), pids.size(0)) 337 | batch_xent_loss.update(xent_loss.item(), pids.size(0)) 338 | batch_htri_loss.update(htri_loss.item(), pids.size(0)) 339 | batch_info_loss.update(info_loss.item(), pids.size(0)) 340 | batch_loss.update(loss.item(), pids.size(0)) 341 | 342 | # measure elapsed time 343 | batch_time.update(time.time() - end) 344 | end = time.time() 345 | 346 | pd.close() 347 | 348 | print_time('Epoch{:>4d} ' 349 | 'Time:{batch_time.sum:>5.1f}s ' 350 | 'Data:{data_time.sum:>4.1f}s | ' 351 | 'Loss:{loss.avg:>7.4f} ' 352 | 'xent:{ce.avg:>7.4f} ' 353 | 'htri:{ht.avg:>7.4f} | ' 354 | 'infonce:{info.avg:>7.4f} | ' 355 | 'Acc:{acc.avg:>7.2%} '.format( 356 | epoch+1, batch_time=batch_time, 357 | data_time=data_time, loss=batch_loss, 358 | ce=batch_xent_loss, ht=batch_htri_loss, 359 | info=batch_info_loss, acc=batch_corrects,)) 360 | 361 | def _cal_dist(qf, gf, distance='cosine'): 362 | """ 363 | :param logger: 364 | :param qf: (query_num, feat_dim) 365 | :param gf: (gallery_num, feat_dim) 366 | :param distance: 367 | cosine 368 | :return: 369 | distance matrix with shape, (query_num, gallery_num) 370 | """ 371 | if distance == 'cosine': 372 | qf = F.normalize(qf, dim=1, p=2) 373 | gf = F.normalize(gf, dim=1, p=2) 374 | distmat = -torch.matmul(qf, gf.transpose(0, 1)) 375 | else: 376 | raise NotImplementedError 377 | return distmat 378 | 379 | def extract_feat_sampled_frames(model, vids, use_gpu=True): 380 | """ 381 | :param model: 382 | :param vids: (b, 3, t, 256, 128) 383 | :param use_gpu: 384 | :return: 385 | features: (b, c) 386 | """ 387 | if use_gpu: vids = vids.cuda() 388 | f = model(vids) # (b, t, c) 389 | f = f.mean(-1) 390 | f = f.data.cpu() 391 | return f 392 | 393 | def extract_feat_all_frames(model, vids, max_clip_per_batch=45, use_gpu=True): 394 | """ 395 | :param model: 396 | :param vids: (_, b, c, t, h, w) 397 | :param max_clip_per_batch: 398 | :param use_gpu: 399 | :return: 400 | f, (1, C) 401 | """ 402 | if use_gpu: 403 | vids = vids.cuda() 404 | _, b, c, t, h, w = vids.size() 405 | vids = vids.reshape(b, c, t, h, w) 406 | 407 | if max_clip_per_batch is not None and b > max_clip_per_batch: 408 | feat_set = [] 409 | for i in range((b - 1) // max_clip_per_batch + 1): 410 | clip = vids[i * max_clip_per_batch: (i + 1) * max_clip_per_batch] 411 | f = model(clip) # (max_clip_per_batch, t, c) 412 | f = f.mean(-1) 413 | feat_set.append(f) 414 | f = torch.cat(feat_set, dim=0) 415 | else: 416 | f = model(vids) # (b, t, c) 417 | f = f.mean(-1) # (b, c) 418 | 419 | f = f.mean(0, keepdim=True) 420 | f = f.data.cpu() 421 | return f 422 | 423 | def _feats_of_loader(model, loader, feat_func=extract_feat_sampled_frames, use_gpu=True): 424 | qf, q_pids, q_camids = [], [], [] 425 | 426 | pd = tqdm(total=len(loader), ncols=120, leave=False) 427 | for batch_idx, (vids, pids, camids) in enumerate(loader): 428 | pd.update(1) 429 | 430 | f = feat_func(model, vids, use_gpu=use_gpu) 431 | qf.append(f) 432 | q_pids.extend(pids.numpy()) 433 | q_camids.extend(camids.numpy()) 434 | pd.close() 435 | 436 | qf = torch.cat(qf, 0) 437 | q_pids = np.asarray(q_pids) 438 | q_camids = np.asarray(q_camids) 439 | 440 | return qf, q_pids, q_camids 441 | 442 | def _eval_format_logger(cmc, mAP, ranks, desc=''): 443 | print_time("Results {}".format(desc)) 444 | ptr = "mAP: {:.2%}".format(mAP) 445 | for r in ranks: 446 | ptr += " | R-{:<3}: {:.2%}".format(r, cmc[r - 1]) 447 | print_time(ptr) 448 | print_time("--------------------------------------") 449 | 450 | def test(model, queryloader, galleryloader, use_gpu, ranks=[1, 5, 10, 20]): 451 | since = time.time() 452 | model.eval() 453 | 454 | if args.all_frames: 455 | feat_func = extract_feat_all_frames 456 | else: 457 | feat_func = extract_feat_sampled_frames 458 | 459 | qf, q_pids, q_camids = _feats_of_loader( 460 | model, 461 | queryloader, 462 | feat_func, 463 | use_gpu=use_gpu) 464 | print_time("Extracted features for query set, obtained {} matrix".format(qf.shape)) 465 | 466 | gf, g_pids, g_camids = _feats_of_loader( 467 | model, 468 | galleryloader, 469 | feat_func, 470 | use_gpu=use_gpu) 471 | print_time("Extracted features for gallery set, obtained {} matrix".format(gf.shape)) 472 | 473 | if args.dataset == 'mars': 474 | # gallery set must contain query set, otherwise 140 query imgs will not have ground truth. 475 | gf = torch.cat((qf, gf), 0) 476 | g_pids = np.append(q_pids, g_pids) 477 | g_camids = np.append(q_camids, g_camids) 478 | 479 | time_elapsed = time.time() - since 480 | print_time('Extracting features complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 481 | 482 | print_time("Computing distance matrix") 483 | distmat = _cal_dist(qf=qf, gf=gf, distance=args.distance) 484 | print_time("Computing CMC and mAP") 485 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids) 486 | _eval_format_logger(cmc, mAP, ranks, '') 487 | 488 | return cmc[0] 489 | 490 | 491 | if __name__ == '__main__': 492 | specific_params(args) 493 | main() 494 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .resnet import * 4 | __factory = { 5 | 'resnet50': ResNet50, 6 | 'sanet': SANet, 7 | 'sinet': SINet, 8 | 'idnet': IDNet, 9 | 'sbnet': SBNet, 10 | } 11 | 12 | def get_names(): 13 | return __factory.keys() 14 | 15 | def init_model(name, *args, **kwargs): 16 | if name not in __factory.keys(): 17 | raise KeyError("Unknown model: {}".format(name)) 18 | 19 | m = __factory[name](*args, **kwargs) 20 | 21 | return m 22 | -------------------------------------------------------------------------------- /models/inflate.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | def inflate_conv(conv2d, 7 | time_dim=3, 8 | time_padding=0, 9 | time_stride=1, 10 | time_dilation=1, 11 | center=False): 12 | # To preserve activations, padding should be by continuity and not zero 13 | # or no padding in time dimension 14 | kernel_dim = (time_dim, conv2d.kernel_size[0], conv2d.kernel_size[1]) 15 | padding = (time_padding, conv2d.padding[0], conv2d.padding[1]) 16 | stride = (time_stride, conv2d.stride[0], conv2d.stride[0]) 17 | dilation = (time_dilation, conv2d.dilation[0], conv2d.dilation[1]) 18 | conv3d = nn.Conv3d( 19 | conv2d.in_channels, 20 | conv2d.out_channels, 21 | kernel_dim, 22 | padding=padding, 23 | dilation=dilation, 24 | stride=stride, 25 | bias=True) 26 | 27 | # Repeat filter time_dim times along time dimension 28 | weight_2d = conv2d.weight.data 29 | if center: 30 | weight_3d = torch.zeros(*weight_2d.shape) 31 | weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) 32 | middle_idx = time_dim // 2 33 | weight_3d[:, :, middle_idx, :, :] = weight_2d 34 | else: 35 | weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) 36 | weight_3d = weight_3d / time_dim 37 | 38 | # Assign new params 39 | conv3d.weight = nn.Parameter(weight_3d) 40 | return conv3d 41 | 42 | 43 | def inflate_linear(linear2d, time_dim): 44 | """ 45 | Args: 46 | time_dim: final time dimension of the features 47 | """ 48 | linear3d = nn.Linear(linear2d.in_features * time_dim, 49 | linear2d.out_features) 50 | weight3d = linear2d.weight.data.repeat(1, time_dim) 51 | weight3d = weight3d / time_dim 52 | 53 | linear3d.weight = nn.Parameter(weight3d) 54 | linear3d.bias = linear2d.bias 55 | return linear3d 56 | 57 | 58 | def inflate_batch_norm(batch2d): 59 | batch3d = nn.BatchNorm3d(batch2d.num_features) 60 | batch2d._check_input_dim = batch3d._check_input_dim 61 | return batch2d 62 | 63 | 64 | def inflate_pool(pool2d, 65 | time_dim=1, 66 | time_padding=0, 67 | time_stride=None, 68 | time_dilation=1): 69 | kernel_dim = (time_dim, pool2d.kernel_size, pool2d.kernel_size) 70 | padding = (time_padding, pool2d.padding, pool2d.padding) 71 | if time_stride is None: 72 | time_stride = time_dim 73 | stride = (time_stride, pool2d.stride, pool2d.stride) 74 | if isinstance(pool2d, nn.MaxPool2d): 75 | dilation = (time_dilation, pool2d.dilation, pool2d.dilation) 76 | pool3d = nn.MaxPool3d( 77 | kernel_dim, 78 | padding=0, 79 | # padding=padding, 80 | dilation=dilation, 81 | stride=stride, 82 | ceil_mode=pool2d.ceil_mode) 83 | elif isinstance(pool2d, nn.AvgPool2d): 84 | pool3d = nn.AvgPool3d(kernel_dim, stride=stride) 85 | else: 86 | raise ValueError( 87 | '{} is not among known pooling classes'.format(type(pool2d))) 88 | return pool3d -------------------------------------------------------------------------------- /models/integration_distribution_module.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import math 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class IntegrationDistributionModule(nn.Module): 9 | def __init__(self, 10 | in_dim, 11 | factor=16, 12 | t=4): 13 | super(IntegrationDistributionModule, self).__init__() 14 | if in_dim == 256: 15 | in_thw = t * 64 * 32 16 | elif in_dim == 512: 17 | in_thw = t * 32 * 16 18 | else: 19 | in_thw = t * 16 * 8 20 | 21 | inter_dim = in_dim // factor 22 | inter_thw = in_thw // factor 23 | 24 | self.thw_reduction = nn.Sequential( 25 | nn.Conv2d(in_thw, inter_thw, kernel_size=(1, 1), stride=(1, 1)), 26 | nn.BatchNorm2d(inter_thw)) 27 | 28 | self.thw_expansion = nn.Sequential( 29 | nn.Conv2d(inter_thw, in_thw, kernel_size=(1, 1), stride=(1, 1))) 30 | 31 | self.chl_reduction = nn.Sequential( 32 | nn.Conv3d(in_dim, inter_dim, kernel_size=(1, 1, 1), stride=(1, 1, 1)), 33 | nn.BatchNorm3d(inter_dim)) 34 | 35 | self.chl_expansion = nn.Sequential( 36 | nn.Conv3d(inter_dim, in_dim, kernel_size=(1, 1, 1), stride=(1, 1, 1))) 37 | 38 | # init 39 | for m in self.modules(): 40 | if isinstance(m, nn.Conv3d) or isinstance(m, nn.Conv2d): 41 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 42 | m.weight.data.normal_(0, math.sqrt(2. / n)) 43 | elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.BatchNorm1d) or isinstance(m, nn.BatchNorm2d): 44 | m.weight.data.fill_(1) 45 | m.bias.data.zero_() 46 | elif isinstance(m, nn.Conv1d): 47 | n = m.kernel_size[0] * m.out_channels 48 | m.weight.data.normal_(0, math.sqrt(2. / n)) 49 | 50 | self.zero_init(self.chl_expansion) 51 | 52 | def zero_init(self, W): 53 | nn.init.constant_(W[-1].weight.data, 0.0) 54 | nn.init.constant_(W[-1].bias.data, 0.0) 55 | 56 | def forward(self, x): 57 | ''' 58 | :param x: (b, c, t, h, w) 59 | :return: 60 | ''' 61 | x1 = self.chl_reduction(x) 62 | b, c, t, h, w = x1.size() 63 | 64 | x1 = x1.reshape(b, c, -1, 1).transpose(1, 2) # (b, t*h*w, c, 1) 65 | x2 = self.thw_reduction(x1) 66 | x3 = self.thw_expansion(x2) 67 | x3 = x3 + x1 68 | x3 = x3.transpose(1, 2).reshape(b, c, t, h, w) 69 | x4 = self.chl_expansion(x3) 70 | 71 | z = F.relu(x + x4) 72 | return z 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | from functools import partial 5 | import torchvision 6 | import torch.nn as nn 7 | from torch.nn import init 8 | from torch.nn import functional as F 9 | 10 | from models import inflate 11 | from models.statistic_attention_block import StatisticAttentionBlock 12 | from models.salient_to_broad_module import Salient2BroadModule 13 | from models.integration_distribution_module import IntegrationDistributionModule 14 | 15 | 16 | def weights_init_kaiming(m): 17 | classname = m.__class__.__name__ 18 | # print(classname) 19 | if classname.find('Conv') != -1: 20 | # init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 21 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 22 | init.constant_(m.bias.data, 0.0) 23 | elif classname.find('Linear') != -1: 24 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 25 | init.constant_(m.bias.data, 0.0) 26 | elif classname.find('BatchNorm') != -1: 27 | init.normal_(m.weight.data, 1.0, 0.02) 28 | init.constant_(m.bias.data, 0.0) 29 | 30 | 31 | def weights_init_classifier(m): 32 | classname = m.__class__.__name__ 33 | if classname.find('Linear') != -1: 34 | init.normal_(m.weight.data, std=0.001) 35 | init.constant_(m.bias.data, 0.0) 36 | 37 | 38 | class Bottleneck3d(nn.Module): 39 | 40 | def __init__(self, bottleneck2d, inflate_time=False): 41 | super(Bottleneck3d, self).__init__() 42 | 43 | if inflate_time == True: 44 | self.conv1 = inflate.inflate_conv(bottleneck2d.conv1, time_dim=3, time_padding=1, center=True) 45 | else: 46 | self.conv1 = inflate.inflate_conv(bottleneck2d.conv1, time_dim=1) 47 | self.bn1 = inflate.inflate_batch_norm(bottleneck2d.bn1) 48 | self.conv2 = inflate.inflate_conv(bottleneck2d.conv2, time_dim=1) 49 | self.bn2 = inflate.inflate_batch_norm(bottleneck2d.bn2) 50 | self.conv3 = inflate.inflate_conv(bottleneck2d.conv3, time_dim=1) 51 | self.bn3 = inflate.inflate_batch_norm(bottleneck2d.bn3) 52 | self.relu = nn.ReLU(inplace=True) 53 | 54 | if bottleneck2d.downsample is not None: 55 | self.downsample = self._inflate_downsample(bottleneck2d.downsample) 56 | else: 57 | self.downsample = None 58 | 59 | def _inflate_downsample(self, downsample2d, time_stride=1): 60 | downsample3d = nn.Sequential( 61 | inflate.inflate_conv(downsample2d[0], time_dim=1, 62 | time_stride=time_stride), 63 | inflate.inflate_batch_norm(downsample2d[1])) 64 | return downsample3d 65 | 66 | def forward(self, x): 67 | residual = x 68 | out = self.conv1(x) 69 | out = self.bn1(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv2(out) 73 | out = self.bn2(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv3(out) 77 | out = self.bn3(out) 78 | 79 | if self.downsample is not None: 80 | residual = self.downsample(x) 81 | 82 | out += residual 83 | out = self.relu(out) 84 | 85 | return out 86 | 87 | 88 | class _ResNet50(nn.Module): 89 | 90 | def __init__(self, num_classes, losses, plugin_dict): 91 | super(_ResNet50, self).__init__() 92 | self.losses = losses 93 | 94 | resnet2d = torchvision.models.resnet50(pretrained=True) 95 | resnet2d.layer4[0].conv2.stride = (1, 1) 96 | resnet2d.layer4[0].downsample[0].stride = (1, 1) 97 | 98 | self.conv1 = inflate.inflate_conv(resnet2d.conv1, time_dim=1) 99 | self.bn1 = inflate.inflate_batch_norm(resnet2d.bn1) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.maxpool = inflate.inflate_pool(resnet2d.maxpool, time_dim=1) 102 | 103 | self.layer1 = self._inflate_reslayer(resnet2d.layer1, plugin_dict[1]) 104 | self.layer2 = self._inflate_reslayer(resnet2d.layer2, plugin_dict[2]) 105 | self.layer3 = self._inflate_reslayer(resnet2d.layer3, plugin_dict[3]) 106 | self.layer4 = self._inflate_reslayer(resnet2d.layer4, plugin_dict[4]) 107 | 108 | self.bn = nn.BatchNorm1d(2048) 109 | self.bn.apply(weights_init_kaiming) 110 | 111 | self.classifier = nn.Linear(2048, num_classes) 112 | self.classifier.apply(weights_init_classifier) 113 | 114 | if 'infonce' in self.losses: 115 | self.projection = nn.Conv1d(2048, 2048, (1,), (1,)) 116 | self.projection.apply(weights_init_kaiming) 117 | 118 | def _inflate_reslayer(self, reslayer2d, plugin_dict): 119 | reslayers3d = [] 120 | for i, layer2d in enumerate(reslayer2d): 121 | layer3d = Bottleneck3d(layer2d) 122 | reslayers3d.append(layer3d) 123 | 124 | if i in plugin_dict: 125 | reslayers3d.append(plugin_dict[i](in_dim=layer2d.bn3.num_features)) 126 | 127 | return nn.Sequential(*reslayers3d) 128 | 129 | def forward(self, x): 130 | x = self.conv1(x) 131 | x = self.bn1(x) 132 | x = self.relu(x) 133 | x = self.maxpool(x) 134 | 135 | x = self.layer1(x) 136 | x = self.layer2(x) 137 | x = self.layer3(x) 138 | x = self.layer4(x) 139 | 140 | b, c, t, h, w = x.size() 141 | x = x.permute(0, 2, 1, 3, 4).contiguous() 142 | x = x.view(b*t, c, h, w) 143 | x = F.max_pool2d(x, x.size()[2:]) 144 | x = x.view(b, t, -1) 145 | x = x.transpose(1, 2) # (b, c, t) 146 | 147 | if not self.training: 148 | x = self.bn(x) 149 | return x 150 | 151 | v = x.mean(-1) 152 | f = self.bn(v) 153 | y = self.classifier(f) 154 | 155 | if 'infonce' in self.losses: 156 | x = self.bn(x) 157 | x = self.projection(x) 158 | return y, f, x 159 | 160 | return y, f 161 | 162 | 163 | def ResNet50(num_classes, losses): 164 | plugin_dict = { 165 | 1: {}, 166 | 2: {}, 167 | 3: {}, 168 | 4: {} 169 | } 170 | return _ResNet50(num_classes, losses, plugin_dict) 171 | 172 | 173 | def SANet(num_classes, losses, **kwargs): 174 | plugin_dict = { 175 | 1: {}, 176 | 2: {1: StatisticAttentionBlock, 177 | 3: StatisticAttentionBlock}, 178 | 3: {}, 179 | 4: {} 180 | } 181 | return _ResNet50(num_classes, losses, plugin_dict) 182 | 183 | 184 | def IDNet(num_classes, losses, seq_len, **kwargs): 185 | plugin_dict = { 186 | 1: {}, 187 | 2: {1: partial(IntegrationDistributionModule, t=seq_len), 188 | 3: partial(IntegrationDistributionModule, t=seq_len)}, 189 | 3: {}, 190 | 4: {} 191 | } 192 | return _ResNet50(num_classes, losses, plugin_dict) 193 | 194 | 195 | def SBNet(num_classes, losses, seq_len, **kwargs): 196 | plugin_dict = { 197 | 1: {}, 198 | 2: {}, 199 | 3: {1: partial(Salient2BroadModule, split_pos=0), 200 | 3: partial(Salient2BroadModule, split_pos=1), 201 | 5: partial(Salient2BroadModule, split_pos=2)}, 202 | 4: {} 203 | } 204 | return _ResNet50(num_classes, losses, plugin_dict) 205 | 206 | 207 | def SINet(num_classes, losses, seq_len, **kwargs): 208 | plugin_dict = { 209 | 1: {}, 210 | 2: {1: partial(IntegrationDistributionModule, t=seq_len), 211 | 3: partial(IntegrationDistributionModule, t=seq_len)}, 212 | 3: {1: partial(Salient2BroadModule, split_pos=0), 213 | 3: partial(Salient2BroadModule, split_pos=1), 214 | 5: partial(Salient2BroadModule, split_pos=2)}, 215 | 4: {} 216 | } 217 | return _ResNet50(num_classes, losses, plugin_dict) 218 | -------------------------------------------------------------------------------- /models/salient_to_broad_module.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | import math 5 | import logging 6 | import random 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class Salient2BroadModule(nn.Module): 14 | def __init__(self, 15 | in_dim, 16 | inter_dim=None, 17 | split_pos=0, 18 | k=3, 19 | exp_beta=5.0, 20 | cpm_alpha=0.1): 21 | super().__init__() 22 | self.in_channels = in_dim 23 | self.inter_channels = inter_dim or in_dim // 4 24 | self.pos = split_pos 25 | self.k = k 26 | self.exp_beta = exp_beta 27 | self.cpm_alpha = cpm_alpha 28 | 29 | self.kernel = nn.Sequential( 30 | nn.Conv3d(self.in_channels, self.k * self.k, kernel_size=(1, 1, 1), stride=(1, 1, 1)), 31 | nn.BatchNorm3d(self.k * self.k), 32 | nn.ReLU()) 33 | 34 | self.se = nn.Sequential( 35 | nn.AdaptiveAvgPool3d(output_size=(1, 1, 1)), 36 | nn.Conv3d(self.in_channels, self.inter_channels, kernel_size=(1, 1, 1), stride=(1, 1, 1)), 37 | nn.ReLU(inplace=True), 38 | nn.Conv3d(self.inter_channels, self.in_channels, kernel_size=(1, 1, 1), stride=(1, 1, 1)), 39 | nn.Sigmoid()) 40 | 41 | for m in self.modules(): 42 | if isinstance(m, nn.Conv3d): 43 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 44 | m.weight.data.normal_(0, math.sqrt(2. / n)) 45 | elif isinstance(m, nn.BatchNorm3d): 46 | m.weight.data.fill_(1) 47 | m.bias.data.zero_() 48 | 49 | def _suppress(self, a, exp_beta=5.0): 50 | """ 51 | :param a: (b, 1, t, h, w) 52 | :return: 53 | """ 54 | a_sup = (a < 1).float().detach() 55 | a_exp = torch.exp((a-1)*a_sup*exp_beta) 56 | a = a_exp * a_sup + (1 - a_sup) 57 | return a 58 | 59 | def _channel_center(self, x): 60 | """ 61 | :param x: (b, c, t, h, w) 62 | :return: (b, c, 1, 1, 1) 63 | """ 64 | center_w_pad = torch.mean(x, dim=(2,3,4), keepdim=True) 65 | center_wo_pad = torch.mean(x[:,:,:,1:-1, 1:-1], dim=(2,3,4), keepdim=True) 66 | center = center_wo_pad/(center_w_pad + 1e-8) 67 | return center 68 | 69 | def channel_attention_layer(self, x): 70 | se = self.se(x) # (b, c, 1, 1, 1) 71 | center = self._channel_center(x) # (b, c, 1, 1, 1) 72 | center = (center > 1).float().detach() 73 | return se * center 74 | 75 | def _forward(self, x, pos=None): 76 | ''' 77 | :param x: (b, c, t, h, w) 78 | :return: 79 | ''' 80 | pos = self.pos if pos is None else pos 81 | 82 | b, c, t, h, w = x.shape 83 | xf = x[:, :, :pos + 1] 84 | xl = x[:, :, pos + 1:] 85 | 86 | cal = self.channel_attention_layer(x) 87 | xf_se = F.relu(xf * cal) 88 | 89 | # k*k spatial attention 90 | spatial_att = self.kernel(xf_se) # (b, k*k, tf, h, w) 91 | # (b, tf*hw, k*k) 92 | spatial_att = spatial_att.reshape(b, self.k*self.k, -1).transpose(-2, -1) 93 | if self.k != 1: 94 | spatial_att = F.normalize(spatial_att, dim=-1, p=1) 95 | spatial_att = F.normalize(spatial_att, dim=1, p=1) 96 | 97 | # obtain k*k conv kernel 98 | xf_reshape = xf_se.reshape(b, c, -1) 99 | # (b, c, 1, k, k) 100 | kernel = torch.matmul(xf_reshape, spatial_att) 101 | kernel = kernel.reshape(b, c, 1, self.k, self.k) 102 | 103 | # perform convolution with calculated kernel 104 | xl_se = F.relu(xl * cal) # (1, b*c, tl, h, w) 105 | xl_reshape = xl_se.reshape(b*c, -1, h, w) 106 | 107 | pad = (self.k-1)//2 108 | xl_reshape = F.pad(xl_reshape, pad=[pad,pad,pad,pad], mode='replicate') 109 | xl_reshape = xl_reshape.unsqueeze(0) 110 | f = F.conv3d(xl_reshape, weight=kernel, bias=None, stride=1, groups=b) 111 | f = f / (self.k * self.k) 112 | 113 | # suppress operation 114 | f = f.reshape(b, -1, h*w) 115 | f = F.softmax(f, dim=-1) 116 | f = f.reshape(b, 1, -1, h, w).clamp_min(1e-4) 117 | 118 | f = 1.0 / (f * h * w) 119 | f = self._suppress(f, exp_beta=self.exp_beta) 120 | 121 | # cross propagation 122 | xl_res = xl * f + self.cpm_alpha * F.adaptive_avg_pool3d(xf, 1) 123 | xf_res = xf + self.cpm_alpha * F.adaptive_avg_pool3d((1-f)* xl, 1)/F.adaptive_avg_pool3d((1-f), 1) 124 | res = torch.cat([xf_res, xl_res], dim=2) 125 | 126 | return res 127 | 128 | def forward(self, x, pos=None): 129 | b, c, t, h, w = x.shape 130 | if t == 4: 131 | return self._forward(x, pos) 132 | else: 133 | assert t % 4 == 0 134 | x = x.reshape(b, c, 2, 4, h, w) 135 | x = x.transpose(1, 2).reshape(b*2, c, 4, h, w) 136 | x = self._forward(x, pos) 137 | x = x.reshape(b, 2, c, 4, h, w).transpose(1, 2) 138 | x = x.reshape(b, c, t, h, w) 139 | 140 | return x 141 | 142 | 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /models/statistic_attention_block.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class StatisticAttentionBlock(nn.Module): 9 | """ 10 | SA block, statistic attention block 11 | include: 12 | 1. down-channel: 13 | channel reduction for speed and channel shuffle、 14 | 2. get_moments: 15 | calculate moments 16 | 3. sta_distribute: 17 | distribute statistics based on the similarity 18 | 4. up-channel: 19 | channel recover + residual connection 20 | """ 21 | def __init__( 22 | self, 23 | in_dim, 24 | inter_dim=None, 25 | moments=None, 26 | moment_norm=True): 27 | super(StatisticAttentionBlock, self).__init__() 28 | self.in_dim = in_dim 29 | self.inter_dim = in_dim // 4 if inter_dim is None else inter_dim 30 | 31 | self.moments = [1, 2, 4, 5, 6] if moments is None else moments 32 | self.moment_norm = moment_norm 33 | 34 | self.down_channel = nn.Conv3d(self.in_dim, self.inter_dim, kernel_size=(1, 1, 1), stride=(1, 1, 1)) 35 | 36 | self.up_channel = nn.Sequential( 37 | nn.Conv3d(self.inter_dim, self.in_dim, kernel_size=(1, 1, 1), stride=(1, 1, 1)), 38 | nn.BatchNorm3d(self.in_dim)) 39 | 40 | # initialization 41 | for m in self.modules(): 42 | if isinstance(m, nn.Conv3d): 43 | nn.init.kaiming_normal_(m.weight) 44 | nn.init.constant_(m.bias, 0) 45 | elif isinstance(m, nn.BatchNorm3d): 46 | m.weight.data.fill_(1) 47 | m.bias.data.zero_() 48 | 49 | nn.init.constant_(self.up_channel[-1].weight.data, 0.0) 50 | nn.init.constant_(self.up_channel[-1].bias.data, 0.0) 51 | 52 | def forward(self, x): 53 | z = self.down_channel(x) 54 | y = get_moments(z, moments=self.moments, moment_norm=self.moment_norm) 55 | z = sta_distribute(z, y) 56 | z = self.up_channel(z) 57 | z = z + x 58 | z = F.relu(z) 59 | return z 60 | 61 | def get_moments(z, moments, moment_norm=True): 62 | """ 63 | :param z: (b, c, t, h, w) 64 | :param moments: e.g. [1, 2, 3, 4] 65 | :param moment_norm: True or False 66 | :return: 67 | (b, c, m), m=|moments| 68 | """ 69 | b, c, t, h, w = z.size() 70 | 71 | mean = F.adaptive_avg_pool3d(z, output_size=1) # (b, c, 1, 1, 1) 72 | mean = mean.reshape(b, c, 1) 73 | moments_set = [mean, ] 74 | 75 | z = z.reshape(b, c, t * h * w) # (b, c, t*h*w) 76 | 77 | if 2 in moments: 78 | variance = torch.mean((z - torch.mean(z, dim=-1, keepdim=True)) ** 2, dim=-1, keepdim=True) 79 | if moment_norm: variance = torch.sqrt(variance) 80 | moments_set.append(variance) 81 | 82 | for i in moments: 83 | if i <= 2: continue 84 | 85 | c_moment = torch.mean((z - torch.mean(z, dim=-1, keepdim=True)) ** i, dim=-1, keepdim=True) 86 | if moment_norm: 87 | c_moment = c_moment / (variance ** i) 88 | moments_set.append(c_moment) 89 | 90 | y = torch.cat(moments_set, dim=2) # (b, c, m) 91 | return y 92 | 93 | def sta_distribute(x, mv, norm=True): 94 | """ 95 | :param x: feature map, size: (b, c, t, h, w) 96 | :param mv: moment vectors, size: (b, c, m) 97 | :param norm: True or False 98 | :return: 99 | (b, c, t, h, w) 100 | """ 101 | b, c, t, h, w = x.size() 102 | 103 | if norm: 104 | mv = F.normalize(mv, dim=1) # (b, c, m) 105 | 106 | x = x.reshape(b, c, t * h * w).permute(0, 2, 1) # (b, t*h*w, c) 107 | mv = mv.reshape(b, c, -1) 108 | 109 | f = torch.matmul(x, mv) # (b, t*h*w, m) 110 | f = F.softmax(f, dim=-1) # (b, t*h*w, m) 111 | 112 | # (b, t*h*w, c) <- (b, t*h*w, m) * (b, m, c) 113 | x = torch.matmul(f, mv.permute(0, 2, 1)) 114 | x = x.permute(0, 2, 1).reshape(b, c, t, h, w) 115 | 116 | return x 117 | 118 | 119 | -------------------------------------------------------------------------------- /transforms/spatial_transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import random 4 | import math 5 | import collections 6 | import numpy as np 7 | import torch 8 | 9 | from PIL import Image 10 | from torchvision import transforms 11 | 12 | try: 13 | import accimage 14 | except ImportError: 15 | accimage = None 16 | 17 | class Compose(object): 18 | """Composes several transforms together. 19 | Args: 20 | transforms (list of ``Transform`` objects): list of transforms to compose. 21 | Example: 22 | >>> transforms.Compose([ 23 | >>> transforms.CenterCrop(10), 24 | >>> transforms.ToTensor(), 25 | >>> ]) 26 | """ 27 | def __init__(self, transforms): 28 | self.transforms = transforms 29 | def __call__(self, img): 30 | for t in self.transforms: 31 | img = t(img) 32 | return img 33 | def randomize_parameters(self): 34 | for t in self.transforms: 35 | if hasattr(t, 'randomize_parameters'): 36 | t.randomize_parameters() 37 | 38 | 39 | class ToTensor(object): 40 | """Convert a ``PIL.Image`` or ``numpy.ndarray`` to tensor. 41 | Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 42 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 43 | """ 44 | 45 | def __init__(self, norm_value=255): 46 | self.norm_value = norm_value 47 | 48 | def __call__(self, pic): 49 | """ 50 | Args: 51 | pic (PIL.Image or numpy.ndarray): Image to be converted to tensor. 52 | Returns: 53 | Tensor: Converted image. 54 | """ 55 | if isinstance(pic, np.ndarray): 56 | # handle numpy array 57 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 58 | # backward compatibility 59 | return img.float().div(self.norm_value) 60 | 61 | if accimage is not None and isinstance(pic, accimage.Image): 62 | nppic = np.zeros( 63 | [pic.channels, pic.height, pic.width], dtype=np.float32) 64 | pic.copyto(nppic) 65 | return torch.from_numpy(nppic) 66 | 67 | # handle PIL Image 68 | if pic.mode == 'I': 69 | img = torch.from_numpy(np.array(pic, np.int32, copy=False)) 70 | elif pic.mode == 'I;16': 71 | img = torch.from_numpy(np.array(pic, np.int16, copy=False)) 72 | else: 73 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 74 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 75 | if pic.mode == 'YCbCr': 76 | nchannel = 3 77 | elif pic.mode == 'I;16': 78 | nchannel = 1 79 | else: 80 | nchannel = len(pic.mode) 81 | img = img.view(pic.size[1], pic.size[0], nchannel) 82 | # put it from HWC to CHW format 83 | # yikes, this transpose takes 80% of the loading time/CPU 84 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 85 | if isinstance(img, torch.ByteTensor): 86 | return img.float().div(self.norm_value) 87 | else: 88 | return img 89 | 90 | def randomize_parameters(self): 91 | pass 92 | 93 | 94 | class Normalize(object): 95 | """Normalize an tensor image with mean and standard deviation. 96 | Given mean: (R, G, B) and std: (R, G, B), 97 | will normalize each channel of the torch.*Tensor, i.e. 98 | channel = (channel - mean) / std 99 | Args: 100 | mean (sequence): Sequence of means for R, G, B channels respecitvely. 101 | std (sequence): Sequence of standard deviations for R, G, B channels 102 | respecitvely. 103 | """ 104 | 105 | def __init__(self, mean, std): 106 | self.mean = mean 107 | self.std = std 108 | 109 | def __call__(self, tensor): 110 | """ Args: 111 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 112 | Returns: 113 | Tensor: Normalized image. 114 | """ 115 | # TODO: make efficient 116 | for t, m, s in zip(tensor, self.mean, self.std): 117 | t.sub_(m).div_(s) 118 | return tensor 119 | 120 | def randomize_parameters(self): 121 | pass 122 | 123 | 124 | class Scale(object): 125 | """Rescale the input PIL.Image to the given size. 126 | Args: 127 | size (sequence or int): Desired output size. If size is a sequence like 128 | (w, h), output size will be matched to this. If size is an int, 129 | smaller edge of the image will be matched to this number. 130 | i.e, if height > width, then image will be rescaled to 131 | (size * height / width, size) 132 | interpolation (int, optional): Desired interpolation. Default is 133 | ``PIL.Image.BILINEAR`` 134 | """ 135 | 136 | def __init__(self, size, interpolation=Image.BILINEAR): 137 | assert isinstance(size, 138 | int) or (isinstance(size, collections.Iterable) and 139 | len(size) == 2) 140 | self.size = size 141 | self.interpolation = interpolation 142 | 143 | def __call__(self, img): 144 | """ 145 | Args: 146 | img (PIL.Image): Image to be scaled. 147 | Returns: 148 | PIL.Image: Rescaled image. 149 | """ 150 | if isinstance(self.size, int): 151 | w, h = img.size 152 | if (w <= h and w == self.size) or (h <= w and h == self.size): 153 | return img 154 | if w < h: 155 | ow = self.size 156 | oh = int(self.size * h / w) 157 | return img.resize((ow, oh), self.interpolation) 158 | else: 159 | oh = self.size 160 | ow = int(self.size * w / h) 161 | return img.resize((ow, oh), self.interpolation) 162 | else: 163 | return img.resize(self.size[::-1], self.interpolation) 164 | 165 | def randomize_parameters(self): 166 | pass 167 | 168 | 169 | class RandomHorizontalFlip(object): 170 | """Horizontally flip the given PIL.Image randomly with a probability of 0.5.""" 171 | 172 | def __call__(self, img): 173 | """ 174 | Args: 175 | img (PIL.Image): Image to be flipped. 176 | Returns: 177 | PIL.Image: Randomly flipped image. 178 | """ 179 | if self.p < 0.5: 180 | return img.transpose(Image.FLIP_LEFT_RIGHT) 181 | return img 182 | 183 | def randomize_parameters(self): 184 | self.p = random.random() 185 | 186 | 187 | class RandomErasing(object): 188 | """ Randomly selects a rectangle region in an image and erases its pixels. 189 | 'Random Erasing Data Augmentation' by Zhong et al. 190 | See https://arxiv.org/pdf/1708.04896.pdf 191 | Args: 192 | probability: The probability that the Random Erasing operation will be performed. 193 | sl: Minimum proportion of erased area against input image. 194 | sh: Maximum proportion of erased area against input image. 195 | r1: Minimum aspect ratio of erased area. 196 | mean: Erasing value. 197 | """ 198 | 199 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.4914, 0.4822, 0.4465]): 200 | self.probability = probability 201 | self.mean = mean 202 | self.sl = sl 203 | self.sh = sh 204 | self.r1 = r1 205 | 206 | def __call__(self, img): 207 | 208 | if self.erasing > self.probability: 209 | return img 210 | 211 | for attempt in range(100): 212 | area = img.size()[1] * img.size()[2] 213 | 214 | target_area = random.uniform(self.sl, self.sh) * area 215 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 216 | 217 | h = int(round(math.sqrt(target_area * aspect_ratio))) 218 | w = int(round(math.sqrt(target_area / aspect_ratio))) 219 | 220 | if w < img.size()[2] and h < img.size()[1]: 221 | x1 = int(round(self.tl_x * (img.size()[1] - h))) 222 | y1 = int(round(self.tl_y * (img.size()[2] - w))) # Changed by baist! 223 | 224 | if img.size()[0] == 3: 225 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 226 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 227 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 228 | else: 229 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 230 | return img 231 | 232 | return img 233 | 234 | def randomize_parameters(self): 235 | self.erasing = random.uniform(0, 1) < self.probability 236 | self.tl_x = random.random() 237 | self.tl_y = random.random() 238 | 239 | -------------------------------------------------------------------------------- /transforms/temporal_transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import random 4 | import numpy as np 5 | 6 | 7 | class LoopPadding(object): 8 | 9 | def __init__(self, size): 10 | self.size = size 11 | 12 | def __call__(self, frame_indices): 13 | out = list(frame_indices) 14 | 15 | while len(out) < self.size: 16 | for index in out: 17 | if len(out) >= self.size: 18 | break 19 | out.append(index) 20 | 21 | return out 22 | 23 | 24 | class TemporalRandomCrop(object): 25 | """Temporally crop the given frame indices at a random location. 26 | 27 | If the number of frames is less than the size, 28 | loop the indices as many times as necessary to satisfy the size. 29 | 30 | Args: 31 | size (int): Desired output size of the crop. 32 | """ 33 | 34 | def __init__(self, seq_len=4, sample_stride=8, **kwargs): 35 | self.size = seq_len 36 | self.stride = sample_stride 37 | 38 | def __call__(self, frame_indices, stride=None): 39 | """ 40 | Args: 41 | frame_indices (list): frame indices to be cropped. 42 | Returns: 43 | list: Cropped frame indices. 44 | """ 45 | frame_indices = list(frame_indices) 46 | self.stride = stride if stride is not None else self.stride 47 | 48 | if len(frame_indices) >= self.size * self.stride: 49 | rand_end = len(frame_indices) - (self.size - 1) * self.stride - 1 50 | begin_index = random.randint(0, rand_end) 51 | end_index = begin_index + (self.size - 1) * self.stride + 1 52 | out = frame_indices[begin_index:end_index:self.stride] 53 | elif len(frame_indices) >= self.size: 54 | index = np.random.choice(len(frame_indices), size=self.size, replace=False) 55 | index.sort() 56 | out = [frame_indices[index[i]] for i in range(self.size)] 57 | else: 58 | index = np.random.choice(len(frame_indices), size=self.size, replace=True) 59 | index.sort() 60 | out = [frame_indices[index[i]] for i in range(self.size)] 61 | return out 62 | 63 | 64 | class TemporalBeginCrop(object): 65 | 66 | def __init__(self, size=4, sample_stride=8, **kwargs): 67 | self.size = size 68 | self.stride = sample_stride 69 | 70 | def __call__(self, frame_indices): 71 | frame_indices = list(frame_indices) 72 | size = self.size 73 | stride = self.stride 74 | if len(frame_indices) >= size * stride: 75 | out = frame_indices[0:(size-1)*stride + 1: stride] 76 | 77 | elif len(frame_indices) >= size: 78 | out = frame_indices[:size] 79 | else: 80 | index = np.random.choice(len(frame_indices), size=size, replace=True) 81 | index.sort() 82 | out = [frame_indices[index[i]] for i in range(size)] 83 | return out 84 | 85 | 86 | class TemporalRestrictedCrop(object): 87 | 88 | def __init__(self, size=4, **kwargs): 89 | self.size = size 90 | 91 | def __call__(self, frame_indices): 92 | """ 93 | Args: 94 | frame_indices (list): frame indices to be cropped. 95 | Returns: 96 | list: Cropped frame indices. 97 | """ 98 | frame_indices = list(frame_indices) 99 | 100 | while len(frame_indices) < self.size: 101 | frame_indices.append(frame_indices[-1]) 102 | 103 | out = [] 104 | block_size = len(frame_indices)//self.size 105 | for i in range(self.size - 1): 106 | index = i*block_size + random.randint(0, block_size-1) 107 | out.append(frame_indices[index]) 108 | 109 | index = (self.size-1)*block_size + random.randint(0, len(frame_indices)-(self.size-1)*block_size-1) 110 | out.append(frame_indices[index]) 111 | 112 | return out 113 | 114 | 115 | class TemporalRestrictedBeginCrop(object): 116 | 117 | def __init__(self, size=4): 118 | self.size = size 119 | 120 | def __call__(self, frame_indices): 121 | """ 122 | Args: 123 | frame_indices (list): frame indices to be cropped. 124 | Returns: 125 | list: Cropped frame indices. 126 | """ 127 | frame_indices = list(frame_indices) 128 | 129 | while len(frame_indices) < self.size: 130 | frame_indices.append(frame_indices[-1]) 131 | 132 | out = [] 133 | block_size = len(frame_indices)//self.size 134 | for i in range(self.size): 135 | index = i*block_size 136 | out.append(frame_indices[index]) 137 | out.sort() 138 | 139 | return out 140 | 141 | 142 | tem_factory = { 143 | 'random': TemporalRandomCrop, 144 | 'begin': TemporalBeginCrop, 145 | 'restricted': TemporalRestrictedCrop, 146 | 'restrictedbegin': TemporalRestrictedBeginCrop, 147 | } 148 | 149 | 150 | -------------------------------------------------------------------------------- /utils/data_manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import glob 4 | import urllib 5 | import tarfile 6 | import os.path as osp 7 | from scipy.io import loadmat 8 | import numpy as np 9 | import logging 10 | import h5py 11 | import math 12 | 13 | from utils.utils import mkdir_if_missing, write_json, read_json 14 | 15 | """Dataset classes""" 16 | 17 | class Mars(object): 18 | """ 19 | MARS 20 | 21 | Reference: 22 | Zheng et al. MARS: A Video Benchmark for Large-Scale Person Re-identification. ECCV 2016. 23 | 24 | Dataset statistics: 25 | # identities: 1261 26 | # tracklets: 8298 (train) + 1980 (query) + 11310 (gallery) 27 | # cameras: 6 28 | 29 | Note: 30 | # gallery set must contain query set, otherwise 140 query imgs will not have ground truth. 31 | # gallery imgs with label=-1 can be remove, which do not influence on final performance. 32 | 33 | Args: 34 | min_seq_len (int): tracklet with length shorter than this value will be discarded (default: 0). 35 | """ 36 | 37 | def __init__(self, root=None, min_seq_len=0, split_id=0, *args, **kwargs): 38 | self._root = root 39 | self.train_name_path = osp.join(self._root, 'info/train_name.txt') 40 | self.test_name_path = osp.join(self._root, 'info/test_name.txt') 41 | self.track_train_info_path = osp.join(self._root, 'info/tracks_train_info.mat') 42 | self.track_test_info_path = osp.join(self._root, 'info/tracks_test_info.mat') 43 | self.query_IDX_path = osp.join(self._root, 'info/query_IDX.mat') 44 | 45 | self.sampling_type = split_id 46 | self._check_before_run() 47 | 48 | # prepare meta data 49 | train_names = self._get_names(self.train_name_path) 50 | test_names = self._get_names(self.test_name_path) 51 | track_train = loadmat(self.track_train_info_path)['track_train_info'] # numpy.ndarray (8298, 4) 52 | track_test = loadmat(self.track_test_info_path)['track_test_info'] # numpy.ndarray (12180, 4) 53 | query_IDX = loadmat(self.query_IDX_path)['query_IDX'].squeeze() # numpy.ndarray (1980,) 54 | query_IDX -= 1 # index from 0 55 | track_query = track_test[query_IDX,:] 56 | gallery_IDX = [i for i in range(track_test.shape[0]) if i not in query_IDX] 57 | track_gallery = track_test[gallery_IDX,:] 58 | # track_gallery = track_test 59 | 60 | train, num_train_tracklets, num_train_pids, num_train_imgs = \ 61 | self._process_data(train_names, track_train, home_dir='bbox_train', relabel=True, min_seq_len=min_seq_len) 62 | 63 | query, num_query_tracklets, num_query_pids, num_query_imgs = \ 64 | self._process_data(test_names, track_query, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) 65 | 66 | gallery, num_gallery_tracklets, num_gallery_pids, num_gallery_imgs = \ 67 | self._process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) 68 | 69 | num_imgs_per_tracklet = num_train_imgs + num_gallery_imgs + num_query_imgs 70 | min_num = np.min(num_imgs_per_tracklet) 71 | max_num = np.max(num_imgs_per_tracklet) 72 | avg_num = np.mean(num_imgs_per_tracklet) 73 | 74 | num_total_pids = num_train_pids + num_gallery_pids 75 | num_total_tracklets = num_train_tracklets + num_gallery_tracklets + num_query_tracklets 76 | 77 | print("=> MARS loaded") 78 | print("Dataset statistics:") 79 | print(" ------------------------------") 80 | print(" subset | # ids | # tracklets") 81 | print(" ------------------------------") 82 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 83 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 84 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 85 | print(" ------------------------------") 86 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 87 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 88 | print(" ------------------------------") 89 | 90 | self.train = train 91 | self.query = query 92 | self.gallery = gallery 93 | 94 | self.num_train_pids = num_train_pids 95 | self.num_query_pids = num_query_pids 96 | self.num_gallery_pids = num_gallery_pids 97 | 98 | def _check_before_run(self): 99 | """Check if all files are available before going deeper""" 100 | if not osp.exists(self._root): 101 | raise RuntimeError("'{}' is not available".format(self._root)) 102 | if not osp.exists(self.train_name_path): 103 | raise RuntimeError("'{}' is not available".format(self.train_name_path)) 104 | if not osp.exists(self.test_name_path): 105 | raise RuntimeError("'{}' is not available".format(self.test_name_path)) 106 | if not osp.exists(self.track_train_info_path): 107 | raise RuntimeError("'{}' is not available".format(self.track_train_info_path)) 108 | if not osp.exists(self.track_test_info_path): 109 | raise RuntimeError("'{}' is not available".format(self.track_test_info_path)) 110 | if not osp.exists(self.query_IDX_path): 111 | raise RuntimeError("'{}' is not available".format(self.query_IDX_path)) 112 | 113 | def _get_names(self, fpath): 114 | names = [] 115 | with open(fpath, 'r') as f: 116 | for line in f: 117 | new_line = line.rstrip() 118 | names.append(new_line) 119 | return names 120 | 121 | def _process_data(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0): 122 | assert home_dir in ['bbox_train', 'bbox_test'] 123 | num_tracklets = meta_data.shape[0] 124 | pid_list = list(set(meta_data[:,2].tolist())) 125 | num_pids = len(pid_list) 126 | 127 | if relabel: pid2label = {pid:label for label, pid in enumerate(pid_list)} 128 | tracklets = [] 129 | num_imgs_per_tracklet = [] 130 | 131 | vids_per_pid_count = np.zeros(len(pid_list)) 132 | 133 | for tracklet_idx in range(num_tracklets): 134 | data = meta_data[tracklet_idx, ...] 135 | start_index, end_index, pid, camid = data 136 | if pid == -1: continue # junk images are just ignored 137 | assert 1 <= camid <= 6 138 | if relabel: 139 | pid = pid2label[pid] 140 | camid -= 1 # index starts from 0 141 | img_names = names[start_index-1:end_index] 142 | 143 | # make sure image names correspond to the same person 144 | pnames = [img_name[:4] for img_name in img_names] 145 | assert len(set(pnames)) == 1, "Error: a single tracklet contains different person images" 146 | 147 | # make sure all images are captured under the same camera 148 | camnames = [img_name[5] for img_name in img_names] 149 | assert len(set(camnames)) == 1, "Error: images are captured under different cameras!" 150 | 151 | # append image names with directory information 152 | img_paths = [osp.join(self._root, home_dir, img_name[:4], img_name) for img_name in img_names] 153 | 154 | if home_dir == 'bbox_train': 155 | if self.sampling_type == 1250: 156 | if vids_per_pid_count[pid] >= 2: 157 | continue 158 | vids_per_pid_count[pid] = vids_per_pid_count[pid] + 1 159 | 160 | elif self.sampling_type > 0: 161 | num_pids = self.sampling_type 162 | 163 | vids_thred = 2 164 | 165 | if self.sampling_type == 125: 166 | vids_thred = 13 167 | 168 | if pid >= self.sampling_type: continue 169 | 170 | if vids_per_pid_count[pid] >= vids_thred: 171 | continue 172 | vids_per_pid_count[pid] = vids_per_pid_count[pid] + 1 173 | else: 174 | pass 175 | 176 | 177 | if len(img_paths) >= min_seq_len: 178 | img_paths = tuple(img_paths) 179 | tracklets.append((img_paths, pid, camid)) 180 | num_imgs_per_tracklet.append(len(img_paths)) 181 | 182 | num_tracklets = len(tracklets) 183 | 184 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 185 | 186 | 187 | class DukeMTMCVidReID(object): 188 | """ 189 | DukeMTMCVidReID 190 | Reference: 191 | Wu et al. Exploit the Unknown Gradually: One-Shot Video-Based Person 192 | Re-Identification by Stepwise Learning. CVPR 2018. 193 | URL: https://github.com/Yu-Wu/DukeMTMC-VideoReID 194 | 195 | Dataset statistics: 196 | # identities: 702 (train) + 702 (test) 197 | # tracklets: 2196 (train) + 2636 (test) 198 | """ 199 | 200 | def __init__(self, 201 | root='/data/baishutao/data/dukemtmc-video', 202 | sampling_step=32, 203 | min_seq_len=0, 204 | verbose=True, 205 | *args, **kwargs): 206 | self.dataset_dir = root 207 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-VideoReID.zip' 208 | 209 | self.train_dir = osp.join(self.dataset_dir, 'train') 210 | self.query_dir = osp.join(self.dataset_dir, 'query') 211 | self.gallery_dir = osp.join(self.dataset_dir, 'gallery') 212 | 213 | self.split_train_dense_json_path = osp.join(self.dataset_dir, 'split_train_dense_{}.json'.format(sampling_step)) 214 | self.split_train_json_path = osp.join(self.dataset_dir, 'split_train.json') 215 | self.split_query_json_path = osp.join(self.dataset_dir, 'split_query.json') 216 | self.split_gallery_json_path = osp.join(self.dataset_dir, 'split_gallery.json') 217 | 218 | self.split_train_1stframe_json_path = osp.join(self.dataset_dir, 'split_train_1stframe.json') 219 | self.split_query_1stframe_json_path = osp.join(self.dataset_dir, 'split_query_1stframe.json') 220 | self.split_gallery_1stframe_json_path = osp.join(self.dataset_dir, 'split_gallery_1stframe.json') 221 | 222 | self.min_seq_len = min_seq_len 223 | self._check_before_run() 224 | 225 | train, \ 226 | num_train_tracklets, \ 227 | num_train_pids, \ 228 | num_imgs_train = self._process_dir( 229 | self.train_dir, 230 | self.split_train_json_path, 231 | relabel=True) 232 | 233 | train_dense, \ 234 | num_train_tracklets_dense, \ 235 | num_train_pids_dense, \ 236 | num_imgs_train_dense = self._process_dir( 237 | self.train_dir, 238 | self.split_train_dense_json_path, 239 | relabel=True, 240 | sampling_step=sampling_step) 241 | 242 | query, \ 243 | num_query_tracklets, \ 244 | num_query_pids, \ 245 | num_imgs_query = self._process_dir( 246 | self.query_dir, 247 | self.split_query_json_path, 248 | relabel=False) 249 | gallery, \ 250 | num_gallery_tracklets, \ 251 | num_gallery_pids, \ 252 | num_imgs_gallery = self._process_dir( 253 | self.gallery_dir, 254 | self.split_gallery_json_path, 255 | relabel=False) 256 | 257 | print("the number of tracklets under dense sampling for train set: {}". 258 | format(num_train_tracklets_dense)) 259 | 260 | num_imgs_per_tracklet = num_imgs_train + num_imgs_query + num_imgs_gallery 261 | min_num = np.min(num_imgs_per_tracklet) 262 | max_num = np.max(num_imgs_per_tracklet) 263 | avg_num = np.mean(num_imgs_per_tracklet) 264 | 265 | num_total_pids = num_train_pids + num_query_pids 266 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets 267 | 268 | if verbose: 269 | print("=> DukeMTMC-VideoReID loaded") 270 | print("Dataset statistics:") 271 | print(" ------------------------------") 272 | print(" subset | # ids | # tracklets") 273 | print(" ------------------------------") 274 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 275 | if sampling_step != 0: 276 | print(" train_d | {:5d} | {:8d}".format(num_train_pids_dense, num_train_tracklets_dense)) 277 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 278 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 279 | print(" ------------------------------") 280 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 281 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 282 | print(" ------------------------------") 283 | 284 | if sampling_step!=0: 285 | self.train = train_dense 286 | else: 287 | self.train = train 288 | 289 | self.query = query 290 | self.gallery = gallery 291 | 292 | self.num_train_pids = num_train_pids 293 | self.num_query_pids = num_query_pids 294 | self.num_gallery_pids = num_gallery_pids 295 | 296 | def _check_before_run(self): 297 | """Check if all files are available before going deeper""" 298 | if not osp.exists(self.dataset_dir): 299 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 300 | if not osp.exists(self.train_dir): 301 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 302 | if not osp.exists(self.query_dir): 303 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 304 | if not osp.exists(self.gallery_dir): 305 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 306 | 307 | def _process_dir(self, dir_path, json_path, relabel, sampling_step=0): 308 | if osp.exists(json_path): 309 | # print("=> {} generated before, awesome!".format(json_path)) 310 | split = read_json(json_path) 311 | return split['tracklets'], split['num_tracklets'], split['num_pids'], split['num_imgs_per_tracklet'] 312 | 313 | print("=> Automatically generating split (might take a while for the first time, have a coffe)") 314 | pdirs = glob.glob(osp.join(dir_path, '*')) # avoid .DS_Store 315 | print("Processing {} with {} person identities".format(dir_path, len(pdirs))) 316 | 317 | pid_container = set() 318 | for pdir in pdirs: 319 | pid = int(osp.basename(pdir)) 320 | pid_container.add(pid) 321 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 322 | 323 | tracklets = [] 324 | num_imgs_per_tracklet = [] 325 | for pdir in pdirs: 326 | pid = int(osp.basename(pdir)) 327 | if relabel: pid = pid2label[pid] 328 | tdirs = glob.glob(osp.join(pdir, '*')) 329 | for tdir in tdirs: 330 | raw_img_paths = glob.glob(osp.join(tdir, '*.jpg')) 331 | num_imgs = len(raw_img_paths) 332 | 333 | if num_imgs < self.min_seq_len: 334 | continue 335 | 336 | num_imgs_per_tracklet.append(num_imgs) 337 | img_paths = [] 338 | for img_idx in range(num_imgs): 339 | # some tracklet starts from 0002 instead of 0001 340 | img_idx_name = 'F' + str(img_idx + 1).zfill(4) 341 | res = glob.glob(osp.join(tdir, '*' + img_idx_name + '*.jpg')) 342 | if len(res) == 0: 343 | print("Warn: index name {} in {} is missing, jump to next".format(img_idx_name, tdir)) 344 | continue 345 | img_paths.append(res[0]) 346 | img_name = osp.basename(img_paths[0]) 347 | if img_name.find('_') == -1: 348 | # old naming format: 0001C6F0099X30823.jpg 349 | camid = int(img_name[5]) - 1 350 | else: 351 | # new naming format: 0001_C6_F0099_X30823.jpg 352 | camid = int(img_name[6]) - 1 353 | img_paths = tuple(img_paths) 354 | 355 | # dense sampling 356 | num_sampling = len(img_paths)//sampling_step 357 | if num_sampling == 0: 358 | tracklets.append((img_paths, pid, camid)) 359 | else: 360 | for idx in range(num_sampling): 361 | if idx == num_sampling - 1: 362 | tracklets.append((img_paths[idx*sampling_step:], pid, camid)) 363 | else: 364 | tracklets.append((img_paths[idx*sampling_step : (idx+1)*sampling_step], pid, camid)) 365 | 366 | num_pids = len(pid_container) 367 | num_tracklets = len(tracklets) 368 | 369 | print("Saving split to {}".format(json_path)) 370 | split_dict = {'tracklets': tracklets, 'num_tracklets': num_tracklets, 'num_pids': num_pids, 371 | 'num_imgs_per_tracklet': num_imgs_per_tracklet, } 372 | write_json(split_dict, json_path) 373 | 374 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 375 | 376 | 377 | class iLIDSVID(object): 378 | """ 379 | iLIDS-VID 380 | 381 | Reference: 382 | Wang et al. Person Re-Identification by Video Ranking. ECCV 2014. 383 | 384 | Dataset statistics: 385 | # identities: 300 386 | # tracklets: 600 387 | # cameras: 2 388 | 389 | Args: 390 | split_id (int): indicates which split to use. There are totally 10 splits. 391 | """ 392 | 393 | def __init__(self, root, split_id=0): 394 | print('Dataset: iLIDSVID spli_id :{}'.format(split_id)) 395 | 396 | self.root = root 397 | self.dataset_url = 'http://www.eecs.qmul.ac.uk/~xiatian/iLIDS-VID/iLIDS-VID.tar' 398 | self.data_dir = osp.join(root, 'i-LIDS-VID') 399 | self.split_dir = osp.join(root, 'train-test people splits') 400 | self.split_mat_path = osp.join(self.split_dir, 'train_test_splits_ilidsvid.mat') 401 | self.split_path = osp.join(root, 'splits.json') 402 | self.cam_1_path = osp.join(root, 'i-LIDS-VID/sequences/cam1') 403 | self.cam_2_path = osp.join(root, 'i-LIDS-VID/sequences/cam2') 404 | 405 | self._download_data() 406 | self._check_before_run() 407 | 408 | self._prepare_split() 409 | splits = read_json(self.split_path) 410 | if split_id >= len(splits): 411 | raise ValueError("split_id exceeds range, received {}, but expected between 0 and {}".format(split_id, len(splits)-1)) 412 | split = splits[split_id] 413 | train_dirs, test_dirs = split['train'], split['test'] 414 | 415 | train, num_train_tracklets, num_train_pids, num_imgs_train = \ 416 | self._process_data(train_dirs, cam1=True, cam2=True) 417 | query, num_query_tracklets, num_query_pids, num_imgs_query = \ 418 | self._process_data(test_dirs, cam1=True, cam2=False) 419 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery = \ 420 | self._process_data(test_dirs, cam1=False, cam2=True) 421 | 422 | num_imgs_per_tracklet = num_imgs_train + num_imgs_query + num_imgs_gallery 423 | min_num = np.min(num_imgs_per_tracklet) 424 | max_num = np.max(num_imgs_per_tracklet) 425 | avg_num = np.mean(num_imgs_per_tracklet) 426 | 427 | num_total_pids = num_train_pids + num_query_pids 428 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets 429 | 430 | print("=> iLIDS-VID loaded w/ split_id {}".format(split_id)) 431 | print("Dataset statistics:") 432 | print(" ------------------------------") 433 | print(" subset | # ids | # tracklets") 434 | print(" ------------------------------") 435 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 436 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 437 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 438 | print(" ------------------------------") 439 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 440 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 441 | print(" ------------------------------") 442 | 443 | self.train = train 444 | self.query = query 445 | self.gallery = gallery 446 | 447 | self.num_train_pids = num_train_pids 448 | self.num_query_pids = num_query_pids 449 | self.num_gallery_pids = num_gallery_pids 450 | 451 | def _download_data(self): 452 | if osp.exists(self.root): 453 | # print("This dataset has been downloaded.") 454 | return 455 | 456 | mkdir_if_missing(self.root) 457 | fpath = osp.join(self.root, osp.basename(self.dataset_url)) 458 | 459 | print("Downloading iLIDS-VID dataset") 460 | url_opener = urllib.URLopener() 461 | url_opener.retrieve(self.dataset_url, fpath) 462 | 463 | print("Extracting files") 464 | tar = tarfile.open(fpath) 465 | tar.extractall(path=self.root) 466 | tar.close() 467 | 468 | def _check_before_run(self): 469 | """Check if all files are available before going deeper""" 470 | if not osp.exists(self.root): 471 | raise RuntimeError("'{}' is not available".format(self.root)) 472 | if not osp.exists(self.data_dir): 473 | raise RuntimeError("'{}' is not available".format(self.data_dir)) 474 | if not osp.exists(self.split_dir): 475 | raise RuntimeError("'{}' is not available".format(self.split_dir)) 476 | 477 | def _prepare_split(self): 478 | if not osp.exists(self.split_path): 479 | # print("Creating splits") 480 | mat_split_data = loadmat(self.split_mat_path)['ls_set'] 481 | 482 | num_splits = mat_split_data.shape[0] 483 | num_total_ids = mat_split_data.shape[1] 484 | assert num_splits == 10 485 | assert num_total_ids == 300 486 | num_ids_each = num_total_ids/2 487 | 488 | # pids in mat_split_data are indices, so we need to transform them 489 | # to real pids 490 | person_cam1_dirs = os.listdir(self.cam_1_path) 491 | person_cam2_dirs = os.listdir(self.cam_2_path) 492 | 493 | # make sure persons in one camera view can be found in the other camera view 494 | assert set(person_cam1_dirs) == set(person_cam2_dirs) 495 | 496 | splits = [] 497 | for i_split in range(num_splits): 498 | # first 50% for testing and the remaining for training, following Wang et al. ECCV'14. 499 | train_idxs = sorted(list(mat_split_data[i_split,num_ids_each:])) 500 | test_idxs = sorted(list(mat_split_data[i_split,:num_ids_each])) 501 | 502 | train_idxs = [int(i)-1 for i in train_idxs] 503 | test_idxs = [int(i)-1 for i in test_idxs] 504 | 505 | # transform pids to person dir names 506 | train_dirs = [person_cam1_dirs[i] for i in train_idxs] 507 | test_dirs = [person_cam1_dirs[i] for i in test_idxs] 508 | 509 | split = {'train': train_dirs, 'test': test_dirs} 510 | splits.append(split) 511 | 512 | print("Totally {} splits are created, following Wang et al. ECCV'14".format(len(splits))) 513 | print("Split file is saved to {}".format(self.split_path)) 514 | write_json(splits, self.split_path) 515 | 516 | def _process_data(self, dirnames, cam1=True, cam2=True): 517 | tracklets = [] 518 | num_imgs_per_tracklet = [] 519 | dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)} 520 | 521 | sampling_step = 0 522 | 523 | for dirname in dirnames: 524 | if cam1: 525 | person_dir = osp.join(self.cam_1_path, dirname) 526 | img_names = glob.glob(osp.join(person_dir, '*.png')) 527 | assert len(img_names) > 0 528 | # img_names = tuple(img_names) 529 | pid = dirname2pid[dirname] 530 | 531 | if sampling_step != 0: 532 | num_sampling = len(img_names) // sampling_step 533 | if num_sampling == 0: 534 | tracklets.append((img_names, pid, 0)) 535 | else: 536 | for idx in range(num_sampling): 537 | if idx == num_sampling - 1: 538 | tracklets.append((img_names[-sampling_step:], pid,0)) 539 | else: 540 | tracklets.append((img_names[idx * sampling_step: (idx + 1) * sampling_step], pid, 0)) 541 | else: 542 | tracklets.append((img_names, pid, 0)) 543 | # tracklets.append((img_names, pid, 0)) 544 | num_imgs_per_tracklet.append(len(img_names)) 545 | 546 | 547 | if cam2: 548 | person_dir = osp.join(self.cam_2_path, dirname) 549 | img_names = glob.glob(osp.join(person_dir, '*.png')) 550 | assert len(img_names) > 0 551 | # img_names = tuple(img_names) 552 | pid = dirname2pid[dirname] 553 | 554 | if sampling_step != 0: 555 | num_sampling = len(img_names) // sampling_step 556 | if num_sampling == 0: 557 | tracklets.append((img_names, pid, 1)) 558 | else: 559 | for idx in range(num_sampling): 560 | if idx == num_sampling - 1: 561 | tracklets.append((img_names[-sampling_step:], pid, 1)) 562 | else: 563 | tracklets.append((img_names[idx * sampling_step: (idx + 1) * sampling_step], pid, 1)) 564 | else: 565 | tracklets.append((img_names, pid, 1)) 566 | # tracklets.append((img_names, pid, 1)) 567 | num_imgs_per_tracklet.append(len(img_names)) 568 | 569 | num_tracklets = len(tracklets) 570 | num_pids = len(dirnames) 571 | 572 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 573 | 574 | 575 | class PRID(object): 576 | """ 577 | PRID 578 | 579 | Reference: 580 | Hirzer et al. Person Re-Identification by Descriptive and Discriminative Classification. SCIA 2011. 581 | 582 | Dataset statistics: 583 | # identities: 200 584 | # tracklets: 400 585 | # cameras: 2 586 | 587 | Args: 588 | split_id (int): indicates which split to use. There are totally 10 splits. 589 | min_seq_len (int): tracklet with length shorter than this value will be discarded (default: 0). 590 | """ 591 | 592 | 593 | def __init__(self, root, split_id=0, min_seq_len=0): 594 | 595 | self.root = root 596 | self.dataset_url = 'https://files.icg.tugraz.at/f/6ab7e8ce8f/?raw=1' 597 | self.split_path = osp.join(root, 'splits_prid2011.json') 598 | self.cam_a_path = osp.join(root, 'multi_shot', 'cam_a') 599 | self.cam_b_path = osp.join(root, 'multi_shot', 'cam_b') 600 | 601 | self._check_before_run() 602 | splits = read_json(self.split_path) 603 | if split_id >= len(splits): 604 | raise ValueError("split_id exceeds range, received {}, but expected between 0 and {}".format(split_id, len(splits)-1)) 605 | split = splits[split_id] 606 | train_dirs, test_dirs = split['train'], split['test'] 607 | print("# train identites: {}, # test identites {}".format(len(train_dirs), len(test_dirs))) 608 | 609 | train, num_train_tracklets, num_train_pids, num_imgs_train = \ 610 | self._process_data(train_dirs, cam1=True, cam2=True) 611 | query, num_query_tracklets, num_query_pids, num_imgs_query = \ 612 | self._process_data(test_dirs, cam1=True, cam2=False) 613 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery = \ 614 | self._process_data(test_dirs, cam1=False, cam2=True) 615 | 616 | num_imgs_per_tracklet = num_imgs_train + num_imgs_query + num_imgs_gallery 617 | min_num = np.min(num_imgs_per_tracklet) 618 | max_num = np.max(num_imgs_per_tracklet) 619 | avg_num = np.mean(num_imgs_per_tracklet) 620 | 621 | num_total_pids = num_train_pids + num_query_pids 622 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets 623 | 624 | print("=> PRID-2011 loaded") 625 | print("Dataset statistics:") 626 | print(" ------------------------------") 627 | print(" subset | # ids | # tracklets") 628 | print(" ------------------------------") 629 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 630 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 631 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 632 | print(" ------------------------------") 633 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 634 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 635 | print(" ------------------------------") 636 | 637 | self.train = train 638 | self.query = query 639 | self.gallery = gallery 640 | 641 | self.num_train_pids = num_train_pids 642 | self.num_query_pids = num_query_pids 643 | self.num_gallery_pids = num_gallery_pids 644 | 645 | def _check_before_run(self): 646 | """Check if all files are available before going deeper""" 647 | if not osp.exists(self.root): 648 | raise RuntimeError("'{}' is not available".format(self.root)) 649 | 650 | def _process_data(self, dirnames, cam1=True, cam2=True): 651 | tracklets = [] 652 | num_imgs_per_tracklet = [] 653 | dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)} 654 | 655 | for dirname in dirnames: 656 | if cam1: 657 | person_dir = osp.join(self.cam_a_path, dirname) 658 | img_names = glob.glob(osp.join(person_dir, '*.png')) 659 | assert len(img_names) > 0 660 | img_names = tuple(img_names) 661 | pid = dirname2pid[dirname] 662 | tracklets.append((img_names, pid, 0)) 663 | num_imgs_per_tracklet.append(len(img_names)) 664 | 665 | if cam2: 666 | person_dir = osp.join(self.cam_b_path, dirname) 667 | img_names = glob.glob(osp.join(person_dir, '*.png')) 668 | assert len(img_names) > 0 669 | img_names = tuple(img_names) 670 | pid = dirname2pid[dirname] 671 | tracklets.append((img_names, pid, 1)) 672 | num_imgs_per_tracklet.append(len(img_names)) 673 | 674 | num_tracklets = len(tracklets) 675 | num_pids = len(dirnames) 676 | 677 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 678 | 679 | 680 | class LSVID(object): 681 | """ 682 | LS-VID 683 | 684 | Reference: 685 | Li J, Wang J, Tian Q, Gao W and Zhang S Global-Local Temporal Representations for Video Person Re-Identification[J]. ICCV, 2019 686 | 687 | Dataset statistics: 688 | # identities: 3772 689 | # tracklets: 2831 (train) + 3504 (query) + 7829 (gallery) 690 | # cameras: 15 691 | 692 | Note: 693 | # gallery set must contain query set, otherwise 140 query imgs will not have ground truth. 694 | # gallery imgs with label=-1 can be remove, which do not influence on final performance. 695 | 696 | Args: 697 | min_seq_len (int): tracklet with length shorter than this value will be discarded (default: 0). 698 | """ 699 | 700 | def __init__(self, root=None, sampling_step=48, *args, **kwargs): 701 | self._root = root 702 | self.train_name_path = osp.join(self._root, 'list_sequence/list_seq_train.txt') 703 | self.test_name_path = osp.join(self._root, 'list_sequence/list_seq_test.txt') 704 | self.query_IDX_path = osp.join(self._root, 'test/data/info_test.mat') 705 | 706 | self._check_before_run() 707 | 708 | # prepare meta data 709 | track_train = self._get_names(self.train_name_path) 710 | track_test = self._get_names(self.test_name_path) 711 | 712 | track_train = np.array(track_train) 713 | track_test = np.array(track_test) 714 | 715 | query_IDX = h5py.File(self.query_IDX_path, mode='r')['query'][0,:] # numpy.ndarray (1980,) 716 | query_IDX = np.array(query_IDX, dtype=int) 717 | 718 | query_IDX -= 1 # index from 0 719 | track_query = track_test[query_IDX, :] 720 | 721 | gallery_IDX = [i for i in range(track_test.shape[0]) if i not in query_IDX] 722 | track_gallery = track_test[gallery_IDX, :] 723 | 724 | self.split_train_dense_json_path = osp.join(self._root,'split_train_dense_{}.json'.format(sampling_step)) 725 | self.split_train_json_path = osp.join(self._root, 'split_train.json') 726 | self.split_query_json_path = osp.join(self._root, 'split_query.json') 727 | self.split_gallery_json_path = osp.join(self._root, 'split_gallery.json') 728 | 729 | train, num_train_tracklets, num_train_pids, num_train_imgs = \ 730 | self._process_data(track_train, json_path=self.split_train_json_path, relabel=True) 731 | 732 | train_dense, num_train_tracklets_dense, num_train_pids_dense, num_train_imgs_dense = \ 733 | self._process_data(track_train, json_path=self.split_train_dense_json_path, relabel=True, sampling_step=sampling_step) 734 | 735 | query, num_query_tracklets, num_query_pids, num_query_imgs = \ 736 | self._process_data(track_query, json_path=self.split_query_json_path, relabel=False) 737 | 738 | gallery, num_gallery_tracklets, num_gallery_pids, num_gallery_imgs = \ 739 | self._process_data(track_gallery, json_path=self.split_gallery_json_path, relabel=False) 740 | 741 | num_imgs_per_tracklet = num_train_imgs + num_gallery_imgs + num_query_imgs 742 | min_num = np.min(num_imgs_per_tracklet) 743 | max_num = np.max(num_imgs_per_tracklet) 744 | avg_num = np.mean(num_imgs_per_tracklet) 745 | 746 | num_total_pids = num_train_pids + num_gallery_pids 747 | num_total_tracklets = num_train_tracklets + num_gallery_tracklets + num_query_tracklets 748 | 749 | print("=> LS-VID loaded") 750 | print("Dataset statistics:") 751 | print(" ------------------------------") 752 | print(" subset | # ids | # tracklets") 753 | print(" ------------------------------") 754 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 755 | if sampling_step != 0: 756 | print(" train_d | {:5d} | {:8d}".format(num_train_pids_dense, num_train_tracklets_dense)) 757 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 758 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 759 | print(" ------------------------------") 760 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 761 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 762 | print(" ------------------------------") 763 | 764 | if sampling_step != 0: 765 | self.train = train_dense 766 | else: 767 | self.train = train 768 | self.query = query 769 | self.gallery = gallery 770 | 771 | self.num_train_pids = num_train_pids 772 | self.num_query_pids = num_query_pids 773 | self.num_gallery_pids = num_gallery_pids 774 | 775 | def _check_before_run(self): 776 | """Check if all files are available before going deeper""" 777 | if not osp.exists(self._root): 778 | raise RuntimeError("'{}' is not available".format(self._root)) 779 | if not osp.exists(self.train_name_path): 780 | raise RuntimeError("'{}' is not available".format(self.train_name_path)) 781 | if not osp.exists(self.test_name_path): 782 | raise RuntimeError("'{}' is not available".format(self.test_name_path)) 783 | if not osp.exists(self.query_IDX_path): 784 | raise RuntimeError("'{}' is not available".format(self.query_IDX_path)) 785 | 786 | def _get_names(self, fpath): 787 | names = [] 788 | with open(fpath, 'r') as f: 789 | for line in f: 790 | new_line = line.rstrip() 791 | basepath, pid = new_line.split(' ') 792 | names.append([basepath, int(pid)]) 793 | return names 794 | 795 | def _process_data(self, 796 | meta_data, 797 | relabel=False, 798 | json_path=None, 799 | sampling_step=0): 800 | if osp.exists(json_path): 801 | split = read_json(json_path) 802 | return split['tracklets'], split['num_tracklets'], split['num_pids'], split['num_imgs_per_tracklet'] 803 | 804 | num_tracklets = meta_data.shape[0] 805 | pid_list = list(set(meta_data[:, 1].tolist())) 806 | num_pids = len(pid_list) 807 | 808 | if relabel: pid2label = {int(pid): label for label, pid in enumerate(pid_list)} 809 | tracklets = [] 810 | num_imgs_per_tracklet = [] 811 | 812 | vids_per_pid_count = np.zeros(len(pid_list)) 813 | 814 | for tracklet_idx in range(num_tracklets): 815 | tracklet_path = osp.join(self._root, meta_data[tracklet_idx, 0]) + '*' 816 | img_paths = glob.glob(tracklet_path) # avoid .DS_Store 817 | img_paths.sort() 818 | pid = int(meta_data[tracklet_idx, 1]) 819 | _, _, camid, _ = osp.basename(img_paths[0]).split('_')[:4] 820 | camid = int(camid) 821 | 822 | if pid == -1: continue # junk images are just ignored 823 | assert 1 <= camid <= 15 824 | if relabel: pid = pid2label[pid] 825 | camid -= 1 # index starts from 0 826 | 827 | num_sampling = len(img_paths) // sampling_step 828 | if num_sampling == 0: 829 | tracklets.append((img_paths, pid, camid)) 830 | else: 831 | for idx in range(num_sampling): 832 | if idx == num_sampling - 1: 833 | tracklets.append((img_paths[idx * sampling_step:], pid, camid)) 834 | else: 835 | tracklets.append((img_paths[idx * sampling_step: (idx + 1) * sampling_step], pid, camid)) 836 | num_imgs_per_tracklet.append(len(img_paths)) 837 | 838 | num_tracklets = len(tracklets) 839 | 840 | print("Saving split to {}".format(json_path)) 841 | split_dict = {'tracklets': tracklets, 'num_tracklets': num_tracklets, 'num_pids': num_pids, 842 | 'num_imgs_per_tracklet': num_imgs_per_tracklet, } 843 | write_json(split_dict, json_path) 844 | 845 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 846 | 847 | 848 | __factory = { 849 | 'mars': Mars, 850 | 'ilidsvid': iLIDSVID, 851 | 'prid': PRID, 852 | 'lsvid': LSVID, 853 | 'duke': DukeMTMCVidReID, 854 | } 855 | 856 | 857 | def get_names(): 858 | return __factory.keys() 859 | 860 | 861 | def init_dataset(name, root=None, *args, **kwargs): 862 | if name not in __factory.keys(): 863 | raise KeyError("Unknown dataset: {}".format(name)) 864 | 865 | return __factory[name](root=root, *args, **kwargs) 866 | -------------------------------------------------------------------------------- /utils/eval_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def compute_ap_cmc(index, good_index, junk_index): 7 | ap = 0 8 | cmc = np.zeros(len(index)) 9 | 10 | # remove junk_index 11 | mask = np.in1d(index, junk_index, invert=True) 12 | index = index[mask] 13 | 14 | # find good_index index 15 | ngood = len(good_index) 16 | mask = np.in1d(index, good_index) 17 | rows_good = np.argwhere(mask==True) 18 | rows_good = rows_good.flatten() 19 | 20 | cmc[rows_good[0]:] = 1.0 21 | for i in range(ngood): 22 | d_recall = 1.0/ngood 23 | precision = (i+1)*1.0/(rows_good[i]+1) 24 | ap = ap + d_recall*precision 25 | 26 | return ap, cmc 27 | 28 | 29 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids): 30 | num_q, num_g = distmat.shape 31 | index = torch.argsort(distmat, dim=1) # from small to large 32 | index = index.numpy() 33 | 34 | num_no_gt = 0 # num of query imgs without groundtruth 35 | num_r1 = 0 36 | CMC = np.zeros(len(g_pids)) 37 | AP = 0 38 | 39 | for i in range(num_q): 40 | # ground truth index 41 | query_index = np.argwhere(g_pids==q_pids[i]) 42 | camera_index = np.argwhere(g_camids==q_camids[i]) 43 | good_index = np.setdiff1d(query_index, camera_index, assume_unique=True) 44 | if good_index.size == 0: 45 | num_no_gt += 1 46 | continue 47 | # remove gallery samples that have the same pid and camid with query 48 | junk_index = np.intersect1d(query_index, camera_index) 49 | 50 | ap_tmp, CMC_tmp = compute_ap_cmc(index[i], good_index, junk_index) 51 | if CMC_tmp[0]==1: 52 | num_r1 += 1 53 | CMC = CMC + CMC_tmp 54 | AP += ap_tmp 55 | 56 | # if num_no_gt > 0: 57 | # print("{} query imgs do not have groundtruth.".format(num_no_gt)) 58 | 59 | CMC = CMC / (num_q - num_no_gt) 60 | mAP = AP / (num_q - num_no_gt) 61 | 62 | return CMC, mAP 63 | 64 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | 8 | 9 | class CrossEntropyLabelSmooth(nn.Module): 10 | """Cross entropy loss with label smoothing regularizer. 11 | 12 | Reference: 13 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 14 | Equation: y = (1 - epsilon) * y + epsilon / K. 15 | 16 | Args: 17 | num_classes (int): number of classes. 18 | epsilon (float): weight. 19 | """ 20 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 21 | super(CrossEntropyLabelSmooth, self).__init__() 22 | self.num_classes = num_classes 23 | self.epsilon = epsilon 24 | self.use_gpu = use_gpu 25 | self.logsoftmax = nn.LogSoftmax(dim=1) 26 | 27 | def forward(self, inputs, targets): 28 | """ 29 | Args: 30 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 31 | targets: ground truth labels with shape (num_classes) 32 | """ 33 | log_probs = self.logsoftmax(inputs) 34 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 35 | if self.use_gpu: targets = targets.cuda() 36 | targets = Variable(targets, requires_grad=False) 37 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 38 | loss = (- targets * log_probs).mean(0).sum() 39 | return loss 40 | 41 | 42 | class TripletLoss(nn.Module): 43 | """Triplet loss with hard positive/negative mining. 44 | Reference: 45 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 46 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 47 | Args: 48 | margin (float): margin for triplet. 49 | """ 50 | 51 | def __init__(self, margin=0.3, distance='cosine'): 52 | super(TripletLoss, self).__init__() 53 | if distance not in ['euclidean', 'cosine']: 54 | raise KeyError("Unsupported distance: {}".format(distance)) 55 | self.distance = distance 56 | self.margin = margin 57 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 58 | 59 | def forward(self, input, target): 60 | """ 61 | :param input: feature matrix with shape (batch_size, feat_dim) 62 | :param target: ground truth labels with shape (batch_size) 63 | :return: 64 | """ 65 | n = input.size(0) 66 | # Compute pairwise distance, replace by the official when merged 67 | if self.distance == 'cosine': 68 | input = F.normalize(input, dim=-1) 69 | dist = - torch.matmul(input, input.t()) 70 | else: 71 | raise NotImplementedError 72 | 73 | # For each anchor, find the hardest positive and negative 74 | mask = target.expand(n, n).eq(target.expand(n, n).t()).float() 75 | dist_ap, _ = torch.topk(dist*mask - (1-mask), dim=-1, k=1) 76 | dist_an, _ = torch.topk(dist*(1-mask) + mask, dim=-1, k=1, largest=False) 77 | 78 | 79 | # Compute ranking hinge loss 80 | y = torch.ones_like(dist_an) 81 | loss = self.ranking_loss(dist_an, dist_ap, y) 82 | return loss 83 | 84 | 85 | class InfoNce(nn.Module): 86 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 87 | It also supports the unsupervised contrastive loss in SimCLR""" 88 | 89 | def __init__(self, 90 | temperature=0.07, 91 | num_instance=4): 92 | 93 | super(InfoNce, self).__init__() 94 | self.temperature = temperature 95 | self.ni = num_instance 96 | 97 | def forward(self, features): 98 | """ 99 | :param features: (B, C, T) 100 | :param labels: (B) 101 | :return: 102 | """ 103 | b, c, t = features.shape 104 | if t == 8: 105 | features = features.reshape(b, c, 2, 4).transpose(1, 2).reshape(b*2, c, 4) 106 | b, c, t = features.shape 107 | 108 | ni = self.ni 109 | features = features.reshape(b//ni, ni, c, t).permute(0, 3, 1, 2).reshape(b//ni, t*ni, c) 110 | features = F.normalize(features, dim=-1) 111 | labels = torch.arange(0, t).reshape(t, 1).repeat(1, ni).reshape(t*ni, 1) 112 | # (t*ni, t*ni) 113 | mask = torch.eq(labels.view(-1, 1), labels.view(1, -1)).float().cuda() # (t*ni, t*ni) 114 | mask_pos = (1 - torch.eye(t*ni)).cuda() 115 | mask_pos = (mask * mask_pos).unsqueeze(0) 116 | 117 | # (b//ni, t*ni, t*ni) 118 | cos = torch.matmul(features, features.transpose(-1, -2)) 119 | 120 | logits = torch.div(cos, self.temperature) 121 | exp_neg_logits = (logits.exp() * (1-mask)).sum(dim=-1, keepdim=True) 122 | 123 | log_prob = logits - torch.log(exp_neg_logits + logits.exp()) 124 | loss = (log_prob * mask_pos).sum() / (mask_pos.sum()) 125 | loss = - loss 126 | return loss 127 | 128 | 129 | if __name__ == '__main__': 130 | loss = InfoNce() 131 | x = torch.rand(8, 16, 4).cuda() 132 | y = loss(x) 133 | print(y) 134 | -------------------------------------------------------------------------------- /utils/samplers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | from collections import defaultdict 5 | import numpy as np 6 | import copy 7 | import random 8 | from torch.utils.data.sampler import Sampler 9 | 10 | 11 | class RandomIdentitySampler(Sampler): 12 | """ 13 | Randomly sample N identities, then for each identity, 14 | randomly sample K instances, therefore batch size is N*K. 15 | 16 | Args: 17 | - data_source (Dataset): dataset to sample from. 18 | - num_instances (int): number of instances per identity. 19 | """ 20 | def __init__(self, data_source, num_instances=4): 21 | self.data_source = data_source 22 | self.num_instances = num_instances 23 | self.index_dic = defaultdict(list) 24 | 25 | for index, (_, pid, _) in enumerate(data_source): 26 | self.index_dic[pid].append(index) 27 | 28 | self.pids = list(self.index_dic.keys()) 29 | self.num_identities = len(self.pids) 30 | 31 | # compute number of examples in an epoch 32 | self.length = 0 33 | for pid in self.pids: 34 | idxs = self.index_dic[pid] 35 | num = len(idxs) 36 | if num < self.num_instances: 37 | num = self.num_instances 38 | self.length += num - num % self.num_instances 39 | 40 | def __iter__(self): 41 | list_container = [] 42 | 43 | for pid in self.pids: 44 | idxs = copy.deepcopy(self.index_dic[pid]) 45 | if len(idxs) < self.num_instances: 46 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 47 | random.shuffle(idxs) 48 | batch_idxs = [] 49 | for idx in idxs: 50 | batch_idxs.append(idx) 51 | if len(batch_idxs) == self.num_instances: 52 | list_container.append(batch_idxs) 53 | batch_idxs = [] 54 | 55 | random.shuffle(list_container) 56 | 57 | ret = [] 58 | for batch_idxs in list_container: 59 | ret.extend(batch_idxs) 60 | 61 | return iter(ret) 62 | 63 | def __len__(self): 64 | return self.length -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | import errno 5 | import shutil 6 | import json 7 | import os.path as osp 8 | 9 | import torch 10 | import time 11 | import math 12 | 13 | def mkdir_if_missing(directory): 14 | if not osp.exists(directory): 15 | try: 16 | os.makedirs(directory) 17 | except OSError as e: 18 | if e.errno != errno.EEXIST: 19 | raise 20 | 21 | class AverageMeter(object): 22 | """Computes and stores the average and current value. 23 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 24 | """ 25 | def __init__(self): 26 | self.reset() 27 | 28 | def reset(self): 29 | self.val = 0 30 | self.avg = 0 31 | self.sum = 0 32 | self.count = 0 33 | 34 | def update(self, val, n=1): 35 | self.val = val 36 | self.sum += val * n 37 | self.count += n 38 | self.avg = self.sum / self.count 39 | 40 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 41 | mkdir_if_missing(osp.dirname(fpath)) 42 | torch.save(state, fpath) 43 | if is_best: 44 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar')) 45 | 46 | def read_json(fpath): 47 | with open(fpath, 'r') as f: 48 | obj = json.load(f) 49 | return obj 50 | 51 | def write_json(obj, fpath): 52 | mkdir_if_missing(osp.dirname(fpath)) 53 | with open(fpath, 'w') as f: 54 | json.dump(obj, f, indent=4, separators=(',', ': ')) 55 | 56 | def print_time(string=''): 57 | ctime = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) 58 | res = ctime + ' | ' + string 59 | print(res) 60 | 61 | class Logger(object): 62 | """ 63 | Write console output to external text file. 64 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 65 | """ 66 | def __init__(self, fpath=None): 67 | self.console = sys.stdout 68 | self.file = None 69 | if fpath is not None: 70 | mkdir_if_missing(os.path.dirname(fpath)) 71 | self.file = open(fpath, 'w') 72 | 73 | def __del__(self): 74 | self.close() 75 | 76 | def __enter__(self): 77 | pass 78 | 79 | def __exit__(self, *args): 80 | self.close() 81 | 82 | def write(self, msg): 83 | self.console.write(msg) 84 | if self.file is not None: 85 | self.file.write(msg) 86 | 87 | def flush(self): 88 | self.console.flush() 89 | if self.file is not None: 90 | self.file.flush() 91 | os.fsync(self.file.fileno()) 92 | 93 | def close(self): 94 | self.console.close() 95 | if self.file is not None: 96 | self.file.close() 97 | -------------------------------------------------------------------------------- /utils/video_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import os 4 | import math 5 | import torch 6 | import functools 7 | import torch.utils.data as data 8 | from PIL import Image 9 | 10 | 11 | def pil_loader(path): 12 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 13 | with open(path, 'rb') as f: 14 | with Image.open(f) as img: 15 | return img.convert('RGB') 16 | 17 | def accimage_loader(path): 18 | try: 19 | import accimage 20 | return accimage.Image(path) 21 | except IOError: 22 | # Potentially a decoding problem, fall back to PIL.Image 23 | return pil_loader(path) 24 | 25 | def get_default_image_loader(): 26 | from torchvision import get_image_backend 27 | if get_image_backend() == 'accimage': 28 | return accimage_loader 29 | else: 30 | return pil_loader 31 | 32 | def image_loader(path): 33 | from torchvision import get_image_backend 34 | if get_image_backend() == 'accimage': 35 | return accimage_loader(path) 36 | else: 37 | return pil_loader(path) 38 | 39 | def video_loader(img_paths, image_loader): 40 | video = [] 41 | for image_path in img_paths: 42 | if os.path.exists(image_path): 43 | video.append(image_loader(image_path)) 44 | else: 45 | return video 46 | 47 | return video 48 | 49 | def get_default_video_loader(): 50 | image_loader = get_default_image_loader() 51 | return functools.partial(video_loader, image_loader=image_loader) 52 | 53 | 54 | class VideoDataset(data.Dataset): 55 | """Video Person ReID Dataset. 56 | Note: 57 | Batch data has shape N x C x T x H x W 58 | Args: 59 | dataset (list): List with items (img_paths, pid, camid) 60 | temporal_transform (callable, optional): A function/transform that takes in a list of frame indices 61 | and returns a transformed version 62 | target_transform (callable, optional): A function/transform that takes in the 63 | target and transforms it. 64 | loader (callable, optional): A function to load an video given its path and frame indices. 65 | """ 66 | 67 | def __init__(self, 68 | dataset, 69 | spatial_transform=None, 70 | temporal_transform=None, 71 | get_loader=get_default_video_loader): 72 | self.dataset = dataset 73 | self.spatial_transform = spatial_transform 74 | self.temporal_transform = temporal_transform 75 | self.loader = get_loader() 76 | 77 | def __len__(self): 78 | return len(self.dataset) 79 | 80 | def __getitem__(self, index): 81 | """ 82 | Args: 83 | index (int): Index 84 | 85 | Returns: 86 | tuple: (clip, pid, camid) where pid is identity of the clip. 87 | """ 88 | img_paths, pid, camid = self.dataset[index] 89 | 90 | if self.temporal_transform is not None: 91 | img_paths = self.temporal_transform(img_paths) 92 | 93 | clip = self.loader(img_paths) 94 | 95 | if self.spatial_transform is not None: 96 | self.spatial_transform.randomize_parameters() 97 | clip = [self.spatial_transform(img) for img in clip] 98 | 99 | # trans T x C x H x W to C x T x H x W 100 | clip = torch.stack(clip, 0).permute(1, 0, 2, 3) 101 | 102 | return clip, pid, camid 103 | 104 | 105 | class VideoDatasetInfer(data.Dataset): 106 | """Video Person ReID Dataset. 107 | Note: 108 | Batch data has shape N x C x T x H x W 109 | Args: 110 | dataset (list): List with items (img_paths, pid, camid) 111 | temporal_transform (callable, optional): A function/transform that takes in a list of frame indices 112 | and returns a transformed version 113 | target_transform (callable, optional): A function/transform that takes in the 114 | target and transforms it. 115 | loader (callable, optional): A function to load an video given its path and frame indices. 116 | """ 117 | 118 | def __init__(self, 119 | dataset, 120 | seq_len=12, 121 | temporal_sampler='restricted', 122 | spatial_transform=None, 123 | get_loader=get_default_video_loader): 124 | self.dataset = dataset 125 | self.seq_len = seq_len 126 | self.temporal_sampler = temporal_sampler 127 | self.spatial_transform = spatial_transform 128 | 129 | self.loader = get_loader() 130 | 131 | 132 | def __len__(self): 133 | return len(self.dataset) 134 | 135 | 136 | @staticmethod 137 | def loop_padding(img_paths, size): 138 | img_paths = list(img_paths) 139 | exp_len = math.ceil(len(img_paths) / size) * size 140 | while len(img_paths) != exp_len: 141 | lack_num = exp_len - len(img_paths) 142 | if len(img_paths) > lack_num: 143 | img_paths.extend(img_paths[-lack_num:]) 144 | else: 145 | img_paths.extend(img_paths) 146 | 147 | img_paths.sort() 148 | assert len(img_paths) % size == 0, \ 149 | 'every clip must have {} frames, but we have {}' \ 150 | .format(size, len(img_paths)) 151 | return img_paths 152 | 153 | 154 | def __getitem__(self, index): 155 | img_paths, pid, camid = self.dataset[index] 156 | img_paths = self.loop_padding(img_paths, self.seq_len) 157 | 158 | clip = self.loader(img_paths) 159 | if self.spatial_transform is not None: 160 | self.spatial_transform.randomize_parameters() 161 | clip = [self.spatial_transform(img) for img in clip] 162 | 163 | # C x T x H x W 164 | clip = torch.stack(clip, 1) 165 | C, T, H, W = clip.size() 166 | 167 | # T//seq_len, C, seq_len, H, W 168 | if self.temporal_sampler == 'restricted': 169 | clip = clip.reshape(C, self.seq_len, T // self.seq_len, H, W).permute(2, 0, 1, 3, 4) 170 | else: 171 | raise NotImplementedError 172 | 173 | return clip, pid, camid 174 | --------------------------------------------------------------------------------