├── .gitignore ├── IDE.py ├── PCB.py ├── README.md ├── ZJU.py ├── reid ├── __init__.py ├── camstyle_trainer.py ├── datasets │ ├── __init__.py │ ├── ai_city.py │ ├── dukemtmc.py │ ├── market1501.py │ ├── veri.py │ └── veri_affinity │ │ └── affinity_matrix.txt ├── evaluation_metrics │ ├── __init__.py │ ├── classification.py │ └── ranking.py ├── evaluators.py ├── feature_extraction │ ├── __init__.py │ ├── cnn.py │ └── database.py ├── loss │ ├── __init__.py │ ├── label_smooth.py │ └── triplet.py ├── metric │ ├── MLP_model.py │ ├── __init__.py │ ├── metric_evaluate.py │ ├── metric_trainer.py │ └── reid_feat_dataset.py ├── models │ ├── IDE_model.py │ ├── PCB_model.py │ └── __init__.py ├── prepare │ ├── __init__.py │ ├── add_aic_gps.py │ ├── affinity_matrix.py │ ├── ensemble.py │ ├── extract_bbox.py │ └── label_det_dataset.py ├── trainers.py └── utils │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── dataset.py │ ├── og_sampler.py │ ├── preprocessor.py │ ├── transforms.py │ └── zju_sampler.py │ ├── draw_curve.py │ ├── get_loaders.py │ ├── logger.py │ ├── meters.py │ └── serialization.py ├── reid_metric.py ├── requirements.txt ├── save_cnn_feature.py └── triplet.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | 3 | *~ 4 | 5 | # temporary files which can be created if a process still has a handle open of a deleted file 6 | .fuse_hidden* 7 | 8 | # KDE directory preferences 9 | .directory 10 | 11 | # Linux trash folder which might appear on any partition or disk 12 | .Trash-* 13 | 14 | # .nfs files are created when an open file is removed but is still being accessed 15 | .nfs* 16 | 17 | 18 | *.DS_Store 19 | .AppleDouble 20 | .LSOverride 21 | 22 | # Icon must end with two \r 23 | Icon 24 | 25 | 26 | # Thumbnails 27 | ._* 28 | 29 | # Files that might appear in the root of a volume 30 | .DocumentRevisions-V100 31 | .fseventsd 32 | .Spotlight-V100 33 | .TemporaryItems 34 | .Trashes 35 | .VolumeIcon.icns 36 | .com.apple.timemachine.donotpresent 37 | 38 | # Directories potentially created on remote AFP share 39 | .AppleDB 40 | .AppleDesktop 41 | Network Trash Folder 42 | Temporary Items 43 | .apdisk 44 | 45 | 46 | # swap 47 | [._]*.s[a-v][a-z] 48 | [._]*.sw[a-p] 49 | [._]s[a-v][a-z] 50 | [._]sw[a-p] 51 | # session 52 | Session.vim 53 | # temporary 54 | .netrwhist 55 | *~ 56 | # auto-generated tag files 57 | tags 58 | 59 | 60 | # cache files for sublime text 61 | *.tmlanguage.cache 62 | *.tmPreferences.cache 63 | *.stTheme.cache 64 | 65 | # workspace files are user-specific 66 | *.sublime-workspace 67 | 68 | # project files should be checked into the repository, unless a significant 69 | # proportion of contributors will probably not be using SublimeText 70 | # *.sublime-project 71 | 72 | # sftp configuration file 73 | sftp-config.json 74 | 75 | # Package control specific files 76 | Package Control.last-run 77 | Package Control.ca-list 78 | Package Control.ca-bundle 79 | Package Control.system-ca-bundle 80 | Package Control.cache/ 81 | Package Control.ca-certs/ 82 | Package Control.merged-ca-bundle 83 | Package Control.user-ca-bundle 84 | oscrypto-ca-bundle.crt 85 | bh_unicode_properties.cache 86 | 87 | # Sublime-github package stores a github token in this file 88 | # https://packagecontrol.io/packages/sublime-github 89 | GitHub.sublime-settings 90 | 91 | 92 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 93 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 94 | 95 | # User-specific stuff: 96 | .idea 97 | .idea/**/workspace.xml 98 | .idea/**/tasks.xml 99 | 100 | # Sensitive or high-churn files: 101 | .idea/**/dataSources/ 102 | .idea/**/dataSources.ids 103 | .idea/**/dataSources.xml 104 | .idea/**/dataSources.local.xml 105 | .idea/**/sqlDataSources.xml 106 | .idea/**/dynamic.xml 107 | .idea/**/uiDesigner.xml 108 | 109 | # Gradle: 110 | .idea/**/gradle.xml 111 | .idea/**/libraries 112 | 113 | # Mongo Explorer plugin: 114 | .idea/**/mongoSettings.xml 115 | 116 | ## File-based project format: 117 | *.iws 118 | 119 | ## Plugin-specific files: 120 | 121 | # IntelliJ 122 | /out/ 123 | 124 | # mpeltonen/sbt-idea plugin 125 | .idea_modules/ 126 | 127 | # JIRA plugin 128 | atlassian-ide-plugin.xml 129 | 130 | # Crashlytics plugin (for Android Studio and IntelliJ) 131 | com_crashlytics_export_strings.xml 132 | crashlytics.properties 133 | crashlytics-build.properties 134 | fabric.properties 135 | 136 | 137 | # Byte-compiled / optimized / DLL files 138 | __pycache__/ 139 | *.py[cod] 140 | *$py.class 141 | 142 | # C extensions 143 | *.so 144 | 145 | # Distribution / packaging 146 | .Python 147 | env/ 148 | build/ 149 | develop-eggs/ 150 | dist/ 151 | downloads/ 152 | eggs/ 153 | .eggs/ 154 | lib/ 155 | lib64/ 156 | parts/ 157 | sdist/ 158 | var/ 159 | *.egg-info/ 160 | .installed.cfg 161 | *.egg 162 | 163 | # PyInstaller 164 | # Usually these files are written by a python script from a template 165 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 166 | *.manifest 167 | *.spec 168 | 169 | # Installer logs 170 | pip-log.txt 171 | pip-delete-this-directory.txt 172 | 173 | # Unit test / coverage reports 174 | htmlcov/ 175 | .tox/ 176 | .coverage 177 | .coverage.* 178 | .cache 179 | nosetests.xml 180 | coverage.xml 181 | *,cover 182 | .hypothesis/ 183 | 184 | # Translations 185 | *.mo 186 | *.pot 187 | 188 | # Django stuff: 189 | *.log 190 | local_settings.py 191 | 192 | # Flask stuff: 193 | instance/ 194 | .webassets-cache 195 | 196 | # Scrapy stuff: 197 | .scrapy 198 | 199 | # Sphinx documentation 200 | docs/_build/ 201 | 202 | # PyBuilder 203 | target/ 204 | 205 | # IPython Notebook 206 | .ipynb_checkpoints 207 | 208 | # pyenv 209 | .python-version 210 | 211 | # celery beat schedule file 212 | celerybeat-schedule 213 | 214 | # dotenv 215 | .env 216 | 217 | # virtualenv 218 | venv/ 219 | ENV/ 220 | 221 | # Spyder project settings 222 | .spyderproject 223 | 224 | # Rope project settings 225 | .ropeproject 226 | 227 | 228 | # Project specific 229 | examples/data 230 | examples/logs 231 | 232 | logs/ 233 | IDE_sigma.pt 234 | *.lnk 235 | *.bak 236 | reid/models/PCB_model - Copy.py 237 | ~$pcb_rpp_result.xlsx 238 | *.h5 239 | -------------------------------------------------------------------------------- /IDE.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['OMP_NUM_THREADS'] = '1' 4 | import argparse 5 | import datetime 6 | import sys 7 | import shutil 8 | from distutils.dir_util import copy_tree 9 | import time 10 | import numpy as np 11 | import torch 12 | from reid import models 13 | from reid.camstyle_trainer import CamStyleTrainer 14 | from reid.evaluators import Evaluator 15 | from reid.loss import * 16 | from reid.trainers import Trainer 17 | from reid.utils.logger import Logger 18 | from reid.utils.draw_curve import * 19 | from reid.utils.get_loaders import * 20 | from reid.utils.serialization import save_checkpoint 21 | 22 | ''' 23 | no crop for duke_tracking by default check 24 | RE check 25 | ''' 26 | 27 | 28 | def main(args): 29 | # seed 30 | if args.seed is not None: 31 | np.random.seed(args.seed) 32 | torch.manual_seed(args.seed) 33 | torch.backends.cudnn.deterministic = True 34 | torch.backends.cudnn.benchmark = False 35 | else: 36 | torch.backends.cudnn.benchmark = True 37 | 38 | if args.logs_dir is None: 39 | args.logs_dir = osp.join(f'logs/ide/{args.dataset}', datetime.datetime.today().strftime('%Y-%m-%d_%H-%M-%S')) 40 | else: 41 | args.logs_dir = osp.join(f'logs/ide/{args.dataset}', args.logs_dir) 42 | if args.train: 43 | os.makedirs(args.logs_dir, exist_ok=True) 44 | copy_tree('./reid', args.logs_dir + '/scripts/reid') 45 | for script in os.listdir('.'): 46 | if script.split('.')[-1] == 'py': 47 | dst_file = os.path.join(args.logs_dir, 'scripts', os.path.basename(script)) 48 | shutil.copyfile(script, dst_file) 49 | sys.stdout = Logger(os.path.join(args.logs_dir, 'log.txt'), ) 50 | print('Settings:') 51 | print(vars(args)) 52 | print('\n') 53 | 54 | # Create data loaders 55 | dataset, num_classes, train_loader, query_loader, gallery_loader, camstyle_loader = \ 56 | get_data(args.dataset, args.data_dir, args.height, args.width, args.batch_size, args.num_workers, 57 | args.combine_trainval, args.crop, args.tracking_icams, args.tracking_fps, args.re, 0, args.camstyle) 58 | 59 | # Create model 60 | model = models.create('ide', feature_dim=args.feature_dim, num_classes=num_classes, norm=args.norm, 61 | dropout=args.dropout, last_stride=args.last_stride, arch=args.arch) 62 | 63 | # Load from checkpoint 64 | start_epoch = best_top1 = 0 65 | if args.resume: 66 | resume_fname = osp.join(f'logs/ide/{args.dataset}', args.resume, 'model_best.pth.tar') 67 | model, start_epoch, best_top1 = checkpoint_loader(model, resume_fname) 68 | print("=> Last epoch {} best top1 {:.1%}".format(start_epoch, best_top1)) 69 | start_epoch += 1 70 | model = nn.DataParallel(model).cuda() 71 | 72 | # Criterion 73 | criterion = nn.CrossEntropyLoss().cuda() if not args.LSR else LSR_loss().cuda() 74 | 75 | # Optimizer 76 | if hasattr(model.module, 'base'): # low learning_rate the base network (aka. ResNet-50) 77 | base_param_ids = set(map(id, model.module.base.parameters())) 78 | new_params = [p for p in model.parameters() if id(p) not in base_param_ids] 79 | param_groups = [{'params': model.module.base.parameters(), 'lr_mult': 0.1}, 80 | {'params': new_params, 'lr_mult': 1.0}] 81 | else: 82 | param_groups = model.parameters() 83 | optimizer = torch.optim.SGD(param_groups, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, 84 | nesterov=True) 85 | 86 | # Trainer 87 | if args.camstyle == 0: 88 | trainer = Trainer(model, criterion) 89 | else: 90 | trainer = CamStyleTrainer(model, criterion, camstyle_loader) 91 | 92 | # Evaluator 93 | evaluator = Evaluator(model) 94 | 95 | if args.train: 96 | # Schedule learning rate 97 | def adjust_lr(epoch): 98 | step_size = args.step_size 99 | lr = args.lr * (0.1 ** (epoch // step_size)) 100 | for g in optimizer.param_groups: 101 | g['lr'] = lr * g.get('lr_mult', 1) 102 | 103 | # Draw Curve 104 | epoch_s = [] 105 | loss_s = [] 106 | prec_s = [] 107 | eval_epoch_s = [] 108 | eval_top1_s = [] 109 | 110 | # Start training 111 | for epoch in range(start_epoch + 1, args.epochs + 1): 112 | t0 = time.time() 113 | adjust_lr(epoch) 114 | # train_loss, train_prec = 0, 0 115 | train_loss, train_prec = trainer.train(epoch, train_loader, optimizer, fix_bn=args.fix_bn) 116 | 117 | if epoch < args.start_save: 118 | continue 119 | 120 | if epoch % 5 == 0: 121 | top1 = evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery) 122 | eval_epoch_s.append(epoch) 123 | eval_top1_s.append(top1) 124 | else: 125 | top1 = 0 126 | 127 | is_best = top1 >= best_top1 128 | best_top1 = max(top1, best_top1) 129 | save_checkpoint({ 130 | 'state_dict': model.module.state_dict(), 131 | 'epoch': epoch, 132 | 'best_top1': best_top1, 133 | }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) 134 | epoch_s.append(epoch) 135 | loss_s.append(train_loss) 136 | prec_s.append(train_prec) 137 | draw_curve(os.path.join(args.logs_dir, 'train_curve.jpg'), epoch_s, loss_s, prec_s, 138 | eval_epoch_s, None, eval_top1_s) 139 | 140 | t1 = time.time() 141 | t_epoch = t1 - t0 142 | print('\n * Finished epoch {:3d} top1: {:5.1%} best_eval: {:5.1%} {}\n'. 143 | format(epoch, top1, best_top1, ' *' if is_best else '')) 144 | print('*************** Epoch takes time: {:^10.2f} *********************\n'.format(t_epoch)) 145 | pass 146 | 147 | # Final test 148 | print('Test with best model:') 149 | model, start_epoch, best_top1 = checkpoint_loader(model, osp.join(args.logs_dir, 'model_best.pth.tar')) 150 | print("=> Start epoch {} best top1 {:.1%}".format(start_epoch, best_top1)) 151 | 152 | evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery) 153 | else: 154 | print("Test:") 155 | evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery) 156 | pass 157 | 158 | 159 | if __name__ == '__main__': 160 | parser = argparse.ArgumentParser(description="Softmax loss classification") 161 | # data 162 | parser.add_argument('-d', '--dataset', type=str, default='market1501', choices=datasets.names()) 163 | parser.add_argument('-b', '--batch-size', type=int, default=64, help="batch size") 164 | parser.add_argument('-j', '--num-workers', type=int, default=4) 165 | parser.add_argument('--height', type=int, default=256, help="input height, default: 256 for resnet*") 166 | parser.add_argument('--width', type=int, default=128, help="input width, default: 128 for resnet*") 167 | parser.add_argument('--combine-trainval', action='store_true', 168 | help="train and val sets together for training, val set alone for validation") 169 | parser.add_argument('--tracking_icams', type=int, default=0, help="specify if train on single iCam") 170 | parser.add_argument('--tracking_fps', type=int, default=1, help="specify if train on single iCam") 171 | parser.add_argument('--re', type=float, default=0, help="random erasing") 172 | parser.add_argument('--crop', type=bool, default=1, help="resize then crop, default: True") 173 | # model 174 | parser.add_argument('--feature_dim', type=int, default=256) 175 | parser.add_argument('--dropout', type=float, default=0.5) 176 | parser.add_argument('-s', '--last_stride', type=int, default=2, choices=[1, 2]) 177 | parser.add_argument('--norm', action='store_true', help="normalize feat, default: False") 178 | parser.add_argument('--arch', type=str, default='resnet50', choices=['resnet50', 'densenet121'], 179 | help='architecture for base network') 180 | # optimizer 181 | parser.add_argument('--lr', type=float, default=0.1, 182 | help="learning rate of new parameters, for pretrained " 183 | "parameters it is 10 times smaller than this") 184 | parser.add_argument('--momentum', type=float, default=0.9) 185 | parser.add_argument('--weight-decay', type=float, default=5e-4) 186 | parser.add_argument('--LSR', action='store_true', help="use label smooth loss") 187 | # training configs 188 | parser.add_argument('--train', action='store_true', help="train IDE model from start") 189 | parser.add_argument('--fix_bn', type=bool, default=0, help="fix (skip training) BN in base network") 190 | parser.add_argument('--resume', type=str, default=None, metavar='PATH') 191 | parser.add_argument('--epochs', type=int, default=60) 192 | parser.add_argument('--step-size', type=int, default=40) 193 | parser.add_argument('--start_save', type=int, default=0, help="start saving checkpoints after specific epoch") 194 | parser.add_argument('--seed', type=int, default=None) 195 | parser.add_argument('--print-freq', type=int, default=1) 196 | # camstyle batchsize 197 | parser.add_argument('--camstyle', type=int, default=0) 198 | parser.add_argument('--fake_pooling', type=int, default=1) 199 | # misc 200 | parser.add_argument('--data-dir', type=str, metavar='PATH', default=osp.expanduser('~/Data')) 201 | parser.add_argument('--logs-dir', type=str, metavar='PATH', default=None) 202 | main(parser.parse_args()) 203 | -------------------------------------------------------------------------------- /PCB.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['OMP_NUM_THREADS'] = '1' 4 | import argparse 5 | import datetime 6 | import sys 7 | import shutil 8 | from distutils.dir_util import copy_tree 9 | import time 10 | import numpy as np 11 | import torch 12 | from reid import models 13 | from reid.evaluators import Evaluator 14 | from reid.trainers import Trainer 15 | from reid.utils.logger import Logger 16 | from reid.utils.draw_curve import * 17 | from reid.utils.get_loaders import * 18 | from reid.utils.serialization import save_checkpoint 19 | 20 | ''' 21 | ideas for better training from Dr. Yifan Sun 22 | 23 | train resnet BN by default check 24 | no crop check 25 | batch_size = 64 , lr = 0.1 check 26 | dropout -- possible at layer: pool5 check 27 | ''' 28 | 29 | 30 | def main(args): 31 | # seed 32 | if args.seed is not None: 33 | np.random.seed(args.seed) 34 | torch.manual_seed(args.seed) 35 | torch.backends.cudnn.deterministic = True 36 | torch.backends.cudnn.benchmark = False 37 | else: 38 | torch.backends.cudnn.benchmark = True 39 | 40 | if args.logs_dir is None: 41 | args.logs_dir = osp.join(f'logs/pcb/{args.dataset}', datetime.datetime.today().strftime('%Y-%m-%d_%H-%M-%S')) 42 | else: 43 | args.logs_dir = osp.join(f'logs/pcb/{args.dataset}', args.logs_dir) 44 | if args.train: 45 | os.makedirs(args.logs_dir, exist_ok=True) 46 | copy_tree('./reid', args.logs_dir + '/scripts/reid') 47 | for script in os.listdir('.'): 48 | if script.split('.')[-1] == 'py': 49 | dst_file = os.path.join(args.logs_dir, 'scripts', os.path.basename(script)) 50 | shutil.copyfile(script, dst_file) 51 | sys.stdout = Logger(os.path.join(args.logs_dir, 'log.txt'), ) 52 | print('Settings:') 53 | print(vars(args)) 54 | print('\n') 55 | 56 | # Create data loaders 57 | dataset, num_classes, train_loader, query_loader, gallery_loader, camstyle_loader = \ 58 | get_data(args.dataset, args.data_dir, args.height, args.width, args.batch_size, args.num_workers, 59 | args.combine_trainval, args.crop, args.tracking_icams, args.tracking_fps, args.re, 0, args.camstyle) 60 | 61 | # Create model 62 | model = models.create('pcb', feature_dim=args.feature_dim, num_classes=num_classes, norm=args.norm, 63 | dropout=args.dropout, last_stride=args.last_stride) 64 | 65 | # Load from checkpoint 66 | start_epoch = best_top1 = 0 67 | if args.resume: 68 | resume_fname = osp.join(f'logs/pcb/{args.dataset}', args.resume, 'model_best.pth.tar') 69 | model, start_epoch, best_top1 = checkpoint_loader(model, resume_fname) 70 | print("=> Last epoch {} best top1 {:.1%}".format(start_epoch, best_top1)) 71 | start_epoch += 1 72 | model = nn.DataParallel(model).cuda() 73 | 74 | # Criterion 75 | criterion = nn.CrossEntropyLoss().cuda() 76 | 77 | # Optimizer 78 | if hasattr(model.module, 'base'): # low learning_rate the base network (aka. ResNet-50) 79 | base_param_ids = set(map(id, model.module.base.parameters())) 80 | new_params = [p for p in model.parameters() if id(p) not in base_param_ids] 81 | param_groups = [{'params': model.module.base.parameters(), 'lr_mult': 0.1}, 82 | {'params': new_params, 'lr_mult': 1.0}] 83 | else: 84 | param_groups = model.parameters() 85 | optimizer = torch.optim.SGD(param_groups, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, 86 | nesterov=True) 87 | 88 | # Trainer 89 | trainer = Trainer(model, criterion) 90 | 91 | # Evaluator 92 | evaluator = Evaluator(model) 93 | 94 | if args.train: 95 | # Schedule learning rate 96 | def adjust_lr(epoch): 97 | step_size = args.step_size 98 | lr = args.lr * (0.1 ** (epoch // step_size)) 99 | for g in optimizer.param_groups: 100 | g['lr'] = lr * g.get('lr_mult', 1) 101 | 102 | # Draw Curve 103 | epoch_s = [] 104 | loss_s = [] 105 | prec_s = [] 106 | eval_epoch_s = [] 107 | eval_top1_s = [] 108 | 109 | # Start training 110 | for epoch in range(start_epoch + 1, args.epochs + 1): 111 | t0 = time.time() 112 | adjust_lr(epoch) 113 | # train_loss, train_prec = 0, 0 114 | train_loss, train_prec = trainer.train(epoch, train_loader, optimizer, fix_bn=args.fix_bn) 115 | 116 | if epoch < args.start_save: 117 | continue 118 | 119 | if epoch % 5 == 0: 120 | top1 = evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery) 121 | eval_epoch_s.append(epoch) 122 | eval_top1_s.append(top1) 123 | else: 124 | top1 = 0 125 | 126 | is_best = top1 >= best_top1 127 | best_top1 = max(top1, best_top1) 128 | save_checkpoint({ 129 | 'state_dict': model.module.state_dict(), 130 | 'epoch': epoch, 131 | 'best_top1': best_top1, 132 | }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) 133 | epoch_s.append(epoch) 134 | loss_s.append(train_loss) 135 | prec_s.append(train_prec) 136 | draw_curve(os.path.join(args.logs_dir, 'train_curve.jpg'), epoch_s, loss_s, prec_s, 137 | eval_epoch_s, None, eval_top1_s) 138 | 139 | t1 = time.time() 140 | t_epoch = t1 - t0 141 | print('\n * Finished epoch {:3d} top1: {:5.1%} best_eval: {:5.1%} {}\n'. 142 | format(epoch, top1, best_top1, ' *' if is_best else '')) 143 | print('*************** Epoch takes time: {:^10.2f} *********************\n'.format(t_epoch)) 144 | pass 145 | 146 | # Final test 147 | print('Test with best model:') 148 | model, start_epoch, best_top1 = checkpoint_loader(model, osp.join(args.logs_dir, 'model_best.pth.tar')) 149 | print("=> Start epoch {} best top1 {:.1%}".format(start_epoch, best_top1)) 150 | 151 | evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery) 152 | else: 153 | print("Test:") 154 | evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery) 155 | pass 156 | 157 | 158 | if __name__ == '__main__': 159 | parser = argparse.ArgumentParser(description="Softmax loss classification") 160 | # data 161 | parser.add_argument('-d', '--dataset', type=str, default='market1501', choices=datasets.names()) 162 | parser.add_argument('-b', '--batch-size', type=int, default=64, help="batch size") 163 | parser.add_argument('-j', '--num-workers', type=int, default=4) 164 | parser.add_argument('--height', type=int, default=384, help="input height, default: 384 for PCB*") 165 | parser.add_argument('--width', type=int, default=128, help="input width, default: 128 for resnet*") 166 | parser.add_argument('--combine-trainval', action='store_true', 167 | help="train and val sets together for training, val set alone for validation") 168 | parser.add_argument('--tracking_icams', type=int, default=0, help="specify if train on single iCam") 169 | parser.add_argument('--tracking_fps', type=int, default=1, help="specify if train on single iCam") 170 | parser.add_argument('--re', type=float, default=0, help="random erasing") 171 | parser.add_argument('--crop', type=bool, default=0, help="resize then crop, default: False") 172 | # model 173 | parser.add_argument('--feature_dim', type=int, default=256) 174 | parser.add_argument('--dropout', type=float, default=0.5) 175 | parser.add_argument('-s', '--last_stride', type=int, default=1, choices=[1, 2]) 176 | parser.add_argument('--norm', action='store_true', help="normalize feat, default: False") 177 | parser.add_argument('--arch', type=str, default='resnet50', choices=['resnet50', 'densenet121'], 178 | help='architecture for base network') 179 | # optimizer 180 | parser.add_argument('--lr', type=float, default=0.1, 181 | help="learning rate of new parameters, for pretrained " 182 | "parameters it is 10 times smaller than this") 183 | parser.add_argument('--momentum', type=float, default=0.9) 184 | parser.add_argument('--weight-decay', type=float, default=5e-4) 185 | # training configs 186 | parser.add_argument('--train', action='store_true', help="train PCB model from start") 187 | parser.add_argument('--fix_bn', type=bool, default=0, help="fix (skip training) BN in base network") 188 | parser.add_argument('--resume', type=str, default=None, metavar='PATH') 189 | parser.add_argument('--epochs', type=int, default=60) 190 | parser.add_argument('--step-size', type=int, default=40) 191 | parser.add_argument('--start_save', type=int, default=0, help="start saving checkpoints after specific epoch") 192 | parser.add_argument('--seed', type=int, default=None) 193 | parser.add_argument('--print-freq', type=int, default=1) 194 | # camstyle batchsize 195 | parser.add_argument('--camstyle', type=int, default=0) 196 | # misc 197 | parser.add_argument('--data-dir', type=str, metavar='PATH', default=osp.expanduser('~/Data')) 198 | parser.add_argument('--logs-dir', type=str, metavar='PATH', default=None) 199 | main(parser.parse_args()) 200 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Open-ReID-tracking 2 | 3 | This repo is based on Cysu's [open-reid](https://github.com/Cysu/open-reid), which is a great re-ID library. For performance, we implemented some other baseline models on top of it. For utility, we add some function for the tracking-by-detection workflow in tracking works. 4 | 5 | - update all models for performance & readability. 6 | - add ```data/README.md```. check for folder structure & dataset download. 7 | - add ```requirements.txt```. use ```conda install --file requirements.txt``` to install. 8 | - add BN after feature layer in ```reid/models/IDE_model.py``` for separation. This introduces a higher performance. 9 | - fix high cpu usage via adding ```os.environ['OMP_NUM_THREADS'] = '1'``` in runable files. 10 | - NEW: We adopt a baseline from Hao Luo \[[git](https://github.com/michuanhaohao/reid-strong-baseline), [paper](https://arxiv.org/abs/1903.07071)\]. See ```ZJU.py```. We achieve competitive performance with the same `IDE_model.py`. 11 | 12 | Please use this repo alongside with our flavor of [DeepCC](https://github.com/hou-yz/DeepCC_aic) tracker for tracking. 13 | 14 | ## Model 15 | - IDE \[[paper](https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Zheng_Scalable_Person_Re-Identification_ICCV_2015_paper.pdf)\] 16 | - Triplet \[[paper](https://arxiv.org/abs/1703.07737)\] 17 | - PCB \[[git](https://github.com/syfafterzy/PCB_RPP_for_reID), [paper](http://openaccess.thecvf.com/content_ECCV_2018/papers/Yifan_Sun_Beyond_Part_Models_ECCV_2018_paper.pdf)\] 18 | - ZJU \[[git](https://github.com/michuanhaohao/reid-strong-baseline), [paper](https://arxiv.org/abs/1903.07071)\] 19 | 20 | 21 | ## Data 22 | The re-ID datasets should be stored in a file structure like this: 23 | ``` 24 | ~ 25 | └───Data 26 | └───AIC19 27 | │ │ track-1 data 28 | │ │ ... 29 | │ 30 | └───AIC19-reid 31 | │ │ track-2 data 32 | │ │ ... 33 | │ 34 | └───VeRi 35 | │ │ ... 36 | │ 37 | └───DukeMTMC-reID 38 | │ │ ... 39 | │ 40 | └───Market-1501-v15.09.15 41 | │ ... 42 | ``` 43 | 44 | 45 | ## Usage 46 | ### Re-ID 47 | training from scratch 48 | ```shell script 49 | CUDA_VISIBLE_DEVICES=0 python3 IDE.py -d market1501 --train 50 | ``` 51 | this will automatically save your logs at `./logs/ide/market1501/YYYY-MM-DD_HH-MM-SS`, where `YYYY-MM-DD_HH-MM-SS` is the time stamp when the training started. 52 | 53 | resume & evaluate 54 | ```shell script 55 | CUDA_VISIBLE_DEVICES=0 python3 IDE.py -d market1501 --resume YYYY-MM-DD_HH-MM-SS 56 | ``` 57 | 58 | ### Feature Extraction for Tracking (to be updated) 59 | We describe the workflow for a simple model. For the full ensemble model, please check 60 | 61 | First, please use the following to extract detection bounding boxes from videos. 62 | ```shell script 63 | python3 reid/prepare/extract_bbox.py 64 | ``` 65 | 66 | Next, train the baseline on re-ID data from AI-City 2019 (track-2). 67 | ```shell script 68 | # train 69 | CUDA_VISIBLE_DEVICES=0,1 python3 ZJU.py --train -d aic_reid --logs-dir logs/ZJU/256/aic_reid/lr001_colorjitter --colorjitter --height 256 --width 256 --lr 0.01 --step-size 30,60,80 --warmup 10 --LSR --backbone densenet121 --features 256 --BNneck -s 1 -b 64 --epochs 120 70 | ``` 71 | Then, the detection bounding box feature are computed. 72 | ```shell script 73 | # gt feat (optional) 74 | # CUDA_VISIBLE_DEVICES=0,1 python3 save_cnn_feature.py -a zju --backbone densenet121 --resume logs/ZJU/256/aic_reid/lr001_colorjitter/model_best.pth.tar --features 256 --height 256 --width 256 --l0_name zju_lr001_colorjitter_256 --BNneck -s 1 -d aic --type gt_all -b 64 75 | # reid feat (parameter tuning, see DeepCC_aic) 76 | CUDA_VISIBLE_DEVICES=0,1 python3 save_cnn_feature.py -a zju --backbone densenet121 --resume logs/ZJU/256/aic_reid/lr001_colorjitter/model_best.pth.tar --features 256 --height 256 --width 256 --l0_name zju_lr001_colorjitter_256 --BNneck -s 1 -d aic --type gt_mini -b 64 77 | # det feat (tracking pre-requisite, see DeepCC_aic) 78 | CUDA_VISIBLE_DEVICES=0,1 python3 save_cnn_feature.py -a zju --backbone densenet121 --resume logs/ZJU/256/aic_reid/lr001_colorjitter/model_best.pth.tar --features 256 --height 256 --width 256 --l0_name zju_lr001_colorjitter_256 --BNneck -s 1 -d aic --type detections --det_time trainval -b 64 79 | CUDA_VISIBLE_DEVICES=0,1 python3 save_cnn_feature.py -a zju --backbone densenet121 --resume logs/ZJU/256/aic_reid/lr001_colorjitter/model_best.pth.tar --features 256 --height 256 --width 256 --l0_name zju_lr001_colorjitter_256 --BNneck -s 1 -d aic --type detections --det_time test -b 64 80 | ``` 81 | 82 | ## Implementation details 83 | 84 | Cross-entropy loss: 85 | - `batch_size = 64`. 86 | - `learning rate = 0.1`, step decay after 40 epochs. Train for 60 epochs in total. 87 | - 0.1x learning rate for `resnet-50` base. 88 | - `weight decay = 5e-4`. 89 | - SGD optimizer, `momentum = 0.9`, `nestrov = true`. 90 | 91 | Triplet loss: 92 | - `margin=0.3`. 93 | - `ims_per_id = 4`, `ids_per_batch = 32`. 94 | - `learning rate = 2e-4`, exponentially decay after 150 epochs. Train for 300 epochs in total. 95 | - unifide learning rate for `resnet-50` base and `fc` feature layer. 96 | - `weight decay = 5e-4`. 97 | - Adam optimizer. 98 | 99 | 100 | `Default` Settings: 101 | - IDE 102 | - `stride = 2` in last conv block. 103 | - `h x w = 256 x 128`. 104 | - random horizontal flip + random crop. 105 | - Triplet 106 | - `stride = 2` in last conv block. 107 | - `h x w = 256 x 128`. 108 | - random horizontal flip + random crop. 109 | - PCB 110 | - `stride = 1` in last conv block. 111 | - `h x w = 384 x 128`. 112 | - random horizontal flip. 113 | - ZJU 114 | - cross entropy + triplet. 115 | - `ims_per_id = 4`, `ids_per_batch = 16`. 116 | - `h x w = 256 x 128`. 117 | - warmup for 10 epochs. 118 | - random horizontal flip + pad 10 pixel then random crop + random erasing with `re = 0.5`. 119 | - label smooth. 120 | - `stride = 1` in last conv block. 121 | - ~~BNneck.~~ 122 | - ~~center loss.~~ 123 | 124 | `Tracking` settings for IDE, Triplet, and PCB: 125 | - `stride = 1` in last conv block. 126 | - `h x w = 384 x 128`. 127 | - horizontal flipping + Random Erasing with `re = 0.5`. 128 | 129 | `Raw` setting for ZJU: 130 | - cross entropy + triplet. 131 | - `ims_per_id = 4`, `ids_per_batch = 16`. 132 | - `h x w = 256 x 128`. 133 | - random horizontal flip + pad 10 pixel then random crop. 134 | 135 | 136 | 137 | ## Experiment Results 138 | 139 | | dataset | model | settings | mAP (%) | Rank-1 (%) | 140 | | --- | --- | --- | :---: | :---: | 141 | | Duke|IDE|Default | 58.70 | 77.56 | 142 | | Duke|Triplet|Default | 62.40 | 78.19 | 143 | | Duke|PCB|Default | 68.72 | 83.12 | 144 | | Duke|ZJU|Default | 75.20 | 86.71 | 145 | | Market|IDE|Default | 69.34 | 86.58 | 146 | | Market|Triplet|Default | 72.42 | 86.55 | 147 | | Market|PCB|Default | 77.53 | 92.52 | 148 | | Market|ZJU|Default | 85.37 | 93.79 | 149 | 150 | 163 | -------------------------------------------------------------------------------- /ZJU.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['OMP_NUM_THREADS'] = '1' 4 | import argparse 5 | import datetime 6 | import sys 7 | import shutil 8 | from distutils.dir_util import copy_tree 9 | import time 10 | import numpy as np 11 | import torch 12 | from bisect import bisect_right 13 | from reid import models 14 | from reid.evaluators import Evaluator 15 | from reid.loss import * 16 | from reid.trainers import Trainer 17 | from reid.utils.logger import Logger 18 | from reid.utils.draw_curve import * 19 | from reid.utils.get_loaders import * 20 | from reid.utils.serialization import save_checkpoint 21 | 22 | ''' 23 | tricks from ZJU paper 24 | 25 | warmup check 26 | re check 27 | lsr check 28 | s=1 check 29 | BNneck skip 30 | centerloss skip 31 | ''' 32 | 33 | 34 | def main(args): 35 | args.step_size = args.step_size.split(',') 36 | args.step_size = [int(x) for x in args.step_size] 37 | # seed 38 | if args.seed is not None: 39 | np.random.seed(args.seed) 40 | torch.manual_seed(args.seed) 41 | torch.backends.cudnn.deterministic = True 42 | torch.backends.cudnn.benchmark = False 43 | else: 44 | torch.backends.cudnn.benchmark = True 45 | 46 | if args.logs_dir is None: 47 | args.logs_dir = osp.join(f'logs/zju/{args.dataset}', datetime.datetime.today().strftime('%Y-%m-%d_%H-%M-%S')) 48 | else: 49 | args.logs_dir = osp.join(f'logs/zju/{args.dataset}', args.logs_dir) 50 | if args.train: 51 | os.makedirs(args.logs_dir, exist_ok=True) 52 | copy_tree('./reid', args.logs_dir + '/scripts/reid') 53 | for script in os.listdir('.'): 54 | if script.split('.')[-1] == 'py': 55 | dst_file = os.path.join(args.logs_dir, 'scripts', os.path.basename(script)) 56 | shutil.copyfile(script, dst_file) 57 | sys.stdout = Logger(os.path.join(args.logs_dir, 'log.txt'), ) 58 | print('Settings:') 59 | print(vars(args)) 60 | print('\n') 61 | 62 | # Create data loaders 63 | dataset, num_classes, train_loader, query_loader, gallery_loader, camstyle_loader = \ 64 | get_data(args.dataset, args.data_dir, args.height, args.width, args.batch_size, args.num_workers, 65 | args.combine_trainval, args.crop, args.tracking_icams, args.tracking_fps, args.re, args.num_instances, 66 | camstyle=0, zju=1, colorjitter=args.colorjitter) 67 | 68 | # Create model 69 | model = models.create('ide', feature_dim=args.feature_dim, norm=args.norm, 70 | num_classes=num_classes, last_stride=args.last_stride, arch=args.arch) 71 | 72 | # Load from checkpoint 73 | start_epoch = best_top1 = 0 74 | if args.resume: 75 | resume_fname = osp.join(f'logs/zju/{args.dataset}', args.resume, 'model_best.pth.tar') 76 | model, start_epoch, best_top1 = checkpoint_loader(model, resume_fname) 77 | print("=> Last epoch {} best top1 {:.1%}".format(start_epoch, best_top1)) 78 | start_epoch += 1 79 | model = nn.DataParallel(model).cuda() 80 | 81 | # Criterion 82 | criterion = [LSR_loss().cuda() if args.LSR else nn.CrossEntropyLoss().cuda(), 83 | TripletLoss(margin=None if args.softmargin else args.margin).cuda()] 84 | 85 | # Optimizer 86 | if 'aic' in args.dataset: 87 | # Optimizer 88 | if hasattr(model.module, 'base'): # low learning_rate the base network (aka. DenseNet-121) 89 | base_param_ids = set(map(id, model.module.base.parameters())) 90 | new_params = [p for p in model.parameters() if id(p) not in base_param_ids] 91 | param_groups = [{'params': model.module.base.parameters(), 'lr_mult': 1}, 92 | {'params': new_params, 'lr_mult': 2}] 93 | else: 94 | param_groups = model.parameters() 95 | optimizer = torch.optim.SGD(param_groups, lr=args.lr, momentum=args.momentum, 96 | weight_decay=args.weight_decay) 97 | else: 98 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, ) 99 | 100 | # Trainer 101 | trainer = Trainer(model, criterion) 102 | 103 | # Evaluator 104 | evaluator = Evaluator(model) 105 | 106 | if args.train: 107 | # Schedule learning rate 108 | def adjust_lr(epoch): 109 | if epoch <= args.warmup: 110 | alpha = epoch / args.warmup 111 | warmup_factor = 0.01 * (1 - alpha) + alpha 112 | else: 113 | warmup_factor = 1 114 | lr = args.lr * warmup_factor * (0.1 ** bisect_right(args.step_size, epoch)) 115 | print('Current learning rate: {}'.format(lr)) 116 | for g in optimizer.param_groups: 117 | if 'aic' in args.dataset: 118 | g['lr'] = lr * g.get('lr_mult', 1) 119 | else: 120 | g['lr'] = lr 121 | 122 | # Draw Curve 123 | epoch_s = [] 124 | loss_s = [] 125 | prec_s = [] 126 | eval_epoch_s = [] 127 | eval_top1_s = [] 128 | 129 | # Start training 130 | for epoch in range(start_epoch + 1, args.epochs + 1): 131 | t0 = time.time() 132 | adjust_lr(epoch) 133 | # train_loss, train_prec = 0, 0 134 | train_loss, train_prec = trainer.train(epoch, train_loader, optimizer, fix_bn=args.fix_bn, print_freq=10) 135 | 136 | if epoch < args.start_save: 137 | continue 138 | 139 | if epoch % 10 == 0: 140 | top1 = evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery) 141 | eval_epoch_s.append(epoch) 142 | eval_top1_s.append(top1) 143 | else: 144 | top1 = 0 145 | 146 | is_best = top1 >= best_top1 147 | best_top1 = max(top1, best_top1) 148 | save_checkpoint({ 149 | 'state_dict': model.module.state_dict(), 150 | 'epoch': epoch, 151 | 'best_top1': best_top1, 152 | }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) 153 | epoch_s.append(epoch) 154 | loss_s.append(train_loss) 155 | prec_s.append(train_prec) 156 | draw_curve(os.path.join(args.logs_dir, 'train_curve.jpg'), epoch_s, loss_s, prec_s, 157 | eval_epoch_s, None, eval_top1_s) 158 | 159 | t1 = time.time() 160 | t_epoch = t1 - t0 161 | print('\n * Finished epoch {:3d} top1: {:5.1%} best_eval: {:5.1%} {}\n'. 162 | format(epoch, top1, best_top1, ' *' if is_best else '')) 163 | print('*************** Epoch takes time: {:^10.2f} *********************\n'.format(t_epoch)) 164 | pass 165 | 166 | # Final test 167 | print('Test with best model:') 168 | model, start_epoch, best_top1 = checkpoint_loader(model, osp.join(args.logs_dir, 'model_best.pth.tar')) 169 | print("=> Start epoch {} best top1 {:.1%}".format(start_epoch, best_top1)) 170 | 171 | evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery) 172 | else: 173 | print("Test:") 174 | evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery) 175 | pass 176 | 177 | 178 | if __name__ == '__main__': 179 | parser = argparse.ArgumentParser(description="ZJU baseline") 180 | # data 181 | parser.add_argument('-d', '--dataset', type=str, default='market1501', choices=datasets.names()) 182 | parser.add_argument('-b', '--batch-size', type=int, default=64, help="batch size") 183 | parser.add_argument('-j', '--num-workers', type=int, default=4) 184 | parser.add_argument('--height', type=int, default=256, help="input height, default: 256 for resnet*") 185 | parser.add_argument('--width', type=int, default=128, help="input width, default: 128 for resnet*") 186 | parser.add_argument('--combine-trainval', action='store_true', 187 | help="train and val sets together for training, val set alone for validation") 188 | parser.add_argument('--tracking_icams', type=int, default=0, help="specify if train on single iCam") 189 | parser.add_argument('--tracking_fps', type=int, default=1, help="specify if train on single iCam") 190 | parser.add_argument('--re', type=float, default=0.5, help="random erasing") 191 | parser.add_argument('--crop', type=bool, default=1, help="resize then crop, default: True") 192 | parser.add_argument('--colorjitter', action='store_true', help="resize then crop, default: True") 193 | # model 194 | parser.add_argument('--feature_dim', type=int, default=256) 195 | parser.add_argument('--dropout', type=float, default=0) 196 | parser.add_argument('-s', '--last_stride', type=int, default=1, choices=[1, 2]) 197 | parser.add_argument('--norm', action='store_true', help="normalize feat, default: False") 198 | parser.add_argument('--arch', type=str, default='resnet50', choices=['resnet50', 'densenet121'], 199 | help='architecture for base network') 200 | # loss 201 | parser.add_argument('--margin', type=float, default=0.3, help="margin of the triplet loss, default: 0.3") 202 | parser.add_argument('--softmargin', action='store_true', help="use softmargin triplet loss, default: false") 203 | parser.add_argument('--num-instances', type=int, default=4, 204 | help="each minibatch consist of " 205 | "(batch_size // num_instances) identities, and " 206 | "each identity has num_instances instances, default: 4") 207 | # optimizer 208 | parser.add_argument('--lr', type=float, default=0.00035, 209 | help="learning rate of new parameters, for pretrained " 210 | "parameters it is 10 times smaller than this") 211 | parser.add_argument('--momentum', type=float, default=0.9) 212 | parser.add_argument('--weight-decay', type=float, default=5e-4) 213 | parser.add_argument('--LSR', type=bool, default=1, help="use label smooth loss, default: True") 214 | # training configs 215 | parser.add_argument('--train', action='store_true', help="train IDE model from start") 216 | parser.add_argument('--fix_bn', type=bool, default=0, help="fix (skip training) BN in base network") 217 | parser.add_argument('--resume', type=str, default=None, metavar='PATH') 218 | parser.add_argument('--warmup', type=int, default=10) 219 | parser.add_argument('--epochs', type=int, default=120) 220 | parser.add_argument('--step-size', default='30,60,80') 221 | parser.add_argument('--start_save', type=int, default=0, help="start saving checkpoints after specific epoch") 222 | parser.add_argument('--seed', type=int, default=None) 223 | parser.add_argument('--print-freq', type=int, default=1) 224 | # misc 225 | parser.add_argument('--data-dir', type=str, metavar='PATH', default=osp.expanduser('~/Data')) 226 | parser.add_argument('--logs-dir', type=str, metavar='PATH', default=None) 227 | main(parser.parse_args()) 228 | -------------------------------------------------------------------------------- /reid/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation_metrics 5 | from . import evaluators 6 | from . import feature_extraction 7 | from . import loss 8 | from . import models 9 | from . import trainers 10 | from . import utils 11 | 12 | __version__ = '0.2.0' 13 | -------------------------------------------------------------------------------- /reid/camstyle_trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import time 4 | 5 | import torch 6 | from torch import nn 7 | from torch.autograd import Variable 8 | 9 | from .evaluation_metrics import accuracy 10 | from .loss import * 11 | from .trainers import BaseTrainer 12 | from .utils.meters import AverageMeter 13 | 14 | 15 | class CamStyleTrainer(BaseTrainer): 16 | def __init__(self, model, criterion, camstyle_loader): 17 | super(CamStyleTrainer, self).__init__(model, criterion) 18 | self.camstyle_loader = camstyle_loader 19 | self.camstyle_loader_iter = iter(self.camstyle_loader) 20 | 21 | def train(self, epoch, data_loader, optimizer, fix_bn=False, print_freq=10): 22 | self.model.train() 23 | 24 | if fix_bn: 25 | # set the bn layers to eval() and don't change weight & bias 26 | for m in self.model.module.base.modules(): 27 | if isinstance(m, nn.BatchNorm2d): 28 | m.eval() 29 | if m.affine: 30 | m.weight.requires_grad = False 31 | m.bias.requires_grad = False 32 | 33 | batch_time = AverageMeter() 34 | data_time = AverageMeter() 35 | losses = AverageMeter() 36 | precisions = AverageMeter() 37 | 38 | end = time.time() 39 | for i, inputs in enumerate(data_loader): 40 | data_time.update(time.time() - end) 41 | 42 | try: 43 | camstyle_inputs = next(self.camstyle_loader_iter) 44 | except: 45 | self.camstyle_loader_iter = iter(self.camstyle_loader) 46 | camstyle_inputs = next(self.camstyle_loader_iter) 47 | inputs, targets = self._parse_data(inputs) 48 | camstyle_inputs, camstyle_targets = self._parse_data(camstyle_inputs) 49 | loss, prec1 = self._forward(inputs, targets, camstyle_inputs, camstyle_targets) 50 | 51 | losses.update(loss.item(), targets.size(0)) 52 | precisions.update(prec1, targets.size(0)) 53 | 54 | optimizer.zero_grad() 55 | loss.backward() 56 | optimizer.step() 57 | 58 | batch_time.update(time.time() - end) 59 | end = time.time() 60 | 61 | if (i + 1) % print_freq == 0: 62 | print('Epoch: [{}][{}/{}]\t' 63 | 'Time {:.3f} ({:.3f})\t' 64 | 'Data {:.3f} ({:.3f})\t' 65 | 'Loss {:.3f} ({:.3f})\t' 66 | 'Prec {:.2%} ({:.2%})\t' 67 | .format(epoch, i + 1, len(data_loader), 68 | batch_time.val, batch_time.avg, 69 | data_time.val, data_time.avg, 70 | losses.val, losses.avg, 71 | precisions.val, precisions.avg)) 72 | 73 | return losses.avg, precisions.avg 74 | 75 | def _parse_data(self, inputs): 76 | imgs, _, pids, _ = inputs 77 | inputs = Variable(imgs.cuda()) 78 | targets = Variable(pids.cuda()) 79 | return inputs, targets 80 | 81 | def _forward(self, inputs, targets, camstyle_inputs, camstyle_targets): 82 | outputs = self.model(inputs) 83 | camstyle_outputs = self.model(camstyle_inputs) 84 | if isinstance(self.criterion, torch.nn.CrossEntropyLoss): 85 | # if isinstance(self.model.module, IDE_model) or isinstance(self.model.module, PCB_model): 86 | prediction_s = outputs[1] 87 | loss = 0 88 | for pred in prediction_s: 89 | loss += self.criterion(pred, targets) 90 | prediction = prediction_s[0] 91 | prec, = accuracy(prediction.data, targets.data) 92 | # else: 93 | # loss = self.criterion(outputs, targets) 94 | # prec, = accuracy(outputs.data, targets.data) 95 | prec = prec.item() 96 | elif isinstance(self.criterion, TripletLoss): 97 | loss, prec = self.criterion(outputs, targets) 98 | else: 99 | raise ValueError("Unsupported loss:", self.criterion) 100 | # label soft loss 101 | camstyle_loss = self._lsr_loss(camstyle_outputs[1][0], camstyle_targets) 102 | loss += camstyle_loss 103 | return loss, prec 104 | 105 | def _lsr_loss(self, outputs, targets): 106 | num_class = outputs.size()[1] 107 | targets = self._class_to_one_hot(targets.data.cpu(), num_class) 108 | targets = Variable(targets.cuda()) 109 | outputs = torch.nn.LogSoftmax(dim=1)(outputs) 110 | loss = - (targets * outputs) 111 | loss = loss.sum(dim=1) 112 | loss = loss.mean(dim=0) 113 | return loss 114 | 115 | def _class_to_one_hot(self, targets, num_class): 116 | targets = torch.unsqueeze(targets, 1) 117 | targets_onehot = torch.FloatTensor(targets.size()[0], num_class) 118 | targets_onehot.zero_() 119 | targets_onehot.scatter_(1, targets, 0.9) 120 | targets_onehot.add_(0.1 / num_class) 121 | return targets_onehot 122 | -------------------------------------------------------------------------------- /reid/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .ai_city import AI_City 4 | from .dukemtmc import DukeMTMC 5 | from .market1501 import Market1501 6 | from .veri import VeRi 7 | 8 | __factory = { 9 | 'market1501': Market1501, 10 | 'duke_tracking': DukeMTMC, 11 | 'duke_reid': DukeMTMC, 12 | 'aic_tracking': AI_City, 13 | 'aic_reid': AI_City, 14 | 'veri': VeRi, 15 | } 16 | 17 | 18 | def names(): 19 | return sorted(__factory.keys()) 20 | 21 | 22 | def create(name, *args, **kwargs): 23 | """ 24 | Create a dataset instance. 25 | 26 | Parameters 27 | ---------- 28 | name : str 29 | The dataset name. Can be one of 'viper', 'cuhk01', 'cuhk03', 30 | 'market1501', and 'dukemtmc'. 31 | root : str 32 | The path to the dataset directory. 33 | split_id : int, optional 34 | The index of data split. Default: 0 35 | num_val : int or float, optional 36 | When int, it means the number of validation identities. When float, 37 | it means the proportion of validation to all the trainval. Default: 100 38 | download : bool, optional 39 | If True, will download the dataset. Default: False 40 | """ 41 | if name not in __factory: 42 | raise KeyError("Unknown dataset:", name) 43 | return __factory[name](*args, **kwargs) 44 | -------------------------------------------------------------------------------- /reid/datasets/ai_city.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import os.path as osp 4 | import re 5 | import xml.dom.minidom as XD 6 | from collections import defaultdict 7 | from glob import glob 8 | 9 | 10 | class AI_City(object): 11 | 12 | def __init__(self, root, data_type='reid', fps=10, trainval=False, gt_type='gt'): 13 | if data_type == 'tracking_gt': 14 | self.root = osp.join(root, 'AIC19') 15 | if not trainval: 16 | train_dir = osp.join(root, f'AIC19/ALL_{gt_type}_bbox/train') 17 | else: 18 | train_dir = osp.join(root, f'AIC19/ALL_{gt_type}_bbox/trainval') 19 | val_dir = osp.join(root, f'AIC19/ALL_gt_bbox/val') 20 | self.train_path = osp.join(train_dir, f'gt_bbox_{fps}_fps') 21 | self.gallery_path = osp.join(val_dir, 'gt_bbox_1_fps') 22 | self.query_path = osp.join(val_dir, 'gt_bbox_1_fps') 23 | elif data_type == 'tracking_det': 24 | self.root = root 25 | self.train_path = root 26 | self.gallery_path = None 27 | self.query_path = None 28 | elif data_type == 'reid': # reid 29 | self.root = osp.join(root, 'AIC19-reid') 30 | self.train_path = osp.join(root, 'AIC19-reid/image_train') 31 | self.gallery_path = osp.join(root, 'VeRi/image_query/') 32 | self.query_path = osp.join(root, 'VeRi/image_test/') 33 | 34 | xml_dir = osp.join(root, 'AIC19-reid/train_label.xml') 35 | self.reid_info = XD.parse(xml_dir).documentElement.getElementsByTagName('Item') 36 | self.index_by_fname_dict = defaultdict() 37 | for index in range(len(self.reid_info)): 38 | fname = self.reid_info[index].getAttribute('imageName') 39 | self.index_by_fname_dict[fname] = index 40 | elif data_type == 'reid_test': # reid_test for feature extraction 41 | self.root = osp.join(root, 'AIC19-reid') 42 | self.train_path = None 43 | self.gallery_path = osp.join(root, 'AIC19-reid/image_test') 44 | self.query_path = osp.join(root, 'AIC19-reid/image_query') 45 | else: 46 | raise Exception 47 | 48 | self.train, self.query, self.gallery = [], [], [] 49 | self.num_train_ids, self.num_query_ids, self.num_gallery_ids = 0, 0, 0 50 | self.num_cams = 40 51 | 52 | self.data_type = data_type 53 | self.load() 54 | 55 | def preprocess(self, path, relabel=True, type='reid'): 56 | if type == 'tracking_det': 57 | pattern = re.compile(r'c([-\d]+)_f(\d+)') 58 | elif type == 'tracking_gt': 59 | pattern = re.compile(r'([-\d]+)_c(\d+)') 60 | else: # reid 61 | pattern = None 62 | all_pids = {} 63 | ret = [] 64 | if path is None: 65 | return ret, int(len(all_pids)) 66 | fpaths = sorted(glob(osp.join(path, '*.jpg'))) 67 | for fpath in fpaths: 68 | fname = osp.basename(fpath) 69 | if type == 'tracking_det': 70 | cam, frame = map(int, pattern.search(fname).groups()) 71 | pid = 1 72 | elif type == 'tracking_gt': 73 | pid, cam = map(int, pattern.search(fname).groups()) 74 | elif type == 'reid': # reid 75 | pid, cam = map(int, [self.reid_info[self.index_by_fname_dict[fname]].getAttribute('vehicleID'), 76 | self.reid_info[self.index_by_fname_dict[fname]].getAttribute('cameraID')[1:]]) 77 | else: # reid test 78 | pid, cam = 1, 1 79 | if pid == -1: continue 80 | if relabel: 81 | if pid not in all_pids: 82 | all_pids[pid] = len(all_pids) 83 | else: 84 | if pid not in all_pids: 85 | all_pids[pid] = pid 86 | pid = all_pids[pid] 87 | ret.append((fname, pid, cam - 1)) 88 | return ret, int(len(all_pids)) 89 | 90 | def load(self): 91 | self.train, self.num_train_ids = self.preprocess(self.train_path, True, self.data_type) 92 | self.gallery, self.num_gallery_ids = self.preprocess(self.gallery_path, False, 93 | 'reid_test' if self.data_type == 'reid_test' else 'tracking_gt') 94 | self.query, self.num_query_ids = self.preprocess(self.query_path, False, 95 | 'reid_test' if self.data_type == 'reid_test' else 'tracking_gt') 96 | 97 | print(self.__class__.__name__, "dataset loaded") 98 | print(" subset | # ids | # images") 99 | print(" ---------------------------") 100 | print(" train | {:5d} | {:8d}" 101 | .format(self.num_train_ids, len(self.train))) 102 | print(" query | {:5d} | {:8d}" 103 | .format(self.num_query_ids, len(self.query))) 104 | print(" gallery | {:5d} | {:8d}" 105 | .format(self.num_gallery_ids, len(self.gallery))) 106 | -------------------------------------------------------------------------------- /reid/datasets/dukemtmc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import os.path as osp 4 | import re 5 | from glob import glob 6 | 7 | 8 | class DukeMTMC(object): 9 | 10 | def __init__(self, root, data_type='reid', iCams=None, fps=1, trainval=False): 11 | if iCams is None: 12 | iCams = list(range(1, 9)) 13 | if data_type == 'tracking_gt': 14 | self.root = osp.join(root, 'DukeMTMC') 15 | if not trainval: 16 | train_dir = osp.join(root, 'DukeMTMC/ALL_gt_bbox/train') 17 | else: 18 | train_dir = osp.join(root, 'DukeMTMC/ALL_gt_bbox/trainval') 19 | val_dir = osp.join(root, 'DukeMTMC/ALL_gt_bbox/val') 20 | self.train_path = osp.join(train_dir, f'gt_bbox_{fps}_fps') 21 | self.gallery_path = osp.join(val_dir, 'gt_bbox_1_fps') 22 | self.query_path = osp.join(val_dir, 'gt_bbox_1_fps') 23 | elif data_type == 'tracking_det': 24 | self.root = root 25 | self.train_path = root 26 | self.gallery_path = None 27 | self.query_path = None 28 | elif data_type == 'reid': # reid 29 | self.root = osp.join(root, 'DukeMTMC-reID') 30 | self.train_path = osp.join(root, 'DukeMTMC-reID/bounding_box_train') 31 | self.gallery_path = osp.join(root, 'DukeMTMC-reID/bounding_box_test') 32 | self.query_path = osp.join(root, 'DukeMTMC-reID/query') 33 | else: 34 | raise Exception 35 | 36 | self.camstyle_path = osp.join(root, 'DukeMTMC-reID/bounding_box_train_camstyle') 37 | self.train, self.query, self.gallery, self.camstyle = [], [], [], [] 38 | self.num_train_ids, self.num_query_ids, self.num_gallery_ids, self.num_camstyle_ids = 0, 0, 0, 0 39 | self.num_cams = 8 40 | 41 | self.data_type = data_type 42 | self.iCams = iCams 43 | self.load() 44 | 45 | def preprocess(self, path, relabel=True, type='reid'): 46 | if type == 'tracking_det': 47 | pattern = re.compile(r'c(\d+)_f(\d+)') 48 | else: 49 | pattern = re.compile(r'([-\d]+)_c(\d+)') 50 | all_pids = {} 51 | ret = [] 52 | if path is None: 53 | return ret, int(len(all_pids)) 54 | if type == 'tracking_gt': 55 | fpaths = [] 56 | for iCam in self.iCams: 57 | fpaths += sorted(glob(osp.join(path, 'camera' + str(iCam), '*.jpg'))) 58 | else: 59 | fpaths = sorted(glob(osp.join(path, '*.jpg'))) 60 | for fpath in fpaths: 61 | fname = osp.basename(fpath) 62 | if type == 'tracking_det': 63 | cam, frame = map(int, pattern.search(fname).groups()) 64 | pid = 8000 65 | else: 66 | pid, cam = map(int, pattern.search(fname).groups()) 67 | if type == 'tracking_gt': 68 | fname = osp.join('camera' + str(cam), osp.basename(fpath)) 69 | if pid == -1: continue 70 | if relabel: 71 | if pid not in all_pids: 72 | all_pids[pid] = len(all_pids) 73 | else: 74 | if pid not in all_pids: 75 | all_pids[pid] = pid 76 | pid = all_pids[pid] 77 | ret.append((fname, pid, cam - 1)) 78 | return ret, int(len(all_pids)) 79 | 80 | def load(self): 81 | self.train, self.num_train_ids = self.preprocess(self.train_path, True, self.data_type) 82 | self.gallery, self.num_gallery_ids = self.preprocess(self.gallery_path, False, self.data_type) 83 | self.query, self.num_query_ids = self.preprocess(self.query_path, False, self.data_type) 84 | self.camstyle, self.num_camstyle_ids = self.preprocess(self.camstyle_path, True, self.data_type) 85 | 86 | print(self.__class__.__name__, "dataset loaded") 87 | print(" subset | # ids | # images") 88 | print(" ---------------------------") 89 | print(" train | {:5d} | {:8d}" 90 | .format(self.num_train_ids, len(self.train))) 91 | print(" query | {:5d} | {:8d}" 92 | .format(self.num_query_ids, len(self.query))) 93 | print(" gallery | {:5d} | {:8d}" 94 | .format(self.num_gallery_ids, len(self.gallery))) 95 | print(" camstyle | {:5d} | {:8d}" 96 | .format(self.num_camstyle_ids, len(self.camstyle))) 97 | -------------------------------------------------------------------------------- /reid/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import os.path as osp 4 | import re 5 | from glob import glob 6 | 7 | 8 | class Market1501(object): 9 | def __init__(self, root): 10 | self.root = osp.join(root, 'Market-1501-v15.09.15') 11 | self.train_path = osp.join(root, 'Market-1501-v15.09.15/bounding_box_train') 12 | self.gallery_path = osp.join(root, 'Market-1501-v15.09.15/bounding_box_test') 13 | self.query_path = osp.join(root, 'Market-1501-v15.09.15/query') 14 | self.camstyle_path = osp.join(root, 'Market-1501-v15.09.15/bounding_box_train_camstyle') 15 | self.train, self.query, self.gallery, self.camstyle = [], [], [], [] 16 | self.num_train_ids, self.num_query_ids, self.num_gallery_ids, self.num_camstyle_ids = 0, 0, 0, 0 17 | self.num_cams = 6 18 | self.load() 19 | 20 | def preprocess(self, path, relabel=True): 21 | pattern = re.compile(r'([-\d]+)_c(\d+)') 22 | all_pids = {} 23 | ret = [] 24 | fpaths = sorted(glob(osp.join(path, '*.jpg'))) 25 | for fpath in fpaths: 26 | fname = osp.basename(fpath) 27 | pid, cam = map(int, pattern.search(fname).groups()) 28 | if pid == -1: continue 29 | if relabel: 30 | if pid not in all_pids: 31 | all_pids[pid] = len(all_pids) 32 | else: 33 | if pid not in all_pids: 34 | all_pids[pid] = pid 35 | pid = all_pids[pid] 36 | ret.append((fname, pid, cam - 1)) 37 | return ret, int(len(all_pids)) 38 | 39 | def load(self): 40 | self.train, self.num_train_ids = self.preprocess(self.train_path) 41 | self.gallery, self.num_gallery_ids = self.preprocess(self.gallery_path, False) 42 | self.query, self.num_query_ids = self.preprocess(self.query_path, False) 43 | self.camstyle, self.num_camstyle_ids = self.preprocess(self.camstyle_path) 44 | 45 | print(self.__class__.__name__, "dataset loaded") 46 | print(" subset | # ids | # images") 47 | print(" ---------------------------") 48 | print(" train | {:5d} | {:8d}" 49 | .format(self.num_train_ids, len(self.train))) 50 | print(" query | {:5d} | {:8d}" 51 | .format(self.num_query_ids, len(self.query))) 52 | print(" gallery | {:5d} | {:8d}" 53 | .format(self.num_gallery_ids, len(self.gallery))) 54 | print(" camstyle | {:5d} | {:8d}" 55 | .format(self.num_camstyle_ids, len(self.camstyle))) 56 | -------------------------------------------------------------------------------- /reid/datasets/veri.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import os.path as osp 4 | import re 5 | from glob import glob 6 | 7 | 8 | class VeRi(object): 9 | def __init__(self, root): 10 | self.root = osp.join(root, 'VeRi') 11 | self.train_path = osp.join(root, 'VeRi/image_train/') 12 | self.gallery_path = osp.join(root, 'VeRi/image_test/') 13 | self.query_path = osp.join(root, 'VeRi/image_query/') 14 | 15 | self.train, self.query, self.gallery = [], [], [] 16 | self.num_train_ids, self.num_query_ids, self.num_gallery_ids = 0, 0, 0 17 | self.num_cams = 20 18 | 19 | self.load() 20 | 21 | def preprocess(self, path): 22 | pattern = re.compile(r'(\d+)_c(\d+)') 23 | all_pids = {} 24 | ret = [] 25 | fpaths = sorted(glob(osp.join(path, '*.jpg'))) 26 | for fpath in fpaths: 27 | fname = osp.basename(fpath) 28 | pid, cam = map(int, pattern.search(fname).groups()) 29 | if pid == -1: continue 30 | if pid not in all_pids: 31 | all_pids[pid] = len(all_pids) 32 | pid = all_pids[pid] 33 | ret.append((fname, pid, cam - 1)) 34 | return ret, int(len(all_pids)) 35 | 36 | def load(self): 37 | self.train, self.num_train_ids = self.preprocess(self.train_path) 38 | self.gallery, self.num_gallery_ids = self.preprocess(self.gallery_path) 39 | self.query, self.num_query_ids = self.preprocess(self.query_path) 40 | 41 | print(self.__class__.__name__, "dataset loaded") 42 | print(" subset | # ids | # images") 43 | print(" ---------------------------") 44 | print(" train | {:5d} | {:8d}" 45 | .format(self.num_train_ids, len(self.train))) 46 | print(" query | {:5d} | {:8d}" 47 | .format(self.num_query_ids, len(self.query))) 48 | print(" gallery | {:5d} | {:8d}" 49 | .format(self.num_gallery_ids, len(self.gallery))) 50 | -------------------------------------------------------------------------------- /reid/datasets/veri_affinity/affinity_matrix.txt: -------------------------------------------------------------------------------- 1 | 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 2 | 1 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 3 | 0 0 0 1 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 4 | 0 0 1 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 5 | 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 6 | 0 1 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 7 | 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 8 | 0 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 9 | 0 0 1 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 10 | 0 0 1 1 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 11 | 0 1 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 12 | 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 1 0 0 13 | 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 14 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 15 | 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 16 | 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 17 | 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 0 18 | 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 19 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 0 0 1 20 | 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 -------------------------------------------------------------------------------- /reid/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classification import accuracy 4 | from .ranking import cmc, mean_ap 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap', 10 | ] 11 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from ..utils import to_torch 4 | 5 | 6 | def accuracy(output, target, topk=(1,)): 7 | output, target = to_torch(output), to_torch(target) 8 | maxk = max(topk) 9 | batch_size = target.size(0) 10 | 11 | _, pred = output.topk(maxk, 1, True, True) 12 | pred = pred.t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | 15 | ret = [] 16 | for k in topk: 17 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 18 | ret.append(correct_k.mul_(1. / batch_size)) 19 | return ret 20 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | from sklearn.metrics import average_precision_score 7 | 8 | from ..utils import to_numpy 9 | 10 | 11 | def _unique_sample(ids_dict, num): 12 | mask = np.zeros(num, dtype=np.bool) 13 | for _, indices in ids_dict.items(): 14 | i = np.random.choice(indices) 15 | mask[i] = True 16 | return mask 17 | 18 | 19 | def cmc(distmat, query_ids=None, gallery_ids=None, 20 | query_cams=None, gallery_cams=None, topk=100, 21 | separate_camera_set=False, 22 | single_gallery_shot=False, 23 | first_match_break=False): 24 | distmat = to_numpy(distmat) 25 | m, n = distmat.shape 26 | # Fill up default values 27 | if query_ids is None: 28 | query_ids = np.arange(m) 29 | if gallery_ids is None: 30 | gallery_ids = np.arange(n) 31 | if query_cams is None: 32 | query_cams = np.zeros(m).astype(np.int32) 33 | if gallery_cams is None: 34 | gallery_cams = np.ones(n).astype(np.int32) 35 | # Ensure numpy array 36 | query_ids = np.asarray(query_ids) 37 | gallery_ids = np.asarray(gallery_ids) 38 | query_cams = np.asarray(query_cams) 39 | gallery_cams = np.asarray(gallery_cams) 40 | # Sort and find correct matches 41 | indices = np.argsort(distmat, axis=1) 42 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 43 | # Compute CMC for each query 44 | ret = np.zeros(topk) 45 | num_valid_queries = 0 46 | for i in range(m): 47 | # Filter out the same id and same camera 48 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 49 | (gallery_cams[indices[i]] != query_cams[i])) 50 | if separate_camera_set: 51 | # Filter out samples from same camera 52 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 53 | if not np.any(matches[i, valid]): continue 54 | if single_gallery_shot: 55 | repeat = 10 56 | gids = gallery_ids[indices[i][valid]] 57 | inds = np.where(valid)[0] 58 | ids_dict = defaultdict(list) 59 | for j, x in zip(inds, gids): 60 | ids_dict[x].append(j) 61 | else: 62 | repeat = 1 63 | for _ in range(repeat): 64 | if single_gallery_shot: 65 | # Randomly choose one instance for each id 66 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 67 | index = np.nonzero(matches[i, sampled])[0] 68 | else: 69 | index = np.nonzero(matches[i, valid])[0] 70 | delta = 1. / (len(index) * repeat) 71 | for j, k in enumerate(index): 72 | if k - j >= topk: break 73 | if first_match_break: 74 | ret[k - j] += 1 75 | break 76 | ret[k - j] += delta 77 | num_valid_queries += 1 78 | if num_valid_queries == 0: 79 | raise RuntimeError("No valid query") 80 | return ret.cumsum() / num_valid_queries 81 | 82 | 83 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 84 | query_cams=None, gallery_cams=None): 85 | distmat = to_numpy(distmat) 86 | m, n = distmat.shape 87 | # Fill up default values 88 | if query_ids is None: 89 | query_ids = np.arange(m) 90 | if gallery_ids is None: 91 | gallery_ids = np.arange(n) 92 | if query_cams is None: 93 | query_cams = np.zeros(m).astype(np.int32) 94 | if gallery_cams is None: 95 | gallery_cams = np.ones(n).astype(np.int32) 96 | # Ensure numpy array 97 | query_ids = np.asarray(query_ids) 98 | gallery_ids = np.asarray(gallery_ids) 99 | query_cams = np.asarray(query_cams) 100 | gallery_cams = np.asarray(gallery_cams) 101 | # Sort and find correct matches 102 | indices = np.argsort(distmat, axis=1) 103 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 104 | # Compute AP for each query 105 | aps = [] 106 | for i in range(m): 107 | # Filter out the same id and same camera 108 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 109 | (gallery_cams[indices[i]] != query_cams[i])) 110 | y_true = matches[i, valid] 111 | y_score = -distmat[i][indices[i]][valid] 112 | if not np.any(y_true): continue 113 | aps.append(average_precision_score(y_true, y_score)) 114 | if len(aps) == 0: 115 | raise RuntimeError("No valid query") 116 | return np.mean(aps) 117 | -------------------------------------------------------------------------------- /reid/evaluators.py: -------------------------------------------------------------------------------- 1 | import time 2 | import copy 3 | from collections import OrderedDict 4 | import torch 5 | from .evaluation_metrics import cmc, mean_ap 6 | from .feature_extraction import extract_cnn_feature 7 | from .utils.meters import AverageMeter 8 | 9 | 10 | def extract_features(model, data_loader, print_freq=100): 11 | model.eval() 12 | batch_time = AverageMeter() 13 | data_time = AverageMeter() 14 | 15 | features = OrderedDict() 16 | labels = OrderedDict() 17 | 18 | end = time.time() 19 | for i, (imgs, fnames, pids, _) in enumerate(data_loader): 20 | data_time.update(time.time() - end) 21 | 22 | outputs = extract_cnn_feature(model, imgs) 23 | for fname, output, pid in zip(fnames, outputs, pids): 24 | features[fname] = output 25 | labels[fname] = pid 26 | 27 | batch_time.update(time.time() - end) 28 | end = time.time() 29 | 30 | if (i + 1) % print_freq == 0 or (i + 1) == len(data_loader): 31 | print('Extract Features: [{}/{}]\t' 32 | 'Time {:.3f} ({:.3f})\t' 33 | 'Data {:.3f} ({:.3f})\t' 34 | .format(i + 1, len(data_loader), 35 | batch_time.val, batch_time.avg, 36 | data_time.val, data_time.avg)) 37 | 38 | return features, labels 39 | 40 | 41 | def pairwise_distance(query_features, gallery_features, query=None, gallery=None): 42 | if query is not None and gallery is not None: 43 | x = torch.cat([query_features[f].unsqueeze(0) for f, _, _ in query], 0) 44 | y = torch.cat([gallery_features[f].unsqueeze(0) for f, _, _ in gallery], 0) 45 | else: 46 | x = copy.deepcopy(query_features) 47 | y = copy.deepcopy(gallery_features) 48 | m, n = x.size(0), y.size(0) 49 | x = x.view(m, -1) 50 | y = y.view(n, -1) 51 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 52 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 53 | dist.addmm_(1, -2, x, y.t()) 54 | return dist 55 | 56 | 57 | def evaluate_all(distmat, query=None, gallery=None, 58 | query_ids=None, gallery_ids=None, 59 | query_cams=None, gallery_cams=None, 60 | cmc_topk=(1, 5, 10)): 61 | if query is not None and gallery is not None: 62 | query_ids = [pid for _, pid, _ in query] 63 | gallery_ids = [pid for _, pid, _ in gallery] 64 | query_cams = [cam for _, _, cam in query] 65 | gallery_cams = [cam for _, _, cam in gallery] 66 | else: 67 | assert (query_ids is not None and gallery_ids is not None 68 | and query_cams is not None and gallery_cams is not None) 69 | 70 | # Compute mean AP 71 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 72 | # print('Mean AP: {:4.1%}'.format(mAP)) 73 | 74 | # Compute all kinds of CMC scores 75 | cmc_configs = { 76 | # 'allshots': dict(separate_camera_set=False, 77 | # single_gallery_shot=False, 78 | # first_match_break=False), 79 | # 'cuhk03': dict(separate_camera_set=True, 80 | # single_gallery_shot=True, 81 | # first_match_break=False), 82 | 'market1501': dict(separate_camera_set=False, 83 | single_gallery_shot=False, 84 | first_match_break=True)} 85 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 86 | query_cams, gallery_cams, **params) 87 | for name, params in cmc_configs.items()} 88 | 89 | print('[mAP: {:5.2%}], [cmc1: {:5.2%}], [cmc5: {:5.2%}], [cmc10: {:5.2%}]' 90 | .format(mAP, *cmc_scores['market1501'][[0, 4, 9]])) 91 | 92 | # Use the allshots cmc top-1 score for validation criterion 93 | return cmc_scores['market1501'][0] 94 | 95 | 96 | class Evaluator(object): 97 | def __init__(self, model): 98 | super(Evaluator, self).__init__() 99 | self.model = model 100 | 101 | def evaluate(self, query_loader, gallery_loader, query, gallery, ): 102 | self.model.eval() 103 | print('extracting query features\n') 104 | query_features, _ = extract_features(self.model, query_loader) 105 | print('extracting gallery features\n') 106 | gallery_features, _ = extract_features(self.model, gallery_loader) 107 | distmat = pairwise_distance(query_features, gallery_features, query, gallery) 108 | return evaluate_all(distmat, query=query, gallery=gallery) 109 | -------------------------------------------------------------------------------- /reid/feature_extraction/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .cnn import extract_cnn_feature 4 | from .database import FeatureDatabase 5 | 6 | __all__ = [ 7 | 'extract_cnn_feature', 8 | 'FeatureDatabase', 9 | ] 10 | -------------------------------------------------------------------------------- /reid/feature_extraction/cnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from collections import OrderedDict 4 | 5 | import torch 6 | from torch.autograd import Variable 7 | 8 | from ..utils import to_torch 9 | 10 | 11 | def extract_cnn_feature(model, inputs, modules=None): 12 | model.eval() 13 | inputs = to_torch(inputs) 14 | inputs = Variable(inputs, requires_grad=False) 15 | if modules is None: 16 | # if isinstance(model.module, IDE_model) or isinstance(model.module, PCB_model): 17 | with torch.no_grad(): 18 | outputs = model(inputs) 19 | outputs = outputs[0] 20 | # else: 21 | # outputs = model(inputs) 22 | outputs = outputs.data.cpu() 23 | return outputs 24 | # Register forward hook for each module 25 | outputs = OrderedDict() 26 | handles = [] 27 | for m in modules: 28 | outputs[id(m)] = None 29 | 30 | def func(m, i, o): outputs[id(m)] = o.data.cpu() 31 | 32 | handles.append(m.register_forward_hook(func)) 33 | model(inputs) 34 | for h in handles: 35 | h.remove() 36 | return list(outputs.values()) 37 | -------------------------------------------------------------------------------- /reid/feature_extraction/database.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import h5py 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class FeatureDatabase(Dataset): 9 | def __init__(self, *args, **kwargs): 10 | super(FeatureDatabase, self).__init__() 11 | self.fid = h5py.File(*args, **kwargs) 12 | 13 | def __enter__(self): 14 | return self 15 | 16 | def __exit__(self, exc_type, exc_val, exc_tb): 17 | self.close() 18 | 19 | def __getitem__(self, keys): 20 | if isinstance(keys, (tuple, list)): 21 | return [self._get_single_item(k) for k in keys] 22 | return self._get_single_item(keys) 23 | 24 | def _get_single_item(self, key): 25 | return np.asarray(self.fid[key]) 26 | 27 | def __setitem__(self, key, value): 28 | if key in self.fid: 29 | if self.fid[key].shape == value.shape and \ 30 | self.fid[key].dtype == value.dtype: 31 | self.fid[key][...] = value 32 | else: 33 | del self.fid[key] 34 | self.fid.create_dataset(key, data=value) 35 | else: 36 | self.fid.create_dataset(key, data=value) 37 | 38 | def __delitem__(self, key): 39 | del self.fid[key] 40 | 41 | def __len__(self): 42 | return len(self.fid) 43 | 44 | def __iter__(self): 45 | return iter(self.fid) 46 | 47 | def flush(self): 48 | self.fid.flush() 49 | 50 | def close(self): 51 | self.fid.close() 52 | -------------------------------------------------------------------------------- /reid/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .label_smooth import LSR_loss 4 | from .triplet import TripletLoss 5 | 6 | __all__ = [ 7 | 'TripletLoss', 8 | 'LSR_loss' 9 | ] 10 | -------------------------------------------------------------------------------- /reid/loss/label_smooth.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class LSR_loss(nn.Module): 8 | 9 | def __init__(self, e=0.1): 10 | super().__init__() 11 | self.log_softmax = nn.LogSoftmax(dim=1) 12 | self.e = e 13 | 14 | def _one_hot(self, labels, classes, value=1): 15 | one_hot = torch.zeros(labels.size(0), classes) 16 | # labels and value_added size must match 17 | labels = labels.view(labels.size(0), -1) 18 | value_added = torch.Tensor(labels.size(0), 1).fill_(value) 19 | value_added = value_added.to(labels.device) 20 | one_hot = one_hot.to(labels.device) 21 | one_hot.scatter_add_(1, labels, value_added) 22 | return one_hot 23 | 24 | def _smooth_label(self, target, length, smooth_factor): 25 | one_hot = self._one_hot(target, length, value=1 - smooth_factor) 26 | one_hot += smooth_factor / length 27 | return one_hot.to(target.device) 28 | 29 | def forward(self, x, target): 30 | smoothed_target = self._smooth_label(target, x.size(1), self.e) 31 | x = self.log_softmax(x) 32 | loss = torch.sum(- x * smoothed_target, dim=1) 33 | return torch.mean(loss) 34 | -------------------------------------------------------------------------------- /reid/loss/triplet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | def normalize(x, axis=-1): 8 | """Normalizing to unit length along the specified dimension. 9 | Args: 10 | x: pytorch Variable 11 | Returns: 12 | x: pytorch Variable, same shape as input 13 | """ 14 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 15 | return x 16 | 17 | 18 | def euclidean_dist(x, y): 19 | """ 20 | Args: 21 | x: pytorch Variable, with shape [m, d] 22 | y: pytorch Variable, with shape [n, d] 23 | Returns: 24 | dist: pytorch Variable, with shape [m, n] 25 | """ 26 | m, n = x.size(0), y.size(0) 27 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 28 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 29 | dist = xx + yy 30 | dist.addmm_(1, -2, x, y.t()) 31 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 32 | return dist 33 | 34 | 35 | def hard_example_mining(dist_mat, labels, return_inds=False): 36 | """For each anchor, find the hardest positive and negative sample. 37 | Args: 38 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 39 | labels: pytorch LongTensor, with shape [N] 40 | return_inds: whether to return the indices. Save time if `False`(?) 41 | Returns: 42 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 43 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 44 | p_inds: pytorch LongTensor, with shape [N]; 45 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 46 | n_inds: pytorch LongTensor, with shape [N]; 47 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 48 | NOTE: Only consider the case in which all labels have same num of samples, 49 | thus we can cope with all anchors in parallel. 50 | """ 51 | 52 | assert len(dist_mat.size()) == 2 53 | assert dist_mat.size(0) == dist_mat.size(1) 54 | N = dist_mat.size(0) 55 | 56 | # shape [N, N] 57 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 58 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 59 | 60 | # `dist_ap` means distance(anchor, positive) 61 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 62 | dist_ap, relative_p_inds = torch.max(dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 63 | # `dist_an` means distance(anchor, negative) 64 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 65 | dist_an, relative_n_inds = torch.min(dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 66 | # shape [N] 67 | dist_ap = dist_ap.squeeze(1) 68 | dist_an = dist_an.squeeze(1) 69 | 70 | if return_inds: 71 | # shape [N, N] 72 | ind = (labels.new().resize_as_(labels).copy_(torch.arange(0, N).long()).unsqueeze(0).expand(N, N)) 73 | # shape [N, 1] 74 | p_inds = torch.gather(ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 75 | n_inds = torch.gather(ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 76 | # shape [N] 77 | p_inds = p_inds.squeeze(1) 78 | n_inds = n_inds.squeeze(1) 79 | return dist_ap, dist_an, p_inds, n_inds 80 | 81 | return dist_ap, dist_an 82 | 83 | 84 | class TripletLoss(nn.Module): 85 | def __init__(self, margin=None): 86 | super(TripletLoss, self).__init__() 87 | self.margin = margin 88 | if self.margin is not None: 89 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 90 | else: 91 | self.ranking_loss = nn.SoftMarginLoss() 92 | 93 | def forward(self, global_feat, labels, normalize_feature=False): 94 | if normalize_feature: 95 | global_feat = normalize(global_feat, axis=-1) 96 | # shape [N, N] 97 | dist_mat = euclidean_dist(global_feat, global_feat) 98 | dist_ap, dist_an = hard_example_mining(dist_mat, labels) 99 | y = dist_an.new().resize_as_(dist_an).fill_(1) 100 | if self.margin is not None: 101 | loss = self.ranking_loss(dist_an, dist_ap, y) 102 | else: 103 | loss = self.ranking_loss(dist_an - dist_ap, y) 104 | prec = (dist_an.data > dist_ap.data).sum().float() / y.size(0) 105 | return loss, prec, dist_ap, dist_an 106 | -------------------------------------------------------------------------------- /reid/metric/MLP_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import init 5 | import numpy as np 6 | 7 | 8 | class MLP_metric(nn.Module): 9 | def __init__(self, feature_dim=256, num_class=0): 10 | super(MLP_metric, self).__init__() 11 | self.num_class = num_class 12 | layer_dim = 128 13 | 14 | self.fc1 = nn.Sequential(nn.Linear(feature_dim, layer_dim), nn.ReLU()) 15 | self.fc2 = nn.Sequential(nn.Linear(layer_dim, layer_dim), nn.ReLU()) 16 | self.fc3 = nn.Sequential(nn.Linear(layer_dim, layer_dim), nn.ReLU()) 17 | self.dropout = nn.Dropout(0.5) 18 | self.classifier = nn.Linear(layer_dim, self.num_class, bias=False) 19 | init.normal_(self.classifier.weight, std=0.001) 20 | 21 | def forward(self, feat1, feat2): 22 | out = self.fc1((feat2 - feat1).abs()) 23 | out = self.fc2(out) 24 | out = self.fc3(out) 25 | out = self.dropout(out) 26 | out = self.classifier(out) 27 | return out 28 | -------------------------------------------------------------------------------- /reid/metric/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hou-yz/open-reid-tracking/7a40462231e27e607dd094412550f26270b4efb8/reid/metric/__init__.py -------------------------------------------------------------------------------- /reid/metric/metric_evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from reid.evaluators import evaluate_all, pairwise_distance 5 | 6 | 7 | def metric_distance(model, query_features, gallery_features): 8 | dist = np.zeros([len(query_features), len(gallery_features)]) 9 | step = 1024 10 | for i in range(len(query_features)): 11 | for j in range(0, len(gallery_features), step): 12 | numel = min(step, len(gallery_features) - j) 13 | query_feat = query_features[i].view([1, -1]).repeat([numel, 1]).cuda() 14 | gallery_feat = gallery_features[j:j + numel].cuda() 15 | output = model(query_feat, gallery_feat) 16 | dist[i, j:j + numel] = F.softmax(output, dim=1)[:, 0].cpu().detach().numpy() 17 | return dist 18 | 19 | 20 | def metric_evaluate(model, query_set, gallery_set): 21 | model.eval() 22 | print('=> L2 distance') 23 | dist = pairwise_distance(query_set.features, gallery_set.features) 24 | evaluate_all(dist, query_ids=query_set.labels[:, 1], gallery_ids=gallery_set.labels[:, 1], 25 | query_cams=query_set.labels[:, 0], gallery_cams=gallery_set.labels[:, 0], ) 26 | print('=> Metric') 27 | dist = metric_distance(model, query_set.features, gallery_set.features) 28 | evaluate_all(dist, query_ids=query_set.labels[:, 1], gallery_ids=gallery_set.labels[:, 1], 29 | query_cams=query_set.labels[:, 0], gallery_cams=gallery_set.labels[:, 0], ) 30 | return 31 | -------------------------------------------------------------------------------- /reid/metric/metric_trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | 4 | 5 | class BaseTrainer(object): 6 | def __init__(self): 7 | super(BaseTrainer, self).__init__() 8 | 9 | 10 | class CNNTrainer(BaseTrainer): 11 | def __init__(self, model, criterion, ): 12 | super(BaseTrainer, self).__init__() 13 | self.model = model 14 | self.criterion = criterion 15 | 16 | def train(self, epoch, data_loader, optimizer, log_interval=100, cyclic_scheduler=None, ): 17 | self.model.train() 18 | losses = 0 19 | correct = 0 20 | miss = 0 21 | t0 = time.time() 22 | for batch_idx, (data, target) in enumerate(data_loader): 23 | feat1, feat2, target = data[0].cuda(), data[1].cuda(), target.cuda() 24 | optimizer.zero_grad() 25 | output = self.model(feat1, feat2) 26 | pred = torch.argmax(output, 1) 27 | correct += pred.eq(target).sum().item() 28 | miss += target.shape[0] - pred.eq(target).sum().item() 29 | loss = self.criterion(output, target) 30 | loss.backward() 31 | optimizer.step() 32 | losses += loss.item() 33 | if cyclic_scheduler is not None: 34 | if isinstance(cyclic_scheduler, torch.optim.lr_scheduler.CosineAnnealingWarmRestarts): 35 | cyclic_scheduler.step(epoch - 1 + batch_idx / len(data_loader)) 36 | elif isinstance(cyclic_scheduler, torch.optim.lr_scheduler.OneCycleLR): 37 | cyclic_scheduler.step() 38 | if (batch_idx + 1) % log_interval == 0: 39 | # print(cyclic_scheduler.last_epoch, optimizer.param_groups[0]['lr']) 40 | t1 = time.time() 41 | t_epoch = t1 - t0 42 | print('Train Epoch: {}, Batch:{}, \tLoss: {:.6f}, Prec: {:.1f}%, Time: {:.3f}'.format( 43 | epoch, (batch_idx + 1), losses / (batch_idx + 1), 100. * correct / (correct + miss), t_epoch)) 44 | 45 | t1 = time.time() 46 | t_epoch = t1 - t0 47 | print('Train Epoch: {}, Batch:{}, \tLoss: {:.6f}, Prec: {:.1f}%, Time: {:.3f}'.format( 48 | epoch, len(data_loader), losses / len(data_loader), 100. * correct / (correct + miss), t_epoch)) 49 | 50 | return losses / len(data_loader), correct / (correct + miss) 51 | 52 | def test(self, test_loader): 53 | self.model.eval() 54 | losses = 0 55 | correct = 0 56 | miss = 0 57 | t0 = time.time() 58 | 59 | for batch_idx, (data, target) in enumerate(test_loader): 60 | feat1, feat2, target = data[0].cuda(), data[1].cuda(), target.cuda() 61 | with torch.no_grad(): 62 | output = self.model(feat1, feat2) 63 | pred = torch.argmax(output, 1) 64 | correct += pred.eq(target).sum().item() 65 | miss += target.shape[0] - pred.eq(target).sum().item() 66 | loss = self.criterion(output, target) 67 | losses += loss.item() 68 | 69 | print('Test, Loss: {:.6f}, Prec: {:.1f}%, time: {:.1f}'.format(losses / (len(test_loader) + 1), 70 | 100. * correct / (correct + miss), 71 | time.time() - t0)) 72 | 73 | return losses / len(test_loader), correct / (correct + miss) 74 | -------------------------------------------------------------------------------- /reid/metric/reid_feat_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import h5py 4 | from glob import glob 5 | import os.path as osp 6 | from collections import defaultdict 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class HyperFeat(Dataset): 11 | def __init__(self, root, ): 12 | self.root = root 13 | self.data = [] 14 | fpaths = sorted(glob(osp.join(root, '*.h5'))) 15 | for fpath in fpaths: 16 | h5file = h5py.File(fpath, 'r') 17 | self.data.append(np.array(h5file['emb'])) 18 | self.data = np.concatenate(self.data, axis=0) 19 | self.data = self.data[self.data[:, 1] != -1, :] # rm -1 terms 20 | self.features, self.labels = torch.from_numpy(self.data[:, 3:]).float(), self.data[:, :3] 21 | del self.data 22 | # iCam, pid, centerFrame, 256-dim feat 23 | self.feature_dim = self.features.shape[1] 24 | self.pid_dic = [] 25 | self.index_by_icam_pid_dic = defaultdict(dict) 26 | self.index_by_pid_dic = defaultdict(list) 27 | self.index_by_pid_icam_dic = defaultdict(dict) 28 | for index in range(self.labels.shape[0]): 29 | [icam, pid, frame] = self.labels[index, :] 30 | if pid not in self.pid_dic: 31 | self.pid_dic.append(pid) 32 | 33 | if icam not in self.index_by_icam_pid_dic: 34 | self.index_by_icam_pid_dic[icam] = defaultdict(list) 35 | self.index_by_icam_pid_dic[icam][pid].append(index) 36 | 37 | self.index_by_pid_dic[pid].append(index) 38 | 39 | if pid not in self.index_by_pid_icam_dic: 40 | self.index_by_pid_icam_dic[pid] = defaultdict(list) 41 | self.index_by_pid_icam_dic[pid][icam].append(index) 42 | 43 | pass 44 | 45 | def __getitem__(self, index): 46 | feat = self.features[index, :] 47 | iCam, pid, frame = map(int, self.labels[index, :]) 48 | return feat, iCam, pid, frame 49 | 50 | def __len__(self): 51 | return self.labels.shape[0] 52 | 53 | 54 | class SiameseHyperFeat(Dataset): 55 | def __init__(self, h_dataset, ): 56 | self.h_dataset = h_dataset 57 | self.feature_dim = self.h_dataset.feature_dim 58 | 59 | def __len__(self): 60 | return len(self.h_dataset) 61 | 62 | def __getitem__(self, index): 63 | feat1, cam1, pid1, frame1 = self.h_dataset.__getitem__(index) 64 | target = np.random.randint(0, 2) 65 | if pid1 == -1: 66 | target = 0 67 | 68 | # 1 for same 69 | if target == 1: 70 | siamese_index = index 71 | index_pool = self.h_dataset.index_by_pid_dic[pid1] 72 | if len(index_pool) > 1: 73 | while siamese_index == index: 74 | siamese_index = np.random.choice(index_pool) 75 | # 0 for different 76 | else: 77 | pid_pool = self.h_dataset.pid_dic 78 | pid2 = np.random.choice(pid_pool) 79 | if len(pid_pool) > 1: 80 | while pid2 == pid1: 81 | pid2 = np.random.choice(pid_pool) 82 | index_pool = self.h_dataset.index_by_pid_dic[pid2] 83 | siamese_index = np.random.choice(index_pool) 84 | 85 | feat2, cam2, pid2, frame2 = self.h_dataset.__getitem__(siamese_index) 86 | if target != (pid1 == pid2): 87 | target = (pid1 == pid2) 88 | pass 89 | 90 | return (feat1, feat2), target 91 | -------------------------------------------------------------------------------- /reid/models/IDE_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | from torchvision.models import resnet50, densenet121 7 | 8 | 9 | class IDE_model(nn.Module): 10 | def __init__(self, feature_dim=256, num_classes=0, norm=False, dropout=0, last_stride=2, arch='resnet50'): 11 | super(IDE_model, self).__init__() 12 | # Create IDE_only model 13 | self.feature_dim = feature_dim 14 | self.num_classes = num_classes 15 | self.norm = norm 16 | 17 | if arch == 'resnet50': 18 | self.base = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2]) 19 | if last_stride != 2: 20 | # decrease the downsampling rate 21 | # change the stride2 conv layer in self.layer4 to stride=1 22 | self.base[7][0].conv2.stride = last_stride 23 | # change the downsampling layer in self.layer4 to stride=1 24 | self.base[7][0].downsample[0].stride = last_stride 25 | base_dim = 2048 26 | elif arch == 'densenet121': 27 | self.base = nn.Sequential(*list(densenet121(pretrained=True).children())[:-1])[0] 28 | if last_stride != 2: 29 | # remove the pooling layer in last transition block 30 | self.base[-3][-1].stride = 1 31 | self.base[-3][-1].kernel_size = 1 32 | pass 33 | base_dim = 1024 34 | else: 35 | raise Exception('Please select arch from [resnet50, densenet121]!') 36 | 37 | self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1)) 38 | 39 | # feat & feat_bn 40 | if self.feature_dim > 0: 41 | self.feat_fc = nn.Linear(base_dim, feature_dim) 42 | init.kaiming_normal_(self.feat_fc.weight, mode='fan_out') 43 | init.constant_(self.feat_fc.bias, 0.0) 44 | else: 45 | feature_dim = base_dim 46 | 47 | self.feat_bn = nn.BatchNorm1d(feature_dim) 48 | init.constant_(self.feat_bn.weight, 1) 49 | init.constant_(self.feat_bn.bias, 0) 50 | 51 | # dropout before classifier 52 | self.dropout = dropout 53 | if self.dropout > 0: 54 | self.drop_layer = nn.Dropout2d(self.dropout) 55 | 56 | # classifier: 57 | if self.num_classes > 0: 58 | self.classifier = nn.Linear(feature_dim, self.num_classes, bias=False) 59 | init.normal_(self.classifier.weight, std=0.001) 60 | pass 61 | 62 | def forward(self, x): 63 | """ 64 | Returns: 65 | h_s: each member with shape [N, c] 66 | prediction_s: each member with shape [N, num_classes] 67 | """ 68 | x = self.base(x) 69 | x = self.global_avg_pool(x).view(x.shape[0], -1) 70 | feature_base = x 71 | 72 | if self.feature_dim > 0: 73 | x = self.feat_fc(x) 74 | feature_ide = x 75 | else: 76 | feature_ide = feature_base 77 | 78 | x = self.feat_bn(x) 79 | # no relu after feature_fc 80 | 81 | if self.dropout > 0: 82 | x = self.drop_layer(x) 83 | 84 | prediction_s = [] 85 | if self.num_classes > 0 and self.training: 86 | prediction = self.classifier(x) 87 | prediction_s.append(prediction) 88 | 89 | if self.norm: 90 | feature_ide = F.normalize(feature_ide) 91 | 92 | return feature_ide, tuple(prediction_s) 93 | -------------------------------------------------------------------------------- /reid/models/PCB_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | from torchvision.models import resnet50, densenet121 7 | 8 | 9 | class PCB_model(nn.Module): 10 | def __init__(self, num_stripes=6, feature_dim=256, num_classes=0, norm=False, dropout=0, last_stride=1, 11 | arch='resnet50'): 12 | super(PCB_model, self).__init__() 13 | # Create PCB_only model 14 | self.num_stripes = num_stripes 15 | self.feature_dim = feature_dim 16 | self.num_classes = num_classes 17 | self.norm = norm 18 | 19 | if arch == 'resnet50': 20 | self.base = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2]) 21 | if last_stride != 2: 22 | # decrease the downsampling rate 23 | # change the stride2 conv layer in self.layer4 to stride=1 24 | self.base[7][0].conv2.stride = last_stride 25 | # change the downsampling layer in self.layer4 to stride=1 26 | self.base[7][0].downsample[0].stride = last_stride 27 | base_dim = 2048 28 | elif arch == 'densenet121': 29 | self.base = nn.Sequential(*list(densenet121(pretrained=True).children())[:-1])[0] 30 | if last_stride != 2: 31 | # remove the pooling layer in last transition block 32 | self.base[-3][-1].stride = 1 33 | self.base[-3][-1].kernel_size = 1 34 | pass 35 | base_dim = 1024 36 | else: 37 | raise Exception('Please select arch from [resnet50, densenet121]!') 38 | 39 | self.avg_pool = nn.AdaptiveAvgPool2d((6, 1)) 40 | 41 | # feat & feat_bn 42 | if self.feature_dim > 0: 43 | self.feat_fc = nn.Conv2d(base_dim, feature_dim, kernel_size=1, padding=0) 44 | init.kaiming_normal_(self.feat_fc.weight, mode='fan_out') 45 | init.constant_(self.feat_fc.bias, 0.0) 46 | else: 47 | feature_dim = base_dim 48 | 49 | self.feat_bn = nn.BatchNorm2d(feature_dim) 50 | init.constant_(self.feat_bn.weight, 1) 51 | init.constant_(self.feat_bn.bias, 0) 52 | 53 | # dropout before classifier 54 | self.dropout = dropout 55 | if self.dropout > 0: 56 | self.drop_layer = nn.Dropout2d(self.dropout) 57 | 58 | # 6 branches of classifiers: 59 | if self.num_classes > 0: 60 | self.classifier_s = nn.ModuleList() 61 | for _ in range(self.num_stripes): 62 | classifier = nn.Linear(feature_dim, self.num_classes, bias=False) 63 | init.normal_(classifier.weight, std=0.001) 64 | self.classifier_s.append(classifier) 65 | 66 | def forward(self, x): 67 | """ 68 | Returns: 69 | h_s: each member with shape [N, c] 70 | prediction_s: each member with shape [N, num_classes] 71 | """ 72 | x = self.base(x) 73 | x = self.avg_pool(x) 74 | 75 | if self.dropout > 0: 76 | x = self.drop_layer(x) 77 | 78 | feature_base = x / x.norm(2, 1).unsqueeze(1).expand_as(x) 79 | feature_base = feature_base.view(feature_base.shape[0], -1) 80 | 81 | if self.feature_dim > 0: 82 | x = self.feat_fc(x) 83 | feature_pcb = x / x.norm(2, 1).unsqueeze(1).expand_as(x) 84 | feature_pcb = feature_pcb.view(feature_pcb.shape[0], -1) 85 | else: 86 | feature_pcb = feature_base 87 | 88 | x = self.feat_bn(x) 89 | # no relu after feature_fc 90 | 91 | x_s = x.chunk(self.num_stripes, 2) 92 | prediction_s = [] 93 | if self.num_classes > 0 and self.training: 94 | for i in range(self.num_stripes): 95 | prediction_s.append(self.classifier_s[i](x_s[i].view(x.shape[0], -1))) 96 | 97 | if self.norm: 98 | feature_pcb = F.normalize(feature_pcb) 99 | 100 | return feature_pcb, tuple(prediction_s) 101 | -------------------------------------------------------------------------------- /reid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .IDE_model import * 4 | from .PCB_model import * 5 | 6 | __factory = { 7 | 'pcb': PCB_model, 8 | 'ide': IDE_model, 9 | } 10 | 11 | 12 | def names(): 13 | return sorted(__factory.keys()) 14 | 15 | 16 | def create(name, *args, **kwargs): 17 | """ 18 | Create a model instance. 19 | 20 | Parameters 21 | ---------- 22 | name : str 23 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 24 | 'resnet50', 'resnet101', and 'resnet152'. 25 | pretrained : bool, optional 26 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 27 | model. Default: True 28 | cut_at_pooling : bool, optional 29 | If True, will cut the model before the last global pooling layer and 30 | ignore the remaining kwargs. Default: False 31 | num_features : int, optional 32 | If positive, will append a Linear layer after the global pooling layer, 33 | with this number of output units, followed by a BatchNorm layer. 34 | Otherwise these layers will not be appended. Default: 256 for 35 | 'inception', 0 for 'resnet*' 36 | norm : bool, optional 37 | If True, will normalize the feature to be unit L2-norm for each sample. 38 | Otherwise will append a ReLU layer after the above Linear layer if 39 | num_features > 0. Default: False 40 | dropout : float, optional 41 | If positive, will append a Dropout layer with this dropout rate. 42 | Default: 0 43 | num_classes : int, optional 44 | If positive, will append a Linear layer at the end as the classifier 45 | with this number of output units. Default: 0 46 | """ 47 | if name not in __factory: 48 | raise KeyError("Unknown model:", name) 49 | return __factory[name](*args, **kwargs) 50 | -------------------------------------------------------------------------------- /reid/prepare/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hou-yz/open-reid-tracking/7a40462231e27e607dd094412550f26270b4efb8/reid/prepare/__init__.py -------------------------------------------------------------------------------- /reid/prepare/add_aic_gps.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import cv2 5 | import numpy as np 6 | from numpy.linalg import inv 7 | 8 | data_path = 'D:/Data/AIC19/' if os.name == 'nt' else osp.expanduser('~/Data/AIC19/') 9 | scenes = [1, 2, 3, 4, 5] 10 | folder_by_scene = {1: 'train', 11 | 2: 'test', 12 | 3: 'train', 13 | 4: 'train', 14 | 5: 'test', } 15 | world_centers = {1: np.array([42.525678, -90.723601]), 16 | 2: np.array([42.491916, -90.723723]), 17 | 3: np.array([42.498780, -90.686393]), 18 | 4: np.array([42.498780, -90.686393]), 19 | 5: np.array([42.498780, -90.686393]), } 20 | 21 | world_scale = 6371000 / 180 * np.pi 22 | 23 | 24 | def image2gps(feet_pos, parameters, scene): 25 | feet_pos = feet_pos.reshape(-1, 1, 2) 26 | if 'intrinsic' in parameters: 27 | # Have to provide P matrix for appropriate scaling 28 | feet_pos = cv2.undistortPoints(feet_pos, parameters['intrinsic'], parameters['distortion'], 29 | P=parameters['intrinsic']) 30 | world_pos = cv2.perspectiveTransform(feet_pos, inv(parameters['homography'])).reshape(-1, 2) 31 | world_pos = (world_pos - world_centers[scene]) * world_scale 32 | return world_pos[:, ::-1] 33 | 34 | 35 | def gps2image(world_pos, parameters, scene): 36 | world_pos = world_pos[:, ::-1] / world_scale + world_centers[scene] 37 | world_pos = world_pos.reshape(-1, 1, 2) 38 | feet_pos = cv2.perspectiveTransform(world_pos, parameters['homography']).reshape(-1, 2) 39 | if 'intrinsic' in parameters: 40 | rvec = np.array([0, 0, 0], dtype=np.float32) 41 | tvec = np.array([0, 0, 0], dtype=np.float32) 42 | feet_pos, _ = cv2.projectPoints( 43 | np.matmul(inv(parameters['intrinsic']), 44 | np.concatenate((feet_pos, np.ones(feet_pos.shape[0]).reshape(-1, 1)), axis=1).T, 45 | ).T, 46 | rvec, tvec, parameters['intrinsic'], parameters['distortion']) 47 | return feet_pos 48 | 49 | 50 | if __name__ == '__main__': 51 | for scene in scenes: 52 | scene_path = osp.join(data_path, folder_by_scene[scene], 'S{:02d}'.format(scene)) 53 | frame_offset_fname = osp.join(data_path, 'cam_timestamp', 'S{:02d}.txt'.format(scene)) 54 | frame_offset = {} 55 | with open(frame_offset_fname) as f: 56 | for line in f: 57 | (key, val) = line.split(' ') 58 | key = int(key[1:]) 59 | val = 10 * float(val) 60 | frame_offset[key] = val 61 | for camera_dir in sorted(os.listdir(scene_path)): 62 | iCam = int(camera_dir[1:]) 63 | calibration_fname = osp.join(data_path, 'calibration', camera_dir, 'calibration.txt') 64 | parameters = {} 65 | with open(calibration_fname) as f: 66 | for line in f: 67 | (key, val) = line.split(':') 68 | key = key.split(' ')[0].lower() 69 | if key == 'reprojection': key = 'error' 70 | if ';' in val: 71 | val = np.fromstring(val.replace(';', ' '), dtype=float, sep=' ').reshape([3, 3]) 72 | else: 73 | val = np.fromstring(val, dtype=float, sep=' ') 74 | parameters[key] = val 75 | pass 76 | bbox_types = ['gt', 'det'] if folder_by_scene[scene] == 'train' else ['det'] 77 | for bbox_type in bbox_types: 78 | bbox_file = osp.join(scene_path, camera_dir, bbox_type, 79 | 'gt.txt' if bbox_type == 'gt' else 'det_ssd512.txt') 80 | bboxs = np.loadtxt(bbox_file, delimiter=',') 81 | feet_pos = np.array([bboxs[:, 2] + bboxs[:, 4] / 2, bboxs[:, 3] + bboxs[:, 5]]).T 82 | world_pos = image2gps(feet_pos, parameters, scene) 83 | new_feet_pos = gps2image(world_pos, parameters, scene) 84 | error = np.mean(np.sum(new_feet_pos - feet_pos, axis=1)) 85 | 86 | bboxs[:, 7] = iCam 87 | bboxs[:, 8] = bboxs[:, 0] + frame_offset[iCam] 88 | bboxs = bboxs[:, :9] 89 | bboxs = np.concatenate((bboxs, world_pos), axis=1) 90 | bbox_gps_file = osp.join(scene_path, camera_dir, bbox_type, 91 | 'gt_gps.txt' if bbox_type == 'gt' else 'det_ssd512_gps.txt') 92 | np.savetxt(bbox_gps_file, bboxs, delimiter=',', fmt='%g') 93 | pass 94 | -------------------------------------------------------------------------------- /reid/prepare/affinity_matrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os.path as osp 3 | import glob 4 | import re 5 | 6 | path = osp.expanduser('~/Data/VeRi/image_train') 7 | fpaths = sorted(glob.glob(osp.join(path, '*.jpg'))) 8 | pattern = re.compile(r'(\d+)_c(\d+)_(\d+)') 9 | all_pids = {} 10 | ret = [] 11 | for fpath in fpaths: 12 | fname = osp.basename(fpath) 13 | pid, line, frame = map(int, pattern.search(fname).groups()) 14 | if pid == -1: continue 15 | if pid not in all_pids: 16 | all_pids[pid] = len(all_pids) 17 | pid = all_pids[pid] 18 | ret.append((pid, line - 1, frame)) 19 | 20 | affinity_matrix = np.zeros([20, 20]) 21 | pid_cam_frame = np.array(ret) 22 | for pid in all_pids.values(): 23 | indices = np.where(pid_cam_frame[:, 0] == pid) 24 | samepid_cam_frame = pid_cam_frame[indices] 25 | samepid_cam_frame = samepid_cam_frame[samepid_cam_frame[:, 2].argsort()] 26 | for i in range(samepid_cam_frame.shape[0]): 27 | if i == 0: 28 | last_line = samepid_cam_frame[i, :] 29 | else: 30 | last_line = line 31 | line = samepid_cam_frame[i, :] 32 | if last_line[1] != line[1] or line[2] - last_line[2] > 200: 33 | affinity_matrix[last_line[1], line[1]] += 1 34 | affinity_matrix += affinity_matrix.T 35 | np.savetxt('affinity_matrix.txt', affinity_matrix, '%d') 36 | pass 37 | -------------------------------------------------------------------------------- /reid/prepare/ensemble.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import re 4 | from glob import glob 5 | 6 | import h5py 7 | import numpy as np 8 | from sklearn.preprocessing import normalize 9 | 10 | models = ['lr001', 'lr001_softmargin', 'lr001_colorjitter'] 11 | dirs = ['gt_all', 'gt_mini', ] # 'test', 'trainval', 12 | 13 | for data_dir in dirs: 14 | 15 | models_feat = {} 16 | models_header = {} 17 | for model in models: 18 | if data_dir == 'gt_mini': 19 | folder = osp.join('/home/houyz/Code/DeepCC/experiments', 'zju_{}_gt_trainval'.format(model)) 20 | elif data_dir == 'gt_all': 21 | folder = osp.join('/home/houyz/Data/AIC19/L0-features', 'gt_features_zju_{}'.format(model)) 22 | else: 23 | folder = osp.join('/home/houyz/Data/AIC19/L0-features', 24 | 'det_features_zju_{}_{}_ssd'.format(model, data_dir)) 25 | fnames = sorted(glob(osp.join(folder, '*.h5'))) 26 | 27 | pattern = re.compile(r'(\d+)') 28 | for fname in fnames: 29 | h5file = h5py.File(fname, 'r') 30 | data = np.array(h5file['emb']) 31 | cam = int(pattern.search(osp.basename(fname)).groups()[0]) 32 | if cam not in models_feat: 33 | models_feat[cam] = np.array([]) 34 | models_header[cam] = data[:, :3 if 'gt' in data_dir else 2] 35 | 36 | data = data[:, 3 if 'gt' in data_dir else 2:] 37 | data = normalize(data, axis=1) 38 | models_feat[cam] = np.hstack([models_feat[cam], data]) if models_feat[cam].size else data 39 | pass 40 | for cam in models_feat.keys(): 41 | models_feat[cam] /= len(models) ** 0.5 42 | ensemble_feat = np.hstack([models_header[cam], models_feat[cam]]) 43 | if data_dir == 'gt_mini': 44 | folder = osp.join('/home/houyz/Code/DeepCC/experiments', 'zju_lr001_ensemble_gt_trainval') 45 | elif data_dir == 'gt_all': 46 | folder = osp.join('/home/houyz/Data/AIC19/L0-features', 'gt_features_zju_lr001_ensemble') 47 | else: 48 | folder = osp.join('/home/houyz/Data/AIC19/L0-features', 49 | 'det_features_zju_lr001_ensemble_{}_ssd'.format(data_dir)) 50 | 51 | output_fname = folder + '/features%d.h5' % cam 52 | if not osp.exists(folder): 53 | os.makedirs(folder) 54 | with h5py.File(output_fname, 'w') as f: 55 | f.create_dataset('emb', data=ensemble_feat, dtype=float, maxshape=(None, None)) 56 | pass 57 | 58 | pass 59 | -------------------------------------------------------------------------------- /reid/prepare/extract_bbox.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import os.path as osp 4 | 5 | import cv2 6 | import numpy as np 7 | import psutil 8 | 9 | path = '~/Data/AIC19/' 10 | og_fps = 10 11 | 12 | 13 | def get_bbox(type='gt', det_time='train', fps=10, det_type='ssd'): 14 | # type = ['gt','det','labeled'] 15 | data_path = osp.join(osp.expanduser(path), 'test' if det_time == 'test' else 'train') 16 | save_path = osp.join(osp.expanduser('~/Data/AIC19/ALL_{}_bbox/'.format(type)), det_time) 17 | 18 | if type == 'gt' or type == 'labeled': 19 | save_path = osp.join(save_path, 'gt_bbox_{}_fps'.format(fps)) 20 | fps_pooling = int(og_fps / fps) # use minimal number of gt's to train ide model 21 | else: 22 | save_path = osp.join(save_path, det_type) 23 | 24 | os.makedirs(save_path, exist_ok=True) 25 | 26 | # scene selection for train/val 27 | if det_time == 'train': 28 | scenes = ['S03', 'S04'] 29 | elif det_time == 'trainval': 30 | scenes = ['S01', 'S03', 'S04'] 31 | elif det_time == 'val': 32 | scenes = ['S01'] 33 | elif det_time == 'test' and type == 'gt': 34 | scenes = ['S02', 'S06'] 35 | else: # test 36 | scenes = os.listdir(data_path) 37 | 38 | for scene in scenes: 39 | scene_path = osp.join(data_path, scene) 40 | for camera_dir in os.listdir(scene_path): 41 | iCam = int(camera_dir[1:]) 42 | # get bboxs 43 | if type == 'gt': 44 | if det_time == 'test': 45 | bbox_filename = osp.join('/home/houyz/Code/DeepCC/experiments/aic_label_det/L3-identities', 46 | 'cam{}_test.txt'.format(iCam)) 47 | delimiter = None 48 | else: 49 | bbox_filename = osp.join(scene_path, camera_dir, 'gt', 'gt.txt') 50 | delimiter = ',' 51 | elif type == 'labeled': 52 | bbox_filename = osp.join(scene_path, camera_dir, 'det', 53 | 'det_{}_labeled.txt'.format('ssd512' if det_type == 'ssd' else 'yolo3')) 54 | delimiter = ',' 55 | else: # det 56 | bbox_filename = osp.join(scene_path, camera_dir, 'det', 57 | 'det_{}.txt'.format('ssd512' if det_type == 'ssd' else 'yolo3')) 58 | delimiter = ',' 59 | bboxs = np.loadtxt(bbox_filename, delimiter=delimiter) 60 | if type == 'gt' or type == 'labeled': 61 | bboxs = bboxs[np.where(bboxs[:, 0] % fps_pooling == 0)[0], :] 62 | 63 | # get frame_pics 64 | video_file = osp.join(scene_path, camera_dir, 'vdo.avi') 65 | video_reader = cv2.VideoCapture(video_file) 66 | # get vcap property 67 | width = video_reader.get(cv2.CAP_PROP_FRAME_WIDTH) # float 68 | height = video_reader.get(cv2.CAP_PROP_FRAME_HEIGHT) # float 69 | 70 | # enlarge by 40 pixel for detection 71 | if type == 'det' or type == 'labeled': 72 | bboxs[:, 2:4] = bboxs[:, 2:4] - 20 73 | bboxs[:, 4:6] = bboxs[:, 4:6] + 40 74 | 75 | # bboxs 76 | bbox_top = np.maximum(bboxs[:, 3], 0) 77 | bbox_bottom = np.minimum(bboxs[:, 3] + bboxs[:, 5], height - 1) 78 | bbox_left = np.maximum(bboxs[:, 2], 0) 79 | bbox_right = np.minimum(bboxs[:, 2] + bboxs[:, 4], width - 1) 80 | bboxs[:, 2:6] = np.stack((bbox_top, bbox_bottom, bbox_left, bbox_right), axis=1) 81 | 82 | # frame_pics = [] 83 | frame_num = 0 84 | success = video_reader.isOpened() 85 | printed_img_count = 0 86 | while (success): 87 | assert psutil.virtual_memory().percent < 95, "reading video will be killed!!!!!!" 88 | 89 | success, frame_pic = video_reader.read() 90 | frame_num = frame_num + 1 91 | bboxs_in_frame = bboxs[bboxs[:, 0] == frame_num, :] 92 | 93 | for index in range(bboxs_in_frame.shape[0]): 94 | frame = int(bboxs_in_frame[index, 0]) 95 | pid = int(bboxs_in_frame[index, 1]) 96 | bbox_top = int(bboxs_in_frame[index, 2]) 97 | bbox_bottom = int(bboxs_in_frame[index, 3]) 98 | bbox_left = int(bboxs_in_frame[index, 4]) 99 | bbox_right = int(bboxs_in_frame[index, 5]) 100 | 101 | bbox_pic = frame_pic[bbox_top:bbox_bottom, bbox_left:bbox_right] 102 | if bbox_pic.size == 0: 103 | continue 104 | 105 | if type == 'gt' or type == 'labeled': 106 | save_file = osp.join(save_path, "{:04d}_c{:02d}_f{:05d}.jpg".format(pid, iCam, frame)) 107 | else: 108 | save_file = osp.join(save_path, 'c{:02d}_f{:05d}_{:03d}.jpg'.format(iCam, frame, index)) 109 | 110 | cv2.imwrite(save_file, bbox_pic) 111 | cv2.waitKey(0) 112 | printed_img_count += 1 113 | 114 | cv2.waitKey(0) 115 | video_reader.release() 116 | # assert printed_img_count == bboxs.shape[0] 117 | 118 | print(video_file, 'completed!') 119 | print(scene, 'completed!') 120 | print(save_path, 'complete d!') 121 | 122 | 123 | if __name__ == '__main__': 124 | print('{}'.format(datetime.datetime.today().strftime('%Y-%m-%d_%H-%M-%S'))) 125 | get_bbox(type='gt', fps=10, det_time='trainval') 126 | # get_bbox(type='labeled', det_time='trainval', fps=1) 127 | get_bbox(type='det', det_time='trainval', det_type='ssd') 128 | get_bbox(type='det', det_time='test', det_type='ssd') 129 | print('{}'.format(datetime.datetime.today().strftime('%Y-%m-%d_%H-%M-%S'))) 130 | print('Job Completed!') 131 | -------------------------------------------------------------------------------- /reid/prepare/label_det_dataset.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | 8 | def bbox_ious(boxA, boxB): 9 | if boxA.size == 0 or boxB.size == 0: 10 | return np.array([]) 11 | boxA[:, 2:4] = np.array([boxA[:, 0] + boxA[:, 2], boxA[:, 1] + boxA[:, 3]]).T 12 | boxB[:, 2:4] = np.array([boxB[:, 0] + boxB[:, 2], boxB[:, 1] + boxB[:, 3]]).T 13 | # determine the (x, y)-coordinates of the intersection rectangle 14 | xA = np.maximum(np.array([boxA[:, 0]]).T, boxB[:, 0]) 15 | yA = np.maximum(np.array([boxA[:, 1]]).T, boxB[:, 1]) 16 | xB = np.minimum(np.array([boxA[:, 2]]).T, boxB[:, 2]) 17 | yB = np.minimum(np.array([boxA[:, 3]]).T, boxB[:, 3]) 18 | # compute the area of intersection rectangle 19 | interArea = np.maximum(0, xB - xA + 1) * np.maximum(0, yB - yA + 1) 20 | # compute the area of both the prediction and ground-truth 21 | # rectangles 22 | boxAArea = np.array([(boxA[:, 2] - boxA[:, 0] + 1) * (boxA[:, 3] - boxA[:, 1] + 1)]).T 23 | boxBArea = (boxB[:, 2] - boxB[:, 0] + 1) * (boxB[:, 3] - boxB[:, 1] + 1) 24 | # compute the intersection over union by taking the intersection 25 | # area and dividing it by the sum of prediction + ground-truth 26 | # areas - the interesection area 27 | ious = interArea / (boxAArea + boxBArea - interArea) 28 | # return the intersection over union value 29 | return ious 30 | 31 | 32 | def main(det_time='train', IoUthreshold=0.3): 33 | data_dir = os.path.expanduser('~/Data/AIC19/train') 34 | 35 | if det_time == 'train': 36 | scenes = ['S03', 'S04'] 37 | elif det_time == 'trainval': 38 | scenes = ['S01', 'S03', 'S04'] 39 | elif det_time == 'val': 40 | scenes = ['S01'] 41 | else: 42 | scenes = None 43 | 44 | # loop for subsets 45 | for scene in scenes: 46 | scene_dir = os.path.join(data_dir, scene) 47 | # savedir = os.path.join(data_dir, 'labeled') 48 | # if not os.path.exists(savedir): 49 | # os.mkdir(savedir) 50 | 51 | # loop for cameras 52 | for camera in os.listdir(scene_dir): 53 | gt_file_path = os.path.join(scene_dir, camera, 'gt', 'gt.txt') 54 | det_file_path = os.path.join(scene_dir, camera, 'det', 'det_ssd512.txt') 55 | gt_file = np.array(pd.read_csv(gt_file_path, header=None)) 56 | det_file = np.array(pd.read_csv(det_file_path, header=None)) 57 | # frame, id, bbox*4, score 58 | frames = np.unique(gt_file[:, 0]) 59 | for frame in frames: 60 | gt_line_ids = np.where(gt_file[:, 0] == frame)[0] 61 | same_frame_gt_bboxs = gt_file[gt_line_ids, 2:6] 62 | det_line_ids = np.where(det_file[:, 0] == frame)[0] 63 | same_frame_det_bboxs = det_file[det_line_ids, 2:6] 64 | ious = bbox_ious(same_frame_gt_bboxs, same_frame_det_bboxs) 65 | if ious.size == 0: 66 | continue 67 | label = np.argmax(ious, axis=1) 68 | det_file[det_line_ids[label], 1] = gt_file[gt_line_ids, 1] 69 | det_file[det_line_ids[np.max(ious, axis=0) < IoUthreshold], 1] = -1 70 | pass 71 | np.savetxt(os.path.join(scene_dir, camera, 'det', 'det_ssd512_labeled.txt'), 72 | det_file, delimiter=',', fmt='%d') 73 | # np.savetxt(os.path.join(savedir, '{}_det_ssd512_labeled.txt'.format(camera)), 74 | # det_file, delimiter=',', fmt='%d') 75 | 76 | print(camera, 'is completed') 77 | print(scene, 'is completed') 78 | 79 | 80 | if __name__ == '__main__': 81 | print('{}'.format(datetime.datetime.today().strftime('%Y-%m-%d_%H-%M-%S'))) 82 | main(det_time='trainval') 83 | print('{}'.format(datetime.datetime.today().strftime('%Y-%m-%d_%H-%M-%S'))) 84 | print('Job Completed!') 85 | -------------------------------------------------------------------------------- /reid/trainers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import time 4 | 5 | import torch 6 | from torch import nn 7 | from torch.autograd import Variable 8 | 9 | from .evaluation_metrics import accuracy 10 | from .loss import * 11 | from .utils.meters import AverageMeter 12 | 13 | 14 | class BaseTrainer(object): 15 | def __init__(self, model, criterion): 16 | super(BaseTrainer, self).__init__() 17 | self.model = model 18 | self.criterion = criterion 19 | 20 | def train(self, epoch, data_loader, optimizer): 21 | raise NotImplementedError 22 | 23 | def _parse_data(self, inputs): 24 | raise NotImplementedError 25 | 26 | def _forward(self, inputs, targets): 27 | raise NotImplementedError 28 | 29 | 30 | class Trainer(BaseTrainer): 31 | def train(self, epoch, data_loader, optimizer, fix_bn=False, print_freq=10): 32 | self.model.train() 33 | 34 | is_triplet = isinstance(self.criterion, TripletLoss) 35 | if isinstance(self.criterion, list): 36 | is_triplet = isinstance(self.criterion[1], TripletLoss) 37 | if isinstance(self.criterion, TripletLoss) or isinstance(self.criterion, list): 38 | margin = self.criterion.margin if isinstance(self.criterion, TripletLoss) else self.criterion[1].margin 39 | 40 | # detailed logging for triplet 41 | if isinstance(self.criterion, TripletLoss): 42 | # For recording precision, satisfying margin, etc 43 | prec_meter = AverageMeter() 44 | sm_meter = AverageMeter() 45 | dist_ap_meter = AverageMeter() 46 | dist_an_meter = AverageMeter() 47 | loss_meter = AverageMeter() 48 | if fix_bn: 49 | # set the bn layers to eval() and don't change weight & bias 50 | for m in self.model.module.base.modules(): 51 | if isinstance(m, nn.BatchNorm2d): 52 | m.eval() 53 | if m.affine: 54 | m.weight.requires_grad = False 55 | m.bias.requires_grad = False 56 | 57 | batch_time = AverageMeter() 58 | data_time = AverageMeter() 59 | losses = AverageMeter() 60 | precisions = AverageMeter() 61 | 62 | end = time.time() 63 | for i, inputs in enumerate(data_loader): 64 | data_time.update(time.time() - end) 65 | 66 | inputs, targets = self._parse_data(inputs) 67 | if isinstance(self.criterion, TripletLoss): 68 | loss, prec1, dist_ap, dist_an = self._forward(inputs, targets) 69 | # the proportion of triplets that satisfy margin 70 | sm = (dist_an > dist_ap + margin).data.float().mean() 71 | # average (anchor, positive) distance 72 | d_ap = dist_ap.data.mean() 73 | # average (anchor, negative) distance 74 | d_an = dist_an.data.mean() 75 | prec_meter.update(prec1) 76 | sm_meter.update(sm) 77 | dist_ap_meter.update(d_ap) 78 | dist_an_meter.update(d_an) 79 | loss_meter.update(loss) 80 | # tri_log = ('prec {:.2%}, sm {:.2%}, d_ap {:.4f}, d_an {:.4f}, loss {:.4f}'.format( 81 | # prec_meter.val, sm_meter.val, dist_ap_meter.val, dist_an_meter.val, loss_meter.val, )) 82 | # print(tri_log) 83 | else: 84 | loss, prec1 = self._forward(inputs, targets) 85 | 86 | losses.update(loss.item(), targets.size(0)) 87 | precisions.update(prec1, targets.size(0)) 88 | 89 | optimizer.zero_grad() 90 | loss.backward() 91 | optimizer.step() 92 | 93 | batch_time.update(time.time() - end) 94 | end = time.time() 95 | 96 | if (i + 1) % print_freq == 0 and not isinstance(self.criterion, TripletLoss): 97 | print('Epoch: [{}][{}/{}]\t' 98 | 'Time {:.3f} ({:.3f})\t' 99 | 'Data {:.3f} ({:.3f})\t' 100 | 'Loss {:.3f} ({:.3f})\t' 101 | 'Prec {:.2%} ({:.2%})\t' 102 | .format(epoch, i + 1, len(data_loader), 103 | batch_time.val, batch_time.avg, 104 | data_time.val, data_time.avg, 105 | losses.val, losses.avg, 106 | precisions.val, precisions.avg)) 107 | 108 | # detailed logging at the end of epoch for triplet 109 | if isinstance(self.criterion, TripletLoss): 110 | time_log = 'Epoch [{}], {:.2f}s'.format(epoch, batch_time.avg * len(data_loader), ) 111 | tri_log = (', prec {:.2%}, sm {:.2%}, d_ap {:.4f}, d_an {:.4f}, loss {:.4f}'.format( 112 | prec_meter.val, sm_meter.val, dist_ap_meter.val, dist_an_meter.val, loss_meter.val, )) 113 | print(time_log + tri_log) 114 | 115 | return losses.avg, precisions.avg 116 | 117 | def _parse_data(self, inputs): 118 | imgs, _, pids, _ = inputs 119 | inputs = [Variable(imgs)] 120 | targets = Variable(pids.cuda()) 121 | return inputs, targets 122 | 123 | def _forward(self, inputs, targets): 124 | outputs = self.model(*inputs) 125 | if isinstance(self.criterion, torch.nn.CrossEntropyLoss) or isinstance(self.criterion, LSR_loss): 126 | # if isinstance(self.model.module, IDE_model) or isinstance(self.model.module, PCB_model): 127 | prediction = outputs[1] 128 | loss = 0 129 | for pred in prediction: 130 | loss += self.criterion(pred, targets) 131 | prediction = prediction[0] 132 | prec, = accuracy(prediction.data, targets.data) 133 | # else: 134 | # loss = self.criterion(outputs, targets) 135 | # prec, = accuracy(outputs.data, targets.data) 136 | prec = prec.item() 137 | pass 138 | elif isinstance(self.criterion, TripletLoss): 139 | # if isinstance(self.model.module, PCB_model) or isinstance(self.model.module, IDE_model): 140 | outputs = outputs[0] # = x_s 141 | return self.criterion(outputs, targets) 142 | elif isinstance(self.criterion[1], TripletLoss): 143 | # if isinstance(self.model.module, PCB_model) or isinstance(self.model.module, IDE_model): 144 | feat = outputs[0] # = x_s 145 | prediction = outputs[1][0] 146 | loss = self.criterion[0](prediction, targets) + self.criterion[1](feat, targets)[0] 147 | prec, = accuracy(prediction.data, targets.data) 148 | prec = prec.item() 149 | pass 150 | else: 151 | raise ValueError("Unsupported loss:", self.criterion) 152 | return loss, prec 153 | -------------------------------------------------------------------------------- /reid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | -------------------------------------------------------------------------------- /reid/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .dataset import Dataset 4 | from .preprocessor import Preprocessor 5 | -------------------------------------------------------------------------------- /reid/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os.path as osp 4 | 5 | import numpy as np 6 | 7 | from ..serialization import read_json 8 | 9 | 10 | def _pluck(identities, indices, relabel=False): 11 | ret = [] 12 | for index, pid in enumerate(indices): 13 | pid_images = identities[pid] 14 | for camid, cam_images in enumerate(pid_images): 15 | for fname in cam_images: 16 | name = osp.splitext(fname)[0] 17 | # x, y, _ = map(int, name.split('_')) 18 | x, y, _ = name.split('_') 19 | x = int(x) 20 | if 'c' in y: 21 | y = int(y.split('c')[-1]) - 1 22 | else: 23 | y = int(y) 24 | 25 | assert pid == x and camid == y 26 | if relabel: 27 | ret.append((fname, index, camid)) 28 | else: 29 | ret.append((fname, pid, camid)) 30 | return ret 31 | 32 | 33 | class Dataset(object): 34 | def __init__(self, root, split_id=0): 35 | self.root = root 36 | self.split_id = split_id 37 | self.meta = None 38 | self.split = None 39 | self.train, self.val, self.trainval = [], [], [] 40 | self.query, self.gallery = [], [] 41 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 42 | 43 | @property 44 | def images_dir(self): 45 | return osp.join(self.root, 'images') 46 | 47 | def load(self, num_val=0.3, verbose=True): 48 | splits = read_json(osp.join(self.root, 'splits.json')) 49 | if self.split_id >= len(splits): 50 | raise ValueError("split_id exceeds total splits {}" 51 | .format(len(splits))) 52 | self.split = splits[self.split_id] 53 | 54 | # Randomly split train / val 55 | trainval_pids = np.asarray(self.split['trainval']) 56 | np.random.shuffle(trainval_pids) 57 | num = len(trainval_pids) 58 | if isinstance(num_val, float): 59 | num_val = int(round(num * num_val)) 60 | if num_val >= num or num_val < 0: 61 | raise ValueError("num_val exceeds total identities {}" 62 | .format(num)) 63 | if num_val: 64 | train_pids = sorted(trainval_pids[:-num_val]) 65 | val_pids = sorted(trainval_pids[-num_val:]) 66 | else: 67 | train_pids = sorted(trainval_pids) 68 | val_pids = sorted([]) 69 | 70 | self.meta = read_json(osp.join(self.root, 'meta.json')) 71 | identities = self.meta['identities'] 72 | self.train = _pluck(identities, train_pids, relabel=True) 73 | self.val = _pluck(identities, val_pids, relabel=True) 74 | self.trainval = _pluck(identities, trainval_pids, relabel=True) 75 | self.query = _pluck(identities, self.split['query']) 76 | self.gallery = _pluck(identities, self.split['gallery']) 77 | self.num_train_ids = len(train_pids) 78 | self.num_val_ids = len(val_pids) 79 | self.num_trainval_ids = len(trainval_pids) 80 | 81 | if verbose: 82 | print(self.__class__.__name__, "dataset loaded") 83 | print(" subset | # ids | # images") 84 | print(" ---------------------------") 85 | print(" train | {:5d} | {:8d}" 86 | .format(self.num_train_ids, len(self.train))) 87 | print(" val | {:5d} | {:8d}" 88 | .format(self.num_val_ids, len(self.val))) 89 | print(" trainval | {:5d} | {:8d}" 90 | .format(self.num_trainval_ids, len(self.trainval))) 91 | print(" query | {:5d} | {:8d}" 92 | .format(len(self.split['query']), len(self.query))) 93 | print(" gallery | {:5d} | {:8d}" 94 | .format(len(self.split['gallery']), len(self.gallery))) 95 | 96 | def _check_integrity(self): 97 | return osp.isdir(osp.join(self.root, 'images')) and \ 98 | osp.isfile(osp.join(self.root, 'meta.json')) and \ 99 | osp.isfile(osp.join(self.root, 'splits.json')) 100 | -------------------------------------------------------------------------------- /reid/utils/data/og_sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data.sampler import Sampler 8 | 9 | 10 | class RandomIdentitySampler(Sampler): 11 | def __init__(self, data_source, num_instances=1): 12 | super(RandomIdentitySampler, self).__init__(data_source) 13 | self.data_source = data_source 14 | self.num_instances = num_instances 15 | self.index_dic = defaultdict(list) 16 | for index, (_, pid, _) in enumerate(data_source): 17 | self.index_dic[pid].append(index) 18 | self.pids = list(self.index_dic.keys()) 19 | self.num_samples = len(self.pids) 20 | 21 | def __len__(self): 22 | return self.num_samples * self.num_instances 23 | 24 | def __iter__(self): 25 | indices = torch.randperm(self.num_samples) 26 | ret = [] 27 | for i in indices: 28 | pid = self.pids[i] 29 | t = self.index_dic[pid] 30 | if len(t) >= self.num_instances: 31 | t = np.random.choice(t, size=self.num_instances, replace=False) 32 | else: 33 | t = np.random.choice(t, size=self.num_instances, replace=True) 34 | ret.extend(t) 35 | return iter(ret) 36 | -------------------------------------------------------------------------------- /reid/utils/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os.path as osp 4 | 5 | from PIL import Image 6 | 7 | 8 | class Preprocessor(object): 9 | def __init__(self, dataset, root=None, transform=None): 10 | super(Preprocessor, self).__init__() 11 | self.dataset = dataset 12 | self.root = root 13 | self.transform = transform 14 | 15 | def __len__(self): 16 | return len(self.dataset) 17 | 18 | def __getitem__(self, indices): 19 | if isinstance(indices, (tuple, list)): 20 | return [self._get_single_item(index) for index in indices] 21 | return self._get_single_item(indices) 22 | 23 | def _get_single_item(self, index): 24 | fname, pid, camid = self.dataset[index] 25 | fpath = fname 26 | if self.root is not None: 27 | fpath = osp.join(self.root, fname) 28 | img = Image.open(fpath).convert('RGB') 29 | if self.transform is not None: 30 | img = self.transform(img) 31 | return img, fname, pid, camid 32 | -------------------------------------------------------------------------------- /reid/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import math 4 | import random 5 | from torchvision.transforms import * 6 | 7 | 8 | class RandomErasing(object): 9 | """ Randomly selects a rectangle region in an image and erases its pixels. 10 | 'Random Erasing Data Augmentation' by Zhong et al. 11 | See https://arxiv.org/pdf/1708.04896.pdf 12 | Args: 13 | probability: The probability that the Random Erasing operation will be performed. 14 | sl: Minimum proportion of erased area against input image. 15 | sh: Maximum proportion of erased area against input image. 16 | r1: Minimum aspect ratio of erased area. 17 | mean: Erasing value. 18 | """ 19 | 20 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 21 | self.probability = probability 22 | self.mean = mean 23 | self.sl = sl 24 | self.sh = sh 25 | self.r1 = r1 26 | 27 | def __call__(self, img): 28 | 29 | if random.uniform(0, 1) >= self.probability: 30 | return img 31 | 32 | for attempt in range(100): 33 | area = img.size()[1] * img.size()[2] 34 | 35 | target_area = random.uniform(self.sl, self.sh) * area 36 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 37 | 38 | h = int(round(math.sqrt(target_area * aspect_ratio))) 39 | w = int(round(math.sqrt(target_area / aspect_ratio))) 40 | 41 | if w < img.size()[2] and h < img.size()[1]: 42 | x1 = random.randint(0, img.size()[1] - h) 43 | y1 = random.randint(0, img.size()[2] - w) 44 | if img.size()[0] == 3: 45 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 46 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 47 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 48 | else: 49 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 50 | return img 51 | 52 | return img 53 | -------------------------------------------------------------------------------- /reid/utils/data/zju_sampler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import copy 8 | import random 9 | from collections import defaultdict 10 | 11 | import numpy as np 12 | from torch.utils.data.sampler import Sampler 13 | 14 | 15 | class ZJU_RandomIdentitySampler(Sampler): 16 | """ 17 | Randomly sample N identities, then for each identity, 18 | randomly sample K instances, therefore batch size is N*K. 19 | Args: 20 | - data_source (list): list of (img_path, pid, camid). 21 | - num_instances (int): number of instances per identity in a batch. 22 | - batch_size (int): number of examples in a batch. 23 | """ 24 | 25 | def __init__(self, data_source, batch_size, num_instances): 26 | super(ZJU_RandomIdentitySampler, self).__init__(data_source) 27 | self.data_source = data_source 28 | self.batch_size = batch_size 29 | self.num_instances = num_instances 30 | self.num_pids_per_batch = self.batch_size // self.num_instances 31 | self.index_dic = defaultdict(list) 32 | for index, (_, pid, _) in enumerate(self.data_source): 33 | self.index_dic[pid].append(index) 34 | self.pids = list(self.index_dic.keys()) 35 | 36 | # estimate number of examples in an epoch 37 | # only a rough estimation 38 | self.length = 0 39 | for pid in self.pids: 40 | idxs = self.index_dic[pid] 41 | num = len(idxs) 42 | if num < self.num_instances: 43 | num = self.num_instances 44 | self.length += num - num % self.num_instances 45 | pass 46 | 47 | def __iter__(self): 48 | batch_idxs_dict = defaultdict(list) 49 | 50 | for pid in self.pids: 51 | idxs = copy.deepcopy(self.index_dic[pid]) 52 | if len(idxs) < self.num_instances: 53 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 54 | random.shuffle(idxs) 55 | batch_idxs = [] 56 | for idx in idxs: 57 | batch_idxs.append(idx) 58 | if len(batch_idxs) == self.num_instances: 59 | batch_idxs_dict[pid].append(batch_idxs) 60 | batch_idxs = [] 61 | 62 | avai_pids = copy.deepcopy(self.pids) 63 | final_idxs = [] 64 | 65 | while len(avai_pids) >= self.num_pids_per_batch: 66 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 67 | for pid in selected_pids: 68 | batch_idxs = batch_idxs_dict[pid].pop(0) 69 | final_idxs.extend(batch_idxs) 70 | if len(batch_idxs_dict[pid]) == 0: 71 | avai_pids.remove(pid) 72 | 73 | return iter(final_idxs) 74 | 75 | def __len__(self): 76 | return self.length 77 | -------------------------------------------------------------------------------- /reid/utils/draw_curve.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | 3 | matplotlib.use('agg') 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def draw_curve(path, x_epoch, train_loss, train_prec, test_x_epoch=None, test_loss=None, test_prec=None, ): 8 | fig = plt.figure() 9 | ax0 = fig.add_subplot(121, title="loss") 10 | ax1 = fig.add_subplot(122, title="prec") 11 | ax0.plot(x_epoch, train_loss, 'bo-', label='train: {:.3f}'.format(train_loss[-1])) 12 | ax1.plot(x_epoch, train_prec, 'bo-', label='train: {:.3f}'.format(train_prec[-1])) 13 | if test_x_epoch: 14 | if test_loss: 15 | ax0.plot(test_x_epoch, test_loss, 'ro-', label='test: {:.3f}'.format(test_loss[-1])) 16 | if test_prec: 17 | ax1.plot(test_x_epoch, test_prec, 'ro-', label='test: {:.3f}'.format(test_prec[-1])) 18 | else: 19 | if test_loss: 20 | ax0.plot(x_epoch, test_loss, 'ro-', label='test: {:.3f}'.format(test_loss[-1])) 21 | if test_prec: 22 | ax1.plot(x_epoch, test_prec, 'ro-', label='test: {:.3f}'.format(test_prec[-1])) 23 | ax0.legend() 24 | ax1.legend() 25 | fig.savefig(path) 26 | plt.close(fig) 27 | -------------------------------------------------------------------------------- /reid/utils/get_loaders.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from torch import nn 3 | from torch.utils.data import DataLoader 4 | from reid import datasets 5 | from reid.utils.serialization import load_checkpoint 6 | from reid.utils.data.og_sampler import RandomIdentitySampler 7 | from reid.utils.data.zju_sampler import ZJU_RandomIdentitySampler 8 | from reid.utils.data import transforms as T 9 | from reid.utils.data.preprocessor import Preprocessor 10 | 11 | 12 | 13 | def get_data(name, data_dir, height, width, batch_size, workers, 14 | combine_trainval, crop, tracking_icams, fps, re=0, num_instances=0, camstyle=0, zju=0, colorjitter=0): 15 | # if name == 'market1501': 16 | # root = osp.join(data_dir, 'Market-1501-v15.09.15') 17 | # elif name == 'duke_reid': 18 | # root = osp.join(data_dir, 'DukeMTMC-reID') 19 | # elif name == 'duke_tracking': 20 | # root = osp.join(data_dir, 'DukeMTMC') 21 | # else: 22 | # root = osp.join(data_dir, name) 23 | if name == 'duke_tracking': 24 | if tracking_icams != 0: 25 | tracking_icams = [tracking_icams] 26 | else: 27 | tracking_icams = list(range(1, 9)) 28 | dataset = datasets.create(name, data_dir, data_type='tracking_gt', iCams=tracking_icams, fps=fps, 29 | trainval=combine_trainval) 30 | elif name == 'aic_tracking': 31 | dataset = datasets.create(name, data_dir, data_type='tracking_gt', fps=fps, trainval=combine_trainval) 32 | else: 33 | dataset = datasets.create(name, data_dir) 34 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 35 | std=[0.229, 0.224, 0.225]) 36 | num_classes = dataset.num_train_ids 37 | 38 | train_transformer = T.Compose([ 39 | T.ColorJitter(brightness=0.1 * colorjitter, contrast=0.1 * colorjitter, saturation=0.1 * colorjitter, hue=0), 40 | T.Resize((height, width)), 41 | T.RandomHorizontalFlip(), 42 | T.Pad(10 * crop), 43 | T.RandomCrop((height, width)), 44 | T.ToTensor(), 45 | normalizer, 46 | T.RandomErasing(probability=re), 47 | ]) 48 | test_transformer = T.Compose([ 49 | T.Resize((height, width)), 50 | # T.RectScale(height, width, interpolation=3), 51 | T.ToTensor(), 52 | normalizer, 53 | ]) 54 | 55 | if zju: 56 | train_loader = DataLoader( 57 | Preprocessor(dataset.train, root=dataset.train_path, transform=train_transformer), 58 | batch_size=batch_size, num_workers=workers, 59 | sampler=ZJU_RandomIdentitySampler(dataset.train, batch_size, num_instances) if num_instances else None, 60 | shuffle=False if num_instances else True, pin_memory=True, drop_last=False if num_instances else True) 61 | else: 62 | train_loader = DataLoader( 63 | Preprocessor(dataset.train, root=dataset.train_path, transform=train_transformer), 64 | batch_size=batch_size, num_workers=workers, 65 | sampler=RandomIdentitySampler(dataset.train, num_instances) if num_instances else None, 66 | shuffle=False if num_instances else True, pin_memory=True, drop_last=True) 67 | query_loader = DataLoader( 68 | Preprocessor(dataset.query, root=dataset.query_path, transform=test_transformer), 69 | batch_size=batch_size, num_workers=workers, 70 | shuffle=False, pin_memory=True) 71 | gallery_loader = DataLoader( 72 | Preprocessor(dataset.gallery, root=dataset.gallery_path, transform=test_transformer), 73 | batch_size=batch_size, num_workers=workers, 74 | shuffle=False, pin_memory=True) 75 | if camstyle <= 0: 76 | camstyle_loader = None 77 | else: 78 | camstyle_loader = DataLoader( 79 | Preprocessor(dataset.camstyle, root=dataset.camstyle_path, transform=train_transformer), 80 | batch_size=camstyle, num_workers=workers, 81 | shuffle=True, pin_memory=True, drop_last=True) 82 | return dataset, num_classes, train_loader, query_loader, gallery_loader, camstyle_loader 83 | 84 | 85 | def checkpoint_loader(model, path): 86 | checkpoint = load_checkpoint(path) 87 | pretrained_dict = checkpoint['state_dict'] 88 | if isinstance(model, nn.DataParallel): 89 | Parallel = 1 90 | model = model.module.cpu() 91 | else: 92 | Parallel = 0 93 | 94 | model_dict = model.state_dict() 95 | # 1. filter out unnecessary keys 96 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 97 | # if eval_only: 98 | # keys_to_del = [] 99 | # for key in pretrained_dict.keys(): 100 | # if 'classifier' in key: 101 | # keys_to_del.append(key) 102 | # for key in keys_to_del: 103 | # del pretrained_dict[key] 104 | # pass 105 | # 2. overwrite entries in the existing state dict 106 | model_dict.update(pretrained_dict) 107 | # 3. load the new state dict 108 | model.load_state_dict(model_dict) 109 | 110 | start_epoch = checkpoint['epoch'] 111 | best_top1 = checkpoint['best_top1'] 112 | 113 | if Parallel: 114 | model = nn.DataParallel(model).cuda() 115 | 116 | return model, start_epoch, best_top1 117 | -------------------------------------------------------------------------------- /reid/utils/logger.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import sys 5 | 6 | 7 | class Logger(object): 8 | def __init__(self, fpath=None): 9 | self.console = sys.stdout 10 | self.file = None 11 | if fpath is not None: 12 | os.makedirs(os.path.dirname(fpath), exist_ok=True) 13 | self.file = open(fpath, 'w') 14 | 15 | def __del__(self): 16 | self.close() 17 | 18 | def __enter__(self): 19 | pass 20 | 21 | def __exit__(self, *args): 22 | self.close() 23 | 24 | def write(self, msg): 25 | self.console.write(msg) 26 | if self.file is not None: 27 | self.file.write(msg) 28 | 29 | def flush(self): 30 | self.console.flush() 31 | if self.file is not None: 32 | self.file.flush() 33 | os.fsync(self.file.fileno()) 34 | 35 | def close(self): 36 | self.console.close() 37 | if self.file is not None: 38 | self.file.close() 39 | -------------------------------------------------------------------------------- /reid/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /reid/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import json 4 | import os 5 | import os.path as osp 6 | 7 | import torch 8 | from torch.nn import Parameter 9 | 10 | 11 | def read_json(fpath): 12 | with open(fpath, 'r') as f: 13 | obj = json.load(f) 14 | return obj 15 | 16 | 17 | def write_json(obj, fpath): 18 | os.makedirs(osp.dirname(fpath), exist_ok=True) 19 | with open(fpath, 'w') as f: 20 | json.dump(obj, f, indent=4, separators=(',', ': ')) 21 | 22 | 23 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 24 | os.makedirs(osp.dirname(fpath), exist_ok=True) 25 | if int(state['epoch']) % 10 == 0: 26 | torch.save(state, fpath) 27 | if is_best: 28 | torch.save(state, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 29 | 30 | 31 | def load_checkpoint(fpath): 32 | if osp.isfile(fpath): 33 | checkpoint = torch.load(fpath) 34 | print("=> Loaded checkpoint '{}'".format(fpath)) 35 | return checkpoint 36 | else: 37 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 38 | 39 | 40 | def copy_state_dict(state_dict, model, strip=None): 41 | tgt_state = model.state_dict() 42 | copied_names = set() 43 | for name, param in state_dict.items(): 44 | if strip is not None and name.startswith(strip): 45 | name = name[len(strip):] 46 | if name not in tgt_state: 47 | continue 48 | if isinstance(param, Parameter): 49 | param = param.data 50 | if param.size() != tgt_state[name].size(): 51 | print('mismatch:', name, param.size(), tgt_state[name].size()) 52 | continue 53 | tgt_state[name].copy_(param) 54 | copied_names.add(name) 55 | 56 | missing = set(tgt_state.keys()) - copied_names 57 | if len(missing) > 0: 58 | print("missing keys in state_dict:", missing) 59 | 60 | return model 61 | -------------------------------------------------------------------------------- /reid_metric.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['OMP_NUM_THREADS'] = '1' 4 | import argparse 5 | import os.path as osp 6 | import datetime 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.utils.data import DataLoader 11 | from reid import datasets 12 | from reid.utils.draw_curve import draw_curve 13 | from reid.metric.metric_trainer import CNNTrainer 14 | from reid.metric.reid_feat_dataset import * 15 | from reid.metric.MLP_model import MLP_metric 16 | from reid.metric.metric_evaluate import metric_evaluate 17 | 18 | 19 | def main(args): 20 | # dataset path 21 | if args.logs_dir is None: 22 | args.logs_dir = osp.join(f'logs/metric/mlp/{args.dataset}') 23 | else: 24 | args.logs_dir = osp.join(f'logs/metric/mlp/{args.dataset}') 25 | os.makedirs(args.logs_dir, exist_ok=True) 26 | 27 | root = osp.expanduser('~/Data') 28 | if 'duke' in args.dataset: 29 | if 'tracking' in args.dataset: 30 | root = osp.join(root, 'DukeMTMC') 31 | else: 32 | root = osp.join(root, 'DukeMTMC-reID') 33 | elif 'aic' in args.dataset: 34 | if 'tracking' in args.dataset: 35 | root = osp.join(root, 'AIC19') 36 | else: 37 | root = osp.join(root, 'AIC19-reid') 38 | elif args.dataset == 'market1501': 39 | root = osp.join(root, 'Market1501') 40 | elif args.dataset == 'veri': 41 | root = osp.join(root, 'VeRi') 42 | else: 43 | raise Exception 44 | root += '/L0-features/' 45 | assert args.data_dir, 'Must provide data directory' 46 | train_dir = args.data_dir 47 | query_dir = args.data_dir.replace('trainval', 'query') 48 | gallery_dir = args.data_dir.replace('trainval', 'gallery') 49 | 50 | feat_trainset = HyperFeat(root + train_dir) 51 | feat_queryset = HyperFeat(root + query_dir) 52 | feat_galleryset = HyperFeat(root + gallery_dir) 53 | siamese_trainset = SiameseHyperFeat(feat_trainset) 54 | siamese_testset = SiameseHyperFeat(feat_galleryset) 55 | 56 | train_loader = DataLoader(siamese_trainset, batch_size=args.batch_size, 57 | num_workers=args.num_workers, pin_memory=True, shuffle=True) 58 | test_loader = DataLoader(siamese_testset, batch_size=args.batch_size, 59 | num_workers=args.num_workers, pin_memory=True) 60 | 61 | # model 62 | model = MLP_metric(feature_dim=siamese_trainset.feature_dim, num_class=2).cuda() 63 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 64 | # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 20, 1) 65 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, 66 | steps_per_epoch=len(train_loader), epochs=args.epochs) 67 | 68 | trainer = CNNTrainer(model, nn.CrossEntropyLoss(), ) 69 | 70 | if args.train: 71 | # Draw Curve 72 | x_epoch = [] 73 | train_loss_s = [] 74 | train_prec_s = [] 75 | test_loss_s = [] 76 | test_prec_s = [] 77 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, 78 | weight_decay=args.weight_decay) 79 | for epoch in range(1, args.epochs + 1): 80 | train_loss, train_prec = trainer.train(epoch, train_loader, optimizer, cyclic_scheduler=scheduler) 81 | test_loss, test_prec = trainer.test(test_loader, ) 82 | x_epoch.append(epoch) 83 | train_loss_s.append(train_loss) 84 | train_prec_s.append(train_prec) 85 | test_loss_s.append(test_loss) 86 | test_prec_s.append(test_prec) 87 | draw_curve(args.logs_dir + '/MetricNet.jpg', x_epoch, train_loss_s, train_prec_s, 88 | None, test_loss_s, test_prec_s) 89 | pass 90 | torch.save({'state_dict': model.state_dict(), }, args.logs_dir + '/model.pth.tar') 91 | 92 | checkpoint = torch.load(args.logs_dir + '/model.pth.tar') 93 | model_dict = checkpoint['state_dict'] 94 | model.load_state_dict(model_dict) 95 | trainer.test(test_loader) 96 | metric_evaluate(model, feat_queryset, feat_galleryset) 97 | 98 | 99 | if __name__ == '__main__': 100 | # Training settings 101 | parser = argparse.ArgumentParser(description='Metric learning on top of re-ID features') 102 | parser.add_argument('--model', type=str, default='mlp', choices=['mlp', 'gcn']) 103 | parser.add_argument('-d', '--dataset', type=str, default='duke_reid', choices=datasets.names()) 104 | parser.add_argument('-b', '--batch-size', type=int, default=64, metavar='N', 105 | help='input batch size for training (default: 64)') 106 | parser.add_argument('-j', '--num-workers', type=int, default=1) 107 | parser.add_argument('--epochs', type=int, default=40, metavar='N') 108 | parser.add_argument('--lr', type=float, default=1e-3, metavar='LR') 109 | parser.add_argument('--combine-trainval', action='store_true', 110 | help="train and val sets together for training, val set alone for validation") 111 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') 112 | parser.add_argument('--weight-decay', type=float, default=5e-4) 113 | parser.add_argument('--train', action='store_true') 114 | parser.add_argument('--resume', type=str, default=None, metavar='PATH') 115 | parser.add_argument('--log-interval', type=int, default=300, metavar='N', 116 | help='how many batches to wait before logging training status') 117 | 118 | parser.add_argument('--data-dir', type=str, default=None, metavar='PATH') 119 | parser.add_argument('--logs-dir', type=str, default=None, metavar='PATH') 120 | args = parser.parse_args() 121 | main(args) 122 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-learn 2 | h5py 3 | matplotlib 4 | -------------------------------------------------------------------------------- /save_cnn_feature.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | 3 | import argparse 4 | import json 5 | import os 6 | import re 7 | import time 8 | 9 | import h5py 10 | import numpy as np 11 | import torch 12 | from torch.backends import cudnn 13 | 14 | from reid import models 15 | from reid.datasets import * 16 | from reid.feature_extraction import extract_cnn_feature 17 | from reid.utils.meters import AverageMeter 18 | from reid.utils.get_loaders import * 19 | 20 | 21 | def save_file(lines, args, root, if_created): 22 | # write file 23 | if args.data_type == 'tracking_det': 24 | folder_name = root + f"/L0-features/det_{args.det_time}_features_{args.model}_{args.resume}" 25 | if args.dataset == 'aic': 26 | folder_name += f'_{args.det_type}' 27 | elif args.data_type == 'reid': 28 | folder_name = root + f"/L0-features/reid_trainval_features_{args.model}_{args.resume}" 29 | elif args.data_type == 'tracking_gt': # only extract ground truth data from 'train' set 30 | folder_name = root + f"/L0-features/gt_{args.det_time}_features_{args.model}_{args.resume}" 31 | elif args.data_type == 'reid_test': # reid_test: query/gallery 32 | folder_name = root + f"/L0-features/reid_{args.reid_test}_features_{args.model}_{args.resume}" 33 | else: 34 | raise Exception 35 | 36 | if args.re: 37 | folder_name += '_RE' 38 | if args.crop: 39 | folder_name += '_CROP' 40 | 41 | os.makedirs(folder_name, exist_ok=True) 42 | with open(osp.join(folder_name, 'args.json'), 'w') as fp: 43 | json.dump(vars(args), fp, indent=1) 44 | for cam in range(len(lines)): 45 | output_fname = folder_name + '/features%d.h5' % (cam + 1) 46 | if args.tracking_icams != 0 and cam + 1 != args.tracking_icams and args.tracking_icams is not None: 47 | continue 48 | if not lines[cam]: 49 | continue 50 | 51 | if not if_created[cam]: 52 | with h5py.File(output_fname, 'w') as f: 53 | mat_data = np.vstack(lines[cam]) 54 | f.create_dataset('emb', data=mat_data, dtype=float, maxshape=(None, None)) 55 | pass 56 | if_created[cam] = 1 57 | else: 58 | with h5py.File(output_fname, 'a') as f: 59 | mat_data = np.vstack(lines[cam]) 60 | f['emb'].resize((f['emb'].shape[0] + mat_data.shape[0]), axis=0) 61 | f['emb'][-mat_data.shape[0]:] = mat_data 62 | pass 63 | 64 | return if_created 65 | 66 | 67 | def extract_n_save(model, data_loader, args, root, num_cams, is_detection=True, use_fname=True, gt_type='reid'): 68 | model.eval() 69 | print_freq = 1000 70 | batch_time = AverageMeter() 71 | data_time = AverageMeter() 72 | 73 | if_created = [0 for _ in range(num_cams)] 74 | lines = [[] for _ in range(num_cams)] 75 | 76 | end = time.time() 77 | for i, (imgs, fnames, pids, cams) in enumerate(data_loader): 78 | cams += 1 79 | outputs = extract_cnn_feature(model, imgs) 80 | for fname, output, pid, cam in zip(fnames, outputs, pids, cams): 81 | if is_detection: 82 | pattern = re.compile(r'c(\d+)_f(\d+)') 83 | cam, frame = map(int, pattern.search(fname).groups()) 84 | # f_names[cam - 1].append(fname) 85 | # features[cam - 1].append(output.numpy()) 86 | line = np.concatenate([np.array([cam, 0, frame]), output.numpy()]) 87 | else: 88 | if use_fname: 89 | pattern = re.compile(r'(\d+)_c(\d+)_f(\d+)') 90 | pid, cam, frame = map(int, pattern.search(fname).groups()) 91 | else: 92 | cam, pid = cam.numpy(), pid.numpy() 93 | frame = -1 * np.ones_like(pid) 94 | # line = output.numpy() 95 | line = np.concatenate([np.array([cam, pid, frame]), output.numpy()]) 96 | lines[cam - 1].append(line) 97 | batch_time.update(time.time() - end) 98 | end = time.time() 99 | 100 | if (i + 1) % print_freq == 0: 101 | print('Extract Features: [{}/{}]\t' 102 | 'Time {:.3f} ({:.3f})\t' 103 | 'Data {:.3f} ({:.3f})\t' 104 | .format(i + 1, len(data_loader), 105 | batch_time.val, batch_time.avg, 106 | data_time.val, data_time.avg)) 107 | 108 | if_created = save_file(lines, args, root, if_created) 109 | 110 | lines = [[] for _ in range(num_cams)] 111 | 112 | save_file(lines, args, root, if_created) 113 | return 114 | 115 | 116 | def main(args): 117 | # seed 118 | if args.seed is not None: 119 | np.random.seed(args.seed) 120 | torch.manual_seed(args.seed) 121 | torch.backends.cudnn.deterministic = True 122 | torch.backends.cudnn.benchmark = False 123 | else: 124 | torch.backends.cudnn.benchmark = True 125 | 126 | tic = time.time() 127 | if args.tracking_icams: 128 | tracking_icams = [args.tracking_icams] 129 | 130 | if args.data_type == 'tracking_det': 131 | if args.dataset == 'duke_tracking': 132 | dataset_dir = osp.join(args.data_dir, 'DukeMTMC', 'ALL_det_bbox', f'det_bbox_OpenPose_{args.det_time}') 133 | elif args.dataset == 'aic_tracking': 134 | dataset_dir = osp.join(args.data_dir, 'AIC19', 'ALL_det_bbox', 135 | f'det_bbox_{args.det_type}_{args.det_time}', ) 136 | fps = None 137 | use_fname = True 138 | elif args.data_type == 'reid': 139 | # args.det_time = 'trainval' 140 | dataset_dir = None 141 | fps = 1 142 | use_fname = False 143 | elif args.data_type == 'tracking_gt': 144 | if args.dataset == 'aic': 145 | args.det_time = 'trainval' 146 | dataset_dir = None 147 | fps = 60 if args.dataset == 'duke' else 10 148 | use_fname = True 149 | elif args.data_type == 'reid_test': # reid_test 150 | dataset_dir = None 151 | fps = 1 152 | use_fname = False 153 | else: 154 | raise Exception 155 | 156 | print(dataset_dir) 157 | if args.dataset == 'duke_tracking': 158 | dataset = DukeMTMC(dataset_dir, data_type=args.data_type, iCams=tracking_icams, fps=fps, 159 | trainval=args.det_time == 'trainval') 160 | elif args.dataset == 'aic_tracking': # aic 161 | dataset = AI_City(dataset_dir, data_type=args.data_type, fps=fps, trainval=args.det_time == 'trainval', 162 | gt_type=args.gt_type) 163 | else: 164 | dataset = datasets.create(args.dataset, args.data_dir) 165 | 166 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 167 | test_transformer = T.Compose([ 168 | T.Resize([args.height, args.width]), 169 | T.Pad(10 * args.crop), 170 | T.RandomCrop([args.height, args.width]), 171 | T.ToTensor(), 172 | normalizer, 173 | T.RandomErasing(probability=args.re), ]) 174 | # Create model 175 | model = models.create(args.model, feature_dim=args.features, num_classes=0, norm=args.norm, 176 | dropout=args.dropout, last_stride=args.last_stride, arch=args.arch) 177 | # Load from checkpoint 178 | assert args.resume, 'must provide resume directory' 179 | resume_fname = osp.join(f'logs/{args.model}/{args.dataset}', args.resume, 'model_best.pth.tar') 180 | model, start_epoch, best_top1 = checkpoint_loader(model, resume_fname) 181 | print(f"=> Last epoch {start_epoch}") 182 | model = nn.DataParallel(model).cuda() 183 | model.eval() 184 | toc = time.time() - tic 185 | print('*************** initialization takes time: {:^10.2f} *********************\n'.format(toc)) 186 | 187 | tic = time.time() 188 | if args.data_type == 'reid_test': 189 | args.reid_test = 'query' 190 | data_loader = DataLoader(Preprocessor(dataset.query, root=dataset.query_path, transform=test_transformer), 191 | batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) 192 | extract_n_save(model, data_loader, args, dataset.root, dataset.num_cams, 193 | is_detection=False, use_fname=use_fname) 194 | args.reid_test = 'gallery' 195 | data_loader = DataLoader(Preprocessor(dataset.gallery, root=dataset.gallery_path, transform=test_transformer), 196 | batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) 197 | extract_n_save(model, data_loader, args, dataset.root, dataset.num_cams, 198 | is_detection=False, use_fname=use_fname) 199 | else: 200 | data_loader = DataLoader(Preprocessor(dataset.train, root=dataset.train_path, transform=test_transformer), 201 | batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False) 202 | extract_n_save(model, data_loader, args, dataset.root, dataset.num_cams, 203 | is_detection=args.data_type == 'tracking_det', use_fname=use_fname) 204 | toc = time.time() - tic 205 | print('*************** compute features takes time: {:^10.2f} *********************\n'.format(toc)) 206 | pass 207 | 208 | 209 | if __name__ == '__main__': 210 | parser = argparse.ArgumentParser(description="Save re-ID features") 211 | # data 212 | parser.add_argument('--model', type=str, default='ide', choices=models.names()) 213 | parser.add_argument('-a', '--arch', type=str, default='resnet50', choices=['resnet50', 'densenet121'], 214 | help='architecture for base network') 215 | parser.add_argument('-d', '--dataset', type=str, default='duke', choices=datasets.names()) 216 | parser.add_argument('--data_type', type=str, default='reid', 217 | choices=['tracking_det', 'reid', 'tracking_gt', 'reid_test']) 218 | parser.add_argument('-b', '--batch-size', type=int, default=64, help="batch size") 219 | parser.add_argument('-j', '--num-workers', type=int, default=4) 220 | parser.add_argument('--height', type=int, default=256, help="input height, default: 256 for resnet*") 221 | parser.add_argument('--width', type=int, default=128, help="input width, default: 128 for resnet*") 222 | # model 223 | parser.add_argument('--resume', type=str, default=None, metavar='PATH') 224 | parser.add_argument('--features', type=int, default=256) 225 | parser.add_argument('--dropout', type=float, default=0.5, help='0.5 for ide/pcb, 0 for triplet/zju') 226 | parser.add_argument('-s', '--last_stride', type=int, default=2, choices=[1, 2]) 227 | parser.add_argument('--norm', action='store_true', help="normalize feat, default: False") 228 | # misc 229 | parser.add_argument('--data-dir', type=str, metavar='PATH', default=osp.expanduser('~/Data')) 230 | parser.add_argument('--logs-dir', type=str, metavar='PATH', default=None) 231 | parser.add_argument('--det_time', type=str, metavar='PATH', default='val', 232 | choices=['trainval_nano', 'trainval', 'train', 'val', 'test_all', 'test']) 233 | parser.add_argument('--det_type', type=str, default='ssd', choices=['ssd', 'yolo']) 234 | parser.add_argument('--gt_type', type=str, default='gt', choices=['gt', 'labeled']) 235 | parser.add_argument('--tracking_icams', type=int, default=None, help="specify if train on single iCam") 236 | parser.add_argument('--seed', type=int, default=None) 237 | # data jittering 238 | parser.add_argument('--re', type=float, default=0, help="random erasing") 239 | parser.add_argument('--crop', action='store_true', help="resize then crop, default: False") 240 | main(parser.parse_args()) 241 | -------------------------------------------------------------------------------- /triplet.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['OMP_NUM_THREADS'] = '1' 4 | import argparse 5 | import datetime 6 | import sys 7 | import shutil 8 | from distutils.dir_util import copy_tree 9 | import numpy as np 10 | import torch 11 | from reid import models 12 | from reid.evaluators import Evaluator 13 | from reid.loss import TripletLoss 14 | from reid.trainers import Trainer 15 | from reid.utils.logger import Logger 16 | from reid.utils.draw_curve import * 17 | from reid.utils.get_loaders import * 18 | from reid.utils.serialization import save_checkpoint 19 | 20 | ''' 21 | triplet loss 22 | ''' 23 | 24 | 25 | def main(args): 26 | # seed 27 | if args.seed is not None: 28 | np.random.seed(args.seed) 29 | torch.manual_seed(args.seed) 30 | torch.backends.cudnn.deterministic = True 31 | torch.backends.cudnn.benchmark = False 32 | else: 33 | torch.backends.cudnn.benchmark = True 34 | 35 | if args.logs_dir is None: 36 | args.logs_dir = osp.join(f'logs/triplet/{args.dataset}', datetime.datetime.today().strftime('%Y-%m-%d_%H-%M-%S')) 37 | else: 38 | args.logs_dir = osp.join(f'logs/triplet/{args.dataset}', args.logs_dir) 39 | if args.train: 40 | os.makedirs(args.logs_dir, exist_ok=True) 41 | copy_tree('./reid', args.logs_dir + '/scripts/reid') 42 | for script in os.listdir('.'): 43 | if script.split('.')[-1] == 'py': 44 | dst_file = os.path.join(args.logs_dir, 'scripts', os.path.basename(script)) 45 | shutil.copyfile(script, dst_file) 46 | sys.stdout = Logger(os.path.join(args.logs_dir, 'log.txt'), ) 47 | print('Settings:') 48 | print(vars(args)) 49 | print('\n') 50 | 51 | # Create data loaders 52 | assert args.num_instances > 1, "num_instances should be larger than 1" 53 | assert args.batch_size % args.num_instances == 0, 'num_instances should divide batch_size' 54 | dataset, num_classes, train_loader, query_loader, gallery_loader, _ = \ 55 | get_data(args.dataset, args.data_dir, args.height, args.width, args.batch_size, args.num_workers, 56 | args.combine_trainval, args.crop, args.tracking_icams, args.tracking_fps, args.re, args.num_instances, 57 | False) 58 | 59 | # Create model for triplet (num_classes = 0, num_instances > 0) 60 | model = models.create('ide', feature_dim=args.feature_dim, num_classes=0, norm=args.norm, 61 | dropout=args.dropout, last_stride=args.last_stride) 62 | 63 | # Load from checkpoint 64 | start_epoch = best_top1 = 0 65 | if args.resume: 66 | resume_fname = osp.join(f'logs/triplet/{args.dataset}', args.resume, 'model_best.pth.tar') 67 | model, start_epoch, best_top1 = checkpoint_loader(model, resume_fname) 68 | print("=> Last epoch {} best top1 {:.1%}".format(start_epoch, best_top1)) 69 | start_epoch += 1 70 | model = nn.DataParallel(model).cuda() 71 | 72 | # Criterion 73 | criterion = TripletLoss(margin=args.margin).cuda() 74 | 75 | # Optimizer 76 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 77 | 78 | # Trainer 79 | trainer = Trainer(model, criterion) 80 | 81 | # Evaluator 82 | evaluator = Evaluator(model) 83 | 84 | if args.train: 85 | # Schedule learning rate 86 | def adjust_lr(epoch): 87 | if epoch <= args.step_size: 88 | lr = args.lr 89 | else: 90 | lr = args.lr * (0.001 ** (float(epoch - args.step_size) / (args.epochs - args.step_size))) 91 | for g in optimizer.param_groups: 92 | g['lr'] = lr * g.get('lr_mult', 1) 93 | 94 | # Draw Curve 95 | epoch_s = [] 96 | loss_s = [] 97 | prec_s = [] 98 | eval_epoch_s = [] 99 | eval_top1_s = [] 100 | 101 | # Start training 102 | for epoch in range(start_epoch + 1, args.epochs + 1): 103 | adjust_lr(epoch) 104 | # train_loss, train_prec = 0, 0 105 | train_loss, train_prec = trainer.train(epoch, train_loader, optimizer, fix_bn=args.fix_bn) 106 | 107 | if epoch < args.start_save: 108 | continue 109 | 110 | if epoch % 25 == 0: 111 | top1 = evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery) 112 | eval_epoch_s.append(epoch) 113 | eval_top1_s.append(top1) 114 | else: 115 | top1 = 0 116 | 117 | is_best = top1 >= best_top1 118 | best_top1 = max(top1, best_top1) 119 | save_checkpoint({ 120 | 'state_dict': model.module.state_dict(), 121 | 'epoch': epoch, 122 | 'best_top1': best_top1, 123 | }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) 124 | epoch_s.append(epoch) 125 | loss_s.append(train_loss) 126 | prec_s.append(train_prec) 127 | draw_curve(os.path.join(args.logs_dir, 'train_curve.jpg'), epoch_s, loss_s, prec_s, 128 | eval_epoch_s, None, eval_top1_s) 129 | pass 130 | 131 | # Final test 132 | print('Test with best model:') 133 | model, start_epoch, best_top1 = checkpoint_loader(model, osp.join(args.logs_dir, 'model_best.pth.tar')) 134 | print("=> Start epoch {} best top1 {:.1%}".format(start_epoch, best_top1)) 135 | 136 | evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery) 137 | else: 138 | print("Test:") 139 | evaluator.evaluate(query_loader, gallery_loader, dataset.query, dataset.gallery) 140 | pass 141 | 142 | 143 | if __name__ == '__main__': 144 | parser = argparse.ArgumentParser(description="Triplet loss classification") 145 | # data 146 | parser.add_argument('-d', '--dataset', type=str, default='market1501', choices=datasets.names()) 147 | parser.add_argument('-b', '--batch-size', type=int, default=128, help="batch size") 148 | parser.add_argument('-j', '--num-workers', type=int, default=4) 149 | parser.add_argument('--height', type=int, default=256, help="input height, default: 256 for resnet*") 150 | parser.add_argument('--width', type=int, default=128, help="input width, default: 128 for resnet*") 151 | parser.add_argument('--combine-trainval', action='store_true', 152 | help="train and val sets together for training, val set alone for validation") 153 | parser.add_argument('--tracking_icams', type=int, default=0, help="specify if train on single iCam") 154 | parser.add_argument('--tracking_fps', type=int, default=1, help="specify if train on single iCam") 155 | parser.add_argument('--re', type=float, default=0, help="random erasing") 156 | parser.add_argument('--crop', type=bool, default=1, help="resize then crop, default: True") 157 | # model 158 | parser.add_argument('--feature_dim', type=int, default=256) 159 | parser.add_argument('--dropout', type=float, default=0) 160 | parser.add_argument('-s', '--last_stride', type=int, default=2, choices=[1, 2]) 161 | parser.add_argument('--norm', action='store_true', help="normalize feat, default: False") 162 | parser.add_argument('--arch', type=str, default='resnet50', choices=['resnet50', 'densenet121'], 163 | help='architecture for base network') 164 | # loss 165 | parser.add_argument('--margin', type=float, default=0.3, help="margin of the triplet loss, default: 0.3") 166 | parser.add_argument('--num-instances', type=int, default=4, 167 | help="each minibatch consist of " 168 | "(batch_size // num_instances) identities, and " 169 | "each identity has num_instances instances, " 170 | "default: 4") 171 | # optimizer 172 | parser.add_argument('--lr', type=float, default=2e-4, help="learning rate of ALL parameters") 173 | parser.add_argument('--weight-decay', type=float, default=5e-4) 174 | # training configs 175 | parser.add_argument('--train', action='store_true', help="train IDE model from start") 176 | parser.add_argument('--fix_bn', type=bool, default=0, help="fix (skip training) BN in base network") 177 | parser.add_argument('--resume', type=str, default=None, metavar='PATH') 178 | parser.add_argument('--epochs', type=int, default=300) 179 | parser.add_argument('--step-size', type=int, default=150) 180 | parser.add_argument('--start_save', type=int, default=0, help="start saving checkpoints after specific epoch") 181 | parser.add_argument('--seed', type=int, default=None) 182 | parser.add_argument('--print-freq', type=int, default=10) 183 | # misc 184 | parser.add_argument('--data-dir', type=str, metavar='PATH', default=osp.expanduser('~/Data')) 185 | parser.add_argument('--logs-dir', type=str, metavar='PATH', default=None) 186 | main(parser.parse_args()) 187 | --------------------------------------------------------------------------------