├── LICENSE
├── README.md
├── config
├── __init__.py
└── defaults.py
├── configs
├── duke_r101.yml
├── vehicleid_r101.yml
└── veri_r101.yml
├── datasets
├── __init__.py
├── base.py
├── base_id.py
├── data_loading.py
├── duke.py
├── init_dataset.py
├── loader.py
├── test_loading.py
├── transform.py
├── vehicleid.py
└── veri.py
├── eval.py
├── images
├── affinity_matrix.png
└── architecture.png
├── loss
├── __init__.py
├── cross_entropy_loss.py
├── hard_mine_triplet_loss.py
└── losses.py
├── main.py
├── model
├── __init__.py
├── lr_schedulers.py
├── models.py
├── optimizers.py
├── resnet.py
└── senet.py
├── pkl
├── duke
│ └── index.pkl
├── vehicleid
│ └── index.pkl
└── veri
│ ├── cids.pkl
│ ├── data.pkl
│ └── index_vp.pkl
├── requirements.txt
├── train.py
└── utils
├── avgmeter.py
├── create_gms_index.py
├── evaluation.py
├── functions.py
├── generaltools.py
├── iotools.py
├── kwargs.py
├── loggers.py
├── mean_and_std.py
├── reranking.py
├── torchtools.py
└── visualtools.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Adhiraj Ghosh
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Relation Preserving Triplet Mining for Stabilising the Triplet Loss in Re-identification Sytems
3 |
WACV 2023
4 |
5 |
Adhiraj Ghosh1,2,
Kuruparan Shanmugalingam1,3,
Wen-Yan Lin1
6 |
7 |
1Singapore Management University
2University of Tübingen
3University of New South Wales
8 |
9 | [](https://paperswithcode.com/sota/vehicle-re-identification-on-veri-776?p=relation-preserving-triplet-mining-for)
10 | [](https://paperswithcode.com/sota/vehicle-re-identification-on-vehicleid-small?p=relation-preserving-triplet-mining-for)
11 |
12 |

13 |
14 | [[Paper](https://openaccess.thecvf.com/content/WACV2023/html/Ghosh_Relation_Preserving_Triplet_Mining_for_Stabilising_the_Triplet_Loss_In_WACV_2023_paper.html)]
15 | [[Video](https://youtu.be/TseV_Hoz2Ms?si=VlAReJ2eETPmYKh1)]
16 |
17 | The *official* repository for **Relation Preserving Triplet Mining for Stabilising the Triplet Loss in Re-identification Sytems**. Our work achieves state-of-the-art results and provides a faster optimised and more generalisable model for re-identification.
18 |
19 |
20 | ## Network Architecture
21 | 
22 |
23 | ## Preparation
24 |
25 | ### Installation
26 |
27 | 1. Install CUDA compatible torch. Modify based on CUDA version.
28 | ```
29 | conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia
30 | ```
31 | 2. Install other dependencies.
32 | ```bash
33 | pip install -r requirements.txt
34 | ```
35 |
36 | 3. Install apex (optional but recommended)
37 |
38 | Follow the installation guidelines from https://github.com/NVIDIA/apex
39 | Then set SOLVER.USE_AMP as True in the config files directly or via command line.
40 | ### Prepare Datasets
41 |
42 | ```bash
43 | mkdir data
44 | ```
45 |
46 | Download the vehicle reID datasets [VehicleID](https://www.pkuml.org/resources/pku-vehicleid.html) and [VeRi-776](https://github.com/JDAI-CV/VeRidataset), and the person reID datasets [DukeMTMC-reID](https://arxiv.org/abs/1609.01775).
47 | Follow the structure and naming convention as below.
48 |
49 | ```
50 | data
51 | ├── duke
52 | │ └── images ..
53 | ├── vehicleid
54 | │ └── images ..
55 | └── veri
56 | └── images ..
57 | ```
58 |
59 | ### Prepare GMS Feature Matches
60 | ```bash
61 | mkdir gms
62 | ```
63 |
64 | You need to download the GMS feature matches for VeRi, VehicleID and DukeMTMC: [GMS](https://drive.google.com/drive/folders/1hdk3pi4Bi_Tb2B7XcBmvwG91Sfisi6BO?usp=share_link).
65 |
66 | The folder should follow the structure as shown below:
67 | ```
68 | gms
69 | ├── duke
70 | │ └── 0001.pkl ..
71 | ├── vehicleid
72 | │ └── 00001.pkl ..
73 | └── veri
74 | └── 001.pkl ..
75 | ```
76 |
77 | You can also create your own GMS matches for VeRi-776, VeRi-Wild and VehicleID by running the following script: ```utils/create_gms_index.py```. You can edit which Dataset to build GMS matches for by editing the initial parameters inside the script.
78 |
79 | ## Running RPTM
80 | 1. Training
81 | ```bash
82 | python main.py --config_file configs/veri_r101.yml
83 | ```
84 | The above command trains a baseline using our RPTM algorithm for VeRi. Note that after training, the model provides evaluation results, both qualitative as well as quantitative.
85 |
86 | 2. RPTM Thresholding Strategies
87 |
88 | In Section 4.2 of our paper, we defined a thresholding strategy for better anchor-positive selections. We define this in config files as MODEL.RPTM_SELECT. While it is set to 'mean', feel free to work with 'min' and 'max'.
89 |
90 | #### Min Thresholding
91 | ```bash
92 | python main.py --config_file configs/veri_r101.yml MODEL.RPTM_SELECT 'min'
93 | ```
94 |
95 | #### Max Thresholding
96 | ```bash
97 | python main.py --config_file configs/veri_r101.yml MODEL.RPTM_SELECT 'max'
98 | ```
99 |
100 | 3. Testing
101 | ```bash
102 | mkdir logs
103 | python main.py --config_file configs/veri_r101.yml TEST.WEIGHT '' TEST.EVAL True
104 | ```
105 |
106 | ## Mean Average Precision(mAP) Results
107 | 1. VeRi776: **88.0%**
108 | 2. VehicleID (query size 800): **84.8%**
109 | 3. VehicleID (query size 1600): **81.2%**
110 | 4. VehicleID (query size 2400): **80.5%**
111 | 5. DukeMTMC: **89.2%**
112 |
113 | ## Acknowledgement
114 |
115 | GMS Feature Matching Algorithm taken from: https://github.com/JiawangBian/GMS-Feature-Matcher
116 |
117 | ## Citation
118 |
119 | If you find this code useful for your research, please cite our paper
120 |
121 | ```
122 | @InProceedings{Ghosh_2023_WACV,
123 | author = {Ghosh, Adhiraj and Shanmugalingam, Kuruparan and Lin, Wen-Yan},
124 | title = {Relation Preserving Triplet Mining for Stabilising the Triplet Loss In re-Identification Systems},
125 | booktitle = {Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision (WACV)},
126 | month = {January},
127 | year = {2023},
128 | pages = {4840-4849}
129 | }
130 | ```
131 |
132 | ## Contact
133 |
134 | If you have any questions, please feel free to contact us. E-mail: [Adhiraj Ghosh](mailto:adhirajghosh1998@gmail.com) , [Wen-Yan Lin](mailto:daniellin@smu.edu.sg)
135 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .defaults import _C as cfg
2 | from .defaults import _C as cfg_test
--------------------------------------------------------------------------------
/config/defaults.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | # -----------------------------------------------------------------------------
4 | # Config definition
5 | # -----------------------------------------------------------------------------
6 |
7 | _C = CN()
8 |
9 | # -----------------------------------------------------------------------------
10 | # MODEL
11 | # -----------------------------------------------------------------------------
12 | _C.MODEL = CN()
13 | _C.MODEL.DEVICE = "cuda"
14 | _C.MODEL.PRETRAIN_CHOICE= 'imagenet'
15 | _C.MODEL.PRETRAIN_PATH= ''
16 | _C.MODEL.ARCH= 'SE_net'
17 | _C.MODEL.DROPRATE= 0
18 | _C.MODEL.STRIDE= 1
19 | _C.MODEL.POOL= 'avg'
20 | _C.MODEL.GPU_ID= ('0')
21 | _C.MODEL.RPTM_SELECT= 'mean'
22 |
23 | # ---------------------------------------------------------------------------- #
24 | # Input options
25 | # ---------------------------------------------------------------------------- #
26 | _C.INPUT = CN()
27 | _C.INPUT.HEIGHT= 128
28 | _C.INPUT.WIDTH= 128
29 | _C.INPUT.PROB = 0.5
30 | _C.INPUT.RANDOM_ERASE = True
31 | _C.INPUT.JITTER= True
32 | _C.INPUT.AUG= True
33 |
34 | # ---------------------------------------------------------------------------- #
35 | # Dataset options
36 | # ---------------------------------------------------------------------------- #
37 |
38 | _C.DATASET = CN()
39 | _C.DATASET.SOURCE_NAME= ['veri']
40 | _C.DATASET.TARGET_NAME= ['veri']
41 | _C.DATASET.ROOT_DIR= ''
42 | _C.DATASET.TRAIN_DIR= ''
43 | _C.DATASET.SPLIT_DIR= ''
44 |
45 | # ---------------------------------------------------------------------------- #
46 | # Dataloader options
47 | # ---------------------------------------------------------------------------- #
48 | _C.DATALOADER = CN()
49 | _C.DATALOADER.SAMPLER= 'RandomSampler'
50 | _C.DATALOADER.NUM_INSTANCE= 6
51 | _C.DATALOADER.NUM_WORKERS= 16
52 |
53 | # ---------------------------------------------------------------------------- #
54 | # Solver options
55 | # ---------------------------------------------------------------------------- #
56 | _C.SOLVER = CN()
57 | _C.SOLVER.OPTIMIZER_NAME= 'SGD'
58 | _C.SOLVER.MAX_EPOCHS= 80
59 | _C.SOLVER.BASE_LR= 0.005
60 | _C.SOLVER.LR_SCHEDULER= 'multi-step'
61 | _C.SOLVER.STEPSIZE= [20, 40, 60]
62 | _C.SOLVER.GAMMA= 0.1
63 | _C.SOLVER.WEIGHT_DECAY= 5e-4
64 | _C.SOLVER.MOMENTUM= 0.9
65 | _C.SOLVER.SGD_DAMP= 0.0
66 | _C.SOLVER.NESTEROV= True
67 | _C.SOLVER.WARMUP_FACTOR= 0.01
68 | _C.SOLVER.WARMUP_EPOCHS= 10
69 | _C.SOLVER.WARMUP_METHOD= 'linear'
70 | _C.SOLVER.LARGE_FC_LR= False
71 | _C.SOLVER.TRAIN_BATCH_SIZE= 20
72 | _C.SOLVER.USE_AMP= True
73 | _C.SOLVER.CHECKPOINT_PERIOD= 10
74 | _C.SOLVER.LOG_PERIOD= 50
75 | _C.SOLVER.EVAL_PERIOD= 1
76 |
77 | # ---------------------------------------------------------------------------- #
78 | # Loss options
79 | # ---------------------------------------------------------------------------- #
80 | _C.LOSS = CN()
81 | _C.LOSS.MARGIN= 1.0
82 | _C.LOSS.LAMBDA_HTRI= 1.0
83 | _C.LOSS.LAMBDA_XENT= 1.0
84 |
85 | # ---------------------------------------------------------------------------- #
86 | # Test options
87 | # ---------------------------------------------------------------------------- #
88 | _C.TEST = CN()
89 | _C.TEST.EVAL= True
90 | _C.TEST.TEST_BATCH_SIZE= 100
91 | _C.TEST.TEST_SIZE = 11579
92 | _C.TEST.RE_RANKING= True
93 | _C.TEST.VIS_RANK= True
94 | _C.TEST.WEIGHT= ''
95 | _C.TEST.NECK_FEAT= 'after'
96 | _C.TEST.FEAT_NORM= 'yes'
97 |
98 | # ---------------------------------------------------------------------------- #
99 | # Misc options
100 | # ---------------------------------------------------------------------------- #
101 | _C.MISC = CN()
102 | _C.MISC.SAVE_DIR= './logs/veri/'
103 | _C.MISC.GMS_PATH= './gms/veri/'
104 | _C.MISC.INDEX_PATH= './pkl/veri/index_vp.pkl'
105 | _C.MISC.USE_GPU= True
106 | _C.MISC.PRINT_FREQ= 100
107 | _C.MISC.FP16= True
108 |
--------------------------------------------------------------------------------
/configs/duke_r101.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | ARCH: 'resnet101_ibn_a'
4 | DROPRATE: 0
5 | STRIDE: 1
6 | POOL: 'avg'
7 | GPU_ID: ('0')
8 | RPTM_SELECT: 'mean'
9 |
10 | INPUT:
11 | HEIGHT: 300
12 | WIDTH: 150
13 | PROB: 0.5 # random horizontal flip
14 | RANDOM_ERASE: True
15 | JITTER: True
16 | AUG: True
17 |
18 |
19 | DATASET:
20 | SOURCE_NAME: ['duke']
21 | TARGET_NAME: ['duke']
22 | ROOT_DIR: './data/'
23 | TRAIN_DIR: './data/duke/image_train/'
24 | SPLIT_DIR: './data/duke/train_split/'
25 |
26 | DATALOADER:
27 | SAMPLER: 'RandomSampler'
28 | NUM_INSTANCE: 6
29 | NUM_WORKERS: 16
30 |
31 | SOLVER:
32 | OPTIMIZER_NAME: 'sgd'
33 | MAX_EPOCHS: 80
34 | BASE_LR: 0.005
35 | LR_SCHEDULER: 'multi-step'
36 | STEPSIZE: [20,40,60]
37 | GAMMA: 0.1
38 | WEIGHT_DECAY: 5e-4
39 | MOMENTUM: 0.9
40 | SGD_DAMP: 0.0
41 | NESTEROV: True
42 | WARMUP_FACTOR: 0.01
43 | WARMUP_EPOCHS: 10
44 | WARMUP_METHOD: 'linear'
45 | LARGE_FC_LR: False
46 | TRAIN_BATCH_SIZE: 20
47 | USE_AMP: False
48 | CHECKPOINT_PERIOD: 10
49 | LOG_PERIOD: 50
50 | EVAL_PERIOD: 1
51 |
52 | LOSS:
53 | MARGIN: 1.0
54 | LAMBDA_HTRI: 1.0
55 | LAMBDA_XENT: 1.0
56 |
57 | TEST:
58 | EVAL: False
59 | WEIGHT: ''
60 | TEST_BATCH_SIZE: 100
61 | RE_RANKING: True
62 | VIS_RANK: True
63 | NECK_FEAT: 'after'
64 | FEAT_NORM: 'yes'
65 |
66 | MISC:
67 | SAVE_DIR: './logs/duke/'
68 | GMS_PATH: './gms/duke/'
69 | INDEX_PATH: './pkl/duke/index.pkl'
70 | USE_GPU: True
71 | PRINT_FREQ: 100
72 | FP16: True
73 |
74 |
75 |
76 |
77 |
--------------------------------------------------------------------------------
/configs/vehicleid_r101.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | ARCH: 'resnet101_ibn_a'
4 | DROPRATE: 0
5 | STRIDE: 1
6 | POOL: 'avg'
7 | GPU_ID: ('0')
8 | RPTM_SELECT: 'mean'
9 |
10 | INPUT:
11 | HEIGHT: 128
12 | WIDTH: 128
13 | PROB: 0.5 # random horizontal flip
14 | RANDOM_ERASE: True
15 | JITTER: True
16 | AUG: True
17 |
18 |
19 | DATASET:
20 | SOURCE_NAME: ['vehicleid']
21 | TARGET_NAME: ['vehicleid']
22 | ROOT_DIR: './data/'
23 | TRAIN_DIR: './data/vehicleid/image_train/'
24 | SPLIT_DIR: './data/vehicleid/train_split/'
25 |
26 | DATALOADER:
27 | SAMPLER: 'RandomSampler'
28 | NUM_INSTANCE: 6
29 | NUM_WORKERS: 16
30 |
31 | SOLVER:
32 | OPTIMIZER_NAME: 'sgd'
33 | MAX_EPOCHS: 40
34 | BASE_LR: 0.005
35 | LR_SCHEDULER: 'multi-step'
36 | STEPSIZE: [10, 20, 30]
37 | GAMMA: 0.1
38 | WEIGHT_DECAY: 5e-4
39 | MOMENTUM: 0.9
40 | SGD_DAMP: 0.0
41 | NESTEROV: True
42 | WARMUP_FACTOR: 0.01
43 | WARMUP_EPOCHS: 10
44 | WARMUP_METHOD: 'linear'
45 | LARGE_FC_LR: False
46 | TRAIN_BATCH_SIZE: 20
47 | USE_AMP: False
48 | CHECKPOINT_PERIOD: 10
49 | LOG_PERIOD: 50
50 | EVAL_PERIOD: 1
51 |
52 | LOSS:
53 | MARGIN: 1.0
54 | LAMBDA_HTRI: 1.0
55 | LAMBDA_XENT: 1.0
56 |
57 | TEST:
58 | EVAL: True
59 | WEIGHT: ''
60 | TEST_BATCH_SIZE: 100
61 | TEST_SIZE: 800
62 | RE_RANKING: True
63 | VIS_RANK: True
64 | NECK_FEAT: 'after'
65 | FEAT_NORM: 'yes'
66 |
67 | MISC:
68 | SAVE_DIR: './logs/vehicleid/'
69 | GMS_PATH: './gms/vehicleid/'
70 | INDEX_PATH: './pkl/vehicleid/index.pkl'
71 | USE_GPU: True
72 | PRINT_FREQ: 100
73 | FP16: True
74 |
75 |
76 |
77 |
78 |
--------------------------------------------------------------------------------
/configs/veri_r101.yml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | PRETRAIN_CHOICE: 'imagenet'
3 | ARCH: 'resnet101_ibn_a'
4 | DROPRATE: 0
5 | STRIDE: 1
6 | POOL: 'avg'
7 | GPU_ID: ('0')
8 | RPTM_SELECT: 'mean'
9 |
10 | INPUT:
11 | HEIGHT: 128
12 | WIDTH: 128
13 | PROB: 0.5 # random horizontal flip
14 | RANDOM_ERASE: True
15 | JITTER: True
16 | AUG: True
17 |
18 |
19 | DATASET:
20 | SOURCE_NAME: ['veri']
21 | TARGET_NAME: ['veri']
22 | ROOT_DIR: './data/'
23 | TRAIN_DIR: './data/veri/image_train/'
24 | SPLIT_DIR: './data/veri/train_split/'
25 |
26 | DATALOADER:
27 | SAMPLER: 'RandomSampler'
28 | NUM_INSTANCE: 6
29 | NUM_WORKERS: 16
30 |
31 | SOLVER:
32 | OPTIMIZER_NAME: 'sgd'
33 | MAX_EPOCHS: 80
34 | BASE_LR: 0.005
35 | LR_SCHEDULER: 'multi-step'
36 | STEPSIZE: [20,40,60]
37 | GAMMA: 0.1
38 | WEIGHT_DECAY: 5e-4
39 | MOMENTUM: 0.9
40 | SGD_DAMP: 0.0
41 | NESTEROV: True
42 | WARMUP_FACTOR: 0.01
43 | WARMUP_EPOCHS: 10
44 | WARMUP_METHOD: 'linear'
45 | LARGE_FC_LR: False
46 | TRAIN_BATCH_SIZE: 20
47 | USE_AMP: False
48 | CHECKPOINT_PERIOD: 10
49 | LOG_PERIOD: 50
50 | EVAL_PERIOD: 1
51 |
52 | LOSS:
53 | MARGIN: 1.0
54 | LAMBDA_HTRI: 1.0
55 | LAMBDA_XENT: 1.0
56 |
57 | TEST:
58 | EVAL: True
59 | WEIGHT: ''
60 | TEST_BATCH_SIZE: 100
61 | RE_RANKING: True
62 | VIS_RANK: True
63 | NECK_FEAT: 'after'
64 | FEAT_NORM: 'yes'
65 |
66 | MISC:
67 | SAVE_DIR: './logs/veri/'
68 | GMS_PATH: './gms/veri/'
69 | INDEX_PATH: './pkl/veri/index_vp.pkl'
70 | USE_GPU: True
71 | PRINT_FREQ: 100
72 | FP16: True
73 |
74 |
75 |
76 |
77 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 |
6 | from .loader import *
7 | # from .data_loading import *
8 | # from .test_loading import *
9 | # from .transform import *
10 |
11 |
--------------------------------------------------------------------------------
/datasets/base.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 |
4 | import os.path as osp
5 |
6 |
7 | class BaseDataset(object):
8 | """
9 | Base class of reid dataset
10 | """
11 |
12 | def __init__(self, root):
13 | self.root = osp.expanduser(root)
14 |
15 | def get_imagedata_info(self, data):
16 | pids, cams = [], []
17 | for _, pid, camid in data:
18 | pids += [pid]
19 | cams += [camid]
20 | pids = set(pids)
21 | cams = set(cams)
22 | num_pids = len(pids)
23 | num_cams = len(cams)
24 | num_imgs = len(data)
25 | return num_pids, num_imgs, num_cams
26 |
27 | def print_dataset_statistics(self):
28 | raise NotImplementedError
29 |
30 |
31 | class BaseImageDataset(BaseDataset):
32 | """
33 | Base class of image reid dataset
34 | """
35 |
36 | def print_dataset_statistics(self, train, query, gallery):
37 | num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train)
38 | #num_val_pids, num_val_imgs, num_val_cams = self.get_imagedata_info(val)
39 | num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query)
40 | num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery)
41 |
42 | print('Image Dataset statistics:')
43 | print(' ----------------------------------------')
44 | print(' subset | # ids | # images | # cameras')
45 | print(' ----------------------------------------')
46 | print(' train | {:5d} | {:8d} | {:9d}'.format(num_train_pids, num_train_imgs, num_train_cams))
47 | #print(' val | {:5d} | {:8d} | {:9d}'.format(num_val_pids, num_val_imgs, num_val_cams))
48 | print(' query | {:5d} | {:8d} | {:9d}'.format(num_query_pids, num_query_imgs, num_query_cams))
49 | print(' gallery | {:5d} | {:8d} | {:9d}'.format(num_gallery_pids, num_gallery_imgs, num_gallery_cams))
50 | print(' ----------------------------------------')
51 |
--------------------------------------------------------------------------------
/datasets/base_id.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 |
4 | import os.path as osp
5 |
6 |
7 | class BaseDataset(object):
8 | """
9 | Base class of reid dataset
10 | """
11 |
12 | def __init__(self, root):
13 | self.root = osp.expanduser(root)
14 |
15 | def get_imagedata_info(self, data):
16 | pids = []
17 | for _, pid in data:
18 | pids += [pid]
19 |
20 | pids = set(pids)
21 | num_pids = len(pids)
22 | num_imgs = len(data)
23 | return num_pids, num_imgs
24 |
25 | def print_dataset_statistics(self):
26 | raise NotImplementedError
27 |
28 |
29 | class BaseImageDataset(BaseDataset):
30 | """
31 | Base class of image reid dataset
32 | """
33 |
34 | def print_dataset_statistics(self, train, query, gallery):
35 | num_train_pids, num_train_imgs = self.get_imagedata_info(train)
36 | #num_val_pids, num_val_imgs, num_val_cams = self.get_imagedata_info(val)
37 | num_query_pids, num_query_imgs = self.get_imagedata_info(query)
38 | num_gallery_pids, num_gallery_imgs = self.get_imagedata_info(gallery)
39 |
40 | print('Image Dataset statistics:')
41 | print(' ----------------------------')
42 | print(' subset | # ids | # images ')
43 | print(' ----------------------------')
44 | print(' train | {:5d} | {:8d} '.format(num_train_pids, num_train_imgs))
45 | #print(' val | {:5d} | {:8d} | {:9d}'.format(num_val_pids, num_val_imgs, num_val_cams))
46 | print(' query | {:5d} | {:8d} '.format(num_query_pids, num_query_imgs))
47 | print(' gallery | {:5d} | {:8d} '.format(num_gallery_pids, num_gallery_imgs))
48 | print(' ----------------------------')
49 |
--------------------------------------------------------------------------------
/datasets/data_loading.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import torch
3 | import os
4 | import sys
5 | import time
6 | import datetime
7 | import os.path as osp
8 | import numpy as np
9 | import warnings
10 | from PIL import Image
11 | from skimage import io, transform
12 | from torch.utils.data import Dataset
13 |
14 | class VeriDataset(Dataset):
15 | """Veri dataset."""
16 |
17 | def __init__(self, pkl_file, dataset, root_dir, transform=None):
18 |
19 | with open(pkl_file, 'rb') as handle:
20 | c = pickle.load(handle)
21 | self.index = c
22 | self.root_dir = root_dir
23 | self.dataset = dataset
24 | self.transform = transform
25 |
26 | def __len__(self):
27 | return len(self.dataset)
28 |
29 | def __getitem__(self, idx):
30 | if torch.is_tensor(idx):
31 | idx = idx.tolist()
32 |
33 | img_name = os.path.join(self.root_dir,
34 | self.dataset[idx][0])
35 | img = Image.open(os.path.join(self.root_dir, img_name[-24:])).convert('RGB')
36 | label = self.dataset[idx][1]
37 | pid = self.dataset[idx][2]
38 | cid = self.dataset[idx][3]
39 | if self.dataset[idx][0] not in self.index:
40 | index = 0
41 | else:
42 | index = self.index[self.dataset[idx][0]][1]
43 |
44 | if self.transform:
45 | img = self.transform(img)
46 |
47 | return img,label,index,pid, cid
48 |
49 | class IdDataset(Dataset):
50 | """VehicleId dataset."""
51 |
52 | def __init__(self, pkl_file, dataset, root_dir, transform=None):
53 |
54 | with open(pkl_file, 'rb') as handle:
55 | c = pickle.load(handle)
56 | self.index = c
57 | self.root_dir = root_dir
58 | self.dataset = dataset
59 | self.transform = transform
60 |
61 | def __len__(self):
62 | return len(self.dataset)
63 |
64 | def __getitem__(self, idx):
65 | if torch.is_tensor(idx):
66 | idx = idx.tolist()
67 |
68 | img_name = os.path.join(self.root_dir,
69 | self.dataset[idx][0])
70 | img = Image.open(os.path.join(self.root_dir, img_name[-17:])).convert('RGB')
71 | label = self.dataset[idx][1]
72 | pid = self.dataset[idx][2]
73 | cid = self.dataset[idx][3]
74 | index = self.index[self.dataset[idx][0]][1]
75 |
76 |
77 | if self.transform:
78 | img = self.transform(img)
79 |
80 | return img,label,index,pid, cid
81 |
82 | class DukeDataset(Dataset):
83 | """Duke dataset."""
84 |
85 | def __init__(self, pkl_file, dataset, root_dir, transform=None):
86 |
87 | with open(pkl_file, 'rb') as handle:
88 | c = pickle.load(handle)
89 | self.index = c
90 | self.root_dir = root_dir
91 | self.dataset = dataset
92 | self.transform = transform
93 |
94 | def __len__(self):
95 | return len(self.dataset)
96 |
97 | def __getitem__(self, idx):
98 | if torch.is_tensor(idx):
99 | idx = idx.tolist()
100 |
101 | img_name = os.path.join(self.root_dir,
102 | self.dataset[idx][0])
103 | img = Image.open(os.path.join(self.root_dir, img_name[-20:])).convert('RGB')
104 | label = self.dataset[idx][1]
105 | pid = self.dataset[idx][2]
106 | cid = self.dataset[idx][3]
107 | index = self.index[self.dataset[idx][0]][1]
108 |
109 |
110 | if self.transform:
111 | img = self.transform(img)
112 |
113 | return img,label,index,pid, cid
114 |
115 |
116 | def read_image(img_path):
117 | """Keep reading image until succeed.
118 | This can avoid IOError incurred by heavy IO process."""
119 | got_img = False
120 | if not osp.exists(img_path):
121 | raise IOError('{} does not exist'.format(img_path))
122 | while not got_img:
123 | try:
124 | img = Image.open(img_path).convert('RGB')
125 | got_img = True
126 | except IOError:
127 | print('IOError incurred when reading "{}". Will redo. Don\'t worry. Just chill.'.format(img_path))
128 | pass
129 | return img
130 |
131 |
132 | class ImageDataset(Dataset):
133 | """Image Person ReID Dataset"""
134 |
135 | def __init__(self, dataset, transform=None):
136 | self.dataset = dataset
137 | self.transform = transform
138 |
139 | def __len__(self):
140 | return len(self.dataset)
141 |
142 | def __getitem__(self, index):
143 | img_path, pid, camid = self.dataset[index]
144 | img = read_image(img_path)
145 |
146 | if self.transform is not None:
147 | img = self.transform(img)
148 |
149 | return img, pid, camid, img_path
150 |
151 | class IdImageDataset(Dataset):
152 | """Image Person ReID Dataset"""
153 |
154 | def __init__(self, dataset, transform=None):
155 | self.dataset = dataset
156 | self.transform = transform
157 |
158 | def __len__(self):
159 | return len(self.dataset)
160 |
161 | def __getitem__(self, index):
162 | img_path, pid = self.dataset[index]
163 | img = read_image(img_path)
164 |
165 | if self.transform is not None:
166 | img = self.transform(img)
167 |
168 | return img, pid, _, img_path
169 |
--------------------------------------------------------------------------------
/datasets/duke.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import glob
6 | import re
7 | import os.path as osp
8 |
9 | from .base import BaseImageDataset
10 |
11 |
12 | class duke(BaseImageDataset):
13 |
14 | dataset_dir = 'duke'
15 |
16 | def __init__(self, root='datasets', dataset_dir = 'duke', verbose=True, **kwargs):
17 | super(duke, self).__init__(root)
18 | self.dataset_dir = osp.join(self.root, self.dataset_dir)
19 | self.train_dir = osp.join(self.dataset_dir, 'image_train')
20 | #self.val_dir = osp.join(self.dataset_dir, 'image_val')
21 | self.query_dir = osp.join(self.dataset_dir, 'image_query')
22 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test')
23 |
24 | self.check_before_run()
25 |
26 | train = self.process_dir(self.train_dir, relabel=True)
27 | #val = self.process_dir(self.val_dir, relabel=True)
28 | query = self.process_dir(self.query_dir, relabel=False)
29 | gallery = self.process_dir(self.gallery_dir, relabel=False)
30 |
31 | if verbose:
32 | print('=> Duke loaded')
33 | self.print_dataset_statistics(train, query, gallery)
34 |
35 | self.train = train
36 | #self.val = val
37 | self.query = query
38 | self.gallery = gallery
39 |
40 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
41 | #self.num_val_pids, self.num_val_imgs, self.num_val_cams = self.get_imagedata_info(self.val)
42 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
43 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
44 |
45 | def check_before_run(self):
46 | """Check if all files are available before going deeper"""
47 | if not osp.exists(self.dataset_dir):
48 | raise RuntimeError('"{}" is not available'.format(self.dataset_dir))
49 | if not osp.exists(self.train_dir):
50 | raise RuntimeError('"{}" is not available'.format(self.train_dir))
51 | if not osp.exists(self.query_dir):
52 | raise RuntimeError('"{}" is not available'.format(self.query_dir))
53 | if not osp.exists(self.gallery_dir):
54 | raise RuntimeError('"{}" is not available'.format(self.gallery_dir))
55 |
56 | def process_dir(self, dir_path, relabel=False):
57 | img_paths = sorted(glob.glob(osp.join(dir_path, '*.jpg')))
58 | pattern = re.compile(r'([-\d]+)_c([-\d]+)')
59 |
60 | pid_container = set()
61 | for img_path in img_paths:
62 | pid, _ = map(int, pattern.search(img_path).groups())
63 | if pid == -1:
64 | continue # junk images are just ignored
65 | pid_container.add(pid)
66 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
67 |
68 | dataset = []
69 | for img_path in img_paths:
70 | pid, camid = map(int, pattern.search(img_path).groups())
71 | if pid == -1:
72 | continue # junk images are just ignored
73 | assert 0 <= pid <= 7140 # pid == 0 means background
74 | assert 1 <= camid <= 20
75 | camid -= 1 # index starts from 0
76 | if relabel:
77 | pid = pid2label[pid]
78 | dataset.append((img_path, pid, camid))
79 |
80 | return dataset
81 |
--------------------------------------------------------------------------------
/datasets/init_dataset.py:
--------------------------------------------------------------------------------
1 | from .veri import VeRi
2 | from .vehicleid import VehicleID
3 | from .duke import duke
4 |
5 |
6 | __imgreid_factory = {
7 | 'veri': VeRi,
8 | 'vehicleID': VehicleID,
9 | 'duke': duke,
10 | }
11 | def init_imgreid_dataset(name, **kwargs):
12 | if name not in list(__imgreid_factory.keys()):
13 | raise KeyError('Invalid dataset, got "{}", but expected to be one of {}'.format(name, list(__imgreid_factory.keys())))
14 | return __imgreid_factory[name](**kwargs)
15 |
--------------------------------------------------------------------------------
/datasets/loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from torch.utils.data import DataLoader
4 | from .init_dataset import init_imgreid_dataset
5 | from .transform import *
6 | from .data_loading import VeriDataset as vd
7 | from .data_loading import IdDataset as id
8 | from .data_loading import DukeDataset as dd
9 | from .test_loading import ImageDataManager
10 |
11 | def data_loader(cfg, dataset_kwargs, transform_kwargs):
12 | dataset = init_imgreid_dataset(root=cfg.DATASET.ROOT_DIR, name=cfg.DATASET.SOURCE_NAME[0])
13 | num_train_pids = 0
14 | num_train_cams = 0
15 | train = []
16 |
17 | for img_path, pid, camid in dataset.train:
18 | # path = img_path[-24:]
19 | path = img_path.split('/', 4)[-1]
20 | if cfg.DATASET.SOURCE_NAME[0] == 'veri':
21 | folder = path.split('_', 1)[0][1:]
22 | else:
23 | folder = path.split('_', 1)[0]
24 | pid += num_train_pids
25 | camid += num_train_cams
26 | train.append((path, folder, pid, camid))
27 |
28 | num_train_pids += dataset.num_train_pids
29 | class_names = num_train_pids
30 | num_train_cams += dataset.num_train_cams
31 |
32 | pid = 0
33 | pidx = {}
34 | for img_path, pid, camid in dataset.train:
35 | path = img_path.split('/', 4)[-1]
36 | if cfg.DATASET.SOURCE_NAME[0] == 'veri':
37 | folder = path.split('_', 1)[0][1:]
38 | else:
39 | folder = path.split('_', 1)[0]
40 | pidx[folder] = pid
41 | pid += 1
42 |
43 | gms = {}
44 | entries = sorted(os.listdir(cfg.MISC.GMS_PATH))
45 | # print(entries)
46 | for name in entries:
47 | f = open((cfg.MISC.GMS_PATH + name), 'rb')
48 | if name == 'featureMatrix.pkl':
49 | s = name[0:13]
50 | else:
51 | s = name[0:3]
52 | gms[s] = pickle.load(f)
53 | f.close
54 |
55 | transform_t = train_transforms(**transform_kwargs)
56 | if cfg.DATASET.SOURCE_NAME[0] == 'veri':
57 | data_tfr = vd(pkl_file=cfg.MISC.INDEX_PATH, dataset=train, root_dir=cfg.DATASET.TRAIN_DIR, transform=transform_t)
58 | elif cfg.DATASET.SOURCE_NAME[0] == 'vehicleid':
59 | data_tfr = id(pkl_file=cfg.MISC.INDEX_PATH, dataset=train, root_dir=cfg.DATASET.TRAIN_DIR, transform=transform_t)
60 | elif cfg.DATASET.SOURCE_NAME[0] == 'duke':
61 | data_tfr = dd(pkl_file=cfg.MISC.INDEX_PATH, dataset=train, root_dir=cfg.DATASET.TRAIN_DIR, transform=transform_t)
62 | trainloader = DataLoader(data_tfr, sampler=None, batch_size=cfg.SOLVER.TRAIN_BATCH_SIZE, shuffle=True, num_workers=cfg.DATALOADER.NUM_WORKERS,
63 | pin_memory=False, drop_last=True)
64 |
65 | print('Initializing test data manager')
66 | dm = ImageDataManager(cfg.MISC.USE_GPU, **dataset_kwargs)
67 | testloader_dict = dm.return_dataloaders()
68 | train_dict = {}
69 | train_dict['class_names'] = class_names
70 | train_dict['num_train_pids'] = num_train_pids
71 | train_dict['gms'] = gms
72 | train_dict['pidx'] = pidx
73 |
74 |
75 | return trainloader, train_dict, data_tfr, testloader_dict, dm
76 |
--------------------------------------------------------------------------------
/datasets/test_loading.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 |
4 | from torch.utils.data import DataLoader
5 | from torchvision import transforms, utils
6 | from torchvision.transforms import *
7 | from .data_loading import ImageDataset
8 | from .init_dataset import init_imgreid_dataset
9 | from .transform import test_transform
10 |
11 |
12 | class BaseDataManager(object):
13 |
14 | def __init__(self,
15 | use_gpu,
16 | source_names,
17 | target_names,
18 | root='datasets',
19 | height=128,
20 | width=256,
21 | train_batch_size=32,
22 | test_batch_size=100,
23 | workers=4,
24 | train_sampler='',
25 | val_sampler='',
26 | random_erase=False, # use random erasing for data augmentation
27 | color_jitter=False, # randomly change the brightness, contrast and saturation
28 | color_aug=False, # randomly alter the intensities of RGB channels
29 | num_instances=4, # number of instances per identity (for RandomIdentitySampler)
30 | **kwargs
31 | ):
32 | self.use_gpu = use_gpu
33 | self.source_names = source_names
34 | self.target_names = target_names
35 | self.root = root
36 | self.height = height
37 | self.width = width
38 | self.train_batch_size = train_batch_size
39 | self.test_batch_size = test_batch_size
40 | self.workers = workers
41 | self.train_sampler = train_sampler
42 | self.val_sampler = val_sampler
43 | self.random_erase = random_erase
44 | self.color_jitter = color_jitter
45 | self.color_aug = color_aug
46 | self.num_instances = num_instances
47 |
48 | transform_test = test_transform(self.height, self.width)
49 | self.transform_test = transform_test
50 |
51 |
52 | def return_dataloaders(self):
53 | """
54 | Return testloader dictionary
55 | """
56 | return self.testloader_dict
57 |
58 | def return_testdataset_by_name(self, name):
59 | """
60 | Return query and gallery, each containing a list of (img_path, pid, camid).
61 | """
62 | return self.testdataset_dict[name]['query'], self.testdataset_dict[name]['gallery']
63 |
64 |
65 | class ImageDataManager(BaseDataManager):
66 | """
67 | Vehicle-ReID data manager
68 | """
69 | def __init__(self,
70 | use_gpu,
71 | source_names,
72 | target_names,
73 | **kwargs
74 | ):
75 | super(ImageDataManager, self).__init__(use_gpu, source_names, target_names, **kwargs)
76 |
77 | print('=> Initializing TEST (target) datasets')
78 | self.testloader_dict = {name: {'query': None, 'gallery': None} for name in target_names}
79 | self.testdataset_dict = {name: {'query': None, 'gallery': None} for name in target_names}
80 |
81 | for name in self.target_names:
82 | dataset = init_imgreid_dataset(
83 | root=self.root, name=name)
84 |
85 | self.testloader_dict[name]['query'] = DataLoader(
86 | ImageDataset(dataset.query, transform=self.transform_test),
87 | batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
88 | pin_memory=self.use_gpu, drop_last=False
89 | )
90 |
91 | self.testloader_dict[name]['gallery'] = DataLoader(
92 | ImageDataset(dataset.gallery, transform=self.transform_test),
93 | batch_size=self.test_batch_size, shuffle=False, num_workers=self.workers,
94 | pin_memory=self.use_gpu, drop_last=False
95 | )
96 |
97 | self.testdataset_dict[name]['query'] = dataset.query
98 | self.testdataset_dict[name]['gallery'] = dataset.gallery
99 |
--------------------------------------------------------------------------------
/datasets/transform.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from PIL import Image
6 | import random
7 | import math
8 |
9 | import torch
10 | from torchvision.transforms import *
11 |
12 | class Random2DTranslation(object):
13 | """
14 | With a probability, first increase image size to (1 + 1/8), and then perform random crop.
15 | Args:
16 | - height (int): target image height.
17 | - width (int): target image width.
18 | - p (float): probability of performing this transformation. Default: 0.5.
19 | """
20 |
21 | def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR):
22 | self.height = height
23 | self.width = width
24 | self.p = p
25 | self.interpolation = interpolation
26 |
27 | def __call__(self, img):
28 | """
29 | Args:
30 | - img (PIL Image): Image to be cropped.
31 | """
32 | if random.uniform(0, 1) > self.p:
33 | return img.resize((self.width, self.height), self.interpolation)
34 |
35 | new_width, new_height = int(round(self.width * 1.125)), int(round(self.height * 1.125))
36 | resized_img = img.resize((new_width, new_height), self.interpolation)
37 | x_maxrange = new_width - self.width
38 | y_maxrange = new_height - self.height
39 | x1 = int(round(random.uniform(0, x_maxrange)))
40 | y1 = int(round(random.uniform(0, y_maxrange)))
41 | croped_img = resized_img.crop((x1, y1, x1 + self.width, y1 + self.height))
42 | return croped_img
43 |
44 |
45 | class RandomErasing(object):
46 |
47 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=[0.4914, 0.4822, 0.4465]):
48 | self.probability = probability
49 | self.mean = mean
50 | self.sl = sl
51 | self.sh = sh
52 | self.r1 = r1
53 |
54 | def __call__(self, img):
55 |
56 | if random.uniform(0, 1) > self.probability:
57 | return img
58 |
59 | for attempt in range(100):
60 | area = img.size()[1] * img.size()[2]
61 |
62 | target_area = random.uniform(self.sl, self.sh) * area
63 | aspect_ratio = random.uniform(self.r1, 1 / self.r1)
64 |
65 | h = int(round(math.sqrt(target_area * aspect_ratio)))
66 | w = int(round(math.sqrt(target_area / aspect_ratio)))
67 |
68 | if w < img.size()[2] and h < img.size()[1]:
69 | x1 = random.randint(0, img.size()[1] - h)
70 | y1 = random.randint(0, img.size()[2] - w)
71 | if img.size()[0] == 3:
72 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
73 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
74 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
75 | else:
76 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
77 | return img
78 |
79 | return img
80 |
81 |
82 | class ColorAugmentation(object):
83 | """
84 | Randomly alter the intensities of RGB channels
85 | Reference:
86 | Krizhevsky et al. ImageNet Classification with Deep ConvolutionalNeural Networks. NIPS 2012.
87 | """
88 |
89 | def __init__(self, p=0.5):
90 | self.p = p
91 | self.eig_vec = torch.Tensor([
92 | [0.4009, 0.7192, -0.5675],
93 | [-0.8140, -0.0045, -0.5808],
94 | [0.4203, -0.6948, -0.5836],
95 | ])
96 | self.eig_val = torch.Tensor([[0.2175, 0.0188, 0.0045]])
97 |
98 | def _check_input(self, tensor):
99 | assert tensor.dim() == 3 and tensor.size(0) == 3
100 |
101 | def __call__(self, tensor):
102 | if random.uniform(0, 1) > self.p:
103 | return tensor
104 | alpha = torch.normal(mean=torch.zeros_like(self.eig_val)) * 0.1
105 | quatity = torch.mm(self.eig_val * alpha, self.eig_vec)
106 | tensor = tensor + quatity.view(3, 1, 1)
107 | return tensor
108 |
109 |
110 | def build_transforms(height,
111 | width,
112 | random_erase=False, # use random erasing for data augmentation
113 | color_jitter=False, # randomly change the brightness, contrast and saturation
114 | color_aug=False, # randomly alter the intensities of RGB channels
115 | **kwargs):
116 | # use imagenet mean and std as default
117 | # TODO: compute dataset-specific mean and std
118 | imagenet_mean = [0.485, 0.456, 0.406]
119 | imagenet_std = [0.229, 0.224, 0.225]
120 | normalize = Normalize(mean=imagenet_mean, std=imagenet_std)
121 |
122 | # build train transformations
123 | transform_train = []
124 | transform_train += [Random2DTranslation(height, width)]
125 | transform_train += [RandomHorizontalFlip()]
126 | if color_jitter:
127 | transform_train += [ColorJitter(brightness=0.2, contrast=0.15, saturation=0, hue=0)]
128 | transform_train += [ToTensor()]
129 | if color_aug:
130 | transform_train += [ColorAugmentation()]
131 | transform_train += [normalize]
132 | if random_erase:
133 | transform_train += [RandomErasing()]
134 | transform_train = Compose(transform_train)
135 |
136 | transform_val = []
137 | transform_val += [Random2DTranslation(height, width)]
138 | transform_val += [RandomHorizontalFlip()]
139 | if color_jitter:
140 | transform_val += [ColorJitter(brightness=0.2, contrast=0.15, saturation=0, hue=0)]
141 | transform_val += [ToTensor()]
142 | if color_aug:
143 | transform_val += [ColorAugmentation()]
144 | transform_val += [normalize]
145 | if random_erase:
146 | transform_val += [RandomErasing()]
147 | transform_val = Compose(transform_val)
148 |
149 | # build test transformations
150 | transform_test = Compose([
151 | Resize((height, width)),
152 | ToTensor(),
153 | normalize,
154 | ])
155 |
156 | return transform_train, transform_val, transform_test
157 | #return transform_train, transform_test
158 |
159 | def train_transforms(height,
160 | width,
161 | random_erase=False, # use random erasing for data augmentation
162 | color_jitter=False, # randomly change the brightness, contrast and saturation
163 | color_aug=False, # randomly alter the intensities of RGB channels
164 | **kwargs):
165 | # use imagenet mean and std as default
166 | # TODO: compute dataset-specific mean and std
167 | imagenet_mean = [0.485, 0.456, 0.406]
168 | imagenet_std = [0.229, 0.224, 0.225]
169 | normalize = Normalize(mean=imagenet_mean, std=imagenet_std)
170 |
171 | # build train transformations
172 | transform_train = []
173 | transform_train += [Random2DTranslation(height, width)]
174 | transform_train += [RandomHorizontalFlip()]
175 | if color_jitter:
176 | transform_train += [ColorJitter(brightness=0.2, contrast=0.15, saturation=0, hue=0)]
177 | transform_train += [ToTensor()]
178 | if color_aug:
179 | transform_train += [ColorAugmentation()]
180 | transform_train += [normalize]
181 | if random_erase:
182 | transform_train += [RandomErasing()]
183 | transform_train = Compose(transform_train)
184 |
185 |
186 | return transform_train
187 |
188 |
189 | def test_transform(height, width):
190 | imagenet_mean = [0.485, 0.456, 0.406]
191 | imagenet_std = [0.229, 0.224, 0.225]
192 | normalize = Normalize(mean=imagenet_mean, std=imagenet_std)
193 |
194 | transform_test = Compose([
195 | Resize((height, width)),
196 | ToTensor(),
197 | normalize,
198 | ])
199 |
200 | return transform_test
201 |
--------------------------------------------------------------------------------
/datasets/vehicleid.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import random
6 | import os
7 | import os.path as osp
8 | import glob
9 | import re
10 |
11 | from .base import BaseImageDataset
12 | from collections import defaultdict
13 |
14 |
15 | class VehicleID(BaseImageDataset):
16 | """
17 | VehicleID
18 |
19 | Reference:
20 | @inproceedings{liu2016deep,
21 | title={Deep Relative Distance Learning: Tell the Difference Between Similar Vehicles},
22 | author={Liu, Hongye and Tian, Yonghong and Wang, Yaowei and Pang, Lu and Huang, Tiejun},
23 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
24 | pages={2167--2175},
25 | year={2016}}
26 |
27 | Dataset statistics:
28 | # train_list: 13164 vehicles for model training
29 | # test_list_800: 800 vehicles for model testing(small test set in paper
30 | # test_list_1600: 1600 vehicles for model testing(medium test set in paper
31 | # test_list_2400: 2400 vehicles for model testing(large test set in paper
32 | # test_list_3200: 3200 vehicles for model testing
33 | # test_list_6000: 6000 vehicles for model testing
34 | # test_list_13164: 13164 vehicles for model testing
35 | """
36 | dataset_dir = 'vehicleid'
37 |
38 | def __init__(self, root='datasets', verbose=True, test_size=2400, **kwargs):
39 | super(VehicleID, self).__init__(root)
40 | self.dataset_dir = osp.join(self.root, self.dataset_dir)
41 | self.img_dir = osp.join(self.dataset_dir, 'image')
42 | self.train_dir = osp.join(self.dataset_dir, 'image_train')
43 | self.split_dir = osp.join(self.dataset_dir, 'train_test_split')
44 | self.train_list = osp.join(self.split_dir, 'train_list.txt')
45 | self.test_size = test_size
46 |
47 | if self.test_size == 800:
48 | self.gallery_dir = osp.join(self.dataset_dir, 'image_gallery_800')
49 | self.test_list = osp.join(self.split_dir, 'test_list_800.txt')
50 | elif self.test_size == 1600:
51 | self.gallery_dir = osp.join(self.dataset_dir, 'image_gallery_1600')
52 | self.test_list = osp.join(self.split_dir, 'test_list_1600.txt')
53 | elif self.test_size == 2400:
54 | self.gallery_dir = osp.join(self.dataset_dir, 'image_gallery_2400')
55 | self.test_list = osp.join(self.split_dir, 'test_list_2400.txt')
56 |
57 | print(self.gallery_dir)
58 |
59 | self.check_before_run()
60 |
61 | train = self.process_dir(self.train_dir, relabel=True)
62 | query, gallery = self.process_split(relabel=True)
63 |
64 | self.train = train
65 | self.query = query
66 | self.gallery = gallery
67 |
68 | if verbose:
69 | print('=> VehicleID loaded')
70 | self.print_dataset_statistics(train, query, gallery)
71 |
72 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
73 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
74 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
75 |
76 |
77 | def check_before_run(self):
78 | """Check if all files are available before going deeper"""
79 | if not osp.exists(self.dataset_dir):
80 | raise RuntimeError('"{}" is not available'.format(self.dataset_dir))
81 | if not osp.exists(self.train_dir):
82 | raise RuntimeError('"{}" is not available'.format(self.train_dir))
83 | if self.test_size not in [800, 1600, 2400]:
84 | raise RuntimeError('"{}" is not available'.format(self.test_size))
85 | if not osp.exists(self.gallery_dir):
86 | raise RuntimeError('"{}" is not available'.format(self.gallery_dir))
87 |
88 | def get_pid2label(self, pids):
89 | pid_container = set(pids)
90 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
91 | return pid2label
92 |
93 | def parse_img_pids(self, nl_pairs, pid2label=None):
94 | # il_pair is the pairs of img name and label
95 | output = []
96 | for info in nl_pairs:
97 | name = info[0]
98 | pid = info[1]
99 | if pid2label is not None:
100 | pid = pid2label[pid]
101 | camid = 1 # don't have camid information use 1 for all
102 | img_path = osp.join(self.img_dir, name+'.jpg')
103 | output.append((img_path, pid, camid))
104 | return output
105 |
106 | def process_dir(self, dir_path, relabel=False):
107 | img_paths = sorted(glob.glob(osp.join(dir_path, '*.jpg')))
108 | #pattern = re.compile(r'([-\d]+)_c([-\d]+)')
109 |
110 | pid_container = set()
111 | for img_path in img_paths:
112 | pid = int(re.search(r'([-\d]+)', img_path).group())
113 | if pid == -1:
114 | continue # junk images are just ignored
115 | pid_container.add(pid)
116 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
117 |
118 | dataset = []
119 | for img_path in img_paths:
120 | pid = int(re.search(r'([-\d]+)', img_path).group())
121 | if pid == -1:
122 | continue # junk images are just ignored
123 | assert 0 <= pid <= 131640 # pid == 0 means background
124 | pid = pid2label[pid]
125 | camid = 1
126 | dataset.append((img_path, pid, camid))
127 |
128 | return dataset
129 |
130 | def process_split(self, relabel=False):
131 |
132 | test_pid_dict = defaultdict(list)
133 | with open(self.test_list) as f_test:
134 | test_data = f_test.readlines()
135 | for data in test_data:
136 | name, pid = data.split(' ')
137 | test_pid_dict[pid].append([name, pid])
138 | test_pids = list(test_pid_dict.keys())
139 | num_test_pids = len(test_pids)
140 | assert num_test_pids == self.test_size, 'There should be {} vehicles for testing,' \
141 | ' but but got {}, please check the data'\
142 | .format(self.test_size, num_test_pids)
143 |
144 | query_data = []
145 | gallery_data = []
146 |
147 | # for each test id, random choose one image for gallery
148 | # and the other ones for query.
149 | for pid in test_pids:
150 | imginfo = test_pid_dict[pid]
151 | sample = random.choice(imginfo)
152 | imginfo.remove(sample)
153 | gallery_data.extend(imginfo)
154 | query_data.append(sample)
155 |
156 | query = self.parse_img_pids(query_data)
157 | gallery = self.parse_img_pids(gallery_data)
158 | return query, gallery
159 |
160 |
--------------------------------------------------------------------------------
/datasets/veri.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | import glob
6 | import re
7 | import os.path as osp
8 |
9 | from .base import BaseImageDataset
10 |
11 |
12 | class VeRi(BaseImageDataset):
13 |
14 | dataset_dir = 'veri'
15 |
16 | def __init__(self, root='datasets', dataset_dir = 'VeRi', verbose=True, **kwargs):
17 | super(VeRi, self).__init__(root)
18 | self.dataset_dir = osp.join(self.root, self.dataset_dir)
19 | self.train_dir = osp.join(self.dataset_dir, 'image_train')
20 | #self.val_dir = osp.join(self.dataset_dir, 'image_val')
21 | self.query_dir = osp.join(self.dataset_dir, 'image_query')
22 | self.gallery_dir = osp.join(self.dataset_dir, 'image_test')
23 |
24 | self.check_before_run()
25 |
26 | train = self.process_dir(self.train_dir, relabel=True)
27 | #val = self.process_dir(self.val_dir, relabel=True)
28 | query = self.process_dir(self.query_dir, relabel=False)
29 | gallery = self.process_dir(self.gallery_dir, relabel=False)
30 |
31 | if verbose:
32 | print('=> VeRi loaded')
33 | self.print_dataset_statistics(train, query, gallery)
34 |
35 | self.train = train
36 | #self.val = val
37 | self.query = query
38 | self.gallery = gallery
39 |
40 | self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
41 | #self.num_val_pids, self.num_val_imgs, self.num_val_cams = self.get_imagedata_info(self.val)
42 | self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
43 | self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
44 |
45 | def check_before_run(self):
46 | """Check if all files are available before going deeper"""
47 | if not osp.exists(self.dataset_dir):
48 | raise RuntimeError('"{}" is not available'.format(self.dataset_dir))
49 | if not osp.exists(self.train_dir):
50 | raise RuntimeError('"{}" is not available'.format(self.train_dir))
51 | if not osp.exists(self.query_dir):
52 | raise RuntimeError('"{}" is not available'.format(self.query_dir))
53 | if not osp.exists(self.gallery_dir):
54 | raise RuntimeError('"{}" is not available'.format(self.gallery_dir))
55 |
56 | def process_dir(self, dir_path, relabel=False):
57 | img_paths = sorted(glob.glob(osp.join(dir_path, '*.jpg')))
58 | pattern = re.compile(r'([-\d]+)_c([-\d]+)')
59 |
60 | pid_container = set()
61 | for img_path in img_paths:
62 | pid, _ = map(int, pattern.search(img_path).groups())
63 | if pid == -1:
64 | continue # junk images are just ignored
65 | pid_container.add(pid)
66 | pid2label = {pid: label for label, pid in enumerate(pid_container)}
67 |
68 | dataset = []
69 | for img_path in img_paths:
70 | pid, camid = map(int, pattern.search(img_path).groups())
71 | if pid == -1:
72 | continue # junk images are just ignored
73 | assert 0 <= pid <= 1501 # pid == 0 means background
74 | assert 1 <= camid <= 20
75 | camid -= 1 # index starts from 0
76 | if relabel:
77 | pid = pid2label[pid]
78 | dataset.append((img_path, pid, camid))
79 |
80 | return dataset
81 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from __future__ import division
3 | import time
4 | import numpy as np
5 | import torch
6 | from utils.evaluation import evaluate, evaluate_vid
7 | from utils.reranking import re_ranking
8 | from utils.avgmeter import AverageMeter
9 |
10 |
11 | def do_test(model, queryloader, galleryloader, batch_size, use_gpu, dataset, ranks=[1, 5, 10]):
12 | batch_time = AverageMeter()
13 |
14 | model.eval()
15 |
16 | with torch.no_grad():
17 | qf, q_pids, q_camids = [], [], []
18 | for batch_idx, (imgs, pids, camids, _) in enumerate(queryloader):
19 | if use_gpu:
20 | imgs = imgs.cuda()
21 |
22 | end = time.time()
23 | features = model(imgs)
24 | batch_time.update(time.time() - end)
25 |
26 | features = features.data.cpu()
27 | qf.append(features)
28 | q_pids.extend(pids)
29 | q_camids.extend(camids)
30 | qf = torch.cat(qf, 0)
31 | q_pids = np.asarray(q_pids)
32 | q_camids = np.asarray(q_camids)
33 |
34 | print('Extracted features for query set, obtained {}-by-{} matrix'.format(qf.size(0), qf.size(1)))
35 |
36 | gf, g_pids, g_camids = [], [], []
37 | for batch_idx, (imgs, pids, camids, _) in enumerate(galleryloader):
38 | if use_gpu:
39 | imgs = imgs.cuda()
40 |
41 | end = time.time()
42 | features = model(imgs)
43 | batch_time.update(time.time() - end)
44 |
45 | features = features.data.cpu()
46 | gf.append(features)
47 | g_pids.extend(pids)
48 | g_camids.extend(camids)
49 | gf = torch.cat(gf, 0)
50 | g_pids = np.asarray(g_pids)
51 | g_camids = np.asarray(g_camids)
52 |
53 | print('Extracted features for gallery set, obtained {}-by-{} matrix'.format(gf.size(0), gf.size(1)))
54 |
55 | print('=> BatchTime(s)/BatchSize(img): {:.3f}/{}'.format(batch_time.avg, batch_size))
56 |
57 | m, n = qf.size(0), gf.size(0)
58 | distmat = torch.pow(qf, 2).sum(dim=1, keepdim=True).expand(m, n) + \
59 | torch.pow(gf, 2).sum(dim=1, keepdim=True).expand(n, m).t()
60 | distmat.addmm_(1, -2, qf, gf.t())
61 | distmat = distmat.numpy()
62 |
63 | print('Computing CMC and mAP')
64 | if dataset == 'vehicleid':
65 | cmc, mAP = evaluate_vid(distmat, q_pids, g_pids, q_camids, g_camids, 50)
66 | else:
67 | cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, 50)
68 |
69 | print('Results ----------')
70 | print('mAP: {:.1%}'.format(mAP))
71 | print('CMC curve')
72 | for r in ranks:
73 | print('Rank-{:<3}: {:.1%}'.format(r, cmc[r - 1]))
74 | print('------------------')
75 |
76 | distmat_re = re_ranking(qf, gf, k1=80, k2=15, lambda_value=0.2)
77 | print('Computing CMC and mAP')
78 | if dataset == 'vehicleid':
79 | cmc_re, mAP_re = evaluate_vid(distmat_re, q_pids, g_pids, q_camids, g_camids, 50)
80 | else:
81 | cmc_re, mAP_re = evaluate(distmat_re, q_pids, g_pids, q_camids, g_camids, 50)
82 | print('Re-Ranked Results--')
83 | print('mAP: {:.1%}'.format(mAP_re))
84 | print('CMC curve')
85 | for r in ranks:
86 | print('Rank-{:<3}: {:.1%}'.format(r, cmc_re[r - 1]))
87 | print('------------------')
88 |
89 | return cmc[0], distmat, cmc_re[0], distmat_re
--------------------------------------------------------------------------------
/images/affinity_matrix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adhirajghosh/RPTM_reid/183e1f77a0979ab2ffa08b0bdb1c43ef0f633ad5/images/affinity_matrix.png
--------------------------------------------------------------------------------
/images/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adhirajghosh/RPTM_reid/183e1f77a0979ab2ffa08b0bdb1c43ef0f633ad5/images/architecture.png
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from .cross_entropy_loss import CrossEntropyLoss
6 | from .hard_mine_triplet_loss import TripletLoss
7 |
8 |
9 | def DeepSupervision(criterion, xs, y):
10 | """
11 | Args:
12 | - criterion: loss function
13 | - xs: tuple of inputs
14 | - y: ground truth
15 | """
16 | loss = 0.
17 | for x in xs:
18 | loss += criterion(x, y)
19 | loss /= len(xs)
20 | return loss
21 |
--------------------------------------------------------------------------------
/loss/cross_entropy_loss.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class CrossEntropyLoss(nn.Module):
9 | """Cross entropy loss with label smoothing regularizer.
10 |
11 | Reference:
12 | Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
13 |
14 | Equation: y = (1 - epsilon) * y + epsilon / K.
15 |
16 | Args:
17 | - num_classes (int): number of classes
18 | - epsilon (float): weight
19 | - use_gpu (bool): whether to use gpu devices
20 | - label_smooth (bool): whether to apply label smoothing, if False, epsilon = 0
21 | """
22 |
23 | def __init__(self, num_classes, epsilon=0.1, use_gpu=True, label_smooth=True):
24 | super(CrossEntropyLoss, self).__init__()
25 | self.num_classes = num_classes
26 | self.epsilon = epsilon if label_smooth else 0
27 | self.use_gpu = use_gpu
28 | self.logsoftmax = nn.LogSoftmax(dim=1)
29 |
30 | def forward(self, inputs, targets):
31 | """
32 | Args:
33 | - inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
34 | - targets: ground truth labels with shape (num_classes)
35 | """
36 | log_probs = self.logsoftmax(inputs)
37 | targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).data.cpu(), 1)
38 | if self.use_gpu: targets = targets.cuda()
39 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
40 | loss = (- targets * log_probs).mean(0).sum()
41 | return loss
42 |
--------------------------------------------------------------------------------
/loss/hard_mine_triplet_loss.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class TripletLoss(nn.Module):
9 |
10 |
11 | def __init__(self, margin=0.3):
12 | super(TripletLoss, self).__init__()
13 | self.margin = margin
14 | self.ranking_loss = nn.MarginRankingLoss(margin=margin)
15 |
16 | def forward(self, inputs, targets):
17 | """
18 | Args:
19 | - inputs: feature matrix with shape (batch_size, feat_dim)
20 | - targets: ground truth labels with shape (num_classes)
21 | """
22 | n = inputs.size(0)
23 |
24 | # Compute pairwise distance, replace by the official when merged
25 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
26 | dist = dist + dist.t()
27 | dist.addmm_(1, -2, inputs, inputs.t())
28 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
29 |
30 | # For each anchor, find the hardest positive and negative
31 | mask = targets.expand(n, n).eq(targets.expand(n, n).t())
32 | dist_ap, dist_an = [], []
33 | for i in range(n):
34 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))
35 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))
36 | dist_ap = torch.cat(dist_ap)
37 | dist_an = torch.cat(dist_an)
38 |
39 | # Compute ranking hinge loss
40 | y = torch.ones_like(dist_an)
41 | loss = self.ranking_loss(dist_an, dist_ap, y)
42 | return loss
43 |
--------------------------------------------------------------------------------
/loss/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | import torch
5 | import yaml
6 | import torch.nn as nn
7 | from torch.autograd import Variable
8 | import torch.nn.functional as F
9 | from torch.nn import Parameter
10 | from torch.nn import init
11 | import math
12 |
13 | def triplet_loss(features, margin, batch_size, size_average = True):
14 | ranking_loss = nn.MarginRankingLoss(margin=margin)
15 | #anchor = L2Normalization(features[0:batch_size])
16 | #positive = L2Normalization(features[batch_size:batch_size*2])
17 | #negative = L2Normalization(features[batch_size*2:batch_size*3])
18 | anchor = features[0:batch_size]
19 | positive = features[batch_size:batch_size*2]
20 | negative = features[batch_size*2:batch_size*3]
21 | #distance1 = torch.sqrt(torch.sum(torch.pow(anchor - positive, 2), 1, keepdims=True))
22 | #distance2 = torch.sqrt(torch.sum(torch.pow(anchor - negative, 2), 1, keepdims=True))
23 | distance_positive = (anchor - positive).pow(2).sum(1).pow(.5)
24 | distance_negative = (anchor - negative).pow(2).sum(1).pow(.5)
25 | y = torch.ones_like(distance_negative)
26 | #losses = F.relu(distance_positive - distance_negative + margin)
27 | losses = ranking_loss(distance_positive, distance_negative, y)
28 | return losses if size_average else losses.sum()
29 |
30 |
31 | def xent_loss(output, trainY, num_classes):
32 | epsilon = 1.0
33 | logsoftmax = nn.LogSoftmax(dim=1)
34 | log_probs = logsoftmax(output)
35 | targets = torch.zeros(log_probs.size())
36 | for i in range(len(targets)):
37 | for j in range(len(targets[i])):
38 | if j == trainY[i]:
39 | targets[i][j] = torch.tensor(1.0)
40 | break
41 | targets = targets.cuda()
42 | targets = (1 - epsilon) * targets + epsilon / num_classes
43 | loss = (- targets * log_probs).mean(0).sum()
44 | return loss
45 |
46 | def L2Normalization(ff, dim = 1):
47 | # ff is B*N
48 | fnorm = torch.norm(ff, p=2, dim=dim, keepdim=True) + 1e-5
49 | ff = ff.div(fnorm.expand_as(ff))
50 | return ff
51 |
52 | def myphi(x,m):
53 | x = x * m
54 | return 1-x**2/math.factorial(2)+x**4/math.factorial(4)-x**6/math.factorial(6) + \
55 | x**8/math.factorial(8) - x**9/math.factorial(9)
56 |
57 | # I largely modified the AngleLinear Loss
58 | class AngleLinear(nn.Module):
59 | def __init__(self, in_features, out_features, m = 4, phiflag=True):
60 | super(AngleLinear, self).__init__()
61 | self.in_features = in_features
62 | self.out_features = out_features
63 | self.weight = Parameter(torch.Tensor(in_features,out_features))
64 | init.normal_(self.weight.data, std=0.001)
65 | self.phiflag = phiflag
66 | self.m = m
67 | self.mlambda = [
68 | lambda x: x**0,
69 | lambda x: x**1,
70 | lambda x: 2*x**2-1,
71 | lambda x: 4*x**3-3*x,
72 | lambda x: 8*x**4-8*x**2+1,
73 | lambda x: 16*x**5-20*x**3+5*x
74 | ]
75 |
76 | def forward(self, input):
77 | x = input # size=(B,F) F is feature len
78 | w = self.weight # size=(F,Classnum) F=in_features Classnum=out_features
79 |
80 | ww = w.renorm(2,1,1e-5).mul(1e5)
81 | xlen = x.pow(2).sum(1).pow(0.5) # size=B
82 | wlen = ww.pow(2).sum(0).pow(0.5) # size=Classnum
83 |
84 | cos_theta = x.mm(ww) # size=(B,Classnum)
85 | cos_theta = cos_theta / xlen.view(-1,1) / wlen.view(1,-1)
86 | cos_theta = cos_theta.clamp(-1,1)
87 |
88 | if self.phiflag:
89 | cos_m_theta = self.mlambda[self.m](cos_theta)
90 | theta = Variable(cos_theta.data.acos())
91 | k = (self.m*theta/3.14159265).floor()
92 | n_one = k*0.0 - 1
93 | phi_theta = (n_one**k) * cos_m_theta - 2*k
94 | else:
95 | theta = cos_theta.acos()
96 | phi_theta = myphi(theta,self.m)
97 | phi_theta = phi_theta.clamp(-1*self.m,1)
98 |
99 | cos_theta = cos_theta * xlen.view(-1,1)
100 | phi_theta = phi_theta * xlen.view(-1,1)
101 | output = (cos_theta,phi_theta)
102 | return output # size=(B,Classnum,2)
103 |
104 | #https://github.com/auroua/InsightFace_TF/blob/master/losses/face_losses.py#L80
105 | class ArcLinear(nn.Module):
106 | def __init__(self, in_features, out_features, s=64.0):
107 | super(ArcLinear, self).__init__()
108 | self.weight = Parameter(torch.Tensor(in_features,out_features))
109 | init.normal_(self.weight.data, std=0.001)
110 | self.loss_s = s
111 |
112 | def forward(self, input):
113 | embedding = input
114 | nembedding = L2Normalization(embedding, dim=1)*self.loss_s
115 | _weight = L2Normalization(self.weight, dim=0)
116 | fc7 = nembedding.mm(_weight)
117 | output = (fc7, _weight, nembedding)
118 | return output
119 |
120 | class ArcLoss(nn.Module):
121 | def __init__(self, m1=1.0, m2=0.5, m3 =0.0, s = 64.0):
122 | super(ArcLoss, self).__init__()
123 | self.loss_m1 = m1
124 | self.loss_m2 = m2
125 | self.loss_m3 = m3
126 | self.loss_s = s
127 |
128 | def forward(self, input, target):
129 | fc7, _weight, nembedding = input
130 |
131 | index = fc7.data * 0.0 #size=(B,Classnum)
132 | index.scatter_(1,target.data.view(-1,1),1)
133 | index = index.byte()
134 | index = Variable(index)
135 |
136 | zy = fc7[index]
137 | cos_t = zy/self.loss_s
138 | t = torch.acos(cos_t)
139 | t = t*self.loss_m1 + self.loss_m2
140 | body = torch.cos(t) - self.loss_m3
141 |
142 | new_zy = body*self.loss_s
143 | diff = new_zy - zy
144 | fc7[index] += diff
145 | loss = F.cross_entropy(fc7, target)
146 | return loss
147 |
148 | class AngleLoss(nn.Module):
149 | def __init__(self, gamma=0):
150 | super(AngleLoss, self).__init__()
151 | self.gamma = gamma
152 | self.it = 0
153 | self.LambdaMin = 5.0
154 | self.LambdaMax = 1500.0
155 | self.lamb = 1500.0
156 |
157 | def forward(self, input, target):
158 | self.it += 1
159 | cos_theta,phi_theta = input
160 | target = target.view(-1,1) #size=(B,1)
161 |
162 | index = cos_theta.data * 0.0 #size=(B,Classnum)
163 | index.scatter_(1,target.data.view(-1,1),1)
164 | index = index.byte()
165 | index = Variable(index)
166 |
167 | self.lamb = max(self.LambdaMin,self.LambdaMax/(1+0.1*self.it ))
168 | output = cos_theta * 1.0 #size=(B,Classnum)
169 | output[index] -= cos_theta[index]*(1.0+0)/(1+self.lamb)
170 | output[index] += phi_theta[index]*(1.0+0)/(1+self.lamb)
171 |
172 | logpt = F.log_softmax(output, dim=1)
173 | logpt = logpt.gather(1,target)
174 | logpt = logpt.view(-1)
175 | pt = Variable(logpt.data.exp())
176 |
177 | loss = -1 * (1-pt)**self.gamma * logpt
178 | loss = loss.mean()
179 |
180 | return loss
181 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | import numpy as np
4 | import os
5 | import os.path as osp
6 | import argparse
7 | import sys
8 | try:
9 | from apex.fp16_utils import *
10 | from apex import amp, optimizers
11 | except ImportError: # will be 3.x series
12 | print('This is not an error. If you want to use low precision, i.e., fp16, please install the apex with cuda support (https://github.com/NVIDIA/apex) and update pytorch to 1.0')
13 | from config import cfg
14 | from datasets import data_loader
15 | from model import ft_net_SE, init_model, init_optimizer
16 | from loss import CrossEntropyLoss, TripletLoss
17 | from train import do_train
18 | from eval import do_test
19 | from utils.kwargs import return_kwargs
20 | from utils.loggers import Logger
21 | from utils.torchtools import count_num_param, accuracy, load_pretrained_weights, save_checkpoint
22 | from utils.visualtools import visualize_ranked_results
23 | from utils.functions import create_split_dirs
24 |
25 | try:
26 | from apex import amp
27 | APEX_AVAILABLE = True
28 | except ModuleNotFoundError:
29 | APEX_AVAILABLE = False
30 |
31 | def set_seed(seed):
32 | torch.manual_seed(seed)
33 | torch.cuda.manual_seed(seed)
34 | torch.cuda.manual_seed_all(seed)
35 | np.random.seed(seed)
36 | random.seed(seed)
37 | torch.backends.cudnn.deterministic = True
38 | torch.backends.cudnn.benchmark = True
39 |
40 | def main():
41 | parser = argparse.ArgumentParser(description="Relation Preserving Triplet Mining for Object Re-identification")
42 | parser.add_argument(
43 | "--config_file", default="configs/veri_r101.yml", help="path to config file", type=str
44 | )
45 | parser.add_argument("opts", help="Modify config options using the command-line", default=None,
46 | nargs=argparse.REMAINDER)
47 |
48 | args = parser.parse_args()
49 |
50 | #Load the config file
51 | if args.config_file != "":
52 | cfg.merge_from_file(args.config_file)
53 | cfg.merge_from_list(args.opts)
54 | cfg.freeze()
55 |
56 | set_seed(1234)
57 | os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.GPU_ID
58 |
59 | output_dir = cfg.MISC.SAVE_DIR
60 | if output_dir and not os.path.exists(output_dir):
61 | os.makedirs(output_dir)
62 |
63 | dataset_kwargs, transform_kwargs, optimizer_kwargs, lr_scheduler_kwargs = return_kwargs(cfg)
64 |
65 | if cfg.MISC.FP16:
66 | fp16 = True
67 |
68 | use_gpu = cfg.MISC.USE_GPU
69 | log_name = './log_test.txt' if cfg.TEST.EVAL else './log_train.txt'
70 | sys.stdout = Logger(osp.join(cfg.MISC.SAVE_DIR, log_name))
71 |
72 | if not os.path.exists(cfg.DATASET.SPLIT_DIR):
73 | create_split_dirs(cfg)
74 |
75 | print("Running for RPTM: ", cfg.MODEL.RPTM_SELECT)
76 | print('Currently using GPU ', cfg.MODEL.GPU_ID)
77 | print('Initializing image data manager')
78 |
79 | trainloader, train_dict, data_tfr, testloader_dict, dm = data_loader(cfg, dataset_kwargs, transform_kwargs)
80 |
81 | print('Initializing model: {}'.format(cfg.MODEL.ARCH))
82 |
83 | model = init_model(cfg.MODEL.ARCH, train_dict['class_names'], loss={'xent', 'htri'}, use_gpu=use_gpu)
84 | print('Model size: {:.3f} M'.format(count_num_param(model)))
85 |
86 | if cfg.MODEL.PRETRAIN_PATH != '':
87 | print("weights loaded")
88 | load_pretrained_weights(model, cfg.MODEL.PRETRAIN_PATH)
89 |
90 | if use_gpu:
91 | model = model.cuda()
92 | optimizer = init_optimizer(model, **optimizer_kwargs)
93 | if APEX_AVAILABLE:
94 | model, optimizer = amp.initialize(
95 | model, optimizer, opt_level="O2",
96 | keep_batchnorm_fp32=True, loss_scale="dynamic")
97 |
98 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg.SOLVER.STEPSIZE, gamma=cfg.SOLVER.GAMMA)
99 |
100 | criterion_xent = CrossEntropyLoss(num_classes=train_dict['num_train_pids'], use_gpu=use_gpu, label_smooth=True)
101 | criterion_htri = TripletLoss(margin=cfg.LOSS.MARGIN)
102 |
103 | if cfg.TEST.EVAL:
104 | print('Evaluate only')
105 |
106 | for name in cfg.DATASET.TARGET_NAME:
107 | print('Evaluating {} ...'.format(name))
108 | queryloader = testloader_dict[name]['query']
109 | galleryloader = testloader_dict[name]['gallery']
110 | _, distmat, _, distmat_re = do_test(model, queryloader, galleryloader, cfg.TEST.TEST_BATCH_SIZE, use_gpu, cfg.DATASET.TARGET_NAME[0])
111 |
112 | if cfg.TEST.VIS_RANK:
113 | visualize_ranked_results(
114 | distmat_re, dm.return_testdataset_by_name(name),
115 | save_dir=osp.join(cfg.MISC.SAVE_DIR, 'ranked_results', name),
116 | topk=20
117 | )
118 | return
119 |
120 | print('=> Start training')
121 |
122 | do_train(cfg,
123 | trainloader,
124 | train_dict,
125 | data_tfr,
126 | testloader_dict,
127 | dm,
128 | model,
129 | optimizer,
130 | scheduler,
131 | criterion_htri,
132 | criterion_xent,
133 | )
134 |
135 |
136 | if __name__ == '__main__':
137 | main()
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import *
2 | from .optimizers import *
3 | from .resnet import *
4 | from .senet import *
5 |
6 | __model_factory = {
7 | # image classification models
8 | 'resnet50': resnet50,
9 | 'resnet50_fc512': resnet50_fc512,
10 | 'resnet101': resnet101,
11 | 'resnet152': resnet152,
12 | 'resnet50_ibn_a': resnet50_ibn_a,
13 | 'resnet101_ibn_a': resnet101_ibn_a,
14 | 'resnet152_ibn_a': resnet152_ibn_a,
15 | 'senet154': senet154,
16 | 'se_resnet50': se_resnet50,
17 | 'se_resnet101': se_resnet101,
18 | 'se_resnet152': se_resnet152,
19 | 'se_resnext50_32x4d': se_resnext50_32x4d,
20 | 'se_resnext101_32x4d': se_resnext101_32x4d }
21 |
22 | def get_names():
23 | return list(__model_factory.keys())
24 |
25 |
26 | def init_model(name, *args, **kwargs):
27 | if name not in list(__model_factory.keys()):
28 | raise KeyError('Unknown model: {}'.format(name))
29 | return __model_factory[name](*args, **kwargs)
30 |
--------------------------------------------------------------------------------
/model/lr_schedulers.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 |
4 | import torch
5 |
6 |
7 | def init_lr_scheduler(optimizer,
8 | lr_scheduler='multi_step', # learning rate scheduler
9 | stepsize=[20, 40], # step size to decay learning rate
10 | gamma=0.1 # learning rate decay
11 | ):
12 | if lr_scheduler == 'single_step':
13 | return torch.optim.lr_scheduler.StepLR(optimizer, step_size=stepsize[0], gamma=gamma)
14 |
15 | elif lr_scheduler == 'multi_step':
16 | return torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=stepsize, gamma=gamma)
17 | elif lr_scheduler == 'plateau':
18 | return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=gamma, patience=0, verbose=True)
19 |
20 | else:
21 | raise ValueError('Unsupported lr_scheduler: {}'.format(lr_scheduler))
22 |
--------------------------------------------------------------------------------
/model/models.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import init
5 | from torchvision import models
6 | from torch.autograd import Variable
7 | import pretrainedmodels
8 | from .senet import se_resnext101_32x4d
9 | from torch.nn import functional as F
10 |
11 | ######################################################################
12 | def weights_init_kaiming(m):
13 | classname = m.__class__.__name__
14 | # print(classname)
15 | if classname.find('Conv') != -1:
16 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') # For old pytorch, you may use kaiming_normal.
17 | elif classname.find('Linear') != -1:
18 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out')
19 | init.constant_(m.bias.data, 0.0)
20 | elif classname.find('BatchNorm1d') != -1:
21 | init.normal_(m.weight.data, 1.0, 0.02)
22 | init.constant_(m.bias.data, 0.0)
23 |
24 | def weights_init_classifier(m):
25 | classname = m.__class__.__name__
26 | if classname.find('Linear') != -1:
27 | init.normal_(m.weight.data, std=0.001)
28 | init.constant_(m.bias.data, 0.0)
29 |
30 | def fix_relu(m):
31 | classname = m.__class__.__name__
32 | if classname.find('ReLU') != -1:
33 | m.inplace=True
34 | # Defines the new fc layer and classification layer
35 | # |--Linear--|--bn--|--relu--|--Linear--|
36 | class ClassBlock(nn.Module):
37 | def __init__(self, input_dim, class_num, droprate, relu=False, bnorm=True, num_bottleneck=512, linear=True, return_f = False):
38 | super(ClassBlock, self).__init__()
39 | self.return_f = return_f
40 | add_block = []
41 | if linear:
42 | add_block += [nn.Linear(input_dim, num_bottleneck)]
43 | else:
44 | num_bottleneck = input_dim
45 | if bnorm:
46 | add_block += [nn.BatchNorm1d(num_bottleneck)]
47 | if relu:
48 | add_block += [nn.LeakyReLU(0.1)]
49 | if droprate>0:
50 | add_block += [nn.Dropout(p=droprate)]
51 | add_block = nn.Sequential(*add_block)
52 | add_block.apply(weights_init_kaiming)
53 |
54 | classifier = []
55 | classifier += [nn.Linear(num_bottleneck, class_num)]
56 | classifier = nn.Sequential(*classifier)
57 | classifier.apply(weights_init_classifier)
58 |
59 | self.add_block = add_block
60 | self.classifier = classifier
61 | def forward(self, x):
62 | x = self.add_block(x)
63 | if self.return_f:
64 | f = x
65 | x = self.classifier(x)
66 | return x,f
67 | else:
68 | x = self.classifier(x)
69 | return x
70 |
71 | # Define the SE-based Model
72 | class ft_net_SE(nn.Module):
73 |
74 | def __init__(self, class_num, droprate=0.5, stride=2, pool='avg', init_model=None):
75 | super().__init__()
76 | model_name = 'se_resnext101_32x4d' # could be fbresnet152 or inceptionresnetv2
77 | # model_ft = pretrainedmodels.__dict__[model_name](num_classes=1000, pretrained='imagenet')
78 | model_ft = se_resnext101_32x4d(num_classes=1000, pretrained='imagenet')
79 |
80 | if stride == 1:
81 | model_ft.layer4[0].conv2.stride = (1,1)
82 | model_ft.layer4[0].downsample[0].stride = (1,1)
83 | if pool == 'avg':
84 | model_ft.avg_pool = nn.AdaptiveAvgPool2d((1,1))
85 | elif pool == 'max':
86 | model_ft.avg_pool = nn.AdaptiveMaxPool2d((1,1))
87 | elif pool == 'avg+max':
88 | model_ft.avg_pool2 = nn.AdaptiveAvgPool2d((1,1))
89 | model_ft.max_pool2 = nn.AdaptiveMaxPool2d((1,1))
90 | else:
91 | print('UNKNOW POOLING!!!!!!!!!!!!!!!!!!!!!!!!!!')
92 | #model_ft.dropout = nn.Sequential()
93 | model_ft.last_linear = nn.Sequential()
94 | self.model = model_ft
95 | self.pool = pool
96 | # For DenseNet, the feature dim is 2048
97 | if pool == 'avg+max':
98 | self.classifier = ClassBlock(4096, class_num, droprate)
99 | else:
100 | self.classifier = ClassBlock(2048, class_num, droprate)
101 | self.flag = False
102 | if init_model!=None:
103 | self.flag = True
104 | self.model = init_model.model
105 | self.classifier.add_block = init_model.classifier.add_block
106 | self.new_dropout = nn.Sequential(nn.Dropout(p = droprate))
107 |
108 | def forward(self, x):
109 | x = self.model.features(x)
110 | if self.pool == 'avg+max':
111 | v1 = self.model.avg_pool2(x)
112 | v2 = self.model.max_pool2(x)
113 | v = torch.cat((v1,v2), dim = 1)
114 | else:
115 | v = self.model.avg_pool(x)
116 | v = v.view(v.size(0), v.size(1))
117 | if not self.training:
118 | return v
119 | # Convolution layers
120 | # Pooling and final linear layer
121 | if self.flag:
122 | v = self.classifier.add_block(v)
123 | v = self.new_dropout(v)
124 | y = self.classifier.classifier(v)
125 | else:
126 | y = self.classifier(v)
127 | return y,v
128 |
129 |
--------------------------------------------------------------------------------
/model/optimizers.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | def init_optimizer(model,
9 | optim='adam', # optimizer choices
10 | lr=0.003, # learning rate
11 | weight_decay=5e-4, # weight decay
12 | momentum=0.9, # momentum factor for sgd and rmsprop
13 | sgd_dampening=0, # sgd's dampening for momentum
14 | sgd_nesterov=True, # whether to enable sgd's Nesterov momentum
15 | rmsprop_alpha=0.99, # rmsprop's smoothing constant
16 | adam_beta1=0.9, # exponential decay rate for adam's first moment
17 | adam_beta2=0.999, # # exponential decay rate for adam's second moment
18 | staged_lr=False, # different lr for different layers
19 | new_layers=None, # new layers use the default lr, while other layers's lr is scaled by base_lr_mult
20 | base_lr_mult=0.1, # learning rate multiplier for base layers
21 | ):
22 | if staged_lr:
23 | assert new_layers is not None
24 | base_params = []
25 | base_layers = []
26 | new_params = []
27 | if isinstance(model, nn.DataParallel):
28 | model = model.module
29 | for name, module in model.named_children():
30 | if name in new_layers:
31 | new_params += [p for p in module.parameters()]
32 | else:
33 | base_params += [p for p in module.parameters()]
34 | base_layers.append(name)
35 | param_groups = [
36 | {'params': base_params, 'lr': lr * base_lr_mult},
37 | {'params': new_params},
38 | ]
39 | print('Use staged learning rate')
40 | print('* Base layers (initial lr = {}): {}'.format(lr * base_lr_mult, base_layers))
41 | print('* New layers (initial lr = {}): {}'.format(lr, new_layers))
42 | else:
43 | param_groups = model.parameters()
44 |
45 | # Construct optimizer
46 | if optim == 'adam':
47 | return torch.optim.Adam(param_groups, lr=lr, weight_decay=weight_decay,
48 | betas=(adam_beta1, adam_beta2))
49 |
50 | elif optim == 'amsgrad':
51 | return torch.optim.Adam(param_groups, lr=lr, weight_decay=weight_decay,
52 | betas=(adam_beta1, adam_beta2), amsgrad=True)
53 |
54 | elif optim == 'sgd':
55 | return torch.optim.SGD(param_groups, lr=lr, momentum=momentum, weight_decay=weight_decay,
56 | dampening=sgd_dampening, nesterov=sgd_nesterov)
57 |
58 | elif optim == 'rmsprop':
59 | return torch.optim.RMSprop(param_groups, lr=lr, momentum=momentum, weight_decay=weight_decay,
60 | alpha=rmsprop_alpha)
61 |
62 | else:
63 | raise ValueError('Unsupported optimizer: {}'.format(optim))
64 |
--------------------------------------------------------------------------------
/model/resnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 |
4 | import torch
5 | from torch import nn
6 | import math
7 | import torch.utils.model_zoo as model_zoo
8 |
9 | __all__ = ['resnet50', 'resnet50_fc512', 'resnet101', 'resnet152', 'resnet50_ibn_a', 'resnet101_ibn_a',
10 | 'resnet152_ibn_a']
11 |
12 | model_urls = {
13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
15 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
16 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
17 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
18 | 'resnet101_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_a-59ea0ac6.pth'
19 | }
20 |
21 |
22 | def conv3x3(in_planes, out_planes, stride=1):
23 | """3x3 convolution with padding"""
24 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
25 | padding=1, bias=False)
26 |
27 |
28 | class BasicBlock(nn.Module):
29 | expansion = 1
30 |
31 | def __init__(self, inplanes, planes, stride=1, downsample=None):
32 | super(BasicBlock, self).__init__()
33 | self.conv1 = conv3x3(inplanes, planes, stride)
34 | self.bn1 = nn.BatchNorm2d(planes)
35 | self.relu = nn.ReLU(inplace=True)
36 | self.conv2 = conv3x3(planes, planes)
37 | self.bn2 = nn.BatchNorm2d(planes)
38 | self.downsample = downsample
39 | self.stride = stride
40 |
41 | def forward(self, x):
42 | residual = x
43 |
44 | out = self.conv1(x)
45 | out = self.bn1(out)
46 | out = self.relu(out)
47 |
48 | out = self.conv2(out)
49 | out = self.bn2(out)
50 |
51 | if self.downsample is not None:
52 | residual = self.downsample(x)
53 |
54 | out += residual
55 | out = self.relu(out)
56 |
57 | return out
58 |
59 |
60 | class Bottleneck(nn.Module):
61 | expansion = 4
62 |
63 | def __init__(self, inplanes, planes, stride=1, downsample=None):
64 | super(Bottleneck, self).__init__()
65 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
66 | self.bn1 = nn.BatchNorm2d(planes)
67 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
68 | padding=1, bias=False)
69 | self.bn2 = nn.BatchNorm2d(planes)
70 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
71 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
72 | self.relu = nn.ReLU(inplace=True)
73 | self.downsample = downsample
74 | self.stride = stride
75 |
76 | def forward(self, x):
77 | residual = x
78 |
79 | out = self.conv1(x)
80 | out = self.bn1(out)
81 | out = self.relu(out)
82 |
83 | out = self.conv2(out)
84 | out = self.bn2(out)
85 | out = self.relu(out)
86 |
87 | out = self.conv3(out)
88 | out = self.bn3(out)
89 |
90 | if self.downsample is not None:
91 | residual = self.downsample(x)
92 |
93 | out += residual
94 | out = self.relu(out)
95 |
96 | return out
97 |
98 |
99 | class ResNet(nn.Module):
100 | """
101 | Residual network
102 | Reference:
103 | He et al. Deep Residual Learning for Image Recognition. CVPR 2016.
104 | """
105 |
106 | def __init__(self, num_classes, loss, block, layers,
107 | last_stride=2,
108 | fc_dims=None,
109 | dropout_p=None,
110 | **kwargs):
111 | self.inplanes = 64
112 | super(ResNet, self).__init__()
113 | self.loss = loss
114 | self.feature_dim = 512 * block.expansion
115 |
116 | # backbone network
117 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
118 | self.bn1 = nn.BatchNorm2d(64)
119 | self.relu = nn.ReLU(inplace=True)
120 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
121 | self.layer1 = self._make_layer(block, 64, layers[0])
122 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
123 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
124 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride)
125 |
126 | self.global_avgpool = nn.AdaptiveAvgPool2d(1)
127 | self.fc = self._construct_fc_layer(fc_dims, 512 * block.expansion, dropout_p)
128 | self.classifier = nn.Linear(self.feature_dim, num_classes)
129 |
130 | self._init_params()
131 |
132 | def _make_layer(self, block, planes, blocks, stride=1):
133 | downsample = None
134 | if stride != 1 or self.inplanes != planes * block.expansion:
135 | downsample = nn.Sequential(
136 | nn.Conv2d(self.inplanes, planes * block.expansion,
137 | kernel_size=1, stride=stride, bias=False),
138 | nn.BatchNorm2d(planes * block.expansion),
139 | )
140 |
141 | layers = []
142 | layers.append(block(self.inplanes, planes, stride, downsample))
143 | self.inplanes = planes * block.expansion
144 | for i in range(1, blocks):
145 | layers.append(block(self.inplanes, planes))
146 |
147 | return nn.Sequential(*layers)
148 |
149 | def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
150 | """
151 | Construct fully connected layer
152 | - fc_dims (list or tuple): dimensions of fc layers, if None,
153 | no fc layers are constructed
154 | - input_dim (int): input dimension
155 | - dropout_p (float): dropout probability, if None, dropout is unused
156 | """
157 | if fc_dims is None:
158 | self.feature_dim = input_dim
159 | return None
160 |
161 | assert isinstance(fc_dims, (list, tuple)), 'fc_dims must be either list or tuple, but got {}'.format(
162 | type(fc_dims))
163 |
164 | layers = []
165 | for dim in fc_dims:
166 | layers.append(nn.Linear(input_dim, dim))
167 | layers.append(nn.BatchNorm1d(dim))
168 | layers.append(nn.ReLU(inplace=True))
169 | if dropout_p is not None:
170 | layers.append(nn.Dropout(p=dropout_p))
171 | input_dim = dim
172 |
173 | self.feature_dim = fc_dims[-1]
174 |
175 | return nn.Sequential(*layers)
176 |
177 | def _init_params(self):
178 | for m in self.modules():
179 | if isinstance(m, nn.Conv2d):
180 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
181 | if m.bias is not None:
182 | nn.init.constant_(m.bias, 0)
183 | elif isinstance(m, nn.BatchNorm2d):
184 | nn.init.constant_(m.weight, 1)
185 | nn.init.constant_(m.bias, 0)
186 | elif isinstance(m, nn.BatchNorm1d):
187 | nn.init.constant_(m.weight, 1)
188 | nn.init.constant_(m.bias, 0)
189 | elif isinstance(m, nn.Linear):
190 | nn.init.normal_(m.weight, 0, 0.01)
191 | if m.bias is not None:
192 | nn.init.constant_(m.bias, 0)
193 |
194 | def featuremaps(self, x):
195 | x = self.conv1(x)
196 | x = self.bn1(x)
197 | x = self.relu(x)
198 | x = self.maxpool(x)
199 | x = self.layer1(x)
200 | x = self.layer2(x)
201 | x = self.layer3(x)
202 | x = self.layer4(x)
203 | return x
204 |
205 | def forward(self, x):
206 | f = self.featuremaps(x)
207 | v = self.global_avgpool(f)
208 | v = v.view(v.size(0), -1)
209 |
210 | if self.fc is not None:
211 | v = self.fc(v)
212 |
213 | if not self.training:
214 | return v
215 |
216 | y = self.classifier(v)
217 |
218 | if self.loss == {'xent'}:
219 | return y
220 | elif self.loss == {'xent', 'htri'}:
221 | return y, v
222 | else:
223 | raise KeyError("Unsupported loss: {}".format(self.loss))
224 |
225 |
226 | class IBN(nn.Module):
227 | def __init__(self, planes):
228 | super(IBN, self).__init__()
229 | half1 = int(planes / 2)
230 | self.half = half1
231 | half2 = planes - half1
232 | self.IN = nn.InstanceNorm2d(half1, affine=True)
233 | self.BN = nn.BatchNorm2d(half2)
234 |
235 | def forward(self, x):
236 | split = torch.split(x, self.half, 1)
237 | out1 = self.IN(split[0].contiguous())
238 | out2 = self.BN(split[1].contiguous())
239 | out = torch.cat((out1, out2), 1)
240 | return out
241 |
242 |
243 | class Bottleneck_IBN(nn.Module):
244 | expansion = 4
245 |
246 | def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None):
247 | super(Bottleneck_IBN, self).__init__()
248 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
249 | if ibn:
250 | self.bn1 = IBN(planes)
251 | else:
252 | self.bn1 = nn.BatchNorm2d(planes)
253 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
254 | padding=1, bias=False)
255 | self.bn2 = nn.BatchNorm2d(planes)
256 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
257 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
258 | self.relu = nn.ReLU(inplace=True)
259 | self.downsample = downsample
260 | self.stride = stride
261 |
262 | def forward(self, x):
263 | residual = x
264 |
265 | out = self.conv1(x)
266 | out = self.bn1(out)
267 | out = self.relu(out)
268 |
269 | out = self.conv2(out)
270 | out = self.bn2(out)
271 | out = self.relu(out)
272 |
273 | out = self.conv3(out)
274 | out = self.bn3(out)
275 |
276 | if self.downsample is not None:
277 | residual = self.downsample(x)
278 |
279 | out += residual
280 | out = self.relu(out)
281 |
282 | return out
283 |
284 |
285 | class ResNet_IBN(nn.Module):
286 |
287 | def __init__(self, last_stride, block, layers, loss, num_classes, fc_dims=None, dropout_p=None, **kwargs):
288 | scale = 64
289 | self.inplanes = scale
290 | super(ResNet_IBN, self).__init__()
291 | self.feature_dim = 512 * block.expansion
292 | self.conv1 = nn.Conv2d(3, scale, kernel_size=7, stride=2, padding=3,
293 | bias=False)
294 | self.bn1 = nn.BatchNorm2d(scale)
295 | self.relu = nn.ReLU(inplace=True)
296 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
297 | self.layer1 = self._make_layer(block, scale, layers[0])
298 | self.layer2 = self._make_layer(block, scale * 2, layers[1], stride=2)
299 | self.layer3 = self._make_layer(block, scale * 4, layers[2], stride=2)
300 | self.layer4 = self._make_layer(block, scale * 8, layers[3], stride=last_stride)
301 | self.avgpool = nn.AvgPool2d(7)
302 | # self.fc = nn.Linear(scale * 8 * block.expansion, num_classes)
303 | self.loss = loss
304 | self.global_avgpool = nn.AdaptiveAvgPool2d(1)
305 | self.fc = self._construct_fc_layer(fc_dims, scale * 8 * block.expansion, dropout_p)
306 | self.classifier = nn.Linear(self.feature_dim, num_classes)
307 |
308 | for m in self.modules():
309 | if isinstance(m, nn.Conv2d):
310 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
311 | m.weight.data.normal_(0, math.sqrt(2. / n))
312 | elif isinstance(m, nn.BatchNorm2d):
313 | m.weight.data.fill_(1)
314 | m.bias.data.zero_()
315 | elif isinstance(m, nn.InstanceNorm2d):
316 | m.weight.data.fill_(1)
317 | m.bias.data.zero_()
318 |
319 | def _make_layer(self, block, planes, blocks, stride=1):
320 | downsample = None
321 | if stride != 1 or self.inplanes != planes * block.expansion:
322 | downsample = nn.Sequential(
323 | nn.Conv2d(self.inplanes, planes * block.expansion,
324 | kernel_size=1, stride=stride, bias=False),
325 | nn.BatchNorm2d(planes * block.expansion),
326 | )
327 |
328 | layers = []
329 | ibn = True
330 | if planes == 512:
331 | ibn = False
332 | layers.append(block(self.inplanes, planes, ibn, stride, downsample))
333 | self.inplanes = planes * block.expansion
334 | for i in range(1, blocks):
335 | layers.append(block(self.inplanes, planes, ibn))
336 |
337 | return nn.Sequential(*layers)
338 |
339 | def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
340 | """
341 | Construct fully connected layer
342 | - fc_dims (list or tuple): dimensions of fc layers, if None,
343 | no fc layers are constructed
344 | - input_dim (int): input dimension
345 | - dropout_p (float): dropout probability, if None, dropout is unused
346 | """
347 | if fc_dims is None:
348 | self.feature_dim = input_dim
349 | return None
350 |
351 | assert isinstance(fc_dims, (list, tuple)), 'fc_dims must be either list or tuple, but got {}'.format(
352 | type(fc_dims))
353 |
354 | layers = []
355 | for dim in fc_dims:
356 | layers.append(nn.Linear(input_dim, dim))
357 | layers.append(nn.BatchNorm1d(dim))
358 | layers.append(nn.ReLU(inplace=True))
359 | if dropout_p is not None:
360 | layers.append(nn.Dropout(p=dropout_p))
361 | input_dim = dim
362 |
363 | self.feature_dim = fc_dims[-1]
364 |
365 | return nn.Sequential(*layers)
366 |
367 | def forward(self, x):
368 | x = self.conv1(x)
369 | x = self.bn1(x)
370 | x = self.relu(x)
371 | x = self.maxpool(x)
372 |
373 | x = self.layer1(x)
374 | x = self.layer2(x)
375 | x = self.layer3(x)
376 | x = self.layer4(x)
377 |
378 | # x = self.avgpool(x)
379 | # x = x.view(x.size(0), -1)
380 | # x = self.fc(x)
381 | f = x
382 | v = self.global_avgpool(x)
383 | v = v.view(v.size(0), -1)
384 |
385 | if self.fc is not None:
386 | v = self.fc(v)
387 |
388 | if not self.training:
389 | return v
390 |
391 | y = self.classifier(v)
392 |
393 | if self.loss == {'xent'}:
394 | return y
395 | elif self.loss == {'xent', 'htri'}:
396 | return y, v
397 | else:
398 | raise KeyError("Unsupported loss: {}".format(self.loss))
399 |
400 | def load_param(self, model_path):
401 | param_dict = torch.load(model_path)
402 | for i in param_dict:
403 | if 'fc' in i:
404 | continue
405 | self.state_dict()[i].copy_(param_dict[i])
406 |
407 |
408 | def init_pretrained_weights(model, model_url):
409 | """
410 | Initialize model with pretrained weights.
411 | Layers that don't match with pretrained layers in name or size are kept unchanged.
412 | """
413 | pretrain_dict = model_zoo.load_url(model_url)
414 | model_dict = model.state_dict()
415 | pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
416 | model_dict.update(pretrain_dict)
417 | model.load_state_dict(model_dict)
418 | print('Initialized model with pretrained weights from {}'.format(model_url))
419 |
420 |
421 | """
422 | Residual network configurations:
423 | --
424 | resnet18: block=BasicBlock, layers=[2, 2, 2, 2]
425 | resnet34: block=BasicBlock, layers=[3, 4, 6, 3]
426 | resnet50: block=Bottleneck, layers=[3, 4, 6, 3]
427 | resnet101: block=Bottleneck, layers=[3, 4, 23, 3]
428 | resnet152: block=Bottleneck, layers=[3, 8, 36, 3]
429 | """
430 |
431 |
432 | def resnet50(num_classes, loss={'xent'}, pretrained=True, **kwargs):
433 | model = ResNet(
434 | num_classes=num_classes,
435 | loss=loss,
436 | block=Bottleneck,
437 | layers=[3, 4, 6, 3],
438 | last_stride=2,
439 | fc_dims=None,
440 | dropout_p=None,
441 | **kwargs
442 | )
443 | if pretrained:
444 | init_pretrained_weights(model, model_urls['resnet50'])
445 | return model
446 |
447 |
448 | def resnet50_fc512(num_classes, loss={'xent'}, pretrained=True, **kwargs):
449 | model = ResNet(
450 | num_classes=num_classes,
451 | loss=loss,
452 | block=Bottleneck,
453 | layers=[3, 4, 6, 3],
454 | last_stride=1,
455 | fc_dims=[512],
456 | dropout_p=None,
457 | **kwargs
458 | )
459 | if pretrained:
460 | init_pretrained_weights(model, model_urls['resnet50'])
461 | return model
462 |
463 |
464 | def resnet101(num_classes, loss={'xent'}, pretrained=True, **kwargs):
465 | model = ResNet(
466 | num_classes=num_classes,
467 | loss=loss,
468 | block=Bottleneck,
469 | layers=[3, 4, 23, 3],
470 | last_stride=2,
471 | fc_dims=None,
472 | dropout_p=None,
473 | **kwargs
474 | )
475 | if pretrained:
476 | init_pretrained_weights(model, model_urls['resnet101'])
477 | return model
478 |
479 |
480 | def resnet152(num_classes, loss={'xent'}, pretrained=True, **kwargs):
481 | model = ResNet(
482 | num_classes=num_classes,
483 | loss=loss,
484 | block=Bottleneck,
485 | layers=[3, 8, 36, 3],
486 | last_stride=2,
487 | fc_dims=None,
488 | dropout_p=None,
489 | **kwargs
490 | )
491 | if pretrained:
492 | init_pretrained_weights(model, model_urls['resnet152'])
493 | return model
494 |
495 |
496 | def resnet50_ibn_a(num_classes, loss={'xent'}, pretrained=True, **kwargs):
497 | """Constructs a ResNet-50 model.
498 | Args:
499 | pretrained (bool): If True, returns a model pre-trained on ImageNet
500 | """
501 | model = ResNet_IBN(1, Bottleneck_IBN, [3, 4, 6, 3], loss, **kwargs)
502 | if pretrained:
503 | init_pretrained_weights(model, model_urls['resnet50'])
504 | return model
505 |
506 |
507 | def resnet101_ibn_a(num_classes, loss={'xent'}, pretrained=True, **kwargs):
508 | """Constructs a ResNet-101 model.
509 | Args:
510 | pretrained (bool): If True, returns a model pre-trained on ImageNet
511 | """
512 | model = ResNet_IBN(1, Bottleneck_IBN, [3, 4, 23, 3], loss, num_classes, **kwargs)
513 | if pretrained:
514 | init_pretrained_weights(model, model_urls['resnet101_ibn_a'])
515 | return model
516 |
517 |
518 | def resnet152_ibn_a(num_classes, loss={'xent'}, pretrained=True, **kwargs):
519 | """Constructs a ResNet-152 model.
520 | Args:
521 | pretrained (bool): If True, returns a model pre-trained on ImageNet
522 | """
523 | model = ResNet_IBN(1, Bottleneck_IBN, [3, 8, 36, 3], loss, num_classes, **kwargs)
524 | if pretrained:
525 | init_pretrained_weights(model, model_urls['resnet152'])
526 | return model
527 |
--------------------------------------------------------------------------------
/model/senet.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division, absolute_import
2 | from collections import OrderedDict
3 | import math
4 |
5 | import torch.nn as nn
6 | from torch.utils import model_zoo
7 |
8 | __all__ = ['SENet', 'senet154', 'se_resnet50', 'se_resnet101', 'se_resnet152',
9 | 'se_resnext50_32x4d', 'se_resnext101_32x4d']
10 |
11 | model_urls = {
12 | 'senet154': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth',
13 | 'se_resnet50': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
14 | 'se_resnet101': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth',
15 | 'se_resnet152': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth',
16 | 'se_resnext50_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
17 | 'se_resnext101_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
18 | }
19 |
20 | pretrained_settings = {
21 | 'senet154': {
22 | 'imagenet': {
23 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth',
24 | 'input_space': 'RGB',
25 | 'input_size': [3, 224, 224],
26 | 'input_range': [0, 1],
27 | 'mean': [0.485, 0.456, 0.406],
28 | 'std': [0.229, 0.224, 0.225],
29 | 'num_classes': 1000
30 | }
31 | },
32 | 'se_resnet50': {
33 | 'imagenet': {
34 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
35 | 'input_space': 'RGB',
36 | 'input_size': [3, 224, 224],
37 | 'input_range': [0, 1],
38 | 'mean': [0.485, 0.456, 0.406],
39 | 'std': [0.229, 0.224, 0.225],
40 | 'num_classes': 1000
41 | }
42 | },
43 | 'se_resnet101': {
44 | 'imagenet': {
45 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth',
46 | 'input_space': 'RGB',
47 | 'input_size': [3, 224, 224],
48 | 'input_range': [0, 1],
49 | 'mean': [0.485, 0.456, 0.406],
50 | 'std': [0.229, 0.224, 0.225],
51 | 'num_classes': 1000
52 | }
53 | },
54 | 'se_resnet152': {
55 | 'imagenet': {
56 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth',
57 | 'input_space': 'RGB',
58 | 'input_size': [3, 224, 224],
59 | 'input_range': [0, 1],
60 | 'mean': [0.485, 0.456, 0.406],
61 | 'std': [0.229, 0.224, 0.225],
62 | 'num_classes': 1000
63 | }
64 | },
65 | 'se_resnext50_32x4d': {
66 | 'imagenet': {
67 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
68 | 'input_space': 'RGB',
69 | 'input_size': [3, 224, 224],
70 | 'input_range': [0, 1],
71 | 'mean': [0.485, 0.456, 0.406],
72 | 'std': [0.229, 0.224, 0.225],
73 | 'num_classes': 1000
74 | }
75 | },
76 | 'se_resnext101_32x4d': {
77 | 'imagenet': {
78 | 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
79 | 'input_space': 'RGB',
80 | 'input_size': [3, 224, 224],
81 | 'input_range': [0, 1],
82 | 'mean': [0.485, 0.456, 0.406],
83 | 'std': [0.229, 0.224, 0.225],
84 | 'num_classes': 1000
85 | }
86 | },
87 | }
88 |
89 | class SEModule(nn.Module):
90 |
91 | def __init__(self, channels, reduction):
92 | super(SEModule, self).__init__()
93 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
94 | self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1,
95 | padding=0)
96 | self.relu = nn.ReLU(inplace=True)
97 | self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1,
98 | padding=0)
99 | self.sigmoid = nn.Sigmoid()
100 |
101 | def forward(self, x):
102 | module_input = x
103 | x = self.avg_pool(x)
104 | x = self.fc1(x)
105 | x = self.relu(x)
106 | x = self.fc2(x)
107 | x = self.sigmoid(x)
108 | return module_input * x
109 |
110 |
111 | class Bottleneck(nn.Module):
112 | """
113 | Base class for bottlenecks that implements `forward()` method.
114 | """
115 | def forward(self, x):
116 | residual = x
117 |
118 | out = self.conv1(x)
119 | out = self.bn1(out)
120 | out = self.relu(out)
121 |
122 | out = self.conv2(out)
123 | out = self.bn2(out)
124 | out = self.relu(out)
125 |
126 | out = self.conv3(out)
127 | out = self.bn3(out)
128 |
129 | if self.downsample is not None:
130 | residual = self.downsample(x)
131 |
132 | out = self.se_module(out) + residual
133 | out = self.relu(out)
134 |
135 | return out
136 |
137 |
138 | class SEBottleneck(Bottleneck):
139 | """
140 | Bottleneck for SENet154.
141 | """
142 | expansion = 4
143 |
144 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
145 | downsample=None):
146 | super(SEBottleneck, self).__init__()
147 | self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
148 | self.bn1 = nn.BatchNorm2d(planes * 2)
149 | self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3,
150 | stride=stride, padding=1, groups=groups,
151 | bias=False)
152 | self.bn2 = nn.BatchNorm2d(planes * 4)
153 | self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1,
154 | bias=False)
155 | self.bn3 = nn.BatchNorm2d(planes * 4)
156 | self.relu = nn.ReLU(inplace=True)
157 | self.se_module = SEModule(planes * 4, reduction=reduction)
158 | self.downsample = downsample
159 | self.stride = stride
160 |
161 |
162 | class SEResNetBottleneck(Bottleneck):
163 | """
164 | ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
165 | implementation and uses `stride=stride` in `conv1` and not in `conv2`
166 | (the latter is used in the torchvision implementation of ResNet).
167 | """
168 | expansion = 4
169 |
170 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
171 | downsample=None):
172 | super(SEResNetBottleneck, self).__init__()
173 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False,
174 | stride=stride)
175 | self.bn1 = nn.BatchNorm2d(planes)
176 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1,
177 | groups=groups, bias=False)
178 | self.bn2 = nn.BatchNorm2d(planes)
179 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
180 | self.bn3 = nn.BatchNorm2d(planes * 4)
181 | self.relu = nn.ReLU(inplace=True)
182 | self.se_module = SEModule(planes * 4, reduction=reduction)
183 | self.downsample = downsample
184 | self.stride = stride
185 |
186 |
187 | class SEResNeXtBottleneck(Bottleneck):
188 | """
189 | ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
190 | """
191 | expansion = 4
192 |
193 | def __init__(self, inplanes, planes, groups, reduction, stride=1,
194 | downsample=None, base_width=4):
195 | super(SEResNeXtBottleneck, self).__init__()
196 | width = math.floor(planes * (base_width / 64)) * groups
197 | self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False,
198 | stride=1)
199 | self.bn1 = nn.BatchNorm2d(width)
200 | self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
201 | padding=1, groups=groups, bias=False)
202 | self.bn2 = nn.BatchNorm2d(width)
203 | self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
204 | self.bn3 = nn.BatchNorm2d(planes * 4)
205 | self.relu = nn.ReLU(inplace=True)
206 | self.se_module = SEModule(planes * 4, reduction=reduction)
207 | self.downsample = downsample
208 | self.stride = stride
209 |
210 |
211 | class SENet(nn.Module):
212 |
213 | def __init__(self, block, layers, groups, reduction, dropout_p=0.2,
214 | inplanes=128, input_3x3=True, downsample_kernel_size=3,
215 | downsample_padding=1, num_classes=1000, loss={'xent'}):
216 | """
217 | Parameters
218 | ----------
219 | block (nn.Module): Bottleneck class.
220 | - For SENet154: SEBottleneck
221 | - For SE-ResNet models: SEResNetBottleneck
222 | - For SE-ResNeXt models: SEResNeXtBottleneck
223 | layers (list of ints): Number of residual blocks for 4 layers of the
224 | network (layer1...layer4).
225 | groups (int): Number of groups for the 3x3 convolution in each
226 | bottleneck block.
227 | - For SENet154: 64
228 | - For SE-ResNet models: 1
229 | - For SE-ResNeXt models: 32
230 | reduction (int): Reduction ratio for Squeeze-and-Excitation modules.
231 | - For all models: 16
232 | dropout_p (float or None): Drop probability for the Dropout layer.
233 | If `None` the Dropout layer is not used.
234 | - For SENet154: 0.2
235 | - For SE-ResNet models: None
236 | - For SE-ResNeXt models: None
237 | inplanes (int): Number of input channels for layer1.
238 | - For SENet154: 128
239 | - For SE-ResNet models: 64
240 | - For SE-ResNeXt models: 64
241 | input_3x3 (bool): If `True`, use three 3x3 convolutions instead of
242 | a single 7x7 convolution in layer0.
243 | - For SENet154: True
244 | - For SE-ResNet models: False
245 | - For SE-ResNeXt models: False
246 | downsample_kernel_size (int): Kernel size for downsampling convolutions
247 | in layer2, layer3 and layer4.
248 | - For SENet154: 3
249 | - For SE-ResNet models: 1
250 | - For SE-ResNeXt models: 1
251 | downsample_padding (int): Padding for downsampling convolutions in
252 | layer2, layer3 and layer4.
253 | - For SENet154: 1
254 | - For SE-ResNet models: 0
255 | - For SE-ResNeXt models: 0
256 | num_classes (int): Number of outputs in `last_linear` layer.
257 | - For all models: 1000
258 | """
259 | super(SENet, self).__init__()
260 | self.inplanes = inplanes
261 | if input_3x3:
262 | layer0_modules = [
263 | ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1,
264 | bias=False)),
265 | ('bn1', nn.BatchNorm2d(64)),
266 | ('relu1', nn.ReLU(inplace=True)),
267 | ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1,
268 | bias=False)),
269 | ('bn2', nn.BatchNorm2d(64)),
270 | ('relu2', nn.ReLU(inplace=True)),
271 | ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1,
272 | bias=False)),
273 | ('bn3', nn.BatchNorm2d(inplanes)),
274 | ('relu3', nn.ReLU(inplace=True)),
275 | ]
276 | else:
277 | layer0_modules = [
278 | ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2,
279 | padding=3, bias=False)),
280 | ('bn1', nn.BatchNorm2d(inplanes)),
281 | ('relu1', nn.ReLU(inplace=True)),
282 | ]
283 | # To preserve compatibility with Caffe weights `ceil_mode=True`
284 | # is used instead of `padding=1`.
285 | layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2,
286 | ceil_mode=True)))
287 | self.layer0 = nn.Sequential(OrderedDict(layer0_modules))
288 | self.layer1 = self._make_layer(
289 | block,
290 | planes=64,
291 | blocks=layers[0],
292 | groups=groups,
293 | reduction=reduction,
294 | downsample_kernel_size=1,
295 | downsample_padding=0
296 | )
297 | self.layer2 = self._make_layer(
298 | block,
299 | planes=128,
300 | blocks=layers[1],
301 | stride=2,
302 | groups=groups,
303 | reduction=reduction,
304 | downsample_kernel_size=downsample_kernel_size,
305 | downsample_padding=downsample_padding
306 | )
307 | self.layer3 = self._make_layer(
308 | block,
309 | planes=256,
310 | blocks=layers[2],
311 | stride=2,
312 | groups=groups,
313 | reduction=reduction,
314 | downsample_kernel_size=downsample_kernel_size,
315 | downsample_padding=downsample_padding
316 | )
317 | self.layer4 = self._make_layer(
318 | block,
319 | planes=512,
320 | blocks=layers[3],
321 | stride=2,
322 | groups=groups,
323 | reduction=reduction,
324 | downsample_kernel_size=downsample_kernel_size,
325 | downsample_padding=downsample_padding
326 | )
327 | self.avg_pool = nn.AvgPool2d(7, stride=1)
328 | self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
329 | self.last_linear = nn.Linear(512 * block.expansion, num_classes)
330 | self.loss = loss
331 |
332 | def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
333 | downsample_kernel_size=1, downsample_padding=0):
334 | downsample = None
335 | if stride != 1 or self.inplanes != planes * block.expansion:
336 | downsample = nn.Sequential(
337 | nn.Conv2d(self.inplanes, planes * block.expansion,
338 | kernel_size=downsample_kernel_size, stride=stride,
339 | padding=downsample_padding, bias=False),
340 | nn.BatchNorm2d(planes * block.expansion),
341 | )
342 |
343 | layers = []
344 | layers.append(block(self.inplanes, planes, groups, reduction, stride,
345 | downsample))
346 | self.inplanes = planes * block.expansion
347 | for i in range(1, blocks):
348 | layers.append(block(self.inplanes, planes, groups, reduction))
349 |
350 | return nn.Sequential(*layers)
351 |
352 | def features(self, x):
353 | x = self.layer0(x)
354 | x = self.layer1(x)
355 | x = self.layer2(x)
356 | x = self.layer3(x)
357 | x = self.layer4(x)
358 | return x
359 |
360 | def logits(self, x):
361 | x = self.avg_pool(x)
362 | if self.dropout is not None:
363 | x = self.dropout(x)
364 | x = x.view(x.size(0), -1)
365 | x = self.last_linear(x)
366 | return x
367 |
368 | def forward(self, x):
369 | x = self.conv1(x)
370 | x = self.bn1(x)
371 | x = self.relu(x)
372 | x = self.maxpool(x)
373 |
374 | x = self.layer1(x)
375 | x = self.layer2(x)
376 | x = self.layer3(x)
377 | x = self.layer4(x)
378 |
379 | v = self.global_avgpool(x)
380 | v = v.view(v.size(0), -1)
381 |
382 | if self.fc is not None:
383 | v = self.fc(v)
384 |
385 | if not self.training:
386 | return v
387 |
388 | y = self.classifier(v)
389 |
390 | if self.loss == {'xent'}:
391 | return y
392 | elif self.loss == {'xent', 'htri'}:
393 | return y, v
394 | else:
395 | raise KeyError("Unsupported loss: {}".format(self.loss))
396 |
397 | def init_pretrained_weights(model, model_url):
398 | """
399 | Initialize model with pretrained weights.
400 | Layers that don't match with pretrained layers in name or size are kept unchanged.
401 | """
402 | pretrain_dict = model_zoo.load_url(model_url)
403 | model_dict = model.state_dict()
404 | pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
405 | model_dict.update(pretrain_dict)
406 | model.load_state_dict(model_dict)
407 | print('Initialized model with pretrained weights from {}'.format(model_url))
408 |
409 |
410 | def initialize_pretrained_model(model, num_classes, settings):
411 | assert num_classes == settings['num_classes'], \
412 | 'num_classes should be {}, but is {}'.format(
413 | settings['num_classes'], num_classes)
414 | model.load_state_dict(model_zoo.load_url(settings['url']))
415 | model.input_space = settings['input_space']
416 | model.input_size = settings['input_size']
417 | model.input_range = settings['input_range']
418 | model.mean = settings['mean']
419 | model.std = settings['std']
420 |
421 |
422 | def senet154(num_classes, loss={'xent'}, pretrained=True, **kwargs):
423 | model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16,
424 | dropout_p=0.2, num_classes=num_classes, loss=loss)
425 | if pretrained:
426 | init_pretrained_weights(model, model_urls['senet154'])
427 | return model
428 |
429 |
430 | def se_resnet50(num_classes, loss={'xent'}, pretrained=True, **kwargs):
431 | model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16,
432 | dropout_p=None, inplanes=64, input_3x3=False,
433 | downsample_kernel_size=1, downsample_padding=0,
434 | num_classes=num_classes, loss=loss)
435 | if pretrained:
436 | init_pretrained_weights(model, model_urls['se_resnet50'])
437 | return model
438 |
439 |
440 | def se_resnet101(num_classes, loss={'xent'}, pretrained=True, **kwargs):
441 | model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16,
442 | dropout_p=None, inplanes=64, input_3x3=False,
443 | downsample_kernel_size=1, downsample_padding=0,
444 | num_classes=num_classes, loss=loss)
445 | if pretrained:
446 | init_pretrained_weights(model, model_urls['se_resnet101'])
447 | return model
448 |
449 |
450 | def se_resnet152(num_classes, loss={'xent'}, pretrained=True, **kwargs):
451 | model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16,
452 | dropout_p=None, inplanes=64, input_3x3=False,
453 | downsample_kernel_size=1, downsample_padding=0,
454 | num_classes=num_classes, loss=loss)
455 | if pretrained:
456 | init_pretrained_weights(model, model_urls['se_resnet152'])
457 | return model
458 |
459 |
460 | def se_resnext50_32x4d(num_classes, loss={'xent'}, pretrained=True, **kwargs):
461 | model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16,
462 | dropout_p=None, inplanes=64, input_3x3=False,
463 | downsample_kernel_size=1, downsample_padding=0,
464 | num_classes=num_classes, loss=loss)
465 | if pretrained:
466 | init_pretrained_weights(model, model_urls['se_resnext50_32x4d'])
467 | return model
468 |
469 |
470 | def se_resnext101_32x4d(num_classes, loss={'xent'}, pretrained=True, **kwargs):
471 | model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16,
472 | dropout_p=None, inplanes=64, input_3x3=False,
473 | downsample_kernel_size=1, downsample_padding=0,
474 | num_classes=num_classes, loss=loss)
475 | # if pretrained is not None:
476 | # settings = pretrained_settings['se_resnext101_32x4d'][pretrained]
477 | # initialize_pretrained_model(model, num_classes, settings)
478 | if pretrained:
479 | init_pretrained_weights(model, model_urls['se_resnext101_32x4d'])
480 | return model
--------------------------------------------------------------------------------
/pkl/duke/index.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adhirajghosh/RPTM_reid/183e1f77a0979ab2ffa08b0bdb1c43ef0f633ad5/pkl/duke/index.pkl
--------------------------------------------------------------------------------
/pkl/vehicleid/index.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adhirajghosh/RPTM_reid/183e1f77a0979ab2ffa08b0bdb1c43ef0f633ad5/pkl/vehicleid/index.pkl
--------------------------------------------------------------------------------
/pkl/veri/cids.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adhirajghosh/RPTM_reid/183e1f77a0979ab2ffa08b0bdb1c43ef0f633ad5/pkl/veri/cids.pkl
--------------------------------------------------------------------------------
/pkl/veri/data.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adhirajghosh/RPTM_reid/183e1f77a0979ab2ffa08b0bdb1c43ef0f633ad5/pkl/veri/data.pkl
--------------------------------------------------------------------------------
/pkl/veri/index_vp.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adhirajghosh/RPTM_reid/183e1f77a0979ab2ffa08b0bdb1c43ef0f633ad5/pkl/veri/index_vp.pkl
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | timm
2 | yacs
3 | opencv-python
4 | albumentations
5 | matplotlib
6 | umap-learn
7 | Pillow
8 | pretrainedmodels
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from __future__ import division
3 | import os
4 | import os.path as osp
5 | import time
6 | import torch
7 | import numpy as np
8 | import numpy.ma as ma
9 | import random
10 | try:
11 | from apex.fp16_utils import *
12 | from apex import amp, optimizers
13 | except ImportError: # will be 3.x series
14 | print('This is not an error. If you want to use low precision, i.e., fp16, please install the apex with cuda support (https://github.com/NVIDIA/apex) and update pytorch to 1.0')
15 |
16 | from eval import do_test
17 | from utils.loggers import RankLogger
18 | from utils.torchtools import accuracy, save_checkpoint
19 | from utils.functions import search, strint
20 | from utils.avgmeter import AverageMeter
21 | from utils.visualtools import visualize_ranked_results
22 |
23 |
24 | def do_train(cfg, trainloader, train_dict, data_tfr, testloader_dict, dm,
25 | model, optimizer, scheduler, criterion_htri,criterion_xent):
26 | ranklogger = RankLogger(cfg.DATASET.SOURCE_NAME, cfg.DATASET.TARGET_NAME)
27 | gms = train_dict['gms']
28 | pidx = train_dict['pidx']
29 | folders = []
30 | for fld in os.listdir(cfg.DATASET.SPLIT_DIR):
31 | folders.append(fld)
32 | # data_index = search_index(gms, cfg.DATASET.SPLIT_DIR, folders)
33 | data_index = search(cfg.DATASET.SPLIT_DIR)
34 |
35 | for epoch in range(cfg.SOLVER.MAX_EPOCHS):
36 | losses = AverageMeter()
37 | xent_losses = AverageMeter()
38 | htri_losses = AverageMeter()
39 | accs = AverageMeter()
40 | batch_time = AverageMeter()
41 |
42 | model.train()
43 | for p in model.parameters():
44 | p.requires_grad = True # open all layers
45 |
46 | end = time.time()
47 | for batch_idx, (img, label, index, pid, _) in enumerate(trainloader):
48 |
49 | trainX, trainY = torch.zeros((cfg.SOLVER.TRAIN_BATCH_SIZE * 3, 3, cfg.INPUT.HEIGHT, cfg.INPUT.WIDTH), dtype=torch.float32), torch.zeros(
50 | (cfg.SOLVER.TRAIN_BATCH_SIZE * 3), dtype=torch.int64)
51 |
52 | for i in range(cfg.SOLVER.TRAIN_BATCH_SIZE):
53 |
54 | labelx = str(label[i])
55 | # print(labelx)
56 | indexx = int(index[i])
57 | cidx = int(pid[i])
58 | if indexx > len(gms[labelx]) - 1:
59 | indexx = len(gms[labelx]) - 1
60 | a = gms[labelx][indexx]
61 |
62 | if cfg.MODEL.RPTM_SELECT == 'min':
63 | threshold = np.arange(10)
64 | elif cfg.MODEL.RPTM_SELECT == 'mean':
65 | threshold = np.arange(np.amax(gms[labelx][indexx])//2)
66 | elif cfg.MODEL.RPTM_SELECT == 'max':
67 | threshold = np.arange(np.amax(gms[labelx][indexx]))
68 | else:
69 | threshold = np.arange(np.amax(gms[labelx][indexx]) // 2) #defaults to mean
70 |
71 | minpos = np.argmin(ma.masked_where(a == threshold, a))
72 | pos_dic = data_tfr[data_index[cidx][1] + minpos]
73 | # print(pos_dic[1])
74 | neg_label = int(labelx)
75 | while True:
76 | neg_label = random.choice(range(1, 770))
77 | if neg_label is not int(labelx) and os.path.isdir(
78 | os.path.join(cfg.DATASET.SPLIT_DIR, strint(neg_label, 'veri'))) is True:
79 | break
80 | negative_label = strint(neg_label, 'veri')
81 | neg_cid = pidx[negative_label]
82 | neg_index = random.choice(range(0, len(gms[negative_label])))
83 |
84 | neg_dic = data_tfr[data_index[neg_cid][1] + neg_index]
85 | trainX[i] = img[i]
86 | trainX[i + cfg.SOLVER.TRAIN_BATCH_SIZE] = pos_dic[0]
87 | trainX[i + (cfg.SOLVER.TRAIN_BATCH_SIZE * 2)] = neg_dic[0]
88 | trainY[i] = cidx
89 | trainY[i + cfg.SOLVER.TRAIN_BATCH_SIZE] = pos_dic[3]
90 | trainY[i + (cfg.SOLVER.TRAIN_BATCH_SIZE * 2)] = neg_dic[3]
91 | optimizer.zero_grad()
92 | trainX = trainX.cuda()
93 | trainY = trainY.cuda()
94 | outputs, features = model(trainX)
95 | xent_loss = criterion_xent(outputs[0:cfg.SOLVER.TRAIN_BATCH_SIZE], trainY[0:cfg.SOLVER.TRAIN_BATCH_SIZE])
96 | htri_loss = criterion_htri(features, trainY)
97 |
98 |
99 | loss = cfg.LOSS.LAMBDA_HTRI * htri_loss + cfg.LOSS.LAMBDA_XENT * xent_loss
100 |
101 | if cfg.SOLVER.USE_AMP:
102 | with amp.scale_loss(loss, optimizer) as scaled_loss:
103 | scaled_loss.backward()
104 | else:
105 | loss.backward()
106 |
107 | optimizer.step()
108 | for param_group in optimizer.param_groups:
109 | # print(param_group['lr'] )
110 | lrrr = str(param_group['lr'])
111 |
112 | batch_time.update(time.time() - end)
113 | losses.update(loss.item(), trainY.size(0))
114 | htri_losses.update(htri_loss.item(), trainY.size(0))
115 | accs.update(accuracy(outputs[0:cfg.SOLVER.TRAIN_BATCH_SIZE], trainY[0:cfg.SOLVER.TRAIN_BATCH_SIZE])[0])
116 |
117 | if (batch_idx) % cfg.MISC.PRINT_FREQ == 0:
118 | print('Train ', end=" ")
119 | print('Epoch: [{0}][{1}/{2}]\t'
120 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
121 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
122 | 'Acc {acc.val:.2f} ({acc.avg:.2f})\t'
123 | 'lr {lrrr} \t'.format(
124 | epoch + 1, batch_idx + 1, len(trainloader),
125 | batch_time=batch_time,
126 | loss=losses,
127 | acc=accs,
128 | lrrr=lrrr,
129 | ))
130 |
131 | end = time.time()
132 |
133 | scheduler.step()
134 | print('=> Test')
135 |
136 | for name in cfg.DATASET.TARGET_NAME:
137 | print('Evaluating {} ...'.format(name))
138 | queryloader = testloader_dict[name]['query']
139 | galleryloader = testloader_dict[name]['gallery']
140 | rank1, distmat, rank2, distmat_re = do_test(model, queryloader, galleryloader, cfg.TEST.TEST_BATCH_SIZE, cfg.MISC.USE_GPU, cfg.DATASET.TARGET_NAME[0])
141 |
142 | ranklogger.write(name, epoch + 1, rank1)
143 | ranklogger.write(name, epoch + 1, rank2)
144 |
145 | if (epoch + 1) == cfg.SOLVER.MAX_EPOCHS and cfg.TEST.VIS_RANK == True:
146 | visualize_ranked_results(
147 | distmat_re, dm.return_testdataset_by_name(name),
148 | save_dir=osp.join(cfg.MISC.SAVE_DIR, 'ranked_results', name),
149 | topk=20)
150 |
151 | del queryloader
152 | del galleryloader
153 | del distmat
154 | # print(torch.cuda.memory_allocated(),torch.cuda.memory_cached())
155 | torch.cuda.empty_cache()
156 |
157 | if (epoch + 1) == cfg.SOLVER.MAX_EPOCHS:
158 | save_checkpoint({
159 | 'state_dict': model.state_dict(),
160 | 'rank1': rank2,
161 | 'epoch': epoch + 1,
162 | 'arch': cfg.MODEL.ARCH,
163 | 'optimizer': optimizer.state_dict(),
164 | }, cfg.MISC.SAVE_DIR, cfg.SOLVER.OPTIMIZER_NAME)
--------------------------------------------------------------------------------
/utils/avgmeter.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 |
4 |
5 | class AverageMeter(object):
6 | """Computes and stores the average and current value.
7 |
8 | Code imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
9 | """
10 |
11 | def __init__(self):
12 | self.reset()
13 |
14 | def reset(self):
15 | self.val = 0
16 | self.avg = 0
17 | self.sum = 0
18 | self.count = 0
19 |
20 | def update(self, val, n=1):
21 | self.val = val
22 | self.sum += val * n
23 | self.count += n
24 | self.avg = self.sum / self.count
25 |
--------------------------------------------------------------------------------
/utils/create_gms_index.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import re
4 | import xml.etree.ElementTree as ET
5 | from collections import defaultdict
6 |
7 | import cv2
8 | import numpy as np
9 | from tqdm import tqdm
10 |
11 | def str2int(car_id_num: str, dataset: str):
12 | if dataset == 'veri':
13 | if len(car_id_num) == 1:
14 | car_id_num = '00' + str(car_id_num)
15 | elif len(car_id_num) == 2:
16 | car_id_num = '0' + str(car_id_num)
17 | else:
18 | pass
19 | elif dataset == 'vehicleid' or dataset == 'veriwild':
20 | if len(car_id_num) == 1:
21 | car_id_num = '0000' + car_id_num
22 | elif len(car_id_num) == 2:
23 | car_id_num = '000' + car_id_num
24 | elif len(car_id_num) == 3:
25 | car_id_num = '00' + car_id_num
26 | elif len(car_id_num) == 4:
27 | car_id_num = '0' + car_id_num
28 | else:
29 | pass
30 | else:
31 | raise ValueError(f"Unknown dataset: {dataset}")
32 | return car_id_num
33 |
34 | def compute_gms_matches(orb: cv2.ORB, bf: cv2.BFMatcher, img1: np.ndarray, img2: np.ndarray, verbose: bool = False):
35 | # Detect and compute keypoints and descriptors
36 | kp1, des1 = orb.detectAndCompute(img1, None)
37 | kp2, des2 = orb.detectAndCompute(img2, None)
38 |
39 | # Check if descriptors were found
40 | if des1 is None or des2 is None or len(des1) == 0 or len(des2) == 0:
41 | if verbose:
42 | print(f"Warning: No descriptors found for one of the images. Returning 0 matches.")
43 | return 0
44 |
45 | if des1.shape[1] != des2.shape[1]:
46 | if verbose:
47 | print(f"Error: Descriptor sizes don't match. Cannot proceed with matching.")
48 | return 0
49 |
50 | # Convert des1 and des2 to have the same type
51 | # Fixes: cv2.error: OpenCV(4.10.0) /io/opencv/modules/core/src/batch_distance.cpp:274: error: (-215:Assertion failed) type == src2.type() && src1.cols == src2.cols && (type == CV_32F || type == CV_8U) in function 'batchDistance'
52 | if(des1.dtype != [np.uint8, np.float32]) or (des1.dtype != [np.uint8, np.float32]):
53 | if verbose:
54 | print(f"Warning: Converting descriptors to np.uint8.")
55 | des1 = des1.astype(np.uint8)
56 |
57 | if(des2.dtype != [np.uint8, np.float32]) or (des2.dtype != [np.uint8, np.float32]):
58 | if verbose:
59 | print(f"Warning: Converting descriptors to np.uint8.")
60 | des2 = des2.astype(np.uint8)
61 |
62 | # Perform initial matching
63 | matches = bf.match(des1, des2)
64 |
65 | # Apply GMS matching
66 | gms_matches = cv2.xfeatures2d.matchGMS(size1=img1.shape[:2], size2=img2.shape[:2],
67 | keypoints1=kp1, keypoints2=kp2,
68 | matches1to2=matches, withRotation=True)
69 | return len(gms_matches)
70 |
71 | def process_class(image_paths: list, image_size: tuple = (224, 224), verbose: bool = False):
72 | n = len(image_paths)
73 | width, height = image_size
74 | adj_matrix = np.zeros((n, n), dtype=np.int32) # Initialize the adjacency matrix
75 |
76 | # Iterate over all the images
77 | for i in range(n):
78 | # Read and resize, as per paper
79 | img1 = cv2.imread(image_paths[i], cv2.IMREAD_GRAYSCALE)
80 | img1 = cv2.resize(img1, (width, height))
81 | img1 = img1.astype(np.uint8)
82 |
83 | # Only iterate over j > i
84 | for j in range(i + 1, n):
85 | # Read and resize, as per paper
86 | img2 = cv2.imread(image_paths[j], cv2.IMREAD_GRAYSCALE)
87 | img2 = cv2.resize(img2, (width, height))
88 | img2 = img2.astype(np.uint8)
89 |
90 | # Compute GMS matches
91 | matches = compute_gms_matches(orb, bf, img1, img2, verbose)
92 |
93 | # Set both (i,j) and (j,i) at once
94 | adj_matrix[i, j] = matches
95 | adj_matrix[j, i] = matches
96 |
97 | pbar.update(1)
98 |
99 | return adj_matrix
100 |
101 | def get_dict(dataset: str, train_file: str, img_dir: str):
102 | class_images = defaultdict(list)
103 | original_to_new_id = {}
104 | new_id_counter = 1
105 |
106 | # Read query file names
107 | if (dataset == 'veri'):
108 | # Open the file with the correct encoding
109 | with open(train_file, 'r', encoding='gb2312') as file:
110 | xml_content = file.read()
111 |
112 | # Parse the XML string
113 | root = ET.fromstring(xml_content)
114 |
115 | # Iterate through each Item element
116 | for item in root.findall('.//Item'):
117 | vehicle_id_str = item.get('vehicleID')
118 |
119 | car_id_num = str(int(re.search(r'\d+', vehicle_id_str).group()))
120 | car_id_num = str2int(car_id_num, dataset)
121 |
122 | full_image_path = os.path.join(img_dir, item.get('imageName'))
123 |
124 | class_images[car_id_num].append(full_image_path)
125 | elif (dataset == 'veriwild'):
126 | with open(train_file, 'r') as file:
127 | lines = [line.strip().split(' ') for line in file.readlines()]
128 |
129 | # Iterate through each Item element
130 | for line in tqdm(lines, desc='Splitting train images'):
131 | vehicle_id = line[0].split('/')[0]
132 | image_name = line[0].split('/')[1]
133 | full_image_path = os.path.join(img_dir, vehicle_id, image_name)
134 |
135 | # Here we map the original vehicle ID to a new ID
136 | if vehicle_id not in original_to_new_id:
137 | new_id = str2int(str(new_id_counter), dataset)
138 |
139 | original_to_new_id[vehicle_id] = new_id
140 | new_id_counter += 1
141 |
142 | new_vehicle_id = original_to_new_id[vehicle_id]
143 | class_images[new_vehicle_id].append(full_image_path)
144 | elif (dataset == 'vehicleid'):
145 | with open(train_file, 'r') as file:
146 | lines = [line.strip() for line in file.readlines()]
147 |
148 | # Iterate through each Item element
149 | for line in tqdm(lines, desc='Splitting train images'):
150 | image_name = line.split(' ')[0]
151 | vehicle_id = line.split(' ')[1]
152 | full_image_path = os.path.join(img_dir, image_name + '.jpg')
153 |
154 | # Here we map the original vehicle ID to a new ID
155 | if vehicle_id not in original_to_new_id:
156 | new_id = str2int(str(new_id_counter), dataset)
157 |
158 | original_to_new_id[vehicle_id] = new_id
159 | new_id_counter += 1
160 |
161 | new_vehicle_id = original_to_new_id[vehicle_id]
162 | class_images[new_vehicle_id].append(full_image_path)
163 | else:
164 | raise ValueError(f"Unknown dataset: {dataset}")
165 |
166 | # Return both the dictionary and the mapping from original to new IDs
167 | return class_images, original_to_new_id
168 |
169 | # ========================== MAIN ========================== #
170 | # Set up paths
171 | dataset = 'veri' # 'veri' (Which is: VeRi-776) / 'veriwild' / 'vehicleid'
172 | base_datapath = 'data'
173 | gms_path = 'gms'
174 | image_size = (224, 224) # Before computing GMS matches, resize the images to this size (as per paper)
175 | verbose = False # Set to True to see more detailed output, errors etc.
176 |
177 | if (dataset == 'veri'):
178 | data_path = os.path.join(base_datapath, 'veri')
179 | img_dir = os.path.join(data_path, 'image_train')
180 | train_file = os.path.join(data_path, 'train_label.xml')
181 | elif (dataset == 'veriwild'):
182 | data_path = os.path.join(base_datapath, 'veriwild')
183 | img_dir = os.path.join(data_path, 'images')
184 | train_file = os.path.join(data_path, 'train_test_split', 'train_list_start0.txt')
185 | elif (dataset == 'vehicleid'):
186 | data_path = os.path.join(base_datapath, 'vehicleid')
187 | img_dir = os.path.join(data_path, 'image')
188 | train_file = os.path.join(data_path, 'train_test_split', 'train_list.txt')
189 | else:
190 | raise ValueError(f"Unknown dataset: {dataset}")
191 |
192 | output = os.path.join(gms_path, dataset)
193 | if (os.path.exists(output) == False):
194 | os.makedirs(output)
195 | if verbose:
196 | print(f"Output directory created at: {output}")
197 |
198 | # Instantiate the ORB and BFMatcher objects
199 | orb = cv2.ORB_create(nfeatures = 10000, fastThreshold = 0)
200 | bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck = False)
201 |
202 | # Get the dictionary of class images
203 | # It will contain keys as class labels and values as lists of image paths
204 | # Example:
205 | # {
206 | # '1': ['/190472.jpg', '/134671.jpg', ...],
207 | # '2': ['/134718.jpg', '/824511.jpg', ...],
208 | # ...
209 | # }
210 | class_images, id_mapping = get_dict(dataset, train_file, img_dir)
211 |
212 | # # In case you want to filter the dictionary to start from a certain class (resuming from a checkpoint, basically)
213 | # resuming_class = 5344
214 | # class_images = {k: v for k, v in class_images.items() if int(k) >= resuming_class}
215 |
216 | # Create the index_vp.pkl file
217 | # dict_index should contain the name of the images as keys, and a tuple (class_label, counter) as values
218 | dict_index = {os.path.basename(image): (class_label, counter)
219 | for class_label, images in class_images.items()
220 | for counter, image in enumerate(images)}
221 |
222 | with open(os.path.join(output, f'index_vp_{dataset}.pkl'), 'wb') as f:
223 | pickle.dump(dict_index, f)
224 | if verbose:
225 | print("Successfully saved the Index Pickle file.")
226 |
227 | # Get how many iterations are needed (for tqdm)
228 | total_iterations = sum([len(images) for images in class_images.values()])
229 |
230 | # Process each class
231 | with tqdm(total=total_iterations, desc="Processing pickle files") as pbar:
232 | for class_label, images in class_images.items():
233 | print(f"Processing class {class_label} with {len(images)} images")
234 | adj_matrix = process_class(images, image_size=image_size, verbose=verbose)
235 |
236 | # Save the adjacency matrix
237 | with open(os.path.join(output, f'{class_label}.pkl'), 'wb') as f:
238 | pickle.dump(adj_matrix, f)
239 |
240 | print("Processing complete. Adjacency matrices saved.")
241 | # ========================================================== #
--------------------------------------------------------------------------------
/utils/evaluation.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 | from __future__ import division
4 |
5 | import numpy as np
6 |
7 |
8 | def evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
9 | """Evaluation with veri metric
10 | Key: for each query identity, its gallery images from the same camera view are discarded.
11 | """
12 | num_q, num_g = distmat.shape
13 |
14 | if num_g < max_rank:
15 | max_rank = num_g
16 | print('Note: number of gallery samples is quite small, got {}'.format(num_g))
17 |
18 | indices = np.argsort(distmat, axis=1)
19 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
20 |
21 | # compute cmc curve for each query
22 | all_cmc = []
23 | all_AP = []
24 | num_valid_q = 0. # number of valid query
25 |
26 | for q_idx in range(num_q):
27 | # get query pid and camid
28 | q_pid = q_pids[q_idx]
29 | q_camid = q_camids[q_idx]
30 |
31 | # remove gallery samples that have the same pid and camid with query
32 | order = indices[q_idx]
33 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid)
34 | keep = np.invert(remove)
35 |
36 | # compute cmc curve
37 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
38 | if not np.any(raw_cmc):
39 | # this condition is true when query identity does not appear in gallery
40 | continue
41 |
42 | cmc = raw_cmc.cumsum()
43 | cmc[cmc > 1] = 1
44 |
45 | all_cmc.append(cmc[:max_rank])
46 | num_valid_q += 1.
47 |
48 | # compute average precision
49 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
50 | num_rel = raw_cmc.sum()
51 | tmp_cmc = raw_cmc.cumsum()
52 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
53 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
54 | AP = tmp_cmc.sum() / num_rel
55 | all_AP.append(AP)
56 |
57 | #assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
58 |
59 | all_cmc = np.asarray(all_cmc).astype(np.float32)
60 | all_cmc = all_cmc.sum(0) / num_valid_q
61 | #mAP = np.amax(all_AP)
62 | mAP = np.mean(all_AP)
63 | #mAP = all_AP
64 | return all_cmc, mAP
65 |
66 | def evaluate_vid(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
67 | """Evaluation with vehicleid metric
68 | Key: gallery contains one images for each test vehicles and the other images in test
69 | use as query
70 | """
71 | num_q, num_g = distmat.shape
72 |
73 | if num_g < max_rank:
74 | max_rank = num_g
75 | print('Note: number of gallery samples is quite small, got {}'.format(num_g))
76 |
77 | indices = np.argsort(distmat, axis=1)
78 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
79 |
80 | # compute cmc curve for each query
81 | all_cmc = []
82 | all_AP = []
83 | num_valid_q = 0. # number of valid query
84 |
85 | for q_idx in range(num_q):
86 | # get query pid and camid
87 | # remove gallery samples that have the same pid and camid with query
88 | '''
89 | q_pid = q_pids[q_idx]
90 | q_camid = q_camids[q_idx]
91 | order = indices[q_idx]
92 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) # original remove
93 | '''
94 | remove = False # without camid imformation remove no images in gallery
95 | keep = np.invert(remove)
96 | # compute cmc curve
97 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
98 | if not np.any(raw_cmc):
99 | # this condition is true when query identity does not appear in gallery
100 | continue
101 |
102 | cmc = raw_cmc.cumsum()
103 | cmc[cmc > 1] = 1
104 |
105 | all_cmc.append(cmc[:max_rank])
106 | num_valid_q += 1.
107 |
108 | # compute average precision
109 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision
110 | num_rel = raw_cmc.sum()
111 | tmp_cmc = raw_cmc.cumsum()
112 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)]
113 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc
114 | AP = tmp_cmc.sum() / num_rel
115 | all_AP.append(AP)
116 |
117 | assert num_valid_q > 0, 'Error: all query identities do not appear in gallery'
118 |
119 | all_cmc = np.asarray(all_cmc).astype(np.float32)
120 | all_cmc = all_cmc.sum(0) / num_valid_q
121 | mAP = np.mean(all_AP)
122 |
123 | return all_cmc, mAP
--------------------------------------------------------------------------------
/utils/functions.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | def keyfromval(dic, val):
5 | return list(dic.keys())[list(dic.values()).index(val)]
6 |
7 | def strint(x, dataset):
8 | if dataset =='veri':
9 |
10 | if len(str(x))==1:
11 | return '00'+str(x)
12 | if len(str(x))==2:
13 | return '0'+str(x)
14 | if len(str(x))==3:
15 | return str(x)
16 |
17 | if dataset == 'duke':
18 | if len(str(x))==1:
19 | return '000'+str(x)
20 | if len(str(x))==2:
21 | return '00'+str(x)
22 | if len(str(x))==3:
23 | return '0'+str(x)
24 | if len(str(x))==4:
25 | return str(x)
26 |
27 | if dataset == 'vehicleid':
28 | if len(str(x))==1:
29 | return '0000'+str(x)
30 | if len(str(x))==2:
31 | return '000'+str(x)
32 | if len(str(x))==3:
33 | return '00'+str(x)
34 | if len(str(x))==4:
35 | return '0'+str(x)
36 | if len(str(x))==5:
37 | return str(x)
38 |
39 | def search1(pkl, path):
40 | #MAIN ONE
41 | start = 0
42 | count = 0
43 | end = 0
44 | data_index = []
45 | for i in range(1, 777):
46 | label = strint(i)
47 | if os.path.isdir(os.path.join(path, label)) is True:
48 | size = len(pkl[label])
49 | start = end
50 | end = end+size
51 | data_index.append((count, start, end-1))
52 | count+=1
53 | if label == '769':
54 | size = len(pkl[label])
55 | start = end
56 | end = end+size
57 | data_index.append((count, start, end-1))
58 | break
59 | return data_index
60 |
61 | def search(path):
62 | #MAIN ONE
63 | start = 0
64 | count = 0
65 | end = 0
66 | data_index = []
67 | for i in sorted(os.listdir(path)):
68 | x = len(os.listdir(os.path.join(path,i)))
69 | data_index.append((count, start, start+x-1))
70 | count = count+1
71 | start = start+x
72 | return data_index
73 |
74 | def search_index(pkl, path, folders):
75 | start = 0
76 | count = 0
77 | end = 0
78 | data_index = []
79 | for i in range(0, len(folders)):
80 | label = folders[i]
81 | size = len(pkl[label])
82 | start = end
83 | end = end+size
84 | data_index.append((count, start, end-1))
85 | count+=1
86 | return data_index
87 |
88 | def create_split_dirs(cfg):
89 | src_root = cfg.DATASET.TRAIN_DIR
90 | dest_root = cfg.DATASET.SPLIT_DIR
91 | if cfg.DATASET.SOURCE_NAME[0] == 'vehicleid':
92 | if os.path.exists(os.path.join(cfg.DATASET.ROOT_DIR, 'vehicleid/images/')):
93 | return
94 | for i in os.listdir(src_root):
95 | if cfg.DATASET.SOURCE_NAME[0] == 'veri':
96 | folder_name = i.split('_', 2)[0][1:]
97 | else:
98 | folder_name = i.split('_', 2)[0]
99 | if not os.path.exists(os.path.join(dest_root, folder_name)):
100 | os.makedirs(os.path.join(dest_root, folder_name))
101 | shutil.copyfile(os.path.join(src_root, i), os.path.join(dest_root, folder_name, i))
102 |
--------------------------------------------------------------------------------
/utils/generaltools.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 | from __future__ import division
4 |
5 | import random
6 | import numpy as np
7 | import torch
8 |
9 |
10 | def set_random_seed(seed):
11 | random.seed(seed)
12 | np.random.seed(seed)
13 | torch.manual_seed(seed)
14 | torch.cuda.manual_seed_all(seed)
15 |
--------------------------------------------------------------------------------
/utils/iotools.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import os
4 | import os.path as osp
5 | import errno
6 | import json
7 | import warnings
8 |
9 |
10 | def mkdir_if_missing(directory):
11 | if not osp.exists(directory):
12 | try:
13 | os.makedirs(directory)
14 | except OSError as e:
15 | if e.errno != errno.EEXIST:
16 | raise
17 |
18 |
19 | def check_isfile(path):
20 | isfile = osp.isfile(path)
21 | if not isfile:
22 | warnings.warn('No file found at "{}"'.format(path))
23 | return isfile
24 |
25 |
26 | def read_json(fpath):
27 | with open(fpath, 'r') as f:
28 | obj = json.load(f)
29 | return obj
30 |
31 |
32 | def write_json(obj, fpath):
33 | mkdir_if_missing(osp.dirname(fpath))
34 | with open(fpath, 'w') as f:
35 | json.dump(obj, f, indent=4, separators=(',', ': '))
36 |
--------------------------------------------------------------------------------
/utils/kwargs.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | from __future__ import division
3 |
4 | def return_kwargs(cfg):
5 | if cfg.DATASET.SOURCE_NAME[0] == 'vehicleid':
6 | dataset_kwargs = {
7 | 'source_names': cfg.DATASET.SOURCE_NAME,
8 | 'target_names': cfg.DATASET.TARGET_NAME,
9 | 'root': cfg.DATASET.ROOT_DIR,
10 | 'height': cfg.INPUT.HEIGHT,
11 | 'width': cfg.INPUT.WIDTH,
12 | 'test_size': cfg.TEST.TEST_SIZE,
13 | 'train_batch_size': cfg.SOLVER.TRAIN_BATCH_SIZE,
14 | 'test_batch_size': cfg.TEST.TEST_BATCH_SIZE,
15 | 'train_sampler': cfg.DATALOADER.SAMPLER,
16 | 'random_erase': cfg.INPUT.RANDOM_ERASE,
17 | 'color_jitter': cfg.INPUT.JITTER,
18 | 'color_aug': cfg.INPUT.AUG
19 | }
20 | else:
21 | dataset_kwargs = {
22 | 'source_names': cfg.DATASET.SOURCE_NAME,
23 | 'target_names': cfg.DATASET.TARGET_NAME,
24 | 'root': cfg.DATASET.ROOT_DIR,
25 | 'height': cfg.INPUT.HEIGHT,
26 | 'width': cfg.INPUT.WIDTH,
27 | 'test_size': cfg.TEST.TEST_SIZE,
28 | 'train_batch_size': cfg.SOLVER.TRAIN_BATCH_SIZE,
29 | 'test_batch_size': cfg.TEST.TEST_BATCH_SIZE,
30 | 'train_sampler': cfg.DATALOADER.SAMPLER,
31 | 'random_erase': cfg.INPUT.RANDOM_ERASE,
32 | 'color_jitter': cfg.INPUT.JITTER,
33 | 'color_aug': cfg.INPUT.AUG
34 | }
35 |
36 | transform_kwargs = {
37 | 'height': cfg.INPUT.HEIGHT,
38 | 'width': cfg.INPUT.WIDTH,
39 | 'random_erase': cfg.INPUT.RANDOM_ERASE,
40 | 'color_jitter': cfg.INPUT.JITTER,
41 | 'color_aug': cfg.INPUT.AUG
42 | }
43 |
44 | optimizer_kwargs = {
45 | 'optim': cfg.SOLVER.OPTIMIZER_NAME,
46 | 'lr': cfg.SOLVER.BASE_LR,
47 | 'weight_decay': cfg.SOLVER.WEIGHT_DECAY,
48 | 'momentum': cfg.SOLVER.MOMENTUM,
49 | 'sgd_dampening': cfg.SOLVER.SGD_DAMP,
50 | 'sgd_nesterov': cfg.SOLVER.NESTEROV
51 | }
52 |
53 | lr_scheduler_kwargs = {
54 | 'lr_scheduler': cfg.SOLVER.LR_SCHEDULER,
55 | 'stepsize': cfg.SOLVER.STEPSIZE,
56 | 'gamma': cfg.SOLVER.GAMMA
57 | }
58 |
59 | return dataset_kwargs, transform_kwargs, optimizer_kwargs, lr_scheduler_kwargs
--------------------------------------------------------------------------------
/utils/loggers.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import sys
4 | import os
5 | import os.path as osp
6 |
7 | from .iotools import mkdir_if_missing
8 |
9 |
10 | class Logger(object):
11 | """
12 | Write console output to external text file.
13 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py.
14 | """
15 | def __init__(self, fpath=None):
16 | self.console = sys.stdout
17 | self.file = None
18 | if fpath is not None:
19 | mkdir_if_missing(osp.dirname(fpath))
20 | self.file = open(fpath, 'w')
21 |
22 | def __del__(self):
23 | self.close()
24 |
25 | def __enter__(self):
26 | pass
27 |
28 | def __exit__(self, *args):
29 | self.close()
30 |
31 | def write(self, msg):
32 | self.console.write(msg)
33 | if self.file is not None:
34 | self.file.write(msg)
35 |
36 | def flush(self):
37 | self.console.flush()
38 | if self.file is not None:
39 | self.file.flush()
40 | os.fsync(self.file.fileno())
41 |
42 | def close(self):
43 | self.console.close()
44 | if self.file is not None:
45 | self.file.close()
46 |
47 |
48 | class RankLogger(object):
49 | """
50 | RankLogger records the rank1 matching accuracy obtained for each
51 | test dataset at specified evaluation steps and provides a function
52 | to show the summarized results, which are convenient for analysis.
53 | Args:
54 | - source_names (list): list of strings (names) of source datasets.
55 | - target_names (list): list of strings (names) of target datasets.
56 | """
57 | def __init__(self, source_names, target_names):
58 | self.source_names = source_names
59 | self.target_names = target_names
60 | self.logger = {name: {'epoch': [], 'rank1': []} for name in self.target_names}
61 |
62 | def write(self, name, epoch, rank1):
63 | self.logger[name]['epoch'].append(epoch)
64 | self.logger[name]['rank1'].append(rank1)
65 |
66 | def show_summary(self):
67 | print('=> Show performance summary')
68 | for name in self.target_names:
69 | from_where = 'source' if name in self.source_names else 'target'
70 | print('{} ({})'.format(name, from_where))
71 | for epoch, rank1 in zip(self.logger[name]['epoch'], self.logger[name]['rank1']):
72 | print('- epoch {}\t rank1 {:.1%}'.format(epoch, rank1))
73 |
--------------------------------------------------------------------------------
/utils/mean_and_std.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def get_mean_and_std(dataloader, dataset):
5 | # Compute the mean and std value of dataset.
6 | mean = torch.zeros(3)
7 | std = torch.zeros(3)
8 | print('==> Computing mean and std..')
9 | for inputs, _, _ in dataloader:
10 | for i in range(3):
11 | mean[i] += inputs[:,i,:,:].mean()
12 | std[i] += inputs[:,i,:,:].std()
13 | mean.div_(len(dataset))
14 | std.div_(len(dataset))
15 | return mean, std
16 |
17 |
18 | def calculate_mean_and_std(dataset_loader, dataset_size):
19 | mean = torch.zeros(3)
20 | std = torch.zeros(3)
21 | for data in dataset_loader:
22 | now_batch_size, c, h, w = data[0].shape
23 | mean += torch.sum(torch.mean(torch.mean(data[0], dim=3), dim=2), dim=0)
24 | std += torch.sum(torch.std(data[0].view(now_batch_size, c, h * w), dim=2), dim=0)
25 | return mean/dataset_size, std/dataset_size
26 |
--------------------------------------------------------------------------------
/utils/reranking.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | """
4 | Created on Fri, 25 May 2018 20:29:09
5 |
6 | @author: luohao
7 | """
8 |
9 | """
10 | CVPR2017 paper:Zhong Z, Zheng L, Cao D, et al. Re-ranking Person Re-identification with k-reciprocal Encoding[J]. 2017.
11 | url:http://openaccess.thecvf.com/content_cvpr_2017/papers/Zhong_Re-Ranking_Person_Re-Identification_CVPR_2017_paper.pdf
12 | Matlab version: https://github.com/zhunzhong07/person-re-ranking
13 | """
14 |
15 | """
16 | API
17 |
18 | probFea: all feature vectors of the query set (torch tensor)
19 | probFea: all feature vectors of the gallery set (torch tensor)
20 | k1,k2,lambda: parameters, the original paper is (k1=20,k2=6,lambda=0.3)
21 | MemorySave: set to 'True' when using MemorySave mode
22 | Minibatch: avaliable when 'MemorySave' is 'True'
23 | """
24 |
25 | import numpy as np
26 | import torch
27 | from scipy.spatial.distance import cdist
28 |
29 | def re_ranking(probFea, galFea, k1, k2, lambda_value, local_distmat=None, only_local=False):
30 | # if feature vector is numpy, you should use 'torch.tensor' transform it to tensor
31 | query_num = probFea.size(0)
32 | all_num = query_num + galFea.size(0)
33 | if only_local:
34 | original_dist = local_distmat
35 | else:
36 | feat = torch.cat([probFea, galFea])
37 | # print('using GPU to compute original distance')
38 | distmat = torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num) + \
39 | torch.pow(feat, 2).sum(dim=1, keepdim=True).expand(all_num, all_num).t()
40 | distmat.addmm_(1, -2, feat, feat.t())
41 | original_dist = distmat.cpu().numpy()
42 | del feat
43 | if not local_distmat is None:
44 | original_dist = original_dist + local_distmat
45 | gallery_num = original_dist.shape[0]
46 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0))
47 | V = np.zeros_like(original_dist).astype(np.float16)
48 | initial_rank = np.argsort(original_dist).astype(np.int32)
49 |
50 | # print('starting re_ranking')
51 | for i in range(all_num):
52 | # k-reciprocal neighbors
53 | forward_k_neigh_index = initial_rank[i, :k1 + 1]
54 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1]
55 | fi = np.where(backward_k_neigh_index == i)[0]
56 | k_reciprocal_index = forward_k_neigh_index[fi]
57 | k_reciprocal_expansion_index = k_reciprocal_index
58 | for j in range(len(k_reciprocal_index)):
59 | candidate = k_reciprocal_index[j]
60 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1]
61 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,
62 | :int(np.around(k1 / 2)) + 1]
63 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]
64 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
65 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len(
66 | candidate_k_reciprocal_index):
67 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index)
68 |
69 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
70 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index])
71 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight)
72 | original_dist = original_dist[:query_num, ]
73 | if k2 != 1:
74 | V_qe = np.zeros_like(V, dtype=np.float16)
75 | for i in range(all_num):
76 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0)
77 | V = V_qe
78 | del V_qe
79 | del initial_rank
80 | invIndex = []
81 | for i in range(gallery_num):
82 | invIndex.append(np.where(V[:, i] != 0)[0])
83 |
84 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16)
85 |
86 | for i in range(query_num):
87 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16)
88 | indNonZero = np.where(V[i, :] != 0)[0]
89 | indImages = [invIndex[ind] for ind in indNonZero]
90 | for j in range(len(indNonZero)):
91 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]],
92 | V[indImages[j], indNonZero[j]])
93 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min)
94 |
95 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value
96 | del original_dist
97 | del V
98 | del jaccard_dist
99 | final_dist = final_dist[:query_num, query_num:]
100 | return final_dist
101 |
102 | def re_ranking_numpy(probFea, galFea, k1, k2, lambda_value, local_distmat=None, only_local=False):
103 | query_num = probFea.shape[0]
104 | all_num = query_num + galFea.shape[0]
105 | if only_local:
106 | original_dist = local_distmat
107 | else:
108 | q_g_dist = cdist(probFea, galFea)
109 | q_q_dist = cdist(probFea, probFea)
110 | g_g_dist = cdist(galFea, galFea)
111 | original_dist = np.concatenate(
112 | [np.concatenate([q_q_dist, q_g_dist], axis=1), np.concatenate([q_g_dist.T, g_g_dist], axis=1)],
113 | axis=0)
114 | original_dist = np.power(original_dist, 2).astype(np.float32)
115 | original_dist = np.transpose(1. * original_dist / np.max(original_dist, axis=0))
116 | if not local_distmat is None:
117 | original_dist = original_dist + local_distmat
118 | gallery_num = original_dist.shape[0]
119 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0))
120 | V = np.zeros_like(original_dist).astype(np.float16)
121 | initial_rank = np.argsort(original_dist).astype(np.int32)
122 |
123 | print('starting re_ranking')
124 | for i in range(all_num):
125 | # k-reciprocal neighbors
126 | forward_k_neigh_index = initial_rank[i, :k1 + 1]
127 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1]
128 | fi = np.where(backward_k_neigh_index == i)[0]
129 | k_reciprocal_index = forward_k_neigh_index[fi]
130 | k_reciprocal_expansion_index = k_reciprocal_index
131 | for j in range(len(k_reciprocal_index)):
132 | candidate = k_reciprocal_index[j]
133 | candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1]
134 | candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,
135 | :int(np.around(k1 / 2)) + 1]
136 | fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]
137 | candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
138 | if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len(
139 | candidate_k_reciprocal_index):
140 | k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index)
141 |
142 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
143 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index])
144 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight)
145 | original_dist = original_dist[:query_num, ]
146 | if k2 != 1:
147 | V_qe = np.zeros_like(V, dtype=np.float16)
148 | for i in range(all_num):
149 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0)
150 | V = V_qe
151 | del V_qe
152 | del initial_rank
153 | invIndex = []
154 | for i in range(gallery_num):
155 | invIndex.append(np.where(V[:, i] != 0)[0])
156 |
157 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16)
158 |
159 | for i in range(query_num):
160 | temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float16)
161 | indNonZero = np.where(V[i, :] != 0)[0]
162 | indImages = [invIndex[ind] for ind in indNonZero]
163 | for j in range(len(indNonZero)):
164 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]],
165 | V[indImages[j], indNonZero[j]])
166 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min)
167 |
168 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value
169 | del original_dist
170 | del V
171 | del jaccard_dist
172 | final_dist = final_dist[:query_num, query_num:]
173 | return final_dist
--------------------------------------------------------------------------------
/utils/torchtools.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 | from __future__ import division
4 |
5 | from collections import OrderedDict
6 | import shutil
7 | import warnings
8 | import os.path as osp
9 |
10 | import torch
11 | import torch.nn as nn
12 |
13 | from .iotools import mkdir_if_missing
14 |
15 |
16 | def save_checkpoint(state, save_dir, opt, is_best=False, remove_module_from_keys=False):
17 | mkdir_if_missing(save_dir)
18 | if remove_module_from_keys:
19 | # remove 'module.' in state_dict's keys
20 | state_dict = state['state_dict']
21 | new_state_dict = OrderedDict()
22 | for k, v in state_dict.items():
23 | if k.startswith('module.'):
24 | k = k[7:]
25 | new_state_dict[k] = v
26 | state['state_dict'] = new_state_dict
27 | # save
28 | epoch = state['epoch']
29 |
30 | arch = state['arch']
31 | fpath = osp.join(save_dir, 'model_' + arch+'_'+opt+ '_'+str(epoch)+'.pth.tar')
32 | torch.save(state, fpath)
33 | print('Checkpoint saved to "{}"'.format(fpath))
34 | if is_best:
35 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar'))
36 |
37 |
38 | def resume_from_checkpoint(ckpt_path, model, optimizer=None):
39 | print('Loading checkpoint from "{}"'.format(ckpt_path))
40 | ckpt = torch.load(ckpt_path)
41 | model.load_state_dict(ckpt['state_dict'])
42 | print('Loaded model weights')
43 | if optimizer is not None:
44 | optimizer.load_state_dict(ckpt['optimizer'])
45 | print('Loaded optimizer')
46 | start_epoch = ckpt['epoch']
47 | print('** previous epoch = {}\t previous rank1 = {:.1%}'.format(start_epoch, ckpt['rank1']))
48 | return start_epoch
49 |
50 |
51 | def adjust_learning_rate(optimizer, base_lr, epoch, stepsize=20, gamma=0.1,
52 | linear_decay=False, final_lr=0, max_epoch=100):
53 | if linear_decay:
54 | # linearly decay learning rate from base_lr to final_lr
55 | frac_done = epoch / max_epoch
56 | lr = frac_done * final_lr + (1. - frac_done) * base_lr
57 | else:
58 | # decay learning rate by gamma for every stepsize
59 | lr = base_lr * (gamma ** (epoch // stepsize))
60 |
61 | for param_group in optimizer.param_groups:
62 | param_group['lr'] = lr
63 |
64 |
65 | def set_bn_to_eval(m):
66 | # 1. no update for running mean and var
67 | # 2. scale and shift parameters are still trainable
68 | classname = m.__class__.__name__
69 | if classname.find('BatchNorm') != -1:
70 | m.eval()
71 |
72 |
73 | def open_all_layers(model):
74 | """
75 | Open all layers in model for training.
76 | Args:
77 | - model (nn.Module): neural net model.
78 | """
79 | model.train()
80 | for p in model.parameters():
81 | p.requires_grad = True
82 |
83 |
84 | def open_specified_layers(model, open_layers):
85 | """
86 | Open specified layers in model for training while keeping
87 | other layers frozen.
88 | Args:
89 | - model (nn.Module): neural net model.
90 | - open_layers (list): list of layer names.
91 | """
92 | if isinstance(model, nn.DataParallel):
93 | model = model.module
94 |
95 | for layer in open_layers:
96 | assert hasattr(model, layer), '"{}" is not an attribute of the model, please provide the correct name'.format(
97 | layer)
98 |
99 | for name, module in model.named_children():
100 | if name in open_layers:
101 | module.train()
102 | for p in module.parameters():
103 | p.requires_grad = True
104 | else:
105 | module.eval()
106 | for p in module.parameters():
107 | p.requires_grad = False
108 |
109 |
110 | def count_num_param(model):
111 | num_param = sum(p.numel() for p in model.parameters()) / 1e+06
112 |
113 | if isinstance(model, nn.DataParallel):
114 | model = model.module
115 |
116 | if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Module):
117 | # we ignore the classifier because it is unused at test time
118 | num_param -= sum(p.numel() for p in model.classifier.parameters()) / 1e+06
119 | return num_param
120 |
121 |
122 | def accuracy(output, target, topk=(1,)):
123 | """Computes the accuracy over the k top predictions for the specified values of k"""
124 | with torch.no_grad():
125 | maxk = max(topk)
126 | batch_size = target.size(0)
127 |
128 | if isinstance(output, (tuple, list)):
129 | output = output[0]
130 |
131 | _, pred = output.topk(maxk, 1, True, True)
132 | pred = pred.t()
133 | correct = pred.eq(target.view(1, -1).expand_as(pred))
134 |
135 | res = []
136 | for k in topk:
137 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
138 | acc = correct_k.mul_(100.0 / batch_size)
139 | res.append(acc.item())
140 | return res
141 |
142 |
143 | def load_pretrained_weights(model, weight_path):
144 | """Load pretrianed weights to model
145 | Incompatible layers (unmatched in name or size) will be ignored
146 | Args:
147 | - model (nn.Module): network model, which must not be nn.DataParallel
148 | - weight_path (str): path to pretrained weights
149 | """
150 | checkpoint = torch.load(weight_path)
151 | if 'state_dict' in checkpoint:
152 | state_dict = checkpoint['state_dict']
153 | else:
154 | state_dict = checkpoint
155 | model_dict = model.state_dict()
156 | new_state_dict = OrderedDict()
157 | matched_layers, discarded_layers = [], []
158 | for k, v in state_dict.items():
159 | # If the pretrained state_dict was saved as nn.DataParallel,
160 | # keys would contain "module.", which should be ignored.
161 | if k.startswith('module.'):
162 | k = k[7:]
163 | if k in model_dict and model_dict[k].size() == v.size():
164 | new_state_dict[k] = v
165 | matched_layers.append(k)
166 | else:
167 | discarded_layers.append(k)
168 | model_dict.update(new_state_dict)
169 | model.load_state_dict(model_dict)
170 | if len(matched_layers) == 0:
171 | warnings.warn(
172 | 'The pretrained weights "{}" cannot be loaded, please check the key names manually (** ignored and continue **)'.format(
173 | weight_path))
174 | else:
175 | print('Successfully loaded pretrained weights from "{}"'.format(weight_path))
176 | if len(discarded_layers) > 0:
177 | print("** The following layers are discarded due to unmatched keys or layer size: {}".format(
178 | discarded_layers))
179 |
--------------------------------------------------------------------------------
/utils/visualtools.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import print_function
3 |
4 | import numpy as np
5 | import os.path as osp
6 | import shutil
7 |
8 | from .iotools import mkdir_if_missing
9 |
10 |
11 | def visualize_ranked_results(distmat, dataset, save_dir='log/ranked_results', topk=20):
12 | """
13 | Visualize ranked results
14 | Args:
15 | - distmat: distance matrix of shape (num_query, num_gallery).
16 | - dataset: a 2-tuple containing (query, gallery), each contains a list of (img_path, pid, camid);
17 | for imgreid, img_path is a string, while for vidreid, img_path is a tuple containing
18 | a sequence of strings.
19 | - save_dir: directory to save output images.
20 | - topk: int, denoting top-k images in the rank list to be visualized.
21 | """
22 | num_q, num_g = distmat.shape
23 |
24 | print('Visualizing top-{} ranks'.format(topk))
25 | print('# query: {}\n# gallery {}'.format(num_q, num_g))
26 | print('Saving images to "{}"'.format(save_dir))
27 |
28 | query, gallery = dataset
29 | #assert num_q == len(query)
30 | #assert num_g == len(gallery)
31 |
32 | indices = np.argsort(distmat, axis=1)
33 | mkdir_if_missing(save_dir)
34 |
35 | def _cp_img_to(src, dst, rank, prefix):
36 | """
37 | - src: image path or tuple (for vidreid)
38 | - dst: target directory
39 | - rank: int, denoting ranked position, starting from 1
40 | - prefix: string
41 | """
42 | if isinstance(src, tuple) or isinstance(src, list):
43 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3))
44 | mkdir_if_missing(dst)
45 | for img_path in src:
46 | shutil.copy(img_path, dst)
47 | else:
48 | dst = osp.join(dst, prefix + '_top' + str(rank).zfill(3) + '_name_' + osp.basename(src))
49 | shutil.copy(src, dst)
50 |
51 | for q_idx in range(num_q):
52 | qimg_path, qpid, qcamid = query[q_idx]
53 | if isinstance(qimg_path, tuple) or isinstance(qimg_path, list):
54 | qdir = osp.join(save_dir, osp.basename(qimg_path[0]))
55 | else:
56 | qdir = osp.join(save_dir, osp.basename(qimg_path))
57 | mkdir_if_missing(qdir)
58 | _cp_img_to(qimg_path, qdir, rank=0, prefix='query')
59 |
60 | rank_idx = 1
61 | for g_idx in indices[q_idx, :]:
62 | gimg_path, gpid, gcamid = gallery[g_idx]
63 | invalid = (qpid == gpid) & (qcamid == gcamid)
64 | if not invalid:
65 | _cp_img_to(gimg_path, qdir, rank=rank_idx, prefix='gallery')
66 | rank_idx += 1
67 | if rank_idx > topk:
68 | break
69 |
70 | print("Done")
71 |
--------------------------------------------------------------------------------