├── LICENSE ├── README.md ├── config ├── __init__.py └── defaults.py ├── configs ├── MARS │ ├── pit-test.yml │ └── pit.yml └── iLIDS-VID │ ├── pit-test.yml │ └── pit.yml ├── datasets ├── __init__.py ├── bases.py ├── ilids.py ├── make_dataloader.py ├── mars.py ├── misc.py ├── preprocessing.py ├── sampler.py ├── sampler_ddp.py └── temporal_transforms.py ├── framework.jpg ├── loss ├── __init__.py ├── arcface.py ├── center_loss.py ├── make_loss.py ├── metric_learning.py ├── softmax_loss.py └── triplet_loss.py ├── model ├── __init__.py ├── backbones │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── resnet.cpython-36.pyc │ │ └── vit_pytorch.cpython-36.pyc │ ├── layers.py │ ├── resnet.py │ └── vit_pytorch.py └── make_model.py ├── processor ├── __init__.py └── processor.py ├── solver ├── __init__.py ├── cosine_lr.py ├── lr_scheduler.py ├── make_optimizer.py ├── scheduler.py └── scheduler_factory.py ├── thop ├── __init__.py ├── fx_profile.py ├── onnx_profile.py ├── profile.py ├── rnn_hooks.py ├── utils.py └── vision │ ├── __init__.py │ ├── basic_hooks.py │ ├── counter.py │ ├── efficientnet.py │ └── onnx_counter.py ├── train.py ├── utils ├── __init__.py ├── iotools.py ├── logger.py ├── meter.py ├── metrics.py ├── reranking.py └── saver.py └── vis.py /README.md: -------------------------------------------------------------------------------- 1 | # Multi-direction and Multi-scale Pyramid in Transformer for Video-based Pedestrian Retrieval 2 | ![LICENSE](https://img.shields.io/badge/license-GPL%202.0-green) ![Python](https://img.shields.io/badge/python-3.6-blue.svg) ![pytorch](https://img.shields.io/badge/pytorch-1.8.1-orange) 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-direction-and-multi-scale-pyramid-in-1/person-re-identification-on-ilids-vid)](https://paperswithcode.com/sota/person-re-identification-on-ilids-vid?p=multi-direction-and-multi-scale-pyramid-in-1) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/multi-direction-and-multi-scale-pyramid-in-1/person-re-identification-on-mars)](https://paperswithcode.com/sota/person-re-identification-on-mars?p=multi-direction-and-multi-scale-pyramid-in-1) 4 | 5 | Implementation of the proposed PiT. For the preprint version, please refer to [[Arxiv]](https://arxiv.org/pdf/2202.06014.pdf). 6 | 7 | ![framework](./framework.jpg) 8 | 9 | 10 | ## Getting Started 11 | ### Requirements 12 | Here is a brief instruction for installing the experimental environment. 13 | ``` 14 | # install virtual envs 15 | $ conda create -n PiT python=3.6 -y 16 | $ conda activate PiT 17 | # install pytorch 1.8.1/1.6.0 (other versions may also work) 18 | $ pip install timm scipy einops yacs opencv-python tensorboard pandas 19 | ``` 20 | 21 | ### Download pre-trained model 22 | The pre-trained vit model can be downloaded in this [link](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth) and should be put in the `/home/[USER]/.cache/torch/checkpoints/` directory. 23 | 24 | ### Dataset Preparation 25 | For iLIDS-VID, please refer to this [issue](https://github.com/deropty/PiT/issues/2). 26 | 27 | ## Training and Testing 28 | ``` 29 | # This command below includes the training and testing processes. 30 | $ python train.py --config_file configs/MARS/pit.yml MODEL.DEVICE_ID "('0')" 31 | # For testing only 32 | $ python train.py --config_file configs/MARS/pit-test.yml MODEL.DEVICE_ID "('0')" 33 | ``` 34 | 35 | 36 | ## Results in the Paper 37 | The results of MARS and iLIDS-VID are trained using one 24G NVIDIA GPU and provided below. You can change the parameter `DATALOADER.P` in yml file to decrease the GPU memory cost. 38 | 39 | | Model | Rank-1@MARS | Rank-1@iLIDS-VID | Rank-1@MARS/iLIDS-VID | 40 | | --- | --- | --- | --- | 41 | | PiT | [90.22](https://pan.baidu.com/s/1nw5yofEilW0ffG_ZF4eoXQ) (code:wqxv)| [92.07](https://pan.baidu.com/s/10LosWwUMktTiWvbHEP1Tjw) (code: quci)| 90.22/92.07([google drive](https://drive.google.com/drive/folders/1P7xiJ05yVBYz9xQSSuEFbs2FWSrsAhua?usp=sharing)) 42 | 43 | You can download these models and put them in the `../logs/[DATASET]_PiT_1x210_3x70_105x2_6p` directory. Then use the command below to evaluate them. 44 | ``` 45 | $ python test.py --config_file configs/MARS/pit.yml MODEL.DEVICE_ID "('0')" 46 | ``` 47 | 48 | 49 | ## Acknowledgement 50 | 51 | This repository is built upon the repository [TranReID](https://github.com/damo-cv/TransReID). 52 | 53 | ## Citation 54 | If you find this project useful for your research, please kindly cite: 55 | 56 | ``` 57 | @ARTICLE{9714137, 58 | author={Zang, Xianghao and Li, Ge and Gao, Wei}, 59 | journal={IEEE Transactions on Industrial Informatics}, 60 | title={Multidirection and Multiscale Pyramid in Transformer for Video-Based Pedestrian Retrieval}, 61 | year={2022}, 62 | volume={18}, 63 | number={12}, 64 | pages={8776-8785}, 65 | doi={10.1109/TII.2022.3151766} 66 | } 67 | ``` 68 | 69 | ## License 70 | This repository is released under the GPL-2.0 License as found in the [LICENSE](LICENSE) file. 71 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .defaults import _C as cfg 8 | from .defaults import _C as cfg_test 9 | -------------------------------------------------------------------------------- /config/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | # ----------------------------------------------------------------------------- 4 | # Convention about Training / Test specific parameters 5 | # ----------------------------------------------------------------------------- 6 | # Whenever an argument can be either used for training or for testing, the 7 | # corresponding name will be post-fixed by a _TRAIN for a training parameter, 8 | 9 | # ----------------------------------------------------------------------------- 10 | # Config definition 11 | # ----------------------------------------------------------------------------- 12 | 13 | _C = CN() 14 | # ----------------------------------------------------------------------------- 15 | # MODEL 16 | # ----------------------------------------------------------------------------- 17 | _C.MODEL = CN() 18 | # Using cuda or cpu for training 19 | _C.MODEL.DEVICE = "cuda" 20 | # ID number of GPU 21 | _C.MODEL.DEVICE_ID = '0' 22 | # Name of backbone 23 | _C.MODEL.NAME = 'resnet50' 24 | # Last stride of backbone 25 | _C.MODEL.LAST_STRIDE = 1 26 | # Path to pretrained model of backbone 27 | _C.MODEL.PRETRAIN_PATH = '' 28 | 29 | # Use ImageNet pretrained model to initialize backbone or use self trained model to initialize the whole model 30 | # Options: 'imagenet' , 'self' , 'finetune' 31 | _C.MODEL.PRETRAIN_CHOICE = 'imagenet' 32 | 33 | # If train with BNNeck, options: 'bnneck' or 'no' 34 | _C.MODEL.NECK = 'bnneck' 35 | # If train loss include center loss, options: 'yes' or 'no'. Loss with center loss has different optimizer configuration 36 | _C.MODEL.IF_WITH_CENTER = 'no' 37 | 38 | _C.MODEL.ID_LOSS_TYPE = 'softmax' 39 | _C.MODEL.ID_LOSS_WEIGHT = 1.0 40 | _C.MODEL.TRIPLET_LOSS_WEIGHT = 1.0 41 | 42 | _C.MODEL.METRIC_LOSS_TYPE = 'triplet' 43 | # If train with multi-gpu ddp mode, options: 'True', 'False' 44 | _C.MODEL.DIST_TRAIN = False 45 | # If train with soft triplet loss, options: 'True', 'False' 46 | _C.MODEL.NO_MARGIN = False 47 | # If train with label smooth, options: 'on', 'off' 48 | _C.MODEL.IF_LABELSMOOTH = 'on' 49 | # If train with arcface loss, options: 'True', 'False' 50 | _C.MODEL.COS_LAYER = False 51 | 52 | # Transformer setting 53 | _C.MODEL.DROP_PATH = 0.1 54 | _C.MODEL.DROP_OUT = 0.0 55 | _C.MODEL.ATT_DROP_RATE = 0.0 56 | _C.MODEL.TRANSFORMER_TYPE = 'None' 57 | _C.MODEL.STRIDE_SIZE = [16, 16] 58 | 59 | # JPM Parameter 60 | _C.MODEL.JPM = False 61 | _C.MODEL.SHIFT_NUM = 5 62 | _C.MODEL.SHUFFLE_GROUP = 2 63 | _C.MODEL.DEVIDE_LENGTH = 5 64 | _C.MODEL.RE_ARRANGE = True 65 | 66 | # SIE Parameter 67 | _C.MODEL.SIE_COE = 3.0 68 | _C.MODEL.SIE_CAMERA = False 69 | _C.MODEL.SIE_VIEW = False 70 | 71 | # Sample Strategy 72 | _C.MODEL.TRAIN_STRATEGY = '' # ['multiview', 'chunk'] 73 | _C.MODEL.SPATIAL = False 74 | _C.MODEL.TEMPORAL = False 75 | _C.MODEL.FREEZE = False 76 | _C.MODEL.PYRAMID0_TYPE = '' 77 | _C.MODEL.PYRAMID1_TYPE = '' 78 | _C.MODEL.PYRAMID2_TYPE = '' 79 | _C.MODEL.PYRAMID3_TYPE = '' 80 | _C.MODEL.PYRAMID4_TYPE = '' 81 | _C.MODEL.LAYER_COMBIN = 1 82 | _C.MODEL.LAYER0_DIVISION_TYPE = 'NULL' 83 | _C.MODEL.LAYER1_DIVISION_TYPE = 'NULL' 84 | _C.MODEL.LAYER2_DIVISION_TYPE = 'NULL' 85 | _C.MODEL.LAYER3_DIVISION_TYPE = 'NULL' 86 | _C.MODEL.LAYER4_DIVISION_TYPE = 'NULL' 87 | _C.MODEL.DIVERSITY = False 88 | 89 | 90 | 91 | # ----------------------------------------------------------------------------- 92 | # INPUT 93 | # ----------------------------------------------------------------------------- 94 | _C.INPUT = CN() 95 | # Size of the image during training 96 | _C.INPUT.SIZE_TRAIN = [384, 128] 97 | # Size of the image during test 98 | _C.INPUT.SIZE_TEST = [384, 128] 99 | # Random probability for image horizontal flip 100 | _C.INPUT.PROB = 0.5 101 | # Random probability for random erasing 102 | _C.INPUT.RE_PROB = 0.5 103 | # Random erasing 104 | _C.INPUT.RE = True 105 | # Values to be used for image normalization 106 | _C.INPUT.PIXEL_MEAN = [0.485, 0.456, 0.406] 107 | # Values to be used for image normalization 108 | _C.INPUT.PIXEL_STD = [0.229, 0.224, 0.225] 109 | # Value of padding size 110 | _C.INPUT.PADDING = 10 111 | 112 | # ----------------------------------------------------------------------------- 113 | # Dataset 114 | # ----------------------------------------------------------------------------- 115 | _C.DATASETS = CN() 116 | # List of the dataset names for training, as present in paths_catalog.py 117 | _C.DATASETS.NAMES = ('market1501') 118 | # Root directory where datasets should be used (and downloaded if not found) 119 | _C.DATASETS.ROOT_DIR = ('../data') 120 | 121 | 122 | # ----------------------------------------------------------------------------- 123 | # DataLoader 124 | # ----------------------------------------------------------------------------- 125 | _C.DATALOADER = CN() 126 | # Number of data loading threads 127 | _C.DATALOADER.NUM_WORKERS = 8 128 | # Sampler for data loading 129 | _C.DATALOADER.SAMPLER = 'softmax' 130 | # Number of instance for one batch 131 | _C.DATALOADER.NUM_INSTANCE = 16 132 | # random select p persons for each sample 133 | _C.DATALOADER.P = 16 134 | # random select k tracklets for each person 135 | _C.DATALOADER.K = 8 136 | # random select 8 images of each tracklet for test 137 | _C.DATALOADER.NUM_TEST_IMAGES = 8 138 | # random select 8 images of each tracklet for train 139 | _C.DATALOADER.NUM_TRAIN_IMAGES = 8 140 | 141 | # ---------------------------------------------------------------------------- # 142 | # Solver 143 | # ---------------------------------------------------------------------------- # 144 | _C.SOLVER = CN() 145 | # Name of optimizer 146 | _C.SOLVER.OPTIMIZER_NAME = "Adam" 147 | # Number of max epoches 148 | _C.SOLVER.MAX_EPOCHS = 100 149 | # Base learning rate 150 | _C.SOLVER.BASE_LR = 3e-4 151 | # Whether using larger learning rate for fc layer 152 | _C.SOLVER.LARGE_FC_LR = False 153 | # Factor of learning bias 154 | _C.SOLVER.BIAS_LR_FACTOR = 1 155 | # Factor of learning bias 156 | _C.SOLVER.SEED = 1234 157 | # Momentum 158 | _C.SOLVER.MOMENTUM = 0.9 159 | # Margin of triplet loss 160 | _C.SOLVER.MARGIN = 0.3 161 | # Learning rate of SGD to learn the centers of center loss 162 | _C.SOLVER.CENTER_LR = 0.5 163 | # Balanced weight of center loss 164 | _C.SOLVER.CENTER_LOSS_WEIGHT = 0.0005 165 | 166 | # Settings of weight decay 167 | _C.SOLVER.WEIGHT_DECAY = 0.0005 168 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0.0005 169 | 170 | # decay rate of learning rate 171 | _C.SOLVER.GAMMA = 0.1 172 | # decay step of learning rate 173 | _C.SOLVER.STEPS = (40, 70) 174 | # warm up factor 175 | _C.SOLVER.WARMUP_FACTOR = 0.01 176 | # warm up epochs 177 | _C.SOLVER.WARMUP_EPOCHS = 5 178 | # method of warm up, option: 'constant','linear' 179 | _C.SOLVER.WARMUP_METHOD = "linear" 180 | 181 | _C.SOLVER.COSINE_MARGIN = 0.5 182 | _C.SOLVER.COSINE_SCALE = 30 183 | 184 | # epoch number of saving checkpoints 185 | _C.SOLVER.CHECKPOINT_PERIOD = 10 186 | # iteration of display training log 187 | _C.SOLVER.LOG_PERIOD = 100 188 | # epoch number of validation 189 | _C.SOLVER.EVAL_PERIOD = 10 190 | # Number of images per batch 191 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 128, each GPU will 192 | # contain 16 images per batch 193 | _C.SOLVER.IMS_PER_BATCH = 64 194 | 195 | # ---------------------------------------------------------------------------- # 196 | # TEST 197 | # ---------------------------------------------------------------------------- # 198 | 199 | _C.TEST = CN() 200 | # Number of images per batch during test 201 | _C.TEST.IMS_PER_BATCH = 128 202 | # If test with re-ranking, options: 'True','False' 203 | _C.TEST.RE_RANKING = False 204 | # Path to trained model 205 | _C.TEST.WEIGHT = "" 206 | # Which feature of BNNeck to be used for test, before or after BNNneck, options: 'before' or 'after' 207 | _C.TEST.NECK_FEAT = 'after' 208 | # Whether feature is nomalized before test, if yes, it is equivalent to cosine distance 209 | _C.TEST.FEAT_NORM = 'yes' 210 | 211 | # Name for saving the distmat after testing. 212 | _C.TEST.DIST_MAT = "dist_mat.npy" 213 | # Whether calculate the eval score option: 'True', 'False' 214 | _C.TEST.EVAL = False 215 | # 216 | _C.TEST.IMG_TEST_BATCH = 512 217 | _C.TEST.TEST_BATCH = 32 218 | _C.TEST.VIS = False 219 | # ---------------------------------------------------------------------------- # 220 | # Misc options 221 | # ---------------------------------------------------------------------------- # 222 | # Path to checkpoint and saved log of trained model 223 | _C.OUTPUT_DIR = "" 224 | -------------------------------------------------------------------------------- /configs/MARS/pit-test.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/zangxh/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('6') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_PiT' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True # original 13 | SIE_COE: 3.0 14 | JPM: True # original 15 | RE_ARRANGE: False # original 16 | TRAIN_STRATEGY: 'chunk' 17 | SPATIAL: True 18 | TEMPORAL: False 19 | FREEZE: True 20 | PYRAMID0_TYPE: 'horizontal' # 'horizontal','vertical','patch' 21 | LAYER0_DIVISION_TYPE: '1x210' # the type of division 22 | PYRAMID1_TYPE: 'horizontal' 23 | LAYER1_DIVISION_TYPE: '3x70' 24 | PYRAMID2_TYPE: 'vertical' 25 | LAYER2_DIVISION_TYPE: '105x2' 26 | PYRAMID3_TYPE: 'patch' 27 | LAYER3_DIVISION_TYPE: '6p' 28 | PYRAMID4_TYPE: 'horizontal' 29 | LAYER4_DIVISION_TYPE: 'NULL' 30 | LAYER_COMBIN: 4 # use how many layers of pyramid 31 | DIVERSITY: False 32 | 33 | INPUT: 34 | SIZE_TRAIN: [256, 128] 35 | SIZE_TEST: [256, 128] 36 | PROB: 0.5 # random horizontal flip 37 | RE_PROB: 0.5 # random erasing 38 | RE: True # random erasing 39 | PADDING: 10 40 | PIXEL_MEAN: [0.5, 0.5, 0.5] 41 | PIXEL_STD: [0.5, 0.5, 0.5] 42 | 43 | DATASETS: 44 | NAMES: ('mars') 45 | ROOT_DIR: ('../../data') 46 | 47 | DATALOADER: 48 | SAMPLER: 'softmax_triplet' 49 | NUM_INSTANCE: 4 50 | NUM_WORKERS: 8 51 | P: 2 52 | K: 4 53 | NUM_TEST_IMAGES: 8 54 | NUM_TRAIN_IMAGES: 8 55 | 56 | SOLVER: 57 | OPTIMIZER_NAME: 'SGD' 58 | SEED: 1151 59 | MAX_EPOCHS: 121 60 | BASE_LR: 0.01 61 | IMS_PER_BATCH: 64 62 | WARMUP_METHOD: 'linear' 63 | LARGE_FC_LR: False 64 | CHECKPOINT_PERIOD: 20 # 65 | LOG_PERIOD: 50 66 | EVAL_PERIOD: 20 # 67 | WEIGHT_DECAY: 1e-4 68 | WEIGHT_DECAY_BIAS: 1e-4 69 | BIAS_LR_FACTOR: 2 70 | WARMUP_EPOCHS: 5 71 | 72 | TEST: 73 | EVAL: True 74 | IMS_PER_BATCH: 256 75 | RE_RANKING: False 76 | WEIGHT: '../logs/mars_PiT_1x210_3x70_105x2_6p' 77 | NECK_FEAT: 'before' 78 | FEAT_NORM: 'yes' 79 | 80 | OUTPUT_DIR: '../logs/mars_PiT' 81 | 82 | 83 | -------------------------------------------------------------------------------- /configs/MARS/pit.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/zangxh/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('6') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_PiT' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True # original 13 | SIE_COE: 3.0 14 | JPM: True # original 15 | RE_ARRANGE: False # original 16 | TRAIN_STRATEGY: 'chunk' 17 | SPATIAL: True 18 | TEMPORAL: False 19 | FREEZE: True 20 | PYRAMID0_TYPE: 'horizontal' # 'horizontal','vertical','patch' 21 | LAYER0_DIVISION_TYPE: '1x210' # the type of division 22 | PYRAMID1_TYPE: 'horizontal' 23 | LAYER1_DIVISION_TYPE: '3x70' 24 | PYRAMID2_TYPE: 'vertical' 25 | LAYER2_DIVISION_TYPE: '105x2' 26 | PYRAMID3_TYPE: 'patch' 27 | LAYER3_DIVISION_TYPE: '6p' 28 | PYRAMID4_TYPE: 'horizontal' 29 | LAYER4_DIVISION_TYPE: 'NULL' 30 | LAYER_COMBIN: 4 # use how many layers of pyramid 31 | DIVERSITY: False 32 | 33 | INPUT: 34 | SIZE_TRAIN: [256, 128] 35 | SIZE_TEST: [256, 128] 36 | PROB: 0.5 # random horizontal flip 37 | RE_PROB: 0.5 # random erasing 38 | RE: True # random erasing 39 | PADDING: 10 40 | PIXEL_MEAN: [0.5, 0.5, 0.5] 41 | PIXEL_STD: [0.5, 0.5, 0.5] 42 | 43 | DATASETS: 44 | NAMES: ('mars') 45 | ROOT_DIR: ('../../data') 46 | 47 | DATALOADER: 48 | SAMPLER: 'softmax_triplet' 49 | NUM_INSTANCE: 4 50 | NUM_WORKERS: 8 51 | P: 2 52 | K: 4 53 | NUM_TEST_IMAGES: 8 54 | NUM_TRAIN_IMAGES: 8 55 | 56 | SOLVER: 57 | OPTIMIZER_NAME: 'SGD' 58 | SEED: 1151 59 | MAX_EPOCHS: 121 60 | BASE_LR: 0.01 61 | IMS_PER_BATCH: 64 62 | WARMUP_METHOD: 'linear' 63 | LARGE_FC_LR: False 64 | CHECKPOINT_PERIOD: 20 # 65 | LOG_PERIOD: 50 66 | EVAL_PERIOD: 20 # 67 | WEIGHT_DECAY: 1e-4 68 | WEIGHT_DECAY_BIAS: 1e-4 69 | BIAS_LR_FACTOR: 2 70 | WARMUP_EPOCHS: 5 71 | 72 | TEST: 73 | EVAL: True 74 | IMS_PER_BATCH: 256 75 | RE_RANKING: False 76 | WEIGHT: '' # such as '../logs/mars_PiT_1x210_3x70_105x2_6p' 77 | NECK_FEAT: 'before' 78 | FEAT_NORM: 'yes' 79 | 80 | OUTPUT_DIR: '../logs/mars_PiT' 81 | 82 | 83 | -------------------------------------------------------------------------------- /configs/iLIDS-VID/pit-test.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/zangxh/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('6') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_PiT' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True # original 13 | SIE_COE: 3.0 14 | JPM: True # original 15 | RE_ARRANGE: False # original 16 | TRAIN_STRATEGY: 'chunk' 17 | SPATIAL: True 18 | TEMPORAL: False 19 | FREEZE: True 20 | PYRAMID0_TYPE: 'horizontal' # 'horizontal','vertical','patch' 21 | LAYER0_DIVISION_TYPE: '1x210' # the type of division 22 | PYRAMID1_TYPE: 'horizontal' 23 | LAYER1_DIVISION_TYPE: '3x70' 24 | PYRAMID2_TYPE: 'vertical' 25 | LAYER2_DIVISION_TYPE: '105x2' 26 | PYRAMID3_TYPE: 'patch' 27 | LAYER3_DIVISION_TYPE: '6p' 28 | PYRAMID4_TYPE: 'horizontal' 29 | LAYER4_DIVISION_TYPE: 'NULL' 30 | LAYER_COMBIN: 4 # use how many layers of pyramid 31 | DIVERSITY: False 32 | 33 | INPUT: 34 | SIZE_TRAIN: [256, 128] 35 | SIZE_TEST: [256, 128] 36 | PROB: 0.5 # random horizontal flip 37 | RE_PROB: 0.5 # random erasing 38 | RE: True # random erasing 39 | PADDING: 10 40 | PIXEL_MEAN: [0.5, 0.5, 0.5] 41 | PIXEL_STD: [0.5, 0.5, 0.5] 42 | 43 | DATASETS: 44 | NAMES: ('ilids') 45 | ROOT_DIR: ('../../data') 46 | 47 | DATALOADER: 48 | SAMPLER: 'softmax_triplet' 49 | NUM_INSTANCE: 4 50 | NUM_WORKERS: 8 51 | P: 2 52 | K: 4 53 | NUM_TEST_IMAGES: 8 54 | NUM_TRAIN_IMAGES: 8 55 | 56 | SOLVER: 57 | OPTIMIZER_NAME: 'SGD' 58 | SEED: 1171 59 | MAX_EPOCHS: 121 60 | BASE_LR: 0.01 61 | IMS_PER_BATCH: 64 62 | WARMUP_METHOD: 'linear' 63 | LARGE_FC_LR: False 64 | CHECKPOINT_PERIOD: 20 # 65 | LOG_PERIOD: 50 66 | EVAL_PERIOD: 20 # 67 | WEIGHT_DECAY: 1e-4 68 | WEIGHT_DECAY_BIAS: 1e-4 69 | BIAS_LR_FACTOR: 2 70 | WARMUP_EPOCHS: 5 71 | 72 | TEST: 73 | EVAL: True 74 | IMS_PER_BATCH: 256 75 | RE_RANKING: False 76 | WEIGHT: '../logs/ilids_PiT_1x210_3x70_105x2_6p' 77 | NECK_FEAT: 'before' 78 | FEAT_NORM: 'yes' 79 | 80 | OUTPUT_DIR: '../logs/ilids_PiT' 81 | 82 | 83 | -------------------------------------------------------------------------------- /configs/iLIDS-VID/pit.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | PRETRAIN_PATH: '/home/zangxh/.cache/torch/checkpoints/jx_vit_base_p16_224-80ecf9dd.pth' 4 | METRIC_LOSS_TYPE: 'triplet' 5 | IF_LABELSMOOTH: 'off' 6 | IF_WITH_CENTER: 'no' 7 | NAME: 'transformer' 8 | NO_MARGIN: True 9 | DEVICE_ID: ('6') 10 | TRANSFORMER_TYPE: 'vit_base_patch16_224_PiT' 11 | STRIDE_SIZE: [12, 12] 12 | SIE_CAMERA: True # original 13 | SIE_COE: 3.0 14 | JPM: True # original 15 | RE_ARRANGE: False # original 16 | TRAIN_STRATEGY: 'chunk' 17 | SPATIAL: True 18 | TEMPORAL: False 19 | FREEZE: True 20 | PYRAMID0_TYPE: 'horizontal' # 'horizontal','vertical','patch' 21 | LAYER0_DIVISION_TYPE: '1x210' # the type of division 22 | PYRAMID1_TYPE: 'horizontal' 23 | LAYER1_DIVISION_TYPE: '3x70' 24 | PYRAMID2_TYPE: 'vertical' 25 | LAYER2_DIVISION_TYPE: '105x2' 26 | PYRAMID3_TYPE: 'patch' 27 | LAYER3_DIVISION_TYPE: '6p' 28 | PYRAMID4_TYPE: 'horizontal' 29 | LAYER4_DIVISION_TYPE: 'NULL' 30 | LAYER_COMBIN: 4 # use how many layers of pyramid 31 | DIVERSITY: False 32 | 33 | INPUT: 34 | SIZE_TRAIN: [256, 128] 35 | SIZE_TEST: [256, 128] 36 | PROB: 0.5 # random horizontal flip 37 | RE_PROB: 0.5 # random erasing 38 | RE: True # random erasing 39 | PADDING: 10 40 | PIXEL_MEAN: [0.5, 0.5, 0.5] 41 | PIXEL_STD: [0.5, 0.5, 0.5] 42 | 43 | DATASETS: 44 | NAMES: ('ilids') 45 | ROOT_DIR: ('../../data') 46 | 47 | DATALOADER: 48 | SAMPLER: 'softmax_triplet' 49 | NUM_INSTANCE: 4 50 | NUM_WORKERS: 8 51 | P: 2 52 | K: 4 53 | NUM_TEST_IMAGES: 8 54 | NUM_TRAIN_IMAGES: 8 55 | 56 | SOLVER: 57 | OPTIMIZER_NAME: 'SGD' 58 | SEED: 1171 59 | MAX_EPOCHS: 121 60 | BASE_LR: 0.01 61 | IMS_PER_BATCH: 64 62 | WARMUP_METHOD: 'linear' 63 | LARGE_FC_LR: False 64 | CHECKPOINT_PERIOD: 20 # 65 | LOG_PERIOD: 50 66 | EVAL_PERIOD: 20 # 67 | WEIGHT_DECAY: 1e-4 68 | WEIGHT_DECAY_BIAS: 1e-4 69 | BIAS_LR_FACTOR: 2 70 | WARMUP_EPOCHS: 5 71 | 72 | TEST: 73 | EVAL: True 74 | IMS_PER_BATCH: 256 75 | RE_RANKING: False 76 | WEIGHT: '' # such as '../logs/ilids_PiT_1x210_3x70_105x2_6p' 77 | NECK_FEAT: 'before' 78 | FEAT_NORM: 'yes' 79 | 80 | OUTPUT_DIR: '../logs/ilids_PiT' 81 | 82 | 83 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_dataloader import make_dataloader -------------------------------------------------------------------------------- /datasets/bases.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageFile 2 | 3 | from torch.utils.data import Dataset 4 | import os.path as osp 5 | import torch 6 | from .misc import get_default_video_loader 7 | from .temporal_transforms import MultiViewTemporalTransform 8 | ImageFile.LOAD_TRUNCATED_IMAGES = True 9 | 10 | 11 | def read_image(img_path): 12 | """Keep reading image until succeed. 13 | This can avoid IOError incurred by heavy IO process.""" 14 | got_img = False 15 | if not osp.exists(img_path): 16 | raise IOError("{} does not exist".format(img_path)) 17 | while not got_img: 18 | try: 19 | img = Image.open(img_path).convert('RGB') 20 | got_img = True 21 | except IOError: 22 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 23 | pass 24 | return img 25 | 26 | 27 | class BaseDataset(object): 28 | """ 29 | Base class of reid dataset 30 | """ 31 | 32 | def get_imagedata_info(self, data): 33 | pids, cams, tracks = [], [], [] 34 | num_imgs = 0 35 | 36 | for img_paths, pid, camid, trackid in data: 37 | num_imgs += len(img_paths) 38 | pids += [pid] 39 | cams += [camid] 40 | tracks += [trackid] 41 | pids = set(pids) 42 | cams = set(cams) 43 | tracks = set(tracks) 44 | num_pids = len(pids) 45 | num_cams = len(cams) 46 | # num_imgs = len(data) 47 | num_views = len(tracks) 48 | return num_pids, num_imgs, num_cams, num_views 49 | 50 | def print_dataset_statistics(self): 51 | raise NotImplementedError 52 | 53 | 54 | class BaseImageDataset(BaseDataset): 55 | """ 56 | Base class of image reid dataset 57 | """ 58 | 59 | def print_dataset_statistics(self, train, query, gallery): 60 | num_train_pids, num_train_imgs, num_train_cams, num_train_views = self.get_imagedata_info(train) 61 | num_query_pids, num_query_imgs, num_query_cams, num_train_views = self.get_imagedata_info(query) 62 | num_gallery_pids, num_gallery_imgs, num_gallery_cams, num_train_views = self.get_imagedata_info(gallery) 63 | 64 | print("Dataset statistics:") 65 | print(" ----------------------------------------") 66 | print(" subset | # ids | # images | # cameras") 67 | print(" ----------------------------------------") 68 | print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) 69 | print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) 70 | print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 71 | print(" ----------------------------------------") 72 | 73 | 74 | class ImageDataset(Dataset): 75 | def __init__(self, dataset, transform=None): 76 | self.dataset = dataset 77 | self.transform = transform 78 | 79 | def __len__(self): 80 | return len(self.dataset) 81 | 82 | def __getitem__(self, index): 83 | img_path, pid, camid, trackid = self.dataset[index] 84 | img = read_image(img_path) 85 | 86 | if self.transform is not None: 87 | img = self.transform(img) 88 | 89 | return img, pid, camid, trackid, img_path.split('/')[-1] 90 | 91 | 92 | class VideoDataset(Dataset): 93 | """Video ReID Dataset. 94 | Note: 95 | Batch data has shape N x C x T x H x W 96 | Args: 97 | dataset (list): List with items (img_paths, pid, camid) 98 | temporal_transform (callable, optional): A function/transform that takes in a list of frame indices 99 | and returns a transformed version 100 | target_transform (callable, optional): A function/transform that takes in the 101 | target and transforms it. 102 | """ 103 | 104 | def __init__(self, 105 | dataset, 106 | spatial_transform=None, 107 | temporal_transform=None, 108 | get_loader=get_default_video_loader): 109 | self.dataset = dataset 110 | self.spatial_transform = spatial_transform 111 | self.temporal_transform = temporal_transform 112 | self.loader = get_loader() 113 | self.teacher_mode = False 114 | 115 | def __len__(self): 116 | return len(self.dataset) 117 | 118 | def get_num_pids(self): 119 | return len(np.unique([el[1] for el in self.dataset])) 120 | 121 | def get_num_cams(self): 122 | return len(np.unique([el[2] for el in self.dataset])) 123 | 124 | def set_teacher_mode(self, is_teacher: bool): 125 | self.teacher_mode = is_teacher 126 | 127 | def __getitem__(self, index): 128 | """ 129 | Args: 130 | index (int): Index 131 | 132 | Returns: 133 | tuple: (clip, pid, camid) where pid is identity of the clip. 134 | """ 135 | img_paths, pid, camid, tracklet_id = self.dataset[index] 136 | 137 | if isinstance(self.temporal_transform, MultiViewTemporalTransform): 138 | candidates = list(filter(lambda x: x[1] == pid, self.dataset)) 139 | img_paths = self.temporal_transform(candidates, index) 140 | elif self.temporal_transform is not None: 141 | img_paths = self.temporal_transform(img_paths, index) 142 | 143 | clip = self.loader(img_paths) 144 | 145 | if not self.teacher_mode: 146 | clip = [self.spatial_transform(img) for img in clip] 147 | else: 148 | clip_aug = [self.spatial_transform(img) for img in clip] 149 | std_daug = T.Compose([ 150 | self.spatial_transform.transforms[0], 151 | T.ToTensor(), 152 | self.spatial_transform.transforms[-1] if not isinstance(self.spatial_transform.transforms[-1], T.RandomErasing) else self.spatial_transform.transforms[-2] 153 | ]) 154 | clip_std = [std_daug(img) for img in clip] 155 | clip = clip_aug + clip_std 156 | 157 | clip = torch.stack(clip, 0) 158 | 159 | return clip, pid, camid, tracklet_id, [i.split('/')[-1] for i in img_paths] -------------------------------------------------------------------------------- /datasets/ilids.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import glob 4 | import re 5 | import sys 6 | import urllib 7 | import tarfile 8 | import zipfile 9 | import os.path as osp 10 | from scipy.io import loadmat 11 | import numpy as np 12 | import json 13 | from .bases import BaseImageDataset 14 | 15 | 16 | """Dataset classes""" 17 | 18 | class iLIDSVID(BaseImageDataset): 19 | """ 20 | iLIDS-VID 21 | 22 | Reference: 23 | Wang et al. Person Re-Identification by Video Ranking. ECCV 2014. 24 | 25 | Dataset statistics: 26 | # identities: 300 27 | # tracklets: 600 28 | # cameras: 2 29 | 30 | Args: 31 | split_id (int): indicates which split to use. There are totally 10 splits. 32 | """ 33 | root = '/data/i-LIDS-VID' 34 | dataset_url = 'http://www.eecs.qmul.ac.uk/~xiatian/iLIDS-VID/iLIDS-VID.tar' 35 | 36 | def __init__(self, root='/data/i-LIDS-VID', split_id=0): 37 | self.root = root 38 | self.data_dir = osp.join(self.root, 'i-LIDS-VID') 39 | self.split_dir = osp.join(self.data_dir, 'train-test people splits') 40 | self.split_mat_path = osp.join(self.split_dir, 'train_test_splits_ilidsvid.mat') 41 | self.split_path = osp.join(self.data_dir, 'splits.json') 42 | self.cam_1_path = osp.join(self.data_dir, 'i-LIDS-VID/sequences/cam1') 43 | self.cam_2_path = osp.join(self.data_dir, 'i-LIDS-VID/sequences/cam2') 44 | 45 | if split_id == 0: 46 | self._download_data() 47 | self._check_before_run() 48 | 49 | self._prepare_split() 50 | splits = read_json(self.split_path) 51 | if split_id >= len(splits): 52 | raise ValueError("split_id exceeds range, received {}, but expected between 0 and {}".format(split_id, len(splits)-1)) 53 | split = splits[split_id] 54 | train_dirs, test_dirs = split['train'], split['test'] 55 | if split_id == 0: 56 | print("# train identites: {}, # test identites {}".format(len(train_dirs), len(test_dirs))) 57 | 58 | train, num_train_tracklets, num_train_pids, num_imgs_train = \ 59 | self._process_data(train_dirs, cam1=True, cam2=True) 60 | query, num_query_tracklets, num_query_pids, num_imgs_query = \ 61 | self._process_data(test_dirs, cam1=True, cam2=False) 62 | gallery, num_gallery_tracklets, num_gallery_pids, num_imgs_gallery = \ 63 | self._process_data(test_dirs, cam1=False, cam2=True) 64 | 65 | num_imgs_per_tracklet = num_imgs_train + num_imgs_query + num_imgs_gallery 66 | min_num = np.min(num_imgs_per_tracklet) 67 | max_num = np.max(num_imgs_per_tracklet) 68 | avg_num = np.mean(num_imgs_per_tracklet) 69 | 70 | num_total_pids = num_train_pids + num_query_pids 71 | num_total_tracklets = num_train_tracklets + num_query_tracklets + num_gallery_tracklets 72 | 73 | if split_id == 0: 74 | print("=> iLIDS-VID loaded") 75 | print("Dataset statistics:") 76 | print(" ------------------------------") 77 | print(" subset | # ids | # tracklets") 78 | print(" ------------------------------") 79 | print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_tracklets)) 80 | print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 81 | print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 82 | print(" ------------------------------") 83 | print(" total | {:5d} | {:8d}".format(num_total_pids, num_total_tracklets)) 84 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 85 | print(" ------------------------------") 86 | 87 | self.train = train 88 | self.query = query 89 | self.gallery = gallery 90 | 91 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info( 92 | self.train) 93 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info( 94 | self.query) 95 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info( 96 | self.gallery) 97 | 98 | def _download_data(self): 99 | if osp.exists(self.root): 100 | print("This dataset has been downloaded.") 101 | return 102 | 103 | mkdir_if_missing(self.root) 104 | fpath = osp.join(self.root, osp.basename(self.dataset_url)) 105 | 106 | print("Downloading iLIDS-VID dataset") 107 | url_opener = urllib.request.URLopener() 108 | url_opener.retrieve(self.dataset_url, fpath) 109 | 110 | print("Extracting files") 111 | tar = tarfile.open(fpath) 112 | tar.extractall(path=self.root) 113 | tar.close() 114 | 115 | def _check_before_run(self): 116 | """Check if all files are available before going deeper""" 117 | if not osp.exists(self.root): 118 | raise RuntimeError("'{}' is not available".format(self.root)) 119 | if not osp.exists(self.data_dir): 120 | raise RuntimeError("'{}' is not available".format(self.data_dir)) 121 | if not osp.exists(self.split_dir): 122 | raise RuntimeError("'{}' is not available".format(self.split_dir)) 123 | 124 | def _prepare_split(self): 125 | if not osp.exists(self.split_path): 126 | print("Creating splits") 127 | mat_split_data = loadmat(self.split_mat_path)['ls_set'] 128 | 129 | num_splits = mat_split_data.shape[0] 130 | num_total_ids = mat_split_data.shape[1] 131 | assert num_splits == 10 132 | assert num_total_ids == 300 133 | num_ids_each = int(num_total_ids/2) 134 | 135 | # pids in mat_split_data are indices, so we need to transform them 136 | # to real pids 137 | person_cam1_dirs = os.listdir(self.cam_1_path) 138 | person_cam2_dirs = os.listdir(self.cam_2_path) 139 | 140 | # make sure persons in one camera view can be found in the other camera view 141 | assert set(person_cam1_dirs) == set(person_cam2_dirs) 142 | 143 | splits = [] 144 | for i_split in range(num_splits): 145 | # first 50% for testing and the remaining for training, following Wang et al. ECCV'14. 146 | train_idxs = sorted(list(mat_split_data[i_split,num_ids_each:])) 147 | test_idxs = sorted(list(mat_split_data[i_split,:num_ids_each])) 148 | 149 | train_idxs = [int(i)-1 for i in train_idxs] 150 | test_idxs = [int(i)-1 for i in test_idxs] 151 | 152 | # transform pids to person dir names 153 | train_dirs = [person_cam1_dirs[i] for i in train_idxs] 154 | test_dirs = [person_cam1_dirs[i] for i in test_idxs] 155 | 156 | split = {'train': train_dirs, 'test': test_dirs} 157 | splits.append(split) 158 | 159 | print("Totally {} splits are created, following Wang et al. ECCV'14".format(len(splits))) 160 | print("Split file is saved to {}".format(self.split_path)) 161 | write_json(splits, self.split_path) 162 | 163 | print("Splits created") 164 | 165 | def _process_data(self, dirnames, cam1=True, cam2=True): 166 | tracklets = [] 167 | num_imgs_per_tracklet = [] 168 | dirname2pid = {dirname:i for i, dirname in enumerate(dirnames)} 169 | 170 | #txt_name = 'ilids' + str(cam1) + str(cam2) + '.txt' 171 | #fid = open(txt_name, "w") 172 | 173 | for dirname in dirnames: 174 | if cam1: 175 | person_dir = osp.join(self.cam_1_path, dirname) 176 | img_names = glob.glob(osp.join(person_dir, '*.png')) 177 | assert len(img_names) > 0 178 | img_names = tuple(img_names) 179 | pid = dirname2pid[dirname] 180 | tracklets.append((img_names, pid, 0, 1)) 181 | num_imgs_per_tracklet.append(len(img_names)) 182 | #fid.write("cam1_" + dirname + '\n') 183 | if cam2: 184 | person_dir = osp.join(self.cam_2_path, dirname) 185 | img_names = glob.glob(osp.join(person_dir, '*.png')) 186 | assert len(img_names) > 0 187 | img_names = tuple(img_names) 188 | pid = dirname2pid[dirname] 189 | tracklets.append((img_names, pid, 1, 1)) 190 | num_imgs_per_tracklet.append(len(img_names)) 191 | #fid.write("cam2_" + dirname + '\n') 192 | 193 | num_tracklets = len(tracklets) 194 | num_pids = len(dirnames) 195 | 196 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 197 | 198 | """Create dataset""" 199 | 200 | __factory = { 201 | 'ilidsvid': iLIDSVID, 202 | } 203 | 204 | def get_names(): 205 | return __factory.keys() 206 | 207 | def init_dataset(name, *args, **kwargs): 208 | if name not in __factory.keys(): 209 | raise KeyError("Unknown dataset: {}".format(name)) 210 | return __factory[name](*args, **kwargs) 211 | 212 | if __name__ == '__main__': 213 | # test 214 | dataset = iLIDSVID() 215 | 216 | def mkdir_if_missing(directory): 217 | if not osp.exists(directory): 218 | try: 219 | os.makedirs(directory) 220 | except OSError as e: 221 | if e.errno != errno.EEXIST: 222 | raise 223 | 224 | def read_json(fpath): 225 | with open(fpath, 'r') as f: 226 | obj = json.load(f) 227 | return obj 228 | 229 | def write_json(obj, fpath): 230 | mkdir_if_missing(osp.dirname(fpath)) 231 | with open(fpath, 'w') as f: 232 | json.dump(obj, f, indent=4, separators=(',', ': ')) 233 | 234 | 235 | -------------------------------------------------------------------------------- /datasets/make_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as T 3 | from torch.utils.data import DataLoader 4 | 5 | from .bases import ImageDataset, VideoDataset 6 | from timm.data.random_erasing import RandomErasing 7 | from .sampler import RandomIdentitySampler, ReIDBatchSampler 8 | from .sampler_ddp import RandomIdentitySampler_DDP 9 | import torch.distributed as dist 10 | from .mars import Mars 11 | from .ilids import iLIDSVID 12 | 13 | from .misc import get_transforms, init_worker 14 | 15 | __factory = { 16 | 'mars': Mars, 17 | 'ilids': iLIDSVID, 18 | } 19 | 20 | def train_collate_fn(batch): 21 | """ 22 | # collate_fn这个函数的输入就是一个list,list的长度是一个batch size,list中的每个元素都是__getitem__得到的结果 23 | """ 24 | imgs, pids, camids, viewids , _ = zip(*batch) 25 | pids = torch.tensor(pids, dtype=torch.int64) 26 | viewids = torch.tensor(viewids, dtype=torch.int64) 27 | camids = torch.tensor(camids, dtype=torch.int64) 28 | return torch.stack(imgs, dim=0), pids, camids, viewids, 29 | 30 | def val_collate_fn(batch): 31 | imgs, pids, camids, viewids, img_paths = zip(*batch) 32 | viewids = torch.tensor(viewids, dtype=torch.int64) 33 | camids_batch = torch.tensor(camids, dtype=torch.int64) 34 | return torch.stack(imgs, dim=0), pids, camids, camids_batch, viewids, img_paths 35 | 36 | def make_dataloader(cfg): 37 | train_transforms = T.Compose([ 38 | T.Resize(cfg.INPUT.SIZE_TRAIN, interpolation=3), 39 | T.RandomHorizontalFlip(p=cfg.INPUT.PROB), 40 | T.Pad(cfg.INPUT.PADDING), 41 | T.RandomCrop(cfg.INPUT.SIZE_TRAIN), 42 | T.ToTensor(), 43 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD), 44 | RandomErasing(probability=cfg.INPUT.RE_PROB, mode='pixel', max_count=1, device='cpu'), 45 | ]) 46 | 47 | val_transforms = T.Compose([ 48 | T.Resize(cfg.INPUT.SIZE_TEST), 49 | T.ToTensor(), 50 | T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD) 51 | ]) 52 | 53 | num_workers = cfg.DATALOADER.NUM_WORKERS 54 | pin_memory = True 55 | 56 | if cfg.DATASETS.NAMES in ['ilids', 'prid']: 57 | dataset_10trails = [] 58 | for i in range(10): 59 | dataset = __factory[cfg.DATASETS.NAMES](root=cfg.DATASETS.ROOT_DIR, split_id=i) 60 | dataset_10trails.append(dataset) 61 | else: 62 | dataset = __factory[cfg.DATASETS.NAMES](root=cfg.DATASETS.ROOT_DIR) 63 | 64 | if cfg.DATASETS.NAMES in ['mars', 'duke-video-reid']: 65 | s_tr_train, t_tr_train = get_transforms(True, cfg) 66 | train_set = VideoDataset(dataset.train, spatial_transform=s_tr_train, 67 | temporal_transform=t_tr_train) 68 | elif cfg.DATASETS.NAMES in ['ilids', 'prid']: 69 | s_tr_train, t_tr_train = get_transforms(True, cfg) 70 | train_set = [VideoDataset(i.train, spatial_transform=s_tr_train, 71 | temporal_transform=t_tr_train) for i in dataset_10trails] 72 | else: 73 | train_set = ImageDataset(dataset.train, train_transforms) 74 | train_set_normal = ImageDataset(dataset.train, val_transforms) 75 | num_classes = dataset.num_train_pids 76 | cam_num = dataset.num_train_cams 77 | view_num = dataset.num_train_vids 78 | 79 | if 'triplet' in cfg.DATALOADER.SAMPLER: 80 | if cfg.MODEL.DIST_TRAIN: 81 | print('DIST_TRAIN START') 82 | mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // dist.get_world_size() 83 | data_sampler = RandomIdentitySampler_DDP(dataset.train, cfg.SOLVER.IMS_PER_BATCH, cfg.DATALOADER.NUM_INSTANCE) 84 | batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, mini_batch_size, True) 85 | train_loader = torch.utils.data.DataLoader( 86 | train_set, 87 | num_workers=num_workers, 88 | batch_sampler=batch_sampler, 89 | collate_fn=train_collate_fn, 90 | pin_memory=True, 91 | ) 92 | else: 93 | if cfg.DATASETS.NAMES in ['mars']: 94 | train_loader = [DataLoader( 95 | train_set, batch_sampler=ReIDBatchSampler(dataset.train, p=cfg.DATALOADER.P, 96 | k=cfg.DATALOADER.K), 97 | num_workers=num_workers, collate_fn=train_collate_fn, worker_init_fn=init_worker, 98 | pin_memory=pin_memory 99 | )] 100 | elif cfg.DATASETS.NAMES in ['ilids']: 101 | train_loader = [DataLoader( 102 | i, batch_sampler=ReIDBatchSampler(j.train, p=cfg.DATALOADER.P, 103 | k=cfg.DATALOADER.K), 104 | num_workers=num_workers, collate_fn=train_collate_fn, worker_init_fn=init_worker, 105 | pin_memory=pin_memory 106 | ) for i,j in zip(train_set, dataset_10trails)] 107 | else: 108 | train_loader = DataLoader( 109 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, 110 | sampler=RandomIdentitySampler(dataset.train, cfg.SOLVER.IMS_PER_BATCH, 111 | cfg.DATALOADER.NUM_INSTANCE), 112 | num_workers=num_workers, collate_fn=train_collate_fn 113 | ) 114 | elif cfg.DATALOADER.SAMPLER == 'softmax': 115 | print('using softmax sampler') 116 | train_loader = DataLoader( 117 | train_set, batch_size=cfg.SOLVER.IMS_PER_BATCH, shuffle=True, num_workers=num_workers, 118 | collate_fn=train_collate_fn 119 | ) 120 | else: 121 | print('unsupported sampler! expected softmax or triplet but got {}'.format(cfg.SAMPLER)) 122 | 123 | if cfg.DATASETS.NAMES in ['mars']: 124 | s_tr_test, t_tr_test = get_transforms(False, cfg) 125 | 126 | val_set = VideoDataset(dataset.query + dataset.gallery, spatial_transform=s_tr_test, 127 | temporal_transform=t_tr_test) 128 | val_loader = [DataLoader( 129 | val_set, batch_size=cfg.TEST.TEST_BATCH, shuffle=False, num_workers=2, 130 | pin_memory=pin_memory, drop_last=False, worker_init_fn=init_worker, 131 | collate_fn=val_collate_fn 132 | )] 133 | elif cfg.DATASETS.NAMES in ['ilids']: 134 | s_tr_test, t_tr_test = get_transforms(False, cfg) 135 | 136 | val_set = [VideoDataset(i.query + i.gallery, spatial_transform=s_tr_test, 137 | temporal_transform=t_tr_test) for i in dataset_10trails] 138 | val_loader = [DataLoader( 139 | i, batch_size=cfg.TEST.TEST_BATCH, shuffle=False, num_workers=2, 140 | pin_memory=pin_memory, drop_last=False, worker_init_fn=init_worker, 141 | collate_fn=val_collate_fn 142 | ) for i in val_set] 143 | else: 144 | val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms) 145 | 146 | val_loader = DataLoader( 147 | val_set, batch_size=cfg.TEST.IMS_PER_BATCH, shuffle=False, num_workers=num_workers, 148 | collate_fn=val_collate_fn 149 | ) 150 | 151 | return train_loader, val_loader, len(dataset.query), num_classes, cam_num, view_num 152 | -------------------------------------------------------------------------------- /datasets/mars.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from scipy.io import loadmat 3 | import numpy as np 4 | 5 | from .bases import BaseImageDataset 6 | 7 | class Mars(BaseImageDataset): 8 | """ 9 | MARS 10 | 11 | Reference: 12 | Zheng et al. MARS: A Video Benchmark for Large-Scale Person Re-identification. ECCV 2016. 13 | """ 14 | 15 | def __init__(self, root='/data/datasets/', min_seq_len=0): 16 | self.root = osp.join(root, 'mars') 17 | self.train_name_path = osp.join(self.root, 'info/train_name.txt') 18 | self.test_name_path = osp.join(self.root, 'info/test_name.txt') 19 | self.track_train_info_path = osp.join(self.root, 'info/tracks_train_info.mat') 20 | self.track_test_info_path = osp.join(self.root, 'info/tracks_test_info.mat') 21 | self.query_IDX_path = osp.join(self.root, 'info/query_IDX.mat') 22 | 23 | self._check_before_run() 24 | 25 | # prepare meta data 26 | train_names = self._get_names(self.train_name_path) 27 | test_names = self._get_names(self.test_name_path) 28 | track_train = loadmat(self.track_train_info_path)['track_train_info'] # numpy.ndarray (8298, 4) 29 | track_test = loadmat(self.track_test_info_path)['track_test_info'] # numpy.ndarray (12180, 4) 30 | query_IDX = loadmat(self.query_IDX_path)['query_IDX'].squeeze() # numpy.ndarray (1980,) 31 | query_IDX -= 1 # index from 0 32 | track_query = track_test[query_IDX,:] 33 | gallery_IDX = [i for i in range(track_test.shape[0]) if i not in query_IDX] 34 | track_gallery = track_test[gallery_IDX,:] 35 | # track_gallery = track_test 36 | 37 | train, num_train_tracklets, num_train_pids, num_train_imgs = \ 38 | self._process_data(train_names, track_train, home_dir='bbox_train', relabel=True, min_seq_len=min_seq_len) 39 | 40 | query, num_query_tracklets, num_query_pids, num_query_imgs = \ 41 | self._process_data(test_names, track_query, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) 42 | 43 | gallery, num_gallery_tracklets, num_gallery_pids, num_gallery_imgs = \ 44 | self._process_data(test_names, track_gallery, home_dir='bbox_test', relabel=False, min_seq_len=min_seq_len) 45 | 46 | train_img, _, _ = \ 47 | self._extract_1stfeame(train_names, track_train, home_dir='bbox_train', relabel=True) 48 | 49 | query_img, _, _ = \ 50 | self._extract_1stfeame(test_names, track_query, home_dir='bbox_test', relabel=False) 51 | 52 | gallery_img, _, _ = \ 53 | self._extract_1stfeame(test_names, track_gallery, home_dir='bbox_test', relabel=False) 54 | 55 | num_imgs_per_tracklet = num_train_imgs + num_gallery_imgs + num_query_imgs 56 | total_num = np.sum(num_imgs_per_tracklet) 57 | min_num = np.min(num_imgs_per_tracklet) 58 | max_num = np.max(num_imgs_per_tracklet) 59 | avg_num = np.mean(num_imgs_per_tracklet) 60 | 61 | num_total_pids = num_train_pids + num_query_pids 62 | num_total_tracklets = num_train_tracklets + num_gallery_tracklets + num_query_tracklets 63 | 64 | print("=> MARS loaded") 65 | print("Dataset statistics:") 66 | print(" -----------------------------------------") 67 | print(" subset | # ids | # tracklets | # images") 68 | print(" -----------------------------------------") 69 | print(" train | {:5d} | {:8d} | {:8d}".format(num_train_pids, num_train_tracklets, np.sum(num_train_imgs))) 70 | print(" query | {:5d} | {:8d} | {:8d}".format(num_query_pids, num_query_tracklets, np.sum(num_query_imgs))) 71 | print(" gallery | {:5d} | {:8d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets, np.sum(num_gallery_imgs))) 72 | print(" -----------------------------------------") 73 | print(" total | {:5d} | {:8d} | {:8d}".format(num_total_pids, num_total_tracklets, total_num)) 74 | print(" -----------------------------------------") 75 | print(" number of images per tracklet: {} ~ {}, average {:.1f}".format(min_num, max_num, avg_num)) 76 | print(" -----------------------------------------") 77 | 78 | self.train = train 79 | self.query = query 80 | self.gallery = gallery 81 | 82 | self.train_img = train_img 83 | self.query_img = query_img 84 | self.gallery_img = gallery_img 85 | 86 | self.num_train_pids, self.num_train_imgs, self.num_train_cams, self.num_train_vids = self.get_imagedata_info( 87 | self.train) 88 | self.num_query_pids, self.num_query_imgs, self.num_query_cams, self.num_query_vids = self.get_imagedata_info( 89 | self.query) 90 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams, self.num_gallery_vids = self.get_imagedata_info( 91 | self.gallery) 92 | 93 | def _check_before_run(self): 94 | """Check if all files are available before going deeper""" 95 | if not osp.exists(self.root): 96 | raise RuntimeError("'{}' is not available".format(self.root)) 97 | if not osp.exists(self.train_name_path): 98 | raise RuntimeError("'{}' is not available".format(self.train_name_path)) 99 | if not osp.exists(self.test_name_path): 100 | raise RuntimeError("'{}' is not available".format(self.test_name_path)) 101 | if not osp.exists(self.track_train_info_path): 102 | raise RuntimeError("'{}' is not available".format(self.track_train_info_path)) 103 | if not osp.exists(self.track_test_info_path): 104 | raise RuntimeError("'{}' is not available".format(self.track_test_info_path)) 105 | if not osp.exists(self.query_IDX_path): 106 | raise RuntimeError("'{}' is not available".format(self.query_IDX_path)) 107 | 108 | def _get_names(self, fpath): 109 | names = [] 110 | with open(fpath, 'r') as f: 111 | for line in f: 112 | new_line = line.rstrip() 113 | names.append(new_line) 114 | return names 115 | 116 | def _process_data(self, names, meta_data, home_dir=None, relabel=False, min_seq_len=0): 117 | assert home_dir in ['bbox_train', 'bbox_test'] 118 | num_tracklets = meta_data.shape[0] 119 | pid_list = list(set(meta_data[:,2].tolist())) 120 | num_pids = len(pid_list) 121 | 122 | if relabel: pid2label = {pid:label for label, pid in enumerate(pid_list)} 123 | tracklets = [] 124 | num_imgs_per_tracklet = [] 125 | 126 | for tracklet_idx in range(num_tracklets): 127 | data = meta_data[tracklet_idx,...] 128 | start_index, end_index, pid, camid = data 129 | if pid == -1: continue # junk images are just ignored 130 | assert 1 <= camid <= 6 131 | if relabel: pid = pid2label[pid] 132 | camid -= 1 # index starts from 0 133 | img_names = names[start_index-1:end_index] 134 | 135 | # make sure image names correspond to the same person 136 | pnames = [img_name[:4] for img_name in img_names] 137 | assert len(set(pnames)) == 1, "Error: a single tracklet contains different person images" 138 | 139 | # make sure all images are captured under the same camera 140 | camnames = [img_name[5] for img_name in img_names] 141 | assert len(set(camnames)) == 1, "Error: images are captured under different cameras!" 142 | 143 | # append image names with directory information 144 | img_paths = [osp.join(self.root, home_dir, img_name[:4], img_name) for img_name in img_names] 145 | if len(img_paths) >= min_seq_len: 146 | img_paths = tuple(img_paths) 147 | tracklets.append((img_paths, pid, camid, 1)) 148 | num_imgs_per_tracklet.append(len(img_paths)) 149 | 150 | num_tracklets = len(tracklets) 151 | 152 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 153 | 154 | def _extract_1stfeame(self, names, meta_data, home_dir=None, relabel=False): 155 | assert home_dir in ['bbox_train', 'bbox_test'] 156 | num_tracklets = meta_data.shape[0] 157 | pid_list = list(set(meta_data[:,2].tolist())) 158 | num_pids = len(pid_list) 159 | 160 | if relabel: pid2label = {pid:label for label, pid in enumerate(pid_list)} 161 | imgs = [] 162 | 163 | for tracklet_idx in range(num_tracklets): 164 | data = meta_data[tracklet_idx,...] 165 | start_index, end_index, pid, camid = data 166 | if pid == -1: continue # junk images are just ignored 167 | assert 1 <= camid <= 6 168 | if relabel: pid = pid2label[pid] 169 | camid -= 1 # index starts from 0 170 | img_name = names[start_index-1] 171 | 172 | # append image names with directory information 173 | img_path = osp.join(self.root, home_dir, img_name[:4], img_name) 174 | 175 | imgs.append(([img_path], pid, camid)) 176 | 177 | num_imgs = len(imgs) 178 | 179 | return imgs, num_imgs, num_pids 180 | -------------------------------------------------------------------------------- /datasets/misc.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | import argparse 4 | from argparse import Namespace 5 | import numpy as np 6 | from PIL import Image 7 | from torchvision import transforms as T 8 | 9 | from .temporal_transforms import TemporalChunkCrop, TemporalRandomFrames, RandomTemporalChunkCrop, MultiViewTemporalTransform 10 | 11 | def init_worker(worker_id): 12 | np.random.seed(1234 + worker_id) 13 | 14 | def str2bool(v): 15 | if isinstance(v, bool): 16 | return v 17 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 18 | return True 19 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 20 | return False 21 | else: 22 | raise argparse.ArgumentTypeError('Boolean value expected.') 23 | 24 | def pil_loader(path): 25 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 26 | with open(path, 'rb') as f: 27 | with Image.open(f) as img: 28 | return img.convert('RGB') 29 | 30 | 31 | def accimage_loader(path): 32 | try: 33 | import accimage 34 | return accimage.Image(path) 35 | except IOError: 36 | # Potentially a decoding problem, fall back to PIL.Image 37 | return pil_loader(path) 38 | 39 | 40 | def get_default_image_loader(): 41 | from torchvision import get_image_backend 42 | if get_image_backend() == 'accimage': 43 | return accimage_loader 44 | else: 45 | return pil_loader 46 | 47 | 48 | def image_loader(path): 49 | from torchvision import get_image_backend 50 | if get_image_backend() == 'accimage': 51 | return accimage_loader(path) 52 | else: 53 | return pil_loader(path) 54 | 55 | 56 | def video_loader(img_paths, image_loader): 57 | video = [] 58 | for image_path in img_paths: 59 | if os.path.exists(image_path): 60 | video.append(image_loader(image_path)) 61 | else: 62 | return video 63 | 64 | return video 65 | 66 | 67 | def get_default_video_loader(): 68 | image_loader = get_default_image_loader() 69 | return functools.partial(video_loader, image_loader=image_loader) 70 | 71 | 72 | def get_transforms(train_mode: bool, cfg: Namespace): 73 | 74 | mean, var = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 75 | 76 | input_res = { 77 | 'mars': (256, 128), 78 | 'ilids': (256, 128), 79 | }.get(cfg.DATASETS.NAMES, (224, 224)) 80 | 81 | erase_ratio = { 82 | 'mars': (0.3, 3.3), 83 | 'ilids': (0.3, 3.3), 84 | }.get(cfg.DATASETS.NAMES, (0.7, 1.4)) 85 | 86 | erase_scale = (0.02, 0.4) 87 | 88 | resize_operation = { 89 | 'mars': T.Resize(input_res, interpolation=3), 90 | 'ilids': T.Resize(input_res, interpolation=3), 91 | }.get(cfg.DATASETS.NAMES, AdaptiveResize(height=input_res[0], width=input_res[1])) 92 | 93 | if not train_mode: 94 | t_tr_test = TemporalChunkCrop(cfg.DATALOADER.NUM_TEST_IMAGES) 95 | 96 | s_tr_test = T.Compose([ 97 | resize_operation, 98 | T.ToTensor(), 99 | T.Normalize(mean, var) 100 | ]) 101 | return s_tr_test, t_tr_test 102 | 103 | tr_re = [T.RandomErasing(p=0.5, scale=erase_scale, ratio=erase_ratio)] \ 104 | if cfg.INPUT.RE else [] 105 | 106 | # Data augmentation 107 | s_tr_train = T.Compose([ 108 | resize_operation, 109 | T.Pad(10), 110 | T.RandomCrop(input_res), 111 | T.RandomHorizontalFlip(), 112 | T.ToTensor(), 113 | T.Normalize(mean, var) 114 | ] + tr_re) 115 | 116 | if cfg.MODEL.TRAIN_STRATEGY == 'random': 117 | t_tr_train = TemporalRandomFrames(cfg.DATALOADER.NUM_TRAIN_IMAGES) 118 | elif cfg.MODEL.TRAIN_STRATEGY == 'chunk': 119 | t_tr_train = RandomTemporalChunkCrop(cfg.DATALOADER.NUM_TRAIN_IMAGES) 120 | elif cfg.MODEL.TRAIN_STRATEGY == 'temporal': 121 | t_tr_train = TemporalChunkCrop(cfg.DATALOADER.NUM_TRAIN_IMAGES) 122 | elif cfg.MODEL.TRAIN_STRATEGY == 'multiview': 123 | t_tr_train = MultiViewTemporalTransform(cfg.DATALOADER.NUM_TRAIN_IMAGES) 124 | else: 125 | raise ValueError 126 | 127 | return s_tr_train, t_tr_train 128 | 129 | 130 | class AdaptiveResize: 131 | def __init__(self, width, height, interpolation=3): 132 | self.height = height 133 | self.width = width 134 | self.interpolation = interpolation 135 | 136 | @staticmethod 137 | def get_padding(padding): 138 | if padding == 0: 139 | p_1, p_2 = 0, 0 140 | elif padding % 2 == 0: 141 | p_1, p_2 = padding // 2, padding // 2 142 | else: 143 | p_1, p_2 = padding // 2 + 1, padding // 2 144 | return p_1, p_2 145 | 146 | def __call__(self, img: Image.Image): 147 | h, w = img.height, img.width 148 | # resize to ensure fit in target shape 149 | ratio_w = self.width / w 150 | ratio_h = self.height / h 151 | ratio = min(ratio_w, ratio_h) 152 | new_w, new_h = map(lambda x: int(np.floor(x * ratio)), (w, h)) 153 | img = img.resize((new_w, new_h), resample=self.interpolation) 154 | 155 | # compute padding 156 | h, w = img.height, img.width 157 | p_t, p_b = self.get_padding(self.height - h) 158 | p_l, p_r = self.get_padding(self.width - w) 159 | 160 | # copy into new buffer 161 | img = np.pad(np.asarray(img), ((p_t, p_b), (p_l, p_r), (0, 0)), mode='constant') 162 | 163 | return Image.fromarray(img) 164 | -------------------------------------------------------------------------------- /datasets/preprocessing.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | 4 | 5 | class RandomErasing(object): 6 | """ Randomly selects a rectangle region in an image and erases its pixels. 7 | 'Random Erasing Data Augmentation' by Zhong et al. 8 | See https://arxiv.org/pdf/1708.04896.pdf 9 | Args: 10 | probability: The probability that the Random Erasing operation will be performed. 11 | sl: Minimum proportion of erased area against input image. 12 | sh: Maximum proportion of erased area against input image. 13 | r1: Minimum aspect ratio of erased area. 14 | mean: Erasing value. 15 | """ 16 | 17 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): 18 | self.probability = probability 19 | self.mean = mean 20 | self.sl = sl 21 | self.sh = sh 22 | self.r1 = r1 23 | 24 | def __call__(self, img): 25 | 26 | if random.uniform(0, 1) >= self.probability: 27 | return img 28 | 29 | for attempt in range(100): 30 | area = img.size()[1] * img.size()[2] 31 | 32 | target_area = random.uniform(self.sl, self.sh) * area 33 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 34 | 35 | h = int(round(math.sqrt(target_area * aspect_ratio))) 36 | w = int(round(math.sqrt(target_area / aspect_ratio))) 37 | 38 | if w < img.size()[2] and h < img.size()[1]: 39 | x1 = random.randint(0, img.size()[1] - h) 40 | y1 = random.randint(0, img.size()[2] - w) 41 | if img.size()[0] == 3: 42 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 43 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 44 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 45 | else: 46 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 47 | return img 48 | 49 | return img 50 | 51 | -------------------------------------------------------------------------------- /datasets/sampler.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | from collections import defaultdict 3 | import copy 4 | import random 5 | import numpy as np 6 | from itertools import chain 7 | 8 | class RandomIdentitySampler(Sampler): 9 | """ 10 | Randomly sample N identities, then for each identity, 11 | randomly sample K instances, therefore batch size is N*K. 12 | Args: 13 | - data_source (list): list of (img_path, pid, camid). 14 | - num_instances (int): number of instances per identity in a batch. 15 | - batch_size (int): number of examples in a batch. 16 | """ 17 | 18 | def __init__(self, data_source, batch_size, num_instances): 19 | self.data_source = data_source 20 | self.batch_size = batch_size 21 | self.num_instances = num_instances 22 | self.num_pids_per_batch = self.batch_size // self.num_instances 23 | self.index_dic = defaultdict(list) #dict with list value 24 | #{783: [0, 5, 116, 876, 1554, 2041],...,} 25 | for index, (_, pid, _, _) in enumerate(self.data_source): 26 | self.index_dic[pid].append(index) 27 | self.pids = list(self.index_dic.keys()) 28 | 29 | # estimate number of examples in an epoch 30 | self.length = 0 31 | for pid in self.pids: 32 | idxs = self.index_dic[pid] 33 | num = len(idxs) 34 | if num < self.num_instances: 35 | num = self.num_instances 36 | self.length += num - num % self.num_instances 37 | 38 | def __iter__(self): 39 | batch_idxs_dict = defaultdict(list) 40 | 41 | for pid in self.pids: 42 | idxs = copy.deepcopy(self.index_dic[pid]) 43 | if len(idxs) < self.num_instances: 44 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 45 | random.shuffle(idxs) 46 | batch_idxs = [] 47 | for idx in idxs: 48 | batch_idxs.append(idx) 49 | if len(batch_idxs) == self.num_instances: 50 | batch_idxs_dict[pid].append(batch_idxs) 51 | batch_idxs = [] 52 | 53 | avai_pids = copy.deepcopy(self.pids) 54 | final_idxs = [] 55 | 56 | while len(avai_pids) >= self.num_pids_per_batch: 57 | selected_pids = random.sample(avai_pids, self.num_pids_per_batch) 58 | for pid in selected_pids: 59 | batch_idxs = batch_idxs_dict[pid].pop(0) 60 | final_idxs.extend(batch_idxs) 61 | if len(batch_idxs_dict[pid]) == 0: 62 | avai_pids.remove(pid) 63 | 64 | return iter(final_idxs) 65 | 66 | def __len__(self): 67 | return self.length 68 | 69 | def compute_pids_and_pids_dict(data_source): 70 | 71 | index_dic = defaultdict(list) 72 | for index, (_, pid, _, _) in enumerate(data_source): 73 | index_dic[pid].append(index) 74 | pids = list(index_dic.keys()) 75 | return pids, index_dic 76 | 77 | 78 | class ReIDBatchSampler(Sampler): 79 | 80 | def __init__(self, data_source, p: int, k: int): 81 | 82 | self._p = p 83 | self._k = k 84 | 85 | pids, index_dic = compute_pids_and_pids_dict(data_source) 86 | 87 | self._unique_labels = np.array(pids) 88 | self._label_to_items = index_dic.copy() 89 | 90 | self._num_iterations = len(self._unique_labels) // self._p 91 | 92 | def __iter__(self): 93 | 94 | def sample(set, n): 95 | if len(set) < n: 96 | return np.random.choice(set, n, replace=True) 97 | return np.random.choice(set, n, replace=False) 98 | 99 | np.random.shuffle(self._unique_labels) 100 | 101 | for k, v in self._label_to_items.items(): 102 | random.shuffle(self._label_to_items[k]) 103 | 104 | curr_p = 0 105 | 106 | for idx in range(self._num_iterations): 107 | p_labels = self._unique_labels[curr_p: curr_p + self._p] 108 | curr_p += self._p 109 | batch = [sample(self._label_to_items[l], self._k) for l in p_labels] 110 | batch = list(chain(*batch)) 111 | yield batch 112 | 113 | def __len__(self): 114 | return self._num_iterations -------------------------------------------------------------------------------- /datasets/sampler_ddp.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.sampler import Sampler 2 | from collections import defaultdict 3 | import copy 4 | import random 5 | import numpy as np 6 | import math 7 | import torch.distributed as dist 8 | _LOCAL_PROCESS_GROUP = None 9 | import torch 10 | import pickle 11 | 12 | def _get_global_gloo_group(): 13 | """ 14 | Return a process group based on gloo backend, containing all the ranks 15 | The result is cached. 16 | """ 17 | if dist.get_backend() == "nccl": 18 | return dist.new_group(backend="gloo") 19 | else: 20 | return dist.group.WORLD 21 | 22 | def _serialize_to_tensor(data, group): 23 | backend = dist.get_backend(group) 24 | assert backend in ["gloo", "nccl"] 25 | device = torch.device("cpu" if backend == "gloo" else "cuda") 26 | 27 | buffer = pickle.dumps(data) 28 | if len(buffer) > 1024 ** 3: 29 | print( 30 | "Rank {} trying to all-gather {:.2f} GB of data on device {}".format( 31 | dist.get_rank(), len(buffer) / (1024 ** 3), device 32 | ) 33 | ) 34 | storage = torch.ByteStorage.from_buffer(buffer) 35 | tensor = torch.ByteTensor(storage).to(device=device) 36 | return tensor 37 | 38 | def _pad_to_largest_tensor(tensor, group): 39 | """ 40 | Returns: 41 | list[int]: size of the tensor, on each rank 42 | Tensor: padded tensor that has the max size 43 | """ 44 | world_size = dist.get_world_size(group=group) 45 | assert ( 46 | world_size >= 1 47 | ), "comm.gather/all_gather must be called from ranks within the given group!" 48 | local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device) 49 | size_list = [ 50 | torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size) 51 | ] 52 | dist.all_gather(size_list, local_size, group=group) 53 | size_list = [int(size.item()) for size in size_list] 54 | 55 | max_size = max(size_list) 56 | 57 | # we pad the tensor because torch all_gather does not support 58 | # gathering tensors of different shapes 59 | if local_size != max_size: 60 | padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device) 61 | tensor = torch.cat((tensor, padding), dim=0) 62 | return size_list, tensor 63 | 64 | def all_gather(data, group=None): 65 | """ 66 | Run all_gather on arbitrary picklable data (not necessarily tensors). 67 | Args: 68 | data: any picklable object 69 | group: a torch process group. By default, will use a group which 70 | contains all ranks on gloo backend. 71 | Returns: 72 | list[data]: list of data gathered from each rank 73 | """ 74 | if dist.get_world_size() == 1: 75 | return [data] 76 | if group is None: 77 | group = _get_global_gloo_group() 78 | if dist.get_world_size(group) == 1: 79 | return [data] 80 | 81 | tensor = _serialize_to_tensor(data, group) 82 | 83 | size_list, tensor = _pad_to_largest_tensor(tensor, group) 84 | max_size = max(size_list) 85 | 86 | # receiving Tensor from all ranks 87 | tensor_list = [ 88 | torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list 89 | ] 90 | dist.all_gather(tensor_list, tensor, group=group) 91 | 92 | data_list = [] 93 | for size, tensor in zip(size_list, tensor_list): 94 | buffer = tensor.cpu().numpy().tobytes()[:size] 95 | data_list.append(pickle.loads(buffer)) 96 | 97 | return data_list 98 | 99 | def shared_random_seed(): 100 | """ 101 | Returns: 102 | int: a random number that is the same across all workers. 103 | If workers need a shared RNG, they can use this shared seed to 104 | create one. 105 | All workers must call this function, otherwise it will deadlock. 106 | """ 107 | ints = np.random.randint(2 ** 31) 108 | all_ints = all_gather(ints) 109 | return all_ints[0] 110 | 111 | class RandomIdentitySampler_DDP(Sampler): 112 | """ 113 | Randomly sample N identities, then for each identity, 114 | randomly sample K instances, therefore batch size is N*K. 115 | Args: 116 | - data_source (list): list of (img_path, pid, camid). 117 | - num_instances (int): number of instances per identity in a batch. 118 | - batch_size (int): number of examples in a batch. 119 | """ 120 | 121 | def __init__(self, data_source, batch_size, num_instances): 122 | self.data_source = data_source 123 | self.batch_size = batch_size 124 | self.world_size = dist.get_world_size() 125 | self.num_instances = num_instances 126 | self.mini_batch_size = self.batch_size // self.world_size 127 | self.num_pids_per_batch = self.mini_batch_size // self.num_instances 128 | self.index_dic = defaultdict(list) 129 | 130 | for index, (_, pid, _, _) in enumerate(self.data_source): 131 | self.index_dic[pid].append(index) 132 | self.pids = list(self.index_dic.keys()) 133 | 134 | # estimate number of examples in an epoch 135 | self.length = 0 136 | for pid in self.pids: 137 | idxs = self.index_dic[pid] 138 | num = len(idxs) 139 | if num < self.num_instances: 140 | num = self.num_instances 141 | self.length += num - num % self.num_instances 142 | 143 | self.rank = dist.get_rank() 144 | #self.world_size = dist.get_world_size() 145 | self.length //= self.world_size 146 | 147 | def __iter__(self): 148 | seed = shared_random_seed() 149 | np.random.seed(seed) 150 | self._seed = int(seed) 151 | final_idxs = self.sample_list() 152 | length = int(math.ceil(len(final_idxs) * 1.0 / self.world_size)) 153 | #final_idxs = final_idxs[self.rank * length:(self.rank + 1) * length] 154 | final_idxs = self.__fetch_current_node_idxs(final_idxs, length) 155 | self.length = len(final_idxs) 156 | return iter(final_idxs) 157 | 158 | 159 | def __fetch_current_node_idxs(self, final_idxs, length): 160 | total_num = len(final_idxs) 161 | block_num = (length // self.mini_batch_size) 162 | index_target = [] 163 | for i in range(0, block_num * self.world_size, self.world_size): 164 | index = range(self.mini_batch_size * self.rank + self.mini_batch_size * i, min(self.mini_batch_size * self.rank + self.mini_batch_size * (i+1), total_num)) 165 | index_target.extend(index) 166 | index_target_npy = np.array(index_target) 167 | final_idxs = list(np.array(final_idxs)[index_target_npy]) 168 | return final_idxs 169 | 170 | 171 | def sample_list(self): 172 | #np.random.seed(self._seed) 173 | avai_pids = copy.deepcopy(self.pids) 174 | batch_idxs_dict = {} 175 | 176 | batch_indices = [] 177 | while len(avai_pids) >= self.num_pids_per_batch: 178 | selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False).tolist() 179 | for pid in selected_pids: 180 | if pid not in batch_idxs_dict: 181 | idxs = copy.deepcopy(self.index_dic[pid]) 182 | if len(idxs) < self.num_instances: 183 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True).tolist() 184 | np.random.shuffle(idxs) 185 | batch_idxs_dict[pid] = idxs 186 | 187 | avai_idxs = batch_idxs_dict[pid] 188 | for _ in range(self.num_instances): 189 | batch_indices.append(avai_idxs.pop(0)) 190 | 191 | if len(avai_idxs) < self.num_instances: avai_pids.remove(pid) 192 | 193 | return batch_indices 194 | 195 | def __len__(self): 196 | return self.length 197 | 198 | -------------------------------------------------------------------------------- /datasets/temporal_transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | 5 | class TemporalChunkCrop(object): 6 | 7 | def __init__(self, size: int = 4): 8 | self.S = size 9 | 10 | def __call__(self, frame_indices, tracklet_index): 11 | sample_clip = [] 12 | F = len(frame_indices) 13 | if F < self.S: 14 | strip = list(range(0, F)) + [F-1] * (self.S - F) 15 | for s in range(self.S): 16 | pool = strip[s * 1:(s + 1) * 1] 17 | sample_clip.append(list(pool)) 18 | else: 19 | interval = math.ceil(F / self.S) 20 | strip = list(range(0, F)) + [F-1] * (interval * self.S - F) 21 | for s in range(self.S): 22 | pool = strip[s * interval:(s + 1) * interval] 23 | sample_clip.append(list(pool)) 24 | return [ frame_indices[idx] for idx 25 | in np.array(sample_clip)[:, 0].tolist() ] 26 | 27 | 28 | class RandomTemporalChunkCrop(object): 29 | 30 | def __init__(self, size: int = 4): 31 | self.S = size 32 | 33 | def __call__(self, frame_indices, tracklet_index): 34 | sample_clip = [] 35 | F = len(frame_indices) 36 | if F < self.S: 37 | strip = list(range(0, F)) + [F-1] * (self.S - F) 38 | for s in range(self.S): 39 | pool = strip[s * 1:(s + 1) * 1] 40 | sample_clip.append(list(pool)) 41 | else: 42 | interval = math.ceil(F / self.S) 43 | strip = list(range(0, F)) + [F-1] * (interval * self.S - F) 44 | for s in range(self.S): 45 | pool = strip[s * interval:(s + 1) * interval] 46 | sample_clip.append(list(pool)) 47 | 48 | sample_clip = np.array(sample_clip) 49 | sample_clip = sample_clip[np.arange(self.S), 50 | np.random.randint(0, sample_clip.shape[1], self.S)] 51 | return [ frame_indices[idx] for idx in sample_clip ] 52 | 53 | 54 | class MultiViewTemporalTransform(object): 55 | 56 | def __init__(self, size: int = 4): 57 | self.size = size 58 | 59 | def __call__(self, candidate, tracklet_index): 60 | img_paths = [] 61 | candidate_perm = np.random.permutation(len(candidate)) 62 | for idx in range(self.size): 63 | cur_tracklet = candidate_perm[idx % len(candidate_perm)] 64 | cur_frame = np.random.randint(0, len(candidate[cur_tracklet][0])) 65 | cur_img_path = candidate[cur_tracklet][0][cur_frame] 66 | img_paths.append(cur_img_path) 67 | return img_paths 68 | 69 | 70 | class TemporalRandomFrames(object): 71 | """ 72 | Get size random frames (without replacement if possible) from a video 73 | """ 74 | 75 | def __init__(self, num_images=4): 76 | self.num_images = num_images 77 | 78 | def __call__(self, frame_indices, tracklet_index): 79 | frame_indices = list(frame_indices) 80 | if len(frame_indices) < self.num_images: 81 | return list(np.random.choice(frame_indices, size=self.num_images, replace=True)) 82 | 83 | return list(np.random.choice(frame_indices, size=self.num_images, replace=False)) 84 | -------------------------------------------------------------------------------- /framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deropty/PiT/779d3110808929f1f3d8c4e0ec6ec3f7f60dbffa/framework.jpg -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_loss import make_loss 2 | from .arcface import ArcFace -------------------------------------------------------------------------------- /loss/arcface.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn import Parameter 5 | import math 6 | 7 | 8 | class ArcFace(nn.Module): 9 | def __init__(self, in_features, out_features, s=30.0, m=0.50, bias=False): 10 | super(ArcFace, self).__init__() 11 | self.in_features = in_features 12 | self.out_features = out_features 13 | self.s = s 14 | self.m = m 15 | self.cos_m = math.cos(m) 16 | self.sin_m = math.sin(m) 17 | 18 | self.th = math.cos(math.pi - m) 19 | self.mm = math.sin(math.pi - m) * m 20 | 21 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 22 | if bias: 23 | self.bias = Parameter(torch.Tensor(out_features)) 24 | else: 25 | self.register_parameter('bias', None) 26 | self.reset_parameters() 27 | 28 | def reset_parameters(self): 29 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 30 | if self.bias is not None: 31 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 32 | bound = 1 / math.sqrt(fan_in) 33 | nn.init.uniform_(self.bias, -bound, bound) 34 | 35 | def forward(self, input, label): 36 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 37 | sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1)) 38 | phi = cosine * self.cos_m - sine * self.sin_m 39 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 40 | # --------------------------- convert label to one-hot --------------------------- 41 | # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda') 42 | one_hot = torch.zeros(cosine.size(), device='cuda') 43 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 44 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 45 | output = (one_hot * phi) + ( 46 | (1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 47 | output *= self.s 48 | # print(output) 49 | 50 | return output 51 | 52 | class CircleLoss(nn.Module): 53 | def __init__(self, in_features, num_classes, s=256, m=0.25): 54 | super(CircleLoss, self).__init__() 55 | self.weight = Parameter(torch.Tensor(num_classes, in_features)) 56 | self.s = s 57 | self.m = m 58 | self._num_classes = num_classes 59 | self.reset_parameters() 60 | 61 | 62 | def reset_parameters(self): 63 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 64 | 65 | def __call__(self, bn_feat, targets): 66 | 67 | sim_mat = F.linear(F.normalize(bn_feat), F.normalize(self.weight)) 68 | alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self.m, min=0.) 69 | alpha_n = torch.clamp_min(sim_mat.detach() + self.m, min=0.) 70 | delta_p = 1 - self.m 71 | delta_n = self.m 72 | 73 | s_p = self.s * alpha_p * (sim_mat - delta_p) 74 | s_n = self.s * alpha_n * (sim_mat - delta_n) 75 | 76 | targets = F.one_hot(targets, num_classes=self._num_classes) 77 | 78 | pred_class_logits = targets * s_p + (1.0 - targets) * s_n 79 | 80 | return pred_class_logits -------------------------------------------------------------------------------- /loss/center_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class CenterLoss(nn.Module): 8 | """Center loss. 9 | 10 | Reference: 11 | Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. 12 | 13 | Args: 14 | num_classes (int): number of classes. 15 | feat_dim (int): feature dimension. 16 | """ 17 | 18 | def __init__(self, num_classes=751, feat_dim=2048, use_gpu=True): 19 | super(CenterLoss, self).__init__() 20 | self.num_classes = num_classes 21 | self.feat_dim = feat_dim 22 | self.use_gpu = use_gpu 23 | 24 | if self.use_gpu: 25 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda()) 26 | else: 27 | self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim)) 28 | 29 | def forward(self, x, labels): 30 | """ 31 | Args: 32 | x: feature matrix with shape (batch_size, feat_dim). 33 | labels: ground truth labels with shape (num_classes). 34 | """ 35 | assert x.size(0) == labels.size(0), "features.size(0) is not equal to labels.size(0)" 36 | 37 | batch_size = x.size(0) 38 | distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ 39 | torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() 40 | distmat.addmm_(1, -2, x, self.centers.t()) 41 | 42 | classes = torch.arange(self.num_classes).long() 43 | if self.use_gpu: classes = classes.cuda() 44 | labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) 45 | mask = labels.eq(classes.expand(batch_size, self.num_classes)) 46 | 47 | dist = [] 48 | for i in range(batch_size): 49 | value = distmat[i][mask[i]] 50 | value = value.clamp(min=1e-12, max=1e+12) # for numerical stability 51 | dist.append(value) 52 | dist = torch.cat(dist) 53 | loss = dist.mean() 54 | return loss 55 | 56 | 57 | if __name__ == '__main__': 58 | use_gpu = False 59 | center_loss = CenterLoss(use_gpu=use_gpu) 60 | features = torch.rand(16, 2048) 61 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).long() 62 | if use_gpu: 63 | features = torch.rand(16, 2048).cuda() 64 | targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda() 65 | 66 | loss = center_loss(features, targets) 67 | print(loss) 68 | -------------------------------------------------------------------------------- /loss/make_loss.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import torch.nn.functional as F 8 | from .softmax_loss import CrossEntropyLabelSmooth, LabelSmoothingCrossEntropy 9 | from .triplet_loss import TripletLoss 10 | from .center_loss import CenterLoss 11 | 12 | 13 | def make_loss(cfg, num_classes): # modified by gu 14 | sampler = cfg.DATALOADER.SAMPLER 15 | feat_dim = 2048 16 | center_criterion = CenterLoss(num_classes=num_classes, feat_dim=feat_dim, use_gpu=True) # center loss 17 | if 'triplet' in cfg.MODEL.METRIC_LOSS_TYPE: 18 | if cfg.MODEL.NO_MARGIN: 19 | triplet = TripletLoss() 20 | print("using soft triplet loss for training") 21 | else: 22 | triplet = TripletLoss(cfg.SOLVER.MARGIN) # triplet loss 23 | print("using triplet loss with margin:{}".format(cfg.SOLVER.MARGIN)) 24 | else: 25 | print('expected METRIC_LOSS_TYPE should be triplet' 26 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 27 | 28 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 29 | xent = CrossEntropyLabelSmooth(num_classes=num_classes) 30 | print("label smooth on, numclasses:", num_classes) 31 | 32 | if sampler == 'softmax': 33 | def loss_func(score, feat, target): 34 | return F.cross_entropy(score, target) 35 | 36 | elif cfg.DATALOADER.SAMPLER == 'softmax_triplet': 37 | def loss_func(score, feat, target, target_cam): 38 | if cfg.MODEL.METRIC_LOSS_TYPE == 'triplet': 39 | if cfg.MODEL.IF_LABELSMOOTH == 'on': 40 | if isinstance(score, list): 41 | ID_LOSS = [xent(scor, target) for scor in score[1:]] 42 | ID_LOSS = sum(ID_LOSS) / len(ID_LOSS) 43 | ID_LOSS = 0.5 * ID_LOSS + 0.5 * xent(score[0], target) 44 | else: 45 | ID_LOSS = xent(score, target) 46 | 47 | if isinstance(feat, list): 48 | TRI_LOSS = [triplet(feats, target)[0] for feats in feat[1:]] 49 | TRI_LOSS = sum(TRI_LOSS) / len(TRI_LOSS) 50 | TRI_LOSS = 0.5 * TRI_LOSS + 0.5 * triplet(feat[0], target)[0] 51 | else: 52 | TRI_LOSS = triplet(feat, target)[0] 53 | 54 | return cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + \ 55 | cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS 56 | else: 57 | if isinstance(score, list): 58 | ID_LOSS = sum([sum([F.cross_entropy(s, target) 59 | for s in scor]) / len(scor) 60 | for scor in score]) / len(score) 61 | else: 62 | ID_LOSS = F.cross_entropy(score, target) 63 | 64 | if isinstance(feat, list): 65 | TRI_LOSS = sum([sum([triplet(f, target)[0] 66 | for f in fea]) / len(fea) 67 | for fea in feat]) / len(feat) 68 | else: 69 | TRI_LOSS = triplet(feat, target)[0] 70 | 71 | return ID_LOSS, TRI_LOSS 72 | else: 73 | print('expected METRIC_LOSS_TYPE should be triplet' 74 | 'but got {}'.format(cfg.MODEL.METRIC_LOSS_TYPE)) 75 | 76 | else: 77 | print('expected sampler should be softmax, triplet, softmax_triplet or softmax_triplet_center' 78 | 'but got {}'.format(cfg.DATALOADER.SAMPLER)) 79 | return loss_func, center_criterion 80 | 81 | 82 | -------------------------------------------------------------------------------- /loss/metric_learning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.autograd 5 | from torch.nn import Parameter 6 | import math 7 | 8 | 9 | class ContrastiveLoss(nn.Module): 10 | def __init__(self, margin=0.3, **kwargs): 11 | super(ContrastiveLoss, self).__init__() 12 | self.margin = margin 13 | 14 | def forward(self, inputs, targets): 15 | n = inputs.size(0) 16 | # Compute similarity matrix 17 | sim_mat = torch.matmul(inputs, inputs.t()) 18 | targets = targets 19 | loss = list() 20 | c = 0 21 | 22 | for i in range(n): 23 | pos_pair_ = torch.masked_select(sim_mat[i], targets == targets[i]) 24 | 25 | # move itself 26 | pos_pair_ = torch.masked_select(pos_pair_, pos_pair_ < 1) 27 | neg_pair_ = torch.masked_select(sim_mat[i], targets != targets[i]) 28 | 29 | pos_pair_ = torch.sort(pos_pair_)[0] 30 | neg_pair_ = torch.sort(neg_pair_)[0] 31 | 32 | neg_pair = torch.masked_select(neg_pair_, neg_pair_ > self.margin) 33 | 34 | neg_loss = 0 35 | 36 | pos_loss = torch.sum(-pos_pair_ + 1) 37 | if len(neg_pair) > 0: 38 | neg_loss = torch.sum(neg_pair) 39 | loss.append(pos_loss + neg_loss) 40 | 41 | loss = sum(loss) / n 42 | return loss 43 | 44 | 45 | class CircleLoss(nn.Module): 46 | def __init__(self, in_features, num_classes, s=256, m=0.25): 47 | super(CircleLoss, self).__init__() 48 | self.weight = Parameter(torch.Tensor(num_classes, in_features)) 49 | self.s = s 50 | self.m = m 51 | self._num_classes = num_classes 52 | self.reset_parameters() 53 | 54 | 55 | def reset_parameters(self): 56 | nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 57 | 58 | def __call__(self, bn_feat, targets): 59 | 60 | sim_mat = F.linear(F.normalize(bn_feat), F.normalize(self.weight)) 61 | alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self.m, min=0.) 62 | alpha_n = torch.clamp_min(sim_mat.detach() + self.m, min=0.) 63 | delta_p = 1 - self.m 64 | delta_n = self.m 65 | 66 | s_p = self.s * alpha_p * (sim_mat - delta_p) 67 | s_n = self.s * alpha_n * (sim_mat - delta_n) 68 | 69 | targets = F.one_hot(targets, num_classes=self._num_classes) 70 | 71 | pred_class_logits = targets * s_p + (1.0 - targets) * s_n 72 | 73 | return pred_class_logits 74 | 75 | 76 | class Arcface(nn.Module): 77 | r"""Implement of large margin arc distance: : 78 | Args: 79 | in_features: size of each input sample 80 | out_features: size of each output sample 81 | s: norm of input feature 82 | m: margin 83 | cos(theta + m) 84 | """ 85 | def __init__(self, in_features, out_features, s=30.0, m=0.30, easy_margin=False, ls_eps=0.0): 86 | super(Arcface, self).__init__() 87 | self.in_features = in_features 88 | self.out_features = out_features 89 | self.s = s 90 | self.m = m 91 | self.ls_eps = ls_eps # label smoothing 92 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)) 93 | nn.init.xavier_uniform_(self.weight) 94 | 95 | self.easy_margin = easy_margin 96 | self.cos_m = math.cos(m) 97 | self.sin_m = math.sin(m) 98 | self.th = math.cos(math.pi - m) 99 | self.mm = math.sin(math.pi - m) * m 100 | 101 | def forward(self, input, label): 102 | # --------------------------- cos(theta) & phi(theta) --------------------------- 103 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 104 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) 105 | phi = cosine * self.cos_m - sine * self.sin_m 106 | phi = phi.type_as(cosine) 107 | if self.easy_margin: 108 | phi = torch.where(cosine > 0, phi, cosine) 109 | else: 110 | phi = torch.where(cosine > self.th, phi, cosine - self.mm) 111 | # --------------------------- convert label to one-hot --------------------------- 112 | # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda') 113 | one_hot = torch.zeros(cosine.size(), device='cuda') 114 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 115 | if self.ls_eps > 0: 116 | one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features 117 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 118 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) 119 | output *= self.s 120 | 121 | return output 122 | 123 | 124 | class Cosface(nn.Module): 125 | r"""Implement of large margin cosine distance: : 126 | Args: 127 | in_features: size of each input sample 128 | out_features: size of each output sample 129 | s: norm of input feature 130 | m: margin 131 | cos(theta) - m 132 | """ 133 | 134 | def __init__(self, in_features, out_features, s=30.0, m=0.30): 135 | super(Cosface, self).__init__() 136 | self.in_features = in_features 137 | self.out_features = out_features 138 | self.s = s 139 | self.m = m 140 | self.weight = Parameter(torch.FloatTensor(out_features, in_features)) 141 | nn.init.xavier_uniform_(self.weight) 142 | 143 | def forward(self, input, label): 144 | # --------------------------- cos(theta) & phi(theta) --------------------------- 145 | cosine = F.linear(F.normalize(input), F.normalize(self.weight)) 146 | phi = cosine - self.m 147 | # --------------------------- convert label to one-hot --------------------------- 148 | one_hot = torch.zeros(cosine.size(), device='cuda') 149 | # one_hot = one_hot.cuda() if cosine.is_cuda else one_hot 150 | one_hot.scatter_(1, label.view(-1, 1).long(), 1) 151 | # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------- 152 | output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4 153 | output *= self.s 154 | # print(output) 155 | 156 | return output 157 | 158 | def __repr__(self): 159 | return self.__class__.__name__ + '(' \ 160 | + 'in_features=' + str(self.in_features) \ 161 | + ', out_features=' + str(self.out_features) \ 162 | + ', s=' + str(self.s) \ 163 | + ', m=' + str(self.m) + ')' 164 | 165 | 166 | class AMSoftmax(nn.Module): 167 | def __init__(self, in_features, out_features, s=30.0, m=0.30): 168 | super(AMSoftmax, self).__init__() 169 | self.m = m 170 | self.s = s 171 | self.in_feats = in_features 172 | self.W = torch.nn.Parameter(torch.randn(in_features, out_features), requires_grad=True) 173 | self.ce = nn.CrossEntropyLoss() 174 | nn.init.xavier_normal_(self.W, gain=1) 175 | 176 | def forward(self, x, lb): 177 | assert x.size()[0] == lb.size()[0] 178 | assert x.size()[1] == self.in_feats 179 | x_norm = torch.norm(x, p=2, dim=1, keepdim=True).clamp(min=1e-12) 180 | x_norm = torch.div(x, x_norm) 181 | w_norm = torch.norm(self.W, p=2, dim=0, keepdim=True).clamp(min=1e-12) 182 | w_norm = torch.div(self.W, w_norm) 183 | costh = torch.mm(x_norm, w_norm) 184 | # print(x_norm.shape, w_norm.shape, costh.shape) 185 | lb_view = lb.view(-1, 1) 186 | delt_costh = torch.zeros(costh.size(), device='cuda').scatter_(1, lb_view, self.m) 187 | costh_m = costh - delt_costh 188 | costh_m_s = self.s * costh_m 189 | return costh_m_s -------------------------------------------------------------------------------- /loss/softmax_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | class CrossEntropyLabelSmooth(nn.Module): 5 | """Cross entropy loss with label smoothing regularizer. 6 | 7 | Reference: 8 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 9 | Equation: y = (1 - epsilon) * y + epsilon / K. 10 | 11 | Args: 12 | num_classes (int): number of classes. 13 | epsilon (float): weight. 14 | """ 15 | 16 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True): 17 | super(CrossEntropyLabelSmooth, self).__init__() 18 | self.num_classes = num_classes 19 | self.epsilon = epsilon 20 | self.use_gpu = use_gpu 21 | self.logsoftmax = nn.LogSoftmax(dim=1) 22 | 23 | def forward(self, inputs, targets): 24 | """ 25 | Args: 26 | inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 27 | targets: ground truth labels with shape (num_classes) 28 | """ 29 | log_probs = self.logsoftmax(inputs) 30 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 31 | if self.use_gpu: targets = targets.cuda() 32 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 33 | loss = (- targets * log_probs).mean(0).sum() 34 | return loss 35 | 36 | class LabelSmoothingCrossEntropy(nn.Module): 37 | """ 38 | NLL loss with label smoothing. 39 | """ 40 | def __init__(self, smoothing=0.1): 41 | """ 42 | Constructor for the LabelSmoothing module. 43 | :param smoothing: label smoothing factor 44 | """ 45 | super(LabelSmoothingCrossEntropy, self).__init__() 46 | assert smoothing < 1.0 47 | self.smoothing = smoothing 48 | self.confidence = 1. - smoothing 49 | 50 | def forward(self, x, target): 51 | logprobs = F.log_softmax(x, dim=-1) 52 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 53 | nll_loss = nll_loss.squeeze(1) 54 | smooth_loss = -logprobs.mean(dim=-1) 55 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 56 | return loss.mean() -------------------------------------------------------------------------------- /loss/triplet_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | def normalize(x, axis=-1): 6 | """Normalizing to unit length along the specified dimension. 7 | Args: 8 | x: pytorch Variable 9 | Returns: 10 | x: pytorch Variable, same shape as input 11 | """ 12 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 13 | return x 14 | 15 | 16 | def euclidean_dist(x, y): 17 | """ 18 | Args: 19 | x: pytorch Variable, with shape [m, d] 20 | y: pytorch Variable, with shape [n, d] 21 | Returns: 22 | dist: pytorch Variable, with shape [m, n] 23 | """ 24 | m, n = x.size(0), y.size(0) 25 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 26 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 27 | dist = xx + yy 28 | dist = dist - 2 * torch.matmul(x, y.t()) 29 | # dist.addmm_(1, -2, x, y.t()) 30 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 31 | return dist 32 | 33 | 34 | def cosine_dist(x, y): 35 | """ 36 | Args: 37 | x: pytorch Variable, with shape [m, d] 38 | y: pytorch Variable, with shape [n, d] 39 | Returns: 40 | dist: pytorch Variable, with shape [m, n] 41 | """ 42 | m, n = x.size(0), y.size(0) 43 | x_norm = torch.pow(x, 2).sum(1, keepdim=True).sqrt().expand(m, n) 44 | y_norm = torch.pow(y, 2).sum(1, keepdim=True).sqrt().expand(n, m).t() 45 | xy_intersection = torch.mm(x, y.t()) 46 | dist = xy_intersection/(x_norm * y_norm) 47 | dist = (1. - dist) / 2 48 | return dist 49 | 50 | 51 | def hard_example_mining(dist_mat, labels, return_inds=False): 52 | """For each anchor, find the hardest positive and negative sample. 53 | Args: 54 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 55 | labels: pytorch LongTensor, with shape [N] 56 | return_inds: whether to return the indices. Save time if `False`(?) 57 | Returns: 58 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 59 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 60 | p_inds: pytorch LongTensor, with shape [N]; 61 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 62 | n_inds: pytorch LongTensor, with shape [N]; 63 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 64 | NOTE: Only consider the case in which all labels have same num of samples, 65 | thus we can cope with all anchors in parallel. 66 | """ 67 | 68 | assert len(dist_mat.size()) == 2 69 | assert dist_mat.size(0) == dist_mat.size(1) 70 | N = dist_mat.size(0) 71 | 72 | # shape [N, N] 73 | is_pos = labels.expand(N, N).eq(labels.expand(N, N).t()) 74 | is_neg = labels.expand(N, N).ne(labels.expand(N, N).t()) 75 | 76 | # `dist_ap` means distance(anchor, positive) 77 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 78 | dist_ap, relative_p_inds = torch.max( 79 | dist_mat[is_pos].contiguous().view(N, -1), 1, keepdim=True) 80 | # print(dist_mat[is_pos].shape) 81 | # `dist_an` means distance(anchor, negative) 82 | # both `dist_an` and `relative_n_inds` with shape [N, 1] 83 | dist_an, relative_n_inds = torch.min( 84 | dist_mat[is_neg].contiguous().view(N, -1), 1, keepdim=True) 85 | # shape [N] 86 | dist_ap = dist_ap.squeeze(1) 87 | dist_an = dist_an.squeeze(1) 88 | 89 | if return_inds: 90 | # shape [N, N] 91 | ind = (labels.new().resize_as_(labels) 92 | .copy_(torch.arange(0, N).long()) 93 | .unsqueeze(0).expand(N, N)) 94 | # shape [N, 1] 95 | p_inds = torch.gather( 96 | ind[is_pos].contiguous().view(N, -1), 1, relative_p_inds.data) 97 | n_inds = torch.gather( 98 | ind[is_neg].contiguous().view(N, -1), 1, relative_n_inds.data) 99 | # shape [N] 100 | p_inds = p_inds.squeeze(1) 101 | n_inds = n_inds.squeeze(1) 102 | return dist_ap, dist_an, p_inds, n_inds 103 | 104 | return dist_ap, dist_an 105 | 106 | 107 | class TripletLoss(object): 108 | """ 109 | Triplet loss using HARDER example mining, 110 | modified based on original triplet loss using hard example mining 111 | """ 112 | 113 | def __init__(self, margin=None, hard_factor=0.0): 114 | self.margin = margin 115 | self.hard_factor = hard_factor 116 | if margin is not None: 117 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 118 | else: 119 | self.ranking_loss = nn.SoftMarginLoss() 120 | 121 | def __call__(self, global_feat, labels, normalize_feature=False): 122 | if normalize_feature: 123 | global_feat = normalize(global_feat, axis=-1) 124 | dist_mat = euclidean_dist(global_feat, global_feat) 125 | dist_ap, dist_an = hard_example_mining(dist_mat, labels) 126 | 127 | dist_ap *= (1.0 + self.hard_factor) 128 | dist_an *= (1.0 - self.hard_factor) 129 | 130 | y = dist_an.new().resize_as_(dist_an).fill_(1) 131 | if self.margin is not None: 132 | loss = self.ranking_loss(dist_an, dist_ap, y) 133 | else: 134 | loss = self.ranking_loss(dist_an - dist_ap, y) 135 | return loss, dist_ap, dist_an 136 | 137 | 138 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .make_model import make_model -------------------------------------------------------------------------------- /model/backbones/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deropty/PiT/779d3110808929f1f3d8c4e0ec6ec3f7f60dbffa/model/backbones/__init__.py -------------------------------------------------------------------------------- /model/backbones/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deropty/PiT/779d3110808929f1f3d8c4e0ec6ec3f7f60dbffa/model/backbones/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /model/backbones/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deropty/PiT/779d3110808929f1f3d8c4e0ec6ec3f7f60dbffa/model/backbones/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /model/backbones/__pycache__/vit_pytorch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deropty/PiT/779d3110808929f1f3d8c4e0ec6ec3f7f60dbffa/model/backbones/__pycache__/vit_pytorch.cpython-36.pyc -------------------------------------------------------------------------------- /model/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | """3x3 convolution with padding""" 9 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 10 | padding=1, bias=False) 11 | 12 | 13 | class BasicBlock(nn.Module): 14 | expansion = 1 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None): 17 | super(BasicBlock, self).__init__() 18 | self.conv1 = conv3x3(inplanes, planes, stride) 19 | self.bn1 = nn.BatchNorm2d(planes) 20 | self.relu = nn.ReLU(inplace=True) 21 | self.conv2 = conv3x3(planes, planes) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.downsample = downsample 24 | self.stride = stride 25 | 26 | def forward(self, x): 27 | residual = x 28 | 29 | out = self.conv1(x) 30 | out = self.bn1(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv2(out) 34 | out = self.bn2(out) 35 | 36 | if self.downsample is not None: 37 | residual = self.downsample(x) 38 | 39 | out += residual 40 | out = self.relu(out) 41 | 42 | return out 43 | 44 | 45 | class Bottleneck(nn.Module): 46 | expansion = 4 47 | 48 | def __init__(self, inplanes, planes, stride=1, downsample=None): 49 | super(Bottleneck, self).__init__() 50 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 51 | self.bn1 = nn.BatchNorm2d(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 53 | padding=1, bias=False) 54 | self.bn2 = nn.BatchNorm2d(planes) 55 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 56 | self.bn3 = nn.BatchNorm2d(planes * 4) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | def forward(self, x): 62 | residual = x 63 | 64 | out = self.conv1(x) 65 | out = self.bn1(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv2(out) 69 | out = self.bn2(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv3(out) 73 | out = self.bn3(out) 74 | 75 | if self.downsample is not None: 76 | residual = self.downsample(x) 77 | 78 | out += residual 79 | out = self.relu(out) 80 | 81 | return out 82 | 83 | 84 | class ResNet(nn.Module): 85 | def __init__(self, last_stride=2, block=Bottleneck,layers=[3, 4, 6, 3]): 86 | self.inplanes = 64 87 | super().__init__() 88 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 89 | bias=False) 90 | self.bn1 = nn.BatchNorm2d(64) 91 | # self.relu = nn.ReLU(inplace=True) # add missed relu 92 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=None, padding=0) 93 | self.layer1 = self._make_layer(block, 64, layers[0]) 94 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 95 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 96 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride) 97 | 98 | def _make_layer(self, block, planes, blocks, stride=1): 99 | downsample = None 100 | if stride != 1 or self.inplanes != planes * block.expansion: 101 | downsample = nn.Sequential( 102 | nn.Conv2d(self.inplanes, planes * block.expansion, 103 | kernel_size=1, stride=stride, bias=False), 104 | nn.BatchNorm2d(planes * block.expansion), 105 | ) 106 | 107 | layers = [] 108 | layers.append(block(self.inplanes, planes, stride, downsample)) 109 | self.inplanes = planes * block.expansion 110 | for i in range(1, blocks): 111 | layers.append(block(self.inplanes, planes)) 112 | 113 | return nn.Sequential(*layers) 114 | 115 | def forward(self, x, cam_label=None): 116 | x = self.conv1(x) 117 | x = self.bn1(x) 118 | # x = self.relu(x) # add missed relu 119 | x = self.maxpool(x) 120 | x = self.layer1(x) 121 | x = self.layer2(x) 122 | x = self.layer3(x) 123 | x = self.layer4(x) 124 | 125 | return x 126 | 127 | def load_param(self, model_path): 128 | param_dict = torch.load(model_path) 129 | for i in param_dict: 130 | if 'fc' in i: 131 | continue 132 | self.state_dict()[i].copy_(param_dict[i]) 133 | 134 | def random_init(self): 135 | for m in self.modules(): 136 | if isinstance(m, nn.Conv2d): 137 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 138 | m.weight.data.normal_(0, math.sqrt(2. / n)) 139 | elif isinstance(m, nn.BatchNorm2d): 140 | m.weight.data.fill_(1) 141 | m.bias.data.zero_() -------------------------------------------------------------------------------- /processor/__init__.py: -------------------------------------------------------------------------------- 1 | from .processor import do_train, do_inference -------------------------------------------------------------------------------- /processor/processor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | import torch 5 | import torch.nn as nn 6 | from utils.meter import AverageMeter 7 | from utils.metrics import R1_mAP_eval 8 | import torch.distributed as dist 9 | from torch.cuda import amp 10 | 11 | def do_train(cfg, 12 | model, 13 | center_criterion, 14 | train_loader, 15 | val_loader, 16 | optimizer, 17 | optimizer_center, 18 | scheduler, 19 | loss_fn, 20 | num_query, local_rank, saver, num, test): 21 | log_period = cfg.SOLVER.LOG_PERIOD 22 | checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD 23 | eval_period = cfg.SOLVER.EVAL_PERIOD 24 | 25 | device = "cuda" 26 | epochs = cfg.SOLVER.MAX_EPOCHS 27 | 28 | logger = logging.getLogger("pit.train") 29 | logger.info('start training') 30 | _LOCAL_PROCESS_GROUP = None 31 | if device: 32 | model.to(local_rank) 33 | if torch.cuda.device_count() > 1 and cfg.MODEL.DIST_TRAIN: 34 | print('Using {} GPUs for training'.format(torch.cuda.device_count())) 35 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True) 36 | 37 | cls_loss_meter = AverageMeter() 38 | tri_loss_meter = AverageMeter() 39 | acc_meter = AverageMeter() 40 | if cfg.MODEL.DIVERSITY: 41 | div_loss_meter = AverageMeter() 42 | 43 | evaluator = R1_mAP_eval(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM) 44 | 45 | scaler = amp.GradScaler() 46 | isVideo = True if cfg.DATASETS.NAMES in ['mars', 'duke-video-reid', 'ilids', 'prid'] else False 47 | freeze_layers = ['base', 'pyramid_layer'] 48 | freeze_epochs = cfg.SOLVER.WARMUP_EPOCHS 49 | freeze_or_not = cfg.MODEL.FREEZE 50 | # train 51 | for epoch in range(1, epochs + 1): 52 | if not test: 53 | start_time = time.time() 54 | cls_loss_meter.reset() 55 | tri_loss_meter.reset() 56 | if cfg.MODEL.DIVERSITY: 57 | div_loss_meter.reset() 58 | acc_meter.reset() 59 | scheduler.step(epoch) 60 | model.train() 61 | if freeze_or_not and epoch <= freeze_epochs: # freeze layers for 2000 iterations 62 | for name, module in model.named_children(): 63 | if name in freeze_layers: 64 | module.eval() 65 | for n_iter, (img, vid, target_cam, target_view) in enumerate(train_loader): 66 | optimizer.zero_grad() 67 | optimizer_center.zero_grad() 68 | img = img.to(device) 69 | target = vid.to(device) 70 | target_cam = target_cam.to(device) 71 | target_view = target_view.to(device) 72 | with amp.autocast(enabled=True): 73 | score, feat, diversity = model(img, target, cam_label=target_cam, view_label=target_view ) 74 | ID_LOSS, TRI_LOSS = loss_fn(score, feat, target, target_cam) 75 | loss = cfg.MODEL.ID_LOSS_WEIGHT * ID_LOSS + cfg.MODEL.TRIPLET_LOSS_WEIGHT * TRI_LOSS 76 | if cfg.MODEL.DIVERSITY: 77 | DIV_LOSS = sum([sum(diver_loss) / len(diver_loss) for diver_loss in diversity]) / len(diversity) 78 | loss += 1.0 * DIV_LOSS 79 | 80 | scaler.scale(loss).backward() 81 | 82 | scaler.step(optimizer) 83 | scaler.update() 84 | 85 | if 'center' in cfg.MODEL.METRIC_LOSS_TYPE: 86 | for param in center_criterion.parameters(): 87 | param.grad.data *= (1. / cfg.SOLVER.CENTER_LOSS_WEIGHT) 88 | scaler.step(optimizer_center) 89 | scaler.update() 90 | 91 | if isinstance(score, list): 92 | acc = (score[0][0].max(1)[1] == target).float().mean() 93 | else: 94 | acc = (score.max(1)[1] == target).float().mean() 95 | 96 | cls_loss_meter.update(ID_LOSS.item(), img.shape[0]) 97 | tri_loss_meter.update(TRI_LOSS.item(), img.shape[0]) 98 | if cfg.MODEL.DIVERSITY: 99 | div_loss_meter.update(DIV_LOSS.item(), img.shape[0]) 100 | acc_meter.update(acc, 1) 101 | 102 | torch.cuda.synchronize() 103 | if (n_iter + 1) % log_period == 0: 104 | if cfg.MODEL.DIVERSITY: 105 | logger.info( 106 | "Epoch[{}] Iteration[{}/{}] cls_loss: {:.3f}, tri_loss: {:.3f}, div_loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}" 107 | .format(epoch, (n_iter + 1), len(train_loader), 108 | cls_loss_meter.avg, tri_loss_meter.avg, div_loss_meter.avg, acc_meter.avg, scheduler._get_lr(epoch)[0])) 109 | else: 110 | logger.info("Epoch[{}] Iteration[{}/{}] cls_loss: {:.3f}, tri_loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}" 111 | .format(epoch, (n_iter + 1), len(train_loader), 112 | cls_loss_meter.avg, tri_loss_meter.avg,acc_meter.avg, scheduler._get_lr(epoch)[0])) 113 | 114 | end_time = time.time() 115 | time_per_batch = (end_time - start_time) / (n_iter + 1) 116 | 117 | saver.dump_metric_tb(cls_loss_meter.avg, epoch, f'losses', f'cls_loss') 118 | saver.dump_metric_tb(tri_loss_meter.avg, epoch, f'losses', f'tri_loss') 119 | if cfg.MODEL.DIVERSITY: 120 | saver.dump_metric_tb(div_loss_meter.avg, epoch, f'losses', f'div_loss') 121 | saver.dump_metric_tb(acc_meter.avg, epoch, f'losses', f'acc') 122 | saver.dump_metric_tb(optimizer.param_groups[0]['lr'], epoch, f'losses', f'lr') 123 | 124 | if cfg.MODEL.DIST_TRAIN: 125 | pass 126 | else: 127 | if isVideo: 128 | num_samples = cfg.DATALOADER.P * cfg.DATALOADER.K * cfg.DATALOADER.NUM_TRAIN_IMAGES 129 | logger.info("Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]" 130 | .format(epoch, time_per_batch, num_samples / time_per_batch)) 131 | else: 132 | logger.info("Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]" 133 | .format(epoch, time_per_batch, train_loader.batch_size / time_per_batch)) 134 | 135 | if epoch % checkpoint_period == 0: 136 | if cfg.MODEL.DIST_TRAIN: 137 | if dist.get_rank() == 0: 138 | torch.save(model.state_dict(), 139 | os.path.join(cfg.OUTPUT_DIR, cfg.MODEL.NAME + '_{}.pth'.format(epoch))) 140 | else: 141 | torch.save(model.state_dict(), 142 | os.path.join(cfg.OUTPUT_DIR, str(num+1), cfg.MODEL.NAME + '_{}.pth'.format(epoch))) 143 | elif epoch != 120: 144 | continue 145 | 146 | evaluator.reset() 147 | if epoch % eval_period == 0 or epoch == 1: 148 | if cfg.MODEL.DIST_TRAIN: 149 | if dist.get_rank() == 0: 150 | model.eval() 151 | for n_iter, (img, vid, camid, camids, target_view, _) in enumerate(val_loader): 152 | with torch.no_grad(): 153 | img = img.to(device) 154 | camids = camids.to(device) 155 | target_view = target_view.to(device) 156 | feat = model(img, cam_label=camids, view_label=target_view) 157 | evaluator.update((feat, vid, camid)) 158 | cmc, mAP, _, _, _, _, _ = evaluator.compute() 159 | logger.info("Validation Results - Epoch: {}".format(epoch)) 160 | logger.info("mAP: {:.1%}".format(mAP)) 161 | for r in [1, 5, 10]: 162 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 163 | torch.cuda.empty_cache() 164 | else: 165 | model.eval() 166 | for n_iter, (img, vid, camid, camids, target_view, _) in enumerate(val_loader): 167 | with torch.no_grad(): 168 | img = img.to(device) 169 | camids = camids.to(device) 170 | target_view = target_view.to(device) 171 | feat = model(img, cam_label=camids, view_label=target_view) 172 | evaluator.update((feat, vid, camid)) 173 | cmc, mAP, _, _, _, _, _ = evaluator.compute() 174 | logger.info("Validation Results - Epoch: {}".format(epoch)) 175 | logger.info("mAP: {:.3%}".format(mAP)) 176 | for r in [1, 5, 10, 20]: 177 | logger.info("CMC curve, Rank-{:<3}:{:.3%}".format(r, cmc[r - 1])) 178 | torch.cuda.empty_cache() 179 | 180 | saver.dump_metric_tb(mAP, epoch, f'v2v', f'mAP') 181 | for cmc_v in [1, 5, 10, 20]: 182 | saver.dump_metric_tb(cmc[cmc_v-1], epoch, f'v2v', f'cmc{cmc_v}') 183 | 184 | return cmc, mAP 185 | 186 | 187 | def do_inference(cfg, 188 | model, 189 | val_loader, 190 | num_query): 191 | device = "cuda" 192 | logger = logging.getLogger("pit.test") 193 | logger.info("Enter inferencing") 194 | 195 | evaluator = R1_mAP_eval(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM, reranking=cfg.TEST.RE_RANKING) 196 | 197 | evaluator.reset() 198 | 199 | if device: 200 | if torch.cuda.device_count() > 1: 201 | print('Using {} GPUs for inference'.format(torch.cuda.device_count())) 202 | model = nn.DataParallel(model) 203 | model.to(device) 204 | 205 | model.eval() 206 | img_path_list = [] 207 | 208 | for n_iter, (img, pid, camid, camids, target_view, imgpath) in enumerate(val_loader): 209 | with torch.no_grad(): 210 | img = img.to(device) 211 | camids = camids.to(device) 212 | target_view = target_view.to(device) 213 | feat = model(img, cam_label=camids, view_label=target_view) 214 | evaluator.update((feat, pid, camid)) 215 | img_path_list.extend(imgpath) 216 | 217 | import pandas as pd 218 | import numpy as np 219 | img_path_list = np.asarray(img_path_list) 220 | data = pd.DataFrame({str(i): img_path_list[:, i] for i in range(img_path_list.shape[1])}) 221 | data.to_csv('img_path.csv', index=True, sep=',') 222 | 223 | cmc, mAP, _, _, _, _, _ = evaluator.compute() 224 | logger.info("Validation Results ") 225 | logger.info("mAP: {:.1%}".format(mAP)) 226 | for r in [1, 5, 10, 20]: 227 | logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1])) 228 | return cmc[0], cmc[4] 229 | 230 | 231 | -------------------------------------------------------------------------------- /solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .lr_scheduler import WarmupMultiStepLR 2 | from .make_optimizer import make_optimizer -------------------------------------------------------------------------------- /solver/cosine_lr.py: -------------------------------------------------------------------------------- 1 | """ Cosine Scheduler 2 | 3 | Cosine LR schedule with warmup, cycle/restarts, noise. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | import logging 8 | import math 9 | import torch 10 | 11 | from .scheduler import Scheduler 12 | 13 | 14 | _logger = logging.getLogger(__name__) 15 | 16 | 17 | class CosineLRScheduler(Scheduler): 18 | """ 19 | Cosine decay with restarts. 20 | This is described in the paper https://arxiv.org/abs/1608.03983. 21 | 22 | Inspiration from 23 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 24 | """ 25 | 26 | def __init__(self, 27 | optimizer: torch.optim.Optimizer, 28 | t_initial: int, 29 | t_mul: float = 1., 30 | lr_min: float = 0., 31 | decay_rate: float = 1., 32 | warmup_t=0, 33 | warmup_lr_init=0, 34 | warmup_prefix=False, 35 | cycle_limit=0, 36 | t_in_epochs=True, 37 | noise_range_t=None, 38 | noise_pct=0.67, 39 | noise_std=1.0, 40 | noise_seed=42, 41 | initialize=True) -> None: 42 | super().__init__( 43 | optimizer, param_group_field="lr", 44 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 45 | initialize=initialize) 46 | 47 | assert t_initial > 0 48 | assert lr_min >= 0 49 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 50 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 51 | "rate since t_initial = t_mul = eta_mul = 1.") 52 | self.t_initial = t_initial 53 | self.t_mul = t_mul 54 | self.lr_min = lr_min 55 | self.decay_rate = decay_rate 56 | self.cycle_limit = cycle_limit 57 | self.warmup_t = warmup_t 58 | self.warmup_lr_init = warmup_lr_init 59 | self.warmup_prefix = warmup_prefix 60 | self.t_in_epochs = t_in_epochs 61 | if self.warmup_t: 62 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 63 | super().update_groups(self.warmup_lr_init) 64 | else: 65 | self.warmup_steps = [1 for _ in self.base_values] 66 | 67 | def _get_lr(self, t): 68 | if t < self.warmup_t: 69 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 70 | else: 71 | if self.warmup_prefix: 72 | t = t - self.warmup_t 73 | 74 | if self.t_mul != 1: 75 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 76 | t_i = self.t_mul ** i * self.t_initial 77 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 78 | else: 79 | i = t // self.t_initial 80 | t_i = self.t_initial 81 | t_curr = t - (self.t_initial * i) 82 | 83 | gamma = self.decay_rate ** i 84 | lr_min = self.lr_min * gamma 85 | lr_max_values = [v * gamma for v in self.base_values] 86 | 87 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 88 | lrs = [ 89 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 90 | ] 91 | else: 92 | lrs = [self.lr_min for _ in self.base_values] 93 | 94 | return lrs 95 | 96 | def get_epoch_values(self, epoch: int): 97 | if self.t_in_epochs: 98 | return self._get_lr(epoch) 99 | else: 100 | return None 101 | 102 | def get_update_values(self, num_updates: int): 103 | if not self.t_in_epochs: 104 | return self._get_lr(num_updates) 105 | else: 106 | return None 107 | 108 | def get_cycle_length(self, cycles=0): 109 | if not cycles: 110 | cycles = self.cycle_limit 111 | cycles = max(1, cycles) 112 | if self.t_mul == 1.0: 113 | return self.t_initial * cycles 114 | else: 115 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 116 | -------------------------------------------------------------------------------- /solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: liaoxingyu 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | from bisect import bisect_right 7 | import torch 8 | 9 | 10 | # FIXME ideally this would be achieved with a CombinedLRScheduler, 11 | # separating MultiStepLR with WarmupLR 12 | # but the current LRScheduler design doesn't allow it 13 | 14 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 15 | def __init__( 16 | self, 17 | optimizer, 18 | milestones, # steps 19 | gamma=0.1, 20 | warmup_factor=1.0 / 3, 21 | warmup_iters=500, 22 | warmup_method="linear", 23 | last_epoch=-1, 24 | ): 25 | if not list(milestones) == sorted(milestones): 26 | raise ValueError( 27 | "Milestones should be a list of" " increasing integers. Got {}", 28 | milestones, 29 | ) 30 | 31 | if warmup_method not in ("constant", "linear"): 32 | raise ValueError( 33 | "Only 'constant' or 'linear' warmup_method accepted" 34 | "got {}".format(warmup_method) 35 | ) 36 | self.milestones = milestones 37 | self.gamma = gamma 38 | self.warmup_factor = warmup_factor 39 | self.warmup_iters = warmup_iters 40 | self.warmup_method = warmup_method 41 | super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) 42 | 43 | def _get_lr(self): 44 | warmup_factor = 1 45 | if self.last_epoch < self.warmup_iters: 46 | if self.warmup_method == "constant": 47 | warmup_factor = self.warmup_factor 48 | elif self.warmup_method == "linear": 49 | alpha = self.last_epoch / self.warmup_iters 50 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 51 | return [ 52 | base_lr 53 | * warmup_factor 54 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 55 | for base_lr in self.base_lrs 56 | ] 57 | -------------------------------------------------------------------------------- /solver/make_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def make_optimizer(cfg, model, center_criterion): 5 | params = [] 6 | for key, value in model.named_parameters(): 7 | if not value.requires_grad: 8 | continue 9 | lr = cfg.SOLVER.BASE_LR 10 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 11 | if "bias" in key: 12 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 13 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 14 | if cfg.SOLVER.LARGE_FC_LR: 15 | if "classifier" in key or "arcface" in key: 16 | lr = cfg.SOLVER.BASE_LR * 2 17 | print('Using two times learning rate for fc ') 18 | 19 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 20 | 21 | if cfg.SOLVER.OPTIMIZER_NAME == 'SGD': 22 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params, momentum=cfg.SOLVER.MOMENTUM) 23 | elif cfg.SOLVER.OPTIMIZER_NAME == 'AdamW': 24 | optimizer = torch.optim.AdamW(params, lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 25 | elif cfg.SOLVER.OPTIMIZER_NAME == 'Adam': 26 | optimizer = torch.optim.Adam(params, lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 27 | else: 28 | optimizer = getattr(torch.optim, cfg.SOLVER.OPTIMIZER_NAME)(params) 29 | optimizer_center = torch.optim.SGD(center_criterion.parameters(), lr=cfg.SOLVER.CENTER_LR) 30 | 31 | return optimizer, optimizer_center 32 | -------------------------------------------------------------------------------- /solver/scheduler.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | import torch 4 | 5 | 6 | class Scheduler: 7 | """ Parameter Scheduler Base Class 8 | A scheduler base class that can be used to schedule any optimizer parameter groups. 9 | 10 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called 11 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value 12 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value 13 | 14 | The schedulers built on this should try to remain as stateless as possible (for simplicity). 15 | 16 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' 17 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training 18 | code and explicitly passed in to the schedulers on the corresponding step or step_update call. 19 | 20 | Based on ideas from: 21 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler 22 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers 23 | """ 24 | 25 | def __init__(self, 26 | optimizer: torch.optim.Optimizer, 27 | param_group_field: str, 28 | noise_range_t=None, 29 | noise_type='normal', 30 | noise_pct=0.67, 31 | noise_std=1.0, 32 | noise_seed=None, 33 | initialize: bool = True) -> None: 34 | self.optimizer = optimizer 35 | self.param_group_field = param_group_field 36 | self._initial_param_group_field = f"initial_{param_group_field}" 37 | if initialize: 38 | for i, group in enumerate(self.optimizer.param_groups): 39 | if param_group_field not in group: 40 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]") 41 | group.setdefault(self._initial_param_group_field, group[param_group_field]) 42 | else: 43 | for i, group in enumerate(self.optimizer.param_groups): 44 | if self._initial_param_group_field not in group: 45 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") 46 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] 47 | self.metric = None # any point to having this for all? 48 | self.noise_range_t = noise_range_t 49 | self.noise_pct = noise_pct 50 | self.noise_type = noise_type 51 | self.noise_std = noise_std 52 | self.noise_seed = noise_seed if noise_seed is not None else 42 53 | self.update_groups(self.base_values) 54 | 55 | def state_dict(self) -> Dict[str, Any]: 56 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 57 | 58 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 59 | self.__dict__.update(state_dict) 60 | 61 | def get_epoch_values(self, epoch: int): 62 | return None 63 | 64 | def get_update_values(self, num_updates: int): 65 | return None 66 | 67 | def step(self, epoch: int, metric: float = None) -> None: 68 | self.metric = metric 69 | values = self.get_epoch_values(epoch) 70 | if values is not None: 71 | values = self._add_noise(values, epoch) 72 | self.update_groups(values) 73 | 74 | def step_update(self, num_updates: int, metric: float = None): 75 | self.metric = metric 76 | values = self.get_update_values(num_updates) 77 | if values is not None: 78 | values = self._add_noise(values, num_updates) 79 | self.update_groups(values) 80 | 81 | def update_groups(self, values): 82 | if not isinstance(values, (list, tuple)): 83 | values = [values] * len(self.optimizer.param_groups) 84 | for param_group, value in zip(self.optimizer.param_groups, values): 85 | param_group[self.param_group_field] = value 86 | 87 | def _add_noise(self, lrs, t): 88 | if self.noise_range_t is not None: 89 | if isinstance(self.noise_range_t, (list, tuple)): 90 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] 91 | else: 92 | apply_noise = t >= self.noise_range_t 93 | if apply_noise: 94 | g = torch.Generator() 95 | g.manual_seed(self.noise_seed + t) 96 | if self.noise_type == 'normal': 97 | while True: 98 | # resample if noise out of percent limit, brute force but shouldn't spin much 99 | noise = torch.randn(1, generator=g).item() 100 | if abs(noise) < self.noise_pct: 101 | break 102 | else: 103 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 104 | lrs = [v + v * noise for v in lrs] 105 | return lrs 106 | -------------------------------------------------------------------------------- /solver/scheduler_factory.py: -------------------------------------------------------------------------------- 1 | """ Scheduler Factory 2 | Hacked together by / Copyright 2020 Ross Wightman 3 | """ 4 | from .cosine_lr import CosineLRScheduler 5 | 6 | 7 | def create_scheduler(cfg, optimizer): 8 | num_epochs = cfg.SOLVER.MAX_EPOCHS 9 | # type 1 10 | # lr_min = 0.01 * cfg.SOLVER.BASE_LR 11 | # warmup_lr_init = 0.001 * cfg.SOLVER.BASE_LR 12 | # type 2 13 | lr_min = 0.002 * cfg.SOLVER.BASE_LR 14 | warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR 15 | # type 3 16 | # lr_min = 0.001 * cfg.SOLVER.BASE_LR 17 | # warmup_lr_init = 0.01 * cfg.SOLVER.BASE_LR 18 | 19 | warmup_t = cfg.SOLVER.WARMUP_EPOCHS 20 | noise_range = None 21 | 22 | lr_scheduler = CosineLRScheduler( 23 | optimizer, 24 | t_initial=num_epochs, 25 | lr_min=lr_min, 26 | t_mul= 1., 27 | decay_rate=0.1, 28 | warmup_lr_init=warmup_lr_init, 29 | warmup_t=warmup_t, 30 | cycle_limit=1, 31 | t_in_epochs=True, 32 | noise_range_t=noise_range, 33 | noise_pct= 0.67, 34 | noise_std= 1., 35 | noise_seed=42, 36 | ) 37 | 38 | return lr_scheduler 39 | -------------------------------------------------------------------------------- /thop/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import clever_format 2 | from .profile import profile, profile_origin 3 | from .onnx_profile import OnnxProfile 4 | import torch 5 | 6 | default_dtype = torch.float64 7 | -------------------------------------------------------------------------------- /thop/fx_profile.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch as th 3 | import torch.nn as nn 4 | from distutils.version import LooseVersion 5 | 6 | if LooseVersion(torch.__version__) < LooseVersion("1.8.0"): 7 | logging.warning( 8 | f"torch.fx requires version higher than 1.8.0. " 9 | f"But You are using an old version PyTorch {torch.__version__}. " 10 | ) 11 | 12 | 13 | def count_clamp(input_shapes, output_shapes): 14 | return 0 15 | 16 | 17 | def count_mul(input_shapes, output_shapes): 18 | # element-wise 19 | return output_shapes[0].numel() 20 | 21 | 22 | def count_matmul(input_shapes, output_shapes): 23 | in_shape = input_shapes[0] 24 | out_shape = output_shapes[0] 25 | in_features = in_shape[-1] 26 | num_elements = out_shape.numel() 27 | return in_features * num_elements 28 | 29 | 30 | def count_fn_linear(input_shapes, output_shapes, *args, **kwargs): 31 | mul_flops = count_matmul(input_shapes, output_shapes) 32 | if "bias" in kwargs: 33 | add_flops = output_shapes[0].numel() 34 | return mul_flops 35 | 36 | 37 | from .vision.counter import counter_conv 38 | 39 | 40 | def count_fn_conv2d(input_shapes, output_shapes, *args, **kwargs): 41 | inputs, weight, bias, stride, padding, dilation, groups = args 42 | if len(input_shapes) == 2: 43 | x_shape, k_shape = input_shapes 44 | elif len(input_shapes) == 3: 45 | x_shape, k_shape, b_shape = input_shapes 46 | out_shape = output_shapes[0] 47 | 48 | kernel_parameters = k_shape[2:].numel() 49 | bias_op = 0 # check it later 50 | in_channel = x_shape[1] 51 | 52 | total_ops = counter_conv( 53 | bias_op, kernel_parameters, out_shape.numel(), in_channel, groups 54 | ).item() 55 | return int(total_ops) 56 | 57 | 58 | def count_nn_linear(module: nn.Module, input_shapes, output_shapes): 59 | return count_matmul(input_shapes, output_shapes) 60 | 61 | 62 | def count_zero_ops(module: nn.Module, input_shapes, output_shapes, *args, **kwargs): 63 | return 0 64 | 65 | 66 | def count_nn_conv2d(module: nn.Conv2d, input_shapes, output_shapes): 67 | bias_op = 1 if module.bias is not None else 0 68 | out_shape = output_shapes[0] 69 | 70 | in_channel = module.in_channels 71 | groups = module.groups 72 | kernel_ops = module.weight.shape[2:].numel() 73 | total_ops = counter_conv( 74 | bias_op, kernel_ops, out_shape.numel(), in_channel, groups 75 | ).item() 76 | return int(total_ops) 77 | 78 | 79 | def count_nn_bn2d(module: nn.BatchNorm2d, input_shapes, output_shapes): 80 | assert len(output_shapes) == 1, "nn.BatchNorm2d should only have one output" 81 | y = output_shapes[0] 82 | # y = (x - mean) / \sqrt{var + e} * weight + bias 83 | total_ops = 2 * y.numel() 84 | return total_ops 85 | 86 | 87 | zero_ops = ( 88 | nn.ReLU, 89 | nn.ReLU6, 90 | nn.Dropout, 91 | nn.MaxPool2d, 92 | nn.AvgPool2d, 93 | nn.AdaptiveAvgPool2d, 94 | ) 95 | 96 | count_map = { 97 | nn.Linear: count_nn_linear, 98 | nn.Conv2d: count_nn_conv2d, 99 | nn.BatchNorm2d: count_nn_bn2d, 100 | "function linear": count_fn_linear, 101 | "clamp": count_clamp, 102 | "built-in function add": count_zero_ops, 103 | "built-in method fl": count_zero_ops, 104 | "built-in method conv2d of type object": count_fn_conv2d, 105 | "built-in function mul": count_mul, 106 | "built-in function truediv": count_mul, 107 | } 108 | 109 | for k in zero_ops: 110 | count_map[k] = count_zero_ops 111 | 112 | missing_maps = {} 113 | 114 | from torch.fx import symbolic_trace 115 | from torch.fx.passes.shape_prop import ShapeProp 116 | from .utils import prGreen, prRed, prYellow 117 | 118 | 119 | def null_print(*args, **kwargs): 120 | return 121 | 122 | 123 | def fx_profile(mod: nn.Module, input: th.Tensor, verbose=False): 124 | gm: torch.fx.GraphModule = symbolic_trace(mod) 125 | g = gm.graph 126 | ShapeProp(gm).propagate(input) 127 | 128 | fprint = null_print 129 | if verbose: 130 | fprint = print 131 | 132 | v_maps = {} 133 | total_flops = 0 134 | 135 | for node in gm.graph.nodes: 136 | # print(f"{node.target},\t{node.op},\t{node.meta['tensor_meta'].dtype},\t{node.meta['tensor_meta'].shape}") 137 | fprint( 138 | f"NodeOP:{node.op},\tTarget:{node.target},\tNodeName:{node.name},\tNodeArgs:{node.args}" 139 | ) 140 | # node_op_type = str(node.target).split(".")[-1] 141 | node_flops = None 142 | 143 | input_shapes = [] 144 | output_shapes = [] 145 | fprint("input_shape:", end="\t") 146 | for arg in node.args: 147 | if str(arg) not in v_maps: 148 | continue 149 | fprint(f"{v_maps[str(arg)]}", end="\t") 150 | input_shapes.append(v_maps[str(arg)]) 151 | fprint() 152 | fprint(f"output_shape:\t{node.meta['tensor_meta'].shape}") 153 | output_shapes.append(node.meta["tensor_meta"].shape) 154 | 155 | if node.op in ["output", "placeholder"]: 156 | node_flops = 0 157 | elif node.op == "call_function": 158 | # torch internal functions 159 | key = ( 160 | str(node.target) 161 | .split("at")[0] 162 | .replace("<", "") 163 | .replace(">", "") 164 | .strip() 165 | ) 166 | if key in count_map: 167 | node_flops = count_map[key]( 168 | input_shapes, output_shapes, *node.args, **node.kwargs 169 | ) 170 | else: 171 | missing_maps[key] = (node.op, key) 172 | prRed(f"|{key}| is missing") 173 | elif node.op == "call_method": 174 | # torch internal functions 175 | # fprint(str(node.target) in count_map, str(node.target), count_map.keys()) 176 | key = str(node.target) 177 | if key in count_map: 178 | node_flops = count_map[key](input_shapes, output_shapes) 179 | else: 180 | missing_maps[key] = (node.op, key) 181 | prRed(f"{key} is missing") 182 | elif node.op == "call_module": 183 | # torch.nn modules 184 | # m = getattr(mod, node.target, None) 185 | m = mod.get_submodule(node.target) 186 | key = type(m) 187 | fprint(type(m), type(m) in count_map) 188 | if type(m) in count_map: 189 | node_flops = count_map[type(m)](m, input_shapes, output_shapes) 190 | else: 191 | missing_maps[key] = (node.op,) 192 | prRed(f"{key} is missing") 193 | print("module type:", type(m)) 194 | if isinstance(m, zero_ops): 195 | print(f"weight_shape: None") 196 | else: 197 | print(type(m)) 198 | print( 199 | f"weight_shape: {mod.state_dict()[node.target + '.weight'].shape}" 200 | ) 201 | 202 | v_maps[str(node.name)] = node.meta["tensor_meta"].shape 203 | if node_flops is not None: 204 | total_flops += node_flops 205 | prYellow(f"Current node's FLOPs: {node_flops}, total FLOPs: {total_flops}") 206 | fprint("==" * 40) 207 | 208 | if len(missing_maps.keys()) > 0: 209 | from pprint import pprint 210 | print("Missing operators: ") 211 | pprint(missing_maps) 212 | return total_flops 213 | 214 | 215 | if __name__ == "__main__": 216 | 217 | class MyOP(nn.Module): 218 | def forward(self, input): 219 | return input / 1 220 | 221 | class MyModule(torch.nn.Module): 222 | def __init__(self): 223 | super().__init__() 224 | self.linear1 = torch.nn.Linear(5, 3) 225 | self.linear2 = torch.nn.Linear(5, 3) 226 | self.myop = MyOP() 227 | 228 | def forward(self, x): 229 | out1 = self.linear1(x) 230 | out2 = self.linear2(x).clamp(min=0.0, max=1.0) 231 | return self.myop(out1 + out2) 232 | 233 | net = MyModule() 234 | data = th.randn(20, 5) 235 | flops = fx_profile(net, data, verbose=False) 236 | print(flops) 237 | -------------------------------------------------------------------------------- /thop/onnx_profile.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import onnx 4 | from onnx import numpy_helper 5 | import numpy as np 6 | from thop.vision.onnx_counter import onnx_operators 7 | 8 | 9 | class OnnxProfile: 10 | def __init__(self): 11 | pass 12 | 13 | def calculate_params(self, model: onnx.ModelProto): 14 | onnx_weights = model.graph.initializer 15 | params = 0 16 | 17 | for onnx_w in onnx_weights: 18 | try: 19 | weight = numpy_helper.to_array(onnx_w) 20 | params += np.prod(weight.shape) 21 | except Exception as _: 22 | pass 23 | 24 | return params 25 | 26 | def create_dict(self, weight, input, output): 27 | diction = {} 28 | for w in weight: 29 | dim = np.array(w.dims) 30 | diction[str(w.name)] = dim 31 | if dim.size == 1: 32 | diction[str(w.name)] = np.append(1, dim) 33 | for i in input: 34 | # print(i.type.tensor_type.shape.dim[0].dim_value) 35 | dim = np.array(i.type.tensor_type.shape.dim[0].dim_value) 36 | # print(i.type.tensor_type.shape.dim.__sizeof__()) 37 | # name2dims[str(i.name)] = [dim] 38 | dim = [] 39 | for key in i.type.tensor_type.shape.dim: 40 | dim = np.append(dim, int(key.dim_value)) 41 | # print(key.dim_value) 42 | # print(dim) 43 | diction[str(i.name)] = dim 44 | if dim.size == 1: 45 | diction[str(i.name)] = np.append(1, dim) 46 | for o in output: 47 | dim = np.array(o.type.tensor_type.shape.dim[0].dim_value) 48 | diction[str(o.name)] = [dim] 49 | if dim.size == 1: 50 | diction[str(o.name)] = np.append(1, dim) 51 | return diction 52 | 53 | def nodes_counter(self, diction, node): 54 | if node.op_type not in onnx_operators: 55 | print("Sorry, we haven't add ", node.op_type, "into dictionary.") 56 | return 0, None, None 57 | else: 58 | fn = onnx_operators[node.op_type] 59 | return fn(diction, node) 60 | 61 | def calculate_macs(self, model: onnx.ModelProto) -> torch.DoubleTensor: 62 | macs = 0 63 | name2dims = {} 64 | weight = model.graph.initializer 65 | nodes = model.graph.node 66 | input = model.graph.input 67 | output = model.graph.output 68 | name2dims = self.create_dict(weight, input, output) 69 | macs = 0 70 | for n in nodes: 71 | macs_adding, out_size, outname = self.nodes_counter(name2dims, n) 72 | 73 | name2dims[outname] = out_size 74 | macs += macs_adding 75 | return np.array(macs[0]) 76 | -------------------------------------------------------------------------------- /thop/profile.py: -------------------------------------------------------------------------------- 1 | from distutils.version import LooseVersion 2 | 3 | from thop.vision.basic_hooks import * 4 | from thop.rnn_hooks import * 5 | 6 | 7 | # logger = logging.getLogger(__name__) 8 | # logger.setLevel(logging.INFO) 9 | 10 | from .utils import prGreen, prRed, prYellow 11 | 12 | if LooseVersion(torch.__version__) < LooseVersion("1.0.0"): 13 | logging.warning( 14 | "You are using an old version PyTorch {version}, which THOP does NOT support.".format( 15 | version=torch.__version__ 16 | ) 17 | ) 18 | 19 | default_dtype = torch.float64 20 | 21 | register_hooks = { 22 | nn.ZeroPad2d: zero_ops, # padding does not involve any multiplication. 23 | nn.Conv1d: count_convNd, 24 | nn.Conv2d: count_convNd, 25 | nn.Conv3d: count_convNd, 26 | nn.ConvTranspose1d: count_convNd, 27 | nn.ConvTranspose2d: count_convNd, 28 | nn.ConvTranspose3d: count_convNd, 29 | nn.BatchNorm1d: count_bn, 30 | nn.BatchNorm2d: count_bn, 31 | nn.BatchNorm3d: count_bn, 32 | nn.LayerNorm: count_ln, 33 | nn.InstanceNorm1d: count_in, 34 | nn.InstanceNorm2d: count_in, 35 | nn.InstanceNorm3d: count_in, 36 | nn.PReLU: count_prelu, 37 | nn.Softmax: count_softmax, 38 | nn.ReLU: zero_ops, 39 | nn.ReLU6: zero_ops, 40 | nn.LeakyReLU: count_relu, 41 | nn.MaxPool1d: zero_ops, 42 | nn.MaxPool2d: zero_ops, 43 | nn.MaxPool3d: zero_ops, 44 | nn.AdaptiveMaxPool1d: zero_ops, 45 | nn.AdaptiveMaxPool2d: zero_ops, 46 | nn.AdaptiveMaxPool3d: zero_ops, 47 | nn.AvgPool1d: count_avgpool, 48 | nn.AvgPool2d: count_avgpool, 49 | nn.AvgPool3d: count_avgpool, 50 | nn.AdaptiveAvgPool1d: count_adap_avgpool, 51 | nn.AdaptiveAvgPool2d: count_adap_avgpool, 52 | nn.AdaptiveAvgPool3d: count_adap_avgpool, 53 | nn.Linear: count_linear, 54 | nn.Dropout: zero_ops, 55 | nn.Upsample: count_upsample, 56 | nn.UpsamplingBilinear2d: count_upsample, 57 | nn.UpsamplingNearest2d: count_upsample, 58 | nn.RNNCell: count_rnn_cell, 59 | nn.GRUCell: count_gru_cell, 60 | nn.LSTMCell: count_lstm_cell, 61 | nn.RNN: count_rnn, 62 | nn.GRU: count_gru, 63 | nn.LSTM: count_lstm, 64 | nn.Sequential: zero_ops, 65 | } 66 | 67 | if LooseVersion(torch.__version__) >= LooseVersion("1.1.0"): 68 | register_hooks.update({nn.SyncBatchNorm: count_bn}) 69 | 70 | 71 | def profile_origin(model, inputs, custom_ops=None, verbose=True, report_missing=False): 72 | handler_collection = [] 73 | types_collection = set() 74 | if custom_ops is None: 75 | custom_ops = {} 76 | if report_missing: 77 | verbose = True 78 | 79 | def add_hooks(m): 80 | if len(list(m.children())) > 0: 81 | return 82 | 83 | if hasattr(m, "total_ops") or hasattr(m, "total_params"): 84 | logging.warning( 85 | "Either .total_ops or .total_params is already defined in %s. " 86 | "Be careful, it might change your code's behavior." % str(m) 87 | ) 88 | 89 | m.register_buffer("total_ops", torch.zeros(1, dtype=default_dtype)) 90 | m.register_buffer("total_params", torch.zeros(1, dtype=default_dtype)) 91 | 92 | for p in m.parameters(): 93 | m.total_params += torch.DoubleTensor([p.numel()]) 94 | 95 | m_type = type(m) 96 | 97 | fn = None 98 | if ( 99 | m_type in custom_ops 100 | ): # if defined both op maps, use custom_ops to overwrite. 101 | fn = custom_ops[m_type] 102 | if m_type not in types_collection and verbose: 103 | print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type)) 104 | elif m_type in register_hooks: 105 | fn = register_hooks[m_type] 106 | if m_type not in types_collection and verbose: 107 | print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type)) 108 | else: 109 | if m_type not in types_collection and report_missing: 110 | prRed( 111 | "[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." 112 | % m_type 113 | ) 114 | 115 | if fn is not None: 116 | handler = m.register_forward_hook(fn) 117 | handler_collection.append(handler) 118 | types_collection.add(m_type) 119 | 120 | training = model.training 121 | 122 | model.eval() 123 | model.apply(add_hooks) 124 | 125 | with torch.no_grad(): 126 | model(*inputs) 127 | 128 | total_ops = 0 129 | total_params = 0 130 | for m in model.modules(): 131 | if len(list(m.children())) > 0: # skip for non-leaf module 132 | continue 133 | total_ops += m.total_ops 134 | total_params += m.total_params 135 | 136 | total_ops = total_ops.item() 137 | total_params = total_params.item() 138 | 139 | # reset model to original status 140 | model.train(training) 141 | for handler in handler_collection: 142 | handler.remove() 143 | 144 | # remove temporal buffers 145 | for n, m in model.named_modules(): 146 | if len(list(m.children())) > 0: 147 | continue 148 | if "total_ops" in m._buffers: 149 | m._buffers.pop("total_ops") 150 | if "total_params" in m._buffers: 151 | m._buffers.pop("total_params") 152 | 153 | return total_ops, total_params 154 | 155 | 156 | def profile( 157 | model: nn.Module, 158 | inputs, 159 | custom_ops=None, 160 | verbose=True, 161 | ret_layer_info=False, 162 | report_missing=False, 163 | ): 164 | handler_collection = {} 165 | types_collection = set() 166 | if custom_ops is None: 167 | custom_ops = {} 168 | if report_missing: 169 | # overwrite `verbose` option when enable report_missing 170 | verbose = True 171 | 172 | def add_hooks(m: nn.Module): 173 | m.register_buffer("total_ops", torch.zeros(1, dtype=torch.float64)) 174 | m.register_buffer("total_params", torch.zeros(1, dtype=torch.float64)) 175 | 176 | # for p in m.parameters(): 177 | # m.total_params += torch.DoubleTensor([p.numel()]) 178 | 179 | m_type = type(m) 180 | 181 | fn = None 182 | if m_type in custom_ops: 183 | # if defined both op maps, use custom_ops to overwrite. 184 | fn = custom_ops[m_type] 185 | if m_type not in types_collection and verbose: 186 | print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type)) 187 | elif m_type in register_hooks: 188 | fn = register_hooks[m_type] 189 | if m_type not in types_collection and verbose: 190 | print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type)) 191 | else: 192 | if m_type not in types_collection and report_missing: 193 | prRed( 194 | "[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." 195 | % m_type 196 | ) 197 | 198 | if fn is not None: 199 | handler_collection[m] = ( 200 | m.register_forward_hook(fn), 201 | m.register_forward_hook(count_parameters), 202 | ) 203 | types_collection.add(m_type) 204 | 205 | prev_training_status = model.training 206 | 207 | model.eval() 208 | model.apply(add_hooks) 209 | 210 | with torch.no_grad(): 211 | model(*inputs) 212 | 213 | def dfs_count(module: nn.Module, prefix="\t") -> (int, int): 214 | total_ops, total_params = module.total_ops.item(), 0 215 | ret_dict = {} 216 | for n, m in module.named_children(): 217 | # if not hasattr(m, "total_ops") and not hasattr(m, "total_params"): # and len(list(m.children())) > 0: 218 | # m_ops, m_params = dfs_count(m, prefix=prefix + "\t") 219 | # else: 220 | # m_ops, m_params = m.total_ops, m.total_params 221 | next_dict = {} 222 | if m in handler_collection and not isinstance( 223 | m, (nn.Sequential, nn.ModuleList) 224 | ): 225 | m_ops, m_params = m.total_ops.item(), m.total_params.item() 226 | else: 227 | m_ops, m_params, next_dict = dfs_count(m, prefix=prefix + "\t") 228 | ret_dict[n] = (m_ops, m_params, next_dict) 229 | total_ops += m_ops 230 | total_params += m_params 231 | # print(prefix, module._get_name(), (total_ops, total_params)) 232 | return total_ops, total_params, ret_dict 233 | 234 | total_ops, total_params, ret_dict = dfs_count(model) 235 | 236 | # reset model to original status 237 | model.train(prev_training_status) 238 | for m, (op_handler, params_handler) in handler_collection.items(): 239 | op_handler.remove() 240 | params_handler.remove() 241 | m._buffers.pop("total_ops") 242 | m._buffers.pop("total_params") 243 | 244 | if ret_layer_info: 245 | return total_ops, total_params, ret_dict 246 | return total_ops, total_params 247 | -------------------------------------------------------------------------------- /thop/rnn_hooks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.rnn import PackedSequence 4 | 5 | 6 | def _count_rnn_cell(input_size, hidden_size, bias=True): 7 | # h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh}) 8 | total_ops = hidden_size * (input_size + hidden_size) + hidden_size 9 | if bias: 10 | total_ops += hidden_size * 2 11 | 12 | return total_ops 13 | 14 | 15 | def count_rnn_cell(m: nn.RNNCell, x: torch.Tensor, y: torch.Tensor): 16 | total_ops = _count_rnn_cell(m.input_size, m.hidden_size, m.bias) 17 | 18 | batch_size = x[0].size(0) 19 | total_ops *= batch_size 20 | 21 | m.total_ops += torch.DoubleTensor([int(total_ops)]) 22 | 23 | 24 | def _count_gru_cell(input_size, hidden_size, bias=True): 25 | total_ops = 0 26 | # r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\ 27 | # z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\ 28 | state_ops = (hidden_size + input_size) * hidden_size + hidden_size 29 | if bias: 30 | state_ops += hidden_size * 2 31 | total_ops += state_ops * 2 32 | 33 | # n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ 34 | total_ops += (hidden_size + input_size) * hidden_size + hidden_size 35 | if bias: 36 | total_ops += hidden_size * 2 37 | # r hadamard : r * (~) 38 | total_ops += hidden_size 39 | 40 | # h' = (1 - z) * n + z * h 41 | # hadamard hadamard add 42 | total_ops += hidden_size * 3 43 | 44 | return total_ops 45 | 46 | 47 | def count_gru_cell(m: nn.GRUCell, x: torch.Tensor, y: torch.Tensor): 48 | total_ops = _count_gru_cell(m.input_size, m.hidden_size, m.bias) 49 | 50 | batch_size = x[0].size(0) 51 | total_ops *= batch_size 52 | 53 | m.total_ops += torch.DoubleTensor([int(total_ops)]) 54 | 55 | 56 | def _count_lstm_cell(input_size, hidden_size, bias=True): 57 | total_ops = 0 58 | 59 | # i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\ 60 | # f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\ 61 | # o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\ 62 | # g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\ 63 | state_ops = (input_size + hidden_size) * hidden_size + hidden_size 64 | if bias: 65 | state_ops += hidden_size * 2 66 | total_ops += state_ops * 4 67 | 68 | # c' = f * c + i * g \\ 69 | # hadamard hadamard add 70 | total_ops += hidden_size * 3 71 | 72 | # h' = o * \tanh(c') \\ 73 | total_ops += hidden_size 74 | 75 | return total_ops 76 | 77 | 78 | def count_lstm_cell(m: nn.LSTMCell, x: torch.Tensor, y: torch.Tensor): 79 | total_ops = _count_lstm_cell(m.input_size, m.hidden_size, m.bias) 80 | 81 | batch_size = x[0].size(0) 82 | total_ops *= batch_size 83 | 84 | m.total_ops += torch.DoubleTensor([int(total_ops)]) 85 | 86 | 87 | def count_rnn(m: nn.RNN, x, y): 88 | bias = m.bias 89 | input_size = m.input_size 90 | hidden_size = m.hidden_size 91 | num_layers = m.num_layers 92 | 93 | if isinstance(x[0], PackedSequence): 94 | batch_size = torch.max(x[0].batch_sizes) 95 | num_steps = x[0].batch_sizes.size(0) 96 | else: 97 | if m.batch_first: 98 | batch_size = x[0].size(0) 99 | num_steps = x[0].size(1) 100 | else: 101 | batch_size = x[0].size(1) 102 | num_steps = x[0].size(0) 103 | 104 | total_ops = 0 105 | if m.bidirectional: 106 | total_ops += _count_rnn_cell(input_size, hidden_size, bias) * 2 107 | else: 108 | total_ops += _count_rnn_cell(input_size, hidden_size, bias) 109 | 110 | for i in range(num_layers - 1): 111 | if m.bidirectional: 112 | total_ops += _count_rnn_cell(hidden_size * 2, hidden_size, bias) * 2 113 | else: 114 | total_ops += _count_rnn_cell(hidden_size, hidden_size, bias) 115 | 116 | # time unroll 117 | total_ops *= num_steps 118 | # batch_size 119 | total_ops *= batch_size 120 | 121 | m.total_ops += torch.DoubleTensor([int(total_ops)]) 122 | 123 | 124 | def count_gru(m: nn.GRU, x, y): 125 | bias = m.bias 126 | input_size = m.input_size 127 | hidden_size = m.hidden_size 128 | num_layers = m.num_layers 129 | 130 | if isinstance(x[0], PackedSequence): 131 | batch_size = torch.max(x[0].batch_sizes) 132 | num_steps = x[0].batch_sizes.size(0) 133 | else: 134 | if m.batch_first: 135 | batch_size = x[0].size(0) 136 | num_steps = x[0].size(1) 137 | else: 138 | batch_size = x[0].size(1) 139 | num_steps = x[0].size(0) 140 | 141 | total_ops = 0 142 | if m.bidirectional: 143 | total_ops += _count_gru_cell(input_size, hidden_size, bias) * 2 144 | else: 145 | total_ops += _count_gru_cell(input_size, hidden_size, bias) 146 | 147 | for i in range(num_layers - 1): 148 | if m.bidirectional: 149 | total_ops += _count_gru_cell(hidden_size * 2, hidden_size, bias) * 2 150 | else: 151 | total_ops += _count_gru_cell(hidden_size, hidden_size, bias) 152 | 153 | # time unroll 154 | total_ops *= num_steps 155 | # batch_size 156 | total_ops *= batch_size 157 | 158 | m.total_ops += torch.DoubleTensor([int(total_ops)]) 159 | 160 | 161 | def count_lstm(m: nn.LSTM, x, y): 162 | bias = m.bias 163 | input_size = m.input_size 164 | hidden_size = m.hidden_size 165 | num_layers = m.num_layers 166 | 167 | if isinstance(x[0], PackedSequence): 168 | batch_size = torch.max(x[0].batch_sizes) 169 | num_steps = x[0].batch_sizes.size(0) 170 | else: 171 | if m.batch_first: 172 | batch_size = x[0].size(0) 173 | num_steps = x[0].size(1) 174 | else: 175 | batch_size = x[0].size(1) 176 | num_steps = x[0].size(0) 177 | 178 | total_ops = 0 179 | if m.bidirectional: 180 | total_ops += _count_lstm_cell(input_size, hidden_size, bias) * 2 181 | else: 182 | total_ops += _count_lstm_cell(input_size, hidden_size, bias) 183 | 184 | for i in range(num_layers - 1): 185 | if m.bidirectional: 186 | total_ops += _count_lstm_cell(hidden_size * 2, hidden_size, bias) * 2 187 | else: 188 | total_ops += _count_lstm_cell(hidden_size, hidden_size, bias) 189 | 190 | # time unroll 191 | total_ops *= num_steps 192 | # batch_size 193 | total_ops *= batch_size 194 | 195 | m.total_ops += torch.DoubleTensor([int(total_ops)]) 196 | -------------------------------------------------------------------------------- /thop/utils.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | 3 | COLOR_RED = "91m" 4 | COLOR_GREEN = "92m" 5 | COLOR_YELLOW = "93m" 6 | 7 | def colorful_print(fn_print, color=COLOR_RED): 8 | def actual_call(*args, **kwargs): 9 | print(f"\033[{color}", end="") 10 | fn_print(*args, **kwargs) 11 | print("\033[00m", end="") 12 | return actual_call 13 | 14 | prRed = colorful_print(print, color=COLOR_RED) 15 | prGreen = colorful_print(print, color=COLOR_GREEN) 16 | prYellow = colorful_print(print, color=COLOR_YELLOW) 17 | 18 | # def prRed(skk): 19 | # print("\033[91m{}\033[00m".format(skk)) 20 | 21 | # def prGreen(skk): 22 | # print("\033[92m{}\033[00m".format(skk)) 23 | 24 | # def prYellow(skk): 25 | # print("\033[93m{}\033[00m".format(skk)) 26 | 27 | 28 | def clever_format(nums, format="%.2f"): 29 | if not isinstance(nums, Iterable): 30 | nums = [nums] 31 | clever_nums = [] 32 | 33 | for num in nums: 34 | if num > 1e12: 35 | clever_nums.append(format % (num / 1e12) + "T") 36 | elif num > 1e9: 37 | clever_nums.append(format % (num / 1e9) + "G") 38 | elif num > 1e6: 39 | clever_nums.append(format % (num / 1e6) + "M") 40 | elif num > 1e3: 41 | clever_nums.append(format % (num / 1e3) + "K") 42 | else: 43 | clever_nums.append(format % num + "B") 44 | 45 | clever_nums = clever_nums[0] if len(clever_nums) == 1 else (*clever_nums,) 46 | 47 | return clever_nums 48 | 49 | 50 | if __name__ == "__main__": 51 | prRed("hello", "world") 52 | prGreen("hello", "world") 53 | prYellow("hello", "world") -------------------------------------------------------------------------------- /thop/vision/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deropty/PiT/779d3110808929f1f3d8c4e0ec6ec3f7f60dbffa/thop/vision/__init__.py -------------------------------------------------------------------------------- /thop/vision/basic_hooks.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from .counter import ( 4 | counter_parameters, 5 | counter_conv, 6 | counter_norm, 7 | counter_relu, 8 | counter_softmax, 9 | counter_avgpool, 10 | counter_adap_avg, 11 | counter_zero_ops, 12 | counter_upsample, 13 | counter_linear, 14 | ) 15 | import torch 16 | import torch.nn as nn 17 | from torch.nn.modules.conv import _ConvNd 18 | 19 | multiply_adds = 1 20 | 21 | 22 | def count_parameters(m, x, y): 23 | total_params = 0 24 | for p in m.parameters(): 25 | total_params += torch.DoubleTensor([p.numel()]) 26 | m.total_params[0] = counter_parameters(m.parameters()) 27 | 28 | 29 | def zero_ops(m, x, y): 30 | m.total_ops += counter_zero_ops() 31 | 32 | 33 | def count_convNd(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): 34 | x = x[0] 35 | 36 | kernel_ops = torch.zeros(m.weight.size()[2:]).numel() # Kw x Kh 37 | bias_ops = 1 if m.bias is not None else 0 38 | 39 | # N x Cout x H x W x (Cin x Kw x Kh + bias) 40 | m.total_ops += counter_conv( 41 | bias_ops, 42 | torch.zeros(m.weight.size()[2:]).numel(), 43 | y.nelement(), 44 | m.in_channels, 45 | m.groups, 46 | ) 47 | 48 | 49 | def count_convNd_ver2(m: _ConvNd, x: (torch.Tensor,), y: torch.Tensor): 50 | x = x[0] 51 | 52 | # N x H x W (exclude Cout) 53 | output_size = torch.zeros((y.size()[:1] + y.size()[2:])).numel() 54 | # # Cout x Cin x Kw x Kh 55 | # kernel_ops = m.weight.nelement() 56 | # if m.bias is not None: 57 | # # Cout x 1 58 | # kernel_ops += + m.bias.nelement() 59 | # # x N x H x W x Cout x (Cin x Kw x Kh + bias) 60 | # m.total_ops += torch.DoubleTensor([int(output_size * kernel_ops)]) 61 | m.total_ops += counter_conv(m.bias.nelement(), m.weight.nelement(), output_size) 62 | 63 | 64 | def count_bn(m, x, y): 65 | x = x[0] 66 | if not m.training: 67 | m.total_ops += counter_norm(x.numel()) 68 | 69 | 70 | def count_ln(m, x, y): 71 | x = x[0] 72 | if not m.training: 73 | m.total_ops += counter_norm(x.numel()) 74 | 75 | 76 | def count_in(m, x, y): 77 | x = x[0] 78 | if not m.training: 79 | m.total_ops += counter_norm(x.numel()) 80 | 81 | 82 | def count_prelu(m, x, y): 83 | x = x[0] 84 | 85 | nelements = x.numel() 86 | if not m.training: 87 | m.total_ops += counter_relu(nelements) 88 | 89 | 90 | def count_relu(m, x, y): 91 | x = x[0] 92 | 93 | nelements = x.numel() 94 | 95 | m.total_ops += counter_relu(nelements) 96 | 97 | 98 | def count_softmax(m, x, y): 99 | x = x[0] 100 | nfeatures = x.size()[m.dim] 101 | batch_size = x.numel() // nfeatures 102 | 103 | m.total_ops += counter_softmax(batch_size, nfeatures) 104 | 105 | 106 | def count_avgpool(m, x, y): 107 | # total_add = torch.prod(torch.Tensor([m.kernel_size])) 108 | # total_div = 1 109 | # kernel_ops = total_add + total_div 110 | num_elements = y.numel() 111 | m.total_ops += counter_avgpool(num_elements) 112 | 113 | 114 | def count_adap_avgpool(m, x, y): 115 | kernel = torch.DoubleTensor([*(x[0].shape[2:])]) // torch.DoubleTensor( 116 | [*(y.shape[2:])] 117 | ) 118 | total_add = torch.prod(kernel) 119 | num_elements = y.numel() 120 | m.total_ops += counter_adap_avg(total_add, num_elements) 121 | 122 | 123 | # TODO: verify the accuracy 124 | def count_upsample(m, x, y): 125 | if m.mode not in ( 126 | "nearest", 127 | "linear", 128 | "bilinear", 129 | "bicubic", 130 | ): # "trilinear" 131 | logging.warning("mode %s is not implemented yet, take it a zero op" % m.mode) 132 | return counter_zero_ops() 133 | 134 | if m.mode == "nearest": 135 | return counter_zero_ops() 136 | 137 | x = x[0] 138 | m.total_ops += counter_upsample(m.mode, y.nelement()) 139 | 140 | 141 | # nn.Linear 142 | def count_linear(m, x, y): 143 | # per output element 144 | total_mul = m.in_features 145 | # total_add = m.in_features - 1 146 | # total_add += 1 if m.bias is not None else 0 147 | num_elements = y.numel() 148 | 149 | m.total_ops += counter_linear(total_mul, num_elements) 150 | -------------------------------------------------------------------------------- /thop/vision/counter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def counter_parameters(para_list): 6 | total_params = 0 7 | for p in para_list: 8 | total_params += torch.DoubleTensor([p.nelement()]) 9 | return total_params 10 | 11 | 12 | def counter_zero_ops(): 13 | return torch.DoubleTensor([int(0)]) 14 | 15 | 16 | def counter_conv(bias, kernel_size, output_size, in_channel, group): 17 | """inputs are all numbers!""" 18 | return torch.DoubleTensor([output_size * (in_channel / group * kernel_size + bias)]) 19 | 20 | 21 | def counter_norm(input_size): 22 | """input is a number not a array or tensor""" 23 | return torch.DoubleTensor([2 * input_size]) 24 | 25 | 26 | def counter_relu(input_size: torch.Tensor): 27 | return torch.DoubleTensor([int(input_size)]) 28 | 29 | 30 | def counter_softmax(batch_size, nfeatures): 31 | total_exp = nfeatures 32 | total_add = nfeatures - 1 33 | total_div = nfeatures 34 | total_ops = batch_size * (total_exp + total_add + total_div) 35 | return torch.DoubleTensor([int(total_ops)]) 36 | 37 | 38 | def counter_avgpool(input_size): 39 | return torch.DoubleTensor([int(input_size)]) 40 | 41 | 42 | def counter_adap_avg(kernel_size, output_size): 43 | total_div = 1 44 | kernel_op = kernel_size + total_div 45 | return torch.DoubleTensor([int(kernel_op * output_size)]) 46 | 47 | 48 | def counter_upsample(mode: str, output_size): 49 | total_ops = output_size 50 | if mode == "linear": 51 | total_ops *= 5 52 | elif mode == "bilinear": 53 | total_ops *= 11 54 | elif mode == "bicubic": 55 | ops_solve_A = 224 # 128 muls + 96 adds 56 | ops_solve_p = 35 # 16 muls + 12 adds + 4 muls + 3 adds 57 | total_ops *= ops_solve_A + ops_solve_p 58 | elif mode == "trilinear": 59 | total_ops *= 13 * 2 + 5 60 | return torch.DoubleTensor([int(total_ops)]) 61 | 62 | 63 | def counter_linear(in_feature, num_elements): 64 | return torch.DoubleTensor([int(in_feature * num_elements)]) 65 | 66 | 67 | def counter_matmul(input_size, output_size): 68 | input_size = np.array(input_size) 69 | output_size = np.array(output_size) 70 | return np.prod(input_size) * output_size[-1] 71 | 72 | 73 | def counter_mul(input_size): 74 | return input_size 75 | 76 | 77 | def counter_pow(input_size): 78 | return input_size 79 | 80 | 81 | def counter_sqrt(input_size): 82 | return input_size 83 | 84 | 85 | def counter_div(input_size): 86 | return input_size 87 | -------------------------------------------------------------------------------- /thop/vision/efficientnet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.modules.conv import _ConvNd 7 | 8 | from efficientnet_pytorch.utils import Conv2dDynamicSamePadding, Conv2dStaticSamePadding 9 | 10 | register_hooks = {} 11 | -------------------------------------------------------------------------------- /thop/vision/onnx_counter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from onnx import numpy_helper 4 | from thop.vision.basic_hooks import zero_ops 5 | from .counter import ( 6 | counter_matmul, 7 | counter_zero_ops, 8 | counter_conv, 9 | counter_mul, 10 | counter_norm, 11 | counter_pow, 12 | counter_sqrt, 13 | counter_div, 14 | counter_softmax, 15 | counter_avgpool, 16 | ) 17 | 18 | 19 | def onnx_counter_matmul(diction, node): 20 | input1 = node.input[0] 21 | input2 = node.input[1] 22 | input1_dim = diction[input1] 23 | input2_dim = diction[input2] 24 | out_size = np.append(input1_dim[0:-1], input2_dim[-1]) 25 | output_name = node.output[0] 26 | macs = counter_matmul(input1_dim, out_size[-2:]) 27 | return macs, out_size, output_name 28 | 29 | 30 | def onnx_counter_add(diction, node): 31 | if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size: 32 | out_size = diction[node.input[1]] 33 | else: 34 | out_size = diction[node.input[0]] 35 | output_name = node.output[0] 36 | macs = counter_zero_ops() 37 | # if '140' in diction: 38 | # print(diction['140'],output_name) 39 | return macs, out_size, output_name 40 | 41 | 42 | def onnx_counter_conv(diction, node): 43 | # print(node) 44 | # bias,kernelsize,outputsize 45 | dim_bias = 0 46 | input_count = 0 47 | for i in node.input: 48 | input_count += 1 49 | if input_count == 3: 50 | dim_bias = 1 51 | dim_weight = diction[node.input[1]] 52 | else: 53 | dim_weight = diction[node.input[1]] 54 | for attr in node.attribute: 55 | # print(attr) 56 | if attr.name == "kernel_shape": 57 | dim_kernel = attr.ints # kw,kh 58 | if attr.name == "strides": 59 | dim_stride = attr.ints 60 | if attr.name == "pads": 61 | dim_pad = attr.ints 62 | if attr.name == "dilations": 63 | dim_dil = attr.ints 64 | if attr.name == "group": 65 | group = attr.i 66 | # print(dim_dil) 67 | dim_input = diction[node.input[0]] 68 | output_size = np.append( 69 | dim_input[0 : -np.array(dim_kernel).size - 1], dim_weight[0] 70 | ) 71 | hw = np.array(dim_input[-np.array(dim_kernel).size :]) 72 | for i in range(hw.size): 73 | hw[i] = int( 74 | (hw[i] + 2 * dim_pad[i] - dim_dil[i] * (dim_kernel[i] - 1) - 1) 75 | / dim_stride[i] 76 | + 1 77 | ) 78 | output_size = np.append(output_size, hw) 79 | macs = counter_conv( 80 | dim_bias, np.prod(dim_kernel), np.prod(output_size), dim_weight[1], group 81 | ) 82 | output_name = node.output[0] 83 | 84 | # if '140' in diction: 85 | # print("conv",diction['140'],output_name) 86 | return macs, output_size, output_name 87 | 88 | 89 | def onnx_counter_constant(diction, node): 90 | # print("constant",node) 91 | macs = counter_zero_ops() 92 | output_name = node.output[0] 93 | output_size = [1] 94 | # print(macs, output_size, output_name) 95 | return macs, output_size, output_name 96 | 97 | 98 | def onnx_counter_mul(diction, node): 99 | if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size: 100 | input_size = diction[node.input[1]] 101 | else: 102 | input_size = diction[node.input[0]] 103 | macs = counter_mul(np.prod(input_size)) 104 | output_size = diction[node.input[0]] 105 | output_name = node.output[0] 106 | return macs, output_size, output_name 107 | 108 | 109 | def onnx_counter_bn(diction, node): 110 | input_size = diction[node.input[0]] 111 | macs = counter_norm(np.prod(input_size)) 112 | output_name = node.output[0] 113 | output_size = input_size 114 | return macs, output_size, output_name 115 | 116 | 117 | def onnx_counter_relu(diction, node): 118 | input_size = diction[node.input[0]] 119 | macs = counter_zero_ops() 120 | output_name = node.output[0] 121 | output_size = input_size 122 | # print(macs, output_size, output_name) 123 | # if '140' in diction: 124 | # print("relu",diction['140'],output_name) 125 | return macs, output_size, output_name 126 | 127 | 128 | def onnx_counter_reducemean(diction, node): 129 | keep_dim = 0 130 | for attr in node.attribute: 131 | if "axes" in attr.name: 132 | dim_axis = np.array(attr.ints) 133 | elif "keepdims" in attr.name: 134 | keep_dim = attr.i 135 | 136 | input_size = diction[node.input[0]] 137 | macs = counter_zero_ops() 138 | output_name = node.output[0] 139 | if keep_dim == 1: 140 | output_size = input_size 141 | else: 142 | output_size = np.delete(input_size, dim_axis) 143 | # output_size = input_size 144 | return macs, output_size, output_name 145 | 146 | 147 | def onnx_counter_sub(diction, node): 148 | input_size = diction[node.input[0]] 149 | macs = counter_zero_ops() 150 | output_name = node.output[0] 151 | output_size = input_size 152 | return macs, output_size, output_name 153 | 154 | 155 | def onnx_counter_pow(diction, node): 156 | if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size: 157 | input_size = diction[node.input[1]] 158 | else: 159 | input_size = diction[node.input[0]] 160 | macs = counter_pow(np.prod(input_size)) 161 | output_name = node.output[0] 162 | output_size = input_size 163 | return macs, output_size, output_name 164 | 165 | 166 | def onnx_counter_sqrt(diction, node): 167 | input_size = diction[node.input[0]] 168 | macs = counter_sqrt(np.prod(input_size)) 169 | output_name = node.output[0] 170 | output_size = input_size 171 | return macs, output_size, output_name 172 | 173 | 174 | def onnx_counter_div(diction, node): 175 | if np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size: 176 | input_size = diction[node.input[1]] 177 | else: 178 | input_size = diction[node.input[0]] 179 | macs = counter_div(np.prod(input_size)) 180 | output_name = node.output[0] 181 | output_size = input_size 182 | return macs, output_size, output_name 183 | 184 | 185 | def onnx_counter_instance(diction, node): 186 | input_size = diction[node.input[0]] 187 | macs = counter_norm(np.prod(input_size)) 188 | output_name = node.output[0] 189 | output_size = input_size 190 | return macs, output_size, output_name 191 | 192 | 193 | def onnx_counter_softmax(diction, node): 194 | input_size = diction[node.input[0]] 195 | dim = node.attribute[0].i 196 | nfeatures = input_size[dim] 197 | batch_size = np.prod(input_size) / nfeatures 198 | macs = counter_softmax(nfeatures, batch_size) 199 | output_name = node.output[0] 200 | output_size = input_size 201 | return macs, output_size, output_name 202 | 203 | 204 | def onnx_counter_pad(diction, node): 205 | # # TODO add constant name and output real vector 206 | # if 207 | # if (np.array(diction[node.input[1]]).size >= np.array(diction[node.input[0]]).size): 208 | # input_size = diction[node.input[1]] 209 | # else: 210 | # input_size = diction[node.input[0]] 211 | input_size = diction[node.input[0]] 212 | macs = counter_zero_ops() 213 | output_name = node.output[0] 214 | output_size = input_size 215 | return macs, output_size, output_name 216 | 217 | 218 | def onnx_counter_averagepool(diction, node): 219 | # TODO add support of ceil_mode and floor 220 | macs = counter_avgpool(np.prod(diction[node.input[0]])) 221 | output_name = node.output[0] 222 | dim_pad = None 223 | for attr in node.attribute: 224 | # print(attr) 225 | if attr.name == "kernel_shape": 226 | dim_kernel = attr.ints # kw,kh 227 | elif attr.name == "strides": 228 | dim_stride = attr.ints 229 | elif attr.name == "pads": 230 | dim_pad = attr.ints 231 | elif attr.name == "dilations": 232 | dim_dil = attr.ints 233 | # print(dim_dil) 234 | dim_input = diction[node.input[0]] 235 | hw = dim_input[-np.array(dim_kernel).size :] 236 | if dim_pad is not None: 237 | for i in range(hw.size): 238 | hw[i] = int((hw[i] + 2 * dim_pad[i] - dim_kernel[i]) / dim_stride[i] + 1) 239 | output_size = np.append(dim_input[0 : -np.array(dim_kernel).size], hw) 240 | else: 241 | for i in range(hw.size): 242 | hw[i] = int((hw[i] - dim_kernel[i]) / dim_stride[i] + 1) 243 | output_size = np.append(dim_input[0 : -np.array(dim_kernel).size], hw) 244 | # print(macs, output_size, output_name) 245 | return macs, output_size, output_name 246 | 247 | 248 | def onnx_counter_flatten(diction, node): 249 | # print(node) 250 | macs = counter_zero_ops() 251 | output_name = node.output[0] 252 | axis = node.attribute[0].i 253 | input_size = diction[node.input[0]] 254 | output_size = np.append(input_size[axis - 1], np.prod(input_size[axis:])) 255 | # print("flatten",output_size) 256 | return macs, output_size, output_name 257 | 258 | 259 | def onnx_counter_gemm(diction, node): 260 | # print(node) 261 | # Compute Y = alpha * A' * B' + beta * C 262 | input_size = diction[node.input[0]] 263 | dim_weight = diction[node.input[1]] 264 | # print(input_size,dim_weight) 265 | macs = np.prod(input_size) * dim_weight[1] + dim_weight[0] 266 | output_size = np.append(input_size[0:-1], dim_weight[0]) 267 | output_name = node.output[0] 268 | return macs, output_size, output_name 269 | pass 270 | 271 | 272 | def onnx_counter_maxpool(diction, node): 273 | # TODO add support of ceil_mode and floor 274 | # print(node) 275 | macs = counter_zero_ops() 276 | output_name = node.output[0] 277 | dim_pad = None 278 | for attr in node.attribute: 279 | # print(attr) 280 | if attr.name == "kernel_shape": 281 | dim_kernel = attr.ints # kw,kh 282 | elif attr.name == "strides": 283 | dim_stride = attr.ints 284 | elif attr.name == "pads": 285 | dim_pad = attr.ints 286 | elif attr.name == "dilations": 287 | dim_dil = attr.ints 288 | # print(dim_dil) 289 | dim_input = diction[node.input[0]] 290 | hw = dim_input[-np.array(dim_kernel).size :] 291 | if dim_pad is not None: 292 | for i in range(hw.size): 293 | hw[i] = int((hw[i] + 2 * dim_pad[i] - dim_kernel[i]) / dim_stride[i] + 1) 294 | output_size = np.append(dim_input[0 : -np.array(dim_kernel).size], hw) 295 | else: 296 | for i in range(hw.size): 297 | hw[i] = int((hw[i] - dim_kernel[i]) / dim_stride[i] + 1) 298 | output_size = np.append(dim_input[0 : -np.array(dim_kernel).size], hw) 299 | # print(macs, output_size, output_name) 300 | return macs, output_size, output_name 301 | 302 | 303 | def onnx_counter_globalaveragepool(diction, node): 304 | macs = counter_zero_ops() 305 | output_name = node.output[0] 306 | input_size = diction[node.input[0]] 307 | output_size = input_size 308 | return macs, output_size, output_name 309 | 310 | 311 | def onnx_counter_concat(diction, node): 312 | # print(node) 313 | # print(diction[node.input[0]]) 314 | axis = node.attribute[0].i 315 | input_size = diction[node.input[0]] 316 | for i in node.input: 317 | dim_concat = diction[i][axis] 318 | output_size = input_size 319 | output_size[axis] = dim_concat 320 | output_name = node.output[0] 321 | macs = counter_zero_ops() 322 | return macs, output_size, output_name 323 | 324 | 325 | def onnx_counter_clip(diction, node): 326 | macs = counter_zero_ops() 327 | output_name = node.output[0] 328 | input_size = diction[node.input[0]] 329 | output_size = input_size 330 | return macs, output_size, output_name 331 | 332 | 333 | onnx_operators = { 334 | "MatMul": onnx_counter_matmul, 335 | "Add": onnx_counter_add, 336 | "Conv": onnx_counter_conv, 337 | "Mul": onnx_counter_mul, 338 | "Constant": onnx_counter_constant, 339 | "BatchNormalization": onnx_counter_bn, 340 | "Relu": onnx_counter_relu, 341 | "ReduceMean": onnx_counter_reducemean, 342 | "Sub": onnx_counter_sub, 343 | "Pow": onnx_counter_pow, 344 | "Sqrt": onnx_counter_sqrt, 345 | "Div": onnx_counter_div, 346 | "InstanceNormalization": onnx_counter_instance, 347 | "Softmax": onnx_counter_softmax, 348 | "Pad": onnx_counter_pad, 349 | "AveragePool": onnx_counter_averagepool, 350 | "MaxPool": onnx_counter_maxpool, 351 | "Flatten": onnx_counter_flatten, 352 | "Gemm": onnx_counter_gemm, 353 | "GlobalAveragePool": onnx_counter_globalaveragepool, 354 | "Concat": onnx_counter_concat, 355 | "Clip": onnx_counter_clip, 356 | None: None, 357 | } 358 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from utils.logger import setup_logger 2 | from datasets import make_dataloader 3 | from model import make_model 4 | from solver import make_optimizer 5 | from solver.scheduler_factory import create_scheduler 6 | from loss import make_loss 7 | from processor import do_train 8 | import random 9 | import torch 10 | import numpy as np 11 | import os 12 | import argparse 13 | # from timm.scheduler import create_scheduler 14 | from config import cfg 15 | from utils.saver import Saver 16 | 17 | def set_seed(seed): 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | np.random.seed(seed) 22 | random.seed(seed) 23 | torch.backends.cudnn.deterministic = True 24 | torch.backends.cudnn.benchmark = True 25 | 26 | if __name__ == '__main__': 27 | 28 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 29 | parser.add_argument( 30 | "--config_file", default="", help="path to config file", type=str 31 | ) 32 | 33 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 34 | nargs=argparse.REMAINDER) 35 | parser.add_argument("--local_rank", default=0, type=int) 36 | args = parser.parse_args() 37 | 38 | if args.config_file != "": 39 | cfg.merge_from_file(args.config_file) 40 | cfg.merge_from_list(args.opts) 41 | cfg.freeze() 42 | 43 | set_seed(cfg.SOLVER.SEED) 44 | 45 | if cfg.MODEL.DIST_TRAIN: 46 | torch.cuda.set_device(args.local_rank) 47 | 48 | output_dir = cfg.OUTPUT_DIR 49 | if output_dir and not os.path.exists(output_dir): 50 | os.makedirs(output_dir) 51 | 52 | logger = setup_logger("pit", output_dir, if_train=True) 53 | logger.info("Saving model in the path :{}".format(cfg.OUTPUT_DIR)) 54 | logger.info(args) 55 | 56 | if args.config_file != "": 57 | logger.info("Loaded configuration file {}".format(args.config_file)) 58 | with open(args.config_file, 'r') as cf: 59 | config_str = "\n" + cf.read() 60 | logger.info(config_str) 61 | logger.info("Running with config:\n{}".format(cfg)) 62 | 63 | if cfg.MODEL.DIST_TRAIN: 64 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 65 | 66 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID 67 | train_loader, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg) 68 | 69 | if cfg.DATASETS.NAMES in ['ilids']: 70 | num_trails = 10 71 | else: 72 | num_trails = 1 73 | cmcs, mAPs = [], [] 74 | for i in range(num_trails): 75 | output_dir = cfg.OUTPUT_DIR + '/' + str(i + 1) 76 | test_weight = cfg.TEST.WEIGHT + '/' + str(i + 1) 77 | 78 | saver = Saver(output_dir, f'tensorboard') 79 | model = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num = view_num) 80 | test = os.path.isdir(test_weight) 81 | if test: 82 | model.load_param(os.path.join(test_weight, 'transformer_' + str(120) + '.pth')) 83 | 84 | computation_complexity = False 85 | if computation_complexity: 86 | from thop import profile 87 | import torch 88 | 89 | x = torch.randn(1, 1, 3, 256, 128) 90 | label = torch.randn(16).long() 91 | cam_label = torch.randn(1).long().clamp(0, 1) 92 | view_label = torch.randn(1).long().clamp(1, 1) 93 | macs, params = profile(model, inputs=(x,label, cam_label, view_label)) 94 | print(macs); print(params); 95 | 96 | loss_func, center_criterion = make_loss(cfg, num_classes=num_classes) 97 | 98 | optimizer, optimizer_center = make_optimizer(cfg, model, center_criterion) 99 | 100 | scheduler = create_scheduler(cfg, optimizer) 101 | 102 | cmc, mAP = do_train( 103 | cfg, 104 | model, 105 | center_criterion, 106 | train_loader[i], 107 | val_loader[i], 108 | optimizer, 109 | optimizer_center, 110 | scheduler, 111 | loss_func, 112 | num_query, args.local_rank, saver, i, test 113 | ) 114 | cmcs.append(cmc) 115 | mAPs.append(mAP) 116 | mAP = np.stack(mAPs).mean(axis=0) 117 | cmc = np.stack(cmcs).mean(axis=0) 118 | logger.info("{} trails average:".format(num_trails)) 119 | logger.info("mAP: {:.3%}".format(mAP)) 120 | for r in [1, 5, 10, 20]: 121 | logger.info("CMC curve, Rank-{:<3}:{:.3%}".format(r, cmc[r - 1])) 122 | 123 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deropty/PiT/779d3110808929f1f3d8c4e0ec6ec3f7f60dbffa/utils/__init__.py -------------------------------------------------------------------------------- /utils/iotools.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import errno 8 | import json 9 | import os 10 | 11 | import os.path as osp 12 | 13 | 14 | def mkdir_if_missing(directory): 15 | if not osp.exists(directory): 16 | try: 17 | os.makedirs(directory) 18 | except OSError as e: 19 | if e.errno != errno.EEXIST: 20 | raise 21 | 22 | 23 | def check_isfile(path): 24 | isfile = osp.isfile(path) 25 | if not isfile: 26 | print("=> Warning: no file found at '{}' (ignored)".format(path)) 27 | return isfile 28 | 29 | 30 | def read_json(fpath): 31 | with open(fpath, 'r') as f: 32 | obj = json.load(f) 33 | return obj 34 | 35 | 36 | def write_json(obj, fpath): 37 | mkdir_if_missing(osp.dirname(fpath)) 38 | with open(fpath, 'w') as f: 39 | json.dump(obj, f, indent=4, separators=(',', ': ')) 40 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import os.path as osp 5 | def setup_logger(name, save_dir, if_train): 6 | logger = logging.getLogger(name) 7 | logger.setLevel(logging.DEBUG) 8 | 9 | ch = logging.StreamHandler(stream=sys.stdout) 10 | ch.setLevel(logging.DEBUG) 11 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 12 | ch.setFormatter(formatter) 13 | logger.addHandler(ch) 14 | 15 | if save_dir: 16 | if not osp.exists(save_dir): 17 | os.makedirs(save_dir) 18 | if if_train: 19 | fh = logging.FileHandler(os.path.join(save_dir, "train_log.txt"), mode='w') 20 | else: 21 | fh = logging.FileHandler(os.path.join(save_dir, "test_log.txt"), mode='w') 22 | fh.setLevel(logging.DEBUG) 23 | fh.setFormatter(formatter) 24 | logger.addHandler(fh) 25 | 26 | return logger -------------------------------------------------------------------------------- /utils/meter.py: -------------------------------------------------------------------------------- 1 | class AverageMeter(object): 2 | """Computes and stores the average and current value""" 3 | 4 | def __init__(self): 5 | self.val = 0 6 | self.avg = 0 7 | self.sum = 0 8 | self.count = 0 9 | 10 | def reset(self): 11 | self.val = 0 12 | self.avg = 0 13 | self.sum = 0 14 | self.count = 0 15 | 16 | def update(self, val, n=1): 17 | self.val = val 18 | self.sum += val * n 19 | self.count += n 20 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from utils.reranking import re_ranking 5 | 6 | 7 | def euclidean_distance(qf, gf): 8 | m = qf.shape[0] 9 | n = gf.shape[0] 10 | dist_mat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 11 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 12 | dist_mat.addmm_(1, -2, qf, gf.t()) 13 | return dist_mat.cpu().numpy() 14 | 15 | def cosine_similarity(qf, gf): 16 | epsilon = 0.00001 17 | dist_mat = qf.mm(gf.t()) 18 | qf_norm = torch.norm(qf, p=2, dim=1, keepdim=True) # mx1 19 | gf_norm = torch.norm(gf, p=2, dim=1, keepdim=True) # nx1 20 | qg_normdot = qf_norm.mm(gf_norm.t()) 21 | 22 | dist_mat = dist_mat.mul(1 / qg_normdot).cpu().numpy() 23 | dist_mat = np.clip(dist_mat, -1 + epsilon, 1 - epsilon) 24 | dist_mat = np.arccos(dist_mat) 25 | return dist_mat 26 | 27 | 28 | def eval_func(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=50): 29 | """Evaluation with market1501 metric 30 | Key: for each query identity, its gallery images from the same camera view are discarded. 31 | """ 32 | num_q, num_g = distmat.shape 33 | # distmat g 34 | # q 1 3 2 4 35 | # 4 1 2 3 36 | if num_g < max_rank: 37 | max_rank = num_g 38 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 39 | indices = np.argsort(distmat, axis=1) 40 | # 0 2 1 3 41 | # 1 2 3 0 42 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 43 | # compute cmc curve for each query 44 | all_cmc = [] 45 | all_AP = [] 46 | num_valid_q = 0. # number of valid query 47 | All_AP = [] 48 | for q_idx in range(num_q): 49 | # get query pid and camid 50 | q_pid = q_pids[q_idx] 51 | q_camid = q_camids[q_idx] 52 | 53 | # remove gallery samples that have the same pid and camid with query 54 | order = indices[q_idx] # select one row 55 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 56 | keep = np.invert(remove) 57 | 58 | # compute cmc curve 59 | # binary vector, positions with value 1 are correct matches 60 | orig_cmc = matches[q_idx][keep] 61 | if not np.any(orig_cmc): 62 | # this condition is true when query identity does not appear in gallery 63 | All_AP.append('None') 64 | continue 65 | 66 | cmc = orig_cmc.cumsum() 67 | cmc[cmc > 1] = 1 68 | 69 | all_cmc.append(cmc[:max_rank]) 70 | num_valid_q += 1. 71 | 72 | # compute average precision 73 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 74 | num_rel = orig_cmc.sum() 75 | tmp_cmc = orig_cmc.cumsum() 76 | #tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 77 | y = np.arange(1, tmp_cmc.shape[0] + 1) * 1.0 78 | tmp_cmc = tmp_cmc / y 79 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 80 | AP = tmp_cmc.sum() / num_rel 81 | all_AP.append(AP) 82 | All_AP.append(AP) 83 | 84 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 85 | 86 | all_cmc = np.asarray(all_cmc).astype(np.float32) 87 | all_cmc = all_cmc.sum(0) / num_valid_q 88 | mAP = np.mean(all_AP) 89 | 90 | import pandas as pd 91 | All_AP = np.asarray(All_AP) 92 | data = pd.DataFrame({'AP': All_AP}) 93 | data.to_csv('All_AP.csv', index=False, sep=',') 94 | 95 | data = pd.DataFrame({str(i): distmat[:, i] for i in range(distmat.shape[1])}) 96 | data.to_csv('distmat.csv', index=True, sep=',') 97 | 98 | return all_cmc, mAP 99 | 100 | 101 | class R1_mAP_eval(): 102 | def __init__(self, num_query, max_rank=50, feat_norm=True, reranking=False): 103 | super(R1_mAP_eval, self).__init__() 104 | self.num_query = num_query 105 | self.max_rank = max_rank 106 | self.feat_norm = feat_norm 107 | self.reranking = reranking 108 | 109 | def reset(self): 110 | self.feats = [] 111 | self.pids = [] 112 | self.camids = [] 113 | 114 | def update(self, output): # called once for each batch 115 | feat, pid, camid = output 116 | self.feats.append(feat.cpu()) 117 | self.pids.extend(np.asarray(pid)) 118 | self.camids.extend(np.asarray(camid)) 119 | 120 | def compute(self): # called after each epoch 121 | feats = torch.cat(self.feats, dim=0) 122 | if self.feat_norm: 123 | print("The test feature is normalized") 124 | feats = torch.nn.functional.normalize(feats, dim=1, p=2) # along channel 125 | # query 126 | qf = feats[:self.num_query] 127 | q_pids = np.asarray(self.pids[:self.num_query]) 128 | q_camids = np.asarray(self.camids[:self.num_query]) 129 | # gallery 130 | gf = feats[self.num_query:] 131 | g_pids = np.asarray(self.pids[self.num_query:]) 132 | 133 | g_camids = np.asarray(self.camids[self.num_query:]) 134 | if self.reranking: 135 | print('=> Enter reranking') 136 | # distmat = re_ranking(qf, gf, k1=20, k2=6, lambda_value=0.3) 137 | distmat = re_ranking(qf, gf, k1=50, k2=15, lambda_value=0.3) 138 | 139 | else: 140 | print('=> Computing DistMat with euclidean_distance') 141 | distmat = euclidean_distance(qf, gf) 142 | cmc, mAP = eval_func(distmat, q_pids, g_pids, q_camids, g_camids) 143 | 144 | return cmc, mAP, distmat, self.pids, self.camids, qf, gf 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /utils/reranking.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri, 25 May 2018 20:29:09 5 | 6 | 7 | """ 8 | 9 | """ 10 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 11 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 12 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 13 | """ 14 | 15 | """ 16 | API 17 | 18 | probFea: all feature vectors of the query set (torch tensor) 19 | probFea: all feature vectors of the gallery set (torch tensor) 20 | k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3) 21 | MemorySave: set to 'True' when using MemorySave mode 22 | Minibatch: avaliable when 'MemorySave' is 'True' 23 | """ 24 | 25 | import numpy as np 26 | import torch 27 | 28 | 29 | def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat=None, only_local=False): 30 | # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor 31 | query_num = probFea.size(0) 32 | all_num = query_num + galFea.size(0) 33 | if only_local: 34 | original_dist = local_distmat 35 | else: 36 | feat = torch.cat([probFea, galFea]) 37 | # print('using GPU to compute original distance') 38 | distmat = torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num) + \ 39 | torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t() 40 | distmat.addmm_(1, -2, feat, feat.t()) 41 | original_dist = distmat.cpu().numpy() 42 | del feat 43 | if not local_distmat is None: 44 | original_dist = original_dist + local_distmat 45 | gallery_num = original_dist.shape[0] 46 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) 47 | V = np.zeros_like(original_dist).astype(np.float16) 48 | initial_rank = np.argsort(original_dist).astype(np.int32) 49 | 50 | # print('starting re_ranking') 51 | for i in range(all_num): 52 | # k-reciprocal neighbors 53 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 54 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 55 | fi = np.where(backward_k_neigh_index == i)[0] 56 | k_reciprocal_index = forward_k_neigh_index[fi] 57 | k_reciprocal_expansion_index = k_reciprocal_index 58 | for j in range(len(k_reciprocal_index)): 59 | candidate = k_reciprocal_index[j] 60 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1] 61 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 62 | :int(np.around(k1 / 2)) + 1] 63 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 64 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 65 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 66 | candidate_k_reciprocal_index): 67 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 68 | 69 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 70 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 71 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) 72 | original_dist = original_dist[:query_num, ] 73 | if k2 != 1: 74 | V_qe = np.zeros_like(V, dtype=np.float16) 75 | for i in range(all_num): 76 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 77 | V = V_qe 78 | del V_qe 79 | del initial_rank 80 | invIndex = [] 81 | for i in range(gallery_num): 82 | invIndex.append(np.where(V[:, i] != 0)[0]) 83 | 84 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) 85 | 86 | for i in range(query_num): 87 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16) 88 | indNonZero = np.where(V[i, :] != 0)[0] 89 | indImages = [invIndex[ind] for ind in indNonZero] 90 | for j in range(len(indNonZero)): 91 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 92 | V[indImages[j], indNonZero[j]]) 93 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 94 | 95 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 96 | del original_dist 97 | del V 98 | del jaccard_dist 99 | final_dist = final_dist[:query_num, query_num:] 100 | return final_dist 101 | 102 | -------------------------------------------------------------------------------- /utils/saver.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from pathlib import Path 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | import json 8 | 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | 12 | class Saver(object): 13 | """ 14 | """ 15 | def __init__(self, path: str, uuid: str): 16 | self.path = Path(path) / uuid 17 | self.path.mkdir(exist_ok=True, parents=True) 18 | 19 | self.chk_path = self.path / 'chk' 20 | self.chk_path.mkdir(exist_ok=True) 21 | 22 | self.log_path = self.path / 'logs' 23 | self.log_path.mkdir(exist_ok=True) 24 | 25 | self.params_path = self.path / 'params' 26 | self.params_path.mkdir(exist_ok=True) 27 | 28 | # TB logs 29 | self.writer = SummaryWriter(str(self.path)) 30 | 31 | # Dump the `git log` and `git diff`. In this way one can checkout 32 | # the last commit, add the diff and should be in the same state. 33 | for cmd in ['log', 'diff']: 34 | with open(self.path / f'git_{cmd}.txt', mode='wt') as f: 35 | subprocess.run(['git', cmd], stdout=f) 36 | 37 | def load_logs(self): 38 | with open(str(self.params_path / 'params.json'), 'r') as fp: 39 | params = json.load(fp) 40 | with open(str(self.params_path / 'hparams.json'), 'r') as fp: 41 | hparams = json.load(fp) 42 | return params, hparams 43 | 44 | @staticmethod 45 | def load_net(path: str, chk_name: str, dataset_name: str): 46 | with open(str(Path(path) / 'params' / 'hparams.json'), 'r') as fp: 47 | net_hparams = json.load(fp) 48 | with open(str(Path(path) / 'params' / 'params.json'), 'r') as fp: 49 | net_params = json.load(fp) 50 | 51 | assert dataset_name == net_params['dataset_name'] 52 | net = TriNet(backbone_type=net_hparams['backbone_type'], pretrained=True, 53 | num_classes=net_hparams['num_classes']) 54 | net_state_dict = torch.load(Path(path) / 'chk' / chk_name) 55 | net.load_state_dict(net_state_dict) 56 | return net 57 | 58 | def write_logs(self, model: torch.nn.Module, params: dict): 59 | with open(str(self.params_path / 'params.json'), 'w') as fp: 60 | json.dump(params, fp) 61 | with open(str(self.params_path / 'hparams.json'), 'w') as fp: 62 | json.dump(model.get_hparams(), fp) 63 | 64 | def write_image(self, image: np.ndarray, epoch: int, name: str): 65 | out_image_path = self.log_path / f'{epoch:05d}_{name}.jpg' 66 | cv2.imwrite(str(out_image_path), image) 67 | 68 | image = image[..., ::-1] 69 | self.writer.add_image(f'{name}', image, epoch, dataformats='HWC') 70 | 71 | def dump_metric_tb(self, value: float, epoch: int, m_type: str, m_desc: str): 72 | self.writer.add_scalar(f'{m_type}/{m_desc}', value, epoch) 73 | 74 | def save_net(self, net: torch.nn.Module, name: str = 'weights', overwrite: bool = False): 75 | weights_path = self.chk_path / name 76 | if weights_path.exists() and not overwrite: 77 | raise ValueError('PREVENT OVERWRITE WEIGHTS') 78 | torch.save(net.state_dict(), weights_path) 79 | 80 | def dump_hparams(self, hparams: dict, metrics: dict): 81 | self.writer.add_hparams(hparams, metrics) 82 | -------------------------------------------------------------------------------- /vis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch, cv2 3 | import numpy as np 4 | from config import cfg 5 | import argparse 6 | from datasets import make_dataloader 7 | from model import make_model 8 | from processor import do_inference 9 | from utils.logger import setup_logger 10 | import torchvision.transforms as T 11 | from PIL import Image 12 | import matplotlib.pyplot as plt 13 | from model.make_model import __num_of_layers 14 | 15 | 16 | if __name__ == "__main__": 17 | # import matplotlib.pyplot as plt 18 | # from matplotlib.backends.backend_pdf import PdfPages 19 | # 20 | # x1 = [1, 2, 3, 5, 6, 7] 21 | # x2 = [1, 6, 14, 15] 22 | # y1 = [84, 85.07, 85.72, 85.56, 85.96, 85.99] 23 | # y2 = [84, 85.33, 85.86, 86.24, 85.94, 86] 24 | # y3 = [84, 86.01, 86.11, 86.17] 25 | # plt.plot(x1, y1, 's-', color='r', label="Vertical Division") 26 | # plt.plot(x1, y2, 'o-', color='g', label="Horizontal Division") 27 | # plt.plot(x2, y3, 'd-', color='b', label="Patch-based Division") 28 | # plt.xlabel("Number of Parts") 29 | # plt.ylabel("mAP") 30 | # plt.legend(loc="best") 31 | # plt.savefig("map.pdf") 32 | parser = argparse.ArgumentParser(description="ReID Baseline Training") 33 | parser.add_argument( 34 | "--config_file", default="", help="path to config file", type=str 35 | ) 36 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 37 | nargs=argparse.REMAINDER) 38 | args = parser.parse_args() 39 | 40 | 41 | 42 | if args.config_file != "": 43 | cfg.merge_from_file(args.config_file) 44 | cfg.merge_from_list(args.opts) 45 | cfg.merge_from_list(['TEST.VIS', "True"]) 46 | cfg.freeze() 47 | 48 | output_dir = cfg.OUTPUT_DIR 49 | if output_dir and not os.path.exists(output_dir): 50 | os.makedirs(output_dir) 51 | 52 | logger = setup_logger("pit", output_dir, if_train=False) 53 | logger.info(args) 54 | 55 | if args.config_file != "": 56 | logger.info("Loaded configuration file {}".format(args.config_file)) 57 | with open(args.config_file, 'r') as cf: 58 | config_str = "\n" + cf.read() 59 | logger.info(config_str) 60 | logger.info("Running with config:\n{}".format(cfg)) 61 | 62 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID 63 | 64 | epochs = cfg.SOLVER.MAX_EPOCHS 65 | eval_period = cfg.SOLVER.EVAL_PERIOD 66 | OUTPUT_DIR = cfg.TEST.WEIGHT 67 | 68 | val_transforms = T.Compose([ 69 | T.Resize(cfg.INPUT.SIZE_TEST, interpolation=3), 70 | T.ToTensor(), 71 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 72 | ]) 73 | 74 | filename = '0208C6T0044F010' 75 | with open(filename+'.jpg', 'rb') as f: 76 | with Image.open(f) as im: 77 | img = im.convert('RGB') 78 | img = torch.stack([val_transforms(img)], 0).unsqueeze(0).to("cuda") 79 | 80 | model_path = 'transformer_120.pth' 81 | train_loader, val_loader, num_query, num_classes, camera_num, view_num = make_dataloader(cfg) 82 | model = make_model(cfg, num_class=num_classes, camera_num=camera_num, view_num=view_num) 83 | model.load_param(os.path.join(OUTPUT_DIR, model_path)) 84 | model.to("cuda") 85 | model.eval() 86 | attns = model(img, cam_label=[0]) 87 | 88 | attn_base, attn_head = attns 89 | att_base, att_head = torch.stack(attn_base).squeeze(1), torch.stack(attn_head).squeeze(1) 90 | 91 | # Average the attention weights across all heads. 92 | att_base, att_head = torch.mean(att_base, dim=1), torch.mean(att_head, dim=1) 93 | 94 | # To account for residual connections, we add an identity matrix to the 95 | # attention matrix and re-normalize the weights. 96 | residual_att_base, redisual_att_head = \ 97 | torch.eye(att_base.size(1)).to("cuda"), \ 98 | torch.eye(att_head.size(1)).to("cuda") 99 | aug_att_base, aug_att_head = \ 100 | att_base + residual_att_base, \ 101 | att_head + redisual_att_head 102 | aug_att_base, aug_att_head = \ 103 | aug_att_base / aug_att_base.sum(dim=-1).unsqueeze(-1), \ 104 | aug_att_head / aug_att_head.sum(dim=-1).unsqueeze(-1) 105 | 106 | # Recursively multiply the weight matrices 107 | joint_attentions = torch.zeros(aug_att_base.size()).to("cuda") 108 | joint_attentions[0] = aug_att_base[0] 109 | 110 | for n in range(1, aug_att_base.size(0)): 111 | joint_attentions[n] = torch.matmul(aug_att_base[n], joint_attentions[n - 1]) 112 | 113 | # Attention from the att_base output token to the input space. 114 | v = joint_attentions[-1] 115 | img_H, img_W = 21, 10 116 | mask = v[0, 1:].reshape(img_H, img_W).detach().cpu().numpy() 117 | mask = cv2.resize(mask / mask.max(), im.size)[..., np.newaxis] 118 | result = (mask * im).astype("uint8") 119 | 120 | # # Attention for the att_head output token to the input space 121 | # cls_token = joint_attentions[-1][0:1, 0:1 ] 122 | # patch_token = joint_attentions[-1][0:1, 1:].reshape(img_H, img_W) 123 | # num_patch = __num_of_layers[cfg.MODEL.LAYER0_DIVISION_TYPE] 124 | # division_length = patch_token.size(1) // num_patch 125 | # # tokens = [patch_token[:, i*division_length: (i+1)*division_length] for i in range(num_patch)] 126 | # # tokens = [torch.cat((cls_token, i), dim=1) for i in tokens] 127 | # # joint_attentions_head = [torch.matmul(i,j) for i,j in zip(tokens, aug_att_head)] 128 | # 129 | # if cfg.MODEL.PYRAMID0_TYPE == 'patch': 130 | # if num_patch == 6: 131 | # # patch_tokens = [patch_token[m * 7:(m + 1) * 7, n * 5:(n + 1) * 5].reshape(35, 1) 132 | # # for m in range(3) for n in range(2)] 133 | # # tokens = [torch.cat((cls_token, i), dim=0) for i in patch_tokens] 134 | # # joint_attentions_head = [torch.matmul(i, j) for i, j in zip(aug_att_head, tokens)] 135 | # # attns = [i[1:, :].reshape(7, 5) for i in joint_attentions_head] 136 | # # attns = torch.cat([torch.cat( 137 | # # [attns[m*2+n] for n in range(2)], dim=1) 138 | # # for m in range(3)], dim=0) 139 | # # attns = attns.reshape(img_H, img_W).detach().cpu().numpy() 140 | # # mask = cv2.resize(attns / attns.max(), im.size)[..., np.newaxis] 141 | # # result = (mask * im).astype("uint8") 142 | # 143 | # # pos = [[m * 7, (m + 1) * 7, n * 5, (n + 1) * 5] for m in range(3) for n in range(2)] 144 | # # pos_store = [] 145 | # # for p in pos: 146 | # # i,j,m,n = p 147 | # # tmp = [] 148 | # # for x in range(i,j): 149 | # # for y in range(m,n): 150 | # # tmp.append(x*10+y+1) 151 | # # pos_store.append(tmp) 152 | # # tokens = joint_attentions[-1] 153 | # # attentions = [] 154 | # # for p in pos_store: 155 | # # p_ = [0] + p 156 | # # tmp = [tokens[0][p_]] 157 | # # for p_in in p: 158 | # # tmp.append(tokens[p_in][p_]) 159 | # # tmp = torch.stack(tmp, dim=0) 160 | # # attentions.append(tmp) 161 | # # joint_attentions_head = [torch.matmul(i, j) for i, j in zip(attentions, aug_att_head)] 162 | # # attns = [i[0, 1:].reshape(7, 5) for i in joint_attentions_head] 163 | # # attns = torch.cat([torch.cat( 164 | # # [attns[m*2+n] for n in range(2)], dim=1) 165 | # # for m in range(3)], dim=0) 166 | # # attns = attns.reshape(img_H, img_W).detach().cpu().numpy() 167 | # 168 | # att_head_vis = att_head[:,0,1:].reshape(num_patch, 7, 5) 169 | # attns = torch.cat([torch.cat( 170 | # [att_head_vis[m*2+n] for n in range(2)], dim=1) 171 | # for m in range(3)], dim=0) 172 | # p = attns.detach().cpu().numpy() 173 | # plt.figure() 174 | # plt.imshow(p) 175 | # plt.show() 176 | # 177 | # 178 | # 179 | # elif num_patch == 14: 180 | # attns = [i[:, :, 0, 1:].reshape(1, -1, 3, 5) for i in attns] 181 | # attns = torch.cat([torch.cat( 182 | # [attns[m * 2 + n] for n in range(2)], dim=3) 183 | # for m in range(7)], dim=2) 184 | # elif num_patch == 15: 185 | # attns = [i[:, :, 0, 1:].reshape(1, -1, 7, 2) for i in attns] 186 | # attns = torch.cat([torch.cat( 187 | # [attns[m * 2 + n] for n in range(3)], dim=3) 188 | # for m in range(5)], dim=2) 189 | # else: 190 | # if cfg.MODEL.PYRAMID0_TYPE == 'horizontal': 191 | # feature = feature.reshape(B, N, -1, self.in_planes) 192 | # elif cfg.MODEL.PYRAMID0_TYPE == 'vertical': 193 | # feature = feature.transpose(-3, -2).reshape(B, N, -1, self.in_planes) 194 | # division_length = (features.size(2) - 1) // self.layer_division_num[i] 195 | # local_feats = [feature[:, :, m * division_length:(m + 1) * division_length] 196 | # for m in range(self.layer_division_num[i])] 197 | 198 | from matplotlib.backends.backend_pdf import PdfPages 199 | plt.figure() 200 | plt.imshow(result) 201 | # plt.show() 202 | plt.savefig(filename + '_result.pdf') 203 | 204 | plt.figure() 205 | plt.imshow(mask) 206 | # plt.show() 207 | plt.savefig(filename + '_mask.pdf') 208 | 209 | plt.figure() 210 | plt.imshow(im) 211 | # plt.show() 212 | plt.savefig(filename + '_im.pdf') 213 | 214 | 215 | 216 | --------------------------------------------------------------------------------