├── LICENSE ├── README.md ├── config ├── __init__.py └── defaults.py ├── configs ├── duke_r101.yml ├── vehicleid_r101.yml └── veri_r101.yml ├── datasets ├── __init__.py ├── base.py ├── base_id.py ├── data_loading.py ├── duke.py ├── init_dataset.py ├── loader.py ├── test_loading.py ├── transform.py ├── vehicleid.py └── veri.py ├── eval.py ├── images ├── affinity_matrix.png └── architecture.png ├── loss ├── __init__.py ├── cross_entropy_loss.py ├── hard_mine_triplet_loss.py └── losses.py ├── main.py ├── model ├── __init__.py ├── lr_schedulers.py ├── models.py ├── optimizers.py ├── resnet.py └── senet.py ├── pkl ├── duke │ └── index.pkl ├── vehicleid │ └── index.pkl └── veri │ ├── cids.pkl │ ├── data.pkl │ └── index_vp.pkl ├── requirements.txt ├── train.py └── utils ├── avgmeter.py ├── create_gms_index.py ├── evaluation.py ├── functions.py ├── generaltools.py ├── iotools.py ├── kwargs.py ├── loggers.py ├── mean_and_std.py ├── reranking.py ├── torchtools.py └── visualtools.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Adhiraj Ghosh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Relation Preserving Triplet Mining for Stabilising the Triplet Loss in Re-identification Sytems

3 |

WACV 2023

