├── README.md ├── Rename.py ├── config ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── defaults.cpython-37.pyc │ └── defaults.cpython-38.pyc └── defaults.py ├── configs ├── DukeMTMC │ └── dpm.yml ├── Market │ └── dpm.yml └── OCC_Duke │ └── dpm.yml ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── bases.cpython-37.pyc │ ├── dukemtmcreid.cpython-37.pyc │ ├── make_dataloader.cpython-37.pyc │ ├── make_dataloader.cpython-38.pyc │ ├── market1501.cpython-37.pyc │ ├── msmt17.cpython-37.pyc │ ├── occ_duke.cpython-37.pyc │ ├── sampler.cpython-37.pyc │ ├── sampler_ddp.cpython-37.pyc │ ├── vehicleid.cpython-37.pyc │ └── veri.cpython-37.pyc ├── bases.py ├── dukemtmcreid.py ├── make_dataloader.py ├── market1501.py ├── occ_duke.py ├── preprocessing.py ├── sampler.py └── sampler_ddp.py ├── loss ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── arcface.cpython-37.pyc │ ├── center_loss.cpython-37.pyc │ ├── make_loss.cpython-37.pyc │ ├── metric_learning.cpython-37.pyc │ ├── softmax_loss.cpython-37.pyc │ └── triplet_loss.cpython-37.pyc ├── arcface.py ├── center_loss.py ├── make_loss.py ├── metric_learning.py ├── softmax_loss.py └── triplet_loss.py ├── model ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── make_model.cpython-37.pyc ├── backbones │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── resnet.cpython-37.pyc │ │ └── vit_pytorch.cpython-37.pyc │ ├── resnet.py │ └── vit_pytorch.py └── make_model.py ├── processor ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── processor.cpython-37.pyc └── processor.py ├── requirements.txt ├── solver ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── cosine_lr.cpython-37.pyc │ ├── lr_scheduler.cpython-37.pyc │ ├── make_optimizer.cpython-37.pyc │ ├── scheduler.cpython-37.pyc │ └── scheduler_factory.cpython-37.pyc ├── cosine_lr.py ├── lr_scheduler.py ├── make_optimizer.py ├── scheduler.py └── scheduler_factory.py ├── test.py ├── train.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-37.pyc ├── __init__.cpython-38.pyc ├── iotools.cpython-37.pyc ├── logger.cpython-37.pyc ├── logger.cpython-38.pyc ├── meter.cpython-37.pyc ├── metrics.cpython-37.pyc └── reranking.cpython-37.pyc ├── iotools.py ├── logger.py ├── meter.py ├── metrics.py └── reranking.py /README.md: -------------------------------------------------------------------------------- 1 | ![Python >=3.6](https://img.shields.io/badge/Python->=3.6-yellow.svg) 2 | ![PyTorch >=1.10](https://img.shields.io/badge/PyTorch->=1.10-blue.svg) 3 | 4 | # [ACMMM2022] Dynamic Prototype Mask for Occluded Person Re-Identification 5 | The official repository for Dynamic Prototype Mask for Occluded Person Re-Identification [[pdf]](https://arxiv.org/pdf/2207.09046.pdf) 6 | 7 | ### Prepare Datasets 8 | 9 | ```bash 10 | mkdir data 11 | ``` 12 | Download the person datasets [Market-1501](https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view), [DukeMTMC-reID](https://arxiv.org/abs/1609.01775), [Occluded-Duke](https://github.com/lightas/Occluded-DukeMTMC-Dataset), and the [Occluded_REID](https://github.com/wangguanan/light-reid/blob/master/reid_datasets.md), 13 | Then unzip them and rename them under the directory like 14 | 15 | ``` 16 | data 17 | ├── Occluded_Duke 18 | │ └── images .. 19 | ├── Occluded_REID 20 | │ └── images .. 21 | ├── market1501 22 | │ └── images .. 23 | └── dukemtmcreid 24 | └── images .. 25 | ``` 26 | 27 | ### Installation 28 | 29 | ```bash 30 | pip install -r requirements.txt 31 | ``` 32 | 33 | ### Prepare ViT Pre-trained Models 34 | 35 | You need to download the ImageNet pretrained transformer model : [ViT-Base](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth) 36 | 37 | ## Training 38 | 39 | We utilize 1 3090 GPU for training and it takes around 14GB GPU memory. 40 | 41 | You can train the DPM with: 42 | 43 | ```bash 44 | python train.py --config_file configs/dpm.yml MODEL.DEVICE_ID "('your device id')" 45 | ``` 46 | **Some examples:** 47 | ```bash 48 | # Occluded_Duke 49 | python train.py --config_file configs/OCC_Duke/dpm.yml MODEL.DEVICE_ID "('0')" 50 | ``` 51 | 52 | 1. We have set the validation set as Occluded REID when training on the Market-1501. Therefore, if you want to use the Market-1501, please modify it in the 'datasets/market1501.py'. 53 | 54 | 2. Before training on the Occluded REID, please put the Rename.py under the dataset dir to rename the dataset. 55 | 56 | 57 | ## Evaluation 58 | 59 | ```bash 60 | python test.py --config_file 'choose which config to test' MODEL.DEVICE_ID "('your device id')" TEST.WEIGHT "('your path of trained checkpoints')" 61 | ``` 62 | 63 | **Some examples:** 64 | ```bash 65 | # OCC_Duke 66 | python test.py --config_file configs/OCC_Duke/dpm.yml MODEL.DEVICE_ID "('0')" TEST.WEIGHT './logs/occ_duke_dpm/transformer_150.pth' 67 | ``` 68 | 69 | #### Results 70 | | Dataset | Rank@1 | mAP | Model | 71 | | :------: |:------: | :------: | :------: | 72 | | Occluded-Duke | 71.4 (72.0) | 61.8 (61.9) | [model](https://drive.google.com/file/d/12rTyilUnwOy-lsaM65Y_ce_6AmOivgm1/view?usp=sharing) | 73 | | Occluded-REID | 85.5 (86.2) | 79.7 (80.0) | [model](https://drive.google.com/file/d/1J86byKnQocDK9XZeQuMvg-qvS_gN_zAd/view?usp=sharing) | 74 | 75 | We reorganize code and the performances are slightly higher than the paper's. 76 | 77 | ## Citation 78 | Please kindly cite this paper in your publications if it helps your research: 79 | ```bash 80 | @inproceedings{tan2022dynamic, 81 | title={Dynamic prototype mask for occluded person re-identification}, 82 | author={Tan, Lei and Dai, Pingyang and Ji, Rongrong and Wu, Yongjian}, 83 | booktitle={Proceedings of the 30th ACM international conference on multimedia}, 84 | pages={531--540}, 85 | year={2022} 86 | } 87 | ``` 88 | 89 | ## Acknowledgement 90 | Our code is based on [TransReID](https://github.com/damo-cv/TransReID)[1] 91 | 92 | ## References 93 | [1]Shuting He, Hao Luo, Pichao Wang, Fan Wang, Hao Li, and Wei Jiang. 2021. Transreid: Transformer-based object re-identification. In Proceedings of the IEEE/CVF 94 | International Conference on Computer Vision. 15013–15022. 95 | 96 | ## Contact 97 | 98 | If you have any questions, please feel free to contact us. E-mail: [tanlei@stu.xmu.edu.cn](mailto:tanlei@stu.xmu.edu.cn) 99 | -------------------------------------------------------------------------------- /Rename.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import os.path as osp 4 | 5 | query_path = 'occluded_body_images/' 6 | gallery_path = 'whole_body_images/' 7 | tquery_path = 'query/' 8 | tgallery_path = 'bounding_box_test/' 9 | 10 | if not osp.exists(tquery_path): 11 | os.makedirs(tquery_path) 12 | 13 | if not osp.exists(tgallery_path): 14 | os.makedirs(tgallery_path) 15 | 16 | for i in range(200): 17 | i = i + 1 18 | filename = str(i).zfill(3) 19 | tfilename = str(i).zfill(4) 20 | for j in range(5): 21 | j = j + 1 22 | img_num = str(j).zfill(2) 23 | 24 | qimg_path = query_path + filename + '/' + filename + '_' + img_num + '.tif' 25 | qout_path = tquery_path + tfilename + '_c1s1' + '_' + str(j * 50).zfill(6) + '_00' + '.tif' 26 | shutil.copyfile(qimg_path, qout_path) 27 | 28 | gimg_path = gallery_path + filename + '/' + filename + '_' + img_num + '.tif' 29 | gout_path = tgallery_path + tfilename + '_c3s3' + '_' + str(j * 50).zfill(6) + '_00' + '.tif' 30 | shutil.copyfile(gimg_path, gout_path) -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .defaults import _C as cfg 8 | from .defaults import _C as cfg_test 9 | -------------------------------------------------------------------------------- /config/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/config/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /config/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/config/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /config/__pycache__/defaults.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/config/__pycache__/defaults.cpython-37.pyc -------------------------------------------------------------------------------- /config/__pycache__/defaults.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/config/__pycache__/defaults.cpython-38.pyc -------------------------------------------------------------------------------- /config/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | # ----------------------------------------------------------------------------- 4 | # Convention about Training / Test specific parameters 5 | # ----------------------------------------------------------------------------- 6 | # Whenever an argument can be either used for training or for testing, the 7 | # corresponding name will be post-fixed by a _TRAIN for a training parameter, 8 | 9 | # ----------------------------------------------------------------------------- 10 | # Config definition 11 | # ----------------------------------------------------------------------------- 12 | 13 | _C = CN() 14 | # ----------------------------------------------------------------------------- 15 | # MODEL 16 | # ----------------------------------------------------------------------------- 17 | _C.MODEL = CN() 18 | # Using cuda or cpu for training 19 | _C.MODEL.DEVICE = "cuda" 20 | # ID number of GPU 21 | _C.MODEL.DEVICE_ID = '0' 22 | # Name of backbone 23 | _C.MODEL.NAME = 'resnet50' 24 | # Last stride of backbone 25 | _C.MODEL.LAST_STRIDE = 1 26 | # Path to pretrained model of backbone 27 | _C.MODEL.PRETRAIN_PATH = '' 28 | 29 | # Use ImageNet pretrained model to initialize backbone or use self trained model to initialize the whole model 30 | # Options: 'imagenet' , 'self' , 'finetune' 31 | _C.MODEL.PRETRAIN_CHOICE = 'imagenet' 32 | 33 | # If train with BNNeck, options: 'bnneck' or 'no' 34 | _C.MODEL.NECK = 'bnneck' 35 | # If train loss include center loss, options: 'yes' or 'no'. Loss with center loss has different optimizer configuration 36 | _C.MODEL.IF_WITH_CENTER = 'no' 37 | 38 | _C.MODEL.ID_LOSS_TYPE = 'softmax' 39 | _C.MODEL.ID_LOSS_WEIGHT = 1.0 40 | _C.MODEL.TRIPLET_LOSS_WEIGHT = 1.0 41 | 42 | _C.MODEL.METRIC_LOSS_TYPE = 'triplet' 43 | # If train with multi-gpu ddp mode, options: 'True', 'False' 44 | _C.MODEL.DIST_TRAIN = False 45 | # If train with soft triplet loss, options: 'True', 'False' 46 | _C.MODEL.NO_MARGIN = False 47 | # If train with label smooth, options: 'on', 'off' 48 | _C.MODEL.IF_LABELSMOOTH = 'on' 49 | # If train with arcface loss, options: 'True', 'False' 50 | _C.MODEL.COS_LAYER = False 51 | 52 | # Transformer setting 53 | _C.MODEL.DROP_PATH = 0.1 54 | _C.MODEL.DROP_OUT = 0.0 55 | _C.MODEL.ATT_DROP_RATE = 0.0 56 | _C.MODEL.TRANSFORMER_TYPE = 'None' 57 | _C.MODEL.STRIDE_SIZE = [16, 16] 58 | 59 | # JPM Parameter 60 | _C.MODEL.JPM = False 61 | _C.MODEL.SHIFT_NUM = 5 62 | _C.MODEL.SHUFFLE_GROUP = 2 63 | _C.MODEL.DEVIDE_LENGTH = 4 64 | _C.MODEL.RE_ARRANGE = True 65 | 66 | # SIE Parameter 67 | _C.MODEL.SIE_COE = 3.0 68 | _C.MODEL.SIE_CAMERA = False 69 | _C.MODEL.SIE_VIEW = False 70 | 71 | # ----------------------------------------------------------------------------- 72 | # INPUT 73 | # ----------------------------------------------------------------------------- 74 | _C.INPUT = CN() 75 | # Size of the image during training 76 | _C.INPUT.SIZE_TRAIN = [384, 128] 77 | # Size of the image during test 78 | _C.INPUT.SIZE_TEST = [384, 128] 79 | # Random probability for image horizontal flip 80 | _C.INPUT.PROB = 0.5 81 | # Random probability for random erasing 82 | _C.INPUT.RE_PROB = 0.5 83 | # Values to be used for image normalization 84 | _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406] 85 | # Values to be used for image normalization 86 | _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225] 87 | # Value of padding size 88 | _C.INPUT.PADDING = 10 89 | 90 | # ----------------------------------------------------------------------------- 91 | # Dataset 92 | # ----------------------------------------------------------------------------- 93 | _C.DATASETS = CN() 94 | # List of the dataset names for training, as present in paths_catalog.py 95 | _C.DATASETS.NAMES = ('market1501') 96 | # Root directory where datasets should be used (and downloaded if not found) 97 | _C.DATASETS.ROOT_DIR = ('../data') 98 | 99 | 100 | # ----------------------------------------------------------------------------- 101 | # DataLoader 102 | # ----------------------------------------------------------------------------- 103 | _C.DATALOADER = CN() 104 | # Number of data loading threads 105 | _C.DATALOADER.NUM_WORKERS = 8 106 | # Sampler for data loading 107 | _C.DATALOADER.SAMPLER = 'softmax' 108 | # Number of instance for one batch 109 | _C.DATALOADER.NUM_INSTANCE = 16 110 | 111 | # ---------------------------------------------------------------------------- # 112 | # Solver 113 | # ---------------------------------------------------------------------------- # 114 | _C.SOLVER = CN() 115 | # Name of optimizer 116 | _C.SOLVER.OPTIMIZER_NAME = "Adam" 117 | # Number of max epoches 118 | _C.SOLVER.MAX_EPOCHS = 100 119 | # Base learning rate 120 | _C.SOLVER.BASE_LR = 3e-4 121 | # Whether using larger learning rate for fc layer 122 | _C.SOLVER.LARGE_FC_LR = False 123 | # Factor of learning bias 124 | _C.SOLVER.BIAS_LR_FACTOR = 1 125 | # Factor of learning bias 126 | _C.SOLVER.SEED = 1234 127 | # Momentum 128 | _C.SOLVER.MOMENTUM = 0.9 129 | # Margin of triplet loss 130 | _C.SOLVER.MARGIN = 0.3 131 | # Learning rate of SGD to learn the centers of center loss 132 | _C.SOLVER.CENTER_LR = 0.5 133 | # Balanced weight of center loss 134 | _C.SOLVER.CENTER_LOSS_WEIGHT = 0.0005 135 | 136 | # Settings of weight decay 137 | _C.SOLVER.WEIGHT_DECAY = 0.0005 138 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0.0005 139 | 140 | # decay rate of learning rate 141 | _C.SOLVER.GAMMA = 0.1 142 | # decay step of learning rate 143 | _C.SOLVER.STEPS = (40, 70) 144 | # warm up factor 145 | _C.SOLVER.WARMUP_FACTOR = 0.01 146 | # warm up epochs 147 | _C.SOLVER.WARMUP_EPOCHS = 5 148 | # method of warm up, option: 'constant','linear' 149 | _C.SOLVER.WARMUP_METHOD = "linear" 150 | 151 | _C.SOLVER.COSINE_MARGIN = 0.5 152 | _C.SOLVER.COSINE_SCALE = 30 153 | 154 | # epoch number of saving checkpoints 155 | _C.SOLVER.CHECKPOINT_PERIOD = 10 156 | # iteration of display training log 157 | _C.SOLVER.LOG_PERIOD = 100 158 | # epoch number of validation 159 | _C.SOLVER.EVAL_PERIOD = 10 160 | # Number of images per batch 161 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 128, each GPU will 162 | # contain 16 images per batch 163 | _C.SOLVER.IMS_PER_BATCH = 64 164 | 165 | # ---------------------------------------------------------------------------- # 166 | # TEST 167 | # ---------------------------------------------------------------------------- # 168 | 169 | _C.TEST = CN() 170 | # Number of images per batch during test 171 | _C.TEST.IMS_PER_BATCH = 128 172 | # If test with re-ranking, options: 'True','False' 173 | _C.TEST.RE_RANKING = False 174 | # Path to trained model 175 | _C.TEST.WEIGHT = "" 176 | # Which feature of BNNeck to be used for test, before or after BNNneck, options: 'before' or 'after' 177 | _C.TEST.NECK_FEAT = 'after' 178 | # Whether feature is nomalized before test, if yes, it is equivalent to cosine distance 179 | _C.TEST.FEAT_NORM = 'yes' 180 | 181 | # Name for saving the distmat after testing. 182 | _C.TEST.DIST_MAT = "dist_mat.npy" 183 | # Whether calculate the eval score option: 'True', 'False' 184 | _C.TEST.EVAL = False 185 | # ---------------------------------------------------------------------------- # 186 | # Misc options 187 | # ---------------------------------------------------------------------------- # 188 | # Path to checkpoint and saved log of trained model 189 | _C.OUTPUT_DIR = "" 190 | -------------------------------------------------------------------------------- /configs/DukeMTMC/dpm.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/tan/data/TransReID-main/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('4') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: False 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('dukemtmc') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 150 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 10 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | COSINE_MARGIN: 0.3 49 | 50 | TEST: 51 | EVAL: True 52 | IMS_PER_BATCH: 256 53 | RE_RANKING: False 54 | WEIGHT: '../logs/duke_vit_transreid_stride/transformer_150.pth' 55 | NECK_FEAT: 'before' 56 | FEAT_NORM: 'yes' 57 | 58 | OUTPUT_DIR: './logs/duke_dpm' 59 | 60 | 61 | -------------------------------------------------------------------------------- /configs/Market/dpm.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/tan/data/TransReID-main/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('5') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: False 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('market1501') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 50 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 120 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 1 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | COSINE_MARGIN: 0.3 49 | 50 | TEST: 51 | EVAL: True 52 | IMS_PER_BATCH: 256 53 | RE_RANKING: False 54 | WEIGHT: '' 55 | NECK_FEAT: 'before' 56 | FEAT_NORM: 'yes' 57 | 58 | OUTPUT_DIR: './logs/market_dpm' 59 | 60 | 61 | -------------------------------------------------------------------------------- /configs/OCC_Duke/dpm.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/tan/data/TransReID-main/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('3') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_TransReID' 11 | STRIDE_SIZE: [11, 11] 12 | SIE_CAMERA: True 13 | SIE_COE: 3.0 14 | JPM: False 15 | RE_ARRANGE: True 16 | 17 | INPUT: 18 | SIZE_TRAIN: [256, 128] 19 | SIZE_TEST: [256, 128] 20 | PROB: 0.5 # random horizontal flip 21 | RE_PROB: 0.5 # random erasing 22 | PADDING: 10 23 | PIXEL_MEAN: [0.5, 0.5, 0.5] 24 | PIXEL_STD: [0.5, 0.5, 0.5] 25 | 26 | DATASETS: 27 | NAMES: ('occ_duke') 28 | ROOT_DIR: ('../../data') 29 | 30 | DATALOADER: 31 | SAMPLER: 'softmax_triplet' 32 | NUM_INSTANCE: 4 33 | NUM_WORKERS: 8 34 | 35 | SOLVER: 36 | OPTIMIZER_NAME: 'SGD' 37 | MAX_EPOCHS: 150 38 | BASE_LR: 0.008 39 | IMS_PER_BATCH: 64 40 | WARMUP_METHOD: 'linear' 41 | LARGE_FC_LR: False 42 | CHECKPOINT_PERIOD: 150 43 | LOG_PERIOD: 50 44 | EVAL_PERIOD: 10 45 | WEIGHT_DECAY: 1e-4 46 | WEIGHT_DECAY_BIAS: 1e-4 47 | BIAS_LR_FACTOR: 2 48 | 49 | TEST: 50 | EVAL: True 51 | IMS_PER_BATCH: 256 52 | RE_RANKING: False 53 | WEIGHT: '' 54 | NECK_FEAT: 'before' 55 | FEAT_NORM: 'yes' 56 | 57 | OUTPUT_DIR: './logs/occ_duke_dpm' 58 | 59 | 60 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_dataloader import make_dataloader -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/bases.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/datasets/__pycache__/bases.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dukemtmcreid.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/datasets/__pycache__/dukemtmcreid.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/make_dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/datasets/__pycache__/make_dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/make_dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/datasets/__pycache__/make_dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/market1501.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/datasets/__pycache__/market1501.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/msmt17.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/datasets/__pycache__/msmt17.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/occ_duke.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/datasets/__pycache__/occ_duke.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sampler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/datasets/__pycache__/sampler.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sampler_ddp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/datasets/__pycache__/sampler_ddp.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/vehicleid.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/datasets/__pycache__/vehicleid.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/veri.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/datasets/__pycache__/veri.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/bases.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageFile 2 | 3 | from torch.utils.data import Dataset 4 | import os.path as osp 5 | import random 6 | import torch 7 | ImageFile.LOAD_TRUNCATED_IMAGES = True 8 | 9 | 10 | def read_image(img_path): 11 | """Keep reading image until succeed. 12 | This can avoid IOError incurred by heavy IO process.""" 13 | got_img = False 14 | if not osp.exists(img_path): 15 | raise IOError("{} does not exist".format(img_path)) 16 | while not got_img: 17 | try: 18 | img = Image.open(img_path).convert('RGB') 19 | got_img = True 20 | except IOError: 21 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 22 | pass 23 | return img 24 | 25 | 26 | class BaseDataset(object): 27 | """ 28 | Base class of reid dataset 29 | """ 30 | 31 | def get_imagedata_info(self, data): 32 | pids, cams, tracks = [], [], [] 33 | 34 | for _, pid, camid, trackid in data: 35 | pids += [pid] 36 | cams += [camid] 37 | tracks += [trackid] 38 | pids = set(pids) 39 | cams = set(cams) 40 | tracks = set(tracks) 41 | num_pids = len(pids) 42 | num_cams = len(cams) 43 | num_imgs = len(data) 44 | num_views = len(tracks) 45 | return num_pids, num_imgs, num_cams, num_views 46 | 47 | def print_dataset_statistics(self): 48 | raise NotImplementedError 49 | 50 | 51 | class BaseImageDataset(BaseDataset): 52 | """ 53 | Base class of image reid dataset 54 | """ 55 | 56 | def print_dataset_statistics(self, train, query, gallery): 57 | num_train_pids, num_train_imgs, num_train_cams, num_train_views = self.get_imagedata_info(train) 58 | num_query_pids, num_query_imgs, num_query_cams, num_train_views = self.get_imagedata_info(query) 59 | num_gallery_pids, num_gallery_imgs, num_gallery_cams, num_train_views = self.get_imagedata_info(gallery) 60 | 61 | print("Dataset statistics:") 62 | print(" ----------------------------------------") 63 | print(" subset | # ids | # images | # cameras") 64 | print(" ----------------------------------------") 65 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 66 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 67 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 68 | print(" ----------------------------------------") 69 | 70 | 71 | class ImageDataset(Dataset): 72 | def __init__(self, dataset, transform=None): 73 | self.dataset = dataset 74 | self.transform = transform 75 | 76 | def __len__(self): 77 | return len(self.dataset) 78 | 79 | def __getitem__(self, index): 80 | img_path, pid, camid, trackid = self.dataset[index] 81 | img = read_image(img_path) 82 | 83 | if self.transform is not None: 84 | img = self.transform(img) 85 | 86 | return img, pid, camid, trackid,img_path.split('/')[-1] -------------------------------------------------------------------------------- /datasets/dukemtmcreid.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import glob 8 | import re 9 | import urllib 10 | import zipfile 11 | 12 | import os.path as osp 13 | 14 | from utils.iotools import mkdir_if_missing 15 | from .bases import BaseImageDataset 16 | 17 | 18 | class DukeMTMCreID(BaseImageDataset): 19 | """ 20 | DukeMTMC-reID 21 | Reference: 22 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 23 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 24 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 25 | 26 | Dataset statistics: 27 | # identities: 1404 (train + query) 28 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 29 | # cameras: 8 30 | """ 31 | dataset_dir = 'dukemtmcreid' 32 | 33 | def __init__(self, root='', verbose=True, pid_begin=0, **kwargs): 34 | super(DukeMTMCreID, self).__init__() 35 | self.dataset_dir = osp.join(root, self.dataset_dir) 36 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 37 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 38 | self.query_dir = osp.join(self.dataset_dir, 'query') 39 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 40 | self.pid_begin = pid_begin 41 | self._download_data() 42 | self._check_before_run() 43 | 44 | train = self._process_dir(self.train_dir, relabel=True) 45 | query = self._process_dir(self.query_dir, relabel=False) 46 | gallery = self._process_dir(self.gallery_dir, relabel=False) 47 | 48 | if verbose: 49 | print("=> DukeMTMC-reID loaded") 50 | self.print_dataset_statistics(train, query, gallery) 51 | 52 | self.train = train 53 | self.query = query 54 | self.gallery = gallery 55 | 56 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train) 57 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query) 58 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery) 59 | 60 | def _download_data(self): 61 | if osp.exists(self.dataset_dir): 62 | print("This dataset has been downloaded.") 63 | return 64 | 65 | print("Creating directory {}".format(self.dataset_dir)) 66 | mkdir_if_missing(self.dataset_dir) 67 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 68 | 69 | print("Downloading DukeMTMC-reID dataset") 70 | urllib.request.urlretrieve(self.dataset_url, fpath) 71 | 72 | print("Extracting files") 73 | zip_ref = zipfile.ZipFile(fpath, 'r') 74 | zip_ref.extractall(self.dataset_dir) 75 | zip_ref.close() 76 | 77 | def _check_before_run(self): 78 | """Check if all files are available before going deeper""" 79 | if not osp.exists(self.dataset_dir): 80 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 81 | if not osp.exists(self.train_dir): 82 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 83 | if not osp.exists(self.query_dir): 84 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 85 | if not osp.exists(self.gallery_dir): 86 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 87 | 88 | def _process_dir(self, dir_path, relabel=False): 89 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 90 | pattern = re.compile(r'([-\d]+)_c(\d)') 91 | 92 | pid_container = set() 93 | for img_path in img_paths: 94 | pid, _ = map(int, pattern.search(img_path).groups()) 95 | pid_container.add(pid) 96 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 97 | 98 | dataset = [] 99 | cam_container = set() 100 | for img_path in img_paths: 101 | pid, camid = map(int, pattern.search(img_path).groups()) 102 | assert 1 <= camid <= 8 103 | camid -= 1 # index starts from 0 104 | if relabel: pid = pid2label[pid] 105 | dataset.append((img_path, self.pid_begin + pid, camid, 1)) 106 | cam_container.add(camid) 107 | print(cam_container, 'cam_container') 108 | return dataset 109 | -------------------------------------------------------------------------------- /datasets/make_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as T 3 | from torch.utils.data import DataLoader 4 | 5 | from .bases import ImageDataset 6 | from timm.data.random_erasing import RandomErasing 7 | from .sampler import RandomIdentitySampler 8 | from .dukemtmcreid import DukeMTMCreID 9 | from .market1501 import Market1501 10 | from .sampler_ddp import RandomIdentitySampler_DDP 11 | import torch.distributed as dist 12 | from .occ_duke import OCC_DukeMTMCreID 13 | 14 | __factory = { 15 | 'market1501': Market1501, 16 | 'dukemtmc': DukeMTMCreID, 17 | 'occ_duke': OCC_DukeMTMCreID, 18 | } 19 | 20 | def train_collate_fn(batch): 21 | imgs, pids, camids, viewids , _ = zip(*batch) 22 | pids = torch.tensor(pids, dtype=torch.int64) 23 | viewids = torch.tensor(viewids, dtype=torch.int64) 24 | camids = torch.tensor(camids, dtype=torch.int64) 25 | return torch.stack(imgs, dim=0), pids, camids, viewids, 26 | 27 | def val_collate_fn(batch): 28 | imgs, pids, camids, viewids, img_paths = zip(*batch) 29 | viewids = torch.tensor(viewids, dtype=torch.int64) 30 | camids_batch = torch.tensor(camids, dtype=torch.int64) 31 | return torch.stack(imgs, dim=0), pids, camids, camids_batch, viewids, img_paths 32 | 33 | def make_dataloader(cfg): 34 | train_transforms = T.Compose([ 35 | T.Resize(cfg.INPUT.SIZE_TRAIN, interpolation=3), 36 | #T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2, hue=0.2), # Release this part when training for the Occluded-REID 37 | T.RandomHorizontalFlip(p=cfg.INPUT.PROB), 38 | T.Pad(cfg.INPUT.PADDING), 39 | T.RandomCrop(cfg.INPUT.SIZE_TRAIN), 40 | T.ToTensor(), 41 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD), 42 | RandomErasing(probability=cfg.INPUT.RE_PROB, mode='pixel', max_count=1, device='cpu'), 43 | # RandomErasing(probability=cfg.INPUT.RE_PROB, mean=cfg.INPUT.PIXEL_MEAN) 44 | ]) 45 | 46 | val_transforms = T.Compose([ 47 | T.Resize(cfg.INPUT.SIZE_TEST), 48 | T.ToTensor(), 49 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) 50 | ]) 51 | 52 | num_workers = cfg.DATALOADER.NUM_WORKERS 53 | 54 | dataset = __factory[cfg.DATASETS.NAMES](root=cfg.DATASETS.ROOT_DIR) 55 | 56 | train_set = ImageDataset(dataset.train, train_transforms) 57 | train_set_normal = ImageDataset(dataset.train, val_transforms) 58 | num_classes = dataset.num_train_pids 59 | cam_num = dataset.num_train_cams 60 | view_num = dataset.num_train_vids 61 | 62 | if 'triplet' in cfg.DATALOADER.SAMPLER: 63 | if cfg.MODEL.DIST_TRAIN: 64 | print('DIST_TRAIN START') 65 | mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // dist.get_world_size() 66 | data_sampler = RandomIdentitySampler_DDP(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE) 67 | batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True) 68 | train_loader = torch.utils.data.DataLoader( 69 | train_set, 70 | num_workers=num_workers, 71 | batch_sampler=batch_sampler, 72 | collate_fn=train_collate_fn, 73 | pin_memory=True, 74 | ) 75 | else: 76 | train_loader = DataLoader( 77 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, 78 | sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE), 79 | num_workers=num_workers, collate_fn=train_collate_fn 80 | ) 81 | elif cfg.DATALOADER.SAMPLER == 'softmax': 82 | print('using softmax sampler') 83 | train_loader = DataLoader( 84 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers, 85 | collate_fn=train_collate_fn 86 | ) 87 | else: 88 | print('unsupported sampler! expected softmax or triplet but got {}'.format(cfg.SAMPLER)) 89 | 90 | val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms) 91 | 92 | val_loader = DataLoader( 93 | val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 94 | collate_fn=val_collate_fn 95 | ) 96 | train_loader_normal = DataLoader( 97 | train_set_normal, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 98 | collate_fn=val_collate_fn 99 | ) 100 | return train_loader, train_loader_normal, val_loader, len(dataset.query), num_classes, cam_num, view_num 101 | -------------------------------------------------------------------------------- /datasets/market1501.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import glob 8 | import re 9 | 10 | import os.path as osp 11 | 12 | from .bases import BaseImageDataset 13 | from collections import defaultdict 14 | import pickle 15 | class Market1501(BaseImageDataset): 16 | """ 17 | Market1501 18 | Reference: 19 | Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015. 20 | URL: http://www.liangzheng.org/Project/project_reid.html 21 | 22 | Dataset statistics: 23 | # identities: 1501 (+1 for background) 24 | # images: 12936 (train) + 3368 (query) + 15913 (gallery) 25 | """ 26 | dataset_dir = 'market1501' 27 | val_dir = 'Occluded_REID' 28 | 29 | def __init__(self, root='', verbose=True, pid_begin = 0, **kwargs): 30 | super(Market1501, self).__init__() 31 | self.dataset_dir = osp.join(root, self.dataset_dir) 32 | self.val_dir = osp.join(root, self.val_dir) 33 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 34 | self.query_dir = osp.join(self.val_dir, 'query') 35 | self.gallery_dir = osp.join(self.val_dir, 'bounding_box_test') 36 | 37 | self._check_before_run() 38 | self.pid_begin = pid_begin 39 | train = self._process_dir(self.train_dir, relabel=True) 40 | query = self._process_valdir(self.query_dir, relabel=False) 41 | gallery = self._process_valdir(self.gallery_dir, relabel=False) 42 | 43 | if verbose: 44 | print("=> Market1501 loaded") 45 | self.print_dataset_statistics(train, query, gallery) 46 | 47 | self.train = train 48 | self.query = query 49 | self.gallery = gallery 50 | 51 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train) 52 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query) 53 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery) 54 | 55 | def _check_before_run(self): 56 | """Check if all files are available before going deeper""" 57 | if not osp.exists(self.dataset_dir): 58 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 59 | if not osp.exists(self.train_dir): 60 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 61 | if not osp.exists(self.query_dir): 62 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 63 | if not osp.exists(self.gallery_dir): 64 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 65 | 66 | def _process_dir(self, dir_path, relabel=False): 67 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 68 | pattern = re.compile(r'([-\d]+)_c(\d)') 69 | 70 | pid_container = set() 71 | for img_path in sorted(img_paths): 72 | pid, _ = map(int, pattern.search(img_path).groups()) 73 | if pid == -1: continue # junk images are just ignored 74 | pid_container.add(pid) 75 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 76 | dataset = [] 77 | for img_path in sorted(img_paths): 78 | pid, camid = map(int, pattern.search(img_path).groups()) 79 | if pid == -1: continue # junk images are just ignored 80 | assert 0 <= pid <= 1501 # pid == 0 means background 81 | assert 1 <= camid <= 6 82 | camid -= 1 # index starts from 0 83 | if relabel: pid = pid2label[pid] 84 | 85 | dataset.append((img_path, self.pid_begin + pid, camid, 1)) 86 | return dataset 87 | 88 | def _process_valdir(self, dir_path, relabel=False): 89 | img_paths = glob.glob(osp.join(dir_path, '*.tif')) 90 | pattern = re.compile(r'([-\d]+)_c(\d)') 91 | 92 | pid_container = set() 93 | for img_path in sorted(img_paths): 94 | pid, _ = map(int, pattern.search(img_path).groups()) 95 | if pid == -1: continue # junk images are just ignored 96 | pid_container.add(pid) 97 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 98 | dataset = [] 99 | for img_path in sorted(img_paths): 100 | pid, camid = map(int, pattern.search(img_path).groups()) 101 | if pid == -1: continue # junk images are just ignored 102 | assert 0 <= pid <= 1501 # pid == 0 means background 103 | assert 1 <= camid <= 6 104 | camid -= 1 # index starts from 0 105 | if relabel: pid = pid2label[pid] 106 | 107 | dataset.append((img_path, self.pid_begin + pid, camid, 1)) 108 | return dataset 109 | -------------------------------------------------------------------------------- /datasets/occ_duke.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: liaoxingyu2@jd.com 5 | """ 6 | 7 | import glob 8 | import re 9 | import urllib 10 | import zipfile 11 | 12 | import os.path as osp 13 | 14 | from utils.iotools import mkdir_if_missing 15 | from .bases import BaseImageDataset 16 | 17 | 18 | class OCC_DukeMTMCreID(BaseImageDataset): 19 | """ 20 | DukeMTMC-reID 21 | Reference: 22 | 1. Ristani et al. Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking. ECCVW 2016. 23 | 2. Zheng et al. Unlabeled Samples Generated by GAN Improve the Person Re-identification Baseline in vitro. ICCV 2017. 24 | URL: https://github.com/layumi/DukeMTMC-reID_evaluation 25 | 26 | Dataset statistics: 27 | # identities: 1404 (train + query) 28 | # images:16522 (train) + 2228 (query) + 17661 (gallery) 29 | # cameras: 8 30 | """ 31 | dataset_dir = 'Occluded_Duke' 32 | 33 | def __init__(self, root='', verbose=True, pid_begin=0, **kwargs): 34 | super(OCC_DukeMTMCreID, self).__init__() 35 | self.dataset_dir = osp.join(root, self.dataset_dir) 36 | self.dataset_url = 'http://vision.cs.duke.edu/DukeMTMC/data/misc/DukeMTMC-reID.zip' 37 | self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') 38 | self.query_dir = osp.join(self.dataset_dir, 'query') 39 | self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') 40 | self.pid_begin = pid_begin 41 | self._download_data() 42 | self._check_before_run() 43 | 44 | train = self._process_dir(self.train_dir, relabel=True) 45 | query = self._process_dir(self.query_dir, relabel=False) 46 | gallery = self._process_dir(self.gallery_dir, relabel=False) 47 | 48 | if verbose: 49 | print("=> DukeMTMC-reID loaded") 50 | self.print_dataset_statistics(train, query, gallery) 51 | 52 | self.train = train 53 | self.query = query 54 | self.gallery = gallery 55 | 56 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info(self.train) 57 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info(self.query) 58 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info(self.gallery) 59 | 60 | def _download_data(self): 61 | if osp.exists(self.dataset_dir): 62 | print("This dataset has been downloaded.") 63 | return 64 | 65 | print("Creating directory {}".format(self.dataset_dir)) 66 | mkdir_if_missing(self.dataset_dir) 67 | fpath = osp.join(self.dataset_dir, osp.basename(self.dataset_url)) 68 | 69 | print("Downloading DukeMTMC-reID dataset") 70 | urllib.request.urlretrieve(self.dataset_url, fpath) 71 | 72 | print("Extracting files") 73 | zip_ref = zipfile.ZipFile(fpath, 'r') 74 | zip_ref.extractall(self.dataset_dir) 75 | zip_ref.close() 76 | 77 | def _check_before_run(self): 78 | """Check if all files are available before going deeper""" 79 | if not osp.exists(self.dataset_dir): 80 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 81 | if not osp.exists(self.train_dir): 82 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 83 | if not osp.exists(self.query_dir): 84 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 85 | if not osp.exists(self.gallery_dir): 86 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 87 | 88 | def _process_dir(self, dir_path, relabel=False): 89 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 90 | pattern = re.compile(r'([-\d]+)_c(\d)') 91 | 92 | pid_container = set() 93 | for img_path in img_paths: 94 | pid, _ = map(int, pattern.search(img_path).groups()) 95 | pid_container.add(pid) 96 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 97 | 98 | dataset = [] 99 | cam_container = set() 100 | for img_path in img_paths: 101 | pid, camid = map(int, pattern.search(img_path).groups()) 102 | assert 1 <= camid <= 8 103 | camid -= 1 # index starts from 0 104 | if relabel: pid = pid2label[pid] 105 | dataset.append((img_path, self.pid_begin + pid, camid, 1)) 106 | cam_container.add(camid) 107 | print(cam_container, 'cam_container') 108 | return dataset 109 | -------------------------------------------------------------------------------- /datasets/preprocessing.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | 4 | 5 | class RandomErasing(object): 6 | """ Randomly selects a rectangle region in an image and erases its pixels. 7 | 'Random Erasing Data Augmentation' by Zhong et al. 8 | See https://arxiv.org/pdf/1708.04896.pdf 9 | Args: 10 | probability: The probability that the Random Erasing operation will be performed. 11 | sl: Minimum proportion of erased area against input image. 12 | sh: Maximum proportion of erased area against input image. 13 | r1: Minimum aspect ratio of erased area. 14 | mean: Erasing value. 15 | """ 16 | 17 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 18 | self.probability = probability 19 | self.mean = mean 20 | self.sl = sl 21 | self.sh = sh 22 | self.r1 = r1 23 | 24 | def __call__(self, img): 25 | 26 | if random.uniform(0, 1) >= self.probability: 27 | return img 28 | 29 | for attempt in range(100): 30 | area = img.size()[1] * img.size()[2] 31 | 32 | target_area = random.uniform(self.sl, self.sh) * area 33 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 34 | 35 | h = int(round(math.sqrt(target_area * aspect_ratio))) 36 | w = int(round(math.sqrt(target_area / aspect_ratio))) 37 | 38 | if w < img.size()[2] and h < img.size()[1]: 39 | x1 = random.randint(0, img.size()[1] - h) 40 | y1 = random.randint(0, img.size()[2] - w) 41 | if img.size()[0] == 3: 42 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 43 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 44 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 45 | else: 46 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 47 | return img 48 | 49 | return img 50 | 51 | -------------------------------------------------------------------------------- /datasets/sampler.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | from collections import defaultdict 3 | import copy 4 | import random 5 | import numpy as np 6 | 7 | class RandomIdentitySampler(Sampler): 8 | """ 9 | Randomly sample N identities, then for each identity, 10 | randomly sample K instances, therefore batch size is N*K. 11 | Args: 12 | - data_source (list): list of (img_path, pid, camid). 13 | - num_instances (int): number of instances per identity in a batch. 14 | - batch_size (int): number of examples in a batch. 15 | """ 16 | 17 | def __init__(self, data_source, batch_size, num_instances): 18 | self.data_source = data_source 19 | self.batch_size = batch_size 20 | self.num_instances = num_instances 21 | self.num_pids_per_batch = self.batch_size // self.num_instances 22 | self.index_dic = defaultdict(list) #dict with list value 23 | #{783: [0, 5, 116, 876, 1554, 2041],...,} 24 | for index, (_, pid, _, _) in enumerate(self.data_source): 25 | self.index_dic[pid].append(index) 26 | self.pids = list(self.index_dic.keys()) 27 | 28 | # estimate number of examples in an epoch 29 | self.length = 0 30 | for pid in self.pids: 31 | idxs = self.index_dic[pid] 32 | num = len(idxs) 33 | if num < self.num_instances: 34 | num = self.num_instances 35 | self.length += num - num % self.num_instances 36 | 37 | def __iter__(self): 38 | batch_idxs_dict = defaultdict(list) 39 | 40 | for pid in self.pids: 41 | idxs = copy.deepcopy(self.index_dic[pid]) 42 | if len(idxs) < self.num_instances: 43 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 44 | random.shuffle(idxs) 45 | batch_idxs = [] 46 | for idx in idxs: 47 | batch_idxs.append(idx) 48 | if len(batch_idxs) == self.num_instances: 49 | batch_idxs_dict[pid].append(batch_idxs) 50 | batch_idxs = [] 51 | 52 | avai_pids = copy.deepcopy(self.pids) 53 | final_idxs = [] 54 | 55 | while len(avai_pids) >= self.num_pids_per_batch: 56 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 57 | for pid in selected_pids: 58 | batch_idxs = batch_idxs_dict[pid].pop(0) 59 | final_idxs.extend(batch_idxs) 60 | if len(batch_idxs_dict[pid]) == 0: 61 | avai_pids.remove(pid) 62 | 63 | return iter(final_idxs) 64 | 65 | def __len__(self): 66 | return self.length 67 | 68 | -------------------------------------------------------------------------------- /datasets/sampler_ddp.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | from collections import defaultdict 3 | import copy 4 | import random 5 | import numpy as np 6 | import math 7 | import torch.distributed as dist 8 | _LOCAL_PROCESS_GROUP = None 9 | import torch 10 | import pickle 11 | 12 | def _get_global_gloo_group(): 13 | """ 14 | Return a process group based on gloo backend, containing all the ranks 15 | The result is cached. 16 | """ 17 | if dist.get_backend() == "nccl": 18 | return dist.new_group(backend="gloo") 19 | else: 20 | return dist.group.WORLD 21 | 22 | def _serialize_to_tensor(data, group): 23 | backend = dist.get_backend(group) 24 | assert backend in ["gloo", "nccl"] 25 | device = torch.device("cpu" if backend == "gloo" else "cuda") 26 | 27 | buffer = pickle.dumps(data) 28 | if len(buffer) > 1024 ** 3: 29 | print( 30 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 31 | dist.get_rank(), len(buffer) / (1024 ** 3), device 32 | ) 33 | ) 34 | storage = torch.ByteStorage.from_buffer(buffer) 35 | tensor = torch.ByteTensor(storage).to(device=device) 36 | return tensor 37 | 38 | def _pad_to_largest_tensor(tensor, group): 39 | """ 40 | Returns: 41 | list[int]: size of the tensor, on each rank 42 | Tensor: padded tensor that has the max size 43 | """ 44 | world_size = dist.get_world_size(group=group) 45 | assert ( 46 | world_size >= 1 47 | ), "comm.gather/all_gather must be called from ranks within the given group!" 48 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) 49 | size_list = [ 50 | torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) 51 | ] 52 | dist.all_gather(size_list, local_size, group=group) 53 | size_list = [int(size.item()) for size in size_list] 54 | 55 | max_size = max(size_list) 56 | 57 | # we pad the tensor because torch all_gather does not support 58 | # gathering tensors of different shapes 59 | if local_size != max_size: 60 | padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) 61 | tensor = torch.cat((tensor, padding), dim=0) 62 | return size_list, tensor 63 | 64 | def all_gather(data, group=None): 65 | """ 66 | Run all_gather on arbitrary picklable data (not necessarily tensors). 67 | Args: 68 | data: any picklable object 69 | group: a torch process group. By default, will use a group which 70 | contains all ranks on gloo backend. 71 | Returns: 72 | list[data]: list of data gathered from each rank 73 | """ 74 | if dist.get_world_size() == 1: 75 | return [data] 76 | if group is None: 77 | group = _get_global_gloo_group() 78 | if dist.get_world_size(group) == 1: 79 | return [data] 80 | 81 | tensor = _serialize_to_tensor(data, group) 82 | 83 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 84 | max_size = max(size_list) 85 | 86 | # receiving Tensor from all ranks 87 | tensor_list = [ 88 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list 89 | ] 90 | dist.all_gather(tensor_list, tensor, group=group) 91 | 92 | data_list = [] 93 | for size, tensor in zip(size_list, tensor_list): 94 | buffer = tensor.cpu().numpy().tobytes()[:size] 95 | data_list.append(pickle.loads(buffer)) 96 | 97 | return data_list 98 | 99 | def shared_random_seed(): 100 | """ 101 | Returns: 102 | int: a random number that is the same across all workers. 103 | If workers need a shared RNG, they can use this shared seed to 104 | create one. 105 | All workers must call this function, otherwise it will deadlock. 106 | """ 107 | ints = np.random.randint(2 ** 31) 108 | all_ints = all_gather(ints) 109 | return all_ints[0] 110 | 111 | class RandomIdentitySampler_DDP(Sampler): 112 | """ 113 | Randomly sample N identities, then for each identity, 114 | randomly sample K instances, therefore batch size is N*K. 115 | Args: 116 | - data_source (list): list of (img_path, pid, camid). 117 | - num_instances (int): number of instances per identity in a batch. 118 | - batch_size (int): number of examples in a batch. 119 | """ 120 | 121 | def __init__(self, data_source, batch_size, num_instances): 122 | self.data_source = data_source 123 | self.batch_size = batch_size 124 | self.world_size = dist.get_world_size() 125 | self.num_instances = num_instances 126 | self.mini_batch_size = self.batch_size // self.world_size 127 | self.num_pids_per_batch = self.mini_batch_size // self.num_instances 128 | self.index_dic = defaultdict(list) 129 | 130 | for index, (_, pid, _, _) in enumerate(self.data_source): 131 | self.index_dic[pid].append(index) 132 | self.pids = list(self.index_dic.keys()) 133 | 134 | # estimate number of examples in an epoch 135 | self.length = 0 136 | for pid in self.pids: 137 | idxs = self.index_dic[pid] 138 | num = len(idxs) 139 | if num < self.num_instances: 140 | num = self.num_instances 141 | self.length += num - num % self.num_instances 142 | 143 | self.rank = dist.get_rank() 144 | #self.world_size = dist.get_world_size() 145 | self.length //= self.world_size 146 | 147 | def __iter__(self): 148 | seed = shared_random_seed() 149 | np.random.seed(seed) 150 | self._seed = int(seed) 151 | final_idxs = self.sample_list() 152 | length = int(math.ceil(len(final_idxs) * 1.0 / self.world_size)) 153 | #final_idxs = final_idxs[self.rank * length:(self.rank + 1) * length] 154 | final_idxs = self.__fetch_current_node_idxs(final_idxs, length) 155 | self.length = len(final_idxs) 156 | return iter(final_idxs) 157 | 158 | 159 | def __fetch_current_node_idxs(self, final_idxs, length): 160 | total_num = len(final_idxs) 161 | block_num = (length // self.mini_batch_size) 162 | index_target = [] 163 | for i in range(0, block_num * self.world_size, self.world_size): 164 | index = range(self.mini_batch_size * self.rank + self.mini_batch_size * i, min(self.mini_batch_size * self.rank + self.mini_batch_size * (i+1), total_num)) 165 | index_target.extend(index) 166 | index_target_npy = np.array(index_target) 167 | final_idxs = list(np.array(final_idxs)[index_target_npy]) 168 | return final_idxs 169 | 170 | 171 | def sample_list(self): 172 | #np.random.seed(self._seed) 173 | avai_pids = copy.deepcopy(self.pids) 174 | batch_idxs_dict = {} 175 | 176 | batch_indices = [] 177 | while len(avai_pids) >= self.num_pids_per_batch: 178 | selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False).tolist() 179 | for pid in selected_pids: 180 | if pid not in batch_idxs_dict: 181 | idxs = copy.deepcopy(self.index_dic[pid]) 182 | if len(idxs) < self.num_instances: 183 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist() 184 | np.random.shuffle(idxs) 185 | batch_idxs_dict[pid] = idxs 186 | 187 | avai_idxs = batch_idxs_dict[pid] 188 | for _ in range(self.num_instances): 189 | batch_indices.append(avai_idxs.pop(0)) 190 | 191 | if len(avai_idxs) < self.num_instances: avai_pids.remove(pid) 192 | 193 | return batch_indices 194 | 195 | def __len__(self): 196 | return self.length 197 | 198 | -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_loss import make_loss 2 | from .arcface import ArcFace -------------------------------------------------------------------------------- /loss/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/loss/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /loss/__pycache__/arcface.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/loss/__pycache__/arcface.cpython-37.pyc -------------------------------------------------------------------------------- /loss/__pycache__/center_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/loss/__pycache__/center_loss.cpython-37.pyc -------------------------------------------------------------------------------- /loss/__pycache__/make_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/loss/__pycache__/make_loss.cpython-37.pyc -------------------------------------------------------------------------------- /loss/__pycache__/metric_learning.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/loss/__pycache__/metric_learning.cpython-37.pyc -------------------------------------------------------------------------------- /loss/__pycache__/softmax_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/loss/__pycache__/softmax_loss.cpython-37.pyc -------------------------------------------------------------------------------- /loss/__pycache__/triplet_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/loss/__pycache__/triplet_loss.cpython-37.pyc -------------------------------------------------------------------------------- /loss/arcface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Parameter 5 | import math 6 | 7 | 8 | class ArcFace(nn.Module): 9 | def __init__(self, in_features, out_features, s=30.0, m=0.50, bias=False): 10 | super(ArcFace, self).__init__() 11 | self.in_features = in_features 12 | self.out_features = out_features 13 | self.s = s 14 | self.m = m 15 | self.cos_m = math.cos(m) 16 | self.sin_m = math.sin(m) 17 | 18 | self.th = math.cos(math.pi - m) 19 | self.mm = math.sin(math.pi - m) * m 20 | 21 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 22 | if bias: 23 | self.bias = Parameter(torch.Tensor(out_features)) 24 | else: 25 | self.register_parameter('bias', None) 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 30 | if self.bias is not None: 31 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 32 | bound = 1 / math.sqrt(fan_in) 33 | nn.init.uniform_(self.bias, -bound, bound) 34 | 35 | def forward(self, input, label): 36 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 37 | sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)) 38 | phi = cosine * self.cos_m - sine * self.sin_m 39 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 40 | # --------------------------- convert label to one-hot --------------------------- 41 | # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda') 42 | one_hot = torch.zeros(cosine.size(), device='cuda') 43 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 44 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 45 | output = (one_hot * phi) + ( 46 | (1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 47 | output *= self.s 48 | # print(output) 49 | 50 | return output 51 | 52 | class CircleLoss(nn.Module): 53 | def __init__(self, in_features, num_classes, s=256, m=0.25): 54 | super(CircleLoss, self).__init__() 55 | self.weight = Parameter(torch.Tensor(num_classes, in_features)) 56 | self.s = s 57 | self.m = m 58 | self._num_classes = num_classes 59 | self.reset_parameters() 60 | 61 | 62 | def reset_parameters(self): 63 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 64 | 65 | def __call__(self, bn_feat, targets): 66 | 67 | sim_mat = F.linear(F.normalize(bn_feat), F.normalize(self.weight)) 68 | alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self.m, min=0.) 69 | alpha_n = torch.clamp_min(sim_mat.detach() + self.m, min=0.) 70 | delta_p = 1 - self.m 71 | delta_n = self.m 72 | 73 | s_p = self.s * alpha_p * (sim_mat - delta_p) 74 | s_n = self.s * alpha_n * (sim_mat - delta_n) 75 | 76 | targets = F.one_hot(targets, num_classes=self._num_classes) 77 | 78 | pred_class_logits = targets * s_p + (1.0 - targets) * s_n 79 | 80 | return pred_class_logits -------------------------------------------------------------------------------- /loss/center_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class CenterLoss(nn.Module): 8 | """Center loss. 9 | 10 | Reference: 11 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 12 | 13 | Args: 14 | num_classes (int): number of classes. 15 | feat_dim (int): feature dimension. 16 | """ 17 | 18 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True): 19 | super(CenterLoss, self).__init__() 20 | self.num_classes = num_classes 21 | self.feat_dim = feat_dim 22 | self.use_gpu = use_gpu 23 | 24 | if self.use_gpu: 25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 26 | else: 27 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 28 | 29 | def forward(self, x, labels): 30 | """ 31 | Args: 32 | x: feature matrix with shape (batch_size, feat_dim). 33 | labels: ground truth labels with shape (num_classes). 34 | """ 35 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)" 36 | 37 | batch_size = x.size(0) 38 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 39 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 40 | distmat.addmm_(1, -2, x, self.centers.t()) 41 | 42 | classes = torch.arange(self.num_classes).long() 43 | if self.use_gpu: classes = classes.cuda() 44 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 45 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 46 | 47 | dist = [] 48 | for i in range(batch_size): 49 | value = distmat[i][mask[i]] 50 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 51 | dist.append(value) 52 | dist = torch.cat(dist) 53 | loss = dist.mean() 54 | return loss 55 | 56 | 57 | if __name__ == '__main__': 58 | use_gpu = False 59 | center_loss = CenterLoss(use_gpu=use_gpu) 60 | features = torch.rand(16, 2048) 61 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long() 62 | if use_gpu: 63 | features = torch.rand(16, 2048).cuda() 64 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda() 65 | 66 | loss = center_loss(features, targets) 67 | print(loss) 68 | -------------------------------------------------------------------------------- /loss/make_loss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | import torch 7 | import torch.nn.functional as F 8 | from .softmax_loss import CrossEntropyLabelSmooth, LabelSmoothingCrossEntropy 9 | from .triplet_loss import TripletLoss 10 | from .center_loss import CenterLoss 11 | 12 | def featureL2Norm(feature): 13 | epsilon = 1e-6 14 | norm = torch.pow(torch.sum(torch.pow(feature, 2), 2) + 15 | epsilon, 0.5).unsqueeze(2).expand_as(feature) 16 | return torch.div(feature, norm) 17 | 18 | def orthonomal_loss(w): 19 | B, K, C = w.shape 20 | w_norm = featureL2Norm(w) 21 | WWT = torch.matmul(w_norm, w_norm.transpose(1, 2)) 22 | return F.mse_loss(WWT - torch.eye(K).unsqueeze(0).cuda(), torch.zeros(B, K, K).cuda(), size_average=False) / (K*K) 23 | 24 | def make_loss(cfg, num_classes): # modified by gu 25 | sampler = cfg.DATALOADER.SAMPLER 26 | feat_dim = 2048 27 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss 28 | if 'triplet' in cfg.MODEL.METRIC_LOSS_TYPE: 29 | if cfg.MODEL.NO_MARGIN: 30 | triplet = TripletLoss() 31 | print("using soft triplet loss for training") 32 | else: 33 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss 34 | print("using triplet loss with margin:{}".format(cfg.SOLVER.MARGIN)) 35 | else: 36 | print('expected METRIC_LOSS_TYPE should be triplet' 37 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 38 | 39 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 40 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) 41 | print("label smooth on, numclasses:", num_classes) 42 | 43 | if sampler == 'softmax': 44 | def loss_func(score, feat, target): 45 | return F.cross_entropy(score, target) 46 | 47 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet': 48 | def loss_func(score, Mscore, feat, orth_proto, epoch, target, target_cam): 49 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet': 50 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 51 | if isinstance(score, list): 52 | ID_LOSS = [xent(scor, target) for scor in score[1:]] 53 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS) 54 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * xent(score[0], target) 55 | else: 56 | ID_LOSS = xent(score, target) 57 | 58 | if isinstance(feat, list): 59 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[1:]] 60 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS) 61 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * triplet(feat[0], target)[0] 62 | else: 63 | TRI_LOSS = triplet(feat, target)[0] 64 | 65 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \ 66 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS 67 | else: 68 | if isinstance(score, list): 69 | ID_LOSS = [F.cross_entropy(scor, target) for scor in score[1:]] 70 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS) 71 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * F.cross_entropy(score[0], target) 72 | else: 73 | ID_LOSS = 0.5 * F.cross_entropy(score, target) + 0.5 * F.cross_entropy(Mscore, target) + 0.1 * orthonomal_loss(orth_proto) # Set the orthonomal_loss to 0.01 when testing on the Occluded-REID 74 | 75 | if isinstance(feat, list): 76 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[1:]] 77 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS) 78 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * triplet(feat[0], target)[0] 79 | else: 80 | TRI_LOSS = triplet(feat, target)[0] 81 | 82 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \ 83 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS 84 | else: 85 | print('expected METRIC_LOSS_TYPE should be triplet' 86 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 87 | 88 | else: 89 | print('expected sampler should be softmax, triplet, softmax_triplet or softmax_triplet_center' 90 | 'but got {}'.format(cfg.DATALOADER.SAMPLER)) 91 | return loss_func, center_criterion 92 | 93 | 94 | -------------------------------------------------------------------------------- /loss/metric_learning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.autograd 5 | from torch.nn import Parameter 6 | import math 7 | 8 | 9 | class ContrastiveLoss(nn.Module): 10 | def __init__(self, margin=0.3, **kwargs): 11 | super(ContrastiveLoss, self).__init__() 12 | self.margin = margin 13 | 14 | def forward(self, inputs, targets): 15 | n = inputs.size(0) 16 | # Compute similarity matrix 17 | sim_mat = torch.matmul(inputs, inputs.t()) 18 | targets = targets 19 | loss = list() 20 | c = 0 21 | 22 | for i in range(n): 23 | pos_pair_ = torch.masked_select(sim_mat[i], targets == targets[i]) 24 | 25 | # move itself 26 | pos_pair_ = torch.masked_select(pos_pair_, pos_pair_ < 1) 27 | neg_pair_ = torch.masked_select(sim_mat[i], targets != targets[i]) 28 | 29 | pos_pair_ = torch.sort(pos_pair_)[0] 30 | neg_pair_ = torch.sort(neg_pair_)[0] 31 | 32 | neg_pair = torch.masked_select(neg_pair_, neg_pair_ > self.margin) 33 | 34 | neg_loss = 0 35 | 36 | pos_loss = torch.sum(-pos_pair_ + 1) 37 | if len(neg_pair) > 0: 38 | neg_loss = torch.sum(neg_pair) 39 | loss.append(pos_loss + neg_loss) 40 | 41 | loss = sum(loss) / n 42 | return loss 43 | 44 | 45 | class CircleLoss(nn.Module): 46 | def __init__(self, in_features, num_classes, s=256, m=0.25): 47 | super(CircleLoss, self).__init__() 48 | self.weight = Parameter(torch.Tensor(num_classes, in_features)) 49 | self.s = s 50 | self.m = m 51 | self._num_classes = num_classes 52 | self.reset_parameters() 53 | 54 | 55 | def reset_parameters(self): 56 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 57 | 58 | def __call__(self, bn_feat, targets): 59 | 60 | sim_mat = F.linear(F.normalize(bn_feat), F.normalize(self.weight)) 61 | alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self.m, min=0.) 62 | alpha_n = torch.clamp_min(sim_mat.detach() + self.m, min=0.) 63 | delta_p = 1 - self.m 64 | delta_n = self.m 65 | 66 | s_p = self.s * alpha_p * (sim_mat - delta_p) 67 | s_n = self.s * alpha_n * (sim_mat - delta_n) 68 | 69 | targets = F.one_hot(targets, num_classes=self._num_classes) 70 | 71 | pred_class_logits = targets * s_p + (1.0 - targets) * s_n 72 | 73 | return pred_class_logits 74 | 75 | 76 | class Arcface(nn.Module): 77 | r"""Implement of large margin arc distance: : 78 | Args: 79 | in_features: size of each input sample 80 | out_features: size of each output sample 81 | s: norm of input feature 82 | m: margin 83 | cos(theta + m) 84 | """ 85 | def __init__(self, in_features, out_features, s=30.0, m=0.30, easy_margin=False, ls_eps=0.0): 86 | super(Arcface, self).__init__() 87 | self.in_features = in_features 88 | self.out_features = out_features 89 | self.s = s 90 | self.m = m 91 | self.ls_eps = ls_eps # label smoothing 92 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)) 93 | nn.init.xavier_uniform_(self.weight) 94 | 95 | self.easy_margin = easy_margin 96 | self.cos_m = math.cos(m) 97 | self.sin_m = math.sin(m) 98 | self.th = math.cos(math.pi - m) 99 | self.mm = math.sin(math.pi - m) * m 100 | 101 | def forward(self, input, label): 102 | # --------------------------- cos(theta) & phi(theta) --------------------------- 103 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 104 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) 105 | phi = cosine * self.cos_m - sine * self.sin_m 106 | phi = phi.type_as(cosine) 107 | if self.easy_margin: 108 | phi = torch.where(cosine > 0, phi, cosine) 109 | else: 110 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 111 | # --------------------------- convert label to one-hot --------------------------- 112 | # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda') 113 | one_hot = torch.zeros(cosine.size(), device='cuda') 114 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 115 | if self.ls_eps > 0: 116 | one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features 117 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 118 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) 119 | output *= self.s 120 | 121 | return output 122 | 123 | 124 | class MArcface(nn.Module): 125 | r"""Implement of large margin arc distance: : 126 | Args: 127 | in_features: size of each input sample 128 | out_features: size of each output sample 129 | s: norm of input feature 130 | m: margin 131 | cos(theta + m) 132 | """ 133 | def __init__(self, in_features, out_features, s=30.0, m=0.30, easy_margin=False, ls_eps=0.0): 134 | super(MArcface, self).__init__() 135 | self.in_features = in_features 136 | self.out_features = out_features 137 | self.s = s 138 | self.m = m 139 | self.ls_eps = ls_eps # label smoothing 140 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)) 141 | nn.init.xavier_uniform_(self.weight) 142 | 143 | self.easy_margin = easy_margin 144 | self.cos_m = math.cos(m) 145 | self.sin_m = math.sin(m) 146 | self.th = math.cos(math.pi - m) 147 | self.mm = math.sin(math.pi - m) * m 148 | 149 | def forward(self, input, mask, label): 150 | output_ori = F.linear(input, self.weight) 151 | # --------------------------- cos(theta) & phi(theta) --------------------------- 152 | for i in range(input.shape[0]): 153 | if i == 0: 154 | query_vector = F.normalize(input[i,:].unsqueeze(0)) 155 | gallery_prototype = mask[i,:].unsqueeze(0) * F.normalize(self.weight) 156 | cosine = F.linear(F.normalize(query_vector), F.normalize(gallery_prototype)) 157 | else: 158 | query_vector = F.normalize(input[i,:].unsqueeze(0)) 159 | gallery_prototype = mask[i,:].unsqueeze(0) * F.normalize(self.weight) 160 | cosine_single = F.linear(F.normalize(query_vector), F.normalize(gallery_prototype)) 161 | cosine = torch.cat((cosine, cosine_single), 0) 162 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) 163 | phi = cosine * self.cos_m - sine * self.sin_m 164 | phi = phi.type_as(cosine) 165 | if self.easy_margin: 166 | phi = torch.where(cosine > 0, phi, cosine) 167 | else: 168 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 169 | # --------------------------- convert label to one-hot --------------------------- 170 | # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda') 171 | one_hot = torch.zeros(cosine.size(), device='cuda') 172 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 173 | if self.ls_eps > 0: 174 | one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features 175 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 176 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) 177 | output *= self.s 178 | 179 | return output_ori, output 180 | 181 | 182 | class Cosface(nn.Module): 183 | r"""Implement of large margin cosine distance: : 184 | Args: 185 | in_features: size of each input sample 186 | out_features: size of each output sample 187 | s: norm of input feature 188 | m: margin 189 | cos(theta) - m 190 | """ 191 | 192 | def __init__(self, in_features, out_features, s=30.0, m=0.30): 193 | super(Cosface, self).__init__() 194 | self.in_features = in_features 195 | self.out_features = out_features 196 | self.s = s 197 | self.m = m 198 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)) 199 | nn.init.xavier_uniform_(self.weight) 200 | 201 | def forward(self, input, label): 202 | # --------------------------- cos(theta) & phi(theta) --------------------------- 203 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 204 | phi = cosine - self.m 205 | # --------------------------- convert label to one-hot --------------------------- 206 | one_hot = torch.zeros(cosine.size(), device='cuda') 207 | # one_hot = one_hot.cuda() if cosine.is_cuda else one_hot 208 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 209 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 210 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 211 | output *= self.s 212 | # print(output) 213 | 214 | return output 215 | 216 | def __repr__(self): 217 | return self.__class__.__name__ + '(' \ 218 | + 'in_features=' + str(self.in_features) \ 219 | + ', out_features=' + str(self.out_features) \ 220 | + ', s=' + str(self.s) \ 221 | + ', m=' + str(self.m) + ')' 222 | 223 | 224 | class AMSoftmax(nn.Module): 225 | def __init__(self, in_features, out_features, s=30.0, m=0.30): 226 | super(AMSoftmax, self).__init__() 227 | self.m = m 228 | self.s = s 229 | self.in_feats = in_features 230 | self.W = torch.nn.Parameter(torch.randn(in_features, out_features), requires_grad=True) 231 | self.ce = nn.CrossEntropyLoss() 232 | nn.init.xavier_normal_(self.W, gain=1) 233 | 234 | def forward(self, x, lb): 235 | assert x.size()[0] == lb.size()[0] 236 | assert x.size()[1] == self.in_feats 237 | x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12) 238 | x_norm = torch.div(x, x_norm) 239 | w_norm = torch.norm(self.W, p=2, dim=0, keepdim=True).clamp(min=1e-12) 240 | w_norm = torch.div(self.W, w_norm) 241 | costh = torch.mm(x_norm, w_norm) 242 | # print(x_norm.shape, w_norm.shape, costh.shape) 243 | lb_view = lb.view(-1, 1) 244 | delt_costh = torch.zeros(costh.size(), device='cuda').scatter_(1, lb_view, self.m) 245 | costh_m = costh - delt_costh 246 | costh_m_s = self.s * costh_m 247 | return costh_m_s -------------------------------------------------------------------------------- /loss/softmax_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | class CrossEntropyLabelSmooth(nn.Module): 5 | """Cross entropy loss with label smoothing regularizer. 6 | 7 | Reference: 8 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 9 | Equation: y = (1 - epsilon) * y + epsilon / K. 10 | 11 | Args: 12 | num_classes (int): number of classes. 13 | epsilon (float): weight. 14 | """ 15 | 16 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 17 | super(CrossEntropyLabelSmooth, self).__init__() 18 | self.num_classes = num_classes 19 | self.epsilon = epsilon 20 | self.use_gpu = use_gpu 21 | self.logsoftmax = nn.LogSoftmax(dim=1) 22 | 23 | def forward(self, inputs, targets): 24 | """ 25 | Args: 26 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 27 | targets: ground truth labels with shape (num_classes) 28 | """ 29 | log_probs = self.logsoftmax(inputs) 30 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 31 | if self.use_gpu: targets = targets.cuda() 32 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 33 | loss = (- targets * log_probs).mean(0).sum() 34 | return loss 35 | 36 | class LabelSmoothingCrossEntropy(nn.Module): 37 | """ 38 | NLL loss with label smoothing. 39 | """ 40 | def __init__(self, smoothing=0.1): 41 | """ 42 | Constructor for the LabelSmoothing module. 43 | :param smoothing: label smoothing factor 44 | """ 45 | super(LabelSmoothingCrossEntropy, self).__init__() 46 | assert smoothing < 1.0 47 | self.smoothing = smoothing 48 | self.confidence = 1. - smoothing 49 | 50 | def forward(self, x, target): 51 | logprobs = F.log_softmax(x, dim=-1) 52 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 53 | nll_loss = nll_loss.squeeze(1) 54 | smooth_loss = -logprobs.mean(dim=-1) 55 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 56 | return loss.mean() -------------------------------------------------------------------------------- /loss/triplet_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def normalize(x, axis=-1): 6 | """Normalizing to unit length along the specified dimension. 7 | Args: 8 | x: pytorch Variable 9 | Returns: 10 | x: pytorch Variable, same shape as input 11 | """ 12 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 13 | return x 14 | 15 | 16 | def euclidean_dist(x, y): 17 | """ 18 | Args: 19 | x: pytorch Variable, with shape [m, d] 20 | y: pytorch Variable, with shape [n, d] 21 | Returns: 22 | dist: pytorch Variable, with shape [m, n] 23 | """ 24 | m, n = x.size(0), y.size(0) 25 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 26 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 27 | dist = xx + yy 28 | dist = dist - 2 * torch.matmul(x, y.t()) 29 | # dist.addmm_(1, -2, x, y.t()) 30 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 31 | return dist 32 | 33 | 34 | def cosine_dist(x, y): 35 | """ 36 | Args: 37 | x: pytorch Variable, with shape [m, d] 38 | y: pytorch Variable, with shape [n, d] 39 | Returns: 40 | dist: pytorch Variable, with shape [m, n] 41 | """ 42 | m, n = x.size(0), y.size(0) 43 | x_norm = torch.pow(x, 2).sum(1, keepdim=True).sqrt().expand(m, n) 44 | y_norm = torch.pow(y, 2).sum(1, keepdim=True).sqrt().expand(n, m).t() 45 | xy_intersection = torch.mm(x, y.t()) 46 | dist = xy_intersection/(x_norm * y_norm) 47 | dist = (1. - dist) / 2 48 | return dist 49 | 50 | 51 | def hard_example_mining(dist_mat, labels, return_inds=False): 52 | """For each anchor, find the hardest positive and negative sample. 53 | Args: 54 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 55 | labels: pytorch LongTensor, with shape [N] 56 | return_inds: whether to return the indices. Save time if `False`(?) 57 | Returns: 58 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 59 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 60 | p_inds: pytorch LongTensor, with shape [N]; 61 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 62 | n_inds: pytorch LongTensor, with shape [N]; 63 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 64 | NOTE: Only consider the case in which all labels have same num of samples, 65 | thus we can cope with all anchors in parallel. 66 | """ 67 | 68 | assert len(dist_mat.size()) == 2 69 | assert dist_mat.size(0) == dist_mat.size(1) 70 | N = dist_mat.size(0) 71 | 72 | # shape [N, N] 73 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 74 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 75 | 76 | # `dist_ap` means distance(anchor, positive) 77 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 78 | dist_ap, relative_p_inds = torch.max( 79 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 80 | # print(dist_mat[is_pos].shape) 81 | # `dist_an` means distance(anchor, negative) 82 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 83 | dist_an, relative_n_inds = torch.min( 84 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 85 | # shape [N] 86 | dist_ap = dist_ap.squeeze(1) 87 | dist_an = dist_an.squeeze(1) 88 | 89 | if return_inds: 90 | # shape [N, N] 91 | ind = (labels.new().resize_as_(labels) 92 | .copy_(torch.arange(0, N).long()) 93 | .unsqueeze(0).expand(N, N)) 94 | # shape [N, 1] 95 | p_inds = torch.gather( 96 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 97 | n_inds = torch.gather( 98 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 99 | # shape [N] 100 | p_inds = p_inds.squeeze(1) 101 | n_inds = n_inds.squeeze(1) 102 | return dist_ap, dist_an, p_inds, n_inds 103 | 104 | return dist_ap, dist_an 105 | 106 | 107 | class TripletLoss(object): 108 | """ 109 | Triplet loss using HARDER example mining, 110 | modified based on original triplet loss using hard example mining 111 | """ 112 | 113 | def __init__(self, margin=None, hard_factor=0.0): 114 | self.margin = margin 115 | self.hard_factor = hard_factor 116 | if margin is not None: 117 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 118 | else: 119 | self.ranking_loss = nn.SoftMarginLoss() 120 | 121 | def __call__(self, global_feat, labels, normalize_feature=False): 122 | if normalize_feature: 123 | global_feat = normalize(global_feat, axis=-1) 124 | dist_mat = euclidean_dist(global_feat, global_feat) 125 | dist_ap, dist_an = hard_example_mining(dist_mat, labels) 126 | 127 | dist_ap *= (1.0 + self.hard_factor) 128 | dist_an *= (1.0 - self.hard_factor) 129 | 130 | y = dist_an.new().resize_as_(dist_an).fill_(1) 131 | if self.margin is not None: 132 | loss = self.ranking_loss(dist_an, dist_ap, y) 133 | else: 134 | loss = self.ranking_loss(dist_an - dist_ap, y) 135 | return loss, dist_ap, dist_an 136 | 137 | 138 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_model import make_model -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/make_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/model/__pycache__/make_model.cpython-37.pyc -------------------------------------------------------------------------------- /model/backbones/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/model/backbones/__init__.py -------------------------------------------------------------------------------- /model/backbones/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/model/backbones/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /model/backbones/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/model/backbones/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /model/backbones/__pycache__/vit_pytorch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/model/backbones/__pycache__/vit_pytorch.cpython-37.pyc -------------------------------------------------------------------------------- /model/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.downsample = downsample 24 | self.stride = stride 25 | 26 | def forward(self, x): 27 | residual = x 28 | 29 | out = self.conv1(x) 30 | out = self.bn1(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv2(out) 34 | out = self.bn2(out) 35 | 36 | if self.downsample is not None: 37 | residual = self.downsample(x) 38 | 39 | out += residual 40 | out = self.relu(out) 41 | 42 | return out 43 | 44 | 45 | class Bottleneck(nn.Module): 46 | expansion = 4 47 | 48 | def __init__(self, inplanes, planes, stride=1, downsample=None): 49 | super(Bottleneck, self).__init__() 50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 51 | self.bn1 = nn.BatchNorm2d(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 53 | padding=1, bias=False) 54 | self.bn2 = nn.BatchNorm2d(planes) 55 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 56 | self.bn3 = nn.BatchNorm2d(planes * 4) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | def forward(self, x): 62 | residual = x 63 | 64 | out = self.conv1(x) 65 | out = self.bn1(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv2(out) 69 | out = self.bn2(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv3(out) 73 | out = self.bn3(out) 74 | 75 | if self.downsample is not None: 76 | residual = self.downsample(x) 77 | 78 | out += residual 79 | out = self.relu(out) 80 | 81 | return out 82 | 83 | 84 | class ResNet(nn.Module): 85 | def __init__(self, last_stride=2, block=Bottleneck,layers=[3, 4, 6, 3]): 86 | self.inplanes = 64 87 | super().__init__() 88 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 89 | bias=False) 90 | self.bn1 = nn.BatchNorm2d(64) 91 | # self.relu = nn.ReLU(inplace=True) # add missed relu 92 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=None, padding=0) 93 | self.layer1 = self._make_layer(block, 64, layers[0]) 94 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 95 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 96 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride) 97 | 98 | def _make_layer(self, block, planes, blocks, stride=1): 99 | downsample = None 100 | if stride != 1 or self.inplanes != planes * block.expansion: 101 | downsample = nn.Sequential( 102 | nn.Conv2d(self.inplanes, planes * block.expansion, 103 | kernel_size=1, stride=stride, bias=False), 104 | nn.BatchNorm2d(planes * block.expansion), 105 | ) 106 | 107 | layers = [] 108 | layers.append(block(self.inplanes, planes, stride, downsample)) 109 | self.inplanes = planes * block.expansion 110 | for i in range(1, blocks): 111 | layers.append(block(self.inplanes, planes)) 112 | 113 | return nn.Sequential(*layers) 114 | 115 | def forward(self, x, cam_label=None): 116 | x = self.conv1(x) 117 | x = self.bn1(x) 118 | # x = self.relu(x) # add missed relu 119 | x = self.maxpool(x) 120 | x = self.layer1(x) 121 | x = self.layer2(x) 122 | x = self.layer3(x) 123 | x = self.layer4(x) 124 | 125 | return x 126 | 127 | def load_param(self, model_path): 128 | param_dict = torch.load(model_path) 129 | for i in param_dict: 130 | if 'fc' in i: 131 | continue 132 | self.state_dict()[i].copy_(param_dict[i]) 133 | 134 | def random_init(self): 135 | for m in self.modules(): 136 | if isinstance(m, nn.Conv2d): 137 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 138 | m.weight.data.normal_(0, math.sqrt(2. / n)) 139 | elif isinstance(m, nn.BatchNorm2d): 140 | m.weight.data.fill_(1) 141 | m.bias.data.zero_() -------------------------------------------------------------------------------- /model/backbones/vit_pytorch.py: -------------------------------------------------------------------------------- 1 | """ Vision Transformer (ViT) in PyTorch 2 | 3 | A PyTorch implement of Vision Transformers as described in 4 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929 5 | 6 | The official jax code is released and available at https://github.com/google-research/vision_transformer 7 | 8 | Status/TODO: 9 | * Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights. 10 | * Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches. 11 | * Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code. 12 | * Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future. 13 | 14 | Acknowledgments: 15 | * The paper authors for releasing code and weights, thanks! 16 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out 17 | for some einops/einsum fun 18 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT 19 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert 20 | 21 | Hacked together by / Copyright 2020 Ross Wightman 22 | """ 23 | import math 24 | from functools import partial 25 | from itertools import repeat 26 | 27 | import torch 28 | import torch.nn as nn 29 | import torch.nn.functional as F 30 | import collections.abc as container_abcs 31 | 32 | 33 | # From PyTorch internals 34 | def _ntuple(n): 35 | def parse(x): 36 | if isinstance(x, container_abcs.Iterable): 37 | return x 38 | return tuple(repeat(x, n)) 39 | return parse 40 | 41 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 42 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 43 | to_2tuple = _ntuple(2) 44 | 45 | def drop_path(x, drop_prob: float = 0., training: bool = False): 46 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 47 | 48 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 49 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 50 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 51 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 52 | 'survival rate' as the argument. 53 | 54 | """ 55 | if drop_prob == 0. or not training: 56 | return x 57 | keep_prob = 1 - drop_prob 58 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 59 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 60 | random_tensor.floor_() # binarize 61 | output = x.div(keep_prob) * random_tensor 62 | return output 63 | 64 | 65 | def conv3x3_block(in_planes, out_planes, stride=1): 66 | """3x3 convolution with padding""" 67 | conv_layer = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1) 68 | 69 | block = nn.Sequential( 70 | conv_layer, 71 | nn.BatchNorm2d(out_planes), 72 | nn.ReLU(inplace=True), 73 | ) 74 | return block 75 | 76 | 77 | class DropPath(nn.Module): 78 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 79 | """ 80 | def __init__(self, drop_prob=None): 81 | super(DropPath, self).__init__() 82 | self.drop_prob = drop_prob 83 | 84 | def forward(self, x): 85 | return drop_path(x, self.drop_prob, self.training) 86 | 87 | 88 | def _cfg(url='', **kwargs): 89 | return { 90 | 'url': url, 91 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 92 | 'crop_pct': .9, 'interpolation': 'bicubic', 93 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 94 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 95 | **kwargs 96 | } 97 | 98 | 99 | default_cfgs = { 100 | # patch models 101 | 'vit_small_patch16_224': _cfg( 102 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', 103 | ), 104 | 'vit_base_patch16_224': _cfg( 105 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', 106 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 107 | ), 108 | 'vit_base_patch16_384': _cfg( 109 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', 110 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 111 | 'vit_base_patch32_384': _cfg( 112 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth', 113 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 114 | 'vit_large_patch16_224': _cfg( 115 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', 116 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 117 | 'vit_large_patch16_384': _cfg( 118 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', 119 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 120 | 'vit_large_patch32_384': _cfg( 121 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', 122 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 123 | 'vit_huge_patch16_224': _cfg(), 124 | 'vit_huge_patch32_384': _cfg(input_size=(3, 384, 384)), 125 | # hybrid models 126 | 'vit_small_resnet26d_224': _cfg(), 127 | 'vit_small_resnet50d_s3_224': _cfg(), 128 | 'vit_base_resnet26d_224': _cfg(), 129 | 'vit_base_resnet50d_224': _cfg(), 130 | } 131 | 132 | 133 | class Mlp(nn.Module): 134 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 135 | super().__init__() 136 | out_features = out_features or in_features 137 | hidden_features = hidden_features or in_features 138 | self.fc1 = nn.Linear(in_features, hidden_features) 139 | self.act = act_layer() 140 | self.fc2 = nn.Linear(hidden_features, out_features) 141 | self.drop = nn.Dropout(drop) 142 | 143 | def forward(self, x): 144 | x = self.fc1(x) 145 | x = self.act(x) 146 | x = self.drop(x) 147 | x = self.fc2(x) 148 | x = self.drop(x) 149 | return x 150 | 151 | 152 | class Attention(nn.Module): 153 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 154 | super().__init__() 155 | self.num_heads = num_heads 156 | head_dim = dim // num_heads 157 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 158 | self.scale = qk_scale or head_dim ** -0.5 159 | 160 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 161 | self.attn_drop = nn.Dropout(attn_drop) 162 | self.proj = nn.Linear(dim, dim) 163 | self.proj_drop = nn.Dropout(proj_drop) 164 | 165 | def forward(self, x): 166 | B, N, C = x.shape 167 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 168 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 169 | 170 | attn = (q @ k.transpose(-2, -1)) * self.scale 171 | orth_proto = attn[:, :, 0, 1:] 172 | attn = attn.softmax(dim=-1) 173 | attn = self.attn_drop(attn) 174 | 175 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 176 | x = self.proj(x) 177 | x = self.proj_drop(x) 178 | return x, orth_proto 179 | 180 | 181 | class Block(nn.Module): 182 | 183 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 184 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 185 | super().__init__() 186 | self.norm1 = norm_layer(dim) 187 | self.attn = Attention( 188 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 189 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 190 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 191 | self.norm2 = norm_layer(dim) 192 | mlp_hidden_dim = int(dim * mlp_ratio) 193 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 194 | 195 | def forward(self, x): 196 | att_x, orth_proto = self.attn(self.norm1(x)) 197 | x = x + self.drop_path(att_x) 198 | x = x + self.drop_path(self.mlp(self.norm2(x))) 199 | return x, orth_proto 200 | 201 | 202 | class PatchEmbed(nn.Module): 203 | """ Image to Patch Embedding 204 | """ 205 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 206 | super().__init__() 207 | img_size = to_2tuple(img_size) 208 | patch_size = to_2tuple(patch_size) 209 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 210 | self.img_size = img_size 211 | self.patch_size = patch_size 212 | self.num_patches = num_patches 213 | 214 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 215 | 216 | def forward(self, x): 217 | B, C, H, W = x.shape 218 | # FIXME look at relaxing size constraints 219 | assert H == self.img_size[0] and W == self.img_size[1], \ 220 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 221 | x = self.proj(x).flatten(2).transpose(1, 2) 222 | return x 223 | 224 | 225 | class HybridEmbed(nn.Module): 226 | """ CNN Feature Map Embedding 227 | Extract feature map from CNN, flatten, project to embedding dim. 228 | """ 229 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): 230 | super().__init__() 231 | assert isinstance(backbone, nn.Module) 232 | img_size = to_2tuple(img_size) 233 | self.img_size = img_size 234 | self.backbone = backbone 235 | if feature_size is None: 236 | with torch.no_grad(): 237 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature 238 | # map for all networks, the feature metadata has reliable channel and stride info, but using 239 | # stride to calc feature dim requires info about padding of each stage that isn't captured. 240 | training = backbone.training 241 | if training: 242 | backbone.eval() 243 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) 244 | if isinstance(o, (list, tuple)): 245 | o = o[-1] # last feature if backbone outputs list/tuple of features 246 | feature_size = o.shape[-2:] 247 | feature_dim = o.shape[1] 248 | backbone.train(training) 249 | else: 250 | feature_size = to_2tuple(feature_size) 251 | if hasattr(self.backbone, 'feature_info'): 252 | feature_dim = self.backbone.feature_info.channels()[-1] 253 | else: 254 | feature_dim = self.backbone.num_features 255 | self.num_patches = feature_size[0] * feature_size[1] 256 | self.proj = nn.Conv2d(feature_dim, embed_dim, 1) 257 | 258 | def forward(self, x): 259 | x = self.backbone(x) 260 | if isinstance(x, (list, tuple)): 261 | x = x[-1] # last feature if backbone outputs list/tuple of features 262 | x = self.proj(x).flatten(2).transpose(1, 2) 263 | return x 264 | 265 | 266 | class PatchEmbed_overlap(nn.Module): 267 | """ Image to Patch Embedding with overlapping patches 268 | """ 269 | def __init__(self, img_size=224, patch_size=16, stride_size=20, in_chans=3, embed_dim=768): 270 | super().__init__() 271 | img_size = to_2tuple(img_size) 272 | patch_size = to_2tuple(patch_size) 273 | stride_size_tuple = to_2tuple(stride_size) 274 | self.num_x = (img_size[1] - patch_size[1]) // stride_size_tuple[1] + 1 275 | self.num_y = (img_size[0] - patch_size[0]) // stride_size_tuple[0] + 1 276 | print('using stride: {}, and patch number is num_y{} * num_x{}'.format(stride_size, self.num_y, self.num_x)) 277 | num_patches = self.num_x * self.num_y 278 | self.img_size = img_size 279 | self.patch_size = patch_size 280 | self.num_patches = num_patches 281 | 282 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride_size) 283 | for m in self.modules(): 284 | if isinstance(m, nn.Conv2d): 285 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 286 | m.weight.data.normal_(0, math.sqrt(2. / n)) 287 | elif isinstance(m, nn.BatchNorm2d): 288 | m.weight.data.fill_(1) 289 | m.bias.data.zero_() 290 | elif isinstance(m, nn.InstanceNorm2d): 291 | m.weight.data.fill_(1) 292 | m.bias.data.zero_() 293 | 294 | def forward(self, x): 295 | B, C, H, W = x.shape 296 | 297 | # FIXME look at relaxing size constraints 298 | assert H == self.img_size[0] and W == self.img_size[1], \ 299 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 300 | x = self.proj(x) 301 | 302 | x = x.flatten(2).transpose(1, 2) # [64, 8, 768] 303 | return x 304 | 305 | 306 | class TransReID(nn.Module): 307 | """ Transformer-based Object Re-Identification 308 | """ 309 | def __init__(self, img_size=224, patch_size=16, stride_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 310 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., camera=0, view=0, 311 | drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, local_feature=False, sie_xishu =1.0): 312 | super().__init__() 313 | self.num_classes = num_classes 314 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 315 | self.local_feature = local_feature 316 | if hybrid_backbone is not None: 317 | self.patch_embed = HybridEmbed( 318 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) 319 | else: 320 | self.patch_embed = PatchEmbed_overlap( 321 | img_size=img_size, patch_size=patch_size, stride_size=stride_size, in_chans=in_chans, 322 | embed_dim=embed_dim) 323 | 324 | num_patches = self.patch_embed.num_patches 325 | 326 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 327 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 328 | self.cam_num = camera 329 | self.view_num = view 330 | self.sie_xishu = sie_xishu 331 | # Initialize SIE Embedding 332 | if camera > 1 and view > 1: 333 | self.sie_embed = nn.Parameter(torch.zeros(camera * view, 1, embed_dim)) 334 | trunc_normal_(self.sie_embed, std=.02) 335 | print('camera number is : {} and viewpoint number is : {}'.format(camera, view)) 336 | print('using SIE_Lambda is : {}'.format(sie_xishu)) 337 | elif camera > 1: 338 | self.sie_embed = nn.Parameter(torch.zeros(camera, 1, embed_dim)) 339 | trunc_normal_(self.sie_embed, std=.02) 340 | print('camera number is : {}'.format(camera)) 341 | print('using SIE_Lambda is : {}'.format(sie_xishu)) 342 | elif view > 1: 343 | self.sie_embed = nn.Parameter(torch.zeros(view, 1, embed_dim)) 344 | trunc_normal_(self.sie_embed, std=.02) 345 | print('viewpoint number is : {}'.format(view)) 346 | print('using SIE_Lambda is : {}'.format(sie_xishu)) 347 | 348 | print('using drop_out rate is : {}'.format(drop_rate)) 349 | print('using attn_drop_out rate is : {}'.format(attn_drop_rate)) 350 | print('using drop_path rate is : {}'.format(drop_path_rate)) 351 | 352 | self.pos_drop = nn.Dropout(p=drop_rate) 353 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 354 | 355 | self.blocks = nn.ModuleList([ 356 | Block( 357 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 358 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 359 | for i in range(depth)]) 360 | 361 | self.norm = norm_layer(embed_dim) 362 | 363 | self.ConvLayer = nn.Sequential( 364 | conv3x3_block(768 * 4, 768 * 2), 365 | nn.MaxPool2d(kernel_size=2, stride=2), 366 | conv3x3_block(768 * 2, 768), 367 | nn.MaxPool2d(kernel_size=2, stride=2), 368 | conv3x3_block(768, 768), 369 | nn.AdaptiveMaxPool2d((1, 1))) 370 | 371 | self.attfc = nn.Linear(768, 768) 372 | self.sigmoid = nn.Sigmoid() 373 | # Classifier head 374 | self.fc = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 375 | trunc_normal_(self.cls_token, std=.02) 376 | trunc_normal_(self.pos_embed, std=.02) 377 | 378 | self.apply(self._init_weights) 379 | self.attfc.weight.data.zero_() 380 | 381 | 382 | def _init_weights(self, m): 383 | if isinstance(m, nn.Linear): 384 | trunc_normal_(m.weight, std=.02) 385 | if isinstance(m, nn.Linear) and m.bias is not None: 386 | nn.init.constant_(m.bias, 0) 387 | elif isinstance(m, nn.LayerNorm): 388 | nn.init.constant_(m.bias, 0) 389 | nn.init.constant_(m.weight, 1.0) 390 | 391 | @torch.jit.ignore 392 | def no_weight_decay(self): 393 | return {'pos_embed', 'cls_token'} 394 | 395 | def get_classifier(self): 396 | return self.head 397 | 398 | def reset_classifier(self, num_classes, global_pool=''): 399 | self.num_classes = num_classes 400 | self.fc = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 401 | 402 | def forward_features(self, x, camera_id, view_id): 403 | B = x.shape[0] 404 | x = self.patch_embed(x) 405 | 406 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 407 | x = torch.cat((cls_tokens, x), dim=1) 408 | 409 | if self.cam_num > 0 and self.view_num > 0: 410 | x = x + self.pos_embed + self.sie_xishu * self.sie_embed[camera_id * self.view_num + view_id] 411 | elif self.cam_num > 0: 412 | x = x + self.pos_embed + self.sie_xishu * self.sie_embed[camera_id] 413 | elif self.view_num > 0: 414 | x = x + self.pos_embed + self.sie_xishu * self.sie_embed[view_id] 415 | else: 416 | x = x + self.pos_embed 417 | 418 | W = int(((x.shape[1] - 1) / 2.0) ** 0.5) 419 | H = int((x.shape[1] - 1) / W) 420 | x = self.pos_drop(x) 421 | 422 | flag = 0 423 | 424 | for blk in self.blocks: 425 | x, orth_proto = blk(x) 426 | if flag == 1: 427 | x1 = x[:, 1:].transpose(1, 2).reshape(B, -1, H, W) 428 | elif flag == 3: 429 | x3 = x[:, 1:].transpose(1, 2).reshape(B, -1, H, W) 430 | elif flag == 9: 431 | x9 = x[:, 1:].transpose(1, 2).reshape(B, -1, H, W) 432 | flag += 1 433 | 434 | mask = x[:, 1:].transpose(1, 2).reshape(B, -1, H, W) 435 | mask = torch.cat((x1, x3, x9, mask), 1) 436 | mask = self.ConvLayer(mask) 437 | mask = self.attfc(torch.squeeze(mask)) 438 | x = self.norm(x) 439 | return x[:, 0], mask, orth_proto 440 | 441 | def forward(self, x, cam_label=None, view_label=None): 442 | x, mask, orth_proto = self.forward_features(x, cam_label, view_label) 443 | return x, self.sigmoid(mask), orth_proto 444 | 445 | def load_param(self, model_path): 446 | param_dict = torch.load(model_path, map_location='cpu') 447 | if 'model' in param_dict: 448 | param_dict = param_dict['model'] 449 | if 'state_dict' in param_dict: 450 | param_dict = param_dict['state_dict'] 451 | for k, v in param_dict.items(): 452 | if 'head' in k or 'dist' in k: 453 | continue 454 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4: 455 | # For old models that I trained prior to conv based patchification 456 | O, I, H, W = self.patch_embed.proj.weight.shape 457 | v = v.reshape(O, -1, H, W) 458 | elif k == 'pos_embed' and v.shape != self.pos_embed.shape: 459 | # To resize pos embedding when using model at different size from pretrained weights 460 | if 'distilled' in model_path: 461 | print('distill need to choose right cls token in the pth') 462 | v = torch.cat([v[:, 0:1], v[:, 2:]], dim=1) 463 | v = resize_pos_embed(v, self.pos_embed, self.patch_embed.num_y, self.patch_embed.num_x) 464 | try: 465 | self.state_dict()[k].copy_(v) 466 | except: 467 | print('===========================ERROR=========================') 468 | print('shape do not match in k :{}: param_dict{} vs self.state_dict(){}'.format(k, v.shape, self.state_dict()[k].shape)) 469 | 470 | 471 | def resize_pos_embed(posemb, posemb_new, hight, width): 472 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 473 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 474 | ntok_new = posemb_new.shape[1] 475 | 476 | posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:] 477 | ntok_new -= 1 478 | 479 | gs_old = int(math.sqrt(len(posemb_grid))) 480 | print('Resized position embedding from size:{} to size: {} with height:{} width: {}'.format(posemb.shape, posemb_new.shape, hight, width)) 481 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 482 | posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear') 483 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1) 484 | posemb = torch.cat([posemb_token, posemb_grid], dim=1) 485 | return posemb 486 | 487 | 488 | def vit_base_patch16_224_TransReID(img_size=(256, 128), stride_size=16, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1, camera=0, view=0,local_feature=False,sie_xishu=1.5, **kwargs): 489 | model = TransReID( 490 | img_size=img_size, patch_size=16, stride_size=stride_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,\ 491 | camera=camera, view=view, drop_path_rate=drop_path_rate, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, 492 | norm_layer=partial(nn.LayerNorm, eps=1e-6), sie_xishu=sie_xishu, local_feature=local_feature, **kwargs) 493 | 494 | return model 495 | 496 | def vit_small_patch16_224_TransReID(img_size=(256, 128), stride_size=16, drop_rate=0., attn_drop_rate=0.,drop_path_rate=0.1, camera=0, view=0, local_feature=False, sie_xishu=1.5, **kwargs): 497 | kwargs.setdefault('qk_scale', 768 ** -0.5) 498 | model = TransReID( 499 | img_size=img_size, patch_size=16, stride_size=stride_size, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., qkv_bias=False, drop_path_rate = drop_path_rate,\ 500 | camera=camera, view=view, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, 501 | norm_layer=partial(nn.LayerNorm, eps=1e-6), sie_xishu=sie_xishu, local_feature=local_feature, **kwargs) 502 | 503 | return model 504 | 505 | def deit_small_patch16_224_TransReID(img_size=(256, 128), stride_size=16, drop_path_rate=0.1, drop_rate=0.0, attn_drop_rate=0.0, camera=0, view=0, local_feature=False, sie_xishu=1.5, **kwargs): 506 | model = TransReID( 507 | img_size=img_size, patch_size=16, stride_size=stride_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 508 | drop_path_rate=drop_path_rate, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, camera=camera, view=view, sie_xishu=sie_xishu, local_feature=local_feature, 509 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 510 | 511 | return model 512 | 513 | 514 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 515 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 516 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 517 | def norm_cdf(x): 518 | # Computes standard normal cumulative distribution function 519 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 520 | 521 | if (mean < a - 2 * std) or (mean > b + 2 * std): 522 | print("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 523 | "The distribution of values may be incorrect.",) 524 | 525 | with torch.no_grad(): 526 | # Values are generated by using a truncated uniform distribution and 527 | # then using the inverse CDF for the normal distribution. 528 | # Get upper and lower cdf values 529 | l = norm_cdf((a - mean) / std) 530 | u = norm_cdf((b - mean) / std) 531 | 532 | # Uniformly fill tensor with values from [l, u], then translate to 533 | # [2l-1, 2u-1]. 534 | tensor.uniform_(2 * l - 1, 2 * u - 1) 535 | 536 | # Use inverse cdf transform for normal distribution to get truncated 537 | # standard normal 538 | tensor.erfinv_() 539 | 540 | # Transform to proper mean, std 541 | tensor.mul_(std * math.sqrt(2.)) 542 | tensor.add_(mean) 543 | 544 | # Clamp to ensure it's in the proper range 545 | tensor.clamp_(min=a, max=b) 546 | return tensor 547 | 548 | 549 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 550 | # type: (Tensor, float, float, float, float) -> Tensor 551 | r"""Fills the input Tensor with values drawn from a truncated 552 | normal distribution. The values are effectively drawn from the 553 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 554 | with values outside :math:`[a, b]` redrawn until they are within 555 | the bounds. The method used for generating the random values works 556 | best when :math:`a \leq \text{mean} \leq b`. 557 | Args: 558 | tensor: an n-dimensional `torch.Tensor` 559 | mean: the mean of the normal distribution 560 | std: the standard deviation of the normal distribution 561 | a: the minimum cutoff value 562 | b: the maximum cutoff value 563 | Examples: 564 | >>> w = torch.empty(3, 5) 565 | >>> nn.init.trunc_normal_(w) 566 | """ 567 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 568 | -------------------------------------------------------------------------------- /model/make_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .backbones.resnet import ResNet, Bottleneck 4 | import copy 5 | from .backbones.vit_pytorch import vit_base_patch16_224_TransReID, vit_small_patch16_224_TransReID, deit_small_patch16_224_TransReID 6 | from loss.metric_learning import Arcface, Cosface, AMSoftmax, CircleLoss, MArcface 7 | 8 | def shuffle_unit(features, shift, group, begin=1): 9 | 10 | batchsize = features.size(0) 11 | dim = features.size(-1) 12 | # Shift Operation 13 | feature_random = torch.cat([features[:, begin-1+shift:], features[:, begin:begin-1+shift]], dim=1) 14 | x = feature_random 15 | # Patch Shuffle Operation 16 | try: 17 | x = x.view(batchsize, group, -1, dim) 18 | except: 19 | x = torch.cat([x, x[:, -2:-1, :]], dim=1) 20 | x = x.view(batchsize, group, -1, dim) 21 | 22 | x = torch.transpose(x, 1, 2).contiguous() 23 | x = x.view(batchsize, -1, dim) 24 | 25 | return x 26 | 27 | def weights_init_kaiming(m): 28 | classname = m.__class__.__name__ 29 | if classname.find('Linear') != -1: 30 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out') 31 | nn.init.constant_(m.bias, 0.0) 32 | 33 | elif classname.find('Conv') != -1: 34 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 35 | if m.bias is not None: 36 | nn.init.constant_(m.bias, 0.0) 37 | elif classname.find('BatchNorm') != -1: 38 | if m.affine: 39 | nn.init.constant_(m.weight, 1.0) 40 | nn.init.constant_(m.bias, 0.0) 41 | 42 | def weights_init_classifier(m): 43 | classname = m.__class__.__name__ 44 | if classname.find('Linear') != -1: 45 | nn.init.normal_(m.weight, std=0.001) 46 | if m.bias: 47 | nn.init.constant_(m.bias, 0.0) 48 | 49 | 50 | class Backbone(nn.Module): 51 | def __init__(self, num_classes, cfg): 52 | super(Backbone, self).__init__() 53 | last_stride = cfg.MODEL.LAST_STRIDE 54 | model_path = cfg.MODEL.PRETRAIN_PATH 55 | model_name = cfg.MODEL.NAME 56 | pretrain_choice = cfg.MODEL.PRETRAIN_CHOICE 57 | self.cos_layer = cfg.MODEL.COS_LAYER 58 | self.neck = cfg.MODEL.NECK 59 | self.neck_feat = cfg.TEST.NECK_FEAT 60 | 61 | if model_name == 'resnet50': 62 | self.in_planes = 2048 63 | self.base = ResNet(last_stride=last_stride, 64 | block=Bottleneck, 65 | layers=[3, 4, 6, 3]) 66 | print('using resnet50 as a backbone') 67 | else: 68 | print('unsupported backbone! but got {}'.format(model_name)) 69 | 70 | if pretrain_choice == 'imagenet': 71 | self.base.load_param(model_path) 72 | print('Loading pretrained ImageNet model......from {}'.format(model_path)) 73 | 74 | self.gap = nn.AdaptiveAvgPool2d(1) 75 | self.num_classes = num_classes 76 | 77 | self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) 78 | self.classifier.apply(weights_init_classifier) 79 | 80 | self.bottleneck = nn.BatchNorm1d(self.in_planes) 81 | self.bottleneck.bias.requires_grad_(False) 82 | self.bottleneck.apply(weights_init_kaiming) 83 | 84 | def forward(self, x, label=None): # label is unused if self.cos_layer == 'no' 85 | x = self.base(x) 86 | global_feat = nn.functional.avg_pool2d(x, x.shape[2:4]) 87 | global_feat = global_feat.view(global_feat.shape[0], -1) # flatten to (bs, 2048) 88 | 89 | if self.neck == 'no': 90 | feat = global_feat 91 | elif self.neck == 'bnneck': 92 | feat = self.bottleneck(global_feat) 93 | 94 | if self.training: 95 | if self.cos_layer: 96 | cls_score = self.arcface(feat, label) 97 | else: 98 | cls_score = self.classifier(feat) 99 | return cls_score, global_feat 100 | else: 101 | if self.neck_feat == 'after': 102 | return feat 103 | else: 104 | return global_feat 105 | 106 | def load_param(self, trained_path): 107 | param_dict = torch.load(trained_path) 108 | if 'state_dict' in param_dict: 109 | param_dict = param_dict['state_dict'] 110 | for i in param_dict: 111 | self.state_dict()[i].copy_(param_dict[i]) 112 | print('Loading pretrained model from {}'.format(trained_path)) 113 | 114 | def load_param_finetune(self, model_path): 115 | param_dict = torch.load(model_path) 116 | for i in param_dict: 117 | self.state_dict()[i].copy_(param_dict[i]) 118 | print('Loading pretrained model for finetuning from {}'.format(model_path)) 119 | 120 | 121 | class build_transformer(nn.Module): 122 | def __init__(self, num_classes, camera_num, view_num, cfg, factory): 123 | super(build_transformer, self).__init__() 124 | last_stride = cfg.MODEL.LAST_STRIDE 125 | model_path = cfg.MODEL.PRETRAIN_PATH 126 | model_name = cfg.MODEL.NAME 127 | pretrain_choice = cfg.MODEL.PRETRAIN_CHOICE 128 | self.cos_layer = cfg.MODEL.COS_LAYER 129 | self.neck = cfg.MODEL.NECK 130 | self.neck_feat = cfg.TEST.NECK_FEAT 131 | self.in_planes = 768 132 | 133 | print('using Transformer_type: {} as a backbone'.format(cfg.MODEL.TRANSFORMER_TYPE)) 134 | 135 | if cfg.MODEL.SIE_CAMERA: 136 | camera_num = camera_num 137 | else: 138 | camera_num = 0 139 | if cfg.MODEL.SIE_VIEW: 140 | view_num = view_num 141 | else: 142 | view_num = 0 143 | 144 | self.base = factory[cfg.MODEL.TRANSFORMER_TYPE](img_size=cfg.INPUT.SIZE_TRAIN, sie_xishu=cfg.MODEL.SIE_COE, 145 | camera=camera_num, view=view_num, stride_size=cfg.MODEL.STRIDE_SIZE, drop_path_rate=cfg.MODEL.DROP_PATH, 146 | drop_rate= cfg.MODEL.DROP_OUT, 147 | attn_drop_rate=cfg.MODEL.ATT_DROP_RATE) 148 | if cfg.MODEL.TRANSFORMER_TYPE == 'deit_small_patch16_224_TransReID': 149 | self.in_planes = 384 150 | if pretrain_choice == 'imagenet': 151 | self.base.load_param(model_path) 152 | print('Loading pretrained ImageNet model......from {}'.format(model_path)) 153 | 154 | self.gap = nn.AdaptiveAvgPool2d(1) 155 | 156 | self.num_classes = num_classes 157 | self.ID_LOSS_TYPE = cfg.MODEL.ID_LOSS_TYPE 158 | if self.ID_LOSS_TYPE == 'arcface': 159 | print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,cfg.SOLVER.COSINE_SCALE,cfg.SOLVER.COSINE_MARGIN)) 160 | self.classifier = Arcface(self.in_planes, self.num_classes, 161 | s=cfg.SOLVER.COSINE_SCALE, m=cfg.SOLVER.COSINE_MARGIN) 162 | elif self.ID_LOSS_TYPE == 'cosface': 163 | print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,cfg.SOLVER.COSINE_SCALE,cfg.SOLVER.COSINE_MARGIN)) 164 | self.classifier = Cosface(self.in_planes, self.num_classes, 165 | s=cfg.SOLVER.COSINE_SCALE, m=cfg.SOLVER.COSINE_MARGIN) 166 | elif self.ID_LOSS_TYPE == 'amsoftmax': 167 | print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,cfg.SOLVER.COSINE_SCALE,cfg.SOLVER.COSINE_MARGIN)) 168 | self.classifier = AMSoftmax(self.in_planes, self.num_classes, 169 | s=cfg.SOLVER.COSINE_SCALE, m=cfg.SOLVER.COSINE_MARGIN) 170 | elif self.ID_LOSS_TYPE == 'circle': 171 | print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE, cfg.SOLVER.COSINE_SCALE, cfg.SOLVER.COSINE_MARGIN)) 172 | self.classifier = CircleLoss(self.in_planes, self.num_classes, 173 | s=cfg.SOLVER.COSINE_SCALE, m=cfg.SOLVER.COSINE_MARGIN) 174 | else: 175 | self.classifier = MArcface(self.in_planes, self.num_classes, 176 | s=cfg.SOLVER.COSINE_SCALE, m=cfg.SOLVER.COSINE_MARGIN) 177 | '''self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) 178 | self.classifier.apply(weights_init_classifier)''' 179 | 180 | self.bottleneck = nn.BatchNorm1d(self.in_planes) 181 | self.bottleneck.bias.requires_grad_(False) 182 | self.bottleneck.apply(weights_init_kaiming) 183 | 184 | def forward(self, x, label=None, cam_label= None, view_label=None): 185 | global_feat, mask, orth_proto = self.base(x, cam_label=cam_label, view_label=view_label) 186 | 187 | feat = self.bottleneck(global_feat) 188 | 189 | if self.training: 190 | cls_score, Mcls_score = self.classifier(feat, mask, label) 191 | 192 | return cls_score, Mcls_score, global_feat, orth_proto # global feature for triplet loss 193 | else: 194 | if self.neck_feat == 'after': 195 | # print("Test with feature after BN") 196 | return feat, mask 197 | else: 198 | # print("Test with feature before BN") 199 | return global_feat, mask 200 | 201 | def load_param(self, trained_path): 202 | param_dict = torch.load(trained_path) 203 | for i in param_dict: 204 | self.state_dict()[i.replace('module.', '')].copy_(param_dict[i]) 205 | print('Loading pretrained model from {}'.format(trained_path)) 206 | 207 | def load_param_finetune(self, model_path): 208 | param_dict = torch.load(model_path) 209 | for i in param_dict: 210 | self.state_dict()[i].copy_(param_dict[i]) 211 | print('Loading pretrained model for finetuning from {}'.format(model_path)) 212 | 213 | 214 | class build_transformer_local(nn.Module): 215 | def __init__(self, num_classes, camera_num, view_num, cfg, factory, rearrange): 216 | super(build_transformer_local, self).__init__() 217 | model_path = cfg.MODEL.PRETRAIN_PATH 218 | pretrain_choice = cfg.MODEL.PRETRAIN_CHOICE 219 | self.cos_layer = cfg.MODEL.COS_LAYER 220 | self.neck = cfg.MODEL.NECK 221 | self.neck_feat = cfg.TEST.NECK_FEAT 222 | self.in_planes = 768 223 | 224 | print('using Transformer_type: {} as a backbone'.format(cfg.MODEL.TRANSFORMER_TYPE)) 225 | 226 | if cfg.MODEL.SIE_CAMERA: 227 | camera_num = camera_num 228 | else: 229 | camera_num = 0 230 | 231 | if cfg.MODEL.SIE_VIEW: 232 | view_num = view_num 233 | else: 234 | view_num = 0 235 | 236 | self.base = factory[cfg.MODEL.TRANSFORMER_TYPE](img_size=cfg.INPUT.SIZE_TRAIN, sie_xishu=cfg.MODEL.SIE_COE, local_feature=cfg.MODEL.JPM, camera=camera_num, view=view_num, stride_size=cfg.MODEL.STRIDE_SIZE, drop_path_rate=cfg.MODEL.DROP_PATH) 237 | 238 | if pretrain_choice == 'imagenet': 239 | self.base.load_param(model_path) 240 | print('Loading pretrained ImageNet model......from {}'.format(model_path)) 241 | 242 | block = self.base.blocks[-1] 243 | layer_norm = self.base.norm 244 | self.b1 = nn.Sequential( 245 | copy.deepcopy(block), 246 | copy.deepcopy(layer_norm) 247 | ) 248 | self.b2 = nn.Sequential( 249 | copy.deepcopy(block), 250 | copy.deepcopy(layer_norm) 251 | ) 252 | 253 | self.num_classes = num_classes 254 | self.ID_LOSS_TYPE = cfg.MODEL.ID_LOSS_TYPE 255 | if self.ID_LOSS_TYPE == 'arcface': 256 | print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,cfg.SOLVER.COSINE_SCALE,cfg.SOLVER.COSINE_MARGIN)) 257 | self.classifier = Arcface(self.in_planes, self.num_classes, 258 | s=cfg.SOLVER.COSINE_SCALE, m=cfg.SOLVER.COSINE_MARGIN) 259 | elif self.ID_LOSS_TYPE == 'cosface': 260 | print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,cfg.SOLVER.COSINE_SCALE,cfg.SOLVER.COSINE_MARGIN)) 261 | self.classifier = Cosface(self.in_planes, self.num_classes, 262 | s=cfg.SOLVER.COSINE_SCALE, m=cfg.SOLVER.COSINE_MARGIN) 263 | elif self.ID_LOSS_TYPE == 'amsoftmax': 264 | print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE,cfg.SOLVER.COSINE_SCALE,cfg.SOLVER.COSINE_MARGIN)) 265 | self.classifier = AMSoftmax(self.in_planes, self.num_classes, 266 | s=cfg.SOLVER.COSINE_SCALE, m=cfg.SOLVER.COSINE_MARGIN) 267 | elif self.ID_LOSS_TYPE == 'circle': 268 | print('using {} with s:{}, m: {}'.format(self.ID_LOSS_TYPE, cfg.SOLVER.COSINE_SCALE, cfg.SOLVER.COSINE_MARGIN)) 269 | self.classifier = CircleLoss(self.in_planes, self.num_classes, 270 | s=cfg.SOLVER.COSINE_SCALE, m=cfg.SOLVER.COSINE_MARGIN) 271 | else: 272 | self.classifier = nn.Linear(self.in_planes, self.num_classes, bias=False) 273 | self.classifier.apply(weights_init_classifier) 274 | self.classifier_1 = nn.Linear(self.in_planes, self.num_classes, bias=False) 275 | self.classifier_1.apply(weights_init_classifier) 276 | self.classifier_2 = nn.Linear(self.in_planes, self.num_classes, bias=False) 277 | self.classifier_2.apply(weights_init_classifier) 278 | self.classifier_3 = nn.Linear(self.in_planes, self.num_classes, bias=False) 279 | self.classifier_3.apply(weights_init_classifier) 280 | self.classifier_4 = nn.Linear(self.in_planes, self.num_classes, bias=False) 281 | self.classifier_4.apply(weights_init_classifier) 282 | 283 | self.bottleneck = nn.BatchNorm1d(self.in_planes) 284 | self.bottleneck.bias.requires_grad_(False) 285 | self.bottleneck.apply(weights_init_kaiming) 286 | self.bottleneck_1 = nn.BatchNorm1d(self.in_planes) 287 | self.bottleneck_1.bias.requires_grad_(False) 288 | self.bottleneck_1.apply(weights_init_kaiming) 289 | self.bottleneck_2 = nn.BatchNorm1d(self.in_planes) 290 | self.bottleneck_2.bias.requires_grad_(False) 291 | self.bottleneck_2.apply(weights_init_kaiming) 292 | self.bottleneck_3 = nn.BatchNorm1d(self.in_planes) 293 | self.bottleneck_3.bias.requires_grad_(False) 294 | self.bottleneck_3.apply(weights_init_kaiming) 295 | self.bottleneck_4 = nn.BatchNorm1d(self.in_planes) 296 | self.bottleneck_4.bias.requires_grad_(False) 297 | self.bottleneck_4.apply(weights_init_kaiming) 298 | 299 | self.shuffle_groups = cfg.MODEL.SHUFFLE_GROUP 300 | print('using shuffle_groups size:{}'.format(self.shuffle_groups)) 301 | self.shift_num = cfg.MODEL.SHIFT_NUM 302 | print('using shift_num size:{}'.format(self.shift_num)) 303 | self.divide_length = cfg.MODEL.DEVIDE_LENGTH 304 | print('using divide_length size:{}'.format(self.divide_length)) 305 | self.rearrange = rearrange 306 | 307 | def forward(self, x, label=None, cam_label= None, view_label=None): # label is unused if self.cos_layer == 'no' 308 | 309 | features = self.base(x, cam_label=cam_label, view_label=view_label) 310 | 311 | # global branch 312 | b1_feat = self.b1(features) # [64, 129, 768] 313 | global_feat = b1_feat[:, 0] 314 | 315 | # JPM branch 316 | feature_length = features.size(1) - 1 317 | patch_length = feature_length // self.divide_length 318 | token = features[:, 0:1] 319 | 320 | if self.rearrange: 321 | x = shuffle_unit(features, self.shift_num, self.shuffle_groups) 322 | else: 323 | x = features[:, 1:] 324 | # lf_1 325 | b1_local_feat = x[:, :patch_length] 326 | b1_local_feat = self.b2(torch.cat((token, b1_local_feat), dim=1)) 327 | local_feat_1 = b1_local_feat[:, 0] 328 | 329 | # lf_2 330 | b2_local_feat = x[:, patch_length:patch_length*2] 331 | b2_local_feat = self.b2(torch.cat((token, b2_local_feat), dim=1)) 332 | local_feat_2 = b2_local_feat[:, 0] 333 | 334 | # lf_3 335 | b3_local_feat = x[:, patch_length*2:patch_length*3] 336 | b3_local_feat = self.b2(torch.cat((token, b3_local_feat), dim=1)) 337 | local_feat_3 = b3_local_feat[:, 0] 338 | 339 | # lf_4 340 | b4_local_feat = x[:, patch_length*3:patch_length*4] 341 | b4_local_feat = self.b2(torch.cat((token, b4_local_feat), dim=1)) 342 | local_feat_4 = b4_local_feat[:, 0] 343 | 344 | feat = self.bottleneck(global_feat) 345 | 346 | local_feat_1_bn = self.bottleneck_1(local_feat_1) 347 | local_feat_2_bn = self.bottleneck_2(local_feat_2) 348 | local_feat_3_bn = self.bottleneck_3(local_feat_3) 349 | local_feat_4_bn = self.bottleneck_4(local_feat_4) 350 | 351 | if self.training: 352 | if self.ID_LOSS_TYPE in ('arcface', 'cosface', 'amsoftmax', 'circle'): 353 | cls_score = self.classifier(feat, label) 354 | else: 355 | cls_score = self.classifier(feat) 356 | cls_score_1 = self.classifier_1(local_feat_1_bn) 357 | cls_score_2 = self.classifier_2(local_feat_2_bn) 358 | cls_score_3 = self.classifier_3(local_feat_3_bn) 359 | cls_score_4 = self.classifier_4(local_feat_4_bn) 360 | return [cls_score, cls_score_1, cls_score_2, cls_score_3, 361 | cls_score_4 362 | ], [global_feat, local_feat_1, local_feat_2, local_feat_3, 363 | local_feat_4] # global feature for triplet loss 364 | else: 365 | if self.neck_feat == 'after': 366 | return torch.cat( 367 | [feat, local_feat_1_bn / 4, local_feat_2_bn / 4, local_feat_3_bn / 4, local_feat_4_bn / 4], dim=1) 368 | else: 369 | return torch.cat( 370 | [global_feat, local_feat_1 / 4, local_feat_2 / 4, local_feat_3 / 4, local_feat_4 / 4], dim=1) 371 | 372 | def load_param(self, trained_path): 373 | param_dict = torch.load(trained_path) 374 | for i in param_dict: 375 | self.state_dict()[i.replace('module.', '')].copy_(param_dict[i]) 376 | print('Loading pretrained model from {}'.format(trained_path)) 377 | 378 | def load_param_finetune(self, model_path): 379 | param_dict = torch.load(model_path) 380 | for i in param_dict: 381 | self.state_dict()[i].copy_(param_dict[i]) 382 | print('Loading pretrained model for finetuning from {}'.format(model_path)) 383 | 384 | 385 | __factory_T_type = { 386 | 'vit_base_patch16_224_TransReID': vit_base_patch16_224_TransReID, 387 | 'deit_base_patch16_224_TransReID': vit_base_patch16_224_TransReID, 388 | 'vit_small_patch16_224_TransReID': vit_small_patch16_224_TransReID, 389 | 'deit_small_patch16_224_TransReID': deit_small_patch16_224_TransReID 390 | } 391 | 392 | def make_model(cfg, num_class, camera_num, view_num): 393 | if cfg.MODEL.NAME == 'transformer': 394 | if cfg.MODEL.JPM: 395 | model = build_transformer_local(num_class, camera_num, view_num, cfg, __factory_T_type, rearrange=cfg.MODEL.RE_ARRANGE) 396 | print('===========building transformer with JPM module ===========') 397 | else: 398 | model = build_transformer(num_class, camera_num, view_num, cfg, __factory_T_type) 399 | print('===========building transformer===========') 400 | else: 401 | model = Backbone(num_class, cfg) 402 | print('===========building ResNet===========') 403 | return model 404 | -------------------------------------------------------------------------------- /processor/__init__.py: -------------------------------------------------------------------------------- 1 | from .processor import do_train, do_inference -------------------------------------------------------------------------------- /processor/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/processor/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /processor/__pycache__/processor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/processor/__pycache__/processor.cpython-37.pyc -------------------------------------------------------------------------------- /processor/processor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | import torch 5 | import torch.nn as nn 6 | from utils.meter import AverageMeter 7 | from utils.metrics import R1_mAP_eval 8 | from torch.cuda import amp 9 | import torch.distributed as dist 10 | 11 | def do_train(cfg, 12 | model, 13 | center_criterion, 14 | train_loader, 15 | val_loader, 16 | optimizer, 17 | Moptimizer, 18 | scheduler, 19 | loss_fn, 20 | num_query, local_rank): 21 | log_period = cfg.SOLVER.LOG_PERIOD 22 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD 23 | eval_period = cfg.SOLVER.EVAL_PERIOD 24 | 25 | device = "cuda" 26 | epochs = cfg.SOLVER.MAX_EPOCHS 27 | 28 | logger = logging.getLogger("transreid.train") 29 | logger.info('start training') 30 | _LOCAL_PROCESS_GROUP = None 31 | if device: 32 | model.to(local_rank) 33 | if torch.cuda.device_count() > 1 and cfg.MODEL.DIST_TRAIN: 34 | print('Using {} GPUs for training'.format(torch.cuda.device_count())) 35 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True) 36 | 37 | loss_meter = AverageMeter() 38 | acc_meter = AverageMeter() 39 | 40 | evaluator = R1_mAP_eval(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM) 41 | scaler = amp.GradScaler() 42 | # train 43 | for epoch in range(1, epochs + 1): 44 | start_time = time.time() 45 | loss_meter.reset() 46 | acc_meter.reset() 47 | evaluator.reset() 48 | scheduler.step(epoch) 49 | model.train() 50 | for n_iter, (img, vid, target_cam, target_view) in enumerate(train_loader): 51 | img = img.to(device) 52 | target = vid.to(device) 53 | target_cam = target_cam.to(device) 54 | target_view = target_view.to(device) 55 | ############ Feature extractor ########################### 56 | optimizer.zero_grad() 57 | with amp.autocast(enabled=True): 58 | score, Mscore, feat, orth_proto = model(img, target, cam_label=target_cam, view_label=target_view ) 59 | loss = loss_fn(score, Mscore, feat, orth_proto, epoch, target, target_cam) 60 | 61 | scaler.scale(loss).backward() 62 | 63 | scaler.step(optimizer) 64 | scaler.update() 65 | ############ Mask Generator ############################### 66 | Moptimizer.zero_grad() 67 | with amp.autocast(enabled=True): 68 | score, Mscore, feat, orth_proto = model(img, target, cam_label=target_cam, view_label=target_view ) 69 | loss = loss_fn(score, Mscore, feat, orth_proto, epoch, target, target_cam) 70 | 71 | scaler.scale(loss).backward() 72 | 73 | scaler.step(Moptimizer) 74 | scaler.update() 75 | ########################################################### 76 | '''if 'center' in cfg.MODEL.METRIC_LOSS_TYPE: 77 | for param in center_criterion.parameters(): 78 | param.grad.data *= (1. / cfg.SOLVER.CENTER_LOSS_WEIGHT) 79 | scaler.step(optimizer_center) 80 | scaler.update()''' 81 | if isinstance(score, list): 82 | acc = (score[0].max(1)[1] == target).float().mean() 83 | else: 84 | acc = (score.max(1)[1] == target).float().mean() 85 | 86 | loss_meter.update(loss.item(), img.shape[0]) 87 | acc_meter.update(acc, 1) 88 | 89 | torch.cuda.synchronize() 90 | if (n_iter + 1) % log_period == 0: 91 | logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}" 92 | .format(epoch, (n_iter + 1), len(train_loader), 93 | loss_meter.avg, acc_meter.avg, scheduler._get_lr(epoch)[0])) 94 | 95 | end_time = time.time() 96 | time_per_batch = (end_time - start_time) / (n_iter + 1) 97 | if cfg.MODEL.DIST_TRAIN: 98 | pass 99 | else: 100 | logger.info("Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]" 101 | .format(epoch, time_per_batch, train_loader.batch_size / time_per_batch)) 102 | 103 | if epoch % checkpoint_period == 0: 104 | if cfg.MODEL.DIST_TRAIN: 105 | if dist.get_rank() == 0: 106 | torch.save(model.state_dict(), 107 | os.path.join(cfg.OUTPUT_DIR, cfg.MODEL.NAME + '_{}.pth'.format(epoch))) 108 | else: 109 | torch.save(model.state_dict(), 110 | os.path.join(cfg.OUTPUT_DIR, cfg.MODEL.NAME + '_{}.pth'.format(epoch))) 111 | 112 | if epoch % eval_period == 0: 113 | if cfg.MODEL.DIST_TRAIN: 114 | if dist.get_rank() == 0: 115 | model.eval() 116 | for n_iter, (img, vid, camid, camids, target_view, _) in enumerate(val_loader): 117 | with torch.no_grad(): 118 | img = img.to(device) 119 | camids = camids.to(device) 120 | target_view = target_view.to(device) 121 | feat = model(img, cam_label=camids, view_label=target_view) 122 | evaluator.update((feat, vid, camid)) 123 | cmc, mAP, _, _, _, _, _ = evaluator.compute() 124 | logger.info("Validation Results - Epoch: {}".format(epoch)) 125 | logger.info("mAP: {:.1%}".format(mAP)) 126 | for r in [1, 5, 10]: 127 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 128 | torch.cuda.empty_cache() 129 | else: 130 | model.eval() 131 | for n_iter, (img, vid, camid, camids, target_view, _) in enumerate(val_loader): 132 | with torch.no_grad(): 133 | img = img.to(device) 134 | camids = camids.to(device) 135 | target_view = target_view.to(device) 136 | feat, mask = model(img, cam_label=camids, view_label=target_view) 137 | evaluator.update((feat, mask, vid, camid)) 138 | cmc, mAP, _, _, _, _, _ = evaluator.compute() 139 | logger.info("Validation Results - Epoch: {}".format(epoch)) 140 | logger.info("mAP: {:.1%}".format(mAP)) 141 | for r in [1, 5, 10]: 142 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 143 | torch.cuda.empty_cache() 144 | 145 | 146 | def do_inference(cfg, 147 | model, 148 | val_loader, 149 | num_query): 150 | device = "cuda" 151 | logger = logging.getLogger("transreid.test") 152 | logger.info("Enter inferencing") 153 | 154 | evaluator = R1_mAP_eval(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM) 155 | 156 | evaluator.reset() 157 | 158 | if device: 159 | if torch.cuda.device_count() > 1: 160 | print('Using {} GPUs for inference'.format(torch.cuda.device_count())) 161 | model = nn.DataParallel(model) 162 | model.to(device) 163 | 164 | model.eval() 165 | img_path_list = [] 166 | 167 | for n_iter, (img, pid, camid, camids, target_view, imgpath) in enumerate(val_loader): 168 | with torch.no_grad(): 169 | img = img.to(device) 170 | camids = camids.to(device) 171 | target_view = target_view.to(device) 172 | feat, mask = model(img, cam_label=camids, view_label=target_view) 173 | evaluator.update((feat, mask, pid, camid)) 174 | img_path_list.extend(imgpath) 175 | 176 | cmc, mAP, _, _, _, _, _ = evaluator.compute() 177 | logger.info("Validation Results ") 178 | logger.info("mAP: {:.1%}".format(mAP)) 179 | for r in [1, 5, 10]: 180 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 181 | return cmc[0], cmc[4] 182 | 183 | 184 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.10.1 2 | torchvision==0.11.2 3 | timm==0.4.9 4 | yacs==0.1.6 5 | opencv-python==4.5.5.62 -------------------------------------------------------------------------------- /solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .lr_scheduler import WarmupMultiStepLR 2 | from .make_optimizer import make_optimizer -------------------------------------------------------------------------------- /solver/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/solver/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /solver/__pycache__/cosine_lr.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/solver/__pycache__/cosine_lr.cpython-37.pyc -------------------------------------------------------------------------------- /solver/__pycache__/lr_scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/solver/__pycache__/lr_scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /solver/__pycache__/make_optimizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/solver/__pycache__/make_optimizer.cpython-37.pyc -------------------------------------------------------------------------------- /solver/__pycache__/scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/solver/__pycache__/scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /solver/__pycache__/scheduler_factory.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/solver/__pycache__/scheduler_factory.cpython-37.pyc -------------------------------------------------------------------------------- /solver/cosine_lr.py: -------------------------------------------------------------------------------- 1 | """ Cosine Scheduler 2 | 3 | Cosine LR schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import torch 10 | 11 | from .scheduler import Scheduler 12 | 13 | 14 | _logger = logging.getLogger(__name__) 15 | 16 | 17 | class CosineLRScheduler(Scheduler): 18 | """ 19 | Cosine decay with restarts. 20 | This is described in the paper https://arxiv.org/abs/1608.03983. 21 | 22 | Inspiration from 23 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 24 | """ 25 | 26 | def __init__(self, 27 | optimizer: torch.optim.Optimizer, 28 | t_initial: int, 29 | t_mul: float = 1., 30 | lr_min: float = 0., 31 | decay_rate: float = 1., 32 | warmup_t=0, 33 | warmup_lr_init=0, 34 | warmup_prefix=False, 35 | cycle_limit=0, 36 | t_in_epochs=True, 37 | noise_range_t=None, 38 | noise_pct=0.67, 39 | noise_std=1.0, 40 | noise_seed=42, 41 | initialize=True) -> None: 42 | super().__init__( 43 | optimizer, param_group_field="lr", 44 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 45 | initialize=initialize) 46 | 47 | assert t_initial > 0 48 | assert lr_min >= 0 49 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 50 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 51 | "rate since t_initial = t_mul = eta_mul = 1.") 52 | self.t_initial = t_initial 53 | self.t_mul = t_mul 54 | self.lr_min = lr_min 55 | self.decay_rate = decay_rate 56 | self.cycle_limit = cycle_limit 57 | self.warmup_t = warmup_t 58 | self.warmup_lr_init = warmup_lr_init 59 | self.warmup_prefix = warmup_prefix 60 | self.t_in_epochs = t_in_epochs 61 | if self.warmup_t: 62 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 63 | super().update_groups(self.warmup_lr_init) 64 | else: 65 | self.warmup_steps = [1 for _ in self.base_values] 66 | 67 | def _get_lr(self, t): 68 | if t < self.warmup_t: 69 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 70 | else: 71 | if self.warmup_prefix: 72 | t = t - self.warmup_t 73 | 74 | if self.t_mul != 1: 75 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 76 | t_i = self.t_mul ** i * self.t_initial 77 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 78 | else: 79 | i = t // self.t_initial 80 | t_i = self.t_initial 81 | t_curr = t - (self.t_initial * i) 82 | 83 | gamma = self.decay_rate ** i 84 | lr_min = self.lr_min * gamma 85 | lr_max_values = [v * gamma for v in self.base_values] 86 | 87 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 88 | lrs = [ 89 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 90 | ] 91 | else: 92 | lrs = [self.lr_min for _ in self.base_values] 93 | 94 | return lrs 95 | 96 | def get_epoch_values(self, epoch: int): 97 | if self.t_in_epochs: 98 | return self._get_lr(epoch) 99 | else: 100 | return None 101 | 102 | def get_update_values(self, num_updates: int): 103 | if not self.t_in_epochs: 104 | return self._get_lr(num_updates) 105 | else: 106 | return None 107 | 108 | def get_cycle_length(self, cycles=0): 109 | if not cycles: 110 | cycles = self.cycle_limit 111 | cycles = max(1, cycles) 112 | if self.t_mul == 1.0: 113 | return self.t_initial * cycles 114 | else: 115 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 116 | -------------------------------------------------------------------------------- /solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from bisect import bisect_right 7 | import torch 8 | 9 | 10 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 11 | # separating MultiStepLR with WarmupLR 12 | # but the current LRScheduler design doesn't allow it 13 | 14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 15 | def __init__( 16 | self, 17 | optimizer, 18 | milestones, # steps 19 | gamma=0.1, 20 | warmup_factor=1.0 / 3, 21 | warmup_iters=500, 22 | warmup_method="linear", 23 | last_epoch=-1, 24 | ): 25 | if not list(milestones) == sorted(milestones): 26 | raise ValueError( 27 | "Milestones should be a list of" " increasing integers. Got {}", 28 | milestones, 29 | ) 30 | 31 | if warmup_method not in ("constant", "linear"): 32 | raise ValueError( 33 | "Only 'constant' or 'linear' warmup_method accepted" 34 | "got {}".format(warmup_method) 35 | ) 36 | self.milestones = milestones 37 | self.gamma = gamma 38 | self.warmup_factor = warmup_factor 39 | self.warmup_iters = warmup_iters 40 | self.warmup_method = warmup_method 41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 42 | 43 | def _get_lr(self): 44 | warmup_factor = 1 45 | if self.last_epoch < self.warmup_iters: 46 | if self.warmup_method == "constant": 47 | warmup_factor = self.warmup_factor 48 | elif self.warmup_method == "linear": 49 | alpha = self.last_epoch / self.warmup_iters 50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 51 | return [ 52 | base_lr 53 | * warmup_factor 54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 55 | for base_lr in self.base_lrs 56 | ] 57 | -------------------------------------------------------------------------------- /solver/make_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def make_optimizer(cfg, model, center_criterion): 4 | params = [] 5 | Mparams = [] 6 | for key, value in model.named_parameters(): 7 | if "ConvLayer" in key or "attfc" in key: 8 | print(key) 9 | lr = cfg.SOLVER.BASE_LR 10 | if "ConvLayer" in key and "weight" in key: 11 | lr = cfg.SOLVER.BASE_LR 12 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 13 | if "attfc" in key and "weight" in key: 14 | lr = cfg.SOLVER.BASE_LR 15 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 16 | if "bias" in key: 17 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 18 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 19 | Mparams += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 20 | else: 21 | if not value.requires_grad: 22 | continue 23 | lr = cfg.SOLVER.BASE_LR 24 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 25 | if "bias" in key: 26 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 27 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 28 | if cfg.SOLVER.LARGE_FC_LR: 29 | if "classifier" in key or "arcface" in key: 30 | lr = cfg.SOLVER.BASE_LR * 2 31 | print('Using two times learning rate for fc ') 32 | 33 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 34 | 35 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 36 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 37 | Moptimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(Mparams, momentum=cfg.SOLVER.MOMENTUM) 38 | elif cfg.SOLVER.OPTIMIZER_NAME == 'AdamW': 39 | optimizer = torch.optim.AdamW(params, lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 40 | else: 41 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 42 | #optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR) 43 | 44 | return optimizer, Moptimizer 45 | -------------------------------------------------------------------------------- /solver/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | import torch 4 | 5 | 6 | class Scheduler: 7 | """ Parameter Scheduler Base Class 8 | A scheduler base class that can be used to schedule any optimizer parameter groups. 9 | 10 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called 11 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value 12 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value 13 | 14 | The schedulers built on this should try to remain as stateless as possible (for simplicity). 15 | 16 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' 17 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training 18 | code and explicitly passed in to the schedulers on the corresponding step or step_update call. 19 | 20 | Based on ideas from: 21 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler 22 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers 23 | """ 24 | 25 | def __init__(self, 26 | optimizer: torch.optim.Optimizer, 27 | param_group_field: str, 28 | noise_range_t=None, 29 | noise_type='normal', 30 | noise_pct=0.67, 31 | noise_std=1.0, 32 | noise_seed=None, 33 | initialize: bool = True) -> None: 34 | self.optimizer = optimizer 35 | self.param_group_field = param_group_field 36 | self._initial_param_group_field = f"initial_{param_group_field}" 37 | if initialize: 38 | for i, group in enumerate(self.optimizer.param_groups): 39 | if param_group_field not in group: 40 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]") 41 | group.setdefault(self._initial_param_group_field, group[param_group_field]) 42 | else: 43 | for i, group in enumerate(self.optimizer.param_groups): 44 | if self._initial_param_group_field not in group: 45 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") 46 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] 47 | self.metric = None # any point to having this for all? 48 | self.noise_range_t = noise_range_t 49 | self.noise_pct = noise_pct 50 | self.noise_type = noise_type 51 | self.noise_std = noise_std 52 | self.noise_seed = noise_seed if noise_seed is not None else 42 53 | self.update_groups(self.base_values) 54 | 55 | def state_dict(self) -> Dict[str, Any]: 56 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 57 | 58 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 59 | self.__dict__.update(state_dict) 60 | 61 | def get_epoch_values(self, epoch: int): 62 | return None 63 | 64 | def get_update_values(self, num_updates: int): 65 | return None 66 | 67 | def step(self, epoch: int, metric: float = None) -> None: 68 | self.metric = metric 69 | values = self.get_epoch_values(epoch) 70 | if values is not None: 71 | values = self._add_noise(values, epoch) 72 | self.update_groups(values) 73 | 74 | def step_update(self, num_updates: int, metric: float = None): 75 | self.metric = metric 76 | values = self.get_update_values(num_updates) 77 | if values is not None: 78 | values = self._add_noise(values, num_updates) 79 | self.update_groups(values) 80 | 81 | def update_groups(self, values): 82 | if not isinstance(values, (list, tuple)): 83 | values = [values] * len(self.optimizer.param_groups) 84 | for param_group, value in zip(self.optimizer.param_groups, values): 85 | param_group[self.param_group_field] = value 86 | 87 | def _add_noise(self, lrs, t): 88 | if self.noise_range_t is not None: 89 | if isinstance(self.noise_range_t, (list, tuple)): 90 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] 91 | else: 92 | apply_noise = t >= self.noise_range_t 93 | if apply_noise: 94 | g = torch.Generator() 95 | g.manual_seed(self.noise_seed + t) 96 | if self.noise_type == 'normal': 97 | while True: 98 | # resample if noise out of percent limit, brute force but shouldn't spin much 99 | noise = torch.randn(1, generator=g).item() 100 | if abs(noise) < self.noise_pct: 101 | break 102 | else: 103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 104 | lrs = [v + v * noise for v in lrs] 105 | return lrs 106 | -------------------------------------------------------------------------------- /solver/scheduler_factory.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from .cosine_lr import CosineLRScheduler 5 | 6 | 7 | def create_scheduler(cfg, optimizer): 8 | num_epochs = cfg.SOLVER.MAX_EPOCHS 9 | # type 1 10 | # lr_min = 0.01 * cfg.SOLVER.BASE_LR 11 | # warmup_lr_init = 0.001 * cfg.SOLVER.BASE_LR 12 | # type 2 13 | lr_min = 0.002 * cfg.SOLVER.BASE_LR 14 | warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR 15 | # type 3 16 | # lr_min = 0.001 * cfg.SOLVER.BASE_LR 17 | # warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR 18 | 19 | warmup_t = cfg.SOLVER.WARMUP_EPOCHS 20 | noise_range = None 21 | 22 | lr_scheduler = CosineLRScheduler( 23 | optimizer, 24 | t_initial=num_epochs, 25 | lr_min=lr_min, 26 | t_mul= 1., 27 | decay_rate=0.1, 28 | warmup_lr_init=warmup_lr_init, 29 | warmup_t=warmup_t, 30 | cycle_limit=1, 31 | t_in_epochs=True, 32 | noise_range_t=noise_range, 33 | noise_pct= 0.67, 34 | noise_std= 1., 35 | noise_seed=42, 36 | ) 37 | 38 | return lr_scheduler 39 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from config import cfg 3 | import argparse 4 | from datasets import make_dataloader 5 | from model import make_model 6 | from processor import do_inference 7 | from utils.logger import setup_logger 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 12 | parser.add_argument( 13 | "--config_file", default="", help="path to config file", type=str 14 | ) 15 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 16 | nargs=argparse.REMAINDER) 17 | 18 | args = parser.parse_args() 19 | 20 | 21 | 22 | if args.config_file != "": 23 | cfg.merge_from_file(args.config_file) 24 | cfg.merge_from_list(args.opts) 25 | cfg.freeze() 26 | 27 | output_dir = cfg.OUTPUT_DIR 28 | if output_dir and not os.path.exists(output_dir): 29 | os.makedirs(output_dir) 30 | 31 | logger = setup_logger("transreid", output_dir, if_train=False) 32 | logger.info(args) 33 | 34 | if args.config_file != "": 35 | logger.info("Loaded configuration file {}".format(args.config_file)) 36 | with open(args.config_file, 'r') as cf: 37 | config_str = "\n" + cf.read() 38 | logger.info(config_str) 39 | logger.info("Running with config:\n{}".format(cfg)) 40 | 41 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID 42 | 43 | train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg) 44 | 45 | model = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num) 46 | model.load_param(cfg.TEST.WEIGHT) 47 | 48 | if cfg.DATASETS.NAMES == 'VehicleID': 49 | for trial in range(10): 50 | train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg) 51 | rank_1, rank5 = do_inference(cfg, 52 | model, 53 | val_loader, 54 | num_query) 55 | if trial == 0: 56 | all_rank_1 = rank_1 57 | all_rank_5 = rank5 58 | else: 59 | all_rank_1 = all_rank_1 + rank_1 60 | all_rank_5 = all_rank_5 + rank5 61 | 62 | logger.info("rank_1:{}, rank_5 {} : trial : {}".format(rank_1, rank5, trial)) 63 | logger.info("sum_rank_1:{:.1%}, sum_rank_5 {:.1%}".format(all_rank_1.sum()/10.0, all_rank_5.sum()/10.0)) 64 | else: 65 | do_inference(cfg, 66 | model, 67 | val_loader, 68 | num_query) 69 | 70 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from utils.logger import setup_logger 2 | from datasets import make_dataloader 3 | from model import make_model 4 | from solver import make_optimizer 5 | from solver.scheduler_factory import create_scheduler 6 | from loss import make_loss 7 | from processor import do_train 8 | import random 9 | import torch 10 | import numpy as np 11 | import os 12 | import argparse 13 | # from timm.scheduler import create_scheduler 14 | from config import cfg 15 | 16 | def set_seed(seed): 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | np.random.seed(seed) 21 | random.seed(seed) 22 | torch.backends.cudnn.deterministic = True 23 | torch.backends.cudnn.benchmark = True 24 | 25 | if __name__ == '__main__': 26 | 27 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 28 | parser.add_argument( 29 | "--config_file", default="", help="path to config file", type=str 30 | ) 31 | 32 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 33 | nargs=argparse.REMAINDER) 34 | parser.add_argument("--local_rank", default=0, type=int) 35 | args = parser.parse_args() 36 | 37 | if args.config_file != "": 38 | cfg.merge_from_file(args.config_file) 39 | cfg.merge_from_list(args.opts) 40 | cfg.freeze() 41 | 42 | set_seed(cfg.SOLVER.SEED) 43 | 44 | if cfg.MODEL.DIST_TRAIN: 45 | torch.cuda.set_device(args.local_rank) 46 | 47 | output_dir = cfg.OUTPUT_DIR 48 | if output_dir and not os.path.exists(output_dir): 49 | os.makedirs(output_dir) 50 | 51 | logger = setup_logger("transreid", output_dir, if_train=True) 52 | logger.info("Saving model in the path :{}".format(cfg.OUTPUT_DIR)) 53 | logger.info(args) 54 | 55 | if args.config_file != "": 56 | logger.info("Loaded configuration file {}".format(args.config_file)) 57 | with open(args.config_file, 'r') as cf: 58 | config_str = "\n" + cf.read() 59 | logger.info(config_str) 60 | logger.info("Running with config:\n{}".format(cfg)) 61 | 62 | if cfg.MODEL.DIST_TRAIN: 63 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 64 | 65 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID 66 | train_loader, train_loader_normal, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg) 67 | 68 | model = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num) 69 | 70 | loss_func, center_criterion = make_loss(cfg, num_classes=num_classes) 71 | 72 | optimizer, Moptimizer = make_optimizer(cfg, model, center_criterion) 73 | 74 | scheduler = create_scheduler(cfg, optimizer) 75 | 76 | do_train( 77 | cfg, 78 | model, 79 | center_criterion, 80 | train_loader, 81 | val_loader, 82 | optimizer, 83 | Moptimizer, 84 | scheduler, 85 | loss_func, 86 | num_query, args.local_rank 87 | ) 88 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/iotools.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/utils/__pycache__/iotools.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/utils/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/meter.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/utils/__pycache__/meter.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/utils/__pycache__/metrics.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/reranking.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stone96123/DPM/a7bb2267b22de7cebfcabae5fd998c71b0f94b13/utils/__pycache__/reranking.cpython-37.pyc -------------------------------------------------------------------------------- /utils/iotools.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import errno 8 | import json 9 | import os 10 | 11 | import os.path as osp 12 | 13 | 14 | def mkdir_if_missing(directory): 15 | if not osp.exists(directory): 16 | try: 17 | os.makedirs(directory) 18 | except OSError as e: 19 | if e.errno != errno.EEXIST: 20 | raise 21 | 22 | 23 | def check_isfile(path): 24 | isfile = osp.isfile(path) 25 | if not isfile: 26 | print("=> Warning: no file found at '{}' (ignored)".format(path)) 27 | return isfile 28 | 29 | 30 | def read_json(fpath): 31 | with open(fpath, 'r') as f: 32 | obj = json.load(f) 33 | return obj 34 | 35 | 36 | def write_json(obj, fpath): 37 | mkdir_if_missing(osp.dirname(fpath)) 38 | with open(fpath, 'w') as f: 39 | json.dump(obj, f, indent=4, separators=(',', ': ')) 40 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import os.path as osp 5 | def setup_logger(name, save_dir, if_train): 6 | logger = logging.getLogger(name) 7 | logger.setLevel(logging.DEBUG) 8 | 9 | ch = logging.StreamHandler(stream=sys.stdout) 10 | ch.setLevel(logging.DEBUG) 11 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 12 | ch.setFormatter(formatter) 13 | logger.addHandler(ch) 14 | 15 | if save_dir: 16 | if not osp.exists(save_dir): 17 | os.makedirs(save_dir) 18 | if if_train: 19 | fh = logging.FileHandler(os.path.join(save_dir, "train_log.txt"), mode='w') 20 | else: 21 | fh = logging.FileHandler(os.path.join(save_dir, "test_log.txt"), mode='w') 22 | fh.setLevel(logging.DEBUG) 23 | fh.setFormatter(formatter) 24 | logger.addHandler(fh) 25 | 26 | return logger -------------------------------------------------------------------------------- /utils/meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | 4 | def __init__(self): 5 | self.val = 0 6 | self.avg = 0 7 | self.sum = 0 8 | self.count = 0 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from utils.reranking import re_ranking 5 | import torch.nn.functional as F 6 | 7 | def euclidean_distance(qf, gf): 8 | m = qf.shape[0] 9 | n = gf.shape[0] 10 | dist_mat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 11 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 12 | dist_mat.addmm_(1, -2, qf, gf.t()) 13 | return dist_mat.cpu().numpy() 14 | 15 | def cosine_similarity(qf, gf): 16 | epsilon = 0.00001 17 | dist_mat = qf.mm(gf.t()) 18 | qf_norm = torch.norm(qf, p=2, dim=1, keepdim=True) # mx1 19 | gf_norm = torch.norm(gf, p=2, dim=1, keepdim=True) # nx1 20 | qg_normdot = qf_norm.mm(gf_norm.t()) 21 | 22 | dist_mat = dist_mat.mul(1 / qg_normdot).cpu().numpy() 23 | dist_mat = np.clip(dist_mat, -1 + epsilon, 1 - epsilon) 24 | dist_mat = np.arccos(dist_mat) 25 | return dist_mat 26 | 27 | def cosine_single(qf, gf, qm, gm): 28 | epsilon = 0.00001 29 | for i in range (qf.shape[0]): 30 | if i == 0: 31 | query_vector = F.normalize(qf[i,:].unsqueeze(0)) 32 | gallery_mrtrix = F.normalize(qm[i,:].unsqueeze(0) * F.normalize(gf)) 33 | dist_mat = query_vector.mm(gallery_mrtrix.t()) 34 | else: 35 | query_vector = F.normalize(qf[i,:].unsqueeze(0)) 36 | gallery_mrtrix = F.normalize(qm[i,:].unsqueeze(0) * F.normalize(gf)) 37 | dist_single = query_vector.mm(gallery_mrtrix.t()) 38 | dist_mat = torch.cat((dist_mat, dist_single), 0) 39 | dist_mat = dist_mat.cpu().numpy() 40 | dist_mat = np.clip(dist_mat, -1 + epsilon, 1 - epsilon) 41 | dist_mat = np.arccos(dist_mat) 42 | return dist_mat 43 | 44 | def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): 45 | """Evaluation with market1501 metric 46 | Key: for each query identity, its gallery images from the same camera view are discarded. 47 | """ 48 | num_q, num_g = distmat.shape 49 | # distmat g 50 | # q 1 3 2 4 51 | # 4 1 2 3 52 | if num_g < max_rank: 53 | max_rank = num_g 54 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 55 | indices = np.argsort(distmat, axis=1) 56 | # 0 2 1 3 57 | # 1 2 3 0 58 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 59 | # compute cmc curve for each query 60 | all_cmc = [] 61 | all_AP = [] 62 | num_valid_q = 0. # number of valid query 63 | for q_idx in range(num_q): 64 | # get query pid and camid 65 | q_pid = q_pids[q_idx] 66 | q_camid = q_camids[q_idx] 67 | 68 | # remove gallery samples that have the same pid and camid with query 69 | order = indices[q_idx] # select one row 70 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 71 | keep = np.invert(remove) 72 | 73 | # compute cmc curve 74 | # binary vector, positions with value 1 are correct matches 75 | orig_cmc = matches[q_idx][keep] 76 | if not np.any(orig_cmc): 77 | # this condition is true when query identity does not appear in gallery 78 | continue 79 | 80 | cmc = orig_cmc.cumsum() 81 | cmc[cmc > 1] = 1 82 | 83 | all_cmc.append(cmc[:max_rank]) 84 | num_valid_q += 1. 85 | 86 | # compute average precision 87 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 88 | num_rel = orig_cmc.sum() 89 | tmp_cmc = orig_cmc.cumsum() 90 | #tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 91 | y = np.arange(1, tmp_cmc.shape[0] + 1) * 1.0 92 | tmp_cmc = tmp_cmc / y 93 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 94 | AP = tmp_cmc.sum() / num_rel 95 | all_AP.append(AP) 96 | 97 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 98 | 99 | all_cmc = np.asarray(all_cmc).astype(np.float32) 100 | all_cmc = all_cmc.sum(0) / num_valid_q 101 | mAP = np.mean(all_AP) 102 | 103 | return all_cmc, mAP 104 | 105 | 106 | class R1_mAP_eval(): 107 | def __init__(self, num_query, max_rank=50, feat_norm=True, reranking=False): 108 | super(R1_mAP_eval, self).__init__() 109 | self.num_query = num_query 110 | self.max_rank = max_rank 111 | self.feat_norm = feat_norm 112 | self.reranking = reranking 113 | 114 | def reset(self): 115 | self.feats = [] 116 | self.masks = [] 117 | self.pids = [] 118 | self.camids = [] 119 | 120 | def update(self, output): # called once for each batch 121 | feat, mask, pid, camid = output 122 | self.feats.append(feat.cpu()) 123 | self.masks.append(mask.cpu()) 124 | self.pids.extend(np.asarray(pid)) 125 | self.camids.extend(np.asarray(camid)) 126 | 127 | def compute(self): # called after each epoch 128 | feats = torch.cat(self.feats, dim=0) 129 | masks = torch.cat(self.masks, dim=0) 130 | if self.feat_norm: 131 | print("The test feature is normalized") 132 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) # along channel 133 | # query 134 | qf = feats[:self.num_query] 135 | qm = masks[:self.num_query] 136 | q_pids = np.asarray(self.pids[:self.num_query]) 137 | q_camids = np.asarray(self.camids[:self.num_query]) 138 | # gallery 139 | gf = feats[self.num_query:] 140 | gm = masks[self.num_query:] 141 | g_pids = np.asarray(self.pids[self.num_query:]) 142 | 143 | g_camids = np.asarray(self.camids[self.num_query:]) 144 | if self.reranking: 145 | print('=> Enter reranking') 146 | # distmat = re_ranking(qf, gf, k1=20, k2=6, lambda_value=0.3) 147 | distmat = re_ranking(qf, gf, k1=50, k2=15, lambda_value=0.3) 148 | 149 | else: 150 | print('=> Computing DistMat with euclidean_distance') 151 | distmat = cosine_single(qf,gf,qm,gm) 152 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids) 153 | 154 | return cmc, mAP, distmat, self.pids, self.camids, qf, gf 155 | 156 | 157 | 158 | -------------------------------------------------------------------------------- /utils/reranking.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri, 25 May 2018 20:29:09 5 | 6 | 7 | """ 8 | 9 | """ 10 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 11 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 12 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 13 | """ 14 | 15 | """ 16 | API 17 | 18 | probFea: all feature vectors of the query set (torch tensor) 19 | probFea: all feature vectors of the gallery set (torch tensor) 20 | k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3) 21 | MemorySave: set to 'True' when using MemorySave mode 22 | Minibatch: avaliable when 'MemorySave' is 'True' 23 | """ 24 | 25 | import numpy as np 26 | import torch 27 | 28 | 29 | def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat=None, only_local=False): 30 | # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor 31 | query_num = probFea.size(0) 32 | all_num = query_num + galFea.size(0) 33 | if only_local: 34 | original_dist = local_distmat 35 | else: 36 | feat = torch.cat([probFea, galFea]) 37 | # print('using GPU to compute original distance') 38 | distmat = torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num) + \ 39 | torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t() 40 | distmat.addmm_(1, -2, feat, feat.t()) 41 | original_dist = distmat.cpu().numpy() 42 | del feat 43 | if not local_distmat is None: 44 | original_dist = original_dist + local_distmat 45 | gallery_num = original_dist.shape[0] 46 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) 47 | V = np.zeros_like(original_dist).astype(np.float16) 48 | initial_rank = np.argsort(original_dist).astype(np.int32) 49 | 50 | # print('starting re_ranking') 51 | for i in range(all_num): 52 | # k-reciprocal neighbors 53 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 54 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 55 | fi = np.where(backward_k_neigh_index == i)[0] 56 | k_reciprocal_index = forward_k_neigh_index[fi] 57 | k_reciprocal_expansion_index = k_reciprocal_index 58 | for j in range(len(k_reciprocal_index)): 59 | candidate = k_reciprocal_index[j] 60 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1] 61 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 62 | :int(np.around(k1 / 2)) + 1] 63 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 64 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 65 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 66 | candidate_k_reciprocal_index): 67 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 68 | 69 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 70 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 71 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) 72 | original_dist = original_dist[:query_num, ] 73 | if k2 != 1: 74 | V_qe = np.zeros_like(V, dtype=np.float16) 75 | for i in range(all_num): 76 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 77 | V = V_qe 78 | del V_qe 79 | del initial_rank 80 | invIndex = [] 81 | for i in range(gallery_num): 82 | invIndex.append(np.where(V[:, i] != 0)[0]) 83 | 84 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) 85 | 86 | for i in range(query_num): 87 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16) 88 | indNonZero = np.where(V[i, :] != 0)[0] 89 | indImages = [invIndex[ind] for ind in indNonZero] 90 | for j in range(len(indNonZero)): 91 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 92 | V[indImages[j], indNonZero[j]]) 93 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 94 | 95 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 96 | del original_dist 97 | del V 98 | del jaccard_dist 99 | final_dist = final_dist[:query_num, query_num:] 100 | return final_dist 101 | 102 | --------------------------------------------------------------------------------