├── README.md ├── cam_functions.py ├── mars_train.py ├── reid ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── __init__.cpython-37.pyc ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── dataloader.cpython-36.pyc │ │ ├── dataloader.cpython-37.pyc │ │ ├── datasequence.cpython-36.pyc │ │ ├── datasequence.cpython-37.pyc │ │ ├── sampler.cpython-36.pyc │ │ ├── sampler.cpython-37.pyc │ │ ├── seqpreprocessor.cpython-36.pyc │ │ ├── seqpreprocessor.cpython-37.pyc │ │ ├── seqtransforms.cpython-36.pyc │ │ ├── seqtransforms.cpython-37.pyc │ │ ├── video_loader.cpython-36.pyc │ │ └── video_loader.cpython-37.pyc │ ├── data_manager.py │ ├── dataloader.py │ ├── datasequence.py │ ├── sampler.py │ ├── samplers.py │ ├── seqpreprocessor.py │ ├── seqtransforms.py │ └── video_loader.py ├── dataset │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── duke.cpython-36.pyc │ │ ├── duke.cpython-37.pyc │ │ ├── ilidsvidsequence.cpython-36.pyc │ │ ├── ilidsvidsequence.cpython-37.pyc │ │ ├── mars.cpython-36.pyc │ │ ├── mars.cpython-37.pyc │ │ ├── prid2011sequence.cpython-36.pyc │ │ └── prid2011sequence.cpython-37.pyc │ ├── duke.py │ ├── ilidsvidsequence.py │ ├── mars.py │ └── prid2011sequence.py ├── evaluator │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── attevaluator.cpython-36.pyc │ │ ├── attevaluator.cpython-37.pyc │ │ ├── eva_functions.cpython-36.pyc │ │ ├── eva_functions.cpython-37.pyc │ │ ├── rerank.cpython-36.pyc │ │ ├── rerank.cpython-37.pyc │ │ ├── visualize.cpython-36.pyc │ │ └── visualize.cpython-37.pyc │ ├── attevaluator.py │ ├── eva_functions.py │ ├── evaluator.py │ ├── rerank.py │ └── visualize.py ├── loss │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── oim.cpython-36.pyc │ │ ├── oim.cpython-37.pyc │ │ ├── pairloss.cpython-36.pyc │ │ ├── pairloss.cpython-37.pyc │ │ ├── triplet.cpython-36.pyc │ │ └── triplet_oim.cpython-36.pyc │ ├── oim.py │ ├── pairloss.py │ ├── triplet.py │ └── triplet_oim.py ├── models │ ├── Siamese.py │ ├── Siamese_video.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── Siamese.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── basebranch.cpython-36.pyc │ │ ├── grl_model.cpython-36.pyc │ │ ├── resnet.cpython-36.pyc │ │ ├── resnet.cpython-37.pyc │ │ ├── resnets1.cpython-36.pyc │ │ └── resnets1.cpython-37.pyc │ ├── basebranch.py │ ├── grl_model.py │ ├── resnet.py │ └── resnets1.py └── train │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ └── trainer.cpython-36.pyc │ └── trainer.py ├── test_all.py ├── utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── logging.cpython-36.pyc │ ├── logging.cpython-37.pyc │ ├── meters.cpython-36.pyc │ ├── meters.cpython-37.pyc │ ├── osutils.cpython-36.pyc │ ├── osutils.cpython-37.pyc │ ├── serialization.cpython-36.pyc │ └── serialization.cpython-37.pyc ├── logging.py ├── meters.py ├── osutils.py └── serialization.py └── visualize.py /README.md: -------------------------------------------------------------------------------- 1 | # GRL 2 | This repo is the implementation of "Watching You: Global-guided Reciprocal Learning for Video-based Person Re-identification" 3 | # Requirements: 4 | python==3.6 5 | pytorch==1.0 6 | # Usage 7 | To train the model, please run 8 | python mars_train.py 9 | -------------------------------------------------------------------------------- /cam_functions.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | # -*- coding:utf-8 _*- 3 | # @author: ycy 4 | # @contact: asuradayuci@gmail.com 5 | # @time: 2019/8/13 下午8:06 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | # import matplotlib.pyplot as plt 11 | import os 12 | from visualize import visualize 13 | from torchvision.utils import save_image 14 | from utils.osutils import mkdir_if_missing 15 | 16 | dirs_now = os.path.dirname(os.path.abspath(__file__)) 17 | PATH = dirs_now + '/mask_eval/' 18 | PATH_EVAL = dirs_now + '/SADouteval/' 19 | 20 | 21 | def visual_batch(cam, image, k, save_dir, mode): 22 | cam = cam.squeeze() # b, t, 16, 8 torch.Size([240, 16, 8]) 23 | b, t, h, w = cam.size() # b, t, 16, 8 torch.Size([240, 16, 8]) 24 | image = torch.stack(image, 0).contiguous() # [bt, 1, 3, 256, 128] torch.Size([240, 1, 3, 256, 128]) 25 | # cam = cam.view(8, 8, *cam.size()[1:]) # 8,8,16,8 26 | 27 | image = image.view(b, t, 1, *image.size()[-3:]) # b, t, 1, 3, 256, 128 28 | path = PATH + save_dir 29 | mkdir_if_missing(path) 30 | for i in range(cam.size(0)): # b 31 | fig = plt.figure(figsize=(15, 15)) 32 | for j in range(cam.size(1)): # t 33 | ax1 = plt.subplot(3, 8, j+1) 34 | ax1.axis('off') 35 | plt.title('cam', fontsize=18) 36 | plt.imshow(cam[i][j].detach().cpu().numpy(), alpha=0.6, cmap='jet') 37 | 38 | ax3 = plt.subplot(3, 8, j + 17) 39 | ax3.axis('off') 40 | plt.title('cam+img', fontsize=18) 41 | cam_ij = cam[i][j].unsqueeze(0) 42 | cam_ij = cam_ij.unsqueeze(0) 43 | images_ij = image[i][j] 44 | heatmap, raw_image = visualize(images_ij, cam_ij) 45 | heatmap = heatmap.squeeze().cpu().numpy().transpose(1, 2, 0) 46 | plt.imshow(heatmap) 47 | 48 | ax4 = plt.subplot(3, 8, j + 9) 49 | ax4.axis('off') 50 | plt.title('raw_image', fontsize=18) 51 | raw_image = raw_image.squeeze().cpu().numpy().transpose(1, 2, 0) 52 | plt.imshow(raw_image) 53 | # fig.tight_layout() 54 | fig.savefig(path + "/iter_{}index_{}_{}.jpg".format(k, i, mode)) 55 | 56 | 57 | def visual_batch_eval(cam, image, length, k): 58 | cam = torch.stack(cam, 0).contiguous() # torch.Size([240, 16, 8]) 59 | image = torch.stack(image, 0).contiguous() # torch.Size([240, 1, 3, 256, 128]) 60 | 61 | cam = cam.view(30, 8, 16, -1) # 8,8,16,8 62 | 63 | image = image.view(30, 8, 1, 3, 256, -1) 64 | path = PATH_EVAL + "fenzhi{}".format(k) 65 | mkdir_if_missing(path) 66 | for i in range(cam.size(0)): 67 | fig = plt.figure(figsize=(15, 15)) 68 | for j in range(cam.size(1)): 69 | ax1 = plt.subplot(3, 8, j+1) 70 | ax1.axis('off') 71 | plt.title('cam', fontsize=18) 72 | plt.imshow(cam[i][j].detach().cpu().numpy(), alpha=0.6, cmap='jet') 73 | 74 | ax3 = plt.subplot(3, 8, j + 17) 75 | ax3.axis('off') 76 | plt.title('cam+img', fontsize=18) 77 | cam_ij = cam[i][j].unsqueeze(0) 78 | cam_ij = cam_ij.unsqueeze(0) 79 | images_ij = image[i][j] 80 | heatmap, raw_image = visualize(images_ij, cam_ij) 81 | heatmap = heatmap.squeeze().cpu().numpy().transpose(1, 2, 0) 82 | plt.imshow(heatmap) 83 | 84 | ax4 = plt.subplot(3, 8, j + 9) 85 | ax4.axis('off') 86 | plt.title('raw_image', fontsize=18) 87 | raw_image = raw_image.squeeze().cpu().numpy().transpose(1, 2, 0) 88 | plt.imshow(raw_image) 89 | # fig.tight_layout() 90 | fig.savefig(PATH_EVAL + "fenzhi{}/cambatch_{}.jpg".format(k, i)) -------------------------------------------------------------------------------- /mars_train.py: -------------------------------------------------------------------------------- 1 | # system tool 2 | from __future__ import print_function, absolute_import 3 | import argparse 4 | import os 5 | import os.path as osp 6 | import sys 7 | 8 | # computation tool 9 | import torch 10 | import numpy as np 11 | 12 | # device tool 13 | import torch.backends.cudnn as cudnn 14 | from utils.logging import Logger 15 | from reid import models 16 | from utils.serialization import load_checkpoint, save_cnn_checkpoint, save_siamese_checkpoint 17 | from utils.serialization import remove_repeat_tensorboard_files 18 | from reid.loss import PairLoss, OIMLoss 19 | from reid.data import get_data 20 | from reid.train import SEQTrainer 21 | from reid.evaluator import ATTEvaluator 22 | 23 | 24 | def save_checkpoint(cnn_model, siamese_model, epoch, best_top1, is_best): 25 | save_cnn_checkpoint({ 26 | 'state_dict': cnn_model.state_dict(), 27 | 'epoch': epoch + 1, 28 | 'best_top1': best_top1, 29 | }, is_best, fpath=osp.join(args.logs_dir, 'cnn_checkpoint.pth.tar')) 30 | 31 | save_siamese_checkpoint({ 32 | 'state_dict': siamese_model.state_dict(), 33 | 'epoch': epoch + 1, 34 | 'best_top1': best_top1, 35 | }, is_best, fpath=osp.join(args.logs_dir, 'siamese_checkpoint.pth.tar')) 36 | 37 | 38 | def load_best_checkpoint(cnn_model, siamese_model): 39 | checkpoint0 = load_checkpoint(osp.join(args.logs_dir, 'cnnmodel_best.pth.tar')) 40 | cnn_model.load_state_dict(checkpoint0['state_dict']) 41 | 42 | checkpoint1 = load_checkpoint(osp.join(args.logs_dir, 'siamesemodel_best.pth.tar')) 43 | siamese_model.load_state_dict(checkpoint1['state_dict']) 44 | 45 | 46 | def main(args): 47 | 48 | np.random.seed(args.seed) 49 | torch.manual_seed(args.seed) 50 | torch.cuda.manual_seed_all(args.seed) 51 | cudnn.benchmark = True 52 | os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' 53 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 54 | 55 | # log file 日志文件 防止重名覆盖 56 | run = 0 57 | if args.evaluate == 1: 58 | while osp.exists("%s" % (osp.join(args.logs_dir, 'log_test{}.txt'.format(run)))): 59 | run += 1 60 | 61 | sys.stdout = Logger(osp.join(args.logs_dir, 'log_test{}.txt'.format(run))) 62 | else: 63 | while osp.exists("%s" % (osp.join(args.logs_dir, 'log_train{}.txt'.format(run)))): 64 | run += 1 65 | 66 | sys.stdout = Logger(osp.join(args.logs_dir, 'log_train{}.txt'.format(run))) 67 | print("==========\nArgs:{}\n==========".format(args)) 68 | 69 | # 70 | dataset, num_classes, train_loader, query_loader, gallery_loader = \ 71 | get_data(args.dataset, args.split, args.data_dir, 72 | args.batch_size, args.seq_len, args.seq_srd, 73 | args.workers, only_eval=False) 74 | 75 | # create model 76 | cnn_model = models.create(args.arch1, num_features=args.features, dropout=args.dropout, numclasses=num_classes) 77 | siamese_model = models.create(args.arch2, input_num=args.features, output_num=512, class_num=2) 78 | siamese_model_uncorr = models.create('siamese_video', input_num=2048, output_num=512, class_num=2) 79 | 80 | cnn_model = torch.nn.DataParallel(cnn_model).to(device) 81 | siamese_model = siamese_model.to(device) 82 | siamese_model_uncorr = siamese_model_uncorr.to(device) 83 | 84 | # Loss function 85 | criterion_corr = OIMLoss(2048, num_classes, scalar=args.oim_scalar, momentum=args.oim_momentum) 86 | criterion_uncorr = OIMLoss(2048, num_classes, scalar=args.oim_scalar, momentum=args.oim_momentum) 87 | criterion_veri = PairLoss() 88 | 89 | criterion_corr.to(device) 90 | criterion_uncorr.to(device) 91 | criterion_veri.to(device) 92 | 93 | # Optimizer 94 | base_param_ids = set(map(id, cnn_model.module.backbone.parameters())) 95 | new_params = [p for p in cnn_model.parameters() if 96 | id(p) not in base_param_ids] 97 | 98 | param_groups = [ 99 | {'params': cnn_model.module.backbone.parameters(), 'lr_mult': 1}, 100 | {'params': new_params, 'lr_mult': 2}, 101 | {'params': siamese_model.parameters(), 'lr_mult': 2}, 102 | {'params': siamese_model_uncorr.parameters(), 'lr_mult': 2} 103 | ] 104 | 105 | optimizer = torch.optim.SGD(param_groups, lr=args.lr, 106 | momentum=args.momentum, 107 | weight_decay=args.weight_decay, 108 | nesterov=True) 109 | 110 | def adjust_lr(epoch): 111 | lr = args.lr * (0.1 ** (epoch//args.lr_step)) 112 | print(lr) 113 | for g in optimizer.param_groups: 114 | g['lr'] = lr * g.get('lr_mult', 1) 115 | 116 | # Evaluator 测试 117 | evaluator = ATTEvaluator(cnn_model, siamese_model, only_eval=False) 118 | best_top1 = 0 119 | if args.evaluate == 1: 120 | load_best_checkpoint(cnn_model, siamese_model) 121 | top1 = evaluator.evaluate(dataset.query, dataset.gallery, query_loader, gallery_loader, args.logs_dir, args.visual, args.rerank) 122 | print('best rank-1 accuracy is', top1) 123 | else: 124 | # Trainer 训练器,类的实例化 125 | tensorboard_train_logdir = osp.join(args.logs_dir, 'train_log') 126 | remove_repeat_tensorboard_files(tensorboard_train_logdir) 127 | 128 | trainer = SEQTrainer(cnn_model, siamese_model, siamese_model_uncorr, criterion_veri, criterion_corr, criterion_uncorr, 129 | tensorboard_train_logdir) 130 | for epoch in range(args.start_epoch, args.epochs): 131 | adjust_lr(epoch) 132 | trainer.train(epoch, train_loader, optimizer) 133 | 134 | # 每训练3个epoch进行一次评估. 135 | if (epoch+1) % 5 == 0 or (epoch+1) == args.epochs or ((epoch+1) > 30 and (epoch+1) % 3 == 0): 136 | top1 = evaluator.evaluate(dataset.query, dataset.gallery, query_loader, gallery_loader, args.logs_dir, args.visual, args.rerank) 137 | is_best = top1 > best_top1 138 | if is_best: 139 | best_top1 = top1 140 | save_checkpoint(cnn_model, siamese_model, epoch, best_top1, is_best) 141 | del top1 142 | torch.cuda.empty_cache() 143 | 144 | 145 | if __name__ == '__main__': 146 | parser = argparse.ArgumentParser(description="ID Training ResNet Model") 147 | 148 | # DATA 149 | parser.add_argument('-d', '--dataset', type=str, default='mars', 150 | choices=['ilidsvidsequence', 'prid2011sequence', 'mars', 'duke']) 151 | parser.add_argument('-b', '--batch-size', type=int, default=16) 152 | 153 | parser.add_argument('-j', '--workers', type=int, default=8) 154 | 155 | parser.add_argument('--seq_len', type=int, default=8) 156 | 157 | parser.add_argument('--seq_srd', type=int, default=4) 158 | 159 | parser.add_argument('--split', type=int, default=0) 160 | 161 | # MODEL 162 | # CNN model 163 | parser.add_argument('--arch1', type=str, default='resnet50_grl', 164 | choices=['resnet50_grl', 'resnet50']) 165 | parser.add_argument('--features', type=int, default=2048) 166 | parser.add_argument('--dropout', type=float, default=0.0) 167 | 168 | # Siamese model 169 | parser.add_argument('--arch2', type=str, default='siamese', 170 | choices=models.names()) 171 | 172 | # Criterion model 173 | parser.add_argument('--loss', type=str, default='oim', 174 | choices=['xentropy', 'oim', 'triplet']) 175 | parser.add_argument('--oim-scalar', type=float, default=30) 176 | parser.add_argument('--oim-momentum', type=float, default=0.5) 177 | parser.add_argument('--sampling-rate', type=int, default=3) 178 | parser.add_argument('--sample_method', type=str, default='rrs') 179 | 180 | # OPTIMIZER 181 | parser.add_argument('--seed', type=int, default=0) 182 | parser.add_argument('--lr', type=float, default=0.001) 183 | 184 | parser.add_argument('--lr_step', type=float, default=15) 185 | 186 | parser.add_argument('--momentum', type=float, default=0.9) 187 | parser.add_argument('--weight-decay', type=float, default=5e-4) 188 | parser.add_argument('--cnn_resume', type=str, default='', metavar='PATH') 189 | 190 | # TRAINER 191 | parser.add_argument('--start-epoch', type=int, default=0) 192 | parser.add_argument('--epochs', type=int, default=60) 193 | # EVAL 194 | parser.add_argument('--evaluate', type=int, default=0) 195 | parser.add_argument('--visual', type=int, default=0, help='visual the result') 196 | parser.add_argument('--rerank', type=int, default=0, help='rerank the result') 197 | # misc 198 | working_dir = osp.dirname(osp.abspath(__file__)) 199 | parser.add_argument('--data-dir', type=str, metavar='PATH', 200 | default='') 201 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 202 | default=osp.join(working_dir, 'log/grl')) 203 | 204 | args = parser.parse_args() 205 | 206 | # main function 207 | main(args) 208 | -------------------------------------------------------------------------------- /reid/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/__init__.py -------------------------------------------------------------------------------- /reid/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reid/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .sampler import * 3 | from .datasequence import Datasequence 4 | from .seqpreprocessor import SeqTrainPreprocessor 5 | from .seqpreprocessor import SeqTestPreprocessor 6 | from .dataloader import get_data 7 | # from .video_loader import VideoDataset 8 | -------------------------------------------------------------------------------- /reid/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reid/data/__pycache__/dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/data/__pycache__/dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /reid/data/__pycache__/dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/data/__pycache__/dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /reid/data/__pycache__/datasequence.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/data/__pycache__/datasequence.cpython-36.pyc -------------------------------------------------------------------------------- /reid/data/__pycache__/datasequence.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/data/__pycache__/datasequence.cpython-37.pyc -------------------------------------------------------------------------------- /reid/data/__pycache__/sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/data/__pycache__/sampler.cpython-36.pyc -------------------------------------------------------------------------------- /reid/data/__pycache__/sampler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/data/__pycache__/sampler.cpython-37.pyc -------------------------------------------------------------------------------- /reid/data/__pycache__/seqpreprocessor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/data/__pycache__/seqpreprocessor.cpython-36.pyc -------------------------------------------------------------------------------- /reid/data/__pycache__/seqpreprocessor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/data/__pycache__/seqpreprocessor.cpython-37.pyc -------------------------------------------------------------------------------- /reid/data/__pycache__/seqtransforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/data/__pycache__/seqtransforms.cpython-36.pyc -------------------------------------------------------------------------------- /reid/data/__pycache__/seqtransforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/data/__pycache__/seqtransforms.cpython-37.pyc -------------------------------------------------------------------------------- /reid/data/__pycache__/video_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/data/__pycache__/video_loader.cpython-36.pyc -------------------------------------------------------------------------------- /reid/data/__pycache__/video_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/data/__pycache__/video_loader.cpython-37.pyc -------------------------------------------------------------------------------- /reid/data/dataloader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os.path as osp 3 | from torch.utils.data import DataLoader 4 | from reid.dataset import get_sequence 5 | from reid.data import seqtransforms as T 6 | from reid.data import SeqTrainPreprocessor 7 | from reid.data import SeqTestPreprocessor 8 | from reid.data import RandomPairSampler, RandomPairSamplerForMars 9 | from reid.data.video_loader import VideoDataset 10 | 11 | 12 | def get_data(dataset_name, split_id, data_dir, batch_size, seq_len, seq_srd, workers, only_eval): 13 | 14 | if dataset_name != 'mars' and dataset_name != 'duke': 15 | root = osp.join(data_dir, dataset_name) 16 | dataset = get_sequence(dataset_name, root, split_id=split_id, 17 | seq_len=seq_len, seq_srd=seq_srd, num_val=1, download=True) 18 | train_set = dataset.trainval 19 | num_classes = dataset.num_trainval_ids 20 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 21 | 22 | train_processor = SeqTrainPreprocessor(train_set, dataset, seq_len, 23 | transform=T.Compose([T.RectScale(256, 128), 24 | T.RandomHorizontalFlip(), 25 | T.RandomSizedEarser(), 26 | T.ToTensor(), normalizer])) 27 | 28 | query_processor = SeqTestPreprocessor(dataset.query, dataset, seq_len, 29 | transform=T.Compose([T.RectScale(256, 128), 30 | T.ToTensor(), normalizer])) 31 | 32 | gallery_processor = SeqTestPreprocessor(dataset.gallery, dataset, seq_len, 33 | transform=T.Compose([T.RectScale(256, 128), 34 | T.ToTensor(), normalizer])) 35 | 36 | train_loader = DataLoader(train_processor, batch_size=batch_size, num_workers=workers, 37 | sampler=RandomPairSampler(train_set), pin_memory=True, drop_last=True) 38 | 39 | query_loader = DataLoader(query_processor, batch_size=8, 40 | num_workers=workers, shuffle=False, pin_memory=True, drop_last=False) 41 | 42 | gallery_loader = DataLoader(gallery_processor, batch_size=8, 43 | num_workers=workers, shuffle=False, pin_memory=True, drop_last=False) 44 | 45 | else: 46 | dataset = get_sequence(dataset_name) # mars数据集 47 | train_set = dataset.train # 8298 48 | 49 | num_classes = dataset.num_train_pids # 625 50 | 51 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 52 | 53 | train_processor = VideoDataset(train_set, seq_len=seq_len, sample='rrs_train', 54 | transform=T.Compose([T.RectScale(256, 128), 55 | T.RandomHorizontalFlip(), 56 | T.RandomSizedEarser(), 57 | T.ToTensor(), normalizer])) 58 | 59 | 60 | if only_eval: 61 | sampler_method = 'dense' 62 | batch_size_eval = 1 63 | else: 64 | sampler_method = 'rrs_test' 65 | batch_size_eval = 30 66 | query_processor = VideoDataset(dataset.query, seq_len=seq_len, sample=sampler_method, 67 | transform=T.Compose([T.RectScale(256, 128), 68 | T.ToTensor(), normalizer])) 69 | 70 | gallery_processor = VideoDataset(dataset.gallery, seq_len=seq_len, sample=sampler_method, 71 | transform=T.Compose([T.RectScale(256, 128), 72 | T.ToTensor(), normalizer])) 73 | 74 | train_loader = DataLoader(train_processor, batch_size=batch_size, num_workers=workers, 75 | sampler=RandomPairSamplerForMars(train_set), pin_memory=True, drop_last=True) 76 | 77 | query_loader = DataLoader(query_processor, batch_size=batch_size_eval, shuffle=False, pin_memory=True, drop_last=False) 78 | 79 | gallery_loader = DataLoader(gallery_processor, batch_size=batch_size_eval, shuffle=False, pin_memory=True, drop_last=False) 80 | 81 | return dataset, num_classes, train_loader, query_loader, gallery_loader 82 | -------------------------------------------------------------------------------- /reid/data/datasequence.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os.path as osp 3 | import numpy as np 4 | from utils.serialization import read_json 5 | import torch 6 | 7 | 8 | def _pluckseq(identities, indices, seq_len, seq_str): 9 | ret = [] 10 | for index, pid in enumerate(indices): 11 | pid_images = identities[pid] 12 | for camid, cam_images in enumerate(pid_images): 13 | seqall = len(cam_images) 14 | seq_inds = [(start_ind, start_ind + seq_len)\ 15 | for start_ind in range(0, seqall-seq_len, seq_str)] 16 | 17 | if not seq_inds: 18 | seq_inds = [(0, seqall)] 19 | for seq_ind in seq_inds: 20 | ret.append((seq_ind[0], seq_ind[1], pid, index, camid)) 21 | return ret 22 | 23 | 24 | class Datasequence(object): 25 | def __init__(self, root, split_id=0): 26 | self.root = root 27 | self.split_id = split_id 28 | self.meta = None 29 | self.split = None 30 | self.train, self.val, self.trainval = [], [], [] 31 | self.query, self.gallery = [], [] 32 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 33 | self.identities = [] 34 | 35 | @property 36 | def images_dir(self): 37 | return osp.join(self.root, 'images') 38 | 39 | def load(self, seq_len, seq_str, num_val=0.3, verbose=True): 40 | splits = read_json(osp.join(self.root, 'splits.json')) # 根据splits.json文件 41 | if self.split_id >= len(splits): 42 | raise ValueError("split_id exceeds total splits {}" 43 | .format(len(splits))) 44 | 45 | self.split = splits[self.split_id] 46 | 47 | # Randomly split train / val 48 | trainval_pids = np.asarray(self.split['trainval']) # 100个元素的数组 49 | np.random.shuffle(trainval_pids) 50 | num = len(trainval_pids) # 100 51 | 52 | if isinstance(num_val, float): 53 | num_val = int(round(num * num_val)) 54 | if num_val >= num or num_val < 0: 55 | raise ValueError("num_val exceeds total identities {}" 56 | .format(num)) 57 | 58 | train_pids = sorted(trainval_pids[:-num_val]) # 99 59 | val_pids = sorted(trainval_pids[-num_val:]) # 1 60 | 61 | # comments validation set changes every time it loads 62 | 63 | self.meta = read_json(osp.join(self.root, 'meta.json')) # 字典 64 | identities = self.meta['identities'] 65 | self.identities = identities 66 | self.train = _pluckseq(identities, train_pids, seq_len, seq_str) # 这里确定tracklets的长度 67 | self.val = _pluckseq(identities, val_pids, seq_len, seq_str) 68 | self.trainval = _pluckseq(identities, trainval_pids, seq_len, seq_str) 69 | # res = len(self.trainval) % 4 70 | # length1 = len(self.trainval) - res 71 | # length2 = len(self.val) - res 72 | # self.val = self.val[0:length2] 73 | # self.trainval = self.trainval[0:length1] 74 | self.num_train_ids = len(train_pids) 75 | self.num_val_ids = len(val_pids) 76 | self.num_trainval_ids = len(trainval_pids) 77 | 78 | if verbose: 79 | print(self.__class__.__name__, "dataset loaded") 80 | print(" subset | # ids | # sequences") 81 | print(" ---------------------------") 82 | print(" train | {:5d} | {:8d}" 83 | .format(self.num_train_ids, len(self.train))) 84 | print(" val | {:5d} | {:8d}" 85 | .format(self.num_val_ids, len(self.val))) 86 | print(" trainval | {:5d} | {:8d}" 87 | .format(self.num_trainval_ids, len(self.trainval))) 88 | print(" query | {:5d} | {:8d}" 89 | .format(len(self.split['query']), len(self.split['query']))) 90 | print(" gallery | {:5d} | {:8d}" 91 | .format(len(self.split['gallery']), len(self.split['gallery']))) 92 | 93 | def _check_integrity(self): 94 | return osp.isdir(osp.join(self.root, 'images')) and \ 95 | osp.isfile(osp.join(self.root, 'meta.json')) and \ 96 | osp.isfile(osp.join(self.root, 'splits.json')) 97 | -------------------------------------------------------------------------------- /reid/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from torch.utils.data.sampler import ( 8 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 9 | WeightedRandomSampler) 10 | 11 | 12 | def No_index(a, b): # a = : [2, 2, 3, 4, 4, 4] b = 4 13 | assert isinstance(a, list) 14 | return [i for i, j in enumerate(a) if j != b] 15 | 16 | 17 | class RandomIdentitySampler(Sampler): 18 | 19 | def __init__(self, data_source, num_instances=1): 20 | self.data_source = data_source 21 | self.num_instances = num_instances 22 | self.index_dic = defaultdict(list) 23 | for index, (_, pid, _) in enumerate(data_source): 24 | self.index_dic[pid].append(index) 25 | self.pids = list(self.index_dic.keys()) 26 | self.num_samples = len(data_source) 27 | 28 | def __len__(self): 29 | return self.num_samples * self.num_instances 30 | 31 | def __iter__(self): 32 | indices = torch.randperm(self.num_samples) 33 | ret = [] 34 | for i in indices: 35 | pid = self.pids[i] 36 | t = self.index_dic[pid] 37 | if len(t) >= self.num_instances: 38 | t = np.random.choice(t, size=self.num_instances, replace=False) 39 | else: 40 | t = np.random.choice(t, size=self.num_instances, replace=True) 41 | ret.extend(t) 42 | return iter(ret) 43 | 44 | 45 | class RandomPairSampler(Sampler): 46 | def __init__(self, data_source): 47 | self.data_source = data_source 48 | self.index_pid = defaultdict(int) 49 | self.pid_cam = defaultdict(list) 50 | self.pid_index = defaultdict(list) 51 | self.num_samples = len(data_source) 52 | for index, (_, _, _, pid, cam) in enumerate(data_source): 53 | self.index_pid[index] = pid 54 | self.pid_cam[pid].append(cam) 55 | self.pid_index[pid].append(index) 56 | 57 | def __len__(self): 58 | return self.num_samples * 2 59 | 60 | def __iter__(self): 61 | indices = torch.randperm(self.num_samples) 62 | ret = [] 63 | for i in indices: 64 | i = int(i) 65 | _, _, i_label, i_pid, i_cam = self.data_source[i] # relabel ? 66 | ret.append(i) 67 | pid_i = self.index_pid[i] 68 | cams = self.pid_cam[pid_i] 69 | index = self.pid_index[pid_i] 70 | select_cams = No_index(cams, i_cam) 71 | try: 72 | select_camind = np.random.choice(select_cams) 73 | except ValueError: 74 | print(cams) 75 | print(pid_i) 76 | print(i_label) 77 | select_ind = index[select_camind] 78 | ret.append(select_ind) 79 | 80 | return iter(ret) 81 | 82 | 83 | class RandomPairSamplerForMars(Sampler): 84 | def __init__(self, data_source): 85 | self.data_source = data_source 86 | self.index_pid = defaultdict(int) 87 | self.pid_cam = defaultdict(list) 88 | self.pid_index = defaultdict(list) 89 | self.num_samples = len(data_source) 90 | for index, (_, pid, cam) in enumerate(data_source): 91 | self.index_pid[index] = pid 92 | self.pid_cam[pid].append(cam) 93 | self.pid_index[pid].append(index) 94 | 95 | def __len__(self): 96 | return self.num_samples * 2 97 | 98 | def __iter__(self): 99 | indices = torch.randperm(self.num_samples) # 8298 打乱顺序 100 | ret = [] 101 | for i in indices: 102 | i = int(i) # 第3367行 103 | _, i_pid, i_cam = self.data_source[i] # relabel ? 104 | ret.append(i) # [3367] 105 | pid_i = self.index_pid[i] # pid_i = 182 106 | cams = self.pid_cam[pid_i] # : [2, 2, 3, 4, 4, 4] 107 | index = self.pid_index[pid_i] # : [3363, 3364, 3365, 3366, 3367, 3368] 108 | if len(set(cams)) == 1: # 只有1个cam 109 | if len(index) == 1: # 只有1个cam并且只有一个tracklet 110 | select_camind = 0 111 | else: 112 | select_cams = No_index(index, i) 113 | select_camind = np.random.choice(select_cams) 114 | else: 115 | select_cams = No_index(cams, i_cam) 116 | try: 117 | select_camind = np.random.choice(select_cams) 118 | except ValueError: 119 | print(cams) 120 | print(pid_i) 121 | # print(i_label) 122 | select_ind = index[select_camind] 123 | ret.append(select_ind) 124 | 125 | return iter(ret) 126 | -------------------------------------------------------------------------------- /reid/data/samplers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import copy 3 | import random 4 | import torch 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | from torch.utils.data.sampler import Sampler 9 | 10 | class RandomIdentitySampler(Sampler): 11 | """ 12 | Randomly sample N identities, then for each identity, 13 | randomly sample K instances, therefore batch size is N*K. 14 | 15 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py. 16 | 17 | Args: 18 | data_source (Dataset): dataset to sample from. 19 | num_instances (int): number of instances per identity. 20 | """ 21 | def __init__(self, data_source, num_instances=4): 22 | super(RandomIdentitySampler).__init__() 23 | self.data_source = data_source 24 | self.num_instances = num_instances 25 | self.index_dic = defaultdict(list) 26 | for index, (_, pid, _, sid) in enumerate(data_source): 27 | self.index_dic[pid].append(index) 28 | self.pids = list(self.index_dic.keys()) 29 | self.num_identities = len(self.pids) 30 | 31 | def __iter__(self): 32 | indices = torch.randperm(self.num_identities) 33 | ret = [] 34 | for i in indices: 35 | pid = self.pids[i] 36 | t = self.index_dic[pid] 37 | replace = False if len(t) >= self.num_instances else True 38 | t = np.random.choice(t, size=self.num_instances, replace=replace) 39 | ret.extend(t) 40 | # print(ret) 41 | return iter(ret) 42 | 43 | def __len__(self): 44 | return self.num_identities * self.num_instances 45 | class RandomIdentitySamplerStrongBasaline(Sampler): 46 | """ 47 | Randomly sample N identities, then for each identity, 48 | randomly sample K instances, therefore batch size is N*K. 49 | Args: 50 | - data_source (list): list of (img_path, pid, camid). 51 | - num_instances (int): number of instances per identity in a batch. 52 | - batch_size (int): number of examples in a batch. 53 | """ 54 | 55 | def __init__(self, data_source, batch_size, num_instances): 56 | self.data_source = data_source 57 | self.batch_size = batch_size 58 | self.num_instances = num_instances 59 | self.num_pids_per_batch = self.batch_size // self.num_instances 60 | self.index_dic = defaultdict(list) 61 | for index, (_, pid, _) in enumerate(self.data_source): 62 | self.index_dic[pid].append(index) 63 | self.pids = list(self.index_dic.keys()) 64 | 65 | # estimate number of examples in an epoch 66 | self.length = 0 67 | for pid in self.pids: 68 | idxs = self.index_dic[pid] 69 | num = len(idxs) 70 | if num < self.num_instances: 71 | num = self.num_instances 72 | self.length += num - num % self.num_instances 73 | 74 | def __iter__(self): 75 | batch_idxs_dict = defaultdict(list) 76 | 77 | for pid in self.pids: 78 | idxs = copy.deepcopy(self.index_dic[pid]) 79 | if len(idxs) < self.num_instances: 80 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 81 | random.shuffle(idxs) 82 | batch_idxs = [] 83 | for idx in idxs: 84 | batch_idxs.append(idx) 85 | if len(batch_idxs) == self.num_instances: 86 | batch_idxs_dict[pid].append(batch_idxs) 87 | batch_idxs = [] 88 | 89 | avai_pids = copy.deepcopy(self.pids) 90 | final_idxs = [] 91 | 92 | while len(avai_pids) >= self.num_pids_per_batch: 93 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 94 | for pid in selected_pids: 95 | batch_idxs = batch_idxs_dict[pid].pop(0) 96 | final_idxs.extend(batch_idxs) 97 | if len(batch_idxs_dict[pid]) == 0: 98 | avai_pids.remove(pid) 99 | 100 | self.length = len(final_idxs) 101 | return iter(final_idxs) 102 | 103 | def __len__(self): 104 | return self.length 105 | 106 | 107 | # New add by gu 108 | class RandomIdentitySampler_alignedreid(Sampler): 109 | """ 110 | Randomly sample N identities, then for each identity, 111 | randomly sample K instances, therefore batch size is N*K. 112 | 113 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py. 114 | 115 | Args: 116 | data_source (Dataset): dataset to sample from. 117 | num_instances (int): number of instances per identity. 118 | """ 119 | def __init__(self, data_source, num_instances): 120 | self.data_source = data_source 121 | self.num_instances = num_instances 122 | self.index_dic = defaultdict(list) 123 | for index, (_, pid, _) in enumerate(data_source): 124 | self.index_dic[pid].append(index) 125 | self.pids = list(self.index_dic.keys()) 126 | self.num_identities = len(self.pids) 127 | 128 | def __iter__(self): 129 | indices = torch.randperm(self.num_identities) 130 | ret = [] 131 | for i in indices: 132 | pid = self.pids[i] 133 | t = self.index_dic[pid] 134 | replace = False if len(t) >= self.num_instances else True 135 | t = np.random.choice(t, size=self.num_instances, replace=replace) 136 | ret.extend(t) 137 | return iter(ret) 138 | 139 | def __len__(self): 140 | return self.num_identities * self.num_instances -------------------------------------------------------------------------------- /reid/data/seqpreprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os.path as osp 3 | import torch 4 | from PIL import Image 5 | 6 | 7 | class SeqTrainPreprocessor(object): 8 | def __init__(self, seqset, dataset, seq_len, transform=None): 9 | super(SeqTrainPreprocessor, self).__init__() 10 | self.seqset = seqset 11 | self.identities = dataset.identities 12 | self.transform = transform 13 | self.seq_len = seq_len 14 | self.root = [dataset.images_dir] 15 | self.root.append(dataset.other_dir) 16 | 17 | def __len__(self): 18 | return len(self.seqset) 19 | 20 | def __getitem__(self, indices): 21 | if isinstance(indices, (tuple, list)): 22 | return [self._get_single_item(index) for index in indices] 23 | return self._get_single_item(indices) 24 | 25 | def _get_single_item(self, index): 26 | 27 | start_ind, end_ind, pid, label, camid = self.seqset[index] 28 | 29 | imgseq = [] 30 | flowseq = [] 31 | for ind in range(start_ind, end_ind): 32 | fname = self.identities[pid][camid][ind] 33 | fpath_img = osp.join(self.root[0], fname) 34 | imgrgb = Image.open(fpath_img).convert('RGB') 35 | fpath_flow = osp.join(self.root[1], fname) 36 | flowrgb = Image.open(fpath_flow).convert('RGB') 37 | imgseq.append(imgrgb) 38 | flowseq.append(flowrgb) 39 | 40 | while len(imgseq) < self.seq_len: 41 | imgseq.append(imgrgb) 42 | flowseq.append(flowrgb) 43 | 44 | seq = [imgseq, flowseq] 45 | 46 | if self.transform is not None: 47 | seq = self.transform(seq) 48 | 49 | img_tensor = torch.stack(seq[0], 0) 50 | 51 | flow_tensor = torch.stack(seq[1], 0) 52 | 53 | return img_tensor, flow_tensor, label, camid 54 | 55 | 56 | class SeqTestPreprocessor(object): 57 | 58 | def __init__(self, seqset, dataset, seq_len, transform=None): 59 | super(SeqTestPreprocessor, self).__init__() 60 | self.seqset = seqset 61 | self.identities = dataset.identities 62 | self.transform = transform 63 | self.seq_len = seq_len 64 | self.root = [dataset.images_dir] 65 | self.root.append(dataset.other_dir) 66 | 67 | def __len__(self): 68 | return len(self.seqset) 69 | 70 | def __getitem__(self, indices): 71 | if isinstance(indices, (tuple, list)): 72 | return [self._get_single_item(index) for index in indices] 73 | return self._get_single_item(indices) 74 | 75 | def _get_single_item(self, index): 76 | 77 | start_ind, end_ind, pid, label, camid = self.seqset[index] 78 | 79 | imgseq = [] 80 | flowseq = [] 81 | for ind in range(start_ind, end_ind): 82 | fname = self.identities[pid][camid][ind] 83 | fpath_img = osp.join(self.root[0], fname) 84 | imgrgb = Image.open(fpath_img).convert('RGB') 85 | fpath_flow = osp.join(self.root[1], fname) 86 | flowrgb = Image.open(fpath_flow).convert('RGB') 87 | imgseq.append(imgrgb) 88 | flowseq.append(flowrgb) 89 | 90 | while len(imgseq) < self.seq_len: 91 | imgseq.append(imgrgb) 92 | flowseq.append(flowrgb) 93 | 94 | seq = [imgseq, flowseq] 95 | 96 | if self.transform is not None: 97 | seq = self.transform(seq) 98 | 99 | img_tensor = torch.stack(seq[0], 0) 100 | 101 | if len(self.root) == 2: 102 | flow_tensor = torch.stack(seq[1], 0) 103 | else: 104 | flow_tensor = None 105 | 106 | return img_tensor, flow_tensor, pid, camid 107 | -------------------------------------------------------------------------------- /reid/data/seqtransforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import random 4 | from PIL import Image, ImageOps 5 | import numpy as np 6 | 7 | 8 | class Compose(object): 9 | """Composes several transforms together. 10 | 11 | Args: 12 | transforms (List[Transform]): list of transforms to compose. 13 | 14 | Example: 15 | >>> transforms.Compose([ 16 | >>> transforms.CenterCrop(10), 17 | >>> transforms.ToTensor(), 18 | >>> ]) 19 | """ 20 | 21 | def __init__(self, transforms): 22 | self.transforms = transforms 23 | 24 | def __call__(self, seqs): 25 | for t in self.transforms: 26 | seqs = t(seqs) 27 | return seqs 28 | 29 | 30 | class RectScale(object): 31 | def __init__(self, height, width, interpolation=Image.BILINEAR): 32 | self.height = height 33 | self.width = width 34 | self.interpolation = interpolation 35 | 36 | def __call__(self, seqs): # seqs = list[[image0,...,image8]] 37 | modallen = len(seqs) # 1个list 38 | framelen = len(seqs[0]) # 8 39 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)] # : [[[], [], [], [], [], [], [], []]] 40 | 41 | for modal_ind, modal in enumerate(seqs): # 遍历modal,总共有1个 42 | for frame_ind, frame in enumerate(modal): # 遍历每一帧图片 43 | w, h = frame.size # w:128,h:256 44 | if h == self.height and w == self.width: 45 | new_seqs[modal_ind][frame_ind] = frame 46 | else: 47 | new_seqs[modal_ind][frame_ind] = frame.resize((self.width, self.height), self.interpolation) 48 | 49 | return new_seqs 50 | 51 | 52 | class RandomSizedRectCrop(object): 53 | def __init__(self, height, width, interpolation=Image.BILINEAR): 54 | self.height = height 55 | self.width = width 56 | self.interpolation = interpolation 57 | 58 | def __call__(self, seqs): 59 | sample_img = seqs[0][0] 60 | for attempt in range(10): 61 | area = sample_img.size[0] * sample_img.size[1] 62 | target_area = random.uniform(0.64, 1.0) * area 63 | aspect_ratio = random.uniform(2, 3) 64 | 65 | h = int(round(math.sqrt(target_area * aspect_ratio))) 66 | w = int(round(math.sqrt(target_area / aspect_ratio))) 67 | 68 | if w <= sample_img.size[0] and h <= sample_img.size[1]: 69 | x1 = random.randint(0, sample_img.size[0] - w) 70 | y1 = random.randint(0, sample_img.size[1] - h) 71 | 72 | sample_img = sample_img.crop((x1, y1, x1 + w, y1 + h)) 73 | assert (sample_img.size == (w, h)) 74 | modallen = len(seqs) 75 | framelen = len(seqs[0]) 76 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)] 77 | 78 | for modal_ind, modal in enumerate(seqs): 79 | for frame_ind, frame in enumerate(modal): 80 | 81 | frame = frame.crop((x1, y1, x1 + w, y1 + h)) 82 | new_seqs[modal_ind][frame_ind] = frame.resize((self.width, self.height), self.interpolation) 83 | 84 | return new_seqs 85 | 86 | # Fallback 87 | scale = RectScale(self.height, self.width, 88 | interpolation=self.interpolation) 89 | return scale(seqs) 90 | 91 | 92 | class RandomSizedEarser(object): 93 | 94 | def __init__(self, sl=0.02, sh=0.2, asratio=0.3, p=0.5): 95 | self.sl = sl 96 | self.sh = sh 97 | self.asratio = asratio 98 | self.p = p 99 | 100 | def __call__(self, seqs): 101 | modallen = len(seqs) 102 | framelen = len(seqs[0]) 103 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)] 104 | for modal_ind, modal in enumerate(seqs): 105 | for frame_ind, frame in enumerate(modal): 106 | p1 = random.uniform(0.0, 1.0) 107 | W = frame.size[0] 108 | H = frame.size[1] 109 | area = H * W 110 | 111 | if p1 > self.p: 112 | new_seqs[modal_ind][frame_ind] = frame 113 | else: 114 | gen = True 115 | while gen: 116 | Se = random.uniform(self.sl, self.sh) * area 117 | re = random.uniform(self.asratio, 1 / self.asratio) 118 | He = np.sqrt(Se * re) 119 | We = np.sqrt(Se / re) 120 | xe = random.uniform(0, W - We) 121 | ye = random.uniform(0, H - He) 122 | if xe + We <= W and ye + He <= H and xe > 0 and ye > 0: 123 | x1 = int(np.ceil(xe)) 124 | y1 = int(np.ceil(ye)) 125 | x2 = int(np.floor(x1 + We)) 126 | y2 = int(np.floor(y1 + He)) 127 | part1 = frame.crop((x1, y1, x2, y2)) 128 | Rc = random.randint(0, 255) 129 | Gc = random.randint(0, 255) 130 | Bc = random.randint(0, 255) 131 | I = Image.new('RGB', part1.size, (Rc, Gc, Bc)) 132 | frame.paste(I, part1.size) 133 | break 134 | 135 | new_seqs[modal_ind][frame_ind] = frame 136 | 137 | return new_seqs 138 | 139 | 140 | class RandomHorizontalFlip(object): 141 | """Randomly horizontally flips the given PIL.Image Sequence with a probability of 0.5 142 | """ 143 | def __call__(self, seqs): 144 | if random.random() < 0.5: 145 | modallen = len(seqs) 146 | framelen = len(seqs[0]) 147 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)] 148 | for modal_ind, modal in enumerate(seqs): 149 | for frame_ind, frame in enumerate(modal): 150 | new_seqs[modal_ind][frame_ind] = frame.transpose(Image.FLIP_LEFT_RIGHT) 151 | return new_seqs 152 | return seqs 153 | 154 | 155 | class ToTensor(object): 156 | 157 | def __call__(self, seqs): 158 | modallen = len(seqs) 159 | framelen = len(seqs[0]) 160 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)] 161 | pic = seqs[0][0] 162 | 163 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 164 | if pic.mode == 'YCbCr': 165 | nchannel = 3 166 | elif pic.mode == 'I;16': 167 | nchannel = 1 168 | else: 169 | nchannel = len(pic.mode) 170 | 171 | if pic.mode == 'I': 172 | for modal_ind, modal in enumerate(seqs): 173 | for frame_ind, frame in enumerate(modal): 174 | img = torch.from_numpy(np.array(frame, np.int32, copy=False)) 175 | img = img.view(pic.size[1], pic.size[0], nchannel) 176 | new_seqs[modal_ind][frame_ind] = img.transpose(0, 1).transpose(0, 2).contiguous() 177 | 178 | elif pic.mode == 'I;16': 179 | for modal_ind, modal in enumerate(seqs): 180 | for frame_ind, frame in enumerate(modal): 181 | img = torch.from_numpy(np.array(frame, np.int16, copy=False)) 182 | img = img.view(pic.size[1], pic.size[0], nchannel) 183 | new_seqs[modal_ind][frame_ind] = img.transpose(0, 1).transpose(0, 2).contiguous() 184 | else: 185 | for modal_ind, modal in enumerate(seqs): 186 | for frame_ind, frame in enumerate(modal): 187 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(frame.tobytes())) 188 | img = img.view(pic.size[1], pic.size[0], nchannel) 189 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 190 | new_seqs[modal_ind][frame_ind] = img.float().div(255) 191 | 192 | return new_seqs 193 | 194 | 195 | class Normalize(object): 196 | """Given mean: (R, G, B) and std: (R, G, B), 197 | will normalize each channel of the torch.*Tensor, i.e. 198 | channel = (channel - mean) / std 199 | """ 200 | def __init__(self, mean, std): 201 | self.mean = mean 202 | self.std = std 203 | 204 | def __call__(self, seqs): 205 | # TODO: make efficient 206 | modallen = len(seqs) 207 | framelen = len(seqs[0]) 208 | new_seqs = [[[] for _ in range(framelen)] for _ in range(modallen)] 209 | 210 | for modal_ind, modal in enumerate(seqs): 211 | for frame_ind, frame in enumerate(modal): 212 | for t, m, s in zip(frame, self.mean, self.std): 213 | t.sub_(m).div_(s) 214 | new_seqs[modal_ind][frame_ind] = frame 215 | 216 | return new_seqs 217 | -------------------------------------------------------------------------------- /reid/data/video_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | import math 6 | import torch 7 | from torch.utils.data import Dataset 8 | import random 9 | 10 | 11 | class VideoDataset(Dataset): 12 | """Video Person ReID Dataset. 13 | Note batch data has shape (batch, seq_len, channel, height, width). 14 | """ 15 | sample_methods = ['evenly', 'random', 'dense'] 16 | 17 | def __init__(self, dataset, seq_len=15, sample='evenly', transform=None): 18 | self.dataset = dataset 19 | self.seq_len = seq_len 20 | self.sample = sample 21 | self.transform = transform 22 | 23 | def __len__(self): 24 | return len(self.dataset) 25 | 26 | def __getitem__(self, indices): 27 | if isinstance(indices, (tuple, list)): 28 | return [self.__get_single_item__(index) for index in indices] 29 | return self.__get_single_item__(indices) 30 | 31 | def __get_single_item__(self, index): 32 | S = self.seq_len 33 | img_paths, pid, camid = self.dataset[index] 34 | num = len(img_paths) # 27 35 | """rss 操作""" 36 | sample_clip = [] 37 | frame_indices = list(range(num)) 38 | if num < S: # 8 = chunk的数目,每个tracklet分成8段,每段随机选一帧 39 | strip = list(range(num)) + [frame_indices[-1]] * (S - num) 40 | for s in range(S): 41 | pool = strip[s * 1:(s + 1) * 1] 42 | sample_clip.append(list(pool)) 43 | else: 44 | inter_val = math.ceil(num / S) 45 | strip = list(range(num)) + [frame_indices[-1]] * (inter_val * S - num) 46 | for s in range(S): 47 | pool = strip[inter_val * s:inter_val * (s + 1)] 48 | sample_clip.append(list(pool)) 49 | 50 | sample_clip = np.array(sample_clip) 51 | 52 | if self.sample == 'random': 53 | """ 54 | Randomly sample seq_len consecutive frames from num frames, 55 | if num is smaller than seq_len, then replicate items. 56 | This sampling strategy is used in training phase. 57 | """ 58 | frame_indices = list(range(num)) 59 | rand_end = max(0, len(frame_indices) - self.seq_len - 1) 60 | begin_index = random.randint(0, rand_end) 61 | end_index = min(begin_index + self.seq_len, len(frame_indices)) 62 | 63 | indices = frame_indices[begin_index:end_index] 64 | 65 | for index in indices: 66 | if len(indices) >= self.seq_len: 67 | break 68 | indices.append(index) 69 | indices = np.array(indices) 70 | imgseq = [] 71 | for index in indices: 72 | index = int(index) 73 | img_path = img_paths[index] 74 | img = Image.open(img_path).convert('RGB') # 3x224x112 75 | imgseq.append(img) 76 | 77 | seq = [imgseq] 78 | if self.transform is not None: 79 | seq = self.transform(seq) 80 | 81 | img_tensor = torch.stack(seq[0], dim=0) # seq_len 4x3x224x112 82 | flow_tensor = None 83 | 84 | return img_tensor, pid, camid 85 | 86 | elif self.sample == 'dense': 87 | """ 88 | Sample all frames in a video into a list of clips, each clip contains seq_len frames, batch_size needs to be set to 1. 89 | This sampling strategy is used in test phase. 90 | """ 91 | cur_index = 0 92 | frame_indices = list(range(num)) # 27 93 | indices_list = [] 94 | while num-cur_index > self.seq_len: 95 | indices_list.append(frame_indices[cur_index:cur_index+self.seq_len]) 96 | cur_index += self.seq_len 97 | 98 | last_seq = frame_indices[cur_index:] 99 | 100 | for index in last_seq: 101 | if len(last_seq) >= self.seq_len: 102 | break 103 | last_seq.append(index) 104 | 105 | indices_list.append(last_seq) # : [[0, 1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14, 15], [16, 17, 18, 19, 20, 21, 22, 23], [24, 25, 24, 25, 24, 25, 24, 25]] 106 | imgs_list = [] 107 | for indices in indices_list: # : [0, 1, 2, 3, 4, 5, 6, 7] 108 | imgs = [] 109 | for index in indices: 110 | index = int(index) 111 | img_path = img_paths[index] 112 | img = Image.open(img_path).convert('RGB') 113 | # img = img.unsqueeze(0) 114 | imgs.append(img) 115 | 116 | imgs = [imgs] 117 | if self.transform is not None: 118 | imgs = self.transform(imgs) 119 | imgs = torch.stack(imgs[0], 0) # torch.Size([8, 3, 224, 112]) 120 | imgs_list.append(imgs) 121 | imgs_tensor = torch.stack(imgs_list) # torch.Size([13, 8, 3, 224, 112]) 122 | flow_tensor = None 123 | return imgs_tensor, pid, camid 124 | elif self.sample == 'rrs_train': 125 | idx = np.random.choice(sample_clip.shape[1], sample_clip.shape[0]) 126 | number = sample_clip[np.arange(len(sample_clip)), idx] 127 | # imgseq = [] 128 | img_paths = np.array(list(img_paths)) # img_paths原始为tuple,转换成数组 129 | # flow_paths = np.array([img_path.replace('Mars', 'Mars_optical') for img_path in img_paths]) 130 | imgseq = [Image.open(img_path).convert('RGB') for img_path in img_paths[number]] 131 | # flowseq = [Image.open(flow_path).convert('RGB') for flow_path in flow_paths[number]] 132 | 133 | seq = [imgseq] 134 | # seq = [imgseq, flowseq] 135 | if self.transform is not None: 136 | seq = self.transform(seq) 137 | 138 | img_tensor = torch.stack(seq[0], dim=0) # seq_len 4x3x224x112 139 | # flow_tensor = torch.stack(seq[1], dim=0) # seq_len 4x3x224x112 140 | 141 | return img_tensor, pid, camid 142 | elif self.sample == 'rrs_test': 143 | number = sample_clip[:, 0] 144 | img_paths = np.array(list(img_paths)) # img_paths原始为tuple,转换成数组 145 | # flow_paths = np.array([img_path.replace('Mars', 'Mars_optical') for img_path in img_paths]) 146 | imgseq = [Image.open(img_path).convert('RGB') for img_path in img_paths[number]] 147 | # flowseq = [Image.open(flow_path).convert('RGB') for flow_path in flow_paths[number]] 148 | 149 | seq = [imgseq] 150 | # seq = [imgseq, flowseq] 151 | if self.transform is not None: 152 | seq = self.transform(seq) 153 | img_tensor = torch.stack(seq[0], dim=0) # torch.Size([8, 3, 256, 128]) 154 | # flow_tensor = torch.stack(seq[1], dim=0) 155 | return img_tensor, pid, camid 156 | else: 157 | raise KeyError("Unknown sample method: {}. Expected one of {}".format(self.sample, self.sample_methods)) 158 | -------------------------------------------------------------------------------- /reid/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .ilidsvidsequence import iLIDSVIDSEQUENCE 3 | from .prid2011sequence import PRID2011SEQUENCE 4 | from .mars import Mars 5 | from .duke import DukeMTMCVidReID 6 | 7 | 8 | def get_sequence(name, *args, **kwargs): 9 | __factory = { 10 | 'ilidsvidsequence': iLIDSVIDSEQUENCE, 11 | 'prid2011sequence': PRID2011SEQUENCE, 12 | 'mars': Mars, 13 | 'duke': DukeMTMCVidReID, 14 | } 15 | 16 | if name not in __factory: 17 | raise KeyError("Unknown dataset", name) 18 | return __factory[name](*args, **kwargs) -------------------------------------------------------------------------------- /reid/dataset/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/dataset/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/dataset/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/dataset/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reid/dataset/__pycache__/duke.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/dataset/__pycache__/duke.cpython-36.pyc -------------------------------------------------------------------------------- /reid/dataset/__pycache__/duke.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/dataset/__pycache__/duke.cpython-37.pyc -------------------------------------------------------------------------------- /reid/dataset/__pycache__/ilidsvidsequence.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/dataset/__pycache__/ilidsvidsequence.cpython-36.pyc -------------------------------------------------------------------------------- /reid/dataset/__pycache__/ilidsvidsequence.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/dataset/__pycache__/ilidsvidsequence.cpython-37.pyc -------------------------------------------------------------------------------- /reid/dataset/__pycache__/mars.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/dataset/__pycache__/mars.cpython-36.pyc -------------------------------------------------------------------------------- /reid/dataset/__pycache__/mars.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/dataset/__pycache__/mars.cpython-37.pyc -------------------------------------------------------------------------------- /reid/dataset/__pycache__/prid2011sequence.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/dataset/__pycache__/prid2011sequence.cpython-36.pyc -------------------------------------------------------------------------------- /reid/dataset/__pycache__/prid2011sequence.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/dataset/__pycache__/prid2011sequence.cpython-37.pyc -------------------------------------------------------------------------------- /reid/dataset/duke.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import glob 4 | import re 5 | import sys 6 | import urllib 7 | import tarfile 8 | import zipfile 9 | import os.path as osp 10 | from scipy.io import loadmat 11 | import numpy as np 12 | 13 | import sys 14 | from utils.osutils import mkdir_if_missing 15 | from utils.serialization import write_json, read_json 16 | 17 | 18 | class DukeMTMCVidReID(object): 19 | """ 20 | DukeMTMCVidReID 21 | Reference: 22 | Wu et al. Exploit the Unknown Gradually: One-Shot Video-Based Person 23 | Re-Identification by Stepwise Learning. CVPR 2018. 24 | URL: https://github.com/Yu-Wu/DukeMTMC-VideoReID 25 | 26 | Dataset statistics: 27 | # identities: 702 (train) + 702 (test) 28 | # tracklets: 2196 (train) + 2636 (test) 29 | """ 30 | root = '/home/ycy/data/DukeMTMC-VideoReID' 31 | 32 | def __init__(self, min_seq_len=0, verbose=True, **kwargs): 33 | self.dataset_dir = self.root 34 | self.train_dir = osp.join(self.dataset_dir, 'train') 35 | self.query_dir = osp.join(self.dataset_dir, 'query') 36 | self.gallery_dir = osp.join(self.dataset_dir, 'gallery') 37 | self.split_train_json_path = osp.join(self.dataset_dir, 'split_train.json') 38 | self.split_train_dense_json_path = osp.join(self.dataset_dir, 'split_train_dense.json') 39 | self.split_query_json_path = osp.join(self.dataset_dir, 'split_query.json') 40 | self.split_gallery_json_path = osp.join(self.dataset_dir, 'split_gallery.json') 41 | 42 | self.min_seq_len = min_seq_len 43 | self._check_before_run() 44 | print( 45 | "Note: if root path is changed, the previously generated json files need to be re-generated (so delete them first)") 46 | 47 | train, num_train_tracklets, num_train_pids, num_imgs_train = \ 48 | self._process_dir(self.train_dir, self.split_train_json_path, relabel=True) 49 | train_dense, num_train_tracklets_dense, num_train_pids_dense, num_imgs_train_dense = \ 50 | self._process_dir_dense(self.train_dir, self.split_train_dense_json_path, relabel=True, sampling_step=32) 51 | query, num_query_tracklets, num_query_pids, num_imgs_query = \ 52 | self._process_dir(self.query_dir, self.split_query_json_path, relabel=False) 53 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery = \ 54 | self._process_dir(self.gallery_dir, self.split_gallery_json_path, relabel=False) 55 | 56 | print("the number of tracklets under dense sampling for train set: {}".format(num_train_tracklets_dense)) 57 | 58 | num_imgs_per_tracklet = num_imgs_train + num_imgs_query + num_imgs_gallery 59 | min_num = np.min(num_imgs_per_tracklet) 60 | max_num = np.max(num_imgs_per_tracklet) 61 | avg_num = np.mean(num_imgs_per_tracklet) 62 | 63 | num_total_pids = num_train_pids + num_query_pids 64 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets 65 | 66 | if verbose: 67 | print("=> DukeMTMC-VideoReID loaded") 68 | print("Dataset statistics:") 69 | print(" ------------------------------") 70 | print(" subset | # ids | # tracklets") 71 | print(" ------------------------------") 72 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 73 | print(" train_dense | {:5d} | {:8d}".format(num_train_pids_dense, num_train_tracklets_dense)) 74 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 75 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 76 | print(" ------------------------------") 77 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 78 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 79 | print(" ------------------------------") 80 | 81 | self.train = train 82 | self.train_dense = train_dense 83 | self.query = query 84 | self.gallery = gallery 85 | 86 | self.num_train_pids = num_train_pids 87 | self.num_query_pids = num_query_pids 88 | self.num_gallery_pids = num_gallery_pids 89 | 90 | def _check_before_run(self): 91 | """Check if all files are available before going deeper""" 92 | if not osp.exists(self.dataset_dir): 93 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 94 | if not osp.exists(self.train_dir): 95 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 96 | if not osp.exists(self.query_dir): 97 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 98 | if not osp.exists(self.gallery_dir): 99 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 100 | 101 | def _process_dir(self, dir_path, json_path, relabel): 102 | if osp.exists(json_path): 103 | print("=> {} generated before, awesome!".format(json_path)) 104 | split = read_json(json_path) 105 | return split['tracklets'], split['num_tracklets'], split['num_pids'], split['num_imgs_per_tracklet'] 106 | 107 | print("=> Automatically generating split (might take a while for the first time, have a coffe)") 108 | pdirs = glob.glob(osp.join(dir_path, '*')) # avoid .DS_Store 得到数据集中的所有文件夹 109 | print("Processing {} with {} person identities".format(dir_path, len(pdirs))) 110 | 111 | pid_container = set() # 得到文件夹的名字,即行人的id,集合的形式,一共有702个文件夹,即702个行人id 112 | for pdir in pdirs: 113 | pid = int(osp.basename(pdir)) 114 | pid_container.add(pid) 115 | pid2label = {pid: label for label, pid in enumerate(pid_container)} # relabel。。 116 | 117 | tracklets = [] 118 | num_imgs_per_tracklet = [] # 存放每个tracklet的图片数目的列表 119 | for pdir in pdirs: # 遍历每个子文件夹,得到其中的图片,即每个id对应的视频图片集 120 | pid = int(osp.basename(pdir)) # pid=817.。 121 | if relabel: pid = pid2label[pid] # relabel。。 122 | tdirs = glob.glob(osp.join(pdir, '*')) # 得到文件夹中的所有tracklets,一个id有多个视频序列 123 | for tdir in tdirs: 124 | raw_img_paths = glob.glob(osp.join(tdir, '*.jpg')) # 得到每个tracklet中图片的绝对路径,乱序 125 | num_imgs = len(raw_img_paths) # 162 tracklet的长度=图片的数目 126 | 127 | if num_imgs < self.min_seq_len: 128 | continue 129 | 130 | num_imgs_per_tracklet.append(num_imgs) 131 | img_paths = [] 132 | for img_idx in range(num_imgs): # 在这里,将每个tracklet中图片的乱序索引,进行排序。 133 | # some tracklet starts from 0002 instead of 0001 134 | img_idx_name = 'F' + str(img_idx + 1).zfill(4) # F0001 135 | res = glob.glob(osp.join(tdir, '*' + img_idx_name + '*.jpg')) # 找到对应img索引的图片的绝对路径 136 | if len(res) == 0: # 有些帧的索引可能不存在,这时需要跳过 137 | print("Warn: index name {} in {} is missing, jump to next".format(img_idx_name, tdir)) 138 | continue 139 | img_paths.append(res[0]) 140 | img_name = osp.basename(img_paths[0]) # 图片的格式:'0817_C1_F0001_X207382.jpg' 141 | if img_name.find('_') == -1: 142 | # old naming format: 0001C6F0099X30823.jpg 143 | camid = int(img_name[5]) - 1 144 | else: 145 | # new naming format: 0001_C6_F0099_X30823.jpg 146 | camid = int(img_name[6]) - 1 147 | img_paths = tuple(img_paths) 148 | tracklets.append((img_paths, pid, camid)) # 得到每个tracklet的所有图片的绝对路径,行人id,camid =》 和Mars数据集类似 149 | 150 | num_pids = len(pid_container) # 训练集中的id数目 151 | num_tracklets = len(tracklets) 152 | 153 | print("Saving split to {}".format(json_path)) 154 | split_dict = { 155 | 'tracklets': tracklets, 156 | 'num_tracklets': num_tracklets, 157 | 'num_pids': num_pids, 158 | 'num_imgs_per_tracklet': num_imgs_per_tracklet, 159 | } 160 | write_json(split_dict, json_path) 161 | 162 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 163 | 164 | def _process_dir_dense(self, dir_path, json_path, relabel, sampling_step=32): 165 | if osp.exists(json_path): 166 | print("=> {} generated before, awesome!".format(json_path)) 167 | split = read_json(json_path) 168 | return split['tracklets'], split['num_tracklets'], split['num_pids'], split['num_imgs_per_tracklet'] 169 | 170 | print("=> Automatically generating split (might take a while for the first time, have a coffe)") 171 | pdirs = glob.glob(osp.join(dir_path, '*')) # avoid .DS_Store 172 | print("Processing {} with {} person identities".format(dir_path, len(pdirs))) 173 | 174 | pid_container = set() 175 | for pdir in pdirs: 176 | pid = int(osp.basename(pdir)) 177 | pid_container.add(pid) 178 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 179 | 180 | tracklets = [] 181 | num_imgs_per_tracklet = [] 182 | for pdir in pdirs: 183 | pid = int(osp.basename(pdir)) 184 | if relabel: pid = pid2label[pid] 185 | tdirs = glob.glob(osp.join(pdir, '*')) 186 | for tdir in tdirs: 187 | raw_img_paths = glob.glob(osp.join(tdir, '*.jpg')) 188 | num_imgs = len(raw_img_paths) 189 | 190 | if num_imgs < self.min_seq_len: 191 | continue 192 | 193 | num_imgs_per_tracklet.append(num_imgs) 194 | img_paths = [] 195 | for img_idx in range(num_imgs): 196 | # some tracklet starts from 0002 instead of 0001 197 | img_idx_name = 'F' + str(img_idx + 1).zfill(4) 198 | res = glob.glob(osp.join(tdir, '*' + img_idx_name + '*.jpg')) 199 | if len(res) == 0: 200 | print("Warn: index name {} in {} is missing, jump to next".format(img_idx_name, tdir)) 201 | continue 202 | img_paths.append(res[0]) 203 | img_name = osp.basename(img_paths[0]) 204 | if img_name.find('_') == -1: 205 | # old naming format: 0001C6F0099X30823.jpg 206 | camid = int(img_name[5]) - 1 207 | else: 208 | # new naming format: 0001_C6_F0099_X30823.jpg 209 | camid = int(img_name[6]) - 1 210 | img_paths = tuple(img_paths) 211 | 212 | # dense sampling 213 | num_sampling = len(img_paths) // sampling_step 214 | if num_sampling == 0: 215 | tracklets.append((img_paths, pid, camid)) 216 | else: 217 | for idx in range(num_sampling): 218 | if idx == num_sampling - 1: 219 | tracklets.append((img_paths[idx * sampling_step:], pid, camid)) 220 | else: 221 | tracklets.append((img_paths[idx * sampling_step: (idx + 1) * sampling_step], pid, camid)) 222 | 223 | num_pids = len(pid_container) 224 | num_tracklets = len(tracklets) 225 | 226 | print("Saving split to {}".format(json_path)) 227 | split_dict = { 228 | 'tracklets': tracklets, 229 | 'num_tracklets': num_tracklets, 230 | 'num_pids': num_pids, 231 | 'num_imgs_per_tracklet': num_imgs_per_tracklet, 232 | } 233 | write_json(split_dict, json_path) 234 | 235 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 236 | 237 | 238 | if __name__ == '__main__': 239 | # test 240 | dataset = DukeMTMCVidReID() 241 | -------------------------------------------------------------------------------- /reid/dataset/ilidsvidsequence.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import os.path as osp 4 | from reid.data.datasequence import Datasequence 5 | from utils.osutils import mkdir_if_missing 6 | from utils.serialization import write_json 7 | import tarfile 8 | from glob import glob 9 | import shutil 10 | import scipy.io as sio 11 | 12 | datasetname = 'iLIDS-VID' 13 | flowname = 'Farneback' 14 | 15 | 16 | class infostruct(object): 17 | pass 18 | 19 | 20 | class iLIDSVIDSEQUENCE(Datasequence): 21 | 22 | def __init__(self, root, split_id=0, seq_len=12, seq_srd=6, num_val=1, download=False): 23 | super(iLIDSVIDSEQUENCE, self).__init__(root, split_id=split_id) 24 | 25 | if download: 26 | self.download() 27 | 28 | if not self._check_integrity(): 29 | self.imgextract() 30 | 31 | self.load(seq_len, seq_srd, num_val) 32 | 33 | self.query, query_pid, query_camid, query_num = self._pluckseq_cam(self.identities, self.split['query'], 34 | seq_len, seq_srd, 0) 35 | self.queryinfo = infostruct() 36 | self.queryinfo.pid = query_pid 37 | self.queryinfo.camid = query_camid 38 | self.queryinfo.tranum = query_num 39 | 40 | self.gallery, gallery_pid, gallery_camid, gallery_num = self._pluckseq_cam(self.identities, 41 | self.split['gallery'], 42 | seq_len, seq_srd, 1) 43 | self.galleryinfo = infostruct() 44 | self.galleryinfo.pid = gallery_pid 45 | self.galleryinfo.camid = gallery_camid 46 | self.galleryinfo.tranum = gallery_num 47 | 48 | @property 49 | def other_dir(self): 50 | return osp.join(self.root, 'others') 51 | 52 | def download(self): 53 | 54 | if self._check_integrity(): 55 | print("Files already downloaded and verified") 56 | return 57 | 58 | raw_dir = osp.join(self.root, 'raw') 59 | mkdir_if_missing(raw_dir) 60 | 61 | fpath1 = osp.join(raw_dir, datasetname + '.tar') 62 | fpath2 = osp.join(raw_dir, flowname + '.tar') 63 | 64 | if osp.isfile(fpath1) and osp.isfile(fpath2): 65 | print("Using the download file:" + fpath1 + " " + fpath2) 66 | else: 67 | print("Please firstly download the files") 68 | raise RuntimeError("Downloaded file missing!") 69 | 70 | def imgextract(self): 71 | 72 | raw_dir = osp.join(self.root, 'raw') 73 | exdir1 = osp.join(raw_dir, datasetname) 74 | exdir2 = osp.join(raw_dir, flowname) 75 | fpath1 = osp.join(raw_dir, datasetname + '.tar') 76 | fpath2 = osp.join(raw_dir, flowname + '.tar') 77 | 78 | if not osp.isdir(exdir1): 79 | print("Extracting tar file") 80 | cwd = os.getcwd() 81 | tar = tarfile.open(fpath1) 82 | mkdir_if_missing(exdir1) 83 | os.chdir(exdir1) 84 | tar.extractall() 85 | tar.close() 86 | os.chdir(cwd) 87 | 88 | if not osp.isdir(exdir2): 89 | print("Extracting tar file") 90 | cwd = os.getcwd() 91 | tar = tarfile.open(fpath2) 92 | mkdir_if_missing(exdir2) 93 | os.chdir(exdir2) 94 | tar.extractall() 95 | tar.close() 96 | os.chdir(cwd) 97 | 98 | # reorganzing the dataset 99 | # Format 100 | 101 | temp_images_dir = osp.join(self.root, 'temp_images') 102 | mkdir_if_missing(temp_images_dir) 103 | 104 | temp_others_dir = osp.join(self.root, 'temp_others') 105 | mkdir_if_missing(temp_others_dir) 106 | 107 | images_dir = osp.join(self.root, 'images') 108 | mkdir_if_missing(images_dir) 109 | 110 | others_dir = osp.join(self.root, 'others') 111 | mkdir_if_missing(others_dir) 112 | 113 | fpaths1 = sorted(glob(osp.join(exdir1, 'i-LIDS-VID/sequences', '*/*/*.png'))) 114 | fpaths2 = sorted(glob(osp.join(exdir2, flowname, '*/*/*.png'))) 115 | 116 | identities_imgraw = [[[] for _ in range(2)] for _ in range(319)] 117 | identities_otherraw = [[[] for _ in range(2)] for _ in range(319)] 118 | 119 | # image information 120 | for fpath in fpaths1: 121 | fname = osp.basename(fpath) 122 | fname_list = fname.split('_') 123 | cam_name = fname_list[0] 124 | pid_name = fname_list[1] 125 | cam = int(cam_name[-1]) 126 | pid = int(pid_name[-3:]) 127 | temp_fname = ('{:08d}_{:02d}_{:04d}.png' 128 | .format(pid, cam, len(identities_imgraw[pid - 1][cam - 1]))) 129 | identities_imgraw[pid - 1][cam - 1].append(temp_fname) 130 | shutil.copy(fpath, osp.join(temp_images_dir, temp_fname)) 131 | 132 | identities_temp = [x for x in identities_imgraw if x != [[], []]] 133 | identities_images = identities_temp 134 | 135 | for pid in range(len(identities_temp)): 136 | for cam in range(2): 137 | for img in range(len(identities_images[pid][cam])): 138 | temp_fname = identities_temp[pid][cam][img] 139 | fname = ('{:08d}_{:02d}_{:04d}.png' 140 | .format(pid, cam, img)) 141 | identities_images[pid][cam][img] = fname 142 | shutil.copy(osp.join(temp_images_dir, temp_fname), osp.join(images_dir, fname)) 143 | 144 | shutil.rmtree(temp_images_dir) 145 | 146 | # flow information 147 | 148 | for fpath in fpaths2: 149 | fname = osp.basename(fpath) 150 | fname_list = fname.split('_') 151 | cam_name = fname_list[0] 152 | pid_name = fname_list[1] 153 | cam = int(cam_name[-1]) 154 | pid = int(pid_name[-3:]) 155 | temp_fname = ('{:08d}_{:02d}_{:04d}.png' 156 | .format(pid, cam, len(identities_otherraw[pid - 1][cam - 1]))) 157 | identities_otherraw[pid - 1][cam - 1].append(temp_fname) 158 | shutil.copy(fpath, osp.join(temp_others_dir, temp_fname)) 159 | 160 | identities_temp = [x for x in identities_otherraw if x != [[], []]] 161 | identities_others = identities_temp 162 | 163 | for pid in range(len(identities_temp)): 164 | for cam in range(2): 165 | for img in range(len(identities_others[pid][cam])): 166 | temp_fname = identities_temp[pid][cam][img] 167 | fname = ('{:08d}_{:02d}_{:04d}.png' 168 | .format(pid, cam, img)) 169 | identities_others[pid][cam][img] = fname 170 | shutil.copy(osp.join(temp_others_dir, temp_fname), osp.join(others_dir, fname)) 171 | 172 | shutil.rmtree(temp_others_dir) 173 | 174 | meta = {'name': 'iLIDS-sequence', 'shot': 'sequence', 'num_cameras': 2, 175 | 'identities': identities_images} 176 | 177 | write_json(meta, osp.join(self.root, 'meta.json')) 178 | 179 | # Consider fixed training and testing split 180 | splitmat_name = osp.join(exdir1, 'train-test people splits', 'train_test_splits_ilidsvid.mat') 181 | data = sio.loadmat(splitmat_name) 182 | person_list = data['ls_set'] 183 | num = len(identities_images) 184 | splits = [] 185 | 186 | for i in range(10): 187 | pids = (person_list[i] - 1).tolist() 188 | trainval_pids = sorted(pids[:num // 2]) 189 | test_pids = sorted(pids[num // 2:]) 190 | split = {'trainval': trainval_pids, 191 | 'query': test_pids, 192 | 'gallery': test_pids} 193 | splits.append(split) 194 | write_json(splits, osp.join(self.root, 'splits.json')) 195 | 196 | def _pluckseq_cam(self, identities, indices, seq_len, seq_str, camid): 197 | ret = [] 198 | per_id = [] 199 | cam_id = [] 200 | tra_num = [] 201 | 202 | for index, pid in enumerate(indices): 203 | pid_images = identities[pid] 204 | cam_images = pid_images[camid] 205 | seqall = len(cam_images) 206 | seq_inds = [(start_ind, start_ind + seq_len) for start_ind in range(0, seqall - seq_len, seq_str)] 207 | if not seq_inds: 208 | seq_inds = [(0, seqall)] 209 | for seq_ind in seq_inds: 210 | ret.append((seq_ind[0], seq_ind[1], pid, index, camid)) 211 | per_id.append(pid) 212 | cam_id.append(camid) 213 | tra_num.append(len(seq_inds)) 214 | return ret, per_id, cam_id, tra_num 215 | -------------------------------------------------------------------------------- /reid/dataset/mars.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import print_function, absolute_import 3 | import os.path as osp 4 | from scipy.io import loadmat 5 | import numpy as np 6 | from utils.serialization import write_json, read_json 7 | 8 | 9 | class infostruct(object): 10 | pass 11 | 12 | 13 | class Mars(object): 14 | root = '/home/snowtiger/snowtiger/data/MARS' 15 | train_name_path = osp.join(root, 'info/train_name.txt') 16 | test_name_path = osp.join(root, 'info/test_name.txt') 17 | track_train_info_path = osp.join(root, 'info/tracks_train_info.mat') 18 | track_test_info_path = osp.join(root, 'info/tracks_test_info.mat') 19 | query_IDX_path = osp.join(root, 'info/query_IDX.mat') 20 | split_train_json_path = osp.join(root, 'split_train.json') 21 | split_query_json_path = osp.join(root, 'split_query.json') 22 | split_gallery_json_path = osp.join(root, 'split_gallery.json') 23 | 24 | # prepare meta data 小段视频信息 [index1 index2 pid camid] 25 | def __init__(self, min_seq_len=0): 26 | 27 | self._check_before_run() 28 | 29 | train_names = self._get_names(self.train_name_path) # : '0001C1T0001F001.jpg' 30 | test_names = self._get_names(self.test_name_path) # : '00-1C1T0001F001.jpg' 31 | track_train = loadmat(self.track_train_info_path)[ 32 | 'track_train_info'] # numpy.ndarray (8298, 4) [[1 16 1 1],[17 95 1 1] ...] 33 | track_test = loadmat(self.track_test_info_path)[ 34 | 'track_test_info'] # numpy.ndarray (12180, 4) [[1 24 -1 1][25 34 -1 1]] 35 | query_IDX = loadmat(self.query_IDX_path)['query_IDX'].squeeze() # numpy.ndarray (1980,) [4130, 4138...] 36 | query_IDX -= 1 # index from 0 [4129,4137....] 37 | track_query = track_test[query_IDX, :] # 对应行的小段视频信息,[[171610 171649 2 1],[172214 172313 2 2]...] 38 | gallery_IDX = [i for i in range(track_test.shape[0]) if i not in query_IDX] # gallery = 10200 39 | # gallery_IDX = [i for i in range(track_test.shape[0])] # gallery = 12180 40 | track_gallery = track_test[gallery_IDX, :] # : (12180, 4) [[1 24 -1 1][25 34 -1 1]...] 41 | 42 | train, num_train_tracklets, num_train_pids, num_train_imgs = \ 43 | self._process_data(train_names, track_train, home_dir='bbox_train', relabel=True, 44 | min_seq_len=min_seq_len, json_path=self.split_train_json_path) 45 | 46 | query, num_query_tracklets, num_query_pids, num_query_imgs, query_pid, query_camid = \ 47 | self._process_gallery_data(test_names, track_query, home_dir='bbox_test', relabel=False, 48 | min_seq_len=min_seq_len, json_path=self.split_query_json_path,) 49 | 50 | gallery, num_gallery_tracklets, num_gallery_pids, num_gallery_imgs, gallery_pid, gallery_camid = \ 51 | self._process_gallery_data(test_names, track_gallery, home_dir='bbox_test', relabel=False, 52 | min_seq_len=min_seq_len, json_path=self.split_gallery_json_path) 53 | 54 | num_imgs_per_tracklet = num_train_imgs + num_query_imgs + num_gallery_imgs 55 | min_num = np.min(num_imgs_per_tracklet) 56 | max_num = np.max(num_imgs_per_tracklet) 57 | avg_num = np.mean(num_imgs_per_tracklet) 58 | 59 | num_total_pids = num_train_pids + num_query_pids 60 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets 61 | 62 | print("=> MARS loaded") 63 | print("Dataset statistics:") 64 | print(" ------------------------------") 65 | print(" subset | # ids | # tracklets") 66 | print(" ------------------------------") 67 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 68 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 69 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 70 | print(" ------------------------------") 71 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 72 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 73 | print(" ------------------------------") 74 | 75 | self.train = train 76 | self.query = query 77 | self.gallery = gallery 78 | 79 | self.num_train_pids = num_train_pids 80 | self.num_query_pids = num_query_pids 81 | self.num_gallery_pids = num_gallery_pids 82 | 83 | self.queryinfo = infostruct() 84 | self.queryinfo.pid = query_pid 85 | self.queryinfo.camid = query_camid 86 | self.queryinfo.tranum = num_query_imgs 87 | 88 | self.galleryinfo = infostruct() 89 | self.galleryinfo.pid = gallery_pid 90 | self.galleryinfo.camid = gallery_camid 91 | self.galleryinfo.tranum = num_gallery_imgs 92 | 93 | def _check_before_run(self): 94 | """Check if all files are available before going deeper""" 95 | if not osp.exists(self.root): 96 | raise RuntimeError("'{}' is not available".format(self.root)) 97 | if not osp.exists(self.train_name_path): 98 | raise RuntimeError("'{}' is not available".format(self.train_name_path)) 99 | if not osp.exists(self.test_name_path): 100 | raise RuntimeError("'{}' is not available".format(self.test_name_path)) 101 | if not osp.exists(self.track_train_info_path): 102 | raise RuntimeError("'{}' is not available".format(self.track_train_info_path)) 103 | if not osp.exists(self.track_test_info_path): 104 | raise RuntimeError("'{}' is not available".format(self.track_test_info_path)) 105 | if not osp.exists(self.query_IDX_path): 106 | raise RuntimeError("'{}' is not available".format(self.query_IDX_path)) 107 | 108 | def _get_names(self, fpath): 109 | names = [] 110 | with open(fpath, 'r') as f: 111 | for line in f: 112 | new_line = line.rstrip() 113 | names.append(new_line) 114 | return names 115 | 116 | def _process_data(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0, json_path=''): 117 | if osp.exists(json_path): 118 | print("=> {} generated before, awesome!".format(json_path)) 119 | split = read_json(json_path) 120 | return split['tracklets'], split['num_tracklets'], split['num_pids'], split['num_imgs_per_tracklet'] 121 | print("=> Automatically generating split (might take a while for the first time, have a coffe)") 122 | assert home_dir in ['bbox_train', 'bbox_test'] 123 | num_tracklets = meta_data.shape[0] # 8298 TODO: 要不要增加? 124 | pid_list = list(set(meta_data[:, 2].tolist())) # pid = 625 => [1 3 5 7 9...] 125 | num_pids = len(pid_list) 126 | 127 | if relabel: 128 | pid2label = {pid: label for label, pid in enumerate(pid_list)} # {1:0,3:1,5:2,...} 129 | tracklets = [] 130 | num_imgs_per_tracklet = [] 131 | 132 | for tracklet_idx in range(num_tracklets): 133 | data = meta_data[tracklet_idx, ...] # [1 16 1 1] 134 | start_index, end_index, pid, camid = data 135 | if pid == -1: 136 | continue # junk images are just ignored 137 | assert 1 <= camid <= 6 138 | if relabel: 139 | pid = pid2label[pid] # pid = 0 140 | camid -= 1 141 | # index starts from 0 142 | img_names = names[start_index - 1:end_index] 143 | # :['0001C1T0001F001.jpg'.. '0001C1T0001F016.jpg'] 144 | 145 | # make sure image names correspond to the same person 146 | pnames = [img_name[:4] for img_name in img_names] # pnames = ['0001','0001'...] 147 | assert len(set(pnames)) == 1, "Error: a single tracklet contains different person images" 148 | 149 | # make sure all images are captured under the same camera 150 | camnames = [img_name[5] for img_name in img_names] # camnames = ['1','1'...] 151 | assert len(set(camnames)) == 1, "Error: images are captured under different cameras!" 152 | 153 | # append image names with directory information 154 | # '/media/ying/0BDD17830BDD1783/ReIdDataset/Mars/bbox_train/0001/0001C1T0001F001.jpg' 155 | img_paths = [osp.join(self.root, home_dir, img_name[:4], img_name) for img_name in img_names] # list<16> 156 | if len(img_paths) >= min_seq_len: 157 | img_paths = tuple(img_paths) 158 | tracklets.append((img_paths, int(pid), int(camid))) # (('.jpg','.jpg','每张图片的路径'), 0'行人id', 0'camid' ) 159 | num_imgs_per_tracklet.append(len(img_paths)) # [16,79,15...'每个小段视频包含的图片帧数目'] 160 | 161 | num_tracklets = len(tracklets) # 8298 162 | 163 | print("Saving split to {}".format(json_path)) 164 | split_dict = { 165 | 'tracklets': tracklets, 166 | 'num_tracklets': num_tracklets, 167 | 'num_pids': num_pids, 168 | 'num_imgs_per_tracklet': num_imgs_per_tracklet, 169 | } 170 | write_json(split_dict, json_path) 171 | 172 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 173 | 174 | def _process_gallery_data(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0, json_path=''): 175 | if osp.exists(json_path): 176 | print("=> {} generated before, awesome!".format(json_path)) 177 | split = read_json(json_path) 178 | return split['tracklets'], split['num_tracklets'], split['num_pids'], split['num_imgs_per_tracklet'], split['pids'], split['camid'] 179 | 180 | assert home_dir in ['bbox_train', 'bbox_test'] 181 | num_tracklets = meta_data.shape[0] # 8298 TODO: 要不要增加? 182 | pid_list = list(set(meta_data[:, 2].tolist())) # pid = 625 => [1 3 5 7 9...] 183 | num_pids = len(pid_list) # 626 622 184 | 185 | if relabel: 186 | pid2label = {pid: label for label, pid in enumerate(pid_list)} # {1:0,3:1,5:2,...} 187 | tracklets = [] 188 | num_imgs_per_tracklet = [] 189 | gallery_pid = [] 190 | gallery_camid = [] 191 | 192 | for tracklet_idx in range(num_tracklets): 193 | data = meta_data[tracklet_idx, ...] # [1 16 1 1] 194 | start_index, end_index, pid, camid = data 195 | 196 | if pid == -1: 197 | continue # junk images are just ignored 198 | assert 1 <= camid <= 6 199 | if relabel: 200 | pid = pid2label[pid] # pid = 0 201 | camid -= 1 202 | # index starts from 0 203 | img_names = names[start_index - 1:end_index] 204 | # :['0001C1T0001F001.jpg'.. '0001C1T0001F016.jpg'] 205 | 206 | # make sure image names correspond to the same person 207 | pnames = [img_name[:4] for img_name in img_names] # pnames = ['0001','0001'...] 208 | assert len(set(pnames)) == 1, "Error: a single tracklet contains different person images" 209 | 210 | # make sure all images are captured under the same camera 211 | camnames = [img_name[5] for img_name in img_names] # camnames = ['1','1'...] 212 | assert len(set(camnames)) == 1, "Error: images are captured under different cameras!" 213 | 214 | # append image names with directory information 215 | # '/media/ying/0BDD17830BDD1783/ReIdDataset/Mars/bbox_train/0001/0001C1T0001F001.jpg' 216 | img_paths = [osp.join(self.root, home_dir, img_name[:4], img_name) for img_name in img_names] # list<16> 217 | if len(img_paths) >= min_seq_len: 218 | img_paths = tuple(img_paths) 219 | tracklets.append((img_paths, int(pid), int(camid))) # (('.jpg','.jpg','每张图片的路径'), 0'行人id', 0'camid' ) 220 | num_imgs_per_tracklet.append(len(img_paths)) # [16,79,15...'每个小段视频包含的图片帧数目'] 221 | gallery_pid.append(int(pid)) 222 | gallery_camid.append(int(camid)) 223 | num_tracklets = len(tracklets) # 8298 224 | print("Saving split to {}".format(json_path)) 225 | split_dict = { 226 | 'tracklets': tracklets, 227 | 'num_tracklets': num_tracklets, 228 | 'num_pids': num_pids, 229 | 'num_imgs_per_tracklet': num_imgs_per_tracklet, 230 | 'pids': gallery_pid, 231 | 'camid': gallery_camid, 232 | } 233 | write_json(split_dict, json_path) 234 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet, gallery_pid, gallery_camid 235 | 236 | 237 | if __name__ == '__main__': 238 | # test 239 | dataset = Mars() 240 | -------------------------------------------------------------------------------- /reid/dataset/prid2011sequence.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import os.path as osp 4 | from reid.data.datasequence import Datasequence 5 | from utils.osutils import mkdir_if_missing 6 | from utils.serialization import write_json 7 | import tarfile 8 | from glob import glob 9 | import shutil 10 | import numpy as np 11 | 12 | datasetname = 'prid_2011' 13 | flowname = 'prid2011flow' 14 | 15 | 16 | class infostruct(object): 17 | pass 18 | 19 | 20 | class PRID2011SEQUENCE(Datasequence): 21 | 22 | def __init__(self, root, split_id=0, seq_len=12, seq_srd=6, num_val=1, download=False): 23 | super(PRID2011SEQUENCE, self).__init__(root, split_id=split_id) 24 | 25 | if download: 26 | self.download() 27 | 28 | if not self._check_integrity(): 29 | self.imgextract() 30 | 31 | self.load(seq_len, seq_srd, num_val) 32 | 33 | self.query, query_pid, query_camid, query_num = self._pluckseq_cam(self.identities, self.split['query'], 34 | seq_len, seq_srd, 0) 35 | self.queryinfo = infostruct() 36 | self.queryinfo.pid = query_pid 37 | self.queryinfo.camid = query_camid 38 | self.queryinfo.tranum = query_num 39 | 40 | self.gallery, gallery_pid, gallery_camid, gallery_num = self._pluckseq_cam(self.identities, 41 | self.split['gallery'], 42 | seq_len, seq_srd, 1) 43 | self.galleryinfo = infostruct() 44 | self.galleryinfo.pid = gallery_pid 45 | self.galleryinfo.camid = gallery_camid 46 | self.galleryinfo.tranum = gallery_num 47 | 48 | @property 49 | def other_dir(self): 50 | return osp.join(self.root, 'others') 51 | 52 | def download(self): 53 | 54 | if self._check_integrity(): 55 | print("Files already downloaded and verified") 56 | return 57 | 58 | raw_dir = osp.join(self.root, 'raw') 59 | mkdir_if_missing(raw_dir) 60 | 61 | fpath1 = osp.join(raw_dir, datasetname + '.tar') 62 | fpath2 = osp.join(raw_dir, flowname + '.tar') 63 | 64 | if osp.isfile(fpath1) and osp.isfile(fpath2): 65 | print("Using the download file:" + fpath1 + " " + fpath2) 66 | else: 67 | print("Please firstly download the files") 68 | raise RuntimeError("Downloaded file missing!") 69 | 70 | def imgextract(self): 71 | 72 | raw_dir = osp.join(self.root, 'raw') 73 | # raw_dir = /media/ying/0BDD17830BDD1783/video_reid _prid/data/prid2011sequence/raw 74 | exdir1 = osp.join(raw_dir, datasetname) 75 | # exdir1 = /media/ying/0BDD17830BDD1783/video_reid _prid/data/prid2011sequence/raw/prid_2011 76 | exdir2 = osp.join(raw_dir, flowname) 77 | # exdir2 = /media/ying/0BDD17830BDD1783/video_reid _prid/data/prid2011sequence/raw/prid2011flow 78 | fpath1 = osp.join(raw_dir, datasetname + '.tar') 79 | # fpath1 = /media/ying/0BDD17830BDD1783/video_reid _prid/data/prid2011sequence/raw/prid_2011.tar 80 | fpath2 = osp.join(raw_dir, flowname + '.tar') 81 | # fpath2 = /media/ying/0BDD17830BDD1783/video_reid _prid/data/prid2011sequence/raw/prid2011flow.tar 82 | 83 | if not osp.isdir(exdir1): 84 | print("Extracting tar file") 85 | cwd = os.getcwd() 86 | tar_ref = tarfile.open(fpath1) 87 | mkdir_if_missing(exdir1) 88 | os.chdir(exdir1) 89 | tar_ref.extractall() 90 | tar_ref.close() 91 | os.chdir(cwd) 92 | 93 | if not osp.isdir(exdir2): 94 | print("Extracting tar file") 95 | cwd = os.getcwd() 96 | tar_ref = tarfile.open(fpath2) 97 | mkdir_if_missing(exdir2) 98 | os.chdir(exdir2) 99 | tar_ref.extractall() 100 | tar_ref.close() 101 | os.chdir(cwd) 102 | 103 | # recognizing the dataset 104 | # Format 105 | temp_images_dir = osp.join(self.root, 'temp_images') 106 | mkdir_if_missing(temp_images_dir) 107 | 108 | temp_others_dir = osp.join(self.root, 'temp_others') 109 | mkdir_if_missing(temp_others_dir) 110 | 111 | images_dir = osp.join(self.root, 'images') 112 | mkdir_if_missing(images_dir) 113 | # images_dir = /media/ying/0BDD17830BDD1783/video_reid _prid/data/prid2011sequence/images 114 | 115 | others_dir = osp.join(self.root, 'others') 116 | mkdir_if_missing(others_dir) 117 | # others_dir = /media/ying/0BDD17830BDD1783/video_reid _prid/data/prid2011sequence/others 118 | 119 | fpaths1 = sorted(glob(osp.join(exdir1, 'prid_2011/multi_shot', '*/*/*.png'))) # 存放所有图片的绝对路径 120 | fpaths2 = sorted(glob(osp.join(exdir2, 'prid2011flow', '*/*/*.png'))) 121 | 122 | identities_imgraw = [[[] for _ in range(2)] for _ in range(200)] # 200个[ []..[] ] 123 | identities_otherraw = [[[] for _ in range(2)] for _ in range(200)] 124 | 125 | for fpath in fpaths1: 126 | fname = fpath 127 | fname_list = fname.split('/') 128 | cam_name = fname_list[-3] # cam_a / cam_b 129 | pid_name = fname_list[-2] # person_001 130 | frame_name = fname_list[-1] # 0001.png 131 | cam_id = 1 if cam_name == 'cam_a' else 2 # cam_id = 1 / 2 132 | pid_id = int(pid_name.split('_')[-1]) # pid_id = 001 133 | if pid_id > 200: 134 | continue 135 | frame_id = int(frame_name.split('.')[-2]) # frame_id = 0001 136 | temp_fname = ('{:08d}_{:02d}_{:04d}.png' 137 | .format(pid_id-1, cam_id-1, frame_id-1)) 138 | identities_imgraw[pid_id - 1][cam_id - 1].append(temp_fname) 139 | shutil.copy(fpath, osp.join(temp_images_dir, temp_fname)) 140 | 141 | identities_temp = [x for x in identities_imgraw if x != [[], []]] 142 | identities_images = identities_temp 143 | 144 | for pid in range(len(identities_temp)): 145 | for cam in range(2): 146 | for img in range(len(identities_images[pid][cam])): 147 | temp_fname = identities_temp[pid][cam][img] 148 | fname = ('{:08d}_{:02d}_{:04d}.png'.format(pid, cam, img)) 149 | identities_images[pid][cam][img] = fname 150 | shutil.copy(osp.join(temp_images_dir, temp_fname), osp.join(images_dir, fname)) 151 | 152 | shutil.rmtree(temp_images_dir) 153 | 154 | for fpath in fpaths2: 155 | fname = fpath 156 | fname_list = fname.split('/') 157 | cam_name = fname_list[-3] # cam_a / cam_b 158 | pid_name = fname_list[-2] # person_001 159 | frame_name = fname_list[-1] # 0001.png 160 | cam_id = 1 if cam_name == 'cam_a' else 2 # cam_id = 1 / 2 161 | pid_id = int(pid_name.split('_')[-1]) # pid_id = 001 162 | if pid_id > 200: 163 | continue 164 | frame_id = int(frame_name.split('.')[-2]) # frame_id = 0001 165 | temp_fname = ('{:08d}_{:02d}_{:04d}.png' 166 | .format(pid_id-1, cam_id-1, frame_id-1)) 167 | identities_otherraw[pid_id - 1][cam_id - 1].append(temp_fname) 168 | shutil.copy(fpath, osp.join(temp_others_dir, temp_fname)) 169 | 170 | identities_temp = [x for x in identities_otherraw if x != [[], []]] 171 | identities_others = identities_temp 172 | 173 | for pid in range(len(identities_temp)): 174 | for cam in range(2): 175 | for img in range(len(identities_others[pid][cam])): 176 | temp_fname = identities_temp[pid][cam][img] 177 | fname = ('{:08d}_{:02d}_{:04d}.png'.format(pid, cam, img)) 178 | identities_images[pid][cam][img] = fname 179 | shutil.copy(osp.join(temp_others_dir, temp_fname), osp.join(others_dir, fname)) 180 | 181 | shutil.rmtree(temp_others_dir) 182 | 183 | meta = {'name': 'prid-sequence', 'shot': 'sequence', 'num_cameras': 2, 184 | 'identities': identities_images} 185 | 186 | write_json(meta, osp.join(self.root, 'meta.json')) 187 | # Consider fixed training and testing split 188 | num = len(identities_images) 189 | splits = [] 190 | for i in range(20): 191 | pids = np.random.permutation(num) 192 | pids = (pids - 1).tolist() 193 | trainval_pids = pids[:num // 2] 194 | test_pids = pids[num // 2:] 195 | split = {'trainval': trainval_pids, 196 | 'query': test_pids, 197 | 'gallery': test_pids} 198 | 199 | splits.append(split) 200 | write_json(splits, osp.join(self.root, 'splits.json')) 201 | 202 | def _pluckseq_cam(self, identities, indices, seq_len, seq_str, camid): 203 | ret = [] 204 | per_id = [] 205 | cam_id = [] 206 | tra_num = [] 207 | 208 | for index, pid in enumerate(indices): 209 | pid_images = identities[pid] 210 | cam_images = pid_images[camid] 211 | seqall = len(cam_images) 212 | seq_inds = [(start_ind, start_ind + seq_len) for start_ind in range(0, seqall - seq_len, seq_str)] 213 | if not seq_inds: 214 | seq_inds = [(0, seqall)] 215 | for seq_ind in seq_inds: 216 | ret.append((seq_ind[0], seq_ind[1], pid, index, camid)) 217 | per_id.append(pid) 218 | cam_id.append(camid) 219 | tra_num.append(len(seq_inds)) 220 | return ret, per_id, cam_id, tra_num 221 | -------------------------------------------------------------------------------- /reid/evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .eva_functions import accuracy, cmc, mean_ap 3 | from .attevaluator import ATTEvaluator 4 | from .rerank import re_ranking 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap', 10 | 're_ranking', 11 | ] 12 | 13 | 14 | -------------------------------------------------------------------------------- /reid/evaluator/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/evaluator/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/evaluator/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/evaluator/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reid/evaluator/__pycache__/attevaluator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/evaluator/__pycache__/attevaluator.cpython-36.pyc -------------------------------------------------------------------------------- /reid/evaluator/__pycache__/attevaluator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/evaluator/__pycache__/attevaluator.cpython-37.pyc -------------------------------------------------------------------------------- /reid/evaluator/__pycache__/eva_functions.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/evaluator/__pycache__/eva_functions.cpython-36.pyc -------------------------------------------------------------------------------- /reid/evaluator/__pycache__/eva_functions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/evaluator/__pycache__/eva_functions.cpython-37.pyc -------------------------------------------------------------------------------- /reid/evaluator/__pycache__/rerank.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/evaluator/__pycache__/rerank.cpython-36.pyc -------------------------------------------------------------------------------- /reid/evaluator/__pycache__/rerank.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/evaluator/__pycache__/rerank.cpython-37.pyc -------------------------------------------------------------------------------- /reid/evaluator/__pycache__/visualize.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/evaluator/__pycache__/visualize.cpython-36.pyc -------------------------------------------------------------------------------- /reid/evaluator/__pycache__/visualize.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/evaluator/__pycache__/visualize.cpython-37.pyc -------------------------------------------------------------------------------- /reid/evaluator/attevaluator.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | import math 4 | import torch 5 | from utils.meters import AverageMeter 6 | from utils import to_torch 7 | from .eva_functions import cmc, mean_ap, evaluate, evaluate_zhengliang 8 | from .rerank import re_ranking 9 | from .visualize import visualize_ranked_results, visualize_in_pic 10 | import numpy as np 11 | from torch import nn 12 | import scipy.io 13 | 14 | 15 | def evaluate_seq(distmat, query_pids, query_camids, gallery_pids, gallery_camids, path, cmc_topk=[1, 5, 10, 20]): 16 | query_ids = np.array(query_pids) # : (1980,) 17 | gallery_ids = np.array(gallery_pids) # : (9330,) 18 | query_cams = np.array(query_camids) # : (1980,) 19 | gallery_cams = np.array(gallery_camids) # : (9330,) 20 | 21 | ## 22 | cmc_scores, mAP = evaluate(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 23 | 24 | print('Mean AP: {:4.1%}'.format(mAP)) 25 | 26 | for r in cmc_topk: 27 | print("Rank-{:<3}: {:.1%}".format(r, cmc_scores[r-1])) 28 | print("------------------") 29 | 30 | return cmc_scores[0] 31 | 32 | 33 | def pairwise_distance_tensor(query_x, gallery_x): 34 | m, n = query_x.size(0), gallery_x.size(0) 35 | x = query_x.view(m, -1) 36 | y = gallery_x.view(n, -1) 37 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) +\ 38 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 39 | dist.addmm_(1, -2, x, y.t()) 40 | dist = dist.clamp(min=1e-12).sqrt() 41 | return dist 42 | 43 | 44 | def cosin_dist(qf, gf): 45 | dist = -torch.mm(qf, gf.t()) 46 | return dist 47 | 48 | 49 | class ATTEvaluator(object): 50 | 51 | def __init__(self, cnn_model, Siamese_model, only_eval): 52 | super(ATTEvaluator, self).__init__() 53 | self.cnn_model = cnn_model 54 | self.siamese_model = Siamese_model 55 | self.softmax = nn.Softmax(dim=-1) 56 | self.only_eval = only_eval 57 | 58 | @torch.no_grad() 59 | def extract_feature(self, data_loader): 60 | 61 | self.cnn_model.eval() 62 | self.siamese_model.eval() 63 | 64 | qf, q_pids, q_camids = [], [], [] 65 | for i, inputs in enumerate(data_loader): 66 | imgs, pids, camids = inputs 67 | 68 | if self.only_eval: 69 | b, n, s, c, h, w = imgs.size() # 1, 5, 8, c, h, w 70 | imgs = imgs.view(b*n, s, c, h, w).cuda() 71 | with torch.no_grad(): 72 | if b*n > 8: # 如果序列过长,则分成若干个15个batch_size 73 | feat_list = [] # 弄一个临时列表,存放特征 74 | num = int(math.ceil(b*n*1.0/8)) # 有几个32 75 | for y in range(num): 76 | clips = imgs[y*8:(y+1)*8, :, :, :, :].cuda() # 32, 8, c, h, w 77 | x_uncorr, feats_corr = self.cnn_model(clips) 78 | 79 | out_frame = self.siamese_model.self_attention(feats_corr) 80 | out_feat = torch.cat((x_uncorr, out_frame, feats_corr.mean(dim=1)), dim=1) 81 | 82 | feat_list.append(out_feat) 83 | feat_list = torch.cat(feat_list, 0) 84 | feat_list = torch.mean(feat_list, dim=0) 85 | qf.append(feat_list.unsqueeze(0)) 86 | q_pids.extend(pids) 87 | q_camids.extend(camids) 88 | else: 89 | x_uncorr, feats_corr = self.cnn_model(imgs) 90 | 91 | out_frame = self.siamese_model.self_attention(feats_corr) 92 | out_feat = torch.cat((x_uncorr, out_frame, feats_corr.mean(dim=1)), dim=1) 93 | 94 | out_feat = out_feat.view(n, -1) 95 | out_feat = torch.mean(out_feat, dim=0) 96 | qf.append(out_feat.unsqueeze(0)) 97 | q_pids.extend(pids) 98 | q_camids.extend(camids) 99 | torch.cuda.empty_cache() 100 | else: 101 | b, s, c, h, w = imgs.size() 102 | imgs = imgs.view(b, s, c, h, w) 103 | imgs = to_torch(imgs) 104 | 105 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 106 | imgs = imgs.to(device) 107 | 108 | with torch.no_grad(): 109 | x_uncorr, feats_corr = self.cnn_model(imgs) 110 | 111 | out_frame = self.siamese_model.self_attention(feats_corr) 112 | out_feat = torch.cat((x_uncorr, out_frame, feats_corr.mean(dim=1)), dim=1) 113 | 114 | qf.append(out_feat) 115 | q_pids.extend(pids) 116 | q_camids.extend(camids) 117 | torch.cuda.empty_cache() 118 | 119 | qf = torch.cat(qf, 0) 120 | q_pids = np.asarray(q_pids) 121 | q_camids = np.asarray(q_camids) 122 | 123 | return qf, q_pids, q_camids 124 | 125 | def evaluate(self, query, gallery, query_loader, gallery_loader, path, visual, rerank): 126 | # 1 127 | rerank = rerank 128 | path = path 129 | 130 | if visual: 131 | result = scipy.io.loadmat(path+'dist.mat') 132 | distmat = result['distmat'] 133 | save_dir = path + 'visual' 134 | visual_id = 4 135 | visualize_in_pic(distmat, query, gallery, save_dir, visual_id) 136 | 137 | else: 138 | qf, q_pids, q_camids = self.extract_feature(query_loader) 139 | torch.cuda.empty_cache() 140 | print('Done, obtained {}-by-{} matrix'.format(qf.size(0), qf.size(1))) 141 | 142 | gf, g_pids, g_camids = self.extract_feature(gallery_loader) 143 | gf = torch.cat((qf, gf), 0) 144 | g_pids = np.append(q_pids, g_pids) 145 | g_camids = np.append(q_camids, g_camids) 146 | print('Done, obtained {}-by-{} matrix'.format(gf.size(0), gf.size(1))) 147 | 148 | print("Computing distance matrix") 149 | 150 | distmat = cosin_dist(qf, gf).cpu().numpy() 151 | if rerank: 152 | print('Applying person re-ranking ...') 153 | distmat_qq = pairwise_distance_tensor(qf, qf).cpu().numpy() 154 | distmat_gg = pairwise_distance_tensor(gf, gf).cpu().numpy() 155 | distmat = re_ranking(distmat, distmat_qq, distmat_gg) 156 | 157 | print("save matrixs for visualization") 158 | 159 | del query_loader 160 | del gallery_loader 161 | final = evaluate_seq(distmat, q_pids, q_camids, g_pids, g_camids, path) 162 | torch.cuda.empty_cache() 163 | return final 164 | -------------------------------------------------------------------------------- /reid/evaluator/eva_functions.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import torch 6 | from sklearn.metrics import average_precision_score 7 | from utils import to_torch, to_numpy 8 | 9 | 10 | def _unique_sample(ids_dict, num): 11 | mask = np.zeros(num, dtype=np.bool) 12 | for _, indices in ids_dict.items(): 13 | i = np.random.choice(indices) 14 | mask[i] = True 15 | return mask 16 | 17 | 18 | def cmc(distmat, query_ids=None, gallery_ids=None, 19 | query_cams=None, gallery_cams=None, topk=100, 20 | separate_camera_set=False, 21 | single_gallery_shot=False, 22 | first_match_break=False): 23 | distmat = to_numpy(distmat) 24 | m, n = distmat.shape 25 | # Fill up default values 26 | if query_ids is None: 27 | query_ids = np.arange(m) 28 | if gallery_ids is None: 29 | gallery_ids = np.arange(n) 30 | if query_cams is None: 31 | query_cams = np.zeros(m).astype(np.int32) 32 | if gallery_cams is None: 33 | gallery_cams = np.ones(n).astype(np.int32) 34 | # Ensure numpy array 35 | query_ids = np.asarray(query_ids) 36 | gallery_ids = np.asarray(gallery_ids) 37 | query_cams = np.asarray(query_cams) 38 | gallery_cams = np.asarray(gallery_cams) 39 | # Sort and find correct matches 40 | indices = np.argsort(distmat, axis=1) 41 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 42 | # Compute CMC for each query 43 | ret = np.zeros(topk) 44 | num_valid_queries = 0 45 | for i in range(m): 46 | # Filter out the same id and same camera 47 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 48 | (gallery_cams[indices[i]] != query_cams[i])) 49 | if separate_camera_set: 50 | # Filter out samples from same camera 51 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 52 | if not np.any(matches[i, valid]): continue 53 | if single_gallery_shot: 54 | repeat = 10 55 | gids = gallery_ids[indices[i][valid]] 56 | inds = np.where(valid)[0] 57 | ids_dict = defaultdict(list) 58 | for j, x in zip(inds, gids): 59 | ids_dict[x].append(j) 60 | else: 61 | repeat = 1 62 | for _ in range(repeat): 63 | if single_gallery_shot: 64 | # Randomly choose one instance for each id 65 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 66 | index = np.nonzero(matches[i, sampled])[0] 67 | else: 68 | index = np.nonzero(matches[i, valid])[0] 69 | delta = 1. / (len(index) * repeat) 70 | for j, k in enumerate(index): 71 | if k - j >= topk: break 72 | if first_match_break: 73 | ret[k - j] += 1 74 | break 75 | ret[k - j] += delta 76 | num_valid_queries += 1 77 | if num_valid_queries == 0: 78 | raise RuntimeError("No valid query") 79 | return ret.cumsum() / num_valid_queries 80 | 81 | 82 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 83 | query_cams=None, gallery_cams=None): 84 | distmat = to_numpy(distmat) # : (100, 100) 85 | m, n = distmat.shape 86 | # Fill up default values 87 | if query_ids is None: 88 | query_ids = np.arange(m) 89 | if gallery_ids is None: 90 | gallery_ids = np.arange(n) 91 | if query_cams is None: 92 | query_cams = np.zeros(m).astype(np.int32) 93 | if gallery_cams is None: 94 | gallery_cams = np.ones(n).astype(np.int32) 95 | # Ensure numpy array 96 | query_ids = np.asarray(query_ids) 97 | gallery_ids = np.asarray(gallery_ids) 98 | query_cams = np.asarray(query_cams) 99 | gallery_cams = np.asarray(gallery_cams) 100 | # Sort and find correct matches 101 | indices = np.argsort(distmat, axis=1) 102 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 103 | # Compute AP for each query 104 | aps = [] 105 | for i in range(m): 106 | # Filter out the same id and same camera 107 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 108 | (gallery_cams[indices[i]] != query_cams[i])) 109 | y_true = matches[i, valid] 110 | y_score = -distmat[i][indices[i]][valid] 111 | if not np.any(y_true): continue 112 | aps.append(average_precision_score(y_true, y_score)) 113 | if len(aps) == 0: 114 | raise RuntimeError("No valid query") 115 | return np.mean(aps) 116 | 117 | 118 | def accuracy(output, target, topk=(1,)): 119 | output, target = to_torch(output), to_torch(target) 120 | maxk = max(topk) 121 | batch_size = target.size(0) 122 | 123 | _, pred = output.topk(maxk, 1, True, True) 124 | pred = pred.t() 125 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 126 | 127 | ret = [] 128 | for k in topk: 129 | correct_k = correct[:k].view(-1).float().sum(0) 130 | ret.append(correct_k.mul_(1. / batch_size)) 131 | return ret 132 | 133 | 134 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=100): 135 | num_q, num_g = distmat.shape # 1980,9330 136 | if num_g < max_rank: 137 | max_rank = num_g 138 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 139 | indices = np.argsort(distmat, axis=1) # torch.Size([1980, 9330]) 140 | 141 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) # torch.Size([1980, 9330]) 142 | 143 | # compute cmc curve for each query 144 | all_cmc = [] 145 | all_AP = [] 146 | num_valid_q = 0. 147 | for q_idx in range(num_q): 148 | # get query pid and camid 149 | q_pid = q_pids[q_idx] 150 | q_camid = q_camids[q_idx] 151 | # remove gallery samples that have the same pid and camid with query 152 | order = indices[q_idx] # torch.Size([9330]) 153 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 154 | keep = np.invert(remove) 155 | 156 | # compute cmc curve 157 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 158 | 159 | if not np.any(orig_cmc): 160 | # this condition is true when query identity does not appear in gallery 161 | continue 162 | 163 | cmc = orig_cmc.cumsum() 164 | cmc[cmc > 1] = 1 165 | all_cmc.append(cmc[:max_rank]) 166 | num_valid_q += 1. 167 | 168 | # compute average precision 169 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 170 | num_rel = orig_cmc.sum() 171 | tmp_cmc = orig_cmc.cumsum() 172 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 173 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 174 | AP = tmp_cmc.sum() / num_rel 175 | 176 | all_AP.append(AP) 177 | 178 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 179 | 180 | all_cmc = np.asarray(all_cmc).astype(np.float32) 181 | all_cmc = all_cmc.sum(0) / num_valid_q # : (50,) 182 | mAP = np.mean(all_AP) 183 | 184 | return all_cmc, mAP 185 | 186 | 187 | def evaluate_zhengliang(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=100): 188 | cmc = np.zeros((distmat.shape[0], max_rank)) # (1980, 100) 189 | ap = np.zeros(distmat.shape[0]) # (1980,) 190 | 191 | junk_mask0 = (g_pids == -1) # gallery中id为-1的样本是无意义的,忽略 9330 192 | num_valid_q = 0. 193 | for k in range(distmat.shape[0]): 194 | score = distmat[k, :] 195 | good_idx = np.where((q_pids[k] == g_pids) & (q_camids[k] != g_camids))[0] # 18 196 | if len(good_idx) == 0: 197 | num_valid_q = num_valid_q 198 | continue 199 | else: 200 | num_valid_q += 1 201 | junk_mask1 = ((q_pids[k] == g_pids) & (q_camids[k] == g_camids)) 202 | junk_idx = np.where(junk_mask0 | junk_mask1)[0] 203 | sort_idx = np.argsort(score)[:max_rank] 204 | ap[k], cmc[k, :] = Compute_AP(good_idx, junk_idx, sort_idx) 205 | 206 | all_cmc = np.asarray(cmc).astype(np.float32) 207 | CMC = all_cmc.sum(0) / num_valid_q # : (50,) 208 | 209 | mAP = np.mean(ap) 210 | return CMC, mAP 211 | 212 | 213 | def Compute_AP(good_idx, junk_idx, index): 214 | cmc = np.zeros((len(index), )) 215 | num_real = len(good_idx) 216 | 217 | old_recall = 0 218 | old_precision = 1. 219 | ap = 0 220 | intersect_size = 0 221 | j = 0 222 | good_now = 0 223 | njunk = 0 224 | for n in range(len(index)): # rank N 225 | flag = 0 226 | if np.any(good_idx == index[n]): 227 | cmc[n - njunk:] = 1 228 | flag = 1 # good image 229 | good_now += 1 230 | if np.any(junk_idx == index[n]): 231 | njunk += 1 232 | continue # junk image 233 | 234 | if flag == 1: 235 | intersect_size += 1 236 | recall = intersect_size / num_real # 1 / 21 = 0.047 237 | precision = intersect_size / (j + 1) # 1 238 | ap += (recall - old_recall) * (old_precision + precision) / 2 239 | old_recall = recall 240 | old_precision = precision 241 | j += 1 242 | 243 | if good_now == num_real: 244 | return ap, cmc 245 | return ap, cmc -------------------------------------------------------------------------------- /reid/evaluator/evaluator.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | import torch 4 | from torch.autograd import Variable 5 | from utils.meters import AverageMeter 6 | from utils import to_numpy 7 | from .eva_functions import cmc, mean_ap 8 | import numpy as np 9 | 10 | 11 | def evaluate_seq(distmat, query_pids, query_camids, gallery_pids, gallery_camids, cmc_topk=(1, 5, 10)): 12 | query_ids = np.array(query_pids) 13 | gallery_ids = np.array(gallery_pids) 14 | query_cams = np.array(query_camids) 15 | gallery_cams = np.array(gallery_camids) 16 | 17 | ## 18 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 19 | print('Mean AP: {:4.1%}'.format(mAP)) 20 | 21 | cmc_configs = { 22 | 'allshots': dict(separate_camera_set=False, 23 | single_gallery_shot=False, 24 | first_match_break=False), 25 | 'cuhk03': dict(separate_camera_set=True, 26 | single_gallery_shot=True, 27 | first_match_break=False), 28 | 'market1501': dict(separate_camera_set=False, 29 | single_gallery_shot=False, 30 | first_match_break=True)} 31 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 32 | query_cams, gallery_cams, **params) 33 | for name, params in cmc_configs.items()} 34 | 35 | print('CMC Scores{:>12}{:>12}{:>12}' 36 | .format('allshots', 'cuhk03', 'market1501')) 37 | for k in cmc_topk: 38 | print(' top-{:<4}{:12.1%}{:12.1%}{:12.1%}' 39 | .format(k, cmc_scores['allshots'][k - 1], 40 | cmc_scores['cuhk03'][k - 1], 41 | cmc_scores['market1501'][k - 1])) 42 | 43 | # Use the allshots cmc top-1 score for validation criterion 44 | return mAP 45 | 46 | 47 | def pairwise_distance_tensor(query_x, gallery_x): 48 | 49 | m, n = query_x.size(0), gallery_x.size(0) 50 | x = query_x.view(m, -1) 51 | y = gallery_x.view(n, -1) 52 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) +\ 53 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 54 | dist.addmm_(1, -2, x, y.t()) 55 | 56 | return dist 57 | 58 | 59 | class CNNEvaluator(object): 60 | 61 | def __init__(self, cnn_model, mode): 62 | super(CNNEvaluator, self).__init__() 63 | self.cnn_model = cnn_model 64 | self.mode = mode 65 | 66 | def extract_feature(self, cnn_model, data_loader): 67 | print_freq = 50 68 | cnn_model.eval() 69 | batch_time = AverageMeter() 70 | data_time = AverageMeter() 71 | end = time.time() 72 | 73 | allfeatures = 0 74 | 75 | for i, (imgs, flows, _, _) in enumerate(data_loader): 76 | data_time.update(time.time() - end) 77 | imgs = Variable(imgs, volatile=True) 78 | flows = Variable(flows, volatile=True) 79 | 80 | if i == 0: 81 | out_feat = self.cnn_model(imgs, flows, self.mode) 82 | allfeatures = out_feat.data 83 | preimgs = imgs 84 | preflows = flows 85 | elif imgs.size(0): (11310, 11310) 45 | original_dist = np.power(original_dist, 2).astype(np.float32) # : (11310, 11310) 46 | original_dist = np.transpose(1. * original_dist / np.max(original_dist, axis=0)) # : (11310, 11310) 47 | V = np.zeros_like(original_dist).astype(np.float32) # : (11310, 11310) 48 | initial_rank = np.argsort(original_dist).astype(np.int32) # : (11310, 11310) 49 | 50 | query_num = q_g_dist.shape[0] # 1980 51 | gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1] # 11310 52 | all_num = gallery_num # 11310 53 | 54 | for i in range(all_num): 55 | # k-reciprocal neighbors 56 | forward_k_neigh_index = initial_rank[i, :k1 + 1] # 21 57 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] # : (21, 21) 58 | fi = np.where(backward_k_neigh_index == i)[0] # [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19] 59 | k_reciprocal_index = forward_k_neigh_index[fi] # : (20,) 60 | k_reciprocal_expansion_index = k_reciprocal_index # : (20,) 61 | for j in range(len(k_reciprocal_index)): 62 | candidate = k_reciprocal_index[j] # 0 63 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2.)) + 1] # : (11,) 64 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 65 | :int(np.around(k1 / 2.)) + 1] # : (11, 11) 66 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] # : (7,) 67 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] # [ 0 5238 5251 5245 1 5252 2] 68 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2. / 3 * len( 69 | candidate_k_reciprocal_index): 70 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 71 | 72 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) # : (23,) 73 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) # : (23,) 74 | V[i, k_reciprocal_expansion_index] = 1. * weight / np.sum(weight) 75 | original_dist = original_dist[:query_num, ] # : (11310, 11310) 76 | if k2 != 1: 77 | V_qe = np.zeros_like(V, dtype=np.float32) # : (11310, 11310) 78 | for i in range(all_num): 79 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 80 | V = V_qe # : (11310, 11310) 81 | del V_qe 82 | del initial_rank 83 | invIndex = [] 84 | for i in range(gallery_num): 85 | invIndex.append(np.where(V[:, i] != 0)[0]) 86 | 87 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float32) # : (1980, 11310) 88 | 89 | for i in range(query_num): 90 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float32) 91 | indNonZero = np.where(V[i, :] != 0)[0] 92 | indImages = [] 93 | indImages = [invIndex[ind] for ind in indNonZero] 94 | for j in range(len(indNonZero)): 95 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 96 | V[indImages[j], indNonZero[j]]) 97 | jaccard_dist[i] = 1 - temp_min / (2. - temp_min) 98 | 99 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value # : (1980, 11310) 100 | del original_dist 101 | del V 102 | del jaccard_dist 103 | final_dist = final_dist[:query_num, query_num:] # : (1980, 9330) 104 | return final_dist 105 | -------------------------------------------------------------------------------- /reid/evaluator/visualize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | # @time:2019/6/12上午11:09 4 | # @Author: Yu Ci 5 | __all__ = ['visualize_ranked_results', 'visualize_in_pic'] 6 | 7 | import torch 8 | import numpy as np 9 | import os 10 | import os.path as osp 11 | import shutil 12 | # import matplotlib.pyplot as plt 13 | 14 | from utils.osutils import mkdir_if_missing 15 | 16 | 17 | def visualize_ranked_results(distmat, queryloader, galleryloader, save_dir='', visual_id=2, topk=10): 18 | """Visualizes ranked results. 存放在一个文件夹中 19 | 20 | Supports both image-reid and video-reid. 21 | 22 | Args: 23 | distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery). 24 | queryloader (tuple): tuples of (img_path(s), pid, camid). 25 | galleryloader (tuple): tuples of (img_path(s), pid, camid). 26 | save_dir (str): directory to save output images. 27 | visual_id(int, optional): only show 1 id 28 | topk (int, optional): denoting top-k images in the rank list to be visualized. 29 | """ 30 | num_q, num_g = distmat.shape 31 | 32 | print('Visualizing top-{} ranks'.format(topk)) 33 | print('# query: {}\n# gallery {}'.format(num_q, num_g)) 34 | print('Saving images to "{}"'.format(save_dir)) 35 | 36 | query = queryloader # 1980个tuple (img_path(s), pid, camid)) 37 | gallery = galleryloader # 9330个tuple (img_path(s), pid, camid) 38 | assert num_q == len(query) 39 | assert num_g == len(gallery) 40 | 41 | indices = np.argsort(distmat, axis=1) # : (1980, 9330) 42 | mkdir_if_missing(save_dir) # '/home/ying/Desktop/mars_rank/log/debug_for_eval/split0visual' 43 | 44 | def _cp_img_to(src, dst, rank, prefix): 45 | """ 46 | Args: 47 | src: image path or tuple (for vidreid) 48 | dst: target directory # '/home/ying/Desktop/mars_rank/log/debug_for_eval/split0visual/0016C1T0006F001.jpg' 49 | rank: int, denoting ranked position, starting from 1 50 | prefix: string (query or gallery) 51 | """ 52 | if isinstance(src, tuple) or isinstance(src, list): # video reid 53 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3)) # '/home/ying/Desktop/mars_rank/log/debug_for_eval/split0visual/0016C1T0006F001.jpg/query_top000' 54 | mkdir_if_missing(dst) 55 | for img_path in src: # 将图片copy到目标文件夹中 56 | shutil.copy(img_path, dst) 57 | else: 58 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3) + '_name_' + osp.basename(src)) 59 | shutil.copy(src, dst) 60 | 61 | for q_idx in range(num_q): # 考虑到速度等因素,只输出1个id的rank结果。这个id不是实际的行人id,是在tuple中的顺序 62 | if q_idx == visual_id: # 14 63 | qimg_path, qpid, qcamid = query[q_idx] # qpid = 16, camid = 0 64 | 65 | if isinstance(qimg_path, tuple) or isinstance(qimg_path, list): # query_dir 保存Rank结果的文件夹名称 = query的第一张图片名称 66 | qdir = osp.join(save_dir, osp.basename(qimg_path[0])) # '/home/ying/Desktop/mars_rank/log/debug_for_eval/split0visual/0016C1T0006F001.jpg' 67 | else: 68 | qdir = osp.join(save_dir, osp.basename(qimg_path)) 69 | mkdir_if_missing(qdir) # 新建这个保存rank结果的文件夹 70 | _cp_img_to(qimg_path, qdir, rank=0, prefix='query') # 复制query的图片到结果文件夹中 71 | 72 | rank_idx = 1 73 | for g_idx in indices[q_idx, :]: # 3291, 3288, 3289, 3290, 3293 74 | gimg_path, gpid, gcamid = gallery[g_idx] 75 | invalid = (qpid == gpid) & (qcamid == gcamid) # true, 排除相同cam的情况 76 | if not invalid: 77 | _cp_img_to(gimg_path, qdir, rank=rank_idx, prefix='gallery') 78 | rank_idx += 1 79 | if rank_idx > topk: 80 | break 81 | print("Done") 82 | 83 | 84 | def visualize_in_pic(distmat, queryloader, galleryloader, save_dir='', visual_id=2, topk=9): 85 | """ 86 | 87 | distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery). 88 | queryloader (tuple): tuples of (img_path(s), pid, camid). 89 | galleryloader (tuple): tuples of (img_path(s), pid, camid). 90 | save_dir (str): directory to save output images. 91 | visual_id(int, optional): only show 1 id 92 | topk (int, optional): denoting top-k images in the rank list to be visualized. 93 | """ 94 | num_q, num_g = distmat.shape 95 | 96 | print('Visualizing top-{} ranks'.format(topk+1)) 97 | print('# query: {}\n# gallery {}'.format(num_q, num_g)) 98 | print('Saving images to "{}"'.format(save_dir)) 99 | 100 | query = queryloader # 1980个tuple (img_path(s), pid, camid)) 101 | gallery = galleryloader # 9330个tuple (img_path(s), pid, camid) 102 | assert num_q == len(query) 103 | assert num_g == len(gallery) 104 | 105 | indices = np.argsort(distmat, axis=1) # : (1980, 9330) 106 | mkdir_if_missing(save_dir) # '/home/ying/Desktop/mars_rank/log/debug_for_eval/split0visual' 107 | 108 | def imshow(path, title=None): 109 | """Imshow for Tensor.""" 110 | im = plt.imread(path) 111 | plt.imshow(im) 112 | if title is not None: 113 | plt.title(title) 114 | # plt.pause(0.001) # pause a bit so that plots are updated 115 | flag = 0 116 | for q_idx in range(num_q): # 考虑到速度等因素,只输出1个id的rank结果。2,4,6,8,10.. 117 | qimg_path, qpid, qcamid = query[q_idx] # qpid = 16, camid = 0 118 | 119 | if qpid == visual_id: # 14 120 | flag = 1 121 | fig = plt.figure(figsize=(25, 8)) 122 | ax = plt.subplot(1, 11, 1) 123 | ax.axis('off') 124 | imshow(qimg_path[0], 'query, pid:{}'.format(qpid)) 125 | 126 | rank_idx = 0 127 | for g_idx in indices[q_idx, :]: # 3291, 3288, 3289, 3290, 3293 128 | gimg_path, gpid, gcamid = gallery[g_idx] 129 | # invalid = (qpid == gpid) & (qcamid == gcamid) # true, 排除相同cam的情况 130 | invalid = False 131 | if not invalid: 132 | rank_idx += 1 133 | ax = plt.subplot(1, 11, rank_idx+1) 134 | ax.axis('off') 135 | imshow(gimg_path[0]) 136 | if qpid == gpid: 137 | ax.set_title('rank:{},pid{}_{}'.format(rank_idx, gpid, gcamid), color='green') 138 | else: 139 | ax.set_title('rank:{},pid{}_{}'.format(rank_idx, gpid, gcamid), color='red') 140 | 141 | if rank_idx > topk: 142 | break 143 | fig.savefig("show_{}_{}.png".format(qpid, qcamid)) 144 | break 145 | if flag == 1: 146 | print("Done") 147 | else: 148 | print("No matched person in query_dataset, try another id") 149 | -------------------------------------------------------------------------------- /reid/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .oim import oim, OIM, OIMLoss 4 | from .pairloss import PairLoss 5 | from .triplet_oim import TripletLoss_OIM 6 | from .triplet import TripletLoss 7 | 8 | __all__ = [ 9 | 'oim', 10 | 'OIM', 11 | 'OIMLoss', 12 | 'PairLoss', 13 | 'TripletLoss', 14 | 'TripletLoss_OIM' 15 | ] 16 | 17 | 18 | -------------------------------------------------------------------------------- /reid/loss/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/loss/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/loss/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/oim.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/loss/__pycache__/oim.cpython-36.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/oim.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/loss/__pycache__/oim.cpython-37.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/pairloss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/loss/__pycache__/pairloss.cpython-36.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/pairloss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/loss/__pycache__/pairloss.cpython-37.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/triplet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/loss/__pycache__/triplet.cpython-36.pyc -------------------------------------------------------------------------------- /reid/loss/__pycache__/triplet_oim.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/loss/__pycache__/triplet_oim.cpython-36.pyc -------------------------------------------------------------------------------- /reid/loss/oim.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, autograd 6 | 7 | 8 | class OIM(autograd.Function): 9 | def __init__(self, lut, momentum=0.5): 10 | super(OIM, self).__init__() 11 | self.lut = lut # torch.Size([625, 128]) 12 | self.momentum = momentum 13 | 14 | def forward(self, inputs, targets): 15 | self.save_for_backward(inputs, targets) # inputs: torch.Size([64, 128]) 16 | outputs = inputs.mm(self.lut.t()) # (64, 128) * (128, 625) 17 | return outputs # torch.Size([64, 625]) 18 | 19 | def backward(self, grad_outputs): 20 | inputs, targets = self.saved_tensors 21 | grad_inputs = None 22 | if self.needs_input_grad[0]: 23 | grad_inputs = grad_outputs.mm(self.lut) 24 | for x, y in zip(inputs, targets): 25 | self.lut[y] = self.momentum * self.lut[y] + (1. - self.momentum) * x 26 | self.lut[y] /= self.lut[y].norm() 27 | return grad_inputs, None 28 | 29 | 30 | def oim(inputs, targets, lut, momentum=0.5): 31 | return OIM(lut, momentum=momentum)(inputs, targets) 32 | 33 | 34 | class OIMLoss(nn.Module): 35 | def __init__(self, num_features, num_classes, scalar=1.0, momentum=0.5, 36 | weight=None, size_average=True): 37 | super(OIMLoss, self).__init__() 38 | self.num_features = num_features # 512 39 | self.num_classes = num_classes # 625 40 | self.momentum = momentum # 0.5 41 | self.scalar = scalar # 30 42 | self.weight = weight # None 43 | self.register_buffer('lut', torch.zeros(num_classes, num_features)) 44 | self.size_average = size_average # True 45 | 46 | def forward(self, inputs, targets):# 47 | 48 | inputs = oim(inputs, targets, self.lut, momentum=self.momentum) 49 | 50 | inputs *= self.scalar 51 | 52 | loss = F.cross_entropy(inputs, targets, weight=self.weight) 53 | return loss, inputs 54 | -------------------------------------------------------------------------------- /reid/loss/pairloss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from reid.evaluator import accuracy 7 | 8 | 9 | class PairLoss(nn.Module): 10 | def __init__(self): 11 | super(PairLoss, self).__init__() 12 | 13 | # self.sigmod = nn.Sigmoid() 14 | self.BCE = nn.BCELoss() 15 | self.BCE.size_average = True 16 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 17 | 18 | def forward(self, score, tar_probe, tar_gallery): 19 | cls_Size = score.size() # torch.Size([4, 2]) 20 | N_probe = cls_Size[0] # 4 21 | N_gallery = cls_Size[0] 22 | 23 | tar_gallery = tar_gallery.unsqueeze(1) # 6,1 tensor([[ 94],[ 10],[ 15],[ 16],[ 75],[ 39]]) 24 | tar_probe = tar_probe.unsqueeze(0) # 1,6 tensor([[ 94, 10, 15, 16, 75, 39]]) 25 | mask = tar_probe.expand(N_probe, N_gallery).eq(tar_gallery.expand(N_probe, N_gallery)) 26 | mask = mask.view(-1).cpu().numpy().tolist() 27 | # [1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1] 28 | 29 | score = score.contiguous() # torch.Size([4, 4]) 30 | samplers = score.view(-1) # torch.Size([16]) 31 | 32 | # samplers = self.sigmod(samplers) 33 | # labels = Variable(torch.Tensor(mask).cuda()) 34 | labels = torch.Tensor(mask).to(self.device) 35 | 36 | loss = self.BCE(samplers, labels) 37 | 38 | samplers_data = samplers.data # torch.Size([36]) 39 | samplers_neg = 1 - samplers_data 40 | samplerdata = torch.cat((samplers_neg.unsqueeze(1), samplers_data.unsqueeze(1)), 1) # torch.Size([36, 2]) 41 | 42 | labeldata = torch.LongTensor(mask).to(self.device) 43 | prec, = accuracy(samplerdata, labeldata) 44 | 45 | return loss, prec 46 | -------------------------------------------------------------------------------- /reid/loss/triplet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class TripletLoss(nn.Module): 6 | 7 | def __init__(self, margin=0, batch_hard=False, dim=2048): 8 | super(TripletLoss, self).__init__() 9 | self.batch_hard = batch_hard # True 10 | if isinstance(margin, float) or margin == 'soft': 11 | self.margin = margin # ‘soft’ 12 | else: 13 | raise NotImplementedError( 14 | 'The margin {} is not recognized in TripletLoss()'.format(margin)) 15 | 16 | def forward(self, feat, id=None, pos_mask=None, neg_mask=None, mode='id', dis_func='eu', n_dis=0): 17 | if dis_func == 'cdist': 18 | feat = feat / feat.norm(p=2, dim=1, keepdim=True) 19 | dist = self.cdist(feat, feat) 20 | elif dis_func == 'eu': 21 | dist = self.cdist(feat, feat) # torch.Size([8, 8]) 22 | 23 | if mode == 'id': 24 | if id is None: 25 | raise RuntimeError('foward is in id mode, please input id!') 26 | else: 27 | identity_mask = torch.eye(feat.size(0)).byte() # torch.Size([8, 8]) 28 | identity_mask = identity_mask.cuda() if id.is_cuda else identity_mask 29 | same_id_mask = torch.eq(id.unsqueeze(1), id.unsqueeze(0)) 30 | negative_mask = same_id_mask ^ 1 # ^ 异或操作,同为0,异为1 31 | positive_mask = same_id_mask ^ identity_mask 32 | elif mode == 'mask': 33 | if pos_mask is None or neg_mask is None: 34 | raise RuntimeError('foward is in mask mode, please input pos_mask & neg_mask!') 35 | else: 36 | positive_mask = pos_mask 37 | same_id_mask = neg_mask ^ 1 38 | negative_mask = neg_mask 39 | else: 40 | raise ValueError('unrecognized mode') 41 | 42 | if self.batch_hard: 43 | if n_dis != 0: 44 | img_dist = dist[:-n_dis, :-n_dis] 45 | max_positive = (img_dist * positive_mask[:-n_dis, :-n_dis].float()).max(1)[0] 46 | min_negative = (img_dist + 1e5 * same_id_mask[:-n_dis, :-n_dis].float()).min(1)[0] 47 | dis_min_negative = dist[:-n_dis, -n_dis:].min(1)[0] 48 | z_origin = max_positive - min_negative 49 | # z_dis = max_positive - dis_min_negative 50 | else: 51 | max_positive = dist * positive_mask.float() 52 | max_positive = max_positive.max(1)[0] 53 | # tensor([11.2461, 11.0022, 11.2461, 11.1370, 8.9170, 8.4666, 8.9170, 8.4710]) 54 | same_id_mask = 1e5 * same_id_mask.float() 55 | min_negative = dist + same_id_mask 56 | min_negative = min_negative.min(1)[ 57 | 0] # tensor([ 9.5545, 10.7909, 9.3813, 10.1063, 10.7909, 10.1282, 9.3813, 11.3685], 58 | z = max_positive - min_negative 59 | # tensor([3.6010, 2.3646, 2.0217, 2.9059, 0.7020, 2.2440, 1.1459, 1.0037], 60 | else: 61 | pos = positive_mask.topk(k=1, dim=1)[1].view(-1, 1) 62 | positive = torch.gather(dist, dim=1, index=pos) 63 | pos = negative_mask.topk(k=1, dim=1)[1].view(-1, 1) 64 | negative = torch.gather(dist, dim=1, index=pos) 65 | z = positive - negative 66 | 67 | if isinstance(self.margin, float): 68 | b_loss = torch.clamp(z + self.margin, min=0) 69 | elif self.margin == 'soft': 70 | if n_dis != 0: 71 | b_loss = torch.log(1 + torch.exp(z_origin)) + -0.5 * dis_min_negative # + torch.log(1+torch.exp(z_dis)) 72 | else: 73 | b_loss = torch.log(1 + torch.exp(z)) 74 | else: 75 | raise NotImplementedError("How do you even get here!") 76 | return b_loss 77 | 78 | def cdist(self, a, b): 79 | ''' 80 | Returns euclidean distance between a and b 81 | 82 | Args: 83 | a (2D Tensor): A batch of vectors shaped (B1, D) 84 | b (2D Tensor): A batch of vectors shaped (B2, D) 85 | Returns: 86 | A matrix of all pairwise distance between all vectors in a and b, 87 | will be shape of (B1, B2) 88 | ''' 89 | diff = a.unsqueeze(1) - b.unsqueeze(0) 90 | return ((diff ** 2).sum(2) + 1e-12).sqrt() 91 | -------------------------------------------------------------------------------- /reid/loss/triplet_oim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class TripletLoss_OIM(nn.Module): 6 | 7 | def __init__(self, margin=0, batch_hard=False, dim=2048): 8 | super(TripletLoss_OIM, self).__init__() 9 | self.batch_hard = batch_hard # True 10 | if isinstance(margin, float) or margin == 'soft': 11 | self.margin = margin # ‘soft’ 12 | else: 13 | raise NotImplementedError( 14 | 'The margin {} is not recognized in TripletLoss()'.format(margin)) 15 | 16 | def forward(self, feat, lut, id=None, pos_mask=None, neg_mask=None, mode='id', dis_func='eu', n_dis=0): 17 | feat_OIM = [] 18 | for i in id: 19 | feat_OIM.append(lut[i].unsqueeze(0)) 20 | feat_oim = torch.cat(feat_OIM, dim=0) 21 | if dis_func == 'cdist': 22 | feat = feat / feat.norm(p=2, dim=1, keepdim=True) 23 | dist = self.cdist(feat, feat) 24 | elif dis_func == 'eu': 25 | dist = self.cdist(feat, feat_oim) # torch.Size([8, 8]) 26 | 27 | if mode == 'id': 28 | if id is None: 29 | raise RuntimeError('foward is in id mode, please input id!') 30 | else: 31 | identity_mask = torch.eye(feat.size(0)).byte() # torch.Size([8, 8]) 32 | identity_mask = identity_mask.cuda() if id.is_cuda else identity_mask 33 | same_id_mask = torch.eq(id.unsqueeze(1), id.unsqueeze(0)) 34 | negative_mask = same_id_mask ^ 1 # ^ 异或操作,同为0,异为1 35 | positive_mask = same_id_mask ^ identity_mask 36 | elif mode == 'mask': 37 | if pos_mask is None or neg_mask is None: 38 | raise RuntimeError('foward is in mask mode, please input pos_mask & neg_mask!') 39 | else: 40 | positive_mask = pos_mask 41 | same_id_mask = neg_mask ^ 1 42 | negative_mask = neg_mask 43 | else: 44 | raise ValueError('unrecognized mode') 45 | 46 | if self.batch_hard: 47 | if n_dis != 0: 48 | img_dist = dist[:-n_dis, :-n_dis] 49 | max_positive = (img_dist * positive_mask[:-n_dis, :-n_dis].float()).max(1)[0] 50 | min_negative = (img_dist + 1e5 * same_id_mask[:-n_dis, :-n_dis].float()).min(1)[0] 51 | dis_min_negative = dist[:-n_dis, -n_dis:].min(1)[0] 52 | z_origin = max_positive - min_negative 53 | # z_dis = max_positive - dis_min_negative 54 | else: 55 | max_positive = dist * positive_mask.float() 56 | max_positive = max_positive.max(1)[0] 57 | # tensor([11.2461, 11.0022, 11.2461, 11.1370, 8.9170, 8.4666, 8.9170, 8.4710]) 58 | same_id_mask = 1e5 * same_id_mask.float() 59 | min_negative = dist + same_id_mask 60 | min_negative = min_negative.min(1)[ 61 | 0] # tensor([ 9.5545, 10.7909, 9.3813, 10.1063, 10.7909, 10.1282, 9.3813, 11.3685], 62 | z = max_positive - min_negative # tensor([3.6010, 2.3646, 2.0217, 2.9059, 0.7020, 2.2440, 1.1459, 1.0037], 63 | else: 64 | pos = positive_mask.topk(k=1, dim=1)[1].view(-1, 1) 65 | positive = torch.gather(dist, dim=1, index=pos) 66 | pos = negative_mask.topk(k=1, dim=1)[1].view(-1, 1) 67 | negative = torch.gather(dist, dim=1, index=pos) 68 | z = positive - negative 69 | 70 | if isinstance(self.margin, float): 71 | b_loss = torch.clamp(z + self.margin, min=0) 72 | elif self.margin == 'soft': 73 | if n_dis != 0: 74 | b_loss = torch.log(1 + torch.exp(z_origin)) + -0.5 * dis_min_negative # + torch.log(1+torch.exp(z_dis)) 75 | else: 76 | b_loss = torch.log(1 + torch.exp(z)) 77 | else: 78 | raise NotImplementedError("How do you even get here!") 79 | return b_loss 80 | 81 | def cdist(self, a, b): 82 | ''' 83 | Returns euclidean distance between a and b 84 | 85 | Args: 86 | a (2D Tensor): A batch of vectors shaped (B1, D) 87 | b (2D Tensor): A batch of vectors shaped (B2, D) 88 | Returns: 89 | A matrix of all pairwise distance between all vectors in a and b, 90 | will be shape of (B1, B2) 91 | ''' 92 | diff = a.unsqueeze(1) - b.unsqueeze(0) 93 | return ((diff ** 2).sum(2) + 1e-12).sqrt() 94 | -------------------------------------------------------------------------------- /reid/models/Siamese.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | # -*- coding:utf-8 _*- 3 | # @author: ycy 4 | # @contact: asuradayuci@gmail.com 5 | # @time: 2019/9/7 下午2:53 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | #!usr/bin/env python 10 | # -*- coding:utf-8 _*- 11 | # @author: ycy 12 | # @contact: asuradayuci@gmail.com 13 | # @time: 2019/9/7 下午2:53 14 | import torch 15 | from torch import nn 16 | 17 | 18 | def weights_init_kaiming(m): 19 | classname = m.__class__.__name__ 20 | if classname.find('Linear') != -1: 21 | nn.init.kaiming_uniform_(m.weight, mode='fan_out') 22 | nn.init.constant_(m.bias, 0.0) 23 | elif classname.find('Conv') != -1: 24 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 25 | if m.bias is not None: 26 | nn.init.constant_(m.bias, 0.0) 27 | elif classname.find('BatchNorm') != -1: 28 | if m.affine: 29 | nn.init.constant_(m.weight, 1.0) 30 | nn.init.constant_(m.bias, 0.0) 31 | 32 | 33 | def weights_init_classifier(m): 34 | classname = m.__class__.__name__ 35 | if classname.find('Linear') != -1: 36 | nn.init.normal_(m.weight, std=0.001) 37 | # if m.bias: 38 | # nn.init.constant_(m.bias, 0.0) 39 | nn.init.constant_(m.bias, 0.0) 40 | 41 | 42 | class Siamese(nn.Module): 43 | 44 | def __init__(self, input_num, output_num, class_num): 45 | super(Siamese, self).__init__() 46 | 47 | self.input_num = input_num 48 | self.output_num = output_num 49 | self.class_num = class_num 50 | self.feat_num = input_num 51 | # linear_Q 52 | self.featQ = nn.Linear(self.input_num, self.output_num) 53 | self.featQ_bn = nn.BatchNorm1d(self.output_num) 54 | self.featQ.apply(weights_init_kaiming) 55 | self.featQ_bn.apply(weights_init_kaiming) 56 | 57 | # linear_K 58 | self.featK = nn.Linear(self.input_num, self.output_num) 59 | self.featK_bn = nn.BatchNorm1d(self.output_num) 60 | self.featK.apply(weights_init_kaiming) 61 | self.featK_bn.apply(weights_init_kaiming) 62 | 63 | # linear_V 64 | self.featV = nn.Linear(self.input_num, self.output_num) 65 | self.featV_bn = nn.BatchNorm1d(self.output_num) 66 | self.featV.apply(weights_init_kaiming) 67 | self.featV_bn.apply(weights_init_kaiming) 68 | 69 | # Softmax 70 | self.softmax = nn.Softmax(dim=-1) 71 | 72 | # BCE classifier 73 | self.classifierBN = nn.BatchNorm1d(self.feat_num) 74 | self.classifierlinear = nn.Linear(self.feat_num, self.class_num) 75 | self.classifierBN.apply(weights_init_kaiming) 76 | self.classifierlinear.apply(weights_init_classifier) 77 | 78 | 79 | def self_attention(self, input): 80 | size = input.size() 81 | batch = size[0] 82 | len = size[1] 83 | 84 | Qs = input.view(batch * len, -1) 85 | Qs = self.featQ(Qs) 86 | Qs = self.featQ_bn(Qs) 87 | Qs = Qs / Qs.norm(2, 1).unsqueeze(1).expand_as(Qs) 88 | Qs = Qs.contiguous().view(batch, len, -1) 89 | 90 | K = input.view(batch*len, -1) 91 | K = self.featK(K) 92 | K = self.featK_bn(K) 93 | K = K / K.norm(2, 1).unsqueeze(1).expand_as(K) 94 | K = K.view(batch, len, -1) 95 | 96 | weights = torch.matmul(Qs, K.transpose(-1, -2)) 97 | weights = self.softmax(weights) 98 | 99 | V = input.view(batch, len, -1) 100 | pool_input = torch.matmul(weights, V) 101 | 102 | pool_input = pool_input.sum(1) 103 | pool_input = pool_input / pool_input.norm(2, 1).unsqueeze(1).expand_as(pool_input) 104 | pool_input = pool_input.squeeze(1) 105 | 106 | return pool_input 107 | 108 | def forward(self, x): 109 | xsize = x.size() 110 | sample_num = xsize[0] 111 | 112 | if sample_num % 2 != 0: 113 | raise RuntimeError("the batch size should be even number!") 114 | 115 | seq_len = x.size()[1] # 8 116 | x = x.view(int(sample_num/2), 2, seq_len, -1) 117 | 118 | probe_x = x[:, 0, :, :] 119 | probe_x = probe_x.contiguous() 120 | gallery_x = x[:, 1, :, :] 121 | gallery_x = gallery_x.contiguous() 122 | 123 | pooled_probe = self.self_attention(probe_x) 124 | pooled_gallery = self.self_attention(gallery_x) 125 | 126 | siamese_out = torch.cat((pooled_probe, pooled_gallery)) 127 | probesize = pooled_probe.size() 128 | gallerysize = pooled_gallery.size() 129 | probe_batch = probesize[0] 130 | gallery_batch = gallerysize[0] 131 | 132 | pooled_gallery = pooled_gallery.unsqueeze(0) 133 | pooled_probe = pooled_probe.unsqueeze(1) 134 | 135 | diff = pooled_probe - pooled_gallery 136 | diff = torch.pow(diff, 2) 137 | diff = diff.view(probe_batch * gallery_batch, -1).contiguous() 138 | diff = self.classifierBN(diff) 139 | cls_encode = self.classifierlinear(diff) 140 | cls_encode = cls_encode.view(probe_batch, gallery_batch, -1) 141 | 142 | return cls_encode, siamese_out 143 | -------------------------------------------------------------------------------- /reid/models/Siamese_video.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python 2 | # -*- coding:utf-8 _*- 3 | # @author: ycy 4 | # @contact: asuradayuci@gmail.com 5 | # @time: 2019/9/7 下午2:53 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | #!usr/bin/env python 10 | # -*- coding:utf-8 _*- 11 | # @author: ycy 12 | # @contact: asuradayuci@gmail.com 13 | # @time: 2019/9/7 下午2:53 14 | import torch 15 | from torch import nn 16 | 17 | 18 | def weights_init_kaiming(m): 19 | classname = m.__class__.__name__ 20 | if classname.find('Linear') != -1: 21 | nn.init.kaiming_uniform_(m.weight, mode='fan_out') 22 | nn.init.constant_(m.bias, 0.0) 23 | elif classname.find('Conv') != -1: 24 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 25 | if m.bias is not None: 26 | nn.init.constant_(m.bias, 0.0) 27 | elif classname.find('BatchNorm') != -1: 28 | if m.affine: 29 | nn.init.constant_(m.weight, 1.0) 30 | nn.init.constant_(m.bias, 0.0) 31 | 32 | 33 | def weights_init_classifier(m): 34 | classname = m.__class__.__name__ 35 | if classname.find('Linear') != -1: 36 | nn.init.normal_(m.weight, std=0.001) 37 | # if m.bias: 38 | # nn.init.constant_(m.bias, 0.0) 39 | nn.init.constant_(m.bias, 0.0) 40 | 41 | 42 | class Siamese_video(nn.Module): 43 | 44 | def __init__(self, input_num=2048, output_num=2048, class_num=2): 45 | super(Siamese_video, self).__init__() 46 | 47 | # self.input_num = 2048 48 | # self.output_num = 512 49 | self.class_num = class_num 50 | self.feat_num = input_num 51 | # linear_Q 52 | # self.featQ = nn.Linear(self.input_num, self.output_num) 53 | # self.featQ_bn = nn.BatchNorm1d(self.output_num) 54 | # self.featQ.apply(weights_init_kaiming) 55 | # self.featQ_bn.apply(weights_init_kaiming) 56 | # 57 | # # linear_K 58 | # self.featK = nn.Linear(self.input_num, self.output_num) 59 | # self.featK_bn = nn.BatchNorm1d(self.output_num) 60 | # self.featK.apply(weights_init_kaiming) 61 | # self.featK_bn.apply(weights_init_kaiming) 62 | # 63 | # # linear_V 64 | # self.featV = nn.Linear(self.input_num, self.output_num) 65 | # self.featV_bn = nn.BatchNorm1d(self.output_num) 66 | # self.featV.apply(weights_init_kaiming) 67 | # self.featV_bn.apply(weights_init_kaiming) 68 | # 69 | # # Softmax 70 | # self.softmax = nn.Softmax(dim=-1) 71 | # 72 | # # numti_head 73 | # self.d_k = 128 74 | # self.head = 4 75 | 76 | # BCE classifier 77 | self.classifierBN = nn.BatchNorm1d(self.feat_num) 78 | self.classifierlinear = nn.Linear(self.feat_num, self.class_num) 79 | self.classifierBN.apply(weights_init_kaiming) 80 | self.classifierlinear.apply(weights_init_classifier) 81 | self.muti_head = False 82 | 83 | def self_attention(self, probe_value, probe_base): 84 | pro_size = probe_value.size() # torch.Size([4, 8, 128]) 85 | pro_batch = pro_size[0] 86 | pro_len = pro_size[1] 87 | 88 | Qs = probe_base.view(pro_batch * pro_len, -1) # 32 , 2048 89 | Qs = self.featQ(Qs) 90 | Qs = self.featQ_bn(Qs) # 32, 128 91 | Qs = Qs / Qs.norm(2, 1).unsqueeze(1).expand_as(Qs) # torch.Size([32, 256]) 92 | if self.muti_head: 93 | Qs = Qs.contiguous().view(pro_batch, -1, self.head, self.d_k).transpose(1, 2) # torch.Size([4, 4, 8, 64]) 94 | else: 95 | Qs = Qs.contiguous().view(pro_batch, pro_len, -1) # torch.Size([4, 8, 512]) 96 | 97 | # generating Keys, key 不等于 value 98 | K = probe_base.view(pro_batch*pro_len, -1) 99 | K = self.featK(K) 100 | K = self.featK_bn(K) 101 | K = K / K.norm(2, 1).unsqueeze(1).expand_as(K) 102 | if self.muti_head: 103 | tmp_k = K.view(pro_batch, -1, self.head, self.d_k).transpose(1, 2) # torch.Size([4, 4, 8, 64]) 104 | else: 105 | tmp_k = K.view(pro_batch, pro_len, -1) # torch.Size([4, 8, 512]) 106 | 107 | # 1.single= [4,8, 512] * [4, 512, 8] = 4, 8, 8 108 | weights = torch.matmul(Qs, tmp_k.transpose(-1, -2)) # 2. muti:torch.Size([4, 4, 8, 8]) 109 | 110 | weights = self.softmax(weights) # 4 * 8 * 8 torch.Size([4, 4, 8, 8]) 111 | 112 | if self.muti_head: 113 | V = probe_value.view(pro_batch, -1, self.head, self.d_k).transpose(1, 2) 114 | else: 115 | V = probe_value.view(pro_batch, pro_len, -1) 116 | 117 | pool_probe = torch.matmul(weights, V) # ([4, 8, 8]) * ([4, 8, 512]) = 4 * 8 * 512 torch.Size([4, 4, 8, 64]) 118 | if self.muti_head: 119 | pool_probe = pool_probe.transpose(1, 2).contiguous() # torch.Size([4, 8, 4, 64]) 120 | pool_probe = pool_probe.view(pro_batch, -1, self.head * self.d_k) # torch.Size([4, 8, 512]) 121 | 122 | pool_probe = pool_probe.sum(1) # torch.Size([4, 128]) 123 | # pool_probe = torch.mean(probe_value, dim=1) 124 | pool_probe = pool_probe / pool_probe.norm(2, 1).unsqueeze(1).expand_as(pool_probe) # 单位向量 125 | pool_probe = pool_probe.squeeze(1) 126 | 127 | return pool_probe, pool_probe 128 | 129 | def forward(self, x): 130 | # xsize = x.size() # 12,8,128 131 | # sample_num = xsize[0] # 12 132 | # 133 | # if sample_num % 2 != 0: 134 | # raise RuntimeError("the batch size should be even number!") 135 | # 136 | # seq_len = x.size()[1] # 8 137 | # x = x.view(int(sample_num/2), 2, seq_len, -1) # torch.Size([6, 2, 8, 128]) 138 | # input = input.view(int(sample_num/2), 2, seq_len, -1) # torch.Size([6, 2, 8, 2048]) => raw 139 | # probe_x = x[:, 0, :, :] 140 | # probe_x = probe_x.contiguous() # torch.Size([6, 8, 128]) 141 | # gallery_x = x[:, 1, :, :] 142 | # gallery_x = gallery_x.contiguous() # torch.Size([6, 8, 128]) 143 | # 144 | # probe_input = input[:, 0, :, :] 145 | # probe_input = probe_input.contiguous() # torch.Size([6, 8, 2048]) 146 | # gallery_input = input[:, 1, :, :] 147 | # gallery_input = gallery_input.contiguous() # torch.Size([6, 8, 2048]) 148 | # 149 | # # self-pooling pooled_probe:torch.Size([6, 128]) hidden_probe:torch.Size([6, 128]) 150 | # pooled_probe, probe_out_raw = self.self_attention(probe_x, probe_input) 151 | # # pooled_probe = probe_x.mean(dim=1) 152 | # # probe_out_raw = probe_input.mean(dim=1) 153 | # # 154 | # pooled_gallery, gallery_out_raw = self.self_attention(gallery_x, gallery_input) 155 | # # pooled_gallery = gallery_x.mean(dim=1) 156 | # # gallery_out_raw = gallery_input.mean(dim=1) 157 | 158 | batchsize = x.size(0) 159 | 160 | x = x.reshape(int(batchsize/2), 2, -1) 161 | pooled_probe = x[:,0,:] 162 | pooled_gallery = x[:,1,:] 163 | 164 | 165 | siamese_out = torch.cat((pooled_probe, pooled_gallery)) 166 | probesize = pooled_probe.size() # 4, 2048 167 | gallerysize = pooled_gallery.size() # 4, 2048 168 | probe_batch = probesize[0] # 4 169 | gallery_batch = gallerysize[0] # 4 170 | 171 | # pooled_gallery: 4, 4, 2048 172 | pooled_gallery = pooled_gallery.unsqueeze(0) # 1, 4, 2048 173 | 174 | pooled_probe = pooled_probe.unsqueeze(1) # 4, 1, 2048 175 | 176 | diff = pooled_probe - pooled_gallery 177 | diff = torch.pow(diff, 2) # torch.Size([4, 4, 2048]) 178 | diff = diff.view(probe_batch * gallery_batch, -1).contiguous() # torch.Size([16, 2048]) 179 | diff = self.classifierBN(diff) 180 | # diff = diff / diff.norm(2, 1).unsqueeze(1).expand_as(diff) 181 | cls_encode = self.classifierlinear(diff) # torch.Size([16, 2]) 182 | cls_encode = cls_encode.view(probe_batch, gallery_batch, -1) # torch.Size([4, 4, 2]) 183 | 184 | return cls_encode, siamese_out 185 | -------------------------------------------------------------------------------- /reid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .resnet import * 3 | from .grl_model import * 4 | from .Siamese import Siamese 5 | from .Siamese_video import Siamese_video 6 | 7 | 8 | __factory = { 9 | 'resnet50': resnet50, 10 | 'siamese': Siamese, 11 | 'siamese_video': Siamese_video, 12 | 'resnet50_grl': resnet50_grl, 13 | } 14 | 15 | 16 | def names(): 17 | return sorted(__factory.keys()) 18 | 19 | 20 | def create(name, *args, **kwargs): 21 | """ 22 | Create a model instance. 23 | Parameters 24 | ---------- 25 | name : str 26 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 27 | 'resnet50', 'resnet101', and 'resnet152'. 28 | pretrained : bool, optional 29 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 30 | model. Default: True 31 | cut_at_pooling : bool, optional 32 | If True, will cut the model before the last global pooling layer and 33 | ignore the remaining kwargs. Default: False 34 | num_features : int, optional 35 | If positive, will append a Linear layer after the global pooling layer, 36 | with this number of output units, followed by a BatchNorm layer. 37 | Otherwise these layers will not be appended. Default: 256 for 38 | 'inception', 0 for 'resnet*' 39 | norm : bool, optional 40 | If True, will normalize the feature to be unit L2-norm for each sample. 41 | Otherwise will append a ReLU layer after the above Linear layer if 42 | num_features > 0. Default: False 43 | dropout : float, optional 44 | If positive, will append a Dropout layer with this dropout rate. 45 | Default: 0 46 | """ 47 | if name not in __factory: 48 | raise KeyError("Unknown model:", name) 49 | return __factory[name](*args, **kwargs) 50 | -------------------------------------------------------------------------------- /reid/models/__pycache__/Siamese.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/models/__pycache__/Siamese.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/basebranch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/models/__pycache__/basebranch.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/grl_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/models/__pycache__/grl_model.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/models/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/models/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/resnets1.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/models/__pycache__/resnets1.cpython-36.pyc -------------------------------------------------------------------------------- /reid/models/__pycache__/resnets1.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/models/__pycache__/resnets1.cpython-37.pyc -------------------------------------------------------------------------------- /reid/models/basebranch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | from __future__ import unicode_literals 8 | 9 | import math 10 | import sys 11 | import os 12 | 13 | import torch 14 | import torch as th 15 | from torch import nn 16 | import torch.nn.functional as F 17 | 18 | from .resnets1 import resnet50_s1 19 | 20 | 21 | class Backbone(nn.Module): 22 | def __init__(self, height=256, width=128): 23 | super(Backbone, self).__init__() 24 | # resnet50 25 | resnet2d = resnet50_s1(pretrained=True) 26 | 27 | self.base = nn.Sequential( 28 | resnet2d.conv1, 29 | resnet2d.bn1, 30 | nn.ReLU(), 31 | resnet2d.maxpool, 32 | resnet2d.layer1, 33 | resnet2d.layer2, 34 | resnet2d.layer3, 35 | resnet2d.layer4, 36 | ) 37 | 38 | self.glo_fc = nn.Sequential(nn.Linear(2048, 1024), 39 | nn.BatchNorm1d(1024), 40 | nn.ReLU()) 41 | 42 | self.corr_atte = nn.Sequential( 43 | nn.Conv2d(2048 + 1024, 1024, 1, 1, bias=False), 44 | nn.BatchNorm2d(1024), 45 | nn.Conv2d(1024, 256, 1, 1, bias=False), 46 | nn.BatchNorm2d(256), 47 | nn.ReLU(), 48 | nn.Conv2d(256, 1, 1, 1, bias=False), 49 | nn.BatchNorm2d(1), 50 | ) 51 | 52 | def forward(self, x, b, t): 53 | 54 | x = self.base(x) 55 | 56 | x_4 = x.view(b, t, x.size(1), x.size(2), x.size(3)) 57 | 58 | x_glo = x_4.mean(dim=-1).mean(dim=-1).mean(dim=1) 59 | glo = self.glo_fc(x_glo).view(b,1, 1024, 1, 1).contiguous().expand(b,t, 1024, 16, 8).contiguous().view(b*t,1024, 16,8) 60 | 61 | x_corr = torch.cat((x, glo), dim=1) 62 | corr_map = self.corr_atte(x_corr) 63 | corr_map = F.sigmoid(corr_map).view(b * t, 1, 16, 8).contiguous() 64 | 65 | x_corr = x * corr_map 66 | x_uncorr = x*(1-corr_map) 67 | 68 | return x_uncorr, x_corr, corr_map 69 | -------------------------------------------------------------------------------- /reid/models/grl_model.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | from __future__ import unicode_literals 6 | 7 | import os 8 | import math 9 | import torch 10 | from torch import nn 11 | from torch.nn import functional as F 12 | from torch.nn import init 13 | from torch.autograd import Variable 14 | 15 | import torchvision 16 | import numpy as np 17 | 18 | from .basebranch import Backbone 19 | 20 | __all__ = ['resnet50_grl'] 21 | 22 | 23 | # =================== 24 | # Initialization 25 | # =================== 26 | 27 | def weights_init_kaiming(m): 28 | classname = m.__class__.__name__ 29 | if classname.find('Linear') != -1: 30 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 31 | nn.init.constant_(m.bias, 0.0) 32 | elif classname.find('Conv') != -1: 33 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 34 | if m.bias is not None: 35 | nn.init.constant_(m.bias, 0.0) 36 | elif classname.find('BatchNorm') != -1: 37 | if m.affine: 38 | nn.init.constant_(m.weight, 1.0) 39 | nn.init.constant_(m.bias, 0.0) 40 | 41 | 42 | def weights_init_classifier(m): 43 | classname = m.__class__.__name__ 44 | if classname.find('Linear') != -1: 45 | nn.init.normal_(m.weight, std=0.001) 46 | if m.bias: 47 | nn.init.constant_(m.bias, 0.0) 48 | 49 | 50 | 51 | class BasicBlock(nn.Module): 52 | 53 | def __init__(self, inplanes, planes, stride=1, downsample=None): 54 | super(BasicBlock, self).__init__() 55 | 56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 57 | self.bn1 = nn.BatchNorm2d(planes) 58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=1, 59 | # stride=stride, 60 | # padding=1, 61 | bias=False) 62 | self.bn2 = nn.BatchNorm2d(planes) 63 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 64 | self.bn3 = nn.BatchNorm2d(planes * 4) 65 | self.relu = nn.ReLU() 66 | 67 | def forward(self, x1, x2): 68 | x = x1 + x2 69 | residual = x 70 | 71 | out = self.conv1(x) 72 | out = self.bn1(out) 73 | out = self.relu(out) 74 | 75 | out = self.conv2(out) 76 | out = self.bn2(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv3(out) 80 | out = self.bn3(out) 81 | 82 | out += residual 83 | out = self.relu(out) 84 | 85 | return out 86 | 87 | class TRLBlock(nn.Module): 88 | def __init__(self, feat_num): 89 | super(TRLBlock, self).__init__() 90 | self.feat_num = feat_num 91 | self.feat_num_half = int(feat_num / 2) 92 | 93 | self.uncorr_memo_forward = BasicBlock(2048, 512) 94 | 95 | self.forward_f1 = nn.Sequential(nn.Conv2d(2048, 2048, 1, 1), 96 | nn.ReLU(), 97 | ) 98 | 99 | self.forward_f2 = nn.Sequential(nn.Conv2d(2048, 2048, 1, 1), 100 | nn.ReLU(), 101 | ) 102 | 103 | self.channel_atte_foreward_corr = nn.Sequential( 104 | nn.Linear(2048, 2048 // 16, bias=False), 105 | nn.ReLU(inplace=True), 106 | nn.Linear(2048// 16, 2048, bias=False), 107 | nn.Sigmoid(), 108 | ) 109 | 110 | 111 | ####################################################3 112 | 113 | self.uncorr_memo_backward = BasicBlock(2048, 512) 114 | 115 | self.backward_f1 = nn.Sequential(nn.Conv2d(2048, 2048, 1, 1), 116 | nn.ReLU(), 117 | ) 118 | 119 | self.backward_f2 = nn.Sequential(nn.Conv2d(2048, 2048, 1, 1), 120 | nn.ReLU(), 121 | ) 122 | 123 | self.channel_atte_backward_corr = nn.Sequential( 124 | nn.Linear(2048, 2048 // 16, bias=False), 125 | nn.ReLU(inplace=True), 126 | nn.Linear(2048// 16 , 2048, bias=False), 127 | nn.Sigmoid(), 128 | ) 129 | 130 | 131 | def forward(self, x_uncorr, x_corr): 132 | b, t, c, h, w = x_corr.size() 133 | 134 | f_step_forward = [] 135 | f_step_backward = [] 136 | 137 | x_uncorr_memo_forward = x_uncorr.mean(dim=1) ## b*c*h*w 138 | x_uncorr_memo_backward = x_uncorr.mean(dim=1) 139 | 140 | 141 | for i in range(0,t,1): 142 | x_corr_forward = x_corr[:,i,:,:,:] 143 | x_uncorr_forward = x_uncorr[:,i,:,:,:] 144 | 145 | 146 | f11 = self.forward_f1(x_uncorr_memo_forward) 147 | f21 = self.forward_f2( x_corr_forward )# 148 | 149 | c_atte = self.channel_atte_foreward_corr((f11 - f21).pow(2).mean(dim=-1).mean(dim=-1)) 150 | x_temp = x_corr_forward * c_atte.view(b, c, 1, 1).contiguous().expand(b, c, h, w) + x_corr_forward 151 | f_step_forward.append(x_temp.mean(dim=-1).mean(dim=-1)) 152 | 153 | x_uncorr_memo_forward = self.uncorr_memo_forward(x_uncorr_memo_forward, x_uncorr_forward) 154 | 155 | ######################### 156 | 157 | x_corr_backward = x_corr[:, t-1-i, :, :, :] 158 | x_uncorr_backward = x_uncorr[:, t-1-i, :, :, :] 159 | 160 | f12 = self.backward_f1( x_uncorr_memo_backward ) 161 | f22 = self.backward_f2( x_corr_backward )# 162 | 163 | c_atte = self.channel_atte_backward_corr((f12 - f22).pow(2).mean(dim=-1).mean(dim=-1)) 164 | x_temp = x_corr_backward * c_atte.view(b, c, 1, 1).contiguous().expand(b, c, h, w) + x_corr_backward 165 | f_step_backward.append(x_temp.mean(dim=-1).mean(dim=-1)) 166 | 167 | x_uncorr_memo_backward = self.uncorr_memo_backward(x_uncorr_memo_backward, x_uncorr_backward) 168 | 169 | 170 | temp = [] 171 | for i in range(t): 172 | temp.append(f_step_backward[t-1-i]) 173 | f_step_backward = torch.stack(temp, dim=1) 174 | f_step_forward = torch.stack(f_step_forward, dim=1) 175 | 176 | f_corr = f_step_forward + f_step_backward 177 | 178 | f_uncorr = x_uncorr_memo_forward.mean(dim=-1).mean(dim=-1) + x_uncorr_memo_backward.mean(dim=-1).mean(dim=-1) 179 | 180 | return f_uncorr, f_corr 181 | 182 | 183 | 184 | class ResNet50_GRL_Model(nn.Module): 185 | ''' 186 | Backbone: ResNet-50 + GRL modules. 187 | ''' 188 | 189 | def __init__(self, num_feat=2048, num_features=512, height=256, width=128, pretrained=True, 190 | dropout=0, numclasses=0): 191 | super(ResNet50_GRL_Model, self).__init__() 192 | self.pretrained = pretrained 193 | self.num_feat = num_feat # resnet output 194 | self.dropout = dropout 195 | self.num_classes = numclasses 196 | self.output_dim = num_features # bnneck 197 | print('Num of features: {}.'.format(self.num_feat)) 198 | 199 | self.backbone = Backbone(height=height, width=width) 200 | 201 | self.temporal_learning_block = TRLBlock(2048) 202 | # # 203 | self.corr_bn = nn.BatchNorm1d(2048) 204 | init.constant_(self.corr_bn.weight, 1) 205 | init.constant_(self.corr_bn.bias, 0) 206 | 207 | self.uncorr_bn = nn.BatchNorm1d(2048) 208 | init.constant_(self.uncorr_bn.weight, 1) 209 | init.constant_(self.uncorr_bn.bias, 0) 210 | 211 | def forward(self, inputs, training=True): 212 | b, t, c, h, w = inputs.size() 213 | im_input = inputs.view(b * t, c, h, w) # 80, 3, 256, 128 214 | x_uncorr, x_corr,corr_map = self.backbone(im_input, b, t) # b*t,2048,16,8 215 | 216 | ########################### 217 | x_corr = x_corr.view(b, t, x_corr.size(1), x_corr.size(2), x_corr.size(3)) 218 | x_uncorr = x_uncorr.view(b, t, x_uncorr.size(1), x_uncorr.size(2), x_uncorr.size(3)) 219 | 220 | x_uncorr, x_corr = self.temporal_learning_block(x_uncorr, x_corr) # 221 | 222 | x_corr = self.corr_bn(x_corr.view(b * t, 2048)).view(b, t, 2048) 223 | x_corr = F.normalize(x_corr, p=2, dim=2) 224 | 225 | x_uncorr = self.uncorr_bn(x_uncorr.view(b, 2048)).view(b, 2048) 226 | x_uncorr = F.normalize(x_uncorr, p=2, dim=1) 227 | 228 | return x_uncorr, x_corr 229 | 230 | 231 | def resnet50_grl(*args, **kwargs): 232 | return ResNet50_GRL_Model(*args, **kwargs) 233 | -------------------------------------------------------------------------------- /reid/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | from torch import nn 6 | import torchvision 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | 'resnet152'] 10 | 11 | 12 | class ResNet(nn.Module): 13 | __factory = { 14 | 18: torchvision.models.resnet18, 15 | 34: torchvision.models.resnet34, 16 | 50: torchvision.models.resnet50, 17 | 101: torchvision.models.resnet101, 18 | 152: torchvision.models.resnet152, 19 | } 20 | 21 | def __init__(self, depth, pretrained=True, cut_at_pooling=False, 22 | num_features=0, dropout=0, numclasses=0): 23 | super(ResNet, self).__init__() 24 | 25 | self.depth = depth 26 | self.pretrained = pretrained 27 | self.cut_at_pooling = cut_at_pooling 28 | 29 | # Construct base (pretrain) resnet 30 | if depth not in ResNet.__factory: 31 | raise KeyError("Unsupported depth:", depth) 32 | self.base = ResNet.__factory[depth](pretrained=pretrained) 33 | self.base.layer4[0].conv2.stride = (1, 1) 34 | self.base.layer4[0].downsample[0].stride = (1, 1) 35 | 36 | self.classifier = nn.Linear(self.base.fc.in_features, numclasses) # 2048, C 37 | init.kaiming_uniform_(self.classifier.weight, mode='fan_out') 38 | init.constant_(self.classifier.bias, 0) 39 | if not self.cut_at_pooling: 40 | self.num_features = num_features 41 | self.dropout = dropout 42 | self.has_embedding = num_features > 0 43 | 44 | out_planes = self.base.fc.in_features 45 | self.feat_bn2 = nn.BatchNorm1d(out_planes) 46 | init.constant_(self.feat_bn2.weight, 1) 47 | init.constant_(self.feat_bn2.bias, 0) 48 | # Append new layers 49 | if self.has_embedding: 50 | self.feat = nn.Linear(out_planes, self.num_features) 51 | self.feat_bn = nn.BatchNorm1d(self.num_features) 52 | init.kaiming_uniform_(self.feat.weight, mode='fan_out') 53 | init.constant_(self.feat.bias, 0) 54 | init.constant_(self.feat_bn.weight, 1) 55 | init.constant_(self.feat_bn.bias, 0) 56 | else: 57 | self.num_features = out_planes 58 | 59 | if self.dropout > 0: 60 | self.drop = nn.Dropout(self.dropout) 61 | 62 | if not self.pretrained: 63 | self.reset_params() 64 | 65 | def forward(self, imgs): 66 | # todo: change the base model 67 | img_size = imgs.size() 68 | # motion_size = motions.size() 69 | batch_sz = img_size[0] 70 | seq_len = img_size[1] 71 | imgs = imgs.view(-1, img_size[2], img_size[3], img_size[4]) 72 | 73 | for name, module in self.base._modules.items(): 74 | 75 | if name == 'conv1': 76 | # x = module(imgs) + self.conv0(motions) 77 | x = module(imgs) 78 | continue 79 | if name == 'avgpool': 80 | break 81 | x = module(x) 82 | 83 | x = F.avg_pool2d(x, x.size()[2:]) # torch.Size([64, 2048, 1, 1]) 84 | x = x.view(x.size(0), -1) # torch.Size([64, 2048]) 85 | raw = self.feat_bn2(x) 86 | raw = raw / raw.norm(2, 1).unsqueeze(1).expand_as(raw) 87 | raw = raw.squeeze(1) 88 | raw = raw.view(batch_sz, seq_len, -1) # torch.Size([8, 8, 2048]) 89 | 90 | x = self.feat(x) # 64,128 91 | x = self.feat_bn(x) 92 | 93 | x = x / x.norm(2, 1).unsqueeze(1).expand_as(x) 94 | x = x.squeeze(1) 95 | x = x.view(batch_sz, seq_len, -1) # 8,8,128 96 | return x, raw 97 | 98 | def reset_params(self): 99 | for m in self.modules(): 100 | if isinstance(m, nn.Conv2d): 101 | init.kaiming_uniform_(m.weight, mode='fan_out') 102 | if m.bias is not None: 103 | init.constant(m.bias, 0) 104 | elif isinstance(m, nn.BatchNorm2d): 105 | init.constant(m.weight, 1) 106 | init.constant(m.bias, 0) 107 | elif isinstance(m, nn.Linear): 108 | init.normal(m.weight, std=0.001) 109 | if m.bias is not None: 110 | init.constant(m.bias, 0) 111 | 112 | def guiyihua(self, x): 113 | x_min = x.min() 114 | x_max = x.max() 115 | x_1 = (x - x_min) / (x_max - x_min) 116 | return x_1 117 | 118 | 119 | def resnet18(**kwargs): 120 | return ResNet(18, **kwargs) 121 | 122 | 123 | def resnet34(**kwargs): 124 | return ResNet(34, **kwargs) 125 | 126 | 127 | def resnet50(**kwargs): 128 | return ResNet(50, **kwargs) 129 | 130 | 131 | def resnet101(**kwargs): 132 | return ResNet(101, **kwargs) 133 | 134 | 135 | def resnet152(**kwargs): 136 | return ResNet(152, **kwargs) 137 | -------------------------------------------------------------------------------- /reid/models/resnets1.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50_s1', 'resnet101', 7 | 'resnet152'] 8 | 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | """3x3 convolution with padding""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | expansion = 4 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None): 61 | super(Bottleneck, self).__init__() 62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 63 | self.bn1 = nn.BatchNorm2d(planes) 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 65 | padding=1, bias=False) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 68 | self.bn3 = nn.BatchNorm2d(planes * 4) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x): 74 | residual = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv3(out) 85 | out = self.bn3(out) 86 | 87 | if self.downsample is not None: 88 | residual = self.downsample(x) 89 | 90 | out += residual 91 | out = self.relu(out) 92 | 93 | return out 94 | 95 | 96 | class ResNet(nn.Module): 97 | 98 | def __init__(self, block, layers, num_classes=1000): 99 | self.inplanes = 64 100 | super(ResNet, self).__init__() 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 102 | self.bn1 = nn.BatchNorm2d(64) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 105 | 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1) 110 | self.avgpool = nn.AvgPool2d(7, stride=1) 111 | self.fc = nn.Linear(512 * block.expansion, num_classes) 112 | 113 | for m in self.modules(): 114 | if isinstance(m, nn.Conv2d): 115 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 116 | m.weight.data.normal_(0, math.sqrt(2. / n)) 117 | elif isinstance(m, nn.BatchNorm2d): 118 | m.weight.data.fill_(1) 119 | m.bias.data.zero_() 120 | 121 | def _make_layer(self, block, planes, blocks, stride=1): 122 | downsample = None 123 | if stride != 1 or self.inplanes != planes * block.expansion: 124 | downsample = nn.Sequential( 125 | nn.Conv2d(self.inplanes, planes * block.expansion, 126 | kernel_size=1, stride=stride, bias=False), 127 | nn.BatchNorm2d(planes * block.expansion), 128 | ) 129 | 130 | layers = [] 131 | layers.append(block(self.inplanes, planes, stride, downsample)) 132 | self.inplanes = planes * block.expansion 133 | for i in range(1, blocks): 134 | layers.append(block(self.inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | x = self.conv1(x) 140 | x = self.bn1(x) 141 | x = self.relu(x) 142 | x = self.maxpool(x) 143 | 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | 149 | x = self.avgpool(x) 150 | x = x.view(x.size(0), -1) 151 | x = self.fc(x) 152 | 153 | return x 154 | 155 | 156 | def resnet18(pretrained=False, **kwargs): 157 | """Constructs a ResNet-18 model. 158 | 159 | Args: 160 | pretrained (bool): If True, returns a model pre-trained on ImageNet 161 | """ 162 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 163 | if pretrained: 164 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 165 | return model 166 | 167 | 168 | def resnet34(pretrained=False, **kwargs): 169 | """Constructs a ResNet-34 model. 170 | 171 | Args: 172 | pretrained (bool): If True, returns a model pre-trained on ImageNet 173 | """ 174 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 175 | if pretrained: 176 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 177 | return model 178 | 179 | 180 | def resnet50_s1(pretrained=True, **kwargs): 181 | """Constructs a ResNet-50 model. 182 | 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | """ 186 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 187 | if pretrained: 188 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 189 | return model 190 | 191 | if __name__ == '__main__': 192 | net = resnet50_s1() 193 | layer4 = net.layer4 194 | 195 | layer4 = nn.Sequential(layer4[0], layer4[1]) 196 | print(layer4) 197 | 198 | 199 | def resnet101(pretrained=False, **kwargs): 200 | """Constructs a ResNet-101 model. 201 | 202 | Args: 203 | pretrained (bool): If True, returns a model pre-trained on ImageNet 204 | """ 205 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 206 | if pretrained: 207 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 208 | return model 209 | 210 | 211 | def resnet152(pretrained=False, **kwargs): 212 | """Constructs a ResNet-152 model. 213 | 214 | Args: 215 | pretrained (bool): If True, returns a model pre-trained on ImageNet 216 | """ 217 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 218 | if pretrained: 219 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 220 | return model 221 | -------------------------------------------------------------------------------- /reid/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import SEQTrainer -------------------------------------------------------------------------------- /reid/train/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/train/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /reid/train/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/train/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reid/train/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/reid/train/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /reid/train/trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | import torch 4 | from reid.evaluator import accuracy 5 | from utils.meters import AverageMeter 6 | import torch.nn.functional as F 7 | from tensorboardX import SummaryWriter 8 | from visualize import reverse_normalize 9 | from cam_functions import visual_batch 10 | 11 | from reid.loss import TripletLoss, TripletLoss_OIM 12 | criterion_triplet_oim = TripletLoss_OIM('soft', True) 13 | criterion_triplet = TripletLoss('soft', True) 14 | 15 | 16 | class BaseTrainer(object): 17 | 18 | def __init__(self, model, criterion): 19 | super(BaseTrainer, self).__init__() 20 | self.model = model 21 | self.criterion_ver = criterion 22 | self.criterion_ver_uncorr = criterion 23 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 24 | 25 | def train(self, epoch, data_loader, optimizer1): 26 | self.model.train() 27 | 28 | batch_time = AverageMeter() 29 | data_time = AverageMeter() 30 | losses = AverageMeter() 31 | 32 | precisions = AverageMeter() 33 | precisions1 = AverageMeter() 34 | precisions2 = AverageMeter() 35 | 36 | 37 | end = time.time() 38 | for i, inputs in enumerate(data_loader): 39 | data_time.update(time.time() - end) 40 | 41 | inputs, targets = self._parse_data(inputs) 42 | 43 | all_loss, uncorr_prec_id_vid, corr_prec_id_vid, corr_prec_id_frame = self._forward(inputs, targets, i, epoch) 44 | loss = all_loss 45 | 46 | losses.update(loss.item(), targets.size(0)) 47 | 48 | precisions.update(uncorr_prec_id_vid, targets.size(0)) 49 | precisions1.update(corr_prec_id_vid, targets.size(0)) 50 | precisions2.update(corr_prec_id_frame, targets.size(0)) 51 | 52 | 53 | optimizer1.zero_grad() 54 | loss.backward() 55 | optimizer1.step() 56 | 57 | batch_time.update(time.time() - end) 58 | end = time.time() 59 | print_freq = 100 60 | num_step = len(data_loader) 61 | num_iter = num_step * epoch + i 62 | 63 | self.writer.add_scalar('train/total_loss_step', losses.val, num_iter) 64 | 65 | self.writer.add_scalar('train/total_loss_avg', losses.avg, num_iter) 66 | 67 | 68 | if (i + 1) % print_freq == 0: 69 | print('Epoch: [{}][{}/{}]\t' 70 | 'Loss {:.3f} ({:.3f})\t' 71 | 'uncorr_vid {:.2%} ({:.2%})\t' 72 | 'corr_vid {:.2%} ({:.2%})\t' 73 | 'corr_frame {:.2%} ({:.2%})\t' 74 | .format(epoch, i + 1, len(data_loader), losses.val, losses.avg, 75 | precisions.val, precisions.avg, 76 | precisions1.val, precisions1.avg, 77 | precisions2.val, precisions2.avg 78 | )) 79 | 80 | def _parse_data(self, inputs): 81 | raise NotImplementedError 82 | 83 | def _forward(self, inputs, targets, i, epoch): 84 | raise NotImplementedError 85 | 86 | 87 | class SEQTrainer(BaseTrainer): 88 | 89 | def __init__(self, cnn_model, siamese_model, siamese_model_uncorr, criterion_veri, criterion_corr, criterion_uncorr, logdir): 90 | super(SEQTrainer, self).__init__(cnn_model, criterion_veri) 91 | self.siamese_model = siamese_model 92 | self.siamese_model_uncorr = siamese_model_uncorr 93 | 94 | self.criterion_uncorr = criterion_uncorr 95 | self.criterion_corr = criterion_corr 96 | 97 | self.writer = SummaryWriter(log_dir=logdir) 98 | 99 | def _parse_data(self, inputs): 100 | imgs, pids, _ = inputs 101 | imgs = imgs.to(self.device) 102 | inputs = [imgs] 103 | 104 | targets = pids.to(self.device) 105 | return inputs, targets 106 | 107 | def _forward(self, inputs, targets, i, epoch): 108 | batch_size = inputs[0].size(0) 109 | seq_len = inputs[0].size(1) 110 | 111 | x_uncorr, x_corr = self.model(inputs[0]) 112 | 113 | # uncorr_id_loss_vid, output_id = self.criterion_uncorr(x_uncorr, targets) 114 | # uncorr_prec_id_vid, = accuracy(output_id.data, targets.data) 115 | 116 | # expand the target label ID loss 117 | frame_corr = x_corr.view(batch_size * seq_len, -1) 118 | 119 | targetX = targets.unsqueeze(1) 120 | targetX = targetX.expand(batch_size, seq_len) 121 | targetX = targetX.contiguous() 122 | targetX = targetX.view(batch_size * seq_len, -1) # 123 | targetX = targetX.squeeze(1) 124 | 125 | ####### 126 | corr_id_loss_frame, output_id = self.criterion_corr(frame_corr, targetX) 127 | corr_prec_id_frame, = accuracy(output_id.data, targetX.data) 128 | 129 | # verification label 130 | targets = targets.data 131 | targets = targets.view(int(batch_size / 2), -1) 132 | tar_probe = targets[:, 0] 133 | tar_gallery = targets[:, 1] 134 | 135 | target = torch.cat((tar_probe, tar_gallery)) 136 | 137 | encode_scores, siamese_out = self.siamese_model(x_corr) 138 | corr_id_loss_vid, output_id = self.criterion_corr(siamese_out, target) 139 | corr_prec_id_vid, = accuracy(output_id.data, target.data) 140 | 141 | corr_loss_tri = criterion_triplet(siamese_out, target).mean() 142 | 143 | ### verification loss for pair-wise video feature 144 | encode_size = encode_scores.size() 145 | encodemat = encode_scores.view(-1, 2) 146 | encodemat = F.softmax(encodemat, dim=-1) 147 | encodemat = encodemat.view(encode_size[0], encode_size[1], 2) 148 | encodemat0 = encodemat[:, :, 1] 149 | corr_loss_ver, corr_prec_ver = self.criterion_ver(encodemat0, tar_probe, tar_gallery) 150 | 151 | encode_scores, siamese_out = self.siamese_model_uncorr(x_uncorr) 152 | uncorr_id_loss_vid, output_id = self.criterion_uncorr(siamese_out, target) 153 | uncorr_prec_id_vid, = accuracy(output_id.data, target.data) 154 | 155 | # uncorr_loss_tri = criterion_triplet(siamese_out, target).mean() 156 | 157 | encode_size = encode_scores.size() 158 | encodemat = encode_scores.view(-1, 2) 159 | encodemat = F.softmax(encodemat, dim=-1) 160 | encodemat = encodemat.view(encode_size[0], encode_size[1], 2) 161 | encodemat0 = encodemat[:, :, 1] 162 | uncorr_loss_ver, uncorr_prec_ver = self.criterion_ver_uncorr(encodemat0, tar_probe, tar_gallery) 163 | 164 | 165 | corr_loss = corr_id_loss_frame + corr_id_loss_vid + corr_loss_ver*20 + corr_loss_tri 166 | uncorr_loss = uncorr_id_loss_vid #+ corr_loss_ver*10 167 | 168 | all_loss = uncorr_loss + corr_loss 169 | 170 | return all_loss, uncorr_prec_id_vid, corr_prec_id_vid , corr_prec_id_frame 171 | 172 | def train(self, epoch, data_loader, optimizer1): 173 | self.siamese_model.train() 174 | self.siamese_model_uncorr.train() 175 | 176 | super(SEQTrainer, self).train(epoch, data_loader, optimizer1) 177 | 178 | -------------------------------------------------------------------------------- /test_all.py: -------------------------------------------------------------------------------- 1 | # system tool 2 | from __future__ import print_function, absolute_import 3 | import argparse 4 | import os 5 | import os.path as osp 6 | import sys 7 | 8 | # computation tool 9 | import torch 10 | import numpy as np 11 | 12 | # device tool 13 | import torch.backends.cudnn as cudnn 14 | # from tensorboardX import SummaryWriter 15 | # import adabound 16 | # utilis 17 | from utils.logging import Logger 18 | from reid import models 19 | from utils.serialization import load_checkpoint, save_cnn_checkpoint, save_siamese_checkpoint 20 | from utils.serialization import remove_repeat_tensorboard_files 21 | from reid.loss import PairLoss, OIMLoss 22 | from reid.data import get_data 23 | from reid.train import SEQTrainer 24 | from reid.evaluator import ATTEvaluator 25 | 26 | 27 | def save_checkpoint(cnn_model, siamese_model, epoch, best_top1, is_best): 28 | save_cnn_checkpoint({ 29 | 'state_dict': cnn_model.state_dict(), 30 | 'epoch': epoch + 1, 31 | 'best_top1': best_top1, 32 | }, is_best, fpath=osp.join(args.logs_dir, 'cnn_checkpoint.pth.tar')) 33 | 34 | save_siamese_checkpoint({ 35 | 'state_dict': siamese_model.state_dict(), 36 | 'epoch': epoch + 1, 37 | 'best_top1': best_top1, 38 | }, is_best, fpath=osp.join(args.logs_dir, 'siamese_checkpoint.pth.tar')) 39 | 40 | 41 | def load_best_checkpoint(cnn_model, siamese_model): 42 | checkpoint0 = load_checkpoint(osp.join(args.logs_dir, 'cnnmodel_best.pth.tar')) 43 | cnn_model.load_state_dict(checkpoint0['state_dict']) 44 | 45 | checkpoint1 = load_checkpoint(osp.join(args.logs_dir, 'siamesemodel_best.pth.tar')) 46 | siamese_model.load_state_dict(checkpoint1['state_dict']) 47 | 48 | 49 | def main(args): 50 | np.random.seed(args.seed) 51 | torch.manual_seed(args.seed) 52 | torch.cuda.manual_seed_all(args.seed) 53 | cudnn.benchmark = True 54 | # os.environ['CUDA_VISIBLE_DEVICES'] = '1' 55 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 56 | 57 | # log file 日志文件 防止重名覆盖 58 | run = 0 59 | if args.evaluate == 1: 60 | while osp.exists("%s" % (osp.join(args.logs_dir, 'log_testall{}.txt'.format(run)))): 61 | run += 1 62 | 63 | sys.stdout = Logger(osp.join(args.logs_dir, 'log_testall{}.txt'.format(run))) 64 | else: 65 | while osp.exists("%s" % (osp.join(args.logs_dir, 'log_train{}.txt'.format(run)))): 66 | run += 1 67 | 68 | sys.stdout = Logger(osp.join(args.logs_dir, 'log_train{}.txt'.format(run))) 69 | print("==========\nArgs:{}\n==========".format(args)) 70 | 71 | dataset, num_classes, train_loader, query_loader, gallery_loader = \ 72 | get_data(args.dataset, args.split, args.data_dir, 73 | args.batch_size, args.seq_len, args.seq_srd, 74 | args.workers, only_eval=True) 75 | 76 | cnn_model = models.create(args.arch1, num_features=args.features, dropout=args.dropout, numclasses=num_classes) 77 | 78 | # create Siamese model 79 | siamese_model = models.create(args.arch2, input_num=args.features, output_num=512, class_num=2) 80 | 81 | cnn_model = torch.nn.DataParallel(cnn_model).to(device) 82 | siamese_model = siamese_model.to(device) 83 | 84 | tensorboard_train_logdir = osp.join(args.logs_dir, 'train_log') 85 | remove_repeat_tensorboard_files(tensorboard_train_logdir) 86 | # Evaluator 测试 87 | 88 | evaluator = ATTEvaluator(cnn_model, siamese_model, only_eval=True) 89 | 90 | load_best_checkpoint(cnn_model, siamese_model) 91 | top1 = evaluator.evaluate(dataset.query, dataset.gallery, query_loader, gallery_loader, args.logs_dir1, args.visul, 92 | args.rerank) 93 | 94 | 95 | if __name__ == '__main__': 96 | parser = argparse.ArgumentParser(description="ID Training ResNet Model") 97 | 98 | # DATA 99 | parser.add_argument('-d', '--dataset', type=str, default='mars', 100 | choices=['ilidsvidsequence', 'prid2011sequence', 'mars']) 101 | parser.add_argument('-b', '--batch-size', type=int, default=32) 102 | 103 | parser.add_argument('-j', '--workers', type=int, default=8) 104 | 105 | parser.add_argument('--seq_len', type=int, default=8) 106 | 107 | parser.add_argument('--seq_srd', type=int, default=4) 108 | 109 | parser.add_argument('--split', type=int, default=0) 110 | 111 | # MODEL 112 | # CNN model 113 | parser.add_argument('--a1', '--arch_1', type=str, default='resnet50_rga', 114 | choices=['resnet50_rga', 'resnet50']) 115 | parser.add_argument('--features', type=int, default=512) 116 | parser.add_argument('--dropout', type=float, default=0.0) 117 | 118 | # Siamese model 119 | parser.add_argument('--a2', '--arch_2', type=str, default='siamese', 120 | choices=models.names()) 121 | 122 | # Criterion model 123 | parser.add_argument('--loss', type=str, default='oim', 124 | choices=['xentropy', 'oim', 'triplet']) 125 | parser.add_argument('--oim-scalar', type=float, default=20) 126 | parser.add_argument('--oim-momentum', type=float, default=0.5) 127 | parser.add_argument('--sampling-rate', type=int, default=3) 128 | parser.add_argument('--sample_method', type=str, default='rrs') 129 | 130 | # OPTIMIZER 131 | parser.add_argument('--seed', type=int, default=1) 132 | parser.add_argument('--lr1', type=float, default=0.001) 133 | parser.add_argument('--lr2', type=float, default=0.001) 134 | parser.add_argument('--lr3', type=float, default=1.0) 135 | 136 | parser.add_argument('--lr1step', type=float, default=15) 137 | parser.add_argument('--lr2step', type=float, default=20) 138 | parser.add_argument('--lr3step', type=float, default=40) 139 | 140 | parser.add_argument('--momentum', type=float, default=0.9) 141 | parser.add_argument('--weight-decay', type=float, default=5e-4) 142 | parser.add_argument('--cnn_resume', type=str, default='', metavar='PATH') 143 | 144 | # TRAINER 145 | parser.add_argument('--start-epoch', type=int, default=0) 146 | parser.add_argument('--epochs', type=int, default=60) 147 | # EVAL 148 | parser.add_argument('--evaluate', type=int, default=1) 149 | parser.add_argument('--visul', type=int, default=0, help='visul the result') 150 | parser.add_argument('--rerank', type=int, default=0, help='rerank the result') 151 | # misc 152 | working_dir = osp.dirname(osp.abspath(__file__)) 153 | parser.add_argument('--data-dir', type=str, metavar='PATH', 154 | default='/home/ycy/data') 155 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 156 | default=osp.join(working_dir, 'log/no_rga_6_8*4')) 157 | parser.add_argument('--logs-dir1', type=str, metavar='PATH', 158 | default=osp.join(working_dir, 'log/no_rga_6_8*4/split0')) 159 | 160 | args = parser.parse_args() 161 | 162 | # main function 163 | main(args) 164 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def to_numpy(tensor): 5 | if torch.is_tensor(tensor): 6 | return tensor.cpu().numpy() 7 | elif type(tensor).__module__ != 'numpy': 8 | raise ValueError("Cannot convert {} to numpy array" 9 | .format(type(tensor))) 10 | return tensor 11 | 12 | 13 | def to_torch(ndarray): 14 | if type(ndarray).__module__ == 'numpy': 15 | return torch.from_numpy(ndarray) 16 | elif not torch.is_tensor(ndarray): 17 | raise ValueError("Cannot convert {} to torch tensor" 18 | .format(type(ndarray))) 19 | return ndarray 20 | 21 | -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logging.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/utils/__pycache__/logging.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logging.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/utils/__pycache__/logging.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/meters.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/utils/__pycache__/meters.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/meters.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/utils/__pycache__/meters.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/osutils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/utils/__pycache__/osutils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/osutils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/utils/__pycache__/osutils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/serialization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/utils/__pycache__/serialization.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/serialization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/flysnowtiger/GRL/26eed6542b636ab373dbc909d5d92b2d2b0cd7ae/utils/__pycache__/serialization.cpython-37.pyc -------------------------------------------------------------------------------- /utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os 4 | import os.path as osp 5 | import shutil 6 | 7 | import torch 8 | from torch.nn import Parameter 9 | 10 | from .osutils import mkdir_if_missing 11 | 12 | 13 | def read_json(fpath): 14 | with open(fpath, 'r') as f: 15 | obj = json.load(f) 16 | return obj 17 | 18 | 19 | def write_json(obj, fpath): 20 | mkdir_if_missing(osp.dirname(fpath)) 21 | with open(fpath, 'w') as f: 22 | json.dump(obj, f, indent=4, separators=(',', ': ')) 23 | 24 | 25 | def save_cnn_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 26 | mkdir_if_missing(osp.dirname(fpath)) 27 | torch.save(state, fpath) 28 | if is_best: 29 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'cnnmodel_best.pth.tar')) 30 | 31 | 32 | def save_att_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 33 | mkdir_if_missing(osp.dirname(fpath)) 34 | torch.save(state, fpath) 35 | if is_best: 36 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'attmodel_best.pth.tar')) 37 | 38 | 39 | def save_siamese_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 40 | mkdir_if_missing(osp.dirname(fpath)) 41 | torch.save(state, fpath) 42 | if is_best: 43 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'siamesemodel_best.pth.tar')) 44 | 45 | 46 | def save_cls_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 47 | mkdir_if_missing(osp.dirname(fpath)) 48 | torch.save(state, fpath) 49 | if is_best: 50 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'clsmodel_best.pth.tar')) 51 | 52 | 53 | def load_checkpoint(fpath): 54 | if osp.isfile(fpath): 55 | checkpoint = torch.load(fpath) 56 | print("=> Loaded checkpoint '{}'".format(fpath)) 57 | return checkpoint 58 | else: 59 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 60 | 61 | 62 | def copy_state_dict(state_dict, model, strip=None): 63 | tgt_state = model.state_dict() 64 | copied_names = set() 65 | for name, param in state_dict.items(): 66 | if strip is not None and name.startswith(strip): 67 | name = name[len(strip):] 68 | if name not in tgt_state: 69 | continue 70 | if isinstance(param, Parameter): 71 | param = param.data 72 | if param.size() != tgt_state[name].size(): 73 | print('mismatch:', name, param.size(), tgt_state[name].size()) 74 | continue 75 | tgt_state[name].copy_(param) 76 | copied_names.add(name) 77 | 78 | missing = set(tgt_state.keys()) - copied_names 79 | if len(missing) > 0: 80 | print("missing keys in state_dict:", missing) 81 | 82 | return model 83 | 84 | 85 | def remove_repeat_tensorboard_files(path): 86 | if osp.exists(path): # 1./home/zpy/Desktop/SCAN/logs/test_log 文件夹存在 87 | if os.listdir(path): # 这个文件夹中存在文件 88 | files = os.listdir(path) 89 | for file in files: 90 | file_path = osp.join(path, file) # 文件的绝对路径 91 | os.remove(file_path) # 删除文件 92 | 93 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import numpy as np 7 | import cv2 8 | 9 | 10 | def reverse_normalize(x, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 11 | x[:, 0, :, :] = x[:, 0, :, :] * std[0] + mean[0] 12 | x[:, 1, :, :] = x[:, 1, :, :] * std[1] + mean[1] 13 | x[:, 2, :, :] = x[:, 2, :, :] * std[2] + mean[2] 14 | return x 15 | 16 | 17 | def visualize(img, cam): 18 | """ 19 | Synthesize an image with CAM to make a result image. 20 | Args: 21 | img: (Tensor) shape => (1, 3, H, W) 22 | cam: (Tensor) shape => (1, 1, H', W') 23 | Return: 24 | synthesized image (Tensor): shape =>(1, 3, H, W) 25 | """ 26 | 27 | _, _, H, W = img.shape 28 | cam = F.interpolate(cam, size=(H, W), mode='bilinear', align_corners=False) # torch.Size([1, 1, 256, 128]) 29 | cam = 255 * cam.squeeze() # torch.Size([256, 128]) 30 | cam = cam.detach().cpu() 31 | heatmap = cv2.applyColorMap(np.uint8(cam), cv2.COLORMAP_JET) # : (256, 128, 3) 32 | heatmap = torch.from_numpy(heatmap.transpose(2, 0, 1)) # torch.Size([3, 256, 128]) 33 | heatmap = heatmap.float() / 255 # torch.Size([3, 256, 128]) 34 | b, g, r = heatmap.split(1) 35 | heatmap = torch.cat([r, g, b]) # torch.Size([3, 256, 128]) 36 | 37 | result = heatmap + img.cpu() 38 | # result = heatmap 39 | result = result.div(result.max()) 40 | 41 | return result, img.cpu() 42 | 43 | 44 | def visualize2(cam): 45 | """ 46 | Synthesize an image with CAM to make a result image. 47 | Args: 48 | cam: (Tensor) shape => (1, 1, H', W') 49 | Return: 50 | synthesized image (Tensor): shape =>(1, 3, H, W) 51 | """ 52 | 53 | _, _, H, W = cam.shape 54 | cam = F.interpolate(cam, size=(H, W), mode='bilinear', align_corners=False) 55 | cam = 255 * cam.squeeze() 56 | cam = cam.detach().cpu() 57 | heatmap = cv2.applyColorMap(np.uint8(cam), cv2.COLORMAP_JET) 58 | heatmap = torch.from_numpy(heatmap.transpose(2, 0, 1)) 59 | heatmap = heatmap.float() / 255 60 | b, g, r = heatmap.split(1) 61 | heatmap = torch.cat([r, g, b]) 62 | 63 | result = heatmap 64 | result = result.div(result.max()) 65 | 66 | return result --------------------------------------------------------------------------------