4 | 5 | Adhiraj Ghosh1,2, Kuruparan Shanmugalingam1,3, Wen-Yan Lin1 6 | 7 | 1Singapore Management University 2University of Tübingen 3University of New South Wales 8 | 9 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/relation-preserving-triplet-mining-for/vehicle-re-identification-on-veri-776)](https://paperswithcode.com/sota/vehicle-re-identification-on-veri-776?p=relation-preserving-triplet-mining-for) 10 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/relation-preserving-triplet-mining-for/vehicle-re-identification-on-vehicleid-small)](https://paperswithcode.com/sota/vehicle-re-identification-on-vehicleid-small?p=relation-preserving-triplet-mining-for) 11 | 12 | PyTorch 13 | 14 | [[Paper](https://openaccess.thecvf.com/content/WACV2023/html/Ghosh_Relation_Preserving_Triplet_Mining_for_Stabilising_the_Triplet_Loss_In_WACV_2023_paper.html)] 15 | [[Video](https://youtu.be/TseV_Hoz2Ms?si=VlAReJ2eETPmYKh1)] 16 | 17 | The *official* repository for **Relation Preserving Triplet Mining for Stabilising the Triplet Loss in Re-identification Sytems**. Our work achieves state-of-the-art results and provides a faster optimised and more generalisable model for re-identification. 18 |
19 | 20 | ## Network Architecture 21 | ![Architecture](images/architecture.png) 22 | 23 | ## Preparation 24 | 25 | ### Installation 26 | 27 | 1. Install CUDA compatible torch. Modify based on CUDA version. 28 | ``` 29 | conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia 30 | ``` 31 | 2. Install other dependencies. 32 | ```bash 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | 3. Install apex (optional but recommended) 37 | 38 | Follow the installation guidelines from https://github.com/NVIDIA/apex 39 | Then set SOLVER.USE_AMP as True in the config files directly or via command line. 40 | ### Prepare Datasets 41 | 42 | ```bash 43 | mkdir data 44 | ``` 45 | 46 | Download the vehicle reID datasets [VehicleID](https://www.pkuml.org/resources/pku-vehicleid.html) and [VeRi-776](https://github.com/JDAI-CV/VeRidataset), and the person reID datasets [DukeMTMC-reID](https://arxiv.org/abs/1609.01775). 47 | Follow the structure and naming convention as below. 48 | 49 | ``` 50 | data 51 | ├── duke 52 | │   └── images .. 53 | ├── vehicleid 54 | │   └── images .. 55 | └── veri 56 | └── images .. 57 | ``` 58 | 59 | ### Prepare GMS Feature Matches 60 | ```bash 61 | mkdir gms 62 | ``` 63 | 64 | You need to download the GMS feature matches for VeRi, VehicleID and DukeMTMC: [GMS](https://drive.google.com/drive/folders/1hdk3pi4Bi_Tb2B7XcBmvwG91Sfisi6BO?usp=share_link). 65 | 66 | The folder should follow the structure as shown below: 67 | ``` 68 | gms 69 | ├── duke 70 | │   └── 0001.pkl .. 71 | ├── vehicleid 72 | │   └── 00001.pkl .. 73 | └── veri 74 | └── 001.pkl .. 75 | ``` 76 | 77 | You can also create your own GMS matches for VeRi-776, VeRi-Wild and VehicleID by running the following script: ```utils/create_gms_index.py```. You can edit which Dataset to build GMS matches for by editing the initial parameters inside the script. 78 | 79 | ## Running RPTM 80 | 1. Training 81 | ```bash 82 | python main.py --config_file configs/veri_r101.yml 83 | ``` 84 | The above command trains a baseline using our RPTM algorithm for VeRi. Note that after training, the model provides evaluation results, both qualitative as well as quantitative. 85 | 86 | 2. RPTM Thresholding Strategies 87 | 88 | In Section 4.2 of our paper, we defined a thresholding strategy for better anchor-positive selections. We define this in config files as MODEL.RPTM_SELECT. While it is set to 'mean', feel free to work with 'min' and 'max'. 89 | 90 | #### Min Thresholding 91 | ```bash 92 | python main.py --config_file configs/veri_r101.yml MODEL.RPTM_SELECT 'min' 93 | ``` 94 | 95 | #### Max Thresholding 96 | ```bash 97 | python main.py --config_file configs/veri_r101.yml MODEL.RPTM_SELECT 'max' 98 | ``` 99 | 100 | 3. Testing 101 | ```bash 102 | mkdir logs 103 | python main.py --config_file configs/veri_r101.yml TEST.WEIGHT '' TEST.EVAL True 104 | ``` 105 | 106 | ## Mean Average Precision(mAP) Results 107 | 1. VeRi776: **88.0%** 108 | 2. VehicleID (query size 800): **84.8%** 109 | 3. VehicleID (query size 1600): **81.2%** 110 | 4. VehicleID (query size 2400): **80.5%** 111 | 5. DukeMTMC: **89.2%** 112 | 113 | ## Acknowledgement 114 | 115 | GMS Feature Matching Algorithm taken from: https://github.com/JiawangBian/GMS-Feature-Matcher 116 | 117 | ## Citation 118 | 119 | If you find this code useful for your research, please cite our paper 120 | 121 | ``` 122 | @InProceedings{Ghosh_2023_WACV, 123 | author = {Ghosh, Adhiraj and Shanmugalingam, Kuruparan and Lin, Wen-Yan}, 124 | title = {Relation Preserving Triplet Mining for Stabilising the Triplet Loss In re-Identification Systems}, 125 | booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)}, 126 | month = {January}, 127 | year = {2023}, 128 | pages = {4840-4849} 129 | } 130 | ``` 131 | 132 | ## Contact 133 | 134 | If you have any questions, please feel free to contact us. E-mail: [Adhiraj Ghosh](mailto:adhirajghosh1998@gmail.com) , [Wen-Yan Lin](mailto:daniellin@smu.edu.sg) 135 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | from .defaults import _C as cfg 2 | from .defaults import _C as cfg_test -------------------------------------------------------------------------------- /config/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | # ----------------------------------------------------------------------------- 4 | # Config definition 5 | # ----------------------------------------------------------------------------- 6 | 7 | _C = CN() 8 | 9 | # ----------------------------------------------------------------------------- 10 | # MODEL 11 | # ----------------------------------------------------------------------------- 12 | _C.MODEL = CN() 13 | _C.MODEL.DEVICE = "cuda" 14 | _C.MODEL.PRETRAIN_CHOICE= 'imagenet' 15 | _C.MODEL.PRETRAIN_PATH= '' 16 | _C.MODEL.ARCH= 'SE_net' 17 | _C.MODEL.DROPRATE= 0 18 | _C.MODEL.STRIDE= 1 19 | _C.MODEL.POOL= 'avg' 20 | _C.MODEL.GPU_ID= ('0') 21 | _C.MODEL.RPTM_SELECT= 'mean' 22 | 23 | # ---------------------------------------------------------------------------- # 24 | # Input options 25 | # ---------------------------------------------------------------------------- # 26 | _C.INPUT = CN() 27 | _C.INPUT.HEIGHT= 128 28 | _C.INPUT.WIDTH= 128 29 | _C.INPUT.PROB = 0.5 30 | _C.INPUT.RANDOM_ERASE = True 31 | _C.INPUT.JITTER= True 32 | _C.INPUT.AUG= True 33 | 34 | # ---------------------------------------------------------------------------- # 35 | # Dataset options 36 | # ---------------------------------------------------------------------------- # 37 | 38 | _C.DATASET = CN() 39 | _C.DATASET.SOURCE_NAME= ['veri'] 40 | _C.DATASET.TARGET_NAME= ['veri'] 41 | _C.DATASET.ROOT_DIR= '' 42 | _C.DATASET.TRAIN_DIR= '' 43 | _C.DATASET.SPLIT_DIR= '' 44 | 45 | # ---------------------------------------------------------------------------- # 46 | # Dataloader options 47 | # ---------------------------------------------------------------------------- # 48 | _C.DATALOADER = CN() 49 | _C.DATALOADER.SAMPLER= 'RandomSampler' 50 | _C.DATALOADER.NUM_INSTANCE= 6 51 | _C.DATALOADER.NUM_WORKERS= 16 52 | 53 | # ---------------------------------------------------------------------------- # 54 | # Solver options 55 | # ---------------------------------------------------------------------------- # 56 | _C.SOLVER = CN() 57 | _C.SOLVER.OPTIMIZER_NAME= 'SGD' 58 | _C.SOLVER.MAX_EPOCHS= 80 59 | _C.SOLVER.BASE_LR= 0.005 60 | _C.SOLVER.LR_SCHEDULER= 'multi-step' 61 | _C.SOLVER.STEPSIZE= [20, 40, 60] 62 | _C.SOLVER.GAMMA= 0.1 63 | _C.SOLVER.WEIGHT_DECAY= 5e-4 64 | _C.SOLVER.MOMENTUM= 0.9 65 | _C.SOLVER.SGD_DAMP= 0.0 66 | _C.SOLVER.NESTEROV= True 67 | _C.SOLVER.WARMUP_FACTOR= 0.01 68 | _C.SOLVER.WARMUP_EPOCHS= 10 69 | _C.SOLVER.WARMUP_METHOD= 'linear' 70 | _C.SOLVER.LARGE_FC_LR= False 71 | _C.SOLVER.TRAIN_BATCH_SIZE= 20 72 | _C.SOLVER.USE_AMP= True 73 | _C.SOLVER.CHECKPOINT_PERIOD= 10 74 | _C.SOLVER.LOG_PERIOD= 50 75 | _C.SOLVER.EVAL_PERIOD= 1 76 | 77 | # ---------------------------------------------------------------------------- # 78 | # Loss options 79 | # ---------------------------------------------------------------------------- # 80 | _C.LOSS = CN() 81 | _C.LOSS.MARGIN= 1.0 82 | _C.LOSS.LAMBDA_HTRI= 1.0 83 | _C.LOSS.LAMBDA_XENT= 1.0 84 | 85 | # ---------------------------------------------------------------------------- # 86 | # Test options 87 | # ---------------------------------------------------------------------------- # 88 | _C.TEST = CN() 89 | _C.TEST.EVAL= True 90 | _C.TEST.TEST_BATCH_SIZE= 100 91 | _C.TEST.TEST_SIZE = 11579 92 | _C.TEST.RE_RANKING= True 93 | _C.TEST.VIS_RANK= True 94 | _C.TEST.WEIGHT= '' 95 | _C.TEST.NECK_FEAT= 'after' 96 | _C.TEST.FEAT_NORM= 'yes' 97 | 98 | # ---------------------------------------------------------------------------- # 99 | # Misc options 100 | # ---------------------------------------------------------------------------- # 101 | _C.MISC = CN() 102 | _C.MISC.SAVE_DIR= './logs/veri/' 103 | _C.MISC.GMS_PATH= './gms/veri/' 104 | _C.MISC.INDEX_PATH= './pkl/veri/index_vp.pkl' 105 | _C.MISC.USE_GPU= True 106 | _C.MISC.PRINT_FREQ= 100 107 | _C.MISC.FP16= True 108 | -------------------------------------------------------------------------------- /configs/duke_r101.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | ARCH: 'resnet101_ibn_a' 4 | DROPRATE: 0 5 | STRIDE: 1 6 | POOL: 'avg' 7 | GPU_ID: ('0') 8 | RPTM_SELECT: 'mean' 9 | 10 | INPUT: 11 | HEIGHT: 300 12 | WIDTH: 150 13 | PROB: 0.5 # random horizontal flip 14 | RANDOM_ERASE: True 15 | JITTER: True 16 | AUG: True 17 | 18 | 19 | DATASET: 20 | SOURCE_NAME: ['duke'] 21 | TARGET_NAME: ['duke'] 22 | ROOT_DIR: './data/' 23 | TRAIN_DIR: './data/duke/image_train/' 24 | SPLIT_DIR: './data/duke/train_split/' 25 | 26 | DATALOADER: 27 | SAMPLER: 'RandomSampler' 28 | NUM_INSTANCE: 6 29 | NUM_WORKERS: 16 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'sgd' 33 | MAX_EPOCHS: 80 34 | BASE_LR: 0.005 35 | LR_SCHEDULER: 'multi-step' 36 | STEPSIZE: [20,40,60] 37 | GAMMA: 0.1 38 | WEIGHT_DECAY: 5e-4 39 | MOMENTUM: 0.9 40 | SGD_DAMP: 0.0 41 | NESTEROV: True 42 | WARMUP_FACTOR: 0.01 43 | WARMUP_EPOCHS: 10 44 | WARMUP_METHOD: 'linear' 45 | LARGE_FC_LR: False 46 | TRAIN_BATCH_SIZE: 20 47 | USE_AMP: False 48 | CHECKPOINT_PERIOD: 10 49 | LOG_PERIOD: 50 50 | EVAL_PERIOD: 1 51 | 52 | LOSS: 53 | MARGIN: 1.0 54 | LAMBDA_HTRI: 1.0 55 | LAMBDA_XENT: 1.0 56 | 57 | TEST: 58 | EVAL: False 59 | WEIGHT: '' 60 | TEST_BATCH_SIZE: 100 61 | RE_RANKING: True 62 | VIS_RANK: True 63 | NECK_FEAT: 'after' 64 | FEAT_NORM: 'yes' 65 | 66 | MISC: 67 | SAVE_DIR: './logs/duke/' 68 | GMS_PATH: './gms/duke/' 69 | INDEX_PATH: './pkl/duke/index.pkl' 70 | USE_GPU: True 71 | PRINT_FREQ: 100 72 | FP16: True 73 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /configs/vehicleid_r101.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | ARCH: 'resnet101_ibn_a' 4 | DROPRATE: 0 5 | STRIDE: 1 6 | POOL: 'avg' 7 | GPU_ID: ('0') 8 | RPTM_SELECT: 'mean' 9 | 10 | INPUT: 11 | HEIGHT: 128 12 | WIDTH: 128 13 | PROB: 0.5 # random horizontal flip 14 | RANDOM_ERASE: True 15 | JITTER: True 16 | AUG: True 17 | 18 | 19 | DATASET: 20 | SOURCE_NAME: ['vehicleid'] 21 | TARGET_NAME: ['vehicleid'] 22 | ROOT_DIR: './data/' 23 | TRAIN_DIR: './data/vehicleid/image_train/' 24 | SPLIT_DIR: './data/vehicleid/train_split/' 25 | 26 | DATALOADER: 27 | SAMPLER: 'RandomSampler' 28 | NUM_INSTANCE: 6 29 | NUM_WORKERS: 16 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'sgd' 33 | MAX_EPOCHS: 40 34 | BASE_LR: 0.005 35 | LR_SCHEDULER: 'multi-step' 36 | STEPSIZE: [10, 20, 30] 37 | GAMMA: 0.1 38 | WEIGHT_DECAY: 5e-4 39 | MOMENTUM: 0.9 40 | SGD_DAMP: 0.0 41 | NESTEROV: True 42 | WARMUP_FACTOR: 0.01 43 | WARMUP_EPOCHS: 10 44 | WARMUP_METHOD: 'linear' 45 | LARGE_FC_LR: False 46 | TRAIN_BATCH_SIZE: 20 47 | USE_AMP: False 48 | CHECKPOINT_PERIOD: 10 49 | LOG_PERIOD: 50 50 | EVAL_PERIOD: 1 51 | 52 | LOSS: 53 | MARGIN: 1.0 54 | LAMBDA_HTRI: 1.0 55 | LAMBDA_XENT: 1.0 56 | 57 | TEST: 58 | EVAL: True 59 | WEIGHT: '' 60 | TEST_BATCH_SIZE: 100 61 | TEST_SIZE: 800 62 | RE_RANKING: True 63 | VIS_RANK: True 64 | NECK_FEAT: 'after' 65 | FEAT_NORM: 'yes' 66 | 67 | MISC: 68 | SAVE_DIR: './logs/vehicleid/' 69 | GMS_PATH: './gms/vehicleid/' 70 | INDEX_PATH: './pkl/vehicleid/index.pkl' 71 | USE_GPU: True 72 | PRINT_FREQ: 100 73 | FP16: True 74 | 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /configs/veri_r101.yml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PRETRAIN_CHOICE: 'imagenet' 3 | ARCH: 'resnet101_ibn_a' 4 | DROPRATE: 0 5 | STRIDE: 1 6 | POOL: 'avg' 7 | GPU_ID: ('0') 8 | RPTM_SELECT: 'mean' 9 | 10 | INPUT: 11 | HEIGHT: 128 12 | WIDTH: 128 13 | PROB: 0.5 # random horizontal flip 14 | RANDOM_ERASE: True 15 | JITTER: True 16 | AUG: True 17 | 18 | 19 | DATASET: 20 | SOURCE_NAME: ['veri'] 21 | TARGET_NAME: ['veri'] 22 | ROOT_DIR: './data/' 23 | TRAIN_DIR: './data/veri/image_train/' 24 | SPLIT_DIR: './data/veri/train_split/' 25 | 26 | DATALOADER: 27 | SAMPLER: 'RandomSampler' 28 | NUM_INSTANCE: 6 29 | NUM_WORKERS: 16 30 | 31 | SOLVER: 32 | OPTIMIZER_NAME: 'sgd' 33 | MAX_EPOCHS: 80 34 | BASE_LR: 0.005 35 | LR_SCHEDULER: 'multi-step' 36 | STEPSIZE: [20,40,60] 37 | GAMMA: 0.1 38 | WEIGHT_DECAY: 5e-4 39 | MOMENTUM: 0.9 40 | SGD_DAMP: 0.0 41 | NESTEROV: True 42 | WARMUP_FACTOR: 0.01 43 | WARMUP_EPOCHS: 10 44 | WARMUP_METHOD: 'linear' 45 | LARGE_FC_LR: False 46 | TRAIN_BATCH_SIZE: 20 47 | USE_AMP: False 48 | CHECKPOINT_PERIOD: 10 49 | LOG_PERIOD: 50 50 | EVAL_PERIOD: 1 51 | 52 | LOSS: 53 | MARGIN: 1.0 54 | LAMBDA_HTRI: 1.0 55 | LAMBDA_XENT: 1.0 56 | 57 | TEST: 58 | EVAL: True 59 | WEIGHT: '' 60 | TEST_BATCH_SIZE: 100 61 | RE_RANKING: True 62 | VIS_RANK: True 63 | NECK_FEAT: 'after' 64 | FEAT_NORM: 'yes' 65 | 66 | MISC: 67 | SAVE_DIR: './logs/veri/' 68 | GMS_PATH: './gms/veri/' 69 | INDEX_PATH: './pkl/veri/index_vp.pkl' 70 | USE_GPU: True 71 | PRINT_FREQ: 100 72 | FP16: True 73 | 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | from .loader import * 7 | # from .data_loading import * 8 | # from .test_loading import * 9 | # from .transform import * 10 | 11 | -------------------------------------------------------------------------------- /datasets/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import os.path as osp 5 | 6 | 7 | class BaseDataset(object): 8 | """ 9 | Base class of reid dataset 10 | """ 11 | 12 | def __init__(self, root): 13 | self.root = osp.expanduser(root) 14 | 15 | def get_imagedata_info(self, data): 16 | pids, cams = [], [] 17 | for _, pid, camid in data: 18 | pids += [pid] 19 | cams += [camid] 20 | pids = set(pids) 21 | cams = set(cams) 22 | num_pids = len(pids) 23 | num_cams = len(cams) 24 | num_imgs = len(data) 25 | return num_pids, num_imgs, num_cams 26 | 27 | def print_dataset_statistics(self): 28 | raise NotImplementedError 29 | 30 | 31 | class BaseImageDataset(BaseDataset): 32 | """ 33 | Base class of image reid dataset 34 | """ 35 | 36 | def print_dataset_statistics(self, train, query, gallery): 37 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) 38 | #num_val_pids, num_val_imgs, num_val_cams = self.get_imagedata_info(val) 39 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) 40 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) 41 | 42 | print('Image Dataset statistics:') 43 | print(' ----------------------------------------') 44 | print(' subset | # ids | # images | # cameras') 45 | print(' ----------------------------------------') 46 | print(' train | {:5d} | {:8d} | {:9d}'.format(num_train_pids, num_train_imgs, num_train_cams)) 47 | #print(' val | {:5d} | {:8d} | {:9d}'.format(num_val_pids, num_val_imgs, num_val_cams)) 48 | print(' query | {:5d} | {:8d} | {:9d}'.format(num_query_pids, num_query_imgs, num_query_cams)) 49 | print(' gallery | {:5d} | {:8d} | {:9d}'.format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) 50 | print(' ----------------------------------------') 51 | -------------------------------------------------------------------------------- /datasets/base_id.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import os.path as osp 5 | 6 | 7 | class BaseDataset(object): 8 | """ 9 | Base class of reid dataset 10 | """ 11 | 12 | def __init__(self, root): 13 | self.root = osp.expanduser(root) 14 | 15 | def get_imagedata_info(self, data): 16 | pids = [] 17 | for _, pid in data: 18 | pids += [pid] 19 | 20 | pids = set(pids) 21 | num_pids = len(pids) 22 | num_imgs = len(data) 23 | return num_pids, num_imgs 24 | 25 | def print_dataset_statistics(self): 26 | raise NotImplementedError 27 | 28 | 29 | class BaseImageDataset(BaseDataset): 30 | """ 31 | Base class of image reid dataset 32 | """ 33 | 34 | def print_dataset_statistics(self, train, query, gallery): 35 | num_train_pids, num_train_imgs = self.get_imagedata_info(train) 36 | #num_val_pids, num_val_imgs, num_val_cams = self.get_imagedata_info(val) 37 | num_query_pids, num_query_imgs = self.get_imagedata_info(query) 38 | num_gallery_pids, num_gallery_imgs = self.get_imagedata_info(gallery) 39 | 40 | print('Image Dataset statistics:') 41 | print(' ----------------------------') 42 | print(' subset | # ids | # images ') 43 | print(' ----------------------------') 44 | print(' train | {:5d} | {:8d} '.format(num_train_pids, num_train_imgs)) 45 | #print(' val | {:5d} | {:8d} | {:9d}'.format(num_val_pids, num_val_imgs, num_val_cams)) 46 | print(' query | {:5d} | {:8d} '.format(num_query_pids, num_query_imgs)) 47 | print(' gallery | {:5d} | {:8d} '.format(num_gallery_pids, num_gallery_imgs)) 48 | print(' ----------------------------') 49 | -------------------------------------------------------------------------------- /datasets/data_loading.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import torch 3 | import os 4 | import sys 5 | import time 6 | import datetime 7 | import os.path as osp 8 | import numpy as np 9 | import warnings 10 | from PIL import Image 11 | from skimage import io, transform 12 | from torch.utils.data import Dataset 13 | 14 | class VeriDataset(Dataset): 15 | """Veri dataset.""" 16 | 17 | def __init__(self, pkl_file, dataset, root_dir, transform=None): 18 | 19 | with open(pkl_file, 'rb') as handle: 20 | c = pickle.load(handle) 21 | self.index = c 22 | self.root_dir = root_dir 23 | self.dataset = dataset 24 | self.transform = transform 25 | 26 | def __len__(self): 27 | return len(self.dataset) 28 | 29 | def __getitem__(self, idx): 30 | if torch.is_tensor(idx): 31 | idx = idx.tolist() 32 | 33 | img_name = os.path.join(self.root_dir, 34 | self.dataset[idx][0]) 35 | img = Image.open(os.path.join(self.root_dir, img_name[-24:])).convert('RGB') 36 | label = self.dataset[idx][1] 37 | pid = self.dataset[idx][2] 38 | cid = self.dataset[idx][3] 39 | if self.dataset[idx][0] not in self.index: 40 | index = 0 41 | else: 42 | index = self.index[self.dataset[idx][0]][1] 43 | 44 | if self.transform: 45 | img = self.transform(img) 46 | 47 | return img,label,index,pid, cid 48 | 49 | class IdDataset(Dataset): 50 | """VehicleId dataset.""" 51 | 52 | def __init__(self, pkl_file, dataset, root_dir, transform=None): 53 | 54 | with open(pkl_file, 'rb') as handle: 55 | c = pickle.load(handle) 56 | self.index = c 57 | self.root_dir = root_dir 58 | self.dataset = dataset 59 | self.transform = transform 60 | 61 | def __len__(self): 62 | return len(self.dataset) 63 | 64 | def __getitem__(self, idx): 65 | if torch.is_tensor(idx): 66 | idx = idx.tolist() 67 | 68 | img_name = os.path.join(self.root_dir, 69 | self.dataset[idx][0]) 70 | img = Image.open(os.path.join(self.root_dir, img_name[-17:])).convert('RGB') 71 | label = self.dataset[idx][1] 72 | pid = self.dataset[idx][2] 73 | cid = self.dataset[idx][3] 74 | index = self.index[self.dataset[idx][0]][1] 75 | 76 | 77 | if self.transform: 78 | img = self.transform(img) 79 | 80 | return img,label,index,pid, cid 81 | 82 | class DukeDataset(Dataset): 83 | """Duke dataset.""" 84 | 85 | def __init__(self, pkl_file, dataset, root_dir, transform=None): 86 | 87 | with open(pkl_file, 'rb') as handle: 88 | c = pickle.load(handle) 89 | self.index = c 90 | self.root_dir = root_dir 91 | self.dataset = dataset 92 | self.transform = transform 93 | 94 | def __len__(self): 95 | return len(self.dataset) 96 | 97 | def __getitem__(self, idx): 98 | if torch.is_tensor(idx): 99 | idx = idx.tolist() 100 | 101 | img_name = os.path.join(self.root_dir, 102 | self.dataset[idx][0]) 103 | img = Image.open(os.path.join(self.root_dir, img_name[-20:])).convert('RGB') 104 | label = self.dataset[idx][1] 105 | pid = self.dataset[idx][2] 106 | cid = self.dataset[idx][3] 107 | index = self.index[self.dataset[idx][0]][1] 108 | 109 | 110 | if self.transform: 111 | img = self.transform(img) 112 | 113 | return img,label,index,pid, cid 114 | 115 | 116 | def read_image(img_path): 117 | """Keep reading image until succeed. 118 | This can avoid IOError incurred by heavy IO process.""" 119 | got_img = False 120 | if not osp.exists(img_path): 121 | raise IOError('{} does not exist'.format(img_path)) 122 | while not got_img: 123 | try: 124 | img = Image.open(img_path).convert('RGB') 125 | got_img = True 126 | except IOError: 127 | print('IOError incurred when reading "{}". Will redo. Don\'t worry. Just chill.'.format(img_path)) 128 | pass 129 | return img 130 | 131 | 132 | class ImageDataset(Dataset): 133 | """Image Person ReID Dataset""" 134 | 135 | def __init__(self, dataset, transform=None): 136 | self.dataset = dataset 137 | self.transform = transform 138 | 139 | def __len__(self): 140 | return len(self.dataset) 141 | 142 | def __getitem__(self, index): 143 | img_path, pid, camid = self.dataset[index] 144 | img = read_image(img_path) 145 | 146 | if self.transform is not None: 147 | img = self.transform(img) 148 | 149 | return img, pid, camid, img_path 150 | 151 | class IdImageDataset(Dataset): 152 | """Image Person ReID Dataset""" 153 | 154 | def __init__(self, dataset, transform=None): 155 | self.dataset = dataset 156 | self.transform = transform 157 | 158 | def __len__(self): 159 | return len(self.dataset) 160 | 161 | def __getitem__(self, index): 162 | img_path, pid = self.dataset[index] 163 | img = read_image(img_path) 164 | 165 | if self.transform is not None: 166 | img = self.transform(img) 167 | 168 | return img, pid, _, img_path 169 | -------------------------------------------------------------------------------- /datasets/duke.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import glob 6 | import re 7 | import os.path as osp 8 | 9 | from .base import BaseImageDataset 10 | 11 | 12 | class duke(BaseImageDataset): 13 | 14 | dataset_dir = 'duke' 15 | 16 | def __init__(self, root='datasets', dataset_dir = 'duke', verbose=True, **kwargs): 17 | super(duke, self).__init__(root) 18 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 19 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 20 | #self.val_dir = osp.join(self.dataset_dir, 'image_val') 21 | self.query_dir = osp.join(self.dataset_dir, 'image_query') 22 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test') 23 | 24 | self.check_before_run() 25 | 26 | train = self.process_dir(self.train_dir, relabel=True) 27 | #val = self.process_dir(self.val_dir, relabel=True) 28 | query = self.process_dir(self.query_dir, relabel=False) 29 | gallery = self.process_dir(self.gallery_dir, relabel=False) 30 | 31 | if verbose: 32 | print('=> Duke loaded') 33 | self.print_dataset_statistics(train, query, gallery) 34 | 35 | self.train = train 36 | #self.val = val 37 | self.query = query 38 | self.gallery = gallery 39 | 40 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 41 | #self.num_val_pids, self.num_val_imgs, self.num_val_cams = self.get_imagedata_info(self.val) 42 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 43 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 44 | 45 | def check_before_run(self): 46 | """Check if all files are available before going deeper""" 47 | if not osp.exists(self.dataset_dir): 48 | raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) 49 | if not osp.exists(self.train_dir): 50 | raise RuntimeError('"{}" is not available'.format(self.train_dir)) 51 | if not osp.exists(self.query_dir): 52 | raise RuntimeError('"{}" is not available'.format(self.query_dir)) 53 | if not osp.exists(self.gallery_dir): 54 | raise RuntimeError('"{}" is not available'.format(self.gallery_dir)) 55 | 56 | def process_dir(self, dir_path, relabel=False): 57 | img_paths = sorted(glob.glob(osp.join(dir_path, '*.jpg'))) 58 | pattern = re.compile(r'([-\d]+)_c([-\d]+)') 59 | 60 | pid_container = set() 61 | for img_path in img_paths: 62 | pid, _ = map(int, pattern.search(img_path).groups()) 63 | if pid == -1: 64 | continue # junk images are just ignored 65 | pid_container.add(pid) 66 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 67 | 68 | dataset = [] 69 | for img_path in img_paths: 70 | pid, camid = map(int, pattern.search(img_path).groups()) 71 | if pid == -1: 72 | continue # junk images are just ignored 73 | assert 0 <= pid <= 7140 # pid == 0 means background 74 | assert 1 <= camid <= 20 75 | camid -= 1 # index starts from 0 76 | if relabel: 77 | pid = pid2label[pid] 78 | dataset.append((img_path, pid, camid)) 79 | 80 | return dataset 81 | -------------------------------------------------------------------------------- /datasets/init_dataset.py: -------------------------------------------------------------------------------- 1 | from .veri import VeRi 2 | from .vehicleid import VehicleID 3 | from .duke import duke 4 | 5 | 6 | __imgreid_factory = { 7 | 'veri': VeRi, 8 | 'vehicleID': VehicleID, 9 | 'duke': duke, 10 | } 11 | def init_imgreid_dataset(name, **kwargs): 12 | if name not in list(__imgreid_factory.keys()): 13 | raise KeyError('Invalid dataset, got "{}", but expected to be one of {}'.format(name, list(__imgreid_factory.keys()))) 14 | return __imgreid_factory[name](**kwargs) 15 | -------------------------------------------------------------------------------- /datasets/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from torch.utils.data import DataLoader 4 | from .init_dataset import init_imgreid_dataset 5 | from .transform import * 6 | from .data_loading import VeriDataset as vd 7 | from .data_loading import IdDataset as id 8 | from .data_loading import DukeDataset as dd 9 | from .test_loading import ImageDataManager 10 | 11 | def data_loader(cfg, dataset_kwargs, transform_kwargs): 12 | dataset = init_imgreid_dataset(root=cfg.DATASET.ROOT_DIR, name=cfg.DATASET.SOURCE_NAME[0]) 13 | num_train_pids = 0 14 | num_train_cams = 0 15 | train = [] 16 | 17 | for img_path, pid, camid in dataset.train: 18 | # path = img_path[-24:] 19 | path = img_path.split('/', 4)[-1] 20 | if cfg.DATASET.SOURCE_NAME[0] == 'veri': 21 | folder = path.split('_', 1)[0][1:] 22 | else: 23 | folder = path.split('_', 1)[0] 24 | pid += num_train_pids 25 | camid += num_train_cams 26 | train.append((path, folder, pid, camid)) 27 | 28 | num_train_pids += dataset.num_train_pids 29 | class_names = num_train_pids 30 | num_train_cams += dataset.num_train_cams 31 | 32 | pid = 0 33 | pidx = {} 34 | for img_path, pid, camid in dataset.train: 35 | path = img_path.split('/', 4)[-1] 36 | if cfg.DATASET.SOURCE_NAME[0] == 'veri': 37 | folder = path.split('_', 1)[0][1:] 38 | else: 39 | folder = path.split('_', 1)[0] 40 | pidx[folder] = pid 41 | pid += 1 42 | 43 | gms = {} 44 | entries = sorted(os.listdir(cfg.MISC.GMS_PATH)) 45 | # print(entries) 46 | for name in entries: 47 | f = open((cfg.MISC.GMS_PATH + name), 'rb') 48 | if name == 'featureMatrix.pkl': 49 | s = name[0:13] 50 | else: 51 | s = name[0:3] 52 | gms[s] = pickle.load(f) 53 | f.close 54 | 55 | transform_t = train_transforms(**transform_kwargs) 56 | if cfg.DATASET.SOURCE_NAME[0] == 'veri': 57 | data_tfr = vd(pkl_file=cfg.MISC.INDEX_PATH, dataset=train, root_dir=cfg.DATASET.TRAIN_DIR, transform=transform_t) 58 | elif cfg.DATASET.SOURCE_NAME[0] == 'vehicleid': 59 | data_tfr = id(pkl_file=cfg.MISC.INDEX_PATH, dataset=train, root_dir=cfg.DATASET.TRAIN_DIR, transform=transform_t) 60 | elif cfg.DATASET.SOURCE_NAME[0] == 'duke': 61 | data_tfr = dd(pkl_file=cfg.MISC.INDEX_PATH, dataset=train, root_dir=cfg.DATASET.TRAIN_DIR, transform=transform_t) 62 | trainloader = DataLoader(data_tfr, sampler=None, batch_size=cfg.SOLVER.TRAIN_BATCH_SIZE, shuffle=True, num_workers=cfg.DATALOADER.NUM_WORKERS, 63 | pin_memory=False, drop_last=True) 64 | 65 | print('Initializing test data manager') 66 | dm = ImageDataManager(cfg.MISC.USE_GPU, **dataset_kwargs) 67 | testloader_dict = dm.return_dataloaders() 68 | train_dict = {} 69 | train_dict['class_names'] = class_names 70 | train_dict['num_train_pids'] = num_train_pids 71 | train_dict['gms'] = gms 72 | train_dict['pidx'] = pidx 73 | 74 | 75 | return trainloader, train_dict, data_tfr, testloader_dict, dm 76 | -------------------------------------------------------------------------------- /datasets/test_loading.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | from torch.utils.data import DataLoader 5 | from torchvision import transforms, utils 6 | from torchvision.transforms import * 7 | from .data_loading import ImageDataset 8 | from .init_dataset import init_imgreid_dataset 9 | from .transform import test_transform 10 | 11 | 12 | class BaseDataManager(object): 13 | 14 | def __init__(self, 15 | use_gpu, 16 | source_names, 17 | target_names, 18 | root='datasets', 19 | height=128, 20 | width=256, 21 | train_batch_size=32, 22 | test_batch_size=100, 23 | workers=4, 24 | train_sampler='', 25 | val_sampler='', 26 | random_erase=False, # use random erasing for data augmentation 27 | color_jitter=False, # randomly change the brightness, contrast and saturation 28 | color_aug=False, # randomly alter the intensities of RGB channels 29 | num_instances=4, # number of instances per identity (for RandomIdentitySampler) 30 | **kwargs 31 | ): 32 | self.use_gpu = use_gpu 33 | self.source_names = source_names 34 | self.target_names = target_names 35 | self.root = root 36 | self.height = height 37 | self.width = width 38 | self.train_batch_size = train_batch_size 39 | self.test_batch_size = test_batch_size 40 | self.workers = workers 41 | self.train_sampler = train_sampler 42 | self.val_sampler = val_sampler 43 | self.random_erase = random_erase 44 | self.color_jitter = color_jitter 45 | self.color_aug = color_aug 46 | self.num_instances = num_instances 47 | 48 | transform_test = test_transform(self.height, self.width) 49 | self.transform_test = transform_test 50 | 51 | 52 | def return_dataloaders(self): 53 | """ 54 | Return testloader dictionary 55 | """ 56 | return self.testloader_dict 57 | 58 | def return_testdataset_by_name(self, name): 59 | """ 60 | Return query and gallery, each containing a list of (img_path, pid, camid). 61 | """ 62 | return self.testdataset_dict[name]['query'], self.testdataset_dict[name]['gallery'] 63 | 64 | 65 | class ImageDataManager(BaseDataManager): 66 | """ 67 | Vehicle-ReID data manager 68 | """ 69 | def __init__(self, 70 | use_gpu, 71 | source_names, 72 | target_names, 73 | **kwargs 74 | ): 75 | super(ImageDataManager, self).__init__(use_gpu, source_names, target_names, **kwargs) 76 | 77 | print('=> Initializing TEST (target) datasets') 78 | self.testloader_dict = {name: {'query': None, 'gallery': None} for name in target_names} 79 | self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in target_names} 80 | 81 | for name in self.target_names: 82 | dataset = init_imgreid_dataset( 83 | root=self.root, name=name) 84 | 85 | self.testloader_dict[name]['query'] = DataLoader( 86 | ImageDataset(dataset.query, transform=self.transform_test), 87 | batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers, 88 | pin_memory=self.use_gpu, drop_last=False 89 | ) 90 | 91 | self.testloader_dict[name]['gallery'] = DataLoader( 92 | ImageDataset(dataset.gallery, transform=self.transform_test), 93 | batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers, 94 | pin_memory=self.use_gpu, drop_last=False 95 | ) 96 | 97 | self.testdataset_dict[name]['query'] = dataset.query 98 | self.testdataset_dict[name]['gallery'] = dataset.gallery 99 | -------------------------------------------------------------------------------- /datasets/transform.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from PIL import Image 6 | import random 7 | import math 8 | 9 | import torch 10 | from torchvision.transforms import * 11 | 12 | class Random2DTranslation(object): 13 | """ 14 | With a probability, first increase image size to (1 + 1/8), and then perform random crop. 15 | Args: 16 | - height (int): target image height. 17 | - width (int): target image width. 18 | - p (float): probability of performing this transformation. Default: 0.5. 19 | """ 20 | 21 | def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR): 22 | self.height = height 23 | self.width = width 24 | self.p = p 25 | self.interpolation = interpolation 26 | 27 | def __call__(self, img): 28 | """ 29 | Args: 30 | - img (PIL Image): Image to be cropped. 31 | """ 32 | if random.uniform(0, 1) > self.p: 33 | return img.resize((self.width, self.height), self.interpolation) 34 | 35 | new_width, new_height = int(round(self.width * 1.125)), int(round(self.height * 1.125)) 36 | resized_img = img.resize((new_width, new_height), self.interpolation) 37 | x_maxrange = new_width - self.width 38 | y_maxrange = new_height - self.height 39 | x1 = int(round(random.uniform(0, x_maxrange))) 40 | y1 = int(round(random.uniform(0, y_maxrange))) 41 | croped_img = resized_img.crop((x1, y1, x1 + self.width, y1 + self.height)) 42 | return croped_img 43 | 44 | 45 | class RandomErasing(object): 46 | 47 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.4914, 0.4822, 0.4465]): 48 | self.probability = probability 49 | self.mean = mean 50 | self.sl = sl 51 | self.sh = sh 52 | self.r1 = r1 53 | 54 | def __call__(self, img): 55 | 56 | if random.uniform(0, 1) > self.probability: 57 | return img 58 | 59 | for attempt in range(100): 60 | area = img.size()[1] * img.size()[2] 61 | 62 | target_area = random.uniform(self.sl, self.sh) * area 63 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 64 | 65 | h = int(round(math.sqrt(target_area * aspect_ratio))) 66 | w = int(round(math.sqrt(target_area / aspect_ratio))) 67 | 68 | if w < img.size()[2] and h < img.size()[1]: 69 | x1 = random.randint(0, img.size()[1] - h) 70 | y1 = random.randint(0, img.size()[2] - w) 71 | if img.size()[0] == 3: 72 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 73 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 74 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 75 | else: 76 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 77 | return img 78 | 79 | return img 80 | 81 | 82 | class ColorAugmentation(object): 83 | """ 84 | Randomly alter the intensities of RGB channels 85 | Reference: 86 | Krizhevsky et al. ImageNet Classification with Deep ConvolutionalNeural Networks. NIPS 2012. 87 | """ 88 | 89 | def __init__(self, p=0.5): 90 | self.p = p 91 | self.eig_vec = torch.Tensor([ 92 | [0.4009, 0.7192, -0.5675], 93 | [-0.8140, -0.0045, -0.5808], 94 | [0.4203, -0.6948, -0.5836], 95 | ]) 96 | self.eig_val = torch.Tensor([[0.2175, 0.0188, 0.0045]]) 97 | 98 | def _check_input(self, tensor): 99 | assert tensor.dim() == 3 and tensor.size(0) == 3 100 | 101 | def __call__(self, tensor): 102 | if random.uniform(0, 1) > self.p: 103 | return tensor 104 | alpha = torch.normal(mean=torch.zeros_like(self.eig_val)) * 0.1 105 | quatity = torch.mm(self.eig_val * alpha, self.eig_vec) 106 | tensor = tensor + quatity.view(3, 1, 1) 107 | return tensor 108 | 109 | 110 | def build_transforms(height, 111 | width, 112 | random_erase=False, # use random erasing for data augmentation 113 | color_jitter=False, # randomly change the brightness, contrast and saturation 114 | color_aug=False, # randomly alter the intensities of RGB channels 115 | **kwargs): 116 | # use imagenet mean and std as default 117 | # TODO: compute dataset-specific mean and std 118 | imagenet_mean = [0.485, 0.456, 0.406] 119 | imagenet_std = [0.229, 0.224, 0.225] 120 | normalize = Normalize(mean=imagenet_mean, std=imagenet_std) 121 | 122 | # build train transformations 123 | transform_train = [] 124 | transform_train += [Random2DTranslation(height, width)] 125 | transform_train += [RandomHorizontalFlip()] 126 | if color_jitter: 127 | transform_train += [ColorJitter(brightness=0.2, contrast=0.15, saturation=0, hue=0)] 128 | transform_train += [ToTensor()] 129 | if color_aug: 130 | transform_train += [ColorAugmentation()] 131 | transform_train += [normalize] 132 | if random_erase: 133 | transform_train += [RandomErasing()] 134 | transform_train = Compose(transform_train) 135 | 136 | transform_val = [] 137 | transform_val += [Random2DTranslation(height, width)] 138 | transform_val += [RandomHorizontalFlip()] 139 | if color_jitter: 140 | transform_val += [ColorJitter(brightness=0.2, contrast=0.15, saturation=0, hue=0)] 141 | transform_val += [ToTensor()] 142 | if color_aug: 143 | transform_val += [ColorAugmentation()] 144 | transform_val += [normalize] 145 | if random_erase: 146 | transform_val += [RandomErasing()] 147 | transform_val = Compose(transform_val) 148 | 149 | # build test transformations 150 | transform_test = Compose([ 151 | Resize((height, width)), 152 | ToTensor(), 153 | normalize, 154 | ]) 155 | 156 | return transform_train, transform_val, transform_test 157 | #return transform_train, transform_test 158 | 159 | def train_transforms(height, 160 | width, 161 | random_erase=False, # use random erasing for data augmentation 162 | color_jitter=False, # randomly change the brightness, contrast and saturation 163 | color_aug=False, # randomly alter the intensities of RGB channels 164 | **kwargs): 165 | # use imagenet mean and std as default 166 | # TODO: compute dataset-specific mean and std 167 | imagenet_mean = [0.485, 0.456, 0.406] 168 | imagenet_std = [0.229, 0.224, 0.225] 169 | normalize = Normalize(mean=imagenet_mean, std=imagenet_std) 170 | 171 | # build train transformations 172 | transform_train = [] 173 | transform_train += [Random2DTranslation(height, width)] 174 | transform_train += [RandomHorizontalFlip()] 175 | if color_jitter: 176 | transform_train += [ColorJitter(brightness=0.2, contrast=0.15, saturation=0, hue=0)] 177 | transform_train += [ToTensor()] 178 | if color_aug: 179 | transform_train += [ColorAugmentation()] 180 | transform_train += [normalize] 181 | if random_erase: 182 | transform_train += [RandomErasing()] 183 | transform_train = Compose(transform_train) 184 | 185 | 186 | return transform_train 187 | 188 | 189 | def test_transform(height, width): 190 | imagenet_mean = [0.485, 0.456, 0.406] 191 | imagenet_std = [0.229, 0.224, 0.225] 192 | normalize = Normalize(mean=imagenet_mean, std=imagenet_std) 193 | 194 | transform_test = Compose([ 195 | Resize((height, width)), 196 | ToTensor(), 197 | normalize, 198 | ]) 199 | 200 | return transform_test 201 | -------------------------------------------------------------------------------- /datasets/vehicleid.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import random 6 | import os 7 | import os.path as osp 8 | import glob 9 | import re 10 | 11 | from .base import BaseImageDataset 12 | from collections import defaultdict 13 | 14 | 15 | class VehicleID(BaseImageDataset): 16 | """ 17 | VehicleID 18 | 19 | Reference: 20 | @inproceedings{liu2016deep, 21 | title={Deep Relative Distance Learning: Tell the Difference Between Similar Vehicles}, 22 | author={Liu, Hongye and Tian, Yonghong and Wang, Yaowei and Pang, Lu and Huang, Tiejun}, 23 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 24 | pages={2167--2175}, 25 | year={2016}} 26 | 27 | Dataset statistics: 28 | # train_list: 13164 vehicles for model training 29 | # test_list_800: 800 vehicles for model testing(small test set in paper 30 | # test_list_1600: 1600 vehicles for model testing(medium test set in paper 31 | # test_list_2400: 2400 vehicles for model testing(large test set in paper 32 | # test_list_3200: 3200 vehicles for model testing 33 | # test_list_6000: 6000 vehicles for model testing 34 | # test_list_13164: 13164 vehicles for model testing 35 | """ 36 | dataset_dir = 'vehicleid' 37 | 38 | def __init__(self, root='datasets', verbose=True, test_size=2400, **kwargs): 39 | super(VehicleID, self).__init__(root) 40 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 41 | self.img_dir = osp.join(self.dataset_dir, 'image') 42 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 43 | self.split_dir = osp.join(self.dataset_dir, 'train_test_split') 44 | self.train_list = osp.join(self.split_dir, 'train_list.txt') 45 | self.test_size = test_size 46 | 47 | if self.test_size == 800: 48 | self.gallery_dir = osp.join(self.dataset_dir, 'image_gallery_800') 49 | self.test_list = osp.join(self.split_dir, 'test_list_800.txt') 50 | elif self.test_size == 1600: 51 | self.gallery_dir = osp.join(self.dataset_dir, 'image_gallery_1600') 52 | self.test_list = osp.join(self.split_dir, 'test_list_1600.txt') 53 | elif self.test_size == 2400: 54 | self.gallery_dir = osp.join(self.dataset_dir, 'image_gallery_2400') 55 | self.test_list = osp.join(self.split_dir, 'test_list_2400.txt') 56 | 57 | print(self.gallery_dir) 58 | 59 | self.check_before_run() 60 | 61 | train = self.process_dir(self.train_dir, relabel=True) 62 | query, gallery = self.process_split(relabel=True) 63 | 64 | self.train = train 65 | self.query = query 66 | self.gallery = gallery 67 | 68 | if verbose: 69 | print('=> VehicleID loaded') 70 | self.print_dataset_statistics(train, query, gallery) 71 | 72 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 73 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 74 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 75 | 76 | 77 | def check_before_run(self): 78 | """Check if all files are available before going deeper""" 79 | if not osp.exists(self.dataset_dir): 80 | raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) 81 | if not osp.exists(self.train_dir): 82 | raise RuntimeError('"{}" is not available'.format(self.train_dir)) 83 | if self.test_size not in [800, 1600, 2400]: 84 | raise RuntimeError('"{}" is not available'.format(self.test_size)) 85 | if not osp.exists(self.gallery_dir): 86 | raise RuntimeError('"{}" is not available'.format(self.gallery_dir)) 87 | 88 | def get_pid2label(self, pids): 89 | pid_container = set(pids) 90 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 91 | return pid2label 92 | 93 | def parse_img_pids(self, nl_pairs, pid2label=None): 94 | # il_pair is the pairs of img name and label 95 | output = [] 96 | for info in nl_pairs: 97 | name = info[0] 98 | pid = info[1] 99 | if pid2label is not None: 100 | pid = pid2label[pid] 101 | camid = 1 # don't have camid information use 1 for all 102 | img_path = osp.join(self.img_dir, name+'.jpg') 103 | output.append((img_path, pid, camid)) 104 | return output 105 | 106 | def process_dir(self, dir_path, relabel=False): 107 | img_paths = sorted(glob.glob(osp.join(dir_path, '*.jpg'))) 108 | #pattern = re.compile(r'([-\d]+)_c([-\d]+)') 109 | 110 | pid_container = set() 111 | for img_path in img_paths: 112 | pid = int(re.search(r'([-\d]+)', img_path).group()) 113 | if pid == -1: 114 | continue # junk images are just ignored 115 | pid_container.add(pid) 116 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 117 | 118 | dataset = [] 119 | for img_path in img_paths: 120 | pid = int(re.search(r'([-\d]+)', img_path).group()) 121 | if pid == -1: 122 | continue # junk images are just ignored 123 | assert 0 <= pid <= 131640 # pid == 0 means background 124 | pid = pid2label[pid] 125 | camid = 1 126 | dataset.append((img_path, pid, camid)) 127 | 128 | return dataset 129 | 130 | def process_split(self, relabel=False): 131 | 132 | test_pid_dict = defaultdict(list) 133 | with open(self.test_list) as f_test: 134 | test_data = f_test.readlines() 135 | for data in test_data: 136 | name, pid = data.split(' ') 137 | test_pid_dict[pid].append([name, pid]) 138 | test_pids = list(test_pid_dict.keys()) 139 | num_test_pids = len(test_pids) 140 | assert num_test_pids == self.test_size, 'There should be {} vehicles for testing,' \ 141 | ' but but got {}, please check the data'\ 142 | .format(self.test_size, num_test_pids) 143 | 144 | query_data = [] 145 | gallery_data = [] 146 | 147 | # for each test id, random choose one image for gallery 148 | # and the other ones for query. 149 | for pid in test_pids: 150 | imginfo = test_pid_dict[pid] 151 | sample = random.choice(imginfo) 152 | imginfo.remove(sample) 153 | gallery_data.extend(imginfo) 154 | query_data.append(sample) 155 | 156 | query = self.parse_img_pids(query_data) 157 | gallery = self.parse_img_pids(gallery_data) 158 | return query, gallery 159 | 160 | -------------------------------------------------------------------------------- /datasets/veri.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import glob 6 | import re 7 | import os.path as osp 8 | 9 | from .base import BaseImageDataset 10 | 11 | 12 | class VeRi(BaseImageDataset): 13 | 14 | dataset_dir = 'veri' 15 | 16 | def __init__(self, root='datasets', dataset_dir = 'VeRi', verbose=True, **kwargs): 17 | super(VeRi, self).__init__(root) 18 | self.dataset_dir = osp.join(self.root, self.dataset_dir) 19 | self.train_dir = osp.join(self.dataset_dir, 'image_train') 20 | #self.val_dir = osp.join(self.dataset_dir, 'image_val') 21 | self.query_dir = osp.join(self.dataset_dir, 'image_query') 22 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test') 23 | 24 | self.check_before_run() 25 | 26 | train = self.process_dir(self.train_dir, relabel=True) 27 | #val = self.process_dir(self.val_dir, relabel=True) 28 | query = self.process_dir(self.query_dir, relabel=False) 29 | gallery = self.process_dir(self.gallery_dir, relabel=False) 30 | 31 | if verbose: 32 | print('=> VeRi loaded') 33 | self.print_dataset_statistics(train, query, gallery) 34 | 35 | self.train = train 36 | #self.val = val 37 | self.query = query 38 | self.gallery = gallery 39 | 40 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) 41 | #self.num_val_pids, self.num_val_imgs, self.num_val_cams = self.get_imagedata_info(self.val) 42 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) 43 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) 44 | 45 | def check_before_run(self): 46 | """Check if all files are available before going deeper""" 47 | if not osp.exists(self.dataset_dir): 48 | raise RuntimeError('"{}" is not available'.format(self.dataset_dir)) 49 | if not osp.exists(self.train_dir): 50 | raise RuntimeError('"{}" is not available'.format(self.train_dir)) 51 | if not osp.exists(self.query_dir): 52 | raise RuntimeError('"{}" is not available'.format(self.query_dir)) 53 | if not osp.exists(self.gallery_dir): 54 | raise RuntimeError('"{}" is not available'.format(self.gallery_dir)) 55 | 56 | def process_dir(self, dir_path, relabel=False): 57 | img_paths = sorted(glob.glob(osp.join(dir_path, '*.jpg'))) 58 | pattern = re.compile(r'([-\d]+)_c([-\d]+)') 59 | 60 | pid_container = set() 61 | for img_path in img_paths: 62 | pid, _ = map(int, pattern.search(img_path).groups()) 63 | if pid == -1: 64 | continue # junk images are just ignored 65 | pid_container.add(pid) 66 | pid2label = {pid: label for label, pid in enumerate(pid_container)} 67 | 68 | dataset = [] 69 | for img_path in img_paths: 70 | pid, camid = map(int, pattern.search(img_path).groups()) 71 | if pid == -1: 72 | continue # junk images are just ignored 73 | assert 0 <= pid <= 1501 # pid == 0 means background 74 | assert 1 <= camid <= 20 75 | camid -= 1 # index starts from 0 76 | if relabel: 77 | pid = pid2label[pid] 78 | dataset.append((img_path, pid, camid)) 79 | 80 | return dataset 81 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import time 4 | import numpy as np 5 | import torch 6 | from utils.evaluation import evaluate, evaluate_vid 7 | from utils.reranking import re_ranking 8 | from utils.avgmeter import AverageMeter 9 | 10 | 11 | def do_test(model, queryloader, galleryloader, batch_size, use_gpu, dataset, ranks=[1, 5, 10]): 12 | batch_time = AverageMeter() 13 | 14 | model.eval() 15 | 16 | with torch.no_grad(): 17 | qf, q_pids, q_camids = [], [], [] 18 | for batch_idx, (imgs, pids, camids, _) in enumerate(queryloader): 19 | if use_gpu: 20 | imgs = imgs.cuda() 21 | 22 | end = time.time() 23 | features = model(imgs) 24 | batch_time.update(time.time() - end) 25 | 26 | features = features.data.cpu() 27 | qf.append(features) 28 | q_pids.extend(pids) 29 | q_camids.extend(camids) 30 | qf = torch.cat(qf, 0) 31 | q_pids = np.asarray(q_pids) 32 | q_camids = np.asarray(q_camids) 33 | 34 | print('Extracted features for query set, obtained {}-by-{} matrix'.format(qf.size(0), qf.size(1))) 35 | 36 | gf, g_pids, g_camids = [], [], [] 37 | for batch_idx, (imgs, pids, camids, _) in enumerate(galleryloader): 38 | if use_gpu: 39 | imgs = imgs.cuda() 40 | 41 | end = time.time() 42 | features = model(imgs) 43 | batch_time.update(time.time() - end) 44 | 45 | features = features.data.cpu() 46 | gf.append(features) 47 | g_pids.extend(pids) 48 | g_camids.extend(camids) 49 | gf = torch.cat(gf, 0) 50 | g_pids = np.asarray(g_pids) 51 | g_camids = np.asarray(g_camids) 52 | 53 | print('Extracted features for gallery set, obtained {}-by-{} matrix'.format(gf.size(0), gf.size(1))) 54 | 55 | print('=> BatchTime(s)/BatchSize(img): {:.3f}/{}'.format(batch_time.avg, batch_size)) 56 | 57 | m, n = qf.size(0), gf.size(0) 58 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 59 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t() 60 | distmat.addmm_(1, -2, qf, gf.t()) 61 | distmat = distmat.numpy() 62 | 63 | print('Computing CMC and mAP') 64 | if dataset == 'vehicleid': 65 | cmc, mAP = evaluate_vid(distmat, q_pids, g_pids, q_camids, g_camids, 50) 66 | else: 67 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, 50) 68 | 69 | print('Results ----------') 70 | print('mAP: {:.1%}'.format(mAP)) 71 | print('CMC curve') 72 | for r in ranks: 73 | print('Rank-{:<3}: {:.1%}'.format(r, cmc[r - 1])) 74 | print('------------------') 75 | 76 | distmat_re = re_ranking(qf, gf, k1=80, k2=15, lambda_value=0.2) 77 | print('Computing CMC and mAP') 78 | if dataset == 'vehicleid': 79 | cmc_re, mAP_re = evaluate_vid(distmat_re, q_pids, g_pids, q_camids, g_camids, 50) 80 | else: 81 | cmc_re, mAP_re = evaluate(distmat_re, q_pids, g_pids, q_camids, g_camids, 50) 82 | print('Re-Ranked Results--') 83 | print('mAP: {:.1%}'.format(mAP_re)) 84 | print('CMC curve') 85 | for r in ranks: 86 | print('Rank-{:<3}: {:.1%}'.format(r, cmc_re[r - 1])) 87 | print('------------------') 88 | 89 | return cmc[0], distmat, cmc_re[0], distmat_re -------------------------------------------------------------------------------- /images/affinity_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adhirajghosh/RPTM_reid/183e1f77a0979ab2ffa08b0bdb1c43ef0f633ad5/images/affinity_matrix.png -------------------------------------------------------------------------------- /images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adhirajghosh/RPTM_reid/183e1f77a0979ab2ffa08b0bdb1c43ef0f633ad5/images/architecture.png -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .cross_entropy_loss import CrossEntropyLoss 6 | from .hard_mine_triplet_loss import TripletLoss 7 | 8 | 9 | def DeepSupervision(criterion, xs, y): 10 | """ 11 | Args: 12 | - criterion: loss function 13 | - xs: tuple of inputs 14 | - y: ground truth 15 | """ 16 | loss = 0. 17 | for x in xs: 18 | loss += criterion(x, y) 19 | loss /= len(xs) 20 | return loss 21 | -------------------------------------------------------------------------------- /loss/cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class CrossEntropyLoss(nn.Module): 9 | """Cross entropy loss with label smoothing regularizer. 10 | 11 | Reference: 12 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016. 13 | 14 | Equation: y = (1 - epsilon) * y + epsilon / K. 15 | 16 | Args: 17 | - num_classes (int): number of classes 18 | - epsilon (float): weight 19 | - use_gpu (bool): whether to use gpu devices 20 | - label_smooth (bool): whether to apply label smoothing, if False, epsilon = 0 21 | """ 22 | 23 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, label_smooth=True): 24 | super(CrossEntropyLoss, self).__init__() 25 | self.num_classes = num_classes 26 | self.epsilon = epsilon if label_smooth else 0 27 | self.use_gpu = use_gpu 28 | self.logsoftmax = nn.LogSoftmax(dim=1) 29 | 30 | def forward(self, inputs, targets): 31 | """ 32 | Args: 33 | - inputs: prediction matrix (before softmax) with shape (batch_size, num_classes) 34 | - targets: ground truth labels with shape (num_classes) 35 | """ 36 | log_probs = self.logsoftmax(inputs) 37 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1) 38 | if self.use_gpu: targets = targets.cuda() 39 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 40 | loss = (- targets * log_probs).mean(0).sum() 41 | return loss 42 | -------------------------------------------------------------------------------- /loss/hard_mine_triplet_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class TripletLoss(nn.Module): 9 | 10 | 11 | def __init__(self, margin=0.3): 12 | super(TripletLoss, self).__init__() 13 | self.margin = margin 14 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 15 | 16 | def forward(self, inputs, targets): 17 | """ 18 | Args: 19 | - inputs: feature matrix with shape (batch_size, feat_dim) 20 | - targets: ground truth labels with shape (num_classes) 21 | """ 22 | n = inputs.size(0) 23 | 24 | # Compute pairwise distance, replace by the official when merged 25 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 26 | dist = dist + dist.t() 27 | dist.addmm_(1, -2, inputs, inputs.t()) 28 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 29 | 30 | # For each anchor, find the hardest positive and negative 31 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 32 | dist_ap, dist_an = [], [] 33 | for i in range(n): 34 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 35 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 36 | dist_ap = torch.cat(dist_ap) 37 | dist_an = torch.cat(dist_an) 38 | 39 | # Compute ranking hinge loss 40 | y = torch.ones_like(dist_an) 41 | loss = self.ranking_loss(dist_an, dist_ap, y) 42 | return loss 43 | -------------------------------------------------------------------------------- /loss/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | import torch 5 | import yaml 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | import torch.nn.functional as F 9 | from torch.nn import Parameter 10 | from torch.nn import init 11 | import math 12 | 13 | def triplet_loss(features, margin, batch_size, size_average = True): 14 | ranking_loss = nn.MarginRankingLoss(margin=margin) 15 | #anchor = L2Normalization(features[0:batch_size]) 16 | #positive = L2Normalization(features[batch_size:batch_size*2]) 17 | #negative = L2Normalization(features[batch_size*2:batch_size*3]) 18 | anchor = features[0:batch_size] 19 | positive = features[batch_size:batch_size*2] 20 | negative = features[batch_size*2:batch_size*3] 21 | #distance1 = torch.sqrt(torch.sum(torch.pow(anchor - positive, 2), 1, keepdims=True)) 22 | #distance2 = torch.sqrt(torch.sum(torch.pow(anchor - negative, 2), 1, keepdims=True)) 23 | distance_positive = (anchor - positive).pow(2).sum(1).pow(.5) 24 | distance_negative = (anchor - negative).pow(2).sum(1).pow(.5) 25 | y = torch.ones_like(distance_negative) 26 | #losses = F.relu(distance_positive - distance_negative + margin) 27 | losses = ranking_loss(distance_positive, distance_negative, y) 28 | return losses if size_average else losses.sum() 29 | 30 | 31 | def xent_loss(output, trainY, num_classes): 32 | epsilon = 1.0 33 | logsoftmax = nn.LogSoftmax(dim=1) 34 | log_probs = logsoftmax(output) 35 | targets = torch.zeros(log_probs.size()) 36 | for i in range(len(targets)): 37 | for j in range(len(targets[i])): 38 | if j == trainY[i]: 39 | targets[i][j] = torch.tensor(1.0) 40 | break 41 | targets = targets.cuda() 42 | targets = (1 - epsilon) * targets + epsilon / num_classes 43 | loss = (- targets * log_probs).mean(0).sum() 44 | return loss 45 | 46 | def L2Normalization(ff, dim = 1): 47 | # ff is B*N 48 | fnorm = torch.norm(ff, p=2, dim=dim, keepdim=True) + 1e-5 49 | ff = ff.div(fnorm.expand_as(ff)) 50 | return ff 51 | 52 | def myphi(x,m): 53 | x = x * m 54 | return 1-x**2/math.factorial(2)+x**4/math.factorial(4)-x**6/math.factorial(6) + \ 55 | x**8/math.factorial(8) - x**9/math.factorial(9) 56 | 57 | # I largely modified the AngleLinear Loss 58 | class AngleLinear(nn.Module): 59 | def __init__(self, in_features, out_features, m = 4, phiflag=True): 60 | super(AngleLinear, self).__init__() 61 | self.in_features = in_features 62 | self.out_features = out_features 63 | self.weight = Parameter(torch.Tensor(in_features,out_features)) 64 | init.normal_(self.weight.data, std=0.001) 65 | self.phiflag = phiflag 66 | self.m = m 67 | self.mlambda = [ 68 | lambda x: x**0, 69 | lambda x: x**1, 70 | lambda x: 2*x**2-1, 71 | lambda x: 4*x**3-3*x, 72 | lambda x: 8*x**4-8*x**2+1, 73 | lambda x: 16*x**5-20*x**3+5*x 74 | ] 75 | 76 | def forward(self, input): 77 | x = input # size=(B,F) F is feature len 78 | w = self.weight # size=(F,Classnum) F=in_features Classnum=out_features 79 | 80 | ww = w.renorm(2,1,1e-5).mul(1e5) 81 | xlen = x.pow(2).sum(1).pow(0.5) # size=B 82 | wlen = ww.pow(2).sum(0).pow(0.5) # size=Classnum 83 | 84 | cos_theta = x.mm(ww) # size=(B,Classnum) 85 | cos_theta = cos_theta / xlen.view(-1,1) / wlen.view(1,-1) 86 | cos_theta = cos_theta.clamp(-1,1) 87 | 88 | if self.phiflag: 89 | cos_m_theta = self.mlambda[self.m](cos_theta) 90 | theta = Variable(cos_theta.data.acos()) 91 | k = (self.m*theta/3.14159265).floor() 92 | n_one = k*0.0 - 1 93 | phi_theta = (n_one**k) * cos_m_theta - 2*k 94 | else: 95 | theta = cos_theta.acos() 96 | phi_theta = myphi(theta,self.m) 97 | phi_theta = phi_theta.clamp(-1*self.m,1) 98 | 99 | cos_theta = cos_theta * xlen.view(-1,1) 100 | phi_theta = phi_theta * xlen.view(-1,1) 101 | output = (cos_theta,phi_theta) 102 | return output # size=(B,Classnum,2) 103 | 104 | #https://github.com/auroua/InsightFace_TF/blob/master/losses/face_losses.py#L80 105 | class ArcLinear(nn.Module): 106 | def __init__(self, in_features, out_features, s=64.0): 107 | super(ArcLinear, self).__init__() 108 | self.weight = Parameter(torch.Tensor(in_features,out_features)) 109 | init.normal_(self.weight.data, std=0.001) 110 | self.loss_s = s 111 | 112 | def forward(self, input): 113 | embedding = input 114 | nembedding = L2Normalization(embedding, dim=1)*self.loss_s 115 | _weight = L2Normalization(self.weight, dim=0) 116 | fc7 = nembedding.mm(_weight) 117 | output = (fc7, _weight, nembedding) 118 | return output 119 | 120 | class ArcLoss(nn.Module): 121 | def __init__(self, m1=1.0, m2=0.5, m3 =0.0, s = 64.0): 122 | super(ArcLoss, self).__init__() 123 | self.loss_m1 = m1 124 | self.loss_m2 = m2 125 | self.loss_m3 = m3 126 | self.loss_s = s 127 | 128 | def forward(self, input, target): 129 | fc7, _weight, nembedding = input 130 | 131 | index = fc7.data * 0.0 #size=(B,Classnum) 132 | index.scatter_(1,target.data.view(-1,1),1) 133 | index = index.byte() 134 | index = Variable(index) 135 | 136 | zy = fc7[index] 137 | cos_t = zy/self.loss_s 138 | t = torch.acos(cos_t) 139 | t = t*self.loss_m1 + self.loss_m2 140 | body = torch.cos(t) - self.loss_m3 141 | 142 | new_zy = body*self.loss_s 143 | diff = new_zy - zy 144 | fc7[index] += diff 145 | loss = F.cross_entropy(fc7, target) 146 | return loss 147 | 148 | class AngleLoss(nn.Module): 149 | def __init__(self, gamma=0): 150 | super(AngleLoss, self).__init__() 151 | self.gamma = gamma 152 | self.it = 0 153 | self.LambdaMin = 5.0 154 | self.LambdaMax = 1500.0 155 | self.lamb = 1500.0 156 | 157 | def forward(self, input, target): 158 | self.it += 1 159 | cos_theta,phi_theta = input 160 | target = target.view(-1,1) #size=(B,1) 161 | 162 | index = cos_theta.data * 0.0 #size=(B,Classnum) 163 | index.scatter_(1,target.data.view(-1,1),1) 164 | index = index.byte() 165 | index = Variable(index) 166 | 167 | self.lamb = max(self.LambdaMin,self.LambdaMax/(1+0.1*self.it )) 168 | output = cos_theta * 1.0 #size=(B,Classnum) 169 | output[index] -= cos_theta[index]*(1.0+0)/(1+self.lamb) 170 | output[index] += phi_theta[index]*(1.0+0)/(1+self.lamb) 171 | 172 | logpt = F.log_softmax(output, dim=1) 173 | logpt = logpt.gather(1,target) 174 | logpt = logpt.view(-1) 175 | pt = Variable(logpt.data.exp()) 176 | 177 | loss = -1 * (1-pt)**self.gamma * logpt 178 | loss = loss.mean() 179 | 180 | return loss 181 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | import os 5 | import os.path as osp 6 | import argparse 7 | import sys 8 | try: 9 | from apex.fp16_utils import * 10 | from apex import amp, optimizers 11 | except ImportError: # will be 3.x series 12 | print('This is not an error. If you want to use low precision, i.e., fp16, please install the apex with cuda support (https://github.com/NVIDIA/apex) and update pytorch to 1.0') 13 | from config import cfg 14 | from datasets import data_loader 15 | from model import ft_net_SE, init_model, init_optimizer 16 | from loss import CrossEntropyLoss, TripletLoss 17 | from train import do_train 18 | from eval import do_test 19 | from utils.kwargs import return_kwargs 20 | from utils.loggers import Logger 21 | from utils.torchtools import count_num_param, accuracy, load_pretrained_weights, save_checkpoint 22 | from utils.visualtools import visualize_ranked_results 23 | from utils.functions import create_split_dirs 24 | 25 | try: 26 | from apex import amp 27 | APEX_AVAILABLE = True 28 | except ModuleNotFoundError: 29 | APEX_AVAILABLE = False 30 | 31 | def set_seed(seed): 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed(seed) 34 | torch.cuda.manual_seed_all(seed) 35 | np.random.seed(seed) 36 | random.seed(seed) 37 | torch.backends.cudnn.deterministic = True 38 | torch.backends.cudnn.benchmark = True 39 | 40 | def main(): 41 | parser = argparse.ArgumentParser(description="Relation Preserving Triplet Mining for Object Re-identification") 42 | parser.add_argument( 43 | "--config_file", default="configs/veri_r101.yml", help="path to config file", type=str 44 | ) 45 | parser.add_argument("opts", help="Modify config options using the command-line", default=None, 46 | nargs=argparse.REMAINDER) 47 | 48 | args = parser.parse_args() 49 | 50 | #Load the config file 51 | if args.config_file != "": 52 | cfg.merge_from_file(args.config_file) 53 | cfg.merge_from_list(args.opts) 54 | cfg.freeze() 55 | 56 | set_seed(1234) 57 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.GPU_ID 58 | 59 | output_dir = cfg.MISC.SAVE_DIR 60 | if output_dir and not os.path.exists(output_dir): 61 | os.makedirs(output_dir) 62 | 63 | dataset_kwargs, transform_kwargs, optimizer_kwargs, lr_scheduler_kwargs = return_kwargs(cfg) 64 | 65 | if cfg.MISC.FP16: 66 | fp16 = True 67 | 68 | use_gpu = cfg.MISC.USE_GPU 69 | log_name = './log_test.txt' if cfg.TEST.EVAL else './log_train.txt' 70 | sys.stdout = Logger(osp.join(cfg.MISC.SAVE_DIR, log_name)) 71 | 72 | if not os.path.exists(cfg.DATASET.SPLIT_DIR): 73 | create_split_dirs(cfg) 74 | 75 | print("Running for RPTM: ", cfg.MODEL.RPTM_SELECT) 76 | print('Currently using GPU ', cfg.MODEL.GPU_ID) 77 | print('Initializing image data manager') 78 | 79 | trainloader, train_dict, data_tfr, testloader_dict, dm = data_loader(cfg, dataset_kwargs, transform_kwargs) 80 | 81 | print('Initializing model: {}'.format(cfg.MODEL.ARCH)) 82 | 83 | model = init_model(cfg.MODEL.ARCH, train_dict['class_names'], loss={'xent', 'htri'}, use_gpu=use_gpu) 84 | print('Model size: {:.3f} M'.format(count_num_param(model))) 85 | 86 | if cfg.MODEL.PRETRAIN_PATH != '': 87 | print("weights loaded") 88 | load_pretrained_weights(model, cfg.MODEL.PRETRAIN_PATH) 89 | 90 | if use_gpu: 91 | model = model.cuda() 92 | optimizer = init_optimizer(model, **optimizer_kwargs) 93 | if APEX_AVAILABLE: 94 | model, optimizer = amp.initialize( 95 | model, optimizer, opt_level="O2", 96 | keep_batchnorm_fp32=True, loss_scale="dynamic") 97 | 98 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg.SOLVER.STEPSIZE, gamma=cfg.SOLVER.GAMMA) 99 | 100 | criterion_xent = CrossEntropyLoss(num_classes=train_dict['num_train_pids'], use_gpu=use_gpu, label_smooth=True) 101 | criterion_htri = TripletLoss(margin=cfg.LOSS.MARGIN) 102 | 103 | if cfg.TEST.EVAL: 104 | print('Evaluate only') 105 | 106 | for name in cfg.DATASET.TARGET_NAME: 107 | print('Evaluating {} ...'.format(name)) 108 | queryloader = testloader_dict[name]['query'] 109 | galleryloader = testloader_dict[name]['gallery'] 110 | _, distmat, _, distmat_re = do_test(model, queryloader, galleryloader, cfg.TEST.TEST_BATCH_SIZE, use_gpu, cfg.DATASET.TARGET_NAME[0]) 111 | 112 | if cfg.TEST.VIS_RANK: 113 | visualize_ranked_results( 114 | distmat_re, dm.return_testdataset_by_name(name), 115 | save_dir=osp.join(cfg.MISC.SAVE_DIR, 'ranked_results', name), 116 | topk=20 117 | ) 118 | return 119 | 120 | print('=> Start training') 121 | 122 | do_train(cfg, 123 | trainloader, 124 | train_dict, 125 | data_tfr, 126 | testloader_dict, 127 | dm, 128 | model, 129 | optimizer, 130 | scheduler, 131 | criterion_htri, 132 | criterion_xent, 133 | ) 134 | 135 | 136 | if __name__ == '__main__': 137 | main() -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | from .optimizers import * 3 | from .resnet import * 4 | from .senet import * 5 | 6 | __model_factory = { 7 | # image classification models 8 | 'resnet50': resnet50, 9 | 'resnet50_fc512': resnet50_fc512, 10 | 'resnet101': resnet101, 11 | 'resnet152': resnet152, 12 | 'resnet50_ibn_a': resnet50_ibn_a, 13 | 'resnet101_ibn_a': resnet101_ibn_a, 14 | 'resnet152_ibn_a': resnet152_ibn_a, 15 | 'senet154': senet154, 16 | 'se_resnet50': se_resnet50, 17 | 'se_resnet101': se_resnet101, 18 | 'se_resnet152': se_resnet152, 19 | 'se_resnext50_32x4d': se_resnext50_32x4d, 20 | 'se_resnext101_32x4d': se_resnext101_32x4d } 21 | 22 | def get_names(): 23 | return list(__model_factory.keys()) 24 | 25 | 26 | def init_model(name, *args, **kwargs): 27 | if name not in list(__model_factory.keys()): 28 | raise KeyError('Unknown model: {}'.format(name)) 29 | return __model_factory[name](*args, **kwargs) 30 | -------------------------------------------------------------------------------- /model/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import torch 5 | 6 | 7 | def init_lr_scheduler(optimizer, 8 | lr_scheduler='multi_step', # learning rate scheduler 9 | stepsize=[20, 40], # step size to decay learning rate 10 | gamma=0.1 # learning rate decay 11 | ): 12 | if lr_scheduler == 'single_step': 13 | return torch.optim.lr_scheduler.StepLR(optimizer, step_size=stepsize[0], gamma=gamma) 14 | 15 | elif lr_scheduler == 'multi_step': 16 | return torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=stepsize, gamma=gamma) 17 | elif lr_scheduler == 'plateau': 18 | return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=gamma, patience=0, verbose=True) 19 | 20 | else: 21 | raise ValueError('Unsupported lr_scheduler: {}'.format(lr_scheduler)) 22 | -------------------------------------------------------------------------------- /model/models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import init 5 | from torchvision import models 6 | from torch.autograd import Variable 7 | import pretrainedmodels 8 | from .senet import se_resnext101_32x4d 9 | from torch.nn import functional as F 10 | 11 | ###################################################################### 12 | def weights_init_kaiming(m): 13 | classname = m.__class__.__name__ 14 | # print(classname) 15 | if classname.find('Conv') != -1: 16 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') # For old pytorch, you may use kaiming_normal. 17 | elif classname.find('Linear') != -1: 18 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 19 | init.constant_(m.bias.data, 0.0) 20 | elif classname.find('BatchNorm1d') != -1: 21 | init.normal_(m.weight.data, 1.0, 0.02) 22 | init.constant_(m.bias.data, 0.0) 23 | 24 | def weights_init_classifier(m): 25 | classname = m.__class__.__name__ 26 | if classname.find('Linear') != -1: 27 | init.normal_(m.weight.data, std=0.001) 28 | init.constant_(m.bias.data, 0.0) 29 | 30 | def fix_relu(m): 31 | classname = m.__class__.__name__ 32 | if classname.find('ReLU') != -1: 33 | m.inplace=True 34 | # Defines the new fc layer and classification layer 35 | # |--Linear--|--bn--|--relu--|--Linear--| 36 | class ClassBlock(nn.Module): 37 | def __init__(self, input_dim, class_num, droprate, relu=False, bnorm=True, num_bottleneck=512, linear=True, return_f = False): 38 | super(ClassBlock, self).__init__() 39 | self.return_f = return_f 40 | add_block = [] 41 | if linear: 42 | add_block += [nn.Linear(input_dim, num_bottleneck)] 43 | else: 44 | num_bottleneck = input_dim 45 | if bnorm: 46 | add_block += [nn.BatchNorm1d(num_bottleneck)] 47 | if relu: 48 | add_block += [nn.LeakyReLU(0.1)] 49 | if droprate>0: 50 | add_block += [nn.Dropout(p=droprate)] 51 | add_block = nn.Sequential(*add_block) 52 | add_block.apply(weights_init_kaiming) 53 | 54 | classifier = [] 55 | classifier += [nn.Linear(num_bottleneck, class_num)] 56 | classifier = nn.Sequential(*classifier) 57 | classifier.apply(weights_init_classifier) 58 | 59 | self.add_block = add_block 60 | self.classifier = classifier 61 | def forward(self, x): 62 | x = self.add_block(x) 63 | if self.return_f: 64 | f = x 65 | x = self.classifier(x) 66 | return x,f 67 | else: 68 | x = self.classifier(x) 69 | return x 70 | 71 | # Define the SE-based Model 72 | class ft_net_SE(nn.Module): 73 | 74 | def __init__(self, class_num, droprate=0.5, stride=2, pool='avg', init_model=None): 75 | super().__init__() 76 | model_name = 'se_resnext101_32x4d' # could be fbresnet152 or inceptionresnetv2 77 | # model_ft = pretrainedmodels.__dict__[model_name](num_classes=1000, pretrained='imagenet') 78 | model_ft = se_resnext101_32x4d(num_classes=1000, pretrained='imagenet') 79 | 80 | if stride == 1: 81 | model_ft.layer4[0].conv2.stride = (1,1) 82 | model_ft.layer4[0].downsample[0].stride = (1,1) 83 | if pool == 'avg': 84 | model_ft.avg_pool = nn.AdaptiveAvgPool2d((1,1)) 85 | elif pool == 'max': 86 | model_ft.avg_pool = nn.AdaptiveMaxPool2d((1,1)) 87 | elif pool == 'avg+max': 88 | model_ft.avg_pool2 = nn.AdaptiveAvgPool2d((1,1)) 89 | model_ft.max_pool2 = nn.AdaptiveMaxPool2d((1,1)) 90 | else: 91 | print('UNKNOW POOLING!!!!!!!!!!!!!!!!!!!!!!!!!!') 92 | #model_ft.dropout = nn.Sequential() 93 | model_ft.last_linear = nn.Sequential() 94 | self.model = model_ft 95 | self.pool = pool 96 | # For DenseNet, the feature dim is 2048 97 | if pool == 'avg+max': 98 | self.classifier = ClassBlock(4096, class_num, droprate) 99 | else: 100 | self.classifier = ClassBlock(2048, class_num, droprate) 101 | self.flag = False 102 | if init_model!=None: 103 | self.flag = True 104 | self.model = init_model.model 105 | self.classifier.add_block = init_model.classifier.add_block 106 | self.new_dropout = nn.Sequential(nn.Dropout(p = droprate)) 107 | 108 | def forward(self, x): 109 | x = self.model.features(x) 110 | if self.pool == 'avg+max': 111 | v1 = self.model.avg_pool2(x) 112 | v2 = self.model.max_pool2(x) 113 | v = torch.cat((v1,v2), dim = 1) 114 | else: 115 | v = self.model.avg_pool(x) 116 | v = v.view(v.size(0), v.size(1)) 117 | if not self.training: 118 | return v 119 | # Convolution layers 120 | # Pooling and final linear layer 121 | if self.flag: 122 | v = self.classifier.add_block(v) 123 | v = self.new_dropout(v) 124 | y = self.classifier.classifier(v) 125 | else: 126 | y = self.classifier(v) 127 | return y,v 128 | 129 | -------------------------------------------------------------------------------- /model/optimizers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def init_optimizer(model, 9 | optim='adam', # optimizer choices 10 | lr=0.003, # learning rate 11 | weight_decay=5e-4, # weight decay 12 | momentum=0.9, # momentum factor for sgd and rmsprop 13 | sgd_dampening=0, # sgd's dampening for momentum 14 | sgd_nesterov=True, # whether to enable sgd's Nesterov momentum 15 | rmsprop_alpha=0.99, # rmsprop's smoothing constant 16 | adam_beta1=0.9, # exponential decay rate for adam's first moment 17 | adam_beta2=0.999, # # exponential decay rate for adam's second moment 18 | staged_lr=False, # different lr for different layers 19 | new_layers=None, # new layers use the default lr, while other layers's lr is scaled by base_lr_mult 20 | base_lr_mult=0.1, # learning rate multiplier for base layers 21 | ): 22 | if staged_lr: 23 | assert new_layers is not None 24 | base_params = [] 25 | base_layers = [] 26 | new_params = [] 27 | if isinstance(model, nn.DataParallel): 28 | model = model.module 29 | for name, module in model.named_children(): 30 | if name in new_layers: 31 | new_params += [p for p in module.parameters()] 32 | else: 33 | base_params += [p for p in module.parameters()] 34 | base_layers.append(name) 35 | param_groups = [ 36 | {'params': base_params, 'lr': lr * base_lr_mult}, 37 | {'params': new_params}, 38 | ] 39 | print('Use staged learning rate') 40 | print('* Base layers (initial lr = {}): {}'.format(lr * base_lr_mult, base_layers)) 41 | print('* New layers (initial lr = {}): {}'.format(lr, new_layers)) 42 | else: 43 | param_groups = model.parameters() 44 | 45 | # Construct optimizer 46 | if optim == 'adam': 47 | return torch.optim.Adam(param_groups, lr=lr, weight_decay=weight_decay, 48 | betas=(adam_beta1, adam_beta2)) 49 | 50 | elif optim == 'amsgrad': 51 | return torch.optim.Adam(param_groups, lr=lr, weight_decay=weight_decay, 52 | betas=(adam_beta1, adam_beta2), amsgrad=True) 53 | 54 | elif optim == 'sgd': 55 | return torch.optim.SGD(param_groups, lr=lr, momentum=momentum, weight_decay=weight_decay, 56 | dampening=sgd_dampening, nesterov=sgd_nesterov) 57 | 58 | elif optim == 'rmsprop': 59 | return torch.optim.RMSprop(param_groups, lr=lr, momentum=momentum, weight_decay=weight_decay, 60 | alpha=rmsprop_alpha) 61 | 62 | else: 63 | raise ValueError('Unsupported optimizer: {}'.format(optim)) 64 | -------------------------------------------------------------------------------- /model/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | import torch 5 | from torch import nn 6 | import math 7 | import torch.utils.model_zoo as model_zoo 8 | 9 | __all__ = ['resnet50', 'resnet50_fc512', 'resnet101', 'resnet152', 'resnet50_ibn_a', 'resnet101_ibn_a', 10 | 'resnet152_ibn_a'] 11 | 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 18 | 'resnet101_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_a-59ea0ac6.pth' 19 | } 20 | 21 | 22 | def conv3x3(in_planes, out_planes, stride=1): 23 | """3x3 convolution with padding""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 25 | padding=1, bias=False) 26 | 27 | 28 | class BasicBlock(nn.Module): 29 | expansion = 1 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None): 32 | super(BasicBlock, self).__init__() 33 | self.conv1 = conv3x3(inplanes, planes, stride) 34 | self.bn1 = nn.BatchNorm2d(planes) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.conv2 = conv3x3(planes, planes) 37 | self.bn2 = nn.BatchNorm2d(planes) 38 | self.downsample = downsample 39 | self.stride = stride 40 | 41 | def forward(self, x): 42 | residual = x 43 | 44 | out = self.conv1(x) 45 | out = self.bn1(out) 46 | out = self.relu(out) 47 | 48 | out = self.conv2(out) 49 | out = self.bn2(out) 50 | 51 | if self.downsample is not None: 52 | residual = self.downsample(x) 53 | 54 | out += residual 55 | out = self.relu(out) 56 | 57 | return out 58 | 59 | 60 | class Bottleneck(nn.Module): 61 | expansion = 4 62 | 63 | def __init__(self, inplanes, planes, stride=1, downsample=None): 64 | super(Bottleneck, self).__init__() 65 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 66 | self.bn1 = nn.BatchNorm2d(planes) 67 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 68 | padding=1, bias=False) 69 | self.bn2 = nn.BatchNorm2d(planes) 70 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 71 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 72 | self.relu = nn.ReLU(inplace=True) 73 | self.downsample = downsample 74 | self.stride = stride 75 | 76 | def forward(self, x): 77 | residual = x 78 | 79 | out = self.conv1(x) 80 | out = self.bn1(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv2(out) 84 | out = self.bn2(out) 85 | out = self.relu(out) 86 | 87 | out = self.conv3(out) 88 | out = self.bn3(out) 89 | 90 | if self.downsample is not None: 91 | residual = self.downsample(x) 92 | 93 | out += residual 94 | out = self.relu(out) 95 | 96 | return out 97 | 98 | 99 | class ResNet(nn.Module): 100 | """ 101 | Residual network 102 | Reference: 103 | He et al. Deep Residual Learning for Image Recognition. CVPR 2016. 104 | """ 105 | 106 | def __init__(self, num_classes, loss, block, layers, 107 | last_stride=2, 108 | fc_dims=None, 109 | dropout_p=None, 110 | **kwargs): 111 | self.inplanes = 64 112 | super(ResNet, self).__init__() 113 | self.loss = loss 114 | self.feature_dim = 512 * block.expansion 115 | 116 | # backbone network 117 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 118 | self.bn1 = nn.BatchNorm2d(64) 119 | self.relu = nn.ReLU(inplace=True) 120 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 121 | self.layer1 = self._make_layer(block, 64, layers[0]) 122 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 123 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 124 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride) 125 | 126 | self.global_avgpool = nn.AdaptiveAvgPool2d(1) 127 | self.fc = self._construct_fc_layer(fc_dims, 512 * block.expansion, dropout_p) 128 | self.classifier = nn.Linear(self.feature_dim, num_classes) 129 | 130 | self._init_params() 131 | 132 | def _make_layer(self, block, planes, blocks, stride=1): 133 | downsample = None 134 | if stride != 1 or self.inplanes != planes * block.expansion: 135 | downsample = nn.Sequential( 136 | nn.Conv2d(self.inplanes, planes * block.expansion, 137 | kernel_size=1, stride=stride, bias=False), 138 | nn.BatchNorm2d(planes * block.expansion), 139 | ) 140 | 141 | layers = [] 142 | layers.append(block(self.inplanes, planes, stride, downsample)) 143 | self.inplanes = planes * block.expansion 144 | for i in range(1, blocks): 145 | layers.append(block(self.inplanes, planes)) 146 | 147 | return nn.Sequential(*layers) 148 | 149 | def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): 150 | """ 151 | Construct fully connected layer 152 | - fc_dims (list or tuple): dimensions of fc layers, if None, 153 | no fc layers are constructed 154 | - input_dim (int): input dimension 155 | - dropout_p (float): dropout probability, if None, dropout is unused 156 | """ 157 | if fc_dims is None: 158 | self.feature_dim = input_dim 159 | return None 160 | 161 | assert isinstance(fc_dims, (list, tuple)), 'fc_dims must be either list or tuple, but got {}'.format( 162 | type(fc_dims)) 163 | 164 | layers = [] 165 | for dim in fc_dims: 166 | layers.append(nn.Linear(input_dim, dim)) 167 | layers.append(nn.BatchNorm1d(dim)) 168 | layers.append(nn.ReLU(inplace=True)) 169 | if dropout_p is not None: 170 | layers.append(nn.Dropout(p=dropout_p)) 171 | input_dim = dim 172 | 173 | self.feature_dim = fc_dims[-1] 174 | 175 | return nn.Sequential(*layers) 176 | 177 | def _init_params(self): 178 | for m in self.modules(): 179 | if isinstance(m, nn.Conv2d): 180 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 181 | if m.bias is not None: 182 | nn.init.constant_(m.bias, 0) 183 | elif isinstance(m, nn.BatchNorm2d): 184 | nn.init.constant_(m.weight, 1) 185 | nn.init.constant_(m.bias, 0) 186 | elif isinstance(m, nn.BatchNorm1d): 187 | nn.init.constant_(m.weight, 1) 188 | nn.init.constant_(m.bias, 0) 189 | elif isinstance(m, nn.Linear): 190 | nn.init.normal_(m.weight, 0, 0.01) 191 | if m.bias is not None: 192 | nn.init.constant_(m.bias, 0) 193 | 194 | def featuremaps(self, x): 195 | x = self.conv1(x) 196 | x = self.bn1(x) 197 | x = self.relu(x) 198 | x = self.maxpool(x) 199 | x = self.layer1(x) 200 | x = self.layer2(x) 201 | x = self.layer3(x) 202 | x = self.layer4(x) 203 | return x 204 | 205 | def forward(self, x): 206 | f = self.featuremaps(x) 207 | v = self.global_avgpool(f) 208 | v = v.view(v.size(0), -1) 209 | 210 | if self.fc is not None: 211 | v = self.fc(v) 212 | 213 | if not self.training: 214 | return v 215 | 216 | y = self.classifier(v) 217 | 218 | if self.loss == {'xent'}: 219 | return y 220 | elif self.loss == {'xent', 'htri'}: 221 | return y, v 222 | else: 223 | raise KeyError("Unsupported loss: {}".format(self.loss)) 224 | 225 | 226 | class IBN(nn.Module): 227 | def __init__(self, planes): 228 | super(IBN, self).__init__() 229 | half1 = int(planes / 2) 230 | self.half = half1 231 | half2 = planes - half1 232 | self.IN = nn.InstanceNorm2d(half1, affine=True) 233 | self.BN = nn.BatchNorm2d(half2) 234 | 235 | def forward(self, x): 236 | split = torch.split(x, self.half, 1) 237 | out1 = self.IN(split[0].contiguous()) 238 | out2 = self.BN(split[1].contiguous()) 239 | out = torch.cat((out1, out2), 1) 240 | return out 241 | 242 | 243 | class Bottleneck_IBN(nn.Module): 244 | expansion = 4 245 | 246 | def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None): 247 | super(Bottleneck_IBN, self).__init__() 248 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 249 | if ibn: 250 | self.bn1 = IBN(planes) 251 | else: 252 | self.bn1 = nn.BatchNorm2d(planes) 253 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 254 | padding=1, bias=False) 255 | self.bn2 = nn.BatchNorm2d(planes) 256 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 257 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 258 | self.relu = nn.ReLU(inplace=True) 259 | self.downsample = downsample 260 | self.stride = stride 261 | 262 | def forward(self, x): 263 | residual = x 264 | 265 | out = self.conv1(x) 266 | out = self.bn1(out) 267 | out = self.relu(out) 268 | 269 | out = self.conv2(out) 270 | out = self.bn2(out) 271 | out = self.relu(out) 272 | 273 | out = self.conv3(out) 274 | out = self.bn3(out) 275 | 276 | if self.downsample is not None: 277 | residual = self.downsample(x) 278 | 279 | out += residual 280 | out = self.relu(out) 281 | 282 | return out 283 | 284 | 285 | class ResNet_IBN(nn.Module): 286 | 287 | def __init__(self, last_stride, block, layers, loss, num_classes, fc_dims=None, dropout_p=None, **kwargs): 288 | scale = 64 289 | self.inplanes = scale 290 | super(ResNet_IBN, self).__init__() 291 | self.feature_dim = 512 * block.expansion 292 | self.conv1 = nn.Conv2d(3, scale, kernel_size=7, stride=2, padding=3, 293 | bias=False) 294 | self.bn1 = nn.BatchNorm2d(scale) 295 | self.relu = nn.ReLU(inplace=True) 296 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 297 | self.layer1 = self._make_layer(block, scale, layers[0]) 298 | self.layer2 = self._make_layer(block, scale * 2, layers[1], stride=2) 299 | self.layer3 = self._make_layer(block, scale * 4, layers[2], stride=2) 300 | self.layer4 = self._make_layer(block, scale * 8, layers[3], stride=last_stride) 301 | self.avgpool = nn.AvgPool2d(7) 302 | # self.fc = nn.Linear(scale * 8 * block.expansion, num_classes) 303 | self.loss = loss 304 | self.global_avgpool = nn.AdaptiveAvgPool2d(1) 305 | self.fc = self._construct_fc_layer(fc_dims, scale * 8 * block.expansion, dropout_p) 306 | self.classifier = nn.Linear(self.feature_dim, num_classes) 307 | 308 | for m in self.modules(): 309 | if isinstance(m, nn.Conv2d): 310 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 311 | m.weight.data.normal_(0, math.sqrt(2. / n)) 312 | elif isinstance(m, nn.BatchNorm2d): 313 | m.weight.data.fill_(1) 314 | m.bias.data.zero_() 315 | elif isinstance(m, nn.InstanceNorm2d): 316 | m.weight.data.fill_(1) 317 | m.bias.data.zero_() 318 | 319 | def _make_layer(self, block, planes, blocks, stride=1): 320 | downsample = None 321 | if stride != 1 or self.inplanes != planes * block.expansion: 322 | downsample = nn.Sequential( 323 | nn.Conv2d(self.inplanes, planes * block.expansion, 324 | kernel_size=1, stride=stride, bias=False), 325 | nn.BatchNorm2d(planes * block.expansion), 326 | ) 327 | 328 | layers = [] 329 | ibn = True 330 | if planes == 512: 331 | ibn = False 332 | layers.append(block(self.inplanes, planes, ibn, stride, downsample)) 333 | self.inplanes = planes * block.expansion 334 | for i in range(1, blocks): 335 | layers.append(block(self.inplanes, planes, ibn)) 336 | 337 | return nn.Sequential(*layers) 338 | 339 | def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): 340 | """ 341 | Construct fully connected layer 342 | - fc_dims (list or tuple): dimensions of fc layers, if None, 343 | no fc layers are constructed 344 | - input_dim (int): input dimension 345 | - dropout_p (float): dropout probability, if None, dropout is unused 346 | """ 347 | if fc_dims is None: 348 | self.feature_dim = input_dim 349 | return None 350 | 351 | assert isinstance(fc_dims, (list, tuple)), 'fc_dims must be either list or tuple, but got {}'.format( 352 | type(fc_dims)) 353 | 354 | layers = [] 355 | for dim in fc_dims: 356 | layers.append(nn.Linear(input_dim, dim)) 357 | layers.append(nn.BatchNorm1d(dim)) 358 | layers.append(nn.ReLU(inplace=True)) 359 | if dropout_p is not None: 360 | layers.append(nn.Dropout(p=dropout_p)) 361 | input_dim = dim 362 | 363 | self.feature_dim = fc_dims[-1] 364 | 365 | return nn.Sequential(*layers) 366 | 367 | def forward(self, x): 368 | x = self.conv1(x) 369 | x = self.bn1(x) 370 | x = self.relu(x) 371 | x = self.maxpool(x) 372 | 373 | x = self.layer1(x) 374 | x = self.layer2(x) 375 | x = self.layer3(x) 376 | x = self.layer4(x) 377 | 378 | # x = self.avgpool(x) 379 | # x = x.view(x.size(0), -1) 380 | # x = self.fc(x) 381 | f = x 382 | v = self.global_avgpool(x) 383 | v = v.view(v.size(0), -1) 384 | 385 | if self.fc is not None: 386 | v = self.fc(v) 387 | 388 | if not self.training: 389 | return v 390 | 391 | y = self.classifier(v) 392 | 393 | if self.loss == {'xent'}: 394 | return y 395 | elif self.loss == {'xent', 'htri'}: 396 | return y, v 397 | else: 398 | raise KeyError("Unsupported loss: {}".format(self.loss)) 399 | 400 | def load_param(self, model_path): 401 | param_dict = torch.load(model_path) 402 | for i in param_dict: 403 | if 'fc' in i: 404 | continue 405 | self.state_dict()[i].copy_(param_dict[i]) 406 | 407 | 408 | def init_pretrained_weights(model, model_url): 409 | """ 410 | Initialize model with pretrained weights. 411 | Layers that don't match with pretrained layers in name or size are kept unchanged. 412 | """ 413 | pretrain_dict = model_zoo.load_url(model_url) 414 | model_dict = model.state_dict() 415 | pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} 416 | model_dict.update(pretrain_dict) 417 | model.load_state_dict(model_dict) 418 | print('Initialized model with pretrained weights from {}'.format(model_url)) 419 | 420 | 421 | """ 422 | Residual network configurations: 423 | -- 424 | resnet18: block=BasicBlock, layers=[2, 2, 2, 2] 425 | resnet34: block=BasicBlock, layers=[3, 4, 6, 3] 426 | resnet50: block=Bottleneck, layers=[3, 4, 6, 3] 427 | resnet101: block=Bottleneck, layers=[3, 4, 23, 3] 428 | resnet152: block=Bottleneck, layers=[3, 8, 36, 3] 429 | """ 430 | 431 | 432 | def resnet50(num_classes, loss={'xent'}, pretrained=True, **kwargs): 433 | model = ResNet( 434 | num_classes=num_classes, 435 | loss=loss, 436 | block=Bottleneck, 437 | layers=[3, 4, 6, 3], 438 | last_stride=2, 439 | fc_dims=None, 440 | dropout_p=None, 441 | **kwargs 442 | ) 443 | if pretrained: 444 | init_pretrained_weights(model, model_urls['resnet50']) 445 | return model 446 | 447 | 448 | def resnet50_fc512(num_classes, loss={'xent'}, pretrained=True, **kwargs): 449 | model = ResNet( 450 | num_classes=num_classes, 451 | loss=loss, 452 | block=Bottleneck, 453 | layers=[3, 4, 6, 3], 454 | last_stride=1, 455 | fc_dims=[512], 456 | dropout_p=None, 457 | **kwargs 458 | ) 459 | if pretrained: 460 | init_pretrained_weights(model, model_urls['resnet50']) 461 | return model 462 | 463 | 464 | def resnet101(num_classes, loss={'xent'}, pretrained=True, **kwargs): 465 | model = ResNet( 466 | num_classes=num_classes, 467 | loss=loss, 468 | block=Bottleneck, 469 | layers=[3, 4, 23, 3], 470 | last_stride=2, 471 | fc_dims=None, 472 | dropout_p=None, 473 | **kwargs 474 | ) 475 | if pretrained: 476 | init_pretrained_weights(model, model_urls['resnet101']) 477 | return model 478 | 479 | 480 | def resnet152(num_classes, loss={'xent'}, pretrained=True, **kwargs): 481 | model = ResNet( 482 | num_classes=num_classes, 483 | loss=loss, 484 | block=Bottleneck, 485 | layers=[3, 8, 36, 3], 486 | last_stride=2, 487 | fc_dims=None, 488 | dropout_p=None, 489 | **kwargs 490 | ) 491 | if pretrained: 492 | init_pretrained_weights(model, model_urls['resnet152']) 493 | return model 494 | 495 | 496 | def resnet50_ibn_a(num_classes, loss={'xent'}, pretrained=True, **kwargs): 497 | """Constructs a ResNet-50 model. 498 | Args: 499 | pretrained (bool): If True, returns a model pre-trained on ImageNet 500 | """ 501 | model = ResNet_IBN(1, Bottleneck_IBN, [3, 4, 6, 3], loss, **kwargs) 502 | if pretrained: 503 | init_pretrained_weights(model, model_urls['resnet50']) 504 | return model 505 | 506 | 507 | def resnet101_ibn_a(num_classes, loss={'xent'}, pretrained=True, **kwargs): 508 | """Constructs a ResNet-101 model. 509 | Args: 510 | pretrained (bool): If True, returns a model pre-trained on ImageNet 511 | """ 512 | model = ResNet_IBN(1, Bottleneck_IBN, [3, 4, 23, 3], loss, num_classes, **kwargs) 513 | if pretrained: 514 | init_pretrained_weights(model, model_urls['resnet101_ibn_a']) 515 | return model 516 | 517 | 518 | def resnet152_ibn_a(num_classes, loss={'xent'}, pretrained=True, **kwargs): 519 | """Constructs a ResNet-152 model. 520 | Args: 521 | pretrained (bool): If True, returns a model pre-trained on ImageNet 522 | """ 523 | model = ResNet_IBN(1, Bottleneck_IBN, [3, 8, 36, 3], loss, num_classes, **kwargs) 524 | if pretrained: 525 | init_pretrained_weights(model, model_urls['resnet152']) 526 | return model 527 | -------------------------------------------------------------------------------- /model/senet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division, absolute_import 2 | from collections import OrderedDict 3 | import math 4 | 5 | import torch.nn as nn 6 | from torch.utils import model_zoo 7 | 8 | __all__ = ['SENet', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152', 9 | 'se_resnext50_32x4d', 'se_resnext101_32x4d'] 10 | 11 | model_urls = { 12 | 'senet154': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth', 13 | 'se_resnet50': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth', 14 | 'se_resnet101': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth', 15 | 'se_resnet152': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth', 16 | 'se_resnext50_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth', 17 | 'se_resnext101_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth', 18 | } 19 | 20 | pretrained_settings = { 21 | 'senet154': { 22 | 'imagenet': { 23 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth', 24 | 'input_space': 'RGB', 25 | 'input_size': [3, 224, 224], 26 | 'input_range': [0, 1], 27 | 'mean': [0.485, 0.456, 0.406], 28 | 'std': [0.229, 0.224, 0.225], 29 | 'num_classes': 1000 30 | } 31 | }, 32 | 'se_resnet50': { 33 | 'imagenet': { 34 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth', 35 | 'input_space': 'RGB', 36 | 'input_size': [3, 224, 224], 37 | 'input_range': [0, 1], 38 | 'mean': [0.485, 0.456, 0.406], 39 | 'std': [0.229, 0.224, 0.225], 40 | 'num_classes': 1000 41 | } 42 | }, 43 | 'se_resnet101': { 44 | 'imagenet': { 45 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth', 46 | 'input_space': 'RGB', 47 | 'input_size': [3, 224, 224], 48 | 'input_range': [0, 1], 49 | 'mean': [0.485, 0.456, 0.406], 50 | 'std': [0.229, 0.224, 0.225], 51 | 'num_classes': 1000 52 | } 53 | }, 54 | 'se_resnet152': { 55 | 'imagenet': { 56 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth', 57 | 'input_space': 'RGB', 58 | 'input_size': [3, 224, 224], 59 | 'input_range': [0, 1], 60 | 'mean': [0.485, 0.456, 0.406], 61 | 'std': [0.229, 0.224, 0.225], 62 | 'num_classes': 1000 63 | } 64 | }, 65 | 'se_resnext50_32x4d': { 66 | 'imagenet': { 67 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth', 68 | 'input_space': 'RGB', 69 | 'input_size': [3, 224, 224], 70 | 'input_range': [0, 1], 71 | 'mean': [0.485, 0.456, 0.406], 72 | 'std': [0.229, 0.224, 0.225], 73 | 'num_classes': 1000 74 | } 75 | }, 76 | 'se_resnext101_32x4d': { 77 | 'imagenet': { 78 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth', 79 | 'input_space': 'RGB', 80 | 'input_size': [3, 224, 224], 81 | 'input_range': [0, 1], 82 | 'mean': [0.485, 0.456, 0.406], 83 | 'std': [0.229, 0.224, 0.225], 84 | 'num_classes': 1000 85 | } 86 | }, 87 | } 88 | 89 | class SEModule(nn.Module): 90 | 91 | def __init__(self, channels, reduction): 92 | super(SEModule, self).__init__() 93 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 94 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, 95 | padding=0) 96 | self.relu = nn.ReLU(inplace=True) 97 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, 98 | padding=0) 99 | self.sigmoid = nn.Sigmoid() 100 | 101 | def forward(self, x): 102 | module_input = x 103 | x = self.avg_pool(x) 104 | x = self.fc1(x) 105 | x = self.relu(x) 106 | x = self.fc2(x) 107 | x = self.sigmoid(x) 108 | return module_input * x 109 | 110 | 111 | class Bottleneck(nn.Module): 112 | """ 113 | Base class for bottlenecks that implements `forward()` method. 114 | """ 115 | def forward(self, x): 116 | residual = x 117 | 118 | out = self.conv1(x) 119 | out = self.bn1(out) 120 | out = self.relu(out) 121 | 122 | out = self.conv2(out) 123 | out = self.bn2(out) 124 | out = self.relu(out) 125 | 126 | out = self.conv3(out) 127 | out = self.bn3(out) 128 | 129 | if self.downsample is not None: 130 | residual = self.downsample(x) 131 | 132 | out = self.se_module(out) + residual 133 | out = self.relu(out) 134 | 135 | return out 136 | 137 | 138 | class SEBottleneck(Bottleneck): 139 | """ 140 | Bottleneck for SENet154. 141 | """ 142 | expansion = 4 143 | 144 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 145 | downsample=None): 146 | super(SEBottleneck, self).__init__() 147 | self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) 148 | self.bn1 = nn.BatchNorm2d(planes * 2) 149 | self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3, 150 | stride=stride, padding=1, groups=groups, 151 | bias=False) 152 | self.bn2 = nn.BatchNorm2d(planes * 4) 153 | self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, 154 | bias=False) 155 | self.bn3 = nn.BatchNorm2d(planes * 4) 156 | self.relu = nn.ReLU(inplace=True) 157 | self.se_module = SEModule(planes * 4, reduction=reduction) 158 | self.downsample = downsample 159 | self.stride = stride 160 | 161 | 162 | class SEResNetBottleneck(Bottleneck): 163 | """ 164 | ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe 165 | implementation and uses `stride=stride` in `conv1` and not in `conv2` 166 | (the latter is used in the torchvision implementation of ResNet). 167 | """ 168 | expansion = 4 169 | 170 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 171 | downsample=None): 172 | super(SEResNetBottleneck, self).__init__() 173 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, 174 | stride=stride) 175 | self.bn1 = nn.BatchNorm2d(planes) 176 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, 177 | groups=groups, bias=False) 178 | self.bn2 = nn.BatchNorm2d(planes) 179 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 180 | self.bn3 = nn.BatchNorm2d(planes * 4) 181 | self.relu = nn.ReLU(inplace=True) 182 | self.se_module = SEModule(planes * 4, reduction=reduction) 183 | self.downsample = downsample 184 | self.stride = stride 185 | 186 | 187 | class SEResNeXtBottleneck(Bottleneck): 188 | """ 189 | ResNeXt bottleneck type C with a Squeeze-and-Excitation module. 190 | """ 191 | expansion = 4 192 | 193 | def __init__(self, inplanes, planes, groups, reduction, stride=1, 194 | downsample=None, base_width=4): 195 | super(SEResNeXtBottleneck, self).__init__() 196 | width = math.floor(planes * (base_width / 64)) * groups 197 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, 198 | stride=1) 199 | self.bn1 = nn.BatchNorm2d(width) 200 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, 201 | padding=1, groups=groups, bias=False) 202 | self.bn2 = nn.BatchNorm2d(width) 203 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) 204 | self.bn3 = nn.BatchNorm2d(planes * 4) 205 | self.relu = nn.ReLU(inplace=True) 206 | self.se_module = SEModule(planes * 4, reduction=reduction) 207 | self.downsample = downsample 208 | self.stride = stride 209 | 210 | 211 | class SENet(nn.Module): 212 | 213 | def __init__(self, block, layers, groups, reduction, dropout_p=0.2, 214 | inplanes=128, input_3x3=True, downsample_kernel_size=3, 215 | downsample_padding=1, num_classes=1000, loss={'xent'}): 216 | """ 217 | Parameters 218 | ---------- 219 | block (nn.Module): Bottleneck class. 220 | - For SENet154: SEBottleneck 221 | - For SE-ResNet models: SEResNetBottleneck 222 | - For SE-ResNeXt models: SEResNeXtBottleneck 223 | layers (list of ints): Number of residual blocks for 4 layers of the 224 | network (layer1...layer4). 225 | groups (int): Number of groups for the 3x3 convolution in each 226 | bottleneck block. 227 | - For SENet154: 64 228 | - For SE-ResNet models: 1 229 | - For SE-ResNeXt models: 32 230 | reduction (int): Reduction ratio for Squeeze-and-Excitation modules. 231 | - For all models: 16 232 | dropout_p (float or None): Drop probability for the Dropout layer. 233 | If `None` the Dropout layer is not used. 234 | - For SENet154: 0.2 235 | - For SE-ResNet models: None 236 | - For SE-ResNeXt models: None 237 | inplanes (int): Number of input channels for layer1. 238 | - For SENet154: 128 239 | - For SE-ResNet models: 64 240 | - For SE-ResNeXt models: 64 241 | input_3x3 (bool): If `True`, use three 3x3 convolutions instead of 242 | a single 7x7 convolution in layer0. 243 | - For SENet154: True 244 | - For SE-ResNet models: False 245 | - For SE-ResNeXt models: False 246 | downsample_kernel_size (int): Kernel size for downsampling convolutions 247 | in layer2, layer3 and layer4. 248 | - For SENet154: 3 249 | - For SE-ResNet models: 1 250 | - For SE-ResNeXt models: 1 251 | downsample_padding (int): Padding for downsampling convolutions in 252 | layer2, layer3 and layer4. 253 | - For SENet154: 1 254 | - For SE-ResNet models: 0 255 | - For SE-ResNeXt models: 0 256 | num_classes (int): Number of outputs in `last_linear` layer. 257 | - For all models: 1000 258 | """ 259 | super(SENet, self).__init__() 260 | self.inplanes = inplanes 261 | if input_3x3: 262 | layer0_modules = [ 263 | ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1, 264 | bias=False)), 265 | ('bn1', nn.BatchNorm2d(64)), 266 | ('relu1', nn.ReLU(inplace=True)), 267 | ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, 268 | bias=False)), 269 | ('bn2', nn.BatchNorm2d(64)), 270 | ('relu2', nn.ReLU(inplace=True)), 271 | ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, 272 | bias=False)), 273 | ('bn3', nn.BatchNorm2d(inplanes)), 274 | ('relu3', nn.ReLU(inplace=True)), 275 | ] 276 | else: 277 | layer0_modules = [ 278 | ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2, 279 | padding=3, bias=False)), 280 | ('bn1', nn.BatchNorm2d(inplanes)), 281 | ('relu1', nn.ReLU(inplace=True)), 282 | ] 283 | # To preserve compatibility with Caffe weights `ceil_mode=True` 284 | # is used instead of `padding=1`. 285 | layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2, 286 | ceil_mode=True))) 287 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) 288 | self.layer1 = self._make_layer( 289 | block, 290 | planes=64, 291 | blocks=layers[0], 292 | groups=groups, 293 | reduction=reduction, 294 | downsample_kernel_size=1, 295 | downsample_padding=0 296 | ) 297 | self.layer2 = self._make_layer( 298 | block, 299 | planes=128, 300 | blocks=layers[1], 301 | stride=2, 302 | groups=groups, 303 | reduction=reduction, 304 | downsample_kernel_size=downsample_kernel_size, 305 | downsample_padding=downsample_padding 306 | ) 307 | self.layer3 = self._make_layer( 308 | block, 309 | planes=256, 310 | blocks=layers[2], 311 | stride=2, 312 | groups=groups, 313 | reduction=reduction, 314 | downsample_kernel_size=downsample_kernel_size, 315 | downsample_padding=downsample_padding 316 | ) 317 | self.layer4 = self._make_layer( 318 | block, 319 | planes=512, 320 | blocks=layers[3], 321 | stride=2, 322 | groups=groups, 323 | reduction=reduction, 324 | downsample_kernel_size=downsample_kernel_size, 325 | downsample_padding=downsample_padding 326 | ) 327 | self.avg_pool = nn.AvgPool2d(7, stride=1) 328 | self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None 329 | self.last_linear = nn.Linear(512 * block.expansion, num_classes) 330 | self.loss = loss 331 | 332 | def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, 333 | downsample_kernel_size=1, downsample_padding=0): 334 | downsample = None 335 | if stride != 1 or self.inplanes != planes * block.expansion: 336 | downsample = nn.Sequential( 337 | nn.Conv2d(self.inplanes, planes * block.expansion, 338 | kernel_size=downsample_kernel_size, stride=stride, 339 | padding=downsample_padding, bias=False), 340 | nn.BatchNorm2d(planes * block.expansion), 341 | ) 342 | 343 | layers = [] 344 | layers.append(block(self.inplanes, planes, groups, reduction, stride, 345 | downsample)) 346 | self.inplanes = planes * block.expansion 347 | for i in range(1, blocks): 348 | layers.append(block(self.inplanes, planes, groups, reduction)) 349 | 350 | return nn.Sequential(*layers) 351 | 352 | def features(self, x): 353 | x = self.layer0(x) 354 | x = self.layer1(x) 355 | x = self.layer2(x) 356 | x = self.layer3(x) 357 | x = self.layer4(x) 358 | return x 359 | 360 | def logits(self, x): 361 | x = self.avg_pool(x) 362 | if self.dropout is not None: 363 | x = self.dropout(x) 364 | x = x.view(x.size(0), -1) 365 | x = self.last_linear(x) 366 | return x 367 | 368 | def forward(self, x): 369 | x = self.conv1(x) 370 | x = self.bn1(x) 371 | x = self.relu(x) 372 | x = self.maxpool(x) 373 | 374 | x = self.layer1(x) 375 | x = self.layer2(x) 376 | x = self.layer3(x) 377 | x = self.layer4(x) 378 | 379 | v = self.global_avgpool(x) 380 | v = v.view(v.size(0), -1) 381 | 382 | if self.fc is not None: 383 | v = self.fc(v) 384 | 385 | if not self.training: 386 | return v 387 | 388 | y = self.classifier(v) 389 | 390 | if self.loss == {'xent'}: 391 | return y 392 | elif self.loss == {'xent', 'htri'}: 393 | return y, v 394 | else: 395 | raise KeyError("Unsupported loss: {}".format(self.loss)) 396 | 397 | def init_pretrained_weights(model, model_url): 398 | """ 399 | Initialize model with pretrained weights. 400 | Layers that don't match with pretrained layers in name or size are kept unchanged. 401 | """ 402 | pretrain_dict = model_zoo.load_url(model_url) 403 | model_dict = model.state_dict() 404 | pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} 405 | model_dict.update(pretrain_dict) 406 | model.load_state_dict(model_dict) 407 | print('Initialized model with pretrained weights from {}'.format(model_url)) 408 | 409 | 410 | def initialize_pretrained_model(model, num_classes, settings): 411 | assert num_classes == settings['num_classes'], \ 412 | 'num_classes should be {}, but is {}'.format( 413 | settings['num_classes'], num_classes) 414 | model.load_state_dict(model_zoo.load_url(settings['url'])) 415 | model.input_space = settings['input_space'] 416 | model.input_size = settings['input_size'] 417 | model.input_range = settings['input_range'] 418 | model.mean = settings['mean'] 419 | model.std = settings['std'] 420 | 421 | 422 | def senet154(num_classes, loss={'xent'}, pretrained=True, **kwargs): 423 | model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16, 424 | dropout_p=0.2, num_classes=num_classes, loss=loss) 425 | if pretrained: 426 | init_pretrained_weights(model, model_urls['senet154']) 427 | return model 428 | 429 | 430 | def se_resnet50(num_classes, loss={'xent'}, pretrained=True, **kwargs): 431 | model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16, 432 | dropout_p=None, inplanes=64, input_3x3=False, 433 | downsample_kernel_size=1, downsample_padding=0, 434 | num_classes=num_classes, loss=loss) 435 | if pretrained: 436 | init_pretrained_weights(model, model_urls['se_resnet50']) 437 | return model 438 | 439 | 440 | def se_resnet101(num_classes, loss={'xent'}, pretrained=True, **kwargs): 441 | model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16, 442 | dropout_p=None, inplanes=64, input_3x3=False, 443 | downsample_kernel_size=1, downsample_padding=0, 444 | num_classes=num_classes, loss=loss) 445 | if pretrained: 446 | init_pretrained_weights(model, model_urls['se_resnet101']) 447 | return model 448 | 449 | 450 | def se_resnet152(num_classes, loss={'xent'}, pretrained=True, **kwargs): 451 | model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16, 452 | dropout_p=None, inplanes=64, input_3x3=False, 453 | downsample_kernel_size=1, downsample_padding=0, 454 | num_classes=num_classes, loss=loss) 455 | if pretrained: 456 | init_pretrained_weights(model, model_urls['se_resnet152']) 457 | return model 458 | 459 | 460 | def se_resnext50_32x4d(num_classes, loss={'xent'}, pretrained=True, **kwargs): 461 | model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16, 462 | dropout_p=None, inplanes=64, input_3x3=False, 463 | downsample_kernel_size=1, downsample_padding=0, 464 | num_classes=num_classes, loss=loss) 465 | if pretrained: 466 | init_pretrained_weights(model, model_urls['se_resnext50_32x4d']) 467 | return model 468 | 469 | 470 | def se_resnext101_32x4d(num_classes, loss={'xent'}, pretrained=True, **kwargs): 471 | model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16, 472 | dropout_p=None, inplanes=64, input_3x3=False, 473 | downsample_kernel_size=1, downsample_padding=0, 474 | num_classes=num_classes, loss=loss) 475 | # if pretrained is not None: 476 | # settings = pretrained_settings['se_resnext101_32x4d'][pretrained] 477 | # initialize_pretrained_model(model, num_classes, settings) 478 | if pretrained: 479 | init_pretrained_weights(model, model_urls['se_resnext101_32x4d']) 480 | return model -------------------------------------------------------------------------------- /pkl/duke/index.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adhirajghosh/RPTM_reid/183e1f77a0979ab2ffa08b0bdb1c43ef0f633ad5/pkl/duke/index.pkl -------------------------------------------------------------------------------- /pkl/vehicleid/index.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adhirajghosh/RPTM_reid/183e1f77a0979ab2ffa08b0bdb1c43ef0f633ad5/pkl/vehicleid/index.pkl -------------------------------------------------------------------------------- /pkl/veri/cids.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adhirajghosh/RPTM_reid/183e1f77a0979ab2ffa08b0bdb1c43ef0f633ad5/pkl/veri/cids.pkl -------------------------------------------------------------------------------- /pkl/veri/data.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adhirajghosh/RPTM_reid/183e1f77a0979ab2ffa08b0bdb1c43ef0f633ad5/pkl/veri/data.pkl -------------------------------------------------------------------------------- /pkl/veri/index_vp.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adhirajghosh/RPTM_reid/183e1f77a0979ab2ffa08b0bdb1c43ef0f633ad5/pkl/veri/index_vp.pkl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm 2 | yacs 3 | opencv-python 4 | albumentations 5 | matplotlib 6 | umap-learn 7 | Pillow 8 | pretrainedmodels -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | import os 4 | import os.path as osp 5 | import time 6 | import torch 7 | import numpy as np 8 | import numpy.ma as ma 9 | import random 10 | try: 11 | from apex.fp16_utils import * 12 | from apex import amp, optimizers 13 | except ImportError: # will be 3.x series 14 | print('This is not an error. If you want to use low precision, i.e., fp16, please install the apex with cuda support (https://github.com/NVIDIA/apex) and update pytorch to 1.0') 15 | 16 | from eval import do_test 17 | from utils.loggers import RankLogger 18 | from utils.torchtools import accuracy, save_checkpoint 19 | from utils.functions import search, strint 20 | from utils.avgmeter import AverageMeter 21 | from utils.visualtools import visualize_ranked_results 22 | 23 | 24 | def do_train(cfg, trainloader, train_dict, data_tfr, testloader_dict, dm, 25 | model, optimizer, scheduler, criterion_htri,criterion_xent): 26 | ranklogger = RankLogger(cfg.DATASET.SOURCE_NAME, cfg.DATASET.TARGET_NAME) 27 | gms = train_dict['gms'] 28 | pidx = train_dict['pidx'] 29 | folders = [] 30 | for fld in os.listdir(cfg.DATASET.SPLIT_DIR): 31 | folders.append(fld) 32 | # data_index = search_index(gms, cfg.DATASET.SPLIT_DIR, folders) 33 | data_index = search(cfg.DATASET.SPLIT_DIR) 34 | 35 | for epoch in range(cfg.SOLVER.MAX_EPOCHS): 36 | losses = AverageMeter() 37 | xent_losses = AverageMeter() 38 | htri_losses = AverageMeter() 39 | accs = AverageMeter() 40 | batch_time = AverageMeter() 41 | 42 | model.train() 43 | for p in model.parameters(): 44 | p.requires_grad = True # open all layers 45 | 46 | end = time.time() 47 | for batch_idx, (img, label, index, pid, _) in enumerate(trainloader): 48 | 49 | trainX, trainY = torch.zeros((cfg.SOLVER.TRAIN_BATCH_SIZE * 3, 3, cfg.INPUT.HEIGHT, cfg.INPUT.WIDTH), dtype=torch.float32), torch.zeros( 50 | (cfg.SOLVER.TRAIN_BATCH_SIZE * 3), dtype=torch.int64) 51 | 52 | for i in range(cfg.SOLVER.TRAIN_BATCH_SIZE): 53 | 54 | labelx = str(label[i]) 55 | # print(labelx) 56 | indexx = int(index[i]) 57 | cidx = int(pid[i]) 58 | if indexx > len(gms[labelx]) - 1: 59 | indexx = len(gms[labelx]) - 1 60 | a = gms[labelx][indexx] 61 | 62 | if cfg.MODEL.RPTM_SELECT == 'min': 63 | threshold = np.arange(10) 64 | elif cfg.MODEL.RPTM_SELECT == 'mean': 65 | threshold = np.arange(np.amax(gms[labelx][indexx])//2) 66 | elif cfg.MODEL.RPTM_SELECT == 'max': 67 | threshold = np.arange(np.amax(gms[labelx][indexx])) 68 | else: 69 | threshold = np.arange(np.amax(gms[labelx][indexx]) // 2) #defaults to mean 70 | 71 | minpos = np.argmin(ma.masked_where(a == threshold, a)) 72 | pos_dic = data_tfr[data_index[cidx][1] + minpos] 73 | # print(pos_dic[1]) 74 | neg_label = int(labelx) 75 | while True: 76 | neg_label = random.choice(range(1, 770)) 77 | if neg_label is not int(labelx) and os.path.isdir( 78 | os.path.join(cfg.DATASET.SPLIT_DIR, strint(neg_label, 'veri'))) is True: 79 | break 80 | negative_label = strint(neg_label, 'veri') 81 | neg_cid = pidx[negative_label] 82 | neg_index = random.choice(range(0, len(gms[negative_label]))) 83 | 84 | neg_dic = data_tfr[data_index[neg_cid][1] + neg_index] 85 | trainX[i] = img[i] 86 | trainX[i + cfg.SOLVER.TRAIN_BATCH_SIZE] = pos_dic[0] 87 | trainX[i + (cfg.SOLVER.TRAIN_BATCH_SIZE * 2)] = neg_dic[0] 88 | trainY[i] = cidx 89 | trainY[i + cfg.SOLVER.TRAIN_BATCH_SIZE] = pos_dic[3] 90 | trainY[i + (cfg.SOLVER.TRAIN_BATCH_SIZE * 2)] = neg_dic[3] 91 | optimizer.zero_grad() 92 | trainX = trainX.cuda() 93 | trainY = trainY.cuda() 94 | outputs, features = model(trainX) 95 | xent_loss = criterion_xent(outputs[0:cfg.SOLVER.TRAIN_BATCH_SIZE], trainY[0:cfg.SOLVER.TRAIN_BATCH_SIZE]) 96 | htri_loss = criterion_htri(features, trainY) 97 | 98 | 99 | loss = cfg.LOSS.LAMBDA_HTRI * htri_loss + cfg.LOSS.LAMBDA_XENT * xent_loss 100 | 101 | if cfg.SOLVER.USE_AMP: 102 | with amp.scale_loss(loss, optimizer) as scaled_loss: 103 | scaled_loss.backward() 104 | else: 105 | loss.backward() 106 | 107 | optimizer.step() 108 | for param_group in optimizer.param_groups: 109 | # print(param_group['lr'] ) 110 | lrrr = str(param_group['lr']) 111 | 112 | batch_time.update(time.time() - end) 113 | losses.update(loss.item(), trainY.size(0)) 114 | htri_losses.update(htri_loss.item(), trainY.size(0)) 115 | accs.update(accuracy(outputs[0:cfg.SOLVER.TRAIN_BATCH_SIZE], trainY[0:cfg.SOLVER.TRAIN_BATCH_SIZE])[0]) 116 | 117 | if (batch_idx) % cfg.MISC.PRINT_FREQ == 0: 118 | print('Train ', end=" ") 119 | print('Epoch: [{0}][{1}/{2}]\t' 120 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 121 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 122 | 'Acc {acc.val:.2f} ({acc.avg:.2f})\t' 123 | 'lr {lrrr} \t'.format( 124 | epoch + 1, batch_idx + 1, len(trainloader), 125 | batch_time=batch_time, 126 | loss=losses, 127 | acc=accs, 128 | lrrr=lrrr, 129 | )) 130 | 131 | end = time.time() 132 | 133 | scheduler.step() 134 | print('=> Test') 135 | 136 | for name in cfg.DATASET.TARGET_NAME: 137 | print('Evaluating {} ...'.format(name)) 138 | queryloader = testloader_dict[name]['query'] 139 | galleryloader = testloader_dict[name]['gallery'] 140 | rank1, distmat, rank2, distmat_re = do_test(model, queryloader, galleryloader, cfg.TEST.TEST_BATCH_SIZE, cfg.MISC.USE_GPU, cfg.DATASET.TARGET_NAME[0]) 141 | 142 | ranklogger.write(name, epoch + 1, rank1) 143 | ranklogger.write(name, epoch + 1, rank2) 144 | 145 | if (epoch + 1) == cfg.SOLVER.MAX_EPOCHS and cfg.TEST.VIS_RANK == True: 146 | visualize_ranked_results( 147 | distmat_re, dm.return_testdataset_by_name(name), 148 | save_dir=osp.join(cfg.MISC.SAVE_DIR, 'ranked_results', name), 149 | topk=20) 150 | 151 | del queryloader 152 | del galleryloader 153 | del distmat 154 | # print(torch.cuda.memory_allocated(),torch.cuda.memory_cached()) 155 | torch.cuda.empty_cache() 156 | 157 | if (epoch + 1) == cfg.SOLVER.MAX_EPOCHS: 158 | save_checkpoint({ 159 | 'state_dict': model.state_dict(), 160 | 'rank1': rank2, 161 | 'epoch': epoch + 1, 162 | 'arch': cfg.MODEL.ARCH, 163 | 'optimizer': optimizer.state_dict(), 164 | }, cfg.MISC.SAVE_DIR, cfg.SOLVER.OPTIMIZER_NAME) -------------------------------------------------------------------------------- /utils/avgmeter.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value. 7 | 8 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262 9 | """ 10 | 11 | def __init__(self): 12 | self.reset() 13 | 14 | def reset(self): 15 | self.val = 0 16 | self.avg = 0 17 | self.sum = 0 18 | self.count = 0 19 | 20 | def update(self, val, n=1): 21 | self.val = val 22 | self.sum += val * n 23 | self.count += n 24 | self.avg = self.sum / self.count 25 | -------------------------------------------------------------------------------- /utils/create_gms_index.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import re 4 | import xml.etree.ElementTree as ET 5 | from collections import defaultdict 6 | 7 | import cv2 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | def str2int(car_id_num: str, dataset: str): 12 | if dataset == 'veri': 13 | if len(car_id_num) == 1: 14 | car_id_num = '00' + str(car_id_num) 15 | elif len(car_id_num) == 2: 16 | car_id_num = '0' + str(car_id_num) 17 | else: 18 | pass 19 | elif dataset == 'vehicleid' or dataset == 'veriwild': 20 | if len(car_id_num) == 1: 21 | car_id_num = '0000' + car_id_num 22 | elif len(car_id_num) == 2: 23 | car_id_num = '000' + car_id_num 24 | elif len(car_id_num) == 3: 25 | car_id_num = '00' + car_id_num 26 | elif len(car_id_num) == 4: 27 | car_id_num = '0' + car_id_num 28 | else: 29 | pass 30 | else: 31 | raise ValueError(f"Unknown dataset: {dataset}") 32 | return car_id_num 33 | 34 | def compute_gms_matches(orb: cv2.ORB, bf: cv2.BFMatcher, img1: np.ndarray, img2: np.ndarray, verbose: bool = False): 35 | # Detect and compute keypoints and descriptors 36 | kp1, des1 = orb.detectAndCompute(img1, None) 37 | kp2, des2 = orb.detectAndCompute(img2, None) 38 | 39 | # Check if descriptors were found 40 | if des1 is None or des2 is None or len(des1) == 0 or len(des2) == 0: 41 | if verbose: 42 | print(f"Warning: No descriptors found for one of the images. Returning 0 matches.") 43 | return 0 44 | 45 | if des1.shape[1] != des2.shape[1]: 46 | if verbose: 47 | print(f"Error: Descriptor sizes don't match. Cannot proceed with matching.") 48 | return 0 49 | 50 | # Convert des1 and des2 to have the same type 51 | # Fixes: cv2.error: OpenCV(4.10.0) /io/opencv/modules/core/src/batch_distance.cpp:274: error: (-215:Assertion failed) type == src2.type() && src1.cols == src2.cols && (type == CV_32F || type == CV_8U) in function 'batchDistance' 52 | if(des1.dtype != [np.uint8, np.float32]) or (des1.dtype != [np.uint8, np.float32]): 53 | if verbose: 54 | print(f"Warning: Converting descriptors to np.uint8.") 55 | des1 = des1.astype(np.uint8) 56 | 57 | if(des2.dtype != [np.uint8, np.float32]) or (des2.dtype != [np.uint8, np.float32]): 58 | if verbose: 59 | print(f"Warning: Converting descriptors to np.uint8.") 60 | des2 = des2.astype(np.uint8) 61 | 62 | # Perform initial matching 63 | matches = bf.match(des1, des2) 64 | 65 | # Apply GMS matching 66 | gms_matches = cv2.xfeatures2d.matchGMS(size1=img1.shape[:2], size2=img2.shape[:2], 67 | keypoints1=kp1, keypoints2=kp2, 68 | matches1to2=matches, withRotation=True) 69 | return len(gms_matches) 70 | 71 | def process_class(image_paths: list, image_size: tuple = (224, 224), verbose: bool = False): 72 | n = len(image_paths) 73 | width, height = image_size 74 | adj_matrix = np.zeros((n, n), dtype=np.int32) # Initialize the adjacency matrix 75 | 76 | # Iterate over all the images 77 | for i in range(n): 78 | # Read and resize, as per paper 79 | img1 = cv2.imread(image_paths[i], cv2.IMREAD_GRAYSCALE) 80 | img1 = cv2.resize(img1, (width, height)) 81 | img1 = img1.astype(np.uint8) 82 | 83 | # Only iterate over j > i 84 | for j in range(i + 1, n): 85 | # Read and resize, as per paper 86 | img2 = cv2.imread(image_paths[j], cv2.IMREAD_GRAYSCALE) 87 | img2 = cv2.resize(img2, (width, height)) 88 | img2 = img2.astype(np.uint8) 89 | 90 | # Compute GMS matches 91 | matches = compute_gms_matches(orb, bf, img1, img2, verbose) 92 | 93 | # Set both (i,j) and (j,i) at once 94 | adj_matrix[i, j] = matches 95 | adj_matrix[j, i] = matches 96 | 97 | pbar.update(1) 98 | 99 | return adj_matrix 100 | 101 | def get_dict(dataset: str, train_file: str, img_dir: str): 102 | class_images = defaultdict(list) 103 | original_to_new_id = {} 104 | new_id_counter = 1 105 | 106 | # Read query file names 107 | if (dataset == 'veri'): 108 | # Open the file with the correct encoding 109 | with open(train_file, 'r', encoding='gb2312') as file: 110 | xml_content = file.read() 111 | 112 | # Parse the XML string 113 | root = ET.fromstring(xml_content) 114 | 115 | # Iterate through each Item element 116 | for item in root.findall('.//Item'): 117 | vehicle_id_str = item.get('vehicleID') 118 | 119 | car_id_num = str(int(re.search(r'\d+', vehicle_id_str).group())) 120 | car_id_num = str2int(car_id_num, dataset) 121 | 122 | full_image_path = os.path.join(img_dir, item.get('imageName')) 123 | 124 | class_images[car_id_num].append(full_image_path) 125 | elif (dataset == 'veriwild'): 126 | with open(train_file, 'r') as file: 127 | lines = [line.strip().split(' ') for line in file.readlines()] 128 | 129 | # Iterate through each Item element 130 | for line in tqdm(lines, desc='Splitting train images'): 131 | vehicle_id = line[0].split('/')[0] 132 | image_name = line[0].split('/')[1] 133 | full_image_path = os.path.join(img_dir, vehicle_id, image_name) 134 | 135 | # Here we map the original vehicle ID to a new ID 136 | if vehicle_id not in original_to_new_id: 137 | new_id = str2int(str(new_id_counter), dataset) 138 | 139 | original_to_new_id[vehicle_id] = new_id 140 | new_id_counter += 1 141 | 142 | new_vehicle_id = original_to_new_id[vehicle_id] 143 | class_images[new_vehicle_id].append(full_image_path) 144 | elif (dataset == 'vehicleid'): 145 | with open(train_file, 'r') as file: 146 | lines = [line.strip() for line in file.readlines()] 147 | 148 | # Iterate through each Item element 149 | for line in tqdm(lines, desc='Splitting train images'): 150 | image_name = line.split(' ')[0] 151 | vehicle_id = line.split(' ')[1] 152 | full_image_path = os.path.join(img_dir, image_name + '.jpg') 153 | 154 | # Here we map the original vehicle ID to a new ID 155 | if vehicle_id not in original_to_new_id: 156 | new_id = str2int(str(new_id_counter), dataset) 157 | 158 | original_to_new_id[vehicle_id] = new_id 159 | new_id_counter += 1 160 | 161 | new_vehicle_id = original_to_new_id[vehicle_id] 162 | class_images[new_vehicle_id].append(full_image_path) 163 | else: 164 | raise ValueError(f"Unknown dataset: {dataset}") 165 | 166 | # Return both the dictionary and the mapping from original to new IDs 167 | return class_images, original_to_new_id 168 | 169 | # ========================== MAIN ========================== # 170 | # Set up paths 171 | dataset = 'veri' # 'veri' (Which is: VeRi-776) / 'veriwild' / 'vehicleid' 172 | base_datapath = 'data' 173 | gms_path = 'gms' 174 | image_size = (224, 224) # Before computing GMS matches, resize the images to this size (as per paper) 175 | verbose = False # Set to True to see more detailed output, errors etc. 176 | 177 | if (dataset == 'veri'): 178 | data_path = os.path.join(base_datapath, 'veri') 179 | img_dir = os.path.join(data_path, 'image_train') 180 | train_file = os.path.join(data_path, 'train_label.xml') 181 | elif (dataset == 'veriwild'): 182 | data_path = os.path.join(base_datapath, 'veriwild') 183 | img_dir = os.path.join(data_path, 'images') 184 | train_file = os.path.join(data_path, 'train_test_split', 'train_list_start0.txt') 185 | elif (dataset == 'vehicleid'): 186 | data_path = os.path.join(base_datapath, 'vehicleid') 187 | img_dir = os.path.join(data_path, 'image') 188 | train_file = os.path.join(data_path, 'train_test_split', 'train_list.txt') 189 | else: 190 | raise ValueError(f"Unknown dataset: {dataset}") 191 | 192 | output = os.path.join(gms_path, dataset) 193 | if (os.path.exists(output) == False): 194 | os.makedirs(output) 195 | if verbose: 196 | print(f"Output directory created at: {output}") 197 | 198 | # Instantiate the ORB and BFMatcher objects 199 | orb = cv2.ORB_create(nfeatures = 10000, fastThreshold = 0) 200 | bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck = False) 201 | 202 | # Get the dictionary of class images 203 | # It will contain keys as class labels and values as lists of image paths 204 | # Example: 205 | # { 206 | # '1': ['/190472.jpg', '/134671.jpg', ...], 207 | # '2': ['/134718.jpg', '/824511.jpg', ...], 208 | # ... 209 | # } 210 | class_images, id_mapping = get_dict(dataset, train_file, img_dir) 211 | 212 | # # In case you want to filter the dictionary to start from a certain class (resuming from a checkpoint, basically) 213 | # resuming_class = 5344 214 | # class_images = {k: v for k, v in class_images.items() if int(k) >= resuming_class} 215 | 216 | # Create the index_vp.pkl file 217 | # dict_index should contain the name of the images as keys, and a tuple (class_label, counter) as values 218 | dict_index = {os.path.basename(image): (class_label, counter) 219 | for class_label, images in class_images.items() 220 | for counter, image in enumerate(images)} 221 | 222 | with open(os.path.join(output, f'index_vp_{dataset}.pkl'), 'wb') as f: 223 | pickle.dump(dict_index, f) 224 | if verbose: 225 | print("Successfully saved the Index Pickle file.") 226 | 227 | # Get how many iterations are needed (for tqdm) 228 | total_iterations = sum([len(images) for images in class_images.values()]) 229 | 230 | # Process each class 231 | with tqdm(total=total_iterations, desc="Processing pickle files") as pbar: 232 | for class_label, images in class_images.items(): 233 | print(f"Processing class {class_label} with {len(images)} images") 234 | adj_matrix = process_class(images, image_size=image_size, verbose=verbose) 235 | 236 | # Save the adjacency matrix 237 | with open(os.path.join(output, f'{class_label}.pkl'), 'wb') as f: 238 | pickle.dump(adj_matrix, f) 239 | 240 | print("Processing complete. Adjacency matrices saved.") 241 | # ========================================================== # -------------------------------------------------------------------------------- /utils/evaluation.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import numpy as np 6 | 7 | 8 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank): 9 | """Evaluation with veri metric 10 | Key: for each query identity, its gallery images from the same camera view are discarded. 11 | """ 12 | num_q, num_g = distmat.shape 13 | 14 | if num_g < max_rank: 15 | max_rank = num_g 16 | print('Note: number of gallery samples is quite small, got {}'.format(num_g)) 17 | 18 | indices = np.argsort(distmat, axis=1) 19 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 20 | 21 | # compute cmc curve for each query 22 | all_cmc = [] 23 | all_AP = [] 24 | num_valid_q = 0. # number of valid query 25 | 26 | for q_idx in range(num_q): 27 | # get query pid and camid 28 | q_pid = q_pids[q_idx] 29 | q_camid = q_camids[q_idx] 30 | 31 | # remove gallery samples that have the same pid and camid with query 32 | order = indices[q_idx] 33 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 34 | keep = np.invert(remove) 35 | 36 | # compute cmc curve 37 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 38 | if not np.any(raw_cmc): 39 | # this condition is true when query identity does not appear in gallery 40 | continue 41 | 42 | cmc = raw_cmc.cumsum() 43 | cmc[cmc > 1] = 1 44 | 45 | all_cmc.append(cmc[:max_rank]) 46 | num_valid_q += 1. 47 | 48 | # compute average precision 49 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 50 | num_rel = raw_cmc.sum() 51 | tmp_cmc = raw_cmc.cumsum() 52 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 53 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 54 | AP = tmp_cmc.sum() / num_rel 55 | all_AP.append(AP) 56 | 57 | #assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' 58 | 59 | all_cmc = np.asarray(all_cmc).astype(np.float32) 60 | all_cmc = all_cmc.sum(0) / num_valid_q 61 | #mAP = np.amax(all_AP) 62 | mAP = np.mean(all_AP) 63 | #mAP = all_AP 64 | return all_cmc, mAP 65 | 66 | def evaluate_vid(distmat, q_pids, g_pids, q_camids, g_camids, max_rank): 67 | """Evaluation with vehicleid metric 68 | Key: gallery contains one images for each test vehicles and the other images in test 69 | use as query 70 | """ 71 | num_q, num_g = distmat.shape 72 | 73 | if num_g < max_rank: 74 | max_rank = num_g 75 | print('Note: number of gallery samples is quite small, got {}'.format(num_g)) 76 | 77 | indices = np.argsort(distmat, axis=1) 78 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 79 | 80 | # compute cmc curve for each query 81 | all_cmc = [] 82 | all_AP = [] 83 | num_valid_q = 0. # number of valid query 84 | 85 | for q_idx in range(num_q): 86 | # get query pid and camid 87 | # remove gallery samples that have the same pid and camid with query 88 | ''' 89 | q_pid = q_pids[q_idx] 90 | q_camid = q_camids[q_idx] 91 | order = indices[q_idx] 92 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) # original remove 93 | ''' 94 | remove = False # without camid imformation remove no images in gallery 95 | keep = np.invert(remove) 96 | # compute cmc curve 97 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 98 | if not np.any(raw_cmc): 99 | # this condition is true when query identity does not appear in gallery 100 | continue 101 | 102 | cmc = raw_cmc.cumsum() 103 | cmc[cmc > 1] = 1 104 | 105 | all_cmc.append(cmc[:max_rank]) 106 | num_valid_q += 1. 107 | 108 | # compute average precision 109 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 110 | num_rel = raw_cmc.sum() 111 | tmp_cmc = raw_cmc.cumsum() 112 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 113 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 114 | AP = tmp_cmc.sum() / num_rel 115 | all_AP.append(AP) 116 | 117 | assert num_valid_q > 0, 'Error: all query identities do not appear in gallery' 118 | 119 | all_cmc = np.asarray(all_cmc).astype(np.float32) 120 | all_cmc = all_cmc.sum(0) / num_valid_q 121 | mAP = np.mean(all_AP) 122 | 123 | return all_cmc, mAP -------------------------------------------------------------------------------- /utils/functions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | def keyfromval(dic, val): 5 | return list(dic.keys())[list(dic.values()).index(val)] 6 | 7 | def strint(x, dataset): 8 | if dataset =='veri': 9 | 10 | if len(str(x))==1: 11 | return '00'+str(x) 12 | if len(str(x))==2: 13 | return '0'+str(x) 14 | if len(str(x))==3: 15 | return str(x) 16 | 17 | if dataset == 'duke': 18 | if len(str(x))==1: 19 | return '000'+str(x) 20 | if len(str(x))==2: 21 | return '00'+str(x) 22 | if len(str(x))==3: 23 | return '0'+str(x) 24 | if len(str(x))==4: 25 | return str(x) 26 | 27 | if dataset == 'vehicleid': 28 | if len(str(x))==1: 29 | return '0000'+str(x) 30 | if len(str(x))==2: 31 | return '000'+str(x) 32 | if len(str(x))==3: 33 | return '00'+str(x) 34 | if len(str(x))==4: 35 | return '0'+str(x) 36 | if len(str(x))==5: 37 | return str(x) 38 | 39 | def search1(pkl, path): 40 | #MAIN ONE 41 | start = 0 42 | count = 0 43 | end = 0 44 | data_index = [] 45 | for i in range(1, 777): 46 | label = strint(i) 47 | if os.path.isdir(os.path.join(path, label)) is True: 48 | size = len(pkl[label]) 49 | start = end 50 | end = end+size 51 | data_index.append((count, start, end-1)) 52 | count+=1 53 | if label == '769': 54 | size = len(pkl[label]) 55 | start = end 56 | end = end+size 57 | data_index.append((count, start, end-1)) 58 | break 59 | return data_index 60 | 61 | def search(path): 62 | #MAIN ONE 63 | start = 0 64 | count = 0 65 | end = 0 66 | data_index = [] 67 | for i in sorted(os.listdir(path)): 68 | x = len(os.listdir(os.path.join(path,i))) 69 | data_index.append((count, start, start+x-1)) 70 | count = count+1 71 | start = start+x 72 | return data_index 73 | 74 | def search_index(pkl, path, folders): 75 | start = 0 76 | count = 0 77 | end = 0 78 | data_index = [] 79 | for i in range(0, len(folders)): 80 | label = folders[i] 81 | size = len(pkl[label]) 82 | start = end 83 | end = end+size 84 | data_index.append((count, start, end-1)) 85 | count+=1 86 | return data_index 87 | 88 | def create_split_dirs(cfg): 89 | src_root = cfg.DATASET.TRAIN_DIR 90 | dest_root = cfg.DATASET.SPLIT_DIR 91 | if cfg.DATASET.SOURCE_NAME[0] == 'vehicleid': 92 | if os.path.exists(os.path.join(cfg.DATASET.ROOT_DIR, 'vehicleid/images/')): 93 | return 94 | for i in os.listdir(src_root): 95 | if cfg.DATASET.SOURCE_NAME[0] == 'veri': 96 | folder_name = i.split('_', 2)[0][1:] 97 | else: 98 | folder_name = i.split('_', 2)[0] 99 | if not os.path.exists(os.path.join(dest_root, folder_name)): 100 | os.makedirs(os.path.join(dest_root, folder_name)) 101 | shutil.copyfile(os.path.join(src_root, i), os.path.join(dest_root, folder_name, i)) 102 | -------------------------------------------------------------------------------- /utils/generaltools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | import random 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def set_random_seed(seed): 11 | random.seed(seed) 12 | np.random.seed(seed) 13 | torch.manual_seed(seed) 14 | torch.cuda.manual_seed_all(seed) 15 | -------------------------------------------------------------------------------- /utils/iotools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import os 4 | import os.path as osp 5 | import errno 6 | import json 7 | import warnings 8 | 9 | 10 | def mkdir_if_missing(directory): 11 | if not osp.exists(directory): 12 | try: 13 | os.makedirs(directory) 14 | except OSError as e: 15 | if e.errno != errno.EEXIST: 16 | raise 17 | 18 | 19 | def check_isfile(path): 20 | isfile = osp.isfile(path) 21 | if not isfile: 22 | warnings.warn('No file found at "{}"'.format(path)) 23 | return isfile 24 | 25 | 26 | def read_json(fpath): 27 | with open(fpath, 'r') as f: 28 | obj = json.load(f) 29 | return obj 30 | 31 | 32 | def write_json(obj, fpath): 33 | mkdir_if_missing(osp.dirname(fpath)) 34 | with open(fpath, 'w') as f: 35 | json.dump(obj, f, indent=4, separators=(',', ': ')) 36 | -------------------------------------------------------------------------------- /utils/kwargs.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | 4 | def return_kwargs(cfg): 5 | if cfg.DATASET.SOURCE_NAME[0] == 'vehicleid': 6 | dataset_kwargs = { 7 | 'source_names': cfg.DATASET.SOURCE_NAME, 8 | 'target_names': cfg.DATASET.TARGET_NAME, 9 | 'root': cfg.DATASET.ROOT_DIR, 10 | 'height': cfg.INPUT.HEIGHT, 11 | 'width': cfg.INPUT.WIDTH, 12 | 'test_size': cfg.TEST.TEST_SIZE, 13 | 'train_batch_size': cfg.SOLVER.TRAIN_BATCH_SIZE, 14 | 'test_batch_size': cfg.TEST.TEST_BATCH_SIZE, 15 | 'train_sampler': cfg.DATALOADER.SAMPLER, 16 | 'random_erase': cfg.INPUT.RANDOM_ERASE, 17 | 'color_jitter': cfg.INPUT.JITTER, 18 | 'color_aug': cfg.INPUT.AUG 19 | } 20 | else: 21 | dataset_kwargs = { 22 | 'source_names': cfg.DATASET.SOURCE_NAME, 23 | 'target_names': cfg.DATASET.TARGET_NAME, 24 | 'root': cfg.DATASET.ROOT_DIR, 25 | 'height': cfg.INPUT.HEIGHT, 26 | 'width': cfg.INPUT.WIDTH, 27 | 'test_size': cfg.TEST.TEST_SIZE, 28 | 'train_batch_size': cfg.SOLVER.TRAIN_BATCH_SIZE, 29 | 'test_batch_size': cfg.TEST.TEST_BATCH_SIZE, 30 | 'train_sampler': cfg.DATALOADER.SAMPLER, 31 | 'random_erase': cfg.INPUT.RANDOM_ERASE, 32 | 'color_jitter': cfg.INPUT.JITTER, 33 | 'color_aug': cfg.INPUT.AUG 34 | } 35 | 36 | transform_kwargs = { 37 | 'height': cfg.INPUT.HEIGHT, 38 | 'width': cfg.INPUT.WIDTH, 39 | 'random_erase': cfg.INPUT.RANDOM_ERASE, 40 | 'color_jitter': cfg.INPUT.JITTER, 41 | 'color_aug': cfg.INPUT.AUG 42 | } 43 | 44 | optimizer_kwargs = { 45 | 'optim': cfg.SOLVER.OPTIMIZER_NAME, 46 | 'lr': cfg.SOLVER.BASE_LR, 47 | 'weight_decay': cfg.SOLVER.WEIGHT_DECAY, 48 | 'momentum': cfg.SOLVER.MOMENTUM, 49 | 'sgd_dampening': cfg.SOLVER.SGD_DAMP, 50 | 'sgd_nesterov': cfg.SOLVER.NESTEROV 51 | } 52 | 53 | lr_scheduler_kwargs = { 54 | 'lr_scheduler': cfg.SOLVER.LR_SCHEDULER, 55 | 'stepsize': cfg.SOLVER.STEPSIZE, 56 | 'gamma': cfg.SOLVER.GAMMA 57 | } 58 | 59 | return dataset_kwargs, transform_kwargs, optimizer_kwargs, lr_scheduler_kwargs -------------------------------------------------------------------------------- /utils/loggers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import sys 4 | import os 5 | import os.path as osp 6 | 7 | from .iotools import mkdir_if_missing 8 | 9 | 10 | class Logger(object): 11 | """ 12 | Write console output to external text file. 13 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 14 | """ 15 | def __init__(self, fpath=None): 16 | self.console = sys.stdout 17 | self.file = None 18 | if fpath is not None: 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | self.file = open(fpath, 'w') 21 | 22 | def __del__(self): 23 | self.close() 24 | 25 | def __enter__(self): 26 | pass 27 | 28 | def __exit__(self, *args): 29 | self.close() 30 | 31 | def write(self, msg): 32 | self.console.write(msg) 33 | if self.file is not None: 34 | self.file.write(msg) 35 | 36 | def flush(self): 37 | self.console.flush() 38 | if self.file is not None: 39 | self.file.flush() 40 | os.fsync(self.file.fileno()) 41 | 42 | def close(self): 43 | self.console.close() 44 | if self.file is not None: 45 | self.file.close() 46 | 47 | 48 | class RankLogger(object): 49 | """ 50 | RankLogger records the rank1 matching accuracy obtained for each 51 | test dataset at specified evaluation steps and provides a function 52 | to show the summarized results, which are convenient for analysis. 53 | Args: 54 | - source_names (list): list of strings (names) of source datasets. 55 | - target_names (list): list of strings (names) of target datasets. 56 | """ 57 | def __init__(self, source_names, target_names): 58 | self.source_names = source_names 59 | self.target_names = target_names 60 | self.logger = {name: {'epoch': [], 'rank1': []} for name in self.target_names} 61 | 62 | def write(self, name, epoch, rank1): 63 | self.logger[name]['epoch'].append(epoch) 64 | self.logger[name]['rank1'].append(rank1) 65 | 66 | def show_summary(self): 67 | print('=> Show performance summary') 68 | for name in self.target_names: 69 | from_where = 'source' if name in self.source_names else 'target' 70 | print('{} ({})'.format(name, from_where)) 71 | for epoch, rank1 in zip(self.logger[name]['epoch'], self.logger[name]['rank1']): 72 | print('- epoch {}\t rank1 {:.1%}'.format(epoch, rank1)) 73 | -------------------------------------------------------------------------------- /utils/mean_and_std.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_mean_and_std(dataloader, dataset): 5 | # Compute the mean and std value of dataset. 6 | mean = torch.zeros(3) 7 | std = torch.zeros(3) 8 | print('==> Computing mean and std..') 9 | for inputs, _, _ in dataloader: 10 | for i in range(3): 11 | mean[i] += inputs[:,i,:,:].mean() 12 | std[i] += inputs[:,i,:,:].std() 13 | mean.div_(len(dataset)) 14 | std.div_(len(dataset)) 15 | return mean, std 16 | 17 | 18 | def calculate_mean_and_std(dataset_loader, dataset_size): 19 | mean = torch.zeros(3) 20 | std = torch.zeros(3) 21 | for data in dataset_loader: 22 | now_batch_size, c, h, w = data[0].shape 23 | mean += torch.sum(torch.mean(torch.mean(data[0], dim=3), dim=2), dim=0) 24 | std += torch.sum(torch.std(data[0].view(now_batch_size, c, h * w), dim=2), dim=0) 25 | return mean/dataset_size, std/dataset_size 26 | -------------------------------------------------------------------------------- /utils/reranking.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Fri, 25 May 2018 20:29:09 5 | 6 | @author: luohao 7 | """ 8 | 9 | """ 10 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017. 11 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf 12 | Matlab version: https://github.com/zhunzhong07/person-re-ranking 13 | """ 14 | 15 | """ 16 | API 17 | 18 | probFea: all feature vectors of the query set (torch tensor) 19 | probFea: all feature vectors of the gallery set (torch tensor) 20 | k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3) 21 | MemorySave: set to 'True' when using MemorySave mode 22 | Minibatch: avaliable when 'MemorySave' is 'True' 23 | """ 24 | 25 | import numpy as np 26 | import torch 27 | from scipy.spatial.distance import cdist 28 | 29 | def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat=None, only_local=False): 30 | # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor 31 | query_num = probFea.size(0) 32 | all_num = query_num + galFea.size(0) 33 | if only_local: 34 | original_dist = local_distmat 35 | else: 36 | feat = torch.cat([probFea, galFea]) 37 | # print('using GPU to compute original distance') 38 | distmat = torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num) + \ 39 | torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t() 40 | distmat.addmm_(1, -2, feat, feat.t()) 41 | original_dist = distmat.cpu().numpy() 42 | del feat 43 | if not local_distmat is None: 44 | original_dist = original_dist + local_distmat 45 | gallery_num = original_dist.shape[0] 46 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) 47 | V = np.zeros_like(original_dist).astype(np.float16) 48 | initial_rank = np.argsort(original_dist).astype(np.int32) 49 | 50 | # print('starting re_ranking') 51 | for i in range(all_num): 52 | # k-reciprocal neighbors 53 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 54 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 55 | fi = np.where(backward_k_neigh_index == i)[0] 56 | k_reciprocal_index = forward_k_neigh_index[fi] 57 | k_reciprocal_expansion_index = k_reciprocal_index 58 | for j in range(len(k_reciprocal_index)): 59 | candidate = k_reciprocal_index[j] 60 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1] 61 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 62 | :int(np.around(k1 / 2)) + 1] 63 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 64 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 65 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 66 | candidate_k_reciprocal_index): 67 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 68 | 69 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 70 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 71 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) 72 | original_dist = original_dist[:query_num, ] 73 | if k2 != 1: 74 | V_qe = np.zeros_like(V, dtype=np.float16) 75 | for i in range(all_num): 76 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 77 | V = V_qe 78 | del V_qe 79 | del initial_rank 80 | invIndex = [] 81 | for i in range(gallery_num): 82 | invIndex.append(np.where(V[:, i] != 0)[0]) 83 | 84 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) 85 | 86 | for i in range(query_num): 87 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16) 88 | indNonZero = np.where(V[i, :] != 0)[0] 89 | indImages = [invIndex[ind] for ind in indNonZero] 90 | for j in range(len(indNonZero)): 91 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 92 | V[indImages[j], indNonZero[j]]) 93 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 94 | 95 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 96 | del original_dist 97 | del V 98 | del jaccard_dist 99 | final_dist = final_dist[:query_num, query_num:] 100 | return final_dist 101 | 102 | def re_ranking_numpy(probFea, galFea, k1, k2, lambda_value, local_distmat=None, only_local=False): 103 | query_num = probFea.shape[0] 104 | all_num = query_num + galFea.shape[0] 105 | if only_local: 106 | original_dist = local_distmat 107 | else: 108 | q_g_dist = cdist(probFea, galFea) 109 | q_q_dist = cdist(probFea, probFea) 110 | g_g_dist = cdist(galFea, galFea) 111 | original_dist = np.concatenate( 112 | [np.concatenate([q_q_dist, q_g_dist], axis=1), np.concatenate([q_g_dist.T, g_g_dist], axis=1)], 113 | axis=0) 114 | original_dist = np.power(original_dist, 2).astype(np.float32) 115 | original_dist = np.transpose(1. * original_dist / np.max(original_dist, axis=0)) 116 | if not local_distmat is None: 117 | original_dist = original_dist + local_distmat 118 | gallery_num = original_dist.shape[0] 119 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) 120 | V = np.zeros_like(original_dist).astype(np.float16) 121 | initial_rank = np.argsort(original_dist).astype(np.int32) 122 | 123 | print('starting re_ranking') 124 | for i in range(all_num): 125 | # k-reciprocal neighbors 126 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 127 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 128 | fi = np.where(backward_k_neigh_index == i)[0] 129 | k_reciprocal_index = forward_k_neigh_index[fi] 130 | k_reciprocal_expansion_index = k_reciprocal_index 131 | for j in range(len(k_reciprocal_index)): 132 | candidate = k_reciprocal_index[j] 133 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1] 134 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 135 | :int(np.around(k1 / 2)) + 1] 136 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 137 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 138 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 139 | candidate_k_reciprocal_index): 140 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 141 | 142 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 143 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 144 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) 145 | original_dist = original_dist[:query_num, ] 146 | if k2 != 1: 147 | V_qe = np.zeros_like(V, dtype=np.float16) 148 | for i in range(all_num): 149 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 150 | V = V_qe 151 | del V_qe 152 | del initial_rank 153 | invIndex = [] 154 | for i in range(gallery_num): 155 | invIndex.append(np.where(V[:, i] != 0)[0]) 156 | 157 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) 158 | 159 | for i in range(query_num): 160 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16) 161 | indNonZero = np.where(V[i, :] != 0)[0] 162 | indImages = [invIndex[ind] for ind in indNonZero] 163 | for j in range(len(indNonZero)): 164 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 165 | V[indImages[j], indNonZero[j]]) 166 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 167 | 168 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 169 | del original_dist 170 | del V 171 | del jaccard_dist 172 | final_dist = final_dist[:query_num, query_num:] 173 | return final_dist -------------------------------------------------------------------------------- /utils/torchtools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | from __future__ import division 4 | 5 | from collections import OrderedDict 6 | import shutil 7 | import warnings 8 | import os.path as osp 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from .iotools import mkdir_if_missing 14 | 15 | 16 | def save_checkpoint(state, save_dir, opt, is_best=False, remove_module_from_keys=False): 17 | mkdir_if_missing(save_dir) 18 | if remove_module_from_keys: 19 | # remove 'module.' in state_dict's keys 20 | state_dict = state['state_dict'] 21 | new_state_dict = OrderedDict() 22 | for k, v in state_dict.items(): 23 | if k.startswith('module.'): 24 | k = k[7:] 25 | new_state_dict[k] = v 26 | state['state_dict'] = new_state_dict 27 | # save 28 | epoch = state['epoch'] 29 | 30 | arch = state['arch'] 31 | fpath = osp.join(save_dir, 'model_' + arch+'_'+opt+ '_'+str(epoch)+'.pth.tar') 32 | torch.save(state, fpath) 33 | print('Checkpoint saved to "{}"'.format(fpath)) 34 | if is_best: 35 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar')) 36 | 37 | 38 | def resume_from_checkpoint(ckpt_path, model, optimizer=None): 39 | print('Loading checkpoint from "{}"'.format(ckpt_path)) 40 | ckpt = torch.load(ckpt_path) 41 | model.load_state_dict(ckpt['state_dict']) 42 | print('Loaded model weights') 43 | if optimizer is not None: 44 | optimizer.load_state_dict(ckpt['optimizer']) 45 | print('Loaded optimizer') 46 | start_epoch = ckpt['epoch'] 47 | print('** previous epoch = {}\t previous rank1 = {:.1%}'.format(start_epoch, ckpt['rank1'])) 48 | return start_epoch 49 | 50 | 51 | def adjust_learning_rate(optimizer, base_lr, epoch, stepsize=20, gamma=0.1, 52 | linear_decay=False, final_lr=0, max_epoch=100): 53 | if linear_decay: 54 | # linearly decay learning rate from base_lr to final_lr 55 | frac_done = epoch / max_epoch 56 | lr = frac_done * final_lr + (1. - frac_done) * base_lr 57 | else: 58 | # decay learning rate by gamma for every stepsize 59 | lr = base_lr * (gamma ** (epoch // stepsize)) 60 | 61 | for param_group in optimizer.param_groups: 62 | param_group['lr'] = lr 63 | 64 | 65 | def set_bn_to_eval(m): 66 | # 1. no update for running mean and var 67 | # 2. scale and shift parameters are still trainable 68 | classname = m.__class__.__name__ 69 | if classname.find('BatchNorm') != -1: 70 | m.eval() 71 | 72 | 73 | def open_all_layers(model): 74 | """ 75 | Open all layers in model for training. 76 | Args: 77 | - model (nn.Module): neural net model. 78 | """ 79 | model.train() 80 | for p in model.parameters(): 81 | p.requires_grad = True 82 | 83 | 84 | def open_specified_layers(model, open_layers): 85 | """ 86 | Open specified layers in model for training while keeping 87 | other layers frozen. 88 | Args: 89 | - model (nn.Module): neural net model. 90 | - open_layers (list): list of layer names. 91 | """ 92 | if isinstance(model, nn.DataParallel): 93 | model = model.module 94 | 95 | for layer in open_layers: 96 | assert hasattr(model, layer), '"{}" is not an attribute of the model, please provide the correct name'.format( 97 | layer) 98 | 99 | for name, module in model.named_children(): 100 | if name in open_layers: 101 | module.train() 102 | for p in module.parameters(): 103 | p.requires_grad = True 104 | else: 105 | module.eval() 106 | for p in module.parameters(): 107 | p.requires_grad = False 108 | 109 | 110 | def count_num_param(model): 111 | num_param = sum(p.numel() for p in model.parameters()) / 1e+06 112 | 113 | if isinstance(model, nn.DataParallel): 114 | model = model.module 115 | 116 | if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Module): 117 | # we ignore the classifier because it is unused at test time 118 | num_param -= sum(p.numel() for p in model.classifier.parameters()) / 1e+06 119 | return num_param 120 | 121 | 122 | def accuracy(output, target, topk=(1,)): 123 | """Computes the accuracy over the k top predictions for the specified values of k""" 124 | with torch.no_grad(): 125 | maxk = max(topk) 126 | batch_size = target.size(0) 127 | 128 | if isinstance(output, (tuple, list)): 129 | output = output[0] 130 | 131 | _, pred = output.topk(maxk, 1, True, True) 132 | pred = pred.t() 133 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 134 | 135 | res = [] 136 | for k in topk: 137 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 138 | acc = correct_k.mul_(100.0 / batch_size) 139 | res.append(acc.item()) 140 | return res 141 | 142 | 143 | def load_pretrained_weights(model, weight_path): 144 | """Load pretrianed weights to model 145 | Incompatible layers (unmatched in name or size) will be ignored 146 | Args: 147 | - model (nn.Module): network model, which must not be nn.DataParallel 148 | - weight_path (str): path to pretrained weights 149 | """ 150 | checkpoint = torch.load(weight_path) 151 | if 'state_dict' in checkpoint: 152 | state_dict = checkpoint['state_dict'] 153 | else: 154 | state_dict = checkpoint 155 | model_dict = model.state_dict() 156 | new_state_dict = OrderedDict() 157 | matched_layers, discarded_layers = [], [] 158 | for k, v in state_dict.items(): 159 | # If the pretrained state_dict was saved as nn.DataParallel, 160 | # keys would contain "module.", which should be ignored. 161 | if k.startswith('module.'): 162 | k = k[7:] 163 | if k in model_dict and model_dict[k].size() == v.size(): 164 | new_state_dict[k] = v 165 | matched_layers.append(k) 166 | else: 167 | discarded_layers.append(k) 168 | model_dict.update(new_state_dict) 169 | model.load_state_dict(model_dict) 170 | if len(matched_layers) == 0: 171 | warnings.warn( 172 | 'The pretrained weights "{}" cannot be loaded, please check the key names manually (** ignored and continue **)'.format( 173 | weight_path)) 174 | else: 175 | print('Successfully loaded pretrained weights from "{}"'.format(weight_path)) 176 | if len(discarded_layers) > 0: 177 | print("** The following layers are discarded due to unmatched keys or layer size: {}".format( 178 | discarded_layers)) 179 | -------------------------------------------------------------------------------- /utils/visualtools.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import print_function 3 | 4 | import numpy as np 5 | import os.path as osp 6 | import shutil 7 | 8 | from .iotools import mkdir_if_missing 9 | 10 | 11 | def visualize_ranked_results(distmat, dataset, save_dir='log/ranked_results', topk=20): 12 | """ 13 | Visualize ranked results 14 | Args: 15 | - distmat: distance matrix of shape (num_query, num_gallery). 16 | - dataset: a 2-tuple containing (query, gallery), each contains a list of (img_path, pid, camid); 17 | for imgreid, img_path is a string, while for vidreid, img_path is a tuple containing 18 | a sequence of strings. 19 | - save_dir: directory to save output images. 20 | - topk: int, denoting top-k images in the rank list to be visualized. 21 | """ 22 | num_q, num_g = distmat.shape 23 | 24 | print('Visualizing top-{} ranks'.format(topk)) 25 | print('# query: {}\n# gallery {}'.format(num_q, num_g)) 26 | print('Saving images to "{}"'.format(save_dir)) 27 | 28 | query, gallery = dataset 29 | #assert num_q == len(query) 30 | #assert num_g == len(gallery) 31 | 32 | indices = np.argsort(distmat, axis=1) 33 | mkdir_if_missing(save_dir) 34 | 35 | def _cp_img_to(src, dst, rank, prefix): 36 | """ 37 | - src: image path or tuple (for vidreid) 38 | - dst: target directory 39 | - rank: int, denoting ranked position, starting from 1 40 | - prefix: string 41 | """ 42 | if isinstance(src, tuple) or isinstance(src, list): 43 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3)) 44 | mkdir_if_missing(dst) 45 | for img_path in src: 46 | shutil.copy(img_path, dst) 47 | else: 48 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3) + '_name_' + osp.basename(src)) 49 | shutil.copy(src, dst) 50 | 51 | for q_idx in range(num_q): 52 | qimg_path, qpid, qcamid = query[q_idx] 53 | if isinstance(qimg_path, tuple) or isinstance(qimg_path, list): 54 | qdir = osp.join(save_dir, osp.basename(qimg_path[0])) 55 | else: 56 | qdir = osp.join(save_dir, osp.basename(qimg_path)) 57 | mkdir_if_missing(qdir) 58 | _cp_img_to(qimg_path, qdir, rank=0, prefix='query') 59 | 60 | rank_idx = 1 61 | for g_idx in indices[q_idx, :]: 62 | gimg_path, gpid, gcamid = gallery[g_idx] 63 | invalid = (qpid == gpid) & (qcamid == gcamid) 64 | if not invalid: 65 | _cp_img_to(gimg_path, qdir, rank=rank_idx, prefix='gallery') 66 | rank_idx += 1 67 | if rank_idx > topk: 68 | break 69 | 70 | print("Done") 71 | --------------------------------------------------------------------------------