├── LICENSE ├── README.md ├── configs ├── RegDB_MCJA.yml ├── SYSU_MCJA.yml ├── __init__.py └── default │ ├── __init__.py │ ├── dataset.py │ └── strategy.py ├── data ├── __init__.py ├── dataset.py ├── sampler.py └── transform.py ├── engine ├── __init__.py ├── engine.py └── metric.py ├── figs └── mcja_overall_structure.png ├── losses ├── __init__.py └── cm_retrieval_loss.py ├── main.py ├── models ├── __init__.py ├── backbones │ ├── __init__.py │ └── resnet.py ├── mcja.py └── modules │ ├── __init__.py │ └── mda.py └── utils ├── __init__.py ├── calc_acc.py ├── eval_data.py ├── mser_rerank.py └── tools.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 workingcoder 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bridging the Gap: Multi-level Cross-modality Joint Alignment for Visible-infrared Person Re-identification 2 | 3 | By [Tengfei Liang](https://scholar.google.com/citations?user=YE6fPvgAAAAJ), [Yi Jin](https://scholar.google.com/citations?user=NQAenU0AAAAJ), [Wu Liu](https://scholar.google.com/citations?user=rQpizr0AAAAJ), [Tao Wang](https://scholar.google.com/citations?user=F3C5oAcAAAAJ), [Songhe Feng](https://scholar.google.com/citations?user=K5lqMYgAAAAJ), [Yidong Li](https://scholar.google.com/citations?user=3PagRQEAAAAJ). 4 | 5 | This repository is an official implementation of the paper [Bridging the Gap: Multi-level Cross-modality Joint Alignment for Visible-infrared Person Re-identification](https://ieeexplore.ieee.org/abstract/document/10472470). [`IEEEXplore`](https://ieeexplore.ieee.org/abstract/document/10472470) [`Google Drive`](https://drive.google.com/file/d/19-2f-gTj3P9tV-YhabtVpXSgShxFrGrs/view?usp=sharing) 6 | 7 | *Notes:* 8 | 9 | This repository offers the complete code of the entire method, featuring a well-organized directory structure and detailed comments, facilitating the training and testing of the model. 10 | It is hoped that this can serve as a new baseline for cross-modal visible-infrared person re-identification. 11 | 12 | 13 | ## Abstract 14 | 15 | Visible-Infrared person Re-IDentification (VI-ReID) is a challenging cross-modality image retrieval task that aims to match pedestrians' images across visible and infrared cameras. 16 | To solve the modality gap, existing mainstream methods adopt a learning paradigm converting the image retrieval task into an image classification task with cross-entropy loss and auxiliary metric learning losses. 17 | These losses follow the strategy of adjusting the distribution of extracted embeddings to reduce the intra-class distance and increase the inter-class distance. 18 | However, such objectives do not precisely correspond to the final test setting of the retrieval task, resulting in a new gap at the optimization level. 19 | By rethinking these keys of VI-ReID, we propose a simple and effective method, the Multi-level Cross-modality Joint Alignment (MCJA), bridging both the modality and objective-level gap. 20 | For the former, we design the Visible-Infrared Modality Coordinator in the image space and propose the Modality Distribution Adapter in the feature space, effectively reducing modality discrepancy of the feature extraction process. 21 | For the latter, we introduce a new Cross-Modality Retrieval loss. 22 | It is the first work to constrain from the perspective of the ranking list in the VI-ReID, aligning with the goal of the testing stage. 23 | Moreover, to strengthen the robustness and cross-modality retrieval ability, we further introduce a Multi-Spectral Enhanced Ranking strategy for the testing phase. 24 | Based on the global feature only, our method outperforms existing methods by a large margin, achieving the remarkable rank-1 of 89.51% and mAP of 87.58% on the most challenging single-shot setting and all-search mode of the SYSU-MM01 dataset. 25 | (For more details, please refer to [the original paper](https://ieeexplore.ieee.org/abstract/document/10472470)) 26 | 27 |
28 |
29 | 30 | 31 | Fig. 1: Overall architecture of the proposed MCJA model. 32 |
33 | 34 | 35 | ## Requirements 36 | 37 | The code of this repository is designed to run on a single GPU. 38 | Here are the Python packages and their corresponding versions during the execution of our experiments: 39 | 40 | - Python 3.8 41 | - apex==0.1 42 | - numpy==1.21.5 43 | - Pillow==8.4.0 44 | - pytorch_ignite==0.2.1 45 | - scipy==1.7.3 46 | - torch==1.8.1+cu111 47 | - torchsort==0.1.9 48 | - torchvision==0.9.1+cu111 49 | - yacs==0.1.8 50 | 51 | *Notes:* 52 | When installing the 'apex' package, please refer to its [official repository - apex](https://github.com/NVIDIA/apex). 53 | 54 | *P.S.* Higher or Lower versions of these packages might be supported. 55 | When attempting to use a different version of PyTorch, please be mindful of the compatibility with pytorch_ignite, torchsort, etc. 56 | 57 | 58 | ## Dataset & Preparation 59 | 60 | During the experiment, we evaluate our proposed method on publicly available datasets, SYSU-MM01 and RegDB, which are commonly used for comparison in VI-ReID. 61 | Please download the corresponding datasets and modify the path of the data_root folder in [configs/default/dataset.py](./configs/default/dataset.py). 62 | 63 | 64 | ## Experiments 65 | 66 | Our [main.py](./main.py) supports both training and testing as well as testing only. 67 | 68 | ### Train 69 | 70 | During the training process, executing the following command allows for the training and evaluation of MCJA models on the SYSU-MM01 and RegDB datasets: 71 | 72 | ```bash 73 | python main.py --cfg configs/SYSU_MCJA.yml --gpu 0 --seed 8 --desc MCJA 74 | ``` 75 | 76 | ```bash 77 | python main.py --cfg configs/RegDB_MCJA.yml --gpu 0 --seed 8 --desc MCJA 78 | ``` 79 | 80 | ### Test 81 | 82 | When conducting tests only, set 'test_only' to true in the 'XXXX.yml' configuration file and specify the path for loading the model in the 'resume' setting. 83 | Then, execute the same command as mentioned above to complete the testing and evaluation: 84 | 85 | ```bash 86 | python main.py --cfg configs/SYSU_MCJA.yml --gpu 0 --desc MCJA_test_only 87 | ``` 88 | 89 | ```bash 90 | python main.py --cfg configs/RegDB_MCJA.yml --gpu 0 --desc MCJA_test_only 91 | ``` 92 | 93 | *Notes:* 94 | The '--seed' and '--desc' of [main.py](./main.py) are optional. 95 | The former is used to add a suffix description to the current run, while the latter controls the random seed for this experiment. 96 | 97 | 98 | ## Citation 99 | If you find MCJA useful in your research, please kindly cite this paper in your publications: 100 | ```bibtex 101 | @article{TCSVT24_MCJA, 102 | author = {Liang, Tengfei and Jin, Yi and Liu, Wu and Wang, Tao and Feng, Songhe and Li, Yidong}, 103 | title = {Bridging the Gap: Multi-level Cross-modality Joint Alignment for Visible-infrared Person Re-identification}, 104 | journal = {IEEE Transactions on Circuits and Systems for Video Technology}, 105 | pages = {1-1}, 106 | year = {2024}, 107 | doi = {10.1109/TCSVT.2024.3377252} 108 | } 109 | ``` 110 | 111 | 112 | ## Related Repos 113 | Our repository builds upon the work of others, and we extend our gratitude for their contributions. 114 | Below is a list of some of these works: 115 | 116 | - AGW - https://github.com/mangye16/Cross-Modal-Re-ID-baseline 117 | - MPANet - https://github.com/DoubtedSteam/MPANet 118 | 119 | 120 | ## License 121 | 122 | This repository is released under the MIT license. Please see the [LICENSE](./LICENSE) file for more information. 123 | -------------------------------------------------------------------------------- /configs/RegDB_MCJA.yml: -------------------------------------------------------------------------------- 1 | # Customized Strategy Config 2 | 3 | prefix: RegDB 4 | 5 | # Setting for Data 6 | dataset: regdb 7 | image_size: (384, 192) 8 | sample_method: norm_triplet 9 | p_size: 4 10 | k_size: 32 11 | batch_size: 128 12 | 13 | # Settings for Augmentation 14 | random_flip: true 15 | random_crop: true 16 | random_erase: true 17 | color_jitter: true 18 | padding: 10 19 | vimc_wg: true 20 | vimc_cc: true 21 | vimc_sj: true 22 | 23 | # Settings for Model 24 | drop_last_stride: false 25 | mda_ratio: 2 26 | mda_m: 12 27 | 28 | # Setting for Loss 29 | loss_id: true 30 | loss_cmr: true 31 | 32 | # Settings for Training 33 | lr: 0.00035 34 | wd: 0.0005 35 | num_epoch: 300 36 | lr_step: [ 300 ] 37 | fp16: true 38 | 39 | # Settings for Logging 40 | start_eval: 0 41 | log_period: 5 42 | eval_interval: 1 43 | 44 | # Settings for Testing 45 | resume: '' 46 | mser: true 47 | test_only: false 48 | -------------------------------------------------------------------------------- /configs/SYSU_MCJA.yml: -------------------------------------------------------------------------------- 1 | # Customized Strategy Config 2 | 3 | prefix: SYSU 4 | 5 | # Setting for Data 6 | dataset: sysu 7 | image_size: (384, 192) 8 | sample_method: norm_triplet 9 | p_size: 8 10 | k_size: 16 11 | batch_size: 128 12 | 13 | # Settings for Augmentation 14 | random_flip: true 15 | random_crop: true 16 | random_erase: true 17 | color_jitter: true 18 | padding: 10 19 | vimc_wg: true 20 | vimc_cc: true 21 | vimc_sj: true 22 | 23 | # Settings for Model 24 | drop_last_stride: false 25 | mda_ratio: 2 26 | mda_m: 2 27 | 28 | # Setting for Loss 29 | loss_id: true 30 | loss_cmr: true 31 | 32 | # Settings for Training 33 | lr: 0.00035 34 | wd: 0.0005 35 | num_epoch: 140 36 | lr_step: [ 80, 120 ] 37 | fp16: true 38 | 39 | # Settings for Logging 40 | start_eval: 0 41 | log_period: 10 42 | eval_interval: 1 43 | 44 | # Settings for Testing 45 | resume: '' 46 | mser: true 47 | test_only: false 48 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | """MCJA/configs/__init__.py 2 | It is used to mark a directory as a Python package directory. 3 | """ -------------------------------------------------------------------------------- /configs/default/__init__.py: -------------------------------------------------------------------------------- 1 | """MCJA/configs/default/__init__.py 2 | It serves as an entry point for the default configuration settings. 3 | """ 4 | 5 | from configs.default.dataset import dataset_cfg 6 | from configs.default.strategy import strategy_cfg 7 | 8 | __all__ = ["dataset_cfg", "strategy_cfg"] 9 | -------------------------------------------------------------------------------- /configs/default/dataset.py: -------------------------------------------------------------------------------- 1 | """MCJA/configs/default/dataset.py 2 | It defines the default configuration settings for customized datasets. 3 | """ 4 | 5 | from yacs.config import CfgNode 6 | 7 | dataset_cfg = CfgNode() 8 | 9 | dataset_cfg.sysu = CfgNode() 10 | dataset_cfg.sysu.num_id = 395 11 | dataset_cfg.sysu.num_cam = 6 12 | dataset_cfg.sysu.data_root = "../../Datasets/SYSU-MM01" 13 | 14 | dataset_cfg.regdb = CfgNode() 15 | dataset_cfg.regdb.num_id = 206 16 | dataset_cfg.regdb.num_cam = 2 17 | dataset_cfg.regdb.data_root = "../../Datasets/RegDB" 18 | -------------------------------------------------------------------------------- /configs/default/strategy.py: -------------------------------------------------------------------------------- 1 | """MCJA/configs/default/strategy.py 2 | It outlines the default strategy configurations for the framework. 3 | """ 4 | 5 | from yacs.config import CfgNode 6 | 7 | strategy_cfg = CfgNode() 8 | 9 | strategy_cfg.prefix = "SYSU" 10 | 11 | # Settings for Data 12 | strategy_cfg.dataset = "sysu" 13 | strategy_cfg.image_size = (384, 128) 14 | strategy_cfg.sample_method = "norm_triplet" 15 | strategy_cfg.p_size = 8 16 | strategy_cfg.k_size = 16 17 | strategy_cfg.batch_size = 128 18 | 19 | # Settings for Augmentation 20 | strategy_cfg.random_flip = True 21 | strategy_cfg.random_crop = True 22 | strategy_cfg.random_erase = True 23 | strategy_cfg.color_jitter = True 24 | strategy_cfg.padding = 10 25 | strategy_cfg.vimc_wg = True 26 | strategy_cfg.vimc_cc = True 27 | strategy_cfg.vimc_sj = True 28 | 29 | # Settings for Model 30 | strategy_cfg.drop_last_stride = False 31 | strategy_cfg.mda_ratio = 2 32 | strategy_cfg.mda_m = 2 33 | 34 | # Setting for Loss 35 | strategy_cfg.loss_id = True 36 | strategy_cfg.loss_cmr = True 37 | 38 | # Settings for Training 39 | strategy_cfg.lr = 0.00035 40 | strategy_cfg.wd = 0.0005 41 | strategy_cfg.num_epoch = 140 42 | strategy_cfg.lr_step = [80, 120] 43 | strategy_cfg.fp16 = True 44 | 45 | # Settings for Logging 46 | strategy_cfg.start_eval = 0 47 | strategy_cfg.log_period = 10 48 | strategy_cfg.eval_interval = 1 49 | 50 | # Settings for Testing 51 | strategy_cfg.resume = '' 52 | strategy_cfg.mser = True 53 | strategy_cfg.test_only = False 54 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """MCJA/data/__init__.py 2 | It orchestrates the data handling process, focusing on data transforming and loader for the training and testing. 3 | """ 4 | 5 | import torch 6 | import torchvision.transforms as T 7 | from torch.utils.data import DataLoader 8 | 9 | from data.dataset import SYSUDataset 10 | from data.dataset import RegDBDataset 11 | 12 | from data.sampler import CrossModalityIdentitySampler 13 | from data.sampler import CrossModalityRandomSampler 14 | from data.sampler import IdentityCrossModalitySampler 15 | from data.sampler import NormTripletSampler 16 | 17 | from data.transform import WeightedGrayscale 18 | from data.transform import ChannelCutMix 19 | from data.transform import SpectrumJitter 20 | from data.transform import ChannelAugmentation 21 | from data.transform import NoTransform 22 | 23 | 24 | def collate_fn(batch): 25 | """ 26 | Custom collate function for DataLoader to handle batches of data. This function is designed to process a batch of 27 | data by separating and recombining the elements of each data point in the batch, except for a specified element 28 | (e.g., image paths). The recombination is done in such a way that it preserves the integrity of multi-modal data 29 | or other structured data necessary for model training or evaluation. The function operates by zipping the batch 30 | (which combines elements from each data point across the batch), then selectively stacking the elements to form a 31 | new batch tensor. It specifically skips over a index (in this case, the image paths) and reinserts this non-tensor 32 | data back into its original position in the batch. This approach ensures compatibility with models expecting data 33 | in a specific format while accommodating for elements like paths that should not be converted into tensors. 34 | 35 | Args: 36 | - batch (list): A list of tuples, where each tuple represents a data point and contains elements, 37 | including images, labels, camera IDs, image paths, and image IDs (img, label, cam_id, img_path, img_id). 38 | 39 | Returns: 40 | - list: A list of tensors and other data types recombined from the input batch, with tensor elements 41 | stacked along a new dimension and non-tensor elements (e.g., paths) preserved in their original form. 42 | """ 43 | 44 | samples = list(zip(*batch)) 45 | 46 | data = [torch.stack(x, 0) for i, x in enumerate(samples) if i != 3] 47 | data.insert(3, samples[3]) 48 | return data 49 | 50 | 51 | def get_train_loader(dataset, root, sample_method, batch_size, p_size, k_size, image_size, 52 | random_flip=False, random_crop=False, random_erase=False, color_jitter=False, padding=0, 53 | vimc_wg=False, vimc_cc=False, vimc_sj=False, num_workers=4): 54 | """ 55 | Constructs and returns a DataLoader for training with specific datasets (SYSU or RegDB), incorporating a variety 56 | of data augmentation techniques and sampling strategies tailored for mixed-modality (visible and infrared) computer 57 | vision tasks. This function allows for extensive customization of the data preprocessing pipeline, including options 58 | for random flipping, cropping, erasing, color jitter, and innovative visible-infrared modality coordination (VIMC). 59 | The sampling strategy for forming batches can be selected from among several options, including norm triplet, cross 60 | modality random, cross modality identity, and identity cross modality samplers, to suit different training needs 61 | and objectives. This function plays a critical role in preparing the data for efficient and effective training by 62 | dynamically adjusting to the specified dataset, sample method, data augmentation preferences, etc. 63 | 64 | Args: 65 | - dataset (str): Name of the dataset to use ('sysu' or 'regdb'). 66 | - root (str): Root directory where the dataset is stored. 67 | - sample_method (str): Method used for sampling data points to form batches. 68 | - batch_size (int): Number of data points in each batch. 69 | - p_size (int): Number of identities per batch (used in certain sampling methods). 70 | - k_size (int): Number of instances per identity (used in certain sampling methods). 71 | - image_size (tuple): The size to which the images are resized. 72 | - random_flip (bool): Whether to randomly flip images horizontally. 73 | - random_crop (bool): Whether to randomly crop images. 74 | - random_erase (bool): Whether to randomly erase parts of images. 75 | - color_jitter (bool): Whether to apply random color jittering. 76 | - padding (int): Padding size used for random cropping. 77 | - vimc_wg (bool): Whether to apply weighted grayscale conversion. 78 | - vimc_cc (bool): Whether to apply channel cutmix augmentation. 79 | - vimc_sj (bool): Whether to apply spectrum jitter. 80 | - num_workers (int): Number of worker threads to use for loading data. 81 | 82 | Returns: 83 | - DataLoader: A DataLoader object ready for training, with batches formed 84 | according to the specified sample method and data augmentation settings. 85 | """ 86 | 87 | # Data Transform - RGB --------------------------------------------------------------------------------------------- 88 | t = [T.Resize(image_size)] 89 | 90 | t.append(T.RandomChoice([ 91 | T.RandomApply([T.ColorJitter(hue=0.20)], p=0.5) if color_jitter else NoTransform(), 92 | 93 | ###### Visible-Infrared Modality Coordinator (VIMC) ###### 94 | WeightedGrayscale(p=0.5) if vimc_wg else NoTransform(), 95 | ChannelCutMix(p=0.5) if vimc_cc else NoTransform(), 96 | SpectrumJitter(factor=(0.00, 1.00), p=0.5) if vimc_sj else NoTransform(), 97 | ])) 98 | 99 | if random_flip: 100 | t.append(T.RandomHorizontalFlip()) 101 | 102 | if random_crop: 103 | t.append(T.RandomCrop(image_size, padding=padding, fill=127)) 104 | 105 | t.extend([T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 106 | 107 | if random_erase: 108 | t.append(T.RandomErasing(value=0, scale=(0.02, 0.30))) 109 | 110 | transform_rgb = T.Compose(t) 111 | 112 | # Data Transform - IR ---------------------------------------------------------------------------------------------- 113 | t = [T.Resize(image_size)] 114 | 115 | if random_flip: 116 | t.append(T.RandomHorizontalFlip()) 117 | 118 | if random_crop: 119 | t.append(T.RandomCrop(image_size, padding=padding, fill=127)) 120 | 121 | t.extend([T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 122 | 123 | if random_erase: 124 | t.append(T.RandomErasing(value=0, scale=(0.02, 0.30))) 125 | 126 | transform_ir = T.Compose(t) 127 | 128 | # Dataset ---------------------------------------------------------------------------------------------------------- 129 | if dataset == 'sysu': 130 | train_dataset = SYSUDataset(root, mode='train', transform_rgb=transform_rgb, transform_ir=transform_ir) 131 | elif dataset == 'regdb': 132 | train_dataset = RegDBDataset(root, mode='train', transform_rgb=transform_rgb, transform_ir=transform_ir) 133 | else: 134 | raise NotImplementedError(f'Dataset - {dataset} is not supported') 135 | 136 | # DataSampler ------------------------------------------------------------------------------------------------------ 137 | assert sample_method in ['none', 'norm_triplet', 138 | 'cross_modality_random', 139 | 'cross_modality_identity', 140 | 'identity_cross_modality'] 141 | shuffle = False 142 | if sample_method == 'none': 143 | sampler = None 144 | shuffle = True 145 | elif sample_method == 'norm_triplet': 146 | batch_size = p_size * k_size 147 | sampler = NormTripletSampler(train_dataset, p_size * k_size, k_size) 148 | elif sample_method == 'cross_modality_random': 149 | sampler = CrossModalityRandomSampler(train_dataset, batch_size) 150 | elif sample_method == 'cross_modality_identity': 151 | batch_size = p_size * k_size 152 | sampler = CrossModalityIdentitySampler(train_dataset, p_size, k_size) 153 | elif sample_method == 'identity_cross_modality': 154 | batch_size = p_size * k_size 155 | sampler = IdentityCrossModalitySampler(train_dataset, p_size * k_size, k_size) 156 | ## Note: 157 | ## When sample_method is in [none, cross_modity_random], 158 | ## batch_size is adopted, and p_size and k_size are invalid. 159 | ## When sample_method is in [norm_triplet, cross_modity_identity, identity_cross_modity], 160 | ## p_size and k_size are adopted, and batch_size is invalid. 161 | 162 | # DataLoader ------------------------------------------------------------------------------------------------------- 163 | train_loader = DataLoader(train_dataset, batch_size, sampler=sampler, 164 | shuffle=shuffle, drop_last=True, pin_memory=True, 165 | collate_fn=collate_fn, num_workers=num_workers) 166 | return train_loader 167 | 168 | 169 | def get_test_loader(dataset, root, batch_size, image_size, num_workers=4, mode=None): 170 | """ 171 | Creates and returns DataLoader objects for the gallery and query datasets, intended for use in the testing phase of 172 | mixed-modality (visible and infrared) computer vision tasks. This function configures data preprocessing pipelines 173 | with transformations tailored for evaluation, including resizing, channel augmentation based on a specified mode, 174 | and normalization. It supports various modes for channel augmentation, enabling flexibility in how images are 175 | processed and potentially enhancing model robustness during evaluation. The function is designed to work with 176 | specific datasets (SYSU or RegDB), preparing both gallery and query sets for efficient and effective testing. 177 | 178 | Args: 179 | - dataset (str): The name of the dataset to be used ('sysu' or 'regdb'). 180 | - root (str): The root directory where the dataset is stored. 181 | - batch_size (int): The number of data points in each batch. 182 | - image_size (tuple): The size to which the images are resized. 183 | - num_workers (int): The number of worker threads to use for data loading. 184 | - mode (str, optional): The mode of channel augmentation to apply to the RGB data. 185 | Options include None, 'avg', 'r', 'g', 'b', 'rand', 'wg', 'cc', 'sj', with each providing a different manner. 186 | 187 | Returns: 188 | - tuple: A tuple containing two DataLoader objects, one for the gallery dataset and one for the query dataset, 189 | both configured for testing with the specified transformations and settings. 190 | """ 191 | 192 | assert mode in [None, 'avg', 'r', 'g', 'b', 'rand', 'wg', 'cc', 'sj'] 193 | 194 | # Data Transform - RGB --------------------------------------------------------------------------------------------- 195 | transform_rgb = T.Compose([ 196 | T.Resize(image_size), 197 | ChannelAugmentation(mode=mode), 198 | T.ToTensor(), 199 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 200 | ]) 201 | 202 | # Data Transform - IR ---------------------------------------------------------------------------------------------- 203 | transform_ir = T.Compose([ 204 | T.Resize(image_size), 205 | T.ToTensor(), 206 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 207 | ]) 208 | 209 | # Dataset ---------------------------------------------------------------------------------------------------------- 210 | if dataset == 'sysu': 211 | gallery_dataset = SYSUDataset(root, mode='gallery', transform_rgb=transform_rgb, transform_ir=transform_ir) 212 | query_dataset = SYSUDataset(root, mode='query', transform_rgb=transform_rgb, transform_ir=transform_ir) 213 | elif dataset == 'regdb': 214 | gallery_dataset = RegDBDataset(root, mode='gallery', transform_rgb=transform_rgb, transform_ir=transform_ir) 215 | query_dataset = RegDBDataset(root, mode='query', transform_rgb=transform_rgb, transform_ir=transform_ir) 216 | else: 217 | raise NotImplementedError(f'Dataset - {dataset} is not supported') 218 | 219 | # DataLoader ------------------------------------------------------------------------------------------------------- 220 | query_loader = DataLoader(dataset=query_dataset, 221 | batch_size=batch_size, 222 | shuffle=False, 223 | pin_memory=True, 224 | drop_last=False, 225 | collate_fn=collate_fn, 226 | num_workers=num_workers) 227 | 228 | gallery_loader = DataLoader(dataset=gallery_dataset, 229 | batch_size=batch_size, 230 | shuffle=False, 231 | pin_memory=True, 232 | drop_last=False, 233 | collate_fn=collate_fn, 234 | num_workers=num_workers) 235 | 236 | return gallery_loader, query_loader 237 | -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | """MCJA/data/dataset.py 2 | It defines dataset classes for handling image data in the context of cross-modality person re-identification task. 3 | """ 4 | 5 | import os 6 | from glob import glob 7 | 8 | import torch 9 | from PIL import Image 10 | from torch.utils.data import Dataset 11 | 12 | 13 | class SYSUDataset(Dataset): 14 | """ 15 | A dataset class tailored for the SYSU-MM01 dataset, designed to support the loading and preprocessing of data for 16 | training, querying, and gallery modes in cross-modality (visible and infrared) person re-identification tasks. The 17 | class handles the specifics of dataset directory structure, selecting appropriate subsets of images based on the 18 | mode (train, gallery, query) and performing specified transformations on the visible and infrared images separately. 19 | This class ensures that images are correctly matched with their labels, camera IDs, and other relevant information, 20 | facilitating their use in training and evaluation of models for person re-identification. 21 | 22 | The constructor of this class takes the root directory of the SYSU-MM01 dataset, the mode of operation (training, 23 | gallery, or query), and optional transformations for RGB (visible) and IR (infrared) images. If `memory_loaded` is 24 | set to True, all images are loaded into memory at initialization for faster access during training or evaluation. 25 | This class is compatible with PyTorch's DataLoader, making it easy to batch and shuffle the dataset as needed. 26 | 27 | Args: 28 | - root (str): The root directory where the SYSU-MM01 dataset is stored. 29 | - mode (str): The mode of dataset usage, which can be 'train', 'gallery', or 'query'. 30 | - transform_rgb (callable, optional): A function that takes in an RGB image and returns a transformed version. 31 | - transform_ir (callable, optional): A function that takes in an IR image and returns a transformed version. 32 | - memory_loaded (bool): If True, all images will be loaded into memory at initialization. 33 | 34 | Attributes: 35 | - img_paths (list): A list of paths to images that belong to the selected mode and IDs. 36 | - cam_ids (list): Camera IDs corresponding to each image in `img_paths`. 37 | - num_ids (int): The number of unique identities present in the selected mode. 38 | - ids (list): A list of identity labels corresponding to each image. 39 | - img_data (list, optional): If `memory_loaded` is True, this list contains preloaded images from `img_paths`. 40 | 41 | Methods: 42 | - __len__(): Returns the total number of images in the img_paths. 43 | - __getitem__(item): Retrieves the image and its metadata at the specified index, 44 | applying the appropriate transformations based on the camera ID (modality labels). 45 | """ 46 | 47 | def __init__(self, root, mode='train', transform_rgb=None, transform_ir=None, memory_loaded=False): 48 | assert os.path.isdir(root) 49 | assert mode in ['train', 'gallery', 'query'] 50 | 51 | if mode == 'train': 52 | train_ids = open(os.path.join(root, 'exp', 'train_id.txt')).readline() 53 | val_ids = open(os.path.join(root, 'exp', 'val_id.txt')).readline() 54 | 55 | train_ids = train_ids.strip('\n').split(',') 56 | val_ids = val_ids.strip('\n').split(',') 57 | selected_ids = train_ids + val_ids 58 | else: 59 | test_ids = open(os.path.join(root, 'exp', 'test_id.txt')).readline() 60 | selected_ids = test_ids.strip('\n').split(',') 61 | 62 | selected_ids = [int(i) for i in selected_ids] 63 | num_ids = len(selected_ids) 64 | 65 | img_paths = glob(os.path.join(root, '**/*.jpg'), recursive=True) 66 | img_paths = [path for path in img_paths if int(path.split('/')[-2]) in selected_ids] 67 | 68 | if mode == 'gallery': 69 | img_paths = [path for path in img_paths if int(path.split('/')[-3][-1]) in (1, 2, 4, 5)] 70 | elif mode == 'query': 71 | img_paths = [path for path in img_paths if int(path.split('/')[-3][-1]) in (3, 6)] 72 | 73 | img_paths = sorted(img_paths) 74 | self.img_paths = img_paths 75 | self.cam_ids = [int(path.split('/')[-3][-1]) for path in img_paths] 76 | self.num_ids = num_ids 77 | self.transform_rgb = transform_rgb 78 | self.transform_ir = transform_ir 79 | 80 | if mode == 'train': 81 | id_map = dict(zip(selected_ids, range(num_ids))) 82 | self.ids = [id_map[int(path.split('/')[-2])] for path in img_paths] 83 | else: 84 | self.ids = [int(path.split('/')[-2]) for path in img_paths] 85 | 86 | self.memory_loaded = memory_loaded 87 | if memory_loaded: 88 | self.img_data = [Image.open(path) for path in self.img_paths] 89 | 90 | def __len__(self): 91 | return len(self.img_paths) 92 | 93 | def __getitem__(self, item): 94 | path = self.img_paths[item] 95 | if self.memory_loaded: 96 | img = self.img_data[item] 97 | else: 98 | img = Image.open(path) 99 | 100 | label = torch.as_tensor(self.ids[item], dtype=torch.long) 101 | cam = torch.as_tensor(self.cam_ids[item], dtype=torch.long) 102 | item = torch.as_tensor(item, dtype=torch.long) 103 | 104 | if cam == 3 or cam == 6: 105 | if self.transform_ir is not None: 106 | img = self.transform_ir(img) 107 | else: 108 | if self.transform_rgb is not None: 109 | img = self.transform_rgb(img) 110 | 111 | return img, label, cam, path, item 112 | 113 | 114 | class RegDBDataset(Dataset): 115 | """ 116 | A dataset class specifically designed for the RegDB dataset, facilitating the loading and preprocessing of images 117 | for the cross-modality (visible and infrared) person re-identification task. It supports different operational modes 118 | including training, gallery, and query, applying distinct preprocessing routines to RGB (visible) and IR (infrared) 119 | images as specified. This class handles the unique structure of the RegDB dataset, including its division into 120 | separate visible and thermal image sets, and it prepares the dataset for use in a PyTorch DataLoader, ensuring that 121 | images are appropriately matched with their labels and camera IDs. 122 | 123 | The constructor of this class takes several parameters including dataset's root directory, the mode of operation, 124 | optional transformations for both RGB and IR images, and a flag indicating whether images should be loaded into 125 | memory at initialization. This facilitates faster access during model training and evaluation, especially useful 126 | when working with large datasets or in environments where I/O speed is a bottleneck. 127 | 128 | Args: 129 | - root (str): The root directory where the RegDB dataset is stored. 130 | - mode (str): The mode of dataset usage, which can be 'train', 'gallery', or 'query'. 131 | - transform_rgb (callable, optional): A function/transform that applies to RGB images. 132 | - transform_ir (callable, optional): A function/transform that applies to IR images. 133 | - memory_loaded (bool): If set to True, all images are loaded into memory upfront for faster access. 134 | 135 | Attributes: 136 | - img_paths (list): A list of paths to images that belong to the selected mode and IDs. 137 | - cam_ids (list): Camera IDs derived from the image paths, (with visible cameras marked as 2 and thermal as 3). 138 | - num_ids (int): The number of unique identities present in the selected mode. 139 | - ids (list): A list of identity labels corresponding to each image. 140 | - img_data (list, optional): If `memory_loaded` is True, this list contains preloaded images from `img_paths`. 141 | 142 | Methods: 143 | - __len__(): Returns the total number of images in the img_paths. 144 | - __getitem__(item): Retrieves the image and its metadata at the specified index, 145 | applying the appropriate transformations based on the camera ID (modality labels). 146 | """ 147 | 148 | def __init__(self, root, mode='train', transform_rgb=None, transform_ir=None, memory_loaded=False): 149 | assert os.path.isdir(root) 150 | assert mode in ['train', 'gallery', 'query'] 151 | 152 | def loadIdx(index): 153 | Lines = index.readlines() 154 | idx = [] 155 | for line in Lines: 156 | tmp = line.strip('\n') 157 | tmp = tmp.split(' ') 158 | idx.append(tmp) 159 | return idx 160 | 161 | num = '1' 162 | if mode == 'train': 163 | index_RGB = loadIdx(open(root + '/idx/train_visible_' + num + '.txt', 'r')) 164 | index_IR = loadIdx(open(root + '/idx/train_thermal_' + num + '.txt', 'r')) 165 | else: 166 | index_RGB = loadIdx(open(root + '/idx/test_visible_' + num + '.txt', 'r')) 167 | index_IR = loadIdx(open(root + '/idx/test_thermal_' + num + '.txt', 'r')) 168 | 169 | if mode == 'gallery': 170 | img_paths = [root + '/' + path for path, _ in index_RGB] 171 | elif mode == 'query': 172 | img_paths = [root + '/' + path for path, _ in index_IR] 173 | else: 174 | img_paths = [root + '/' + path for path, _ in index_RGB] + [root + '/' + path for path, _ in index_IR] 175 | 176 | selected_ids = [int(path.split('/')[-2]) for path in img_paths] 177 | selected_ids = list(set(selected_ids)) 178 | num_ids = len(selected_ids) 179 | 180 | img_paths = sorted(img_paths) 181 | self.img_paths = img_paths 182 | self.cam_ids = [int(path.split('/')[-3] == 'Thermal') + 2 for path in img_paths] 183 | # Note: In SYSU-MM01 dataset, the visible cams are 1 2 4 5, and thermal cams are 3 6. 184 | # To simplify the code, visible cam is 2 and thermal cam is 3 in RegDB dataset. 185 | self.num_ids = num_ids 186 | self.transform_rgb = transform_rgb 187 | self.transform_ir = transform_ir 188 | 189 | if mode == 'train': 190 | id_map = dict(zip(selected_ids, range(num_ids))) 191 | self.ids = [id_map[int(path.split('/')[-2])] for path in img_paths] 192 | else: 193 | self.ids = [int(path.split('/')[-2]) for path in img_paths] 194 | 195 | self.memory_loaded = memory_loaded 196 | if memory_loaded: 197 | self.img_data = [Image.open(path) for path in self.img_paths] 198 | 199 | def __len__(self): 200 | return len(self.img_paths) 201 | 202 | def __getitem__(self, item): 203 | if self.memory_loaded: 204 | img = self.img_data[item] 205 | else: 206 | path = self.img_paths[item] 207 | img = Image.open(path) 208 | 209 | label = torch.tensor(self.ids[item], dtype=torch.long) 210 | cam = torch.tensor(self.cam_ids[item], dtype=torch.long) 211 | item = torch.tensor(item, dtype=torch.long) 212 | 213 | if cam == 3 or cam == 6: 214 | if self.transform_ir is not None: 215 | img = self.transform_ir(img) 216 | else: 217 | if self.transform_rgb is not None: 218 | img = self.transform_rgb(img) 219 | 220 | return img, label, cam, path, item 221 | -------------------------------------------------------------------------------- /data/sampler.py: -------------------------------------------------------------------------------- 1 | """MCJA/data/sampler.py 2 | It defines several Sampler classes, designed to facilitate cross-modality (e.g.,RGB and IR) person re-identification. 3 | """ 4 | 5 | import copy 6 | import numpy as np 7 | 8 | from collections import defaultdict 9 | from torch.utils.data import Sampler 10 | 11 | 12 | class NormTripletSampler(Sampler): 13 | """ 14 | Randomly sample N identities, then for each identity, 15 | randomly sample K instances, therefore batch size is N*K. 16 | It does not distinguish modalities. 17 | 18 | Args: 19 | - dataset (Dataset): Instance of dataset class. 20 | - num_instances (int): Number of instances per identity in a batch. 21 | - batch_size (int): Number of examples in a batch. 22 | """ 23 | 24 | def __init__(self, dataset, batch_size, num_instances): 25 | self.dataset = dataset 26 | self.batch_size = batch_size 27 | self.num_instances = num_instances 28 | self.num_pids_per_batch = self.batch_size // self.num_instances 29 | self.index_dic = defaultdict(list) 30 | for index, pid in enumerate(self.dataset.ids): 31 | self.index_dic[pid].append(index) 32 | self.pids = list(self.index_dic.keys()) 33 | 34 | # estimate number of examples in an epoch 35 | self.length = 0 36 | for pid in self.pids: 37 | idxs = self.index_dic[pid] 38 | num = len(idxs) 39 | if num < self.num_instances: 40 | num = self.num_instances 41 | self.length += num - num % self.num_instances 42 | 43 | def __iter__(self): 44 | batch_idxs_dict = defaultdict(list) 45 | 46 | for pid in self.pids: 47 | idxs = copy.deepcopy(self.index_dic[pid]) 48 | if len(idxs) < self.num_instances: 49 | idxs = np.random.choice(idxs, size=self.num_instances, replace=True) 50 | np.random.shuffle(idxs) 51 | batch_idxs = [] 52 | for idx in idxs: 53 | batch_idxs.append(idx) 54 | if len(batch_idxs) == self.num_instances: 55 | batch_idxs_dict[pid].append(batch_idxs) 56 | batch_idxs = [] 57 | 58 | avai_pids = copy.deepcopy(self.pids) 59 | final_idxs = [] 60 | 61 | while len(avai_pids) >= self.num_pids_per_batch: 62 | selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False) 63 | for pid in selected_pids: 64 | batch_idxs = batch_idxs_dict[pid].pop(0) 65 | final_idxs.extend(batch_idxs) 66 | if len(batch_idxs_dict[pid]) == 0: 67 | avai_pids.remove(pid) 68 | 69 | self.length = len(final_idxs) 70 | return iter(final_idxs) 71 | 72 | def __len__(self): 73 | return self.length 74 | 75 | 76 | class CrossModalityRandomSampler(Sampler): 77 | """ 78 | The first half of a batch are randomly selected RGB images, 79 | and the second half are randomly selected IR images. 80 | 81 | Args: 82 | - dataset (Dataset): Instance of dataset class. 83 | - batch_size (int): Total number of images in a batch. 84 | """ 85 | 86 | def __init__(self, dataset, batch_size): 87 | self.dataset = dataset 88 | self.batch_size = batch_size 89 | 90 | self.rgb_list = [] 91 | self.ir_list = [] 92 | for i, cam in enumerate(dataset.cam_ids): 93 | if cam in [3, 6]: 94 | self.ir_list.append(i) 95 | else: 96 | self.rgb_list.append(i) 97 | 98 | def __len__(self): 99 | return max(len(self.rgb_list), len(self.ir_list)) * 2 100 | 101 | def __iter__(self): 102 | sample_list = [] 103 | rgb_list = np.random.permutation(self.rgb_list).tolist() 104 | ir_list = np.random.permutation(self.ir_list).tolist() 105 | 106 | rgb_size = len(self.rgb_list) 107 | ir_size = len(self.ir_list) 108 | if rgb_size >= ir_size: 109 | diff = rgb_size - ir_size 110 | reps = diff // ir_size 111 | pad_size = diff % ir_size 112 | for _ in range(reps): 113 | ir_list.extend(np.random.permutation(self.ir_list).tolist()) 114 | ir_list.extend(np.random.choice(self.ir_list, pad_size, replace=False).tolist()) 115 | else: 116 | diff = ir_size - rgb_size 117 | reps = diff // ir_size 118 | pad_size = diff % ir_size 119 | for _ in range(reps): 120 | rgb_list.extend(np.random.permutation(self.rgb_list).tolist()) 121 | rgb_list.extend(np.random.choice(self.rgb_list, pad_size, replace=False).tolist()) 122 | 123 | assert len(rgb_list) == len(ir_list) 124 | 125 | half_bs = self.batch_size // 2 126 | for start in range(0, len(rgb_list), half_bs): 127 | sample_list.extend(rgb_list[start:start + half_bs]) 128 | sample_list.extend(ir_list[start:start + half_bs]) 129 | 130 | return iter(sample_list) 131 | 132 | 133 | class CrossModalityIdentitySampler(Sampler): 134 | """ 135 | The first half of a batch are randomly selected k_size/2 RGB images for each randomly selected p_size people, 136 | and the second half are randomly selected k_size/2 IR images for each the same p_size people. 137 | Batch - [id1_rgb, id1_rgb, ..., id2_rgb, id2_rgb, ..., id1_ir, id1_ir, ..., id2_ir, id2_ir, ...] 138 | 139 | Args: 140 | - dataset (Dataset): Instance of dataset class. 141 | - p_size (int): Number of identities per batch. 142 | - k_size (int): Number of instances per identity. 143 | """ 144 | 145 | def __init__(self, dataset, p_size, k_size): 146 | self.dataset = dataset 147 | self.p_size = p_size 148 | self.k_size = k_size // 2 149 | self.batch_size = p_size * k_size * 2 150 | 151 | self.id2idx_rgb = defaultdict(list) 152 | self.id2idx_ir = defaultdict(list) 153 | for i, identity in enumerate(dataset.ids): 154 | if dataset.cam_ids[i] in [3, 6]: 155 | self.id2idx_ir[identity].append(i) 156 | else: 157 | self.id2idx_rgb[identity].append(i) 158 | 159 | self.num_base_samples = self.dataset.num_ids * self.k_size * 2 160 | 161 | self.num_repeats = len(dataset.ids) // self.num_base_samples 162 | self.num_samples = self.num_base_samples * self.num_repeats 163 | 164 | # num_ir, num_rgb = 0, 0 165 | # for c_id in dataset.cam_ids: 166 | # if c_id in [3, 6]: 167 | # num_ir += 1 168 | # else: 169 | # num_rgb += 1 170 | # self.num_repeats = (num_ir * 2) // self.num_base_samples 171 | # self.num_samples = self.num_base_samples * self.num_repeats 172 | 173 | def __len__(self): 174 | return self.num_samples 175 | 176 | def __iter__(self): 177 | sample_list = [] 178 | 179 | for r in range(self.num_repeats): 180 | id_perm = np.random.permutation(self.dataset.num_ids) 181 | for start in range(0, self.dataset.num_ids, self.p_size): 182 | selected_ids = id_perm[start:start + self.p_size] 183 | 184 | sample = [] 185 | for identity in selected_ids: 186 | replace = len(self.id2idx_rgb[identity]) < self.k_size 187 | s = np.random.choice(self.id2idx_rgb[identity], size=self.k_size, replace=replace) 188 | sample.extend(s) 189 | 190 | sample_list.extend(sample) 191 | 192 | sample.clear() 193 | for identity in selected_ids: 194 | replace = len(self.id2idx_ir[identity]) < self.k_size 195 | s = np.random.choice(self.id2idx_ir[identity], size=self.k_size, replace=replace) 196 | sample.extend(s) 197 | 198 | sample_list.extend(sample) 199 | 200 | return iter(sample_list) 201 | 202 | 203 | class IdentityCrossModalitySampler(Sampler): 204 | """ 205 | It is equivalent to CrossModalityIdentitySampler, but the arrangement is different. 206 | Batch - [id1_ir, id1_rgb, id1_ir, id1_rgb, ..., id2_ir, id2_rgb, id2_ir, id2_rgb, ...] 207 | 208 | Args: 209 | - dataset (Dataset): Instance of dataset class. 210 | - batch_size (int): Number of examples in a batch. 211 | - num_instances (int): Number of instances per identity in a batch. 212 | """ 213 | 214 | def __init__(self, dataset, batch_size, num_instances): 215 | self.dataset = dataset 216 | self.batch_size = batch_size 217 | self.num_instances = num_instances 218 | self.num_pids_per_batch = self.batch_size // self.num_instances 219 | self.index_dic_R = defaultdict(list) 220 | self.index_dic_I = defaultdict(list) 221 | for i, identity in enumerate(dataset.ids): 222 | if dataset.cam_ids[i] in [3, 6]: 223 | self.index_dic_I[identity].append(i) 224 | else: 225 | self.index_dic_R[identity].append(i) 226 | self.pids = list(self.index_dic_I.keys()) 227 | 228 | # estimate number of examples in an epoch 229 | self.length = 0 230 | for pid in self.pids: 231 | idxs = self.index_dic_I[pid] 232 | num = len(idxs) 233 | if num < self.num_instances: 234 | num = self.num_instances 235 | self.length += num - num % self.num_instances 236 | 237 | def __len__(self): 238 | return self.length 239 | 240 | def __iter__(self): 241 | batch_idxs_dict = defaultdict(list) 242 | 243 | for pid in self.pids: 244 | idxs_I = copy.deepcopy(self.index_dic_I[pid]) 245 | idxs_R = copy.deepcopy(self.index_dic_R[pid]) 246 | if len(idxs_I) < self.num_instances // 2 and len(idxs_R) < self.num_instances // 2: 247 | idxs_I = np.random.choice(idxs_I, size=self.num_instances // 2, replace=True) 248 | idxs_R = np.random.choice(idxs_R, size=self.num_instances // 2, replace=True) 249 | if len(idxs_I) > len(idxs_R): 250 | idxs_I = np.random.choice(idxs_I, size=len(idxs_R), replace=False) 251 | if len(idxs_R) > len(idxs_I): 252 | idxs_R = np.random.choice(idxs_R, size=len(idxs_I), replace=False) 253 | np.random.shuffle(idxs_I) 254 | np.random.shuffle(idxs_R) 255 | batch_idxs = [] 256 | for idx_I, idx_R in zip(idxs_I, idxs_R): 257 | batch_idxs.append(idx_I) 258 | batch_idxs.append(idx_R) 259 | if len(batch_idxs) == self.num_instances: 260 | batch_idxs_dict[pid].append(batch_idxs) 261 | batch_idxs = [] 262 | 263 | avai_pids = copy.deepcopy(self.pids) 264 | final_idxs = [] 265 | 266 | while len(avai_pids) >= self.num_pids_per_batch: 267 | selected_pids = np.random.choice(avai_pids, self.num_pids_per_batch, replace=False) 268 | for pid in selected_pids: 269 | batch_idxs = batch_idxs_dict[pid].pop(0) 270 | final_idxs.extend(batch_idxs) 271 | if len(batch_idxs_dict[pid]) == 0: 272 | avai_pids.remove(pid) 273 | 274 | self.length = len(final_idxs) 275 | return iter(final_idxs) 276 | -------------------------------------------------------------------------------- /data/transform.py: -------------------------------------------------------------------------------- 1 | """MCJA/data/transform.py 2 | It contains a collection of custom image transformation classes designed for augmenting and preprocessing images. 3 | """ 4 | 5 | import math 6 | import numbers 7 | import random 8 | import numpy as np 9 | from PIL import Image 10 | 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | class WeightedGrayscale(nn.Module): 16 | """ 17 | A module that applies a weighted grayscale transformation to an image, which converts an RGB image to grayscale 18 | by applying custom weights to each channel before averaging them. This approach allows for more flexibility than 19 | standard grayscale, potentially emphasizing certain features more than others depending on the chosen weights. 20 | 21 | Args: 22 | - weights (tuple of floats, optional): The weights to apply to the R, G, and B channels, respectively. 23 | If not specified, weights are randomly generated for each image. 24 | - p (float): The probability with which the weighted grayscale transformation is applied. 25 | A value of 1.0 means the transformation is always applied, whereas a value of 0 means it is never applied. 26 | 27 | Methods: 28 | - forward(img): Applies the weighted grayscale transformation to the given img with probability p. 29 | If the transformation is not applied, the original image is returned unchanged. 30 | """ 31 | 32 | def __init__(self, weights=None, p=1.0): 33 | super().__init__() 34 | self.weights = weights 35 | self.p = p 36 | 37 | def forward(self, img): 38 | if self.p < torch.rand(1): 39 | return img 40 | 41 | if self.weights is not None: 42 | w1, w2, w3 = self.weights 43 | else: 44 | w1 = random.uniform(0, 1) 45 | w2 = random.uniform(0, 1) 46 | w3 = random.uniform(0, 1) 47 | s = w1 + w2 + w3 48 | w1, w2, w3 = w1 / s, w2 / s, w3 / s 49 | img_data = np.asarray(img) 50 | img_data = w1 * img_data[:, :, 0] + w2 * img_data[:, :, 1] + w3 * img_data[:, :, 2] 51 | img_data = np.expand_dims(img_data, axis=-1).repeat(3, axis=-1) 52 | 53 | return Image.fromarray(np.uint8(img_data)) 54 | 55 | 56 | class ChannelCutMix(nn.Module): 57 | """ 58 | A module that implements the ChannelCutMix augmentation, a variant of the CutMix augmentation strategy that operates 59 | at the channel level. Unlike traditional CutMix, which combines patches from different images, ChannelCutMix 60 | selectively replaces a region in one channel of an image with the corresponding region from another channel of the 61 | same image. This process introduces diversity in the training data by blending features from different channels, 62 | potentially enhancing the robustness of models to variations in input data. 63 | 64 | Args: 65 | - p (float): The probability with which the ChannelCutMix augmentation is applied. 66 | A value of 1.0 means the augmentation is always applied, whereas a value of 0 means it is never applied. 67 | - scale (tuple of floats): The range of scales relative to the original area of the 68 | image that determines the size of the region to be replaced. 69 | - ratio (tuple of floats): The range of aspect ratios of the region to be replaced. 70 | This controls the shape of the region, allowing for both narrow and wide regions to be selected. 71 | 72 | Methods: 73 | - forward(img): Applies the ChannelCutMix augmentation to the given img with probability p. 74 | If the augmentation is not applied, the original image is returned unchanged. 75 | """ 76 | 77 | def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3)): 78 | super().__init__() 79 | self.p = p 80 | self.scale = scale 81 | self.ratio = ratio 82 | 83 | def forward(self, img): 84 | if self.p < torch.rand(1): 85 | return img 86 | 87 | img_h, img_w = img.size # PIL Image Type 88 | area = img_h * img_w 89 | log_ratio = torch.log(torch.tensor(self.ratio)) 90 | i, j, h, w = None, None, None, None 91 | for _ in range(10): 92 | cutmix_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() 93 | aspect_ratio = torch.exp(torch.empty(1).uniform_(log_ratio[0], log_ratio[1])).item() 94 | h = int(round(math.sqrt(cutmix_area * aspect_ratio))) 95 | w = int(round(math.sqrt(cutmix_area / aspect_ratio))) 96 | if not (h < img_h and w < img_w): 97 | continue 98 | i = torch.randint(0, img_h - h + 1, size=(1,)).item() 99 | j = torch.randint(0, img_w - w + 1, size=(1,)).item() 100 | break 101 | 102 | img_data = np.asarray(img) 103 | bg_c, fg_c = random.sample(range(3), k=2) 104 | bg_img_data = img_data[:, :, bg_c] 105 | fg_img_data = img_data[:, :, fg_c] 106 | bg_img_data[i:i + h, j:j + w] = fg_img_data[i:i + h, j:j + w] 107 | img_data = np.expand_dims(bg_img_data, axis=-1).repeat(3, axis=-1) 108 | 109 | return Image.fromarray(np.uint8(img_data)) 110 | 111 | 112 | class SpectrumJitter(nn.Module): 113 | """ 114 | A module for applying Spectrum Jitter augmentation to an image, which selectively alters the intensity of a randomly 115 | chosen color channel and blends it with the original image. This augmentation can introduce variations in color 116 | intensity and distribution across different channels, mimicking conditions of varying spectrum that a model might 117 | encounter in real-world scenarios. The purpose is to improve the model's robustness to changes in spectrum by 118 | exposing it to a wider range of color spectrum during training. 119 | 120 | Args: 121 | - factor (float or tuple of float): Specifies the range of factors to use for blending the altered channel back into 122 | the original image. If a single float is provided, it's interpreted as the maximum deviation from the default 123 | intensity of 1.0, creating a range [1-factor, 1+factor]. If a tuple is provided, it directly specifies the range. 124 | The factor influences how strongly the selected channel's intensity is altered. 125 | - p (float): The probability with which the Spectrum Jitter augmentation is applied to any given image. 126 | A value of 1.0 means the augmentation is always applied, while a value of 0 means it is never applied. 127 | 128 | Methods: 129 | - forward(img): Applies the Spectrum Jitter augmentation to the given img with probability p. 130 | If the augmentation is not applied, the original image is returned unchanged. 131 | """ 132 | 133 | def __init__(self, factor=0.5, p=0.5): 134 | super().__init__() 135 | self.factor = self._check_input(factor, 'spectrum') 136 | self.p = p 137 | 138 | @torch.jit.unused # Inspired by the implementation of color jitter in standard libraries 139 | def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): 140 | if isinstance(value, numbers.Number): 141 | if value < 0: 142 | raise ValueError("If {} is a single number, it must be non negative.".format(name)) 143 | value = [center - float(value), center + float(value)] 144 | if clip_first_on_zero: 145 | value[0] = max(value[0], 0.0) 146 | elif isinstance(value, (tuple, list)) and len(value) == 2: 147 | if not bound[0] <= value[0] <= value[1] <= bound[1]: 148 | raise ValueError("{} values should be between {}".format(name, bound)) 149 | else: 150 | raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name)) 151 | 152 | if value[0] == value[1] == center: 153 | value = None 154 | return value 155 | 156 | def forward(self, img): 157 | if self.p < torch.rand(1): 158 | return img 159 | 160 | selected_c = random.randint(0, 2) 161 | 162 | img_data = np.asarray(img) 163 | img_data = img_data[:, :, selected_c] 164 | img_data = np.expand_dims(img_data, axis=-1).repeat(3, axis=-1) 165 | degenerate = Image.fromarray(np.uint8(img_data)) 166 | 167 | factor = float(torch.empty(1).uniform_(self.factor[0], self.factor[1])) 168 | return Image.blend(degenerate, img, factor) 169 | 170 | 171 | class ChannelAugmentation(nn.Module): 172 | """ 173 | A module that encapsulates various channel-based augmentation strategies, allowing for flexible and probabilistic 174 | application of different augmentations to an image. This module supports a range of augmentation modes, including 175 | selection of individual RGB channels, averaging of channels, mixing channels, and more advanced techniques such 176 | as weighted grayscale conversion, channel cutmix, and spectrum jitter. The specific augmentation to apply can be 177 | selected via the `mode` parameter, providing a versatile tool for enhancing the diversity of original data. 178 | 179 | Args: 180 | - mode (str, optional): Specifies the augmentation technique to apply. Options include: 181 | - None: No augmentation is applied. 182 | - 'r', 'g', 'b': Selects a specific RGB channel. 183 | - 'avg': Averages the RGB channels. 184 | - 'rg_avg', 'rb_avg', 'gb_avg': Averages specified pairs of RGB channels. 185 | - 'rand': Randomly selects a channel at each call. 186 | - 'wg': Applies weighted grayscale augmentation. 187 | - 'cc': Applies channel cutmix augmentation. 188 | - 'sj': Applies spectrum jitter augmentation. 189 | Each mode introduces different types of variations, from simple channel selection to more complex transformations. 190 | - p (float): The probability with which the selected augmentation is applied. 191 | A value of 1.0 means the augmentation is always applied, while a value of 0 means it is never applied. 192 | 193 | Methods: 194 | - forward(img): Applies the configured augmentation to the given img with probability p. 195 | If the augmentation is not applied (either because p < 1 or mode is None), the original image is returned unchanged. 196 | """ 197 | 198 | def __init__(self, mode=None, p=1.0): 199 | super().__init__() 200 | assert mode in [None, 'r', 'g', 'b', 'avg', 'rg_avg', 'rb_avg', 'gb_avg', 'rand', 'wg', 'cc', 'sj'] 201 | self.mode = mode 202 | self.p = p 203 | if mode in ['r', 'g', 'b', 'avg', 'rg_avg', 'rb_avg', 'gb_avg', 'rand']: 204 | self.ca = ChannelSelection(mode=mode, p=p) 205 | elif mode == 'wg': 206 | self.ca = WeightedGrayscale(p=p) 207 | elif mode == 'cc': 208 | self.ca = ChannelCutMix(p=p) 209 | elif mode == 'sj': 210 | self.ca = SpectrumJitter(factor=(0.00, 1.00), p=p) 211 | else: 212 | self.ca = NoTransform() 213 | 214 | def forward(self, img): 215 | return self.ca(img) 216 | 217 | 218 | class ChannelSelection(nn.Module): 219 | """ 220 | A module that selectively manipulates the color channels of an image according to a specified mode. 221 | This augmentation technique can emphasize or de-emphasize certain features in the image based on color, 222 | which might be beneficial for tasks sensitive to specific color channels. The module supports a variety 223 | of modes that target different channels or combinations thereof, and it applies these transformations 224 | with a specified probability, allowing for stochastic data augmentation. 225 | 226 | Args: 227 | - mode (str, optional): Specifies the channel manipulation mode. 228 | It can be one of 'r', 'g', 'b', 'avg', 'rg_avg', 'rb_avg', 'gb_avg', or 'rand'. The default is 'rand', 229 | which randomly selects one of the RGB channels each time the augmentation is applied. 230 | - p (float): The probability with which the channel selection or modification is applied. 231 | A value of 1.0 means the transformation is always applied, whereas a value of 0 means it is never applied. 232 | """ 233 | 234 | def __init__(self, mode='rand', p=1.0): 235 | super().__init__() 236 | assert mode in ['r', 'g', 'b', 'avg', 'rg_avg', 'rb_avg', 'gb_avg', 'rand'] 237 | self.mode = mode 238 | self.p = p 239 | 240 | def forward(self, img): 241 | if self.p < torch.rand(1): 242 | return img 243 | 244 | img_data = np.asarray(img) 245 | if 'avg' in self.mode: 246 | if self.mode == 'avg': 247 | pass 248 | elif self.mode == 'rg_avg': 249 | img_data = np.stack([img_data[:, :, 0], img_data[:, :, 1]]) 250 | elif self.mode == 'rb_avg': 251 | img_data = np.stack([img_data[:, :, 0], img_data[:, :, 2]]) 252 | elif self.mode == 'gb_avg': 253 | img_data = np.stack([img_data[:, :, 1], img_data[:, :, 2]]) 254 | img_data = img_data.mean(axis=-1) 255 | else: 256 | if self.mode == 'r': 257 | selected_c = 0 258 | elif self.mode == 'g': 259 | selected_c = 1 260 | elif self.mode == 'b': 261 | selected_c = 2 262 | elif self.mode == 'rand': 263 | selected_c = random.randint(0, 2) 264 | img_data = img_data[:, :, selected_c] 265 | img_data = np.expand_dims(img_data, axis=-1).repeat(3, axis=-1) 266 | return Image.fromarray(np.uint8(img_data)) 267 | 268 | 269 | class NoTransform(nn.Module): 270 | """ 271 | A module that acts as a placeholder for a transformation step, performing no operation on the input image. 272 | It is designed to seamlessly integrate into data processing pipelines or augmentation sequences where conditional 273 | application of transformations is required but, in some cases, no actual transformation should be applied. 274 | 275 | Methods: 276 | - forward(img): Returns the input image unchanged, serving as a pass-through operation in a transformation pipeline. 277 | """ 278 | 279 | def __init__(self): 280 | super().__init__() 281 | 282 | def forward(self, img): 283 | return img 284 | -------------------------------------------------------------------------------- /engine/__init__.py: -------------------------------------------------------------------------------- 1 | """MCJA/engine/__init__.py 2 | It initializes the training and evaluation engines for the Multi-level Cross-modality Joint Alignment (MCJA) method. 3 | """ 4 | 5 | import os 6 | 7 | import numpy as np 8 | import scipy.io as sio 9 | import torch 10 | 11 | from glob import glob 12 | from ignite.engine import Events 13 | from ignite.handlers import ModelCheckpoint 14 | from ignite.handlers import Timer 15 | 16 | from engine.engine import create_eval_engine 17 | from engine.engine import create_train_engine 18 | from engine.metric import AutoKVMetric 19 | from utils.eval_data import eval_sysu, eval_regdb 20 | from configs.default.dataset import dataset_cfg 21 | 22 | 23 | def get_trainer(dataset, model, optimizer, lr_scheduler=None, logger=None, writer=None, non_blocking=False, 24 | log_period=10, save_dir="checkpoints", prefix="model", eval_interval=None, start_eval=None, 25 | gallery_loader=None, query_loader=None): 26 | """ 27 | A factory function that assembles and returns a training engine configured for VI-ReID tasks. This function sets up 28 | a trainer with custom event handlers for various stages of the training process, including model checkpointing, 29 | evaluation, logging, and learning rate scheduling. It integrates functionalities for performance evaluation using 30 | specified metrics and supports conditional execution of evaluations and logging activities based on the training. 31 | 32 | Args: 33 | - dataset (str): The name of the dataset being used, which dictates certain evaluation protocols. 34 | - model (nn.Module): The neural network model to be trained. 35 | - optimizer (Optimizer): The optimizer used for training the model. 36 | - lr_scheduler (Optional[Scheduler]): A learning rate scheduler for adjusting the learning rate across epochs. 37 | - logger (Logger): A logger for recording training progress and evaluation results. 38 | - writer (Optional[SummaryWriter]): A TensorBoard writer for logging metrics and visualizations. 39 | - non_blocking (bool): If set to True, attempts to asynchronously transfer data to device to improve performance. 40 | - log_period (int): The frequency (in iterations) with which training metrics are logged. 41 | - save_dir (str): The directory where model checkpoints are saved. 42 | - prefix (str): The prefix used for naming saved model files. 43 | - eval_interval (Optional[int]): The frequency (in epochs) with which the model is evaluated. 44 | - start_eval (Optional[int]): The epoch from which to start performing evaluations. 45 | - gallery_loader (Optional[DataLoader]): The DataLoader for the gallery set used in evaluations. 46 | - query_loader (Optional[DataLoader]): The DataLoader for the query set used in evaluations. 47 | 48 | Returns: 49 | - Engine: An Ignite Engine object configured for training, 50 | equipped with handlers for checkpointing, evaluation, and logging. 51 | """ 52 | 53 | # Trainer 54 | trainer = create_train_engine(model, optimizer, non_blocking) 55 | 56 | # Checkpoint Handler 57 | handler = ModelCheckpoint(save_dir, prefix, save_interval=eval_interval, n_saved=3, create_dir=True, 58 | save_as_state_dict=True, require_empty=False) 59 | trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {"model": model}) 60 | 61 | timer = Timer(average=True) 62 | kv_metric = AutoKVMetric() 63 | 64 | # Evaluator 65 | evaluator = None 66 | if not type(eval_interval) == int: 67 | raise TypeError("The parameter 'validate_interval' must be type INT.") 68 | if not type(start_eval) == int: 69 | raise TypeError("The parameter 'start_eval' must be type INT.") 70 | if eval_interval > 0 and gallery_loader is not None and query_loader is not None: 71 | evaluator = create_eval_engine(model, non_blocking) 72 | 73 | def run_init_eval(engine): 74 | logger.info('\n## Checking model performance with initial parameters...') 75 | 76 | # Extract Query Feature 77 | evaluator.run(query_loader) 78 | q_feats = torch.cat(evaluator.state.feat_list, dim=0) 79 | q_ids = torch.cat(evaluator.state.id_list, dim=0).numpy() 80 | q_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy() 81 | q_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0) 82 | 83 | # Extract Gallery Feature 84 | evaluator.run(gallery_loader) 85 | g_feats = torch.cat(evaluator.state.feat_list, dim=0) 86 | g_ids = torch.cat(evaluator.state.id_list, dim=0).numpy() 87 | g_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy() 88 | g_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0) 89 | 90 | if dataset == 'sysu': 91 | perm = sio.loadmat(os.path.join(dataset_cfg.sysu.data_root, 'exp', 'rand_perm_cam.mat'))['rand_perm_cam'] 92 | eval_sysu(q_feats, q_ids, q_cams, q_img_paths, 93 | g_feats, g_ids, g_cams, g_img_paths, 94 | perm, mode='all', num_shots=1) 95 | eval_sysu(q_feats, q_ids, q_cams, q_img_paths, 96 | g_feats, g_ids, g_cams, g_img_paths, 97 | perm, mode='all', num_shots=10) 98 | eval_sysu(q_feats, q_ids, q_cams, q_img_paths, 99 | g_feats, g_ids, g_cams, g_img_paths, 100 | perm, mode='indoor', num_shots=1) 101 | eval_sysu(q_feats, q_ids, q_cams, q_img_paths, 102 | g_feats, g_ids, g_cams, g_img_paths, 103 | perm, mode='indoor', num_shots=10) 104 | elif dataset == 'regdb': 105 | logger.info('Test Mode - infrared to visible') 106 | eval_regdb(q_feats, q_ids, q_cams, q_img_paths, 107 | g_feats, g_ids, g_cams, g_img_paths, mode='i2v') 108 | logger.info('Test Mode - visible to infrared') 109 | eval_regdb(g_feats, g_ids, g_cams, g_img_paths, 110 | q_feats, q_ids, q_cams, q_img_paths, mode='v2i') 111 | else: 112 | raise NotImplementedError(f'Dataset - {dataset} is not supported') 113 | 114 | evaluator.state.feat_list.clear() 115 | evaluator.state.id_list.clear() 116 | evaluator.state.cam_list.clear() 117 | evaluator.state.img_path_list.clear() 118 | del q_feats, q_ids, q_cams, g_feats, g_ids, g_cams 119 | 120 | logger.info('\n## Starting the training process...') 121 | 122 | @trainer.on(Events.STARTED) 123 | def train_start(engine): 124 | setattr(engine.state, "best_rank1", 0.0) 125 | run_init_eval(engine) 126 | 127 | @trainer.on(Events.COMPLETED) 128 | def train_completed(engine): 129 | pass 130 | 131 | @trainer.on(Events.EPOCH_STARTED) 132 | def epoch_started_callback(engine): 133 | kv_metric.reset() 134 | timer.reset() 135 | 136 | @trainer.on(Events.EPOCH_COMPLETED) 137 | def epoch_completed_callback(engine): 138 | epoch = engine.state.epoch 139 | 140 | if lr_scheduler is not None: 141 | lr_scheduler.step() 142 | 143 | if epoch % eval_interval == 0: 144 | logger.info("Model saved at {}/{}_model_{}.pth".format(save_dir, prefix, epoch)) 145 | 146 | if evaluator and epoch % eval_interval == 0 and epoch >= start_eval: 147 | # Extract Query Feature 148 | evaluator.run(query_loader) 149 | q_feats = torch.cat(evaluator.state.feat_list, dim=0) 150 | q_ids = torch.cat(evaluator.state.id_list, dim=0).numpy() 151 | q_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy() 152 | q_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0) 153 | 154 | # Extract Gallery Feature 155 | evaluator.run(gallery_loader) 156 | g_feats = torch.cat(evaluator.state.feat_list, dim=0) 157 | g_ids = torch.cat(evaluator.state.id_list, dim=0).numpy() 158 | g_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy() 159 | g_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0) 160 | 161 | if dataset == 'sysu': 162 | perm = sio.loadmat(os.path.join(dataset_cfg.sysu.data_root, 'exp', 'rand_perm_cam.mat'))[ 163 | 'rand_perm_cam'] 164 | r1, r5, r10, r20, mAP, mINP, _ = eval_sysu(q_feats, q_ids, q_cams, q_img_paths, 165 | g_feats, g_ids, g_cams, g_img_paths, 166 | perm, mode='all', num_shots=1) 167 | elif dataset == 'regdb': 168 | logger.info('Test Mode - infrared to visible') 169 | r1, r5, r10, r20, mAP, mINP, _ = eval_regdb(q_feats, q_ids, q_cams, q_img_paths, 170 | g_feats, g_ids, g_cams, g_img_paths, mode='i2v') 171 | logger.info('Test Mode - visible to infrared') 172 | r1_, r5_, r10_, r20_, mAP_, mINP_, _ = eval_regdb(g_feats, g_ids, g_cams, g_img_paths, 173 | q_feats, q_ids, q_cams, q_img_paths, mode='v2i') 174 | r1 = (r1 + r1_) / 2 175 | r5 = (r5 + r5_) / 2 176 | r10 = (r10 + r10_) / 2 177 | r20 = (r20 + r20_) / 2 178 | mAP = (mAP + mAP_) / 2 179 | mINP = (mINP + mINP_) / 2 180 | else: 181 | raise NotImplementedError(f'Dataset - {dataset} is not supported') 182 | 183 | if r1 > engine.state.best_rank1: 184 | for rm_best_model_path in glob("{}/{}_model_best-*.pth".format(save_dir, prefix)): 185 | os.remove(rm_best_model_path) 186 | engine.state.best_rank1 = r1 187 | torch.save(model.state_dict(), "{}/{}_model_best-{}.pth".format(save_dir, prefix, epoch)) 188 | 189 | if writer is not None: 190 | writer.add_scalar('eval/r1', r1, epoch) 191 | writer.add_scalar('eval/r5', r5, epoch) 192 | writer.add_scalar('eval/r10', r10, epoch) 193 | writer.add_scalar('eval/r20', r20, epoch) 194 | writer.add_scalar('eval/mAP', mAP, epoch) 195 | writer.add_scalar('eval/mINP', mINP, epoch) 196 | 197 | evaluator.state.feat_list.clear() 198 | evaluator.state.id_list.clear() 199 | evaluator.state.cam_list.clear() 200 | evaluator.state.img_path_list.clear() 201 | del q_feats, q_ids, q_cams, g_feats, g_ids, g_cams 202 | 203 | @trainer.on(Events.ITERATION_COMPLETED) 204 | def iteration_complete_callback(engine): 205 | timer.step() 206 | kv_metric.update(engine.state.output) 207 | 208 | epoch = engine.state.epoch 209 | iteration = engine.state.iteration 210 | iter_in_epoch = iteration - (epoch - 1) * len(engine.state.dataloader) 211 | 212 | if iter_in_epoch % log_period == 0 and iter_in_epoch > 0: 213 | batch_size = engine.state.batch[0].size(0) 214 | speed = batch_size / timer.value() 215 | msg = "Epoch[%d] Batch [%d] Speed: %.2f samples/sec" % (epoch, iter_in_epoch, speed) 216 | metric_dict = kv_metric.compute() 217 | if logger is not None: 218 | for k in sorted(metric_dict.keys()): 219 | msg += " %s: %.4f" % (k, metric_dict[k]) 220 | if writer is not None: 221 | writer.add_scalar('metric/{}'.format(k), metric_dict[k], iteration) 222 | logger.info(msg) 223 | kv_metric.reset() 224 | timer.reset() 225 | 226 | return trainer 227 | -------------------------------------------------------------------------------- /engine/engine.py: -------------------------------------------------------------------------------- 1 | """MCJA/engine/engine.py 2 | It defines the creation of training and evaluation engines using the Ignite library. 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from torch.autograd import no_grad 9 | from ignite.engine import Engine 10 | from ignite.engine import Events 11 | from apex import amp 12 | 13 | 14 | def create_train_engine(model, optimizer, non_blocking=False): 15 | """ 16 | A factory function that creates and returns an Ignite Engine configured for training a VI-ReID model. This engine 17 | orchestrates the training process, managing the data flow, loss calculation, parameter updates, and any additional 18 | computations needed per iteration. The function encapsulates the core training loop, including data loading to the 19 | device, executing model's forward pass, computing the loss, performing backpropagation, and updating model weights. 20 | 21 | Args: 22 | - model (nn.Module): The model to be trained. The model should accept input data, labels, camera IDs, and 23 | potentially other information like image paths and epoch number, returning computed loss and additional metrics. 24 | - optimizer (Optimizer): The optimizer used for updating the model parameters based on the computed gradients. 25 | - non_blocking (bool): If set to True, allows asynchronous data transfers to the GPU for improved efficiency. 26 | 27 | Returns: 28 | - Engine: An Ignite Engine object that processes batches of data using the provided model and optimizer. 29 | """ 30 | 31 | device = torch.device("cuda", torch.cuda.current_device()) 32 | 33 | def _process_func(engine, batch): 34 | model.train() 35 | 36 | data, labels, cam_ids, img_paths, img_idxes = batch 37 | epoch = engine.state.epoch 38 | data = data.to(device, non_blocking=non_blocking) 39 | labels = labels.to(device, non_blocking=non_blocking) 40 | cam_ids = cam_ids.to(device, non_blocking=non_blocking) 41 | 42 | optimizer.zero_grad(set_to_none=True) 43 | 44 | loss, metric = model(data, labels, 45 | cam_ids=cam_ids, 46 | img_paths=img_paths, 47 | epoch=epoch) 48 | 49 | with amp.scale_loss(loss, optimizer) as scaled_loss: 50 | scaled_loss.backward() 51 | 52 | optimizer.step() 53 | 54 | return metric 55 | 56 | return Engine(_process_func) 57 | 58 | 59 | def create_eval_engine(model, non_blocking=False): 60 | """ 61 | A factory function that creates and returns an Ignite Engine configured for evaluating a VI-ReID model. This engine 62 | manages evaluation process, facilitating the flow of data through the model and the collection of output features 63 | for later analysis. It operates in evaluation mode, ensuring that the model's behavior is consistent with inference 64 | conditions, such as disabled dropout layers. 65 | 66 | Args: 67 | - model (nn.Module): The model to be evaluated. The model should accept input data and potentially other 68 | information like camera IDs, returning feature representations. 69 | - non_blocking (bool): If set to True, allows asynchronous data transfers to the GPU to improve efficiency. 70 | 71 | Returns: 72 | - Engine: An Ignite Engine object that processes batches of data through the provided model in evaluation mode. 73 | """ 74 | 75 | device = torch.device("cuda", torch.cuda.current_device()) 76 | 77 | def _process_func(engine, batch): 78 | model.eval() 79 | data, labels, cam_ids, img_paths = batch[:4] 80 | data = data.to(device, non_blocking=non_blocking) 81 | 82 | with no_grad(): 83 | feat = model(data, cam_ids=cam_ids.to(device, non_blocking=non_blocking)) 84 | 85 | return feat.data.float().cpu(), labels, cam_ids, np.array(img_paths) 86 | 87 | engine = Engine(_process_func) 88 | 89 | @engine.on(Events.EPOCH_STARTED) 90 | def clear_data(engine): 91 | if not hasattr(engine.state, "feat_list"): 92 | setattr(engine.state, "feat_list", []) 93 | else: 94 | engine.state.feat_list.clear() 95 | 96 | if not hasattr(engine.state, "id_list"): 97 | setattr(engine.state, "id_list", []) 98 | else: 99 | engine.state.id_list.clear() 100 | 101 | if not hasattr(engine.state, "cam_list"): 102 | setattr(engine.state, "cam_list", []) 103 | else: 104 | engine.state.cam_list.clear() 105 | 106 | if not hasattr(engine.state, "img_path_list"): 107 | setattr(engine.state, "img_path_list", []) 108 | else: 109 | engine.state.img_path_list.clear() 110 | 111 | @engine.on(Events.ITERATION_COMPLETED) 112 | def store_data(engine): 113 | engine.state.feat_list.append(engine.state.output[0]) 114 | engine.state.id_list.append(engine.state.output[1]) 115 | engine.state.cam_list.append(engine.state.output[2]) 116 | engine.state.img_path_list.append(engine.state.output[3]) 117 | 118 | return engine 119 | -------------------------------------------------------------------------------- /engine/metric.py: -------------------------------------------------------------------------------- 1 | """MCJA/engine/metric.py 2 | It provides a flexible mechanism for aggregating and computing metrics of cross-modality person re-identification. 3 | """ 4 | 5 | from collections import defaultdict 6 | 7 | import torch 8 | from ignite.exceptions import NotComputableError 9 | from ignite.metrics import Metric, Accuracy 10 | 11 | 12 | class ScalarMetric(Metric): 13 | """ 14 | A simple, generic implementation of an Ignite Metric for aggregating scalar values over iterations or epochs. This 15 | class provides a framework for tracking and computing the average of any scalar metric (e.g., loss, accuracy) during 16 | the training or evaluation process of a machine learning model. It accumulates the sum of the scalar values and the 17 | count of instances (batches) it has seen, allowing for the calculation of average scalar metric over all instances. 18 | 19 | Methods: 20 | - update(value): Adds a new scalar value to the running sum and increments the instance count. 21 | This method is called at each iteration with the scalar metric value for that iteration. 22 | - reset(): Resets the running sum and instance count to zero. 23 | Typically called at the start of each epoch or evaluation run to prepare for new calculations. 24 | - compute(): Calculates and returns the average of all scalar values added since the last reset. 25 | If no instances have been added, it raises a NotComputableError, indicating that there is not enough data. 26 | """ 27 | 28 | def update(self, value): 29 | self.sum_metric += value 30 | self.sum_inst += 1 31 | 32 | def reset(self): 33 | self.sum_inst = 0 34 | self.sum_metric = 0 35 | 36 | def compute(self): 37 | if self.sum_inst == 0: 38 | raise NotComputableError('Accuracy must have at least one example before it can be computed') 39 | return self.sum_metric / self.sum_inst 40 | 41 | 42 | class IgnoreAccuracy(Accuracy): 43 | """ 44 | An extension of the Ignite Accuracy metric that incorporates the ability to ignore certain target labels during the 45 | computation of accuracy. This class is particularly useful in scenarios where some target labels in the dataset 46 | should not contribute to the accuracy calculation, such as padding tokens in sequence models or background classes 47 | in segmentation tasks. By specifying an ignore index, instances with this target label are excluded from both the 48 | numerator and denominator of the accuracy calculation. 49 | 50 | Args: 51 | - ignore_index (int): The target label that should be ignored in the accuracy computation. Instances with this 52 | label are not considered correct or incorrect predictions, effectively being excluded from the metric. 53 | 54 | Methods: 55 | - reset(): Resets the internal counters for correct predictions and total examples, 56 | preparing the metric for a new set of calculations. 57 | - update(output): Processes a batch of predictions and targets, 58 | updating the internal counters by counting correct predictions that do not correspond to the ignore index. 59 | - compute(): Calculates and returns the accuracy over all batches processed since the last reset, 60 | excluding instances with the ignore index from the calculation. 61 | """ 62 | 63 | def __init__(self, ignore_index=-1): 64 | super(IgnoreAccuracy, self).__init__() 65 | 66 | self.ignore_index = ignore_index 67 | 68 | def reset(self): 69 | self._num_correct = 0 70 | self._num_examples = 0 71 | 72 | def update(self, output): 73 | 74 | y_pred, y = self._check_shape(output) 75 | self._check_type((y_pred, y)) 76 | 77 | if self._type == "binary": 78 | indices = torch.round(y_pred).type(y.type()) 79 | elif self._type == "multiclass": 80 | indices = torch.max(y_pred, dim=1)[1] 81 | 82 | correct = torch.eq(indices, y).view(-1) 83 | ignore = torch.eq(y, self.ignore_index).view(-1) 84 | self._num_correct += torch.sum(correct).item() 85 | self._num_examples += correct.shape[0] - ignore.sum().item() 86 | 87 | def compute(self): 88 | if self._num_examples == 0: 89 | raise NotComputableError('Accuracy must have at least one example before it can be computed') 90 | return self._num_correct / self._num_examples 91 | 92 | 93 | class AutoKVMetric(Metric): 94 | """ 95 | A flexible metric class in the Ignite framework that computes and stores key-value (KV) pair metrics for each 96 | output of a model during training or evaluation. The AutoKVMetric class is designed to handle outputs in the 97 | form of dictionaries where each key corresponds to a specific metric name, and its value represents the metric 98 | value for that batch. This class allows for the automatic aggregation of multiple metrics over all batches, 99 | providing a convenient way to track a variety of performance indicators within a single metric class. 100 | 101 | Methods: 102 | - update(output): Updates the running sum of each metric based on the current batch's output. The output is expected 103 | to be a dictionary where each key-value pair represents a metric name and its corresponding value. 104 | - reset(): Resets all internal counters and sums for each metric, preparing metric for a new round of calculations. 105 | This method is typically called at the start of each epoch or evaluation run. 106 | - compute(): Calculates and returns the average value of each metric over all processed batches since last reset. 107 | The return value is a dictionary mirroring the structure of the input to `update`, with each key corresponding to 108 | a metric name and each value being the average metric value. 109 | """ 110 | 111 | def __init__(self): 112 | self.kv_sum_metric = defaultdict(lambda: torch.tensor(0., device="cuda")) 113 | self.kv_sum_inst = defaultdict(lambda: torch.tensor(0., device="cuda")) 114 | 115 | self.kv_metric = defaultdict(lambda: 0) 116 | 117 | super(AutoKVMetric, self).__init__() 118 | 119 | def update(self, output): 120 | if not isinstance(output, dict): 121 | raise TypeError('The output must be a key-value dict.') 122 | 123 | for k in output.keys(): 124 | self.kv_sum_metric[k].add_(output[k]) 125 | self.kv_sum_inst[k].add_(1) 126 | 127 | def reset(self): 128 | for k in self.kv_sum_metric.keys(): 129 | self.kv_sum_metric[k].zero_() 130 | self.kv_sum_inst[k].zero_() 131 | self.kv_metric[k] = 0 132 | 133 | def compute(self): 134 | for k in self.kv_sum_metric.keys(): 135 | if self.kv_sum_inst[k] == 0: 136 | continue 137 | # raise NotComputableError('Accuracy must have at least one example before it can be computed') 138 | 139 | metric_value = self.kv_sum_metric[k] / self.kv_sum_inst[k] 140 | self.kv_metric[k] = metric_value.item() 141 | 142 | return self.kv_metric 143 | -------------------------------------------------------------------------------- /figs/mcja_overall_structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/workingcoder/MCJA/efc1ffe9f1cb2c4f76a463bf708150626555b295/figs/mcja_overall_structure.png -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | """MCJA/losses/__init__.py 2 | It is used to mark a directory as a Python package directory. 3 | """ -------------------------------------------------------------------------------- /losses/cm_retrieval_loss.py: -------------------------------------------------------------------------------- 1 | """MCJA/losses/cm_retrieval_loss.py 2 | It defines the `CMRetrievalLoss` class, a loss function specifically designed for cross-modality retrieval task. 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torchsort 8 | 9 | from torch.nn import functional as F 10 | 11 | 12 | class CMRetrievalLoss(nn.Module): 13 | """ 14 | A module that implements a Cross-Modal Retrieval (CMR) Loss, designed for use in training models on tasks involving 15 | the retrieval of relevant items across different modalities. The CMR Loss computes pairwise distances between 16 | embeddings from different modalities, classifying them as either matches (same identity label across modalities) or 17 | mismatches (different identity labels). It uses a soft ranking mechanism to assign ranks based on the pairwise 18 | distances, aiming to ensure that matches are ranked higher (closer) than mismatches. 19 | 20 | Args: 21 | - embeddings (Tensor): The embeddings generated by the model for a batch of items. 22 | - id_labels (Tensor): The identity labels for each item in the batch. 23 | - m_labels (Tensor): The modality labels for each item in the batch. 24 | 25 | Methods: 26 | - forward(embeddings, id_labels, m_labels): Computes the CMR Loss for a batch of embeddings, 27 | calculating the soft ranking loss between predicted and target ranks. 28 | """ 29 | 30 | def __init__(self): 31 | super(CMRetrievalLoss, self).__init__() 32 | 33 | def forward(self, embeddings, id_labels, m_labels): 34 | m_labels_unique = torch.unique(m_labels) 35 | m_num = len(m_labels_unique) 36 | 37 | embeddings_list = [embeddings[m_labels == m_label] for m_label in m_labels_unique] 38 | id_labels_list = [id_labels[m_labels == m_label] for m_label in m_labels_unique] 39 | 40 | cmr_loss = 0 41 | valid_m_count = 0 42 | for i in range(m_num): 43 | cur_m_embeddings = embeddings_list[i] 44 | cur_m_id_labels = id_labels_list[i] 45 | other_m_embeddings = torch.cat([embeddings_list[j] for j in range(len(m_labels_unique)) if j != i], dim=0) 46 | other_m_id_labels = torch.cat([id_labels_list[j] for j in range(len(m_labels_unique)) if j != i], dim=0) 47 | 48 | match_mask = cur_m_id_labels.unsqueeze(dim=1) == other_m_id_labels.unsqueeze(dim=0) 49 | mismatch_mask = ~match_mask 50 | match_num = match_mask.sum(dim=-1) 51 | mismatch_num = mismatch_mask.sum(dim=-1) 52 | 53 | # Remove invalid queries (It has no cross-modal matching in the batch) 54 | remove_mask = (match_num == 0) | (match_num == len(other_m_id_labels)) 55 | if remove_mask.all(): 56 | continue 57 | cur_m_embeddings = cur_m_embeddings[~remove_mask] 58 | cur_m_id_labels = cur_m_id_labels[~remove_mask] 59 | match_mask = match_mask[~remove_mask] 60 | mismatch_mask = mismatch_mask[~remove_mask] 61 | match_num = match_num[~remove_mask] 62 | mismatch_num = mismatch_num[~remove_mask] 63 | 64 | dist_mat = F.cosine_similarity(cur_m_embeddings[:, None, :], other_m_embeddings[None, :, :], dim=-1) 65 | dist_mat = (1 - dist_mat) / 2 66 | 67 | predict_rank = torchsort.soft_rank(dist_mat, regularization="l2", regularization_strength=0.5) 68 | 69 | target_rank = torch.zeros_like(predict_rank) 70 | q_num, g_num = match_mask.shape 71 | 72 | target_rank[match_mask] = 1 73 | target_rank[mismatch_mask] = g_num 74 | 75 | cmr_loss += F.l1_loss(predict_rank, target_rank) 76 | 77 | valid_m_count += 1 78 | 79 | cmr_loss = cmr_loss / valid_m_count if valid_m_count > 0 else 0 80 | 81 | return cmr_loss 82 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """MCJA/main.py 2 | It is the main entry point for training the Multi-level Cross-modality Joint Alignment (MCJA) method. 3 | """ 4 | 5 | import os 6 | import glob 7 | import pprint 8 | import logging 9 | 10 | import numpy as np 11 | import scipy.io as sio 12 | 13 | import torch 14 | from torch import optim 15 | from torch.utils.tensorboard import SummaryWriter 16 | from apex import amp 17 | 18 | from data import get_train_loader 19 | from data import get_test_loader 20 | from models.mcja import MCJA 21 | from engine import get_trainer 22 | from engine.engine import create_eval_engine 23 | from utils.eval_data import eval_sysu, eval_regdb 24 | 25 | 26 | def train(cfg): 27 | # Recorder --------------------------------------------------------------------------------------------------------- 28 | logger = logging.getLogger('MCJA') 29 | tb_dir = os.path.join(cfg.log_dir, 'tensorboard') 30 | if not os.path.isdir(tb_dir): 31 | os.makedirs(tb_dir, exist_ok=True) 32 | writer = SummaryWriter(log_dir=tb_dir) 33 | 34 | # Train DataLoader ------------------------------------------------------------------------------------------------- 35 | train_loader = get_train_loader(dataset=cfg.dataset, root=cfg.data_root, 36 | sample_method=cfg.sample_method, 37 | batch_size=cfg.batch_size, 38 | p_size=cfg.p_size, 39 | k_size=cfg.k_size, 40 | image_size=cfg.image_size, 41 | random_flip=cfg.random_flip, 42 | random_crop=cfg.random_crop, 43 | random_erase=cfg.random_erase, 44 | color_jitter=cfg.color_jitter, 45 | padding=cfg.padding, 46 | vimc_wg=cfg.vimc_wg, 47 | vimc_cc=cfg.vimc_cc, 48 | vimc_sj=cfg.vimc_sj, 49 | num_workers=4) 50 | 51 | # Test DataLoader -------------------------------------------------------------------------------------------------- 52 | gallery_loader, query_loader = None, None 53 | if cfg.eval_interval > 0: 54 | gallery_loader, query_loader = get_test_loader(dataset=cfg.dataset, 55 | root=cfg.data_root, 56 | batch_size=cfg.batch_size, 57 | image_size=cfg.image_size, 58 | num_workers=4) 59 | 60 | # Model ------------------------------------------------------------------------------------------------------------ 61 | model = MCJA(num_classes=cfg.num_id, 62 | drop_last_stride=cfg.drop_last_stride, 63 | mda_ratio=cfg.mda_ratio, 64 | mda_m=cfg.mda_m, 65 | loss_id=cfg.loss_id, 66 | loss_cmr=cfg.loss_cmr) 67 | 68 | def get_parameter_number(net): 69 | total_num = sum(p.numel() for p in net.parameters()) 70 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 71 | return {'Total': total_num, 'Trainable': trainable_num} 72 | 73 | logger.info(f'Model Parameter Num - {get_parameter_number(model)}') 74 | 75 | model.cuda() 76 | 77 | # Optimizer -------------------------------------------------------------------------------------------------------- 78 | optimizer = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd) 79 | model, optimizer = amp.initialize(model, optimizer, enabled=cfg.fp16, opt_level='O1', verbosity=0) 80 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=cfg.lr_step, gamma=0.1) 81 | 82 | # Resume ----------------------------------------------------------------------------------------------------------- 83 | if cfg.resume: 84 | checkpoint = torch.load(cfg.resume) 85 | for key in list(checkpoint.keys()): 86 | model_state_dict = model.state_dict() 87 | if key in model_state_dict: 88 | if torch.is_tensor(checkpoint[key]) and checkpoint[key].shape != model_state_dict[key].shape: 89 | logger.info(f'Warning during loading weights - Auto remove mismatch key: {key}') 90 | checkpoint.pop(key) 91 | model.load_state_dict(checkpoint, strict=False) 92 | 93 | # Engine ----------------------------------------------------------------------------------------------------------- 94 | checkpoint_dir = os.path.join('ckptlog/', cfg.dataset, cfg.prefix) 95 | engine = get_trainer(dataset=cfg.dataset, 96 | model=model, 97 | optimizer=optimizer, 98 | lr_scheduler=lr_scheduler, 99 | logger=logger, 100 | writer=writer, 101 | non_blocking=True, 102 | log_period=cfg.log_period, 103 | save_dir=checkpoint_dir, 104 | prefix=cfg.prefix, 105 | eval_interval=cfg.eval_interval, 106 | start_eval=cfg.start_eval, 107 | gallery_loader=gallery_loader, 108 | query_loader=query_loader) 109 | engine.run(train_loader, max_epochs=cfg.num_epoch) 110 | writer.close() 111 | 112 | 113 | def test(cfg): 114 | # Logger ----------------------------------------------------------------------------------------------------------- 115 | logger = logging.getLogger('MCJA') 116 | logger.info('\n## Starting the testing process...') 117 | 118 | # Test DataLoader -------------------------------------------------------------------------------------------------- 119 | gallery_loader, query_loader = get_test_loader(dataset=cfg.dataset, 120 | root=cfg.data_root, 121 | batch_size=cfg.batch_size, 122 | image_size=cfg.image_size, 123 | num_workers=4, 124 | mode=None) 125 | if cfg.mser: 126 | gallery_loader_r, query_loader_r = get_test_loader(dataset=cfg.dataset, 127 | root=cfg.data_root, 128 | batch_size=cfg.batch_size, 129 | image_size=cfg.image_size, 130 | num_workers=4, 131 | mode='r') 132 | gallery_loader_g, query_loader_g = get_test_loader(dataset=cfg.dataset, 133 | root=cfg.data_root, 134 | batch_size=cfg.batch_size, 135 | image_size=cfg.image_size, 136 | num_workers=4, 137 | mode='g') 138 | gallery_loader_b, query_loader_b = get_test_loader(dataset=cfg.dataset, 139 | root=cfg.data_root, 140 | batch_size=cfg.batch_size, 141 | image_size=cfg.image_size, 142 | num_workers=4, 143 | mode='b') 144 | 145 | # Model ------------------------------------------------------------------------------------------------------------ 146 | model = MCJA(num_classes=cfg.num_id, 147 | drop_last_stride=cfg.drop_last_stride, 148 | mda_ratio=cfg.mda_ratio, 149 | mda_m=cfg.mda_m, 150 | loss_id=cfg.loss_id, 151 | loss_cmr=cfg.loss_cmr) 152 | model.cuda() 153 | model = amp.initialize(model, enabled=cfg.fp16, opt_level='O1', verbosity=0) 154 | 155 | # Resume ----------------------------------------------------------------------------------------------------------- 156 | resume_path = cfg.resume if cfg.resume else glob.glob(f'{cfg.log_dir}/*best*')[0] 157 | ## Note: if cfg.resume is specified, it will be used; 158 | ## otherwise, the best model trained in the current experiment will be automatically loaded. 159 | checkpoint = torch.load(resume_path) 160 | for key in list(checkpoint.keys()): 161 | model_state_dict = model.state_dict() 162 | if key in model_state_dict: 163 | if torch.is_tensor(checkpoint[key]) and checkpoint[key].shape != model_state_dict[key].shape: 164 | logger.info(f'Warning during loading weights - Auto remove mismatch key: {key}') 165 | checkpoint.pop(key) 166 | model.load_state_dict(checkpoint, strict=False) 167 | 168 | # Evaluator -------------------------------------------------------------------------------------------------------- 169 | non_blocking = True 170 | evaluator = create_eval_engine(model, non_blocking) 171 | # extract query feature 172 | evaluator.run(query_loader) 173 | q_feats = torch.cat(evaluator.state.feat_list, dim=0) 174 | q_ids = torch.cat(evaluator.state.id_list, dim=0).numpy() 175 | q_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy() 176 | q_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0) 177 | # extract gallery feature 178 | evaluator.run(gallery_loader) 179 | g_feats = torch.cat(evaluator.state.feat_list, dim=0) 180 | g_ids = torch.cat(evaluator.state.id_list, dim=0).numpy() 181 | g_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy() 182 | g_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0) 183 | 184 | if cfg.mser: 185 | ###### Multi-Spectral Enhanced Ranking (MSER) ###### 186 | evaluator = create_eval_engine(model, non_blocking) 187 | # extract query feature mode - r 188 | evaluator.run(query_loader_r) 189 | q_feats_r = torch.cat(evaluator.state.feat_list, dim=0) 190 | q_ids_r = torch.cat(evaluator.state.id_list, dim=0).numpy() 191 | q_cams_r = torch.cat(evaluator.state.cam_list, dim=0).numpy() 192 | q_img_paths_r = np.concatenate(evaluator.state.img_path_list, axis=0) 193 | # extract gallery feature mode - r 194 | evaluator.run(gallery_loader_r) 195 | g_feats_r = torch.cat(evaluator.state.feat_list, dim=0) 196 | g_ids_r = torch.cat(evaluator.state.id_list, dim=0).numpy() 197 | g_cams_r = torch.cat(evaluator.state.cam_list, dim=0).numpy() 198 | g_img_paths_r = np.concatenate(evaluator.state.img_path_list, axis=0) 199 | 200 | evaluator = create_eval_engine(model, non_blocking) 201 | # extract query feature mode - g 202 | evaluator.run(query_loader_g) 203 | q_feats_g = torch.cat(evaluator.state.feat_list, dim=0) 204 | q_ids_g = torch.cat(evaluator.state.id_list, dim=0).numpy() 205 | q_cams_g = torch.cat(evaluator.state.cam_list, dim=0).numpy() 206 | q_img_paths_g = np.concatenate(evaluator.state.img_path_list, axis=0) 207 | # extract gallery feature mode - g 208 | evaluator.run(gallery_loader_g) 209 | g_feats_g = torch.cat(evaluator.state.feat_list, dim=0) 210 | g_ids_g = torch.cat(evaluator.state.id_list, dim=0).numpy() 211 | g_cams_g = torch.cat(evaluator.state.cam_list, dim=0).numpy() 212 | g_img_paths_g = np.concatenate(evaluator.state.img_path_list, axis=0) 213 | 214 | evaluator = create_eval_engine(model, non_blocking) 215 | # extract query feature mode - b 216 | evaluator.run(query_loader_b) 217 | q_feats_b = torch.cat(evaluator.state.feat_list, dim=0) 218 | q_ids_b = torch.cat(evaluator.state.id_list, dim=0).numpy() 219 | q_cams_b = torch.cat(evaluator.state.cam_list, dim=0).numpy() 220 | q_img_paths_b = np.concatenate(evaluator.state.img_path_list, axis=0) 221 | # extract gallery feature mode - b 222 | evaluator.run(gallery_loader_b) 223 | g_feats_b = torch.cat(evaluator.state.feat_list, dim=0) 224 | g_ids_b = torch.cat(evaluator.state.id_list, dim=0).numpy() 225 | g_cams_b = torch.cat(evaluator.state.cam_list, dim=0).numpy() 226 | g_img_paths_b = np.concatenate(evaluator.state.img_path_list, axis=0) 227 | 228 | q_feats_mser = [q_feats, q_feats_r, q_feats_g, q_feats_b] 229 | q_ids_mser = [q_ids, q_ids_r, q_ids_g, q_ids_b] 230 | q_cams_mser = [q_cams, q_cams_r, q_cams_g, q_cams_b] 231 | q_img_paths_mser = [q_img_paths, q_img_paths_r, q_img_paths_g, q_img_paths_b] 232 | g_feats_mser = [g_feats, g_feats_r, g_feats_g, g_feats_b] 233 | g_ids_mser = [g_ids, g_ids_r, g_ids_g, g_ids_b] 234 | g_cams_mser = [g_cams, g_cams_r, g_cams_g, g_cams_b] 235 | g_img_paths_mser = [g_img_paths, g_img_paths_r, g_img_paths_g, g_img_paths_b] 236 | 237 | if cfg.dataset == 'sysu': 238 | perm = sio.loadmat(os.path.join(cfg.data_root, 'exp', 'rand_perm_cam.mat'))['rand_perm_cam'] 239 | eval_sysu(q_feats, q_ids, q_cams, q_img_paths, 240 | g_feats, g_ids, g_cams, g_img_paths, perm, mode='all', num_shots=1) 241 | eval_sysu(q_feats, q_ids, q_cams, q_img_paths, 242 | g_feats, g_ids, g_cams, g_img_paths, perm, mode='all', num_shots=10) 243 | eval_sysu(q_feats, q_ids, q_cams, q_img_paths, 244 | g_feats, g_ids, g_cams, g_img_paths, perm, mode='indoor', num_shots=1) 245 | eval_sysu(q_feats, q_ids, q_cams, q_img_paths, 246 | g_feats, g_ids, g_cams, g_img_paths, perm, mode='indoor', num_shots=10) 247 | if cfg.mser: 248 | eval_sysu(q_feats_mser, q_ids_mser, q_cams_mser, q_img_paths_mser, 249 | g_feats_mser, g_ids_mser, g_cams_mser, g_img_paths_mser, 250 | perm, mode='all', num_shots=1, mser=True) 251 | eval_sysu(q_feats_mser, q_ids_mser, q_cams_mser, q_img_paths_mser, 252 | g_feats_mser, g_ids_mser, g_cams_mser, g_img_paths_mser, 253 | perm, mode='all', num_shots=10, mser=True) 254 | eval_sysu(q_feats_mser, q_ids_mser, q_cams_mser, q_img_paths_mser, 255 | g_feats_mser, g_ids_mser, g_cams_mser, g_img_paths_mser, 256 | perm, mode='indoor', num_shots=1, mser=True) 257 | eval_sysu(q_feats_mser, q_ids_mser, q_cams_mser, q_img_paths_mser, 258 | g_feats_mser, g_ids_mser, g_cams_mser, g_img_paths_mser, 259 | perm, mode='indoor', num_shots=10, mser=True) 260 | 261 | elif cfg.dataset == 'regdb': 262 | logger.info('Test Mode - infrared to visible') 263 | eval_regdb(q_feats, q_ids, q_cams, q_img_paths, 264 | g_feats, g_ids, g_cams, g_img_paths, mode='i2v') 265 | logger.info('Test Mode - visible to infrared') 266 | eval_regdb(g_feats, g_ids, g_cams, g_img_paths, 267 | q_feats, q_ids, q_cams, q_img_paths, mode='v2i') 268 | if cfg.mser: 269 | logger.info('Test Mode - infrared to visible') 270 | eval_regdb(q_feats_mser, q_ids_mser, q_cams_mser, q_img_paths_mser, 271 | g_feats_mser, g_ids_mser, g_cams_mser, g_img_paths_mser, mode='i2v', mser=True) 272 | logger.info('Test Mode - visible to infrared') 273 | eval_regdb(g_feats_mser, g_ids_mser, g_cams_mser, g_img_paths_mser, 274 | q_feats_mser, q_ids_mser, q_cams_mser, q_img_paths_mser, mode='v2i', mser=True) 275 | else: 276 | raise NotImplementedError(f'Dataset - {cfg.dataset} is not supported') 277 | 278 | evaluator.state.feat_list.clear() 279 | evaluator.state.id_list.clear() 280 | evaluator.state.cam_list.clear() 281 | evaluator.state.img_path_list.clear() 282 | 283 | 284 | if __name__ == '__main__': 285 | # Tools ------------------------------------------------------------------------------------------------------------ 286 | import argparse 287 | from configs.default import strategy_cfg 288 | from configs.default import dataset_cfg 289 | from utils.tools import set_seed, time_str 290 | 291 | # Argument Parser -------------------------------------------------------------------------------------------------- 292 | parser = argparse.ArgumentParser() 293 | parser.add_argument('--cfg', type=str, default='configs/SYSU_MCJA.yml', help='customized strategy config') 294 | parser.add_argument('--seed', type=int, default=8, help='random seed - choose a lucky number') 295 | parser.add_argument('--desc', type=str, default=None, help='auxiliary description of this experiment') 296 | parser.add_argument('--gpu', type=str, default='0', help='GPU device for the training process') 297 | args = parser.parse_args() 298 | 299 | # Environment ------------------------------------------------------------------------------------------------------ 300 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 301 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 302 | set_seed(args.seed) 303 | 304 | # Configuration ---------------------------------------------------------------------------------------------------- 305 | ## strategy_cfg 306 | cfg = strategy_cfg 307 | cfg.merge_from_file(args.cfg) 308 | ## dataset_cfg 309 | dataset_cfg = dataset_cfg.get(cfg.dataset) 310 | for k, v in dataset_cfg.items(): 311 | cfg[k] = v 312 | ## other cfg 313 | cfg.prefix += f'_Time-{time_str()}' 314 | cfg.prefix += f'_{args.desc}' if (args.desc is not None) else '' 315 | cfg['log_dir'] = os.path.join('ckptlog/', cfg.dataset, cfg.prefix) 316 | ## freeze cfg 317 | cfg.freeze() 318 | 319 | # Logger --------------------------------------------------------------------------------------------------------- 320 | if not os.path.isdir(cfg.log_dir): 321 | os.makedirs(cfg.log_dir, exist_ok=True) 322 | logger = logging.getLogger('MCJA') 323 | logger.setLevel(logging.DEBUG) 324 | consoleHandler = logging.StreamHandler() 325 | consoleHandler.setLevel(logging.INFO) 326 | fileHandler = logging.FileHandler(filename=os.path.join(cfg.log_dir, 'log.txt')) 327 | fileHandler.setLevel(logging.INFO) 328 | formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='[%Y-%m-%d %H:%M:%S]') 329 | consoleHandler.setFormatter(formatter) 330 | fileHandler.setFormatter(formatter) 331 | logger.addHandler(consoleHandler) 332 | logger.addHandler(fileHandler) 333 | logger.info('\n' + pprint.pformat(cfg)) 334 | 335 | # Train & Test ----------------------------------------------------------------------------------------------------- 336 | if not cfg.test_only: 337 | train(cfg) 338 | test(cfg) 339 | 340 | # ------------------------------------------------------------------------------------------------------------------ 341 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """MCJA/models/__init__.py 2 | It is used to mark a directory as a Python package directory. 3 | """ -------------------------------------------------------------------------------- /models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | """MCJA/models/backbones/__init__.py 2 | It is used to mark a directory as a Python package directory. 3 | """ -------------------------------------------------------------------------------- /models/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | """MCJA/models/backbones/resnet.py 2 | It implements a series of ResNets (the code originates from the standard library of PyTorch). 3 | """ 4 | 5 | import torch 6 | from torch import Tensor 7 | import torch.nn as nn 8 | from torchvision.models.resnet import load_state_dict_from_url 9 | from typing import Type, Any, Callable, Union, List, Optional 10 | 11 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 12 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 13 | 'wide_resnet50_2', 'wide_resnet101_2'] 14 | 15 | model_urls = { 16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 21 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 22 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 23 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 24 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 25 | 26 | # 'resnet50': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth', 27 | # 'resnet50': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a2_0-a2746f79.pth', 28 | # 'resnet50': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a3_0-59cae1ef.pth', 29 | } 30 | 31 | 32 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 33 | """3x3 convolution with padding""" 34 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 35 | padding=dilation, groups=groups, bias=False, dilation=dilation) 36 | 37 | 38 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 39 | """1x1 convolution""" 40 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 41 | 42 | 43 | class BasicBlock(nn.Module): 44 | expansion: int = 1 45 | 46 | def __init__( 47 | self, 48 | inplanes: int, 49 | planes: int, 50 | stride: int = 1, 51 | downsample: Optional[nn.Module] = None, 52 | groups: int = 1, 53 | base_width: int = 64, 54 | dilation: int = 1, 55 | norm_layer: Optional[Callable[..., nn.Module]] = None 56 | ) -> None: 57 | super(BasicBlock, self).__init__() 58 | if norm_layer is None: 59 | norm_layer = nn.BatchNorm2d 60 | if groups != 1 or base_width != 64: 61 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 62 | if dilation > 1: 63 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 64 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 65 | self.conv1 = conv3x3(inplanes, planes, stride) 66 | self.bn1 = norm_layer(planes) 67 | self.relu = nn.ReLU(inplace=True) 68 | self.conv2 = conv3x3(planes, planes) 69 | self.bn2 = norm_layer(planes) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x: Tensor) -> Tensor: 74 | identity = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | 83 | if self.downsample is not None: 84 | identity = self.downsample(x) 85 | 86 | out += identity 87 | out = self.relu(out) 88 | 89 | return out 90 | 91 | 92 | class Bottleneck(nn.Module): 93 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 94 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 95 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 96 | # This variant is also known as ResNet V1.5 and improves accuracy according to 97 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 98 | 99 | expansion: int = 4 100 | 101 | def __init__( 102 | self, 103 | inplanes: int, 104 | planes: int, 105 | stride: int = 1, 106 | downsample: Optional[nn.Module] = None, 107 | groups: int = 1, 108 | base_width: int = 64, 109 | dilation: int = 1, 110 | norm_layer: Optional[Callable[..., nn.Module]] = None 111 | ) -> None: 112 | super(Bottleneck, self).__init__() 113 | if norm_layer is None: 114 | norm_layer = nn.BatchNorm2d 115 | width = int(planes * (base_width / 64.)) * groups 116 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 117 | self.conv1 = conv1x1(inplanes, width) 118 | self.bn1 = norm_layer(width) 119 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 120 | self.bn2 = norm_layer(width) 121 | self.conv3 = conv1x1(width, planes * self.expansion) 122 | self.bn3 = norm_layer(planes * self.expansion) 123 | self.relu = nn.ReLU(inplace=True) 124 | self.downsample = downsample 125 | self.stride = stride 126 | 127 | def forward(self, x: Tensor) -> Tensor: 128 | identity = x 129 | 130 | out = self.conv1(x) 131 | out = self.bn1(out) 132 | out = self.relu(out) 133 | 134 | out = self.conv2(out) 135 | out = self.bn2(out) 136 | out = self.relu(out) 137 | 138 | out = self.conv3(out) 139 | out = self.bn3(out) 140 | 141 | if self.downsample is not None: 142 | identity = self.downsample(x) 143 | 144 | out += identity 145 | out = self.relu(out) 146 | 147 | return out 148 | 149 | 150 | class ResNet(nn.Module): 151 | 152 | def __init__( 153 | self, 154 | block: Type[Union[BasicBlock, Bottleneck]], 155 | layers: List[int], 156 | num_classes: int = 1000, 157 | zero_init_residual: bool = False, 158 | groups: int = 1, 159 | width_per_group: int = 64, 160 | replace_stride_with_dilation: Optional[List[bool]] = None, 161 | norm_layer: Optional[Callable[..., nn.Module]] = None, 162 | **kwargs: Any 163 | ) -> None: 164 | super(ResNet, self).__init__() 165 | if norm_layer is None: 166 | norm_layer = nn.BatchNorm2d 167 | self._norm_layer = norm_layer 168 | 169 | self.inplanes = 64 170 | self.dilation = 1 171 | if replace_stride_with_dilation is None: 172 | # each element in the tuple indicates if we should replace 173 | # the 2x2 stride with a dilated convolution instead 174 | replace_stride_with_dilation = [False, False, False] 175 | if len(replace_stride_with_dilation) != 3: 176 | raise ValueError("replace_stride_with_dilation should be None " 177 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 178 | self.groups = groups 179 | self.base_width = width_per_group 180 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 181 | bias=False) 182 | self.bn1 = norm_layer(self.inplanes) 183 | self.relu = nn.ReLU(inplace=True) 184 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 185 | self.layer1 = self._make_layer(block, 64, layers[0]) 186 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 187 | dilate=replace_stride_with_dilation[0]) 188 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 189 | dilate=replace_stride_with_dilation[1]) 190 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1 if kwargs.get('drop_last_stride') else 2, 191 | dilate=replace_stride_with_dilation[2]) 192 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 193 | self.fc = nn.Linear(512 * block.expansion, num_classes) 194 | 195 | for m in self.modules(): 196 | if isinstance(m, nn.Conv2d): 197 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 198 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 199 | nn.init.constant_(m.weight, 1) 200 | nn.init.constant_(m.bias, 0) 201 | 202 | # Zero-initialize the last BN in each residual branch, 203 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 204 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 205 | if zero_init_residual: 206 | for m in self.modules(): 207 | if isinstance(m, Bottleneck): 208 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 209 | elif isinstance(m, BasicBlock): 210 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 211 | 212 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 213 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 214 | norm_layer = self._norm_layer 215 | downsample = None 216 | previous_dilation = self.dilation 217 | if dilate: 218 | self.dilation *= stride 219 | stride = 1 220 | if stride != 1 or self.inplanes != planes * block.expansion: 221 | downsample = nn.Sequential( 222 | conv1x1(self.inplanes, planes * block.expansion, stride), 223 | norm_layer(planes * block.expansion), 224 | ) 225 | 226 | layers = [] 227 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 228 | self.base_width, previous_dilation, norm_layer)) 229 | self.inplanes = planes * block.expansion 230 | for _ in range(1, blocks): 231 | layers.append(block(self.inplanes, planes, groups=self.groups, 232 | base_width=self.base_width, dilation=self.dilation, 233 | norm_layer=norm_layer)) 234 | 235 | return nn.Sequential(*layers) 236 | 237 | def _forward_impl(self, x: Tensor, mode) -> Tensor: 238 | # See note [TorchScript super()] 239 | x = self.conv1(x) 240 | x = self.bn1(x) 241 | x = self.relu(x) 242 | x = self.maxpool(x) 243 | 244 | x = self.layer1(x) 245 | x = self.layer2(x) 246 | x = self.layer3(x) 247 | x = self.layer4(x) 248 | 249 | if mode == 'features': 250 | return x 251 | 252 | x = self.avgpool(x) 253 | if mode == 'embeddings': 254 | return x 255 | 256 | x = torch.flatten(x, 1) 257 | x = self.fc(x) 258 | return x 259 | 260 | def forward(self, x: Tensor, mode='logits') -> Tensor: 261 | assert mode in ['features', 'embeddings', 'logits'] 262 | return self._forward_impl(x, mode) 263 | 264 | 265 | def _resnet( 266 | arch: str, 267 | block: Type[Union[BasicBlock, Bottleneck]], 268 | layers: List[int], 269 | pretrained: bool, 270 | progress: bool, 271 | **kwargs: Any 272 | ) -> ResNet: 273 | model = ResNet(block, layers, **kwargs) 274 | if pretrained: 275 | state_dict = load_state_dict_from_url(model_urls[arch], 276 | progress=progress) 277 | model.load_state_dict(state_dict) 278 | return model 279 | 280 | 281 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 282 | r"""ResNet-18 model from 283 | `"Deep Residual Learning for Image Recognition" `_. 284 | 285 | Args: 286 | pretrained (bool): If True, returns a model pre-trained on ImageNet 287 | progress (bool): If True, displays a progress bar of the download to stderr 288 | """ 289 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 290 | **kwargs) 291 | 292 | 293 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 294 | r"""ResNet-34 model from 295 | `"Deep Residual Learning for Image Recognition" `_. 296 | 297 | Args: 298 | pretrained (bool): If True, returns a model pre-trained on ImageNet 299 | progress (bool): If True, displays a progress bar of the download to stderr 300 | """ 301 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 302 | **kwargs) 303 | 304 | 305 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 306 | r"""ResNet-50 model from 307 | `"Deep Residual Learning for Image Recognition" `_. 308 | 309 | Args: 310 | pretrained (bool): If True, returns a model pre-trained on ImageNet 311 | progress (bool): If True, displays a progress bar of the download to stderr 312 | """ 313 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 314 | **kwargs) 315 | 316 | 317 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 318 | r"""ResNet-101 model from 319 | `"Deep Residual Learning for Image Recognition" `_. 320 | 321 | Args: 322 | pretrained (bool): If True, returns a model pre-trained on ImageNet 323 | progress (bool): If True, displays a progress bar of the download to stderr 324 | """ 325 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 326 | **kwargs) 327 | 328 | 329 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 330 | r"""ResNet-152 model from 331 | `"Deep Residual Learning for Image Recognition" `_. 332 | 333 | Args: 334 | pretrained (bool): If True, returns a model pre-trained on ImageNet 335 | progress (bool): If True, displays a progress bar of the download to stderr 336 | """ 337 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 338 | **kwargs) 339 | 340 | 341 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 342 | r"""ResNeXt-50 32x4d model from 343 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 344 | 345 | Args: 346 | pretrained (bool): If True, returns a model pre-trained on ImageNet 347 | progress (bool): If True, displays a progress bar of the download to stderr 348 | """ 349 | kwargs['groups'] = 32 350 | kwargs['width_per_group'] = 4 351 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 352 | pretrained, progress, **kwargs) 353 | 354 | 355 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 356 | r"""ResNeXt-101 32x8d model from 357 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 358 | 359 | Args: 360 | pretrained (bool): If True, returns a model pre-trained on ImageNet 361 | progress (bool): If True, displays a progress bar of the download to stderr 362 | """ 363 | kwargs['groups'] = 32 364 | kwargs['width_per_group'] = 8 365 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 366 | pretrained, progress, **kwargs) 367 | 368 | 369 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 370 | r"""Wide ResNet-50-2 model from 371 | `"Wide Residual Networks" `_. 372 | 373 | The model is the same as ResNet except for the bottleneck number of channels 374 | which is twice larger in every block. The number of channels in outer 1x1 375 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 376 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 377 | 378 | Args: 379 | pretrained (bool): If True, returns a model pre-trained on ImageNet 380 | progress (bool): If True, displays a progress bar of the download to stderr 381 | """ 382 | kwargs['width_per_group'] = 64 * 2 383 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 384 | pretrained, progress, **kwargs) 385 | 386 | 387 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 388 | r"""Wide ResNet-101-2 model from 389 | `"Wide Residual Networks" `_. 390 | 391 | The model is the same as ResNet except for the bottleneck number of channels 392 | which is twice larger in every block. The number of channels in outer 1x1 393 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 394 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 395 | 396 | Args: 397 | pretrained (bool): If True, returns a model pre-trained on ImageNet 398 | progress (bool): If True, displays a progress bar of the download to stderr 399 | """ 400 | kwargs['width_per_group'] = 64 * 2 401 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 402 | pretrained, progress, **kwargs) 403 | -------------------------------------------------------------------------------- /models/mcja.py: -------------------------------------------------------------------------------- 1 | """MCJA/models/mcja.py 2 | It defines the Multi-level Cross-modality Joint Alignment (MCJA) model, a framework for cross-modality VI-ReID task. 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from models.backbones.resnet import resnet50 9 | from models.modules.mda import MDA 10 | from losses.cm_retrieval_loss import CMRetrievalLoss 11 | from utils.calc_acc import calc_acc 12 | 13 | 14 | class MCJA(nn.Module): 15 | """ 16 | The Class of Multi-Channel Joint Analysis (MCJA) model, designed for cross-modality person re-identification tasks. 17 | This model integrates various components, including a backbone for feature extraction, the Modality Distribution 18 | Adapter (MDA) for better cross-modality feature alignment & distribution adaptation, a neck for feature embedding, 19 | a head for classification, and specialized loss functions (identity and cross-modality retrieval (CMR) losses). 20 | 21 | Args: 22 | - num_classes (int): The number of identity classes in the dataset. 23 | - drop_last_stride (bool): A flag to determine whether the last stride in the backbone should be dropped. 24 | - mda_ratio (int): The ratio for reducing the channel dimensions in MDA layers. 25 | - mda_m (int): The number of modalities considered by the MDA layers. 26 | - loss_id (bool): Whether to use the identity loss during training. 27 | - loss_cmr (bool): Whether to use the cross-modality retrieval loss during training. 28 | 29 | Methods: 30 | - forward(inputs, labels=None, **kwargs): Processes the input through the MCJA model. 31 | In training mode, it computes the loss and metrics based on the provided labels and additional information (e.g., 32 | camera IDs for modality labels). In evaluation mode, it returns the feature embeddings after BN neck processing. 33 | - train_forward(embeddings, labels, **kwargs): A helper function called during training to compute losses. 34 | It calculates the identity and CMR losses based on embeddings, identity labels, and modality labels. 35 | """ 36 | 37 | def __init__(self, num_classes, drop_last_stride=False, mda_ratio=2, mda_m=2, loss_id=True, loss_cmr=True): 38 | super(MCJA, self).__init__() 39 | 40 | # Backbone ----------------------------------------------------------------------------------------------------- 41 | self.backbone = resnet50(pretrained=True, drop_last_stride=drop_last_stride) 42 | self.base_dim = 2048 43 | 44 | # Neck --------------------------------------------------------------------------------------------------------- 45 | self.bn_neck = nn.BatchNorm1d(self.base_dim) 46 | nn.init.constant_(self.bn_neck.bias, 0) 47 | self.bn_neck.bias.requires_grad_(False) 48 | 49 | # Head --------------------------------------------------------------------------------------------------------- 50 | self.classifier = nn.Linear(self.base_dim, num_classes, bias=False) 51 | 52 | # Loss --------------------------------------------------------------------------------------------------------- 53 | self.id_loss = nn.CrossEntropyLoss() if loss_id else None 54 | ###### Cross-Modality Retrieval Loss (CMR) ###### 55 | self.cmr_loss = CMRetrievalLoss() if loss_cmr else None 56 | 57 | # Module ------------------------------------------------------------------------------------------------------- 58 | layers = [3, 4, 6, 3] # Just for ResNet50 59 | ###### Modality Distribution Adapter (MDA) ###### 60 | mda_layers = [0, 2, 3, 0] 61 | self.MDA_1 = nn.ModuleList( 62 | [MDA(in_channels=256, inter_ratio=mda_ratio, m_num=mda_m) for _ in range(mda_layers[0])]) 63 | self.MDA_1_idx = sorted([layers[0] - (i + 1) for i in range(mda_layers[0])]) 64 | self.MDA_2 = nn.ModuleList( 65 | [MDA(in_channels=512, inter_ratio=mda_ratio, m_num=mda_m) for _ in range(mda_layers[1])]) 66 | self.MDA_2_idx = sorted([layers[1] - (i + 1) for i in range(mda_layers[1])]) 67 | self.MDA_3 = nn.ModuleList( 68 | [MDA(in_channels=1024, inter_ratio=mda_ratio, m_num=mda_m) for _ in range(mda_layers[2])]) 69 | self.MDA_3_idx = sorted([layers[2] - (i + 1) for i in range(mda_layers[2])]) 70 | self.MDA_4 = nn.ModuleList( 71 | [MDA(in_channels=2048, inter_ratio=mda_ratio, m_num=mda_m) for _ in range(mda_layers[3])]) 72 | self.MDA_4_idx = sorted([layers[3] - (i + 1) for i in range(mda_layers[3])]) 73 | 74 | def forward(self, inputs, labels=None, **kwargs): 75 | 76 | # Feature Extraction ------------------------------------------------------------------------------------------- 77 | feats = self.backbone.conv1(inputs) 78 | feats = self.backbone.bn1(feats) 79 | feats = self.backbone.relu(feats) 80 | feats = self.backbone.maxpool(feats) 81 | 82 | MDA_1_counter = 0 83 | if len(self.MDA_1_idx) == 0: self.MDA_1_idx = [-1] 84 | for i in range(len(self.backbone.layer1)): 85 | feats = self.backbone.layer1[i](feats) 86 | if i == self.MDA_1_idx[MDA_1_counter]: 87 | _, C, H, W = feats.shape 88 | feats = self.MDA_1[MDA_1_counter](feats) 89 | MDA_1_counter += 1 90 | MDA_2_counter = 0 91 | if len(self.MDA_2_idx) == 0: self.MDA_2_idx = [-1] 92 | for i in range(len(self.backbone.layer2)): 93 | feats = self.backbone.layer2[i](feats) 94 | if i == self.MDA_2_idx[MDA_2_counter]: 95 | _, C, H, W = feats.shape 96 | feats = self.MDA_2[MDA_2_counter](feats) 97 | MDA_2_counter += 1 98 | MDA_3_counter = 0 99 | if len(self.MDA_3_idx) == 0: self.MDA_3_idx = [-1] 100 | for i in range(len(self.backbone.layer3)): 101 | feats = self.backbone.layer3[i](feats) 102 | if i == self.MDA_3_idx[MDA_3_counter]: 103 | _, C, H, W = feats.shape 104 | feats = self.MDA_3[MDA_3_counter](feats) 105 | MDA_3_counter += 1 106 | MDA_4_counter = 0 107 | if len(self.MDA_4_idx) == 0: self.MDA_4_idx = [-1] 108 | for i in range(len(self.backbone.layer4)): 109 | feats = self.backbone.layer4[i](feats) 110 | if i == self.MDA_4_idx[MDA_4_counter]: 111 | _, C, H, W = feats.shape 112 | feats = self.MDA_4[MDA_4_counter](feats) 113 | MDA_4_counter += 1 114 | global_feats = feats 115 | 116 | # Feature Embedding -------------------------------------------------------------------------------------------- 117 | b, c, h, w = global_feats.shape 118 | global_feats = global_feats.view(b, c, -1) 119 | p = 3.0 120 | embeddings = (torch.mean(global_feats ** p, dim=-1) + 1e-12) ** (1 / p) # GeMPooling 121 | 122 | # Train & Test Return ------------------------------------------------------------------------------------------ 123 | if self.training: 124 | return self.train_forward(embeddings, labels, **kwargs) 125 | else: 126 | return self.bn_neck(embeddings) 127 | 128 | def train_forward(self, embeddings, labels, **kwargs): 129 | loss = 0 130 | metric = {} 131 | 132 | embeddings = self.bn_neck(embeddings) 133 | 134 | # modality labels 135 | cam_ids = kwargs.get('cam_ids') 136 | rgb_idx_mask = (cam_ids == 1) + (cam_ids == 2) + (cam_ids == 4) + (cam_ids == 5) 137 | ir_idx_mask = (cam_ids == 3) + (cam_ids == 6) 138 | m_labels = torch.ones((len(labels))) 139 | m_labels[rgb_idx_mask] = 0 140 | m_labels[ir_idx_mask] = 1 141 | 142 | if self.cmr_loss is not None: 143 | ###### Cross-Modality Retrieval Loss (CMR) ###### 144 | cmr_loss = self.cmr_loss(embeddings.float(), id_labels=labels, m_labels=m_labels) 145 | loss += cmr_loss 146 | metric.update({'loss_cmr': cmr_loss.data}) 147 | 148 | logits = self.classifier(embeddings) 149 | 150 | if self.id_loss is not None: 151 | # Identity Loss (ID Loss) 152 | id_loss = self.id_loss(logits.float(), labels) 153 | loss += id_loss 154 | metric.update({'cls_acc': calc_acc(logits.data, labels), 'loss_id': id_loss.data}) 155 | 156 | return loss, metric 157 | -------------------------------------------------------------------------------- /models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """MCJA/models/modules/__init__.py 2 | It is used to mark a directory as a Python package directory. 3 | """ -------------------------------------------------------------------------------- /models/modules/mda.py: -------------------------------------------------------------------------------- 1 | """MCJA/models/modules/mda.py 2 | It defines the Modality Distribution Adapter (MDA) class, a module designed to enhance cross-modal feature learning. 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class MDA(nn.Module): 10 | """ 11 | A module implementing the Modality Distribution Adapter (MDA), a mechanism designed to enhance cross-modal learning 12 | by adaptively re-weighting feature channels based on their relevance to different modalities. The MDA module 13 | dynamically adjusts the contribution of different feature channels to the task at hand, depending on the context 14 | provided by different modalities, thereby facilitating more effective integration of multimodal information. 15 | 16 | Args: 17 | - in_channels (int): Number of channels in the input feature map. 18 | - inter_ratio (int): Reduction ratio for intermediate channel dimensions, controlling the compactness of the module. 19 | - m_num (int): Number of distinct modalities that the model needs to adapt to. 20 | 21 | Methods: 22 | - forward(x): Processes the input tensor through the MDA to produce an output with adapted feature distributions. 23 | """ 24 | 25 | def __init__(self, in_channels, inter_ratio=2, m_num=2): 26 | super(MDA, self).__init__() 27 | self.in_channels = in_channels 28 | self.inter_ratio = inter_ratio 29 | self.planes = in_channels // inter_ratio 30 | self.m_num = m_num 31 | 32 | self.sc_conv = nn.Conv2d(in_channels, m_num, kernel_size=1) 33 | self.softmax = nn.Softmax(dim=-1) 34 | 35 | self.ca_conv = nn.Conv1d(self.in_channels, self.in_channels, 36 | kernel_size=m_num, groups=self.in_channels, bias=False) 37 | self.ca_bn = nn.BatchNorm1d(self.in_channels) 38 | 39 | self.norm_bn = nn.BatchNorm2d(self.in_channels) 40 | nn.init.constant_(self.norm_bn.weight, 0.0) 41 | nn.init.constant_(self.norm_bn.bias, 0.0) 42 | 43 | def forward(self, x): 44 | input_x = x 45 | batch, channel, height, width = x.size() 46 | 47 | # Spatial Characteristics Learning ----------------------------------------------------------------------------- 48 | # [B, C, H, W] -> [B, C, H * W] 49 | input_x = input_x.view(batch, channel, height * width) 50 | # [B, C, H * W] -> [B, 1, C, H * W] 51 | input_x = input_x.unsqueeze(1) 52 | # [B, C, H, W] -> [B, M, H, W] 53 | context_mask = self.sc_conv(x) 54 | # [B, M, H, W] -> [B, M, H * W] 55 | context_mask = context_mask.view(batch, self.m_num, height * width) 56 | # [B, M, H * W] -> [B, M, H * W] 57 | context_mask = self.softmax(context_mask) 58 | # [B, M, H * W] -> [B, M, H * W, 1] 59 | context_mask = context_mask.unsqueeze(-1) 60 | # [B, 1, C, H * W] [B, M, H * W, 1] -> [B, M, C, 1] 61 | context = torch.matmul(input_x, context_mask) 62 | # [B, M, C, 1] -> [B, C, M] 63 | context = context.squeeze(-1).permute(0, 2, 1) 64 | 65 | # Characteristics Aggregation ---------------------------------------------------------------------------------- 66 | # [B, C, M] -> [B, C, 1] 67 | z = self.ca_conv(context) 68 | # [B, C, 1] -> [B, C, 1] 69 | z = self.ca_bn(z) 70 | # [B, C, 1] -> [B, C, 1] 71 | g = torch.sigmoid(z) 72 | 73 | # Feature Distribution Adaption -------------------------------------------------------------------------------- 74 | # [B, C, 1] -> [B, C, 1, 1] 75 | g = g.view(batch, channel, 1, 1) 76 | # [B, C, H, W] [B, C, 1, 1] -> [B, C, H, W] 77 | out = self.norm_bn(x * g.expand_as(x)) + x 78 | 79 | return out 80 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | """MCJA/utils/__init__.py 2 | It is used to mark a directory as a Python package directory. 3 | """ -------------------------------------------------------------------------------- /utils/calc_acc.py: -------------------------------------------------------------------------------- 1 | """MCJA/utils/calc_acc.py 2 | This utility file defines a function `calc_acc` for calculating classification accuracy. 3 | """ 4 | 5 | import torch 6 | 7 | 8 | def calc_acc(logits, label, ignore_index=-100, mode="multiclass"): 9 | """ 10 | A utility function for calculating the accuracy of model predictions given the logits and corresponding labels. 11 | It supports both binary and multiclass classification tasks by interpreting the logits according to the specified 12 | mode and comparing them against the ground truth labels to determine the number of correct predictions. 13 | The function also accommodates scenarios where certain examples should be ignored in the accuracy calculation, 14 | based on a designated ignore_index or the structure of the labels. 15 | 16 | Args: 17 | - logits (Tensor): The output logits from a model. For binary classification, logits should be a 1D tensor of 18 | probabilities. For multiclass classification, logits should be a 2D tensor with shape [batch_size, num_classes]. 19 | - label (Tensor): The ground truth labels for the predictions. For multiclass classification, labels should be a 20 | 1D tensor of class indices. For binary classification, labels should be a tensor with the same shape as logits. 21 | - ignore_index (int, optional): Specifies a label value that should be ignored when calculating accuracy. 22 | Examples with this label are not considered in the denominator of the accuracy calculation. Default is -100. 23 | - mode (str, optional): Determines how logits are interpreted. Can be "binary" for binary classification tasks, 24 | where logits are rounded to 0 or 1, or "multiclass" for tasks with more than two classes, where the class with 25 | the highest logit is selected. Default is "multiclass". 26 | 27 | Returns: 28 | - Tensor: The calculated accuracy as a float value, representing the proportion of correct predictions out 29 | of the total number of examples considered (excluding ignored examples). 30 | """ 31 | 32 | if mode == "binary": 33 | indices = torch.round(logits).type(label.type()) 34 | elif mode == "multiclass": 35 | indices = torch.max(logits, dim=1)[1] 36 | 37 | if label.size() == logits.size(): 38 | ignore = 1 - torch.round(label.sum(dim=1)) 39 | label = torch.max(label, dim=1)[1] 40 | else: 41 | ignore = torch.eq(label, ignore_index).view(-1) 42 | 43 | correct = torch.eq(indices, label).view(-1) 44 | num_correct = torch.sum(correct) 45 | num_examples = logits.shape[0] - ignore.sum() 46 | 47 | return num_correct.float() / num_examples.float() 48 | -------------------------------------------------------------------------------- /utils/eval_data.py: -------------------------------------------------------------------------------- 1 | """MCJA/utils/eval_data.py 2 | It provides evaluation utilities for assessing the performance of cross-modal person re-identification methods. 3 | """ 4 | 5 | import os 6 | import logging 7 | import numpy as np 8 | import torch 9 | from torch.nn import functional as F 10 | from .mser_rerank import mser_rerank, pairwise_distance 11 | 12 | 13 | def get_gallery_names_sysu(perm, cams, ids, trial_id, num_shots=1): 14 | """ 15 | A utility function designed specifically for constructing a list of gallery image file paths for the SYSU-MM01 16 | dataset in a cross-modality person re-identification task. This function takes a permutation array that organizes 17 | the data into camera views and identities, and then selects a specified number of shots (instances) for each 18 | identity from the desired cameras to compile the gallery set for a given trial. The output is a list of formatted 19 | strings that represent the file paths to the selected gallery images, adhering to the dataset's directory structure. 20 | 21 | Args: 22 | - perm (list): A nested list where each element corresponds to a camera view in the dataset. 23 | Each camera's list contains arrays of instance indices for each identity, organized by trials. 24 | - cams (list): A list of integers indicating which camera views to include in the gallery set. 25 | Camera numbers should match those used in the dataset. 26 | - ids (list): A list of integers specifying the identities to be included in the gallery set. 27 | Identity numbers should correspond to those in the dataset. 28 | - trial_id (int): The index of the trial for which to construct the gallery. 29 | This index is used to select specific instances from the permutation arrays, 30 | allowing for variability across different evaluation runs. 31 | - num_shots (int, optional): The number of shots to select for each identity from each camera view. 32 | 33 | Returns: 34 | - list: A list of strings, each representing the file path to an image selected for the gallery. 35 | Paths are formatted to match the directory structure of the SYSU-MM01 dataset. 36 | """ 37 | 38 | names = [] 39 | for cam in cams: 40 | cam_perm = perm[cam - 1][0].squeeze() 41 | for i in ids: 42 | if (i - 1) < len(cam_perm) and len(cam_perm[i - 1]) > 0: 43 | instance_id = cam_perm[i - 1][trial_id][:num_shots] 44 | names.extend(['cam{}/{:0>4d}/{:0>4d}'.format(cam, i, ins) for ins in instance_id.tolist()]) 45 | return names 46 | 47 | 48 | def get_unique(array): 49 | """ 50 | A utility function that returns a sorted unique array of elements from the input array. It identifies all unique 51 | elements within the input array and selects their first occurrence, preserving the order of these unique elements 52 | based on their initial appearance in the input array. This function is particularly useful for processing arrays 53 | where the uniqueness and order of elements are essential, such as when filtering duplicate entries from lists of 54 | identifiers or categories without disrupting their original sequence. 55 | 56 | Args: 57 | - array (ndarray): An input array from which unique elements are to be extracted. 58 | 59 | Returns: 60 | - ndarray: A new array containing only the unique elements of the input array, 61 | sorted according to their first occurrence in the original array. 62 | """ 63 | 64 | _, idx = np.unique(array, return_index=True) 65 | array_new = array[np.sort(idx)] 66 | return array_new 67 | 68 | 69 | def get_cmc(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids): 70 | """ 71 | A utility function for calculating the Cumulative Matching Characteristics (CMC) curve in person re-identification 72 | tasks. The CMC curve is a standard evaluation metric used to measure the performance of a re-identification model 73 | by determining the probability that a query identity appears in different sized candidate lists. This function 74 | processes the sorted indices of gallery samples for each query, excludes gallery samples captured by the same 75 | camera as the query to avoid camera bias, and computes the CMC curve based on the first correct match's position. 76 | 77 | Args: 78 | - sorted_indices (ndarray): An array of indices that sorts the gallery samples 79 | in ascending order of their distance to each query sample. 80 | - query_ids (ndarray): An array containing the identity labels of the query samples. 81 | - query_cam_ids (ndarray): An array containing the camera IDs associated with each query sample. 82 | - gallery_ids (ndarray): An array containing the identity labels of the gallery samples. 83 | - gallery_cam_ids (ndarray): An array containing the camera IDs associated with each gallery sample. 84 | 85 | Returns: 86 | - ndarray: The CMC curve represented as a 1D array where each element at index i indicates the probability 87 | that a query identity is correctly matched within the top-(i+1) ranks of the sorted gallery list. 88 | """ 89 | 90 | gallery_unique_count = get_unique(gallery_ids).shape[0] 91 | match_counter = np.zeros((gallery_unique_count,)) 92 | 93 | result = gallery_ids[sorted_indices] 94 | cam_locations_result = gallery_cam_ids[sorted_indices] 95 | 96 | valid_probe_sample_count = 0 97 | 98 | for probe_index in range(sorted_indices.shape[0]): 99 | # remove gallery samples from the same camera of the probe 100 | result_i = result[probe_index, :] 101 | result_i[np.equal(cam_locations_result[probe_index], query_cam_ids[probe_index])] = -1 102 | 103 | # remove the -1 entries from the label result 104 | result_i = np.array([i for i in result_i if i != -1]) 105 | 106 | # remove duplicated id in "stable" manner - following the official test protocol in VI-ReID 107 | result_i_unique = get_unique(result_i) 108 | 109 | # match for probe i 110 | match_i = np.equal(result_i_unique, query_ids[probe_index]) 111 | 112 | if np.sum(match_i) != 0: # if there is true matching in gallery 113 | valid_probe_sample_count += 1 114 | match_counter += match_i 115 | 116 | rank = match_counter / valid_probe_sample_count 117 | cmc = np.cumsum(rank) 118 | return cmc 119 | 120 | 121 | def get_mAP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids): 122 | """ 123 | A utility function for computing the mean Average Precision (mAP) for evaluating person re-identification methods. 124 | The mAP metric provides a single-figure measure of quality across recall levels, particularly useful in scenarios 125 | where the query identity appears multiple times in the gallery. This function iterates over each query, excludes 126 | gallery images captured by the same camera to prevent bias, and calculates the Average Precision (AP) for each 127 | query based on its matches in the gallery. The mAP is then obtained by averaging the APs across all queries. 128 | 129 | Args: 130 | - sorted_indices (ndarray): An array of indices that sorts the gallery samples 131 | in ascending order of their distance to each query sample. 132 | - query_ids (ndarray): An array containing the identity labels of the query samples. 133 | - query_cam_ids (ndarray): An array containing the camera IDs associated with each query sample. 134 | - gallery_ids (ndarray): An array containing the identity labels of the gallery samples. 135 | - gallery_cam_ids (ndarray): An array containing the camera IDs associated with each gallery sample. 136 | 137 | Returns: 138 | - float: The mean Average Precision (mAP) calculated across all query samples, 139 | representing the overall precision of the re-identification method at varying levels of recall. 140 | """ 141 | 142 | result = gallery_ids[sorted_indices] 143 | cam_locations_result = gallery_cam_ids[sorted_indices] 144 | 145 | valid_probe_sample_count = 0 146 | avg_precision_sum = 0 147 | 148 | for probe_index in range(sorted_indices.shape[0]): 149 | # remove gallery samples from the same camera of the probe 150 | result_i = result[probe_index, :] 151 | result_i[cam_locations_result[probe_index, :] == query_cam_ids[probe_index]] = -1 152 | 153 | # remove the -1 entries from the label result 154 | result_i = np.array([i for i in result_i if i != -1]) 155 | 156 | # match for probe i 157 | match_i = result_i == query_ids[probe_index] 158 | true_match_count = np.sum(match_i) 159 | 160 | if true_match_count != 0: # if there is true matching in gallery 161 | valid_probe_sample_count += 1 162 | true_match_rank = np.where(match_i)[0] 163 | 164 | ap = np.mean(np.arange(1, true_match_count + 1) / (true_match_rank + 1)) 165 | avg_precision_sum += ap 166 | 167 | mAP = avg_precision_sum / valid_probe_sample_count 168 | return mAP 169 | 170 | 171 | def get_mINP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids): 172 | """ 173 | A utility function designed for evaluating the mean Inverse Negative Penalty (mINP) across all query samples in 174 | a person re-identification method. The mINP metric focuses on the hardest positive sample's position in the ranked 175 | list of gallery samples for each query, providing insight into the system's ability to recall all relevant instances 176 | of an identity. This function computes the INP for each query by excluding gallery samples captured by the same 177 | camera as the query, identifying the rank position of the farthest correct match, and calculating the INP based on 178 | this position. The mean INP is then derived by averaging the INP scores across all valid queries. 179 | 180 | Args: 181 | - sorted_indices (ndarray): An array of indices that sorts the gallery samples 182 | in ascending order of their distance to each query sample. 183 | - query_ids (ndarray): An array containing the identity labels of the query samples. 184 | - query_cam_ids (ndarray): An array containing the camera IDs associated with each query sample. 185 | - gallery_ids (ndarray): An array containing the identity labels of the gallery samples. 186 | - gallery_cam_ids (ndarray): An array containing the camera IDs associated with each gallery sample. 187 | 188 | Returns: 189 | - float: The mean Inverse Negative Penalty (mINP) calculated across all queries, reflecting the method's 190 | effectiveness in retrieving relevant matches from gallery, particularly the most challenging matches to identify. 191 | """ 192 | 193 | result = gallery_ids[sorted_indices] 194 | cam_locations_result = gallery_cam_ids[sorted_indices] 195 | 196 | valid_probe_sample_count = 0 197 | INP_sum = 0 198 | 199 | for probe_index in range(sorted_indices.shape[0]): 200 | # remove gallery samples from the same camera of the probe 201 | result_i = result[probe_index, :] 202 | result_i[np.equal(cam_locations_result[probe_index], query_cam_ids[probe_index])] = -1 203 | 204 | # remove the -1 entries from the label result 205 | result_i = np.array([i for i in result_i if i != -1]) 206 | 207 | # match for probe i 208 | match_i = result_i == query_ids[probe_index] 209 | true_match_count = np.sum(match_i) 210 | 211 | if true_match_count != 0: # if there is true matching in gallery 212 | valid_probe_sample_count += 1 213 | true_match_rank = np.where(match_i)[0] 214 | hardest_match_pos = true_match_rank[-1] + 1 215 | 216 | NP = (hardest_match_pos - true_match_count) / hardest_match_pos 217 | INP = 1 - NP 218 | 219 | INP_sum += INP 220 | 221 | mINP = INP_sum / valid_probe_sample_count 222 | 223 | return mINP 224 | 225 | 226 | def get_cmc_mAP_mINP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids): 227 | """ 228 | A comprehensive utility function designed to compute three key metrics simultaneously for evaluating the performance 229 | of person re-identification methods: Cumulative Matching Characteristics (CMC), mean Average Precision (mAP), and 230 | mean Inverse Negative Penalty (mINP). This function integrates the processes of calculating these metrics into a 231 | single operation, optimizing the evaluation workflow for the person re-identification task. 232 | 233 | Args: 234 | - sorted_indices (ndarray): Indices sorting the gallery samples by ascending similarity to each query sample. 235 | - query_ids (ndarray): Identity labels for the query samples. 236 | - query_cam_ids (ndarray): Camera IDs from which each query sample was captured. 237 | - gallery_ids (ndarray): Identity labels for the gallery samples. 238 | - gallery_cam_ids (ndarray): Camera IDs from which each gallery sample was captured. 239 | 240 | Returns: 241 | - tuple: A tuple containing three elements: 242 | - cmc (ndarray): The CMC curve as a 1D array. 243 | - mAP (float): The mean Average Precision score. 244 | - mINP (float): The mean Inverse Negative Penalty score. 245 | """ 246 | 247 | gallery_unique_count = get_unique(gallery_ids).shape[0] 248 | match_counter = np.zeros((gallery_unique_count,)) 249 | 250 | result = gallery_ids[sorted_indices] 251 | cam_locations_result = gallery_cam_ids[sorted_indices] 252 | 253 | valid_probe_sample_count = 0 254 | avg_precision_sum = 0 255 | INP_sum = 0 256 | 257 | for probe_index in range(sorted_indices.shape[0]): 258 | # remove gallery samples from the same camera of the probe 259 | result_i = result[probe_index, :] 260 | result_i[np.equal(cam_locations_result[probe_index], query_cam_ids[probe_index])] = -1 261 | 262 | # remove the -1 entries from the label result 263 | result_i = np.array([i for i in result_i if i != -1]) 264 | 265 | # remove duplicated id in "stable" manner - following the official test protocol in VI-ReID 266 | result_i_unique = get_unique(result_i) 267 | 268 | # match for probe i 269 | match_i = result_i == query_ids[probe_index] 270 | true_match_count = np.sum(match_i) 271 | match_i_unique = np.equal(result_i_unique, query_ids[probe_index]) 272 | 273 | if true_match_count != 0: # if there is true matching in gallery 274 | valid_probe_sample_count += 1 275 | if match_counter.shape != match_i_unique.shape: 276 | sub_num = match_counter.shape[0] - match_i_unique.shape[0] 277 | match_i_unique = np.hstack([match_i_unique, [False] * sub_num]) 278 | match_counter += match_i_unique 279 | true_match_rank = np.where(match_i)[0] 280 | ap = np.mean(np.arange(1, true_match_count + 1) / (true_match_rank + 1)) 281 | avg_precision_sum += ap 282 | hardest_match_pos = true_match_rank[-1] + 1 283 | NP = (hardest_match_pos - true_match_count) / hardest_match_pos 284 | INP = 1 - NP 285 | INP_sum += INP 286 | 287 | rank = match_counter / valid_probe_sample_count 288 | cmc = np.cumsum(rank) 289 | mAP = avg_precision_sum / valid_probe_sample_count 290 | mINP = INP_sum / valid_probe_sample_count 291 | return cmc, mAP, mINP 292 | 293 | 294 | def eval_sysu(query_feats, query_ids, query_cam_ids, query_img_paths, 295 | gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths, 296 | perm, mode='all', num_shots=1, num_trials=10, mser=False): 297 | """ 298 | A versatile function designed for evaluating the performance of VI-ReID models on the SYSU-MM01 dataset, 299 | offering the flexibility to apply either a basic evaluation strategy or the Multi-Spectral Enhanced Ranking (MSER) 300 | strategy based on the specified parameters. This function orchestrates the evaluation process across multiple 301 | trials, adjusting for different experimental settings such as the evaluation mode and the number of gallery shots. 302 | 303 | Args: 304 | - query_feats (Tensor or List[Tensor]): Feature vectors of query images. 305 | - query_ids (Tensor or List[Tensor]): Identity labels associated with query images. 306 | - query_cam_ids (Tensor or List[Tensor]): Camera IDs from which each query image was captured. 307 | - query_img_paths (ndarray or List[ndarray]): File paths of query images. 308 | - gallery_feats (Tensor or List[Tensor]): Feature vectors of gallery images. 309 | - gallery_ids (Tensor or List[Tensor]): Identity labels associated with gallery images. 310 | - gallery_cam_ids (Tensor or List[Tensor]): Camera IDs from which each gallery image was captured. 311 | - gallery_img_paths (ndarray or List[ndarray]): File paths of gallery images. 312 | - perm (ndarray): A permutation array for determining gallery subsets in each trial. 313 | - mode (str): Specifies the subset of gallery images to use. 314 | Options include 'indoor' for indoor cameras only and 'all' for all cameras. 315 | - num_shots (int): Number of instances of each identity to include in the gallery for each trial. 316 | - num_trials (int): Number of trials to perform, with each trial potentially using a different subset of gallery. 317 | - mser (bool): A flag indicating whether to use the MSER strategy for re-ranking gallery images. 318 | If set to False, a basic evaluation strategy is used. 319 | 320 | Returns: 321 | - tuple: A tuple containing average metrics of rank-1, rank-5, rank-10, rank-20 precision, mAP, and mINP 322 | across all trials, along with detailed rank results for further analysis. 323 | """ 324 | 325 | if mser: 326 | return eval_sysu_mser(query_feats, query_ids, query_cam_ids, query_img_paths, 327 | gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths, 328 | perm, mode, num_shots, num_trials) 329 | return eval_sysu_base(query_feats, query_ids, query_cam_ids, query_img_paths, 330 | gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths, 331 | perm, mode, num_shots, num_trials) 332 | 333 | 334 | def eval_sysu_base(query_feats, query_ids, query_cam_ids, query_img_paths, 335 | gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths, 336 | perm, mode='all', num_shots=1, num_trials=10): 337 | """ 338 | A function designed for evaluating the performance of a VI-ReID model on the SYSU-MM01 dataset under specific 339 | experimental settings. This function conducts evaluations across multiple trials, each trial potentially utilizing 340 | a different subset of gallery images based on the specified mode (indoor or all locations) and the number of shots. 341 | It computes re-identification metrics including rank-1, rank-5, rank-10, rank-20 precision, mean Average Precision 342 | (mAP), and mean Inverse Negative Penalty (mINP) across all trials, averaging the results to provide a comprehensive 343 | assessment of the model's performance. 344 | 345 | Args: 346 | - query_feats (Tensor): The feature representations of query images. 347 | - query_ids (Tensor): The identity labels associated with query images. 348 | - query_cam_ids (Tensor): The camera IDs from which each query image was captured. 349 | - query_img_paths (ndarray): The file paths of query images. 350 | - gallery_feats (Tensor): The feature representations of gallery images. 351 | - gallery_ids (Tensor): The identity labels associated with gallery images. 352 | - gallery_cam_ids (Tensor): The camera IDs from which each gallery image was captured. 353 | - gallery_img_paths (ndarray): The file paths of gallery images. 354 | - perm (ndarray): A permutation array used for determining gallery subsets in each trial. 355 | - mode (str): Specifies subset of gallery images to use ('indoor' for indoor cameras only, 'all' for all cameras). 356 | - num_shots (int): The number of instances per identity (per cameras) to include in the gallery for each trial. 357 | - num_trials (int): The number of trials to perform, with each trial potentially using a different gallery subset. 358 | 359 | Returns: 360 | - tuple: A tuple containing the average values of rank-1, rank-5, rank-10, rank-20 precision, mAP, and mINP 361 | across all trials, along with detailed rank results for each trial. 362 | """ 363 | 364 | assert mode in ['indoor', 'all'] 365 | 366 | gallery_cams = [1, 2] if mode == 'indoor' else [1, 2, 4, 5] 367 | 368 | # cam2 and cam3 are in the same location 369 | query_cam_ids[np.equal(query_cam_ids, 3)] = 2 370 | query_feats = F.normalize(query_feats, dim=1) 371 | 372 | gallery_indices = np.in1d(gallery_cam_ids, gallery_cams) 373 | 374 | gallery_feats = gallery_feats[gallery_indices] 375 | gallery_feats = F.normalize(gallery_feats, dim=1) 376 | gallery_ids = gallery_ids[gallery_indices] 377 | gallery_cam_ids = gallery_cam_ids[gallery_indices] 378 | gallery_img_paths = gallery_img_paths[gallery_indices] 379 | gallery_names = np.array(['/'.join(os.path.splitext(path)[0].split('/')[-3:]) for path in gallery_img_paths]) 380 | 381 | gallery_id_set = np.unique(gallery_ids) 382 | 383 | r1, r5, r10, r20, mAP, mINP = 0, 0, 0, 0, 0, 0 384 | rank_results = [] 385 | for t in range(num_trials): 386 | names = get_gallery_names_sysu(perm, gallery_cams, gallery_id_set, t, num_shots) 387 | flag = np.in1d(gallery_names, names) 388 | 389 | g_feats = gallery_feats[flag] 390 | g_ids = gallery_ids[flag] 391 | g_cam_ids = gallery_cam_ids[flag] 392 | g_img_paths = gallery_img_paths[flag] 393 | 394 | # dist_mat = pairwise_distance(query_feats, g_feats) # A 395 | dist_mat = -torch.mm(query_feats, g_feats.permute(1, 0)) # B 396 | # When using normalization on extracted features, these two distance measures are equivalent (A = 2 + 2 * B) 397 | # B is a little faster than A 398 | sorted_indices = np.argsort(dist_mat, axis=1) 399 | 400 | cur_rank_results = dict() 401 | cur_rank_results['query_ids'] = query_ids 402 | cur_rank_results['query_img_paths'] = query_img_paths 403 | cur_rank_results['gallery_ids'] = g_ids 404 | cur_rank_results['gallery_img_paths'] = g_img_paths 405 | cur_rank_results['dist_mat'] = dist_mat 406 | cur_rank_results['sorted_indices'] = sorted_indices 407 | rank_results.append(cur_rank_results) 408 | 409 | # cur_cmc = get_cmc(sorted_indices, query_ids, query_cam_ids, g_ids, g_cam_ids) 410 | # cur_mAP = get_mAP(sorted_indices, query_ids, query_cam_ids, g_ids, g_cam_ids) 411 | # cur_mINP = get_mINP(sorted_indices, query_ids, query_cam_ids, g_ids, g_cam_ids) 412 | cur_cmc, cur_mAP, cur_mINP = get_cmc_mAP_mINP(sorted_indices, query_ids, query_cam_ids, g_ids, g_cam_ids) 413 | r1 += cur_cmc[0] 414 | r5 += cur_cmc[4] 415 | r10 += cur_cmc[9] 416 | r20 += cur_cmc[19] 417 | mAP += cur_mAP 418 | mINP += cur_mINP 419 | 420 | r1 = r1 / num_trials * 100 421 | r5 = r5 / num_trials * 100 422 | r10 = r10 / num_trials * 100 423 | r20 = r20 / num_trials * 100 424 | mAP = mAP / num_trials * 100 425 | mINP = mINP / num_trials * 100 426 | 427 | logger = logging.getLogger('MCJA') 428 | logger.info('-' * 150) 429 | perf = '{} num-shot:{} r1 precision = {:.2f} , r10 precision = {:.2f} , r20 precision = {:.2f} , ' \ 430 | 'mAP = {:.2f} , mINP = {:.2f}' 431 | logger.info(perf.format(mode, num_shots, r1, r10, r20, mAP, mINP)) 432 | logger.info('-' * 150) 433 | 434 | return r1, r5, r10, r20, mAP, mINP, rank_results 435 | 436 | 437 | def eval_sysu_mser(query_feats_list, query_ids_list, query_cam_ids_list, query_img_paths_list, 438 | gallery_feats_list, gallery_ids_list, gallery_cam_ids_list, gallery_img_paths_list, 439 | perm, mode='all', num_shots=1, num_trials=10): 440 | """ 441 | A function designed to evaluate the performance of a VI-ReID model using the Multi-Spectral Enhanced Ranking (MSER) 442 | strategy on the SYSU-MM01 dataset. The MSER strategy involves a novel re-ranking process that enhances the initial 443 | ranking of gallery images based on their similarity to query images, considering multiple spectral representations. 444 | This evaluation function also supports variable experimental settings, such as different evaluation modes and the 445 | number of shots, across multiple trials for a comprehensive performance assessment. 446 | 447 | Args: 448 | - query_feats_list (List[Tensor]): List of tensors representing features of query images. 449 | - query_ids_list (List[Tensor]): List of tensors containing the identity labels of the query images. 450 | - query_cam_ids_list (List[Tensor]): List of tensors with camera IDs from which each query image was captured. 451 | - query_img_paths_list (List[ndarray]): List of ndarrays holding the file paths for each query image. 452 | - gallery_feats_list (List[Tensor]): List of tensors representing the feature vectors of gallery images. 453 | - gallery_ids_list (List[Tensor]): List of tensors containing the identity labels of the gallery images. 454 | - gallery_cam_ids_list (List[Tensor]): List of tensors with camera IDs from which each gallery image was captured. 455 | - gallery_img_paths_list (List[ndarray]): List of ndarrays holding the file paths for each gallery image. 456 | - perm (ndarray): A permutation array for determining the subsets of gallery images used in each trial. 457 | - mode (str): Specifies the subset of gallery images to use for evaluation. 458 | Options include 'indoor' for indoor cameras only and 'all' for all cameras. 459 | - num_shots (int): Specifies the number of instances of each identity to include in the gallery set for each trial. 460 | - num_trials (int): The number of trials to perform, with each trial using a different subset of gallery images. 461 | 462 | Returns: 463 | - tuple: A tuple containing the average metrics of rank-1, rank-5, rank-10, rank-20 precision, mean Average 464 | Precision (mAP), and mean Inverse Negative Penalty (mINP) across all trials. Additionally, detailed rank 465 | results for each trial are provided for further analysis. 466 | """ 467 | 468 | assert mode in ['indoor', 'all'] 469 | 470 | gallery_cams = [1, 2] if mode == 'indoor' else [1, 2, 4, 5] 471 | 472 | list_num = len(query_feats_list) 473 | 474 | for c in range(list_num): 475 | # cam2 and cam3 are in the same location 476 | query_cam_ids_list[c][np.equal(query_cam_ids_list[c], 3)] = 2 477 | query_feats_list[c] = F.normalize(query_feats_list[c], dim=1) 478 | 479 | gallery_indices = np.in1d(gallery_cam_ids_list[c], gallery_cams) 480 | 481 | gallery_feats_list[c] = gallery_feats_list[c][gallery_indices] 482 | gallery_feats_list[c] = F.normalize(gallery_feats_list[c], dim=1) 483 | gallery_ids_list[c] = gallery_ids_list[c][gallery_indices] 484 | gallery_cam_ids_list[c] = gallery_cam_ids_list[c][gallery_indices] 485 | gallery_img_paths_list[c] = gallery_img_paths_list[c][gallery_indices] 486 | 487 | gallery_names = np.array( 488 | ['/'.join(os.path.splitext(path)[0].split('/')[-3:]) for path in gallery_img_paths_list[0]]) 489 | gallery_id_set = np.unique(gallery_ids_list[0]) 490 | 491 | r1, r5, r10, r20, mAP, mINP = 0, 0, 0, 0, 0, 0 492 | rank_results = [] 493 | for t in range(num_trials): 494 | names = get_gallery_names_sysu(perm, gallery_cams, gallery_id_set, t, num_shots) 495 | flag = np.in1d(gallery_names, names) 496 | 497 | g_feats_list, g_ids_list, g_cam_ids_list, g_img_paths_list = [], [], [], [] 498 | for c in range(list_num): 499 | g_feats = gallery_feats_list[c][flag] 500 | g_ids = gallery_ids_list[c][flag] 501 | g_cam_ids = gallery_cam_ids_list[c][flag] 502 | g_img_paths = gallery_img_paths_list[c][flag] 503 | g_feats_list.append(g_feats) 504 | g_ids_list.append(g_ids) 505 | g_cam_ids_list.append(g_cam_ids) 506 | g_img_paths_list.append(g_img_paths) 507 | 508 | dist_mat = mser_rerank(query_feats_list, g_feats_list, 509 | k1=40, k2=20, lambda_value=0.3, mode='i2v') 510 | sorted_indices = np.argsort(dist_mat, axis=1) 511 | 512 | cur_rank_results = dict() 513 | cur_rank_results['query_ids'] = query_ids_list[0] 514 | cur_rank_results['query_img_paths'] = query_img_paths_list[0] 515 | cur_rank_results['gallery_ids'] = g_ids_list[0] 516 | cur_rank_results['gallery_img_paths'] = g_img_paths_list[0] 517 | cur_rank_results['dist_mat'] = dist_mat 518 | cur_rank_results['sorted_indices'] = sorted_indices 519 | rank_results.append(cur_rank_results) 520 | 521 | # cur_cmc = get_cmc(sorted_indices, query_ids, query_cam_ids, g_ids, g_cam_ids) 522 | # cur_mAP = get_mAP(sorted_indices, query_ids, query_cam_ids, g_ids, g_cam_ids) 523 | # cur_mINP = get_mINP(sorted_indices, query_ids, query_cam_ids, g_ids, g_cam_ids) 524 | cur_cmc, cur_mAP, cur_mINP = get_cmc_mAP_mINP(sorted_indices, 525 | query_ids_list[0], query_cam_ids_list[0], 526 | g_ids_list[0], g_cam_ids_list[0]) 527 | r1 += cur_cmc[0] 528 | r5 += cur_cmc[4] 529 | r10 += cur_cmc[9] 530 | r20 += cur_cmc[19] 531 | mAP += cur_mAP 532 | mINP += cur_mINP 533 | 534 | r1 = r1 / num_trials * 100 535 | r5 = r5 / num_trials * 100 536 | r10 = r10 / num_trials * 100 537 | r20 = r20 / num_trials * 100 538 | mAP = mAP / num_trials * 100 539 | mINP = mINP / num_trials * 100 540 | 541 | logger = logging.getLogger('MCJA') 542 | logger.info('-' * 150) 543 | perf = '[MSER] {} num-shot:{} r1 precision = {:.2f} , r10 precision = {:.2f} , r20 precision = {:.2f} , ' \ 544 | 'mAP = {:.2f} , mINP = {:.2f}' 545 | logger.info(perf.format(mode, num_shots, r1, r10, r20, mAP, mINP)) 546 | logger.info('-' * 150) 547 | 548 | return r1, r5, r10, r20, mAP, mINP, rank_results 549 | 550 | 551 | def eval_regdb(query_feats, query_ids, query_cam_ids, query_img_paths, 552 | gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths, mode='i2v', mser=False): 553 | """ 554 | A comprehensive function tailored for evaluating VI-ReID models on the RegDB dataset, which integrates both basic 555 | and advanced Multi-Spectral Enhanced Ranking (MSER) strategies for performance assessment. This evaluation mechanism 556 | is devised to accommodate the unique modality challenges presented by the RegDB dataset, specifically focusing on 557 | the thermal-to-visible (t2v or i2v) and visible-to-thermal (v2t or v2i) matching scenarios. By leveraging the 558 | flexibility in choosing between a straightforward evaluation approach and a new MSER strategy, this function enables 559 | analysis of model performance. 560 | 561 | Args: 562 | - query_feats (Tensor): The feature representations of query images. 563 | - query_ids (Tensor): The identity labels associated with each query image. 564 | - query_cam_ids (Tensor): The camera IDs from which each query image was captured, indicating the source modality. 565 | - query_img_paths (ndarray): The file paths for query images, useful for detailed analysis and debugging. 566 | - gallery_feats (Tensor): The feature representations of gallery images. 567 | - gallery_ids (Tensor): The identity labels for gallery images. 568 | - gallery_cam_ids (Tensor): The camera IDs for gallery images, highlighting the target modality for matching. 569 | - gallery_img_paths (ndarray): The file paths for gallery images, enabling precise tracking of evaluated samples. 570 | - mode (str): Determines the direction of modality matching, either 'i2v' or 'v2i'. 571 | - mser (bool): Indicates whether the Multi-Spectral Enhanced Ranking strategy should be applied. 572 | 573 | Returns: 574 | - tuple: Delivers evaluation metrics including rank-1, rank-5, rank-10, rank-20 precision, mean Average Precision 575 | (mAP), and mean Inverse Negative Penalty (mINP), alongside detailed ranking results for in-depth analysis. 576 | """ 577 | 578 | if mser: 579 | return eval_regdb_mser(query_feats, query_ids, query_cam_ids, query_img_paths, 580 | gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths, mode) 581 | return eval_regdb_base(query_feats, query_ids, query_cam_ids, query_img_paths, 582 | gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths, mode) 583 | 584 | 585 | def eval_regdb_base(query_feats, query_ids, query_cam_ids, query_img_paths, 586 | gallery_feats, gallery_ids, gallery_cam_ids, gallery_img_paths, mode='i2v'): 587 | """ 588 | A function specifically developed for evaluating VI-ReID models on the RegDB dataset, focusing on the thermal-to- 589 | visible (t2v or i2v) or visible-to-thermal (v2t or v2i) matching scenarios. This evaluation function computes the 590 | similarity between query and gallery features using cosine similarity, applies normalization to feature vectors, 591 | and ranks the gallery images based on their similarity to the query set. The performance is quantified using 592 | re-identification metrics such as rank-1, rank-5, rank-10, rank-20 precision, mean Average Precision (mAP), and 593 | mean Inverse Negative Penalty (mINP), providing a detailed analysis of the model's performance. 594 | 595 | Args: 596 | - query_feats (Tensor or List[Tensor]): Feature vectors of query images. 597 | - query_ids (Tensor or List[Tensor]): Identity labels associated with query images. 598 | - query_cam_ids (Tensor or List[Tensor]): Camera IDs from which each query image was captured. 599 | - query_img_paths (ndarray or List[ndarray]): File paths of query images. 600 | - gallery_feats (Tensor or List[Tensor]): Feature vectors of gallery images. 601 | - gallery_ids (Tensor or List[Tensor]): Identity labels associated with gallery images. 602 | - gallery_cam_ids (Tensor or List[Tensor]): Camera IDs from which each gallery image was captured. 603 | - gallery_img_paths (ndarray or List[ndarray]): File paths of gallery images. 604 | - mode (str): The evaluation mode, either 'i2v' (infrared to visible) or 'v2i' (visible to infrared), 605 | dictating the direction of matching between the query and gallery sets. 606 | 607 | Returns: 608 | - tuple: A tuple containing metrics of rank-1, rank-5, rank-10, rank-20 precision, mAP, and mINP, 609 | along with detailed rank results for further analysis. 610 | """ 611 | 612 | assert mode in ['i2v', 'v2i'] 613 | 614 | gallery_feats = F.normalize(gallery_feats, dim=1) 615 | query_feats = F.normalize(query_feats, dim=1) 616 | 617 | # dist_mat = pairwise_distance(query_feats, gallery_feats) 618 | dist_mat = -torch.mm(query_feats, gallery_feats.t()) 619 | sorted_indices = np.argsort(dist_mat, axis=1) 620 | 621 | rank_results = [] 622 | cur_rank_results = dict() 623 | cur_rank_results['query_ids'] = query_ids 624 | cur_rank_results['query_img_paths'] = query_img_paths 625 | cur_rank_results['gallery_ids'] = gallery_ids 626 | cur_rank_results['gallery_img_paths'] = gallery_img_paths 627 | cur_rank_results['dist_mat'] = dist_mat 628 | cur_rank_results['sorted_indices'] = sorted_indices 629 | rank_results.append(cur_rank_results) 630 | 631 | # mAP = get_mAP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids) 632 | # cmc = get_cmc(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids) 633 | # mINP = get_mINP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids) 634 | cmc, mAP, mINP = get_cmc_mAP_mINP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids) 635 | 636 | r1 = cmc[0] 637 | r5 = cmc[4] 638 | r10 = cmc[9] 639 | r20 = cmc[19] 640 | 641 | r1 = r1 * 100 642 | r5 = r5 * 100 643 | r10 = r10 * 100 644 | r20 = r20 * 100 645 | mAP = mAP * 100 646 | mINP = mINP * 100 647 | 648 | logger = logging.getLogger('MCJA') 649 | logger.info('-' * 150) 650 | perf = 'r1 precision = {:.2f} , r10 precision = {:.2f} , r20 precision = {:.2f} , ' \ 651 | 'mAP = {:.2f} , mINP = {:.2f}' 652 | logger.info(perf.format(r1, r10, r20, mAP, mINP)) 653 | logger.info('-' * 150) 654 | 655 | return r1, r5, r10, r20, mAP, mINP, rank_results 656 | 657 | 658 | def eval_regdb_mser(query_feats_list, query_ids_list, query_cam_ids_list, query_img_paths_list, 659 | gallery_feats_list, gallery_ids_list, gallery_cam_ids_list, gallery_img_paths_list, mode='i2v'): 660 | """ 661 | A function crafted for evaluating VI-ReID models on the RegDB dataset using the Multi-Spectral Enhanced Ranking 662 | (MSER) strategy, tailored specifically for the thermal-to-visible (t2v or i2v) or visible-to-thermal (v2t or v2i) 663 | matching scenarios. The MSER strategy applies a re-ranking mechanism to enhance the initial distance matrix 664 | computation, utilizing multiple spectral representations. This function performs normalization on the feature 665 | vectors of both query and gallery sets, computes the distance matrix with MSER re-ranking, and evaluates the 666 | model based on re-identification metrics. 667 | 668 | Args: 669 | - query_feats_list (List[Tensor]): List of tensors representing features of query images. 670 | - query_ids_list (List[Tensor]): List of tensors containing the identity labels of the query images. 671 | - query_cam_ids_list (List[Tensor]): List of tensors with camera IDs from which each query image was captured. 672 | - query_img_paths_list (List[ndarray]): List of ndarrays holding the file paths for each query image. 673 | - gallery_feats_list (List[Tensor]): List of tensors representing the feature vectors of gallery images. 674 | - gallery_ids_list (List[Tensor]): List of tensors containing the identity labels of the gallery images. 675 | - gallery_cam_ids_list (List[Tensor]): List of tensors with camera IDs from which each gallery image was captured. 676 | - gallery_img_paths_list (List[ndarray]): List of ndarrays holding the file paths for each gallery image. 677 | - mode (str): The evaluation mode, either 'i2v' (infrared to visible) or 'v2i' (visible to infrared), 678 | dictating the direction of matching between query and gallery sets. 679 | 680 | Returns: 681 | - tuple: A tuple containing metrics of rank-1, rank-5, rank-10, rank-20 precision, mean Average Precision (mAP), 682 | and mean Inverse Negative Penalty (mINP), along with detailed rank results for further analysis. 683 | """ 684 | 685 | list_num = len(query_feats_list) 686 | for c in range(list_num): 687 | gallery_feats_list[c] = F.normalize(gallery_feats_list[c], dim=1) 688 | query_feats_list[c] = F.normalize(query_feats_list[c], dim=1) 689 | 690 | dist_mat = mser_rerank(query_feats_list, gallery_feats_list, 691 | k1=50, k2=10, lambda_value=0.2, eval_type=False, mode=mode) 692 | sorted_indices = np.argsort(dist_mat, axis=1) 693 | 694 | rank_results = [] 695 | cur_rank_results = dict() 696 | cur_rank_results['query_ids'] = query_ids_list[0] 697 | cur_rank_results['query_img_paths'] = query_img_paths_list[0] 698 | cur_rank_results['gallery_ids'] = gallery_ids_list[0] 699 | cur_rank_results['gallery_img_paths'] = gallery_img_paths_list[0] 700 | cur_rank_results['dist_mat'] = dist_mat[0] 701 | cur_rank_results['sorted_indices'] = sorted_indices[0] 702 | rank_results.append(cur_rank_results) 703 | 704 | # mAP = get_mAP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids) 705 | # cmc = get_cmc(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids) 706 | # mINP = get_mINP(sorted_indices, query_ids, query_cam_ids, gallery_ids, gallery_cam_ids) 707 | cmc, mAP, mINP = get_cmc_mAP_mINP(sorted_indices, 708 | query_ids_list[0], query_cam_ids_list[0], 709 | gallery_ids_list[0], gallery_cam_ids_list[0]) 710 | 711 | r1 = cmc[0] 712 | r5 = cmc[4] 713 | r10 = cmc[9] 714 | r20 = cmc[19] 715 | 716 | r1 = r1 * 100 717 | r5 = r5 * 100 718 | r10 = r10 * 100 719 | r20 = r20 * 100 720 | mAP = mAP * 100 721 | mINP = mINP * 100 722 | 723 | logger = logging.getLogger('MCJA') 724 | logger.info('-' * 150) 725 | perf = '[MSER] r1 precision = {:.2f} , r10 precision = {:.2f} , r20 precision = {:.2f} , ' \ 726 | 'mAP = {:.2f} , mINP = {:.2f}' 727 | logger.info(perf.format(r1, r10, r20, mAP, mINP)) 728 | logger.info('-' * 150) 729 | 730 | return r1, r5, r10, r20, mAP, mINP, rank_results 731 | -------------------------------------------------------------------------------- /utils/mser_rerank.py: -------------------------------------------------------------------------------- 1 | """MCJA/utils/mser_rerank.py 2 | It introduces the Multi-Spectral Enhanced Ranking (MSER) re-ranking strategy. 3 | """ 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def pairwise_distance(query_features, gallery_features): 10 | """ 11 | A function that efficiently computes the pairwise Euclidean distances between two sets of features, typically used 12 | in the context of person re-identification tasks to measure similarities between query and gallery sets. This 13 | implementation leverages matrix operations for high performance, calculating the squared differences between 14 | each pair of features in the query and gallery feature tensors. 15 | 16 | Args: 17 | - query_features (Tensor): A tensor containing the feature vectors of the query set. 18 | Each row represents the feature vector of a query sample. 19 | - gallery_features (Tensor): A tensor containing the feature vectors of the gallery set. 20 | Each row represents the feature vector of a gallery sample. 21 | 22 | Returns: 23 | - Tensor: A matrix of distances where each element (i, j) represents the Euclidean distance between 24 | the i-th query feature and the j-th gallery feature. 25 | """ 26 | 27 | x = query_features 28 | y = gallery_features 29 | m, n = x.size(0), y.size(0) 30 | x = x.view(m, -1) 31 | y = y.view(n, -1) 32 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 33 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 34 | dist.addmm_(beta=1, alpha=-2, mat1=x, mat2=y.t()) 35 | return dist 36 | 37 | 38 | def mser_rerank(query_feats_list, g_feats_list, k1=40, k2=20, lambda_value=0.3, eval_type=True, mode='v2i'): 39 | """ 40 | A function designed to perform Multi-Spectral Enhanced Ranking (MSER) for cross-modality person re-identification. 41 | The MSER strategy advances the initial matching process by integrating a re-ranking mechanism that emphasizes 42 | reciprocal relationships and spectral characteristics. This approach recalculates the pairwise distances between 43 | query and gallery sets, refining the initial matches through k-reciprocal nearest neighbors and a Jaccard distance 44 | measure to produce a more accurate ranking. 45 | 46 | As mentioned in our paper, mser is based on the rerank[1] strategy in single-modality ReID and extends it to VI-ReID 47 | with multi-spectral information within images. Here, we utilize the rerank code from [2] in our implementation. 48 | 49 | Ref: 50 | [1] (Paper) Re-ranking Person Re-identification with k-Reciprocal Encoding, CVPR 2017. 51 | [2] (Code) https://github.com/DoubtedSteam/MPANet/blob/main/utils/rerank.py 52 | 53 | Args: 54 | - query_feats_list (List[Tensor]): List of tensors representing feature vectors of query images. 55 | - g_feats_list (List[Tensor]): List of tensors representing feature vectors of gallery images. 56 | - k1 (int): The primary parameter controlling the extent of k-reciprocal neighbors to consider, 57 | affecting the initial scope of re-ranking. 58 | - k2 (int): The secondary parameter influencing the expansion of reciprocal neighbors, 59 | further refining the selection based on mutual nearest neighbors. 60 | - lambda_value (float): A coefficient used to balance the original distance matrix with the Jaccard distance, 61 | adjusting the influence of each component in the final distance computation. 62 | - eval_type (bool): Indicates the type of evaluation to be performed. 63 | - mode (str): Specifies the modality matching direction, either 'i2v' for infrared-to-visible or 'v2i' for 64 | visible-to-infrared, adapting the function for different dataset characteristics. 65 | 66 | Returns: 67 | - numpy.ndarray: The re-ranked distance matrix, where each element reflects the recalculated distance between a 68 | query and a gallery feature vector, with lower values denoting higher similarity. 69 | """ 70 | 71 | # Note: The MSER strategy requires more CPU memory. 72 | 73 | assert mode in ['i2v', 'v2i'] 74 | 75 | if mode == 'i2v': 76 | list_num = len(g_feats_list) 77 | q_feat = query_feats_list[0] 78 | feats = torch.cat([q_feat] + g_feats_list, 0) 79 | else: # mode == 'v2i' 80 | list_num = len(query_feats_list) 81 | q_feat = torch.cat(query_feats_list, 0) 82 | g_feat = g_feats_list[0] 83 | feats = torch.cat(query_feats_list + [g_feat], 0) 84 | 85 | dist = pairwise_distance(feats, feats) 86 | # dist = -torch.mm(feats, feats.permute(1, 0)) 87 | original_dist = dist.clone().numpy() 88 | original_dist = np.transpose(original_dist / np.max(original_dist, axis=0)) 89 | V = np.zeros_like(original_dist).astype(np.float16) 90 | 91 | query_num = q_feat.size(0) 92 | all_num = original_dist.shape[0] 93 | if eval_type: 94 | dist[:, query_num:] = dist.max() 95 | dist = dist.numpy() 96 | initial_rank = np.argsort(dist).astype(np.int32) 97 | 98 | for i in range(all_num): 99 | # k-reciprocal neighbors 100 | forward_k_neigh_index = initial_rank[i, :k1 + 1] 101 | backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] 102 | fi = np.where(backward_k_neigh_index == i)[0] 103 | k_reciprocal_index = forward_k_neigh_index[fi] 104 | k_reciprocal_expansion_index = k_reciprocal_index 105 | 106 | # for j in range(len(k_reciprocal_index)): 107 | # candidate = k_reciprocal_index[j] 108 | # candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2)) + 1] 109 | # candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, 110 | # :int(np.around(k1 / 2)) + 1] 111 | # fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] 112 | # candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] 113 | # if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( 114 | # candidate_k_reciprocal_index): 115 | # k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) 116 | 117 | k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) 118 | weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) 119 | V[i, k_reciprocal_expansion_index] = weight / np.sum(weight) 120 | original_dist = original_dist[:query_num, ] 121 | if k2 != 1: 122 | V_qe = np.zeros_like(V, dtype=np.float16) 123 | for i in range(all_num): 124 | V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) 125 | V = V_qe 126 | del V_qe 127 | del initial_rank 128 | invIndex = [] # row index 129 | for i in range(all_num): 130 | invIndex.append(np.where(V[:, i] != 0)[0]) 131 | 132 | jaccard_dist = np.zeros_like(original_dist, dtype=np.float16) 133 | 134 | for i in range(query_num): 135 | temp_min = np.zeros(shape=[1, all_num], dtype=np.float16) 136 | indNonZero = np.where(V[i, :] != 0)[0] 137 | indImages = [invIndex[ind] for ind in indNonZero] 138 | for j in range(len(indNonZero)): 139 | temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], 140 | V[indImages[j], indNonZero[j]]) 141 | jaccard_dist[i] = 1 - temp_min / (2 - temp_min) 142 | 143 | final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value 144 | del original_dist 145 | del V 146 | del jaccard_dist 147 | final_dist = final_dist[:query_num, query_num:] 148 | 149 | if mode == 'i2v': 150 | final_dist = final_dist.reshape(query_num, list_num, -1) 151 | final_dist = np.mean(final_dist, axis=1) 152 | else: # mode == 'v2i' 153 | final_dist = final_dist 154 | final_dist = final_dist.reshape(list_num, query_num // list_num, -1) 155 | final_dist = np.mean(final_dist, axis=0) 156 | return final_dist 157 | 158 | 159 | if __name__ == '__main__': 160 | q_feat = torch.randn((8, 16)) 161 | g_feat = torch.randn((4, 16)) 162 | dist = mser_rerank([q_feat, q_feat, q_feat], [g_feat, g_feat, g_feat], k1=6, k2=4, mode='v2i') 163 | dist = mser_rerank([q_feat, q_feat, q_feat], [g_feat, g_feat, g_feat], k1=6, k2=4, mode='i2v') 164 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | """MCJA/utils/tools.py 2 | This utility file provides some helper functions. 3 | """ 4 | 5 | import random 6 | import datetime 7 | import numpy as np 8 | import torch 9 | 10 | 11 | def set_seed(seed, cuda=True): 12 | """ 13 | A utility function for setting the random seed across various libraries commonly used in deep learning projects to 14 | ensure reproducibility of results. By fixing the random seed, this function makes experiments deterministic, meaning 15 | that running the same code with the same inputs and the same seed (on the same experimental platform) will produce 16 | the same outputs every time, which is crucial for debugging and comparing different models and configurations. 17 | 18 | Args: 19 | - seed (int): The random seed value to be set across all libraries. 20 | - cuda (bool, optional): A flag indicating whether to apply the seed to CUDA operations as well. Default is True. 21 | """ 22 | 23 | random.seed(seed) 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | if cuda: 27 | torch.cuda.manual_seed_all(seed) 28 | torch.backends.cudnn.deterministic = True 29 | torch.backends.cudnn.benchmark = False 30 | 31 | 32 | def time_str(fmt=None): 33 | """ 34 | A utility function for generating a formatted string representing the current date and time. This function is 35 | particularly useful for creating timestamps for logging, file naming, or any other task that requires capturing 36 | the exact moment when an event occurs. By default, the function produces a string formatted as "YYYY-MM-DD_hh-mm-ss", 37 | but it allows for customization of the format according to the user's needs. 38 | 39 | Args: 40 | - fmt (str, optional): A format string defining how the date and time should be represented. 41 | This string should follow the formatting rules used by Python's `strftime` method. If no format is specified, the 42 | default format "%Y-%m-%d_%H-%M-%S" is used, which corresponds to the "year-month-day_hour-minute-second" format. 43 | 44 | Returns: 45 | - str: A string representation of the current date and time, 46 | formatted according to the provided or default format specification. 47 | """ 48 | 49 | if fmt is None: 50 | fmt = '%Y-%m-%d_%H-%M-%S' 51 | return datetime.datetime.today().strftime(fmt) 52 | --------------------------------------------------------------------------------