├── .gitignore ├── INSTALL.md ├── README.md ├── configs ├── Base-AGW.yml ├── Base-MGN.yml ├── Base-SBS.yml ├── Base-bagtricks.yml ├── DukeMTMC │ ├── AGW_R101-ibn.yml │ ├── AGW_R50-ibn.yml │ ├── AGW_R50.yml │ ├── AGW_S50.yml │ ├── bagtricks_R101-ibn.yml │ ├── bagtricks_R50-ibn.yml │ ├── bagtricks_R50.yml │ ├── bagtricks_S50.yml │ ├── mgn_R50-ibn.yml │ ├── sbs_R101-ibn.yml │ ├── sbs_R50-ibn.yml │ ├── sbs_R50.yml │ └── sbs_S50.yml ├── MSMT17 │ ├── AGW_R101-ibn.yml │ ├── AGW_R50-ibn.yml │ ├── AGW_R50.yml │ ├── AGW_S50.yml │ ├── bagtricks_R101-ibn.yml │ ├── bagtricks_R50-ibn.yml │ ├── bagtricks_R50.yml │ ├── bagtricks_S50.yml │ ├── mgn_R50-ibn.yml │ ├── sbs_R101-ibn.yml │ ├── sbs_R50-ibn.yml │ ├── sbs_R50.yml │ └── sbs_S50.yml ├── Market1501 │ ├── AGW_R101-ibn.yml │ ├── AGW_R50-ibn.yml │ ├── AGW_R50.yml │ ├── AGW_S50.yml │ ├── bagtricks_R101-ibn.yml │ ├── bagtricks_R50-ibn.yml │ ├── bagtricks_R50.yml │ ├── bagtricks_S50.yml │ ├── bagtricks_vit.yml │ ├── mgn_R50-ibn.yml │ ├── sbs_R101-ibn.yml │ ├── sbs_R50-ibn.yml │ ├── sbs_R50.yml │ └── sbs_S50.yml ├── VERIWild │ └── bagtricks_R50-ibn.yml ├── VeRi │ └── sbs_R50-ibn.yml ├── VehicleID │ └── bagtricks_R50-ibn.yml └── bagtricks_DR50_mix.yml ├── copy_launch.py ├── fastreid ├── __init__.py ├── config │ ├── __init__.py │ ├── config.py │ └── defaults.py ├── data │ ├── __init__.py │ ├── build.py │ ├── common.py │ ├── data_utils.py │ ├── datasets │ │ ├── AirportALERT.py │ │ ├── __init__.py │ │ ├── bases.py │ │ ├── caviara.py │ │ ├── cuhk03.py │ │ ├── cuhk03_full.py │ │ ├── cuhk_sysu.py │ │ ├── dukemtmcreid.py │ │ ├── grid.py │ │ ├── iLIDS.py │ │ ├── lpw.py │ │ ├── market1501.py │ │ ├── msmt17.py │ │ ├── pes3d.py │ │ ├── pku.py │ │ ├── prai.py │ │ ├── prid.py │ │ ├── saivt.py │ │ ├── sensereid.py │ │ ├── shinpuhkan.py │ │ ├── sysu_mm.py │ │ ├── thermalworld.py │ │ ├── vehicleid.py │ │ ├── veri.py │ │ ├── veriwild.py │ │ ├── viper.py │ │ └── wildtracker.py │ ├── samplers │ │ ├── __init__.py │ │ ├── data_sampler.py │ │ ├── imbalance_sampler.py │ │ └── triplet_sampler.py │ └── transforms │ │ ├── __init__.py │ │ ├── autoaugment.py │ │ ├── build.py │ │ ├── functional.py │ │ └── transforms.py ├── engine │ ├── __init__.py │ ├── defaults.py │ ├── hooks.py │ ├── launch.py │ └── train_loop.py ├── evaluation │ ├── __init__.py │ ├── clas_evaluator.py │ ├── evaluator.py │ ├── query_expansion.py │ ├── rank.py │ ├── rank_cylib │ │ ├── Makefile │ │ ├── __init__.py │ │ ├── rank_cy.pyx │ │ ├── roc_cy.pyx │ │ ├── setup.py │ │ └── test_cython.py │ ├── reid_evaluation.py │ ├── rerank.py │ ├── roc.py │ └── testing.py ├── layers │ ├── __init__.py │ ├── activation.py │ ├── any_softmax.py │ ├── batch_norm.py │ ├── context_block.py │ ├── drop.py │ ├── frn.py │ ├── gather_layer.py │ ├── helpers.py │ ├── non_local.py │ ├── pooling.py │ ├── se_layer.py │ ├── splat.py │ └── weight_init.py ├── modeling │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── build.py │ │ ├── meta_dynamic_router_resnet.py │ │ ├── mobilenet.py │ │ ├── mobilenetv3.py │ │ ├── osnet.py │ │ ├── regnet │ │ │ ├── __init__.py │ │ │ ├── config.py │ │ │ ├── effnet.py │ │ │ ├── effnet │ │ │ │ ├── EN-B0_dds_8gpu.yaml │ │ │ │ ├── EN-B1_dds_8gpu.yaml │ │ │ │ ├── EN-B2_dds_8gpu.yaml │ │ │ │ ├── EN-B3_dds_8gpu.yaml │ │ │ │ ├── EN-B4_dds_8gpu.yaml │ │ │ │ └── EN-B5_dds_8gpu.yaml │ │ │ ├── regnet.py │ │ │ ├── regnetx │ │ │ │ ├── RegNetX-1.6GF_dds_8gpu.yaml │ │ │ │ ├── RegNetX-12GF_dds_8gpu.yaml │ │ │ │ ├── RegNetX-16GF_dds_8gpu.yaml │ │ │ │ ├── RegNetX-200MF_dds_8gpu.yaml │ │ │ │ ├── RegNetX-3.2GF_dds_8gpu.yaml │ │ │ │ ├── RegNetX-32GF_dds_8gpu.yaml │ │ │ │ ├── RegNetX-4.0GF_dds_8gpu.yaml │ │ │ │ ├── RegNetX-400MF_dds_8gpu.yaml │ │ │ │ ├── RegNetX-6.4GF_dds_8gpu.yaml │ │ │ │ ├── RegNetX-600MF_dds_8gpu.yaml │ │ │ │ ├── RegNetX-8.0GF_dds_8gpu.yaml │ │ │ │ └── RegNetX-800MF_dds_8gpu.yaml │ │ │ └── regnety │ │ │ │ ├── RegNetY-1.6GF_dds_8gpu.yaml │ │ │ │ ├── RegNetY-12GF_dds_8gpu.yaml │ │ │ │ ├── RegNetY-16GF_dds_8gpu.yaml │ │ │ │ ├── RegNetY-200MF_dds_8gpu.yaml │ │ │ │ ├── RegNetY-3.2GF_dds_8gpu.yaml │ │ │ │ ├── RegNetY-32GF_dds_8gpu.yaml │ │ │ │ ├── RegNetY-4.0GF_dds_8gpu.yaml │ │ │ │ ├── RegNetY-400MF_dds_8gpu.yaml │ │ │ │ ├── RegNetY-6.4GF_dds_8gpu.yaml │ │ │ │ ├── RegNetY-600MF_dds_8gpu.yaml │ │ │ │ ├── RegNetY-8.0GF_dds_8gpu.yaml │ │ │ │ └── RegNetY-800MF_dds_8gpu.yaml │ │ ├── repvgg.py │ │ ├── resnest.py │ │ ├── resnet.py │ │ ├── resnext.py │ │ ├── shufflenet.py │ │ └── vision_transformer.py │ ├── heads │ │ ├── __init__.py │ │ ├── build.py │ │ ├── clas_head.py │ │ ├── embedding_head.py │ │ └── meta_embedding_head.py │ ├── losses │ │ ├── __init__.py │ │ ├── center_loss.py │ │ ├── circle_loss.py │ │ ├── cluster_loss.py │ │ ├── cross_entroy_loss.py │ │ ├── domain_SCT_loss.py │ │ ├── focal_loss.py │ │ ├── svmo.py │ │ ├── triplet_loss.py │ │ ├── triplet_loss_MetaIBN.py │ │ └── utils.py │ ├── meta_arch │ │ ├── __init__.py │ │ ├── baseline.py │ │ ├── build.py │ │ ├── distiller.py │ │ ├── mgn.py │ │ └── moco.py │ └── ops.py ├── solver │ ├── __init__.py │ ├── build.py │ ├── lr_scheduler.py │ └── optim │ │ ├── __init__.py │ │ ├── lamb.py │ │ ├── radam.py │ │ └── swa.py └── utils │ ├── __init__.py │ ├── checkpoint.py │ ├── collect_env.py │ ├── comm.py │ ├── compute_dist.py │ ├── env.py │ ├── events.py │ ├── faiss_utils.py │ ├── file_io.py │ ├── history_buffer.py │ ├── logger.py │ ├── params.py │ ├── precision_bn.py │ ├── registry.py │ ├── summary.py │ ├── timer.py │ └── visualizer.py ├── fig ├── pipeline.pdf └── pipeline.png ├── launch.sh ├── requirements.txt └── tools ├── deploy ├── Caffe │ ├── ReadMe.md │ ├── __init__.py │ ├── caffe.proto │ ├── caffe_lmdb.py │ ├── caffe_net.py │ ├── caffe_pb2.py │ ├── layer_param.py │ └── net.py ├── README.md ├── caffe_export.py ├── caffe_inference.py ├── onnx_export.py ├── onnx_inference.py ├── pytorch_to_caffe.py ├── test_data │ ├── 0022_c6s1_002976_01.jpg │ ├── 0027_c2s2_091032_02.jpg │ ├── 0032_c6s1_002851_01.jpg │ ├── 0048_c1s1_005351_01.jpg │ └── 0065_c6s1_009501_02.jpg ├── trt_calibrator.py ├── trt_export.py └── trt_inference.py ├── plain_train_net.py └── train_net.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | logs 3 | pkl_zoo 4 | 5 | # compilation and distribution 6 | __pycache__ 7 | _ext 8 | *.pyc 9 | *.pyd 10 | *.so 11 | *.dll 12 | *.egg-info/ 13 | build/ 14 | dist/ 15 | wheels/ 16 | 17 | # pytorch/python/numpy formats 18 | *.pth 19 | *.pkl 20 | *.npy 21 | *.ts 22 | model_ts*.txt 23 | 24 | # ipython/jupyter notebooks 25 | *.ipynb 26 | **/.ipynb_checkpoints/ 27 | 28 | # Editor temporaries 29 | *.swn 30 | *.swo 31 | *.swp 32 | *~ 33 | 34 | # editor settings 35 | .idea 36 | .vscode 37 | _darcs 38 | .DS_Store 39 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | ## Requirements 4 | 5 | - Linux or macOS with python ≥ 3.6 6 | - PyTorch ≥ 1.6 7 | - torchvision that matches the Pytorch installation. You can install them together at [pytorch.org](https://pytorch.org/) to make sure of this. 8 | - [yacs](https://github.com/rbgirshick/yacs) 9 | - Cython (optional to compile evaluation code) 10 | - tensorboard (needed for visualization): `pip install tensorboard` 11 | - gdown (for automatically downloading pre-train model) 12 | - sklearn 13 | - termcolor 14 | - tabulate 15 | - [faiss](https://github.com/facebookresearch/faiss) `pip install faiss-cpu` 16 | 17 | 18 | 19 | # Set up with Conda 20 | ```shell script 21 | conda create -n fastreid python=3.7 22 | conda activate fastreid 23 | conda install pytorch==1.6.0 torchvision tensorboard -c pytorch 24 | pip install -r docs/requirements.txt 25 | ``` 26 | 27 | # Set up with Dockder 28 | comming soon 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adaptive Cross-Domain learning for Generalizable Person Re-Identification 2 | ![](fig/pipeline.png) 3 | 4 | ## Requirements 5 | + CUDA>=10.0 6 | + At least four 1080-Ti GPUs 7 | + Setup could refer to [INSTALL.md](INSTALL.md) 8 | + Other necessary packages listed in [requirements.txt](requirements.txt) 9 | + Training Data \ 10 | The model is trained and evaluated on Market-1501, MSMT17, cuhkSYSU, CUHK03. Download for these datasets, please refer to [fast-reid](https://github.com/JDAI-CV/fast-reid). 11 | ## Run 12 | 13 | ``` 14 | # train 15 | python copy_launch.py 16 | 17 | # test 18 | python3 tools/train_net.py --config-file ./configs/bagtricks_DR50_mix.yml --eval-only MODEL.WEIGHTS /path/to/checkpoint_file MODEL.DEVICE "cuda:0" 19 | ``` 20 | ## Acknowledgments 21 | This repo borrows partially from [fast-reid](https://github.com/JDAI-CV/fast-reid), and [MetaBIN](https://github.com/bismex/MetaBIN). -------------------------------------------------------------------------------- /configs/Base-AGW.yml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-bagtricks.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | WITH_NL: True 6 | 7 | HEADS: 8 | POOL_LAYER: GeneralizedMeanPooling 9 | 10 | LOSSES: 11 | NAME: ("CrossEntropyLoss", "TripletLoss") 12 | CE: 13 | EPSILON: 0.1 14 | SCALE: 1.0 15 | 16 | TRI: 17 | MARGIN: 0.0 18 | HARD_MINING: False 19 | SCALE: 1.0 20 | -------------------------------------------------------------------------------- /configs/Base-MGN.yml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-SBS.yml 2 | 3 | MODEL: 4 | META_ARCHITECTURE: MGN 5 | 6 | FREEZE_LAYERS: [backbone, b1, b2, b3,] 7 | 8 | BACKBONE: 9 | WITH_NL: False 10 | 11 | HEADS: 12 | EMBEDDING_DIM: 256 13 | -------------------------------------------------------------------------------- /configs/Base-SBS.yml: -------------------------------------------------------------------------------- 1 | _BASE_: Base-bagtricks.yml 2 | 3 | MODEL: 4 | FREEZE_LAYERS: [ backbone ] 5 | 6 | BACKBONE: 7 | WITH_NL: True 8 | 9 | HEADS: 10 | NECK_FEAT: after 11 | POOL_LAYER: GeneralizedMeanPoolingP 12 | CLS_LAYER: CircleSoftmax 13 | SCALE: 64 14 | MARGIN: 0.35 15 | 16 | LOSSES: 17 | NAME: ("CrossEntropyLoss", "TripletLoss",) 18 | CE: 19 | EPSILON: 0.1 20 | SCALE: 1.0 21 | 22 | TRI: 23 | MARGIN: 0.0 24 | HARD_MINING: True 25 | NORM_FEAT: False 26 | SCALE: 1.0 27 | 28 | INPUT: 29 | SIZE_TRAIN: [ 384, 128 ] 30 | SIZE_TEST: [ 384, 128 ] 31 | 32 | AUTOAUG: 33 | ENABLED: True 34 | PROB: 0.1 35 | 36 | DATALOADER: 37 | NUM_INSTANCE: 16 38 | 39 | SOLVER: 40 | AMP: 41 | ENABLED: True 42 | OPT: Adam 43 | MAX_EPOCH: 60 44 | BASE_LR: 0.00035 45 | WEIGHT_DECAY: 0.0005 46 | IMS_PER_BATCH: 64 47 | 48 | SCHED: CosineAnnealingLR 49 | DELAY_EPOCHS: 30 50 | ETA_MIN_LR: 0.0000007 51 | 52 | WARMUP_FACTOR: 0.1 53 | WARMUP_ITERS: 2000 54 | 55 | FREEZE_ITERS: 1000 56 | 57 | CHECKPOINT_PERIOD: 20 58 | 59 | TEST: 60 | EVAL_PERIOD: 10 61 | IMS_PER_BATCH: 128 62 | 63 | CUDNN_BENCHMARK: True 64 | -------------------------------------------------------------------------------- /configs/Base-bagtricks.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: Baseline 3 | 4 | BACKBONE: 5 | NAME: build_resnet_backbone 6 | NORM: BN 7 | DEPTH: 50x 8 | LAST_STRIDE: 1 9 | FEAT_DIM: 2048 10 | WITH_IBN: False 11 | PRETRAIN: True 12 | 13 | HEADS: 14 | NAME: EmbeddingHead 15 | NORM: BN 16 | WITH_BNNECK: True 17 | POOL_LAYER: GlobalAvgPool 18 | NECK_FEAT: before 19 | CLS_LAYER: Linear 20 | 21 | LOSSES: 22 | NAME: ("CrossEntropyLoss", "TripletLoss",) 23 | 24 | CE: 25 | EPSILON: 0.1 26 | SCALE: 1. 27 | 28 | TRI: 29 | MARGIN: 0.3 30 | HARD_MINING: True 31 | NORM_FEAT: False 32 | SCALE: 1. 33 | 34 | INPUT: 35 | SIZE_TRAIN: [ 256, 128 ] 36 | SIZE_TEST: [ 256, 128 ] 37 | 38 | REA: 39 | ENABLED: True 40 | PROB: 0.5 41 | 42 | FLIP: 43 | ENABLED: True 44 | 45 | PADDING: 46 | ENABLED: True 47 | 48 | DATALOADER: 49 | SAMPLER_TRAIN: NaiveIdentitySampler 50 | NUM_INSTANCE: 4 51 | NUM_WORKERS: 8 52 | 53 | SOLVER: 54 | AMP: 55 | ENABLED: True 56 | OPT: Adam 57 | MAX_EPOCH: 120 58 | BASE_LR: 0.00035 59 | WEIGHT_DECAY: 0.0005 60 | WEIGHT_DECAY_NORM: 0.0005 61 | IMS_PER_BATCH: 64 62 | 63 | SCHED: MultiStepLR 64 | STEPS: [ 40, 90 ] 65 | GAMMA: 0.1 66 | 67 | WARMUP_FACTOR: 0.1 68 | WARMUP_ITERS: 2000 69 | 70 | CHECKPOINT_PERIOD: 30 71 | 72 | TEST: 73 | EVAL_PERIOD: 30 74 | IMS_PER_BATCH: 128 75 | 76 | CUDNN_BENCHMARK: True 77 | -------------------------------------------------------------------------------- /configs/DukeMTMC/AGW_R101-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-AGW.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | DEPTH: 101x 6 | WITH_IBN: True 7 | 8 | DATASETS: 9 | NAMES: ("DukeMTMC",) 10 | TESTS: ("DukeMTMC",) 11 | 12 | OUTPUT_DIR: logs/dukemtmc/agw_R101-ibn 13 | -------------------------------------------------------------------------------- /configs/DukeMTMC/AGW_R50-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-AGW.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | WITH_IBN: True 6 | 7 | DATASETS: 8 | NAMES: ("DukeMTMC",) 9 | TESTS: ("DukeMTMC",) 10 | 11 | OUTPUT_DIR: logs/dukemtmc/agw_R50-ibn 12 | -------------------------------------------------------------------------------- /configs/DukeMTMC/AGW_R50.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-AGW.yml 2 | 3 | DATASETS: 4 | NAMES: ("DukeMTMC",) 5 | TESTS: ("DukeMTMC",) 6 | 7 | OUTPUT_DIR: logs/dukemtmc/agw_R50 8 | -------------------------------------------------------------------------------- /configs/DukeMTMC/AGW_S50.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-AGW.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | NAME: build_resnest_backbone 6 | 7 | DATASETS: 8 | NAMES: ("DukeMTMC",) 9 | TESTS: ("DukeMTMC",) 10 | 11 | OUTPUT_DIR: logs/dukemtmc/agw_S50 12 | -------------------------------------------------------------------------------- /configs/DukeMTMC/bagtricks_R101-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-bagtricks.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | DEPTH: 101x 6 | WITH_IBN: True 7 | 8 | DATASETS: 9 | NAMES: ("DukeMTMC",) 10 | TESTS: ("DukeMTMC",) 11 | 12 | OUTPUT_DIR: logs/dukemtmc/bagtricks_R101-ibn 13 | -------------------------------------------------------------------------------- /configs/DukeMTMC/bagtricks_R50-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-bagtricks.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | WITH_IBN: True 6 | 7 | DATASETS: 8 | NAMES: ("DukeMTMC",) 9 | TESTS: ("DukeMTMC",) 10 | 11 | OUTPUT_DIR: logs/dukemtmc/bagtricks_R50-ibn 12 | -------------------------------------------------------------------------------- /configs/DukeMTMC/bagtricks_R50.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-bagtricks.yml 2 | 3 | DATASETS: 4 | NAMES: ("DukeMTMC",) 5 | TESTS: ("DukeMTMC",) 6 | 7 | OUTPUT_DIR: logs/dukemtmc/bagtricks_R50 8 | -------------------------------------------------------------------------------- /configs/DukeMTMC/bagtricks_S50.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-bagtricks.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | NAME: build_resnest_backbone 6 | 7 | DATASETS: 8 | NAMES: ("DukeMTMC",) 9 | TESTS: ("DukeMTMC",) 10 | 11 | OUTPUT_DIR: logs/dukemtmc/bagtricks_S50 12 | -------------------------------------------------------------------------------- /configs/DukeMTMC/mgn_R50-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-MGN.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | WITH_IBN: True 6 | 7 | DATASETS: 8 | NAMES: ("DukeMTMC",) 9 | TESTS: ("DukeMTMC",) 10 | 11 | OUTPUT_DIR: logs/dukemtmc/mgn_R50-ibn 12 | -------------------------------------------------------------------------------- /configs/DukeMTMC/sbs_R101-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-SBS.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | DEPTH: 101x 6 | WITH_IBN: True 7 | 8 | DATASETS: 9 | NAMES: ("DukeMTMC",) 10 | TESTS: ("DukeMTMC",) 11 | 12 | OUTPUT_DIR: logs/dukemtmc/sbs_R101-ibn 13 | -------------------------------------------------------------------------------- /configs/DukeMTMC/sbs_R50-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-SBS.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | WITH_IBN: True 6 | 7 | DATASETS: 8 | NAMES: ("DukeMTMC",) 9 | TESTS: ("DukeMTMC",) 10 | 11 | OUTPUT_DIR: logs/dukemtmc/sbs_R50-ibn 12 | -------------------------------------------------------------------------------- /configs/DukeMTMC/sbs_R50.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-SBS.yml 2 | 3 | DATASETS: 4 | NAMES: ("DukeMTMC",) 5 | TESTS: ("DukeMTMC",) 6 | 7 | OUTPUT_DIR: logs/dukemtmc/sbs_R50 8 | -------------------------------------------------------------------------------- /configs/DukeMTMC/sbs_S50.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-SBS.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | NAME: build_resnest_backbone 6 | 7 | DATASETS: 8 | NAMES: ("DukeMTMC",) 9 | TESTS: ("DukeMTMC",) 10 | 11 | OUTPUT_DIR: logs/dukemtmc/sbs_S50 12 | -------------------------------------------------------------------------------- /configs/MSMT17/AGW_R101-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-AGW.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | DEPTH: 101x 6 | WITH_IBN: True 7 | 8 | DATASETS: 9 | NAMES: ("MSMT17",) 10 | TESTS: ("MSMT17",) 11 | 12 | OUTPUT_DIR: logs/msmt17/agw_R101-ibn 13 | -------------------------------------------------------------------------------- /configs/MSMT17/AGW_R50-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-AGW.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | WITH_IBN: True 6 | 7 | DATASETS: 8 | NAMES: ("MSMT17",) 9 | TESTS: ("MSMT17",) 10 | 11 | OUTPUT_DIR: logs/msmt17/agw_R50-ibn 12 | -------------------------------------------------------------------------------- /configs/MSMT17/AGW_R50.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-AGW.yml 2 | 3 | DATASETS: 4 | NAMES: ("MSMT17",) 5 | TESTS: ("MSMT17",) 6 | 7 | OUTPUT_DIR: logs/msmt17/agw_R50 8 | -------------------------------------------------------------------------------- /configs/MSMT17/AGW_S50.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-AGW.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | NAME: build_resnest_backbone 6 | 7 | DATASETS: 8 | NAMES: ("MSMT17",) 9 | TESTS: ("MSMT17",) 10 | 11 | OUTPUT_DIR: logs/msmt17/agw_S50 12 | -------------------------------------------------------------------------------- /configs/MSMT17/bagtricks_R101-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-bagtricks.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | DEPTH: 101x 6 | WITH_IBN: True 7 | 8 | DATASETS: 9 | NAMES: ("MSMT17",) 10 | TESTS: ("MSMT17",) 11 | 12 | OUTPUT_DIR: logs/msmt17/bagtricks_R101-ibn 13 | 14 | -------------------------------------------------------------------------------- /configs/MSMT17/bagtricks_R50-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-bagtricks.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | WITH_IBN: True 6 | 7 | DATASETS: 8 | NAMES: ("MSMT17",) 9 | TESTS: ("MSMT17",) 10 | 11 | OUTPUT_DIR: logs/msmt17/bagtricks_R50-ibn 12 | 13 | -------------------------------------------------------------------------------- /configs/MSMT17/bagtricks_R50.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-bagtricks.yml 2 | 3 | DATASETS: 4 | NAMES: ("MSMT17",) 5 | TESTS: ("MSMT17",) 6 | 7 | OUTPUT_DIR: logs/msmt17/bagtricks_R50 8 | -------------------------------------------------------------------------------- /configs/MSMT17/bagtricks_S50.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-bagtricks.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | NAME: build_resnest_backbone 6 | 7 | DATASETS: 8 | NAMES: ("MSMT17",) 9 | TESTS: ("MSMT17",) 10 | 11 | OUTPUT_DIR: logs/msmt17/bagtricks_S50 12 | 13 | -------------------------------------------------------------------------------- /configs/MSMT17/mgn_R50-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-MGN.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | WITH_IBN: True 6 | 7 | DATASETS: 8 | NAMES: ("MSMT17",) 9 | TESTS: ("MSMT17",) 10 | 11 | OUTPUT_DIR: logs/msmt17/mgn_R50-ibn 12 | -------------------------------------------------------------------------------- /configs/MSMT17/sbs_R101-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-SBS.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | DEPTH: 101x 6 | WITH_IBN: True 7 | 8 | DATASETS: 9 | NAMES: ("MSMT17",) 10 | TESTS: ("MSMT17",) 11 | 12 | OUTPUT_DIR: logs/msmt17/sbs_R101-ibn 13 | -------------------------------------------------------------------------------- /configs/MSMT17/sbs_R50-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-SBS.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | WITH_IBN: True 6 | 7 | DATASETS: 8 | NAMES: ("MSMT17",) 9 | TESTS: ("MSMT17",) 10 | 11 | OUTPUT_DIR: logs/msmt17/sbs_R50-ibn 12 | -------------------------------------------------------------------------------- /configs/MSMT17/sbs_R50.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-SBS.yml 2 | 3 | DATASETS: 4 | NAMES: ("MSMT17",) 5 | TESTS: ("MSMT17",) 6 | 7 | OUTPUT_DIR: logs/msmt17/sbs_R50 8 | -------------------------------------------------------------------------------- /configs/MSMT17/sbs_S50.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-SBS.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | NAME: build_resnest_backbone 6 | 7 | DATASETS: 8 | NAMES: ("MSMT17",) 9 | TESTS: ("MSMT17",) 10 | 11 | OUTPUT_DIR: logs/msmt17/sbs_S50 12 | -------------------------------------------------------------------------------- /configs/Market1501/AGW_R101-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-AGW.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | DEPTH: 101x 6 | WITH_IBN: True 7 | 8 | DATASETS: 9 | NAMES: ("Market1501",) 10 | TESTS: ("Market1501",) 11 | 12 | OUTPUT_DIR: logs/market1501/agw_R101-ibn 13 | -------------------------------------------------------------------------------- /configs/Market1501/AGW_R50-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-AGW.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | WITH_IBN: True 6 | 7 | DATASETS: 8 | NAMES: ("Market1501",) 9 | TESTS: ("Market1501",) 10 | 11 | OUTPUT_DIR: logs/market1501/agw_R50-ibn 12 | -------------------------------------------------------------------------------- /configs/Market1501/AGW_R50.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-AGW.yml 2 | 3 | DATASETS: 4 | NAMES: ("Market1501",) 5 | TESTS: ("Market1501",) 6 | 7 | OUTPUT_DIR: logs/market1501/agw_R50 8 | -------------------------------------------------------------------------------- /configs/Market1501/AGW_S50.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-AGW.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | NAME: build_resnest_backbone 6 | 7 | DATASETS: 8 | NAMES: ("Market1501",) 9 | TESTS: ("Market1501",) 10 | 11 | OUTPUT_DIR: logs/market1501/agw_S50 12 | -------------------------------------------------------------------------------- /configs/Market1501/bagtricks_R101-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-bagtricks.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | DEPTH: 101x 6 | WITH_IBN: True 7 | 8 | DATASETS: 9 | NAMES: ("Market1501",) 10 | TESTS: ("Market1501",) 11 | 12 | OUTPUT_DIR: logs/market1501/bagtricks_R101-ibn 13 | -------------------------------------------------------------------------------- /configs/Market1501/bagtricks_R50-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-bagtricks.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | WITH_IBN: True 6 | 7 | DATASETS: 8 | NAMES: ("Market1501",) 9 | TESTS: ("Market1501",) 10 | 11 | OUTPUT_DIR: logs/market1501/bagtricks_R50-ibn 12 | -------------------------------------------------------------------------------- /configs/Market1501/bagtricks_R50.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-bagtricks.yml 2 | 3 | DATASETS: 4 | NAMES: ("Market1501",) 5 | TESTS: ("Market1501",) 6 | 7 | OUTPUT_DIR: logs/market1501/bagtricks_R50 8 | -------------------------------------------------------------------------------- /configs/Market1501/bagtricks_S50.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-bagtricks.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | NAME: build_resnest_backbone 6 | 7 | DATASETS: 8 | NAMES: ("Market1501",) 9 | TESTS: ("Market1501",) 10 | 11 | OUTPUT_DIR: logs/market1501/bagtricks_S50 12 | -------------------------------------------------------------------------------- /configs/Market1501/bagtricks_vit.yml: -------------------------------------------------------------------------------- 1 | 2 | MODEL: 3 | META_ARCHITECTURE: Baseline 4 | PIXEL_MEAN: [127.5, 127.5, 127.5] 5 | PIXEL_STD: [127.5, 127.5, 127.5] 6 | 7 | BACKBONE: 8 | NAME: build_vit_backbone 9 | DEPTH: base 10 | FEAT_DIM: 768 11 | PRETRAIN: True 12 | PRETRAIN_PATH: /export/home/lxy/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth 13 | STRIDE_SIZE: (16, 16) 14 | DROP_PATH_RATIO: 0.1 15 | DROP_RATIO: 0.0 16 | ATT_DROP_RATE: 0.0 17 | 18 | HEADS: 19 | NAME: EmbeddingHead 20 | NORM: BN 21 | WITH_BNNECK: True 22 | POOL_LAYER: Identity 23 | NECK_FEAT: before 24 | CLS_LAYER: Linear 25 | 26 | LOSSES: 27 | NAME: ("CrossEntropyLoss", "TripletLoss",) 28 | 29 | CE: 30 | EPSILON: 0. # no smooth 31 | SCALE: 1. 32 | 33 | TRI: 34 | MARGIN: 0.0 35 | HARD_MINING: True 36 | NORM_FEAT: False 37 | SCALE: 1. 38 | 39 | INPUT: 40 | SIZE_TRAIN: [ 256, 128 ] 41 | SIZE_TEST: [ 256, 128 ] 42 | 43 | REA: 44 | ENABLED: True 45 | PROB: 0.5 46 | 47 | FLIP: 48 | ENABLED: True 49 | 50 | PADDING: 51 | ENABLED: True 52 | 53 | DATALOADER: 54 | SAMPLER_TRAIN: NaiveIdentitySampler 55 | NUM_INSTANCE: 4 56 | NUM_WORKERS: 8 57 | 58 | SOLVER: 59 | AMP: 60 | ENABLED: False 61 | OPT: SGD 62 | MAX_EPOCH: 120 63 | BASE_LR: 0.008 64 | WEIGHT_DECAY: 0.0001 65 | IMS_PER_BATCH: 64 66 | 67 | SCHED: CosineAnnealingLR 68 | ETA_MIN_LR: 0.000016 69 | 70 | WARMUP_FACTOR: 0.01 71 | WARMUP_ITERS: 1000 72 | 73 | CLIP_GRADIENTS: 74 | ENABLED: True 75 | 76 | CHECKPOINT_PERIOD: 30 77 | 78 | TEST: 79 | EVAL_PERIOD: 5 80 | IMS_PER_BATCH: 128 81 | 82 | CUDNN_BENCHMARK: True 83 | 84 | DATASETS: 85 | NAMES: ("Market1501",) 86 | TESTS: ("Market1501",) 87 | 88 | OUTPUT_DIR: logs/market1501/sbs_vit_base 89 | -------------------------------------------------------------------------------- /configs/Market1501/mgn_R50-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-MGN.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | WITH_IBN: True 6 | 7 | DATASETS: 8 | NAMES: ("Market1501",) 9 | TESTS: ("Market1501",) 10 | 11 | OUTPUT_DIR: logs/market1501/mgn_R50-ibn 12 | -------------------------------------------------------------------------------- /configs/Market1501/sbs_R101-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-SBS.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | DEPTH: 101x 6 | WITH_IBN: True 7 | 8 | DATASETS: 9 | NAMES: ("Market1501",) 10 | TESTS: ("Market1501",) 11 | 12 | OUTPUT_DIR: logs/market1501/sbs_R101-ibn 13 | -------------------------------------------------------------------------------- /configs/Market1501/sbs_R50-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-SBS.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | WITH_IBN: True 6 | 7 | DATASETS: 8 | NAMES: ("Market1501",) 9 | TESTS: ("Market1501",) 10 | 11 | OUTPUT_DIR: logs/market1501/sbs_R50-ibn 12 | -------------------------------------------------------------------------------- /configs/Market1501/sbs_R50.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-SBS.yml 2 | 3 | DATASETS: 4 | NAMES: ("Market1501",) 5 | TESTS: ("Market1501",) 6 | 7 | OUTPUT_DIR: logs/market1501/sbs_R50 8 | -------------------------------------------------------------------------------- /configs/Market1501/sbs_S50.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-SBS.yml 2 | 3 | MODEL: 4 | BACKBONE: 5 | NAME: build_resnest_backbone 6 | 7 | DATASETS: 8 | NAMES: ("Market1501",) 9 | TESTS: ("Market1501",) 10 | 11 | OUTPUT_DIR: logs/market1501/sbs_S50 12 | -------------------------------------------------------------------------------- /configs/VERIWild/bagtricks_R50-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-bagtricks.yml 2 | 3 | INPUT: 4 | SIZE_TRAIN: [256, 256] 5 | SIZE_TEST: [256, 256] 6 | 7 | MODEL: 8 | BACKBONE: 9 | WITH_IBN: True 10 | 11 | HEADS: 12 | POOL_LAYER: GeneralizedMeanPooling 13 | 14 | LOSSES: 15 | TRI: 16 | HARD_MINING: False 17 | MARGIN: 0.0 18 | 19 | DATASETS: 20 | NAMES: ("VeRiWild",) 21 | TESTS: ("SmallVeRiWild", "MediumVeRiWild", "LargeVeRiWild",) 22 | 23 | SOLVER: 24 | IMS_PER_BATCH: 512 # 512 For 4 GPUs 25 | MAX_EPOCH: 120 26 | STEPS: [30, 70, 90] 27 | WARMUP_ITERS: 5000 28 | 29 | CHECKPOINT_PERIOD: 20 30 | 31 | TEST: 32 | EVAL_PERIOD: 10 33 | IMS_PER_BATCH: 128 34 | 35 | OUTPUT_DIR: logs/veriwild/bagtricks_R50-ibn_4gpu 36 | -------------------------------------------------------------------------------- /configs/VeRi/sbs_R50-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-SBS.yml 2 | 3 | INPUT: 4 | SIZE_TRAIN: [256, 256] 5 | SIZE_TEST: [256, 256] 6 | 7 | MODEL: 8 | BACKBONE: 9 | WITH_IBN: True 10 | WITH_NL: True 11 | 12 | SOLVER: 13 | OPT: SGD 14 | BASE_LR: 0.01 15 | ETA_MIN_LR: 7.7e-5 16 | 17 | IMS_PER_BATCH: 64 18 | MAX_EPOCH: 60 19 | WARMUP_ITERS: 3000 20 | FREEZE_ITERS: 3000 21 | 22 | CHECKPOINT_PERIOD: 10 23 | 24 | DATASETS: 25 | NAMES: ("VeRi",) 26 | TESTS: ("VeRi",) 27 | 28 | DATALOADER: 29 | SAMPLER_TRAIN: BalancedIdentitySampler 30 | 31 | TEST: 32 | EVAL_PERIOD: 10 33 | IMS_PER_BATCH: 256 34 | 35 | OUTPUT_DIR: logs/veri/sbs_R50-ibn 36 | -------------------------------------------------------------------------------- /configs/VehicleID/bagtricks_R50-ibn.yml: -------------------------------------------------------------------------------- 1 | _BASE_: ../Base-bagtricks.yml 2 | 3 | INPUT: 4 | SIZE_TRAIN: [256, 256] 5 | SIZE_TEST: [256, 256] 6 | 7 | MODEL: 8 | BACKBONE: 9 | WITH_IBN: True 10 | HEADS: 11 | POOL_LAYER: GeneralizedMeanPooling 12 | 13 | LOSSES: 14 | TRI: 15 | HARD_MINING: False 16 | MARGIN: 0.0 17 | 18 | DATASETS: 19 | NAMES: ("VehicleID",) 20 | TESTS: ("SmallVehicleID", "MediumVehicleID", "LargeVehicleID",) 21 | 22 | SOLVER: 23 | BIAS_LR_FACTOR: 1. 24 | 25 | IMS_PER_BATCH: 512 26 | MAX_EPOCH: 60 27 | STEPS: [30, 50] 28 | WARMUP_ITERS: 2000 29 | 30 | CHECKPOINT_PERIOD: 20 31 | 32 | TEST: 33 | EVAL_PERIOD: 20 34 | IMS_PER_BATCH: 128 35 | 36 | OUTPUT_DIR: logs/vehicleid/bagtricks_R50-ibn_4gpu 37 | -------------------------------------------------------------------------------- /configs/bagtricks_DR50_mix.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: Baseline 3 | 4 | BACKBONE: 5 | NAME: build_meta_dynamic_router_resnet_backbone 6 | NORM: BN 7 | DEPTH: 50x 8 | LAST_STRIDE: 1 9 | FEAT_DIM: 2048 10 | WITH_IBN: True 11 | PRETRAIN: True 12 | 13 | HEADS: 14 | NAME: MetaEmbeddingHead 15 | NORM: BN 16 | WITH_BNNECK: True 17 | POOL_LAYER: GeneralizedMeanPooling 18 | NECK_FEAT: after 19 | CLS_LAYER: Linear 20 | 21 | LOSSES: 22 | NAME: ("CrossEntropyLoss", "TripletLoss") 23 | 24 | CE: 25 | EPSILON: 0.1 26 | SCALE: .5 27 | 28 | TRI: 29 | MARGIN: 0.3 30 | HARD_MINING: True 31 | NORM_FEAT: False 32 | SCALE: 1. 33 | 34 | CIRCLE: 35 | MARGIN: 0.25 36 | GAMMA: 32 37 | SCALE: 0.1 38 | 39 | INPUT: 40 | SIZE_TRAIN: [ 256, 128 ] 41 | SIZE_TEST: [ 256, 128 ] 42 | 43 | AUTOAUG: 44 | ENABLED: True 45 | PROB: 0.1 46 | 47 | # REA: 48 | # ENABLED: True 49 | # PROB: 0.5 50 | 51 | FLIP: 52 | ENABLED: True 53 | 54 | PADDING: 55 | ENABLED: True 56 | 57 | DATALOADER: 58 | SAMPLER_TRAIN: NaiveIdentitySampler 59 | NUM_INSTANCE: 2 60 | NUM_WORKERS: 8 61 | 62 | SOLVER: 63 | AMP: 64 | ENABLED: False 65 | OPT: SGD 66 | MAX_EPOCH: 60 67 | BASE_LR: 0.04 68 | ETA_MIN_LR: 0.00004 69 | WEIGHT_DECAY: 0.0005 70 | WEIGHT_DECAY_NORM: 0.0005 71 | IMS_PER_BATCH: 64 72 | 73 | SCHED: CosineAnnealingLR 74 | STEPS: [ 30, 50 ] 75 | WARMUP_FACTOR: 0.01 76 | WARMUP_ITERS: 10 77 | WARMUP_METHOD: linear 78 | 79 | CHECKPOINT_PERIOD: 5 80 | 81 | TEST: 82 | EVAL_PERIOD: 5 83 | IMS_PER_BATCH: 128 84 | 85 | CUDNN_BENCHMARK: True 86 | 87 | 88 | DATASETS: 89 | NAMES: ("cuhkSYSU", "Market1501", "MSMT17") 90 | TESTS: ("CUHK03",) 91 | # COMBINEALL: True 92 | 93 | OUTPUT_DIR: logs/ 94 | -------------------------------------------------------------------------------- /copy_launch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import shutil 4 | 5 | f = open('launch.sh', 'r') 6 | words = f.readlines()[-1].split(' ') 7 | f.close() 8 | for word in words: 9 | if 'configs' in word: 10 | config = word.strip() 11 | print('Load config from ', config) 12 | 13 | f = open(config, 'r') 14 | save_path = f.readlines()[-1].split(' ')[1].strip() 15 | f.close() 16 | print('Save code to ', save_path) 17 | 18 | try: 19 | os.mkdir(save_path) 20 | except Exception: 21 | pass 22 | shutil.copytree(os.getcwd(), os.path.join(save_path, 'code')) 23 | print('Start process') 24 | os.chdir(os.path.join(save_path, 'code')) 25 | os.system('sh launch.sh') 26 | print('Done') -------------------------------------------------------------------------------- /fastreid/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | 8 | __version__ = "1.3" 9 | -------------------------------------------------------------------------------- /fastreid/config/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: l1aoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .config import CfgNode, get_cfg, global_cfg, set_global_cfg, configurable 8 | 9 | __all__ = [ 10 | 'CfgNode', 11 | 'get_cfg', 12 | 'global_cfg', 13 | 'set_global_cfg', 14 | 'configurable' 15 | ] 16 | -------------------------------------------------------------------------------- /fastreid/data/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from . import transforms # isort:skip 8 | from .build import ( 9 | build_reid_train_loader, 10 | build_reid_test_loader 11 | ) 12 | from .common import CommDataset 13 | 14 | # ensure the builtin datasets are registered 15 | from . import datasets, samplers # isort:skip 16 | 17 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 18 | -------------------------------------------------------------------------------- /fastreid/data/common.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from fastreid.solver.optim import lamb 8 | from torch.utils.data import Dataset 9 | 10 | from .data_utils import read_image 11 | 12 | 13 | class CommDataset(Dataset): 14 | """Image Person ReID Dataset""" 15 | #CHANGE Add domain id 16 | 17 | def __init__(self, img_items, transform=None, relabel=True, mapping=None, offset=0): 18 | self.img_items = img_items 19 | self.transform = transform 20 | self.relabel = relabel 21 | self.mapping = mapping 22 | 23 | assert self.mapping is not None, 'mapping must be initialized!!!' 24 | 25 | if isinstance(self.mapping, dict): 26 | pid_set = [set() for i in range(len(self.mapping))] 27 | cam_set = [set() for i in range(len(self.mapping))] 28 | for i in img_items: 29 | domain_id = self.mapping[i[1].split("_")[0]] 30 | pid_set[domain_id].add(i[1]) 31 | cam_set[domain_id].add(i[2]) 32 | 33 | self.pids = [] 34 | self.cams = [] 35 | for temp_pid, temp_cam in zip(pid_set, cam_set): 36 | self.pids += sorted(list(temp_pid)) 37 | self.cams += sorted(list(temp_cam)) 38 | else: 39 | pid_set = set() 40 | cam_set = set() 41 | for i in img_items: 42 | pid_set.add(i[1]) 43 | cam_set.add(i[2]) 44 | 45 | self.pids = sorted(list(pid_set)) 46 | self.cams = sorted(list(cam_set)) 47 | 48 | if relabel: 49 | self.pid_dict = dict([(p, i+offset) for i, p in enumerate(self.pids)]) 50 | self.cam_dict = dict([(p, i) for i, p in enumerate(self.cams)]) 51 | 52 | def __len__(self): 53 | return len(self.img_items) 54 | 55 | def __getitem__(self, index): 56 | img_item = self.img_items[index] 57 | img_path = img_item[0] 58 | pid = img_item[1] 59 | camid = img_item[2] 60 | img = read_image(img_path) 61 | if self.transform is not None: 62 | img0 = self.transform[0](img) 63 | img = self.transform[1](img) 64 | if self.mapping and isinstance(self.mapping, dict): 65 | domain_id = self.mapping[pid.split("_")[0]] 66 | else: 67 | domain_id = self.mapping 68 | if self.relabel: 69 | pid = self.pid_dict[pid] 70 | camid = self.cam_dict[camid] 71 | return { 72 | "images0": img0, 73 | "images": img, 74 | "targets": pid, 75 | "camids": camid, 76 | "domainids": domain_id, 77 | "img_paths": img_path, 78 | } 79 | 80 | @property 81 | def num_classes(self): 82 | return len(self.pids) 83 | 84 | @property 85 | def num_cameras(self): 86 | return len(self.cams) 87 | -------------------------------------------------------------------------------- /fastreid/data/datasets/AirportALERT.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import os 8 | 9 | from fastreid.data.datasets import DATASET_REGISTRY 10 | from fastreid.data.datasets.bases import ImageDataset 11 | 12 | __all__ = ['AirportALERT', ] 13 | 14 | 15 | @DATASET_REGISTRY.register() 16 | class AirportALERT(ImageDataset): 17 | """Airport 18 | 19 | """ 20 | dataset_dir = "AirportALERT" 21 | dataset_name = "airport" 22 | 23 | def __init__(self, root='datasets', **kwargs): 24 | self.root = root 25 | self.train_path = os.path.join(self.root, self.dataset_dir) 26 | self.train_file = os.path.join(self.root, self.dataset_dir, 'filepath.txt') 27 | 28 | required_files = [self.train_file, self.train_path] 29 | self.check_before_run(required_files) 30 | 31 | train = self.process_train(self.train_path, self.train_file) 32 | 33 | super().__init__(train, [], [], **kwargs) 34 | 35 | def process_train(self, dir_path, train_file): 36 | data = [] 37 | with open(train_file, "r") as f: 38 | img_paths = [line.strip('\n') for line in f.readlines()] 39 | 40 | for path in img_paths: 41 | split_path = path.split('\\') 42 | img_path = '/'.join(split_path) 43 | camid = self.dataset_name + "_" + split_path[0] 44 | pid = self.dataset_name + "_" + split_path[1] 45 | img_path = os.path.join(dir_path, img_path) 46 | # if 11001 <= int(split_path[1]) <= 401999: 47 | if 11001 <= int(split_path[1]): 48 | data.append([img_path, pid, camid]) 49 | 50 | return data 51 | -------------------------------------------------------------------------------- /fastreid/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from ...utils.registry import Registry 8 | 9 | DATASET_REGISTRY = Registry("DATASET") 10 | DATASET_REGISTRY.__doc__ = """ 11 | Registry for datasets 12 | It must returns an instance of :class:`Backbone`. 13 | """ 14 | 15 | # Person re-id datasets 16 | from .cuhk03 import CUHK03 17 | from .dukemtmcreid import DukeMTMC 18 | from .market1501 import Market1501 19 | from .msmt17 import MSMT17 20 | from .AirportALERT import AirportALERT 21 | from .iLIDS import iLIDS 22 | from .pku import PKU 23 | from .prai import PRAI 24 | from .prid import PRID 25 | from .grid import GRID 26 | from .saivt import SAIVT 27 | from .sensereid import SenseReID 28 | from .sysu_mm import SYSU_mm 29 | from .thermalworld import Thermalworld 30 | from .pes3d import PeS3D 31 | from .caviara import CAVIARa 32 | from .viper import VIPeR 33 | from .lpw import LPW 34 | from .shinpuhkan import Shinpuhkan 35 | from .wildtracker import WildTrackCrop 36 | from .cuhk_sysu import cuhkSYSU 37 | from .cuhk03_full import CUHK03FULL 38 | 39 | # Vehicle re-id datasets 40 | from .veri import VeRi 41 | from .vehicleid import VehicleID, SmallVehicleID, MediumVehicleID, LargeVehicleID 42 | from .veriwild import VeRiWild, SmallVeRiWild, MediumVeRiWild, LargeVeRiWild 43 | 44 | __all__ = [k for k in globals().keys() if "builtin" not in k and not k.startswith("_")] 45 | -------------------------------------------------------------------------------- /fastreid/data/datasets/caviara.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import os 8 | from glob import glob 9 | 10 | from fastreid.data.datasets import DATASET_REGISTRY 11 | from fastreid.data.datasets.bases import ImageDataset 12 | 13 | __all__ = ['CAVIARa', ] 14 | 15 | 16 | @DATASET_REGISTRY.register() 17 | class CAVIARa(ImageDataset): 18 | """CAVIARa 19 | """ 20 | dataset_dir = "CAVIARa" 21 | dataset_name = "caviara" 22 | 23 | def __init__(self, root='datasets', **kwargs): 24 | self.root = root 25 | self.train_path = os.path.join(self.root, self.dataset_dir) 26 | 27 | required_files = [self.train_path] 28 | self.check_before_run(required_files) 29 | 30 | train = self.process_train(self.train_path) 31 | 32 | super().__init__(train, [], [], **kwargs) 33 | 34 | def process_train(self, train_path): 35 | data = [] 36 | 37 | img_list = glob(os.path.join(train_path, "*.jpg")) 38 | for img_path in img_list: 39 | img_name = img_path.split('/')[-1] 40 | pid = self.dataset_name + "_" + img_name[:4] 41 | camid = self.dataset_name + "_cam0" 42 | data.append([img_path, pid, camid]) 43 | 44 | return data 45 | -------------------------------------------------------------------------------- /fastreid/data/datasets/cuhk_sysu.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import glob 8 | import os.path as osp 9 | import re 10 | import warnings 11 | 12 | from .bases import ImageDataset 13 | from ..datasets import DATASET_REGISTRY 14 | 15 | 16 | @DATASET_REGISTRY.register() 17 | class cuhkSYSU(ImageDataset): 18 | """CUHK SYSU datasets. 19 | 20 | The dataset is collected from two sources: street snap and movie. 21 | In street snap, 12,490 images and 6,057 query persons were collected 22 | with movable cameras across hundreds of scenes while 5,694 images and 23 | 2,375 query persons were selected from movies and TV dramas. 24 | 25 | Dataset statistics: 26 | - identities: 11,934 27 | - images: 34,574 28 | """ 29 | dataset_dir = 'cuhk_sysu' 30 | dataset_name = "cuhkSYSU" 31 | 32 | def __init__(self, root='datasets', **kwargs): 33 | self.root = '/data/pengyi/datasets/reid_data/' 34 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 35 | 36 | self.data_dir = osp.join(self.dataset_dir, "cropped_images") 37 | 38 | required_files = [self.data_dir] 39 | self.check_before_run(required_files) 40 | 41 | train = self.process_dir(self.data_dir) 42 | query = [] 43 | gallery = [] 44 | 45 | super(cuhkSYSU, self).__init__(train, query, gallery, **kwargs) 46 | 47 | def process_dir(self, dirname): 48 | img_paths = glob.glob(osp.join(dirname, '*.jpg')) 49 | # num_imgs = len(img_paths) 50 | 51 | # get all identities: 52 | pid_container = set() 53 | for img_path in img_paths: 54 | img_name = osp.basename(img_path) 55 | pid = img_name.split('_')[0] 56 | pid_container.add(pid) 57 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 58 | 59 | # num_pids = len(pid_container) 60 | 61 | # extract data 62 | data = [] 63 | for img_path in img_paths: 64 | img_name = osp.basename(img_path) 65 | pid = img_name.split('_')[0] 66 | label = self.dataset_name + "_" + str(pid2label[pid]) 67 | camid = self.dataset_name + "_0" 68 | data.append((img_path, label, camid)) # dummy camera id 69 | 70 | return data -------------------------------------------------------------------------------- /fastreid/data/datasets/dukemtmcreid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import glob 8 | import os.path as osp 9 | import re 10 | 11 | from .bases import ImageDataset 12 | from ..datasets import DATASET_REGISTRY 13 | 14 | 15 | @DATASET_REGISTRY.register() 16 | class DukeMTMC(ImageDataset): 17 | """DukeMTMC-reID. 18 | 19 | Reference: 20 | - Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 21 | - Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 22 | 23 | URL: ``_ 24 | 25 | Dataset statistics: 26 | - identities: 1404 (train + query). 27 | - images:16522 (train) + 2228 (query) + 17661 (gallery). 28 | - cameras: 8. 29 | """ 30 | dataset_dir = 'DukeMTMC-reID' 31 | dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 32 | dataset_name = "DukeMTMC" 33 | 34 | def __init__(self, root='datasets', **kwargs): 35 | root='/data/pengyi/datasets/reid_data/dukemtmc/' 36 | # self.root = osp.abspath(osp.expanduser(root)) 37 | self.root = root 38 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 39 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 40 | self.query_dir = osp.join(self.dataset_dir, 'query') 41 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 42 | 43 | required_files = [ 44 | self.dataset_dir, 45 | self.train_dir, 46 | self.query_dir, 47 | self.gallery_dir, 48 | ] 49 | self.check_before_run(required_files) 50 | 51 | train = self.process_dir(self.train_dir) 52 | query = self.process_dir(self.query_dir, is_train=False) 53 | gallery = self.process_dir(self.gallery_dir, is_train=False) 54 | 55 | super(DukeMTMC, self).__init__(train, query, gallery, **kwargs) 56 | 57 | def process_dir(self, dir_path, is_train=True): 58 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 59 | pattern = re.compile(r'([-\d]+)_c(\d)') 60 | 61 | data = [] 62 | for img_path in img_paths: 63 | pid, camid = map(int, pattern.search(img_path).groups()) 64 | assert 1 <= camid <= 8 65 | camid -= 1 # index starts from 0 66 | if is_train: 67 | pid = self.dataset_name + "_" + str(pid) 68 | camid = self.dataset_name + "_" + str(camid) 69 | data.append((img_path, pid, camid)) 70 | 71 | return data 72 | -------------------------------------------------------------------------------- /fastreid/data/datasets/lpw.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import os 8 | from glob import glob 9 | 10 | from fastreid.data.datasets import DATASET_REGISTRY 11 | from fastreid.data.datasets.bases import ImageDataset 12 | 13 | __all__ = ['LPW', ] 14 | 15 | 16 | @DATASET_REGISTRY.register() 17 | class LPW(ImageDataset): 18 | """LPW 19 | """ 20 | dataset_dir = "pep_256x128/data_slim" 21 | dataset_name = "lpw" 22 | 23 | def __init__(self, root='datasets', **kwargs): 24 | self.root = root 25 | self.train_path = os.path.join(self.root, self.dataset_dir) 26 | 27 | required_files = [self.train_path] 28 | self.check_before_run(required_files) 29 | 30 | train = self.process_train(self.train_path) 31 | 32 | super().__init__(train, [], [], **kwargs) 33 | 34 | def process_train(self, train_path): 35 | data = [] 36 | 37 | file_path_list = ['scen1', 'scen2', 'scen3'] 38 | 39 | for scene in file_path_list: 40 | cam_list = os.listdir(os.path.join(train_path, scene)) 41 | for cam in cam_list: 42 | camid = self.dataset_name + "_" + cam 43 | pid_list = os.listdir(os.path.join(train_path, scene, cam)) 44 | for pid_dir in pid_list: 45 | img_paths = glob(os.path.join(train_path, scene, cam, pid_dir, "*.jpg")) 46 | for img_path in img_paths: 47 | pid = self.dataset_name + "_" + scene + "-" + pid_dir 48 | data.append([img_path, pid, camid]) 49 | return data 50 | -------------------------------------------------------------------------------- /fastreid/data/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import glob 8 | import os.path as osp 9 | import re 10 | import warnings 11 | 12 | from .bases import ImageDataset 13 | from ..datasets import DATASET_REGISTRY 14 | 15 | 16 | @DATASET_REGISTRY.register() 17 | class Market1501(ImageDataset): 18 | """Market1501. 19 | 20 | Reference: 21 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 22 | 23 | URL: ``_ 24 | 25 | Dataset statistics: 26 | - identities: 1501 (+1 for background). 27 | - images: 12936 (train) + 3368 (query) + 15913 (gallery). 28 | """ 29 | _junk_pids = [0, -1] 30 | dataset_dir = '' 31 | dataset_url = 'http://188.138.127.15:81/Datasets/Market-1501-v15.09.15.zip' 32 | dataset_name = "Market1501" 33 | 34 | def __init__(self, root='datasets', market1501_500k=False, **kwargs): 35 | root='/data/pengyi/datasets/reid_data/market1501/Market-1501-v15.09.15' 36 | # self.root = osp.abspath(osp.expanduser(root)) 37 | self.root = root 38 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 39 | 40 | # allow alternative directory structure 41 | self.data_dir = self.dataset_dir 42 | data_dir = osp.join(self.data_dir, 'Market-1501-v15.09.15') 43 | if osp.isdir(data_dir): 44 | self.data_dir = data_dir 45 | else: 46 | warnings.warn('The current data structure is deprecated. Please ' 47 | 'put data folders such as "bounding_box_train" under ' 48 | '"Market-1501-v15.09.15".') 49 | 50 | self.train_dir = osp.join(self.data_dir, 'bounding_box_train') 51 | self.query_dir = osp.join(self.data_dir, 'query') 52 | self.gallery_dir = osp.join(self.data_dir, 'bounding_box_test') 53 | self.extra_gallery_dir = osp.join(self.data_dir, 'images') 54 | self.market1501_500k = market1501_500k 55 | 56 | required_files = [ 57 | self.data_dir, 58 | self.train_dir, 59 | self.query_dir, 60 | self.gallery_dir, 61 | ] 62 | if self.market1501_500k: 63 | required_files.append(self.extra_gallery_dir) 64 | self.check_before_run(required_files) 65 | 66 | train = lambda: self.process_dir(self.train_dir) 67 | query = lambda: self.process_dir(self.query_dir, is_train=False) 68 | gallery = lambda: self.process_dir(self.gallery_dir, is_train=False) + \ 69 | (self.process_dir(self.extra_gallery_dir, is_train=False) if self.market1501_500k else []) 70 | 71 | super(Market1501, self).__init__(train, query, gallery, **kwargs) 72 | 73 | def process_dir(self, dir_path, is_train=True): 74 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 75 | pattern = re.compile(r'([-\d]+)_c(\d)') 76 | 77 | data = [] 78 | for img_path in img_paths: 79 | pid, camid = map(int, pattern.search(img_path).groups()) 80 | if pid == -1: 81 | continue # junk images are just ignored 82 | assert 0 <= pid <= 1501 # pid == 0 means background 83 | assert 1 <= camid <= 6 84 | camid -= 1 # index starts from 0 85 | if is_train: 86 | pid = self.dataset_name + "_" + str(pid) 87 | camid = self.dataset_name + "_" + str(camid) 88 | data.append((img_path, pid, camid)) 89 | 90 | return data 91 | -------------------------------------------------------------------------------- /fastreid/data/datasets/msmt17.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: l1aoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import sys 8 | import os 9 | import os.path as osp 10 | 11 | from .bases import ImageDataset 12 | from ..datasets import DATASET_REGISTRY 13 | ##### Log ##### 14 | # 22.01.2019 15 | # - add v2 16 | # - v1 and v2 differ in dir names 17 | # - note that faces in v2 are blurred 18 | TRAIN_DIR_KEY = 'train_dir' 19 | TEST_DIR_KEY = 'test_dir' 20 | VERSION_DICT = { 21 | 'MSMT17_V1': { 22 | TRAIN_DIR_KEY: 'train', 23 | TEST_DIR_KEY: 'test', 24 | }, 25 | 'MSMT17_V2': { 26 | TRAIN_DIR_KEY: 'mask_train_v2', 27 | TEST_DIR_KEY: 'mask_test_v2', 28 | } 29 | } 30 | 31 | 32 | @DATASET_REGISTRY.register() 33 | class MSMT17(ImageDataset): 34 | """MSMT17. 35 | Reference: 36 | Wei et al. Person Transfer GAN to Bridge Domain Gap for Person Re-Identification. CVPR 2018. 37 | URL: ``_ 38 | 39 | Dataset statistics: 40 | - identities: 4101. 41 | - images: 32621 (train) + 11659 (query) + 82161 (gallery). 42 | - cameras: 15. 43 | """ 44 | # dataset_dir = 'MSMT17_V2' 45 | dataset_url = None 46 | dataset_name = 'MSMT17' 47 | 48 | def __init__(self, root='datasets', **kwargs): 49 | root='/data/pengyi/datasets/reid_data/msmt17v2/' 50 | self.dataset_dir = root 51 | 52 | has_main_dir = False 53 | for main_dir in VERSION_DICT: 54 | if osp.exists(osp.join(self.dataset_dir, main_dir)): 55 | train_dir = VERSION_DICT[main_dir][TRAIN_DIR_KEY] 56 | test_dir = VERSION_DICT[main_dir][TEST_DIR_KEY] 57 | has_main_dir = True 58 | break 59 | assert has_main_dir, 'Dataset folder not found' 60 | 61 | self.train_dir = osp.join(self.dataset_dir, main_dir, train_dir) 62 | self.test_dir = osp.join(self.dataset_dir, main_dir, test_dir) 63 | self.list_train_path = osp.join(self.dataset_dir, main_dir, 'list_train.txt') 64 | self.list_val_path = osp.join(self.dataset_dir, main_dir, 'list_val.txt') 65 | self.list_query_path = osp.join(self.dataset_dir, main_dir, 'list_query.txt') 66 | self.list_gallery_path = osp.join(self.dataset_dir, main_dir, 'list_gallery.txt') 67 | 68 | required_files = [ 69 | self.dataset_dir, 70 | self.train_dir, 71 | self.test_dir 72 | ] 73 | self.check_before_run(required_files) 74 | 75 | train = self.process_dir(self.train_dir, self.list_train_path) 76 | val = self.process_dir(self.train_dir, self.list_val_path) 77 | query = self.process_dir(self.test_dir, self.list_query_path, is_train=False) 78 | gallery = self.process_dir(self.test_dir, self.list_gallery_path, is_train=False) 79 | 80 | num_train_pids = self.get_num_pids(train) 81 | query_tmp = [] 82 | for img_path, pid, camid in query: 83 | query_tmp.append((img_path, pid+num_train_pids, camid)) 84 | del query 85 | query = query_tmp 86 | 87 | gallery_temp = [] 88 | for img_path, pid, camid in gallery: 89 | gallery_temp.append((img_path, pid+num_train_pids, camid)) 90 | del gallery 91 | gallery = gallery_temp 92 | 93 | # Note: to fairly compare with published methods on the conventional ReID setting, 94 | # do not add val images to the training set. 95 | #CHANGE Add the val set during training 96 | # train += val 97 | if 'combineall' in kwargs and kwargs['combineall']: 98 | train += val 99 | super(MSMT17, self).__init__(train, query, gallery, **kwargs) 100 | 101 | def process_dir(self, dir_path, list_path, is_train=True): 102 | with open(list_path, 'r') as txt: 103 | lines = txt.readlines() 104 | 105 | data = [] 106 | 107 | for img_idx, img_info in enumerate(lines): 108 | img_path, pid = img_info.split(' ') 109 | pid = int(pid) # no need to relabel 110 | camid = int(img_path.split('_')[2]) - 1 # index starts from 0 111 | img_path = osp.join(dir_path, img_path) 112 | if is_train: 113 | pid = self.dataset_name + "_" + str(pid) 114 | camid = self.dataset_name + "_" + str(camid) 115 | data.append((img_path, pid, camid)) 116 | 117 | return data 118 | -------------------------------------------------------------------------------- /fastreid/data/datasets/pes3d.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import os 8 | from glob import glob 9 | 10 | from fastreid.data.datasets import DATASET_REGISTRY 11 | from fastreid.data.datasets.bases import ImageDataset 12 | 13 | __all__ = ['PeS3D',] 14 | 15 | 16 | @DATASET_REGISTRY.register() 17 | class PeS3D(ImageDataset): 18 | """3Dpes 19 | """ 20 | dataset_dir = "3DPeS" 21 | dataset_name = "pes3d" 22 | 23 | def __init__(self, root='datasets', **kwargs): 24 | self.root = root 25 | self.train_path = os.path.join(self.root, self.dataset_dir) 26 | 27 | required_files = [self.train_path] 28 | self.check_before_run(required_files) 29 | 30 | train = self.process_train(self.train_path) 31 | 32 | super().__init__(train, [], [], **kwargs) 33 | 34 | def process_train(self, train_path): 35 | data = [] 36 | 37 | pid_list = os.listdir(train_path) 38 | for pid_dir in pid_list: 39 | pid = self.dataset_name + "_" + pid_dir 40 | img_list = glob(os.path.join(train_path, pid_dir, "*.bmp")) 41 | for img_path in img_list: 42 | camid = self.dataset_name + "_cam0" 43 | data.append([img_path, pid, camid]) 44 | return data 45 | -------------------------------------------------------------------------------- /fastreid/data/datasets/pku.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import os 8 | from glob import glob 9 | 10 | from fastreid.data.datasets import DATASET_REGISTRY 11 | from fastreid.data.datasets.bases import ImageDataset 12 | 13 | __all__ = ['PKU', ] 14 | 15 | 16 | @DATASET_REGISTRY.register() 17 | class PKU(ImageDataset): 18 | """PKU 19 | """ 20 | dataset_dir = "PKUv1a_128x48" 21 | dataset_name = 'pku' 22 | 23 | def __init__(self, root='datasets', **kwargs): 24 | self.root = root 25 | self.train_path = os.path.join(self.root, self.dataset_dir) 26 | 27 | required_files = [self.train_path] 28 | self.check_before_run(required_files) 29 | 30 | train = self.process_train(self.train_path) 31 | 32 | super().__init__(train, [], [], **kwargs) 33 | 34 | def process_train(self, train_path): 35 | data = [] 36 | img_paths = glob(os.path.join(train_path, "*.png")) 37 | 38 | for img_path in img_paths: 39 | split_path = img_path.split('/') 40 | img_info = split_path[-1].split('_') 41 | pid = self.dataset_name + "_" + img_info[0] 42 | camid = self.dataset_name + "_" + img_info[1] 43 | data.append([img_path, pid, camid]) 44 | return data 45 | -------------------------------------------------------------------------------- /fastreid/data/datasets/prai.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import os 8 | from glob import glob 9 | 10 | from fastreid.data.datasets import DATASET_REGISTRY 11 | from fastreid.data.datasets.bases import ImageDataset 12 | 13 | __all__ = ['PRAI', ] 14 | 15 | 16 | @DATASET_REGISTRY.register() 17 | class PRAI(ImageDataset): 18 | """PRAI 19 | """ 20 | dataset_dir = "PRAI-1581" 21 | dataset_name = 'prai' 22 | 23 | def __init__(self, root='datasets', **kwargs): 24 | self.root = root 25 | self.train_path = os.path.join(self.root, self.dataset_dir, 'images') 26 | 27 | required_files = [self.train_path] 28 | self.check_before_run(required_files) 29 | 30 | train = self.process_train(self.train_path) 31 | 32 | super().__init__(train, [], [], **kwargs) 33 | 34 | def process_train(self, train_path): 35 | data = [] 36 | img_paths = glob(os.path.join(train_path, "*.jpg")) 37 | for img_path in img_paths: 38 | split_path = img_path.split('/') 39 | img_info = split_path[-1].split('_') 40 | pid = self.dataset_name + "_" + img_info[0] 41 | camid = self.dataset_name + "_" + img_info[1] 42 | data.append([img_path, pid, camid]) 43 | return data 44 | -------------------------------------------------------------------------------- /fastreid/data/datasets/saivt.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import os 8 | from glob import glob 9 | 10 | from fastreid.data.datasets import DATASET_REGISTRY 11 | from fastreid.data.datasets.bases import ImageDataset 12 | 13 | __all__ = ['SAIVT', ] 14 | 15 | 16 | @DATASET_REGISTRY.register() 17 | class SAIVT(ImageDataset): 18 | """SAIVT 19 | """ 20 | dataset_dir = "SAIVT-SoftBio" 21 | dataset_name = "saivt" 22 | 23 | def __init__(self, root='datasets', **kwargs): 24 | self.root = root 25 | self.train_path = os.path.join(self.root, self.dataset_dir) 26 | 27 | required_files = [self.train_path] 28 | self.check_before_run(required_files) 29 | 30 | train = self.process_train(self.train_path) 31 | 32 | super().__init__(train, [], [], **kwargs) 33 | 34 | def process_train(self, train_path): 35 | data = [] 36 | 37 | pid_path = os.path.join(train_path, "cropped_images") 38 | pid_list = os.listdir(pid_path) 39 | 40 | for pid_name in pid_list: 41 | pid = self.dataset_name + '_' + pid_name 42 | img_list = glob(os.path.join(pid_path, pid_name, "*.jpeg")) 43 | for img_path in img_list: 44 | img_name = os.path.basename(img_path) 45 | camid = self.dataset_name + '_' + img_name.split('-')[2] 46 | data.append([img_path, pid, camid]) 47 | return data 48 | -------------------------------------------------------------------------------- /fastreid/data/datasets/sensereid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import os 8 | from glob import glob 9 | 10 | from fastreid.data.datasets import DATASET_REGISTRY 11 | from fastreid.data.datasets.bases import ImageDataset 12 | 13 | __all__ = ['SenseReID', ] 14 | 15 | 16 | @DATASET_REGISTRY.register() 17 | class SenseReID(ImageDataset): 18 | """Sense reid 19 | """ 20 | dataset_dir = "SenseReID" 21 | dataset_name = "senseid" 22 | 23 | def __init__(self, root='datasets', **kwargs): 24 | self.root = root 25 | self.train_path = os.path.join(self.root, self.dataset_dir) 26 | 27 | required_files = [self.train_path] 28 | self.check_before_run(required_files) 29 | 30 | train = self.process_train(self.train_path) 31 | 32 | super().__init__(train, [], [], **kwargs) 33 | 34 | def process_train(self, train_path): 35 | data = [] 36 | file_path_list = ['test_gallery', 'test_prob'] 37 | 38 | for file_path in file_path_list: 39 | sub_file = os.path.join(train_path, file_path) 40 | img_name = glob(os.path.join(sub_file, "*.jpg")) 41 | for img_path in img_name: 42 | img_name = img_path.split('/')[-1] 43 | img_info = img_name.split('_') 44 | pid = self.dataset_name + "_" + img_info[0] 45 | camid = self.dataset_name + "_" + img_info[1].split('.')[0] 46 | data.append([img_path, pid, camid]) 47 | return data 48 | -------------------------------------------------------------------------------- /fastreid/data/datasets/shinpuhkan.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import os 8 | 9 | from fastreid.data.datasets import DATASET_REGISTRY 10 | from fastreid.data.datasets.bases import ImageDataset 11 | 12 | __all__ = ['Shinpuhkan', ] 13 | 14 | 15 | @DATASET_REGISTRY.register() 16 | class Shinpuhkan(ImageDataset): 17 | """shinpuhkan 18 | """ 19 | dataset_dir = "shinpuhkan" 20 | dataset_name = 'shinpuhkan' 21 | 22 | def __init__(self, root='datasets', **kwargs): 23 | self.root = root 24 | self.train_path = os.path.join(self.root, self.dataset_dir) 25 | 26 | required_files = [self.train_path] 27 | self.check_before_run(required_files) 28 | 29 | train = self.process_train(self.train_path) 30 | 31 | super().__init__(train, [], [], **kwargs) 32 | 33 | def process_train(self, train_path): 34 | data = [] 35 | 36 | for root, dirs, files in os.walk(train_path): 37 | img_names = list(filter(lambda x: x.endswith(".jpg"), files)) 38 | # fmt: off 39 | if len(img_names) == 0: continue 40 | # fmt: on 41 | for img_name in img_names: 42 | img_path = os.path.join(root, img_name) 43 | split_path = img_name.split('_') 44 | pid = self.dataset_name + "_" + split_path[0] 45 | camid = self.dataset_name + "_" + split_path[2] 46 | data.append((img_path, pid, camid)) 47 | 48 | return data 49 | -------------------------------------------------------------------------------- /fastreid/data/datasets/sysu_mm.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import os 8 | from glob import glob 9 | 10 | from fastreid.data.datasets import DATASET_REGISTRY 11 | from fastreid.data.datasets.bases import ImageDataset 12 | 13 | __all__ = ['SYSU_mm', ] 14 | 15 | 16 | @DATASET_REGISTRY.register() 17 | class SYSU_mm(ImageDataset): 18 | """sysu mm 19 | """ 20 | dataset_dir = "SYSU-MM01" 21 | dataset_name = "sysumm01" 22 | 23 | def __init__(self, root='datasets', **kwargs): 24 | self.root = root 25 | self.train_path = os.path.join(self.root, self.dataset_dir) 26 | 27 | required_files = [self.train_path] 28 | self.check_before_run(required_files) 29 | 30 | train = self.process_train(self.train_path) 31 | 32 | super().__init__(train, [], [], **kwargs) 33 | 34 | def process_train(self, train_path): 35 | data = [] 36 | 37 | file_path_list = ['cam1', 'cam2', 'cam4', 'cam5'] 38 | 39 | for file_path in file_path_list: 40 | camid = self.dataset_name + "_" + file_path 41 | pid_list = os.listdir(os.path.join(train_path, file_path)) 42 | for pid_dir in pid_list: 43 | pid = self.dataset_name + "_" + pid_dir 44 | img_list = glob(os.path.join(train_path, file_path, pid_dir, "*.jpg")) 45 | for img_path in img_list: 46 | data.append([img_path, pid, camid]) 47 | return data 48 | -------------------------------------------------------------------------------- /fastreid/data/datasets/thermalworld.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import os 8 | from glob import glob 9 | 10 | from fastreid.data.datasets import DATASET_REGISTRY 11 | from fastreid.data.datasets.bases import ImageDataset 12 | 13 | __all__ = ['Thermalworld', ] 14 | 15 | 16 | @DATASET_REGISTRY.register() 17 | class Thermalworld(ImageDataset): 18 | """thermal world 19 | """ 20 | dataset_dir = "thermalworld_rgb" 21 | dataset_name = "thermalworld" 22 | 23 | def __init__(self, root='datasets', **kwargs): 24 | self.root = root 25 | self.train_path = os.path.join(self.root, self.dataset_dir) 26 | 27 | required_files = [self.train_path] 28 | self.check_before_run(required_files) 29 | 30 | train = self.process_train(self.train_path) 31 | 32 | super().__init__(train, [], [], **kwargs) 33 | 34 | def process_train(self, train_path): 35 | data = [] 36 | pid_list = os.listdir(train_path) 37 | for pid_dir in pid_list: 38 | pid = self.dataset_name + "_" + pid_dir 39 | img_list = glob(os.path.join(train_path, pid_dir, "*.jpg")) 40 | for img_path in img_list: 41 | camid = self.dataset_name + "_cam0" 42 | data.append([img_path, pid, camid]) 43 | return data 44 | -------------------------------------------------------------------------------- /fastreid/data/datasets/vehicleid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: Jinkai Zheng 4 | @contact: 1315673509@qq.com 5 | """ 6 | 7 | import os.path as osp 8 | import random 9 | 10 | from .bases import ImageDataset 11 | from ..datasets import DATASET_REGISTRY 12 | 13 | 14 | @DATASET_REGISTRY.register() 15 | class VehicleID(ImageDataset): 16 | """VehicleID. 17 | 18 | Reference: 19 | Liu et al. Deep relative distance learning: Tell the difference between similar vehicles. CVPR 2016. 20 | 21 | URL: ``_ 22 | 23 | Train dataset statistics: 24 | - identities: 13164. 25 | - images: 113346. 26 | """ 27 | dataset_dir = "vehicleid" 28 | dataset_name = "vehicleid" 29 | 30 | def __init__(self, root='datasets', test_list='', **kwargs): 31 | self.dataset_dir = osp.join(root, self.dataset_dir) 32 | 33 | self.image_dir = osp.join(self.dataset_dir, 'image') 34 | self.train_list = osp.join(self.dataset_dir, 'train_test_split/train_list.txt') 35 | if test_list: 36 | self.test_list = test_list 37 | else: 38 | self.test_list = osp.join(self.dataset_dir, 'train_test_split/test_list_13164.txt') 39 | 40 | required_files = [ 41 | self.dataset_dir, 42 | self.image_dir, 43 | self.train_list, 44 | self.test_list, 45 | ] 46 | self.check_before_run(required_files) 47 | 48 | train = self.process_dir(self.train_list, is_train=True) 49 | query, gallery = self.process_dir(self.test_list, is_train=False) 50 | 51 | super(VehicleID, self).__init__(train, query, gallery, **kwargs) 52 | 53 | def process_dir(self, list_file, is_train=True): 54 | img_list_lines = open(list_file, 'r').readlines() 55 | 56 | dataset = [] 57 | for idx, line in enumerate(img_list_lines): 58 | line = line.strip() 59 | vid = int(line.split(' ')[1]) 60 | imgid = line.split(' ')[0] 61 | img_path = osp.join(self.image_dir, f"{imgid}.jpg") 62 | imgid = int(imgid) 63 | if is_train: 64 | vid = f"{self.dataset_name}_{vid}" 65 | imgid = f"{self.dataset_name}_{imgid}" 66 | dataset.append((img_path, vid, imgid)) 67 | 68 | if is_train: return dataset 69 | else: 70 | random.shuffle(dataset) 71 | vid_container = set() 72 | query = [] 73 | gallery = [] 74 | for sample in dataset: 75 | if sample[1] not in vid_container: 76 | vid_container.add(sample[1]) 77 | gallery.append(sample) 78 | else: 79 | query.append(sample) 80 | 81 | return query, gallery 82 | 83 | 84 | @DATASET_REGISTRY.register() 85 | class SmallVehicleID(VehicleID): 86 | """VehicleID. 87 | Small test dataset statistics: 88 | - identities: 800. 89 | - images: 6493. 90 | """ 91 | 92 | def __init__(self, root='datasets', **kwargs): 93 | dataset_dir = osp.join(root, self.dataset_dir) 94 | self.test_list = osp.join(dataset_dir, 'train_test_split/test_list_800.txt') 95 | 96 | super(SmallVehicleID, self).__init__(root, self.test_list, **kwargs) 97 | 98 | 99 | @DATASET_REGISTRY.register() 100 | class MediumVehicleID(VehicleID): 101 | """VehicleID. 102 | Medium test dataset statistics: 103 | - identities: 1600. 104 | - images: 13377. 105 | """ 106 | 107 | def __init__(self, root='datasets', **kwargs): 108 | dataset_dir = osp.join(root, self.dataset_dir) 109 | self.test_list = osp.join(dataset_dir, 'train_test_split/test_list_1600.txt') 110 | 111 | super(MediumVehicleID, self).__init__(root, self.test_list, **kwargs) 112 | 113 | 114 | @DATASET_REGISTRY.register() 115 | class LargeVehicleID(VehicleID): 116 | """VehicleID. 117 | Large test dataset statistics: 118 | - identities: 2400. 119 | - images: 19777. 120 | """ 121 | 122 | def __init__(self, root='datasets', **kwargs): 123 | dataset_dir = osp.join(root, self.dataset_dir) 124 | self.test_list = osp.join(dataset_dir, 'train_test_split/test_list_2400.txt') 125 | 126 | super(LargeVehicleID, self).__init__(root, self.test_list, **kwargs) 127 | -------------------------------------------------------------------------------- /fastreid/data/datasets/veri.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: Jinkai Zheng 4 | @contact: 1315673509@qq.com 5 | """ 6 | 7 | import glob 8 | import os.path as osp 9 | import re 10 | 11 | from .bases import ImageDataset 12 | from ..datasets import DATASET_REGISTRY 13 | 14 | 15 | @DATASET_REGISTRY.register() 16 | class VeRi(ImageDataset): 17 | """VeRi. 18 | 19 | Reference: 20 | Xinchen Liu et al. A Deep Learning based Approach for Progressive Vehicle Re-Identification. ECCV 2016. 21 | Xinchen Liu et al. PROVID: Progressive and Multimodal Vehicle Reidentification for Large-Scale Urban Surveillance. IEEE TMM 2018. 22 | 23 | URL: ``_ 24 | 25 | Dataset statistics: 26 | - identities: 775. 27 | - images: 37778 (train) + 1678 (query) + 11579 (gallery). 28 | """ 29 | dataset_dir = "veri" 30 | dataset_name = "veri" 31 | 32 | def __init__(self, root='datasets', **kwargs): 33 | self.dataset_dir = osp.join(root, self.dataset_dir) 34 | 35 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 36 | self.query_dir = osp.join(self.dataset_dir, 'image_query') 37 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test') 38 | 39 | required_files = [ 40 | self.dataset_dir, 41 | self.train_dir, 42 | self.query_dir, 43 | self.gallery_dir, 44 | ] 45 | self.check_before_run(required_files) 46 | 47 | train = self.process_dir(self.train_dir) 48 | query = self.process_dir(self.query_dir, is_train=False) 49 | gallery = self.process_dir(self.gallery_dir, is_train=False) 50 | 51 | super(VeRi, self).__init__(train, query, gallery, **kwargs) 52 | 53 | def process_dir(self, dir_path, is_train=True): 54 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 55 | pattern = re.compile(r'([\d]+)_c(\d\d\d)') 56 | 57 | data = [] 58 | for img_path in img_paths: 59 | pid, camid = map(int, pattern.search(img_path).groups()) 60 | if pid == -1: continue # junk images are just ignored 61 | assert 0 <= pid <= 776 62 | assert 1 <= camid <= 20 63 | camid -= 1 # index starts from 0 64 | if is_train: 65 | pid = self.dataset_name + "_" + str(pid) 66 | camid = self.dataset_name + "_" + str(camid) 67 | data.append((img_path, pid, camid)) 68 | 69 | return data 70 | -------------------------------------------------------------------------------- /fastreid/data/datasets/wildtracker.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: wangguanan 4 | @contact: guan.wang0706@gmail.com 5 | """ 6 | 7 | import glob 8 | import os 9 | 10 | from .bases import ImageDataset 11 | from ..datasets import DATASET_REGISTRY 12 | 13 | 14 | @DATASET_REGISTRY.register() 15 | class WildTrackCrop(ImageDataset): 16 | """WildTrack. 17 | Reference: 18 | WILDTRACK: A Multi-camera HD Dataset for Dense Unscripted Pedestrian Detection 19 | T. Chavdarova; P. Baqué; A. Maksai; S. Bouquet; C. Jose et al. 20 | URL: ``_ 21 | Dataset statistics: 22 | - identities: 313 23 | - images: 33979 (train only) 24 | - cameras: 7 25 | Args: 26 | data_path(str): path to WildTrackCrop dataset 27 | combineall(bool): combine train and test sets as train set if True 28 | """ 29 | dataset_url = None 30 | dataset_dir = 'Wildtrack_crop_dataset' 31 | dataset_name = 'wildtrack' 32 | 33 | def __init__(self, root='datasets', **kwargs): 34 | self.root = root 35 | self.dataset_dir = os.path.join(self.root, self.dataset_dir) 36 | 37 | self.train_dir = os.path.join(self.dataset_dir, "crop") 38 | 39 | train = self.process_dir(self.train_dir) 40 | query = [] 41 | gallery = [] 42 | 43 | super(WildTrackCrop, self).__init__(train, query, gallery, **kwargs) 44 | 45 | def process_dir(self, dir_path): 46 | r""" 47 | :param dir_path: directory path saving images 48 | Returns 49 | data(list) = [img_path, pid, camid] 50 | """ 51 | data = [] 52 | for dir_name in os.listdir(dir_path): 53 | img_lists = glob.glob(os.path.join(dir_path, dir_name, "*.png")) 54 | for img_path in img_lists: 55 | pid = self.dataset_name + "_" + dir_name 56 | camid = img_path.split('/')[-1].split('_')[0] 57 | camid = self.dataset_name + "_" + camid 58 | data.append([img_path, pid, camid]) 59 | return data 60 | -------------------------------------------------------------------------------- /fastreid/data/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .triplet_sampler import BalancedIdentitySampler, NaiveIdentitySampler, SetReWeightSampler 8 | from .data_sampler import TrainingSampler, InferenceSampler 9 | from .imbalance_sampler import ImbalancedDatasetSampler 10 | 11 | __all__ = [ 12 | "BalancedIdentitySampler", 13 | "NaiveIdentitySampler", 14 | "SetReWeightSampler", 15 | "TrainingSampler", 16 | "InferenceSampler", 17 | "ImbalancedDatasetSampler", 18 | ] 19 | -------------------------------------------------------------------------------- /fastreid/data/samplers/data_sampler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: l1aoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import itertools 7 | from typing import Optional 8 | 9 | import numpy as np 10 | from torch.utils.data import Sampler 11 | 12 | from fastreid.utils import comm 13 | 14 | 15 | class TrainingSampler(Sampler): 16 | """ 17 | In training, we only care about the "infinite stream" of training data. 18 | So this sampler produces an infinite stream of indices and 19 | all workers cooperate to correctly shuffle the indices and sample different indices. 20 | The samplers in each worker effectively produces `indices[worker_id::num_workers]` 21 | where `indices` is an infinite stream of indices consisting of 22 | `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True) 23 | or `range(size) + range(size) + ...` (if shuffle is False) 24 | """ 25 | 26 | def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None): 27 | """ 28 | Args: 29 | size (int): the total number of data of the underlying dataset to sample from 30 | shuffle (bool): whether to shuffle the indices or not 31 | seed (int): the initial seed of the shuffle. Must be the same 32 | across all workers. If None, will use a random seed shared 33 | among workers (require synchronization among all workers). 34 | """ 35 | self._size = size 36 | assert size > 0 37 | self._shuffle = shuffle 38 | if seed is None: 39 | seed = comm.shared_random_seed() 40 | self._seed = int(seed) 41 | 42 | self._rank = comm.get_rank() 43 | self._world_size = comm.get_world_size() 44 | 45 | def __iter__(self): 46 | start = self._rank 47 | yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) 48 | 49 | def _infinite_indices(self): 50 | np.random.seed(self._seed) 51 | while True: 52 | if self._shuffle: 53 | yield from np.random.permutation(self._size) 54 | else: 55 | yield from np.arange(self._size) 56 | 57 | 58 | class InferenceSampler(Sampler): 59 | """ 60 | Produce indices for inference. 61 | Inference needs to run on the __exact__ set of samples, 62 | therefore when the total number of samples is not divisible by the number of workers, 63 | this sampler produces different number of samples on different workers. 64 | """ 65 | 66 | def __init__(self, size: int): 67 | """ 68 | Args: 69 | size (int): the total number of data of the underlying dataset to sample from 70 | """ 71 | self._size = size 72 | assert size > 0 73 | self._rank = comm.get_rank() 74 | self._world_size = comm.get_world_size() 75 | 76 | shard_size = (self._size - 1) // self._world_size + 1 77 | begin = shard_size * self._rank 78 | end = min(shard_size * (self._rank + 1), self._size) 79 | self._local_indices = range(begin, end) 80 | 81 | def __iter__(self): 82 | yield from self._local_indices 83 | 84 | def __len__(self): 85 | return len(self._local_indices) 86 | -------------------------------------------------------------------------------- /fastreid/data/samplers/imbalance_sampler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | # based on: 8 | # https://github.com/ufoym/imbalanced-dataset-sampler/blob/master/torchsampler/imbalanced.py 9 | 10 | 11 | import itertools 12 | from typing import Optional, List, Callable 13 | 14 | import numpy as np 15 | import torch 16 | from torch.utils.data.sampler import Sampler 17 | 18 | from fastreid.utils import comm 19 | 20 | 21 | class ImbalancedDatasetSampler(Sampler): 22 | """Samples elements randomly from a given list of indices for imbalanced dataset 23 | Arguments: 24 | data_source: a list of data items 25 | size: number of samples to draw 26 | """ 27 | 28 | def __init__(self, data_source: List, size: int = None, seed: Optional[int] = None, 29 | callback_get_label: Callable = None): 30 | self.data_source = data_source 31 | # consider all elements in the dataset 32 | self.indices = list(range(len(data_source))) 33 | # if num_samples is not provided, draw `len(indices)` samples in each iteration 34 | self._size = len(self.indices) if size is None else size 35 | self.callback_get_label = callback_get_label 36 | 37 | # distribution of classes in the dataset 38 | label_to_count = {} 39 | for idx in self.indices: 40 | label = self._get_label(data_source, idx) 41 | label_to_count[label] = label_to_count.get(label, 0) + 1 42 | 43 | # weight for each sample 44 | weights = [1.0 / label_to_count[self._get_label(data_source, idx)] for idx in self.indices] 45 | self.weights = torch.DoubleTensor(weights) 46 | 47 | if seed is None: 48 | seed = comm.shared_random_seed() 49 | self._seed = int(seed) 50 | self._rank = comm.get_rank() 51 | self._world_size = comm.get_world_size() 52 | 53 | def _get_label(self, dataset, idx): 54 | if self.callback_get_label: 55 | return self.callback_get_label(dataset, idx) 56 | else: 57 | return dataset[idx][1] 58 | 59 | def __iter__(self): 60 | start = self._rank 61 | yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) 62 | 63 | def _infinite_indices(self): 64 | np.random.seed(self._seed) 65 | while True: 66 | for i in torch.multinomial(self.weights, self._size, replacement=True): 67 | yield self.indices[i] 68 | -------------------------------------------------------------------------------- /fastreid/data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .autoaugment import AutoAugment 8 | from .build import build_transforms 9 | from .transforms import * 10 | 11 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 12 | -------------------------------------------------------------------------------- /fastreid/data/transforms/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import copy 8 | import torchvision.transforms as T 9 | 10 | from .transforms import * 11 | from .autoaugment import AutoAugment 12 | 13 | 14 | def build_transforms(cfg, is_train=True): 15 | res = [] 16 | 17 | if is_train: 18 | size_train = cfg.INPUT.SIZE_TRAIN 19 | 20 | # crop 21 | do_crop = cfg.INPUT.CROP.ENABLED 22 | crop_size = cfg.INPUT.CROP.SIZE 23 | crop_scale = cfg.INPUT.CROP.SCALE 24 | crop_ratio = cfg.INPUT.CROP.RATIO 25 | 26 | # augmix augmentation 27 | do_augmix = cfg.INPUT.AUGMIX.ENABLED 28 | augmix_prob = cfg.INPUT.AUGMIX.PROB 29 | 30 | # auto augmentation 31 | do_autoaug = cfg.INPUT.AUTOAUG.ENABLED 32 | autoaug_prob = cfg.INPUT.AUTOAUG.PROB 33 | 34 | # horizontal filp 35 | do_flip = cfg.INPUT.FLIP.ENABLED 36 | flip_prob = cfg.INPUT.FLIP.PROB 37 | 38 | # padding 39 | do_pad = cfg.INPUT.PADDING.ENABLED 40 | padding_size = cfg.INPUT.PADDING.SIZE 41 | padding_mode = cfg.INPUT.PADDING.MODE 42 | 43 | # color jitter 44 | do_cj = cfg.INPUT.CJ.ENABLED 45 | cj_prob = cfg.INPUT.CJ.PROB 46 | cj_brightness = cfg.INPUT.CJ.BRIGHTNESS 47 | cj_contrast = cfg.INPUT.CJ.CONTRAST 48 | cj_saturation = cfg.INPUT.CJ.SATURATION 49 | cj_hue = cfg.INPUT.CJ.HUE 50 | 51 | # random affine 52 | do_affine = cfg.INPUT.AFFINE.ENABLED 53 | 54 | # random erasing 55 | do_rea = cfg.INPUT.REA.ENABLED 56 | rea_prob = cfg.INPUT.REA.PROB 57 | rea_value = cfg.INPUT.REA.VALUE 58 | 59 | # random patch 60 | do_rpt = cfg.INPUT.RPT.ENABLED 61 | rpt_prob = cfg.INPUT.RPT.PROB 62 | 63 | if do_autoaug: 64 | res.append(T.RandomApply([AutoAugment()], p=autoaug_prob)) 65 | 66 | if size_train[0] > 0: 67 | res.append(T.Resize(size_train[0] if len(size_train) == 1 else size_train, interpolation=3)) 68 | 69 | if do_crop: 70 | res.append(T.RandomResizedCrop(size=crop_size[0] if len(crop_size) == 1 else crop_size, 71 | interpolation=3, 72 | scale=crop_scale, ratio=crop_ratio)) 73 | if do_pad: 74 | res.extend([T.Pad(padding_size, padding_mode=padding_mode), 75 | T.RandomCrop(size_train[0] if len(size_train) == 1 else size_train)]) 76 | if do_flip: 77 | res.append(T.RandomHorizontalFlip(p=flip_prob)) 78 | 79 | res1 = copy.deepcopy(res) 80 | res1.append(T.RandomApply([T.ColorJitter(cj_brightness, cj_contrast, cj_saturation, cj_hue)], p=cj_prob)) 81 | if do_cj: 82 | res.append(T.RandomApply([T.ColorJitter(cj_brightness, cj_contrast, cj_saturation, cj_hue)], p=cj_prob)) 83 | if do_affine: 84 | res.append(T.RandomAffine(degrees=10, translate=None, scale=[0.9, 1.1], shear=0.1, resample=False, 85 | fillcolor=0)) 86 | if do_augmix: 87 | res.append(AugMix(prob=augmix_prob)) 88 | res1.append(ToTensor()) 89 | res.append(ToTensor()) 90 | if do_rea: 91 | res1.append(T.RandomErasing(p=rea_prob, value=rea_value)) 92 | res.append(T.RandomErasing(p=rea_prob, value=rea_value)) 93 | if do_rpt: 94 | res.append(RandomPatch(prob_happen=rpt_prob)) 95 | 96 | return [T.Compose(res), T.Compose(res1)] 97 | else: 98 | size_test = cfg.INPUT.SIZE_TEST 99 | do_crop = cfg.INPUT.CROP.ENABLED 100 | crop_size = cfg.INPUT.CROP.SIZE 101 | 102 | if size_test[0] > 0: 103 | res.append(T.Resize(size_test[0] if len(size_test) == 1 else size_test, interpolation=3)) 104 | if do_crop: 105 | res.append(T.CenterCrop(size=crop_size[0] if len(crop_size) == 1 else crop_size)) 106 | res.append(ToTensor()) 107 | 108 | return [T.Compose(res), T.Compose(res)] 109 | -------------------------------------------------------------------------------- /fastreid/engine/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from .train_loop import * 7 | 8 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 9 | 10 | 11 | # prefer to let hooks and defaults live in separate namespaces (therefore not in __all__) 12 | # but still make them available here 13 | from .hooks import * 14 | from .defaults import * 15 | from .launch import * 16 | -------------------------------------------------------------------------------- /fastreid/engine/launch.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | # based on: 8 | # https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/launch.py 9 | 10 | 11 | import logging 12 | 13 | import torch 14 | import torch.distributed as dist 15 | import torch.multiprocessing as mp 16 | 17 | from fastreid.utils import comm 18 | 19 | __all__ = ["launch"] 20 | 21 | 22 | def _find_free_port(): 23 | import socket 24 | 25 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 26 | # Binding to port 0 will cause the OS to find an available port for us 27 | sock.bind(("", 0)) 28 | port = sock.getsockname()[1] 29 | sock.close() 30 | # NOTE: there is still a chance the port could be taken by other processes. 31 | return port 32 | 33 | 34 | def launch(main_func, num_gpus_per_machine, num_machines=1, machine_rank=0, dist_url=None, args=()): 35 | """ 36 | Launch multi-gpu or distributed training. 37 | This function must be called on all machines involved in the training. 38 | It will spawn child processes (defined by ``num_gpus_per_machine`) on each machine. 39 | Args: 40 | main_func: a function that will be called by `main_func(*args)` 41 | num_gpus_per_machine (int): number of GPUs per machine 42 | num_machines (int): the total number of machines 43 | machine_rank (int): the rank of this machine 44 | dist_url (str): url to connect to for distributed jobs, including protocol 45 | e.g. "tcp://127.0.0.1:8686". 46 | Can be set to "auto" to automatically select a free port on localhost 47 | args (tuple): arguments passed to main_func 48 | """ 49 | main_func(*args) 50 | # world_size = num_machines * num_gpus_per_machine 51 | # if world_size > 1: 52 | # # https://github.com/pytorch/pytorch/pull/14391 53 | # # TODO prctl in spawned processes 54 | 55 | # if dist_url == "auto": 56 | # assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs." 57 | # port = _find_free_port() 58 | # dist_url = f"tcp://127.0.0.1:{port}" 59 | # if num_machines > 1 and dist_url.startswith("file://"): 60 | # logger = logging.getLogger(__name__) 61 | # logger.warning( 62 | # "file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://" 63 | # ) 64 | 65 | # mp.spawn( 66 | # _distributed_worker, 67 | # nprocs=num_gpus_per_machine, 68 | # args=(main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args), 69 | # daemon=False, 70 | # ) 71 | # else: 72 | # main_func(*args) 73 | 74 | 75 | def _distributed_worker( 76 | local_rank, main_func, world_size, num_gpus_per_machine, machine_rank, dist_url, args 77 | ): 78 | assert torch.cuda.is_available(), "cuda is not available. Please check your installation." 79 | global_rank = machine_rank * num_gpus_per_machine + local_rank 80 | try: 81 | dist.init_process_group( 82 | backend="NCCL", init_method=dist_url, world_size=world_size, rank=global_rank 83 | ) 84 | except Exception as e: 85 | logger = logging.getLogger(__name__) 86 | logger.error("Process group URL: {}".format(dist_url)) 87 | raise e 88 | # synchronize is needed here to prevent a possible timeout after calling init_process_group 89 | # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172 90 | comm.synchronize() 91 | 92 | assert num_gpus_per_machine <= torch.cuda.device_count() 93 | torch.cuda.set_device(local_rank) 94 | 95 | # Setup the local process group (which contains ranks within the same machine) 96 | assert comm._LOCAL_PROCESS_GROUP is None 97 | num_machines = world_size // num_gpus_per_machine 98 | for i in range(num_machines): 99 | ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine)) 100 | pg = dist.new_group(ranks_on_i) 101 | if i == machine_rank: 102 | comm._LOCAL_PROCESS_GROUP = pg 103 | 104 | main_func(*args) 105 | -------------------------------------------------------------------------------- /fastreid/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluator import DatasetEvaluator, inference_context, inference_on_dataset 2 | from .rank import evaluate_rank 3 | from .reid_evaluation import ReidEvaluator 4 | from .clas_evaluator import ClasEvaluator 5 | from .roc import evaluate_roc 6 | from .testing import print_csv_format, verify_results 7 | 8 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 9 | -------------------------------------------------------------------------------- /fastreid/evaluation/clas_evaluator.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import copy 8 | import itertools 9 | import logging 10 | from collections import OrderedDict 11 | 12 | import torch 13 | 14 | from fastreid.utils import comm 15 | from .evaluator import DatasetEvaluator 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def accuracy(output, target, topk=(1,)): 21 | """Computes the accuracy over the k top predictions for the specified values of k""" 22 | with torch.no_grad(): 23 | maxk = max(topk) 24 | batch_size = target.size(0) 25 | 26 | _, pred = output.topk(maxk, 1, True, True) 27 | pred = pred.t() 28 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 29 | 30 | res = [] 31 | for k in topk: 32 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 33 | res.append(correct_k.mul_(100.0 / batch_size)) 34 | return res 35 | 36 | 37 | class ClasEvaluator(DatasetEvaluator): 38 | def __init__(self, cfg, output_dir=None): 39 | self.cfg = cfg 40 | self._output_dir = output_dir 41 | self._cpu_device = torch.device('cpu') 42 | 43 | self._predictions = [] 44 | 45 | def reset(self): 46 | self._predictions = [] 47 | 48 | def process(self, inputs, outputs): 49 | pred_logits = outputs.to(self._cpu_device, torch.float32) 50 | labels = inputs["targets"].to(self._cpu_device) 51 | 52 | # measure accuracy 53 | acc1, = accuracy(pred_logits, labels, topk=(1,)) 54 | num_correct_acc1 = acc1 * labels.size(0) / 100 55 | 56 | self._predictions.append({"num_correct": num_correct_acc1, "num_samples": labels.size(0)}) 57 | 58 | def evaluate(self): 59 | if comm.get_world_size() > 1: 60 | comm.synchronize() 61 | predictions = comm.gather(self._predictions, dst=0) 62 | predictions = list(itertools.chain(*predictions)) 63 | 64 | if not comm.is_main_process(): return {} 65 | 66 | else: 67 | predictions = self._predictions 68 | 69 | total_correct_num = 0 70 | total_samples = 0 71 | for prediction in predictions: 72 | total_correct_num += prediction["num_correct"] 73 | total_samples += prediction["num_samples"] 74 | 75 | acc1 = total_correct_num / total_samples * 100 76 | 77 | self._results = OrderedDict() 78 | self._results["Acc@1"] = acc1 79 | self._results["metric"] = acc1 80 | 81 | return copy.deepcopy(self._results) 82 | -------------------------------------------------------------------------------- /fastreid/evaluation/query_expansion.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | # based on 8 | # https://github.com/PyRetri/PyRetri/blob/master/pyretri/index/re_ranker/re_ranker_impl/query_expansion.py 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | 15 | def aqe(query_feat: torch.tensor, gallery_feat: torch.tensor, 16 | qe_times: int = 1, qe_k: int = 10, alpha: float = 3.0): 17 | """ 18 | Combining the retrieved topk nearest neighbors with the original query and doing another retrieval. 19 | c.f. https://www.robots.ox.ac.uk/~vgg/publications/papers/chum07b.pdf 20 | Args : 21 | query_feat (torch.tensor): 22 | gallery_feat (torch.tensor): 23 | qe_times (int): number of query expansion times. 24 | qe_k (int): number of the neighbors to be combined. 25 | alpha (float): 26 | """ 27 | num_query = query_feat.shape[0] 28 | all_feat = torch.cat((query_feat, gallery_feat), dim=0) 29 | norm_feat = F.normalize(all_feat, p=2, dim=1) 30 | 31 | all_feat = all_feat.numpy() 32 | for i in range(qe_times): 33 | all_feat_list = [] 34 | sims = torch.mm(norm_feat, norm_feat.t()) 35 | sims = sims.data.cpu().numpy() 36 | for sim in sims: 37 | init_rank = np.argpartition(-sim, range(1, qe_k + 1)) 38 | weights = sim[init_rank[:qe_k]].reshape((-1, 1)) 39 | weights = np.power(weights, alpha) 40 | all_feat_list.append(np.mean(all_feat[init_rank[:qe_k], :] * weights, axis=0)) 41 | all_feat = np.stack(all_feat_list, axis=0) 42 | norm_feat = F.normalize(torch.from_numpy(all_feat), p=2, dim=1) 43 | 44 | query_feat = torch.from_numpy(all_feat[:num_query]) 45 | gallery_feat = torch.from_numpy(all_feat[num_query:]) 46 | return query_feat, gallery_feat 47 | -------------------------------------------------------------------------------- /fastreid/evaluation/rank_cylib/Makefile: -------------------------------------------------------------------------------- 1 | all: 2 | python3 setup.py build_ext --inplace 3 | rm -rf build 4 | python3 test_cython.py 5 | clean: 6 | rm -rf build 7 | rm -f rank_cy.c *.so 8 | -------------------------------------------------------------------------------- /fastreid/evaluation/rank_cylib/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ -------------------------------------------------------------------------------- /fastreid/evaluation/rank_cylib/roc_cy.pyx: -------------------------------------------------------------------------------- 1 | # cython: boundscheck=False, wraparound=False, nonecheck=False, cdivision=True 2 | # credits: https://github.com/KaiyangZhou/deep-person-reid/blob/master/torchreid/metrics/rank_cylib/rank_cy.pyx 3 | 4 | import cython 5 | import faiss 6 | import numpy as np 7 | cimport numpy as np 8 | 9 | 10 | """ 11 | Compiler directives: 12 | https://github.com/cython/cython/wiki/enhancements-compilerdirectives 13 | Cython tutorial: 14 | https://cython.readthedocs.io/en/latest/src/userguide/numpy_tutorial.html 15 | Credit to https://github.com/luzai 16 | """ 17 | 18 | 19 | # Main interface 20 | cpdef evaluate_roc_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids, 21 | long[:]q_camids, long[:]g_camids): 22 | 23 | distmat = np.asarray(distmat, dtype=np.float32) 24 | q_pids = np.asarray(q_pids, dtype=np.int64) 25 | g_pids = np.asarray(g_pids, dtype=np.int64) 26 | q_camids = np.asarray(q_camids, dtype=np.int64) 27 | g_camids = np.asarray(g_camids, dtype=np.int64) 28 | 29 | cdef long num_q = distmat.shape[0] 30 | cdef long num_g = distmat.shape[1] 31 | 32 | cdef: 33 | long[:,:] indices = np.argsort(distmat, axis=1) 34 | long[:,:] matches = (np.asarray(g_pids)[np.asarray(indices)] == np.asarray(q_pids)[:, np.newaxis]).astype(np.int64) 35 | 36 | float[:] pos = np.zeros(num_q*num_g, dtype=np.float32) 37 | float[:] neg = np.zeros(num_q*num_g, dtype=np.float32) 38 | 39 | long valid_pos = 0 40 | long valid_neg = 0 41 | long ind 42 | 43 | long q_idx, q_pid, q_camid, g_idx 44 | long[:] order = np.zeros(num_g, dtype=np.int64) 45 | 46 | float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches 47 | long[:] sort_idx = np.zeros(num_g, dtype=np.int64) 48 | 49 | long idx 50 | 51 | for q_idx in range(num_q): 52 | # get query pid and camid 53 | q_pid = q_pids[q_idx] 54 | q_camid = q_camids[q_idx] 55 | 56 | for g_idx in range(num_g): 57 | order[g_idx] = indices[q_idx, g_idx] 58 | num_g_real = 0 59 | 60 | # remove gallery samples that have the same pid and camid with query 61 | for g_idx in range(num_g): 62 | if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid): 63 | raw_cmc[num_g_real] = matches[q_idx][g_idx] 64 | sort_idx[num_g_real] = order[g_idx] 65 | num_g_real += 1 66 | 67 | q_dist = distmat[q_idx] 68 | 69 | for valid_idx in range(num_g_real): 70 | if raw_cmc[valid_idx] == 1: 71 | pos[valid_pos] = q_dist[sort_idx[valid_idx]] 72 | valid_pos += 1 73 | elif raw_cmc[valid_idx] == 0: 74 | neg[valid_neg] = q_dist[sort_idx[valid_idx]] 75 | valid_neg += 1 76 | 77 | cdef float[:] scores = np.hstack((pos[:valid_pos], neg[:valid_neg])) 78 | cdef float[:] labels = np.hstack((np.zeros(valid_pos, dtype=np.float32), 79 | np.ones(valid_neg, dtype=np.float32))) 80 | return np.asarray(scores), np.asarray(labels) 81 | 82 | 83 | # Compute the cumulative sum 84 | cdef void function_cumsum(cython.numeric[:] src, cython.numeric[:] dst, long n): 85 | cdef long i 86 | dst[0] = src[0] 87 | for i in range(1, n): 88 | dst[i] = src[i] + dst[i - 1] -------------------------------------------------------------------------------- /fastreid/evaluation/rank_cylib/setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from distutils.extension import Extension 3 | 4 | import numpy as np 5 | from Cython.Build import cythonize 6 | 7 | 8 | def numpy_include(): 9 | try: 10 | numpy_include = np.get_include() 11 | except AttributeError: 12 | numpy_include = np.get_numpy_include() 13 | return numpy_include 14 | 15 | 16 | ext_modules = [ 17 | Extension( 18 | 'rank_cy', 19 | ['rank_cy.pyx'], 20 | include_dirs=[numpy_include()], 21 | ), 22 | Extension( 23 | 'roc_cy', 24 | ['roc_cy.pyx'], 25 | include_dirs=[numpy_include()], 26 | ) 27 | ] 28 | 29 | setup( 30 | name='Cython-based reid evaluation code', 31 | ext_modules=cythonize(ext_modules) 32 | ) 33 | -------------------------------------------------------------------------------- /fastreid/evaluation/rank_cylib/test_cython.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import timeit 3 | import numpy as np 4 | import os.path as osp 5 | 6 | sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..') 7 | 8 | from fastreid.evaluation import evaluate_rank 9 | from fastreid.evaluation import evaluate_roc 10 | 11 | """ 12 | Test the speed of cython-based evaluation code. The speed improvements 13 | can be much bigger when using the real reid data, which contains a larger 14 | amount of query and gallery images. 15 | Note: you might encounter the following error: 16 | 'AssertionError: Error: all query identities do not appear in gallery'. 17 | This is normal because the inputs are random numbers. Just try again. 18 | """ 19 | 20 | print('*** Compare running time ***') 21 | 22 | setup = ''' 23 | import sys 24 | import os.path as osp 25 | import numpy as np 26 | sys.path.insert(0, osp.dirname(osp.abspath(__file__)) + '/../../..') 27 | from fastreid.evaluation import evaluate_rank 28 | from fastreid.evaluation import evaluate_roc 29 | num_q = 30 30 | num_g = 300 31 | dim = 512 32 | max_rank = 5 33 | q_feats = np.random.rand(num_q, dim).astype(np.float32) * 20 34 | q_feats = q_feats / np.linalg.norm(q_feats, ord=2, axis=1, keepdims=True) 35 | g_feats = np.random.rand(num_g, dim).astype(np.float32) * 20 36 | g_feats = g_feats / np.linalg.norm(g_feats, ord=2, axis=1, keepdims=True) 37 | distmat = 1 - np.dot(q_feats, g_feats.transpose()) 38 | q_pids = np.random.randint(0, num_q, size=num_q) 39 | g_pids = np.random.randint(0, num_g, size=num_g) 40 | q_camids = np.random.randint(0, 5, size=num_q) 41 | g_camids = np.random.randint(0, 5, size=num_g) 42 | ''' 43 | 44 | print('=> Using CMC metric') 45 | pytime = timeit.timeit( 46 | 'evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)', 47 | setup=setup, 48 | number=20 49 | ) 50 | cytime = timeit.timeit( 51 | 'evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)', 52 | setup=setup, 53 | number=20 54 | ) 55 | print('Python time: {} s'.format(pytime)) 56 | print('Cython time: {} s'.format(cytime)) 57 | print('CMC Cython is {} times faster than python\n'.format(pytime / cytime)) 58 | 59 | print('=> Using ROC metric') 60 | pytime = timeit.timeit( 61 | 'evaluate_roc(distmat, q_pids, g_pids, q_camids, g_camids, use_cython=False)', 62 | setup=setup, 63 | number=20 64 | ) 65 | cytime = timeit.timeit( 66 | 'evaluate_roc(distmat, q_pids, g_pids, q_camids, g_camids, use_cython=True)', 67 | setup=setup, 68 | number=20 69 | ) 70 | print('Python time: {} s'.format(pytime)) 71 | print('Cython time: {} s'.format(cytime)) 72 | print('ROC Cython is {} times faster than python\n'.format(pytime / cytime)) 73 | 74 | print("=> Check precision") 75 | num_q = 30 76 | num_g = 300 77 | dim = 512 78 | max_rank = 5 79 | q_feats = np.random.rand(num_q, dim).astype(np.float32) * 20 80 | q_feats = q_feats / np.linalg.norm(q_feats, ord=2, axis=1, keepdims=True) 81 | g_feats = np.random.rand(num_g, dim).astype(np.float32) * 20 82 | g_feats = g_feats / np.linalg.norm(g_feats, ord=2, axis=1, keepdims=True) 83 | distmat = 1 - np.dot(q_feats, g_feats.transpose()) 84 | q_pids = np.random.randint(0, num_q, size=num_q) 85 | g_pids = np.random.randint(0, num_g, size=num_g) 86 | q_camids = np.random.randint(0, 5, size=num_q) 87 | g_camids = np.random.randint(0, 5, size=num_g) 88 | 89 | cmc_py, mAP_py, mINP_py = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False) 90 | 91 | cmc_cy, mAP_cy, mINP_cy = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True) 92 | 93 | np.testing.assert_allclose(cmc_py, cmc_cy, rtol=1e-3, atol=1e-6) 94 | np.testing.assert_allclose(mAP_py, mAP_cy, rtol=1e-3, atol=1e-6) 95 | np.testing.assert_allclose(mINP_py, mINP_cy, rtol=1e-3, atol=1e-6) 96 | print('Rank results between python and cython are the same!') 97 | 98 | scores_cy, labels_cy = evaluate_roc(distmat, q_pids, g_pids, q_camids, g_camids, use_cython=True) 99 | scores_py, labels_py = evaluate_roc(distmat, q_pids, g_pids, q_camids, g_camids, use_cython=False) 100 | 101 | np.testing.assert_allclose(scores_cy, scores_py, rtol=1e-3, atol=1e-6) 102 | np.testing.assert_allclose(labels_cy, labels_py, rtol=1e-3, atol=1e-6) 103 | print('ROC results between python and cython are the same!\n') 104 | 105 | print("=> Check exact values") 106 | print("mAP = {} \ncmc = {}\nmINP = {}\nScores = {}".format(np.array(mAP_cy), cmc_cy, np.array(mINP_cy), scores_cy)) 107 | -------------------------------------------------------------------------------- /fastreid/evaluation/rerank.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | # based on: 4 | # https://github.com/zhunzhong07/person-re-ranking 5 | 6 | __all__ = ['re_ranking'] 7 | 8 | import numpy as np 9 | 10 | 11 | def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1: int = 20, k2: int = 6, lambda_value: float = 0.3): 12 | original_dist = np.concatenate( 13 | [np.concatenate([q_q_dist, q_g_dist], axis=1), 14 | np.concatenate([q_g_dist.T, g_g_dist], axis=1)], 15 | axis=0) 16 | original_dist = np.power(original_dist, 2).astype(np.float32) 17 | original_dist = np.transpose(1. * original_dist / np.max(original_dist, axis=0)) 18 | V = np.zeros_like(original_dist).astype(np.float32) 19 | initial_rank = np.argsort(original_dist).astype(np.int32) 20 | 21 | query_num = q_g_dist.shape[0] 22 | gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1] 23 | all_num = gallery_num 24 | 25 | for i in range(all_num): 26 | # k-reciprocal neighbors 27 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 28 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 29 | fi = np.where(backward_k_neigh_index == i)[0] 30 | k_reciprocal_index = forward_k_neigh_index[fi] 31 | k_reciprocal_expansion_index = k_reciprocal_index 32 | for j in range(len(k_reciprocal_index)): 33 | candidate = k_reciprocal_index[j] 34 | candidate_forward_k_neigh_index = initial_rank[candidate, 35 | :int(np.around(k1 / 2.)) + 1] 36 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 37 | :int(np.around(k1 / 2.)) + 1] 38 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 39 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 40 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2. / 3 * len( 41 | candidate_k_reciprocal_index): 42 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 43 | 44 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 45 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 46 | V[i, k_reciprocal_expansion_index] = 1. * weight / np.sum(weight) 47 | original_dist = original_dist[:query_num, ] 48 | if k2 != 1: 49 | V_qe = np.zeros_like(V, dtype=np.float32) 50 | for i in range(all_num): 51 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 52 | V = V_qe 53 | del V_qe 54 | del initial_rank 55 | invIndex = [] 56 | for i in range(gallery_num): 57 | invIndex.append(np.where(V[:, i] != 0)[0]) 58 | 59 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float32) 60 | 61 | for i in range(query_num): 62 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float32) 63 | indNonZero = np.where(V[i, :] != 0)[0] 64 | indImages = [invIndex[ind] for ind in indNonZero] 65 | for j in range(len(indNonZero)): 66 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 67 | V[indImages[j], indNonZero[j]]) 68 | jaccard_dist[i] = 1 - temp_min / (2. - temp_min) 69 | 70 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 71 | del original_dist, V, jaccard_dist 72 | final_dist = final_dist[:query_num, query_num:] 73 | return final_dist 74 | -------------------------------------------------------------------------------- /fastreid/evaluation/roc.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: l1aoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import warnings 8 | 9 | import faiss 10 | import numpy as np 11 | 12 | try: 13 | from .rank_cylib.roc_cy import evaluate_roc_cy 14 | 15 | IS_CYTHON_AVAI = True 16 | except ImportError: 17 | IS_CYTHON_AVAI = False 18 | warnings.warn( 19 | 'Cython roc evaluation (very fast so highly recommended) is ' 20 | 'unavailable, now use python evaluation.' 21 | ) 22 | 23 | 24 | def evaluate_roc_py(distmat, q_pids, g_pids, q_camids, g_camids): 25 | r"""Evaluation with ROC curve. 26 | Key: for each query identity, its gallery images from the same camera view are discarded. 27 | 28 | Args: 29 | distmat (np.ndarray): cosine distance matrix 30 | """ 31 | num_q, num_g = distmat.shape 32 | 33 | indices = np.argsort(distmat, axis=1) 34 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 35 | 36 | pos = [] 37 | neg = [] 38 | for q_idx in range(num_q): 39 | # get query pid and camid 40 | q_pid = q_pids[q_idx] 41 | q_camid = q_camids[q_idx] 42 | 43 | # Remove gallery samples that have the same pid and camid with query 44 | order = indices[q_idx] 45 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 46 | keep = np.invert(remove) 47 | raw_cmc = matches[q_idx][keep] 48 | 49 | sort_idx = order[keep] 50 | 51 | q_dist = distmat[q_idx] 52 | ind_pos = np.where(raw_cmc == 1)[0] 53 | pos.extend(q_dist[sort_idx[ind_pos]]) 54 | 55 | ind_neg = np.where(raw_cmc == 0)[0] 56 | neg.extend(q_dist[sort_idx[ind_neg]]) 57 | 58 | scores = np.hstack((pos, neg)) 59 | 60 | labels = np.hstack((np.zeros(len(pos)), np.ones(len(neg)))) 61 | return scores, labels 62 | 63 | 64 | def evaluate_roc( 65 | distmat, 66 | q_pids, 67 | g_pids, 68 | q_camids, 69 | g_camids, 70 | use_cython=True 71 | ): 72 | """Evaluates CMC rank. 73 | Args: 74 | distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery). 75 | q_pids (numpy.ndarray): 1-D array containing person identities 76 | of each query instance. 77 | g_pids (numpy.ndarray): 1-D array containing person identities 78 | of each gallery instance. 79 | q_camids (numpy.ndarray): 1-D array containing camera views under 80 | which each query instance is captured. 81 | g_camids (numpy.ndarray): 1-D array containing camera views under 82 | which each gallery instance is captured. 83 | use_cython (bool, optional): use cython code for evaluation. Default is True. 84 | This is highly recommended as the cython code can speed up the cmc computation 85 | by more than 10x. This requires Cython to be installed. 86 | """ 87 | if use_cython and IS_CYTHON_AVAI: 88 | return evaluate_roc_cy(distmat, q_pids, g_pids, q_camids, g_camids) 89 | else: 90 | return evaluate_roc_py(distmat, q_pids, g_pids, q_camids, g_camids) 91 | -------------------------------------------------------------------------------- /fastreid/evaluation/testing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import logging 3 | import pprint 4 | import sys 5 | from collections import Mapping, OrderedDict 6 | 7 | import numpy as np 8 | from tabulate import tabulate 9 | from termcolor import colored 10 | 11 | 12 | def print_csv_format(results): 13 | """ 14 | Print main metrics in a format similar to Detectron2, 15 | so that they are easy to copypaste into a spreadsheet. 16 | Args: 17 | results (OrderedDict): {metric -> score} 18 | """ 19 | # unordered results cannot be properly printed 20 | assert isinstance(results, OrderedDict) or not len(results), results 21 | logger = logging.getLogger(__name__) 22 | 23 | dataset_name = results.pop('dataset') 24 | metrics = ["Dataset"] + [k for k in results] 25 | csv_results = [(dataset_name, *list(results.values()))] 26 | 27 | # tabulate it 28 | table = tabulate( 29 | csv_results, 30 | tablefmt="pipe", 31 | floatfmt=".2f", 32 | headers=metrics, 33 | numalign="left", 34 | ) 35 | 36 | logger.info("Evaluation results in csv format: \n" + colored(table, "cyan")) 37 | 38 | 39 | def verify_results(cfg, results): 40 | """ 41 | Args: 42 | results (OrderedDict[dict]): task_name -> {metric -> score} 43 | Returns: 44 | bool: whether the verification succeeds or not 45 | """ 46 | expected_results = cfg.TEST.EXPECTED_RESULTS 47 | if not len(expected_results): 48 | return True 49 | 50 | ok = True 51 | for task, metric, expected, tolerance in expected_results: 52 | actual = results[task][metric] 53 | if not np.isfinite(actual): 54 | ok = False 55 | diff = abs(actual - expected) 56 | if diff > tolerance: 57 | ok = False 58 | 59 | logger = logging.getLogger(__name__) 60 | if not ok: 61 | logger.error("Result verification failed!") 62 | logger.error("Expected Results: " + str(expected_results)) 63 | logger.error("Actual Results: " + pprint.pformat(results)) 64 | 65 | sys.exit(1) 66 | else: 67 | logger.info("Results verification passed.") 68 | return ok 69 | 70 | 71 | def flatten_results_dict(results): 72 | """ 73 | Expand a hierarchical dict of scalars into a flat dict of scalars. 74 | If results[k1][k2][k3] = v, the returned dict will have the entry 75 | {"k1/k2/k3": v}. 76 | Args: 77 | results (dict): 78 | """ 79 | r = {} 80 | for k, v in results.items(): 81 | if isinstance(v, Mapping): 82 | v = flatten_results_dict(v) 83 | for kk, vv in v.items(): 84 | r[k + "/" + kk] = vv 85 | else: 86 | r[k] = v 87 | return r 88 | -------------------------------------------------------------------------------- /fastreid/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .activation import * 8 | from .batch_norm import * 9 | from .context_block import ContextBlock 10 | from .drop import DropPath, DropBlock2d, drop_block_2d, drop_path 11 | from .frn import FRN, TLU 12 | from .gather_layer import GatherLayer 13 | from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible 14 | from .non_local import Non_local 15 | from .se_layer import SELayer 16 | from .splat import SplAtConv2d, DropBlock2D 17 | from .weight_init import ( 18 | trunc_normal_, variance_scaling_, lecun_normal_, weights_init_kaiming, weights_init_classifier 19 | ) 20 | -------------------------------------------------------------------------------- /fastreid/layers/activation.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | __all__ = [ 14 | 'Mish', 15 | 'Swish', 16 | 'MemoryEfficientSwish', 17 | 'GELU'] 18 | 19 | 20 | class Mish(nn.Module): 21 | def __init__(self): 22 | super().__init__() 23 | 24 | def forward(self, x): 25 | # inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!) 26 | return x * (torch.tanh(F.softplus(x))) 27 | 28 | 29 | class Swish(nn.Module): 30 | def forward(self, x): 31 | return x * torch.sigmoid(x) 32 | 33 | 34 | class SwishImplementation(torch.autograd.Function): 35 | @staticmethod 36 | def forward(ctx, i): 37 | result = i * torch.sigmoid(i) 38 | ctx.save_for_backward(i) 39 | return result 40 | 41 | @staticmethod 42 | def backward(ctx, grad_output): 43 | i = ctx.saved_variables[0] 44 | sigmoid_i = torch.sigmoid(i) 45 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) 46 | 47 | 48 | class MemoryEfficientSwish(nn.Module): 49 | def forward(self, x): 50 | return SwishImplementation.apply(x) 51 | 52 | 53 | class GELU(nn.Module): 54 | """ 55 | Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU 56 | """ 57 | 58 | def forward(self, x): 59 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 60 | -------------------------------------------------------------------------------- /fastreid/layers/any_softmax.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | __all__ = [ 11 | "Linear", 12 | "ArcSoftmax", 13 | "CosSoftmax", 14 | "CircleSoftmax" 15 | ] 16 | 17 | 18 | class Linear(nn.Module): 19 | def __init__(self, num_classes, scale, margin): 20 | super().__init__() 21 | self.num_classes = num_classes 22 | self.s = scale 23 | self.m = margin 24 | 25 | def forward(self, logits, targets): 26 | return logits.mul_(self.s) 27 | 28 | def extra_repr(self): 29 | return f"num_classes={self.num_classes}, scale={self.s}, margin={self.m}" 30 | 31 | 32 | class CosSoftmax(Linear): 33 | r"""Implement of large margin cosine distance: 34 | """ 35 | 36 | def forward(self, logits, targets): 37 | index = torch.where(targets != -1)[0] 38 | m_hot = torch.zeros(index.size()[0], logits.size()[1], device=logits.device, dtype=logits.dtype) 39 | m_hot.scatter_(1, targets[index, None], self.m) 40 | logits[index] -= m_hot 41 | logits.mul_(self.s) 42 | return logits 43 | 44 | 45 | class ArcSoftmax(Linear): 46 | 47 | def forward(self, logits, targets): 48 | index = torch.where(targets != -1)[0] 49 | m_hot = torch.zeros(index.size()[0], logits.size()[1], device=logits.device, dtype=logits.dtype) 50 | m_hot.scatter_(1, targets[index, None], self.m) 51 | logits.acos_() 52 | logits[index] += m_hot 53 | logits.cos_().mul_(self.s) 54 | return logits 55 | 56 | 57 | class CircleSoftmax(Linear): 58 | 59 | def forward(self, logits, targets): 60 | alpha_p = torch.clamp_min(-logits.detach() + 1 + self.m, min=0.) 61 | alpha_n = torch.clamp_min(logits.detach() + self.m, min=0.) 62 | delta_p = 1 - self.m 63 | delta_n = self.m 64 | 65 | # When use model parallel, there are some targets not in class centers of local rank 66 | index = torch.where(targets != -1)[0] 67 | m_hot = torch.zeros(index.size()[0], logits.size()[1], device=logits.device, dtype=logits.dtype) 68 | m_hot.scatter_(1, targets[index, None], 1) 69 | 70 | logits_p = alpha_p * (logits - delta_p) 71 | logits_n = alpha_n * (logits - delta_n) 72 | 73 | logits[index] = logits_p[index] * m_hot + logits_n[index] * (1 - m_hot) 74 | 75 | neg_index = torch.where(targets == -1)[0] 76 | logits[neg_index] = logits_n[neg_index] 77 | 78 | logits.mul_(self.s) 79 | 80 | return logits 81 | -------------------------------------------------------------------------------- /fastreid/layers/gather_layer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | # based on: https://github.com/open-mmlab/OpenSelfSup/blob/master/openselfsup/models/utils/gather_layer.py 8 | 9 | import torch 10 | import torch.distributed as dist 11 | 12 | 13 | class GatherLayer(torch.autograd.Function): 14 | """Gather tensors from all process, supporting backward propagation. 15 | """ 16 | 17 | @staticmethod 18 | def forward(ctx, input): 19 | ctx.save_for_backward(input) 20 | output = [torch.zeros_like(input) \ 21 | for _ in range(dist.get_world_size())] 22 | dist.all_gather(output, input) 23 | return tuple(output) 24 | 25 | @staticmethod 26 | def backward(ctx, *grads): 27 | input, = ctx.saved_tensors 28 | grad_out = torch.zeros_like(input) 29 | grad_out[:] = grads[dist.get_rank()] 30 | return grad_out 31 | -------------------------------------------------------------------------------- /fastreid/layers/helpers.py: -------------------------------------------------------------------------------- 1 | """ Layer/Module Helpers 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | import collections.abc 5 | from itertools import repeat 6 | 7 | 8 | # From PyTorch internals 9 | def _ntuple(n): 10 | def parse(x): 11 | if isinstance(x, collections.abc.Iterable): 12 | return x 13 | return tuple(repeat(x, n)) 14 | 15 | return parse 16 | 17 | 18 | to_1tuple = _ntuple(1) 19 | to_2tuple = _ntuple(2) 20 | to_3tuple = _ntuple(3) 21 | to_4tuple = _ntuple(4) 22 | to_ntuple = _ntuple 23 | 24 | 25 | def make_divisible(v, divisor=8, min_value=None): 26 | min_value = min_value or divisor 27 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 28 | # Make sure that round down does not go down by more than 10%. 29 | if new_v < 0.9 * v: 30 | new_v += divisor 31 | return new_v 32 | -------------------------------------------------------------------------------- /fastreid/layers/non_local.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | 3 | 4 | import torch 5 | from torch import nn 6 | from .batch_norm import get_norm 7 | 8 | 9 | class Non_local(nn.Module): 10 | def __init__(self, in_channels, bn_norm, reduc_ratio=2): 11 | super(Non_local, self).__init__() 12 | 13 | self.in_channels = in_channels 14 | self.inter_channels = reduc_ratio // reduc_ratio 15 | 16 | self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 17 | kernel_size=1, stride=1, padding=0) 18 | 19 | self.W = nn.Sequential( 20 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, 21 | kernel_size=1, stride=1, padding=0), 22 | get_norm(bn_norm, self.in_channels), 23 | ) 24 | nn.init.constant_(self.W[1].weight, 0.0) 25 | nn.init.constant_(self.W[1].bias, 0.0) 26 | 27 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 28 | kernel_size=1, stride=1, padding=0) 29 | 30 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 31 | kernel_size=1, stride=1, padding=0) 32 | 33 | def forward(self, x): 34 | """ 35 | :param x: (b, t, h, w) 36 | :return x: (b, t, h, w) 37 | """ 38 | batch_size = x.size(0) 39 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 40 | g_x = g_x.permute(0, 2, 1) 41 | 42 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 43 | theta_x = theta_x.permute(0, 2, 1) 44 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 45 | f = torch.matmul(theta_x, phi_x) 46 | N = f.size(-1) 47 | f_div_C = f / N 48 | 49 | y = torch.matmul(f_div_C, g_x) 50 | y = y.permute(0, 2, 1).contiguous() 51 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 52 | W_y = self.W(y) 53 | z = W_y + x 54 | return z 55 | -------------------------------------------------------------------------------- /fastreid/layers/pooling.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: l1aoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import nn 10 | 11 | __all__ = [ 12 | 'Identity', 13 | 'Flatten', 14 | 'GlobalAvgPool', 15 | 'GlobalMaxPool', 16 | 'GeneralizedMeanPooling', 17 | 'GeneralizedMeanPoolingP', 18 | 'FastGlobalAvgPool', 19 | 'AdaptiveAvgMaxPool', 20 | 'ClipGlobalAvgPool', 21 | ] 22 | 23 | 24 | class Identity(nn.Module): 25 | def __init__(self, *args, **kwargs): 26 | super().__init__() 27 | 28 | def forward(self, input): 29 | return input 30 | 31 | 32 | class Flatten(nn.Module): 33 | def __init__(self, *args, **kwargs): 34 | super().__init__() 35 | 36 | def forward(self, input): 37 | return input.view(input.size(0), -1, 1, 1) 38 | 39 | 40 | class GlobalAvgPool(nn.AdaptiveAvgPool2d): 41 | def __init__(self, output_size=1, *args, **kwargs): 42 | super().__init__(output_size) 43 | 44 | 45 | class GlobalMaxPool(nn.AdaptiveMaxPool2d): 46 | def __init__(self, output_size=1, *args, **kwargs): 47 | super().__init__(output_size) 48 | 49 | 50 | class GeneralizedMeanPooling(nn.Module): 51 | r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes. 52 | The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)` 53 | - At p = infinity, one gets Max Pooling 54 | - At p = 1, one gets Average Pooling 55 | The output is of size H x W, for any input size. 56 | The number of output features is equal to the number of input planes. 57 | Args: 58 | output_size: the target output size of the image of the form H x W. 59 | Can be a tuple (H, W) or a single H for a square image H x H 60 | H and W can be either a ``int``, or ``None`` which means the size will 61 | be the same as that of the input. 62 | """ 63 | 64 | def __init__(self, norm=3, output_size=(1, 1), eps=1e-6, *args, **kwargs): 65 | super(GeneralizedMeanPooling, self).__init__() 66 | assert norm > 0 67 | self.p = float(norm) 68 | self.output_size = output_size 69 | self.eps = eps 70 | 71 | def forward(self, x): 72 | x = x.clamp(min=self.eps).pow(self.p) 73 | return F.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p) 74 | 75 | def __repr__(self): 76 | return self.__class__.__name__ + '(' \ 77 | + str(self.p) + ', ' \ 78 | + 'output_size=' + str(self.output_size) + ')' 79 | 80 | 81 | class GeneralizedMeanPoolingP(GeneralizedMeanPooling): 82 | """ Same, but norm is trainable 83 | """ 84 | 85 | def __init__(self, norm=3, output_size=(1, 1), eps=1e-6, *args, **kwargs): 86 | super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps) 87 | self.p = nn.Parameter(torch.ones(1) * norm) 88 | 89 | 90 | class AdaptiveAvgMaxPool(nn.Module): 91 | def __init__(self, output_size=1, *args, **kwargs): 92 | super().__init__() 93 | self.gap = FastGlobalAvgPool() 94 | self.gmp = GlobalMaxPool(output_size) 95 | 96 | def forward(self, x): 97 | avg_feat = self.gap(x) 98 | max_feat = self.gmp(x) 99 | feat = avg_feat + max_feat 100 | return feat 101 | 102 | 103 | class FastGlobalAvgPool(nn.Module): 104 | def __init__(self, flatten=False, *args, **kwargs): 105 | super().__init__() 106 | self.flatten = flatten 107 | 108 | def forward(self, x): 109 | if self.flatten: 110 | in_size = x.size() 111 | return x.view((in_size[0], in_size[1], -1)).mean(dim=2) 112 | else: 113 | return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) 114 | 115 | 116 | class ClipGlobalAvgPool(nn.Module): 117 | def __init__(self, *args, **kwargs): 118 | super().__init__() 119 | self.avgpool = FastGlobalAvgPool() 120 | 121 | def forward(self, x): 122 | x = self.avgpool(x) 123 | x = torch.clamp(x, min=0., max=1.) 124 | return x 125 | -------------------------------------------------------------------------------- /fastreid/layers/se_layer.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from torch import nn 8 | 9 | 10 | class SELayer(nn.Module): 11 | def __init__(self, channel, reduction=16): 12 | super(SELayer, self).__init__() 13 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 14 | self.fc = nn.Sequential( 15 | nn.Linear(channel, int(channel / reduction), bias=False), 16 | nn.ReLU(inplace=True), 17 | nn.Linear(int(channel / reduction), channel, bias=False), 18 | nn.Sigmoid() 19 | ) 20 | 21 | def forward(self, x): 22 | b, c, _, _ = x.size() 23 | y = self.avg_pool(x).view(b, c) 24 | y = self.fc(y).view(b, c, 1, 1) 25 | return x * y.expand_as(x) 26 | -------------------------------------------------------------------------------- /fastreid/layers/splat.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import nn 10 | from torch.nn import Conv2d, ReLU 11 | from torch.nn.modules.utils import _pair 12 | from fastreid.layers import get_norm 13 | 14 | 15 | class SplAtConv2d(nn.Module): 16 | """Split-Attention Conv2d 17 | """ 18 | 19 | def __init__(self, in_channels, channels, kernel_size, stride=(1, 1), padding=(0, 0), 20 | dilation=(1, 1), groups=1, bias=True, 21 | radix=2, reduction_factor=4, 22 | rectify=False, rectify_avg=False, norm_layer=None, 23 | dropblock_prob=0.0, **kwargs): 24 | super(SplAtConv2d, self).__init__() 25 | padding = _pair(padding) 26 | self.rectify = rectify and (padding[0] > 0 or padding[1] > 0) 27 | self.rectify_avg = rectify_avg 28 | inter_channels = max(in_channels * radix // reduction_factor, 32) 29 | self.radix = radix 30 | self.cardinality = groups 31 | self.channels = channels 32 | self.dropblock_prob = dropblock_prob 33 | if self.rectify: 34 | from rfconv import RFConv2d 35 | self.conv = RFConv2d(in_channels, channels * radix, kernel_size, stride, padding, dilation, 36 | groups=groups * radix, bias=bias, average_mode=rectify_avg, **kwargs) 37 | else: 38 | self.conv = Conv2d(in_channels, channels * radix, kernel_size, stride, padding, dilation, 39 | groups=groups * radix, bias=bias, **kwargs) 40 | self.use_bn = norm_layer is not None 41 | if self.use_bn: 42 | self.bn0 = get_norm(norm_layer, channels * radix) 43 | self.relu = ReLU(inplace=True) 44 | self.fc1 = Conv2d(channels, inter_channels, 1, groups=self.cardinality) 45 | if self.use_bn: 46 | self.bn1 = get_norm(norm_layer, inter_channels) 47 | self.fc2 = Conv2d(inter_channels, channels * radix, 1, groups=self.cardinality) 48 | if dropblock_prob > 0.0: 49 | self.dropblock = DropBlock2D(dropblock_prob, 3) 50 | self.rsoftmax = rSoftMax(radix, groups) 51 | 52 | def forward(self, x): 53 | x = self.conv(x) 54 | if self.use_bn: 55 | x = self.bn0(x) 56 | if self.dropblock_prob > 0.0: 57 | x = self.dropblock(x) 58 | x = self.relu(x) 59 | 60 | batch, rchannel = x.shape[:2] 61 | if self.radix > 1: 62 | if torch.__version__ < '1.5': 63 | splited = torch.split(x, int(rchannel // self.radix), dim=1) 64 | else: 65 | splited = torch.split(x, rchannel // self.radix, dim=1) 66 | gap = sum(splited) 67 | else: 68 | gap = x 69 | gap = F.adaptive_avg_pool2d(gap, 1) 70 | gap = self.fc1(gap) 71 | 72 | if self.use_bn: 73 | gap = self.bn1(gap) 74 | gap = self.relu(gap) 75 | 76 | atten = self.fc2(gap) 77 | atten = self.rsoftmax(atten).view(batch, -1, 1, 1) 78 | 79 | if self.radix > 1: 80 | if torch.__version__ < '1.5': 81 | attens = torch.split(atten, int(rchannel // self.radix), dim=1) 82 | else: 83 | attens = torch.split(atten, rchannel // self.radix, dim=1) 84 | out = sum([att * split for (att, split) in zip(attens, splited)]) 85 | else: 86 | out = atten * x 87 | return out.contiguous() 88 | 89 | 90 | class rSoftMax(nn.Module): 91 | def __init__(self, radix, cardinality): 92 | super().__init__() 93 | self.radix = radix 94 | self.cardinality = cardinality 95 | 96 | def forward(self, x): 97 | batch = x.size(0) 98 | if self.radix > 1: 99 | x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2) 100 | x = F.softmax(x, dim=1) 101 | x = x.reshape(batch, -1) 102 | else: 103 | x = torch.sigmoid(x) 104 | return x 105 | 106 | 107 | class DropBlock2D(object): 108 | def __init__(self, *args, **kwargs): 109 | raise NotImplementedError 110 | -------------------------------------------------------------------------------- /fastreid/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from . import losses 8 | from .backbones import ( 9 | BACKBONE_REGISTRY, 10 | build_resnet_backbone, 11 | build_backbone, 12 | ) 13 | from .heads import ( 14 | REID_HEADS_REGISTRY, 15 | build_heads, 16 | EmbeddingHead, 17 | ) 18 | from .meta_arch import ( 19 | build_model, 20 | META_ARCH_REGISTRY, 21 | ) 22 | 23 | __all__ = [k for k in globals().keys() if not k.startswith("_")] -------------------------------------------------------------------------------- /fastreid/modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import build_backbone, BACKBONE_REGISTRY 8 | 9 | from .resnet import build_resnet_backbone 10 | from .osnet import build_osnet_backbone 11 | from .resnest import build_resnest_backbone 12 | from .resnext import build_resnext_backbone 13 | from .regnet import build_regnet_backbone, build_effnet_backbone 14 | from .shufflenet import build_shufflenetv2_backbone 15 | from .mobilenet import build_mobilenetv2_backbone 16 | from .repvgg import build_repvgg_backbone 17 | from .vision_transformer import build_vit_backbone 18 | from .meta_dynamic_router_resnet import build_meta_dynamic_router_resnet_backbone 19 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from ...utils.registry import Registry 8 | 9 | BACKBONE_REGISTRY = Registry("BACKBONE") 10 | BACKBONE_REGISTRY.__doc__ = """ 11 | Registry for backbones, which extract feature maps from images 12 | The registered object must be a callable that accepts two arguments: 13 | 1. A :class:`fastreid.config.CfgNode` 14 | It must returns an instance of :class:`Backbone`. 15 | """ 16 | 17 | 18 | def build_backbone(cfg): 19 | """ 20 | Build a backbone from `cfg.MODEL.BACKBONE.NAME`. 21 | Returns: 22 | an instance of :class:`Backbone` 23 | """ 24 | 25 | backbone_name = cfg.MODEL.BACKBONE.NAME 26 | backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg) 27 | return backbone 28 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from .regnet import build_regnet_backbone 4 | from .effnet import build_effnet_backbone 5 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/effnet/EN-B0_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: effnet 3 | NUM_CLASSES: 1000 4 | EN: 5 | STEM_W: 32 6 | STRIDES: [1, 2, 2, 2, 1, 2, 1] 7 | DEPTHS: [1, 2, 2, 3, 3, 4, 1] 8 | WIDTHS: [16, 24, 40, 80, 112, 192, 320] 9 | EXP_RATIOS: [1, 6, 6, 6, 6, 6, 6] 10 | KERNELS: [3, 3, 5, 3, 5, 5, 3] 11 | HEAD_W: 1280 12 | OPTIM: 13 | LR_POLICY: cos 14 | BASE_LR: 0.4 15 | MAX_EPOCH: 100 16 | MOMENTUM: 0.9 17 | WEIGHT_DECAY: 1e-5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 224 21 | BATCH_SIZE: 256 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 256 25 | BATCH_SIZE: 200 26 | NUM_GPUS: 8 27 | OUT_DIR: . 28 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/effnet/EN-B1_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: effnet 3 | NUM_CLASSES: 1000 4 | EN: 5 | STEM_W: 32 6 | STRIDES: [1, 2, 2, 2, 1, 2, 1] 7 | DEPTHS: [2, 3, 3, 4, 4, 5, 2] 8 | WIDTHS: [16, 24, 40, 80, 112, 192, 320] 9 | EXP_RATIOS: [1, 6, 6, 6, 6, 6, 6] 10 | KERNELS: [3, 3, 5, 3, 5, 5, 3] 11 | HEAD_W: 1280 12 | OPTIM: 13 | LR_POLICY: cos 14 | BASE_LR: 0.4 15 | MAX_EPOCH: 100 16 | MOMENTUM: 0.9 17 | WEIGHT_DECAY: 1e-5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 240 21 | BATCH_SIZE: 256 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 274 25 | BATCH_SIZE: 200 26 | NUM_GPUS: 8 27 | OUT_DIR: . 28 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/effnet/EN-B2_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: effnet 3 | NUM_CLASSES: 1000 4 | EN: 5 | STEM_W: 32 6 | STRIDES: [1, 2, 2, 2, 1, 2, 1] 7 | DEPTHS: [2, 3, 3, 4, 4, 5, 2] 8 | WIDTHS: [16, 24, 48, 88, 120, 208, 352] 9 | EXP_RATIOS: [1, 6, 6, 6, 6, 6, 6] 10 | KERNELS: [3, 3, 5, 3, 5, 5, 3] 11 | HEAD_W: 1408 12 | OPTIM: 13 | LR_POLICY: cos 14 | BASE_LR: 0.4 15 | MAX_EPOCH: 100 16 | MOMENTUM: 0.9 17 | WEIGHT_DECAY: 1e-5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 260 21 | BATCH_SIZE: 256 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 298 25 | BATCH_SIZE: 200 26 | NUM_GPUS: 8 27 | OUT_DIR: . 28 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/effnet/EN-B3_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: effnet 3 | NUM_CLASSES: 1000 4 | EN: 5 | STEM_W: 40 6 | STRIDES: [1, 2, 2, 2, 1, 2, 1] 7 | DEPTHS: [2, 3, 3, 5, 5, 6, 2] 8 | WIDTHS: [24, 32, 48, 96, 136, 232, 384] 9 | EXP_RATIOS: [1, 6, 6, 6, 6, 6, 6] 10 | KERNELS: [3, 3, 5, 3, 5, 5, 3] 11 | HEAD_W: 1536 12 | OPTIM: 13 | LR_POLICY: cos 14 | BASE_LR: 0.4 15 | MAX_EPOCH: 100 16 | MOMENTUM: 0.9 17 | WEIGHT_DECAY: 1e-5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 300 21 | BATCH_SIZE: 256 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 342 25 | BATCH_SIZE: 200 26 | NUM_GPUS: 8 27 | OUT_DIR: . 28 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/effnet/EN-B4_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: effnet 3 | NUM_CLASSES: 1000 4 | EN: 5 | STEM_W: 48 6 | STRIDES: [1, 2, 2, 2, 1, 2, 1] 7 | DEPTHS: [2, 4, 4, 6, 6, 8, 2] 8 | WIDTHS: [24, 32, 56, 112, 160, 272, 448] 9 | EXP_RATIOS: [1, 6, 6, 6, 6, 6, 6] 10 | KERNELS: [3, 3, 5, 3, 5, 5, 3] 11 | HEAD_W: 1792 12 | OPTIM: 13 | LR_POLICY: cos 14 | BASE_LR: 0.2 15 | MAX_EPOCH: 100 16 | MOMENTUM: 0.9 17 | WEIGHT_DECAY: 1e-5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 380 21 | BATCH_SIZE: 128 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 434 25 | BATCH_SIZE: 104 26 | NUM_GPUS: 8 27 | OUT_DIR: . 28 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/effnet/EN-B5_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: effnet 3 | NUM_CLASSES: 1000 4 | EN: 5 | STEM_W: 48 6 | STRIDES: [1, 2, 2, 2, 1, 2, 1] 7 | DEPTHS: [3, 5, 5, 7, 7, 9, 3] 8 | WIDTHS: [24, 40, 64, 128, 176, 304, 512] 9 | EXP_RATIOS: [1, 6, 6, 6, 6, 6, 6] 10 | KERNELS: [3, 3, 5, 3, 5, 5, 3] 11 | HEAD_W: 2048 12 | OPTIM: 13 | LR_POLICY: cos 14 | BASE_LR: 0.1 15 | MAX_EPOCH: 100 16 | MOMENTUM: 0.9 17 | WEIGHT_DECAY: 1e-5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 456 21 | BATCH_SIZE: 64 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 522 25 | BATCH_SIZE: 48 26 | NUM_GPUS: 8 27 | OUT_DIR: . 28 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | DEPTH: 18 6 | W0: 80 7 | WA: 34.01 8 | WM: 2.25 9 | GROUP_W: 24 10 | OPTIM: 11 | LR_POLICY: cos 12 | BASE_LR: 0.8 13 | MAX_EPOCH: 100 14 | MOMENTUM: 0.9 15 | WEIGHT_DECAY: 5e-5 16 | WARMUP_ITERS: 5 17 | TRAIN: 18 | DATASET: imagenet 19 | IM_SIZE: 224 20 | BATCH_SIZE: 1024 21 | TEST: 22 | DATASET: imagenet 23 | IM_SIZE: 256 24 | BATCH_SIZE: 800 25 | NUM_GPUS: 8 26 | OUT_DIR: . 27 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnetx/RegNetX-12GF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | DEPTH: 19 6 | W0: 168 7 | WA: 73.36 8 | WM: 2.37 9 | GROUP_W: 112 10 | OPTIM: 11 | LR_POLICY: cos 12 | BASE_LR: 0.4 13 | MAX_EPOCH: 100 14 | MOMENTUM: 0.9 15 | WEIGHT_DECAY: 5e-5 16 | WARMUP_ITERS: 5 17 | TRAIN: 18 | DATASET: imagenet 19 | IM_SIZE: 224 20 | BATCH_SIZE: 512 21 | TEST: 22 | DATASET: imagenet 23 | IM_SIZE: 256 24 | BATCH_SIZE: 400 25 | NUM_GPUS: 8 26 | OUT_DIR: . 27 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnetx/RegNetX-16GF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | DEPTH: 22 6 | W0: 216 7 | WA: 55.59 8 | WM: 2.1 9 | GROUP_W: 128 10 | OPTIM: 11 | LR_POLICY: cos 12 | BASE_LR: 0.4 13 | MAX_EPOCH: 100 14 | MOMENTUM: 0.9 15 | WEIGHT_DECAY: 5e-5 16 | WARMUP_ITERS: 5 17 | TRAIN: 18 | DATASET: imagenet 19 | IM_SIZE: 224 20 | BATCH_SIZE: 512 21 | TEST: 22 | DATASET: imagenet 23 | IM_SIZE: 256 24 | BATCH_SIZE: 400 25 | NUM_GPUS: 8 26 | OUT_DIR: . 27 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnetx/RegNetX-200MF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | DEPTH: 13 6 | W0: 24 7 | WA: 36.44 8 | WM: 2.49 9 | GROUP_W: 8 10 | OPTIM: 11 | LR_POLICY: cos 12 | BASE_LR: 0.8 13 | MAX_EPOCH: 100 14 | MOMENTUM: 0.9 15 | WEIGHT_DECAY: 5e-5 16 | WARMUP_ITERS: 5 17 | TRAIN: 18 | DATASET: imagenet 19 | IM_SIZE: 224 20 | BATCH_SIZE: 1024 21 | TEST: 22 | DATASET: imagenet 23 | IM_SIZE: 256 24 | BATCH_SIZE: 800 25 | NUM_GPUS: 8 26 | OUT_DIR: . 27 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | DEPTH: 25 6 | W0: 88 7 | WA: 26.31 8 | WM: 2.25 9 | GROUP_W: 48 10 | OPTIM: 11 | LR_POLICY: cos 12 | BASE_LR: 0.4 13 | MAX_EPOCH: 100 14 | MOMENTUM: 0.9 15 | WEIGHT_DECAY: 5e-5 16 | WARMUP_ITERS: 5 17 | TRAIN: 18 | DATASET: imagenet 19 | IM_SIZE: 224 20 | BATCH_SIZE: 512 21 | TEST: 22 | DATASET: imagenet 23 | IM_SIZE: 256 24 | BATCH_SIZE: 400 25 | NUM_GPUS: 8 26 | OUT_DIR: . 27 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnetx/RegNetX-32GF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | DEPTH: 23 6 | W0: 320 7 | WA: 69.86 8 | WM: 2.0 9 | GROUP_W: 168 10 | OPTIM: 11 | LR_POLICY: cos 12 | BASE_LR: 0.2 13 | MAX_EPOCH: 100 14 | MOMENTUM: 0.9 15 | WEIGHT_DECAY: 5e-5 16 | WARMUP_ITERS: 5 17 | TRAIN: 18 | DATASET: imagenet 19 | IM_SIZE: 224 20 | BATCH_SIZE: 256 21 | TEST: 22 | DATASET: imagenet 23 | IM_SIZE: 256 24 | BATCH_SIZE: 200 25 | NUM_GPUS: 8 26 | OUT_DIR: . 27 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnetx/RegNetX-4.0GF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | DEPTH: 23 6 | W0: 96 7 | WA: 38.65 8 | WM: 2.43 9 | GROUP_W: 40 10 | OPTIM: 11 | LR_POLICY: cos 12 | BASE_LR: 0.4 13 | MAX_EPOCH: 100 14 | MOMENTUM: 0.9 15 | WEIGHT_DECAY: 5e-5 16 | WARMUP_ITERS: 5 17 | TRAIN: 18 | DATASET: imagenet 19 | IM_SIZE: 224 20 | BATCH_SIZE: 512 21 | TEST: 22 | DATASET: imagenet 23 | IM_SIZE: 256 24 | BATCH_SIZE: 400 25 | NUM_GPUS: 8 26 | OUT_DIR: . 27 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnetx/RegNetX-400MF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | DEPTH: 22 6 | W0: 24 7 | WA: 24.48 8 | WM: 2.54 9 | GROUP_W: 16 10 | OPTIM: 11 | LR_POLICY: cos 12 | BASE_LR: 0.8 13 | MAX_EPOCH: 100 14 | MOMENTUM: 0.9 15 | WEIGHT_DECAY: 5e-5 16 | WARMUP_ITERS: 5 17 | TRAIN: 18 | DATASET: imagenet 19 | IM_SIZE: 224 20 | BATCH_SIZE: 1024 21 | TEST: 22 | DATASET: imagenet 23 | IM_SIZE: 256 24 | BATCH_SIZE: 800 25 | NUM_GPUS: 8 26 | OUT_DIR: . 27 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | DEPTH: 17 6 | W0: 184 7 | WA: 60.83 8 | WM: 2.07 9 | GROUP_W: 56 10 | OPTIM: 11 | LR_POLICY: cos 12 | BASE_LR: 0.4 13 | MAX_EPOCH: 100 14 | MOMENTUM: 0.9 15 | WEIGHT_DECAY: 5e-5 16 | WARMUP_ITERS: 5 17 | TRAIN: 18 | DATASET: imagenet 19 | IM_SIZE: 224 20 | BATCH_SIZE: 512 21 | TEST: 22 | DATASET: imagenet 23 | IM_SIZE: 256 24 | BATCH_SIZE: 400 25 | NUM_GPUS: 8 26 | OUT_DIR: . 27 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnetx/RegNetX-600MF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | DEPTH: 16 6 | W0: 48 7 | WA: 36.97 8 | WM: 2.24 9 | GROUP_W: 24 10 | OPTIM: 11 | LR_POLICY: cos 12 | BASE_LR: 0.8 13 | MAX_EPOCH: 100 14 | MOMENTUM: 0.9 15 | WEIGHT_DECAY: 5e-5 16 | WARMUP_ITERS: 5 17 | TRAIN: 18 | DATASET: imagenet 19 | IM_SIZE: 224 20 | BATCH_SIZE: 1024 21 | TEST: 22 | DATASET: imagenet 23 | IM_SIZE: 256 24 | BATCH_SIZE: 800 25 | NUM_GPUS: 8 26 | OUT_DIR: . 27 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnetx/RegNetX-8.0GF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | DEPTH: 23 6 | W0: 80 7 | WA: 49.56 8 | WM: 2.88 9 | GROUP_W: 120 10 | OPTIM: 11 | LR_POLICY: cos 12 | BASE_LR: 0.4 13 | MAX_EPOCH: 100 14 | MOMENTUM: 0.9 15 | WEIGHT_DECAY: 5e-5 16 | WARMUP_ITERS: 5 17 | TRAIN: 18 | DATASET: imagenet 19 | IM_SIZE: 224 20 | BATCH_SIZE: 512 21 | TEST: 22 | DATASET: imagenet 23 | IM_SIZE: 256 24 | BATCH_SIZE: 400 25 | NUM_GPUS: 8 26 | OUT_DIR: . 27 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | DEPTH: 16 6 | W0: 56 7 | WA: 35.73 8 | WM: 2.28 9 | GROUP_W: 16 10 | OPTIM: 11 | LR_POLICY: cos 12 | BASE_LR: 0.8 13 | MAX_EPOCH: 100 14 | MOMENTUM: 0.9 15 | WEIGHT_DECAY: 5e-5 16 | WARMUP_ITERS: 5 17 | TRAIN: 18 | DATASET: imagenet 19 | IM_SIZE: 224 20 | BATCH_SIZE: 1024 21 | TEST: 22 | DATASET: imagenet 23 | IM_SIZE: 256 24 | BATCH_SIZE: 800 25 | NUM_GPUS: 8 26 | OUT_DIR: . 27 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | SE_ON: True 6 | DEPTH: 27 7 | W0: 48 8 | WA: 20.71 9 | WM: 2.65 10 | GROUP_W: 24 11 | OPTIM: 12 | LR_POLICY: cos 13 | BASE_LR: 0.8 14 | MAX_EPOCH: 100 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 5e-5 17 | WARMUP_ITERS: 5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 224 21 | BATCH_SIZE: 1024 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 256 25 | BATCH_SIZE: 800 26 | NUM_GPUS: 8 27 | OUT_DIR: . 28 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnety/RegNetY-12GF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | SE_ON: True 6 | DEPTH: 19 7 | W0: 168 8 | WA: 73.36 9 | WM: 2.37 10 | GROUP_W: 112 11 | OPTIM: 12 | LR_POLICY: cos 13 | BASE_LR: 0.4 14 | MAX_EPOCH: 100 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 5e-5 17 | WARMUP_ITERS: 5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 224 21 | BATCH_SIZE: 512 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 256 25 | BATCH_SIZE: 400 26 | NUM_GPUS: 8 27 | OUT_DIR: . 28 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnety/RegNetY-16GF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | SE_ON: True 6 | DEPTH: 18 7 | W0: 200 8 | WA: 106.23 9 | WM: 2.48 10 | GROUP_W: 112 11 | OPTIM: 12 | LR_POLICY: cos 13 | BASE_LR: 0.2 14 | MAX_EPOCH: 100 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 5e-5 17 | WARMUP_ITERS: 5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 224 21 | BATCH_SIZE: 256 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 256 25 | BATCH_SIZE: 200 26 | NUM_GPUS: 8 27 | OUT_DIR: . 28 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnety/RegNetY-200MF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | SE_ON: True 6 | DEPTH: 13 7 | W0: 24 8 | WA: 36.44 9 | WM: 2.49 10 | GROUP_W: 8 11 | OPTIM: 12 | LR_POLICY: cos 13 | BASE_LR: 0.8 14 | MAX_EPOCH: 100 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 5e-5 17 | TRAIN: 18 | DATASET: imagenet 19 | IM_SIZE: 224 20 | BATCH_SIZE: 1024 21 | TEST: 22 | DATASET: imagenet 23 | IM_SIZE: 256 24 | BATCH_SIZE: 800 25 | NUM_GPUS: 8 26 | OUT_DIR: . 27 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | SE_ON: True 6 | DEPTH: 21 7 | W0: 80 8 | WA: 42.63 9 | WM: 2.66 10 | GROUP_W: 24 11 | OPTIM: 12 | LR_POLICY: cos 13 | BASE_LR: 0.4 14 | MAX_EPOCH: 100 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 5e-5 17 | WARMUP_ITERS: 5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 224 21 | BATCH_SIZE: 512 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 256 25 | BATCH_SIZE: 400 26 | NUM_GPUS: 8 27 | OUT_DIR: . 28 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnety/RegNetY-32GF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | SE_ON: True 6 | DEPTH: 20 7 | W0: 232 8 | WA: 115.89 9 | WM: 2.53 10 | GROUP_W: 232 11 | OPTIM: 12 | LR_POLICY: cos 13 | BASE_LR: 0.2 14 | MAX_EPOCH: 100 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 5e-5 17 | WARMUP_ITERS: 5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 224 21 | BATCH_SIZE: 256 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 256 25 | BATCH_SIZE: 200 26 | NUM_GPUS: 8 27 | OUT_DIR: . 28 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | SE_ON: True 6 | DEPTH: 22 7 | W0: 96 8 | WA: 31.41 9 | WM: 2.24 10 | GROUP_W: 64 11 | OPTIM: 12 | LR_POLICY: cos 13 | BASE_LR: 0.4 14 | MAX_EPOCH: 100 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 5e-5 17 | WARMUP_ITERS: 5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 224 21 | BATCH_SIZE: 512 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 256 25 | BATCH_SIZE: 400 26 | NUM_GPUS: 8 27 | OUT_DIR: . 28 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnety/RegNetY-400MF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | SE_ON: True 6 | DEPTH: 16 7 | W0: 48 8 | WA: 27.89 9 | WM: 2.09 10 | GROUP_W: 8 11 | OPTIM: 12 | LR_POLICY: cos 13 | BASE_LR: 0.8 14 | MAX_EPOCH: 100 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 5e-5 17 | WARMUP_ITERS: 5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 224 21 | BATCH_SIZE: 1024 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 256 25 | BATCH_SIZE: 800 26 | NUM_GPUS: 8 27 | OUT_DIR: . 28 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | SE_ON: True 6 | DEPTH: 25 7 | W0: 112 8 | WA: 33.22 9 | WM: 2.27 10 | GROUP_W: 72 11 | OPTIM: 12 | LR_POLICY: cos 13 | BASE_LR: 0.4 14 | MAX_EPOCH: 100 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 5e-5 17 | WARMUP_ITERS: 5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 224 21 | BATCH_SIZE: 512 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 256 25 | BATCH_SIZE: 400 26 | NUM_GPUS: 8 27 | OUT_DIR: . 28 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnety/RegNetY-600MF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | SE_ON: True 6 | DEPTH: 15 7 | W0: 48 8 | WA: 32.54 9 | WM: 2.32 10 | GROUP_W: 16 11 | OPTIM: 12 | LR_POLICY: cos 13 | BASE_LR: 0.8 14 | MAX_EPOCH: 100 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 5e-5 17 | WARMUP_ITERS: 5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 224 21 | BATCH_SIZE: 1024 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 256 25 | BATCH_SIZE: 800 26 | NUM_GPUS: 8 27 | OUT_DIR: . 28 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnety/RegNetY-8.0GF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | SE_ON: true 6 | DEPTH: 17 7 | W0: 192 8 | WA: 76.82 9 | WM: 2.19 10 | GROUP_W: 56 11 | OPTIM: 12 | LR_POLICY: cos 13 | BASE_LR: 0.4 14 | MAX_EPOCH: 100 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 5e-5 17 | WARMUP_ITERS: 5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 224 21 | BATCH_SIZE: 512 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 256 25 | BATCH_SIZE: 400 26 | NUM_GPUS: 8 27 | OUT_DIR: . 28 | -------------------------------------------------------------------------------- /fastreid/modeling/backbones/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: regnet 3 | NUM_CLASSES: 1000 4 | REGNET: 5 | SE_ON: True 6 | DEPTH: 14 7 | W0: 56 8 | WA: 38.84 9 | WM: 2.4 10 | GROUP_W: 16 11 | OPTIM: 12 | LR_POLICY: cos 13 | BASE_LR: 0.8 14 | MAX_EPOCH: 100 15 | MOMENTUM: 0.9 16 | WEIGHT_DECAY: 5e-5 17 | WARMUP_ITERS: 5 18 | TRAIN: 19 | DATASET: imagenet 20 | IM_SIZE: 224 21 | BATCH_SIZE: 1024 22 | TEST: 23 | DATASET: imagenet 24 | IM_SIZE: 256 25 | BATCH_SIZE: 800 26 | NUM_GPUS: 8 27 | OUT_DIR: . 28 | -------------------------------------------------------------------------------- /fastreid/modeling/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import REID_HEADS_REGISTRY, build_heads 8 | 9 | # import all the meta_arch, so they will be registered 10 | from .embedding_head import EmbeddingHead 11 | from .meta_embedding_head import MetaEmbeddingHead 12 | from .clas_head import ClasHead 13 | -------------------------------------------------------------------------------- /fastreid/modeling/heads/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from ...utils.registry import Registry 8 | 9 | REID_HEADS_REGISTRY = Registry("HEADS") 10 | REID_HEADS_REGISTRY.__doc__ = """ 11 | Registry for reid heads in a baseline model. 12 | 13 | ROIHeads take feature maps and region proposals, and 14 | perform per-region computation. 15 | The registered object will be called with `obj(cfg, input_shape)`. 16 | The call is expected to return an :class:`ROIHeads`. 17 | """ 18 | 19 | 20 | def build_heads(cfg): 21 | """ 22 | Build REIDHeads defined by `cfg.MODEL.REID_HEADS.NAME`. 23 | """ 24 | head = cfg.MODEL.HEADS.NAME 25 | return REID_HEADS_REGISTRY.get(head)(cfg) 26 | -------------------------------------------------------------------------------- /fastreid/modeling/heads/clas_head.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch.nn.functional as F 8 | 9 | from fastreid.modeling.heads import REID_HEADS_REGISTRY, EmbeddingHead 10 | 11 | 12 | @REID_HEADS_REGISTRY.register() 13 | class ClasHead(EmbeddingHead): 14 | def forward(self, features, targets=None): 15 | """ 16 | See :class:`ClsHeads.forward`. 17 | """ 18 | pool_feat = self.pool_layer(features) 19 | neck_feat = self.bottleneck(pool_feat) 20 | neck_feat = neck_feat.view(neck_feat.size(0), -1) 21 | 22 | if self.cls_layer.__class__.__name__ == 'Linear': 23 | logits = F.linear(neck_feat, self.weight) 24 | else: 25 | logits = F.linear(F.normalize(neck_feat), F.normalize(self.weight)) 26 | 27 | # Evaluation 28 | if not self.training: return logits.mul_(self.cls_layer.s) 29 | 30 | cls_outputs = self.cls_layer(logits.clone(), targets) 31 | 32 | return { 33 | "cls_outputs": cls_outputs, 34 | "pred_class_logits": logits.mul_(self.cls_layer.s), 35 | "features": neck_feat, 36 | } 37 | -------------------------------------------------------------------------------- /fastreid/modeling/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: l1aoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .circle_loss import * 8 | from .cross_entroy_loss import cross_entropy_loss, log_accuracy 9 | from .focal_loss import focal_loss 10 | from .triplet_loss import triplet_loss 11 | from .center_loss import centerLoss 12 | from .svmo import SVMORegularizer 13 | from .domain_SCT_loss import domain_SCT_loss 14 | from .triplet_loss_MetaIBN import triplet_loss_Meta 15 | 16 | __all__ = [k for k in globals().keys() if not k.startswith("_")] -------------------------------------------------------------------------------- /fastreid/modeling/losses/center_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | # class CenterLoss(nn.Module): 8 | # """Center loss. 9 | 10 | # Reference: 11 | # Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 12 | 13 | # Args: 14 | # num_classes (int): number of classes. 15 | # feat_dim (int): feature dimension. 16 | # """ 17 | 18 | # def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True): 19 | # super(CenterLoss, self).__init__() 20 | # self.num_classes = num_classes 21 | # self.feat_dim = feat_dim 22 | # self.use_gpu = use_gpu 23 | 24 | # if self.use_gpu: 25 | # self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 26 | # else: 27 | # self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 28 | 29 | # def forward(self, x, labels): 30 | # """ 31 | # Args: 32 | # x: feature matrix with shape (batch_size, feat_dim). 33 | # labels: ground truth labels with shape (num_classes). 34 | # """ 35 | # assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)" 36 | 37 | # batch_size = x.size(0) 38 | # distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 39 | # torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 40 | # distmat.addmm_(1, -2, x, self.centers.t()) 41 | 42 | # classes = torch.arange(self.num_classes).long() 43 | # if self.use_gpu: classes = classes.cuda() 44 | # labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 45 | # mask = labels.eq(classes.expand(batch_size, self.num_classes)) 46 | 47 | # dist = distmat * mask.float() 48 | # loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size 49 | # #dist = [] 50 | # #for i in range(batch_size): 51 | # # value = distmat[i][mask[i]] 52 | # # value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 53 | # # dist.append(value) 54 | # #dist = torch.cat(dist) 55 | # #loss = dist.mean() 56 | # return loss 57 | 58 | def centerLoss(distmat, labels): 59 | """ 60 | Args: 61 | x: feature matrix with shape (batch_size, feat_dim). 62 | labels: ground truth labels with shape (num_classes). 63 | """ 64 | 65 | batch_size = distmat.size(0) 66 | num_classes = distmat.size(1) 67 | 68 | classes = torch.arange(num_classes).long().to(distmat.device) 69 | labels = labels.unsqueeze(1).expand(batch_size, num_classes) 70 | mask = labels.eq(classes.expand(batch_size, num_classes)) 71 | 72 | # import pdb; pdb.set_trace() 73 | 74 | dist = distmat * mask.float() 75 | loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size 76 | 77 | return loss 78 | 79 | 80 | if __name__ == '__main__': 81 | use_gpu = False 82 | center_loss = CenterLoss(use_gpu=use_gpu) 83 | features = torch.rand(16, 2048) 84 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long() 85 | if use_gpu: 86 | features = torch.rand(16, 2048).cuda() 87 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda() 88 | 89 | loss = center_loss(features, targets) 90 | print(loss) -------------------------------------------------------------------------------- /fastreid/modeling/losses/circle_loss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | __all__ = ["pairwise_circleloss", "pairwise_cosface"] 11 | 12 | 13 | def pairwise_circleloss( 14 | embedding: torch.Tensor, 15 | targets: torch.Tensor, 16 | margin: float, 17 | gamma: float, ) -> torch.Tensor: 18 | embedding = F.normalize(embedding, dim=1) 19 | 20 | dist_mat = torch.matmul(embedding, embedding.t()) 21 | 22 | N = dist_mat.size(0) 23 | 24 | is_pos = targets.view(N, 1).expand(N, N).eq(targets.view(N, 1).expand(N, N).t()).float() 25 | is_neg = targets.view(N, 1).expand(N, N).ne(targets.view(N, 1).expand(N, N).t()).float() 26 | 27 | # Mask scores related to itself 28 | is_pos = is_pos - torch.eye(N, N, device=is_pos.device) 29 | 30 | s_p = dist_mat * is_pos 31 | s_n = dist_mat * is_neg 32 | 33 | alpha_p = torch.clamp_min(-s_p.detach() + 1 + margin, min=0.) 34 | alpha_n = torch.clamp_min(s_n.detach() + margin, min=0.) 35 | delta_p = 1 - margin 36 | delta_n = margin 37 | 38 | logit_p = - gamma * alpha_p * (s_p - delta_p) + (-99999999.) * (1 - is_pos) 39 | logit_n = gamma * alpha_n * (s_n - delta_n) + (-99999999.) * (1 - is_neg) 40 | 41 | loss = F.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean() 42 | 43 | return loss 44 | 45 | 46 | def pairwise_cosface( 47 | embedding: torch.Tensor, 48 | targets: torch.Tensor, 49 | margin: float, 50 | gamma: float, ) -> torch.Tensor: 51 | # Normalize embedding features 52 | embedding = F.normalize(embedding, dim=1) 53 | 54 | dist_mat = torch.matmul(embedding, embedding.t()) 55 | 56 | N = dist_mat.size(0) 57 | is_pos = targets.view(N, 1).expand(N, N).eq(targets.view(N, 1).expand(N, N).t()).float() 58 | is_neg = targets.view(N, 1).expand(N, N).ne(targets.view(N, 1).expand(N, N).t()).float() 59 | 60 | # Mask scores related to itself 61 | is_pos = is_pos - torch.eye(N, N, device=is_pos.device) 62 | 63 | s_p = dist_mat * is_pos 64 | s_n = dist_mat * is_neg 65 | 66 | logit_p = -gamma * s_p + (-99999999.) * (1 - is_pos) 67 | logit_n = gamma * (s_n + margin) + (-99999999.) * (1 - is_neg) 68 | 69 | loss = F.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean() 70 | 71 | return loss 72 | -------------------------------------------------------------------------------- /fastreid/modeling/losses/cluster_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def intraCluster(feature, domain_ids, margin=0.1): 6 | feature = F.normalize(feature, 2, 1) 7 | loss = 0 8 | count = 0 9 | for i in range(3): 10 | if (domain_ids == i).sum().item() == 0: 11 | continue 12 | count += 1 13 | loss += (F.relu(torch.norm(feature[domain_ids==i] - feature[domain_ids==i].mean(), 2, 1) - margin) ** 2).mean() 14 | 15 | return loss / count 16 | 17 | 18 | def interCluster(feature, domain_ids, margin=0.3): 19 | feature = F.normalize(feature, 2, 1) 20 | candidate_list = [] 21 | loss = 0 22 | count = 0 23 | for i in range(3): 24 | if (domain_ids == i).sum().item() == 0: 25 | continue 26 | candidate_list.append(feature[domain_ids==i].mean()) 27 | for i in range(len(candidate_list)): 28 | for j in range(i+1, len(candidate_list)): 29 | count += 1 30 | loss += F.relu(margin - torch.norm(candidate_list[i]-candidate_list[j], 2, 0)) ** 2 31 | 32 | return loss / count if count else domain_ids.float().mean() -------------------------------------------------------------------------------- /fastreid/modeling/losses/cross_entroy_loss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: l1aoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | from fastreid.utils.events import get_event_storage 10 | 11 | 12 | def log_accuracy(pred_class_logits, gt_classes, topk=(1,)): 13 | """ 14 | Log the accuracy metrics to EventStorage. 15 | """ 16 | bsz = pred_class_logits.size(0) 17 | maxk = max(topk) 18 | _, pred_class = pred_class_logits.topk(maxk, 1, True, True) 19 | pred_class = pred_class.t() 20 | correct = pred_class.eq(gt_classes.view(1, -1).expand_as(pred_class)) 21 | 22 | ret = [] 23 | for k in topk: 24 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 25 | ret.append(correct_k.mul_(1. / bsz)) 26 | 27 | storage = get_event_storage() 28 | storage.put_scalar("cls_accuracy", ret[0]) 29 | 30 | 31 | def cross_entropy_loss(pred_class_outputs, gt_classes, eps, alpha=0.2): 32 | num_classes = pred_class_outputs.size(1) 33 | 34 | if eps >= 0: 35 | smooth_param = eps 36 | else: 37 | # Adaptive label smooth regularization 38 | soft_label = F.softmax(pred_class_outputs, dim=1) 39 | smooth_param = alpha * soft_label[torch.arange(soft_label.size(0)), gt_classes].unsqueeze(1) 40 | 41 | log_probs = F.log_softmax(pred_class_outputs, dim=1) 42 | with torch.no_grad(): 43 | targets = torch.ones_like(log_probs) 44 | targets *= smooth_param / (num_classes - 1) 45 | targets.scatter_(1, gt_classes.data.unsqueeze(1), (1 - smooth_param)) 46 | 47 | loss = (-targets * log_probs).sum(dim=1) 48 | 49 | with torch.no_grad(): 50 | non_zero_cnt = max(loss.nonzero(as_tuple=False).size(0), 1) 51 | 52 | loss = loss.sum() / non_zero_cnt 53 | 54 | return loss -------------------------------------------------------------------------------- /fastreid/modeling/losses/domain_SCT_loss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: l1aoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import torch 7 | import torch.nn.functional as F 8 | from .utils import concat_all_gather, euclidean_dist, normalize, cosine_dist, cosine_sim 9 | 10 | 11 | def domain_SCT_loss(embedding, domain_labels, norm_feat, type): 12 | 13 | # type = 'cosine' # 'cosine', 'euclidean' 14 | # eps=1e-05 15 | if norm_feat: embedding = normalize(embedding, axis=-1) 16 | unique_label = torch.unique(domain_labels) 17 | embedding_all = list() 18 | for i, x in enumerate(unique_label): 19 | embedding_all.append(embedding[x == domain_labels]) 20 | num_domain = len(embedding_all) 21 | loss_all = [] 22 | for i in range(num_domain): 23 | feat = embedding_all[i] 24 | center_feat = torch.mean(feat, 0) 25 | if type == 'euclidean': 26 | loss = torch.mean(euclidean_dist(center_feat.view(1, -1), feat)) 27 | loss_all.append(-loss) 28 | elif type == 'cosine': 29 | loss = torch.mean(cosine_dist(center_feat.view(1, -1), feat)) 30 | loss_all.append(-loss) 31 | elif type == 'cosine_sim': 32 | loss = torch.mean(cosine_sim(center_feat.view(1, -1), feat)) 33 | loss_all.append(loss) 34 | 35 | loss_all = torch.mean(torch.stack(loss_all)) 36 | 37 | return loss_all 38 | -------------------------------------------------------------------------------- /fastreid/modeling/losses/focal_loss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | 11 | # based on: 12 | # https://github.com/kornia/kornia/blob/master/kornia/losses/focal.py 13 | 14 | def focal_loss( 15 | input: torch.Tensor, 16 | target: torch.Tensor, 17 | alpha: float, 18 | gamma: float = 2.0, 19 | reduction: str = 'mean') -> torch.Tensor: 20 | r"""Criterion that computes Focal loss. 21 | See :class:`fastreid.modeling.losses.FocalLoss` for details. 22 | According to [1], the Focal loss is computed as follows: 23 | .. math:: 24 | \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t) 25 | where: 26 | - :math:`p_t` is the model's estimated probability for each class. 27 | Arguments: 28 | alpha (float): Weighting factor :math:`\alpha \in [0, 1]`. 29 | gamma (float): Focusing parameter :math:`\gamma >= 0`. 30 | reduction (str, optional): Specifies the reduction to apply to the 31 | output: ‘none’ | ‘mean’ | ‘sum’. ‘none’: no reduction will be applied, 32 | ‘mean’: the sum of the output will be divided by the number of elements 33 | in the output, ‘sum’: the output will be summed. Default: ‘none’. 34 | Shape: 35 | - Input: :math:`(N, C, *)` where C = number of classes. 36 | - Target: :math:`(N, *)` where each value is 37 | :math:`0 ≤ targets[i] ≤ C−1`. 38 | Examples: 39 | >>> N = 5 # num_classes 40 | >>> loss = FocalLoss(cfg) 41 | >>> input = torch.randn(1, N, 3, 5, requires_grad=True) 42 | >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N) 43 | >>> output = loss(input, target) 44 | >>> output.backward() 45 | References: 46 | [1] https://arxiv.org/abs/1708.02002 47 | """ 48 | if not torch.is_tensor(input): 49 | raise TypeError("Input type is not a torch.Tensor. Got {}" 50 | .format(type(input))) 51 | 52 | if not len(input.shape) >= 2: 53 | raise ValueError("Invalid input shape, we expect BxCx*. Got: {}" 54 | .format(input.shape)) 55 | 56 | if input.size(0) != target.size(0): 57 | raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).' 58 | .format(input.size(0), target.size(0))) 59 | 60 | n = input.size(0) 61 | out_size = (n,) + input.size()[2:] 62 | if target.size()[1:] != input.size()[2:]: 63 | raise ValueError('Expected target size {}, got {}'.format( 64 | out_size, target.size())) 65 | 66 | if not input.device == target.device: 67 | raise ValueError( 68 | "input and target must be in the same device. Got: {}".format( 69 | input.device, target.device)) 70 | 71 | # compute softmax over the classes axis 72 | input_soft = F.softmax(input, dim=1) 73 | 74 | # create the labels one hot tensor 75 | target_one_hot = F.one_hot(target, num_classes=input.shape[1]) 76 | 77 | # compute the actual focal loss 78 | weight = torch.pow(-input_soft + 1., gamma) 79 | 80 | focal = -alpha * weight * torch.log(input_soft) 81 | loss_tmp = torch.sum(target_one_hot * focal, dim=1) 82 | 83 | if reduction == 'none': 84 | loss = loss_tmp 85 | elif reduction == 'mean': 86 | loss = torch.mean(loss_tmp) 87 | elif reduction == 'sum': 88 | loss = torch.sum(loss_tmp) 89 | else: 90 | raise NotImplementedError("Invalid reduction mode: {}" 91 | .format(reduction)) 92 | return loss 93 | -------------------------------------------------------------------------------- /fastreid/modeling/losses/svmo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SVMORegularizer(nn.Module): 6 | 7 | def __init__(self): 8 | super().__init__() 9 | 10 | self.beta = 1e-3 11 | 12 | def dominant_eigenvalue(self, A): 13 | 14 | N, _ = A.size() 15 | x = torch.rand(N, 1, device='cuda') 16 | 17 | Ax = (A @ x) 18 | AAx = (A @ Ax) 19 | 20 | return AAx.permute(1, 0) @ Ax / (Ax.permute(1, 0) @ Ax) 21 | 22 | 23 | def get_singular_values(self, A): 24 | 25 | ATA = A.permute(1, 0) @ A 26 | N, _ = ATA.size() 27 | largest = self.dominant_eigenvalue(ATA) 28 | I = torch.eye(N, device='cuda') # noqa 29 | I = I * largest # noqa 30 | tmp = self.dominant_eigenvalue(ATA - I) 31 | return tmp + largest, largest 32 | 33 | def forward(self, W): 34 | 35 | # old_W = W 36 | old_size = W.size() 37 | 38 | if old_size[0] == 1: 39 | return 0 40 | 41 | W = W.view(old_size[0], -1).permute(1, 0) # (C x H x W) x S 42 | 43 | smallest, largest = self.get_singular_values(W) 44 | return ( 45 | self.beta * 10 * (largest - smallest)**2 46 | ).squeeze() -------------------------------------------------------------------------------- /fastreid/modeling/losses/triplet_loss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from .utils import euclidean_dist, cosine_dist 11 | 12 | 13 | def softmax_weights(dist, mask): 14 | max_v = torch.max(dist * mask, dim=1, keepdim=True)[0] 15 | diff = dist - max_v 16 | Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero 17 | W = torch.exp(diff) * mask / Z 18 | return W 19 | 20 | 21 | def hard_example_mining(dist_mat, is_pos, is_neg): 22 | """For each anchor, find the hardest positive and negative sample. 23 | Args: 24 | dist_mat: pair wise distance between samples, shape [N, M] 25 | is_pos: positive index with shape [N, M] 26 | is_neg: negative index with shape [N, M] 27 | Returns: 28 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 29 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 30 | p_inds: pytorch LongTensor, with shape [N]; 31 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 32 | n_inds: pytorch LongTensor, with shape [N]; 33 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 34 | NOTE: Only consider the case in which all labels have same num of samples, 35 | thus we can cope with all anchors in parallel. 36 | """ 37 | 38 | assert len(dist_mat.size()) == 2 39 | 40 | # `dist_ap` means distance(anchor, positive) 41 | # both `dist_ap` and `relative_p_inds` with shape [N] 42 | dist_ap, _ = torch.max(dist_mat * is_pos, dim=1) 43 | # `dist_an` means distance(anchor, negative) 44 | # both `dist_an` and `relative_n_inds` with shape [N] 45 | dist_an, _ = torch.min(dist_mat * is_neg + is_pos * 1e9, dim=1) 46 | 47 | return dist_ap, dist_an 48 | 49 | 50 | def weighted_example_mining(dist_mat, is_pos, is_neg): 51 | """For each anchor, find the weighted positive and negative sample. 52 | Args: 53 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 54 | is_pos: 55 | is_neg: 56 | Returns: 57 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 58 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 59 | """ 60 | assert len(dist_mat.size()) == 2 61 | 62 | is_pos = is_pos 63 | is_neg = is_neg 64 | dist_ap = dist_mat * is_pos 65 | dist_an = dist_mat * is_neg 66 | 67 | weights_ap = softmax_weights(dist_ap, is_pos) 68 | weights_an = softmax_weights(-dist_an, is_neg) 69 | 70 | dist_ap = torch.sum(dist_ap * weights_ap, dim=1) 71 | dist_an = torch.sum(dist_an * weights_an, dim=1) 72 | 73 | return dist_ap, dist_an 74 | 75 | 76 | def triplet_loss(embedding, targets, margin, norm_feat, hard_mining): 77 | r"""Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). 78 | Related Triplet Loss theory can be found in paper 'In Defense of the Triplet 79 | Loss for Person Re-Identification'.""" 80 | 81 | if norm_feat: 82 | dist_mat = cosine_dist(embedding, embedding) 83 | else: 84 | dist_mat = euclidean_dist(embedding, embedding) 85 | 86 | # For distributed training, gather all features from different process. 87 | # if comm.get_world_size() > 1: 88 | # all_embedding = torch.cat(GatherLayer.apply(embedding), dim=0) 89 | # all_targets = concat_all_gather(targets) 90 | # else: 91 | # all_embedding = embedding 92 | # all_targets = targets 93 | 94 | N = dist_mat.size(0) 95 | is_pos = targets.view(N, 1).expand(N, N).eq(targets.view(N, 1).expand(N, N).t()).float() 96 | is_neg = targets.view(N, 1).expand(N, N).ne(targets.view(N, 1).expand(N, N).t()).float() 97 | 98 | if hard_mining: 99 | dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg) 100 | else: 101 | dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg) 102 | 103 | y = dist_an.new().resize_as_(dist_an).fill_(1) 104 | 105 | if margin > 0: 106 | loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=margin) 107 | else: 108 | loss = F.soft_margin_loss(dist_an - dist_ap, y) 109 | # fmt: off 110 | if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3) 111 | # fmt: on 112 | 113 | return loss -------------------------------------------------------------------------------- /fastreid/modeling/losses/utils.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | 11 | def concat_all_gather(tensor): 12 | """ 13 | Performs all_gather operation on the provided tensors. 14 | *** Warning ***: torch.distributed.all_gather has no gradient. 15 | """ 16 | tensors_gather = [torch.ones_like(tensor) 17 | for _ in range(torch.distributed.get_world_size())] 18 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 19 | 20 | output = torch.cat(tensors_gather, dim=0) 21 | return output 22 | 23 | 24 | def normalize(x, axis=-1): 25 | """Normalizing to unit length along the specified dimension. 26 | Args: 27 | x: pytorch Variable 28 | Returns: 29 | x: pytorch Variable, same shape as input 30 | """ 31 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 32 | return x 33 | 34 | 35 | def euclidean_dist(x, y): 36 | m, n = x.size(0), y.size(0) 37 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 38 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 39 | dist = xx + yy - 2 * torch.matmul(x, y.t()) 40 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 41 | return dist 42 | 43 | 44 | def cosine_dist(x, y): 45 | x = F.normalize(x, dim=1) 46 | y = F.normalize(y, dim=1) 47 | dist = 2 - 2 * torch.mm(x, y.t()) 48 | return dist 49 | 50 | 51 | def cosine_sim(x, y): 52 | bs1, bs2 = x.size(0), y.size(0) 53 | frac_up = torch.matmul(x, y.transpose(0, 1)) 54 | frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \ 55 | (torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1) 56 | cosine = frac_up / frac_down 57 | return cosine -------------------------------------------------------------------------------- /fastreid/modeling/meta_arch/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .build import META_ARCH_REGISTRY, build_model 8 | 9 | 10 | # import all the meta_arch, so they will be registered 11 | from .baseline import Baseline 12 | from .mgn import MGN 13 | from .moco import MoCo 14 | from .distiller import Distiller 15 | -------------------------------------------------------------------------------- /fastreid/modeling/meta_arch/build.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import torch 7 | 8 | from fastreid.utils.registry import Registry 9 | 10 | META_ARCH_REGISTRY = Registry("META_ARCH") # noqa F401 isort:skip 11 | META_ARCH_REGISTRY.__doc__ = """ 12 | Registry for meta-architectures, i.e. the whole model. 13 | The registered object will be called with `obj(cfg)` 14 | and expected to return a `nn.Module` object. 15 | """ 16 | 17 | 18 | def build_model(cfg): 19 | """ 20 | Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``. 21 | Note that it does not load any weights from ``cfg``. 22 | """ 23 | meta_arch = cfg.MODEL.META_ARCHITECTURE 24 | model = META_ARCH_REGISTRY.get(meta_arch)(cfg) 25 | model.to(torch.device(cfg.MODEL.DEVICE)) 26 | return model 27 | -------------------------------------------------------------------------------- /fastreid/solver/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | 8 | from .build import build_lr_scheduler, build_optimizer -------------------------------------------------------------------------------- /fastreid/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from typing import List 8 | 9 | import torch 10 | from torch.optim.lr_scheduler import * 11 | 12 | 13 | class WarmupLR(torch.optim.lr_scheduler._LRScheduler): 14 | def __init__( 15 | self, 16 | optimizer: torch.optim.Optimizer, 17 | warmup_factor: float = 0.1, 18 | warmup_iters: int = 1000, 19 | warmup_method: str = "linear", 20 | last_epoch: int = -1, 21 | ): 22 | self.warmup_factor = warmup_factor 23 | self.warmup_iters = warmup_iters 24 | self.warmup_method = warmup_method 25 | super().__init__(optimizer, last_epoch) 26 | 27 | def get_lr(self) -> List[float]: 28 | warmup_factor = _get_warmup_factor_at_epoch( 29 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor 30 | ) 31 | return [ 32 | base_lr * warmup_factor for base_lr in self.base_lrs 33 | ] 34 | 35 | def _compute_values(self) -> List[float]: 36 | # The new interface 37 | return self.get_lr() 38 | 39 | 40 | def _get_warmup_factor_at_epoch( 41 | method: str, iter: int, warmup_iters: int, warmup_factor: float 42 | ) -> float: 43 | """ 44 | Return the learning rate warmup factor at a specific iteration. 45 | See https://arxiv.org/abs/1706.02677 for more details. 46 | Args: 47 | method (str): warmup method; either "constant" or "linear". 48 | iter (int): iter at which to calculate the warmup factor. 49 | warmup_iters (int): the number of warmup epochs. 50 | warmup_factor (float): the base warmup factor (the meaning changes according 51 | to the method used). 52 | Returns: 53 | float: the effective warmup factor at the given iteration. 54 | """ 55 | if iter >= warmup_iters: 56 | return 1.0 57 | 58 | if method == "constant": 59 | return warmup_factor 60 | elif method == "linear": 61 | alpha = iter / warmup_iters 62 | return warmup_factor * (1 - alpha) + alpha 63 | elif method == "exp": 64 | return warmup_factor ** (1 - iter / warmup_iters) 65 | else: 66 | raise ValueError("Unknown warmup method: {}".format(method)) 67 | -------------------------------------------------------------------------------- /fastreid/solver/optim/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .lamb import Lamb 8 | from .swa import SWA 9 | from .radam import RAdam 10 | from torch.optim import * 11 | -------------------------------------------------------------------------------- /fastreid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | -------------------------------------------------------------------------------- /fastreid/utils/env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import importlib 3 | import importlib.util 4 | import logging 5 | import numpy as np 6 | import os 7 | import random 8 | import sys 9 | from datetime import datetime 10 | import torch 11 | 12 | __all__ = ["seed_all_rng"] 13 | 14 | 15 | TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2]) 16 | """ 17 | PyTorch version as a tuple of 2 ints. Useful for comparison. 18 | """ 19 | 20 | 21 | def seed_all_rng(seed=None): 22 | """ 23 | Set the random seed for the RNG in torch, numpy and python. 24 | Args: 25 | seed (int): if None, will use a strong random seed. 26 | """ 27 | if seed is None: 28 | seed = ( 29 | os.getpid() 30 | + int(datetime.now().strftime("%S%f")) 31 | + int.from_bytes(os.urandom(2), "big") 32 | ) 33 | logger = logging.getLogger(__name__) 34 | logger.info("Using a generated random seed {}".format(seed)) 35 | np.random.seed(seed) 36 | torch.set_rng_state(torch.manual_seed(seed).get_state()) 37 | random.seed(seed) 38 | 39 | 40 | # from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path 41 | def _import_file(module_name, file_path, make_importable=False): 42 | spec = importlib.util.spec_from_file_location(module_name, file_path) 43 | module = importlib.util.module_from_spec(spec) 44 | spec.loader.exec_module(module) 45 | if make_importable: 46 | sys.modules[module_name] = module 47 | return module 48 | 49 | 50 | def _configure_libraries(): 51 | """ 52 | Configurations for some libraries. 53 | """ 54 | # An environment option to disable `import cv2` globally, 55 | # in case it leads to negative performance impact 56 | disable_cv2 = int(os.environ.get("DETECTRON2_DISABLE_CV2", False)) 57 | if disable_cv2: 58 | sys.modules["cv2"] = None 59 | else: 60 | # Disable opencl in opencv since its interaction with cuda often has negative effects 61 | # This envvar is supported after OpenCV 3.4.0 62 | os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled" 63 | try: 64 | import cv2 65 | 66 | if int(cv2.__version__.split(".")[0]) >= 3: 67 | cv2.ocl.setUseOpenCL(False) 68 | except ImportError: 69 | pass 70 | 71 | def get_version(module, digit=2): 72 | return tuple(map(int, module.__version__.split(".")[:digit])) 73 | 74 | # fmt: off 75 | assert get_version(torch) >= (1, 4), "Requires torch>=1.4" 76 | import yaml 77 | assert get_version(yaml) >= (5, 1), "Requires pyyaml>=5.1" 78 | # fmt: on 79 | 80 | 81 | _ENV_SETUP_DONE = False 82 | 83 | 84 | def setup_environment(): 85 | """Perform environment setup work. The default setup is a no-op, but this 86 | function allows the user to specify a Python source file or a module in 87 | the $FASTREID_ENV_MODULE environment variable, that performs 88 | custom setup work that may be necessary to their computing environment. 89 | """ 90 | global _ENV_SETUP_DONE 91 | if _ENV_SETUP_DONE: 92 | return 93 | _ENV_SETUP_DONE = True 94 | 95 | _configure_libraries() 96 | 97 | custom_module_path = os.environ.get("FASTREID_ENV_MODULE") 98 | 99 | if custom_module_path: 100 | setup_custom_environment(custom_module_path) 101 | else: 102 | # The default setup is a no-op 103 | pass 104 | 105 | 106 | def setup_custom_environment(custom_module): 107 | """ 108 | Load custom environment setup by importing a Python source file or a 109 | module, and run the setup function. 110 | """ 111 | if custom_module.endswith(".py"): 112 | module = _import_file("fastreid.utils.env.custom_module", custom_module) 113 | else: 114 | module = importlib.import_module(custom_module) 115 | assert hasattr(module, "setup_environment") and callable(module.setup_environment), ( 116 | "Custom environment module defined in {} does not have the " 117 | "required callable attribute 'setup_environment'." 118 | ).format(custom_module) 119 | module.setup_environment() -------------------------------------------------------------------------------- /fastreid/utils/faiss_utils.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | # copy from: https://github.com/open-mmlab/OpenUnReID/blob/66bb2ae0b00575b80fbe8915f4d4f4739cc21206/openunreid/core/utils/faiss_utils.py 3 | 4 | import faiss 5 | import torch 6 | 7 | 8 | def swig_ptr_from_FloatTensor(x): 9 | assert x.is_contiguous() 10 | assert x.dtype == torch.float32 11 | return faiss.cast_integer_to_float_ptr( 12 | x.storage().data_ptr() + x.storage_offset() * 4 13 | ) 14 | 15 | 16 | def swig_ptr_from_LongTensor(x): 17 | assert x.is_contiguous() 18 | assert x.dtype == torch.int64, "dtype=%s" % x.dtype 19 | return faiss.cast_integer_to_long_ptr( 20 | x.storage().data_ptr() + x.storage_offset() * 8 21 | ) 22 | 23 | 24 | def search_index_pytorch(index, x, k, D=None, I=None): 25 | """call the search function of an index with pytorch tensor I/O (CPU 26 | and GPU supported)""" 27 | assert x.is_contiguous() 28 | n, d = x.size() 29 | assert d == index.d 30 | 31 | if D is None: 32 | D = torch.empty((n, k), dtype=torch.float32, device=x.device) 33 | else: 34 | assert D.size() == (n, k) 35 | 36 | if I is None: 37 | I = torch.empty((n, k), dtype=torch.int64, device=x.device) 38 | else: 39 | assert I.size() == (n, k) 40 | torch.cuda.synchronize() 41 | xptr = swig_ptr_from_FloatTensor(x) 42 | Iptr = swig_ptr_from_LongTensor(I) 43 | Dptr = swig_ptr_from_FloatTensor(D) 44 | index.search_c(n, xptr, k, Dptr, Iptr) 45 | torch.cuda.synchronize() 46 | return D, I 47 | 48 | 49 | def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None, metric=faiss.METRIC_L2): 50 | assert xb.device == xq.device 51 | 52 | nq, d = xq.size() 53 | if xq.is_contiguous(): 54 | xq_row_major = True 55 | elif xq.t().is_contiguous(): 56 | xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-) 57 | xq_row_major = False 58 | else: 59 | raise TypeError("matrix should be row or column-major") 60 | 61 | xq_ptr = swig_ptr_from_FloatTensor(xq) 62 | 63 | nb, d2 = xb.size() 64 | assert d2 == d 65 | if xb.is_contiguous(): 66 | xb_row_major = True 67 | elif xb.t().is_contiguous(): 68 | xb = xb.t() 69 | xb_row_major = False 70 | else: 71 | raise TypeError("matrix should be row or column-major") 72 | xb_ptr = swig_ptr_from_FloatTensor(xb) 73 | 74 | if D is None: 75 | D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) 76 | else: 77 | assert D.shape == (nq, k) 78 | assert D.device == xb.device 79 | 80 | if I is None: 81 | I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) 82 | else: 83 | assert I.shape == (nq, k) 84 | assert I.device == xb.device 85 | 86 | D_ptr = swig_ptr_from_FloatTensor(D) 87 | I_ptr = swig_ptr_from_LongTensor(I) 88 | 89 | faiss.bruteForceKnn( 90 | res, 91 | metric, 92 | xb_ptr, 93 | xb_row_major, 94 | nb, 95 | xq_ptr, 96 | xq_row_major, 97 | nq, 98 | d, 99 | k, 100 | D_ptr, 101 | I_ptr, 102 | ) 103 | 104 | return D, I 105 | 106 | 107 | def index_init_gpu(ngpus, feat_dim): 108 | flat_config = [] 109 | for i in range(ngpus): 110 | cfg = faiss.GpuIndexFlatConfig() 111 | cfg.useFloat16 = False 112 | cfg.device = i 113 | flat_config.append(cfg) 114 | 115 | res = [faiss.StandardGpuResources() for i in range(ngpus)] 116 | indexes = [ 117 | faiss.GpuIndexFlatL2(res[i], feat_dim, flat_config[i]) for i in range(ngpus) 118 | ] 119 | index = faiss.IndexShards(feat_dim) 120 | for sub_index in indexes: 121 | index.add_shard(sub_index) 122 | index.reset() 123 | return index 124 | 125 | 126 | def index_init_cpu(feat_dim): 127 | return faiss.IndexFlatL2(feat_dim) 128 | -------------------------------------------------------------------------------- /fastreid/utils/history_buffer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | import numpy as np 5 | from typing import List, Tuple 6 | 7 | 8 | class HistoryBuffer: 9 | """ 10 | Track a series of scalar values and provide access to smoothed values over a 11 | window or the global average of the series. 12 | """ 13 | 14 | def __init__(self, max_length: int = 1000000): 15 | """ 16 | Args: 17 | max_length: maximal number of values that can be stored in the 18 | buffer. When the capacity of the buffer is exhausted, old 19 | values will be removed. 20 | """ 21 | self._max_length: int = max_length 22 | self._data: List[Tuple[float, float]] = [] # (value, iteration) pairs 23 | self._count: int = 0 24 | self._global_avg: float = 0 25 | 26 | def update(self, value: float, iteration: float = None): 27 | """ 28 | Add a new scalar value produced at certain iteration. If the length 29 | of the buffer exceeds self._max_length, the oldest element will be 30 | removed from the buffer. 31 | """ 32 | if iteration is None: 33 | iteration = self._count 34 | if len(self._data) == self._max_length: 35 | self._data.pop(0) 36 | self._data.append((value, iteration)) 37 | 38 | self._count += 1 39 | self._global_avg += (value - self._global_avg) / self._count 40 | 41 | def latest(self): 42 | """ 43 | Return the latest scalar value added to the buffer. 44 | """ 45 | return self._data[-1][0] 46 | 47 | def median(self, window_size: int): 48 | """ 49 | Return the median of the latest `window_size` values in the buffer. 50 | """ 51 | return np.median([x[0] for x in self._data[-window_size:]]) 52 | 53 | def avg(self, window_size: int): 54 | """ 55 | Return the mean of the latest `window_size` values in the buffer. 56 | """ 57 | return np.mean([x[0] for x in self._data[-window_size:]]) 58 | 59 | def global_avg(self): 60 | """ 61 | Return the mean of all the elements in the buffer. Note that this 62 | includes those getting removed due to limited buffer storage. 63 | """ 64 | return self._global_avg 65 | 66 | def values(self): 67 | """ 68 | Returns: 69 | list[(number, iteration)]: content of the current buffer. 70 | """ 71 | return self._data 72 | -------------------------------------------------------------------------------- /fastreid/utils/precision_bn.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import itertools 8 | 9 | import torch 10 | 11 | BN_MODULE_TYPES = ( 12 | torch.nn.BatchNorm1d, 13 | torch.nn.BatchNorm2d, 14 | torch.nn.BatchNorm3d, 15 | torch.nn.SyncBatchNorm, 16 | ) 17 | 18 | 19 | @torch.no_grad() 20 | def update_bn_stats(model, data_loader, num_iters: int = 200): 21 | """ 22 | Recompute and update the batch norm stats to make them more precise. During 23 | training both BN stats and the weight are changing after every iteration, so 24 | the running average can not precisely reflect the actual stats of the 25 | current model. 26 | In this function, the BN stats are recomputed with fixed weights, to make 27 | the running average more precise. Specifically, it computes the true average 28 | of per-batch mean/variance instead of the running average. 29 | Args: 30 | model (nn.Module): the model whose bn stats will be recomputed. 31 | Note that: 32 | 1. This function will not alter the training mode of the given model. 33 | Users are responsible for setting the layers that needs 34 | precise-BN to training mode, prior to calling this function. 35 | 2. Be careful if your models contain other stateful layers in 36 | addition to BN, i.e. layers whose state can change in forward 37 | iterations. This function will alter their state. If you wish 38 | them unchanged, you need to either pass in a submodule without 39 | those layers, or backup the states. 40 | data_loader (iterator): an iterator. Produce data as inputs to the model. 41 | num_iters (int): number of iterations to compute the stats. 42 | """ 43 | bn_layers = get_bn_modules(model) 44 | if len(bn_layers) == 0: 45 | return 46 | 47 | # In order to make the running stats only reflect the current batch, the 48 | # momentum is disabled. 49 | # bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean 50 | # Setting the momentum to 1.0 to compute the stats without momentum. 51 | momentum_actual = [bn.momentum for bn in bn_layers] 52 | for bn in bn_layers: 53 | bn.momentum = 1.0 54 | 55 | # Note that running_var actually means "running average of variance" 56 | running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers] 57 | running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers] 58 | 59 | for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)): 60 | inputs['targets'].fill_(-1) 61 | with torch.no_grad(): # No need to backward 62 | model(inputs) 63 | for i, bn in enumerate(bn_layers): 64 | # Accumulates the bn stats. 65 | running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1) 66 | running_var[i] += (bn.running_var - running_var[i]) / (ind + 1) 67 | # We compute the "average of variance" across iterations. 68 | assert ind == num_iters - 1, ( 69 | "update_bn_stats is meant to run for {} iterations, " 70 | "but the dataloader stops at {} iterations.".format(num_iters, ind) 71 | ) 72 | 73 | for i, bn in enumerate(bn_layers): 74 | # Sets the precise bn stats. 75 | bn.running_mean = running_mean[i] 76 | bn.running_var = running_var[i] 77 | bn.momentum = momentum_actual[i] 78 | 79 | 80 | def get_bn_modules(model): 81 | """ 82 | Find all BatchNorm (BN) modules that are in training mode. See 83 | fvcore.precise_bn.BN_MODULE_TYPES for a list of all modules that are 84 | included in this search. 85 | Args: 86 | model (nn.Module): a model possibly containing BN modules. 87 | Returns: 88 | list[nn.Module]: all BN modules in the model. 89 | """ 90 | # Finds all the bn layers. 91 | bn_layers = [ 92 | m for m in model.modules() if m.training and isinstance(m, BN_MODULE_TYPES) 93 | ] 94 | return bn_layers 95 | -------------------------------------------------------------------------------- /fastreid/utils/registry.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from typing import Dict, Optional 5 | 6 | 7 | class Registry(object): 8 | """ 9 | The registry that provides name -> object mapping, to support third-party 10 | users' custom modules. 11 | To create a registry (e.g. a backbone registry): 12 | .. code-block:: python 13 | BACKBONE_REGISTRY = Registry('BACKBONE') 14 | To register an object: 15 | .. code-block:: python 16 | @BACKBONE_REGISTRY.register() 17 | class MyBackbone(): 18 | ... 19 | Or: 20 | .. code-block:: python 21 | BACKBONE_REGISTRY.register(MyBackbone) 22 | """ 23 | 24 | def __init__(self, name: str) -> None: 25 | """ 26 | Args: 27 | name (str): the name of this registry 28 | """ 29 | self._name: str = name 30 | self._obj_map: Dict[str, object] = {} 31 | 32 | def _do_register(self, name: str, obj: object) -> None: 33 | assert ( 34 | name not in self._obj_map 35 | ), "An object named '{}' was already registered in '{}' registry!".format( 36 | name, self._name 37 | ) 38 | self._obj_map[name] = obj 39 | 40 | def register(self, obj: object = None) -> Optional[object]: 41 | """ 42 | Register the given object under the the name `obj.__name__`. 43 | Can be used as either a decorator or not. See docstring of this class for usage. 44 | """ 45 | if obj is None: 46 | # used as a decorator 47 | def deco(func_or_class: object) -> object: 48 | name = func_or_class.__name__ # pyre-ignore 49 | self._do_register(name, func_or_class) 50 | return func_or_class 51 | 52 | return deco 53 | 54 | # used as a function call 55 | name = obj.__name__ # pyre-ignore 56 | self._do_register(name, obj) 57 | 58 | def get(self, name: str) -> object: 59 | ret = self._obj_map.get(name) 60 | if ret is None: 61 | raise KeyError( 62 | "No object named '{}' found in '{}' registry!".format( 63 | name, self._name 64 | ) 65 | ) 66 | return ret 67 | -------------------------------------------------------------------------------- /fastreid/utils/timer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # -*- coding: utf-8 -*- 3 | 4 | from time import perf_counter 5 | from typing import Optional 6 | 7 | 8 | class Timer: 9 | """ 10 | A timer which computes the time elapsed since the start/reset of the timer. 11 | """ 12 | 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | """ 18 | Reset the timer. 19 | """ 20 | self._start = perf_counter() 21 | self._paused: Optional[float] = None 22 | self._total_paused = 0 23 | self._count_start = 1 24 | 25 | def pause(self): 26 | """ 27 | Pause the timer. 28 | """ 29 | if self._paused is not None: 30 | raise ValueError("Trying to pause a Timer that is already paused!") 31 | self._paused = perf_counter() 32 | 33 | def is_paused(self) -> bool: 34 | """ 35 | Returns: 36 | bool: whether the timer is currently paused 37 | """ 38 | return self._paused is not None 39 | 40 | def resume(self): 41 | """ 42 | Resume the timer. 43 | """ 44 | if self._paused is None: 45 | raise ValueError("Trying to resume a Timer that is not paused!") 46 | self._total_paused += perf_counter() - self._paused 47 | self._paused = None 48 | self._count_start += 1 49 | 50 | def seconds(self) -> float: 51 | """ 52 | Returns: 53 | (float): the total number of seconds since the start/reset of the 54 | timer, excluding the time when the timer is paused. 55 | """ 56 | if self._paused is not None: 57 | end_time: float = self._paused # type: ignore 58 | else: 59 | end_time = perf_counter() 60 | return end_time - self._start - self._total_paused 61 | 62 | def avg_seconds(self) -> float: 63 | """ 64 | Returns: 65 | (float): the average number of seconds between every start/reset and 66 | pause. 67 | """ 68 | return self.seconds() / self._count_start 69 | -------------------------------------------------------------------------------- /fig/pipeline.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peterzpy/ACL-DGReID/d84c75dfa7e5b502a094c084ba137b4fac42e82b/fig/pipeline.pdf -------------------------------------------------------------------------------- /fig/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peterzpy/ACL-DGReID/d84c75dfa7e5b502a094c084ba137b4fac42e82b/fig/pipeline.png -------------------------------------------------------------------------------- /launch.sh: -------------------------------------------------------------------------------- 1 | #logs/ 2 | CUDA_VISIBLE_DEVICES=0,1,2,3 python3 tools/train_net.py --config-file ./configs/bagtricks_DR50_mix.yml --num-gpus 4 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | scipy 3 | Pillow 4 | numpy 5 | prettytable 6 | easydict 7 | scikit-learn 8 | pyyaml 9 | yacs 10 | termcolor 11 | tabulate 12 | tensorboard 13 | opencv-python 14 | pyyaml 15 | yacs 16 | termcolor 17 | scikit-learn 18 | tabulate 19 | gdown 20 | faiss-gpu -------------------------------------------------------------------------------- /tools/deploy/Caffe/ReadMe.md: -------------------------------------------------------------------------------- 1 | # The Caffe in nn_tools Provides some convenient API 2 | If there are some problem in parse your prototxt or caffemodel, Please replace 3 | the caffe.proto with your own version and compile it with command 4 | `protoc --python_out ./ caffe.proto` 5 | 6 | ## caffe_net.py 7 | Using `from nn_tools.Caffe import caffe_net` to import this model 8 | ### Prototxt 9 | + `net=caffe_net.Prototxt(file_name)` to open a prototxt file 10 | + `net.init_caffemodel(caffe_cmd_path='caffe')` to generate a caffemodel file in the current work directory \ 11 | if your `caffe` cmd not in the $PATH, specify your caffe cmd path by the `caffe_cmd_path` kwargs. 12 | ### Caffemodel 13 | + `net=caffe_net.Caffemodel(file_name)` to open a caffemodel 14 | + `net.save_prototxt(path)` to save the caffemodel to a prototxt file (not containing the weight data) 15 | + `net.get_layer_data(layer_name)` return the numpy ndarray data of the layer 16 | + `net.set_layer_date(layer_name, datas)` specify the data of one layer in the caffemodel .`datas` is normally a list of numpy ndarray `[weights,bias]` 17 | + `net.save(path)` save the changed caffemodel 18 | ### Functions for both Prototxt and Caffemodel 19 | + `net.add_layer(layer_params,before='',after='')` add a new layer with `Layer_Param` object 20 | + `net.remove_layer_by_name(layer_name)` 21 | + `net.get_layer_by_name(layer_name)` or `net.layer(layer_name)` get the raw Layer object defined in caffe_pb2 22 | -------------------------------------------------------------------------------- /tools/deploy/Caffe/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peterzpy/ACL-DGReID/d84c75dfa7e5b502a094c084ba137b4fac42e82b/tools/deploy/Caffe/__init__.py -------------------------------------------------------------------------------- /tools/deploy/Caffe/caffe_lmdb.py: -------------------------------------------------------------------------------- 1 | import lmdb 2 | from Caffe import caffe_pb2 as pb2 3 | import numpy as np 4 | 5 | class Read_Caffe_LMDB(): 6 | def __init__(self,path,dtype=np.uint8): 7 | 8 | self.env=lmdb.open(path, readonly=True) 9 | self.dtype=dtype 10 | self.txn=self.env.begin() 11 | self.cursor=self.txn.cursor() 12 | 13 | @staticmethod 14 | def to_numpy(value,dtype=np.uint8): 15 | datum = pb2.Datum() 16 | datum.ParseFromString(value) 17 | flat_x = np.fromstring(datum.data, dtype=dtype) 18 | data = flat_x.reshape(datum.channels, datum.height, datum.width) 19 | label=flat_x = datum.label 20 | return data,label 21 | 22 | def iterator(self): 23 | while True: 24 | key,value=self.cursor.key(),self.cursor.value() 25 | yield self.to_numpy(value,self.dtype) 26 | if not self.cursor.next(): 27 | return 28 | 29 | def __iter__(self): 30 | self.cursor.first() 31 | it = self.iterator() 32 | return it 33 | 34 | def __len__(self): 35 | return int(self.env.stat()['entries']) 36 | -------------------------------------------------------------------------------- /tools/deploy/Caffe/net.py: -------------------------------------------------------------------------------- 1 | raise ImportError("the nn_tools.Caffe.net is no longer used, please use nn_tools.Caffe.caffe_net") 2 | 3 | -------------------------------------------------------------------------------- /tools/deploy/caffe_export.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import argparse 8 | import logging 9 | import sys 10 | 11 | import torch 12 | 13 | sys.path.append('.') 14 | 15 | import pytorch_to_caffe 16 | from fastreid.config import get_cfg 17 | from fastreid.modeling.meta_arch import build_model 18 | from fastreid.utils.file_io import PathManager 19 | from fastreid.utils.checkpoint import Checkpointer 20 | from fastreid.utils.logger import setup_logger 21 | 22 | # import some modules added in project like this below 23 | # sys.path.append("projects/PartialReID") 24 | # from partialreid import * 25 | 26 | setup_logger(name='fastreid') 27 | logger = logging.getLogger("fastreid.caffe_export") 28 | 29 | 30 | def setup_cfg(args): 31 | cfg = get_cfg() 32 | cfg.merge_from_file(args.config_file) 33 | cfg.merge_from_list(args.opts) 34 | cfg.freeze() 35 | return cfg 36 | 37 | 38 | def get_parser(): 39 | parser = argparse.ArgumentParser(description="Convert Pytorch to Caffe model") 40 | 41 | parser.add_argument( 42 | "--config-file", 43 | metavar="FILE", 44 | help="path to config file", 45 | ) 46 | parser.add_argument( 47 | "--name", 48 | default="baseline", 49 | help="name for converted model" 50 | ) 51 | parser.add_argument( 52 | "--output", 53 | default='caffe_model', 54 | help='path to save converted caffe model' 55 | ) 56 | parser.add_argument( 57 | "--opts", 58 | help="Modify config options using the command-line 'KEY VALUE' pairs", 59 | default=[], 60 | nargs=argparse.REMAINDER, 61 | ) 62 | return parser 63 | 64 | 65 | if __name__ == '__main__': 66 | args = get_parser().parse_args() 67 | cfg = setup_cfg(args) 68 | 69 | cfg.defrost() 70 | cfg.MODEL.BACKBONE.PRETRAIN = False 71 | cfg.MODEL.HEADS.POOL_LAYER = "Identity" 72 | cfg.MODEL.BACKBONE.WITH_NL = False 73 | 74 | model = build_model(cfg) 75 | Checkpointer(model).load(cfg.MODEL.WEIGHTS) 76 | model.eval() 77 | logger.info(model) 78 | 79 | inputs = torch.randn(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1]).to(torch.device(cfg.MODEL.DEVICE)) 80 | PathManager.mkdirs(args.output) 81 | pytorch_to_caffe.trans_net(model, inputs, args.name) 82 | pytorch_to_caffe.save_prototxt(f"{args.output}/{args.name}.prototxt") 83 | pytorch_to_caffe.save_caffemodel(f"{args.output}/{args.name}.caffemodel") 84 | 85 | logger.info(f"Export caffe model in {args.output} sucessfully!") 86 | -------------------------------------------------------------------------------- /tools/deploy/caffe_inference.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import caffe 8 | import tqdm 9 | import glob 10 | import os 11 | import cv2 12 | import numpy as np 13 | 14 | caffe.set_mode_gpu() 15 | 16 | import argparse 17 | 18 | 19 | def get_parser(): 20 | parser = argparse.ArgumentParser(description="Caffe model inference") 21 | 22 | parser.add_argument( 23 | "--model-def", 24 | default="logs/test_caffe/baseline_R50.prototxt", 25 | help="caffe model prototxt" 26 | ) 27 | parser.add_argument( 28 | "--model-weights", 29 | default="logs/test_caffe/baseline_R50.caffemodel", 30 | help="caffe model weights" 31 | ) 32 | parser.add_argument( 33 | "--input", 34 | nargs="+", 35 | help="A list of space separated input images; " 36 | "or a single glob pattern such as 'directory/*.jpg'", 37 | ) 38 | parser.add_argument( 39 | "--output", 40 | default='caffe_output', 41 | help='path to save converted caffe model' 42 | ) 43 | parser.add_argument( 44 | "--height", 45 | type=int, 46 | default=256, 47 | help="height of image" 48 | ) 49 | parser.add_argument( 50 | "--width", 51 | type=int, 52 | default=128, 53 | help="width of image" 54 | ) 55 | return parser 56 | 57 | 58 | def preprocess(image_path, image_height, image_width): 59 | original_image = cv2.imread(image_path) 60 | # the model expects RGB inputs 61 | original_image = original_image[:, :, ::-1] 62 | 63 | # Apply pre-processing to image. 64 | image = cv2.resize(original_image, (image_width, image_height), interpolation=cv2.INTER_CUBIC) 65 | image = image.astype("float32").transpose(2, 0, 1)[np.newaxis] # (1, 3, h, w) 66 | image = (image - np.array([0.485 * 255, 0.456 * 255, 0.406 * 255]).reshape((1, -1, 1, 1))) / np.array( 67 | [0.229 * 255, 0.224 * 255, 0.225 * 255]).reshape((1, -1, 1, 1)) 68 | return image 69 | 70 | 71 | def normalize(nparray, order=2, axis=-1): 72 | """Normalize a N-D numpy array along the specified axis.""" 73 | norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True) 74 | return nparray / (norm + np.finfo(np.float32).eps) 75 | 76 | 77 | if __name__ == "__main__": 78 | args = get_parser().parse_args() 79 | 80 | net = caffe.Net(args.model_def, args.model_weights, caffe.TEST) 81 | net.blobs['blob1'].reshape(1, 3, args.height, args.width) 82 | 83 | if not os.path.exists(args.output): os.makedirs(args.output) 84 | 85 | if args.input: 86 | if os.path.isdir(args.input[0]): 87 | args.input = glob.glob(os.path.expanduser(args.input[0])) 88 | assert args.input, "The input path(s) was not found" 89 | for path in tqdm.tqdm(args.input): 90 | image = preprocess(path, args.height, args.width) 91 | net.blobs["blob1"].data[...] = image 92 | feat = net.forward()["output"] 93 | feat = normalize(feat[..., 0, 0], axis=1) 94 | np.save(os.path.join(args.output, os.path.basename(path).split('.')[0] + '.npy'), feat) 95 | 96 | -------------------------------------------------------------------------------- /tools/deploy/onnx_inference.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import argparse 8 | import glob 9 | import os 10 | 11 | import cv2 12 | import numpy as np 13 | import onnxruntime 14 | import tqdm 15 | 16 | 17 | def get_parser(): 18 | parser = argparse.ArgumentParser(description="onnx model inference") 19 | 20 | parser.add_argument( 21 | "--model-path", 22 | default="onnx_model/baseline.onnx", 23 | help="onnx model path" 24 | ) 25 | parser.add_argument( 26 | "--input", 27 | nargs="+", 28 | help="A list of space separated input images; " 29 | "or a single glob pattern such as 'directory/*.jpg'", 30 | ) 31 | parser.add_argument( 32 | "--output", 33 | default='onnx_output', 34 | help='path to save converted caffe model' 35 | ) 36 | parser.add_argument( 37 | "--height", 38 | type=int, 39 | default=256, 40 | help="height of image" 41 | ) 42 | parser.add_argument( 43 | "--width", 44 | type=int, 45 | default=128, 46 | help="width of image" 47 | ) 48 | return parser 49 | 50 | 51 | def preprocess(image_path, image_height, image_width): 52 | original_image = cv2.imread(image_path) 53 | # the model expects RGB inputs 54 | original_image = original_image[:, :, ::-1] 55 | 56 | # Apply pre-processing to image. 57 | img = cv2.resize(original_image, (image_width, image_height), interpolation=cv2.INTER_CUBIC) 58 | img = img.astype("float32").transpose(2, 0, 1)[np.newaxis] # (1, 3, h, w) 59 | return img 60 | 61 | 62 | def normalize(nparray, order=2, axis=-1): 63 | """Normalize a N-D numpy array along the specified axis.""" 64 | norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True) 65 | return nparray / (norm + np.finfo(np.float32).eps) 66 | 67 | 68 | if __name__ == "__main__": 69 | args = get_parser().parse_args() 70 | 71 | ort_sess = onnxruntime.InferenceSession(args.model_path) 72 | 73 | input_name = ort_sess.get_inputs()[0].name 74 | 75 | if not os.path.exists(args.output): os.makedirs(args.output) 76 | 77 | if args.input: 78 | if os.path.isdir(args.input[0]): 79 | args.input = glob.glob(os.path.expanduser(args.input[0])) 80 | assert args.input, "The input path(s) was not found" 81 | for path in tqdm.tqdm(args.input): 82 | image = preprocess(path, args.height, args.width) 83 | feat = ort_sess.run(None, {input_name: image})[0] 84 | feat = normalize(feat, axis=1) 85 | np.save(os.path.join(args.output, path.replace('.jpg', '.npy').split('/')[-1]), feat) 86 | -------------------------------------------------------------------------------- /tools/deploy/test_data/0022_c6s1_002976_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peterzpy/ACL-DGReID/d84c75dfa7e5b502a094c084ba137b4fac42e82b/tools/deploy/test_data/0022_c6s1_002976_01.jpg -------------------------------------------------------------------------------- /tools/deploy/test_data/0027_c2s2_091032_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peterzpy/ACL-DGReID/d84c75dfa7e5b502a094c084ba137b4fac42e82b/tools/deploy/test_data/0027_c2s2_091032_02.jpg -------------------------------------------------------------------------------- /tools/deploy/test_data/0032_c6s1_002851_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peterzpy/ACL-DGReID/d84c75dfa7e5b502a094c084ba137b4fac42e82b/tools/deploy/test_data/0032_c6s1_002851_01.jpg -------------------------------------------------------------------------------- /tools/deploy/test_data/0048_c1s1_005351_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peterzpy/ACL-DGReID/d84c75dfa7e5b502a094c084ba137b4fac42e82b/tools/deploy/test_data/0048_c1s1_005351_01.jpg -------------------------------------------------------------------------------- /tools/deploy/test_data/0065_c6s1_009501_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/peterzpy/ACL-DGReID/d84c75dfa7e5b502a094c084ba137b4fac42e82b/tools/deploy/test_data/0065_c6s1_009501_02.jpg -------------------------------------------------------------------------------- /tools/deploy/trt_calibrator.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: xingyu liao 4 | @contact: sherlockliao01@gmail.com 5 | 6 | Create custom calibrator, use to calibrate int8 TensorRT model. 7 | Need to override some methods of trt.IInt8EntropyCalibrator2, such as get_batch_size, get_batch, 8 | read_calibration_cache, write_calibration_cache. 9 | """ 10 | 11 | # based on: 12 | # https://github.com/qq995431104/Pytorch2TensorRT/blob/master/myCalibrator.py 13 | 14 | import os 15 | import sys 16 | 17 | import tensorrt as trt 18 | import pycuda.driver as cuda 19 | import pycuda.autoinit 20 | 21 | import numpy as np 22 | import torchvision.transforms as T 23 | 24 | sys.path.append('../..') 25 | 26 | from fastreid.data.build import _root 27 | from fastreid.data.data_utils import read_image 28 | from fastreid.data.datasets import DATASET_REGISTRY 29 | import logging 30 | 31 | from fastreid.data.transforms import ToTensor 32 | 33 | 34 | logger = logging.getLogger('trt_export.calibrator') 35 | 36 | 37 | class FeatEntropyCalibrator(trt.IInt8EntropyCalibrator2): 38 | 39 | def __init__(self, args): 40 | trt.IInt8EntropyCalibrator2.__init__(self) 41 | 42 | self.cache_file = 'reid_feat.cache' 43 | 44 | self.batch_size = args.batch_size 45 | self.channel = args.channel 46 | self.height = args.height 47 | self.width = args.width 48 | self.transform = T.Compose([ 49 | T.Resize((self.height, self.width), interpolation=3), # [h,w] 50 | ToTensor(), 51 | ]) 52 | 53 | dataset = DATASET_REGISTRY.get(args.calib_data)(root=_root) 54 | self._data_items = dataset.train + dataset.query + dataset.gallery 55 | np.random.shuffle(self._data_items) 56 | self.imgs = [item[0] for item in self._data_items] 57 | 58 | self.batch_idx = 0 59 | self.max_batch_idx = len(self.imgs) // self.batch_size 60 | 61 | self.data_size = self.batch_size * self.channel * self.height * self.width * trt.float32.itemsize 62 | self.device_input = cuda.mem_alloc(self.data_size) 63 | 64 | def next_batch(self): 65 | if self.batch_idx < self.max_batch_idx: 66 | batch_files = self.imgs[self.batch_idx * self.batch_size:(self.batch_idx + 1) * self.batch_size] 67 | batch_imgs = np.zeros((self.batch_size, self.channel, self.height, self.width), 68 | dtype=np.float32) 69 | for i, f in enumerate(batch_files): 70 | img = read_image(f) 71 | img = self.transform(img).numpy() 72 | assert (img.nbytes == self.data_size // self.batch_size), 'not valid img!' + f 73 | batch_imgs[i] = img 74 | self.batch_idx += 1 75 | logger.info("batch:[{}/{}]".format(self.batch_idx, self.max_batch_idx)) 76 | return np.ascontiguousarray(batch_imgs) 77 | else: 78 | return np.array([]) 79 | 80 | def get_batch_size(self): 81 | return self.batch_size 82 | 83 | def get_batch(self, names, p_str=None): 84 | try: 85 | batch_imgs = self.next_batch() 86 | batch_imgs = batch_imgs.ravel() 87 | if batch_imgs.size == 0 or batch_imgs.size != self.batch_size * self.channel * self.height * self.width: 88 | return None 89 | cuda.memcpy_htod(self.device_input, batch_imgs.astype(np.float32)) 90 | return [int(self.device_input)] 91 | except: 92 | return None 93 | 94 | def read_calibration_cache(self): 95 | # If there is a cache, use it instead of calibrating again. Otherwise, implicitly return None. 96 | if os.path.exists(self.cache_file): 97 | with open(self.cache_file, "rb") as f: 98 | return f.read() 99 | 100 | def write_calibration_cache(self, cache): 101 | with open(self.cache_file, "wb") as f: 102 | f.write(cache) 103 | -------------------------------------------------------------------------------- /tools/train_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | """ 4 | @author: sherlock 5 | @contact: sherlockliao01@gmail.com 6 | """ 7 | 8 | import sys 9 | 10 | sys.path.append('.') 11 | 12 | from fastreid.config import get_cfg 13 | from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup, launch 14 | from fastreid.utils.checkpoint import Checkpointer 15 | 16 | 17 | def setup(args): 18 | """ 19 | Create configs and perform basic setups. 20 | """ 21 | cfg = get_cfg() 22 | cfg.merge_from_file(args.config_file) 23 | cfg.merge_from_list(args.opts) 24 | default_setup(cfg, args) 25 | return cfg 26 | 27 | 28 | def main(args): 29 | cfg = setup(args) 30 | 31 | if args.eval_only: 32 | cfg.defrost() 33 | cfg.MODEL.BACKBONE.PRETRAIN = False 34 | model = DefaultTrainer.build_model(cfg) 35 | 36 | Checkpointer(model).load(cfg.MODEL.WEIGHTS) # load trained model 37 | 38 | res = DefaultTrainer.test(cfg, model, 0) 39 | return res 40 | 41 | trainer = DefaultTrainer(cfg) 42 | 43 | trainer.resume_or_load(resume=args.resume) 44 | return trainer.train() 45 | 46 | 47 | if __name__ == "__main__": 48 | args = default_argument_parser().parse_args() 49 | print("Command Line Args:", args) 50 | launch( 51 | main, 52 | args.num_gpus, 53 | num_machines=args.num_machines, 54 | machine_rank=args.machine_rank, 55 | dist_url=args.dist_url, 56 | args=(args,), 57 | ) 58 | --------------------------------------------------------------------------------