├── LICENSE ├── README.md ├── conda.yaml ├── configs ├── cuhk_detected.yaml ├── cuhk_labeled.yaml ├── duke.yaml └── market.yaml ├── default_config.py ├── exp_FPB.py ├── requirements.txt └── torchreid ├── __init__.py ├── data ├── __init__.py ├── datamanager.py ├── datasets │ ├── __init__.py │ ├── dataset.py │ ├── image │ │ ├── __init__.py │ │ ├── cuhk01.py │ │ ├── cuhk02.py │ │ ├── cuhk03.py │ │ ├── cuhk03_detected.py │ │ ├── cuhk03_labeled.py │ │ ├── dukemtmcreid.py │ │ ├── grid.py │ │ ├── ilids.py │ │ ├── market1501.py │ │ ├── msmt17.py │ │ ├── prid.py │ │ ├── sensereid.py │ │ └── viper.py │ └── video │ │ ├── __init__.py │ │ ├── dukemtmcvidreid.py │ │ ├── ilidsvid.py │ │ ├── mars.py │ │ └── prid2011.py ├── sampler.py └── transforms.py ├── engine ├── __init__.py ├── engine.py ├── engine_vis.py ├── image │ ├── __init__.py │ ├── engine_FPB.py │ ├── softmax.py │ └── triplet.py └── video │ ├── __init__.py │ ├── softmax.py │ └── triplet.py ├── losses ├── __init__.py ├── center_loss.py ├── cross_entropy_loss.py ├── hard_mine_triplet_loss.py └── ranked_loss.py ├── metrics ├── __init__.py ├── accuracy.py ├── distance.py ├── rank.py └── rank_cylib │ ├── Makefile │ ├── __init__.py │ ├── rank_cy.pyx │ ├── setup.py │ └── test_cython.py ├── models ├── __init__.py ├── fpb.py ├── nn_utils.py └── pc.py ├── optim ├── __init__.py ├── lr_scheduler.py ├── optimizer.py └── warm_up.py └── utils ├── __init__.py ├── avgmeter.py ├── loggers.py ├── model_complexity.py ├── reidtools.py ├── rerank.py ├── tools.py └── torchtools.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Suofei Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FPB 2 | 3 | To create exceutable environment with anaconda: 4 | ``` 5 | conda env create -f conda.yaml 6 | conda activate reid 7 | ``` 8 | 9 | To install required python packages with pip: 10 | ``` 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | To train and evaluate FPB model on Market1501 dataset: 15 | ``` 16 | python exp_FPB.py --config-file ./configs/market.yaml --root /path/to/your/data 17 | ``` 18 | 19 | To train and evaluate FPB model on CUHK03_Detected dataset: 20 | ``` 21 | python exp_FPB.py --config-file ./configs/cuhk_detected.yaml --root /path/to/your/data 22 | ``` 23 | 24 | To train and evaluate FPB model on CUHK03_Labeled dataset: 25 | ``` 26 | python exp_FPB.py --config-file ./configs/cuhk_labeled.yaml --root /path/to/your/data 27 | ``` 28 | 29 | To train and evaluate FPB model on DukeMTMC dataset: 30 | ``` 31 | python exp_FPB.py --config-file ./configs/duke.yaml --root /path/to/your/data 32 | ``` 33 | 34 | All records and log file are in the `./log/` directory. -------------------------------------------------------------------------------- /conda.yaml: -------------------------------------------------------------------------------- 1 | name: reid 2 | channels: 3 | - defaults 4 | dependencies: 5 | - ca-certificates=2019.1.23=0 6 | - certifi=2019.3.9=py36_0 7 | - libedit=3.1.20181209=hc058e9b_0 8 | - libffi=3.2.1=hd88cf55_4 9 | - libgcc-ng=8.2.0=hdf63c60_1 10 | - libstdcxx-ng=8.2.0=hdf63c60_1 11 | - ncurses=6.1=he6710b0_1 12 | - openssl=1.1.1b=h7b6447c_1 13 | - pip=19.1.1=py36_0 14 | - python=3.6.8=h0371630_0 15 | - readline=7.0=h7b6447c_5 16 | - setuptools=41.0.1=py36_0 17 | - sqlite=3.28.0=h7b6447c_0 18 | - tk=8.6.8=hbc83047_0 19 | - wheel=0.33.4=py36_0 20 | - xz=5.2.4=h14c3975_4 21 | - zlib=1.2.11=h7b6447c_3 22 | - pip: 23 | - absl-py==0.8.0 24 | - astroid==2.3.3 25 | - attrs==19.1.0 26 | - backcall==0.1.0 27 | - beautifulsoup4==4.8.1 28 | - bleach==3.1.0 29 | - chardet==3.0.4 30 | - click==7.0 31 | - configargparse==0.14.0 32 | - cycler==0.10.0 33 | - cython==0.29.12 34 | - cython-based-reid-evaluation-code==0.0.0 35 | - decorator==4.4.0 36 | - defusedxml==0.6.0 37 | - dropblock==0.3.0 38 | - entrypoints==0.3 39 | - enum34==1.1.6 40 | - et-xmlfile==1.0.1 41 | - filelock==3.0.12 42 | - future==0.17.1 43 | - gdown==3.8.3 44 | - grpcio==1.23.0 45 | - h5py==2.10.0 46 | - html5lib==1.0.1 47 | - idna==2.8 48 | - ignite==1.0.0 49 | - ipykernel==5.1.1 50 | - ipython==7.5.0 51 | - ipython-genutils==0.2.0 52 | - ipywidgets==7.4.2 53 | - isort==4.3.21 54 | - jdcal==1.4.1 55 | - jedi==0.13.3 56 | - jinja2==2.10.1 57 | - joblib==0.14.0 58 | - json5==0.8.5 59 | - jsonschema==3.0.1 60 | - jupyter==1.0.0 61 | - jupyter-client==5.2.4 62 | - jupyter-console==6.0.0 63 | - jupyter-core==4.4.0 64 | - jupyterlab==1.1.4 65 | - jupyterlab-server==1.0.6 66 | - kiwisolver==1.1.0 67 | - lazy-object-proxy==1.4.3 68 | - lxml==4.4.1 69 | - markdown==3.1.1 70 | - markupsafe==1.1.1 71 | - matplotlib==3.1.0 72 | - mccabe==0.6.1 73 | - mistune==0.8.4 74 | - mock==3.0.5 75 | - navi==0.1.0a19 76 | - nbconvert==5.5.0 77 | - nbformat==4.4.0 78 | - notebook==5.7.8 79 | - numexpr==2.7.0 80 | - numpy==1.16.3 81 | - opencv-python==4.1.1.26 82 | - openpyxl==3.0.0 83 | - pandas==0.25.1 84 | - pandas-datareader==0.8.1 85 | - pandocfilters==1.4.2 86 | - parso==0.4.0 87 | - pexpect==4.7.0 88 | - pickleshare==0.7.5 89 | - pillow==6.0.0 90 | - prometheus-client==0.6.0 91 | - prompt-toolkit==2.0.9 92 | - protobuf==3.9.1 93 | - ptyprocess==0.6.0 94 | - pydispatcher==2.0.5 95 | - pygments==2.4.0 96 | - pylint==2.4.4 97 | - pyparsing==2.4.0 98 | - pyrsistent==0.15.2 99 | - python-dateutil==2.8.0 100 | - pytorch-ignite==0.1.1 101 | - pytz==2019.3 102 | - pyyaml==5.1.2 103 | - pyzmq==18.0.1 104 | - qtconsole==4.4.4 105 | - requests==2.22.0 106 | - scikit-learn==0.21.3 107 | - scipy==1.2.1 108 | - send2trash==1.5.0 109 | - six==1.12.0 110 | - soupsieve==1.9.4 111 | - sqlalchemy==1.3.10 112 | - tables==3.5.2 113 | - tb-nightly==2.0.0a20190917 114 | - terminado==0.8.2 115 | - testpath==0.4.2 116 | - torch==1.1.0 117 | - torchfile==0.1.0 118 | - torchsummary==1.5.1 119 | - torchvision==0.3.0 120 | - tornado==6.0.3 121 | - tqdm==4.35.0 122 | - traitlets==4.3.2 123 | - typed-ast==1.4.0 124 | - urllib3==1.25.2 125 | - visdom==0.1.8.8 126 | - wcwidth==0.1.7 127 | - webencodings==0.5.1 128 | - websocket-client==0.56.0 129 | - werkzeug==0.15.6 130 | - widgetsnbextension==3.4.2 131 | - wrapt==1.11.2 132 | - xlrd==1.2.0 133 | - yacs==0.1.6 134 | prefix: /home/user/anaconda3/envs/reid 135 | 136 | -------------------------------------------------------------------------------- /configs/cuhk_detected.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: 'fpb' 3 | pretrained: True 4 | 5 | data: 6 | type: 'image' 7 | sources: ['cuhk03_detected'] 8 | targets: ['cuhk03_detected'] 9 | height: 384 10 | width: 128 11 | combineall: False 12 | transforms: ['random_flip','random_erase','random_crop','random_patch'] 13 | save_dir: 'log' 14 | exp_name: 'cuhk_detected' 15 | 16 | sampler: 17 | train_sampler: 'RandomIdentitySampler' 18 | num_instances: 4 19 | 20 | loss: 21 | name: 'engine_FPB' 22 | softmax: 23 | label_smooth: True 24 | triplet: 25 | weight_t: 1. 26 | weight_x: 1. 27 | div_reg: True 28 | div_reg_start: 0 #50 29 | 30 | train: 31 | optim: 'adam' 32 | lr: 0.000035 33 | max_epoch: 120 34 | batch_size: 128 35 | lr_scheduler: 'warmup' 36 | warmup_multiplier: 10 37 | warmup_total_epoch: 19 38 | stepsize: [40, 70] 39 | 40 | cuhk03: 41 | labeled_images: False 42 | use_metric_cuhk03: False 43 | 44 | test: 45 | batch_size: 100 46 | dist_metric: 'euclidean' 47 | normalize_feature: False 48 | evaluate: False 49 | eval_freq: 1 50 | start_eval: 60 51 | rerank: False 52 | visactmap: False 53 | -------------------------------------------------------------------------------- /configs/cuhk_labeled.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: 'fpb' 3 | pretrained: True 4 | 5 | data: 6 | type: 'image' 7 | sources: ['cuhk03_labeled'] 8 | targets: ['cuhk03_labeled'] 9 | height: 384 10 | width: 128 11 | combineall: False 12 | transforms: ['random_flip','random_erase','random_crop','random_patch'] 13 | save_dir: 'log' 14 | exp_name: 'cuhk_labeled' 15 | 16 | sampler: 17 | train_sampler: 'RandomIdentitySampler' 18 | num_instances: 4 19 | 20 | loss: 21 | name: 'engine_FPB' 22 | softmax: 23 | label_smooth: True 24 | triplet: 25 | weight_t: 1. 26 | weight_x: 1. 27 | div_reg: True 28 | div_reg_start: 0 #50 29 | 30 | train: 31 | optim: 'adam' 32 | lr: 0.000035 33 | max_epoch: 120 34 | batch_size: 128 35 | lr_scheduler: 'warmup' 36 | warmup_multiplier: 10 37 | warmup_total_epoch: 19 38 | stepsize: [40, 70] 39 | 40 | cuhk03: 41 | labeled_images: True 42 | use_metric_cuhk03: False 43 | 44 | test: 45 | batch_size: 100 46 | dist_metric: 'euclidean' 47 | normalize_feature: False 48 | evaluate: False 49 | eval_freq: 1 50 | start_eval: 60 51 | rerank: False 52 | visactmap: False 53 | -------------------------------------------------------------------------------- /configs/duke.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: 'fpb' 3 | pretrained: True 4 | 5 | data: 6 | type: 'image' 7 | sources: ['dukemtmcreid'] 8 | targets: ['dukemtmcreid'] 9 | height: 384 10 | width: 128 11 | combineall: False 12 | transforms: ['random_flip','random_erase','random_crop','random_patch'] 13 | save_dir: 'log' 14 | exp_name: 'duke' 15 | 16 | sampler: 17 | train_sampler: 'RandomIdentitySampler' 18 | num_instances: 4 19 | 20 | loss: 21 | name: 'engine_FPB' 22 | softmax: 23 | label_smooth: True 24 | triplet: 25 | weight_t: 1. 26 | weight_x: 1. 27 | div_reg: True 28 | div_reg_start: 0 #50 29 | 30 | train: 31 | optim: 'adam' 32 | lr: 0.000035 33 | max_epoch: 120 34 | batch_size: 128 35 | lr_scheduler: 'warmup' 36 | warmup_multiplier: 10 37 | warmup_total_epoch: 19 38 | stepsize: [40, 70] 39 | 40 | 41 | test: 42 | batch_size: 100 43 | dist_metric: 'euclidean' 44 | normalize_feature: False 45 | evaluate: False 46 | eval_freq: 1 47 | start_eval: 60 48 | rerank: False 49 | visactmap: False 50 | -------------------------------------------------------------------------------- /configs/market.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | name: 'fpb' 3 | pretrained: True 4 | 5 | data: 6 | type: 'image' 7 | sources: ['market1501'] 8 | targets: ['market1501'] 9 | height: 384 10 | width: 128 11 | combineall: False 12 | transforms: ['random_flip','random_erase','random_crop','random_patch'] 13 | save_dir: 'log' 14 | exp_name: 'market1501' 15 | 16 | sampler: 17 | train_sampler: 'RandomIdentitySampler' 18 | num_instances: 4 19 | 20 | loss: 21 | name: 'engine_FPB' 22 | softmax: 23 | label_smooth: True 24 | triplet: 25 | weight_t: 1. 26 | weight_x: 1. 27 | div_reg: True 28 | div_reg_start: 0 #50 29 | 30 | train: 31 | optim: 'adam' 32 | lr: 0.000035 33 | max_epoch: 120 34 | batch_size: 64 35 | lr_scheduler: 'warmup' 36 | warmup_multiplier: 10 37 | warmup_total_epoch: 19 38 | stepsize: [40, 70] 39 | 40 | 41 | test: 42 | batch_size: 100 43 | dist_metric: 'euclidean' 44 | normalize_feature: False 45 | evaluate: False 46 | eval_freq: 1 47 | start_eval: 60 48 | rerank: False 49 | visactmap: False 50 | -------------------------------------------------------------------------------- /default_config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from yacs.config import CfgNode as CN 3 | 4 | 5 | def get_default_config(): 6 | cfg = CN() 7 | 8 | # model 9 | cfg.model = CN() 10 | cfg.model.name = 'resnet50' 11 | cfg.model.pretrained = True # automatically load pretrained model weights if available 12 | cfg.model.with_attention = True 13 | cfg.model.load_weights = '' # path to model weights 14 | cfg.model.resume = '' # path to checkpoint for resume training 15 | 16 | # data 17 | cfg.data = CN() 18 | cfg.data.type = 'image' 19 | cfg.data.root = 'reid-data' 20 | cfg.data.sources = ['market1501'] 21 | cfg.data.targets = ['market1501'] 22 | cfg.data.workers = 4 # number of data loading workers 23 | cfg.data.split_id = 0 # split index 24 | cfg.data.height = 256 # image height 25 | cfg.data.width = 128 # image width 26 | cfg.data.combineall = False # combine train, query and gallery for training 27 | cfg.data.transforms = ['random_flip'] # data augmentation 28 | cfg.data.norm_mean = [0.485, 0.456, 0.406] # default is imagenet mean 29 | cfg.data.norm_std = [0.229, 0.224, 0.225] # default is imagenet std 30 | cfg.data.save_dir = 'log' # path to save log 31 | cfg.data.exp_name = 'test' 32 | 33 | # specific datasets 34 | cfg.market1501 = CN() 35 | cfg.market1501.use_500k_distractors = False # add 500k distractors to the gallery set for market1501 36 | cfg.cuhk03 = CN() 37 | cfg.cuhk03.labeled_images = True # use labeled images, if False, use detected images 38 | cfg.cuhk03.classic_split = False # use classic split by Li et al. CVPR14 39 | cfg.cuhk03.use_metric_cuhk03 = False # use cuhk03's metric for evaluation 40 | 41 | # sampler 42 | cfg.sampler = CN() 43 | cfg.sampler.train_sampler = 'RandomSampler' 44 | cfg.sampler.num_instances = 4 # number of instances per identity for RandomIdentitySampler 45 | 46 | # video reid setting 47 | cfg.video = CN() 48 | cfg.video.seq_len = 15 # number of images to sample in a tracklet 49 | cfg.video.sample_method = 'evenly' # how to sample images from a tracklet 50 | cfg.video.pooling_method = 'avg' # how to pool features over a tracklet 51 | 52 | # train 53 | cfg.train = CN() 54 | cfg.train.optim = 'adam' 55 | cfg.train.lr = 0.0003 56 | cfg.train.weight_decay = 5e-4 57 | cfg.train.max_epoch = 60 58 | cfg.train.start_epoch = 0 59 | cfg.train.batch_size = 32 60 | cfg.train.fixbase_epoch = 0 # number of epochs to fix base layers 61 | cfg.train.open_layers = ['classifier'] # layers for training while keeping others frozen 62 | cfg.train.staged_lr = False # set different lr to different layers 63 | cfg.train.new_layers = ['classifier'] # newly added layers with default lr 64 | cfg.train.base_lr_mult = 0.1 # learning rate multiplier for base layers 65 | cfg.train.lr_scheduler = 'single_step' 66 | cfg.train.stepsize = [20] # stepsize to decay learning rate 67 | cfg.train.gamma = 0.1 # learning rate decay multiplier 68 | cfg.train.print_freq = 20 # print frequency 69 | cfg.train.seed = 1 # random seed 70 | cfg.train.warmup_multiplier = 100 71 | cfg.train.warmup_total_epoch = 9 72 | 73 | # optimizer 74 | cfg.sgd = CN() 75 | cfg.sgd.momentum = 0.9 # momentum factor for sgd and rmsprop 76 | cfg.sgd.dampening = 0. # dampening for momentum 77 | cfg.sgd.nesterov = False # Nesterov momentum 78 | cfg.rmsprop = CN() 79 | cfg.rmsprop.alpha = 0.99 # smoothing constant 80 | cfg.adam = CN() 81 | cfg.adam.beta1 = 0.9 # exponential decay rate for first moment 82 | cfg.adam.beta2 = 0.999 # exponential decay rate for second moment 83 | 84 | # loss 85 | cfg.loss = CN() 86 | cfg.loss.name = 'softmax' 87 | cfg.loss.softmax = CN() 88 | cfg.loss.softmax.label_smooth = True # use label smoothing regularizer 89 | cfg.loss.triplet = CN() 90 | cfg.loss.triplet.margin = 0.3 # distance margin 91 | cfg.loss.triplet.weight_t = 1. # weight to balance hard triplet loss 92 | cfg.loss.triplet.weight_x = 0. # weight to balance cross entropy loss 93 | cfg.loss.dynamic = CN() 94 | cfg.loss.dynamic.alpha = 0.25 95 | cfg.loss.dynamic.gamma = 2 96 | cfg.loss.dynamic.delta = 0.16 97 | cfg.loss.div_reg = False 98 | cfg.loss.div_reg_beta = 1e-6 99 | cfg.loss.div_reg_start = 23 100 | 101 | # test 102 | cfg.test = CN() 103 | cfg.test.batch_size = 100 104 | cfg.test.dist_metric = 'euclidean' # distance metric, ['euclidean', 'cosine'] 105 | cfg.test.normalize_feature = False # normalize feature vectors before computing distance 106 | cfg.test.ranks = [1, 5, 10, 20] # cmc ranks 107 | cfg.test.evaluate = False # test only 108 | cfg.test.multi_eval = False # when multiple datasets are used, whether to evaluate model on the combination of datasets rather than one by one 109 | cfg.test.eval_freq = -1 # evaluation frequency (-1 means to only test after training) 110 | cfg.test.start_eval = 0 # start to evaluate after a specific epoch 111 | cfg.test.rerank = False # use person re-ranking 112 | cfg.test.visrank = False # visualize ranked results (only available when cfg.test.evaluate=True) 113 | cfg.test.visrank_topk = 10 # top-k ranks to visualize 114 | cfg.test.visactmap = False # visualize CNN activation maps 115 | 116 | return cfg 117 | 118 | 119 | def imagedata_kwargs(cfg): 120 | return { 121 | 'root': cfg.data.root, 122 | 'sources': cfg.data.sources, 123 | 'targets': cfg.data.targets, 124 | 'height': cfg.data.height, 125 | 'width': cfg.data.width, 126 | 'transforms': cfg.data.transforms, 127 | 'norm_mean': cfg.data.norm_mean, 128 | 'norm_std': cfg.data.norm_std, 129 | 'use_gpu': cfg.use_gpu, 130 | 'split_id': cfg.data.split_id, 131 | 'combineall': cfg.data.combineall, 132 | 'batch_size_train': cfg.train.batch_size, 133 | 'batch_size_test': cfg.test.batch_size, 134 | 'workers': cfg.data.workers, 135 | 'num_instances': cfg.sampler.num_instances, 136 | 'train_sampler': cfg.sampler.train_sampler, 137 | # image 138 | 'cuhk03_labeled': cfg.cuhk03.labeled_images, 139 | 'cuhk03_classic_split': cfg.cuhk03.classic_split, 140 | 'market1501_500k': cfg.market1501.use_500k_distractors, 141 | } 142 | 143 | 144 | def videodata_kwargs(cfg): 145 | return { 146 | 'root': cfg.data.root, 147 | 'sources': cfg.data.sources, 148 | 'targets': cfg.data.targets, 149 | 'height': cfg.data.height, 150 | 'width': cfg.data.width, 151 | 'transforms': cfg.data.transforms, 152 | 'norm_mean': cfg.data.norm_mean, 153 | 'norm_std': cfg.data.norm_std, 154 | 'use_gpu': cfg.use_gpu, 155 | 'split_id': cfg.data.split_id, 156 | 'combineall': cfg.data.combineall, 157 | 'batch_size_train': cfg.train.batch_size, 158 | 'batch_size_test': cfg.test.batch_size, 159 | 'workers': cfg.data.workers, 160 | 'num_instances': cfg.sampler.num_instances, 161 | 'train_sampler': cfg.sampler.train_sampler, 162 | # video 163 | 'seq_len': cfg.video.seq_len, 164 | 'sample_method': cfg.video.sample_method 165 | } 166 | 167 | 168 | def optimizer_kwargs(cfg): 169 | return { 170 | 'optim': cfg.train.optim, 171 | 'lr': cfg.train.lr, 172 | 'weight_decay': cfg.train.weight_decay, 173 | 'momentum': cfg.sgd.momentum, 174 | 'sgd_dampening': cfg.sgd.dampening, 175 | 'sgd_nesterov': cfg.sgd.nesterov, 176 | 'rmsprop_alpha': cfg.rmsprop.alpha, 177 | 'adam_beta1': cfg.adam.beta1, 178 | 'adam_beta2': cfg.adam.beta2, 179 | 'staged_lr': cfg.train.staged_lr, 180 | 'new_layers': cfg.train.new_layers, 181 | 'base_lr_mult': cfg.train.base_lr_mult 182 | } 183 | 184 | 185 | def lr_scheduler_kwargs(cfg): 186 | return { 187 | 'lr_scheduler': cfg.train.lr_scheduler, 188 | 'stepsize': cfg.train.stepsize, 189 | 'gamma': cfg.train.gamma, 190 | 'max_epoch': cfg.train.max_epoch, 191 | 'warmup_multiplier': cfg.train.warmup_multiplier, 192 | 'warmup_total_epoch': cfg.train.warmup_total_epoch 193 | } 194 | 195 | 196 | def engine_run_kwargs(cfg): 197 | return { 198 | 'save_dir': cfg.data.save_dir, 199 | 'max_epoch': cfg.train.max_epoch, 200 | 'start_epoch': cfg.train.start_epoch, 201 | 'fixbase_epoch': cfg.train.fixbase_epoch, 202 | 'open_layers': cfg.train.open_layers, 203 | 'start_eval': cfg.test.start_eval, 204 | 'eval_freq': cfg.test.eval_freq, 205 | 'test_only': cfg.test.evaluate, 206 | 'multi_eval': cfg.test.multi_eval, 207 | 'print_freq': cfg.train.print_freq, 208 | 'dist_metric': cfg.test.dist_metric, 209 | 'normalize_feature': cfg.test.normalize_feature, 210 | 'visrank': cfg.test.visrank, 211 | 'visrank_topk': cfg.test.visrank_topk, 212 | 'use_metric_cuhk03': cfg.cuhk03.use_metric_cuhk03, 213 | 'ranks': cfg.test.ranks, 214 | 'rerank': cfg.test.rerank, 215 | 'visactmap': cfg.test.visactmap 216 | } 217 | -------------------------------------------------------------------------------- /exp_FPB.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import os.path as osp 4 | import warnings 5 | import time 6 | import argparse 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | from default_config import ( 12 | get_default_config, imagedata_kwargs, videodata_kwargs, 13 | optimizer_kwargs, lr_scheduler_kwargs, engine_run_kwargs 14 | ) 15 | import torchreid 16 | from torchreid.utils import ( 17 | Logger, set_random_seed, check_isfile, resume_from_checkpoint, 18 | load_pretrained_weights, compute_model_complexity, collect_env_info, calc_model_params 19 | ) 20 | from torchreid.models.nn_utils import OFPenalty 21 | 22 | def build_datamanager(cfg): 23 | if cfg.data.type == 'image': 24 | return torchreid.data.ImageDataManager(**imagedata_kwargs(cfg)) 25 | else: 26 | return torchreid.data.VideoDataManager(**videodata_kwargs(cfg)) 27 | 28 | 29 | def build_engine(cfg, datamanager, model, optimizer, scheduler): 30 | if cfg.data.type == 'image': 31 | if cfg.loss.name == 'softmax': 32 | engine = torchreid.engine.ImageSoftmaxEngine( 33 | datamanager, 34 | model, 35 | optimizer, 36 | scheduler=scheduler, 37 | use_gpu=cfg.use_gpu, 38 | label_smooth=cfg.loss.softmax.label_smooth 39 | ) 40 | if cfg.loss.name == 'triplet': 41 | engine = torchreid.engine.ImageTripletEngine( 42 | datamanager, 43 | model, 44 | optimizer, 45 | margin=cfg.loss.triplet.margin, 46 | weight_t=cfg.loss.triplet.weight_t, 47 | weight_x=cfg.loss.triplet.weight_x, 48 | scheduler=scheduler, 49 | use_gpu=cfg.use_gpu, 50 | label_smooth=cfg.loss.softmax.label_smooth 51 | ) 52 | if cfg.loss.name == 'engine_FPB': 53 | if cfg.loss.div_reg: 54 | div_penalty = OFPenalty(cfg.loss.div_reg_beta) 55 | else: 56 | div_penalty = None 57 | 58 | engine = torchreid.engine.ImageFPBEngine( 59 | datamanager, 60 | model, 61 | optimizer, 62 | margin=cfg.loss.triplet.margin, 63 | weight_t=cfg.loss.triplet.weight_t, 64 | weight_x=cfg.loss.triplet.weight_x, 65 | scheduler=scheduler, 66 | use_gpu=cfg.use_gpu, 67 | label_smooth=cfg.loss.softmax.label_smooth, 68 | div_penalty = div_penalty, 69 | div_start = cfg.loss.div_reg_start 70 | ) 71 | 72 | return engine 73 | 74 | 75 | def reset_config(cfg, args): 76 | if args.root: 77 | cfg.data.root = args.root 78 | if args.sources: 79 | cfg.data.sources = args.sources 80 | if args.targets: 81 | cfg.data.targets = args.targets 82 | if args.transforms: 83 | cfg.data.transforms = args.transforms 84 | 85 | cfg.data.save_dir = cfg.data.save_dir+'/'+cfg.data.exp_name 86 | 87 | def main(): 88 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 89 | parser.add_argument('--config-file', type=str, default='', help='path to config file') 90 | parser.add_argument('-s', '--sources', type=str, nargs='+', help='source datasets (delimited by space)') 91 | parser.add_argument('-t', '--targets', type=str, nargs='+', help='target datasets (delimited by space)') 92 | parser.add_argument('--transforms', type=str, nargs='+', help='data augmentation') 93 | parser.add_argument('--root', type=str, default='', help='path to data root') 94 | parser.add_argument('--gpu-devices', type=str, default='1',) 95 | parser.add_argument('opts', default=None, nargs=argparse.REMAINDER, help='Modify config options using the command-line') 96 | args = parser.parse_args() 97 | 98 | cfg = get_default_config() 99 | cfg.use_gpu = torch.cuda.is_available() 100 | if args.config_file: 101 | cfg.merge_from_file(args.config_file) 102 | reset_config(cfg, args) 103 | cfg.merge_from_list(args.opts) 104 | cfg.freeze() 105 | set_random_seed(cfg.train.seed) 106 | 107 | if cfg.use_gpu and args.gpu_devices: 108 | # if gpu_devices is not specified, all available gpus will be used 109 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices 110 | 111 | if osp.exists(cfg.data.save_dir): 112 | print('!!!The name of experiment already exists!!!') 113 | exit(1) 114 | 115 | log_name = 'test.log' if cfg.test.evaluate else 'train.log' 116 | # log_name += time.strftime('-%Y-%m-%d-%H-%M-%S') 117 | sys.stdout = Logger(osp.join(cfg.data.save_dir, log_name)) 118 | 119 | print('Show configuration\n{}\n'.format(cfg)) 120 | print('Collecting env info ...') 121 | print('** System info **\n{}\n'.format(collect_env_info())) 122 | 123 | if cfg.use_gpu: 124 | torch.backends.cudnn.benchmark = True 125 | 126 | datamanager = build_datamanager(cfg) 127 | 128 | print('Building model: {}'.format(cfg.model.name)) 129 | model = torchreid.models.build_model( 130 | name=cfg.model.name, 131 | num_classes=datamanager.num_train_pids, 132 | loss=cfg.loss.name, 133 | pretrained=cfg.model.pretrained, 134 | use_gpu=cfg.use_gpu 135 | ) 136 | num_params, flops = compute_model_complexity(model, (1, 3, cfg.data.height, cfg.data.width)) 137 | print('Model complexity: params={:,} flops={:,}'.format(num_params, flops)) 138 | calc_model_params(model) 139 | if cfg.model.load_weights and check_isfile(cfg.model.load_weights): 140 | load_pretrained_weights(model, cfg.model.load_weights) 141 | 142 | if cfg.use_gpu: 143 | model = nn.DataParallel(model).cuda() 144 | 145 | optimizer = torchreid.optim.build_optimizer(model, **optimizer_kwargs(cfg)) 146 | scheduler = torchreid.optim.build_lr_scheduler(optimizer, **lr_scheduler_kwargs(cfg)) 147 | 148 | if cfg.model.resume and check_isfile(cfg.model.resume): 149 | args.start_epoch = resume_from_checkpoint(cfg.model.resume, model, optimizer=optimizer) 150 | 151 | print('Building {}-engine for {}-reid'.format(cfg.loss.name, cfg.data.type)) 152 | engine = build_engine(cfg, datamanager, model, optimizer, scheduler) 153 | engine.run(**engine_run_kwargs(cfg)) 154 | 155 | 156 | if __name__ == '__main__': 157 | main() 158 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.8.0 2 | astroid==2.3.3 3 | attrs==19.1.0 4 | backcall==0.1.0 5 | beautifulsoup4==4.8.1 6 | bleach==3.1.0 7 | certifi==2019.3.9 8 | chardet==3.0.4 9 | Click==7.0 10 | ConfigArgParse==0.14.0 11 | cycler==0.10.0 12 | Cython==0.29.12 13 | Cython-based-reid-evaluation-code==0.0.0 14 | decorator==4.4.0 15 | defusedxml==0.6.0 16 | dropblock==0.3.0 17 | entrypoints==0.3 18 | enum34==1.1.6 19 | et-xmlfile==1.0.1 20 | filelock==3.0.12 21 | future==0.17.1 22 | gdown==3.8.3 23 | grpcio==1.23.0 24 | h5py==2.10.0 25 | html5lib==1.0.1 26 | idna==2.8 27 | ignite==1.0.0 28 | ipykernel==5.1.1 29 | ipython==7.5.0 30 | ipython-genutils==0.2.0 31 | ipywidgets==7.4.2 32 | isort==4.3.21 33 | jdcal==1.4.1 34 | jedi==0.13.3 35 | Jinja2==2.10.1 36 | joblib==0.14.0 37 | json5==0.8.5 38 | jsonschema==3.0.1 39 | jupyter==1.0.0 40 | jupyter-client==5.2.4 41 | jupyter-console==6.0.0 42 | jupyter-core==4.4.0 43 | jupyterlab==1.1.4 44 | jupyterlab-server==1.0.6 45 | kiwisolver==1.1.0 46 | lazy-object-proxy==1.4.3 47 | lxml==4.4.1 48 | Markdown==3.1.1 49 | MarkupSafe==1.1.1 50 | matplotlib==3.1.0 51 | mccabe==0.6.1 52 | mistune==0.8.4 53 | mock==3.0.5 54 | navi==0.1.0a19 55 | nbconvert==5.5.0 56 | nbformat==4.4.0 57 | notebook==5.7.8 58 | numexpr==2.7.0 59 | numpy==1.16.3 60 | opencv-python==4.1.1.26 61 | openpyxl==3.0.0 62 | pandas==0.25.1 63 | pandas-datareader==0.8.1 64 | pandocfilters==1.4.2 65 | parso==0.4.0 66 | pexpect==4.7.0 67 | pickleshare==0.7.5 68 | Pillow==6.0.0 69 | prometheus-client==0.6.0 70 | prompt-toolkit==2.0.9 71 | protobuf==3.9.1 72 | ptyprocess==0.6.0 73 | PyDispatcher==2.0.5 74 | Pygments==2.4.0 75 | pylint==2.4.4 76 | pyparsing==2.4.0 77 | pyrsistent==0.15.2 78 | python-dateutil==2.8.0 79 | pytorch-ignite==0.1.1 80 | pytz==2019.3 81 | PyYAML==5.1.2 82 | pyzmq==18.0.1 83 | qtconsole==4.4.4 84 | requests==2.22.0 85 | scikit-learn==0.21.3 86 | scipy==1.2.1 87 | Send2Trash==1.5.0 88 | six==1.12.0 89 | soupsieve==1.9.4 90 | SQLAlchemy==1.3.10 91 | tables==3.5.2 92 | tb-nightly==2.0.0a20190917 93 | terminado==0.8.2 94 | testpath==0.4.2 95 | torch==1.1.0 96 | torchfile==0.1.0 97 | torchsummary==1.5.1 98 | torchvision==0.3.0 99 | tornado==6.0.3 100 | tqdm==4.35.0 101 | traitlets==4.3.2 102 | typed-ast==1.4.0 103 | urllib3==1.25.2 104 | visdom==0.1.8.8 105 | wcwidth==0.1.7 106 | webencodings==0.5.1 107 | websocket-client==0.56.0 108 | Werkzeug==0.15.6 109 | widgetsnbextension==3.4.2 110 | wrapt==1.11.2 111 | xlrd==1.2.0 112 | yacs==0.1.6 -------------------------------------------------------------------------------- /torchreid/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | __version__ = '1.0.2' 5 | __author__ = 'Kaiyang Zhou' 6 | __description__ = 'Deep learning person re-identification in PyTorch' 7 | __url__ = 'https://github.com/KaiyangZhou/deep-person-reid' 8 | 9 | from torchreid import ( 10 | engine, 11 | models, 12 | losses, 13 | metrics, 14 | data, 15 | optim, 16 | utils 17 | ) 18 | -------------------------------------------------------------------------------- /torchreid/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from .datasets import Dataset, ImageDataset, VideoDataset 5 | from .datasets import register_image_dataset 6 | from .datasets import register_video_dataset 7 | from .datamanager import ImageDataManager, VideoDataManager, ImageHybridDatamanager -------------------------------------------------------------------------------- /torchreid/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from .dataset import Dataset, ImageDataset, VideoDataset 5 | from .image import * 6 | from .video import * 7 | 8 | 9 | __image_datasets = { 10 | 'market1501': Market1501, 11 | 'cuhk03': CUHK03, 12 | 'cuhk03_detected': CUHK03_Detected, 13 | 'cuhk03_labeled': CUHK03_Labeled, 14 | 'dukemtmcreid': DukeMTMCreID, 15 | 'msmt17': MSMT17, 16 | 'viper': VIPeR, 17 | 'grid': GRID, 18 | 'cuhk01': CUHK01, 19 | 'ilids': iLIDS, 20 | 'sensereid': SenseReID, 21 | 'prid': PRID, 22 | 'cuhk02': CUHK02 23 | } 24 | 25 | 26 | __video_datasets = { 27 | 'mars': Mars, 28 | 'ilidsvid': iLIDSVID, 29 | 'prid2011': PRID2011, 30 | 'dukemtmcvidreid': DukeMTMCVidReID 31 | } 32 | 33 | 34 | def init_image_dataset(name, **kwargs): 35 | """Initializes an image dataset.""" 36 | avai_datasets = list(__image_datasets.keys()) 37 | if name not in avai_datasets: 38 | raise ValueError('Invalid dataset name. Received "{}", ' 39 | 'but expected to be one of {}'.format(name, avai_datasets)) 40 | return __image_datasets[name](**kwargs) 41 | 42 | 43 | def init_video_dataset(name, **kwargs): 44 | """Initializes a video dataset.""" 45 | avai_datasets = list(__video_datasets.keys()) 46 | if name not in avai_datasets: 47 | raise ValueError('Invalid dataset name. Received "{}", ' 48 | 'but expected to be one of {}'.format(name, avai_datasets)) 49 | return __video_datasets[name](**kwargs) 50 | 51 | 52 | def register_image_dataset(name, dataset): 53 | """Registers a new image dataset. 54 | 55 | Args: 56 | name (str): key corresponding to the new dataset. 57 | dataset (Dataset): the new dataset class. 58 | 59 | Examples:: 60 | 61 | import torchreid 62 | import NewDataset 63 | torchreid.data.register_image_dataset('new_dataset', NewDataset) 64 | # single dataset case 65 | datamanager = torchreid.data.ImageDataManager( 66 | root='reid-data', 67 | sources='new_dataset' 68 | ) 69 | # multiple dataset case 70 | datamanager = torchreid.data.ImageDataManager( 71 | root='reid-data', 72 | sources=['new_dataset', 'dukemtmcreid'] 73 | ) 74 | """ 75 | global __image_datasets 76 | curr_datasets = list(__image_datasets.keys()) 77 | if name in curr_datasets: 78 | raise ValueError('The given name already exists, please choose ' 79 | 'another name excluding {}'.format(curr_datasets)) 80 | __image_datasets[name] = dataset 81 | 82 | 83 | def register_video_dataset(name, dataset): 84 | """Registers a new video dataset. 85 | 86 | Args: 87 | name (str): key corresponding to the new dataset. 88 | dataset (Dataset): the new dataset class. 89 | 90 | Examples:: 91 | 92 | import torchreid 93 | import NewDataset 94 | torchreid.data.register_video_dataset('new_dataset', NewDataset) 95 | # single dataset case 96 | datamanager = torchreid.data.VideoDataManager( 97 | root='reid-data', 98 | sources='new_dataset' 99 | ) 100 | # multiple dataset case 101 | datamanager = torchreid.data.VideoDataManager( 102 | root='reid-data', 103 | sources=['new_dataset', 'ilidsvid'] 104 | ) 105 | """ 106 | global __video_datasets 107 | curr_datasets = list(__video_datasets.keys()) 108 | if name in curr_datasets: 109 | raise ValueError('The given name already exists, please choose ' 110 | 'another name excluding {}'.format(curr_datasets)) 111 | __video_datasets[name] = dataset 112 | -------------------------------------------------------------------------------- /torchreid/data/datasets/image/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from .market1501 import Market1501 5 | from .dukemtmcreid import DukeMTMCreID 6 | from .cuhk03 import CUHK03 7 | from .msmt17 import MSMT17 8 | from .viper import VIPeR 9 | from .grid import GRID 10 | from .cuhk01 import CUHK01 11 | from .ilids import iLIDS 12 | from .sensereid import SenseReID 13 | from .prid import PRID 14 | from .cuhk02 import CUHK02 15 | from .cuhk03_detected import CUHK03_Detected 16 | from .cuhk03_labeled import CUHK03_Labeled 17 | from .msmt17 import MSMT17 18 | -------------------------------------------------------------------------------- /torchreid/data/datasets/image/cuhk01.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import zipfile 10 | import numpy as np 11 | 12 | from torchreid.data.datasets import ImageDataset 13 | from torchreid.utils import read_json, write_json 14 | 15 | 16 | class CUHK01(ImageDataset): 17 | """CUHK01. 18 | 19 | Reference: 20 | Li et al. Human Reidentification with Transferred Metric Learning. ACCV 2012. 21 | 22 | URL: ``_ 23 | 24 | Dataset statistics: 25 | - identities: 971. 26 | - images: 3884. 27 | - cameras: 4. 28 | """ 29 | dataset_dir = 'cuhk01' 30 | dataset_url = None 31 | 32 | def __init__(self, root='', split_id=0, **kwargs): 33 | self.root = osp.abspath(osp.expanduser(root)) 34 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 35 | self.download_dataset(self.dataset_dir, self.dataset_url) 36 | 37 | self.zip_path = osp.join(self.dataset_dir, 'CUHK01.zip') 38 | self.campus_dir = osp.join(self.dataset_dir, 'campus') 39 | self.split_path = osp.join(self.dataset_dir, 'splits.json') 40 | 41 | self.extract_file() 42 | 43 | required_files = [ 44 | self.dataset_dir, 45 | self.campus_dir 46 | ] 47 | self.check_before_run(required_files) 48 | 49 | self.prepare_split() 50 | splits = read_json(self.split_path) 51 | if split_id >= len(splits): 52 | raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1)) 53 | split = splits[split_id] 54 | 55 | train = split['train'] 56 | query = split['query'] 57 | gallery = split['gallery'] 58 | 59 | train = [tuple(item) for item in train] 60 | query = [tuple(item) for item in query] 61 | gallery = [tuple(item) for item in gallery] 62 | 63 | super(CUHK01, self).__init__(train, query, gallery, **kwargs) 64 | 65 | def extract_file(self): 66 | if not osp.exists(self.campus_dir): 67 | print('Extracting files') 68 | zip_ref = zipfile.ZipFile(self.zip_path, 'r') 69 | zip_ref.extractall(self.dataset_dir) 70 | zip_ref.close() 71 | 72 | def prepare_split(self): 73 | """ 74 | Image name format: 0001001.png, where first four digits represent identity 75 | and last four digits represent cameras. Camera 1&2 are considered the same 76 | view and camera 3&4 are considered the same view. 77 | """ 78 | if not osp.exists(self.split_path): 79 | print('Creating 10 random splits of train ids and test ids') 80 | img_paths = sorted(glob.glob(osp.join(self.campus_dir, '*.png'))) 81 | img_list = [] 82 | pid_container = set() 83 | for img_path in img_paths: 84 | img_name = osp.basename(img_path) 85 | pid = int(img_name[:4]) - 1 86 | camid = (int(img_name[4:7]) - 1) // 2 # result is either 0 or 1 87 | img_list.append((img_path, pid, camid)) 88 | pid_container.add(pid) 89 | 90 | num_pids = len(pid_container) 91 | num_train_pids = num_pids // 2 92 | 93 | splits = [] 94 | for _ in range(10): 95 | order = np.arange(num_pids) 96 | np.random.shuffle(order) 97 | train_idxs = order[:num_train_pids] 98 | train_idxs = np.sort(train_idxs) 99 | idx2label = {idx: label for label, idx in enumerate(train_idxs)} 100 | 101 | train, test_a, test_b = [], [], [] 102 | for img_path, pid, camid in img_list: 103 | if pid in train_idxs: 104 | train.append((img_path, idx2label[pid], camid)) 105 | else: 106 | if camid == 0: 107 | test_a.append((img_path, pid, camid)) 108 | else: 109 | test_b.append((img_path, pid, camid)) 110 | 111 | # use cameraA as query and cameraB as gallery 112 | split = { 113 | 'train': train, 114 | 'query': test_a, 115 | 'gallery': test_b, 116 | 'num_train_pids': num_train_pids, 117 | 'num_query_pids': num_pids - num_train_pids, 118 | 'num_gallery_pids': num_pids - num_train_pids 119 | } 120 | splits.append(split) 121 | 122 | # use cameraB as query and cameraA as gallery 123 | split = { 124 | 'train': train, 125 | 'query': test_b, 126 | 'gallery': test_a, 127 | 'num_train_pids': num_train_pids, 128 | 'num_query_pids': num_pids - num_train_pids, 129 | 'num_gallery_pids': num_pids - num_train_pids 130 | } 131 | splits.append(split) 132 | 133 | print('Totally {} splits are created'.format(len(splits))) 134 | write_json(splits, self.split_path) 135 | print('Split file saved to {}'.format(self.split_path)) 136 | -------------------------------------------------------------------------------- /torchreid/data/datasets/image/cuhk02.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | 10 | from torchreid.data.datasets import ImageDataset 11 | 12 | 13 | class CUHK02(ImageDataset): 14 | """CUHK02. 15 | 16 | Reference: 17 | Li and Wang. Locally Aligned Feature Transforms across Views. CVPR 2013. 18 | 19 | URL: ``_ 20 | 21 | Dataset statistics: 22 | - 5 camera view pairs each with two cameras 23 | - 971, 306, 107, 193 and 239 identities from P1 - P5 24 | - totally 1,816 identities 25 | - image format is png 26 | 27 | Protocol: Use P1 - P4 for training and P5 for evaluation. 28 | """ 29 | dataset_dir = 'cuhk02' 30 | cam_pairs = ['P1', 'P2', 'P3', 'P4', 'P5'] 31 | test_cam_pair = 'P5' 32 | 33 | def __init__(self, root='', **kwargs): 34 | self.root = osp.abspath(osp.expanduser(root)) 35 | self.dataset_dir = osp.join(self.root, self.dataset_dir, 'Dataset') 36 | 37 | required_files = [self.dataset_dir] 38 | self.check_before_run(required_files) 39 | 40 | train, query, gallery = self.get_data_list() 41 | 42 | super(CUHK02, self).__init__(train, query, gallery, **kwargs) 43 | 44 | def get_data_list(self): 45 | num_train_pids, camid = 0, 0 46 | train, query, gallery = [], [], [] 47 | 48 | for cam_pair in self.cam_pairs: 49 | cam_pair_dir = osp.join(self.dataset_dir, cam_pair) 50 | 51 | cam1_dir = osp.join(cam_pair_dir, 'cam1') 52 | cam2_dir = osp.join(cam_pair_dir, 'cam2') 53 | 54 | impaths1 = glob.glob(osp.join(cam1_dir, '*.png')) 55 | impaths2 = glob.glob(osp.join(cam2_dir, '*.png')) 56 | 57 | if cam_pair == self.test_cam_pair: 58 | # add images to query 59 | for impath in impaths1: 60 | pid = osp.basename(impath).split('_')[0] 61 | pid = int(pid) 62 | query.append((impath, pid, camid)) 63 | camid += 1 64 | 65 | # add images to gallery 66 | for impath in impaths2: 67 | pid = osp.basename(impath).split('_')[0] 68 | pid = int(pid) 69 | gallery.append((impath, pid, camid)) 70 | camid += 1 71 | 72 | else: 73 | pids1 = [osp.basename(impath).split('_')[0] for impath in impaths1] 74 | pids2 = [osp.basename(impath).split('_')[0] for impath in impaths2] 75 | pids = set(pids1 + pids2) 76 | pid2label = {pid: label+num_train_pids for label, pid in enumerate(pids)} 77 | 78 | # add images to train from cam1 79 | for impath in impaths1: 80 | pid = osp.basename(impath).split('_')[0] 81 | pid = pid2label[pid] 82 | train.append((impath, pid, camid)) 83 | camid += 1 84 | 85 | # add images to train from cam1 86 | for impath in impaths1: 87 | pid = osp.basename(impath).split('_')[0] 88 | pid = pid2label[pid] 89 | train.append((impath, pid, camid)) 90 | camid += 1 91 | num_train_pids += len(pids) 92 | 93 | return train, query, gallery 94 | -------------------------------------------------------------------------------- /torchreid/data/datasets/image/cuhk03_detected.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import re 10 | 11 | from torchreid.data.datasets import ImageDataset 12 | 13 | 14 | class CUHK03_Detected(ImageDataset): 15 | """DukeMTMC-reID. 16 | 17 | Reference: 18 | - Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 19 | - Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 1404 (train + query). 25 | - images:16522 (train) + 2228 (query) + 17661 (gallery). 26 | - cameras: 8. 27 | """ 28 | dataset_dir = 'cuhk03' 29 | 30 | def __init__(self, root='', **kwargs): 31 | self.root = osp.abspath(osp.expanduser(root)) 32 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 33 | self.train_dir = osp.join(self.dataset_dir, 'CUHK03_detected/bounding_box_train') 34 | self.query_dir = osp.join(self.dataset_dir, 'CUHK03_detected/query') 35 | self.gallery_dir = osp.join(self.dataset_dir, 'CUHK03_detected/bounding_box_test') 36 | 37 | required_files = [ 38 | self.dataset_dir, 39 | self.train_dir, 40 | self.query_dir, 41 | self.gallery_dir 42 | ] 43 | self.check_before_run(required_files) 44 | 45 | train = self.process_dir(self.train_dir, relabel=True) 46 | query = self.process_dir(self.query_dir, relabel=False) 47 | gallery = self.process_dir(self.gallery_dir, relabel=False) 48 | 49 | super(CUHK03_Detected, self).__init__(train, query, gallery, **kwargs) 50 | 51 | def process_dir(self, dir_path, relabel=False): 52 | img_paths = glob.glob(osp.join(dir_path, '*.png')) 53 | pattern = re.compile(r'([-\d]+)_c(\d)') 54 | 55 | pid_container = set() 56 | for img_path in img_paths: 57 | pid, _ = map(int, pattern.search(img_path).groups()) 58 | pid_container.add(pid) 59 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 60 | 61 | data = [] 62 | for img_path in img_paths: 63 | pid, camid = map(int, pattern.search(img_path).groups()) 64 | #assert 1 <= camid <= 8 65 | camid -= 1 # index starts from 0 66 | if relabel: pid = pid2label[pid] 67 | data.append((img_path, pid, camid)) 68 | 69 | return data 70 | -------------------------------------------------------------------------------- /torchreid/data/datasets/image/cuhk03_labeled.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import re 10 | 11 | from torchreid.data.datasets import ImageDataset 12 | 13 | 14 | class CUHK03_Labeled(ImageDataset): 15 | """DukeMTMC-reID. 16 | 17 | Reference: 18 | - Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 19 | - Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 1404 (train + query). 25 | - images:16522 (train) + 2228 (query) + 17661 (gallery). 26 | - cameras: 8. 27 | """ 28 | dataset_dir = 'cuhk03' 29 | 30 | def __init__(self, root='', **kwargs): 31 | self.root = osp.abspath(osp.expanduser(root)) 32 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 33 | self.train_dir = osp.join(self.dataset_dir, 'CUHK03_labeled/bounding_box_train') 34 | self.query_dir = osp.join(self.dataset_dir, 'CUHK03_labeled/query') 35 | self.gallery_dir = osp.join(self.dataset_dir, 'CUHK03_labeled/bounding_box_test') 36 | 37 | required_files = [ 38 | self.dataset_dir, 39 | self.train_dir, 40 | self.query_dir, 41 | self.gallery_dir 42 | ] 43 | self.check_before_run(required_files) 44 | 45 | train = self.process_dir(self.train_dir, relabel=True) 46 | query = self.process_dir(self.query_dir, relabel=False) 47 | gallery = self.process_dir(self.gallery_dir, relabel=False) 48 | 49 | super(CUHK03_Labeled, self).__init__(train, query, gallery, **kwargs) 50 | 51 | def process_dir(self, dir_path, relabel=False): 52 | img_paths = glob.glob(osp.join(dir_path, '*.png')) 53 | pattern = re.compile(r'([-\d]+)_c(\d)') 54 | 55 | pid_container = set() 56 | for img_path in img_paths: 57 | pid, _ = map(int, pattern.search(img_path).groups()) 58 | pid_container.add(pid) 59 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 60 | 61 | data = [] 62 | for img_path in img_paths: 63 | pid, camid = map(int, pattern.search(img_path).groups()) 64 | #assert 1 <= camid <= 8 65 | camid -= 1 # index starts from 0 66 | if relabel: pid = pid2label[pid] 67 | data.append((img_path, pid, camid)) 68 | 69 | return data 70 | -------------------------------------------------------------------------------- /torchreid/data/datasets/image/dukemtmcreid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import re 10 | 11 | from torchreid.data.datasets import ImageDataset 12 | 13 | 14 | class DukeMTMCreID(ImageDataset): 15 | """DukeMTMC-reID. 16 | 17 | Reference: 18 | - Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 19 | - Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 1404 (train + query). 25 | - images:16522 (train) + 2228 (query) + 17661 (gallery). 26 | - cameras: 8. 27 | """ 28 | dataset_dir = 'dukemtmc-reid' 29 | dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 30 | 31 | def __init__(self, root='', **kwargs): 32 | self.root = osp.abspath(osp.expanduser(root)) 33 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 34 | self.download_dataset(self.dataset_dir, self.dataset_url) 35 | self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train') 36 | self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query') 37 | self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test') 38 | 39 | required_files = [ 40 | self.dataset_dir, 41 | self.train_dir, 42 | self.query_dir, 43 | self.gallery_dir 44 | ] 45 | self.check_before_run(required_files) 46 | 47 | train = self.process_dir(self.train_dir, relabel=True) 48 | query = self.process_dir(self.query_dir, relabel=False) 49 | gallery = self.process_dir(self.gallery_dir, relabel=False) 50 | 51 | super(DukeMTMCreID, self).__init__(train, query, gallery, **kwargs) 52 | 53 | def process_dir(self, dir_path, relabel=False): 54 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 55 | pattern = re.compile(r'([-\d]+)_c(\d)') 56 | 57 | pid_container = set() 58 | for img_path in img_paths: 59 | pid, _ = map(int, pattern.search(img_path).groups()) 60 | pid_container.add(pid) 61 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 62 | 63 | data = [] 64 | for img_path in img_paths: 65 | pid, camid = map(int, pattern.search(img_path).groups()) 66 | assert 1 <= camid <= 8 67 | camid -= 1 # index starts from 0 68 | if relabel: pid = pid2label[pid] 69 | data.append((img_path, pid, camid)) 70 | 71 | return data -------------------------------------------------------------------------------- /torchreid/data/datasets/image/grid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | from scipy.io import loadmat 10 | 11 | from torchreid.data.datasets import ImageDataset 12 | from torchreid.utils import read_json, write_json 13 | 14 | 15 | class GRID(ImageDataset): 16 | """GRID. 17 | 18 | Reference: 19 | Loy et al. Multi-camera activity correlation analysis. CVPR 2009. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 250. 25 | - images: 1275. 26 | - cameras: 8. 27 | """ 28 | dataset_dir = 'grid' 29 | dataset_url = 'http://personal.ie.cuhk.edu.hk/~ccloy/files/datasets/underground_reid.zip' 30 | 31 | def __init__(self, root='', split_id=0, **kwargs): 32 | self.root = osp.abspath(osp.expanduser(root)) 33 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 34 | self.download_dataset(self.dataset_dir, self.dataset_url) 35 | 36 | self.probe_path = osp.join(self.dataset_dir, 'underground_reid', 'probe') 37 | self.gallery_path = osp.join(self.dataset_dir, 'underground_reid', 'gallery') 38 | self.split_mat_path = osp.join(self.dataset_dir, 'underground_reid', 'features_and_partitions.mat') 39 | self.split_path = osp.join(self.dataset_dir, 'splits.json') 40 | 41 | required_files = [ 42 | self.dataset_dir, 43 | self.probe_path, 44 | self.gallery_path, 45 | self.split_mat_path 46 | ] 47 | self.check_before_run(required_files) 48 | 49 | self.prepare_split() 50 | splits = read_json(self.split_path) 51 | if split_id >= len(splits): 52 | raise ValueError('split_id exceeds range, received {}, ' 53 | 'but expected between 0 and {}'.format(split_id, len(splits)-1)) 54 | split = splits[split_id] 55 | 56 | train = split['train'] 57 | query = split['query'] 58 | gallery = split['gallery'] 59 | 60 | train = [tuple(item) for item in train] 61 | query = [tuple(item) for item in query] 62 | gallery = [tuple(item) for item in gallery] 63 | 64 | super(GRID, self).__init__(train, query, gallery, **kwargs) 65 | 66 | def prepare_split(self): 67 | if not osp.exists(self.split_path): 68 | print('Creating 10 random splits') 69 | split_mat = loadmat(self.split_mat_path) 70 | trainIdxAll = split_mat['trainIdxAll'][0] # length = 10 71 | probe_img_paths = sorted(glob.glob(osp.join(self.probe_path, '*.jpeg'))) 72 | gallery_img_paths = sorted(glob.glob(osp.join(self.gallery_path, '*.jpeg'))) 73 | 74 | splits = [] 75 | for split_idx in range(10): 76 | train_idxs = trainIdxAll[split_idx][0][0][2][0].tolist() 77 | assert len(train_idxs) == 125 78 | idx2label = {idx: label for label, idx in enumerate(train_idxs)} 79 | 80 | train, query, gallery = [], [], [] 81 | 82 | # processing probe folder 83 | for img_path in probe_img_paths: 84 | img_name = osp.basename(img_path) 85 | img_idx = int(img_name.split('_')[0]) 86 | camid = int(img_name.split('_')[1]) - 1 # index starts from 0 87 | if img_idx in train_idxs: 88 | train.append((img_path, idx2label[img_idx], camid)) 89 | else: 90 | query.append((img_path, img_idx, camid)) 91 | 92 | # process gallery folder 93 | for img_path in gallery_img_paths: 94 | img_name = osp.basename(img_path) 95 | img_idx = int(img_name.split('_')[0]) 96 | camid = int(img_name.split('_')[1]) - 1 # index starts from 0 97 | if img_idx in train_idxs: 98 | train.append((img_path, idx2label[img_idx], camid)) 99 | else: 100 | gallery.append((img_path, img_idx, camid)) 101 | 102 | split = { 103 | 'train': train, 104 | 'query': query, 105 | 'gallery': gallery, 106 | 'num_train_pids': 125, 107 | 'num_query_pids': 125, 108 | 'num_gallery_pids': 900 109 | } 110 | splits.append(split) 111 | 112 | print('Totally {} splits are created'.format(len(splits))) 113 | write_json(splits, self.split_path) 114 | print('Split file saved to {}'.format(self.split_path)) -------------------------------------------------------------------------------- /torchreid/data/datasets/image/ilids.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import numpy as np 10 | import copy 11 | import random 12 | from collections import defaultdict 13 | 14 | from torchreid.data.datasets import ImageDataset 15 | from torchreid.utils import read_json, write_json 16 | 17 | 18 | class iLIDS(ImageDataset): 19 | """QMUL-iLIDS. 20 | 21 | Reference: 22 | Zheng et al. Associating Groups of People. BMVC 2009. 23 | 24 | Dataset statistics: 25 | - identities: 119. 26 | - images: 476. 27 | - cameras: 8 (not explicitly provided). 28 | """ 29 | dataset_dir = 'ilids' 30 | dataset_url = 'http://www.eecs.qmul.ac.uk/~jason/data/i-LIDS_Pedestrian.tgz' 31 | 32 | def __init__(self, root='', split_id=0, **kwargs): 33 | self.root = osp.abspath(osp.expanduser(root)) 34 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 35 | self.download_dataset(self.dataset_dir, self.dataset_url) 36 | 37 | self.data_dir = osp.join(self.dataset_dir, 'i-LIDS_Pedestrian/Persons') 38 | self.split_path = osp.join(self.dataset_dir, 'splits.json') 39 | 40 | required_files = [ 41 | self.dataset_dir, 42 | self.data_dir 43 | ] 44 | self.check_before_run(required_files) 45 | 46 | self.prepare_split() 47 | splits = read_json(self.split_path) 48 | if split_id >= len(splits): 49 | raise ValueError('split_id exceeds range, received {}, but ' 50 | 'expected between 0 and {}'.format(split_id, len(splits)-1)) 51 | split = splits[split_id] 52 | 53 | train, query, gallery = self.process_split(split) 54 | 55 | super(iLIDS, self).__init__(train, query, gallery, **kwargs) 56 | 57 | def prepare_split(self): 58 | if not osp.exists(self.split_path): 59 | print('Creating splits ...') 60 | 61 | paths = glob.glob(osp.join(self.data_dir, '*.jpg')) 62 | img_names = [osp.basename(path) for path in paths] 63 | num_imgs = len(img_names) 64 | assert num_imgs == 476, 'There should be 476 images, but ' \ 65 | 'got {}, please check the data'.format(num_imgs) 66 | 67 | # store image names 68 | # image naming format: 69 | # the first four digits denote the person ID 70 | # the last four digits denote the sequence index 71 | pid_dict = defaultdict(list) 72 | for img_name in img_names: 73 | pid = int(img_name[:4]) 74 | pid_dict[pid].append(img_name) 75 | pids = list(pid_dict.keys()) 76 | num_pids = len(pids) 77 | assert num_pids == 119, 'There should be 119 identities, ' \ 78 | 'but got {}, please check the data'.format(num_pids) 79 | 80 | num_train_pids = int(num_pids * 0.5) 81 | num_test_pids = num_pids - num_train_pids # supposed to be 60 82 | 83 | splits = [] 84 | for _ in range(10): 85 | # randomly choose num_train_pids train IDs and num_test_pids test IDs 86 | pids_copy = copy.deepcopy(pids) 87 | random.shuffle(pids_copy) 88 | train_pids = pids_copy[:num_train_pids] 89 | test_pids = pids_copy[num_train_pids:] 90 | 91 | train = [] 92 | query = [] 93 | gallery = [] 94 | 95 | # for train IDs, all images are used in the train set. 96 | for pid in train_pids: 97 | img_names = pid_dict[pid] 98 | train.extend(img_names) 99 | 100 | # for each test ID, randomly choose two images, one for 101 | # query and the other one for gallery. 102 | for pid in test_pids: 103 | img_names = pid_dict[pid] 104 | samples = random.sample(img_names, 2) 105 | query.append(samples[0]) 106 | gallery.append(samples[1]) 107 | 108 | split = {'train': train, 'query': query, 'gallery': gallery} 109 | splits.append(split) 110 | 111 | print('Totally {} splits are created'.format(len(splits))) 112 | write_json(splits, self.split_path) 113 | print('Split file is saved to {}'.format(self.split_path)) 114 | 115 | def get_pid2label(self, img_names): 116 | pid_container = set() 117 | for img_name in img_names: 118 | pid = int(img_name[:4]) 119 | pid_container.add(pid) 120 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 121 | return pid2label 122 | 123 | def parse_img_names(self, img_names, pid2label=None): 124 | data = [] 125 | 126 | for img_name in img_names: 127 | pid = int(img_name[:4]) 128 | if pid2label is not None: 129 | pid = pid2label[pid] 130 | camid = int(img_name[4:7]) - 1 # 0-based 131 | img_path = osp.join(self.data_dir, img_name) 132 | data.append((img_path, pid, camid)) 133 | 134 | return data 135 | 136 | def process_split(self, split): 137 | train, query, gallery = [], [], [] 138 | train_pid2label = self.get_pid2label(split['train']) 139 | train = self.parse_img_names(split['train'], train_pid2label) 140 | query = self.parse_img_names(split['query']) 141 | gallery = self.parse_img_names(split['gallery']) 142 | return train, query, gallery -------------------------------------------------------------------------------- /torchreid/data/datasets/image/market1501.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import re 10 | import warnings 11 | 12 | from torchreid.data.datasets import ImageDataset 13 | 14 | 15 | class Market1501(ImageDataset): 16 | """Market1501. 17 | 18 | Reference: 19 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 1501 (+1 for background). 25 | - images: 12936 (train) + 3368 (query) + 15913 (gallery). 26 | """ 27 | _junk_pids = [0, -1] 28 | dataset_dir = 'market1501' 29 | dataset_url = 'http://188.138.127.15:81/Datasets/Market-1501-v15.09.15.zip' 30 | 31 | def __init__(self, root='', market1501_500k=False, **kwargs): 32 | self.root = osp.abspath(osp.expanduser(root)) 33 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 34 | self.download_dataset(self.dataset_dir, self.dataset_url) 35 | 36 | # allow alternative directory structure 37 | self.data_dir = self.dataset_dir 38 | data_dir = osp.join(self.data_dir, 'Market-1501-v15.09.15') 39 | if osp.isdir(data_dir): 40 | self.data_dir = data_dir 41 | else: 42 | warnings.warn('The current data structure is deprecated. Please ' 43 | 'put data folders such as "bounding_box_train" under ' 44 | '"Market-1501-v15.09.15".') 45 | 46 | self.train_dir = osp.join(self.data_dir, 'bounding_box_train') 47 | self.query_dir = osp.join(self.data_dir, 'query') 48 | self.gallery_dir = osp.join(self.data_dir, 'bounding_box_test') 49 | self.extra_gallery_dir = osp.join(self.data_dir, 'images') 50 | self.market1501_500k = market1501_500k 51 | 52 | required_files = [ 53 | self.data_dir, 54 | self.train_dir, 55 | self.query_dir, 56 | self.gallery_dir 57 | ] 58 | if self.market1501_500k: 59 | required_files.append(self.extra_gallery_dir) 60 | self.check_before_run(required_files) 61 | 62 | train = self.process_dir(self.train_dir, relabel=True) 63 | query = self.process_dir(self.query_dir, relabel=False) 64 | gallery = self.process_dir(self.gallery_dir, relabel=False) 65 | if self.market1501_500k: 66 | gallery += self.process_dir(self.extra_gallery_dir, relabel=False) 67 | 68 | super(Market1501, self).__init__(train, query, gallery, **kwargs) 69 | 70 | def process_dir(self, dir_path, relabel=False): 71 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 72 | pattern = re.compile(r'([-\d]+)_c(\d)') 73 | 74 | pid_container = set() 75 | for img_path in img_paths: 76 | pid, _ = map(int, pattern.search(img_path).groups()) 77 | if pid == -1: 78 | continue # junk images are just ignored 79 | pid_container.add(pid) 80 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 81 | 82 | data = [] 83 | for img_path in img_paths: 84 | pid, camid = map(int, pattern.search(img_path).groups()) 85 | if pid == -1: 86 | continue # junk images are just ignored 87 | assert 0 <= pid <= 1501 # pid == 0 means background 88 | assert 1 <= camid <= 6 89 | camid -= 1 # index starts from 0 90 | if relabel: 91 | pid = pid2label[pid] 92 | data.append((img_path, pid, camid)) 93 | 94 | return data -------------------------------------------------------------------------------- /torchreid/data/datasets/image/msmt17.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import re 10 | 11 | from torchreid.data.datasets import ImageDataset 12 | 13 | 14 | class MSMT17(ImageDataset): 15 | """ 16 | MSMT17. 17 | 18 | Reference: 19 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 4101. 25 | - images: 32621 (train) + 11659 (query) + 82161 (gallery). 26 | - cameras: 15. 27 | """ 28 | 29 | dataset_dir = 'msmt17' 30 | 31 | def __init__(self, root='', **kwargs): 32 | self.root = osp.abspath(osp.expanduser(root)) 33 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 34 | self.train_dir = osp.join(self.dataset_dir, 'MSMT17/bounding_box_train') 35 | self.query_dir = osp.join(self.dataset_dir, 'MSMT17/query') 36 | self.gallery_dir = osp.join(self.dataset_dir, 'MSMT17/bounding_box_test') 37 | 38 | required_files = [ 39 | self.dataset_dir, 40 | self.train_dir, 41 | self.query_dir, 42 | self.gallery_dir 43 | ] 44 | self.check_before_run(required_files) 45 | 46 | train = self.process_dir(self.train_dir, relabel=True) 47 | query = self.process_dir(self.query_dir, relabel=False) 48 | gallery = self.process_dir(self.gallery_dir, relabel=False) 49 | 50 | super(MSMT17, self).__init__(train, query, gallery, **kwargs) 51 | 52 | def process_dir(self, dir_path, relabel=False): 53 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 54 | pattern = re.compile(r'([-\d]+)_c(\d)') 55 | 56 | pid_container = set() 57 | for img_path in img_paths: 58 | pid, _ = map(int, pattern.search(img_path).groups()) 59 | pid_container.add(pid) 60 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 61 | 62 | data = [] 63 | for img_path in img_paths: 64 | pid, camid = map(int, pattern.search(img_path).groups()) 65 | assert 1 <= camid <= 15 66 | camid -= 1 # index starts from 0 67 | if relabel: pid = pid2label[pid] 68 | data.append((img_path, pid, camid)) 69 | 70 | return data 71 | -------------------------------------------------------------------------------- /torchreid/data/datasets/image/prid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import random 9 | 10 | from torchreid.data.datasets import ImageDataset 11 | from torchreid.utils import read_json, write_json 12 | 13 | 14 | class PRID(ImageDataset): 15 | """PRID (single-shot version of prid-2011) 16 | 17 | Reference: 18 | Hirzer et al. Person Re-Identification by Descriptive and Discriminative 19 | Classification. SCIA 2011. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - Two views. 25 | - View A captures 385 identities. 26 | - View B captures 749 identities. 27 | - 200 identities appear in both views. 28 | """ 29 | dataset_dir = 'prid2011' 30 | dataset_url = None 31 | 32 | def __init__(self, root='', split_id=0, **kwargs): 33 | self.root = osp.abspath(osp.expanduser(root)) 34 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 35 | self.download_dataset(self.dataset_dir, self.dataset_url) 36 | 37 | self.cam_a_dir = osp.join(self.dataset_dir, 'prid_2011', 'single_shot', 'cam_a') 38 | self.cam_b_dir = osp.join(self.dataset_dir, 'prid_2011', 'single_shot', 'cam_b') 39 | self.split_path = osp.join(self.dataset_dir, 'splits_single_shot.json') 40 | 41 | required_files = [ 42 | self.dataset_dir, 43 | self.cam_a_dir, 44 | self.cam_b_dir 45 | ] 46 | self.check_before_run(required_files) 47 | 48 | self.prepare_split() 49 | splits = read_json(self.split_path) 50 | if split_id >= len(splits): 51 | raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1)) 52 | split = splits[split_id] 53 | 54 | train, query, gallery = self.process_split(split) 55 | 56 | super(PRID, self).__init__(train, query, gallery, **kwargs) 57 | 58 | def prepare_split(self): 59 | if not osp.exists(self.split_path): 60 | print('Creating splits ...') 61 | 62 | splits = [] 63 | for _ in range(10): 64 | # randomly sample 100 IDs for train and use the rest 100 IDs for test 65 | # (note: there are only 200 IDs appearing in both views) 66 | pids = [i for i in range(1, 201)] 67 | train_pids = random.sample(pids, 100) 68 | train_pids.sort() 69 | test_pids = [i for i in pids if i not in train_pids] 70 | split = {'train': train_pids, 'test': test_pids} 71 | splits.append(split) 72 | 73 | print('Totally {} splits are created'.format(len(splits))) 74 | write_json(splits, self.split_path) 75 | print('Split file is saved to {}'.format(self.split_path)) 76 | 77 | def process_split(self, split): 78 | train, query, gallery = [], [], [] 79 | train_pids = split['train'] 80 | test_pids = split['test'] 81 | 82 | train_pid2label = {pid: label for label, pid in enumerate(train_pids)} 83 | 84 | # train 85 | train = [] 86 | for pid in train_pids: 87 | img_name = 'person_' + str(pid).zfill(4) + '.png' 88 | pid = train_pid2label[pid] 89 | img_a_path = osp.join(self.cam_a_dir, img_name) 90 | train.append((img_a_path, pid, 0)) 91 | img_b_path = osp.join(self.cam_b_dir, img_name) 92 | train.append((img_b_path, pid, 1)) 93 | 94 | # query and gallery 95 | query, gallery = [], [] 96 | for pid in test_pids: 97 | img_name = 'person_' + str(pid).zfill(4) + '.png' 98 | img_a_path = osp.join(self.cam_a_dir, img_name) 99 | query.append((img_a_path, pid, 0)) 100 | img_b_path = osp.join(self.cam_b_dir, img_name) 101 | gallery.append((img_b_path, pid, 1)) 102 | for pid in range(201, 750): 103 | img_name = 'person_' + str(pid).zfill(4) + '.png' 104 | img_b_path = osp.join(self.cam_b_dir, img_name) 105 | gallery.append((img_b_path, pid, 1)) 106 | 107 | return train, query, gallery -------------------------------------------------------------------------------- /torchreid/data/datasets/image/sensereid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import copy 10 | 11 | from torchreid.data.datasets import ImageDataset 12 | 13 | 14 | class SenseReID(ImageDataset): 15 | """SenseReID. 16 | 17 | This dataset is used for test purpose only. 18 | 19 | Reference: 20 | Zhao et al. Spindle Net: Person Re-identification with Human Body 21 | Region Guided Feature Decomposition and Fusion. CVPR 2017. 22 | 23 | URL: ``_ 24 | 25 | Dataset statistics: 26 | - query: 522 ids, 1040 images. 27 | - gallery: 1717 ids, 3388 images. 28 | """ 29 | dataset_dir = 'sensereid' 30 | dataset_url = None 31 | 32 | def __init__(self, root='', **kwargs): 33 | self.root = osp.abspath(osp.expanduser(root)) 34 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 35 | self.download_dataset(self.dataset_dir, self.dataset_url) 36 | 37 | self.query_dir = osp.join(self.dataset_dir, 'SenseReID', 'test_probe') 38 | self.gallery_dir = osp.join(self.dataset_dir, 'SenseReID', 'test_gallery') 39 | 40 | required_files = [ 41 | self.dataset_dir, 42 | self.query_dir, 43 | self.gallery_dir 44 | ] 45 | self.check_before_run(required_files) 46 | 47 | query = self.process_dir(self.query_dir) 48 | gallery = self.process_dir(self.gallery_dir) 49 | 50 | # relabel 51 | g_pids = set() 52 | for _, pid, _ in gallery: 53 | g_pids.add(pid) 54 | pid2label = {pid: i for i, pid in enumerate(g_pids)} 55 | 56 | query = [(img_path, pid2label[pid], camid) for img_path, pid, camid in query] 57 | gallery = [(img_path, pid2label[pid], camid) for img_path, pid, camid in gallery] 58 | train = copy.deepcopy(query) + copy.deepcopy(gallery) # dummy variable 59 | 60 | super(SenseReID, self).__init__(train, query, gallery, **kwargs) 61 | 62 | def process_dir(self, dir_path): 63 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 64 | data = [] 65 | 66 | for img_path in img_paths: 67 | img_name = osp.splitext(osp.basename(img_path))[0] 68 | pid, camid = img_name.split('_') 69 | pid, camid = int(pid), int(camid) 70 | data.append((img_path, pid, camid)) 71 | 72 | return data -------------------------------------------------------------------------------- /torchreid/data/datasets/image/viper.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import numpy as np 10 | 11 | from torchreid.data.datasets import ImageDataset 12 | from torchreid.utils import read_json, write_json 13 | 14 | 15 | class VIPeR(ImageDataset): 16 | """VIPeR. 17 | 18 | Reference: 19 | Gray et al. Evaluating appearance models for recognition, reacquisition, and tracking. PETS 2007. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 632. 25 | - images: 632 x 2 = 1264. 26 | - cameras: 2. 27 | """ 28 | dataset_dir = 'viper' 29 | dataset_url = 'http://users.soe.ucsc.edu/~manduchi/VIPeR.v1.0.zip' 30 | 31 | def __init__(self, root='', split_id=0, **kwargs): 32 | self.root = osp.abspath(osp.expanduser(root)) 33 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 34 | self.download_dataset(self.dataset_dir, self.dataset_url) 35 | 36 | self.cam_a_dir = osp.join(self.dataset_dir, 'VIPeR', 'cam_a') 37 | self.cam_b_dir = osp.join(self.dataset_dir, 'VIPeR', 'cam_b') 38 | self.split_path = osp.join(self.dataset_dir, 'splits.json') 39 | 40 | required_files = [ 41 | self.dataset_dir, 42 | self.cam_a_dir, 43 | self.cam_b_dir 44 | ] 45 | self.check_before_run(required_files) 46 | 47 | self.prepare_split() 48 | splits = read_json(self.split_path) 49 | if split_id >= len(splits): 50 | raise ValueError('split_id exceeds range, received {}, ' 51 | 'but expected between 0 and {}'.format(split_id, len(splits)-1)) 52 | split = splits[split_id] 53 | 54 | train = split['train'] 55 | query = split['query'] # query and gallery share the same images 56 | gallery = split['gallery'] 57 | 58 | train = [tuple(item) for item in train] 59 | query = [tuple(item) for item in query] 60 | gallery = [tuple(item) for item in gallery] 61 | 62 | super(VIPeR, self).__init__(train, query, gallery, **kwargs) 63 | 64 | def prepare_split(self): 65 | if not osp.exists(self.split_path): 66 | print('Creating 10 random splits of train ids and test ids') 67 | 68 | cam_a_imgs = sorted(glob.glob(osp.join(self.cam_a_dir, '*.bmp'))) 69 | cam_b_imgs = sorted(glob.glob(osp.join(self.cam_b_dir, '*.bmp'))) 70 | assert len(cam_a_imgs) == len(cam_b_imgs) 71 | num_pids = len(cam_a_imgs) 72 | print('Number of identities: {}'.format(num_pids)) 73 | num_train_pids = num_pids // 2 74 | 75 | """ 76 | In total, there will be 20 splits because each random split creates two 77 | sub-splits, one using cameraA as query and cameraB as gallery 78 | while the other using cameraB as query and cameraA as gallery. 79 | Therefore, results should be averaged over 20 splits (split_id=0~19). 80 | 81 | In practice, a model trained on split_id=0 can be applied to split_id=0&1 82 | as split_id=0&1 share the same training data (so on and so forth). 83 | """ 84 | splits = [] 85 | for _ in range(10): 86 | order = np.arange(num_pids) 87 | np.random.shuffle(order) 88 | train_idxs = order[:num_train_pids] 89 | test_idxs = order[num_train_pids:] 90 | assert not bool(set(train_idxs) & set(test_idxs)), 'Error: train and test overlap' 91 | 92 | train = [] 93 | for pid, idx in enumerate(train_idxs): 94 | cam_a_img = cam_a_imgs[idx] 95 | cam_b_img = cam_b_imgs[idx] 96 | train.append((cam_a_img, pid, 0)) 97 | train.append((cam_b_img, pid, 1)) 98 | 99 | test_a = [] 100 | test_b = [] 101 | for pid, idx in enumerate(test_idxs): 102 | cam_a_img = cam_a_imgs[idx] 103 | cam_b_img = cam_b_imgs[idx] 104 | test_a.append((cam_a_img, pid, 0)) 105 | test_b.append((cam_b_img, pid, 1)) 106 | 107 | # use cameraA as query and cameraB as gallery 108 | split = { 109 | 'train': train, 110 | 'query': test_a, 111 | 'gallery': test_b, 112 | 'num_train_pids': num_train_pids, 113 | 'num_query_pids': num_pids - num_train_pids, 114 | 'num_gallery_pids': num_pids - num_train_pids 115 | } 116 | splits.append(split) 117 | 118 | # use cameraB as query and cameraA as gallery 119 | split = { 120 | 'train': train, 121 | 'query': test_b, 122 | 'gallery': test_a, 123 | 'num_train_pids': num_train_pids, 124 | 'num_query_pids': num_pids - num_train_pids, 125 | 'num_gallery_pids': num_pids - num_train_pids 126 | } 127 | splits.append(split) 128 | 129 | print('Totally {} splits are created'.format(len(splits))) 130 | write_json(splits, self.split_path) 131 | print('Split file saved to {}'.format(self.split_path)) -------------------------------------------------------------------------------- /torchreid/data/datasets/video/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from .mars import Mars 5 | from .ilidsvid import iLIDSVID 6 | from .prid2011 import PRID2011 7 | from .dukemtmcvidreid import DukeMTMCVidReID -------------------------------------------------------------------------------- /torchreid/data/datasets/video/dukemtmcvidreid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | import warnings 10 | 11 | from torchreid.data.datasets import VideoDataset 12 | from torchreid.utils import read_json, write_json 13 | 14 | 15 | class DukeMTMCVidReID(VideoDataset): 16 | """DukeMTMCVidReID. 17 | 18 | Reference: 19 | - Ristani et al. Performance Measures and a Data Set for Multi-Target, 20 | Multi-Camera Tracking. ECCVW 2016. 21 | - Wu et al. Exploit the Unknown Gradually: One-Shot Video-Based Person 22 | Re-Identification by Stepwise Learning. CVPR 2018. 23 | 24 | URL: ``_ 25 | 26 | Dataset statistics: 27 | - identities: 702 (train) + 702 (test). 28 | - tracklets: 2196 (train) + 2636 (test). 29 | """ 30 | dataset_dir = 'dukemtmc-vidreid' 31 | dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-VideoReID.zip' 32 | 33 | def __init__(self, root='', min_seq_len=0, **kwargs): 34 | self.root = osp.abspath(osp.expanduser(root)) 35 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 36 | self.download_dataset(self.dataset_dir, self.dataset_url) 37 | 38 | self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-VideoReID/train') 39 | self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-VideoReID/query') 40 | self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-VideoReID/gallery') 41 | self.split_train_json_path = osp.join(self.dataset_dir, 'split_train.json') 42 | self.split_query_json_path = osp.join(self.dataset_dir, 'split_query.json') 43 | self.split_gallery_json_path = osp.join(self.dataset_dir, 'split_gallery.json') 44 | self.min_seq_len = min_seq_len 45 | 46 | required_files = [ 47 | self.dataset_dir, 48 | self.train_dir, 49 | self.query_dir, 50 | self.gallery_dir 51 | ] 52 | self.check_before_run(required_files) 53 | 54 | train = self.process_dir(self.train_dir, self.split_train_json_path, relabel=True) 55 | query = self.process_dir(self.query_dir, self.split_query_json_path, relabel=False) 56 | gallery = self.process_dir(self.gallery_dir, self.split_gallery_json_path, relabel=False) 57 | 58 | super(DukeMTMCVidReID, self).__init__(train, query, gallery, **kwargs) 59 | 60 | def process_dir(self, dir_path, json_path, relabel): 61 | if osp.exists(json_path): 62 | split = read_json(json_path) 63 | return split['tracklets'] 64 | 65 | print('=> Generating split json file (** this might take a while **)') 66 | pdirs = glob.glob(osp.join(dir_path, '*')) # avoid .DS_Store 67 | print('Processing "{}" with {} person identities'.format(dir_path, len(pdirs))) 68 | 69 | pid_container = set() 70 | for pdir in pdirs: 71 | pid = int(osp.basename(pdir)) 72 | pid_container.add(pid) 73 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 74 | 75 | tracklets = [] 76 | for pdir in pdirs: 77 | pid = int(osp.basename(pdir)) 78 | if relabel: 79 | pid = pid2label[pid] 80 | tdirs = glob.glob(osp.join(pdir, '*')) 81 | for tdir in tdirs: 82 | raw_img_paths = glob.glob(osp.join(tdir, '*.jpg')) 83 | num_imgs = len(raw_img_paths) 84 | 85 | if num_imgs < self.min_seq_len: 86 | continue 87 | 88 | img_paths = [] 89 | for img_idx in range(num_imgs): 90 | # some tracklet starts from 0002 instead of 0001 91 | img_idx_name = 'F' + str(img_idx+1).zfill(4) 92 | res = glob.glob(osp.join(tdir, '*' + img_idx_name + '*.jpg')) 93 | if len(res) == 0: 94 | warnings.warn('Index name {} in {} is missing, skip'.format(img_idx_name, tdir)) 95 | continue 96 | img_paths.append(res[0]) 97 | img_name = osp.basename(img_paths[0]) 98 | if img_name.find('_') == -1: 99 | # old naming format: 0001C6F0099X30823.jpg 100 | camid = int(img_name[5]) - 1 101 | else: 102 | # new naming format: 0001_C6_F0099_X30823.jpg 103 | camid = int(img_name[6]) - 1 104 | img_paths = tuple(img_paths) 105 | tracklets.append((img_paths, pid, camid)) 106 | 107 | print('Saving split to {}'.format(json_path)) 108 | split_dict = {'tracklets': tracklets} 109 | write_json(split_dict, json_path) 110 | 111 | return tracklets -------------------------------------------------------------------------------- /torchreid/data/datasets/video/ilidsvid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | from scipy.io import loadmat 10 | 11 | from torchreid.data.datasets import VideoDataset 12 | from torchreid.utils import read_json, write_json 13 | 14 | 15 | class iLIDSVID(VideoDataset): 16 | """iLIDS-VID. 17 | 18 | Reference: 19 | Wang et al. Person Re-Identification by Video Ranking. ECCV 2014. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 300. 25 | - tracklets: 600. 26 | - cameras: 2. 27 | """ 28 | dataset_dir = 'ilids-vid' 29 | dataset_url = 'http://www.eecs.qmul.ac.uk/~xiatian/iLIDS-VID/iLIDS-VID.tar' 30 | 31 | def __init__(self, root='', split_id=0, **kwargs): 32 | self.root = osp.abspath(osp.expanduser(root)) 33 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 34 | self.download_dataset(self.dataset_dir, self.dataset_url) 35 | 36 | self.data_dir = osp.join(self.dataset_dir, 'i-LIDS-VID') 37 | self.split_dir = osp.join(self.dataset_dir, 'train-test people splits') 38 | self.split_mat_path = osp.join(self.split_dir, 'train_test_splits_ilidsvid.mat') 39 | self.split_path = osp.join(self.dataset_dir, 'splits.json') 40 | self.cam_1_path = osp.join(self.dataset_dir, 'i-LIDS-VID/sequences/cam1') 41 | self.cam_2_path = osp.join(self.dataset_dir, 'i-LIDS-VID/sequences/cam2') 42 | 43 | required_files = [ 44 | self.dataset_dir, 45 | self.data_dir, 46 | self.split_dir 47 | ] 48 | self.check_before_run(required_files) 49 | 50 | self.prepare_split() 51 | splits = read_json(self.split_path) 52 | if split_id >= len(splits): 53 | raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1)) 54 | split = splits[split_id] 55 | train_dirs, test_dirs = split['train'], split['test'] 56 | 57 | train = self.process_data(train_dirs, cam1=True, cam2=True) 58 | query = self.process_data(test_dirs, cam1=True, cam2=False) 59 | gallery = self.process_data(test_dirs, cam1=False, cam2=True) 60 | 61 | super(iLIDSVID, self).__init__(train, query, gallery, **kwargs) 62 | 63 | def prepare_split(self): 64 | if not osp.exists(self.split_path): 65 | print('Creating splits ...') 66 | mat_split_data = loadmat(self.split_mat_path)['ls_set'] 67 | 68 | num_splits = mat_split_data.shape[0] 69 | num_total_ids = mat_split_data.shape[1] 70 | assert num_splits == 10 71 | assert num_total_ids == 300 72 | num_ids_each = num_total_ids // 2 73 | 74 | # pids in mat_split_data are indices, so we need to transform them 75 | # to real pids 76 | person_cam1_dirs = sorted(glob.glob(osp.join(self.cam_1_path, '*'))) 77 | person_cam2_dirs = sorted(glob.glob(osp.join(self.cam_2_path, '*'))) 78 | 79 | person_cam1_dirs = [osp.basename(item) for item in person_cam1_dirs] 80 | person_cam2_dirs = [osp.basename(item) for item in person_cam2_dirs] 81 | 82 | # make sure persons in one camera view can be found in the other camera view 83 | assert set(person_cam1_dirs) == set(person_cam2_dirs) 84 | 85 | splits = [] 86 | for i_split in range(num_splits): 87 | # first 50% for testing and the remaining for training, following Wang et al. ECCV'14. 88 | train_idxs = sorted(list(mat_split_data[i_split, num_ids_each:])) 89 | test_idxs = sorted(list(mat_split_data[i_split, :num_ids_each])) 90 | 91 | train_idxs = [int(i)-1 for i in train_idxs] 92 | test_idxs = [int(i)-1 for i in test_idxs] 93 | 94 | # transform pids to person dir names 95 | train_dirs = [person_cam1_dirs[i] for i in train_idxs] 96 | test_dirs = [person_cam1_dirs[i] for i in test_idxs] 97 | 98 | split = {'train': train_dirs, 'test': test_dirs} 99 | splits.append(split) 100 | 101 | print('Totally {} splits are created, following Wang et al. ECCV\'14'.format(len(splits))) 102 | print('Split file is saved to {}'.format(self.split_path)) 103 | write_json(splits, self.split_path) 104 | 105 | def process_data(self, dirnames, cam1=True, cam2=True): 106 | tracklets = [] 107 | dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)} 108 | 109 | for dirname in dirnames: 110 | if cam1: 111 | person_dir = osp.join(self.cam_1_path, dirname) 112 | img_names = glob.glob(osp.join(person_dir, '*.png')) 113 | assert len(img_names) > 0 114 | img_names = tuple(img_names) 115 | pid = dirname2pid[dirname] 116 | tracklets.append((img_names, pid, 0)) 117 | 118 | if cam2: 119 | person_dir = osp.join(self.cam_2_path, dirname) 120 | img_names = glob.glob(osp.join(person_dir, '*.png')) 121 | assert len(img_names) > 0 122 | img_names = tuple(img_names) 123 | pid = dirname2pid[dirname] 124 | tracklets.append((img_names, pid, 1)) 125 | 126 | return tracklets -------------------------------------------------------------------------------- /torchreid/data/datasets/video/mars.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | from scipy.io import loadmat 9 | import warnings 10 | 11 | from torchreid.data.datasets import VideoDataset 12 | 13 | 14 | class Mars(VideoDataset): 15 | """MARS. 16 | 17 | Reference: 18 | Zheng et al. MARS: A Video Benchmark for Large-Scale Person Re-identification. ECCV 2016. 19 | 20 | URL: ``_ 21 | 22 | Dataset statistics: 23 | - identities: 1261. 24 | - tracklets: 8298 (train) + 1980 (query) + 9330 (gallery). 25 | - cameras: 6. 26 | """ 27 | dataset_dir = 'mars' 28 | dataset_url = None 29 | 30 | def __init__(self, root='', **kwargs): 31 | self.root = osp.abspath(osp.expanduser(root)) 32 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 33 | self.download_dataset(self.dataset_dir, self.dataset_url) 34 | 35 | self.train_name_path = osp.join(self.dataset_dir, 'info/train_name.txt') 36 | self.test_name_path = osp.join(self.dataset_dir, 'info/test_name.txt') 37 | self.track_train_info_path = osp.join(self.dataset_dir, 'info/tracks_train_info.mat') 38 | self.track_test_info_path = osp.join(self.dataset_dir, 'info/tracks_test_info.mat') 39 | self.query_IDX_path = osp.join(self.dataset_dir, 'info/query_IDX.mat') 40 | 41 | required_files = [ 42 | self.dataset_dir, 43 | self.train_name_path, 44 | self.test_name_path, 45 | self.track_train_info_path, 46 | self.track_test_info_path, 47 | self.query_IDX_path 48 | ] 49 | self.check_before_run(required_files) 50 | 51 | train_names = self.get_names(self.train_name_path) 52 | test_names = self.get_names(self.test_name_path) 53 | track_train = loadmat(self.track_train_info_path)['track_train_info'] # numpy.ndarray (8298, 4) 54 | track_test = loadmat(self.track_test_info_path)['track_test_info'] # numpy.ndarray (12180, 4) 55 | query_IDX = loadmat(self.query_IDX_path)['query_IDX'].squeeze() # numpy.ndarray (1980,) 56 | query_IDX -= 1 # index from 0 57 | track_query = track_test[query_IDX,:] 58 | gallery_IDX = [i for i in range(track_test.shape[0]) if i not in query_IDX] 59 | track_gallery = track_test[gallery_IDX,:] 60 | 61 | train = self.process_data(train_names, track_train, home_dir='bbox_train', relabel=True) 62 | query = self.process_data(test_names, track_query, home_dir='bbox_test', relabel=False) 63 | gallery = self.process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False) 64 | 65 | super(Mars, self).__init__(train, query, gallery, **kwargs) 66 | 67 | def get_names(self, fpath): 68 | names = [] 69 | with open(fpath, 'r') as f: 70 | for line in f: 71 | new_line = line.rstrip() 72 | names.append(new_line) 73 | return names 74 | 75 | def process_data(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0): 76 | assert home_dir in ['bbox_train', 'bbox_test'] 77 | num_tracklets = meta_data.shape[0] 78 | pid_list = list(set(meta_data[:,2].tolist())) 79 | num_pids = len(pid_list) 80 | 81 | if relabel: pid2label = {pid:label for label, pid in enumerate(pid_list)} 82 | tracklets = [] 83 | 84 | for tracklet_idx in range(num_tracklets): 85 | data = meta_data[tracklet_idx,...] 86 | start_index, end_index, pid, camid = data 87 | if pid == -1: 88 | continue # junk images are just ignored 89 | assert 1 <= camid <= 6 90 | if relabel: pid = pid2label[pid] 91 | camid -= 1 # index starts from 0 92 | img_names = names[start_index - 1:end_index] 93 | 94 | # make sure image names correspond to the same person 95 | pnames = [img_name[:4] for img_name in img_names] 96 | assert len(set(pnames)) == 1, 'Error: a single tracklet contains different person images' 97 | 98 | # make sure all images are captured under the same camera 99 | camnames = [img_name[5] for img_name in img_names] 100 | assert len(set(camnames)) == 1, 'Error: images are captured under different cameras!' 101 | 102 | # append image names with directory information 103 | img_paths = [osp.join(self.dataset_dir, home_dir, img_name[:4], img_name) for img_name in img_names] 104 | if len(img_paths) >= min_seq_len: 105 | img_paths = tuple(img_paths) 106 | tracklets.append((img_paths, pid, camid)) 107 | 108 | return tracklets 109 | 110 | def combine_all(self): 111 | warnings.warn('Some query IDs do not appear in gallery. Therefore, combineall ' 112 | 'does not make any difference to Mars') -------------------------------------------------------------------------------- /torchreid/data/datasets/video/prid2011.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | import glob 9 | 10 | from torchreid.data.datasets import VideoDataset 11 | from torchreid.utils import read_json, write_json 12 | 13 | 14 | class PRID2011(VideoDataset): 15 | """PRID2011. 16 | 17 | Reference: 18 | Hirzer et al. Person Re-Identification by Descriptive and 19 | Discriminative Classification. SCIA 2011. 20 | 21 | URL: ``_ 22 | 23 | Dataset statistics: 24 | - identities: 200. 25 | - tracklets: 400. 26 | - cameras: 2. 27 | """ 28 | dataset_dir = 'prid2011' 29 | dataset_url = None 30 | 31 | def __init__(self, root='', split_id=0, **kwargs): 32 | self.root = osp.abspath(osp.expanduser(root)) 33 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 34 | self.download_dataset(self.dataset_dir, self.dataset_url) 35 | 36 | self.split_path = osp.join(self.dataset_dir, 'splits_prid2011.json') 37 | self.cam_a_dir = osp.join(self.dataset_dir, 'prid_2011', 'multi_shot', 'cam_a') 38 | self.cam_b_dir = osp.join(self.dataset_dir, 'prid_2011', 'multi_shot', 'cam_b') 39 | 40 | required_files = [ 41 | self.dataset_dir, 42 | self.cam_a_dir, 43 | self.cam_b_dir 44 | ] 45 | self.check_before_run(required_files) 46 | 47 | splits = read_json(self.split_path) 48 | if split_id >= len(splits): 49 | raise ValueError('split_id exceeds range, received {}, but expected between 0 and {}'.format(split_id, len(splits)-1)) 50 | split = splits[split_id] 51 | train_dirs, test_dirs = split['train'], split['test'] 52 | 53 | train = self.process_dir(train_dirs, cam1=True, cam2=True) 54 | query = self.process_dir(test_dirs, cam1=True, cam2=False) 55 | gallery = self.process_dir(test_dirs, cam1=False, cam2=True) 56 | 57 | super(PRID2011, self).__init__(train, query, gallery, **kwargs) 58 | 59 | def process_dir(self, dirnames, cam1=True, cam2=True): 60 | tracklets = [] 61 | dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)} 62 | 63 | for dirname in dirnames: 64 | if cam1: 65 | person_dir = osp.join(self.cam_a_dir, dirname) 66 | img_names = glob.glob(osp.join(person_dir, '*.png')) 67 | assert len(img_names) > 0 68 | img_names = tuple(img_names) 69 | pid = dirname2pid[dirname] 70 | tracklets.append((img_names, pid, 0)) 71 | 72 | if cam2: 73 | person_dir = osp.join(self.cam_b_dir, dirname) 74 | img_names = glob.glob(osp.join(person_dir, '*.png')) 75 | assert len(img_names) > 0 76 | img_names = tuple(img_names) 77 | pid = dirname2pid[dirname] 78 | tracklets.append((img_names, pid, 1)) 79 | 80 | return tracklets -------------------------------------------------------------------------------- /torchreid/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | from collections import defaultdict 5 | import numpy as np 6 | import copy 7 | import random 8 | 9 | import torch 10 | from torch.utils.data.sampler import Sampler, RandomSampler 11 | 12 | 13 | class RandomIdentitySampler(Sampler): 14 | """Randomly samples N identities each with K instances. 15 | 16 | Args: 17 | data_source (list): contains tuples of (img_path(s), pid, camid). 18 | batch_size (int): batch size. 19 | num_instances (int): number of instances per identity in a batch. 20 | """ 21 | def __init__(self, data_source, batch_size, num_instances): 22 | if batch_size < num_instances: 23 | raise ValueError('batch_size={} must be no less ' 24 | 'than num_instances={}'.format(batch_size, num_instances)) 25 | 26 | self.data_source = data_source 27 | self.batch_size = batch_size 28 | self.num_instances = num_instances 29 | self.num_pids_per_batch = self.batch_size // self.num_instances 30 | self.index_dic = defaultdict(list) 31 | for index, (_, pid, _) in enumerate(self.data_source): 32 | self.index_dic[pid].append(index) 33 | self.pids = list(self.index_dic.keys()) 34 | 35 | # estimate number of examples in an epoch 36 | # TODO: improve precision 37 | self.length = 0 38 | for pid in self.pids: 39 | idxs = self.index_dic[pid] 40 | num = len(idxs) 41 | if num < self.num_instances: 42 | num = self.num_instances 43 | self.length += num - num % self.num_instances 44 | 45 | def __iter__(self): 46 | batch_idxs_dict = defaultdict(list) 47 | 48 | for pid in self.pids: 49 | idxs = copy.deepcopy(self.index_dic[pid]) 50 | if len(idxs) < self.num_instances: 51 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 52 | random.shuffle(idxs) 53 | batch_idxs = [] 54 | for idx in idxs: 55 | batch_idxs.append(idx) 56 | if len(batch_idxs) == self.num_instances: 57 | batch_idxs_dict[pid].append(batch_idxs) 58 | batch_idxs = [] 59 | 60 | avai_pids = copy.deepcopy(self.pids) 61 | final_idxs = [] 62 | 63 | while len(avai_pids) >= self.num_pids_per_batch: 64 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 65 | for pid in selected_pids: 66 | batch_idxs = batch_idxs_dict[pid].pop(0) 67 | final_idxs.extend(batch_idxs) 68 | if len(batch_idxs_dict[pid]) == 0: 69 | avai_pids.remove(pid) 70 | 71 | ''' 20191114_1 72 | while self.length-len(final_idxs) >= self.batch_size: 73 | for i in range(self.num_pids_per_batch): 74 | selected_pid = random.sample(avai_pids, 1) 75 | batch_idxs = batch_idxs_dict[selected_pid[0]].pop(0) 76 | final_idxs.extend(batch_idxs) 77 | if len(batch_idxs_dict[selected_pid[0]]) == 0: 78 | avai_pids.remove(selected_pid[0]) 79 | ''' 80 | 81 | return iter(final_idxs) 82 | 83 | def __len__(self): 84 | return self.length 85 | 86 | 87 | class TripletSampler(Sampler): 88 | """Randomly samples N identities each with K instances. Each identity should at least have K samples, or it will be omitted. 89 | 90 | Args: 91 | data_source (list): contains tuples of (img_path(s), pid, camid). 92 | batch_size (int): batch size. 93 | num_instances (int): number of instances per identity in a batch. 94 | """ 95 | def __init__(self, data_source, batch_size, num_instances): 96 | if batch_size < num_instances: 97 | raise ValueError('batch_size={} must be no less ' 98 | 'than num_instances={}'.format(batch_size, num_instances)) 99 | 100 | self.data_source = data_source 101 | self.batch_size = batch_size 102 | self.num_instances = num_instances 103 | self.num_pids_per_batch = self.batch_size // self.num_instances 104 | self.index_dic = defaultdict(list) 105 | for index, (_, pid, _) in enumerate(self.data_source): 106 | self.index_dic[pid].append(index) 107 | self.pids = list(self.index_dic.keys()) 108 | 109 | index_dic_temp = copy.deepcopy(self.index_dic) 110 | pids_temp = copy.deepcopy(self.pids) 111 | 112 | self.length = 0 113 | for pid in pids_temp: 114 | idxs = index_dic_temp[pid] 115 | num = len(idxs) 116 | if num < self.num_instances: 117 | self.pids.remove(pid) 118 | del self.index_dic[pid] 119 | else: 120 | self.length += num - num % self.num_instances 121 | 122 | 123 | def __iter__(self): 124 | batch_idxs_dict = defaultdict(list) 125 | 126 | for pid in self.pids: 127 | idxs = copy.deepcopy(self.index_dic[pid]) 128 | # if len(idxs) < self.num_instances: 129 | # idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 130 | random.shuffle(idxs) 131 | batch_idxs = [] 132 | for idx in idxs: 133 | batch_idxs.append(idx) 134 | if len(batch_idxs) == self.num_instances: 135 | batch_idxs_dict[pid].append(batch_idxs) 136 | batch_idxs = [] 137 | 138 | avai_pids = copy.deepcopy(self.pids) 139 | final_idxs = [] 140 | 141 | while len(avai_pids) >= self.num_pids_per_batch: 142 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 143 | for pid in selected_pids: 144 | batch_idxs = batch_idxs_dict[pid].pop(0) 145 | final_idxs.extend(batch_idxs) 146 | if len(batch_idxs_dict[pid]) == 0: 147 | avai_pids.remove(pid) 148 | 149 | while self.length-len(final_idxs) >= self.batch_size: 150 | for i in range(self.num_pids_per_batch): 151 | selected_pid = random.sample(avai_pids, 1) 152 | batch_idxs = batch_idxs_dict[selected_pid[0]].pop(0) 153 | final_idxs.extend(batch_idxs) 154 | if len(batch_idxs_dict[selected_pid[0]]) == 0: 155 | avai_pids.remove(selected_pid[0]) 156 | 157 | return iter(final_idxs) 158 | 159 | def __len__(self): 160 | return self.length 161 | 162 | 163 | def build_train_sampler(data_source, train_sampler, batch_size=32, num_instances=4, **kwargs): 164 | """Builds a training sampler. 165 | 166 | Args: 167 | data_source (list): contains tuples of (img_path(s), pid, camid). 168 | train_sampler (str): sampler name (default: ``RandomSampler``). 169 | batch_size (int, optional): batch size. Default is 32. 170 | num_instances (int, optional): number of instances per identity in a 171 | batch (for ``RandomIdentitySampler``). Default is 4. 172 | """ 173 | if train_sampler == 'RandomIdentitySampler': 174 | sampler = RandomIdentitySampler(data_source, batch_size, num_instances) 175 | elif train_sampler == 'TripletSampler': 176 | sampler = TripletSampler(data_source, batch_size, num_instances) 177 | else: 178 | sampler = RandomSampler(data_source) 179 | 180 | return sampler -------------------------------------------------------------------------------- /torchreid/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from .engine import Engine 5 | 6 | from .image import ImageSoftmaxEngine 7 | from .image import ImageTripletEngine 8 | from .image import ImageFPBEngine 9 | 10 | from .video import VideoSoftmaxEngine 11 | from .video import VideoTripletEngine -------------------------------------------------------------------------------- /torchreid/engine/image/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .softmax import ImageSoftmaxEngine 4 | from .triplet import ImageTripletEngine 5 | from .engine_FPB import ImageFPBEngine -------------------------------------------------------------------------------- /torchreid/engine/image/engine_FPB.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import time 6 | import datetime 7 | 8 | import torch 9 | 10 | import torchreid 11 | from torchreid.engine import engine 12 | from torchreid.losses import CrossEntropyLoss, TripletLoss, CenterLoss 13 | from torchreid.utils import AverageMeter, open_specified_layers, open_all_layers 14 | from torchreid import metrics 15 | 16 | 17 | class ImageFPBEngine(engine.Engine): 18 | r"""PcbAttentionEngine 19 | """ 20 | 21 | def __init__(self, datamanager, model, optimizer, margin=0.3, 22 | weight_t=1, weight_x=1, scheduler=None, use_gpu=True, 23 | label_smooth=True, div_penalty=None, div_start=0): 24 | super(ImageFPBEngine, self).__init__(datamanager, model, optimizer, scheduler, use_gpu) 25 | 26 | self.weight_t = weight_t 27 | self.weight_x = weight_x 28 | self.num_parts = 3 29 | self.feature_dim = 2048+(1024)*self.num_parts 30 | self.centloss_weight = 1.0+self.num_parts*(1.0) 31 | 32 | self.div_penalty = div_penalty 33 | self.div_start = div_start 34 | 35 | self.criterion_t = TripletLoss(margin=margin) 36 | self.criterion_x = CrossEntropyLoss( 37 | num_classes=self.datamanager.num_train_pids, 38 | use_gpu=self.use_gpu, 39 | label_smooth=label_smooth 40 | ) 41 | self.criterion_c = CenterLoss(num_classes=751, feat_dim=self.feature_dim) 42 | 43 | 44 | def train(self, epoch, max_epoch, trainloader, fixbase_epoch=0, open_layers=None, print_freq=10): 45 | losses_t = AverageMeter() 46 | losses_x = AverageMeter() 47 | losses_c = AverageMeter() 48 | losses_p = AverageMeter() 49 | accs = AverageMeter() 50 | batch_time = AverageMeter() 51 | data_time = AverageMeter() 52 | 53 | 54 | if self.div_penalty is not None and epoch >= self.div_start: 55 | print("Using div penalty!") 56 | 57 | self.model.train() 58 | if (epoch+1)<=fixbase_epoch and open_layers is not None: 59 | print('* Only train {} (epoch: {}/{})'.format(open_layers, epoch+1, fixbase_epoch)) 60 | open_specified_layers(self.model, open_layers) 61 | else: 62 | open_all_layers(self.model) 63 | 64 | num_batches = len(trainloader) 65 | end = time.time() 66 | for batch_idx, data in enumerate(trainloader): 67 | data_time.update(time.time() - end) 68 | 69 | imgs, pids = self._parse_data_for_train(data) 70 | if self.use_gpu: 71 | imgs = imgs.cuda() 72 | pids = pids.cuda() 73 | 74 | self.optimizer.zero_grad() 75 | output, fea, reg_feat = self.model(imgs) 76 | 77 | b = output[0].size(0) # 78 | loss_c = self._compute_loss(self.criterion_c, fea, pids) 79 | loss_t = self._compute_loss(self.criterion_t, fea, pids) 80 | loss_x = self._compute_loss(self.criterion_x, output, pids) 81 | loss = self.weight_x * loss_x + self.weight_t * loss_t + 0.0005 / self.centloss_weight * loss_c 82 | 83 | if self.div_penalty is not None: 84 | penalty = self.div_penalty(reg_feat) 85 | 86 | if epoch >= self.div_start: 87 | loss += penalty 88 | 89 | losses_p.update(penalty.item(), b) 90 | 91 | loss.backward() 92 | self.optimizer.step() 93 | 94 | batch_time.update(time.time() - end) 95 | 96 | losses_t.update(loss_t.item(), b) 97 | losses_x.update(loss_x.item(), b) 98 | losses_c.update(loss_c.item(), b) 99 | 100 | accs.update(metrics.accuracy(output, pids)[0].item()) 101 | 102 | if (batch_idx+1) % print_freq == 0: 103 | # estimate remaining time 104 | eta_seconds = batch_time.avg * (num_batches-(batch_idx+1) + (max_epoch-(epoch+1))*num_batches) 105 | eta_str = str(datetime.timedelta(seconds=int(eta_seconds))) 106 | print('Epoch: [{0}/{1}][{2}/{3}]\t' 107 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 108 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 109 | 'Loss_t {loss_t.val:.4f} ({loss_t.avg:.4f})\t' 110 | 'Loss_x {loss_x.val:.4f} ({loss_x.avg:.4f})\t' 111 | 'Loss_c {loss_c.val:.4f} ({loss_c.avg:.4f})\t' 112 | 'Acc {acc.val:.2f} ({acc.avg:.2f})\t' 113 | 'Lr {lr:.6f}\t' 114 | 'eta {eta}'.format( 115 | epoch+1, max_epoch, batch_idx+1, num_batches, 116 | batch_time=batch_time, 117 | data_time=data_time, 118 | loss_t=losses_t, 119 | loss_x=losses_x, 120 | loss_c=losses_c, 121 | acc=accs, 122 | lr=self.optimizer.param_groups[0]['lr'], 123 | eta=eta_str 124 | ) 125 | ) 126 | 127 | if self.writer is not None: 128 | n_iter = epoch * num_batches + batch_idx 129 | self.writer.add_scalar('Train/Time', batch_time.avg, n_iter) 130 | self.writer.add_scalar('Train/Data', data_time.avg, n_iter) 131 | self.writer.add_scalar('Train/Loss_t', losses_t.val, n_iter) 132 | self.writer.add_scalar('Train/Loss_x', losses_x.val, n_iter) 133 | self.writer.add_scalar('Train/Loss_c', losses_c.val, n_iter) 134 | self.writer.add_scalar('Train/Acc1', accs.val, n_iter) 135 | self.writer.add_scalar('Train/Lr', self.optimizer.param_groups[0]['lr'], n_iter) 136 | if self.div_penalty is not None: 137 | self.writer.add_scalar('Train/Loss_p', losses_p.val, n_iter) 138 | 139 | end = time.time() 140 | 141 | if self.scheduler is not None: 142 | self.scheduler.step() 143 | -------------------------------------------------------------------------------- /torchreid/engine/image/softmax.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import time 6 | import datetime 7 | 8 | import torch 9 | 10 | import torchreid 11 | from torchreid.engine import engine 12 | from torchreid.losses import CrossEntropyLoss 13 | from torchreid.utils import AverageMeter, open_specified_layers, open_all_layers 14 | from torchreid import metrics 15 | 16 | 17 | class ImageSoftmaxEngine(engine.Engine): 18 | r"""Softmax-loss engine for image-reid. 19 | 20 | Args: 21 | datamanager (DataManager): an instance of ``torchreid.data.ImageDataManager`` 22 | or ``torchreid.data.VideoDataManager``. 23 | model (nn.Module): model instance. 24 | optimizer (Optimizer): an Optimizer. 25 | scheduler (LRScheduler, optional): if None, no learning rate decay will be performed. 26 | use_gpu (bool, optional): use gpu. Default is True. 27 | label_smooth (bool, optional): use label smoothing regularizer. Default is True. 28 | 29 | Examples:: 30 | 31 | import torch 32 | import torchreid 33 | datamanager = torchreid.data.ImageDataManager( 34 | root='path/to/reid-data', 35 | sources='market1501', 36 | height=256, 37 | width=128, 38 | combineall=False, 39 | batch_size=32 40 | ) 41 | model = torchreid.models.build_model( 42 | name='resnet50', 43 | num_classes=datamanager.num_train_pids, 44 | loss='softmax' 45 | ) 46 | model = model.cuda() 47 | optimizer = torchreid.optim.build_optimizer( 48 | model, optim='adam', lr=0.0003 49 | ) 50 | scheduler = torchreid.optim.build_lr_scheduler( 51 | optimizer, 52 | lr_scheduler='single_step', 53 | stepsize=20 54 | ) 55 | engine = torchreid.engine.ImageSoftmaxEngine( 56 | datamanager, model, optimizer, scheduler=scheduler 57 | ) 58 | engine.run( 59 | max_epoch=60, 60 | save_dir='log/resnet50-softmax-market1501', 61 | print_freq=10 62 | ) 63 | """ 64 | 65 | def __init__(self, datamanager, model, optimizer, scheduler=None, use_gpu=True, 66 | label_smooth=True): 67 | super(ImageSoftmaxEngine, self).__init__(datamanager, model, optimizer, scheduler, use_gpu) 68 | 69 | self.criterion = CrossEntropyLoss( 70 | num_classes=self.datamanager.num_train_pids, 71 | use_gpu=self.use_gpu, 72 | label_smooth=label_smooth 73 | ) 74 | 75 | def train(self, epoch, max_epoch, trainloader, fixbase_epoch=0, open_layers=None, print_freq=10): 76 | losses = AverageMeter() 77 | accs = AverageMeter() 78 | batch_time = AverageMeter() 79 | data_time = AverageMeter() 80 | 81 | self.model.train() 82 | if (epoch+1)<=fixbase_epoch and open_layers is not None: 83 | print('* Only train {} (epoch: {}/{})'.format(open_layers, epoch+1, fixbase_epoch)) 84 | open_specified_layers(self.model, open_layers) 85 | else: 86 | open_all_layers(self.model) 87 | 88 | num_batches = len(trainloader) 89 | end = time.time() 90 | for batch_idx, data in enumerate(trainloader): 91 | data_time.update(time.time() - end) 92 | 93 | imgs, pids = self._parse_data_for_train(data) 94 | if self.use_gpu: 95 | imgs = imgs.cuda() 96 | pids = pids.cuda() 97 | 98 | self.optimizer.zero_grad() 99 | output1 = self.model(imgs) 100 | loss1 = self._compute_loss(self.criterion, output1, pids) 101 | loss = loss1 102 | loss.backward() 103 | self.optimizer.step() 104 | 105 | batch_time.update(time.time() - end) 106 | 107 | losses.update(loss.item(), pids.size(0)) 108 | accs.update(metrics.accuracy(output1, pids)[0].item()) 109 | 110 | if (batch_idx+1) % print_freq == 0: 111 | # estimate remaining time 112 | eta_seconds = batch_time.avg * (num_batches-(batch_idx+1) + (max_epoch-(epoch+1))*num_batches) 113 | eta_str = str(datetime.timedelta(seconds=int(eta_seconds))) 114 | print('Epoch: [{0}/{1}][{2}/{3}]\t' 115 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 116 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 117 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 118 | 'Acc {acc.val:.2f} ({acc.avg:.2f})\t' 119 | 'Lr {lr:.6f}\t' 120 | 'eta {eta}'.format( 121 | epoch+1, max_epoch, batch_idx+1, num_batches, 122 | batch_time=batch_time, 123 | data_time=data_time, 124 | loss=losses, 125 | acc=accs, 126 | lr=self.optimizer.param_groups[0]['lr'], 127 | eta=eta_str 128 | ) 129 | ) 130 | 131 | if self.writer is not None: 132 | n_iter = epoch * num_batches + batch_idx 133 | self.writer.add_scalar('Train/Time', batch_time.avg, n_iter) 134 | self.writer.add_scalar('Train/Data', data_time.avg, n_iter) 135 | self.writer.add_scalar('Train/Loss', losses.avg, n_iter) 136 | self.writer.add_scalar('Train/Acc', accs.avg, n_iter) 137 | self.writer.add_scalar('Train/Lr', self.optimizer.param_groups[0]['lr'], n_iter) 138 | 139 | end = time.time() 140 | 141 | if self.scheduler is not None: 142 | self.scheduler.step() 143 | -------------------------------------------------------------------------------- /torchreid/engine/image/triplet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import time 6 | import datetime 7 | 8 | import torch 9 | 10 | import torchreid 11 | from torchreid.engine import engine 12 | from torchreid.losses import CrossEntropyLoss, TripletLoss 13 | from torchreid.utils import AverageMeter, open_specified_layers, open_all_layers 14 | from torchreid import metrics 15 | 16 | 17 | class ImageTripletEngine(engine.Engine): 18 | r"""Triplet-loss engine for image-reid. 19 | 20 | Args: 21 | datamanager (DataManager): an instance of ``torchreid.data.ImageDataManager`` 22 | or ``torchreid.data.VideoDataManager``. 23 | model (nn.Module): model instance. 24 | optimizer (Optimizer): an Optimizer. 25 | margin (float, optional): margin for triplet loss. Default is 0.3. 26 | weight_t (float, optional): weight for triplet loss. Default is 1. 27 | weight_x (float, optional): weight for softmax loss. Default is 1. 28 | scheduler (LRScheduler, optional): if None, no learning rate decay will be performed. 29 | use_gpu (bool, optional): use gpu. Default is True. 30 | label_smooth (bool, optional): use label smoothing regularizer. Default is True. 31 | 32 | Examples:: 33 | 34 | import torch 35 | import torchreid 36 | datamanager = torchreid.data.ImageDataManager( 37 | root='path/to/reid-data', 38 | sources='market1501', 39 | height=256, 40 | width=128, 41 | combineall=False, 42 | batch_size=32, 43 | num_instances=4, 44 | train_sampler='RandomIdentitySampler' # this is important 45 | ) 46 | model = torchreid.models.build_model( 47 | name='resnet50', 48 | num_classes=datamanager.num_train_pids, 49 | loss='triplet' 50 | ) 51 | model = model.cuda() 52 | optimizer = torchreid.optim.build_optimizer( 53 | model, optim='adam', lr=0.0003 54 | ) 55 | scheduler = torchreid.optim.build_lr_scheduler( 56 | optimizer, 57 | lr_scheduler='single_step', 58 | stepsize=20 59 | ) 60 | engine = torchreid.engine.ImageTripletEngine( 61 | datamanager, model, optimizer, margin=0.3, 62 | weight_t=0.7, weight_x=1, scheduler=scheduler 63 | ) 64 | engine.run( 65 | max_epoch=60, 66 | save_dir='log/resnet50-triplet-market1501', 67 | print_freq=10 68 | ) 69 | """ 70 | 71 | def __init__(self, datamanager, model, optimizer, margin=0.3, 72 | weight_t=1, weight_x=1, scheduler=None, use_gpu=True, 73 | label_smooth=True): 74 | super(ImageTripletEngine, self).__init__(datamanager, model, optimizer, scheduler, use_gpu) 75 | 76 | self.weight_t = weight_t 77 | self.weight_x = weight_x 78 | 79 | self.criterion_t = TripletLoss(margin=margin) 80 | self.criterion_x = CrossEntropyLoss( 81 | num_classes=self.datamanager.num_train_pids, 82 | use_gpu=self.use_gpu, 83 | label_smooth=label_smooth 84 | ) 85 | 86 | def train(self, epoch, max_epoch, trainloader, fixbase_epoch=0, open_layers=None, print_freq=10): 87 | losses_t = AverageMeter() 88 | losses_x = AverageMeter() 89 | accs = AverageMeter() 90 | batch_time = AverageMeter() 91 | data_time = AverageMeter() 92 | 93 | self.model.train() 94 | if (epoch+1)<=fixbase_epoch and open_layers is not None: 95 | print('* Only train {} (epoch: {}/{})'.format(open_layers, epoch+1, fixbase_epoch)) 96 | open_specified_layers(self.model, open_layers) 97 | else: 98 | open_all_layers(self.model) 99 | 100 | num_batches = len(trainloader) 101 | end = time.time() 102 | for batch_idx, data in enumerate(trainloader): 103 | data_time.update(time.time() - end) 104 | 105 | imgs, pids = self._parse_data_for_train(data) 106 | if self.use_gpu: 107 | imgs = imgs.cuda() 108 | pids = pids.cuda() 109 | 110 | self.optimizer.zero_grad() 111 | outputs, features = self.model(imgs) 112 | loss_t = self._compute_loss(self.criterion_t, features, pids) 113 | loss_x = self._compute_loss(self.criterion_x, outputs, pids) 114 | loss = self.weight_t * loss_t + self.weight_x * loss_x 115 | loss.backward() 116 | self.optimizer.step() 117 | 118 | batch_time.update(time.time() - end) 119 | 120 | losses_t.update(loss_t.item(), pids.size(0)) 121 | losses_x.update(loss_x.item(), pids.size(0)) 122 | accs.update(metrics.accuracy(outputs, pids)[0].item()) 123 | 124 | if (batch_idx+1) % print_freq == 0: 125 | # estimate remaining time 126 | eta_seconds = batch_time.avg * (num_batches-(batch_idx+1) + (max_epoch-(epoch+1))*num_batches) 127 | eta_str = str(datetime.timedelta(seconds=int(eta_seconds))) 128 | print('Epoch: [{0}/{1}][{2}/{3}]\t' 129 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 130 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 131 | 'Loss_t {loss_t.val:.4f} ({loss_t.avg:.4f})\t' 132 | 'Loss_x {loss_x.val:.4f} ({loss_x.avg:.4f})\t' 133 | 'Acc {acc.val:.2f} ({acc.avg:.2f})\t' 134 | 'Lr {lr:.6f}\t' 135 | 'eta {eta}'.format( 136 | epoch+1, max_epoch, batch_idx+1, num_batches, 137 | batch_time=batch_time, 138 | data_time=data_time, 139 | loss_t=losses_t, 140 | loss_x=losses_x, 141 | acc=accs, 142 | lr=self.optimizer.param_groups[0]['lr'], 143 | eta=eta_str 144 | ) 145 | ) 146 | 147 | if self.writer is not None: 148 | n_iter = epoch * num_batches + batch_idx 149 | self.writer.add_scalar('Train/Time', batch_time.avg, n_iter) 150 | self.writer.add_scalar('Train/Data', data_time.avg, n_iter) 151 | self.writer.add_scalar('Train/Loss_t', losses_t.avg, n_iter) 152 | self.writer.add_scalar('Train/Loss_x', losses_x.avg, n_iter) 153 | self.writer.add_scalar('Train/Acc', accs.avg, n_iter) 154 | self.writer.add_scalar('Train/Lr', self.optimizer.param_groups[0]['lr'], n_iter) 155 | 156 | end = time.time() 157 | 158 | if self.scheduler is not None: 159 | self.scheduler.step() 160 | -------------------------------------------------------------------------------- /torchreid/engine/video/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .softmax import VideoSoftmaxEngine 4 | from .triplet import VideoTripletEngine -------------------------------------------------------------------------------- /torchreid/engine/video/softmax.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import time 6 | import datetime 7 | 8 | import torch 9 | 10 | import torchreid 11 | from torchreid.engine.image import ImageSoftmaxEngine 12 | 13 | 14 | class VideoSoftmaxEngine(ImageSoftmaxEngine): 15 | """Softmax-loss engine for video-reid. 16 | 17 | Args: 18 | datamanager (DataManager): an instance of ``torchreid.data.ImageDataManager`` 19 | or ``torchreid.data.VideoDataManager``. 20 | model (nn.Module): model instance. 21 | optimizer (Optimizer): an Optimizer. 22 | scheduler (LRScheduler, optional): if None, no learning rate decay will be performed. 23 | use_gpu (bool, optional): use gpu. Default is True. 24 | label_smooth (bool, optional): use label smoothing regularizer. Default is True. 25 | pooling_method (str, optional): how to pool features for a tracklet. 26 | Default is "avg" (average). Choices are ["avg", "max"]. 27 | 28 | Examples:: 29 | 30 | import torch 31 | import torchreid 32 | # Each batch contains batch_size*seq_len images 33 | datamanager = torchreid.data.VideoDataManager( 34 | root='path/to/reid-data', 35 | sources='mars', 36 | height=256, 37 | width=128, 38 | combineall=False, 39 | batch_size=8, # number of tracklets 40 | seq_len=15 # number of images in each tracklet 41 | ) 42 | model = torchreid.models.build_model( 43 | name='resnet50', 44 | num_classes=datamanager.num_train_pids, 45 | loss='softmax' 46 | ) 47 | model = model.cuda() 48 | optimizer = torchreid.optim.build_optimizer( 49 | model, optim='adam', lr=0.0003 50 | ) 51 | scheduler = torchreid.optim.build_lr_scheduler( 52 | optimizer, 53 | lr_scheduler='single_step', 54 | stepsize=20 55 | ) 56 | engine = torchreid.engine.VideoSoftmaxEngine( 57 | datamanager, model, optimizer, scheduler=scheduler, 58 | pooling_method='avg' 59 | ) 60 | engine.run( 61 | max_epoch=60, 62 | save_dir='log/resnet50-softmax-mars', 63 | print_freq=10 64 | ) 65 | """ 66 | 67 | def __init__(self, datamanager, model, optimizer, scheduler=None, 68 | use_gpu=True, label_smooth=True, pooling_method='avg'): 69 | super(VideoSoftmaxEngine, self).__init__(datamanager, model, optimizer, scheduler=scheduler, 70 | use_gpu=use_gpu, label_smooth=label_smooth) 71 | self.pooling_method = pooling_method 72 | 73 | def _parse_data_for_train(self, data): 74 | imgs = data[0] 75 | pids = data[1] 76 | if imgs.dim() == 5: 77 | # b: batch size 78 | # s: sqeuence length 79 | # c: channel depth 80 | # h: height 81 | # w: width 82 | b, s, c, h, w = imgs.size() 83 | imgs = imgs.view(b*s, c, h, w) 84 | pids = pids.view(b, 1).expand(b, s) 85 | pids = pids.contiguous().view(b*s) 86 | return imgs, pids 87 | 88 | def _extract_features(self, input): 89 | self.model.eval() 90 | # b: batch size 91 | # s: sqeuence length 92 | # c: channel depth 93 | # h: height 94 | # w: width 95 | b, s, c, h, w = input.size() 96 | input = input.view(b*s, c, h, w) 97 | features = self.model(input) 98 | features = features.view(b, s, -1) 99 | if self.pooling_method == 'avg': 100 | features = torch.mean(features, 1) 101 | else: 102 | features = torch.max(features, 1)[0] 103 | return features -------------------------------------------------------------------------------- /torchreid/engine/video/triplet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import time 6 | import datetime 7 | 8 | import torch 9 | 10 | import torchreid 11 | from torchreid.engine.image import ImageTripletEngine 12 | from torchreid.engine.video import VideoSoftmaxEngine 13 | 14 | 15 | class VideoTripletEngine(ImageTripletEngine, VideoSoftmaxEngine): 16 | """Triplet-loss engine for video-reid. 17 | 18 | Args: 19 | datamanager (DataManager): an instance of ``torchreid.data.ImageDataManager`` 20 | or ``torchreid.data.VideoDataManager``. 21 | model (nn.Module): model instance. 22 | optimizer (Optimizer): an Optimizer. 23 | margin (float, optional): margin for triplet loss. Default is 0.3. 24 | weight_t (float, optional): weight for triplet loss. Default is 1. 25 | weight_x (float, optional): weight for softmax loss. Default is 1. 26 | scheduler (LRScheduler, optional): if None, no learning rate decay will be performed. 27 | use_gpu (bool, optional): use gpu. Default is True. 28 | label_smooth (bool, optional): use label smoothing regularizer. Default is True. 29 | pooling_method (str, optional): how to pool features for a tracklet. 30 | Default is "avg" (average). Choices are ["avg", "max"]. 31 | 32 | Examples:: 33 | 34 | import torch 35 | import torchreid 36 | # Each batch contains batch_size*seq_len images 37 | # Each identity is sampled with num_instances tracklets 38 | datamanager = torchreid.data.VideoDataManager( 39 | root='path/to/reid-data', 40 | sources='mars', 41 | height=256, 42 | width=128, 43 | combineall=False, 44 | num_instances=4, 45 | train_sampler='RandomIdentitySampler' 46 | batch_size=8, # number of tracklets 47 | seq_len=15 # number of images in each tracklet 48 | ) 49 | model = torchreid.models.build_model( 50 | name='resnet50', 51 | num_classes=datamanager.num_train_pids, 52 | loss='triplet' 53 | ) 54 | model = model.cuda() 55 | optimizer = torchreid.optim.build_optimizer( 56 | model, optim='adam', lr=0.0003 57 | ) 58 | scheduler = torchreid.optim.build_lr_scheduler( 59 | optimizer, 60 | lr_scheduler='single_step', 61 | stepsize=20 62 | ) 63 | engine = torchreid.engine.VideoTripletEngine( 64 | datamanager, model, optimizer, margin=0.3, 65 | weight_t=0.7, weight_x=1, scheduler=scheduler, 66 | pooling_method='avg' 67 | ) 68 | engine.run( 69 | max_epoch=60, 70 | save_dir='log/resnet50-triplet-mars', 71 | print_freq=10 72 | ) 73 | """ 74 | 75 | def __init__(self, datamanager, model, optimizer, margin=0.3, 76 | weight_t=1, weight_x=1, scheduler=None, use_gpu=False, 77 | label_smooth=True, pooling_method='avg'): 78 | super(VideoTripletEngine, self).__init__(datamanager, model, optimizer, margin=margin, 79 | weight_t=weight_t, weight_x=weight_x, 80 | scheduler=scheduler, use_gpu=use_gpu, 81 | label_smooth=label_smooth) 82 | self.pooling_method = pooling_method 83 | -------------------------------------------------------------------------------- /torchreid/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .cross_entropy_loss import CrossEntropyLoss 6 | from .hard_mine_triplet_loss import TripletLoss 7 | from .center_loss import CenterLoss 8 | from .ranked_loss import RankedLoss 9 | 10 | def DeepSupervision(criterion, xs, y): 11 | """DeepSupervision 12 | 13 | Applies criterion to each element in a list. 14 | 15 | Args: 16 | criterion: loss function 17 | xs: tuple of inputs 18 | y: ground truth 19 | """ 20 | loss = 0. 21 | for x in xs: 22 | loss += criterion(x, y) 23 | loss /= len(xs) 24 | return loss -------------------------------------------------------------------------------- /torchreid/losses/center_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import warnings 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class CenterLoss(nn.Module): 11 | """Center loss. 12 | 13 | Reference: 14 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 15 | 16 | Args: 17 | - num_classes (int): number of classes. 18 | - feat_dim (int): feature dimension. 19 | """ 20 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True): 21 | super(CenterLoss, self).__init__() 22 | warnings.warn("This method is deprecated") 23 | self.num_classes = num_classes 24 | self.feat_dim = feat_dim 25 | self.use_gpu = use_gpu 26 | 27 | if self.use_gpu: 28 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 29 | else: 30 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 31 | 32 | def forward(self, x, labels): 33 | """ 34 | Args: 35 | - x: feature matrix with shape (batch_size, feat_dim). 36 | - labels: ground truth labels with shape (num_classes). 37 | """ 38 | batch_size = x.size(0) 39 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 40 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 41 | distmat.addmm_(1, -2, x, self.centers.t()) 42 | 43 | classes = torch.arange(self.num_classes).long() 44 | if self.use_gpu: classes = classes.cuda() 45 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 46 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 47 | 48 | dist = [] 49 | for i in range(batch_size): 50 | value = distmat[i][mask[i]] 51 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 52 | dist.append(value) 53 | dist = torch.cat(dist) 54 | loss = dist.mean() 55 | 56 | return loss 57 | -------------------------------------------------------------------------------- /torchreid/losses/cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class CrossEntropyLoss(nn.Module): 9 | r"""Cross entropy loss with label smoothing regularizer. 10 | 11 | Reference: 12 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 13 | 14 | With label smoothing, the label :math:`y` for a class is computed by 15 | 16 | .. math:: 17 | \begin{equation} 18 | (1 - \epsilon) \times y + \frac{\epsilon}{K}, 19 | \end{equation} 20 | 21 | where :math:`K` denotes the number of classes and :math:`\epsilon` is a weight. When 22 | :math:`\epsilon = 0`, the loss function reduces to the normal cross entropy. 23 | 24 | Args: 25 | num_classes (int): number of classes. 26 | epsilon (float, optional): weight. Default is 0.1. 27 | use_gpu (bool, optional): whether to use gpu devices. Default is True. 28 | label_smooth (bool, optional): whether to apply label smoothing. Default is True. 29 | """ 30 | 31 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, label_smooth=True): 32 | super(CrossEntropyLoss, self).__init__() 33 | self.num_classes = num_classes 34 | self.epsilon = epsilon if label_smooth else 0 35 | self.use_gpu = use_gpu 36 | self.logsoftmax = nn.LogSoftmax(dim=1) 37 | 38 | def forward(self, inputs, targets): 39 | """ 40 | Args: 41 | inputs (torch.Tensor): prediction matrix (before softmax) with 42 | shape (batch_size, num_classes). 43 | targets (torch.LongTensor): ground truth labels with shape (batch_size). 44 | Each position contains the label index. 45 | """ 46 | log_probs = self.logsoftmax(inputs) 47 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 48 | if self.use_gpu: targets = targets.cuda() 49 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 50 | return (- targets * log_probs).mean(0).sum() -------------------------------------------------------------------------------- /torchreid/losses/hard_mine_triplet_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class TripletLoss(nn.Module): 9 | """Triplet loss with hard positive/negative mining. 10 | 11 | Reference: 12 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 13 | 14 | Imported from ``_. 15 | 16 | Args: 17 | margin (float, optional): margin for triplet. Default is 0.3. 18 | """ 19 | 20 | def __init__(self, margin=0.3): 21 | super(TripletLoss, self).__init__() 22 | self.margin = margin 23 | if margin == 0.: 24 | self.ranking_loss = nn.SoftMarginLoss() 25 | else: 26 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 27 | 28 | def forward(self, inputs, targets): 29 | """ 30 | Args: 31 | inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim). 32 | targets (torch.LongTensor): ground truth labels with shape (num_classes). 33 | """ 34 | n = inputs.size(0) 35 | 36 | # Compute pairwise distance, replace by the official when merged 37 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 38 | dist = dist + dist.t() 39 | dist.addmm_(1, -2, inputs, inputs.t()) 40 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 41 | 42 | # For each anchor, find the hardest positive and negative 43 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 44 | dist_ap, dist_an = [], [] 45 | for i in range(n): 46 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 47 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 48 | dist_ap = torch.cat(dist_ap) 49 | dist_an = torch.cat(dist_an) 50 | 51 | # Compute ranking hinge loss 52 | y = torch.ones_like(dist_an) 53 | if self.margin == 0.: 54 | loss = self.ranking_loss(dist_an - dist_ap, y) 55 | else: 56 | loss = self.ranking_loss(dist_an, dist_ap, y) 57 | return loss 58 | -------------------------------------------------------------------------------- /torchreid/losses/ranked_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | def normalize_rank(x, axis=-1): 5 | """Normalizing to unit length along the specified dimension. 6 | Args: 7 | x: pytorch Variable 8 | Returns: 9 | x: pytorch Variable, same shape as input 10 | """ 11 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 12 | return x 13 | 14 | def euclidean_dist_rank(x, y): 15 | """ 16 | Args: 17 | x: pytorch Variable, with shape [m, d] 18 | y: pytorch Variable, with shape [n, d] 19 | Returns: 20 | dist: pytorch Variable, with shape [m, n] 21 | """ 22 | m, n = x.size(0), y.size(0) 23 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 24 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 25 | dist = xx + yy 26 | dist.addmm_(1, -2, x, y.t()) 27 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 28 | return dist 29 | 30 | def rank_loss(dist_mat, labels, margin, alpha, tval): 31 | """ 32 | Args: 33 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 34 | labels: pytorch LongTensor, with shape [N] 35 | 36 | """ 37 | assert len(dist_mat.size()) == 2 38 | assert dist_mat.size(0) == dist_mat.size(1) 39 | N = dist_mat.size(0) 40 | 41 | total_loss = 0.0 42 | for ind in range(N): 43 | is_pos = labels.eq(labels[ind]) 44 | is_pos[ind] = 0 45 | is_neg = labels.ne(labels[ind]) 46 | 47 | dist_ap = dist_mat[ind][is_pos] 48 | dist_an = dist_mat[ind][is_neg] 49 | 50 | ap_is_pos = torch.clamp(torch.add(dist_ap,margin-alpha),min=0.0) 51 | ap_pos_num = ap_is_pos.size(0) +1e-5 52 | ap_pos_val_sum = torch.sum(ap_is_pos) 53 | loss_ap = torch.div(ap_pos_val_sum,float(ap_pos_num)) 54 | 55 | an_is_pos = torch.lt(dist_an,alpha) 56 | an_less_alpha = dist_an[an_is_pos] 57 | an_weight = torch.exp(tval*(-1*an_less_alpha+alpha)) 58 | an_weight_sum = torch.sum(an_weight) +1e-5 59 | an_dist_lm = alpha - an_less_alpha 60 | an_ln_sum = torch.sum(torch.mul(an_dist_lm,an_weight)) 61 | loss_an = torch.div(an_ln_sum,an_weight_sum) 62 | 63 | total_loss = total_loss+loss_ap+loss_an 64 | total_loss = total_loss*1.0/N 65 | return total_loss 66 | 67 | class RankedLoss(object): 68 | "Ranked_List_Loss_for_Deep_Metric_Learning_CVPR_2019_paper" 69 | 70 | def __init__(self, margin=1.3, alpha=2.0, tval=1.0): 71 | self.margin = margin 72 | self.alpha = alpha 73 | self.tval = tval 74 | 75 | def __call__(self, global_feat, labels, normalize_feature=False): 76 | if normalize_feature: 77 | global_feat = normalize_rank(global_feat, axis=-1) 78 | dist_mat = euclidean_dist_rank(global_feat, global_feat) 79 | total_loss = rank_loss(dist_mat,labels,self.margin,self.alpha,self.tval) 80 | 81 | return total_loss 82 | -------------------------------------------------------------------------------- /torchreid/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .accuracy import accuracy 4 | from .rank import evaluate_rank 5 | from .distance import compute_distance_matrix -------------------------------------------------------------------------------- /torchreid/metrics/accuracy.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | 6 | def accuracy(output, target, topk=(1,)): 7 | """Computes the accuracy over the k top predictions for 8 | the specified values of k. 9 | 10 | Args: 11 | output (torch.Tensor): prediction matrix with shape (batch_size, num_classes). 12 | target (torch.LongTensor): ground truth labels with shape (batch_size). 13 | topk (tuple, optional): accuracy at top-k will be computed. For example, 14 | topk=(1, 5) means accuracy at top-1 and top-5 will be computed. 15 | 16 | Returns: 17 | list: accuracy at top-k. 18 | 19 | Examples:: 20 | >>> from torchreid import metrics 21 | >>> metrics.accuracy(output, target) 22 | """ 23 | maxk = max(topk) 24 | batch_size = target.size(0) 25 | 26 | if isinstance(output, (tuple, list)): 27 | output = output[0] 28 | 29 | _, pred = output.topk(maxk, 1, True, True) 30 | pred = pred.t() 31 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 32 | 33 | res = [] 34 | for k in topk: 35 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 36 | acc = correct_k.mul_(100.0 / batch_size) 37 | res.append(acc) 38 | 39 | return res -------------------------------------------------------------------------------- /torchreid/metrics/distance.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import numpy as np 6 | 7 | import torch 8 | from torch.nn import functional as F 9 | 10 | 11 | def compute_distance_matrix(input1, input2, metric='euclidean'): 12 | """A wrapper function for computing distance matrix. 13 | 14 | Args: 15 | input1 (torch.Tensor): 2-D feature matrix. 16 | input2 (torch.Tensor): 2-D feature matrix. 17 | metric (str, optional): "euclidean" or "cosine". 18 | Default is "euclidean". 19 | 20 | Returns: 21 | torch.Tensor: distance matrix. 22 | 23 | Examples:: 24 | >>> from torchreid import metrics 25 | >>> input1 = torch.rand(10, 2048) 26 | >>> input2 = torch.rand(100, 2048) 27 | >>> distmat = metrics.compute_distance_matrix(input1, input2) 28 | >>> distmat.size() # (10, 100) 29 | """ 30 | # check input 31 | assert isinstance(input1, torch.Tensor) 32 | assert isinstance(input2, torch.Tensor) 33 | assert input1.dim() == 2, 'Expected 2-D tensor, but got {}-D'.format(input1.dim()) 34 | assert input2.dim() == 2, 'Expected 2-D tensor, but got {}-D'.format(input2.dim()) 35 | assert input1.size(1) == input2.size(1) 36 | 37 | if metric == 'euclidean': 38 | distmat = euclidean_squared_distance(input1, input2) 39 | elif metric == 'cosine': 40 | distmat = cosine_distance(input1, input2) 41 | else: 42 | raise ValueError( 43 | 'Unknown distance metric: {}. ' 44 | 'Please choose either "euclidean" or "cosine"'.format(metric) 45 | ) 46 | 47 | return distmat 48 | 49 | 50 | def euclidean_squared_distance(input1, input2): 51 | """Computes euclidean squared distance. 52 | 53 | Args: 54 | input1 (torch.Tensor): 2-D feature matrix. 55 | input2 (torch.Tensor): 2-D feature matrix. 56 | 57 | Returns: 58 | torch.Tensor: distance matrix. 59 | """ 60 | m, n = input1.size(0), input2.size(0) 61 | distmat = torch.pow(input1, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 62 | torch.pow(input2, 2).sum(dim=1, keepdim=True).expand(n, m).t() 63 | distmat.addmm_(1, -2, input1, input2.t()) 64 | return distmat 65 | 66 | 67 | def cosine_distance(input1, input2): 68 | """Computes cosine distance. 69 | 70 | Args: 71 | input1 (torch.Tensor): 2-D feature matrix. 72 | input2 (torch.Tensor): 2-D feature matrix. 73 | 74 | Returns: 75 | torch.Tensor: distance matrix. 76 | """ 77 | input1_normed = F.normalize(input1, p=2, dim=1) 78 | input2_normed = F.normalize(input2, p=2, dim=1) 79 | distmat = 1 - torch.mm(input1_normed, input2_normed.t()) 80 | return distmat -------------------------------------------------------------------------------- /torchreid/metrics/rank.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import numpy as np 6 | import copy 7 | from collections import defaultdict 8 | import sys 9 | import warnings 10 | 11 | try: 12 | from torchreid.metrics.rank_cylib.rank_cy import evaluate_cy 13 | IS_CYTHON_AVAI = True 14 | except ImportError: 15 | IS_CYTHON_AVAI = False 16 | warnings.warn( 17 | 'Cython evaluation (very fast so highly recommended) is ' 18 | 'unavailable, now use python evaluation.' 19 | ) 20 | 21 | 22 | def eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank): 23 | """Evaluation with cuhk03 metric 24 | Key: one image for each gallery identity is randomly sampled for each query identity. 25 | Random sampling is performed num_repeats times. 26 | """ 27 | num_repeats = 10 28 | num_q, num_g = distmat.shape 29 | 30 | if num_g < max_rank: 31 | max_rank = num_g 32 | print('Note: number of gallery samples is quite small, got {}'.format(num_g)) 33 | 34 | indices = np.argsort(distmat, axis=1) 35 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 36 | 37 | # compute cmc curve for each query 38 | all_cmc = [] 39 | all_AP = [] 40 | num_valid_q = 0. # number of valid query 41 | 42 | for q_idx in range(num_q): 43 | # get query pid and camid 44 | q_pid = q_pids[q_idx] 45 | q_camid = q_camids[q_idx] 46 | 47 | # remove gallery samples that have the same pid and camid with query 48 | order = indices[q_idx] 49 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 50 | keep = np.invert(remove) 51 | 52 | # compute cmc curve 53 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 54 | if not np.any(raw_cmc): 55 | # this condition is true when query identity does not appear in gallery 56 | continue 57 | 58 | kept_g_pids = g_pids[order][keep] 59 | g_pids_dict = defaultdict(list) 60 | for idx, pid in enumerate(kept_g_pids): 61 | g_pids_dict[pid].append(idx) 62 | 63 | cmc = 0. 64 | for repeat_idx in range(num_repeats): 65 | mask = np.zeros(len(raw_cmc), dtype=np.bool) 66 | for _, idxs in g_pids_dict.items(): 67 | # randomly sample one image for each gallery person 68 | rnd_idx = np.random.choice(idxs) 69 | mask[rnd_idx] = True 70 | masked_raw_cmc = raw_cmc[mask] 71 | _cmc = masked_raw_cmc.cumsum() 72 | _cmc[_cmc > 1] = 1 73 | cmc += _cmc[:max_rank].astype(np.float32) 74 | 75 | cmc /= num_repeats 76 | all_cmc.append(cmc) 77 | # compute AP 78 | num_rel = raw_cmc.sum() 79 | tmp_cmc = raw_cmc.cumsum() 80 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 81 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 82 | AP = tmp_cmc.sum() / num_rel 83 | all_AP.append(AP) 84 | num_valid_q += 1. 85 | 86 | assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' 87 | 88 | all_cmc = np.asarray(all_cmc).astype(np.float32) 89 | all_cmc = all_cmc.sum(0) / num_valid_q 90 | mAP = np.mean(all_AP) 91 | 92 | return all_cmc, mAP 93 | 94 | 95 | def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank): 96 | """Evaluation with market1501 metric 97 | Key: for each query identity, its gallery images from the same camera view are discarded. 98 | """ 99 | num_q, num_g = distmat.shape 100 | 101 | if num_g < max_rank: 102 | max_rank = num_g 103 | print('Note: number of gallery samples is quite small, got {}'.format(num_g)) 104 | 105 | indices = np.argsort(distmat, axis=1) 106 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 107 | 108 | # compute cmc curve for each query 109 | all_cmc = [] 110 | all_AP = [] 111 | num_valid_q = 0. # number of valid query 112 | 113 | for q_idx in range(num_q): 114 | # get query pid and camid 115 | q_pid = q_pids[q_idx] 116 | q_camid = q_camids[q_idx] 117 | 118 | # remove gallery samples that have the same pid and camid with query 119 | order = indices[q_idx] 120 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 121 | keep = np.invert(remove) 122 | 123 | # compute cmc curve 124 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 125 | if not np.any(raw_cmc): 126 | # this condition is true when query identity does not appear in gallery 127 | continue 128 | 129 | cmc = raw_cmc.cumsum() 130 | cmc[cmc > 1] = 1 131 | 132 | all_cmc.append(cmc[:max_rank]) 133 | num_valid_q += 1. 134 | 135 | # compute average precision 136 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 137 | num_rel = raw_cmc.sum() 138 | tmp_cmc = raw_cmc.cumsum() 139 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 140 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 141 | AP = tmp_cmc.sum() / num_rel 142 | all_AP.append(AP) 143 | 144 | assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' 145 | 146 | all_cmc = np.asarray(all_cmc).astype(np.float32) 147 | all_cmc = all_cmc.sum(0) / num_valid_q 148 | mAP = np.mean(all_AP) 149 | 150 | return all_cmc, mAP 151 | 152 | 153 | def evaluate_py(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03): 154 | if use_metric_cuhk03: 155 | return eval_cuhk03(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 156 | else: 157 | return eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 158 | 159 | 160 | def evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50, 161 | use_metric_cuhk03=False, use_cython=True): 162 | """Evaluates CMC rank. 163 | 164 | Args: 165 | distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery). 166 | q_pids (numpy.ndarray): 1-D array containing person identities 167 | of each query instance. 168 | g_pids (numpy.ndarray): 1-D array containing person identities 169 | of each gallery instance. 170 | q_camids (numpy.ndarray): 1-D array containing camera views under 171 | which each query instance is captured. 172 | g_camids (numpy.ndarray): 1-D array containing camera views under 173 | which each gallery instance is captured. 174 | max_rank (int, optional): maximum CMC rank to be computed. Default is 50. 175 | use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03. 176 | Default is False. This should be enabled when using cuhk03 classic split. 177 | use_cython (bool, optional): use cython code for evaluation. Default is True. 178 | This is highly recommended as the cython code can speed up the cmc computation 179 | by more than 10x. This requires Cython to be installed. 180 | """ 181 | if use_cython and IS_CYTHON_AVAI: 182 | return evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03) 183 | else: 184 | return evaluate_py(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03) -------------------------------------------------------------------------------- /torchreid/metrics/rank_cylib/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | $(PYTHON) setup.py build_ext --inplace 3 | rm -rf build 4 | clean: 5 | rm -rf build 6 | rm -f rank_cy.c *.so -------------------------------------------------------------------------------- /torchreid/metrics/rank_cylib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anocodetest1/FPB/297415164c3ab648c5ddd0f0595733657700c612/torchreid/metrics/rank_cylib/__init__.py -------------------------------------------------------------------------------- /torchreid/metrics/rank_cylib/rank_cy.pyx: -------------------------------------------------------------------------------- 1 | # cython: boundscheck=False, wraparound=False, nonecheck=False, cdivision=True 2 | 3 | from __future__ import print_function 4 | 5 | import cython 6 | import numpy as np 7 | cimport numpy as np 8 | from collections import defaultdict 9 | import random 10 | 11 | 12 | """ 13 | Compiler directives: 14 | https://github.com/cython/cython/wiki/enhancements-compilerdirectives 15 | 16 | Cython tutorial: 17 | https://cython.readthedocs.io/en/latest/src/userguide/numpy_tutorial.html 18 | 19 | Credit to https://github.com/luzai 20 | """ 21 | 22 | 23 | # Main interface 24 | cpdef evaluate_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=False): 25 | distmat = np.asarray(distmat, dtype=np.float32) 26 | q_pids = np.asarray(q_pids, dtype=np.int64) 27 | g_pids = np.asarray(g_pids, dtype=np.int64) 28 | q_camids = np.asarray(q_camids, dtype=np.int64) 29 | g_camids = np.asarray(g_camids, dtype=np.int64) 30 | if use_metric_cuhk03: 31 | return eval_cuhk03_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 32 | return eval_market1501_cy(distmat, q_pids, g_pids, q_camids, g_camids, max_rank) 33 | 34 | 35 | cpdef eval_cuhk03_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids, 36 | long[:]q_camids, long[:]g_camids, long max_rank): 37 | 38 | cdef long num_q = distmat.shape[0] 39 | cdef long num_g = distmat.shape[1] 40 | 41 | if num_g < max_rank: 42 | max_rank = num_g 43 | print('Note: number of gallery samples is quite small, got {}'.format(num_g)) 44 | 45 | cdef: 46 | long num_repeats = 10 47 | long[:,:] indices = np.argsort(distmat, axis=1) 48 | long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64) 49 | 50 | float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32) 51 | float[:] all_AP = np.zeros(num_q, dtype=np.float32) 52 | float num_valid_q = 0. # number of valid query 53 | 54 | long q_idx, q_pid, q_camid, g_idx 55 | long[:] order = np.zeros(num_g, dtype=np.int64) 56 | long keep 57 | 58 | float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches 59 | float[:] masked_raw_cmc = np.zeros(num_g, dtype=np.float32) 60 | float[:] cmc, masked_cmc 61 | long num_g_real, num_g_real_masked, rank_idx, rnd_idx 62 | unsigned long meet_condition 63 | float AP 64 | long[:] kept_g_pids, mask 65 | 66 | float num_rel 67 | float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32) 68 | float tmp_cmc_sum 69 | 70 | for q_idx in range(num_q): 71 | # get query pid and camid 72 | q_pid = q_pids[q_idx] 73 | q_camid = q_camids[q_idx] 74 | 75 | # remove gallery samples that have the same pid and camid with query 76 | for g_idx in range(num_g): 77 | order[g_idx] = indices[q_idx, g_idx] 78 | num_g_real = 0 79 | meet_condition = 0 80 | kept_g_pids = np.zeros(num_g, dtype=np.int64) 81 | 82 | for g_idx in range(num_g): 83 | if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid): 84 | raw_cmc[num_g_real] = matches[q_idx][g_idx] 85 | kept_g_pids[num_g_real] = g_pids[order[g_idx]] 86 | num_g_real += 1 87 | if matches[q_idx][g_idx] > 1e-31: 88 | meet_condition = 1 89 | 90 | if not meet_condition: 91 | # this condition is true when query identity does not appear in gallery 92 | continue 93 | 94 | # cuhk03-specific setting 95 | g_pids_dict = defaultdict(list) # overhead! 96 | for g_idx in range(num_g_real): 97 | g_pids_dict[kept_g_pids[g_idx]].append(g_idx) 98 | 99 | cmc = np.zeros(max_rank, dtype=np.float32) 100 | for _ in range(num_repeats): 101 | mask = np.zeros(num_g_real, dtype=np.int64) 102 | 103 | for _, idxs in g_pids_dict.items(): 104 | # randomly sample one image for each gallery person 105 | rnd_idx = np.random.choice(idxs) 106 | #rnd_idx = idxs[0] # use deterministic for debugging 107 | mask[rnd_idx] = 1 108 | 109 | num_g_real_masked = 0 110 | for g_idx in range(num_g_real): 111 | if mask[g_idx] == 1: 112 | masked_raw_cmc[num_g_real_masked] = raw_cmc[g_idx] 113 | num_g_real_masked += 1 114 | 115 | masked_cmc = np.zeros(num_g, dtype=np.float32) 116 | function_cumsum(masked_raw_cmc, masked_cmc, num_g_real_masked) 117 | for g_idx in range(num_g_real_masked): 118 | if masked_cmc[g_idx] > 1: 119 | masked_cmc[g_idx] = 1 120 | 121 | for rank_idx in range(max_rank): 122 | cmc[rank_idx] += masked_cmc[rank_idx] / num_repeats 123 | 124 | for rank_idx in range(max_rank): 125 | all_cmc[q_idx, rank_idx] = cmc[rank_idx] 126 | # compute average precision 127 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 128 | function_cumsum(raw_cmc, tmp_cmc, num_g_real) 129 | num_rel = 0 130 | tmp_cmc_sum = 0 131 | for g_idx in range(num_g_real): 132 | tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx] 133 | num_rel += raw_cmc[g_idx] 134 | all_AP[q_idx] = tmp_cmc_sum / num_rel 135 | num_valid_q += 1. 136 | 137 | assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' 138 | 139 | # compute averaged cmc 140 | cdef float[:] avg_cmc = np.zeros(max_rank, dtype=np.float32) 141 | for rank_idx in range(max_rank): 142 | for q_idx in range(num_q): 143 | avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx] 144 | avg_cmc[rank_idx] /= num_valid_q 145 | 146 | cdef float mAP = 0 147 | for q_idx in range(num_q): 148 | mAP += all_AP[q_idx] 149 | mAP /= num_valid_q 150 | 151 | return np.asarray(avg_cmc).astype(np.float32), mAP 152 | 153 | 154 | cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids, 155 | long[:]q_camids, long[:]g_camids, long max_rank): 156 | 157 | cdef long num_q = distmat.shape[0] 158 | cdef long num_g = distmat.shape[1] 159 | 160 | if num_g < max_rank: 161 | max_rank = num_g 162 | print('Note: number of gallery samples is quite small, got {}'.format(num_g)) 163 | 164 | cdef: 165 | long[:,:] indices = np.argsort(distmat, axis=1) 166 | long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64) 167 | 168 | float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32) 169 | float[:] all_AP = np.zeros(num_q, dtype=np.float32) 170 | float num_valid_q = 0. # number of valid query 171 | 172 | long q_idx, q_pid, q_camid, g_idx 173 | long[:] order = np.zeros(num_g, dtype=np.int64) 174 | long keep 175 | 176 | float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches 177 | float[:] cmc = np.zeros(num_g, dtype=np.float32) 178 | long num_g_real, rank_idx 179 | unsigned long meet_condition 180 | 181 | float num_rel 182 | float[:] tmp_cmc = np.zeros(num_g, dtype=np.float32) 183 | float tmp_cmc_sum 184 | 185 | for q_idx in range(num_q): 186 | # get query pid and camid 187 | q_pid = q_pids[q_idx] 188 | q_camid = q_camids[q_idx] 189 | 190 | # remove gallery samples that have the same pid and camid with query 191 | for g_idx in range(num_g): 192 | order[g_idx] = indices[q_idx, g_idx] 193 | num_g_real = 0 194 | meet_condition = 0 195 | 196 | for g_idx in range(num_g): 197 | if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid): 198 | raw_cmc[num_g_real] = matches[q_idx][g_idx] 199 | num_g_real += 1 200 | if matches[q_idx][g_idx] > 1e-31: 201 | meet_condition = 1 202 | 203 | if not meet_condition: 204 | # this condition is true when query identity does not appear in gallery 205 | continue 206 | 207 | # compute cmc 208 | function_cumsum(raw_cmc, cmc, num_g_real) 209 | for g_idx in range(num_g_real): 210 | if cmc[g_idx] > 1: 211 | cmc[g_idx] = 1 212 | 213 | for rank_idx in range(max_rank): 214 | all_cmc[q_idx, rank_idx] = cmc[rank_idx] 215 | num_valid_q += 1. 216 | 217 | # compute average precision 218 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 219 | function_cumsum(raw_cmc, tmp_cmc, num_g_real) 220 | num_rel = 0 221 | tmp_cmc_sum = 0 222 | for g_idx in range(num_g_real): 223 | tmp_cmc_sum += (tmp_cmc[g_idx] / (g_idx + 1.)) * raw_cmc[g_idx] 224 | num_rel += raw_cmc[g_idx] 225 | all_AP[q_idx] = tmp_cmc_sum / num_rel 226 | 227 | assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' 228 | 229 | # compute averaged cmc 230 | cdef float[:] avg_cmc = np.zeros(max_rank, dtype=np.float32) 231 | for rank_idx in range(max_rank): 232 | for q_idx in range(num_q): 233 | avg_cmc[rank_idx] += all_cmc[q_idx, rank_idx] 234 | avg_cmc[rank_idx] /= num_valid_q 235 | 236 | cdef float mAP = 0 237 | for q_idx in range(num_q): 238 | mAP += all_AP[q_idx] 239 | mAP /= num_valid_q 240 | 241 | return np.asarray(avg_cmc).astype(np.float32), mAP 242 | 243 | 244 | # Compute the cumulative sum 245 | cdef void function_cumsum(cython.numeric[:] src, cython.numeric[:] dst, long n): 246 | cdef long i 247 | dst[0] = src[0] 248 | for i in range(1, n): 249 | dst[i] = src[i] + dst[i - 1] -------------------------------------------------------------------------------- /torchreid/metrics/rank_cylib/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from distutils.extension import Extension 3 | from Cython.Build import cythonize 4 | import numpy as np 5 | 6 | 7 | def numpy_include(): 8 | try: 9 | numpy_include = np.get_include() 10 | except AttributeError: 11 | numpy_include = np.get_numpy_include() 12 | return numpy_include 13 | 14 | 15 | ext_modules = [ 16 | Extension( 17 | 'rank_cy', 18 | ['rank_cy.pyx'], 19 | include_dirs=[numpy_include()], 20 | ) 21 | ] 22 | 23 | setup( 24 | name='Cython-based reid evaluation code', 25 | ext_modules=cythonize(ext_modules) 26 | ) -------------------------------------------------------------------------------- /torchreid/metrics/rank_cylib/test_cython.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | import os.path as osp 5 | import timeit 6 | import numpy as np 7 | 8 | sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..') 9 | from torchreid import metrics 10 | 11 | """ 12 | Test the speed of cython-based evaluation code. The speed improvements 13 | can be much bigger when using the real reid data, which contains a larger 14 | amount of query and gallery images. 15 | 16 | Note: you might encounter the following error: 17 | 'AssertionError: Error: all query identities do not appear in gallery'. 18 | This is normal because the inputs are random numbers. Just try again. 19 | """ 20 | 21 | print('*** Compare running time ***') 22 | 23 | setup = ''' 24 | import sys 25 | import os.path as osp 26 | import numpy as np 27 | sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..') 28 | from torchreid import metrics 29 | num_q = 30 30 | num_g = 300 31 | max_rank = 5 32 | distmat = np.random.rand(num_q, num_g) * 20 33 | q_pids = np.random.randint(0, num_q, size=num_q) 34 | g_pids = np.random.randint(0, num_g, size=num_g) 35 | q_camids = np.random.randint(0, 5, size=num_q) 36 | g_camids = np.random.randint(0, 5, size=num_g) 37 | ''' 38 | 39 | print('=> Using market1501\'s metric') 40 | pytime = timeit.timeit('metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)', setup=setup, number=20) 41 | cytime = timeit.timeit('metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)', setup=setup, number=20) 42 | print('Python time: {} s'.format(pytime)) 43 | print('Cython time: {} s'.format(cytime)) 44 | print('Cython is {} times faster than python\n'.format(pytime / cytime)) 45 | 46 | print('=> Using cuhk03\'s metric') 47 | pytime = timeit.timeit('metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=False)', setup=setup, number=20) 48 | cytime = timeit.timeit('metrics.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=True)', setup=setup, number=20) 49 | print('Python time: {} s'.format(pytime)) 50 | print('Cython time: {} s'.format(cytime)) 51 | print('Cython is {} times faster than python\n'.format(pytime / cytime)) 52 | 53 | """ 54 | print("=> Check precision") 55 | 56 | num_q = 30 57 | num_g = 300 58 | max_rank = 5 59 | distmat = np.random.rand(num_q, num_g) * 20 60 | q_pids = np.random.randint(0, num_q, size=num_q) 61 | g_pids = np.random.randint(0, num_g, size=num_g) 62 | q_camids = np.random.randint(0, 5, size=num_q) 63 | g_camids = np.random.randint(0, 5, size=num_g) 64 | 65 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False) 66 | print("Python:\nmAP = {} \ncmc = {}\n".format(mAP, cmc)) 67 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True) 68 | print("Cython:\nmAP = {} \ncmc = {}\n".format(mAP, cmc)) 69 | """ -------------------------------------------------------------------------------- /torchreid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | from .fpb import * 6 | 7 | 8 | __model_factory = { 9 | # image classification models 10 | # lightweight models 11 | # reid-specific models 12 | 'fpb': fpb 13 | } 14 | 15 | 16 | def show_avai_models(): 17 | """Displays available models. 18 | 19 | Examples:: 20 | >>> from torchreid import models 21 | >>> models.show_avai_models() 22 | """ 23 | print(list(__model_factory.keys())) 24 | 25 | 26 | def build_model(name, num_classes, loss='softmax', pretrained=True, use_gpu=True): 27 | """A function wrapper for building a model. 28 | 29 | Args: 30 | name (str): model name. 31 | num_classes (int): number of training identities. 32 | loss (str, optional): loss function to optimize the model. Currently 33 | supports "softmax" and "triplet". Default is "softmax". 34 | pretrained (bool, optional): whether to load ImageNet-pretrained weights. 35 | Default is True. 36 | use_gpu (bool, optional): whether to use gpu. Default is True. 37 | 38 | Returns: 39 | nn.Module 40 | 41 | Examples:: 42 | >>> from torchreid import models 43 | >>> model = models.build_model('resnet50', 751, loss='softmax') 44 | """ 45 | avai_models = list(__model_factory.keys()) 46 | if name not in avai_models: 47 | raise KeyError('Unknown model: {}. Must be one of {}'.format(name, avai_models)) 48 | return __model_factory[name]( 49 | num_classes=num_classes, 50 | loss=loss, 51 | pretrained=pretrained, 52 | use_gpu=use_gpu 53 | ) 54 | 55 | -------------------------------------------------------------------------------- /torchreid/models/fpb.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | __all__ = ['fpb'] 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | import torch.utils.model_zoo as model_zoo 10 | from torchvision.models.resnet import resnet50, resnet101, Bottleneck 11 | from .nn_utils import * 12 | from .pc import * 13 | 14 | class FPNModule(nn.Module): 15 | def __init__(self, num_layers, num_channels): 16 | super(FPNModule, self).__init__() 17 | self.num_layers = num_layers 18 | self.num_channels = num_channels 19 | self.eps = 0.0001 20 | 21 | self.convs = nn.ModuleList() 22 | self.downsamples = nn.ModuleList() 23 | for _ in range(4): 24 | conv = Conv_Bn_Relu(self.num_channels, self.num_channels, k=3, p=1) 25 | self.convs.append(conv) 26 | 27 | for _ in range(2): 28 | downsample = nn.Sequential(nn.Conv2d(self.num_channels, self.num_channels, 1, bias=False), nn.BatchNorm2d(self.num_channels), nn.ReLU(inplace=True)) 29 | self.downsamples.append(downsample) 30 | 31 | self.pc1 = PC_Module(self.num_channels, dropout=True) 32 | 33 | self._init_params() 34 | 35 | def _init_params(self): 36 | for downsample in self.downsamples: 37 | init_struct(downsample) 38 | 39 | return 40 | 41 | def forward(self, x): 42 | 43 | reg_feats = [] 44 | 45 | y = x 46 | x_clone = [] 47 | for t in x: 48 | x_clone.append(t.clone()) 49 | 50 | reg_feats.append(y[0]) 51 | y[0] = self.convs[0](y[0]) 52 | reg_feat = self.pc1(y[1]) 53 | reg_feats.append(reg_feat) 54 | y[1] = self.convs[1](reg_feat+F.interpolate(y[0], scale_factor=2, mode='nearest')) 55 | y[1] = self.convs[2](y[1])+self.downsamples[0](x_clone[1]) 56 | y[0] = self.convs[3](y[0]+F.max_pool2d(y[1], kernel_size=2))+self.downsamples[1](x_clone[0]) 57 | 58 | return y, reg_feats 59 | 60 | 61 | class FPN(nn.Module): 62 | def __init__(self, num_layers, in_channels): 63 | super(FPN, self).__init__() 64 | 65 | self.num_layers = num_layers 66 | self.in_channels = in_channels 67 | self.num_neck_channel = 256 68 | 69 | self.lateral_convs = nn.ModuleList() 70 | for i in range(1, self.num_layers+1): 71 | conv = Conv_Bn_Relu(in_channels[i], self.num_neck_channel, k=1) 72 | self.lateral_convs.append(conv) 73 | 74 | self.fpn_module1 = FPNModule(self.num_layers, self.num_neck_channel) 75 | 76 | self.conv = Conv_Bn_Relu(self.num_neck_channel, self.in_channels[1], k=1, activation_cfg=False) 77 | self.relu = nn.ReLU(inplace=True) 78 | 79 | self._init_params() 80 | 81 | def _init_params(self): 82 | 83 | return 84 | 85 | def forward(self, x): 86 | y = [] 87 | for i in range(self.num_layers): 88 | y.append(self.lateral_convs[i](x[i+1])) 89 | 90 | y, reg_feat = self.fpn_module1(y) 91 | 92 | y = self.conv(y[0])# 93 | y = self.relu(y) 94 | 95 | return y, reg_feat 96 | 97 | 98 | class FPB(nn.Module): 99 | def __init__(self, num_classes, loss=None, **kwargs): 100 | super(FPB, self).__init__() 101 | 102 | resnet_ = resnet50(pretrained=True) 103 | 104 | self.num_parts = 3 105 | self.branch_layers = 2 106 | self.loss = loss 107 | 108 | self.layer0 = nn.Sequential(resnet_.conv1, resnet_.bn1, resnet_.relu, 109 | resnet_.maxpool) 110 | 111 | self.layer1 = resnet_.layer1 112 | self.layer2 = resnet_.layer2 113 | self.pc1 = PC_Module(512, dropout=True) 114 | self.layer3 = resnet_.layer3 115 | 116 | layer4 = nn.Sequential( 117 | Bottleneck(1024, 118 | 512, 119 | downsample=nn.Sequential( 120 | nn.Conv2d(1024, 2048, 1, bias=False), 121 | nn.BatchNorm2d(2048))), Bottleneck(2048, 512), 122 | Bottleneck(2048, 512)) 123 | layer4.load_state_dict(resnet_.layer4.state_dict()) 124 | 125 | self.layer4 = layer4 126 | 127 | self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1)) 128 | self.bn_l4 = nn.BatchNorm2d(2048) 129 | self.classifier_l4 = nn.Linear(2048, num_classes) 130 | 131 | self.in_channels = [2048, 1024, 512, 256] 132 | self.neck = FPN(self.branch_layers, self.in_channels) 133 | 134 | self.part_pools = nn.AdaptiveAvgPool2d((self.num_parts, 1)) 135 | self.dim_reds = DimReduceLayer(self.in_channels[1], 256) 136 | self.classifiers = nn.ModuleList([nn.Linear(256, num_classes) for _ in range(self.num_parts)]) 137 | 138 | self._init_params() 139 | 140 | def _init_params(self): 141 | init_bn(self.bn_l4) 142 | 143 | nn.init.normal_(self.classifier_l4.weight, 0, 0.01) 144 | if self.classifier_l4.bias is not None: 145 | nn.init.constant_(self.classifier_l4.bias, 0) 146 | 147 | init_struct(self.dim_reds) 148 | for c in self.classifiers: 149 | nn.init.normal_(c.weight, 0, 0.01) 150 | if c.bias is not None: 151 | nn.init.constant_(c.bias, 0) 152 | 153 | def featuremaps(self, x): 154 | fs = [] 155 | y_l0 = self.layer0(x) 156 | y_l1 = self.layer1(y_l0) 157 | y_l2 = self.layer2(y_l1) 158 | y_l2_1 = self.pc1(y_l2) 159 | y_l3 = self.layer3(y_l2_1) 160 | y_l4 = self.layer4(y_l3) 161 | 162 | fs.append(y_l4) 163 | fs.append(y_l3) 164 | fs.append(y_l2) 165 | fs.append(y_l1) 166 | 167 | return fs, y_l2_1 168 | 169 | def cross_ofp(self, x): 170 | x[1] = F.max_pool2d(x[1], kernel_size=2) 171 | 172 | y = torch.cat(x, 1) 173 | return y 174 | 175 | 176 | def forward(self, x): 177 | bs = x.size(0) 178 | reg_feat_re = [] 179 | 180 | fs, y_l2_1 = self.featuremaps(x) 181 | 182 | f_branch, reg_feats = self.neck(fs) 183 | 184 | f_l4_train = self.global_avgpool(fs[0]) 185 | 186 | f_parts = self.part_pools(f_branch) 187 | 188 | f_l4 = self.bn_l4(f_l4_train).view(bs, -1) 189 | 190 | if not self.training: 191 | f = [] 192 | f.append(F.normalize(f_l4, p=2, dim=1)) 193 | f.append(F.normalize(f_parts, p=2, dim=1).view(bs, -1)) 194 | 195 | f = torch.cat(f, 1) 196 | 197 | return f 198 | 199 | y = [] 200 | 201 | y_l4 = self.classifier_l4(f_l4) 202 | y.append(y_l4) 203 | 204 | f_short = self.dim_reds(f_parts) 205 | for j in range(self.num_parts): 206 | f_j = f_short[:, :, j, :].view(bs, -1) 207 | y_j = self.classifiers[j](f_j) 208 | 209 | y.append(y_j) 210 | 211 | reg_feat_re.append(self.cross_ofp(reg_feats)) 212 | 213 | if self.loss == 'softmax': 214 | return y 215 | elif self.loss == 'engine_FPB': 216 | f = [] 217 | f.append(F.normalize(f_l4_train, p=2, dim=1).view(bs, -1)) 218 | f.append(F.normalize(f_parts, p=2, dim=1).view(bs, -1)) 219 | 220 | f = torch.cat(f, 1) 221 | 222 | return y, f, reg_feat_re 223 | else: 224 | raise KeyError("Unsupported loss: {}".format(self.loss)) 225 | 226 | 227 | def fpb(num_classes, loss='softmax', pretrained=True, **kwargs): 228 | model = FPB(num_classes=num_classes, loss=loss, **kwargs) 229 | 230 | return model 231 | 232 | 233 | 234 | 235 | 236 | -------------------------------------------------------------------------------- /torchreid/models/nn_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | from torch.nn import Softmax 6 | from torch import nn 7 | from torch.nn import functional as F 8 | 9 | 10 | def init_bn(bn_struct): 11 | nn.init.constant_(bn_struct.weight, 1.0) 12 | nn.init.constant_(bn_struct.bias, 0.0) 13 | 14 | 15 | def init_struct(nn_struct): 16 | if not isinstance(nn_struct, nn.Module): 17 | return KeyError("Only nn.Module can be initialized!") 18 | for m in nn_struct.modules(): 19 | if isinstance(m, nn.Conv2d): 20 | nn.init.kaiming_normal_(m.weight, 21 | mode='fan_out', 22 | nonlinearity='relu') 23 | if m.bias is not None: 24 | nn.init.constant_(m.bias, 0) 25 | elif isinstance(m, nn.BatchNorm2d): 26 | nn.init.constant_(m.weight, 1) 27 | nn.init.constant_(m.bias, 0) 28 | elif isinstance(m, nn.BatchNorm1d): 29 | nn.init.constant_(m.weight, 1) 30 | nn.init.constant_(m.bias, 0) 31 | elif isinstance(m, nn.Linear): 32 | nn.init.normal_(m.weight, 0, 0.01) 33 | if m.bias is not None: 34 | nn.init.constant_(m.bias, 0) 35 | 36 | 37 | def init_conv(layer): 38 | if not (isinstance(layer, nn.Conv1d) or isinstance(layer, nn.Conv2d)): 39 | return KeyError("Only nn.Conv1d or nn.Conv2d can be initialized!") 40 | 41 | nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity='relu') 42 | if layer.bias is not None: 43 | nn.init.constant_(layer.bias, 0) 44 | 45 | 46 | def init_conv2d(layer): 47 | if not isinstance(layer, nn.Conv2d): 48 | return KeyError("Only nn.Conv2d can be initialized!") 49 | 50 | nn.init.kaiming_normal_(layer.weight, mode="fan_out", nonlinearity='relu') 51 | # nn.init.xavier_uniform_(layer.weight, gain=1) #20200421_3 52 | if layer.bias is not None: 53 | nn.init.constant_(layer.bias, 0) 54 | 55 | 56 | def init_fc(layer): 57 | if not isinstance(layer, nn.Linear): 58 | return KeyError("Only nn.Linear can be initialized!") 59 | nn.init.normal_(layer.weight, 0, 0.01) 60 | if m.bias is not None: 61 | nn.init.constant_(layer.bias, 0) 62 | 63 | 64 | class DimReduceLayer(nn.Module): 65 | def __init__(self, in_channels, out_channels): 66 | super(DimReduceLayer, self).__init__() 67 | layers = [] 68 | layers.append( 69 | nn.Conv2d(in_channels, 70 | out_channels, 71 | 1, 72 | stride=1, 73 | padding=0, 74 | bias=False)) 75 | layers.append(nn.BatchNorm2d(out_channels)) 76 | 77 | layers.append(nn.ReLU(inplace=True)) 78 | 79 | self.layers = nn.Sequential(*layers) 80 | 81 | def forward(self, x): 82 | return self.layers(x) 83 | 84 | 85 | class Conv_Bn_Relu(nn.Module): 86 | def __init__(self, in_c, out_c, k, s=1, p=0, norm_cfg=True, activation_cfg=True): 87 | super(Conv_Bn_Relu, self).__init__() 88 | if norm_cfg: 89 | self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p, bias=False) 90 | else: 91 | self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p) 92 | self.bn = nn.BatchNorm2d(out_c) 93 | self.norm_cfg = norm_cfg 94 | self.activation_cfg = activation_cfg 95 | 96 | self.init_params() 97 | 98 | def init_params(self): 99 | init_bn(self.bn) 100 | init_conv2d(self.conv) 101 | 102 | def forward(self, x): 103 | y = self.conv(x) 104 | 105 | if self.norm_cfg: 106 | y = self.bn(y) 107 | 108 | if self.activation_cfg: 109 | y = F.relu(y) 110 | 111 | return y 112 | 113 | 114 | class OFPenalty(nn.Module): 115 | def __init__(self, of_beta): 116 | super(OFPenalty, self).__init__() 117 | 118 | self.beta = of_beta 119 | # self.softmax = Softmax(dim=-1) 120 | 121 | def dominant_eigenvalue(self, A): 122 | B, N, _ = A.size() 123 | x = torch.randn(B, N, 1, device='cuda') 124 | 125 | for _ in range(1): 126 | x = torch.bmm(A, x) 127 | # x: 'B x N x 1' 128 | numerator = torch.bmm( 129 | torch.bmm(A, x).view(B, 1, N), 130 | x 131 | ).squeeze() 132 | denominator = (torch.norm(x.view(B, N), p=2, dim=1) ** 2).squeeze() 133 | 134 | return numerator / denominator 135 | 136 | def get_singular_values(self, A): 137 | AAT = torch.bmm(A, A.permute(0, 2, 1)) # C*S*S*C=C*C 138 | # AAT = self.softmax(AAT) 139 | B, N, _ = AAT.size() 140 | largest = self.dominant_eigenvalue(AAT) 141 | I = torch.eye(N, device='cuda').expand(B, N, N) # noqa 142 | I = I * largest.view(B, 1, 1).repeat(1, N, N) # noqa 143 | tmp = self.dominant_eigenvalue(AAT - I) 144 | return tmp + largest, largest 145 | 146 | def apply_penalty(self, x): 147 | batches, channels, height, width = x.size() 148 | W = x.view(batches, channels, -1) 149 | smallest, largest = self.get_singular_values(W) 150 | singular_penalty = (largest - smallest) * self.beta 151 | 152 | singular_penalty *= 0.01 153 | 154 | return singular_penalty.sum() / (x.size(0) / 32.) # Quirk: normalize to 32-batch case 155 | 156 | def forward(self, reg_feats): 157 | singular_penalty = sum([self.apply_penalty(x) for x in reg_feats]) 158 | 159 | return singular_penalty -------------------------------------------------------------------------------- /torchreid/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .optimizer import build_optimizer 4 | from .lr_scheduler import build_lr_scheduler 5 | from .warm_up import GradualWarmupScheduler 6 | -------------------------------------------------------------------------------- /torchreid/optim/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import torch 5 | #from warm_up import GradualWarmupScheduler 6 | from .warm_up import * 7 | 8 | AVAI_SCH = ['single_step', 'multi_step', 'cosine', 'warmup'] 9 | 10 | 11 | def build_lr_scheduler(optimizer, lr_scheduler='single_step', stepsize=1, gamma=0.1, max_epoch=1, warmup_multiplier=100, warmup_total_epoch=9): 12 | """A function wrapper for building a learning rate scheduler. 13 | 14 | Args: 15 | optimizer (Optimizer): an Optimizer. 16 | lr_scheduler (str, optional): learning rate scheduler method. Default is single_step. 17 | stepsize (int or list, optional): step size to decay learning rate. When ``lr_scheduler`` 18 | is "single_step", ``stepsize`` should be an integer. When ``lr_scheduler`` is 19 | "multi_step", ``stepsize`` is a list. Default is 1. 20 | gamma (float, optional): decay rate. Default is 0.1. 21 | max_epoch (int, optional): maximum epoch (for cosine annealing). Default is 1. 22 | 23 | Examples:: 24 | >>> # Decay learning rate by every 20 epochs. 25 | >>> scheduler = torchreid.optim.build_lr_scheduler( 26 | >>> optimizer, lr_scheduler='single_step', stepsize=20 27 | >>> ) 28 | >>> # Decay learning rate at 30, 50 and 55 epochs. 29 | >>> scheduler = torchreid.optim.build_lr_scheduler( 30 | >>> optimizer, lr_scheduler='multi_step', stepsize=[30, 50, 55] 31 | >>> ) 32 | """ 33 | if lr_scheduler not in AVAI_SCH: 34 | raise ValueError('Unsupported scheduler: {}. Must be one of {}'.format(lr_scheduler, AVAI_SCH)) 35 | 36 | if lr_scheduler == 'single_step': 37 | if isinstance(stepsize, list): 38 | stepsize = stepsize[-1] 39 | 40 | if not isinstance(stepsize, int): 41 | raise TypeError( 42 | 'For single_step lr_scheduler, stepsize must ' 43 | 'be an integer, but got {}'.format(type(stepsize)) 44 | ) 45 | 46 | scheduler = torch.optim.lr_scheduler.StepLR( 47 | optimizer, step_size=stepsize, gamma=gamma 48 | ) 49 | 50 | elif lr_scheduler == 'multi_step': 51 | if not isinstance(stepsize, list): 52 | raise TypeError( 53 | 'For multi_step lr_scheduler, stepsize must ' 54 | 'be a list, but got {}'.format(type(stepsize)) 55 | ) 56 | 57 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 58 | optimizer, milestones=stepsize, gamma=gamma 59 | ) 60 | 61 | elif lr_scheduler == 'cosine': 62 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 63 | optimizer, float(max_epoch) 64 | ) 65 | 66 | elif lr_scheduler == 'warmup': 67 | scheduler_multistep = torch.optim.lr_scheduler.MultiStepLR( 68 | optimizer, milestones=stepsize, gamma=gamma 69 | ) 70 | scheduler = GradualWarmupScheduler( 71 | optimizer, multiplier=warmup_multiplier, total_epoch=warmup_total_epoch, after_scheduler=scheduler_multistep 72 | ) 73 | 74 | return scheduler 75 | -------------------------------------------------------------------------------- /torchreid/optim/optimizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import warnings 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | AVAI_OPTIMS = ['adam', 'amsgrad', 'sgd', 'rmsprop'] 11 | 12 | 13 | def build_optimizer( 14 | model, 15 | optim='adam', 16 | lr=0.0003, 17 | weight_decay=5e-04, 18 | momentum=0.9, 19 | sgd_dampening=0, 20 | sgd_nesterov=False, 21 | rmsprop_alpha=0.99, 22 | adam_beta1=0.9, 23 | adam_beta2=0.99, 24 | staged_lr=False, 25 | new_layers='', 26 | base_lr_mult=0.1 27 | ): 28 | """A function wrapper for building an optimizer. 29 | 30 | Args: 31 | model (nn.Module): model. 32 | optim (str, optional): optimizer. Default is "adam". 33 | lr (float, optional): learning rate. Default is 0.0003. 34 | weight_decay (float, optional): weight decay (L2 penalty). Default is 5e-04. 35 | momentum (float, optional): momentum factor in sgd. Default is 0.9, 36 | sgd_dampening (float, optional): dampening for momentum. Default is 0. 37 | sgd_nesterov (bool, optional): enables Nesterov momentum. Default is False. 38 | rmsprop_alpha (float, optional): smoothing constant for rmsprop. Default is 0.99. 39 | adam_beta1 (float, optional): beta-1 value in adam. Default is 0.9. 40 | adam_beta2 (float, optional): beta-2 value in adam. Default is 0.99, 41 | staged_lr (bool, optional): uses different learning rates for base and new layers. Base 42 | layers are pretrained layers while new layers are randomly initialized, e.g. the 43 | identity classification layer. Enabling ``staged_lr`` can allow the base layers to 44 | be trained with a smaller learning rate determined by ``base_lr_mult``, while the new 45 | layers will take the ``lr``. Default is False. 46 | new_layers (str or list): attribute names in ``model``. Default is empty. 47 | base_lr_mult (float, optional): learning rate multiplier for base layers. Default is 0.1. 48 | 49 | Examples:: 50 | >>> # A normal optimizer can be built by 51 | >>> optimizer = torchreid.optim.build_optimizer(model, optim='sgd', lr=0.01) 52 | >>> # If you want to use a smaller learning rate for pretrained layers 53 | >>> # and the attribute name for the randomly initialized layer is 'classifier', 54 | >>> # you can do 55 | >>> optimizer = torchreid.optim.build_optimizer( 56 | >>> model, optim='sgd', lr=0.01, staged_lr=True, 57 | >>> new_layers='classifier', base_lr_mult=0.1 58 | >>> ) 59 | >>> # Now the `classifier` has learning rate 0.01 but the base layers 60 | >>> # have learning rate 0.01 * 0.1. 61 | >>> # new_layers can also take multiple attribute names. Say the new layers 62 | >>> # are 'fc' and 'classifier', you can do 63 | >>> optimizer = torchreid.optim.build_optimizer( 64 | >>> model, optim='sgd', lr=0.01, staged_lr=True, 65 | >>> new_layers=['fc', 'classifier'], base_lr_mult=0.1 66 | >>> ) 67 | """ 68 | if optim not in AVAI_OPTIMS: 69 | raise ValueError('Unsupported optim: {}. Must be one of {}'.format(optim, AVAI_OPTIMS)) 70 | 71 | if not isinstance(model, nn.Module): 72 | raise TypeError('model given to build_optimizer must be an instance of nn.Module') 73 | 74 | if staged_lr: 75 | if isinstance(new_layers, str): 76 | if new_layers is None: 77 | warnings.warn('new_layers is empty, therefore, staged_lr is useless') 78 | new_layers = [new_layers] 79 | 80 | if isinstance(model, nn.DataParallel): 81 | model = model.module 82 | 83 | base_params = [] 84 | base_layers = [] 85 | new_params = [] 86 | 87 | for name, module in model.named_children(): 88 | if name in new_layers: 89 | new_params += [p for p in module.parameters()] 90 | else: 91 | base_params += [p for p in module.parameters()] 92 | base_layers.append(name) 93 | 94 | param_groups = [ 95 | {'params': base_params, 'lr': lr * base_lr_mult}, 96 | {'params': new_params}, 97 | ] 98 | 99 | else: 100 | param_groups = model.parameters() 101 | 102 | if optim == 'adam': 103 | optimizer = torch.optim.Adam( 104 | param_groups, 105 | lr=lr, 106 | weight_decay=weight_decay, 107 | betas=(adam_beta1, adam_beta2), 108 | ) 109 | 110 | elif optim == 'amsgrad': 111 | optimizer = torch.optim.Adam( 112 | param_groups, 113 | lr=lr, 114 | weight_decay=weight_decay, 115 | betas=(adam_beta1, adam_beta2), 116 | amsgrad=True, 117 | ) 118 | 119 | elif optim == 'sgd': 120 | optimizer = torch.optim.SGD( 121 | param_groups, 122 | lr=lr, 123 | momentum=momentum, 124 | weight_decay=weight_decay, 125 | dampening=sgd_dampening, 126 | nesterov=sgd_nesterov, 127 | ) 128 | 129 | elif optim == 'rmsprop': 130 | optimizer = torch.optim.RMSprop( 131 | param_groups, 132 | lr=lr, 133 | momentum=momentum, 134 | weight_decay=weight_decay, 135 | alpha=rmsprop_alpha, 136 | ) 137 | 138 | return optimizer -------------------------------------------------------------------------------- /torchreid/optim/warm_up.py: -------------------------------------------------------------------------------- 1 | # from:https://github.com/ildoonet/pytorch-gradual-warmup-lr 2 | from torch.optim.lr_scheduler import _LRScheduler 3 | from torch.optim.lr_scheduler import ReduceLROnPlateau 4 | 5 | class GradualWarmupScheduler(_LRScheduler): 6 | """ Gradually warm-up(increasing) learning rate in optimizer. 7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 8 | Args: 9 | optimizer (Optimizer): Wrapped optimizer. 10 | multiplier: target learning rate = base lr * multiplier 11 | total_epoch: target learning rate is reached at total_epoch, gradually 12 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 13 | """ 14 | 15 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 16 | self.multiplier = multiplier 17 | if self.multiplier <= 1.: 18 | raise ValueError('multiplier should be greater than 1.') 19 | self.total_epoch = total_epoch 20 | self.after_scheduler = after_scheduler 21 | self.finished = False 22 | super().__init__(optimizer) 23 | 24 | def get_lr(self): 25 | if self.last_epoch > self.total_epoch: 26 | if self.after_scheduler: 27 | if not self.finished: 28 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 29 | self.finished = True 30 | return self.after_scheduler.get_lr() 31 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 32 | 33 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 34 | 35 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 36 | if epoch is None: 37 | epoch = self.last_epoch + 1 38 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 39 | if self.last_epoch <= self.total_epoch: 40 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 41 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 42 | param_group['lr'] = lr 43 | else: 44 | if epoch is None: 45 | self.after_scheduler.step(metrics, None) 46 | else: 47 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 48 | 49 | def step(self, epoch=None, metrics=None): 50 | if type(self.after_scheduler) != ReduceLROnPlateau: 51 | if self.finished and self.after_scheduler: 52 | if epoch is None: 53 | self.after_scheduler.step(None) 54 | else: 55 | self.after_scheduler.step(epoch - self.total_epoch) 56 | else: 57 | return super(GradualWarmupScheduler, self).step(epoch) 58 | else: 59 | self.step_ReduceLROnPlateau(metrics, epoch) 60 | 61 | if __name__ == '__main__': 62 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_epoch) 63 | scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=8, total_epoch=10, after_scheduler=scheduler_cosine) 64 | 65 | 66 | -------------------------------------------------------------------------------- /torchreid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .avgmeter import * 4 | from .loggers import * 5 | from .tools import * 6 | from .reidtools import * 7 | from .torchtools import * 8 | from .rerank import re_ranking 9 | from .model_complexity import compute_model_complexity, calc_model_params 10 | -------------------------------------------------------------------------------- /torchreid/utils/avgmeter.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | __all__ = ['AverageMeter'] 5 | 6 | 7 | class AverageMeter(object): 8 | """Computes and stores the average and current value. 9 | 10 | Examples:: 11 | >>> # Initialize a meter to record loss 12 | >>> losses = AverageMeter() 13 | >>> # Update meter after every minibatch update 14 | >>> losses.update(loss_value, batch_size) 15 | """ 16 | def __init__(self): 17 | self.reset() 18 | 19 | def reset(self): 20 | self.val = 0 21 | self.avg = 0 22 | self.sum = 0 23 | self.count = 0 24 | 25 | def update(self, val, n=1): 26 | self.val = val 27 | self.sum += val * n 28 | self.count += n 29 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /torchreid/utils/loggers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | __all__ = ['Logger', 'RankLogger'] 4 | 5 | import sys 6 | import os 7 | import os.path as osp 8 | 9 | from .tools import mkdir_if_missing 10 | 11 | 12 | class Logger(object): 13 | """Writes console output to external text file. 14 | 15 | Imported from ``_ 16 | 17 | Args: 18 | fpath (str): directory to save logging file. 19 | 20 | Examples:: 21 | >>> import sys 22 | >>> import os 23 | >>> import os.path as osp 24 | >>> from torchreid.utils import Logger 25 | >>> save_dir = 'log/resnet50-softmax-market1501' 26 | >>> log_name = 'train.log' 27 | >>> sys.stdout = Logger(osp.join(args.save_dir, log_name)) 28 | """ 29 | def __init__(self, fpath=None): 30 | self.console = sys.stdout 31 | self.file = None 32 | if fpath is not None: 33 | mkdir_if_missing(osp.dirname(fpath)) 34 | self.file = open(fpath, 'w') 35 | 36 | def __del__(self): 37 | self.close() 38 | 39 | def __enter__(self): 40 | pass 41 | 42 | def __exit__(self, *args): 43 | self.close() 44 | 45 | def write(self, msg): 46 | self.console.write(msg) 47 | if self.file is not None: 48 | self.file.write(msg) 49 | 50 | def flush(self): 51 | self.console.flush() 52 | if self.file is not None: 53 | self.file.flush() 54 | os.fsync(self.file.fileno()) 55 | 56 | def close(self): 57 | self.console.close() 58 | if self.file is not None: 59 | self.file.close() 60 | 61 | 62 | class RankLogger(object): 63 | """Records the rank1 matching accuracy obtained for each 64 | test dataset at specified evaluation steps and provides a function 65 | to show the summarized results, which are convenient for analysis. 66 | 67 | Args: 68 | sources (str or list): source dataset name(s). 69 | targets (str or list): target dataset name(s). 70 | 71 | Examples:: 72 | >>> from torchreid.utils import RankLogger 73 | >>> s = 'market1501' 74 | >>> t = 'market1501' 75 | >>> ranklogger = RankLogger(s, t) 76 | >>> ranklogger.write(t, 10, 0.5) 77 | >>> ranklogger.write(t, 20, 0.7) 78 | >>> ranklogger.write(t, 30, 0.9) 79 | >>> ranklogger.show_summary() 80 | >>> # You will see: 81 | >>> # => Show performance summary 82 | >>> # market1501 (source) 83 | >>> # - epoch 10 rank1 50.0% 84 | >>> # - epoch 20 rank1 70.0% 85 | >>> # - epoch 30 rank1 90.0% 86 | >>> # If there are multiple test datasets 87 | >>> t = ['market1501', 'dukemtmcreid'] 88 | >>> ranklogger = RankLogger(s, t) 89 | >>> ranklogger.write(t[0], 10, 0.5) 90 | >>> ranklogger.write(t[0], 20, 0.7) 91 | >>> ranklogger.write(t[0], 30, 0.9) 92 | >>> ranklogger.write(t[1], 10, 0.1) 93 | >>> ranklogger.write(t[1], 20, 0.2) 94 | >>> ranklogger.write(t[1], 30, 0.3) 95 | >>> ranklogger.show_summary() 96 | >>> # You can see: 97 | >>> # => Show performance summary 98 | >>> # market1501 (source) 99 | >>> # - epoch 10 rank1 50.0% 100 | >>> # - epoch 20 rank1 70.0% 101 | >>> # - epoch 30 rank1 90.0% 102 | >>> # dukemtmcreid (target) 103 | >>> # - epoch 10 rank1 10.0% 104 | >>> # - epoch 20 rank1 20.0% 105 | >>> # - epoch 30 rank1 30.0% 106 | """ 107 | def __init__(self, sources, targets): 108 | self.sources = sources 109 | self.targets = targets 110 | 111 | if isinstance(self.sources, str): 112 | self.sources = [self.sources] 113 | 114 | if isinstance(self.targets, str): 115 | self.targets = [self.targets] 116 | 117 | self.logger = {name: {'epoch': [], 'rank1': []} for name in self.targets} 118 | 119 | def write(self, name, epoch, rank1): 120 | """Writes result. 121 | 122 | Args: 123 | name (str): dataset name. 124 | epoch (int): current epoch. 125 | rank1 (float): rank1 result. 126 | """ 127 | self.logger[name]['epoch'].append(epoch) 128 | self.logger[name]['rank1'].append(rank1) 129 | 130 | def show_summary(self): 131 | """Shows saved results.""" 132 | print('=> Show performance summary') 133 | for name in self.targets: 134 | from_where = 'source' if name in self.sources else 'target' 135 | print('{} ({})'.format(name, from_where)) 136 | for epoch, rank1 in zip(self.logger[name]['epoch'], self.logger[name]['rank1']): 137 | print('- epoch {}\t rank1 {:.1%}'.format(epoch, rank1)) -------------------------------------------------------------------------------- /torchreid/utils/reidtools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | __all__ = ['visualize_ranked_results'] 5 | 6 | import numpy as np 7 | import os 8 | import os.path as osp 9 | import shutil 10 | import cv2 11 | from matplotlib import pyplot as plt 12 | 13 | from .tools import mkdir_if_missing 14 | 15 | 16 | GRID_SPACING = 10 17 | QUERY_EXTRA_SPACING = 90 18 | BW = 5 # border width 19 | GREEN = (0, 255, 0) 20 | RED = (0, 0, 255) 21 | 22 | 23 | def visualize_ranked_results(distmat, dataset, data_type, width=128, height=256, save_dir='', topk=10): 24 | """Visualizes ranked results. 25 | 26 | Supports both image-reid and video-reid. 27 | 28 | For image-reid, ranks will be plotted in a single figure. For video-reid, ranks will be 29 | saved in folders each containing a tracklet. 30 | 31 | Args: 32 | distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery). 33 | dataset (tuple): a 2-tuple containing (query, gallery), each of which contains 34 | tuples of (img_path(s), pid, camid). 35 | data_type (str): "image" or "video". 36 | width (int, optional): resized image width. Default is 128. 37 | height (int, optional): resized image height. Default is 256. 38 | save_dir (str): directory to save output images. 39 | topk (int, optional): denoting top-k images in the rank list to be visualized. 40 | Default is 10. 41 | """ 42 | num_q, num_g = distmat.shape 43 | mkdir_if_missing(save_dir) 44 | 45 | print('# query: {}\n# gallery {}'.format(num_q, num_g)) 46 | print('Visualizing top-{} ranks ...'.format(topk)) 47 | 48 | query, gallery = dataset 49 | assert num_q == len(query) 50 | assert num_g == len(gallery) 51 | 52 | indices = np.argsort(distmat, axis=1) 53 | 54 | def _cp_img_to(src, dst, rank, prefix, matched=False): 55 | """ 56 | Args: 57 | src: image path or tuple (for vidreid) 58 | dst: target directory 59 | rank: int, denoting ranked position, starting from 1 60 | prefix: string 61 | matched: bool 62 | """ 63 | if isinstance(src, (tuple, list)): 64 | if prefix == 'gallery': 65 | suffix = 'TRUE' if matched else 'FALSE' 66 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3)) + '_' + suffix 67 | else: 68 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3)) 69 | mkdir_if_missing(dst) 70 | for img_path in src: 71 | shutil.copy(img_path, dst) 72 | else: 73 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3) + '_name_' + osp.basename(src)) 74 | shutil.copy(src, dst) 75 | 76 | for q_idx in range(num_q): 77 | qimg_path, qpid, qcamid = query[q_idx] 78 | qimg_path_name = qimg_path[0] if isinstance(qimg_path, (tuple, list)) else qimg_path 79 | 80 | if data_type == 'image': 81 | qimg = cv2.imread(qimg_path) 82 | qimg = cv2.resize(qimg, (width, height)) 83 | qimg = cv2.copyMakeBorder(qimg, BW, BW, BW, BW, cv2.BORDER_CONSTANT, value=(0, 0, 0)) 84 | # resize twice to ensure that the border width is consistent across images 85 | qimg = cv2.resize(qimg, (width, height)) 86 | num_cols = topk + 1 87 | grid_img = 255 * np.ones((height, num_cols*width+topk*GRID_SPACING+QUERY_EXTRA_SPACING, 3), dtype=np.uint8) 88 | grid_img[:, :width, :] = qimg 89 | else: 90 | qdir = osp.join(save_dir, osp.basename(osp.splitext(qimg_path_name)[0])) 91 | mkdir_if_missing(qdir) 92 | _cp_img_to(qimg_path, qdir, rank=0, prefix='query') 93 | 94 | rank_idx = 1 95 | for g_idx in indices[q_idx,:]: 96 | gimg_path, gpid, gcamid = gallery[g_idx] 97 | invalid = (qpid == gpid) & (qcamid == gcamid) 98 | 99 | if not invalid: 100 | matched = gpid==qpid 101 | if data_type == 'image': 102 | border_color = GREEN if matched else RED 103 | gimg = cv2.imread(gimg_path) 104 | gimg = cv2.resize(gimg, (width, height)) 105 | gimg = cv2.copyMakeBorder(gimg, BW, BW, BW, BW, cv2.BORDER_CONSTANT, value=border_color) 106 | gimg = cv2.resize(gimg, (width, height)) 107 | start = rank_idx*width + rank_idx*GRID_SPACING + QUERY_EXTRA_SPACING 108 | end = (rank_idx+1)*width + rank_idx*GRID_SPACING + QUERY_EXTRA_SPACING 109 | grid_img[:, start: end, :] = gimg 110 | else: 111 | _cp_img_to(gimg_path, qdir, rank=rank_idx, prefix='gallery', matched=matched) 112 | 113 | rank_idx += 1 114 | if rank_idx > topk: 115 | break 116 | 117 | if data_type == 'image': 118 | imname = osp.basename(osp.splitext(qimg_path_name)[0]) 119 | cv2.imwrite(osp.join(save_dir, imname+'.jpg'), grid_img) 120 | 121 | if (q_idx+1) % 100 == 0: 122 | print('- done {}/{}'.format(q_idx+1, num_q)) 123 | 124 | print('Done. Images have been saved to "{}" ...'.format(save_dir)) 125 | -------------------------------------------------------------------------------- /torchreid/utils/rerank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2/python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Source: https://github.com/zhunzhong07/person-re-ranking 5 | 6 | Created on Mon Jun 26 14:46:56 2017 7 | @author: luohao 8 | Modified by Houjing Huang, 2017-12-22. 9 | - This version accepts distance matrix instead of raw features. 10 | - The difference of `/` division between python 2 and 3 is handled. 11 | - numpy.float16 is replaced by numpy.float32 for numerical precision. 12 | 13 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 14 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 15 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 16 | 17 | API 18 | q_g_dist: query-gallery distance matrix, numpy array, shape [num_query, num_gallery] 19 | q_q_dist: query-query distance matrix, numpy array, shape [num_query, num_query] 20 | g_g_dist: gallery-gallery distance matrix, numpy array, shape [num_gallery, num_gallery] 21 | k1, k2, lambda_value: parameters, the original paper is (k1=20, k2=6, lambda_value=0.3) 22 | Returns: 23 | final_dist: re-ranked distance, numpy array, shape [num_query, num_gallery] 24 | """ 25 | from __future__ import absolute_import 26 | from __future__ import print_function 27 | from __future__ import division 28 | 29 | __all__ = ['re_ranking'] 30 | 31 | import numpy as np 32 | 33 | 34 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3): 35 | 36 | # The following naming, e.g. gallery_num, is different from outer scope. 37 | # Don't care about it. 38 | 39 | original_dist = np.concatenate( 40 | [np.concatenate([q_q_dist, q_g_dist], axis=1), 41 | np.concatenate([q_g_dist.T, g_g_dist], axis=1)], 42 | axis=0) 43 | original_dist = np.power(original_dist, 2).astype(np.float32) 44 | original_dist = np.transpose(1. * original_dist/np.max(original_dist,axis = 0)) 45 | V = np.zeros_like(original_dist).astype(np.float32) 46 | initial_rank = np.argsort(original_dist).astype(np.int32) 47 | 48 | query_num = q_g_dist.shape[0] 49 | gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1] 50 | all_num = gallery_num 51 | 52 | for i in range(all_num): 53 | # k-reciprocal neighbors 54 | forward_k_neigh_index = initial_rank[i,:k1+1] 55 | backward_k_neigh_index = initial_rank[forward_k_neigh_index,:k1+1] 56 | fi = np.where(backward_k_neigh_index==i)[0] 57 | k_reciprocal_index = forward_k_neigh_index[fi] 58 | k_reciprocal_expansion_index = k_reciprocal_index 59 | for j in range(len(k_reciprocal_index)): 60 | candidate = k_reciprocal_index[j] 61 | candidate_forward_k_neigh_index = initial_rank[candidate,:int(np.around(k1/2.))+1] 62 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,:int(np.around(k1/2.))+1] 63 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 64 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 65 | if len(np.intersect1d(candidate_k_reciprocal_index,k_reciprocal_index))> 2./3*len(candidate_k_reciprocal_index): 66 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index,candidate_k_reciprocal_index) 67 | 68 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 69 | weight = np.exp(-original_dist[i,k_reciprocal_expansion_index]) 70 | V[i,k_reciprocal_expansion_index] = 1.*weight/np.sum(weight) 71 | original_dist = original_dist[:query_num,] 72 | if k2 != 1: 73 | V_qe = np.zeros_like(V,dtype=np.float32) 74 | for i in range(all_num): 75 | V_qe[i,:] = np.mean(V[initial_rank[i,:k2],:],axis=0) 76 | V = V_qe 77 | del V_qe 78 | del initial_rank 79 | invIndex = [] 80 | for i in range(gallery_num): 81 | invIndex.append(np.where(V[:,i] != 0)[0]) 82 | 83 | jaccard_dist = np.zeros_like(original_dist,dtype = np.float32) 84 | 85 | 86 | for i in range(query_num): 87 | temp_min = np.zeros(shape=[1,gallery_num],dtype=np.float32) 88 | indNonZero = np.where(V[i,:] != 0)[0] 89 | indImages = [] 90 | indImages = [invIndex[ind] for ind in indNonZero] 91 | for j in range(len(indNonZero)): 92 | temp_min[0,indImages[j]] = temp_min[0,indImages[j]]+ np.minimum(V[i,indNonZero[j]],V[indImages[j],indNonZero[j]]) 93 | jaccard_dist[i] = 1-temp_min/(2.-temp_min) 94 | 95 | final_dist = jaccard_dist*(1-lambda_value) + original_dist*lambda_value 96 | del original_dist 97 | del V 98 | del jaccard_dist 99 | final_dist = final_dist[:query_num,query_num:] 100 | return final_dist 101 | -------------------------------------------------------------------------------- /torchreid/utils/tools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | __all__ = ['mkdir_if_missing', 'check_isfile', 'read_json', 'write_json', 6 | 'set_random_seed', 'download_url', 'read_image', 'collect_env_info'] 7 | 8 | import sys 9 | import os 10 | import os.path as osp 11 | import time 12 | import errno 13 | import json 14 | from collections import OrderedDict 15 | import warnings 16 | import random 17 | import numpy as np 18 | import PIL 19 | from PIL import Image 20 | 21 | import torch 22 | 23 | 24 | def mkdir_if_missing(dirname): 25 | """Creates dirname if it is missing.""" 26 | if not osp.exists(dirname): 27 | try: 28 | os.makedirs(dirname) 29 | except OSError as e: 30 | if e.errno != errno.EEXIST: 31 | raise 32 | 33 | 34 | def check_isfile(fpath): 35 | """Checks if the given path is a file. 36 | 37 | Args: 38 | fpath (str): file path. 39 | 40 | Returns: 41 | bool 42 | """ 43 | isfile = osp.isfile(fpath) 44 | if not isfile: 45 | warnings.warn('No file found at "{}"'.format(fpath)) 46 | return isfile 47 | 48 | 49 | def read_json(fpath): 50 | """Reads json file from a path.""" 51 | with open(fpath, 'r') as f: 52 | obj = json.load(f) 53 | return obj 54 | 55 | 56 | def write_json(obj, fpath): 57 | """Writes to a json file.""" 58 | mkdir_if_missing(osp.dirname(fpath)) 59 | with open(fpath, 'w') as f: 60 | json.dump(obj, f, indent=4, separators=(',', ': ')) 61 | 62 | 63 | def set_random_seed(seed): 64 | random.seed(seed) 65 | np.random.seed(seed) 66 | torch.manual_seed(seed) 67 | torch.cuda.manual_seed_all(seed) 68 | 69 | 70 | def download_url(url, dst): 71 | """Downloads file from a url to a destination. 72 | 73 | Args: 74 | url (str): url to download file. 75 | dst (str): destination path. 76 | """ 77 | from six.moves import urllib 78 | print('* url="{}"'.format(url)) 79 | print('* destination="{}"'.format(dst)) 80 | 81 | def _reporthook(count, block_size, total_size): 82 | global start_time 83 | if count == 0: 84 | start_time = time.time() 85 | return 86 | duration = time.time() - start_time 87 | progress_size = int(count * block_size) 88 | speed = int(progress_size / (1024 * duration)) 89 | percent = int(count * block_size * 100 / total_size) 90 | sys.stdout.write('\r...%d%%, %d MB, %d KB/s, %d seconds passed' % 91 | (percent, progress_size / (1024 * 1024), speed, duration)) 92 | sys.stdout.flush() 93 | 94 | urllib.request.urlretrieve(url, dst, _reporthook) 95 | sys.stdout.write('\n') 96 | 97 | 98 | def read_image(path): 99 | """Reads image from path using ``PIL.Image``. 100 | 101 | Args: 102 | path (str): path to an image. 103 | 104 | Returns: 105 | PIL image 106 | """ 107 | got_img = False 108 | if not osp.exists(path): 109 | raise IOError('"{}" does not exist'.format(path)) 110 | while not got_img: 111 | try: 112 | img = Image.open(path).convert('RGB') 113 | got_img = True 114 | except IOError: 115 | print('IOError incurred when reading "{}". Will redo. Don\'t worry. Just chill.'.format(img_path)) 116 | pass 117 | return img 118 | 119 | 120 | def collect_env_info(): 121 | """Returns env info as a string. 122 | 123 | Code source: github.com/facebookresearch/maskrcnn-benchmark 124 | """ 125 | from torch.utils.collect_env import get_pretty_env_info 126 | env_str = get_pretty_env_info() 127 | env_str += '\n Pillow ({})'.format(PIL.__version__) 128 | return env_str 129 | --------------------------------------------------------------------------------