├── README.md ├── configs ├── __pycache__ │ ├── default_img.cpython-36.pyc │ ├── default_img.cpython-37.pyc │ ├── default_vid.cpython-36.pyc │ └── default_vid.cpython-37.pyc ├── c2dres50_ce_cal.yaml ├── default_img.py ├── default_vid.py ├── res50_cels_cal.yaml ├── res50_cels_cal_16x4.yaml └── res50_cels_cal_tri_16x4.yaml ├── data ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── dataloader.cpython-36.pyc │ ├── dataloader.cpython-37.pyc │ ├── dataset_loader.cpython-36.pyc │ ├── dataset_loader.cpython-37.pyc │ ├── dataset_loader3.cpython-36.pyc │ ├── img_transforms.cpython-36.pyc │ ├── img_transforms.cpython-37.pyc │ ├── img_transforms3.cpython-36.pyc │ ├── samplers.cpython-36.pyc │ ├── samplers.cpython-37.pyc │ ├── spatial_transforms.cpython-36.pyc │ ├── spatial_transforms.cpython-37.pyc │ ├── temporal_transforms.cpython-36.pyc │ └── temporal_transforms.cpython-37.pyc ├── dataloader.py ├── dataset_loader.py ├── datasets │ ├── __pycache__ │ │ ├── ccvid.cpython-36.pyc │ │ ├── deepchange.cpython-36.pyc │ │ ├── last.cpython-36.pyc │ │ ├── ltcc.cpython-36.pyc │ │ ├── ltcc.cpython-37.pyc │ │ ├── ltcc3.cpython-36.pyc │ │ ├── prcc.cpython-36.pyc │ │ └── vcclothes.cpython-36.pyc │ ├── ccvid.py │ ├── deepchange.py │ ├── last.py │ ├── ltcc.py │ ├── prcc.py │ └── vcclothes.py ├── img_transforms.py ├── samplers.py ├── spatial_transforms.py └── temporal_transforms.py ├── infer.py ├── losses ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── arcface_loss.cpython-36.pyc │ ├── circle_loss.cpython-36.pyc │ ├── clothes_based_adversarial_loss.cpython-36.pyc │ ├── contrastive_loss.cpython-36.pyc │ ├── cosface_loss.cpython-36.pyc │ ├── cross_entropy_loss_with_label_smooth.cpython-36.pyc │ ├── gather.cpython-36.pyc │ └── triplet_loss.cpython-36.pyc ├── arcface_loss.py ├── circle_loss.py ├── clothes_based_adversarial_loss.py ├── contrastive_loss.py ├── cosface_loss.py ├── cross_entropy_loss_with_label_smooth.py ├── gather.py └── triplet_loss.py ├── main.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── classifier.cpython-36.pyc │ ├── img_resnet.cpython-36.pyc │ ├── img_resnet_3.cpython-36.pyc │ ├── img_resnet_fc.cpython-36.pyc │ ├── img_resnet_sep.cpython-36.pyc │ └── vid_resnet.cpython-36.pyc ├── classifier.py ├── img_resnet.py ├── utils │ ├── __pycache__ │ │ ├── c3d_blocks.cpython-36.pyc │ │ ├── inflate.cpython-36.pyc │ │ ├── nonlocal_blocks.cpython-36.pyc │ │ └── pooling.cpython-36.pyc │ ├── c3d_blocks.py │ ├── inflate.py │ ├── nonlocal_blocks.py │ └── pooling.py └── vid_resnet.py ├── test.py ├── tools ├── __pycache__ │ ├── eval_metrics.cpython-36.pyc │ └── utils.cpython-36.pyc ├── eval_metrics.py └── utils.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | 3 | This is the source code of our TCSVT 2023 paper "DCR-ReID: Deep Component Reconstruction for Cloth-Changing Person Re-Identification". Please cite the following paper if you use our code. 4 | 5 | Zhenyu Cui, Jiahuan Zhou, Yuxin Peng, Shiliang Zhang and Yaowei Wang, "DCR-ReID: Deep Component Reconstruction for Cloth-Changing Person Re-Identification", IEEE Transactions on Circuits and Systems for Video Technology (TCSVT), 2023. 6 | 7 | 8 | 9 | # Dependencies 10 | 11 | - Python 3.6 12 | 13 | - PyTorch 1.6.0 14 | 15 | - yacs 16 | 17 | - apex 18 | 19 | 20 | 21 | # Data Preparation 22 | 23 | - Download the pre-processed datasets that we used from the [link](https://pan.baidu.com/s/1LwAyB1R86P3xMZxIPm1vwQ) (password: dg1a) and unzip them to ./datasets folders. 24 | 25 | 26 | # Usage 27 | 28 | - Replace `_C.DATA.ROOT` and `_C.OUTPUT` in `configs/default_img.py&default_vid.py`with your own `data path` and `output path`, respectively. 29 | 30 | - Start training by executing the following commands. 31 | 32 | 1. For LTCC dataset: `python -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset ltcc --cfg configs/res50_cels_cal.yaml --gpu 0,1 --spr 0 --sacr 0.05 --rr 1.0` 33 | 34 | 2. For PRCC dataset: `python -m torch.distributed.launch --nproc_per_node=2 --master_port 12345 main.py --dataset prcc --cfg configs/res50_cels_cal.yaml --gpu 2,3 --spr 1.0 --sacr 0.05 --rr 1.0` 35 | 36 | 3. For CCVID dataset: `python -m torch.distributed.launch --nproc_per_node=4 --master_port 12345 main.py --dataset ccvid --cfg configs/c2dres50_ce_cal.yaml --gpu 0,1,2,3` 37 | 38 | For any questions, feel free to contact us (cuizhenyu@stu.pku.edu.cn). 39 | 40 | Welcome to our [Laboratory Homepage](http://www.icst.pku.edu.cn/mipl/home/) for more information about our papers, source codes, and datasets. 41 | 42 | -------------------------------------------------------------------------------- /configs/__pycache__/default_img.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/configs/__pycache__/default_img.cpython-36.pyc -------------------------------------------------------------------------------- /configs/__pycache__/default_img.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/configs/__pycache__/default_img.cpython-37.pyc -------------------------------------------------------------------------------- /configs/__pycache__/default_vid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/configs/__pycache__/default_vid.cpython-36.pyc -------------------------------------------------------------------------------- /configs/__pycache__/default_vid.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/configs/__pycache__/default_vid.cpython-37.pyc -------------------------------------------------------------------------------- /configs/c2dres50_ce_cal.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: c2dres50 3 | LOSS: 4 | CLA_LOSS: crossentropy 5 | CAL: cal 6 | TAG: c2dres50-ce-cal -------------------------------------------------------------------------------- /configs/default_img.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from yacs.config import CfgNode as CN 4 | 5 | 6 | _C = CN() 7 | # ----------------------------------------------------------------------------- 8 | # Data settings 9 | # ----------------------------------------------------------------------------- 10 | _C.DATA = CN() 11 | # Root path for dataset directory 12 | _C.DATA.ROOT = '/data1/cuizhenyu/Simple-CCReID-main/datasets/' 13 | # Dataset for evaluation 14 | _C.DATA.DATASET = 'ltcc' 15 | # Workers for dataloader 16 | _C.DATA.NUM_WORKERS = 4 17 | # Height of input image 18 | _C.DATA.HEIGHT = 384 19 | # Width of input image 20 | _C.DATA.WIDTH = 192 21 | # Batch size for training 22 | _C.DATA.TRAIN_BATCH = 32 #32 23 | # Batch size for testing 24 | _C.DATA.TEST_BATCH = 128 #128 25 | # The number of instances per identity for training sampler 26 | _C.DATA.NUM_INSTANCES = 8 27 | # ----------------------------------------------------------------------------- 28 | # Augmentation settings 29 | # ----------------------------------------------------------------------------- 30 | _C.AUG = CN() 31 | # Random crop prob 32 | _C.AUG.RC_PROB = 0.5 33 | # Random erase prob 34 | _C.AUG.RE_PROB = 0.5 35 | # Random flip prob 36 | _C.AUG.RF_PROB = 0.5 37 | # ----------------------------------------------------------------------------- 38 | # Parameters settings 39 | # ----------------------------------------------------------------------------- 40 | _C.PARA = CN() 41 | # Random crop prob 42 | _C.PARA.SHUF_PID_RATIO = 1.0 43 | # Random erase prob 44 | _C.PARA.SHUF_ADV_CLO_RATIO = 0.1 45 | # Random flip prob 46 | _C.PARA.RECON_RATIO = 1.0 47 | # ----------------------------------------------------------------------------- 48 | # Model settings 49 | # ----------------------------------------------------------------------------- 50 | _C.MODEL = CN() 51 | # Model name 52 | _C.MODEL.NAME = 'resnet50' 53 | # The stride for laery4 in resnet 54 | _C.MODEL.RES4_STRIDE = 1 55 | # feature dim 56 | _C.MODEL.FEATURE_DIM = 4096 57 | # Model path for resuming 58 | _C.MODEL.RESUME = '' 59 | # Global pooling after the backbone 60 | _C.MODEL.POOLING = CN() 61 | # Choose in ['avg', 'max', 'gem', 'maxavg'] 62 | _C.MODEL.POOLING.NAME = 'maxavg' 63 | # Initialized power for GeM pooling 64 | _C.MODEL.POOLING.P = 3 65 | # Hidden reconstruct dim for model 66 | _C.MODEL.HID_REC_DIM = 4096 67 | # No cLoth dim for model 68 | _C.MODEL.NO_CLOTHES_DIM = 2048 69 | # Contour dim for model 70 | _C.MODEL.CONTOUR_DIM = 1024 71 | # CLoth dim for model 72 | _C.MODEL.CLOTHES_DIM = 1024 73 | # ----------------------------------------------------------------------------- 74 | # Losses for training 75 | # ----------------------------------------------------------------------------- 76 | _C.LOSS = CN() 77 | # Classification loss 78 | _C.LOSS.CLA_LOSS = 'crossentropy' 79 | # Clothes classification loss 80 | _C.LOSS.CLOTHES_CLA_LOSS = 'cosface' 81 | # Scale for classification loss 82 | _C.LOSS.CLA_S = 16. 83 | # Margin for classification loss 84 | _C.LOSS.CLA_M = 0. 85 | # Pairwise loss 86 | _C.LOSS.PAIR_LOSS = 'triplet' 87 | # The weight for pairwise loss 88 | _C.LOSS.PAIR_LOSS_WEIGHT = 0.0 89 | # Scale for pairwise loss 90 | _C.LOSS.PAIR_S = 16. 91 | # Margin for pairwise loss 92 | _C.LOSS.PAIR_M = 0.3 93 | # Clothes-based adversarial loss 94 | _C.LOSS.CAL = 'cal' 95 | # Epsilon for clothes-based adversarial loss 96 | _C.LOSS.EPSILON = 0.1 97 | # Momentum for clothes-based adversarial loss with memory bank 98 | _C.LOSS.MOMENTUM = 0. 99 | # ----------------------------------------------------------------------------- 100 | # Training settings 101 | # ----------------------------------------------------------------------------- 102 | _C.TRAIN = CN() 103 | _C.TRAIN.START_EPOCH = 0 104 | _C.TRAIN.MAX_EPOCH = 80 105 | # Start epoch for clothes classification 106 | _C.TRAIN.START_EPOCH_CC = 25 107 | # Start epoch for adversarial training 108 | _C.TRAIN.START_EPOCH_ADV = 25 109 | # Optimizer 110 | _C.TRAIN.OPTIMIZER = CN() 111 | _C.TRAIN.OPTIMIZER.NAME = 'adam' 112 | # Learning rate 113 | _C.TRAIN.OPTIMIZER.LR = 0.00035 114 | _C.TRAIN.OPTIMIZER.WEIGHT_DECAY = 5e-4 115 | # LR scheduler 116 | _C.TRAIN.LR_SCHEDULER = CN() 117 | # Stepsize to decay learning rate 118 | _C.TRAIN.LR_SCHEDULER.STEPSIZE = [20, 40, 60] 119 | # LR decay rate, used in StepLRScheduler 120 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 121 | # Using amp for training 122 | _C.TRAIN.AMP = False 123 | # ----------------------------------------------------------------------------- 124 | # Testing settings 125 | # ----------------------------------------------------------------------------- 126 | _C.TEST = CN() 127 | # Perform evaluation after every N epochs (set to -1 to test after training) 128 | _C.TEST.EVAL_STEP = 5 129 | # Start to evaluate after specific epoch 130 | _C.TEST.START_EVAL = 0 131 | # ----------------------------------------------------------------------------- 132 | # Infering settings 133 | # ----------------------------------------------------------------------------- 134 | _C.INFER = CN() 135 | # SHOW CC or GENRAL 136 | _C.INFER.SHOW_CC = True 137 | # ----------------------------------------------------------------------------- 138 | # Misc 139 | # ----------------------------------------------------------------------------- 140 | # Fixed random seed 141 | _C.SEED = 1 142 | # Perform evaluation only 143 | _C.EVAL_MODE = False 144 | # Perform inference only 145 | _C.INFER_MODE = False 146 | # GPU device ids for CUDA_VISIBLE_DEVICES 147 | _C.GPU = '0' 148 | # Path to output folder, overwritten by command line argument 149 | _C.OUTPUT = './logs/' 150 | # Tag of experiment, overwritten by command line argument 151 | _C.TAG = 'res50-ce-cal' 152 | 153 | 154 | def update_config(config, args): 155 | config.defrost() 156 | config.merge_from_file(args.cfg) 157 | 158 | # merge from specific arguments 159 | if args.root: 160 | config.DATA.ROOT = args.root 161 | if args.output: 162 | config.OUTPUT = args.output 163 | 164 | if args.resume: 165 | config.MODEL.RESUME = args.resume 166 | if args.eval: 167 | config.EVAL_MODE = True 168 | if args.infer: 169 | config.INFER_MODE = True 170 | 171 | if args.tag: 172 | config.TAG = args.tag 173 | 174 | if args.dataset: 175 | config.DATA.DATASET = args.dataset 176 | if args.gpu: 177 | config.GPU = args.gpu 178 | if args.amp: 179 | config.TRAIN.AMP = True 180 | 181 | config.PARA.SHUF_PID_RATIO = args.spr 182 | config.PARA.SHUF_ADV_CLO_RATIO = args.sacr 183 | config.PARA.RECON_RATIO = args.rr 184 | # output folder 185 | config.OUTPUT = os.path.join(config.OUTPUT, config.DATA.DATASET, config.TAG, str(config.PARA.SHUF_PID_RATIO)+'-'+str(config.PARA.SHUF_ADV_CLO_RATIO)+'-'+str(config.PARA.RECON_RATIO)) 186 | 187 | config.freeze() 188 | 189 | 190 | def get_img_config(args): 191 | """Get a yacs CfgNode object with default values.""" 192 | config = _C.clone() 193 | update_config(config, args) 194 | 195 | return config 196 | -------------------------------------------------------------------------------- /configs/default_vid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from yacs.config import CfgNode as CN 4 | 5 | 6 | _C = CN() 7 | # ----------------------------------------------------------------------------- 8 | # Data settings 9 | # ----------------------------------------------------------------------------- 10 | _C.DATA = CN() 11 | # Root path for dataset directory 12 | _C.DATA.ROOT = '/data1/cuizhenyu/Simple-CCReID-main/datasets/' 13 | # Dataset for evaluation 14 | _C.DATA.DATASET = 'ccvid' 15 | # Whether split each full-length video in the training set into some clips 16 | _C.DATA.DENSE_SAMPLING = True 17 | # Sampling step of dense sampling for training set 18 | _C.DATA.SAMPLING_STEP = 64 19 | # Workers for dataloader 20 | _C.DATA.NUM_WORKERS = 4 21 | # Height of input image 22 | _C.DATA.HEIGHT = 256 23 | # Width of input image 24 | _C.DATA.WIDTH = 128 25 | # Batch size for training 26 | _C.DATA.TRAIN_BATCH = 8 # 16 27 | # Batch size for testing 28 | _C.DATA.TEST_BATCH = 64 # 128 29 | # The number of instances per identity for training sampler 30 | _C.DATA.NUM_INSTANCES = 4 31 | # ----------------------------------------------------------------------------- 32 | # Augmentation settings 33 | # ----------------------------------------------------------------------------- 34 | _C.AUG = CN() 35 | # Random erase prob 36 | _C.AUG.RE_PROB = 0.0 37 | # Temporal sampling mode for training, 'tsn' or 'stride' 38 | _C.AUG.TEMPORAL_SAMPLING_MODE = 'stride' 39 | # Sequence length of each input video clip 40 | _C.AUG.SEQ_LEN = 8 41 | # Sampling stride of each input video clip 42 | _C.AUG.SAMPLING_STRIDE = 4 43 | # ----------------------------------------------------------------------------- 44 | # Parameters settings 45 | # ----------------------------------------------------------------------------- 46 | _C.PARA = CN() 47 | # Random crop prob 48 | _C.PARA.SHUF_PID_RATIO = 1.0 49 | # Random erase prob 50 | _C.PARA.SHUF_ADV_CLO_RATIO = 0.05 51 | # Random flip prob 52 | _C.PARA.RECON_RATIO = 1.0 53 | # ----------------------------------------------------------------------------- 54 | # Model settings 55 | # ----------------------------------------------------------------------------- 56 | _C.MODEL = CN() 57 | # Model name. All supported model can be seen in models/__init__.py 58 | _C.MODEL.NAME = 'c2dres50' 59 | # The stride for laery4 in resnet 60 | _C.MODEL.RES4_STRIDE = 1 61 | # feature dim 62 | _C.MODEL.FEATURE_DIM = 2048 63 | # Model path for resuming 64 | _C.MODEL.RESUME = '' 65 | # Params for AP3D 66 | _C.MODEL.AP3D = CN() 67 | # Temperature for APM 68 | _C.MODEL.AP3D.TEMPERATURE = 4 69 | # Contrastive attention 70 | _C.MODEL.AP3D.CONTRACTIVE_ATT = True 71 | # Hidden reconstruct dim for model 72 | _C.MODEL.HID_REC_DIM = 2048 73 | # No cLoth dim for model 74 | _C.MODEL.NO_CLOTHES_DIM = 1024 75 | # Contour dim for model 76 | _C.MODEL.CONTOUR_DIM = 512 77 | # CLoth dim for model 78 | _C.MODEL.CLOTHES_DIM = 512 79 | # ----------------------------------------------------------------------------- 80 | # Losses for training 81 | # ----------------------------------------------------------------------------- 82 | _C.LOSS = CN() 83 | # Classification loss 84 | _C.LOSS.CLA_LOSS = 'crossentropy' 85 | # Clothes classification loss 86 | _C.LOSS.CLOTHES_CLA_LOSS = 'cosface' 87 | # Scale for classification loss 88 | _C.LOSS.CLA_S = 16. 89 | # Margin for classification loss 90 | _C.LOSS.CLA_M = 0. 91 | # Pairwise loss 92 | _C.LOSS.PAIR_LOSS = 'triplet' 93 | # The weight for pairwise loss 94 | _C.LOSS.PAIR_LOSS_WEIGHT = 0.0 95 | # Scale for pairwise loss 96 | _C.LOSS.PAIR_S = 16. 97 | # Margin for pairwise loss 98 | _C.LOSS.PAIR_M = 0.3 99 | # Clothes-based adversarial loss 100 | _C.LOSS.CAL = 'cal' 101 | # Epsilon for clothes-based adversarial loss 102 | _C.LOSS.EPSILON = 0.1 103 | # Momentum for clothes-based adversarial loss with memory bank 104 | _C.LOSS.MOMENTUM = 0. 105 | # ----------------------------------------------------------------------------- 106 | # Training settings 107 | # ----------------------------------------------------------------------------- 108 | _C.TRAIN = CN() 109 | _C.TRAIN.START_EPOCH = 0 110 | _C.TRAIN.MAX_EPOCH = 150 111 | # Start epoch for clothes classification 112 | _C.TRAIN.START_EPOCH_CC = 50 113 | # Start epoch for adversarial training 114 | _C.TRAIN.START_EPOCH_ADV = 50 115 | # Optimizer 116 | _C.TRAIN.OPTIMIZER = CN() 117 | _C.TRAIN.OPTIMIZER.NAME = 'adam' 118 | # Learning rate 119 | _C.TRAIN.OPTIMIZER.LR = 0.00035 120 | _C.TRAIN.OPTIMIZER.WEIGHT_DECAY = 5e-4 121 | # LR scheduler 122 | _C.TRAIN.LR_SCHEDULER = CN() 123 | # Stepsize to decay learning rate 124 | _C.TRAIN.LR_SCHEDULER.STEPSIZE = [40, 80, 120] 125 | # LR decay rate, used in StepLRScheduler 126 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 127 | # Using amp for training 128 | _C.TRAIN.AMP = False 129 | # ----------------------------------------------------------------------------- 130 | # Testing settings 131 | # ----------------------------------------------------------------------------- 132 | _C.TEST = CN() 133 | # Perform evaluation after every N epochs (set to -1 to test after training) 134 | _C.TEST.EVAL_STEP = 10 135 | # Start to evaluate after specific epoch 136 | _C.TEST.START_EVAL = 0 137 | # ----------------------------------------------------------------------------- 138 | # Infering settings 139 | # ----------------------------------------------------------------------------- 140 | _C.INFER = CN() 141 | # SHOW CC or GENRAL 142 | _C.INFER.SHOW_CC = True 143 | # ----------------------------------------------------------------------------- 144 | # Misc 145 | # ----------------------------------------------------------------------------- 146 | # Fixed random seed 147 | _C.SEED = 1 148 | # Perform evaluation only 149 | _C.EVAL_MODE = False 150 | # Perform inference only 151 | _C.INFER_MODE = False 152 | # GPU device ids for CUDA_VISIBLE_DEVICES 153 | _C.GPU = '0, 1' 154 | # Path to output folder, overwritten by command line argument 155 | _C.OUTPUT = './logs/' 156 | # Tag of experiment, overwritten by command line argument 157 | _C.TAG = 'res50-ce-cal' 158 | 159 | 160 | def update_config(config, args): 161 | config.defrost() 162 | config.merge_from_file(args.cfg) 163 | 164 | # merge from specific arguments 165 | if args.root: 166 | config.DATA.ROOT = args.root 167 | if args.output: 168 | config.OUTPUT = args.output 169 | 170 | if args.resume: 171 | config.MODEL.RESUME = args.resume 172 | if args.eval: 173 | config.EVAL_MODE = True 174 | if args.infer: 175 | config.INFER_MODE = True 176 | 177 | if args.tag: 178 | config.TAG = args.tag 179 | 180 | if args.dataset: 181 | config.DATA.DATASET = args.dataset 182 | if args.gpu: 183 | config.GPU = args.gpu 184 | if args.amp: 185 | config.TRAIN.AMP = True 186 | 187 | config.PARA.SHUF_PID_RATIO = args.spr 188 | config.PARA.SHUF_ADV_CLO_RATIO = args.sacr 189 | config.PARA.RECON_RATIO = args.rr 190 | 191 | # output folder 192 | config.OUTPUT = os.path.join(config.OUTPUT, config.DATA.DATASET, config.TAG, str(config.PARA.SHUF_PID_RATIO)+'-'+str(config.PARA.SHUF_ADV_CLO_RATIO)+'-'+str(config.PARA.RECON_RATIO)) 193 | 194 | config.freeze() 195 | 196 | 197 | def get_vid_config(args): 198 | """Get a yacs CfgNode object with default values.""" 199 | config = _C.clone() 200 | update_config(config, args) 201 | 202 | return config 203 | -------------------------------------------------------------------------------- /configs/res50_cels_cal.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: resnet50 3 | LOSS: 4 | CLA_LOSS: crossentropylabelsmooth 5 | CAL: cal 6 | TAG: res50-cels-cal -------------------------------------------------------------------------------- /configs/res50_cels_cal_16x4.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: resnet50 3 | DATA: 4 | NUM_INSTANCES: 4 5 | TRAIN_BATCH: 32 6 | LOSS: 7 | CLA_LOSS: crossentropylabelsmooth 8 | CAL: cal 9 | TAG: res50-cels-cal-16x4 -------------------------------------------------------------------------------- /configs/res50_cels_cal_tri_16x4.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | NAME: resnet50 3 | DATA: 4 | NUM_INSTANCES: 4 5 | TRAIN_BATCH: 32 6 | LOSS: 7 | CLA_LOSS: crossentropylabelsmooth 8 | PAIR_LOSS: triplet 9 | CAL: cal 10 | PAIR_M: 0.3 11 | PAIR_LOSS_WEIGHT: 1.0 12 | TAG: res50-cels-cal-tri-16x4 -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import data.img_transforms as T 2 | import data.img_transforms as T 3 | import data.spatial_transforms as ST 4 | import data.temporal_transforms as TT 5 | from torch.utils.data import DataLoader 6 | from data.dataloader import DataLoaderX 7 | from data.dataset_loader import ImageDataset, VideoDataset, ImageDataset_Train 8 | from data.samplers import DistributedRandomIdentitySampler, DistributedInferenceSampler 9 | from data.datasets.ltcc import LTCC 10 | from data.datasets.prcc import PRCC 11 | from data.datasets.last import LaST 12 | from data.datasets.ccvid import CCVID 13 | from data.datasets.deepchange import DeepChange 14 | from data.datasets.vcclothes import VCClothes, VCClothesSameClothes, VCClothesClothesChanging 15 | 16 | 17 | __factory = { 18 | 'ltcc': LTCC, 19 | 'prcc': PRCC, 20 | 'vcclothes': VCClothes, 21 | 'vcclothes_sc': VCClothesSameClothes, 22 | 'vcclothes_cc': VCClothesClothesChanging, 23 | 'last': LaST, 24 | 'ccvid': CCVID, 25 | 'deepchange': DeepChange, 26 | } 27 | 28 | VID_DATASET = ['ccvid'] 29 | 30 | 31 | def get_names(): 32 | return list(__factory.keys()) 33 | 34 | 35 | def build_dataset(config): 36 | if config.DATA.DATASET not in __factory.keys(): 37 | raise KeyError("Invalid dataset, got '{}', but expected to be one of {}".format(name, __factory.keys())) 38 | 39 | if config.DATA.DATASET in VID_DATASET: 40 | dataset = __factory[config.DATA.DATASET](root=config.DATA.ROOT, 41 | sampling_step=config.DATA.SAMPLING_STEP, 42 | seq_len=config.AUG.SEQ_LEN, 43 | stride=config.AUG.SAMPLING_STRIDE) 44 | else: 45 | dataset = __factory[config.DATA.DATASET](root=config.DATA.ROOT) 46 | 47 | return dataset 48 | 49 | 50 | def build_img_transforms(config): 51 | transform_train = T.Compose([ 52 | T.Resize((config.DATA.HEIGHT, config.DATA.WIDTH)), 53 | T.RandomCroping(p=config.AUG.RC_PROB), 54 | T.RandomHorizontalFlip(p=config.AUG.RF_PROB), 55 | T.ToTensor(), 56 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 57 | T.RandomErasing(probability=config.AUG.RE_PROB) 58 | ]) 59 | transform_test = T.Compose([ 60 | T.Resize((config.DATA.HEIGHT, config.DATA.WIDTH)), 61 | T.ToTensor(), 62 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 63 | ]) 64 | 65 | return transform_train, transform_test 66 | 67 | 68 | def build_vid_transforms(config): 69 | spatial_transform_train = ST.Compose([ 70 | ST.Scale((config.DATA.HEIGHT, config.DATA.WIDTH), interpolation=3), 71 | ST.RandomHorizontalFlip(), 72 | ST.ToTensor(), 73 | ST.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 74 | ST.RandomErasing(height=config.DATA.HEIGHT, width=config.DATA.WIDTH, probability=config.AUG.RE_PROB) 75 | ]) 76 | spatial_transform_test = ST.Compose([ 77 | ST.Scale((config.DATA.HEIGHT, config.DATA.WIDTH), interpolation=3), 78 | ST.ToTensor(), 79 | ST.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 80 | ]) 81 | 82 | if config.AUG.TEMPORAL_SAMPLING_MODE == 'tsn': 83 | temporal_transform_train = TT.TemporalDivisionCrop(size=config.AUG.SEQ_LEN) 84 | elif config.AUG.TEMPORAL_SAMPLING_MODE == 'stride': 85 | temporal_transform_train = TT.TemporalRandomCrop(size=config.AUG.SEQ_LEN, 86 | stride=config.AUG.SAMPLING_STRIDE) 87 | else: 88 | raise KeyError("Invalid temporal sempling mode '{}'".format(config.AUG.TEMPORAL_SAMPLING_MODE)) 89 | 90 | temporal_transform_test = None 91 | 92 | return spatial_transform_train, spatial_transform_test, temporal_transform_train, temporal_transform_test 93 | 94 | 95 | def build_dataloader(config): 96 | dataset = build_dataset(config) 97 | # video dataset 98 | if config.DATA.DATASET in VID_DATASET: 99 | spatial_transform_train, spatial_transform_test, temporal_transform_train, temporal_transform_test = build_vid_transforms(config) 100 | 101 | if config.DATA.DENSE_SAMPLING: 102 | train_sampler = DistributedRandomIdentitySampler(dataset.train_dense, 103 | num_instances=config.DATA.NUM_INSTANCES, 104 | seed=config.SEED) 105 | # split each original training video into a series of short videos and sample one clip for each short video during training 106 | trainloader = DataLoaderX( 107 | dataset=VideoDataset(dataset.train_dense, spatial_transform_train, temporal_transform_train), 108 | sampler=train_sampler, 109 | batch_size=config.DATA.TRAIN_BATCH, num_workers=config.DATA.NUM_WORKERS, 110 | pin_memory=True, drop_last=True) 111 | else: 112 | train_sampler = DistributedRandomIdentitySampler(dataset.train, 113 | num_instances=config.DATA.NUM_INSTANCES, 114 | seed=config.SEED) 115 | # sample one clip for each original training video during training 116 | trainloader = DataLoaderX( 117 | dataset=VideoDataset(dataset.train, spatial_transform_train, temporal_transform_train), 118 | sampler=train_sampler, 119 | batch_size=config.DATA.TRAIN_BATCH, num_workers=config.DATA.NUM_WORKERS, 120 | pin_memory=True, drop_last=True) 121 | 122 | # split each original test video into a series of clips and use the averaged feature of all clips as its representation 123 | queryloader = DataLoaderX( 124 | dataset=VideoDataset(dataset.recombined_query, spatial_transform_test, temporal_transform_test), 125 | sampler=DistributedInferenceSampler(dataset.recombined_query), 126 | batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS, 127 | pin_memory=True, drop_last=False, shuffle=False) 128 | galleryloader = DataLoaderX( 129 | dataset=VideoDataset(dataset.recombined_gallery, spatial_transform_test, temporal_transform_test), 130 | sampler=DistributedInferenceSampler(dataset.recombined_gallery), 131 | batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS, 132 | pin_memory=True, drop_last=False, shuffle=False) 133 | 134 | return trainloader, queryloader, galleryloader, dataset, train_sampler 135 | # image dataset 136 | else: 137 | transform_train, transform_test = build_img_transforms(config) 138 | train_sampler = DistributedRandomIdentitySampler(dataset.train, 139 | num_instances=config.DATA.NUM_INSTANCES, 140 | seed=config.SEED) 141 | trainloader = DataLoaderX(dataset=ImageDataset_Train(dataset.train, transform=transform_train), 142 | sampler=train_sampler, 143 | batch_size=config.DATA.TRAIN_BATCH, num_workers=config.DATA.NUM_WORKERS, 144 | pin_memory=True, drop_last=True) 145 | 146 | galleryloader = DataLoaderX(dataset=ImageDataset(dataset.gallery, transform=transform_test), 147 | sampler=DistributedInferenceSampler(dataset.gallery), 148 | batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS, 149 | pin_memory=True, drop_last=False, shuffle=False) 150 | 151 | if config.DATA.DATASET == 'prcc': 152 | queryloader_same = DataLoaderX(dataset=ImageDataset(dataset.query_same, transform=transform_test), 153 | sampler=DistributedInferenceSampler(dataset.query_same), 154 | batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS, 155 | pin_memory=True, drop_last=False, shuffle=False) 156 | queryloader_diff = DataLoaderX(dataset=ImageDataset(dataset.query_diff, transform=transform_test), 157 | sampler=DistributedInferenceSampler(dataset.query_diff), 158 | batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS, 159 | pin_memory=True, drop_last=False, shuffle=False) 160 | 161 | return trainloader, queryloader_same, queryloader_diff, galleryloader, dataset, train_sampler 162 | else: 163 | queryloader = DataLoaderX(dataset=ImageDataset(dataset.query, transform=transform_test), 164 | sampler=DistributedInferenceSampler(dataset.query), 165 | batch_size=config.DATA.TEST_BATCH, num_workers=config.DATA.NUM_WORKERS, 166 | pin_memory=True, drop_last=False, shuffle=False) 167 | 168 | return trainloader, queryloader, galleryloader, dataset, train_sampler 169 | 170 | 171 | 172 | 173 | -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/__pycache__/dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/__pycache__/dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataset_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/__pycache__/dataset_loader.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataset_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/__pycache__/dataset_loader.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataset_loader3.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/__pycache__/dataset_loader3.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/img_transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/__pycache__/img_transforms.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/img_transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/__pycache__/img_transforms.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/img_transforms3.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/__pycache__/img_transforms3.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/samplers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/__pycache__/samplers.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/samplers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/__pycache__/samplers.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/spatial_transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/__pycache__/spatial_transforms.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/spatial_transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/__pycache__/spatial_transforms.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/temporal_transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/__pycache__/temporal_transforms.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/temporal_transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/__pycache__/temporal_transforms.cpython-37.pyc -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | # refer to: https://github.com/JDAI-CV/fast-reid/blob/master/fastreid/data/data_utils.py 2 | 3 | import torch 4 | import threading 5 | import queue 6 | from torch.utils.data import DataLoader 7 | from torch import distributed as dist 8 | 9 | 10 | """ 11 | #based on http://stackoverflow.com/questions/7323664/python-generator-pre-fetch 12 | This is a single-function package that transforms arbitrary generator into a background-thead generator that 13 | prefetches several batches of data in a parallel background thead. 14 | 15 | This is useful if you have a computationally heavy process (CPU or GPU) that 16 | iteratively processes minibatches from the generator while the generator 17 | consumes some other resource (disk IO / loading from database / more CPU if you have unused cores). 18 | 19 | By default these two processes will constantly wait for one another to finish. If you make generator work in 20 | prefetch mode (see examples below), they will work in parallel, potentially saving you your GPU time. 21 | We personally use the prefetch generator when iterating minibatches of data for deep learning with PyTorch etc. 22 | 23 | Quick usage example (ipython notebook) - https://github.com/justheuristic/prefetch_generator/blob/master/example.ipynb 24 | This package contains this object 25 | - BackgroundGenerator(any_other_generator[,max_prefetch = something]) 26 | """ 27 | 28 | 29 | class BackgroundGenerator(threading.Thread): 30 | """ 31 | the usage is below 32 | >> for batch in BackgroundGenerator(my_minibatch_iterator): 33 | >> doit() 34 | More details are written in the BackgroundGenerator doc 35 | >> help(BackgroundGenerator) 36 | """ 37 | 38 | def __init__(self, generator, local_rank, max_prefetch=10): 39 | """ 40 | This function transforms generator into a background-thead generator. 41 | :param generator: generator or genexp or any 42 | It can be used with any minibatch generator. 43 | 44 | It is quite lightweight, but not entirely weightless. 45 | Using global variables inside generator is not recommended (may raise GIL and zero-out the 46 | benefit of having a background thread.) 47 | The ideal use case is when everything it requires is store inside it and everything it 48 | outputs is passed through queue. 49 | 50 | There's no restriction on doing weird stuff, reading/writing files, retrieving 51 | URLs [or whatever] wlilst iterating. 52 | 53 | :param max_prefetch: defines, how many iterations (at most) can background generator keep 54 | stored at any moment of time. 55 | Whenever there's already max_prefetch batches stored in queue, the background process will halt until 56 | one of these batches is dequeued. 57 | 58 | !Default max_prefetch=1 is okay unless you deal with some weird file IO in your generator! 59 | 60 | Setting max_prefetch to -1 lets it store as many batches as it can, which will work 61 | slightly (if any) faster, but will require storing 62 | all batches in memory. If you use infinite generator with max_prefetch=-1, it will exceed the RAM size 63 | unless dequeued quickly enough. 64 | """ 65 | super().__init__() 66 | self.queue = queue.Queue(max_prefetch) 67 | self.generator = generator 68 | self.local_rank = local_rank 69 | self.daemon = True 70 | self.exit_event = threading.Event() 71 | self.start() 72 | 73 | def run(self): 74 | torch.cuda.set_device(self.local_rank) 75 | for item in self.generator: 76 | if self.exit_event.is_set(): 77 | break 78 | self.queue.put(item) 79 | self.queue.put(None) 80 | 81 | def next(self): 82 | next_item = self.queue.get() 83 | if next_item is None: 84 | raise StopIteration 85 | return next_item 86 | 87 | # Python 3 compatibility 88 | def __next__(self): 89 | return self.next() 90 | 91 | def __iter__(self): 92 | return self 93 | 94 | 95 | class DataLoaderX(DataLoader): 96 | def __init__(self, **kwargs): 97 | super().__init__(**kwargs) 98 | local_rank = dist.get_rank() 99 | self.stream = torch.cuda.Stream(local_rank) # create a new cuda stream in each process 100 | self.local_rank = local_rank 101 | 102 | def __iter__(self): 103 | self.iter = super().__iter__() 104 | self.iter = BackgroundGenerator(self.iter, self.local_rank) 105 | self.preload() 106 | return self 107 | 108 | def _shutdown_background_thread(self): 109 | if not self.iter.is_alive(): 110 | # avoid re-entrance or ill-conditioned thread state 111 | return 112 | 113 | # Set exit event to True for background threading stopping 114 | self.iter.exit_event.set() 115 | 116 | # Exhaust all remaining elements, so that the queue becomes empty, 117 | # and the thread should quit 118 | for _ in self.iter: 119 | pass 120 | 121 | # Waiting for background thread to quit 122 | self.iter.join() 123 | 124 | def preload(self): 125 | self.batch = next(self.iter, None) 126 | if self.batch is None: 127 | return None 128 | with torch.cuda.stream(self.stream): 129 | # if isinstance(self.batch[0], torch.Tensor): 130 | # self.batch[0] = self.batch[0].to(device=self.local_rank, non_blocking=True) 131 | for k, v in enumerate(self.batch): 132 | if isinstance(self.batch[k], torch.Tensor): 133 | self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) 134 | 135 | def __next__(self): 136 | torch.cuda.current_stream().wait_stream( 137 | self.stream 138 | ) # wait tensor to put on GPU 139 | batch = self.batch 140 | if batch is None: 141 | raise StopIteration 142 | self.preload() 143 | return batch 144 | 145 | # Signal for shutting down background thread 146 | def shutdown(self): 147 | # If the dataloader is to be freed, shutdown its BackgroundGenerator 148 | self._shutdown_background_thread() 149 | -------------------------------------------------------------------------------- /data/dataset_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import functools 3 | import os.path as osp 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | import numpy as np 7 | 8 | 9 | def read_image(img_path): 10 | """Keep reading image until succeed. 11 | This can avoid IOError incurred by heavy IO process.""" 12 | got_img = False 13 | if not osp.exists(img_path): 14 | raise IOError("{} does not exist".format(img_path)) 15 | while not got_img: 16 | try: 17 | img = Image.open(img_path).convert('RGB') 18 | got_img = True 19 | except IOError: 20 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 21 | pass 22 | return img 23 | 24 | def read_image_gray(img_path): 25 | """Keep reading image until succeed. 26 | This can avoid IOError incurred by heavy IO process.""" 27 | got_img = False 28 | if not osp.exists(img_path): 29 | raise IOError("{} does not exist".format(img_path)) 30 | while not got_img: 31 | try: 32 | img = Image.open(img_path).convert('L') 33 | got_img = True 34 | except IOError: 35 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 36 | pass 37 | return img 38 | 39 | class ImageDataset(Dataset): 40 | """Image Person ReID Dataset""" 41 | def __init__(self, dataset, transform=None): 42 | self.dataset = dataset 43 | self.transform = transform 44 | 45 | def __len__(self): 46 | return len(self.dataset) 47 | 48 | def __getitem__(self, index): 49 | img_path, pid, camid, clothes_id = self.dataset[index] 50 | img = read_image(img_path) 51 | 52 | if osp.exists(img_path.replace('/query/', '/query_cloth/')): 53 | cloth = read_image_gray(img_path.replace('/query/', '/query_cloth/'))######################## 54 | else: 55 | cloth = read_image_gray(img_path) 56 | 57 | if osp.exists(img_path.replace('/query/', '/query_cloth_/')): 58 | _cloth = read_image_gray(img_path.replace('/query/', '/query_cloth_/'))####################### 59 | else: 60 | _cloth = read_image_gray(img_path) 61 | 62 | if osp.exists(img_path.replace('/query/', '/query_contour/')): 63 | contour = read_image_gray(img_path.replace('/query/', '/query_contour/')) 64 | else: 65 | contour = read_image_gray(img_path) 66 | 67 | if self.transform is not None: 68 | img, cloth, _cloth, contour = self.transform(img, cloth, _cloth, contour) 69 | return img, pid, camid, clothes_id, cloth, _cloth, contour#, img_path 70 | 71 | class ImageDataset_Train(Dataset): 72 | """Image Person ReID Dataset""" 73 | def __init__(self, dataset, transform=None): 74 | self.dataset = dataset 75 | self.transform = transform 76 | 77 | def __len__(self): 78 | return len(self.dataset) 79 | 80 | def __getitem__(self, index): 81 | img_path, pid, camid, clothes_id, cloth_path, _cloth_path, contour_path = self.dataset[index] 82 | img = read_image(img_path) 83 | cloth = read_image_gray(cloth_path)################################## 84 | _cloth = read_image_gray(_cloth_path)#################################### 85 | contour = read_image_gray(contour_path) 86 | # print (np.max(cloth), np.max(_cloth), np.max(contour)) 87 | if self.transform is not None: 88 | img, cloth, _cloth, contour = self.transform(img, cloth, _cloth, contour) 89 | return img, pid, camid, clothes_id, cloth, _cloth, contour 90 | 91 | 92 | def pil_loader(path): 93 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 94 | with open(path, 'rb') as f: 95 | with Image.open(f) as img: 96 | return img.convert('RGB') 97 | 98 | def gray_pil_loader(gray_path): 99 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 100 | with open(gray_path, 'rb') as gray_f: 101 | with Image.open(gray_f) as gray_img: 102 | return gray_img.convert('L') 103 | 104 | def accimage_loader(path): 105 | try: 106 | import accimage 107 | return accimage.Image(path) 108 | except IOError: 109 | # Potentially a decoding problem, fall back to PIL.Image 110 | return pil_loader(path) 111 | 112 | 113 | def get_default_image_loader(): 114 | from torchvision import get_image_backend 115 | if get_image_backend() == 'accimage': 116 | return accimage_loader 117 | else: 118 | return pil_loader 119 | 120 | def get_gray_image_loader(): 121 | return gray_pil_loader 122 | 123 | 124 | def image_loader(path): 125 | from torchvision import get_image_backend 126 | if get_image_backend() == 'accimage': 127 | return accimage_loader(path) 128 | else: 129 | return pil_loader(path) 130 | 131 | def gray_image_loader(gray_path): 132 | return gray_pil_loader(gray_path) 133 | 134 | 135 | def video_loader(img_paths, image_loader): 136 | video = [] 137 | for image_path in img_paths: 138 | if osp.exists(image_path): 139 | video.append(image_loader(image_path)) 140 | else: 141 | return video 142 | 143 | return video 144 | 145 | def gray_video_loader(gray_img_paths, gray_image_loader): 146 | gray_video = [] 147 | for gray_image_path in gray_img_paths: 148 | if osp.exists(gray_image_path): 149 | gray_video.append(gray_image_loader(gray_image_path)) 150 | else: 151 | return gray_video 152 | 153 | return gray_video 154 | 155 | 156 | def get_default_video_loader(): 157 | image_loader = get_default_image_loader() 158 | return functools.partial(video_loader, image_loader=image_loader) 159 | 160 | def get_gray_video_loader(): 161 | gray_image_loader = get_gray_image_loader() 162 | return functools.partial(gray_video_loader, gray_image_loader=gray_image_loader) 163 | 164 | 165 | class VideoDataset(Dataset): 166 | """Video Person ReID Dataset. 167 | Note: 168 | Batch data has shape N x C x T x H x W 169 | Args: 170 | dataset (list): List with items (img_paths, pid, camid) 171 | temporal_transform (callable, optional): A function/transform that takes in a list of frame indices 172 | and returns a transformed version 173 | target_transform (callable, optional): A function/transform that takes in the 174 | target and transforms it. 175 | loader (callable, optional): A function to load an video given its path and frame indices. 176 | """ 177 | 178 | def __init__(self, 179 | dataset, 180 | spatial_transform=None, 181 | temporal_transform=None, 182 | get_loader=get_default_video_loader, 183 | cloth_changing=True): 184 | self.dataset = dataset 185 | self.spatial_transform = spatial_transform 186 | self.temporal_transform = temporal_transform 187 | self.loader = get_loader() 188 | self.gray_loader = get_gray_video_loader() 189 | self.cloth_changing = cloth_changing 190 | 191 | def __len__(self): 192 | return len(self.dataset) 193 | 194 | def __getitem__(self, index): 195 | """ 196 | Args: 197 | index (int): Index 198 | 199 | Returns: 200 | tuple: (clip, pid, camid) where pid is identity of the clip. 201 | """ 202 | if self.cloth_changing: 203 | img_paths, pid, camid, clothes_id = self.dataset[index] 204 | else: 205 | img_paths, pid, camid = self.dataset[index] 206 | 207 | if self.temporal_transform is not None: 208 | img_paths = self.temporal_transform(img_paths) 209 | 210 | # 获取衣服 211 | path_clip_cloth = [each.replace('session', 'cloth_session') for each in img_paths] 212 | # 获取非衣服 213 | path_clip_cloth_ = [each.replace('session', '_cloth_session') for each in img_paths] 214 | # 获取轮廓 215 | path_clip_contour = [each.replace('session', 'contour_session') for each in img_paths] 216 | 217 | clip = self.loader(img_paths)##################### 218 | clip_cloth = self.gray_loader(path_clip_cloth) 219 | clip_cloth_ = self.gray_loader(path_clip_cloth_) 220 | clip_contour = self.gray_loader(path_clip_contour) 221 | 222 | # print (len(clip), ",", clip[0], "|", len(clip_cloth), ",", clip_cloth[0], "|", len(clip_contour), clip_contour[0]) 223 | # 8 , (128, 256) | 8 , (128, 256) | 8 (128, 256) 224 | 225 | # np.max(clip[0], np.max(clip_cloth), np.max(clip_contour)) 226 | 227 | 228 | if self.spatial_transform is not None: 229 | self.spatial_transform.randomize_parameters() 230 | clip_all = [self.spatial_transform(img1, img2, img3, img4) for img1, img2, img3, img4 in zip(clip, clip_cloth, clip_cloth_, clip_contour)] 231 | 232 | clip, clip_cloth, clip_cloth_, clip_contour = [], [], [], [] 233 | for (a,b,c,d) in clip_all: 234 | clip.append(a) 235 | clip_cloth.append(b) 236 | clip_cloth_.append(c) 237 | clip_contour.append(d) 238 | 239 | # print ('==========================') 240 | # exit(0) 241 | 242 | # trans T x C x H x W to C x T x H x W 243 | clip = torch.stack(clip, 0).permute(1, 0, 2, 3) 244 | clip_cloth = torch.stack(clip_cloth, 0).permute(1, 0, 2, 3) 245 | clip_cloth_ = torch.stack(clip_cloth_, 0).permute(1, 0, 2, 3) 246 | clip_contour = torch.stack(clip_contour, 0).permute(1, 0, 2, 3) 247 | 248 | if self.cloth_changing: 249 | return clip, pid, camid, clothes_id, clip_cloth, clip_cloth_, clip_contour 250 | else: 251 | return clip, pid, camid 252 | -------------------------------------------------------------------------------- /data/datasets/__pycache__/ccvid.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/datasets/__pycache__/ccvid.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/deepchange.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/datasets/__pycache__/deepchange.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/last.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/datasets/__pycache__/last.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/ltcc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/datasets/__pycache__/ltcc.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/ltcc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/datasets/__pycache__/ltcc.cpython-37.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/ltcc3.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/datasets/__pycache__/ltcc3.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/prcc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/datasets/__pycache__/prcc.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/__pycache__/vcclothes.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/data/datasets/__pycache__/vcclothes.cpython-36.pyc -------------------------------------------------------------------------------- /data/datasets/ccvid.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import glob 4 | import h5py 5 | import random 6 | import math 7 | import logging 8 | import numpy as np 9 | import os.path as osp 10 | from scipy.io import loadmat 11 | from tools.utils import mkdir_if_missing, write_json, read_json 12 | 13 | 14 | class CCVID(object): 15 | """ CCVID 16 | 17 | Reference: 18 | Gu et al. Clothes-Changing Person Re-identification with RGB Modality Only. In CVPR, 2022. 19 | """ 20 | def __init__(self, root='/data/datasets/', sampling_step=64, seq_len=16, stride=4, **kwargs): 21 | self.root = osp.join(root, 'CCVID') 22 | self.train_path = osp.join(self.root, 'train.txt') 23 | self.query_path = osp.join(self.root, 'query.txt') 24 | self.gallery_path = osp.join(self.root, 'gallery.txt') 25 | self._check_before_run() 26 | 27 | train, num_train_tracklets, num_train_pids, num_train_imgs, num_train_clothes, pid2clothes, _ = \ 28 | self._process_data(self.train_path, relabel=True) 29 | clothes2label = self._clothes2label_test(self.query_path, self.gallery_path) 30 | query, num_query_tracklets, num_query_pids, num_query_imgs, num_query_clothes, _, _ = \ 31 | self._process_data(self.query_path, relabel=False, clothes2label=clothes2label) 32 | gallery, num_gallery_tracklets, num_gallery_pids, num_gallery_imgs, num_gallery_clothes, _, _ = \ 33 | self._process_data(self.gallery_path, relabel=False, clothes2label=clothes2label) 34 | 35 | # slice each full-length video in the trainingset into more video clip 36 | train_dense = self._densesampling_for_trainingset(train, sampling_step) 37 | # In the test stage, each video sample is divided into a series of equilong video clips with a pre-defined stride. 38 | recombined_query, query_vid2clip_index = self._recombination_for_testset(query, seq_len=seq_len, stride=stride) 39 | recombined_gallery, gallery_vid2clip_index = self._recombination_for_testset(gallery, seq_len=seq_len, stride=stride) 40 | 41 | num_imgs_per_tracklet = num_train_imgs + num_gallery_imgs + num_query_imgs 42 | min_num = np.min(num_imgs_per_tracklet) 43 | max_num = np.max(num_imgs_per_tracklet) 44 | avg_num = np.mean(num_imgs_per_tracklet) 45 | 46 | num_total_pids = num_train_pids + num_gallery_pids 47 | num_total_clothes = num_train_clothes + len(clothes2label) 48 | num_total_tracklets = num_train_tracklets + num_gallery_tracklets + num_query_tracklets 49 | 50 | logger = logging.getLogger('reid.dataset') 51 | logger.info("=> CCVID loaded") 52 | logger.info("Dataset statistics:") 53 | logger.info(" ---------------------------------------------") 54 | logger.info(" subset | # ids | # tracklets | # clothes") 55 | logger.info(" ---------------------------------------------") 56 | logger.info(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_clothes)) 57 | logger.info(" train_dense | {:5d} | {:11d} | {:9d}".format(num_train_pids, len(train_dense), num_train_clothes)) 58 | logger.info(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_clothes)) 59 | logger.info(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_clothes)) 60 | logger.info(" ---------------------------------------------") 61 | logger.info(" total | {:5d} | {:11d} | {:9d}".format(num_total_pids, num_total_tracklets, num_total_clothes)) 62 | logger.info(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 63 | logger.info(" ---------------------------------------------") 64 | 65 | self.train = train 66 | self.train_dense = train_dense 67 | self.query = query 68 | self.gallery = gallery 69 | 70 | self.recombined_query = recombined_query 71 | self.recombined_gallery = recombined_gallery 72 | self.query_vid2clip_index = query_vid2clip_index 73 | self.gallery_vid2clip_index = gallery_vid2clip_index 74 | 75 | self.num_train_pids = num_train_pids 76 | self.num_train_clothes = num_train_clothes 77 | self.pid2clothes = pid2clothes 78 | 79 | def _check_before_run(self): 80 | """Check if all files are available before going deeper""" 81 | if not osp.exists(self.root): 82 | raise RuntimeError("'{}' is not available".format(self.root)) 83 | if not osp.exists(self.train_path): 84 | raise RuntimeError("'{}' is not available".format(self.train_path)) 85 | if not osp.exists(self.query_path): 86 | raise RuntimeError("'{}' is not available".format(self.query_path)) 87 | if not osp.exists(self.gallery_path): 88 | raise RuntimeError("'{}' is not available".format(self.gallery_path)) 89 | 90 | def _clothes2label_test(self, query_path, gallery_path): 91 | pid_container = set() 92 | clothes_container = set() 93 | with open(query_path, 'r') as f: 94 | for line in f: 95 | new_line = line.rstrip() 96 | tracklet_path, pid, clothes_label = new_line.split() 97 | clothes = '{}_{}'.format(pid, clothes_label) 98 | pid_container.add(pid) 99 | clothes_container.add(clothes) 100 | with open(gallery_path, 'r') as f: 101 | for line in f: 102 | new_line = line.rstrip() 103 | tracklet_path, pid, clothes_label = new_line.split() 104 | clothes = '{}_{}'.format(pid, clothes_label) 105 | pid_container.add(pid) 106 | clothes_container.add(clothes) 107 | pid_container = sorted(pid_container) 108 | clothes_container = sorted(clothes_container) 109 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 110 | clothes2label = {clothes:label for label, clothes in enumerate(clothes_container)} 111 | 112 | return clothes2label 113 | 114 | def _process_data(self, data_path, relabel=False, clothes2label=None): 115 | tracklet_path_list = [] 116 | pid_container = set() 117 | clothes_container = set() 118 | with open(data_path, 'r') as f: 119 | for line in f: 120 | new_line = line.rstrip() 121 | tracklet_path, pid, clothes_label = new_line.split() 122 | tracklet_path_list.append((tracklet_path, pid, clothes_label)) 123 | clothes = '{}_{}'.format(pid, clothes_label) 124 | pid_container.add(pid) 125 | clothes_container.add(clothes) 126 | pid_container = sorted(pid_container) 127 | clothes_container = sorted(clothes_container) 128 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 129 | if clothes2label is None: 130 | clothes2label = {clothes:label for label, clothes in enumerate(clothes_container)} 131 | 132 | num_tracklets = len(tracklet_path_list) 133 | num_pids = len(pid_container) 134 | num_clothes = len(clothes_container) 135 | 136 | tracklets = [] 137 | num_imgs_per_tracklet = [] 138 | pid2clothes = np.zeros((num_pids, len(clothes2label))) 139 | 140 | for tracklet_path, pid, clothes_label in tracklet_path_list: 141 | img_paths = glob.glob(osp.join(self.root, tracklet_path, '*')) 142 | img_paths.sort() 143 | 144 | clothes = '{}_{}'.format(pid, clothes_label) 145 | clothes_id = clothes2label[clothes] 146 | pid2clothes[pid2label[pid], clothes_id] = 1 147 | if relabel: 148 | pid = pid2label[pid] 149 | else: 150 | pid = int(pid) 151 | session = tracklet_path.split('/')[0] 152 | cam = tracklet_path.split('_')[1] 153 | if session == 'session3': 154 | camid = int(cam) + 12 155 | else: 156 | camid = int(cam) 157 | 158 | num_imgs_per_tracklet.append(len(img_paths)) 159 | tracklets.append((img_paths, pid, camid, clothes_id)) 160 | 161 | num_tracklets = len(tracklets) 162 | 163 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet, num_clothes, pid2clothes, clothes2label 164 | 165 | def _densesampling_for_trainingset(self, dataset, sampling_step=64): 166 | ''' Split all videos in training set into lots of clips for dense sampling. 167 | 168 | Args: 169 | dataset (list): input dataset, each video is organized as (img_paths, pid, camid, clothes_id) 170 | sampling_step (int): sampling step for dense sampling 171 | 172 | Returns: 173 | new_dataset (list): output dataset 174 | ''' 175 | new_dataset = [] 176 | for (img_paths, pid, camid, clothes_id) in dataset: 177 | if sampling_step != 0: 178 | num_sampling = len(img_paths)//sampling_step 179 | if num_sampling == 0: 180 | new_dataset.append((img_paths, pid, camid, clothes_id)) 181 | else: 182 | for idx in range(num_sampling): 183 | if idx == num_sampling - 1: 184 | new_dataset.append((img_paths[idx*sampling_step:], pid, camid, clothes_id)) 185 | else: 186 | new_dataset.append((img_paths[idx*sampling_step : (idx+1)*sampling_step], pid, camid, clothes_id)) 187 | else: 188 | new_dataset.append((img_paths, pid, camid, clothes_id)) 189 | 190 | return new_dataset 191 | 192 | def _recombination_for_testset(self, dataset, seq_len=16, stride=4): 193 | ''' Split all videos in test set into lots of equilong clips. 194 | 195 | Args: 196 | dataset (list): input dataset, each video is organized as (img_paths, pid, camid, clothes_id) 197 | seq_len (int): sequence length of each output clip 198 | stride (int): temporal sampling stride 199 | 200 | Returns: 201 | new_dataset (list): output dataset with lots of equilong clips 202 | vid2clip_index (list): a list contains the start and end clip index of each original video 203 | ''' 204 | new_dataset = [] 205 | vid2clip_index = np.zeros((len(dataset), 2), dtype=int) 206 | for idx, (img_paths, pid, camid, clothes_id) in enumerate(dataset): 207 | # start index 208 | vid2clip_index[idx, 0] = len(new_dataset) 209 | # process the sequence that can be divisible by seq_len*stride 210 | for i in range(len(img_paths)//(seq_len*stride)): 211 | for j in range(stride): 212 | begin_idx = i * (seq_len * stride) + j 213 | end_idx = (i + 1) * (seq_len * stride) 214 | clip_paths = img_paths[begin_idx : end_idx : stride] 215 | assert(len(clip_paths) == seq_len) 216 | new_dataset.append((clip_paths, pid, camid, clothes_id)) 217 | # process the remaining sequence that can't be divisible by seq_len*stride 218 | if len(img_paths)%(seq_len*stride) != 0: 219 | # reducing stride 220 | new_stride = (len(img_paths)%(seq_len*stride)) // seq_len 221 | for i in range(new_stride): 222 | begin_idx = len(img_paths) // (seq_len*stride) * (seq_len*stride) + i 223 | end_idx = len(img_paths) // (seq_len*stride) * (seq_len*stride) + seq_len * new_stride 224 | clip_paths = img_paths[begin_idx : end_idx : new_stride] 225 | assert(len(clip_paths) == seq_len) 226 | new_dataset.append((clip_paths, pid, camid, clothes_id)) 227 | # process the remaining sequence that can't be divisible by seq_len 228 | if len(img_paths) % seq_len != 0: 229 | clip_paths = img_paths[len(img_paths)//seq_len*seq_len:] 230 | # loop padding 231 | while len(clip_paths) < seq_len: 232 | for index in clip_paths: 233 | if len(clip_paths) >= seq_len: 234 | break 235 | clip_paths.append(index) 236 | assert(len(clip_paths) == seq_len) 237 | new_dataset.append((clip_paths, pid, camid, clothes_id)) 238 | # end index 239 | vid2clip_index[idx, 1] = len(new_dataset) 240 | assert((vid2clip_index[idx, 1]-vid2clip_index[idx, 0]) == math.ceil(len(img_paths)/seq_len)) 241 | 242 | return new_dataset, vid2clip_index.tolist() 243 | 244 | -------------------------------------------------------------------------------- /data/datasets/deepchange.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import glob 4 | import h5py 5 | import random 6 | import math 7 | import logging 8 | import numpy as np 9 | import os.path as osp 10 | from scipy.io import loadmat 11 | from tools.utils import mkdir_if_missing, write_json, read_json 12 | 13 | 14 | class DeepChange(object): 15 | """ DeepChange 16 | 17 | Reference: 18 | Xu et al. DeepChange: A Long-Term Person Re-Identification Benchmark. arXiv:2105.14685, 2021. 19 | 20 | URL: https://github.com/PengBoXiangShang/deepchange 21 | """ 22 | dataset_dir = 'DeepChangeDataset' 23 | def __init__(self, root='data', **kwargs): 24 | self.dataset_dir = osp.join(root, self.dataset_dir) 25 | self.train_dir = osp.join(self.dataset_dir, 'train-set') 26 | self.train_list = osp.join(self.dataset_dir, 'train-set-bbox.txt') 27 | self.val_query_dir = osp.join(self.dataset_dir, 'val-set-query') 28 | self.val_query_list = osp.join(self.dataset_dir, 'val-set-query-bbox.txt') 29 | self.val_gallery_dir = osp.join(self.dataset_dir, 'val-set-gallery') 30 | self.val_gallery_list = osp.join(self.dataset_dir, 'val-set-gallery-bbox.txt') 31 | self.test_query_dir = osp.join(self.dataset_dir, 'test-set-query') 32 | self.test_query_list = osp.join(self.dataset_dir, 'test-set-query-bbox.txt') 33 | self.test_gallery_dir = osp.join(self.dataset_dir, 'test-set-gallery') 34 | self.test_gallery_list = osp.join(self.dataset_dir, 'test-set-gallery-bbox.txt') 35 | self._check_before_run() 36 | 37 | train_names = self._get_names(self.train_list) 38 | val_query_names = self._get_names(self.val_query_list) 39 | val_gallery_names = self._get_names(self.val_gallery_list) 40 | test_query_names = self._get_names(self.test_query_list) 41 | test_gallery_names = self._get_names(self.test_gallery_list) 42 | 43 | pid2label, clothes2label, pid2clothes = self.get_pid2label_and_clothes2label(train_names) 44 | train, num_train_pids, num_train_clothes = self._process_dir(self.train_dir, train_names, clothes2label, pid2label=pid2label) 45 | 46 | pid2label, clothes2label = self.get_pid2label_and_clothes2label(val_query_names, val_gallery_names) 47 | val_query, num_val_query_pids, num_val_query_clothes = self._process_dir(self.val_query_dir, val_query_names, clothes2label) 48 | val_gallery, num_val_gallery_pids, num_val_gallery_clothes = self._process_dir(self.val_gallery_dir, val_gallery_names, clothes2label) 49 | num_val_pids = len(pid2label) 50 | num_val_clothes = len(clothes2label) 51 | 52 | pid2label, clothes2label = self.get_pid2label_and_clothes2label(test_query_names, test_gallery_names) 53 | test_query, num_test_query_pids, num_test_query_clothes = self._process_dir(self.test_query_dir, test_query_names, clothes2label) 54 | test_gallery, num_test_gallery_pids, num_test_gallery_clothes = self._process_dir(self.test_gallery_dir, test_gallery_names, clothes2label) 55 | num_test_pids = len(pid2label) 56 | num_test_clothes = len(clothes2label) 57 | 58 | num_total_pids = num_train_pids + num_val_pids + num_test_pids 59 | num_total_clothes = num_train_clothes + num_val_clothes + num_test_clothes 60 | num_total_imgs = len(train) + len(val_query) + len(val_gallery) + len(test_query) + len(test_gallery) 61 | 62 | logger = logging.getLogger('reid.dataset') 63 | logger.info("=> DeepChange loaded") 64 | logger.info("Dataset statistics:") 65 | logger.info(" --------------------------------------------") 66 | logger.info(" subset | # ids | # images | # clothes") 67 | logger.info(" ----------------------------------------") 68 | logger.info(" train | {:5d} | {:8d} | {:9d} ".format(num_train_pids, len(train), num_train_clothes)) 69 | logger.info(" query(val) | {:5d} | {:8d} | {:9d} ".format(num_val_query_pids, len(val_query), num_val_query_clothes)) 70 | logger.info(" gallery(val) | {:5d} | {:8d} | {:9d} ".format(num_val_gallery_pids, len(val_gallery), num_val_gallery_clothes)) 71 | logger.info(" query | {:5d} | {:8d} | {:9d} ".format(num_test_query_pids, len(test_query), num_test_query_clothes)) 72 | logger.info(" gallery | {:5d} | {:8d} | {:9d} ".format(num_test_gallery_pids, len(test_gallery), num_test_gallery_clothes)) 73 | logger.info(" --------------------------------------------") 74 | logger.info(" total | {:5d} | {:8d} | {:9d} ".format(num_total_pids, num_total_imgs, num_total_clothes)) 75 | logger.info(" --------------------------------------------") 76 | 77 | self.train = train 78 | self.val_query = val_query 79 | self.val_gallery = val_gallery 80 | self.query = test_query 81 | self.gallery = test_gallery 82 | 83 | self.num_train_pids = num_train_pids 84 | self.num_train_clothes = num_train_clothes 85 | self.pid2clothes = pid2clothes 86 | 87 | def _get_names(self, fpath): 88 | names = [] 89 | with open(fpath, 'r') as f: 90 | for line in f: 91 | new_line = line.rstrip() 92 | names.append(new_line) 93 | return names 94 | 95 | def get_pid2label_and_clothes2label(self, img_names1, img_names2=None): 96 | if img_names2 is not None: 97 | img_names = img_names1 + img_names2 98 | else: 99 | img_names = img_names1 100 | 101 | pid_container = set() 102 | clothes_container = set() 103 | for img_name in img_names: 104 | names = img_name.split('.')[0].split('_') 105 | clothes = names[0] + names[2] 106 | pid = int(names[0][1:]) 107 | pid_container.add(pid) 108 | clothes_container.add(clothes) 109 | pid_container = sorted(pid_container) 110 | clothes_container = sorted(clothes_container) 111 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 112 | clothes2label = {clothes:label for label, clothes in enumerate(clothes_container)} 113 | 114 | if img_names2 is not None: 115 | return pid2label, clothes2label 116 | 117 | num_pids = len(pid_container) 118 | num_clothes = len(clothes_container) 119 | pid2clothes = np.zeros((num_pids, num_clothes)) 120 | for img_name in img_names: 121 | names = img_name.split('.')[0].split('_') 122 | clothes = names[0] + names[2] 123 | pid = int(names[0][1:]) 124 | pid = pid2label[pid] 125 | clothes_id = clothes2label[clothes] 126 | pid2clothes[pid, clothes_id] = 1 127 | 128 | return pid2label, clothes2label, pid2clothes 129 | 130 | def _check_before_run(self): 131 | """Check if all files are available before going deeper""" 132 | if not osp.exists(self.dataset_dir): 133 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 134 | if not osp.exists(self.train_dir): 135 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 136 | if not osp.exists(self.val_query_dir): 137 | raise RuntimeError("'{}' is not available".format(self.val_query_dir)) 138 | if not osp.exists(self.val_gallery_dir): 139 | raise RuntimeError("'{}' is not available".format(self.val_gallery_dir)) 140 | if not osp.exists(self.test_query_dir): 141 | raise RuntimeError("'{}' is not available".format(self.test_query_dir)) 142 | if not osp.exists(self.test_gallery_dir): 143 | raise RuntimeError("'{}' is not available".format(self.test_gallery_dir)) 144 | 145 | def _process_dir(self, home_dir, img_names, clothes2label, pid2label=None): 146 | dataset = [] 147 | pid_container = set() 148 | clothes_container = set() 149 | for img_name in img_names: 150 | img_path = osp.join(home_dir, img_name.split(',')[0]) 151 | names = img_name.split('.')[0].split('_') 152 | tracklet_id = int(img_name.split(',')[1]) 153 | clothes = names[0] + names[2] 154 | clothes_id = clothes2label[clothes] 155 | clothes_container.add(clothes_id) 156 | pid = int(names[0][1:]) 157 | pid_container.add(pid) 158 | camid = int(names[1][1:]) 159 | if pid2label is not None: 160 | pid = pid2label[pid] 161 | # on DeepChange, we allow the true matches coming from the same camera 162 | # but different tracklets as query following the original paper. 163 | # So we use tracklet_id to replace camid for each sample. 164 | dataset.append((img_path, pid, tracklet_id, clothes_id)) 165 | num_pids = len(pid_container) 166 | num_clothes = len(clothes_container) 167 | 168 | return dataset, num_pids, num_clothes -------------------------------------------------------------------------------- /data/datasets/last.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import glob 4 | import h5py 5 | import random 6 | import math 7 | import logging 8 | import numpy as np 9 | import os.path as osp 10 | from scipy.io import loadmat 11 | from tools.utils import mkdir_if_missing, write_json, read_json 12 | 13 | 14 | class LaST(object): 15 | """ LaST 16 | 17 | Reference: 18 | Shu et al. Large-Scale Spatio-Temporal Person Re-identification: Algorithm and Benchmark. arXiv:2105.15076, 2021. 19 | 20 | URL: https://github.com/shuxjweb/last 21 | 22 | Note that LaST does not provide the clothes label for val and test set. 23 | """ 24 | dataset_dir = "last" 25 | def __init__(self, root='data', **kwargs): 26 | super(LaST, self).__init__() 27 | self.dataset_dir = osp.join(root, self.dataset_dir) 28 | self.train_dir = osp.join(self.dataset_dir, 'train') 29 | self.val_query_dir = osp.join(self.dataset_dir, 'val', 'query') 30 | self.val_gallery_dir = osp.join(self.dataset_dir, 'val', 'gallery') 31 | self.test_query_dir = osp.join(self.dataset_dir, 'test', 'query') 32 | self.test_gallery_dir = osp.join(self.dataset_dir, 'test', 'gallery') 33 | self._check_before_run() 34 | 35 | pid2label, clothes2label, pid2clothes = self.get_pid2label_and_clothes2label(self.train_dir) 36 | 37 | train, num_train_pids = self._process_dir(self.train_dir, pid2label=pid2label, clothes2label=clothes2label, relabel=True) 38 | val_query, num_val_query_pids = self._process_dir(self.val_query_dir, relabel=False) 39 | val_gallery, num_val_gallery_pids = self._process_dir(self.val_gallery_dir, relabel=False, recam=len(val_query)) 40 | test_query, num_test_query_pids = self._process_dir(self.test_query_dir, relabel=False) 41 | test_gallery, num_test_gallery_pids = self._process_dir(self.test_gallery_dir, relabel=False, recam=len(test_query)) 42 | 43 | num_total_pids = num_train_pids+num_val_gallery_pids+num_test_gallery_pids 44 | num_total_imgs = len(train) + len(val_query) + len(val_gallery) + len(test_query) + len(test_gallery) 45 | 46 | logger = logging.getLogger('reid.dataset') 47 | logger.info("=> LaST loaded") 48 | logger.info("Dataset statistics:") 49 | logger.info(" --------------------------------------------") 50 | logger.info(" subset | # ids | # images | # clothes") 51 | logger.info(" ----------------------------------------") 52 | logger.info(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, len(train), len(clothes2label))) 53 | logger.info(" query(val) | {:5d} | {:8d} |".format(num_val_query_pids, len(val_query))) 54 | logger.info(" gallery(val) | {:5d} | {:8d} |".format(num_val_gallery_pids, len(val_gallery))) 55 | logger.info(" query | {:5d} | {:8d} |".format(num_test_query_pids, len(test_query))) 56 | logger.info(" gallery | {:5d} | {:8d} |".format(num_test_gallery_pids, len(test_gallery))) 57 | logger.info(" --------------------------------------------") 58 | logger.info(" total | {:5d} | {:8d} | ".format(num_total_pids, num_total_imgs)) 59 | logger.info(" --------------------------------------------") 60 | 61 | self.train = train 62 | self.val_query = val_query 63 | self.val_gallery = val_gallery 64 | self.query = test_query 65 | self.gallery = test_gallery 66 | 67 | self.num_train_pids = num_train_pids 68 | self.num_train_clothes = len(clothes2label) 69 | self.pid2clothes = pid2clothes 70 | 71 | def get_pid2label_and_clothes2label(self, dir_path): 72 | img_paths = glob.glob(osp.join(dir_path, '*/*.jpg')) # [103367,] 73 | img_paths.sort() 74 | 75 | pid_container = set() 76 | clothes_container = set() 77 | for img_path in img_paths: 78 | names = osp.basename(img_path).split('.')[0].split('_') 79 | clothes = names[0] + '_' + names[-1] 80 | pid = int(names[0]) 81 | pid_container.add(pid) 82 | clothes_container.add(clothes) 83 | pid_container = sorted(pid_container) 84 | clothes_container = sorted(clothes_container) 85 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 86 | clothes2label = {clothes:label for label, clothes in enumerate(clothes_container)} 87 | 88 | num_pids = len(pid_container) 89 | num_clothes = len(clothes_container) 90 | 91 | pid2clothes = np.zeros((num_pids, num_clothes)) 92 | for img_path in img_paths: 93 | names = osp.basename(img_path).split('.')[0].split('_') 94 | clothes = names[0] + '_' + names[-1] 95 | pid = int(names[0]) 96 | pid = pid2label[pid] 97 | clothes_id = clothes2label[clothes] 98 | pid2clothes[pid, clothes_id] = 1 99 | 100 | return pid2label, clothes2label, pid2clothes 101 | 102 | def _check_before_run(self): 103 | """Check if all files are available before going deeper""" 104 | if not osp.exists(self.dataset_dir): 105 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 106 | if not osp.exists(self.train_dir): 107 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 108 | if not osp.exists(self.val_query_dir): 109 | raise RuntimeError("'{}' is not available".format(self.val_query_dir)) 110 | if not osp.exists(self.val_gallery_dir): 111 | raise RuntimeError("'{}' is not available".format(self.val_gallery_dir)) 112 | if not osp.exists(self.test_query_dir): 113 | raise RuntimeError("'{}' is not available".format(self.test_query_dir)) 114 | if not osp.exists(self.test_gallery_dir): 115 | raise RuntimeError("'{}' is not available".format(self.test_gallery_dir)) 116 | 117 | def _process_dir(self, dir_path, pid2label=None, clothes2label=None, relabel=False, recam=0): 118 | if 'query' in dir_path: 119 | img_paths = glob.glob(osp.join(dir_path, '*.jpg')) 120 | else: 121 | img_paths = glob.glob(osp.join(dir_path, '*/*.jpg')) 122 | img_paths.sort() 123 | 124 | dataset = [] 125 | pid_container = set() 126 | for ii, img_path in enumerate(img_paths): 127 | names = osp.basename(img_path).split('.')[0].split('_') 128 | clothes = names[0] + '_' + names[-1] 129 | pid = int(names[0]) 130 | pid_container.add(pid) 131 | camid = int(recam + ii) 132 | if relabel and pid2label is not None: 133 | pid = pid2label[pid] 134 | if relabel and clothes2label is not None: 135 | clothes_id = clothes2label[clothes] 136 | else: 137 | clothes_id = pid 138 | dataset.append((img_path, pid, camid, clothes_id)) 139 | num_pids = len(pid_container) 140 | 141 | return dataset, num_pids -------------------------------------------------------------------------------- /data/datasets/ltcc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import glob 4 | import h5py 5 | import random 6 | import math 7 | import logging 8 | import numpy as np 9 | import os.path as osp 10 | from scipy.io import loadmat 11 | from tools.utils import mkdir_if_missing, write_json, read_json 12 | 13 | 14 | class LTCC(object): 15 | """ LTCC 16 | 17 | Reference: 18 | Qian et al. Long-Term Cloth-Changing Person Re-identification. arXiv:2005.12633, 2020. 19 | 20 | URL: https://naiq.github.io/LTCC_Perosn_ReID.html# 21 | """ 22 | dataset_dir = 'LTCC_ReID' 23 | def __init__(self, root='data', **kwargs): 24 | self.dataset_dir = osp.join(root, self.dataset_dir) 25 | self.train_dir = osp.join(self.dataset_dir, 'train') 26 | self.query_dir = osp.join(self.dataset_dir, 'query') 27 | self.gallery_dir = osp.join(self.dataset_dir, 'test') 28 | self._check_before_run() 29 | 30 | train, num_train_pids, num_train_imgs, num_train_clothes, pid2clothes = \ 31 | self._process_dir_train(self.train_dir) 32 | query, gallery, num_test_pids, num_query_imgs, num_gallery_imgs, num_test_clothes = \ 33 | self._process_dir_test(self.query_dir, self.gallery_dir) 34 | num_total_pids = num_train_pids + num_test_pids 35 | num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs 36 | num_test_imgs = num_query_imgs + num_gallery_imgs 37 | num_total_clothes = num_train_clothes + num_test_clothes 38 | 39 | logger = logging.getLogger('reid.dataset') 40 | logger.info("=> LTCC loaded") 41 | logger.info("Dataset statistics:") 42 | logger.info(" ----------------------------------------") 43 | logger.info(" subset | # ids | # images | # clothes") 44 | logger.info(" ----------------------------------------") 45 | logger.info(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_clothes)) 46 | logger.info(" test | {:5d} | {:8d} | {:9d}".format(num_test_pids, num_test_imgs, num_test_clothes)) 47 | logger.info(" query | {:5d} | {:8d} |".format(num_test_pids, num_query_imgs)) 48 | logger.info(" gallery | {:5d} | {:8d} |".format(num_test_pids, num_gallery_imgs)) 49 | logger.info(" ----------------------------------------") 50 | logger.info(" total | {:5d} | {:8d} | {:9d}".format(num_total_pids, num_total_imgs, num_total_clothes)) 51 | logger.info(" ----------------------------------------") 52 | 53 | self.train = train 54 | self.query = query 55 | self.gallery = gallery 56 | 57 | self.num_train_pids = num_train_pids 58 | self.num_train_clothes = num_train_clothes 59 | self.pid2clothes = pid2clothes 60 | 61 | def _check_before_run(self): 62 | """Check if all files are available before going deeper""" 63 | if not osp.exists(self.dataset_dir): 64 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 65 | if not osp.exists(self.train_dir): 66 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 67 | if not osp.exists(self.query_dir): 68 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 69 | if not osp.exists(self.gallery_dir): 70 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 71 | 72 | def _process_dir_train(self, dir_path): 73 | img_paths = glob.glob(osp.join(dir_path, '*.png')) 74 | img_paths.sort() 75 | pattern1 = re.compile(r'(\d+)_(\d+)_c(\d+)') 76 | pattern2 = re.compile(r'(\w+)_c') 77 | 78 | pid_container = set() 79 | clothes_container = set() 80 | for img_path in img_paths: 81 | pid, _, _ = map(int, pattern1.search(img_path).groups()) 82 | clothes_id = pattern2.search(img_path).group(1) 83 | pid_container.add(pid) 84 | clothes_container.add(clothes_id) 85 | pid_container = sorted(pid_container) 86 | clothes_container = sorted(clothes_container) 87 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 88 | clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)} 89 | 90 | num_pids = len(pid_container) 91 | num_clothes = len(clothes_container) 92 | 93 | dataset = [] 94 | pid2clothes = np.zeros((num_pids, num_clothes)) 95 | for img_path in img_paths: 96 | pid, _, camid = map(int, pattern1.search(img_path).groups()) 97 | clothes = pattern2.search(img_path).group(1) 98 | camid -= 1 # index starts from 0 99 | pid = pid2label[pid] 100 | clothes_id = clothes2label[clothes] 101 | cloth_path = img_path.replace('/train/', '/train_cloth/') 102 | _cloth_path = img_path.replace('/train/', '/train_cloth_/') 103 | contour_path = img_path.replace('/train/', '/train_contour/') 104 | dataset.append((img_path, pid, camid, clothes_id, cloth_path, _cloth_path, contour_path)) 105 | pid2clothes[pid, clothes_id] = 1 106 | 107 | num_imgs = len(dataset) 108 | 109 | return dataset, num_pids, num_imgs, num_clothes, pid2clothes 110 | 111 | def _process_dir_test(self, query_path, gallery_path): 112 | query_img_paths = glob.glob(osp.join(query_path, '*.png')) 113 | gallery_img_paths = glob.glob(osp.join(gallery_path, '*.png')) 114 | query_img_paths.sort() 115 | gallery_img_paths.sort() 116 | pattern1 = re.compile(r'(\d+)_(\d+)_c(\d+)') 117 | pattern2 = re.compile(r'(\w+)_c') 118 | 119 | pid_container = set() 120 | clothes_container = set() 121 | for img_path in query_img_paths: 122 | pid, _, _ = map(int, pattern1.search(img_path).groups()) 123 | clothes_id = pattern2.search(img_path).group(1) 124 | pid_container.add(pid) 125 | clothes_container.add(clothes_id) 126 | for img_path in gallery_img_paths: 127 | pid, _, _ = map(int, pattern1.search(img_path).groups()) 128 | clothes_id = pattern2.search(img_path).group(1) 129 | pid_container.add(pid) 130 | clothes_container.add(clothes_id) 131 | pid_container = sorted(pid_container) 132 | clothes_container = sorted(clothes_container) 133 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 134 | clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)} 135 | 136 | num_pids = len(pid_container) 137 | num_clothes = len(clothes_container) 138 | 139 | query_dataset = [] 140 | gallery_dataset = [] 141 | for img_path in query_img_paths: 142 | pid, _, camid = map(int, pattern1.search(img_path).groups()) 143 | clothes_id = pattern2.search(img_path).group(1) 144 | camid -= 1 # index starts from 0 145 | clothes_id = clothes2label[clothes_id] 146 | query_dataset.append((img_path, pid, camid, clothes_id)) 147 | 148 | for img_path in gallery_img_paths: 149 | pid, _, camid = map(int, pattern1.search(img_path).groups()) 150 | clothes_id = pattern2.search(img_path).group(1) 151 | camid -= 1 # index starts from 0 152 | clothes_id = clothes2label[clothes_id] 153 | gallery_dataset.append((img_path, pid, camid, clothes_id)) 154 | 155 | num_imgs_query = len(query_dataset) 156 | num_imgs_gallery = len(gallery_dataset) 157 | 158 | return query_dataset, gallery_dataset, num_pids, num_imgs_query, num_imgs_gallery, num_clothes 159 | 160 | -------------------------------------------------------------------------------- /data/datasets/prcc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import glob 4 | import h5py 5 | import random 6 | import math 7 | import logging 8 | import numpy as np 9 | import os.path as osp 10 | from scipy.io import loadmat 11 | from tools.utils import mkdir_if_missing, write_json, read_json 12 | 13 | 14 | class PRCC(object): 15 | """ PRCC 16 | 17 | Reference: 18 | Yang et al. Person Re-identification by Contour Sketch under Moderate Clothing Change. TPAMI, 2019. 19 | 20 | URL: https://drive.google.com/file/d/1yTYawRm4ap3M-j0PjLQJ--xmZHseFDLz/view 21 | """ 22 | dataset_dir = 'prcc' 23 | def __init__(self, root='data', **kwargs): 24 | self.dataset_dir = osp.join(root, self.dataset_dir) 25 | self.train_dir = osp.join(self.dataset_dir, 'rgb/train') 26 | self.val_dir = osp.join(self.dataset_dir, 'rgb/val') 27 | self.test_dir = osp.join(self.dataset_dir, 'rgb/test') 28 | self._check_before_run() 29 | 30 | train, num_train_pids, num_train_imgs, num_train_clothes, pid2clothes = \ 31 | self._process_dir_train(self.train_dir) 32 | val, num_val_pids, num_val_imgs, num_val_clothes, _ = \ 33 | self._process_dir_train(self.val_dir) 34 | 35 | query_same, query_diff, gallery, num_test_pids, \ 36 | num_query_imgs_same, num_query_imgs_diff, num_gallery_imgs, \ 37 | num_test_clothes, gallery_idx = self._process_dir_test(self.test_dir) 38 | 39 | num_total_pids = num_train_pids + num_test_pids 40 | num_test_imgs = num_query_imgs_same + num_query_imgs_diff + num_gallery_imgs 41 | num_total_imgs = num_train_imgs + num_val_imgs + num_test_imgs 42 | num_total_clothes = num_train_clothes + num_test_clothes 43 | 44 | logger = logging.getLogger('reid.dataset') 45 | logger.info("=> PRCC loaded") 46 | logger.info("Dataset statistics:") 47 | logger.info(" --------------------------------------------") 48 | logger.info(" subset | # ids | # images | # clothes") 49 | logger.info(" --------------------------------------------") 50 | logger.info(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_clothes)) 51 | logger.info(" val | {:5d} | {:8d} | {:9d}".format(num_val_pids, num_val_imgs, num_val_clothes)) 52 | logger.info(" test | {:5d} | {:8d} | {:9d}".format(num_test_pids, num_test_imgs, num_test_clothes)) 53 | logger.info(" query(same) | {:5d} | {:8d} |".format(num_test_pids, num_query_imgs_same)) 54 | logger.info(" query(diff) | {:5d} | {:8d} |".format(num_test_pids, num_query_imgs_diff)) 55 | logger.info(" gallery | {:5d} | {:8d} |".format(num_test_pids, num_gallery_imgs)) 56 | logger.info(" --------------------------------------------") 57 | logger.info(" total | {:5d} | {:8d} | {:9d}".format(num_total_pids, num_total_imgs, num_total_clothes)) 58 | logger.info(" --------------------------------------------") 59 | 60 | self.train = train 61 | self.val = val 62 | self.query_same = query_same 63 | self.query_diff = query_diff 64 | self.gallery = gallery 65 | 66 | self.num_train_pids = num_train_pids 67 | self.num_train_clothes = num_train_clothes 68 | self.pid2clothes = pid2clothes 69 | self.gallery_idx = gallery_idx 70 | 71 | def _check_before_run(self): 72 | """Check if all files are available before going deeper""" 73 | if not osp.exists(self.dataset_dir): 74 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 75 | if not osp.exists(self.train_dir): 76 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 77 | if not osp.exists(self.val_dir): 78 | raise RuntimeError("'{}' is not available".format(self.val_dir)) 79 | if not osp.exists(self.test_dir): 80 | raise RuntimeError("'{}' is not available".format(self.test_dir)) 81 | 82 | def _process_dir_train(self, dir_path): 83 | pdirs = glob.glob(osp.join(dir_path, '*')) 84 | pdirs.sort() 85 | 86 | pid_container = set() 87 | clothes_container = set() 88 | for pdir in pdirs: 89 | pid = int(osp.basename(pdir)) 90 | pid_container.add(pid) 91 | img_dirs = glob.glob(osp.join(pdir, '*.jpg')) 92 | for img_dir in img_dirs: 93 | cam = osp.basename(img_dir)[0] # 'A' or 'B' or 'C' 94 | if cam in ['A', 'B']: 95 | clothes_container.add(osp.basename(pdir)) 96 | else: 97 | clothes_container.add(osp.basename(pdir)+osp.basename(img_dir)[0]) 98 | pid_container = sorted(pid_container) 99 | clothes_container = sorted(clothes_container) 100 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 101 | clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)} 102 | cam2label = {'A': 0, 'B': 1, 'C': 2} 103 | 104 | num_pids = len(pid_container) 105 | num_clothes = len(clothes_container) 106 | 107 | dataset = [] 108 | pid2clothes = np.zeros((num_pids, num_clothes)) 109 | for pdir in pdirs: 110 | pid = int(osp.basename(pdir)) 111 | img_dirs = glob.glob(osp.join(pdir, '*.jpg')) 112 | for img_dir in img_dirs: 113 | cam = osp.basename(img_dir)[0] # 'A' or 'B' or 'C' 114 | label = pid2label[pid] 115 | camid = cam2label[cam] 116 | if cam in ['A', 'B']: 117 | clothes_id = clothes2label[osp.basename(pdir)] 118 | else: 119 | clothes_id = clothes2label[osp.basename(pdir)+osp.basename(img_dir)[0]] 120 | cloth_path = img_dir.replace('/rgb/', '/cloth/') 121 | _cloth_path = img_dir.replace('/rgb/', '/cloth_/') 122 | contour_path = img_dir.replace('/rgb/', '/contour/') 123 | dataset.append((img_dir, label, camid, clothes_id, cloth_path, _cloth_path, contour_path)) 124 | pid2clothes[label, clothes_id] = 1 125 | 126 | num_imgs = len(dataset) 127 | 128 | return dataset, num_pids, num_imgs, num_clothes, pid2clothes 129 | 130 | def _process_dir_test(self, test_path): 131 | pdirs = glob.glob(osp.join(test_path, '*')) 132 | pdirs.sort() 133 | 134 | pid_container = set() 135 | for pdir in glob.glob(osp.join(test_path, 'A', '*')): 136 | pid = int(osp.basename(pdir)) 137 | pid_container.add(pid) 138 | pid_container = sorted(pid_container) 139 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 140 | cam2label = {'A': 0, 'B': 1, 'C': 2} 141 | 142 | num_pids = len(pid_container) 143 | num_clothes = num_pids * 2 144 | 145 | query_dataset_same_clothes = [] 146 | query_dataset_diff_clothes = [] 147 | gallery_dataset = [] 148 | for cam in ['A', 'B', 'C']: 149 | pdirs = glob.glob(osp.join(test_path, cam, '*')) 150 | for pdir in pdirs: 151 | pid = int(osp.basename(pdir)) 152 | img_dirs = glob.glob(osp.join(pdir, '*.jpg')) 153 | for img_dir in img_dirs: 154 | # pid = pid2label[pid] 155 | camid = cam2label[cam] 156 | if cam == 'A': 157 | clothes_id = pid2label[pid] * 2 158 | gallery_dataset.append((img_dir, pid, camid, clothes_id)) 159 | elif cam == 'B': 160 | clothes_id = pid2label[pid] * 2 161 | query_dataset_same_clothes.append((img_dir, pid, camid, clothes_id)) 162 | else: 163 | clothes_id = pid2label[pid] * 2 + 1 164 | query_dataset_diff_clothes.append((img_dir, pid, camid, clothes_id)) 165 | 166 | pid2imgidx = {} 167 | for idx, (img_dir, pid, camid, clothes_id) in enumerate(gallery_dataset): 168 | if pid not in pid2imgidx: 169 | pid2imgidx[pid] = [] 170 | pid2imgidx[pid].append(idx) 171 | 172 | # get 10 gallery index to perform single-shot test 173 | gallery_idx = {} 174 | random.seed(3) 175 | for idx in range(0, 10): 176 | gallery_idx[idx] = [] 177 | for pid in pid2imgidx: 178 | gallery_idx[idx].append(random.choice(pid2imgidx[pid])) 179 | 180 | num_imgs_query_same = len(query_dataset_same_clothes) 181 | num_imgs_query_diff = len(query_dataset_diff_clothes) 182 | num_imgs_gallery = len(gallery_dataset) 183 | 184 | return query_dataset_same_clothes, query_dataset_diff_clothes, gallery_dataset, \ 185 | num_pids, num_imgs_query_same, num_imgs_query_diff, num_imgs_gallery, \ 186 | num_clothes, gallery_idx 187 | -------------------------------------------------------------------------------- /data/datasets/vcclothes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import glob 4 | import h5py 5 | import random 6 | import math 7 | import logging 8 | import numpy as np 9 | import os.path as osp 10 | from scipy.io import loadmat 11 | from tools.utils import mkdir_if_missing, write_json, read_json 12 | 13 | 14 | class VCClothes(object): 15 | """ VC-Clothes 16 | 17 | Reference: 18 | Wang et al. When Person Re-identification Meets Changing Clothes. In CVPR Workshop, 2020. 19 | 20 | URL: https://wanfb.github.io/dataset.html 21 | """ 22 | dataset_dir = 'VC-Clothes' 23 | def __init__(self, root='data', mode='all', **kwargs): 24 | self.dataset_dir = osp.join(root, self.dataset_dir) 25 | self.train_dir = osp.join(self.dataset_dir, 'train') 26 | self.query_dir = osp.join(self.dataset_dir, 'query') 27 | self.gallery_dir = osp.join(self.dataset_dir, 'gallery') 28 | # 'all' for all cameras; 'sc' for cam2&3; 'cc' for cam3&4 29 | self.mode = mode 30 | self._check_before_run() 31 | 32 | train, num_train_pids, num_train_imgs, num_train_clothes, pid2clothes = self._process_dir_train() 33 | query, gallery, num_test_pids, num_query_imgs, num_gallery_imgs, num_test_clothes = self._process_dir_test() 34 | num_total_pids = num_train_pids + num_test_pids 35 | num_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgs 36 | num_test_imgs = num_query_imgs + num_gallery_imgs 37 | num_total_clothes = num_train_clothes + num_test_clothes 38 | 39 | logger = logging.getLogger('reid.dataset') 40 | logger.info("=> VC-Clothes loaded") 41 | logger.info("Dataset statistics:") 42 | logger.info(" ----------------------------------------") 43 | logger.info(" subset | # ids | # images | # clothes") 44 | logger.info(" ----------------------------------------") 45 | logger.info(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_clothes)) 46 | logger.info(" test | {:5d} | {:8d} | {:9d}".format(num_test_pids, num_test_imgs, num_test_clothes)) 47 | logger.info(" query | {:5d} | {:8d} |".format(num_test_pids, num_query_imgs)) 48 | logger.info(" gallery | {:5d} | {:8d} |".format(num_test_pids, num_gallery_imgs)) 49 | logger.info(" ----------------------------------------") 50 | logger.info(" total | {:5d} | {:8d} | {:9d}".format(num_total_pids, num_total_imgs, num_total_clothes)) 51 | logger.info(" ----------------------------------------") 52 | 53 | self.train = train 54 | self.query = query 55 | self.gallery = gallery 56 | 57 | self.num_train_pids = num_train_pids 58 | self.num_train_clothes = num_train_clothes 59 | self.pid2clothes = pid2clothes 60 | 61 | def _check_before_run(self): 62 | """Check if all files are available before going deeper""" 63 | if not osp.exists(self.dataset_dir): 64 | raise RuntimeError("'{}' is not available".format(self.dataset_dir)) 65 | if not osp.exists(self.train_dir): 66 | raise RuntimeError("'{}' is not available".format(self.train_dir)) 67 | if not osp.exists(self.query_dir): 68 | raise RuntimeError("'{}' is not available".format(self.query_dir)) 69 | if not osp.exists(self.gallery_dir): 70 | raise RuntimeError("'{}' is not available".format(self.gallery_dir)) 71 | 72 | def _process_dir_train(self): 73 | img_paths = glob.glob(osp.join(self.train_dir, '*.jpg')) 74 | img_paths.sort() 75 | pattern = re.compile(r'(\d+)-(\d+)-(\d+)-(\d+)') 76 | 77 | pid_container = set() 78 | clothes_container = set() 79 | for img_path in img_paths: 80 | pid, camid, clothes, _ = pattern.search(img_path).groups() 81 | clothes_id = pid + clothes 82 | pid, camid = int(pid), int(camid) 83 | pid_container.add(pid) 84 | clothes_container.add(clothes_id) 85 | pid_container = sorted(pid_container) 86 | clothes_container = sorted(clothes_container) 87 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 88 | clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)} 89 | 90 | num_pids = len(pid_container) 91 | num_clothes = len(clothes_container) 92 | 93 | dataset = [] 94 | pid2clothes = np.zeros((num_pids, num_clothes)) 95 | for img_path in img_paths: 96 | pid, camid, clothes, _ = pattern.search(img_path).groups() 97 | clothes_id = pid + clothes 98 | pid, camid = int(pid), int(camid) 99 | camid -= 1 # index starts from 0 100 | pid = pid2label[pid] 101 | clothes_id = clothes2label[clothes_id] 102 | dataset.append((img_path, pid, camid, clothes_id)) 103 | pid2clothes[pid, clothes_id] = 1 104 | 105 | num_imgs = len(dataset) 106 | 107 | return dataset, num_pids, num_imgs, num_clothes, pid2clothes 108 | 109 | def _process_dir_test(self): 110 | query_img_paths = glob.glob(osp.join(self.query_dir, '*.jpg')) 111 | gallery_img_paths = glob.glob(osp.join(self.gallery_dir, '*.jpg')) 112 | query_img_paths.sort() 113 | gallery_img_paths.sort() 114 | pattern = re.compile(r'(\d+)-(\d+)-(\d+)-(\d+)') 115 | 116 | pid_container = set() 117 | clothes_container = set() 118 | for img_path in query_img_paths: 119 | pid, camid, clothes, _ = pattern.search(img_path).groups() 120 | clothes_id = pid + clothes 121 | pid, camid = int(pid), int(camid) 122 | if self.mode == 'sc' and camid not in [2, 3]: 123 | continue 124 | if self.mode == 'cc' and camid not in [3, 4]: 125 | continue 126 | pid_container.add(pid) 127 | clothes_container.add(clothes_id) 128 | for img_path in gallery_img_paths: 129 | pid, camid, clothes, _ = pattern.search(img_path).groups() 130 | clothes_id = pid + clothes 131 | pid, camid = int(pid), int(camid) 132 | if self.mode == 'sc' and camid not in [2, 3]: 133 | continue 134 | if self.mode == 'cc' and camid not in [3, 4]: 135 | continue 136 | pid_container.add(pid) 137 | clothes_container.add(clothes_id) 138 | pid_container = sorted(pid_container) 139 | clothes_container = sorted(clothes_container) 140 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 141 | clothes2label = {clothes_id:label for label, clothes_id in enumerate(clothes_container)} 142 | 143 | num_pids = len(pid_container) 144 | num_clothes = len(clothes_container) 145 | 146 | query_dataset = [] 147 | gallery_dataset = [] 148 | for img_path in query_img_paths: 149 | pid, camid, clothes, _ = pattern.search(img_path).groups() 150 | clothes_id = pid + clothes 151 | pid, camid = int(pid), int(camid) 152 | if self.mode == 'sc' and camid not in [2, 3]: 153 | continue 154 | if self.mode == 'cc' and camid not in [3, 4]: 155 | continue 156 | camid -= 1 # index starts from 0 157 | clothes_id = clothes2label[clothes_id] 158 | query_dataset.append((img_path, pid, camid, clothes_id)) 159 | 160 | for img_path in gallery_img_paths: 161 | pid, camid, clothes, _ = pattern.search(img_path).groups() 162 | clothes_id = pid + clothes 163 | pid, camid = int(pid), int(camid) 164 | if self.mode == 'sc' and camid not in [2, 3]: 165 | continue 166 | if self.mode == 'cc' and camid not in [3, 4]: 167 | continue 168 | camid -= 1 # index starts from 0 169 | clothes_id = clothes2label[clothes_id] 170 | gallery_dataset.append((img_path, pid, camid, clothes_id)) 171 | 172 | num_imgs_query = len(query_dataset) 173 | num_imgs_gallery = len(gallery_dataset) 174 | 175 | return query_dataset, gallery_dataset, num_pids, num_imgs_query, num_imgs_gallery, num_clothes 176 | 177 | 178 | def VCClothesSameClothes(root='data', **kwargs): 179 | return VCClothes(root=root, mode='sc') 180 | 181 | 182 | def VCClothesClothesChanging(root='data', **kwargs): 183 | return VCClothes(root=root, mode='cc') 184 | -------------------------------------------------------------------------------- /data/img_transforms.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import * 2 | from PIL import Image 3 | import random 4 | import math 5 | from torchvision.transforms import functional as F 6 | import collections 7 | import torch 8 | import numpy as np 9 | 10 | 11 | class ResizeWithEqualScale(object): 12 | """ 13 | Resize an image with equal scale as the original image. 14 | 15 | Args: 16 | height (int): resized height. 17 | width (int): resized width. 18 | interpolation: interpolation manner. 19 | fill_color (tuple): color for padding. 20 | """ 21 | def __init__(self, height, width, interpolation=Image.BILINEAR, fill_color=(0,0,0)): 22 | self.height = height 23 | self.width = width 24 | self.interpolation = interpolation 25 | self.fill_color = fill_color 26 | 27 | def __call__(self, img): 28 | width, height = img.size 29 | if self.height / self.width >= height / width: 30 | height = int(self.width * (height / width)) 31 | width = self.width 32 | else: 33 | width = int(self.height * (width / height)) 34 | height = self.height 35 | 36 | resized_img = img.resize((width, height), self.interpolation) 37 | new_img = Image.new('RGB', (self.width, self.height), self.fill_color) 38 | new_img.paste(resized_img, (int((self.width - width) / 2), int((self.height - height) / 2))) 39 | 40 | return new_img 41 | 42 | 43 | class RandomCroping(object): 44 | """ 45 | With a probability, first increase image size to (1 + 1/8), and then perform random crop. 46 | 47 | Args: 48 | p (float): probability of performing this transformation. Default: 0.5. 49 | """ 50 | def __init__(self, p=0.5, interpolation=Image.BILINEAR): 51 | self.p = p 52 | self.interpolation = interpolation 53 | 54 | def __call__(self, img1, img2, img3, contour): 55 | """ 56 | Args: 57 | img (PIL Image): Image to be cropped. 58 | 59 | 60 | Returns: 61 | PIL Image: Cropped image. 62 | """ 63 | width, height = img1.size 64 | if random.uniform(0, 1) >= self.p: 65 | return img1, img2, img3, contour 66 | 67 | new_width, new_height = int(round(width * 1.125)), int(round(height * 1.125)) 68 | resized_img1 = img1.resize((new_width, new_height), self.interpolation) 69 | resized_img2 = img2.resize((new_width, new_height), self.interpolation) 70 | resized_img3 = img3.resize((new_width, new_height), self.interpolation) 71 | resized_contour = contour.resize((new_width, new_height), self.interpolation) 72 | x_maxrange = new_width - width 73 | y_maxrange = new_height - height 74 | x1 = int(round(random.uniform(0, x_maxrange))) 75 | y1 = int(round(random.uniform(0, y_maxrange))) 76 | croped_img1 = resized_img1.crop((x1, y1, x1 + width, y1 + height)) 77 | croped_img2 = resized_img2.crop((x1, y1, x1 + width, y1 + height)) 78 | croped_img3 = resized_img3.crop((x1, y1, x1 + width, y1 + height)) 79 | croped_contour = resized_contour.crop((x1, y1, x1 + width, y1 + height)) 80 | 81 | return croped_img1, croped_img2, croped_img3, croped_contour 82 | 83 | 84 | class RandomErasing(object): 85 | """ 86 | Randomly selects a rectangle region in an image and erases its pixels. 87 | 88 | Reference: 89 | Zhong et al. Random Erasing Data Augmentation. arxiv: 1708.04896, 2017. 90 | 91 | Args: 92 | probability: The probability that the Random Erasing operation will be performed. 93 | sl: Minimum proportion of erased area against input image. 94 | sh: Maximum proportion of erased area against input image. 95 | r1: Minimum aspect ratio of erased area. 96 | mean: Erasing value. 97 | """ 98 | 99 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 100 | self.probability = probability 101 | self.mean = mean 102 | self.sl = sl 103 | self.sh = sh 104 | self.r1 = r1 105 | 106 | def __call__(self, img1, img2, img3, contour): 107 | 108 | if random.uniform(0, 1) >= self.probability: 109 | return img1, img2, img3, contour 110 | 111 | for attempt in range(100): 112 | area = img1.size()[1] * img1.size()[2] 113 | 114 | target_area = random.uniform(self.sl, self.sh) * area 115 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 116 | 117 | h = int(round(math.sqrt(target_area * aspect_ratio))) 118 | w = int(round(math.sqrt(target_area / aspect_ratio))) 119 | 120 | if w < img1.size()[2] and h < img1.size()[1]: 121 | x1 = random.randint(0, img1.size()[1] - h) 122 | y1 = random.randint(0, img1.size()[2] - w) 123 | if img1.size()[0] == 3: 124 | img1[0, x1:x1+h, y1:y1+w] = self.mean[0] 125 | img1[1, x1:x1+h, y1:y1+w] = self.mean[1] 126 | img1[2, x1:x1+h, y1:y1+w] = self.mean[2] 127 | else: 128 | img1[0, x1:x1+h, y1:y1+w] = self.mean[0] 129 | 130 | if img2.size()[0] == 3: 131 | img2[0, x1:x1+h, y1:y1+w] = self.mean[0] 132 | img2[1, x1:x1+h, y1:y1+w] = self.mean[1] 133 | img2[2, x1:x1+h, y1:y1+w] = self.mean[2] 134 | else: 135 | img2[0, x1:x1+h, y1:y1+w] = self.mean[0] 136 | 137 | if img3.size()[0] == 3: 138 | img3[0, x1:x1+h, y1:y1+w] = self.mean[0] 139 | img3[1, x1:x1+h, y1:y1+w] = self.mean[1] 140 | img3[2, x1:x1+h, y1:y1+w] = self.mean[2] 141 | else: 142 | img3[0, x1:x1+h, y1:y1+w] = self.mean[0] 143 | 144 | contour[0, x1:x1+h, y1:y1+w] = 1.0 145 | 146 | return img1, img2, img3, contour 147 | 148 | return img1, img2, img3, contour 149 | 150 | class Resize(object): 151 | """Resize the input PIL Image to the given size. 152 | 153 | Args: 154 | size (sequence or int): Desired output size. If size is a sequence like 155 | (h, w), output size will be matched to this. If size is an int, 156 | smaller edge of the image will be matched to this number. 157 | i.e, if height > width, then image will be rescaled to 158 | (size * height / width, size) 159 | interpolation (int, optional): Desired interpolation. Default is 160 | ``PIL.Image.BILINEAR`` 161 | """ 162 | 163 | def __init__(self, size, interpolation=Image.BILINEAR): 164 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 165 | self.size = size 166 | self.interpolation = interpolation 167 | 168 | def __call__(self, img1, img2, img3, contour): 169 | """ 170 | Args: 171 | img (PIL Image): Image to be scaled. 172 | 173 | Returns: 174 | PIL Image: Rescaled image. 175 | """ 176 | img1 = F.resize(img1, self.size, self.interpolation) 177 | img2 = F.resize(img2, self.size, self.interpolation) 178 | img3 = F.resize(img3, self.size, self.interpolation) 179 | contour = F.resize(contour, self.size, self.interpolation) 180 | return img1, img2, img3, contour 181 | 182 | class ToTensor(object): 183 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 184 | 185 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 186 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 187 | """ 188 | 189 | def __call__(self, img1, img2, img3, contour): 190 | """ 191 | Args: 192 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 193 | 194 | Returns: 195 | Tensor: Converted image. 196 | """ 197 | return F.to_tensor(img1), F.to_tensor(img2), F.to_tensor(img3), F.to_tensor(contour) 198 | 199 | class Normalize(object): 200 | """Normalize a tensor image with mean and standard deviation. 201 | Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform 202 | will normalize each channel of the input ``torch.*Tensor`` i.e. 203 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 204 | 205 | Args: 206 | mean (sequence): Sequence of means for each channel. 207 | std (sequence): Sequence of standard deviations for each channel. 208 | """ 209 | 210 | def __init__(self, mean, std, mean2=0.5, std2=0.5): 211 | self.mean1 = mean 212 | self.std1 = std 213 | self.mean2 = mean2 214 | self.std2 = std2 215 | self.mean_con = 0.5 # (0.5, 0.5, 0.5) 216 | self.std_con = 0.5 # (0.5, 0.5, 0.5) 217 | 218 | def __call__(self, tensor1, tensor2, tensor3, contour): 219 | """ 220 | Args: 221 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 222 | 223 | Returns: 224 | Tensor: Normalized Tensor image. 225 | """ 226 | # print (torch.max(tensor1), torch.min(tensor1)) 227 | # print (torch.max(tensor2), torch.min(tensor2)) 228 | # print (torch.max(tensor3), torch.min(tensor3)) 229 | # print (torch.max(contour), torch.min(contour)) 230 | return F.normalize(tensor1, self.mean1, self.std1), F.normalize(tensor2, self.mean_con, self.std_con), F.normalize(tensor3, self.mean_con, self.std_con), 0.0 - F.normalize(contour, self.mean2, self.std2) 231 | 232 | class RandomHorizontalFlip(object): 233 | """Horizontally flip the given PIL Image randomly with a given probability. 234 | 235 | Args: 236 | p (float): probability of the image being flipped. Default value is 0.5 237 | """ 238 | 239 | def __init__(self, p=0.5): 240 | self.p = p 241 | 242 | def __call__(self, img1, img2, img3, contour): 243 | """ 244 | Args: 245 | img (PIL Image): Image to be flipped. 246 | 247 | Returns: 248 | PIL Image: Randomly flipped image. 249 | """ 250 | if random.random() < self.p: 251 | return F.hflip(img1), F.hflip(img2), F.hflip(img3), F.hflip(contour) 252 | return img1, img2, img3, contour 253 | 254 | class Compose(object): 255 | def __init__(self, transforms): 256 | self.transforms = transforms 257 | 258 | def __call__(self, img1, img2, img3, contour): 259 | for idx,t in enumerate(self.transforms): 260 | img1, img2, img3, contour = t(img1, img2, img3, contour) 261 | return img1, img2, img3, contour -------------------------------------------------------------------------------- /data/samplers.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import random 4 | import numpy as np 5 | from torch import distributed as dist 6 | from collections import defaultdict 7 | from torch.utils.data.sampler import Sampler 8 | 9 | 10 | class RandomIdentitySampler(Sampler): 11 | """ 12 | Randomly sample N identities, then for each identity, 13 | randomly sample K instances, therefore batch size is N*K. 14 | 15 | Args: 16 | data_source (Dataset): dataset to sample from. 17 | num_instances (int): number of instances per identity. 18 | """ 19 | def __init__(self, data_source, num_instances=4): 20 | self.data_source = data_source 21 | self.num_instances = num_instances 22 | self.index_dic = defaultdict(list) 23 | for index, (_, pid, _, _) in enumerate(data_source): 24 | self.index_dic[pid].append(index) 25 | self.pids = list(self.index_dic.keys()) 26 | self.num_identities = len(self.pids) 27 | 28 | # compute number of examples in an epoch 29 | self.length = 0 30 | for pid in self.pids: 31 | idxs = self.index_dic[pid] 32 | num = len(idxs) 33 | if num < self.num_instances: 34 | num = self.num_instances 35 | self.length += num - num % self.num_instances 36 | 37 | def __iter__(self): 38 | list_container = [] 39 | 40 | for pid in self.pids: 41 | idxs = copy.deepcopy(self.index_dic[pid]) 42 | if len(idxs) < self.num_instances: 43 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 44 | random.shuffle(idxs) 45 | batch_idxs = [] 46 | for idx in idxs: 47 | batch_idxs.append(idx) 48 | if len(batch_idxs) == self.num_instances: 49 | list_container.append(batch_idxs) 50 | batch_idxs = [] 51 | 52 | random.shuffle(list_container) 53 | 54 | ret = [] 55 | for batch_idxs in list_container: 56 | ret.extend(batch_idxs) 57 | 58 | return iter(ret) 59 | 60 | def __len__(self): 61 | return self.length 62 | 63 | 64 | class DistributedRandomIdentitySampler(Sampler): 65 | """ 66 | Randomly sample N identities, then for each identity, 67 | randomly sample K instances, therefore batch size is N*K. 68 | 69 | Args: 70 | - data_source (Dataset): dataset to sample from. 71 | - num_instances (int): number of instances per identity. 72 | - num_replicas (int, optional): Number of processes participating in 73 | distributed training. By default, :attr:`world_size` is retrieved from the 74 | current distributed group. 75 | - rank (int, optional): Rank of the current process within :attr:`num_replicas`. 76 | By default, :attr:`rank` is retrieved from the current distributed group. 77 | - seed (int, optional): random seed used to shuffle the sampler. 78 | This number should be identical across all 79 | processes in the distributed group. Default: ``0``. 80 | """ 81 | def __init__(self, data_source, num_instances=4, 82 | num_replicas=None, rank=None, seed=0): 83 | if num_replicas is None: 84 | if not dist.is_available(): 85 | raise RuntimeError("Requires distributed package to be available") 86 | num_replicas = dist.get_world_size() 87 | if rank is None: 88 | if not dist.is_available(): 89 | raise RuntimeError("Requires distributed package to be available") 90 | rank = dist.get_rank() 91 | if rank >= num_replicas or rank < 0: 92 | raise ValueError( 93 | "Invalid rank {}, rank should be in the interval" 94 | " [0, {}]".format(rank, num_replicas - 1)) 95 | self.num_replicas = num_replicas 96 | self.rank = rank 97 | self.seed = seed 98 | self.epoch = 0 99 | 100 | self.data_source = data_source 101 | self.num_instances = num_instances 102 | self.index_dic = defaultdict(list) 103 | for index, (batch) in enumerate(data_source): 104 | pid = batch[1] 105 | self.index_dic[pid].append(index) 106 | self.pids = list(self.index_dic.keys()) 107 | self.num_identities = len(self.pids) 108 | 109 | # compute number of examples in an epoch 110 | self.length = 0 111 | for pid in self.pids: 112 | idxs = self.index_dic[pid] 113 | num = len(idxs) 114 | if num < self.num_instances: 115 | num = self.num_instances 116 | self.length += num - num % self.num_instances 117 | assert self.length % self.num_instances == 0 118 | 119 | if self.length // self.num_instances % self.num_replicas != 0: 120 | self.num_samples = math.ceil((self.length // self.num_instances - self.num_replicas) / self.num_replicas) * self.num_instances 121 | else: 122 | self.num_samples = math.ceil(self.length / self.num_replicas) 123 | self.total_size = self.num_samples * self.num_replicas 124 | 125 | def __iter__(self): 126 | # deterministically shuffle based on epoch and seed 127 | random.seed(self.seed + self.epoch) 128 | np.random.seed(self.seed + self.epoch) 129 | 130 | list_container = [] 131 | for pid in self.pids: 132 | idxs = copy.deepcopy(self.index_dic[pid]) 133 | if len(idxs) < self.num_instances: 134 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 135 | random.shuffle(idxs) 136 | batch_idxs = [] 137 | for idx in idxs: 138 | batch_idxs.append(idx) 139 | if len(batch_idxs) == self.num_instances: 140 | list_container.append(batch_idxs) 141 | batch_idxs = [] 142 | random.shuffle(list_container) 143 | 144 | # remove tail of data to make it evenly divisible. 145 | list_container = list_container[:self.total_size//self.num_instances] 146 | assert len(list_container) == self.total_size//self.num_instances 147 | 148 | # subsample 149 | list_container = list_container[self.rank:self.total_size//self.num_instances:self.num_replicas] 150 | assert len(list_container) == self.num_samples//self.num_instances 151 | 152 | ret = [] 153 | for batch_idxs in list_container: 154 | ret.extend(batch_idxs) 155 | 156 | return iter(ret) 157 | 158 | def __len__(self): 159 | return self.num_samples 160 | 161 | def set_epoch(self, epoch): 162 | """ 163 | Sets the epoch for this sampler. This ensures all replicas 164 | use a different random ordering for each epoch. Otherwise, the next iteration of this 165 | sampler will yield the same ordering. 166 | 167 | Args: 168 | epoch (int): Epoch number. 169 | """ 170 | self.epoch = epoch 171 | 172 | 173 | class DistributedInferenceSampler(Sampler): 174 | """ 175 | refer to: https://github.com/huggingface/transformers/blob/447808c85f0e6d6b0aeeb07214942bf1e578f9d2/src/transformers/trainer_pt_utils.py 176 | 177 | Distributed Sampler that subsamples indicies sequentially, 178 | making it easier to collate all results at the end. 179 | Even though we only use this sampler for eval and predict (no training), 180 | which means that the model params won't have to be synced (i.e. will not hang 181 | for synchronization even if varied number of forward passes), we still add extra 182 | samples to the sampler to make it evenly divisible (like in `DistributedSampler`) 183 | to make it easy to `gather` or `reduce` resulting tensors at the end of the loop. 184 | """ 185 | def __init__(self, dataset, rank=None, num_replicas=None): 186 | if num_replicas is None: 187 | if not dist.is_available(): 188 | raise RuntimeError("Requires distributed package to be available") 189 | num_replicas = dist.get_world_size() 190 | if rank is None: 191 | if not dist.is_available(): 192 | raise RuntimeError("Requires distributed package to be available") 193 | rank = dist.get_rank() 194 | self.dataset = dataset 195 | self.num_replicas = num_replicas 196 | self.rank = rank 197 | 198 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 199 | self.total_size = self.num_samples * self.num_replicas 200 | 201 | def __iter__(self): 202 | indices = list(range(len(self.dataset))) 203 | # add extra samples to make it evenly divisible 204 | indices += [indices[-1]] * (self.total_size - len(indices)) 205 | # subsample 206 | indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples] 207 | return iter(indices) 208 | 209 | def __len__(self): 210 | return self.num_samples -------------------------------------------------------------------------------- /data/temporal_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | 5 | class TemporalRandomCrop(object): 6 | """Temporally crop the given frame indices at a random location. 7 | 8 | If the number of frames is less than the size, 9 | loop the indices as many times as necessary to satisfy the size. 10 | 11 | Args: 12 | size (int): Desired output size of the crop. 13 | stride (int): Temporal sampling stride 14 | """ 15 | 16 | def __init__(self, size=4, stride=8): 17 | self.size = size 18 | self.stride = stride 19 | 20 | def __call__(self, frame_indices): 21 | """ 22 | Args: 23 | frame_indices (list): frame indices to be cropped. 24 | Returns: 25 | list: Cropped frame indices. 26 | """ 27 | frame_indices = list(frame_indices) 28 | 29 | if len(frame_indices) >= self.size * self.stride: 30 | rand_end = len(frame_indices) - (self.size - 1) * self.stride - 1 31 | begin_index = random.randint(0, rand_end) 32 | end_index = begin_index + (self.size - 1) * self.stride + 1 33 | out = frame_indices[begin_index:end_index:self.stride] 34 | elif len(frame_indices) >= self.size: 35 | clips = [] 36 | for i in range(self.size): 37 | clips.append(frame_indices[len(frame_indices)//self.size*i : len(frame_indices)//self.size*(i+1)]) 38 | out = [] 39 | for i in range(self.size): 40 | out.append(random.choice(clips[i])) 41 | else: 42 | index = np.random.choice(len(frame_indices), size=self.size, replace=True) 43 | index.sort() 44 | out = [frame_indices[index[i]] for i in range(self.size)] 45 | 46 | return out 47 | 48 | 49 | class TemporalBeginCrop(object): 50 | """Temporally crop the given frame indices at a beginning. 51 | 52 | If the number of frames is less than the size, 53 | loop the indices as many times as necessary to satisfy the size. 54 | 55 | Args: 56 | size (int): Desired output size of the crop. 57 | stride (int): Temporal sampling stride 58 | """ 59 | 60 | def __init__(self, size=8, stride=4): 61 | self.size = size 62 | self.stride = stride 63 | 64 | def __call__(self, frame_indices): 65 | frame_indices = list(frame_indices) 66 | 67 | if len(frame_indices) >= self.size * self.stride: 68 | out = frame_indices[0 : self.size * self.stride : self.stride] 69 | else: 70 | out = frame_indices[0 : self.size] 71 | while len(out) < self.size: 72 | for index in out: 73 | if len(out) >= self.size: 74 | break 75 | out.append(index) 76 | 77 | return out 78 | 79 | 80 | class TemporalDivisionCrop(object): 81 | """Temporally crop the given frame indices by TSN. 82 | 83 | Args: 84 | size (int): Desired output size of the crop. 85 | """ 86 | def __init__(self, size=4): 87 | self.size = size 88 | 89 | def __call__(self, frame_indices): 90 | """ 91 | Args: 92 | frame_indices (list): frame indices to be cropped. 93 | Returns: 94 | list: Cropped frame indices. 95 | """ 96 | frame_indices = list(frame_indices) 97 | 98 | if len(frame_indices) >= self.size: 99 | clips = [] 100 | for i in range(self.size): 101 | clips.append(frame_indices[len(frame_indices)//self.size*i : len(frame_indices)//self.size*(i+1)]) 102 | out = [] 103 | for i in range(self.size): 104 | out.append(random.choice(clips[i])) 105 | else: 106 | index = np.random.choice(len(frame_indices), size=self.size, replace=True) 107 | index.sort() 108 | out = [frame_indices[index[i]] for i in range(self.size)] 109 | 110 | return out 111 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from losses.cross_entropy_loss_with_label_smooth import CrossEntropyWithLabelSmooth 3 | from losses.triplet_loss import TripletLoss 4 | from losses.contrastive_loss import ContrastiveLoss 5 | from losses.arcface_loss import ArcFaceLoss 6 | from losses.cosface_loss import CosFaceLoss, PairwiseCosFaceLoss 7 | from losses.circle_loss import CircleLoss, PairwiseCircleLoss 8 | from losses.clothes_based_adversarial_loss import ClothesBasedAdversarialLoss, ClothesBasedAdversarialLossWithMemoryBank 9 | 10 | 11 | def build_losses(config, num_train_clothes): 12 | # Build identity classification loss 13 | if config.LOSS.CLA_LOSS == 'crossentropy': 14 | criterion_cla = nn.CrossEntropyLoss() 15 | elif config.LOSS.CLA_LOSS == 'crossentropylabelsmooth': 16 | criterion_cla = CrossEntropyWithLabelSmooth() 17 | elif config.LOSS.CLA_LOSS == 'arcface': 18 | criterion_cla = ArcFaceLoss(scale=config.LOSS.CLA_S, margin=config.LOSS.CLA_M) 19 | elif config.LOSS.CLA_LOSS == 'cosface': 20 | criterion_cla = CosFaceLoss(scale=config.LOSS.CLA_S, margin=config.LOSS.CLA_M) 21 | elif config.LOSS.CLA_LOSS == 'circle': 22 | criterion_cla = CircleLoss(scale=config.LOSS.CLA_S, margin=config.LOSS.CLA_M) 23 | else: 24 | raise KeyError("Invalid classification loss: '{}'".format(config.LOSS.CLA_LOSS)) 25 | 26 | # Build pairwise loss 27 | if config.LOSS.PAIR_LOSS == 'triplet': 28 | criterion_pair = TripletLoss(margin=config.LOSS.PAIR_M) 29 | elif config.LOSS.PAIR_LOSS == 'contrastive': 30 | criterion_pair = ContrastiveLoss(scale=config.LOSS.PAIR_S) 31 | elif config.LOSS.PAIR_LOSS == 'cosface': 32 | criterion_pair = PairwiseCosFaceLoss(scale=config.LOSS.PAIR_S, margin=config.LOSS.PAIR_M) 33 | elif config.LOSS.PAIR_LOSS == 'circle': 34 | criterion_pair = PairwiseCircleLoss(scale=config.LOSS.PAIR_S, margin=config.LOSS.PAIR_M) 35 | else: 36 | raise KeyError("Invalid pairwise loss: '{}'".format(config.LOSS.PAIR_LOSS)) 37 | 38 | # Build clothes classification loss 39 | if config.LOSS.CLOTHES_CLA_LOSS == 'crossentropy': 40 | criterion_clothes = nn.CrossEntropyLoss() 41 | elif config.LOSS.CLOTHES_CLA_LOSS == 'cosface': 42 | criterion_clothes = CosFaceLoss(scale=config.LOSS.CLA_S, margin=0) 43 | else: 44 | raise KeyError("Invalid clothes classification loss: '{}'".format(config.LOSS.CLOTHES_CLA_LOSS)) 45 | 46 | # Build clothes-based adversarial loss 47 | if config.LOSS.CAL == 'cal': 48 | criterion_cal = ClothesBasedAdversarialLoss(scale=config.LOSS.CLA_S, epsilon=config.LOSS.EPSILON) 49 | criterion_shuffle = ClothesBasedAdversarialLoss(scale=config.LOSS.CLA_S, epsilon=1.0) 50 | elif config.LOSS.CAL == 'calwithmemory': 51 | criterion_cal = ClothesBasedAdversarialLossWithMemoryBank(num_clothes=num_train_clothes, feat_dim=config.MODEL.FEATURE_DIM, 52 | momentum=config.LOSS.MOMENTUM, scale=config.LOSS.CLA_S, epsilon=config.LOSS.EPSILON) 53 | criterion_shuffle = ClothesBasedAdversarialLossWithMemoryBank(num_clothes=num_train_clothes, feat_dim=config.MODEL.FEATURE_DIM, 54 | momentum=config.LOSS.MOMENTUM, scale=config.LOSS.CLA_S, epsilon=1.0) 55 | else: 56 | raise KeyError("Invalid clothing classification loss: '{}'".format(config.LOSS.CAL)) 57 | 58 | recon_uncloth = nn.L1Loss() 59 | recon_contour = nn.L1Loss() 60 | recon_cloth = nn.L1Loss() 61 | 62 | return criterion_cla, criterion_pair, criterion_clothes, criterion_cal, criterion_shuffle, recon_uncloth, recon_contour, recon_cloth 63 | -------------------------------------------------------------------------------- /losses/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/losses/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /losses/__pycache__/arcface_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/losses/__pycache__/arcface_loss.cpython-36.pyc -------------------------------------------------------------------------------- /losses/__pycache__/circle_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/losses/__pycache__/circle_loss.cpython-36.pyc -------------------------------------------------------------------------------- /losses/__pycache__/clothes_based_adversarial_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/losses/__pycache__/clothes_based_adversarial_loss.cpython-36.pyc -------------------------------------------------------------------------------- /losses/__pycache__/contrastive_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/losses/__pycache__/contrastive_loss.cpython-36.pyc -------------------------------------------------------------------------------- /losses/__pycache__/cosface_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/losses/__pycache__/cosface_loss.cpython-36.pyc -------------------------------------------------------------------------------- /losses/__pycache__/cross_entropy_loss_with_label_smooth.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/losses/__pycache__/cross_entropy_loss_with_label_smooth.cpython-36.pyc -------------------------------------------------------------------------------- /losses/__pycache__/gather.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/losses/__pycache__/gather.cpython-36.pyc -------------------------------------------------------------------------------- /losses/__pycache__/triplet_loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/losses/__pycache__/triplet_loss.cpython-36.pyc -------------------------------------------------------------------------------- /losses/arcface_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | 7 | class ArcFaceLoss(nn.Module): 8 | """ ArcFace loss. 9 | 10 | Reference: 11 | Deng et al. ArcFace: Additive Angular Margin Loss for Deep Face Recognition. In CVPR, 2019. 12 | 13 | Args: 14 | scale (float): scaling factor. 15 | margin (float): pre-defined margin. 16 | """ 17 | def __init__(self, scale=16, margin=0.1): 18 | super().__init__() 19 | self.s = scale 20 | self.m = margin 21 | 22 | def forward(self, inputs, targets): 23 | """ 24 | Args: 25 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 26 | targets: ground truth labels with shape (batch_size) 27 | """ 28 | # get a one-hot index 29 | index = inputs.data * 0.0 30 | index.scatter_(1, targets.data.view(-1, 1), 1) 31 | index = index.bool() 32 | 33 | cos_m = math.cos(self.m) 34 | sin_m = math.sin(self.m) 35 | cos_t = inputs[index] 36 | sin_t = torch.sqrt(1.0 - cos_t * cos_t) 37 | cos_t_add_m = cos_t * cos_m - sin_t * sin_m 38 | 39 | cond_v = cos_t - math.cos(math.pi - self.m) 40 | cond = F.relu(cond_v) 41 | keep = cos_t - math.sin(math.pi - self.m) * self.m 42 | 43 | cos_t_add_m = torch.where(cond.bool(), cos_t_add_m, keep) 44 | 45 | output = inputs * 1.0 46 | output[index] = cos_t_add_m 47 | output = self.s * output 48 | 49 | return F.cross_entropy(output, targets) 50 | -------------------------------------------------------------------------------- /losses/circle_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch import distributed as dist 5 | from losses.gather import GatherLayer 6 | 7 | 8 | class CircleLoss(nn.Module): 9 | """ Circle Loss based on the predictions of classifier. 10 | 11 | Reference: 12 | Sun et al. Circle Loss: A Unified Perspective of Pair Similarity Optimization. In CVPR, 2020. 13 | 14 | Args: 15 | scale (float): scaling factor. 16 | margin (float): pre-defined margin. 17 | """ 18 | def __init__(self, scale=96, margin=0.3, **kwargs): 19 | super().__init__() 20 | self.s = scale 21 | self.m = margin 22 | 23 | def forward(self, inputs, targets): 24 | """ 25 | Args: 26 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 27 | targets: ground truth labels with shape (batch_size) 28 | """ 29 | mask = torch.zeros_like(inputs).cuda() 30 | mask.scatter_(1, targets.view(-1, 1), 1.0) 31 | 32 | pos_scale = self.s * F.relu(1 + self.m - inputs.detach()) 33 | neg_scale = self.s * F.relu(inputs.detach() + self.m) 34 | scale_matrix = pos_scale * mask + neg_scale * (1 - mask) 35 | 36 | scores = (inputs - (1 - self.m) * mask - self.m * (1 - mask)) * scale_matrix 37 | 38 | loss = F.cross_entropy(scores, targets) 39 | 40 | return loss 41 | 42 | 43 | class PairwiseCircleLoss(nn.Module): 44 | """ Circle Loss among sample pairs. 45 | 46 | Reference: 47 | Sun et al. Circle Loss: A Unified Perspective of Pair Similarity Optimization. In CVPR, 2020. 48 | 49 | Args: 50 | scale (float): scaling factor. 51 | margin (float): pre-defined margin. 52 | """ 53 | def __init__(self, scale=48, margin=0.35, **kwargs): 54 | super().__init__() 55 | self.s = scale 56 | self.m = margin 57 | 58 | def forward(self, inputs, targets): 59 | """ 60 | Args: 61 | inputs: sample features (before classifier) with shape (batch_size, feat_dim) 62 | targets: ground truth labels with shape (batch_size) 63 | """ 64 | # l2-normalize 65 | inputs = F.normalize(inputs, p=2, dim=1) 66 | 67 | # gather all samples from different GPUs as gallery to compute pairwise loss. 68 | gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0) 69 | gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0) 70 | m, n = targets.size(0), gallery_targets.size(0) 71 | 72 | # compute cosine similarity 73 | similarities = torch.matmul(inputs, gallery_inputs.t()) 74 | 75 | # get mask for pos/neg pairs 76 | targets, gallery_targets = targets.view(-1, 1), gallery_targets.view(-1, 1) 77 | mask = torch.eq(targets, gallery_targets.T).float().cuda() 78 | mask_self = torch.zeros_like(mask) 79 | rank = dist.get_rank() 80 | mask_self[:, rank * m:(rank + 1) * m] += torch.eye(m).float().cuda() 81 | mask_pos = mask - mask_self 82 | mask_neg = 1 - mask 83 | 84 | pos_scale = self.s * F.relu(1 + self.m - similarities.detach()) 85 | neg_scale = self.s * F.relu(similarities.detach() + self.m) 86 | scale_matrix = pos_scale * mask_pos + neg_scale * mask_neg 87 | 88 | scores = (similarities - self.m) * mask_neg + (1 - self.m - similarities) * mask_pos 89 | scores = scores * scale_matrix 90 | 91 | neg_scores_LSE = torch.logsumexp(scores * mask_neg - 99999999 * (1 - mask_neg), dim=1) 92 | pos_scores_LSE = torch.logsumexp(scores * mask_pos - 99999999 * (1 - mask_pos), dim=1) 93 | 94 | loss = F.softplus(neg_scores_LSE + pos_scores_LSE).mean() 95 | 96 | return loss 97 | -------------------------------------------------------------------------------- /losses/clothes_based_adversarial_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from losses.gather import GatherLayer 5 | 6 | 7 | class ClothesBasedAdversarialLoss(nn.Module): 8 | """ Clothes-based Adversarial Loss. 9 | 10 | Reference: 11 | Gu et al. Clothes-Changing Person Re-identification with RGB Modality Only. In CVPR, 2022. 12 | 13 | Args: 14 | scale (float): scaling factor. 15 | epsilon (float): a trade-off hyper-parameter. 16 | """ 17 | def __init__(self, scale=16, epsilon=0.1): 18 | super().__init__() 19 | self.scale = scale 20 | self.epsilon = epsilon 21 | 22 | def forward(self, inputs, targets, positive_mask): 23 | """ 24 | Args: 25 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 26 | targets: ground truth labels with shape (batch_size) 27 | positive_mask: positive mask matrix with shape (batch_size, num_classes). The clothes classes with 28 | the same identity as the anchor sample are defined as positive clothes classes and their mask 29 | values are 1. The clothes classes with different identities from the anchor sample are defined 30 | as negative clothes classes and their mask values in positive_mask are 0. 31 | """ 32 | inputs = self.scale * inputs 33 | negtive_mask = 1 - positive_mask 34 | identity_mask = torch.zeros(inputs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1).cuda() 35 | 36 | exp_logits = torch.exp(inputs) 37 | log_sum_exp_pos_and_all_neg = torch.log((exp_logits * negtive_mask).sum(1, keepdim=True) + exp_logits) 38 | log_prob = inputs - log_sum_exp_pos_and_all_neg 39 | 40 | mask = (1 - self.epsilon) * identity_mask + self.epsilon / positive_mask.sum(1, keepdim=True) * positive_mask 41 | loss = (- mask * log_prob).sum(1).mean() 42 | 43 | return loss 44 | 45 | 46 | class ClothesBasedAdversarialLossWithMemoryBank(nn.Module): 47 | """ Clothes-based Adversarial Loss between mini batch and the samples in memory. 48 | 49 | Reference: 50 | Gu et al. Clothes-Changing Person Re-identification with RGB Modality Only. In CVPR, 2022. 51 | 52 | Args: 53 | num_clothes (int): the number of clothes classes. 54 | feat_dim (int): the dimensions of feature. 55 | momentum (float): momentum to update memory. 56 | scale (float): scaling factor. 57 | epsilon (float): a trade-off hyper-parameter. 58 | """ 59 | def __init__(self, num_clothes, feat_dim, momentum=0., scale=16, epsilon=0.1): 60 | super().__init__() 61 | self.num_clothes = num_clothes 62 | self.feat_dim = feat_dim 63 | self.momentum = momentum 64 | self.epsilon = epsilon 65 | self.scale = scale 66 | 67 | self.register_buffer('feature_memory', torch.zeros((num_clothes, feat_dim))) 68 | self.register_buffer('label_memory', torch.zeros(num_clothes, dtype=torch.int64) - 1) 69 | self.has_been_filled = False 70 | 71 | def forward(self, inputs, targets, positive_mask): 72 | """ 73 | Args: 74 | inputs: sample features (before classifier) with shape (batch_size, feat_dim) 75 | targets: ground truth labels with shape (batch_size) 76 | positive_mask: positive mask matrix with shape (batch_size, num_classes). 77 | """ 78 | # gather all samples from different GPUs to update memory. 79 | gathered_inputs = torch.cat(GatherLayer.apply(inputs), dim=0) 80 | gathered_targets = torch.cat(GatherLayer.apply(targets), dim=0) 81 | self._update_memory(gathered_inputs.detach(), gathered_targets) 82 | 83 | inputs_norm = F.normalize(inputs, p=2, dim=1) 84 | memory_norm = F.normalize(self.feature_memory.detach(), p=2, dim=1) 85 | similarities = torch.matmul(inputs_norm, memory_norm.t()) * self.scale 86 | 87 | negtive_mask = 1 - positive_mask 88 | mask_identity = torch.zeros(positive_mask.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1).cuda() 89 | 90 | if not self.has_been_filled: 91 | invalid_index = self.label_memory == -1 92 | positive_mask[:, invalid_index] = 0 93 | negtive_mask[:, invalid_index] = 0 94 | if sum(invalid_index.type(torch.int)) == 0: 95 | self.has_been_filled = True 96 | print('Memory bank is full') 97 | 98 | # compute log_prob 99 | exp_logits = torch.exp(similarities) 100 | log_sum_exp_pos_and_all_neg = torch.log((exp_logits * negtive_mask).sum(1, keepdim=True) + exp_logits) 101 | log_prob = similarities - log_sum_exp_pos_and_all_neg 102 | 103 | # compute mean of log-likelihood over positive 104 | mask = (1 - self.epsilon) * mask_identity + self.epsilon / positive_mask.sum(1, keepdim=True) * positive_mask 105 | loss = (- mask * log_prob).sum(1).mean() 106 | 107 | return loss 108 | 109 | def _update_memory(self, features, labels): 110 | label_to_feat = {} 111 | for x, y in zip(features, labels): 112 | if y not in label_to_feat: 113 | label_to_feat[y] = [x.unsqueeze(0)] 114 | else: 115 | label_to_feat[y].append(x.unsqueeze(0)) 116 | if not self.has_been_filled: 117 | for y in label_to_feat: 118 | feat = torch.mean(torch.cat(label_to_feat[y], dim=0), dim=0) 119 | self.feature_memory[y] = feat 120 | self.label_memory[y] = y 121 | else: 122 | for y in label_to_feat: 123 | feat = torch.mean(torch.cat(label_to_feat[y], dim=0), dim=0) 124 | self.feature_memory[y] = self.momentum * self.feature_memory[y] + (1. - self.momentum) * feat 125 | # self.embedding_memory[y] /= self.embedding_memory[y].norm() -------------------------------------------------------------------------------- /losses/contrastive_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch import distributed as dist 5 | from losses.gather import GatherLayer 6 | 7 | 8 | class ContrastiveLoss(nn.Module): 9 | """ Supervised Contrastive Learning Loss among sample pairs. 10 | 11 | Args: 12 | scale (float): scaling factor. 13 | """ 14 | def __init__(self, scale=16, **kwargs): 15 | super().__init__() 16 | self.s = scale 17 | 18 | def forward(self, inputs, targets): 19 | """ 20 | Args: 21 | inputs: sample features (before classifier) with shape (batch_size, feat_dim) 22 | targets: ground truth labels with shape (batch_size) 23 | """ 24 | # l2-normalize 25 | inputs = F.normalize(inputs, p=2, dim=1) 26 | 27 | # gather all samples from different GPUs as gallery to compute pairwise loss. 28 | gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0) 29 | gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0) 30 | m, n = targets.size(0), gallery_targets.size(0) 31 | 32 | # compute cosine similarity 33 | similarities = torch.matmul(inputs, gallery_inputs.t()) * self.s 34 | 35 | # get mask for pos/neg pairs 36 | targets, gallery_targets = targets.view(-1, 1), gallery_targets.view(-1, 1) 37 | mask = torch.eq(targets, gallery_targets.T).float().cuda() 38 | mask_self = torch.zeros_like(mask) 39 | rank = dist.get_rank() 40 | mask_self[:, rank * m:(rank + 1) * m] += torch.eye(m).float().cuda() 41 | mask_pos = mask - mask_self 42 | mask_neg = 1 - mask 43 | 44 | # compute log_prob 45 | exp_logits = torch.exp(similarities) * (1 - mask_self) 46 | # log_prob = similarities - torch.log(exp_logits.sum(1, keepdim=True)) 47 | log_sum_exp_pos_and_all_neg = torch.log((exp_logits * mask_neg).sum(1, keepdim=True) + exp_logits) 48 | log_prob = similarities - log_sum_exp_pos_and_all_neg 49 | 50 | # compute mean of log-likelihood over positive 51 | loss = (mask_pos * log_prob).sum(1) / mask_pos.sum(1) 52 | 53 | loss = - loss.mean() 54 | 55 | return loss -------------------------------------------------------------------------------- /losses/cosface_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch import distributed as dist 5 | from losses.gather import GatherLayer 6 | 7 | 8 | class CosFaceLoss(nn.Module): 9 | """ CosFace Loss based on the predictions of classifier. 10 | 11 | Reference: 12 | Wang et al. CosFace: Large Margin Cosine Loss for Deep Face Recognition. In CVPR, 2018. 13 | 14 | Args: 15 | scale (float): scaling factor. 16 | margin (float): pre-defined margin. 17 | """ 18 | def __init__(self, scale=16, margin=0.1, **kwargs): 19 | super().__init__() 20 | self.s = scale 21 | self.m = margin 22 | 23 | def forward(self, inputs, targets): 24 | """ 25 | Args: 26 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 27 | targets: ground truth labels with shape (batch_size) 28 | """ 29 | one_hot = torch.zeros_like(inputs) 30 | one_hot.scatter_(1, targets.view(-1, 1), 1.0) 31 | 32 | output = self.s * (inputs - one_hot * self.m) 33 | 34 | return F.cross_entropy(output, targets) 35 | 36 | 37 | class PairwiseCosFaceLoss(nn.Module): 38 | """ CosFace Loss among sample pairs. 39 | 40 | Reference: 41 | Sun et al. Circle Loss: A Unified Perspective of Pair Similarity Optimization. In CVPR, 2020. 42 | 43 | Args: 44 | scale (float): scaling factor. 45 | margin (float): pre-defined margin. 46 | """ 47 | def __init__(self, scale=16, margin=0): 48 | super().__init__() 49 | self.s = scale 50 | self.m = margin 51 | 52 | def forward(self, inputs, targets): 53 | """ 54 | Args: 55 | inputs: sample features (before classifier) with shape (batch_size, feat_dim) 56 | targets: ground truth labels with shape (batch_size) 57 | """ 58 | # l2-normalize 59 | inputs = F.normalize(inputs, p=2, dim=1) 60 | 61 | # gather all samples from different GPUs as gallery to compute pairwise loss. 62 | gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0) 63 | gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0) 64 | m, n = targets.size(0), gallery_targets.size(0) 65 | 66 | # compute cosine similarity 67 | similarities = torch.matmul(inputs, gallery_inputs.t()) 68 | 69 | # get mask for pos/neg pairs 70 | targets, gallery_targets = targets.view(-1, 1), gallery_targets.view(-1, 1) 71 | mask = torch.eq(targets, gallery_targets.T).float().cuda() 72 | mask_self = torch.zeros_like(mask) 73 | rank = dist.get_rank() 74 | mask_self[:, rank * m:(rank + 1) * m] += torch.eye(m).float().cuda() 75 | mask_pos = mask - mask_self 76 | mask_neg = 1 - mask 77 | 78 | scores = (similarities + self.m) * mask_neg - similarities * mask_pos 79 | scores = scores * self.s 80 | 81 | neg_scores_LSE = torch.logsumexp(scores * mask_neg - 99999999 * (1 - mask_neg), dim=1) 82 | pos_scores_LSE = torch.logsumexp(scores * mask_pos - 99999999 * (1 - mask_pos), dim=1) 83 | 84 | loss = F.softplus(neg_scores_LSE + pos_scores_LSE).mean() 85 | 86 | return loss -------------------------------------------------------------------------------- /losses/cross_entropy_loss_with_label_smooth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class CrossEntropyWithLabelSmooth(nn.Module): 6 | """ Cross entropy loss with label smoothing regularization. 7 | 8 | Reference: 9 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. In CVPR, 2016. 10 | Equation: 11 | y = (1 - epsilon) * y + epsilon / K. 12 | 13 | Args: 14 | epsilon (float): a hyper-parameter in the above equation. 15 | """ 16 | def __init__(self, epsilon=0.1): 17 | super().__init__() 18 | self.epsilon = epsilon 19 | self.logsoftmax = nn.LogSoftmax(dim=1) 20 | 21 | def forward(self, inputs, targets): 22 | """ 23 | Args: 24 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 25 | targets: ground truth labels with shape (batch_size) 26 | """ 27 | _, num_classes = inputs.size() 28 | log_probs = self.logsoftmax(inputs) 29 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1).cuda() 30 | targets = (1 - self.epsilon) * targets + self.epsilon / num_classes 31 | loss = (- targets * log_probs).mean(0).sum() 32 | 33 | return loss 34 | -------------------------------------------------------------------------------- /losses/gather.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | 5 | class GatherLayer(torch.autograd.Function): 6 | """Gather tensors from all process, supporting backward propagation.""" 7 | 8 | @staticmethod 9 | def forward(ctx, input): 10 | ctx.save_for_backward(input) 11 | output = [torch.zeros_like(input) for _ in range(dist.get_world_size())] 12 | dist.all_gather(output, input) 13 | 14 | return tuple(output) 15 | 16 | @staticmethod 17 | def backward(ctx, *grads): 18 | (input,) = ctx.saved_tensors 19 | grad_out = torch.zeros_like(input) 20 | 21 | # dist.reduce_scatter(grad_out, list(grads)) 22 | # grad_out.div_(dist.get_world_size()) 23 | 24 | grad_out[:] = grads[dist.get_rank()] 25 | 26 | return grad_out -------------------------------------------------------------------------------- /losses/triplet_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | from losses.gather import GatherLayer 6 | 7 | 8 | class TripletLoss(nn.Module): 9 | """ Triplet loss with hard example mining. 10 | 11 | Reference: 12 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 13 | 14 | Args: 15 | margin (float): pre-defined margin. 16 | 17 | Note that we use cosine similarity, rather than Euclidean distance in the original paper. 18 | """ 19 | def __init__(self, margin=0.3): 20 | super().__init__() 21 | self.m = margin 22 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 23 | 24 | def forward(self, inputs, targets): 25 | """ 26 | Args: 27 | inputs: sample features (before classifier) with shape (batch_size, feat_dim) 28 | targets: ground truth labels with shape (batch_size) 29 | """ 30 | # l2-normlize 31 | inputs = F.normalize(inputs, p=2, dim=1) 32 | 33 | # gather all samples from different GPUs as gallery to compute pairwise loss. 34 | gallery_inputs = torch.cat(GatherLayer.apply(inputs), dim=0) 35 | gallery_targets = torch.cat(GatherLayer.apply(targets), dim=0) 36 | 37 | # compute distance 38 | dist = 1 - torch.matmul(inputs, gallery_inputs.t()) # values in [0, 2] 39 | 40 | # get positive and negative masks 41 | targets, gallery_targets = targets.view(-1,1), gallery_targets.view(-1,1) 42 | mask_pos = torch.eq(targets, gallery_targets.T).float().cuda() 43 | mask_neg = 1 - mask_pos 44 | 45 | # For each anchor, find the hardest positive and negative pairs 46 | dist_ap, _ = torch.max((dist - mask_neg * 99999999.), dim=1) 47 | dist_an, _ = torch.min((dist + mask_pos * 99999999.), dim=1) 48 | 49 | # Compute ranking hinge loss 50 | y = torch.ones_like(dist_an) 51 | loss = self.ranking_loss(dist_an, dist_ap, y) 52 | 53 | return loss -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import datetime 5 | import argparse 6 | import logging 7 | import os.path as osp 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | from torch.optim import lr_scheduler 14 | from torch import distributed as dist 15 | from apex import amp 16 | 17 | from configs.default_img import get_img_config 18 | from configs.default_vid import get_vid_config 19 | from data import build_dataloader 20 | from models import build_model 21 | from losses import build_losses 22 | from tools.utils import save_checkpoint, set_seed, get_logger 23 | from train import train_cal, train_cal_with_memory 24 | from test import test, test_prcc 25 | from infer import infer, infer_prcc 26 | 27 | 28 | VID_DATASET = ['ccvid'] 29 | 30 | 31 | def parse_option(): 32 | parser = argparse.ArgumentParser(description='Train clothes-changing re-id model with clothes-based adversarial loss') 33 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file') 34 | # Datasets 35 | parser.add_argument('--root', type=str, help="your root path to data directory") 36 | parser.add_argument('--dataset', type=str, default='ltcc', help="ltcc, prcc, vcclothes, ccvid, last, deepchange") 37 | # Miscs 38 | parser.add_argument('--output', type=str, help="your output path to save model and logs") 39 | parser.add_argument('--resume', type=str, metavar='PATH') 40 | parser.add_argument('--amp', action='store_true', help="automatic mixed precision") 41 | parser.add_argument('--eval', action='store_true', help="evaluation only") 42 | parser.add_argument('--infer', action='store_true', help="inference only") 43 | parser.add_argument('--tag', type=str, help='tag for log file') 44 | parser.add_argument('--gpu', default='0', type=str, help='gpu device ids for CUDA_VISIBLE_DEVICES') 45 | # Miscs 46 | parser.add_argument('--spr', default = '1.0', type = float) 47 | parser.add_argument('--sacr', default = '0.05', type = float) 48 | parser.add_argument('--rr', default = '1.0', type = float) 49 | 50 | args, unparsed = parser.parse_known_args() 51 | if args.dataset in VID_DATASET: 52 | config = get_vid_config(args) 53 | else: 54 | config = get_img_config(args) 55 | 56 | return config 57 | 58 | 59 | def main(config): 60 | # Build dataloader 61 | if config.DATA.DATASET == 'prcc': 62 | trainloader, queryloader_same, queryloader_diff, galleryloader, dataset, train_sampler = build_dataloader(config) 63 | else: 64 | trainloader, queryloader, galleryloader, dataset, train_sampler = build_dataloader(config) 65 | # Define a matrix pid2clothes with shape (num_pids, num_clothes). 66 | # pid2clothes[i, j] = 1 when j-th clothes belongs to i-th identity. Otherwise, pid2clothes[i, j] = 0. 67 | pid2clothes = torch.from_numpy(dataset.pid2clothes) 68 | 69 | # Build model 70 | model, classifier, clothes_classifier = build_model(config, dataset.num_train_pids, dataset.num_train_clothes) 71 | # model = nn.SyncBatchNorm.convert_sync_batchnorm(model) 72 | # Build identity classification loss, pairwise loss, clothes classificaiton loss, and adversarial loss. 73 | criterion_cla, criterion_pair, criterion_clothes, criterion_adv, criterion_shuffle, recon_uncloth, recon_contour, recon_cloth = build_losses(config, dataset.num_train_clothes) 74 | # Build optimizer 75 | parameters = list(model.parameters()) + list(classifier.parameters()) 76 | if config.TRAIN.OPTIMIZER.NAME == 'adam': 77 | optimizer = optim.Adam(parameters, lr=config.TRAIN.OPTIMIZER.LR, 78 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY) 79 | optimizer_cc = optim.Adam(clothes_classifier.parameters(), lr=config.TRAIN.OPTIMIZER.LR, 80 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY) 81 | elif config.TRAIN.OPTIMIZER.NAME == 'adamw': 82 | optimizer = optim.AdamW(parameters, lr=config.TRAIN.OPTIMIZER.LR, 83 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY) 84 | optimizer_cc = optim.AdamW(clothes_classifier.parameters(), lr=config.TRAIN.OPTIMIZER.LR, 85 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY) 86 | elif config.TRAIN.OPTIMIZER.NAME == 'sgd': 87 | optimizer = optim.SGD(parameters, lr=config.TRAIN.OPTIMIZER.LR, momentum=0.9, 88 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY, nesterov=True) 89 | optimizer_cc = optim.SGD(clothes_classifier.parameters(), lr=config.TRAIN.OPTIMIZER.LR, momentum=0.9, 90 | weight_decay=config.TRAIN.OPTIMIZER.WEIGHT_DECAY, nesterov=True) 91 | else: 92 | raise KeyError("Unknown optimizer: {}".format(config.TRAIN.OPTIMIZER.NAME)) 93 | # Build lr_scheduler 94 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=config.TRAIN.LR_SCHEDULER.STEPSIZE, 95 | gamma=config.TRAIN.LR_SCHEDULER.DECAY_RATE) 96 | 97 | start_epoch = config.TRAIN.START_EPOCH 98 | if config.MODEL.RESUME: 99 | logger.info("Loading checkpoint from '{}'".format(config.MODEL.RESUME)) 100 | checkpoint = torch.load(config.MODEL.RESUME) 101 | model.load_state_dict(checkpoint['model_state_dict']) 102 | classifier.load_state_dict(checkpoint['classifier_state_dict']) 103 | if config.LOSS.CAL == 'calwithmemory': 104 | criterion_adv.load_state_dict(checkpoint['clothes_classifier_state_dict']) 105 | else: 106 | clothes_classifier.load_state_dict(checkpoint['clothes_classifier_state_dict']) 107 | start_epoch = checkpoint['epoch'] 108 | 109 | local_rank = dist.get_rank() 110 | model = model.cuda(local_rank) 111 | classifier = classifier.cuda(local_rank) 112 | if config.LOSS.CAL == 'calwithmemory': 113 | criterion_adv = criterion_adv.cuda(local_rank) 114 | else: 115 | clothes_classifier = clothes_classifier.cuda(local_rank) 116 | torch.cuda.set_device(local_rank) 117 | 118 | if config.TRAIN.AMP: 119 | [model, classifier], optimizer = amp.initialize([model, classifier], optimizer, opt_level="O1") 120 | if config.LOSS.CAL != 'calwithmemory': 121 | clothes_classifier, optimizer_cc = amp.initialize(clothes_classifier, optimizer_cc, opt_level="O1") 122 | 123 | model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) 124 | classifier = nn.parallel.DistributedDataParallel(classifier, device_ids=[local_rank], output_device=local_rank) 125 | if config.LOSS.CAL != 'calwithmemory': 126 | clothes_classifier = nn.parallel.DistributedDataParallel(clothes_classifier, device_ids=[local_rank], output_device=local_rank) 127 | 128 | if config.EVAL_MODE: 129 | logger.info("Evaluate only") 130 | with torch.no_grad(): 131 | if config.DATA.DATASET == 'prcc': 132 | test_prcc(config, model, queryloader_same, queryloader_diff, galleryloader, dataset) 133 | else: 134 | test(config, model, queryloader, galleryloader, dataset) 135 | return 136 | 137 | if config.INFER_MODE: 138 | logger.info("Infer only") 139 | with torch.no_grad(): 140 | if config.DATA.DATASET == 'prcc': 141 | infer_prcc(config, model, queryloader_same, queryloader_diff, galleryloader, dataset) 142 | else: 143 | infer(config, model, queryloader, galleryloader, dataset) 144 | return 145 | 146 | start_time = time.time() 147 | train_time = 0 148 | best_rank1 = -np.inf 149 | best_map = -np.inf 150 | best_epoch = 0 151 | logger.info("==> Start training") 152 | for epoch in range(start_epoch, config.TRAIN.MAX_EPOCH): 153 | train_sampler.set_epoch(epoch) 154 | start_train_time = time.time() 155 | if config.LOSS.CAL == 'calwithmemory': 156 | train_cal_with_memory(config, epoch, model, classifier, criterion_cla, criterion_pair, 157 | criterion_adv, optimizer, trainloader, pid2clothes) 158 | else: 159 | train_cal(config, epoch, model, classifier, clothes_classifier, criterion_cla, criterion_pair, 160 | criterion_clothes, criterion_adv, criterion_shuffle, recon_uncloth, recon_contour, recon_cloth, optimizer, optimizer_cc, trainloader, pid2clothes) 161 | train_time += round(time.time() - start_train_time) 162 | 163 | if (epoch+1) > config.TEST.START_EVAL and config.TEST.EVAL_STEP > 0 and \ 164 | (epoch+1) % config.TEST.EVAL_STEP == 0 or (epoch+1) == config.TRAIN.MAX_EPOCH or (epoch+1) >= 30: 165 | logger.info("==> Test") 166 | torch.cuda.empty_cache() 167 | if config.DATA.DATASET == 'prcc': 168 | rank1, mAP = test_prcc(config, model, queryloader_same, queryloader_diff, galleryloader, dataset) 169 | else: 170 | rank1, mAP = test(config, model, queryloader, galleryloader, dataset) 171 | torch.cuda.empty_cache() 172 | is_best = rank1 + mAP > best_rank1 + best_map 173 | if is_best: 174 | best_rank1 = rank1 175 | best_map = mAP 176 | best_epoch = epoch + 1 177 | 178 | model_state_dict = model.module.state_dict() 179 | classifier_state_dict = classifier.module.state_dict() 180 | if config.LOSS.CAL == 'calwithmemory': 181 | clothes_classifier_state_dict = criterion_adv.state_dict() 182 | else: 183 | clothes_classifier_state_dict = clothes_classifier.module.state_dict() 184 | if local_rank == 0: 185 | save_checkpoint({ 186 | 'model_state_dict': model_state_dict, 187 | 'classifier_state_dict': classifier_state_dict, 188 | 'clothes_classifier_state_dict': clothes_classifier_state_dict, 189 | 'rank1': rank1, 190 | 'epoch': epoch, 191 | }, is_best, osp.join(config.OUTPUT, 'checkpoint_ep' + str(epoch+1) + '.pth.tar')) 192 | scheduler.step() 193 | 194 | logger.info("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(best_rank1, best_epoch)) 195 | 196 | elapsed = round(time.time() - start_time) 197 | elapsed = str(datetime.timedelta(seconds=elapsed)) 198 | train_time = str(datetime.timedelta(seconds=train_time)) 199 | logger.info("Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".format(elapsed, train_time)) 200 | 201 | 202 | if __name__ == '__main__': 203 | config = parse_option() 204 | # Set GPU 205 | os.environ['CUDA_VISIBLE_DEVICES'] = config.GPU 206 | # Init dist 207 | dist.init_process_group(backend="nccl", init_method='env://') 208 | local_rank = dist.get_rank() 209 | # Set random seed 210 | set_seed(config.SEED + local_rank) 211 | # get logger 212 | if (not config.EVAL_MODE) and (not config.INFER_MODE): 213 | output_file = osp.join(config.OUTPUT, 'log_train_.log') 214 | elif config.EVAL_MODE: 215 | output_file = osp.join(config.OUTPUT, 'log_test.log') 216 | elif config.INFER_MODE: 217 | output_file = osp.join(config.OUTPUT, 'log_infer.log') 218 | logger = get_logger(output_file, local_rank, 'reid') 219 | logger.info("Config:\n-----------------------------------------") 220 | logger.info(config) 221 | logger.info("-----------------------------------------") 222 | 223 | main(config) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from models.classifier import Classifier, NormalizedClassifier 3 | from models.img_resnet import ResNet50 4 | from models.vid_resnet import C2DResNet50, I3DResNet50, AP3DResNet50, NLResNet50, AP3DNLResNet50 5 | 6 | 7 | __factory = { 8 | 'resnet50': ResNet50, 9 | 'c2dres50': C2DResNet50, 10 | 'i3dres50': I3DResNet50, 11 | 'ap3dres50': AP3DResNet50, 12 | 'nlres50': NLResNet50, 13 | 'ap3dnlres50': AP3DNLResNet50, 14 | } 15 | 16 | 17 | def build_model(config, num_identities, num_clothes): 18 | logger = logging.getLogger('reid.model') 19 | # Build backbone 20 | logger.info("Initializing model: {}".format(config.MODEL.NAME)) 21 | if config.MODEL.NAME not in __factory.keys(): 22 | raise KeyError("Invalid model: '{}'".format(config.MODEL.NAME)) 23 | else: 24 | logger.info("Init model: '{}'".format(config.MODEL.NAME)) 25 | model = __factory[config.MODEL.NAME](config) 26 | logger.info("Model size: {:.5f}M".format(sum(p.numel() for p in model.parameters())/1000000.0)) 27 | 28 | # Build classifier 29 | if config.LOSS.CLA_LOSS in ['crossentropy', 'crossentropylabelsmooth']: 30 | identity_classifier = Classifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_identities) 31 | else: 32 | identity_classifier = NormalizedClassifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_identities) 33 | 34 | clothes_classifier = NormalizedClassifier(feature_dim=config.MODEL.FEATURE_DIM, num_classes=num_clothes) 35 | 36 | return model, identity_classifier, clothes_classifier -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/classifier.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/models/__pycache__/classifier.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/img_resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/models/__pycache__/img_resnet.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/img_resnet_3.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/models/__pycache__/img_resnet_3.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/img_resnet_fc.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/models/__pycache__/img_resnet_fc.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/img_resnet_sep.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/models/__pycache__/img_resnet_sep.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/vid_resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/models/__pycache__/vid_resnet.cpython-36.pyc -------------------------------------------------------------------------------- /models/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import init 4 | from torch.nn import functional as F 5 | from torch.nn import Parameter 6 | 7 | 8 | __all__ = ['Classifier', 'NormalizedClassifier'] 9 | 10 | 11 | class Classifier(nn.Module): 12 | def __init__(self, feature_dim, num_classes): 13 | super().__init__() 14 | self.classifier = nn.Linear(feature_dim, num_classes) 15 | init.normal_(self.classifier.weight.data, std=0.001) 16 | init.constant_(self.classifier.bias.data, 0.0) 17 | 18 | def forward(self, x): 19 | y = self.classifier(x) 20 | 21 | return y 22 | 23 | 24 | class NormalizedClassifier(nn.Module): 25 | def __init__(self, feature_dim, num_classes): 26 | super().__init__() 27 | self.weight = Parameter(torch.Tensor(num_classes, feature_dim)) 28 | self.weight.data.uniform_(-1, 1).renorm_(2,0,1e-5).mul_(1e5) 29 | 30 | def forward(self, x): 31 | w = self.weight 32 | 33 | x = F.normalize(x, p=2, dim=1) 34 | w = F.normalize(w, p=2, dim=1) 35 | 36 | return F.linear(x, w) 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /models/img_resnet.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torch 3 | from torch import nn 4 | from torch.nn import init 5 | from models.utils import pooling 6 | 7 | class GEN(nn.Module): 8 | def __init__(self, in_feat_dim, out_img_dim, config, **kwargs): 9 | super().__init__() 10 | 11 | self.in_feat_dim = in_feat_dim 12 | self.out_img_dim = out_img_dim 13 | 14 | self.conv0 = nn.Conv2d(self.in_feat_dim, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=True) 15 | self.conv1 = nn.Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=True) 16 | self.conv2 = nn.Conv2d(64, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=True) 17 | self.conv3 = nn.Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=True) 18 | self.conv4 = nn.Conv2d(32, self.out_img_dim, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=True) 19 | 20 | self.up = nn.Upsample(scale_factor=2) 21 | 22 | self.bn = nn.BatchNorm2d(64) 23 | init.normal_(self.bn.weight.data, 1.0, 0.02) 24 | init.constant_(self.bn.bias.data, 0.0) 25 | 26 | self.relu = nn.ReLU() 27 | 28 | def forward(self, x): 29 | 30 | x = self.conv0(x) 31 | x = self.bn(x) 32 | x = self.relu(x) 33 | 34 | x = self.up(x) 35 | x = self.conv1(x) 36 | x = self.relu(x) 37 | 38 | x = self.up(x) 39 | x = self.conv2(x) 40 | x = self.relu(x) 41 | 42 | x = self.up(x) 43 | x = self.conv3(x) 44 | x = self.relu(x) 45 | 46 | x = self.up(x) 47 | x = self.conv4(x) 48 | x = torch.tanh(x) 49 | 50 | return x 51 | 52 | 53 | 54 | 55 | 56 | class ResNet50(nn.Module): 57 | def __init__(self, config, **kwargs): 58 | super().__init__() 59 | 60 | resnet50 = torchvision.models.resnet50(pretrained=True) 61 | if config.MODEL.RES4_STRIDE == 1: 62 | resnet50.layer4[0].conv2.stride=(1, 1) 63 | resnet50.layer4[0].downsample[0].stride=(1, 1) 64 | self.base = nn.Sequential(*list(resnet50.children())[:-2]) 65 | 66 | if config.MODEL.POOLING.NAME == 'avg': 67 | self.globalpooling = nn.AdaptiveAvgPool2d(1) 68 | elif config.MODEL.POOLING.NAME == 'max': 69 | self.globalpooling = nn.AdaptiveMaxPool2d(1) 70 | elif config.MODEL.POOLING.NAME == 'gem': 71 | self.globalpooling = pooling.GeMPooling(p=config.MODEL.POOLING.P) 72 | elif config.MODEL.POOLING.NAME == 'maxavg': 73 | self.globalpooling = pooling.MaxAvgPooling() 74 | else: 75 | raise KeyError("Invalid pooling: '{}'".format(config.MODEL.POOLING.NAME)) 76 | 77 | self.bn = nn.BatchNorm1d(config.MODEL.FEATURE_DIM) 78 | init.normal_(self.bn.weight.data, 1.0, 0.02) 79 | init.constant_(self.bn.bias.data, 0.0) 80 | 81 | self.uncloth_dim = config.MODEL.NO_CLOTHES_DIM//2 82 | self.contour_dim = config.MODEL.CONTOUR_DIM//2 83 | self.cloth_dim = config.MODEL.CLOTHES_DIM//2 84 | 85 | self.uncloth_net = GEN(in_feat_dim = self.uncloth_dim, out_img_dim=1, config = config) 86 | self.contour_net = GEN(in_feat_dim = self.contour_dim + self.cloth_dim, out_img_dim=1, config = config) 87 | self.cloth_net = GEN(in_feat_dim = self.cloth_dim, out_img_dim=1, config = config) 88 | 89 | 90 | 91 | def forward(self, x): 92 | x = self.base(x) 93 | x_ori = x 94 | x = self.globalpooling(x) 95 | x = x.view(x.size(0), -1) 96 | f = self.bn(x) 97 | 98 | f_unclo = x_ori[:, 0:self.uncloth_dim, :, :] 99 | f_cont = x_ori[:, self.uncloth_dim:self.uncloth_dim+self.contour_dim+self.cloth_dim, :, :] 100 | f_clo = x_ori[:, self.uncloth_dim+self.contour_dim:self.uncloth_dim+self.contour_dim+self.cloth_dim, :, :] 101 | 102 | unclo_img = self.uncloth_net(f_unclo) 103 | cont_img = self.contour_net(f_cont) 104 | clo_img = self.cloth_net(f_clo) 105 | 106 | return (f, unclo_img, cont_img, clo_img) 107 | -------------------------------------------------------------------------------- /models/utils/__pycache__/c3d_blocks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/models/utils/__pycache__/c3d_blocks.cpython-36.pyc -------------------------------------------------------------------------------- /models/utils/__pycache__/inflate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/models/utils/__pycache__/inflate.cpython-36.pyc -------------------------------------------------------------------------------- /models/utils/__pycache__/nonlocal_blocks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/models/utils/__pycache__/nonlocal_blocks.cpython-36.pyc -------------------------------------------------------------------------------- /models/utils/__pycache__/pooling.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/models/utils/__pycache__/pooling.cpython-36.pyc -------------------------------------------------------------------------------- /models/utils/inflate.py: -------------------------------------------------------------------------------- 1 | # inflate 2D modules to 3D modules 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | 6 | 7 | def inflate_conv(conv2d, 8 | time_dim=1, 9 | time_padding=0, 10 | time_stride=1, 11 | time_dilation=1, 12 | center=False): 13 | # To preserve activations, padding should be by continuity and not zero 14 | # or no padding in time dimension 15 | kernel_dim = (time_dim, conv2d.kernel_size[0], conv2d.kernel_size[1]) 16 | padding = (time_padding, conv2d.padding[0], conv2d.padding[1]) 17 | stride = (time_stride, conv2d.stride[0], conv2d.stride[0]) 18 | dilation = (time_dilation, conv2d.dilation[0], conv2d.dilation[1]) 19 | conv3d = nn.Conv3d( 20 | conv2d.in_channels, 21 | conv2d.out_channels, 22 | kernel_dim, 23 | padding=padding, 24 | dilation=dilation, 25 | stride=stride) 26 | # Repeat filter time_dim times along time dimension 27 | weight_2d = conv2d.weight.data 28 | if center: 29 | weight_3d = torch.zeros(*weight_2d.shape) 30 | weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) 31 | middle_idx = time_dim // 2 32 | weight_3d[:, :, middle_idx, :, :] = weight_2d 33 | else: 34 | weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1) 35 | weight_3d = weight_3d / time_dim 36 | 37 | # Assign new params 38 | conv3d.weight = nn.Parameter(weight_3d) 39 | conv3d.bias = conv2d.bias 40 | return conv3d 41 | 42 | 43 | def inflate_linear(linear2d, time_dim): 44 | """ 45 | Args: 46 | time_dim: final time dimension of the features 47 | """ 48 | linear3d = nn.Linear(linear2d.in_features * time_dim, 49 | linear2d.out_features) 50 | weight3d = linear2d.weight.data.repeat(1, time_dim) 51 | weight3d = weight3d / time_dim 52 | 53 | linear3d.weight = nn.Parameter(weight3d) 54 | linear3d.bias = linear2d.bias 55 | return linear3d 56 | 57 | 58 | def inflate_batch_norm(batch2d): 59 | # In pytorch 0.2.0 the 2d and 3d versions of batch norm 60 | # work identically except for the check that verifies the 61 | # input dimensions 62 | 63 | batch3d = nn.BatchNorm3d(batch2d.num_features) 64 | # retrieve 3d _check_input_dim function 65 | batch2d._check_input_dim = batch3d._check_input_dim 66 | return batch2d 67 | 68 | 69 | def inflate_pool(pool2d, 70 | time_dim=1, 71 | time_padding=0, 72 | time_stride=None, 73 | time_dilation=1): 74 | kernel_dim = (time_dim, pool2d.kernel_size, pool2d.kernel_size) 75 | padding = (time_padding, pool2d.padding, pool2d.padding) 76 | if time_stride is None: 77 | time_stride = time_dim 78 | stride = (time_stride, pool2d.stride, pool2d.stride) 79 | if isinstance(pool2d, nn.MaxPool2d): 80 | dilation = (time_dilation, pool2d.dilation, pool2d.dilation) 81 | pool3d = nn.MaxPool3d( 82 | kernel_dim, 83 | padding=padding, 84 | dilation=dilation, 85 | stride=stride, 86 | ceil_mode=pool2d.ceil_mode) 87 | elif isinstance(pool2d, nn.AvgPool2d): 88 | pool3d = nn.AvgPool3d(kernel_dim, stride=stride) 89 | else: 90 | raise ValueError( 91 | '{} is not among known pooling classes'.format(type(pool2d))) 92 | return pool3d 93 | 94 | 95 | class MaxPool2dFor3dInput(nn.Module): 96 | """ 97 | Since nn.MaxPool3d is nondeterministic operation, using fixed random seeds can't get consistent results. 98 | So we attempt to use max_pool2d to implement MaxPool3d with kernelsize (1, kernel_size, kernel_size). 99 | """ 100 | def __init__(self, kernel_size, stride=None, padding=0, dilation=1): 101 | super().__init__() 102 | self.maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) 103 | def forward(self, x): 104 | b, c, t, h, w = x.size() 105 | x = x.permute(0, 2, 1, 3, 4).contiguous() # b, t, c, h, w 106 | x = x.view(b*t, c, h, w) 107 | # max pooling 108 | x = self.maxpool(x) 109 | _, _, h, w = x.size() 110 | x = x.view(b, t, c, h, w).permute(0, 2, 1, 3, 4).contiguous() 111 | 112 | return x -------------------------------------------------------------------------------- /models/utils/nonlocal_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from models.utils import inflate 6 | 7 | 8 | class NonLocalBlockND(nn.Module): 9 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 10 | super(NonLocalBlockND, self).__init__() 11 | 12 | assert dimension in [1, 2, 3] 13 | 14 | self.dimension = dimension 15 | self.sub_sample = sub_sample 16 | self.in_channels = in_channels 17 | self.inter_channels = inter_channels 18 | 19 | if self.inter_channels is None: 20 | self.inter_channels = in_channels // 2 21 | if self.inter_channels == 0: 22 | self.inter_channels = 1 23 | 24 | if dimension == 3: 25 | conv_nd = nn.Conv3d 26 | # max_pool = inflate.MaxPool2dFor3dInput 27 | max_pool = nn.MaxPool3d 28 | bn = nn.BatchNorm3d 29 | elif dimension == 2: 30 | conv_nd = nn.Conv2d 31 | max_pool = nn.MaxPool2d 32 | bn = nn.BatchNorm2d 33 | else: 34 | conv_nd = nn.Conv1d 35 | max_pool = nn.MaxPool1d 36 | bn = nn.BatchNorm1d 37 | 38 | self.g = conv_nd(self.in_channels, self.inter_channels, 39 | kernel_size=1, stride=1, padding=0, bias=True) 40 | self.theta = conv_nd(self.in_channels, self.inter_channels, 41 | kernel_size=1, stride=1, padding=0, bias=True) 42 | self.phi = conv_nd(self.in_channels, self.inter_channels, 43 | kernel_size=1, stride=1, padding=0, bias=True) 44 | # if sub_sample: 45 | # self.g = nn.Sequential(self.g, max_pool(kernel_size=2)) 46 | # self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2)) 47 | if sub_sample: 48 | if dimension == 3: 49 | self.g = nn.Sequential(self.g, max_pool((1, 2, 2))) 50 | self.phi = nn.Sequential(self.phi, max_pool((1, 2, 2))) 51 | else: 52 | self.g = nn.Sequential(self.g, max_pool(kernel_size=2)) 53 | self.phi = nn.Sequential(self.phi, max_pool(kernel_size=2)) 54 | 55 | if bn_layer: 56 | self.W = nn.Sequential( 57 | conv_nd(self.inter_channels, self.in_channels, 58 | kernel_size=1, stride=1, padding=0, bias=True), 59 | bn(self.in_channels) 60 | ) 61 | else: 62 | self.W = conv_nd(self.inter_channels, self.in_channels, 63 | kernel_size=1, stride=1, padding=0, bias=True) 64 | 65 | # init 66 | for m in self.modules(): 67 | if isinstance(m, conv_nd): 68 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 69 | m.weight.data.normal_(0, math.sqrt(2. / n)) 70 | elif isinstance(m, bn): 71 | m.weight.data.fill_(1) 72 | m.bias.data.zero_() 73 | 74 | if bn_layer: 75 | nn.init.constant_(self.W[1].weight.data, 0.0) 76 | nn.init.constant_(self.W[1].bias.data, 0.0) 77 | else: 78 | nn.init.constant_(self.W.weight.data, 0.0) 79 | nn.init.constant_(self.W.bias.data, 0.0) 80 | 81 | 82 | def forward(self, x): 83 | ''' 84 | :param x: (b, c, t, h, w) 85 | :return: 86 | ''' 87 | batch_size = x.size(0) 88 | 89 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 90 | g_x = g_x.permute(0, 2, 1) 91 | 92 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 93 | theta_x = theta_x.permute(0, 2, 1) 94 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 95 | f = torch.matmul(theta_x, phi_x) 96 | f = F.softmax(f, dim=-1) 97 | 98 | y = torch.matmul(f, g_x) 99 | y = y.permute(0, 2, 1).contiguous() 100 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 101 | y = self.W(y) 102 | z = y + x 103 | 104 | return z 105 | 106 | 107 | class NonLocalBlock1D(NonLocalBlockND): 108 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 109 | super(NonLocalBlock1D, self).__init__(in_channels, 110 | inter_channels=inter_channels, 111 | dimension=1, sub_sample=sub_sample, 112 | bn_layer=bn_layer) 113 | 114 | 115 | class NonLocalBlock2D(NonLocalBlockND): 116 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 117 | super(NonLocalBlock2D, self).__init__(in_channels, 118 | inter_channels=inter_channels, 119 | dimension=2, sub_sample=sub_sample, 120 | bn_layer=bn_layer) 121 | 122 | 123 | class NonLocalBlock3D(NonLocalBlockND): 124 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 125 | super(NonLocalBlock3D, self).__init__(in_channels, 126 | inter_channels=inter_channels, 127 | dimension=3, sub_sample=sub_sample, 128 | bn_layer=bn_layer) 129 | -------------------------------------------------------------------------------- /models/utils/pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | class GeMPooling(nn.Module): 7 | def __init__(self, p=3, eps=1e-6): 8 | super().__init__() 9 | self.p = nn.Parameter(torch.ones(1) * p) 10 | self.eps = eps 11 | 12 | def forward(self, x): 13 | return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), x.size()[2:]).pow(1./self.p) 14 | 15 | 16 | class MaxAvgPooling(nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | self.maxpooling = nn.AdaptiveMaxPool2d(1) 20 | self.avgpooling = nn.AdaptiveAvgPool2d(1) 21 | 22 | # def forward(self, x): 23 | # max_f = self.maxpooling(x) 24 | # avg_f = self.avgpooling(x) 25 | 26 | # return torch.cat((max_f, avg_f), 1) 27 | def forward(self, x): 28 | max_f = self.maxpooling(x) 29 | avg_f = self.avgpooling(x) 30 | 31 | out = torch.zeros_like(torch.cat((max_f, avg_f), 1), dtype=torch.float) 32 | out[:,::2,:,:] = max_f 33 | out[:,1::2,:,:] = avg_f 34 | 35 | return out -------------------------------------------------------------------------------- /models/vid_resnet.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | import torch.nn as nn 3 | import torch 4 | from torch.nn import init 5 | from torch.nn import functional as F 6 | from models.utils import inflate 7 | from models.utils import c3d_blocks 8 | from models.utils import nonlocal_blocks 9 | 10 | 11 | __all__ = ['AP3DResNet50', 'AP3DNLResNet50', 'NLResNet50', 'C2DResNet50', 12 | 'I3DResNet50', 13 | ] 14 | 15 | class GEN(nn.Module): 16 | def __init__(self, in_feat_dim, out_img_dim, config, **kwargs): 17 | super().__init__() 18 | 19 | self.in_feat_dim = in_feat_dim 20 | self.out_img_dim = out_img_dim 21 | 22 | self.conv0 = nn.Conv3d(self.in_feat_dim, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=True) 23 | self.conv1 = nn.Conv3d(64, 64, kernel_size=(3, 5, 5), stride=(1, 1, 1), padding=(1, 2, 2), bias=True) 24 | self.conv2 = nn.Conv3d(64, 32, kernel_size=(3, 5, 5), stride=(1, 1, 1), padding=(1, 2, 2), bias=True) 25 | self.conv3 = nn.Conv3d(32, 32, kernel_size=(3, 5, 5), stride=(1, 1, 1), padding=(1, 2, 2), bias=True) 26 | self.conv4 = nn.Conv3d(32, self.out_img_dim, kernel_size=(3, 5, 5), stride=(1, 1, 1), padding=(1, 2, 2), bias=True) 27 | 28 | self.up = nn.Upsample(scale_factor=(1, 2, 2)) 29 | 30 | self.bn = nn.BatchNorm3d(64) 31 | init.normal_(self.bn.weight.data, 1.0, 0.02) 32 | init.constant_(self.bn.bias.data, 0.0) 33 | 34 | self.relu = nn.ReLU() 35 | 36 | def forward(self, x): 37 | x = self.conv0(x) 38 | x = self.bn(x) 39 | x = self.relu(x) 40 | 41 | x = self.up(x) 42 | x = self.conv1(x) 43 | x = self.relu(x) 44 | 45 | x = self.up(x) 46 | x = self.conv2(x) 47 | x = self.relu(x) 48 | 49 | x = self.up(x) 50 | x = self.conv3(x) 51 | x = self.relu(x) 52 | 53 | x = self.up(x) 54 | x = self.conv4(x) 55 | x = torch.tanh(x) 56 | 57 | return x 58 | 59 | 60 | class Bottleneck3D(nn.Module): 61 | def __init__(self, bottleneck2d, block, inflate_time=False, temperature=4, contrastive_att=True): 62 | super().__init__() 63 | self.conv1 = inflate.inflate_conv(bottleneck2d.conv1, time_dim=1) 64 | self.bn1 = inflate.inflate_batch_norm(bottleneck2d.bn1) 65 | if inflate_time == True: 66 | self.conv2 = block(bottleneck2d.conv2, temperature=temperature, contrastive_att=contrastive_att) 67 | else: 68 | self.conv2 = inflate.inflate_conv(bottleneck2d.conv2, time_dim=1) 69 | self.bn2 = inflate.inflate_batch_norm(bottleneck2d.bn2) 70 | self.conv3 = inflate.inflate_conv(bottleneck2d.conv3, time_dim=1) 71 | self.bn3 = inflate.inflate_batch_norm(bottleneck2d.bn3) 72 | self.relu = nn.ReLU(inplace=True) 73 | 74 | if bottleneck2d.downsample is not None: 75 | self.downsample = self._inflate_downsample(bottleneck2d.downsample) 76 | else: 77 | self.downsample = None 78 | 79 | def _inflate_downsample(self, downsample2d, time_stride=1): 80 | downsample3d = nn.Sequential( 81 | inflate.inflate_conv(downsample2d[0], time_dim=1, 82 | time_stride=time_stride), 83 | inflate.inflate_batch_norm(downsample2d[1])) 84 | return downsample3d 85 | 86 | def forward(self, x): 87 | residual = x 88 | out = self.conv1(x) 89 | out = self.bn1(out) 90 | out = self.relu(out) 91 | 92 | out = self.conv2(out) 93 | out = self.bn2(out) 94 | out = self.relu(out) 95 | 96 | out = self.conv3(out) 97 | out = self.bn3(out) 98 | 99 | if self.downsample is not None: 100 | residual = self.downsample(x) 101 | 102 | out += residual 103 | out = self.relu(out) 104 | 105 | return out 106 | 107 | 108 | class ResNet503D(nn.Module): 109 | def __init__(self, config, block, c3d_idx, nl_idx, **kwargs): 110 | super().__init__() 111 | self.block = block 112 | self.temperature = config.MODEL.AP3D.TEMPERATURE 113 | self.contrastive_att = config.MODEL.AP3D.CONTRACTIVE_ATT 114 | 115 | resnet2d = torchvision.models.resnet50(pretrained=True) 116 | if config.MODEL.RES4_STRIDE == 1: 117 | resnet2d.layer4[0].conv2.stride=(1, 1) 118 | resnet2d.layer4[0].downsample[0].stride=(1, 1) 119 | 120 | self.conv1 = inflate.inflate_conv(resnet2d.conv1, time_dim=1) 121 | self.bn1 = inflate.inflate_batch_norm(resnet2d.bn1) 122 | self.relu = nn.ReLU(inplace=True) 123 | self.maxpool = inflate.inflate_pool(resnet2d.maxpool, time_dim=1) 124 | # self.maxpool = inflate.MaxPool2dFor3dInput(kernel_size=resnet2d.maxpool.kernel_size, 125 | # stride=resnet2d.maxpool.stride, 126 | # padding=resnet2d.maxpool.padding, 127 | # dilation=resnet2d.maxpool.dilation) 128 | 129 | self.layer1 = self._inflate_reslayer(resnet2d.layer1, c3d_idx=c3d_idx[0], \ 130 | nonlocal_idx=nl_idx[0], nonlocal_channels=256) 131 | self.layer2 = self._inflate_reslayer(resnet2d.layer2, c3d_idx=c3d_idx[1], \ 132 | nonlocal_idx=nl_idx[1], nonlocal_channels=512) 133 | self.layer3 = self._inflate_reslayer(resnet2d.layer3, c3d_idx=c3d_idx[2], \ 134 | nonlocal_idx=nl_idx[2], nonlocal_channels=1024) 135 | self.layer4 = self._inflate_reslayer(resnet2d.layer4, c3d_idx=c3d_idx[3], \ 136 | nonlocal_idx=nl_idx[3], nonlocal_channels=2048) 137 | 138 | self.bn = nn.BatchNorm1d(2048) 139 | init.normal_(self.bn.weight.data, 1.0, 0.02) 140 | init.constant_(self.bn.bias.data, 0.0) 141 | 142 | self.uncloth_dim = config.MODEL.NO_CLOTHES_DIM//2 143 | self.contour_dim = config.MODEL.CONTOUR_DIM//2 144 | self.cloth_dim = config.MODEL.CLOTHES_DIM//2 145 | 146 | self.uncloth_net = GEN(in_feat_dim = self.uncloth_dim, out_img_dim=1, config = config) 147 | self.contour_net = GEN(in_feat_dim = self.contour_dim + self.cloth_dim, out_img_dim=1, config = config) 148 | self.cloth_net = GEN(in_feat_dim = self.cloth_dim, out_img_dim=1, config = config) 149 | 150 | 151 | def _inflate_reslayer(self, reslayer2d, c3d_idx, nonlocal_idx=[], nonlocal_channels=0): 152 | reslayers3d = [] 153 | for i,layer2d in enumerate(reslayer2d): 154 | if i not in c3d_idx: 155 | layer3d = Bottleneck3D(layer2d, c3d_blocks.C2D, inflate_time=False) 156 | else: 157 | layer3d = Bottleneck3D(layer2d, self.block, inflate_time=True, \ 158 | temperature=self.temperature, contrastive_att=self.contrastive_att) 159 | reslayers3d.append(layer3d) 160 | 161 | if i in nonlocal_idx: 162 | non_local_block = nonlocal_blocks.NonLocalBlock3D(nonlocal_channels, sub_sample=True) 163 | reslayers3d.append(non_local_block) 164 | 165 | return nn.Sequential(*reslayers3d) 166 | 167 | def forward(self, x): 168 | x = self.conv1(x) 169 | x = self.bn1(x) 170 | x = self.relu(x) 171 | x = self.maxpool(x) 172 | 173 | x = self.layer1(x) 174 | x = self.layer2(x) 175 | x = self.layer3(x) 176 | x = self.layer4(x) 177 | 178 | x_ori = x 179 | b, c, t, h, w = x.size() 180 | x = x.permute(0, 2, 1, 3, 4).contiguous() 181 | x = x.view(b*t, c, h, w) 182 | # spatial max pooling 183 | x = F.max_pool2d(x, x.size()[2:]) 184 | x = x.view(b, t, -1) 185 | # temporal avg pooling 186 | x = x.mean(1) 187 | f = self.bn(x) 188 | 189 | f_unclo = x_ori[:, 0:self.uncloth_dim, :, :, :] 190 | f_cont = x_ori[:, self.uncloth_dim:self.uncloth_dim+self.contour_dim+self.cloth_dim, :, :, :] 191 | f_clo = x_ori[:, self.uncloth_dim+self.contour_dim:self.uncloth_dim+self.contour_dim+self.cloth_dim, :, :, :] 192 | 193 | unclo_img = self.uncloth_net(f_unclo) 194 | cont_img = self.contour_net(f_cont) 195 | clo_img = self.cloth_net(f_clo) 196 | 197 | return (f, unclo_img, cont_img, clo_img) 198 | 199 | 200 | def C2DResNet50(config, **kwargs): 201 | c3d_idx = [[],[],[],[]] 202 | nl_idx = [[],[],[],[]] 203 | 204 | return ResNet503D(config, c3d_blocks.APP3DC, c3d_idx, nl_idx, **kwargs) 205 | 206 | 207 | def AP3DResNet50(config, **kwargs): 208 | c3d_idx = [[],[0, 2],[0, 2, 4],[]] 209 | nl_idx = [[],[],[],[]] 210 | 211 | return ResNet503D(config, c3d_blocks.APP3DC, c3d_idx, nl_idx, **kwargs) 212 | 213 | 214 | def I3DResNet50(config, **kwargs): 215 | c3d_idx = [[],[0, 2],[0, 2, 4],[]] 216 | nl_idx = [[],[],[],[]] 217 | 218 | return ResNet503D(config, c3d_blocks.I3D, c3d_idx, nl_idx, **kwargs) 219 | 220 | 221 | def AP3DNLResNet50(config, **kwargs): 222 | c3d_idx = [[],[0, 2],[0, 2, 4],[]] 223 | nl_idx = [[],[1, 3],[1, 3, 5],[]] 224 | 225 | return ResNet503D(config, c3d_blocks.APP3DC, c3d_idx, nl_idx, **kwargs) 226 | 227 | 228 | def NLResNet50(config, **kwargs): 229 | c3d_idx = [[],[],[],[]] 230 | nl_idx = [[],[1, 3],[1, 3, 5],[]] 231 | 232 | return ResNet503D(config, c3d_blocks.APP3DC, c3d_idx, nl_idx, **kwargs) 233 | -------------------------------------------------------------------------------- /tools/__pycache__/eval_metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/tools/__pycache__/eval_metrics.cpython-36.pyc -------------------------------------------------------------------------------- /tools/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ICST-MIPL/DCR-ReID_TCSVT2023/b6c58c48a97feb1c2b3f062b211c50e17bac274e/tools/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /tools/eval_metrics.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | 4 | 5 | def compute_ap_cmc(index, good_index, junk_index): 6 | """ Compute AP and CMC for each sample 7 | """ 8 | ap = 0 9 | cmc = np.zeros(len(index)) 10 | 11 | # remove junk_index 12 | mask = np.in1d(index, junk_index, invert=True) 13 | index = index[mask] 14 | 15 | # find good_index index 16 | ngood = len(good_index) 17 | mask = np.in1d(index, good_index) 18 | rows_good = np.argwhere(mask==True) 19 | rows_good = rows_good.flatten() 20 | 21 | cmc[rows_good[0]:] = 1.0 22 | for i in range(ngood): 23 | d_recall = 1.0/ngood 24 | precision = (i+1)*1.0/(rows_good[i]+1) 25 | ap = ap + d_recall*precision 26 | 27 | return ap, cmc 28 | 29 | 30 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids): 31 | """ Compute CMC and mAP 32 | 33 | Args: 34 | distmat (numpy ndarray): distance matrix with shape (num_query, num_gallery). 35 | q_pids (numpy array): person IDs for query samples. 36 | g_pids (numpy array): person IDs for gallery samples. 37 | q_camids (numpy array): camera IDs for query samples. 38 | g_camids (numpy array): camera IDs for gallery samples. 39 | """ 40 | num_q, num_g = distmat.shape 41 | index = np.argsort(distmat, axis=1) # from small to large 42 | 43 | num_no_gt = 0 # num of query imgs without groundtruth 44 | num_r1 = 0 45 | CMC = np.zeros(len(g_pids)) 46 | AP = 0 47 | 48 | for i in range(num_q): 49 | # groundtruth index 50 | query_index = np.argwhere(g_pids==q_pids[i]) 51 | camera_index = np.argwhere(g_camids==q_camids[i]) 52 | good_index = np.setdiff1d(query_index, camera_index, assume_unique=True) 53 | if good_index.size == 0: 54 | num_no_gt += 1 55 | continue 56 | # remove gallery samples that have the same pid and camid with query 57 | junk_index = np.intersect1d(query_index, camera_index) 58 | 59 | ap_tmp, CMC_tmp = compute_ap_cmc(index[i], good_index, junk_index) 60 | if CMC_tmp[0]==1: 61 | num_r1 += 1 62 | CMC = CMC + CMC_tmp 63 | AP += ap_tmp 64 | 65 | if num_no_gt > 0: 66 | logger = logging.getLogger('reid.evaluate') 67 | logger.info("{} query samples do not have groundtruth.".format(num_no_gt)) 68 | 69 | CMC = CMC / (num_q - num_no_gt) 70 | mAP = AP / (num_q - num_no_gt) 71 | 72 | return CMC, mAP 73 | 74 | 75 | def evaluate_with_clothes(distmat, q_pids, g_pids, q_camids, g_camids, q_clothids, g_clothids, mode='CC'): 76 | """ Compute CMC and mAP with clothes 77 | 78 | Args: 79 | distmat (numpy ndarray): distance matrix with shape (num_query, num_gallery). 80 | q_pids (numpy array): person IDs for query samples. 81 | g_pids (numpy array): person IDs for gallery samples. 82 | q_camids (numpy array): camera IDs for query samples. 83 | g_camids (numpy array): camera IDs for gallery samples. 84 | q_clothids (numpy array): clothes IDs for query samples. 85 | g_clothids (numpy array): clothes IDs for gallery samples. 86 | mode: 'CC' for clothes-changing; 'SC' for the same clothes. 87 | """ 88 | assert mode in ['CC', 'SC'] 89 | 90 | num_q, num_g = distmat.shape 91 | index = np.argsort(distmat, axis=1) # from small to large 92 | 93 | num_no_gt = 0 # num of query imgs without groundtruth 94 | num_r1 = 0 95 | CMC = np.zeros(len(g_pids)) 96 | AP = 0 97 | 98 | for i in range(num_q): 99 | # groundtruth index 100 | query_index = np.argwhere(g_pids==q_pids[i]) 101 | camera_index = np.argwhere(g_camids==q_camids[i]) 102 | cloth_index = np.argwhere(g_clothids==q_clothids[i]) 103 | good_index = np.setdiff1d(query_index, camera_index, assume_unique=True) 104 | if mode == 'CC': 105 | good_index = np.setdiff1d(good_index, cloth_index, assume_unique=True) 106 | # remove gallery samples that have the same (pid, camid) or (pid, clothid) with query 107 | junk_index1 = np.intersect1d(query_index, camera_index) 108 | junk_index2 = np.intersect1d(query_index, cloth_index) 109 | junk_index = np.union1d(junk_index1, junk_index2) 110 | else: 111 | good_index = np.intersect1d(good_index, cloth_index) 112 | # remove gallery samples that have the same (pid, camid) or 113 | # (the same pid and different clothid) with query 114 | junk_index1 = np.intersect1d(query_index, camera_index) 115 | junk_index2 = np.setdiff1d(query_index, cloth_index) 116 | junk_index = np.union1d(junk_index1, junk_index2) 117 | 118 | if good_index.size == 0: 119 | num_no_gt += 1 120 | continue 121 | 122 | ap_tmp, CMC_tmp = compute_ap_cmc(index[i], good_index, junk_index) 123 | if CMC_tmp[0]==1: 124 | num_r1 += 1 125 | CMC = CMC + CMC_tmp 126 | AP += ap_tmp 127 | 128 | if num_no_gt > 0: 129 | logger = logging.getLogger('reid.evaluate') 130 | logger.info("{} query samples do not have groundtruth.".format(num_no_gt)) 131 | 132 | if (num_q - num_no_gt) != 0: 133 | CMC = CMC / (num_q - num_no_gt) 134 | mAP = AP / (num_q - num_no_gt) 135 | else: 136 | mAP = 0 137 | 138 | return CMC, mAP -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import errno 5 | import json 6 | import os.path as osp 7 | import torch 8 | import random 9 | import logging 10 | import numpy as np 11 | 12 | 13 | def set_seed(seed=None): 14 | if seed is None: 15 | return 16 | random.seed(seed) 17 | os.environ['PYTHONHASHSEED'] = ("%s" % seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | torch.backends.cudnn.benchmark = False 23 | torch.backends.cudnn.deterministic = True 24 | 25 | 26 | def mkdir_if_missing(directory): 27 | if not osp.exists(directory): 28 | try: 29 | os.makedirs(directory) 30 | except OSError as e: 31 | if e.errno != errno.EEXIST: 32 | raise 33 | 34 | 35 | def read_json(fpath): 36 | with open(fpath, 'r') as f: 37 | obj = json.load(f) 38 | return obj 39 | 40 | 41 | def write_json(obj, fpath): 42 | mkdir_if_missing(osp.dirname(fpath)) 43 | with open(fpath, 'w') as f: 44 | json.dump(obj, f, indent=4, separators=(',', ': ')) 45 | 46 | 47 | class AverageMeter(object): 48 | """Computes and stores the average and current value. 49 | 50 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 51 | """ 52 | def __init__(self): 53 | self.reset() 54 | 55 | def reset(self): 56 | self.val = 0 57 | self.avg = 0 58 | self.sum = 0 59 | self.count = 0 60 | 61 | def update(self, val, n=1): 62 | self.val = val 63 | self.sum += val * n 64 | self.count += n 65 | self.avg = self.sum / self.count 66 | 67 | 68 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 69 | mkdir_if_missing(osp.dirname(fpath)) 70 | torch.save(state, fpath) 71 | if is_best: 72 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar')) 73 | 74 | ''' 75 | class Logger(object): 76 | """ 77 | Write console output to external text file. 78 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 79 | """ 80 | def __init__(self, fpath=None): 81 | self.console = sys.stdout 82 | self.file = None 83 | if fpath is not None: 84 | mkdir_if_missing(os.path.dirname(fpath)) 85 | self.file = open(fpath, 'w') 86 | 87 | def __del__(self): 88 | self.close() 89 | 90 | def __enter__(self): 91 | pass 92 | 93 | def __exit__(self, *args): 94 | self.close() 95 | 96 | def write(self, msg): 97 | self.console.write(msg) 98 | if self.file is not None: 99 | self.file.write(msg) 100 | 101 | def flush(self): 102 | self.console.flush() 103 | if self.file is not None: 104 | self.file.flush() 105 | os.fsync(self.file.fileno()) 106 | 107 | def close(self): 108 | self.console.close() 109 | if self.file is not None: 110 | self.file.close() 111 | ''' 112 | 113 | 114 | def get_logger(fpath, local_rank=0, name=''): 115 | # Creat logger 116 | logger = logging.getLogger(name) 117 | level = logging.INFO if local_rank in [-1, 0] else logging.WARN 118 | logger.setLevel(level=level) 119 | 120 | # Output to console 121 | console_handler = logging.StreamHandler(sys.stdout) 122 | console_handler.setLevel(level=level) 123 | console_handler.setFormatter(logging.Formatter('%(message)s')) 124 | logger.addHandler(console_handler) 125 | 126 | # Output to file 127 | if fpath is not None: 128 | mkdir_if_missing(os.path.dirname(fpath)) 129 | file_handler = logging.FileHandler(fpath, mode='w') 130 | file_handler.setLevel(level=level) 131 | file_handler.setFormatter(logging.Formatter('%(message)s')) 132 | logger.addHandler(file_handler) 133 | 134 | return logger -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import datetime 3 | import logging 4 | import torch 5 | import numpy as np 6 | from apex import amp 7 | from tools.utils import AverageMeter 8 | 9 | 10 | def train_cal(config, epoch, model, classifier, clothes_classifier, criterion_cla, criterion_pair, 11 | criterion_clothes, criterion_adv, criterion_shuffle, recon_uncloth, recon_contour, recon_cloth, optimizer, optimizer_cc, trainloader, pid2clothes): 12 | logger = logging.getLogger('reid.train') 13 | batch_cla_loss = AverageMeter() 14 | batch_pair_loss = AverageMeter() 15 | batch_clo_loss = AverageMeter() 16 | batch_adv_loss = AverageMeter() 17 | batch_rec_loss = AverageMeter() 18 | 19 | batch_clo_rec_loss = AverageMeter() 20 | batch_unclo_rec_loss = AverageMeter() 21 | batch_cont_rec_loss = AverageMeter() 22 | 23 | corrects = AverageMeter() 24 | clothes_corrects = AverageMeter() 25 | batch_time = AverageMeter() 26 | data_time = AverageMeter() 27 | 28 | model.train() 29 | classifier.train() 30 | clothes_classifier.train() 31 | 32 | end = time.time() 33 | for batch_idx, (imgs, pids, camids, clothes_ids, cloth, _cloth, contour) in enumerate(trainloader): 34 | 35 | pos_mask = pid2clothes[pids] 36 | imgs, pids, clothes_ids, pos_mask = imgs.cuda(), pids.cuda(), clothes_ids.cuda(), pos_mask.float().cuda() 37 | cloth, _cloth, contour = cloth.cuda(), _cloth.cuda(), contour.cuda() 38 | # Measure data loading time 39 | data_time.update(time.time() - end) 40 | # Forward 41 | features, unclo_img, cont_img, clo_img = model(imgs) 42 | features_shuffle = features.clone() 43 | ori_lst = np.arange(0, int(config.DATA.TRAIN_BATCH/config.DATA.NUM_INSTANCES)) 44 | rdn_lst = np.arange(0, int(config.DATA.TRAIN_BATCH/config.DATA.NUM_INSTANCES)) 45 | np.random.shuffle(rdn_lst) 46 | while np.sum(ori_lst == rdn_lst) > 0: 47 | np.random.shuffle(rdn_lst) 48 | ori_lst = np.arange(0, int(config.DATA.TRAIN_BATCH)) 49 | rdn_lst = np.transpose(np.array([rdn_lst*4+i for i in range(0, config.DATA.NUM_INSTANCES)])).reshape(-1) 50 | features_shuffle[ori_lst, config.MODEL.FEATURE_DIM-config.MODEL.CLOTHES_DIM:config.MODEL.FEATURE_DIM] = features_shuffle[rdn_lst, config.MODEL.FEATURE_DIM-config.MODEL.CLOTHES_DIM:config.MODEL.FEATURE_DIM] 51 | # Classification 52 | outputs = classifier(features) 53 | outputs_shuffle = classifier(features_shuffle) 54 | pred_clothes = clothes_classifier(features.detach()) 55 | _, preds = torch.max(outputs.data, 1) 56 | # Update the clothes discriminator 57 | clothes_loss = criterion_clothes(pred_clothes, clothes_ids) 58 | if epoch >= config.TRAIN.START_EPOCH_CC: 59 | optimizer_cc.zero_grad() 60 | if config.TRAIN.AMP: 61 | with amp.scale_loss(clothes_loss, optimizer_cc) as scaled_loss: 62 | scaled_loss.backward() 63 | else: 64 | clothes_loss.backward() 65 | optimizer_cc.step() 66 | # Update the backbone 67 | new_pred_clothes = clothes_classifier(features) 68 | new_pred_clothes_shuffle = clothes_classifier(features_shuffle) 69 | _, clothes_preds = torch.max(new_pred_clothes.data, 1) 70 | # Compute loss 71 | cla_loss = criterion_cla(outputs, pids) 72 | cla_loss_shuffle = criterion_cla(outputs_shuffle, pids) 73 | pair_loss = criterion_pair(features, pids) 74 | pair_loss_shuffle = criterion_pair(features_shuffle, pids) 75 | adv_loss = criterion_adv(new_pred_clothes, clothes_ids, pos_mask) 76 | adv_loss_shuffle = criterion_shuffle(new_pred_clothes_shuffle, clothes_ids, pos_mask) 77 | 78 | cla_loss = (cla_loss + config.PARA.SHUF_PID_RATIO * cla_loss_shuffle) 79 | pair_loss= (pair_loss + pair_loss_shuffle) 80 | adv_loss = (adv_loss + config.PARA.SHUF_ADV_CLO_RATIO * adv_loss_shuffle) 81 | 82 | unclo_loss = recon_uncloth(unclo_img, _cloth) 83 | cont_loss = recon_contour(cont_img, contour) 84 | clo_loss = recon_cloth(clo_img, cloth) 85 | 86 | rec_loss = (unclo_loss + cont_loss + clo_loss) 87 | 88 | if epoch >= config.TRAIN.START_EPOCH_ADV: 89 | loss = cla_loss + adv_loss + config.LOSS.PAIR_LOSS_WEIGHT * pair_loss + config.PARA.RECON_RATIO * rec_loss 90 | else: 91 | loss = cla_loss + config.LOSS.PAIR_LOSS_WEIGHT * pair_loss + config.PARA.RECON_RATIO * rec_loss 92 | optimizer.zero_grad() 93 | if config.TRAIN.AMP: 94 | with amp.scale_loss(loss, optimizer) as scaled_loss: 95 | scaled_loss.backward() 96 | else: 97 | loss.backward() 98 | optimizer.step() 99 | 100 | # statistics 101 | corrects.update(torch.sum(preds == pids.data).float()/pids.size(0), pids.size(0)) 102 | clothes_corrects.update(torch.sum(clothes_preds == clothes_ids.data).float()/clothes_ids.size(0), clothes_ids.size(0)) 103 | batch_cla_loss.update(cla_loss.item(), pids.size(0)) 104 | batch_pair_loss.update(pair_loss.item(), pids.size(0)) 105 | batch_clo_loss.update(clothes_loss.item(), clothes_ids.size(0)) 106 | batch_adv_loss.update(adv_loss.item(), clothes_ids.size(0)) 107 | batch_rec_loss.update(rec_loss.item(), pids.size(0)) 108 | 109 | batch_unclo_rec_loss.update(unclo_loss.item(), pids.size(0)) 110 | batch_clo_rec_loss.update(clo_loss.item(), pids.size(0)) 111 | batch_cont_rec_loss.update(cont_loss.item(), pids.size(0)) 112 | 113 | # measure elapsed time 114 | batch_time.update(time.time() - end) 115 | end = time.time() 116 | 117 | logger.info('Epoch{0} ' 118 | 'Time:{batch_time.sum:.1f}s ' 119 | 'Data:{data_time.sum:.1f}s ' 120 | 'ClaLoss:{cla_loss.avg:.4f} ' 121 | 'PairLoss:{pair_loss.avg:.4f} ' 122 | 'CloLoss:{clo_loss.avg:.4f} ' 123 | 'AdvLoss:{adv_loss.avg:.4f} ' 124 | 'RecLoss:{rec_loss.avg:.4f} ' 125 | 'Acc:{acc.avg:.2%} ' 126 | 'CloAcc:{clo_acc.avg:.2%} ' 127 | 'clo_loss:{clo__loss.avg:.4f} ' 128 | 'unclo_loss:{unclo_loss.avg:.4f} ' 129 | 'cont_loss:{cont_loss.avg:.4f} '.format( 130 | epoch+1, batch_time=batch_time, data_time=data_time, 131 | cla_loss=batch_cla_loss, pair_loss=batch_pair_loss, 132 | clo_loss=batch_clo_loss, adv_loss=batch_adv_loss, rec_loss=batch_rec_loss, 133 | acc=corrects, clo_acc=clothes_corrects, 134 | clo__loss=batch_clo_rec_loss, unclo_loss=batch_unclo_rec_loss,cont_loss=batch_cont_rec_loss)) 135 | 136 | 137 | def train_cal_with_memory(config, epoch, model, classifier, criterion_cla, criterion_pair, 138 | criterion_adv, optimizer, trainloader, pid2clothes): 139 | logger = logging.getLogger('reid.train') 140 | batch_cla_loss = AverageMeter() 141 | batch_pair_loss = AverageMeter() 142 | batch_adv_loss = AverageMeter() 143 | corrects = AverageMeter() 144 | batch_time = AverageMeter() 145 | data_time = AverageMeter() 146 | 147 | model.train() 148 | classifier.train() 149 | 150 | end = time.time() 151 | for batch_idx, (imgs, pids, camids, clothes_ids) in enumerate(trainloader): 152 | # Get all positive clothes classes (belonging to the same identity) for each sample 153 | pos_mask = pid2clothes[pids] 154 | imgs, pids, clothes_ids, pos_mask = imgs.cuda(), pids.cuda(), clothes_ids.cuda(), pos_mask.float().cuda() 155 | # Measure data loading time 156 | data_time.update(time.time() - end) 157 | # Forward 158 | features = model(imgs) 159 | outputs = classifier(features) 160 | _, preds = torch.max(outputs.data, 1) 161 | 162 | # Compute loss 163 | cla_loss = criterion_cla(outputs, pids) 164 | pair_loss = criterion_pair(features, pids) 165 | 166 | if epoch >= config.TRAIN.START_EPOCH_ADV: 167 | adv_loss = criterion_adv(features, clothes_ids, pos_mask) 168 | loss = cla_loss + adv_loss + config.LOSS.PAIR_LOSS_WEIGHT * pair_loss 169 | else: 170 | loss = cla_loss + config.LOSS.PAIR_LOSS_WEIGHT * pair_loss 171 | 172 | optimizer.zero_grad() 173 | if config.TRAIN.AMP: 174 | with amp.scale_loss(loss, optimizer) as scaled_loss: 175 | scaled_loss.backward() 176 | else: 177 | loss.backward() 178 | optimizer.step() 179 | 180 | # statistics 181 | corrects.update(torch.sum(preds == pids.data).float()/pids.size(0), pids.size(0)) 182 | batch_cla_loss.update(cla_loss.item(), pids.size(0)) 183 | batch_pair_loss.update(pair_loss.item(), pids.size(0)) 184 | if epoch >= config.TRAIN.START_EPOCH_ADV: 185 | batch_adv_loss.update(adv_loss.item(), clothes_ids.size(0)) 186 | # measure elapsed time 187 | batch_time.update(time.time() - end) 188 | end = time.time() 189 | 190 | logger.info('Epoch{0} ' 191 | 'Time:{batch_time.sum:.1f}s ' 192 | 'Data:{data_time.sum:.1f}s ' 193 | 'ClaLoss:{cla_loss.avg:.4f} ' 194 | 'PairLoss:{pair_loss.avg:.4f} ' 195 | 'AdvLoss:{adv_loss.avg:.4f} ' 196 | 'Acc:{acc.avg:.2%} '.format( 197 | epoch+1, batch_time=batch_time, data_time=data_time, 198 | cla_loss=batch_cla_loss, pair_loss=batch_pair_loss, 199 | adv_loss=batch_adv_loss, acc=corrects)) --------------------------------------------------------------------------------