├── LICENSE ├── README.md ├── config ├── __init__.py └── defaults.py ├── configs ├── DukeMTMC │ ├── deit_transreid_stride.yml │ ├── vit_base.yml │ ├── vit_jpm.yml │ ├── vit_sie.yml │ ├── vit_transreid.yml │ ├── vit_transreid_384.yml │ ├── vit_transreid_stride.yml │ └── vit_transreid_stride_384.yml ├── MSMT17 │ ├── deit_small.yml │ ├── deit_transreid_stride.yml │ ├── vit_base.yml │ ├── vit_jpm.yml │ ├── vit_sie.yml │ ├── vit_small.yml │ ├── vit_transreid.yml │ ├── vit_transreid_384.yml │ ├── vit_transreid_stride.yml │ └── vit_transreid_stride_384.yml ├── Market │ ├── deit_transreid_stride.yml │ ├── vit_base.yml │ ├── vit_jpm.yml │ ├── vit_sie.yml │ ├── vit_transreid.yml │ ├── vit_transreid_384.yml │ ├── vit_transreid_stride.yml │ └── vit_transreid_stride_384.yml ├── OCC_Duke │ ├── deit_transreid_stride.yml │ ├── vit_base.yml │ ├── vit_jpm.yml │ ├── vit_sie.yml │ ├── vit_transreid.yml │ └── vit_transreid_stride.yml ├── VeRi │ ├── deit_transreid.yml │ ├── deit_transreid_stride.yml │ ├── vit_base.yml │ ├── vit_transreid.yml │ └── vit_transreid_stride.yml ├── VehicleID │ ├── deit_transreid.yml │ ├── deit_transreid_stride.yml │ ├── vit_base.yml │ ├── vit_transreid.yml │ └── vit_transreid_stride.yml └── transformer_base.yml ├── datasets ├── __init__.py ├── bases.py ├── cuhk03.py ├── dukemtmcreid.py ├── keypoint_test.txt ├── keypoint_train.txt ├── make_dataloader.py ├── market1501.py ├── msmt17.py ├── occ_duke.py ├── preprocessing.py ├── sampler.py ├── sampler_ddp.py ├── vehicleid.py └── veri.py ├── figs ├── exp.png └── framework.png ├── loss ├── __init__.py ├── arcface.py ├── center_loss.py ├── dissimilar_loss.py ├── make_loss.py ├── metric_learning.py ├── softmax_loss.py └── triplet_loss.py ├── model ├── __init__.py ├── backbones │ ├── __init__.py │ ├── resnet.py │ └── vit_pytorch.py └── make_model.py ├── processor ├── __init__.py └── processor.py ├── solver ├── __init__.py ├── cosine_lr.py ├── lr_scheduler.py ├── make_optimizer.py ├── scheduler.py └── scheduler_factory.py ├── test.py ├── train.py └── utils ├── __init__.py ├── iotools.py ├── logger.py ├── meter.py ├── metrics.py └── reranking.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ant Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # [AAAI2023] DC-Former: Diverse and Compact Transformer for Person Re-Identification 3 | 4 | The repository for [DC-Former: Diverse and Compact Transformer for Person Re-Identification] achieves state-of-the-art performances on 3 commonly used person re-ID including MSMT17, Market-1501 and CUHK03. 5 | 6 | 7 | ## Overview 8 | In person re-identification (re-ID) task, it is still challenging to learn discriminative representation by deep learning, due to limited data. Generally speaking, the model will get better performance when increasing the amount of data. The addition of similar classes strengthens the ability of the classifier to identify similar identities, thereby improving the discrimination of representation. In this paper, we propose a Diverse and Compact Transformer (DC-Former) that can achieve a similar effect by splitting embedding space into multiple diverse and compact subspaces. Compact embedding subspace helps model learn more robust and discriminative embedding to identify similar classes. And the fusion of these diverse embeddings containing more fine-grained information can further improve the effect of re-ID. Specifically, multiple class tokens are used in vision transformer to represent multiple embedding spaces. Then, a self-diverse constraint (SDC) is applied to these spaces to push them away from each other, which makes each embedding space diverse and compact. Further, a dynamic weight controller (DWC) is further designed for balancing the relative importance among them during training. The experimental results of our method are promising, which surpass previous state-of-the-art methods on several commonly used person re-ID benchmarks. Our code will be publicly available soon. 9 | 10 | ![framework](figs/framework.png) 11 | 12 | ## Performance 13 | 14 | ![framework](figs/exp.png) 15 | 16 | 17 | ## Training 18 | 19 | We utilize 4 GPUs for training. 20 | 21 | ```bash 22 | CUDA_VISIBLE_DEVICES=0,1,2,3 \ 23 | python -m torch.distributed.launch --nproc_per_node=4 --master_port 66666 train.py --config_file configs/MSMT17/vit_transreid_stride_384.yml \ 24 | MODEL.DIST_TRAIN True \ 25 | INPUT.GS_PROB 1.0 \ 26 | MODEL.JPM False \ 27 | MODEL.PRETRAIN_PATH $PRETRAIN_PATH \ 28 | MODEL.CLS_TOKEN_NUM 2 \ 29 | MODEL.CLS_TOKENS_LOSS True \ 30 | MODEL.DYNAMIC_BALANCER False \ 31 | DATASETS.ROOT_DIR $DATAROOT \ 32 | SOLVER.BASE_LR 0.032 \ 33 | SOLVER.IMS_PER_BATCH 256 \ 34 | SOLVER.MAX_EPOCHS 180 \ 35 | SOLVER.CHECKPOINT_PERIOD 180 \ 36 | SOLVER.EVAL_PERIOD 180 \ 37 | TEST.MEAN_FEAT False \ 38 | OUTPUT_DIR $OUTPUT_DIR 39 | ``` 40 | 41 | ## Evaluation 42 | 43 | ```bash 44 | python test.py --config_file configs/MSMT17/vit_transreid_stride_384.yml \ 45 | INPUT.GS_PROB 1.0 \ 46 | MODEL.DEVICE_ID "('0')" \ 47 | MODEL.JPM False \ 48 | MODEL.PRETRAIN_PATH $PRETRAIN_PATH \ 49 | MODEL.PRETRAIN_CHOICE self \ 50 | MODEL.CLS_TOKEN_NUM 2 \ 51 | MODEL.CLS_TOKENS_LOSS True \ 52 | DATASETS.ROOT_DIR $DATAROOT \ 53 | TEST.MEAN_FEAT False \ 54 | TEST.WEIGHT $CHECKPOINT_PATH \ 55 | OUTPUT_DIR $OUTPUT_DIR 56 | ``` 57 | 58 | ## Acknowledgement 59 | 60 | Codebase from [TransReID](https://github.com/damo-cv/TransReID) 61 | 62 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | from .defaults import _C as cfg 4 | from .defaults import _C as cfg_test 5 | -------------------------------------------------------------------------------- /config/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | # ----------------------------------------------------------------------------- 4 | # Convention about Training / Test specific parameters 5 | # ----------------------------------------------------------------------------- 6 | # Whenever an argument can be either used for training or for testing, the 7 | # corresponding name will be post-fixed by a _TRAIN for a training parameter, 8 | 9 | # ----------------------------------------------------------------------------- 10 | # Config definition 11 | # ----------------------------------------------------------------------------- 12 | 13 | _C = CN() 14 | # ----------------------------------------------------------------------------- 15 | # MODEL 16 | # ----------------------------------------------------------------------------- 17 | _C.MODEL = CN() 18 | # Using cuda or cpu for training 19 | _C.MODEL.DEVICE = "cuda" 20 | # ID number of GPU 21 | _C.MODEL.DEVICE_ID = '0' 22 | # Name of backbone 23 | _C.MODEL.NAME = 'resnet50' 24 | # Last stride of backbone 25 | _C.MODEL.LAST_STRIDE = 1 26 | # Path to pretrained model of backbone 27 | _C.MODEL.PRETRAIN_PATH = '' 28 | 29 | # Use ImageNet pretrained model to initialize backbone or use self trained model to initialize the whole model 30 | # Options: 'imagenet' , 'self' , 'finetune' 31 | _C.MODEL.PRETRAIN_CHOICE = 'imagenet' 32 | 33 | # If train with BNNeck, options: 'bnneck' or 'no' 34 | _C.MODEL.NECK = 'bnneck' 35 | # If train loss include center loss, options: 'yes' or 'no'. Loss with center loss has different optimizer configuration 36 | _C.MODEL.IF_WITH_CENTER = 'no' 37 | 38 | _C.MODEL.ID_LOSS_TYPE = 'softmax' 39 | _C.MODEL.ID_LOSS_WEIGHT = 1.0 40 | _C.MODEL.TRIPLET_LOSS_WEIGHT = 1.0 41 | _C.MODEL.CLS_TOKEN_NUM = 1 42 | _C.MODEL.CLS_TOKENS_LOSS = False 43 | _C.MODEL.DIVERSE_CLS_WEIGHT = 1.0 44 | _C.MODEL.REWEIGHT_CLS_TOKEN = False 45 | _C.MODEL.CLS_TOKEN_WEIGHT = 2 46 | _C.MODEL.DYNAMIC_BALANCER = False 47 | _C.MODEL.PART_ID_RATIO = 1.0 48 | 49 | _C.MODEL.METRIC_LOSS_TYPE = 'triplet' 50 | # If train with multi-gpu ddp mode, options: 'True', 'False' 51 | _C.MODEL.DIST_TRAIN = False 52 | # If train with soft triplet loss, options: 'True', 'False' 53 | _C.MODEL.NO_MARGIN = False 54 | # If train with label smooth, options: 'on', 'off' 55 | _C.MODEL.IF_LABELSMOOTH = 'on' 56 | # If train with temperature, options: 'on', 'off' 57 | _C.MODEL.IF_TEMPERATURE_SOFTMAX = 'off' 58 | # temperature 59 | _C.MODEL.TEMPERATURE = [1.0, 0.5] 60 | # If train with arcface loss, options: 'True', 'False' 61 | _C.MODEL.COS_LAYER = False 62 | 63 | # Transformer setting 64 | _C.MODEL.DROP_PATH = 0.1 65 | _C.MODEL.DROP_OUT = 0.0 66 | _C.MODEL.ATT_DROP_RATE = 0.0 67 | _C.MODEL.TRANSFORMER_TYPE = 'None' 68 | _C.MODEL.STRIDE_SIZE = [16, 16] 69 | 70 | # JPM Parameter 71 | _C.MODEL.JPM = False 72 | _C.MODEL.SHIFT_NUM = 5 73 | _C.MODEL.SHUFFLE_GROUP = 2 74 | _C.MODEL.DEVIDE_LENGTH = 4 75 | _C.MODEL.RE_ARRANGE = True 76 | 77 | # SIE Parameter 78 | _C.MODEL.SIE_COE = 3.0 79 | _C.MODEL.SIE_CAMERA = False 80 | _C.MODEL.SIE_VIEW = False 81 | 82 | # ID hard mining per batch 83 | _C.MODEL.ID_HARD_MINING = False 84 | # ----------------------------------------------------------------------------- 85 | # INPUT 86 | # ----------------------------------------------------------------------------- 87 | _C.INPUT = CN() 88 | # Size of the image during training 89 | _C.INPUT.SIZE_TRAIN = [384, 128] 90 | # Size of the image during test 91 | _C.INPUT.SIZE_TEST = [384, 128] 92 | # Random probability for image horizontal flip 93 | _C.INPUT.PROB = 0.5 94 | # Random probability for random erasing 95 | _C.INPUT.RE_PROB = 0.5 96 | # Random probability for grayscale patch replacement 97 | _C.INPUT.GS_PROB = 0.0 98 | # Values to be used for image normalization 99 | _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406] 100 | # Values to be used for image normalization 101 | _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225] 102 | # Value of padding size 103 | _C.INPUT.PADDING = 10 104 | 105 | # ----------------------------------------------------------------------------- 106 | # Dataset 107 | # ----------------------------------------------------------------------------- 108 | _C.DATASETS = CN() 109 | # List of the dataset names for training, as present in paths_catalog.py 110 | _C.DATASETS.NAMES = ('market1501') 111 | # Root directory where datasets should be used (and downloaded if not found) 112 | _C.DATASETS.ROOT_DIR = ('../data') 113 | 114 | 115 | # ----------------------------------------------------------------------------- 116 | # DataLoader 117 | # ----------------------------------------------------------------------------- 118 | _C.DATALOADER = CN() 119 | # Number of data loading threads 120 | _C.DATALOADER.NUM_WORKERS = 8 121 | # Sampler for data loading 122 | _C.DATALOADER.SAMPLER = 'softmax' 123 | # Number of instance for one batch 124 | _C.DATALOADER.NUM_INSTANCE = 16 125 | 126 | # ---------------------------------------------------------------------------- # 127 | # Solver 128 | # ---------------------------------------------------------------------------- # 129 | _C.SOLVER = CN() 130 | # Name of optimizer 131 | _C.SOLVER.OPTIMIZER_NAME = "Adam" 132 | # Number of max epoches 133 | _C.SOLVER.MAX_EPOCHS = 100 134 | # Base learning rate 135 | _C.SOLVER.BASE_LR = 3e-4 136 | # Whether using larger learning rate for fc layer 137 | _C.SOLVER.LARGE_FC_LR = False 138 | # Factor of learning bias 139 | _C.SOLVER.BIAS_LR_FACTOR = 1 140 | # Factor of learning bias 141 | _C.SOLVER.SEED = 1234 142 | # Momentum 143 | _C.SOLVER.MOMENTUM = 0.9 144 | # Margin of triplet loss 145 | _C.SOLVER.MARGIN = 0.3 146 | # Learning rate of SGD to learn the centers of center loss 147 | _C.SOLVER.CENTER_LR = 0.5 148 | # Balanced weight of center loss 149 | _C.SOLVER.CENTER_LOSS_WEIGHT = 0.0005 150 | 151 | # Settings of weight decay 152 | _C.SOLVER.WEIGHT_DECAY = 0.0005 153 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0.0005 154 | 155 | # decay rate of learning rate 156 | _C.SOLVER.GAMMA = 0.1 157 | # decay step of learning rate 158 | _C.SOLVER.STEPS = (40, 70) 159 | # warm up factor 160 | _C.SOLVER.WARMUP_FACTOR = 0.01 161 | # warm up epochs 162 | _C.SOLVER.WARMUP_EPOCHS = 5 163 | # method of warm up, option: 'constant','linear' 164 | _C.SOLVER.WARMUP_METHOD = "linear" 165 | 166 | _C.SOLVER.COSINE_MARGIN = 0.5 167 | _C.SOLVER.COSINE_SCALE = 30 168 | 169 | # epoch number of saving checkpoints 170 | _C.SOLVER.CHECKPOINT_PERIOD = 10 171 | # iteration of display training log 172 | _C.SOLVER.LOG_PERIOD = 100 173 | # epoch number of validation 174 | _C.SOLVER.EVAL_PERIOD = 10 175 | # Number of images per batch 176 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 128, each GPU will 177 | # contain 16 images per batch 178 | _C.SOLVER.IMS_PER_BATCH = 64 179 | 180 | # ---------------------------------------------------------------------------- # 181 | # TEST 182 | # ---------------------------------------------------------------------------- # 183 | 184 | _C.TEST = CN() 185 | # Number of images per batch during test 186 | _C.TEST.IMS_PER_BATCH = 128 187 | # If test with re-ranking, options: 'True','False' 188 | _C.TEST.RE_RANKING = False 189 | # Path to trained model 190 | _C.TEST.WEIGHT = "" 191 | # Which feature of BNNeck to be used for test, before or after BNNneck, options: 'before' or 'after' 192 | _C.TEST.NECK_FEAT = 'after' 193 | # Whether feature is nomalized before test, if yes, it is equivalent to cosine distance 194 | _C.TEST.FEAT_NORM = 'yes' 195 | 196 | # Whether mean multi cls token, if False, concat 197 | _C.TEST.MEAN_FEAT = False 198 | 199 | # Name for saving the distmat after testing. 200 | _C.TEST.DIST_MAT = "dist_mat.npy" 201 | # Whether calculate the eval score option: 'True', 'False' 202 | _C.TEST.EVAL = False 203 | # ---------------------------------------------------------------------------- # 204 | # Misc options 205 | # ---------------------------------------------------------------------------- # 206 | # Path to checkpoint and saved log of trained model 207 | _C.OUTPUT_DIR = "" 208 | -------------------------------------------------------------------------------- /configs/DukeMTMC/deit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('2') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.8 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('dukemtmc') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '../logs/dukemtmc_deit_transreid/transformer_120.pth' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/dukemtmc_deit_transreid' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/DukeMTMC/vit_base.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('6') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 128] 15 | SIZE_TEST: [256, 128] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.5 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('dukemtmc') 24 | ROOT_DIR: ('../../data') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 120 34 | BASE_LR: 0.008 35 | IMS_PER_BATCH: 64 36 | WARMUP_METHOD: 'linear' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 120 39 | LOG_PERIOD: 50 40 | EVAL_PERIOD: 120 41 | WEIGHT_DECAY: 1e-4 42 | WEIGHT_DECAY_BIAS: 1e-4 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: '../logs/duke_vit_base' 54 | 55 | 56 | -------------------------------------------------------------------------------- /configs/DukeMTMC/vit_jpm.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('1') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | JPM: True 13 | RE_ARRANGE: True 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('dukemtmc') 26 | ROOT_DIR: ('../../data') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.008 37 | IMS_PER_BATCH: 64 38 | WARMUP_METHOD: 'linear' 39 | LARGE_FC_LR: False 40 | CHECKPOINT_PERIOD: 120 41 | LOG_PERIOD: 50 42 | EVAL_PERIOD: 120 43 | WEIGHT_DECAY: 1e-4 44 | WEIGHT_DECAY_BIAS: 1e-4 45 | BIAS_LR_FACTOR: 2 46 | 47 | TEST: 48 | EVAL: True 49 | IMS_PER_BATCH: 256 50 | RE_RANKING: False 51 | WEIGHT: '../logs/duke_vit_jpm/transformer_120.pth' 52 | NECK_FEAT: 'before' 53 | FEAT_NORM: 'yes' 54 | 55 | OUTPUT_DIR: '../logs/duke_vit_jpm' 56 | 57 | 58 | -------------------------------------------------------------------------------- /configs/DukeMTMC/vit_sie.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('2') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('dukemtmc') 26 | ROOT_DIR: ('../../data') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.008 37 | IMS_PER_BATCH: 64 38 | WARMUP_METHOD: 'linear' 39 | LARGE_FC_LR: False 40 | CHECKPOINT_PERIOD: 120 41 | LOG_PERIOD: 50 42 | EVAL_PERIOD: 120 43 | WEIGHT_DECAY: 1e-4 44 | WEIGHT_DECAY_BIAS: 1e-4 45 | BIAS_LR_FACTOR: 2 46 | 47 | TEST: 48 | EVAL: True 49 | IMS_PER_BATCH: 256 50 | RE_RANKING: False 51 | WEIGHT: '../logs/duke_vit_sie/transformer_120.pth' 52 | NECK_FEAT: 'before' 53 | FEAT_NORM: 'yes' 54 | 55 | OUTPUT_DIR: '../logs/duke_vit_sie' 56 | 57 | 58 | -------------------------------------------------------------------------------- /configs/DukeMTMC/vit_transreid.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('dukemtmc') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '../logs/duke_vit_transreid/transformer_120.pth' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/duke_vit_transreid' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/DukeMTMC/vit_transreid_384.yml: -------------------------------------------------------------------------------- 1 | tMODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [384, 128] 19 | SIZE_TEST: [384, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('dukemtmc') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '../logs/duke_vit_transreid_384/transformer_120.pth' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/duke_vit_transreid_384' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/DukeMTMC/vit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('4') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('dukemtmc') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '../logs/duke_vit_transreid_stride/transformer_120.pth' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/duke_vit_transreid_stride' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/DukeMTMC/vit_transreid_stride_384.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [384, 128] 19 | SIZE_TEST: [384, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('dukemtmc') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '../logs/duke_vit_transreid_stride_384/transformer_120.pth' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/duke_vit_transreid_stride_384' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/MSMT17/deit_small.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/deit_small_distilled_patch16_224-649709d9.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('1') 10 | TRANSFORMER_TYPE: 'deit_small_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 128] 15 | SIZE_TEST: [256, 128] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.8 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('msmt17') 24 | ROOT_DIR: ('../../data') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 120 34 | BASE_LR: 0.005 35 | IMS_PER_BATCH: 64 36 | LARGE_FC_LR: False 37 | CHECKPOINT_PERIOD: 120 38 | LOG_PERIOD: 50 39 | EVAL_PERIOD: 120 40 | WEIGHT_DECAY: 1e-4 41 | WEIGHT_DECAY_BIAS: 1e-4 42 | BIAS_LR_FACTOR: 2 43 | 44 | TEST: 45 | EVAL: True 46 | IMS_PER_BATCH: 256 47 | RE_RANKING: False 48 | WEIGHT: '' 49 | NECK_FEAT: 'before' 50 | FEAT_NORM: 'yes' 51 | 52 | OUTPUT_DIR: '../logs/msmt17_deit_small_try' 53 | 54 | 55 | -------------------------------------------------------------------------------- /configs/MSMT17/deit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('5') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.8 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('msmt17') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.005 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/msmt17_deit_transreid' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/MSMT17/vit_base.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('4') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 128] 15 | SIZE_TEST: [256, 128] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.5 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('msmt17') 24 | ROOT_DIR: ('../../data') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 120 34 | BASE_LR: 0.008 35 | IMS_PER_BATCH: 64 36 | WARMUP_METHOD: 'linear' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 120 39 | LOG_PERIOD: 50 40 | EVAL_PERIOD: 120 41 | WEIGHT_DECAY: 1e-4 42 | WEIGHT_DECAY_BIAS: 1e-4 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: '../logs/msmt17_vit_base' 54 | 55 | 56 | -------------------------------------------------------------------------------- /configs/MSMT17/vit_jpm.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('1') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | JPM: True 13 | RE_ARRANGE: True 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('msmt17') 26 | ROOT_DIR: ('../../data') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.008 37 | IMS_PER_BATCH: 64 38 | WARMUP_METHOD: 'linear' 39 | LARGE_FC_LR: False 40 | CHECKPOINT_PERIOD: 120 41 | LOG_PERIOD: 50 42 | EVAL_PERIOD: 120 43 | WEIGHT_DECAY: 1e-4 44 | WEIGHT_DECAY_BIAS: 1e-4 45 | BIAS_LR_FACTOR: 2 46 | 47 | TEST: 48 | EVAL: True 49 | IMS_PER_BATCH: 256 50 | RE_RANKING: False 51 | WEIGHT: '' 52 | NECK_FEAT: 'before' 53 | FEAT_NORM: 'yes' 54 | 55 | OUTPUT_DIR: '../logs/msmt17_vit_jpm' 56 | 57 | 58 | -------------------------------------------------------------------------------- /configs/MSMT17/vit_sie.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('2') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('msmt17') 26 | ROOT_DIR: ('../../data') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.008 37 | IMS_PER_BATCH: 64 38 | WARMUP_METHOD: 'linear' 39 | LARGE_FC_LR: False 40 | CHECKPOINT_PERIOD: 120 41 | LOG_PERIOD: 50 42 | EVAL_PERIOD: 120 43 | WEIGHT_DECAY: 1e-4 44 | WEIGHT_DECAY_BIAS: 1e-4 45 | BIAS_LR_FACTOR: 2 46 | 47 | TEST: 48 | EVAL: True 49 | IMS_PER_BATCH: 256 50 | RE_RANKING: False 51 | WEIGHT: '' 52 | NECK_FEAT: 'before' 53 | FEAT_NORM: 'yes' 54 | 55 | OUTPUT_DIR: '../logs/msmt17_vit_sie' 56 | 57 | 58 | -------------------------------------------------------------------------------- /configs/MSMT17/vit_small.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/vit_small_p16_224-15ec54c9.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('0') 10 | TRANSFORMER_TYPE: 'vit_small_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 128] 15 | SIZE_TEST: [256, 128] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.8 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('msmt17') 24 | ROOT_DIR: ('../../data') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 120 34 | BASE_LR: 0.005 35 | IMS_PER_BATCH: 64 36 | WARMUP_METHOD: 'linear' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 120 39 | LOG_PERIOD: 50 40 | EVAL_PERIOD: 120 41 | WEIGHT_DECAY: 1e-4 42 | WEIGHT_DECAY_BIAS: 1e-4 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: '../logs/msmt17_vit_small' 54 | 55 | 56 | -------------------------------------------------------------------------------- /configs/MSMT17/vit_transreid.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('msmt17') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/msmt17_vit_transreid' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/MSMT17/vit_transreid_384.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [384, 128] 19 | SIZE_TEST: [384, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('msmt17') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/msmt17_vit_transreid_384' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/MSMT17/vit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('6') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('msmt17') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/msmt17_vit_transreid_stride' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/MSMT17/vit_transreid_stride_384.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [384, 128] 19 | SIZE_TEST: [384, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('msmt17') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/msmt17_vit_transreid_stride_384' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/Market/deit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('4') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.8 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('market1501') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '../logs/0321_market_deit_transreie/transformer_120.pth' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/0321_market_deit_transreie' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/Market/vit_base.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('7') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 128] 15 | SIZE_TEST: [256, 128] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.5 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('market1501') 24 | ROOT_DIR: ('../../data') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 120 34 | BASE_LR: 0.008 35 | IMS_PER_BATCH: 64 36 | WARMUP_METHOD: 'linear' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 120 39 | LOG_PERIOD: 50 40 | EVAL_PERIOD: 120 41 | WEIGHT_DECAY: 1e-4 42 | WEIGHT_DECAY_BIAS: 1e-4 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '../logs/0321_market_vit_base/transformer_120.pth' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: '../logs/0321_market_vit_base' 54 | 55 | 56 | -------------------------------------------------------------------------------- /configs/Market/vit_jpm.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('1') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | JPM: True 13 | RE_ARRANGE: True 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('market1501') 26 | ROOT_DIR: ('../../data') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.008 37 | IMS_PER_BATCH: 64 38 | WARMUP_METHOD: 'linear' 39 | LARGE_FC_LR: False 40 | CHECKPOINT_PERIOD: 120 41 | LOG_PERIOD: 50 42 | EVAL_PERIOD: 120 43 | WEIGHT_DECAY: 1e-4 44 | WEIGHT_DECAY_BIAS: 1e-4 45 | BIAS_LR_FACTOR: 2 46 | 47 | TEST: 48 | EVAL: True 49 | IMS_PER_BATCH: 256 50 | RE_RANKING: False 51 | WEIGHT: '../logs/0321_market_vit_jpm/transformer_120.pth' 52 | NECK_FEAT: 'before' 53 | FEAT_NORM: 'yes' 54 | 55 | OUTPUT_DIR: '../logs/0321_market_vit_jpm' 56 | 57 | 58 | -------------------------------------------------------------------------------- /configs/Market/vit_sie.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('7') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('market1501') 26 | ROOT_DIR: ('../../data') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.008 37 | IMS_PER_BATCH: 64 38 | WARMUP_METHOD: 'linear' 39 | LARGE_FC_LR: False 40 | CHECKPOINT_PERIOD: 120 41 | LOG_PERIOD: 50 42 | EVAL_PERIOD: 120 43 | WEIGHT_DECAY: 1e-4 44 | WEIGHT_DECAY_BIAS: 1e-4 45 | BIAS_LR_FACTOR: 2 46 | 47 | TEST: 48 | EVAL: True 49 | IMS_PER_BATCH: 256 50 | RE_RANKING: False 51 | WEIGHT: '' 52 | NECK_FEAT: 'before' 53 | FEAT_NORM: 'yes' 54 | 55 | OUTPUT_DIR: '../logs/market_vit_sie' 56 | 57 | 58 | -------------------------------------------------------------------------------- /configs/Market/vit_transreid.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('5') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('market1501') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '../logs/market_vit_transreid/transformer_120.pth' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/market_vit_transreid' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/Market/vit_transreid_384.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('5') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [384, 128] 19 | SIZE_TEST: [384, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('market1501') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '../logs/market_vit_transreid_384/transformer_120.pth' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/market_vit_transreid_384' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/Market/vit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('5') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('market1501') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/market_vit_transreid_stride' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/Market/vit_transreid_stride_384.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [384, 128] 19 | SIZE_TEST: [384, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('market1501') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/0321_market_vit_transreid_stride_384' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/OCC_Duke/deit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('2') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [11, 11] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.8 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('occ_duke') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/occ_duke_deit_transreid_stride11' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/OCC_Duke/vit_base.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('5') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 128] 15 | SIZE_TEST: [256, 128] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.5 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('occ_duke') 24 | ROOT_DIR: ('../../data') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 120 34 | BASE_LR: 0.008 35 | IMS_PER_BATCH: 64 36 | WARMUP_METHOD: 'linear' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 120 39 | LOG_PERIOD: 50 40 | EVAL_PERIOD: 120 41 | WEIGHT_DECAY: 1e-4 42 | WEIGHT_DECAY_BIAS: 1e-4 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: '../logs/occ_duke_vit_base' 54 | 55 | 56 | -------------------------------------------------------------------------------- /configs/OCC_Duke/vit_jpm.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('1') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | JPM: True 13 | RE_ARRANGE: True 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('occ_duke') 26 | ROOT_DIR: ('../../data') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.008 37 | IMS_PER_BATCH: 64 38 | WARMUP_METHOD: 'linear' 39 | LARGE_FC_LR: False 40 | CHECKPOINT_PERIOD: 120 41 | LOG_PERIOD: 50 42 | EVAL_PERIOD: 120 43 | WEIGHT_DECAY: 1e-4 44 | WEIGHT_DECAY_BIAS: 1e-4 45 | BIAS_LR_FACTOR: 2 46 | 47 | TEST: 48 | EVAL: True 49 | IMS_PER_BATCH: 256 50 | RE_RANKING: False 51 | WEIGHT: '' 52 | NECK_FEAT: 'before' 53 | FEAT_NORM: 'yes' 54 | 55 | OUTPUT_DIR: '../logs/occ_duke_vit_jpm' 56 | 57 | 58 | -------------------------------------------------------------------------------- /configs/OCC_Duke/vit_sie.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('2') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('occ_duke') 26 | ROOT_DIR: ('../../data') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.008 37 | IMS_PER_BATCH: 64 38 | WARMUP_METHOD: 'linear' 39 | LARGE_FC_LR: False 40 | CHECKPOINT_PERIOD: 120 41 | LOG_PERIOD: 50 42 | EVAL_PERIOD: 120 43 | WEIGHT_DECAY: 1e-4 44 | WEIGHT_DECAY_BIAS: 1e-4 45 | BIAS_LR_FACTOR: 2 46 | 47 | TEST: 48 | EVAL: True 49 | IMS_PER_BATCH: 256 50 | RE_RANKING: False 51 | WEIGHT: '' 52 | NECK_FEAT: 'before' 53 | FEAT_NORM: 'yes' 54 | 55 | OUTPUT_DIR: '../logs/occ_duke_vit_sie' 56 | 57 | 58 | -------------------------------------------------------------------------------- /configs/OCC_Duke/vit_transreid.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('occ_duke') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/occ_duke_vit_transreid' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/OCC_Duke/vit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [11, 11] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('occ_duke') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/occ_duke_vit_transreid_stride' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/VeRi/deit_transreid.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('4') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_VIEW: True 14 | SIE_COE: 3.0 15 | JPM: True 16 | SHIFT_NUM: 8 17 | RE_ARRANGE: True 18 | 19 | INPUT: 20 | SIZE_TRAIN: [256, 256] 21 | SIZE_TEST: [256, 256] 22 | PROB: 0.5 # random horizontal flip 23 | RE_PROB: 0.8 # random erasing 24 | PADDING: 10 25 | 26 | DATASETS: 27 | NAMES: ('veri') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.01 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/veri_deit_transreid' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/VeRi/deit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('0') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_VIEW: True 14 | SIE_COE: 3.0 15 | JPM: True 16 | SHIFT_NUM: 8 17 | RE_ARRANGE: True 18 | 19 | INPUT: 20 | SIZE_TRAIN: [256, 256] 21 | SIZE_TEST: [256, 256] 22 | PROB: 0.5 # random horizontal flip 23 | RE_PROB: 0.8 # random erasing 24 | PADDING: 10 25 | 26 | DATASETS: 27 | NAMES: ('veri') 28 | ROOT_DIR: ('../../datasets') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.01 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/veri_deit_transreid_stride' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/VeRi/vit_base.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('4') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 256] 15 | SIZE_TEST: [256, 256] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.5 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('veri') 24 | ROOT_DIR: ('../../data') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 120 34 | BASE_LR: 0.008 35 | IMS_PER_BATCH: 64 36 | WARMUP_METHOD: 'linear' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 120 39 | LOG_PERIOD: 50 40 | EVAL_PERIOD: 120 41 | WEIGHT_DECAY: 1e-4 42 | WEIGHT_DECAY_BIAS: 1e-4 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: '../logs/veri_vit_base' 54 | 55 | 56 | -------------------------------------------------------------------------------- /configs/VeRi/vit_transreid.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('4') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_VIEW: True 14 | SIE_COE: 3.0 15 | JPM: True 16 | SHIFT_NUM: 8 17 | RE_ARRANGE: True 18 | 19 | INPUT: 20 | SIZE_TRAIN: [256, 256] 21 | SIZE_TEST: [256, 256] 22 | PROB: 0.5 # random horizontal flip 23 | RE_PROB: 0.5 # random erasing 24 | PADDING: 10 25 | PIXEL_MEAN: [0.5, 0.5, 0.5] 26 | PIXEL_STD: [0.5, 0.5, 0.5] 27 | 28 | DATASETS: 29 | NAMES: ('veri') 30 | ROOT_DIR: ('../../data') 31 | 32 | DATALOADER: 33 | SAMPLER: 'softmax_triplet' 34 | NUM_INSTANCE: 4 35 | NUM_WORKERS: 8 36 | 37 | SOLVER: 38 | OPTIMIZER_NAME: 'SGD' 39 | MAX_EPOCHS: 120 40 | BASE_LR: 0.01 41 | IMS_PER_BATCH: 64 42 | WARMUP_METHOD: 'linear' 43 | LARGE_FC_LR: False 44 | CHECKPOINT_PERIOD: 120 45 | LOG_PERIOD: 50 46 | EVAL_PERIOD: 120 47 | WEIGHT_DECAY: 1e-4 48 | WEIGHT_DECAY_BIAS: 1e-4 49 | BIAS_LR_FACTOR: 2 50 | 51 | TEST: 52 | EVAL: True 53 | IMS_PER_BATCH: 256 54 | RE_RANKING: False 55 | WEIGHT: '' 56 | NECK_FEAT: 'before' 57 | FEAT_NORM: 'yes' 58 | 59 | OUTPUT_DIR: '../logs/veri_vit_transreid' 60 | 61 | 62 | -------------------------------------------------------------------------------- /configs/VeRi/vit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('2') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_VIEW: True 14 | SIE_COE: 3.0 15 | JPM: True 16 | SHIFT_NUM: 8 17 | RE_ARRANGE: True 18 | 19 | INPUT: 20 | SIZE_TRAIN: [256, 256] 21 | SIZE_TEST: [256, 256] 22 | PROB: 0.5 # random horizontal flip 23 | RE_PROB: 0.5 # random erasing 24 | PADDING: 10 25 | PIXEL_MEAN: [0.5, 0.5, 0.5] 26 | PIXEL_STD: [0.5, 0.5, 0.5] 27 | 28 | DATASETS: 29 | NAMES: ('veri') 30 | ROOT_DIR: ('../../data') 31 | 32 | DATALOADER: 33 | SAMPLER: 'softmax_triplet' 34 | NUM_INSTANCE: 4 35 | NUM_WORKERS: 8 36 | 37 | SOLVER: 38 | OPTIMIZER_NAME: 'SGD' 39 | MAX_EPOCHS: 120 40 | BASE_LR: 0.01 41 | IMS_PER_BATCH: 64 42 | WARMUP_METHOD: 'linear' 43 | LARGE_FC_LR: False 44 | CHECKPOINT_PERIOD: 120 45 | LOG_PERIOD: 50 46 | EVAL_PERIOD: 120 47 | WEIGHT_DECAY: 1e-4 48 | WEIGHT_DECAY_BIAS: 1e-4 49 | BIAS_LR_FACTOR: 2 50 | 51 | TEST: 52 | EVAL: True 53 | IMS_PER_BATCH: 256 54 | RE_RANKING: False 55 | WEIGHT: '' 56 | NECK_FEAT: 'before' 57 | FEAT_NORM: 'yes' 58 | 59 | OUTPUT_DIR: '../logs/veri_vit_transreid_stride' 60 | 61 | 62 | -------------------------------------------------------------------------------- /configs/VehicleID/deit_transreid.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | # DEVICE_ID: ('0') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | DIST_TRAIN: True 13 | JPM: True 14 | SHIFT_NUM: 8 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 256] 19 | SIZE_TEST: [256, 256] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.8 # random erasing 22 | PADDING: 10 23 | 24 | DATASETS: 25 | NAMES: ('VehicleID') 26 | ROOT_DIR: ('../../data') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.03 37 | IMS_PER_BATCH: 256 38 | WARMUP_METHOD: 'linear' 39 | LARGE_FC_LR: False 40 | CHECKPOINT_PERIOD: 120 41 | LOG_PERIOD: 50 42 | EVAL_PERIOD: 120 43 | WEIGHT_DECAY: 1e-4 44 | WEIGHT_DECAY_BIAS: 1e-4 45 | BIAS_LR_FACTOR: 2 46 | 47 | TEST: 48 | EVAL: True 49 | IMS_PER_BATCH: 256 50 | RE_RANKING: False 51 | WEIGHT: '' 52 | NECK_FEAT: 'before' 53 | FEAT_NORM: 'yes' 54 | 55 | OUTPUT_DIR: '../logs/vehicleID_deit_transreid' 56 | 57 | 58 | -------------------------------------------------------------------------------- /configs/VehicleID/deit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | # DEVICE_ID: ('0') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | DIST_TRAIN: True 13 | JPM: True 14 | SHIFT_NUM: 8 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 256] 19 | SIZE_TEST: [256, 256] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.8 # random erasing 22 | PADDING: 10 23 | 24 | DATASETS: 25 | NAMES: ('VehicleID') 26 | ROOT_DIR: ('../../data') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.03 37 | IMS_PER_BATCH: 256 38 | WARMUP_METHOD: 'linear' 39 | LARGE_FC_LR: False 40 | CHECKPOINT_PERIOD: 120 41 | LOG_PERIOD: 50 42 | EVAL_PERIOD: 120 43 | WEIGHT_DECAY: 1e-4 44 | WEIGHT_DECAY_BIAS: 1e-4 45 | BIAS_LR_FACTOR: 2 46 | 47 | TEST: 48 | EVAL: True 49 | IMS_PER_BATCH: 256 50 | RE_RANKING: False 51 | WEIGHT: '' 52 | NECK_FEAT: 'before' 53 | FEAT_NORM: 'yes' 54 | 55 | OUTPUT_DIR: '../logs/vehicleID_deit_transreid_stride' 56 | 57 | 58 | -------------------------------------------------------------------------------- /configs/VehicleID/vit_base.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | # DEVICE_ID: ('0') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 256] 15 | SIZE_TEST: [256, 256] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.5 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('VehicleID') 24 | ROOT_DIR: ('../../data') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 120 34 | BASE_LR: 0.04 35 | IMS_PER_BATCH: 224 36 | WARMUP_METHOD: 'linear' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 120 39 | LOG_PERIOD: 50 40 | EVAL_PERIOD: 120 41 | WEIGHT_DECAY: 1e-4 42 | WEIGHT_DECAY_BIAS: 1e-4 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '../logs/vehicleID_vit_base/transformer_120.pth' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: '../logs/vehicleID_vit_base' 54 | 55 | 56 | -------------------------------------------------------------------------------- /configs/VehicleID/vit_transreid.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | # DEVICE_ID: ('0') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | # DIST_TRAIN: True 13 | JPM: True 14 | SHIFT_NUM: 8 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 256] 19 | SIZE_TEST: [256, 256] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('VehicleID') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.045 39 | IMS_PER_BATCH: 224 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '../logs/vehicleID_vit_transreid/transformer_120.pth' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/vehicleID_vit_transreid' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/VehicleID/vit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | # DEVICE_ID: ('0') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | # DIST_TRAIN: True 13 | JPM: True 14 | SHIFT_NUM: 8 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 256] 19 | SIZE_TEST: [256, 256] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('VehicleID') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.045 39 | IMS_PER_BATCH: 256 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '../logs/vehicleID_vit_transreid_stride/transformer_120.pth' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/vehicleID_vit_transreid_stride' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/transformer_base.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('7') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 256] 15 | SIZE_TEST: [256, 256] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.5 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('dukemtmc') 24 | ROOT_DIR: ('../../data') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 120 34 | BASE_LR: 0.008 35 | IMS_PER_BATCH: 64 36 | WARMUP_METHOD: 'linear' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 120 39 | LOG_PERIOD: 50 40 | EVAL_PERIOD: 120 41 | WEIGHT_DECAY: 1e-4 42 | WEIGHT_DECAY_BIAS: 1e-4 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: '../logs/' 54 | 55 | 56 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_dataloader import make_dataloader -------------------------------------------------------------------------------- /datasets/bases.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageFile 2 | 3 | from torch.utils.data import Dataset 4 | import os.path as osp 5 | import random 6 | import torch 7 | ImageFile.LOAD_TRUNCATED_IMAGES = True 8 | 9 | 10 | def read_image(img_path): 11 | """Keep reading image until succeed. 12 | This can avoid IOError incurred by heavy IO process.""" 13 | got_img = False 14 | if not osp.exists(img_path): 15 | raise IOError("{} does not exist".format(img_path)) 16 | while not got_img: 17 | try: 18 | img = Image.open(img_path).convert('RGB') 19 | got_img = True 20 | except IOError: 21 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 22 | pass 23 | return img 24 | 25 | 26 | class BaseDataset(object): 27 | """ 28 | Base class of reid dataset 29 | """ 30 | 31 | def get_imagedata_info(self, data): 32 | pids, cams, tracks = [], [], [] 33 | 34 | for _, pid, camid, trackid in data: 35 | pids += [pid] 36 | cams += [camid] 37 | tracks += [trackid] 38 | pids = set(pids) 39 | cams = set(cams) 40 | tracks = set(tracks) 41 | num_pids = len(pids) 42 | num_cams = len(cams) 43 | num_imgs = len(data) 44 | num_views = len(tracks) 45 | return num_pids, num_imgs, num_cams, num_views 46 | 47 | def print_dataset_statistics(self): 48 | raise NotImplementedError 49 | 50 | 51 | class BaseImageDataset(BaseDataset): 52 | """ 53 | Base class of image reid dataset 54 | """ 55 | 56 | def print_dataset_statistics(self, train, query, gallery): 57 | num_train_pids, num_train_imgs, num_train_cams, num_train_views = self.get_imagedata_info(train) 58 | num_query_pids, num_query_imgs, num_query_cams, num_train_views = self.get_imagedata_info(query) 59 | num_gallery_pids, num_gallery_imgs, num_gallery_cams, num_train_views = self.get_imagedata_info(gallery) 60 | 61 | print("Dataset statistics:") 62 | print(" ----------------------------------------") 63 | print(" subset | # ids | # images | # cameras") 64 | print(" ----------------------------------------") 65 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 66 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 67 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 68 | print(" ----------------------------------------") 69 | 70 | 71 | class ImageDataset(Dataset): 72 | def __init__(self, dataset, transform=None): 73 | self.dataset = dataset 74 | self.transform = transform 75 | 76 | def __len__(self): 77 | return len(self.dataset) 78 | 79 | def __getitem__(self, index): 80 | img_path, pid, camid, trackid = self.dataset[index] 81 | img = read_image(img_path) 82 | 83 | if self.transform is not None: 84 | img = self.transform(img) 85 | 86 | return img, pid, camid, trackid,img_path.split('/')[-1] -------------------------------------------------------------------------------- /datasets/cuhk03.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import glob 3 | import re 4 | 5 | import os.path as osp 6 | 7 | from .bases import BaseImageDataset 8 | from collections import defaultdict 9 | import pickle 10 | class CUHK03(BaseImageDataset): 11 | """ 12 | Dataset statistics: 13 | # identities: 1467 14 | # images: 7368 (train) + 1400 (query) + 5328 (gallery) 15 | """ 16 | dataset_dir = 'cuhk03-np/detected' 17 | 18 | def __init__(self, root='', verbose=True, pid_begin = 0, **kwargs): 19 | super(CUHK03, self).__init__() 20 | self.dataset_dir = osp.join(root, self.dataset_dir) 21 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 22 | self.query_dir = osp.join(self.dataset_dir, 'query') 23 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 24 | 25 | self._check_before_run() 26 | self.pid_begin = pid_begin 27 | train = self._process_dir(self.train_dir, relabel=True) 28 | query = self._process_dir(self.query_dir, relabel=False) 29 | gallery = self._process_dir(self.gallery_dir, relabel=False) 30 | 31 | if verbose: 32 | print("=> CUHK03 loaded") 33 | self.print_dataset_statistics(train, query, gallery) 34 | 35 | self.train = train 36 | self.query = query 37 | self.gallery = gallery 38 | 39 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train) 40 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query) 41 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery) 42 | 43 | def _check_before_run(self): 44 | """Check if all files are available before going deeper""" 45 | if not osp.exists(self.dataset_dir): 46 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 47 | if not osp.exists(self.train_dir): 48 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 49 | if not osp.exists(self.query_dir): 50 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 51 | if not osp.exists(self.gallery_dir): 52 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 53 | 54 | def _process_dir(self, dir_path, relabel=False): 55 | img_paths = glob.glob(osp.join(dir_path, '*.png')) 56 | pattern = re.compile(r'([-\d]+)_c(\d)') 57 | 58 | pid_container = set() 59 | for img_path in sorted(img_paths): 60 | pid, _ = map(int, pattern.search(img_path).groups()) 61 | if pid == -1: continue # junk images are just ignored 62 | pid_container.add(pid) 63 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 64 | dataset = [] 65 | for img_path in sorted(img_paths): 66 | pid, camid = map(int, pattern.search(img_path).groups()) 67 | if pid == -1: continue # junk images are just ignored 68 | assert 1 <= pid <= 1467 69 | assert 1 <= camid <= 2 70 | camid -= 1 # index starts from 0 71 | if relabel: pid = pid2label[pid] 72 | 73 | dataset.append((img_path, self.pid_begin + pid, camid, 1)) 74 | return dataset 75 | -------------------------------------------------------------------------------- /datasets/dukemtmcreid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import glob 4 | import re 5 | import urllib 6 | import zipfile 7 | 8 | import os.path as osp 9 | 10 | from utils.iotools import mkdir_if_missing 11 | from .bases import BaseImageDataset 12 | 13 | 14 | class DukeMTMCreID(BaseImageDataset): 15 | """ 16 | DukeMTMC-reID 17 | Reference: 18 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 19 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 20 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 21 | 22 | Dataset statistics: 23 | # identities: 1404 (train + query) 24 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 25 | # cameras: 8 26 | """ 27 | dataset_dir = 'DukeMTMC-reID' 28 | 29 | def __init__(self, root='', verbose=True, pid_begin=0, **kwargs): 30 | super(DukeMTMCreID, self).__init__() 31 | self.dataset_dir = osp.join(root, self.dataset_dir) 32 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 33 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 34 | self.query_dir = osp.join(self.dataset_dir, 'query') 35 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 36 | self.pid_begin = pid_begin 37 | self._download_data() 38 | self._check_before_run() 39 | 40 | train = self._process_dir(self.train_dir, relabel=True) 41 | query = self._process_dir(self.query_dir, relabel=False) 42 | gallery = self._process_dir(self.gallery_dir, relabel=False) 43 | 44 | if verbose: 45 | print("=> DukeMTMC-reID loaded") 46 | self.print_dataset_statistics(train, query, gallery) 47 | 48 | self.train = train 49 | self.query = query 50 | self.gallery = gallery 51 | 52 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train) 53 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query) 54 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery) 55 | 56 | def _download_data(self): 57 | if osp.exists(self.dataset_dir): 58 | print("This dataset has been downloaded.") 59 | return 60 | 61 | print("Creating directory {}".format(self.dataset_dir)) 62 | mkdir_if_missing(self.dataset_dir) 63 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 64 | 65 | print("Downloading DukeMTMC-reID dataset") 66 | urllib.request.urlretrieve(self.dataset_url, fpath) 67 | 68 | print("Extracting files") 69 | zip_ref = zipfile.ZipFile(fpath, 'r') 70 | zip_ref.extractall(self.dataset_dir) 71 | zip_ref.close() 72 | 73 | def _check_before_run(self): 74 | """Check if all files are available before going deeper""" 75 | if not osp.exists(self.dataset_dir): 76 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 77 | if not osp.exists(self.train_dir): 78 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 79 | if not osp.exists(self.query_dir): 80 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 81 | if not osp.exists(self.gallery_dir): 82 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 83 | 84 | def _process_dir(self, dir_path, relabel=False): 85 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 86 | pattern = re.compile(r'([-\d]+)_c(\d)') 87 | 88 | pid_container = set() 89 | for img_path in img_paths: 90 | pid, _ = map(int, pattern.search(img_path).groups()) 91 | pid_container.add(pid) 92 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 93 | 94 | dataset = [] 95 | cam_container = set() 96 | for img_path in img_paths: 97 | pid, camid = map(int, pattern.search(img_path).groups()) 98 | assert 1 <= camid <= 8 99 | camid -= 1 # index starts from 0 100 | if relabel: pid = pid2label[pid] 101 | dataset.append((img_path, self.pid_begin + pid, camid, 1)) 102 | cam_container.add(camid) 103 | print(cam_container, 'cam_container') 104 | return dataset 105 | -------------------------------------------------------------------------------- /datasets/make_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as T 3 | from .preprocessing import RandomGrayscalePatchReplacement 4 | from torch.utils.data import DataLoader 5 | 6 | from .bases import ImageDataset 7 | from timm.data.random_erasing import RandomErasing 8 | from .sampler import RandomIdentitySampler, RandomIdentityBatchSampler 9 | from .dukemtmcreid import DukeMTMCreID 10 | from .market1501 import Market1501 11 | from .msmt17 import MSMT17 12 | from .cuhk03 import CUHK03 13 | from .sampler_ddp import RandomIdentitySampler_DDP, RandomIdentityBatchSampler_DDP 14 | import torch.distributed as dist 15 | from .occ_duke import OCC_DukeMTMCreID 16 | from .vehicleid import VehicleID 17 | from .veri import VeRi 18 | __factory = { 19 | 'market1501': Market1501, 20 | 'dukemtmc': DukeMTMCreID, 21 | 'msmt17': MSMT17, 22 | 'occ_duke': OCC_DukeMTMCreID, 23 | 'veri': VeRi, 24 | 'VehicleID': VehicleID, 25 | 'cuhk03':CUHK03, 26 | } 27 | 28 | def train_collate_fn(batch): 29 | """ 30 | # collate_fn这个函数的输入就是一个list,list的长度是一个batch size,list中的每个元素都是__getitem__得到的结果 31 | """ 32 | imgs, pids, camids, viewids , _ = zip(*batch) 33 | pids = torch.tensor(pids, dtype=torch.int64) 34 | viewids = torch.tensor(viewids, dtype=torch.int64) 35 | camids = torch.tensor(camids, dtype=torch.int64) 36 | return torch.stack(imgs, dim=0), pids, camids, viewids, 37 | 38 | def val_collate_fn(batch): 39 | imgs, pids, camids, viewids, img_paths = zip(*batch) 40 | viewids = torch.tensor(viewids, dtype=torch.int64) 41 | camids_batch = torch.tensor(camids, dtype=torch.int64) 42 | return torch.stack(imgs, dim=0), pids, camids, camids_batch, viewids, img_paths 43 | 44 | def make_dataloader(cfg): 45 | train_transforms = T.Compose([ 46 | T.Resize(cfg.INPUT.SIZE_TRAIN, interpolation=3), 47 | RandomGrayscalePatchReplacement(probability=cfg.INPUT.GS_PROB), 48 | T.RandomHorizontalFlip(p=cfg.INPUT.PROB), 49 | T.Pad(cfg.INPUT.PADDING), 50 | T.RandomCrop(cfg.INPUT.SIZE_TRAIN), 51 | T.ToTensor(), 52 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD), 53 | RandomErasing(probability=cfg.INPUT.RE_PROB, mode='pixel', max_count=1, device='cpu'), 54 | # RandomErasing(probability=cfg.INPUT.RE_PROB, mean=cfg.INPUT.PIXEL_MEAN) 55 | ]) 56 | 57 | val_transforms = T.Compose([ 58 | T.Resize(cfg.INPUT.SIZE_TEST), 59 | T.ToTensor(), 60 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) 61 | ]) 62 | 63 | num_workers = cfg.DATALOADER.NUM_WORKERS 64 | 65 | dataset = __factory[cfg.DATASETS.NAMES](root=cfg.DATASETS.ROOT_DIR) 66 | 67 | train_set = ImageDataset(dataset.train, train_transforms) 68 | train_set_normal = ImageDataset(dataset.train, val_transforms) 69 | num_classes = dataset.num_train_pids 70 | cam_num = dataset.num_train_cams 71 | view_num = dataset.num_train_vids 72 | 73 | if 'triplet' in cfg.DATALOADER.SAMPLER: 74 | if cfg.MODEL.DIST_TRAIN: 75 | print('DIST_TRAIN START') 76 | mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // dist.get_world_size() 77 | if cfg.MODEL.ID_HARD_MINING: 78 | data_sampler = RandomIdentityBatchSampler_DDP(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE, num_classes) 79 | else: 80 | data_sampler = RandomIdentitySampler_DDP(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE) 81 | batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True) 82 | train_loader = torch.utils.data.DataLoader( 83 | train_set, 84 | num_workers=num_workers, 85 | batch_sampler=batch_sampler, 86 | collate_fn=train_collate_fn, 87 | pin_memory=True, 88 | ) 89 | else: 90 | if cfg.MODEL.ID_HARD_MINING: 91 | train_loader = DataLoader( 92 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, 93 | sampler=RandomIdentityBatchSampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE, num_classes), 94 | num_workers=num_workers, collate_fn=train_collate_fn 95 | ) 96 | else: 97 | train_loader = DataLoader( 98 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, 99 | sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE), 100 | num_workers=num_workers, collate_fn=train_collate_fn 101 | ) 102 | elif cfg.DATALOADER.SAMPLER == 'softmax': 103 | print('using softmax sampler') 104 | train_loader = DataLoader( 105 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers, 106 | collate_fn=train_collate_fn 107 | ) 108 | else: 109 | print('unsupported sampler! expected softmax or triplet but got {}'.format(cfg.SAMPLER)) 110 | 111 | val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms) 112 | 113 | val_loader = DataLoader( 114 | val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 115 | collate_fn=val_collate_fn 116 | ) 117 | train_loader_normal = DataLoader( 118 | train_set_normal, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 119 | collate_fn=val_collate_fn 120 | ) 121 | return train_loader, train_loader_normal, val_loader, len(dataset.query), num_classes, cam_num, view_num 122 | -------------------------------------------------------------------------------- /datasets/market1501.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import glob 4 | import re 5 | 6 | import os.path as osp 7 | 8 | from .bases import BaseImageDataset 9 | from collections import defaultdict 10 | import pickle 11 | class Market1501(BaseImageDataset): 12 | """ 13 | Market1501 14 | Reference: 15 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 16 | URL: http://www.liangzheng.org/Project/project_reid.html 17 | 18 | Dataset statistics: 19 | # identities: 1501 (+1 for background) 20 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 21 | """ 22 | dataset_dir = 'market1501' 23 | 24 | def __init__(self, root='', verbose=True, pid_begin = 0, **kwargs): 25 | super(Market1501, self).__init__() 26 | self.dataset_dir = osp.join(root, self.dataset_dir) 27 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 28 | self.query_dir = osp.join(self.dataset_dir, 'query') 29 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 30 | 31 | self._check_before_run() 32 | self.pid_begin = pid_begin 33 | train = self._process_dir(self.train_dir, relabel=True) 34 | query = self._process_dir(self.query_dir, relabel=False) 35 | gallery = self._process_dir(self.gallery_dir, relabel=False) 36 | 37 | if verbose: 38 | print("=> Market1501 loaded") 39 | self.print_dataset_statistics(train, query, gallery) 40 | 41 | self.train = train 42 | self.query = query 43 | self.gallery = gallery 44 | 45 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train) 46 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query) 47 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery) 48 | 49 | def _check_before_run(self): 50 | """Check if all files are available before going deeper""" 51 | if not osp.exists(self.dataset_dir): 52 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 53 | if not osp.exists(self.train_dir): 54 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 55 | if not osp.exists(self.query_dir): 56 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 57 | if not osp.exists(self.gallery_dir): 58 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 59 | 60 | def _process_dir(self, dir_path, relabel=False): 61 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 62 | pattern = re.compile(r'([-\d]+)_c(\d)') 63 | 64 | pid_container = set() 65 | for img_path in sorted(img_paths): 66 | pid, _ = map(int, pattern.search(img_path).groups()) 67 | if pid == -1: continue # junk images are just ignored 68 | pid_container.add(pid) 69 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 70 | dataset = [] 71 | for img_path in sorted(img_paths): 72 | pid, camid = map(int, pattern.search(img_path).groups()) 73 | if pid == -1: continue # junk images are just ignored 74 | assert 0 <= pid <= 1501 # pid == 0 means background 75 | assert 1 <= camid <= 6 76 | camid -= 1 # index starts from 0 77 | if relabel: pid = pid2label[pid] 78 | 79 | dataset.append((img_path, self.pid_begin + pid, camid, 1)) 80 | return dataset 81 | -------------------------------------------------------------------------------- /datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | 2 | import glob 3 | import re 4 | 5 | import os.path as osp 6 | 7 | from .bases import BaseImageDataset 8 | 9 | 10 | class MSMT17(BaseImageDataset): 11 | """ 12 | MSMT17 13 | 14 | Reference: 15 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018. 16 | 17 | URL: http://www.pkuvmc.com/publications/msmt17.html 18 | 19 | Dataset statistics: 20 | # identities: 4101 21 | # images: 32621 (train) + 11659 (query) + 82161 (gallery) 22 | # cameras: 15 23 | """ 24 | dataset_dir = 'MSMT17_V1' 25 | 26 | def __init__(self, root='', verbose=True, pid_begin=0, **kwargs): 27 | super(MSMT17, self).__init__() 28 | self.pid_begin = pid_begin 29 | self.dataset_dir = osp.join(root, self.dataset_dir) 30 | self.train_dir = osp.join(self.dataset_dir, 'train') 31 | self.test_dir = osp.join(self.dataset_dir, 'test') 32 | self.list_train_path = osp.join(self.dataset_dir, 'list_train.txt') 33 | self.list_val_path = osp.join(self.dataset_dir, 'list_val.txt') 34 | self.list_query_path = osp.join(self.dataset_dir, 'list_query.txt') 35 | self.list_gallery_path = osp.join(self.dataset_dir, 'list_gallery.txt') 36 | 37 | self._check_before_run() 38 | train = self._process_dir(self.train_dir, self.list_train_path) 39 | val = self._process_dir(self.train_dir, self.list_val_path) 40 | train += val 41 | query = self._process_dir(self.test_dir, self.list_query_path) 42 | gallery = self._process_dir(self.test_dir, self.list_gallery_path) 43 | if verbose: 44 | print("=> MSMT17 loaded") 45 | self.print_dataset_statistics(train, query, gallery) 46 | 47 | self.train = train 48 | self.query = query 49 | self.gallery = gallery 50 | 51 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train) 52 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query) 53 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery) 54 | def _check_before_run(self): 55 | """Check if all files are available before going deeper""" 56 | if not osp.exists(self.dataset_dir): 57 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 58 | if not osp.exists(self.train_dir): 59 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 60 | if not osp.exists(self.test_dir): 61 | raise RuntimeError("'{}' is not available".format(self.test_dir)) 62 | 63 | def _process_dir(self, dir_path, list_path): 64 | with open(list_path, 'r') as txt: 65 | lines = txt.readlines() 66 | dataset = [] 67 | pid_container = set() 68 | cam_container = set() 69 | for img_idx, img_info in enumerate(lines): 70 | img_path, pid = img_info.split(' ') 71 | pid = int(pid) # no need to relabel 72 | camid = int(img_path.split('_')[2]) 73 | img_path = osp.join(dir_path, img_path) 74 | dataset.append((img_path, self.pid_begin +pid, camid-1, 1)) 75 | pid_container.add(pid) 76 | cam_container.add(camid) 77 | print(cam_container, 'cam_container') 78 | # check if pid starts from 0 and increments with 1 79 | for idx, pid in enumerate(pid_container): 80 | assert idx == pid, "See code comment for explanation" 81 | return dataset -------------------------------------------------------------------------------- /datasets/occ_duke.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import glob 4 | import re 5 | import urllib 6 | import zipfile 7 | 8 | import os.path as osp 9 | 10 | from utils.iotools import mkdir_if_missing 11 | from .bases import BaseImageDataset 12 | 13 | 14 | class OCC_DukeMTMCreID(BaseImageDataset): 15 | """ 16 | DukeMTMC-reID 17 | Reference: 18 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 19 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 20 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 21 | 22 | Dataset statistics: 23 | # identities: 1404 (train + query) 24 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 25 | # cameras: 8 26 | """ 27 | dataset_dir = 'Occluded_Duke' 28 | 29 | def __init__(self, root='', verbose=True, pid_begin=0, **kwargs): 30 | super(OCC_DukeMTMCreID, self).__init__() 31 | self.dataset_dir = osp.join(root, self.dataset_dir) 32 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 33 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 34 | self.query_dir = osp.join(self.dataset_dir, 'query') 35 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 36 | self.pid_begin = pid_begin 37 | self._download_data() 38 | self._check_before_run() 39 | 40 | train = self._process_dir(self.train_dir, relabel=True) 41 | query = self._process_dir(self.query_dir, relabel=False) 42 | gallery = self._process_dir(self.gallery_dir, relabel=False) 43 | 44 | if verbose: 45 | print("=> DukeMTMC-reID loaded") 46 | self.print_dataset_statistics(train, query, gallery) 47 | 48 | self.train = train 49 | self.query = query 50 | self.gallery = gallery 51 | 52 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train) 53 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query) 54 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery) 55 | 56 | def _download_data(self): 57 | if osp.exists(self.dataset_dir): 58 | print("This dataset has been downloaded.") 59 | return 60 | 61 | print("Creating directory {}".format(self.dataset_dir)) 62 | mkdir_if_missing(self.dataset_dir) 63 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 64 | 65 | print("Downloading DukeMTMC-reID dataset") 66 | urllib.request.urlretrieve(self.dataset_url, fpath) 67 | 68 | print("Extracting files") 69 | zip_ref = zipfile.ZipFile(fpath, 'r') 70 | zip_ref.extractall(self.dataset_dir) 71 | zip_ref.close() 72 | 73 | def _check_before_run(self): 74 | """Check if all files are available before going deeper""" 75 | if not osp.exists(self.dataset_dir): 76 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 77 | if not osp.exists(self.train_dir): 78 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 79 | if not osp.exists(self.query_dir): 80 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 81 | if not osp.exists(self.gallery_dir): 82 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 83 | 84 | def _process_dir(self, dir_path, relabel=False): 85 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 86 | pattern = re.compile(r'([-\d]+)_c(\d)') 87 | 88 | pid_container = set() 89 | for img_path in img_paths: 90 | pid, _ = map(int, pattern.search(img_path).groups()) 91 | pid_container.add(pid) 92 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 93 | 94 | dataset = [] 95 | cam_container = set() 96 | for img_path in img_paths: 97 | pid, camid = map(int, pattern.search(img_path).groups()) 98 | assert 1 <= camid <= 8 99 | camid -= 1 # index starts from 0 100 | if relabel: pid = pid2label[pid] 101 | dataset.append((img_path, self.pid_begin + pid, camid, 1)) 102 | cam_container.add(camid) 103 | print(cam_container, 'cam_container') 104 | return dataset 105 | -------------------------------------------------------------------------------- /datasets/preprocessing.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | import numpy as np 4 | from PIL import Image 5 | import cv2 6 | 7 | 8 | class RandomErasing(object): 9 | """ Randomly selects a rectangle region in an image and erases its pixels. 10 | 'Random Erasing Data Augmentation' by Zhong et al. 11 | See https://arxiv.org/pdf/1708.04896.pdf 12 | Args: 13 | probability: The probability that the Random Erasing operation will be performed. 14 | sl: Minimum proportion of erased area against input image. 15 | sh: Maximum proportion of erased area against input image. 16 | r1: Minimum aspect ratio of erased area. 17 | mean: Erasing value. 18 | """ 19 | 20 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 21 | self.probability = probability 22 | self.mean = mean 23 | self.sl = sl 24 | self.sh = sh 25 | self.r1 = r1 26 | 27 | def __call__(self, img): 28 | 29 | if random.uniform(0, 1) >= self.probability: 30 | return img 31 | 32 | for attempt in range(100): 33 | area = img.size()[1] * img.size()[2] 34 | 35 | target_area = random.uniform(self.sl, self.sh) * area 36 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 37 | 38 | h = int(round(math.sqrt(target_area * aspect_ratio))) 39 | w = int(round(math.sqrt(target_area / aspect_ratio))) 40 | 41 | if w < img.size()[2] and h < img.size()[1]: 42 | x1 = random.randint(0, img.size()[1] - h) 43 | y1 = random.randint(0, img.size()[2] - w) 44 | if img.size()[0] == 3: 45 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 46 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 47 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 48 | else: 49 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 50 | return img 51 | 52 | return img 53 | 54 | class RandomGrayscalePatchReplacement(object): 55 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3): 56 | self.probability = probability 57 | self.sl = sl 58 | self.sh = sh 59 | self.r1 = r1 60 | 61 | def __call__(self, img, max_attempt_num=100): 62 | """ 63 | References: 64 | https://arxiv.org/abs/2101.08533 65 | https://github.com/finger-monkey/Data-Augmentation/blob/main/trans_gray.py 66 | """ 67 | if random.uniform(0, 1) >= self.probability: 68 | return img 69 | img = np.array(img) 70 | img = img.copy() 71 | image_height, image_width = img.shape[:-1] 72 | image_area = image_height * image_width 73 | for _ in range(max_attempt_num): 74 | target_area = np.random.uniform(self.sl, self.sh) * image_area 75 | aspect_ratio = np.random.uniform(self.r1, 1 / self.r1) 76 | erasing_height = int(np.round(np.sqrt(target_area * aspect_ratio))) 77 | erasing_width = int(np.round(np.sqrt(target_area / aspect_ratio))) 78 | if erasing_width < image_width and erasing_height < image_height: 79 | starting_height = np.random.randint(0, 80 | image_height - erasing_height) 81 | starting_width = np.random.randint(0, image_width - erasing_width) 82 | patch_in_RGB = img[starting_height:starting_height + 83 | erasing_height, 84 | starting_width:starting_width + 85 | erasing_width] 86 | patch_in_GRAY = cv2.cvtColor(patch_in_RGB, cv2.COLOR_RGB2GRAY) 87 | for index in range(3): 88 | img[starting_height:starting_height + erasing_height, 89 | starting_width:starting_width + erasing_width, 90 | index] = patch_in_GRAY 91 | break 92 | img = Image.fromarray(img) 93 | return img -------------------------------------------------------------------------------- /datasets/sampler.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | from collections import defaultdict 3 | import copy 4 | import random 5 | import numpy as np 6 | import torch 7 | 8 | class RandomIdentitySampler(Sampler): 9 | """ 10 | Randomly sample N identities, then for each identity, 11 | randomly sample K instances, therefore batch size is N*K. 12 | Args: 13 | - data_source (list): list of (img_path, pid, camid). 14 | - num_instances (int): number of instances per identity in a batch. 15 | - batch_size (int): number of examples in a batch. 16 | """ 17 | 18 | def __init__(self, data_source, batch_size, num_instances): 19 | self.data_source = data_source 20 | self.batch_size = batch_size 21 | self.num_instances = num_instances 22 | self.num_pids_per_batch = self.batch_size // self.num_instances 23 | self.index_dic = defaultdict(list) #dict with list value 24 | #{783: [0, 5, 116, 876, 1554, 2041],...,} 25 | for index, (_, pid, _, _) in enumerate(self.data_source): 26 | self.index_dic[pid].append(index) 27 | self.pids = list(self.index_dic.keys()) 28 | 29 | # estimate number of examples in an epoch 30 | self.length = 0 31 | for pid in self.pids: 32 | idxs = self.index_dic[pid] 33 | num = len(idxs) 34 | if num < self.num_instances: 35 | num = self.num_instances 36 | self.length += num - num % self.num_instances 37 | 38 | def __iter__(self): 39 | batch_idxs_dict = defaultdict(list) 40 | 41 | for pid in self.pids: 42 | idxs = copy.deepcopy(self.index_dic[pid]) 43 | if len(idxs) < self.num_instances: 44 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 45 | random.shuffle(idxs) 46 | batch_idxs = [] 47 | for idx in idxs: 48 | batch_idxs.append(idx) 49 | if len(batch_idxs) == self.num_instances: 50 | batch_idxs_dict[pid].append(batch_idxs) 51 | batch_idxs = [] 52 | 53 | avai_pids = copy.deepcopy(self.pids) 54 | final_idxs = [] 55 | 56 | while len(avai_pids) >= self.num_pids_per_batch: 57 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 58 | for pid in selected_pids: 59 | batch_idxs = batch_idxs_dict[pid].pop(0) 60 | final_idxs.extend(batch_idxs) 61 | if len(batch_idxs_dict[pid]) == 0: 62 | avai_pids.remove(pid) 63 | 64 | return iter(final_idxs) 65 | 66 | def __len__(self): 67 | return self.length 68 | 69 | class RandomIdentityBatchSampler(Sampler): 70 | """ 71 | Randomly sample N identities, then for each identity, 72 | randomly sample K instances, therefore batch size is N*K. 73 | Args: 74 | - data_source (list): list of (img_path, pid, camid). 75 | - num_instances (int): number of instances per identity in a batch. 76 | - batch_size (int): number of examples in a batch. 77 | """ 78 | 79 | def __init__(self, data_source, batch_size, num_instances, num_classes): 80 | self.data_source = data_source 81 | self.batch_size = batch_size 82 | self.num_instances = num_instances 83 | self.num_pids_per_batch = self.batch_size // self.num_instances 84 | self.index_dic = defaultdict(list) #dict with list value 85 | #{783: [0, 5, 116, 876, 1554, 2041],...,} 86 | for index, (_, pid, _, _) in enumerate(self.data_source): 87 | self.index_dic[pid].append(index) 88 | self.pids = list(self.index_dic.keys()) 89 | self.weight = torch.zeros((num_classes, 768)) 90 | 91 | # estimate number of examples in an epoch 92 | self.length = 0 93 | for pid in self.pids: 94 | idxs = self.index_dic[pid] 95 | num = len(idxs) 96 | if num < self.num_instances: 97 | num = self.num_instances 98 | self.length += num - num % self.num_instances 99 | 100 | def __iter__(self): 101 | batch_idxs_dict = defaultdict(list) 102 | 103 | for pid in self.pids: 104 | idxs = copy.deepcopy(self.index_dic[pid]) 105 | if len(idxs) < self.num_instances: 106 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 107 | random.shuffle(idxs) 108 | batch_idxs = [] 109 | for idx in idxs: 110 | batch_idxs.append(idx) 111 | if len(batch_idxs) == self.num_instances: 112 | batch_idxs_dict[pid].append(batch_idxs) 113 | batch_idxs = [] 114 | 115 | final_idxs = [] 116 | 117 | n = random.randint(0, len(self.weight)-1) 118 | similarity = torch.cosine_similarity(self.weight[n], self.weight, dim=-1) 119 | _, simi_idx = torch.sort(similarity, dim=-1, descending=True) 120 | avai_pids = copy.deepcopy(simi_idx.numpy().tolist()) 121 | 122 | while len(avai_pids) >= self.num_pids_per_batch: 123 | selected_pids = avai_pids[:self.num_pids_per_batch] 124 | # selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 125 | for pid in selected_pids: 126 | batch_idxs = batch_idxs_dict[pid].pop(0) 127 | final_idxs.extend(batch_idxs) 128 | if len(batch_idxs_dict[pid]) == 0: 129 | avai_pids.remove(pid) 130 | 131 | return iter(final_idxs) 132 | 133 | def __len__(self): 134 | return self.length 135 | 136 | def update_weight(self, weight): 137 | self.weight = weight -------------------------------------------------------------------------------- /datasets/sampler_ddp.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | from collections import defaultdict 3 | import copy 4 | import random 5 | import numpy as np 6 | import math 7 | import torch.distributed as dist 8 | _LOCAL_PROCESS_GROUP = None 9 | import torch 10 | import pickle 11 | 12 | def _get_global_gloo_group(): 13 | """ 14 | Return a process group based on gloo backend, containing all the ranks 15 | The result is cached. 16 | """ 17 | if dist.get_backend() == "nccl": 18 | return dist.new_group(backend="gloo") 19 | else: 20 | return dist.group.WORLD 21 | 22 | def _serialize_to_tensor(data, group): 23 | backend = dist.get_backend(group) 24 | assert backend in ["gloo", "nccl"] 25 | device = torch.device("cpu" if backend == "gloo" else "cuda") 26 | 27 | buffer = pickle.dumps(data) 28 | if len(buffer) > 1024 ** 3: 29 | print( 30 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 31 | dist.get_rank(), len(buffer) / (1024 ** 3), device 32 | ) 33 | ) 34 | storage = torch.ByteStorage.from_buffer(buffer) 35 | tensor = torch.ByteTensor(storage).to(device=device) 36 | return tensor 37 | 38 | def _pad_to_largest_tensor(tensor, group): 39 | """ 40 | Returns: 41 | list[int]: size of the tensor, on each rank 42 | Tensor: padded tensor that has the max size 43 | """ 44 | world_size = dist.get_world_size(group=group) 45 | assert ( 46 | world_size >= 1 47 | ), "comm.gather/all_gather must be called from ranks within the given group!" 48 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) 49 | size_list = [ 50 | torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) 51 | ] 52 | dist.all_gather(size_list, local_size, group=group) 53 | size_list = [int(size.item()) for size in size_list] 54 | 55 | max_size = max(size_list) 56 | 57 | # we pad the tensor because torch all_gather does not support 58 | # gathering tensors of different shapes 59 | if local_size != max_size: 60 | padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) 61 | tensor = torch.cat((tensor, padding), dim=0) 62 | return size_list, tensor 63 | 64 | def all_gather(data, group=None): 65 | """ 66 | Run all_gather on arbitrary picklable data (not necessarily tensors). 67 | Args: 68 | data: any picklable object 69 | group: a torch process group. By default, will use a group which 70 | contains all ranks on gloo backend. 71 | Returns: 72 | list[data]: list of data gathered from each rank 73 | """ 74 | if dist.get_world_size() == 1: 75 | return [data] 76 | if group is None: 77 | group = _get_global_gloo_group() 78 | if dist.get_world_size(group) == 1: 79 | return [data] 80 | 81 | tensor = _serialize_to_tensor(data, group) 82 | 83 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 84 | max_size = max(size_list) 85 | 86 | # receiving Tensor from all ranks 87 | tensor_list = [ 88 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list 89 | ] 90 | dist.all_gather(tensor_list, tensor, group=group) 91 | 92 | data_list = [] 93 | for size, tensor in zip(size_list, tensor_list): 94 | buffer = tensor.cpu().numpy().tobytes()[:size] 95 | data_list.append(pickle.loads(buffer)) 96 | 97 | return data_list 98 | 99 | def shared_random_seed(): 100 | """ 101 | Returns: 102 | int: a random number that is the same across all workers. 103 | If workers need a shared RNG, they can use this shared seed to 104 | create one. 105 | All workers must call this function, otherwise it will deadlock. 106 | """ 107 | ints = np.random.randint(2 ** 31) 108 | all_ints = all_gather(ints) 109 | return all_ints[0] 110 | 111 | class RandomIdentitySampler_DDP(Sampler): 112 | """ 113 | Randomly sample N identities, then for each identity, 114 | randomly sample K instances, therefore batch size is N*K. 115 | Args: 116 | - data_source (list): list of (img_path, pid, camid). 117 | - num_instances (int): number of instances per identity in a batch. 118 | - batch_size (int): number of examples in a batch. 119 | """ 120 | 121 | def __init__(self, data_source, batch_size, num_instances): 122 | self.data_source = data_source 123 | self.batch_size = batch_size 124 | self.world_size = dist.get_world_size() 125 | self.num_instances = num_instances 126 | self.mini_batch_size = self.batch_size // self.world_size 127 | self.num_pids_per_batch = self.mini_batch_size // self.num_instances 128 | self.index_dic = defaultdict(list) 129 | 130 | for index, (_, pid, _, _) in enumerate(self.data_source): 131 | self.index_dic[pid].append(index) 132 | self.pids = list(self.index_dic.keys()) 133 | 134 | # estimate number of examples in an epoch 135 | self.length = 0 136 | for pid in self.pids: 137 | idxs = self.index_dic[pid] 138 | num = len(idxs) 139 | if num < self.num_instances: 140 | num = self.num_instances 141 | self.length += num - num % self.num_instances 142 | 143 | self.rank = dist.get_rank() 144 | #self.world_size = dist.get_world_size() 145 | self.length //= self.world_size 146 | 147 | def __iter__(self): 148 | seed = shared_random_seed() 149 | np.random.seed(seed) 150 | self._seed = int(seed) 151 | final_idxs = self.sample_list() 152 | length = int(math.ceil(len(final_idxs) * 1.0 / self.world_size)) 153 | #final_idxs = final_idxs[self.rank * length:(self.rank + 1) * length] 154 | final_idxs = self.__fetch_current_node_idxs(final_idxs, length) 155 | self.length = len(final_idxs) 156 | return iter(final_idxs) 157 | 158 | 159 | def __fetch_current_node_idxs(self, final_idxs, length): 160 | total_num = len(final_idxs) 161 | block_num = (length // self.mini_batch_size) 162 | index_target = [] 163 | for i in range(0, block_num * self.world_size, self.world_size): 164 | index = range(self.mini_batch_size * self.rank + self.mini_batch_size * i, min(self.mini_batch_size * self.rank + self.mini_batch_size * (i+1), total_num)) 165 | index_target.extend(index) 166 | index_target_npy = np.array(index_target) 167 | final_idxs = list(np.array(final_idxs)[index_target_npy]) 168 | return final_idxs 169 | 170 | 171 | def sample_list(self): 172 | #np.random.seed(self._seed) 173 | avai_pids = copy.deepcopy(self.pids) 174 | batch_idxs_dict = {} 175 | 176 | batch_indices = [] 177 | while len(avai_pids) >= self.num_pids_per_batch: 178 | selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False).tolist() 179 | for pid in selected_pids: 180 | if pid not in batch_idxs_dict: 181 | idxs = copy.deepcopy(self.index_dic[pid]) 182 | if len(idxs) < self.num_instances: 183 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist() 184 | np.random.shuffle(idxs) 185 | batch_idxs_dict[pid] = idxs 186 | 187 | avai_idxs = batch_idxs_dict[pid] 188 | for _ in range(self.num_instances): 189 | batch_indices.append(avai_idxs.pop(0)) 190 | 191 | if len(avai_idxs) < self.num_instances: avai_pids.remove(pid) 192 | 193 | return batch_indices 194 | 195 | def __len__(self): 196 | return self.length 197 | 198 | class RandomIdentityBatchSampler_DDP(Sampler): 199 | """ 200 | Randomly sample N identities, then for each identity, 201 | randomly sample K instances, therefore batch size is N*K. 202 | Args: 203 | - data_source (list): list of (img_path, pid, camid). 204 | - num_instances (int): number of instances per identity in a batch. 205 | - batch_size (int): number of examples in a batch. 206 | """ 207 | 208 | def __init__(self, data_source, batch_size, num_instances, num_classes): 209 | self.data_source = data_source 210 | self.batch_size = batch_size 211 | self.world_size = dist.get_world_size() 212 | self.num_instances = num_instances 213 | self.mini_batch_size = self.batch_size // self.world_size 214 | self.num_pids_per_batch = self.mini_batch_size // self.num_instances 215 | self.index_dic = defaultdict(list) 216 | self.weight = torch.zeros((num_classes, 768)) 217 | 218 | for index, (_, pid, _, _) in enumerate(self.data_source): 219 | self.index_dic[pid].append(index) 220 | self.pids = list(self.index_dic.keys()) 221 | 222 | # estimate number of examples in an epoch 223 | self.length = 0 224 | for pid in self.pids: 225 | idxs = self.index_dic[pid] 226 | num = len(idxs) 227 | if num < self.num_instances: 228 | num = self.num_instances 229 | self.length += num - num % self.num_instances 230 | 231 | self.rank = dist.get_rank() 232 | #self.world_size = dist.get_world_size() 233 | self.length //= self.world_size 234 | 235 | def __iter__(self): 236 | seed = shared_random_seed() 237 | np.random.seed(seed) 238 | self._seed = int(seed) 239 | final_idxs = self.sample_list() 240 | length = int(math.ceil(len(final_idxs) * 1.0 / self.world_size)) 241 | final_idxs = self.__fetch_current_node_idxs(final_idxs, length) 242 | self.length = len(final_idxs) 243 | return iter(final_idxs) 244 | 245 | 246 | def __fetch_current_node_idxs(self, final_idxs, length): 247 | total_num = len(final_idxs) 248 | block_num = (length // self.mini_batch_size) 249 | index_target = [] 250 | for i in range(0, block_num * self.world_size, self.world_size): 251 | index = range(self.mini_batch_size * self.rank + self.mini_batch_size * i, min(self.mini_batch_size * self.rank + self.mini_batch_size * (i+1), total_num)) 252 | index_target.extend(index) 253 | index_target_npy = np.array(index_target) 254 | final_idxs = list(np.array(final_idxs)[index_target_npy]) 255 | return final_idxs 256 | 257 | 258 | def sample_list(self): 259 | #np.random.seed(self._seed) 260 | avai_pids = copy.deepcopy(self.pids) 261 | batch_idxs_dict = {} 262 | 263 | n = random.randint(0, len(self.weight)-1) 264 | similarity = torch.cosine_similarity(self.weight[n], self.weight, dim=-1) 265 | _, simi_idx = torch.sort(similarity, dim=-1, descending=True) 266 | avai_pids = copy.deepcopy(simi_idx.numpy().tolist()) 267 | 268 | batch_indices = [] 269 | while len(avai_pids) >= self.num_pids_per_batch: 270 | selected_pids = avai_pids[:self.num_pids_per_batch] 271 | # selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False).tolist() 272 | for pid in selected_pids: 273 | if pid not in batch_idxs_dict: 274 | idxs = copy.deepcopy(self.index_dic[pid]) 275 | if len(idxs) < self.num_instances: 276 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist() 277 | np.random.shuffle(idxs) 278 | batch_idxs_dict[pid] = idxs 279 | 280 | avai_idxs = batch_idxs_dict[pid] 281 | for _ in range(self.num_instances): 282 | batch_indices.append(avai_idxs.pop(0)) 283 | 284 | if len(avai_idxs) < self.num_instances: avai_pids.remove(pid) 285 | 286 | return batch_indices 287 | 288 | def update_weight(self, weight): 289 | self.weight = weight 290 | 291 | def __len__(self): 292 | return self.length -------------------------------------------------------------------------------- /datasets/vehicleid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import random 5 | import os.path as osp 6 | from .bases import BaseImageDataset 7 | from collections import defaultdict 8 | import pickle 9 | 10 | class VehicleID(BaseImageDataset): 11 | """ 12 | VehicleID 13 | Reference: 14 | Deep Relative Distance Learning: Tell the Difference Between Similar Vehicles 15 | 16 | Dataset statistics: 17 | # train_list: 13164 vehicles for model training 18 | # test_list_800: 800 vehicles for model testing(small test set in paper 19 | # test_list_1600: 1600 vehicles for model testing(medium test set in paper 20 | # test_list_2400: 2400 vehicles for model testing(large test set in paper 21 | # test_list_3200: 3200 vehicles for model testing 22 | # test_list_6000: 6000 vehicles for model testing 23 | # test_list_13164: 13164 vehicles for model testing 24 | """ 25 | dataset_dir = 'VehicleID_V1.0' 26 | 27 | def __init__(self, root='', verbose=True, test_size=800, **kwargs): 28 | super(VehicleID, self).__init__() 29 | self.dataset_dir = osp.join(root, self.dataset_dir) 30 | self.img_dir = osp.join(self.dataset_dir, 'image') 31 | self.split_dir = osp.join(self.dataset_dir, 'train_test_split') 32 | self.train_list = osp.join(self.split_dir, 'train_list.txt') 33 | self.test_size = test_size 34 | 35 | if self.test_size == 800: 36 | self.test_list = osp.join(self.split_dir, 'test_list_800.txt') 37 | elif self.test_size == 1600: 38 | self.test_list = osp.join(self.split_dir, 'test_list_1600.txt') 39 | elif self.test_size == 2400: 40 | self.test_list = osp.join(self.split_dir, 'test_list_2400.txt') 41 | 42 | print(self.test_list) 43 | 44 | self.check_before_run() 45 | 46 | train, query, gallery = self.process_split(relabel=True) 47 | self.train = train 48 | self.query = query 49 | self.gallery = gallery 50 | 51 | if verbose: 52 | print('=> VehicleID loaded') 53 | self.print_dataset_statistics(train, query, gallery) 54 | 55 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info( 56 | self.train) 57 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info( 58 | self.query) 59 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info( 60 | self.gallery) 61 | 62 | def check_before_run(self): 63 | """Check if all files are available before going deeper""" 64 | if not osp.exists(self.dataset_dir): 65 | raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) 66 | if not osp.exists(self.split_dir): 67 | raise RuntimeError('"{}" is not available'.format(self.split_dir)) 68 | if not osp.exists(self.train_list): 69 | raise RuntimeError('"{}" is not available'.format(self.train_list)) 70 | if self.test_size not in [800, 1600, 2400]: 71 | raise RuntimeError('"{}" is not available'.format(self.test_size)) 72 | if not osp.exists(self.test_list): 73 | raise RuntimeError('"{}" is not available'.format(self.test_list)) 74 | 75 | def get_pid2label(self, pids): 76 | pid_container = set(pids) 77 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 78 | return pid2label 79 | 80 | 81 | def parse_img_pids(self, nl_pairs, pid2label=None, cam=0): 82 | # il_pair is the pairs of img name and label 83 | output = [] 84 | for info in nl_pairs: 85 | name = info[0] 86 | pid = info[1] 87 | if pid2label is not None: 88 | pid = pid2label[pid] 89 | camid = cam # use 0 or 1 90 | img_path = osp.join(self.img_dir, name+'.jpg') 91 | viewid = 1 92 | output.append((img_path, pid, camid, viewid)) 93 | return output 94 | 95 | def process_split(self, relabel=False): 96 | # read train paths 97 | train_pid_dict = defaultdict(list) 98 | 99 | # 'train_list.txt' format: 100 | # the first number is the number of image 101 | # the second number is the id of vehicle 102 | with open(self.train_list) as f_train: 103 | train_data = f_train.readlines() 104 | for data in train_data: 105 | name, pid = data.strip().split(' ') 106 | 107 | pid = int(pid) 108 | train_pid_dict[pid].append([name, pid]) 109 | train_pids = list(train_pid_dict.keys()) 110 | num_train_pids = len(train_pids) 111 | assert num_train_pids == 13164, 'There should be 13164 vehicles for training,' \ 112 | ' but but got {}, please check the data'\ 113 | .format(num_train_pids) 114 | # print('num of train ids: {}'.format(num_train_pids)) 115 | test_pid_dict = defaultdict(list) 116 | with open(self.test_list) as f_test: 117 | test_data = f_test.readlines() 118 | for data in test_data: 119 | name, pid = data.split(' ') 120 | pid = int(pid) 121 | test_pid_dict[pid].append([name, pid]) 122 | test_pids = list(test_pid_dict.keys()) 123 | num_test_pids = len(test_pids) 124 | assert num_test_pids == self.test_size, 'There should be {} vehicles for testing,' \ 125 | ' but but got {}, please check the data'\ 126 | .format(self.test_size, num_test_pids) 127 | 128 | train_data = [] 129 | query_data = [] 130 | gallery_data = [] 131 | train_pids = sorted(train_pids) 132 | # for train ids, all images are used in the train set. 133 | for pid in train_pids: 134 | imginfo = train_pid_dict[pid] # imginfo include image name and id 135 | train_data.extend(imginfo) 136 | 137 | # for each test id, random choose one image for gallery 138 | # and the other ones for query. 139 | for pid in test_pids: 140 | imginfo = test_pid_dict[pid] 141 | sample = random.choice(imginfo) 142 | imginfo.remove(sample) 143 | query_data.extend(imginfo) 144 | gallery_data.append(sample) 145 | 146 | if relabel: 147 | train_pid2label = self.get_pid2label(train_pids) 148 | else: 149 | train_pid2label = None 150 | 151 | train = self.parse_img_pids(train_data, train_pid2label) 152 | query = self.parse_img_pids(query_data, cam=0) 153 | gallery = self.parse_img_pids(gallery_data, cam=1) 154 | # attach different camera to prevent eval fail 155 | 156 | return train, query, gallery 157 | -------------------------------------------------------------------------------- /datasets/veri.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import re 3 | import os.path as osp 4 | 5 | from .bases import BaseImageDataset 6 | 7 | 8 | class VeRi(BaseImageDataset): 9 | """ 10 | VeRi-776 11 | Reference: 12 | Liu, Xinchen, et al. "Large-scale vehicle re-identification in urban surveillance videos." ICME 2016. 13 | 14 | URL:https://vehiclereid.github.io/VeRi/ 15 | 16 | Dataset statistics: 17 | # identities: 776 18 | # images: 37778 (train) + 1678 (query) + 11579 (gallery) 19 | # cameras: 20 20 | """ 21 | 22 | dataset_dir = 'VeRi' 23 | 24 | def __init__(self, root='', verbose=True, **kwargs): 25 | super(VeRi, self).__init__() 26 | self.dataset_dir = osp.join(root, self.dataset_dir) 27 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 28 | self.query_dir = osp.join(self.dataset_dir, 'image_query') 29 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test') 30 | 31 | self._check_before_run() 32 | 33 | path_train = 'datasets/keypoint_train.txt' 34 | with open(path_train, 'r') as txt: 35 | lines = txt.readlines() 36 | self.image_map_view_train = {} 37 | for img_idx, img_info in enumerate(lines): 38 | content = img_info.split(' ') 39 | viewid = int(content[-1]) 40 | self.image_map_view_train[osp.basename(content[0])] = viewid 41 | 42 | path_test = 'datasets/keypoint_test.txt' 43 | with open(path_test, 'r') as txt: 44 | lines = txt.readlines() 45 | self.image_map_view_test = {} 46 | for img_idx, img_info in enumerate(lines): 47 | content = img_info.split(' ') 48 | viewid = int(content[-1]) 49 | self.image_map_view_test[osp.basename(content[0])] = viewid 50 | 51 | train = self._process_dir(self.train_dir, relabel=True) 52 | query = self._process_dir(self.query_dir, relabel=False) 53 | gallery = self._process_dir(self.gallery_dir, relabel=False) 54 | 55 | if verbose: 56 | print("=> VeRi-776 loaded") 57 | self.print_dataset_statistics(train, query, gallery) 58 | 59 | self.train = train 60 | self.query = query 61 | self.gallery = gallery 62 | 63 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info( 64 | self.train) 65 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info( 66 | self.query) 67 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info( 68 | self.gallery) 69 | 70 | def _check_before_run(self): 71 | """Check if all files are available before going deeper""" 72 | if not osp.exists(self.dataset_dir): 73 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 74 | if not osp.exists(self.train_dir): 75 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 76 | if not osp.exists(self.query_dir): 77 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 78 | if not osp.exists(self.gallery_dir): 79 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 80 | 81 | def _process_dir(self, dir_path, relabel=False): 82 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 83 | pattern = re.compile(r'([-\d]+)_c(\d+)') 84 | 85 | pid_container = set() 86 | for img_path in img_paths: 87 | pid, _ = map(int, pattern.search(img_path).groups()) 88 | if pid == -1: continue # junk images are just ignored 89 | pid_container.add(pid) 90 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 91 | 92 | view_container = set() 93 | dataset = [] 94 | count = 0 95 | for img_path in img_paths: 96 | pid, camid = map(int, pattern.search(img_path).groups()) 97 | if pid == -1: continue # junk images are just ignored 98 | assert 0 <= pid <= 776 # pid == 0 means background 99 | assert 1 <= camid <= 20 100 | camid -= 1 # index starts from 0 101 | if relabel: pid = pid2label[pid] 102 | 103 | if osp.basename(img_path) not in self.image_map_view_train.keys(): 104 | try: 105 | viewid = self.image_map_view_test[osp.basename(img_path)] 106 | except: 107 | count += 1 108 | # print(img_path, 'img_path') 109 | continue 110 | else: 111 | viewid = self.image_map_view_train[osp.basename(img_path)] 112 | view_container.add(viewid) 113 | dataset.append((img_path, pid, camid, viewid)) 114 | print(view_container, 'view_container') 115 | print(count, 'samples without viewpoint annotations') 116 | return dataset 117 | 118 | -------------------------------------------------------------------------------- /figs/exp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alipay/Diverse-and-Compact-Transformer/de44d0a0b79f9a947f27368a9188f73b16515188/figs/exp.png -------------------------------------------------------------------------------- /figs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alipay/Diverse-and-Compact-Transformer/de44d0a0b79f9a947f27368a9188f73b16515188/figs/framework.png -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_loss import make_loss 2 | from .arcface import ArcFace -------------------------------------------------------------------------------- /loss/arcface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Parameter 5 | import math 6 | 7 | 8 | class ArcFace(nn.Module): 9 | def __init__(self, in_features, out_features, s=30.0, m=0.50, bias=False): 10 | super(ArcFace, self).__init__() 11 | self.in_features = in_features 12 | self.out_features = out_features 13 | self.s = s 14 | self.m = m 15 | self.cos_m = math.cos(m) 16 | self.sin_m = math.sin(m) 17 | 18 | self.th = math.cos(math.pi - m) 19 | self.mm = math.sin(math.pi - m) * m 20 | 21 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 22 | if bias: 23 | self.bias = Parameter(torch.Tensor(out_features)) 24 | else: 25 | self.register_parameter('bias', None) 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 30 | if self.bias is not None: 31 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 32 | bound = 1 / math.sqrt(fan_in) 33 | nn.init.uniform_(self.bias, -bound, bound) 34 | 35 | def forward(self, input, label): 36 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 37 | sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)) 38 | phi = cosine * self.cos_m - sine * self.sin_m 39 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 40 | # --------------------------- convert label to one-hot --------------------------- 41 | # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda') 42 | one_hot = torch.zeros(cosine.size(), device='cuda') 43 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 44 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 45 | output = (one_hot * phi) + ( 46 | (1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 47 | output *= self.s 48 | # print(output) 49 | 50 | return output 51 | 52 | class CircleLoss(nn.Module): 53 | def __init__(self, in_features, num_classes, s=256, m=0.25): 54 | super(CircleLoss, self).__init__() 55 | self.weight = Parameter(torch.Tensor(num_classes, in_features)) 56 | self.s = s 57 | self.m = m 58 | self._num_classes = num_classes 59 | self.reset_parameters() 60 | 61 | 62 | def reset_parameters(self): 63 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 64 | 65 | def __call__(self, bn_feat, targets): 66 | 67 | sim_mat = F.linear(F.normalize(bn_feat), F.normalize(self.weight)) 68 | alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self.m, min=0.) 69 | alpha_n = torch.clamp_min(sim_mat.detach() + self.m, min=0.) 70 | delta_p = 1 - self.m 71 | delta_n = self.m 72 | 73 | s_p = self.s * alpha_p * (sim_mat - delta_p) 74 | s_n = self.s * alpha_n * (sim_mat - delta_n) 75 | 76 | targets = F.one_hot(targets, num_classes=self._num_classes) 77 | 78 | pred_class_logits = targets * s_p + (1.0 - targets) * s_n 79 | 80 | return pred_class_logits -------------------------------------------------------------------------------- /loss/center_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class CenterLoss(nn.Module): 8 | """Center loss. 9 | 10 | Reference: 11 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 12 | 13 | Args: 14 | num_classes (int): number of classes. 15 | feat_dim (int): feature dimension. 16 | """ 17 | 18 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True): 19 | super(CenterLoss, self).__init__() 20 | self.num_classes = num_classes 21 | self.feat_dim = feat_dim 22 | self.use_gpu = use_gpu 23 | 24 | if self.use_gpu: 25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 26 | else: 27 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 28 | 29 | def forward(self, x, labels): 30 | """ 31 | Args: 32 | x: feature matrix with shape (batch_size, feat_dim). 33 | labels: ground truth labels with shape (num_classes). 34 | """ 35 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)" 36 | 37 | batch_size = x.size(0) 38 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 39 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 40 | distmat.addmm_(1, -2, x, self.centers.t()) 41 | 42 | classes = torch.arange(self.num_classes).long() 43 | if self.use_gpu: classes = classes.cuda() 44 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 45 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 46 | 47 | dist = [] 48 | for i in range(batch_size): 49 | value = distmat[i][mask[i]] 50 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 51 | dist.append(value) 52 | dist = torch.cat(dist) 53 | loss = dist.mean() 54 | return loss 55 | 56 | 57 | if __name__ == '__main__': 58 | use_gpu = False 59 | center_loss = CenterLoss(use_gpu=use_gpu) 60 | features = torch.rand(16, 2048) 61 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long() 62 | if use_gpu: 63 | features = torch.rand(16, 2048).cuda() 64 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda() 65 | 66 | loss = center_loss(features, targets) 67 | print(loss) 68 | -------------------------------------------------------------------------------- /loss/dissimilar_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def normalize(x, axis=-1): 7 | """Normalizing to unit length along the specified dimension. 8 | Args: 9 | x: pytorch Variable 10 | Returns: 11 | x: pytorch Variable, same shape as input 12 | """ 13 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 14 | return x 15 | 16 | 17 | def euclidean_dist(x, y): 18 | """ 19 | Args: 20 | x: pytorch Variable, with shape [B, m, d] 21 | y: pytorch Variable, with shape [B, n, d] 22 | Returns: 23 | dist: pytorch Variable, with shape [B, m, n] 24 | """ 25 | B = x.size(0) 26 | m, n = x.size(1), y.size(1) 27 | x = torch.nn.functional.normalize(x, dim=2, p=2) 28 | y = torch.nn.functional.normalize(y, dim=2, p=2) 29 | xx = torch.pow(x, 2).sum(2, keepdim=True).expand(B, m, n) 30 | yy = torch.pow(y, 2).sum(2, keepdim=True).expand(B, n, m).transpose(-2, -1) 31 | dist = xx + yy 32 | dist = dist - 2 * (x @ y.transpose(-2, -1)) 33 | # dist.addmm_(1, -2, x, y.t()) 34 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 35 | # return 1. / dist 36 | return dist 37 | # return -torch.log(dist) 38 | 39 | 40 | def cosine_dist(x, y): 41 | """ 42 | Args: 43 | x: pytorch Variable, with shape [B, m, d] 44 | y: pytorch Variable, with shape [B, n, d] 45 | Returns: 46 | dist: pytorch Variable, with shape [B, m, n] 47 | """ 48 | B = x.size(0) 49 | m, n = x.size(1), y.size(1) 50 | x_norm = torch.pow(x, 2).sum(2, keepdim=True).sqrt().expand(B, m, n) 51 | y_norm = torch.pow(y, 2).sum(2, keepdim=True).sqrt().expand(B, n, m).transpose(-2, -1) 52 | xy_intersection = x @ y.transpose(-2, -1) 53 | dist = xy_intersection/(x_norm * y_norm) 54 | return torch.abs(dist) 55 | 56 | class Dissimilar(object): 57 | def __init__(self, dynamic_balancer=False): 58 | self.dynamic_balancer = dynamic_balancer 59 | 60 | def __call__(self, features): 61 | B, N, C = features.shape 62 | dist_mat = cosine_dist(features, features) # B*N*N 63 | # dist_mat = euclidean_dist(features, features) 64 | # 上三角index 65 | top_triu = torch.triu(torch.ones(N, N, dtype=torch.bool), diagonal=1) 66 | _dist = dist_mat[:, top_triu] 67 | 68 | # 1.用softmax替换平均,使得相似度更高的权重更大 69 | if self.dynamic_balancer: 70 | weight = F.softmax(_dist, dim=-1) 71 | dist = torch.mean(torch.sum(weight*_dist, dim=1)) 72 | # 2.直接平均 73 | else: 74 | dist = torch.mean(_dist, dim=(0, 1)) 75 | return dist -------------------------------------------------------------------------------- /loss/make_loss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from .softmax_loss import CrossEntropyLabelSmooth, LabelSmoothingCrossEntropy, TemperatureCrossEntropy 6 | from .triplet_loss import TripletLoss 7 | from .center_loss import CenterLoss 8 | from .dissimilar_loss import Dissimilar 9 | 10 | 11 | def make_loss(cfg, num_classes): 12 | sampler = cfg.DATALOADER.SAMPLER 13 | feat_dim = 2048 14 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss 15 | dissimilar = Dissimilar(dynamic_balancer=cfg.MODEL.DYNAMIC_BALANCER) 16 | if 'triplet' in cfg.MODEL.METRIC_LOSS_TYPE: 17 | if cfg.MODEL.NO_MARGIN: 18 | triplet = TripletLoss() 19 | print("using soft triplet loss for training") 20 | else: 21 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss 22 | print("using triplet loss with margin:{}".format(cfg.SOLVER.MARGIN)) 23 | else: 24 | print('expected METRIC_LOSS_TYPE should be triplet' 25 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 26 | 27 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 28 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) 29 | print("label smooth on, numclasses:", num_classes) 30 | if cfg.MODEL.IF_TEMPERATURE_SOFTMAX == 'on': 31 | xent = TemperatureCrossEntropy() 32 | print("temperature softmax on") 33 | 34 | if sampler == 'softmax': 35 | def loss_func(score, feat, target): 36 | return F.cross_entropy(score, target) 37 | 38 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet': 39 | def loss_func(score, feat, target, target_cam): 40 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet': 41 | if cfg.MODEL.IF_TEMPERATURE_SOFTMAX == 'on': 42 | if isinstance(score, list): 43 | if isinstance(score[0], tuple): 44 | ID_LOSS = [xent(scor, lbl) for scor, lbl in score] 45 | else: 46 | ID_LOSS = [xent(scor, target, t) for scor, t in zip(score, cfg.MODEL.TEMPERATURE)] 47 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS) 48 | else: 49 | ID_LOSS = xent(score, target) 50 | else: 51 | if isinstance(score, list): 52 | if isinstance(score[0], tuple): 53 | ID_LOSS = [F.cross_entropy(scor, lbl) for scor, lbl in score] 54 | else: 55 | ID_LOSS = [F.cross_entropy(scor, target) for scor in score] 56 | # ID_LOSS = sum(ID_LOSS) / len(ID_LOSS) 57 | ID_LOSSes = ID_LOSS 58 | else: 59 | ID_LOSS = F.cross_entropy(score, target) 60 | ID_LOSSes = [ID_LOSS] 61 | 62 | if isinstance(feat, list): 63 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat] 64 | # TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS) 65 | TRI_LOSSes = TRI_LOSS 66 | else: 67 | TRI_LOSS = triplet(feat, target)[0] 68 | TRI_LOSSes = [TRI_LOSS] 69 | 70 | if len(feat) > 1: 71 | Dissimilar_LOSS = dissimilar(torch.stack(feat, dim=1)) 72 | else: 73 | Dissimilar_LOSS = 0 74 | 75 | return [[id_loss * cfg.MODEL.ID_LOSS_WEIGHT for id_loss in ID_LOSSes], 76 | [tri_loss * cfg.MODEL.TRIPLET_LOSS_WEIGHT for tri_loss in TRI_LOSSes], 77 | cfg.MODEL.DIVERSE_CLS_WEIGHT * Dissimilar_LOSS] 78 | else: 79 | print('expected METRIC_LOSS_TYPE should be triplet' 80 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 81 | 82 | else: 83 | print('expected sampler should be softmax, triplet, softmax_triplet or softmax_triplet_center' 84 | 'but got {}'.format(cfg.DATALOADER.SAMPLER)) 85 | return loss_func, center_criterion 86 | 87 | 88 | -------------------------------------------------------------------------------- /loss/metric_learning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.autograd 5 | from torch.nn import Parameter 6 | import math 7 | import random 8 | 9 | 10 | class ContrastiveLoss(nn.Module): 11 | def __init__(self, margin=0.3, **kwargs): 12 | super(ContrastiveLoss, self).__init__() 13 | self.margin = margin 14 | 15 | def forward(self, inputs, targets): 16 | n = inputs.size(0) 17 | # Compute similarity matrix 18 | sim_mat = torch.matmul(inputs, inputs.t()) 19 | targets = targets 20 | loss = list() 21 | c = 0 22 | 23 | for i in range(n): 24 | pos_pair_ = torch.masked_select(sim_mat[i], targets == targets[i]) 25 | 26 | # move itself 27 | pos_pair_ = torch.masked_select(pos_pair_, pos_pair_ < 1) 28 | neg_pair_ = torch.masked_select(sim_mat[i], targets != targets[i]) 29 | 30 | pos_pair_ = torch.sort(pos_pair_)[0] 31 | neg_pair_ = torch.sort(neg_pair_)[0] 32 | 33 | neg_pair = torch.masked_select(neg_pair_, neg_pair_ > self.margin) 34 | 35 | neg_loss = 0 36 | 37 | pos_loss = torch.sum(-pos_pair_ + 1) 38 | if len(neg_pair) > 0: 39 | neg_loss = torch.sum(neg_pair) 40 | loss.append(pos_loss + neg_loss) 41 | 42 | loss = sum(loss) / n 43 | return loss 44 | 45 | 46 | class CircleLoss(nn.Module): 47 | def __init__(self, in_features, num_classes, s=256, m=0.25): 48 | super(CircleLoss, self).__init__() 49 | self.weight = Parameter(torch.Tensor(num_classes, in_features)) 50 | self.s = s 51 | self.m = m 52 | self._num_classes = num_classes 53 | self.reset_parameters() 54 | 55 | 56 | def reset_parameters(self): 57 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 58 | 59 | def __call__(self, bn_feat, targets): 60 | 61 | sim_mat = F.linear(F.normalize(bn_feat), F.normalize(self.weight)) 62 | alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self.m, min=0.) 63 | alpha_n = torch.clamp_min(sim_mat.detach() + self.m, min=0.) 64 | delta_p = 1 - self.m 65 | delta_n = self.m 66 | 67 | s_p = self.s * alpha_p * (sim_mat - delta_p) 68 | s_n = self.s * alpha_n * (sim_mat - delta_n) 69 | 70 | targets = F.one_hot(targets, num_classes=self._num_classes) 71 | 72 | pred_class_logits = targets * s_p + (1.0 - targets) * s_n 73 | 74 | return pred_class_logits 75 | 76 | 77 | class Arcface(nn.Module): 78 | r"""Implement of large margin arc distance: : 79 | Args: 80 | in_features: size of each input sample 81 | out_features: size of each output sample 82 | s: norm of input feature 83 | m: margin 84 | cos(theta + m) 85 | """ 86 | def __init__(self, in_features, out_features, s=30.0, m=0.30, easy_margin=False, ls_eps=0.0): 87 | super(Arcface, self).__init__() 88 | self.in_features = in_features 89 | self.out_features = out_features 90 | self.s = s 91 | self.m = m 92 | self.ls_eps = ls_eps # label smoothing 93 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)) 94 | nn.init.xavier_uniform_(self.weight) 95 | 96 | self.easy_margin = easy_margin 97 | self.cos_m = math.cos(m) 98 | self.sin_m = math.sin(m) 99 | self.th = math.cos(math.pi - m) 100 | self.mm = math.sin(math.pi - m) * m 101 | 102 | def forward(self, input, label): 103 | # --------------------------- cos(theta) & phi(theta) --------------------------- 104 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 105 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) 106 | phi = cosine * self.cos_m - sine * self.sin_m 107 | phi = phi.type_as(cosine) 108 | if self.easy_margin: 109 | phi = torch.where(cosine > 0, phi, cosine) 110 | else: 111 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 112 | # --------------------------- convert label to one-hot --------------------------- 113 | # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda') 114 | one_hot = torch.zeros(cosine.size(), device='cuda') 115 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 116 | if self.ls_eps > 0: 117 | one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features 118 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 119 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) 120 | output *= self.s 121 | 122 | return output 123 | 124 | 125 | class Cosface(nn.Module): 126 | r"""Implement of large margin cosine distance: : 127 | Args: 128 | in_features: size of each input sample 129 | out_features: size of each output sample 130 | s: norm of input feature 131 | m: margin 132 | cos(theta) - m 133 | """ 134 | 135 | def __init__(self, in_features, out_features, s=30.0, m=0.30): 136 | super(Cosface, self).__init__() 137 | self.in_features = in_features 138 | self.out_features = out_features 139 | self.s = s 140 | self.m = m 141 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)) 142 | nn.init.xavier_uniform_(self.weight) 143 | 144 | def forward(self, input, label): 145 | # --------------------------- cos(theta) & phi(theta) --------------------------- 146 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 147 | phi = cosine - self.m 148 | # --------------------------- convert label to one-hot --------------------------- 149 | one_hot = torch.zeros(cosine.size(), device='cuda') 150 | # one_hot = one_hot.cuda() if cosine.is_cuda else one_hot 151 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 152 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 153 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 154 | output *= self.s 155 | # print(output) 156 | 157 | return output 158 | 159 | def __repr__(self): 160 | return self.__class__.__name__ + '(' \ 161 | + 'in_features=' + str(self.in_features) \ 162 | + ', out_features=' + str(self.out_features) \ 163 | + ', s=' + str(self.s) \ 164 | + ', m=' + str(self.m) + ')' 165 | 166 | 167 | class AMSoftmax(nn.Module): 168 | def __init__(self, in_features, out_features, s=30.0, m=0.30): 169 | super(AMSoftmax, self).__init__() 170 | self.m = m 171 | self.s = s 172 | self.in_feats = in_features 173 | self.W = torch.nn.Parameter(torch.randn(in_features, out_features), requires_grad=True) 174 | self.ce = nn.CrossEntropyLoss() 175 | nn.init.xavier_normal_(self.W, gain=1) 176 | 177 | def forward(self, x, lb): 178 | assert x.size()[0] == lb.size()[0] 179 | assert x.size()[1] == self.in_feats 180 | x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12) 181 | x_norm = torch.div(x, x_norm) 182 | w_norm = torch.norm(self.W, p=2, dim=0, keepdim=True).clamp(min=1e-12) 183 | w_norm = torch.div(self.W, w_norm) 184 | costh = torch.mm(x_norm, w_norm) 185 | # print(x_norm.shape, w_norm.shape, costh.shape) 186 | lb_view = lb.view(-1, 1) 187 | delt_costh = torch.zeros(costh.size(), device='cuda').scatter_(1, lb_view, self.m) 188 | costh_m = costh - delt_costh 189 | costh_m_s = self.s * costh_m 190 | return costh_m_s 191 | 192 | class PartSoftmax(nn.Module): 193 | def __init__(self, in_features, out_features, ratio): 194 | super(PartSoftmax, self).__init__() 195 | self.in_feats = in_features 196 | num_classes = int(out_features * ratio) 197 | self.pids = sorted(random.sample(range(out_features), num_classes)) 198 | self.W = torch.nn.Parameter(torch.randn(in_features, num_classes), requires_grad=True) 199 | nn.init.normal_(self.W, std=0.001) 200 | 201 | def forward(self, x, lb): 202 | assert x.size()[0] == lb.size()[0] 203 | assert x.size()[1] == self.in_feats 204 | logits = torch.mm(x, self.W) 205 | lb_idx = torch.full_like(lb, False, dtype=bool) 206 | new_lb = [] 207 | for i in range(lb.shape[0]): 208 | if int(lb[i]) in self.pids: 209 | idx = self.pids.index(int(lb[i])) 210 | lb_idx[i] = True 211 | new_lb.append(idx) 212 | logits = logits[lb_idx] 213 | new_lb = torch.tensor(new_lb, dtype=torch.long).to(device=x.device) 214 | return (logits, new_lb) -------------------------------------------------------------------------------- /loss/softmax_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | class CrossEntropyLabelSmooth(nn.Module): 5 | """Cross entropy loss with label smoothing regularizer. 6 | 7 | Reference: 8 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 9 | Equation: y = (1 - epsilon) * y + epsilon / K. 10 | 11 | Args: 12 | num_classes (int): number of classes. 13 | epsilon (float): weight. 14 | """ 15 | 16 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 17 | super(CrossEntropyLabelSmooth, self).__init__() 18 | self.num_classes = num_classes 19 | self.epsilon = epsilon 20 | self.use_gpu = use_gpu 21 | self.logsoftmax = nn.LogSoftmax(dim=1) 22 | 23 | def forward(self, inputs, targets): 24 | """ 25 | Args: 26 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 27 | targets: ground truth labels with shape (num_classes) 28 | """ 29 | log_probs = self.logsoftmax(inputs) 30 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 31 | if self.use_gpu: targets = targets.cuda() 32 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 33 | loss = (- targets * log_probs).mean(0).sum() 34 | return loss 35 | 36 | class LabelSmoothingCrossEntropy(nn.Module): 37 | """ 38 | NLL loss with label smoothing. 39 | """ 40 | def __init__(self, smoothing=0.1): 41 | """ 42 | Constructor for the LabelSmoothing module. 43 | :param smoothing: label smoothing factor 44 | """ 45 | super(LabelSmoothingCrossEntropy, self).__init__() 46 | assert smoothing < 1.0 47 | self.smoothing = smoothing 48 | self.confidence = 1. - smoothing 49 | 50 | def forward(self, x, target): 51 | logprobs = F.log_softmax(x, dim=-1) 52 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 53 | nll_loss = nll_loss.squeeze(1) 54 | smooth_loss = -logprobs.mean(dim=-1) 55 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 56 | return loss.mean() 57 | 58 | class TemperatureCrossEntropy(nn.Module): 59 | def __init__(self, use_gpu=True): 60 | super(TemperatureCrossEntropy, self).__init__() 61 | self.use_gpu = use_gpu 62 | self.logsoftmax = nn.LogSoftmax(dim=1) 63 | self.nll = nn.NLLLoss() 64 | 65 | def forward(self, inputs, targets, temperature): 66 | t_log_probs = self.logsoftmax(inputs / temperature) 67 | if self.use_gpu: targets = targets.cuda() 68 | loss = self.nll(t_log_probs, targets) 69 | return loss -------------------------------------------------------------------------------- /loss/triplet_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def normalize(x, axis=-1): 6 | """Normalizing to unit length along the specified dimension. 7 | Args: 8 | x: pytorch Variable 9 | Returns: 10 | x: pytorch Variable, same shape as input 11 | """ 12 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 13 | return x 14 | 15 | 16 | def euclidean_dist(x, y): 17 | """ 18 | Args: 19 | x: pytorch Variable, with shape [m, d] 20 | y: pytorch Variable, with shape [n, d] 21 | Returns: 22 | dist: pytorch Variable, with shape [m, n] 23 | """ 24 | m, n = x.size(0), y.size(0) 25 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 26 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 27 | dist = xx + yy 28 | dist = dist - 2 * torch.matmul(x, y.t()) 29 | # dist.addmm_(1, -2, x, y.t()) 30 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 31 | return dist 32 | 33 | 34 | def cosine_dist(x, y): 35 | """ 36 | Args: 37 | x: pytorch Variable, with shape [m, d] 38 | y: pytorch Variable, with shape [n, d] 39 | Returns: 40 | dist: pytorch Variable, with shape [m, n] 41 | """ 42 | m, n = x.size(0), y.size(0) 43 | x_norm = torch.pow(x, 2).sum(1, keepdim=True).sqrt().expand(m, n) 44 | y_norm = torch.pow(y, 2).sum(1, keepdim=True).sqrt().expand(n, m).t() 45 | xy_intersection = torch.mm(x, y.t()) 46 | dist = xy_intersection/(x_norm * y_norm) 47 | dist = (1. - dist) / 2 48 | return dist 49 | 50 | 51 | def hard_example_mining(dist_mat, labels, return_inds=False): 52 | """For each anchor, find the hardest positive and negative sample. 53 | Args: 54 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 55 | labels: pytorch LongTensor, with shape [N] 56 | return_inds: whether to return the indices. Save time if `False`(?) 57 | Returns: 58 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 59 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 60 | p_inds: pytorch LongTensor, with shape [N]; 61 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 62 | n_inds: pytorch LongTensor, with shape [N]; 63 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 64 | NOTE: Only consider the case in which all labels have same num of samples, 65 | thus we can cope with all anchors in parallel. 66 | """ 67 | 68 | assert len(dist_mat.size()) == 2 69 | assert dist_mat.size(0) == dist_mat.size(1) 70 | N = dist_mat.size(0) 71 | 72 | # shape [N, N] 73 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 74 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 75 | 76 | # `dist_ap` means distance(anchor, positive) 77 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 78 | dist_ap, relative_p_inds = torch.max( 79 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 80 | # print(dist_mat[is_pos].shape) 81 | # `dist_an` means distance(anchor, negative) 82 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 83 | dist_an, relative_n_inds = torch.min( 84 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 85 | # shape [N] 86 | dist_ap = dist_ap.squeeze(1) 87 | dist_an = dist_an.squeeze(1) 88 | 89 | if return_inds: 90 | # shape [N, N] 91 | ind = (labels.new().resize_as_(labels) 92 | .copy_(torch.arange(0, N).long()) 93 | .unsqueeze(0).expand(N, N)) 94 | # shape [N, 1] 95 | p_inds = torch.gather( 96 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 97 | n_inds = torch.gather( 98 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 99 | # shape [N] 100 | p_inds = p_inds.squeeze(1) 101 | n_inds = n_inds.squeeze(1) 102 | return dist_ap, dist_an, p_inds, n_inds 103 | 104 | return dist_ap, dist_an 105 | 106 | 107 | class TripletLoss(object): 108 | """ 109 | Triplet loss using HARDER example mining, 110 | modified based on original triplet loss using hard example mining 111 | """ 112 | 113 | def __init__(self, margin=None, hard_factor=0.0): 114 | self.margin = margin 115 | self.hard_factor = hard_factor 116 | if margin is not None: 117 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 118 | else: 119 | self.ranking_loss = nn.SoftMarginLoss() 120 | 121 | def __call__(self, global_feat, labels, normalize_feature=False): 122 | if normalize_feature: 123 | global_feat = normalize(global_feat, axis=-1) 124 | dist_mat = euclidean_dist(global_feat, global_feat) 125 | dist_ap, dist_an = hard_example_mining(dist_mat, labels) 126 | 127 | dist_ap *= (1.0 + self.hard_factor) 128 | dist_an *= (1.0 - self.hard_factor) 129 | 130 | y = dist_an.new().resize_as_(dist_an).fill_(1) 131 | if self.margin is not None: 132 | loss = self.ranking_loss(dist_an, dist_ap, y) 133 | else: 134 | loss = self.ranking_loss(dist_an - dist_ap, y) 135 | return loss, dist_ap, dist_an 136 | 137 | 138 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_model import make_model -------------------------------------------------------------------------------- /model/backbones/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alipay/Diverse-and-Compact-Transformer/de44d0a0b79f9a947f27368a9188f73b16515188/model/backbones/__init__.py -------------------------------------------------------------------------------- /model/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.downsample = downsample 24 | self.stride = stride 25 | 26 | def forward(self, x): 27 | residual = x 28 | 29 | out = self.conv1(x) 30 | out = self.bn1(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv2(out) 34 | out = self.bn2(out) 35 | 36 | if self.downsample is not None: 37 | residual = self.downsample(x) 38 | 39 | out += residual 40 | out = self.relu(out) 41 | 42 | return out 43 | 44 | 45 | class Bottleneck(nn.Module): 46 | expansion = 4 47 | 48 | def __init__(self, inplanes, planes, stride=1, downsample=None): 49 | super(Bottleneck, self).__init__() 50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 51 | self.bn1 = nn.BatchNorm2d(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 53 | padding=1, bias=False) 54 | self.bn2 = nn.BatchNorm2d(planes) 55 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 56 | self.bn3 = nn.BatchNorm2d(planes * 4) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | def forward(self, x): 62 | residual = x 63 | 64 | out = self.conv1(x) 65 | out = self.bn1(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv2(out) 69 | out = self.bn2(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv3(out) 73 | out = self.bn3(out) 74 | 75 | if self.downsample is not None: 76 | residual = self.downsample(x) 77 | 78 | out += residual 79 | out = self.relu(out) 80 | 81 | return out 82 | 83 | 84 | class ResNet(nn.Module): 85 | def __init__(self, last_stride=2, block=Bottleneck,layers=[3, 4, 6, 3]): 86 | self.inplanes = 64 87 | super().__init__() 88 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 89 | bias=False) 90 | self.bn1 = nn.BatchNorm2d(64) 91 | # self.relu = nn.ReLU(inplace=True) # add missed relu 92 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=None, padding=0) 93 | self.layer1 = self._make_layer(block, 64, layers[0]) 94 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 95 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 96 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride) 97 | 98 | def _make_layer(self, block, planes, blocks, stride=1): 99 | downsample = None 100 | if stride != 1 or self.inplanes != planes * block.expansion: 101 | downsample = nn.Sequential( 102 | nn.Conv2d(self.inplanes, planes * block.expansion, 103 | kernel_size=1, stride=stride, bias=False), 104 | nn.BatchNorm2d(planes * block.expansion), 105 | ) 106 | 107 | layers = [] 108 | layers.append(block(self.inplanes, planes, stride, downsample)) 109 | self.inplanes = planes * block.expansion 110 | for i in range(1, blocks): 111 | layers.append(block(self.inplanes, planes)) 112 | 113 | return nn.Sequential(*layers) 114 | 115 | def forward(self, x, cam_label=None): 116 | x = self.conv1(x) 117 | x = self.bn1(x) 118 | # x = self.relu(x) # add missed relu 119 | x = self.maxpool(x) 120 | x = self.layer1(x) 121 | x = self.layer2(x) 122 | x = self.layer3(x) 123 | x = self.layer4(x) 124 | 125 | return x 126 | 127 | def load_param(self, model_path): 128 | param_dict = torch.load(model_path) 129 | for i in param_dict: 130 | if 'fc' in i: 131 | continue 132 | self.state_dict()[i].copy_(param_dict[i]) 133 | 134 | def random_init(self): 135 | for m in self.modules(): 136 | if isinstance(m, nn.Conv2d): 137 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 138 | m.weight.data.normal_(0, math.sqrt(2. / n)) 139 | elif isinstance(m, nn.BatchNorm2d): 140 | m.weight.data.fill_(1) 141 | m.bias.data.zero_() -------------------------------------------------------------------------------- /processor/__init__.py: -------------------------------------------------------------------------------- 1 | from .processor import do_train, do_inference -------------------------------------------------------------------------------- /processor/processor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | import torch 5 | import torch.nn as nn 6 | from utils.meter import AverageMeter 7 | from utils.metrics import R1_mAP_eval 8 | from torch.cuda import amp 9 | import torch.distributed as dist 10 | import numpy as np 11 | 12 | def do_train(cfg, 13 | model, 14 | center_criterion, 15 | train_loader, 16 | val_loader, 17 | optimizer, 18 | optimizer_center, 19 | scheduler, 20 | loss_fn, 21 | num_query, local_rank, num_classes): 22 | log_period = cfg.SOLVER.LOG_PERIOD 23 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD 24 | eval_period = cfg.SOLVER.EVAL_PERIOD 25 | 26 | device = "cuda" 27 | epochs = cfg.SOLVER.MAX_EPOCHS 28 | 29 | logger = logging.getLogger("transreid.train") 30 | logger.info('start training') 31 | _LOCAL_PROCESS_GROUP = None 32 | if device: 33 | model.to(local_rank) 34 | if torch.cuda.device_count() > 1 and cfg.MODEL.DIST_TRAIN: 35 | print('Using {} GPUs for training'.format(torch.cuda.device_count())) 36 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True) 37 | 38 | loss_meter = AverageMeter() 39 | id_loss_meter = AverageMeter() 40 | id_losses_meters = [AverageMeter() for i in range(cfg.MODEL.CLS_TOKEN_NUM)] 41 | triplet_loss_meter = AverageMeter() 42 | dissimilar_loss_meter = AverageMeter() 43 | acc_meter = AverageMeter() 44 | 45 | evaluator = R1_mAP_eval(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM, reranking=cfg.TEST.RE_RANKING) 46 | scaler = amp.GradScaler() 47 | # train 48 | if cfg.MODEL.ID_HARD_MINING: 49 | weight = torch.normal(0, 1, size=(num_classes, 768)) 50 | for epoch in range(1, epochs + 1): 51 | if cfg.MODEL.ID_HARD_MINING: 52 | train_loader.batch_sampler.sampler.update_weight(weight) 53 | start_time = time.time() 54 | loss_meter.reset() 55 | id_loss_meter.reset() 56 | for i in range(cfg.MODEL.CLS_TOKEN_NUM): 57 | id_losses_meters[i].reset() 58 | triplet_loss_meter.reset() 59 | dissimilar_loss_meter.reset() 60 | acc_meter.reset() 61 | evaluator.reset() 62 | scheduler.step(epoch) 63 | model.train() 64 | for n_iter, (img, vid, target_cam, target_view) in enumerate(train_loader): 65 | optimizer.zero_grad() 66 | optimizer_center.zero_grad() 67 | img = img.to(device) 68 | target = vid.to(device) 69 | target_cam = target_cam.to(device) 70 | target_view = target_view.to(device) 71 | with amp.autocast(enabled=True): 72 | score, feat, weight = model(img, target, cam_label=target_cam, view_label=target_view ) 73 | losses = loss_fn(score, feat, target, target_cam) 74 | id_loss = sum(losses[0]) / len(losses[0]) 75 | tri_loss = sum(losses[1]) / len(losses[1]) 76 | if cfg.MODEL.CLS_TOKENS_LOSS: 77 | loss = id_loss + tri_loss + losses[2] 78 | else: 79 | loss = id_loss + tri_loss 80 | 81 | scaler.scale(loss).backward() 82 | 83 | # grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5) 84 | scaler.step(optimizer) 85 | scaler.update() 86 | 87 | if 'center' in cfg.MODEL.METRIC_LOSS_TYPE: 88 | for param in center_criterion.parameters(): 89 | param.grad.data *= (1. / cfg.SOLVER.CENTER_LOSS_WEIGHT) 90 | scaler.step(optimizer_center) 91 | scaler.update() 92 | try: 93 | if isinstance(score, list): 94 | if isinstance(score[0], tuple): 95 | acc = (score[0][0].max(1)[1] == score[0][1]).float().mean() 96 | else: 97 | acc = (score[0].max(1)[1] == target).float().mean() 98 | else: 99 | acc = (score.max(1)[1] == target).float().mean() 100 | except: 101 | acc = 0 102 | 103 | loss_meter.update(loss.item(), img.shape[0]) 104 | id_loss_meter.update(id_loss.item(), img.shape[0]) 105 | for i in range(cfg.MODEL.CLS_TOKEN_NUM): 106 | id_losses_meters[i].update(losses[0][i].item(), img.shape[0]) 107 | triplet_loss_meter.update(tri_loss.item(), img.shape[0]) 108 | if cfg.MODEL.CLS_TOKEN_NUM > 1: 109 | dissimilar_loss_meter.update(losses[2].item(), img.shape[0]) 110 | acc_meter.update(acc, 1) 111 | 112 | torch.cuda.synchronize() 113 | if (n_iter + 1) % log_period == 0: 114 | id_losses_avgs = [id_losses_meter.avg for id_losses_meter in id_losses_meters] 115 | id_losses_avgs_info = "{:.3f}".format(id_losses_avgs[0]) 116 | for i in range(1, cfg.MODEL.CLS_TOKEN_NUM): 117 | id_losses_avgs_info += "/{:.3f}".format(id_losses_avgs[i]) 118 | logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, ID_Loss: {:.3f}-{}, TRIPLE_Loss: {:.3f}, DISSIMILAR Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}" 119 | .format(epoch, (n_iter + 1), len(train_loader), 120 | loss_meter.avg, id_loss_meter.avg, id_losses_avgs_info, triplet_loss_meter.avg, dissimilar_loss_meter.avg, acc_meter.avg, scheduler._get_lr(epoch)[0])) 121 | 122 | end_time = time.time() 123 | time_per_batch = (end_time - start_time) / (n_iter + 1) 124 | if cfg.MODEL.DIST_TRAIN: 125 | pass 126 | else: 127 | logger.info("Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]" 128 | .format(epoch, time_per_batch, train_loader.batch_size / time_per_batch)) 129 | 130 | if epoch % checkpoint_period == 0: 131 | if cfg.MODEL.DIST_TRAIN: 132 | if dist.get_rank() == 0: 133 | torch.save(model.state_dict(), 134 | os.path.join(cfg.OUTPUT_DIR, cfg.MODEL.NAME + '_{}.pth'.format(epoch))) 135 | else: 136 | torch.save(model.state_dict(), 137 | os.path.join(cfg.OUTPUT_DIR, cfg.MODEL.NAME + '_{}.pth'.format(epoch))) 138 | 139 | if epoch % eval_period == 0: 140 | if cfg.MODEL.DIST_TRAIN: 141 | if dist.get_rank() == 0: 142 | model.eval() 143 | for n_iter, (img, vid, camid, camids, target_view, _) in enumerate(val_loader): 144 | with torch.no_grad(): 145 | img = img.to(device) 146 | camids = camids.to(device) 147 | target_view = target_view.to(device) 148 | feat = model(img, cam_label=camids, view_label=target_view) 149 | evaluator.update((feat, vid, camid)) 150 | cmc, mAP, distmat, pids, camids, feats = evaluator.compute() 151 | with open(os.path.join(cfg.OUTPUT_DIR, cfg.TEST.DIST_MAT), 'wb') as f: 152 | np.save(f, distmat) 153 | np.save(f, pids) 154 | np.save(f, camids) 155 | np.save(f, feats) 156 | logger.info("dist_mat saved at: {}".format(os.path.join(cfg.OUTPUT_DIR, cfg.TEST.DIST_MAT))) 157 | logger.info("Validation Results - Epoch: {}".format(epoch)) 158 | logger.info("mAP: {:.1%}".format(mAP)) 159 | for r in [1, 5, 10]: 160 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 161 | torch.cuda.empty_cache() 162 | else: 163 | model.eval() 164 | for n_iter, (img, vid, camid, camids, target_view, _) in enumerate(val_loader): 165 | with torch.no_grad(): 166 | img = img.to(device) 167 | camids = camids.to(device) 168 | target_view = target_view.to(device) 169 | feat = model(img, cam_label=camids, view_label=target_view) 170 | evaluator.update((feat, vid, camid)) 171 | cmc, mAP, distmat, pids, camids, feats = evaluator.compute() 172 | with open(os.path.join(cfg.OUTPUT_DIR, cfg.TEST.DIST_MAT), 'wb') as f: 173 | np.save(f, distmat) 174 | np.save(f, pids) 175 | np.save(f, camids) 176 | np.save(f, feats) 177 | logger.info("dist_mat saved at: {}".format(os.path.join(cfg.OUTPUT_DIR, cfg.TEST.DIST_MAT))) 178 | logger.info("Validation Results - Epoch: {}".format(epoch)) 179 | logger.info("mAP: {:.1%}".format(mAP)) 180 | for r in [1, 5, 10]: 181 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 182 | torch.cuda.empty_cache() 183 | 184 | 185 | def do_inference(cfg, 186 | model, 187 | val_loader, 188 | num_query): 189 | device = "cuda" 190 | logger = logging.getLogger("transreid.test") 191 | logger.info("Enter inferencing") 192 | 193 | evaluator = R1_mAP_eval(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM, reranking=cfg.TEST.RE_RANKING) 194 | 195 | evaluator.reset() 196 | 197 | if device: 198 | if torch.cuda.device_count() > 1: 199 | print('Using {} GPUs for inference'.format(torch.cuda.device_count())) 200 | model = nn.DataParallel(model) 201 | model.to(device) 202 | 203 | model.eval() 204 | img_path_list = [] 205 | 206 | for n_iter, (img, pid, camid, camids, target_view, imgpath) in enumerate(val_loader): 207 | with torch.no_grad(): 208 | img = img.to(device) 209 | camids = camids.to(device) 210 | target_view = target_view.to(device) 211 | feat = model(img, cam_label=camids, view_label=target_view) 212 | evaluator.update((feat, pid, camid)) 213 | img_path_list.extend(imgpath) 214 | 215 | cmc, mAP, distmat, pids, camids, feats = evaluator.compute() 216 | with open(os.path.join(cfg.OUTPUT_DIR, cfg.TEST.DIST_MAT), 'wb') as f: 217 | np.save(f, distmat) 218 | np.save(f, pids) 219 | np.save(f, camids) 220 | np.save(f, feats) 221 | logger.info("dist_mat saved at: {}".format(os.path.join(cfg.OUTPUT_DIR, cfg.TEST.DIST_MAT))) 222 | logger.info("Validation Results ") 223 | logger.info("mAP: {:.1%}".format(mAP)) 224 | for r in [1, 5, 10]: 225 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 226 | return cmc[0], cmc[4] 227 | 228 | 229 | -------------------------------------------------------------------------------- /solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .lr_scheduler import WarmupMultiStepLR 2 | from .make_optimizer import make_optimizer -------------------------------------------------------------------------------- /solver/cosine_lr.py: -------------------------------------------------------------------------------- 1 | """ Cosine Scheduler 2 | 3 | Cosine LR schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import torch 10 | 11 | from .scheduler import Scheduler 12 | 13 | 14 | _logger = logging.getLogger(__name__) 15 | 16 | 17 | class CosineLRScheduler(Scheduler): 18 | """ 19 | Cosine decay with restarts. 20 | This is described in the paper https://arxiv.org/abs/1608.03983. 21 | 22 | Inspiration from 23 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 24 | """ 25 | 26 | def __init__(self, 27 | optimizer: torch.optim.Optimizer, 28 | t_initial: int, 29 | t_mul: float = 1., 30 | lr_min: float = 0., 31 | decay_rate: float = 1., 32 | warmup_t=0, 33 | warmup_lr_init=0, 34 | warmup_prefix=False, 35 | cycle_limit=0, 36 | t_in_epochs=True, 37 | noise_range_t=None, 38 | noise_pct=0.67, 39 | noise_std=1.0, 40 | noise_seed=42, 41 | initialize=True) -> None: 42 | super().__init__( 43 | optimizer, param_group_field="lr", 44 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 45 | initialize=initialize) 46 | 47 | assert t_initial > 0 48 | assert lr_min >= 0 49 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 50 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 51 | "rate since t_initial = t_mul = eta_mul = 1.") 52 | self.t_initial = t_initial 53 | self.t_mul = t_mul 54 | self.lr_min = lr_min 55 | self.decay_rate = decay_rate 56 | self.cycle_limit = cycle_limit 57 | self.warmup_t = warmup_t 58 | self.warmup_lr_init = warmup_lr_init 59 | self.warmup_prefix = warmup_prefix 60 | self.t_in_epochs = t_in_epochs 61 | if self.warmup_t: 62 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 63 | super().update_groups(self.warmup_lr_init) 64 | else: 65 | self.warmup_steps = [1 for _ in self.base_values] 66 | 67 | def _get_lr(self, t): 68 | if t < self.warmup_t: 69 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 70 | else: 71 | if self.warmup_prefix: 72 | t = t - self.warmup_t 73 | 74 | if self.t_mul != 1: 75 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 76 | t_i = self.t_mul ** i * self.t_initial 77 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 78 | else: 79 | i = t // self.t_initial 80 | t_i = self.t_initial 81 | t_curr = t - (self.t_initial * i) 82 | 83 | gamma = self.decay_rate ** i 84 | lr_min = self.lr_min * gamma 85 | lr_max_values = [v * gamma for v in self.base_values] 86 | 87 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 88 | lrs = [ 89 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 90 | ] 91 | else: 92 | lrs = [self.lr_min for _ in self.base_values] 93 | 94 | return lrs 95 | 96 | def get_epoch_values(self, epoch: int): 97 | if self.t_in_epochs: 98 | return self._get_lr(epoch) 99 | else: 100 | return None 101 | 102 | def get_update_values(self, num_updates: int): 103 | if not self.t_in_epochs: 104 | return self._get_lr(num_updates) 105 | else: 106 | return None 107 | 108 | def get_cycle_length(self, cycles=0): 109 | if not cycles: 110 | cycles = self.cycle_limit 111 | cycles = max(1, cycles) 112 | if self.t_mul == 1.0: 113 | return self.t_initial * cycles 114 | else: 115 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 116 | -------------------------------------------------------------------------------- /solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | from bisect import bisect_right 4 | import torch 5 | 6 | 7 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 8 | # separating MultiStepLR with WarmupLR 9 | # but the current LRScheduler design doesn't allow it 10 | 11 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 12 | def __init__( 13 | self, 14 | optimizer, 15 | milestones, # steps 16 | gamma=0.1, 17 | warmup_factor=1.0 / 3, 18 | warmup_iters=500, 19 | warmup_method="linear", 20 | last_epoch=-1, 21 | ): 22 | if not list(milestones) == sorted(milestones): 23 | raise ValueError( 24 | "Milestones should be a list of" " increasing integers. Got {}", 25 | milestones, 26 | ) 27 | 28 | if warmup_method not in ("constant", "linear"): 29 | raise ValueError( 30 | "Only 'constant' or 'linear' warmup_method accepted" 31 | "got {}".format(warmup_method) 32 | ) 33 | self.milestones = milestones 34 | self.gamma = gamma 35 | self.warmup_factor = warmup_factor 36 | self.warmup_iters = warmup_iters 37 | self.warmup_method = warmup_method 38 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 39 | 40 | def _get_lr(self): 41 | warmup_factor = 1 42 | if self.last_epoch < self.warmup_iters: 43 | if self.warmup_method == "constant": 44 | warmup_factor = self.warmup_factor 45 | elif self.warmup_method == "linear": 46 | alpha = self.last_epoch / self.warmup_iters 47 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 48 | return [ 49 | base_lr 50 | * warmup_factor 51 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 52 | for base_lr in self.base_lrs 53 | ] 54 | -------------------------------------------------------------------------------- /solver/make_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def make_optimizer(cfg, model, center_criterion): 5 | params = [] 6 | for key, value in model.named_parameters(): 7 | if not value.requires_grad: 8 | continue 9 | lr = cfg.SOLVER.BASE_LR 10 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 11 | if "bias" in key: 12 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 13 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 14 | if cfg.SOLVER.LARGE_FC_LR: 15 | if "classifier" in key or "arcface" in key: 16 | lr = cfg.SOLVER.BASE_LR * 2 17 | print('Using two times learning rate for fc ') 18 | 19 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 20 | 21 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 22 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 23 | elif cfg.SOLVER.OPTIMIZER_NAME == 'AdamW': 24 | optimizer = torch.optim.AdamW(params, lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 25 | else: 26 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 27 | optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR) 28 | 29 | return optimizer, optimizer_center 30 | -------------------------------------------------------------------------------- /solver/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | import torch 4 | 5 | 6 | class Scheduler: 7 | """ Parameter Scheduler Base Class 8 | A scheduler base class that can be used to schedule any optimizer parameter groups. 9 | 10 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called 11 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value 12 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value 13 | 14 | The schedulers built on this should try to remain as stateless as possible (for simplicity). 15 | 16 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' 17 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training 18 | code and explicitly passed in to the schedulers on the corresponding step or step_update call. 19 | 20 | Based on ideas from: 21 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler 22 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers 23 | """ 24 | 25 | def __init__(self, 26 | optimizer: torch.optim.Optimizer, 27 | param_group_field: str, 28 | noise_range_t=None, 29 | noise_type='normal', 30 | noise_pct=0.67, 31 | noise_std=1.0, 32 | noise_seed=None, 33 | initialize: bool = True) -> None: 34 | self.optimizer = optimizer 35 | self.param_group_field = param_group_field 36 | self._initial_param_group_field = f"initial_{param_group_field}" 37 | if initialize: 38 | for i, group in enumerate(self.optimizer.param_groups): 39 | if param_group_field not in group: 40 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]") 41 | group.setdefault(self._initial_param_group_field, group[param_group_field]) 42 | else: 43 | for i, group in enumerate(self.optimizer.param_groups): 44 | if self._initial_param_group_field not in group: 45 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") 46 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] 47 | self.metric = None # any point to having this for all? 48 | self.noise_range_t = noise_range_t 49 | self.noise_pct = noise_pct 50 | self.noise_type = noise_type 51 | self.noise_std = noise_std 52 | self.noise_seed = noise_seed if noise_seed is not None else 42 53 | self.update_groups(self.base_values) 54 | 55 | def state_dict(self) -> Dict[str, Any]: 56 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 57 | 58 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 59 | self.__dict__.update(state_dict) 60 | 61 | def get_epoch_values(self, epoch: int): 62 | return None 63 | 64 | def get_update_values(self, num_updates: int): 65 | return None 66 | 67 | def step(self, epoch: int, metric: float = None) -> None: 68 | self.metric = metric 69 | values = self.get_epoch_values(epoch) 70 | if values is not None: 71 | values = self._add_noise(values, epoch) 72 | self.update_groups(values) 73 | 74 | def step_update(self, num_updates: int, metric: float = None): 75 | self.metric = metric 76 | values = self.get_update_values(num_updates) 77 | if values is not None: 78 | values = self._add_noise(values, num_updates) 79 | self.update_groups(values) 80 | 81 | def update_groups(self, values): 82 | if not isinstance(values, (list, tuple)): 83 | values = [values] * len(self.optimizer.param_groups) 84 | for param_group, value in zip(self.optimizer.param_groups, values): 85 | param_group[self.param_group_field] = value 86 | 87 | def _add_noise(self, lrs, t): 88 | if self.noise_range_t is not None: 89 | if isinstance(self.noise_range_t, (list, tuple)): 90 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] 91 | else: 92 | apply_noise = t >= self.noise_range_t 93 | if apply_noise: 94 | g = torch.Generator() 95 | g.manual_seed(self.noise_seed + t) 96 | if self.noise_type == 'normal': 97 | while True: 98 | # resample if noise out of percent limit, brute force but shouldn't spin much 99 | noise = torch.randn(1, generator=g).item() 100 | if abs(noise) < self.noise_pct: 101 | break 102 | else: 103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 104 | lrs = [v + v * noise for v in lrs] 105 | return lrs 106 | -------------------------------------------------------------------------------- /solver/scheduler_factory.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from .cosine_lr import CosineLRScheduler 5 | 6 | 7 | def create_scheduler(cfg, optimizer): 8 | num_epochs = cfg.SOLVER.MAX_EPOCHS 9 | # type 1 10 | # lr_min = 0.01 * cfg.SOLVER.BASE_LR 11 | # warmup_lr_init = 0.001 * cfg.SOLVER.BASE_LR 12 | # type 2 13 | lr_min = 0.002 * cfg.SOLVER.BASE_LR 14 | warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR 15 | # type 3 16 | # lr_min = 0.001 * cfg.SOLVER.BASE_LR 17 | # warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR 18 | 19 | warmup_t = cfg.SOLVER.WARMUP_EPOCHS 20 | noise_range = None 21 | 22 | lr_scheduler = CosineLRScheduler( 23 | optimizer, 24 | t_initial=num_epochs, 25 | lr_min=lr_min, 26 | t_mul= 1., 27 | decay_rate=0.1, 28 | warmup_lr_init=warmup_lr_init, 29 | warmup_t=warmup_t, 30 | cycle_limit=1, 31 | t_in_epochs=True, 32 | noise_range_t=noise_range, 33 | noise_pct= 0.67, 34 | noise_std= 1., 35 | noise_seed=42, 36 | ) 37 | 38 | return lr_scheduler 39 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from config import cfg 3 | import argparse 4 | from datasets import make_dataloader 5 | from model import make_model 6 | from processor import do_inference 7 | from utils.logger import setup_logger 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 12 | parser.add_argument( 13 | "--config_file", default="", help="path to config file", type=str 14 | ) 15 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 16 | nargs=argparse.REMAINDER) 17 | 18 | args = parser.parse_args() 19 | 20 | 21 | 22 | if args.config_file != "": 23 | cfg.merge_from_file(args.config_file) 24 | cfg.merge_from_list(args.opts) 25 | cfg.freeze() 26 | 27 | output_dir = cfg.OUTPUT_DIR 28 | if output_dir and not os.path.exists(output_dir): 29 | os.makedirs(output_dir) 30 | 31 | logger = setup_logger("transreid", output_dir, if_train=False) 32 | logger.info(args) 33 | 34 | if args.config_file != "": 35 | logger.info("Loaded configuration file {}".format(args.config_file)) 36 | with open(args.config_file, 'r') as cf: 37 | config_str = "\n" + cf.read() 38 | logger.info(config_str) 39 | logger.info("Running with config:\n{}".format(cfg)) 40 | 41 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID 42 | 43 | train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg) 44 | 45 | model = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num) 46 | model.load_param(cfg.TEST.WEIGHT) 47 | 48 | if cfg.DATASETS.NAMES == 'VehicleID': 49 | for trial in range(10): 50 | train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg) 51 | rank_1, rank5 = do_inference(cfg, 52 | model, 53 | val_loader, 54 | num_query) 55 | if trial == 0: 56 | all_rank_1 = rank_1 57 | all_rank_5 = rank5 58 | else: 59 | all_rank_1 = all_rank_1 + rank_1 60 | all_rank_5 = all_rank_5 + rank5 61 | 62 | logger.info("rank_1:{}, rank_5 {} : trial : {}".format(rank_1, rank5, trial)) 63 | logger.info("sum_rank_1:{:.1%}, sum_rank_5 {:.1%}".format(all_rank_1.sum()/10.0, all_rank_5.sum()/10.0)) 64 | else: 65 | do_inference(cfg, 66 | model, 67 | val_loader, 68 | num_query) 69 | 70 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from utils.logger import setup_logger 2 | from datasets import make_dataloader 3 | from model import make_model 4 | from solver import make_optimizer 5 | from solver.scheduler_factory import create_scheduler 6 | from loss import make_loss 7 | from processor import do_train 8 | import random 9 | import torch 10 | import numpy as np 11 | import os 12 | import argparse 13 | # from timm.scheduler import create_scheduler 14 | from config import cfg 15 | 16 | def set_seed(seed): 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | np.random.seed(seed) 21 | random.seed(seed) 22 | torch.backends.cudnn.benchmark = False 23 | torch.backends.cudnn.deterministic = True 24 | os.environ['PYTHONHASHSEED'] = str(seed) 25 | 26 | if __name__ == '__main__': 27 | 28 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 29 | parser.add_argument( 30 | "--config_file", default="", help="path to config file", type=str 31 | ) 32 | 33 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 34 | nargs=argparse.REMAINDER) 35 | parser.add_argument("--local_rank", default=0, type=int) 36 | args = parser.parse_args() 37 | 38 | if args.config_file != "": 39 | cfg.merge_from_file(args.config_file) 40 | cfg.merge_from_list(args.opts) 41 | cfg.freeze() 42 | 43 | set_seed(cfg.SOLVER.SEED) 44 | 45 | if cfg.MODEL.DIST_TRAIN: 46 | torch.cuda.set_device(args.local_rank) 47 | 48 | output_dir = cfg.OUTPUT_DIR 49 | if output_dir and not os.path.exists(output_dir): 50 | os.makedirs(output_dir, exist_ok=True) 51 | 52 | logger = setup_logger("transreid", output_dir, if_train=True) 53 | logger.info("Saving model in the path :{}".format(cfg.OUTPUT_DIR)) 54 | logger.info(args) 55 | 56 | if args.config_file != "": 57 | logger.info("Loaded configuration file {}".format(args.config_file)) 58 | with open(args.config_file, 'r') as cf: 59 | config_str = "\n" + cf.read() 60 | logger.info(config_str) 61 | logger.info("Running with config:\n{}".format(cfg)) 62 | 63 | if cfg.MODEL.DIST_TRAIN: 64 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 65 | 66 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID 67 | train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg) 68 | 69 | model = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num) 70 | 71 | loss_func, center_criterion = make_loss(cfg, num_classes=num_classes) 72 | 73 | optimizer, optimizer_center = make_optimizer(cfg, model, center_criterion) 74 | 75 | scheduler = create_scheduler(cfg, optimizer) 76 | 77 | do_train( 78 | cfg, 79 | model, 80 | center_criterion, 81 | train_loader, 82 | val_loader, 83 | optimizer, 84 | optimizer_center, 85 | scheduler, 86 | loss_func, 87 | num_query, args.local_rank, num_classes 88 | ) 89 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alipay/Diverse-and-Compact-Transformer/de44d0a0b79f9a947f27368a9188f73b16515188/utils/__init__.py -------------------------------------------------------------------------------- /utils/iotools.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | import errno 4 | import json 5 | import os 6 | 7 | import os.path as osp 8 | 9 | 10 | def mkdir_if_missing(directory): 11 | if not osp.exists(directory): 12 | try: 13 | os.makedirs(directory) 14 | except OSError as e: 15 | if e.errno != errno.EEXIST: 16 | raise 17 | 18 | 19 | def check_isfile(path): 20 | isfile = osp.isfile(path) 21 | if not isfile: 22 | print("=> Warning: no file found at '{}' (ignored)".format(path)) 23 | return isfile 24 | 25 | 26 | def read_json(fpath): 27 | with open(fpath, 'r') as f: 28 | obj = json.load(f) 29 | return obj 30 | 31 | 32 | def write_json(obj, fpath): 33 | mkdir_if_missing(osp.dirname(fpath)) 34 | with open(fpath, 'w') as f: 35 | json.dump(obj, f, indent=4, separators=(',', ': ')) 36 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import os.path as osp 5 | def setup_logger(name, save_dir, if_train): 6 | logger = logging.getLogger(name) 7 | logger.setLevel(logging.DEBUG) 8 | 9 | ch = logging.StreamHandler(stream=sys.stdout) 10 | ch.setLevel(logging.DEBUG) 11 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 12 | ch.setFormatter(formatter) 13 | logger.addHandler(ch) 14 | 15 | if save_dir: 16 | if not osp.exists(save_dir): 17 | os.makedirs(save_dir) 18 | if if_train: 19 | fh = logging.FileHandler(os.path.join(save_dir, "train_log.txt"), mode='w') 20 | else: 21 | fh = logging.FileHandler(os.path.join(save_dir, "test_log.txt"), mode='w') 22 | fh.setLevel(logging.DEBUG) 23 | fh.setFormatter(formatter) 24 | logger.addHandler(fh) 25 | 26 | return logger -------------------------------------------------------------------------------- /utils/meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | 4 | def __init__(self): 5 | self.val = 0 6 | self.avg = 0 7 | self.sum = 0 8 | self.count = 0 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from utils.reranking import re_ranking 5 | 6 | 7 | def euclidean_distance(qf, gf): 8 | m = qf.shape[0] 9 | n = gf.shape[0] 10 | dist_mat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 11 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 12 | dist_mat.addmm_(1, -2, qf, gf.t()) 13 | return dist_mat.cpu().numpy() 14 | 15 | def cosine_similarity(qf, gf): 16 | epsilon = 0.00001 17 | dist_mat = qf.mm(gf.t()) 18 | qf_norm = torch.norm(qf, p=2, dim=1, keepdim=True) # mx1 19 | gf_norm = torch.norm(gf, p=2, dim=1, keepdim=True) # nx1 20 | qg_normdot = qf_norm.mm(gf_norm.t()) 21 | 22 | dist_mat = dist_mat.mul(1 / qg_normdot).cpu().numpy() 23 | dist_mat = np.clip(dist_mat, -1 + epsilon, 1 - epsilon) 24 | dist_mat = np.arccos(dist_mat) 25 | return dist_mat 26 | 27 | 28 | def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): 29 | """Evaluation with market1501 metric 30 | Key: for each query identity, its gallery images from the same camera view are discarded. 31 | """ 32 | num_q, num_g = distmat.shape 33 | # distmat g 34 | # q 1 3 2 4 35 | # 4 1 2 3 36 | if num_g < max_rank: 37 | max_rank = num_g 38 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 39 | indices = np.argsort(distmat, axis=1) 40 | # 0 2 1 3 41 | # 1 2 3 0 42 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 43 | # compute cmc curve for each query 44 | all_cmc = [] 45 | all_AP = [] 46 | num_valid_q = 0. # number of valid query 47 | for q_idx in range(num_q): 48 | # get query pid and camid 49 | q_pid = q_pids[q_idx] 50 | q_camid = q_camids[q_idx] 51 | 52 | # remove gallery samples that have the same pid and camid with query 53 | order = indices[q_idx] # select one row 54 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 55 | keep = np.invert(remove) 56 | 57 | # compute cmc curve 58 | # binary vector, positions with value 1 are correct matches 59 | orig_cmc = matches[q_idx][keep] 60 | if not np.any(orig_cmc): 61 | # this condition is true when query identity does not appear in gallery 62 | continue 63 | 64 | cmc = orig_cmc.cumsum() 65 | cmc[cmc > 1] = 1 66 | 67 | all_cmc.append(cmc[:max_rank]) 68 | num_valid_q += 1. 69 | 70 | # compute average precision 71 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 72 | num_rel = orig_cmc.sum() 73 | tmp_cmc = orig_cmc.cumsum() 74 | #tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 75 | y = np.arange(1, tmp_cmc.shape[0] + 1) * 1.0 76 | tmp_cmc = tmp_cmc / y 77 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 78 | AP = tmp_cmc.sum() / num_rel 79 | all_AP.append(AP) 80 | 81 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 82 | 83 | all_cmc = np.asarray(all_cmc).astype(np.float32) 84 | all_cmc = all_cmc.sum(0) / num_valid_q 85 | mAP = np.mean(all_AP) 86 | 87 | return all_cmc, mAP 88 | 89 | 90 | class R1_mAP_eval(): 91 | def __init__(self, num_query, max_rank=50, feat_norm=True, reranking=False): 92 | super(R1_mAP_eval, self).__init__() 93 | self.num_query = num_query 94 | self.max_rank = max_rank 95 | self.feat_norm = feat_norm 96 | self.reranking = reranking 97 | 98 | def reset(self): 99 | self.feats = [] 100 | self.pids = [] 101 | self.camids = [] 102 | 103 | def update(self, output): # called once for each batch 104 | feat, pid, camid = output 105 | self.feats.append(feat.cpu()) 106 | self.pids.extend(np.asarray(pid)) 107 | self.camids.extend(np.asarray(camid)) 108 | 109 | def compute(self): # called after each epoch 110 | feats = torch.cat(self.feats, dim=0) 111 | if self.feat_norm: 112 | print("The test feature is normalized") 113 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) # along channel 114 | # query 115 | qf = feats[:self.num_query] 116 | q_pids = np.asarray(self.pids[:self.num_query]) 117 | q_camids = np.asarray(self.camids[:self.num_query]) 118 | # gallery 119 | gf = feats[self.num_query:] 120 | g_pids = np.asarray(self.pids[self.num_query:]) 121 | 122 | g_camids = np.asarray(self.camids[self.num_query:]) 123 | if self.reranking: 124 | print('=> Enter reranking') 125 | # distmat = re_ranking(qf, gf, k1=20, k2=6, lambda_value=0.3) 126 | distmat = re_ranking(qf, gf, k1=50, k2=15, lambda_value=0.3) 127 | 128 | else: 129 | print('=> Computing DistMat with euclidean_distance') 130 | distmat = euclidean_distance(qf, gf) 131 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids) 132 | 133 | return cmc, mAP, distmat, self.pids, self.camids, feats 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /utils/reranking.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri, 25 May 2018 20:29:09 5 | 6 | 7 | """ 8 | 9 | """ 10 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 11 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 12 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 13 | """ 14 | 15 | """ 16 | API 17 | 18 | probFea: all feature vectors of the query set (torch tensor) 19 | probFea: all feature vectors of the gallery set (torch tensor) 20 | k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3) 21 | MemorySave: set to 'True' when using MemorySave mode 22 | Minibatch: avaliable when 'MemorySave' is 'True' 23 | """ 24 | 25 | import numpy as np 26 | import torch 27 | 28 | 29 | def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat=None, only_local=False): 30 | # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor 31 | query_num = probFea.size(0) 32 | all_num = query_num + galFea.size(0) 33 | if only_local: 34 | original_dist = local_distmat 35 | else: 36 | feat = torch.cat([probFea, galFea]) 37 | # print('using GPU to compute original distance') 38 | distmat = torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num) + \ 39 | torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t() 40 | distmat.addmm_(1, -2, feat, feat.t()) 41 | original_dist = distmat.cpu().numpy() 42 | del feat 43 | if not local_distmat is None: 44 | original_dist = original_dist + local_distmat 45 | gallery_num = original_dist.shape[0] 46 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) 47 | V = np.zeros_like(original_dist).astype(np.float16) 48 | initial_rank = np.argsort(original_dist).astype(np.int32) 49 | 50 | # print('starting re_ranking') 51 | for i in range(all_num): 52 | # k-reciprocal neighbors 53 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 54 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 55 | fi = np.where(backward_k_neigh_index == i)[0] 56 | k_reciprocal_index = forward_k_neigh_index[fi] 57 | k_reciprocal_expansion_index = k_reciprocal_index 58 | for j in range(len(k_reciprocal_index)): 59 | candidate = k_reciprocal_index[j] 60 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1] 61 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 62 | :int(np.around(k1 / 2)) + 1] 63 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 64 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 65 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 66 | candidate_k_reciprocal_index): 67 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 68 | 69 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 70 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 71 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) 72 | original_dist = original_dist[:query_num, ] 73 | if k2 != 1: 74 | V_qe = np.zeros_like(V, dtype=np.float16) 75 | for i in range(all_num): 76 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 77 | V = V_qe 78 | del V_qe 79 | del initial_rank 80 | invIndex = [] 81 | for i in range(gallery_num): 82 | invIndex.append(np.where(V[:, i] != 0)[0]) 83 | 84 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) 85 | 86 | for i in range(query_num): 87 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16) 88 | indNonZero = np.where(V[i, :] != 0)[0] 89 | indImages = [invIndex[ind] for ind in indNonZero] 90 | for j in range(len(indNonZero)): 91 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 92 | V[indImages[j], indNonZero[j]]) 93 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 94 | 95 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 96 | del original_dist 97 | del V 98 | del jaccard_dist 99 | final_dist = final_dist[:query_num, query_num:] 100 | return final_dist 101 | 102 | --------------------------------------------------------------------------------