├── LICENSE ├── README.md ├── config ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── defaults.cpython-36.pyc │ └── defaults.cpython-37.pyc └── defaults.py ├── configs ├── DukeMTMC │ ├── deit_transreid_stride.yml │ ├── vit_base.yml │ ├── vit_jpm.yml │ ├── vit_sie.yml │ ├── vit_transreid.yml │ ├── vit_transreid_384.yml │ ├── vit_transreid_stride.yml │ └── vit_transreid_stride_384.yml ├── MSMT17 │ ├── deit_small.yml │ ├── deit_transreid_stride.yml │ ├── vit_base.yml │ ├── vit_jpm.yml │ ├── vit_sie.yml │ ├── vit_small.yml │ ├── vit_transreid.yml │ ├── vit_transreid_384.yml │ ├── vit_transreid_stride.yml │ └── vit_transreid_stride_384.yml ├── Market │ ├── deit_transreid_stride.yml │ ├── vit_base.yml │ ├── vit_jpm.yml │ ├── vit_sie.yml │ ├── vit_transreid.yml │ ├── vit_transreid_384.yml │ ├── vit_transreid_stride.yml │ └── vit_transreid_stride_384.yml ├── OCC_Duke │ ├── deit_transreid_stride.yml │ ├── osnet.yml │ ├── resnet.yml │ ├── vit_base.yml │ ├── vit_jpm.yml │ ├── vit_sie.yml │ ├── vit_transreid.yml │ └── vit_transreid_stride.yml ├── OCC_ReID │ ├── vit_base.yml │ ├── vit_local.yml │ ├── vit_small.yml │ └── vit_transreid_stride.yml ├── Partial_ReID │ ├── vit_base.yml │ ├── vit_local.yml │ └── vit_transreid_stride.yml ├── VeRi │ ├── deit_transreid.yml │ ├── deit_transreid_stride.yml │ ├── vit_base.yml │ ├── vit_transreid.yml │ └── vit_transreid_stride.yml ├── VehicleID │ ├── deit_transreid.yml │ ├── deit_transreid_stride.yml │ ├── vit_base.yml │ ├── vit_transreid.yml │ └── vit_transreid_stride.yml └── transformer_base.yml ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── bases.cpython-36.pyc │ ├── dukemtmcreid.cpython-36.pyc │ ├── make_dataloader.cpython-36.pyc │ ├── make_dataloader.cpython-37.pyc │ ├── market1501.cpython-36.pyc │ ├── msmt17.cpython-36.pyc │ ├── occ_duke.cpython-36.pyc │ ├── occ_reid.cpython-36.pyc │ ├── partial_reid.cpython-36.pyc │ ├── sampler.cpython-36.pyc │ ├── sampler_ddp.cpython-36.pyc │ ├── vehicleid.cpython-36.pyc │ └── veri.cpython-36.pyc ├── bases.py ├── dukemtmcreid.py ├── keypoint_test.txt ├── keypoint_train.txt ├── make_dataloader.py ├── make_dataloader_allOCC.py ├── market1501.py ├── msmt17.py ├── occ_duke.py ├── occ_reid.py ├── partial_reid.py ├── preprocessing.py ├── sampler.py ├── sampler_ddp.py ├── vehicleid.py └── veri.py ├── dist_test.sh ├── dist_train.sh ├── dist_train_occReID.sh ├── fig ├── 1 ├── OccludedREID_gallery.jpg ├── OccludedREID_query.jpg ├── RankingList-partial.png ├── image-20221018171750395.png ├── image-20221018171831853.png ├── image-20221018171840117.png ├── market_train.jpg ├── partial_gallery.jpg └── partial_query.jpg ├── loss ├── HCloss.py ├── KLloss.py ├── MSEloss.py ├── __init__.py ├── __pycache__ │ ├── HCloss.cpython-36.pyc │ ├── KLloss.cpython-36.pyc │ ├── MSEloss.cpython-36.pyc │ ├── __init__.cpython-36.pyc │ ├── arcface.cpython-36.pyc │ ├── center_loss.cpython-36.pyc │ ├── make_loss.cpython-36.pyc │ ├── metric_learning.cpython-36.pyc │ ├── softmax_loss.cpython-36.pyc │ └── triplet_loss.cpython-36.pyc ├── arcface.py ├── center_loss.py ├── make_loss.py ├── make_loss1_l2norm.py ├── make_loss1_vitbase_resnet.py ├── make_loss_onlyOneAugmentation.py ├── metric_learning.py ├── softmax_loss.py └── triplet_loss.py ├── model ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── make_model.cpython-36.pyc ├── backbones │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── osnet.cpython-36.pyc │ │ ├── resnet.cpython-36.pyc │ │ └── vit_pytorch.cpython-36.pyc │ ├── osnet.py │ ├── resnet.py │ └── vit_pytorch.py ├── make_model.py ├── make_model1_onlyOneAugmentation.py ├── make_model1_osnet.py ├── make_model1_resnet.py └── make_model1_vitbase.py ├── nohup.out ├── processor ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── processor.cpython-36.pyc ├── processor.py ├── processor_onlyOneAugmentation.py ├── processor_resnet_osnet.py └── processor_vitbase_transreid.py ├── requirements.txt ├── solver ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── cosine_lr.cpython-36.pyc │ ├── lr_scheduler.cpython-36.pyc │ ├── make_optimizer.cpython-36.pyc │ ├── scheduler.cpython-36.pyc │ └── scheduler_factory.cpython-36.pyc ├── cosine_lr.py ├── lr_scheduler.py ├── make_optimizer.py ├── scheduler.py └── scheduler_factory.py ├── test.py ├── train.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── __init__.cpython-37.pyc ├── iotools.cpython-36.pyc ├── logger.cpython-36.pyc ├── logger.cpython-37.pyc ├── meter.cpython-36.pyc ├── metrics.cpython-36.pyc └── reranking.cpython-36.pyc ├── iotools.py ├── logger.py ├── meter.py ├── metrics.py └── reranking.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 heshuting555 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PADE 2 | 3 | ## Exciting News!! This paper has been accepted by ICASSP 2024! 4 | 5 | Code of paper "Parallel Augmentation and Dual Enhancement for Occluded Person Re-identification" 6 | 7 | A simple but effective method for both Occluded Person Re-identification and Normal Person Re-identification (with few occlusions) 8 | 9 | Paper link: Parallel Augmentation and Dual Enhancement for Occluded Person Re-identification. [PDF](https://arxiv.org/pdf/2210.05438.pdf) 10 | 11 | The codes are based on the **TransReID (ICCV 2021)**, the basic preparation and environment installation, please refer to [TransReID](https://github.com/damo-cv/TransReID). 12 | 13 | ## Structure of PADE 14 | 15 | ![image-20221018171750395](./fig/image-20221018171750395.png) 16 | 17 | ## Results 18 | 19 | ![image-20221018171831853](./fig/image-20221018171831853.png) 20 | 21 | ## Visualization 22 | 23 | We visualized some of the ranking list results (rank-10) on the Partial ReID dataset. The above is the result of the baseline, and below is our method. 24 | 25 | ![image-ranking](./fig/RankingList-partial.png) 26 | 27 | ## Training 28 | 29 | We will evaluate the model every few epochs. 30 | 31 | **Note:** Since the Partial-REID and Occluded-ReID datasets have few samples in the test set and are easy to overfit, we adopted the "early stopping" strategy and manually selected the results for better accuracy. 32 | 33 | ```python 34 | # Training on Occluded-Duke 35 | python train.py --config_file configs/OCC_Duke/vit_transreid_stride.yml MODEL.DEVICE_ID "('2')" 36 | 37 | # Training on Partial-REID 38 | python train.py --config_file configs/Partial_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('2')" 39 | 40 | # Training on Occluded-ReID 41 | python train.py --config_file configs/OCC_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('2')" 42 | 43 | # Training on Market-1501 44 | python train.py --config_file configs/Market/vit_transreid_stride.yml MODEL.DEVICE_ID "('2')" 45 | 46 | # Training on DukeMTMC-reID 47 | python train.py --config_file configs/DukeMTMC/vit_transreid_stride.yml MODEL.DEVICE_ID "('2')" 48 | ``` 49 | 50 | ## Test 51 | 52 | The pre-trained models will come soon... 53 | 54 | ## Dataset Comparison 55 | 56 | We demonstrate the training and test data imbalance problem of occluded ReID by displaying samples in the training set and test set. Note: Pick one image for each ID as a representative. 57 | 58 | Training data (Market 1501): only a few IDs are obscured 59 | 60 | 61 | 62 | Testing data (Partial-REID): query (left, 100% occluded), gallery (right, ~ 100% non-occluded) 63 |
64 | 65 | 66 | 67 | Testing data (Occluded-REID): query (left, 100% occluded), gallery (right, ~ 100% non-occluded) 68 |
69 | 70 | 71 | 72 | 73 | ## Contact 74 | 75 | Please contact Zi Wang (email address: [ziwang1121@foxmail.com](mailto:ziwang1121@foxmail.com)). Feel free to drop me an email if you have any questions. 76 | 77 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .defaults import _C as cfg 8 | from .defaults import _C as cfg_test 9 | -------------------------------------------------------------------------------- /config/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/config/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /config/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/config/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /config/__pycache__/defaults.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/config/__pycache__/defaults.cpython-36.pyc -------------------------------------------------------------------------------- /config/__pycache__/defaults.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/config/__pycache__/defaults.cpython-37.pyc -------------------------------------------------------------------------------- /configs/DukeMTMC/deit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('2') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.8 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('dukemtmc') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '../logs/dukemtmc_deit_transreid/transformer_120.pth' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/dukemtmc_deit_transreid' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/DukeMTMC/vit_base.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('6') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 128] 15 | SIZE_TEST: [256, 128] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.5 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('dukemtmc') 24 | ROOT_DIR: ('/data2/zi.wang/') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 120 34 | BASE_LR: 0.008 35 | IMS_PER_BATCH: 64 36 | WARMUP_METHOD: 'linear' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 120 39 | LOG_PERIOD: 50 40 | EVAL_PERIOD: 120 41 | WEIGHT_DECAY: 1e-4 42 | WEIGHT_DECAY_BIAS: 1e-4 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: '../logs/duke_vit_base' 54 | 55 | 56 | -------------------------------------------------------------------------------- /configs/DukeMTMC/vit_jpm.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('1') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | JPM: True 13 | RE_ARRANGE: True 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('dukemtmc') 26 | ROOT_DIR: ('../../data') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.008 37 | IMS_PER_BATCH: 64 38 | WARMUP_METHOD: 'linear' 39 | LARGE_FC_LR: False 40 | CHECKPOINT_PERIOD: 120 41 | LOG_PERIOD: 50 42 | EVAL_PERIOD: 120 43 | WEIGHT_DECAY: 1e-4 44 | WEIGHT_DECAY_BIAS: 1e-4 45 | BIAS_LR_FACTOR: 2 46 | 47 | TEST: 48 | EVAL: True 49 | IMS_PER_BATCH: 256 50 | RE_RANKING: False 51 | WEIGHT: '../logs/duke_vit_jpm/transformer_120.pth' 52 | NECK_FEAT: 'before' 53 | FEAT_NORM: 'yes' 54 | 55 | OUTPUT_DIR: '../logs/duke_vit_jpm' 56 | 57 | 58 | -------------------------------------------------------------------------------- /configs/DukeMTMC/vit_sie.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('2') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('dukemtmc') 26 | ROOT_DIR: ('../../data') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.008 37 | IMS_PER_BATCH: 64 38 | WARMUP_METHOD: 'linear' 39 | LARGE_FC_LR: False 40 | CHECKPOINT_PERIOD: 120 41 | LOG_PERIOD: 50 42 | EVAL_PERIOD: 120 43 | WEIGHT_DECAY: 1e-4 44 | WEIGHT_DECAY_BIAS: 1e-4 45 | BIAS_LR_FACTOR: 2 46 | 47 | TEST: 48 | EVAL: True 49 | IMS_PER_BATCH: 256 50 | RE_RANKING: False 51 | WEIGHT: '../logs/duke_vit_sie/transformer_120.pth' 52 | NECK_FEAT: 'before' 53 | FEAT_NORM: 'yes' 54 | 55 | OUTPUT_DIR: '../logs/duke_vit_sie' 56 | 57 | 58 | -------------------------------------------------------------------------------- /configs/DukeMTMC/vit_transreid.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('dukemtmc') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '../logs/duke_vit_transreid/transformer_120.pth' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/duke_vit_transreid' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/DukeMTMC/vit_transreid_384.yml: -------------------------------------------------------------------------------- 1 | tMODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [384, 128] 19 | SIZE_TEST: [384, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('dukemtmc') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '../logs/duke_vit_transreid_384/transformer_120.pth' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/duke_vit_transreid_384' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/DukeMTMC/vit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('4') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('dukemtmc') 28 | ROOT_DIR: ('/data2/zi.wang/') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 300 38 | BASE_LR: 0.01 39 | IMS_PER_BATCH: 32 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 10000 43 | LOG_PERIOD: 100 44 | EVAL_PERIOD: 10 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: './logs_duke/lr001_b32_Process1_Model12_loss1_AO' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/DukeMTMC/vit_transreid_stride_384.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [384, 128] 19 | SIZE_TEST: [384, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('dukemtmc') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '../logs/duke_vit_transreid_stride_384/transformer_120.pth' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/duke_vit_transreid_stride_384' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/MSMT17/deit_small.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/deit_small_distilled_patch16_224-649709d9.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('1') 10 | TRANSFORMER_TYPE: 'deit_small_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 128] 15 | SIZE_TEST: [256, 128] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.8 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('msmt17') 24 | ROOT_DIR: ('../../data') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 120 34 | BASE_LR: 0.005 35 | IMS_PER_BATCH: 64 36 | LARGE_FC_LR: False 37 | CHECKPOINT_PERIOD: 120 38 | LOG_PERIOD: 50 39 | EVAL_PERIOD: 120 40 | WEIGHT_DECAY: 1e-4 41 | WEIGHT_DECAY_BIAS: 1e-4 42 | BIAS_LR_FACTOR: 2 43 | 44 | TEST: 45 | EVAL: True 46 | IMS_PER_BATCH: 256 47 | RE_RANKING: False 48 | WEIGHT: '' 49 | NECK_FEAT: 'before' 50 | FEAT_NORM: 'yes' 51 | 52 | OUTPUT_DIR: '../logs/msmt17_deit_small_try' 53 | 54 | 55 | -------------------------------------------------------------------------------- /configs/MSMT17/deit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('5') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.8 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('msmt17') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.005 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/msmt17_deit_transreid' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/MSMT17/vit_base.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('4') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 128] 15 | SIZE_TEST: [256, 128] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.5 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('msmt17') 24 | ROOT_DIR: ('../../data') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 120 34 | BASE_LR: 0.008 35 | IMS_PER_BATCH: 64 36 | WARMUP_METHOD: 'linear' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 120 39 | LOG_PERIOD: 50 40 | EVAL_PERIOD: 120 41 | WEIGHT_DECAY: 1e-4 42 | WEIGHT_DECAY_BIAS: 1e-4 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: '../logs/msmt17_vit_base' 54 | 55 | 56 | -------------------------------------------------------------------------------- /configs/MSMT17/vit_jpm.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('1') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | JPM: True 13 | RE_ARRANGE: True 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('msmt17') 26 | ROOT_DIR: ('../../data') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.008 37 | IMS_PER_BATCH: 64 38 | WARMUP_METHOD: 'linear' 39 | LARGE_FC_LR: False 40 | CHECKPOINT_PERIOD: 120 41 | LOG_PERIOD: 50 42 | EVAL_PERIOD: 120 43 | WEIGHT_DECAY: 1e-4 44 | WEIGHT_DECAY_BIAS: 1e-4 45 | BIAS_LR_FACTOR: 2 46 | 47 | TEST: 48 | EVAL: True 49 | IMS_PER_BATCH: 256 50 | RE_RANKING: False 51 | WEIGHT: '' 52 | NECK_FEAT: 'before' 53 | FEAT_NORM: 'yes' 54 | 55 | OUTPUT_DIR: '../logs/msmt17_vit_jpm' 56 | 57 | 58 | -------------------------------------------------------------------------------- /configs/MSMT17/vit_sie.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('2') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('msmt17') 26 | ROOT_DIR: ('../../data') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.008 37 | IMS_PER_BATCH: 64 38 | WARMUP_METHOD: 'linear' 39 | LARGE_FC_LR: False 40 | CHECKPOINT_PERIOD: 120 41 | LOG_PERIOD: 50 42 | EVAL_PERIOD: 120 43 | WEIGHT_DECAY: 1e-4 44 | WEIGHT_DECAY_BIAS: 1e-4 45 | BIAS_LR_FACTOR: 2 46 | 47 | TEST: 48 | EVAL: True 49 | IMS_PER_BATCH: 256 50 | RE_RANKING: False 51 | WEIGHT: '' 52 | NECK_FEAT: 'before' 53 | FEAT_NORM: 'yes' 54 | 55 | OUTPUT_DIR: '../logs/msmt17_vit_sie' 56 | 57 | 58 | -------------------------------------------------------------------------------- /configs/MSMT17/vit_small.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/vit_small_p16_224-15ec54c9.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('0') 10 | TRANSFORMER_TYPE: 'vit_small_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 128] 15 | SIZE_TEST: [256, 128] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.8 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('msmt17') 24 | ROOT_DIR: ('../../data') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 120 34 | BASE_LR: 0.005 35 | IMS_PER_BATCH: 64 36 | WARMUP_METHOD: 'linear' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 120 39 | LOG_PERIOD: 50 40 | EVAL_PERIOD: 120 41 | WEIGHT_DECAY: 1e-4 42 | WEIGHT_DECAY_BIAS: 1e-4 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: '../logs/msmt17_vit_small' 54 | 55 | 56 | -------------------------------------------------------------------------------- /configs/MSMT17/vit_transreid.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('msmt17') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/msmt17_vit_transreid' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/MSMT17/vit_transreid_384.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [384, 128] 19 | SIZE_TEST: [384, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('msmt17') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/msmt17_vit_transreid_384' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/MSMT17/vit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('6') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('msmt17') 28 | ROOT_DIR: ('/data/zi.wang/') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 32 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 1000 43 | LOG_PERIOD: 1000 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: './logs_msmt17/lr0008_b32_Process1_Model1_loss1' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/MSMT17/vit_transreid_stride_384.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [384, 128] 19 | SIZE_TEST: [384, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('msmt17') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/msmt17_vit_transreid_stride_384' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/Market/deit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('4') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.8 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('market1501') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '../logs/0321_market_deit_transreie/transformer_120.pth' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/0321_market_deit_transreie' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/Market/vit_base.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('7') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 128] 15 | SIZE_TEST: [256, 128] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.5 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('market1501') 24 | ROOT_DIR: ('../../data') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 120 34 | BASE_LR: 0.008 35 | IMS_PER_BATCH: 64 36 | WARMUP_METHOD: 'linear' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 120 39 | LOG_PERIOD: 50 40 | EVAL_PERIOD: 120 41 | WEIGHT_DECAY: 1e-4 42 | WEIGHT_DECAY_BIAS: 1e-4 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '../logs/0321_market_vit_base/transformer_120.pth' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: '../logs/0321_market_vit_base' 54 | 55 | 56 | -------------------------------------------------------------------------------- /configs/Market/vit_jpm.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('1') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | JPM: True 13 | RE_ARRANGE: True 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('market1501') 26 | ROOT_DIR: ('../../data') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.008 37 | IMS_PER_BATCH: 64 38 | WARMUP_METHOD: 'linear' 39 | LARGE_FC_LR: False 40 | CHECKPOINT_PERIOD: 120 41 | LOG_PERIOD: 50 42 | EVAL_PERIOD: 120 43 | WEIGHT_DECAY: 1e-4 44 | WEIGHT_DECAY_BIAS: 1e-4 45 | BIAS_LR_FACTOR: 2 46 | 47 | TEST: 48 | EVAL: True 49 | IMS_PER_BATCH: 256 50 | RE_RANKING: False 51 | WEIGHT: '../logs/0321_market_vit_jpm/transformer_120.pth' 52 | NECK_FEAT: 'before' 53 | FEAT_NORM: 'yes' 54 | 55 | OUTPUT_DIR: '../logs/0321_market_vit_jpm' 56 | 57 | 58 | -------------------------------------------------------------------------------- /configs/Market/vit_sie.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('7') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('market1501') 26 | ROOT_DIR: ('../../data') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.008 37 | IMS_PER_BATCH: 64 38 | WARMUP_METHOD: 'linear' 39 | LARGE_FC_LR: False 40 | CHECKPOINT_PERIOD: 120 41 | LOG_PERIOD: 50 42 | EVAL_PERIOD: 120 43 | WEIGHT_DECAY: 1e-4 44 | WEIGHT_DECAY_BIAS: 1e-4 45 | BIAS_LR_FACTOR: 2 46 | 47 | TEST: 48 | EVAL: True 49 | IMS_PER_BATCH: 256 50 | RE_RANKING: False 51 | WEIGHT: '' 52 | NECK_FEAT: 'before' 53 | FEAT_NORM: 'yes' 54 | 55 | OUTPUT_DIR: '../logs/market_vit_sie' 56 | 57 | 58 | -------------------------------------------------------------------------------- /configs/Market/vit_transreid.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('5') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('market1501') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '../logs/market_vit_transreid/transformer_120.pth' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/market_vit_transreid' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/Market/vit_transreid_384.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('5') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [384, 128] 19 | SIZE_TEST: [384, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('market1501') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '../logs/market_vit_transreid_384/transformer_120.pth' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/market_vit_transreid_384' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/Market/vit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('5') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('market1501') 28 | ROOT_DIR: ('/data2/zi.wang/') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 300 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 32 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 10000 43 | LOG_PERIOD: 100 44 | EVAL_PERIOD: 5 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: './logs_market/lr0008_b32_Process1_Model12_loss1_AO' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/Market/vit_transreid_stride_384.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [384, 128] 19 | SIZE_TEST: [384, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('market1501') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/0321_market_vit_transreid_stride_384' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/OCC_Duke/deit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('2') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [11, 11] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.8 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('occ_duke') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/occ_duke_deit_transreid_stride11' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/OCC_Duke/osnet.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/osnet_x0_5_imagenet.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'osnet' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('5') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 128] 15 | SIZE_TEST: [256, 128] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.5 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('occ_duke') 24 | ROOT_DIR: ('/data2/zi.wang/') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 170 34 | BASE_LR: 0.008 35 | IMS_PER_BATCH: 32 36 | WARMUP_METHOD: 'linear' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 1000 39 | LOG_PERIOD: 200 40 | EVAL_PERIOD: 5 41 | WEIGHT_DECAY: 1e-4 42 | WEIGHT_DECAY_BIAS: 1e-4 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: './logs_occ_duke/osnet/lr0008_b32_Process1_Model1_loss1' 54 | 55 | 56 | -------------------------------------------------------------------------------- /configs/OCC_Duke/resnet.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'resnet50' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('5') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 128] 15 | SIZE_TEST: [256, 128] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.5 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('occ_duke') 24 | ROOT_DIR: ('/data2/zi.wang/') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 170 34 | BASE_LR: 0.008 35 | IMS_PER_BATCH: 32 36 | WARMUP_METHOD: 'linear' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 1000 39 | LOG_PERIOD: 200 40 | EVAL_PERIOD: 5 41 | WEIGHT_DECAY: 1e-4 42 | WEIGHT_DECAY_BIAS: 1e-4 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: './logs_occ_duke/resnet/lr0008_b32_Process1_Model1_loss1' 54 | 55 | 56 | -------------------------------------------------------------------------------- /configs/OCC_Duke/vit_base.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('5') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 128] 15 | SIZE_TEST: [256, 128] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.5 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('occ_duke') 24 | ROOT_DIR: ('/data2/zi.wang/') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 120 34 | BASE_LR: 0.008 35 | IMS_PER_BATCH: 32 36 | WARMUP_METHOD: 'linear' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 1000 39 | LOG_PERIOD: 50 40 | EVAL_PERIOD: 5 41 | WEIGHT_DECAY: 1e-4 42 | WEIGHT_DECAY_BIAS: 1e-4 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: './logs_occ_duke/vit_base/lr0008_b32_Process1_Model1_loss4' 54 | 55 | 56 | -------------------------------------------------------------------------------- /configs/OCC_Duke/vit_jpm.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('1') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | JPM: True 13 | RE_ARRANGE: True 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('occ_duke') 26 | ROOT_DIR: ('../../data') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.008 37 | IMS_PER_BATCH: 64 38 | WARMUP_METHOD: 'linear' 39 | LARGE_FC_LR: False 40 | CHECKPOINT_PERIOD: 120 41 | LOG_PERIOD: 50 42 | EVAL_PERIOD: 120 43 | WEIGHT_DECAY: 1e-4 44 | WEIGHT_DECAY_BIAS: 1e-4 45 | BIAS_LR_FACTOR: 2 46 | 47 | TEST: 48 | EVAL: True 49 | IMS_PER_BATCH: 256 50 | RE_RANKING: False 51 | WEIGHT: '' 52 | NECK_FEAT: 'before' 53 | FEAT_NORM: 'yes' 54 | 55 | OUTPUT_DIR: '../logs/occ_duke_vit_jpm' 56 | 57 | 58 | -------------------------------------------------------------------------------- /configs/OCC_Duke/vit_sie.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('2') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | 15 | INPUT: 16 | SIZE_TRAIN: [256, 128] 17 | SIZE_TEST: [256, 128] 18 | PROB: 0.5 # random horizontal flip 19 | RE_PROB: 0.5 # random erasing 20 | PADDING: 10 21 | PIXEL_MEAN: [0.5, 0.5, 0.5] 22 | PIXEL_STD: [0.5, 0.5, 0.5] 23 | 24 | DATASETS: 25 | NAMES: ('occ_duke') 26 | ROOT_DIR: ('../../data') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.008 37 | IMS_PER_BATCH: 64 38 | WARMUP_METHOD: 'linear' 39 | LARGE_FC_LR: False 40 | CHECKPOINT_PERIOD: 120 41 | LOG_PERIOD: 50 42 | EVAL_PERIOD: 120 43 | WEIGHT_DECAY: 1e-4 44 | WEIGHT_DECAY_BIAS: 1e-4 45 | BIAS_LR_FACTOR: 2 46 | 47 | TEST: 48 | EVAL: True 49 | IMS_PER_BATCH: 256 50 | RE_RANKING: False 51 | WEIGHT: '' 52 | NECK_FEAT: 'before' 53 | FEAT_NORM: 'yes' 54 | 55 | OUTPUT_DIR: '../logs/occ_duke_vit_sie' 56 | 57 | 58 | -------------------------------------------------------------------------------- /configs/OCC_Duke/vit_transreid.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/xxx/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('occ_duke') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/occ_duke_vit_transreid' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/OCC_Duke/vit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [11, 11] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('occ_duke') 28 | ROOT_DIR: ('/data2/zi.wang/') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 170 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 32 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 1000 43 | LOG_PERIOD: 200 44 | EVAL_PERIOD: 5 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | STEPS: (40, 70) 49 | 50 | TEST: 51 | EVAL: True 52 | IMS_PER_BATCH: 256 53 | RE_RANKING: False 54 | WEIGHT: '' 55 | NECK_FEAT: 'before' 56 | FEAT_NORM: 'yes' 57 | 58 | OUTPUT_DIR: './logs_occ_duke/lr0008_b32' 59 | 60 | 61 | -------------------------------------------------------------------------------- /configs/OCC_ReID/vit_base.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('5') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 128] 15 | SIZE_TEST: [256, 128] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.5 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('occ_reid') 24 | ROOT_DIR: ('/data/zi.wang/') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 120 34 | BASE_LR: 0.004 35 | IMS_PER_BATCH: 32 36 | WARMUP_METHOD: 'linear' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 1000 39 | LOG_PERIOD: 100 40 | EVAL_PERIOD: 1 41 | WEIGHT_DECAY: 1e-4 42 | WEIGHT_DECAY_BIAS: 1e-4 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: './logs_occ_reid/vit_base_load_lr0004_b32_trainAll' 54 | 55 | 56 | -------------------------------------------------------------------------------- /configs/OCC_ReID/vit_local.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [11, 11] 12 | JPM: True 13 | 14 | INPUT: 15 | SIZE_TRAIN: [256, 128] 16 | SIZE_TEST: [256, 128] 17 | PROB: 0.5 # random horizontal flip 18 | RE_PROB: 0.5 # random erasing 19 | PADDING: 10 20 | PIXEL_MEAN: [0.5, 0.5, 0.5] 21 | PIXEL_STD: [0.5, 0.5, 0.5] 22 | 23 | DATASETS: 24 | NAMES: ('occ_reid') 25 | ROOT_DIR: ('/data2/zi.wang/') 26 | 27 | DATALOADER: 28 | SAMPLER: 'softmax_triplet' 29 | NUM_INSTANCE: 4 30 | NUM_WORKERS: 8 31 | 32 | SOLVER: 33 | OPTIMIZER_NAME: 'SGD' 34 | MAX_EPOCHS: 120 35 | BASE_LR: 0.008 36 | IMS_PER_BATCH: 32 37 | WARMUP_METHOD: 'linear' 38 | LARGE_FC_LR: False 39 | CHECKPOINT_PERIOD: 120 40 | LOG_PERIOD: 100 41 | EVAL_PERIOD: 1 42 | WEIGHT_DECAY: 1e-4 43 | WEIGHT_DECAY_BIAS: 1e-4 44 | BIAS_LR_FACTOR: 2 45 | 46 | TEST: 47 | EVAL: True 48 | IMS_PER_BATCH: 256 49 | RE_RANKING: False 50 | WEIGHT: '' 51 | NECK_FEAT: 'before' 52 | FEAT_NORM: 'yes' 53 | 54 | OUTPUT_DIR: './logs_occ_reid/vit_local_lr0008_b32_trainAll' 55 | 56 | 57 | -------------------------------------------------------------------------------- /configs/OCC_ReID/vit_small.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/vit_small_p16_224-15ec54c9.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('5') 10 | TRANSFORMER_TYPE: 'vit_small_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 128] 15 | SIZE_TEST: [256, 128] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.5 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('occ_reid') 24 | ROOT_DIR: ('/data/zi.wang/') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 80 34 | BASE_LR: 0.004 35 | IMS_PER_BATCH: 32 36 | WARMUP_METHOD: 'linear' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 120 39 | LOG_PERIOD: 100 40 | EVAL_PERIOD: 1 41 | WEIGHT_DECAY: 1e-4 42 | WEIGHT_DECAY_BIAS: 1e-4 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: './logs_occ_reid/vit_small_load_lr0004_b32_KL1_ori_and_eraser' 54 | 55 | 56 | -------------------------------------------------------------------------------- /configs/OCC_ReID/vit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [11, 11] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('occ_reid') 28 | ROOT_DIR: ('/data2/zi.wang/') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 8 38 | BASE_LR: 0.0001 39 | IMS_PER_BATCH: 32 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 1000 43 | LOG_PERIOD: 100 44 | EVAL_PERIOD: 1 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: './logs_occ_reid/lr00001_b32_Process1_Model1_loss1' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/Partial_ReID/vit_base.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('5') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 128] 15 | SIZE_TEST: [256, 128] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.5 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('partial_reid') 24 | ROOT_DIR: ('/data/zi.wang/') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 80 34 | BASE_LR: 0.008 35 | IMS_PER_BATCH: 32 36 | WARMUP_METHOD: 'linear' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 1000 39 | LOG_PERIOD: 50 40 | EVAL_PERIOD: 1 41 | WEIGHT_DECAY: 0.0001 42 | WEIGHT_DECAY_BIAS: 0.0001 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: './logs_partial_reid/vit_base_load_lr0008_b32_trainAll' 54 | 55 | 56 | -------------------------------------------------------------------------------- /configs/Partial_ReID/vit_local.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [11, 11] 12 | JPM: True 13 | 14 | INPUT: 15 | SIZE_TRAIN: [256, 128] 16 | SIZE_TEST: [256, 128] 17 | PROB: 0.5 # random horizontal flip 18 | RE_PROB: 0.5 # random erasing 19 | PADDING: 10 20 | PIXEL_MEAN: [0.5, 0.5, 0.5] 21 | PIXEL_STD: [0.5, 0.5, 0.5] 22 | 23 | DATASETS: 24 | NAMES: ('partial_reid') 25 | ROOT_DIR: ('/data2/zi.wang/') 26 | 27 | DATALOADER: 28 | SAMPLER: 'softmax_triplet' 29 | NUM_INSTANCE: 4 30 | NUM_WORKERS: 8 31 | 32 | SOLVER: 33 | OPTIMIZER_NAME: 'SGD' 34 | MAX_EPOCHS: 80 35 | BASE_LR: 0.008 36 | IMS_PER_BATCH: 32 37 | WARMUP_METHOD: 'linear' 38 | LARGE_FC_LR: False 39 | CHECKPOINT_PERIOD: 120 40 | LOG_PERIOD: 100 41 | EVAL_PERIOD: 1 42 | WEIGHT_DECAY: 1e-4 43 | WEIGHT_DECAY_BIAS: 1e-4 44 | BIAS_LR_FACTOR: 2 45 | 46 | TEST: 47 | EVAL: True 48 | IMS_PER_BATCH: 256 49 | RE_RANKING: False 50 | WEIGHT: '' 51 | NECK_FEAT: 'before' 52 | FEAT_NORM: 'yes' 53 | 54 | OUTPUT_DIR: './logs_partial_reid/vit_local_stride' 55 | 56 | 57 | -------------------------------------------------------------------------------- /configs/Partial_ReID/vit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [11, 11] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: True 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('partial_reid') 28 | ROOT_DIR: ('/data2/zi.wang/') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 8 38 | BASE_LR: 0.0001 39 | IMS_PER_BATCH: 32 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 1000 43 | LOG_PERIOD: 100 44 | EVAL_PERIOD: 1 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: './logs_partial_reid/lr00001_b32_Process1_Model12_loss1' 58 | 59 | -------------------------------------------------------------------------------- /configs/VeRi/deit_transreid.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('4') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_VIEW: True 14 | SIE_COE: 3.0 15 | JPM: True 16 | SHIFT_NUM: 8 17 | RE_ARRANGE: True 18 | 19 | INPUT: 20 | SIZE_TRAIN: [256, 256] 21 | SIZE_TEST: [256, 256] 22 | PROB: 0.5 # random horizontal flip 23 | RE_PROB: 0.8 # random erasing 24 | PADDING: 10 25 | 26 | DATASETS: 27 | NAMES: ('veri') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.01 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/veri_deit_transreid' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/VeRi/deit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('0') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_VIEW: True 14 | SIE_COE: 3.0 15 | JPM: True 16 | SHIFT_NUM: 8 17 | RE_ARRANGE: True 18 | 19 | INPUT: 20 | SIZE_TRAIN: [256, 256] 21 | SIZE_TEST: [256, 256] 22 | PROB: 0.5 # random horizontal flip 23 | RE_PROB: 0.8 # random erasing 24 | PADDING: 10 25 | 26 | DATASETS: 27 | NAMES: ('veri') 28 | ROOT_DIR: ('../../datasets') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.01 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/veri_deit_transreid_stride' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/VeRi/vit_base.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('4') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 256] 15 | SIZE_TEST: [256, 256] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.5 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('veri') 24 | ROOT_DIR: ('../../data') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 120 34 | BASE_LR: 0.008 35 | IMS_PER_BATCH: 64 36 | WARMUP_METHOD: 'linear' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 120 39 | LOG_PERIOD: 50 40 | EVAL_PERIOD: 120 41 | WEIGHT_DECAY: 1e-4 42 | WEIGHT_DECAY_BIAS: 1e-4 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: '../logs/veri_vit_base' 54 | 55 | 56 | -------------------------------------------------------------------------------- /configs/VeRi/vit_transreid.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('4') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | SIE_CAMERA: True 13 | SIE_VIEW: True 14 | SIE_COE: 3.0 15 | JPM: True 16 | SHIFT_NUM: 8 17 | RE_ARRANGE: True 18 | 19 | INPUT: 20 | SIZE_TRAIN: [256, 256] 21 | SIZE_TEST: [256, 256] 22 | PROB: 0.5 # random horizontal flip 23 | RE_PROB: 0.5 # random erasing 24 | PADDING: 10 25 | PIXEL_MEAN: [0.5, 0.5, 0.5] 26 | PIXEL_STD: [0.5, 0.5, 0.5] 27 | 28 | DATASETS: 29 | NAMES: ('veri') 30 | ROOT_DIR: ('../../data') 31 | 32 | DATALOADER: 33 | SAMPLER: 'softmax_triplet' 34 | NUM_INSTANCE: 4 35 | NUM_WORKERS: 8 36 | 37 | SOLVER: 38 | OPTIMIZER_NAME: 'SGD' 39 | MAX_EPOCHS: 120 40 | BASE_LR: 0.01 41 | IMS_PER_BATCH: 64 42 | WARMUP_METHOD: 'linear' 43 | LARGE_FC_LR: False 44 | CHECKPOINT_PERIOD: 120 45 | LOG_PERIOD: 50 46 | EVAL_PERIOD: 120 47 | WEIGHT_DECAY: 1e-4 48 | WEIGHT_DECAY_BIAS: 1e-4 49 | BIAS_LR_FACTOR: 2 50 | 51 | TEST: 52 | EVAL: True 53 | IMS_PER_BATCH: 256 54 | RE_RANKING: False 55 | WEIGHT: '' 56 | NECK_FEAT: 'before' 57 | FEAT_NORM: 'yes' 58 | 59 | OUTPUT_DIR: '../logs/veri_vit_transreid' 60 | 61 | 62 | -------------------------------------------------------------------------------- /configs/VeRi/vit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/zi.wang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('2') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_VIEW: True 14 | SIE_COE: 3.0 15 | JPM: True 16 | SHIFT_NUM: 8 17 | RE_ARRANGE: True 18 | 19 | INPUT: 20 | SIZE_TRAIN: [256, 128] 21 | SIZE_TEST: [256, 128] 22 | PROB: 0.5 # random horizontal flip 23 | RE_PROB: 0.5 # random erasing 24 | PADDING: 10 25 | PIXEL_MEAN: [0.5, 0.5, 0.5] 26 | PIXEL_STD: [0.5, 0.5, 0.5] 27 | 28 | DATASETS: 29 | NAMES: ('veri') 30 | ROOT_DIR: ('/data/zi.wang/') 31 | 32 | DATALOADER: 33 | SAMPLER: 'softmax_triplet' 34 | NUM_INSTANCE: 4 35 | NUM_WORKERS: 8 36 | 37 | SOLVER: 38 | OPTIMIZER_NAME: 'SGD' 39 | MAX_EPOCHS: 150 40 | BASE_LR: 0.01 41 | IMS_PER_BATCH: 32 42 | WARMUP_METHOD: 'linear' 43 | LARGE_FC_LR: False 44 | CHECKPOINT_PERIOD: 1000 45 | LOG_PERIOD: 100 46 | EVAL_PERIOD: 10 47 | WEIGHT_DECAY: 1e-4 48 | WEIGHT_DECAY_BIAS: 1e-4 49 | BIAS_LR_FACTOR: 2 50 | 51 | TEST: 52 | EVAL: True 53 | IMS_PER_BATCH: 256 54 | RE_RANKING: False 55 | WEIGHT: '' 56 | NECK_FEAT: 'before' 57 | FEAT_NORM: 'yes' 58 | 59 | OUTPUT_DIR: './logs_veri/lr001_b32_Process1_Model1_loss1' 60 | 61 | 62 | -------------------------------------------------------------------------------- /configs/VehicleID/deit_transreid.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | # DEVICE_ID: ('0') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | DIST_TRAIN: True 13 | JPM: True 14 | SHIFT_NUM: 8 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 256] 19 | SIZE_TEST: [256, 256] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.8 # random erasing 22 | PADDING: 10 23 | 24 | DATASETS: 25 | NAMES: ('VehicleID') 26 | ROOT_DIR: ('../../data') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.03 37 | IMS_PER_BATCH: 256 38 | WARMUP_METHOD: 'linear' 39 | LARGE_FC_LR: False 40 | CHECKPOINT_PERIOD: 120 41 | LOG_PERIOD: 50 42 | EVAL_PERIOD: 120 43 | WEIGHT_DECAY: 1e-4 44 | WEIGHT_DECAY_BIAS: 1e-4 45 | BIAS_LR_FACTOR: 2 46 | 47 | TEST: 48 | EVAL: True 49 | IMS_PER_BATCH: 256 50 | RE_RANKING: False 51 | WEIGHT: '' 52 | NECK_FEAT: 'before' 53 | FEAT_NORM: 'yes' 54 | 55 | OUTPUT_DIR: '../logs/vehicleID_deit_transreid' 56 | 57 | 58 | -------------------------------------------------------------------------------- /configs/VehicleID/deit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/deit_base_distilled_patch16_224-df68dfff.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | # DEVICE_ID: ('0') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | DIST_TRAIN: True 13 | JPM: True 14 | SHIFT_NUM: 8 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 256] 19 | SIZE_TEST: [256, 256] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.8 # random erasing 22 | PADDING: 10 23 | 24 | DATASETS: 25 | NAMES: ('VehicleID') 26 | ROOT_DIR: ('../../data') 27 | 28 | DATALOADER: 29 | SAMPLER: 'softmax_triplet' 30 | NUM_INSTANCE: 4 31 | NUM_WORKERS: 8 32 | 33 | SOLVER: 34 | OPTIMIZER_NAME: 'SGD' 35 | MAX_EPOCHS: 120 36 | BASE_LR: 0.03 37 | IMS_PER_BATCH: 256 38 | WARMUP_METHOD: 'linear' 39 | LARGE_FC_LR: False 40 | CHECKPOINT_PERIOD: 120 41 | LOG_PERIOD: 50 42 | EVAL_PERIOD: 120 43 | WEIGHT_DECAY: 1e-4 44 | WEIGHT_DECAY_BIAS: 1e-4 45 | BIAS_LR_FACTOR: 2 46 | 47 | TEST: 48 | EVAL: True 49 | IMS_PER_BATCH: 256 50 | RE_RANKING: False 51 | WEIGHT: '' 52 | NECK_FEAT: 'before' 53 | FEAT_NORM: 'yes' 54 | 55 | OUTPUT_DIR: '../logs/vehicleID_deit_transreid_stride' 56 | 57 | 58 | -------------------------------------------------------------------------------- /configs/VehicleID/vit_base.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | # DEVICE_ID: ('0') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 256] 15 | SIZE_TEST: [256, 256] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.5 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('VehicleID') 24 | ROOT_DIR: ('../../data') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 120 34 | BASE_LR: 0.04 35 | IMS_PER_BATCH: 224 36 | WARMUP_METHOD: 'linear' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 120 39 | LOG_PERIOD: 50 40 | EVAL_PERIOD: 120 41 | WEIGHT_DECAY: 1e-4 42 | WEIGHT_DECAY_BIAS: 1e-4 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '../logs/vehicleID_vit_base/transformer_120.pth' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: '../logs/vehicleID_vit_base' 54 | 55 | 56 | -------------------------------------------------------------------------------- /configs/VehicleID/vit_transreid.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | # DEVICE_ID: ('0') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | # DIST_TRAIN: True 13 | JPM: True 14 | SHIFT_NUM: 8 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 256] 19 | SIZE_TEST: [256, 256] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('VehicleID') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.045 39 | IMS_PER_BATCH: 224 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '../logs/vehicleID_vit_transreid/transformer_120.pth' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: '../logs/vehicleID_vit_transreid' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/VehicleID/vit_transreid_stride.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/ziwang/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | # DEVICE_ID: ('0') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | # DIST_TRAIN: True 13 | JPM: True 14 | SHIFT_NUM: 8 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 256] 19 | SIZE_TEST: [256, 256] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('VehicleID') 28 | ROOT_DIR: ('/data/zi.wang/') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 120 38 | BASE_LR: 0.045 39 | IMS_PER_BATCH: 256 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 120 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '../logs/vehicleID_vit_transreid_stride/transformer_120.pth' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: './logs_vehicleID/vit_transreid_stride' 58 | 59 | 60 | -------------------------------------------------------------------------------- /configs/transformer_base.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/heshuting/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('7') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [16, 16] 12 | 13 | INPUT: 14 | SIZE_TRAIN: [256, 256] 15 | SIZE_TEST: [256, 256] 16 | PROB: 0.5 # random horizontal flip 17 | RE_PROB: 0.5 # random erasing 18 | PADDING: 10 19 | PIXEL_MEAN: [0.5, 0.5, 0.5] 20 | PIXEL_STD: [0.5, 0.5, 0.5] 21 | 22 | DATASETS: 23 | NAMES: ('dukemtmc') 24 | ROOT_DIR: ('../../data') 25 | 26 | DATALOADER: 27 | SAMPLER: 'softmax_triplet' 28 | NUM_INSTANCE: 4 29 | NUM_WORKERS: 8 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'SGD' 33 | MAX_EPOCHS: 120 34 | BASE_LR: 0.008 35 | IMS_PER_BATCH: 64 36 | WARMUP_METHOD: 'linear' 37 | LARGE_FC_LR: False 38 | CHECKPOINT_PERIOD: 120 39 | LOG_PERIOD: 50 40 | EVAL_PERIOD: 120 41 | WEIGHT_DECAY: 1e-4 42 | WEIGHT_DECAY_BIAS: 1e-4 43 | BIAS_LR_FACTOR: 2 44 | 45 | TEST: 46 | EVAL: True 47 | IMS_PER_BATCH: 256 48 | RE_RANKING: False 49 | WEIGHT: '' 50 | NECK_FEAT: 'before' 51 | FEAT_NORM: 'yes' 52 | 53 | OUTPUT_DIR: '../logs/' 54 | 55 | 56 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_dataloader import make_dataloader -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/bases.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/bases.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dukemtmcreid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/dukemtmcreid.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/make_dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/make_dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/make_dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/make_dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/market1501.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/market1501.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/msmt17.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/msmt17.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/occ_duke.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/occ_duke.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/occ_reid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/occ_reid.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/partial_reid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/partial_reid.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sampler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/sampler.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sampler_ddp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/sampler_ddp.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/vehicleid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/vehicleid.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/veri.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/datasets/__pycache__/veri.cpython-36.pyc -------------------------------------------------------------------------------- /datasets/bases.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageFile 2 | 3 | from torch.utils.data import Dataset 4 | import os.path as osp 5 | import random 6 | import torch 7 | ImageFile.LOAD_TRUNCATED_IMAGES = True 8 | 9 | 10 | def read_image(img_path): 11 | """Keep reading image until succeed. 12 | This can avoid IOError incurred by heavy IO process.""" 13 | got_img = False 14 | if not osp.exists(img_path): 15 | raise IOError("{} does not exist".format(img_path)) 16 | while not got_img: 17 | try: 18 | img = Image.open(img_path).convert('RGB') 19 | got_img = True 20 | except IOError: 21 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 22 | pass 23 | return img 24 | 25 | 26 | class BaseDataset(object): 27 | """ 28 | Base class of reid dataset 29 | """ 30 | 31 | def get_imagedata_info(self, data): 32 | pids, cams, tracks = [], [], [] 33 | 34 | for _, pid, camid, trackid in data: 35 | pids += [pid] 36 | cams += [camid] 37 | tracks += [trackid] 38 | pids = set(pids) 39 | cams = set(cams) 40 | tracks = set(tracks) 41 | num_pids = len(pids) 42 | num_cams = len(cams) 43 | num_imgs = len(data) 44 | num_views = len(tracks) 45 | return num_pids, num_imgs, num_cams, num_views 46 | 47 | def print_dataset_statistics(self): 48 | raise NotImplementedError 49 | 50 | 51 | class BaseImageDataset(BaseDataset): 52 | """ 53 | Base class of image reid dataset 54 | """ 55 | 56 | def print_dataset_statistics(self, train, query, gallery): 57 | num_train_pids, num_train_imgs, num_train_cams, num_train_views = self.get_imagedata_info(train) 58 | num_query_pids, num_query_imgs, num_query_cams, num_train_views = self.get_imagedata_info(query) 59 | num_gallery_pids, num_gallery_imgs, num_gallery_cams, num_train_views = self.get_imagedata_info(gallery) 60 | 61 | print("Dataset statistics:") 62 | print(" ----------------------------------------") 63 | print(" subset | # ids | # images | # cameras") 64 | print(" ----------------------------------------") 65 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 66 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 67 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 68 | print(" ----------------------------------------") 69 | 70 | 71 | class ImageDataset(Dataset): 72 | def __init__(self, dataset, transform=None, crop_transform=None, eraser_transform=None): 73 | self.dataset = dataset 74 | self.transform = transform 75 | self.crop_transform = crop_transform 76 | self.eraser_transform = eraser_transform 77 | 78 | def __len__(self): 79 | return len(self.dataset) 80 | 81 | def __getitem__(self, index): 82 | img_path, pid, camid, trackid = self.dataset[index] 83 | img = read_image(img_path) 84 | 85 | if self.transform is not None: 86 | img1 = self.transform(img) 87 | if self.crop_transform is not None and self.eraser_transform is not None: 88 | img2 = self.crop_transform(img) 89 | img3 = self.eraser_transform(img) 90 | return img1, img2, img3, pid, camid, trackid,img_path.split('/')[-1] 91 | else: 92 | return img1, img1, img1, pid, camid, trackid,img_path.split('/')[-1] -------------------------------------------------------------------------------- /datasets/dukemtmcreid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import glob 8 | import re 9 | import urllib 10 | import zipfile 11 | 12 | import os.path as osp 13 | 14 | from utils.iotools import mkdir_if_missing 15 | from .bases import BaseImageDataset 16 | 17 | 18 | class DukeMTMCreID(BaseImageDataset): 19 | """ 20 | DukeMTMC-reID 21 | Reference: 22 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 23 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 24 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 25 | 26 | Dataset statistics: 27 | # identities: 1404 (train + query) 28 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 29 | # cameras: 8 30 | """ 31 | dataset_dir = 'DukeMTMC-reID' 32 | 33 | def __init__(self, root='', verbose=True, pid_begin=0, **kwargs): 34 | super(DukeMTMCreID, self).__init__() 35 | self.dataset_dir = osp.join(root, self.dataset_dir) 36 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 37 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 38 | self.query_dir = osp.join(self.dataset_dir, 'query') 39 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 40 | self.pid_begin = pid_begin 41 | self._download_data() 42 | self._check_before_run() 43 | 44 | train = self._process_dir(self.train_dir, relabel=True) 45 | query = self._process_dir(self.query_dir, relabel=False) 46 | gallery = self._process_dir(self.gallery_dir, relabel=False) 47 | 48 | if verbose: 49 | print("=> DukeMTMC-reID loaded") 50 | self.print_dataset_statistics(train, query, gallery) 51 | 52 | self.train = train 53 | self.query = query 54 | self.gallery = gallery 55 | 56 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train) 57 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query) 58 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery) 59 | 60 | def _download_data(self): 61 | if osp.exists(self.dataset_dir): 62 | print("This dataset has been downloaded.") 63 | return 64 | 65 | print("Creating directory {}".format(self.dataset_dir)) 66 | mkdir_if_missing(self.dataset_dir) 67 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 68 | 69 | print("Downloading DukeMTMC-reID dataset") 70 | urllib.request.urlretrieve(self.dataset_url, fpath) 71 | 72 | print("Extracting files") 73 | zip_ref = zipfile.ZipFile(fpath, 'r') 74 | zip_ref.extractall(self.dataset_dir) 75 | zip_ref.close() 76 | 77 | def _check_before_run(self): 78 | """Check if all files are available before going deeper""" 79 | if not osp.exists(self.dataset_dir): 80 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 81 | if not osp.exists(self.train_dir): 82 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 83 | if not osp.exists(self.query_dir): 84 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 85 | if not osp.exists(self.gallery_dir): 86 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 87 | 88 | def _process_dir(self, dir_path, relabel=False): 89 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 90 | pattern = re.compile(r'([-\d]+)_c(\d)') 91 | 92 | pid_container = set() 93 | for img_path in img_paths: 94 | pid, _ = map(int, pattern.search(img_path).groups()) 95 | pid_container.add(pid) 96 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 97 | 98 | dataset = [] 99 | cam_container = set() 100 | for img_path in img_paths: 101 | pid, camid = map(int, pattern.search(img_path).groups()) 102 | assert 1 <= camid <= 8 103 | camid -= 1 # index starts from 0 104 | if relabel: pid = pid2label[pid] 105 | dataset.append((img_path, self.pid_begin + pid, camid, 1)) 106 | cam_container.add(camid) 107 | print(cam_container, 'cam_container') 108 | return dataset 109 | -------------------------------------------------------------------------------- /datasets/make_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as T 3 | from torch.utils.data import DataLoader 4 | 5 | from .bases import ImageDataset 6 | from timm.data.random_erasing import RandomErasing 7 | from .sampler import RandomIdentitySampler 8 | from .dukemtmcreid import DukeMTMCreID 9 | from .market1501 import Market1501 10 | from .msmt17 import MSMT17 11 | from .sampler_ddp import RandomIdentitySampler_DDP 12 | import torch.distributed as dist 13 | from .occ_duke import OCC_DukeMTMCreID 14 | from .vehicleid import VehicleID 15 | from .veri import VeRi 16 | # from .occ_reid import Occ_ReID 17 | # from .partial_reid import Partial_REID 18 | 19 | __factory = { 20 | 'market1501': Market1501, 21 | 'dukemtmc': DukeMTMCreID, 22 | 'msmt17': MSMT17, 23 | 'occ_duke': OCC_DukeMTMCreID, 24 | 'veri': VeRi, 25 | 'VehicleID': VehicleID, 26 | # 'partial_reid': Partial_REID, 27 | # 'occ_reid': Occ_ReID 28 | } 29 | 30 | def train_collate_fn(batch): 31 | """ 32 | # collate_fn这个函数的输入就是一个list,list的长度是一个batch size,list中的每个元素都是__getitem__得到的结果 33 | """ 34 | imgs1, imgs2, imgs3, pids, camids, viewids , _ = zip(*batch) 35 | pids = torch.tensor(pids, dtype=torch.int64) 36 | viewids = torch.tensor(viewids, dtype=torch.int64) 37 | camids = torch.tensor(camids, dtype=torch.int64) 38 | return torch.stack(imgs1, dim=0), torch.stack(imgs2, dim=0), torch.stack(imgs3, dim=0), pids, camids, viewids, 39 | 40 | def val_collate_fn(batch): 41 | imgs1, imgs2, imgs3, pids, camids, viewids, img_paths = zip(*batch) 42 | viewids = torch.tensor(viewids, dtype=torch.int64) 43 | camids_batch = torch.tensor(camids, dtype=torch.int64) 44 | return torch.stack(imgs1, dim=0), torch.stack(imgs2, dim=0), torch.stack(imgs3, dim=0), pids, camids, camids_batch, viewids, img_paths 45 | 46 | def make_dataloader(cfg): 47 | train_transforms = T.Compose([ 48 | T.Resize(cfg.INPUT.SIZE_TRAIN, interpolation=3), 49 | # T.RandomHorizontalFlip(p=cfg.INPUT.PROB), 50 | # T.Pad(cfg.INPUT.PADDING), 51 | # T.RandomCrop(cfg.INPUT.SIZE_TRAIN), 52 | T.ToTensor(), 53 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD), 54 | # RandomErasing(probability=cfg.INPUT.RE_PROB, mode='pixel', max_count=1, device='cpu'), 55 | # RandomErasing(probability=cfg.INPUT.RE_PROB, mean=cfg.INPUT.PIXEL_MEAN) 56 | ]) 57 | crop_transforms = T.Compose([ 58 | T.Resize(cfg.INPUT.SIZE_TRAIN, interpolation=3), 59 | # T.RandomHorizontalFlip(p=cfg.INPUT.PROB), 60 | T.Pad(30), 61 | T.ToTensor(), 62 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD), 63 | # T.RandomResizedCrop(size=(256, 128), scale=(0.3, 0.6)), 64 | T.RandomResizedCrop(size=(256, 128)), 65 | ]) 66 | eraser_transforms = T.Compose([ 67 | T.Resize(cfg.INPUT.SIZE_TRAIN, interpolation=3), 68 | # T.RandomHorizontalFlip(p=cfg.INPUT.PROB), 69 | T.ToTensor(), 70 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD), 71 | RandomErasing(probability=1, mode='pixel', max_count=1, device='cpu'), 72 | ]) 73 | 74 | 75 | 76 | val_transforms = T.Compose([ 77 | T.Resize(cfg.INPUT.SIZE_TEST), 78 | T.ToTensor(), 79 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) 80 | ]) 81 | 82 | num_workers = cfg.DATALOADER.NUM_WORKERS 83 | 84 | dataset = __factory[cfg.DATASETS.NAMES](root=cfg.DATASETS.ROOT_DIR) 85 | 86 | train_set = ImageDataset(dataset.train, train_transforms, crop_transform=crop_transforms, eraser_transform=eraser_transforms) 87 | train_set_normal = ImageDataset(dataset.train, val_transforms) 88 | num_classes = dataset.num_train_pids 89 | cam_num = dataset.num_train_cams 90 | view_num = dataset.num_train_vids 91 | 92 | if 'triplet' in cfg.DATALOADER.SAMPLER: 93 | if cfg.MODEL.DIST_TRAIN: 94 | print('DIST_TRAIN START') 95 | mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // dist.get_world_size() 96 | data_sampler = RandomIdentitySampler_DDP(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE) 97 | batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True) 98 | train_loader = torch.utils.data.DataLoader( 99 | train_set, 100 | num_workers=num_workers, 101 | batch_sampler=batch_sampler, 102 | collate_fn=train_collate_fn, 103 | pin_memory=True, 104 | ) 105 | else: 106 | train_loader = DataLoader( 107 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, 108 | sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE), 109 | num_workers=num_workers, collate_fn=train_collate_fn 110 | ) 111 | elif cfg.DATALOADER.SAMPLER == 'softmax': 112 | print('using softmax sampler') 113 | train_loader = DataLoader( 114 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers, 115 | collate_fn=train_collate_fn 116 | ) 117 | else: 118 | print('unsupported sampler! expected softmax or triplet but got {}'.format(cfg.SAMPLER)) 119 | 120 | val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms) 121 | 122 | val_loader = DataLoader( 123 | val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 124 | collate_fn=val_collate_fn 125 | ) 126 | train_loader_normal = DataLoader( 127 | train_set_normal, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 128 | collate_fn=val_collate_fn 129 | ) 130 | return train_loader, train_loader_normal, val_loader, len(dataset.query), num_classes, cam_num, view_num 131 | -------------------------------------------------------------------------------- /datasets/market1501.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import glob 8 | import re 9 | 10 | import os.path as osp 11 | 12 | from .bases import BaseImageDataset 13 | from collections import defaultdict 14 | import pickle 15 | class Market1501(BaseImageDataset): 16 | """ 17 | Market1501 18 | Reference: 19 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 20 | URL: http://www.liangzheng.org/Project/project_reid.html 21 | 22 | Dataset statistics: 23 | # identities: 1501 (+1 for background) 24 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 25 | """ 26 | dataset_dir = 'market1501' 27 | 28 | def __init__(self, root='', verbose=True, pid_begin = 0, **kwargs): 29 | super(Market1501, self).__init__() 30 | self.dataset_dir = osp.join(root, self.dataset_dir) 31 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 32 | self.query_dir = osp.join(self.dataset_dir, 'query') 33 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 34 | 35 | self._check_before_run() 36 | self.pid_begin = pid_begin 37 | train = self._process_dir(self.train_dir, relabel=True) 38 | query = self._process_dir(self.query_dir, relabel=False) 39 | gallery = self._process_dir(self.gallery_dir, relabel=False) 40 | 41 | if verbose: 42 | print("=> Market1501 loaded") 43 | self.print_dataset_statistics(train, query, gallery) 44 | 45 | self.train = train 46 | self.query = query 47 | self.gallery = gallery 48 | 49 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train) 50 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query) 51 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery) 52 | 53 | def _check_before_run(self): 54 | """Check if all files are available before going deeper""" 55 | if not osp.exists(self.dataset_dir): 56 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 57 | if not osp.exists(self.train_dir): 58 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 59 | if not osp.exists(self.query_dir): 60 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 61 | if not osp.exists(self.gallery_dir): 62 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 63 | 64 | def _process_dir(self, dir_path, relabel=False): 65 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 66 | pattern = re.compile(r'([-\d]+)_c(\d)') 67 | 68 | pid_container = set() 69 | for img_path in sorted(img_paths): 70 | pid, _ = map(int, pattern.search(img_path).groups()) 71 | if pid == -1: continue # junk images are just ignored 72 | pid_container.add(pid) 73 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 74 | dataset = [] 75 | for img_path in sorted(img_paths): 76 | pid, camid = map(int, pattern.search(img_path).groups()) 77 | if pid == -1: continue # junk images are just ignored 78 | assert 0 <= pid <= 1501 # pid == 0 means background 79 | assert 1 <= camid <= 6 80 | camid -= 1 # index starts from 0 81 | if relabel: pid = pid2label[pid] 82 | 83 | dataset.append((img_path, self.pid_begin + pid, camid, 1)) 84 | return dataset 85 | -------------------------------------------------------------------------------- /datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | 2 | import glob 3 | import re 4 | 5 | import os.path as osp 6 | 7 | from .bases import BaseImageDataset 8 | 9 | 10 | class MSMT17(BaseImageDataset): 11 | """ 12 | MSMT17 13 | 14 | Reference: 15 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018. 16 | 17 | URL: http://www.pkuvmc.com/publications/msmt17.html 18 | 19 | Dataset statistics: 20 | # identities: 4101 21 | # images: 32621 (train) + 11659 (query) + 82161 (gallery) 22 | # cameras: 15 23 | """ 24 | dataset_dir = 'MSMT17' 25 | 26 | def __init__(self, root='', verbose=True, pid_begin=0, **kwargs): 27 | super(MSMT17, self).__init__() 28 | self.pid_begin = pid_begin 29 | self.dataset_dir = osp.join(root, self.dataset_dir) 30 | self.train_dir = osp.join(self.dataset_dir, 'train') 31 | self.test_dir = osp.join(self.dataset_dir, 'test') 32 | self.list_train_path = osp.join(self.dataset_dir, 'list_train.txt') 33 | self.list_val_path = osp.join(self.dataset_dir, 'list_val.txt') 34 | self.list_query_path = osp.join(self.dataset_dir, 'list_query.txt') 35 | self.list_gallery_path = osp.join(self.dataset_dir, 'list_gallery.txt') 36 | 37 | self._check_before_run() 38 | train = self._process_dir(self.train_dir, self.list_train_path) 39 | val = self._process_dir(self.train_dir, self.list_val_path) 40 | train += val 41 | query = self._process_dir(self.test_dir, self.list_query_path) 42 | gallery = self._process_dir(self.test_dir, self.list_gallery_path) 43 | if verbose: 44 | print("=> MSMT17 loaded") 45 | self.print_dataset_statistics(train, query, gallery) 46 | 47 | self.train = train 48 | self.query = query 49 | self.gallery = gallery 50 | 51 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train) 52 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query) 53 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery) 54 | def _check_before_run(self): 55 | """Check if all files are available before going deeper""" 56 | if not osp.exists(self.dataset_dir): 57 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 58 | if not osp.exists(self.train_dir): 59 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 60 | if not osp.exists(self.test_dir): 61 | raise RuntimeError("'{}' is not available".format(self.test_dir)) 62 | 63 | def _process_dir(self, dir_path, list_path): 64 | with open(list_path, 'r') as txt: 65 | lines = txt.readlines() 66 | dataset = [] 67 | pid_container = set() 68 | cam_container = set() 69 | for img_idx, img_info in enumerate(lines): 70 | img_path, pid = img_info.split(' ') 71 | pid = int(pid) # no need to relabel 72 | camid = int(img_path.split('_')[2]) 73 | img_path = osp.join(dir_path, img_path) 74 | dataset.append((img_path, self.pid_begin +pid, camid-1, 1)) 75 | pid_container.add(pid) 76 | cam_container.add(camid) 77 | print(cam_container, 'cam_container') 78 | # check if pid starts from 0 and increments with 1 79 | for idx, pid in enumerate(pid_container): 80 | assert idx == pid, "See code comment for explanation" 81 | return dataset -------------------------------------------------------------------------------- /datasets/occ_duke.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import glob 8 | import re 9 | import urllib 10 | import zipfile 11 | 12 | import os.path as osp 13 | 14 | from utils.iotools import mkdir_if_missing 15 | from .bases import BaseImageDataset 16 | 17 | 18 | class OCC_DukeMTMCreID(BaseImageDataset): 19 | """ 20 | DukeMTMC-reID 21 | Reference: 22 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 23 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 24 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 25 | 26 | Dataset statistics: 27 | # identities: 1404 (train + query) 28 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 29 | # cameras: 8 30 | """ 31 | dataset_dir = 'Occluded_Duke' 32 | 33 | def __init__(self, root='', verbose=True, pid_begin=0, **kwargs): 34 | super(OCC_DukeMTMCreID, self).__init__() 35 | self.dataset_dir = osp.join(root, self.dataset_dir) 36 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 37 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 38 | self.query_dir = osp.join(self.dataset_dir, 'query') 39 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 40 | self.pid_begin = pid_begin 41 | self._download_data() 42 | self._check_before_run() 43 | 44 | train = self._process_dir(self.train_dir, relabel=True) 45 | query = self._process_dir(self.query_dir, relabel=False) 46 | gallery = self._process_dir(self.gallery_dir, relabel=False) 47 | 48 | if verbose: 49 | print("=> DukeMTMC-reID loaded") 50 | self.print_dataset_statistics(train, query, gallery) 51 | 52 | self.train = train 53 | self.query = query 54 | self.gallery = gallery 55 | 56 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train) 57 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query) 58 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery) 59 | 60 | def _download_data(self): 61 | if osp.exists(self.dataset_dir): 62 | print("This dataset has been downloaded.") 63 | return 64 | 65 | print("Creating directory {}".format(self.dataset_dir)) 66 | mkdir_if_missing(self.dataset_dir) 67 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 68 | 69 | print("Downloading DukeMTMC-reID dataset") 70 | urllib.request.urlretrieve(self.dataset_url, fpath) 71 | 72 | print("Extracting files") 73 | zip_ref = zipfile.ZipFile(fpath, 'r') 74 | zip_ref.extractall(self.dataset_dir) 75 | zip_ref.close() 76 | 77 | def _check_before_run(self): 78 | """Check if all files are available before going deeper""" 79 | if not osp.exists(self.dataset_dir): 80 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 81 | if not osp.exists(self.train_dir): 82 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 83 | if not osp.exists(self.query_dir): 84 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 85 | if not osp.exists(self.gallery_dir): 86 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 87 | 88 | def _process_dir(self, dir_path, relabel=False): 89 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 90 | pattern = re.compile(r'([-\d]+)_c(\d)') 91 | 92 | pid_container = set() 93 | for img_path in img_paths: 94 | pid, _ = map(int, pattern.search(img_path).groups()) 95 | pid_container.add(pid) 96 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 97 | 98 | dataset = [] 99 | cam_container = set() 100 | for img_path in img_paths: 101 | pid, camid = map(int, pattern.search(img_path).groups()) 102 | assert 1 <= camid <= 8 103 | camid -= 1 # index starts from 0 104 | if relabel: pid = pid2label[pid] 105 | dataset.append((img_path, self.pid_begin + pid, camid, 1)) 106 | cam_container.add(camid) 107 | print(cam_container, 'cam_container') 108 | return dataset 109 | -------------------------------------------------------------------------------- /datasets/occ_reid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import glob 8 | import re 9 | import urllib 10 | import zipfile 11 | 12 | import os.path as osp 13 | 14 | from utils.iotools import mkdir_if_missing 15 | from .bases import BaseImageDataset 16 | 17 | 18 | class Occ_ReID(BaseImageDataset): 19 | 20 | dataset_dir_train = 'market1501' 21 | dataset_dir_test = 'OccludedREID' 22 | 23 | def __init__(self, root='', verbose=True, pid_begin=0, **kwargs): 24 | super(Occ_ReID, self).__init__() 25 | self.dataset_dir_train = osp.join(root, self.dataset_dir_train) 26 | self.dataset_dir_test = osp.join(root, self.dataset_dir_test) 27 | 28 | self.train_dir = osp.join(self.dataset_dir_train, 'bounding_box_train') 29 | self.query_dir = osp.join(self.dataset_dir_test, 'query') 30 | self.gallery_dir = osp.join(self.dataset_dir_test, 'gallery') 31 | self.pid_begin = pid_begin 32 | self._check_before_run() 33 | 34 | train = self._process_dir_train(self.train_dir, relabel=True) 35 | query = self._process_dir_test(self.query_dir, camera_id=1, relabel=False) 36 | gallery = self._process_dir_test(self.gallery_dir, camera_id=2, relabel=False) 37 | 38 | if verbose: 39 | print("=> Occ_ReID loaded") 40 | self.print_dataset_statistics(train, query, gallery) 41 | 42 | self.train = train 43 | self.query = query 44 | self.gallery = gallery 45 | 46 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train) 47 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query) 48 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery) 49 | 50 | def _check_before_run(self): 51 | """Check if all files are available before going deeper""" 52 | if not osp.exists(self.dataset_dir_train): 53 | raise RuntimeError("'{}' is not available".format(self.dataset_dir_train)) 54 | if not osp.exists(self.train_dir): 55 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 56 | if not osp.exists(self.query_dir): 57 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 58 | if not osp.exists(self.gallery_dir): 59 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 60 | 61 | def _process_dir_train(self, dir_path, relabel=False): 62 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 63 | pattern = re.compile(r'([-\d]+)_c(\d)') 64 | 65 | pid_container = set() 66 | for img_path in sorted(img_paths): 67 | pid, _ = map(int, pattern.search(img_path).groups()) 68 | if pid == -1: continue # junk images are just ignored 69 | pid_container.add(pid) 70 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 71 | dataset = [] 72 | for img_path in sorted(img_paths): 73 | pid, camid = map(int, pattern.search(img_path).groups()) 74 | if pid == -1: continue # junk images are just ignored 75 | assert 0 <= pid <= 1501 # pid == 0 means background 76 | assert 1 <= camid <= 6 77 | camid -= 1 # index starts from 0 78 | if relabel: pid = pid2label[pid] 79 | 80 | dataset.append((img_path, self.pid_begin + pid, camid, 1)) 81 | return dataset 82 | 83 | def _process_dir_test(self, dir_path, camera_id=1, relabel=False): 84 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 85 | pid_container = set() 86 | for img_path in img_paths: 87 | jpg_name = img_path.split('/')[-1] 88 | pid = int(jpg_name.split('_')[0]) 89 | pid_container.add(pid) 90 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 91 | 92 | data = [] 93 | for img_path in img_paths: 94 | jpg_name = img_path.split('/')[-1] 95 | pid = int(jpg_name.split('_')[0]) 96 | camid = camera_id 97 | camid -= 1 # index starts from 0 98 | if relabel: 99 | pid = pid2label[pid] 100 | data.append((img_path, pid, camid, 1)) 101 | return data 102 | -------------------------------------------------------------------------------- /datasets/partial_reid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import glob 8 | import re 9 | import urllib 10 | import zipfile 11 | 12 | import os.path as osp 13 | 14 | from utils.iotools import mkdir_if_missing 15 | from .bases import BaseImageDataset 16 | 17 | 18 | class Partial_REID(BaseImageDataset): 19 | 20 | dataset_dir_train = 'market1501' 21 | dataset_dir_test = 'Partial_REID' 22 | 23 | def __init__(self, root='', verbose=True, pid_begin=0, **kwargs): 24 | super(Partial_REID, self).__init__() 25 | self.dataset_dir_train = osp.join(root, self.dataset_dir_train) 26 | self.dataset_dir_test = osp.join(root, self.dataset_dir_test) 27 | 28 | self.train_dir = osp.join(self.dataset_dir_train, 'bounding_box_train') 29 | self.query_dir = osp.join(self.dataset_dir_test, 'partial_body_images') 30 | self.gallery_dir = osp.join(self.dataset_dir_test, 'whole_body_images') 31 | self.pid_begin = pid_begin 32 | self._check_before_run() 33 | 34 | train = self._process_dir_train(self.train_dir, relabel=True) 35 | query = self._process_dir_test(self.query_dir, camera_id=1, relabel=False) 36 | gallery = self._process_dir_test(self.gallery_dir, camera_id=2, relabel=False) 37 | 38 | if verbose: 39 | print("=> Partial REID loaded") 40 | self.print_dataset_statistics(train, query, gallery) 41 | 42 | self.train = train 43 | self.query = query 44 | self.gallery = gallery 45 | 46 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train) 47 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query) 48 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery) 49 | 50 | def _check_before_run(self): 51 | """Check if all files are available before going deeper""" 52 | if not osp.exists(self.dataset_dir_train): 53 | raise RuntimeError("'{}' is not available".format(self.dataset_dir_train)) 54 | if not osp.exists(self.train_dir): 55 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 56 | if not osp.exists(self.query_dir): 57 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 58 | if not osp.exists(self.gallery_dir): 59 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 60 | 61 | def _process_dir_train(self, dir_path, relabel=False): 62 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 63 | pattern = re.compile(r'([-\d]+)_c(\d)') 64 | 65 | pid_container = set() 66 | for img_path in sorted(img_paths): 67 | pid, _ = map(int, pattern.search(img_path).groups()) 68 | if pid == -1: continue # junk images are just ignored 69 | pid_container.add(pid) 70 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 71 | dataset = [] 72 | for img_path in sorted(img_paths): 73 | pid, camid = map(int, pattern.search(img_path).groups()) 74 | if pid == -1: continue # junk images are just ignored 75 | assert 0 <= pid <= 1501 # pid == 0 means background 76 | assert 1 <= camid <= 6 77 | camid -= 1 # index starts from 0 78 | if relabel: pid = pid2label[pid] 79 | 80 | dataset.append((img_path, self.pid_begin + pid, camid, 1)) 81 | return dataset 82 | 83 | def _process_dir_test(self, dir_path, camera_id=1, relabel=False): 84 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 85 | pid_container = set() 86 | for img_path in img_paths: 87 | jpg_name = img_path.split('/')[-1] 88 | pid = int(jpg_name.split('_')[0]) 89 | pid_container.add(pid) 90 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 91 | 92 | data = [] 93 | for img_path in img_paths: 94 | jpg_name = img_path.split('/')[-1] 95 | pid = int(jpg_name.split('_')[0]) 96 | camid = camera_id 97 | camid -= 1 # index starts from 0 98 | if relabel: 99 | pid = pid2label[pid] 100 | data.append((img_path, pid, camid, 1)) 101 | return data 102 | -------------------------------------------------------------------------------- /datasets/preprocessing.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | 4 | 5 | class RandomErasing(object): 6 | """ Randomly selects a rectangle region in an image and erases its pixels. 7 | 'Random Erasing Data Augmentation' by Zhong et al. 8 | See https://arxiv.org/pdf/1708.04896.pdf 9 | Args: 10 | probability: The probability that the Random Erasing operation will be performed. 11 | sl: Minimum proportion of erased area against input image. 12 | sh: Maximum proportion of erased area against input image. 13 | r1: Minimum aspect ratio of erased area. 14 | mean: Erasing value. 15 | """ 16 | 17 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 18 | self.probability = probability 19 | self.mean = mean 20 | self.sl = sl 21 | self.sh = sh 22 | self.r1 = r1 23 | 24 | def __call__(self, img): 25 | 26 | if random.uniform(0, 1) >= self.probability: 27 | return img 28 | 29 | for attempt in range(100): 30 | area = img.size()[1] * img.size()[2] 31 | 32 | target_area = random.uniform(self.sl, self.sh) * area 33 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 34 | 35 | h = int(round(math.sqrt(target_area * aspect_ratio))) 36 | w = int(round(math.sqrt(target_area / aspect_ratio))) 37 | 38 | if w < img.size()[2] and h < img.size()[1]: 39 | x1 = random.randint(0, img.size()[1] - h) 40 | y1 = random.randint(0, img.size()[2] - w) 41 | if img.size()[0] == 3: 42 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 43 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 44 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 45 | else: 46 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 47 | return img 48 | 49 | return img 50 | 51 | -------------------------------------------------------------------------------- /datasets/sampler.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | from collections import defaultdict 3 | import copy 4 | import random 5 | import numpy as np 6 | 7 | class RandomIdentitySampler(Sampler): 8 | """ 9 | Randomly sample N identities, then for each identity, 10 | randomly sample K instances, therefore batch size is N*K. 11 | Args: 12 | - data_source (list): list of (img_path, pid, camid). 13 | - num_instances (int): number of instances per identity in a batch. 14 | - batch_size (int): number of examples in a batch. 15 | """ 16 | 17 | def __init__(self, data_source, batch_size, num_instances): 18 | self.data_source = data_source 19 | self.batch_size = batch_size 20 | self.num_instances = num_instances 21 | self.num_pids_per_batch = self.batch_size // self.num_instances 22 | self.index_dic = defaultdict(list) #dict with list value 23 | #{783: [0, 5, 116, 876, 1554, 2041],...,} 24 | for index, (_, pid, _, _) in enumerate(self.data_source): 25 | self.index_dic[pid].append(index) 26 | self.pids = list(self.index_dic.keys()) 27 | 28 | # estimate number of examples in an epoch 29 | self.length = 0 30 | for pid in self.pids: 31 | idxs = self.index_dic[pid] 32 | num = len(idxs) 33 | if num < self.num_instances: 34 | num = self.num_instances 35 | self.length += num - num % self.num_instances 36 | 37 | def __iter__(self): 38 | batch_idxs_dict = defaultdict(list) 39 | 40 | for pid in self.pids: 41 | idxs = copy.deepcopy(self.index_dic[pid]) 42 | if len(idxs) < self.num_instances: 43 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 44 | random.shuffle(idxs) 45 | batch_idxs = [] 46 | for idx in idxs: 47 | batch_idxs.append(idx) 48 | if len(batch_idxs) == self.num_instances: 49 | batch_idxs_dict[pid].append(batch_idxs) 50 | batch_idxs = [] 51 | 52 | avai_pids = copy.deepcopy(self.pids) 53 | final_idxs = [] 54 | 55 | while len(avai_pids) >= self.num_pids_per_batch: 56 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 57 | for pid in selected_pids: 58 | batch_idxs = batch_idxs_dict[pid].pop(0) 59 | final_idxs.extend(batch_idxs) 60 | if len(batch_idxs_dict[pid]) == 0: 61 | avai_pids.remove(pid) 62 | 63 | return iter(final_idxs) 64 | 65 | def __len__(self): 66 | return self.length 67 | 68 | -------------------------------------------------------------------------------- /datasets/veri.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import re 3 | import os.path as osp 4 | 5 | from .bases import BaseImageDataset 6 | 7 | 8 | class VeRi(BaseImageDataset): 9 | """ 10 | VeRi-776 11 | Reference: 12 | Liu, Xinchen, et al. "Large-scale vehicle re-identification in urban surveillance videos." ICME 2016. 13 | 14 | URL:https://vehiclereid.github.io/VeRi/ 15 | 16 | Dataset statistics: 17 | # identities: 776 18 | # images: 37778 (train) + 1678 (query) + 11579 (gallery) 19 | # cameras: 20 20 | """ 21 | 22 | dataset_dir = 'VeRi' 23 | 24 | def __init__(self, root='', verbose=True, **kwargs): 25 | super(VeRi, self).__init__() 26 | self.dataset_dir = osp.join(root, self.dataset_dir) 27 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 28 | self.query_dir = osp.join(self.dataset_dir, 'image_query') 29 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test') 30 | 31 | self._check_before_run() 32 | 33 | path_train = 'datasets/keypoint_train.txt' 34 | with open(path_train, 'r') as txt: 35 | lines = txt.readlines() 36 | self.image_map_view_train = {} 37 | for img_idx, img_info in enumerate(lines): 38 | content = img_info.split(' ') 39 | viewid = int(content[-1]) 40 | self.image_map_view_train[osp.basename(content[0])] = viewid 41 | 42 | path_test = 'datasets/keypoint_test.txt' 43 | with open(path_test, 'r') as txt: 44 | lines = txt.readlines() 45 | self.image_map_view_test = {} 46 | for img_idx, img_info in enumerate(lines): 47 | content = img_info.split(' ') 48 | viewid = int(content[-1]) 49 | self.image_map_view_test[osp.basename(content[0])] = viewid 50 | 51 | train = self._process_dir(self.train_dir, relabel=True) 52 | query = self._process_dir(self.query_dir, relabel=False) 53 | gallery = self._process_dir(self.gallery_dir, relabel=False) 54 | 55 | if verbose: 56 | print("=> VeRi-776 loaded") 57 | self.print_dataset_statistics(train, query, gallery) 58 | 59 | self.train = train 60 | self.query = query 61 | self.gallery = gallery 62 | 63 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info( 64 | self.train) 65 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info( 66 | self.query) 67 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info( 68 | self.gallery) 69 | 70 | def _check_before_run(self): 71 | """Check if all files are available before going deeper""" 72 | if not osp.exists(self.dataset_dir): 73 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 74 | if not osp.exists(self.train_dir): 75 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 76 | if not osp.exists(self.query_dir): 77 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 78 | if not osp.exists(self.gallery_dir): 79 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 80 | 81 | def _process_dir(self, dir_path, relabel=False): 82 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 83 | pattern = re.compile(r'([-\d]+)_c(\d+)') 84 | 85 | pid_container = set() 86 | for img_path in img_paths: 87 | pid, _ = map(int, pattern.search(img_path).groups()) 88 | if pid == -1: continue # junk images are just ignored 89 | pid_container.add(pid) 90 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 91 | 92 | view_container = set() 93 | dataset = [] 94 | count = 0 95 | for img_path in img_paths: 96 | pid, camid = map(int, pattern.search(img_path).groups()) 97 | if pid == -1: continue # junk images are just ignored 98 | assert 0 <= pid <= 776 # pid == 0 means background 99 | assert 1 <= camid <= 20 100 | camid -= 1 # index starts from 0 101 | if relabel: pid = pid2label[pid] 102 | 103 | if osp.basename(img_path) not in self.image_map_view_train.keys(): 104 | try: 105 | viewid = self.image_map_view_test[osp.basename(img_path)] 106 | except: 107 | count += 1 108 | # print(img_path, 'img_path') 109 | continue 110 | else: 111 | viewid = self.image_map_view_train[osp.basename(img_path)] 112 | view_container.add(viewid) 113 | dataset.append((img_path, pid, camid, viewid)) 114 | print(view_container, 'view_container') 115 | print(count, 'samples without viewpoint annotations') 116 | return dataset 117 | 118 | -------------------------------------------------------------------------------- /dist_test.sh: -------------------------------------------------------------------------------- 1 | 2 | python test.py --config_file configs/DukeMTMC/vit_transreid_stride.yml MODEL.DEVICE_ID "('5')" TEST.WEIGHT '/data2/zi.wang/code/PartialReID-final/logs_duke/lr0008_b32_Process1_Model12_loss1/transformer_best.pth' OUTPUT_DIR './logs_duke/test_AO_0.2' 3 | 4 | python test.py --config_file configs/Market/vit_transreid_stride.yml MODEL.DEVICE_ID "('7')" TEST.WEIGHT '/data2/zi.wang/code/PartialReID-final/logs_market/lr0008_b32_Process1_Model12_loss1/transformer_best.pth' OUTPUT_DIR './logs_market/test_AO_0.2' 5 | 6 | 7 | python test.py --config_file configs/OCC_Duke/vit_transreid_stride.yml MODEL.DEVICE_ID "('5')" TEST.WEIGHT '../logs/occ_duke_vit_transreid_stride/transformer_120.pth' -------------------------------------------------------------------------------- /dist_train.sh: -------------------------------------------------------------------------------- 1 | python train.py --config_file configs/Partial_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('1')" SOLVER.BASE_LR 0.001 OUTPUT_DIR './logs_partial_reid/lr0001_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 20 2 | python train.py --config_file configs/Partial_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('1')" SOLVER.BASE_LR 0.002 OUTPUT_DIR './logs_partial_reid/lr0002_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 20 3 | python train.py --config_file configs/Partial_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('1')" SOLVER.BASE_LR 0.004 OUTPUT_DIR './logs_partial_reid/lr0004_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 15 4 | python train.py --config_file configs/Partial_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('1')" SOLVER.BASE_LR 0.008 OUTPUT_DIR './logs_partial_reid/lr0008_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 15 5 | python train.py --config_file configs/Partial_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('1')" SOLVER.BASE_LR 0.0001 OUTPUT_DIR './logs_partial_reid/lr00001_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 25 6 | python train.py --config_file configs/Partial_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('1')" SOLVER.BASE_LR 0.0002 OUTPUT_DIR './logs_partial_reid/lr00002_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 25 7 | python train.py --config_file configs/Partial_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('1')" SOLVER.BASE_LR 0.0004 OUTPUT_DIR './logs_partial_reid/lr00004_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 25 8 | python train.py --config_file configs/Partial_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('1')" SOLVER.BASE_LR 0.0008 OUTPUT_DIR './logs_partial_reid/lr00008_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 25 9 | -------------------------------------------------------------------------------- /dist_train_occReID.sh: -------------------------------------------------------------------------------- 1 | python train.py --config_file configs/OCC_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('2')" SOLVER.BASE_LR 0.001 OUTPUT_DIR './logs_occ_reid/lr0001_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 20 2 | python train.py --config_file configs/OCC_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('2')" SOLVER.BASE_LR 0.002 OUTPUT_DIR './logs_occ_reid/lr0002_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 20 3 | python train.py --config_file configs/OCC_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('2')" SOLVER.BASE_LR 0.004 OUTPUT_DIR './logs_occ_reid/lr0004_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 15 4 | python train.py --config_file configs/OCC_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('2')" SOLVER.BASE_LR 0.008 OUTPUT_DIR './logs_occ_reid/lr0008_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 15 5 | python train.py --config_file configs/OCC_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('2')" SOLVER.BASE_LR 0.0001 OUTPUT_DIR './logs_occ_reid/lr00001_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 25 6 | python train.py --config_file configs/OCC_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('2')" SOLVER.BASE_LR 0.0002 OUTPUT_DIR './logs_occ_reid/lr00002_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 25 7 | python train.py --config_file configs/OCC_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('2')" SOLVER.BASE_LR 0.0004 OUTPUT_DIR './logs_occ_reid/lr00004_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 25 8 | python train.py --config_file configs/OCC_ReID/vit_transreid_stride.yml MODEL.DEVICE_ID "('2')" SOLVER.BASE_LR 0.0008 OUTPUT_DIR './logs_occ_reid/lr00008_b32_Process1_Model12_loss1' SOLVER.MAX_EPOCHS 25 9 | -------------------------------------------------------------------------------- /fig/1: -------------------------------------------------------------------------------- 1 | 11 2 | -------------------------------------------------------------------------------- /fig/OccludedREID_gallery.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/fig/OccludedREID_gallery.jpg -------------------------------------------------------------------------------- /fig/OccludedREID_query.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/fig/OccludedREID_query.jpg -------------------------------------------------------------------------------- /fig/RankingList-partial.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/fig/RankingList-partial.png -------------------------------------------------------------------------------- /fig/image-20221018171750395.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/fig/image-20221018171750395.png -------------------------------------------------------------------------------- /fig/image-20221018171831853.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/fig/image-20221018171831853.png -------------------------------------------------------------------------------- /fig/image-20221018171840117.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/fig/image-20221018171840117.png -------------------------------------------------------------------------------- /fig/market_train.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/fig/market_train.jpg -------------------------------------------------------------------------------- /fig/partial_gallery.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/fig/partial_gallery.jpg -------------------------------------------------------------------------------- /fig/partial_query.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/fig/partial_query.jpg -------------------------------------------------------------------------------- /loss/HCloss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import nn, tensor 3 | import torch 4 | from torch.autograd import Variable 5 | import pdb 6 | 7 | class hetero_loss(nn.Module): 8 | def __init__(self, margin=0.1, dist_type = 'l2'): 9 | super(hetero_loss, self).__init__() 10 | self.margin = margin 11 | self.dist_type = dist_type 12 | if dist_type == 'l2': 13 | self.dist = nn.MSELoss(reduction='sum') 14 | if dist_type == 'cos': 15 | self.dist = nn.CosineSimilarity(dim=0) 16 | if dist_type == 'l1': 17 | self.dist = nn.L1Loss() 18 | 19 | def forward(self, feat1, feat2, label1): 20 | feat_size = feat1.size()[1] 21 | feat_num = feat1.size()[0] 22 | label_num = len(label1.unique()) 23 | feat1 = feat1.chunk(label_num, 0) 24 | feat2 = feat2.chunk(label_num, 0) 25 | #loss = Variable(.cuda()) 26 | # pdb.set_trace() 27 | for i in range(label_num): 28 | center1 = torch.mean(feat1[i], dim=0) 29 | center2 = torch.mean(feat2[i], dim=0) 30 | if self.dist_type == 'l2' or self.dist_type == 'l1': 31 | if i == 0: 32 | dist = max(0, self.dist(center1, center2) - self.margin) 33 | else: 34 | dist += max(0, self.dist(center1, center2) - self.margin) 35 | elif self.dist_type == 'cos': 36 | if i == 0: 37 | dist = max(0, 1-self.dist(center1, center2) - self.margin) 38 | else: 39 | dist += max(0, 1-self.dist(center1, center2) - self.margin) 40 | 41 | return dist -------------------------------------------------------------------------------- /loss/KLloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class KDLoss(nn.Module): 7 | 8 | def __init__(self, temp: float, reduction: str): 9 | super(KDLoss, self).__init__() 10 | 11 | self.temp = temp 12 | self.reduction = reduction 13 | self.kl_loss = nn.KLDivLoss(reduction=reduction) 14 | 15 | def forward(self, teacher_logits: torch.Tensor, student_logits: torch.Tensor): 16 | 17 | student_softmax = F.log_softmax(student_logits / self.temp, dim=-1) 18 | teacher_softmax = F.softmax(teacher_logits / self.temp, dim=-1) 19 | 20 | kl = nn.KLDivLoss(reduction='none')(student_softmax, teacher_softmax) 21 | kl = kl.sum() if self.reduction == 'sum' else kl.sum(1).mean() 22 | kl = kl * (self.temp ** 2) 23 | 24 | return kl 25 | 26 | def __call__(self, *args, **kwargs): 27 | return super(KDLoss, self).__call__(*args, **kwargs) 28 | 29 | 30 | class LogitsMatching(nn.Module): 31 | 32 | def __init__(self, reduction: str): 33 | super(LogitsMatching, self).__init__() 34 | self.mse_loss = nn.MSELoss(reduction=reduction) 35 | 36 | def forward(self, teacher_logits: torch.Tensor, student_logits: torch.Tensor): 37 | return self.mse_loss(student_logits, teacher_logits) 38 | 39 | def __call__(self, *args, **kwargs): 40 | return super(LogitsMatching, self).__call__(*args, **kwargs) -------------------------------------------------------------------------------- /loss/MSEloss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class HintLoss(nn.Module): 4 | """Fitnets: hints for thin deep nets, ICLR 2015""" 5 | def __init__(self): 6 | super(HintLoss, self).__init__() 7 | self.crit = nn.MSELoss() 8 | 9 | def forward(self, f_s, f_t): 10 | loss = self.crit(f_s, f_t) 11 | return loss 12 | -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_loss import make_loss 2 | from .arcface import ArcFace -------------------------------------------------------------------------------- /loss/__pycache__/HCloss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/loss/__pycache__/HCloss.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/KLloss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/loss/__pycache__/KLloss.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/MSEloss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/loss/__pycache__/MSEloss.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/loss/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/arcface.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/loss/__pycache__/arcface.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/center_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/loss/__pycache__/center_loss.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/make_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/loss/__pycache__/make_loss.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/metric_learning.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/loss/__pycache__/metric_learning.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/softmax_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/loss/__pycache__/softmax_loss.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/triplet_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/loss/__pycache__/triplet_loss.cpython-36.pyc -------------------------------------------------------------------------------- /loss/arcface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Parameter 5 | import math 6 | 7 | 8 | class ArcFace(nn.Module): 9 | def __init__(self, in_features, out_features, s=30.0, m=0.50, bias=False): 10 | super(ArcFace, self).__init__() 11 | self.in_features = in_features 12 | self.out_features = out_features 13 | self.s = s 14 | self.m = m 15 | self.cos_m = math.cos(m) 16 | self.sin_m = math.sin(m) 17 | 18 | self.th = math.cos(math.pi - m) 19 | self.mm = math.sin(math.pi - m) * m 20 | 21 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 22 | if bias: 23 | self.bias = Parameter(torch.Tensor(out_features)) 24 | else: 25 | self.register_parameter('bias', None) 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 30 | if self.bias is not None: 31 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 32 | bound = 1 / math.sqrt(fan_in) 33 | nn.init.uniform_(self.bias, -bound, bound) 34 | 35 | def forward(self, input, label): 36 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 37 | sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)) 38 | phi = cosine * self.cos_m - sine * self.sin_m 39 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 40 | # --------------------------- convert label to one-hot --------------------------- 41 | # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda') 42 | one_hot = torch.zeros(cosine.size(), device='cuda') 43 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 44 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 45 | output = (one_hot * phi) + ( 46 | (1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 47 | output *= self.s 48 | # print(output) 49 | 50 | return output 51 | 52 | class CircleLoss(nn.Module): 53 | def __init__(self, in_features, num_classes, s=256, m=0.25): 54 | super(CircleLoss, self).__init__() 55 | self.weight = Parameter(torch.Tensor(num_classes, in_features)) 56 | self.s = s 57 | self.m = m 58 | self._num_classes = num_classes 59 | self.reset_parameters() 60 | 61 | 62 | def reset_parameters(self): 63 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 64 | 65 | def __call__(self, bn_feat, targets): 66 | 67 | sim_mat = F.linear(F.normalize(bn_feat), F.normalize(self.weight)) 68 | alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self.m, min=0.) 69 | alpha_n = torch.clamp_min(sim_mat.detach() + self.m, min=0.) 70 | delta_p = 1 - self.m 71 | delta_n = self.m 72 | 73 | s_p = self.s * alpha_p * (sim_mat - delta_p) 74 | s_n = self.s * alpha_n * (sim_mat - delta_n) 75 | 76 | targets = F.one_hot(targets, num_classes=self._num_classes) 77 | 78 | pred_class_logits = targets * s_p + (1.0 - targets) * s_n 79 | 80 | return pred_class_logits -------------------------------------------------------------------------------- /loss/center_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class CenterLoss(nn.Module): 8 | """Center loss. 9 | 10 | Reference: 11 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 12 | 13 | Args: 14 | num_classes (int): number of classes. 15 | feat_dim (int): feature dimension. 16 | """ 17 | 18 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True): 19 | super(CenterLoss, self).__init__() 20 | self.num_classes = num_classes 21 | self.feat_dim = feat_dim 22 | self.use_gpu = use_gpu 23 | 24 | if self.use_gpu: 25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 26 | else: 27 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 28 | 29 | def forward(self, x, labels): 30 | """ 31 | Args: 32 | x: feature matrix with shape (batch_size, feat_dim). 33 | labels: ground truth labels with shape (num_classes). 34 | """ 35 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)" 36 | 37 | batch_size = x.size(0) 38 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 39 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 40 | distmat.addmm_(1, -2, x, self.centers.t()) 41 | 42 | classes = torch.arange(self.num_classes).long() 43 | if self.use_gpu: classes = classes.cuda() 44 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 45 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 46 | 47 | dist = [] 48 | for i in range(batch_size): 49 | value = distmat[i][mask[i]] 50 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 51 | dist.append(value) 52 | dist = torch.cat(dist) 53 | loss = dist.mean() 54 | return loss 55 | 56 | 57 | if __name__ == '__main__': 58 | use_gpu = False 59 | center_loss = CenterLoss(use_gpu=use_gpu) 60 | features = torch.rand(16, 2048) 61 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long() 62 | if use_gpu: 63 | features = torch.rand(16, 2048).cuda() 64 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda() 65 | 66 | loss = center_loss(features, targets) 67 | print(loss) 68 | -------------------------------------------------------------------------------- /loss/make_loss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch.nn.functional as F 8 | from .softmax_loss import CrossEntropyLabelSmooth, LabelSmoothingCrossEntropy 9 | from .triplet_loss import TripletLoss 10 | from .center_loss import CenterLoss 11 | 12 | 13 | def make_loss(cfg, num_classes): # modified by gu 14 | sampler = cfg.DATALOADER.SAMPLER 15 | feat_dim = 2048 16 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss 17 | if 'triplet' in cfg.MODEL.METRIC_LOSS_TYPE: 18 | if cfg.MODEL.NO_MARGIN: 19 | triplet = TripletLoss() 20 | print("using soft triplet loss for training") 21 | else: 22 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss 23 | print("using triplet loss with margin:{}".format(cfg.SOLVER.MARGIN)) 24 | else: 25 | print('expected METRIC_LOSS_TYPE should be triplet' 26 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 27 | 28 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 29 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) 30 | print("label smooth on, numclasses:", num_classes) 31 | 32 | if sampler == 'softmax': 33 | def loss_func(score, feat, target): 34 | return F.cross_entropy(score, target) 35 | 36 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet': 37 | def loss_func(score, feat, target, target_cam): 38 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet': 39 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 40 | if isinstance(score, list): 41 | ID_LOSS = [xent(scor, target) for scor in score[1:]] 42 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS) 43 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * xent(score[0], target) 44 | else: 45 | ID_LOSS = xent(score, target) 46 | 47 | if isinstance(feat, list): 48 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[1:]] 49 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS) 50 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * triplet(feat[0], target)[0] 51 | else: 52 | TRI_LOSS = triplet(feat, target)[0] 53 | 54 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \ 55 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS 56 | else: 57 | if isinstance(score, list): 58 | ID_LOSS = [F.cross_entropy(scor, target) for scor in score[3:]] 59 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS) 60 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * ((F.cross_entropy(score[0], target) + F.cross_entropy(score[1], target) + F.cross_entropy(score[2], target))/3) 61 | else: 62 | ID_LOSS = F.cross_entropy(score, target) 63 | 64 | if isinstance(feat, list): 65 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[3:]] 66 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS) 67 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * (triplet(feat[0], target)[0] + triplet(feat[1], target)[0] + triplet(feat[2], target)[0]) 68 | else: 69 | TRI_LOSS = triplet(feat, target)[0] 70 | 71 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \ 72 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS 73 | else: 74 | print('expected METRIC_LOSS_TYPE should be triplet' 75 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 76 | 77 | else: 78 | print('expected sampler should be softmax, triplet, softmax_triplet or softmax_triplet_center' 79 | 'but got {}'.format(cfg.DATALOADER.SAMPLER)) 80 | return loss_func, center_criterion 81 | 82 | 83 | -------------------------------------------------------------------------------- /loss/make_loss1_l2norm.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch.nn.functional as F 8 | from .softmax_loss import CrossEntropyLabelSmooth, LabelSmoothingCrossEntropy 9 | from .triplet_loss import TripletLoss 10 | from .center_loss import CenterLoss 11 | import torch.nn as nn 12 | class normalize(nn.Module): 13 | def __init__(self, power=2): 14 | super(normalize, self).__init__() 15 | self.power = power 16 | 17 | def forward(self, x): 18 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 19 | out = x.div(norm) 20 | return out 21 | 22 | def make_loss(cfg, num_classes): # modified by gu 23 | sampler = cfg.DATALOADER.SAMPLER 24 | feat_dim = 2048 25 | l2_norm = normalize() 26 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss 27 | if 'triplet' in cfg.MODEL.METRIC_LOSS_TYPE: 28 | if cfg.MODEL.NO_MARGIN: 29 | triplet = TripletLoss() 30 | print("using soft triplet loss for training") 31 | else: 32 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss 33 | print("using triplet loss with margin:{}".format(cfg.SOLVER.MARGIN)) 34 | else: 35 | print('expected METRIC_LOSS_TYPE should be triplet' 36 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 37 | 38 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 39 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) 40 | print("label smooth on, numclasses:", num_classes) 41 | 42 | if sampler == 'softmax': 43 | def loss_func(score, feat, target): 44 | return F.cross_entropy(score, target) 45 | 46 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet': 47 | def loss_func(score, feat, target, target_cam): 48 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet': 49 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 50 | if isinstance(score, list): 51 | ID_LOSS = [xent(scor, target) for scor in score[1:]] 52 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS) 53 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * xent(score[0], target) 54 | else: 55 | ID_LOSS = xent(score, target) 56 | 57 | if isinstance(feat, list): 58 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[1:]] 59 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS) 60 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * triplet(feat[0], target)[0] 61 | else: 62 | TRI_LOSS = triplet(feat, target)[0] 63 | 64 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \ 65 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS 66 | else: 67 | if isinstance(score, list): 68 | ID_LOSS = [F.cross_entropy(scor, target) for scor in score[3:]] 69 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS) 70 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * ((F.cross_entropy(score[0], target) + F.cross_entropy(score[1], target) + F.cross_entropy(score[2], target))/3) 71 | else: 72 | ID_LOSS = F.cross_entropy(score, target) 73 | 74 | if isinstance(feat, list): 75 | TRI_LOSS = [triplet(l2_norm(feats), target)[0] for feats in feat[3:]] 76 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS) 77 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * (triplet(l2_norm(feat[0]), target)[0] + triplet(l2_norm(feat[1]), target)[0] + triplet(l2_norm(feat[2]), target)[0]) 78 | else: 79 | TRI_LOSS = triplet(feat, target)[0] 80 | 81 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \ 82 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS 83 | else: 84 | print('expected METRIC_LOSS_TYPE should be triplet' 85 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 86 | 87 | else: 88 | print('expected sampler should be softmax, triplet, softmax_triplet or softmax_triplet_center' 89 | 'but got {}'.format(cfg.DATALOADER.SAMPLER)) 90 | return loss_func, center_criterion 91 | 92 | 93 | -------------------------------------------------------------------------------- /loss/make_loss1_vitbase_resnet.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch.nn.functional as F 8 | from .softmax_loss import CrossEntropyLabelSmooth, LabelSmoothingCrossEntropy 9 | from .triplet_loss import TripletLoss 10 | from .center_loss import CenterLoss 11 | 12 | 13 | def make_loss(cfg, num_classes): # modified by gu 14 | sampler = cfg.DATALOADER.SAMPLER 15 | feat_dim = 2048 16 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss 17 | if 'triplet' in cfg.MODEL.METRIC_LOSS_TYPE: 18 | if cfg.MODEL.NO_MARGIN: 19 | triplet = TripletLoss() 20 | print("using soft triplet loss for training") 21 | else: 22 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss 23 | print("using triplet loss with margin:{}".format(cfg.SOLVER.MARGIN)) 24 | else: 25 | print('expected METRIC_LOSS_TYPE should be triplet' 26 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 27 | 28 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 29 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) 30 | print("label smooth on, numclasses:", num_classes) 31 | 32 | if sampler == 'softmax': 33 | def loss_func(score, feat, target): 34 | return F.cross_entropy(score, target) 35 | 36 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet': 37 | def loss_func(score, feat, target, target_cam): 38 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet': 39 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 40 | if isinstance(score, list): 41 | ID_LOSS = [xent(scor, target) for scor in score[1:]] 42 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS) 43 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * xent(score[0], target) 44 | else: 45 | ID_LOSS = xent(score, target) 46 | 47 | if isinstance(feat, list): 48 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[1:]] 49 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS) 50 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * triplet(feat[0], target)[0] 51 | else: 52 | TRI_LOSS = triplet(feat, target)[0] 53 | 54 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \ 55 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS 56 | else: 57 | if isinstance(score, list): 58 | ID_LOSS = [F.cross_entropy(scor, target) for scor in score] 59 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS) 60 | # ID_LOSS = 0.5 * ID_LOSS + 0.5 * ((F.cross_entropy(score[0], target) + F.cross_entropy(score[1], target) + F.cross_entropy(score[2], target))/3) 61 | else: 62 | ID_LOSS = F.cross_entropy(score, target) 63 | 64 | if isinstance(feat, list): 65 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat] 66 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS) 67 | # TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * (triplet(feat[0], target)[0] + triplet(feat[1], target)[0] + triplet(feat[2], target)[0]) 68 | else: 69 | TRI_LOSS = triplet(feat, target)[0] 70 | 71 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \ 72 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS 73 | else: 74 | print('expected METRIC_LOSS_TYPE should be triplet' 75 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 76 | 77 | else: 78 | print('expected sampler should be softmax, triplet, softmax_triplet or softmax_triplet_center' 79 | 'but got {}'.format(cfg.DATALOADER.SAMPLER)) 80 | return loss_func, center_criterion 81 | 82 | 83 | -------------------------------------------------------------------------------- /loss/make_loss_onlyOneAugmentation.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch.nn.functional as F 8 | from .softmax_loss import CrossEntropyLabelSmooth, LabelSmoothingCrossEntropy 9 | from .triplet_loss import TripletLoss 10 | from .center_loss import CenterLoss 11 | 12 | 13 | def make_loss(cfg, num_classes): # modified by gu 14 | sampler = cfg.DATALOADER.SAMPLER 15 | feat_dim = 2048 16 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss 17 | if 'triplet' in cfg.MODEL.METRIC_LOSS_TYPE: 18 | if cfg.MODEL.NO_MARGIN: 19 | triplet = TripletLoss() 20 | print("using soft triplet loss for training") 21 | else: 22 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss 23 | print("using triplet loss with margin:{}".format(cfg.SOLVER.MARGIN)) 24 | else: 25 | print('expected METRIC_LOSS_TYPE should be triplet' 26 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 27 | 28 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 29 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) 30 | print("label smooth on, numclasses:", num_classes) 31 | 32 | if sampler == 'softmax': 33 | def loss_func(score, feat, target): 34 | return F.cross_entropy(score, target) 35 | 36 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet': 37 | def loss_func(score, feat, target, target_cam): 38 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet': 39 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 40 | if isinstance(score, list): 41 | ID_LOSS = [xent(scor, target) for scor in score[1:]] 42 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS) 43 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * xent(score[0], target) 44 | else: 45 | ID_LOSS = xent(score, target) 46 | 47 | if isinstance(feat, list): 48 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[1:]] 49 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS) 50 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * triplet(feat[0], target)[0] 51 | else: 52 | TRI_LOSS = triplet(feat, target)[0] 53 | 54 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \ 55 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS 56 | else: 57 | if isinstance(score, list): 58 | ID_LOSS = [F.cross_entropy(scor, target) for scor in score[2:]] 59 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS) 60 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * ((F.cross_entropy(score[0], target) + F.cross_entropy(score[1], target))/2) 61 | else: 62 | ID_LOSS = F.cross_entropy(score, target) 63 | 64 | if isinstance(feat, list): 65 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[2:]] 66 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS) 67 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * (triplet(feat[0], target)[0] + triplet(feat[1], target)[0]) 68 | else: 69 | TRI_LOSS = triplet(feat, target)[0] 70 | 71 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \ 72 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS 73 | else: 74 | print('expected METRIC_LOSS_TYPE should be triplet' 75 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 76 | 77 | else: 78 | print('expected sampler should be softmax, triplet, softmax_triplet or softmax_triplet_center' 79 | 'but got {}'.format(cfg.DATALOADER.SAMPLER)) 80 | return loss_func, center_criterion 81 | 82 | 83 | -------------------------------------------------------------------------------- /loss/softmax_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | class CrossEntropyLabelSmooth(nn.Module): 5 | """Cross entropy loss with label smoothing regularizer. 6 | 7 | Reference: 8 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 9 | Equation: y = (1 - epsilon) * y + epsilon / K. 10 | 11 | Args: 12 | num_classes (int): number of classes. 13 | epsilon (float): weight. 14 | """ 15 | 16 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 17 | super(CrossEntropyLabelSmooth, self).__init__() 18 | self.num_classes = num_classes 19 | self.epsilon = epsilon 20 | self.use_gpu = use_gpu 21 | self.logsoftmax = nn.LogSoftmax(dim=1) 22 | 23 | def forward(self, inputs, targets): 24 | """ 25 | Args: 26 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 27 | targets: ground truth labels with shape (num_classes) 28 | """ 29 | log_probs = self.logsoftmax(inputs) 30 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 31 | if self.use_gpu: targets = targets.cuda() 32 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 33 | loss = (- targets * log_probs).mean(0).sum() 34 | return loss 35 | 36 | class LabelSmoothingCrossEntropy(nn.Module): 37 | """ 38 | NLL loss with label smoothing. 39 | """ 40 | def __init__(self, smoothing=0.1): 41 | """ 42 | Constructor for the LabelSmoothing module. 43 | :param smoothing: label smoothing factor 44 | """ 45 | super(LabelSmoothingCrossEntropy, self).__init__() 46 | assert smoothing < 1.0 47 | self.smoothing = smoothing 48 | self.confidence = 1. - smoothing 49 | 50 | def forward(self, x, target): 51 | logprobs = F.log_softmax(x, dim=-1) 52 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 53 | nll_loss = nll_loss.squeeze(1) 54 | smooth_loss = -logprobs.mean(dim=-1) 55 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 56 | return loss.mean() -------------------------------------------------------------------------------- /loss/triplet_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def normalize(x, axis=-1): 6 | """Normalizing to unit length along the specified dimension. 7 | Args: 8 | x: pytorch Variable 9 | Returns: 10 | x: pytorch Variable, same shape as input 11 | """ 12 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 13 | return x 14 | 15 | 16 | def euclidean_dist(x, y): 17 | """ 18 | Args: 19 | x: pytorch Variable, with shape [m, d] 20 | y: pytorch Variable, with shape [n, d] 21 | Returns: 22 | dist: pytorch Variable, with shape [m, n] 23 | """ 24 | m, n = x.size(0), y.size(0) 25 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 26 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 27 | dist = xx + yy 28 | dist = dist - 2 * torch.matmul(x, y.t()) 29 | # dist.addmm_(1, -2, x, y.t()) 30 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 31 | return dist 32 | 33 | 34 | def cosine_dist(x, y): 35 | """ 36 | Args: 37 | x: pytorch Variable, with shape [m, d] 38 | y: pytorch Variable, with shape [n, d] 39 | Returns: 40 | dist: pytorch Variable, with shape [m, n] 41 | """ 42 | m, n = x.size(0), y.size(0) 43 | x_norm = torch.pow(x, 2).sum(1, keepdim=True).sqrt().expand(m, n) 44 | y_norm = torch.pow(y, 2).sum(1, keepdim=True).sqrt().expand(n, m).t() 45 | xy_intersection = torch.mm(x, y.t()) 46 | dist = xy_intersection/(x_norm * y_norm) 47 | dist = (1. - dist) / 2 48 | return dist 49 | 50 | 51 | def hard_example_mining(dist_mat, labels, return_inds=False): 52 | """For each anchor, find the hardest positive and negative sample. 53 | Args: 54 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 55 | labels: pytorch LongTensor, with shape [N] 56 | return_inds: whether to return the indices. Save time if `False`(?) 57 | Returns: 58 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 59 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 60 | p_inds: pytorch LongTensor, with shape [N]; 61 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 62 | n_inds: pytorch LongTensor, with shape [N]; 63 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 64 | NOTE: Only consider the case in which all labels have same num of samples, 65 | thus we can cope with all anchors in parallel. 66 | """ 67 | 68 | assert len(dist_mat.size()) == 2 69 | assert dist_mat.size(0) == dist_mat.size(1) 70 | N = dist_mat.size(0) 71 | 72 | # shape [N, N] 73 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 74 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 75 | 76 | # `dist_ap` means distance(anchor, positive) 77 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 78 | dist_ap, relative_p_inds = torch.max( 79 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 80 | # print(dist_mat[is_pos].shape) 81 | # `dist_an` means distance(anchor, negative) 82 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 83 | dist_an, relative_n_inds = torch.min( 84 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 85 | # shape [N] 86 | dist_ap = dist_ap.squeeze(1) 87 | dist_an = dist_an.squeeze(1) 88 | 89 | if return_inds: 90 | # shape [N, N] 91 | ind = (labels.new().resize_as_(labels) 92 | .copy_(torch.arange(0, N).long()) 93 | .unsqueeze(0).expand(N, N)) 94 | # shape [N, 1] 95 | p_inds = torch.gather( 96 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 97 | n_inds = torch.gather( 98 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 99 | # shape [N] 100 | p_inds = p_inds.squeeze(1) 101 | n_inds = n_inds.squeeze(1) 102 | return dist_ap, dist_an, p_inds, n_inds 103 | 104 | return dist_ap, dist_an 105 | 106 | 107 | class TripletLoss(object): 108 | """ 109 | Triplet loss using HARDER example mining, 110 | modified based on original triplet loss using hard example mining 111 | """ 112 | 113 | def __init__(self, margin=None, hard_factor=0.0): 114 | self.margin = margin 115 | self.hard_factor = hard_factor 116 | if margin is not None: 117 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 118 | else: 119 | self.ranking_loss = nn.SoftMarginLoss() 120 | 121 | def __call__(self, global_feat, labels, normalize_feature=False): 122 | if normalize_feature: 123 | global_feat = normalize(global_feat, axis=-1) 124 | dist_mat = euclidean_dist(global_feat, global_feat) 125 | dist_ap, dist_an = hard_example_mining(dist_mat, labels) 126 | 127 | dist_ap *= (1.0 + self.hard_factor) 128 | dist_an *= (1.0 - self.hard_factor) 129 | 130 | y = dist_an.new().resize_as_(dist_an).fill_(1) 131 | if self.margin is not None: 132 | loss = self.ranking_loss(dist_an, dist_ap, y) 133 | else: 134 | loss = self.ranking_loss(dist_an - dist_ap, y) 135 | return loss, dist_ap, dist_an 136 | 137 | 138 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_model import make_model -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/__pycache__/make_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/model/__pycache__/make_model.cpython-36.pyc -------------------------------------------------------------------------------- /model/backbones/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/model/backbones/__init__.py -------------------------------------------------------------------------------- /model/backbones/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/model/backbones/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/backbones/__pycache__/osnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/model/backbones/__pycache__/osnet.cpython-36.pyc -------------------------------------------------------------------------------- /model/backbones/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/model/backbones/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /model/backbones/__pycache__/vit_pytorch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/model/backbones/__pycache__/vit_pytorch.cpython-36.pyc -------------------------------------------------------------------------------- /model/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.downsample = downsample 24 | self.stride = stride 25 | 26 | def forward(self, x): 27 | residual = x 28 | 29 | out = self.conv1(x) 30 | out = self.bn1(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv2(out) 34 | out = self.bn2(out) 35 | 36 | if self.downsample is not None: 37 | residual = self.downsample(x) 38 | 39 | out += residual 40 | out = self.relu(out) 41 | 42 | return out 43 | 44 | 45 | class Bottleneck(nn.Module): 46 | expansion = 4 47 | 48 | def __init__(self, inplanes, planes, stride=1, downsample=None): 49 | super(Bottleneck, self).__init__() 50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 51 | self.bn1 = nn.BatchNorm2d(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 53 | padding=1, bias=False) 54 | self.bn2 = nn.BatchNorm2d(planes) 55 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 56 | self.bn3 = nn.BatchNorm2d(planes * 4) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | def forward(self, x): 62 | residual = x 63 | 64 | out = self.conv1(x) 65 | out = self.bn1(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv2(out) 69 | out = self.bn2(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv3(out) 73 | out = self.bn3(out) 74 | 75 | if self.downsample is not None: 76 | residual = self.downsample(x) 77 | 78 | out += residual 79 | out = self.relu(out) 80 | 81 | return out 82 | 83 | 84 | class ResNet(nn.Module): 85 | def __init__(self, last_stride=2, block=Bottleneck,layers=[3, 4, 6, 3]): 86 | self.inplanes = 64 87 | super().__init__() 88 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 89 | bias=False) 90 | self.bn1 = nn.BatchNorm2d(64) 91 | # self.relu = nn.ReLU(inplace=True) # add missed relu 92 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=None, padding=0) 93 | self.layer1 = self._make_layer(block, 64, layers[0]) 94 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 95 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 96 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride) 97 | 98 | def _make_layer(self, block, planes, blocks, stride=1): 99 | downsample = None 100 | if stride != 1 or self.inplanes != planes * block.expansion: 101 | downsample = nn.Sequential( 102 | nn.Conv2d(self.inplanes, planes * block.expansion, 103 | kernel_size=1, stride=stride, bias=False), 104 | nn.BatchNorm2d(planes * block.expansion), 105 | ) 106 | 107 | layers = [] 108 | layers.append(block(self.inplanes, planes, stride, downsample)) 109 | self.inplanes = planes * block.expansion 110 | for i in range(1, blocks): 111 | layers.append(block(self.inplanes, planes)) 112 | 113 | return nn.Sequential(*layers) 114 | 115 | def forward(self, x, cam_label=None): 116 | x = self.conv1(x) 117 | x = self.bn1(x) 118 | # x = self.relu(x) # add missed relu 119 | x = self.maxpool(x) 120 | x = self.layer1(x) 121 | x = self.layer2(x) 122 | x = self.layer3(x) 123 | x = self.layer4(x) 124 | 125 | return x 126 | 127 | def load_param(self, model_path): 128 | param_dict = torch.load(model_path) 129 | for i in param_dict: 130 | if 'fc' in i: 131 | continue 132 | self.state_dict()[i].copy_(param_dict[i]) 133 | 134 | def random_init(self): 135 | for m in self.modules(): 136 | if isinstance(m, nn.Conv2d): 137 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 138 | m.weight.data.normal_(0, math.sqrt(2. / n)) 139 | elif isinstance(m, nn.BatchNorm2d): 140 | m.weight.data.fill_(1) 141 | m.bias.data.zero_() -------------------------------------------------------------------------------- /processor/__init__.py: -------------------------------------------------------------------------------- 1 | from .processor import do_train, do_inference -------------------------------------------------------------------------------- /processor/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/processor/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /processor/__pycache__/processor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/processor/__pycache__/processor.cpython-36.pyc -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | timm 4 | yacs 5 | opencv-python -------------------------------------------------------------------------------- /solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .lr_scheduler import WarmupMultiStepLR 2 | from .make_optimizer import make_optimizer -------------------------------------------------------------------------------- /solver/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/solver/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /solver/__pycache__/cosine_lr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/solver/__pycache__/cosine_lr.cpython-36.pyc -------------------------------------------------------------------------------- /solver/__pycache__/lr_scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/solver/__pycache__/lr_scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /solver/__pycache__/make_optimizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/solver/__pycache__/make_optimizer.cpython-36.pyc -------------------------------------------------------------------------------- /solver/__pycache__/scheduler.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/solver/__pycache__/scheduler.cpython-36.pyc -------------------------------------------------------------------------------- /solver/__pycache__/scheduler_factory.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/solver/__pycache__/scheduler_factory.cpython-36.pyc -------------------------------------------------------------------------------- /solver/cosine_lr.py: -------------------------------------------------------------------------------- 1 | """ Cosine Scheduler 2 | 3 | Cosine LR schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import torch 10 | 11 | from .scheduler import Scheduler 12 | 13 | 14 | _logger = logging.getLogger(__name__) 15 | 16 | 17 | class CosineLRScheduler(Scheduler): 18 | """ 19 | Cosine decay with restarts. 20 | This is described in the paper https://arxiv.org/abs/1608.03983. 21 | 22 | Inspiration from 23 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 24 | """ 25 | 26 | def __init__(self, 27 | optimizer: torch.optim.Optimizer, 28 | t_initial: int, 29 | t_mul: float = 1., 30 | lr_min: float = 0., 31 | decay_rate: float = 1., 32 | warmup_t=0, 33 | warmup_lr_init=0, 34 | warmup_prefix=False, 35 | cycle_limit=0, 36 | t_in_epochs=True, 37 | noise_range_t=None, 38 | noise_pct=0.67, 39 | noise_std=1.0, 40 | noise_seed=42, 41 | initialize=True) -> None: 42 | super().__init__( 43 | optimizer, param_group_field="lr", 44 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 45 | initialize=initialize) 46 | 47 | assert t_initial > 0 48 | assert lr_min >= 0 49 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 50 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 51 | "rate since t_initial = t_mul = eta_mul = 1.") 52 | self.t_initial = t_initial 53 | self.t_mul = t_mul 54 | self.lr_min = lr_min 55 | self.decay_rate = decay_rate 56 | self.cycle_limit = cycle_limit 57 | self.warmup_t = warmup_t 58 | self.warmup_lr_init = warmup_lr_init 59 | self.warmup_prefix = warmup_prefix 60 | self.t_in_epochs = t_in_epochs 61 | if self.warmup_t: 62 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 63 | super().update_groups(self.warmup_lr_init) 64 | else: 65 | self.warmup_steps = [1 for _ in self.base_values] 66 | 67 | def _get_lr(self, t): 68 | if t < self.warmup_t: 69 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 70 | else: 71 | if self.warmup_prefix: 72 | t = t - self.warmup_t 73 | 74 | if self.t_mul != 1: 75 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 76 | t_i = self.t_mul ** i * self.t_initial 77 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 78 | else: 79 | i = t // self.t_initial 80 | t_i = self.t_initial 81 | t_curr = t - (self.t_initial * i) 82 | 83 | gamma = self.decay_rate ** i 84 | lr_min = self.lr_min * gamma 85 | lr_max_values = [v * gamma for v in self.base_values] 86 | 87 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 88 | lrs = [ 89 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 90 | ] 91 | else: 92 | lrs = [self.lr_min for _ in self.base_values] 93 | 94 | return lrs 95 | 96 | def get_epoch_values(self, epoch: int): 97 | if self.t_in_epochs: 98 | return self._get_lr(epoch) 99 | else: 100 | return None 101 | 102 | def get_update_values(self, num_updates: int): 103 | if not self.t_in_epochs: 104 | return self._get_lr(num_updates) 105 | else: 106 | return None 107 | 108 | def get_cycle_length(self, cycles=0): 109 | if not cycles: 110 | cycles = self.cycle_limit 111 | cycles = max(1, cycles) 112 | if self.t_mul == 1.0: 113 | return self.t_initial * cycles 114 | else: 115 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 116 | -------------------------------------------------------------------------------- /solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from bisect import bisect_right 7 | import torch 8 | 9 | 10 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 11 | # separating MultiStepLR with WarmupLR 12 | # but the current LRScheduler design doesn't allow it 13 | 14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 15 | def __init__( 16 | self, 17 | optimizer, 18 | milestones, # steps 19 | gamma=0.1, 20 | warmup_factor=1.0 / 3, 21 | warmup_iters=500, 22 | warmup_method="linear", 23 | last_epoch=-1, 24 | ): 25 | if not list(milestones) == sorted(milestones): 26 | raise ValueError( 27 | "Milestones should be a list of" " increasing integers. Got {}", 28 | milestones, 29 | ) 30 | 31 | if warmup_method not in ("constant", "linear"): 32 | raise ValueError( 33 | "Only 'constant' or 'linear' warmup_method accepted" 34 | "got {}".format(warmup_method) 35 | ) 36 | self.milestones = milestones 37 | self.gamma = gamma 38 | self.warmup_factor = warmup_factor 39 | self.warmup_iters = warmup_iters 40 | self.warmup_method = warmup_method 41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 42 | 43 | def _get_lr(self): 44 | warmup_factor = 1 45 | if self.last_epoch < self.warmup_iters: 46 | if self.warmup_method == "constant": 47 | warmup_factor = self.warmup_factor 48 | elif self.warmup_method == "linear": 49 | alpha = self.last_epoch / self.warmup_iters 50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 51 | return [ 52 | base_lr 53 | * warmup_factor 54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 55 | for base_lr in self.base_lrs 56 | ] 57 | -------------------------------------------------------------------------------- /solver/make_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def make_optimizer(cfg, model, center_criterion): 5 | params = [] 6 | for key, value in model.named_parameters(): 7 | if not value.requires_grad: 8 | continue 9 | lr = cfg.SOLVER.BASE_LR 10 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 11 | if "bias" in key: 12 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 13 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 14 | if cfg.SOLVER.LARGE_FC_LR: 15 | if "classifier" in key or "arcface" in key: 16 | lr = cfg.SOLVER.BASE_LR * 2 17 | print('Using two times learning rate for fc ') 18 | 19 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 20 | 21 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 22 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 23 | elif cfg.SOLVER.OPTIMIZER_NAME == 'AdamW': 24 | optimizer = torch.optim.AdamW(params, lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 25 | else: 26 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 27 | optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR) 28 | 29 | return optimizer, optimizer_center 30 | -------------------------------------------------------------------------------- /solver/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | import torch 4 | 5 | 6 | class Scheduler: 7 | """ Parameter Scheduler Base Class 8 | A scheduler base class that can be used to schedule any optimizer parameter groups. 9 | 10 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called 11 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value 12 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value 13 | 14 | The schedulers built on this should try to remain as stateless as possible (for simplicity). 15 | 16 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' 17 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training 18 | code and explicitly passed in to the schedulers on the corresponding step or step_update call. 19 | 20 | Based on ideas from: 21 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler 22 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers 23 | """ 24 | 25 | def __init__(self, 26 | optimizer: torch.optim.Optimizer, 27 | param_group_field: str, 28 | noise_range_t=None, 29 | noise_type='normal', 30 | noise_pct=0.67, 31 | noise_std=1.0, 32 | noise_seed=None, 33 | initialize: bool = True) -> None: 34 | self.optimizer = optimizer 35 | self.param_group_field = param_group_field 36 | self._initial_param_group_field = f"initial_{param_group_field}" 37 | if initialize: 38 | for i, group in enumerate(self.optimizer.param_groups): 39 | if param_group_field not in group: 40 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]") 41 | group.setdefault(self._initial_param_group_field, group[param_group_field]) 42 | else: 43 | for i, group in enumerate(self.optimizer.param_groups): 44 | if self._initial_param_group_field not in group: 45 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") 46 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] 47 | self.metric = None # any point to having this for all? 48 | self.noise_range_t = noise_range_t 49 | self.noise_pct = noise_pct 50 | self.noise_type = noise_type 51 | self.noise_std = noise_std 52 | self.noise_seed = noise_seed if noise_seed is not None else 42 53 | self.update_groups(self.base_values) 54 | 55 | def state_dict(self) -> Dict[str, Any]: 56 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 57 | 58 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 59 | self.__dict__.update(state_dict) 60 | 61 | def get_epoch_values(self, epoch: int): 62 | return None 63 | 64 | def get_update_values(self, num_updates: int): 65 | return None 66 | 67 | def step(self, epoch: int, metric: float = None) -> None: 68 | self.metric = metric 69 | values = self.get_epoch_values(epoch) 70 | if values is not None: 71 | values = self._add_noise(values, epoch) 72 | self.update_groups(values) 73 | 74 | def step_update(self, num_updates: int, metric: float = None): 75 | self.metric = metric 76 | values = self.get_update_values(num_updates) 77 | if values is not None: 78 | values = self._add_noise(values, num_updates) 79 | self.update_groups(values) 80 | 81 | def update_groups(self, values): 82 | if not isinstance(values, (list, tuple)): 83 | values = [values] * len(self.optimizer.param_groups) 84 | for param_group, value in zip(self.optimizer.param_groups, values): 85 | param_group[self.param_group_field] = value 86 | 87 | def _add_noise(self, lrs, t): 88 | if self.noise_range_t is not None: 89 | if isinstance(self.noise_range_t, (list, tuple)): 90 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] 91 | else: 92 | apply_noise = t >= self.noise_range_t 93 | if apply_noise: 94 | g = torch.Generator() 95 | g.manual_seed(self.noise_seed + t) 96 | if self.noise_type == 'normal': 97 | while True: 98 | # resample if noise out of percent limit, brute force but shouldn't spin much 99 | noise = torch.randn(1, generator=g).item() 100 | if abs(noise) < self.noise_pct: 101 | break 102 | else: 103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 104 | lrs = [v + v * noise for v in lrs] 105 | return lrs 106 | -------------------------------------------------------------------------------- /solver/scheduler_factory.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from .cosine_lr import CosineLRScheduler 5 | 6 | 7 | def create_scheduler(cfg, optimizer): 8 | num_epochs = cfg.SOLVER.MAX_EPOCHS 9 | # type 1 10 | # lr_min = 0.01 * cfg.SOLVER.BASE_LR 11 | # warmup_lr_init = 0.001 * cfg.SOLVER.BASE_LR 12 | # type 2 13 | lr_min = 0.002 * cfg.SOLVER.BASE_LR 14 | warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR 15 | # type 3 16 | # lr_min = 0.001 * cfg.SOLVER.BASE_LR 17 | # warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR 18 | 19 | warmup_t = cfg.SOLVER.WARMUP_EPOCHS 20 | noise_range = None 21 | 22 | lr_scheduler = CosineLRScheduler( 23 | optimizer, 24 | t_initial=num_epochs, 25 | lr_min=lr_min, 26 | t_mul= 1., 27 | decay_rate=0.1, 28 | warmup_lr_init=warmup_lr_init, 29 | warmup_t=warmup_t, 30 | cycle_limit=1, 31 | t_in_epochs=True, 32 | noise_range_t=noise_range, 33 | noise_pct= 0.67, 34 | noise_std= 1., 35 | noise_seed=42, 36 | ) 37 | 38 | return lr_scheduler 39 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from config import cfg 3 | import argparse 4 | from datasets import make_dataloader 5 | from model import make_model 6 | from processor import do_inference 7 | from utils.logger import setup_logger 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 12 | parser.add_argument( 13 | "--config_file", default="", help="path to config file", type=str 14 | ) 15 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 16 | nargs=argparse.REMAINDER) 17 | 18 | args = parser.parse_args() 19 | 20 | 21 | 22 | if args.config_file != "": 23 | cfg.merge_from_file(args.config_file) 24 | cfg.merge_from_list(args.opts) 25 | cfg.freeze() 26 | 27 | output_dir = cfg.OUTPUT_DIR 28 | if output_dir and not os.path.exists(output_dir): 29 | os.makedirs(output_dir) 30 | 31 | logger = setup_logger("transreid", output_dir, if_train=False) 32 | logger.info(args) 33 | 34 | if args.config_file != "": 35 | logger.info("Loaded configuration file {}".format(args.config_file)) 36 | with open(args.config_file, 'r') as cf: 37 | config_str = "\n" + cf.read() 38 | logger.info(config_str) 39 | logger.info("Running with config:\n{}".format(cfg)) 40 | 41 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID 42 | 43 | train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg) 44 | 45 | model = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num) 46 | model.load_param(cfg.TEST.WEIGHT) 47 | 48 | if cfg.DATASETS.NAMES == 'VehicleID': 49 | for trial in range(10): 50 | train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg) 51 | rank_1, rank5 = do_inference(cfg, 52 | model, 53 | val_loader, 54 | num_query) 55 | if trial == 0: 56 | all_rank_1 = rank_1 57 | all_rank_5 = rank5 58 | else: 59 | all_rank_1 = all_rank_1 + rank_1 60 | all_rank_5 = all_rank_5 + rank5 61 | 62 | logger.info("rank_1:{}, rank_5 {} : trial : {}".format(rank_1, rank5, trial)) 63 | logger.info("sum_rank_1:{:.1%}, sum_rank_5 {:.1%}".format(all_rank_1.sum()/10.0, all_rank_5.sum()/10.0)) 64 | else: 65 | do_inference(cfg, 66 | model, 67 | val_loader, 68 | num_query) 69 | 70 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from utils.logger import setup_logger 2 | from datasets import make_dataloader 3 | from model import make_model 4 | from solver import make_optimizer 5 | from solver.scheduler_factory import create_scheduler 6 | from loss import make_loss 7 | from processor import do_train 8 | import random 9 | import torch 10 | import numpy as np 11 | import os 12 | import argparse 13 | # from timm.scheduler import create_scheduler 14 | from config import cfg 15 | 16 | def set_seed(seed): 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | np.random.seed(seed) 21 | random.seed(seed) 22 | torch.backends.cudnn.deterministic = True 23 | torch.backends.cudnn.benchmark = True 24 | 25 | if __name__ == '__main__': 26 | 27 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 28 | parser.add_argument( 29 | "--config_file", default="", help="path to config file", type=str 30 | ) 31 | 32 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 33 | nargs=argparse.REMAINDER) 34 | parser.add_argument("--local_rank", default=0, type=int) 35 | args = parser.parse_args() 36 | 37 | if args.config_file != "": 38 | cfg.merge_from_file(args.config_file) 39 | cfg.merge_from_list(args.opts) 40 | cfg.freeze() 41 | 42 | set_seed(cfg.SOLVER.SEED) 43 | 44 | if cfg.MODEL.DIST_TRAIN: 45 | torch.cuda.set_device(args.local_rank) 46 | 47 | output_dir = cfg.OUTPUT_DIR 48 | if output_dir and not os.path.exists(output_dir): 49 | os.makedirs(output_dir) 50 | 51 | logger = setup_logger("transreid", output_dir, if_train=True) 52 | logger.info("Saving model in the path :{}".format(cfg.OUTPUT_DIR)) 53 | logger.info(args) 54 | 55 | if args.config_file != "": 56 | logger.info("Loaded configuration file {}".format(args.config_file)) 57 | with open(args.config_file, 'r') as cf: 58 | config_str = "\n" + cf.read() 59 | logger.info(config_str) 60 | logger.info("Running with config:\n{}".format(cfg)) 61 | 62 | if cfg.MODEL.DIST_TRAIN: 63 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 64 | 65 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID 66 | train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg) 67 | 68 | model = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num) 69 | 70 | loss_func, center_criterion = make_loss(cfg, num_classes=num_classes) 71 | 72 | optimizer, optimizer_center = make_optimizer(cfg, model, center_criterion) 73 | 74 | scheduler = create_scheduler(cfg, optimizer) 75 | 76 | do_train( 77 | cfg, 78 | model, 79 | center_criterion, 80 | train_loader, 81 | val_loader, 82 | optimizer, 83 | optimizer_center, 84 | scheduler, 85 | loss_func, 86 | num_query, args.local_rank 87 | ) 88 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/iotools.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/utils/__pycache__/iotools.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/utils/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/utils/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/meter.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/utils/__pycache__/meter.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/utils/__pycache__/metrics.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/reranking.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ziwang1121/PADE/acec784f557d193f580f7a6ba7ea5f02c6101668/utils/__pycache__/reranking.cpython-36.pyc -------------------------------------------------------------------------------- /utils/iotools.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import errno 8 | import json 9 | import os 10 | 11 | import os.path as osp 12 | 13 | 14 | def mkdir_if_missing(directory): 15 | if not osp.exists(directory): 16 | try: 17 | os.makedirs(directory) 18 | except OSError as e: 19 | if e.errno != errno.EEXIST: 20 | raise 21 | 22 | 23 | def check_isfile(path): 24 | isfile = osp.isfile(path) 25 | if not isfile: 26 | print("=> Warning: no file found at '{}' (ignored)".format(path)) 27 | return isfile 28 | 29 | 30 | def read_json(fpath): 31 | with open(fpath, 'r') as f: 32 | obj = json.load(f) 33 | return obj 34 | 35 | 36 | def write_json(obj, fpath): 37 | mkdir_if_missing(osp.dirname(fpath)) 38 | with open(fpath, 'w') as f: 39 | json.dump(obj, f, indent=4, separators=(',', ': ')) 40 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import os.path as osp 5 | def setup_logger(name, save_dir, if_train): 6 | logger = logging.getLogger(name) 7 | logger.setLevel(logging.DEBUG) 8 | 9 | ch = logging.StreamHandler(stream=sys.stdout) 10 | ch.setLevel(logging.DEBUG) 11 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 12 | ch.setFormatter(formatter) 13 | logger.addHandler(ch) 14 | 15 | if save_dir: 16 | if not osp.exists(save_dir): 17 | os.makedirs(save_dir) 18 | if if_train: 19 | fh = logging.FileHandler(os.path.join(save_dir, "train_log.txt"), mode='w') 20 | else: 21 | fh = logging.FileHandler(os.path.join(save_dir, "test_log.txt"), mode='w') 22 | fh.setLevel(logging.DEBUG) 23 | fh.setFormatter(formatter) 24 | logger.addHandler(fh) 25 | 26 | return logger -------------------------------------------------------------------------------- /utils/meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | 4 | def __init__(self): 5 | self.val = 0 6 | self.avg = 0 7 | self.sum = 0 8 | self.count = 0 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from utils.reranking import re_ranking 5 | 6 | 7 | def euclidean_distance(qf, gf): 8 | m = qf.shape[0] 9 | n = gf.shape[0] 10 | dist_mat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 11 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 12 | dist_mat.addmm_(1, -2, qf, gf.t()) 13 | return dist_mat.cpu().numpy() 14 | 15 | def cosine_similarity(qf, gf): 16 | epsilon = 0.00001 17 | dist_mat = qf.mm(gf.t()) 18 | qf_norm = torch.norm(qf, p=2, dim=1, keepdim=True) # mx1 19 | gf_norm = torch.norm(gf, p=2, dim=1, keepdim=True) # nx1 20 | qg_normdot = qf_norm.mm(gf_norm.t()) 21 | 22 | dist_mat = dist_mat.mul(1 / qg_normdot).cpu().numpy() 23 | dist_mat = np.clip(dist_mat, -1 + epsilon, 1 - epsilon) 24 | dist_mat = np.arccos(dist_mat) 25 | return dist_mat 26 | 27 | 28 | def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): 29 | """Evaluation with market1501 metric 30 | Key: for each query identity, its gallery images from the same camera view are discarded. 31 | """ 32 | num_q, num_g = distmat.shape 33 | # distmat g 34 | # q 1 3 2 4 35 | # 4 1 2 3 36 | if num_g < max_rank: 37 | max_rank = num_g 38 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 39 | indices = np.argsort(distmat, axis=1) 40 | # 0 2 1 3 41 | # 1 2 3 0 42 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 43 | # compute cmc curve for each query 44 | all_cmc = [] 45 | all_AP = [] 46 | num_valid_q = 0. # number of valid query 47 | for q_idx in range(num_q): 48 | # get query pid and camid 49 | q_pid = q_pids[q_idx] 50 | q_camid = q_camids[q_idx] 51 | 52 | # remove gallery samples that have the same pid and camid with query 53 | order = indices[q_idx] # select one row 54 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 55 | keep = np.invert(remove) 56 | 57 | # compute cmc curve 58 | # binary vector, positions with value 1 are correct matches 59 | orig_cmc = matches[q_idx][keep] 60 | if not np.any(orig_cmc): 61 | # this condition is true when query identity does not appear in gallery 62 | continue 63 | 64 | cmc = orig_cmc.cumsum() 65 | cmc[cmc > 1] = 1 66 | 67 | all_cmc.append(cmc[:max_rank]) 68 | num_valid_q += 1. 69 | 70 | # compute average precision 71 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 72 | num_rel = orig_cmc.sum() 73 | tmp_cmc = orig_cmc.cumsum() 74 | #tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 75 | y = np.arange(1, tmp_cmc.shape[0] + 1) * 1.0 76 | tmp_cmc = tmp_cmc / y 77 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 78 | AP = tmp_cmc.sum() / num_rel 79 | all_AP.append(AP) 80 | 81 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 82 | 83 | all_cmc = np.asarray(all_cmc).astype(np.float32) 84 | all_cmc = all_cmc.sum(0) / num_valid_q 85 | mAP = np.mean(all_AP) 86 | 87 | return all_cmc, mAP 88 | 89 | 90 | class R1_mAP_eval(): 91 | def __init__(self, num_query, max_rank=50, feat_norm=True, reranking=False): 92 | super(R1_mAP_eval, self).__init__() 93 | self.num_query = num_query 94 | self.max_rank = max_rank 95 | self.feat_norm = feat_norm 96 | self.reranking = reranking 97 | 98 | def reset(self): 99 | self.feats = [] 100 | self.pids = [] 101 | self.camids = [] 102 | 103 | def update(self, output): # called once for each batch 104 | feat, pid, camid = output 105 | self.feats.append(feat.cpu()) 106 | self.pids.extend(np.asarray(pid)) 107 | self.camids.extend(np.asarray(camid)) 108 | 109 | def compute(self): # called after each epoch 110 | feats = torch.cat(self.feats, dim=0) 111 | if self.feat_norm: 112 | print("The test feature is normalized") 113 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) # along channel 114 | # query 115 | qf = feats[:self.num_query] 116 | q_pids = np.asarray(self.pids[:self.num_query]) 117 | q_camids = np.asarray(self.camids[:self.num_query]) 118 | # gallery 119 | gf = feats[self.num_query:] 120 | g_pids = np.asarray(self.pids[self.num_query:]) 121 | 122 | g_camids = np.asarray(self.camids[self.num_query:]) 123 | if self.reranking: 124 | print('=> Enter reranking') 125 | # distmat = re_ranking(qf, gf, k1=20, k2=6, lambda_value=0.3) 126 | distmat = re_ranking(qf, gf, k1=50, k2=15, lambda_value=0.3) 127 | 128 | else: 129 | print('=> Computing DistMat with euclidean_distance') 130 | distmat = euclidean_distance(qf, gf) 131 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids) 132 | 133 | return cmc, mAP, distmat, self.pids, self.camids, qf, gf 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /utils/reranking.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri, 25 May 2018 20:29:09 5 | 6 | 7 | """ 8 | 9 | """ 10 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 11 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 12 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 13 | """ 14 | 15 | """ 16 | API 17 | 18 | probFea: all feature vectors of the query set (torch tensor) 19 | probFea: all feature vectors of the gallery set (torch tensor) 20 | k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3) 21 | MemorySave: set to 'True' when using MemorySave mode 22 | Minibatch: avaliable when 'MemorySave' is 'True' 23 | """ 24 | 25 | import numpy as np 26 | import torch 27 | 28 | 29 | def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat=None, only_local=False): 30 | # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor 31 | query_num = probFea.size(0) 32 | all_num = query_num + galFea.size(0) 33 | if only_local: 34 | original_dist = local_distmat 35 | else: 36 | feat = torch.cat([probFea, galFea]) 37 | # print('using GPU to compute original distance') 38 | distmat = torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num) + \ 39 | torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t() 40 | distmat.addmm_(1, -2, feat, feat.t()) 41 | original_dist = distmat.cpu().numpy() 42 | del feat 43 | if not local_distmat is None: 44 | original_dist = original_dist + local_distmat 45 | gallery_num = original_dist.shape[0] 46 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) 47 | V = np.zeros_like(original_dist).astype(np.float16) 48 | initial_rank = np.argsort(original_dist).astype(np.int32) 49 | 50 | # print('starting re_ranking') 51 | for i in range(all_num): 52 | # k-reciprocal neighbors 53 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 54 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 55 | fi = np.where(backward_k_neigh_index == i)[0] 56 | k_reciprocal_index = forward_k_neigh_index[fi] 57 | k_reciprocal_expansion_index = k_reciprocal_index 58 | for j in range(len(k_reciprocal_index)): 59 | candidate = k_reciprocal_index[j] 60 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1] 61 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 62 | :int(np.around(k1 / 2)) + 1] 63 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 64 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 65 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 66 | candidate_k_reciprocal_index): 67 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 68 | 69 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 70 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 71 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) 72 | original_dist = original_dist[:query_num, ] 73 | if k2 != 1: 74 | V_qe = np.zeros_like(V, dtype=np.float16) 75 | for i in range(all_num): 76 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 77 | V = V_qe 78 | del V_qe 79 | del initial_rank 80 | invIndex = [] 81 | for i in range(gallery_num): 82 | invIndex.append(np.where(V[:, i] != 0)[0]) 83 | 84 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) 85 | 86 | for i in range(query_num): 87 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16) 88 | indNonZero = np.where(V[i, :] != 0)[0] 89 | indImages = [invIndex[ind] for ind in indNonZero] 90 | for j in range(len(indNonZero)): 91 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 92 | V[indImages[j], indNonZero[j]]) 93 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 94 | 95 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 96 | del original_dist 97 | del V 98 | del jaccard_dist 99 | final_dist = final_dist[:query_num, query_num:] 100 | return final_dist 101 | 102 | --------------------------------------------------------------------------------