├── LICENCE.md ├── README.md ├── config ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── defaults.cpython-36.pyc └── defaults.py ├── configs ├── baseline.yml ├── softmax.yml ├── softmax_triplet.yml ├── softmax_triplet_ft.yml ├── softmax_triplet_ftc.yml └── softmax_triplet_with_center.yml ├── data ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── build.cpython-36.pyc │ └── collate_batch.cpython-36.pyc ├── build.py ├── collate_batch.py ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── bases.cpython-36.pyc │ │ ├── cuhk.cpython-36.pyc │ │ ├── dataset_loader.cpython-36.pyc │ │ ├── dukemtmcreid.cpython-36.pyc │ │ ├── eval_reid.cpython-36.pyc │ │ ├── market1501.cpython-36.pyc │ │ ├── msmt17.cpython-36.pyc │ │ ├── prw.cpython-36.pyc │ │ └── veri.cpython-36.pyc │ ├── bases.py │ ├── cuhk.py │ ├── cuhk03.py │ ├── dataset_loader.py │ ├── dukemtmcreid.py │ ├── eval_reid.py │ ├── market1501.py │ ├── msmt17.py │ ├── prw.py │ └── veri.py ├── samplers │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── triplet_sampler.cpython-36.pyc │ └── triplet_sampler.py └── transforms │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── build.cpython-36.pyc │ └── transforms.cpython-36.pyc │ ├── build.py │ └── transforms.py ├── engine ├── __pycache__ │ ├── inference.cpython-36.pyc │ └── trainer.cpython-36.pyc ├── inference.py └── trainer.py ├── image └── examples.png ├── layers ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── center_loss.cpython-36.pyc │ └── triplet_loss.cpython-36.pyc ├── center_loss.py └── triplet_loss.py ├── modeling ├── PISNet.py ├── Pre_Selection_Model.py ├── __init__.py ├── __pycache__ │ └── __init__.cpython-36.pyc ├── backbones │ ├── Query_Guided_Attention.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── resnet.cpython-36.pyc │ ├── pisnet.py │ └── resnet.py └── baseline.py ├── pi_cuhk.sh ├── pi_prw.sh ├── pre_select_cuhk.sh ├── solver ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── build.cpython-36.pyc │ └── lr_scheduler.cpython-36.pyc ├── build.py └── lr_scheduler.py ├── tests ├── __init__.py └── lr_scheduler_test.py ├── tools ├── __init__.py ├── pre_selection.py └── train.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── iotools.cpython-36.pyc ├── logger.cpython-36.pyc ├── re_ranking.cpython-36.pyc └── reid_metric.cpython-36.pyc ├── iotools.py ├── logger.py ├── re_ranking.py └── reid_metric.py /LICENCE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2019] [HaoLuo] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Do Not Disturb Me: Person Re-identification Under the Interference of Other Pedestrians (ECCV 2020) 2 | 3 | Official code for ECCV 2020 paper [Do Not Disturb Me: Person Re-identification Under the Interference of Other Pedestrians](https://arxiv.org/abs/2008.06963). 4 | 5 |

6 | 7 |

8 | 9 | ## Introduction 10 | 11 | In the conventional person Re-ID setting, it is assumed that cropped images are the person images within the bounding box for each individual. However, in a crowded scene, off-shelf-detectors may generate bounding boxes involving multiple people, where the large proportion of background pedestrians or human occlusion exists. The representa- tion extracted from such cropped images, which contain both the target and the interference pedestrians, might include distractive information. This will lead to wrong retrieval results. To address this problem, this paper presents a novel deep network termed Pedestrian-Interference Sup- pression Network (PISNet). PISNet leverages a Query-Guided Attention Block (QGAB) to enhance the feature of the target in the gallery, under the guidance of the query. Furthermore, the involving Guidance Reversed Attention Module and the Multi-Person Separation Loss promote QGAB to suppress the interference of other pedestrians. Our method is evalu- ated on two new pedestrian-interference datasets and the results show that the proposed method performs favorably against existing Re-ID methods. 12 | 13 | 14 |

Resouces

15 | 16 | 1. Pretrained Models: 17 | 18 | [Baidu NetDisk](https://pan.baidu.com/s/1O08TssJcASsTh8veIBimzA), Password: 6x4x. The Models are trained using the gt boxes from [CUHK-SYSU](https://github.com/ShuangLI59/person_search) and [PRW](https://github.com/liangzheng06/PRW-baseline), respectively. 19 | 20 | 2. Datasets: 21 | 22 | Request the datasets from xbrainzsz@gmail.com (academic only). 23 | Due to licensing issues, please send me your request using your university email. 24 | 25 | ## Citation 26 | 27 | If you find this code useful in your research, please consider citing: 28 | ``` 29 | @inproceedings{zhao2020pireid, 30 | title={Do Not Disturb Me: Person Re-identification Under the Interference of Other Pedestrians}, 31 | author={Shizhen, Zhao and Changxin, Gao and Jun, Zhang and Hao, Cheng and Chuchu, Han and Xinyang, Jiang and Xiaowei, Guo and Wei-Shi, Zheng and Nong, Sang and Xing, Sun}, 32 | booktitle={European Conference on Computer Vision (ECCV)}, 33 | year={2020} 34 | } 35 | ``` 36 | 37 | ## Contact 38 | 39 | Shizhen Zhao: xbrainzsz@gmail.com 40 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .defaults import _C as cfg 8 | -------------------------------------------------------------------------------- /config/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/config/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /config/__pycache__/defaults.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/config/__pycache__/defaults.cpython-36.pyc -------------------------------------------------------------------------------- /config/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | # ----------------------------------------------------------------------------- 4 | # Convention about Training / Test specific parameters 5 | # ----------------------------------------------------------------------------- 6 | # Whenever an argument can be either used for training or for testing, the 7 | # corresponding name will be post-fixed by a _TRAIN for a training parameter, 8 | # or _TEST for a test-specific parameter. 9 | # For example, the number of images during training will be 10 | # IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be 11 | # IMAGES_PER_BATCH_TEST 12 | 13 | # ----------------------------------------------------------------------------- 14 | # Config definition 15 | # ----------------------------------------------------------------------------- 16 | 17 | _C = CN() 18 | 19 | _C.MODEL = CN() 20 | # Using cuda or cpu for training 21 | _C.MODEL.DEVICE = "cuda" 22 | # ID number of GPU 23 | _C.MODEL.DEVICE_ID = '0' 24 | # Name of backbone 25 | _C.MODEL.NAME = 'resnet50' 26 | # Last stride of backbone 27 | _C.MODEL.LAST_STRIDE = 1 28 | # Path to pretrained model of backbone 29 | _C.MODEL.PRETRAIN_PATH = '' 30 | # Use ImageNet pretrained model to initialize backbone or use self trained model to initialize the whole model 31 | # Options: 'imagenet' or 'self' 32 | _C.MODEL.PRETRAIN_CHOICE = 'imagenet' 33 | # If train with BNNeck, options: 'bnneck' or 'no' 34 | _C.MODEL.NECK = 'bnneck' 35 | # If train loss include center loss, options: 'yes' or 'no'. Loss with center loss has different optimizer configuration 36 | _C.MODEL.IF_WITH_CENTER = 'no' 37 | # The loss type of metric loss 38 | # options:['triplet'](without center loss) or ['center','triplet_center'](with center loss) 39 | _C.MODEL.METRIC_LOSS_TYPE = 'triplet' 40 | # For example, if loss type is cross entropy loss + triplet loss + center loss 41 | # the setting should be: _C.MODEL.METRIC_LOSS_TYPE = 'triplet_center' and _C.MODEL.IF_WITH_CENTER = 'yes' 42 | 43 | # If train with label smooth, options: 'on', 'off' 44 | _C.MODEL.IF_LABELSMOOTH = 'on' 45 | 46 | # HAS_NON_LOCAL 47 | _C.MODEL.HAS_NON_LOCAL = "no" 48 | 49 | #Whole model train 50 | _C.MODEL.WHOLE_MODEL_TRAIN = "no" 51 | 52 | #SIAMESE REGULARIZATION 53 | _C.MODEL.SIA_REG = "no" 54 | 55 | #Pyramid Attention 56 | _C.MODEL.PYRAMID = "no" 57 | 58 | #Pyramid Attention 59 | _C.MODEL.PYRAMID = "no" 60 | 61 | #GAMMA 62 | _C.MODEL.GAMMA = 1.0 63 | 64 | #BETA 65 | _C.MODEL.BETA = 1.0 66 | 67 | # ----------------------------------------------------------------------------- 68 | # INPUT 69 | # ----------------------------------------------------------------------------- 70 | _C.INPUT = CN() 71 | # Size of the image during training 72 | _C.INPUT.SIZE_TRAIN = [384, 128] 73 | # Size of the image during test 74 | _C.INPUT.SIZE_TEST = [384, 128] 75 | # Random probability for image horizontal flip 76 | _C.INPUT.PROB = 0.5 77 | # Random probability for random erasing 78 | _C.INPUT.RE_PROB = 0.5 79 | # Values to be used for image normalization 80 | _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406] 81 | # Values to be used for image normalization 82 | _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225] 83 | # Value of padding size 84 | _C.INPUT.PADDING = 10 85 | 86 | # ----------------------------------------------------------------------------- 87 | # Dataset 88 | # ----------------------------------------------------------------------------- 89 | _C.DATASETS = CN() 90 | # List of the dataset names for training, as present in paths_catalog.py 91 | _C.DATASETS.NAMES = ('market1501') 92 | # Root directory where datasets should be used (and downloaded if not found) 93 | _C.DATASETS.ROOT_DIR = '/root/person_search/dataset/multi_person' 94 | # 95 | _C.DATASETS.TRAIN_ANNO = 1 96 | 97 | # ----------------------------------------------------------------------------- 98 | # DataLoader 99 | # ----------------------------------------------------------------------------- 100 | _C.DATALOADER = CN() 101 | # Number of data loading threads 102 | _C.DATALOADER.NUM_WORKERS = 8 103 | # Sampler for data loading 104 | _C.DATALOADER.SAMPLER = 'softmax' 105 | # Number of instance for one batch 106 | _C.DATALOADER.NUM_INSTANCE = 16 107 | 108 | # ---------------------------------------------------------------------------- # 109 | # Solver 110 | # ---------------------------------------------------------------------------- # 111 | _C.SOLVER = CN() 112 | # Name of optimizer 113 | _C.SOLVER.OPTIMIZER_NAME = "Adam" 114 | # Number of max epoches 115 | _C.SOLVER.MAX_EPOCHS = 50 116 | # Base learning rate 117 | _C.SOLVER.BASE_LR = 3e-4 118 | # Factor of learning bias 119 | _C.SOLVER.BIAS_LR_FACTOR = 2 120 | # Momentum 121 | _C.SOLVER.MOMENTUM = 0.9 122 | # Margin of triplet loss 123 | _C.SOLVER.MARGIN = 0.3 124 | # Margin of cluster ;pss 125 | _C.SOLVER.CLUSTER_MARGIN = 0.3 126 | # Learning rate of SGD to learn the centers of center loss 127 | _C.SOLVER.CENTER_LR = 0.5 128 | # Balanced weight of center loss 129 | _C.SOLVER.CENTER_LOSS_WEIGHT = 0.0005 130 | # Settings of range loss 131 | _C.SOLVER.RANGE_K = 2 132 | _C.SOLVER.RANGE_MARGIN = 0.3 133 | _C.SOLVER.RANGE_ALPHA = 0 134 | _C.SOLVER.RANGE_BETA = 1 135 | _C.SOLVER.RANGE_LOSS_WEIGHT = 1 136 | 137 | # Settings of weight decay 138 | _C.SOLVER.WEIGHT_DECAY = 0.0005 139 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0. 140 | 141 | # decay rate of learning rate 142 | _C.SOLVER.GAMMA = 0.1 143 | # decay step of learning rate 144 | _C.SOLVER.STEPS = (30, 55) 145 | 146 | # warm up factor 147 | _C.SOLVER.WARMUP_FACTOR = 1.0 / 3 148 | # iterations of warm up 149 | _C.SOLVER.WARMUP_ITERS = 500 150 | # method of warm up, option: 'constant','linear' 151 | _C.SOLVER.WARMUP_METHOD = "linear" 152 | 153 | # epoch number of saving checkpoints 154 | _C.SOLVER.CHECKPOINT_PERIOD = 50 155 | # iteration of display training log 156 | _C.SOLVER.LOG_PERIOD = 100 157 | # epoch number of validation 158 | _C.SOLVER.EVAL_PERIOD = 50 159 | 160 | # Number of images per batch 161 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will 162 | # see 2 images per batch 163 | _C.SOLVER.IMS_PER_BATCH = 64 164 | 165 | 166 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will 167 | # see 2 images per batch 168 | _C.TEST = CN() 169 | # Number of images per batch during test 170 | _C.TEST.IMS_PER_BATCH = 128 171 | # If test with re-ranking, options: 'yes','no' 172 | _C.TEST.RE_RANKING = 'no' 173 | # Path to trained model 174 | _C.TEST.WEIGHT = "" 175 | # Which feature of BNNeck to be used for test, before or after BNNneck, options: 'before' or 'after' 176 | _C.TEST.NECK_FEAT = 'after' 177 | # Whether feature is nomalized before test, if yes, it is equivalent to cosine distance 178 | _C.TEST.FEAT_NORM = 'yes' 179 | # Test pair 180 | _C.TEST.PAIR = "no" 181 | 182 | # ---------------------------------------------------------------------------- # 183 | # Misc options 184 | # ---------------------------------------------------------------------------- # 185 | # Path to checkpoint and saved log of trained model 186 | _C.OUTPUT_DIR = "" 187 | _C.Pre_Index_DIR = "" 188 | -------------------------------------------------------------------------------- /configs/baseline.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/haoluo/.torch/models/resnet50-19c8e357.pth' 4 | LAST_STRIDE: 2 5 | NECK: 'no' 6 | METRIC_LOSS_TYPE: 'triplet' 7 | IF_LABELSMOOTH: 'off' 8 | IF_WITH_CENTER: 'no' 9 | 10 | 11 | INPUT: 12 | SIZE_TRAIN: [256, 128] 13 | SIZE_TEST: [256, 128] 14 | PROB: 0.5 # random horizontal flip 15 | RE_PROB: 0.0 # random erasing 16 | PADDING: 10 17 | 18 | DATASETS: 19 | NAMES: ('market1501') 20 | 21 | DATALOADER: 22 | SAMPLER: 'softmax_triplet' 23 | NUM_INSTANCE: 4 24 | NUM_WORKERS: 8 25 | 26 | SOLVER: 27 | OPTIMIZER_NAME: 'Adam' 28 | MAX_EPOCHS: 120 29 | BASE_LR: 0.00035 30 | 31 | CLUSTER_MARGIN: 0.3 32 | 33 | CENTER_LR: 0.5 34 | CENTER_LOSS_WEIGHT: 0.0005 35 | 36 | RANGE_K: 2 37 | RANGE_MARGIN: 0.3 38 | RANGE_ALPHA: 0 39 | RANGE_BETA: 1 40 | RANGE_LOSS_WEIGHT: 1 41 | 42 | BIAS_LR_FACTOR: 1 43 | WEIGHT_DECAY: 0.0005 44 | WEIGHT_DECAY_BIAS: 0.0005 45 | IMS_PER_BATCH: 64 46 | 47 | STEPS: [40, 70] 48 | GAMMA: 0.1 49 | 50 | WARMUP_FACTOR: 0.01 51 | WARMUP_ITERS: 0 52 | WARMUP_METHOD: 'linear' 53 | 54 | CHECKPOINT_PERIOD: 40 55 | LOG_PERIOD: 20 56 | EVAL_PERIOD: 40 57 | 58 | TEST: 59 | IMS_PER_BATCH: 128 60 | RE_RANKING: 'no' 61 | WEIGHT: "path" 62 | NECK_FEAT: 'after' 63 | FEAT_NORM: 'yes' 64 | 65 | OUTPUT_DIR: "/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-all-tricks-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on" 66 | 67 | 68 | -------------------------------------------------------------------------------- /configs/softmax.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_PATH: '/home/haoluo/.torch/models/resnet50-19c8e357.pth' 3 | 4 | 5 | INPUT: 6 | SIZE_TRAIN: [256, 128] 7 | SIZE_TEST: [256, 128] 8 | PROB: 0.5 # random horizontal flip 9 | RE_PROB: 0.5 # random erasing 10 | PADDING: 10 11 | 12 | DATASETS: 13 | NAMES: ('market1501') 14 | 15 | DATALOADER: 16 | SAMPLER: 'softmax' 17 | NUM_WORKERS: 8 18 | 19 | SOLVER: 20 | OPTIMIZER_NAME: 'Adam' 21 | MAX_EPOCHS: 120 22 | BASE_LR: 0.00035 23 | BIAS_LR_FACTOR: 1 24 | WEIGHT_DECAY: 0.0005 25 | WEIGHT_DECAY_BIAS: 0.0005 26 | IMS_PER_BATCH: 64 27 | 28 | STEPS: [30, 55] 29 | GAMMA: 0.1 30 | 31 | WARMUP_FACTOR: 0.01 32 | WARMUP_ITERS: 5 33 | WARMUP_METHOD: 'linear' 34 | 35 | CHECKPOINT_PERIOD: 20 36 | LOG_PERIOD: 20 37 | EVAL_PERIOD: 20 38 | 39 | TEST: 40 | IMS_PER_BATCH: 128 41 | 42 | OUTPUT_DIR: "/home/haoluo/log/reid/market1501/softmax_bs64_256x128" 43 | 44 | 45 | -------------------------------------------------------------------------------- /configs/softmax_triplet.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/haoluo/.torch/models/resnet50-19c8e357.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'on' 6 | IF_WITH_CENTER: 'no' 7 | 8 | 9 | 10 | 11 | INPUT: 12 | SIZE_TRAIN: [256, 128] 13 | SIZE_TEST: [256, 128] 14 | PROB: 0.5 # random horizontal flip 15 | RE_PROB: 0.5 # random erasing 16 | PADDING: 10 17 | 18 | DATASETS: 19 | NAMES: ('market1501') 20 | 21 | DATALOADER: 22 | SAMPLER: 'softmax_triplet' 23 | NUM_INSTANCE: 4 24 | NUM_WORKERS: 8 25 | 26 | SOLVER: 27 | OPTIMIZER_NAME: 'Adam' 28 | MAX_EPOCHS: 120 29 | BASE_LR: 0.00035 30 | 31 | CLUSTER_MARGIN: 0.3 32 | 33 | CENTER_LR: 0.5 34 | CENTER_LOSS_WEIGHT: 0.0005 35 | 36 | RANGE_K: 2 37 | RANGE_MARGIN: 0.3 38 | RANGE_ALPHA: 0 39 | RANGE_BETA: 1 40 | RANGE_LOSS_WEIGHT: 1 41 | 42 | BIAS_LR_FACTOR: 1 43 | WEIGHT_DECAY: 0.0005 44 | WEIGHT_DECAY_BIAS: 0.0005 45 | IMS_PER_BATCH: 64 46 | 47 | STEPS: [40, 70] 48 | GAMMA: 0.1 49 | 50 | WARMUP_FACTOR: 0.01 51 | WARMUP_ITERS: 10 52 | WARMUP_METHOD: 'linear' 53 | 54 | CHECKPOINT_PERIOD: 40 55 | LOG_PERIOD: 20 56 | EVAL_PERIOD: 40 57 | 58 | TEST: 59 | IMS_PER_BATCH: 128 60 | RE_RANKING: 'no' 61 | WEIGHT: "path" 62 | NECK_FEAT: 'after' 63 | FEAT_NORM: 'yes' 64 | 65 | OUTPUT_DIR: "/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-all-tricks-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on" 66 | 67 | 68 | -------------------------------------------------------------------------------- /configs/softmax_triplet_ft.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'self' 3 | PRETRAIN_PATH: '/root/person_search/trained/strong_baseline/prw_all_trick_10/resnet50_model_120.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'no' 6 | IF_WITH_CENTER: 'no' 7 | HAS_NON_LOCAL: "yes" 8 | WHOLE_MODEL_TRAIN: "no" 9 | SIA_REG: "no" 10 | PYRAMID: "no" 11 | GAMMA: 1.0 12 | BETA: 1.0 13 | 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | 22 | DATASETS: 23 | NAMES: ('market1501') 24 | TRAIN_ANNO: 1 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'Adam' 33 | MAX_EPOCHS: 50 34 | BASE_LR: 0.00035 35 | 36 | CLUSTER_MARGIN: 0.3 37 | 38 | CENTER_LR: 0.5 39 | CENTER_LOSS_WEIGHT: 0.0005 40 | 41 | RANGE_K: 2 42 | RANGE_MARGIN: 0.3 43 | RANGE_ALPHA: 0 44 | RANGE_BETA: 1 45 | RANGE_LOSS_WEIGHT: 1 46 | 47 | BIAS_LR_FACTOR: 1 48 | WEIGHT_DECAY: 0.0005 49 | WEIGHT_DECAY_BIAS: 0.0005 50 | IMS_PER_BATCH: 64 51 | 52 | STEPS: [20, 40] 53 | GAMMA: 0.1 54 | 55 | WARMUP_FACTOR: 0.01 56 | WARMUP_ITERS: 10 57 | WARMUP_METHOD: 'linear' 58 | 59 | CHECKPOINT_PERIOD: 20 60 | LOG_PERIOD: 20 61 | EVAL_PERIOD: 20 62 | 63 | 64 | TEST: 65 | IMS_PER_BATCH: 128 66 | RE_RANKING: 'no' 67 | WEIGHT: "path" 68 | NECK_FEAT: 'after' 69 | FEAT_NORM: 'yes' 70 | PAIR: "no" 71 | 72 | OUTPUT_DIR: "/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-all-tricks-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on" 73 | Pre_Index_DIR: "/root/person_search/multi-personReid/pre_index_dir/prw_pre_index.json" 74 | 75 | 76 | -------------------------------------------------------------------------------- /configs/softmax_triplet_ftc.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'self' 3 | PRETRAIN_PATH: '/root/person_search/trained/strong_baseline/cuhk_all_trick_1/resnet50_model_120.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'no' 6 | IF_WITH_CENTER: 'no' 7 | HAS_NON_LOCAL: "yes" 8 | WHOLE_MODEL_TRAIN: "no" 9 | SIA_REG: "no" 10 | PYRAMID: "no" 11 | GAMMA: 1.0 12 | BETA: 1.0 13 | 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | 22 | DATASETS: 23 | NAMES: ('market1501') 24 | TRAIN_ANNO: 1 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'Adam' 33 | MAX_EPOCHS: 50 34 | BASE_LR: 0.00035 35 | 36 | CLUSTER_MARGIN: 0.3 37 | 38 | CENTER_LR: 0.5 39 | CENTER_LOSS_WEIGHT: 0.0005 40 | 41 | RANGE_K: 2 42 | RANGE_MARGIN: 0.3 43 | RANGE_ALPHA: 0 44 | RANGE_BETA: 1 45 | RANGE_LOSS_WEIGHT: 1 46 | 47 | BIAS_LR_FACTOR: 1 48 | WEIGHT_DECAY: 0.0005 49 | WEIGHT_DECAY_BIAS: 0.0005 50 | IMS_PER_BATCH: 64 51 | 52 | STEPS: [20, 40] 53 | GAMMA: 0.1 54 | 55 | WARMUP_FACTOR: 0.01 56 | WARMUP_ITERS: 10 57 | WARMUP_METHOD: 'linear' 58 | 59 | CHECKPOINT_PERIOD: 20 60 | LOG_PERIOD: 20 61 | EVAL_PERIOD: 20 62 | 63 | 64 | TEST: 65 | IMS_PER_BATCH: 128 66 | RE_RANKING: 'no' 67 | WEIGHT: "path" 68 | NECK_FEAT: 'after' 69 | FEAT_NORM: 'yes' 70 | PAIR: "no" 71 | 72 | 73 | 74 | OUTPUT_DIR: "/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-all-tricks-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on" 75 | 76 | Pre_Index_DIR: "/root/person_search/multi-personReid/pre_index_dir/cuhk_pre_index.json" 77 | 78 | 79 | -------------------------------------------------------------------------------- /configs/softmax_triplet_with_center.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/haoluo/.torch/models/resnet50-19c8e357.pth' 4 | METRIC_LOSS_TYPE: 'triplet_center' 5 | IF_LABELSMOOTH: 'on' 6 | IF_WITH_CENTER: 'yes' 7 | 8 | 9 | 10 | 11 | INPUT: 12 | SIZE_TRAIN: [256, 128] 13 | SIZE_TEST: [256, 128] 14 | PROB: 0.5 # random horizontal flip 15 | RE_PROB: 0.5 # random erasing 16 | PADDING: 10 17 | 18 | DATASETS: 19 | NAMES: ('market1501') 20 | 21 | DATALOADER: 22 | SAMPLER: 'softmax_triplet' 23 | NUM_INSTANCE: 4 24 | NUM_WORKERS: 8 25 | 26 | SOLVER: 27 | OPTIMIZER_NAME: 'Adam' 28 | MAX_EPOCHS: 120 29 | BASE_LR: 0.00035 30 | 31 | CLUSTER_MARGIN: 0.3 32 | 33 | CENTER_LR: 0.5 34 | CENTER_LOSS_WEIGHT: 0.0005 35 | 36 | RANGE_K: 2 37 | RANGE_MARGIN: 0.3 38 | RANGE_ALPHA: 0 39 | RANGE_BETA: 1 40 | RANGE_LOSS_WEIGHT: 1 41 | 42 | BIAS_LR_FACTOR: 1 43 | WEIGHT_DECAY: 0.0005 44 | WEIGHT_DECAY_BIAS: 0.0005 45 | IMS_PER_BATCH: 64 46 | 47 | STEPS: [40, 70] 48 | GAMMA: 0.1 49 | 50 | WARMUP_FACTOR: 0.01 51 | WARMUP_ITERS: 10 52 | WARMUP_METHOD: 'linear' 53 | 54 | CHECKPOINT_PERIOD: 40 55 | LOG_PERIOD: 20 56 | EVAL_PERIOD: 40 57 | 58 | TEST: 59 | IMS_PER_BATCH: 128 60 | RE_RANKING: 'no' 61 | WEIGHT: "path" 62 | NECK_FEAT: 'after' 63 | FEAT_NORM: 'yes' 64 | 65 | OUTPUT_DIR: "/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-all-tricks-tri_center-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on-triplet_centerloss0_0005" 66 | 67 | 68 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import make_data_loader, make_data_loader_train, make_data_loader_val 8 | -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/build.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/__pycache__/build.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/collate_batch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/__pycache__/collate_batch.cpython-36.pyc -------------------------------------------------------------------------------- /data/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from torch.utils.data import DataLoader 8 | 9 | from .collate_batch import train_collate_fn, val_collate_fn, train_collate_fn_pair, val_collate_fn_pair, train_collate_fn_pair3 10 | from .datasets import init_dataset, ImageDataset, ImageDataset_pair, ImageDataset_pair_val, ImageDataset_pair3 11 | from .samplers import RandomIdentitySampler, RandomIdentitySampler_alignedreid # New add by gu 12 | from .transforms import build_transforms 13 | import json 14 | import numpy as np 15 | import random 16 | import os 17 | 18 | import time 19 | 20 | 21 | def multi_person_training_info_prw(train_anno, root): 22 | root = os.path.join(root, 'prw') 23 | path_gt = os.path.join(root, 'each_pid_info.json') 24 | 25 | with open(path_gt, 'r') as f: 26 | each_pid_info = json.load(f) 27 | # print(each_pid_info) 28 | path_hard = os.path.join(root, 'hard_gallery_train/gallery.json') 29 | 30 | with open(path_hard, 'r') as f: 31 | path_hard = json.load(f) 32 | # print(path_hard) 33 | path_hard_camera_id = os.path.join(root, 'hard_gallery_train/camera_id.json') 34 | with open(path_hard_camera_id, 'r') as f: 35 | path_hard_camera_id = json.load(f) 36 | # print(path_hard_camera_id) 37 | 38 | pairs_anno = [] 39 | for img, pids in path_hard.items(): 40 | camera_id = path_hard_camera_id[img] 41 | if len(pids) < 2: 42 | continue 43 | one_pair = [img] 44 | for index, pid in enumerate(pids): 45 | pid_info = each_pid_info[str(pid)] 46 | pid_info_camera_id = np.array(pid_info[0]) 47 | pos_index = np.where(pid_info_camera_id != camera_id)[0] 48 | if len(pos_index) == 0: 49 | continue 50 | query_img = pid_info[1][random.choice(pos_index)] 51 | one_pair = one_pair + [query_img, pid] 52 | 53 | one_pair = one_pair + [camera_id] 54 | if len(one_pair) > 5: 55 | second_pair = [one_pair[0], one_pair[3], one_pair[4], one_pair[1], one_pair[2], one_pair[5]] 56 | pairs_anno.append(one_pair) 57 | pairs_anno.append(second_pair) 58 | # print(len(pairs_anno)) 59 | anno_save_path = os.path.join(root, "pair_pos_unary" + str(train_anno) + ".json") 60 | with open(anno_save_path, 'w+') as f: 61 | json.dump(pairs_anno, f) 62 | 63 | def multi_person_training_info_cuhk(train_anno, root): 64 | root = os.path.join(root, 'cuhk') 65 | path_gt = os.path.join(root, 'each_pid_info.json') 66 | with open(path_gt, 'r') as f: 67 | each_pid_info = json.load(f) 68 | # print(each_pid_info) 69 | 70 | path_hard = os.path.join(root, 'hard_gallery_train/gallery.json') 71 | with open(path_hard, 'r') as f: 72 | path_hard = json.load(f) 73 | # print(path_hard) 74 | 75 | path_hard_camera_id = os.path.join(root, 'hard_gallery_train/camera_id.json') 76 | with open(path_hard_camera_id, 'r') as f: 77 | path_hard_camera_id = json.load(f) 78 | # print(path_hard_camera_id) 79 | 80 | 81 | pairs_anno = [] 82 | count2 = 0 83 | for img, pids in path_hard.items(): 84 | # camera_id = path_hard_camera_id[img] 85 | if len(pids) < 2: 86 | continue 87 | count2+=1 88 | # else: 89 | # continue 90 | one_pair = [img] 91 | camera_id = 0 92 | for index, pid in enumerate(pids): 93 | pid_info = each_pid_info[str(pid)] 94 | # pid_info_camera_id = np.array(pid_info[0]) 95 | # pos_index = np.where(pid_info_camera_id != camera_id)[0] 96 | # if len(pos_index) == 0: 97 | # continue 98 | # query_img = pid_info[1][random.choice(pos_index)] 99 | query_img = random.choice(pid_info[1]) 100 | one_pair = one_pair + [query_img, pid] 101 | 102 | one_pair = one_pair + [camera_id] 103 | if len(one_pair) > 5: 104 | second_pair = [one_pair[0], one_pair[3], one_pair[4], one_pair[1], one_pair[2], one_pair[5]] 105 | pairs_anno.append(one_pair) 106 | pairs_anno.append(second_pair) 107 | 108 | anno_save_path = os.path.join(root, "pair_pos_unary" + str(train_anno) + ".json") 109 | with open(anno_save_path, 'w+') as f: 110 | json.dump(pairs_anno, f) 111 | 112 | def make_data_loader(cfg): 113 | train_transforms = build_transforms(cfg, is_train=True) 114 | val_transforms = build_transforms(cfg, is_train=False) 115 | num_workers = cfg.DATALOADER.NUM_WORKERS 116 | if len(cfg.DATASETS.NAMES) == 1: 117 | dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR) 118 | else: 119 | # TODO: add multi dataset to train 120 | dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR) 121 | 122 | num_classes = dataset.num_train_pids 123 | train_set = ImageDataset(dataset.train, train_transforms) 124 | if cfg.DATALOADER.SAMPLER == 'softmax': 125 | train_loader = DataLoader( 126 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers, 127 | collate_fn=train_collate_fn 128 | ) 129 | else: 130 | train_loader = DataLoader( 131 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, 132 | sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE), 133 | # sampler=RandomIdentitySampler_alignedreid(dataset.train, cfg.DATALOADER.NUM_INSTANCE), # new add by gu 134 | num_workers=num_workers, collate_fn=train_collate_fn 135 | ) 136 | 137 | val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms) 138 | val_loader = DataLoader( 139 | val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 140 | collate_fn=val_collate_fn 141 | ) 142 | return train_loader, val_loader, len(dataset.query), num_classes 143 | 144 | def make_data_loader_train(cfg): 145 | # multi_person_training_info2(cfg.DATASETS.TRAIN_ANNO) 146 | 147 | if "cuhk" in cfg.DATASETS.NAMES: 148 | multi_person_training_info_cuhk(cfg.DATASETS.TRAIN_ANNO, cfg.DATASETS.ROOT_DIR) 149 | else: 150 | multi_person_training_info_prw(cfg.DATASETS.TRAIN_ANNO, cfg.DATASETS.ROOT_DIR) 151 | 152 | train_transforms = build_transforms(cfg, is_train=True) 153 | val_transforms = build_transforms(cfg, is_train=False) 154 | num_workers = cfg.DATALOADER.NUM_WORKERS 155 | if len(cfg.DATASETS.NAMES) == 1: 156 | dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR, train_anno=cfg.DATASETS.TRAIN_ANNO) 157 | else: 158 | # TODO: add multi dataset to train 159 | dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR, train_anno=cfg.DATASETS.TRAIN_ANNO) 160 | 161 | train_set = ImageDataset_pair3(dataset.train, train_transforms) 162 | num_classes = dataset.num_train_pids 163 | 164 | if cfg.DATALOADER.SAMPLER == 'softmax': 165 | train_loader = DataLoader( 166 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers, 167 | collate_fn=train_collate_fn_pair3 168 | ) 169 | else: 170 | train_loader = DataLoader( 171 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, 172 | sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE), 173 | # sampler=RandomIdentitySampler_alignedreid(dataset.train, cfg.DATALOADER.NUM_INSTANCE), # new add by gu 174 | num_workers=num_workers, collate_fn=train_collate_fn_pair3 175 | ) 176 | 177 | val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms) 178 | val_loader = DataLoader( 179 | val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 180 | collate_fn=val_collate_fn 181 | ) 182 | 183 | return train_loader, val_loader, len(dataset.query), num_classes 184 | 185 | def make_data_loader_val(cfg, index, dataset): 186 | 187 | indice_path = cfg.Pre_Index_DIR 188 | with open(indice_path, 'r') as f: 189 | indices = json.load(f) 190 | indice = indices[index][:100] 191 | 192 | val_transforms = build_transforms(cfg, is_train=False) 193 | num_workers = cfg.DATALOADER.NUM_WORKERS 194 | 195 | # if len(cfg.DATASETS.NAMES) == 1: 196 | # dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR) 197 | # else: 198 | # # TODO: add multi dataset to train 199 | # print(cfg.DATASETS.NAMES) 200 | # dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR) 201 | 202 | query = dataset.query[index] 203 | gallery = [dataset.gallery[ind] for ind in indice] 204 | gallery = [query] + gallery 205 | 206 | val_set = ImageDataset_pair_val(query, gallery, val_transforms) 207 | 208 | val_loader = DataLoader( 209 | val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 210 | collate_fn=val_collate_fn_pair 211 | ) 212 | 213 | return val_loader 214 | 215 | 216 | -------------------------------------------------------------------------------- /data/collate_batch.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | 9 | 10 | def train_collate_fn(batch): 11 | imgs, pids, _, _, = zip(*batch) 12 | pids = torch.tensor(pids, dtype=torch.int64) 13 | return torch.stack(imgs, dim=0), pids 14 | 15 | def train_collate_fn_pair(batch): 16 | imgs_query, img_gallery, pids, _, _ , _, pids2, pos_neg = zip(*batch) 17 | pids = torch.tensor(pids, dtype=torch.int64) 18 | pids2 = torch.tensor(pids2, dtype=torch.int64) 19 | pos_neg = torch.FloatTensor(pos_neg) 20 | # pos_neg = torch.tensor(pos_neg) 21 | 22 | return torch.stack(imgs_query, dim=0), torch.stack(img_gallery, dim=0), pids, pids2, pos_neg 23 | 24 | def train_collate_fn_pair3(batch): 25 | img_gallery, imgs_query1, pids1, imgs_query2, pids2, _ = zip(*batch) 26 | pids1 = torch.tensor(pids1, dtype=torch.int64) 27 | pids2 = torch.tensor(pids2, dtype=torch.int64) 28 | return torch.stack(img_gallery, dim=0), torch.stack(imgs_query1, dim=0), torch.stack(imgs_query2, dim=0), pids1, pids2 29 | 30 | def val_collate_fn(batch): 31 | imgs, pids, camids, _ = zip(*batch) 32 | return torch.stack(imgs, dim=0), pids, camids 33 | 34 | def val_collate_fn_pair(batch): 35 | imgs_query, imgs_gallery, pids, camids, _ , _, is_first = zip(*batch) 36 | return torch.stack(imgs_query, dim=0), torch.stack(imgs_gallery, dim=0), pids, camids, is_first 37 | -------------------------------------------------------------------------------- /data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | # from .cuhk03 import CUHK03 7 | from .dukemtmcreid import DukeMTMCreID 8 | from .market1501 import Market1501 9 | from .msmt17 import MSMT17 10 | from .veri import VeRi 11 | from .dataset_loader import ImageDataset, ImageDataset_pair, ImageDataset_pair_val, ImageDataset_pair3 12 | from .prw import PRW 13 | from .cuhk import CUHK 14 | 15 | __factory = { 16 | 'market1501': Market1501, 17 | # 'cuhk03': CUHK03, 18 | 'dukemtmc': DukeMTMCreID, 19 | 'msmt17': MSMT17, 20 | 'veri': VeRi, 21 | 'prw': PRW, 22 | 'cuhk':CUHK 23 | } 24 | 25 | 26 | def get_names(): 27 | return __factory.keys() 28 | 29 | 30 | def init_dataset(name, *args, **kwargs): 31 | if name not in __factory.keys(): 32 | raise KeyError("Unknown datasets: {}".format(name)) 33 | return __factory[name](*args, **kwargs) 34 | -------------------------------------------------------------------------------- /data/datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/bases.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/datasets/__pycache__/bases.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/cuhk.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/datasets/__pycache__/cuhk.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/dataset_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/datasets/__pycache__/dataset_loader.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/dukemtmcreid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/datasets/__pycache__/dukemtmcreid.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/eval_reid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/datasets/__pycache__/eval_reid.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/market1501.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/datasets/__pycache__/market1501.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/msmt17.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/datasets/__pycache__/msmt17.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/prw.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/datasets/__pycache__/prw.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/veri.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/datasets/__pycache__/veri.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/bases.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import numpy as np 8 | 9 | 10 | class BaseDataset(object): 11 | """ 12 | Base class of reid dataset 13 | """ 14 | 15 | def get_imagedata_info(self, data): 16 | pids, cams = [], [] 17 | for _, pid, camid in data: 18 | pids += [pid] 19 | cams += [camid] 20 | pids = set(pids) 21 | cams = set(cams) 22 | num_pids = len(pids) 23 | num_cams = len(cams) 24 | num_imgs = len(data) 25 | return num_pids, num_imgs, num_cams 26 | 27 | def get_videodata_info(self, data, return_tracklet_stats=False): 28 | pids, cams, tracklet_stats = [], [], [] 29 | for img_paths, pid, camid in data: 30 | pids += [pid] 31 | cams += [camid] 32 | tracklet_stats += [len(img_paths)] 33 | pids = set(pids) 34 | cams = set(cams) 35 | num_pids = len(pids) 36 | num_cams = len(cams) 37 | num_tracklets = len(data) 38 | if return_tracklet_stats: 39 | return num_pids, num_tracklets, num_cams, tracklet_stats 40 | return num_pids, num_tracklets, num_cams 41 | 42 | def print_dataset_statistics(self): 43 | raise NotImplementedError 44 | 45 | 46 | class BaseImageDataset(BaseDataset): 47 | """ 48 | Base class of image reid dataset 49 | """ 50 | 51 | def print_dataset_statistics(self, train, query, gallery): 52 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 53 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 54 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 55 | 56 | print("Dataset statistics:") 57 | print(" ----------------------------------------") 58 | print(" subset | # ids | # images | # cameras") 59 | print(" ----------------------------------------") 60 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 61 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 62 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 63 | print(" ----------------------------------------") 64 | 65 | 66 | class BaseVideoDataset(BaseDataset): 67 | """ 68 | Base class of video reid dataset 69 | """ 70 | 71 | def print_dataset_statistics(self, train, query, gallery): 72 | num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \ 73 | self.get_videodata_info(train, return_tracklet_stats=True) 74 | 75 | num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \ 76 | self.get_videodata_info(query, return_tracklet_stats=True) 77 | 78 | num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \ 79 | self.get_videodata_info(gallery, return_tracklet_stats=True) 80 | 81 | tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats 82 | min_num = np.min(tracklet_stats) 83 | max_num = np.max(tracklet_stats) 84 | avg_num = np.mean(tracklet_stats) 85 | 86 | print("Dataset statistics:") 87 | print(" -------------------------------------------") 88 | print(" subset | # ids | # tracklets | # cameras") 89 | print(" -------------------------------------------") 90 | print(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_cams)) 91 | print(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_cams)) 92 | print(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams)) 93 | print(" -------------------------------------------") 94 | print(" number of images per tracklet: {} ~ {}, average {:.2f}".format(min_num, max_num, avg_num)) 95 | print(" -------------------------------------------") 96 | -------------------------------------------------------------------------------- /data/datasets/cuhk.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import re 3 | 4 | import os.path as osp 5 | 6 | from .bases import BaseImageDataset 7 | import warnings 8 | import json 9 | import cv2 10 | from tqdm import tqdm 11 | import json 12 | import random 13 | import numpy as np 14 | import os 15 | 16 | import time 17 | 18 | 19 | class CUHK(BaseImageDataset): 20 | """Market1501. 21 | 22 | Reference: 23 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 24 | 25 | URL: ``_ 26 | 27 | Dataset statistics: 28 | - identities: 1501 (+1 for background). 29 | - images: 12936 (train) + 3368 (query) + 15913 (gallery). 30 | """ 31 | _junk_pids = [0, -1] 32 | dataset_dir = '' 33 | dataset_url = 'http://188.138.127.15:81/Datasets/Market-1501-v15.09.15.zip' 34 | 35 | def __init__(self, root='datasets', market1501_500k=False, train_anno=1, **kwargs): 36 | 37 | # root = "/root/person_search/dataset/multi_person" 38 | self.root = os.path.join(root, 'cuhk') 39 | self.train_anno = train_anno 40 | 41 | self.pid_container = set() 42 | 43 | self.gallery_id = [] 44 | 45 | # train = self.process_dir("train", relabel=True) 46 | train = self.process_dir_train(relabel=True) 47 | query = self.process_dir("query", relabel=False) 48 | gallery = self.process_dir("gallery", relabel=False) 49 | 50 | query = sorted(query) 51 | gallery = sorted(gallery) 52 | 53 | self.train = train 54 | self.query = query 55 | self.gallery = gallery 56 | # 57 | 58 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info_train(self.train) 59 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 60 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info_gallery(self.gallery) 61 | 62 | print("Dataset statistics:") 63 | print(" ----------------------------------------") 64 | print(" subset | # ids | # images | # cameras") 65 | print(" ----------------------------------------") 66 | print(" train | {:5d} | {:8d} | {:9d}".format(self.num_train_pids, self.num_train_imgs, self.num_train_cams)) 67 | print(" query | {:5d} | {:8d} | {:9d}".format(self.num_query_pids, self.num_query_imgs, self.num_query_cams)) 68 | print(" gallery | {:5d} | {:8d} | {:9d}".format(self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams)) 69 | print(" ----------------------------------------") 70 | 71 | 72 | def get_imagedata_info_train(self, data): 73 | 74 | pids, cams = [], [] 75 | for _, _, pid, camid, pid2, pos_neg in data: 76 | pids += [pid] 77 | pids += [pid2] 78 | cams += [camid] 79 | pids = set(pids) 80 | cams = set(cams) 81 | num_pids = len(pids) 82 | num_cams = len(cams) 83 | num_imgs = len(data) 84 | return num_pids, num_imgs, num_cams 85 | 86 | def get_imagedata_info_gallery(self, data): 87 | pids, cams = [], [] 88 | for _, pid, camid in data: 89 | if isinstance(pid, list): 90 | for one_pid in pid: 91 | pids += [one_pid] 92 | cams += [camid] 93 | pids = set(pids) 94 | cams = set(cams) 95 | num_pids = len(pids) 96 | num_cams = len(cams) 97 | num_imgs = len(data) 98 | return num_pids, num_imgs, num_cams 99 | 100 | def process_dir_train(self, relabel=True): 101 | # root = "/root/person_search/dataset/person_search/cuhk" 102 | anno_path = osp.join(self.root, "gt_training_box.json") 103 | with open(anno_path, 'r+') as f: 104 | all_anno = json.load(f) 105 | 106 | pid_container = set() 107 | for img_name, pid in all_anno.items(): 108 | pid_container.add(int(pid)) 109 | # print(pid_container) 110 | # print("pid_container: " + str(len(pid_container))) 111 | pid2label = {int(pid): label for label, pid in enumerate(pid_container)} 112 | # print(pid2label) 113 | # print("pid_container: " + str(len(pid_container))) 114 | 115 | 116 | new_anno_path = osp.join(self.root, "pair_pos_unary" + str(self.train_anno) + ".json") 117 | with open(new_anno_path, 'r+') as f: 118 | all_anno = json.load(f) 119 | data = [] 120 | 121 | # img_root1 = "/root/person_search/dataset/multi_person/cuhk/hard_gallery_train/image" 122 | # img_root2 = "/root/person_search/dataset/multi_person/cuhk/train_gt/image" 123 | 124 | img_root1 = os.path.join(self.root, 'hard_gallery_train/image') 125 | img_root2 = os.path.join(self.root, 'train_gt/image') 126 | 127 | file_index = 0 128 | for one_pair in all_anno: 129 | hard_imgname = one_pair[0] 130 | query_train_imgname1 = one_pair[1] 131 | pid1 = one_pair[2] 132 | query_train_imgname2 = one_pair[3] 133 | pid2 = one_pair[4] 134 | camera_id = one_pair[5] 135 | if relabel: 136 | pid1 = pid2label[pid1] 137 | pid2 = pid2label[pid2] 138 | hard_imgname_path = osp.join(img_root1, hard_imgname) 139 | query_train_path1 = osp.join(img_root2, query_train_imgname1) 140 | query_train_path2 = osp.join(img_root2, query_train_imgname2) 141 | new_anno = [hard_imgname_path, query_train_path1, pid1, query_train_path2, pid2, camera_id] 142 | # print(new_anno) 143 | data.append(new_anno) 144 | 145 | return data 146 | 147 | def process_dir(self, dataset, relabel=False): 148 | 149 | if dataset == "query": 150 | anno_path = osp.join(self.root, "query", "query.json") 151 | img_root = osp.join(self.root, "query", "query_image") 152 | elif dataset == "gallery": 153 | gallery_name = "hard_gallery_test" 154 | anno_path = osp.join(self.root, gallery_name, "gallery.json") 155 | img_root = osp.join(self.root, gallery_name, "image") 156 | 157 | with open(anno_path, 'r+') as f: 158 | all_anno = json.load(f) 159 | 160 | valid_pid_path = os.path.join(self.root, 'valid_q_pid.json') 161 | with open(valid_pid_path, 'r+') as f: 162 | valid_pid = json.load(f) 163 | 164 | 165 | data = [] 166 | for img_name, pid in all_anno.items(): 167 | image_path = osp.join(img_root, img_name) 168 | if dataset == "query": 169 | camid = 1 170 | elif dataset == "gallery": 171 | camid = 2 172 | if isinstance(pid, str): 173 | pid = int(pid) 174 | if dataset == "query": 175 | if pid not in valid_pid: 176 | continue 177 | data.append((image_path, pid, int(camid))) 178 | return data 179 | 180 | -------------------------------------------------------------------------------- /data/datasets/cuhk03.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import h5py 8 | import os.path as osp 9 | from scipy.io import loadmat 10 | from scipy.misc import imsave 11 | 12 | from utils.iotools import mkdir_if_missing, write_json, read_json 13 | from .bases import BaseImageDataset 14 | 15 | 16 | class CUHK03(BaseImageDataset): 17 | """ 18 | CUHK03 19 | Reference: 20 | Li et al. DeepReID: Deep Filter Pairing Neural Network for Person Re-identification. CVPR 2014. 21 | URL: http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html#! 22 | 23 | Dataset statistics: 24 | # identities: 1360 25 | # images: 13164 26 | # cameras: 6 27 | # splits: 20 (classic) 28 | Args: 29 | split_id (int): split index (default: 0) 30 | cuhk03_labeled (bool): whether to load labeled images; if false, detected images are loaded (default: False) 31 | """ 32 | dataset_dir = 'cuhk03' 33 | 34 | def __init__(self, root='/home/haoluo/data', split_id=0, cuhk03_labeled=False, 35 | cuhk03_classic_split=False, verbose=True, 36 | **kwargs): 37 | super(CUHK03, self).__init__() 38 | self.dataset_dir = osp.join(root, self.dataset_dir) 39 | self.data_dir = osp.join(self.dataset_dir, 'cuhk03_release') 40 | self.raw_mat_path = osp.join(self.data_dir, 'cuhk-03.mat') 41 | 42 | self.imgs_detected_dir = osp.join(self.dataset_dir, 'images_detected') 43 | self.imgs_labeled_dir = osp.join(self.dataset_dir, 'images_labeled') 44 | 45 | self.split_classic_det_json_path = osp.join(self.dataset_dir, 'splits_classic_detected.json') 46 | self.split_classic_lab_json_path = osp.join(self.dataset_dir, 'splits_classic_labeled.json') 47 | 48 | self.split_new_det_json_path = osp.join(self.dataset_dir, 'splits_new_detected.json') 49 | self.split_new_lab_json_path = osp.join(self.dataset_dir, 'splits_new_labeled.json') 50 | 51 | self.split_new_det_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_detected.mat') 52 | self.split_new_lab_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_labeled.mat') 53 | 54 | self._check_before_run() 55 | self._preprocess() 56 | 57 | if cuhk03_labeled: 58 | image_type = 'labeled' 59 | split_path = self.split_classic_lab_json_path if cuhk03_classic_split else self.split_new_lab_json_path 60 | else: 61 | image_type = 'detected' 62 | split_path = self.split_classic_det_json_path if cuhk03_classic_split else self.split_new_det_json_path 63 | 64 | splits = read_json(split_path) 65 | assert split_id < len(splits), "Condition split_id ({}) < len(splits) ({}) is false".format(split_id, 66 | len(splits)) 67 | split = splits[split_id] 68 | print("Split index = {}".format(split_id)) 69 | 70 | train = split['train'] 71 | query = split['query'] 72 | gallery = split['gallery'] 73 | 74 | if verbose: 75 | print("=> CUHK03 ({}) loaded".format(image_type)) 76 | self.print_dataset_statistics(train, query, gallery) 77 | 78 | self.train = train 79 | self.query = query 80 | self.gallery = gallery 81 | 82 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 83 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 84 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 85 | 86 | def _check_before_run(self): 87 | """Check if all files are available before going deeper""" 88 | if not osp.exists(self.dataset_dir): 89 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 90 | if not osp.exists(self.data_dir): 91 | raise RuntimeError("'{}' is not available".format(self.data_dir)) 92 | if not osp.exists(self.raw_mat_path): 93 | raise RuntimeError("'{}' is not available".format(self.raw_mat_path)) 94 | if not osp.exists(self.split_new_det_mat_path): 95 | raise RuntimeError("'{}' is not available".format(self.split_new_det_mat_path)) 96 | if not osp.exists(self.split_new_lab_mat_path): 97 | raise RuntimeError("'{}' is not available".format(self.split_new_lab_mat_path)) 98 | 99 | def _preprocess(self): 100 | """ 101 | This function is a bit complex and ugly, what it does is 102 | 1. Extract data from cuhk-03.mat and save as png images. 103 | 2. Create 20 classic splits. (Li et al. CVPR'14) 104 | 3. Create new split. (Zhong et al. CVPR'17) 105 | """ 106 | print( 107 | "Note: if root path is changed, the previously generated json files need to be re-generated (delete them first)") 108 | if osp.exists(self.imgs_labeled_dir) and \ 109 | osp.exists(self.imgs_detected_dir) and \ 110 | osp.exists(self.split_classic_det_json_path) and \ 111 | osp.exists(self.split_classic_lab_json_path) and \ 112 | osp.exists(self.split_new_det_json_path) and \ 113 | osp.exists(self.split_new_lab_json_path): 114 | return 115 | 116 | mkdir_if_missing(self.imgs_detected_dir) 117 | mkdir_if_missing(self.imgs_labeled_dir) 118 | 119 | print("Extract image data from {} and save as png".format(self.raw_mat_path)) 120 | mat = h5py.File(self.raw_mat_path, 'r') 121 | 122 | def _deref(ref): 123 | return mat[ref][:].T 124 | 125 | def _process_images(img_refs, campid, pid, save_dir): 126 | img_paths = [] # Note: some persons only have images for one view 127 | for imgid, img_ref in enumerate(img_refs): 128 | img = _deref(img_ref) 129 | # skip empty cell 130 | if img.size == 0 or img.ndim < 3: continue 131 | # images are saved with the following format, index-1 (ensure uniqueness) 132 | # campid: index of camera pair (1-5) 133 | # pid: index of person in 'campid'-th camera pair 134 | # viewid: index of view, {1, 2} 135 | # imgid: index of image, (1-10) 136 | viewid = 1 if imgid < 5 else 2 137 | img_name = '{:01d}_{:03d}_{:01d}_{:02d}.png'.format(campid + 1, pid + 1, viewid, imgid + 1) 138 | img_path = osp.join(save_dir, img_name) 139 | if not osp.isfile(img_path): 140 | imsave(img_path, img) 141 | img_paths.append(img_path) 142 | return img_paths 143 | 144 | def _extract_img(name): 145 | print("Processing {} images (extract and save) ...".format(name)) 146 | meta_data = [] 147 | imgs_dir = self.imgs_detected_dir if name == 'detected' else self.imgs_labeled_dir 148 | for campid, camp_ref in enumerate(mat[name][0]): 149 | camp = _deref(camp_ref) 150 | num_pids = camp.shape[0] 151 | for pid in range(num_pids): 152 | img_paths = _process_images(camp[pid, :], campid, pid, imgs_dir) 153 | assert len(img_paths) > 0, "campid{}-pid{} has no images".format(campid, pid) 154 | meta_data.append((campid + 1, pid + 1, img_paths)) 155 | print("- done camera pair {} with {} identities".format(campid + 1, num_pids)) 156 | return meta_data 157 | 158 | meta_detected = _extract_img('detected') 159 | meta_labeled = _extract_img('labeled') 160 | 161 | def _extract_classic_split(meta_data, test_split): 162 | train, test = [], [] 163 | num_train_pids, num_test_pids = 0, 0 164 | num_train_imgs, num_test_imgs = 0, 0 165 | for i, (campid, pid, img_paths) in enumerate(meta_data): 166 | 167 | if [campid, pid] in test_split: 168 | for img_path in img_paths: 169 | camid = int(osp.basename(img_path).split('_')[2]) - 1 # make it 0-based 170 | test.append((img_path, num_test_pids, camid)) 171 | num_test_pids += 1 172 | num_test_imgs += len(img_paths) 173 | else: 174 | for img_path in img_paths: 175 | camid = int(osp.basename(img_path).split('_')[2]) - 1 # make it 0-based 176 | train.append((img_path, num_train_pids, camid)) 177 | num_train_pids += 1 178 | num_train_imgs += len(img_paths) 179 | return train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs 180 | 181 | print("Creating classic splits (# = 20) ...") 182 | splits_classic_det, splits_classic_lab = [], [] 183 | for split_ref in mat['testsets'][0]: 184 | test_split = _deref(split_ref).tolist() 185 | 186 | # create split for detected images 187 | train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \ 188 | _extract_classic_split(meta_detected, test_split) 189 | splits_classic_det.append({ 190 | 'train': train, 'query': test, 'gallery': test, 191 | 'num_train_pids': num_train_pids, 'num_train_imgs': num_train_imgs, 192 | 'num_query_pids': num_test_pids, 'num_query_imgs': num_test_imgs, 193 | 'num_gallery_pids': num_test_pids, 'num_gallery_imgs': num_test_imgs, 194 | }) 195 | 196 | # create split for labeled images 197 | train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \ 198 | _extract_classic_split(meta_labeled, test_split) 199 | splits_classic_lab.append({ 200 | 'train': train, 'query': test, 'gallery': test, 201 | 'num_train_pids': num_train_pids, 'num_train_imgs': num_train_imgs, 202 | 'num_query_pids': num_test_pids, 'num_query_imgs': num_test_imgs, 203 | 'num_gallery_pids': num_test_pids, 'num_gallery_imgs': num_test_imgs, 204 | }) 205 | 206 | write_json(splits_classic_det, self.split_classic_det_json_path) 207 | write_json(splits_classic_lab, self.split_classic_lab_json_path) 208 | 209 | def _extract_set(filelist, pids, pid2label, idxs, img_dir, relabel): 210 | tmp_set = [] 211 | unique_pids = set() 212 | for idx in idxs: 213 | img_name = filelist[idx][0] 214 | camid = int(img_name.split('_')[2]) - 1 # make it 0-based 215 | pid = pids[idx] 216 | if relabel: pid = pid2label[pid] 217 | img_path = osp.join(img_dir, img_name) 218 | tmp_set.append((img_path, int(pid), camid)) 219 | unique_pids.add(pid) 220 | return tmp_set, len(unique_pids), len(idxs) 221 | 222 | def _extract_new_split(split_dict, img_dir): 223 | train_idxs = split_dict['train_idx'].flatten() - 1 # index-0 224 | pids = split_dict['labels'].flatten() 225 | train_pids = set(pids[train_idxs]) 226 | pid2label = {pid: label for label, pid in enumerate(train_pids)} 227 | query_idxs = split_dict['query_idx'].flatten() - 1 228 | gallery_idxs = split_dict['gallery_idx'].flatten() - 1 229 | filelist = split_dict['filelist'].flatten() 230 | train_info = _extract_set(filelist, pids, pid2label, train_idxs, img_dir, relabel=True) 231 | query_info = _extract_set(filelist, pids, pid2label, query_idxs, img_dir, relabel=False) 232 | gallery_info = _extract_set(filelist, pids, pid2label, gallery_idxs, img_dir, relabel=False) 233 | return train_info, query_info, gallery_info 234 | 235 | print("Creating new splits for detected images (767/700) ...") 236 | train_info, query_info, gallery_info = _extract_new_split( 237 | loadmat(self.split_new_det_mat_path), 238 | self.imgs_detected_dir, 239 | ) 240 | splits = [{ 241 | 'train': train_info[0], 'query': query_info[0], 'gallery': gallery_info[0], 242 | 'num_train_pids': train_info[1], 'num_train_imgs': train_info[2], 243 | 'num_query_pids': query_info[1], 'num_query_imgs': query_info[2], 244 | 'num_gallery_pids': gallery_info[1], 'num_gallery_imgs': gallery_info[2], 245 | }] 246 | write_json(splits, self.split_new_det_json_path) 247 | 248 | print("Creating new splits for labeled images (767/700) ...") 249 | train_info, query_info, gallery_info = _extract_new_split( 250 | loadmat(self.split_new_lab_mat_path), 251 | self.imgs_labeled_dir, 252 | ) 253 | splits = [{ 254 | 'train': train_info[0], 'query': query_info[0], 'gallery': gallery_info[0], 255 | 'num_train_pids': train_info[1], 'num_train_imgs': train_info[2], 256 | 'num_query_pids': query_info[1], 'num_query_imgs': query_info[2], 257 | 'num_gallery_pids': gallery_info[1], 'num_gallery_imgs': gallery_info[2], 258 | }] 259 | write_json(splits, self.split_new_lab_json_path) 260 | -------------------------------------------------------------------------------- /data/datasets/dataset_loader.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import os.path as osp 8 | from PIL import Image 9 | from torch.utils.data import Dataset 10 | 11 | 12 | def read_image(img_path): 13 | """Keep reading image until succeed. 14 | This can avoid IOError incurred by heavy IO process.""" 15 | got_img = False 16 | if not osp.exists(img_path): 17 | raise IOError("{} does not exist".format(img_path)) 18 | while not got_img: 19 | try: 20 | img = Image.open(img_path).convert('RGB') 21 | got_img = True 22 | except IOError: 23 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 24 | pass 25 | return img 26 | 27 | 28 | class ImageDataset(Dataset): 29 | """Image Person ReID Dataset""" 30 | 31 | def __init__(self, dataset, transform=None): 32 | self.dataset = dataset 33 | self.transform = transform 34 | 35 | def __len__(self): 36 | return len(self.dataset) 37 | 38 | def __getitem__(self, index): 39 | img_path, pid, camid = self.dataset[index] 40 | img = read_image(img_path) 41 | 42 | if self.transform is not None: 43 | img = self.transform(img) 44 | 45 | return img, pid, camid, img_path 46 | 47 | class ImageDataset_pair(Dataset): 48 | """Image Person ReID Dataset""" 49 | 50 | def __init__(self, dataset, transform=None): 51 | self.dataset = dataset 52 | self.transform = transform 53 | 54 | def __len__(self): 55 | return len(self.dataset) 56 | 57 | def __getitem__(self, index): 58 | query_path, gallery_path, pid, camid, pid2, pos_neg = self.dataset[index] 59 | query_img = read_image(query_path) 60 | gallery_img = read_image(gallery_path) 61 | 62 | if self.transform is not None: 63 | query_img = self.transform(query_img) 64 | gallery_img = self.transform(gallery_img) 65 | 66 | return query_img, gallery_img, pid, camid, query_path, gallery_path, pid2, pos_neg 67 | 68 | class ImageDataset_pair3(Dataset): 69 | """Image Person ReID Dataset""" 70 | 71 | def __init__(self, dataset, transform=None): 72 | self.dataset = dataset 73 | self.transform = transform 74 | 75 | def __len__(self): 76 | return len(self.dataset) 77 | 78 | def __getitem__(self, index): 79 | gallery_path, query_path1, pid1, query_path2, pid2, camera_id = self.dataset[index] 80 | query_img1 = read_image(query_path1) 81 | query_img2 = read_image(query_path2) 82 | gallery_img = read_image(gallery_path) 83 | 84 | if self.transform is not None: 85 | query_img1 = self.transform(query_img1) 86 | query_img2 = self.transform(query_img2) 87 | gallery_img = self.transform(gallery_img) 88 | 89 | return gallery_img, query_img1, pid1, query_img2, pid2, camera_id 90 | 91 | 92 | class ImageDataset_pair_val(Dataset): 93 | """Image Person ReID Dataset""" 94 | 95 | def __init__(self, query, gallery, transform=None): 96 | self.query = query 97 | self.gallery = gallery 98 | self.transform = transform 99 | 100 | def __len__(self): 101 | return len(self.gallery) 102 | 103 | def __getitem__(self, index): 104 | 105 | query_path, pid, camid = self.query 106 | gallery_path, pid, camid = self.gallery[index] 107 | 108 | is_first = query_path == gallery_path 109 | 110 | query_img = read_image(query_path) 111 | gallery_img = read_image(gallery_path) 112 | 113 | if self.transform is not None: 114 | query_img = self.transform(query_img) 115 | gallery_img = self.transform(gallery_img) 116 | 117 | return query_img, gallery_img, pid, camid, query_path, gallery_path, is_first 118 | -------------------------------------------------------------------------------- /data/datasets/dukemtmcreid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import glob 8 | import re 9 | import urllib 10 | import zipfile 11 | 12 | import os.path as osp 13 | 14 | from utils.iotools import mkdir_if_missing 15 | from .bases import BaseImageDataset 16 | import json 17 | 18 | 19 | class DukeMTMCreID(BaseImageDataset): 20 | """ 21 | DukeMTMC-reID 22 | Reference: 23 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 24 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 25 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 26 | 27 | Dataset statistics: 28 | # identities: 1404 (train + query) 29 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 30 | # cameras: 8 31 | """ 32 | dataset_dir = 'dukemtmc-reid' 33 | 34 | def __init__(self, root='/home/haoluo/data', train_anno = 1, verbose=True, **kwargs): 35 | super(DukeMTMCreID, self).__init__() 36 | self.dataset_dir = root 37 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 38 | self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train') 39 | self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query') 40 | self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test') 41 | 42 | # self._download_data() 43 | # self._check_before_run() 44 | 45 | self.root = "/raid/home/henrayzhao/person_search/dataset/multi_person/prw" 46 | # self.multi_person_training_info2() 47 | self.train_anno = train_anno 48 | 49 | train = self.process_dir_train(relabel=True) 50 | query = self._process_dir(self.query_dir, relabel=False) 51 | gallery = self._process_dir(self.gallery_dir, relabel=False) 52 | 53 | # if verbose: 54 | # print("=> DukeMTMC-reID loaded") 55 | # self.print_dataset_statistics(train, query, gallery) 56 | 57 | self.train = train 58 | self.query = query 59 | self.gallery = gallery 60 | 61 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info_train(self.train) 62 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 63 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 64 | 65 | def _download_data(self): 66 | if osp.exists(self.dataset_dir): 67 | print("This dataset has been downloaded.") 68 | return 69 | 70 | print("Creating directory {}".format(self.dataset_dir)) 71 | mkdir_if_missing(self.dataset_dir) 72 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 73 | 74 | print("Downloading DukeMTMC-reID dataset") 75 | urllib.request.urlretrieve(self.dataset_url, fpath) 76 | 77 | print("Extracting files") 78 | zip_ref = zipfile.ZipFile(fpath, 'r') 79 | zip_ref.extractall(self.dataset_dir) 80 | zip_ref.close() 81 | 82 | def _check_before_run(self): 83 | """Check if all files are available before going deeper""" 84 | if not osp.exists(self.dataset_dir): 85 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 86 | if not osp.exists(self.train_dir): 87 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 88 | if not osp.exists(self.query_dir): 89 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 90 | if not osp.exists(self.gallery_dir): 91 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 92 | 93 | def _process_dir(self, dir_path, relabel=False): 94 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 95 | pattern = re.compile(r'([-\d]+)_c(\d)') 96 | 97 | pid_container = set() 98 | for img_path in img_paths: 99 | pid, _ = map(int, pattern.search(img_path).groups()) 100 | pid_container.add(pid) 101 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 102 | 103 | dataset = [] 104 | for img_path in img_paths: 105 | pid, camid = map(int, pattern.search(img_path).groups()) 106 | assert 1 <= camid <= 8 107 | camid -= 1 # index starts from 0 108 | if relabel: pid = pid2label[pid] 109 | dataset.append((img_path, pid, camid)) 110 | if 'query' in dir_path and len(dataset) >= 300: 111 | break 112 | 113 | return dataset 114 | 115 | def get_imagedata_info_train(self, data): 116 | 117 | pids, cams = [], [] 118 | for _, _, pid, camid, pid2, pos_neg in data: 119 | pids += [pid] 120 | pids += [pid2] 121 | cams += [camid] 122 | pids = set(pids) 123 | cams = set(cams) 124 | num_pids = len(pids) 125 | num_cams = len(cams) 126 | num_imgs = len(data) 127 | return num_pids, num_imgs, num_cams 128 | 129 | def process_dir_train(self, relabel=True): 130 | root = "/raid/home/henrayzhao/person_search/dataset/person_search/prw" 131 | anno_path = osp.join(root, "training_box", "training_box.json") 132 | with open(anno_path, 'r+') as f: 133 | all_anno = json.load(f) 134 | 135 | pid_container = set() 136 | for img_name, pid in all_anno.items(): 137 | pid_container.add(pid) 138 | pid2label = {int(pid): label for label, pid in enumerate(pid_container)} 139 | 140 | new_anno_path = osp.join(self.root, "pair_pos_unary" + str(self.train_anno) + ".json") 141 | with open(new_anno_path, 'r+') as f: 142 | all_anno = json.load(f) 143 | data = [] 144 | 145 | img_root1 = "/raid/home/henrayzhao/person_search/dataset/multi_person/prw/hard_gallery_train/image" 146 | img_root2 = "/raid/home/henrayzhao/person_search/dataset/multi_person/prw/train_gt/image" 147 | 148 | for one_pair in all_anno: 149 | # print(one_pair) 150 | hard_imgname = one_pair[0] 151 | query_train_imgname1 = one_pair[1] 152 | pid1 = one_pair[2] 153 | query_train_imgname2 = one_pair[3] 154 | pid2 = one_pair[4] 155 | camera_id = one_pair[5] 156 | if relabel: 157 | pid1 = pid2label[pid1] 158 | pid2 = pid2label[pid2] 159 | hard_imgname_path = osp.join(img_root1, hard_imgname) 160 | query_train_path1 = osp.join(img_root2, query_train_imgname1) 161 | query_train_path2 = osp.join(img_root2, query_train_imgname2) 162 | new_anno = [hard_imgname_path, query_train_path1, pid1, query_train_path2, pid2, camera_id] 163 | data.append(new_anno) 164 | return data 165 | -------------------------------------------------------------------------------- /data/datasets/eval_reid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import numpy as np 8 | import json 9 | import os 10 | 11 | 12 | def process_g_pids(q_pid, g_pid_lists): 13 | g_pids = [] 14 | for g_pid_list in g_pid_lists: 15 | if len(g_pid_list) <= 1: 16 | g_pids.append(g_pid_list[0]) 17 | else: 18 | if q_pid in g_pid_list: 19 | g_pids.append(q_pid) 20 | else: 21 | g_pids.append(g_pid_list[0]) 22 | return np.array(g_pids) 23 | 24 | 25 | def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): 26 | 27 | # print(list(g_pids)) 28 | 29 | """Evaluation with market1501 metric 30 | Key: for each query identity, its gallery images from the same camera view are discarded. 31 | """ 32 | num_q, num_g = distmat.shape 33 | if num_g < max_rank: 34 | max_rank = num_g 35 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 36 | indices = np.argsort(distmat, axis=1) 37 | 38 | # compute cmc curve for each query 39 | all_cmc = [] 40 | all_AP = [] 41 | 42 | flag = 0 43 | 44 | if not isinstance(g_pids[0], (int, str)): 45 | list_g_pids = g_pids 46 | flag = 1 47 | 48 | num_valid_q = 0. # number of valid query 49 | 50 | q_pid_return = -88 51 | 52 | for q_idx in range(num_q): 53 | # get query pid and camid 54 | q_pid = q_pids[q_idx] 55 | q_camid = q_camids[q_idx] 56 | 57 | # print(flag) 58 | if flag == 1: 59 | g_pids = process_g_pids(q_pid, list_g_pids) 60 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 61 | 62 | # remove gallery samples that have the same pid and camid with query 63 | order = indices[q_idx] 64 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 65 | keep = np.invert(remove) 66 | 67 | # compute cmc curve 68 | # binary vector, positions with value 1 are correct matches 69 | orig_cmc = matches[q_idx][keep] 70 | if not np.any(orig_cmc): 71 | # this condition is true when query identity does not appear in gallery 72 | continue 73 | 74 | cmc = orig_cmc.cumsum() 75 | cmc[cmc > 1] = 1 76 | all_cmc.append(cmc[:max_rank]) 77 | num_valid_q += 1. 78 | 79 | # compute average precision 80 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 81 | num_rel = orig_cmc.sum() 82 | tmp_cmc = orig_cmc.cumsum() 83 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 84 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 85 | AP = tmp_cmc.sum() / num_rel 86 | all_AP.append(AP) 87 | q_pid_return = q_pid 88 | 89 | if num_valid_q == 0: 90 | return -1, -1, q_pid_return 91 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 92 | 93 | all_cmc = np.asarray(all_cmc).astype(np.float32) 94 | 95 | all_cmc = all_cmc.sum(0) / num_valid_q 96 | 97 | mAP = np.mean(all_AP) 98 | 99 | return all_cmc, mAP, q_pid_return 100 | -------------------------------------------------------------------------------- /data/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import glob 8 | import re 9 | 10 | import os.path as osp 11 | 12 | from .bases import BaseImageDataset 13 | import json 14 | 15 | 16 | class Market1501(BaseImageDataset): 17 | """ 18 | Market1501 19 | Reference: 20 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 21 | URL: http://www.liangzheng.org/Project/project_reid.html 22 | 23 | Dataset statistics: 24 | # identities: 1501 (+1 for background) 25 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 26 | """ 27 | dataset_dir = 'market1501' 28 | 29 | def __init__(self, root='/home/haoluo/data', train_anno=1, verbose=True, **kwargs): 30 | super(Market1501, self).__init__() 31 | self.dataset_dir = osp.join(root, self.dataset_dir) 32 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 33 | self.query_dir = osp.join(self.dataset_dir, 'query') 34 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 35 | 36 | # self._check_before_run() 37 | 38 | self.root = "/raid/home/henrayzhao/person_search/dataset/multi_person/prw" 39 | # self.multi_person_training_info2() 40 | self.train_anno = train_anno 41 | 42 | 43 | train = self.process_dir_train(relabel=True) 44 | query = self._process_dir(self.query_dir, relabel=False) 45 | gallery = self._process_dir(self.gallery_dir, relabel=False) 46 | 47 | # if verbose: 48 | # print("=> Market1501 loaded") 49 | # self.print_dataset_statistics(train, query, gallery) 50 | 51 | self.train = train 52 | self.query = query 53 | self.gallery = gallery 54 | 55 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info_train(self.train) 56 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 57 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 58 | # 59 | # def _check_before_run(self): 60 | # """Check if all files are available before going deeper""" 61 | # if not osp.exists(self.dataset_dir): 62 | # raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 63 | # if not osp.exists(self.train_dir): 64 | # raise RuntimeError("'{}' is not available".format(self.train_dir)) 65 | # if not osp.exists(self.query_dir): 66 | # raise RuntimeError("'{}' is not available".format(self.query_dir)) 67 | # if not osp.exists(self.gallery_dir): 68 | # raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 69 | 70 | 71 | def get_imagedata_info_train(self, data): 72 | 73 | pids, cams = [], [] 74 | for _, _, pid, camid, pid2, pos_neg in data: 75 | pids += [pid] 76 | pids += [pid2] 77 | cams += [camid] 78 | pids = set(pids) 79 | cams = set(cams) 80 | num_pids = len(pids) 81 | num_cams = len(cams) 82 | num_imgs = len(data) 83 | return num_pids, num_imgs, num_cams 84 | 85 | def _process_dir(self, dir_path, relabel=False): 86 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 87 | pattern = re.compile(r'([-\d]+)_c(\d)') 88 | 89 | pid_container = set() 90 | for img_path in img_paths: 91 | pid, _ = map(int, pattern.search(img_path).groups()) 92 | if pid == -1: continue # junk images are just ignored 93 | pid_container.add(pid) 94 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 95 | 96 | dataset = [] 97 | for img_path in img_paths: 98 | pid, camid = map(int, pattern.search(img_path).groups()) 99 | if pid == -1: continue # junk images are just ignored 100 | assert 0 <= pid <= 1501 # pid == 0 means background 101 | assert 1 <= camid <= 6 102 | camid -= 1 # index starts from 0 103 | if relabel: pid = pid2label[pid] 104 | dataset.append((img_path, pid, camid)) 105 | if 'query' in dir_path and len(dataset) >= 300: 106 | break 107 | 108 | return dataset 109 | 110 | def process_dir_train(self, relabel=True): 111 | root = "/raid/home/henrayzhao/person_search/dataset/person_search/prw" 112 | anno_path = osp.join(root, "training_box", "training_box.json") 113 | with open(anno_path, 'r+') as f: 114 | all_anno = json.load(f) 115 | 116 | pid_container = set() 117 | for img_name, pid in all_anno.items(): 118 | pid_container.add(pid) 119 | pid2label = {int(pid): label for label, pid in enumerate(pid_container)} 120 | 121 | new_anno_path = osp.join(self.root, "pair_pos_unary" + str(self.train_anno) + ".json") 122 | with open(new_anno_path, 'r+') as f: 123 | all_anno = json.load(f) 124 | data = [] 125 | 126 | img_root1 = "/raid/home/henrayzhao/person_search/dataset/multi_person/prw/hard_gallery_train/image" 127 | img_root2 = "/raid/home/henrayzhao/person_search/dataset/multi_person/prw/train_gt/image" 128 | 129 | for one_pair in all_anno: 130 | # print(one_pair) 131 | hard_imgname = one_pair[0] 132 | query_train_imgname1 = one_pair[1] 133 | pid1 = one_pair[2] 134 | query_train_imgname2 = one_pair[3] 135 | pid2 = one_pair[4] 136 | camera_id = one_pair[5] 137 | if relabel: 138 | pid1 = pid2label[pid1] 139 | pid2 = pid2label[pid2] 140 | hard_imgname_path = osp.join(img_root1, hard_imgname) 141 | query_train_path1 = osp.join(img_root2, query_train_imgname1) 142 | query_train_path2 = osp.join(img_root2, query_train_imgname2) 143 | new_anno = [hard_imgname_path, query_train_path1, pid1, query_train_path2, pid2, camera_id] 144 | data.append(new_anno) 145 | return data 146 | -------------------------------------------------------------------------------- /data/datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/1/17 15:00 4 | # @Author : Hao Luo 5 | # @File : msmt17.py 6 | 7 | import glob 8 | import re 9 | 10 | import os.path as osp 11 | 12 | from .bases import BaseImageDataset 13 | 14 | 15 | class MSMT17(BaseImageDataset): 16 | """ 17 | MSMT17 18 | 19 | Reference: 20 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018. 21 | 22 | URL: http://www.pkuvmc.com/publications/msmt17.html 23 | 24 | Dataset statistics: 25 | # identities: 4101 26 | # images: 32621 (train) + 11659 (query) + 82161 (gallery) 27 | # cameras: 15 28 | """ 29 | dataset_dir = 'msmt17' 30 | 31 | def __init__(self,root='/home/haoluo/data', verbose=True, **kwargs): 32 | super(MSMT17, self).__init__() 33 | self.dataset_dir = osp.join(root, self.dataset_dir) 34 | self.train_dir = osp.join(self.dataset_dir, 'MSMT17_V2/mask_train_v2') 35 | self.test_dir = osp.join(self.dataset_dir, 'MSMT17_V2/mask_test_v2') 36 | self.list_train_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_train.txt') 37 | self.list_val_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_val.txt') 38 | self.list_query_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_query.txt') 39 | self.list_gallery_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_gallery.txt') 40 | 41 | self._check_before_run() 42 | train = self._process_dir(self.train_dir, self.list_train_path) 43 | #val, num_val_pids, num_val_imgs = self._process_dir(self.train_dir, self.list_val_path) 44 | query = self._process_dir(self.test_dir, self.list_query_path) 45 | gallery = self._process_dir(self.test_dir, self.list_gallery_path) 46 | if verbose: 47 | print("=> MSMT17 loaded") 48 | self.print_dataset_statistics(train, query, gallery) 49 | 50 | self.train = train 51 | self.query = query 52 | self.gallery = gallery 53 | 54 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 55 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 56 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 57 | 58 | def _check_before_run(self): 59 | """Check if all files are available before going deeper""" 60 | if not osp.exists(self.dataset_dir): 61 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 62 | if not osp.exists(self.train_dir): 63 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 64 | if not osp.exists(self.test_dir): 65 | raise RuntimeError("'{}' is not available".format(self.test_dir)) 66 | 67 | def _process_dir(self, dir_path, list_path): 68 | with open(list_path, 'r') as txt: 69 | lines = txt.readlines() 70 | dataset = [] 71 | pid_container = set() 72 | for img_idx, img_info in enumerate(lines): 73 | img_path, pid = img_info.split(' ') 74 | pid = int(pid) # no need to relabel 75 | camid = int(img_path.split('_')[2]) 76 | img_path = osp.join(dir_path, img_path) 77 | dataset.append((img_path, pid, camid)) 78 | pid_container.add(pid) 79 | 80 | # check if pid starts from 0 and increments with 1 81 | for idx, pid in enumerate(pid_container): 82 | assert idx == pid, "See code comment for explanation" 83 | return dataset -------------------------------------------------------------------------------- /data/datasets/prw.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import re 3 | 4 | import os.path as osp 5 | 6 | from .bases import BaseImageDataset 7 | import warnings 8 | import json 9 | import cv2 10 | from tqdm import tqdm 11 | import json 12 | import random 13 | import numpy as np 14 | import os 15 | 16 | class PRW(BaseImageDataset): 17 | """Market1501. 18 | 19 | Reference: 20 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 21 | 22 | URL: ``_ 23 | 24 | Dataset statistics: 25 | - identities: 1501 (+1 for background). 26 | - images: 12936 (train) + 3368 (query) + 15913 (gallery). 27 | """ 28 | _junk_pids = [0, -1] 29 | dataset_dir = '' 30 | dataset_url = 'http://188.138.127.15:81/Datasets/Market-1501-v15.09.15.zip' 31 | 32 | def __init__(self, root='datasets', market1501_500k=False, train_anno=1, **kwargs): 33 | 34 | # root = "/root/person_search/dataset/multi_person" 35 | self.root = osp.join(root, 'prw') 36 | 37 | # self.root = "/root/person_search/dataset/multi_person/prw" 38 | self.train_anno = train_anno 39 | self.pid_container = set() 40 | 41 | self.gallery_id = [] 42 | 43 | # train = self.process_dir("train", relabel=True) 44 | train = self.process_dir_train(relabel=True) 45 | query = self.process_dir("query", relabel=False) 46 | gallery = self.process_dir("gallery", relabel=False) 47 | 48 | query = sorted(query) 49 | gallery = sorted(gallery) 50 | 51 | # print(query) 52 | # print(len(query)) 53 | 54 | 55 | self.train = train 56 | self.query = query 57 | self.gallery = gallery 58 | # 59 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info_train(self.train) 60 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 61 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info_gallery(self.gallery) 62 | 63 | print("Dataset statistics:") 64 | print(" ----------------------------------------") 65 | print(" subset | # ids | # images | # cameras") 66 | print(" ----------------------------------------") 67 | print( 68 | " train | {:5d} | {:8d} | {:9d}".format(self.num_train_pids, self.num_train_imgs, self.num_train_cams)) 69 | print( 70 | " query | {:5d} | {:8d} | {:9d}".format(self.num_query_pids, self.num_query_imgs, self.num_query_cams)) 71 | print(" gallery | {:5d} | {:8d} | {:9d}".format(self.num_gallery_pids, self.num_gallery_imgs, 72 | self.num_gallery_cams)) 73 | print(" ----------------------------------------") 74 | 75 | 76 | def get_imagedata_info_train(self, data): 77 | 78 | pids, cams = [], [] 79 | for _, _, pid, camid, pid2, pos_neg in data: 80 | pids += [pid] 81 | pids += [pid2] 82 | cams += [camid] 83 | pids = set(pids) 84 | cams = set(cams) 85 | num_pids = len(pids) 86 | num_cams = len(cams) 87 | num_imgs = len(data) 88 | return num_pids, num_imgs, num_cams 89 | 90 | def get_imagedata_info_gallery(self, data): 91 | pids, cams = [], [] 92 | for _, pid, camid in data: 93 | if isinstance(pid, list): 94 | for one_pid in pid: 95 | pids += [one_pid] 96 | cams += [camid] 97 | pids = set(pids) 98 | cams = set(cams) 99 | num_pids = len(pids) 100 | num_cams = len(cams) 101 | num_imgs = len(data) 102 | return num_pids, num_imgs, num_cams 103 | 104 | def process_dir_train(self, relabel=True): 105 | # # root = "/root/person_search/dataset/person_search/prw" 106 | anno_path = osp.join(self.root, "gt_training_box.json") 107 | 108 | with open(anno_path, 'r+') as f: 109 | all_anno = json.load(f) 110 | 111 | pid_container = set() 112 | for img_name, pid in all_anno.items(): 113 | pid_container.add(pid) 114 | pid2label = {int(pid): label for label, pid in enumerate(pid_container)} 115 | # print(pid2label) 116 | 117 | 118 | new_anno_path = osp.join(self.root, "pair_pos_unary" + str(self.train_anno) + ".json") 119 | with open(new_anno_path, 'r+') as f: 120 | all_anno = json.load(f) 121 | data = [] 122 | 123 | # img_root1 = "/root/person_search/dataset/multi_person/prw/hard_gallery_train/image" 124 | # img_root2 = "/root/person_search/dataset/multi_person/prw/train_gt/image" 125 | 126 | img_root1 = os.path.join(self.root, 'hard_gallery_train/image') 127 | img_root2 = os.path.join(self.root, 'train_gt/image') 128 | 129 | for one_pair in all_anno: 130 | # print(one_pair) 131 | hard_imgname = one_pair[0] 132 | query_train_imgname1 = one_pair[1] 133 | pid1 = one_pair[2] 134 | query_train_imgname2 = one_pair[3] 135 | pid2 = one_pair[4] 136 | camera_id = one_pair[5] 137 | if relabel: 138 | pid1 = pid2label[pid1] 139 | pid2 = pid2label[pid2] 140 | hard_imgname_path = osp.join(img_root1, hard_imgname) 141 | query_train_path1 = osp.join(img_root2, query_train_imgname1) 142 | query_train_path2 = osp.join(img_root2, query_train_imgname2) 143 | new_anno = [hard_imgname_path, query_train_path1, pid1, query_train_path2, pid2, camera_id] 144 | data.append(new_anno) 145 | return data 146 | 147 | def process_dir(self, dataset, relabel=False): 148 | 149 | if dataset == "query": 150 | anno_path = osp.join(self.root, "query", "query.json") 151 | img_root = osp.join(self.root, "query", "query_image") 152 | camid_path = osp.join(self.root, "query", "camera_id.json") 153 | elif dataset == "gallery": 154 | gallery_name = "hard_gallery_test" 155 | anno_path = osp.join(self.root, gallery_name, "gallery.json") 156 | img_root = osp.join(self.root, gallery_name, "image") 157 | camid_path = osp.join(self.root, gallery_name, "camera_id.json") 158 | 159 | with open(anno_path, 'r+') as f: 160 | all_anno = json.load(f) 161 | 162 | 163 | 164 | if dataset == "query" or dataset == "gallery": 165 | with open(camid_path, 'r+') as f: 166 | camid_dic = json.load(f) 167 | 168 | # valid_pid_path = "/root/person_search/dataset/multi_person/prw/valid_q_pid_3.json" 169 | valid_pid_path = os.path.join(self.root, 'valid_q_pid.json') 170 | 171 | with open(valid_pid_path, 'r+') as f: 172 | valid_pid = json.load(f) 173 | 174 | 175 | data = [] 176 | pid_set = set() 177 | for img_name, pid in all_anno.items(): 178 | image_path = osp.join(img_root, img_name) 179 | if dataset == "query": 180 | print({img_name: pid}) 181 | if dataset == "query" or dataset == "gallery": 182 | camid = camid_dic[img_name] 183 | if isinstance(pid, str): 184 | pid = int(pid) 185 | if dataset == "query": 186 | if pid in valid_pid: 187 | if pid in pid_set: 188 | continue 189 | else: 190 | pid_set.add(pid) 191 | else: 192 | continue 193 | 194 | data.append((image_path, pid, int(camid))) 195 | 196 | return data 197 | 198 | -------------------------------------------------------------------------------- /data/datasets/veri.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import re 3 | 4 | import os.path as osp 5 | 6 | from .bases import BaseImageDataset 7 | 8 | 9 | class VeRi(BaseImageDataset): 10 | """ 11 | VeRi-776 12 | Reference: 13 | Liu, Xinchen, et al. "Large-scale vehicle re-identification in urban surveillance videos." ICME 2016. 14 | 15 | URL:https://vehiclereid.github.io/VeRi/ 16 | 17 | Dataset statistics: 18 | # identities: 776 19 | # images: 37778 (train) + 1678 (query) + 11579 (gallery) 20 | # cameras: 20 21 | """ 22 | 23 | dataset_dir = 'veri' 24 | 25 | def __init__(self, root='../', verbose=True, **kwargs): 26 | super(VeRi, self).__init__() 27 | self.dataset_dir = osp.join(root, self.dataset_dir) 28 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 29 | self.query_dir = osp.join(self.dataset_dir, 'image_query') 30 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test') 31 | 32 | self._check_before_run() 33 | 34 | train = self._process_dir(self.train_dir, relabel=True) 35 | query = self._process_dir(self.query_dir, relabel=False) 36 | gallery = self._process_dir(self.gallery_dir, relabel=False) 37 | 38 | if verbose: 39 | print("=> VeRi-776 loaded") 40 | self.print_dataset_statistics(train, query, gallery) 41 | 42 | self.train = train 43 | self.query = query 44 | self.gallery = gallery 45 | 46 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 47 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 48 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 49 | 50 | def _check_before_run(self): 51 | """Check if all files are available before going deeper""" 52 | if not osp.exists(self.dataset_dir): 53 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 54 | if not osp.exists(self.train_dir): 55 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 56 | if not osp.exists(self.query_dir): 57 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 58 | if not osp.exists(self.gallery_dir): 59 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 60 | 61 | def _process_dir(self, dir_path, relabel=False): 62 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 63 | pattern = re.compile(r'([-\d]+)_c(\d+)') 64 | 65 | pid_container = set() 66 | for img_path in img_paths: 67 | pid, _ = map(int, pattern.search(img_path).groups()) 68 | if pid == -1: continue # junk images are just ignored 69 | pid_container.add(pid) 70 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 71 | 72 | dataset = [] 73 | for img_path in img_paths: 74 | pid, camid = map(int, pattern.search(img_path).groups()) 75 | if pid == -1: continue # junk images are just ignored 76 | assert 0 <= pid <= 776 # pid == 0 means background 77 | assert 1 <= camid <= 20 78 | camid -= 1 # index starts from 0 79 | if relabel: pid = pid2label[pid] 80 | dataset.append((img_path, pid, camid)) 81 | 82 | return dataset 83 | 84 | -------------------------------------------------------------------------------- /data/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .triplet_sampler import RandomIdentitySampler, RandomIdentitySampler_alignedreid # new add by gu 8 | -------------------------------------------------------------------------------- /data/samplers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/samplers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/samplers/__pycache__/triplet_sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/samplers/__pycache__/triplet_sampler.cpython-36.pyc -------------------------------------------------------------------------------- /data/samplers/triplet_sampler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import copy 8 | import random 9 | import torch 10 | from collections import defaultdict 11 | 12 | import numpy as np 13 | from torch.utils.data.sampler import Sampler 14 | 15 | 16 | class RandomIdentitySampler(Sampler): 17 | """ 18 | Randomly sample N identities, then for each identity, 19 | randomly sample K instances, therefore batch size is N*K. 20 | Args: 21 | - data_source (list): list of (img_path, pid, camid). 22 | - num_instances (int): number of instances per identity in a batch. 23 | - batch_size (int): number of examples in a batch. 24 | """ 25 | 26 | def __init__(self, data_source, batch_size, num_instances): 27 | self.data_source = data_source 28 | self.batch_size = batch_size 29 | self.num_instances = num_instances 30 | self.num_pids_per_batch = self.batch_size // self.num_instances 31 | self.index_dic = defaultdict(list) 32 | # for index, (_, _, pid, _, _, _) in enumerate(self.data_source): 33 | for index, (_, _, pid, _, _, _) in enumerate(self.data_source): 34 | self.index_dic[pid].append(index) 35 | self.pids = list(self.index_dic.keys()) 36 | 37 | # estimate number of examples in an epoch 38 | self.length = 0 39 | for pid in self.pids: 40 | idxs = self.index_dic[pid] 41 | num = len(idxs) 42 | if num < self.num_instances: 43 | num = self.num_instances 44 | self.length += num - num % self.num_instances 45 | 46 | def __iter__(self): 47 | batch_idxs_dict = defaultdict(list) 48 | 49 | for pid in self.pids: 50 | idxs = copy.deepcopy(self.index_dic[pid]) 51 | if len(idxs) < self.num_instances: 52 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 53 | random.shuffle(idxs) 54 | batch_idxs = [] 55 | for idx in idxs: 56 | batch_idxs.append(idx) 57 | if len(batch_idxs) == self.num_instances: 58 | batch_idxs_dict[pid].append(batch_idxs) 59 | batch_idxs = [] 60 | 61 | avai_pids = copy.deepcopy(self.pids) 62 | final_idxs = [] 63 | 64 | while len(avai_pids) >= self.num_pids_per_batch: 65 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 66 | for pid in selected_pids: 67 | batch_idxs = batch_idxs_dict[pid].pop(0) 68 | final_idxs.extend(batch_idxs) 69 | if len(batch_idxs_dict[pid]) == 0: 70 | avai_pids.remove(pid) 71 | 72 | self.length = len(final_idxs) 73 | return iter(final_idxs) 74 | 75 | def __len__(self): 76 | return self.length 77 | 78 | 79 | # New add by gu 80 | class RandomIdentitySampler_alignedreid(Sampler): 81 | """ 82 | Randomly sample N identities, then for each identity, 83 | randomly sample K instances, therefore batch size is N*K. 84 | 85 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py. 86 | 87 | Args: 88 | data_source (Dataset): dataset to sample from. 89 | num_instances (int): number of instances per identity. 90 | """ 91 | def __init__(self, data_source, num_instances): 92 | self.data_source = data_source 93 | self.num_instances = num_instances 94 | self.index_dic = defaultdict(list) 95 | for index, (_, pid, _) in enumerate(data_source): 96 | self.index_dic[pid].append(index) 97 | self.pids = list(self.index_dic.keys()) 98 | self.num_identities = len(self.pids) 99 | 100 | def __iter__(self): 101 | indices = torch.randperm(self.num_identities) 102 | ret = [] 103 | for i in indices: 104 | pid = self.pids[i] 105 | t = self.index_dic[pid] 106 | replace = False if len(t) >= self.num_instances else True 107 | t = np.random.choice(t, size=self.num_instances, replace=replace) 108 | ret.extend(t) 109 | return iter(ret) 110 | 111 | def __len__(self): 112 | return self.num_identities * self.num_instances 113 | -------------------------------------------------------------------------------- /data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import build_transforms 8 | -------------------------------------------------------------------------------- /data/transforms/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/transforms/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/build.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/transforms/__pycache__/build.cpython-36.pyc -------------------------------------------------------------------------------- /data/transforms/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/data/transforms/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /data/transforms/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import torchvision.transforms as T 8 | 9 | from .transforms import RandomErasing 10 | 11 | 12 | def build_transforms(cfg, is_train=True): 13 | normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) 14 | if is_train: 15 | transform = T.Compose([ 16 | T.Resize(cfg.INPUT.SIZE_TRAIN), 17 | T.RandomHorizontalFlip(p=cfg.INPUT.PROB), 18 | T.Pad(cfg.INPUT.PADDING), 19 | T.RandomCrop(cfg.INPUT.SIZE_TRAIN), 20 | T.ToTensor(), 21 | normalize_transform, 22 | RandomErasing(probability=cfg.INPUT.RE_PROB, mean=cfg.INPUT.PIXEL_MEAN) 23 | ]) 24 | else: 25 | transform = T.Compose([ 26 | T.Resize(cfg.INPUT.SIZE_TEST), 27 | T.ToTensor(), 28 | normalize_transform 29 | ]) 30 | 31 | return transform 32 | -------------------------------------------------------------------------------- /data/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import math 8 | import random 9 | 10 | 11 | class RandomErasing(object): 12 | """ Randomly selects a rectangle region in an image and erases its pixels. 13 | 'Random Erasing Data Augmentation' by Zhong et al. 14 | See https://arxiv.org/pdf/1708.04896.pdf 15 | Args: 16 | probability: The probability that the Random Erasing operation will be performed. 17 | sl: Minimum proportion of erased area against input image. 18 | sh: Maximum proportion of erased area against input image. 19 | r1: Minimum aspect ratio of erased area. 20 | mean: Erasing value. 21 | """ 22 | 23 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 24 | self.probability = probability 25 | self.mean = mean 26 | self.sl = sl 27 | self.sh = sh 28 | self.r1 = r1 29 | 30 | def __call__(self, img): 31 | 32 | if random.uniform(0, 1) >= self.probability: 33 | return img 34 | 35 | for attempt in range(100): 36 | area = img.size()[1] * img.size()[2] 37 | 38 | target_area = random.uniform(self.sl, self.sh) * area 39 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 40 | 41 | h = int(round(math.sqrt(target_area * aspect_ratio))) 42 | w = int(round(math.sqrt(target_area / aspect_ratio))) 43 | 44 | if w < img.size()[2] and h < img.size()[1]: 45 | x1 = random.randint(0, img.size()[1] - h) 46 | y1 = random.randint(0, img.size()[2] - w) 47 | if img.size()[0] == 3: 48 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 49 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 50 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 51 | else: 52 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 53 | return img 54 | 55 | return img 56 | -------------------------------------------------------------------------------- /engine/__pycache__/inference.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/engine/__pycache__/inference.cpython-36.pyc -------------------------------------------------------------------------------- /engine/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/engine/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /engine/inference.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import logging 7 | 8 | import torch 9 | import torch.nn as nn 10 | from ignite.engine import Engine 11 | 12 | from utils.reid_metric import R1_mAP, R1_mAP_reranking 13 | 14 | 15 | def create_supervised_evaluator(model, metrics, 16 | device=None): 17 | """ 18 | Factory function for creating an evaluator for supervised models 19 | 20 | Args: 21 | model (`torch.nn.Module`): the model to train 22 | metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics 23 | device (str, optional): device type specification (default: None). 24 | Applies to both model and batches. 25 | Returns: 26 | Engine: an evaluator engine with supervised inference function 27 | """ 28 | if device: 29 | if torch.cuda.device_count() > 1: 30 | model = nn.DataParallel(model) 31 | model.to(device) 32 | 33 | def _inference(engine, batch): 34 | model.eval() 35 | with torch.no_grad(): 36 | data, pids, camids = batch 37 | data = data.to(device) if torch.cuda.device_count() >= 1 else data 38 | feat = model(data) 39 | return feat, pids, camids 40 | 41 | engine = Engine(_inference) 42 | 43 | for name, metric in metrics.items(): 44 | metric.attach(engine, name) 45 | 46 | return engine 47 | 48 | 49 | def inference( 50 | cfg, 51 | model, 52 | val_loader, 53 | num_query 54 | ): 55 | device = cfg.MODEL.DEVICE 56 | 57 | logger = logging.getLogger("reid_baseline.inference") 58 | logger.info("Enter inferencing") 59 | if cfg.TEST.RE_RANKING == 'no': 60 | print("Create evaluator") 61 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=100, feat_norm=cfg.TEST.FEAT_NORM)}, 62 | device=device) 63 | elif cfg.TEST.RE_RANKING == 'yes': 64 | print("Create evaluator for reranking") 65 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP_reranking(num_query, max_rank=100, feat_norm=cfg.TEST.FEAT_NORM)}, 66 | device=device) 67 | else: 68 | print("Unsupported re_ranking config. Only support for no or yes, but got {}.".format(cfg.TEST.RE_RANKING)) 69 | 70 | evaluator.run(val_loader) 71 | cmc, mAP, _ = evaluator.state.metrics['r1_mAP'] 72 | logger.info('Validation Results') 73 | logger.info("mAP: {:.1%}".format(mAP)) 74 | for r in [1, 5, 10, 100]: 75 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 76 | -------------------------------------------------------------------------------- /engine/trainer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import logging 8 | 9 | import torch 10 | import torch.nn as nn 11 | from ignite.engine import Engine, Events 12 | from ignite.handlers import ModelCheckpoint, Timer 13 | from ignite.metrics import RunningAverage 14 | from data import make_data_loader_val 15 | 16 | from utils.reid_metric import R1_mAP, R1_mAP_pair 17 | from tqdm import tqdm 18 | import time 19 | import json 20 | import numpy as np 21 | from layers.triplet_loss import TripletLoss 22 | 23 | from data.datasets import init_dataset 24 | 25 | import copy 26 | 27 | import torch.nn.functional as F 28 | 29 | from data import make_data_loader, make_data_loader_train 30 | import random 31 | import os 32 | 33 | global ITER 34 | ITER = 0 35 | 36 | 37 | def euclidean_dist(gallery_feature1, gallery_feature2): 38 | 39 | xx = torch.pow(gallery_feature1, 2).sum(1, keepdim=True) 40 | yy = torch.pow(gallery_feature2, 2).sum(1, keepdim=True) 41 | dist1 = xx + yy 42 | dist2 = gallery_feature1 * gallery_feature2 43 | dist2 = dist2.sum(1, keepdim=True) 44 | dist = dist1 - 2 * dist2 45 | dist = dist.clamp(min=1e-12).sqrt() 46 | return dist 47 | 48 | def loss1(gallery_feature1, gallery_feature2, query_feature, margin=0.3): 49 | 50 | ranking_loss = nn.MarginRankingLoss(margin=margin) 51 | y = gallery_feature1.new((gallery_feature1.shape[0], 1)).fill_(1) 52 | dist_neg = euclidean_dist(gallery_feature1, gallery_feature2) 53 | dist_pos = euclidean_dist(gallery_feature1, query_feature) 54 | loss = ranking_loss(dist_neg, dist_pos, y) 55 | 56 | return loss 57 | 58 | 59 | def create_supervised_trainer(model, optimizer, loss_fn, 60 | device=None, gamma=1.0, margin=0.3, beta=1.0): 61 | """ 62 | Factory function for creating a trainer for supervised models 63 | 64 | Args: 65 | model (`torch.nn.Module`): the model to train 66 | optimizer (`torch.optim.Optimizer`): the optimizer to use 67 | loss_fn (torch.nn loss function): the loss function to use 68 | device (str, optional): device type specification (default: None). 69 | Applies to both model and batches. 70 | 71 | Returns: 72 | Engine: a trainer engine with supervised update function 73 | """ 74 | if device: 75 | if torch.cuda.device_count() > 1: 76 | model = nn.DataParallel(model) 77 | model.to(device) 78 | 79 | def _update(engine, batch): 80 | model.train() 81 | optimizer.zero_grad() 82 | # guiding, img, target, target2, pos_neg = batch 83 | 84 | img, guiding1, guiding2, target1, target2 = batch 85 | 86 | img = img.to(device) if torch.cuda.device_count() >= 1 else img 87 | 88 | guiding1 = guiding1.to(device) if torch.cuda.device_count() >= 1 else guiding1 89 | target1 = target1.to(device) if torch.cuda.device_count() >= 1 else target1 90 | 91 | guiding2 = guiding2.to(device) if torch.cuda.device_count() >= 1 else guiding2 92 | target2 = target2.to(device) if torch.cuda.device_count() >= 1 else target2 93 | 94 | # score, feat, score_guiding, feature_guiding, gallery_attention, score_pos_neg = model(guiding1, img, x_g2=guiding2) 95 | score, feat, score1, feat1, feat_query, score2, feat2 = model(guiding1, img, x_g2=guiding2) 96 | 97 | loss = loss_fn(score, feat, target1) + gamma * loss1(feat, feat1.detach(), feat_query, margin=margin) + beta * loss_fn(score2, feat, target1) 98 | 99 | loss.backward() 100 | optimizer.step() 101 | 102 | acc = (score.max(1)[1] == target1).float().mean() 103 | return loss.item(), acc.item() 104 | 105 | return Engine(_update) 106 | 107 | 108 | def create_supervised_trainer_with_center(model, center_criterion, optimizer, optimizer_center, loss_fn, cetner_loss_weight, 109 | device=None): 110 | """ 111 | Factory function for creating a trainer for supervised models 112 | 113 | Args: 114 | model (`torch.nn.Module`): the model to train 115 | optimizer (`torch.optim.Optimizer`): the optimizer to use 116 | loss_fn (torch.nn loss function): the loss function to use 117 | device (str, optional): device type specification (default: None). 118 | Applies to both model and batches. 119 | 120 | Returns: 121 | Engine: a trainer engine with supervised update function 122 | """ 123 | if device: 124 | if torch.cuda.device_count() > 1: 125 | model = nn.DataParallel(model) 126 | model.to(device) 127 | 128 | def _update(engine, batch): 129 | model.train() 130 | optimizer.zero_grad() 131 | optimizer_center.zero_grad() 132 | img, target = batch 133 | img = img.to(device) if torch.cuda.device_count() >= 1 else img 134 | target = target.to(device) if torch.cuda.device_count() >= 1 else target 135 | score, feat = model(img) 136 | loss = loss_fn(score, feat, target) 137 | # print("Total loss is {}, center loss is {}".format(loss, center_criterion(feat, target))) 138 | loss.backward() 139 | optimizer.step() 140 | for param in center_criterion.parameters(): 141 | param.grad.data *= (1. / cetner_loss_weight) 142 | optimizer_center.step() 143 | 144 | # compute acc 145 | acc = (score.max(1)[1] == target).float().mean() 146 | return loss.item(), acc.item() 147 | 148 | return Engine(_update) 149 | 150 | 151 | def create_supervised_evaluator(model, metrics, 152 | device=None): 153 | """ 154 | Factory function for creating an evaluator for supervised models 155 | 156 | Args: 157 | model (`torch.nn.Module`): the model to train 158 | metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics 159 | device (str, optional): device type specification (default: None). 160 | Applies to both model and batches. 161 | Returns: 162 | Engine: an evaluator engine with supervised inference function 163 | """ 164 | if device: 165 | if torch.cuda.device_count() > 1: 166 | model = nn.DataParallel(model) 167 | model.to(device) 168 | 169 | def _inference(engine, batch): 170 | model.eval() 171 | with torch.no_grad(): 172 | guiding, data, pids, camids, is_first = batch 173 | 174 | data = data.to(device) if torch.cuda.device_count() >= 1 else data 175 | guiding = guiding.to(device) if torch.cuda.device_count() >= 1 else guiding 176 | feat = model(guiding, data, is_first=is_first) 177 | 178 | return feat, pids, camids 179 | 180 | engine = Engine(_inference) 181 | 182 | for name, metric in metrics.items(): 183 | metric.attach(engine, name) 184 | 185 | return engine 186 | 187 | 188 | def do_train( 189 | cfg, 190 | model, 191 | train_loader, 192 | val_loader, 193 | optimizer, 194 | scheduler, 195 | loss_fn, 196 | num_query, 197 | start_epoch 198 | ): 199 | log_period = cfg.SOLVER.LOG_PERIOD 200 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD 201 | eval_period = cfg.SOLVER.EVAL_PERIOD 202 | output_dir = cfg.OUTPUT_DIR 203 | device = cfg.MODEL.DEVICE 204 | epochs = cfg.SOLVER.MAX_EPOCHS 205 | 206 | logger = logging.getLogger("reid_baseline.train") 207 | logger.info("Start training") 208 | trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device, gamma=cfg.MODEL.GAMMA, margin=cfg.SOLVER.MARGIN, beta=cfg.MODEL.BETA) 209 | if cfg.TEST.PAIR == "no": 210 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(1, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device) 211 | elif cfg.TEST.PAIR == "yes": 212 | evaluator = create_supervised_evaluator(model, metrics={ 213 | 'r1_mAP': R1_mAP_pair(1, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device) 214 | checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False) 215 | # checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, n_saved=10, require_empty=False) 216 | timer = Timer(average=True) 217 | 218 | trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model, 219 | 'optimizer': optimizer}) 220 | timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, 221 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) 222 | 223 | # average metric to attach on trainer 224 | RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss') 225 | RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc') 226 | 227 | dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR) 228 | 229 | @trainer.on(Events.STARTED) 230 | def start_training(engine): 231 | engine.state.epoch = start_epoch 232 | 233 | @trainer.on(Events.EPOCH_STARTED) 234 | def adjust_learning_rate(engine): 235 | scheduler.step() 236 | 237 | @trainer.on(Events.ITERATION_COMPLETED) 238 | def log_training_loss(engine): 239 | global ITER 240 | ITER += 1 241 | if ITER % log_period == 0: 242 | logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}" 243 | .format(engine.state.epoch, ITER, len(train_loader), 244 | engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'], 245 | scheduler.get_lr()[0])) 246 | if len(train_loader) == ITER: 247 | ITER = 0 248 | 249 | # adding handlers using `trainer.on` decorator API 250 | @trainer.on(Events.EPOCH_COMPLETED) 251 | def print_times(engine): 252 | # multi_person_training_info2() 253 | train_loader, val_loader, num_query, num_classes = make_data_loader_train(cfg) 254 | logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]' 255 | .format(engine.state.epoch, timer.value() * timer.step_count, 256 | train_loader.batch_size / timer.value())) 257 | logger.info('-' * 10) 258 | timer.reset() 259 | 260 | @trainer.on(Events.EPOCH_COMPLETED) 261 | def log_validation_results(engine): 262 | # if engine.state.epoch % eval_period == 0: 263 | if engine.state.epoch >= eval_period: 264 | all_cmc = [] 265 | all_AP = [] 266 | num_valid_q = 0 267 | q_pids = [] 268 | for query_index in tqdm(range(num_query)): 269 | 270 | val_loader = make_data_loader_val(cfg, query_index, dataset) 271 | evaluator.run(val_loader) 272 | cmc, AP, q_pid = evaluator.state.metrics['r1_mAP'] 273 | 274 | if AP >= 0: 275 | if cmc.shape[0] < 50: 276 | continue 277 | num_valid_q += 1 278 | 279 | all_cmc.append(cmc) 280 | all_AP.append(AP) 281 | q_pids.append(int(q_pid)) 282 | else: 283 | continue 284 | 285 | all_cmc = np.asarray(all_cmc).astype(np.float32) 286 | cmc = all_cmc.sum(0) / num_valid_q 287 | mAP = np.mean(all_AP) 288 | logger.info("Validation Results - Epoch: {}".format(engine.state.epoch)) 289 | logger.info("mAP: {:.1%}".format(mAP)) 290 | for r in [1, 5, 10]: 291 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 292 | 293 | 294 | trainer.run(train_loader, max_epochs=epochs) 295 | 296 | 297 | def do_train_with_center( 298 | cfg, 299 | model, 300 | center_criterion, 301 | train_loader, 302 | val_loader, 303 | optimizer, 304 | optimizer_center, 305 | scheduler, 306 | loss_fn, 307 | num_query, 308 | start_epoch 309 | ): 310 | log_period = cfg.SOLVER.LOG_PERIOD 311 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD 312 | eval_period = cfg.SOLVER.EVAL_PERIOD 313 | output_dir = cfg.OUTPUT_DIR 314 | device = cfg.MODEL.DEVICE 315 | epochs = cfg.SOLVER.MAX_EPOCHS 316 | 317 | logger = logging.getLogger("reid_baseline.train") 318 | logger.info("Start training") 319 | trainer = create_supervised_trainer_with_center(model, center_criterion, optimizer, optimizer_center, loss_fn, cfg.SOLVER.CENTER_LOSS_WEIGHT, device=device) 320 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device) 321 | checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False) 322 | timer = Timer(average=True) 323 | 324 | trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model, 325 | 'optimizer': optimizer, 326 | 'center_param': center_criterion, 327 | 'optimizer_center': optimizer_center}) 328 | 329 | timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, 330 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) 331 | 332 | # average metric to attach on trainer 333 | RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss') 334 | RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc') 335 | 336 | @trainer.on(Events.STARTED) 337 | def start_training(engine): 338 | engine.state.epoch = start_epoch 339 | 340 | @trainer.on(Events.EPOCH_STARTED) 341 | def adjust_learning_rate(engine): 342 | scheduler.step() 343 | 344 | @trainer.on(Events.ITERATION_COMPLETED) 345 | def log_training_loss(engine): 346 | global ITER 347 | ITER += 1 348 | 349 | if ITER % log_period == 0: 350 | logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}" 351 | .format(engine.state.epoch, ITER, len(train_loader), 352 | engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'], 353 | scheduler.get_lr()[0])) 354 | if len(train_loader) == ITER: 355 | ITER = 0 356 | 357 | # adding handlers using `trainer.on` decorator API 358 | @trainer.on(Events.EPOCH_COMPLETED) 359 | def print_times(engine): 360 | logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]' 361 | .format(engine.state.epoch, timer.value() * timer.step_count, 362 | train_loader.batch_size / timer.value())) 363 | logger.info('-' * 10) 364 | timer.reset() 365 | 366 | @trainer.on(Events.EPOCH_COMPLETED) 367 | def log_validation_results(engine): 368 | if engine.state.epoch % eval_period == 0: 369 | evaluator.run(val_loader) 370 | cmc, mAP = evaluator.state.metrics['r1_mAP'] 371 | logger.info("Validation Results - Epoch: {}".format(engine.state.epoch)) 372 | logger.info("mAP: {:.1%}".format(mAP)) 373 | for r in [1, 5, 10]: 374 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 375 | 376 | trainer.run(train_loader, max_epochs=epochs) 377 | -------------------------------------------------------------------------------- /image/examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/image/examples.png -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch.nn.functional as F 8 | 9 | from .triplet_loss import TripletLoss, CrossEntropyLabelSmooth 10 | from .center_loss import CenterLoss 11 | 12 | 13 | def make_loss(cfg, num_classes): # modified by gu 14 | sampler = cfg.DATALOADER.SAMPLER 15 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet': 16 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss 17 | else: 18 | print('expected METRIC_LOSS_TYPE should be triplet' 19 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 20 | 21 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 22 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) # new add by luo 23 | print("label smooth on, numclasses:", num_classes) 24 | 25 | if sampler == 'softmax': 26 | def loss_func(score, feat, target): 27 | return F.cross_entropy(score, target) 28 | elif cfg.DATALOADER.SAMPLER == 'triplet': 29 | def loss_func(score, feat, target): 30 | return triplet(feat, target)[0] 31 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet': 32 | def loss_func(score, feat, target): 33 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet': 34 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 35 | return xent(score, target) + triplet(feat, target)[0] 36 | else: 37 | return F.cross_entropy(score, target) + triplet(feat, target)[0] 38 | else: 39 | print('expected METRIC_LOSS_TYPE should be triplet' 40 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 41 | else: 42 | print('expected sampler should be softmax, triplet or softmax_triplet, ' 43 | 'but got {}'.format(cfg.DATALOADER.SAMPLER)) 44 | return loss_func 45 | 46 | 47 | def make_loss_with_center(cfg, num_classes): # modified by gu 48 | if cfg.MODEL.NAME == 'resnet18' or cfg.MODEL.NAME == 'resnet34': 49 | feat_dim = 512 50 | else: 51 | feat_dim = 2048 52 | 53 | if cfg.MODEL.METRIC_LOSS_TYPE == 'center': 54 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss 55 | 56 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center': 57 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss 58 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss 59 | 60 | else: 61 | print('expected METRIC_LOSS_TYPE with center should be center, triplet_center' 62 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 63 | 64 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 65 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) # new add by luo 66 | print("label smooth on, numclasses:", num_classes) 67 | 68 | def loss_func(score, feat, target): 69 | if cfg.MODEL.METRIC_LOSS_TYPE == 'center': 70 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 71 | return xent(score, target) + \ 72 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) 73 | else: 74 | return F.cross_entropy(score, target) + \ 75 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) 76 | 77 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center': 78 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 79 | return xent(score, target) + \ 80 | triplet(feat, target)[0] + \ 81 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) 82 | else: 83 | return F.cross_entropy(score, target) + \ 84 | triplet(feat, target)[0] + \ 85 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) 86 | 87 | else: 88 | print('expected METRIC_LOSS_TYPE with center should be center, triplet_center' 89 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 90 | return loss_func, center_criterion -------------------------------------------------------------------------------- /layers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/layers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /layers/__pycache__/center_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/layers/__pycache__/center_loss.cpython-36.pyc -------------------------------------------------------------------------------- /layers/__pycache__/triplet_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/layers/__pycache__/triplet_loss.cpython-36.pyc -------------------------------------------------------------------------------- /layers/center_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class CenterLoss(nn.Module): 8 | """Center loss. 9 | 10 | Reference: 11 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 12 | 13 | Args: 14 | num_classes (int): number of classes. 15 | feat_dim (int): feature dimension. 16 | """ 17 | 18 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True): 19 | super(CenterLoss, self).__init__() 20 | self.num_classes = num_classes 21 | self.feat_dim = feat_dim 22 | self.use_gpu = use_gpu 23 | 24 | if self.use_gpu: 25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 26 | else: 27 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 28 | 29 | def forward(self, x, labels): 30 | """ 31 | Args: 32 | x: feature matrix with shape (batch_size, feat_dim). 33 | labels: ground truth labels with shape (num_classes). 34 | """ 35 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)" 36 | 37 | batch_size = x.size(0) 38 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 39 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 40 | distmat.addmm_(1, -2, x, self.centers.t()) 41 | 42 | classes = torch.arange(self.num_classes).long() 43 | if self.use_gpu: classes = classes.cuda() 44 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 45 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 46 | 47 | dist = distmat * mask.float() 48 | loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size 49 | #dist = [] 50 | #for i in range(batch_size): 51 | # value = distmat[i][mask[i]] 52 | # value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 53 | # dist.append(value) 54 | #dist = torch.cat(dist) 55 | #loss = dist.mean() 56 | return loss 57 | 58 | 59 | if __name__ == '__main__': 60 | use_gpu = False 61 | center_loss = CenterLoss(use_gpu=use_gpu) 62 | features = torch.rand(16, 2048) 63 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long() 64 | if use_gpu: 65 | features = torch.rand(16, 2048).cuda() 66 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda() 67 | 68 | loss = center_loss(features, targets) 69 | print(loss) 70 | -------------------------------------------------------------------------------- /layers/triplet_loss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import torch 7 | from torch import nn 8 | 9 | 10 | def normalize(x, axis=-1): 11 | """Normalizing to unit length along the specified dimension. 12 | Args: 13 | x: pytorch Variable 14 | Returns: 15 | x: pytorch Variable, same shape as input 16 | """ 17 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 18 | return x 19 | 20 | 21 | def euclidean_dist(x, y): 22 | """ 23 | Args: 24 | x: pytorch Variable, with shape [m, d] 25 | y: pytorch Variable, with shape [n, d] 26 | Returns: 27 | dist: pytorch Variable, with shape [m, n] 28 | """ 29 | m, n = x.size(0), y.size(0) 30 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 31 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 32 | dist = xx + yy 33 | dist.addmm_(1, -2, x, y.t()) 34 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 35 | return dist 36 | 37 | 38 | def hard_example_mining(dist_mat, labels, return_inds=False): 39 | """For each anchor, find the hardest positive and negative sample. 40 | Args: 41 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 42 | labels: pytorch LongTensor, with shape [N] 43 | return_inds: whether to return the indices. Save time if `False`(?) 44 | Returns: 45 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 46 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 47 | p_inds: pytorch LongTensor, with shape [N]; 48 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 49 | n_inds: pytorch LongTensor, with shape [N]; 50 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 51 | NOTE: Only consider the case in which all labels have same num of samples, 52 | thus we can cope with all anchors in parallel. 53 | """ 54 | 55 | assert len(dist_mat.size()) == 2 56 | assert dist_mat.size(0) == dist_mat.size(1) 57 | N = dist_mat.size(0) 58 | 59 | # shape [N, N] 60 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 61 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 62 | 63 | # `dist_ap` means distance(anchor, positive) 64 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 65 | 66 | dist_ap, relative_p_inds = torch.max( 67 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 68 | 69 | 70 | # `dist_an` means distance(anchor, negative) 71 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 72 | dist_an, relative_n_inds = torch.min( 73 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 74 | # shape [N] 75 | dist_ap = dist_ap.squeeze(1) 76 | dist_an = dist_an.squeeze(1) 77 | 78 | if return_inds: 79 | # shape [N, N] 80 | ind = (labels.new().resize_as_(labels) 81 | .copy_(torch.arange(0, N).long()) 82 | .unsqueeze(0).expand(N, N)) 83 | # shape [N, 1] 84 | p_inds = torch.gather( 85 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 86 | n_inds = torch.gather( 87 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 88 | # shape [N] 89 | p_inds = p_inds.squeeze(1) 90 | n_inds = n_inds.squeeze(1) 91 | return dist_ap, dist_an, p_inds, n_inds 92 | 93 | return dist_ap, dist_an 94 | 95 | 96 | class TripletLoss(object): 97 | """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). 98 | Related Triplet Loss theory can be found in paper 'In Defense of the Triplet 99 | Loss for Person Re-Identification'.""" 100 | 101 | def __init__(self, margin=None): 102 | self.margin = margin 103 | if margin is not None: 104 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 105 | else: 106 | self.ranking_loss = nn.SoftMarginLoss() 107 | 108 | def __call__(self, global_feat, labels, normalize_feature=False): 109 | if normalize_feature: 110 | global_feat = normalize(global_feat, axis=-1) 111 | dist_mat = euclidean_dist(global_feat, global_feat) 112 | dist_ap, dist_an = hard_example_mining( 113 | dist_mat, labels) 114 | y = dist_an.new().resize_as_(dist_an).fill_(1) 115 | if self.margin is not None: 116 | loss = self.ranking_loss(dist_an, dist_ap, y) 117 | else: 118 | loss = self.ranking_loss(dist_an - dist_ap, y) 119 | return loss, dist_ap, dist_an 120 | 121 | class CrossEntropyLabelSmooth(nn.Module): 122 | """Cross entropy loss with label smoothing regularizer. 123 | 124 | Reference: 125 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 126 | Equation: y = (1 - epsilon) * y + epsilon / K. 127 | 128 | Args: 129 | num_classes (int): number of classes. 130 | epsilon (float): weight. 131 | """ 132 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 133 | super(CrossEntropyLabelSmooth, self).__init__() 134 | self.num_classes = num_classes 135 | self.epsilon = epsilon 136 | self.use_gpu = use_gpu 137 | self.logsoftmax = nn.LogSoftmax(dim=1) 138 | 139 | def forward(self, inputs, targets): 140 | """ 141 | Args: 142 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 143 | targets: ground truth labels with shape (num_classes) 144 | """ 145 | log_probs = self.logsoftmax(inputs) 146 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 147 | if self.use_gpu: targets = targets.cuda() 148 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 149 | loss = (- targets * log_probs).mean(0).sum() 150 | return loss -------------------------------------------------------------------------------- /modeling/PISNet.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from .backbones.pisnet import pisnet, BasicBlock, Bottleneck 11 | 12 | def weights_init_kaiming(m): 13 | classname = m.__class__.__name__ 14 | if classname.find('Linear') != -1: 15 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 16 | nn.init.constant_(m.bias, 0.0) 17 | elif classname.find('Conv') != -1: 18 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 19 | if m.bias is not None: 20 | nn.init.constant_(m.bias, 0.0) 21 | elif classname.find('BatchNorm') != -1: 22 | if m.affine: 23 | nn.init.constant_(m.weight, 1.0) 24 | nn.init.constant_(m.bias, 0.0) 25 | 26 | 27 | def weights_init_classifier(m): 28 | classname = m.__class__.__name__ 29 | if classname.find('Linear') != -1: 30 | nn.init.normal_(m.weight, std=0.001) 31 | if m.bias is not None: 32 | nn.init.constant_(m.bias, 0.0) 33 | 34 | 35 | class PISNet(nn.Module): 36 | in_planes = 2048 37 | 38 | def __init__(self, num_classes, last_stride, model_path, neck, neck_feat, model_name, pretrain_choice, has_non_local="no", sia_reg="no", pyramid="no", test_pair="no"): 39 | super(PISNet, self).__init__() 40 | 41 | self.base = pisnet(last_stride=last_stride, 42 | block=Bottleneck, 43 | layers=[3, 4, 6, 3], has_non_local=has_non_local, sia_reg=sia_reg, pyramid=pyramid) 44 | 45 | if pretrain_choice == 'imagenet': 46 | self.base.load_param(model_path) 47 | print('Loading pretrained ImageNet model......') 48 | 49 | self.gap = nn.AdaptiveAvgPool2d(1) 50 | # self.gap = nn.AdaptiveMaxPool2d(1) 51 | self.num_classes = num_classes 52 | self.neck = neck 53 | self.neck_feat = neck_feat 54 | self.test_pair = test_pair 55 | self.sia_reg = sia_reg 56 | 57 | if self.neck == 'no': 58 | self.classifier = nn.Linear(self.in_planes, self.num_classes) 59 | # self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) # new add by luo 60 | # self.classifier.apply(weights_init_classifier) # new add by luo 61 | elif self.neck == 'bnneck': 62 | self.bottleneck = nn.BatchNorm1d(self.in_planes) 63 | self.bottleneck.bias.requires_grad_(False) # no shift 64 | self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) 65 | 66 | self.bottleneck.apply(weights_init_kaiming) 67 | self.classifier.apply(weights_init_classifier) 68 | 69 | def forward(self, x_g, x, x_g2=[], is_first=False): 70 | 71 | feature_gallery, gallery_attention, feature_gallery1, gallery_attention1, feature_query, reg_feature_query, reg_query_attention = self.base(x_g, x, x_g2=x_g2, is_first=is_first) 72 | 73 | global_feat = self.gap(feature_gallery) 74 | global_feat = global_feat.view(global_feat.shape[0], -1) 75 | # gallery_attention = gallery_attention.view(gallery_attention.shape[0], -1) 76 | 77 | if self.training: 78 | global_feat1 = self.gap(feature_gallery1) 79 | global_feat1 = global_feat.view(global_feat1.shape[0], -1) 80 | gallery_attention1 = gallery_attention.view(gallery_attention1.shape[0], -1) 81 | 82 | global_feature_query = self.gap(feature_query) 83 | global_feature_query = global_feat.view(global_feature_query.shape[0], -1) 84 | 85 | if self.sia_reg == "yes": 86 | global_reg_query = self.gap(reg_feature_query) 87 | global_reg_query = global_feat.view(global_reg_query.shape[0], -1) 88 | reg_query_attention = gallery_attention.view(reg_query_attention.shape[0], -1) 89 | 90 | # cls_score_pos_neg = self.classifier_attention(gallery_attention) 91 | # cls_score_pos_neg = self.sigmoid(cls_score_pos_neg) 92 | 93 | if self.neck == 'no': 94 | feat = global_feat 95 | if self.training: 96 | feat1 = global_feat1 97 | if self.sia_reg == "yes": 98 | feat2 = global_reg_query 99 | # feat_query = global_feature_query 100 | 101 | # feat_guiding = global_feat_guiding 102 | elif self.neck == 'bnneck': 103 | feat = self.bottleneck(global_feat) # normalize for angular softmax 104 | if self.training: 105 | feat1 = self.bottleneck(global_feat1) # normalize for angular softmax 106 | if self.sia_reg == "yes": 107 | feat2 = self.bottleneck(global_reg_query) 108 | # feat_query = self.bottleneck(global_feature_query) 109 | 110 | # feat_guiding = self.bottleneck(global_feat_guiding) 111 | if self.training: 112 | cls_score = self.classifier(feat) 113 | cls_score1 = self.classifier(feat1) 114 | cls_score2 = self.classifier(feat2) 115 | # cls_score_guiding = self.classifier(feat_guiding) 116 | return cls_score, global_feat, cls_score1, global_feat1, global_feature_query, cls_score2, global_reg_query # global feature for triplet loss 117 | else: 118 | if self.neck_feat == 'after': 119 | # print("Test with feature after BN") 120 | return feat 121 | else: 122 | return global_feat 123 | 124 | def load_param(self, trained_path): 125 | param_dict = torch.load(trained_path).state_dict() 126 | for i in param_dict: 127 | if 'classifier' in i: 128 | continue 129 | self.state_dict()[i].copy_(param_dict[i]) 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /modeling/Pre_Selection_Model.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from .backbones.resnet import ResNet, BasicBlock, Bottleneck 11 | 12 | 13 | def weights_init_kaiming(m): 14 | classname = m.__class__.__name__ 15 | if classname.find('Linear') != -1: 16 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 17 | nn.init.constant_(m.bias, 0.0) 18 | elif classname.find('Conv') != -1: 19 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 20 | if m.bias is not None: 21 | nn.init.constant_(m.bias, 0.0) 22 | elif classname.find('BatchNorm') != -1: 23 | if m.affine: 24 | nn.init.constant_(m.weight, 1.0) 25 | nn.init.constant_(m.bias, 0.0) 26 | 27 | 28 | def weights_init_classifier(m): 29 | classname = m.__class__.__name__ 30 | if classname.find('Linear') != -1: 31 | nn.init.normal_(m.weight, std=0.001) 32 | if m.bias: 33 | nn.init.constant_(m.bias, 0.0) 34 | 35 | 36 | class Pre_Selection_Model(nn.Module): 37 | in_planes = 2048 38 | 39 | def __init__(self, num_classes, last_stride, model_path, neck, neck_feat, model_name, pretrain_choice): 40 | super(Pre_Selection_Model, self).__init__() 41 | 42 | self.base = ResNet(last_stride=last_stride, 43 | block=Bottleneck, 44 | layers=[3, 4, 6, 3]) 45 | 46 | if pretrain_choice == 'imagenet': 47 | self.base.load_param(model_path) 48 | print('Loading pretrained ImageNet model......') 49 | 50 | self.gap = nn.AdaptiveAvgPool2d(1) 51 | # self.gap = nn.AdaptiveMaxPool2d(1) 52 | self.num_classes = num_classes 53 | self.neck = neck 54 | self.neck_feat = neck_feat 55 | 56 | if self.neck == 'no': 57 | self.classifier = nn.Linear(self.in_planes, self.num_classes) 58 | # self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) # new add by luo 59 | # self.classifier.apply(weights_init_classifier) # new add by luo 60 | elif self.neck == 'bnneck': 61 | self.bottleneck = nn.BatchNorm1d(self.in_planes) 62 | self.bottleneck.bias.requires_grad_(False) # no shift 63 | self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) 64 | 65 | self.bottleneck.apply(weights_init_kaiming) 66 | self.classifier.apply(weights_init_classifier) 67 | 68 | def forward(self, x): 69 | 70 | global_feat = self.gap(self.base(x)) # (b, 2048, 1, 1) 71 | global_feat = global_feat.view(global_feat.shape[0], -1) # flatten to (bs, 2048) 72 | 73 | if self.neck == 'no': 74 | feat = global_feat 75 | elif self.neck == 'bnneck': 76 | feat = self.bottleneck(global_feat) # normalize for angular softmax 77 | 78 | if self.training: 79 | cls_score = self.classifier(feat) 80 | return cls_score, global_feat # global feature for triplet loss 81 | else: 82 | if self.neck_feat == 'after': 83 | # print("Test with feature after BN") 84 | return feat 85 | else: 86 | # print("Test with feature before BN") 87 | return global_feat 88 | 89 | def load_param(self, trained_path): 90 | param_dict = torch.load(trained_path).state_dict() 91 | # param_dict = torch.load(trained_path)['model'] 92 | for i in param_dict: 93 | if 'classifier' in i: 94 | continue 95 | self.state_dict()[i].copy_(param_dict[i]) 96 | 97 | -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .PISNet import PISNet 8 | from .Pre_Selection_Model import Pre_Selection_Model 9 | 10 | 11 | def build_model(cfg, num_classes): 12 | model = PISNet(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT, cfg.MODEL.NAME, cfg.MODEL.PRETRAIN_CHOICE, has_non_local=cfg.MODEL.HAS_NON_LOCAL, sia_reg=cfg.MODEL.SIA_REG, pyramid=cfg.MODEL.PYRAMID, test_pair=cfg.TEST.PAIR) 13 | return model 14 | 15 | def build_model_pre(cfg, num_classes): 16 | model = Pre_Selection_Model(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT, cfg.MODEL.NAME, cfg.MODEL.PRETRAIN_CHOICE) 17 | return model 18 | 19 | -------------------------------------------------------------------------------- /modeling/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/modeling/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/backbones/Query_Guided_Attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from time import time 5 | 6 | 7 | class _Query_Guided_Attention(nn.Module): 8 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 9 | super(_Query_Guided_Attention, self).__init__() 10 | 11 | assert dimension in [1, 2, 3] 12 | 13 | self.dimension = dimension 14 | self.sub_sample = sub_sample 15 | 16 | self.in_channels = in_channels 17 | self.inter_channels = inter_channels 18 | 19 | if self.inter_channels is None: 20 | self.inter_channels = in_channels // 2 21 | if self.inter_channels == 0: 22 | self.inter_channels = 1 23 | 24 | if dimension == 3: 25 | conv_nd = nn.Conv3d 26 | self.max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 27 | bn = nn.BatchNorm3d 28 | elif dimension == 2: 29 | conv_nd = nn.Conv2d 30 | self.max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 31 | self.max_pool_layer1 = nn.MaxPool2d(kernel_size=(4, 4)) 32 | self.max_pool_layer2 = nn.MaxPool2d(kernel_size=(8, 8)) 33 | self.gmp = nn.AdaptiveMaxPool2d(1) 34 | 35 | bn = nn.BatchNorm2d 36 | else: 37 | conv_nd = nn.Conv1d 38 | self.max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 39 | bn = nn.BatchNorm1d 40 | 41 | # self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 42 | # kernel_size=1, stride=1, padding=0) 43 | 44 | if bn_layer: 45 | self.W = nn.Sequential( 46 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 47 | kernel_size=1, stride=1, padding=0), 48 | bn(self.in_channels) 49 | ) 50 | nn.init.constant_(self.W[1].weight, 0) 51 | nn.init.constant_(self.W[1].bias, 0) 52 | else: 53 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 54 | kernel_size=1, stride=1, padding=0) 55 | nn.init.constant_(self.W.weight, 0) 56 | nn.init.constant_(self.W.bias, 0) 57 | 58 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 59 | kernel_size=1, stride=1, padding=0) 60 | 61 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 62 | kernel_size=1, stride=1, padding=0) 63 | 64 | # self.compress_attention = conv_nd(in_channels=3, out_channels=1, 65 | # kernel_size=1, stride=1, padding=0) 66 | # if sub_sample: 67 | # # self.g = nn.Sequential(self.g, max_pool_layer) 68 | # self.phi = nn.Sequential(self.phi, max_pool_layer) 69 | 70 | self.relu = nn.ReLU() 71 | 72 | # self.gmp = nn.AdaptiveMaxPool1d(1, return_indices=True) 73 | 74 | def forward(self, x, x_g, attention="x", pyramid="no"): 75 | ''' 76 | :param x: (b, c, t, h, w) 77 | :return: 78 | ''' 79 | 80 | batch_size = x.size(0) 81 | theta_x = self.theta(x) 82 | phi_x = self.phi(x_g) 83 | 84 | if attention == "x": 85 | theta_x = theta_x.view(batch_size, self.inter_channels, -1) 86 | theta_x = theta_x.permute(0, 2, 1) 87 | 88 | if pyramid == "yes": 89 | phi_x1 = self.max_pool_layer(phi_x).view(batch_size, self.inter_channels, -1) 90 | f = torch.matmul(theta_x, phi_x1) 91 | N = f.size(-1) 92 | f_div_C1 = f / N 93 | 94 | phi_x2 = phi_x.view(batch_size, self.inter_channels, -1) 95 | f = torch.matmul(theta_x, phi_x2) 96 | f_div_C2 = f / N 97 | 98 | phi_x3 = self.max_pool_layer1(phi_x).view(batch_size, self.inter_channels, -1) 99 | f = torch.matmul(theta_x, phi_x3) 100 | f_div_C3 = f / N 101 | 102 | phi_x4 = self.max_pool_layer1(phi_x).view(batch_size, self.inter_channels, -1) 103 | f = torch.matmul(theta_x, phi_x4) 104 | f_div_C4 = f / N 105 | 106 | phi_x5 = self.gmp(phi_x).view(batch_size, self.inter_channels, -1) 107 | f = torch.matmul(theta_x, phi_x5) 108 | f_div_C5 = f / N 109 | 110 | f_div_C = torch.cat((f_div_C1, f_div_C2, f_div_C3, f_div_C4, f_div_C5), 2) 111 | elif pyramid == "no": 112 | phi_x1 = phi_x.view(batch_size, self.inter_channels, -1) 113 | f = torch.matmul(theta_x, phi_x1) 114 | N = f.size(-1) 115 | f_div_C = f / N 116 | elif pyramid == "s2": 117 | phi_x1 = self.max_pool_layer(phi_x).view(batch_size, self.inter_channels, -1) 118 | f = torch.matmul(theta_x, phi_x1) 119 | N = f.size(-1) 120 | f_div_C = f / N 121 | 122 | f, max_index = torch.max(f_div_C, 2) 123 | f = f.view(batch_size, *x.size()[2:]).unsqueeze(1) 124 | 125 | W_y = x * f 126 | z = W_y + x 127 | 128 | return z, f.squeeze() 129 | 130 | elif attention == "x_g": 131 | phi_x = phi_x.view(batch_size, self.inter_channels, -1).permute(0, 2, 1) 132 | theta_x = theta_x.view(batch_size, self.inter_channels, -1) 133 | f = torch.matmul(phi_x, theta_x) 134 | N = f.size(-1) 135 | f_div_C = f / N 136 | f, max_index = torch.max(f_div_C, 2) 137 | f = f.view(batch_size, *x_g.size()[2:]).unsqueeze(1) 138 | 139 | W_y = x_g * f 140 | z = W_y + x_g 141 | 142 | return z, f 143 | 144 | class Query_Guided_Attention(_Query_Guided_Attention): 145 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 146 | super(Query_Guided_Attention, self).__init__(in_channels, 147 | inter_channels=inter_channels, 148 | dimension=2, sub_sample=sub_sample, 149 | bn_layer=bn_layer) -------------------------------------------------------------------------------- /modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/modeling/backbones/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/backbones/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/modeling/backbones/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /modeling/backbones/pisnet.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import math 8 | 9 | import torch 10 | from torch import nn 11 | from time import time 12 | from torch.nn import functional as F 13 | 14 | from modeling.backbones.Query_Guided_Attention import Query_Guided_Attention 15 | import numpy as np 16 | 17 | def weights_init_kaiming(m): 18 | classname = m.__class__.__name__ 19 | if classname.find('Linear') != -1: 20 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 21 | nn.init.constant_(m.bias, 0.0) 22 | elif classname.find('Conv') != -1: 23 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 24 | if m.bias is not None: 25 | nn.init.constant_(m.bias, 0.0) 26 | elif classname.find('BatchNorm') != -1: 27 | if m.affine: 28 | nn.init.constant_(m.weight, 1.0) 29 | nn.init.constant_(m.bias, 0.0) 30 | 31 | def conv3x3(in_planes, out_planes, stride=1): 32 | """3x3 convolution with padding""" 33 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 34 | padding=1, bias=False) 35 | 36 | 37 | class BasicBlock(nn.Module): 38 | expansion = 1 39 | 40 | def __init__(self, inplanes, planes, stride=1, downsample=None): 41 | super(BasicBlock, self).__init__() 42 | self.conv1 = conv3x3(inplanes, planes, stride) 43 | self.bn1 = nn.BatchNorm2d(planes) 44 | self.relu = nn.ReLU(inplace=True) 45 | self.conv2 = conv3x3(planes, planes) 46 | self.bn2 = nn.BatchNorm2d(planes) 47 | self.downsample = downsample 48 | self.stride = stride 49 | 50 | def forward(self, x): 51 | residual = x 52 | 53 | out = self.conv1(x) 54 | out = self.bn1(out) 55 | out = self.relu(out) 56 | 57 | out = self.conv2(out) 58 | out = self.bn2(out) 59 | 60 | if self.downsample is not None: 61 | residual = self.downsample(x) 62 | 63 | out += residual 64 | out = self.relu(out) 65 | 66 | return out 67 | 68 | 69 | def feature_corruption(x_g, x_g2): 70 | # We ABANDON the standard feature corruption in the paper. 71 | # The simple concat yields the comparable performance. 72 | corrupted_x = torch.cat((x_g, x_g2), 3) 73 | return corrupted_x 74 | 75 | 76 | class Bottleneck(nn.Module): 77 | expansion = 4 78 | 79 | def __init__(self, inplanes, planes, stride=1, downsample=None): 80 | super(Bottleneck, self).__init__() 81 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 82 | self.bn1 = nn.BatchNorm2d(planes) 83 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 84 | padding=1, bias=False) 85 | self.bn2 = nn.BatchNorm2d(planes) 86 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 87 | self.bn3 = nn.BatchNorm2d(planes * 4) 88 | self.relu = nn.ReLU(inplace=True) 89 | self.downsample = downsample 90 | self.stride = stride 91 | 92 | def forward(self, x): 93 | residual = x 94 | 95 | out = self.conv1(x) 96 | out = self.bn1(out) 97 | out = self.relu(out) 98 | 99 | out = self.conv2(out) 100 | out = self.bn2(out) 101 | out = self.relu(out) 102 | 103 | out = self.conv3(out) 104 | out = self.bn3(out) 105 | 106 | if self.downsample is not None: 107 | residual = self.downsample(x) 108 | 109 | out += residual 110 | out = self.relu(out) 111 | 112 | return out 113 | 114 | 115 | class pisnet(nn.Module): 116 | def __init__(self, last_stride=2, block=Bottleneck, layers=[3, 4, 6, 3], has_non_local="no", sia_reg="no", pyramid="no"): 117 | self.inplanes = 64 118 | super().__init__() 119 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 120 | bias=False) 121 | self.bn1 = nn.BatchNorm2d(64) 122 | # self.relu = nn.ReLU(inplace=True) # add missed relu 123 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 124 | self.layer1 = self._make_layer(block, 64, layers[0]) 125 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 126 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 127 | self.layer4 = self._make_layer( 128 | block, 512, layers[3], stride=last_stride) 129 | print("has_non_local:" + has_non_local) 130 | self.has_non_local = has_non_local 131 | self.pyramid = pyramid 132 | self.Query_Guided_Attention = Query_Guided_Attention(in_channels=2048) 133 | self.Query_Guided_Attention.apply(weights_init_kaiming) 134 | self.sia_reg = sia_reg 135 | 136 | 137 | def _make_layer(self, block, planes, blocks, stride=1): 138 | downsample = None 139 | if stride != 1 or self.inplanes != planes * block.expansion: 140 | downsample = nn.Sequential( 141 | nn.Conv2d(self.inplanes, planes * block.expansion, 142 | kernel_size=1, stride=stride, bias=False), 143 | nn.BatchNorm2d(planes * block.expansion), 144 | ) 145 | 146 | layers = [] 147 | layers.append(block(self.inplanes, planes, stride, downsample)) 148 | self.inplanes = planes * block.expansion 149 | for i in range(1, blocks): 150 | layers.append(block(self.inplanes, planes)) 151 | 152 | return nn.Sequential(*layers) 153 | 154 | def forward(self, x_g, x, x_g2=[], is_first=False): 155 | 156 | 157 | x = self.conv1(x) 158 | x_g = self.conv1(x_g) 159 | 160 | x = self.bn1(x) 161 | x_g = self.bn1(x_g) 162 | 163 | x = self.maxpool(x) 164 | x_g = self.maxpool(x_g) 165 | 166 | x = self.layer1(x) 167 | x_g = self.layer1(x_g) 168 | 169 | x = self.layer2(x) 170 | x_g = self.layer2(x_g) 171 | 172 | x = self.layer3(x) 173 | x_g = self.layer3(x_g) 174 | 175 | x = self.layer4(x) 176 | x_g = self.layer4(x_g) 177 | 178 | if not isinstance(x_g2, list): 179 | 180 | x_g2 = self.conv1(x_g2) 181 | x_g2 = self.bn1(x_g2) 182 | x_g2 = self.maxpool(x_g2) 183 | x_g2 = self.layer1(x_g2) 184 | x_g2 = self.layer2(x_g2) 185 | x_g2 = self.layer3(x_g2) 186 | x_g2 = self.layer4(x_g2) 187 | 188 | x1, attention1 = self.Query_Guided_Attention(x, x_g, attention='x', pyramid=self.pyramid) 189 | 190 | if not isinstance(x_g2, list): 191 | x2, attention2 = self.Query_Guided_Attention(x, x_g2, attention='x', pyramid=self.pyramid) 192 | if self.sia_reg == "yes": 193 | rec_x_g = feature_corruption(x_g, x_g2.detach()) 194 | x3, attention3 = self.Query_Guided_Attention(x1, rec_x_g, attention='x_g', pyramid=self.pyramid) 195 | else: 196 | x2 = [] 197 | attention2 = [] 198 | x3 = [] 199 | attention3 = [] 200 | 201 | if isinstance(is_first, tuple): 202 | x1[0, :, :, :] = x_g[0, :, :, :] 203 | 204 | return x1, attention1, x2, attention2, x_g, x3, attention3 205 | 206 | def load_param(self, model_path): 207 | param_dict = torch.load(model_path) 208 | for i in param_dict: 209 | if 'fc' in i: 210 | continue 211 | self.state_dict()[i].copy_(param_dict[i]) 212 | 213 | def random_init(self): 214 | for m in self.modules(): 215 | if isinstance(m, nn.Conv2d): 216 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 217 | m.weight.data.normal_(0, math.sqrt(2. / n)) 218 | elif isinstance(m, nn.BatchNorm2d): 219 | m.weight.data.fill_(1) 220 | m.bias.data.zero_() -------------------------------------------------------------------------------- /modeling/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import math 8 | 9 | import torch 10 | from torch import nn 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None): 23 | super(BasicBlock, self).__init__() 24 | self.conv1 = conv3x3(inplanes, planes, stride) 25 | self.bn1 = nn.BatchNorm2d(planes) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv2 = conv3x3(planes, planes) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | out += residual 46 | out = self.relu(out) 47 | 48 | return out 49 | 50 | 51 | class Bottleneck(nn.Module): 52 | expansion = 4 53 | 54 | def __init__(self, inplanes, planes, stride=1, downsample=None): 55 | super(Bottleneck, self).__init__() 56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 57 | self.bn1 = nn.BatchNorm2d(planes) 58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 59 | padding=1, bias=False) 60 | self.bn2 = nn.BatchNorm2d(planes) 61 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 62 | self.bn3 = nn.BatchNorm2d(planes * 4) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x): 68 | residual = x 69 | 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.bn2(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv3(out) 79 | out = self.bn3(out) 80 | 81 | if self.downsample is not None: 82 | residual = self.downsample(x) 83 | 84 | out += residual 85 | out = self.relu(out) 86 | 87 | return out 88 | 89 | 90 | class ResNet(nn.Module): 91 | def __init__(self, last_stride=2, block=Bottleneck, layers=[3, 4, 6, 3]): 92 | self.inplanes = 64 93 | super().__init__() 94 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 95 | bias=False) 96 | self.bn1 = nn.BatchNorm2d(64) 97 | # self.relu = nn.ReLU(inplace=True) # add missed relu 98 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 99 | self.layer1 = self._make_layer(block, 64, layers[0]) 100 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 101 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 102 | self.layer4 = self._make_layer( 103 | block, 512, layers[3], stride=last_stride) 104 | 105 | def _make_layer(self, block, planes, blocks, stride=1): 106 | downsample = None 107 | if stride != 1 or self.inplanes != planes * block.expansion: 108 | downsample = nn.Sequential( 109 | nn.Conv2d(self.inplanes, planes * block.expansion, 110 | kernel_size=1, stride=stride, bias=False), 111 | nn.BatchNorm2d(planes * block.expansion), 112 | ) 113 | 114 | layers = [] 115 | layers.append(block(self.inplanes, planes, stride, downsample)) 116 | self.inplanes = planes * block.expansion 117 | for i in range(1, blocks): 118 | layers.append(block(self.inplanes, planes)) 119 | 120 | return nn.Sequential(*layers) 121 | 122 | def forward(self, x): 123 | x = self.conv1(x) 124 | x = self.bn1(x) 125 | # x = self.relu(x) # add missed relu 126 | x = self.maxpool(x) 127 | 128 | x = self.layer1(x) 129 | x = self.layer2(x) 130 | x = self.layer3(x) 131 | x = self.layer4(x) 132 | 133 | return x 134 | 135 | def load_param(self, model_path): 136 | param_dict = torch.load(model_path) 137 | for i in param_dict: 138 | if 'fc' in i: 139 | continue 140 | self.state_dict()[i].copy_(param_dict[i]) 141 | 142 | def random_init(self): 143 | for m in self.modules(): 144 | if isinstance(m, nn.Conv2d): 145 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 146 | m.weight.data.normal_(0, math.sqrt(2. / n)) 147 | elif isinstance(m, nn.BatchNorm2d): 148 | m.weight.data.fill_(1) 149 | m.bias.data.zero_() 150 | 151 | -------------------------------------------------------------------------------- /modeling/baseline.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from .backbones.resnet import ResNet, BasicBlock, Bottleneck 11 | 12 | def weights_init_kaiming(m): 13 | classname = m.__class__.__name__ 14 | if classname.find('Linear') != -1: 15 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 16 | nn.init.constant_(m.bias, 0.0) 17 | elif classname.find('Conv') != -1: 18 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 19 | if m.bias is not None: 20 | nn.init.constant_(m.bias, 0.0) 21 | elif classname.find('BatchNorm') != -1: 22 | if m.affine: 23 | nn.init.constant_(m.weight, 1.0) 24 | nn.init.constant_(m.bias, 0.0) 25 | 26 | 27 | def weights_init_classifier(m): 28 | classname = m.__class__.__name__ 29 | if classname.find('Linear') != -1: 30 | nn.init.normal_(m.weight, std=0.001) 31 | if m.bias: 32 | nn.init.constant_(m.bias, 0.0) 33 | 34 | 35 | class Baseline(nn.Module): 36 | in_planes = 2048 37 | 38 | def __init__(self, num_classes, last_stride, model_path, neck, neck_feat, model_name, pretrain_choice): 39 | super(Baseline, self).__init__() 40 | # 41 | -------------------------------------------------------------------------------- /pi_cuhk.sh: -------------------------------------------------------------------------------- 1 | PROJECT_ROOT_DIR=/root/PI-ReID 2 | DATASETS_ROOT_DIR=$PROJECT_ROOT_DIR/datasets 3 | PRETRAINED_PATH=$PROJECT_ROOT_DIR/pretrained/cuhk/resnet50_model_120.pth 4 | OUTPUT=$PROJECT_ROOT_DIR/output/cuhk 5 | Pre_Index_DIR=$PROJECT_ROOT_DIR/pre_index_dir/cuhk_pre_index.json 6 | 7 | python3 tools/pre_selection.py --config_file='configs/softmax_triplet_ftc.yml' MODEL.DEVICE_ID "('1')" DATASETS.NAMES "('cuhk')" \ 8 | DATASETS.ROOT_DIR $DATASETS_ROOT_DIR MODEL.PRETRAIN_CHOICE "('self')" \ 9 | TEST.WEIGHT $PRETRAINED_PATH \ 10 | OUTPUT_DIR $OUTPUT \ 11 | Pre_Index_DIR $Pre_Index_DIR 12 | 13 | python3 tools/train.py --config_file='configs/softmax_triplet_ftc.yml' MODEL.DEVICE_ID "('1')" DATASETS.NAMES "('cuhk')" DATASETS.ROOT_DIR $DATASETS_ROOT_DIR \ 14 | OUTPUT_DIR $OUTPUT SOLVER.BASE_LR 0.000035 TEST.PAIR "no" SOLVER.IMS_PER_BATCH 64 \ 15 | MODEL.WHOLE_MODEL_TRAIN "no" MODEL.PYRAMID "s2" MODEL.SIA_REG "yes" MODEL.GAMMA 1.0 SOLVER.MARGIN 0.1 MODEL.BETA 0.5 DATASETS.TRAIN_ANNO 1 SOLVER.EVAL_PERIOD 10 SOLVER.MAX_EPOCHS 50 \ 16 | MODEL.PRETRAIN_PATH $PRETRAINED_PATH \ 17 | Pre_Index_DIR $Pre_Index_DIR 18 | 19 | -------------------------------------------------------------------------------- /pi_prw.sh: -------------------------------------------------------------------------------- 1 | PROJECT_ROOT_DIR=/root/PI-ReID 2 | DATASETS_ROOT_DIR=$PROJECT_ROOT_DIR/datasets 3 | PRETRAINED_PATH=$PROJECT_ROOT_DIR/pretrained/prw/resnet50_model_120.pth 4 | OUTPUT=$PROJECT_ROOT_DIR/output/prw 5 | Pre_Index_DIR=$PROJECT_ROOT_DIR/pre_index_dir/prw_pre_index.json 6 | 7 | python tools/pre_selection.py --config_file='configs/softmax_triplet_ft.yml' MODEL.DEVICE_ID "('3')" DATASETS.NAMES "('prw')" \ 8 | DATASETS.ROOT_DIR $DATASETS_ROOT_DIR MODEL.PRETRAIN_CHOICE "('self')" \ 9 | TEST.WEIGHT $PRETRAINED_PATH \ 10 | OUTPUT_DIR $OUTPUT \ 11 | Pre_Index_DIR $Pre_Index_DIR 12 | 13 | python3 tools/train.py --config_file='configs/softmax_triplet_ft.yml' MODEL.DEVICE_ID "('3')" DATASETS.NAMES "('prw')" DATASETS.ROOT_DIR $DATASETS_ROOT_DIR \ 14 | OUTPUT_DIR $OUTPUT SOLVER.BASE_LR 0.00035 TEST.PAIR "no" SOLVER.IMS_PER_BATCH 64 \ 15 | MODEL.WHOLE_MODEL_TRAIN "no" MODEL.PYRAMID "s2" MODEL.SIA_REG "yes" MODEL.GAMMA 1.0 SOLVER.MARGIN 0.1 MODEL.BETA 0.5 DATASETS.TRAIN_ANNO 1 SOLVER.EVAL_PERIOD 15 SOLVER.MAX_EPOCHS 50 \ 16 | MODEL.PRETRAIN_PATH $PRETRAINED_PATH \ 17 | Pre_Index_DIR $Pre_Index_DIR 18 | -------------------------------------------------------------------------------- /pre_select_cuhk.sh: -------------------------------------------------------------------------------- 1 | 2 | python3 tools/pre_selection.py --config_file='configs/softmax_triplet_ftc.yml' MODEL.DEVICE_ID "('0')" DATASETS.NAMES "('cuhk')" \ 3 | DATASETS.ROOT_DIR "('/root/person_search/dataset')" MODEL.PRETRAIN_CHOICE "('self')" \ 4 | TEST.WEIGHT "('/root/person_search/trained/strong_baseline/cuhk_all_trick_1/resnet50_model_120.pth')" \ 5 | OUTPUT_DIR "('/root/person_search/trained/multi_person/cuhk_all_trick_1')" 6 | -------------------------------------------------------------------------------- /solver/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import make_optimizer, make_optimizer_with_center 8 | from .lr_scheduler import WarmupMultiStepLR -------------------------------------------------------------------------------- /solver/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/solver/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /solver/__pycache__/build.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/solver/__pycache__/build.cpython-36.pyc -------------------------------------------------------------------------------- /solver/__pycache__/lr_scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/solver/__pycache__/lr_scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /solver/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | 9 | 10 | def make_optimizer(cfg, model): 11 | params = [] 12 | for key, value in model.named_parameters(): 13 | if not value.requires_grad: 14 | continue 15 | lr = cfg.SOLVER.BASE_LR 16 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 17 | if "bias" in key: 18 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 19 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 20 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 21 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 22 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 23 | else: 24 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 25 | return optimizer 26 | 27 | 28 | def make_optimizer_with_center(cfg, model, center_criterion): 29 | params = [] 30 | for key, value in model.named_parameters(): 31 | if not value.requires_grad: 32 | continue 33 | lr = cfg.SOLVER.BASE_LR 34 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 35 | if "bias" in key: 36 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 37 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 38 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 39 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 40 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 41 | else: 42 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 43 | optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR) 44 | return optimizer, optimizer_center 45 | -------------------------------------------------------------------------------- /solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from bisect import bisect_right 7 | import torch 8 | 9 | 10 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 11 | # separating MultiStepLR with WarmupLR 12 | # but the current LRScheduler design doesn't allow it 13 | 14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 15 | def __init__( 16 | self, 17 | optimizer, 18 | milestones, 19 | gamma=0.1, 20 | warmup_factor=1.0 / 3, 21 | warmup_iters=500, 22 | warmup_method="linear", 23 | last_epoch=-1, 24 | ): 25 | if not list(milestones) == sorted(milestones): 26 | raise ValueError( 27 | "Milestones should be a list of" " increasing integers. Got {}", 28 | milestones, 29 | ) 30 | 31 | if warmup_method not in ("constant", "linear"): 32 | raise ValueError( 33 | "Only 'constant' or 'linear' warmup_method accepted" 34 | "got {}".format(warmup_method) 35 | ) 36 | self.milestones = milestones 37 | self.gamma = gamma 38 | self.warmup_factor = warmup_factor 39 | self.warmup_iters = warmup_iters 40 | self.warmup_method = warmup_method 41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 42 | 43 | def get_lr(self): 44 | warmup_factor = 1 45 | if self.last_epoch < self.warmup_iters: 46 | if self.warmup_method == "constant": 47 | warmup_factor = self.warmup_factor 48 | elif self.warmup_method == "linear": 49 | alpha = self.last_epoch / self.warmup_iters 50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 51 | return [ 52 | base_lr 53 | * warmup_factor 54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 55 | for base_lr in self.base_lrs 56 | ] 57 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | -------------------------------------------------------------------------------- /tests/lr_scheduler_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import unittest 3 | 4 | import torch 5 | from torch import nn 6 | 7 | sys.path.append('.') 8 | from solver.lr_scheduler import WarmupMultiStepLR 9 | from solver.build import make_optimizer 10 | from config import cfg 11 | 12 | 13 | class MyTestCase(unittest.TestCase): 14 | def test_something(self): 15 | net = nn.Linear(10, 10) 16 | optimizer = make_optimizer(cfg, net) 17 | lr_scheduler = WarmupMultiStepLR(optimizer, [20, 40], warmup_iters=10) 18 | for i in range(50): 19 | lr_scheduler.step() 20 | for j in range(3): 21 | print(i, lr_scheduler.get_lr()[0]) 22 | optimizer.step() 23 | 24 | 25 | if __name__ == '__main__': 26 | unittest.main() 27 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | -------------------------------------------------------------------------------- /tools/pre_selection.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import argparse 8 | import os 9 | import sys 10 | from os import mkdir 11 | 12 | import torch 13 | from torch.backends import cudnn 14 | import json 15 | 16 | sys.path.append('.') 17 | from config import cfg 18 | from data import make_data_loader 19 | from engine.inference import inference 20 | from modeling import build_model_pre 21 | from utils.logger import setup_logger 22 | 23 | import torch 24 | import torch.nn as nn 25 | from ignite.engine import Engine 26 | 27 | from utils.reid_metric import R1_mAP, pre_selection_index 28 | 29 | def create_supervised_evaluator(model, metrics, 30 | device=None): 31 | """ 32 | Factory function for creating an evaluator for supervised models 33 | 34 | Args: 35 | model (`torch.nn.Module`): the model to train 36 | metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics 37 | device (str, optional): device type specification (default: None). 38 | Applies to both model and batches. 39 | Returns: 40 | Engine: an evaluator engine with supervised inference function 41 | """ 42 | if device: 43 | if torch.cuda.device_count() > 1: 44 | model = nn.DataParallel(model) 45 | model.to(device) 46 | 47 | def _inference(engine, batch): 48 | model.eval() 49 | with torch.no_grad(): 50 | data, pids, camids = batch 51 | data = data.to(device) if torch.cuda.device_count() >= 1 else data 52 | feat = model(data) 53 | return feat, pids, camids 54 | 55 | engine = Engine(_inference) 56 | 57 | for name, metric in metrics.items(): 58 | metric.attach(engine, name) 59 | 60 | return engine 61 | 62 | # def inference( 63 | # cfg, 64 | # model, 65 | # val_loader, 66 | # num_query 67 | # ): 68 | # device = cfg.MODEL.DEVICE 69 | # 70 | # logger = logging.getLogger("reid_baseline.inference") 71 | # logger.info("Enter inferencing") 72 | # if cfg.TEST.RE_RANKING == 'no': 73 | # print("Create evaluator") 74 | # evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=100, feat_norm=cfg.TEST.FEAT_NORM)}, 75 | # device=device) 76 | # elif cfg.TEST.RE_RANKING == 'yes': 77 | # print("Create evaluator for reranking") 78 | # evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP_reranking(num_query, max_rank=100, feat_norm=cfg.TEST.FEAT_NORM)}, 79 | # device=device) 80 | # else: 81 | # print("Unsupported re_ranking config. Only support for no or yes, but got {}.".format(cfg.TEST.RE_RANKING)) 82 | # 83 | # evaluator.run(val_loader) 84 | # cmc, mAP, _ = evaluator.state.metrics['r1_mAP'] 85 | # logger.info('Validation Results') 86 | # logger.info("mAP: {:.1%}".format(mAP)) 87 | # for r in [1, 5, 10, 100]: 88 | # logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 89 | 90 | 91 | 92 | def main(): 93 | parser = argparse.ArgumentParser(description="ReID Baseline Inference") 94 | parser.add_argument( 95 | "--config_file", default="", help="path to config file", type=str 96 | ) 97 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 98 | nargs=argparse.REMAINDER) 99 | 100 | args = parser.parse_args() 101 | 102 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 103 | 104 | if args.config_file != "": 105 | cfg.merge_from_file(args.config_file) 106 | cfg.merge_from_list(args.opts) 107 | cfg.freeze() 108 | 109 | output_dir = cfg.OUTPUT_DIR 110 | if output_dir and not os.path.exists(output_dir): 111 | mkdir(output_dir) 112 | 113 | if cfg.MODEL.DEVICE == "cuda": 114 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID 115 | cudnn.benchmark = True 116 | 117 | train_loader, val_loader, num_query, num_classes = make_data_loader(cfg) 118 | model = build_model_pre(cfg, num_classes) 119 | model.load_param(cfg.TEST.WEIGHT) 120 | 121 | # inference(cfg, model, val_loader, num_query) 122 | device = cfg.MODEL.DEVICE 123 | 124 | evaluator = create_supervised_evaluator(model, metrics={ 125 | 'pre_selection_index': pre_selection_index(num_query, max_rank=100, feat_norm=cfg.TEST.FEAT_NORM)}, 126 | device=device) 127 | 128 | evaluator.run(val_loader) 129 | 130 | index = evaluator.state.metrics['pre_selection_index'] 131 | 132 | with open(cfg.Pre_Index_DIR, 'w+') as f: 133 | json.dump(index.tolist(), f) 134 | 135 | print("Pre_Selection_Done") 136 | 137 | if __name__ == '__main__': 138 | main() 139 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import argparse 8 | import os 9 | import sys 10 | import torch 11 | 12 | from torch.backends import cudnn 13 | 14 | sys.path.append('.') 15 | from config import cfg 16 | from data import make_data_loader, make_data_loader_train 17 | from engine.trainer import do_train, do_train_with_center 18 | from modeling import build_model 19 | from layers import make_loss, make_loss_with_center 20 | from solver import make_optimizer, make_optimizer_with_center, WarmupMultiStepLR 21 | 22 | from utils.logger import setup_logger 23 | 24 | 25 | 26 | 27 | 28 | 29 | def train(cfg): 30 | # prepare dataset 31 | # train_loader, val_loader, num_query, num_classes = make_data_loader(cfg) 32 | train_loader, val_loader, num_query, num_classes = make_data_loader_train(cfg) 33 | 34 | 35 | # prepare model 36 | if 'prw' in cfg.DATASETS.NAMES: 37 | num_classes = 483 38 | elif "market1501" in cfg.DATASETS.NAMES: 39 | num_classes = 751 40 | elif "duke" in cfg.DATASETS.NAMES: 41 | num_classes = 702 42 | elif "cuhk" in cfg.DATASETS.NAMES: 43 | num_classes = 5532 44 | 45 | 46 | model = build_model(cfg, num_classes) 47 | 48 | if cfg.MODEL.IF_WITH_CENTER == 'no': 49 | print('Train without center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE) 50 | optimizer = make_optimizer(cfg, model) 51 | # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 52 | # cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 53 | 54 | loss_func = make_loss(cfg, num_classes) # modified by gu 55 | 56 | # Add for using self trained model 57 | if cfg.MODEL.PRETRAIN_CHOICE == 'self': 58 | # start_epoch = eval(cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')[-1]) 59 | start_epoch = 0 60 | print('Start epoch:', start_epoch) 61 | path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer') 62 | print('Path to the checkpoint of optimizer:', path_to_optimizer) 63 | 64 | 65 | pretrained_dic = torch.load(cfg.MODEL.PRETRAIN_PATH).state_dict() 66 | model_dict = model.state_dict() 67 | 68 | model_dict.update(pretrained_dic) 69 | model.load_state_dict(model_dict) 70 | 71 | if cfg.MODEL.WHOLE_MODEL_TRAIN == "no": 72 | for name, value in model.named_parameters(): 73 | if "Query_Guided_Attention" not in name and "non_local" not in name and "classifier_attention" not in name: 74 | value.requires_grad = False 75 | optimizer = make_optimizer(cfg, model) 76 | # else: 77 | # cfg.SOLVER.BASE_LR = 0.0000035 78 | 79 | # optimizer.load_state_dict(torch.load(path_to_optimizer)) 80 | # ##### 81 | # for state in optimizer.state.values(): 82 | # for k, v in state.items(): 83 | # if isinstance(v, torch.Tensor): 84 | # state[k] = v.cuda() 85 | # ##### 86 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 87 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 88 | elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet': 89 | start_epoch = 0 90 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 91 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 92 | else: 93 | print('Only support pretrain_choice for imagenet and self, but got {}'.format(cfg.MODEL.PRETRAIN_CHOICE)) 94 | 95 | arguments = {} 96 | 97 | do_train( 98 | cfg, 99 | model, 100 | train_loader, 101 | val_loader, 102 | optimizer, 103 | scheduler, # modify for using self trained model 104 | loss_func, 105 | num_query, 106 | start_epoch # add for using self trained model 107 | ) 108 | elif cfg.MODEL.IF_WITH_CENTER == 'yes': 109 | print('Train with center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE) 110 | loss_func, center_criterion = make_loss_with_center(cfg, num_classes) # modified by gu 111 | optimizer, optimizer_center = make_optimizer_with_center(cfg, model, center_criterion) 112 | # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 113 | # cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 114 | 115 | arguments = {} 116 | 117 | # Add for using self trained model 118 | if cfg.MODEL.PRETRAIN_CHOICE == 'self': 119 | start_epoch = eval(cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')[-1]) 120 | print('Start epoch:', start_epoch) 121 | path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer') 122 | print('Path to the checkpoint of optimizer:', path_to_optimizer) 123 | path_to_center_param = cfg.MODEL.PRETRAIN_PATH.replace('model', 'center_param') 124 | print('Path to the checkpoint of center_param:', path_to_center_param) 125 | path_to_optimizer_center = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer_center') 126 | print('Path to the checkpoint of optimizer_center:', path_to_optimizer_center) 127 | model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH)) 128 | optimizer.load_state_dict(torch.load(path_to_optimizer)) 129 | ##### 130 | for state in optimizer.state.values(): 131 | for k, v in state.items(): 132 | if isinstance(v, torch.Tensor): 133 | state[k] = v.cuda() 134 | ##### 135 | center_criterion.load_state_dict(torch.load(path_to_center_param)) 136 | optimizer_center.load_state_dict(torch.load(path_to_optimizer_center)) 137 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 138 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch) 139 | elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet': 140 | start_epoch = 0 141 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 142 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 143 | else: 144 | print('Only support pretrain_choice for imagenet and self, but got {}'.format(cfg.MODEL.PRETRAIN_CHOICE)) 145 | 146 | do_train_with_center( 147 | cfg, 148 | model, 149 | center_criterion, 150 | train_loader, 151 | val_loader, 152 | optimizer, 153 | optimizer_center, 154 | scheduler, # modify for using self trained model 155 | loss_func, 156 | num_query, 157 | start_epoch # add for using self trained model 158 | ) 159 | else: 160 | print("Unsupported value for cfg.MODEL.IF_WITH_CENTER {}, only support yes or no!\n".format(cfg.MODEL.IF_WITH_CENTER)) 161 | 162 | 163 | def main(): 164 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 165 | parser.add_argument( 166 | "--config_file", default="", help="path to config file", type=str 167 | ) 168 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 169 | nargs=argparse.REMAINDER) 170 | 171 | args = parser.parse_args() 172 | 173 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 174 | 175 | if args.config_file != "": 176 | cfg.merge_from_file(args.config_file) 177 | cfg.merge_from_list(args.opts) 178 | cfg.freeze() 179 | 180 | output_dir = cfg.OUTPUT_DIR 181 | if output_dir and not os.path.exists(output_dir): 182 | os.makedirs(output_dir) 183 | 184 | logger = setup_logger("reid_baseline", output_dir, 0) 185 | logger.info("Using {} GPUS".format(num_gpus)) 186 | logger.info(args) 187 | 188 | if args.config_file != "": 189 | logger.info("Loaded configuration file {}".format(args.config_file)) 190 | with open(args.config_file, 'r') as cf: 191 | config_str = "\n" + cf.read() 192 | logger.info(config_str) 193 | logger.info("Running with config:\n{}".format(cfg)) 194 | 195 | if cfg.MODEL.DEVICE == "cuda": 196 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID # new add by gu 197 | cudnn.benchmark = True 198 | train(cfg) 199 | 200 | 201 | if __name__ == '__main__': 202 | 203 | # model_path = "/raid/home/henrayzhao/person_search/trained/strong_baseline/prw_non_local/resnet50_model_40.pth" 204 | # model_dic = torch.load(model_path) 205 | # print(model_dic) 206 | main() 207 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/iotools.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/utils/__pycache__/iotools.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/utils/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/re_ranking.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/utils/__pycache__/re_ranking.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/reid_metric.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/X-BrainLab/PI-ReID/76755f094fe6d6a7f195cd4aaaddc1b6beff8d0f/utils/__pycache__/reid_metric.cpython-36.pyc -------------------------------------------------------------------------------- /utils/iotools.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import errno 8 | import json 9 | import os 10 | 11 | import os.path as osp 12 | 13 | 14 | def mkdir_if_missing(directory): 15 | if not osp.exists(directory): 16 | try: 17 | os.makedirs(directory) 18 | except OSError as e: 19 | if e.errno != errno.EEXIST: 20 | raise 21 | 22 | 23 | def check_isfile(path): 24 | isfile = osp.isfile(path) 25 | if not isfile: 26 | print("=> Warning: no file found at '{}' (ignored)".format(path)) 27 | return isfile 28 | 29 | 30 | def read_json(fpath): 31 | with open(fpath, 'r') as f: 32 | obj = json.load(f) 33 | return obj 34 | 35 | 36 | def write_json(obj, fpath): 37 | mkdir_if_missing(osp.dirname(fpath)) 38 | with open(fpath, 'w') as f: 39 | json.dump(obj, f, indent=4, separators=(',', ': ')) 40 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import logging 8 | import os 9 | import sys 10 | 11 | 12 | def setup_logger(name, save_dir, distributed_rank): 13 | logger = logging.getLogger(name) 14 | logger.setLevel(logging.DEBUG) 15 | # don't log results for the non-master process 16 | if distributed_rank > 0: 17 | return logger 18 | ch = logging.StreamHandler(stream=sys.stdout) 19 | ch.setLevel(logging.DEBUG) 20 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 21 | ch.setFormatter(formatter) 22 | logger.addHandler(ch) 23 | 24 | if save_dir: 25 | fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w') 26 | fh.setLevel(logging.DEBUG) 27 | fh.setFormatter(formatter) 28 | logger.addHandler(fh) 29 | 30 | return logger 31 | -------------------------------------------------------------------------------- /utils/re_ranking.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri, 25 May 2018 20:29:09 5 | 6 | @author: luohao 7 | """ 8 | 9 | """ 10 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 11 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 12 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 13 | """ 14 | 15 | """ 16 | API 17 | 18 | probFea: all feature vectors of the query set (torch tensor) 19 | probFea: all feature vectors of the gallery set (torch tensor) 20 | k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3) 21 | MemorySave: set to 'True' when using MemorySave mode 22 | Minibatch: avaliable when 'MemorySave' is 'True' 23 | """ 24 | 25 | import numpy as np 26 | import torch 27 | 28 | 29 | def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat=None, only_local=False): 30 | # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor 31 | query_num = probFea.size(0) 32 | all_num = query_num + galFea.size(0) 33 | if only_local: 34 | original_dist = local_distmat 35 | else: 36 | feat = torch.cat([probFea,galFea]) 37 | print('using GPU to compute original distance') 38 | distmat = torch.pow(feat,2).sum(dim=1, keepdim=True).expand(all_num,all_num) + \ 39 | torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t() 40 | distmat.addmm_(1,-2,feat,feat.t()) 41 | original_dist = distmat.cpu().numpy() 42 | del feat 43 | if not local_distmat is None: 44 | original_dist = original_dist + local_distmat 45 | gallery_num = original_dist.shape[0] 46 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) 47 | V = np.zeros_like(original_dist).astype(np.float16) 48 | initial_rank = np.argsort(original_dist).astype(np.int32) 49 | 50 | print('starting re_ranking') 51 | for i in range(all_num): 52 | # k-reciprocal neighbors 53 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 54 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 55 | fi = np.where(backward_k_neigh_index == i)[0] 56 | k_reciprocal_index = forward_k_neigh_index[fi] 57 | k_reciprocal_expansion_index = k_reciprocal_index 58 | for j in range(len(k_reciprocal_index)): 59 | candidate = k_reciprocal_index[j] 60 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1] 61 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 62 | :int(np.around(k1 / 2)) + 1] 63 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 64 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 65 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 66 | candidate_k_reciprocal_index): 67 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 68 | 69 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 70 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 71 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) 72 | original_dist = original_dist[:query_num, ] 73 | if k2 != 1: 74 | V_qe = np.zeros_like(V, dtype=np.float16) 75 | for i in range(all_num): 76 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 77 | V = V_qe 78 | del V_qe 79 | del initial_rank 80 | invIndex = [] 81 | for i in range(gallery_num): 82 | invIndex.append(np.where(V[:, i] != 0)[0]) 83 | 84 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) 85 | 86 | for i in range(query_num): 87 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16) 88 | indNonZero = np.where(V[i, :] != 0)[0] 89 | indImages = [invIndex[ind] for ind in indNonZero] 90 | for j in range(len(indNonZero)): 91 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 92 | V[indImages[j], indNonZero[j]]) 93 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 94 | 95 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 96 | del original_dist 97 | del V 98 | del jaccard_dist 99 | final_dist = final_dist[:query_num, query_num:] 100 | return final_dist 101 | 102 | -------------------------------------------------------------------------------- /utils/reid_metric.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | from ignite.metrics import Metric 10 | 11 | from data.datasets.eval_reid import eval_func 12 | from .re_ranking import re_ranking 13 | 14 | class pre_selection_index(Metric): 15 | def __init__(self, num_query, max_rank=100, feat_norm='yes'): 16 | super(pre_selection_index, self).__init__() 17 | self.num_query = num_query 18 | self.max_rank = max_rank 19 | self.feat_norm = feat_norm 20 | 21 | def reset(self): 22 | self.feats = [] 23 | self.pids = [] 24 | self.camids = [] 25 | 26 | def update(self, output): 27 | feat, pid, camid = output 28 | self.feats.append(feat) 29 | self.pids.extend(np.asarray(pid)) 30 | self.camids.extend(np.asarray(camid)) 31 | 32 | def compute(self): 33 | feats = torch.cat(self.feats, dim=0) 34 | if self.feat_norm == 'yes': 35 | # print("The test feature is normalized") 36 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) 37 | # query 38 | qf = feats[:self.num_query] 39 | # gallery 40 | gf = feats[self.num_query:] 41 | m, n = qf.shape[0], gf.shape[0] 42 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 43 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 44 | distmat.addmm_(1, -2, qf, gf.t()) 45 | distmat = distmat.cpu().numpy() 46 | return np.argsort(distmat, axis=1) 47 | 48 | 49 | class R1_mAP(Metric): 50 | def __init__(self, num_query, max_rank=100, feat_norm='yes'): 51 | super(R1_mAP, self).__init__() 52 | self.num_query = num_query 53 | self.max_rank = max_rank 54 | self.feat_norm = feat_norm 55 | 56 | def reset(self): 57 | self.feats = [] 58 | self.pids = [] 59 | self.camids = [] 60 | 61 | def update(self, output): 62 | feat, pid, camid = output 63 | self.feats.append(feat) 64 | self.pids.extend(np.asarray(pid)) 65 | self.camids.extend(np.asarray(camid)) 66 | 67 | def compute(self): 68 | feats = torch.cat(self.feats, dim=0) 69 | if self.feat_norm == 'yes': 70 | # print("The test feature is normalized") 71 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) 72 | # query 73 | qf = feats[:self.num_query] 74 | q_pids = np.asarray(self.pids[:self.num_query]) 75 | q_camids = np.asarray(self.camids[:self.num_query]) 76 | # gallery 77 | gf = feats[self.num_query:] 78 | g_pids = np.asarray(self.pids[self.num_query:]) 79 | g_camids = np.asarray(self.camids[self.num_query:]) 80 | m, n = qf.shape[0], gf.shape[0] 81 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 82 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 83 | distmat.addmm_(1, -2, qf, gf.t()) 84 | distmat = distmat.cpu().numpy() 85 | 86 | cmc, mAP, q_pid_return = eval_func(distmat, q_pids, g_pids, q_camids, g_camids) 87 | 88 | return cmc, mAP, q_pid_return 89 | 90 | class R1_mAP_pair(Metric): 91 | def __init__(self, num_query, max_rank=100, feat_norm='yes'): 92 | super(R1_mAP_pair, self).__init__() 93 | self.num_query = num_query 94 | self.max_rank = max_rank 95 | self.feat_norm = feat_norm 96 | 97 | def reset(self): 98 | self.scores = [] 99 | self.pids = [] 100 | self.camids = [] 101 | 102 | def update(self, output): 103 | score, pid, camid = output 104 | self.scores.append(score) 105 | self.pids.extend(np.asarray(pid)) 106 | self.camids.extend(np.asarray(camid)) 107 | 108 | def compute(self): 109 | scores = torch.cat(self.scores, dim=0).view(1, -1) 110 | distmat = scores.cpu().numpy() 111 | # print(distmat.shape) 112 | 113 | if distmat.shape[1] == 101: 114 | distmat = distmat[:, 1:] 115 | # query 116 | q_pids = np.asarray(self.pids[:self.num_query]) 117 | q_camids = np.asarray(self.camids[:self.num_query]) 118 | # gallery 119 | g_pids = np.asarray(self.pids[self.num_query:]) 120 | g_camids = np.asarray(self.camids[self.num_query:]) 121 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids) 122 | 123 | return cmc, mAP 124 | 125 | 126 | class R1_mAP_reranking(Metric): 127 | def __init__(self, num_query, max_rank=100, feat_norm='yes'): 128 | super(R1_mAP_reranking, self).__init__() 129 | self.num_query = num_query 130 | self.max_rank = max_rank 131 | self.feat_norm = feat_norm 132 | 133 | def reset(self): 134 | self.feats = [] 135 | self.pids = [] 136 | self.camids = [] 137 | 138 | def update(self, output): 139 | feat, pid, camid = output 140 | self.feats.append(feat) 141 | self.pids.extend(np.asarray(pid)) 142 | self.camids.extend(np.asarray(camid)) 143 | 144 | def compute(self): 145 | feats = torch.cat(self.feats, dim=0) 146 | if self.feat_norm == 'yes': 147 | # print("The test feature is normalized") 148 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) 149 | 150 | # query 151 | qf = feats[:self.num_query] 152 | q_pids = np.asarray(self.pids[:self.num_query]) 153 | q_camids = np.asarray(self.camids[:self.num_query]) 154 | # gallery 155 | gf = feats[self.num_query:] 156 | g_pids = np.asarray(self.pids[self.num_query:]) 157 | g_camids = np.asarray(self.camids[self.num_query:]) 158 | # m, n = qf.shape[0], gf.shape[0] 159 | # distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 160 | # torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 161 | # distmat.addmm_(1, -2, qf, gf.t()) 162 | # distmat = distmat.cpu().numpy() 163 | print("Enter reranking") 164 | distmat = re_ranking(qf, gf, k1=20, k2=6, lambda_value=0.3) 165 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids) 166 | 167 | return cmc, mAP --------------------------------------------------------------------------------