├── GASNet ├── README.md ├── main_gasnet.py ├── reid │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── img_evaluators.cpython-36.pyc │ │ ├── img_evaluators.cpython-38.pyc │ │ ├── img_trainers.cpython-36.pyc │ │ └── img_trainers.cpython-38.pyc │ ├── data_manager │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── cuhk03.cpython-36.pyc │ │ │ ├── test_show_dataset.cpython-38.pyc │ │ │ ├── vru.cpython-36.pyc │ │ │ └── vru.cpython-38.pyc │ │ ├── test_show_dataset.py │ │ └── vru.py │ ├── evaluation_metrics │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── classification.cpython-36.pyc │ │ │ ├── classification.cpython-38.pyc │ │ │ ├── evaluate.cpython-36.pyc │ │ │ ├── evaluate.cpython-38.pyc │ │ │ ├── ranking.cpython-36.pyc │ │ │ └── ranking.cpython-38.pyc │ │ ├── classification.py │ │ ├── evaluate.py │ │ └── ranking.py │ ├── img_evaluators.py │ ├── img_trainers.py │ ├── loss │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── loss_set.cpython-36.pyc │ │ │ └── loss_set.cpython-38.pyc │ │ └── loss_set.py │ ├── models │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── rga_model.cpython-36.pyc │ │ │ └── rga_model.cpython-38.pyc │ │ ├── models_utils │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-36.pyc │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ ├── os_modules.cpython-36.pyc │ │ │ │ ├── os_modules.cpython-38.pyc │ │ │ │ ├── rga_branches.cpython-36.pyc │ │ │ │ ├── rga_branches.cpython-38.pyc │ │ │ │ ├── rga_modules.cpython-36.pyc │ │ │ │ └── rga_modules.cpython-38.pyc │ │ │ ├── os_modules.py │ │ │ ├── rga_branches.py │ │ │ └── rga_modules.py │ │ └── rga_model.py │ └── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── lr_scheduler.cpython-36.pyc │ │ ├── lr_scheduler.cpython-38.pyc │ │ ├── meters.cpython-36.pyc │ │ ├── meters.cpython-38.pyc │ │ ├── osutils.cpython-36.pyc │ │ ├── osutils.cpython-38.pyc │ │ ├── serialization.cpython-36.pyc │ │ └── serialization.cpython-38.pyc │ │ ├── data │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── preprocessor.cpython-36.pyc │ │ │ ├── preprocessor.cpython-38.pyc │ │ │ ├── sampler.cpython-36.pyc │ │ │ ├── sampler.cpython-38.pyc │ │ │ ├── transforms.cpython-36.pyc │ │ │ └── transforms.cpython-38.pyc │ │ ├── preprocessor.py │ │ ├── sampler.py │ │ └── transforms.py │ │ ├── logging.py │ │ ├── lr_scheduler.py │ │ ├── meters.py │ │ ├── osutils.py │ │ └── serialization.py └── requirements.txt ├── README.md └── VRU └── README.md /GASNet/README.md: -------------------------------------------------------------------------------- 1 | # Vehicle Re-identification based on UAV Viewpoint: Dataset and Method 2 | 3 | A PyTorch implementation of GASNet 4 | 5 | ## Abstract 6 | 7 | High-resolution remote sensing images bring a large amount of data as well as challenges to traditional vision tasks. Vehicle re-identification (ReID), as an essential vision task that can utilize remote sensing images, has been widely used in suspect vehicle search, cross-border vehicle tracking, traffic behavior analysis, and automatic toll collection systems. Although there have been a large number of studies on vehicle ReID, most of them are based on fixed surveillance cameras and do not take full advantage of high-resolution remote sensing images. Compared with images collected by fixed surveillance cameras, high-resolution remote sensing images based on Unmanned Aerial Vehicles (UAVs) have the characteristics of rich viewpoints and a wide range of scale variations. These characteristics bring richer information to vehicle ReID tasks and have the potential to improve the performance of vehicle ReID models. However, to the best of our knowledge, there is a shortage of large open-source datasets for vehicle ReID based on UAV views, which is not conducive to promoting UAV-view-based vehicle ReID research. To address this issue, we construct a large-scale vehicle ReID dataset named VRU (the abbreviation of Vehicle Re-identification based on UAV), which consists of 172,137 images of 15,085 vehicles captured by UAVs, through which each vehicle has multiple images from various viewpoints. Compared with the existing vehicle ReID datasets based on UAVs, the VRU dataset has a larger volume and is fully open source. Since most of the existing vehicle ReID methods are designed for fixed surveillance cameras, it is difficult for these methods to adapt to UAV-based vehicle ReID images with multi-viewpoint and multi-scale characteristics. Thus, this work proposes a Global Attention and full-Scale Network (GASNet) for the vehicle ReID task based on UAV images. To verify the effectiveness of our GASNet, GASNet is compared with the baseline models on the VRU dataset. The experiment results show that GASNet can achieve 97.45% Rank-1 and 98.51% mAP, which outperforms those baselines by 3.43%/2.08% improvements in term of Rank-1/mAP. Thus, our major contributions can be summarized as follows: (1) the provision of an open-source UAV-based vehicle ReID dataset, (2) the proposal of a state-of-art model for UAV-based vehicle ReID. 8 | 9 | ## Examples 10 | 11 | Please download the pre-train model from this [link](https://pan.baidu.com/s/1XPSgZI92ClK8lcas_v9sRg?pwd=hqj0) and put it in the ./weights/pre_train/ folder. 12 | 13 | Please download the VRU dataset from this [link](https://github.com/GeoX-Lab/ReID/tree/main/VRU) and put it in the ./datasets/ folder. 14 | 15 | train 16 | 17 | `python main_gasnet.py` 18 | 19 | test 20 | 21 | `python main_gasnet.py --evaluate` 22 | 23 | ## Citation 24 | 25 | If you find this code or dataset useful for your research, please cite our paper. 26 | 27 | ``` 28 | Bibtex 29 | @Article{rs14184603, 30 | AUTHOR = {Lu, Mingming and Xu, Yongchuan and Li, Haifeng}, 31 | TITLE = {Vehicle Re-Identification Based on UAV Viewpoint: Dataset and Method}, 32 | JOURNAL = {Remote Sensing}, 33 | VOLUME = {14}, 34 | YEAR = {2022}, 35 | NUMBER = {18}, 36 | ARTICLE-NUMBER = {4603}, 37 | URL = {https://www.mdpi.com/2072-4292/14/18/4603}, 38 | ISSN = {2072-4292}, 39 | DOI = {10.3390/rs14184603} 40 | } 41 | ``` 42 | -------------------------------------------------------------------------------- /GASNet/main_gasnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from __future__ import print_function, absolute_import 5 | import argparse 6 | import os 7 | import os.path as osp 8 | import numpy as np 9 | import sys 10 | 11 | import torch 12 | from torch import nn 13 | from torch.backends import cudnn 14 | from torch.utils.data import DataLoader 15 | from torch.utils.tensorboard import SummaryWriter 16 | from datetime import datetime 17 | # from torchsummary import summary 18 | 19 | from reid import data_manager 20 | from reid import models 21 | from reid.img_trainers import ImgTrainer 22 | from reid.img_evaluators import ImgEvaluator 23 | from reid.loss.loss_set import TripletHardLoss, CrossEntropyLabelSmoothLoss 24 | from reid.utils.data import transforms as T 25 | from reid.utils.data.preprocessor import Preprocessor 26 | from reid.utils.data.sampler import RandomIdentitySampler 27 | from reid.utils.serialization import load_checkpoint, save_checkpoint 28 | from reid.utils.lr_scheduler import LRScheduler 29 | 30 | sys.path.append(os.path.join(os.path.dirname(__file__))) 31 | 32 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 33 | 34 | def get_data(name, data_dir, height, width, batch_size, num_instances, 35 | workers): 36 | # Datasets 37 | if name == 'VRU': 38 | dataset_name = 'VRU' 39 | dataset = data_manager.init_imgreid_dataset( 40 | root=data_dir, name=dataset_name 41 | ) 42 | dataset.images_dir = osp.join(data_dir, 'Pic') 43 | # Num. of training IDs 44 | num_classes = dataset.num_train_cids if name == "VRU" else dataset.num_train_pids 45 | 46 | train_transformer = T.Compose([ 47 | T.Random2DTranslation(height, width), 48 | T.RandomHorizontalFlip(), 49 | T.ToTensor(), 50 | ]) 51 | 52 | test_transformer = T.Compose([ 53 | T.RectScale(height, width), 54 | T.ToTensor(), 55 | ]) 56 | 57 | train_loader = DataLoader( 58 | Preprocessor(dataset.train, root=dataset.images_dir, transform=train_transformer), 59 | batch_size=batch_size, num_workers=workers, 60 | sampler=RandomIdentitySampler(dataset.train, num_instances), 61 | pin_memory=True, drop_last=True) 62 | 63 | query_loader = DataLoader( 64 | Preprocessor(dataset.query, root=dataset.images_dir, transform=test_transformer), 65 | batch_size=batch_size, num_workers=workers, 66 | shuffle=False, pin_memory=True) 67 | 68 | gallery_loader = DataLoader( 69 | Preprocessor(dataset.gallery, root=dataset.images_dir, transform=test_transformer), 70 | batch_size=batch_size, num_workers=workers, 71 | shuffle=False, pin_memory=True) 72 | 73 | return dataset, num_classes, train_loader, query_loader, gallery_loader 74 | 75 | 76 | # 写入txt 77 | class Logger(object): 78 | def __init__(self, filename='default.log', stream=sys.stdout): 79 | self.terminal = stream 80 | self.log = open(filename, 'a') 81 | 82 | def write(self, message): 83 | self.terminal.write(message) 84 | self.log.write(message) 85 | 86 | def flush(self): 87 | pass 88 | 89 | 90 | def main(args): 91 | # Set the seeds 92 | np.random.seed(args.seed) 93 | torch.manual_seed(args.seed) 94 | cudnn.benchmark = True 95 | checkpoint = "checkpoint_88.pth" 96 | # 设置日志输出路径 97 | theTime = datetime.now().strftime('%Y_%m_%d_%H_%M_%S') 98 | if args.evaluate: 99 | sys.stdout = Logger(f'./logs/Test/GASNet_Test_{theTime}_{args.use_o_scale}_{args.dataset}.log', sys.stdout) 100 | else: 101 | sys.stdout = Logger(f'./logs/Train/GASNet_Train_{theTime}_{args.use_o_scale}_{args.dataset}.log', sys.stdout) 102 | 103 | # Create data loaders 104 | assert args.num_instances > 1, "num_instances should be greater than 1" 105 | assert args.batch_size % args.num_instances == 0, \ 106 | 'num_instances should divide batch_size' 107 | if args.height is None or args.width is None: 108 | args.height, args.width = (144, 56) if args.arch == 'inception' else \ 109 | (256, 128) 110 | 111 | dataset, num_classes, train_loader, query_loader, gallery_loader = \ 112 | get_data(args.dataset, args.data_dir, args.height, 113 | args.width, args.batch_size, args.num_instances, args.workers) 114 | 115 | # Summary Writer 116 | if not args.evaluate: 117 | TIMESTAMP = "{0:%Y-%m-%dT%H-%M-%S/}".format(datetime.now()) 118 | log_dir = osp.join(args.logs_dir, 'tensorboard_log' + TIMESTAMP) 119 | print(log_dir) 120 | summary_writer = SummaryWriter(log_dir) 121 | else: 122 | summary_writer = None 123 | 124 | # Create model 125 | model = models.create(args.arch, pretrained=False, num_feat=args.features, 126 | height=args.height, width=args.width, dropout=args.dropout, 127 | num_classes=num_classes, branch_name=args.branch_name, use_o_scale=args.use_o_scale) 128 | 129 | # Load from checkpoint 130 | start_epoch = best_top1 = 0 131 | device_ids = [0] 132 | 133 | # test/evaluate the model 134 | if args.evaluate: 135 | if args.use_o_scale: 136 | evaluate_weight = torch.load(args.checkpoint) 137 | else: 138 | evaluate_weight = torch.load(args.checkpoint) 139 | # model.eval() 140 | model.load_state_dict(evaluate_weight) 141 | model = nn.DataParallel(model, device_ids=device_ids).cuda() 142 | evaluator = ImgEvaluator(model, file_path=args.logs_dir, use_o_scale=args.use_o_scale) 143 | if args.use_o_scale: 144 | feats_list = ['feat_gasnet', 'feat_gasnet_', 'feat_cls'] 145 | evaluator.eval_worerank(query_loader, gallery_loader, dataset.query, dataset.gallery, 146 | metric=['cosine', 'euclidean'], 147 | types_list=feats_list) 148 | return 149 | else: 150 | feats_list = ['feat_rga', 'feat_rga_'] 151 | evaluator.eval_worerank(query_loader, gallery_loader, dataset.query, dataset.gallery, 152 | metric=['cosine', 'euclidean'], 153 | types_list=feats_list) 154 | return 155 | elif args.resume: 156 | torch.cuda.set_device(0) 157 | checkpoint = load_checkpoint(f'./logs/RGA-SC/{args.dataset}/checkpoint_88_88.pth.tar') 158 | model.load_state_dict(checkpoint['state_dict']) 159 | model = nn.DataParallel(model, device_ids=device_ids).cuda() 160 | start_epoch = checkpoint['epoch'] 161 | best_top1 = checkpoint['best_top1'] 162 | print("=> Start epoch {} best top1 {:.1%}".format(start_epoch, best_top1)) 163 | print("=> Start epoch {}".format(start_epoch)) 164 | 165 | else: 166 | print("=> Start train a new model!!") 167 | pre_train_weight = torch.load(f'./weights/pre_train/resnet50-19c8e357.pth') 168 | model.load_state_dict(pre_train_weight, strict=False) 169 | model = nn.DataParallel(model, device_ids=device_ids).cuda() 170 | 171 | # model = nn.DataParallel(model, device_ids=device_ids).cuda() 172 | # Criterion 173 | criterion_cls = CrossEntropyLabelSmoothLoss(num_classes).cuda() 174 | criterion_tri = TripletHardLoss(margin=args.margin) 175 | criterion = [criterion_cls, criterion_tri] 176 | 177 | # Trainer 178 | trainer = ImgTrainer(model, criterion, summary_writer, use_o_scale=args.use_o_scale) 179 | 180 | # Optimizer 181 | if hasattr(model.module, 'backbone'): 182 | base_param_ids = set(map(id, model.module.backbone.parameters())) 183 | new_params = [p for p in model.parameters() if id(p) not in base_param_ids] 184 | param_groups = [ 185 | {'params': filter(lambda p: p.requires_grad, model.module.backbone.parameters()), 'lr_mult': 1.0}, 186 | {'params': filter(lambda p: p.requires_grad, new_params), 'lr_mult': 1.0}] 187 | else: 188 | param_groups = model.parameters() 189 | if args.optimizer == 'sgd': 190 | optimizer = torch.optim.SGD(param_groups, lr=args.lr, 191 | momentum=args.momentum, 192 | weight_decay=args.weight_decay, 193 | nesterov=True) 194 | elif args.optimizer == 'adam': 195 | optimizer = torch.optim.Adam( 196 | param_groups, lr=args.lr, weight_decay=args.weight_decay 197 | ) 198 | else: 199 | raise NameError 200 | # if args.resume and checkpoint.has_key('optimizer'): 201 | # optimizer.load_state_dict(checkpoint['optimizer']) 202 | if args.resume and 'optimizer' in checkpoint: 203 | optimizer.load_state_dict(checkpoint['optimizer']) 204 | 205 | # Learning rate scheduler 206 | lr_scheduler = LRScheduler(base_lr=0.0008, step=[88, 95, 105, 115, 125, 130, 135, 140, 145, 150, 155, 160, 165, 170, 175, 180], 207 | factor=0.5, warmup_epoch=15, 208 | warmup_begin_lr=0.000008) 209 | 210 | # Start training 211 | for epoch in range(start_epoch, args.epochs): 212 | lr = lr_scheduler.update(epoch) 213 | for param_group in optimizer.param_groups: 214 | param_group['lr'] = lr 215 | print('[Info] Epoch [{}] learning rate update to {:.3e}'.format(epoch+1, lr)) 216 | trainer.train(epoch, train_loader, optimizer, random_erasing=args.random_erasing, empty_cache=args.empty_cache) 217 | if (epoch + 1) % 1 == 0 and (epoch + 1) >= args.start_save: 218 | is_best = False 219 | save_checkpoint({ 220 | 'state_dict': model.module.state_dict(), 221 | 'epoch': epoch + 1, 222 | 'best_top1': best_top1, 223 | 'optimizer': optimizer.state_dict(), 224 | }, epoch + 1, is_best, save_interval=1, fpath=osp.join(args.logs_dir, f'{args.dataset}/checkpoint_{epoch + 1}.pth.tar')) 225 | if args.use_o_scale: 226 | torch.save(model.module.state_dict(), f'./weights/{args.dataset}/use_o_scale/checkpoint_{epoch + 1}.pth') 227 | else: 228 | torch.save(model.module.state_dict(), f'./weights/{args.dataset}/use_rga/checkpoint_{epoch + 1}.pth') 229 | 230 | print("-----------Start validate!!!-------------") 231 | if args.use_o_scale: 232 | evaluate_weight = torch.load(f'./weights/{args.dataset}/use_o_scale/checkpoint_{epoch + 1}.pth') 233 | else: 234 | evaluate_weight = torch.load(f'./weights/{args.dataset}/use_rga/checkpoint_{epoch + 1}.pth') 235 | # model.eval() 236 | model.load_state_dict(evaluate_weight, strict=False) 237 | evaluator = ImgEvaluator(model, file_path=args.logs_dir, use_o_scale=args.use_o_scale) 238 | if args.use_o_scale: 239 | feats_list = ['feat_gasnet', 'feat_gasnet_', 'feat_cls'] 240 | evaluator.eval_worerank(query_loader, gallery_loader, dataset.query, dataset.gallery, 241 | metric=['cosine', 'euclidean'], 242 | types_list=feats_list) 243 | else: 244 | feats_list = ['feat_rga', 'feat_rga_'] 245 | evaluator.eval_worerank(query_loader, gallery_loader, dataset.query, dataset.gallery, 246 | metric=['cosine', 'euclidean'], 247 | types_list=feats_list) 248 | 249 | if __name__ == '__main__': 250 | torch.cuda.empty_cache() 251 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' 252 | def str2bool(v): 253 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 254 | return True 255 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 256 | return False 257 | else: 258 | raise argparse.ArgumentTypeError('Unsupported value encountered.') 259 | 260 | 261 | parser = argparse.ArgumentParser(description="Softmax loss classification") 262 | # data 263 | parser.add_argument('-d', '--dataset', type=str, default='VRU') 264 | parser.add_argument('-b', '--batch-size', type=int, default=24) 265 | parser.add_argument('-j', '--workers', type=int, default=0) 266 | parser.add_argument('--height', type=int, 267 | help="input height, default: 256 for resnet*, " 268 | "144 for inception") 269 | parser.add_argument('--width', type=int, 270 | help="input width, default: 128 for resnet*, " 271 | "56 for inception") 272 | parser.add_argument('--combine-trainval', action='store_true', 273 | help="train and val sets together for training, " 274 | "val set alone for validation") 275 | parser.add_argument('--num-instances', type=int, default=4, 276 | help="each minibatch consist of " 277 | "(batch_size // num_instances) identities, and " 278 | "each identity has num_instances instances, " 279 | "default: 4") 280 | # model 281 | parser.add_argument('-a', '--arch', type=str, default='resnet50_rga', 282 | choices=models.names()) 283 | parser.add_argument('--features', type=int, default=2048) 284 | parser.add_argument('--dropout', type=float, default=0) 285 | parser.add_argument('--branch_name', type=str, default='rgasc') 286 | parser.add_argument('--use_rgb', type=str2bool, default=True) 287 | parser.add_argument('--use_bn', type=str2bool, default=True) 288 | parser.add_argument('--use_o_scale', type=str2bool, default=True) 289 | # loss 290 | parser.add_argument('--margin', type=float, default=0.3, 291 | help="margin of the triplet loss, default: 0.3") 292 | # optimizer 293 | parser.add_argument('-opt', '--optimizer', type=str, default='adam') 294 | parser.add_argument('--lr', type=float, default=0.1, 295 | help="learning rate of new parameters, for pretrained " 296 | "parameters it is 10 times smaller than this") 297 | parser.add_argument('--momentum', type=float, default=0.9) 298 | parser.add_argument('--weight-decay', type=float, default=5e-4) 299 | # training configs 300 | parser.add_argument('--num_gpu', type=int, default=1) 301 | parser.add_argument('--resume', action='store_true', 302 | help='continue to train') 303 | parser.add_argument('--rerank', action='store_true', 304 | help="evaluation with re-ranking") 305 | parser.add_argument('--epochs', type=int, default=180) 306 | parser.add_argument('--start_save', type=int, default=1, 307 | help="start saving checkpoints after specific epoch") 308 | parser.add_argument('--seed', type=int, default=16) 309 | parser.add_argument('--print-freq', type=int, default=1) 310 | parser.add_argument('--empty_cache', type=str2bool, default=False) 311 | parser.add_argument('--random_erasing', type=str2bool, default=True) 312 | # testing configs 313 | parser.add_argument('--evaluate', action='store_true', 314 | help="evaluation only") 315 | parser.add_argument('--checkpoint', type=str, default="", 316 | help="load the checkpoint for testing") 317 | # metric learning 318 | parser.add_argument('--dist-metric', type=str, default='euclidean', 319 | choices=['euclidean', 'kissme']) 320 | # misc 321 | working_dir = osp.dirname(osp.abspath(__file__)) 322 | parser.add_argument('--data-dir', type=str, metavar='PATH', 323 | default='./dataset/') 324 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 325 | default="./logs/RGA-SC/") 326 | parser.add_argument('--logs-file', type=str, metavar='PATH', 327 | default='log.txt') 328 | main(parser.parse_args()) 329 | -------------------------------------------------------------------------------- /GASNet/reid/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import evaluation_metrics 4 | from . import loss 5 | from . import models 6 | from . import utils 7 | from . import img_trainers 8 | from . import img_evaluators 9 | 10 | __version__ = '0.2.0' 11 | -------------------------------------------------------------------------------- /GASNet/reid/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/__pycache__/img_evaluators.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/__pycache__/img_evaluators.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/__pycache__/img_evaluators.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/__pycache__/img_evaluators.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/__pycache__/img_trainers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/__pycache__/img_trainers.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/__pycache__/img_trainers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/__pycache__/img_trainers.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/data_manager/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .vru import VRU 6 | from .test_show_dataset import VRU_test 7 | 8 | __imgreid_factory = { 9 | 'VRU': VRU, 10 | 'VRU_show':VRU_test 11 | } 12 | 13 | def get_names(): 14 | return list(__imgreid_factory.keys()) + list(__vidreid_factory.keys()) 15 | 16 | def init_imgreid_dataset(name, **kwargs): 17 | if name not in list(__imgreid_factory.keys()): 18 | raise KeyError("Invalid dataset, got '{}', but expected to be one of {}".format(name, list(__imgreid_factory.keys()))) 19 | return __imgreid_factory[name](**kwargs) 20 | -------------------------------------------------------------------------------- /GASNet/reid/data_manager/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/data_manager/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/data_manager/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/data_manager/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/data_manager/__pycache__/cuhk03.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/data_manager/__pycache__/cuhk03.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/data_manager/__pycache__/test_show_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/data_manager/__pycache__/test_show_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/data_manager/__pycache__/vru.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/data_manager/__pycache__/vru.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/data_manager/__pycache__/vru.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/data_manager/__pycache__/vru.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/data_manager/test_show_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | 4 | from ..utils.osutils import mkdir_if_missing 5 | from ..utils.serialization import write_json, read_json 6 | 7 | class VRU_test(object): 8 | dataset_dir = 'VRU' 9 | 10 | def __init__(self, root='datasets', split_id=0, verbose=True, **kwargs): 11 | super(VRU_test, self).__init__() 12 | self.dataset_dir = osp.join(root, self.dataset_dir) 13 | self.label_dir = osp.join(self.dataset_dir, 'train_test_split') 14 | self.imgs_dir = osp.join(self.dataset_dir, 'Pic') 15 | self.split_labeled_json_path = osp.join(self.dataset_dir, 'test_10.json') 16 | 17 | self._check_before_run() 18 | self._preprocess() 19 | 20 | split_path = self.split_labeled_json_path 21 | 22 | splits = read_json(split_path) 23 | assert split_id < len(splits), "Condition split_id ({}) < len(splits) ({}) is false".format(split_id, len(splits)) 24 | split = splits[split_id] 25 | print("Split index = {}".format(split_id)) 26 | 27 | query = split['query'] # list 28 | gallery = split['gallery'] # list 29 | 30 | num_query_cids = split['num_query_pids'] # int 31 | num_gallery_cids = split['num_gallery_pids'] # int 32 | num_total_cids = num_query_cids 33 | 34 | num_query_imgs = split['num_query_imgs'] # int 35 | num_gallery_imgs = split['num_gallery_imgs'] # int 36 | num_total_imgs = num_query_imgs 37 | 38 | if verbose: 39 | print("=> VRU loaded") 40 | print("Dataset statistics:") 41 | print(" ------------------------------") 42 | print(" subset | # ids | # images") 43 | print(" ------------------------------") 44 | print(" query | {:5d} | {:8d}".format(num_query_cids, num_query_imgs)) 45 | print(" gallery | {:5d} | {:8d}".format(num_gallery_cids, num_gallery_imgs)) 46 | print(" ------------------------------") 47 | print(" total | {:5d} | {:8d}".format(num_total_cids, num_total_imgs)) 48 | print(" ------------------------------") 49 | 50 | self.query = query 51 | self.gallery = gallery 52 | 53 | self.num_query_cids = num_query_cids 54 | self.num_gallery_cids = num_gallery_cids 55 | 56 | def _check_before_run(self): 57 | """Check if all files are available before going deeper""" 58 | if not osp.exists(self.dataset_dir): 59 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 60 | if not osp.exists(self.label_dir): 61 | raise RuntimeError("'{}' is not available".format(self.label_dir)) 62 | if not osp.exists(self.imgs_dir): 63 | raise RuntimeError("'{}' is not available".format(self.imgs_dir)) 64 | # if not osp.exists(self.split_labeled_json_path): 65 | # raise RuntimeError("'{}' is not available".format(self.split_labeled_json_path)) 66 | 67 | def _preprocess(self): 68 | """ 69 | This function is a bit complex and ugly, what it does is 70 | 1. Extract data from cuhk-03.mat and save as png images. 71 | 2. Create 20 classic splits. (Li et al. CVPR'14) 72 | 3. Create new split. (Zhong et al. CVPR'17) 73 | """ 74 | print("Note: if root path is changed, the previously generated json files need to be re-generated (delete them first)") 75 | 76 | def _extract_split(label_dir, label_file_name, pic_dir, split_name): # 77 | split_name_path = osp.join(label_dir, label_file_name) 78 | split_file = open(split_name_path) 79 | line = split_file.readline() 80 | train, query, gallery = [], [], [] 81 | car_dic = {} 82 | # 处理训练集 83 | if split_name == "train": 84 | car_count = 0 85 | while line: 86 | line_list = line.split() # str转list 87 | if line_list[1] in car_dic: 88 | car_dic[line_list[1]] = car_dic[line_list[1]] + 1 89 | else: 90 | car_dic[line_list[1]] = 1 91 | car_id = car_count 92 | car_count += 1 93 | line_list[0] = osp.join(pic_dir, line_list[0]+".jpg") 94 | line_list[1] = int(car_id) 95 | line_list.append(0) # camera ID ,没啥用 占个位而已 96 | train.append(line_list) 97 | line = split_file.readline() 98 | split_file.close() 99 | return train, len(car_dic), len(train) 100 | # 处理查询集和图库集 101 | else: 102 | while line: 103 | line_list = line.split() 104 | if line_list[1] in car_dic: 105 | car_dic[line_list[1]].append(int(line_list[0])) 106 | else: 107 | car_dic[line_list[1]] = [int(line_list[0])] 108 | line = split_file.readline() 109 | num_cids = 0 110 | car_id = 0 111 | for key, value in car_dic.items(): 112 | if len(value) > 1: 113 | num_cids += 1 114 | gallery_list = [] 115 | # choose a random image to set the gallery 116 | gallery_index = np.random.randint(0, len(value)) 117 | gallery_list.append(osp.join(pic_dir, str(value[gallery_index]) + ".jpg")) 118 | gallery_list.append(car_id) 119 | gallery_list.append(0) 120 | # gallery.append(gallery_list) # 图库集数量少 121 | query.append(gallery_list) # 查询集数量少 122 | 123 | # put the rest images into the query 124 | value.pop(gallery_index) 125 | for i in range(len(value)): 126 | query_list = [] 127 | query_list.append(osp.join(pic_dir, str(value[i])+".jpg")) 128 | query_list.append(car_id) 129 | query_list.append(0) 130 | # query.append(query_list) 131 | gallery.append(query_list) 132 | car_id += 1 133 | return query, gallery, num_cids, num_cids, len(query), len(gallery) 134 | 135 | query, gallery, num_query_cids, num_gallery_cids, num_query_imgs, num_gallery_imgs = _extract_split(self.label_dir, "test_list_10.txt", self.imgs_dir, "test") 136 | 137 | splits = [{ 138 | 'query': query, 'gallery': gallery, 139 | 'num_query_pids': num_query_cids, 'num_query_imgs': num_query_imgs, 140 | 'num_gallery_pids': num_gallery_cids, 'num_gallery_imgs': num_gallery_imgs, 141 | }] 142 | write_json(splits, self.split_labeled_json_path) 143 | 144 | -------------------------------------------------------------------------------- /GASNet/reid/data_manager/vru.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | 4 | from ..utils.serialization import write_json, read_json 5 | 6 | class VRU(object): 7 | dataset_dir = 'VRU' 8 | 9 | def __init__(self, root='datasets', split_id=0, verbose=True, **kwargs): 10 | super(VRU, self).__init__() 11 | self.dataset_dir = osp.join(root, self.dataset_dir) 12 | self.label_dir = osp.join(self.dataset_dir, 'train_test_split') 13 | self.imgs_dir = osp.join(self.dataset_dir, 'Pic') 14 | # self.split_labeled_json_path = osp.join(self.dataset_dir, 'train_validation.json') 15 | self.split_labeled_json_path = osp.join(self.dataset_dir, 'test_1200.json') 16 | # self.split_labeled_json_path = osp.join(self.dataset_dir, 'test_1200_big_gallery.json') 17 | # self.split_labeled_json_path = osp.join(self.dataset_dir, 'test_2400.json') 18 | # self.split_labeled_json_path = osp.join(self.dataset_dir, 'test_2400_big_gallery.json') 19 | # self.split_labeled_json_path = osp.join(self.dataset_dir, 'test_8000.json') 20 | # self.split_labeled_json_path = osp.join(self.dataset_dir, 'test_8000_big_gallery.json') 21 | # self.split_labeled_json_path = osp.join(self.dataset_dir, 'validation.json') 22 | # self.split_labeled_json_path = osp.join(self.dataset_dir, 'validation_big_gallery.json') 23 | 24 | 25 | self._check_before_run() 26 | self._preprocess() 27 | 28 | split_path = self.split_labeled_json_path 29 | 30 | splits = read_json(split_path) 31 | assert split_id < len(splits), "Condition split_id ({}) < len(splits) ({}) is false".format(split_id, len(splits)) 32 | split = splits[split_id] 33 | print("Split index = {}".format(split_id)) 34 | 35 | train = split['train'] # list 36 | query = split['query'] # list 37 | gallery = split['gallery'] # list 38 | 39 | num_train_cids = split['num_train_pids'] # int: 车辆实例数 40 | num_query_cids = split['num_query_pids'] # int 41 | num_gallery_cids = split['num_gallery_pids'] # int 42 | num_total_cids = num_train_cids + num_query_cids 43 | 44 | num_train_imgs = split['num_train_imgs'] # int 45 | num_query_imgs = split['num_query_imgs'] # int 46 | num_gallery_imgs = split['num_gallery_imgs'] # int 47 | num_total_imgs = num_train_imgs + num_query_imgs 48 | 49 | if verbose: 50 | print("=> VRU loaded") 51 | print("Dataset statistics:") 52 | print(" ------------------------------") 53 | print(" subset | # ids | # images") 54 | print(" ------------------------------") 55 | print(" train | {:5d} | {:8d}".format(num_train_cids, num_train_imgs)) 56 | print(" query | {:5d} | {:8d}".format(num_query_cids, num_query_imgs)) 57 | print(" gallery | {:5d} | {:8d}".format(num_gallery_cids, num_gallery_imgs)) 58 | print(" ------------------------------") 59 | print(" total | {:5d} | {:8d}".format(num_total_cids, num_total_imgs)) 60 | print(" ------------------------------") 61 | 62 | self.train = train 63 | self.query = query 64 | self.gallery = gallery 65 | 66 | self.num_train_cids = num_train_cids 67 | self.num_query_cids = num_query_cids 68 | self.num_gallery_cids = num_gallery_cids 69 | 70 | def _check_before_run(self): 71 | """Check if all files are available before going deeper""" 72 | if not osp.exists(self.dataset_dir): 73 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 74 | if not osp.exists(self.label_dir): 75 | raise RuntimeError("'{}' is not available".format(self.label_dir)) 76 | if not osp.exists(self.imgs_dir): 77 | raise RuntimeError("'{}' is not available".format(self.imgs_dir)) 78 | # if not osp.exists(self.split_labeled_json_path): 79 | # raise RuntimeError("'{}' is not available".format(self.split_labeled_json_path)) 80 | 81 | def _preprocess(self): 82 | """ 83 | This function is a bit complex and ugly, what it does is 84 | 1. Extract data from cuhk-03.mat and save as png images. 85 | 2. Create 20 classic splits. (Li et al. CVPR'14) 86 | 3. Create new split. (Zhong et al. CVPR'17) 87 | """ 88 | print("Note: if root path is changed, the previously generated json files need to be re-generated (delete them first)") 89 | 90 | def _extract_split(label_dir, label_file_name, pic_dir, split_name): # 91 | split_name_path = osp.join(label_dir, label_file_name) 92 | split_file = open(split_name_path) 93 | line = split_file.readline() 94 | train, query, gallery = [], [], [] 95 | car_dic = {} 96 | # 处理训练集 97 | if split_name == "train": 98 | car_count = 0 99 | while line: 100 | line_list = line.split() # str转list 101 | if line_list[1] in car_dic: 102 | car_dic[line_list[1]] = car_dic[line_list[1]] + 1 103 | else: 104 | car_dic[line_list[1]] = 1 105 | car_id = car_count 106 | car_count += 1 107 | line_list[0] = osp.join(pic_dir, line_list[0]+".jpg") 108 | line_list[1] = int(car_id) 109 | line_list.append(0) # camera ID ,没啥用 占个位而已 110 | train.append(line_list) 111 | line = split_file.readline() 112 | split_file.close() 113 | return train, len(car_dic), len(train) 114 | # 处理查询集和图库集 115 | else: 116 | while line: 117 | line_list = line.split() 118 | if line_list[1] in car_dic: 119 | car_dic[line_list[1]].append(int(line_list[0])) 120 | else: 121 | car_dic[line_list[1]] = [int(line_list[0])] 122 | line = split_file.readline() 123 | num_cids = 0 124 | car_id = 0 125 | for key, value in car_dic.items(): 126 | if len(value) > 1: 127 | num_cids += 1 128 | gallery_list = [] 129 | # choose a random image to set the gallery 130 | gallery_index = np.random.randint(0, len(value)) 131 | gallery_list.append(osp.join(pic_dir, str(value[gallery_index]) + ".jpg")) 132 | gallery_list.append(car_id) 133 | gallery_list.append(0) 134 | # gallery.append(gallery_list) # 图库集数量少 135 | query.append(gallery_list) # 查询集数量少 136 | 137 | # put the rest images into the query 138 | value.pop(gallery_index) 139 | for i in range(len(value)): 140 | query_list = [] 141 | query_list.append(osp.join(pic_dir, str(value[i])+".jpg")) 142 | query_list.append(car_id) 143 | query_list.append(0) 144 | # query.append(query_list) 145 | gallery.append(query_list) 146 | car_id += 1 147 | return query, gallery, num_cids, num_cids, len(query), len(gallery) 148 | 149 | train, num_train_cids, num_train_imgs = _extract_split(self.label_dir, "train_list.txt", self.imgs_dir, "train") 150 | query, gallery, num_query_cids, num_gallery_cids, num_query_imgs, num_gallery_imgs = _extract_split(self.label_dir, "test_list_1200.txt", self.imgs_dir, "test") 151 | 152 | splits = [{ 153 | 'train': train, 'query': query, 'gallery': gallery, 154 | 'num_train_pids': num_train_cids, 'num_train_imgs': num_train_imgs, 155 | 'num_query_pids': num_query_cids, 'num_query_imgs': num_query_imgs, 156 | 'num_gallery_pids': num_gallery_cids, 'num_gallery_imgs': num_gallery_imgs, 157 | }] 158 | write_json(splits, self.split_labeled_json_path) -------------------------------------------------------------------------------- /GASNet/reid/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classification import accuracy 4 | from .ranking import cmc, mean_ap 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap', 10 | ] 11 | -------------------------------------------------------------------------------- /GASNet/reid/evaluation_metrics/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/evaluation_metrics/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/evaluation_metrics/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/evaluation_metrics/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/evaluation_metrics/__pycache__/classification.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/evaluation_metrics/__pycache__/classification.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/evaluation_metrics/__pycache__/classification.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/evaluation_metrics/__pycache__/classification.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/evaluation_metrics/__pycache__/evaluate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/evaluation_metrics/__pycache__/evaluate.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/evaluation_metrics/__pycache__/evaluate.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/evaluation_metrics/__pycache__/evaluate.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/evaluation_metrics/__pycache__/ranking.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/evaluation_metrics/__pycache__/ranking.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/evaluation_metrics/__pycache__/ranking.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/evaluation_metrics/__pycache__/ranking.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/evaluation_metrics/classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from ..utils import to_torch 4 | 5 | 6 | def accuracy(output, target, topk=(1,)): 7 | output, target = to_torch(output), to_torch(target) 8 | maxk = max(topk) 9 | batch_size = target.size(0) 10 | 11 | _, pred = output.topk(maxk, 1, True, True) # return the top-1 result 参数: (返回前几的值, 按那个维度排序, 是否从大到小, 是否按顺序返回) 12 | pred = pred.t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | 15 | ret = [] 16 | for k in topk: 17 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 18 | ret.append(correct_k.mul_(1. / batch_size)) 19 | return ret 20 | -------------------------------------------------------------------------------- /GASNet/reid/evaluation_metrics/evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os.path as osp 3 | 4 | import torch 5 | 6 | from ..utils.serialization import read_json 7 | import matplotlib.pyplot as plt 8 | from tensorboardX import SummaryWriter 9 | import torchvision.transforms as transforms 10 | # import 11 | 12 | def evaluation(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): 13 | """ 14 | 输入: 15 | 1.距离矩阵,dismat 16 | 2.query集行人ID,q_pids 17 | 3.gallery集行人ID,g_pids 18 | 4.query集摄像机ID,q_camids 19 | 5.gallery集摄像机ID,g_camids 20 | 6.计算CMC中Rank-N中的N最大值max_rank 21 | 22 | 输出: 23 | 1.记录CMC曲线中Rank-N的列表all_cmc 24 | 2.记录所有符合要求的query数据的平均准确率均值mAP 25 | """ 26 | q_pids = np.asarray(q_pids) 27 | g_pids = np.asarray(g_pids) 28 | query_cams = np.asarray(q_camids) 29 | gallery_cams = np.asarray(g_camids) 30 | num_q, num_g = distmat.shape # 获得距离矩阵的规模 31 | if num_g < max_rank: # 若gallery数据少于max_rank,则将max_rank设为gallery数据数量num_g 32 | print("提示:gallery数据集数据不足,将参数'max_rank'从{}修改为{}。".format(max_rank, num_g)) 33 | max_rank = num_g 34 | 35 | """ 36 | np.argsort指定按行进行排序,返回每行按升序排列的元素下标 37 | 例如有一个人列表[1,2,0],则按行排列后,就返回[2,0,1]代表第2个元素最小,第0个元素第二小,第1个元素最大 38 | 使用该函数便可以按照距离大小进行排序,并获取排序后的下标顺序 39 | """ 40 | indices = np.argsort(distmat, axis=1) # 按照距离进行排序 41 | 42 | """ 43 | 显示搜索结果 44 | """ 45 | result_show_indices = indices[:, :10] 46 | dataset_dir = "vru" 47 | root = "datasets" 48 | dataset_dir = osp.join(root, dataset_dir) 49 | split_path = osp.join(dataset_dir, "test_10.json") 50 | splits = read_json(split_path) 51 | split = splits[0] 52 | query_indx = split["query"] 53 | gallery_indx = split["gallery"] 54 | query_dir = [] 55 | gallery_dir = [] 56 | for i in range(len(query_indx)): 57 | query_dir.append(query_indx[i][0]) 58 | for i in range(len(gallery_indx)): 59 | gallery_dir.append(gallery_indx[i][0]) 60 | result_dir = [] 61 | result_show_indices = result_show_indices.numpy().tolist() 62 | for i in range(len(result_show_indices)): 63 | result_d = result_show_indices[i] 64 | re = [] 65 | for j in range(len(result_d)): 66 | re.append(gallery_dir[result_d[j]]) 67 | result_dir.append(re) 68 | N = len(query_dir) 69 | M = 11 70 | image_num = 0 71 | for i in range(len(query_dir)): 72 | query_img = plt.imread(query_dir[i]) 73 | image_num += 1 74 | plt.subplot(N, M, image_num) 75 | plt.imshow(query_img) 76 | plt.xticks([]) 77 | plt.yticks([]) 78 | for j in range(10): 79 | result_img = plt.imread(result_dir[i][j]) 80 | image_num += 1 81 | plt.subplot(N, M, image_num) 82 | plt.imshow(result_img) 83 | plt.xticks([]) 84 | plt.yticks([]) 85 | plt.savefig('test2.jpg') 86 | plt.show() 87 | image = plt.imread("test2.jpg") 88 | image = torch.from_numpy(np.transpose(image, (2,0,1))) 89 | writer = SummaryWriter("./result_tb") 90 | writer.add_image("The result of ReID", image) 91 | writer.close() 92 | 93 | """ 94 | g_pids[indices]生成了一个与距离矩阵相同规模的矩阵,但矩阵元素是按照距离大小升序排列后对应的gallery的行人ID 95 | 假设距离矩阵某一行为[1,0,5,6],按行升序排列后得到的下标列表为[1,0,2,3],gallery行人ID为[4,5,6,8] 96 | 则可以计算得到g_pids[indices]对应的那一行为[5,4,6,8],即该行对应的特征应该被分类到ID为5的类中 97 | 98 | q_pids[:, np.newaxis]将q_pids矩阵增加了一维 99 | 将g_pids[indices]与q_pids进行匹配,对应位置相同则元素为1,否则为0 100 | 这样,便可以获得一个对应关系矩阵matches,该矩阵与距离矩阵规模相同。 101 | 102 | matches矩阵第i行第j个元素代表query第i个行人ID与gallery中于其距离第j近的数据行人ID是否相同 103 | 举例,matches[1][3]=1,说明query中第1个行人与距离第三近的gallery数据属于同一行人 104 | """ 105 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) # 进行ID匹配,计算匹配矩阵matched,便于计算cmc与AP 106 | # matches = (g_pids[indices] == q_pids[:, np.newaxis]) # 进行ID匹配,计算匹配矩阵matched,便于计算cmc与AP 107 | all_cmc = [] # 记录每个query数据的CMC数据 108 | all_AP = [] # 记录每个query数据的AP 109 | num_valid_q = 0. # 记录符合CMC与mAP计算的query数据的总数,便于计算总Rank-N 110 | 111 | for q_idx in range(num_q): # 对于query集中的每个数据 112 | 113 | q_pid = q_pids[q_idx] # 获取该数据的行人ID 114 | order = indices[q_idx] # 获得有关该数据的gallery数据距离排序 115 | 116 | # 删除与该query数据相同摄像机ID、行人ID的数据。相同摄像机相同行人的gallery数据不符合跨摄像机的要求 117 | # remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) # 得到需要删除的元素的bool类型列表 118 | # keep = np.invert(remove) # 对remove进行翻转得到可以保留的元素的bool类型列表 119 | 120 | orig_cmc = matches[q_idx] # 匹配矩阵只保留对应keep中为True的元素,得到该query数据的匹配列表 121 | if not np.any(orig_cmc): # 如果该query数据未在可以保留的gallery集中出现,说明该query数据不符合CMC与mAP计算要求,返回循环头 122 | continue 123 | 124 | """ 125 | 计算每个query数据的CMC数据 126 | """ 127 | # 计算匹配列表的叠加和 128 | cmc = orig_cmc.cumsum() 129 | # 根据叠加和得到该query数据关于gallery数据的Rank-N 130 | cmc[cmc > 1] = 1 # 把cmc中大于1的数字改为1 131 | # 将该query数据的CMC数据加入all_AP列表便于之后计算mAP,可以通过指定max_rank来指定一行保留多少列,默认50列 132 | all_cmc.append(cmc[:max_rank]) 133 | 134 | """ 135 | 计算每个query数据的AP 136 | """ 137 | # 每个query数据的正确匹配总数 138 | num_rel = orig_cmc.sum() 139 | # 计算匹配列表的叠加和 140 | tmp_cmc = orig_cmc.cumsum() 141 | # 计算每次正确匹配的准确率 142 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 143 | # 将错误匹配的准确率降为0 144 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 145 | # 计算平均准确度 146 | AP = tmp_cmc.sum() / num_rel 147 | # 将该query数据的AP加入all_AP列表便于之后计算mAP 148 | all_AP.append(AP) 149 | 150 | # 统计符合CMC与mAP计算的query数据的总数,便于计算总Rank-N 151 | num_valid_q += 1. 152 | 153 | # 如果符合CMC计算的query数据的总数小于等于0,则报错所有query数据都不符合要求 154 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 155 | all_cmc = np.asarray(all_cmc).astype(np.float32) # 将all_cmc转换为np.array类型 156 | all_cmc = all_cmc.sum(0) / num_valid_q # 将所有符合条件的query数据的Rank-N按列求和并取平均数,即可计算总CMC曲线中的Rank-N 157 | mAP = np.mean(all_AP) # 平均准确率均值就是所有符合条件的query数据平均准确率的平均数 158 | 159 | return all_cmc, mAP 160 | -------------------------------------------------------------------------------- /GASNet/reid/evaluation_metrics/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | import pdb 4 | import numpy as np 5 | from sklearn.metrics import average_precision_score 6 | 7 | from ..utils import to_numpy 8 | 9 | 10 | def _unique_sample(ids_dict, num): 11 | mask = np.zeros(num, dtype=np.bool) 12 | for _, indices in ids_dict.items(): 13 | i = np.random.choice(indices) 14 | mask[i] = True 15 | return mask 16 | 17 | 18 | def cmc(distmat, query_ids=None, gallery_ids=None, 19 | query_cams=None, gallery_cams=None, topk=100, 20 | separate_camera_set=False, 21 | single_gallery_shot=False, 22 | first_match_break=False): 23 | distmat = to_numpy(distmat) 24 | m, n = distmat.shape 25 | # Fill up default values 26 | if query_ids is None: 27 | query_ids = np.arange(m) 28 | if gallery_ids is None: 29 | gallery_ids = np.arange(n) 30 | if query_cams is None: 31 | query_cams = np.zeros(m).astype(np.int32) 32 | if gallery_cams is None: 33 | gallery_cams = np.ones(n).astype(np.int32) 34 | # Ensure numpy array 35 | query_ids = np.asarray(query_ids) 36 | gallery_ids = np.asarray(gallery_ids) 37 | query_cams = np.asarray(query_cams) 38 | gallery_cams = np.asarray(gallery_cams) 39 | # Sort and find correct matches 40 | indices = np.argsort(distmat, axis=1) 41 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 42 | # Compute CMC for each query 43 | ret = np.zeros(topk) 44 | num_valid_queries = 0 45 | for i in range(m): 46 | # Filter out the same id and same camera 47 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 48 | (gallery_cams[indices[i]] != query_cams[i])) 49 | if separate_camera_set: 50 | # Filter out samples from same camera 51 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 52 | if not np.any(matches[i, valid]): continue 53 | # if not np.any(matches[i]): continue 54 | if single_gallery_shot: 55 | repeat = 10 56 | # gids = gallery_ids[indices[i][valid]] 57 | # inds = np.where(valid)[0] 58 | gids = gallery_ids[indices[i]] 59 | inds = np.where(valid)[0] 60 | ids_dict = defaultdict(list) 61 | for j, x in zip(inds, gids): 62 | ids_dict[x].append(j) 63 | else: 64 | repeat = 1 65 | for _ in range(repeat): 66 | if single_gallery_shot: 67 | # Randomly choose one instance for each id 68 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 69 | index = np.nonzero(matches[i, sampled])[0] 70 | else: 71 | index = np.nonzero(matches[i, valid])[0] 72 | delta = 1. / (len(index) * repeat) 73 | for j, k in enumerate(index): 74 | if k - j >= topk: break 75 | if first_match_break: 76 | ret[k - j] += 1 77 | break 78 | ret[k - j] += delta 79 | num_valid_queries += 1 80 | if num_valid_queries == 0: 81 | raise RuntimeError("No valid query") 82 | return ret.cumsum() / num_valid_queries 83 | 84 | 85 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 86 | query_cams=None, gallery_cams=None): 87 | distmat = to_numpy(distmat) 88 | m, n = distmat.shape 89 | # Fill up default values 90 | if query_ids is None: 91 | query_ids = np.arange(m) 92 | if gallery_ids is None: 93 | gallery_ids = np.arange(n) 94 | if query_cams is None: 95 | query_cams = np.zeros(m).astype(np.int32) 96 | if gallery_cams is None: 97 | gallery_cams = np.ones(n).astype(np.int32) 98 | # Ensure numpy array 99 | query_ids = np.asarray(query_ids) 100 | gallery_ids = np.asarray(gallery_ids) 101 | query_cams = np.asarray(query_cams) 102 | gallery_cams = np.asarray(gallery_cams) 103 | # Sort and find correct matches 104 | indices = np.argsort(distmat, axis=1) # 排序,从小到大 105 | a = gallery_ids[indices] 106 | b = query_ids[:, np.newaxis] 107 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 108 | # matches = (query_ids[indices] == gallery_ids[:, np.newaxis]) 109 | # Compute AP for each query 110 | aps = [] 111 | for i in range(m): 112 | # Filter out the same id and same camera 113 | # valid = ((gallery_ids[indices[i]] != query_ids[i]) | 114 | # (gallery_cams[indices[i]] != query_cams[i])) 115 | valid = (gallery_ids[indices[i]] != query_ids[i]) 116 | # y_true = matches[i, valid] 117 | # y_score = -distmat[i][indices[i]][valid] 118 | y_true = matches[i] 119 | y_score = -distmat[i][indices[i]] 120 | if not np.any(y_true): # 任意一个元素为True,则输出为True 121 | continue 122 | aps.append(average_precision_score(y_true, y_score)) 123 | # if len(aps) == 0: 124 | # raise RuntimeError("No valid query") 125 | return np.mean(aps) 126 | -------------------------------------------------------------------------------- /GASNet/reid/img_evaluators.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torchvision 6 | from torch.autograd import Variable 7 | from collections import OrderedDict, Iterable, defaultdict 8 | 9 | from .utils import to_torch 10 | from .utils.meters import AverageMeter 11 | from .evaluation_metrics import cmc, mean_ap 12 | from .evaluation_metrics.evaluate import evaluation 13 | 14 | import time 15 | import copy 16 | import numpy as np 17 | 18 | # ====================== 19 | # Extracting Features 20 | # ====================== 21 | def inference_feature(model, inputs, feat_type, modules=None, use_o_scale=True): 22 | model.eval() 23 | inputs = [to_torch(inputs)] 24 | 25 | ## Feature Inference 26 | if modules is None: 27 | model_out = model(inputs, training=False, use_o_scale=use_o_scale) # 两个值:原始特征和归一化特征 28 | if isinstance(feat_type, list): 29 | outputs = [] 30 | if use_o_scale: 31 | for i in range(len(feat_type)): 32 | # if feat_type[i] == 'feat_rga': # 全局注意力提取的原始特征 33 | # outputs.append(model_out[0].data.cpu()) 34 | # elif feat_type[i] == 'feat_rga_': # 归一化特征 35 | # outputs.append(model_out[1].data.cpu()) 36 | # elif feat_type[i] == 'feat_osc': # 全尺度提取的原始特征 37 | # outputs.append(model_out[2].data.cpu()) 38 | # elif feat_type[i] == 'feat_osc_': # 归一化特征 39 | # outputs.append(model_out[3].data.cpu()) 40 | if feat_type[i] == 'feat_gasnet': # 全局注意力提取的原始特征 41 | outputs.append(model_out[0].data.cpu()) 42 | elif feat_type[i] == 'feat_gasnet_': # 归一化特征 43 | outputs.append(model_out[1].data.cpu()) 44 | elif feat_type[i] == 'feat_cls': 45 | outputs.append(model_out[2].data.cpu()) 46 | else: 47 | raise ValueError("Cannot support this type of features: {}." 48 | .format(feat_type)) 49 | return outputs 50 | else: 51 | for i in range(len(feat_type)): 52 | if feat_type[i] == 'feat_rga': 53 | outputs.append(model_out[0].data.cpu()) 54 | elif feat_type[i] == 'feat_rga_': 55 | outputs.append(model_out[1].data.cpu()) 56 | elif feat_type[i] == 'feat_cls': 57 | outputs.append(model_out[2].data.cpu()) 58 | else: 59 | raise ValueError("Cannot support this type of features: {}." 60 | .format(feat_type)) 61 | return outputs 62 | elif isinstance(feat_type, str): 63 | if feat_type == 'feat_': 64 | outputs = model_out[0] 65 | elif feat_type == 'feat': 66 | outputs = model_out[1] 67 | else: 68 | raise ValueError("Cannot support this type of features: {}." 69 | .format(feat_type)) 70 | outputs = outputs.data.cpu() 71 | return outputs 72 | else: 73 | raise NameError 74 | 75 | ## Register forward hook for each module 76 | outputs = OrderedDict() 77 | handles = [] 78 | for m in modules: 79 | outputs[id(m)] = None 80 | def func(m, i, o): outputs[id(m)] = o.data.cpu() 81 | handles.append(m.register_forward_hook(func)) 82 | model(inputs) 83 | for h in handles: 84 | h.remove() 85 | return list(outputs.values()) 86 | 87 | def extract_features(model, data_loader, normlizer, flipper, to_pil, to_tensor, 88 | feat_type, uv_size=(32, 32), print_freq=1, metric=None, use_o_scale=True): 89 | model.eval() 90 | batch_time = AverageMeter() 91 | data_time = AverageMeter() 92 | 93 | if isinstance(feat_type, list): 94 | features = {} 95 | labels = {} 96 | for feat_name in feat_type: 97 | features[feat_name] = OrderedDict() 98 | labels[feat_name] = OrderedDict() 99 | elif isinstance(feat_type, str): 100 | features = OrderedDict() 101 | labels = OrderedDict() 102 | else: 103 | raise NameError 104 | 105 | end = time.time() 106 | for i, (imgs, fnames, pids, _) in enumerate(data_loader): 107 | data_time.update(time.time() - end) 108 | in_size = imgs.size() 109 | 110 | ## Extract features 111 | if flipper is not None: 112 | imgs_flip = copy.deepcopy(imgs) 113 | else: 114 | imgs_flip = None 115 | for j in range(in_size[0]): 116 | imgs[j, :, :, :] = normlizer(imgs[j, :, :, :]) 117 | if flipper is not None: 118 | imgs_flip[j, :, :, :] = to_tensor(flipper(to_pil(imgs_flip[j, :, :, :]))) 119 | imgs_flip[j, :, :, :] = normlizer(imgs_flip[j, :, :, :]) 120 | if flipper is not None: 121 | output_unflip = inference_feature(model, imgs, feat_type, use_o_scale=use_o_scale) 122 | output_flip = inference_feature(model, imgs_flip, feat_type, use_o_scale=use_o_scale) 123 | outputs = [] 124 | for jj in range(len(output_unflip)): 125 | outputs.append((output_unflip[jj] + output_flip[jj]) / 2) 126 | else: 127 | outputs = inference_feature(model, imgs, feat_type, use_o_scale=use_o_scale) 128 | 129 | ## Save Features 130 | if isinstance(feat_type, list): 131 | for ii, feat_name in enumerate(feat_type): 132 | for fname, output, pid in zip(fnames, outputs[ii], pids): 133 | features[feat_name][fname] = output 134 | labels[feat_name][fname] = pid 135 | elif isinstance(feat_type, str): 136 | for fname, output, pid in zip(fnames, outputs, pids): 137 | features[fname] = output 138 | labels[fname] = pid 139 | else: 140 | raise NameError 141 | 142 | batch_time.update(time.time() - end) 143 | end = time.time() 144 | 145 | 146 | 147 | # if (i + 1) % print_freq == 0: 148 | # print('Extract Features: [{}/{}]\t' 149 | # 'Time {:.3f} ({:.3f})\t' 150 | # 'Data {:.3f} ({:.3f})\t' 151 | # .format(i + 1, len(data_loader), 152 | # batch_time.val, batch_time.avg, 153 | # data_time.val, data_time.avg)) 154 | 155 | return features, labels 156 | 157 | 158 | # ============= 159 | # Evaluator 160 | # ============= 161 | class ImgEvaluator(object): 162 | def __init__(self, model, file_path, flip_embedding=False, use_o_scale=True): 163 | super(ImgEvaluator, self).__init__() 164 | self.model = model 165 | self.file_path = file_path 166 | self.use_o_scale = use_o_scale 167 | self.normlizer = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 168 | ## added for flipping embedding evaluation 169 | if flip_embedding: 170 | self.flipper = torchvision.transforms.RandomHorizontalFlip(p=1.0) 171 | print ('[Info] Flip Embedding is OPENED in evaluation!') 172 | else: 173 | self.flipper = None 174 | print ('[Info] Flip Embedding is CLOSED in evaluation!') 175 | self.to_pil = torchvision.transforms.ToPILImage() 176 | self.to_tensor = torchvision.transforms.ToTensor() 177 | 178 | def eval_worerank(self, query_loader, gallery_loader, query, gallery, metric, 179 | types_list, cmc_topk=(1, 5, 10)): 180 | print("-----Start extract query features!!!------") 181 | t = time.time() 182 | query_features_list, _ = extract_features(self.model, query_loader, \ 183 | self.normlizer, self.flipper, self.to_pil, self.to_tensor, types_list, use_o_scale=self.use_o_scale) 184 | print(f"extract the query features use {time.time() - t}") 185 | print("-----Start extract gallery features!!!------") 186 | t = time.time() 187 | gallery_features_list, _ = extract_features(self.model, gallery_loader, \ 188 | self.normlizer, self.flipper, self.to_pil, self.to_tensor, types_list, use_o_scale=self.use_o_scale) 189 | print(f"extract the gallery features use {time.time() - t}") 190 | query_features = {} 191 | gallery_features = {} 192 | for feat_name in types_list: 193 | x_q = torch.cat([query_features_list[feat_name][fname].unsqueeze(0) for fname, _, _ in query], 0) 194 | x_q = x_q.view(x_q.size(0), -1) 195 | query_features[feat_name] = x_q 196 | 197 | x_g = torch.cat([gallery_features_list[feat_name][fname].unsqueeze(0) for fname, _, _ in gallery], 0) 198 | x_g = x_g.view(x_g.size(0), -1) 199 | gallery_features[feat_name] = x_g 200 | 201 | query_ids = [pid for _, pid, _ in query] 202 | gallery_ids = [pid for _, pid, _ in gallery] 203 | query_cams = [cam for _, _, cam in query] 204 | gallery_cams = [cam for _, _, cam in gallery] 205 | 206 | for feat_name in types_list: 207 | for dist_type in metric: 208 | print('Evaluated with "{}" features and "{}" metric:'.format(feat_name, dist_type)) 209 | x = query_features[feat_name] 210 | y = gallery_features[feat_name] 211 | m, n = x.size(0), y.size(0) 212 | 213 | # Calculate the distance matrix 214 | if dist_type == 'euclidean': 215 | t = time.time() 216 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 217 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 218 | dist.addmm_(x, y.t(), beta=1, alpha=-2) 219 | print(f"calculate the enclidean distance use {time.time()-t}") 220 | elif dist_type == 'cosine': 221 | t = time.time() 222 | x = F.normalize(x, p=2, dim=1) 223 | y = F.normalize(y, p=2, dim=1) 224 | dist = 1 - torch.mm(x, y.t()) # torch.mm 矩阵相乘 相乘得到余弦相似度 余弦距离=1-余弦相似度 225 | print(f"calculate the cosine distance use {time.time() - t}") 226 | else: 227 | raise NameError 228 | 229 | # Compute mean AP 230 | t = time.time() 231 | cmc, mAP = evaluation(dist, query_ids, gallery_ids, query_cams, gallery_cams, max_rank=10) 232 | print(f"calculate the CMC and mAP use {time.time() - t}") 233 | print("CMC:", cmc) 234 | print("mAP:{:.3%}".format(mAP)) 235 | # mAP = mean_ap(dist, query_ids, gallery_ids, query_cams, gallery_cams) 236 | # print('Mean AP: {:4.3%}'.format(mAP)) 237 | 238 | # # Compute CMC scores 239 | # cmc_configs = { 240 | # 'rank_results': dict(separate_camera_set=False, 241 | # single_gallery_shot=False, 242 | # first_match_break=True)} 243 | # cmc_scores = {name: cmc(dist, query_ids, gallery_ids, 244 | # query_cams, gallery_cams, **params) 245 | # for name, params in cmc_configs.items()} 246 | # 247 | # print('CMC Scores') 248 | # for k in cmc_topk: 249 | # print(' top-{:<4}{:12.1%}' 250 | # .format(k, cmc_scores['rank_results'][k - 1])) 251 | return -------------------------------------------------------------------------------- /GASNet/reid/img_trainers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import io 4 | import time 5 | import sys 6 | import os 7 | # import netron 8 | 9 | import torch 10 | import torchvision 11 | import numpy as np 12 | from torch.autograd import Variable 13 | from scipy import misc 14 | 15 | from .evaluation_metrics import accuracy 16 | from .utils.meters import AverageMeter 17 | from .utils.data.transforms import RandomErasing 18 | 19 | class BaseTrainer(object): 20 | def __init__(self, model, criterion, summary_writer, prob=0.5, use_o_scale=True, mean=[0.4914, 0.4822, 0.4465]): 21 | super(BaseTrainer, self).__init__() 22 | self.model = model 23 | self.criterion = criterion 24 | self.summary_writer = summary_writer 25 | self.use_o_scale = use_o_scale 26 | self.normlizer = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 27 | self.eraser = RandomErasing(probability=prob, mean=[0., 0., 0.]) 28 | 29 | def train(self, epoch, data_loader, optimizer, random_erasing, empty_cache=False, print_freq=10): 30 | self.model.train() 31 | 32 | batch_time = AverageMeter() 33 | data_time = AverageMeter() 34 | losses = AverageMeter() 35 | precisions = AverageMeter() 36 | 37 | end = time.time() 38 | alloss = 0.0 39 | for i, inputs in enumerate(data_loader): # inputs[0] 图像的tensor inputs[1] 路径 inputs[2] 人ID inputs[3] 相机ID 40 | data_time.update(time.time() - end) 41 | 42 | ori_inputs, targets = self._parse_data(inputs) # 扔掉了相机ID 获取图片车辆ID 43 | in_size = inputs[0].size() # (B, C, H, W) 44 | for j in range(in_size[0]): 45 | ori_inputs[0][j, :, :, :] = self.normlizer(ori_inputs[0][j, :, :, :]) 46 | if random_erasing: 47 | ori_inputs[0][j, :, :, :] = self.eraser(ori_inputs[0][j, :, :, :]) 48 | loss, all_loss, prec1 = self._forward(ori_inputs, targets) 49 | 50 | losses.update(loss.data, targets.size(0)) 51 | precisions.update(prec1, targets.size(0)) 52 | 53 | # tensorboard 54 | alloss += loss.item() 55 | if self.summary_writer is not None: 56 | if i % 1000 == 0 : 57 | global_step = epoch * len(data_loader) + i 58 | self.summary_writer.add_scalar('loss', alloss / 1000, global_step) 59 | self.summary_writer.add_scalar('rga_loss_cls', all_loss[0], global_step) 60 | self.summary_writer.add_scalar('rga_loss_tri', all_loss[1], global_step) 61 | self.summary_writer.add_scalar('oscale_loss_cls', all_loss[2], global_step) 62 | self.summary_writer.add_scalar('oscale_loss_tri', all_loss[3], global_step) 63 | alloss = 0.0 64 | optimizer.zero_grad() 65 | loss.backward() 66 | optimizer.step() 67 | if empty_cache: 68 | torch.cuda.empty_cache() 69 | 70 | batch_time.update(time.time() - end) 71 | end = time.time() 72 | 73 | if (i + 1) % print_freq == 0: 74 | print('Epoch: [{}][{}/{}]\t' 75 | 'Total_Loss: {:.5f} rgs_cls: {:.5f} rga_tri: {:.5f} oscale_cls: {:.5f} oscale_tri: {:.5f}\t' 76 | 'Prec {:.5%} ({:.5%})\t' 77 | .format(epoch + 1, i + 1, len(data_loader), 78 | loss, all_loss[0], all_loss[1], all_loss[2], all_loss[3], 79 | precisions.val, precisions.avg)) 80 | 81 | 82 | def _parse_data(self, inputs): 83 | raise NotImplementedError 84 | 85 | def _forward(self, inputs, targets): 86 | raise NotImplementedError 87 | 88 | 89 | class ImgTrainer(BaseTrainer): 90 | def _parse_data(self, inputs): 91 | imgs, _, pids, _ = inputs 92 | inputs = [Variable(imgs)] 93 | targets = Variable(pids.cuda()) 94 | return inputs, targets 95 | 96 | def _forward(self, inputs, targets): # inputs [b, c, h, w] 97 | # a = inputs 98 | # b = targets 99 | 100 | # 此处应该返回6个值 101 | outputs = self.model(inputs, training=True, use_o_scale=self.use_o_scale) # outputs[0]:三元组需要的输出 outputs[1]:未知 outputs[2]:类别输出 102 | 103 | # rga_loss 104 | rga_loss_cls = self.criterion[0](outputs[2], targets) 105 | rga_loss_tri = self.criterion[1](outputs[0], targets) 106 | 107 | # oscale_loss 108 | if self.use_o_scale: 109 | oscale_loss_cls = self.criterion[0](outputs[5], targets) 110 | oscale_loss_tri = self.criterion[1](outputs[3], targets) 111 | else: 112 | oscale_loss_cls = 0 113 | oscale_loss_tri = 0 114 | 115 | loss = rga_loss_cls + rga_loss_tri + oscale_loss_cls + oscale_loss_tri 116 | 117 | losses = [rga_loss_cls, rga_loss_tri, oscale_loss_cls, oscale_loss_tri] 118 | 119 | if self.use_o_scale: 120 | prec, = accuracy(outputs[2].data + outputs[5].data, targets.data) # 类别特征相加 121 | else: 122 | prec, = accuracy(outputs[2].data, targets.data) 123 | prec = prec[0] 124 | return loss, losses, prec 125 | 126 | -------------------------------------------------------------------------------- /GASNet/reid/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | -------------------------------------------------------------------------------- /GASNet/reid/loss/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/loss/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/loss/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/loss/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/loss/__pycache__/loss_set.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/loss/__pycache__/loss_set.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/loss/__pycache__/loss_set.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/loss/__pycache__/loss_set.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/loss/loss_set.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | import torch 7 | from torch import nn 8 | from torch.autograd import Variable 9 | import torch.nn.functional as F 10 | 11 | import numpy as np 12 | 13 | 14 | def normalize(x, axis=-1): 15 | """Normalizing to unit length along the specified dimension. 16 | Args: 17 | x: pytorch Variable 18 | Returns: 19 | x: pytorch Variable, same shape as input 20 | """ 21 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 22 | return x 23 | 24 | def euclidean_dist(x, y): 25 | """ 26 | Args: 27 | x: pytorch Variable, with shape [m, d] 28 | y: pytorch Variable, with shape [n, d] 29 | Returns: 30 | dist: pytorch Variable, with shape [m, n] 31 | """ 32 | m, n = x.size(0), y.size(0) 33 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 34 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 35 | dist = xx + yy 36 | dist.addmm_(1, -2, x, y.t()) 37 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 38 | return dist 39 | 40 | def cosine_dist(x, y): 41 | """ 42 | Args: 43 | x: pytorch Variable, with shape [m, d] 44 | y: pytorch Variable, with shape [n, d] 45 | """ 46 | x_normed = F.normalize(x, p=2, dim=1) 47 | y_normed = F.normalize(y, p=2, dim=1) 48 | return 1 - torch.mm(x_normed, y_normed.t()) 49 | 50 | def cosine_similarity(x, y): 51 | """ 52 | Args: 53 | x: pytorch Variable, with shape [m, d] 54 | y: pytorch Variable, with shape [n, d] 55 | """ 56 | x_normed = F.normalize(x, p=2, dim=1) 57 | y_normed = F.normalize(y, p=2, dim=1) 58 | return torch.mm(x_normed, y_normed.t()) 59 | 60 | 61 | def hard_example_mining(dist_mat, labels, return_inds=False): 62 | """For each anchor, find the hardest positive and negative sample. 63 | Args: 64 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 65 | labels: pytorch LongTensor, with shape [N] 66 | return_inds: whether to return the indices. Save time if `False`(?) 67 | Returns: 68 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 69 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 70 | p_inds: pytorch LongTensor, with shape [N]; 71 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 72 | n_inds: pytorch LongTensor, with shape [N]; 73 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 74 | NOTE: Only consider the case in which all labels have same num of samples, 75 | thus we can cope with all anchors in parallel. 76 | """ 77 | assert len(dist_mat.size()) == 2 78 | assert dist_mat.size(0) == dist_mat.size(1) 79 | N = dist_mat.size(0) 80 | 81 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 82 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 83 | 84 | dist_ap, relative_p_inds = torch.max( 85 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 86 | dist_an, relative_n_inds = torch.min( 87 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 88 | 89 | dist_ap = dist_ap.squeeze(1) 90 | dist_an = dist_an.squeeze(1) 91 | 92 | if return_inds: 93 | ind = (labels.new().resize_as_(labels) 94 | .copy_(torch.arange(0, N).long()) 95 | .unsqueeze(0).expand(N, N)) 96 | p_inds = torch.gather( 97 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 98 | n_inds = torch.gather( 99 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 100 | p_inds = p_inds.squeeze(1) 101 | n_inds = n_inds.squeeze(1) 102 | return dist_ap, dist_an, p_inds, n_inds 103 | 104 | return dist_ap, dist_an 105 | 106 | 107 | # ============== 108 | # Triplet Loss 109 | # ============== 110 | class TripletHardLoss(object): 111 | """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). 112 | Related Triplet Loss theory can be found in paper 'In Defense of the Triplet 113 | Loss for Person Re-Identification'.""" 114 | def __init__(self, margin=None, metric="euclidean"): 115 | self.margin = margin 116 | self.metric = metric 117 | if margin is not None: 118 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 119 | else: 120 | self.ranking_loss = nn.SoftMarginLoss() 121 | 122 | def __call__(self, global_feat, labels, normalize_feature=False): 123 | if normalize_feature: 124 | global_feat = normalize(global_feat, axis=-1) 125 | 126 | if self.metric == "euclidean": 127 | dist_mat = euclidean_dist(global_feat, global_feat) 128 | elif self.metric == "cosine": 129 | dist_mat = cosine_dist(global_feat, global_feat) 130 | else: 131 | raise NameError 132 | 133 | dist_ap, dist_an = hard_example_mining( 134 | dist_mat, labels) 135 | y = dist_an.new().resize_as_(dist_an).fill_(1) 136 | 137 | if self.margin is not None: 138 | loss = self.ranking_loss(dist_an, dist_ap, y) 139 | else: 140 | loss = self.ranking_loss(dist_an - dist_ap, y) 141 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0) 142 | return loss 143 | 144 | # ====================== 145 | # Classification Loss 146 | # ====================== 147 | class CrossEntropyLabelSmoothLoss(nn.Module): 148 | """Cross entropy loss with label smoothing regularizer. 149 | Reference: 150 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 151 | Equation: y = (1 - epsilon) * y + epsilon / K. 152 | Args: 153 | num_classes (int): number of classes. 154 | epsilon (float): weight. 155 | """ 156 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 157 | super(CrossEntropyLabelSmoothLoss, self).__init__() 158 | self.num_classes = num_classes 159 | self.epsilon = epsilon 160 | self.use_gpu = use_gpu 161 | self.logsoftmax = nn.LogSoftmax(dim=1) 162 | 163 | def forward(self, inputs, targets): 164 | """ 165 | Args: 166 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 167 | targets: ground truth labels with shape (num_classes) 168 | """ 169 | log_probs = self.logsoftmax(inputs) 170 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1) 171 | if self.use_gpu: 172 | targets = targets.cuda() 173 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 174 | loss = (- targets * log_probs).mean(0).sum() 175 | return loss 176 | -------------------------------------------------------------------------------- /GASNet/reid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .rga_model import * 4 | from .rga_model import ResNet50_RGA_Model 5 | 6 | __factory = { 7 | 'resnet50_rga': resnet50_rga, 8 | } 9 | 10 | 11 | def names(): 12 | return sorted(__factory.keys()) 13 | 14 | def create(name, *args, **kwargs): 15 | # print("args:",args) 16 | # print("**kwargs:",kwargs) 17 | if name not in __factory: 18 | raise KeyError("Unknown model:", name) 19 | return __factory[name](*args, **kwargs) # *args: 创建一个元组 **kwargs:创建一个字典 20 | # return ResNet50_RGA_Model(*args, **kwargs) # *args: 创建一个元组 **kwargs:创建一个字典 21 | -------------------------------------------------------------------------------- /GASNet/reid/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/models/__pycache__/rga_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/models/__pycache__/rga_model.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/models/__pycache__/rga_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/models/__pycache__/rga_model.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/models/models_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import -------------------------------------------------------------------------------- /GASNet/reid/models/models_utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/models/models_utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/models/models_utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/models/models_utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/models/models_utils/__pycache__/os_modules.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/models/models_utils/__pycache__/os_modules.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/models/models_utils/__pycache__/os_modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/models/models_utils/__pycache__/os_modules.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/models/models_utils/__pycache__/rga_branches.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/models/models_utils/__pycache__/rga_branches.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/models/models_utils/__pycache__/rga_branches.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/models/models_utils/__pycache__/rga_branches.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/models/models_utils/__pycache__/rga_modules.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/models/models_utils/__pycache__/rga_modules.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/models/models_utils/__pycache__/rga_modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/models/models_utils/__pycache__/rga_modules.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/models/models_utils/os_modules.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, absolute_import 2 | import warnings 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | ########## 8 | # Basic layers 9 | ########## 10 | class ConvLayer(nn.Module): 11 | """Convolution layer (conv + bn + relu).""" 12 | 13 | def __init__( 14 | self, 15 | in_channels, 16 | out_channels, 17 | kernel_size, 18 | stride=1, 19 | padding=0, 20 | groups=1, 21 | IN=False 22 | ): 23 | super(ConvLayer, self).__init__() 24 | self.conv = nn.Conv2d( 25 | in_channels, 26 | out_channels, 27 | kernel_size, 28 | stride=stride, 29 | padding=padding, 30 | bias=False, 31 | groups=groups 32 | ) 33 | if IN: 34 | self.bn = nn.InstanceNorm2d(out_channels, affine=True) 35 | else: 36 | self.bn = nn.BatchNorm2d(out_channels) 37 | self.relu = nn.ReLU(inplace=True) 38 | 39 | def forward(self, x): 40 | x = self.conv(x) 41 | x = self.bn(x) 42 | x = self.relu(x) 43 | return x 44 | 45 | 46 | class Conv1x1(nn.Module): 47 | """1x1 convolution + bn + relu.""" 48 | 49 | def __init__(self, in_channels, out_channels, stride=1, groups=1): 50 | super(Conv1x1, self).__init__() 51 | self.conv = nn.Conv2d( 52 | in_channels, 53 | out_channels, 54 | 1, 55 | stride=stride, 56 | padding=0, 57 | bias=False, 58 | groups=groups 59 | ) 60 | self.bn = nn.BatchNorm2d(out_channels) 61 | self.relu = nn.ReLU(inplace=True) 62 | 63 | def forward(self, x): 64 | x = self.conv(x) 65 | x = self.bn(x) 66 | x = self.relu(x) 67 | return x 68 | 69 | 70 | class Conv1x1Linear(nn.Module): 71 | """1x1 convolution + bn (w/o non-linearity).""" 72 | 73 | def __init__(self, in_channels, out_channels, stride=1): 74 | super(Conv1x1Linear, self).__init__() 75 | self.conv = nn.Conv2d( 76 | in_channels, out_channels, 1, stride=stride, padding=0, bias=False 77 | ) 78 | self.bn = nn.BatchNorm2d(out_channels) 79 | 80 | def forward(self, x): 81 | x = self.conv(x) 82 | x = self.bn(x) 83 | return x 84 | 85 | 86 | class Conv3x3(nn.Module): 87 | """3x3 convolution + bn + relu.""" 88 | 89 | def __init__(self, in_channels, out_channels, stride=1, groups=1): 90 | super(Conv3x3, self).__init__() 91 | self.conv = nn.Conv2d( 92 | in_channels, 93 | out_channels, 94 | 3, 95 | stride=stride, 96 | padding=1, 97 | bias=False, 98 | groups=groups 99 | ) 100 | self.bn = nn.BatchNorm2d(out_channels) 101 | self.relu = nn.ReLU(inplace=True) 102 | 103 | def forward(self, x): 104 | x = self.conv(x) 105 | x = self.bn(x) 106 | x = self.relu(x) 107 | return x 108 | 109 | 110 | class LightConv3x3(nn.Module): 111 | """Lightweight 3x3 convolution. 112 | 113 | 1x1 (linear) + dw 3x3 (nonlinear). 114 | """ 115 | 116 | def __init__(self, in_channels, out_channels): 117 | super(LightConv3x3, self).__init__() 118 | self.conv1 = nn.Conv2d( 119 | in_channels, out_channels, 1, stride=1, padding=0, bias=False 120 | ) 121 | self.conv2 = nn.Conv2d( 122 | out_channels, 123 | out_channels, 124 | 3, 125 | stride=1, 126 | padding=1, 127 | bias=False, 128 | groups=out_channels 129 | ) 130 | self.bn = nn.BatchNorm2d(out_channels) 131 | self.relu = nn.ReLU(inplace=True) 132 | 133 | def forward(self, x): 134 | x = self.conv1(x) 135 | x = self.conv2(x) 136 | x = self.bn(x) 137 | x = self.relu(x) 138 | return x 139 | 140 | 141 | ########## 142 | # Building blocks for omni-scale feature learning 143 | ########## 144 | class ChannelGate(nn.Module): 145 | """A mini-network that generates channel-wise gates conditioned on input tensor.""" 146 | 147 | def __init__( 148 | self, 149 | in_channels, 150 | num_gates=None, 151 | return_gates=False, 152 | gate_activation='sigmoid', 153 | reduction=16, 154 | layer_norm=False 155 | ): 156 | super(ChannelGate, self).__init__() 157 | if num_gates is None: 158 | num_gates = in_channels 159 | self.return_gates = return_gates 160 | self.global_avgpool = nn.AdaptiveAvgPool2d(1) 161 | self.fc1 = nn.Conv2d( 162 | in_channels, 163 | in_channels // reduction, 164 | kernel_size=1, 165 | bias=True, 166 | padding=0 167 | ) 168 | self.norm1 = None 169 | if layer_norm: 170 | self.norm1 = nn.LayerNorm((in_channels // reduction, 1, 1)) 171 | self.relu = nn.ReLU(inplace=True) 172 | self.fc2 = nn.Conv2d( 173 | in_channels // reduction, 174 | num_gates, 175 | kernel_size=1, 176 | bias=True, 177 | padding=0 178 | ) 179 | if gate_activation == 'sigmoid': 180 | self.gate_activation = nn.Sigmoid() 181 | elif gate_activation == 'relu': 182 | self.gate_activation = nn.ReLU(inplace=True) 183 | elif gate_activation == 'linear': 184 | self.gate_activation = None 185 | else: 186 | raise RuntimeError( 187 | "Unknown gate activation: {}".format(gate_activation) 188 | ) 189 | 190 | def forward(self, x): 191 | input = x 192 | x = self.global_avgpool(x) 193 | x = self.fc1(x) 194 | if self.norm1 is not None: 195 | x = self.norm1(x) 196 | x = self.relu(x) 197 | x = self.fc2(x) 198 | if self.gate_activation is not None: 199 | x = self.gate_activation(x) 200 | if self.return_gates: 201 | return x 202 | return input * x 203 | 204 | class OSBlock(nn.Module): 205 | """Omni-scale feature learning block.""" 206 | 207 | def __init__( 208 | self, 209 | in_channels, 210 | out_channels, 211 | IN=False, 212 | bottleneck_reduction=4, 213 | **kwargs 214 | ): 215 | super(OSBlock, self).__init__() 216 | mid_channels = out_channels // bottleneck_reduction 217 | self.conv1 = Conv1x1(in_channels, mid_channels) 218 | self.conv2a = LightConv3x3(mid_channels, mid_channels) 219 | self.conv2b = nn.Sequential( 220 | LightConv3x3(mid_channels, mid_channels), 221 | LightConv3x3(mid_channels, mid_channels), 222 | ) 223 | self.conv2c = nn.Sequential( 224 | LightConv3x3(mid_channels, mid_channels), 225 | LightConv3x3(mid_channels, mid_channels), 226 | LightConv3x3(mid_channels, mid_channels), 227 | ) 228 | self.conv2d = nn.Sequential( 229 | LightConv3x3(mid_channels, mid_channels), 230 | LightConv3x3(mid_channels, mid_channels), 231 | LightConv3x3(mid_channels, mid_channels), 232 | LightConv3x3(mid_channels, mid_channels), 233 | ) 234 | self.gate = ChannelGate(mid_channels) 235 | self.conv3 = Conv1x1Linear(mid_channels, out_channels) 236 | self.downsample = None 237 | if in_channels != out_channels: 238 | self.downsample = Conv1x1Linear(in_channels, out_channels) 239 | self.IN = None 240 | if IN: 241 | self.IN = nn.InstanceNorm2d(out_channels, affine=True) 242 | 243 | def forward(self, x): 244 | identity = x 245 | x1 = self.conv1(x) 246 | x2a = self.conv2a(x1) 247 | x2b = self.conv2b(x1) 248 | x2c = self.conv2c(x1) 249 | x2d = self.conv2d(x1) 250 | x2 = self.gate(x2a) + self.gate(x2b) + self.gate(x2c) + self.gate(x2d) 251 | x3 = self.conv3(x2) 252 | if self.downsample is not None: 253 | identity = self.downsample(identity) 254 | out = x3 + identity 255 | if self.IN is not None: 256 | out = self.IN(out) 257 | return F.relu(out) -------------------------------------------------------------------------------- /GASNet/reid/models/models_utils/rga_branches.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | from __future__ import unicode_literals 8 | 9 | import math 10 | import sys 11 | import os 12 | sys.path.append(os.path.dirname(__file__)) 13 | 14 | import torch 15 | import torch as th 16 | from torch import nn 17 | from torch.autograd import Variable 18 | 19 | from rga_modules import RGA_Module 20 | from os_modules import OSBlock 21 | 22 | # WEIGHT_PATH = os.path.join(os.path.dirname(__file__), '../..')+'/weights/pre_train/resnet50-19c8e357.pth' 23 | WEIGHT_PATH = os.path.join(os.path.dirname(__file__), '../..')+'/checkpoint/chechpoint_0.pth' 24 | 25 | def weights_init_kaiming(m): 26 | classname = m.__class__.__name__ 27 | if classname.find('Linear') != -1: 28 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 29 | nn.init.constant_(m.bias, 0.0) 30 | elif classname.find('Conv') != -1: 31 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 32 | if m.bias is not None: 33 | nn.init.constant_(m.bias, 0.0) 34 | elif classname.find('BatchNorm') != -1: 35 | if m.affine: 36 | nn.init.normal_(m.weight, 1.0, 0.02) 37 | nn.init.constant_(m.bias, 0.0) 38 | 39 | 40 | def weights_init_fc(m): 41 | classname = m.__class__.__name__ 42 | if classname.find('Linear') != -1: 43 | nn.init.normal_(m.weight, std=0.001) 44 | nn.init.constant_(m.bias, 0.0) 45 | elif classname.find('BatchNorm') != -1: 46 | if m.affine: 47 | nn.init.normal_(m.weight, 1.0, 0.02) 48 | nn.init.constant_(m.bias, 0.0) 49 | 50 | 51 | class Bottleneck(nn.Module): 52 | expansion = 4 53 | 54 | def __init__(self, in_channels, out_channels, stride=1, downsample=None): 55 | super(Bottleneck, self).__init__() 56 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) 57 | self.bn1 = nn.BatchNorm2d(out_channels) 58 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, 59 | padding=1, bias=False) 60 | self.bn2 = nn.BatchNorm2d(out_channels) 61 | self.conv3 = nn.Conv2d(out_channels, out_channels * 4, kernel_size=1, bias=False) 62 | self.bn3 = nn.BatchNorm2d(out_channels * 4) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x): 68 | residual = x 69 | 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.bn2(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv3(out) 79 | out = self.bn3(out) 80 | 81 | if self.downsample is not None: 82 | residual = self.downsample(x) 83 | 84 | out += residual 85 | out = self.relu(out) 86 | 87 | return out 88 | 89 | 90 | class RGA_Branch(nn.Module): 91 | def __init__(self, pretrained=True, last_stride=1, block=Bottleneck, layers=[3, 4, 6, 3], 92 | spa_on=True, cha_on=True, s_ratio=8, c_ratio=8, d_ratio=8, height=256, width=128, 93 | model_path=WEIGHT_PATH, use_o_scale=True): 94 | super(RGA_Branch, self).__init__() 95 | 96 | print('Use_Spatial_Att: {};\tUse_Channel_Att: {};\tUse_O_Scale: {}.'.format(spa_on, cha_on, use_o_scale)) 97 | 98 | self.in_channels = 64 99 | self.use_o_scale = use_o_scale 100 | 101 | # Networks 102 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_rga_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_rga_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_rga_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_rga_layer(block, 512, layers[3], stride=last_stride) 110 | 111 | 112 | # RGA Modules 113 | self.rga_att1 = RGA_Module(256, (height//4)*(width//4), use_spatial=spa_on, use_channel=cha_on, 114 | cha_ratio=c_ratio, spa_ratio=s_ratio, down_ratio=d_ratio) 115 | self.rga_att2 = RGA_Module(512, (height//8)*(width//8), use_spatial=spa_on, use_channel=cha_on, 116 | cha_ratio=c_ratio, spa_ratio=s_ratio, down_ratio=d_ratio) 117 | self.rga_att3 = RGA_Module(1024, (height//16)*(width//16), use_spatial=spa_on, use_channel=cha_on, 118 | cha_ratio=c_ratio, spa_ratio=s_ratio, down_ratio=d_ratio) 119 | self.rga_att4 = RGA_Module(2048, (height//16)*(width//16), use_spatial=spa_on, use_channel=cha_on, 120 | cha_ratio=c_ratio, spa_ratio=s_ratio, down_ratio=d_ratio) 121 | 122 | # OS Modules 123 | if use_o_scale: 124 | self.o_scale1 = OSBlock(in_channels=1024, out_channels=1024, IN=True) 125 | self.o_scale2 = OSBlock(in_channels=1024, out_channels=2048, IN=True) 126 | 127 | # Load the pre-trained model weights 128 | if pretrained: 129 | self.load_specific_param(self.conv1.state_dict(), 'conv1', model_path) 130 | self.load_specific_param(self.bn1.state_dict(), 'bn1', model_path) 131 | self.load_partial_param(self.layer1.state_dict(), 1, model_path) 132 | self.load_partial_param(self.layer2.state_dict(), 2, model_path) 133 | self.load_partial_param(self.layer3.state_dict(), 3, model_path) 134 | self.load_partial_param(self.layer4.state_dict(), 4, model_path) 135 | 136 | def _make_rga_layer(self, block, channels, blocks, stride=1): 137 | downsample = None 138 | if stride != 1 or self.in_channels != channels * block.expansion: 139 | downsample = nn.Sequential( 140 | nn.Conv2d(self.in_channels, channels * block.expansion, 141 | kernel_size=1, stride=stride, bias=False), 142 | nn.BatchNorm2d(channels * block.expansion), 143 | ) 144 | 145 | layers = [] 146 | layers.append(block(self.in_channels, channels, stride, downsample)) 147 | self.in_channels = channels * block.expansion 148 | for i in range(1, blocks): 149 | layers.append(block(self.in_channels, channels)) 150 | 151 | return nn.Sequential(*layers) 152 | 153 | def load_partial_param(self, state_dict, model_index, model_path): 154 | param_dict = torch.load(model_path) 155 | for i in state_dict: 156 | key = 'layer{}.'.format(model_index)+i 157 | if 'num_batches_tracked' in key: 158 | continue 159 | state_dict[i].copy_(param_dict[key]) 160 | del param_dict 161 | 162 | def load_specific_param(self, state_dict, param_name, model_path): 163 | param_dict = torch.load(model_path) 164 | for i in state_dict: 165 | key = param_name + '.' + i 166 | if 'num_batches_tracked' in key: 167 | continue 168 | state_dict[i].copy_(param_dict[key]) 169 | del param_dict 170 | 171 | def forward(self, x): 172 | x = self.conv1(x) 173 | x = self.bn1(x) 174 | x = self.relu(x) 175 | x = self.maxpool(x) 176 | x = self.layer1(x) 177 | x = self.rga_att1(x) 178 | 179 | x = self.layer2(x) 180 | x = self.rga_att2(x) 181 | 182 | x = self.layer3(x) 183 | if self.use_o_scale: 184 | y = self.o_scale1(x) # insert the os module 185 | y = self.o_scale2(y) 186 | x = self.rga_att3(x) # the rga module 187 | 188 | x = self.layer4(x) 189 | x = self.rga_att4(x) 190 | 191 | if self.use_o_scale: 192 | return x, y 193 | else: 194 | return x 195 | -------------------------------------------------------------------------------- /GASNet/reid/models/models_utils/rga_modules.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | import pdb 9 | 10 | # =================== 11 | # RGA Module 12 | # =================== 13 | 14 | class RGA_Module(nn.Module): 15 | def __init__(self, in_channel, in_spatial, use_spatial=True, use_channel=True, \ 16 | cha_ratio=8, spa_ratio=8, down_ratio=8): 17 | super(RGA_Module, self).__init__() 18 | 19 | self.in_channel = in_channel 20 | self.in_spatial = in_spatial 21 | 22 | self.use_spatial = use_spatial 23 | self.use_channel = use_channel 24 | 25 | self.inter_channel = in_channel // cha_ratio # 猜测是为了调整特征维度 26 | self.inter_spatial = in_spatial // spa_ratio 27 | 28 | # Embedding functions for original features 29 | if self.use_spatial: 30 | self.gx_spatial = nn.Sequential( 31 | nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel, 32 | kernel_size=1, stride=1, padding=0, bias=False), 33 | nn.BatchNorm2d(self.inter_channel), 34 | nn.ReLU() 35 | ) 36 | if self.use_channel: 37 | self.gx_channel = nn.Sequential( 38 | nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial, 39 | kernel_size=1, stride=1, padding=0, bias=False), 40 | nn.BatchNorm2d(self.inter_spatial), 41 | nn.ReLU() 42 | ) 43 | 44 | # Embedding functions for relation features 45 | if self.use_spatial: 46 | self.gg_spatial = nn.Sequential( 47 | nn.Conv2d(in_channels=self.in_spatial * 2, out_channels=self.inter_spatial, 48 | kernel_size=1, stride=1, padding=0, bias=False), 49 | nn.BatchNorm2d(self.inter_spatial), 50 | nn.ReLU() 51 | ) 52 | if self.use_channel: 53 | self.gg_channel = nn.Sequential( 54 | nn.Conv2d(in_channels=self.in_channel*2, out_channels=self.inter_channel, 55 | kernel_size=1, stride=1, padding=0, bias=False), 56 | nn.BatchNorm2d(self.inter_channel), 57 | nn.ReLU() 58 | ) 59 | 60 | # Networks for learning attention weights 61 | if self.use_spatial: 62 | num_channel_s = 1 + self.inter_spatial 63 | self.W_spatial = nn.Sequential( 64 | nn.Conv2d(in_channels=num_channel_s, out_channels=num_channel_s//down_ratio, 65 | kernel_size=1, stride=1, padding=0, bias=False), 66 | nn.BatchNorm2d(num_channel_s//down_ratio), 67 | nn.ReLU(), 68 | nn.Conv2d(in_channels=num_channel_s//down_ratio, out_channels=1, 69 | kernel_size=1, stride=1, padding=0, bias=False), 70 | nn.BatchNorm2d(1) 71 | ) 72 | if self.use_channel: 73 | num_channel_c = 1 + self.inter_channel 74 | self.W_channel = nn.Sequential( 75 | nn.Conv2d(in_channels=num_channel_c, out_channels=num_channel_c//down_ratio, 76 | kernel_size=1, stride=1, padding=0, bias=False), 77 | nn.BatchNorm2d(num_channel_c//down_ratio), 78 | nn.ReLU(), 79 | nn.Conv2d(in_channels=num_channel_c//down_ratio, out_channels=1, 80 | kernel_size=1, stride=1, padding=0, bias=False), 81 | nn.BatchNorm2d(1) 82 | ) 83 | 84 | # Embedding functions for modeling relations 85 | if self.use_spatial: 86 | self.theta_spatial = nn.Sequential( 87 | nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel, 88 | kernel_size=1, stride=1, padding=0, bias=False), 89 | nn.BatchNorm2d(self.inter_channel), 90 | nn.ReLU() 91 | ) 92 | self.phi_spatial = nn.Sequential( 93 | nn.Conv2d(in_channels=self.in_channel, out_channels=self.inter_channel, 94 | kernel_size=1, stride=1, padding=0, bias=False), 95 | nn.BatchNorm2d(self.inter_channel), 96 | nn.ReLU() 97 | ) 98 | if self.use_channel: 99 | self.theta_channel = nn.Sequential( 100 | nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial, 101 | kernel_size=1, stride=1, padding=0, bias=False), 102 | nn.BatchNorm2d(self.inter_spatial), 103 | nn.ReLU() 104 | ) 105 | self.phi_channel = nn.Sequential( 106 | nn.Conv2d(in_channels=self.in_spatial, out_channels=self.inter_spatial, 107 | kernel_size=1, stride=1, padding=0, bias=False), 108 | nn.BatchNorm2d(self.inter_spatial), 109 | nn.ReLU() 110 | ) 111 | 112 | def forward(self, x): 113 | b, c, h, w = x.size() 114 | 115 | if self.use_spatial: 116 | # spatial attention 117 | theta_xs = self.theta_spatial(x) 118 | phi_xs = self.phi_spatial(x) 119 | theta_xs = theta_xs.view(b, self.inter_channel, -1) 120 | theta_xs = theta_xs.permute(0, 2, 1) 121 | phi_xs = phi_xs.view(b, self.inter_channel, -1) 122 | Gs = torch.matmul(theta_xs, phi_xs) 123 | Gs_in = Gs.permute(0, 2, 1).view(b, h*w, h, w) 124 | Gs_out = Gs.view(b, h*w, h, w) 125 | Gs_joint = torch.cat((Gs_in, Gs_out), 1) 126 | Gs_joint = self.gg_spatial(Gs_joint) 127 | 128 | g_xs = self.gx_spatial(x) 129 | g_xs = torch.mean(g_xs, dim=1, keepdim=True) 130 | ys = torch.cat((g_xs, Gs_joint), 1) 131 | 132 | W_ys = self.W_spatial(ys) 133 | if not self.use_channel: 134 | # out = F.sigmoid(W_ys.expand_as(x)) * x 135 | out = torch.sigmoid(W_ys.expand_as(x)) * x 136 | return out 137 | else: 138 | # x = F.sigmoid(W_ys.expand_as(x)) * x 139 | x = torch.sigmoid(W_ys.expand_as(x)) * x 140 | 141 | if self.use_channel: 142 | # channel attention 143 | xc = x.view(b, c, -1).permute(0, 2, 1).unsqueeze(-1) 144 | theta_xc = self.theta_channel(xc).squeeze(-1).permute(0, 2, 1) 145 | phi_xc = self.phi_channel(xc).squeeze(-1) 146 | Gc = torch.matmul(theta_xc, phi_xc) 147 | Gc_in = Gc.permute(0, 2, 1).unsqueeze(-1) 148 | Gc_out = Gc.unsqueeze(-1) 149 | Gc_joint = torch.cat((Gc_in, Gc_out), 1) 150 | Gc_joint = self.gg_channel(Gc_joint) 151 | 152 | g_xc = self.gx_channel(xc) 153 | g_xc = torch.mean(g_xc, dim=1, keepdim=True) 154 | yc = torch.cat((g_xc, Gc_joint), 1) 155 | 156 | W_yc = self.W_channel(yc).transpose(1, 2) 157 | # out = F.sigmoid(W_yc) * x 158 | out = torch.sigmoid(W_yc) * x 159 | 160 | return out -------------------------------------------------------------------------------- /GASNet/reid/models/rga_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | from __future__ import unicode_literals 8 | 9 | import os 10 | import math 11 | import torch 12 | from torch import nn 13 | from torch.nn import functional as F 14 | from torch.nn import init 15 | from torch.autograd import Variable 16 | 17 | import torchvision 18 | import numpy as np 19 | 20 | from .models_utils.rga_branches import RGA_Branch 21 | 22 | __all__ = ['resnet50_rga'] 23 | WEIGHT_PATH = os.path.join(os.path.dirname(__file__), '../..')+'/weights/pre_train/resnet50-19c8e357.pth' 24 | # WEIGHT_PATH = os.path.join(os.path.dirname(__file__), '../..')+'/checkpoint/chechpoint_0.pth' 25 | 26 | # =================== 27 | # Initialization 28 | # =================== 29 | 30 | def weights_init_kaiming(m): 31 | classname = m.__class__.__name__ 32 | if classname.find('Linear') != -1: 33 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 34 | nn.init.constant_(m.bias, 0.0) 35 | elif classname.find('Conv') != -1: 36 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 37 | if m.bias is not None: 38 | nn.init.constant_(m.bias, 0.0) 39 | elif classname.find('BatchNorm') != -1: 40 | if m.affine: 41 | nn.init.constant_(m.weight, 1.0) 42 | nn.init.constant_(m.bias, 0.0) 43 | 44 | 45 | def weights_init_classifier(m): 46 | classname = m.__class__.__name__ 47 | if classname.find('Linear') != -1: 48 | nn.init.normal_(m.weight, std=0.001) 49 | if m.bias: 50 | nn.init.constant_(m.bias, 0.0) 51 | 52 | # =============== 53 | # RGA Model 54 | # =============== 55 | 56 | class ResNet50_RGA_Model(nn.Module): 57 | ''' 58 | Backbone: ResNet-50 + RGA modules. 59 | ''' 60 | def __init__(self, pretrained=True, num_feat=2048, height=256, width=128, 61 | dropout=0, num_classes=0, last_stride=1, branch_name='rgasc', scale=8, d_scale=8, 62 | model_path=WEIGHT_PATH, use_o_scale=True): 63 | super(ResNet50_RGA_Model, self).__init__() 64 | self.pretrained = pretrained 65 | self.num_feat = num_feat 66 | self.dropout = dropout 67 | self.num_classes = num_classes 68 | self.branch_name = branch_name 69 | self.use_o_scale = use_o_scale 70 | print ('Num of features: {}.'.format(self.num_feat)) 71 | 72 | if 'rgasc' in branch_name: 73 | spa_on=True 74 | cha_on=True 75 | elif 'rgas' in branch_name: 76 | spa_on=True 77 | cha_on=False 78 | elif 'rgac' in branch_name: 79 | spa_on=False 80 | cha_on=True 81 | else: 82 | raise NameError 83 | 84 | self.backbone = RGA_Branch(pretrained=pretrained, last_stride=last_stride, 85 | spa_on=spa_on, cha_on=cha_on, height=height, width=width, 86 | s_ratio=scale, c_ratio=scale, d_ratio=d_scale, model_path=model_path, use_o_scale=use_o_scale) 87 | 88 | self.feat_bn = nn.BatchNorm1d(self.num_feat) 89 | self.feat_bn.bias.requires_grad_(False) 90 | if self.dropout > 0: 91 | self.drop = nn.Dropout(self.dropout) 92 | self.cls_ = nn.Linear(self.num_feat, self.num_classes, bias=False) # 调整特征维度与类别数一致 93 | 94 | self.feat_bn.apply(weights_init_kaiming) 95 | self.cls_.apply(weights_init_classifier) 96 | 97 | # 将网络出来的特征分成两部分,暂时不知道有什么用 98 | def _split_feat(self, feature, training): 99 | feat = self.feat_bn(feature) # 特征归一化,用在分类特征上 100 | if self.dropout > 0: 101 | feat = self.drop(feat) 102 | # if training and self.num_classes is not None: 103 | # cls_feat = self.cls_(feat) # 分类特征 104 | # return feat, cls_feat 105 | # if training and self.num_classes is not None: 106 | # cls_feat = self.cls_(feat) # 分类特征 107 | # return feat, cls_feat 108 | # elif not training: 109 | # return feat 110 | cls_feat = self.cls_(feat) # 分类特征 111 | return feat, cls_feat 112 | 113 | def forward(self, inputs, training=True, use_o_scale=True): 114 | im_input = inputs[0] 115 | feat_ = self.backbone(im_input) 116 | 117 | # 用于netron查看模型结构 118 | # a = feat_[0].size() 119 | # b = a[2].item() 120 | # c = a[3].item() 121 | # f = feat_[1].size() 122 | # d = f[2].item() 123 | # e = f[3].item() 124 | 125 | 126 | if use_o_scale: 127 | feat_rga = F.avg_pool2d(feat_[0], feat_[0].size()[2:]).view(feat_[0].size(0), -1) # 全局注意力 128 | feat_osc = F.avg_pool2d(feat_[1], feat_[1].size()[2:]).view(feat_[1].size(0), -1) # 全尺度 129 | # netron查看模型结构时使用下列两行代码 130 | # feat_rga = F.avg_pool2d(feat_[0], kernel_size=[b, c]).view(feat_[0].size(0), -1) # 全局注意力 131 | # feat_osc = F.avg_pool2d(feat_[1], kernel_size=[d, e]).view(feat_[1].size(0), -1) # 全尺度 132 | if training: 133 | feat_rga_, cls_rga = self._split_feat(feat_rga, True) # 调用split_feat 134 | feat_osc_, cls_osc = self._split_feat(feat_osc, True) # 调用split_feat 135 | return (feat_rga, feat_rga_, cls_rga, feat_osc, feat_osc_, cls_osc) 136 | else: 137 | feat_rga_, cls_rga = self._split_feat(feat_rga, False) # 调用split_feat, 返回归一化特征 138 | feat_osc_, cls_osc = self._split_feat(feat_osc, False) # 调用split_feat, 返回归一化特征 139 | # return (feat_rga, feat_rga_, feat_osc, feat_osc_) 140 | return (feat_rga + feat_osc, feat_rga_ + feat_osc_, cls_rga + cls_osc) 141 | else: 142 | feat_rga = F.avg_pool2d(feat_, feat_.size()[2:]).view(feat_.size(0), -1) # 全局注意力 143 | if training: 144 | feat_rga_, cls_rga = self._split_feat(feat_rga, True) # 调用split_feat 145 | return (feat_rga, feat_rga_, cls_rga) 146 | else: 147 | feat_rga_, feat_cls = self._split_feat(feat_rga, False) # 调用split_feat 148 | return (feat_rga, feat_rga_, feat_cls) 149 | 150 | 151 | def resnet50_rga(*args, **kwargs): 152 | return ResNet50_RGA_Model(*args, **kwargs) 153 | 154 | -------------------------------------------------------------------------------- /GASNet/reid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | -------------------------------------------------------------------------------- /GASNet/reid/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/utils/__pycache__/lr_scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/utils/__pycache__/lr_scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/utils/__pycache__/lr_scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/utils/__pycache__/lr_scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/utils/__pycache__/meters.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/utils/__pycache__/meters.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/utils/__pycache__/meters.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/utils/__pycache__/meters.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/utils/__pycache__/osutils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/utils/__pycache__/osutils.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/utils/__pycache__/osutils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/utils/__pycache__/osutils.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/utils/__pycache__/serialization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/utils/__pycache__/serialization.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/utils/__pycache__/serialization.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/utils/__pycache__/serialization.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .preprocessor import Preprocessor -------------------------------------------------------------------------------- /GASNet/reid/utils/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/utils/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/utils/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/utils/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/utils/data/__pycache__/preprocessor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/utils/data/__pycache__/preprocessor.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/utils/data/__pycache__/preprocessor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/utils/data/__pycache__/preprocessor.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/utils/data/__pycache__/sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/utils/data/__pycache__/sampler.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/utils/data/__pycache__/sampler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/utils/data/__pycache__/sampler.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/utils/data/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/utils/data/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /GASNet/reid/utils/data/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GeoX-Lab/ReID/a4871c799c4a45a3eebb165bc4cda85c431c0253/GASNet/reid/utils/data/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /GASNet/reid/utils/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os.path as osp 3 | 4 | from PIL import Image 5 | 6 | 7 | class Preprocessor(object): 8 | def __init__(self, dataset, root=None, transform=None): 9 | super(Preprocessor, self).__init__() 10 | self.dataset = dataset 11 | self.root = root 12 | self.transform = transform 13 | 14 | def __len__(self): 15 | return len(self.dataset) 16 | 17 | def __getitem__(self, indices): 18 | if isinstance(indices, (tuple, list)): 19 | return [self._get_single_item(index) for index in indices] 20 | return self._get_single_item(indices) 21 | 22 | def _get_single_item(self, index): 23 | fname, pid, camid = self.dataset[index] 24 | # if self.root is not None: 25 | # fpath = osp.join(self.root, fname) 26 | fpath = fname 27 | # print(fpath) 28 | img = Image.open(fpath).convert('RGB') 29 | if self.transform is not None: 30 | img = self.transform(img) 31 | return img, fname, pid, camid 32 | -------------------------------------------------------------------------------- /GASNet/reid/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import copy 5 | import random 6 | import numpy as np 7 | import torch 8 | from torch.utils.data.sampler import ( 9 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 10 | WeightedRandomSampler) 11 | 12 | 13 | class RandomIdentitySampler(Sampler): 14 | def __init__(self, data_source, num_instances=1): 15 | self.data_source = data_source 16 | self.num_instances = num_instances 17 | self.index_dic = defaultdict(list) 18 | for index, (_, pid, _) in enumerate(data_source): 19 | self.index_dic[pid].append(index) 20 | self.pids = list(self.index_dic.keys()) 21 | self.num_samples = len(self.pids) 22 | 23 | def __len__(self): 24 | return self.num_samples * self.num_instances 25 | 26 | def __iter__(self): 27 | indices = torch.randperm(self.num_samples) 28 | ret = [] 29 | for i in indices: 30 | pid = self.pids[i] 31 | t = self.index_dic[pid] 32 | if len(t) >= self.num_instances: 33 | t = np.random.choice(t, size=self.num_instances, replace=False) 34 | else: 35 | t = np.random.choice(t, size=self.num_instances, replace=True) 36 | ret.extend(t) 37 | return iter(ret) 38 | -------------------------------------------------------------------------------- /GASNet/reid/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | import torchvision 5 | import torch 6 | from PIL import Image 7 | import random 8 | import math 9 | 10 | import numpy as np 11 | 12 | class RectScale(object): 13 | def __init__(self, height, width, interpolation=Image.BILINEAR): 14 | self.height = height 15 | self.width = width 16 | self.interpolation = interpolation 17 | 18 | def __call__(self, img): 19 | w, h = img.size 20 | if h == self.height and w == self.width: 21 | return img 22 | return img.resize((self.width, self.height), self.interpolation) 23 | 24 | 25 | class RandomSizedRectCrop(object): 26 | def __init__(self, height, width, interpolation=Image.BILINEAR): 27 | self.height = height 28 | self.width = width 29 | self.interpolation = interpolation 30 | 31 | def __call__(self, img): 32 | for attempt in range(10): 33 | area = img.size[0] * img.size[1] 34 | target_area = random.uniform(0.64, 1.0) * area 35 | aspect_ratio = random.uniform(2, 3) 36 | 37 | h = int(round(math.sqrt(target_area * aspect_ratio))) 38 | w = int(round(math.sqrt(target_area / aspect_ratio))) 39 | 40 | if w <= img.size[0] and h <= img.size[1]: 41 | x1 = random.randint(0, img.size[0] - w) 42 | y1 = random.randint(0, img.size[1] - h) 43 | 44 | img = img.crop((x1, y1, x1 + w, y1 + h)) 45 | assert(img.size == (w, h)) 46 | 47 | return img.resize((self.width, self.height), self.interpolation) 48 | 49 | scale = RectScale(self.height, self.width, 50 | interpolation=self.interpolation) 51 | return scale(img) 52 | 53 | 54 | class RandomErasing(object): 55 | """ Randomly selects a rectangle region in an image and erases its pixels. 56 | 'Random Erasing Data Augmentation' by Zhong et al. 57 | See https://arxiv.org/pdf/1708.04896.pdf 58 | Args: 59 | probability: The probability that the Random Erasing operation will be performed. 60 | sl: Minimum proportion of erased area against input image. 61 | sh: Maximum proportion of erased area against input image. 62 | r1: Minimum aspect ratio of erased area. 63 | mean: Erasing value. 64 | """ 65 | 66 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 67 | self.probability = probability 68 | self.mean = mean 69 | self.sl = sl 70 | self.sh = sh 71 | self.r1 = r1 72 | 73 | def __call__(self, img): 74 | 75 | if random.uniform(0, 1) > self.probability: 76 | return img 77 | 78 | for attempt in range(100): 79 | area = img.size()[1] * img.size()[2] 80 | 81 | target_area = random.uniform(self.sl, self.sh) * area 82 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 83 | 84 | h = int(round(math.sqrt(target_area * aspect_ratio))) 85 | w = int(round(math.sqrt(target_area / aspect_ratio))) 86 | 87 | if w < img.size()[2] and h < img.size()[1]: 88 | x1 = random.randint(0, img.size()[1] - h) 89 | y1 = random.randint(0, img.size()[2] - w) 90 | if img.size()[0] == 3: 91 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 92 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 93 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 94 | else: 95 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 96 | return img 97 | 98 | return img 99 | 100 | 101 | class Random2DTranslation(object): 102 | """ 103 | With a probability, first increase image size to (1 + 1/8), and then perform random crop. 104 | 105 | Args: 106 | - height (int): target height. 107 | - width (int): target width. 108 | - p (float): probability of performing this transformation. Default: 0.5. 109 | """ 110 | def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR): 111 | self.height = height 112 | self.width = width 113 | self.p = p 114 | self.interpolation = interpolation 115 | 116 | def __call__(self, img): 117 | """ 118 | Args: 119 | - img (PIL Image): Image to be cropped. 120 | """ 121 | if random.uniform(0, 1) > self.p: 122 | return img.resize((self.width, self.height), self.interpolation) 123 | 124 | new_width, new_height = int(round(self.width * 1.125)), int(round(self.height * 1.125)) 125 | resized_img = img.resize((new_width, new_height), self.interpolation) 126 | x_maxrange = new_width - self.width 127 | y_maxrange = new_height - self.height 128 | x1 = int(round(random.uniform(0, x_maxrange))) 129 | y1 = int(round(random.uniform(0, y_maxrange))) 130 | croped_img = resized_img.crop((x1, y1, x1 + self.width, y1 + self.height)) 131 | return croped_img 132 | -------------------------------------------------------------------------------- /GASNet/reid/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /GASNet/reid/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | from __future__ import unicode_literals 5 | 6 | 7 | class LRScheduler(object): 8 | """Base class of a learning rate scheduler. 9 | 10 | A scheduler returns a new learning rate based on the number of updates that have 11 | been performed. 12 | 13 | Parameters 14 | ---------- 15 | base_lr : float, optional 16 | The initial learning rate. 17 | warmup_epoch: int 18 | number of warmup steps used before this scheduler starts decay 19 | warmup_begin_lr: float 20 | if using warmup, the learning rate from which it starts warming up 21 | warmup_mode: string 22 | warmup can be done in two modes. 23 | 'linear' mode gradually increases lr with each step in equal increments 24 | 'constant' mode keeps lr at warmup_begin_lr for warmup_steps 25 | """ 26 | 27 | def __init__(self, base_lr=0.01, step=(30, 60), factor=0.1, 28 | warmup_epoch=0, warmup_begin_lr=0, warmup_mode='linear'): 29 | self.base_lr = base_lr 30 | self.learning_rate = base_lr 31 | if isinstance(step, tuple) or isinstance(step, list): 32 | self.step = step 33 | else: 34 | self.step = [step*(i+1) for i in range(20)] 35 | self.factor = factor 36 | assert isinstance(warmup_epoch, int) 37 | self.warmup_epoch = warmup_epoch 38 | 39 | self.warmup_final_lr = base_lr 40 | self.warmup_begin_lr = warmup_begin_lr 41 | if self.warmup_begin_lr > self.warmup_final_lr: 42 | raise ValueError("Base lr has to be higher than warmup_begin_lr") 43 | if self.warmup_epoch < 0: 44 | raise ValueError("Warmup steps has to be positive or 0") 45 | if warmup_mode not in ['linear', 'constant']: 46 | raise ValueError("Supports only linear and constant modes of warmup") 47 | self.warmup_mode = warmup_mode 48 | 49 | def update(self, num_epoch): 50 | if self.warmup_epoch > num_epoch: # epoch比较小的时候 51 | # warmup strategy 52 | if self.warmup_mode == 'linear': 53 | self.learning_rate = self.warmup_begin_lr + (self.warmup_final_lr - self.warmup_begin_lr) * \ 54 | num_epoch / self.warmup_epoch 55 | elif self.warmup_mode == 'constant': 56 | self.learning_rate = self.warmup_begin_lr 57 | 58 | else: # epoch较大的时候 59 | count = sum([1 for s in self.step if s <= num_epoch]) 60 | self.learning_rate = self.base_lr * pow(self.factor, count) # pow(x, y): x的y次方 61 | return self.learning_rate 62 | 63 | -------------------------------------------------------------------------------- /GASNet/reid/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /GASNet/reid/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | import os 3 | import sys 4 | import json 5 | import time 6 | import errno 7 | import numpy as np 8 | import random 9 | import os.path as osp 10 | import warnings 11 | import PIL 12 | import torch 13 | from PIL import Image 14 | 15 | __all__ = [ 16 | 'mkdir_if_missing', 'check_isfile', 'read_json', 'write_json', 17 | 'set_random_seed', 'download_url', 'read_image', 'collect_env_info' 18 | ] 19 | 20 | 21 | def mkdir_if_missing(dirname): 22 | """Creates dirname if it is missing.""" 23 | if not osp.exists(dirname): 24 | try: 25 | os.makedirs(dirname) 26 | except OSError as e: 27 | if e.errno != errno.EEXIST: 28 | raise 29 | 30 | 31 | def check_isfile(fpath): 32 | """Checks if the given path is a file. 33 | Args: 34 | fpath (str): file path. 35 | Returns: 36 | bool 37 | """ 38 | isfile = osp.isfile(fpath) 39 | if not isfile: 40 | warnings.warn('No file found at "{}"'.format(fpath)) 41 | return isfile 42 | 43 | 44 | def read_json(fpath): 45 | """Reads json file from a path.""" 46 | with open(fpath, 'r') as f: 47 | obj = json.load(f) 48 | return obj 49 | 50 | 51 | def write_json(obj, fpath): 52 | """Writes to a json file.""" 53 | mkdir_if_missing(osp.dirname(fpath)) 54 | with open(fpath, 'w') as f: 55 | json.dump(obj, f, indent=4, separators=(',', ': ')) 56 | 57 | 58 | def set_random_seed(seed): 59 | random.seed(seed) 60 | np.random.seed(seed) 61 | torch.manual_seed(seed) 62 | torch.cuda.manual_seed_all(seed) 63 | 64 | 65 | def download_url(url, dst): 66 | """Downloads file from a url to a destination. 67 | Args: 68 | url (str): url to download file. 69 | dst (str): destination path. 70 | """ 71 | from six.moves import urllib 72 | print('* url="{}"'.format(url)) 73 | print('* destination="{}"'.format(dst)) 74 | 75 | def _reporthook(count, block_size, total_size): 76 | global start_time 77 | if count == 0: 78 | start_time = time.time() 79 | return 80 | duration = time.time() - start_time 81 | progress_size = int(count * block_size) 82 | speed = int(progress_size / (1024*duration)) 83 | percent = int(count * block_size * 100 / total_size) 84 | sys.stdout.write( 85 | '\r...%d%%, %d MB, %d KB/s, %d seconds passed' % 86 | (percent, progress_size / (1024*1024), speed, duration) 87 | ) 88 | sys.stdout.flush() 89 | 90 | urllib.request.urlretrieve(url, dst, _reporthook) 91 | sys.stdout.write('\n') 92 | 93 | 94 | def read_image(path): 95 | """Reads image from path using ``PIL.Image``. 96 | Args: 97 | path (str): path to an image. 98 | Returns: 99 | PIL image 100 | """ 101 | got_img = False 102 | if not osp.exists(path): 103 | raise IOError('"{}" does not exist'.format(path)) 104 | while not got_img: 105 | try: 106 | img = Image.open(path).convert('RGB') 107 | got_img = True 108 | except IOError: 109 | print( 110 | 'IOError incurred when reading "{}". Will redo. Don\'t worry. Just chill.' 111 | .format(path) 112 | ) 113 | return img 114 | 115 | 116 | def collect_env_info(): 117 | """Returns env info as a string. 118 | Code source: github.com/facebookresearch/maskrcnn-benchmark 119 | """ 120 | from torch.utils.collect_env import get_pretty_env_info 121 | env_str = get_pretty_env_info() 122 | env_str += '\n Pillow ({})'.format(PIL.__version__) 123 | return env_str -------------------------------------------------------------------------------- /GASNet/reid/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, epoch, is_best, save_interval=1, fpath='checkpoint.pth.tar'): 25 | docpath = None 26 | dirpath = osp.dirname(fpath) 27 | fname = osp.basename(fpath) 28 | mkdir_if_missing(dirpath) 29 | docpath = osp.join(dirpath, fname.split('.')[0] + '_{}.pth.tar'.format(epoch)) 30 | torch.save(state, docpath) 31 | if is_best: 32 | shutil.copy(fpath, osp.join(dirpath, 'model_best.pth.tar')) 33 | 34 | 35 | def load_checkpoint(fpath): 36 | if osp.isfile(fpath): 37 | checkpoint = torch.load(fpath) 38 | print("=> Loaded checkpoint '{}'".format(fpath)) 39 | return checkpoint 40 | else: 41 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 42 | 43 | 44 | def copy_state_dict(state_dict, model, strip=None): 45 | tgt_state = model.state_dict() 46 | copied_names = set() 47 | for name, param in state_dict.items(): 48 | if strip is not None and name.startswith(strip): 49 | name = name[len(strip):] 50 | if name not in tgt_state: 51 | continue 52 | if isinstance(param, Parameter): 53 | param = param.data 54 | if param.size() != tgt_state[name].size(): 55 | print('mismatch:', name, param.size(), tgt_state[name].size()) 56 | continue 57 | tgt_state[name].copy_(param) 58 | copied_names.add(name) 59 | 60 | missing = set(tgt_state.keys()) - copied_names 61 | if len(missing) > 0: 62 | print("missing keys in state_dict:", missing) 63 | 64 | return model 65 | 66 | 67 | def unfreeze_all_params(model): 68 | model.train() 69 | for p in model.parameters(): 70 | p.requires_grad_(True) 71 | 72 | 73 | def freeze_specific_params(module): 74 | module.eval() 75 | for p in module.parameters(): 76 | p.requires_grad_(False) 77 | -------------------------------------------------------------------------------- /GASNet/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.5.2 2 | numpy==1.23.1 3 | Pillow==9.2.0 4 | scikit_learn==1.1.2 5 | scipy==1.9.1 6 | six==1.16.0 7 | tensorboardX==2.5.1 8 | torch==1.6.0 9 | torchvision==0.7.0 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ReID 2 | ## Vehicle Re-identification based on UAV Viewpoint: Dataset and Method 3 | 4 | In GASNet, it is Global Attention and full-Scale Network (GASNet) for the vehicle ReID task based on UAV images. 5 | In VRU, it is a large-scale vehicle ReID dataset named VRU (the abbreviation of Vehicle Re-identification based on UAV), which consists of 172,137 images of 15,085 vehicles captured by UAVs, through which each vehicle has multiple images from various viewpoints. 6 | ## GASNet 7 | ### Abstract: 8 | High-resolution remote sensing images bring a large amount of data as well as challenges to traditional vision tasks. Vehicle re-identification (ReID), as an essential vision task that can utilize remote sensing images, has been widely used in suspect vehicle search, cross-border vehicle tracking, traffic behavior analysis, and automatic toll collection systems. Although there have been a large number of studies on vehicle ReID, most of them are based on fixed surveillance cameras and do not take full advantage of high-resolution remote sensing images. Compared with images collected by fixed surveillance cameras, high-resolution remote sensing images based on Unmanned Aerial Vehicles (UAVs) have the characteristics of rich viewpoints and a wide range of scale variations. These characteristics bring richer information to vehicle ReID tasks and have the potential to improve the performance of vehicle ReID models. However, to the best of our knowledge, there is a shortage of large open-source datasets for vehicle ReID based on UAV views, which is not conducive to promoting UAV-view-based vehicle ReID research. To address this issue, we construct a large-scale vehicle ReID dataset named VRU (the abbreviation of Vehicle Re-identification based on UAV), which consists of 172,137 images of 15,085 vehicles captured by UAVs, through which each vehicle has multiple images from various viewpoints. Compared with the existing vehicle ReID datasets based on UAVs, the VRU dataset has a larger volume and is fully open source. Since most of the existing vehicle ReID methods are designed for fixed surveillance cameras, it is difficult for these methods to adapt to UAV-based vehicle ReID images with multi-viewpoint and multi-scale characteristics. Thus, this work proposes a Global Attention and full-Scale Network (GASNet) for the vehicle ReID task based on UAV images. To verify the effectiveness of our GASNet, GASNet is compared with the baseline models on the VRU dataset. The experiment results show that GASNet can achieve 97.45% Rank-1 and 98.51% mAP, which outperforms those baselines by 3.43%/2.08% improvements in term of Rank-1/mAP. Thus, our major contributions can be summarized as follows: (1) the provision of an open-source UAV-based vehicle ReID dataset, (2) the proposal of a state-of-art model for UAV-based vehicle ReID. 9 | 10 | ### Examples 11 | Please download the pre-train model from this link and put it in the ./weights/pre_train/ folder. 12 | https://pan.baidu.com/s/1XPSgZI92ClK8lcas_v9sRg?pwd=hqj0 13 | 14 | Please download the VRU dataset from this link and put it in the ./datasets/ folder. 15 | https://github.com/GeoX-Lab/ReID/tree/main/VRU 16 | 17 | train 18 | ``` 19 | python main_gasnet.py 20 | ``` 21 | 22 | test 23 | 24 | ``` 25 | python main_gasnet.py --evaluate 26 | 27 | ``` 28 | ### Citation 29 | If you find this code or dataset useful for your research, please cite our paper. 30 | 31 | ```Bibtex 32 | @Article{rs14184603, 33 | AUTHOR = {Lu, Mingming and Xu, Yongchuan and Li, Haifeng}, 34 | TITLE = {Vehicle Re-Identification Based on UAV Viewpoint: Dataset and Method}, 35 | JOURNAL = {Remote Sensing}, 36 | VOLUME = {14}, 37 | YEAR = {2022}, 38 | NUMBER = {18}, 39 | ARTICLE-NUMBER = {4603}, 40 | URL = {https://www.mdpi.com/2072-4292/14/18/4603}, 41 | ISSN = {2072-4292}, 42 | DOI = {10.3390/rs14184603} 43 | ```} 44 | -------------------------------------------------------------------------------- /VRU/README.md: -------------------------------------------------------------------------------- 1 | ## VRU dataset 2 | 3 | We use UAVs to construct a vehicle image dataset, named VRU, for the vehicle ReID task. To collect vehicle image data under various scenes, 5 `DJI Mavic 2 Pro' UAVs are deployed. a total of 172,137 images containing 15,085 vehicle instances were obtained. The comparison between the VRU dataset and other vehicle re-identification datasets collected based on UAV is as follows: 4 | 5 | | Datasets | VRU | UAV-VeID | VRAI | 6 | | ----------- | ------- | -------- | -------- | 7 | | Identities | 15085 | 4601 | 13022 | 8 | | Images | 172137 | 41917 | 137613 | 9 | | Multi-view | $\surd$ | $\surd$ | $\surd$ | 10 | | Multi-scale | $\surd$ | $\surd$ | $\surd$ | 11 | | Weather | $\surd$ | $\surd$ | $\times$ | 12 | | Lighting | $\surd$ | $\surd$ | $\times$ | 13 | | full-open | $\surd$ | $\times$ | $\times$ | 14 | 15 | Now, the VRU dataset has been open sourced and can be downloaded from the [Baidu network disk](https://pan.baidu.com/s/1s5RcJK0wAfg3INYuRjG5zw?pwd=382t) or [Google Driver](https://drive.google.com/file/d/1ESeeYeqbf1TIUChXNcevJK_0fyVGQpgZ/view?usp=share_link). 16 | 17 | 18 | ## Citation 19 | 20 | If you find this dataset useful for your research, please cite our paper. 21 | 22 | ``` 23 | Bibtex 24 | @Article{rs14184603, 25 | AUTHOR = {Lu, Mingming and Xu, Yongchuan and Li, Haifeng}, 26 | TITLE = {Vehicle Re-Identification Based on UAV Viewpoint: Dataset and Method}, 27 | JOURNAL = {Remote Sensing}, 28 | VOLUME = {14}, 29 | YEAR = {2022}, 30 | NUMBER = {18}, 31 | ARTICLE-NUMBER = {4603}, 32 | URL = {https://www.mdpi.com/2072-4292/14/18/4603}, 33 | ISSN = {2072-4292}, 34 | DOI = {10.3390/rs14184603} 35 | } 36 | ``` 37 | 38 | --------------------------------------------------------------------------------