├── .gitignore ├── README.md ├── figures ├── m3l_results.png └── overview.png ├── main.py ├── reid ├── __init__.py ├── datasets │ ├── __init__.py │ ├── cuhk03.py │ ├── cuhknp.py │ ├── dukemtmc.py │ ├── market1501.py │ ├── msmt17v1.py │ └── msmt17v2.py ├── evaluation_metrics │ ├── __init__.py │ ├── classification.py │ └── ranking.py ├── evaluators.py ├── loss │ └── triplet.py ├── models │ ├── IBNMeta.py │ ├── MetaModules.py │ ├── __init__.py │ ├── memory.py │ └── resMeta.py ├── solver │ ├── __init__.py │ ├── build.py │ └── lr_scheduler.py ├── trainers.py └── utils │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── base_dataset.py │ ├── preprocessor.py │ ├── sampler.py │ └── transforms.py │ ├── faiss_rerank.py │ ├── faiss_utils.py │ ├── logging.py │ ├── meters.py │ ├── osutils.py │ ├── rerank.py │ └── serialization.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pth 3 | __pycache__ 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## [Learning to Generalize Unseen Domains via Memory-based Multi-Source Meta-Learning for Person Re-Identification](https://arxiv.org/abs/2012.00417) (CVPR 2021) 2 | 3 | 5 | 6 | 7 | ![](figures/overview.png) 8 | 9 | ### Requirements 10 | 11 | - CUDA>=10.0 12 | - At least three 2080-Ti GPUs 13 | - Other necessary packages listed in [requirements.txt](requirements.txt) 14 | - Training Data 15 | 16 | The model is trained and evaluated on [Market-1501](https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view), [DukeMTMC-reID](https://drive.google.com/file/d/1jjE85dRCMOgRtvJ5RQV9-Afs-2_5dY3O/view), [MSMT17_V1](https://www.pkuvmc.com/dataset.html), [MSMT17_V2](https://www.pkuvmc.com/dataset.html), [CUHK03](https://drive.google.com/file/d/1ILKiSthHm_XVeRQU2ThWNDVSO7lKWAZ_/view?usp=sharing) and [CUHK-NP](https://github.com/zhunzhong07/person-re-ranking/blob/master/CUHK03-NP/README.md) 17 | 18 | 19 | 20 | Unzip all datasets and ensure the file structure is as follow: 21 | 22 | ``` 23 | data 24 | │ 25 | └─── market1501 / dukemtmc / cuhknp / cuhk03 / msmt17v1 / msmt17v2 26 | │ 27 | └─── DukeMTMC-reID / Market-1501-v15.09.15 / detected / cuhk03_release / MSMT17_V1 / MSMT17_V2 28 | ``` 29 | 59 | 60 | 61 | *Note:* 62 | 63 | In default, for CUHK03, we use the old protocol (CUHK03, 26,263 images of 1,367 IDs for training) as the source domain for training the model and the detected subset of the new protocol (CUHK-NP) as the target domain for testing; for MSMT17, we use the MSMT17\_V2 for both training and testing. 64 | We also provide the results of using *the detected subset of CUHK-NP* (7,365 images of 767 IDs for training) and *MSMT17\_V1* for both training and testing, and we recommend using this setting in future studies. 65 | 66 | 71 | 72 | 73 | ### Run 74 | ``` 75 | ARCH=resMeta/IBNMeta 76 | SRC1/SRC2/SRC3=market1501/dukemtmc/cuhk03/cuhknp/msmt17v1/msmt17v2 77 | TARGET=market1501/dukemtmc/cuhknp/msmt17v1/msmt17v2 78 | 79 | # train 80 | CUDA_VISIBLE_DEVICES=0,1,2 python main.py \ 81 | -a $ARCH --BNNeck \ 82 | --dataset_src1 $SRC1 --dataset_src2 $SRC2 --dataset_src3 $SRC3 -d $TARGET \ 83 | --logs-dir $LOG_DIR --data-dir $DATA_DIR 84 | 85 | # evaluate 86 | python main.py \ 87 | -a $ARCH -d $TARGET \ 88 | --logs-dir $LOG_DIR --data-dir $DATA_DIR \ 89 | --evaluate --resume $RESUME 90 | ``` 91 | 92 | ### Results 93 | ![](figures/m3l_results.png) 94 | 95 | You can download the above models in the paper from [Google Drive](https://drive.google.com/drive/folders/1P_1nsTirOQ_8OZU0rgEx9eH1M34v5S0v?usp=sharing). The model is named as `$TARGET_$ARCH.pth.tar`. 96 | 97 | ### Acknowledgments 98 | This repo borrows partially from [MWNet](https://github.com/xjtushujun/meta-weight-net), 99 | [ECN](https://github.com/zhunzhong07/ECN) and 100 | [SpCL](https://github.com/yxgeee/SpCL). 101 | 102 | ### Citation 103 | ``` 104 | @inproceedings{zhao2021learning, 105 | title={Learning to Generalize Unseen Domains via Memory-based Multi-Source Meta-Learning for Person Re-Identification}, 106 | author={Zhao, Yuyang and Zhong, Zhun and Yang, Fengxiang and Luo, Zhiming and Lin, Yaojin and Li, Shaozi and Nicu, Sebe}, 107 | booktitle={CVPR}, 108 | year={2021}, 109 | } 110 | ``` 111 | 112 | ### Contact 113 | Email: yuyangzhao98@gmail.com -------------------------------------------------------------------------------- /figures/m3l_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HeliosZhao/M3L/fa6c2344b2974cb9d0cab9b54e78ce99fb770305/figures/m3l_results.png -------------------------------------------------------------------------------- /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HeliosZhao/M3L/fa6c2344b2974cb9d0cab9b54e78ce99fb770305/figures/overview.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os.path as osp 4 | import torch 5 | from torch import nn 6 | from torch.backends import cudnn 7 | from torch.utils.data import DataLoader 8 | import torch.nn.functional as F 9 | 10 | import random 11 | import numpy as np 12 | import sys 13 | import collections 14 | import copy 15 | import time 16 | from datetime import timedelta 17 | from reid import datasets 18 | from reid import models 19 | from reid.models.memory import MemoryClassifier 20 | from reid.trainers import Trainer 21 | from reid.evaluators import Evaluator, extract_features 22 | from reid.utils.data import IterLoader 23 | from reid.utils.data import transforms as T 24 | from reid.utils.data.sampler import RandomMultipleGallerySampler 25 | from reid.utils.data.preprocessor import Preprocessor 26 | from reid.utils.logging import Logger 27 | from reid.utils.serialization import load_checkpoint, save_checkpoint, copy_state_dict 28 | from reid.loss.triplet import TripletLoss 29 | from reid.solver import WarmupMultiStepLR 30 | 31 | start_epoch = best_mAP = 0 32 | 33 | 34 | def get_data(name, data_dir): 35 | root = osp.join(data_dir, name) 36 | dataset = datasets.create(name, root) 37 | return dataset 38 | 39 | 40 | 41 | def get_train_loader(args, dataset, height, width, batch_size, workers, 42 | num_instances, iters, trainset=None): 43 | 44 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 45 | std=[0.229, 0.224, 0.225]) 46 | train_transformer = T.Compose([ 47 | T.Resize((height, width), interpolation=3), 48 | T.RandomHorizontalFlip(p=0.5), 49 | T.Pad(10), 50 | T.RandomCrop((height, width)), 51 | T.ToTensor(), 52 | normalizer]) 53 | 54 | train_set = sorted(dataset.train) if trainset is None else sorted(trainset) 55 | rmgs_flag = num_instances > 0 56 | if rmgs_flag: 57 | sampler = RandomMultipleGallerySampler(train_set, num_instances) 58 | else: 59 | sampler = None 60 | train_loader = IterLoader( 61 | DataLoader(Preprocessor(train_set, root=dataset.images_dir, transform=train_transformer), 62 | batch_size=batch_size, num_workers=workers, sampler=sampler, 63 | shuffle=not rmgs_flag, pin_memory=True, drop_last=True), length=None) 64 | 65 | return train_loader 66 | 67 | 68 | def get_test_loader(dataset, height, width, batch_size, workers, testset=None): 69 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 70 | std=[0.229, 0.224, 0.225]) 71 | test_transformer = T.Compose([ 72 | T.Resize((height, width), interpolation=3), 73 | T.ToTensor(), 74 | normalizer 75 | ]) 76 | 77 | if testset is None: 78 | testset = list(set(dataset.query) | set(dataset.gallery)) 79 | 80 | test_loader = DataLoader( 81 | Preprocessor(testset, root=dataset.images_dir, transform=test_transformer), 82 | batch_size=batch_size, num_workers=workers, 83 | shuffle=False, pin_memory=True) 84 | 85 | return test_loader 86 | 87 | 88 | def create_model(args, num_classes): 89 | model = models.create(args.arch, num_features=args.features, dropout=args.dropout, norm=True,num_classes=num_classes, BNNeck=args.BNNeck) 90 | # use CUDA 91 | model.cuda() 92 | if args.resume: 93 | checkpoint = load_checkpoint(args.resume) 94 | model.copyWeight_eval(checkpoint['state_dict']) 95 | model = nn.DataParallel(model) 96 | return model 97 | 98 | 99 | def main(): 100 | args = parser.parse_args() 101 | 102 | if args.seed is not None: 103 | random.seed(args.seed) 104 | np.random.seed(args.seed) 105 | torch.manual_seed(args.seed) 106 | cudnn.deterministic = True 107 | 108 | main_worker(args) 109 | 110 | 111 | def main_worker(args): 112 | global start_epoch, best_mAP 113 | start_time = time.monotonic() 114 | 115 | cudnn.benchmark = True 116 | 117 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 118 | print("==========\nArgs:{}\n==========".format(args)) 119 | 120 | # Create datasets 121 | iters = args.iters if (args.iters > 0) else None 122 | print("==> Load datasets") 123 | dataset_src1 = get_data(args.dataset_src1,args.data_dir) 124 | dataset_src2 = get_data(args.dataset_src2,args.data_dir) 125 | dataset_src3 = get_data(args.dataset_src3,args.data_dir) 126 | dataset = get_data(args.dataset, args.data_dir) 127 | 128 | datasets_src = [dataset_src1, dataset_src2, dataset_src3] 129 | # camMar, camDuke = get_data('marCam', args.data_dir), get_data('dukeCam', args.data_dir) 130 | train_loader_src1 = get_train_loader(args, dataset_src1, args.height, args.width, 131 | args.batch_size, args.workers, args.num_instances, iters) 132 | train_loader_src2 = get_train_loader(args, dataset_src2, args.height, args.width, 133 | args.batch_size, args.workers, args.num_instances, iters) 134 | train_loader_src3 = get_train_loader(args, dataset_src3, args.height, args.width, 135 | args.batch_size, args.workers, args.num_instances, iters) 136 | 137 | test_loader = get_test_loader(dataset, args.height, args.width, args.test_batch_size, args.workers) 138 | 139 | train_loader = [train_loader_src1, train_loader_src2, train_loader_src3] 140 | 141 | num_classes1 = dataset_src1.num_train_pids 142 | num_classes2 = dataset_src2.num_train_pids 143 | num_classes3 = dataset_src3.num_train_pids 144 | num_classes = [num_classes1, num_classes2, num_classes3] 145 | print(' number classes = ', num_classes) 146 | # Create model 147 | model = create_model(args, num_classes=[0,0,0]) 148 | 149 | # Evaluator 150 | evaluator = Evaluator(model) 151 | if args.evaluate: 152 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, cmc_flag=True) 153 | return 154 | 155 | print("==> Initialize source-domain class centroids and memorys ") 156 | 157 | source_centers_all = [] 158 | memories = [] 159 | 160 | for dataset_i in range(len(datasets_src)): 161 | 162 | dataset_source = datasets_src[dataset_i] 163 | sour_cluster_loader = get_test_loader(dataset_source, args.height, args.width, 164 | args.test_batch_size, args.workers, testset=sorted(dataset_source.train)) 165 | source_features, _ = extract_features(model, sour_cluster_loader, print_freq=50) 166 | sour_fea_dict = collections.defaultdict(list) 167 | 168 | for f, pid, _, _ in sorted(dataset_source.train): 169 | sour_fea_dict[pid].append(source_features[f].unsqueeze(0)) 170 | 171 | source_centers = [torch.cat(sour_fea_dict[pid], 0).mean(0) for pid in sorted(sour_fea_dict.keys())] 172 | source_centers = torch.stack(source_centers, 0) ## pid,2048 173 | source_centers = F.normalize(source_centers, dim=1).cuda() 174 | source_centers_all.append(source_centers) 175 | 176 | curMemo = MemoryClassifier(model.module.num_features, source_centers.shape[0], 177 | temp=args.temp, momentum=args.momentum).cuda() 178 | curMemo.features = source_centers 179 | curMemo.labels = torch.arange(num_classes[dataset_i]).cuda() 180 | curMemo = nn.DataParallel(curMemo) 181 | memories.append(curMemo) 182 | 183 | del source_centers, sour_cluster_loader, sour_fea_dict 184 | 185 | # Optimizer 186 | params = [{"params": [value]} for value in model.module.params() if value.requires_grad] 187 | 188 | optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.weight_decay) 189 | lr_scheduler = WarmupMultiStepLR(optimizer, milestones=[30, 50], gamma=0.1, warmup_factor=0.01, 190 | warmup_iters=10, warmup_method="linear") 191 | 192 | criterion = TripletLoss(args.margin, args.num_instances, False).cuda() 193 | trainer = Trainer(args, model, memories, criterion) 194 | 195 | 196 | 197 | for epoch in range(args.epochs): 198 | # Calculate distance 199 | print('==> start training epoch {} \t ==> learning rate = {}'.format(epoch, optimizer.param_groups[0]['lr'])) 200 | torch.cuda.empty_cache() 201 | trainer.train(epoch, train_loader, optimizer, 202 | print_freq=args.print_freq, train_iters=args.iters) 203 | 204 | if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1): 205 | mAP = evaluator.evaluate(test_loader, dataset.query, dataset.gallery, cmc_flag=False) 206 | is_best = (mAP > best_mAP) 207 | best_mAP = max(mAP, best_mAP) 208 | save_checkpoint({ 209 | 'state_dict': model.state_dict(), 210 | 'epoch': epoch + 1, 211 | 'best_mAP': best_mAP, 212 | }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) 213 | 214 | print('\n * Finished epoch {:3d} model mAP: {:5.1%} best: {:5.1%}{}\n'. 215 | format(epoch, mAP, best_mAP, ' *' if is_best else '')) 216 | 217 | lr_scheduler.step() 218 | 219 | print('==> Test with the best model:') 220 | checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar')) 221 | model.load_state_dict(checkpoint['state_dict']) 222 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, cmc_flag=True) 223 | end_time = time.monotonic() 224 | print('Total running time: ', timedelta(seconds=end_time - start_time)) 225 | 226 | 227 | if __name__ == '__main__': 228 | parser = argparse.ArgumentParser(description="Self-paced contrastive learning on unsupervised re-ID") 229 | # data 230 | parser.add_argument('-d', '--dataset', type=str, default='market1501', 231 | choices=datasets.names()) 232 | parser.add_argument('--dataset_src1', type=str, default='cuhknp', 233 | choices=datasets.names()) 234 | parser.add_argument('--dataset_src2', type=str, default='dukemtmc', 235 | choices=datasets.names()) 236 | parser.add_argument('--dataset_src3', type=str, default='msmt17v1', 237 | choices=datasets.names()) 238 | parser.add_argument('-b', '--batch-size', type=int, default=64) 239 | parser.add_argument('--test-batch-size', type=int, default=256) 240 | parser.add_argument('-j', '--workers', type=int, default=4) 241 | parser.add_argument('--height', type=int, default=256, help="input height") 242 | parser.add_argument('--width', type=int, default=128, help="input width") 243 | parser.add_argument('--num-instances', type=int, default=4, 244 | help="each minibatch consist of " 245 | "(batch_size // num_instances) identities, and " 246 | "each identity has num_instances instances, " 247 | "default: 4") 248 | 249 | # model 250 | parser.add_argument('-a', '--arch', type=str, default='resMeta', 251 | choices=models.names()) 252 | parser.add_argument('--features', type=int, default=0) 253 | parser.add_argument('--dropout', type=float, default=0) 254 | parser.add_argument('--BNNeck', action='store_true', 255 | help="use triplet and BNNeck") 256 | parser.add_argument('--momentum', type=float, default=0.2, 257 | help="update momentum for the hybrid memory") 258 | parser.add_argument('--BNtype', type=str, default='sample', 259 | help=" MetaBN type ") 260 | ##loss 261 | parser.add_argument('--margin', type=float, default=0.3, 262 | help="margin of the triplet loss, default: 0.3") 263 | 264 | # optimizer 265 | parser.add_argument('--lr', type=float, default=0.00035, 266 | help="learning rate") 267 | parser.add_argument('--weight-decay', type=float, default=5e-4) 268 | parser.add_argument('--epochs', type=int, default=60) 269 | parser.add_argument('--iters', type=int, default=200) 270 | parser.add_argument('--step-size', type=int, default=20) 271 | # training configs 272 | parser.add_argument('--seed', type=int, default=1) 273 | parser.add_argument('--print-freq', type=int, default=5) 274 | parser.add_argument('--eval-step', type=int, default=1) 275 | parser.add_argument('--temp', type=float, default=0.05, 276 | help="temperature for scaling contrastive loss") 277 | # path 278 | working_dir = osp.dirname(osp.abspath(__file__)) 279 | parser.add_argument('--data-dir', type=str, metavar='PATH', 280 | default=osp.join(working_dir, 'data')) 281 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 282 | default=osp.join(working_dir, 'logs')) 283 | parser.add_argument('--resume', type=str, default='') 284 | parser.add_argument('--evaluate', action='store_true', 285 | help="evaluation only") 286 | main() 287 | -------------------------------------------------------------------------------- /reid/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation_metrics 5 | from . import models 6 | from . import utils 7 | from . import evaluators 8 | from . import trainers 9 | 10 | __version__ = '0.1.0' 11 | -------------------------------------------------------------------------------- /reid/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import warnings 3 | 4 | from .dukemtmc import DukeMTMC 5 | from .market1501 import Market1501 6 | from .msmt17v2 import MSMT17_V2 7 | from .msmt17v1 import MSMT17_V1 8 | from .cuhk03 import CUHK03 9 | from .cuhknp import CUHK_NP 10 | 11 | __factory = { 12 | 'market1501': Market1501, 13 | 'dukemtmc': DukeMTMC, 14 | 'msmt17v2': MSMT17_V2, 15 | 'msmt17v1': MSMT17_V1, 16 | 'cuhk03': CUHK03, 17 | 'cuhknp': CUHK_NP 18 | } 19 | 20 | 21 | def names(): 22 | return sorted(__factory.keys()) 23 | 24 | 25 | def create(name, root, *args, **kwargs): 26 | """ 27 | Create a dataset instance. 28 | 29 | Parameters 30 | ---------- 31 | name : str 32 | The dataset name. 33 | root : str 34 | The path to the dataset directory. 35 | split_id : int, optional 36 | The index of data split. Default: 0 37 | num_val : int or float, optional 38 | When int, it means the number of validation identities. When float, 39 | it means the proportion of validation to all the trainval. Default: 100 40 | download : bool, optional 41 | If True, will download the dataset. Default: False 42 | """ 43 | if name not in __factory: 44 | raise KeyError("Unknown dataset:", name) 45 | return __factory[name](root, *args, **kwargs) 46 | 47 | 48 | def get_dataset(name, root, *args, **kwargs): 49 | warnings.warn("get_dataset is deprecated. Use create instead.") 50 | return create(name, root, *args, **kwargs) 51 | -------------------------------------------------------------------------------- /reid/datasets/cuhk03.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os.path as osp 3 | import json 4 | import numpy as np 5 | 6 | def read_json(fpath): 7 | with open(fpath, 'r') as f: 8 | obj = json.load(f) 9 | return obj 10 | 11 | 12 | def _pluck(identities, indices, relabel=False): 13 | ret = [] 14 | for index, pid in enumerate(indices): 15 | pid_images = identities[pid] 16 | for camid, cam_images in enumerate(pid_images): 17 | for fname in cam_images: 18 | name = osp.splitext(fname)[0] 19 | x, y, _ = map(int, name.split('_')) 20 | assert pid == x and camid == y 21 | if relabel: 22 | ret.append((fname, index, camid, 0)) 23 | else: 24 | ret.append((fname, pid, camid, 0)) 25 | return ret 26 | 27 | def _pluck_gallery(identities, indices, relabel=False): 28 | ret = [] 29 | for index, pid in enumerate(indices): 30 | pid_images = identities[pid] 31 | for camid, cam_images in enumerate(pid_images): 32 | if len(cam_images[:-1])==0: 33 | for fname in cam_images: 34 | name = osp.splitext(fname)[0] 35 | x, y, _ = map(int, name.split('_')) 36 | assert pid == x and camid == y 37 | if relabel: 38 | ret.append((fname, index, camid)) 39 | else: 40 | ret.append((fname, pid, camid)) 41 | else: 42 | for fname in cam_images[:-1]: 43 | name = osp.splitext(fname)[0] 44 | x, y, _ = map(int, name.split('_')) 45 | assert pid == x and camid == y 46 | if relabel: 47 | ret.append((fname, index, camid, 0)) 48 | else: 49 | ret.append((fname, pid, camid, 0)) 50 | return ret 51 | 52 | def _pluck_query(identities, indices, relabel=False): 53 | ret = [] 54 | for index, pid in enumerate(indices): 55 | pid_images = identities[pid] 56 | for camid, cam_images in enumerate(pid_images): 57 | for fname in cam_images[-1:]: 58 | name = osp.splitext(fname)[0] 59 | x, y, _ = map(int, name.split('_')) 60 | assert pid == x and camid == y 61 | if relabel: 62 | ret.append((fname, index, camid, 0)) 63 | else: 64 | ret.append((fname, pid, camid, 0)) 65 | return ret 66 | 67 | 68 | class CUHK03(object): 69 | def __init__(self, root, split_id=0, verbose=True): 70 | super(CUHK03, self).__init__() 71 | self.root = osp.join(root,'cuhk03_release') 72 | self.split_id = split_id 73 | self.verbose = verbose 74 | self.meta = None 75 | self.split = None 76 | self.train, self.val, self.trainval = [], [], [] 77 | self.query, self.gallery = [], [] 78 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 79 | self._check_integrity() 80 | self.load() 81 | @property 82 | def images_dir(self): 83 | return osp.join(self.root, 'images') 84 | 85 | def load(self): 86 | splits = read_json(osp.join(self.root, 'splits.json')) 87 | if self.split_id >= len(splits): 88 | raise ValueError("split_id exceeds total splits {}" 89 | .format(len(splits))) 90 | self.split = splits[self.split_id] 91 | 92 | # Randomly split train / val 93 | trainval_pids = np.asarray(self.split['trainval']) 94 | np.random.shuffle(trainval_pids) 95 | num = len(trainval_pids) 96 | 97 | train_pids = sorted(trainval_pids) 98 | 99 | self.meta = read_json(osp.join(self.root, 'meta.json')) 100 | identities = self.meta['identities'] 101 | self.train = _pluck(identities, train_pids, relabel=True) 102 | self.query = _pluck_query(identities, self.split['query']) 103 | #self.gallery = _pluck(identities, self.split['gallery']) 104 | self.gallery = _pluck_gallery(identities, self.split['gallery']) 105 | self.num_train_pids = len(train_pids) 106 | 107 | if self.verbose: 108 | print(self.__class__.__name__, "dataset loaded") 109 | print(" subset | # ids | # images") 110 | print(" ---------------------------") 111 | print(" train | {:5d} | {:8d}" 112 | .format(self.num_train_pids, len(self.train))) 113 | print(" trainval | {:5d} | {:8d}" 114 | .format(self.num_trainval_ids, len(self.trainval))) 115 | print(" query | {:5d} | {:8d}" 116 | .format(len(self.split['query']), len(self.query))) 117 | print(" gallery | {:5d} | {:8d}" 118 | .format(len(self.split['gallery']), len(self.gallery))) 119 | 120 | def _check_integrity(self): 121 | return osp.isdir(osp.join(self.root, 'images')) and \ 122 | osp.isfile(osp.join(self.root, 'meta.json')) and \ 123 | osp.isfile(osp.join(self.root, 'splits.json')) 124 | -------------------------------------------------------------------------------- /reid/datasets/cuhknp.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | from glob import glob 4 | import re 5 | 6 | 7 | class CUHK_NP(object): 8 | 9 | def __init__(self, root): 10 | self.root = osp.join(root,'detected') 11 | self.train_path = 'bounding_box_train' 12 | self.gallery_path = 'bounding_box_test' 13 | self.query_path = 'query' 14 | self.train, self.query, self.gallery = [], [], [] 15 | self.num_train_pids, self.num_query_pids, self.num_gallery_pids = 0, 0, 0 16 | self.has_time_info = False 17 | self.load() 18 | 19 | @property 20 | def images_dir(self): 21 | return None 22 | 23 | def preprocess(self, path, relabel=True): 24 | pattern = re.compile(r'([-\d]+)_c(\d)') 25 | all_pids = {} 26 | ret = [] 27 | fpaths = sorted(glob(osp.join(self.root, path, '*.png'))) 28 | for fpath in fpaths: 29 | fname = osp.basename(fpath) 30 | pid, cam = map(int, pattern.search(fname).groups()) 31 | if pid == -1: continue 32 | if relabel: 33 | if pid not in all_pids: 34 | all_pids[pid] = len(all_pids) 35 | else: 36 | if pid not in all_pids: 37 | all_pids[pid] = pid 38 | pid = all_pids[pid] 39 | cam -= 1 40 | ret.append((fpath, pid, cam, 0)) 41 | return ret, int(len(all_pids)) 42 | 43 | def load(self): 44 | self.train, self.num_train_pids = self.preprocess(self.train_path) 45 | self.gallery, self.num_gallery_pids = self.preprocess(self.gallery_path, False) 46 | self.query, self.num_query_pids = self.preprocess(self.query_path, False) 47 | 48 | print(self.__class__.__name__, "dataset loaded") 49 | print(" subset | # ids | # images") 50 | print(" ---------------------------") 51 | print(" train | {:5d} | {:8d}" 52 | .format(self.num_train_pids, len(self.train))) 53 | print(" query | {:5d} | {:8d}" 54 | .format(self.num_query_pids, len(self.query))) 55 | print(" gallery | {:5d} | {:8d}" 56 | .format(self.num_gallery_pids, len(self.gallery))) 57 | -------------------------------------------------------------------------------- /reid/datasets/dukemtmc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | import urllib 6 | import zipfile 7 | 8 | from ..utils.data import BaseImageDataset 9 | from ..utils.osutils import mkdir_if_missing 10 | from ..utils.serialization import write_json 11 | 12 | 13 | class DukeMTMC(BaseImageDataset): 14 | """ 15 | DukeMTMC-reID 16 | Reference: 17 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 18 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 19 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 20 | 21 | Dataset statistics: 22 | # identities: 1404 (train + query) 23 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 24 | # cameras: 8 25 | """ 26 | dataset_dir = '.' 27 | 28 | def __init__(self, root, verbose=True, **kwargs): 29 | super(DukeMTMC, self).__init__() 30 | self.dataset_dir = osp.join(root, self.dataset_dir) 31 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 32 | self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train') 33 | self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query') 34 | self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test') 35 | 36 | self._download_data() 37 | self._check_before_run() 38 | 39 | train = self._process_dir(self.train_dir, relabel=True) 40 | query = self._process_dir(self.query_dir, relabel=False) 41 | gallery = self._process_dir(self.gallery_dir, relabel=False) 42 | 43 | if verbose: 44 | print("=> DukeMTMC-reID loaded") 45 | self.print_dataset_statistics(train, query, gallery) 46 | 47 | self.train = train 48 | self.query = query 49 | self.gallery = gallery 50 | 51 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 52 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 53 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 54 | 55 | def _download_data(self): 56 | if osp.exists(self.dataset_dir): 57 | print("This dataset has been downloaded.") 58 | return 59 | 60 | print("Creating directory {}".format(self.dataset_dir)) 61 | mkdir_if_missing(self.dataset_dir) 62 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 63 | 64 | print("Downloading DukeMTMC-reID dataset") 65 | urllib.request.urlretrieve(self.dataset_url, fpath) 66 | 67 | print("Extracting files") 68 | zip_ref = zipfile.ZipFile(fpath, 'r') 69 | zip_ref.extractall(self.dataset_dir) 70 | zip_ref.close() 71 | 72 | def _check_before_run(self): 73 | """Check if all files are available before going deeper""" 74 | if not osp.exists(self.dataset_dir): 75 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 76 | if not osp.exists(self.train_dir): 77 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 78 | if not osp.exists(self.query_dir): 79 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 80 | if not osp.exists(self.gallery_dir): 81 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 82 | 83 | def _process_dir(self, dir_path, relabel=False): 84 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 85 | pattern = re.compile(r'([-\d]+)_c(\d)') 86 | 87 | pid_container = set() 88 | for img_path in img_paths: 89 | pid, _ = map(int, pattern.search(img_path).groups()) 90 | pid_container.add(pid) 91 | pid2label = {pid: label for label, pid in enumerate(pid_container)} # index and their corres pid 92 | 93 | dataset = [] 94 | for img_path in img_paths: 95 | pid, camid = map(int, pattern.search(img_path).groups()) 96 | assert 1 <= camid <= 8 97 | camid -= 1 # index starts from 0 98 | if relabel: pid = pid2label[pid] 99 | dataset.append((img_path, pid, camid, 2)) 100 | 101 | return dataset 102 | -------------------------------------------------------------------------------- /reid/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import glob 4 | import re 5 | import urllib 6 | import zipfile 7 | import random 8 | from ..utils.data import BaseImageDataset 9 | from ..utils.osutils import mkdir_if_missing 10 | from ..utils.serialization import write_json 11 | 12 | 13 | class Market1501(BaseImageDataset): 14 | """ 15 | Market1501 16 | Reference: 17 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 18 | URL: http://www.liangzheng.org/Project/project_reid.html 19 | 20 | Dataset statistics: 21 | # identities: 1501 (+1 for background) 22 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 23 | """ 24 | dataset_dir = 'Market-1501-v15.09.15' 25 | 26 | def __init__(self, root, verbose=True, **kwargs): 27 | super(Market1501, self).__init__() 28 | self.dataset_dir = osp.join(root, self.dataset_dir) 29 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 30 | self.query_dir = osp.join(self.dataset_dir, 'query') 31 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 32 | 33 | self._check_before_run() 34 | 35 | train = self._process_dir(self.train_dir, relabel=True) 36 | query = self._process_dir(self.query_dir, relabel=False) 37 | gallery = self._process_dir(self.gallery_dir, relabel=False) 38 | 39 | if verbose: 40 | print("=> Market1501 loaded") 41 | self.print_dataset_statistics(train, query, gallery) 42 | 43 | self.train = train 44 | self.query = query 45 | self.gallery = gallery 46 | 47 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 48 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 49 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 50 | 51 | self.query_recollect = self._recollect_test_data(self.query, 640) 52 | self.gallery_recollect = self._recollect_test_data(self.gallery, 640) 53 | 54 | def _recollect_test_data(self, partition, sample_num): 55 | recollect_part = partition 56 | random.shuffle(recollect_part) 57 | recollect_part_sampled = recollect_part[:sample_num] 58 | return recollect_part_sampled 59 | 60 | def _check_before_run(self): 61 | """Check if all files are available before going deeper""" 62 | if not osp.exists(self.dataset_dir): 63 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 64 | if not osp.exists(self.train_dir): 65 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 66 | if not osp.exists(self.query_dir): 67 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 68 | if not osp.exists(self.gallery_dir): 69 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 70 | 71 | def _process_dir(self, dir_path, relabel=False): 72 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 73 | pattern = re.compile(r'([-\d]+)_c(\d)') 74 | 75 | pid_container = set() 76 | for img_path in img_paths: 77 | pid, _ = map(int, pattern.search(img_path).groups()) 78 | if pid == -1: continue # junk images are just ignored 79 | pid_container.add(pid) 80 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 81 | 82 | dataset = [] 83 | for img_path in img_paths: 84 | pid, camid = map(int, pattern.search(img_path).groups()) 85 | if pid == -1: continue # junk images are just ignored 86 | assert 0 <= pid <= 1501 # pid == 0 means background 87 | assert 1 <= camid <= 6 88 | camid -= 1 # index starts from 0 89 | if relabel: pid = pid2label[pid] 90 | dataset.append((img_path, pid, camid, 1)) 91 | 92 | return dataset 93 | -------------------------------------------------------------------------------- /reid/datasets/msmt17v1.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import tarfile 4 | 5 | import glob 6 | import re 7 | import urllib 8 | import zipfile 9 | 10 | from ..utils.osutils import mkdir_if_missing 11 | from ..utils.serialization import write_json 12 | 13 | 14 | def _pluck_msmt(list_file, subdir, pattern=re.compile(r'([-\d]+)_c([-\d]+)_([-\d]+)')): 15 | with open(list_file, 'r') as f: 16 | lines = f.readlines() 17 | ret = [] 18 | pids = [] 19 | for line in lines: 20 | line = line.strip() 21 | fname = line.split(' ')[0] 22 | pid, cam, _ = map(int, pattern.search(osp.basename(fname)).groups()) 23 | if pid not in pids: 24 | pids.append(pid) 25 | ret.append((osp.join(subdir, fname), pid, cam, 3)) 26 | return ret, pids 27 | 28 | 29 | class Dataset_MSMT(object): 30 | def __init__(self, root): 31 | self.root = root 32 | self.train, self.val, self.trainval = [], [], [] 33 | self.query, self.gallery = [], [] 34 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 35 | 36 | @property 37 | def images_dir(self): 38 | return osp.join(self.root, 'MSMT17_V1') 39 | 40 | def load(self, verbose=True): 41 | exdir = osp.join(self.root, 'MSMT17_V1') 42 | self.train, train_pids = _pluck_msmt(osp.join(exdir, 'list_train.txt'), 'bounding_box_train') 43 | self.val, val_pids = _pluck_msmt(osp.join(exdir, 'list_val.txt'), 'bounding_box_train') 44 | self.train = self.train + self.val 45 | self.query, query_pids = _pluck_msmt(osp.join(exdir, 'list_query.txt'), 'query') 46 | self.gallery, gallery_pids = _pluck_msmt(osp.join(exdir, 'list_gallery.txt'), 'bounding_box_test') 47 | self.num_train_pids = len(list(set(train_pids).union(set(val_pids)))) 48 | 49 | if verbose: 50 | print(self.__class__.__name__, "dataset loaded") 51 | print(" subset | # ids | # images") 52 | print(" ---------------------------") 53 | print(" train | {:5d} | {:8d}" 54 | .format(self.num_train_pids, len(self.train))) 55 | print(" query | {:5d} | {:8d}" 56 | .format(len(query_pids), len(self.query))) 57 | print(" gallery | {:5d} | {:8d}" 58 | .format(len(gallery_pids), len(self.gallery))) 59 | 60 | 61 | class MSMT17_V1(Dataset_MSMT): 62 | 63 | def __init__(self, root, split_id=0, download=True): 64 | super(MSMT17_V1, self).__init__(root) 65 | 66 | if download: 67 | self.download() 68 | 69 | self.load() 70 | 71 | def download(self): 72 | raw_dir = osp.join(self.root) 73 | mkdir_if_missing(raw_dir) 74 | 75 | # Download the raw zip file 76 | fpath = osp.join(raw_dir, 'MSMT17_V1') 77 | if osp.isdir(fpath): 78 | print("Using downloaded file: " + fpath) 79 | else: 80 | raise RuntimeError("Please download the dataset manually to {}".format(fpath)) -------------------------------------------------------------------------------- /reid/datasets/msmt17v2.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | 4 | 5 | class MSMT17_V2(object): 6 | 7 | def __init__(self, root, combine_all=False): 8 | 9 | self.images_dir = osp.join(root,'MSMT17_V2') 10 | self.combine_all = combine_all 11 | self.train_path = 'mask_train_v2' 12 | self.test_path = 'mask_test_v2' 13 | self.train_list_file = 'list_train.txt' 14 | self.val_list_file = 'list_val.txt' 15 | self.gallery_list_file = 'list_gallery.txt' 16 | self.query_list_file = 'list_query.txt' 17 | self.gallery_path = self.test_path 18 | self.query_path = self.test_path 19 | self.train, self.val, self.query, self.gallery = [], [], [], [] 20 | self.num_train_pids, self.num_val_ids, self.num_query_ids, self.num_gallery_ids = 0, 0, 0, 0 21 | self.has_time_info = False 22 | self.load() 23 | 24 | def preprocess(self, list_file, subpath): 25 | with open(osp.join(self.images_dir, list_file), 'r') as txt: 26 | lines = txt.readlines() 27 | 28 | data = [] 29 | all_pids = {} 30 | 31 | for img_idx, img_info in enumerate(lines): 32 | img_path, pid = img_info.split(' ') 33 | pid = int(pid) # no need to relabel 34 | if pid not in all_pids: 35 | all_pids[pid] = pid 36 | camid = int(img_path.split('_')[2]) - 1 # index starts from 0 37 | data.append((osp.join(subpath,img_path), pid, camid, 3)) 38 | return data, int(len(all_pids)) 39 | 40 | def load(self): 41 | self.train, self.num_train_pids = self.preprocess(self.train_list_file,self.train_path) 42 | self.val, self.num_val_ids = self.preprocess(self.val_list_file, self.train_path) 43 | self.gallery, self.num_gallery_ids = self.preprocess(self.gallery_list_file,self.test_path) 44 | self.query, self.num_query_ids = self.preprocess(self.query_list_file,self.test_path) 45 | 46 | self.train += self.val 47 | if self.combine_all: 48 | for item in self.train: 49 | item[0] = osp.join(self.train_path, item[0]) 50 | for item in self.gallery: 51 | item[0] = osp.join(self.gallery_path, item[0]) 52 | item[1] += self.num_train_pids 53 | for item in self.query: 54 | item[0] = osp.join(self.query_path, item[0]) 55 | item[1] += self.num_train_pids 56 | self.train += self.gallery 57 | self.train += self.query 58 | self.num_train_pids += self.num_gallery_ids 59 | self.train_path = '' 60 | 61 | print(self.__class__.__name__, "dataset loaded") 62 | print(" subset | # ids | # images") 63 | print(" ---------------------------") 64 | print(" train | {:5d} | {:8d}" 65 | .format(self.num_train_pids, len(self.train))) 66 | print(" query | {:5d} | {:8d}" 67 | .format(self.num_query_ids, len(self.query))) 68 | print(" gallery | {:5d} | {:8d}" 69 | .format(self.num_gallery_ids, len(self.gallery))) 70 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classification import accuracy 4 | from .ranking import cmc, mean_ap 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap' 10 | ] 11 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from ..utils import to_torch 5 | 6 | 7 | def accuracy(output, target, topk=(1,)): 8 | with torch.no_grad(): 9 | output, target = to_torch(output), to_torch(target) 10 | maxk = max(topk) 11 | batch_size = target.size(0) 12 | 13 | _, pred = output.topk(maxk, 1, True, True) 14 | pred = pred.t() 15 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 16 | 17 | ret = [] 18 | for k in topk: 19 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 20 | ret.append(correct_k.mul_(1. / batch_size)) 21 | return ret 22 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from sklearn.metrics import average_precision_score 6 | 7 | from ..utils import to_numpy 8 | 9 | 10 | def _unique_sample(ids_dict, num): 11 | mask = np.zeros(num, dtype=np.bool) 12 | for _, indices in ids_dict.items(): 13 | i = np.random.choice(indices) 14 | mask[i] = True 15 | return mask 16 | 17 | 18 | def cmc(distmat, query_ids=None, gallery_ids=None, 19 | query_cams=None, gallery_cams=None, topk=100, 20 | separate_camera_set=False, 21 | single_gallery_shot=False, 22 | first_match_break=False): 23 | distmat = to_numpy(distmat) 24 | m, n = distmat.shape 25 | # Fill up default values 26 | if query_ids is None: 27 | query_ids = np.arange(m) 28 | if gallery_ids is None: 29 | gallery_ids = np.arange(n) 30 | if query_cams is None: 31 | query_cams = np.zeros(m).astype(np.int32) 32 | if gallery_cams is None: 33 | gallery_cams = np.ones(n).astype(np.int32) 34 | # Ensure numpy array 35 | query_ids = np.asarray(query_ids) 36 | gallery_ids = np.asarray(gallery_ids) 37 | query_cams = np.asarray(query_cams) 38 | gallery_cams = np.asarray(gallery_cams) 39 | # Sort and find correct matches 40 | indices = np.argsort(distmat, axis=1) #, dtype='int32') 41 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 42 | # Compute CMC for each query 43 | ret = np.zeros(topk) 44 | num_valid_queries = 0 45 | for i in range(m): 46 | # Filter out the same id and same camera 47 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 48 | (gallery_cams[indices[i]] != query_cams[i])) 49 | if separate_camera_set: 50 | # Filter out samples from same camera 51 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 52 | if not np.any(matches[i, valid]): continue 53 | if single_gallery_shot: 54 | repeat = 10 55 | gids = gallery_ids[indices[i][valid]] 56 | inds = np.where(valid)[0] 57 | ids_dict = defaultdict(list) 58 | for j, x in zip(inds, gids): 59 | ids_dict[x].append(j) 60 | else: 61 | repeat = 1 62 | for _ in range(repeat): 63 | if single_gallery_shot: 64 | # Randomly choose one instance for each id 65 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 66 | index = np.nonzero(matches[i, sampled])[0] 67 | else: 68 | index = np.nonzero(matches[i, valid])[0] 69 | delta = 1. / (len(index) * repeat) 70 | for j, k in enumerate(index): 71 | if k - j >= topk: break 72 | if first_match_break: 73 | ret[k - j] += 1 74 | break 75 | ret[k - j] += delta 76 | num_valid_queries += 1 77 | if num_valid_queries == 0: 78 | raise RuntimeError("No valid query") 79 | return ret.cumsum() / num_valid_queries 80 | 81 | 82 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 83 | query_cams=None, gallery_cams=None): 84 | distmat = to_numpy(distmat) 85 | m, n = distmat.shape 86 | # Fill up default values 87 | if query_ids is None: 88 | query_ids = np.arange(m) 89 | if gallery_ids is None: 90 | gallery_ids = np.arange(n) 91 | if query_cams is None: 92 | query_cams = np.zeros(m).astype(np.int32) 93 | if gallery_cams is None: 94 | gallery_cams = np.ones(n).astype(np.int32) 95 | # Ensure numpy array 96 | query_ids = np.asarray(query_ids) 97 | gallery_ids = np.asarray(gallery_ids) 98 | query_cams = np.asarray(query_cams) 99 | gallery_cams = np.asarray(gallery_cams) 100 | # Sort and find correct matches 101 | indices = np.argsort(distmat, axis=1) #, dtype='int32') 102 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 103 | # Compute AP for each query 104 | aps = [] 105 | for i in range(m): 106 | # Filter out the same id and same camera 107 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 108 | (gallery_cams[indices[i]] != query_cams[i])) 109 | y_true = matches[i, valid] 110 | y_score = -distmat[i][indices[i]][valid] 111 | if not np.any(y_true): continue 112 | aps.append(average_precision_score(y_true, y_score)) 113 | if len(aps) == 0: 114 | raise RuntimeError("No valid query") 115 | return np.mean(aps) 116 | -------------------------------------------------------------------------------- /reid/evaluators.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | import collections 4 | from collections import OrderedDict 5 | import numpy as np 6 | import torch 7 | import random 8 | import copy 9 | 10 | from .evaluation_metrics import cmc, mean_ap 11 | from .utils.meters import AverageMeter 12 | from .utils.rerank import re_ranking 13 | from .utils import to_torch 14 | 15 | 16 | def extract_cnn_feature(model, inputs): 17 | inputs = to_torch(inputs).cuda() 18 | outputs = model(inputs) 19 | # print(outputs.shape) 20 | outputs = outputs.data.cpu() 21 | return outputs 22 | 23 | 24 | def extract_features(model, data_loader, print_freq=50): 25 | model.eval() 26 | batch_time = AverageMeter() 27 | data_time = AverageMeter() 28 | 29 | features = OrderedDict() 30 | labels = OrderedDict() 31 | 32 | end = time.time() 33 | with torch.no_grad(): 34 | for i, (imgs, fnames, pids, _,_, _) in enumerate(data_loader): 35 | data_time.update(time.time() - end) 36 | 37 | outputs = extract_cnn_feature(model, imgs) 38 | for fname, output, pid in zip(fnames, outputs, pids): 39 | features[fname] = output 40 | labels[fname] = pid 41 | 42 | batch_time.update(time.time() - end) 43 | end = time.time() 44 | 45 | if (i + 1) % print_freq == 0: 46 | print('Extract Features: [{}/{}]\t' 47 | 'Time {:.3f} ({:.3f})\t' 48 | 'Data {:.3f} ({:.3f})\t' 49 | .format(i + 1, len(data_loader), 50 | batch_time.val, batch_time.avg, 51 | data_time.val, data_time.avg)) 52 | 53 | return features, labels 54 | 55 | 56 | def pairwise_distance(features, query=None, gallery=None): 57 | if query is None and gallery is None: 58 | n = len(features) 59 | x = torch.cat(list(features.values())) 60 | x = x.view(n, -1) 61 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2 62 | dist_m = dist_m.expand(n, n) - 2 * torch.mm(x, x.t()) 63 | return dist_m 64 | 65 | x = torch.cat([features[f].unsqueeze(0) for f, _, _, _ in query], 0) 66 | y = torch.cat([features[f].unsqueeze(0) for f, _, _, _ in gallery], 0) 67 | m, n = x.size(0), y.size(0) 68 | x = x.view(m, -1) 69 | y = y.view(n, -1) 70 | dist_m = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 71 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 72 | dist_m.addmm_(1, -2, x, y.t()) 73 | return dist_m, x.numpy(), y.numpy() 74 | 75 | 76 | def evaluate_all(query_features, gallery_features, distmat, query=None, gallery=None, 77 | query_ids=None, gallery_ids=None, 78 | query_cams=None, gallery_cams=None, 79 | cmc_topk=(1, 5, 10), cmc_flag=False): 80 | if query is not None and gallery is not None: 81 | query_ids = [pid for _, pid, _ , _ in query] 82 | gallery_ids = [pid for _, pid, _, _ in gallery] 83 | query_cams = [cam for _, _, cam , _ in query] 84 | gallery_cams = [cam for _, _, cam , _ in gallery] 85 | else: 86 | assert (query_ids is not None and gallery_ids is not None 87 | and query_cams is not None and gallery_cams is not None) 88 | 89 | # Compute mean AP 90 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 91 | print('Mean AP: {:4.1%}'.format(mAP)) 92 | 93 | if not cmc_flag: 94 | return mAP 95 | 96 | cmc_configs = { 97 | 'market1501': dict(separate_camera_set=False, 98 | single_gallery_shot=False, 99 | first_match_break=True), } 100 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 101 | query_cams, gallery_cams, **params) 102 | for name, params in cmc_configs.items()} 103 | 104 | print('CMC Scores:') 105 | for k in cmc_topk: 106 | print(' top-{:<4}{:12.1%}'.format(k, cmc_scores['market1501'][k - 1])) 107 | return cmc_scores['market1501'], mAP 108 | 109 | 110 | class Evaluator(object): 111 | def __init__(self, model): 112 | super(Evaluator, self).__init__() 113 | self.model = model 114 | 115 | def evaluate(self, data_loader, query, gallery, cmc_flag=False, rerank=False): 116 | features, _ = extract_features(self.model, data_loader) 117 | distmat, query_features, gallery_features = pairwise_distance(features, query, gallery) 118 | results = evaluate_all(query_features, gallery_features, distmat, query=query, gallery=gallery, 119 | cmc_flag=cmc_flag) 120 | 121 | if not rerank: 122 | return results 123 | 124 | print('Applying person re-ranking ...') 125 | distmat_qq = pairwise_distance(features, query, query) 126 | distmat_gg = pairwise_distance(features, gallery, gallery) 127 | distmat = re_ranking(distmat.numpy(), distmat_qq.numpy(), distmat_gg.numpy()) 128 | return evaluate_all(query_features, gallery_features, distmat, query=query, gallery=gallery, cmc_flag=cmc_flag) 129 | -------------------------------------------------------------------------------- /reid/loss/triplet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | from torch.nn import functional as F 7 | from scipy.stats import norm 8 | 9 | import numpy as np 10 | 11 | class TripletLoss(nn.Module): 12 | def __init__(self, margin=0, num_instances=0, use_semi=True): 13 | super(TripletLoss, self).__init__() 14 | self.margin = margin 15 | self.ranking_loss = nn.MarginRankingLoss(margin=self.margin) 16 | self.K = num_instances 17 | self.use_semi = use_semi 18 | 19 | def forward(self, inputs, targets): 20 | n = inputs.size(0) 21 | P = n/self.K 22 | t0 = 20.0 23 | t1 = 40.0 24 | 25 | # Compute pairwise distance, replace by the official when merged 26 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 27 | dist = dist + dist.t() 28 | dist.addmm_(1, -2, inputs, inputs.t()) 29 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 30 | # For each anchor, find the hardest positive and negative 31 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 32 | dist_ap, dist_an = [], [] 33 | if self.use_semi: 34 | for i in range(P): 35 | for j in range(self.K): 36 | neg_examples = dist[i*self.K+j][mask[i*self.K+j] == 0] 37 | for pair in range(j+1, self.K): 38 | ap = dist[i*self.K+j][i*self.K+pair] 39 | dist_ap.append(ap) 40 | dist_an.append(neg_examples.min()) 41 | else: 42 | for i in range(n): 43 | dist_ap.append(torch.max(dist[i][mask[i]])) 44 | dist_an.append(torch.min(dist[i][mask[i] == 0])) 45 | dist_ap = [dist_ap[i].unsqueeze(0) for i in range(len(dist_ap))] 46 | dist_ap = torch.cat(dist_ap) 47 | dist_an = [dist_an[i].unsqueeze(0) for i in range(len(dist_ap))] 48 | dist_an = torch.cat(dist_an) 49 | 50 | # Compute ranking hinge loss 51 | y = dist_an.data.new() 52 | y.resize_as_(dist_an.data) 53 | y.fill_(1) 54 | y = Variable(y) 55 | loss = self.ranking_loss(dist_an, dist_ap, y) 56 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0) 57 | return loss 58 | -------------------------------------------------------------------------------- /reid/models/IBNMeta.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.nn import init 7 | from torch.autograd import Variable 8 | from torchvision.models import resnet50, resnet34 9 | import math 10 | import os 11 | from .MetaModules import * 12 | 13 | 14 | class IBN(nn.Module): 15 | def __init__(self, planes): 16 | super(IBN, self).__init__() 17 | half1 = int(planes / 2) 18 | self.half = half1 19 | half2 = planes - half1 20 | self.IN = MetaInstanceNorm2d(half1, affine=True) 21 | self.BN = MetaBatchNorm2d(half2) 22 | 23 | def forward(self, x): 24 | split = torch.split(x, self.half, 1) 25 | out1 = self.IN(split[0].contiguous()) 26 | out2 = self.BN(split[1].contiguous()) 27 | out = torch.cat((out1, out2), 1) 28 | return out 29 | 30 | 31 | class Bottleneck_IBN(nn.Module): 32 | expansion = 4 33 | 34 | def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None): 35 | super(Bottleneck_IBN, self).__init__() 36 | self.conv1 = MetaConv2d(inplanes, planes, kernel_size=1, bias=False) 37 | if ibn: 38 | self.bn1 = IBN(planes) 39 | else: 40 | self.bn1 = MetaBatchNorm2d(planes) 41 | self.conv2 = MetaConv2d(planes, planes, kernel_size=3, stride=stride, 42 | padding=1, bias=False) 43 | self.bn2 = MetaBatchNorm2d(planes) 44 | self.conv3 = MetaConv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 45 | self.bn3 = MetaBatchNorm2d(planes * self.expansion) 46 | self.relu = nn.ReLU(inplace=True) 47 | self.downsample = downsample 48 | self.stride = stride 49 | 50 | def forward(self, x): 51 | residual = x 52 | 53 | out = self.conv1(x) 54 | out = self.bn1(out) 55 | out = self.relu(out) 56 | 57 | out = self.conv2(out) 58 | out = self.bn2(out) 59 | out = self.relu(out) 60 | 61 | out = self.conv3(out) 62 | out = self.bn3(out) 63 | 64 | if self.downsample is not None: 65 | residual = self.downsample(x) 66 | 67 | out += residual 68 | out = self.relu(out) 69 | 70 | return out 71 | 72 | 73 | class MetaResNet_IBN_a_base(MetaModule): 74 | 75 | def __init__(self, layers, block=Bottleneck_IBN): 76 | scale = 64 77 | self.inplanes = scale 78 | super(MetaResNet_IBN_a_base, self).__init__() 79 | self.conv1 = MetaConv2d(3, scale, kernel_size=7, stride=2, padding=3, 80 | bias=False) 81 | self.bn1 = MetaBatchNorm2d(scale) 82 | # self.relu = nn.ReLU(inplace=True) 83 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 84 | self.layer1 = self._make_layer(block, scale, layers[0]) 85 | self.layer2 = self._make_layer(block, scale * 2, layers[1], stride=2) 86 | self.layer3 = self._make_layer(block, scale * 4, layers[2], stride=2) 87 | self.layer4 = self._make_layer(block, scale * 8, layers[3], stride=2) 88 | # self.avgpool = nn.AvgPool2d(7) 89 | # self.fc = nn.Linear(scale * 8 * block.expansion, num_classes) 90 | 91 | for m in self.modules(): 92 | if isinstance(m, MetaConv2d): 93 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 94 | m.weight.data.normal_(0, math.sqrt(2. / n)) 95 | elif isinstance(m, MetaBatchNorm2d): 96 | m.weight.data.fill_(1) 97 | m.bias.data.zero_() 98 | elif isinstance(m, MetaInstanceNorm2d): 99 | m.weight.data.fill_(1) 100 | m.bias.data.zero_() 101 | 102 | def _make_layer(self, block, planes, blocks, stride=1): 103 | downsample = None 104 | if stride != 1 or self.inplanes != planes * block.expansion: 105 | downsample = nn.Sequential( 106 | MetaConv2d(self.inplanes, planes * block.expansion, 107 | kernel_size=1, stride=stride, bias=False), 108 | MetaBatchNorm2d(planes * block.expansion), 109 | ) 110 | 111 | layers = [] 112 | ibn = True 113 | if planes == 512: 114 | ibn = False 115 | layers.append(block(self.inplanes, planes, ibn, stride, downsample)) 116 | self.inplanes = planes * block.expansion 117 | for i in range(1, blocks): 118 | layers.append(block(self.inplanes, planes, ibn)) 119 | 120 | return nn.Sequential(*layers) 121 | 122 | def forward(self, x): 123 | x = self.conv1(x) 124 | x = self.bn1(x) 125 | # x = self.relu(x) 126 | x = self.maxpool(x) 127 | 128 | x = self.layer1(x) 129 | x = self.layer2(x) 130 | x = self.layer3(x) 131 | x = self.layer4(x) 132 | 133 | # x = self.avgpool(x) 134 | # x = x.view(x.size(0), -1) 135 | # x = self.fc(x) 136 | 137 | return x 138 | 139 | 140 | class MetaIBN(MetaModule): 141 | def __init_with_imagenet(self, baseModel): 142 | model = resnet50(pretrained=True) 143 | del model.fc 144 | baseModel.copyWeight(model.state_dict()) 145 | 146 | def getBase(self): 147 | baseModel = MetaResNet_IBN_a_base([3, 4, 6, 3]) 148 | self.__init_with_imagenet(baseModel) 149 | return baseModel 150 | 151 | def __init__(self, num_features=0, dropout=0, cut_at_pooling=False, norm=True, num_classes=[0,0,0], BNNeck=False): 152 | super(MetaIBN, self).__init__() 153 | self.num_features = num_features 154 | self.dropout = dropout 155 | self.cut_at_pooling = cut_at_pooling 156 | self.num_classes1 = num_classes[0] 157 | self.num_classes2 = num_classes[1] 158 | self.num_classes3 = num_classes[2] 159 | self.has_embedding = num_features > 0 160 | self.norm = norm 161 | self.BNNeck = BNNeck 162 | if self.dropout > 0: 163 | self.drop = nn.Dropout(self.dropout) 164 | # Construct base (pretrained) resnet 165 | self.base = self.getBase() 166 | self.base.layer4[0].conv2.stride = (1, 1) 167 | self.base.layer4[0].downsample[0].stride = (1, 1) 168 | self.gap = nn.AdaptiveAvgPool2d(1) 169 | out_planes = 2048 170 | if self.has_embedding: 171 | self.feat = MetaLinear(out_planes, self.num_features) 172 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 173 | init.constant_(self.feat.bias, 0) 174 | else: 175 | # Change the num_features to CNN output channels 176 | self.num_features = out_planes 177 | 178 | self.feat_bn = MixUpBatchNorm1d(self.num_features) 179 | init.constant_(self.feat_bn.weight, 1) 180 | init.constant_(self.feat_bn.bias, 0) 181 | 182 | self.reset_IN() 183 | 184 | def reset_IN(self): 185 | for m in self.modules(): 186 | if isinstance(m, MetaInstanceNorm2d): 187 | if m.affine: 188 | init.constant_(m.weight, 1) 189 | init.constant_(m.bias, 0) 190 | 191 | def forward(self, x, MTE='', save_index=0): 192 | x= self.base(x) 193 | x = self.gap(x) 194 | x = x.view(x.size(0), -1) 195 | 196 | if self.cut_at_pooling: 197 | return x 198 | 199 | if self.has_embedding: 200 | bn_x = self.feat_bn(self.feat(x)) 201 | else: 202 | bn_x= self.feat_bn(x, MTE, save_index) 203 | tri_features = x 204 | 205 | if self.training is False: 206 | bn_x = F.normalize(bn_x) 207 | return bn_x 208 | 209 | if isinstance(bn_x, list): 210 | output = [] 211 | for bnfeature in bn_x: 212 | if self.norm: 213 | bnfeature = F.normalize(bnfeature) 214 | output.append(bnfeature) 215 | if self.BNNeck: 216 | return output, tri_features 217 | else: 218 | return output 219 | 220 | if self.norm: 221 | bn_x = F.normalize(bn_x) 222 | elif self.has_embedding: 223 | bn_x = F.relu(bn_x) 224 | 225 | if self.dropout > 0: 226 | bn_x = self.drop(bn_x) 227 | 228 | if self.BNNeck: 229 | return bn_x, tri_features 230 | else: 231 | return bn_x 232 | 233 | 234 | -------------------------------------------------------------------------------- /reid/models/MetaModules.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.nn import init 7 | from torch.autograd import Variable 8 | from torchvision.models import resnet50, resnet34 9 | import math 10 | import os 11 | import numpy as np 12 | 13 | 14 | def to_var(x, requires_grad=True): 15 | if torch.cuda.is_available(): x = x.cuda() 16 | return Variable(x, requires_grad=requires_grad) 17 | 18 | 19 | class MetaModule(nn.Module): 20 | # adopted from: Adrien Ecoffet https://github.com/AdrienLE 21 | def params(self): 22 | for name, param in self.named_params(self): 23 | yield param 24 | 25 | def named_leaves(self): 26 | return [] 27 | 28 | def named_submodules(self): 29 | return [] 30 | 31 | def named_params(self, curr_module=None, memo=None, prefix=''): 32 | if memo is None: 33 | memo = set() 34 | 35 | if hasattr(curr_module, 'named_leaves'): 36 | for name, p in curr_module.named_leaves(): 37 | if p is not None and p not in memo: 38 | memo.add(p) 39 | yield prefix + ('.' if prefix else '') + name, p 40 | else: 41 | for name, p in curr_module._parameters.items(): 42 | if p is not None and p not in memo: 43 | memo.add(p) 44 | yield prefix + ('.' if prefix else '') + name, p 45 | 46 | for mname, module in curr_module.named_children(): 47 | submodule_prefix = prefix + ('.' if prefix else '') + mname 48 | for name, p in self.named_params(module, memo, submodule_prefix): 49 | yield name, p 50 | 51 | def update_params(self, lr_inner, source_params=None, 52 | solver='sgd', beta1=0.9, beta2=0.999, weight_decay=5e-4): 53 | if solver == 'sgd': 54 | for tgt, src in zip(self.named_params(self), source_params): 55 | name_t, param_t = tgt 56 | grad = src if src is not None else 0 57 | tmp = param_t - lr_inner * grad 58 | self.set_param(self, name_t, tmp) 59 | elif solver == 'adam': 60 | for tgt, gradVal in zip(self.named_params(self), source_params): 61 | name_t, param_t = tgt 62 | exp_avg, exp_avg_sq = torch.zeros_like(param_t.data), \ 63 | torch.zeros_like(param_t.data) 64 | bias_correction1 = 1 - beta1 65 | bias_correction2 = 1 - beta2 66 | gradVal.add_(weight_decay, param_t) 67 | exp_avg.mul_(beta1).add_(1 - beta1, gradVal) 68 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, gradVal, gradVal) 69 | exp_avg_sq.add_(1e-8) # to avoid possible nan in backward 70 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(1e-8) 71 | step_size = lr_inner / bias_correction1 72 | newParam = param_t.addcdiv(-step_size, exp_avg, denom) 73 | self.set_param(self, name_t, newParam) 74 | 75 | def setParams(self, params): 76 | for tgt, param in zip(self.named_params(self), params): 77 | name_t, _ = tgt 78 | self.set_param(self, name_t, param) 79 | 80 | def set_param(self, curr_mod, name, param): 81 | if '.' in name: 82 | n = name.split('.') 83 | module_name = n[0] 84 | rest = '.'.join(n[1:]) 85 | for name, mod in curr_mod.named_children(): 86 | if module_name == name: 87 | self.set_param(mod, rest, param) 88 | break 89 | else: 90 | setattr(curr_mod, name, param) 91 | 92 | def setBN(self, inPart, name, param): 93 | if '.' in name: 94 | part = name.split('.') 95 | self.setBN(getattr(inPart, part[0]), '.'.join(part[1:]), param) 96 | else: 97 | setattr(inPart, name, param) 98 | 99 | def detach_params(self): 100 | for name, param in self.named_params(self): 101 | self.set_param(self, name, param.detach()) 102 | 103 | def copyModel(self, newModel, same_var=False): 104 | # copy meta model to meta model 105 | tarName = list(map(lambda v: v, newModel.state_dict().keys())) 106 | 107 | # requires_grad 108 | partName, partW = list(map(lambda v: v[0], newModel.named_params(newModel))), list( 109 | map(lambda v: v[1], newModel.named_params(newModel))) # new model's weight 110 | 111 | metaName, metaW = list(map(lambda v: v[0], self.named_params(self))), list( 112 | map(lambda v: v[1], self.named_params(self))) 113 | bnNames = list(set(tarName) - set(partName)) 114 | 115 | # copy vars 116 | for name, param in zip(metaName, partW): 117 | if not same_var: 118 | param = to_var(param.data.clone(), requires_grad=True) 119 | self.set_param(self, name, param) 120 | # copy training mean var 121 | tarName = newModel.state_dict() 122 | for name in bnNames: 123 | param = to_var(tarName[name], requires_grad=False) 124 | self.setBN(self, name, param) 125 | 126 | def copyWeight(self, modelW): 127 | # copy state_dict to buffers 128 | curName = list(map(lambda v: v[0], self.named_params(self))) 129 | tarNames = set() 130 | for name in modelW.keys(): 131 | # print(name) 132 | if name.startswith("module"): 133 | tarNames.add(".".join(name.split(".")[1:])) 134 | else: 135 | tarNames.add(name) 136 | bnNames = list(tarNames - set(curName)) 137 | for tgt in self.named_params(self): 138 | name_t, param_t = tgt 139 | # print(name_t) 140 | module_name_t = 'module.' + name_t 141 | if name_t in modelW: 142 | param = to_var(modelW[name_t], requires_grad=True) 143 | self.set_param(self, name_t, param) 144 | elif module_name_t in modelW: 145 | param = to_var(modelW['module.' + name_t], requires_grad=True) 146 | self.set_param(self, name_t, param) 147 | else: 148 | continue 149 | 150 | 151 | def copyWeight_eval(self, modelW): 152 | # copy state_dict to buffers 153 | curName = list(map(lambda v: v[0], self.named_params(self))) 154 | tarNames = set() 155 | for name in modelW.keys(): 156 | # print(name) 157 | if name.startswith("module"): 158 | tarNames.add(".".join(name.split(".")[1:])) 159 | else: 160 | tarNames.add(name) 161 | bnNames = list(tarNames - set(curName)) ## in BN resMeta bnNames only contains running var/mean 162 | for tgt in self.named_params(self): 163 | name_t, param_t = tgt 164 | # print(name_t) 165 | module_name_t = 'module.' + name_t 166 | if name_t in modelW: 167 | param = to_var(modelW[name_t], requires_grad=True) 168 | self.set_param(self, name_t, param) 169 | elif module_name_t in modelW: 170 | param = to_var(modelW['module.' + name_t], requires_grad=True) 171 | self.set_param(self, name_t, param) 172 | else: 173 | continue 174 | 175 | for name in bnNames: 176 | try: 177 | param = to_var(modelW[name], requires_grad=False) 178 | except: 179 | param = to_var(modelW['module.' + name], requires_grad=False) 180 | self.setBN(self, name, param) 181 | 182 | 183 | class MetaLinear(MetaModule): 184 | def __init__(self, *args, **kwargs): 185 | super().__init__() 186 | ignore = nn.Linear(*args, **kwargs) 187 | self.in_features = args[0] 188 | self.out_features = args[1] 189 | 190 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 191 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True) if ignore.bias is not None else None) 192 | 193 | def forward(self, x): 194 | return F.linear(x, self.weight, self.bias) 195 | 196 | def named_leaves(self): 197 | return [('weight', self.weight), ('bias', self.bias)] 198 | 199 | 200 | class MetaConv2d(MetaModule): 201 | def __init__(self, *args, **kwargs): 202 | super().__init__() 203 | ignore = nn.Conv2d(*args, **kwargs) 204 | self.in_channels = ignore.in_channels 205 | self.out_channels = ignore.out_channels 206 | self.stride = ignore.stride 207 | self.padding = ignore.padding 208 | self.dilation = ignore.dilation 209 | self.groups = ignore.groups 210 | self.kernel_size = ignore.kernel_size 211 | 212 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 213 | 214 | if ignore.bias is not None: 215 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 216 | else: 217 | self.register_buffer('bias', None) 218 | 219 | def forward(self, x): 220 | return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 221 | 222 | def named_leaves(self): 223 | return [('weight', self.weight), ('bias', self.bias)] 224 | 225 | 226 | class MetaBatchNorm2d(MetaModule): 227 | def __init__(self, *args, **kwargs): 228 | super().__init__() 229 | ignore = nn.BatchNorm2d(*args, **kwargs) 230 | 231 | self.num_features = ignore.num_features 232 | self.eps = ignore.eps 233 | self.momentum = ignore.momentum 234 | self.affine = ignore.affine 235 | self.track_running_stats = ignore.track_running_stats 236 | 237 | if self.affine: 238 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 239 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 240 | 241 | if self.track_running_stats: 242 | self.register_buffer('running_mean', torch.zeros(self.num_features)) 243 | self.register_buffer('running_var', torch.ones(self.num_features)) 244 | self.register_buffer('num_batches_tracked', torch.LongTensor([0]).squeeze()) 245 | else: 246 | self.register_buffer('running_mean', None) 247 | self.register_buffer('running_var', None) 248 | self.register_buffer('num_batches_tracked', None) 249 | 250 | def forward(self, x): 251 | val2 = self.weight.sum() 252 | res = F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, 253 | self.training or not self.track_running_stats, self.momentum, self.eps) 254 | return res 255 | 256 | def named_leaves(self): 257 | return [('weight', self.weight), ('bias', self.bias)] 258 | 259 | 260 | class MetaBatchNorm1d(MetaModule): 261 | def __init__(self, *args, **kwargs): 262 | super().__init__() 263 | ignore = nn.BatchNorm1d(*args, **kwargs) 264 | 265 | self.num_features = ignore.num_features 266 | self.eps = ignore.eps 267 | self.momentum = ignore.momentum 268 | self.affine = ignore.affine 269 | self.track_running_stats = ignore.track_running_stats 270 | 271 | if self.affine: 272 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 273 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 274 | 275 | if self.track_running_stats: 276 | self.register_buffer('running_mean', torch.zeros(self.num_features)) 277 | self.register_buffer('running_var', torch.ones(self.num_features)) 278 | self.register_buffer('num_batches_tracked', torch.LongTensor([0]).squeeze()) 279 | else: 280 | self.register_buffer('running_mean', None) 281 | self.register_buffer('running_var', None) 282 | self.register_buffer('num_batches_tracked', None) 283 | 284 | def forward(self, x): 285 | return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, 286 | self.training or not self.track_running_stats, self.momentum, self.eps) 287 | ## meta test set this one to False self.training or not self.track_running_stats 288 | def named_leaves(self): 289 | return [('weight', self.weight), ('bias', self.bias)] 290 | 291 | 292 | class MetaInstanceNorm2d(MetaModule): 293 | def __init__(self, *args, **kwargs): 294 | super().__init__() 295 | ignore = nn.InstanceNorm2d(*args, **kwargs) 296 | 297 | self.num_features = ignore.num_features 298 | self.eps = ignore.eps 299 | self.momentum = ignore.momentum 300 | self.affine = ignore.affine 301 | self.track_running_stats = ignore.track_running_stats 302 | 303 | if self.affine: 304 | self.register_buffer('weight', to_var(ignore.weight.data, requires_grad=True)) 305 | self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True)) 306 | else: 307 | self.register_buffer('weight', None) 308 | self.register_buffer('bias', None) 309 | 310 | if self.track_running_stats: 311 | self.register_buffer('running_mean', torch.zeros(self.num_features)) 312 | self.register_buffer('running_var', torch.ones(self.num_features)) 313 | self.register_buffer('num_batches_tracked', torch.LongTensor([0]).squeeze()) 314 | else: 315 | self.register_buffer('running_mean', None) 316 | self.register_buffer('running_var', None) 317 | self.register_buffer('num_batches_tracked', None) 318 | 319 | self.reset_parameters() 320 | 321 | def reset_parameters(self) -> None: 322 | if self.affine: 323 | init.constant_(self.weight, 1) 324 | init.constant_(self.bias, 0) 325 | 326 | def forward(self, x): 327 | 328 | res = F.instance_norm(x, self.running_mean, self.running_var, self.weight, self.bias, 329 | self.training or not self.track_running_stats, self.momentum, self.eps) 330 | return res 331 | 332 | def named_leaves(self): 333 | return [('weight', self.weight), ('bias', self.bias)] 334 | 335 | class MixUpBatchNorm1d(MetaBatchNorm1d): 336 | def __init__(self, num_features, eps=1e-5, momentum=0.1, 337 | affine=True, track_running_stats=True): 338 | super(MixUpBatchNorm1d, self).__init__( 339 | num_features, eps, momentum, affine, track_running_stats) 340 | 341 | self.register_buffer('meta_mean1', torch.zeros(self.num_features)) 342 | self.register_buffer('meta_var1', torch.zeros(self.num_features)) 343 | self.register_buffer('meta_mean2', torch.zeros(self.num_features)) 344 | self.register_buffer('meta_var2', torch.zeros(self.num_features)) 345 | self.device_count = torch.cuda.device_count() 346 | 347 | def forward(self, input, MTE='', save_index=0): 348 | exponential_average_factor = 0.0 349 | 350 | if self.training and self.track_running_stats: 351 | if self.num_batches_tracked is not None: 352 | self.num_batches_tracked += 1 353 | if self.momentum is None: # use cumulative moving average 354 | exponential_average_factor = 1.0 / float(self.num_batches_tracked) 355 | else: # use exponential moving average 356 | exponential_average_factor = self.momentum 357 | 358 | # calculate running estimates 359 | if self.training: 360 | if MTE == 'sample': 361 | from torch.distributions.normal import Normal 362 | Distri1 = Normal(self.meta_mean1, self.meta_var1) 363 | Distri2 = Normal(self.meta_mean2, self.meta_var2) 364 | sample1 = Distri1.sample([input.size(0), ]) 365 | sample2 = Distri2.sample([input.size(0), ]) 366 | lam = np.random.beta(1., 1.) 367 | inputmix1 = lam * sample1 + (1-lam) * input 368 | inputmix2 = lam * sample2 + (1-lam) * input 369 | 370 | mean1 = inputmix1.mean(dim=0) 371 | var1 = inputmix1.var(dim=0, unbiased=False) 372 | mean2 = inputmix2.mean(dim=0) 373 | var2 = inputmix2.var(dim=0, unbiased=False) 374 | 375 | output1 = (inputmix1 - mean1[None, :]) / (torch.sqrt(var1[None, :] + self.eps)) 376 | output2 = (inputmix2 - mean2[None, :]) / (torch.sqrt(var2[None, :] + self.eps)) 377 | if self.affine: 378 | output1 = output1 * self.weight[None, :] + self.bias[None, :] 379 | output2 = output2 * self.weight[None, :] + self.bias[None, :] 380 | return [output1, output2] 381 | 382 | else: 383 | mean = input.mean(dim=0) 384 | # use biased var in train 385 | var = input.var(dim=0, unbiased=False) 386 | n = input.numel() / input.size(1) 387 | 388 | with torch.no_grad(): 389 | running_mean = exponential_average_factor * mean \ 390 | + (1 - exponential_average_factor) * self.running_mean 391 | # update running_var with unbiased var 392 | running_var = exponential_average_factor * var * n / (n - 1) \ 393 | + (1 - exponential_average_factor) * self.running_var 394 | self.running_mean.copy_(running_mean) 395 | self.running_var.copy_(running_var) 396 | if save_index == 1: 397 | self.meta_mean1.copy_(mean) 398 | self.meta_var1.copy_(var) 399 | elif save_index == 2: 400 | self.meta_mean2.copy_(mean) 401 | self.meta_var2.copy_(var) 402 | 403 | else: 404 | mean = self.running_mean 405 | var = self.running_var 406 | 407 | input = (input - mean[None, :]) / (torch.sqrt(var[None, :] + self.eps)) 408 | if self.affine: 409 | input = input * self.weight[None, :] + self.bias[None, :] 410 | 411 | return input 412 | -------------------------------------------------------------------------------- /reid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .IBNMeta import MetaIBN 4 | from .resMeta import MetaResNet 5 | 6 | __factory = { 7 | 'resMeta': MetaResNet, 8 | 'IBNMeta': MetaIBN, 9 | 10 | } 11 | 12 | 13 | def names(): 14 | return sorted(__factory.keys()) 15 | 16 | 17 | def create(name, *args, **kwargs): 18 | """ 19 | Create a model instance. 20 | 21 | Parameters 22 | ---------- 23 | name : str 24 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 25 | 'resnet50', 'resnet101', and 'resnet152'. 26 | pretrained : bool, optional 27 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 28 | model. Default: True 29 | cut_at_pooling : bool, optional 30 | If True, will cut the model before the last global pooling layer and 31 | ignore the remaining kwargs. Default: False 32 | num_features : int, optional 33 | If positive, will append a Linear layer after the global pooling layer, 34 | with this number of output units, followed by a BatchNorm layer. 35 | Otherwise these layers will not be appended. Default: 256 for 36 | 'inception', 0 for 'resnet*' 37 | norm : bool, optional 38 | If True, will normalize the feature to be unit L2-norm for each sample. 39 | Otherwise will append a ReLU layer after the above Linear layer if 40 | num_features > 0. Default: False 41 | dropout : float, optional 42 | If positive, will append a Dropout layer with this dropout rate. 43 | Default: 0 44 | num_classes : int, optional 45 | If positive, will append a Linear layer at the end as the classifier 46 | with this number of output units. Default: 0 47 | """ 48 | if name not in __factory: 49 | raise KeyError("Unknown model:", name) 50 | return __factory[name](*args, **kwargs) 51 | -------------------------------------------------------------------------------- /reid/models/memory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import init 4 | from torch import nn, autograd 5 | import numpy as np 6 | 7 | class MC(autograd.Function): 8 | 9 | @staticmethod 10 | def forward(ctx, inputs, indexes, features, momentum): 11 | ctx.features = features 12 | ctx.momentum = momentum 13 | ctx.save_for_backward(inputs, indexes) 14 | outputs = inputs.mm(ctx.features.t()) 15 | 16 | return outputs 17 | 18 | @staticmethod 19 | def backward(ctx, grad_outputs): 20 | inputs, indexes = ctx.saved_tensors 21 | grad_inputs = None 22 | if ctx.needs_input_grad[0]: 23 | grad_inputs = grad_outputs.mm(ctx.features) 24 | 25 | return grad_inputs, None, None, None 26 | 27 | 28 | def mc(inputs, indexes, features, momentum=0.5): 29 | return MC.apply(inputs, indexes, features, torch.Tensor([momentum]).to(inputs.device)) 30 | 31 | 32 | class MemoryClassifier(nn.Module): 33 | def __init__(self, num_features, num_samples, temp=0.05, momentum=0.2): 34 | super(MemoryClassifier, self).__init__() 35 | self.num_features = num_features 36 | self.num_samples = num_samples 37 | self.momentum = momentum 38 | self.temp = temp 39 | 40 | self.register_buffer('features', torch.zeros(num_samples, num_features)) 41 | self.register_buffer('labels', torch.zeros(num_samples).long()) 42 | 43 | def MomentumUpdate(self, inputs, indexes): 44 | # momentum update 45 | for x, y in zip(inputs, indexes): 46 | self.features[y] = self.momentum * self.features[y] + (1. - self.momentum) * x 47 | self.features[y] = self.features[y] / self.features[y].norm() 48 | 49 | 50 | def forward(self, inputs, indexes): 51 | 52 | sim = mc(inputs, indexes, self.features, self.momentum) ## B * C 53 | 54 | sim = sim / self.temp 55 | 56 | loss = F.cross_entropy(sim, indexes) 57 | return loss 58 | 59 | 60 | -------------------------------------------------------------------------------- /reid/models/resMeta.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.nn import init 7 | from torch.autograd import Variable 8 | from torchvision.models import resnet50, resnet34 9 | import math 10 | import os 11 | import numpy as np 12 | from .MetaModules import * 13 | 14 | 15 | class Bottleneck(nn.Module): 16 | expansion = 4 17 | 18 | def __init__(self, inplanes, planes, stride=1, downsample=None): 19 | super(Bottleneck, self).__init__() 20 | self.conv1 = MetaConv2d(inplanes, planes, kernel_size=1, bias=False) 21 | self.bn1 = MetaBatchNorm2d(planes) 22 | self.conv2 = MetaConv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 23 | self.bn2 = MetaBatchNorm2d(planes) 24 | self.conv3 = MetaConv2d(planes, planes * 4, kernel_size=1, bias=False) 25 | self.bn3 = MetaBatchNorm2d(planes * 4) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.downsample = downsample 28 | self.stride = stride 29 | 30 | def forward(self, x): 31 | residual = x 32 | 33 | out = self.conv1(x) 34 | out = self.bn1(out) 35 | out = self.relu(out) 36 | 37 | out = self.conv2(out) 38 | out = self.bn2(out) 39 | out = self.relu(out) 40 | 41 | out = self.conv3(out) 42 | out = self.bn3(out) 43 | 44 | if self.downsample is not None: 45 | residual = self.downsample(x) 46 | 47 | out += residual 48 | out = self.relu(out) 49 | 50 | return out 51 | 52 | class BasicBlock(nn.Module): 53 | expansion = 1 54 | 55 | def __init__(self, inplanes, planes, stride=1, downsample=None): 56 | super(BasicBlock, self).__init__() 57 | 58 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 59 | self.conv1 = MetaConv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 60 | self.bn1 = MetaBatchNorm2d(planes) 61 | self.relu = nn.ReLU(inplace=True) 62 | self.conv2 = MetaConv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 63 | self.bn2 = MetaBatchNorm2d(planes) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x): 68 | identity = x 69 | out = self.conv1(x) 70 | out = self.bn1(out) 71 | out = self.relu(out) 72 | 73 | out = self.conv2(out) 74 | out = self.bn2(out) 75 | 76 | if self.downsample is not None: 77 | identity = self.downsample(x) 78 | 79 | out += identity 80 | out = self.relu(out) 81 | 82 | return out 83 | 84 | class MetaResNetBase(MetaModule): 85 | def __init__(self, layers, block=Bottleneck): 86 | super(MetaResNetBase, self).__init__() 87 | self.inplanes = 64 88 | self.conv1 = MetaConv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 89 | self.bn1 = MetaBatchNorm2d(64) 90 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 91 | self.layer1 = self._make_layer(block, 64, layers[0]) 92 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 93 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 94 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 95 | 96 | def _make_layer(self, block, planes, blocks, stride=1): 97 | downsample = None 98 | if stride != 1 or self.inplanes != planes * block.expansion: 99 | downsample = nn.Sequential( 100 | MetaConv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 101 | MetaBatchNorm2d(planes * block.expansion), 102 | ) 103 | 104 | layers = [ 105 | block(self.inplanes, planes, stride, downsample) 106 | ] 107 | 108 | self.inplanes = planes * block.expansion 109 | for i in range(1, blocks): 110 | layers.append(block(self.inplanes, planes)) 111 | 112 | return nn.Sequential(*layers) 113 | 114 | def forward(self, x, MTE=False): 115 | x = self.conv1(x) 116 | x = self.bn1(x) 117 | x = self.maxpool(x) 118 | 119 | x = self.layer1(x) 120 | x = self.layer2(x) 121 | x = self.layer3(x) 122 | x = self.layer4(x) 123 | 124 | return x 125 | 126 | 127 | class MetaResNet(MetaModule): 128 | def __init_with_imagenet(self, baseModel): 129 | model = resnet50(pretrained=True) 130 | del model.fc 131 | baseModel.copyWeight(model.state_dict()) 132 | 133 | def getBase(self): 134 | baseModel = MetaResNetBase([3, 4, 6, 3]) 135 | self.__init_with_imagenet(baseModel) 136 | return baseModel 137 | 138 | def __init__(self, num_features=0, dropout=0, cut_at_pooling=False, norm=True, num_classes=[0,0,0], BNNeck=False): 139 | super(MetaResNet, self).__init__() 140 | self.num_features = num_features 141 | self.dropout = dropout 142 | self.cut_at_pooling = cut_at_pooling 143 | self.num_classes1 = num_classes[0] 144 | self.num_classes2 = num_classes[1] 145 | self.num_classes3 = num_classes[2] 146 | self.has_embedding = num_features > 0 147 | self.norm = norm 148 | self.BNNeck = BNNeck 149 | if self.dropout > 0: 150 | self.drop = nn.Dropout(self.dropout) 151 | # Construct base (pretrained) resnet 152 | self.base = self.getBase() 153 | self.base.layer4[0].conv2.stride = (1, 1) 154 | self.base.layer4[0].downsample[0].stride = (1, 1) 155 | self.gap = nn.AdaptiveAvgPool2d(1) 156 | out_planes = 2048 157 | if self.has_embedding: 158 | self.feat = MetaLinear(out_planes, self.num_features) 159 | init.kaiming_normal_(self.feat.weight, mode='fan_out') 160 | init.constant_(self.feat.bias, 0) 161 | else: 162 | # Change the num_features to CNN output channels 163 | self.num_features = out_planes 164 | 165 | self.feat_bn = MixUpBatchNorm1d(self.num_features) 166 | init.constant_(self.feat_bn.weight, 1) 167 | init.constant_(self.feat_bn.bias, 0) 168 | 169 | def forward(self, x, MTE='', save_index=0): 170 | x= self.base(x) 171 | x = self.gap(x) 172 | x = x.view(x.size(0), -1) 173 | 174 | if self.cut_at_pooling: 175 | return x 176 | 177 | if self.has_embedding: 178 | bn_x = self.feat_bn(self.feat(x)) 179 | else: 180 | bn_x = self.feat_bn(x, MTE, save_index) 181 | tri_features = x 182 | 183 | if self.training is False: 184 | bn_x = F.normalize(bn_x) 185 | return bn_x 186 | 187 | if isinstance(bn_x, list): 188 | output = [] 189 | for bnfeature in bn_x: 190 | if self.norm: 191 | bnfeature = F.normalize(bnfeature) 192 | output.append(bnfeature) 193 | if self.BNNeck: 194 | return output, tri_features 195 | else: 196 | return output 197 | 198 | if self.norm: 199 | bn_x = F.normalize(bn_x) 200 | elif self.has_embedding: 201 | bn_x = F.relu(bn_x) 202 | 203 | if self.dropout > 0: 204 | bn_x = self.drop(bn_x) 205 | 206 | if self.BNNeck: 207 | return bn_x, tri_features 208 | else: 209 | return bn_x 210 | 211 | 212 | -------------------------------------------------------------------------------- /reid/solver/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | from .build import make_optimizer 3 | from .lr_scheduler import WarmupMultiStepLR -------------------------------------------------------------------------------- /reid/solver/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import torch 4 | 5 | 6 | def make_optimizer(args, model): 7 | params = [] 8 | for key, value in model.named_parameters(): 9 | if not value.requires_grad: 10 | continue 11 | lr = args.lr 12 | weight_decay = args.weight_decay 13 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 14 | 15 | optimizer = getattr(torch.optim, 'Adam')(params) 16 | return optimizer 17 | -------------------------------------------------------------------------------- /reid/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | from bisect import bisect_right 3 | import torch 4 | 5 | 6 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 7 | def __init__( 8 | self, 9 | optimizer, 10 | milestones, 11 | gamma=0.1, 12 | warmup_factor=1.0 / 3, 13 | warmup_iters=500, 14 | warmup_method="linear", 15 | last_epoch=-1, 16 | ): 17 | if not list(milestones) == sorted(milestones): 18 | raise ValueError( 19 | "Milestones should be a list of" " increasing integers. Got {}", 20 | milestones, 21 | ) 22 | 23 | if warmup_method not in ("constant", "linear"): 24 | raise ValueError( 25 | "Only 'constant' or 'linear' warmup_method accepted" 26 | "got {}".format(warmup_method) 27 | ) 28 | self.milestones = milestones 29 | self.gamma = gamma 30 | self.warmup_factor = warmup_factor 31 | self.warmup_iters = warmup_iters 32 | self.warmup_method = warmup_method 33 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 34 | 35 | def get_lr(self): 36 | warmup_factor = 1 37 | if self.last_epoch < self.warmup_iters: 38 | if self.warmup_method == "constant": 39 | warmup_factor = self.warmup_factor 40 | elif self.warmup_method == "linear": 41 | alpha = self.last_epoch / self.warmup_iters 42 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 43 | return [ 44 | base_lr 45 | * warmup_factor 46 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 47 | for base_lr in self.base_lrs 48 | ] 49 | -------------------------------------------------------------------------------- /reid/trainers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import, division 2 | import time 3 | import numpy as np 4 | import collections 5 | 6 | import torch 7 | import torch.nn as nn 8 | import math 9 | import torch.nn.functional as F 10 | from .utils.meters import AverageMeter 11 | from .models import * 12 | from .evaluation_metrics import accuracy 13 | from .models.MetaModules import MixUpBatchNorm1d as MixUp1D 14 | 15 | class Trainer(object): 16 | def __init__(self, args, model, memory, criterion): 17 | super(Trainer, self).__init__() 18 | self.model = model 19 | self.memory = memory 20 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 21 | self.criterion = criterion 22 | self.args = args 23 | 24 | def train(self, epoch, data_loaders, optimizer, print_freq=10, train_iters=400): 25 | self.model.train() 26 | 27 | batch_time = AverageMeter() 28 | data_time = AverageMeter() 29 | losses = AverageMeter() 30 | losses_meta_train = AverageMeter() 31 | losses_meta_test = AverageMeter() 32 | metaLR = optimizer.param_groups[0]['lr'] 33 | 34 | source_count = len(data_loaders) 35 | 36 | 37 | end = time.time() 38 | for i in range(train_iters): 39 | metaTestID = np.random.choice(source_count) 40 | network_bns = [x for x in list(self.model.modules()) if isinstance(x, MixUp1D)] 41 | 42 | for bn in network_bns: 43 | bn.meta_mean1 = torch.zeros(bn.meta_mean1.size()).float().cuda() 44 | bn.meta_var1 = torch.zeros(bn.meta_var1.size()).float().cuda() 45 | bn.meta_mean2 = torch.zeros(bn.meta_mean2.size()).float().cuda() 46 | bn.meta_var2 = torch.zeros(bn.meta_var2.size()).float().cuda() 47 | 48 | # with torch.autograd.set_detect_anomaly(True): 49 | if True: 50 | data_loader_index = [i for i in range(source_count)] ## 0 2 51 | del data_loader_index[metaTestID] 52 | batch_data = [data_loaders[i].next() for i in range(source_count)] 53 | metaTestinputs = batch_data[metaTestID] 54 | data_time.update(time.time() - end) 55 | # process inputs 56 | testInputs, testPids, _, _, _ = self._parse_data(metaTestinputs) 57 | loss_meta_train = 0. 58 | save_index = 0 59 | for t in data_loader_index: # 0 1 60 | data_time.update(time.time() - end) 61 | traininputs = batch_data[t] 62 | save_index += 1 63 | inputs, targets, _, _, _ = self._parse_data(traininputs) 64 | 65 | f_out, tri_features = self.model(inputs, MTE='', save_index=save_index) 66 | loss_mtr_tri = self.criterion(tri_features, targets) 67 | loss_s = self.memory[t](f_out, targets).mean() 68 | 69 | loss_meta_train = loss_meta_train + loss_s + loss_mtr_tri 70 | 71 | loss_meta_train = loss_meta_train / (source_count - 1) 72 | 73 | self.model.zero_grad() 74 | grad_info = torch.autograd.grad(loss_meta_train, self.model.module.params(), create_graph=True) 75 | self.newMeta = create(self.args.arch, norm=True, BNNeck=self.args.BNNeck) 76 | # creatmodel = time.time() 77 | self.newMeta.copyModel(self.model.module) 78 | # copymodel = time.time() 79 | self.newMeta.update_params( 80 | lr_inner=metaLR, source_params=grad_info, solver='adam' 81 | ) 82 | 83 | del grad_info 84 | 85 | self.newMeta = nn.DataParallel(self.newMeta).to(self.device) 86 | 87 | f_test, mte_tri = self.newMeta(testInputs, MTE=self.args.BNtype) 88 | 89 | loss_meta_test = 0. 90 | if isinstance(f_test, list): 91 | for feature in f_test: 92 | loss_meta_test += self.memory[metaTestID](feature, testPids).mean() 93 | loss_meta_test /= len(f_test) 94 | 95 | else: 96 | loss_meta_test = self.memory[metaTestID](f_test, testPids).mean() 97 | 98 | loss_mte_tri = self.criterion(mte_tri, testPids) 99 | loss_meta_test = loss_meta_test + loss_mte_tri 100 | 101 | loss_final = loss_meta_train + loss_meta_test 102 | losses_meta_train.update(loss_meta_train.item()) 103 | losses_meta_test.update(loss_meta_test.item()) 104 | 105 | optimizer.zero_grad() 106 | loss_final.backward() 107 | optimizer.step() 108 | 109 | with torch.no_grad(): 110 | for m_ind in range(source_count): 111 | imgs, pids, _, _, _ = self._parse_data(batch_data[m_ind]) 112 | f_new, _ = self.model(imgs) 113 | self.memory[m_ind].module.MomentumUpdate(f_new, pids) 114 | 115 | 116 | losses.update(loss_final.item()) 117 | 118 | # print log 119 | batch_time.update(time.time() - end) 120 | end = time.time() 121 | 122 | 123 | if (i + 1) % print_freq == 0: 124 | print('Epoch: [{}][{}/{}]\t' 125 | 'Time {:.3f} ({:.3f})\t' 126 | 'Total loss {:.3f} ({:.3f})\t' 127 | 'Loss {:.3f}({:.3f})\t' 128 | 'LossMeta {:.3f}({:.3f})' 129 | .format(epoch, i + 1, train_iters, 130 | batch_time.val, batch_time.avg, 131 | losses.val, losses.avg, 132 | losses_meta_train.val, losses_meta_train.avg, 133 | losses_meta_test.val, losses_meta_test.avg)) 134 | 135 | def _parse_data(self, inputs): 136 | imgs, names, pids, cams, dataset_id, indexes = inputs 137 | return imgs.cuda(), pids.cuda(), indexes.cuda(), cams.cuda(), dataset_id.cuda() 138 | 139 | 140 | -------------------------------------------------------------------------------- /reid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | -------------------------------------------------------------------------------- /reid/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .base_dataset import BaseDataset, BaseImageDataset 4 | from .preprocessor import Preprocessor 5 | 6 | 7 | class IterLoader: 8 | def __init__(self, loader, length=None): 9 | self.loader = loader 10 | self.length = length 11 | self.iter = None 12 | 13 | def __len__(self): 14 | if self.length is not None: 15 | return self.length 16 | return len(self.loader) 17 | 18 | def new_epoch(self): 19 | self.iter = iter(self.loader) 20 | 21 | def next(self): 22 | try: 23 | return next(self.iter) 24 | except: 25 | self.iter = iter(self.loader) 26 | return next(self.iter) 27 | -------------------------------------------------------------------------------- /reid/utils/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import numpy as np 3 | 4 | 5 | class BaseDataset(object): 6 | """ 7 | Base class of reid dataset 8 | """ 9 | 10 | def get_imagedata_info(self, data): 11 | pids, cams = [], [] 12 | for _, pid, camid, _ in data: 13 | pids += [pid] 14 | cams += [camid] 15 | pids = set(pids) 16 | cams = set(cams) 17 | num_pids = len(pids) 18 | num_cams = len(cams) 19 | num_imgs = len(data) 20 | return num_pids, num_imgs, num_cams 21 | 22 | def print_dataset_statistics(self): 23 | raise NotImplementedError 24 | 25 | @property 26 | def images_dir(self): 27 | return None 28 | 29 | 30 | class BaseImageDataset(BaseDataset): 31 | """ 32 | Base class of image reid dataset 33 | """ 34 | 35 | def print_dataset_statistics(self, train, query, gallery): 36 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 37 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 38 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 39 | 40 | print("Dataset statistics:") 41 | print(" ----------------------------------------") 42 | print(" subset | # ids | # images | # cameras") 43 | print(" ----------------------------------------") 44 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 45 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 46 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 47 | print(" ----------------------------------------") 48 | -------------------------------------------------------------------------------- /reid/utils/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import os.path as osp 4 | from torch.utils.data import DataLoader, Dataset 5 | import numpy as np 6 | import random 7 | import math 8 | from PIL import Image 9 | 10 | 11 | class Preprocessor(Dataset): 12 | def __init__(self, dataset, root=None, transform=None): 13 | super(Preprocessor, self).__init__() 14 | self.dataset = dataset 15 | self.root = root 16 | self.transform = transform 17 | 18 | def __len__(self): 19 | return len(self.dataset) 20 | 21 | def __getitem__(self, indices): 22 | return self._get_single_item(indices) 23 | 24 | def _get_single_item(self, index): 25 | fname, pid, camid, dataset_id = self.dataset[index] 26 | fpath = fname 27 | if self.root is not None: 28 | fpath = osp.join(self.root, fname) 29 | 30 | img = Image.open(fpath).convert('RGB') 31 | 32 | if self.transform is not None: 33 | img = self.transform(img) 34 | 35 | return img, fname, pid, camid, dataset_id, index 36 | -------------------------------------------------------------------------------- /reid/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | import math 4 | 5 | import numpy as np 6 | import copy 7 | import random 8 | import torch 9 | from torch.utils.data.sampler import ( 10 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 11 | WeightedRandomSampler) 12 | 13 | 14 | def No_index(a, b): 15 | assert isinstance(a, list) 16 | return [i for i, j in enumerate(a) if j != b] 17 | 18 | 19 | class RandomIdentitySampler(Sampler): 20 | def __init__(self, data_source, num_instances): 21 | self.data_source = data_source 22 | self.num_instances = num_instances 23 | self.index_dic = defaultdict(list) 24 | for index, (_, pid, _) in enumerate(data_source): 25 | self.index_dic[pid].append(index) 26 | self.pids = list(self.index_dic.keys()) 27 | self.num_samples = len(self.pids) 28 | 29 | def __len__(self): 30 | return self.num_samples * self.num_instances 31 | 32 | def __iter__(self): 33 | indices = torch.randperm(self.num_samples).tolist() 34 | ret = [] 35 | for i in indices: 36 | pid = self.pids[i] 37 | t = self.index_dic[pid] 38 | if len(t) >= self.num_instances: 39 | t = np.random.choice(t, size=self.num_instances, replace=False) 40 | else: 41 | t = np.random.choice(t, size=self.num_instances, replace=True) 42 | ret.extend(t) 43 | return iter(ret) 44 | 45 | 46 | class RandomMultipleGallerySampler(Sampler): 47 | def __init__(self, data_source, num_instances=4): 48 | self.data_source = data_source 49 | self.index_pid = defaultdict(int) 50 | self.pid_cam = defaultdict(list) 51 | self.pid_index = defaultdict(list) 52 | self.num_instances = num_instances 53 | 54 | for index, (_, pid, cam, _) in enumerate(data_source): 55 | if (pid<0): continue 56 | self.index_pid[index] = pid 57 | self.pid_cam[pid].append(cam) 58 | self.pid_index[pid].append(index) 59 | 60 | self.pids = list(self.pid_index.keys()) 61 | self.num_samples = len(self.pids) 62 | 63 | def __len__(self): 64 | return self.num_samples * self.num_instances 65 | 66 | def __iter__(self): 67 | indices = torch.randperm(len(self.pids)).tolist() 68 | ret = [] 69 | 70 | for kid in indices: 71 | i = random.choice(self.pid_index[self.pids[kid]]) 72 | 73 | _, i_pid, i_cam, _ = self.data_source[i] 74 | 75 | ret.append(i) 76 | 77 | pid_i = self.index_pid[i] 78 | cams = self.pid_cam[pid_i] 79 | index = self.pid_index[pid_i] 80 | select_cams = No_index(cams, i_cam) 81 | 82 | if select_cams: 83 | 84 | if len(select_cams) >= self.num_instances: 85 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=False) 86 | else: 87 | cam_indexes = np.random.choice(select_cams, size=self.num_instances-1, replace=True) 88 | 89 | for kk in cam_indexes: 90 | ret.append(index[kk]) 91 | 92 | else: 93 | select_indexes = No_index(index, i) 94 | if (not select_indexes): continue 95 | if len(select_indexes) >= self.num_instances: 96 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=False) 97 | else: 98 | ind_indexes = np.random.choice(select_indexes, size=self.num_instances-1, replace=True) 99 | 100 | for kk in ind_indexes: 101 | ret.append(index[kk]) 102 | 103 | 104 | return iter(ret) 105 | -------------------------------------------------------------------------------- /reid/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import random 6 | import math 7 | import numpy as np 8 | 9 | class RectScale(object): 10 | def __init__(self, height, width, interpolation=Image.BILINEAR): 11 | self.height = height 12 | self.width = width 13 | self.interpolation = interpolation 14 | 15 | def __call__(self, img): 16 | w, h = img.size 17 | if h == self.height and w == self.width: 18 | return img 19 | return img.resize((self.width, self.height), self.interpolation) 20 | 21 | 22 | class RandomSizedRectCrop(object): 23 | def __init__(self, height, width, interpolation=Image.BILINEAR): 24 | self.height = height 25 | self.width = width 26 | self.interpolation = interpolation 27 | 28 | def __call__(self, img): 29 | for attempt in range(10): 30 | area = img.size[0] * img.size[1] 31 | target_area = random.uniform(0.64, 1.0) * area 32 | aspect_ratio = random.uniform(2, 3) 33 | 34 | h = int(round(math.sqrt(target_area * aspect_ratio))) 35 | w = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | if w <= img.size[0] and h <= img.size[1]: 38 | x1 = random.randint(0, img.size[0] - w) 39 | y1 = random.randint(0, img.size[1] - h) 40 | 41 | img = img.crop((x1, y1, x1 + w, y1 + h)) 42 | assert(img.size == (w, h)) 43 | 44 | return img.resize((self.width, self.height), self.interpolation) 45 | 46 | # Fallback 47 | scale = RectScale(self.height, self.width, 48 | interpolation=self.interpolation) 49 | return scale(img) 50 | 51 | 52 | class RandomErasing(object): 53 | """ Randomly selects a rectangle region in an image and erases its pixels. 54 | 'Random Erasing Data Augmentation' by Zhong et al. 55 | See https://arxiv.org/pdf/1708.04896.pdf 56 | Args: 57 | probability: The probability that the Random Erasing operation will be performed. 58 | sl: Minimum proportion of erased area against input image. 59 | sh: Maximum proportion of erased area against input image. 60 | r1: Minimum aspect ratio of erased area. 61 | mean: Erasing value. 62 | """ 63 | 64 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 65 | self.probability = probability 66 | self.mean = mean 67 | self.sl = sl 68 | self.sh = sh 69 | self.r1 = r1 70 | 71 | def __call__(self, img): 72 | 73 | if random.uniform(0, 1) >= self.probability: 74 | return img 75 | 76 | for attempt in range(100): 77 | area = img.size()[1] * img.size()[2] 78 | 79 | target_area = random.uniform(self.sl, self.sh) * area 80 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 81 | 82 | h = int(round(math.sqrt(target_area * aspect_ratio))) 83 | w = int(round(math.sqrt(target_area / aspect_ratio))) 84 | 85 | if w < img.size()[2] and h < img.size()[1]: 86 | x1 = random.randint(0, img.size()[1] - h) 87 | y1 = random.randint(0, img.size()[2] - w) 88 | if img.size()[0] == 3: 89 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 90 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 91 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 92 | else: 93 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 94 | return img 95 | 96 | return img 97 | -------------------------------------------------------------------------------- /reid/utils/faiss_rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 5 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 6 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 7 | """ 8 | 9 | import faiss, torch 10 | import os, sys, time 11 | import numpy as np 12 | from scipy.spatial.distance import cdist 13 | import gc 14 | import torch.nn.functional as F 15 | 16 | from .faiss_utils import search_index_pytorch, search_raw_array_pytorch, \ 17 | index_init_gpu, index_init_cpu 18 | 19 | 20 | def k_reciprocal_neigh(initial_rank, i, k1): 21 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 22 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 23 | fi = np.where(backward_k_neigh_index == i)[0] 24 | return forward_k_neigh_index[fi] 25 | 26 | 27 | def compute_jaccard_distance(target_features, k1=20, k2=6, print_flag=True, search_option=0, use_float16=False): 28 | end = time.time() 29 | if print_flag: 30 | print('Computing jaccard distance...') 31 | 32 | ngpus = faiss.get_num_gpus() 33 | N = target_features.size(0) 34 | mat_type = np.float16 if use_float16 else np.float32 35 | 36 | if search_option == 0: 37 | # GPU + PyTorch CUDA Tensors (1) 38 | res = faiss.StandardGpuResources() 39 | res.setDefaultNullStreamAllDevices() 40 | _, initial_rank = search_raw_array_pytorch(res, target_features, target_features, k1) 41 | initial_rank = initial_rank.cpu().numpy() 42 | elif search_option == 1: 43 | # GPU + PyTorch CUDA Tensors (2) 44 | res = faiss.StandardGpuResources() 45 | index = faiss.GpuIndexFlatL2(res, target_features.size(-1)) 46 | index.add(target_features.cpu().numpy()) 47 | _, initial_rank = search_index_pytorch(index, target_features, k1) 48 | res.syncDefaultStreamCurrentDevice() 49 | initial_rank = initial_rank.cpu().numpy() 50 | elif search_option == 2: 51 | # GPU 52 | index = index_init_gpu(ngpus, target_features.size(-1)) 53 | index.add(target_features.cpu().numpy()) 54 | _, initial_rank = index.search(target_features.cpu().numpy(), k1) 55 | else: 56 | # CPU 57 | index = index_init_cpu(target_features.size(-1)) 58 | index.add(target_features.cpu().numpy()) 59 | _, initial_rank = index.search(target_features.cpu().numpy(), k1) 60 | 61 | nn_k1 = [] 62 | nn_k1_half = [] 63 | for i in range(N): 64 | nn_k1.append(k_reciprocal_neigh(initial_rank, i, k1)) 65 | nn_k1_half.append(k_reciprocal_neigh(initial_rank, i, int(np.around(k1 / 2)))) 66 | 67 | V = np.zeros((N, N), dtype=mat_type) 68 | for i in range(N): 69 | k_reciprocal_index = nn_k1[i] 70 | k_reciprocal_expansion_index = k_reciprocal_index 71 | for candidate in k_reciprocal_index: 72 | candidate_k_reciprocal_index = nn_k1_half[candidate] 73 | if (len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 74 | candidate_k_reciprocal_index)): 75 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 76 | 77 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) # element-wise unique 78 | dist = 2 - 2 * torch.mm(target_features[i].unsqueeze(0).contiguous(), 79 | target_features[k_reciprocal_expansion_index].t()) 80 | if use_float16: 81 | V[i, k_reciprocal_expansion_index] = F.softmax(-dist, dim=1).view(-1).cpu().numpy().astype(mat_type) 82 | else: 83 | V[i, k_reciprocal_expansion_index] = F.softmax(-dist, dim=1).view(-1).cpu().numpy() 84 | 85 | del nn_k1, nn_k1_half 86 | 87 | if k2 != 1: 88 | V_qe = np.zeros_like(V, dtype=mat_type) 89 | for i in range(N): 90 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 91 | V = V_qe 92 | del V_qe 93 | 94 | del initial_rank 95 | 96 | invIndex = [] 97 | for i in range(N): 98 | invIndex.append(np.where(V[:, i] != 0)[0]) # len(invIndex)=all_num 99 | 100 | jaccard_dist = np.zeros((N, N), dtype=mat_type) 101 | for i in range(N): 102 | temp_min = np.zeros((1, N), dtype=mat_type) 103 | indNonZero = np.where(V[i, :] != 0)[0] 104 | indImages = [invIndex[ind] for ind in indNonZero] 105 | for j in range(len(indNonZero)): 106 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 107 | V[indImages[j], indNonZero[j]]) 108 | 109 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 110 | # jaccard_dist[i] = 1-temp_min/(temp_max+1e-6) 111 | 112 | del invIndex, V 113 | 114 | pos_bool = (jaccard_dist < 0) 115 | jaccard_dist[pos_bool] = 0.0 116 | if print_flag: 117 | print("Jaccard distance computing time cost: {}".format(time.time() - end)) 118 | 119 | return jaccard_dist 120 | -------------------------------------------------------------------------------- /reid/utils/faiss_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import faiss 4 | import torch 5 | 6 | 7 | def swig_ptr_from_FloatTensor(x): 8 | assert x.is_contiguous() 9 | assert x.dtype == torch.float32 10 | return faiss.cast_integer_to_float_ptr( 11 | x.storage().data_ptr() + x.storage_offset() * 4) 12 | 13 | 14 | def swig_ptr_from_LongTensor(x): 15 | assert x.is_contiguous() 16 | assert x.dtype == torch.int64, 'dtype=%s' % x.dtype 17 | return faiss.cast_integer_to_long_ptr( 18 | x.storage().data_ptr() + x.storage_offset() * 8) 19 | 20 | 21 | def search_index_pytorch(index, x, k, D=None, I=None): 22 | """call the search function of an index with pytorch tensor I/O (CPU 23 | and GPU supported)""" 24 | assert x.is_contiguous() 25 | n, d = x.size() 26 | assert d == index.d 27 | 28 | if D is None: 29 | D = torch.empty((n, k), dtype=torch.float32, device=x.device) 30 | else: 31 | assert D.size() == (n, k) 32 | 33 | if I is None: 34 | I = torch.empty((n, k), dtype=torch.int64, device=x.device) 35 | else: 36 | assert I.size() == (n, k) 37 | torch.cuda.synchronize() 38 | xptr = swig_ptr_from_FloatTensor(x) 39 | Iptr = swig_ptr_from_LongTensor(I) 40 | Dptr = swig_ptr_from_FloatTensor(D) 41 | index.search_c(n, xptr, 42 | k, Dptr, Iptr) 43 | torch.cuda.synchronize() 44 | return D, I 45 | 46 | 47 | def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None, 48 | metric=faiss.METRIC_L2): 49 | assert xb.device == xq.device 50 | 51 | nq, d = xq.size() 52 | if xq.is_contiguous(): 53 | xq_row_major = True 54 | elif xq.t().is_contiguous(): 55 | xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-) 56 | xq_row_major = False 57 | else: 58 | raise TypeError('matrix should be row or column-major') 59 | 60 | xq_ptr = swig_ptr_from_FloatTensor(xq) 61 | 62 | nb, d2 = xb.size() 63 | assert d2 == d 64 | if xb.is_contiguous(): 65 | xb_row_major = True 66 | elif xb.t().is_contiguous(): 67 | xb = xb.t() 68 | xb_row_major = False 69 | else: 70 | raise TypeError('matrix should be row or column-major') 71 | xb_ptr = swig_ptr_from_FloatTensor(xb) 72 | 73 | if D is None: 74 | D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) 75 | else: 76 | assert D.shape == (nq, k) 77 | assert D.device == xb.device 78 | 79 | if I is None: 80 | I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) 81 | else: 82 | assert I.shape == (nq, k) 83 | assert I.device == xb.device 84 | 85 | D_ptr = swig_ptr_from_FloatTensor(D) 86 | I_ptr = swig_ptr_from_LongTensor(I) 87 | 88 | faiss.bruteForceKnn(res, metric, 89 | xb_ptr, xb_row_major, nb, 90 | xq_ptr, xq_row_major, nq, 91 | d, k, D_ptr, I_ptr) 92 | 93 | return D, I 94 | 95 | 96 | def kMeans(data, numCluster, useGpu=True): 97 | kFunc = faiss.Kmeans(data.shape[1], numCluster, gpu=useGpu) 98 | kFunc.cp.max_points_per_centroid = ((data.shape[0] + numCluster - 1) // numCluster) 99 | if data.is_cuda: 100 | data = data.cpu() 101 | kFunc.train(data.numpy()) 102 | # assign labels 103 | _, labels = kFunc.index.search(data.numpy(), 1) 104 | return kFunc.centroids, labels.squeeze() 105 | 106 | 107 | def index_init_gpu(ngpus, feat_dim): 108 | flat_config = [] 109 | for i in range(ngpus): 110 | cfg = faiss.GpuIndexFlatConfig() 111 | cfg.useFloat16 = False 112 | cfg.device = i 113 | flat_config.append(cfg) 114 | 115 | res = [faiss.StandardGpuResources() for i in range(ngpus)] 116 | indexes = [faiss.GpuIndexFlatL2(res[i], feat_dim, flat_config[i]) for i in range(ngpus)] 117 | index = faiss.IndexShards(feat_dim) 118 | for sub_index in indexes: 119 | index.add_shard(sub_index) 120 | index.reset() 121 | return index 122 | 123 | 124 | def index_init_cpu(feat_dim): 125 | return faiss.IndexFlatL2(feat_dim) 126 | -------------------------------------------------------------------------------- /reid/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /reid/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /reid/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /reid/utils/rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Source: https://github.com/zhunzhong07/person-re-ranking 5 | Created on Mon Jun 26 14:46:56 2017 6 | @author: luohao 7 | Modified by Houjing Huang, 2017-12-22. 8 | - This version accepts distance matrix instead of raw features. 9 | - The difference of `/` division between python 2 and 3 is handled. 10 | - numpy.float16 is replaced by numpy.float32 for numerical precision. 11 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 12 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 13 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 14 | API 15 | q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery] 16 | q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query] 17 | g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery] 18 | k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3) 19 | Returns: 20 | final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery] 21 | """ 22 | from __future__ import absolute_import 23 | from __future__ import print_function 24 | from __future__ import division 25 | 26 | __all__ = ['re_ranking'] 27 | 28 | import numpy as np 29 | 30 | 31 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3): 32 | 33 | # The following naming, e.g. gallery_num, is different from outer scope. 34 | # Don't care about it. 35 | 36 | original_dist = np.concatenate( 37 | [np.concatenate([q_q_dist, q_g_dist], axis=1), 38 | np.concatenate([q_g_dist.T, g_g_dist], axis=1)], 39 | axis=0) 40 | original_dist = np.power(original_dist, 2).astype(np.float32) 41 | original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0)) 42 | V = np.zeros_like(original_dist).astype(np.float32) 43 | initial_rank = np.argsort(original_dist).astype(np.int32) 44 | 45 | query_num = q_g_dist.shape[0] 46 | gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1] 47 | all_num = gallery_num 48 | 49 | for i in range(all_num): 50 | # k-reciprocal neighbors 51 | forward_k_neigh_index = initial_rank[i,:k1+1] 52 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 53 | fi = np.where(backward_k_neigh_index==i)[0] 54 | k_reciprocal_index = forward_k_neigh_index[fi] 55 | k_reciprocal_expansion_index = k_reciprocal_index 56 | for j in range(len(k_reciprocal_index)): 57 | candidate = k_reciprocal_index[j] 58 | candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2.))+1] 59 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2.))+1] 60 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 61 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 62 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index): 63 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 64 | 65 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 66 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index]) 67 | V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight) 68 | original_dist = original_dist[:query_num,] 69 | if k2 != 1: 70 | V_qe = np.zeros_like(V,dtype=np.float32) 71 | for i in range(all_num): 72 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0) 73 | V = V_qe 74 | del V_qe 75 | del initial_rank 76 | invIndex = [] 77 | for i in range(gallery_num): 78 | invIndex.append(np.where(V[:,i] != 0)[0]) 79 | 80 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float32) 81 | 82 | 83 | for i in range(query_num): 84 | temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float32) 85 | indNonZero = np.where(V[i,:] != 0)[0] 86 | indImages = [] 87 | indImages = [invIndex[ind] for ind in indNonZero] 88 | for j in range(len(indNonZero)): 89 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 90 | jaccard_dist[i] = 1-temp_min/(2.-temp_min) 91 | 92 | final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value 93 | del original_dist 94 | del V 95 | del jaccard_dist 96 | final_dist = final_dist[:query_num,query_num:] 97 | return final_dist 98 | -------------------------------------------------------------------------------- /reid/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | if is_best: 28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 29 | 30 | 31 | def load_checkpoint(fpath): 32 | if osp.isfile(fpath): 33 | # checkpoint = torch.load(fpath) 34 | checkpoint = torch.load(fpath, map_location=torch.device('cpu')) 35 | print("=> Loaded checkpoint '{}'".format(fpath)) 36 | return checkpoint 37 | else: 38 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 39 | 40 | 41 | def copy_state_dict(state_dict, model, strip=None): 42 | tgt_state = model.state_dict() 43 | copied_names = set() 44 | for name, param in state_dict.items(): 45 | if strip is not None and name.startswith(strip): 46 | name = name[len(strip):] 47 | if name not in tgt_state: 48 | continue 49 | if isinstance(param, Parameter): 50 | param = param.data 51 | if param.size() != tgt_state[name].size(): 52 | print('mismatch:', name, param.size(), tgt_state[name].size()) 53 | continue 54 | tgt_state[name].copy_(param) 55 | copied_names.add(name) 56 | 57 | missing = set(tgt_state.keys()) - copied_names 58 | if len(missing) > 0: 59 | print("missing keys in state_dict:", missing) 60 | 61 | return model 62 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.3.1 2 | scipy==1.3.1 3 | torchvision 4 | seaborn==0.9.0 5 | numpy==1.17.2 6 | matplotlib==3.1.1 7 | Pillow==8.2.0 8 | scikit_learn==0.24.1 9 | --------------------------------------------------------------------------------