├── LICENSE
├── README.md
├── configs
├── RegDB_MCJA.yml
├── SYSU_MCJA.yml
├── __init__.py
└── default
│ ├── __init__.py
│ ├── dataset.py
│ └── strategy.py
├── data
├── __init__.py
├── dataset.py
├── sampler.py
└── transform.py
├── engine
├── __init__.py
├── engine.py
└── metric.py
├── figs
└── mcja_overall_structure.png
├── losses
├── __init__.py
└── cm_retrieval_loss.py
├── main.py
├── models
├── __init__.py
├── backbones
│ ├── __init__.py
│ └── resnet.py
├── mcja.py
└── modules
│ ├── __init__.py
│ └── mda.py
└── utils
├── __init__.py
├── calc_acc.py
├── eval_data.py
├── mser_rerank.py
└── tools.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 workingcoder
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Bridging the Gap: Multi-level Cross-modality Joint Alignment for Visible-infrared Person Re-identification
2 |
3 | By [Tengfei Liang](https://scholar.google.com/citations?user=YE6fPvgAAAAJ), [Yi Jin](https://scholar.google.com/citations?user=NQAenU0AAAAJ), [Wu Liu](https://scholar.google.com/citations?user=rQpizr0AAAAJ), [Tao Wang](https://scholar.google.com/citations?user=F3C5oAcAAAAJ), [Songhe Feng](https://scholar.google.com/citations?user=K5lqMYgAAAAJ), [Yidong Li](https://scholar.google.com/citations?user=3PagRQEAAAAJ).
4 |
5 | This repository is an official implementation of the paper [Bridging the Gap: Multi-level Cross-modality Joint Alignment for Visible-infrared Person Re-identification](https://ieeexplore.ieee.org/abstract/document/10472470). [`IEEEXplore`](https://ieeexplore.ieee.org/abstract/document/10472470) [`Google Drive`](https://drive.google.com/file/d/19-2f-gTj3P9tV-YhabtVpXSgShxFrGrs/view?usp=sharing)
6 |
7 | *Notes:*
8 |
9 | This repository offers the complete code of the entire method, featuring a well-organized directory structure and detailed comments, facilitating the training and testing of the model.
10 | It is hoped that this can serve as a new baseline for cross-modal visible-infrared person re-identification.
11 |
12 |
13 | ## Abstract
14 |
15 | Visible-Infrared person Re-IDentification (VI-ReID) is a challenging cross-modality image retrieval task that aims to match pedestrians' images across visible and infrared cameras.
16 | To solve the modality gap, existing mainstream methods adopt a learning paradigm converting the image retrieval task into an image classification task with cross-entropy loss and auxiliary metric learning losses.
17 | These losses follow the strategy of adjusting the distribution of extracted embeddings to reduce the intra-class distance and increase the inter-class distance.
18 | However, such objectives do not precisely correspond to the final test setting of the retrieval task, resulting in a new gap at the optimization level.
19 | By rethinking these keys of VI-ReID, we propose a simple and effective method, the Multi-level Cross-modality Joint Alignment (MCJA), bridging both the modality and objective-level gap.
20 | For the former, we design the Visible-Infrared Modality Coordinator in the image space and propose the Modality Distribution Adapter in the feature space, effectively reducing modality discrepancy of the feature extraction process.
21 | For the latter, we introduce a new Cross-Modality Retrieval loss.
22 | It is the first work to constrain from the perspective of the ranking list in the VI-ReID, aligning with the goal of the testing stage.
23 | Moreover, to strengthen the robustness and cross-modality retrieval ability, we further introduce a Multi-Spectral Enhanced Ranking strategy for the testing phase.
24 | Based on the global feature only, our method outperforms existing methods by a large margin, achieving the remarkable rank-1 of 89.51% and mAP of 87.58% on the most challenging single-shot setting and all-search mode of the SYSU-MM01 dataset.
25 | (For more details, please refer to [the original paper](https://ieeexplore.ieee.org/abstract/document/10472470))
26 |
27 |
28 |
29 |

30 |
31 | Fig. 1: Overall architecture of the proposed MCJA model.
32 |
33 |
34 |
35 | ## Requirements
36 |
37 | The code of this repository is designed to run on a single GPU.
38 | Here are the Python packages and their corresponding versions during the execution of our experiments:
39 |
40 | - Python 3.8
41 | - apex==0.1
42 | - numpy==1.21.5
43 | - Pillow==8.4.0
44 | - pytorch_ignite==0.2.1
45 | - scipy==1.7.3
46 | - torch==1.8.1+cu111
47 | - torchsort==0.1.9
48 | - torchvision==0.9.1+cu111
49 | - yacs==0.1.8
50 |
51 | *Notes:*
52 | When installing the 'apex' package, please refer to its [official repository - apex](https://github.com/NVIDIA/apex).
53 |
54 | *P.S.* Higher or Lower versions of these packages might be supported.
55 | When attempting to use a different version of PyTorch, please be mindful of the compatibility with pytorch_ignite, torchsort, etc.
56 |
57 |
58 | ## Dataset & Preparation
59 |
60 | During the experiment, we evaluate our proposed method on publicly available datasets, SYSU-MM01 and RegDB, which are commonly used for comparison in VI-ReID.
61 | Please download the corresponding datasets and modify the path of the data_root folder in [configs/default/dataset.py](./configs/default/dataset.py).
62 |
63 |
64 | ## Experiments
65 |
66 | Our [main.py](./main.py) supports both training and testing as well as testing only.
67 |
68 | ### Train
69 |
70 | During the training process, executing the following command allows for the training and evaluation of MCJA models on the SYSU-MM01 and RegDB datasets:
71 |
72 | ```bash
73 | python main.py --cfg configs/SYSU_MCJA.yml --gpu 0 --seed 8 --desc MCJA
74 | ```
75 |
76 | ```bash
77 | python main.py --cfg configs/RegDB_MCJA.yml --gpu 0 --seed 8 --desc MCJA
78 | ```
79 |
80 | ### Test
81 |
82 | When conducting tests only, set 'test_only' to true in the 'XXXX.yml' configuration file and specify the path for loading the model in the 'resume' setting.
83 | Then, execute the same command as mentioned above to complete the testing and evaluation:
84 |
85 | ```bash
86 | python main.py --cfg configs/SYSU_MCJA.yml --gpu 0 --desc MCJA_test_only
87 | ```
88 |
89 | ```bash
90 | python main.py --cfg configs/RegDB_MCJA.yml --gpu 0 --desc MCJA_test_only
91 | ```
92 |
93 | *Notes:*
94 | The '--seed' and '--desc' of [main.py](./main.py) are optional.
95 | The former is used to add a suffix description to the current run, while the latter controls the random seed for this experiment.
96 |
97 |
98 | ## Citation
99 | If you find MCJA useful in your research, please kindly cite this paper in your publications:
100 | ```bibtex
101 | @article{TCSVT24_MCJA,
102 | author = {Liang, Tengfei and Jin, Yi and Liu, Wu and Wang, Tao and Feng, Songhe and Li, Yidong},
103 | title = {Bridging the Gap: Multi-level Cross-modality Joint Alignment for Visible-infrared Person Re-identification},
104 | journal = {IEEE Transactions on Circuits and Systems for Video Technology},
105 | pages = {1-1},
106 | year = {2024},
107 | doi = {10.1109/TCSVT.2024.3377252}
108 | }
109 | ```
110 |
111 |
112 | ## Related Repos
113 | Our repository builds upon the work of others, and we extend our gratitude for their contributions.
114 | Below is a list of some of these works:
115 |
116 | - AGW - https://github.com/mangye16/Cross-Modal-Re-ID-baseline
117 | - MPANet - https://github.com/DoubtedSteam/MPANet
118 |
119 |
120 | ## License
121 |
122 | This repository is released under the MIT license. Please see the [LICENSE](./LICENSE) file for more information.
123 |
--------------------------------------------------------------------------------
/configs/RegDB_MCJA.yml:
--------------------------------------------------------------------------------
1 | # Customized Strategy Config
2 |
3 | prefix: RegDB
4 |
5 | # Setting for Data
6 | dataset: regdb
7 | image_size: (384, 192)
8 | sample_method: norm_triplet
9 | p_size: 4
10 | k_size: 32
11 | batch_size: 128
12 |
13 | # Settings for Augmentation
14 | random_flip: true
15 | random_crop: true
16 | random_erase: true
17 | color_jitter: true
18 | padding: 10
19 | vimc_wg: true
20 | vimc_cc: true
21 | vimc_sj: true
22 |
23 | # Settings for Model
24 | drop_last_stride: false
25 | mda_ratio: 2
26 | mda_m: 12
27 |
28 | # Setting for Loss
29 | loss_id: true
30 | loss_cmr: true
31 |
32 | # Settings for Training
33 | lr: 0.00035
34 | wd: 0.0005
35 | num_epoch: 300
36 | lr_step: [ 300 ]
37 | fp16: true
38 |
39 | # Settings for Logging
40 | start_eval: 0
41 | log_period: 5
42 | eval_interval: 1
43 |
44 | # Settings for Testing
45 | resume: ''
46 | mser: true
47 | test_only: false
48 |
--------------------------------------------------------------------------------
/configs/SYSU_MCJA.yml:
--------------------------------------------------------------------------------
1 | # Customized Strategy Config
2 |
3 | prefix: SYSU
4 |
5 | # Setting for Data
6 | dataset: sysu
7 | image_size: (384, 192)
8 | sample_method: norm_triplet
9 | p_size: 8
10 | k_size: 16
11 | batch_size: 128
12 |
13 | # Settings for Augmentation
14 | random_flip: true
15 | random_crop: true
16 | random_erase: true
17 | color_jitter: true
18 | padding: 10
19 | vimc_wg: true
20 | vimc_cc: true
21 | vimc_sj: true
22 |
23 | # Settings for Model
24 | drop_last_stride: false
25 | mda_ratio: 2
26 | mda_m: 2
27 |
28 | # Setting for Loss
29 | loss_id: true
30 | loss_cmr: true
31 |
32 | # Settings for Training
33 | lr: 0.00035
34 | wd: 0.0005
35 | num_epoch: 140
36 | lr_step: [ 80, 120 ]
37 | fp16: true
38 |
39 | # Settings for Logging
40 | start_eval: 0
41 | log_period: 10
42 | eval_interval: 1
43 |
44 | # Settings for Testing
45 | resume: ''
46 | mser: true
47 | test_only: false
48 |
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
1 | """MCJA/configs/__init__.py
2 | It is used to mark a directory as a Python package directory.
3 | """
--------------------------------------------------------------------------------
/configs/default/__init__.py:
--------------------------------------------------------------------------------
1 | """MCJA/configs/default/__init__.py
2 | It serves as an entry point for the default configuration settings.
3 | """
4 |
5 | from configs.default.dataset import dataset_cfg
6 | from configs.default.strategy import strategy_cfg
7 |
8 | __all__ = ["dataset_cfg", "strategy_cfg"]
9 |
--------------------------------------------------------------------------------
/configs/default/dataset.py:
--------------------------------------------------------------------------------
1 | """MCJA/configs/default/dataset.py
2 | It defines the default configuration settings for customized datasets.
3 | """
4 |
5 | from yacs.config import CfgNode
6 |
7 | dataset_cfg = CfgNode()
8 |
9 | dataset_cfg.sysu = CfgNode()
10 | dataset_cfg.sysu.num_id = 395
11 | dataset_cfg.sysu.num_cam = 6
12 | dataset_cfg.sysu.data_root = "../../Datasets/SYSU-MM01"
13 |
14 | dataset_cfg.regdb = CfgNode()
15 | dataset_cfg.regdb.num_id = 206
16 | dataset_cfg.regdb.num_cam = 2
17 | dataset_cfg.regdb.data_root = "../../Datasets/RegDB"
18 |
--------------------------------------------------------------------------------
/configs/default/strategy.py:
--------------------------------------------------------------------------------
1 | """MCJA/configs/default/strategy.py
2 | It outlines the default strategy configurations for the framework.
3 | """
4 |
5 | from yacs.config import CfgNode
6 |
7 | strategy_cfg = CfgNode()
8 |
9 | strategy_cfg.prefix = "SYSU"
10 |
11 | # Settings for Data
12 | strategy_cfg.dataset = "sysu"
13 | strategy_cfg.image_size = (384, 128)
14 | strategy_cfg.sample_method = "norm_triplet"
15 | strategy_cfg.p_size = 8
16 | strategy_cfg.k_size = 16
17 | strategy_cfg.batch_size = 128
18 |
19 | # Settings for Augmentation
20 | strategy_cfg.random_flip = True
21 | strategy_cfg.random_crop = True
22 | strategy_cfg.random_erase = True
23 | strategy_cfg.color_jitter = True
24 | strategy_cfg.padding = 10
25 | strategy_cfg.vimc_wg = True
26 | strategy_cfg.vimc_cc = True
27 | strategy_cfg.vimc_sj = True
28 |
29 | # Settings for Model
30 | strategy_cfg.drop_last_stride = False
31 | strategy_cfg.mda_ratio = 2
32 | strategy_cfg.mda_m = 2
33 |
34 | # Setting for Loss
35 | strategy_cfg.loss_id = True
36 | strategy_cfg.loss_cmr = True
37 |
38 | # Settings for Training
39 | strategy_cfg.lr = 0.00035
40 | strategy_cfg.wd = 0.0005
41 | strategy_cfg.num_epoch = 140
42 | strategy_cfg.lr_step = [80, 120]
43 | strategy_cfg.fp16 = True
44 |
45 | # Settings for Logging
46 | strategy_cfg.start_eval = 0
47 | strategy_cfg.log_period = 10
48 | strategy_cfg.eval_interval = 1
49 |
50 | # Settings for Testing
51 | strategy_cfg.resume = ''
52 | strategy_cfg.mser = True
53 | strategy_cfg.test_only = False
54 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | """MCJA/data/__init__.py
2 | It orchestrates the data handling process, focusing on data transforming and loader for the training and testing.
3 | """
4 |
5 | import torch
6 | import torchvision.transforms as T
7 | from torch.utils.data import DataLoader
8 |
9 | from data.dataset import SYSUDataset
10 | from data.dataset import RegDBDataset
11 |
12 | from data.sampler import CrossModalityIdentitySampler
13 | from data.sampler import CrossModalityRandomSampler
14 | from data.sampler import IdentityCrossModalitySampler
15 | from data.sampler import NormTripletSampler
16 |
17 | from data.transform import WeightedGrayscale
18 | from data.transform import ChannelCutMix
19 | from data.transform import SpectrumJitter
20 | from data.transform import ChannelAugmentation
21 | from data.transform import NoTransform
22 |
23 |
24 | def collate_fn(batch):
25 | """
26 | Custom collate function for DataLoader to handle batches of data. This function is designed to process a batch of
27 | data by separating and recombining the elements of each data point in the batch, except for a specified element
28 | (e.g., image paths). The recombination is done in such a way that it preserves the integrity of multi-modal data
29 | or other structured data necessary for model training or evaluation. The function operates by zipping the batch
30 | (which combines elements from each data point across the batch), then selectively stacking the elements to form a
31 | new batch tensor. It specifically skips over a index (in this case, the image paths) and reinserts this non-tensor
32 | data back into its original position in the batch. This approach ensures compatibility with models expecting data
33 | in a specific format while accommodating for elements like paths that should not be converted into tensors.
34 |
35 | Args:
36 | - batch (list): A list of tuples, where each tuple represents a data point and contains elements,
37 | including images, labels, camera IDs, image paths, and image IDs (img, label, cam_id, img_path, img_id).
38 |
39 | Returns:
40 | - list: A list of tensors and other data types recombined from the input batch, with tensor elements
41 | stacked along a new dimension and non-tensor elements (e.g., paths) preserved in their original form.
42 | """
43 |
44 | samples = list(zip(*batch))
45 |
46 | data = [torch.stack(x, 0) for i, x in enumerate(samples) if i != 3]
47 | data.insert(3, samples[3])
48 | return data
49 |
50 |
51 | def get_train_loader(dataset, root, sample_method, batch_size, p_size, k_size, image_size,
52 | random_flip=False, random_crop=False, random_erase=False, color_jitter=False, padding=0,
53 | vimc_wg=False, vimc_cc=False, vimc_sj=False, num_workers=4):
54 | """
55 | Constructs and returns a DataLoader for training with specific datasets (SYSU or RegDB), incorporating a variety
56 | of data augmentation techniques and sampling strategies tailored for mixed-modality (visible and infrared) computer
57 | vision tasks. This function allows for extensive customization of the data preprocessing pipeline, including options
58 | for random flipping, cropping, erasing, color jitter, and innovative visible-infrared modality coordination (VIMC).
59 | The sampling strategy for forming batches can be selected from among several options, including norm triplet, cross
60 | modality random, cross modality identity, and identity cross modality samplers, to suit different training needs
61 | and objectives. This function plays a critical role in preparing the data for efficient and effective training by
62 | dynamically adjusting to the specified dataset, sample method, data augmentation preferences, etc.
63 |
64 | Args:
65 | - dataset (str): Name of the dataset to use ('sysu' or 'regdb').
66 | - root (str): Root directory where the dataset is stored.
67 | - sample_method (str): Method used for sampling data points to form batches.
68 | - batch_size (int): Number of data points in each batch.
69 | - p_size (int): Number of identities per batch (used in certain sampling methods).
70 | - k_size (int): Number of instances per identity (used in certain sampling methods).
71 | - image_size (tuple): The size to which the images are resized.
72 | - random_flip (bool): Whether to randomly flip images horizontally.
73 | - random_crop (bool): Whether to randomly crop images.
74 | - random_erase (bool): Whether to randomly erase parts of images.
75 | - color_jitter (bool): Whether to apply random color jittering.
76 | - padding (int): Padding size used for random cropping.
77 | - vimc_wg (bool): Whether to apply weighted grayscale conversion.
78 | - vimc_cc (bool): Whether to apply channel cutmix augmentation.
79 | - vimc_sj (bool): Whether to apply spectrum jitter.
80 | - num_workers (int): Number of worker threads to use for loading data.
81 |
82 | Returns:
83 | - DataLoader: A DataLoader object ready for training, with batches formed
84 | according to the specified sample method and data augmentation settings.
85 | """
86 |
87 | # Data Transform - RGB ---------------------------------------------------------------------------------------------
88 | t = [T.Resize(image_size)]
89 |
90 | t.append(T.RandomChoice([
91 | T.RandomApply([T.ColorJitter(hue=0.20)], p=0.5) if color_jitter else NoTransform(),
92 |
93 | ###### Visible-Infrared Modality Coordinator (VIMC) ######
94 | WeightedGrayscale(p=0.5) if vimc_wg else NoTransform(),
95 | ChannelCutMix(p=0.5) if vimc_cc else NoTransform(),
96 | SpectrumJitter(factor=(0.00, 1.00), p=0.5) if vimc_sj else NoTransform(),
97 | ]))
98 |
99 | if random_flip:
100 | t.append(T.RandomHorizontalFlip())
101 |
102 | if random_crop:
103 | t.append(T.RandomCrop(image_size, padding=padding, fill=127))
104 |
105 | t.extend([T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
106 |
107 | if random_erase:
108 | t.append(T.RandomErasing(value=0, scale=(0.02, 0.30)))
109 |
110 | transform_rgb = T.Compose(t)
111 |
112 | # Data Transform - IR ----------------------------------------------------------------------------------------------
113 | t = [T.Resize(image_size)]
114 |
115 | if random_flip:
116 | t.append(T.RandomHorizontalFlip())
117 |
118 | if random_crop:
119 | t.append(T.RandomCrop(image_size, padding=padding, fill=127))
120 |
121 | t.extend([T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
122 |
123 | if random_erase:
124 | t.append(T.RandomErasing(value=0, scale=(0.02, 0.30)))
125 |
126 | transform_ir = T.Compose(t)
127 |
128 | # Dataset ----------------------------------------------------------------------------------------------------------
129 | if dataset == 'sysu':
130 | train_dataset = SYSUDataset(root, mode='train', transform_rgb=transform_rgb, transform_ir=transform_ir)
131 | elif dataset == 'regdb':
132 | train_dataset = RegDBDataset(root, mode='train', transform_rgb=transform_rgb, transform_ir=transform_ir)
133 | else:
134 | raise NotImplementedError(f'Dataset - {dataset} is not supported')
135 |
136 | # DataSampler ------------------------------------------------------------------------------------------------------
137 | assert sample_method in ['none', 'norm_triplet',
138 | 'cross_modality_random',
139 | 'cross_modality_identity',
140 | 'identity_cross_modality']
141 | shuffle = False
142 | if sample_method == 'none':
143 | sampler = None
144 | shuffle = True
145 | elif sample_method == 'norm_triplet':
146 | batch_size = p_size * k_size
147 | sampler = NormTripletSampler(train_dataset, p_size * k_size, k_size)
148 | elif sample_method == 'cross_modality_random':
149 | sampler = CrossModalityRandomSampler(train_dataset, batch_size)
150 | elif sample_method == 'cross_modality_identity':
151 | batch_size = p_size * k_size
152 | sampler = CrossModalityIdentitySampler(train_dataset, p_size, k_size)
153 | elif sample_method == 'identity_cross_modality':
154 | batch_size = p_size * k_size
155 | sampler = IdentityCrossModalitySampler(train_dataset, p_size * k_size, k_size)
156 | ## Note:
157 | ## When sample_method is in [none, cross_modity_random],
158 | ## batch_size is adopted, and p_size and k_size are invalid.
159 | ## When sample_method is in [norm_triplet, cross_modity_identity, identity_cross_modity],
160 | ## p_size and k_size are adopted, and batch_size is invalid.
161 |
162 | # DataLoader -------------------------------------------------------------------------------------------------------
163 | train_loader = DataLoader(train_dataset, batch_size, sampler=sampler,
164 | shuffle=shuffle, drop_last=True, pin_memory=True,
165 | collate_fn=collate_fn, num_workers=num_workers)
166 | return train_loader
167 |
168 |
169 | def get_test_loader(dataset, root, batch_size, image_size, num_workers=4, mode=None):
170 | """
171 | Creates and returns DataLoader objects for the gallery and query datasets, intended for use in the testing phase of
172 | mixed-modality (visible and infrared) computer vision tasks. This function configures data preprocessing pipelines
173 | with transformations tailored for evaluation, including resizing, channel augmentation based on a specified mode,
174 | and normalization. It supports various modes for channel augmentation, enabling flexibility in how images are
175 | processed and potentially enhancing model robustness during evaluation. The function is designed to work with
176 | specific datasets (SYSU or RegDB), preparing both gallery and query sets for efficient and effective testing.
177 |
178 | Args:
179 | - dataset (str): The name of the dataset to be used ('sysu' or 'regdb').
180 | - root (str): The root directory where the dataset is stored.
181 | - batch_size (int): The number of data points in each batch.
182 | - image_size (tuple): The size to which the images are resized.
183 | - num_workers (int): The number of worker threads to use for data loading.
184 | - mode (str, optional): The mode of channel augmentation to apply to the RGB data.
185 | Options include None, 'avg', 'r', 'g', 'b', 'rand', 'wg', 'cc', 'sj', with each providing a different manner.
186 |
187 | Returns:
188 | - tuple: A tuple containing two DataLoader objects, one for the gallery dataset and one for the query dataset,
189 | both configured for testing with the specified transformations and settings.
190 | """
191 |
192 | assert mode in [None, 'avg', 'r', 'g', 'b', 'rand', 'wg', 'cc', 'sj']
193 |
194 | # Data Transform - RGB ---------------------------------------------------------------------------------------------
195 | transform_rgb = T.Compose([
196 | T.Resize(image_size),
197 | ChannelAugmentation(mode=mode),
198 | T.ToTensor(),
199 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
200 | ])
201 |
202 | # Data Transform - IR ----------------------------------------------------------------------------------------------
203 | transform_ir = T.Compose([
204 | T.Resize(image_size),
205 | T.ToTensor(),
206 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
207 | ])
208 |
209 | # Dataset ----------------------------------------------------------------------------------------------------------
210 | if dataset == 'sysu':
211 | gallery_dataset = SYSUDataset(root, mode='gallery', transform_rgb=transform_rgb, transform_ir=transform_ir)
212 | query_dataset = SYSUDataset(root, mode='query', transform_rgb=transform_rgb, transform_ir=transform_ir)
213 | elif dataset == 'regdb':
214 | gallery_dataset = RegDBDataset(root, mode='gallery', transform_rgb=transform_rgb, transform_ir=transform_ir)
215 | query_dataset = RegDBDataset(root, mode='query', transform_rgb=transform_rgb, transform_ir=transform_ir)
216 | else:
217 | raise NotImplementedError(f'Dataset - {dataset} is not supported')
218 |
219 | # DataLoader -------------------------------------------------------------------------------------------------------
220 | query_loader = DataLoader(dataset=query_dataset,
221 | batch_size=batch_size,
222 | shuffle=False,
223 | pin_memory=True,
224 | drop_last=False,
225 | collate_fn=collate_fn,
226 | num_workers=num_workers)
227 |
228 | gallery_loader = DataLoader(dataset=gallery_dataset,
229 | batch_size=batch_size,
230 | shuffle=False,
231 | pin_memory=True,
232 | drop_last=False,
233 | collate_fn=collate_fn,
234 | num_workers=num_workers)
235 |
236 | return gallery_loader, query_loader
237 |
--------------------------------------------------------------------------------
/data/dataset.py:
--------------------------------------------------------------------------------
1 | """MCJA/data/dataset.py
2 | It defines dataset classes for handling image data in the context of cross-modality person re-identification task.
3 | """
4 |
5 | import os
6 | from glob import glob
7 |
8 | import torch
9 | from PIL import Image
10 | from torch.utils.data import Dataset
11 |
12 |
13 | class SYSUDataset(Dataset):
14 | """
15 | A dataset class tailored for the SYSU-MM01 dataset, designed to support the loading and preprocessing of data for
16 | training, querying, and gallery modes in cross-modality (visible and infrared) person re-identification tasks. The
17 | class handles the specifics of dataset directory structure, selecting appropriate subsets of images based on the
18 | mode (train, gallery, query) and performing specified transformations on the visible and infrared images separately.
19 | This class ensures that images are correctly matched with their labels, camera IDs, and other relevant information,
20 | facilitating their use in training and evaluation of models for person re-identification.
21 |
22 | The constructor of this class takes the root directory of the SYSU-MM01 dataset, the mode of operation (training,
23 | gallery, or query), and optional transformations for RGB (visible) and IR (infrared) images. If `memory_loaded` is
24 | set to True, all images are loaded into memory at initialization for faster access during training or evaluation.
25 | This class is compatible with PyTorch's DataLoader, making it easy to batch and shuffle the dataset as needed.
26 |
27 | Args:
28 | - root (str): The root directory where the SYSU-MM01 dataset is stored.
29 | - mode (str): The mode of dataset usage, which can be 'train', 'gallery', or 'query'.
30 | - transform_rgb (callable, optional): A function that takes in an RGB image and returns a transformed version.
31 | - transform_ir (callable, optional): A function that takes in an IR image and returns a transformed version.
32 | - memory_loaded (bool): If True, all images will be loaded into memory at initialization.
33 |
34 | Attributes:
35 | - img_paths (list): A list of paths to images that belong to the selected mode and IDs.
36 | - cam_ids (list): Camera IDs corresponding to each image in `img_paths`.
37 | - num_ids (int): The number of unique identities present in the selected mode.
38 | - ids (list): A list of identity labels corresponding to each image.
39 | - img_data (list, optional): If `memory_loaded` is True, this list contains preloaded images from `img_paths`.
40 |
41 | Methods:
42 | - __len__(): Returns the total number of images in the img_paths.
43 | - __getitem__(item): Retrieves the image and its metadata at the specified index,
44 | applying the appropriate transformations based on the camera ID (modality labels).
45 | """
46 |
47 | def __init__(self, root, mode='train', transform_rgb=None, transform_ir=None, memory_loaded=False):
48 | assert os.path.isdir(root)
49 | assert mode in ['train', 'gallery', 'query']
50 |
51 | if mode == 'train':
52 | train_ids = open(os.path.join(root, 'exp', 'train_id.txt')).readline()
53 | val_ids = open(os.path.join(root, 'exp', 'val_id.txt')).readline()
54 |
55 | train_ids = train_ids.strip('\n').split(',')
56 | val_ids = val_ids.strip('\n').split(',')
57 | selected_ids = train_ids + val_ids
58 | else:
59 | test_ids = open(os.path.join(root, 'exp', 'test_id.txt')).readline()
60 | selected_ids = test_ids.strip('\n').split(',')
61 |
62 | selected_ids = [int(i) for i in selected_ids]
63 | num_ids = len(selected_ids)
64 |
65 | img_paths = glob(os.path.join(root, '**/*.jpg'), recursive=True)
66 | img_paths = [path for path in img_paths if int(path.split('/')[-2]) in selected_ids]
67 |
68 | if mode == 'gallery':
69 | img_paths = [path for path in img_paths if int(path.split('/')[-3][-1]) in (1, 2, 4, 5)]
70 | elif mode == 'query':
71 | img_paths = [path for path in img_paths if int(path.split('/')[-3][-1]) in (3, 6)]
72 |
73 | img_paths = sorted(img_paths)
74 | self.img_paths = img_paths
75 | self.cam_ids = [int(path.split('/')[-3][-1]) for path in img_paths]
76 | self.num_ids = num_ids
77 | self.transform_rgb = transform_rgb
78 | self.transform_ir = transform_ir
79 |
80 | if mode == 'train':
81 | id_map = dict(zip(selected_ids, range(num_ids)))
82 | self.ids = [id_map[int(path.split('/')[-2])] for path in img_paths]
83 | else:
84 | self.ids = [int(path.split('/')[-2]) for path in img_paths]
85 |
86 | self.memory_loaded = memory_loaded
87 | if memory_loaded:
88 | self.img_data = [Image.open(path) for path in self.img_paths]
89 |
90 | def __len__(self):
91 | return len(self.img_paths)
92 |
93 | def __getitem__(self, item):
94 | path = self.img_paths[item]
95 | if self.memory_loaded:
96 | img = self.img_data[item]
97 | else:
98 | img = Image.open(path)
99 |
100 | label = torch.as_tensor(self.ids[item], dtype=torch.long)
101 | cam = torch.as_tensor(self.cam_ids[item], dtype=torch.long)
102 | item = torch.as_tensor(item, dtype=torch.long)
103 |
104 | if cam == 3 or cam == 6:
105 | if self.transform_ir is not None:
106 | img = self.transform_ir(img)
107 | else:
108 | if self.transform_rgb is not None:
109 | img = self.transform_rgb(img)
110 |
111 | return img, label, cam, path, item
112 |
113 |
114 | class RegDBDataset(Dataset):
115 | """
116 | A dataset class specifically designed for the RegDB dataset, facilitating the loading and preprocessing of images
117 | for the cross-modality (visible and infrared) person re-identification task. It supports different operational modes
118 | including training, gallery, and query, applying distinct preprocessing routines to RGB (visible) and IR (infrared)
119 | images as specified. This class handles the unique structure of the RegDB dataset, including its division into
120 | separate visible and thermal image sets, and it prepares the dataset for use in a PyTorch DataLoader, ensuring that
121 | images are appropriately matched with their labels and camera IDs.
122 |
123 | The constructor of this class takes several parameters including dataset's root directory, the mode of operation,
124 | optional transformations for both RGB and IR images, and a flag indicating whether images should be loaded into
125 | memory at initialization. This facilitates faster access during model training and evaluation, especially useful
126 | when working with large datasets or in environments where I/O speed is a bottleneck.
127 |
128 | Args:
129 | - root (str): The root directory where the RegDB dataset is stored.
130 | - mode (str): The mode of dataset usage, which can be 'train', 'gallery', or 'query'.
131 | - transform_rgb (callable, optional): A function/transform that applies to RGB images.
132 | - transform_ir (callable, optional): A function/transform that applies to IR images.
133 | - memory_loaded (bool): If set to True, all images are loaded into memory upfront for faster access.
134 |
135 | Attributes:
136 | - img_paths (list): A list of paths to images that belong to the selected mode and IDs.
137 | - cam_ids (list): Camera IDs derived from the image paths, (with visible cameras marked as 2 and thermal as 3).
138 | - num_ids (int): The number of unique identities present in the selected mode.
139 | - ids (list): A list of identity labels corresponding to each image.
140 | - img_data (list, optional): If `memory_loaded` is True, this list contains preloaded images from `img_paths`.
141 |
142 | Methods:
143 | - __len__(): Returns the total number of images in the img_paths.
144 | - __getitem__(item): Retrieves the image and its metadata at the specified index,
145 | applying the appropriate transformations based on the camera ID (modality labels).
146 | """
147 |
148 | def __init__(self, root, mode='train', transform_rgb=None, transform_ir=None, memory_loaded=False):
149 | assert os.path.isdir(root)
150 | assert mode in ['train', 'gallery', 'query']
151 |
152 | def loadIdx(index):
153 | Lines = index.readlines()
154 | idx = []
155 | for line in Lines:
156 | tmp = line.strip('\n')
157 | tmp = tmp.split(' ')
158 | idx.append(tmp)
159 | return idx
160 |
161 | num = '1'
162 | if mode == 'train':
163 | index_RGB = loadIdx(open(root + '/idx/train_visible_' + num + '.txt', 'r'))
164 | index_IR = loadIdx(open(root + '/idx/train_thermal_' + num + '.txt', 'r'))
165 | else:
166 | index_RGB = loadIdx(open(root + '/idx/test_visible_' + num + '.txt', 'r'))
167 | index_IR = loadIdx(open(root + '/idx/test_thermal_' + num + '.txt', 'r'))
168 |
169 | if mode == 'gallery':
170 | img_paths = [root + '/' + path for path, _ in index_RGB]
171 | elif mode == 'query':
172 | img_paths = [root + '/' + path for path, _ in index_IR]
173 | else:
174 | img_paths = [root + '/' + path for path, _ in index_RGB] + [root + '/' + path for path, _ in index_IR]
175 |
176 | selected_ids = [int(path.split('/')[-2]) for path in img_paths]
177 | selected_ids = list(set(selected_ids))
178 | num_ids = len(selected_ids)
179 |
180 | img_paths = sorted(img_paths)
181 | self.img_paths = img_paths
182 | self.cam_ids = [int(path.split('/')[-3] == 'Thermal') + 2 for path in img_paths]
183 | # Note: In SYSU-MM01 dataset, the visible cams are 1 2 4 5, and thermal cams are 3 6.
184 | # To simplify the code, visible cam is 2 and thermal cam is 3 in RegDB dataset.
185 | self.num_ids = num_ids
186 | self.transform_rgb = transform_rgb
187 | self.transform_ir = transform_ir
188 |
189 | if mode == 'train':
190 | id_map = dict(zip(selected_ids, range(num_ids)))
191 | self.ids = [id_map[int(path.split('/')[-2])] for path in img_paths]
192 | else:
193 | self.ids = [int(path.split('/')[-2]) for path in img_paths]
194 |
195 | self.memory_loaded = memory_loaded
196 | if memory_loaded:
197 | self.img_data = [Image.open(path) for path in self.img_paths]
198 |
199 | def __len__(self):
200 | return len(self.img_paths)
201 |
202 | def __getitem__(self, item):
203 | if self.memory_loaded:
204 | img = self.img_data[item]
205 | else:
206 | path = self.img_paths[item]
207 | img = Image.open(path)
208 |
209 | label = torch.tensor(self.ids[item], dtype=torch.long)
210 | cam = torch.tensor(self.cam_ids[item], dtype=torch.long)
211 | item = torch.tensor(item, dtype=torch.long)
212 |
213 | if cam == 3 or cam == 6:
214 | if self.transform_ir is not None:
215 | img = self.transform_ir(img)
216 | else:
217 | if self.transform_rgb is not None:
218 | img = self.transform_rgb(img)
219 |
220 | return img, label, cam, path, item
221 |
--------------------------------------------------------------------------------
/data/sampler.py:
--------------------------------------------------------------------------------
1 | """MCJA/data/sampler.py
2 | It defines several Sampler classes, designed to facilitate cross-modality (e.g.,RGB and IR) person re-identification.
3 | """
4 |
5 | import copy
6 | import numpy as np
7 |
8 | from collections import defaultdict
9 | from torch.utils.data import Sampler
10 |
11 |
12 | class NormTripletSampler(Sampler):
13 | """
14 | Randomly sample N identities, then for each identity,
15 | randomly sample K instances, therefore batch size is N*K.
16 | It does not distinguish modalities.
17 |
18 | Args:
19 | - dataset (Dataset): Instance of dataset class.
20 | - num_instances (int): Number of instances per identity in a batch.
21 | - batch_size (int): Number of examples in a batch.
22 | """
23 |
24 | def __init__(self, dataset, batch_size, num_instances):
25 | self.dataset = dataset
26 | self.batch_size = batch_size
27 | self.num_instances = num_instances
28 | self.num_pids_per_batch = self.batch_size // self.num_instances
29 | self.index_dic = defaultdict(list)
30 | for index, pid in enumerate(self.dataset.ids):
31 | self.index_dic[pid].append(index)
32 | self.pids = list(self.index_dic.keys())
33 |
34 | # estimate number of examples in an epoch
35 | self.length = 0
36 | for pid in self.pids:
37 | idxs = self.index_dic[pid]
38 | num = len(idxs)
39 | if num < self.num_instances:
40 | num = self.num_instances
41 | self.length += num - num % self.num_instances
42 |
43 | def __iter__(self):
44 | batch_idxs_dict = defaultdict(list)
45 |
46 | for pid in self.pids:
47 | idxs = copy.deepcopy(self.index_dic[pid])
48 | if len(idxs) < self.num_instances:
49 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
50 | np.random.shuffle(idxs)
51 | batch_idxs = []
52 | for idx in idxs:
53 | batch_idxs.append(idx)
54 | if len(batch_idxs) == self.num_instances:
55 | batch_idxs_dict[pid].append(batch_idxs)
56 | batch_idxs = []
57 |
58 | avai_pids = copy.deepcopy(self.pids)
59 | final_idxs = []
60 |
61 | while len(avai_pids) >= self.num_pids_per_batch:
62 | selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False)
63 | for pid in selected_pids:
64 | batch_idxs = batch_idxs_dict[pid].pop(0)
65 | final_idxs.extend(batch_idxs)
66 | if len(batch_idxs_dict[pid]) == 0:
67 | avai_pids.remove(pid)
68 |
69 | self.length = len(final_idxs)
70 | return iter(final_idxs)
71 |
72 | def __len__(self):
73 | return self.length
74 |
75 |
76 | class CrossModalityRandomSampler(Sampler):
77 | """
78 | The first half of a batch are randomly selected RGB images,
79 | and the second half are randomly selected IR images.
80 |
81 | Args:
82 | - dataset (Dataset): Instance of dataset class.
83 | - batch_size (int): Total number of images in a batch.
84 | """
85 |
86 | def __init__(self, dataset, batch_size):
87 | self.dataset = dataset
88 | self.batch_size = batch_size
89 |
90 | self.rgb_list = []
91 | self.ir_list = []
92 | for i, cam in enumerate(dataset.cam_ids):
93 | if cam in [3, 6]:
94 | self.ir_list.append(i)
95 | else:
96 | self.rgb_list.append(i)
97 |
98 | def __len__(self):
99 | return max(len(self.rgb_list), len(self.ir_list)) * 2
100 |
101 | def __iter__(self):
102 | sample_list = []
103 | rgb_list = np.random.permutation(self.rgb_list).tolist()
104 | ir_list = np.random.permutation(self.ir_list).tolist()
105 |
106 | rgb_size = len(self.rgb_list)
107 | ir_size = len(self.ir_list)
108 | if rgb_size >= ir_size:
109 | diff = rgb_size - ir_size
110 | reps = diff // ir_size
111 | pad_size = diff % ir_size
112 | for _ in range(reps):
113 | ir_list.extend(np.random.permutation(self.ir_list).tolist())
114 | ir_list.extend(np.random.choice(self.ir_list, pad_size, replace=False).tolist())
115 | else:
116 | diff = ir_size - rgb_size
117 | reps = diff // ir_size
118 | pad_size = diff % ir_size
119 | for _ in range(reps):
120 | rgb_list.extend(np.random.permutation(self.rgb_list).tolist())
121 | rgb_list.extend(np.random.choice(self.rgb_list, pad_size, replace=False).tolist())
122 |
123 | assert len(rgb_list) == len(ir_list)
124 |
125 | half_bs = self.batch_size // 2
126 | for start in range(0, len(rgb_list), half_bs):
127 | sample_list.extend(rgb_list[start:start + half_bs])
128 | sample_list.extend(ir_list[start:start + half_bs])
129 |
130 | return iter(sample_list)
131 |
132 |
133 | class CrossModalityIdentitySampler(Sampler):
134 | """
135 | The first half of a batch are randomly selected k_size/2 RGB images for each randomly selected p_size people,
136 | and the second half are randomly selected k_size/2 IR images for each the same p_size people.
137 | Batch - [id1_rgb, id1_rgb, ..., id2_rgb, id2_rgb, ..., id1_ir, id1_ir, ..., id2_ir, id2_ir, ...]
138 |
139 | Args:
140 | - dataset (Dataset): Instance of dataset class.
141 | - p_size (int): Number of identities per batch.
142 | - k_size (int): Number of instances per identity.
143 | """
144 |
145 | def __init__(self, dataset, p_size, k_size):
146 | self.dataset = dataset
147 | self.p_size = p_size
148 | self.k_size = k_size // 2
149 | self.batch_size = p_size * k_size * 2
150 |
151 | self.id2idx_rgb = defaultdict(list)
152 | self.id2idx_ir = defaultdict(list)
153 | for i, identity in enumerate(dataset.ids):
154 | if dataset.cam_ids[i] in [3, 6]:
155 | self.id2idx_ir[identity].append(i)
156 | else:
157 | self.id2idx_rgb[identity].append(i)
158 |
159 | self.num_base_samples = self.dataset.num_ids * self.k_size * 2
160 |
161 | self.num_repeats = len(dataset.ids) // self.num_base_samples
162 | self.num_samples = self.num_base_samples * self.num_repeats
163 |
164 | # num_ir, num_rgb = 0, 0
165 | # for c_id in dataset.cam_ids:
166 | # if c_id in [3, 6]:
167 | # num_ir += 1
168 | # else:
169 | # num_rgb += 1
170 | # self.num_repeats = (num_ir * 2) // self.num_base_samples
171 | # self.num_samples = self.num_base_samples * self.num_repeats
172 |
173 | def __len__(self):
174 | return self.num_samples
175 |
176 | def __iter__(self):
177 | sample_list = []
178 |
179 | for r in range(self.num_repeats):
180 | id_perm = np.random.permutation(self.dataset.num_ids)
181 | for start in range(0, self.dataset.num_ids, self.p_size):
182 | selected_ids = id_perm[start:start + self.p_size]
183 |
184 | sample = []
185 | for identity in selected_ids:
186 | replace = len(self.id2idx_rgb[identity]) < self.k_size
187 | s = np.random.choice(self.id2idx_rgb[identity], size=self.k_size, replace=replace)
188 | sample.extend(s)
189 |
190 | sample_list.extend(sample)
191 |
192 | sample.clear()
193 | for identity in selected_ids:
194 | replace = len(self.id2idx_ir[identity]) < self.k_size
195 | s = np.random.choice(self.id2idx_ir[identity], size=self.k_size, replace=replace)
196 | sample.extend(s)
197 |
198 | sample_list.extend(sample)
199 |
200 | return iter(sample_list)
201 |
202 |
203 | class IdentityCrossModalitySampler(Sampler):
204 | """
205 | It is equivalent to CrossModalityIdentitySampler, but the arrangement is different.
206 | Batch - [id1_ir, id1_rgb, id1_ir, id1_rgb, ..., id2_ir, id2_rgb, id2_ir, id2_rgb, ...]
207 |
208 | Args:
209 | - dataset (Dataset): Instance of dataset class.
210 | - batch_size (int): Number of examples in a batch.
211 | - num_instances (int): Number of instances per identity in a batch.
212 | """
213 |
214 | def __init__(self, dataset, batch_size, num_instances):
215 | self.dataset = dataset
216 | self.batch_size = batch_size
217 | self.num_instances = num_instances
218 | self.num_pids_per_batch = self.batch_size // self.num_instances
219 | self.index_dic_R = defaultdict(list)
220 | self.index_dic_I = defaultdict(list)
221 | for i, identity in enumerate(dataset.ids):
222 | if dataset.cam_ids[i] in [3, 6]:
223 | self.index_dic_I[identity].append(i)
224 | else:
225 | self.index_dic_R[identity].append(i)
226 | self.pids = list(self.index_dic_I.keys())
227 |
228 | # estimate number of examples in an epoch
229 | self.length = 0
230 | for pid in self.pids:
231 | idxs = self.index_dic_I[pid]
232 | num = len(idxs)
233 | if num < self.num_instances:
234 | num = self.num_instances
235 | self.length += num - num % self.num_instances
236 |
237 | def __len__(self):
238 | return self.length
239 |
240 | def __iter__(self):
241 | batch_idxs_dict = defaultdict(list)
242 |
243 | for pid in self.pids:
244 | idxs_I = copy.deepcopy(self.index_dic_I[pid])
245 | idxs_R = copy.deepcopy(self.index_dic_R[pid])
246 | if len(idxs_I) < self.num_instances // 2 and len(idxs_R) < self.num_instances // 2:
247 | idxs_I = np.random.choice(idxs_I, size=self.num_instances // 2, replace=True)
248 | idxs_R = np.random.choice(idxs_R, size=self.num_instances // 2, replace=True)
249 | if len(idxs_I) > len(idxs_R):
250 | idxs_I = np.random.choice(idxs_I, size=len(idxs_R), replace=False)
251 | if len(idxs_R) > len(idxs_I):
252 | idxs_R = np.random.choice(idxs_R, size=len(idxs_I), replace=False)
253 | np.random.shuffle(idxs_I)
254 | np.random.shuffle(idxs_R)
255 | batch_idxs = []
256 | for idx_I, idx_R in zip(idxs_I, idxs_R):
257 | batch_idxs.append(idx_I)
258 | batch_idxs.append(idx_R)
259 | if len(batch_idxs) == self.num_instances:
260 | batch_idxs_dict[pid].append(batch_idxs)
261 | batch_idxs = []
262 |
263 | avai_pids = copy.deepcopy(self.pids)
264 | final_idxs = []
265 |
266 | while len(avai_pids) >= self.num_pids_per_batch:
267 | selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False)
268 | for pid in selected_pids:
269 | batch_idxs = batch_idxs_dict[pid].pop(0)
270 | final_idxs.extend(batch_idxs)
271 | if len(batch_idxs_dict[pid]) == 0:
272 | avai_pids.remove(pid)
273 |
274 | self.length = len(final_idxs)
275 | return iter(final_idxs)
276 |
--------------------------------------------------------------------------------
/data/transform.py:
--------------------------------------------------------------------------------
1 | """MCJA/data/transform.py
2 | It contains a collection of custom image transformation classes designed for augmenting and preprocessing images.
3 | """
4 |
5 | import math
6 | import numbers
7 | import random
8 | import numpy as np
9 | from PIL import Image
10 |
11 | import torch
12 | import torch.nn as nn
13 |
14 |
15 | class WeightedGrayscale(nn.Module):
16 | """
17 | A module that applies a weighted grayscale transformation to an image, which converts an RGB image to grayscale
18 | by applying custom weights to each channel before averaging them. This approach allows for more flexibility than
19 | standard grayscale, potentially emphasizing certain features more than others depending on the chosen weights.
20 |
21 | Args:
22 | - weights (tuple of floats, optional): The weights to apply to the R, G, and B channels, respectively.
23 | If not specified, weights are randomly generated for each image.
24 | - p (float): The probability with which the weighted grayscale transformation is applied.
25 | A value of 1.0 means the transformation is always applied, whereas a value of 0 means it is never applied.
26 |
27 | Methods:
28 | - forward(img): Applies the weighted grayscale transformation to the given img with probability p.
29 | If the transformation is not applied, the original image is returned unchanged.
30 | """
31 |
32 | def __init__(self, weights=None, p=1.0):
33 | super().__init__()
34 | self.weights = weights
35 | self.p = p
36 |
37 | def forward(self, img):
38 | if self.p < torch.rand(1):
39 | return img
40 |
41 | if self.weights is not None:
42 | w1, w2, w3 = self.weights
43 | else:
44 | w1 = random.uniform(0, 1)
45 | w2 = random.uniform(0, 1)
46 | w3 = random.uniform(0, 1)
47 | s = w1 + w2 + w3
48 | w1, w2, w3 = w1 / s, w2 / s, w3 / s
49 | img_data = np.asarray(img)
50 | img_data = w1 * img_data[:, :, 0] + w2 * img_data[:, :, 1] + w3 * img_data[:, :, 2]
51 | img_data = np.expand_dims(img_data, axis=-1).repeat(3, axis=-1)
52 |
53 | return Image.fromarray(np.uint8(img_data))
54 |
55 |
56 | class ChannelCutMix(nn.Module):
57 | """
58 | A module that implements the ChannelCutMix augmentation, a variant of the CutMix augmentation strategy that operates
59 | at the channel level. Unlike traditional CutMix, which combines patches from different images, ChannelCutMix
60 | selectively replaces a region in one channel of an image with the corresponding region from another channel of the
61 | same image. This process introduces diversity in the training data by blending features from different channels,
62 | potentially enhancing the robustness of models to variations in input data.
63 |
64 | Args:
65 | - p (float): The probability with which the ChannelCutMix augmentation is applied.
66 | A value of 1.0 means the augmentation is always applied, whereas a value of 0 means it is never applied.
67 | - scale (tuple of floats): The range of scales relative to the original area of the
68 | image that determines the size of the region to be replaced.
69 | - ratio (tuple of floats): The range of aspect ratios of the region to be replaced.
70 | This controls the shape of the region, allowing for both narrow and wide regions to be selected.
71 |
72 | Methods:
73 | - forward(img): Applies the ChannelCutMix augmentation to the given img with probability p.
74 | If the augmentation is not applied, the original image is returned unchanged.
75 | """
76 |
77 | def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3)):
78 | super().__init__()
79 | self.p = p
80 | self.scale = scale
81 | self.ratio = ratio
82 |
83 | def forward(self, img):
84 | if self.p < torch.rand(1):
85 | return img
86 |
87 | img_h, img_w = img.size # PIL Image Type
88 | area = img_h * img_w
89 | log_ratio = torch.log(torch.tensor(self.ratio))
90 | i, j, h, w = None, None, None, None
91 | for _ in range(10):
92 | cutmix_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()
93 | aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item()
94 | h = int(round(math.sqrt(cutmix_area * aspect_ratio)))
95 | w = int(round(math.sqrt(cutmix_area / aspect_ratio)))
96 | if not (h < img_h and w < img_w):
97 | continue
98 | i = torch.randint(0, img_h - h + 1, size=(1,)).item()
99 | j = torch.randint(0, img_w - w + 1, size=(1,)).item()
100 | break
101 |
102 | img_data = np.asarray(img)
103 | bg_c, fg_c = random.sample(range(3), k=2)
104 | bg_img_data = img_data[:, :, bg_c]
105 | fg_img_data = img_data[:, :, fg_c]
106 | bg_img_data[i:i + h, j:j + w] = fg_img_data[i:i + h, j:j + w]
107 | img_data = np.expand_dims(bg_img_data, axis=-1).repeat(3, axis=-1)
108 |
109 | return Image.fromarray(np.uint8(img_data))
110 |
111 |
112 | class SpectrumJitter(nn.Module):
113 | """
114 | A module for applying Spectrum Jitter augmentation to an image, which selectively alters the intensity of a randomly
115 | chosen color channel and blends it with the original image. This augmentation can introduce variations in color
116 | intensity and distribution across different channels, mimicking conditions of varying spectrum that a model might
117 | encounter in real-world scenarios. The purpose is to improve the model's robustness to changes in spectrum by
118 | exposing it to a wider range of color spectrum during training.
119 |
120 | Args:
121 | - factor (float or tuple of float): Specifies the range of factors to use for blending the altered channel back into
122 | the original image. If a single float is provided, it's interpreted as the maximum deviation from the default
123 | intensity of 1.0, creating a range [1-factor, 1+factor]. If a tuple is provided, it directly specifies the range.
124 | The factor influences how strongly the selected channel's intensity is altered.
125 | - p (float): The probability with which the Spectrum Jitter augmentation is applied to any given image.
126 | A value of 1.0 means the augmentation is always applied, while a value of 0 means it is never applied.
127 |
128 | Methods:
129 | - forward(img): Applies the Spectrum Jitter augmentation to the given img with probability p.
130 | If the augmentation is not applied, the original image is returned unchanged.
131 | """
132 |
133 | def __init__(self, factor=0.5, p=0.5):
134 | super().__init__()
135 | self.factor = self._check_input(factor, 'spectrum')
136 | self.p = p
137 |
138 | @torch.jit.unused # Inspired by the implementation of color jitter in standard libraries
139 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
140 | if isinstance(value, numbers.Number):
141 | if value < 0:
142 | raise ValueError("If {} is a single number, it must be non negative.".format(name))
143 | value = [center - float(value), center + float(value)]
144 | if clip_first_on_zero:
145 | value[0] = max(value[0], 0.0)
146 | elif isinstance(value, (tuple, list)) and len(value) == 2:
147 | if not bound[0] <= value[0] <= value[1] <= bound[1]:
148 | raise ValueError("{} values should be between {}".format(name, bound))
149 | else:
150 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))
151 |
152 | if value[0] == value[1] == center:
153 | value = None
154 | return value
155 |
156 | def forward(self, img):
157 | if self.p < torch.rand(1):
158 | return img
159 |
160 | selected_c = random.randint(0, 2)
161 |
162 | img_data = np.asarray(img)
163 | img_data = img_data[:, :, selected_c]
164 | img_data = np.expand_dims(img_data, axis=-1).repeat(3, axis=-1)
165 | degenerate = Image.fromarray(np.uint8(img_data))
166 |
167 | factor = float(torch.empty(1).uniform_(self.factor[0], self.factor[1]))
168 | return Image.blend(degenerate, img, factor)
169 |
170 |
171 | class ChannelAugmentation(nn.Module):
172 | """
173 | A module that encapsulates various channel-based augmentation strategies, allowing for flexible and probabilistic
174 | application of different augmentations to an image. This module supports a range of augmentation modes, including
175 | selection of individual RGB channels, averaging of channels, mixing channels, and more advanced techniques such
176 | as weighted grayscale conversion, channel cutmix, and spectrum jitter. The specific augmentation to apply can be
177 | selected via the `mode` parameter, providing a versatile tool for enhancing the diversity of original data.
178 |
179 | Args:
180 | - mode (str, optional): Specifies the augmentation technique to apply. Options include:
181 | - None: No augmentation is applied.
182 | - 'r', 'g', 'b': Selects a specific RGB channel.
183 | - 'avg': Averages the RGB channels.
184 | - 'rg_avg', 'rb_avg', 'gb_avg': Averages specified pairs of RGB channels.
185 | - 'rand': Randomly selects a channel at each call.
186 | - 'wg': Applies weighted grayscale augmentation.
187 | - 'cc': Applies channel cutmix augmentation.
188 | - 'sj': Applies spectrum jitter augmentation.
189 | Each mode introduces different types of variations, from simple channel selection to more complex transformations.
190 | - p (float): The probability with which the selected augmentation is applied.
191 | A value of 1.0 means the augmentation is always applied, while a value of 0 means it is never applied.
192 |
193 | Methods:
194 | - forward(img): Applies the configured augmentation to the given img with probability p.
195 | If the augmentation is not applied (either because p < 1 or mode is None), the original image is returned unchanged.
196 | """
197 |
198 | def __init__(self, mode=None, p=1.0):
199 | super().__init__()
200 | assert mode in [None, 'r', 'g', 'b', 'avg', 'rg_avg', 'rb_avg', 'gb_avg', 'rand', 'wg', 'cc', 'sj']
201 | self.mode = mode
202 | self.p = p
203 | if mode in ['r', 'g', 'b', 'avg', 'rg_avg', 'rb_avg', 'gb_avg', 'rand']:
204 | self.ca = ChannelSelection(mode=mode, p=p)
205 | elif mode == 'wg':
206 | self.ca = WeightedGrayscale(p=p)
207 | elif mode == 'cc':
208 | self.ca = ChannelCutMix(p=p)
209 | elif mode == 'sj':
210 | self.ca = SpectrumJitter(factor=(0.00, 1.00), p=p)
211 | else:
212 | self.ca = NoTransform()
213 |
214 | def forward(self, img):
215 | return self.ca(img)
216 |
217 |
218 | class ChannelSelection(nn.Module):
219 | """
220 | A module that selectively manipulates the color channels of an image according to a specified mode.
221 | This augmentation technique can emphasize or de-emphasize certain features in the image based on color,
222 | which might be beneficial for tasks sensitive to specific color channels. The module supports a variety
223 | of modes that target different channels or combinations thereof, and it applies these transformations
224 | with a specified probability, allowing for stochastic data augmentation.
225 |
226 | Args:
227 | - mode (str, optional): Specifies the channel manipulation mode.
228 | It can be one of 'r', 'g', 'b', 'avg', 'rg_avg', 'rb_avg', 'gb_avg', or 'rand'. The default is 'rand',
229 | which randomly selects one of the RGB channels each time the augmentation is applied.
230 | - p (float): The probability with which the channel selection or modification is applied.
231 | A value of 1.0 means the transformation is always applied, whereas a value of 0 means it is never applied.
232 | """
233 |
234 | def __init__(self, mode='rand', p=1.0):
235 | super().__init__()
236 | assert mode in ['r', 'g', 'b', 'avg', 'rg_avg', 'rb_avg', 'gb_avg', 'rand']
237 | self.mode = mode
238 | self.p = p
239 |
240 | def forward(self, img):
241 | if self.p < torch.rand(1):
242 | return img
243 |
244 | img_data = np.asarray(img)
245 | if 'avg' in self.mode:
246 | if self.mode == 'avg':
247 | pass
248 | elif self.mode == 'rg_avg':
249 | img_data = np.stack([img_data[:, :, 0], img_data[:, :, 1]])
250 | elif self.mode == 'rb_avg':
251 | img_data = np.stack([img_data[:, :, 0], img_data[:, :, 2]])
252 | elif self.mode == 'gb_avg':
253 | img_data = np.stack([img_data[:, :, 1], img_data[:, :, 2]])
254 | img_data = img_data.mean(axis=-1)
255 | else:
256 | if self.mode == 'r':
257 | selected_c = 0
258 | elif self.mode == 'g':
259 | selected_c = 1
260 | elif self.mode == 'b':
261 | selected_c = 2
262 | elif self.mode == 'rand':
263 | selected_c = random.randint(0, 2)
264 | img_data = img_data[:, :, selected_c]
265 | img_data = np.expand_dims(img_data, axis=-1).repeat(3, axis=-1)
266 | return Image.fromarray(np.uint8(img_data))
267 |
268 |
269 | class NoTransform(nn.Module):
270 | """
271 | A module that acts as a placeholder for a transformation step, performing no operation on the input image.
272 | It is designed to seamlessly integrate into data processing pipelines or augmentation sequences where conditional
273 | application of transformations is required but, in some cases, no actual transformation should be applied.
274 |
275 | Methods:
276 | - forward(img): Returns the input image unchanged, serving as a pass-through operation in a transformation pipeline.
277 | """
278 |
279 | def __init__(self):
280 | super().__init__()
281 |
282 | def forward(self, img):
283 | return img
284 |
--------------------------------------------------------------------------------
/engine/__init__.py:
--------------------------------------------------------------------------------
1 | """MCJA/engine/__init__.py
2 | It initializes the training and evaluation engines for the Multi-level Cross-modality Joint Alignment (MCJA) method.
3 | """
4 |
5 | import os
6 |
7 | import numpy as np
8 | import scipy.io as sio
9 | import torch
10 |
11 | from glob import glob
12 | from ignite.engine import Events
13 | from ignite.handlers import ModelCheckpoint
14 | from ignite.handlers import Timer
15 |
16 | from engine.engine import create_eval_engine
17 | from engine.engine import create_train_engine
18 | from engine.metric import AutoKVMetric
19 | from utils.eval_data import eval_sysu, eval_regdb
20 | from configs.default.dataset import dataset_cfg
21 |
22 |
23 | def get_trainer(dataset, model, optimizer, lr_scheduler=None, logger=None, writer=None, non_blocking=False,
24 | log_period=10, save_dir="checkpoints", prefix="model", eval_interval=None, start_eval=None,
25 | gallery_loader=None, query_loader=None):
26 | """
27 | A factory function that assembles and returns a training engine configured for VI-ReID tasks. This function sets up
28 | a trainer with custom event handlers for various stages of the training process, including model checkpointing,
29 | evaluation, logging, and learning rate scheduling. It integrates functionalities for performance evaluation using
30 | specified metrics and supports conditional execution of evaluations and logging activities based on the training.
31 |
32 | Args:
33 | - dataset (str): The name of the dataset being used, which dictates certain evaluation protocols.
34 | - model (nn.Module): The neural network model to be trained.
35 | - optimizer (Optimizer): The optimizer used for training the model.
36 | - lr_scheduler (Optional[Scheduler]): A learning rate scheduler for adjusting the learning rate across epochs.
37 | - logger (Logger): A logger for recording training progress and evaluation results.
38 | - writer (Optional[SummaryWriter]): A TensorBoard writer for logging metrics and visualizations.
39 | - non_blocking (bool): If set to True, attempts to asynchronously transfer data to device to improve performance.
40 | - log_period (int): The frequency (in iterations) with which training metrics are logged.
41 | - save_dir (str): The directory where model checkpoints are saved.
42 | - prefix (str): The prefix used for naming saved model files.
43 | - eval_interval (Optional[int]): The frequency (in epochs) with which the model is evaluated.
44 | - start_eval (Optional[int]): The epoch from which to start performing evaluations.
45 | - gallery_loader (Optional[DataLoader]): The DataLoader for the gallery set used in evaluations.
46 | - query_loader (Optional[DataLoader]): The DataLoader for the query set used in evaluations.
47 |
48 | Returns:
49 | - Engine: An Ignite Engine object configured for training,
50 | equipped with handlers for checkpointing, evaluation, and logging.
51 | """
52 |
53 | # Trainer
54 | trainer = create_train_engine(model, optimizer, non_blocking)
55 |
56 | # Checkpoint Handler
57 | handler = ModelCheckpoint(save_dir, prefix, save_interval=eval_interval, n_saved=3, create_dir=True,
58 | save_as_state_dict=True, require_empty=False)
59 | trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {"model": model})
60 |
61 | timer = Timer(average=True)
62 | kv_metric = AutoKVMetric()
63 |
64 | # Evaluator
65 | evaluator = None
66 | if not type(eval_interval) == int:
67 | raise TypeError("The parameter 'validate_interval' must be type INT.")
68 | if not type(start_eval) == int:
69 | raise TypeError("The parameter 'start_eval' must be type INT.")
70 | if eval_interval > 0 and gallery_loader is not None and query_loader is not None:
71 | evaluator = create_eval_engine(model, non_blocking)
72 |
73 | def run_init_eval(engine):
74 | logger.info('\n## Checking model performance with initial parameters...')
75 |
76 | # Extract Query Feature
77 | evaluator.run(query_loader)
78 | q_feats = torch.cat(evaluator.state.feat_list, dim=0)
79 | q_ids = torch.cat(evaluator.state.id_list, dim=0).numpy()
80 | q_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy()
81 | q_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0)
82 |
83 | # Extract Gallery Feature
84 | evaluator.run(gallery_loader)
85 | g_feats = torch.cat(evaluator.state.feat_list, dim=0)
86 | g_ids = torch.cat(evaluator.state.id_list, dim=0).numpy()
87 | g_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy()
88 | g_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0)
89 |
90 | if dataset == 'sysu':
91 | perm = sio.loadmat(os.path.join(dataset_cfg.sysu.data_root, 'exp', 'rand_perm_cam.mat'))['rand_perm_cam']
92 | eval_sysu(q_feats, q_ids, q_cams, q_img_paths,
93 | g_feats, g_ids, g_cams, g_img_paths,
94 | perm, mode='all', num_shots=1)
95 | eval_sysu(q_feats, q_ids, q_cams, q_img_paths,
96 | g_feats, g_ids, g_cams, g_img_paths,
97 | perm, mode='all', num_shots=10)
98 | eval_sysu(q_feats, q_ids, q_cams, q_img_paths,
99 | g_feats, g_ids, g_cams, g_img_paths,
100 | perm, mode='indoor', num_shots=1)
101 | eval_sysu(q_feats, q_ids, q_cams, q_img_paths,
102 | g_feats, g_ids, g_cams, g_img_paths,
103 | perm, mode='indoor', num_shots=10)
104 | elif dataset == 'regdb':
105 | logger.info('Test Mode - infrared to visible')
106 | eval_regdb(q_feats, q_ids, q_cams, q_img_paths,
107 | g_feats, g_ids, g_cams, g_img_paths, mode='i2v')
108 | logger.info('Test Mode - visible to infrared')
109 | eval_regdb(g_feats, g_ids, g_cams, g_img_paths,
110 | q_feats, q_ids, q_cams, q_img_paths, mode='v2i')
111 | else:
112 | raise NotImplementedError(f'Dataset - {dataset} is not supported')
113 |
114 | evaluator.state.feat_list.clear()
115 | evaluator.state.id_list.clear()
116 | evaluator.state.cam_list.clear()
117 | evaluator.state.img_path_list.clear()
118 | del q_feats, q_ids, q_cams, g_feats, g_ids, g_cams
119 |
120 | logger.info('\n## Starting the training process...')
121 |
122 | @trainer.on(Events.STARTED)
123 | def train_start(engine):
124 | setattr(engine.state, "best_rank1", 0.0)
125 | run_init_eval(engine)
126 |
127 | @trainer.on(Events.COMPLETED)
128 | def train_completed(engine):
129 | pass
130 |
131 | @trainer.on(Events.EPOCH_STARTED)
132 | def epoch_started_callback(engine):
133 | kv_metric.reset()
134 | timer.reset()
135 |
136 | @trainer.on(Events.EPOCH_COMPLETED)
137 | def epoch_completed_callback(engine):
138 | epoch = engine.state.epoch
139 |
140 | if lr_scheduler is not None:
141 | lr_scheduler.step()
142 |
143 | if epoch % eval_interval == 0:
144 | logger.info("Model saved at {}/{}_model_{}.pth".format(save_dir, prefix, epoch))
145 |
146 | if evaluator and epoch % eval_interval == 0 and epoch >= start_eval:
147 | # Extract Query Feature
148 | evaluator.run(query_loader)
149 | q_feats = torch.cat(evaluator.state.feat_list, dim=0)
150 | q_ids = torch.cat(evaluator.state.id_list, dim=0).numpy()
151 | q_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy()
152 | q_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0)
153 |
154 | # Extract Gallery Feature
155 | evaluator.run(gallery_loader)
156 | g_feats = torch.cat(evaluator.state.feat_list, dim=0)
157 | g_ids = torch.cat(evaluator.state.id_list, dim=0).numpy()
158 | g_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy()
159 | g_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0)
160 |
161 | if dataset == 'sysu':
162 | perm = sio.loadmat(os.path.join(dataset_cfg.sysu.data_root, 'exp', 'rand_perm_cam.mat'))[
163 | 'rand_perm_cam']
164 | r1, r5, r10, r20, mAP, mINP, _ = eval_sysu(q_feats, q_ids, q_cams, q_img_paths,
165 | g_feats, g_ids, g_cams, g_img_paths,
166 | perm, mode='all', num_shots=1)
167 | elif dataset == 'regdb':
168 | logger.info('Test Mode - infrared to visible')
169 | r1, r5, r10, r20, mAP, mINP, _ = eval_regdb(q_feats, q_ids, q_cams, q_img_paths,
170 | g_feats, g_ids, g_cams, g_img_paths, mode='i2v')
171 | logger.info('Test Mode - visible to infrared')
172 | r1_, r5_, r10_, r20_, mAP_, mINP_, _ = eval_regdb(g_feats, g_ids, g_cams, g_img_paths,
173 | q_feats, q_ids, q_cams, q_img_paths, mode='v2i')
174 | r1 = (r1 + r1_) / 2
175 | r5 = (r5 + r5_) / 2
176 | r10 = (r10 + r10_) / 2
177 | r20 = (r20 + r20_) / 2
178 | mAP = (mAP + mAP_) / 2
179 | mINP = (mINP + mINP_) / 2
180 | else:
181 | raise NotImplementedError(f'Dataset - {dataset} is not supported')
182 |
183 | if r1 > engine.state.best_rank1:
184 | for rm_best_model_path in glob("{}/{}_model_best-*.pth".format(save_dir, prefix)):
185 | os.remove(rm_best_model_path)
186 | engine.state.best_rank1 = r1
187 | torch.save(model.state_dict(), "{}/{}_model_best-{}.pth".format(save_dir, prefix, epoch))
188 |
189 | if writer is not None:
190 | writer.add_scalar('eval/r1', r1, epoch)
191 | writer.add_scalar('eval/r5', r5, epoch)
192 | writer.add_scalar('eval/r10', r10, epoch)
193 | writer.add_scalar('eval/r20', r20, epoch)
194 | writer.add_scalar('eval/mAP', mAP, epoch)
195 | writer.add_scalar('eval/mINP', mINP, epoch)
196 |
197 | evaluator.state.feat_list.clear()
198 | evaluator.state.id_list.clear()
199 | evaluator.state.cam_list.clear()
200 | evaluator.state.img_path_list.clear()
201 | del q_feats, q_ids, q_cams, g_feats, g_ids, g_cams
202 |
203 | @trainer.on(Events.ITERATION_COMPLETED)
204 | def iteration_complete_callback(engine):
205 | timer.step()
206 | kv_metric.update(engine.state.output)
207 |
208 | epoch = engine.state.epoch
209 | iteration = engine.state.iteration
210 | iter_in_epoch = iteration - (epoch - 1) * len(engine.state.dataloader)
211 |
212 | if iter_in_epoch % log_period == 0 and iter_in_epoch > 0:
213 | batch_size = engine.state.batch[0].size(0)
214 | speed = batch_size / timer.value()
215 | msg = "Epoch[%d] Batch [%d] Speed: %.2f samples/sec" % (epoch, iter_in_epoch, speed)
216 | metric_dict = kv_metric.compute()
217 | if logger is not None:
218 | for k in sorted(metric_dict.keys()):
219 | msg += " %s: %.4f" % (k, metric_dict[k])
220 | if writer is not None:
221 | writer.add_scalar('metric/{}'.format(k), metric_dict[k], iteration)
222 | logger.info(msg)
223 | kv_metric.reset()
224 | timer.reset()
225 |
226 | return trainer
227 |
--------------------------------------------------------------------------------
/engine/engine.py:
--------------------------------------------------------------------------------
1 | """MCJA/engine/engine.py
2 | It defines the creation of training and evaluation engines using the Ignite library.
3 | """
4 |
5 | import numpy as np
6 | import torch
7 |
8 | from torch.autograd import no_grad
9 | from ignite.engine import Engine
10 | from ignite.engine import Events
11 | from apex import amp
12 |
13 |
14 | def create_train_engine(model, optimizer, non_blocking=False):
15 | """
16 | A factory function that creates and returns an Ignite Engine configured for training a VI-ReID model. This engine
17 | orchestrates the training process, managing the data flow, loss calculation, parameter updates, and any additional
18 | computations needed per iteration. The function encapsulates the core training loop, including data loading to the
19 | device, executing model's forward pass, computing the loss, performing backpropagation, and updating model weights.
20 |
21 | Args:
22 | - model (nn.Module): The model to be trained. The model should accept input data, labels, camera IDs, and
23 | potentially other information like image paths and epoch number, returning computed loss and additional metrics.
24 | - optimizer (Optimizer): The optimizer used for updating the model parameters based on the computed gradients.
25 | - non_blocking (bool): If set to True, allows asynchronous data transfers to the GPU for improved efficiency.
26 |
27 | Returns:
28 | - Engine: An Ignite Engine object that processes batches of data using the provided model and optimizer.
29 | """
30 |
31 | device = torch.device("cuda", torch.cuda.current_device())
32 |
33 | def _process_func(engine, batch):
34 | model.train()
35 |
36 | data, labels, cam_ids, img_paths, img_idxes = batch
37 | epoch = engine.state.epoch
38 | data = data.to(device, non_blocking=non_blocking)
39 | labels = labels.to(device, non_blocking=non_blocking)
40 | cam_ids = cam_ids.to(device, non_blocking=non_blocking)
41 |
42 | optimizer.zero_grad(set_to_none=True)
43 |
44 | loss, metric = model(data, labels,
45 | cam_ids=cam_ids,
46 | img_paths=img_paths,
47 | epoch=epoch)
48 |
49 | with amp.scale_loss(loss, optimizer) as scaled_loss:
50 | scaled_loss.backward()
51 |
52 | optimizer.step()
53 |
54 | return metric
55 |
56 | return Engine(_process_func)
57 |
58 |
59 | def create_eval_engine(model, non_blocking=False):
60 | """
61 | A factory function that creates and returns an Ignite Engine configured for evaluating a VI-ReID model. This engine
62 | manages evaluation process, facilitating the flow of data through the model and the collection of output features
63 | for later analysis. It operates in evaluation mode, ensuring that the model's behavior is consistent with inference
64 | conditions, such as disabled dropout layers.
65 |
66 | Args:
67 | - model (nn.Module): The model to be evaluated. The model should accept input data and potentially other
68 | information like camera IDs, returning feature representations.
69 | - non_blocking (bool): If set to True, allows asynchronous data transfers to the GPU to improve efficiency.
70 |
71 | Returns:
72 | - Engine: An Ignite Engine object that processes batches of data through the provided model in evaluation mode.
73 | """
74 |
75 | device = torch.device("cuda", torch.cuda.current_device())
76 |
77 | def _process_func(engine, batch):
78 | model.eval()
79 | data, labels, cam_ids, img_paths = batch[:4]
80 | data = data.to(device, non_blocking=non_blocking)
81 |
82 | with no_grad():
83 | feat = model(data, cam_ids=cam_ids.to(device, non_blocking=non_blocking))
84 |
85 | return feat.data.float().cpu(), labels, cam_ids, np.array(img_paths)
86 |
87 | engine = Engine(_process_func)
88 |
89 | @engine.on(Events.EPOCH_STARTED)
90 | def clear_data(engine):
91 | if not hasattr(engine.state, "feat_list"):
92 | setattr(engine.state, "feat_list", [])
93 | else:
94 | engine.state.feat_list.clear()
95 |
96 | if not hasattr(engine.state, "id_list"):
97 | setattr(engine.state, "id_list", [])
98 | else:
99 | engine.state.id_list.clear()
100 |
101 | if not hasattr(engine.state, "cam_list"):
102 | setattr(engine.state, "cam_list", [])
103 | else:
104 | engine.state.cam_list.clear()
105 |
106 | if not hasattr(engine.state, "img_path_list"):
107 | setattr(engine.state, "img_path_list", [])
108 | else:
109 | engine.state.img_path_list.clear()
110 |
111 | @engine.on(Events.ITERATION_COMPLETED)
112 | def store_data(engine):
113 | engine.state.feat_list.append(engine.state.output[0])
114 | engine.state.id_list.append(engine.state.output[1])
115 | engine.state.cam_list.append(engine.state.output[2])
116 | engine.state.img_path_list.append(engine.state.output[3])
117 |
118 | return engine
119 |
--------------------------------------------------------------------------------
/engine/metric.py:
--------------------------------------------------------------------------------
1 | """MCJA/engine/metric.py
2 | It provides a flexible mechanism for aggregating and computing metrics of cross-modality person re-identification.
3 | """
4 |
5 | from collections import defaultdict
6 |
7 | import torch
8 | from ignite.exceptions import NotComputableError
9 | from ignite.metrics import Metric, Accuracy
10 |
11 |
12 | class ScalarMetric(Metric):
13 | """
14 | A simple, generic implementation of an Ignite Metric for aggregating scalar values over iterations or epochs. This
15 | class provides a framework for tracking and computing the average of any scalar metric (e.g., loss, accuracy) during
16 | the training or evaluation process of a machine learning model. It accumulates the sum of the scalar values and the
17 | count of instances (batches) it has seen, allowing for the calculation of average scalar metric over all instances.
18 |
19 | Methods:
20 | - update(value): Adds a new scalar value to the running sum and increments the instance count.
21 | This method is called at each iteration with the scalar metric value for that iteration.
22 | - reset(): Resets the running sum and instance count to zero.
23 | Typically called at the start of each epoch or evaluation run to prepare for new calculations.
24 | - compute(): Calculates and returns the average of all scalar values added since the last reset.
25 | If no instances have been added, it raises a NotComputableError, indicating that there is not enough data.
26 | """
27 |
28 | def update(self, value):
29 | self.sum_metric += value
30 | self.sum_inst += 1
31 |
32 | def reset(self):
33 | self.sum_inst = 0
34 | self.sum_metric = 0
35 |
36 | def compute(self):
37 | if self.sum_inst == 0:
38 | raise NotComputableError('Accuracy must have at least one example before it can be computed')
39 | return self.sum_metric / self.sum_inst
40 |
41 |
42 | class IgnoreAccuracy(Accuracy):
43 | """
44 | An extension of the Ignite Accuracy metric that incorporates the ability to ignore certain target labels during the
45 | computation of accuracy. This class is particularly useful in scenarios where some target labels in the dataset
46 | should not contribute to the accuracy calculation, such as padding tokens in sequence models or background classes
47 | in segmentation tasks. By specifying an ignore index, instances with this target label are excluded from both the
48 | numerator and denominator of the accuracy calculation.
49 |
50 | Args:
51 | - ignore_index (int): The target label that should be ignored in the accuracy computation. Instances with this
52 | label are not considered correct or incorrect predictions, effectively being excluded from the metric.
53 |
54 | Methods:
55 | - reset(): Resets the internal counters for correct predictions and total examples,
56 | preparing the metric for a new set of calculations.
57 | - update(output): Processes a batch of predictions and targets,
58 | updating the internal counters by counting correct predictions that do not correspond to the ignore index.
59 | - compute(): Calculates and returns the accuracy over all batches processed since the last reset,
60 | excluding instances with the ignore index from the calculation.
61 | """
62 |
63 | def __init__(self, ignore_index=-1):
64 | super(IgnoreAccuracy, self).__init__()
65 |
66 | self.ignore_index = ignore_index
67 |
68 | def reset(self):
69 | self._num_correct = 0
70 | self._num_examples = 0
71 |
72 | def update(self, output):
73 |
74 | y_pred, y = self._check_shape(output)
75 | self._check_type((y_pred, y))
76 |
77 | if self._type == "binary":
78 | indices = torch.round(y_pred).type(y.type())
79 | elif self._type == "multiclass":
80 | indices = torch.max(y_pred, dim=1)[1]
81 |
82 | correct = torch.eq(indices, y).view(-1)
83 | ignore = torch.eq(y, self.ignore_index).view(-1)
84 | self._num_correct += torch.sum(correct).item()
85 | self._num_examples += correct.shape[0] - ignore.sum().item()
86 |
87 | def compute(self):
88 | if self._num_examples == 0:
89 | raise NotComputableError('Accuracy must have at least one example before it can be computed')
90 | return self._num_correct / self._num_examples
91 |
92 |
93 | class AutoKVMetric(Metric):
94 | """
95 | A flexible metric class in the Ignite framework that computes and stores key-value (KV) pair metrics for each
96 | output of a model during training or evaluation. The AutoKVMetric class is designed to handle outputs in the
97 | form of dictionaries where each key corresponds to a specific metric name, and its value represents the metric
98 | value for that batch. This class allows for the automatic aggregation of multiple metrics over all batches,
99 | providing a convenient way to track a variety of performance indicators within a single metric class.
100 |
101 | Methods:
102 | - update(output): Updates the running sum of each metric based on the current batch's output. The output is expected
103 | to be a dictionary where each key-value pair represents a metric name and its corresponding value.
104 | - reset(): Resets all internal counters and sums for each metric, preparing metric for a new round of calculations.
105 | This method is typically called at the start of each epoch or evaluation run.
106 | - compute(): Calculates and returns the average value of each metric over all processed batches since last reset.
107 | The return value is a dictionary mirroring the structure of the input to `update`, with each key corresponding to
108 | a metric name and each value being the average metric value.
109 | """
110 |
111 | def __init__(self):
112 | self.kv_sum_metric = defaultdict(lambda: torch.tensor(0., device="cuda"))
113 | self.kv_sum_inst = defaultdict(lambda: torch.tensor(0., device="cuda"))
114 |
115 | self.kv_metric = defaultdict(lambda: 0)
116 |
117 | super(AutoKVMetric, self).__init__()
118 |
119 | def update(self, output):
120 | if not isinstance(output, dict):
121 | raise TypeError('The output must be a key-value dict.')
122 |
123 | for k in output.keys():
124 | self.kv_sum_metric[k].add_(output[k])
125 | self.kv_sum_inst[k].add_(1)
126 |
127 | def reset(self):
128 | for k in self.kv_sum_metric.keys():
129 | self.kv_sum_metric[k].zero_()
130 | self.kv_sum_inst[k].zero_()
131 | self.kv_metric[k] = 0
132 |
133 | def compute(self):
134 | for k in self.kv_sum_metric.keys():
135 | if self.kv_sum_inst[k] == 0:
136 | continue
137 | # raise NotComputableError('Accuracy must have at least one example before it can be computed')
138 |
139 | metric_value = self.kv_sum_metric[k] / self.kv_sum_inst[k]
140 | self.kv_metric[k] = metric_value.item()
141 |
142 | return self.kv_metric
143 |
--------------------------------------------------------------------------------
/figs/mcja_overall_structure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/workingcoder/MCJA/efc1ffe9f1cb2c4f76a463bf708150626555b295/figs/mcja_overall_structure.png
--------------------------------------------------------------------------------
/losses/__init__.py:
--------------------------------------------------------------------------------
1 | """MCJA/losses/__init__.py
2 | It is used to mark a directory as a Python package directory.
3 | """
--------------------------------------------------------------------------------
/losses/cm_retrieval_loss.py:
--------------------------------------------------------------------------------
1 | """MCJA/losses/cm_retrieval_loss.py
2 | It defines the `CMRetrievalLoss` class, a loss function specifically designed for cross-modality retrieval task.
3 | """
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torchsort
8 |
9 | from torch.nn import functional as F
10 |
11 |
12 | class CMRetrievalLoss(nn.Module):
13 | """
14 | A module that implements a Cross-Modal Retrieval (CMR) Loss, designed for use in training models on tasks involving
15 | the retrieval of relevant items across different modalities. The CMR Loss computes pairwise distances between
16 | embeddings from different modalities, classifying them as either matches (same identity label across modalities) or
17 | mismatches (different identity labels). It uses a soft ranking mechanism to assign ranks based on the pairwise
18 | distances, aiming to ensure that matches are ranked higher (closer) than mismatches.
19 |
20 | Args:
21 | - embeddings (Tensor): The embeddings generated by the model for a batch of items.
22 | - id_labels (Tensor): The identity labels for each item in the batch.
23 | - m_labels (Tensor): The modality labels for each item in the batch.
24 |
25 | Methods:
26 | - forward(embeddings, id_labels, m_labels): Computes the CMR Loss for a batch of embeddings,
27 | calculating the soft ranking loss between predicted and target ranks.
28 | """
29 |
30 | def __init__(self):
31 | super(CMRetrievalLoss, self).__init__()
32 |
33 | def forward(self, embeddings, id_labels, m_labels):
34 | m_labels_unique = torch.unique(m_labels)
35 | m_num = len(m_labels_unique)
36 |
37 | embeddings_list = [embeddings[m_labels == m_label] for m_label in m_labels_unique]
38 | id_labels_list = [id_labels[m_labels == m_label] for m_label in m_labels_unique]
39 |
40 | cmr_loss = 0
41 | valid_m_count = 0
42 | for i in range(m_num):
43 | cur_m_embeddings = embeddings_list[i]
44 | cur_m_id_labels = id_labels_list[i]
45 | other_m_embeddings = torch.cat([embeddings_list[j] for j in range(len(m_labels_unique)) if j != i], dim=0)
46 | other_m_id_labels = torch.cat([id_labels_list[j] for j in range(len(m_labels_unique)) if j != i], dim=0)
47 |
48 | match_mask = cur_m_id_labels.unsqueeze(dim=1) == other_m_id_labels.unsqueeze(dim=0)
49 | mismatch_mask = ~match_mask
50 | match_num = match_mask.sum(dim=-1)
51 | mismatch_num = mismatch_mask.sum(dim=-1)
52 |
53 | # Remove invalid queries (It has no cross-modal matching in the batch)
54 | remove_mask = (match_num == 0) | (match_num == len(other_m_id_labels))
55 | if remove_mask.all():
56 | continue
57 | cur_m_embeddings = cur_m_embeddings[~remove_mask]
58 | cur_m_id_labels = cur_m_id_labels[~remove_mask]
59 | match_mask = match_mask[~remove_mask]
60 | mismatch_mask = mismatch_mask[~remove_mask]
61 | match_num = match_num[~remove_mask]
62 | mismatch_num = mismatch_num[~remove_mask]
63 |
64 | dist_mat = F.cosine_similarity(cur_m_embeddings[:, None, :], other_m_embeddings[None, :, :], dim=-1)
65 | dist_mat = (1 - dist_mat) / 2
66 |
67 | predict_rank = torchsort.soft_rank(dist_mat, regularization="l2", regularization_strength=0.5)
68 |
69 | target_rank = torch.zeros_like(predict_rank)
70 | q_num, g_num = match_mask.shape
71 |
72 | target_rank[match_mask] = 1
73 | target_rank[mismatch_mask] = g_num
74 |
75 | cmr_loss += F.l1_loss(predict_rank, target_rank)
76 |
77 | valid_m_count += 1
78 |
79 | cmr_loss = cmr_loss / valid_m_count if valid_m_count > 0 else 0
80 |
81 | return cmr_loss
82 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """MCJA/main.py
2 | It is the main entry point for training the Multi-level Cross-modality Joint Alignment (MCJA) method.
3 | """
4 |
5 | import os
6 | import glob
7 | import pprint
8 | import logging
9 |
10 | import numpy as np
11 | import scipy.io as sio
12 |
13 | import torch
14 | from torch import optim
15 | from torch.utils.tensorboard import SummaryWriter
16 | from apex import amp
17 |
18 | from data import get_train_loader
19 | from data import get_test_loader
20 | from models.mcja import MCJA
21 | from engine import get_trainer
22 | from engine.engine import create_eval_engine
23 | from utils.eval_data import eval_sysu, eval_regdb
24 |
25 |
26 | def train(cfg):
27 | # Recorder ---------------------------------------------------------------------------------------------------------
28 | logger = logging.getLogger('MCJA')
29 | tb_dir = os.path.join(cfg.log_dir, 'tensorboard')
30 | if not os.path.isdir(tb_dir):
31 | os.makedirs(tb_dir, exist_ok=True)
32 | writer = SummaryWriter(log_dir=tb_dir)
33 |
34 | # Train DataLoader -------------------------------------------------------------------------------------------------
35 | train_loader = get_train_loader(dataset=cfg.dataset, root=cfg.data_root,
36 | sample_method=cfg.sample_method,
37 | batch_size=cfg.batch_size,
38 | p_size=cfg.p_size,
39 | k_size=cfg.k_size,
40 | image_size=cfg.image_size,
41 | random_flip=cfg.random_flip,
42 | random_crop=cfg.random_crop,
43 | random_erase=cfg.random_erase,
44 | color_jitter=cfg.color_jitter,
45 | padding=cfg.padding,
46 | vimc_wg=cfg.vimc_wg,
47 | vimc_cc=cfg.vimc_cc,
48 | vimc_sj=cfg.vimc_sj,
49 | num_workers=4)
50 |
51 | # Test DataLoader --------------------------------------------------------------------------------------------------
52 | gallery_loader, query_loader = None, None
53 | if cfg.eval_interval > 0:
54 | gallery_loader, query_loader = get_test_loader(dataset=cfg.dataset,
55 | root=cfg.data_root,
56 | batch_size=cfg.batch_size,
57 | image_size=cfg.image_size,
58 | num_workers=4)
59 |
60 | # Model ------------------------------------------------------------------------------------------------------------
61 | model = MCJA(num_classes=cfg.num_id,
62 | drop_last_stride=cfg.drop_last_stride,
63 | mda_ratio=cfg.mda_ratio,
64 | mda_m=cfg.mda_m,
65 | loss_id=cfg.loss_id,
66 | loss_cmr=cfg.loss_cmr)
67 |
68 | def get_parameter_number(net):
69 | total_num = sum(p.numel() for p in net.parameters())
70 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
71 | return {'Total': total_num, 'Trainable': trainable_num}
72 |
73 | logger.info(f'Model Parameter Num - {get_parameter_number(model)}')
74 |
75 | model.cuda()
76 |
77 | # Optimizer --------------------------------------------------------------------------------------------------------
78 | optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd)
79 | model, optimizer = amp.initialize(model, optimizer, enabled=cfg.fp16, opt_level='O1', verbosity=0)
80 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=cfg.lr_step, gamma=0.1)
81 |
82 | # Resume -----------------------------------------------------------------------------------------------------------
83 | if cfg.resume:
84 | checkpoint = torch.load(cfg.resume)
85 | for key in list(checkpoint.keys()):
86 | model_state_dict = model.state_dict()
87 | if key in model_state_dict:
88 | if torch.is_tensor(checkpoint[key]) and checkpoint[key].shape != model_state_dict[key].shape:
89 | logger.info(f'Warning during loading weights - Auto remove mismatch key: {key}')
90 | checkpoint.pop(key)
91 | model.load_state_dict(checkpoint, strict=False)
92 |
93 | # Engine -----------------------------------------------------------------------------------------------------------
94 | checkpoint_dir = os.path.join('ckptlog/', cfg.dataset, cfg.prefix)
95 | engine = get_trainer(dataset=cfg.dataset,
96 | model=model,
97 | optimizer=optimizer,
98 | lr_scheduler=lr_scheduler,
99 | logger=logger,
100 | writer=writer,
101 | non_blocking=True,
102 | log_period=cfg.log_period,
103 | save_dir=checkpoint_dir,
104 | prefix=cfg.prefix,
105 | eval_interval=cfg.eval_interval,
106 | start_eval=cfg.start_eval,
107 | gallery_loader=gallery_loader,
108 | query_loader=query_loader)
109 | engine.run(train_loader, max_epochs=cfg.num_epoch)
110 | writer.close()
111 |
112 |
113 | def test(cfg):
114 | # Logger -----------------------------------------------------------------------------------------------------------
115 | logger = logging.getLogger('MCJA')
116 | logger.info('\n## Starting the testing process...')
117 |
118 | # Test DataLoader --------------------------------------------------------------------------------------------------
119 | gallery_loader, query_loader = get_test_loader(dataset=cfg.dataset,
120 | root=cfg.data_root,
121 | batch_size=cfg.batch_size,
122 | image_size=cfg.image_size,
123 | num_workers=4,
124 | mode=None)
125 | if cfg.mser:
126 | gallery_loader_r, query_loader_r = get_test_loader(dataset=cfg.dataset,
127 | root=cfg.data_root,
128 | batch_size=cfg.batch_size,
129 | image_size=cfg.image_size,
130 | num_workers=4,
131 | mode='r')
132 | gallery_loader_g, query_loader_g = get_test_loader(dataset=cfg.dataset,
133 | root=cfg.data_root,
134 | batch_size=cfg.batch_size,
135 | image_size=cfg.image_size,
136 | num_workers=4,
137 | mode='g')
138 | gallery_loader_b, query_loader_b = get_test_loader(dataset=cfg.dataset,
139 | root=cfg.data_root,
140 | batch_size=cfg.batch_size,
141 | image_size=cfg.image_size,
142 | num_workers=4,
143 | mode='b')
144 |
145 | # Model ------------------------------------------------------------------------------------------------------------
146 | model = MCJA(num_classes=cfg.num_id,
147 | drop_last_stride=cfg.drop_last_stride,
148 | mda_ratio=cfg.mda_ratio,
149 | mda_m=cfg.mda_m,
150 | loss_id=cfg.loss_id,
151 | loss_cmr=cfg.loss_cmr)
152 | model.cuda()
153 | model = amp.initialize(model, enabled=cfg.fp16, opt_level='O1', verbosity=0)
154 |
155 | # Resume -----------------------------------------------------------------------------------------------------------
156 | resume_path = cfg.resume if cfg.resume else glob.glob(f'{cfg.log_dir}/*best*')[0]
157 | ## Note: if cfg.resume is specified, it will be used;
158 | ## otherwise, the best model trained in the current experiment will be automatically loaded.
159 | checkpoint = torch.load(resume_path)
160 | for key in list(checkpoint.keys()):
161 | model_state_dict = model.state_dict()
162 | if key in model_state_dict:
163 | if torch.is_tensor(checkpoint[key]) and checkpoint[key].shape != model_state_dict[key].shape:
164 | logger.info(f'Warning during loading weights - Auto remove mismatch key: {key}')
165 | checkpoint.pop(key)
166 | model.load_state_dict(checkpoint, strict=False)
167 |
168 | # Evaluator --------------------------------------------------------------------------------------------------------
169 | non_blocking = True
170 | evaluator = create_eval_engine(model, non_blocking)
171 | # extract query feature
172 | evaluator.run(query_loader)
173 | q_feats = torch.cat(evaluator.state.feat_list, dim=0)
174 | q_ids = torch.cat(evaluator.state.id_list, dim=0).numpy()
175 | q_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy()
176 | q_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0)
177 | # extract gallery feature
178 | evaluator.run(gallery_loader)
179 | g_feats = torch.cat(evaluator.state.feat_list, dim=0)
180 | g_ids = torch.cat(evaluator.state.id_list, dim=0).numpy()
181 | g_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy()
182 | g_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0)
183 |
184 | if cfg.mser:
185 | ###### Multi-Spectral Enhanced Ranking (MSER) ######
186 | evaluator = create_eval_engine(model, non_blocking)
187 | # extract query feature mode - r
188 | evaluator.run(query_loader_r)
189 | q_feats_r = torch.cat(evaluator.state.feat_list, dim=0)
190 | q_ids_r = torch.cat(evaluator.state.id_list, dim=0).numpy()
191 | q_cams_r = torch.cat(evaluator.state.cam_list, dim=0).numpy()
192 | q_img_paths_r = np.concatenate(evaluator.state.img_path_list, axis=0)
193 | # extract gallery feature mode - r
194 | evaluator.run(gallery_loader_r)
195 | g_feats_r = torch.cat(evaluator.state.feat_list, dim=0)
196 | g_ids_r = torch.cat(evaluator.state.id_list, dim=0).numpy()
197 | g_cams_r = torch.cat(evaluator.state.cam_list, dim=0).numpy()
198 | g_img_paths_r = np.concatenate(evaluator.state.img_path_list, axis=0)
199 |
200 | evaluator = create_eval_engine(model, non_blocking)
201 | # extract query feature mode - g
202 | evaluator.run(query_loader_g)
203 | q_feats_g = torch.cat(evaluator.state.feat_list, dim=0)
204 | q_ids_g = torch.cat(evaluator.state.id_list, dim=0).numpy()
205 | q_cams_g = torch.cat(evaluator.state.cam_list, dim=0).numpy()
206 | q_img_paths_g = np.concatenate(evaluator.state.img_path_list, axis=0)
207 | # extract gallery feature mode - g
208 | evaluator.run(gallery_loader_g)
209 | g_feats_g = torch.cat(evaluator.state.feat_list, dim=0)
210 | g_ids_g = torch.cat(evaluator.state.id_list, dim=0).numpy()
211 | g_cams_g = torch.cat(evaluator.state.cam_list, dim=0).numpy()
212 | g_img_paths_g = np.concatenate(evaluator.state.img_path_list, axis=0)
213 |
214 | evaluator = create_eval_engine(model, non_blocking)
215 | # extract query feature mode - b
216 | evaluator.run(query_loader_b)
217 | q_feats_b = torch.cat(evaluator.state.feat_list, dim=0)
218 | q_ids_b = torch.cat(evaluator.state.id_list, dim=0).numpy()
219 | q_cams_b = torch.cat(evaluator.state.cam_list, dim=0).numpy()
220 | q_img_paths_b = np.concatenate(evaluator.state.img_path_list, axis=0)
221 | # extract gallery feature mode - b
222 | evaluator.run(gallery_loader_b)
223 | g_feats_b = torch.cat(evaluator.state.feat_list, dim=0)
224 | g_ids_b = torch.cat(evaluator.state.id_list, dim=0).numpy()
225 | g_cams_b = torch.cat(evaluator.state.cam_list, dim=0).numpy()
226 | g_img_paths_b = np.concatenate(evaluator.state.img_path_list, axis=0)
227 |
228 | q_feats_mser = [q_feats, q_feats_r, q_feats_g, q_feats_b]
229 | q_ids_mser = [q_ids, q_ids_r, q_ids_g, q_ids_b]
230 | q_cams_mser = [q_cams, q_cams_r, q_cams_g, q_cams_b]
231 | q_img_paths_mser = [q_img_paths, q_img_paths_r, q_img_paths_g, q_img_paths_b]
232 | g_feats_mser = [g_feats, g_feats_r, g_feats_g, g_feats_b]
233 | g_ids_mser = [g_ids, g_ids_r, g_ids_g, g_ids_b]
234 | g_cams_mser = [g_cams, g_cams_r, g_cams_g, g_cams_b]
235 | g_img_paths_mser = [g_img_paths, g_img_paths_r, g_img_paths_g, g_img_paths_b]
236 |
237 | if cfg.dataset == 'sysu':
238 | perm = sio.loadmat(os.path.join(cfg.data_root, 'exp', 'rand_perm_cam.mat'))['rand_perm_cam']
239 | eval_sysu(q_feats, q_ids, q_cams, q_img_paths,
240 | g_feats, g_ids, g_cams, g_img_paths, perm, mode='all', num_shots=1)
241 | eval_sysu(q_feats, q_ids, q_cams, q_img_paths,
242 | g_feats, g_ids, g_cams, g_img_paths, perm, mode='all', num_shots=10)
243 | eval_sysu(q_feats, q_ids, q_cams, q_img_paths,
244 | g_feats, g_ids, g_cams, g_img_paths, perm, mode='indoor', num_shots=1)
245 | eval_sysu(q_feats, q_ids, q_cams, q_img_paths,
246 | g_feats, g_ids, g_cams, g_img_paths, perm, mode='indoor', num_shots=10)
247 | if cfg.mser:
248 | eval_sysu(q_feats_mser, q_ids_mser, q_cams_mser, q_img_paths_mser,
249 | g_feats_mser, g_ids_mser, g_cams_mser, g_img_paths_mser,
250 | perm, mode='all', num_shots=1, mser=True)
251 | eval_sysu(q_feats_mser, q_ids_mser, q_cams_mser, q_img_paths_mser,
252 | g_feats_mser, g_ids_mser, g_cams_mser, g_img_paths_mser,
253 | perm, mode='all', num_shots=10, mser=True)
254 | eval_sysu(q_feats_mser, q_ids_mser, q_cams_mser, q_img_paths_mser,
255 | g_feats_mser, g_ids_mser, g_cams_mser, g_img_paths_mser,
256 | perm, mode='indoor', num_shots=1, mser=True)
257 | eval_sysu(q_feats_mser, q_ids_mser, q_cams_mser, q_img_paths_mser,
258 | g_feats_mser, g_ids_mser, g_cams_mser, g_img_paths_mser,
259 | perm, mode='indoor', num_shots=10, mser=True)
260 |
261 | elif cfg.dataset == 'regdb':
262 | logger.info('Test Mode - infrared to visible')
263 | eval_regdb(q_feats, q_ids, q_cams, q_img_paths,
264 | g_feats, g_ids, g_cams, g_img_paths, mode='i2v')
265 | logger.info('Test Mode - visible to infrared')
266 | eval_regdb(g_feats, g_ids, g_cams, g_img_paths,
267 | q_feats, q_ids, q_cams, q_img_paths, mode='v2i')
268 | if cfg.mser:
269 | logger.info('Test Mode - infrared to visible')
270 | eval_regdb(q_feats_mser, q_ids_mser, q_cams_mser, q_img_paths_mser,
271 | g_feats_mser, g_ids_mser, g_cams_mser, g_img_paths_mser, mode='i2v', mser=True)
272 | logger.info('Test Mode - visible to infrared')
273 | eval_regdb(g_feats_mser, g_ids_mser, g_cams_mser, g_img_paths_mser,
274 | q_feats_mser, q_ids_mser, q_cams_mser, q_img_paths_mser, mode='v2i', mser=True)
275 | else:
276 | raise NotImplementedError(f'Dataset - {cfg.dataset} is not supported')
277 |
278 | evaluator.state.feat_list.clear()
279 | evaluator.state.id_list.clear()
280 | evaluator.state.cam_list.clear()
281 | evaluator.state.img_path_list.clear()
282 |
283 |
284 | if __name__ == '__main__':
285 | # Tools ------------------------------------------------------------------------------------------------------------
286 | import argparse
287 | from configs.default import strategy_cfg
288 | from configs.default import dataset_cfg
289 | from utils.tools import set_seed, time_str
290 |
291 | # Argument Parser --------------------------------------------------------------------------------------------------
292 | parser = argparse.ArgumentParser()
293 | parser.add_argument('--cfg', type=str, default='configs/SYSU_MCJA.yml', help='customized strategy config')
294 | parser.add_argument('--seed', type=int, default=8, help='random seed - choose a lucky number')
295 | parser.add_argument('--desc', type=str, default=None, help='auxiliary description of this experiment')
296 | parser.add_argument('--gpu', type=str, default='0', help='GPU device for the training process')
297 | args = parser.parse_args()
298 |
299 | # Environment ------------------------------------------------------------------------------------------------------
300 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
301 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
302 | set_seed(args.seed)
303 |
304 | # Configuration ----------------------------------------------------------------------------------------------------
305 | ## strategy_cfg
306 | cfg = strategy_cfg
307 | cfg.merge_from_file(args.cfg)
308 | ## dataset_cfg
309 | dataset_cfg = dataset_cfg.get(cfg.dataset)
310 | for k, v in dataset_cfg.items():
311 | cfg[k] = v
312 | ## other cfg
313 | cfg.prefix += f'_Time-{time_str()}'
314 | cfg.prefix += f'_{args.desc}' if (args.desc is not None) else ''
315 | cfg['log_dir'] = os.path.join('ckptlog/', cfg.dataset, cfg.prefix)
316 | ## freeze cfg
317 | cfg.freeze()
318 |
319 | # Logger ---------------------------------------------------------------------------------------------------------
320 | if not os.path.isdir(cfg.log_dir):
321 | os.makedirs(cfg.log_dir, exist_ok=True)
322 | logger = logging.getLogger('MCJA')
323 | logger.setLevel(logging.DEBUG)
324 | consoleHandler = logging.StreamHandler()
325 | consoleHandler.setLevel(logging.INFO)
326 | fileHandler = logging.FileHandler(filename=os.path.join(cfg.log_dir, 'log.txt'))
327 | fileHandler.setLevel(logging.INFO)
328 | formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='[%Y-%m-%d %H:%M:%S]')
329 | consoleHandler.setFormatter(formatter)
330 | fileHandler.setFormatter(formatter)
331 | logger.addHandler(consoleHandler)
332 | logger.addHandler(fileHandler)
333 | logger.info('\n' + pprint.pformat(cfg))
334 |
335 | # Train & Test -----------------------------------------------------------------------------------------------------
336 | if not cfg.test_only:
337 | train(cfg)
338 | test(cfg)
339 |
340 | # ------------------------------------------------------------------------------------------------------------------
341 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | """MCJA/models/__init__.py
2 | It is used to mark a directory as a Python package directory.
3 | """
--------------------------------------------------------------------------------
/models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | """MCJA/models/backbones/__init__.py
2 | It is used to mark a directory as a Python package directory.
3 | """
--------------------------------------------------------------------------------
/models/backbones/resnet.py:
--------------------------------------------------------------------------------
1 | """MCJA/models/backbones/resnet.py
2 | It implements a series of ResNets (the code originates from the standard library of PyTorch).
3 | """
4 |
5 | import torch
6 | from torch import Tensor
7 | import torch.nn as nn
8 | from torchvision.models.resnet import load_state_dict_from_url
9 | from typing import Type, Any, Callable, Union, List, Optional
10 |
11 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
12 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
13 | 'wide_resnet50_2', 'wide_resnet101_2']
14 |
15 | model_urls = {
16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
21 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
22 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
23 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
24 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
25 |
26 | # 'resnet50': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth',
27 | # 'resnet50': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a2_0-a2746f79.pth',
28 | # 'resnet50': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a3_0-59cae1ef.pth',
29 | }
30 |
31 |
32 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
33 | """3x3 convolution with padding"""
34 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
35 | padding=dilation, groups=groups, bias=False, dilation=dilation)
36 |
37 |
38 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
39 | """1x1 convolution"""
40 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
41 |
42 |
43 | class BasicBlock(nn.Module):
44 | expansion: int = 1
45 |
46 | def __init__(
47 | self,
48 | inplanes: int,
49 | planes: int,
50 | stride: int = 1,
51 | downsample: Optional[nn.Module] = None,
52 | groups: int = 1,
53 | base_width: int = 64,
54 | dilation: int = 1,
55 | norm_layer: Optional[Callable[..., nn.Module]] = None
56 | ) -> None:
57 | super(BasicBlock, self).__init__()
58 | if norm_layer is None:
59 | norm_layer = nn.BatchNorm2d
60 | if groups != 1 or base_width != 64:
61 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
62 | if dilation > 1:
63 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
64 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
65 | self.conv1 = conv3x3(inplanes, planes, stride)
66 | self.bn1 = norm_layer(planes)
67 | self.relu = nn.ReLU(inplace=True)
68 | self.conv2 = conv3x3(planes, planes)
69 | self.bn2 = norm_layer(planes)
70 | self.downsample = downsample
71 | self.stride = stride
72 |
73 | def forward(self, x: Tensor) -> Tensor:
74 | identity = x
75 |
76 | out = self.conv1(x)
77 | out = self.bn1(out)
78 | out = self.relu(out)
79 |
80 | out = self.conv2(out)
81 | out = self.bn2(out)
82 |
83 | if self.downsample is not None:
84 | identity = self.downsample(x)
85 |
86 | out += identity
87 | out = self.relu(out)
88 |
89 | return out
90 |
91 |
92 | class Bottleneck(nn.Module):
93 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
94 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
95 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
96 | # This variant is also known as ResNet V1.5 and improves accuracy according to
97 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
98 |
99 | expansion: int = 4
100 |
101 | def __init__(
102 | self,
103 | inplanes: int,
104 | planes: int,
105 | stride: int = 1,
106 | downsample: Optional[nn.Module] = None,
107 | groups: int = 1,
108 | base_width: int = 64,
109 | dilation: int = 1,
110 | norm_layer: Optional[Callable[..., nn.Module]] = None
111 | ) -> None:
112 | super(Bottleneck, self).__init__()
113 | if norm_layer is None:
114 | norm_layer = nn.BatchNorm2d
115 | width = int(planes * (base_width / 64.)) * groups
116 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
117 | self.conv1 = conv1x1(inplanes, width)
118 | self.bn1 = norm_layer(width)
119 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
120 | self.bn2 = norm_layer(width)
121 | self.conv3 = conv1x1(width, planes * self.expansion)
122 | self.bn3 = norm_layer(planes * self.expansion)
123 | self.relu = nn.ReLU(inplace=True)
124 | self.downsample = downsample
125 | self.stride = stride
126 |
127 | def forward(self, x: Tensor) -> Tensor:
128 | identity = x
129 |
130 | out = self.conv1(x)
131 | out = self.bn1(out)
132 | out = self.relu(out)
133 |
134 | out = self.conv2(out)
135 | out = self.bn2(out)
136 | out = self.relu(out)
137 |
138 | out = self.conv3(out)
139 | out = self.bn3(out)
140 |
141 | if self.downsample is not None:
142 | identity = self.downsample(x)
143 |
144 | out += identity
145 | out = self.relu(out)
146 |
147 | return out
148 |
149 |
150 | class ResNet(nn.Module):
151 |
152 | def __init__(
153 | self,
154 | block: Type[Union[BasicBlock, Bottleneck]],
155 | layers: List[int],
156 | num_classes: int = 1000,
157 | zero_init_residual: bool = False,
158 | groups: int = 1,
159 | width_per_group: int = 64,
160 | replace_stride_with_dilation: Optional[List[bool]] = None,
161 | norm_layer: Optional[Callable[..., nn.Module]] = None,
162 | **kwargs: Any
163 | ) -> None:
164 | super(ResNet, self).__init__()
165 | if norm_layer is None:
166 | norm_layer = nn.BatchNorm2d
167 | self._norm_layer = norm_layer
168 |
169 | self.inplanes = 64
170 | self.dilation = 1
171 | if replace_stride_with_dilation is None:
172 | # each element in the tuple indicates if we should replace
173 | # the 2x2 stride with a dilated convolution instead
174 | replace_stride_with_dilation = [False, False, False]
175 | if len(replace_stride_with_dilation) != 3:
176 | raise ValueError("replace_stride_with_dilation should be None "
177 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
178 | self.groups = groups
179 | self.base_width = width_per_group
180 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
181 | bias=False)
182 | self.bn1 = norm_layer(self.inplanes)
183 | self.relu = nn.ReLU(inplace=True)
184 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
185 | self.layer1 = self._make_layer(block, 64, layers[0])
186 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
187 | dilate=replace_stride_with_dilation[0])
188 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
189 | dilate=replace_stride_with_dilation[1])
190 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1 if kwargs.get('drop_last_stride') else 2,
191 | dilate=replace_stride_with_dilation[2])
192 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
193 | self.fc = nn.Linear(512 * block.expansion, num_classes)
194 |
195 | for m in self.modules():
196 | if isinstance(m, nn.Conv2d):
197 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
198 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
199 | nn.init.constant_(m.weight, 1)
200 | nn.init.constant_(m.bias, 0)
201 |
202 | # Zero-initialize the last BN in each residual branch,
203 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
204 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
205 | if zero_init_residual:
206 | for m in self.modules():
207 | if isinstance(m, Bottleneck):
208 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
209 | elif isinstance(m, BasicBlock):
210 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
211 |
212 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
213 | stride: int = 1, dilate: bool = False) -> nn.Sequential:
214 | norm_layer = self._norm_layer
215 | downsample = None
216 | previous_dilation = self.dilation
217 | if dilate:
218 | self.dilation *= stride
219 | stride = 1
220 | if stride != 1 or self.inplanes != planes * block.expansion:
221 | downsample = nn.Sequential(
222 | conv1x1(self.inplanes, planes * block.expansion, stride),
223 | norm_layer(planes * block.expansion),
224 | )
225 |
226 | layers = []
227 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
228 | self.base_width, previous_dilation, norm_layer))
229 | self.inplanes = planes * block.expansion
230 | for _ in range(1, blocks):
231 | layers.append(block(self.inplanes, planes, groups=self.groups,
232 | base_width=self.base_width, dilation=self.dilation,
233 | norm_layer=norm_layer))
234 |
235 | return nn.Sequential(*layers)
236 |
237 | def _forward_impl(self, x: Tensor, mode) -> Tensor:
238 | # See note [TorchScript super()]
239 | x = self.conv1(x)
240 | x = self.bn1(x)
241 | x = self.relu(x)
242 | x = self.maxpool(x)
243 |
244 | x = self.layer1(x)
245 | x = self.layer2(x)
246 | x = self.layer3(x)
247 | x = self.layer4(x)
248 |
249 | if mode == 'features':
250 | return x
251 |
252 | x = self.avgpool(x)
253 | if mode == 'embeddings':
254 | return x
255 |
256 | x = torch.flatten(x, 1)
257 | x = self.fc(x)
258 | return x
259 |
260 | def forward(self, x: Tensor, mode='logits') -> Tensor:
261 | assert mode in ['features', 'embeddings', 'logits']
262 | return self._forward_impl(x, mode)
263 |
264 |
265 | def _resnet(
266 | arch: str,
267 | block: Type[Union[BasicBlock, Bottleneck]],
268 | layers: List[int],
269 | pretrained: bool,
270 | progress: bool,
271 | **kwargs: Any
272 | ) -> ResNet:
273 | model = ResNet(block, layers, **kwargs)
274 | if pretrained:
275 | state_dict = load_state_dict_from_url(model_urls[arch],
276 | progress=progress)
277 | model.load_state_dict(state_dict)
278 | return model
279 |
280 |
281 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
282 | r"""ResNet-18 model from
283 | `"Deep Residual Learning for Image Recognition" `_.
284 |
285 | Args:
286 | pretrained (bool): If True, returns a model pre-trained on ImageNet
287 | progress (bool): If True, displays a progress bar of the download to stderr
288 | """
289 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
290 | **kwargs)
291 |
292 |
293 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
294 | r"""ResNet-34 model from
295 | `"Deep Residual Learning for Image Recognition" `_.
296 |
297 | Args:
298 | pretrained (bool): If True, returns a model pre-trained on ImageNet
299 | progress (bool): If True, displays a progress bar of the download to stderr
300 | """
301 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
302 | **kwargs)
303 |
304 |
305 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
306 | r"""ResNet-50 model from
307 | `"Deep Residual Learning for Image Recognition" `_.
308 |
309 | Args:
310 | pretrained (bool): If True, returns a model pre-trained on ImageNet
311 | progress (bool): If True, displays a progress bar of the download to stderr
312 | """
313 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
314 | **kwargs)
315 |
316 |
317 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
318 | r"""ResNet-101 model from
319 | `"Deep Residual Learning for Image Recognition" `_.
320 |
321 | Args:
322 | pretrained (bool): If True, returns a model pre-trained on ImageNet
323 | progress (bool): If True, displays a progress bar of the download to stderr
324 | """
325 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
326 | **kwargs)
327 |
328 |
329 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
330 | r"""ResNet-152 model from
331 | `"Deep Residual Learning for Image Recognition" `_.
332 |
333 | Args:
334 | pretrained (bool): If True, returns a model pre-trained on ImageNet
335 | progress (bool): If True, displays a progress bar of the download to stderr
336 | """
337 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
338 | **kwargs)
339 |
340 |
341 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
342 | r"""ResNeXt-50 32x4d model from
343 | `"Aggregated Residual Transformation for Deep Neural Networks" `_.
344 |
345 | Args:
346 | pretrained (bool): If True, returns a model pre-trained on ImageNet
347 | progress (bool): If True, displays a progress bar of the download to stderr
348 | """
349 | kwargs['groups'] = 32
350 | kwargs['width_per_group'] = 4
351 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
352 | pretrained, progress, **kwargs)
353 |
354 |
355 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
356 | r"""ResNeXt-101 32x8d model from
357 | `"Aggregated Residual Transformation for Deep Neural Networks" `_.
358 |
359 | Args:
360 | pretrained (bool): If True, returns a model pre-trained on ImageNet
361 | progress (bool): If True, displays a progress bar of the download to stderr
362 | """
363 | kwargs['groups'] = 32
364 | kwargs['width_per_group'] = 8
365 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
366 | pretrained, progress, **kwargs)
367 |
368 |
369 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
370 | r"""Wide ResNet-50-2 model from
371 | `"Wide Residual Networks" `_.
372 |
373 | The model is the same as ResNet except for the bottleneck number of channels
374 | which is twice larger in every block. The number of channels in outer 1x1
375 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
376 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
377 |
378 | Args:
379 | pretrained (bool): If True, returns a model pre-trained on ImageNet
380 | progress (bool): If True, displays a progress bar of the download to stderr
381 | """
382 | kwargs['width_per_group'] = 64 * 2
383 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
384 | pretrained, progress, **kwargs)
385 |
386 |
387 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
388 | r"""Wide ResNet-101-2 model from
389 | `"Wide Residual Networks" `_.
390 |
391 | The model is the same as ResNet except for the bottleneck number of channels
392 | which is twice larger in every block. The number of channels in outer 1x1
393 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
394 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
395 |
396 | Args:
397 | pretrained (bool): If True, returns a model pre-trained on ImageNet
398 | progress (bool): If True, displays a progress bar of the download to stderr
399 | """
400 | kwargs['width_per_group'] = 64 * 2
401 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
402 | pretrained, progress, **kwargs)
403 |
--------------------------------------------------------------------------------
/models/mcja.py:
--------------------------------------------------------------------------------
1 | """MCJA/models/mcja.py
2 | It defines the Multi-level Cross-modality Joint Alignment (MCJA) model, a framework for cross-modality VI-ReID task.
3 | """
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | from models.backbones.resnet import resnet50
9 | from models.modules.mda import MDA
10 | from losses.cm_retrieval_loss import CMRetrievalLoss
11 | from utils.calc_acc import calc_acc
12 |
13 |
14 | class MCJA(nn.Module):
15 | """
16 | The Class of Multi-Channel Joint Analysis (MCJA) model, designed for cross-modality person re-identification tasks.
17 | This model integrates various components, including a backbone for feature extraction, the Modality Distribution
18 | Adapter (MDA) for better cross-modality feature alignment & distribution adaptation, a neck for feature embedding,
19 | a head for classification, and specialized loss functions (identity and cross-modality retrieval (CMR) losses).
20 |
21 | Args:
22 | - num_classes (int): The number of identity classes in the dataset.
23 | - drop_last_stride (bool): A flag to determine whether the last stride in the backbone should be dropped.
24 | - mda_ratio (int): The ratio for reducing the channel dimensions in MDA layers.
25 | - mda_m (int): The number of modalities considered by the MDA layers.
26 | - loss_id (bool): Whether to use the identity loss during training.
27 | - loss_cmr (bool): Whether to use the cross-modality retrieval loss during training.
28 |
29 | Methods:
30 | - forward(inputs, labels=None, **kwargs): Processes the input through the MCJA model.
31 | In training mode, it computes the loss and metrics based on the provided labels and additional information (e.g.,
32 | camera IDs for modality labels). In evaluation mode, it returns the feature embeddings after BN neck processing.
33 | - train_forward(embeddings, labels, **kwargs): A helper function called during training to compute losses.
34 | It calculates the identity and CMR losses based on embeddings, identity labels, and modality labels.
35 | """
36 |
37 | def __init__(self, num_classes, drop_last_stride=False, mda_ratio=2, mda_m=2, loss_id=True, loss_cmr=True):
38 | super(MCJA, self).__init__()
39 |
40 | # Backbone -----------------------------------------------------------------------------------------------------
41 | self.backbone = resnet50(pretrained=True, drop_last_stride=drop_last_stride)
42 | self.base_dim = 2048
43 |
44 | # Neck ---------------------------------------------------------------------------------------------------------
45 | self.bn_neck = nn.BatchNorm1d(self.base_dim)
46 | nn.init.constant_(self.bn_neck.bias, 0)
47 | self.bn_neck.bias.requires_grad_(False)
48 |
49 | # Head ---------------------------------------------------------------------------------------------------------
50 | self.classifier = nn.Linear(self.base_dim, num_classes, bias=False)
51 |
52 | # Loss ---------------------------------------------------------------------------------------------------------
53 | self.id_loss = nn.CrossEntropyLoss() if loss_id else None
54 | ###### Cross-Modality Retrieval Loss (CMR) ######
55 | self.cmr_loss = CMRetrievalLoss() if loss_cmr else None
56 |
57 | # Module -------------------------------------------------------------------------------------------------------
58 | layers = [3, 4, 6, 3] # Just for ResNet50
59 | ###### Modality Distribution Adapter (MDA) ######
60 | mda_layers = [0, 2, 3, 0]
61 | self.MDA_1 = nn.ModuleList(
62 | [MDA(in_channels=256, inter_ratio=mda_ratio, m_num=mda_m) for _ in range(mda_layers[0])])
63 | self.MDA_1_idx = sorted([layers[0] - (i + 1) for i in range(mda_layers[0])])
64 | self.MDA_2 = nn.ModuleList(
65 | [MDA(in_channels=512, inter_ratio=mda_ratio, m_num=mda_m) for _ in range(mda_layers[1])])
66 | self.MDA_2_idx = sorted([layers[1] - (i + 1) for i in range(mda_layers[1])])
67 | self.MDA_3 = nn.ModuleList(
68 | [MDA(in_channels=1024, inter_ratio=mda_ratio, m_num=mda_m) for _ in range(mda_layers[2])])
69 | self.MDA_3_idx = sorted([layers[2] - (i + 1) for i in range(mda_layers[2])])
70 | self.MDA_4 = nn.ModuleList(
71 | [MDA(in_channels=2048, inter_ratio=mda_ratio, m_num=mda_m) for _ in range(mda_layers[3])])
72 | self.MDA_4_idx = sorted([layers[3] - (i + 1) for i in range(mda_layers[3])])
73 |
74 | def forward(self, inputs, labels=None, **kwargs):
75 |
76 | # Feature Extraction -------------------------------------------------------------------------------------------
77 | feats = self.backbone.conv1(inputs)
78 | feats = self.backbone.bn1(feats)
79 | feats = self.backbone.relu(feats)
80 | feats = self.backbone.maxpool(feats)
81 |
82 | MDA_1_counter = 0
83 | if len(self.MDA_1_idx) == 0: self.MDA_1_idx = [-1]
84 | for i in range(len(self.backbone.layer1)):
85 | feats = self.backbone.layer1[i](feats)
86 | if i == self.MDA_1_idx[MDA_1_counter]:
87 | _, C, H, W = feats.shape
88 | feats = self.MDA_1[MDA_1_counter](feats)
89 | MDA_1_counter += 1
90 | MDA_2_counter = 0
91 | if len(self.MDA_2_idx) == 0: self.MDA_2_idx = [-1]
92 | for i in range(len(self.backbone.layer2)):
93 | feats = self.backbone.layer2[i](feats)
94 | if i == self.MDA_2_idx[MDA_2_counter]:
95 | _, C, H, W = feats.shape
96 | feats = self.MDA_2[MDA_2_counter](feats)
97 | MDA_2_counter += 1
98 | MDA_3_counter = 0
99 | if len(self.MDA_3_idx) == 0: self.MDA_3_idx = [-1]
100 | for i in range(len(self.backbone.layer3)):
101 | feats = self.backbone.layer3[i](feats)
102 | if i == self.MDA_3_idx[MDA_3_counter]:
103 | _, C, H, W = feats.shape
104 | feats = self.MDA_3[MDA_3_counter](feats)
105 | MDA_3_counter += 1
106 | MDA_4_counter = 0
107 | if len(self.MDA_4_idx) == 0: self.MDA_4_idx = [-1]
108 | for i in range(len(self.backbone.layer4)):
109 | feats = self.backbone.layer4[i](feats)
110 | if i == self.MDA_4_idx[MDA_4_counter]:
111 | _, C, H, W = feats.shape
112 | feats = self.MDA_4[MDA_4_counter](feats)
113 | MDA_4_counter += 1
114 | global_feats = feats
115 |
116 | # Feature Embedding --------------------------------------------------------------------------------------------
117 | b, c, h, w = global_feats.shape
118 | global_feats = global_feats.view(b, c, -1)
119 | p = 3.0
120 | embeddings = (torch.mean(global_feats ** p, dim=-1) + 1e-12) ** (1 / p) # GeMPooling
121 |
122 | # Train & Test Return ------------------------------------------------------------------------------------------
123 | if self.training:
124 | return self.train_forward(embeddings, labels, **kwargs)
125 | else:
126 | return self.bn_neck(embeddings)
127 |
128 | def train_forward(self, embeddings, labels, **kwargs):
129 | loss = 0
130 | metric = {}
131 |
132 | embeddings = self.bn_neck(embeddings)
133 |
134 | # modality labels
135 | cam_ids = kwargs.get('cam_ids')
136 | rgb_idx_mask = (cam_ids == 1) + (cam_ids == 2) + (cam_ids == 4) + (cam_ids == 5)
137 | ir_idx_mask = (cam_ids == 3) + (cam_ids == 6)
138 | m_labels = torch.ones((len(labels)))
139 | m_labels[rgb_idx_mask] = 0
140 | m_labels[ir_idx_mask] = 1
141 |
142 | if self.cmr_loss is not None:
143 | ###### Cross-Modality Retrieval Loss (CMR) ######
144 | cmr_loss = self.cmr_loss(embeddings.float(), id_labels=labels, m_labels=m_labels)
145 | loss += cmr_loss
146 | metric.update({'loss_cmr': cmr_loss.data})
147 |
148 | logits = self.classifier(embeddings)
149 |
150 | if self.id_loss is not None:
151 | # Identity Loss (ID Loss)
152 | id_loss = self.id_loss(logits.float(), labels)
153 | loss += id_loss
154 | metric.update({'cls_acc': calc_acc(logits.data, labels), 'loss_id': id_loss.data})
155 |
156 | return loss, metric
157 |
--------------------------------------------------------------------------------
/models/modules/__init__.py:
--------------------------------------------------------------------------------
1 | """MCJA/models/modules/__init__.py
2 | It is used to mark a directory as a Python package directory.
3 | """
--------------------------------------------------------------------------------
/models/modules/mda.py:
--------------------------------------------------------------------------------
1 | """MCJA/models/modules/mda.py
2 | It defines the Modality Distribution Adapter (MDA) class, a module designed to enhance cross-modal feature learning.
3 | """
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | class MDA(nn.Module):
10 | """
11 | A module implementing the Modality Distribution Adapter (MDA), a mechanism designed to enhance cross-modal learning
12 | by adaptively re-weighting feature channels based on their relevance to different modalities. The MDA module
13 | dynamically adjusts the contribution of different feature channels to the task at hand, depending on the context
14 | provided by different modalities, thereby facilitating more effective integration of multimodal information.
15 |
16 | Args:
17 | - in_channels (int): Number of channels in the input feature map.
18 | - inter_ratio (int): Reduction ratio for intermediate channel dimensions, controlling the compactness of the module.
19 | - m_num (int): Number of distinct modalities that the model needs to adapt to.
20 |
21 | Methods:
22 | - forward(x): Processes the input tensor through the MDA to produce an output with adapted feature distributions.
23 | """
24 |
25 | def __init__(self, in_channels, inter_ratio=2, m_num=2):
26 | super(MDA, self).__init__()
27 | self.in_channels = in_channels
28 | self.inter_ratio = inter_ratio
29 | self.planes = in_channels // inter_ratio
30 | self.m_num = m_num
31 |
32 | self.sc_conv = nn.Conv2d(in_channels, m_num, kernel_size=1)
33 | self.softmax = nn.Softmax(dim=-1)
34 |
35 | self.ca_conv = nn.Conv1d(self.in_channels, self.in_channels,
36 | kernel_size=m_num, groups=self.in_channels, bias=False)
37 | self.ca_bn = nn.BatchNorm1d(self.in_channels)
38 |
39 | self.norm_bn = nn.BatchNorm2d(self.in_channels)
40 | nn.init.constant_(self.norm_bn.weight, 0.0)
41 | nn.init.constant_(self.norm_bn.bias, 0.0)
42 |
43 | def forward(self, x):
44 | input_x = x
45 | batch, channel, height, width = x.size()
46 |
47 | # Spatial Characteristics Learning -----------------------------------------------------------------------------
48 | # [B, C, H, W] -> [B, C, H * W]
49 | input_x = input_x.view(batch, channel, height * width)
50 | # [B, C, H * W] -> [B, 1, C, H * W]
51 | input_x = input_x.unsqueeze(1)
52 | # [B, C, H, W] -> [B, M, H, W]
53 | context_mask = self.sc_conv(x)
54 | # [B, M, H, W] -> [B, M, H * W]
55 | context_mask = context_mask.view(batch, self.m_num, height * width)
56 | # [B, M, H * W] -> [B, M, H * W]
57 | context_mask = self.softmax(context_mask)
58 | # [B, M, H * W] -> [B, M, H * W, 1]
59 | context_mask = context_mask.unsqueeze(-1)
60 | # [B, 1, C, H * W] [B, M, H * W, 1] -> [B, M, C, 1]
61 | context = torch.matmul(input_x, context_mask)
62 | # [B, M, C, 1] -> [B, C, M]
63 | context = context.squeeze(-1).permute(0, 2, 1)
64 |
65 | # Characteristics Aggregation ----------------------------------------------------------------------------------
66 | # [B, C, M] -> [B, C, 1]
67 | z = self.ca_conv(context)
68 | # [B, C, 1] -> [B, C, 1]
69 | z = self.ca_bn(z)
70 | # [B, C, 1] -> [B, C, 1]
71 | g = torch.sigmoid(z)
72 |
73 | # Feature Distribution Adaption --------------------------------------------------------------------------------
74 | # [B, C, 1] -> [B, C, 1, 1]
75 | g = g.view(batch, channel, 1, 1)
76 | # [B, C, H, W] [B, C, 1, 1] -> [B, C, H, W]
77 | out = self.norm_bn(x * g.expand_as(x)) + x
78 |
79 | return out
80 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """MCJA/utils/__init__.py
2 | It is used to mark a directory as a Python package directory.
3 | """
--------------------------------------------------------------------------------
/utils/calc_acc.py:
--------------------------------------------------------------------------------
1 | """MCJA/utils/calc_acc.py
2 | This utility file defines a function `calc_acc` for calculating classification accuracy.
3 | """
4 |
5 | import torch
6 |
7 |
8 | def calc_acc(logits, label, ignore_index=-100, mode="multiclass"):
9 | """
10 | A utility function for calculating the accuracy of model predictions given the logits and corresponding labels.
11 | It supports both binary and multiclass classification tasks by interpreting the logits according to the specified
12 | mode and comparing them against the ground truth labels to determine the number of correct predictions.
13 | The function also accommodates scenarios where certain examples should be ignored in the accuracy calculation,
14 | based on a designated ignore_index or the structure of the labels.
15 |
16 | Args:
17 | - logits (Tensor): The output logits from a model. For binary classification, logits should be a 1D tensor of
18 | probabilities. For multiclass classification, logits should be a 2D tensor with shape [batch_size, num_classes].
19 | - label (Tensor): The ground truth labels for the predictions. For multiclass classification, labels should be a
20 | 1D tensor of class indices. For binary classification, labels should be a tensor with the same shape as logits.
21 | - ignore_index (int, optional): Specifies a label value that should be ignored when calculating accuracy.
22 | Examples with this label are not considered in the denominator of the accuracy calculation. Default is -100.
23 | - mode (str, optional): Determines how logits are interpreted. Can be "binary" for binary classification tasks,
24 | where logits are rounded to 0 or 1, or "multiclass" for tasks with more than two classes, where the class with
25 | the highest logit is selected. Default is "multiclass".
26 |
27 | Returns:
28 | - Tensor: The calculated accuracy as a float value, representing the proportion of correct predictions out
29 | of the total number of examples considered (excluding ignored examples).
30 | """
31 |
32 | if mode == "binary":
33 | indices = torch.round(logits).type(label.type())
34 | elif mode == "multiclass":
35 | indices = torch.max(logits, dim=1)[1]
36 |
37 | if label.size() == logits.size():
38 | ignore = 1 - torch.round(label.sum(dim=1))
39 | label = torch.max(label, dim=1)[1]
40 | else:
41 | ignore = torch.eq(label, ignore_index).view(-1)
42 |
43 | correct = torch.eq(indices, label).view(-1)
44 | num_correct = torch.sum(correct)
45 | num_examples = logits.shape[0] - ignore.sum()
46 |
47 | return num_correct.float() / num_examples.float()
48 |
--------------------------------------------------------------------------------
/utils/eval_data.py:
--------------------------------------------------------------------------------
1 | """MCJA/utils/eval_data.py
2 | It provides evaluation utilities for assessing the performance of cross-modal person re-identification methods.
3 | """
4 |
5 | import os
6 | import logging
7 | import numpy as np
8 | import torch
9 | from torch.nn import functional as F
10 | from .mser_rerank import mser_rerank, pairwise_distance
11 |
12 |
13 | def get_gallery_names_sysu(perm, cams, ids, trial_id, num_shots=1):
14 | """
15 | A utility function designed specifically for constructing a list of gallery image file paths for the SYSU-MM01
16 | dataset in a cross-modality person re-identification task. This function takes a permutation array that organizes
17 | the data into camera views and identities, and then selects a specified number of shots (instances) for each
18 | identity from the desired cameras to compile the gallery set for a given trial. The output is a list of formatted
19 | strings that represent the file paths to the selected gallery images, adhering to the dataset's directory structure.
20 |
21 | Args:
22 | - perm (list): A nested list where each element corresponds to a camera view in the dataset.
23 | Each camera's list contains arrays of instance indices for each identity, organized by trials.
24 | - cams (list): A list of integers indicating which camera views to include in the gallery set.
25 | Camera numbers should match those used in the dataset.
26 | - ids (list): A list of integers specifying the identities to be included in the gallery set.
27 | Identity numbers should correspond to those in the dataset.
28 | - trial_id (int): The index of the trial for which to construct the gallery.
29 | This index is used to select specific instances from the permutation arrays,
30 | allowing for variability across different evaluation runs.
31 | - num_shots (int, optional): The number of shots to select for each identity from each camera view.
32 |
33 | Returns:
34 | - list: A list of strings, each representing the file path to an image selected for the gallery.
35 | Paths are formatted to match the directory structure of the SYSU-MM01 dataset.
36 | """
37 |
38 | names = []
39 | for cam in cams:
40 | cam_perm = perm[cam - 1][0].squeeze()
41 | for i in ids:
42 | if (i - 1) < len(cam_perm) and len(cam_perm[i - 1]) > 0:
43 | instance_id = cam_perm[i - 1][trial_id][:num_shots]
44 | names.extend(['cam{}/{:0>4d}/{:0>4d}'.format(cam, i, ins) for ins in instance_id.tolist()])
45 | return names
46 |
47 |
48 | def get_unique(array):
49 | """
50 | A utility function that returns a sorted unique array of elements from the input array. It identifies all unique
51 | elements within the input array and selects their first occurrence, preserving the order of these unique elements
52 | based on their initial appearance in the input array. This function is particularly useful for processing arrays
53 | where the uniqueness and order of elements are essential, such as when filtering duplicate entries from lists of
54 | identifiers or categories without disrupting their original sequence.
55 |
56 | Args:
57 | - array (ndarray): An input array from which unique elements are to be extracted.
58 |
59 | Returns:
60 | - ndarray: A new array containing only the unique elements of the input array,
61 | sorted according to their first occurrence in the original array.
62 | """
63 |
64 | _, idx = np.unique(array, return_index=True)
65 | array_new = array[np.sort(idx)]
66 | return array_new
67 |
68 |
69 | def get_cmc(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids):
70 | """
71 | A utility function for calculating the Cumulative Matching Characteristics (CMC) curve in person re-identification
72 | tasks. The CMC curve is a standard evaluation metric used to measure the performance of a re-identification model
73 | by determining the probability that a query identity appears in different sized candidate lists. This function
74 | processes the sorted indices of gallery samples for each query, excludes gallery samples captured by the same
75 | camera as the query to avoid camera bias, and computes the CMC curve based on the first correct match's position.
76 |
77 | Args:
78 | - sorted_indices (ndarray): An array of indices that sorts the gallery samples
79 | in ascending order of their distance to each query sample.
80 | - query_ids (ndarray): An array containing the identity labels of the query samples.
81 | - query_cam_ids (ndarray): An array containing the camera IDs associated with each query sample.
82 | - gallery_ids (ndarray): An array containing the identity labels of the gallery samples.
83 | - gallery_cam_ids (ndarray): An array containing the camera IDs associated with each gallery sample.
84 |
85 | Returns:
86 | - ndarray: The CMC curve represented as a 1D array where each element at index i indicates the probability
87 | that a query identity is correctly matched within the top-(i+1) ranks of the sorted gallery list.
88 | """
89 |
90 | gallery_unique_count = get_unique(gallery_ids).shape[0]
91 | match_counter = np.zeros((gallery_unique_count,))
92 |
93 | result = gallery_ids[sorted_indices]
94 | cam_locations_result = gallery_cam_ids[sorted_indices]
95 |
96 | valid_probe_sample_count = 0
97 |
98 | for probe_index in range(sorted_indices.shape[0]):
99 | # remove gallery samples from the same camera of the probe
100 | result_i = result[probe_index, :]
101 | result_i[np.equal(cam_locations_result[probe_index], query_cam_ids[probe_index])] = -1
102 |
103 | # remove the -1 entries from the label result
104 | result_i = np.array([i for i in result_i if i != -1])
105 |
106 | # remove duplicated id in "stable" manner - following the official test protocol in VI-ReID
107 | result_i_unique = get_unique(result_i)
108 |
109 | # match for probe i
110 | match_i = np.equal(result_i_unique, query_ids[probe_index])
111 |
112 | if np.sum(match_i) != 0: # if there is true matching in gallery
113 | valid_probe_sample_count += 1
114 | match_counter += match_i
115 |
116 | rank = match_counter / valid_probe_sample_count
117 | cmc = np.cumsum(rank)
118 | return cmc
119 |
120 |
121 | def get_mAP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids):
122 | """
123 | A utility function for computing the mean Average Precision (mAP) for evaluating person re-identification methods.
124 | The mAP metric provides a single-figure measure of quality across recall levels, particularly useful in scenarios
125 | where the query identity appears multiple times in the gallery. This function iterates over each query, excludes
126 | gallery images captured by the same camera to prevent bias, and calculates the Average Precision (AP) for each
127 | query based on its matches in the gallery. The mAP is then obtained by averaging the APs across all queries.
128 |
129 | Args:
130 | - sorted_indices (ndarray): An array of indices that sorts the gallery samples
131 | in ascending order of their distance to each query sample.
132 | - query_ids (ndarray): An array containing the identity labels of the query samples.
133 | - query_cam_ids (ndarray): An array containing the camera IDs associated with each query sample.
134 | - gallery_ids (ndarray): An array containing the identity labels of the gallery samples.
135 | - gallery_cam_ids (ndarray): An array containing the camera IDs associated with each gallery sample.
136 |
137 | Returns:
138 | - float: The mean Average Precision (mAP) calculated across all query samples,
139 | representing the overall precision of the re-identification method at varying levels of recall.
140 | """
141 |
142 | result = gallery_ids[sorted_indices]
143 | cam_locations_result = gallery_cam_ids[sorted_indices]
144 |
145 | valid_probe_sample_count = 0
146 | avg_precision_sum = 0
147 |
148 | for probe_index in range(sorted_indices.shape[0]):
149 | # remove gallery samples from the same camera of the probe
150 | result_i = result[probe_index, :]
151 | result_i[cam_locations_result[probe_index, :] == query_cam_ids[probe_index]] = -1
152 |
153 | # remove the -1 entries from the label result
154 | result_i = np.array([i for i in result_i if i != -1])
155 |
156 | # match for probe i
157 | match_i = result_i == query_ids[probe_index]
158 | true_match_count = np.sum(match_i)
159 |
160 | if true_match_count != 0: # if there is true matching in gallery
161 | valid_probe_sample_count += 1
162 | true_match_rank = np.where(match_i)[0]
163 |
164 | ap = np.mean(np.arange(1, true_match_count + 1) / (true_match_rank + 1))
165 | avg_precision_sum += ap
166 |
167 | mAP = avg_precision_sum / valid_probe_sample_count
168 | return mAP
169 |
170 |
171 | def get_mINP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids):
172 | """
173 | A utility function designed for evaluating the mean Inverse Negative Penalty (mINP) across all query samples in
174 | a person re-identification method. The mINP metric focuses on the hardest positive sample's position in the ranked
175 | list of gallery samples for each query, providing insight into the system's ability to recall all relevant instances
176 | of an identity. This function computes the INP for each query by excluding gallery samples captured by the same
177 | camera as the query, identifying the rank position of the farthest correct match, and calculating the INP based on
178 | this position. The mean INP is then derived by averaging the INP scores across all valid queries.
179 |
180 | Args:
181 | - sorted_indices (ndarray): An array of indices that sorts the gallery samples
182 | in ascending order of their distance to each query sample.
183 | - query_ids (ndarray): An array containing the identity labels of the query samples.
184 | - query_cam_ids (ndarray): An array containing the camera IDs associated with each query sample.
185 | - gallery_ids (ndarray): An array containing the identity labels of the gallery samples.
186 | - gallery_cam_ids (ndarray): An array containing the camera IDs associated with each gallery sample.
187 |
188 | Returns:
189 | - float: The mean Inverse Negative Penalty (mINP) calculated across all queries, reflecting the method's
190 | effectiveness in retrieving relevant matches from gallery, particularly the most challenging matches to identify.
191 | """
192 |
193 | result = gallery_ids[sorted_indices]
194 | cam_locations_result = gallery_cam_ids[sorted_indices]
195 |
196 | valid_probe_sample_count = 0
197 | INP_sum = 0
198 |
199 | for probe_index in range(sorted_indices.shape[0]):
200 | # remove gallery samples from the same camera of the probe
201 | result_i = result[probe_index, :]
202 | result_i[np.equal(cam_locations_result[probe_index], query_cam_ids[probe_index])] = -1
203 |
204 | # remove the -1 entries from the label result
205 | result_i = np.array([i for i in result_i if i != -1])
206 |
207 | # match for probe i
208 | match_i = result_i == query_ids[probe_index]
209 | true_match_count = np.sum(match_i)
210 |
211 | if true_match_count != 0: # if there is true matching in gallery
212 | valid_probe_sample_count += 1
213 | true_match_rank = np.where(match_i)[0]
214 | hardest_match_pos = true_match_rank[-1] + 1
215 |
216 | NP = (hardest_match_pos - true_match_count) / hardest_match_pos
217 | INP = 1 - NP
218 |
219 | INP_sum += INP
220 |
221 | mINP = INP_sum / valid_probe_sample_count
222 |
223 | return mINP
224 |
225 |
226 | def get_cmc_mAP_mINP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids):
227 | """
228 | A comprehensive utility function designed to compute three key metrics simultaneously for evaluating the performance
229 | of person re-identification methods: Cumulative Matching Characteristics (CMC), mean Average Precision (mAP), and
230 | mean Inverse Negative Penalty (mINP). This function integrates the processes of calculating these metrics into a
231 | single operation, optimizing the evaluation workflow for the person re-identification task.
232 |
233 | Args:
234 | - sorted_indices (ndarray): Indices sorting the gallery samples by ascending similarity to each query sample.
235 | - query_ids (ndarray): Identity labels for the query samples.
236 | - query_cam_ids (ndarray): Camera IDs from which each query sample was captured.
237 | - gallery_ids (ndarray): Identity labels for the gallery samples.
238 | - gallery_cam_ids (ndarray): Camera IDs from which each gallery sample was captured.
239 |
240 | Returns:
241 | - tuple: A tuple containing three elements:
242 | - cmc (ndarray): The CMC curve as a 1D array.
243 | - mAP (float): The mean Average Precision score.
244 | - mINP (float): The mean Inverse Negative Penalty score.
245 | """
246 |
247 | gallery_unique_count = get_unique(gallery_ids).shape[0]
248 | match_counter = np.zeros((gallery_unique_count,))
249 |
250 | result = gallery_ids[sorted_indices]
251 | cam_locations_result = gallery_cam_ids[sorted_indices]
252 |
253 | valid_probe_sample_count = 0
254 | avg_precision_sum = 0
255 | INP_sum = 0
256 |
257 | for probe_index in range(sorted_indices.shape[0]):
258 | # remove gallery samples from the same camera of the probe
259 | result_i = result[probe_index, :]
260 | result_i[np.equal(cam_locations_result[probe_index], query_cam_ids[probe_index])] = -1
261 |
262 | # remove the -1 entries from the label result
263 | result_i = np.array([i for i in result_i if i != -1])
264 |
265 | # remove duplicated id in "stable" manner - following the official test protocol in VI-ReID
266 | result_i_unique = get_unique(result_i)
267 |
268 | # match for probe i
269 | match_i = result_i == query_ids[probe_index]
270 | true_match_count = np.sum(match_i)
271 | match_i_unique = np.equal(result_i_unique, query_ids[probe_index])
272 |
273 | if true_match_count != 0: # if there is true matching in gallery
274 | valid_probe_sample_count += 1
275 | if match_counter.shape != match_i_unique.shape:
276 | sub_num = match_counter.shape[0] - match_i_unique.shape[0]
277 | match_i_unique = np.hstack([match_i_unique, [False] * sub_num])
278 | match_counter += match_i_unique
279 | true_match_rank = np.where(match_i)[0]
280 | ap = np.mean(np.arange(1, true_match_count + 1) / (true_match_rank + 1))
281 | avg_precision_sum += ap
282 | hardest_match_pos = true_match_rank[-1] + 1
283 | NP = (hardest_match_pos - true_match_count) / hardest_match_pos
284 | INP = 1 - NP
285 | INP_sum += INP
286 |
287 | rank = match_counter / valid_probe_sample_count
288 | cmc = np.cumsum(rank)
289 | mAP = avg_precision_sum / valid_probe_sample_count
290 | mINP = INP_sum / valid_probe_sample_count
291 | return cmc, mAP, mINP
292 |
293 |
294 | def eval_sysu(query_feats, query_ids, query_cam_ids, query_img_paths,
295 | gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths,
296 | perm, mode='all', num_shots=1, num_trials=10, mser=False):
297 | """
298 | A versatile function designed for evaluating the performance of VI-ReID models on the SYSU-MM01 dataset,
299 | offering the flexibility to apply either a basic evaluation strategy or the Multi-Spectral Enhanced Ranking (MSER)
300 | strategy based on the specified parameters. This function orchestrates the evaluation process across multiple
301 | trials, adjusting for different experimental settings such as the evaluation mode and the number of gallery shots.
302 |
303 | Args:
304 | - query_feats (Tensor or List[Tensor]): Feature vectors of query images.
305 | - query_ids (Tensor or List[Tensor]): Identity labels associated with query images.
306 | - query_cam_ids (Tensor or List[Tensor]): Camera IDs from which each query image was captured.
307 | - query_img_paths (ndarray or List[ndarray]): File paths of query images.
308 | - gallery_feats (Tensor or List[Tensor]): Feature vectors of gallery images.
309 | - gallery_ids (Tensor or List[Tensor]): Identity labels associated with gallery images.
310 | - gallery_cam_ids (Tensor or List[Tensor]): Camera IDs from which each gallery image was captured.
311 | - gallery_img_paths (ndarray or List[ndarray]): File paths of gallery images.
312 | - perm (ndarray): A permutation array for determining gallery subsets in each trial.
313 | - mode (str): Specifies the subset of gallery images to use.
314 | Options include 'indoor' for indoor cameras only and 'all' for all cameras.
315 | - num_shots (int): Number of instances of each identity to include in the gallery for each trial.
316 | - num_trials (int): Number of trials to perform, with each trial potentially using a different subset of gallery.
317 | - mser (bool): A flag indicating whether to use the MSER strategy for re-ranking gallery images.
318 | If set to False, a basic evaluation strategy is used.
319 |
320 | Returns:
321 | - tuple: A tuple containing average metrics of rank-1, rank-5, rank-10, rank-20 precision, mAP, and mINP
322 | across all trials, along with detailed rank results for further analysis.
323 | """
324 |
325 | if mser:
326 | return eval_sysu_mser(query_feats, query_ids, query_cam_ids, query_img_paths,
327 | gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths,
328 | perm, mode, num_shots, num_trials)
329 | return eval_sysu_base(query_feats, query_ids, query_cam_ids, query_img_paths,
330 | gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths,
331 | perm, mode, num_shots, num_trials)
332 |
333 |
334 | def eval_sysu_base(query_feats, query_ids, query_cam_ids, query_img_paths,
335 | gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths,
336 | perm, mode='all', num_shots=1, num_trials=10):
337 | """
338 | A function designed for evaluating the performance of a VI-ReID model on the SYSU-MM01 dataset under specific
339 | experimental settings. This function conducts evaluations across multiple trials, each trial potentially utilizing
340 | a different subset of gallery images based on the specified mode (indoor or all locations) and the number of shots.
341 | It computes re-identification metrics including rank-1, rank-5, rank-10, rank-20 precision, mean Average Precision
342 | (mAP), and mean Inverse Negative Penalty (mINP) across all trials, averaging the results to provide a comprehensive
343 | assessment of the model's performance.
344 |
345 | Args:
346 | - query_feats (Tensor): The feature representations of query images.
347 | - query_ids (Tensor): The identity labels associated with query images.
348 | - query_cam_ids (Tensor): The camera IDs from which each query image was captured.
349 | - query_img_paths (ndarray): The file paths of query images.
350 | - gallery_feats (Tensor): The feature representations of gallery images.
351 | - gallery_ids (Tensor): The identity labels associated with gallery images.
352 | - gallery_cam_ids (Tensor): The camera IDs from which each gallery image was captured.
353 | - gallery_img_paths (ndarray): The file paths of gallery images.
354 | - perm (ndarray): A permutation array used for determining gallery subsets in each trial.
355 | - mode (str): Specifies subset of gallery images to use ('indoor' for indoor cameras only, 'all' for all cameras).
356 | - num_shots (int): The number of instances per identity (per cameras) to include in the gallery for each trial.
357 | - num_trials (int): The number of trials to perform, with each trial potentially using a different gallery subset.
358 |
359 | Returns:
360 | - tuple: A tuple containing the average values of rank-1, rank-5, rank-10, rank-20 precision, mAP, and mINP
361 | across all trials, along with detailed rank results for each trial.
362 | """
363 |
364 | assert mode in ['indoor', 'all']
365 |
366 | gallery_cams = [1, 2] if mode == 'indoor' else [1, 2, 4, 5]
367 |
368 | # cam2 and cam3 are in the same location
369 | query_cam_ids[np.equal(query_cam_ids, 3)] = 2
370 | query_feats = F.normalize(query_feats, dim=1)
371 |
372 | gallery_indices = np.in1d(gallery_cam_ids, gallery_cams)
373 |
374 | gallery_feats = gallery_feats[gallery_indices]
375 | gallery_feats = F.normalize(gallery_feats, dim=1)
376 | gallery_ids = gallery_ids[gallery_indices]
377 | gallery_cam_ids = gallery_cam_ids[gallery_indices]
378 | gallery_img_paths = gallery_img_paths[gallery_indices]
379 | gallery_names = np.array(['/'.join(os.path.splitext(path)[0].split('/')[-3:]) for path in gallery_img_paths])
380 |
381 | gallery_id_set = np.unique(gallery_ids)
382 |
383 | r1, r5, r10, r20, mAP, mINP = 0, 0, 0, 0, 0, 0
384 | rank_results = []
385 | for t in range(num_trials):
386 | names = get_gallery_names_sysu(perm, gallery_cams, gallery_id_set, t, num_shots)
387 | flag = np.in1d(gallery_names, names)
388 |
389 | g_feats = gallery_feats[flag]
390 | g_ids = gallery_ids[flag]
391 | g_cam_ids = gallery_cam_ids[flag]
392 | g_img_paths = gallery_img_paths[flag]
393 |
394 | # dist_mat = pairwise_distance(query_feats, g_feats) # A
395 | dist_mat = -torch.mm(query_feats, g_feats.permute(1, 0)) # B
396 | # When using normalization on extracted features, these two distance measures are equivalent (A = 2 + 2 * B)
397 | # B is a little faster than A
398 | sorted_indices = np.argsort(dist_mat, axis=1)
399 |
400 | cur_rank_results = dict()
401 | cur_rank_results['query_ids'] = query_ids
402 | cur_rank_results['query_img_paths'] = query_img_paths
403 | cur_rank_results['gallery_ids'] = g_ids
404 | cur_rank_results['gallery_img_paths'] = g_img_paths
405 | cur_rank_results['dist_mat'] = dist_mat
406 | cur_rank_results['sorted_indices'] = sorted_indices
407 | rank_results.append(cur_rank_results)
408 |
409 | # cur_cmc = get_cmc(sorted_indices, query_ids, query_cam_ids, g_ids, g_cam_ids)
410 | # cur_mAP = get_mAP(sorted_indices, query_ids, query_cam_ids, g_ids, g_cam_ids)
411 | # cur_mINP = get_mINP(sorted_indices, query_ids, query_cam_ids, g_ids, g_cam_ids)
412 | cur_cmc, cur_mAP, cur_mINP = get_cmc_mAP_mINP(sorted_indices, query_ids, query_cam_ids, g_ids, g_cam_ids)
413 | r1 += cur_cmc[0]
414 | r5 += cur_cmc[4]
415 | r10 += cur_cmc[9]
416 | r20 += cur_cmc[19]
417 | mAP += cur_mAP
418 | mINP += cur_mINP
419 |
420 | r1 = r1 / num_trials * 100
421 | r5 = r5 / num_trials * 100
422 | r10 = r10 / num_trials * 100
423 | r20 = r20 / num_trials * 100
424 | mAP = mAP / num_trials * 100
425 | mINP = mINP / num_trials * 100
426 |
427 | logger = logging.getLogger('MCJA')
428 | logger.info('-' * 150)
429 | perf = '{} num-shot:{} r1 precision = {:.2f} , r10 precision = {:.2f} , r20 precision = {:.2f} , ' \
430 | 'mAP = {:.2f} , mINP = {:.2f}'
431 | logger.info(perf.format(mode, num_shots, r1, r10, r20, mAP, mINP))
432 | logger.info('-' * 150)
433 |
434 | return r1, r5, r10, r20, mAP, mINP, rank_results
435 |
436 |
437 | def eval_sysu_mser(query_feats_list, query_ids_list, query_cam_ids_list, query_img_paths_list,
438 | gallery_feats_list, gallery_ids_list, gallery_cam_ids_list, gallery_img_paths_list,
439 | perm, mode='all', num_shots=1, num_trials=10):
440 | """
441 | A function designed to evaluate the performance of a VI-ReID model using the Multi-Spectral Enhanced Ranking (MSER)
442 | strategy on the SYSU-MM01 dataset. The MSER strategy involves a novel re-ranking process that enhances the initial
443 | ranking of gallery images based on their similarity to query images, considering multiple spectral representations.
444 | This evaluation function also supports variable experimental settings, such as different evaluation modes and the
445 | number of shots, across multiple trials for a comprehensive performance assessment.
446 |
447 | Args:
448 | - query_feats_list (List[Tensor]): List of tensors representing features of query images.
449 | - query_ids_list (List[Tensor]): List of tensors containing the identity labels of the query images.
450 | - query_cam_ids_list (List[Tensor]): List of tensors with camera IDs from which each query image was captured.
451 | - query_img_paths_list (List[ndarray]): List of ndarrays holding the file paths for each query image.
452 | - gallery_feats_list (List[Tensor]): List of tensors representing the feature vectors of gallery images.
453 | - gallery_ids_list (List[Tensor]): List of tensors containing the identity labels of the gallery images.
454 | - gallery_cam_ids_list (List[Tensor]): List of tensors with camera IDs from which each gallery image was captured.
455 | - gallery_img_paths_list (List[ndarray]): List of ndarrays holding the file paths for each gallery image.
456 | - perm (ndarray): A permutation array for determining the subsets of gallery images used in each trial.
457 | - mode (str): Specifies the subset of gallery images to use for evaluation.
458 | Options include 'indoor' for indoor cameras only and 'all' for all cameras.
459 | - num_shots (int): Specifies the number of instances of each identity to include in the gallery set for each trial.
460 | - num_trials (int): The number of trials to perform, with each trial using a different subset of gallery images.
461 |
462 | Returns:
463 | - tuple: A tuple containing the average metrics of rank-1, rank-5, rank-10, rank-20 precision, mean Average
464 | Precision (mAP), and mean Inverse Negative Penalty (mINP) across all trials. Additionally, detailed rank
465 | results for each trial are provided for further analysis.
466 | """
467 |
468 | assert mode in ['indoor', 'all']
469 |
470 | gallery_cams = [1, 2] if mode == 'indoor' else [1, 2, 4, 5]
471 |
472 | list_num = len(query_feats_list)
473 |
474 | for c in range(list_num):
475 | # cam2 and cam3 are in the same location
476 | query_cam_ids_list[c][np.equal(query_cam_ids_list[c], 3)] = 2
477 | query_feats_list[c] = F.normalize(query_feats_list[c], dim=1)
478 |
479 | gallery_indices = np.in1d(gallery_cam_ids_list[c], gallery_cams)
480 |
481 | gallery_feats_list[c] = gallery_feats_list[c][gallery_indices]
482 | gallery_feats_list[c] = F.normalize(gallery_feats_list[c], dim=1)
483 | gallery_ids_list[c] = gallery_ids_list[c][gallery_indices]
484 | gallery_cam_ids_list[c] = gallery_cam_ids_list[c][gallery_indices]
485 | gallery_img_paths_list[c] = gallery_img_paths_list[c][gallery_indices]
486 |
487 | gallery_names = np.array(
488 | ['/'.join(os.path.splitext(path)[0].split('/')[-3:]) for path in gallery_img_paths_list[0]])
489 | gallery_id_set = np.unique(gallery_ids_list[0])
490 |
491 | r1, r5, r10, r20, mAP, mINP = 0, 0, 0, 0, 0, 0
492 | rank_results = []
493 | for t in range(num_trials):
494 | names = get_gallery_names_sysu(perm, gallery_cams, gallery_id_set, t, num_shots)
495 | flag = np.in1d(gallery_names, names)
496 |
497 | g_feats_list, g_ids_list, g_cam_ids_list, g_img_paths_list = [], [], [], []
498 | for c in range(list_num):
499 | g_feats = gallery_feats_list[c][flag]
500 | g_ids = gallery_ids_list[c][flag]
501 | g_cam_ids = gallery_cam_ids_list[c][flag]
502 | g_img_paths = gallery_img_paths_list[c][flag]
503 | g_feats_list.append(g_feats)
504 | g_ids_list.append(g_ids)
505 | g_cam_ids_list.append(g_cam_ids)
506 | g_img_paths_list.append(g_img_paths)
507 |
508 | dist_mat = mser_rerank(query_feats_list, g_feats_list,
509 | k1=40, k2=20, lambda_value=0.3, mode='i2v')
510 | sorted_indices = np.argsort(dist_mat, axis=1)
511 |
512 | cur_rank_results = dict()
513 | cur_rank_results['query_ids'] = query_ids_list[0]
514 | cur_rank_results['query_img_paths'] = query_img_paths_list[0]
515 | cur_rank_results['gallery_ids'] = g_ids_list[0]
516 | cur_rank_results['gallery_img_paths'] = g_img_paths_list[0]
517 | cur_rank_results['dist_mat'] = dist_mat
518 | cur_rank_results['sorted_indices'] = sorted_indices
519 | rank_results.append(cur_rank_results)
520 |
521 | # cur_cmc = get_cmc(sorted_indices, query_ids, query_cam_ids, g_ids, g_cam_ids)
522 | # cur_mAP = get_mAP(sorted_indices, query_ids, query_cam_ids, g_ids, g_cam_ids)
523 | # cur_mINP = get_mINP(sorted_indices, query_ids, query_cam_ids, g_ids, g_cam_ids)
524 | cur_cmc, cur_mAP, cur_mINP = get_cmc_mAP_mINP(sorted_indices,
525 | query_ids_list[0], query_cam_ids_list[0],
526 | g_ids_list[0], g_cam_ids_list[0])
527 | r1 += cur_cmc[0]
528 | r5 += cur_cmc[4]
529 | r10 += cur_cmc[9]
530 | r20 += cur_cmc[19]
531 | mAP += cur_mAP
532 | mINP += cur_mINP
533 |
534 | r1 = r1 / num_trials * 100
535 | r5 = r5 / num_trials * 100
536 | r10 = r10 / num_trials * 100
537 | r20 = r20 / num_trials * 100
538 | mAP = mAP / num_trials * 100
539 | mINP = mINP / num_trials * 100
540 |
541 | logger = logging.getLogger('MCJA')
542 | logger.info('-' * 150)
543 | perf = '[MSER] {} num-shot:{} r1 precision = {:.2f} , r10 precision = {:.2f} , r20 precision = {:.2f} , ' \
544 | 'mAP = {:.2f} , mINP = {:.2f}'
545 | logger.info(perf.format(mode, num_shots, r1, r10, r20, mAP, mINP))
546 | logger.info('-' * 150)
547 |
548 | return r1, r5, r10, r20, mAP, mINP, rank_results
549 |
550 |
551 | def eval_regdb(query_feats, query_ids, query_cam_ids, query_img_paths,
552 | gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths, mode='i2v', mser=False):
553 | """
554 | A comprehensive function tailored for evaluating VI-ReID models on the RegDB dataset, which integrates both basic
555 | and advanced Multi-Spectral Enhanced Ranking (MSER) strategies for performance assessment. This evaluation mechanism
556 | is devised to accommodate the unique modality challenges presented by the RegDB dataset, specifically focusing on
557 | the thermal-to-visible (t2v or i2v) and visible-to-thermal (v2t or v2i) matching scenarios. By leveraging the
558 | flexibility in choosing between a straightforward evaluation approach and a new MSER strategy, this function enables
559 | analysis of model performance.
560 |
561 | Args:
562 | - query_feats (Tensor): The feature representations of query images.
563 | - query_ids (Tensor): The identity labels associated with each query image.
564 | - query_cam_ids (Tensor): The camera IDs from which each query image was captured, indicating the source modality.
565 | - query_img_paths (ndarray): The file paths for query images, useful for detailed analysis and debugging.
566 | - gallery_feats (Tensor): The feature representations of gallery images.
567 | - gallery_ids (Tensor): The identity labels for gallery images.
568 | - gallery_cam_ids (Tensor): The camera IDs for gallery images, highlighting the target modality for matching.
569 | - gallery_img_paths (ndarray): The file paths for gallery images, enabling precise tracking of evaluated samples.
570 | - mode (str): Determines the direction of modality matching, either 'i2v' or 'v2i'.
571 | - mser (bool): Indicates whether the Multi-Spectral Enhanced Ranking strategy should be applied.
572 |
573 | Returns:
574 | - tuple: Delivers evaluation metrics including rank-1, rank-5, rank-10, rank-20 precision, mean Average Precision
575 | (mAP), and mean Inverse Negative Penalty (mINP), alongside detailed ranking results for in-depth analysis.
576 | """
577 |
578 | if mser:
579 | return eval_regdb_mser(query_feats, query_ids, query_cam_ids, query_img_paths,
580 | gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths, mode)
581 | return eval_regdb_base(query_feats, query_ids, query_cam_ids, query_img_paths,
582 | gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths, mode)
583 |
584 |
585 | def eval_regdb_base(query_feats, query_ids, query_cam_ids, query_img_paths,
586 | gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths, mode='i2v'):
587 | """
588 | A function specifically developed for evaluating VI-ReID models on the RegDB dataset, focusing on the thermal-to-
589 | visible (t2v or i2v) or visible-to-thermal (v2t or v2i) matching scenarios. This evaluation function computes the
590 | similarity between query and gallery features using cosine similarity, applies normalization to feature vectors,
591 | and ranks the gallery images based on their similarity to the query set. The performance is quantified using
592 | re-identification metrics such as rank-1, rank-5, rank-10, rank-20 precision, mean Average Precision (mAP), and
593 | mean Inverse Negative Penalty (mINP), providing a detailed analysis of the model's performance.
594 |
595 | Args:
596 | - query_feats (Tensor or List[Tensor]): Feature vectors of query images.
597 | - query_ids (Tensor or List[Tensor]): Identity labels associated with query images.
598 | - query_cam_ids (Tensor or List[Tensor]): Camera IDs from which each query image was captured.
599 | - query_img_paths (ndarray or List[ndarray]): File paths of query images.
600 | - gallery_feats (Tensor or List[Tensor]): Feature vectors of gallery images.
601 | - gallery_ids (Tensor or List[Tensor]): Identity labels associated with gallery images.
602 | - gallery_cam_ids (Tensor or List[Tensor]): Camera IDs from which each gallery image was captured.
603 | - gallery_img_paths (ndarray or List[ndarray]): File paths of gallery images.
604 | - mode (str): The evaluation mode, either 'i2v' (infrared to visible) or 'v2i' (visible to infrared),
605 | dictating the direction of matching between the query and gallery sets.
606 |
607 | Returns:
608 | - tuple: A tuple containing metrics of rank-1, rank-5, rank-10, rank-20 precision, mAP, and mINP,
609 | along with detailed rank results for further analysis.
610 | """
611 |
612 | assert mode in ['i2v', 'v2i']
613 |
614 | gallery_feats = F.normalize(gallery_feats, dim=1)
615 | query_feats = F.normalize(query_feats, dim=1)
616 |
617 | # dist_mat = pairwise_distance(query_feats, gallery_feats)
618 | dist_mat = -torch.mm(query_feats, gallery_feats.t())
619 | sorted_indices = np.argsort(dist_mat, axis=1)
620 |
621 | rank_results = []
622 | cur_rank_results = dict()
623 | cur_rank_results['query_ids'] = query_ids
624 | cur_rank_results['query_img_paths'] = query_img_paths
625 | cur_rank_results['gallery_ids'] = gallery_ids
626 | cur_rank_results['gallery_img_paths'] = gallery_img_paths
627 | cur_rank_results['dist_mat'] = dist_mat
628 | cur_rank_results['sorted_indices'] = sorted_indices
629 | rank_results.append(cur_rank_results)
630 |
631 | # mAP = get_mAP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids)
632 | # cmc = get_cmc(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids)
633 | # mINP = get_mINP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids)
634 | cmc, mAP, mINP = get_cmc_mAP_mINP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids)
635 |
636 | r1 = cmc[0]
637 | r5 = cmc[4]
638 | r10 = cmc[9]
639 | r20 = cmc[19]
640 |
641 | r1 = r1 * 100
642 | r5 = r5 * 100
643 | r10 = r10 * 100
644 | r20 = r20 * 100
645 | mAP = mAP * 100
646 | mINP = mINP * 100
647 |
648 | logger = logging.getLogger('MCJA')
649 | logger.info('-' * 150)
650 | perf = 'r1 precision = {:.2f} , r10 precision = {:.2f} , r20 precision = {:.2f} , ' \
651 | 'mAP = {:.2f} , mINP = {:.2f}'
652 | logger.info(perf.format(r1, r10, r20, mAP, mINP))
653 | logger.info('-' * 150)
654 |
655 | return r1, r5, r10, r20, mAP, mINP, rank_results
656 |
657 |
658 | def eval_regdb_mser(query_feats_list, query_ids_list, query_cam_ids_list, query_img_paths_list,
659 | gallery_feats_list, gallery_ids_list, gallery_cam_ids_list, gallery_img_paths_list, mode='i2v'):
660 | """
661 | A function crafted for evaluating VI-ReID models on the RegDB dataset using the Multi-Spectral Enhanced Ranking
662 | (MSER) strategy, tailored specifically for the thermal-to-visible (t2v or i2v) or visible-to-thermal (v2t or v2i)
663 | matching scenarios. The MSER strategy applies a re-ranking mechanism to enhance the initial distance matrix
664 | computation, utilizing multiple spectral representations. This function performs normalization on the feature
665 | vectors of both query and gallery sets, computes the distance matrix with MSER re-ranking, and evaluates the
666 | model based on re-identification metrics.
667 |
668 | Args:
669 | - query_feats_list (List[Tensor]): List of tensors representing features of query images.
670 | - query_ids_list (List[Tensor]): List of tensors containing the identity labels of the query images.
671 | - query_cam_ids_list (List[Tensor]): List of tensors with camera IDs from which each query image was captured.
672 | - query_img_paths_list (List[ndarray]): List of ndarrays holding the file paths for each query image.
673 | - gallery_feats_list (List[Tensor]): List of tensors representing the feature vectors of gallery images.
674 | - gallery_ids_list (List[Tensor]): List of tensors containing the identity labels of the gallery images.
675 | - gallery_cam_ids_list (List[Tensor]): List of tensors with camera IDs from which each gallery image was captured.
676 | - gallery_img_paths_list (List[ndarray]): List of ndarrays holding the file paths for each gallery image.
677 | - mode (str): The evaluation mode, either 'i2v' (infrared to visible) or 'v2i' (visible to infrared),
678 | dictating the direction of matching between query and gallery sets.
679 |
680 | Returns:
681 | - tuple: A tuple containing metrics of rank-1, rank-5, rank-10, rank-20 precision, mean Average Precision (mAP),
682 | and mean Inverse Negative Penalty (mINP), along with detailed rank results for further analysis.
683 | """
684 |
685 | list_num = len(query_feats_list)
686 | for c in range(list_num):
687 | gallery_feats_list[c] = F.normalize(gallery_feats_list[c], dim=1)
688 | query_feats_list[c] = F.normalize(query_feats_list[c], dim=1)
689 |
690 | dist_mat = mser_rerank(query_feats_list, gallery_feats_list,
691 | k1=50, k2=10, lambda_value=0.2, eval_type=False, mode=mode)
692 | sorted_indices = np.argsort(dist_mat, axis=1)
693 |
694 | rank_results = []
695 | cur_rank_results = dict()
696 | cur_rank_results['query_ids'] = query_ids_list[0]
697 | cur_rank_results['query_img_paths'] = query_img_paths_list[0]
698 | cur_rank_results['gallery_ids'] = gallery_ids_list[0]
699 | cur_rank_results['gallery_img_paths'] = gallery_img_paths_list[0]
700 | cur_rank_results['dist_mat'] = dist_mat[0]
701 | cur_rank_results['sorted_indices'] = sorted_indices[0]
702 | rank_results.append(cur_rank_results)
703 |
704 | # mAP = get_mAP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids)
705 | # cmc = get_cmc(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids)
706 | # mINP = get_mINP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids)
707 | cmc, mAP, mINP = get_cmc_mAP_mINP(sorted_indices,
708 | query_ids_list[0], query_cam_ids_list[0],
709 | gallery_ids_list[0], gallery_cam_ids_list[0])
710 |
711 | r1 = cmc[0]
712 | r5 = cmc[4]
713 | r10 = cmc[9]
714 | r20 = cmc[19]
715 |
716 | r1 = r1 * 100
717 | r5 = r5 * 100
718 | r10 = r10 * 100
719 | r20 = r20 * 100
720 | mAP = mAP * 100
721 | mINP = mINP * 100
722 |
723 | logger = logging.getLogger('MCJA')
724 | logger.info('-' * 150)
725 | perf = '[MSER] r1 precision = {:.2f} , r10 precision = {:.2f} , r20 precision = {:.2f} , ' \
726 | 'mAP = {:.2f} , mINP = {:.2f}'
727 | logger.info(perf.format(r1, r10, r20, mAP, mINP))
728 | logger.info('-' * 150)
729 |
730 | return r1, r5, r10, r20, mAP, mINP, rank_results
731 |
--------------------------------------------------------------------------------
/utils/mser_rerank.py:
--------------------------------------------------------------------------------
1 | """MCJA/utils/mser_rerank.py
2 | It introduces the Multi-Spectral Enhanced Ranking (MSER) re-ranking strategy.
3 | """
4 |
5 | import numpy as np
6 | import torch
7 |
8 |
9 | def pairwise_distance(query_features, gallery_features):
10 | """
11 | A function that efficiently computes the pairwise Euclidean distances between two sets of features, typically used
12 | in the context of person re-identification tasks to measure similarities between query and gallery sets. This
13 | implementation leverages matrix operations for high performance, calculating the squared differences between
14 | each pair of features in the query and gallery feature tensors.
15 |
16 | Args:
17 | - query_features (Tensor): A tensor containing the feature vectors of the query set.
18 | Each row represents the feature vector of a query sample.
19 | - gallery_features (Tensor): A tensor containing the feature vectors of the gallery set.
20 | Each row represents the feature vector of a gallery sample.
21 |
22 | Returns:
23 | - Tensor: A matrix of distances where each element (i, j) represents the Euclidean distance between
24 | the i-th query feature and the j-th gallery feature.
25 | """
26 |
27 | x = query_features
28 | y = gallery_features
29 | m, n = x.size(0), y.size(0)
30 | x = x.view(m, -1)
31 | y = y.view(n, -1)
32 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
33 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
34 | dist.addmm_(beta=1, alpha=-2, mat1=x, mat2=y.t())
35 | return dist
36 |
37 |
38 | def mser_rerank(query_feats_list, g_feats_list, k1=40, k2=20, lambda_value=0.3, eval_type=True, mode='v2i'):
39 | """
40 | A function designed to perform Multi-Spectral Enhanced Ranking (MSER) for cross-modality person re-identification.
41 | The MSER strategy advances the initial matching process by integrating a re-ranking mechanism that emphasizes
42 | reciprocal relationships and spectral characteristics. This approach recalculates the pairwise distances between
43 | query and gallery sets, refining the initial matches through k-reciprocal nearest neighbors and a Jaccard distance
44 | measure to produce a more accurate ranking.
45 |
46 | As mentioned in our paper, mser is based on the rerank[1] strategy in single-modality ReID and extends it to VI-ReID
47 | with multi-spectral information within images. Here, we utilize the rerank code from [2] in our implementation.
48 |
49 | Ref:
50 | [1] (Paper) Re-ranking Person Re-identification with k-Reciprocal Encoding, CVPR 2017.
51 | [2] (Code) https://github.com/DoubtedSteam/MPANet/blob/main/utils/rerank.py
52 |
53 | Args:
54 | - query_feats_list (List[Tensor]): List of tensors representing feature vectors of query images.
55 | - g_feats_list (List[Tensor]): List of tensors representing feature vectors of gallery images.
56 | - k1 (int): The primary parameter controlling the extent of k-reciprocal neighbors to consider,
57 | affecting the initial scope of re-ranking.
58 | - k2 (int): The secondary parameter influencing the expansion of reciprocal neighbors,
59 | further refining the selection based on mutual nearest neighbors.
60 | - lambda_value (float): A coefficient used to balance the original distance matrix with the Jaccard distance,
61 | adjusting the influence of each component in the final distance computation.
62 | - eval_type (bool): Indicates the type of evaluation to be performed.
63 | - mode (str): Specifies the modality matching direction, either 'i2v' for infrared-to-visible or 'v2i' for
64 | visible-to-infrared, adapting the function for different dataset characteristics.
65 |
66 | Returns:
67 | - numpy.ndarray: The re-ranked distance matrix, where each element reflects the recalculated distance between a
68 | query and a gallery feature vector, with lower values denoting higher similarity.
69 | """
70 |
71 | # Note: The MSER strategy requires more CPU memory.
72 |
73 | assert mode in ['i2v', 'v2i']
74 |
75 | if mode == 'i2v':
76 | list_num = len(g_feats_list)
77 | q_feat = query_feats_list[0]
78 | feats = torch.cat([q_feat] + g_feats_list, 0)
79 | else: # mode == 'v2i'
80 | list_num = len(query_feats_list)
81 | q_feat = torch.cat(query_feats_list, 0)
82 | g_feat = g_feats_list[0]
83 | feats = torch.cat(query_feats_list + [g_feat], 0)
84 |
85 | dist = pairwise_distance(feats, feats)
86 | # dist = -torch.mm(feats, feats.permute(1, 0))
87 | original_dist = dist.clone().numpy()
88 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0))
89 | V = np.zeros_like(original_dist).astype(np.float16)
90 |
91 | query_num = q_feat.size(0)
92 | all_num = original_dist.shape[0]
93 | if eval_type:
94 | dist[:, query_num:] = dist.max()
95 | dist = dist.numpy()
96 | initial_rank = np.argsort(dist).astype(np.int32)
97 |
98 | for i in range(all_num):
99 | # k-reciprocal neighbors
100 | forward_k_neigh_index = initial_rank[i, :k1 + 1]
101 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1]
102 | fi = np.where(backward_k_neigh_index == i)[0]
103 | k_reciprocal_index = forward_k_neigh_index[fi]
104 | k_reciprocal_expansion_index = k_reciprocal_index
105 |
106 | # for j in range(len(k_reciprocal_index)):
107 | # candidate = k_reciprocal_index[j]
108 | # candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1]
109 | # candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,
110 | # :int(np.around(k1 / 2)) + 1]
111 | # fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]
112 | # candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
113 | # if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len(
114 | # candidate_k_reciprocal_index):
115 | # k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index)
116 |
117 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
118 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index])
119 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight)
120 | original_dist = original_dist[:query_num, ]
121 | if k2 != 1:
122 | V_qe = np.zeros_like(V, dtype=np.float16)
123 | for i in range(all_num):
124 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0)
125 | V = V_qe
126 | del V_qe
127 | del initial_rank
128 | invIndex = [] # row index
129 | for i in range(all_num):
130 | invIndex.append(np.where(V[:, i] != 0)[0])
131 |
132 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16)
133 |
134 | for i in range(query_num):
135 | temp_min = np.zeros(shape=[1, all_num], dtype=np.float16)
136 | indNonZero = np.where(V[i, :] != 0)[0]
137 | indImages = [invIndex[ind] for ind in indNonZero]
138 | for j in range(len(indNonZero)):
139 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]],
140 | V[indImages[j], indNonZero[j]])
141 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min)
142 |
143 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value
144 | del original_dist
145 | del V
146 | del jaccard_dist
147 | final_dist = final_dist[:query_num, query_num:]
148 |
149 | if mode == 'i2v':
150 | final_dist = final_dist.reshape(query_num, list_num, -1)
151 | final_dist = np.mean(final_dist, axis=1)
152 | else: # mode == 'v2i'
153 | final_dist = final_dist
154 | final_dist = final_dist.reshape(list_num, query_num // list_num, -1)
155 | final_dist = np.mean(final_dist, axis=0)
156 | return final_dist
157 |
158 |
159 | if __name__ == '__main__':
160 | q_feat = torch.randn((8, 16))
161 | g_feat = torch.randn((4, 16))
162 | dist = mser_rerank([q_feat, q_feat, q_feat], [g_feat, g_feat, g_feat], k1=6, k2=4, mode='v2i')
163 | dist = mser_rerank([q_feat, q_feat, q_feat], [g_feat, g_feat, g_feat], k1=6, k2=4, mode='i2v')
164 |
--------------------------------------------------------------------------------
/utils/tools.py:
--------------------------------------------------------------------------------
1 | """MCJA/utils/tools.py
2 | This utility file provides some helper functions.
3 | """
4 |
5 | import random
6 | import datetime
7 | import numpy as np
8 | import torch
9 |
10 |
11 | def set_seed(seed, cuda=True):
12 | """
13 | A utility function for setting the random seed across various libraries commonly used in deep learning projects to
14 | ensure reproducibility of results. By fixing the random seed, this function makes experiments deterministic, meaning
15 | that running the same code with the same inputs and the same seed (on the same experimental platform) will produce
16 | the same outputs every time, which is crucial for debugging and comparing different models and configurations.
17 |
18 | Args:
19 | - seed (int): The random seed value to be set across all libraries.
20 | - cuda (bool, optional): A flag indicating whether to apply the seed to CUDA operations as well. Default is True.
21 | """
22 |
23 | random.seed(seed)
24 | np.random.seed(seed)
25 | torch.manual_seed(seed)
26 | if cuda:
27 | torch.cuda.manual_seed_all(seed)
28 | torch.backends.cudnn.deterministic = True
29 | torch.backends.cudnn.benchmark = False
30 |
31 |
32 | def time_str(fmt=None):
33 | """
34 | A utility function for generating a formatted string representing the current date and time. This function is
35 | particularly useful for creating timestamps for logging, file naming, or any other task that requires capturing
36 | the exact moment when an event occurs. By default, the function produces a string formatted as "YYYY-MM-DD_hh-mm-ss",
37 | but it allows for customization of the format according to the user's needs.
38 |
39 | Args:
40 | - fmt (str, optional): A format string defining how the date and time should be represented.
41 | This string should follow the formatting rules used by Python's `strftime` method. If no format is specified, the
42 | default format "%Y-%m-%d_%H-%M-%S" is used, which corresponds to the "year-month-day_hour-minute-second" format.
43 |
44 | Returns:
45 | - str: A string representation of the current date and time,
46 | formatted according to the provided or default format specification.
47 | """
48 |
49 | if fmt is None:
50 | fmt = '%Y-%m-%d_%H-%M-%S'
51 | return datetime.datetime.today().strftime(fmt)
52 |
--------------------------------------------------------------------------------