├── PCB.py ├── README.md ├── RPP.py ├── reid ├── __init__.py ├── __init__.pyc ├── datasets │ ├── __init__.py │ ├── __init__.pyc │ ├── duke.py │ ├── duke.pyc │ ├── market.py │ └── market.pyc ├── evaluation_metrics │ ├── __init__.py │ ├── __init__.pyc │ ├── classification.py │ ├── classification.pyc │ ├── ranking.py │ └── ranking.pyc ├── evaluators.py ├── evaluators.pyc ├── feature_extraction │ ├── __init__.py │ ├── __init__.pyc │ ├── cnn.py │ ├── cnn.pyc │ ├── database.py │ └── database.pyc ├── models │ ├── .resnet.py.swo │ ├── .resnet_mask.py.swo │ ├── __init__.py │ ├── __init__.pyc │ ├── resnet.py │ ├── resnet.pyc │ ├── resnet_mask.pyc │ ├── resnet_rpp.py │ └── resnet_rpp.pyc ├── trainers_partloss.py ├── trainers_partloss.pyc └── utils │ ├── __init__.py │ ├── __init__.pyc │ ├── data │ ├── __init__.py │ ├── __init__.pyc │ ├── dataset.py │ ├── dataset.pyc │ ├── preprocessor.py │ ├── preprocessor.pyc │ ├── sampler.py │ ├── sampler.pyc │ ├── transforms.py │ └── transforms.pyc │ ├── logging.py │ ├── logging.pyc │ ├── meters.py │ ├── meters.pyc │ ├── osutils.py │ ├── osutils.pyc │ ├── progress │ ├── LICENSE │ ├── MANIFEST.in │ ├── README.rst │ ├── progress │ │ ├── __init__.py │ │ ├── __init__.pyc │ │ ├── bar.py │ │ ├── bar.pyc │ │ ├── counter.py │ │ ├── helpers.py │ │ ├── helpers.pyc │ │ └── spinner.py │ ├── setup.py │ └── test_progress.py │ ├── serialization.py │ └── serialization.pyc ├── train_PCB.sh └── train_RPP.sh /PCB.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os.path as osp 4 | 5 | import numpy as np 6 | import sys 7 | import torch 8 | from torch import nn 9 | from torch.backends import cudnn 10 | from torch.utils.data import DataLoader 11 | 12 | from reid import datasets 13 | from reid import models 14 | from reid.trainers_partloss import Trainer 15 | from reid.evaluators import Evaluator 16 | from reid.utils.data import transforms as T 17 | from reid.utils.data.preprocessor import Preprocessor 18 | from reid.utils.logging import Logger 19 | from reid.utils.serialization import load_checkpoint, save_checkpoint 20 | 21 | def get_data(name, data_dir, height, width, batch_size, workers): 22 | root = osp.join(data_dir, name) 23 | root = data_dir 24 | dataset = datasets.create(name, root) 25 | 26 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 27 | std=[0.229, 0.224, 0.225]) 28 | 29 | num_classes = dataset.num_train_ids 30 | 31 | train_transformer = T.Compose([ 32 | T.RectScale(height, width), 33 | T.RandomHorizontalFlip(), 34 | T.ToTensor(), 35 | normalizer, 36 | ]) 37 | 38 | test_transformer = T.Compose([ 39 | T.RectScale(height, width), 40 | T.ToTensor(), 41 | normalizer, 42 | ]) 43 | 44 | train_loader = DataLoader( 45 | Preprocessor(dataset.train, root=osp.join(dataset.images_dir,dataset.train_path), 46 | transform=train_transformer), 47 | batch_size=batch_size, num_workers=workers, 48 | shuffle=True, pin_memory=True, drop_last=True) 49 | 50 | query_loader = DataLoader( 51 | Preprocessor(dataset.query, root=osp.join(dataset.images_dir,dataset.query_path), 52 | transform=test_transformer), 53 | batch_size=batch_size, num_workers=workers, 54 | shuffle=False, pin_memory=True) 55 | 56 | gallery_loader = DataLoader( 57 | Preprocessor(dataset.gallery, root=osp.join(dataset.images_dir,dataset.gallery_path), 58 | transform=test_transformer), 59 | batch_size=batch_size, num_workers=workers, 60 | shuffle=False, pin_memory=True) 61 | 62 | 63 | return dataset, num_classes, train_loader, query_loader, gallery_loader 64 | 65 | 66 | def main(args): 67 | np.random.seed(args.seed) 68 | torch.manual_seed(args.seed) 69 | torch.cuda.manual_seed_all(args.seed) 70 | cudnn.benchmark = True 71 | 72 | # Redirect print to both console and log file 73 | if not args.evaluate: 74 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 75 | 76 | # Create data loaders 77 | if args.height is None or args.width is None: 78 | args.height, args.width = (144, 56) if args.arch == 'inception' else \ 79 | (256, 128) 80 | dataset, num_classes, train_loader, query_loader, gallery_loader = \ 81 | get_data(args.dataset, args.data_dir, args.height, 82 | args.width, args.batch_size, args.workers, 83 | ) 84 | 85 | 86 | # Create model 87 | model = models.create(args.arch, num_features=args.features, 88 | dropout=args.dropout, num_classes=num_classes,cut_at_pooling=False, FCN=True) 89 | 90 | # Load from checkpoint 91 | start_epoch = best_top1 = 0 92 | if args.resume: 93 | checkpoint = load_checkpoint(args.resume) 94 | model_dict = model.state_dict() 95 | checkpoint_load = {k: v for k, v in (checkpoint['state_dict']).items() if k in model_dict} 96 | model_dict.update(checkpoint_load) 97 | model.load_state_dict(model_dict) 98 | # model.load_state_dict(checkpoint['state_dict']) 99 | start_epoch = checkpoint['epoch'] 100 | best_top1 = checkpoint['best_top1'] 101 | print("=> Start epoch {} best top1 {:.1%}" 102 | .format(start_epoch, best_top1)) 103 | 104 | model = nn.DataParallel(model).cuda() 105 | 106 | 107 | # Evaluator 108 | evaluator = Evaluator(model) 109 | if args.evaluate: 110 | print("Test:") 111 | evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery) 112 | return 113 | 114 | # Criterion 115 | criterion = nn.CrossEntropyLoss().cuda() 116 | 117 | # Optimizer 118 | if hasattr(model.module, 'base'): 119 | base_param_ids = set(map(id, model.module.base.parameters())) 120 | new_params = [p for p in model.parameters() if 121 | id(p) not in base_param_ids] 122 | param_groups = [ 123 | {'params': model.module.base.parameters(), 'lr_mult': 0.1}, 124 | {'params': new_params, 'lr_mult': 1.0}] 125 | else: 126 | param_groups = model.parameters() 127 | optimizer = torch.optim.SGD(param_groups, lr=args.lr, 128 | momentum=args.momentum, 129 | weight_decay=args.weight_decay, 130 | nesterov=True) 131 | 132 | # Trainer 133 | trainer = Trainer(model, criterion, 0, 0, SMLoss_mode=0) 134 | 135 | # Schedule learning rate 136 | def adjust_lr(epoch): 137 | step_size = 60 if args.arch == 'inception' else args.step_size 138 | lr = args.lr * (0.1 ** (epoch // step_size)) 139 | for g in optimizer.param_groups: 140 | g['lr'] = lr * g.get('lr_mult', 1) 141 | 142 | # Start training 143 | for epoch in range(start_epoch, args.epochs): 144 | adjust_lr(epoch) 145 | trainer.train(epoch, train_loader, optimizer) 146 | is_best = True 147 | save_checkpoint({ 148 | 'state_dict': model.module.state_dict(), 149 | 'epoch': epoch + 1, 150 | 'best_top1': best_top1, 151 | }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) 152 | 153 | # Final test 154 | print('Test with best model:') 155 | checkpoint = load_checkpoint(osp.join(args.logs_dir, 'checkpoint.pth.tar')) 156 | model.module.load_state_dict(checkpoint['state_dict']) 157 | evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery) 158 | 159 | 160 | if __name__ == '__main__': 161 | parser = argparse.ArgumentParser(description="Softmax loss classification") 162 | # data 163 | parser.add_argument('-d', '--dataset', type=str, default='cuhk03', 164 | choices=datasets.names()) 165 | parser.add_argument('-b', '--batch-size', type=int, default=256) 166 | parser.add_argument('-j', '--workers', type=int, default=4) 167 | parser.add_argument('--split', type=int, default=0) 168 | parser.add_argument('--height', type=int, 169 | help="input height, default: 256 for resnet*, " 170 | "144 for inception") 171 | parser.add_argument('--width', type=int, 172 | help="input width, default: 128 for resnet*, " 173 | "56 for inception") 174 | parser.add_argument('--combine-trainval', action='store_true', 175 | help="train and val sets together for training, " 176 | "val set alone for validation") 177 | # model 178 | parser.add_argument('-a', '--arch', type=str, default='resnet50', 179 | choices=models.names()) 180 | parser.add_argument('--features', type=int, default=128) 181 | parser.add_argument('--dropout', type=float, default=0.5) 182 | # optimizer 183 | parser.add_argument('--lr', type=float, default=0.1, 184 | help="learning rate of new parameters, for pretrained " 185 | "parameters it is 10 times smaller than this") 186 | parser.add_argument('--momentum', type=float, default=0.9) 187 | parser.add_argument('--weight-decay', type=float, default=5e-4) 188 | # training configs 189 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 190 | parser.add_argument('--evaluate', action='store_true', 191 | help="evaluation only") 192 | parser.add_argument('--epochs', type=int, default=50) 193 | parser.add_argument('--step-size',type=int, default=40) 194 | parser.add_argument('--seed', type=int, default=1) 195 | parser.add_argument('--print-freq', type=int, default=1) 196 | # misc 197 | working_dir = osp.dirname(osp.abspath(__file__)) 198 | parser.add_argument('--data-dir', type=str, metavar='PATH', 199 | default=osp.join(working_dir, 'data')) 200 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 201 | default=osp.join(working_dir, 'logs')) 202 | main(parser.parse_args()) 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Part-based Convolutional Baseline for Person Retrieval and the Refined Part Pooling 2 | 3 | Code for the paper [Beyond Part Models: Person Retrieval with Refined Part Pooling (and A Strong Convolutional Baseline)](https://arxiv.org/pdf/1711.09349.pdf). 4 | 5 | **This code is ONLY** released for academic use. 6 | 7 | ## Preparation 8 | 9 | 10 | **Prerequisite: Python 2.7 and Pytorch 0.3+** 11 | 12 | 1. Install [Pytorch](https://pytorch.org/) 13 | 14 | 2. Download dataset 15 | a. Market-1501 [BaiduYun](https://pan.baidu.com/s/1ntIi2Op?errno=0&errmsg=Auth%20Login%20Sucess&&bduss=&ssnerror=0&traceid=) 16 | b. DukeMTMC-reID[BaiduYun](https://pan.baidu.com/share/init?surl=jS0XM7Var5nQGcbf9xUztw) (password:bhbh) 17 | c. Move them to ```~/datasets/Market-1501/(DukeMTMC-reID)``` 18 | 19 | 20 | ## train PCB 21 | 22 | 23 | ```sh train_PCB.sh``` 24 | With Pytorch 0.4.0, we shall get about 93.0% rank-1 accuracy and 78.0% mAP on Market-1501. 25 | 26 | 27 | ## train RPP 28 | 29 | 30 | ```sh train_RPP.sh``` 31 | With Pytorch 0.4.0, we shall get about 93.5% rank-1 accuracy and 81.5% mAP on Market-1501. 32 | 33 | 34 | ## Citiaion 35 | 36 | 37 | Please cite this paper in your publications if it helps your research: 38 | 39 | 40 | ``` 41 | @inproceedings{sun2018PCB, 42 | author = {Yifan Sun and 43 | Liang Zheng and 44 | Yi Yang and 45 | Qi Tian and 46 | Shengjin Wang}, 47 | title = {Beyond Part Models: Person Retrieval with Refined Part Pooling (and A Strong Convolutional Baseline)}, 48 | booktitle = {ECCV}, 49 | year = {2018}, 50 | } 51 | ``` 52 | -------------------------------------------------------------------------------- /RPP.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os.path as osp 4 | 5 | import numpy as np 6 | import sys 7 | import torch 8 | from torch import nn 9 | from torch.backends import cudnn 10 | from torch.utils.data import DataLoader 11 | 12 | from reid import datasets 13 | from reid import models 14 | from reid.trainers_partloss import Trainer 15 | from reid.evaluators import Evaluator 16 | from reid.utils.data import transforms as T 17 | #import torchvision.transforms as transforms 18 | from reid.utils.data.preprocessor import Preprocessor 19 | from reid.utils.logging import Logger 20 | from reid.utils.serialization import load_checkpoint, save_checkpoint 21 | 22 | 23 | def get_data(name, data_dir, height, width, batch_size, workers): 24 | root = osp.join(data_dir, name) 25 | root = data_dir 26 | dataset = datasets.create(name, root) 27 | 28 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 29 | std=[0.229, 0.224, 0.225]) 30 | 31 | num_classes = dataset.num_train_ids 32 | 33 | train_transformer = T.Compose([ 34 | T.RectScale(height, width), 35 | T.RandomHorizontalFlip(), 36 | T.ToTensor(), 37 | normalizer, 38 | ]) 39 | 40 | test_transformer = T.Compose([ 41 | T.RectScale(height, width), 42 | T.ToTensor(), 43 | normalizer, 44 | ]) 45 | 46 | train_loader = DataLoader( 47 | Preprocessor(dataset.train, root=osp.join(dataset.images_dir,dataset.train_path), 48 | transform=train_transformer), 49 | batch_size=batch_size, num_workers=workers, 50 | shuffle=True, pin_memory=True, drop_last=True) 51 | 52 | query_loader = DataLoader( 53 | Preprocessor(dataset.query, root=osp.join(dataset.images_dir,dataset.query_path), 54 | transform=test_transformer), 55 | batch_size=batch_size, num_workers=workers, 56 | shuffle=False, pin_memory=True) 57 | 58 | gallery_loader = DataLoader( 59 | Preprocessor(dataset.gallery, root=osp.join(dataset.images_dir,dataset.gallery_path), 60 | transform=test_transformer), 61 | batch_size=batch_size, num_workers=workers, 62 | shuffle=False, pin_memory=True) 63 | 64 | 65 | return dataset, num_classes, train_loader, query_loader, gallery_loader 66 | 67 | 68 | 69 | 70 | def main(args): 71 | np.random.seed(args.seed) 72 | torch.manual_seed(args.seed) 73 | torch.cuda.manual_seed_all(args.seed) 74 | cudnn.benchmark = True 75 | 76 | # Redirect print to both console and log file 77 | if not args.evaluate: 78 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 79 | 80 | # Create data loaders 81 | if args.height is None or args.width is None: 82 | args.height, args.width = (144, 56) if args.arch == 'inception' else \ 83 | (256, 128) 84 | dataset, num_classes, train_loader, query_loader, gallery_loader = \ 85 | get_data(args.dataset, args.data_dir, args.height, 86 | args.width, args.batch_size, args.workers, 87 | ) 88 | 89 | 90 | # Create model 91 | model = models.create(args.arch, num_features=args.features, 92 | dropout=args.dropout, num_classes=num_classes,cut_at_pooling=False, FCN=True, T=args.T, dim=args.dim) 93 | 94 | # Load from checkpoint 95 | start_epoch = best_top1 = 0 96 | if args.resume: 97 | checkpoint = load_checkpoint(args.resume) 98 | 99 | #======================added by syf, to remove undeployed layers============= 100 | model_dict = model.state_dict() 101 | checkpoint_load = {k: v for k, v in (checkpoint['state_dict']).items() if k in model_dict} 102 | model_dict.update(checkpoint_load) 103 | model.load_state_dict(model_dict) 104 | # model.load_state_dict(checkpoint['state_dict']) 105 | start_epoch = checkpoint['epoch'] 106 | best_top1 = checkpoint['best_top1'] 107 | print("=> Start epoch {} best top1 {:.1%}" 108 | .format(start_epoch, best_top1)) 109 | model = nn.DataParallel(model).cuda() 110 | 111 | 112 | # Evaluator 113 | evaluator = Evaluator(model) 114 | if args.evaluate: 115 | print("Test:") 116 | evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery) 117 | return 118 | 119 | # Criterion 120 | criterion = nn.CrossEntropyLoss().cuda() 121 | 122 | # Optimizer 123 | if hasattr(model.module, 'base'): 124 | base_param_ids = set(map(id, model.module.base.parameters())) 125 | new_params = [p for p in model.parameters() if 126 | id(p) not in base_param_ids] 127 | param_groups = [ 128 | {'params': model.module.base.parameters(), 'lr_mult': args.lr_mult}, 129 | {'params': new_params, 'lr_mult': 1.0}] 130 | else: 131 | param_groups = model.parameters() 132 | optimizer = torch.optim.SGD(param_groups, lr=args.lr, 133 | momentum=args.momentum, 134 | weight_decay=args.weight_decay, 135 | nesterov=True) 136 | 137 | # Trainer 138 | trainer = Trainer(model, criterion, 0, 0, SMLoss_mode=0) 139 | 140 | # Schedule learning rate 141 | def adjust_lr(epoch): 142 | step_size = 60 if args.arch == 'inception' else args.step_size 143 | lr = args.lr * (0.1 ** (epoch // step_size)) 144 | for g in optimizer.param_groups: 145 | g['lr'] = lr * g.get('lr_mult', 1) 146 | 147 | # Start training 148 | for epoch in range(start_epoch, args.epochs): 149 | adjust_lr(epoch) 150 | trainer.train(epoch, train_loader, optimizer) 151 | is_best = True 152 | save_checkpoint({ 153 | 'state_dict': model.module.state_dict(), 154 | 'epoch': epoch + 1, 155 | 'best_top1': best_top1, 156 | }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) 157 | 158 | # Final test 159 | print('Test with best model:') 160 | checkpoint = load_checkpoint(osp.join(args.logs_dir, 'checkpoint.pth.tar')) 161 | model.module.load_state_dict(checkpoint['state_dict']) 162 | evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery) 163 | 164 | 165 | if __name__ == '__main__': 166 | parser = argparse.ArgumentParser(description="Softmax loss classification") 167 | # data 168 | parser.add_argument('-d', '--dataset', type=str, default='cuhk03', 169 | choices=datasets.names()) 170 | parser.add_argument('-b', '--batch-size', type=int, default=256) 171 | parser.add_argument('-j', '--workers', type=int, default=4) 172 | parser.add_argument('--split', type=int, default=0) 173 | parser.add_argument('--height', type=int, 174 | help="input height, default: 256 for resnet*, " 175 | "144 for inception") 176 | parser.add_argument('--width', type=int, 177 | help="input width, default: 128 for resnet*, " 178 | "56 for inception") 179 | parser.add_argument('--combine-trainval', action='store_true', 180 | help="train and val sets together for training, " 181 | "val set alone for validation") 182 | # model 183 | parser.add_argument('-a', '--arch', type=str, default='resnet50', 184 | choices=models.names()) 185 | parser.add_argument('--features', type=int, default=128) 186 | parser.add_argument('--dropout', type=float, default=0.5) 187 | # optimizer 188 | parser.add_argument('--lr', type=float, default=0.1, 189 | help="learning rate of new parameters, for pretrained " 190 | "parameters it is 10 times smaller than this") 191 | parser.add_argument('--momentum', type=float, default=0.9) 192 | parser.add_argument('--weight-decay', type=float, default=5e-4) 193 | # training configs 194 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 195 | parser.add_argument('--evaluate', action='store_true', 196 | help="evaluation only") 197 | parser.add_argument('--epochs', type=int, default=50) 198 | parser.add_argument('--T', type=float, default =1)# temperature for soft pooling 199 | parser.add_argument('--lr_mult', type=float, default =1)# temperature for soft pooling 200 | parser.add_argument('--step_size',type=int, default =100) 201 | parser.add_argument('--dim', type=int, default =256)# temperature for soft pooling 202 | parser.add_argument('--seed', type=int, default=1) 203 | parser.add_argument('--print-freq', type=int, default=1) 204 | # metric learning 205 | parser.add_argument('--dist-metric', type=str, default='euclidean', 206 | choices=['euclidean', 'kissme']) 207 | # misc 208 | working_dir = osp.dirname(osp.abspath(__file__)) 209 | parser.add_argument('--data-dir', type=str, metavar='PATH', 210 | default=osp.join(working_dir, 'data')) 211 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 212 | default=osp.join(working_dir, 'logs')) 213 | main(parser.parse_args()) 214 | -------------------------------------------------------------------------------- /reid/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation_metrics 5 | from . import feature_extraction 6 | from . import models 7 | from . import utils 8 | from . import evaluators 9 | __version__ = '0.2.0' 10 | -------------------------------------------------------------------------------- /reid/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/__init__.pyc -------------------------------------------------------------------------------- /reid/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from .duke import Duke 3 | from .market import Market 4 | 5 | 6 | __factory = { 7 | 'market': Market, 8 | 'duke': Duke, 9 | } 10 | 11 | 12 | def names(): 13 | return sorted(__factory.keys()) 14 | 15 | 16 | def create(name, root, *args, **kwargs): 17 | """ 18 | Create a dataset instance. 19 | 20 | Parameters 21 | ---------- 22 | name : str 23 | The dataset name. Can be one of 'market', 'duke'. 24 | root : str 25 | The path to the dataset directory. 26 | """ 27 | if name not in __factory: 28 | raise KeyError("Unknown dataset:", name) 29 | return __factory[name](root, *args, **kwargs) 30 | -------------------------------------------------------------------------------- /reid/datasets/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/datasets/__init__.pyc -------------------------------------------------------------------------------- /reid/datasets/duke.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import numpy as np 4 | import pdb 5 | from glob import glob 6 | import re 7 | 8 | 9 | class Duke(object): 10 | 11 | def __init__(self, root): 12 | 13 | self.images_dir = osp.join(root) 14 | self.train_path = 'bounding_box_train' 15 | self.gallery_path = 'bounding_box_test' 16 | self.query_path = 'query' 17 | self.camstyle_path = 'bounding_box_train_camstyle' 18 | self.train, self.query, self.gallery, self.camstyle = [], [], [], [] 19 | self.num_train_ids, self.num_query_ids, self.num_gallery_ids, self.num_camstyle_ids = 0, 0, 0, 0 20 | self.load() 21 | 22 | def preprocess(self, path, relabel=True): 23 | pattern = re.compile(r'([-\d]+)_c(\d)') 24 | all_pids = {} 25 | ret = [] 26 | fpaths = sorted(glob(osp.join(self.images_dir, path, '*.jpg'))) 27 | for fpath in fpaths: 28 | fname = osp.basename(fpath) 29 | pid, cam = map(int, pattern.search(fname).groups()) 30 | if pid == -1: continue 31 | if relabel: 32 | if pid not in all_pids: 33 | all_pids[pid] = len(all_pids) 34 | else: 35 | if pid not in all_pids: 36 | all_pids[pid] = pid 37 | pid = all_pids[pid] 38 | cam -= 1 39 | ret.append((fname, pid, cam)) 40 | return ret, int(len(all_pids)) 41 | 42 | def load(self): 43 | self.train, self.num_train_ids = self.preprocess(self.train_path) 44 | self.gallery, self.num_gallery_ids = self.preprocess(self.gallery_path, False) 45 | self.query, self.num_query_ids = self.preprocess(self.query_path, False) 46 | self.camstyle, self.num_camstyle_ids = self.preprocess(self.camstyle_path) 47 | 48 | print(self.__class__.__name__, "dataset loaded") 49 | print(" subset | # ids | # images") 50 | print(" ---------------------------") 51 | print(" train | {:5d} | {:8d}" 52 | .format(self.num_train_ids, len(self.train))) 53 | print(" query | {:5d} | {:8d}" 54 | .format(self.num_query_ids, len(self.query))) 55 | print(" gallery | {:5d} | {:8d}" 56 | .format(self.num_gallery_ids, len(self.gallery))) 57 | print(" camstyle | {:5d} | {:8d}" 58 | .format(self.num_camstyle_ids, len(self.camstyle))) 59 | -------------------------------------------------------------------------------- /reid/datasets/duke.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/datasets/duke.pyc -------------------------------------------------------------------------------- /reid/datasets/market.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | import numpy as np 4 | import pdb 5 | from glob import glob 6 | import re 7 | 8 | 9 | class Market(object): 10 | 11 | def __init__(self, root): 12 | 13 | self.images_dir = osp.join(root) 14 | self.train_path = 'bounding_box_train' 15 | self.gallery_path = 'bounding_box_test' 16 | self.query_path = 'query' 17 | # self.camstyle_path = 'bounding_box_train_camstyle' 18 | self.train, self.query, self.gallery = [], [], [] 19 | self.num_train_ids, self.num_query_ids, self.num_gallery_ids = 0, 0, 0 20 | self.load() 21 | 22 | def preprocess(self, path, relabel=True): 23 | pattern = re.compile(r'([-\d]+)_c(\d)') 24 | all_pids = {} 25 | ret = [] 26 | fpaths = sorted(glob(osp.join(self.images_dir, path, '*.jpg'))) 27 | for fpath in fpaths: 28 | fname = osp.basename(fpath) 29 | pid, cam = map(int, pattern.search(fname).groups()) 30 | if pid == -1: continue 31 | if relabel: 32 | if pid not in all_pids: 33 | all_pids[pid] = len(all_pids) 34 | else: 35 | if pid not in all_pids: 36 | all_pids[pid] = pid 37 | pid = all_pids[pid] 38 | cam -= 1 39 | ret.append((fname, pid, cam)) 40 | return ret, int(len(all_pids)) 41 | 42 | def load(self): 43 | self.train, self.num_train_ids = self.preprocess(self.train_path) 44 | self.gallery, self.num_gallery_ids = self.preprocess(self.gallery_path, False) 45 | self.query, self.num_query_ids = self.preprocess(self.query_path, False) 46 | 47 | print(self.__class__.__name__, "dataset loaded") 48 | print(" subset | # ids | # images") 49 | print(" ---------------------------") 50 | print(" train | {:5d} | {:8d}" 51 | .format(self.num_train_ids, len(self.train))) 52 | print(" query | {:5d} | {:8d}" 53 | .format(self.num_query_ids, len(self.query))) 54 | print(" gallery | {:5d} | {:8d}" 55 | .format(self.num_gallery_ids, len(self.gallery))) 56 | -------------------------------------------------------------------------------- /reid/datasets/market.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/datasets/market.pyc -------------------------------------------------------------------------------- /reid/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classification import accuracy 4 | from .ranking import cmc, mean_ap 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap', 10 | ] 11 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/evaluation_metrics/__init__.pyc -------------------------------------------------------------------------------- /reid/evaluation_metrics/classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from ..utils import to_torch 4 | 5 | 6 | def accuracy(output, target, topk=(1,)): 7 | output, target = to_torch(output), to_torch(target) 8 | maxk = max(topk) 9 | batch_size = target.size(0) 10 | 11 | _, pred = output.topk(maxk, 1, True, True) 12 | pred = pred.t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | 15 | ret = [] 16 | for k in topk: 17 | correct_k = correct[:k].view(-1).float().sum(0) 18 | ret.append(correct_k.mul_(1. / batch_size)) 19 | return ret 20 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/classification.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/evaluation_metrics/classification.pyc -------------------------------------------------------------------------------- /reid/evaluation_metrics/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from sklearn.metrics import average_precision_score 6 | 7 | from ..utils import to_numpy 8 | 9 | 10 | def _unique_sample(ids_dict, num): 11 | mask = np.zeros(num, dtype=np.bool) 12 | for _, indices in ids_dict.items(): 13 | i = np.random.choice(indices) 14 | mask[i] = True 15 | return mask 16 | 17 | 18 | def cmc(distmat, query_ids=None, gallery_ids=None, 19 | query_cams=None, gallery_cams=None, topk=100, 20 | separate_camera_set=False, 21 | single_gallery_shot=False, 22 | first_match_break=False): 23 | distmat = to_numpy(distmat) 24 | m, n = distmat.shape 25 | # Fill up default values 26 | if query_ids is None: 27 | query_ids = np.arange(m) 28 | if gallery_ids is None: 29 | gallery_ids = np.arange(n) 30 | if query_cams is None: 31 | query_cams = np.zeros(m).astype(np.int32) 32 | if gallery_cams is None: 33 | gallery_cams = np.ones(n).astype(np.int32) 34 | # Ensure numpy array 35 | query_ids = np.asarray(query_ids) 36 | gallery_ids = np.asarray(gallery_ids) 37 | query_cams = np.asarray(query_cams) 38 | gallery_cams = np.asarray(gallery_cams) 39 | # Sort and find correct matches 40 | indices = np.argsort(distmat, axis=1) 41 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 42 | # Compute CMC for each query 43 | ret = np.zeros(topk) 44 | num_valid_queries = 0 45 | for i in range(m): 46 | # Filter out the same id and same camera 47 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 48 | (gallery_cams[indices[i]] != query_cams[i])) 49 | if separate_camera_set: 50 | # Filter out samples from same camera 51 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 52 | if not np.any(matches[i, valid]): continue 53 | if single_gallery_shot: 54 | repeat = 10 55 | gids = gallery_ids[indices[i][valid]] 56 | inds = np.where(valid)[0] 57 | ids_dict = defaultdict(list) 58 | for j, x in zip(inds, gids): 59 | ids_dict[x].append(j) 60 | else: 61 | repeat = 1 62 | for _ in range(repeat): 63 | if single_gallery_shot: 64 | # Randomly choose one instance for each id 65 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 66 | index = np.nonzero(matches[i, sampled])[0] 67 | else: 68 | index = np.nonzero(matches[i, valid])[0] 69 | delta = 1. / (len(index) * repeat) 70 | for j, k in enumerate(index): 71 | if k - j >= topk: break 72 | if first_match_break: 73 | ret[k - j] += 1 74 | break 75 | ret[k - j] += delta 76 | num_valid_queries += 1 77 | if num_valid_queries == 0: 78 | raise RuntimeError("No valid query") 79 | return ret.cumsum() / num_valid_queries 80 | 81 | 82 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 83 | query_cams=None, gallery_cams=None): 84 | distmat = to_numpy(distmat) 85 | m, n = distmat.shape 86 | # Fill up default values 87 | if query_ids is None: 88 | query_ids = np.arange(m) 89 | if gallery_ids is None: 90 | gallery_ids = np.arange(n) 91 | if query_cams is None: 92 | query_cams = np.zeros(m).astype(np.int32) 93 | if gallery_cams is None: 94 | gallery_cams = np.ones(n).astype(np.int32) 95 | # Ensure numpy array 96 | query_ids = np.asarray(query_ids) 97 | gallery_ids = np.asarray(gallery_ids) 98 | query_cams = np.asarray(query_cams) 99 | gallery_cams = np.asarray(gallery_cams) 100 | # Sort and find correct matches 101 | indices = np.argsort(distmat, axis=1) 102 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 103 | # Compute AP for each query 104 | aps = [] 105 | for i in range(m): 106 | # Filter out the same id and same camera 107 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 108 | (gallery_cams[indices[i]] != query_cams[i])) 109 | y_true = matches[i, valid] 110 | y_score = -distmat[i][indices[i]][valid] 111 | if not np.any(y_true): continue 112 | aps.append(average_precision_score(y_true, y_score)) 113 | if len(aps) == 0: 114 | raise RuntimeError("No valid query") 115 | return np.mean(aps) 116 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/ranking.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/evaluation_metrics/ranking.pyc -------------------------------------------------------------------------------- /reid/evaluators.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | from collections import OrderedDict 4 | 5 | import torch 6 | 7 | from .evaluation_metrics import cmc, mean_ap 8 | from .feature_extraction import extract_cnn_feature 9 | from .utils.meters import AverageMeter 10 | 11 | 12 | def extract_features(model, data_loader, print_freq=10): 13 | model.eval() 14 | batch_time = AverageMeter() 15 | data_time = AverageMeter() 16 | 17 | features = OrderedDict() 18 | labels = OrderedDict() 19 | 20 | end = time.time() 21 | for i, (imgs, fnames, pids, _) in enumerate(data_loader): 22 | data_time.update(time.time() - end) 23 | 24 | outputs = extract_cnn_feature(model, imgs) 25 | for fname, output, pid in zip(fnames, outputs, pids): 26 | features[fname] = output 27 | labels[fname] = pid 28 | 29 | batch_time.update(time.time() - end) 30 | end = time.time() 31 | 32 | if (i + 1) % print_freq == 0: 33 | print('Extract Features: [{}/{}]\t' 34 | 'Time {:.3f} ({:.3f})\t' 35 | 'Data {:.3f} ({:.3f})\t' 36 | .format(i + 1, len(data_loader), 37 | batch_time.val, batch_time.avg, 38 | data_time.val, data_time.avg)) 39 | 40 | return features, labels 41 | 42 | 43 | def pairwise_distance(query_features, gallery_features, query=None, gallery=None): 44 | if query is None and gallery is None: 45 | n = len(features) 46 | x = torch.cat(list(features.values())) 47 | x = x.view(n, -1) 48 | dist = torch.pow(x, 2).sum(1) * 2 49 | dist = dist.expand(n, n) - 2 * torch.mm(x, x.t()) 50 | return dist 51 | 52 | x = torch.cat([query_features[f].unsqueeze(0) for f, _, _ in query], 0) 53 | y = torch.cat([gallery_features[f].unsqueeze(0) for f, _, _ in gallery], 0) 54 | m, n = x.size(0), y.size(0) 55 | x = x.view(m, -1) 56 | y = y.view(n, -1) 57 | dist = torch.pow(x, 2).sum(1).unsqueeze(1).expand(m, n) + \ 58 | torch.pow(y, 2).sum(1).unsqueeze(1).expand(n, m).t() 59 | dist.addmm_(1, -2, x, y.t()) 60 | return dist 61 | 62 | 63 | def evaluate_all(distmat, query=None, gallery=None, 64 | query_ids=None, gallery_ids=None, 65 | query_cams=None, gallery_cams=None, 66 | cmc_topk=(1, 5, 10)): 67 | if query is not None and gallery is not None: 68 | query_ids = [pid for _, pid, _ in query] 69 | gallery_ids = [pid for _, pid, _ in gallery] 70 | query_cams = [cam for _, _, cam in query] 71 | gallery_cams = [cam for _, _, cam in gallery] 72 | else: 73 | assert (query_ids is not None and gallery_ids is not None 74 | and query_cams is not None and gallery_cams is not None) 75 | 76 | # Compute mean AP 77 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 78 | print('Mean AP: {:4.1%}'.format(mAP)) 79 | 80 | # Compute all kinds of CMC scores 81 | cmc_configs = { 82 | 'allshots': dict(separate_camera_set=False, 83 | single_gallery_shot=False, 84 | first_match_break=False), 85 | 'cuhk03': dict(separate_camera_set=True, 86 | single_gallery_shot=True, 87 | first_match_break=False), 88 | 'market1501': dict(separate_camera_set=False, 89 | single_gallery_shot=False, 90 | first_match_break=True)} 91 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 92 | query_cams, gallery_cams, **params) 93 | for name, params in cmc_configs.items()} 94 | 95 | print('CMC Scores{:>12}{:>12}{:>12}' 96 | .format('allshots', 'cuhk03', 'market1501')) 97 | for k in cmc_topk: 98 | print(' top-{:<4}{:12.1%}{:12.1%}{:12.1%}' 99 | .format(k, cmc_scores['allshots'][k - 1], 100 | cmc_scores['cuhk03'][k - 1], 101 | cmc_scores['market1501'][k - 1])) 102 | 103 | # Use the allshots cmc top-1 score for validation criterion 104 | return cmc_scores['allshots'][0] 105 | 106 | 107 | class Evaluator(object): 108 | def __init__(self, model): 109 | super(Evaluator, self).__init__() 110 | self.model = model 111 | 112 | def evaluate(self, query_loader, gallery_loader, query, gallery): 113 | print('extracting query features\n') 114 | query_features, _ = extract_features(self.model, query_loader) 115 | print('extracting gallery features\n') 116 | gallery_features, _ = extract_features(self.model, gallery_loader) 117 | distmat = pairwise_distance(query_features, gallery_features, query, gallery) 118 | return evaluate_all(distmat, query=query, gallery=gallery) 119 | -------------------------------------------------------------------------------- /reid/evaluators.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/evaluators.pyc -------------------------------------------------------------------------------- /reid/feature_extraction/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .cnn import extract_cnn_feature 4 | from .database import FeatureDatabase 5 | 6 | __all__ = [ 7 | 'extract_cnn_feature', 8 | 'FeatureDatabase', 9 | ] 10 | -------------------------------------------------------------------------------- /reid/feature_extraction/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/feature_extraction/__init__.pyc -------------------------------------------------------------------------------- /reid/feature_extraction/cnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import OrderedDict 3 | 4 | from torch.autograd import Variable 5 | 6 | from ..utils import to_torch 7 | 8 | 9 | def extract_cnn_feature(model, inputs, modules=None, return_mask = False): 10 | model.eval() 11 | inputs = to_torch(inputs) 12 | inputs = Variable(inputs, volatile=True) 13 | if modules is None: 14 | tmp = model(inputs) 15 | outputs = tmp[0] 16 | outputs = outputs.data.cpu() 17 | if return_mask: 18 | mask = tmp[4] 19 | mask = mask.data.cpu() 20 | return outputs, mask 21 | return outputs 22 | # Register forward hook for each module 23 | outputs = OrderedDict() 24 | handles = [] 25 | for m in modules: 26 | outputs[id(m)] = None 27 | def func(m, i, o): outputs[id(m)] = o.data.cpu() 28 | handles.append(m.register_forward_hook(func)) 29 | model(inputs) 30 | for h in handles: 31 | h.remove() 32 | return list(outputs.values()) 33 | -------------------------------------------------------------------------------- /reid/feature_extraction/cnn.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/feature_extraction/cnn.pyc -------------------------------------------------------------------------------- /reid/feature_extraction/database.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import h5py 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class FeatureDatabase(Dataset): 9 | def __init__(self, *args, **kwargs): 10 | super(FeatureDatabase, self).__init__() 11 | self.fid = h5py.File(*args, **kwargs) 12 | 13 | def __enter__(self): 14 | return self 15 | 16 | def __exit__(self, exc_type, exc_val, exc_tb): 17 | self.close() 18 | 19 | def __getitem__(self, keys): 20 | if isinstance(keys, (tuple, list)): 21 | return [self._get_single_item(k) for k in keys] 22 | return self._get_single_item(keys) 23 | 24 | def _get_single_item(self, key): 25 | return np.asarray(self.fid[key]) 26 | 27 | def __setitem__(self, key, value): 28 | if key in self.fid: 29 | if self.fid[key].shape == value.shape and \ 30 | self.fid[key].dtype == value.dtype: 31 | self.fid[key][...] = value 32 | else: 33 | del self.fid[key] 34 | self.fid.create_dataset(key, data=value) 35 | else: 36 | self.fid.create_dataset(key, data=value) 37 | 38 | def __delitem__(self, key): 39 | del self.fid[key] 40 | 41 | def __len__(self): 42 | return len(self.fid) 43 | 44 | def __iter__(self): 45 | return iter(self.fid) 46 | 47 | def flush(self): 48 | self.fid.flush() 49 | 50 | def close(self): 51 | self.fid.close() 52 | -------------------------------------------------------------------------------- /reid/feature_extraction/database.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/feature_extraction/database.pyc -------------------------------------------------------------------------------- /reid/models/.resnet.py.swo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/models/.resnet.py.swo -------------------------------------------------------------------------------- /reid/models/.resnet_mask.py.swo: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/models/.resnet_mask.py.swo -------------------------------------------------------------------------------- /reid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .resnet import * 4 | from .resnet_rpp import resnet50_rpp 5 | 6 | __factory = { 7 | 'resnet18': resnet18, 8 | 'resnet34': resnet34, 9 | 'resnet50': resnet50, 10 | 'resnet101': resnet101, 11 | 'resnet152': resnet152, 12 | 'resnet50_rpp': resnet50_rpp, 13 | } 14 | 15 | 16 | def names(): 17 | return sorted(__factory.keys()) 18 | 19 | 20 | def create(name, *args, **kwargs): 21 | """ 22 | Create a model instance. 23 | 24 | Parameters 25 | ---------- 26 | name : str 27 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 28 | 'resnet50', 'resnet101', and 'resnet152'. 29 | pretrained : bool, optional 30 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 31 | model. Default: True 32 | cut_at_pooling : bool, optional 33 | If True, will cut the model before the last global pooling layer and 34 | ignore the remaining kwargs. Default: False 35 | num_features : int, optional 36 | If positive, will append a Linear layer after the global pooling layer, 37 | with this number of output units, followed by a BatchNorm layer. 38 | Otherwise these layers will not be appended. Default: 256 for 39 | 'inception', 0 for 'resnet*' 40 | norm : bool, optional 41 | If True, will normalize the feature to be unit L2-norm for each sample. 42 | Otherwise will append a ReLU layer after the above Linear layer if 43 | num_features > 0. Default: False 44 | dropout : float, optional 45 | If positive, will append a Dropout layer with this dropout rate. 46 | Default: 0 47 | num_classes : int, optional 48 | If positive, will append a Linear layer at the end as the classifier 49 | with this number of output units. Default: 0 50 | """ 51 | if name not in __factory: 52 | raise KeyError("Unknown model:", name) 53 | return __factory[name](*args, **kwargs) 54 | -------------------------------------------------------------------------------- /reid/models/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/models/__init__.pyc -------------------------------------------------------------------------------- /reid/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | import torch 8 | #from torch_deform_conv.layers import ConvOffset2D 9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152'] 11 | 12 | 13 | class ResNet(nn.Module): 14 | __factory = { 15 | 18: torchvision.models.resnet18, 16 | 34: torchvision.models.resnet34, 17 | 50: torchvision.models.resnet50, 18 | 101: torchvision.models.resnet101, 19 | 152: torchvision.models.resnet152, 20 | } 21 | 22 | def __init__(self, depth, pretrained=True, cut_at_pooling=False, 23 | num_features=0, norm=False, dropout=0, num_classes=0, FCN=False, radius=1., thresh=0.5): 24 | super(ResNet, self).__init__() 25 | 26 | self.depth = depth 27 | self.pretrained = pretrained 28 | self.cut_at_pooling = cut_at_pooling 29 | self.FCN=FCN 30 | 31 | # Construct base (pretrained) resnet 32 | if depth not in ResNet.__factory: 33 | raise KeyError("Unsupported depth:", depth) 34 | self.base = ResNet.__factory[depth](pretrained=pretrained) 35 | 36 | #==========================add dilation=============================# 37 | if self.FCN: 38 | for mo in self.base.layer4[0].modules(): 39 | if isinstance(mo, nn.Conv2d): 40 | mo.stride = (1,1) 41 | #================append conv for FCN==============================# 42 | self.num_features = num_features 43 | self.num_classes = 751 #num_classes 44 | self.dropout = dropout 45 | out_planes = self.base.fc.in_features 46 | self.local_conv = nn.Conv2d(out_planes, self.num_features, kernel_size=1,padding=0,bias=False) 47 | init.kaiming_normal(self.local_conv.weight, mode= 'fan_out') 48 | # init.constant(self.local_conv.bias,0) 49 | self.feat_bn2d = nn.BatchNorm2d(self.num_features) #may not be used, not working on caffe 50 | init.constant(self.feat_bn2d.weight,1) #initialize BN, may not be used 51 | init.constant(self.feat_bn2d.bias,0) # iniitialize BN, may not be used 52 | 53 | ##---------------------------stripe1----------------------------------------------# 54 | self.instance0 = nn.Linear(self.num_features, self.num_classes) 55 | init.normal(self.instance0.weight, std=0.001) 56 | init.constant(self.instance0.bias, 0) 57 | ##---------------------------stripe1----------------------------------------------# 58 | ##---------------------------stripe1----------------------------------------------# 59 | self.instance1 = nn.Linear(self.num_features, self.num_classes) 60 | init.normal(self.instance1.weight, std=0.001) 61 | init.constant(self.instance1.bias, 0) 62 | ##---------------------------stripe1----------------------------------------------# 63 | ##---------------------------stripe1----------------------------------------------# 64 | self.instance2 = nn.Linear(self.num_features, self.num_classes) 65 | init.normal(self.instance2.weight, std=0.001) 66 | init.constant(self.instance2.bias, 0) 67 | ##---------------------------stripe1----------------------------------------------# 68 | ##---------------------------stripe1----------------------------------------------# 69 | self.instance3 = nn.Linear(self.num_features, self.num_classes) 70 | init.normal(self.instance3.weight, std=0.001) 71 | init.constant(self.instance3.bias, 0) 72 | ##---------------------------stripe1----------------------------------------------# 73 | ##---------------------------stripe1----------------------------------------------# 74 | self.instance4 = nn.Linear(self.num_features, self.num_classes) 75 | init.normal(self.instance4.weight, std=0.001) 76 | init.constant(self.instance4.bias, 0) 77 | ##---------------------------stripe1----------------------------------------------# 78 | ##---------------------------stripe1----------------------------------------------# 79 | self.instance5 = nn.Linear(self.num_features, self.num_classes) 80 | init.normal(self.instance5.weight, std=0.001) 81 | init.constant(self.instance5.bias, 0) 82 | 83 | self.drop = nn.Dropout(self.dropout) 84 | 85 | elif not self.cut_at_pooling: 86 | self.num_features = num_features 87 | self.norm = norm 88 | self.dropout = dropout 89 | self.has_embedding = num_features > 0 90 | self.num_classes = num_classes 91 | 92 | self.radius = nn.Parameter(torch.FloatTensor([radius])) 93 | self.thresh = nn.Parameter(torch.FloatTensor([thresh])) 94 | 95 | 96 | 97 | out_planes = self.base.fc.in_features 98 | 99 | # Append new layers 100 | if self.has_embedding: 101 | self.feat = nn.Linear(out_planes, self.num_features, bias=False) 102 | self.feat_bn = nn.BatchNorm1d(self.num_features) 103 | init.kaiming_normal(self.feat.weight, mode='fan_out') 104 | else: 105 | # Change the num_features to CNN output channels 106 | self.num_features = out_planes 107 | if self.dropout > 0: 108 | self.drop = nn.Dropout(self.dropout) 109 | if self.num_classes > 0: 110 | self.classifier = nn.Linear(self.num_features, self.num_classes,bias=True) 111 | init.normal(self.classifier.weight, std=0.001) 112 | init.constant(self.classifier.bias, 0) 113 | 114 | if not self.pretrained: 115 | self.reset_params() 116 | 117 | def forward(self, x): 118 | for name, module in self.base._modules.items(): 119 | if name == 'avgpool': 120 | break 121 | x = module(x) 122 | 123 | if self.cut_at_pooling: 124 | return x 125 | #=======================FCN===============================# 126 | if self.FCN: 127 | y = x.unsqueeze(1) 128 | y = F.avg_pool3d(x,(16,1,1)).squeeze(1) 129 | sx = x.size(2)/6 130 | kx = x.size(2)-sx*5 131 | x = F.avg_pool2d(x,kernel_size=(kx,x.size(3)),stride=(sx,x.size(3))) # H4 W8 132 | #========================================================================# 133 | 134 | out0 = x.view(x.size(0),-1) 135 | out0 = x/x.norm(2,1).unsqueeze(1).expand_as(x) 136 | x = self.drop(x) 137 | x = self.local_conv(x) 138 | out1 = x/x.norm(2,1).unsqueeze(1).expand_as(x) 139 | x = self.feat_bn2d(x) 140 | x = F.relu(x) # relu for local_conv feature 141 | 142 | x = x.chunk(6,2) 143 | x0 = x[0].contiguous().view(x[0].size(0),-1) 144 | x1 = x[1].contiguous().view(x[1].size(0),-1) 145 | x2 = x[2].contiguous().view(x[2].size(0),-1) 146 | x3 = x[3].contiguous().view(x[3].size(0),-1) 147 | x4 = x[4].contiguous().view(x[4].size(0),-1) 148 | x5 = x[5].contiguous().view(x[5].size(0),-1) 149 | 150 | c0 = self.instance0(x0) 151 | c1 = self.instance1(x1) 152 | c2 = self.instance2(x2) 153 | c3 = self.instance3(x3) 154 | c4 = self.instance4(x4) 155 | c5 = self.instance5(x5) 156 | return out0, (c0, c1, c2, c3, c4, c5) 157 | 158 | #==========================================================# 159 | 160 | 161 | x = F.avg_pool2d(x, x.size()[2:]) 162 | x = x.view(x.size(0), -1) 163 | out1 = x.view(x.size(0),-1) 164 | center = out1.mean(0).unsqueeze(0).expand_as(out1) 165 | out2 = x/x.norm(2,1).unsqueeze(1).expand_as(x) 166 | 167 | if self.has_embedding: 168 | x = self.feat(x) 169 | out3 = x/ x.norm(2,1).unsqueeze(1).expand_as(x) 170 | x = self.feat_bn(x) 171 | 172 | if self.norm: 173 | x = x / x.norm(2, 1).unsqueeze(1).expand_as(x) 174 | elif self.has_embedding: # adding relu after fc, not used in softmax but in tripletloss 175 | x = F.relu(x) 176 | out4 = x/ x.norm(2,1).unsqueeze(1).expand_as(x) 177 | if self.dropout > 0: 178 | x = self.drop(x) 179 | if self.num_classes > 0: 180 | x = self.classifier(x) 181 | 182 | return out2, x, out2, out2 183 | 184 | 185 | def reset_params(self): 186 | for m in self.modules(): 187 | if isinstance(m, nn.Conv2d): 188 | init.kaiming_normal(m.weight, mode='fan_out') 189 | if m.bias is not None: 190 | init.constant(m.bias, 0) 191 | elif isinstance(m, nn.BatchNorm2d): 192 | init.constant(m.weight, 1) 193 | init.constant(m.bias, 0) 194 | elif isinstance(m, nn.Linear): 195 | init.normal(m.weight, std=0.001) 196 | if m.bias is not None: 197 | init.constant(m.bias, 0) 198 | 199 | 200 | def resnet18(**kwargs): 201 | return ResNet(18, **kwargs) 202 | 203 | 204 | def resnet34(**kwargs): 205 | return ResNet(34, **kwargs) 206 | 207 | 208 | def resnet50(**kwargs): 209 | return ResNet(50, **kwargs) 210 | 211 | 212 | def resnet101(**kwargs): 213 | return ResNet(101, **kwargs) 214 | 215 | 216 | def resnet152(**kwargs): 217 | return ResNet(152, **kwargs) 218 | -------------------------------------------------------------------------------- /reid/models/resnet.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/models/resnet.pyc -------------------------------------------------------------------------------- /reid/models/resnet_mask.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/models/resnet_mask.pyc -------------------------------------------------------------------------------- /reid/models/resnet_rpp.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | import torch 8 | __all__ = ['resnet50_rpp'] 9 | 10 | 11 | class ResNet(nn.Module): 12 | __factory = { 13 | 18: torchvision.models.resnet18, 14 | 34: torchvision.models.resnet34, 15 | 50: torchvision.models.resnet50, 16 | 101: torchvision.models.resnet101, 17 | 152: torchvision.models.resnet152, 18 | } 19 | 20 | def __init__(self, depth, pretrained=True, cut_at_pooling=False, 21 | num_features=0, norm=False, dropout=0, num_classes=0, FCN=False, T=1, dim = 256): 22 | super(ResNet, self).__init__() 23 | 24 | self.depth = depth 25 | self.pretrained = pretrained 26 | self.cut_at_pooling = cut_at_pooling 27 | self.FCN=FCN 28 | self.T = T 29 | self.reduce_dim = dim 30 | if depth not in ResNet.__factory: 31 | raise KeyError("Unsupported depth:", depth) 32 | self.base = ResNet.__factory[depth](pretrained=pretrained) 33 | 34 | #==========================add dilation=============================# 35 | if self.FCN: 36 | self.base.layer4[0].conv2.stride=(1,1) 37 | self.base.layer4[0].downsample[0].stride=(1,1) 38 | #================append conv for FCN==============================# 39 | self.num_features = num_features 40 | self.num_classes = num_classes 41 | self.dropout = dropout 42 | self.local_conv = nn.Conv2d(2048, self.num_features, kernel_size=1,padding=0,bias=False) 43 | init.kaiming_normal(self.local_conv.weight, mode= 'fan_out') 44 | # init.constant(self.local_conv.bias,0) 45 | self.feat_bn2d = nn.BatchNorm2d(self.num_features) #may not be used, not working on caffe 46 | init.constant(self.feat_bn2d.weight,1) #initialize BN, may not be used 47 | init.constant(self.feat_bn2d.bias,0) # iniitialize BN, may not be used 48 | 49 | ##---------------------------stripe1----------------------------------------------# 50 | self.instance0 = nn.Linear(self.num_features, self.num_classes) 51 | init.normal(self.instance0.weight, std=0.001) 52 | init.constant(self.instance0.bias, 0) 53 | ##---------------------------stripe1----------------------------------------------# 54 | ##---------------------------stripe1----------------------------------------------# 55 | self.instance1 = nn.Linear(self.num_features, self.num_classes) 56 | init.normal(self.instance1.weight, std=0.001) 57 | init.constant(self.instance1.bias, 0) 58 | ##---------------------------stripe1----------------------------------------------# 59 | ##---------------------------stripe1----------------------------------------------# 60 | self.instance2 = nn.Linear(self.num_features, self.num_classes) 61 | init.normal(self.instance2.weight, std=0.001) 62 | init.constant(self.instance2.bias, 0) 63 | ##---------------------------stripe1----------------------------------------------# 64 | ##---------------------------stripe1----------------------------------------------# 65 | self.instance3 = nn.Linear(self.num_features, self.num_classes) 66 | init.normal(self.instance3.weight, std=0.001) 67 | init.constant(self.instance3.bias, 0) 68 | ##---------------------------stripe1----------------------------------------------# 69 | ##---------------------------stripe1----------------------------------------------# 70 | self.instance4 = nn.Linear(self.num_features, self.num_classes) 71 | init.normal(self.instance4.weight, std=0.001) 72 | init.constant(self.instance4.bias, 0) 73 | ##---------------------------stripe1----------------------------------------------# 74 | ##---------------------------stripe1----------------------------------------------# 75 | self.instance5 = nn.Linear(self.num_features, self.num_classes) 76 | init.normal(self.instance5.weight, std=0.001) 77 | init.constant(self.instance5.bias, 0) 78 | ##---------------------------stripe1----------------------------------------------# 79 | self.drop = nn.Dropout(self.dropout) 80 | self.local_mask = nn.Conv2d(self.reduce_dim, 6 , kernel_size=1,padding=0,bias=True) 81 | init.kaiming_normal(self.local_mask.weight, mode= 'fan_out') 82 | init.constant(self.local_mask.bias,0) 83 | 84 | #===================================================================# 85 | 86 | elif not self.cut_at_pooling: 87 | self.num_features = num_features 88 | self.norm = norm 89 | self.dropout = dropout 90 | self.has_embedding = num_features > 0 91 | self.num_classes = num_classes 92 | 93 | out_planes = self.base.fc.in_features 94 | 95 | # Append new layers 96 | if self.has_embedding: 97 | self.feat = nn.Linear(out_planes, self.num_features, bias=False) 98 | self.feat_bn = nn.BatchNorm1d(self.num_features) 99 | init.kaiming_normal(self.feat.weight, mode='fan_out') 100 | # init.constant(self.feat.bias, 0) 101 | init.constant(self.feat_bn.weight, 1) 102 | init.constant(self.feat_bn.bias, 0) 103 | else: 104 | # Change the num_features to CNN output channels 105 | self.num_features = out_planes 106 | if self.dropout > 0: 107 | self.drop = nn.Dropout(self.dropout) 108 | if self.num_classes > 0: 109 | # self.classifier = nn.Linear(self.num_features, self.num_classes) 110 | self.classifier = nn.Linear(self.num_features, self.num_classes) 111 | init.normal(self.classifier.weight, std=0.001) 112 | init.constant(self.classifier.bias, 0) 113 | 114 | if not self.pretrained: 115 | self.reset_params() 116 | 117 | def forward(self, x): 118 | for name, module in self.base._modules.items(): 119 | if name == 'avgpool': 120 | break 121 | x = module(x) 122 | 123 | if self.cut_at_pooling: 124 | return x 125 | #=======================FCN===============================# 126 | if self.FCN: 127 | T = self.T 128 | y = self.drop(x).unsqueeze(1) 129 | stride = 2048/self.reduce_dim 130 | y = F.avg_pool3d(y,kernel_size=(stride,1,1),stride=(stride,1,1)).squeeze(1) 131 | center = F.avg_pool2d(y,(y.size(2),y.size(3))) 132 | y = y-center.expand_as(y) 133 | local_mask = self.local_mask(y) 134 | local_mask = F.softmax(T*local_mask) #using softmax mode 135 | 136 | lw = local_mask.chunk(6,1) 137 | x = x*6 138 | f0 = x*(lw[0].expand_as(x)) 139 | f1 = x*(lw[1].expand_as(x)) 140 | f2 = x*(lw[2].expand_as(x)) 141 | f3 = x*(lw[3].expand_as(x)) 142 | f4 = x*(lw[4].expand_as(x)) 143 | f5 = x*(lw[5].expand_as(x)) 144 | f0 = F.avg_pool2d(f0,kernel_size=(f0.size(2),f0.size(3))) 145 | f1 = F.avg_pool2d(f1,kernel_size=(f1.size(2),f1.size(3))) 146 | f2 = F.avg_pool2d(f2,kernel_size=(f2.size(2),f2.size(3))) 147 | f3 = F.avg_pool2d(f3,kernel_size=(f3.size(2),f3.size(3))) 148 | f4 = F.avg_pool2d(f4,kernel_size=(f4.size(2),f4.size(3))) 149 | f5 = F.avg_pool2d(f5,kernel_size=(f5.size(2),f5.size(3))) 150 | x = torch.cat((f0,f1,f2,f3,f4,f5),2) 151 | feat = torch.cat((f0,f1,f2,f3,f4,f5),2) 152 | 153 | out0 = feat/feat.norm(2,1).unsqueeze(1).expand_as(feat) 154 | 155 | x = self.drop(x) 156 | x = self.local_conv(x) 157 | 158 | out1 = x.view(x.size(0),-1) 159 | out1 = x/x.norm(2,1).unsqueeze(1).expand_as(x) 160 | 161 | x = self.feat_bn2d(x) 162 | out1 = x/x.norm(2,1).unsqueeze(1).expand_as(x) 163 | x = F.relu(x) # relu for local_conv feature 164 | x = x.chunk(6,2) 165 | x0 = x[0].contiguous().view(x[0].size(0),-1) 166 | x1 = x[1].contiguous().view(x[1].size(0),-1) 167 | x2 = x[2].contiguous().view(x[2].size(0),-1) 168 | x3 = x[3].contiguous().view(x[3].size(0),-1) 169 | x4 = x[4].contiguous().view(x[4].size(0),-1) 170 | x5 = x[5].contiguous().view(x[5].size(0),-1) 171 | c0 = self.instance0(x0) 172 | c1 = self.instance1(x1) 173 | c2 = self.instance2(x2) 174 | c3 = self.instance3(x3) 175 | c4 = self.instance4(x4) 176 | c5 = self.instance5(x5) 177 | return out0, (c0, c1, c2, c3, c4, c5), local_mask 178 | 179 | #==========================================================# 180 | 181 | 182 | x = F.avg_pool2d(x, x.size()[2:]) 183 | x = x.view(x.size(0), -1) 184 | out1 = x 185 | out1 = x / x.norm(2,1).unsqueeze(1).expand_as(x) 186 | if self.has_embedding: 187 | x = self.feat(x) 188 | x = self.feat_bn(x) 189 | out2 = x/ x.norm(2,1).unsqueeze(1).expand_as(x) 190 | if self.norm: 191 | x = x / x.norm(2, 1).unsqueeze(1).expand_as(x) 192 | if self.dropout > 0: 193 | x = self.drop(x) 194 | if self.num_classes > 0: 195 | x = self.classifier(x) 196 | 197 | 198 | return out2, x 199 | 200 | 201 | def reset_params(self): 202 | for m in self.modules(): 203 | if isinstance(m, nn.Conv2d): 204 | init.kaiming_normal(m.weight, mode='fan_out') 205 | if m.bias is not None: 206 | init.constant(m.bias, 0) 207 | elif isinstance(m, nn.BatchNorm2d): 208 | init.constant(m.weight, 1) 209 | init.constant(m.bias, 0) 210 | elif isinstance(m, nn.Linear): 211 | init.normal(m.weight, std=0.001) 212 | if m.bias is not None: 213 | init.constant(m.bias, 0) 214 | 215 | 216 | 217 | 218 | 219 | 220 | def resnet50_rpp(**kwargs): 221 | return ResNet(50, **kwargs) 222 | -------------------------------------------------------------------------------- /reid/models/resnet_rpp.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/models/resnet_rpp.pyc -------------------------------------------------------------------------------- /reid/trainers_partloss.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | from .evaluation_metrics import accuracy 8 | from .utils.meters import AverageMeter 9 | from .utils import Bar 10 | from torch.nn import functional as F 11 | 12 | class BaseTrainer(object): 13 | def __init__(self, model, criterion, X, Y, SMLoss_mode=0): 14 | super(BaseTrainer, self).__init__() 15 | self.model = model 16 | self.criterion = criterion 17 | 18 | def train(self, epoch, data_loader, optimizer, print_freq=1): 19 | self.model.train() 20 | 21 | 22 | batch_time = AverageMeter() 23 | data_time = AverageMeter() 24 | losses = AverageMeter() 25 | precisions = AverageMeter() 26 | end = time.time() 27 | 28 | bar = Bar('Processing', max=len(data_loader)) 29 | for i, inputs in enumerate(data_loader): 30 | data_time.update(time.time() - end) 31 | 32 | inputs, targets = self._parse_data(inputs) 33 | loss0, loss1, loss2, loss3, loss4, loss5, prec1 = self._forward(inputs, targets) 34 | #=================================================================================== 35 | loss = (loss0+loss1+loss2+loss3+loss4+loss5)/6 36 | losses.update(loss.data[0], targets.size(0)) 37 | precisions.update(prec1, targets.size(0)) 38 | 39 | optimizer.zero_grad() 40 | torch.autograd.backward([ loss0, loss1, loss2, loss3, loss4, loss5],[torch.ones(1).cuda(), torch.ones(1).cuda(), torch.ones(1).cuda(),torch.ones(1).cuda(),torch.ones(1).cuda(),torch.ones(1).cuda(),torch.ones(1).cuda()]) 41 | optimizer.step() 42 | 43 | batch_time.update(time.time() - end) 44 | end = time.time() 45 | 46 | # plot progress 47 | bar.suffix = 'Epoch: [{N_epoch}][{N_batch}/{N_size}] | Time {N_bt:.3f} {N_bta:.3f} | Data {N_dt:.3f} {N_dta:.3f} | Loss {N_loss:.3f} {N_lossa:.3f} | Prec {N_prec:.2f} {N_preca:.2f}'.format( 48 | N_epoch=epoch, N_batch=i + 1, N_size=len(data_loader), 49 | N_bt=batch_time.val, N_bta=batch_time.avg, 50 | N_dt=data_time.val, N_dta=data_time.avg, 51 | N_loss=losses.val, N_lossa=losses.avg, 52 | N_prec=precisions.val, N_preca=precisions.avg, 53 | ) 54 | bar.next() 55 | bar.finish() 56 | 57 | 58 | 59 | def _parse_data(self, inputs): 60 | raise NotImplementedError 61 | 62 | def _forward(self, inputs, targets): 63 | raise NotImplementedError 64 | 65 | 66 | class Trainer(BaseTrainer): 67 | def _parse_data(self, inputs): 68 | imgs, _, pids, _ = inputs 69 | inputs = [Variable(imgs)] 70 | targets = Variable(pids.cuda()) 71 | return inputs, targets 72 | 73 | def _forward(self, inputs, targets): 74 | outputs = self.model(*inputs) 75 | index = (targets-751).data.nonzero().squeeze_() 76 | 77 | if isinstance(self.criterion, torch.nn.CrossEntropyLoss): 78 | loss0 = self.criterion(outputs[1][0],targets) 79 | loss1 = self.criterion(outputs[1][1],targets) 80 | loss2 = self.criterion(outputs[1][2],targets) 81 | loss3 = self.criterion(outputs[1][3],targets) 82 | loss4 = self.criterion(outputs[1][4],targets) 83 | loss5 = self.criterion(outputs[1][5],targets) 84 | prec, = accuracy(outputs[1][2].data, targets.data) 85 | prec = prec[0] 86 | 87 | elif isinstance(self.criterion, OIMLoss): 88 | loss, outputs = self.criterion(outputs, targets) 89 | prec, = accuracy(outputs.data, targets.data) 90 | prec = prec[0] 91 | elif isinstance(self.criterion, TripletLoss): 92 | loss, prec = self.criterion(outputs, targets) 93 | else: 94 | raise ValueError("Unsupported loss:", self.criterion) 95 | return loss0, loss1, loss2, loss3, loss4, loss5, prec 96 | -------------------------------------------------------------------------------- /reid/trainers_partloss.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/trainers_partloss.pyc -------------------------------------------------------------------------------- /reid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | 23 | 24 | # progress bar 25 | import os, sys 26 | sys.path.append(os.path.join(os.path.dirname(__file__), "progress")) 27 | from progress.bar import Bar as Bar 28 | 29 | -------------------------------------------------------------------------------- /reid/utils/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/utils/__init__.pyc -------------------------------------------------------------------------------- /reid/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .dataset import Dataset 4 | from .preprocessor import Preprocessor 5 | -------------------------------------------------------------------------------- /reid/utils/data/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/utils/data/__init__.pyc -------------------------------------------------------------------------------- /reid/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os.path as osp 3 | 4 | import numpy as np 5 | 6 | from ..serialization import read_json 7 | 8 | 9 | def _pluck(identities, indices, relabel=False): 10 | ret = [] 11 | for index, pid in enumerate(indices): 12 | pid_images = identities[pid] 13 | for camid, cam_images in enumerate(pid_images): 14 | for fname in cam_images: 15 | name = osp.splitext(fname)[0] 16 | x, y, _ = map(int, name.split('_')) 17 | assert pid == x and camid == y 18 | if relabel: 19 | ret.append((fname, index, camid)) 20 | else: 21 | ret.append((fname, pid, camid)) 22 | return ret 23 | 24 | 25 | class Dataset(object): 26 | def __init__(self, root, split_id=0): 27 | self.root = root 28 | self.split_id = split_id 29 | self.meta = None 30 | self.split = None 31 | self.train, self.val, self.trainval = [], [], [] 32 | self.query, self.gallery = [], [] 33 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 34 | 35 | @property 36 | def images_dir(self): 37 | return osp.join(self.root, 'images') 38 | 39 | def load(self, num_val=0.3, verbose=True): 40 | splits = read_json(osp.join(self.root, 'splits.json')) 41 | if self.split_id >= len(splits): 42 | raise ValueError("split_id exceeds total splits {}" 43 | .format(len(splits))) 44 | self.split = splits[self.split_id] 45 | 46 | # Randomly split train / val 47 | trainval_pids = np.asarray(self.split['trainval']) 48 | np.random.shuffle(trainval_pids) 49 | num = len(trainval_pids) 50 | if isinstance(num_val, float): 51 | num_val = int(round(num * num_val)) 52 | if num_val >= num or num_val < 0: 53 | raise ValueError("num_val exceeds total identities {}" 54 | .format(num)) 55 | train_pids = sorted(trainval_pids[:-num_val]) 56 | val_pids = sorted(trainval_pids[-num_val:]) 57 | 58 | self.meta = read_json(osp.join(self.root, 'meta.json')) 59 | identities = self.meta['identities'] 60 | self.train = _pluck(identities, train_pids, relabel=True) 61 | self.val = _pluck(identities, val_pids, relabel=True) 62 | self.trainval = _pluck(identities, trainval_pids, relabel=True) 63 | self.query = _pluck(identities, self.split['query']) 64 | self.gallery = _pluck(identities, self.split['gallery']) 65 | self.num_train_ids = len(train_pids) 66 | self.num_val_ids = len(val_pids) 67 | self.num_trainval_ids = len(trainval_pids) 68 | 69 | if verbose: 70 | print(self.__class__.__name__, "dataset loaded") 71 | print(" subset | # ids | # images") 72 | print(" ---------------------------") 73 | print(" train | {:5d} | {:8d}" 74 | .format(self.num_train_ids, len(self.train))) 75 | print(" val | {:5d} | {:8d}" 76 | .format(self.num_val_ids, len(self.val))) 77 | print(" trainval | {:5d} | {:8d}" 78 | .format(self.num_trainval_ids, len(self.trainval))) 79 | print(" query | {:5d} | {:8d}" 80 | .format(len(self.split['query']), len(self.query))) 81 | print(" gallery | {:5d} | {:8d}" 82 | .format(len(self.split['gallery']), len(self.gallery))) 83 | 84 | def _check_integrity(self): 85 | return osp.isdir(osp.join(self.root, 'images')) and \ 86 | osp.isfile(osp.join(self.root, 'meta.json')) and \ 87 | osp.isfile(osp.join(self.root, 'splits.json')) 88 | -------------------------------------------------------------------------------- /reid/utils/data/dataset.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/utils/data/dataset.pyc -------------------------------------------------------------------------------- /reid/utils/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os.path as osp 3 | 4 | from PIL import Image 5 | 6 | 7 | class Preprocessor(object): 8 | def __init__(self, dataset, root=None, transform=None): 9 | super(Preprocessor, self).__init__() 10 | self.dataset = dataset 11 | self.root = root 12 | self.transform = transform 13 | 14 | def __len__(self): 15 | return len(self.dataset) 16 | 17 | def __getitem__(self, indices): 18 | if isinstance(indices, (tuple, list)): 19 | return [self._get_single_item(index) for index in indices] 20 | return self._get_single_item(indices) 21 | 22 | def _get_single_item(self, index): 23 | fname, pid, camid = self.dataset[index] 24 | fpath = fname 25 | if self.root is not None: 26 | fpath = osp.join(self.root, fname) 27 | img = Image.open(fpath).convert('RGB') 28 | if self.transform is not None: 29 | img = self.transform(img) 30 | return img, fname, pid, camid 31 | -------------------------------------------------------------------------------- /reid/utils/data/preprocessor.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/utils/data/preprocessor.pyc -------------------------------------------------------------------------------- /reid/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data.sampler import ( 7 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 8 | WeightedRandomSampler) 9 | 10 | 11 | class RandomIdentitySampler(Sampler): 12 | def __init__(self, data_source, num_instances=1): 13 | self.data_source = data_source 14 | self.num_instances = num_instances 15 | self.index_dic = defaultdict(list) 16 | for index, (_, pid, _) in enumerate(data_source): 17 | self.index_dic[pid].append(index) 18 | self.pids = list(self.index_dic.keys()) 19 | self.num_samples = len(self.pids) 20 | 21 | def __len__(self): 22 | return self.num_samples * self.num_instances 23 | 24 | def __iter__(self): 25 | indices = torch.randperm(self.num_samples) 26 | ret = [] 27 | for i in indices: 28 | pid = self.pids[i] 29 | t = self.index_dic[pid] 30 | if len(t) >= self.num_instances: 31 | t = np.random.choice(t, size=self.num_instances, replace=False) 32 | else: 33 | t = np.random.choice(t, size=self.num_instances, replace=True) 34 | ret.extend(t) 35 | return iter(ret) 36 | ''' 37 | def __iter__(self): 38 | indices = torch.randperm(self.num_samples-1) 39 | ret = [] 40 | u = self.index_dic[751] 41 | for i in indices: 42 | pid = self.pids[i] 43 | t = self.index_dic[pid] 44 | v = np.random.choice(u, size = self.num_instances, replace=False) 45 | if len(t) >= self.num_instances: 46 | t = np.random.choice(t, size=self.num_instances, replace=False) 47 | else: 48 | t = np.random.choice(t, size=self.num_instances, replace=True) 49 | ret.extend(t) 50 | ret.extend(v) 51 | return iter(ret) 52 | ''' 53 | -------------------------------------------------------------------------------- /reid/utils/data/sampler.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/utils/data/sampler.pyc -------------------------------------------------------------------------------- /reid/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import random 6 | import math 7 | 8 | 9 | class RectScale(object): 10 | def __init__(self, height, width, interpolation=Image.BILINEAR): 11 | self.height = height 12 | self.width = width 13 | self.interpolation = interpolation 14 | 15 | def __call__(self, img): 16 | w, h = img.size 17 | if h == self.height and w == self.width: 18 | return img 19 | return img.resize((self.width, self.height), self.interpolation) 20 | 21 | 22 | class RandomSizedRectCrop(object): 23 | def __init__(self, height, width, interpolation=Image.BILINEAR): 24 | self.height = height 25 | self.width = width 26 | self.interpolation = interpolation 27 | 28 | def __call__(self, img): 29 | for attempt in range(10): 30 | area = img.size[0] * img.size[1] 31 | target_area = random.uniform(0.64, 1.0) * area 32 | aspect_ratio = random.uniform(2, 3) 33 | 34 | h = int(round(math.sqrt(target_area * aspect_ratio))) 35 | w = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | if w <= img.size[0] and h <= img.size[1]: 38 | x1 = random.randint(0, img.size[0] - w) 39 | y1 = random.randint(0, img.size[1] - h) 40 | 41 | img = img.crop((x1, y1, x1 + w, y1 + h)) 42 | assert(img.size == (w, h)) 43 | 44 | return img.resize((self.width, self.height), self.interpolation) 45 | 46 | # Fallback 47 | scale = RectScale(self.height, self.width, 48 | interpolation=self.interpolation) 49 | return scale(img) 50 | -------------------------------------------------------------------------------- /reid/utils/data/transforms.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/utils/data/transforms.pyc -------------------------------------------------------------------------------- /reid/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /reid/utils/logging.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/utils/logging.pyc -------------------------------------------------------------------------------- /reid/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /reid/utils/meters.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/utils/meters.pyc -------------------------------------------------------------------------------- /reid/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /reid/utils/osutils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/utils/osutils.pyc -------------------------------------------------------------------------------- /reid/utils/progress/LICENSE: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | -------------------------------------------------------------------------------- /reid/utils/progress/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.rst LICENSE 2 | -------------------------------------------------------------------------------- /reid/utils/progress/README.rst: -------------------------------------------------------------------------------- 1 | Easy progress reporting for Python 2 | ================================== 3 | 4 | |pypi| 5 | 6 | |demo| 7 | 8 | .. |pypi| image:: https://img.shields.io/pypi/v/progress.svg 9 | .. |demo| image:: https://raw.github.com/verigak/progress/master/demo.gif 10 | :alt: Demo 11 | 12 | Bars 13 | ---- 14 | 15 | There are 7 progress bars to choose from: 16 | 17 | - ``Bar`` 18 | - ``ChargingBar`` 19 | - ``FillingSquaresBar`` 20 | - ``FillingCirclesBar`` 21 | - ``IncrementalBar`` 22 | - ``PixelBar`` 23 | - ``ShadyBar`` 24 | 25 | To use them, just call ``next`` to advance and ``finish`` to finish: 26 | 27 | .. code-block:: python 28 | 29 | from progress.bar import Bar 30 | 31 | bar = Bar('Processing', max=20) 32 | for i in range(20): 33 | # Do some work 34 | bar.next() 35 | bar.finish() 36 | 37 | The result will be a bar like the following: :: 38 | 39 | Processing |############# | 42/100 40 | 41 | To simplify the common case where the work is done in an iterator, you can 42 | use the ``iter`` method: 43 | 44 | .. code-block:: python 45 | 46 | for i in Bar('Processing').iter(it): 47 | # Do some work 48 | 49 | Progress bars are very customizable, you can change their width, their fill 50 | character, their suffix and more: 51 | 52 | .. code-block:: python 53 | 54 | bar = Bar('Loading', fill='@', suffix='%(percent)d%%') 55 | 56 | This will produce a bar like the following: :: 57 | 58 | Loading |@@@@@@@@@@@@@ | 42% 59 | 60 | You can use a number of template arguments in ``message`` and ``suffix``: 61 | 62 | ========== ================================ 63 | Name Value 64 | ========== ================================ 65 | index current value 66 | max maximum value 67 | remaining max - index 68 | progress index / max 69 | percent progress * 100 70 | avg simple moving average time per item (in seconds) 71 | elapsed elapsed time in seconds 72 | elapsed_td elapsed as a timedelta (useful for printing as a string) 73 | eta avg * remaining 74 | eta_td eta as a timedelta (useful for printing as a string) 75 | ========== ================================ 76 | 77 | Instead of passing all configuration options on instatiation, you can create 78 | your custom subclass: 79 | 80 | .. code-block:: python 81 | 82 | class FancyBar(Bar): 83 | message = 'Loading' 84 | fill = '*' 85 | suffix = '%(percent).1f%% - %(eta)ds' 86 | 87 | You can also override any of the arguments or create your own: 88 | 89 | .. code-block:: python 90 | 91 | class SlowBar(Bar): 92 | suffix = '%(remaining_hours)d hours remaining' 93 | @property 94 | def remaining_hours(self): 95 | return self.eta // 3600 96 | 97 | 98 | Spinners 99 | ======== 100 | 101 | For actions with an unknown number of steps you can use a spinner: 102 | 103 | .. code-block:: python 104 | 105 | from progress.spinner import Spinner 106 | 107 | spinner = Spinner('Loading ') 108 | while state != 'FINISHED': 109 | # Do some work 110 | spinner.next() 111 | 112 | There are 5 predefined spinners: 113 | 114 | - ``Spinner`` 115 | - ``PieSpinner`` 116 | - ``MoonSpinner`` 117 | - ``LineSpinner`` 118 | - ``PixelSpinner`` 119 | 120 | 121 | Other 122 | ===== 123 | 124 | There are a number of other classes available too, please check the source or 125 | subclass one of them to create your own. 126 | 127 | 128 | License 129 | ======= 130 | 131 | progress is licensed under ISC 132 | -------------------------------------------------------------------------------- /reid/utils/progress/progress/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from __future__ import division 16 | 17 | from collections import deque 18 | from datetime import timedelta 19 | from math import ceil 20 | from sys import stderr 21 | from time import time 22 | 23 | 24 | __version__ = '1.3' 25 | 26 | 27 | class Infinite(object): 28 | file = stderr 29 | sma_window = 10 # Simple Moving Average window 30 | 31 | def __init__(self, *args, **kwargs): 32 | self.index = 0 33 | self.start_ts = time() 34 | self.avg = 0 35 | self._ts = self.start_ts 36 | self._xput = deque(maxlen=self.sma_window) 37 | for key, val in kwargs.items(): 38 | setattr(self, key, val) 39 | 40 | def __getitem__(self, key): 41 | if key.startswith('_'): 42 | return None 43 | return getattr(self, key, None) 44 | 45 | @property 46 | def elapsed(self): 47 | return int(time() - self.start_ts) 48 | 49 | @property 50 | def elapsed_td(self): 51 | return timedelta(seconds=self.elapsed) 52 | 53 | def update_avg(self, n, dt): 54 | if n > 0: 55 | self._xput.append(dt / n) 56 | self.avg = sum(self._xput) / len(self._xput) 57 | 58 | def update(self): 59 | pass 60 | 61 | def start(self): 62 | pass 63 | 64 | def finish(self): 65 | pass 66 | 67 | def next(self, n=1): 68 | now = time() 69 | dt = now - self._ts 70 | self.update_avg(n, dt) 71 | self._ts = now 72 | self.index = self.index + n 73 | self.update() 74 | 75 | def iter(self, it): 76 | try: 77 | for x in it: 78 | yield x 79 | self.next() 80 | finally: 81 | self.finish() 82 | 83 | 84 | class Progress(Infinite): 85 | def __init__(self, *args, **kwargs): 86 | super(Progress, self).__init__(*args, **kwargs) 87 | self.max = kwargs.get('max', 100) 88 | 89 | @property 90 | def eta(self): 91 | return int(ceil(self.avg * self.remaining)) 92 | 93 | @property 94 | def eta_td(self): 95 | return timedelta(seconds=self.eta) 96 | 97 | @property 98 | def percent(self): 99 | return self.progress * 100 100 | 101 | @property 102 | def progress(self): 103 | return min(1, self.index / self.max) 104 | 105 | @property 106 | def remaining(self): 107 | return max(self.max - self.index, 0) 108 | 109 | def start(self): 110 | self.update() 111 | 112 | def goto(self, index): 113 | incr = index - self.index 114 | self.next(incr) 115 | 116 | def iter(self, it): 117 | try: 118 | self.max = len(it) 119 | except TypeError: 120 | pass 121 | 122 | try: 123 | for x in it: 124 | yield x 125 | self.next() 126 | finally: 127 | self.finish() 128 | -------------------------------------------------------------------------------- /reid/utils/progress/progress/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/utils/progress/progress/__init__.pyc -------------------------------------------------------------------------------- /reid/utils/progress/progress/bar.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Progress 19 | from .helpers import WritelnMixin 20 | 21 | 22 | class Bar(WritelnMixin, Progress): 23 | width = 32 24 | message = '' 25 | suffix = '%(index)d/%(max)d' 26 | bar_prefix = ' |' 27 | bar_suffix = '| ' 28 | empty_fill = ' ' 29 | fill = '#' 30 | hide_cursor = True 31 | 32 | def update(self): 33 | filled_length = int(self.width * self.progress) 34 | empty_length = self.width - filled_length 35 | 36 | message = self.message % self 37 | bar = self.fill * filled_length 38 | empty = self.empty_fill * empty_length 39 | suffix = self.suffix % self 40 | line = ''.join([message, self.bar_prefix, bar, empty, self.bar_suffix, 41 | suffix]) 42 | self.writeln(line) 43 | 44 | 45 | class ChargingBar(Bar): 46 | suffix = '%(percent)d%%' 47 | bar_prefix = ' ' 48 | bar_suffix = ' ' 49 | empty_fill = '∙' 50 | fill = '█' 51 | 52 | 53 | class FillingSquaresBar(ChargingBar): 54 | empty_fill = '▢' 55 | fill = '▣' 56 | 57 | 58 | class FillingCirclesBar(ChargingBar): 59 | empty_fill = '◯' 60 | fill = '◉' 61 | 62 | 63 | class IncrementalBar(Bar): 64 | phases = (' ', '▏', '▎', '▍', '▌', '▋', '▊', '▉', '█') 65 | 66 | def update(self): 67 | nphases = len(self.phases) 68 | filled_len = self.width * self.progress 69 | nfull = int(filled_len) # Number of full chars 70 | phase = int((filled_len - nfull) * nphases) # Phase of last char 71 | nempty = self.width - nfull # Number of empty chars 72 | 73 | message = self.message % self 74 | bar = self.phases[-1] * nfull 75 | current = self.phases[phase] if phase > 0 else '' 76 | empty = self.empty_fill * max(0, nempty - len(current)) 77 | suffix = self.suffix % self 78 | line = ''.join([message, self.bar_prefix, bar, current, empty, 79 | self.bar_suffix, suffix]) 80 | self.writeln(line) 81 | 82 | 83 | class PixelBar(IncrementalBar): 84 | phases = ('⡀', '⡄', '⡆', '⡇', '⣇', '⣧', '⣷', '⣿') 85 | 86 | 87 | class ShadyBar(IncrementalBar): 88 | phases = (' ', '░', '▒', '▓', '█') 89 | -------------------------------------------------------------------------------- /reid/utils/progress/progress/bar.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/utils/progress/progress/bar.pyc -------------------------------------------------------------------------------- /reid/utils/progress/progress/counter.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Infinite, Progress 19 | from .helpers import WriteMixin 20 | 21 | 22 | class Counter(WriteMixin, Infinite): 23 | message = '' 24 | hide_cursor = True 25 | 26 | def update(self): 27 | self.write(str(self.index)) 28 | 29 | 30 | class Countdown(WriteMixin, Progress): 31 | hide_cursor = True 32 | 33 | def update(self): 34 | self.write(str(self.remaining)) 35 | 36 | 37 | class Stack(WriteMixin, Progress): 38 | phases = (' ', '▁', '▂', '▃', '▄', '▅', '▆', '▇', '█') 39 | hide_cursor = True 40 | 41 | def update(self): 42 | nphases = len(self.phases) 43 | i = min(nphases - 1, int(self.progress * nphases)) 44 | self.write(self.phases[i]) 45 | 46 | 47 | class Pie(Stack): 48 | phases = ('○', '◔', '◑', '◕', '●') 49 | -------------------------------------------------------------------------------- /reid/utils/progress/progress/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2012 Giorgos Verigakis 2 | # 3 | # Permission to use, copy, modify, and distribute this software for any 4 | # purpose with or without fee is hereby granted, provided that the above 5 | # copyright notice and this permission notice appear in all copies. 6 | # 7 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 10 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 12 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 13 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14 | 15 | from __future__ import print_function 16 | 17 | 18 | HIDE_CURSOR = '\x1b[?25l' 19 | SHOW_CURSOR = '\x1b[?25h' 20 | 21 | 22 | class WriteMixin(object): 23 | hide_cursor = False 24 | 25 | def __init__(self, message=None, **kwargs): 26 | super(WriteMixin, self).__init__(**kwargs) 27 | self._width = 0 28 | if message: 29 | self.message = message 30 | 31 | if self.file.isatty(): 32 | if self.hide_cursor: 33 | print(HIDE_CURSOR, end='', file=self.file) 34 | print(self.message, end='', file=self.file) 35 | self.file.flush() 36 | 37 | def write(self, s): 38 | if self.file.isatty(): 39 | b = '\b' * self._width 40 | c = s.ljust(self._width) 41 | print(b + c, end='', file=self.file) 42 | self._width = max(self._width, len(s)) 43 | self.file.flush() 44 | 45 | def finish(self): 46 | if self.file.isatty() and self.hide_cursor: 47 | print(SHOW_CURSOR, end='', file=self.file) 48 | 49 | 50 | class WritelnMixin(object): 51 | hide_cursor = False 52 | 53 | def __init__(self, message=None, **kwargs): 54 | super(WritelnMixin, self).__init__(**kwargs) 55 | if message: 56 | self.message = message 57 | 58 | if self.file.isatty() and self.hide_cursor: 59 | print(HIDE_CURSOR, end='', file=self.file) 60 | 61 | def clearln(self): 62 | if self.file.isatty(): 63 | print('\r\x1b[K', end='', file=self.file) 64 | 65 | def writeln(self, line): 66 | if self.file.isatty(): 67 | self.clearln() 68 | print(line, end='', file=self.file) 69 | self.file.flush() 70 | 71 | def finish(self): 72 | if self.file.isatty(): 73 | print(file=self.file) 74 | if self.hide_cursor: 75 | print(SHOW_CURSOR, end='', file=self.file) 76 | 77 | 78 | from signal import signal, SIGINT 79 | from sys import exit 80 | 81 | 82 | class SigIntMixin(object): 83 | """Registers a signal handler that calls finish on SIGINT""" 84 | 85 | def __init__(self, *args, **kwargs): 86 | super(SigIntMixin, self).__init__(*args, **kwargs) 87 | signal(SIGINT, self._sigint_handler) 88 | 89 | def _sigint_handler(self, signum, frame): 90 | self.finish() 91 | exit(0) 92 | -------------------------------------------------------------------------------- /reid/utils/progress/progress/helpers.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/utils/progress/progress/helpers.pyc -------------------------------------------------------------------------------- /reid/utils/progress/progress/spinner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Copyright (c) 2012 Giorgos Verigakis 4 | # 5 | # Permission to use, copy, modify, and distribute this software for any 6 | # purpose with or without fee is hereby granted, provided that the above 7 | # copyright notice and this permission notice appear in all copies. 8 | # 9 | # THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 | # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 | # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 | # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 | # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 | # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 | # OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 | 17 | from __future__ import unicode_literals 18 | from . import Infinite 19 | from .helpers import WriteMixin 20 | 21 | 22 | class Spinner(WriteMixin, Infinite): 23 | message = '' 24 | phases = ('-', '\\', '|', '/') 25 | hide_cursor = True 26 | 27 | def update(self): 28 | i = self.index % len(self.phases) 29 | self.write(self.phases[i]) 30 | 31 | 32 | class PieSpinner(Spinner): 33 | phases = ['◷', '◶', '◵', '◴'] 34 | 35 | 36 | class MoonSpinner(Spinner): 37 | phases = ['◑', '◒', '◐', '◓'] 38 | 39 | 40 | class LineSpinner(Spinner): 41 | phases = ['⎺', '⎻', '⎼', '⎽', '⎼', '⎻'] 42 | 43 | class PixelSpinner(Spinner): 44 | phases = ['⣾','⣷', '⣯', '⣟', '⡿', '⢿', '⣻', '⣽'] 45 | -------------------------------------------------------------------------------- /reid/utils/progress/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup 4 | 5 | import progress 6 | 7 | 8 | setup( 9 | name='progress', 10 | version=progress.__version__, 11 | description='Easy to use progress bars', 12 | long_description=open('README.rst').read(), 13 | author='Giorgos Verigakis', 14 | author_email='verigak@gmail.com', 15 | url='http://github.com/verigak/progress/', 16 | license='ISC', 17 | packages=['progress'], 18 | classifiers=[ 19 | 'Environment :: Console', 20 | 'Intended Audience :: Developers', 21 | 'License :: OSI Approved :: ISC License (ISCL)', 22 | 'Programming Language :: Python :: 2.6', 23 | 'Programming Language :: Python :: 2.7', 24 | 'Programming Language :: Python :: 3.3', 25 | 'Programming Language :: Python :: 3.4', 26 | 'Programming Language :: Python :: 3.5', 27 | 'Programming Language :: Python :: 3.6', 28 | ] 29 | ) 30 | -------------------------------------------------------------------------------- /reid/utils/progress/test_progress.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from __future__ import print_function 4 | 5 | import random 6 | import time 7 | 8 | from progress.bar import (Bar, ChargingBar, FillingSquaresBar, 9 | FillingCirclesBar, IncrementalBar, PixelBar, 10 | ShadyBar) 11 | from progress.spinner import (Spinner, PieSpinner, MoonSpinner, LineSpinner, 12 | PixelSpinner) 13 | from progress.counter import Counter, Countdown, Stack, Pie 14 | 15 | 16 | def sleep(): 17 | t = 0.01 18 | t += t * random.uniform(-0.1, 0.1) # Add some variance 19 | time.sleep(t) 20 | 21 | 22 | for bar_cls in (Bar, ChargingBar, FillingSquaresBar, FillingCirclesBar): 23 | suffix = '%(index)d/%(max)d [%(elapsed)d / %(eta)d / %(eta_td)s]' 24 | bar = bar_cls(bar_cls.__name__, suffix=suffix) 25 | for i in bar.iter(range(200)): 26 | sleep() 27 | 28 | for bar_cls in (IncrementalBar, PixelBar, ShadyBar): 29 | suffix = '%(percent)d%% [%(elapsed_td)s / %(eta)d / %(eta_td)s]' 30 | bar = bar_cls(bar_cls.__name__, suffix=suffix) 31 | for i in bar.iter(range(200)): 32 | sleep() 33 | 34 | for spin in (Spinner, PieSpinner, MoonSpinner, LineSpinner, PixelSpinner): 35 | for i in spin(spin.__name__ + ' ').iter(range(100)): 36 | sleep() 37 | print() 38 | 39 | for singleton in (Counter, Countdown, Stack, Pie): 40 | for i in singleton(singleton.__name__ + ' ').iter(range(100)): 41 | sleep() 42 | print() 43 | 44 | bar = IncrementalBar('Random', suffix='%(index)d') 45 | for i in range(100): 46 | bar.goto(random.randint(0, 100)) 47 | sleep() 48 | bar.finish() 49 | -------------------------------------------------------------------------------- /reid/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | if is_best: 28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 29 | 30 | 31 | def load_checkpoint(fpath): 32 | if osp.isfile(fpath): 33 | checkpoint = torch.load(fpath) 34 | print("=> Loaded checkpoint '{}'".format(fpath)) 35 | return checkpoint 36 | else: 37 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 38 | 39 | 40 | def copy_state_dict(state_dict, model, strip=None): 41 | tgt_state = model.state_dict() 42 | copied_names = set() 43 | for name, param in state_dict.items(): 44 | if strip is not None and name.startswith(strip): 45 | name = name[len(strip):] 46 | if name not in tgt_state: 47 | continue 48 | if isinstance(param, Parameter): 49 | param = param.data 50 | if param.size() != tgt_state[name].size(): 51 | print('mismatch:', name, param.size(), tgt_state[name].size()) 52 | continue 53 | tgt_state[name].copy_(param) 54 | copied_names.add(name) 55 | 56 | missing = set(tgt_state.keys()) - copied_names 57 | if len(missing) > 0: 58 | print("missing keys in state_dict:", missing) 59 | 60 | return model 61 | -------------------------------------------------------------------------------- /reid/utils/serialization.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/syfafterzy/PCB_RPP_for_reID/e29cf54486427d1423277d4c793e39ac0eeff87c/reid/utils/serialization.pyc -------------------------------------------------------------------------------- /train_PCB.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2,3 python PCB.py -d market -a resnet50 -b 64 -j 4 --epochs 60 --log logs/market-1501/PCB/ --combine-trainval --feature 256 --height 384 --width 128 --step-size 40 --data-dir ~/datasets/Market-1501/ 2 | -------------------------------------------------------------------------------- /train_RPP.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2,3 python PCB.py -d market -a resnet50 -b 64 -j 4 --log logs/market-1501/PCB_20epoch/ --feature 256 --height 384 --width 128 --epochs 20 --step-size 20 --data-dir ~/datasets/Market-1501/ 2 | 3 | 4 | CUDA_VISIBLE_DEVICES=2,3 python RPP.py -d market -a resnet50_rpp -b 64 -j 4 --log logs/market-1501/RPP/ --feature 256 --height 384 --width 128 --epochs 50 --step_size 20 --data-dir ~/datasets/Market-1501/ --resume logs/market-1501/PCB_20epoch/checkpoint.pth.tar 5 | --------------------------------------------------------------------------------