├── Experiment-all_tricks-tri_center-market.sh ├── LICENCE.md ├── README.md ├── Test-all_tricks-tri_center-feat_after_bn-cos-market.sh ├── config ├── __init__.py └── defaults.py ├── configs ├── baseline.yml ├── softmax.yml ├── softmax_triplet.yml └── softmax_triplet_with_center.yml ├── data ├── __init__.py ├── build.py ├── collate_batch.py ├── datasets │ ├── __init__.py │ ├── bases.py │ ├── cuhk03.py │ ├── dataset_loader.py │ ├── dukemtmcreid.py │ ├── eval_reid.py │ ├── market1501.py │ ├── msmt17.py │ ├── nformer.py │ └── veri.py ├── samplers │ ├── __init__.py │ └── triplet_sampler.py └── transforms │ ├── __init__.py │ ├── build.py │ └── transforms.py ├── engine ├── inference.py └── trainer.py ├── layers ├── __init__.py ├── center_loss.py └── triplet_loss.py ├── modeling ├── __init__.py ├── backbones │ ├── __init__.py │ ├── resnet.py │ ├── resnet_ibn_a.py │ └── senet.py ├── baseline.py ├── model.py └── nformer.py ├── pipeline.jpg ├── solver ├── __init__.py ├── build.py └── lr_scheduler.py ├── tools ├── __init__.py ├── nformer_train.py ├── test.py └── train.py └── utils ├── __init__.py ├── iotools.py ├── logger.py ├── re_ranking.py └── reid_metric.py /Experiment-all_tricks-tri_center-market.sh: -------------------------------------------------------------------------------- 1 | # Experiment all tricks with center loss : 256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on-triplet_centerloss0_0005 2 | # Dataset 1: market1501 3 | # imagesize: 256x128 4 | # batchsize: 16x4 5 | # warmup_step 10 6 | # random erase prob 0.5 7 | # labelsmooth: on 8 | # last stride 1 9 | # bnneck on 10 | # with center loss 11 | python3 tools/train.py --config_file='configs/softmax_triplet_with_center.yml' MODEL.DEVICE_ID "('0')" DATASETS.NAMES "('market1501')" DATASETS.ROOT_DIR "('/home/haochen/workspace/project/NFORMER/')" OUTPUT_DIR "('test')" 12 | -------------------------------------------------------------------------------- /LICENCE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2019] [HaoLuo] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NFormer 2 | 3 | Implementation of NFormer: Robust Person Re-identification with Neighbor Transformer. CVPR2022 4 | 5 | ## Pipeline 6 |
7 | 8 |
9 | 10 | ## Requirements 11 | - Python3 12 | - pytorch>=0.4 13 | - torchvision 14 | - pytorch-ignite=0.1.2 (Note: V0.2.0 may result in an error) 15 | - yacs 16 | ## Hardware 17 | - 1 NVIDIA 3090 Ti 18 | 19 | ## Dataset 20 | Create a directory to store reid datasets under this repo or outside this repo. Set your path to the root of the dataset in `config/defaults.py` or set in scripts `Experiment-all_tricks-tri_center-market.sh` and `Test-all_tricks-tri_center-feat_after_bn-cos-market.sh`. 21 | #### Market1501 22 | * Download dataset to `data/` from https://zheng-lab.cecs.anu.edu.au/Project/project_reid.html 23 | * Extract dataset and rename to `market1501`. The data structure would like: 24 | 25 | ```bash 26 | |- data 27 | |- market1501 # this folder contains 6 files. 28 | |- bounding_box_test/ 29 | |- bounding_box_train/ 30 | ...... 31 | ``` 32 | 33 | 34 | 35 | ## Training 36 | download the pretrained [resnet50](https://download.pytorch.org/models/resnet50-19c8e357.pth) model and set the path at [line3](configs/softmax_triplet_with_center.yml) 37 | 38 | run `Experiment-all_tricks-tri_center-market.sh` to train NFormer on Market-1501 dataset 39 | ``` 40 | sh Experiment-all_tricks-tri_center-market.sh 41 | ``` 42 | or 43 | ``` 44 | python3 tools/train.py --config_file='configs/softmax_triplet_with_center.yml' MODEL.DEVICE_ID "('0')" DATASETS.NAMES "('market1501')" DATASETS.ROOT_DIR "('/home/haochen/workspace/project/NFORMER/')" OUTPUT_DIR "('work_dirs')" 45 | ``` 46 | 47 | ## Evaluation 48 | run `Test-all_tricks-tri_center-feat_after_bn-cos-market.sh` to evaluate NFormer on Market-1501 dataset. Change `TEST.TEST_NFORMER` to determine test for NFormer (`'yes'`) or CNNEncoder (`'no'`). 49 | 50 | ``` 51 | sh Test-all_tricks-tri_center-feat_after_bn-cos-market.sh 52 | ``` 53 | or 54 | ``` 55 | python3 tools/test.py --config_file='configs/softmax_triplet_with_center.yml' MODEL.DEVICE_ID "('0')" DATASETS.NAMES "('market1501')" DATASETS.ROOT_DIR "('/home/haochen/workspace/project/NFORMER')" MODEL.PRETRAIN_CHOICE "('self')" TEST.WEIGHT "('test/nformer_model.pth')" TEST.TEST_NFORMER "('no')" 56 | ``` 57 | 58 | 59 | 60 | ## Acknowledgement 61 | This repo is highly based on [reid-strong-baseline](https://github.com/michuanhaohao/reid-strong-baseline), thanks for their excellent work. 62 | 63 | ## Citation 64 | ``` 65 | @article{wang2022nformer, 66 | title={NFormer: Robust Person Re-identification with Neighbor Transformer}, 67 | author={Wang, Haochen and Shen, Jiayi and Liu, Yongtuo and Gao, Yan and Gavves, Efstratios}, 68 | journal={arXiv preprint arXiv:2204.09331}, 69 | year={2022} 70 | } 71 | 72 | @InProceedings{Luo_2019_CVPR_Workshops, 73 | author = {Luo, Hao and Gu, Youzhi and Liao, Xingyu and Lai, Shenqi and Jiang, Wei}, 74 | title = {Bag of Tricks and a Strong Baseline for Deep Person Re-Identification}, 75 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 76 | month = {June}, 77 | year = {2019} 78 | } 79 | ``` 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /Test-all_tricks-tri_center-feat_after_bn-cos-market.sh: -------------------------------------------------------------------------------- 1 | # Dataset 1: market1501 2 | # imagesize: 256x128 3 | # batchsize: 16x4 4 | # warmup_step 10 5 | # random erase prob 0.5 6 | # labelsmooth: on 7 | # last stride 1 8 | # bnneck on 9 | # with center loss 10 | # without re-ranking 11 | python3 tools/test.py --config_file='configs/softmax_triplet_with_center.yml' MODEL.DEVICE_ID "('0')" DATASETS.NAMES "('market1501')" DATASETS.ROOT_DIR "('/home/haochen/workspace/project/NFORMER')" MODEL.PRETRAIN_CHOICE "('self')" TEST.WEIGHT "('test/nformer_model.pth')" TEST.TEST_NFORMER "('no')" 12 | 13 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .defaults import _C as cfg 8 | -------------------------------------------------------------------------------- /config/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | # ----------------------------------------------------------------------------- 4 | # Convention about Training / Test specific parameters 5 | # ----------------------------------------------------------------------------- 6 | # Whenever an argument can be either used for training or for testing, the 7 | # corresponding name will be post-fixed by a _TRAIN for a training parameter, 8 | # or _TEST for a test-specific parameter. 9 | # For example, the number of images during training will be 10 | # IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be 11 | # IMAGES_PER_BATCH_TEST 12 | 13 | # ----------------------------------------------------------------------------- 14 | # Config definition 15 | # ----------------------------------------------------------------------------- 16 | 17 | _C = CN() 18 | 19 | _C.MODEL = CN() 20 | # Using cuda or cpu for training 21 | _C.MODEL.DEVICE = "cuda" 22 | # ID number of GPU 23 | _C.MODEL.DEVICE_ID = '0' 24 | # Name of backbone 25 | _C.MODEL.NAME = 'resnet50' 26 | # Last stride of backbone 27 | _C.MODEL.LAST_STRIDE = 1 28 | # Path to pretrained model of backbone 29 | _C.MODEL.PRETRAIN_PATH = '' 30 | # Use ImageNet pretrained model to initialize backbone or use self trained model to initialize the whole model 31 | # Options: 'imagenet' or 'self' 32 | _C.MODEL.PRETRAIN_CHOICE = 'imagenet' 33 | # If train with BNNeck, options: 'bnneck' or 'no' 34 | _C.MODEL.NECK = 'bnneck' 35 | # If train loss include center loss, options: 'yes' or 'no'. Loss with center loss has different optimizer configuration 36 | _C.MODEL.IF_WITH_CENTER = 'no' 37 | # The loss type of metric loss 38 | # options:['triplet'](without center loss) or ['center','triplet_center'](with center loss) 39 | _C.MODEL.METRIC_LOSS_TYPE = 'triplet' 40 | # For example, if loss type is cross entropy loss + triplet loss + center loss 41 | # the setting should be: _C.MODEL.METRIC_LOSS_TYPE = 'triplet_center' and _C.MODEL.IF_WITH_CENTER = 'yes' 42 | 43 | # If train with label smooth, options: 'on', 'off' 44 | _C.MODEL.IF_LABELSMOOTH = 'on' 45 | _C.MODEL.N_EMBD = 256 46 | _C.MODEL.N_HEAD = 2 47 | _C.MODEL.N_LAYER = 4 48 | _C.MODEL.EMBD_PDROP = 0.1 49 | _C.MODEL.ATTN_PDROP = 0.1 50 | _C.MODEL.RESID_PDROP = 0.1 51 | _C.MODEL.AFN = 'gelu' 52 | _C.MODEL.CLF_PDROP = 0.1 53 | _C.MODEL.TOPK = 20 54 | _C.MODEL.LANDMARK = 10 55 | 56 | 57 | # ----------------------------------------------------------------------------- 58 | # INPUT 59 | # ----------------------------------------------------------------------------- 60 | _C.INPUT = CN() 61 | # Size of the image during training 62 | _C.INPUT.SIZE_TRAIN = [384, 128] 63 | # Size of the image during test 64 | _C.INPUT.SIZE_TEST = [384, 128] 65 | # Random probability for image horizontal flip 66 | _C.INPUT.PROB = 0.5 67 | # Random probability for random erasing 68 | _C.INPUT.RE_PROB = 0.5 69 | # Values to be used for image normalization 70 | _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406] 71 | # Values to be used for image normalization 72 | _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225] 73 | # Value of padding size 74 | _C.INPUT.PADDING = 10 75 | 76 | # ----------------------------------------------------------------------------- 77 | # Dataset 78 | # ----------------------------------------------------------------------------- 79 | _C.DATASETS = CN() 80 | # List of the dataset names for training, as present in paths_catalog.py 81 | _C.DATASETS.NAMES = ('market1501') 82 | # Root directory where datasets should be used (and downloaded if not found) 83 | _C.DATASETS.ROOT_DIR = ('./data') 84 | 85 | # ----------------------------------------------------------------------------- 86 | # DataLoader 87 | # ----------------------------------------------------------------------------- 88 | _C.DATALOADER = CN() 89 | # Number of data loading threads 90 | _C.DATALOADER.NUM_WORKERS = 8 91 | # Sampler for data loading 92 | _C.DATALOADER.SAMPLER = 'softmax' 93 | # Number of instance for one batch 94 | _C.DATALOADER.NUM_INSTANCE = 16 95 | 96 | # ---------------------------------------------------------------------------- # 97 | # Solver 98 | # ---------------------------------------------------------------------------- # 99 | _C.SOLVER = CN() 100 | # Name of optimizer 101 | _C.SOLVER.OPTIMIZER_NAME = "Adam" 102 | # Number of max epoches 103 | _C.SOLVER.MAX_EPOCHS = 50 104 | # Number of nformer max epoches 105 | _C.SOLVER.NFORMER_MAX_EPOCHS = 20 106 | # Base learning rate 107 | _C.SOLVER.BASE_LR = 3e-4 108 | # Factor of learning bias 109 | _C.SOLVER.BIAS_LR_FACTOR = 2 110 | # Momentum 111 | _C.SOLVER.MOMENTUM = 0.9 112 | # Margin of triplet loss 113 | _C.SOLVER.MARGIN = 0.3 114 | # Margin of cluster ;pss 115 | _C.SOLVER.CLUSTER_MARGIN = 0.3 116 | # Learning rate of SGD to learn the centers of center loss 117 | _C.SOLVER.CENTER_LR = 0.5 118 | # Balanced weight of center loss 119 | _C.SOLVER.CENTER_LOSS_WEIGHT = 0.0005 120 | # Settings of range loss 121 | _C.SOLVER.RANGE_K = 2 122 | _C.SOLVER.RANGE_MARGIN = 0.3 123 | _C.SOLVER.RANGE_ALPHA = 0 124 | _C.SOLVER.RANGE_BETA = 1 125 | _C.SOLVER.RANGE_LOSS_WEIGHT = 1 126 | 127 | # Settings of weight decay 128 | _C.SOLVER.WEIGHT_DECAY = 0.0005 129 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0. 130 | 131 | # decay rate of learning rate 132 | _C.SOLVER.GAMMA = 0.1 133 | # decay step of learning rate 134 | _C.SOLVER.STEPS = (30, 55) 135 | 136 | # warm up factor 137 | _C.SOLVER.WARMUP_FACTOR = 1.0 / 3 138 | # iterations of warm up 139 | _C.SOLVER.WARMUP_ITERS = 500 140 | # method of warm up, option: 'constant','linear' 141 | _C.SOLVER.WARMUP_METHOD = "linear" 142 | 143 | # epoch number of saving checkpoints 144 | _C.SOLVER.CHECKPOINT_PERIOD = 50 145 | # iteration of display training log 146 | _C.SOLVER.LOG_PERIOD = 100 147 | # epoch number of validation 148 | _C.SOLVER.EVAL_PERIOD = 50 149 | 150 | # Number of images per batch 151 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will 152 | # see 2 images per batch 153 | _C.SOLVER.IMS_PER_BATCH = 64 154 | 155 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will 156 | # see 2 images per batch 157 | _C.TEST = CN() 158 | # Number of images per batch during test 159 | _C.TEST.IMS_PER_BATCH = 128 160 | # If test with re-ranking, options: 'yes','no' 161 | _C.TEST.RE_RANKING = 'no' 162 | # Path to trained model 163 | _C.TEST.WEIGHT = "" 164 | # Which feature of BNNeck to be used for test, before or after BNNneck, options: 'before' or 'after' 165 | _C.TEST.NECK_FEAT = 'after' 166 | # Whether feature is nomalized before test, if yes, it is equivalent to cosine distance 167 | _C.TEST.FEAT_NORM = 'yes' 168 | # Whether test nformer or encoder only 169 | _C.TEST.TEST_NFORMER = 'yes' 170 | 171 | # Data type durining test 172 | 173 | # ---------------------------------------------------------------------------- # 174 | # Misc options 175 | # ---------------------------------------------------------------------------- # 176 | # Path to checkpoint and saved log of trained model 177 | _C.OUTPUT_DIR = "" 178 | -------------------------------------------------------------------------------- /configs/baseline.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/haoluo/.torch/models/resnet50-19c8e357.pth' 4 | LAST_STRIDE: 2 5 | NECK: 'no' 6 | METRIC_LOSS_TYPE: 'triplet' 7 | IF_LABELSMOOTH: 'off' 8 | IF_WITH_CENTER: 'no' 9 | 10 | 11 | INPUT: 12 | SIZE_TRAIN: [256, 128] 13 | SIZE_TEST: [256, 128] 14 | PROB: 0.5 # random horizontal flip 15 | RE_PROB: 0.0 # random erasing 16 | PADDING: 10 17 | 18 | DATASETS: 19 | NAMES: ('market1501') 20 | 21 | DATALOADER: 22 | SAMPLER: 'softmax_triplet' 23 | NUM_INSTANCE: 4 24 | NUM_WORKERS: 8 25 | 26 | SOLVER: 27 | OPTIMIZER_NAME: 'Adam' 28 | MAX_EPOCHS: 120 29 | BASE_LR: 0.00035 30 | 31 | CLUSTER_MARGIN: 0.3 32 | 33 | CENTER_LR: 0.5 34 | CENTER_LOSS_WEIGHT: 0.0005 35 | 36 | RANGE_K: 2 37 | RANGE_MARGIN: 0.3 38 | RANGE_ALPHA: 0 39 | RANGE_BETA: 1 40 | RANGE_LOSS_WEIGHT: 1 41 | 42 | BIAS_LR_FACTOR: 1 43 | WEIGHT_DECAY: 0.0005 44 | WEIGHT_DECAY_BIAS: 0.0005 45 | IMS_PER_BATCH: 64 46 | 47 | STEPS: [40, 70] 48 | GAMMA: 0.1 49 | 50 | WARMUP_FACTOR: 0.01 51 | WARMUP_ITERS: 0 52 | WARMUP_METHOD: 'linear' 53 | 54 | CHECKPOINT_PERIOD: 40 55 | LOG_PERIOD: 20 56 | EVAL_PERIOD: 40 57 | 58 | TEST: 59 | IMS_PER_BATCH: 128 60 | RE_RANKING: 'no' 61 | WEIGHT: "path" 62 | NECK_FEAT: 'after' 63 | FEAT_NORM: 'yes' 64 | 65 | OUTPUT_DIR: "/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-all-tricks-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on" 66 | 67 | 68 | -------------------------------------------------------------------------------- /configs/softmax.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_PATH: '/home/haoluo/.torch/models/resnet50-19c8e357.pth' 3 | 4 | 5 | INPUT: 6 | SIZE_TRAIN: [256, 128] 7 | SIZE_TEST: [256, 128] 8 | PROB: 0.5 # random horizontal flip 9 | RE_PROB: 0.5 # random erasing 10 | PADDING: 10 11 | 12 | DATASETS: 13 | NAMES: ('market1501') 14 | 15 | DATALOADER: 16 | SAMPLER: 'softmax' 17 | NUM_WORKERS: 8 18 | 19 | SOLVER: 20 | OPTIMIZER_NAME: 'Adam' 21 | MAX_EPOCHS: 120 22 | BASE_LR: 0.00035 23 | BIAS_LR_FACTOR: 1 24 | WEIGHT_DECAY: 0.0005 25 | WEIGHT_DECAY_BIAS: 0.0005 26 | IMS_PER_BATCH: 64 27 | 28 | STEPS: [30, 55] 29 | GAMMA: 0.1 30 | 31 | WARMUP_FACTOR: 0.01 32 | WARMUP_ITERS: 5 33 | WARMUP_METHOD: 'linear' 34 | 35 | CHECKPOINT_PERIOD: 20 36 | LOG_PERIOD: 20 37 | EVAL_PERIOD: 20 38 | 39 | TEST: 40 | IMS_PER_BATCH: 128 41 | 42 | OUTPUT_DIR: "/home/haoluo/log/reid/market1501/softmax_bs64_256x128" 43 | 44 | 45 | -------------------------------------------------------------------------------- /configs/softmax_triplet.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/haoluo/.torch/models/resnet50-19c8e357.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'on' 6 | IF_WITH_CENTER: 'no' 7 | 8 | 9 | 10 | 11 | INPUT: 12 | SIZE_TRAIN: [256, 128] 13 | SIZE_TEST: [256, 128] 14 | PROB: 0.5 # random horizontal flip 15 | RE_PROB: 0.5 # random erasing 16 | PADDING: 10 17 | 18 | DATASETS: 19 | NAMES: ('market1501') 20 | 21 | DATALOADER: 22 | SAMPLER: 'softmax_triplet' 23 | NUM_INSTANCE: 4 24 | NUM_WORKERS: 8 25 | 26 | SOLVER: 27 | OPTIMIZER_NAME: 'Adam' 28 | MAX_EPOCHS: 120 29 | BASE_LR: 0.00035 30 | 31 | CLUSTER_MARGIN: 0.3 32 | 33 | CENTER_LR: 0.5 34 | CENTER_LOSS_WEIGHT: 0.0005 35 | 36 | RANGE_K: 2 37 | RANGE_MARGIN: 0.3 38 | RANGE_ALPHA: 0 39 | RANGE_BETA: 1 40 | RANGE_LOSS_WEIGHT: 1 41 | 42 | BIAS_LR_FACTOR: 1 43 | WEIGHT_DECAY: 0.0005 44 | WEIGHT_DECAY_BIAS: 0.0005 45 | IMS_PER_BATCH: 64 46 | 47 | STEPS: [40, 70] 48 | GAMMA: 0.1 49 | 50 | WARMUP_FACTOR: 0.01 51 | WARMUP_ITERS: 10 52 | WARMUP_METHOD: 'linear' 53 | 54 | CHECKPOINT_PERIOD: 40 55 | LOG_PERIOD: 20 56 | EVAL_PERIOD: 40 57 | 58 | TEST: 59 | IMS_PER_BATCH: 128 60 | RE_RANKING: 'no' 61 | WEIGHT: "path" 62 | NECK_FEAT: 'after' 63 | FEAT_NORM: 'yes' 64 | 65 | OUTPUT_DIR: "/home/haoluo/log/gu/reid_baseline_review/Opensource_test/market1501/Experiment-all-tricks-256x128-bs16x4-warmup10-erase0_5-labelsmooth_on-laststride1-bnneck_on" 66 | 67 | 68 | -------------------------------------------------------------------------------- /configs/softmax_triplet_with_center.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/haochen/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth' 4 | METRIC_LOSS_TYPE: 'triplet_center' 5 | IF_LABELSMOOTH: 'on' 6 | IF_WITH_CENTER: 'yes' 7 | N_EMBD: 256 8 | N_HEAD: 2 9 | N_LAYER: 4 10 | EMBD_PDROP: 0.1 11 | ATTN_PDROP: 0.1 12 | RESID_PDROP: 0.1 13 | AFN: 'gelu' 14 | CLF_PDROP: 0.1 15 | TOPK: 20 16 | LANDMARK: 5 17 | 18 | 19 | 20 | 21 | INPUT: 22 | SIZE_TRAIN: [256, 128] 23 | SIZE_TEST: [256, 128] 24 | PROB: 0.5 # random horizontal flip 25 | RE_PROB: 0.5 # random erasing 26 | PADDING: 10 27 | 28 | DATASETS: 29 | NAMES: ('market1501') 30 | 31 | DATALOADER: 32 | SAMPLER: 'softmax_triplet' 33 | NUM_INSTANCE: 4 34 | NUM_WORKERS: 8 35 | 36 | SOLVER: 37 | OPTIMIZER_NAME: 'Adam' 38 | MAX_EPOCHS: 120 39 | NFORMER_MAX_EPOCHS: 20 40 | BASE_LR: 0.00035 41 | 42 | CLUSTER_MARGIN: 0.3 43 | 44 | CENTER_LR: 0.5 45 | CENTER_LOSS_WEIGHT: 0.0005 46 | 47 | RANGE_K: 2 48 | RANGE_MARGIN: 0.3 49 | RANGE_ALPHA: 0 50 | RANGE_BETA: 1 51 | RANGE_LOSS_WEIGHT: 1 52 | 53 | BIAS_LR_FACTOR: 1 54 | WEIGHT_DECAY: 0.0005 55 | WEIGHT_DECAY_BIAS: 0.0005 56 | IMS_PER_BATCH: 64 57 | 58 | STEPS: [40, 70] 59 | GAMMA: 0.1 60 | 61 | WARMUP_FACTOR: 0.01 62 | WARMUP_ITERS: 10 63 | WARMUP_METHOD: 'linear' 64 | 65 | CHECKPOINT_PERIOD: 40 66 | LOG_PERIOD: 20 67 | EVAL_PERIOD: 40 68 | 69 | TEST: 70 | IMS_PER_BATCH: 128 71 | RE_RANKING: 'no' 72 | WEIGHT: "path" 73 | NECK_FEAT: 'after' 74 | FEAT_NORM: 'yes' 75 | TEST_NFORMER: 'NO' 76 | 77 | OUTPUT_DIR: "work_dirs" 78 | 79 | 80 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import make_data_loader 8 | -------------------------------------------------------------------------------- /data/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from torch.utils.data import DataLoader 8 | 9 | from .collate_batch import train_collate_fn, val_collate_fn 10 | from .datasets import init_dataset, ImageDataset 11 | from .samplers import RandomIdentitySampler, RandomIdentitySampler_alignedreid # New add by gu 12 | from .transforms import build_transforms 13 | 14 | 15 | def make_data_loader(cfg): 16 | train_transforms = build_transforms(cfg, is_train=True) 17 | val_transforms = build_transforms(cfg, is_train=False) 18 | num_workers = cfg.DATALOADER.NUM_WORKERS 19 | if len(cfg.DATASETS.NAMES) == 1: 20 | dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR) 21 | else: 22 | # TODO: add multi dataset to train 23 | dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR) 24 | 25 | num_classes = dataset.num_train_pids 26 | train_set = ImageDataset(dataset.train, train_transforms) 27 | if cfg.DATALOADER.SAMPLER == 'softmax': 28 | train_loader = DataLoader( 29 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers, 30 | collate_fn=train_collate_fn 31 | ) 32 | else: 33 | train_loader = DataLoader( 34 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, 35 | sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE), 36 | # sampler=RandomIdentitySampler_alignedreid(dataset.train, cfg.DATALOADER.NUM_INSTANCE), # new add by gu 37 | num_workers=num_workers, collate_fn=train_collate_fn 38 | ) 39 | 40 | val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms) 41 | val_loader = DataLoader( 42 | val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 43 | collate_fn=val_collate_fn 44 | ) 45 | return train_loader, val_loader, len(dataset.query), num_classes 46 | -------------------------------------------------------------------------------- /data/collate_batch.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | 9 | 10 | def train_collate_fn(batch): 11 | imgs, pids, _, _, = zip(*batch) 12 | pids = torch.tensor(pids, dtype=torch.int64) 13 | return torch.stack(imgs, dim=0), pids 14 | 15 | 16 | def val_collate_fn(batch): 17 | imgs, pids, camids, _ = zip(*batch) 18 | return torch.stack(imgs, dim=0), pids, camids 19 | -------------------------------------------------------------------------------- /data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | # from .cuhk03 import CUHK03 7 | from .dukemtmcreid import DukeMTMCreID 8 | from .market1501 import Market1501 9 | from .msmt17 import MSMT17 10 | from .veri import VeRi 11 | from .dataset_loader import ImageDataset 12 | 13 | __factory = { 14 | 'market1501': Market1501, 15 | # 'cuhk03': CUHK03, 16 | 'dukemtmc': DukeMTMCreID, 17 | 'msmt17': MSMT17, 18 | 'veri': VeRi, 19 | } 20 | 21 | 22 | def get_names(): 23 | return __factory.keys() 24 | 25 | 26 | def init_dataset(name, *args, **kwargs): 27 | if name not in __factory.keys(): 28 | raise KeyError("Unknown datasets: {}".format(name)) 29 | return __factory[name](*args, **kwargs) 30 | -------------------------------------------------------------------------------- /data/datasets/bases.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import numpy as np 8 | 9 | 10 | class BaseDataset(object): 11 | """ 12 | Base class of reid dataset 13 | """ 14 | 15 | def get_imagedata_info(self, data): 16 | pids, cams = [], [] 17 | for _, pid, camid in data: 18 | pids += [pid] 19 | cams += [camid] 20 | pids = set(pids) 21 | cams = set(cams) 22 | num_pids = len(pids) 23 | num_cams = len(cams) 24 | num_imgs = len(data) 25 | return num_pids, num_imgs, num_cams 26 | 27 | def get_videodata_info(self, data, return_tracklet_stats=False): 28 | pids, cams, tracklet_stats = [], [], [] 29 | for img_paths, pid, camid in data: 30 | pids += [pid] 31 | cams += [camid] 32 | tracklet_stats += [len(img_paths)] 33 | pids = set(pids) 34 | cams = set(cams) 35 | num_pids = len(pids) 36 | num_cams = len(cams) 37 | num_tracklets = len(data) 38 | if return_tracklet_stats: 39 | return num_pids, num_tracklets, num_cams, tracklet_stats 40 | return num_pids, num_tracklets, num_cams 41 | 42 | def print_dataset_statistics(self): 43 | raise NotImplementedError 44 | 45 | 46 | class BaseImageDataset(BaseDataset): 47 | """ 48 | Base class of image reid dataset 49 | """ 50 | 51 | def print_dataset_statistics(self, train, query, gallery): 52 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 53 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 54 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 55 | 56 | print("Dataset statistics:") 57 | print(" ----------------------------------------") 58 | print(" subset | # ids | # images | # cameras") 59 | print(" ----------------------------------------") 60 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 61 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 62 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 63 | print(" ----------------------------------------") 64 | 65 | 66 | class BaseVideoDataset(BaseDataset): 67 | """ 68 | Base class of video reid dataset 69 | """ 70 | 71 | def print_dataset_statistics(self, train, query, gallery): 72 | num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \ 73 | self.get_videodata_info(train, return_tracklet_stats=True) 74 | 75 | num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \ 76 | self.get_videodata_info(query, return_tracklet_stats=True) 77 | 78 | num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \ 79 | self.get_videodata_info(gallery, return_tracklet_stats=True) 80 | 81 | tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats 82 | min_num = np.min(tracklet_stats) 83 | max_num = np.max(tracklet_stats) 84 | avg_num = np.mean(tracklet_stats) 85 | 86 | print("Dataset statistics:") 87 | print(" -------------------------------------------") 88 | print(" subset | # ids | # tracklets | # cameras") 89 | print(" -------------------------------------------") 90 | print(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_cams)) 91 | print(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_cams)) 92 | print(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams)) 93 | print(" -------------------------------------------") 94 | print(" number of images per tracklet: {} ~ {}, average {:.2f}".format(min_num, max_num, avg_num)) 95 | print(" -------------------------------------------") 96 | -------------------------------------------------------------------------------- /data/datasets/cuhk03.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import h5py 8 | import os.path as osp 9 | from scipy.io import loadmat 10 | from scipy.misc import imsave 11 | 12 | from utils.iotools import mkdir_if_missing, write_json, read_json 13 | from .bases import BaseImageDataset 14 | 15 | 16 | class CUHK03(BaseImageDataset): 17 | """ 18 | CUHK03 19 | Reference: 20 | Li et al. DeepReID: Deep Filter Pairing Neural Network for Person Re-identification. CVPR 2014. 21 | URL: http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html#! 22 | 23 | Dataset statistics: 24 | # identities: 1360 25 | # images: 13164 26 | # cameras: 6 27 | # splits: 20 (classic) 28 | Args: 29 | split_id (int): split index (default: 0) 30 | cuhk03_labeled (bool): whether to load labeled images; if false, detected images are loaded (default: False) 31 | """ 32 | dataset_dir = 'cuhk03' 33 | 34 | def __init__(self, root='/home/haoluo/data', split_id=0, cuhk03_labeled=False, 35 | cuhk03_classic_split=False, verbose=True, 36 | **kwargs): 37 | super(CUHK03, self).__init__() 38 | self.dataset_dir = osp.join(root, self.dataset_dir) 39 | self.data_dir = osp.join(self.dataset_dir, 'cuhk03_release') 40 | self.raw_mat_path = osp.join(self.data_dir, 'cuhk-03.mat') 41 | 42 | self.imgs_detected_dir = osp.join(self.dataset_dir, 'images_detected') 43 | self.imgs_labeled_dir = osp.join(self.dataset_dir, 'images_labeled') 44 | 45 | self.split_classic_det_json_path = osp.join(self.dataset_dir, 'splits_classic_detected.json') 46 | self.split_classic_lab_json_path = osp.join(self.dataset_dir, 'splits_classic_labeled.json') 47 | 48 | self.split_new_det_json_path = osp.join(self.dataset_dir, 'splits_new_detected.json') 49 | self.split_new_lab_json_path = osp.join(self.dataset_dir, 'splits_new_labeled.json') 50 | 51 | self.split_new_det_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_detected.mat') 52 | self.split_new_lab_mat_path = osp.join(self.dataset_dir, 'cuhk03_new_protocol_config_labeled.mat') 53 | 54 | self._check_before_run() 55 | self._preprocess() 56 | 57 | if cuhk03_labeled: 58 | image_type = 'labeled' 59 | split_path = self.split_classic_lab_json_path if cuhk03_classic_split else self.split_new_lab_json_path 60 | else: 61 | image_type = 'detected' 62 | split_path = self.split_classic_det_json_path if cuhk03_classic_split else self.split_new_det_json_path 63 | 64 | splits = read_json(split_path) 65 | assert split_id < len(splits), "Condition split_id ({}) < len(splits) ({}) is false".format(split_id, 66 | len(splits)) 67 | split = splits[split_id] 68 | print("Split index = {}".format(split_id)) 69 | 70 | train = split['train'] 71 | query = split['query'] 72 | gallery = split['gallery'] 73 | 74 | if verbose: 75 | print("=> CUHK03 ({}) loaded".format(image_type)) 76 | self.print_dataset_statistics(train, query, gallery) 77 | 78 | self.train = train 79 | self.query = query 80 | self.gallery = gallery 81 | 82 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 83 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 84 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 85 | 86 | def _check_before_run(self): 87 | """Check if all files are available before going deeper""" 88 | if not osp.exists(self.dataset_dir): 89 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 90 | if not osp.exists(self.data_dir): 91 | raise RuntimeError("'{}' is not available".format(self.data_dir)) 92 | if not osp.exists(self.raw_mat_path): 93 | raise RuntimeError("'{}' is not available".format(self.raw_mat_path)) 94 | if not osp.exists(self.split_new_det_mat_path): 95 | raise RuntimeError("'{}' is not available".format(self.split_new_det_mat_path)) 96 | if not osp.exists(self.split_new_lab_mat_path): 97 | raise RuntimeError("'{}' is not available".format(self.split_new_lab_mat_path)) 98 | 99 | def _preprocess(self): 100 | """ 101 | This function is a bit complex and ugly, what it does is 102 | 1. Extract data from cuhk-03.mat and save as png images. 103 | 2. Create 20 classic splits. (Li et al. CVPR'14) 104 | 3. Create new split. (Zhong et al. CVPR'17) 105 | """ 106 | print( 107 | "Note: if root path is changed, the previously generated json files need to be re-generated (delete them first)") 108 | if osp.exists(self.imgs_labeled_dir) and \ 109 | osp.exists(self.imgs_detected_dir) and \ 110 | osp.exists(self.split_classic_det_json_path) and \ 111 | osp.exists(self.split_classic_lab_json_path) and \ 112 | osp.exists(self.split_new_det_json_path) and \ 113 | osp.exists(self.split_new_lab_json_path): 114 | return 115 | 116 | mkdir_if_missing(self.imgs_detected_dir) 117 | mkdir_if_missing(self.imgs_labeled_dir) 118 | 119 | print("Extract image data from {} and save as png".format(self.raw_mat_path)) 120 | mat = h5py.File(self.raw_mat_path, 'r') 121 | 122 | def _deref(ref): 123 | return mat[ref][:].T 124 | 125 | def _process_images(img_refs, campid, pid, save_dir): 126 | img_paths = [] # Note: some persons only have images for one view 127 | for imgid, img_ref in enumerate(img_refs): 128 | img = _deref(img_ref) 129 | # skip empty cell 130 | if img.size == 0 or img.ndim < 3: continue 131 | # images are saved with the following format, index-1 (ensure uniqueness) 132 | # campid: index of camera pair (1-5) 133 | # pid: index of person in 'campid'-th camera pair 134 | # viewid: index of view, {1, 2} 135 | # imgid: index of image, (1-10) 136 | viewid = 1 if imgid < 5 else 2 137 | img_name = '{:01d}_{:03d}_{:01d}_{:02d}.png'.format(campid + 1, pid + 1, viewid, imgid + 1) 138 | img_path = osp.join(save_dir, img_name) 139 | if not osp.isfile(img_path): 140 | imsave(img_path, img) 141 | img_paths.append(img_path) 142 | return img_paths 143 | 144 | def _extract_img(name): 145 | print("Processing {} images (extract and save) ...".format(name)) 146 | meta_data = [] 147 | imgs_dir = self.imgs_detected_dir if name == 'detected' else self.imgs_labeled_dir 148 | for campid, camp_ref in enumerate(mat[name][0]): 149 | camp = _deref(camp_ref) 150 | num_pids = camp.shape[0] 151 | for pid in range(num_pids): 152 | img_paths = _process_images(camp[pid, :], campid, pid, imgs_dir) 153 | assert len(img_paths) > 0, "campid{}-pid{} has no images".format(campid, pid) 154 | meta_data.append((campid + 1, pid + 1, img_paths)) 155 | print("- done camera pair {} with {} identities".format(campid + 1, num_pids)) 156 | return meta_data 157 | 158 | meta_detected = _extract_img('detected') 159 | meta_labeled = _extract_img('labeled') 160 | 161 | def _extract_classic_split(meta_data, test_split): 162 | train, test = [], [] 163 | num_train_pids, num_test_pids = 0, 0 164 | num_train_imgs, num_test_imgs = 0, 0 165 | for i, (campid, pid, img_paths) in enumerate(meta_data): 166 | 167 | if [campid, pid] in test_split: 168 | for img_path in img_paths: 169 | camid = int(osp.basename(img_path).split('_')[2]) - 1 # make it 0-based 170 | test.append((img_path, num_test_pids, camid)) 171 | num_test_pids += 1 172 | num_test_imgs += len(img_paths) 173 | else: 174 | for img_path in img_paths: 175 | camid = int(osp.basename(img_path).split('_')[2]) - 1 # make it 0-based 176 | train.append((img_path, num_train_pids, camid)) 177 | num_train_pids += 1 178 | num_train_imgs += len(img_paths) 179 | return train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs 180 | 181 | print("Creating classic splits (# = 20) ...") 182 | splits_classic_det, splits_classic_lab = [], [] 183 | for split_ref in mat['testsets'][0]: 184 | test_split = _deref(split_ref).tolist() 185 | 186 | # create split for detected images 187 | train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \ 188 | _extract_classic_split(meta_detected, test_split) 189 | splits_classic_det.append({ 190 | 'train': train, 'query': test, 'gallery': test, 191 | 'num_train_pids': num_train_pids, 'num_train_imgs': num_train_imgs, 192 | 'num_query_pids': num_test_pids, 'num_query_imgs': num_test_imgs, 193 | 'num_gallery_pids': num_test_pids, 'num_gallery_imgs': num_test_imgs, 194 | }) 195 | 196 | # create split for labeled images 197 | train, num_train_pids, num_train_imgs, test, num_test_pids, num_test_imgs = \ 198 | _extract_classic_split(meta_labeled, test_split) 199 | splits_classic_lab.append({ 200 | 'train': train, 'query': test, 'gallery': test, 201 | 'num_train_pids': num_train_pids, 'num_train_imgs': num_train_imgs, 202 | 'num_query_pids': num_test_pids, 'num_query_imgs': num_test_imgs, 203 | 'num_gallery_pids': num_test_pids, 'num_gallery_imgs': num_test_imgs, 204 | }) 205 | 206 | write_json(splits_classic_det, self.split_classic_det_json_path) 207 | write_json(splits_classic_lab, self.split_classic_lab_json_path) 208 | 209 | def _extract_set(filelist, pids, pid2label, idxs, img_dir, relabel): 210 | tmp_set = [] 211 | unique_pids = set() 212 | for idx in idxs: 213 | img_name = filelist[idx][0] 214 | camid = int(img_name.split('_')[2]) - 1 # make it 0-based 215 | pid = pids[idx] 216 | if relabel: pid = pid2label[pid] 217 | img_path = osp.join(img_dir, img_name) 218 | tmp_set.append((img_path, int(pid), camid)) 219 | unique_pids.add(pid) 220 | return tmp_set, len(unique_pids), len(idxs) 221 | 222 | def _extract_new_split(split_dict, img_dir): 223 | train_idxs = split_dict['train_idx'].flatten() - 1 # index-0 224 | pids = split_dict['labels'].flatten() 225 | train_pids = set(pids[train_idxs]) 226 | pid2label = {pid: label for label, pid in enumerate(train_pids)} 227 | query_idxs = split_dict['query_idx'].flatten() - 1 228 | gallery_idxs = split_dict['gallery_idx'].flatten() - 1 229 | filelist = split_dict['filelist'].flatten() 230 | train_info = _extract_set(filelist, pids, pid2label, train_idxs, img_dir, relabel=True) 231 | query_info = _extract_set(filelist, pids, pid2label, query_idxs, img_dir, relabel=False) 232 | gallery_info = _extract_set(filelist, pids, pid2label, gallery_idxs, img_dir, relabel=False) 233 | return train_info, query_info, gallery_info 234 | 235 | print("Creating new splits for detected images (767/700) ...") 236 | train_info, query_info, gallery_info = _extract_new_split( 237 | loadmat(self.split_new_det_mat_path), 238 | self.imgs_detected_dir, 239 | ) 240 | splits = [{ 241 | 'train': train_info[0], 'query': query_info[0], 'gallery': gallery_info[0], 242 | 'num_train_pids': train_info[1], 'num_train_imgs': train_info[2], 243 | 'num_query_pids': query_info[1], 'num_query_imgs': query_info[2], 244 | 'num_gallery_pids': gallery_info[1], 'num_gallery_imgs': gallery_info[2], 245 | }] 246 | write_json(splits, self.split_new_det_json_path) 247 | 248 | print("Creating new splits for labeled images (767/700) ...") 249 | train_info, query_info, gallery_info = _extract_new_split( 250 | loadmat(self.split_new_lab_mat_path), 251 | self.imgs_labeled_dir, 252 | ) 253 | splits = [{ 254 | 'train': train_info[0], 'query': query_info[0], 'gallery': gallery_info[0], 255 | 'num_train_pids': train_info[1], 'num_train_imgs': train_info[2], 256 | 'num_query_pids': query_info[1], 'num_query_imgs': query_info[2], 257 | 'num_gallery_pids': gallery_info[1], 'num_gallery_imgs': gallery_info[2], 258 | }] 259 | write_json(splits, self.split_new_lab_json_path) 260 | -------------------------------------------------------------------------------- /data/datasets/dataset_loader.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import os.path as osp 8 | from PIL import Image 9 | from torch.utils.data import Dataset 10 | 11 | 12 | def read_image(img_path): 13 | """Keep reading image until succeed. 14 | This can avoid IOError incurred by heavy IO process.""" 15 | got_img = False 16 | if not osp.exists(img_path): 17 | raise IOError("{} does not exist".format(img_path)) 18 | while not got_img: 19 | try: 20 | img = Image.open(img_path).convert('RGB') 21 | got_img = True 22 | except IOError: 23 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 24 | pass 25 | return img 26 | 27 | 28 | class ImageDataset(Dataset): 29 | """Image Person ReID Dataset""" 30 | 31 | def __init__(self, dataset, transform=None): 32 | self.dataset = dataset 33 | self.transform = transform 34 | 35 | def __len__(self): 36 | return len(self.dataset) 37 | 38 | def __getitem__(self, index): 39 | img_path, pid, camid = self.dataset[index] 40 | img = read_image(img_path) 41 | 42 | if self.transform is not None: 43 | img = self.transform(img) 44 | 45 | return img, pid, camid, img_path 46 | -------------------------------------------------------------------------------- /data/datasets/dukemtmcreid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import glob 8 | import re 9 | import urllib 10 | import zipfile 11 | 12 | import os.path as osp 13 | 14 | from utils.iotools import mkdir_if_missing 15 | from .bases import BaseImageDataset 16 | 17 | 18 | class DukeMTMCreID(BaseImageDataset): 19 | """ 20 | DukeMTMC-reID 21 | Reference: 22 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 23 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 24 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 25 | 26 | Dataset statistics: 27 | # identities: 1404 (train + query) 28 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 29 | # cameras: 8 30 | """ 31 | dataset_dir = 'dukemtmc-reid' 32 | 33 | def __init__(self, root='/home/haoluo/data', verbose=True, **kwargs): 34 | super(DukeMTMCreID, self).__init__() 35 | self.dataset_dir = osp.join(root, self.dataset_dir) 36 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 37 | self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train') 38 | self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query') 39 | self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test') 40 | 41 | self._download_data() 42 | self._check_before_run() 43 | 44 | train = self._process_dir(self.train_dir, relabel=True) 45 | query = self._process_dir(self.query_dir, relabel=False) 46 | gallery = self._process_dir(self.gallery_dir, relabel=False) 47 | 48 | if verbose: 49 | print("=> DukeMTMC-reID loaded") 50 | self.print_dataset_statistics(train, query, gallery) 51 | 52 | self.train = train 53 | self.query = query 54 | self.gallery = gallery 55 | 56 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 57 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 58 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 59 | 60 | def _download_data(self): 61 | if osp.exists(self.dataset_dir): 62 | print("This dataset has been downloaded.") 63 | return 64 | 65 | print("Creating directory {}".format(self.dataset_dir)) 66 | mkdir_if_missing(self.dataset_dir) 67 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 68 | 69 | print("Downloading DukeMTMC-reID dataset") 70 | urllib.request.urlretrieve(self.dataset_url, fpath) 71 | 72 | print("Extracting files") 73 | zip_ref = zipfile.ZipFile(fpath, 'r') 74 | zip_ref.extractall(self.dataset_dir) 75 | zip_ref.close() 76 | 77 | def _check_before_run(self): 78 | """Check if all files are available before going deeper""" 79 | if not osp.exists(self.dataset_dir): 80 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 81 | if not osp.exists(self.train_dir): 82 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 83 | if not osp.exists(self.query_dir): 84 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 85 | if not osp.exists(self.gallery_dir): 86 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 87 | 88 | def _process_dir(self, dir_path, relabel=False): 89 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 90 | pattern = re.compile(r'([-\d]+)_c(\d)') 91 | 92 | pid_container = set() 93 | for img_path in img_paths: 94 | pid, _ = map(int, pattern.search(img_path).groups()) 95 | pid_container.add(pid) 96 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 97 | 98 | dataset = [] 99 | for img_path in img_paths: 100 | pid, camid = map(int, pattern.search(img_path).groups()) 101 | assert 1 <= camid <= 8 102 | camid -= 1 # index starts from 0 103 | if relabel: pid = pid2label[pid] 104 | dataset.append((img_path, pid, camid)) 105 | 106 | return dataset 107 | -------------------------------------------------------------------------------- /data/datasets/eval_reid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import numpy as np 8 | 9 | 10 | def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): 11 | """Evaluation with market1501 metric 12 | Key: for each query identity, its gallery images from the same camera view are discarded. 13 | """ 14 | num_q, num_g = distmat.shape 15 | if num_g < max_rank: 16 | max_rank = num_g 17 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 18 | indices = np.argsort(distmat, axis=1) 19 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 20 | 21 | # compute cmc curve for each query 22 | all_cmc = [] 23 | all_AP = [] 24 | num_valid_q = 0. # number of valid query 25 | for q_idx in range(num_q): 26 | # get query pid and camid 27 | q_pid = q_pids[q_idx] 28 | q_camid = q_camids[q_idx] 29 | 30 | # remove gallery samples that have the same pid and camid with query 31 | order = indices[q_idx] 32 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 33 | keep = np.invert(remove) 34 | 35 | # compute cmc curve 36 | # binary vector, positions with value 1 are correct matches 37 | orig_cmc = matches[q_idx][keep] 38 | if not np.any(orig_cmc): 39 | # this condition is true when query identity does not appear in gallery 40 | continue 41 | 42 | cmc = orig_cmc.cumsum() 43 | cmc[cmc > 1] = 1 44 | 45 | all_cmc.append(cmc[:max_rank]) 46 | num_valid_q += 1. 47 | 48 | # compute average precision 49 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 50 | num_rel = orig_cmc.sum() 51 | tmp_cmc = orig_cmc.cumsum() 52 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 53 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 54 | AP = tmp_cmc.sum() / num_rel 55 | all_AP.append(AP) 56 | 57 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 58 | 59 | all_cmc = np.asarray(all_cmc).astype(np.float32) 60 | all_cmc = all_cmc.sum(0) / num_valid_q 61 | mAP = np.mean(all_AP) 62 | 63 | return all_cmc, mAP 64 | -------------------------------------------------------------------------------- /data/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import glob 8 | import re 9 | 10 | import os.path as osp 11 | 12 | from .bases import BaseImageDataset 13 | 14 | 15 | class Market1501(BaseImageDataset): 16 | """ 17 | Market1501 18 | Reference: 19 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 20 | URL: http://www.liangzheng.org/Project/project_reid.html 21 | 22 | Dataset statistics: 23 | # identities: 1501 (+1 for background) 24 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 25 | """ 26 | dataset_dir = 'market1501' 27 | 28 | def __init__(self, root='/home/haoluo/data', verbose=True, **kwargs): 29 | super(Market1501, self).__init__() 30 | self.dataset_dir = osp.join(root, self.dataset_dir) 31 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 32 | self.query_dir = osp.join(self.dataset_dir, 'query') 33 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 34 | 35 | self._check_before_run() 36 | 37 | train = self._process_dir(self.train_dir, relabel=True) 38 | query = self._process_dir(self.query_dir, relabel=False) 39 | gallery = self._process_dir(self.gallery_dir, relabel=False) 40 | 41 | if verbose: 42 | print("=> Market1501 loaded") 43 | self.print_dataset_statistics(train, query, gallery) 44 | 45 | self.train = train 46 | self.query = query 47 | self.gallery = gallery 48 | 49 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 50 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 51 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 52 | 53 | def _check_before_run(self): 54 | """Check if all files are available before going deeper""" 55 | if not osp.exists(self.dataset_dir): 56 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 57 | if not osp.exists(self.train_dir): 58 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 59 | if not osp.exists(self.query_dir): 60 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 61 | if not osp.exists(self.gallery_dir): 62 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 63 | 64 | def _process_dir(self, dir_path, relabel=False): 65 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 66 | pattern = re.compile(r'([-\d]+)_c(\d)') 67 | 68 | pid_container = set() 69 | for img_path in img_paths: 70 | pid, _ = map(int, pattern.search(img_path).groups()) 71 | if pid == -1: continue # junk images are just ignored 72 | pid_container.add(pid) 73 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 74 | 75 | dataset = [] 76 | for img_path in img_paths: 77 | pid, camid = map(int, pattern.search(img_path).groups()) 78 | if pid == -1: continue # junk images are just ignored 79 | assert 0 <= pid <= 1501 # pid == 0 means background 80 | assert 1 <= camid <= 6 81 | camid -= 1 # index starts from 0 82 | if relabel: pid = pid2label[pid] 83 | dataset.append((img_path, pid, camid)) 84 | 85 | return dataset 86 | -------------------------------------------------------------------------------- /data/datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 2019/1/17 15:00 4 | # @Author : Hao Luo 5 | # @File : msmt17.py 6 | 7 | import glob 8 | import re 9 | 10 | import os.path as osp 11 | 12 | from .bases import BaseImageDataset 13 | 14 | 15 | class MSMT17(BaseImageDataset): 16 | """ 17 | MSMT17 18 | 19 | Reference: 20 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018. 21 | 22 | URL: http://www.pkuvmc.com/publications/msmt17.html 23 | 24 | Dataset statistics: 25 | # identities: 4101 26 | # images: 32621 (train) + 11659 (query) + 82161 (gallery) 27 | # cameras: 15 28 | """ 29 | dataset_dir = 'msmt17' 30 | 31 | def __init__(self,root='/home/haoluo/data', verbose=True, **kwargs): 32 | super(MSMT17, self).__init__() 33 | self.dataset_dir = osp.join(root, self.dataset_dir) 34 | self.train_dir = osp.join(self.dataset_dir, 'MSMT17_V2/mask_train_v2') 35 | self.test_dir = osp.join(self.dataset_dir, 'MSMT17_V2/mask_test_v2') 36 | self.list_train_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_train.txt') 37 | self.list_val_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_val.txt') 38 | self.list_query_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_query.txt') 39 | self.list_gallery_path = osp.join(self.dataset_dir, 'MSMT17_V2/list_gallery.txt') 40 | 41 | self._check_before_run() 42 | train = self._process_dir(self.train_dir, self.list_train_path) 43 | #val, num_val_pids, num_val_imgs = self._process_dir(self.train_dir, self.list_val_path) 44 | query = self._process_dir(self.test_dir, self.list_query_path) 45 | gallery = self._process_dir(self.test_dir, self.list_gallery_path) 46 | if verbose: 47 | print("=> MSMT17 loaded") 48 | self.print_dataset_statistics(train, query, gallery) 49 | 50 | self.train = train 51 | self.query = query 52 | self.gallery = gallery 53 | 54 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 55 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 56 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 57 | 58 | def _check_before_run(self): 59 | """Check if all files are available before going deeper""" 60 | if not osp.exists(self.dataset_dir): 61 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 62 | if not osp.exists(self.train_dir): 63 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 64 | if not osp.exists(self.test_dir): 65 | raise RuntimeError("'{}' is not available".format(self.test_dir)) 66 | 67 | def _process_dir(self, dir_path, list_path): 68 | with open(list_path, 'r') as txt: 69 | lines = txt.readlines() 70 | dataset = [] 71 | pid_container = set() 72 | for img_idx, img_info in enumerate(lines): 73 | img_path, pid = img_info.split(' ') 74 | pid = int(pid) # no need to relabel 75 | camid = int(img_path.split('_')[2]) 76 | img_path = osp.join(dir_path, img_path) 77 | dataset.append((img_path, pid, camid)) 78 | pid_container.add(pid) 79 | 80 | # check if pid starts from 0 and increments with 1 81 | for idx, pid in enumerate(pid_container): 82 | assert idx == pid, "See code comment for explanation" 83 | return dataset -------------------------------------------------------------------------------- /data/datasets/nformer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | 5 | import torch 6 | import torchvision 7 | from torch.utils import data 8 | 9 | import glob 10 | from sklearn.preprocessing import normalize 11 | import random 12 | 13 | class NFormerDataset(data.Dataset): 14 | def __init__(self, data, data_length = 7000): 15 | self.data_length = data_length 16 | self.feats = data[0] 17 | self.ids = data[1] 18 | self.data_num = self.feats.shape[0] 19 | 20 | def __len__(self): 21 | return self.feats.shape[0]//30 22 | 23 | def __getitem__(self, index): 24 | center_index = random.randint(0, self.data_num - 1) 25 | center_feat = self.feats[center_index].unsqueeze(0) 26 | center_pid = self.ids[center_index] 27 | 28 | selected_flags = torch.zeros(self.data_num) 29 | selected_flags[center_index] = 1 30 | distmat = 1 - torch.mm(center_feat, self.feats.transpose(0,1)) 31 | indices = torch.argsort(distmat, dim=1).numpy() 32 | indices = indices[0,:int(self.data_length * (1 + random.random()))].tolist() 33 | indices = random.sample(indices,self.data_length) 34 | 35 | random.shuffle(indices) 36 | feat_ = self.feats[indices] 37 | id_ = self.ids[indices] 38 | 39 | 40 | return feat_, id_ 41 | 42 | 43 | -------------------------------------------------------------------------------- /data/datasets/veri.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import re 3 | 4 | import os.path as osp 5 | 6 | from .bases import BaseImageDataset 7 | 8 | 9 | class VeRi(BaseImageDataset): 10 | """ 11 | VeRi-776 12 | Reference: 13 | Liu, Xinchen, et al. "Large-scale vehicle re-identification in urban surveillance videos." ICME 2016. 14 | 15 | URL:https://vehiclereid.github.io/VeRi/ 16 | 17 | Dataset statistics: 18 | # identities: 776 19 | # images: 37778 (train) + 1678 (query) + 11579 (gallery) 20 | # cameras: 20 21 | """ 22 | 23 | dataset_dir = 'veri' 24 | 25 | def __init__(self, root='../', verbose=True, **kwargs): 26 | super(VeRi, self).__init__() 27 | self.dataset_dir = osp.join(root, self.dataset_dir) 28 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 29 | self.query_dir = osp.join(self.dataset_dir, 'image_query') 30 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test') 31 | 32 | self._check_before_run() 33 | 34 | train = self._process_dir(self.train_dir, relabel=True) 35 | query = self._process_dir(self.query_dir, relabel=False) 36 | gallery = self._process_dir(self.gallery_dir, relabel=False) 37 | 38 | if verbose: 39 | print("=> VeRi-776 loaded") 40 | self.print_dataset_statistics(train, query, gallery) 41 | 42 | self.train = train 43 | self.query = query 44 | self.gallery = gallery 45 | 46 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 47 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 48 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 49 | 50 | def _check_before_run(self): 51 | """Check if all files are available before going deeper""" 52 | if not osp.exists(self.dataset_dir): 53 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 54 | if not osp.exists(self.train_dir): 55 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 56 | if not osp.exists(self.query_dir): 57 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 58 | if not osp.exists(self.gallery_dir): 59 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 60 | 61 | def _process_dir(self, dir_path, relabel=False): 62 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 63 | pattern = re.compile(r'([-\d]+)_c(\d+)') 64 | 65 | pid_container = set() 66 | for img_path in img_paths: 67 | pid, _ = map(int, pattern.search(img_path).groups()) 68 | if pid == -1: continue # junk images are just ignored 69 | pid_container.add(pid) 70 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 71 | 72 | dataset = [] 73 | for img_path in img_paths: 74 | pid, camid = map(int, pattern.search(img_path).groups()) 75 | if pid == -1: continue # junk images are just ignored 76 | assert 0 <= pid <= 776 # pid == 0 means background 77 | assert 1 <= camid <= 20 78 | camid -= 1 # index starts from 0 79 | if relabel: pid = pid2label[pid] 80 | dataset.append((img_path, pid, camid)) 81 | 82 | return dataset 83 | 84 | -------------------------------------------------------------------------------- /data/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .triplet_sampler import RandomIdentitySampler, RandomIdentitySampler_alignedreid # new add by gu 8 | -------------------------------------------------------------------------------- /data/samplers/triplet_sampler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import copy 8 | import random 9 | import torch 10 | from collections import defaultdict 11 | 12 | import numpy as np 13 | from torch.utils.data.sampler import Sampler 14 | 15 | 16 | class RandomIdentitySampler(Sampler): 17 | """ 18 | Randomly sample N identities, then for each identity, 19 | randomly sample K instances, therefore batch size is N*K. 20 | Args: 21 | - data_source (list): list of (img_path, pid, camid). 22 | - num_instances (int): number of instances per identity in a batch. 23 | - batch_size (int): number of examples in a batch. 24 | """ 25 | 26 | def __init__(self, data_source, batch_size, num_instances): 27 | self.data_source = data_source 28 | self.batch_size = batch_size 29 | self.num_instances = num_instances 30 | self.num_pids_per_batch = self.batch_size // self.num_instances 31 | self.index_dic = defaultdict(list) 32 | for index, (_, pid, _) in enumerate(self.data_source): 33 | self.index_dic[pid].append(index) 34 | self.pids = list(self.index_dic.keys()) 35 | 36 | # estimate number of examples in an epoch 37 | self.length = 0 38 | for pid in self.pids: 39 | idxs = self.index_dic[pid] 40 | num = len(idxs) 41 | if num < self.num_instances: 42 | num = self.num_instances 43 | self.length += num - num % self.num_instances 44 | 45 | def __iter__(self): 46 | batch_idxs_dict = defaultdict(list) 47 | 48 | for pid in self.pids: 49 | idxs = copy.deepcopy(self.index_dic[pid]) 50 | if len(idxs) < self.num_instances: 51 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 52 | random.shuffle(idxs) 53 | batch_idxs = [] 54 | for idx in idxs: 55 | batch_idxs.append(idx) 56 | if len(batch_idxs) == self.num_instances: 57 | batch_idxs_dict[pid].append(batch_idxs) 58 | batch_idxs = [] 59 | 60 | avai_pids = copy.deepcopy(self.pids) 61 | final_idxs = [] 62 | 63 | while len(avai_pids) >= self.num_pids_per_batch: 64 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 65 | for pid in selected_pids: 66 | batch_idxs = batch_idxs_dict[pid].pop(0) 67 | final_idxs.extend(batch_idxs) 68 | if len(batch_idxs_dict[pid]) == 0: 69 | avai_pids.remove(pid) 70 | 71 | self.length = len(final_idxs) 72 | return iter(final_idxs) 73 | 74 | def __len__(self): 75 | return self.length 76 | 77 | 78 | # New add by gu 79 | class RandomIdentitySampler_alignedreid(Sampler): 80 | """ 81 | Randomly sample N identities, then for each identity, 82 | randomly sample K instances, therefore batch size is N*K. 83 | 84 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/data/sampler.py. 85 | 86 | Args: 87 | data_source (Dataset): dataset to sample from. 88 | num_instances (int): number of instances per identity. 89 | """ 90 | def __init__(self, data_source, num_instances): 91 | self.data_source = data_source 92 | self.num_instances = num_instances 93 | self.index_dic = defaultdict(list) 94 | for index, (_, pid, _) in enumerate(data_source): 95 | self.index_dic[pid].append(index) 96 | self.pids = list(self.index_dic.keys()) 97 | self.num_identities = len(self.pids) 98 | 99 | def __iter__(self): 100 | indices = torch.randperm(self.num_identities) 101 | ret = [] 102 | for i in indices: 103 | pid = self.pids[i] 104 | t = self.index_dic[pid] 105 | replace = False if len(t) >= self.num_instances else True 106 | t = np.random.choice(t, size=self.num_instances, replace=replace) 107 | ret.extend(t) 108 | return iter(ret) 109 | 110 | def __len__(self): 111 | return self.num_identities * self.num_instances 112 | -------------------------------------------------------------------------------- /data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import build_transforms 8 | -------------------------------------------------------------------------------- /data/transforms/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import torchvision.transforms as T 8 | 9 | from .transforms import RandomErasing 10 | 11 | 12 | def build_transforms(cfg, is_train=True): 13 | normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) 14 | if is_train: 15 | transform = T.Compose([ 16 | T.Resize(cfg.INPUT.SIZE_TRAIN), 17 | T.RandomHorizontalFlip(p=cfg.INPUT.PROB), 18 | T.Pad(cfg.INPUT.PADDING), 19 | T.RandomCrop(cfg.INPUT.SIZE_TRAIN), 20 | T.ToTensor(), 21 | normalize_transform, 22 | RandomErasing(probability=cfg.INPUT.RE_PROB, mean=cfg.INPUT.PIXEL_MEAN) 23 | ]) 24 | else: 25 | transform = T.Compose([ 26 | T.Resize(cfg.INPUT.SIZE_TEST), 27 | T.ToTensor(), 28 | normalize_transform 29 | ]) 30 | 31 | return transform 32 | -------------------------------------------------------------------------------- /data/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import math 8 | import random 9 | 10 | 11 | class RandomErasing(object): 12 | """ Randomly selects a rectangle region in an image and erases its pixels. 13 | 'Random Erasing Data Augmentation' by Zhong et al. 14 | See https://arxiv.org/pdf/1708.04896.pdf 15 | Args: 16 | probability: The probability that the Random Erasing operation will be performed. 17 | sl: Minimum proportion of erased area against input image. 18 | sh: Maximum proportion of erased area against input image. 19 | r1: Minimum aspect ratio of erased area. 20 | mean: Erasing value. 21 | """ 22 | 23 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 24 | self.probability = probability 25 | self.mean = mean 26 | self.sl = sl 27 | self.sh = sh 28 | self.r1 = r1 29 | 30 | def __call__(self, img): 31 | 32 | if random.uniform(0, 1) >= self.probability: 33 | return img 34 | 35 | for attempt in range(100): 36 | area = img.size()[1] * img.size()[2] 37 | 38 | target_area = random.uniform(self.sl, self.sh) * area 39 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 40 | 41 | h = int(round(math.sqrt(target_area * aspect_ratio))) 42 | w = int(round(math.sqrt(target_area / aspect_ratio))) 43 | 44 | if w < img.size()[2] and h < img.size()[1]: 45 | x1 = random.randint(0, img.size()[1] - h) 46 | y1 = random.randint(0, img.size()[2] - w) 47 | if img.size()[0] == 3: 48 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 49 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 50 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 51 | else: 52 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 53 | return img 54 | 55 | return img 56 | -------------------------------------------------------------------------------- /engine/inference.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import logging 7 | 8 | import torch 9 | import torch.nn as nn 10 | from ignite.engine import Engine 11 | 12 | from utils.reid_metric import R1_mAP, NFormer_R1_mAP 13 | 14 | 15 | def create_supervised_evaluator(model, metrics, 16 | device=None): 17 | """ 18 | Factory function for creating an evaluator for supervised models 19 | 20 | Args: 21 | model (`torch.nn.Module`): the model to train 22 | metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics 23 | device (str, optional): device type specification (default: None). 24 | Applies to both model and batches. 25 | Returns: 26 | Engine: an evaluator engine with supervised inference function 27 | """ 28 | if device: 29 | if torch.cuda.device_count() > 1: 30 | model = nn.DataParallel(model) 31 | model.to(device) 32 | 33 | def _inference(engine, batch): 34 | model.eval() 35 | with torch.no_grad(): 36 | data, pids, camids = batch 37 | data = data.to(device) if torch.cuda.device_count() >= 1 else data 38 | feat = model(data) 39 | return feat, pids, camids 40 | 41 | engine = Engine(_inference) 42 | 43 | for name, metric in metrics.items(): 44 | metric.attach(engine, name) 45 | 46 | return engine 47 | 48 | 49 | def inference( 50 | cfg, 51 | model, 52 | val_loader, 53 | num_query 54 | ): 55 | device = cfg.MODEL.DEVICE 56 | 57 | logger = logging.getLogger("reid_baseline.inference") 58 | logger.info("Enter inferencing") 59 | if cfg.TEST.TEST_NFORMER != 'yes': 60 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)},device=device) 61 | else: 62 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': NFormer_R1_mAP(model, num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)},device=device) 63 | 64 | evaluator.run(val_loader) 65 | cmc, mAP = evaluator.state.metrics['r1_mAP'] 66 | logger.info('Validation Results') 67 | logger.info("mAP: {:.1%}".format(mAP)) 68 | for r in [1, 5, 10]: 69 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 70 | -------------------------------------------------------------------------------- /engine/trainer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import os 7 | import logging 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.utils import data 12 | from ignite.engine import Engine, Events 13 | from ignite.handlers import ModelCheckpoint, Timer 14 | from ignite.metrics import RunningAverage, Metric 15 | 16 | from utils.reid_metric import R1_mAP, NFormer_R1_mAP 17 | from data.datasets.nformer import NFormerDataset 18 | 19 | global ITER 20 | ITER = 0 21 | 22 | def create_supervised_trainer(model, optimizer, loss_fn, 23 | device=None): 24 | """ 25 | Factory function for creating a trainer for supervised models 26 | 27 | Args: 28 | model (`torch.nn.Module`): the model to train 29 | optimizer (`torch.optim.Optimizer`): the optimizer to use 30 | loss_fn (torch.nn loss function): the loss function to use 31 | device (str, optional): device type specification (default: None). 32 | Applies to both model and batches. 33 | 34 | Returns: 35 | Engine: a trainer engine with supervised update function 36 | """ 37 | if device: 38 | if torch.cuda.device_count() > 1: 39 | model = nn.DataParallel(model) 40 | model.to(device) 41 | 42 | def _update(engine, batch): 43 | model.train() 44 | optimizer.zero_grad() 45 | img, target = batch 46 | img = img.to(device) if torch.cuda.device_count() >= 1 else img 47 | target = target.to(device) if torch.cuda.device_count() >= 1 else target 48 | score, feat = model(img, stage='encoder') 49 | loss = loss_fn(score, feat, target) 50 | loss.backward() 51 | optimizer.step() 52 | # compute acc 53 | acc = (score.max(1)[1] == target).float().mean() 54 | return loss.item(), acc.item() 55 | 56 | return Engine(_update) 57 | 58 | 59 | def create_supervised_trainer_with_center(model, center_criterion, optimizer, optimizer_center, loss_fn, cetner_loss_weight, 60 | device=None): 61 | """ 62 | Factory function for creating a trainer for supervised models 63 | 64 | Args: 65 | model (`torch.nn.Module`): the model to train 66 | optimizer (`torch.optim.Optimizer`): the optimizer to use 67 | loss_fn (torch.nn loss function): the loss function to use 68 | device (str, optional): device type specification (default: None). 69 | Applies to both model and batches. 70 | 71 | Returns: 72 | Engine: a trainer engine with supervised update function 73 | """ 74 | if device: 75 | if torch.cuda.device_count() > 1: 76 | model = nn.DataParallel(model) 77 | model.to(device) 78 | 79 | def _update(engine, batch): 80 | model.train() 81 | optimizer.zero_grad() 82 | optimizer_center.zero_grad() 83 | img, target = batch 84 | img = img.to(device) if torch.cuda.device_count() >= 1 else img 85 | target = target.to(device) if torch.cuda.device_count() >= 1 else target 86 | score, feat = model(img, stage='encoder') 87 | loss = loss_fn(score, feat, target) 88 | # print("Total loss is {}, center loss is {}".format(loss, center_criterion(feat, target))) 89 | loss.backward() 90 | optimizer.step() 91 | for param in center_criterion.parameters(): 92 | param.grad.data *= (1. / cetner_loss_weight) 93 | optimizer_center.step() 94 | 95 | # compute acc 96 | acc = (score.max(1)[1] == target).float().mean() 97 | return loss.item(), acc.item() 98 | 99 | return Engine(_update) 100 | 101 | def create_supervised_nformer_trainer(model, nformer_center_criterion, optimizer, optimizer_nformer_center, nformer_loss_fn, cetner_loss_weight, device=None): 102 | if device: 103 | if torch.cuda.device_count() > 1: 104 | model = nn.DataParallel(model) 105 | model.to(device) 106 | 107 | def _update(engine, batch): 108 | model.train() 109 | optimizer.zero_grad() 110 | optimizer_nformer_center.zero_grad() 111 | feat, target = batch 112 | feat = feat.to(device) if torch.cuda.device_count() >= 1 else feat 113 | target = target.to(device) if torch.cuda.device_count() >= 1 else target 114 | score, feat = model(feat, stage='nformer') 115 | bs,dl,d = feat.shape 116 | score = score.reshape(bs * dl, -1) 117 | feat = feat.reshape(bs * dl, d) 118 | target = target.reshape(bs * dl) 119 | loss = nformer_loss_fn(score, feat, target) 120 | # print("Total loss is {}, center loss is {}".format(loss, center_criterion(feat, target))) 121 | loss.backward() 122 | optimizer.step() 123 | for param in nformer_center_criterion.parameters(): 124 | param.grad.data *= (1. / cetner_loss_weight) 125 | optimizer_nformer_center.step() 126 | 127 | # compute acc 128 | acc = (score.max(1)[1] == target).float().mean() 129 | return loss.item(), acc.item() 130 | 131 | 132 | return Engine(_update) 133 | 134 | class data_collector(Metric): 135 | def reset(self): 136 | self.feat = [] 137 | self.target = [] 138 | def update(self,output): 139 | feat, target = output 140 | self.feat.append(feat) 141 | self.target.append(target) 142 | def compute(self): 143 | feat = torch.cat(self.feat, dim=0) 144 | #feat = torch.nn.functional.normalize(feat, dim=1, p=2) 145 | target = torch.cat(self.target, dim=0) 146 | return feat, target 147 | 148 | def create_nformer_data_generator(model, metrics, device=None): 149 | if device: 150 | if torch.cuda.device_count() > 1: 151 | model = nn.DataParallel(model) 152 | model.to(device) 153 | def _inference(engine, batch): 154 | model.eval() 155 | img, target = batch 156 | img = img.to(device) if torch.cuda.device_count() >= 1 else img 157 | with torch.no_grad(): 158 | feat = model(img, stage='encoder') 159 | return feat.cpu(), target 160 | 161 | engine = Engine(_inference) 162 | 163 | for name, metric in metrics.items(): 164 | metric.attach(engine, name) 165 | 166 | return engine 167 | 168 | def create_supervised_evaluator(model, metrics, 169 | device=None): 170 | """ 171 | Factory function for creating an evaluator for supervised models 172 | 173 | Args: 174 | model (`torch.nn.Module`): the model to train 175 | metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics 176 | device (str, optional): device type specification (default: None). 177 | Applies to both model and batches. 178 | Returns: 179 | Engine: an evaluator engine with supervised inference function 180 | """ 181 | if device: 182 | if torch.cuda.device_count() > 1: 183 | model = nn.DataParallel(model) 184 | model.to(device) 185 | 186 | def _inference(engine, batch): 187 | model.eval() 188 | with torch.no_grad(): 189 | data, pids, camids = batch 190 | data = data.to(device) if torch.cuda.device_count() >= 1 else data 191 | feat = model(data) 192 | return feat, pids, camids 193 | 194 | engine = Engine(_inference) 195 | 196 | for name, metric in metrics.items(): 197 | metric.attach(engine, name) 198 | 199 | return engine 200 | 201 | 202 | def do_train( 203 | cfg, 204 | model, 205 | train_loader, 206 | val_loader, 207 | optimizer, 208 | scheduler, 209 | loss_fn, 210 | num_query, 211 | start_epoch 212 | ): 213 | log_period = cfg.SOLVER.LOG_PERIOD 214 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD 215 | eval_period = cfg.SOLVER.EVAL_PERIOD 216 | output_dir = cfg.OUTPUT_DIR 217 | device = cfg.MODEL.DEVICE 218 | epochs = cfg.SOLVER.MAX_EPOCHS 219 | 220 | logger = logging.getLogger("reid_baseline.train") 221 | logger.info("Start training") 222 | trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device) 223 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device) 224 | data_generator = create_nformer_data_generator(model, metrics={'data':data_collector()}, device=device) 225 | checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False) 226 | timer = Timer(average=True) 227 | 228 | trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model, 229 | 'optimizer': optimizer}) 230 | timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, 231 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) 232 | 233 | # average metric to attach on trainer 234 | RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss') 235 | RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc') 236 | 237 | @trainer.on(Events.STARTED) 238 | def start_training(engine): 239 | engine.state.epoch = start_epoch 240 | 241 | @trainer.on(Events.EPOCH_STARTED) 242 | def adjust_learning_rate(engine): 243 | scheduler.step() 244 | 245 | @trainer.on(Events.ITERATION_COMPLETED) 246 | def log_training_loss(engine): 247 | global ITER 248 | ITER += 1 249 | 250 | if ITER % log_period == 0: 251 | logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}" 252 | .format(engine.state.epoch, ITER, len(train_loader), 253 | engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'], 254 | scheduler.get_lr()[0])) 255 | if len(train_loader) == ITER: 256 | ITER = 0 257 | 258 | @trainer.on(Events.EPOCH_COMPLETED) 259 | def tarin_nfofmer(engine): 260 | data_generator.run(train_loader) 261 | 262 | 263 | # adding handlers using `trainer.on` decorator API 264 | @trainer.on(Events.EPOCH_COMPLETED) 265 | def print_times(engine): 266 | logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]' 267 | .format(engine.state.epoch, timer.value() * timer.step_count, 268 | train_loader.batch_size / timer.value())) 269 | logger.info('-' * 10) 270 | timer.reset() 271 | 272 | @trainer.on(Events.EPOCH_COMPLETED) 273 | def log_validation_results(engine): 274 | if engine.state.epoch % eval_period == 0: 275 | evaluator.run(val_loader) 276 | cmc, mAP = evaluator.state.metrics['r1_mAP'] 277 | logger.info("Validation Results - Epoch: {}".format(engine.state.epoch)) 278 | logger.info("mAP: {:.1%}".format(mAP)) 279 | for r in [1, 5, 10]: 280 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 281 | 282 | trainer.run(train_loader, max_epochs=epochs) 283 | 284 | 285 | def do_train_with_center( 286 | cfg, 287 | model, 288 | center_criterion, 289 | nformer_center_criterion, 290 | train_loader, 291 | val_loader, 292 | optimizer, 293 | optimizer_center, 294 | optimizer_nformer, 295 | optimizer_nformer_center, 296 | scheduler, 297 | loss_fn, 298 | nformer_loss_fn, 299 | num_query, 300 | start_epoch 301 | ): 302 | log_period = cfg.SOLVER.LOG_PERIOD 303 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD 304 | eval_period = cfg.SOLVER.EVAL_PERIOD 305 | output_dir = cfg.OUTPUT_DIR 306 | device = cfg.MODEL.DEVICE 307 | epochs = cfg.SOLVER.MAX_EPOCHS 308 | nformer_epochs = cfg.SOLVER.NFORMER_MAX_EPOCHS 309 | 310 | logger = logging.getLogger("reid_baseline.train") 311 | logger.info("Start training") 312 | trainer = create_supervised_trainer_with_center(model, center_criterion, optimizer, optimizer_center, loss_fn, cfg.SOLVER.CENTER_LOSS_WEIGHT, device=device) 313 | nformer_trainer = create_supervised_nformer_trainer(model, nformer_center_criterion, optimizer_nformer, optimizer_nformer_center, nformer_loss_fn, cfg.SOLVER.CENTER_LOSS_WEIGHT, device=device) 314 | evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device) 315 | nformer_evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': NFormer_R1_mAP(model, num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device) 316 | data_generator = create_nformer_data_generator(model, metrics={'data':data_collector()}, device=device) 317 | checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False) 318 | timer = Timer(average=True) 319 | 320 | 321 | timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, 322 | pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) 323 | 324 | # average metric to attach on trainer 325 | RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss') 326 | RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc') 327 | 328 | @trainer.on(Events.STARTED) 329 | def start_training(engine): 330 | engine.state.epoch = start_epoch 331 | 332 | @trainer.on(Events.EPOCH_STARTED) 333 | def adjust_learning_rate(engine): 334 | scheduler.step() 335 | 336 | @trainer.on(Events.ITERATION_COMPLETED) 337 | def log_training_loss(engine): 338 | global ITER 339 | ITER += 1 340 | 341 | if ITER % log_period == 0: 342 | logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}" 343 | .format(engine.state.epoch, ITER, len(train_loader), 344 | engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'], 345 | scheduler.get_lr()[0])) 346 | if len(train_loader) == ITER: 347 | ITER = 0 348 | 349 | 350 | # adding handlers using `trainer.on` decorator API 351 | @trainer.on(Events.EPOCH_COMPLETED) 352 | def print_times(engine): 353 | logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]' 354 | .format(engine.state.epoch, timer.value() * timer.step_count, 355 | train_loader.batch_size / timer.value())) 356 | logger.info('-' * 10) 357 | timer.reset() 358 | 359 | @trainer.on(Events.EPOCH_COMPLETED) 360 | def log_validation_results(engine): 361 | if engine.state.epoch % eval_period == 0: 362 | evaluator.run(val_loader) 363 | cmc, mAP = evaluator.state.metrics['r1_mAP'] 364 | logger.info("Validation Results - Epoch: {}".format(engine.state.epoch)) 365 | logger.info("mAP: {:.1%}".format(mAP)) 366 | for r in [1, 5, 10]: 367 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 368 | 369 | @trainer.on(Events.EPOCH_COMPLETED) 370 | def tarin_nformer(engine): 371 | if engine.state.epoch < epochs: 372 | return 373 | for n_epoch in range(nformer_epochs): 374 | data_generator.run(train_loader) 375 | nformer_dataset = NFormerDataset(data_generator.state.metrics['data']) 376 | nformer_trainloader = data.DataLoader(nformer_dataset, batch_size=2, num_workers=1,shuffle = True, pin_memory=True) 377 | nformer_trainer.run(nformer_trainloader, max_epochs=1) 378 | if (n_epoch+1)%5 == 0: 379 | print('evaluate nformer at epoch {}'.format(n_epoch)) 380 | nformer_evaluator.run(val_loader) 381 | cmc, mAP = nformer_evaluator.state.metrics['r1_mAP'] 382 | logger.info("mAP: {:.1%}".format(mAP)) 383 | for r in [1, 5, 10]: 384 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 385 | 386 | 387 | 388 | trainer.run(train_loader, max_epochs=epochs) 389 | if not os.path.exists(output_dir): 390 | os.makedirs(output_dir) 391 | torch.save(model.state_dict(), os.path.join(output_dir, 'nformer_model.pth')) 392 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch.nn.functional as F 8 | 9 | from .triplet_loss import TripletLoss, CrossEntropyLabelSmooth 10 | from .center_loss import CenterLoss 11 | 12 | 13 | def make_loss(cfg, num_classes): # modified by gu 14 | sampler = cfg.DATALOADER.SAMPLER 15 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet': 16 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss 17 | else: 18 | print('expected METRIC_LOSS_TYPE should be triplet' 19 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 20 | 21 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 22 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) # new add by luo 23 | print("label smooth on, numclasses:", num_classes) 24 | 25 | if sampler == 'softmax': 26 | def loss_func(score, feat, target): 27 | return F.cross_entropy(score, target) 28 | elif cfg.DATALOADER.SAMPLER == 'triplet': 29 | def loss_func(score, feat, target): 30 | return triplet(feat, target)[0] 31 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet': 32 | def loss_func(score, feat, target): 33 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet': 34 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 35 | return xent(score, target) + triplet(feat, target)[0] 36 | else: 37 | return F.cross_entropy(score, target) + triplet(feat, target)[0] 38 | else: 39 | print('expected METRIC_LOSS_TYPE should be triplet' 40 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 41 | else: 42 | print('expected sampler should be softmax, triplet or softmax_triplet, ' 43 | 'but got {}'.format(cfg.DATALOADER.SAMPLER)) 44 | return loss_func 45 | 46 | 47 | def make_loss_with_center(cfg, num_classes): # modified by gu 48 | feat_dim = 256 49 | 50 | if cfg.MODEL.METRIC_LOSS_TYPE == 'center': 51 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss 52 | 53 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center': 54 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss 55 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss 56 | 57 | else: 58 | print('expected METRIC_LOSS_TYPE with center should be center, triplet_center' 59 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 60 | 61 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 62 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) # new add by luo 63 | print("label smooth on, numclasses:", num_classes) 64 | 65 | def loss_func(score, feat, target): 66 | if cfg.MODEL.METRIC_LOSS_TYPE == 'center': 67 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 68 | return xent(score, target) + \ 69 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) 70 | else: 71 | return F.cross_entropy(score, target) + \ 72 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) 73 | 74 | elif cfg.MODEL.METRIC_LOSS_TYPE == 'triplet_center': 75 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 76 | return xent(score, target) + \ 77 | triplet(feat, target)[0] + \ 78 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) 79 | else: 80 | return F.cross_entropy(score, target) + \ 81 | triplet(feat, target)[0] + \ 82 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) 83 | 84 | else: 85 | print('expected METRIC_LOSS_TYPE with center should be center, triplet_center' 86 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 87 | return loss_func, center_criterion 88 | 89 | def make_nformer_loss_with_center(cfg, num_classes): 90 | feat_dim = 256 91 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss 92 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) 93 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) # new add by luo 94 | def loss_func(score, feat, target): 95 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 96 | return xent(score, target) + \ 97 | triplet(feat, target)[0] + \ 98 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) 99 | else: 100 | return F.cross_entropy(score, target) + \ 101 | triplet(feat, target)[0] + \ 102 | cfg.SOLVER.CENTER_LOSS_WEIGHT * center_criterion(feat, target) 103 | return loss_func, center_criterion 104 | 105 | -------------------------------------------------------------------------------- /layers/center_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class CenterLoss(nn.Module): 8 | """Center loss. 9 | 10 | Reference: 11 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 12 | 13 | Args: 14 | num_classes (int): number of classes. 15 | feat_dim (int): feature dimension. 16 | """ 17 | 18 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True): 19 | super(CenterLoss, self).__init__() 20 | self.num_classes = num_classes 21 | self.feat_dim = feat_dim 22 | self.use_gpu = use_gpu 23 | 24 | if self.use_gpu: 25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 26 | else: 27 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 28 | 29 | def forward(self, x, labels): 30 | """ 31 | Args: 32 | x: feature matrix with shape (batch_size, feat_dim). 33 | labels: ground truth labels with shape (num_classes). 34 | """ 35 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)" 36 | 37 | batch_size = x.size(0) 38 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 39 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 40 | distmat.addmm_(1, -2, x, self.centers.t()) 41 | 42 | classes = torch.arange(self.num_classes).long() 43 | if self.use_gpu: classes = classes.cuda() 44 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 45 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 46 | 47 | dist = distmat * mask.float() 48 | loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size 49 | #dist = [] 50 | #for i in range(batch_size): 51 | # value = distmat[i][mask[i]] 52 | # value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 53 | # dist.append(value) 54 | #dist = torch.cat(dist) 55 | #loss = dist.mean() 56 | return loss 57 | 58 | 59 | if __name__ == '__main__': 60 | use_gpu = False 61 | center_loss = CenterLoss(use_gpu=use_gpu) 62 | features = torch.rand(16, 2048) 63 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long() 64 | if use_gpu: 65 | features = torch.rand(16, 2048).cuda() 66 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda() 67 | 68 | loss = center_loss(features, targets) 69 | print(loss) 70 | -------------------------------------------------------------------------------- /layers/triplet_loss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import torch 7 | from torch import nn 8 | 9 | 10 | def normalize(x, axis=-1): 11 | """Normalizing to unit length along the specified dimension. 12 | Args: 13 | x: pytorch Variable 14 | Returns: 15 | x: pytorch Variable, same shape as input 16 | """ 17 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 18 | return x 19 | 20 | 21 | def euclidean_dist(x, y): 22 | """ 23 | Args: 24 | x: pytorch Variable, with shape [m, d] 25 | y: pytorch Variable, with shape [n, d] 26 | Returns: 27 | dist: pytorch Variable, with shape [m, n] 28 | """ 29 | m, n = x.size(0), y.size(0) 30 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 31 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 32 | dist = xx + yy 33 | dist.addmm_(1, -2, x, y.t()) 34 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 35 | return dist 36 | 37 | 38 | def hard_example_mining(dist_mat, labels, return_inds=False): 39 | """For each anchor, find the hardest positive and negative sample. 40 | Args: 41 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 42 | labels: pytorch LongTensor, with shape [N] 43 | return_inds: whether to return the indices. Save time if `False`(?) 44 | Returns: 45 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 46 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 47 | p_inds: pytorch LongTensor, with shape [N]; 48 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 49 | n_inds: pytorch LongTensor, with shape [N]; 50 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 51 | NOTE: Only consider the case in which all labels have same num of samples, 52 | thus we can cope with all anchors in parallel. 53 | """ 54 | 55 | assert len(dist_mat.size()) == 2 56 | assert dist_mat.size(0) == dist_mat.size(1) 57 | N = dist_mat.size(0) 58 | 59 | # shape [N, N] 60 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 61 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 62 | 63 | # `dist_ap` means distance(anchor, positive) 64 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 65 | dist_ap, relative_p_inds = torch.max( 66 | dist_mat * is_pos, 1, keepdim=True) 67 | #dist_ap, relative_p_inds = torch.max( 68 | # dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 69 | # `dist_an` means distance(anchor, negative) 70 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 71 | dist_an, relative_n_inds = torch.min( 72 | dist_mat * is_neg + is_pos * 1e8, 1, keepdim=True) 73 | #dist_an, relative_n_inds = torch.min( 74 | # dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 75 | # shape [N] 76 | dist_ap = dist_ap.squeeze(1) 77 | dist_an = dist_an.squeeze(1) 78 | 79 | if return_inds: 80 | # shape [N, N] 81 | ind = (labels.new().resize_as_(labels) 82 | .copy_(torch.arange(0, N).long()) 83 | .unsqueeze(0).expand(N, N)) 84 | # shape [N, 1] 85 | p_inds = torch.gather( 86 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 87 | n_inds = torch.gather( 88 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 89 | # shape [N] 90 | p_inds = p_inds.squeeze(1) 91 | n_inds = n_inds.squeeze(1) 92 | return dist_ap, dist_an, p_inds, n_inds 93 | 94 | return dist_ap, dist_an 95 | 96 | 97 | class TripletLoss(object): 98 | """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). 99 | Related Triplet Loss theory can be found in paper 'In Defense of the Triplet 100 | Loss for Person Re-Identification'.""" 101 | 102 | def __init__(self, margin=None): 103 | self.margin = margin 104 | if margin is not None: 105 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 106 | else: 107 | self.ranking_loss = nn.SoftMarginLoss() 108 | 109 | def __call__(self, global_feat, labels, normalize_feature=False): 110 | if normalize_feature: 111 | global_feat = normalize(global_feat, axis=-1) 112 | dist_mat = euclidean_dist(global_feat, global_feat) 113 | dist_ap, dist_an = hard_example_mining( 114 | dist_mat, labels) 115 | y = dist_an.new().resize_as_(dist_an).fill_(1) 116 | if self.margin is not None: 117 | loss = self.ranking_loss(dist_an, dist_ap, y) 118 | else: 119 | loss = self.ranking_loss(dist_an - dist_ap, y) 120 | return loss, dist_ap, dist_an 121 | 122 | class CrossEntropyLabelSmooth(nn.Module): 123 | """Cross entropy loss with label smoothing regularizer. 124 | 125 | Reference: 126 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 127 | Equation: y = (1 - epsilon) * y + epsilon / K. 128 | 129 | Args: 130 | num_classes (int): number of classes. 131 | epsilon (float): weight. 132 | """ 133 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 134 | super(CrossEntropyLabelSmooth, self).__init__() 135 | self.num_classes = num_classes 136 | self.epsilon = epsilon 137 | self.use_gpu = use_gpu 138 | self.logsoftmax = nn.LogSoftmax(dim=1) 139 | 140 | def forward(self, inputs, targets): 141 | """ 142 | Args: 143 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 144 | targets: ground truth labels with shape (num_classes) 145 | """ 146 | log_probs = self.logsoftmax(inputs) 147 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 148 | if self.use_gpu: targets = targets.cuda() 149 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 150 | loss = (- targets * log_probs).mean(0).sum() 151 | return loss 152 | -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .baseline import Baseline 8 | from .model import nformer_model 9 | 10 | 11 | def build_model(cfg, num_classes): 12 | # if cfg.MODEL.NAME == 'resnet50': 13 | # model = Baseline(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT) 14 | model = Baseline(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT, cfg.MODEL.NAME, cfg.MODEL.PRETRAIN_CHOICE) 15 | return model 16 | 17 | def build_nformer_model(cfg, num_classes): 18 | model = nformer_model(cfg, num_classes) 19 | return model 20 | -------------------------------------------------------------------------------- /modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | -------------------------------------------------------------------------------- /modeling/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import math 8 | 9 | import torch 10 | from torch import nn 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | 18 | 19 | class BasicBlock(nn.Module): 20 | expansion = 1 21 | 22 | def __init__(self, inplanes, planes, stride=1, downsample=None): 23 | super(BasicBlock, self).__init__() 24 | self.conv1 = conv3x3(inplanes, planes, stride) 25 | self.bn1 = nn.BatchNorm2d(planes) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.conv2 = conv3x3(planes, planes) 28 | self.bn2 = nn.BatchNorm2d(planes) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | residual = x 34 | 35 | out = self.conv1(x) 36 | out = self.bn1(out) 37 | out = self.relu(out) 38 | 39 | out = self.conv2(out) 40 | out = self.bn2(out) 41 | 42 | if self.downsample is not None: 43 | residual = self.downsample(x) 44 | 45 | out += residual 46 | out = self.relu(out) 47 | 48 | return out 49 | 50 | 51 | class Bottleneck(nn.Module): 52 | expansion = 4 53 | 54 | def __init__(self, inplanes, planes, stride=1, downsample=None): 55 | super(Bottleneck, self).__init__() 56 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 57 | self.bn1 = nn.BatchNorm2d(planes) 58 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 59 | padding=1, bias=False) 60 | self.bn2 = nn.BatchNorm2d(planes) 61 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 62 | self.bn3 = nn.BatchNorm2d(planes * 4) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x): 68 | residual = x 69 | 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.bn2(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv3(out) 79 | out = self.bn3(out) 80 | 81 | if self.downsample is not None: 82 | residual = self.downsample(x) 83 | 84 | out += residual 85 | out = self.relu(out) 86 | 87 | return out 88 | 89 | 90 | class ResNet(nn.Module): 91 | def __init__(self, last_stride=2, block=Bottleneck, layers=[3, 4, 6, 3]): 92 | self.inplanes = 64 93 | super().__init__() 94 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 95 | bias=False) 96 | self.bn1 = nn.BatchNorm2d(64) 97 | # self.relu = nn.ReLU(inplace=True) # add missed relu 98 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 99 | self.layer1 = self._make_layer(block, 64, layers[0]) 100 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 101 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 102 | self.layer4 = self._make_layer( 103 | block, 512, layers[3], stride=last_stride) 104 | 105 | def _make_layer(self, block, planes, blocks, stride=1): 106 | downsample = None 107 | if stride != 1 or self.inplanes != planes * block.expansion: 108 | downsample = nn.Sequential( 109 | nn.Conv2d(self.inplanes, planes * block.expansion, 110 | kernel_size=1, stride=stride, bias=False), 111 | nn.BatchNorm2d(planes * block.expansion), 112 | ) 113 | 114 | layers = [] 115 | layers.append(block(self.inplanes, planes, stride, downsample)) 116 | self.inplanes = planes * block.expansion 117 | for i in range(1, blocks): 118 | layers.append(block(self.inplanes, planes)) 119 | 120 | return nn.Sequential(*layers) 121 | 122 | def forward(self, x): 123 | x = self.conv1(x) 124 | x = self.bn1(x) 125 | # x = self.relu(x) # add missed relu 126 | x = self.maxpool(x) 127 | 128 | x = self.layer1(x) 129 | x = self.layer2(x) 130 | x = self.layer3(x) 131 | x = self.layer4(x) 132 | 133 | return x 134 | 135 | def load_param(self, model_path): 136 | param_dict = torch.load(model_path) 137 | for i in param_dict: 138 | if 'fc' in i: 139 | continue 140 | self.state_dict()[i].copy_(param_dict[i]) 141 | 142 | def random_init(self): 143 | for m in self.modules(): 144 | if isinstance(m, nn.Conv2d): 145 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 146 | m.weight.data.normal_(0, math.sqrt(2. / n)) 147 | elif isinstance(m, nn.BatchNorm2d): 148 | m.weight.data.fill_(1) 149 | m.bias.data.zero_() 150 | 151 | -------------------------------------------------------------------------------- /modeling/backbones/resnet_ibn_a.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | 7 | __all__ = ['ResNet_IBN', 'resnet50_ibn_a', 'resnet101_ibn_a', 8 | 'resnet152_ibn_a'] 9 | 10 | 11 | model_urls = { 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | class IBN(nn.Module): 19 | def __init__(self, planes): 20 | super(IBN, self).__init__() 21 | half1 = int(planes/2) 22 | self.half = half1 23 | half2 = planes - half1 24 | self.IN = nn.InstanceNorm2d(half1, affine=True) 25 | self.BN = nn.BatchNorm2d(half2) 26 | 27 | def forward(self, x): 28 | split = torch.split(x, self.half, 1) 29 | out1 = self.IN(split[0].contiguous()) 30 | out2 = self.BN(split[1].contiguous()) 31 | out = torch.cat((out1, out2), 1) 32 | return out 33 | 34 | 35 | class Bottleneck_IBN(nn.Module): 36 | expansion = 4 37 | 38 | def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None): 39 | super(Bottleneck_IBN, self).__init__() 40 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 41 | if ibn: 42 | self.bn1 = IBN(planes) 43 | else: 44 | self.bn1 = nn.BatchNorm2d(planes) 45 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 46 | padding=1, bias=False) 47 | self.bn2 = nn.BatchNorm2d(planes) 48 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 49 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 50 | self.relu = nn.ReLU(inplace=True) 51 | self.downsample = downsample 52 | self.stride = stride 53 | 54 | def forward(self, x): 55 | residual = x 56 | 57 | out = self.conv1(x) 58 | out = self.bn1(out) 59 | out = self.relu(out) 60 | 61 | out = self.conv2(out) 62 | out = self.bn2(out) 63 | out = self.relu(out) 64 | 65 | out = self.conv3(out) 66 | out = self.bn3(out) 67 | 68 | if self.downsample is not None: 69 | residual = self.downsample(x) 70 | 71 | out += residual 72 | out = self.relu(out) 73 | 74 | return out 75 | 76 | 77 | class ResNet_IBN(nn.Module): 78 | 79 | def __init__(self, last_stride, block, layers, num_classes=1000): 80 | scale = 64 81 | self.inplanes = scale 82 | super(ResNet_IBN, self).__init__() 83 | self.conv1 = nn.Conv2d(3, scale, kernel_size=7, stride=2, padding=3, 84 | bias=False) 85 | self.bn1 = nn.BatchNorm2d(scale) 86 | self.relu = nn.ReLU(inplace=True) 87 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 88 | self.layer1 = self._make_layer(block, scale, layers[0]) 89 | self.layer2 = self._make_layer(block, scale*2, layers[1], stride=2) 90 | self.layer3 = self._make_layer(block, scale*4, layers[2], stride=2) 91 | self.layer4 = self._make_layer(block, scale*8, layers[3], stride=last_stride) 92 | self.avgpool = nn.AvgPool2d(7) 93 | self.fc = nn.Linear(scale * 8 * block.expansion, num_classes) 94 | 95 | for m in self.modules(): 96 | if isinstance(m, nn.Conv2d): 97 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 98 | m.weight.data.normal_(0, math.sqrt(2. / n)) 99 | elif isinstance(m, nn.BatchNorm2d): 100 | m.weight.data.fill_(1) 101 | m.bias.data.zero_() 102 | elif isinstance(m, nn.InstanceNorm2d): 103 | m.weight.data.fill_(1) 104 | m.bias.data.zero_() 105 | 106 | def _make_layer(self, block, planes, blocks, stride=1): 107 | downsample = None 108 | if stride != 1 or self.inplanes != planes * block.expansion: 109 | downsample = nn.Sequential( 110 | nn.Conv2d(self.inplanes, planes * block.expansion, 111 | kernel_size=1, stride=stride, bias=False), 112 | nn.BatchNorm2d(planes * block.expansion), 113 | ) 114 | 115 | layers = [] 116 | ibn = True 117 | if planes == 512: 118 | ibn = False 119 | layers.append(block(self.inplanes, planes, ibn, stride, downsample)) 120 | self.inplanes = planes * block.expansion 121 | for i in range(1, blocks): 122 | layers.append(block(self.inplanes, planes, ibn)) 123 | 124 | return nn.Sequential(*layers) 125 | 126 | def forward(self, x): 127 | x = self.conv1(x) 128 | x = self.bn1(x) 129 | x = self.relu(x) 130 | x = self.maxpool(x) 131 | 132 | x = self.layer1(x) 133 | x = self.layer2(x) 134 | x = self.layer3(x) 135 | x = self.layer4(x) 136 | 137 | # x = self.avgpool(x) 138 | # x = x.view(x.size(0), -1) 139 | # x = self.fc(x) 140 | 141 | return x 142 | 143 | def load_param(self, model_path): 144 | param_dict = torch.load(model_path) 145 | for i in param_dict: 146 | if 'fc' in i: 147 | continue 148 | self.state_dict()[i].copy_(param_dict[i]) 149 | 150 | 151 | def resnet50_ibn_a(last_stride, pretrained=False, **kwargs): 152 | """Constructs a ResNet-50 model. 153 | Args: 154 | pretrained (bool): If True, returns a model pre-trained on ImageNet 155 | """ 156 | model = ResNet_IBN(last_stride, Bottleneck_IBN, [3, 4, 6, 3], **kwargs) 157 | if pretrained: 158 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 159 | return model 160 | 161 | 162 | def resnet101_ibn_a(last_stride, pretrained=False, **kwargs): 163 | """Constructs a ResNet-101 model. 164 | Args: 165 | pretrained (bool): If True, returns a model pre-trained on ImageNet 166 | """ 167 | model = ResNet_IBN(last_stride, Bottleneck_IBN, [3, 4, 23, 3], **kwargs) 168 | if pretrained: 169 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 170 | return model 171 | 172 | 173 | def resnet152_ibn_a(last_stride, pretrained=False, **kwargs): 174 | """Constructs a ResNet-152 model. 175 | Args: 176 | pretrained (bool): If True, returns a model pre-trained on ImageNet 177 | """ 178 | model = ResNet_IBN(last_stride, Bottleneck_IBN, [3, 8, 36, 3], **kwargs) 179 | if pretrained: 180 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 181 | return model -------------------------------------------------------------------------------- /modeling/backbones/senet.py: -------------------------------------------------------------------------------- 1 | """ 2 | ResNet code gently borrowed from 3 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 4 | """ 5 | from __future__ import print_function, division, absolute_import 6 | from collections import OrderedDict 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | from torch.utils import model_zoo 11 | 12 | __all__ = ['SENet', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 13 | 'se_resnext50_32x4d', 'se_resnext101_32x4d'] 14 | 15 | pretrained_settings = { 16 | 'senet154': { 17 | 'imagenet': { 18 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth', 19 | 'input_space': 'RGB', 20 | 'input_size': [3, 224, 224], 21 | 'input_range': [0, 1], 22 | 'mean': [0.485, 0.456, 0.406], 23 | 'std': [0.229, 0.224, 0.225], 24 | 'num_classes': 1000 25 | } 26 | }, 27 | 'se_resnet50': { 28 | 'imagenet': { 29 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth', 30 | 'input_space': 'RGB', 31 | 'input_size': [3, 224, 224], 32 | 'input_range': [0, 1], 33 | 'mean': [0.485, 0.456, 0.406], 34 | 'std': [0.229, 0.224, 0.225], 35 | 'num_classes': 1000 36 | } 37 | }, 38 | 'se_resnet101': { 39 | 'imagenet': { 40 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth', 41 | 'input_space': 'RGB', 42 | 'input_size': [3, 224, 224], 43 | 'input_range': [0, 1], 44 | 'mean': [0.485, 0.456, 0.406], 45 | 'std': [0.229, 0.224, 0.225], 46 | 'num_classes': 1000 47 | } 48 | }, 49 | 'se_resnet152': { 50 | 'imagenet': { 51 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth', 52 | 'input_space': 'RGB', 53 | 'input_size': [3, 224, 224], 54 | 'input_range': [0, 1], 55 | 'mean': [0.485, 0.456, 0.406], 56 | 'std': [0.229, 0.224, 0.225], 57 | 'num_classes': 1000 58 | } 59 | }, 60 | 'se_resnext50_32x4d': { 61 | 'imagenet': { 62 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth', 63 | 'input_space': 'RGB', 64 | 'input_size': [3, 224, 224], 65 | 'input_range': [0, 1], 66 | 'mean': [0.485, 0.456, 0.406], 67 | 'std': [0.229, 0.224, 0.225], 68 | 'num_classes': 1000 69 | } 70 | }, 71 | 'se_resnext101_32x4d': { 72 | 'imagenet': { 73 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth', 74 | 'input_space': 'RGB', 75 | 'input_size': [3, 224, 224], 76 | 'input_range': [0, 1], 77 | 'mean': [0.485, 0.456, 0.406], 78 | 'std': [0.229, 0.224, 0.225], 79 | 'num_classes': 1000 80 | } 81 | }, 82 | } 83 | 84 | 85 | class SEModule(nn.Module): 86 | 87 | def __init__(self, channels, reduction): 88 | super(SEModule, self).__init__() 89 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 90 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, 91 | padding=0) 92 | self.relu = nn.ReLU(inplace=True) 93 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, 94 | padding=0) 95 | self.sigmoid = nn.Sigmoid() 96 | 97 | def forward(self, x): 98 | module_input = x 99 | x = self.avg_pool(x) 100 | x = self.fc1(x) 101 | x = self.relu(x) 102 | x = self.fc2(x) 103 | x = self.sigmoid(x) 104 | return module_input * x 105 | 106 | 107 | class Bottleneck(nn.Module): 108 | """ 109 | Base class for bottlenecks that implements `forward()` method. 110 | """ 111 | def forward(self, x): 112 | residual = x 113 | 114 | out = self.conv1(x) 115 | out = self.bn1(out) 116 | out = self.relu(out) 117 | 118 | out = self.conv2(out) 119 | out = self.bn2(out) 120 | out = self.relu(out) 121 | 122 | out = self.conv3(out) 123 | out = self.bn3(out) 124 | 125 | if self.downsample is not None: 126 | residual = self.downsample(x) 127 | 128 | out = self.se_module(out) + residual 129 | out = self.relu(out) 130 | 131 | return out 132 | 133 | 134 | class SEBottleneck(Bottleneck): 135 | """ 136 | Bottleneck for SENet154. 137 | """ 138 | expansion = 4 139 | 140 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 141 | downsample=None): 142 | super(SEBottleneck, self).__init__() 143 | self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) 144 | self.bn1 = nn.BatchNorm2d(planes * 2) 145 | self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3, 146 | stride=stride, padding=1, groups=groups, 147 | bias=False) 148 | self.bn2 = nn.BatchNorm2d(planes * 4) 149 | self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, 150 | bias=False) 151 | self.bn3 = nn.BatchNorm2d(planes * 4) 152 | self.relu = nn.ReLU(inplace=True) 153 | self.se_module = SEModule(planes * 4, reduction=reduction) 154 | self.downsample = downsample 155 | self.stride = stride 156 | 157 | 158 | class SEResNetBottleneck(Bottleneck): 159 | """ 160 | ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe 161 | implementation and uses `stride=stride` in `conv1` and not in `conv2` 162 | (the latter is used in the torchvision implementation of ResNet). 163 | """ 164 | expansion = 4 165 | 166 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 167 | downsample=None): 168 | super(SEResNetBottleneck, self).__init__() 169 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, 170 | stride=stride) 171 | self.bn1 = nn.BatchNorm2d(planes) 172 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, 173 | groups=groups, bias=False) 174 | self.bn2 = nn.BatchNorm2d(planes) 175 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 176 | self.bn3 = nn.BatchNorm2d(planes * 4) 177 | self.relu = nn.ReLU(inplace=True) 178 | self.se_module = SEModule(planes * 4, reduction=reduction) 179 | self.downsample = downsample 180 | self.stride = stride 181 | 182 | 183 | class SEResNeXtBottleneck(Bottleneck): 184 | """ 185 | ResNeXt bottleneck type C with a Squeeze-and-Excitation module. 186 | """ 187 | expansion = 4 188 | 189 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 190 | downsample=None, base_width=4): 191 | super(SEResNeXtBottleneck, self).__init__() 192 | width = math.floor(planes * (base_width / 64)) * groups 193 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, 194 | stride=1) 195 | self.bn1 = nn.BatchNorm2d(width) 196 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, 197 | padding=1, groups=groups, bias=False) 198 | self.bn2 = nn.BatchNorm2d(width) 199 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) 200 | self.bn3 = nn.BatchNorm2d(planes * 4) 201 | self.relu = nn.ReLU(inplace=True) 202 | self.se_module = SEModule(planes * 4, reduction=reduction) 203 | self.downsample = downsample 204 | self.stride = stride 205 | 206 | 207 | class SENet(nn.Module): 208 | 209 | def __init__(self, block, layers, groups, reduction, dropout_p=0.2, 210 | inplanes=128, input_3x3=True, downsample_kernel_size=3, 211 | downsample_padding=1, last_stride=2): 212 | """ 213 | Parameters 214 | ---------- 215 | block (nn.Module): Bottleneck class. 216 | - For SENet154: SEBottleneck 217 | - For SE-ResNet models: SEResNetBottleneck 218 | - For SE-ResNeXt models: SEResNeXtBottleneck 219 | layers (list of ints): Number of residual blocks for 4 layers of the 220 | network (layer1...layer4). 221 | groups (int): Number of groups for the 3x3 convolution in each 222 | bottleneck block. 223 | - For SENet154: 64 224 | - For SE-ResNet models: 1 225 | - For SE-ResNeXt models: 32 226 | reduction (int): Reduction ratio for Squeeze-and-Excitation modules. 227 | - For all models: 16 228 | dropout_p (float or None): Drop probability for the Dropout layer. 229 | If `None` the Dropout layer is not used. 230 | - For SENet154: 0.2 231 | - For SE-ResNet models: None 232 | - For SE-ResNeXt models: None 233 | inplanes (int): Number of input channels for layer1. 234 | - For SENet154: 128 235 | - For SE-ResNet models: 64 236 | - For SE-ResNeXt models: 64 237 | input_3x3 (bool): If `True`, use three 3x3 convolutions instead of 238 | a single 7x7 convolution in layer0. 239 | - For SENet154: True 240 | - For SE-ResNet models: False 241 | - For SE-ResNeXt models: False 242 | downsample_kernel_size (int): Kernel size for downsampling convolutions 243 | in layer2, layer3 and layer4. 244 | - For SENet154: 3 245 | - For SE-ResNet models: 1 246 | - For SE-ResNeXt models: 1 247 | downsample_padding (int): Padding for downsampling convolutions in 248 | layer2, layer3 and layer4. 249 | - For SENet154: 1 250 | - For SE-ResNet models: 0 251 | - For SE-ResNeXt models: 0 252 | num_classes (int): Number of outputs in `last_linear` layer. 253 | - For all models: 1000 254 | """ 255 | super(SENet, self).__init__() 256 | self.inplanes = inplanes 257 | if input_3x3: 258 | layer0_modules = [ 259 | ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1, 260 | bias=False)), 261 | ('bn1', nn.BatchNorm2d(64)), 262 | ('relu1', nn.ReLU(inplace=True)), 263 | ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, 264 | bias=False)), 265 | ('bn2', nn.BatchNorm2d(64)), 266 | ('relu2', nn.ReLU(inplace=True)), 267 | ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, 268 | bias=False)), 269 | ('bn3', nn.BatchNorm2d(inplanes)), 270 | ('relu3', nn.ReLU(inplace=True)), 271 | ] 272 | else: 273 | layer0_modules = [ 274 | ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2, 275 | padding=3, bias=False)), 276 | ('bn1', nn.BatchNorm2d(inplanes)), 277 | ('relu1', nn.ReLU(inplace=True)), 278 | ] 279 | # To preserve compatibility with Caffe weights `ceil_mode=True` 280 | # is used instead of `padding=1`. 281 | layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2, 282 | ceil_mode=True))) 283 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) 284 | self.layer1 = self._make_layer( 285 | block, 286 | planes=64, 287 | blocks=layers[0], 288 | groups=groups, 289 | reduction=reduction, 290 | downsample_kernel_size=1, 291 | downsample_padding=0 292 | ) 293 | self.layer2 = self._make_layer( 294 | block, 295 | planes=128, 296 | blocks=layers[1], 297 | stride=2, 298 | groups=groups, 299 | reduction=reduction, 300 | downsample_kernel_size=downsample_kernel_size, 301 | downsample_padding=downsample_padding 302 | ) 303 | self.layer3 = self._make_layer( 304 | block, 305 | planes=256, 306 | blocks=layers[2], 307 | stride=2, 308 | groups=groups, 309 | reduction=reduction, 310 | downsample_kernel_size=downsample_kernel_size, 311 | downsample_padding=downsample_padding 312 | ) 313 | self.layer4 = self._make_layer( 314 | block, 315 | planes=512, 316 | blocks=layers[3], 317 | stride=last_stride, 318 | groups=groups, 319 | reduction=reduction, 320 | downsample_kernel_size=downsample_kernel_size, 321 | downsample_padding=downsample_padding 322 | ) 323 | self.avg_pool = nn.AvgPool2d(7, stride=1) 324 | self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None 325 | 326 | def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, 327 | downsample_kernel_size=1, downsample_padding=0): 328 | downsample = None 329 | if stride != 1 or self.inplanes != planes * block.expansion: 330 | downsample = nn.Sequential( 331 | nn.Conv2d(self.inplanes, planes * block.expansion, 332 | kernel_size=downsample_kernel_size, stride=stride, 333 | padding=downsample_padding, bias=False), 334 | nn.BatchNorm2d(planes * block.expansion), 335 | ) 336 | 337 | layers = [] 338 | layers.append(block(self.inplanes, planes, groups, reduction, stride, 339 | downsample)) 340 | self.inplanes = planes * block.expansion 341 | for i in range(1, blocks): 342 | layers.append(block(self.inplanes, planes, groups, reduction)) 343 | 344 | return nn.Sequential(*layers) 345 | 346 | def load_param(self, model_path): 347 | param_dict = torch.load(model_path) 348 | for i in param_dict: 349 | if 'last_linear' in i: 350 | continue 351 | self.state_dict()[i].copy_(param_dict[i]) 352 | 353 | def forward(self, x): 354 | x = self.layer0(x) 355 | x = self.layer1(x) 356 | x = self.layer2(x) 357 | x = self.layer3(x) 358 | x = self.layer4(x) 359 | return x -------------------------------------------------------------------------------- /modeling/baseline.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | from torch import nn 9 | 10 | from .backbones.resnet import ResNet, BasicBlock, Bottleneck 11 | from .backbones.senet import SENet, SEResNetBottleneck, SEBottleneck, SEResNeXtBottleneck 12 | from .backbones.resnet_ibn_a import resnet50_ibn_a 13 | 14 | 15 | def weights_init_kaiming(m): 16 | classname = m.__class__.__name__ 17 | if classname.find('Linear') != -1: 18 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 19 | nn.init.constant_(m.bias, 0.0) 20 | elif classname.find('Conv') != -1: 21 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 22 | if m.bias is not None: 23 | nn.init.constant_(m.bias, 0.0) 24 | elif classname.find('BatchNorm') != -1: 25 | if m.affine: 26 | nn.init.constant_(m.weight, 1.0) 27 | nn.init.constant_(m.bias, 0.0) 28 | 29 | 30 | def weights_init_classifier(m): 31 | classname = m.__class__.__name__ 32 | if classname.find('Linear') != -1: 33 | nn.init.normal_(m.weight, std=0.001) 34 | if m.bias: 35 | nn.init.constant_(m.bias, 0.0) 36 | 37 | 38 | class Baseline(nn.Module): 39 | in_planes = 2048 40 | 41 | def __init__(self, num_classes, last_stride, model_path, neck, neck_feat, model_name, pretrain_choice): 42 | super(Baseline, self).__init__() 43 | self.feature_dim = 256 44 | if model_name == 'resnet18': 45 | self.in_planes = 512 46 | self.base = ResNet(last_stride=last_stride, 47 | block=BasicBlock, 48 | layers=[2, 2, 2, 2]) 49 | elif model_name == 'resnet34': 50 | self.in_planes = 512 51 | self.base = ResNet(last_stride=last_stride, 52 | block=BasicBlock, 53 | layers=[3, 4, 6, 3]) 54 | elif model_name == 'resnet50': 55 | self.base = ResNet(last_stride=last_stride, 56 | block=Bottleneck, 57 | layers=[3, 4, 6, 3]) 58 | elif model_name == 'resnet101': 59 | self.base = ResNet(last_stride=last_stride, 60 | block=Bottleneck, 61 | layers=[3, 4, 23, 3]) 62 | elif model_name == 'resnet152': 63 | self.base = ResNet(last_stride=last_stride, 64 | block=Bottleneck, 65 | layers=[3, 8, 36, 3]) 66 | 67 | elif model_name == 'se_resnet50': 68 | self.base = SENet(block=SEResNetBottleneck, 69 | layers=[3, 4, 6, 3], 70 | groups=1, 71 | reduction=16, 72 | dropout_p=None, 73 | inplanes=64, 74 | input_3x3=False, 75 | downsample_kernel_size=1, 76 | downsample_padding=0, 77 | last_stride=last_stride) 78 | elif model_name == 'se_resnet101': 79 | self.base = SENet(block=SEResNetBottleneck, 80 | layers=[3, 4, 23, 3], 81 | groups=1, 82 | reduction=16, 83 | dropout_p=None, 84 | inplanes=64, 85 | input_3x3=False, 86 | downsample_kernel_size=1, 87 | downsample_padding=0, 88 | last_stride=last_stride) 89 | elif model_name == 'se_resnet152': 90 | self.base = SENet(block=SEResNetBottleneck, 91 | layers=[3, 8, 36, 3], 92 | groups=1, 93 | reduction=16, 94 | dropout_p=None, 95 | inplanes=64, 96 | input_3x3=False, 97 | downsample_kernel_size=1, 98 | downsample_padding=0, 99 | last_stride=last_stride) 100 | elif model_name == 'se_resnext50': 101 | self.base = SENet(block=SEResNeXtBottleneck, 102 | layers=[3, 4, 6, 3], 103 | groups=32, 104 | reduction=16, 105 | dropout_p=None, 106 | inplanes=64, 107 | input_3x3=False, 108 | downsample_kernel_size=1, 109 | downsample_padding=0, 110 | last_stride=last_stride) 111 | elif model_name == 'se_resnext101': 112 | self.base = SENet(block=SEResNeXtBottleneck, 113 | layers=[3, 4, 23, 3], 114 | groups=32, 115 | reduction=16, 116 | dropout_p=None, 117 | inplanes=64, 118 | input_3x3=False, 119 | downsample_kernel_size=1, 120 | downsample_padding=0, 121 | last_stride=last_stride) 122 | elif model_name == 'senet154': 123 | self.base = SENet(block=SEBottleneck, 124 | layers=[3, 8, 36, 3], 125 | groups=64, 126 | reduction=16, 127 | dropout_p=0.2, 128 | last_stride=last_stride) 129 | elif model_name == 'resnet50_ibn_a': 130 | self.base = resnet50_ibn_a(last_stride) 131 | 132 | if pretrain_choice == 'imagenet': 133 | self.base.load_param(model_path) 134 | print('Loading pretrained ImageNet model......') 135 | 136 | self.gap = nn.AdaptiveAvgPool2d(1) 137 | # self.gap = nn.AdaptiveMaxPool2d(1) 138 | self.num_classes = num_classes 139 | self.neck = neck 140 | self.neck_feat = neck_feat 141 | self.projection = nn.Linear(self.in_planes, self.feature_dim) 142 | 143 | if self.neck == 'no': 144 | self.classifier = nn.Linear(self.feature_dim, self.num_classes) 145 | # self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) # new add by luo 146 | # self.classifier.apply(weights_init_classifier) # new add by luo 147 | elif self.neck == 'bnneck': 148 | self.bottleneck = nn.BatchNorm1d(self.feature_dim) 149 | self.bottleneck.bias.requires_grad_(False) # no shift 150 | self.classifier = nn.Linear(self.feature_dim, self.num_classes, bias=False) 151 | 152 | self.bottleneck.apply(weights_init_kaiming) 153 | self.classifier.apply(weights_init_classifier) 154 | 155 | def forward(self, x): 156 | 157 | global_feat = self.gap(self.base(x)) # (b, 2048, 1, 1) 158 | global_feat = global_feat.view(global_feat.shape[0], -1) # flatten to (bs, 2048) 159 | global_feat = self.projection(global_feat) 160 | 161 | if self.neck == 'no': 162 | feat = global_feat 163 | elif self.neck == 'bnneck': 164 | feat = self.bottleneck(global_feat) # normalize for angular softmax 165 | 166 | if self.training: 167 | cls_score = self.classifier(feat) 168 | return cls_score, global_feat # global feature for triplet loss 169 | else: 170 | if self.neck_feat == 'after': 171 | # print("Test with feature after BN") 172 | return feat 173 | else: 174 | # print("Test with feature before BN") 175 | return global_feat 176 | 177 | def load_param(self, trained_path): 178 | param_dict = torch.load(trained_path).state_dict() 179 | for i in param_dict: 180 | if 'classifier' in i: 181 | continue 182 | self.state_dict()[i].copy_(param_dict[i]) 183 | -------------------------------------------------------------------------------- /modeling/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .baseline import Baseline 3 | from .nformer import NFormer 4 | import torch.nn as nn 5 | class nformer_model(nn.Module): 6 | def __init__(self, cfg, num_classes): 7 | super(nformer_model, self).__init__() 8 | self.backbone = Baseline(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.PRETRAIN_PATH, cfg.MODEL.NECK, cfg.TEST.NECK_FEAT, cfg.MODEL.NAME, cfg.MODEL.PRETRAIN_CHOICE) 9 | self.nformer = NFormer(cfg, num_classes) 10 | 11 | def forward(self,x,stage = 'encoder'): 12 | if stage == 'encoder': 13 | if self.training: 14 | score, feat = self.backbone(x) 15 | return score, feat 16 | else: 17 | feat = self.backbone(x) 18 | return feat 19 | 20 | elif stage == 'nformer': 21 | feat = self.nformer(x) 22 | return feat 23 | -------------------------------------------------------------------------------- /modeling/nformer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import math 4 | import re 5 | import collections 6 | 7 | import random 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.nn.parameter import Parameter 13 | import random 14 | 15 | def weights_init_kaiming(m): 16 | classname = m.__class__.__name__ 17 | if classname.find('Linear') != -1: 18 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 19 | nn.init.constant_(m.bias, 0.0) 20 | elif classname.find('Conv') != -1: 21 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 22 | if m.bias is not None: 23 | nn.init.constant_(m.bias, 0.0) 24 | elif classname.find('BatchNorm') != -1: 25 | if m.affine: 26 | nn.init.constant_(m.weight, 1.0) 27 | nn.init.constant_(m.bias, 0.0) 28 | 29 | def weights_init_classifier(m): 30 | classname = m.__class__.__name__ 31 | if classname.find('Linear') != -1: 32 | nn.init.normal_(m.weight, std=0.001) 33 | if m.bias: 34 | nn.init.constant_(m.bias, 0.0) 35 | 36 | def gelu(x): 37 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 38 | 39 | def swish(x): 40 | return x * torch.sigmoid(x) 41 | 42 | ACT_FNS = { 43 | 'relu': nn.ReLU, 44 | 'swish': swish, 45 | 'gelu': gelu 46 | } 47 | 48 | 49 | class LayerNorm(nn.Module): 50 | "Construct a layernorm module in the OpenAI style (epsilon inside the square root)." 51 | 52 | def __init__(self, n_state, e=1e-5): 53 | super(LayerNorm, self).__init__() 54 | self.g = nn.Parameter(torch.ones(n_state)) 55 | self.b = nn.Parameter(torch.zeros(n_state)) 56 | self.e = e 57 | 58 | def forward(self, x): 59 | u = x.mean(-1, keepdim=True) 60 | s = (x - u).pow(2).mean(-1, keepdim=True) 61 | x = (x - u) / torch.sqrt(s + self.e) 62 | return self.g * x + self.b 63 | 64 | 65 | class Conv1D(nn.Module): 66 | def __init__(self, nf, rf, nx): 67 | super(Conv1D, self).__init__() 68 | self.rf = rf 69 | self.nf = nf 70 | if rf == 1: # faster 1x1 conv 71 | w = torch.empty(nx, nf) 72 | nn.init.normal_(w, std=0.02) 73 | self.w = Parameter(w) 74 | self.b = Parameter(torch.zeros(nf)) 75 | else: # was used to train LM 76 | raise NotImplementedError 77 | 78 | def forward(self, x): 79 | if self.rf == 1: 80 | size_out = x.size()[:-1] + (self.nf,) 81 | x = torch.addmm(self.b, x.view(-1, x.size(-1)), self.w) 82 | x = x.view(*size_out) 83 | else: 84 | raise NotImplementedError 85 | return x 86 | 87 | 88 | class Attention(nn.Module): 89 | def __init__(self, nx, n_ctx, cfg, scale=False): 90 | super(Attention, self).__init__() 91 | n_state = nx 92 | assert n_state % cfg.MODEL.N_HEAD == 0 93 | self.n_head = cfg.MODEL.N_HEAD 94 | self.split_size = n_state 95 | self.scale = scale 96 | self.c_attn = Conv1D(n_state * 3, 1, nx) 97 | self.c_proj = Conv1D(n_state, 1, nx) 98 | 99 | self.resid_dropout = nn.Dropout(cfg.MODEL.RESID_PDROP) 100 | 101 | def _attn(self, q, k, v, num_landmark, rns_indices): 102 | data_length = q.shape[2] 103 | landmark = torch.Tensor(random.sample(range(data_length),num_landmark)).long() 104 | 105 | sq = q[:,:,landmark,:].contiguous() 106 | sk = k[:,:,:,landmark].contiguous() 107 | 108 | w1 = torch.matmul(q, sk) 109 | w2 = torch.matmul(sq, k) 110 | w = torch.matmul(w1, w2) 111 | 112 | if self.scale: 113 | w = w / math.sqrt(v.size(-1)) 114 | return self.rns(w, v, rns_indices) 115 | 116 | def rns(self, w, v, rns_indices): 117 | bs,hn,dl,_ = w.shape 118 | rns_indices = rns_indices.unsqueeze(1).repeat(1,hn,1,1) 119 | mask = torch.zeros_like(w).scatter_(3, rns_indices,torch.ones_like(rns_indices, dtype=w.dtype)) 120 | mask = mask * mask.transpose(2,3) 121 | if 'cuda' in str(w.device): 122 | mask = mask.cuda() 123 | else: 124 | mask = mask.cpu() 125 | if self.training: 126 | w = w * mask + -1e9 * (1 - mask) 127 | w = F.softmax(w,dim=3) 128 | a_v = torch.matmul(w, v) 129 | else: 130 | w = (w * mask).reshape(bs*hn,dl,dl).to_sparse() 131 | w = torch.sparse.softmax(w,2) 132 | v = v.reshape(bs*hn,dl,-1) 133 | a_v = torch.bmm(w,v).reshape(bs,hn,dl,-1) 134 | return a_v 135 | 136 | 137 | def merge_heads(self, x): 138 | x = x.permute(0, 2, 1, 3).contiguous() 139 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) 140 | return x.view(*new_x_shape) 141 | 142 | def split_heads(self, x, k=False): 143 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) 144 | x = x.view(*new_x_shape) 145 | if k: 146 | return x.permute(0, 2, 3, 1) 147 | else: 148 | return x.permute(0, 2, 1, 3) 149 | 150 | 151 | def forward(self, x, num_landmark, rns_indices): 152 | x = self.c_attn(x) 153 | query, key, value = x.split(self.split_size, dim=2) 154 | query = self.split_heads(query) 155 | key = self.split_heads(key, k=True) 156 | value = self.split_heads(value) 157 | mask = None 158 | a = self._attn(query, key, value, num_landmark, rns_indices) 159 | a = self.merge_heads(a) 160 | a = self.c_proj(a) 161 | a = self.resid_dropout(a) 162 | return a 163 | 164 | 165 | class MLP(nn.Module): 166 | def __init__(self, n_state, cfg): 167 | super(MLP, self).__init__() 168 | nx = cfg.MODEL.N_EMBD 169 | self.c_fc = Conv1D(n_state, 1, nx) 170 | self.c_proj = Conv1D(nx, 1, n_state) 171 | self.act = ACT_FNS[cfg.MODEL.AFN] 172 | self.dropout = nn.Dropout(cfg.MODEL.RESID_PDROP) 173 | 174 | def forward(self, x): 175 | h = self.act(self.c_fc(x)) 176 | h2 = self.c_proj(h) 177 | return self.dropout(h2) 178 | 179 | 180 | class Block(nn.Module): 181 | def __init__(self, n_ctx, cfg, scale=False): 182 | super(Block, self).__init__() 183 | nx = cfg.MODEL.N_EMBD 184 | self.attn = Attention(nx, n_ctx, cfg, scale) 185 | self.ln_1 = LayerNorm(nx) 186 | self.mlp = MLP(4 * nx, cfg) 187 | self.ln_2 = LayerNorm(nx) 188 | 189 | def forward(self, x, num_landmark, rns_indices): 190 | a = self.attn(x, num_landmark, rns_indices) 191 | n = self.ln_1(x + a) 192 | m = self.mlp(n) 193 | h = self.ln_2(n + m) 194 | return h 195 | 196 | 197 | class NFormer(nn.Module): 198 | """ NFormer model """ 199 | 200 | def __init__(self, cfg, vocab=40990, n_ctx=1024, num_classes = 751): 201 | super(NFormer, self).__init__() 202 | self.num_classes = num_classes 203 | 204 | block = Block(n_ctx, cfg, scale=True) 205 | self.h = nn.ModuleList([copy.deepcopy(block) for _ in range(cfg.MODEL.N_LAYER)]) 206 | 207 | self.bottleneck = nn.BatchNorm1d(cfg.MODEL.N_EMBD) 208 | self.bottleneck.bias.requires_grad_(False) # no shift 209 | self.bottleneck.apply(weights_init_kaiming) 210 | 211 | self.classifier = nn.Linear(cfg.MODEL.N_EMBD, self.num_classes, bias=False) 212 | self.classifier.apply(weights_init_classifier) 213 | self.topk = cfg.MODEL.TOPK 214 | self.num_landmark = cfg.MODEL.LANDMARK 215 | 216 | def forward(self, x): 217 | _, rns_indices = torch.topk(torch.bmm(x/torch.norm(x,p=2,dim=2,keepdim=True),(x/torch.norm(x,p=2,dim=2,keepdim=True)).transpose(1,2)), self.topk, dim=2) 218 | for block in self.h: 219 | x = block(x, self.num_landmark, rns_indices) 220 | 221 | bs,dl,d = x.shape 222 | x = x.reshape(bs*dl,d) 223 | feat = self.bottleneck(x) 224 | cls_score = self.classifier(feat) 225 | x = x.reshape(bs,dl,d) 226 | feat = feat.reshape(bs,dl,d) 227 | cls_score = cls_score.reshape(bs,dl,-1) 228 | 229 | if self.training: 230 | return cls_score, x 231 | else: 232 | return feat 233 | 234 | 235 | 236 | -------------------------------------------------------------------------------- /pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/haochenheheda/NFormer/c78cb848c6b8cf64e973a1ee0ce14488d4904f8f/pipeline.jpg -------------------------------------------------------------------------------- /solver/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import make_optimizer, make_optimizer_with_center, make_nformer_optimizer_with_center 8 | from .lr_scheduler import WarmupMultiStepLR 9 | -------------------------------------------------------------------------------- /solver/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | 9 | 10 | def make_optimizer(cfg, model): 11 | params = [] 12 | for key, value in model.named_parameters(): 13 | if not value.requires_grad: 14 | continue 15 | lr = cfg.SOLVER.BASE_LR 16 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 17 | if "bias" in key: 18 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 19 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 20 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 21 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 22 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 23 | else: 24 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 25 | return optimizer 26 | 27 | 28 | def make_optimizer_with_center(cfg, model, center_criterion): 29 | params = [] 30 | for key, value in model.named_parameters(): 31 | if not value.requires_grad: 32 | continue 33 | lr = cfg.SOLVER.BASE_LR 34 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 35 | if "bias" in key: 36 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 37 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 38 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 39 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 40 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 41 | else: 42 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 43 | optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR) 44 | return optimizer, optimizer_center 45 | 46 | 47 | def make_nformer_optimizer_with_center(cfg, model, center_criterion, nformer_center_criterion): 48 | params = [] 49 | for key, value in model.named_parameters(): 50 | if not value.requires_grad: 51 | continue 52 | lr = cfg.SOLVER.BASE_LR 53 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 54 | if "bias" in key: 55 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 56 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 57 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 58 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 59 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 60 | else: 61 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 62 | nformer_optimizer = torch.optim.Adam(model.nformer.parameters(),lr = 1e-5,eps=1e-8, betas=[0.9,0.999]) 63 | optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR) 64 | nformer_optimizer_center = torch.optim.SGD(nformer_center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR) 65 | return optimizer, optimizer_center, nformer_optimizer, nformer_optimizer_center 66 | -------------------------------------------------------------------------------- /solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from bisect import bisect_right 7 | import torch 8 | 9 | 10 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 11 | # separating MultiStepLR with WarmupLR 12 | # but the current LRScheduler design doesn't allow it 13 | 14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 15 | def __init__( 16 | self, 17 | optimizer, 18 | milestones, 19 | gamma=0.1, 20 | warmup_factor=1.0 / 3, 21 | warmup_iters=500, 22 | warmup_method="linear", 23 | last_epoch=-1, 24 | ): 25 | if not list(milestones) == sorted(milestones): 26 | raise ValueError( 27 | "Milestones should be a list of" " increasing integers. Got {}", 28 | milestones, 29 | ) 30 | 31 | if warmup_method not in ("constant", "linear"): 32 | raise ValueError( 33 | "Only 'constant' or 'linear' warmup_method accepted" 34 | "got {}".format(warmup_method) 35 | ) 36 | self.milestones = milestones 37 | self.gamma = gamma 38 | self.warmup_factor = warmup_factor 39 | self.warmup_iters = warmup_iters 40 | self.warmup_method = warmup_method 41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 42 | 43 | def get_lr(self): 44 | warmup_factor = 1 45 | if self.last_epoch < self.warmup_iters: 46 | if self.warmup_method == "constant": 47 | warmup_factor = self.warmup_factor 48 | elif self.warmup_method == "linear": 49 | alpha = self.last_epoch / self.warmup_iters 50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 51 | return [ 52 | base_lr 53 | * warmup_factor 54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 55 | for base_lr in self.base_lrs 56 | ] 57 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | -------------------------------------------------------------------------------- /tools/nformer_train.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import argparse 8 | import os 9 | import sys 10 | import torch 11 | 12 | from torch.backends import cudnn 13 | 14 | sys.path.append('.') 15 | from config import cfg 16 | from data import make_data_loader 17 | from engine.trainer import do_train, do_train_with_center 18 | from modeling import build_model 19 | from layers import make_loss, make_loss_with_center 20 | from solver import make_optimizer, make_optimizer_with_center, WarmupMultiStepLR 21 | 22 | from utils.logger import setup_logger 23 | 24 | 25 | def train(cfg): 26 | # prepare dataset 27 | train_loader, val_loader, num_query, num_classes = make_data_loader(cfg) 28 | 29 | # prepare model 30 | model = build_model(cfg, num_classes) 31 | 32 | if cfg.MODEL.IF_WITH_CENTER == 'no': 33 | print('Train without center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE) 34 | optimizer = make_optimizer(cfg, model) 35 | # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 36 | # cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 37 | 38 | loss_func = make_loss(cfg, num_classes) # modified by gu 39 | 40 | # Add for using self trained model 41 | if cfg.MODEL.PRETRAIN_CHOICE == 'self': 42 | start_epoch = eval(cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')[-1]) 43 | print('Start epoch:', start_epoch) 44 | path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer') 45 | print('Path to the checkpoint of optimizer:', path_to_optimizer) 46 | model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH)) 47 | optimizer.load_state_dict(torch.load(path_to_optimizer)) 48 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 49 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch) 50 | elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet': 51 | start_epoch = 0 52 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 53 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 54 | else: 55 | print('Only support pretrain_choice for imagenet and self, but got {}'.format(cfg.MODEL.PRETRAIN_CHOICE)) 56 | 57 | arguments = {} 58 | 59 | do_train( 60 | cfg, 61 | model, 62 | train_loader, 63 | val_loader, 64 | optimizer, 65 | scheduler, # modify for using self trained model 66 | loss_func, 67 | num_query, 68 | start_epoch # add for using self trained model 69 | ) 70 | elif cfg.MODEL.IF_WITH_CENTER == 'yes': 71 | print('Train with center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE) 72 | loss_func, center_criterion = make_loss_with_center(cfg, num_classes) # modified by gu 73 | optimizer, optimizer_center = make_optimizer_with_center(cfg, model, center_criterion) 74 | # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 75 | # cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 76 | 77 | arguments = {} 78 | 79 | # Add for using self trained model 80 | if cfg.MODEL.PRETRAIN_CHOICE == 'self': 81 | start_epoch = eval(cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')[-1]) 82 | print('Start epoch:', start_epoch) 83 | path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer') 84 | print('Path to the checkpoint of optimizer:', path_to_optimizer) 85 | path_to_center_param = cfg.MODEL.PRETRAIN_PATH.replace('model', 'center_param') 86 | print('Path to the checkpoint of center_param:', path_to_center_param) 87 | path_to_optimizer_center = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer_center') 88 | print('Path to the checkpoint of optimizer_center:', path_to_optimizer_center) 89 | model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH)) 90 | optimizer.load_state_dict(torch.load(path_to_optimizer)) 91 | center_criterion.load_state_dict(torch.load(path_to_center_param)) 92 | optimizer_center.load_state_dict(torch.load(path_to_optimizer_center)) 93 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 94 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch) 95 | elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet': 96 | start_epoch = 0 97 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 98 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 99 | else: 100 | print('Only support pretrain_choice for imagenet and self, but got {}'.format(cfg.MODEL.PRETRAIN_CHOICE)) 101 | 102 | do_train_with_center( 103 | cfg, 104 | model, 105 | center_criterion, 106 | train_loader, 107 | val_loader, 108 | optimizer, 109 | optimizer_center, 110 | scheduler, # modify for using self trained model 111 | loss_func, 112 | num_query, 113 | start_epoch # add for using self trained model 114 | ) 115 | else: 116 | print("Unsupported value for cfg.MODEL.IF_WITH_CENTER {}, only support yes or no!\n".format(cfg.MODEL.IF_WITH_CENTER)) 117 | 118 | 119 | def main(): 120 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 121 | parser.add_argument( 122 | "--config_file", default="", help="path to config file", type=str 123 | ) 124 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 125 | nargs=argparse.REMAINDER) 126 | 127 | args = parser.parse_args() 128 | 129 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 130 | 131 | if args.config_file != "": 132 | cfg.merge_from_file(args.config_file) 133 | cfg.merge_from_list(args.opts) 134 | cfg.freeze() 135 | 136 | output_dir = cfg.OUTPUT_DIR 137 | if output_dir and not os.path.exists(output_dir): 138 | os.makedirs(output_dir) 139 | 140 | logger = setup_logger("reid_baseline", output_dir, 0) 141 | logger.info("Using {} GPUS".format(num_gpus)) 142 | logger.info(args) 143 | 144 | if args.config_file != "": 145 | logger.info("Loaded configuration file {}".format(args.config_file)) 146 | with open(args.config_file, 'r') as cf: 147 | config_str = "\n" + cf.read() 148 | logger.info(config_str) 149 | logger.info("Running with config:\n{}".format(cfg)) 150 | 151 | if cfg.MODEL.DEVICE == "cuda": 152 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID # new add by gu 153 | cudnn.benchmark = True 154 | train(cfg) 155 | 156 | 157 | if __name__ == '__main__': 158 | main() 159 | -------------------------------------------------------------------------------- /tools/test.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import argparse 8 | import os 9 | import sys 10 | from os import mkdir 11 | 12 | import torch 13 | from torch.backends import cudnn 14 | 15 | sys.path.append('.') 16 | from config import cfg 17 | from data import make_data_loader 18 | from engine.inference import inference 19 | from modeling import build_nformer_model 20 | from utils.logger import setup_logger 21 | 22 | 23 | def main(): 24 | parser = argparse.ArgumentParser(description="ReID Baseline Inference") 25 | parser.add_argument( 26 | "--config_file", default="", help="path to config file", type=str 27 | ) 28 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 29 | nargs=argparse.REMAINDER) 30 | 31 | args = parser.parse_args() 32 | 33 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 34 | 35 | if args.config_file != "": 36 | cfg.merge_from_file(args.config_file) 37 | cfg.merge_from_list(args.opts) 38 | cfg.freeze() 39 | 40 | output_dir = cfg.OUTPUT_DIR 41 | if output_dir and not os.path.exists(output_dir): 42 | mkdir(output_dir) 43 | 44 | logger = setup_logger("reid_baseline", output_dir, 0) 45 | logger.info("Using {} GPUS".format(num_gpus)) 46 | logger.info(args) 47 | 48 | if args.config_file != "": 49 | logger.info("Loaded configuration file {}".format(args.config_file)) 50 | with open(args.config_file, 'r') as cf: 51 | config_str = "\n" + cf.read() 52 | logger.info(config_str) 53 | logger.info("Running with config:\n{}".format(cfg)) 54 | 55 | if cfg.MODEL.DEVICE == "cuda": 56 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID 57 | cudnn.benchmark = True 58 | 59 | train_loader, val_loader, num_query, num_classes = make_data_loader(cfg) 60 | model = build_nformer_model(cfg, num_classes) 61 | model.load_state_dict(torch.load(cfg.TEST.WEIGHT)) 62 | 63 | inference(cfg, model, val_loader, num_query) 64 | 65 | 66 | if __name__ == '__main__': 67 | main() 68 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import argparse 8 | import os 9 | import sys 10 | import torch 11 | 12 | from torch.backends import cudnn 13 | 14 | sys.path.append('.') 15 | from config import cfg 16 | from data import make_data_loader 17 | from engine.trainer import do_train, do_train_with_center 18 | from modeling import build_model, build_nformer_model 19 | from layers import make_loss, make_loss_with_center, make_nformer_loss_with_center 20 | from solver import make_optimizer, make_optimizer_with_center, make_nformer_optimizer_with_center, WarmupMultiStepLR 21 | 22 | from utils.logger import setup_logger 23 | 24 | 25 | def train(cfg): 26 | # prepare dataset 27 | train_loader, val_loader, num_query, num_classes = make_data_loader(cfg) 28 | 29 | # prepare model 30 | model = build_nformer_model(cfg, num_classes) 31 | 32 | if cfg.MODEL.IF_WITH_CENTER == 'no': 33 | print('Train without center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE) 34 | optimizer = make_optimizer(cfg, model) 35 | # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 36 | # cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 37 | 38 | loss_func = make_loss(cfg, num_classes) # modified by gu 39 | 40 | # Add for using self trained model 41 | if cfg.MODEL.PRETRAIN_CHOICE == 'self': 42 | start_epoch = eval(cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')[-1]) 43 | print('Start epoch:', start_epoch) 44 | path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer') 45 | print('Path to the checkpoint of optimizer:', path_to_optimizer) 46 | model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH)) 47 | optimizer.load_state_dict(torch.load(path_to_optimizer)) 48 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 49 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch) 50 | elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet': 51 | start_epoch = 0 52 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 53 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 54 | else: 55 | print('Only support pretrain_choice for imagenet and self, but got {}'.format(cfg.MODEL.PRETRAIN_CHOICE)) 56 | 57 | arguments = {} 58 | 59 | do_train( 60 | cfg, 61 | model, 62 | train_loader, 63 | val_loader, 64 | optimizer, 65 | scheduler, # modify for using self trained model 66 | loss_func, 67 | num_query, 68 | start_epoch # add for using self trained model 69 | ) 70 | elif cfg.MODEL.IF_WITH_CENTER == 'yes': 71 | print('Train with center loss, the loss type is', cfg.MODEL.METRIC_LOSS_TYPE) 72 | loss_func, center_criterion = make_loss_with_center(cfg, num_classes) # modified by gu 73 | nformer_loss_func, nformer_center_criterion = make_nformer_loss_with_center(cfg, num_classes) # modified by gu 74 | optimizer, optimizer_center, optimizer_nformer, optimizer_nformer_center = make_nformer_optimizer_with_center(cfg, model, center_criterion, nformer_center_criterion) 75 | # scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 76 | # cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 77 | 78 | arguments = {} 79 | 80 | # Add for using self trained model 81 | if cfg.MODEL.PRETRAIN_CHOICE == 'self': 82 | start_epoch = eval(cfg.MODEL.PRETRAIN_PATH.split('/')[-1].split('.')[0].split('_')[-1]) 83 | print('Start epoch:', start_epoch) 84 | path_to_optimizer = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer') 85 | print('Path to the checkpoint of optimizer:', path_to_optimizer) 86 | path_to_center_param = cfg.MODEL.PRETRAIN_PATH.replace('model', 'center_param') 87 | print('Path to the checkpoint of center_param:', path_to_center_param) 88 | path_to_optimizer_center = cfg.MODEL.PRETRAIN_PATH.replace('model', 'optimizer_center') 89 | print('Path to the checkpoint of optimizer_center:', path_to_optimizer_center) 90 | model.load_state_dict(torch.load(cfg.MODEL.PRETRAIN_PATH)) 91 | optimizer.load_state_dict(torch.load(path_to_optimizer)) 92 | center_criterion.load_state_dict(torch.load(path_to_center_param)) 93 | optimizer_center.load_state_dict(torch.load(path_to_optimizer_center)) 94 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 95 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD, start_epoch) 96 | elif cfg.MODEL.PRETRAIN_CHOICE == 'imagenet': 97 | start_epoch = 0 98 | scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS, cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR, 99 | cfg.SOLVER.WARMUP_ITERS, cfg.SOLVER.WARMUP_METHOD) 100 | else: 101 | print('Only support pretrain_choice for imagenet and self, but got {}'.format(cfg.MODEL.PRETRAIN_CHOICE)) 102 | 103 | do_train_with_center( 104 | cfg, 105 | model, 106 | center_criterion, 107 | nformer_center_criterion, 108 | train_loader, 109 | val_loader, 110 | optimizer, 111 | optimizer_center, 112 | optimizer_nformer, 113 | optimizer_nformer_center, 114 | scheduler, # modify for using self trained model 115 | loss_func, 116 | nformer_loss_func, 117 | num_query, 118 | start_epoch # add for using self trained model 119 | ) 120 | else: 121 | print("Unsupported value for cfg.MODEL.IF_WITH_CENTER {}, only support yes or no!\n".format(cfg.MODEL.IF_WITH_CENTER)) 122 | 123 | 124 | def main(): 125 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 126 | parser.add_argument( 127 | "--config_file", default="", help="path to config file", type=str 128 | ) 129 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 130 | nargs=argparse.REMAINDER) 131 | 132 | args = parser.parse_args() 133 | 134 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 135 | 136 | if args.config_file != "": 137 | cfg.merge_from_file(args.config_file) 138 | cfg.merge_from_list(args.opts) 139 | cfg.freeze() 140 | 141 | output_dir = cfg.OUTPUT_DIR 142 | if output_dir and not os.path.exists(output_dir): 143 | os.makedirs(output_dir) 144 | 145 | logger = setup_logger("reid_baseline", output_dir, 0) 146 | logger.info("Using {} GPUS".format(num_gpus)) 147 | logger.info(args) 148 | 149 | if args.config_file != "": 150 | logger.info("Loaded configuration file {}".format(args.config_file)) 151 | with open(args.config_file, 'r') as cf: 152 | config_str = "\n" + cf.read() 153 | logger.info(config_str) 154 | logger.info("Running with config:\n{}".format(cfg)) 155 | 156 | if cfg.MODEL.DEVICE == "cuda": 157 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID # new add by gu 158 | cudnn.benchmark = True 159 | train(cfg) 160 | 161 | 162 | if __name__ == '__main__': 163 | main() 164 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | -------------------------------------------------------------------------------- /utils/iotools.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import errno 8 | import json 9 | import os 10 | 11 | import os.path as osp 12 | 13 | 14 | def mkdir_if_missing(directory): 15 | if not osp.exists(directory): 16 | try: 17 | os.makedirs(directory) 18 | except OSError as e: 19 | if e.errno != errno.EEXIST: 20 | raise 21 | 22 | 23 | def check_isfile(path): 24 | isfile = osp.isfile(path) 25 | if not isfile: 26 | print("=> Warning: no file found at '{}' (ignored)".format(path)) 27 | return isfile 28 | 29 | 30 | def read_json(fpath): 31 | with open(fpath, 'r') as f: 32 | obj = json.load(f) 33 | return obj 34 | 35 | 36 | def write_json(obj, fpath): 37 | mkdir_if_missing(osp.dirname(fpath)) 38 | with open(fpath, 'w') as f: 39 | json.dump(obj, f, indent=4, separators=(',', ': ')) 40 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import logging 8 | import os 9 | import sys 10 | 11 | 12 | def setup_logger(name, save_dir, distributed_rank): 13 | logger = logging.getLogger(name) 14 | logger.setLevel(logging.DEBUG) 15 | # don't log results for the non-master process 16 | if distributed_rank > 0: 17 | return logger 18 | ch = logging.StreamHandler(stream=sys.stdout) 19 | ch.setLevel(logging.DEBUG) 20 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 21 | ch.setFormatter(formatter) 22 | logger.addHandler(ch) 23 | 24 | if save_dir: 25 | fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w') 26 | fh.setLevel(logging.DEBUG) 27 | fh.setFormatter(formatter) 28 | logger.addHandler(fh) 29 | 30 | return logger 31 | -------------------------------------------------------------------------------- /utils/re_ranking.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri, 25 May 2018 20:29:09 5 | 6 | @author: luohao 7 | """ 8 | 9 | """ 10 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 11 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 12 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 13 | """ 14 | 15 | """ 16 | API 17 | 18 | probFea: all feature vectors of the query set (torch tensor) 19 | probFea: all feature vectors of the gallery set (torch tensor) 20 | k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3) 21 | MemorySave: set to 'True' when using MemorySave mode 22 | Minibatch: avaliable when 'MemorySave' is 'True' 23 | """ 24 | 25 | import numpy as np 26 | import torch 27 | 28 | 29 | def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat=None, only_local=False): 30 | # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor 31 | query_num = probFea.size(0) 32 | all_num = query_num + galFea.size(0) 33 | if only_local: 34 | original_dist = local_distmat 35 | else: 36 | feat = torch.cat([probFea,galFea]) 37 | print('using GPU to compute original distance') 38 | distmat = torch.pow(feat,2).sum(dim=1, keepdim=True).expand(all_num,all_num) + \ 39 | torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t() 40 | distmat.addmm_(1,-2,feat,feat.t()) 41 | original_dist = distmat.cpu().numpy() 42 | del feat 43 | if not local_distmat is None: 44 | original_dist = original_dist + local_distmat 45 | gallery_num = original_dist.shape[0] 46 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) 47 | V = np.zeros_like(original_dist).astype(np.float16) 48 | initial_rank = np.argsort(original_dist).astype(np.int32) 49 | 50 | print('starting re_ranking') 51 | for i in range(all_num): 52 | # k-reciprocal neighbors 53 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 54 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 55 | fi = np.where(backward_k_neigh_index == i)[0] 56 | k_reciprocal_index = forward_k_neigh_index[fi] 57 | k_reciprocal_expansion_index = k_reciprocal_index 58 | for j in range(len(k_reciprocal_index)): 59 | candidate = k_reciprocal_index[j] 60 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1] 61 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 62 | :int(np.around(k1 / 2)) + 1] 63 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 64 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 65 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 66 | candidate_k_reciprocal_index): 67 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 68 | 69 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 70 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 71 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) 72 | original_dist = original_dist[:query_num, ] 73 | if k2 != 1: 74 | V_qe = np.zeros_like(V, dtype=np.float16) 75 | for i in range(all_num): 76 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 77 | V = V_qe 78 | del V_qe 79 | del initial_rank 80 | invIndex = [] 81 | for i in range(gallery_num): 82 | invIndex.append(np.where(V[:, i] != 0)[0]) 83 | 84 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) 85 | 86 | for i in range(query_num): 87 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16) 88 | indNonZero = np.where(V[i, :] != 0)[0] 89 | indImages = [invIndex[ind] for ind in indNonZero] 90 | for j in range(len(indNonZero)): 91 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 92 | V[indImages[j], indNonZero[j]]) 93 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 94 | 95 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 96 | del original_dist 97 | del V 98 | del jaccard_dist 99 | final_dist = final_dist[:query_num, query_num:] 100 | return final_dist 101 | 102 | -------------------------------------------------------------------------------- /utils/reid_metric.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import numpy as np 8 | import torch 9 | from ignite.metrics import Metric 10 | 11 | from data.datasets.eval_reid import eval_func 12 | from .re_ranking import re_ranking 13 | 14 | 15 | class R1_mAP(Metric): 16 | def __init__(self, num_query, max_rank=50, feat_norm='yes'): 17 | super(R1_mAP, self).__init__() 18 | self.num_query = num_query 19 | self.max_rank = max_rank 20 | self.feat_norm = feat_norm 21 | 22 | def reset(self): 23 | self.feats = [] 24 | self.pids = [] 25 | self.camids = [] 26 | 27 | def update(self, output): 28 | feat, pid, camid = output 29 | self.feats.append(feat) 30 | self.pids.extend(np.asarray(pid)) 31 | self.camids.extend(np.asarray(camid)) 32 | 33 | def compute(self): 34 | feats = torch.cat(self.feats, dim=0) 35 | if self.feat_norm == 'yes': 36 | print("The test feature is normalized") 37 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) 38 | # query 39 | qf = feats[:self.num_query] 40 | q_pids = np.asarray(self.pids[:self.num_query]) 41 | q_camids = np.asarray(self.camids[:self.num_query]) 42 | # gallery 43 | gf = feats[self.num_query:] 44 | g_pids = np.asarray(self.pids[self.num_query:]) 45 | g_camids = np.asarray(self.camids[self.num_query:]) 46 | m, n = qf.shape[0], gf.shape[0] 47 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 48 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 49 | distmat.addmm_(1, -2, qf, gf.t()) 50 | distmat = distmat.cpu().numpy() 51 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids) 52 | 53 | return cmc, mAP 54 | 55 | class NFormer_R1_mAP(Metric): 56 | def __init__(self, model, num_query, max_rank=50, feat_norm='yes'): 57 | super(NFormer_R1_mAP, self).__init__() 58 | self.model = model 59 | self.num_query = num_query 60 | self.max_rank = max_rank 61 | self.feat_norm = feat_norm 62 | 63 | def reset(self): 64 | self.feats = [] 65 | self.pids = [] 66 | self.camids = [] 67 | 68 | def update(self, output): 69 | feat, pid, camid = output 70 | self.feats.append(feat) 71 | self.pids.extend(np.asarray(pid)) 72 | self.camids.extend(np.asarray(camid)) 73 | 74 | def compute(self): 75 | feats = torch.cat(self.feats, dim=0) 76 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) 77 | self.model.eval() 78 | with torch.no_grad(): 79 | feats = self.model(feats.unsqueeze(0), stage='nformer')[0] 80 | self.model.train() 81 | if self.feat_norm == 'yes': 82 | print("The test feature is normalized") 83 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) 84 | # query 85 | qf = feats[:self.num_query] 86 | q_pids = np.asarray(self.pids[:self.num_query]) 87 | q_camids = np.asarray(self.camids[:self.num_query]) 88 | # gallery 89 | gf = feats[self.num_query:] 90 | g_pids = np.asarray(self.pids[self.num_query:]) 91 | g_camids = np.asarray(self.camids[self.num_query:]) 92 | m, n = qf.shape[0], gf.shape[0] 93 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 94 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 95 | distmat.addmm_(1, -2, qf, gf.t()) 96 | distmat = distmat.cpu().numpy() 97 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids) 98 | 99 | return cmc, mAP 100 | 101 | 102 | class R1_mAP_reranking(Metric): 103 | def __init__(self, num_query, max_rank=50, feat_norm='yes'): 104 | super(R1_mAP_reranking, self).__init__() 105 | self.num_query = num_query 106 | self.max_rank = max_rank 107 | self.feat_norm = feat_norm 108 | 109 | def reset(self): 110 | self.feats = [] 111 | self.pids = [] 112 | self.camids = [] 113 | 114 | def update(self, output): 115 | feat, pid, camid = output 116 | self.feats.append(feat) 117 | self.pids.extend(np.asarray(pid)) 118 | self.camids.extend(np.asarray(camid)) 119 | 120 | def compute(self): 121 | feats = torch.cat(self.feats, dim=0) 122 | if self.feat_norm == 'yes': 123 | print("The test feature is normalized") 124 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) 125 | 126 | # query 127 | qf = feats[:self.num_query] 128 | q_pids = np.asarray(self.pids[:self.num_query]) 129 | q_camids = np.asarray(self.camids[:self.num_query]) 130 | # gallery 131 | gf = feats[self.num_query:] 132 | g_pids = np.asarray(self.pids[self.num_query:]) 133 | g_camids = np.asarray(self.camids[self.num_query:]) 134 | # m, n = qf.shape[0], gf.shape[0] 135 | # distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 136 | # torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 137 | # distmat.addmm_(1, -2, qf, gf.t()) 138 | # distmat = distmat.cpu().numpy() 139 | print("Enter reranking") 140 | distmat = re_ranking(qf, gf, k1=20, k2=6, lambda_value=0.3) 141 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids) 142 | 143 | return cmc, mAP 144 | --------------------------------------------------------------------------------