├── LICENSE ├── README.md ├── assets ├── ablation.png ├── radar.png └── sota.png ├── dataloader.py ├── evaluation.py ├── loss.py ├── model.py ├── opts.py ├── re_ranking.py ├── resnet.py ├── sampler.py ├── test.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BUPTCampus 2 | BUPTCampus is a video-based visible-infrared dataset 3 | with approximately pixel-level aligned tracklet pairs 4 | and single-camera auxiliary samples. 5 | 6 | The [paper](https://ieeexplore.ieee.org/document/10335724) is accepted by **IEEE Transactions on Information Forensics & Security (TIFS)** 2023. 7 | 8 | ![radar](assets/radar.png) 9 | 10 | ## Abstract 11 | 12 | Visible-infrared person re-identification (VI-ReID) aims to match persons captured by visible and infrared cameras, allowing person retrieval and tracking in 24-hour surveillance systems. Previous methods focus on learning from cross-modality person images in different cameras. However, temporal information and single-camera samples tend to be neglected. To crack this nut, in this paper, we first contribute a large-scale VI-ReID dataset named BUPTCampus. Different from most existing VI-ReID datasets, it 1) collects tracklets instead of images to introduce rich temporal information, 2) contains pixel-aligned cross-modality sample pairs for better modality-invariant learning, 3) provides one auxiliary set to help enhance the optimization, in which each identity only appears in a single camera. Based on our constructed dataset, we present a two-stream framework as baseline and apply Generative Adversarial Network (GAN) to narrow the gap between the two modalities. To exploit the advantages introduced by the auxiliary set, we propose a curriculum learning based strategy to jointly learn from both primary and auxiliary sets. Moreover, we design a novel temporal k-reciprocal re-ranking method to refine the ranking list with fine-grained temporal correlation cues. Experimental results demonstrate the effectiveness of the proposed methods. We also reproduce 9 state-of-the-art image-based and video-based VI-ReID methods on BUPTCampus and our methods show substantial superiority to them. The codes and dataset are available at: https://github.com/dyhBUPT/BUPTCampus. 13 | 14 | ## Experiments 15 | 16 | ![ablation](assets/ablation.png) 17 | 18 | ![sota](assets/sota.png) 19 | 20 | ## Data Preparation 21 | 22 | 1\. Download BUPTCampus from [baidu disk](https://pan.baidu.com/s/1GlAlNoSWUuvaGPjOzK4jqQ?pwd=bupt). The file structure should be: 23 | ``` 24 | path_to_dataset 25 | |—— DATA 26 | |—— data_paths.json 27 | |—— gallery.txt 28 | |—— query.txt 29 | |—— train.txt 30 | |—— train_auxiliary.txt 31 | ``` 32 | It contains all training/testing/auxiliary samples with 3,080 identities. 33 | Moreover, in additional to the original RGB/IR samples, 34 | the fake IR samples generated by our PairGAN module are also provided. 35 | 36 | 2\. Set the paths of your dataset to `--data_root` in `opt.py`. 37 | 38 | Please note that (License): 39 | - The dataset is only for academic. Please don't use it for commercial use. 40 | - Please don't redistribute the dataset. 41 | - Please cite our paper if you use the dataset. 42 | 43 | By downloading our dataset, you agree to be bound by and comply with the license agreement. 44 | 45 | ## Requirements 46 | - torch==1.11.0 47 | - torchvision==0.12.0 48 | - tensorboard==2.10.0 49 | - numpy==1.23.1 50 | - Pillow==7.1.2 51 | 52 | ## Test 53 | For direct testing, please download our prepared checkpoints and extracted features from 54 | [baidu disk](https://pan.baidu.com/s/17yfHjKDhUevtfPLdgTMrNw?pwd=bupt). 55 | 56 | #### 1) Baseline 57 | 58 | Then run the following command to load the checkpoint, and you will get the results of baseline. 59 | ```shell script 60 | python test.py --test_ckpt_path path/ckpt/ckpt_res34_real.pth 61 | ``` 62 | 63 | #### 2) AuxNet 64 | 65 | To reproduce the reported performance of AuxNet, 66 | you can directly use the following command to perform re-ranking based on our extracted features. 67 | ```shell script 68 | python re_ranking.py --test_feat_path path/feat 69 | ``` 70 | 71 | If you want to extract all these features by yourself, please use the following commands: 72 | ```shell script 73 | python test.py --test_ckpt_path path/ckpt/ckpt_res34_real_auxiliary.pth --test_frame_sample uniform-first_half-second_half --feature_postfix _real-aux 74 | python test.py --test_ckpt_path path/ckpt/ckpt_res34_fake_auxiliary.pth --test_frame_sample uniform-first_half-second_half --feature_postfix _fake-aux --fake 75 | ``` 76 | Then run `re_ranking.py`, and you will get the final metrics of AuxNet. 77 | 78 | ## Train 79 | 80 | #### 1) Baseline 81 | You can train our baseline module by: 82 | ```shell script 83 | python train.py --gpus 0,1 84 | ``` 85 | 86 | #### 2) AuxNet 87 | To get the full AuxNet model, you should train the model twice. 88 | First, you should train `real RGB` & `real IR` samples with auxiliary learning by: 89 | ```shell script 90 | python train.py --gpus 0,1 --auxiliary 91 | ``` 92 | Then you will get the checkpoints corresponding to the provided `ckpt_res34_real_auxiliary.pth` 93 | 94 | Second, you should train `fake IR` & `real IR` samples with auxiliary learning by: 95 | ```shell script 96 | python train.py --gpus 0,1 --auxiliary --fake 97 | ``` 98 | Then you will get the checkpoints corresponding to the provided `ckpt_res34_fake_auxiliary.pth` 99 | 100 | Finally, for evaluation, please refer to the `Test` section above (feature extraction + re_ranking). 101 | 102 | ## Citation 103 | ``` 104 | @ARTICLE{10335724, 105 | author={Du, Yunhao and Lei, Cheng and Zhao, Zhicheng and Dong, Yuan and Su, Fei}, 106 | journal={IEEE Transactions on Information Forensics and Security}, 107 | title={Video-Based Visible-Infrared Person Re-Identification With Auxiliary Samples}, 108 | year={2024}, 109 | volume={19}, 110 | number={}, 111 | pages={1313-1325}, 112 | doi={10.1109/TIFS.2023.3337972}} 113 | ``` 114 | 115 | ## Acknowledgement 116 | A large part of codes are borrowed from 117 | [DDAG](https://github.com/mangye16/DDAG) and [FastReID](https://github.com/JDAI-CV/fast-reid). 118 | Thanks for their excellent work! 119 | -------------------------------------------------------------------------------- /assets/ablation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dyhBUPT/BUPTCampus/db1f179292eda2ed4ef280d330a850c9a53fb44a/assets/ablation.png -------------------------------------------------------------------------------- /assets/radar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dyhBUPT/BUPTCampus/db1f179292eda2ed4ef280d330a850c9a53fb44a/assets/radar.png -------------------------------------------------------------------------------- /assets/sota.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dyhBUPT/BUPTCampus/db1f179292eda2ed4ef280d330a850c9a53fb44a/assets/sota.png -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: Du Yunhao 3 | @Filename: dataloader.py 4 | @Contact: dyh_bupt@163.com 5 | @Time: 2022/8/29 19:42 6 | @Discription: Dataloader 7 | """ 8 | import json 9 | import math 10 | import torch 11 | import random 12 | import numpy as np 13 | from PIL import Image 14 | from os.path import join 15 | from numpy.random import choice 16 | import torchvision.transforms as T 17 | from torch.utils.data import Dataset, DataLoader 18 | 19 | from utils import * 20 | from sampler import RandomIdentitySampler, ConsistentModalitySampler, RandomCameraSampler 21 | 22 | 23 | def get_dataloader(opt, mode, show=False): 24 | if mode in ('train', 'auxiliary'): 25 | sample_mode = opt.train_frame_sample 26 | transform = get_transform(opt, 'train') 27 | elif mode in ('query', 'gallery'): 28 | sample_mode = opt.test_frame_sample 29 | transform = get_transform(opt, 'test') 30 | else: 31 | raise RuntimeError(f'Wrong dataloader mode {mode}.') 32 | 33 | if opt.dataset == 'BUPTCampus': 34 | dataset = BUPTCampus_Dataset( 35 | data_root=opt.data_root, 36 | mode=mode, 37 | sample=sample_mode, 38 | seq_len=opt.sequence_length, 39 | transform=transform, 40 | random_flip=opt.random_flip, 41 | fake=opt.fake, 42 | ) 43 | else: 44 | raise RuntimeError(f'Dataset {opt.dataset} is not supported for now.') 45 | 46 | if show: 47 | dataset.show_information() 48 | 49 | if mode == 'train': 50 | if opt.train_sampler is None: 51 | dataloader = DataLoader( 52 | dataset, 53 | batch_size=opt.train_bs, 54 | shuffle=True, 55 | drop_last=True, 56 | num_workers=opt.num_workers, 57 | ) 58 | elif opt.train_sampler == 'RandomIdentitySampler': 59 | sampler = RandomIdentitySampler( 60 | dataset, 61 | np=opt.train_bs // (opt.train_sampler_nc * opt.train_sampler_nt), 62 | nc=opt.train_sampler_nc, 63 | nt=opt.train_sampler_nt, 64 | ) 65 | dataloader = DataLoader( 66 | dataset, 67 | batch_size=opt.train_bs, 68 | sampler=sampler, 69 | drop_last=True, 70 | num_workers=opt.num_workers, 71 | ) 72 | elif mode == 'auxiliary': 73 | if opt.auxiliary_sampler is None: 74 | dataloader = DataLoader( 75 | dataset, 76 | batch_size=opt.train_bs, 77 | shuffle=True, 78 | drop_last=True, 79 | num_workers=opt.num_workers, 80 | ) 81 | elif opt.auxiliary_sampler == 'RandomCameraSampler': 82 | sampler = RandomCameraSampler( 83 | dataset, 84 | np=opt.train_bs // (opt.auxiliary_sampler_nc * opt.auxiliary_sampler_nt), 85 | nc=opt.auxiliary_sampler_nc, 86 | nt=opt.auxiliary_sampler_nt, 87 | ) 88 | dataloader = DataLoader( 89 | dataset, 90 | batch_size=opt.train_bs, 91 | sampler=sampler, 92 | drop_last=False, 93 | num_workers=opt.num_workers, 94 | ) 95 | else: 96 | if opt.test_sampler is None: 97 | dataloader = DataLoader( 98 | dataset, 99 | batch_size=opt.test_bs, 100 | shuffle=False, 101 | drop_last=False, 102 | num_workers=opt.num_workers, 103 | ) 104 | elif opt.test_sampler == 'ConsistentModalitySampler': 105 | sampler = ConsistentModalitySampler( 106 | dataset, 107 | batch_size=opt.test_bs, 108 | ) 109 | dataloader = DataLoader( 110 | dataset, 111 | batch_size=opt.test_bs, 112 | sampler=sampler, 113 | drop_last=False, 114 | num_workers=opt.num_workers, 115 | ) 116 | return dataloader, dataset.get_class_num() 117 | 118 | 119 | class BUPTCampus_Dataset(Dataset): 120 | def __init__(self, data_root, mode, sample, seq_len, transform, random_flip=False, fake=False): 121 | """ 122 | :param data_root: 123 | :param mode: 'train', 'query', 'gallery' 124 | :param sample: 'dense', 'uniform' 125 | :param seq_len: 126 | :param transform: 127 | """ 128 | assert mode in ('train', 'query', 'gallery', 'auxiliary') 129 | self.mode = mode 130 | self.sample = sample 131 | self.seq_len = seq_len 132 | self.data_root = data_root 133 | self.transform = transform 134 | self.random_flip = random_flip 135 | self.fake = fake 136 | 137 | self.data_info = self.parse_data() 138 | self.data_paths = json.load(open(join(data_root, '../data_paths.json'))) 139 | 140 | self.pid2label = {pid: label for label, pid in enumerate(self.pids)} 141 | 142 | def parse_data(self): 143 | if self.mode == 'train': 144 | data_info = self._parse_data('../train.txt') 145 | elif self.mode == 'query': 146 | data_info = self._parse_data('../query.txt') 147 | elif self.mode == 'gallery': 148 | data_info = self._parse_data('../gallery.txt') 149 | elif self.mode == 'auxiliary': 150 | data_info = self._parse_data('../train_auxiliary.txt') 151 | return data_info 152 | 153 | def _parse_data(self, path): 154 | data_info, pids = [], [] 155 | path = join(self.data_root, path) 156 | with open(path) as f: 157 | for line in f.readlines(): 158 | obj_id, modality, camera, tracklet_id = line.strip().split(' ') 159 | data_info.append((obj_id, modality, camera, tracklet_id)) 160 | pids.append(obj_id) 161 | self.pids = sorted(set(pids)) 162 | return data_info 163 | 164 | def get_class_num(self): 165 | return len(self.pids) 166 | 167 | def fast_iteration(self): 168 | iteration = [ 169 | [self.pid2label[obj_id], MODALITY[modality], CAMERA[camera]] 170 | for (obj_id, modality, camera, tracklet_id) in self.data_info 171 | ] 172 | return iter(iteration) 173 | 174 | def __getitem__(self, index): 175 | obj_id, modality, camera, tracklet_id = self.data_info[index] 176 | 177 | if self.mode in ('train', 'auxiliary'): 178 | """ 179 | Please note that every sample has two modalities while training, 180 | which means that the final batch size is equal to `2*opt.train_bs` 181 | """ 182 | if modality == 'RGB/IR': 183 | data_paths_ir = self.data_paths[obj_id]['IR'][camera][tracklet_id] 184 | data_paths_rgb = self.data_paths[obj_id]['RGB'][camera][tracklet_id] 185 | if self.fake: 186 | data_paths_rgb = [x.replace('/RGB/', '/FakeIR/') for x in data_paths_rgb] 187 | tra_len = len(data_paths_ir) 188 | else: 189 | raise RuntimeError('Only modality RGB/IR is supported for training.') 190 | 191 | if self.sample == 'random': 192 | replace = tra_len < self.seq_len 193 | frame_idx = sorted(choice(range(tra_len), size=self.seq_len, replace=replace)) 194 | elif self.sample == 'restricted_random': 195 | frame_idx = list() 196 | if tra_len >= self.seq_len: 197 | step = tra_len / self.seq_len 198 | tra_idx = list(range(tra_len)) 199 | else: 200 | step = 1 201 | tra_idx = [0] * (self.seq_len - tra_len) + list(range(tra_len)) 202 | for i in range(self.seq_len): 203 | idx = tra_idx[int(i*step): int((i+1)*step)] 204 | frame_idx += random.sample(idx, 1) 205 | else: 206 | raise RuntimeError(f'Wrong sampling method {self.sample}.') 207 | 208 | images_ir = torch.stack( 209 | [self.transform( 210 | Image.open(join(self.data_root, data_paths_ir[idx])).convert('RGB')) 211 | for idx in frame_idx], 212 | dim=0 213 | ) # [T,C,H,W] 214 | images_rgb = torch.stack( 215 | [self.transform( 216 | Image.open(join(self.data_root, data_paths_rgb[idx])).convert('RGB')) 217 | for idx in frame_idx], 218 | dim=0 219 | ) # [T,C,H,W] 220 | 221 | # If set random_flip in self.transform instead, 222 | # frames within the same tracklet may have different directions 223 | if self.random_flip and torch.rand(1) < 0.5: 224 | images_ir = images_ir.flip(-1) 225 | images_rgb = images_rgb.flip(-1) 226 | 227 | label = self.pid2label[obj_id] 228 | 229 | return images_rgb, images_ir, label, CAMERA[camera] 230 | 231 | else: 232 | data_paths = self.data_paths[obj_id][modality][camera][tracklet_id] 233 | if self.fake: 234 | data_paths = [x.replace('/RGB/', '/FakeIR/') for x in data_paths] 235 | tra_len = len(data_paths) 236 | 237 | if self.sample == 'dense': 238 | ''' 239 | Sample all frames for a tracklet. 240 | Only batch_size=1 is supported for this mode. 241 | ''' 242 | frame_idx = range(tra_len) 243 | elif self.sample == 'uniform': 244 | ''' 245 | Uniform sampling frames for a tracklet. 246 | ''' 247 | frame_idx = np.linspace(0, tra_len, self.seq_len, endpoint=False, dtype=int) 248 | elif self.sample == 'first_half': 249 | frame_idx = np.linspace(0, tra_len//2, self.seq_len, endpoint=False, dtype=int) 250 | elif self.sample == 'second_half': 251 | frame_idx = np.linspace(tra_len//2, tra_len, self.seq_len, endpoint=False, dtype=int) 252 | else: 253 | raise RuntimeError(f'Wrong sampling method {self.sample}.') 254 | 255 | images = torch.stack( 256 | [self.transform(Image.open(join(self.data_root, data_paths[idx])).convert('RGB')) 257 | for idx in frame_idx], 258 | dim=0 259 | ) # [T,C,H,W] 260 | 261 | return images, int(obj_id), CAMERA[camera], MODALITY[modality] 262 | 263 | def __len__(self): 264 | return len(self.data_info) 265 | 266 | def show_information(self): 267 | if self.mode in ('train', 'auxiliary'): 268 | factor = 2 269 | else: 270 | factor = 1 271 | print( 272 | f"===> MCPRL-ReID Dataset ({self.mode}) <===\n" 273 | f"Number of identities: {len(self.pids)}\n" 274 | f"Number of samples : {len(self.data_info) * factor}" 275 | ) 276 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: Du Yunhao 3 | @Filename: evaluation.py 4 | @Contact: dyh_bupt@163.com 5 | @Time: 2022/8/30 16:39 6 | @Discription: evaluation 7 | """ 8 | import torch 9 | import numpy as np 10 | 11 | 12 | def print_metrics(cmc, ap, prefix=''): 13 | print( 14 | '{}mAP: {:.2%} | Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%} | Rank-20: {:.2%}.' 15 | .format(prefix, ap, cmc[0], cmc[4], cmc[9], cmc[19]) 16 | ) 17 | 18 | 19 | def evaluate(distmat, query_pids, gallery_pids, opt): 20 | if isinstance(distmat, torch.Tensor): 21 | distmat = distmat.detach().cpu().numpy() 22 | if isinstance(query_pids, torch.Tensor): 23 | query_pids = query_pids.detach().cpu().numpy() 24 | if isinstance(gallery_pids, torch.Tensor): 25 | gallery_pids = gallery_pids.detach().cpu().numpy() 26 | 27 | num_q, num_g = distmat.shape 28 | assert num_q == len(query_pids) and num_g == len(gallery_pids) 29 | 30 | max_rank = min(opt.max_rank, num_g) 31 | 32 | indices = np.argsort(distmat, axis=1) 33 | matches = (gallery_pids[indices] == query_pids[:, np.newaxis]).astype(np.int32) 34 | 35 | num_valid_query = 0 36 | all_cmc, all_ap = [], [] 37 | for qi in range(num_q): 38 | orig_cmc = matches[qi] 39 | 40 | # This condition is true when the query doesn't appear in gallery. 41 | if not np.any(orig_cmc): 42 | continue 43 | 44 | cmc = orig_cmc.cumsum() 45 | cmc[cmc > 1] = 1 46 | all_cmc.append(cmc[:max_rank]) 47 | num_valid_query += 1. 48 | 49 | # compute average precision 50 | num_rel = orig_cmc.sum() 51 | tmp_cmc = orig_cmc.cumsum() 52 | tmp_cmc = [x / (i + 1.) for i, x in enumerate(tmp_cmc)] 53 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 54 | ap = tmp_cmc.sum() / num_rel 55 | all_ap.append(ap) 56 | 57 | assert num_valid_query > 0, "No query appears in gallery." 58 | 59 | all_cmc = np.asarray(all_cmc).astype(np.float32) 60 | all_cmc = all_cmc.sum(0) / num_valid_query 61 | mAP = np.mean(all_ap) 62 | 63 | return all_cmc, mAP 64 | 65 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: Du Yunhao 3 | @Filename: loss.py 4 | @Contact: dyh_bupt@163.com 5 | @Time: 2022/9/1 16:39 6 | @Discription: loss 7 | Reference: 8 | - https://github.com/JDAI-CV/fast-reid/blob/master/fastreid/modeling/losses/triplet_loss.py 9 | """ 10 | import torch 11 | from torch import nn 12 | import torch.nn.functional as F 13 | 14 | torch.set_printoptions(threshold=1e6, sci_mode=False) 15 | 16 | def get_loss(loss): 17 | if loss == 'cross-entropy': 18 | return nn.CrossEntropyLoss() 19 | elif loss == 'triplet': 20 | return triplet_loss 21 | elif loss == 'auxiliary': 22 | return auxiliary_loss 23 | elif loss == 'l2': 24 | return nn.MSELoss() 25 | elif loss == 'l1': 26 | return nn.L1Loss() 27 | elif loss == 'kl': 28 | return kl_loss 29 | else: 30 | raise RuntimeError(f'Loss {loss} is not supported.') 31 | 32 | 33 | def euclidean_dist(x, y): 34 | m, n = x.size(0), y.size(0) 35 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 36 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 37 | dist = xx + yy - 2 * torch.matmul(x, y.t()) 38 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 39 | return dist 40 | 41 | 42 | def cosine_dist(x, y): 43 | x = F.normalize(x, dim=1) 44 | y = F.normalize(y, dim=1) 45 | dist = 2 - 2 * torch.mm(x, y.t()) 46 | return dist 47 | 48 | 49 | def softmax_weights(dist, mask, return_scalar=False): 50 | if return_scalar: 51 | max_v = torch.max(dist * mask) 52 | diff = dist - max_v 53 | Z = torch.sum(torch.exp(diff) * mask) + 1e-6 54 | W = torch.exp(diff) * mask / Z 55 | return W 56 | else: 57 | max_v = torch.max(dist * mask, dim=1, keepdim=True)[0] 58 | diff = dist - max_v 59 | Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero 60 | W = torch.exp(diff) * mask / Z 61 | return W 62 | 63 | 64 | def hard_example_mining(dist_mat, is_pos, is_neg): 65 | """For each anchor, find the hardest positive and negative sample. 66 | Args: 67 | dist_mat: pair wise distance between samples, shape [N, M] 68 | is_pos: positive index with shape [N, M] 69 | is_neg: negative index with shape [N, M] 70 | Returns: 71 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 72 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 73 | p_inds: pytorch LongTensor, with shape [N]; 74 | indices of selected hard positive samples; 0 <= p_inds[i] <= N - 1 75 | n_inds: pytorch LongTensor, with shape [N]; 76 | indices of selected hard negative samples; 0 <= n_inds[i] <= N - 1 77 | NOTE: Only consider the case in which all labels have same num of samples, 78 | thus we can cope with all anchors in parallel. 79 | """ 80 | 81 | assert len(dist_mat.size()) == 2 82 | 83 | # `dist_ap` means distance(anchor, positive) 84 | # both `dist_ap` and `relative_p_inds` with shape [N] 85 | dist_ap, _ = torch.max(dist_mat * is_pos, dim=1) 86 | # `dist_an` means distance(anchor, negative) 87 | # both `dist_an` and `relative_n_inds` with shape [N] 88 | dist_an, _ = torch.min(dist_mat * is_neg + is_pos * 1e9, dim=1) 89 | 90 | return dist_ap, dist_an 91 | 92 | 93 | def weighted_example_mining(dist_mat, is_pos, is_neg): 94 | """For each anchor, find the weighted positive and negative sample. 95 | Args: 96 | dist_mat: pytorch Variable, pair wise distance between samples, shape [N, N] 97 | is_pos: 98 | is_neg: 99 | Returns: 100 | dist_ap: pytorch Variable, distance(anchor, positive); shape [N] 101 | dist_an: pytorch Variable, distance(anchor, negative); shape [N] 102 | """ 103 | assert len(dist_mat.size()) == 2 104 | 105 | is_pos = is_pos # [B,B], 0/1 106 | is_neg = is_neg # [B,B], 0/1 107 | dist_ap = dist_mat * is_pos # [B,B] 108 | dist_an = dist_mat * is_neg # [B,B] 109 | 110 | weights_ap = softmax_weights(dist_ap, is_pos) # [B,B] 111 | weights_an = softmax_weights(-dist_an, is_neg) # [B,B] 112 | 113 | dist_ap = torch.sum(dist_ap * weights_ap, dim=1) # [B,] 114 | dist_an = torch.sum(dist_an * weights_an, dim=1) # [B,] 115 | 116 | return dist_ap, dist_an 117 | 118 | 119 | def triplet_loss(embedding, targets, margin, norm_feat, hard_mining): 120 | """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid). 121 | Related Triplet Loss theory can be found in paper 'In Defense of the Triplet 122 | Loss for Person Re-Identification'.""" 123 | 124 | if norm_feat: 125 | dist_mat = cosine_dist(embedding, embedding) 126 | else: 127 | dist_mat = euclidean_dist(embedding, embedding) 128 | 129 | N = dist_mat.size(0) 130 | is_pos = targets.view(N, 1).expand(N, N).eq(targets.view(N, 1).expand(N, N).t()).float() 131 | is_neg = targets.view(N, 1).expand(N, N).ne(targets.view(N, 1).expand(N, N).t()).float() 132 | 133 | if hard_mining: 134 | dist_ap, dist_an = hard_example_mining(dist_mat, is_pos, is_neg) # [B,] 135 | else: 136 | dist_ap, dist_an = weighted_example_mining(dist_mat, is_pos, is_neg) # [B,] 137 | 138 | y = dist_an.new().resize_as_(dist_an).fill_(1) # [B,] filled with 1 139 | 140 | if margin > 0: 141 | loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=margin) 142 | else: 143 | loss = F.soft_margin_loss(dist_an - dist_ap, y) 144 | # fmt: off 145 | if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3) 146 | # fmt: on 147 | 148 | return loss 149 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: Du Yunhao 3 | @Filename: model.py 4 | @Contact: dyh_bupt@163.com 5 | @Time: 2022/8/30 15:57 6 | @Discription: model 7 | """ 8 | import torch 9 | import numpy as np 10 | from torch import nn 11 | import torch.nn.functional as F 12 | import torch.utils.model_zoo as model_zoo 13 | 14 | from resnet import resnet34, resnet50, resnet101, remove_fc, model_urls 15 | from utils import * 16 | 17 | 18 | def get_model(opt, class_num=1, name='Baseline'): 19 | if name == 'Baseline': 20 | model = Baseline( 21 | class_num=class_num, 22 | backbone=opt.backbone, 23 | temporal=opt.temporal, 24 | one_stream=opt.one_stream, 25 | ) 26 | model.cuda() 27 | if opt.gpu_mode == 'dp': 28 | model = nn.DataParallel(model) 29 | return model 30 | 31 | 32 | class Normalize(nn.Module): 33 | def __init__(self, power=2): 34 | super(Normalize, self).__init__() 35 | self.power = power 36 | 37 | def forward(self, x): 38 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1./self.power) 39 | return x / norm 40 | 41 | 42 | class BottleNeck(nn.Module): 43 | def __init__(self, feat_dim): 44 | super(BottleNeck, self).__init__() 45 | self.bn = nn.BatchNorm1d(feat_dim) 46 | self.bn.bias.requires_grad_(False) # no shiftgi 47 | self.bn.apply(weights_init_kaiming) 48 | 49 | def forward(self, x): 50 | return self.bn(x) 51 | 52 | 53 | class Classifier(nn.Module): 54 | def __init__(self, feat_dim, class_num, bias=False): 55 | super(Classifier, self).__init__() 56 | self.fc = nn.Linear(feat_dim, class_num, bias) 57 | self.fc.apply(weights_init_classifier) 58 | 59 | def forward(self, x): 60 | return self.fc(x) 61 | 62 | 63 | class modality_speficic_module(nn.Module): 64 | FLAG = False # 加载整个backbone 65 | def __init__(self, backbone='resnet50', input_channel=3): 66 | super(modality_speficic_module, self).__init__() 67 | pretrained = input_channel == 3 68 | if self.FLAG: 69 | self.backbone = eval(backbone)( 70 | pretrained=pretrained, 71 | last_conv_stride=1, 72 | last_conv_dilation=1, 73 | input_channel=input_channel, 74 | ) 75 | else: 76 | self.conv1 = nn.Conv2d(input_channel, 64, kernel_size=7, stride=2, padding=3, 77 | bias=False) 78 | self.bn1 = nn.BatchNorm2d(64) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 81 | for m in self.modules(): 82 | if isinstance(m, nn.Conv2d): 83 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 84 | m.weight.data.normal_(0, math.sqrt(2. / n)) 85 | elif isinstance(m, nn.BatchNorm2d): 86 | m.weight.data.fill_(1) 87 | m.bias.data.zero_() 88 | if pretrained: 89 | state_dict = remove_fc(model_zoo.load_url(model_urls[backbone])) 90 | self.load_state_dict(state_dict, strict=False) 91 | 92 | def forward(self, x): 93 | if self.FLAG: 94 | x = self.backbone.conv1(x) 95 | x = self.backbone.bn1(x) 96 | x = self.backbone.relu(x) 97 | x = self.backbone.maxpool(x) 98 | else: 99 | x = self.conv1(x) 100 | x = self.bn1(x) 101 | x = self.relu(x) 102 | x = self.maxpool(x) 103 | return x 104 | 105 | 106 | class modality_shared_module(nn.Module): 107 | def __init__(self, backbone='resnet50'): 108 | super(modality_shared_module, self).__init__() 109 | self.backbone = eval(backbone)( 110 | pretrained=True, 111 | last_conv_stride=1, 112 | last_conv_dilation=1 113 | ) 114 | 115 | def forward(self, x): 116 | x = self.backbone.layer1(x) 117 | x = self.backbone.layer2(x) 118 | x = self.backbone.layer3(x) 119 | x = self.backbone.layer4(x) 120 | return x 121 | 122 | 123 | class temporal_module(nn.Module): 124 | def __init__(self, method='gap', feat_dim=2048): 125 | super(temporal_module, self).__init__() 126 | self.method = method 127 | self.gap = nn.AdaptiveAvgPool1d(output_size=1) 128 | self.gmp = nn.AdaptiveMaxPool1d(output_size=1) 129 | if method == 'self-attention': 130 | self.transformer = nn.TransformerEncoderLayer( 131 | d_model=feat_dim, 132 | nhead=8, 133 | dim_feedforward=1024, 134 | dropout=0.1, 135 | activation='relu' 136 | ) 137 | 138 | def forward(self, x): 139 | """ 140 | :param x: shape [b,t,c] 141 | :return: shape [b,c] 142 | """ 143 | b, t, c = x.size() 144 | if self.method == 'gap': 145 | x = x.permute(0, 2, 1) 146 | x = self.gap(x) 147 | elif self.method == 'gmp': 148 | x = x.permute(0, 2, 1) 149 | x = self.gmp(x) 150 | elif self.method == 'self-attention': 151 | x = x + self.transformer(x) 152 | x = x.permute(0, 2, 1) 153 | x = self.gap(x) 154 | x = x.view(b, -1) 155 | return x 156 | 157 | 158 | class Baseline(nn.Module): 159 | def __init__(self, class_num, backbone='resnet50', temporal='gap', one_stream=False): 160 | super(Baseline, self).__init__() 161 | if backbone in ['resnet18', 'resnet34']: 162 | feat_dim = 512 163 | elif backbone in ['resnet50', 'resnet101', 'resnet152']: 164 | feat_dim = 2048 165 | else: 166 | raise RuntimeError('Wrong backbone.') 167 | self.one_stream = one_stream 168 | self.shared_module = modality_shared_module(backbone) 169 | self.ir_module = modality_speficic_module(backbone, 3) 170 | self.rgb_module = modality_speficic_module(backbone, 3) 171 | self.classifier = Classifier(feat_dim, class_num, bias=False) 172 | self.temporal_module = temporal_module(temporal, feat_dim) 173 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 174 | self.bottleneck = BottleNeck(feat_dim) 175 | self.l2norm = Normalize(2) 176 | 177 | def forward(self, x_rgb=None, x_ir=None, pids=None): 178 | # self.rgb_module = self.ir_module 179 | # [b,t,c,h,w] 180 | if x_rgb is not None and x_ir is not None: 181 | assert x_rgb.size() == x_ir.size() 182 | b, t, c, h, w = x_rgb.size() 183 | x_rgb = x_rgb.contiguous().view(-1, c, h, w) 184 | x_ir = x_ir.contiguous().view(-1, c, h, w) 185 | if self.one_stream: 186 | x_rgb = self.rgb_module(x_rgb) 187 | x_ir = self.rgb_module(x_ir) 188 | else: 189 | x_rgb = self.rgb_module(x_rgb) 190 | x_ir = self.ir_module(x_ir) 191 | 192 | x = torch.cat((x_rgb, x_ir), dim=0) 193 | 194 | elif x_rgb is not None: 195 | b, t, c, h, w = x_rgb.size() 196 | x_rgb = x_rgb.view(-1, c, h, w) 197 | x = self.rgb_module(x_rgb) 198 | elif x_ir is not None: 199 | b, t, c, h, w = x_ir.size() 200 | x_ir = x_ir.view(-1, c, h, w) 201 | if self.one_stream: 202 | x = self.rgb_module(x_ir) 203 | else: 204 | x = self.ir_module(x_ir) 205 | else: 206 | raise RuntimeError('Both x_rgb and x_ir are None.') 207 | 208 | x = self.shared_module(x) # [bt,c,h,w] e.g., [160,2048,16,8] 209 | features = self.avgpool(x).squeeze() # [bt,c] 210 | features = features.view(features.size(0)//t, t, -1) # [b,t,c] 211 | features = self.temporal_module(features) # [b,c] 212 | features_bn = self.bottleneck(features) 213 | 214 | if self.training: 215 | pids = pids.repeat(2) 216 | return features, self.classifier(features_bn), pids 217 | else: 218 | return self.l2norm(features_bn) 219 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: Du Yunhao 3 | @Filename: opts.py 4 | @Contact: dyh_bupt@163.com 5 | @Time: 2022/8/29 19:35 6 | @Discription: options 7 | """ 8 | import json 9 | import argparse 10 | from os.path import join 11 | 12 | 13 | class opts: 14 | def __init__(self): 15 | self.parser = argparse.ArgumentParser() 16 | 17 | # basic settings 18 | self.parser.add_argument('--gpus', type=str, default='0') 19 | self.parser.add_argument('--dataset', type=str, default='BUPTCampus') 20 | self.parser.add_argument('--gpu_mode', type=str, default='dp', help='single/dp/ddp') 21 | self.parser.add_argument('--data_root', type=str, default='/data1/dyh/data/VIData/BUPTCampus/DATA') 22 | self.parser.add_argument('--save_dir', type=str, default='/data1/dyh/results/BUPTCampus/tmp') 23 | self.parser.add_argument('--fake', action='store_true', default=False) 24 | self.parser.add_argument('--feature_postfix', type=str, default='') 25 | 26 | # basic parameters 27 | self.parser.add_argument('--num_workers', type=int, default=4) 28 | self.parser.add_argument('--sequence_length', type=int, default=10) 29 | self.parser.add_argument('--img_hw', nargs='+', type=int, default=(256, 128)) 30 | self.parser.add_argument('--norm_std', type=list, default=[0.229, 0.224, 0.225]) 31 | self.parser.add_argument('--norm_mean', type=list, default=[0.485, 0.456, 0.406]) 32 | 33 | # model 34 | self.parser.add_argument('--temporal', type=str, default='gap', help='gap/self-attention') 35 | self.parser.add_argument('--backbone', type=str, default='resnet34') 36 | self.parser.add_argument('--one_stream', action='store_true', default=False) 37 | 38 | # training 39 | self.parser.add_argument('--train_bs', type=int, default=16) 40 | self.parser.add_argument('--base_lr', type=float, default=2e-4) 41 | self.parser.add_argument('--max_epoch', type=int, default=100) 42 | self.parser.add_argument('--padding', type=int, default=10) 43 | self.parser.add_argument('--eval_freq', type=int, default=1) 44 | self.parser.add_argument('--warmup_epoch', type=int, default=0) 45 | self.parser.add_argument('--warmup_start_lr', type=float, default=1e-5) 46 | self.parser.add_argument('--optimizer', type=str, default='Adam') 47 | self.parser.add_argument('--cosine_end_lr', type=float, default=0.) 48 | self.parser.add_argument('--weight_decay', type=float, default=1e-5) 49 | self.parser.add_argument('--train_print_freq', type=int, default=100) 50 | self.parser.add_argument('--triplet_margin', type=float, default=0.6) 51 | self.parser.add_argument('--triplet_hard', action='store_false', default=True) 52 | self.parser.add_argument('--train_frame_sample', type=str, default='random') 53 | self.parser.add_argument('--random_flip', action='store_false', default=True) 54 | self.parser.add_argument('--lambda_ce', type=float, default=1) 55 | self.parser.add_argument('--lambda_tri', type=float, default=1) 56 | 57 | # sampler 58 | self.parser.add_argument('--train_sampler', type=str, 59 | default='RandomIdentitySampler', help='None for shuffle') 60 | self.parser.add_argument('--train_sampler_nc', type=int, default=2) 61 | self.parser.add_argument('--train_sampler_nt', type=int, default=1) 62 | self.parser.add_argument('--auxiliary_sampler', type=str, default='RandomCameraSampler', 63 | help='None for shuffle, or RandomIdentitySampler/RandomCameraSampler') 64 | self.parser.add_argument('--auxiliary_sampler_nc', type=int, default=2) 65 | self.parser.add_argument('--auxiliary_sampler_nt', type=int, default=1) # 1 or 2 66 | self.parser.add_argument('--test_sampler', type=str, 67 | default='ConsistentModalitySampler', help='None for no shuffle') 68 | 69 | # Auxiliary 70 | self.parser.add_argument('--auxiliary', action='store_true', default=False) 71 | self.parser.add_argument('--aux_phi', type=float, default=3) 72 | 73 | # resume 74 | self.parser.add_argument('--resume_path', type=str, default='') 75 | 76 | # testing 77 | self.parser.add_argument('--test_bs', type=int, default=64) # Please don't change it. 78 | self.parser.add_argument('--test_frame_sample', type=str, default='uniform') 79 | self.parser.add_argument('--test_ckpt_path', type=str) 80 | self.parser.add_argument('--test_feat_path', type=str) 81 | 82 | # evaluation 83 | self.parser.add_argument('--max_rank', type=int, default=20) 84 | self.parser.add_argument('--distance', type=str, default='euclidean') 85 | self.parser.add_argument('--rerank_lambda', type=float, default=0.7) 86 | self.parser.add_argument('--rerank_k1', type=int, default=3) 87 | self.parser.add_argument('--rerank_k2', type=int, default=1) 88 | 89 | def parse(self, args=''): 90 | if args == '': 91 | opt = self.parser.parse_args() 92 | else: 93 | opt = self.parser.parse_args(args) 94 | 95 | return opt 96 | 97 | 98 | opt = opts().parse() 99 | -------------------------------------------------------------------------------- /re_ranking.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: Du Yunhao 3 | @Filename: re_ranking.py 4 | @Contact: dyh_bupt@163.com 5 | @Time: 2023/4/5 10:15 6 | @Discription: k-reciprocal re-ranking 7 | """ 8 | import os 9 | import torch 10 | import numpy as np 11 | from time import time 12 | 13 | from opts import opt 14 | from evaluation import evaluate, print_metrics 15 | from utils import * 16 | 17 | 18 | class K_Reciprocal: 19 | """reference: https://github.com/michuanhaohao/reid-strong-baseline/blob/master/utils/re_ranking.py""" 20 | def __init__(self, k1, k2, lambda_value, alpha=1/2, beta=2/3, distance='euclidean'): 21 | self.k1 = k1 22 | self.k2 = k2 23 | self.lambda_value = lambda_value 24 | self.alpha = alpha 25 | self.beta = beta 26 | self.distance = distance 27 | 28 | def get_original_distance(self, x, y, norm=False): 29 | assert self.distance in ('cosine', 'euclidean') 30 | fn_norm = lambda x: x / np.sqrt(np.sum(x ** 2, axis=1))[:, np.newaxis] 31 | if norm: 32 | x, y = fn_norm(x), fn_norm(y) 33 | if self.distance == 'cosine': 34 | return 1 - np.dot(x, y.T) 35 | elif self.distance == 'euclidean': 36 | return np.sqrt( 37 | np.sum(x ** 2, axis=1)[:, np.newaxis] + 38 | np.sum(y ** 2, axis=1)[np.newaxis, :] - 39 | 2 * np.dot(x, y.T) + 1e-5 40 | ) 41 | 42 | def get_jaccard_distance(self, q_num, g_num, features, fast_version=True): 43 | """ 44 | - fast_version: fast calculation based on some tricks. It runs much faster, but harder to read. 45 | """ 46 | jaccard_dist = np.ones((q_num, g_num), dtype=np.float16) 47 | if fast_version: 48 | q_non_zero_index = [np.where(features[i, :] != 0)[0] for i in range(q_num)] 49 | g_non_zero_index = [np.where(features[:, j] != 0)[0] for j in range(q_num + g_num)] 50 | for i, query_feature in enumerate(features[:q_num]): 51 | minimum = np.zeros(q_num + g_num, dtype=np.float16) 52 | q_non_zero_index_i = q_non_zero_index[i] 53 | indices = [g_non_zero_index[idx] for idx in q_non_zero_index_i] 54 | for j in range(len(q_non_zero_index_i)): 55 | minimum[indices[j]] += np.minimum( 56 | features[i, q_non_zero_index_i[j]], features[indices[j], q_non_zero_index_i[j]]) 57 | minimum = minimum[q_num:] 58 | jaccard_dist[i] = 1 - minimum / (2 - minimum) 59 | else: 60 | for i, query_feature in enumerate(features[:q_num]): 61 | for j, gallery_feature in enumerate(features[q_num:]): 62 | minimum = np.minimum(query_feature, gallery_feature).sum() 63 | maximum = np.maximum(query_feature, gallery_feature).sum() 64 | jaccard_dist[i, j] = 1 - minimum / maximum 65 | return jaccard_dist 66 | 67 | def get_k_reciprocal_index(self, query_index, ranking_list, k): 68 | forward_k_neighbor_index = ranking_list[query_index, :k + 1] # forward retrieval 69 | backward_k_neighbor_index = ranking_list[forward_k_neighbor_index, :k + 1] # backward retrieval 70 | k_reciprocal_row = np.where(backward_k_neighbor_index == query_index)[0] 71 | k_reciprocal_index = forward_k_neighbor_index[k_reciprocal_row] 72 | return k_reciprocal_index 73 | 74 | def _to_numpy(self, x): 75 | if isinstance(x, torch.Tensor): 76 | return x.cpu().numpy() 77 | 78 | def __call__(self, query_feats, gallery_feats): 79 | """ 80 | Call k-reciprocal re-ranking 81 | query_feats: [M,L] 82 | gallery_feats: [N,L] 83 | """ 84 | '''1) original distance''' 85 | q_feats, g_feats = self._to_numpy(query_feats), self._to_numpy(gallery_feats) 86 | all_feats = np.concatenate((q_feats, g_feats), axis=0) 87 | q_num, g_num, all_num = q_feats.shape[0], g_feats.shape[0], all_feats.shape[0] 88 | original_dist = self.get_original_distance(all_feats, all_feats) # [M+N, M+N] 89 | original_dist /= original_dist.max(axis=1)[:, np.newaxis] # row normalization 90 | original_rank = np.argsort(original_dist).astype(int) # original ranking list 91 | '''2) k-reciprocal features''' 92 | k_reciprocal_features = np.zeros_like(original_dist, dtype=np.float16) # i.e., `V` in paper 93 | for i in range(all_num): 94 | k_reciprocal_index = self.get_k_reciprocal_index(i, original_rank, k=self.k1) 95 | k_reciprocal_incremental_index = k_reciprocal_index.copy() # index after incrementally adding 96 | '''incrementally adding''' 97 | for j, candidate in enumerate(k_reciprocal_index): 98 | candidate_k_reciprocal_index = self.get_k_reciprocal_index( 99 | candidate, original_rank, k=int(round(self.k1 * self.alpha))) 100 | if len(np.intersect1d(k_reciprocal_index, candidate_k_reciprocal_index)) \ 101 | > self.beta * len(candidate_k_reciprocal_index): 102 | k_reciprocal_incremental_index = np.append( 103 | k_reciprocal_incremental_index, candidate_k_reciprocal_index) 104 | k_reciprocal_incremental_index = np.unique(k_reciprocal_incremental_index) 105 | '''compute ''' 106 | weight = np.exp(-original_dist[i, k_reciprocal_incremental_index]) # reassign weights with Gaussian kernel 107 | k_reciprocal_features[i, k_reciprocal_incremental_index] = weight / weight.sum() 108 | '''3) local query expansion''' 109 | if self.k2 != 1: 110 | k_reciprocal_expansion_features = np.zeros_like(k_reciprocal_features) 111 | for i in range(all_num): 112 | k_reciprocal_expansion_features[i, :] = \ 113 | np.mean(k_reciprocal_features[original_rank[i, :self.k2], :], axis=0) 114 | k_reciprocal_features = k_reciprocal_expansion_features 115 | '''4) Jaccard distance''' 116 | jaccard_dist = self.get_jaccard_distance(q_num, g_num, k_reciprocal_features, fast_version=True) 117 | return self.lambda_value * original_dist[:q_num, q_num:] + \ 118 | (1 - self.lambda_value) * jaccard_dist 119 | 120 | 121 | def get_features(mode): 122 | if mode == 'real': 123 | query_feats_main = torch.load(f'{directory}/query_feats_real_all.pth') 124 | query_feats_first = torch.load(f'{directory}/query_feats_real_first.pth') 125 | query_feats_second = torch.load(f'{directory}/query_feats_real_second.pth') 126 | gallery_feats_main = torch.load(f'{directory}/gallery_feats_real_all.pth') 127 | gallery_feats_first = torch.load(f'{directory}/gallery_feats_real_first.pth') 128 | gallery_feats_second = torch.load(f'{directory}/gallery_feats_real_second.pth') 129 | elif mode == 'real-aux': 130 | query_feats_main = torch.load(f'{directory}/query_feats_real-aux_all.pth') 131 | query_feats_first = torch.load(f'{directory}/query_feats_real-aux_first.pth') 132 | query_feats_second = torch.load(f'{directory}/query_feats_real-aux_second.pth') 133 | gallery_feats_main = torch.load(f'{directory}/gallery_feats_real-aux_all.pth') 134 | gallery_feats_first = torch.load(f'{directory}/gallery_feats_real-aux_first.pth') 135 | gallery_feats_second = torch.load(f'{directory}/gallery_feats_real-aux_second.pth') 136 | elif mode == 'real-fake': 137 | query_feats_main = torch.cat(( 138 | torch.load(f'{directory}/query_feats_real_all.pth'), 139 | torch.load(f'{directory}/query_feats_fake_all.pth') 140 | ), dim=1) 141 | query_feats_first = torch.cat(( 142 | torch.load(f'{directory}/query_feats_real_first.pth'), 143 | torch.load(f'{directory}/query_feats_fake_first.pth') 144 | ), dim=1) 145 | query_feats_second = torch.cat(( 146 | torch.load(f'{directory}/query_feats_real_second.pth'), 147 | torch.load(f'{directory}/query_feats_fake_second.pth') 148 | ), dim=1) 149 | gallery_feats_main = torch.cat(( 150 | torch.load(f'{directory}/gallery_feats_real_all.pth'), 151 | torch.load(f'{directory}/gallery_feats_fake_all.pth') 152 | ), dim=1) 153 | gallery_feats_first = torch.cat(( 154 | torch.load(f'{directory}/gallery_feats_real_first.pth'), 155 | torch.load(f'{directory}/gallery_feats_fake_first.pth') 156 | ), dim=1) 157 | gallery_feats_second = torch.cat(( 158 | torch.load(f'{directory}/gallery_feats_real_second.pth'), 159 | torch.load(f'{directory}/gallery_feats_fake_second.pth') 160 | ), dim=1) 161 | elif mode == 'real-fake-aux': 162 | query_feats_main = torch.cat(( 163 | torch.load(f'{directory}/query_feats_real-aux_all.pth'), 164 | torch.load(f'{directory}/query_feats_fake-aux_all.pth') 165 | ), dim=1) 166 | query_feats_first = torch.cat(( 167 | torch.load(f'{directory}/query_feats_real-aux_first.pth'), 168 | torch.load(f'{directory}/query_feats_fake-aux_first.pth') 169 | ), dim=1) 170 | query_feats_second = torch.cat(( 171 | torch.load(f'{directory}/query_feats_real-aux_second.pth'), 172 | torch.load(f'{directory}/query_feats_fake-aux_second.pth') 173 | ), dim=1) 174 | gallery_feats_main = torch.cat(( 175 | torch.load(f'{directory}/gallery_feats_real-aux_all.pth'), 176 | torch.load(f'{directory}/gallery_feats_fake-aux_all.pth') 177 | ), dim=1) 178 | gallery_feats_first = torch.cat(( 179 | torch.load(f'{directory}/gallery_feats_real-aux_first.pth'), 180 | torch.load(f'{directory}/gallery_feats_fake-aux_first.pth') 181 | ), dim=1) 182 | gallery_feats_second = torch.cat(( 183 | torch.load(f'{directory}/gallery_feats_real-aux_second.pth'), 184 | torch.load(f'{directory}/gallery_feats_fake-aux_second.pth') 185 | ), dim=1) 186 | return query_feats_main, query_feats_first, query_feats_second, \ 187 | gallery_feats_main, gallery_feats_first, gallery_feats_second 188 | 189 | 190 | if __name__ == '__main__': 191 | directory = opt.test_feat_path 192 | query_pids = torch.load(f'{directory}/query_pids.pth') 193 | query_modals = torch.load(f'{directory}/query_modals.pth') 194 | query_cids = torch.load(f'{directory}/query_cids.pth') 195 | gallery_pids = torch.load(f'{directory}/gallery_pids.pth') 196 | gallery_modals = torch.load(f'{directory}/gallery_modals.pth') 197 | gallery_cids = torch.load(f'{directory}/gallery_cids.pth') 198 | 199 | query_feats_main, query_feats_first, query_feats_second, \ 200 | gallery_feats_main, gallery_feats_first, gallery_feats_second, \ 201 | = get_features(mode='real-fake-aux') 202 | 203 | k_reciprocal = lambda x, y: K_Reciprocal(k1=5, k2=3, lambda_value=0)(x, y) 204 | 205 | lambda_1, lambda_2 = .8, .1 206 | 207 | start = time() 208 | for (q_modal, g_modal) in ((0, 0), (1, 1), (0, 1), (1, 0), (-1, -1)): 209 | if q_modal == -1: 210 | q_mask = query_modals >= q_modal 211 | g_mask = gallery_modals >= g_modal 212 | else: 213 | q_mask = query_modals == q_modal 214 | g_mask = gallery_modals == g_modal 215 | tmp_distance = euclidean_dist(query_feats_main[q_mask], gallery_feats_main[g_mask]) 216 | '''re-ranking''' 217 | if lambda_1 != 0: 218 | tmp_distance_main = k_reciprocal(query_feats_main[q_mask], gallery_feats_main[g_mask]) 219 | tmp_distance = tmp_distance * lambda_1 + tmp_distance_main * (1 - lambda_1) 220 | if lambda_2 != 0: 221 | tmp_distance_1to2 = k_reciprocal(query_feats_first[q_mask], gallery_feats_second[g_mask]) 222 | tmp_distance_2to1 = k_reciprocal(query_feats_second[q_mask], gallery_feats_first[g_mask]) 223 | tmp_distance += (tmp_distance_1to2 + tmp_distance_2to1) * lambda_2 224 | '''evaluate''' 225 | tmp_qid, tmp_gid = query_pids[q_mask], gallery_pids[g_mask] 226 | tmp_cmc, tmp_ap = evaluate(tmp_distance, tmp_qid, tmp_gid, opt) 227 | print_metrics( 228 | tmp_cmc, tmp_ap, 229 | prefix='{:<3}->{:<3}: '.format(MODALITY_[q_modal], MODALITY_[g_modal]) 230 | ) 231 | # print(time() - start) -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | """Copied from https://github.com/mangye16/DDAG/blob/master/resnet.py""" 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152', 'remove_fc', 'model_urls'] 8 | 9 | model_urls = { 10 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 11 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 12 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 13 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 14 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 15 | } 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 19 | """3x3 convolution with padding""" 20 | # original padding is 1; original dilation is 1 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=dilation, bias=False, dilation=dilation) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = conv3x3(inplanes, planes, stride, dilation) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | expansion = 4 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 61 | super(Bottleneck, self).__init__() 62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 63 | self.bn1 = nn.BatchNorm2d(planes) 64 | # original padding is 1; original dilation is 1 65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 68 | self.bn3 = nn.BatchNorm2d(planes * 4) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x): 74 | residual = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv3(out) 85 | out = self.bn3(out) 86 | 87 | if self.downsample is not None: 88 | residual = self.downsample(x) 89 | 90 | out += residual 91 | out = self.relu(out) 92 | 93 | return out 94 | 95 | 96 | class ResNet(nn.Module): 97 | 98 | def __init__(self, block, layers, last_conv_stride=2, last_conv_dilation=1, input_channel=3): 99 | 100 | self.inplanes = 64 101 | super(ResNet, self).__init__() 102 | self.conv1 = nn.Conv2d(input_channel, 64, kernel_size=7, stride=2, padding=3, 103 | bias=False) 104 | self.bn1 = nn.BatchNorm2d(64) 105 | self.relu = nn.ReLU(inplace=True) 106 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 107 | self.layer1 = self._make_layer(block, 64, layers[0]) 108 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 109 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 110 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_conv_stride, dilation=last_conv_dilation) 111 | 112 | for m in self.modules(): 113 | if isinstance(m, nn.Conv2d): 114 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 115 | m.weight.data.normal_(0, math.sqrt(2. / n)) 116 | elif isinstance(m, nn.BatchNorm2d): 117 | m.weight.data.fill_(1) 118 | m.bias.data.zero_() 119 | 120 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 121 | downsample = None 122 | if stride != 1 or self.inplanes != planes * block.expansion: 123 | downsample = nn.Sequential( 124 | nn.Conv2d(self.inplanes, planes * block.expansion, 125 | kernel_size=1, stride=stride, bias=False), 126 | nn.BatchNorm2d(planes * block.expansion), 127 | ) 128 | 129 | layers = [] 130 | layers.append(block(self.inplanes, planes, stride, downsample, dilation)) 131 | self.inplanes = planes * block.expansion 132 | for i in range(1, blocks): 133 | layers.append(block(self.inplanes, planes)) 134 | 135 | return nn.Sequential(*layers) 136 | 137 | def forward(self, x): 138 | x = self.conv1(x) 139 | x = self.bn1(x) 140 | x = self.relu(x) 141 | x = self.maxpool(x) 142 | 143 | x = self.layer1(x) 144 | x = self.layer2(x) 145 | x = self.layer3(x) 146 | x = self.layer4(x) 147 | 148 | return x 149 | 150 | 151 | def remove_fc(state_dict): 152 | """Remove the fc layer parameters from state_dict.""" 153 | # for key, value in state_dict.items(): 154 | for key, value in list(state_dict.items()): 155 | if key.startswith('fc.'): 156 | del state_dict[key] 157 | return state_dict 158 | 159 | 160 | def resnet18(pretrained=False, **kwargs): 161 | """Constructs a ResNet-18 model. 162 | Args: 163 | pretrained (bool): If True, returns a model pre-trained on ImageNet 164 | """ 165 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 166 | if pretrained: 167 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet18']))) 168 | return model 169 | 170 | 171 | def resnet34(pretrained=False, **kwargs): 172 | """Constructs a ResNet-34 model. 173 | Args: 174 | pretrained (bool): If True, returns a model pre-trained on ImageNet 175 | """ 176 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 177 | if pretrained: 178 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet34']))) 179 | return model 180 | 181 | 182 | def resnet50(pretrained=False, **kwargs): 183 | """Constructs a ResNet-50 model. 184 | Args: 185 | pretrained (bool): If True, returns a model pre-trained on ImageNet 186 | """ 187 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 188 | if pretrained: 189 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50']))) 190 | return model 191 | 192 | 193 | def resnet101(pretrained=False, **kwargs): 194 | """Constructs a ResNet-101 model. 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | """ 198 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 199 | if pretrained: 200 | model.load_state_dict( 201 | remove_fc(model_zoo.load_url(model_urls['resnet101']))) 202 | return model 203 | 204 | 205 | def resnet152(pretrained=False, **kwargs): 206 | """Constructs a ResNet-152 model. 207 | Args: 208 | pretrained (bool): If True, returns a model pre-trained on ImageNet 209 | """ 210 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 211 | if pretrained: 212 | model.load_state_dict( 213 | remove_fc(model_zoo.load_url(model_urls['resnet152']))) 214 | return model -------------------------------------------------------------------------------- /sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: Du Yunhao 3 | @Filename: sampler.py 4 | @Contact: dyh_bupt@163.com 5 | @Time: 2022/9/2 17:09 6 | @Discription: sampler 7 | """ 8 | import copy 9 | import random 10 | from collections import defaultdict 11 | from torch.utils.data.sampler import Sampler 12 | 13 | from utils import * 14 | 15 | 16 | class RandomIdentitySampler(Sampler): 17 | """ 18 | This sampler is for training. 19 | For each batch, randomly sample Np identities. 20 | For each identity, randomly sample Nc cameras. 21 | For each camera, randomly sample Nt tracklets. 22 | Note that each tracklet has two modalities, i.e., RGB and IR. 23 | So, the final batch size is equal to `Np * Nc * Nt * 2` 24 | """ 25 | def __init__(self, dataset, np, nc, nt): 26 | self.np = np 27 | self.nc = nc 28 | self.nt = nt 29 | self.dataset = dataset 30 | 31 | # This line aims to get the self.length 32 | self.final_idx = self._get_final_idx() 33 | 34 | def _get_final_idx(self): 35 | # pid -> cam -> index 36 | index_dict = defaultdict(lambda: defaultdict(list)) 37 | for index, (pid, modal, cam) in enumerate(self.dataset.fast_iteration()): 38 | index_dict[pid][cam].append(index) 39 | 40 | # pid -> batch_idx (e.g., [[0,1], [2,3]]) 41 | pid2batch = defaultdict(list) 42 | for pid, cam2idx in index_dict.items(): 43 | # store those cameras with enough tracklets 44 | available_cameras = [cam for cam in cam2idx 45 | if len(cam2idx[cam]) >= self.nt] 46 | while len(available_cameras) >= self.nc: 47 | batch_idx = [] 48 | cameras = random.sample(available_cameras, self.nc) 49 | for camera in cameras: 50 | sampled_index = random.sample(cam2idx[camera], self.nt) 51 | batch_idx.extend(sampled_index) 52 | cam2idx[camera] = [idx for idx in cam2idx[camera] 53 | if idx not in sampled_index] 54 | if len(cam2idx[camera]) < self.nt: 55 | available_cameras.remove(camera) 56 | pid2batch[pid].append(batch_idx) 57 | 58 | # generate final idx 59 | final_idx = [] 60 | available_pids = copy.deepcopy(list(pid2batch)) 61 | while len(available_pids) >= self.np: 62 | sampled_pids = random.sample(available_pids, self.np) 63 | for pid in sampled_pids: 64 | batch_idx = pid2batch[pid].pop(0) 65 | final_idx.extend(batch_idx) 66 | if len(pid2batch[pid]) == 0: 67 | available_pids.remove(pid) 68 | 69 | self.length = len(final_idx) 70 | return final_idx 71 | 72 | def __iter__(self): 73 | """ 74 | Call self._get_final_idx() in __iter__, 75 | to avoid the same sampling results in all epochs. 76 | """ 77 | return iter(self._get_final_idx()) 78 | 79 | def __len__(self): 80 | return self.length 81 | 82 | 83 | class ConsistentModalitySampler(Sampler): 84 | """ 85 | This sampler is for validation. 86 | It ensures the same modality in one batch. 87 | """ 88 | def __init__(self, dataset, batch_size): 89 | self.dataset = dataset 90 | self.batch_size = batch_size 91 | 92 | # This line aims to get the self.length 93 | self.final_idx = self._get_final_idx() 94 | 95 | def _get_final_idx(self): 96 | # modality -> index 97 | index_dict = defaultdict(list) 98 | for index, (pid, modal, cam) in enumerate(self.dataset.fast_iteration()): 99 | index_dict[modal].append(index) 100 | 101 | # batch_idx (e.g., [[0,1], [2,3]]) 102 | if self.batch_size > 1: 103 | idx_ir = index_dict[MODALITY['IR']] 104 | idx_rgb = index_dict[MODALITY['RGB']] 105 | dropped = len(idx_rgb) % self.batch_size 106 | idx_rgb = idx_rgb[:-dropped] # Warning: This will drop some samples 107 | final_idx = idx_rgb + idx_ir 108 | 109 | self.length = len(final_idx) 110 | return final_idx 111 | 112 | def __iter__(self): 113 | """ 114 | Call self._get_final_idx() in __iter__, 115 | to avoid the same sampling results in all epochs. 116 | """ 117 | return iter(self._get_final_idx()) 118 | 119 | def __len__(self): 120 | return self.length 121 | 122 | 123 | class RandomCameraSampler(Sampler): 124 | """ 125 | This sampler is for auxiliary training. 126 | For each batch, randomly sample Nc cameras. 127 | For each camera, randomly sample Np identities. 128 | For each identity, randomly sample Nt tracklets. 129 | Note that each tracklet has two modalities, i.e., RGB and IR. 130 | So, the final batch size is equal to `Nc * Np * Nt * 2` 131 | """ 132 | def __init__(self, dataset, nc, np, nt): 133 | self.nc = nc 134 | self.np = np 135 | self.nt = nt 136 | self.dataset = dataset 137 | 138 | # This line aims to get the self.length 139 | self.final_idx = self._get_final_idx() 140 | 141 | def _get_final_idx(self): 142 | # cam -> pid -> index 143 | index_dict = defaultdict(lambda: defaultdict(list)) 144 | for index, (pid, modal, cam) in enumerate(self.dataset.fast_iteration()): 145 | index_dict[cam][pid].append(index) 146 | 147 | # cam -> batch_idx (e.g., [[0,1],[2,3]]) 148 | cam2batch = defaultdict(list) 149 | for cam, pid2idx in index_dict.items(): 150 | # store those pids with enough tracklets 151 | available_pids = [pid for pid in pid2idx 152 | if len(pid2idx[pid]) >= self.nt] 153 | while len(available_pids) >= self.np: 154 | batch_idx = [] 155 | pids = random.sample(available_pids, self.np) 156 | for pid in pids: 157 | sampled_index = random.sample(pid2idx[pid], self.nt) 158 | batch_idx.extend(sampled_index) 159 | pid2idx[pid] = [idx for idx in pid2idx[pid] 160 | if idx not in sampled_index] 161 | if len(pid2idx[pid]) < self.nt: 162 | available_pids.remove(pid) 163 | cam2batch[cam].append(batch_idx) 164 | 165 | # generate final idx 166 | final_idx = [] 167 | available_cams = copy.deepcopy(list(cam2batch)) 168 | while len(available_cams) >= self.nc: 169 | sampled_cams = random.sample(available_cams, self.nc) 170 | for cam in sampled_cams: 171 | batch_idx = cam2batch[cam].pop(0) 172 | final_idx.extend(batch_idx) 173 | if len(cam2batch[cam]) == 0: 174 | available_cams.remove(cam) 175 | 176 | self.length = len(final_idx) 177 | return final_idx 178 | 179 | def __iter__(self): 180 | """ 181 | Call self._get_final_idx() in __iter__, 182 | to avoid the same sampling results in all epochs. 183 | """ 184 | return iter(self._get_final_idx()) 185 | 186 | def __len__(self): 187 | return self.length 188 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: Du Yunhao 3 | @Filename: test.py 4 | @Contact: dyh_bupt@163.com 5 | @Time: 2022/8/29 21:34 6 | @Discription: test 7 | """ 8 | import os 9 | import torch 10 | from tqdm import tqdm 11 | import torch.nn.functional as F 12 | 13 | from model import get_model 14 | from dataloader import get_dataloader 15 | 16 | from opts import opt 17 | from evaluation import evaluate, print_metrics 18 | from utils import * 19 | 20 | 21 | def test(model, dataloader_query, dataloader_gallery, show=False, save_dir='', return_all=False, postfix=''): 22 | if save_dir: 23 | os.makedirs(save_dir, exist_ok=True) 24 | print('========== Testing ==========') 25 | model.eval() 26 | with torch.no_grad(): 27 | # query 28 | query_feats, query_pids, query_modals, query_cids = [], [], [], [] 29 | for batch_idx, (imgs, pids, cids, modals) in enumerate(tqdm(dataloader_query)): 30 | imgs, cids = imgs.cuda(), cids.cuda() 31 | modal = modals[0] 32 | if modal == 0: 33 | feats = model(x_rgb=imgs) 34 | elif modal == 1: 35 | feats = model(x_ir=imgs) 36 | else: 37 | continue 38 | query_feats.append(feats) 39 | query_pids.append(pids) 40 | query_cids.append(cids) 41 | query_modals.append(modal.repeat(pids.size())) 42 | query_feats = torch.cat(query_feats, dim=0) # [Nq, C] 43 | query_pids = torch.cat(query_pids, dim=0) # [Nq,] 44 | query_modals = torch.cat(query_modals, dim=0) 45 | query_cids = torch.cat(query_cids, dim=0) 46 | 47 | # gallery 48 | gallery_feats, gallery_pids, gallery_modals, gallery_cids = [], [], [], [] 49 | for batch_idx, (imgs, pids, cids, modals) in enumerate(tqdm(dataloader_gallery)): 50 | imgs, cids = imgs.cuda(), cids.cuda() 51 | modal = modals[0] 52 | assert modals.eq(modal).all() 53 | if modal == 0: 54 | feats = model(x_rgb=imgs) 55 | elif modal == 1: 56 | feats = model(x_ir=imgs) 57 | else: 58 | continue 59 | gallery_feats.append(feats) 60 | gallery_pids.append(pids) 61 | gallery_cids.append(cids) 62 | gallery_modals.append(modal.repeat(pids.size())) 63 | gallery_feats = torch.cat(gallery_feats, dim=0) # [Ng, C] 64 | gallery_pids = torch.cat(gallery_pids, dim=0) # [Ng,] 65 | gallery_modals = torch.cat(gallery_modals, dim=0) 66 | gallery_cids = torch.cat(gallery_cids, dim=0) 67 | 68 | # save 69 | if save_dir: 70 | torch.save(query_feats, join(save_dir, f'query_feats{postfix}.pth')) 71 | torch.save(query_pids, join(save_dir, 'query_pids.pth')) 72 | torch.save(query_modals, join(save_dir, 'query_modals.pth')) 73 | torch.save(query_cids, join(save_dir, 'query_cids.pth')) 74 | torch.save(gallery_feats, join(save_dir, f'gallery_feats{postfix}.pth')) 75 | torch.save(gallery_pids, join(save_dir, 'gallery_pids.pth')) 76 | torch.save(gallery_modals, join(save_dir, 'gallery_modals.pth')) 77 | torch.save(gallery_cids, join(save_dir, 'gallery_cids.pth')) 78 | 79 | # distance 80 | if opt.distance == 'cosine': 81 | distance = 1 - query_feats @ gallery_feats.T 82 | else: 83 | distance = euclidean_dist(query_feats, gallery_feats) 84 | 85 | CMC, MAP = [], [] 86 | 87 | # evaluate (intra/inter-modality) 88 | for q_modal in (0, 1): 89 | for g_modal in (0, 1): 90 | q_mask = query_modals == q_modal 91 | g_mask = gallery_modals == g_modal 92 | tmp_distance = distance[q_mask, :][:, g_mask] 93 | tmp_qid = query_pids[q_mask] 94 | tmp_gid = gallery_pids[g_mask] 95 | tmp_cmc, tmp_ap = evaluate(tmp_distance, tmp_qid, tmp_gid, opt) 96 | CMC.append(tmp_cmc * 100) 97 | MAP.append(tmp_ap * 100) 98 | if show: 99 | print_metrics( 100 | tmp_cmc, tmp_ap, 101 | prefix='{:<3}->{:<3}: '.format(MODALITY_[q_modal], MODALITY_[g_modal]) 102 | ) 103 | 104 | # evaluate (omni-modality) 105 | cmc, ap = evaluate(distance, query_pids, gallery_pids, opt) 106 | CMC.append(cmc * 100) 107 | MAP.append(ap * 100) 108 | 109 | if show: 110 | print_metrics(cmc, ap, prefix='AllModal: ') 111 | 112 | del query_feats, query_pids, query_modals, gallery_feats, gallery_pids, gallery_modals, distance 113 | 114 | if return_all: 115 | return CMC, MAP 116 | else: 117 | return cmc * 100, ap * 100 118 | 119 | 120 | if __name__ == '__main__': 121 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus 122 | model = get_model(opt, class_num=1074) 123 | model, _ = load_from_ckpt(model, 'model', opt.test_ckpt_path) 124 | postfix = opt.feature_postfix 125 | frame_samples = opt.test_frame_sample.split('-') 126 | for frame_sample in frame_samples: 127 | print('Frame Sample: {}'.format(frame_sample)) 128 | opt.test_frame_sample = frame_sample 129 | dataloader_query, _ = get_dataloader(opt, 'query', True) 130 | dataloader_gallery, _ = get_dataloader(opt, 'gallery', True) 131 | if frame_sample == 'uniform': 132 | curr_postfix = postfix + '_all' 133 | elif frame_sample == 'first_half': 134 | curr_postfix = postfix + '_first' 135 | elif frame_sample == 'second_half': 136 | curr_postfix = postfix + '_second' 137 | cmc, ap = test( 138 | model, 139 | dataloader_query, dataloader_gallery, 140 | show=True, 141 | postfix=curr_postfix, 142 | save_dir=opt.save_dir, 143 | ) 144 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: Du Yunhao 3 | @Filename: train.py 4 | @Contact: dyh_bupt@163.com 5 | @Time: 2022/8/31 21:42 6 | @Discription: train 7 | """ 8 | import os 9 | import time 10 | from itertools import cycle 11 | 12 | import torch 13 | from torch import optim 14 | import torch.nn.functional as F 15 | from torch.utils.data import DataLoader 16 | from torch.cuda.amp import autocast, GradScaler 17 | 18 | from utils import * 19 | from opts import opt 20 | from test import test 21 | from loss import get_loss 22 | from model import get_model 23 | from dataloader import get_dataloader 24 | from evaluation import evaluate, print_metrics 25 | from torch.utils.tensorboard import SummaryWriter 26 | 27 | 28 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus 29 | scaler = GradScaler() 30 | 31 | save_configs(opt) 32 | logger = get_logger(opt.save_dir) 33 | writer = SummaryWriter(opt.save_dir) 34 | 35 | dataloader_query, _ = get_dataloader(opt, 'query', False) 36 | dataloader_gallery, _ = get_dataloader(opt, 'gallery', False) 37 | dataloader_train, class_num = get_dataloader(opt, 'train', True) 38 | dataloader_auxiliary, _ = get_dataloader(opt, 'auxiliary', False) 39 | 40 | model = get_model(opt, class_num=class_num) 41 | 42 | optimizer = eval(f'optim.{opt.optimizer}')( 43 | model.parameters(), 44 | lr=opt.base_lr, 45 | weight_decay=opt.weight_decay 46 | ) 47 | 48 | loss_fn_tri = get_loss('triplet') 49 | loss_fn_ce = get_loss('cross-entropy') 50 | 51 | batch_size = 2 * opt.train_bs 52 | 53 | if opt.resume_path: 54 | model, resume_epoch = load_from_ckpt(model, 'model', opt.resume_path) 55 | else: 56 | resume_epoch = -1 57 | 58 | print('========== Training ==========') 59 | iteration = 0 60 | logger.info('Start training!') 61 | for epoch in range(resume_epoch+1, opt.max_epoch): 62 | model.train() 63 | LOSS_ID = AverageMeter('Loss(ID)', ':.4e') 64 | LOSS_TRI = AverageMeter('Loss(Tri)', ':.4e') 65 | LOSS_ID_AUX = AverageMeter('Loss(ID-AUX)', ':.4e') 66 | LOSS_TRI_AUX = AverageMeter('Loss(Tri-AUX)', ':.4e') 67 | BATCH_TIME = AverageMeter('Time', ':6.3f') 68 | lr = get_lr(opt, epoch) 69 | set_lr(optimizer, lr) 70 | meters = [BATCH_TIME, LOSS_TRI, LOSS_ID] 71 | PROGRESS = ProgressMeter( 72 | num_batches=len(dataloader_train), 73 | meters=meters, 74 | prefix="Epoch [{}/{}] ".format(epoch, opt.max_epoch), 75 | lr=lr 76 | ) 77 | end = time.time() 78 | if opt.auxiliary: 79 | alpha = get_auxiliary_alpha(epoch, opt.max_epoch, phi=opt.aux_phi) 80 | for batch_idx, (datas, datas_aux) in enumerate(zip(dataloader_train, cycle(dataloader_auxiliary))): 81 | '''Auxiliary Set''' 82 | imgs_rgb_aux, imgs_ir_aux, labels_aux, cids_aux = datas_aux 83 | imgs_rgb_aux, imgs_ir_aux, labels_aux, cids_aux = \ 84 | imgs_rgb_aux.cuda(), imgs_ir_aux.cuda(), labels_aux.cuda(), cids_aux.cuda() 85 | with autocast(): 86 | feats_aux, logits_aux, labels_aux = model(imgs_rgb_aux, imgs_ir_aux, pids=labels_aux) 87 | loss_tri_aux = loss_fn_tri( 88 | feats_aux, 89 | labels_aux, 90 | margin=opt.triplet_margin, 91 | norm_feat=False, 92 | hard_mining=opt.triplet_hard 93 | ) 94 | LOSS_TRI_AUX.update(loss_tri_aux.item(), batch_size) 95 | loss_aux = alpha * loss_tri_aux 96 | '''Primary Set''' 97 | imgs_rgb, imgs_ir, labels, cids = datas 98 | imgs_rgb, imgs_ir, labels, cids = \ 99 | imgs_rgb.cuda(), imgs_ir.cuda(), labels.cuda(), cids.cuda() 100 | with autocast(): 101 | feats, logits, labels = model(imgs_rgb, imgs_ir, pids=labels) 102 | loss_tri = loss_fn_tri( 103 | feats, 104 | labels, 105 | margin=opt.triplet_margin, 106 | norm_feat=False, 107 | hard_mining=opt.triplet_hard 108 | ) 109 | LOSS_TRI.update(loss_tri.item(), batch_size) 110 | loss_id = loss_fn_ce(logits, labels) 111 | LOSS_ID.update(loss_id.item(), batch_size) 112 | loss = (1 - alpha) * (opt.lambda_tri * loss_tri + opt.lambda_ce * loss_id) 113 | '''Backward''' 114 | loss = loss + loss_aux 115 | optimizer.zero_grad() 116 | scaler.scale(loss).backward() 117 | scaler.step(optimizer) 118 | scaler.update() 119 | '''Write''' 120 | BATCH_TIME.update(time.time() - end) 121 | end = time.time() 122 | iteration += 1 123 | writer.add_scalar('Train/LR', lr, iteration) 124 | writer.add_scalar('Train/Alpha', alpha, iteration) 125 | writer.add_scalar('Loss/Triplet', loss_tri.item(), iteration) 126 | writer.add_scalar('Loss/Identity', loss_id.item(), iteration) 127 | writer.add_scalar('Loss/Auxiliary', loss_aux.item(), iteration) 128 | if batch_idx % opt.train_print_freq == 0: 129 | PROGRESS.display(batch_idx) 130 | logger.info( 131 | 'Epoch:[{}/{}] [{}/{}] Loss(Aux):{:.5f}' 132 | .format(epoch, opt.max_epoch, batch_idx, len(dataloader_train), loss_aux.item()) 133 | ) 134 | else: 135 | for batch_idx, (imgs_rgb, imgs_ir, labels, cids) in enumerate(dataloader_train): 136 | '''Primary Set''' 137 | imgs_rgb, imgs_ir, labels, cids = \ 138 | imgs_rgb.cuda(), imgs_ir.cuda(), labels.cuda(), cids.cuda() 139 | with autocast(): 140 | feats, logits, labels = model(imgs_rgb, imgs_ir, pids=labels) 141 | loss_tri = loss_fn_tri( 142 | feats, 143 | labels, 144 | margin=opt.triplet_margin, 145 | norm_feat=False, 146 | hard_mining=opt.triplet_hard 147 | ) 148 | LOSS_TRI.update(loss_tri.item(), batch_size) 149 | loss_id = loss_fn_ce(logits, labels) 150 | LOSS_ID.update(loss_id.item(), batch_size) 151 | loss = opt.lambda_tri * loss_tri + opt.lambda_ce * loss_id 152 | '''Backward''' 153 | optimizer.zero_grad() 154 | scaler.scale(loss).backward() 155 | scaler.step(optimizer) 156 | scaler.update() 157 | '''Write''' 158 | BATCH_TIME.update(time.time() - end) 159 | end = time.time() 160 | iteration += 1 161 | writer.add_scalar('Loss/Triplet', loss_tri.item(), iteration) 162 | writer.add_scalar('Loss/Identity', loss_id.item(), iteration) 163 | if batch_idx % opt.train_print_freq == 0: 164 | PROGRESS.display(batch_idx) 165 | logger.info( 166 | 'Epoch:[{}/{}] [{}/{}] Loss(Tri):{:.5f}' 167 | .format(epoch, opt.max_epoch, batch_idx, len(dataloader_train), loss_tri.item()) 168 | ) 169 | 170 | torch.cuda.empty_cache() 171 | if (epoch + 1) % opt.eval_freq == 0: 172 | CMC, MAP = test(model, dataloader_query, dataloader_gallery, show=True, return_all=True) 173 | writer.add_scalar('Eval/mAP(%)', MAP[-1], epoch) 174 | writer.add_scalar('Eval/Rank1(%)', CMC[-1][0], epoch) 175 | writer.add_scalar('Eval/Rank5(%)', CMC[-1][4], epoch) 176 | writer.add_scalar('Eval/Rank10(%)', CMC[-1][9], epoch) 177 | MODE = ['RGB->RGB', 'RGB->IR ', 'IR->RGB ', 'IR->IR ', 'AllModal'] 178 | log_info = 'Epoch:[{}/{}]'.format(epoch, opt.max_epoch) 179 | for i, mode in enumerate(MODE): 180 | log_info += '\n\t{}: mAP:{:.2f}% Rank1:{:.2f}% Rank5:{:.2f}% Rank10:{:.2f}% Rank20:{:.2f}%'\ 181 | .format(mode, MAP[i], CMC[i][0], CMC[i][4], CMC[i][9], CMC[i][19]) 182 | logger.info(log_info) 183 | torch.cuda.empty_cache() 184 | 185 | if epoch + 1 >= 80: 186 | state_dict = { 187 | 'model': model.state_dict(), 188 | 'optimizer': optimizer, 189 | 'epoch': epoch 190 | } 191 | torch.save(state_dict, join(opt.save_dir, f'epoch{epoch}.pth')) 192 | 193 | logger.info('Finish training!') 194 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | @Author: Du Yunhao 3 | @Filename: utils.py 4 | @Contact: dyh_bupt@163.com 5 | @Time: 2022/8/30 21:37 6 | @Discription: utils 7 | """ 8 | import os 9 | import math 10 | import json 11 | import torch 12 | import logging 13 | import numpy as np 14 | from os.path import join 15 | from torch.nn import init 16 | from torchvision import transforms 17 | from collections import defaultdict 18 | 19 | 20 | MODALITY = {'RGB/IR': -1, 'RGB': 0, 'IR': 1} 21 | MODALITY_ = {-1:'All', 0: 'RGB', 1: 'IR'} 22 | CAMERA = {'LS3': 0, 'G25': 1, 'CQ1': 2, 'W4': 3, 'TSG1': 4, 'TSG2': 5} 23 | 24 | 25 | def get_auxiliary_alpha(curr_epoch, max_epoch, phi): 26 | # return phi 27 | # return 0.5 * math.exp(-phi * curr_epoch / max_epoch) 28 | return (math.cos(math.pi * curr_epoch / max_epoch) + phi) / (2 + 2 * phi) 29 | 30 | 31 | def euclidean_dist(x, y): 32 | """ 33 | Args: 34 | x: pytorch Variable, with shape [m, d] 35 | y: pytorch Variable, with shape [n, d] 36 | Returns: 37 | dist: pytorch Variable, with shape [m, n] 38 | """ 39 | m, n = x.size(0), y.size(0) 40 | xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) 41 | yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() 42 | dist = xx + yy 43 | dist.addmm_(x, y.T, beta=1, alpha=-2) 44 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 45 | return dist.cpu().numpy() 46 | 47 | 48 | def get_transform(opt, mode): 49 | if mode == 'train': 50 | return transforms.Compose([ 51 | transforms.Resize(opt.img_hw), 52 | transforms.Pad(opt.padding), 53 | transforms.RandomCrop(opt.img_hw), 54 | transforms.ToTensor(), 55 | transforms.Normalize(opt.norm_mean, opt.norm_std) 56 | ]) 57 | elif mode == 'test': 58 | return transforms.Compose([ 59 | transforms.Resize(opt.img_hw), 60 | transforms.ToTensor(), 61 | transforms.Normalize(opt.norm_mean, opt.norm_std) 62 | ]) 63 | else: 64 | raise RuntimeError('Error transformation mode.') 65 | 66 | 67 | def get_lr(opt, curr_epoch): 68 | if curr_epoch < opt.warmup_epoch: 69 | return ( 70 | opt.warmup_start_lr 71 | + (opt.base_lr - opt.warmup_start_lr) 72 | * curr_epoch 73 | / opt.warmup_epoch 74 | ) 75 | else: 76 | return ( 77 | opt.cosine_end_lr 78 | + (opt.base_lr - opt.cosine_end_lr) 79 | * ( 80 | math.cos( 81 | math.pi * (curr_epoch - opt.warmup_epoch) / (opt.max_epoch - opt.warmup_epoch) 82 | ) 83 | + 1.0 84 | ) 85 | * 0.5 86 | ) 87 | 88 | 89 | def set_lr(optimizer, lr): 90 | for param_group in optimizer.param_groups: 91 | param_group['lr'] = lr 92 | 93 | 94 | def weights_init_kaiming(m): 95 | """Copied from https://github.com/mangye16/DDAG/blob/master/model_main.py""" 96 | classname = m.__class__.__name__ 97 | if classname.find('Conv') != -1: 98 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 99 | elif classname.find('Linear') != -1: 100 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 101 | init.zeros_(m.bias.data) 102 | elif classname.find('BatchNorm1d') != -1: 103 | init.normal_(m.weight.data, 1.0, 0.01) 104 | init.zeros_(m.bias.data) 105 | 106 | 107 | def weights_init_classifier(m): 108 | """Copied from https://github.com/mangye16/DDAG/blob/master/model_main.py""" 109 | classname = m.__class__.__name__ 110 | if classname.find('Linear') != -1: 111 | init.normal_(m.weight.data, 0, 0.001) 112 | if m.bias: 113 | init.zeros_(m.bias.data) 114 | 115 | 116 | class AverageMeter: 117 | """Computes and stores the average and current value""" 118 | def __init__(self, name, fmt=':f'): 119 | self.name = name 120 | self.fmt = fmt 121 | self.reset() 122 | 123 | def reset(self): 124 | self.val = 0 125 | self.avg = 0 126 | self.sum = 0 127 | self.count = 0 128 | 129 | def update(self, val, n=1): 130 | self.val = val 131 | self.sum += val * n 132 | self.count += n 133 | self.avg = self.sum / self.count 134 | 135 | def __str__(self): 136 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 137 | return fmtstr.format(**self.__dict__) 138 | 139 | 140 | class ProgressMeter: 141 | def __init__(self, num_batches, meters, prefix="", lr=0.): 142 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches, lr) 143 | self.meters = meters 144 | self.prefix = prefix 145 | 146 | def display(self, batch): 147 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 148 | entries += [str(meter) for meter in self.meters] 149 | print('\t'.join(entries)) 150 | 151 | @staticmethod 152 | def _get_batch_fmtstr(num_batches, lr): 153 | num_digits = len(str(num_batches // 1)) 154 | fmt = '{:' + str(num_digits) + 'd}' 155 | return '[' + fmt + '/' + fmt.format(num_batches) + '] [lr: {:.2e}]'.format(lr) 156 | 157 | 158 | def save_configs(opt): 159 | configs = vars(opt) 160 | os.makedirs(opt.save_dir, exist_ok=True) 161 | json.dump( 162 | configs, 163 | open(join(opt.save_dir, 'config.json'), 'w'), 164 | indent=2 165 | ) 166 | 167 | 168 | def get_logger(save_dir): 169 | logger = logging.getLogger() 170 | logger.setLevel(logging.INFO) 171 | filename = join(save_dir, 'log.txt') 172 | formatter = logging.Formatter('[%(asctime)s][%(filename)s][%(levelname)s] %(message)s') 173 | 174 | # writting to file 175 | file_handler = logging.FileHandler(filename, mode='w') 176 | file_handler.setFormatter(formatter) 177 | logger.addHandler(file_handler) 178 | 179 | # display in terminal 180 | # stream_handler = logging.StreamHandler() 181 | # stream_handler.setFormatter(formatter) 182 | # logger.addHandler(stream_handler) 183 | 184 | return logger 185 | 186 | 187 | def load_from_ckpt(model, model_name, ckpt_path): 188 | print(f'load from {ckpt_path}...') 189 | ckpt = torch.load(ckpt_path) 190 | epoch = ckpt['epoch'] 191 | model.load_state_dict(ckpt[model_name]) 192 | return model, epoch 193 | --------------------------------------------------------------------------------