├── .gitignore ├── ChannelAug.py ├── LICENSE ├── README.md ├── data_loader.py ├── data_manager.py ├── eval_metrics.py ├── loss.py ├── model_bn.py ├── pre_process_sysu.py ├── resnet.py ├── run.sh ├── testy.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # wandb 2 | wandb/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .venv 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | 144 | # pytype static type analyzer 145 | .pytype/ 146 | 147 | # Cython debug symbols 148 | cython_debug/ 149 | 150 | # PyCharm 151 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 152 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 153 | # and can be added to the global gitignore or merged into this file. For a more nuclear 154 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 155 | #.idea/ 156 | 157 | # some logging file 158 | log/ 159 | save_model/ 160 | traindistill.py 161 | run.sh 162 | run2.sh 163 | testori.py 164 | eval_sysu.py 165 | pre_process* -------------------------------------------------------------------------------- /ChannelAug.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | 5 | #from PIL import Image 6 | import random 7 | import math 8 | #import numpy as np 9 | #import torch 10 | 11 | 12 | class ChannelAdap(object): 13 | """ Adaptive selects a channel or two channels. 14 | Args: 15 | probability: The probability that the Random Erasing operation will be performed. 16 | sl: Minimum proportion of erased area against input image. 17 | sh: Maximum proportion of erased area against input image. 18 | r1: Minimum aspect ratio of erased area. 19 | mean: Erasing value. 20 | """ 21 | 22 | def __init__(self, probability = 0.5): 23 | self.probability = probability 24 | 25 | 26 | def __call__(self, img): 27 | 28 | # if random.uniform(0, 1) > self.probability: 29 | # return img 30 | 31 | idx = random.randint(0, 3) 32 | 33 | if idx ==0: 34 | # random select R Channel 35 | img[1, :,:] = img[0,:,:] 36 | img[2, :,:] = img[0,:,:] 37 | elif idx ==1: 38 | # random select B Channel 39 | img[0, :,:] = img[1,:,:] 40 | img[2, :,:] = img[1,:,:] 41 | elif idx ==2: 42 | # random select G Channel 43 | img[0, :,:] = img[2,:,:] 44 | img[1, :,:] = img[2,:,:] 45 | else: 46 | img = img 47 | 48 | return img 49 | 50 | 51 | class ChannelAdapGray(object): 52 | """ Adaptive selects a channel or two channels. 53 | Args: 54 | probability: The probability that the Random Erasing operation will be performed. 55 | sl: Minimum proportion of erased area against input image. 56 | sh: Maximum proportion of erased area against input image. 57 | r1: Minimum aspect ratio of erased area. 58 | mean: Erasing value. 59 | """ 60 | 61 | def __init__(self, probability = 0.5): 62 | self.probability = probability 63 | 64 | 65 | def __call__(self, img): 66 | 67 | # if random.uniform(0, 1) > self.probability: 68 | # return img 69 | 70 | idx = random.randint(0, 3) 71 | 72 | if idx ==0: 73 | # random select R Channel 74 | img[1, :,:] = img[0,:,:] 75 | img[2, :,:] = img[0,:,:] 76 | elif idx ==1: 77 | # random select B Channel 78 | img[0, :,:] = img[1,:,:] 79 | img[2, :,:] = img[1,:,:] 80 | elif idx ==2: 81 | # random select G Channel 82 | img[0, :,:] = img[2,:,:] 83 | img[1, :,:] = img[2,:,:] 84 | else: 85 | if random.uniform(0, 1) > self.probability: 86 | # return img 87 | img = img 88 | else: 89 | tmp_img = 0.2989 * img[0,:,:] + 0.5870 * img[1,:,:] + 0.1140 * img[2,:,:] 90 | img[0,:,:] = tmp_img 91 | img[1,:,:] = tmp_img 92 | img[2,:,:] = tmp_img 93 | return img 94 | 95 | class ChannelRandomErasing(object): 96 | """ Randomly selects a rectangle region in an image and erases its pixels. 97 | 'Random Erasing Data Augmentation' by Zhong et al. 98 | See https://arxiv.org/pdf/1708.04896.pdf 99 | Args: 100 | probability: The probability that the Random Erasing operation will be performed. 101 | sl: Minimum proportion of erased area against input image. 102 | sh: Maximum proportion of erased area against input image. 103 | r1: Minimum aspect ratio of erased area. 104 | mean: Erasing value. 105 | """ 106 | 107 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 108 | 109 | self.probability = probability 110 | self.mean = mean 111 | self.sl = sl 112 | self.sh = sh 113 | self.r1 = r1 114 | 115 | def __call__(self, img): 116 | 117 | if random.uniform(0, 1) > self.probability: 118 | return img 119 | 120 | for attempt in range(100): 121 | area = img.size()[1] * img.size()[2] 122 | 123 | target_area = random.uniform(self.sl, self.sh) * area 124 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 125 | 126 | h = int(round(math.sqrt(target_area * aspect_ratio))) 127 | w = int(round(math.sqrt(target_area / aspect_ratio))) 128 | 129 | if w < img.size()[2] and h < img.size()[1]: 130 | x1 = random.randint(0, img.size()[1] - h) 131 | y1 = random.randint(0, img.size()[2] - w) 132 | if img.size()[0] == 3: 133 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 134 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 135 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 136 | else: 137 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 138 | return img 139 | 140 | return img -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 feng jiawei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SGIEL_VIReID (CVPR 2023) 2 | Official PyTorch Implementation of "Shape-Erased Feature Learning for Visible-Infrared Person Re-Identification" (CVPR'23) 3 | 4 | 5 | 6 | ### Datasets 7 | We follow [Cross-Modal-Re-ID-baseline](https://github.com/mangye16/Cross-Modal-Re-ID-baseline) to preprocess SYSU-MM01 dataset. 8 | 9 | For VCM-HITSZ, please refer to [its official repository](https://github.com/VCM-project233/MITML). 10 | 11 | 12 | ### Body Shape Data 13 | 14 | We borrowed pre-trained Self-Correction Human Parsing ([SCHP](https://github.com/GoGoDuck912/Self-Correction-Human-Parsing)) model (pretrained on Pascal-Person-Part dataset) to segment body shape from background. Given a pixel of a visible or infrared image, we directly summed the probabilities of being a part of the head, torso, or limbs, predicted by SCHP, to create the body-shape map. 15 | 16 | You can also download the body shape data for SYSU-MM01 through this [link](https://drive.google.com/drive/folders/1i3YosMId359OjDe_DfNmvB98kuMclIdc?usp=drive_link). 17 | ### Dependencies 18 | 19 | * python 3.7.9 20 | * pytorch >1.0 (>1.7 recommended) 21 | * torchvision 0.8.2 22 | * cudatoolkit 11.0 23 | 24 | ### Training and Model 25 | 26 | To reproduce our results on SYSU-MM01, just run (after the dataset path declared) 27 | ``` 28 | bash run.sh 29 | ``` 30 | 31 | We are currently working on Issues. Please feel free to contact me (fengjw151@gmail.com) if you need any other information. 32 | 33 | We uploaded a trained [model](https://drive.google.com/file/d/1FSLhVCPynfOX_Ms3y4cwwNDYwZmLABAX/view?usp=drive_link) on SYSU-MM01. 34 | ### Acknowledge 35 | 36 | Thanks for the great code base from the open-sourced [Cross-Modal-Re-ID-baseline](https://github.com/mangye16/Cross-Modal-Re-ID-baseline). 37 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch.utils.data as data 4 | from ChannelAug import ChannelAdap, ChannelAdapGray, ChannelRandomErasing 5 | import torchvision.transforms as transforms 6 | import random 7 | import math 8 | import os 9 | 10 | 11 | 12 | class RandomErasing(object): 13 | """ Randomly selects a rectangle region in an image and erases its pixels. 14 | 'Random Erasing Data Augmentation' by Zhong et al. 15 | See https://arxiv.org/pdf/1708.04896.pdf 16 | Args: 17 | probability: The probability that the Random Erasing operation will be performed. 18 | sl: Minimum proportion of erased area against input image. 19 | sh: Maximum proportion of erased area against input image. 20 | r1: Minimum aspect ratio of erased area. 21 | mean: Erasing value. 22 | """ 23 | 24 | def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.485, 0.456, 0.406)): 25 | self.probability = probability 26 | self.mean = mean 27 | self.sl = sl 28 | self.sh = sh 29 | self.r1 = r1 30 | 31 | def __call__(self, img): 32 | 33 | if random.uniform(0, 1) >= self.probability: 34 | return img 35 | 36 | for attempt in range(100): 37 | area = img.size()[1] * img.size()[2] 38 | 39 | target_area = random.uniform(self.sl, self.sh) * area 40 | aspect_ratio = random.uniform(self.r1, 1 / self.r1) 41 | 42 | h = int(round(math.sqrt(target_area * aspect_ratio))) 43 | w = int(round(math.sqrt(target_area / aspect_ratio))) 44 | 45 | if w < img.size()[2] and h < img.size()[1]: 46 | x1 = random.randint(0, img.size()[1] - h) 47 | y1 = random.randint(0, img.size()[2] - w) 48 | if img.size()[0] == 3: 49 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 50 | img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] 51 | img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] 52 | else: 53 | img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] 54 | return img 55 | 56 | return img 57 | 58 | 59 | def read_image(img_path): 60 | """Keep reading image until succeed. 61 | This can avoid IOError incurred by heavy IO process.""" 62 | got_img = False 63 | while not got_img: 64 | try: 65 | img = Image.open(img_path).convert('RGB') 66 | got_img = True 67 | except IOError: 68 | print("IOError incurred when reading '{}'. Will redo. Don't worry. Just chill.".format(img_path)) 69 | pass 70 | return img 71 | 72 | class ChannelExchange(object): 73 | """ Adaptive selects a channel or two channels. 74 | Args: 75 | probability: The probability that the Random Erasing operation will be performed. 76 | sl: Minimum proportion of erased area against input image. 77 | sh: Maximum proportion of erased area against input image. 78 | r1: Minimum aspect ratio of erased area. 79 | mean: Erasing value. 80 | """ 81 | 82 | def __init__(self, gray = 2): 83 | self.gray = gray 84 | 85 | def __call__(self, img): 86 | 87 | idx = random.randint(0, self.gray) 88 | 89 | if idx ==0: 90 | # random select R Channel 91 | img[1, :,:] = img[0,:,:] 92 | img[2, :,:] = img[0,:,:] 93 | elif idx ==1: 94 | # random select B Channel 95 | img[0, :,:] = img[1,:,:] 96 | img[2, :,:] = img[1,:,:] 97 | elif idx ==2: 98 | # random select G Channel 99 | img[0, :,:] = img[2,:,:] 100 | img[1, :,:] = img[2,:,:] 101 | else: 102 | tmp_img = 0.2989 * img[0,:,:] + 0.5870 * img[1,:,:] + 0.1140 * img[2,:,:] 103 | img[0,:,:] = tmp_img 104 | img[1,:,:] = tmp_img 105 | img[2,:,:] = tmp_img 106 | return img 107 | 108 | 109 | 110 | class SYSUData(data.Dataset): 111 | def __init__(self, transform=None, colorIndex = None, thermalIndex = None, data_dir = None, data_dir1 = None): 112 | 113 | 114 | # Load training images (path) and labels 115 | train_color_image = np.load(data_dir + 'train_rgb_resized_img.npy') 116 | self.train_color_label = np.load(data_dir + 'train_rgb_resized_label.npy') 117 | 118 | train_thermal_image = np.load(data_dir + 'train_ir_resized_img.npy') 119 | self.train_thermal_label = np.load(data_dir + 'train_ir_resized_label.npy') 120 | 121 | train_color_image_shape = np.load(data_dir1 + 'train_rgb_resized_img.npy') 122 | self.train_color_label_shape = np.load(data_dir1 + 'train_rgb_resized_label.npy') 123 | 124 | train_thermal_image_shape = np.load(data_dir1 + 'train_ir_resized_img.npy') 125 | self.train_thermal_label_shape = np.load(data_dir1 + 'train_ir_resized_label.npy') 126 | print(train_color_image.shape, train_color_image_shape.shape) 127 | print(train_thermal_image.shape, train_thermal_image_shape.shape) 128 | 129 | # BGR to RGB 130 | self.train_color_image = train_color_image 131 | self.train_color_image_shape = train_color_image_shape 132 | self.train_thermal_image = train_thermal_image 133 | self.train_thermal_image_shape = train_thermal_image_shape 134 | self.cIndex = colorIndex 135 | self.tIndex = thermalIndex 136 | 137 | 138 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 139 | self.transform_thermal = transforms.Compose( [ 140 | transforms.ToPILImage(), 141 | transforms.Pad(10), 142 | transforms.RandomCrop((288, 144)), 143 | transforms.RandomHorizontalFlip(), 144 | transforms.ToTensor(), 145 | normalize, 146 | ChannelRandomErasing(probability = 0.5), 147 | ChannelAdapGray(probability =0.5)]) 148 | self.transform_thermal_simple = transforms.Compose( [ 149 | transforms.ToPILImage(), 150 | transforms.Pad(10), 151 | transforms.RandomCrop((288, 144)), 152 | transforms.RandomHorizontalFlip(), 153 | transforms.ToTensor(), 154 | normalize, 155 | ]) 156 | self.transform_color_simple = transforms.Compose( [ 157 | transforms.ToPILImage(), 158 | transforms.Pad(10), 159 | transforms.RandomCrop((288, 144)), 160 | transforms.RandomHorizontalFlip(), 161 | transforms.ToTensor(), 162 | normalize, 163 | ]) 164 | 165 | self.transform_color = transforms.Compose( [ 166 | transforms.ToPILImage(), 167 | transforms.Pad(10), 168 | transforms.RandomCrop((288, 144)), 169 | transforms.RandomHorizontalFlip(), 170 | # transforms.RandomGrayscale(p = 0.1), 171 | transforms.ToTensor(), 172 | normalize, 173 | ChannelRandomErasing(probability = 0.5), 174 | ]) 175 | 176 | 177 | self.transform_color1 = transforms.Compose( [ 178 | transforms.ToPILImage(), 179 | transforms.Pad(10), 180 | transforms.RandomCrop((288, 144)), 181 | transforms.RandomHorizontalFlip(), 182 | transforms.ToTensor(), 183 | normalize, 184 | ChannelRandomErasing(probability = 0.5), 185 | ChannelExchange(gray = 2)]) 186 | 187 | def __getitem__(self, index): 188 | 189 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 190 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 191 | img1_shape, target1_shape = self.train_color_image_shape[self.cIndex[index]], self.train_color_label_shape[self.cIndex[index]] 192 | assert target1 == target1_shape 193 | img2_shape, target2_shape = self.train_thermal_image_shape[self.tIndex[index]], self.train_thermal_label_shape[self.tIndex[index]] 194 | assert target2 == target2_shape 195 | 196 | if random.uniform(0, 1) > 0.5: 197 | trans_rgb = self.transform_color 198 | else: 199 | trans_rgb = self.transform_color1 200 | 201 | img1 = trans_rgb(img1) 202 | img2 = self.transform_thermal(img2) 203 | 204 | img1_shape = self.transform_color_simple(img1_shape) 205 | img2_shape = self.transform_thermal_simple(img2_shape) 206 | 207 | return img1, img1_shape, img2, img2_shape, target1, target2 208 | 209 | def __len__(self): 210 | return len(self.train_color_label) 211 | 212 | 213 | class RegDBData(data.Dataset): 214 | def __init__(self, data_dir, trial, transform=None, colorIndex = None, thermalIndex = None): 215 | # Load training images (path) and labels 216 | data_dir = '/home/share/reid_dataset/RGB-IR_RegDB/' 217 | data_dir1 = '/home/share/fengjw/RegDB_shape/' 218 | train_color_list = data_dir + 'idx/train_visible_{}'.format(trial)+ '.txt' 219 | train_thermal_list = data_dir + 'idx/train_thermal_{}'.format(trial)+ '.txt' 220 | 221 | color_img_file, train_color_label = load_data(train_color_list) 222 | thermal_img_file, train_thermal_label = load_data(train_thermal_list) 223 | 224 | train_color_image = [] 225 | train_color_image_shape = [] 226 | for i in range(len(color_img_file)): 227 | 228 | img = Image.open(data_dir+ color_img_file[i]) 229 | img = img.resize((144, 288), Image.ANTIALIAS) 230 | pix_array = np.array(img) 231 | train_color_image.append(pix_array) 232 | 233 | 234 | img1 = Image.open(data_dir1+ color_img_file[i]) 235 | img1 = img1.resize((144, 288), Image.ANTIALIAS) 236 | pix_array1 = np.array(img1) 237 | train_color_image_shape.append(pix_array1) 238 | train_color_image_shape = np.array(train_color_image_shape) 239 | 240 | train_thermal_image = [] 241 | train_thermal_image_shape = [] 242 | for i in range(len(thermal_img_file)): 243 | img = Image.open(data_dir+ thermal_img_file[i]) 244 | img = img.resize((144, 288), Image.ANTIALIAS) 245 | pix_array = np.array(img) 246 | train_thermal_image.append(pix_array) 247 | 248 | img1 = Image.open(data_dir1+ thermal_img_file[i]) 249 | img1 = img1.resize((144, 288), Image.ANTIALIAS) 250 | pix_array1 = np.array(img1) 251 | train_thermal_image_shape.append(pix_array1) 252 | 253 | train_thermal_image_shape = np.array(train_thermal_image_shape) 254 | 255 | # BGR to RGB 256 | self.train_color_image = train_color_image 257 | self.train_color_label = train_color_label 258 | 259 | self.train_color_image_shape = train_color_image_shape 260 | 261 | # BGR to RGB 262 | self.train_thermal_image_shape = train_thermal_image_shape 263 | self.train_thermal_image = train_thermal_image 264 | self.train_thermal_label = train_thermal_label 265 | 266 | self.transform = transform 267 | self.cIndex = colorIndex 268 | self.tIndex = thermalIndex 269 | 270 | 271 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 272 | self.transform_thermal = transforms.Compose( [ 273 | transforms.ToPILImage(), 274 | transforms.Pad(10), 275 | transforms.RandomCrop((288, 144)), 276 | transforms.RandomHorizontalFlip(), 277 | transforms.ToTensor(), 278 | normalize, 279 | ChannelRandomErasing(probability = 0.5), 280 | ChannelAdapGray(probability =0.5)]) 281 | 282 | self.transform_thermal_simple = transforms.Compose( [ 283 | transforms.ToPILImage(), 284 | transforms.Pad(10), 285 | transforms.RandomCrop((288, 144)), 286 | transforms.RandomHorizontalFlip(), 287 | transforms.ToTensor(), 288 | normalize 289 | ]) 290 | self.transform_color_simple = transforms.Compose( [ 291 | transforms.ToPILImage(), 292 | transforms.Pad(10), 293 | transforms.RandomCrop((288, 144)), 294 | transforms.RandomHorizontalFlip(), 295 | transforms.ToTensor(), 296 | normalize 297 | ]) 298 | 299 | self.transform_color = transforms.Compose( [ 300 | transforms.ToPILImage(), 301 | transforms.Pad(10), 302 | transforms.RandomCrop((288, 144)), 303 | transforms.RandomHorizontalFlip(), 304 | # transforms.RandomGrayscale(p = 0.1), 305 | transforms.ToTensor(), 306 | normalize, 307 | ChannelRandomErasing(probability = 0.5)]) 308 | 309 | self.transform_color1 = transforms.Compose( [ 310 | transforms.ToPILImage(), 311 | transforms.Pad(10), 312 | transforms.RandomCrop((288, 144)), 313 | transforms.RandomHorizontalFlip(), 314 | transforms.ToTensor(), 315 | normalize, 316 | ChannelRandomErasing(probability = 0.5), 317 | ChannelExchange(gray = 2)]) 318 | 319 | def __getitem__(self, index): 320 | 321 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 322 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 323 | 324 | img1_shape = self.train_color_image_shape[self.cIndex[index]] 325 | img2_shape = self.train_thermal_image_shape[self.tIndex[index]] 326 | 327 | 328 | if random.uniform(0, 1) > 0.5: 329 | trans_rgb = self.transform_color 330 | else: 331 | trans_rgb = self.transform_color1 332 | 333 | img1 = trans_rgb(img1) 334 | img2 = self.transform_thermal(img2) 335 | 336 | img1_shape = self.transform_color_simple(img1_shape) 337 | img2_shape = self.transform_thermal_simple(img2_shape) 338 | 339 | return img1, img1_shape, img2, img2_shape, target1, target2 340 | 341 | def __len__(self): 342 | return len(self.train_color_label) 343 | 344 | 345 | 346 | def decoder_pic_path(fname): 347 | base = fname[0:4] 348 | modality = fname[5] 349 | if modality == '1' : 350 | modality_str = 'ir' 351 | else: 352 | modality_str = 'rgb' 353 | T_pos = fname.find('T') 354 | D_pos = fname.find('D') 355 | F_pos = fname.find('F') 356 | camera = fname[D_pos:T_pos] 357 | picture = fname[F_pos+1:] 358 | path = base + '/' + modality_str + '/' + camera + '/' + picture 359 | return path 360 | 361 | 362 | class VCM(object): 363 | def __init__(self, colorIndex = None, thermalIndex = None): 364 | # Load training images (path) and labels 365 | min_seq_len = 12 366 | data_dir = '/home/share/reid_dataset/HITSZ-VCM-UNZIP/' 367 | 368 | data_dir1 = '/home/share/fengjw/HITSZ-VCM-UNZIP_shape/' 369 | 370 | train_name_path = os.path.join(data_dir,'info/train_name.txt') 371 | track_train_info_path = os.path.join(data_dir,'info/track_train_info.txt') 372 | 373 | test_name_path = os.path.join(data_dir,'info/test_name.txt') 374 | track_test_info_path = os.path.join(data_dir,'info/track_test_info.txt') 375 | query_IDX_path = os.path.join(data_dir,'info/query_IDX.txt') 376 | 377 | # train_color_list = data_dir + 'idx/train_visible_{}'.format(trial)+ '.txt' 378 | # train_thermal_list = data_dir + 'idx/train_thermal_{}'.format(trial)+ '.txt' 379 | 380 | # color_img_file, train_color_label = load_data(train_color_list) 381 | # thermal_img_file, train_thermal_label = load_data(train_thermal_list) 382 | 383 | train_names = self._get_names(train_name_path) 384 | track_train = self._get_tracks(track_train_info_path) 385 | 386 | test_names = self._get_names(test_name_path) 387 | track_test = self._get_tracks(track_test_info_path) 388 | query_IDX = self._get_query_idx(query_IDX_path) 389 | query_IDX -= 1 390 | 391 | track_query = track_test[query_IDX,:] 392 | print('query') 393 | print(track_query) 394 | gallery_IDX = [i for i in range(track_test.shape[0]) if i not in query_IDX] 395 | track_gallery = track_test[gallery_IDX,:] 396 | print('gallery') 397 | print(track_gallery) 398 | 399 | #---------visible to infrared----------- 400 | gallery_IDX_1 = self._get_query_idx(query_IDX_path) 401 | gallery_IDX_1 -= 1 402 | track_gallery_1 = track_test[gallery_IDX_1,:] 403 | 404 | query_IDX_1 = [j for j in range(track_test.shape[0]) if j not in gallery_IDX_1] 405 | track_query_1 = track_test[query_IDX_1,:] 406 | #----------------------------------------- 407 | 408 | train_ir, train_ir_shape, num_train_tracklets_ir,num_train_imgs_ir,train_rgb, train_rgb_shape, num_train_tracklets_rgb,num_train_imgs_rgb,num_train_pids,ir_label,rgb_label = \ 409 | self._process_data_train(train_names,track_train,relabel=True,min_seq_len=min_seq_len, rootpath=data_dir, rootpath1=data_dir1) 410 | 411 | 412 | query, num_query_tracklets, num_query_pids, num_query_imgs = \ 413 | self._process_data_test(test_names, track_query, relabel=False, min_seq_len=min_seq_len, rootpath=data_dir) 414 | 415 | gallery, num_gallery_tracklets, num_gallery_pids, num_gallery_imgs = \ 416 | self._process_data_test(test_names, track_gallery, relabel=False, min_seq_len=min_seq_len, rootpath=data_dir) 417 | 418 | 419 | #--------visible to infrared----------- 420 | query_1, num_query_tracklets_1, num_query_pids_1, num_query_imgs_1 = \ 421 | self._process_data_test(test_names, track_query_1, relabel=False, min_seq_len=min_seq_len, rootpath=data_dir) 422 | 423 | gallery_1, num_gallery_tracklets_1, num_gallery_pids_1, num_gallery_imgs_1 = \ 424 | self._process_data_test(test_names, track_gallery_1, relabel=False, min_seq_len=min_seq_len, rootpath=data_dir) 425 | #--------------------------------------- 426 | 427 | 428 | print("=> VCM loaded") 429 | print("Dataset statistics:") 430 | print("---------------------------------") 431 | print("subset | # ids | # tracklets") 432 | print("---------------------------------") 433 | print("train_ir | {:5d} | {:8d}".format(num_train_pids,num_train_tracklets_ir)) 434 | print("train_rgb | {:5d} | {:8d}".format(num_train_pids,num_train_tracklets_rgb)) 435 | print("query | {:5d} | {:8d}".format(num_query_pids, num_query_tracklets)) 436 | print("gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_tracklets)) 437 | print("---------------------------------") 438 | print("ir_label | {}".format(np.unique(ir_label))) 439 | print("rgb_label | {}".format(np.unique(rgb_label))) 440 | 441 | 442 | 443 | self.train_ir = train_ir 444 | self.train_ir_shape = train_ir_shape 445 | self.train_rgb = train_rgb 446 | self.train_rgb_shape = train_rgb_shape 447 | self.ir_label = ir_label 448 | self.rgb_label = rgb_label 449 | 450 | self.query = query 451 | self.gallery = gallery 452 | 453 | self.num_train_pids = num_train_pids 454 | self.num_query_pids = num_query_pids 455 | self.num_gallery_pids = num_gallery_pids 456 | self.num_query_tracklets = num_query_tracklets 457 | self.num_gallery_tracklets = num_gallery_tracklets 458 | 459 | #------- visible to infrared------------ 460 | self.query_1 = query_1 461 | self.gallery_1 = gallery_1 462 | 463 | self.num_query_pids_1 = num_query_pids_1 464 | self.num_gallery_pids_1 = num_gallery_pids_1 465 | self.num_query_tracklets_1 = num_query_tracklets_1 466 | self.num_gallery_tracklets_1 = num_gallery_tracklets_1 467 | #--------------------------------------- 468 | 469 | 470 | def _get_names(self,fpath): 471 | """get image name, retuen name list""" 472 | names = [] 473 | with open(fpath,'r') as f: 474 | for line in f: 475 | new_line = line.rstrip() 476 | names.append(new_line) 477 | return names 478 | 479 | def _get_tracks(self,fpath): 480 | """get tracks file""" 481 | names = [] 482 | with open(fpath,'r') as f: 483 | for line in f: 484 | new_line = line.rstrip() 485 | new_line.split(' ') 486 | 487 | tmp = new_line.split(' ')[0:] 488 | 489 | tmp = list(map(int, tmp)) 490 | names.append(tmp) 491 | names = np.array(names) 492 | return names 493 | 494 | 495 | def _get_query_idx(self, fpath): 496 | with open(fpath, 'r') as f: 497 | for line in f: 498 | new_line = line.rstrip() 499 | new_line.split(' ') 500 | 501 | tmp = new_line.split(' ')[0:] 502 | 503 | 504 | tmp = list(map(int, tmp)) 505 | idxs = tmp 506 | idxs = np.array(idxs) 507 | print(idxs) 508 | return idxs 509 | 510 | def _process_data_train(self,names,meta_data,relabel=False,min_seq_len=0, rootpath=None, rootpath1=None): 511 | num_tracklets = meta_data.shape[0] 512 | pid_list = list(set(meta_data[:,3].tolist())) 513 | num_pids = len(pid_list) 514 | 515 | # dict {pid : label} 516 | if relabel: pid2label = {pid: label for label, pid in enumerate(pid_list)} 517 | print('pid_list') 518 | print(pid_list) 519 | print(pid2label) 520 | tracklets_ir = [] 521 | tracklets_ir_shape = [] 522 | num_imgs_per_tracklet_ir = [] 523 | ir_label = [] 524 | 525 | tracklets_rgb = [] 526 | tracklets_rgb_shape = [] 527 | num_imgs_per_tracklet_rgb = [] 528 | rgb_label = [] 529 | 530 | for tracklet_idx in range(num_tracklets): 531 | data = meta_data[tracklet_idx,...] 532 | m,start_index,end_index,pid,camid = data 533 | if relabel: pid = pid2label[pid] 534 | 535 | if m == 1: 536 | img_names = names[start_index-1:end_index] 537 | img_ir_paths = [os.path.join(rootpath,'Train',decoder_pic_path(img_name)) for img_name in img_names] 538 | img_ir_paths_shape = [os.path.join(rootpath1,'Train',decoder_pic_path(img_name)) for img_name in img_names] 539 | if len(img_ir_paths) >= min_seq_len: 540 | img_ir_paths = tuple(img_ir_paths) 541 | ir_label.append(pid) 542 | tracklets_ir.append((img_ir_paths,pid,camid)) 543 | # same id 544 | num_imgs_per_tracklet_ir.append(len(img_ir_paths)) 545 | 546 | # for shape 547 | img_ir_paths_shape = tuple(img_ir_paths_shape) 548 | tracklets_ir_shape.append((img_ir_paths_shape,pid,camid)) 549 | # same id 550 | # num_imgs_per_tracklet_ir.append(len(img_ir_paths_shape)) 551 | else: 552 | img_names = names[start_index-1:end_index] 553 | img_rgb_paths = [os.path.join(rootpath,'Train',decoder_pic_path(img_name)) for img_name in img_names] 554 | img_rgb_paths_shape = [os.path.join(rootpath1,'Train',decoder_pic_path(img_name)) for img_name in img_names] 555 | if len(img_rgb_paths) >= min_seq_len: 556 | img_rgb_paths = tuple(img_rgb_paths) 557 | img_rgb_paths_shape = tuple(img_rgb_paths_shape) 558 | rgb_label.append(pid) 559 | tracklets_rgb.append((img_rgb_paths,pid,camid)) 560 | tracklets_rgb_shape.append((img_rgb_paths_shape,pid,camid)) 561 | #same id 562 | num_imgs_per_tracklet_rgb.append(len(img_rgb_paths)) 563 | 564 | num_tracklets_ir = len(tracklets_ir) 565 | num_tracklets_rgb = len(tracklets_rgb) 566 | num_tracklets = num_tracklets_rgb + num_tracklets_ir 567 | 568 | return tracklets_ir, tracklets_ir_shape, num_tracklets_ir,num_imgs_per_tracklet_ir,tracklets_rgb, tracklets_rgb_shape, num_tracklets_rgb,num_imgs_per_tracklet_rgb,num_pids,ir_label,rgb_label 569 | 570 | def _process_data_test(self,names,meta_data,relabel=False,min_seq_len=0,rootpath=None): 571 | num_tracklets = meta_data.shape[0] 572 | pid_list = list(set(meta_data[:,3].tolist())) 573 | num_pids = len(pid_list) 574 | 575 | # dict {pid : label} 576 | if relabel: pid2label = {pid: label for label, pid in enumerate(pid_list)} 577 | tracklets = [] 578 | num_imgs_per_tracklet = [] 579 | 580 | for tracklet_idx in range(num_tracklets): 581 | data = meta_data[tracklet_idx,...] 582 | m,start_index,end_index,pid,camid = data 583 | if relabel: pid = pid2label[pid] 584 | 585 | img_names = names[start_index-1:end_index] 586 | img_paths = [os.path.join(rootpath,'Test',decoder_pic_path(img_name)) for img_name in img_names] 587 | if len(img_paths) >= min_seq_len: 588 | img_paths = tuple(img_paths) 589 | tracklets.append((img_paths, pid, camid)) 590 | num_imgs_per_tracklet.append(len(img_paths)) 591 | 592 | num_tracklets = len(tracklets) 593 | 594 | return tracklets, num_tracklets, num_pids, num_imgs_per_tracklet 595 | 596 | 597 | 598 | import torch 599 | class VideoDataset_test(data.Dataset): 600 | """Video Person ReID Dataset. 601 | Note batch data has shape (batch, seq_len, channel, height, width). 602 | """ 603 | sample_methods = ['evenly', 'random', 'all'] 604 | 605 | def __init__(self, dataset, seq_len=12, sample='evenly', transform=None): 606 | self.dataset = dataset 607 | self.seq_len = seq_len 608 | self.sample = sample 609 | self.transform = transform 610 | 611 | def __len__(self): 612 | return len(self.dataset) 613 | 614 | def __getitem__(self, index): 615 | img_paths, pid, camid = self.dataset[index] 616 | num = len(img_paths) 617 | 618 | S = self.seq_len 619 | sample_clip_ir = [] 620 | frame_indices_ir = list(range(num)) 621 | if num < S: 622 | strip_ir = list(range(num)) + [frame_indices_ir[-1]] * (S - num) 623 | for s in range(S): 624 | pool_ir = strip_ir[s * 1:(s + 1) * 1] 625 | sample_clip_ir.append(list(pool_ir)) 626 | else: 627 | inter_val_ir = math.ceil(num / S) 628 | strip_ir = list(range(num)) + [frame_indices_ir[-1]] * (inter_val_ir * S - num) 629 | for s in range(S): 630 | pool_ir = strip_ir[inter_val_ir * s:inter_val_ir * (s + 1)] 631 | sample_clip_ir.append(list(pool_ir)) 632 | 633 | sample_clip_ir = np.array(sample_clip_ir) 634 | 635 | if self.sample == 'dense': 636 | """ 637 | Sample all frames in a video into a list of clips, each clip contains seq_len frames, batch_size needs to be set to 1. 638 | This sampling strategy is used in test phase. 639 | """ 640 | cur_index=0 641 | frame_indices = range(num) 642 | indices_list=[] 643 | while num-cur_index > self.seq_len: 644 | indices_list.append(frame_indices[cur_index:cur_index+self.seq_len]) 645 | cur_index+=self.seq_len 646 | last_seq=frame_indices[cur_index:] 647 | last_seq = list(last_seq) 648 | for index in last_seq: 649 | if len(last_seq) >= self.seq_len: 650 | break 651 | last_seq.append(index) 652 | indices_list.append(last_seq) 653 | imgs_list=[] 654 | for indices in indices_list: 655 | imgs = [] 656 | for index in indices: 657 | index=int(index) 658 | img_path = img_paths[index] 659 | img = read_image(img_path) 660 | 661 | img = np.array(img) 662 | if self.transform is not None: 663 | img = self.transform(img) 664 | img = img.unsqueeze(0) 665 | imgs.append(img) 666 | imgs = torch.cat(imgs, dim=0) 667 | 668 | imgs_list.append(imgs) 669 | imgs_array = torch.stack(imgs_list) 670 | return imgs_array, pid, camid 671 | 672 | if self.sample == 'random': 673 | """ 674 | Randomly sample seq_len consecutive frames from num frames, 675 | if num is smaller than seq_len, then replicate items. 676 | This sampling strategy is used in training phase. 677 | """ 678 | num_ir = len(img_paths) 679 | frame_indices = range(num_ir) 680 | rand_end = max(0, len(frame_indices) - self.seq_len - 1) 681 | begin_index = random.randint(0, rand_end) 682 | end_index = min(begin_index + self.seq_len, len(frame_indices)) 683 | 684 | indices = frame_indices[begin_index:end_index] 685 | indices = list(indices) 686 | for index in indices: 687 | if len(indices) >= self.seq_len: 688 | break 689 | indices.append(index) 690 | indices = np.array(indices) 691 | imgs_ir = [] 692 | for index in indices: 693 | index = int(index) 694 | img_path = img_paths[index] 695 | img = read_image(img_path) 696 | 697 | img = np.array(img) 698 | if self.transform is not None: 699 | img = self.transform(img) 700 | 701 | imgs_ir.append(img) 702 | imgs_ir = torch.cat(imgs_ir, dim=0) 703 | return imgs_ir, pid, camid 704 | 705 | if self.sample == 'video_test': 706 | number = sample_clip_ir[:, 0] 707 | imgs_ir = [] 708 | for index in number: 709 | index = int(index) 710 | img_path = img_paths[index] 711 | img = read_image(img_path) 712 | 713 | img = np.array(img) 714 | if self.transform is not None: 715 | img = self.transform(img) 716 | 717 | imgs_ir.append(img.unsqueeze(0)) 718 | imgs_ir = torch.cat(imgs_ir, dim=0) 719 | return imgs_ir, pid, camid 720 | else: 721 | raise KeyError("Unknown sample method: {}. Expected one of {}".format(self.sample, self.sample_methods)) 722 | 723 | 724 | class VideoDataset_train(data.Dataset): 725 | """Video Person ReID Dataset. 726 | Note batch data has shape (batch, seq_len, channel, height, width). 727 | """ 728 | sample_methods = ['evenly', 'random', 'all'] 729 | 730 | def __init__(self, dataset_ir,dataset_rgb, dataset_ir_shape, dataset_rgb_shape, seq_len=12, sample='evenly', transform=None, index1=[], index2=[]): 731 | self.dataset_ir = dataset_ir 732 | self.dataset_ir_shape = dataset_ir_shape 733 | self.dataset_rgb = dataset_rgb 734 | self.dataset_rgb_shape = dataset_rgb_shape 735 | self.seq_len = 3 736 | self.sample = sample 737 | self.transform = transform 738 | self.index1 = index1 739 | self.index2 = index2 740 | 741 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 742 | self.transform_thermal = self.transform #transforms.Compose( [ 743 | # transforms.ToPILImage(), 744 | # transforms.Resize((288,144)), 745 | # transforms.Pad(10), 746 | # transforms.RandomCrop((288, 144)), 747 | # transforms.RandomHorizontalFlip(), 748 | # transforms.ToTensor(), 749 | # normalize, 750 | # ChannelRandomErasing(probability = 0.5), 751 | # ChannelAdapGray(probability =0.5)]) 752 | 753 | self.transform_thermal_simple = self.transform#transforms.Compose( [ 754 | # transforms.ToPILImage(), 755 | # transforms.Resize((288,144)), 756 | # transforms.Pad(10), 757 | # transforms.RandomCrop((288, 144)), 758 | # transforms.RandomHorizontalFlip(), 759 | # transforms.ToTensor(), 760 | # normalize 761 | # ]) 762 | self.transform_color_simple = self.transform #transforms.Compose( [ 763 | # transforms.ToPILImage(), 764 | # transforms.Resize((288,144)), 765 | # transforms.Pad(10), 766 | # transforms.RandomCrop((288, 144)), 767 | # transforms.RandomHorizontalFlip(), 768 | # transforms.ToTensor(), 769 | # normalize 770 | # ]) 771 | 772 | self.transform_color = self.transform#transforms.Compose( [ 773 | # transforms.ToPILImage(), 774 | # transforms.Resize((288,144)), 775 | # transforms.Pad(10), 776 | # transforms.RandomCrop((288, 144)), 777 | # transforms.RandomHorizontalFlip(), 778 | # # transforms.RandomGrayscale(p = 0.1), 779 | # transforms.ToTensor(), 780 | # normalize, 781 | # ChannelRandomErasing(probability = 0.5)]) 782 | 783 | self.transform_color1 = self.transform#transforms.Compose( [ 784 | # transforms.ToPILImage(), 785 | # transforms.Resize((288,144)), 786 | # transforms.Pad(10), 787 | # transforms.RandomCrop((288, 144)), 788 | # transforms.RandomHorizontalFlip(), 789 | # transforms.ToTensor(), 790 | # normalize, 791 | # ChannelRandomErasing(probability = 0.5), 792 | # ChannelExchange(gray = 2)]) 793 | 794 | 795 | def __len__(self): 796 | return len(self.dataset_rgb) 797 | 798 | 799 | def __getitem__(self, index): 800 | 801 | if random.uniform(0, 1) > 0.5: 802 | trans_rgb = self.transform_color 803 | else: 804 | trans_rgb = self.transform_color1 805 | 806 | img_ir_paths, pid_ir, camid_ir = self.dataset_ir[self.index2[index]] 807 | img_ir_paths_shape, pid_ir_shape, camid_ir_shape = self.dataset_ir_shape[self.index2[index]] 808 | 809 | num_ir = len(img_ir_paths) 810 | 811 | img_rgb_paths,pid_rgb,camid_rgb = self.dataset_rgb[self.index1[index]] 812 | img_rgb_paths_shape, pid_rgb_shape, camid_rgb_shape = self.dataset_rgb_shape[self.index1[index]] 813 | num_rgb = len(img_rgb_paths) 814 | 815 | idx1 = np.random.choice(num_ir, self.seq_len) 816 | imgs_ir = [] 817 | imgs_ir_shape = [] 818 | for index in idx1: 819 | index = int(index) 820 | img_path = img_ir_paths[index] 821 | img_path_shape = img_ir_paths_shape[index] 822 | img = read_image(img_path) 823 | img_shape = read_image(img_path_shape) 824 | img = np.array(img) 825 | img_shape = np.array(img_shape) 826 | img = self.transform_thermal(img) 827 | img_shape = self.transform_thermal_simple(img_shape) 828 | 829 | imgs_ir.append(img.unsqueeze(0)) 830 | imgs_ir_shape.append(img_shape.unsqueeze(0)) 831 | imgs_ir = torch.cat(imgs_ir, dim=0) 832 | imgs_ir_shape = torch.cat(imgs_ir_shape, dim=0) 833 | 834 | idx2 = np.random.choice(num_rgb, self.seq_len) 835 | imgs_rgb = [] 836 | imgs_rgb_shape = [] 837 | for index in idx2: 838 | index = int(index) 839 | img_path = img_rgb_paths[index] 840 | img_path_shape = img_rgb_paths_shape[index] 841 | img = read_image(img_path) 842 | img_shape = read_image(img_path_shape) 843 | img = np.array(img) 844 | img_shape = np.array(img_shape) 845 | 846 | img = trans_rgb(img) 847 | img_shape = self.transform_color_simple(img_shape) 848 | 849 | imgs_rgb.append(img.unsqueeze(0)) 850 | imgs_rgb_shape.append(img_shape.unsqueeze(0)) 851 | imgs_rgb = torch.cat(imgs_rgb, dim=0) 852 | imgs_rgb_shape = torch.cat(imgs_rgb_shape, dim=0) 853 | pid_ir = torch.tensor(pid_ir).repeat(self.seq_len) 854 | pid_rgb = torch.tensor(pid_rgb).repeat(self.seq_len) 855 | 856 | 857 | 858 | return imgs_rgb, imgs_rgb_shape, imgs_ir, imgs_ir_shape, pid_rgb, pid_ir 859 | 860 | 861 | 862 | 863 | 864 | class TestData(data.Dataset): 865 | def __init__(self, test_img_file, test_label, test_cam, transform=None, img_size = (144,288)): 866 | 867 | test_image = [] 868 | for i in range(len(test_img_file)): 869 | img = Image.open(test_img_file[i]) 870 | img = img.resize((img_size[0], img_size[1]), Image.ANTIALIAS) 871 | pix_array = np.array(img) 872 | test_image.append(pix_array) 873 | test_image = np.array(test_image) 874 | self.test_image = test_image 875 | self.test_label = test_label 876 | self.test_cam = test_cam 877 | self.transform = transform 878 | 879 | def __getitem__(self, index): 880 | img1, target1, cam1 = self.test_image[index], self.test_label[index], self.test_cam[index] 881 | img1 = self.transform(img1) 882 | return img1, target1, cam1 883 | 884 | def __len__(self): 885 | return len(self.test_image) 886 | 887 | 888 | class TestDataOld(data.Dataset): 889 | def __init__(self, data_dir, test_img_file, test_label, transform=None, img_size = (144,288)): 890 | 891 | test_image = [] 892 | for i in range(len(test_img_file)): 893 | img = Image.open(data_dir + test_img_file[i]) 894 | img = img.resize((img_size[0], img_size[1]), Image.ANTIALIAS) 895 | pix_array = np.array(img) 896 | test_image.append(pix_array) 897 | test_image = np.array(test_image) 898 | self.test_image = test_image 899 | self.test_label = test_label 900 | self.transform = transform 901 | 902 | def __getitem__(self, index): 903 | img1, target1 = self.test_image[index], self.test_label[index] 904 | img1 = self.transform(img1) 905 | return img1, target1 906 | 907 | def __len__(self): 908 | return len(self.test_image) 909 | def load_data(input_data_path ): 910 | with open(input_data_path) as f: 911 | data_file_list = open(input_data_path, 'rt').read().splitlines() 912 | # Get full list of image and labels 913 | file_image = [s.split(' ')[0] for s in data_file_list] 914 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 915 | return file_image, file_label 916 | 917 | 918 | 919 | 920 | if __name__ == '__main__': 921 | dataset = VCM() -------------------------------------------------------------------------------- /data_manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import numpy as np 4 | import random 5 | 6 | def process_query_sysu(mode = 'all', relabel=False,data_path_ori = '/home/share/reid_dataset/SYSU-MM01/'): 7 | if mode== 'all': 8 | ir_cameras = ['cam3','cam6'] 9 | elif mode =='indoor': 10 | ir_cameras = ['cam3','cam6'] 11 | 12 | 13 | file_path = os.path.join(data_path_ori,'exp/test_id.txt') 14 | files_ir = [] 15 | 16 | with open(file_path, 'r') as file: 17 | ids = file.read().splitlines() 18 | ids = [int(y) for y in ids[0].split(',')] 19 | ids = ["%04d" % x for x in ids] 20 | 21 | for id in sorted(ids): 22 | for cam in ir_cameras: 23 | img_dir = os.path.join(data_path_ori,cam,id) 24 | if os.path.isdir(img_dir): 25 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 26 | files_ir.extend(new_files) 27 | query_img = [] 28 | query_id = [] 29 | query_cam = [] 30 | for img_path in files_ir: 31 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 32 | query_img.append(img_path) 33 | query_id.append(pid) 34 | query_cam.append(camid) 35 | return query_img, np.array(query_id), np.array(query_cam) 36 | def process_gallery_sysu(mode = 'all', trial = 0, data_path_ori = '/home/share/reid_dataset/SYSU-MM01/'): 37 | 38 | random.seed(trial) 39 | 40 | if mode== 'all': 41 | rgb_cameras = ['cam1','cam2','cam4','cam5'] 42 | elif mode =='indoor': 43 | rgb_cameras = ['cam1','cam2'] 44 | 45 | file_path = os.path.join(data_path_ori,'exp/test_id.txt') 46 | files_rgb = [] 47 | with open(file_path, 'r') as file: 48 | ids = file.read().splitlines() 49 | ids = [int(y) for y in ids[0].split(',')] 50 | ids = ["%04d" % x for x in ids] 51 | for id in sorted(ids): 52 | for cam in rgb_cameras: 53 | img_dir = os.path.join(data_path_ori,cam,id) 54 | if os.path.isdir(img_dir): 55 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 56 | files_rgb.append(random.choice(new_files)) 57 | 58 | gall_img = [] 59 | gall_id = [] 60 | gall_cam = [] 61 | for img_path in files_rgb: 62 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 63 | gall_img.append(img_path) 64 | gall_id.append(pid) 65 | gall_cam.append(camid) 66 | return gall_img, np.array(gall_id), np.array(gall_cam) 67 | 68 | def process_gallery_sysu_all(mode = 'all', data_path_ori = '/home/share/reid_dataset/SYSU-MM01/'): 69 | if mode== 'all': 70 | rgb_cameras = ['cam1','cam2','cam4','cam5'] 71 | elif mode =='indoor': 72 | rgb_cameras = ['cam1','cam2'] 73 | 74 | file_path = os.path.join(data_path_ori,'exp/test_id.txt') 75 | files_rgb = [] 76 | with open(file_path, 'r') as file: 77 | ids = file.read().splitlines() 78 | ids = [int(y) for y in ids[0].split(',')] 79 | ids = ["%04d" % x for x in ids] 80 | for id in sorted(ids): 81 | for cam in rgb_cameras: 82 | img_dir = os.path.join(data_path_ori,cam,id) 83 | if os.path.isdir(img_dir): 84 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 85 | # files_rgb.append(random.choice(new_files)) 86 | files_rgb.extend(new_files) 87 | 88 | gall_img = [] 89 | gall_id = [] 90 | gall_cam = [] 91 | for img_path in files_rgb: 92 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 93 | gall_img.append(img_path) 94 | gall_id.append(pid) 95 | gall_cam.append(camid) 96 | return gall_img, np.array(gall_id), np.array(gall_cam) 97 | 98 | def process_test_regdb(img_dir, trial = 1, modal = 'visible'): 99 | if modal=='visible': 100 | input_data_path = img_dir + 'idx/test_visible_{}'.format(trial) + '.txt' 101 | elif modal=='thermal': 102 | input_data_path = img_dir + 'idx/test_thermal_{}'.format(trial) + '.txt' 103 | 104 | with open(input_data_path) as f: 105 | data_file_list = open(input_data_path, 'rt').read().splitlines() 106 | # Get full list of image and labels 107 | file_image = [img_dir + '/' + s.split(' ')[0] for s in data_file_list] 108 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 109 | 110 | return file_image, np.array(file_label) -------------------------------------------------------------------------------- /eval_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import numpy as np 3 | """Cross-Modality ReID""" 4 | import pdb 5 | 6 | def eval_sysu(distmat, q_pids, g_pids, q_camids, g_camids, max_rank = 20): 7 | """Evaluation with sysu metric 8 | Key: for each query identity, its gallery images from the same camera view are discarded. "Following the original setting in ite dataset" 9 | """ 10 | num_q, num_g = distmat.shape 11 | if num_g < max_rank: 12 | max_rank = num_g 13 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 14 | indices = np.argsort(distmat, axis=1) 15 | pred_label = g_pids[indices] 16 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 17 | 18 | # compute cmc curve for each query 19 | new_all_cmc = [] 20 | all_cmc = [] 21 | all_AP = [] 22 | all_INP = [] 23 | num_valid_q = 0. # number of valid query 24 | right_idx = [] 25 | for q_idx in range(num_q): 26 | # get query pid and camid 27 | q_pid = q_pids[q_idx] 28 | q_camid = q_camids[q_idx] 29 | 30 | # remove gallery samples that have the same pid and camid with query 31 | order = indices[q_idx] 32 | remove = (q_camid == 3) & (g_camids[order] == 2) 33 | keep = np.invert(remove) 34 | 35 | # compute cmc curve 36 | # the cmc calculation is different from standard protocol 37 | # we follow the protocol of the author's released code 38 | new_cmc = pred_label[q_idx][keep] 39 | new_index = np.unique(new_cmc, return_index=True)[1] 40 | new_cmc = [new_cmc[index] for index in sorted(new_index)] 41 | 42 | new_match = (new_cmc == q_pid).astype(np.int32) 43 | new_cmc = new_match.cumsum() 44 | new_all_cmc.append(new_cmc[:max_rank]) 45 | 46 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 47 | if not np.any(orig_cmc): 48 | # this condition is true when query identity does not appear in gallery 49 | continue 50 | 51 | cmc = orig_cmc.cumsum() 52 | 53 | # compute mINP 54 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook 55 | pos_idx = np.where(orig_cmc == 1) 56 | pos_max_idx = np.max(pos_idx) 57 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0) 58 | all_INP.append(inp) 59 | 60 | cmc[cmc > 1] = 1 61 | 62 | all_cmc.append(cmc[:max_rank]) 63 | num_valid_q += 1. 64 | 65 | # compute average precision 66 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 67 | num_rel = orig_cmc.sum() 68 | tmp_cmc = orig_cmc.cumsum() 69 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 70 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 71 | AP = tmp_cmc.sum() / num_rel 72 | all_AP.append(AP) 73 | if pos_idx[0][0] == 0: 74 | right_idx.append(q_idx) 75 | 76 | 77 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 78 | 79 | all_cmc = np.asarray(all_cmc).astype(np.float32) 80 | all_cmc = all_cmc.sum(0) / num_valid_q # standard CMC 81 | 82 | new_all_cmc = np.asarray(new_all_cmc).astype(np.float32) 83 | new_all_cmc = new_all_cmc.sum(0) / num_valid_q 84 | mAP = np.mean(all_AP) 85 | mINP = np.mean(all_INP) 86 | return new_all_cmc, mAP, mINP#, right_idx 87 | 88 | 89 | 90 | def eval_regdb(distmat, q_pids, g_pids, max_rank = 20): 91 | num_q, num_g = distmat.shape 92 | if num_g < max_rank: 93 | max_rank = num_g 94 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 95 | indices = np.argsort(distmat, axis=1) 96 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 97 | 98 | # compute cmc curve for each query 99 | all_cmc = [] 100 | all_AP = [] 101 | all_INP = [] 102 | num_valid_q = 0. # number of valid query 103 | 104 | # only two cameras 105 | q_camids = np.ones(num_q).astype(np.int32) 106 | g_camids = 2* np.ones(num_g).astype(np.int32) 107 | 108 | for q_idx in range(num_q): 109 | # get query pid and camid 110 | q_pid = q_pids[q_idx] 111 | q_camid = q_camids[q_idx] 112 | 113 | # remove gallery samples that have the same pid and camid with query 114 | order = indices[q_idx] 115 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 116 | keep = np.invert(remove) 117 | 118 | # compute cmc curve 119 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 120 | if not np.any(raw_cmc): 121 | # this condition is true when query identity does not appear in gallery 122 | continue 123 | 124 | cmc = raw_cmc.cumsum() 125 | 126 | # compute mINP 127 | # refernece Deep Learning for Person Re-identification: A Survey and Outlook 128 | pos_idx = np.where(raw_cmc == 1) 129 | pos_max_idx = np.max(pos_idx) 130 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0) 131 | all_INP.append(inp) 132 | 133 | cmc[cmc > 1] = 1 134 | 135 | all_cmc.append(cmc[:max_rank]) 136 | num_valid_q += 1. 137 | 138 | # compute average precision 139 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 140 | num_rel = raw_cmc.sum() 141 | tmp_cmc = raw_cmc.cumsum() 142 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 143 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 144 | AP = tmp_cmc.sum() / num_rel 145 | all_AP.append(AP) 146 | 147 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 148 | 149 | all_cmc = np.asarray(all_cmc).astype(np.float32) 150 | all_cmc = all_cmc.sum(0) / num_valid_q 151 | mAP = np.mean(all_AP) 152 | mINP = np.mean(all_INP) 153 | return all_cmc, mAP, mINP 154 | 155 | 156 | def eval_vcm(distmat, q_pids, g_pids, q_camids, g_camids, max_rank=20): 157 | # print("it is evaluate ing now ") 158 | num_q, num_g = distmat.shape 159 | if num_g < max_rank: 160 | max_rank = num_g 161 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 162 | indices = np.argsort(distmat, axis=1) 163 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 164 | 165 | # compute cmc curve for each query 166 | all_cmc = [] 167 | all_AP = [] 168 | num_valid_q = 0. 169 | for q_idx in range(num_q): 170 | # get query pid and camid 171 | q_pid = q_pids[q_idx] 172 | q_camid = q_camids[q_idx] 173 | 174 | # remove gallery samples that have the same pid and camid with query 175 | order = indices[q_idx] 176 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 177 | keep = np.invert(remove) 178 | 179 | # compute cmc curve 180 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 181 | if not np.any(orig_cmc): 182 | # this condition is true when query identity does not appear in gallery 183 | continue 184 | 185 | cmc = orig_cmc.cumsum() 186 | cmc[cmc > 1] = 1 187 | 188 | all_cmc.append(cmc[:max_rank]) 189 | num_valid_q += 1. 190 | 191 | # compute average precision 192 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 193 | num_rel = orig_cmc.sum() 194 | tmp_cmc = orig_cmc.cumsum() 195 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 196 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 197 | AP = tmp_cmc.sum() / num_rel 198 | all_AP.append(AP) 199 | 200 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 201 | 202 | all_cmc = np.asarray(all_cmc).astype(np.float32) 203 | all_cmc = all_cmc.sum(0) / num_valid_q 204 | mAP = np.mean(all_AP) 205 | 206 | return all_cmc, mAP -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | from turtle import position 2 | from urllib.parse import quote_plus 3 | import torch 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd.function import Function 8 | from torch.autograd import Variable 9 | import pdb 10 | 11 | class OriTripletLoss(nn.Module): 12 | """Triplet loss with hard positive/negative mining. 13 | 14 | Reference: 15 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 16 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 17 | 18 | Args: 19 | - margin (float): margin for triplet. 20 | """ 21 | 22 | def __init__(self, batch_size=None, margin=0.3): 23 | super(OriTripletLoss, self).__init__() 24 | self.margin = margin 25 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 26 | 27 | def forward(self, inputs, targets): 28 | """ 29 | Args: 30 | - inputs: feature matrix with shape (batch_size, feat_dim) 31 | - targets: ground truth labels with shape (num_classes) 32 | """ 33 | n = inputs.size(0) 34 | 35 | # Compute pairwise distance, replace by the official when merged 36 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 37 | dist = dist + dist.t() 38 | # dist.addmm_(1, -2, inputs, inputs.t()) 39 | dist.addmm_(inputs, inputs.t(), beta=1, alpha=-2) 40 | 41 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 42 | 43 | # For each anchor, find the hardest positive and negative 44 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 45 | dist_ap, dist_an = [], [] 46 | for i in range(n): 47 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 48 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 49 | dist_ap = torch.cat(dist_ap) 50 | dist_an = torch.cat(dist_an) 51 | 52 | # Compute ranking hinge loss 53 | y = torch.ones_like(dist_an) 54 | loss = self.ranking_loss(dist_an, dist_ap, y) 55 | 56 | # compute accuracy 57 | correct = torch.ge(dist_an, dist_ap).sum().item() 58 | return loss, correct 59 | 60 | # modal = (torch.arange(n) // (n/2)).cuda() 61 | # modalmask = modal.expand(n, n).ne(modal.expand(n, n).t()) 62 | 63 | # Adaptive weights 64 | def softmax_weights(dist, mask): 65 | max_v = torch.max(dist * mask, dim=1, keepdim=True)[0] 66 | diff = dist - max_v 67 | Z = torch.sum(torch.exp(diff) * mask, dim=1, keepdim=True) + 1e-6 # avoid division by zero 68 | W = torch.exp(diff) * mask / Z 69 | return W 70 | 71 | def normalize(x, axis=-1): 72 | """Normalizing to unit length along the specified dimension. 73 | Args: 74 | x: pytorch Variable 75 | Returns: 76 | x: pytorch Variable, same shape as input 77 | """ 78 | x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12) 79 | return x 80 | 81 | class TripletLoss_WRT(nn.Module): 82 | """Weighted Regularized Triplet'.""" 83 | 84 | def __init__(self): 85 | super(TripletLoss_WRT, self).__init__() 86 | self.ranking_loss = nn.SoftMarginLoss() 87 | 88 | def forward(self, inputs, targets, normalize_feature=False): 89 | if normalize_feature: 90 | inputs = normalize(inputs, axis=-1) 91 | dist_mat = pdist_torch(inputs, inputs) 92 | 93 | N = dist_mat.size(0) 94 | # shape [N, N] 95 | is_pos = targets.expand(N, N).eq(targets.expand(N, N).t()).float() 96 | is_neg = targets.expand(N, N).ne(targets.expand(N, N).t()).float() 97 | 98 | # `dist_ap` means distance(anchor, positive) 99 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 100 | dist_ap = dist_mat * is_pos 101 | dist_an = dist_mat * is_neg 102 | 103 | weights_ap = softmax_weights(dist_ap, is_pos) 104 | weights_an = softmax_weights(-dist_an, is_neg) 105 | furthest_positive = torch.sum(dist_ap * weights_ap, dim=1) 106 | closest_negative = torch.sum(dist_an * weights_an, dim=1) 107 | 108 | y = furthest_positive.new().resize_as_(furthest_positive).fill_(1) 109 | loss = self.ranking_loss(closest_negative - furthest_positive, y) 110 | 111 | # compute accuracy 112 | correct = torch.ge(closest_negative, furthest_positive).sum().item() 113 | return loss, correct 114 | 115 | class TripletLoss_ADP(nn.Module): 116 | """Weighted Regularized Triplet'.""" 117 | 118 | def __init__(self, alpha =1, gamma = 1, square = 0): 119 | super(TripletLoss_ADP, self).__init__() 120 | self.ranking_loss = nn.SoftMarginLoss() 121 | self.alpha = alpha 122 | self.gamma = gamma 123 | self.square = square 124 | 125 | def forward(self, inputs, targets, normalize_feature=False): 126 | if normalize_feature: 127 | inputs = normalize(inputs, axis=-1) 128 | dist_mat = pdist_torch(inputs, inputs) 129 | 130 | N = dist_mat.size(0) 131 | # shape [N, N] 132 | is_pos = targets.expand(N, N).eq(targets.expand(N, N).t()).float() 133 | is_neg = targets.expand(N, N).ne(targets.expand(N, N).t()).float() 134 | 135 | # `dist_ap` means distance(anchor, positive) 136 | # both `dist_ap` and `relative_p_inds` with shape [N, 1] 137 | dist_ap = dist_mat * is_pos 138 | dist_an = dist_mat * is_neg 139 | 140 | weights_ap = softmax_weights(dist_ap*self.alpha, is_pos) 141 | weights_an = softmax_weights(-dist_an*self.alpha, is_neg) 142 | furthest_positive = torch.sum(dist_ap * weights_ap, dim=1) 143 | closest_negative = torch.sum(dist_an * weights_an, dim=1) 144 | 145 | 146 | # ranking_loss = nn.SoftMarginLoss(reduction = 'none') 147 | # loss1 = ranking_loss(closest_negative - furthest_positive, y) 148 | 149 | # squared difference 150 | if self.square ==0: 151 | y = furthest_positive.new().resize_as_(furthest_positive).fill_(1) 152 | loss = self.ranking_loss(self.gamma*(closest_negative - furthest_positive), y) 153 | else: 154 | diff_pow = torch.pow(furthest_positive - closest_negative, 2) * self.gamma 155 | diff_pow =torch.clamp_max(diff_pow, max=10) 156 | 157 | # Compute ranking hinge loss 158 | y1 = (furthest_positive > closest_negative).float() 159 | y2 = y1 - 1 160 | y = -(y1 + y2) 161 | 162 | loss = self.ranking_loss(diff_pow, y) 163 | 164 | # loss = self.ranking_loss(self.gamma*(closest_negative - furthest_positive), y) 165 | 166 | # compute accuracy 167 | correct = torch.ge(closest_negative, furthest_positive).sum().item() 168 | return loss, correct 169 | 170 | 171 | 172 | def sce(new_logits, old_logits): 173 | loss_ke_ce = (- F.softmax(old_logits, dim=1).detach() * F.log_softmax(new_logits,dim=1)).mean(0).sum() 174 | return loss_ke_ce 175 | 176 | 177 | 178 | 179 | 180 | def shape_cpmt_cross_modal_ce(x1,y1,outputs): 181 | 182 | with torch.no_grad(): 183 | batch_size = y1.shape[0] 184 | 185 | rgb_shape_normed = F.normalize(outputs['shape']['zp'][:x1.shape[0]], p=2, dim=1) 186 | ir_shape_normed = F.normalize(outputs['shape']['zp'][x1.shape[0]:], p=2, dim=1) 187 | rgb_ir_shape_cossim = torch.mm(rgb_shape_normed,ir_shape_normed.t()) 188 | mask = y1.expand(batch_size,batch_size).eq(y1.expand(batch_size, batch_size).t()) 189 | target4rgb, target4ir = [], [] 190 | # idx_temp = torch.arange(batch_size) 191 | idx_temp = torch.arange(batch_size,device=rgb_shape_normed.device) 192 | for i in range(batch_size): 193 | sorted_idx_rgb = rgb_ir_shape_cossim[i][mask[i]].sort(descending=False)[1] 194 | sorted_idx_ir = rgb_ir_shape_cossim.t()[i][mask.t()[i]].sort(descending=False)[1] 195 | target4rgb.append(idx_temp[mask[i]][sorted_idx_rgb[0]].unsqueeze(0)) 196 | target4ir.append(idx_temp[mask.t()[i]][sorted_idx_ir[0]].unsqueeze(0)) 197 | target4rgb = torch.cat(target4rgb) 198 | target4ir = torch.cat(target4ir) 199 | loss_top1 = sce(outputs['rgbir']['logit2'][:x1.shape[0]],outputs['rgbir']['logit2'][x1.shape[0]:][target4rgb]) + sce(outputs['rgbir']['logit2'][x1.shape[0]:],outputs['rgbir']['logit2'][:x1.shape[0]][target4ir]) 200 | 201 | 202 | loss_random = sce(outputs['rgbir']['logit2'][:x1.shape[0]],outputs['rgbir']['logit2'][x1.shape[0]:])+sce(outputs['rgbir']['logit2'][x1.shape[0]:],outputs['rgbir']['logit2'][:x1.shape[0]]) 203 | 204 | loss_kl_rgbir2 = loss_random+loss_top1 205 | 206 | return loss_kl_rgbir2 207 | 208 | 209 | 210 | def pdist_torch(emb1, emb2): 211 | ''' 212 | compute the eucilidean distance matrix between embeddings1 and embeddings2 213 | using gpu 214 | ''' 215 | m, n = emb1.shape[0], emb2.shape[0] 216 | emb1_pow = torch.pow(emb1, 2).sum(dim = 1, keepdim = True).expand(m, n) 217 | emb2_pow = torch.pow(emb2, 2).sum(dim = 1, keepdim = True).expand(n, m).t() 218 | dist_mtx = emb1_pow + emb2_pow 219 | # dist_mtx = dist_mtx.addmm_(1, -2, emb1, emb2.t()) 220 | dist_mtx = dist_mtx.addmm_(emb1, emb2.t(), beta=1, alpha=-2) 221 | # dist_mtx = dist_mtx.clamp(min = 1e-12) 222 | dist_mtx = dist_mtx.clamp(min = 1e-12).sqrt() 223 | return dist_mtx 224 | 225 | 226 | def pdist_np(emb1, emb2): 227 | ''' 228 | compute the eucilidean distance matrix between embeddings1 and embeddings2 229 | using cpu 230 | ''' 231 | m, n = emb1.shape[0], emb2.shape[0] 232 | emb1_pow = np.square(emb1).sum(axis = 1)[..., np.newaxis] 233 | emb2_pow = np.square(emb2).sum(axis = 1)[np.newaxis, ...] 234 | dist_mtx = -2 * np.matmul(emb1, emb2.T) + emb1_pow + emb2_pow 235 | # dist_mtx = np.sqrt(dist_mtx.clip(min = 1e-12)) 236 | return dist_mtx -------------------------------------------------------------------------------- /model_bn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | from resnet import resnet18, resnet50, resnet101 6 | from loss import sce, OriTripletLoss, shape_cpmt_cross_modal_ce 7 | class Normalize(nn.Module): 8 | def __init__(self, power=2): 9 | super(Normalize, self).__init__() 10 | self.power = power 11 | 12 | def forward(self, x): 13 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 14 | out = x.div(norm) 15 | return out 16 | 17 | class Non_local(nn.Module): 18 | def __init__(self, in_channels, reduc_ratio=2): 19 | super(Non_local, self).__init__() 20 | 21 | self.in_channels = in_channels 22 | self.inter_channels = reduc_ratio//reduc_ratio 23 | 24 | self.g = nn.Sequential( 25 | nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, 26 | padding=0), 27 | ) 28 | 29 | self.W = nn.Sequential( 30 | nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, 31 | kernel_size=1, stride=1, padding=0), 32 | nn.BatchNorm2d(self.in_channels), 33 | ) 34 | self.Wbn_shape = nn.BatchNorm2d(self.in_channels) 35 | nn.init.constant_(self.W[1].weight, 0.0) 36 | nn.init.constant_(self.W[1].bias, 0.0) 37 | nn.init.constant_(self.Wbn_shape.weight, 0.0) 38 | nn.init.constant_(self.Wbn_shape.bias, 0.0) 39 | 40 | 41 | self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 42 | kernel_size=1, stride=1, padding=0) 43 | 44 | self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, 45 | kernel_size=1, stride=1, padding=0) 46 | 47 | def forward(self, x, shape=False): 48 | ''' 49 | :param x: (b, c, t, h, w) 50 | :return: 51 | ''' 52 | 53 | batch_size = x.size(0) 54 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 55 | g_x = g_x.permute(0, 2, 1) 56 | 57 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 58 | theta_x = theta_x.permute(0, 2, 1) 59 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 60 | f = torch.matmul(theta_x, phi_x) 61 | N = f.size(-1) 62 | # f_div_C = torch.nn.functional.softmax(f, dim=-1) 63 | f_div_C = f / N 64 | 65 | y = torch.matmul(f_div_C, g_x) 66 | y = y.permute(0, 2, 1).contiguous() 67 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 68 | if shape: 69 | W_y = self.Wbn_shape(self.W[0](y)) 70 | else: 71 | W_y = self.W(y) 72 | z = W_y + x 73 | 74 | return z 75 | 76 | 77 | # ##################################################################### 78 | def weights_init_kaiming(m): 79 | classname = m.__class__.__name__ 80 | # print(classname) 81 | if classname.find('Conv') != -1: 82 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 83 | elif classname.find('Linear') != -1: 84 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 85 | if m.bias: 86 | init.zeros_(m.bias.data) 87 | elif classname.find('BatchNorm1d') != -1: 88 | init.normal_(m.weight.data, 1.0, 0.01) 89 | init.zeros_(m.bias.data) 90 | 91 | def weights_init_classifier(m): 92 | classname = m.__class__.__name__ 93 | if classname.find('Linear') != -1: 94 | init.normal_(m.weight.data, 0, 0.001) 95 | if m.bias: 96 | init.zeros_(m.bias.data) 97 | 98 | 99 | 100 | class visible_module(nn.Module): 101 | def __init__(self, isshape, modalbn): 102 | super(visible_module, self).__init__() 103 | 104 | model_v = resnet50(pretrained=True, 105 | last_conv_stride=1, last_conv_dilation=1, isshape=isshape, onlyshallow=True, modalbn=modalbn) 106 | print('visible module:', model_v.isshape, model_v.modalbn) 107 | 108 | # avg pooling to global pooling 109 | self.visible = model_v 110 | 111 | def forward(self, x, modal=0): 112 | x = self.visible.conv1(x) 113 | if modal == 0: # RGB 114 | bbn1 = self.visible.bn1 115 | elif modal == 3: # shape 116 | bbn1 = self.visible.bn1_shape 117 | x = bbn1(x) 118 | x = self.visible.relu(x) 119 | x = self.visible.maxpool(x) 120 | return x 121 | 122 | 123 | class thermal_module(nn.Module): 124 | def __init__(self, isshape, modalbn): 125 | super(thermal_module, self).__init__() 126 | 127 | model_t = resnet50(pretrained=True, 128 | last_conv_stride=1, last_conv_dilation=1,isshape=isshape,onlyshallow=True, modalbn=modalbn) 129 | print('thermal resnet:', model_t.isshape, model_t.modalbn) 130 | 131 | # avg pooling to global pooling 132 | self.thermal = model_t 133 | 134 | 135 | def forward(self, x, modal=1): 136 | x = self.thermal.conv1(x) 137 | if modal == 1: # IR 138 | bbn1 = self.thermal.bn1 139 | elif modal == 3: # shape 140 | bbn1 = self.thermal.bn1_shape 141 | x = bbn1(x) 142 | x = self.thermal.relu(x) 143 | x = self.thermal.maxpool(x) 144 | return x 145 | 146 | 147 | 148 | 149 | class base_resnet(nn.Module): 150 | def __init__(self, isshape, modalbn): 151 | super(base_resnet, self).__init__() 152 | 153 | model_base = resnet50(pretrained=True, 154 | last_conv_stride=1, last_conv_dilation=1, isshape=isshape, modalbn=modalbn) 155 | print('base resnet:', model_base.isshape, model_base.modalbn) 156 | # avg pooling to global pooling 157 | model_base.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 158 | self.base = model_base 159 | 160 | def forward(self, x, modal=0): 161 | x = self.base.layer1(x, modal) 162 | x = self.base.layer2(x, modal) 163 | x = self.base.layer3(x, modal) 164 | x = self.base.layer4(x, modal) 165 | return x 166 | 167 | class embed_net(nn.Module): 168 | def __init__(self, class_num, no_local= 'on', gm_pool = 'on', arch='resnet50'): 169 | super(embed_net, self).__init__() 170 | self.isshape = True 171 | self.modalbn = 2 172 | 173 | self.thermal_module = thermal_module(self.isshape, 1) 174 | self.visible_module = visible_module(self.isshape, 1) 175 | self.base_resnet = base_resnet(self.isshape, self.modalbn) 176 | 177 | # TODO init_bn or not 178 | self.base_resnet.base.init_bn() 179 | self.thermal_module.thermal.init_bn() 180 | self.visible_module.visible.init_bn() 181 | self.non_local = no_local 182 | if self.non_local =='on': 183 | layers=[3, 4, 6, 3] 184 | non_layers=[0,2,3,0] 185 | self.NL_1 = nn.ModuleList( 186 | [Non_local(256) for i in range(non_layers[0])]) 187 | self.NL_1_idx = sorted([layers[0] - (i + 1) for i in range(non_layers[0])]) 188 | self.NL_2 = nn.ModuleList( 189 | [Non_local(512) for i in range(non_layers[1])]) 190 | self.NL_2_idx = sorted([layers[1] - (i + 1) for i in range(non_layers[1])]) 191 | self.NL_3 = nn.ModuleList( 192 | [Non_local(1024) for i in range(non_layers[2])]) 193 | self.NL_3_idx = sorted([layers[2] - (i + 1) for i in range(non_layers[2])]) 194 | self.NL_4 = nn.ModuleList( 195 | [Non_local(2048) for i in range(non_layers[3])]) 196 | self.NL_4_idx = sorted([layers[3] - (i + 1) for i in range(non_layers[3])]) 197 | 198 | pool_dim = 2048 199 | kk = 4 200 | self.l2norm = Normalize(2) 201 | self.bottleneck = nn.BatchNorm1d(pool_dim) 202 | self.bottleneck.bias.requires_grad_(False) # no shift 203 | self.classifier = nn.Linear(pool_dim, class_num, bias=False) 204 | self.bottleneck.apply(weights_init_kaiming) 205 | self.classifier.apply(weights_init_classifier) 206 | 207 | 208 | if self.isshape: 209 | self.bottleneck_shape = nn.BatchNorm1d(pool_dim) 210 | self.bottleneck_shape.bias.requires_grad_(False) # no shift 211 | self.classifier_shape = nn.Linear(pool_dim//kk, class_num, bias=False) 212 | 213 | self.projs = nn.ParameterList([]) 214 | proj = nn.Parameter(torch.zeros([pool_dim,pool_dim//kk], dtype=torch.float32, requires_grad=True)) 215 | # proj2 = nn.Parameter(torch.zeros([pool_dim,pool_dim//4*3], dtype=torch.float32, requires_grad=True)) 216 | proj_shape = nn.Parameter(torch.zeros([pool_dim,pool_dim//kk], dtype=torch.float32, requires_grad=True)) 217 | 218 | nn.init.kaiming_normal_(proj, nonlinearity="linear") 219 | nn.init.kaiming_normal_(proj_shape, nonlinearity="linear") 220 | self.bottleneck_shape.apply(weights_init_kaiming) 221 | self.classifier_shape.apply(weights_init_classifier) 222 | self.projs.append(proj) 223 | self.projs.append(proj_shape) 224 | if self.modalbn >= 2: 225 | self.bottleneck_ir = nn.BatchNorm1d(pool_dim) 226 | self.bottleneck_ir.bias.requires_grad_(False) # no shift 227 | self.classifier_ir = nn.Linear(pool_dim//4, class_num, bias=False) 228 | self.bottleneck_ir.apply(weights_init_kaiming) 229 | self.classifier_ir.apply(weights_init_classifier) 230 | if self.modalbn == 3: 231 | self.bottleneck_modalx = nn.BatchNorm1d(pool_dim) 232 | self.bottleneck_modalx.bias.requires_grad_(False) # no shift 233 | # self.classifier_ = nn.Linear(pool_dim, class_num, bias=False) 234 | self.bottleneck_modalx.apply(weights_init_kaiming) 235 | # self.classifier_rgb.apply(weights_init_classifier) 236 | self.proj_modalx = nn.Linear(pool_dim, pool_dim//kk, bias=False) 237 | self.proj_modalx.apply(weights_init_kaiming) 238 | 239 | 240 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 241 | self.gm_pool = gm_pool 242 | def forward(self, x1, x2, x1_shape=None, x2_shape=None, mode=0): 243 | if mode == 0: # training 244 | 245 | x1 = self.visible_module(x1) 246 | x2 = self.thermal_module(x2) 247 | if x1_shape is not None: 248 | x1_shape = self.visible_module(x1_shape, modal=3) 249 | x2_shape = self.thermal_module(x2_shape, modal=3) 250 | x_shape = torch.cat((x1_shape, x2_shape), 0) 251 | 252 | elif mode == 1: # eval rgb 253 | x = self.visible_module(x1) 254 | elif mode == 2: # eval ir 255 | x = self.thermal_module(x2) 256 | 257 | # shared block 258 | if mode > 0: # eval, only one modality per forward 259 | x = self.base_resnet(x, modal=mode-1) 260 | else: # training 261 | x1 = self.base_resnet(x1, modal=0) 262 | x2 = self.base_resnet(x2, modal=1) 263 | x = torch.cat((x1, x2), 0) 264 | 265 | if mode == 0 and x1_shape is not None: # shape for training 266 | x_shape = self.base_resnet(x_shape, modal=3) 267 | 268 | # gempooling 269 | b, c, h, w = x.shape 270 | x = x.view(b, c, -1) 271 | p = 3.0 272 | x_pool = (torch.mean(x**p, dim=-1) + 1e-12)**(1/p) 273 | 274 | if mode == 0 and x1_shape is not None: 275 | b, c, h, w = x_shape.shape 276 | x_shape = x_shape.view(b, c, -1) 277 | p = 3.0 278 | x_pool_shape = (torch.mean(x_shape**p, dim=-1) + 1e-12)**(1/p) 279 | 280 | # BNNeck 281 | if mode == 1: 282 | feat = self.bottleneck(x_pool) 283 | elif mode == 2: 284 | feat = self.bottleneck_ir(x_pool) 285 | elif mode == 0: 286 | assert x1.shape[0] == x2.shape[0] 287 | feat1 = self.bottleneck(x_pool[:x1.shape[0]]) 288 | feat2 = self.bottleneck_ir(x_pool[x1.shape[0]:]) 289 | feat = torch.cat((feat1, feat2), 0) 290 | if mode == 0 and x1_shape is not None: 291 | feat_shape = self.bottleneck_shape(x_pool_shape) 292 | 293 | # shape-erased feature 294 | if mode == 0: 295 | if x1_shape is not None: 296 | feat_p = torch.mm(feat, self.projs[0]) 297 | proj_norm = F.normalize(self.projs[0], 2, 0) 298 | 299 | feat_pnpn = torch.mm(torch.mm(feat, proj_norm), proj_norm.t()) 300 | 301 | feat_shape_p = torch.mm(feat_shape, self.projs[1]) 302 | 303 | logit2_rgbir = self.classifier(feat-feat_pnpn) 304 | logit_rgbir = self.classifier(feat) 305 | logit_shape = self.classifier_shape(feat_shape_p) 306 | 307 | return {'rgbir':{'bef':x_pool, 'aft':feat, 'logit': logit_rgbir, 'logit2': logit2_rgbir,'zp':feat_p,'other':feat-feat_pnpn},'shape':{'bef':x_pool_shape, 'aft':feat_shape, 'logit':logit_shape,'zp': feat_shape_p} } 308 | 309 | else: 310 | return x_pool, self.classifier(feat) 311 | else: 312 | 313 | return self.l2norm(x_pool), self.l2norm(feat) 314 | def myparameters(self): 315 | res = [] 316 | for k, v in self.named_parameters(): 317 | if v.requires_grad: 318 | if 'classifier' in k or 'proj' in k or 'bn' in k or 'bottleneck' in k: 319 | continue 320 | res.append(v) 321 | return res 322 | -------------------------------------------------------------------------------- /pre_process_sysu.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import pdb 4 | import os 5 | 6 | 7 | data_path = '/home/share/reid_dataset/SYSU-MM01/' 8 | data_path1 = '/home/share/fengjw/SYSU_MM01_SHAPE/' 9 | data_path2 = '/home/share/fengjw/SYSU-MM01_SHAPE_composepow/' 10 | 11 | 12 | rgb_cameras = ['cam1','cam2','cam4','cam5'] 13 | ir_cameras = ['cam3','cam6'] 14 | 15 | # load id info 16 | file_path_train = os.path.join(data_path,'exp/train_id.txt') 17 | file_path_val = os.path.join(data_path,'exp/val_id.txt') 18 | with open(file_path_train, 'r') as file: 19 | ids = file.read().splitlines() 20 | ids = [int(y) for y in ids[0].split(',')] 21 | id_train = ["%04d" % x for x in ids] 22 | print(len(ids)) 23 | with open(file_path_val, 'r') as file: 24 | ids = file.read().splitlines() 25 | ids = [int(y) for y in ids[0].split(',')] 26 | id_val = ["%04d" % x for x in ids] 27 | print(len(ids)) 28 | # combine train and val split 29 | id_train.extend(id_val) 30 | 31 | files_rgb = [] 32 | files_rgb_shape = [] 33 | files_rgb_mask = [] 34 | files_ir = [] 35 | files_ir_shape = [] 36 | files_ir_mask = [] 37 | for id in sorted(id_train): 38 | for cam in rgb_cameras: 39 | img_dir = os.path.join(data_path,cam,id) 40 | if os.path.isdir(img_dir): 41 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 42 | files_rgb.extend(new_files) 43 | img_dir1 = os.path.join(data_path1,cam,id) 44 | if os.path.isdir(img_dir1): 45 | new_files = sorted([img_dir1+'/'+i for i in os.listdir(img_dir1)]) 46 | files_rgb_shape.extend(new_files) 47 | img_dir2 = os.path.join(data_path2,cam,id) 48 | if os.path.isdir(img_dir2): 49 | new_files = sorted([img_dir2+'/'+i for i in os.listdir(img_dir2)]) 50 | files_rgb_mask.extend(new_files) 51 | for cam in ir_cameras: 52 | img_dir = os.path.join(data_path,cam,id) 53 | if os.path.isdir(img_dir): 54 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 55 | files_ir.extend(new_files) 56 | img_dir1 = os.path.join(data_path1,cam,id) 57 | if os.path.isdir(img_dir1): 58 | new_files = sorted([img_dir1+'/'+i for i in os.listdir(img_dir1)]) 59 | files_ir_shape.extend(new_files) 60 | img_dir2 = os.path.join(data_path2,cam,id) 61 | if os.path.isdir(img_dir2): 62 | new_files = sorted([img_dir2+'/'+i for i in os.listdir(img_dir2)]) 63 | files_ir_mask.extend(new_files) 64 | for i in range(len(files_rgb)): 65 | if not files_rgb[i][-19:-1] == files_rgb_mask[i][-19:-1]: 66 | import pdb 67 | pdb.set_trace() 68 | if not files_rgb[i][-19:-1] == files_rgb_shape[i][-19:-1]: 69 | import pdb 70 | pdb.set_trace() 71 | for i in range(len(files_ir)): 72 | if not files_ir[i][-19:-1] == files_ir_mask[i][-19:-1]: 73 | import pdb 74 | pdb.set_trace() 75 | if not files_ir[i][-19:-1] == files_ir_shape[i][-19:-1]: 76 | import pdb 77 | pdb.set_trace() 78 | # relabel 79 | pid_container = set() 80 | for img_path in files_ir: 81 | pid = int(img_path[-13:-9]) 82 | pid_container.add(pid) 83 | print(len(pid_container)) 84 | pid2label = {pid:label for label, pid in enumerate(pid_container)} 85 | fix_image_width = 144 86 | fix_image_height = 288 87 | def read_imgs(train_image): 88 | train_img = [] 89 | train_label = [] 90 | for img_path in train_image: 91 | # img 92 | img = Image.open(img_path) 93 | img = img.resize((fix_image_width, fix_image_height), Image.ANTIALIAS) 94 | pix_array = np.array(img) 95 | 96 | train_img.append(pix_array) 97 | 98 | # label 99 | pid = int(img_path[-13:-9]) 100 | pid = pid2label[pid] 101 | train_label.append(pid) 102 | return np.array(train_img), np.array(train_label) 103 | 104 | # rgb imges 105 | rgb_img, rgb_label = read_imgs(files_rgb) 106 | rgb_img_shape, rgb_label_shape = read_imgs(files_rgb_shape) 107 | train_img, train_label = read_imgs(files_ir) 108 | train_img_shape, train_label_shape = read_imgs(files_ir_shape) 109 | 110 | def func(img, img_shape): 111 | res = [] 112 | shape_mean = img_shape.mean() 113 | imgsize = 288*144 114 | for i in range(img.shape[0]): 115 | tmp = (img_shape[i][:,:,0]>shape_mean).sum()/imgsize 116 | res.append(tmp) 117 | return np.array(res) 118 | 119 | res_rgb = func(rgb_img, rgb_img_shape) 120 | res_ir = func(train_img, train_img_shape) 121 | thres_rgb = 0.1 122 | thres_ir = 0.01 123 | 124 | np.save(data_path + 'train_rgb_resized_img_new.npy', rgb_img[res_rgb>thres_rgb]) 125 | np.save(data_path + 'train_rgb_resized_label_new.npy', rgb_label[res_rgb>thres_rgb]) 126 | 127 | np.save(data_path1 + 'train_rgb_resized_img_new.npy', rgb_img_shape[res_rgb>thres_rgb]) 128 | np.save(data_path1 + 'train_rgb_resized_label_new.npy', rgb_label_shape[res_rgb>thres_rgb]) 129 | 130 | # ir imges 131 | 132 | 133 | np.save(data_path + 'train_ir_resized_img_new.npy', train_img[res_ir>thres_ir]) 134 | np.save(data_path + 'train_ir_resized_label_new.npy', train_label[res_ir>thres_ir]) 135 | 136 | np.save(data_path1 + 'train_ir_resized_img_new.npy', train_img_shape[res_ir>thres_ir]) 137 | np.save(data_path1 + 'train_ir_resized_label_new.npy', train_label_shape[res_ir>thres_ir]) 138 | 139 | -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch.nn as nn 3 | import math 4 | import torch.utils.model_zoo as model_zoo 5 | import torch 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152'] 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 20 | """3x3 convolution with padding""" 21 | # original padding is 1; original dilation is 1 22 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 23 | padding=dilation, bias=False, dilation=dilation) 24 | 25 | 26 | class Sequential(nn.Sequential): 27 | # def __init__(self, *args): 28 | # super(Sequential, self).__init__() 29 | def forward(self, input, modal=0): 30 | for module in self: 31 | res = module(input, modal) 32 | input = res 33 | return res 34 | 35 | 36 | class BasicBlock(nn.Module): 37 | expansion = 1 38 | 39 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 40 | super(BasicBlock, self).__init__() 41 | self.conv1 = conv3x3(inplanes, planes, stride, dilation) 42 | self.bn1 = nn.BatchNorm2d(planes) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.conv2 = conv3x3(planes, planes) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.downsample = downsample 47 | self.stride = stride 48 | 49 | def forward(self, x): 50 | residual = x 51 | 52 | out = self.conv1(x) 53 | out = self.bn1(out) 54 | out = self.relu(out) 55 | 56 | out = self.conv2(out) 57 | out = self.bn2(out) 58 | 59 | if self.downsample is not None: 60 | residual = self.downsample(x) 61 | 62 | out += residual 63 | out = self.relu(out) 64 | 65 | return out 66 | 67 | 68 | class Bottleneck(nn.Module): 69 | expansion = 4 70 | 71 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1, isshape=False, modalbn=1): 72 | super(Bottleneck, self).__init__() 73 | self.isshape = isshape 74 | self.modalbn = modalbn 75 | assert modalbn == 1 or modalbn == 2 or modalbn == 3 76 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 77 | self.bn1 = nn.BatchNorm2d(planes) 78 | if isshape: 79 | self.bn1_shape = nn.BatchNorm2d(planes) 80 | if modalbn == 2: 81 | self.bn1_ir = nn.BatchNorm2d(planes) 82 | if modalbn == 3: 83 | self.bn1_ir = nn.BatchNorm2d(planes) 84 | self.bn1_modalx = nn.BatchNorm2d(planes) 85 | 86 | # original padding is 1; original dilation is 1 87 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation) 88 | self.bn2 = nn.BatchNorm2d(planes) 89 | if isshape: 90 | self.bn2_shape = nn.BatchNorm2d(planes) 91 | if modalbn == 2: 92 | self.bn2_ir = nn.BatchNorm2d(planes) 93 | if modalbn == 3: 94 | self.bn2_ir = nn.BatchNorm2d(planes) 95 | self.bn2_modalx = nn.BatchNorm2d(planes) 96 | 97 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 98 | self.bn3 = nn.BatchNorm2d(planes * 4) 99 | if isshape: 100 | self.bn3_shape = nn.BatchNorm2d(planes * 4) 101 | if modalbn == 2: 102 | self.bn3_ir = nn.BatchNorm2d(planes * 4) 103 | if modalbn == 3: 104 | self.bn3_ir = nn.BatchNorm2d(planes * 4) 105 | self.bn3_modalx = nn.BatchNorm2d(planes * 4) 106 | 107 | 108 | self.relu = nn.ReLU(inplace=True) 109 | self.downsample = downsample 110 | if downsample is not None: 111 | if isshape: 112 | self.dsbn_shape = nn.BatchNorm2d(downsample[1].weight.shape[0]) 113 | if modalbn == 2: 114 | self.dsbn_ir = nn.BatchNorm2d(downsample[1].weight.shape[0]) 115 | if modalbn == 3: 116 | self.dsbn_ir = nn.BatchNorm2d(downsample[1].weight.shape[0]) 117 | self.dsbn_modalx = nn.BatchNorm2d(downsample[1].weight.shape[0]) 118 | self.stride = stride 119 | 120 | def forward(self, x, modal=0): 121 | if modal == 0: # RGB 122 | bbn1 = self.bn1 123 | bbn2 = self.bn2 124 | bbn3 = self.bn3 125 | if self.downsample is not None: 126 | dsbn = self.downsample[1] 127 | elif modal == 1: # IR 128 | bbn1 = self.bn1_ir 129 | bbn2 = self.bn2_ir 130 | bbn3 = self.bn3_ir 131 | if self.downsample is not None: 132 | dsbn = self.dsbn_ir 133 | elif modal == 2: # modalx 134 | bbn1 = self.bn1_modalx 135 | bbn2 = self.bn2_modalx 136 | bbn3 = self.bn3_modalx 137 | if self.downsample is not None: 138 | dsbn = self.dsbn_modalx 139 | elif modal == 3: # shape 140 | assert self.isshape == True 141 | bbn1 = self.bn1_shape 142 | bbn2 = self.bn2_shape 143 | bbn3 = self.bn3_shape 144 | if self.downsample is not None: 145 | dsbn = self.dsbn_shape 146 | 147 | residual = x 148 | 149 | out = self.conv1(x) 150 | out = bbn1(out) 151 | out = self.relu(out) 152 | out = self.conv2(out) 153 | out = bbn2(out) 154 | out = self.relu(out) 155 | out = self.conv3(out) 156 | out = bbn3(out) 157 | 158 | if self.downsample is not None: 159 | residual = dsbn(self.downsample[0](x)) 160 | 161 | out += residual 162 | outt = F.relu(out) 163 | 164 | return outt 165 | 166 | 167 | class ResNet(nn.Module): 168 | 169 | def __init__(self, block, layers, last_conv_stride=2, last_conv_dilation=1, isshape=False, modalbn=1, onlyshallow=False): 170 | self.isshape = isshape 171 | self.modalbn = modalbn 172 | self.inplanes = 64 173 | super(ResNet, self).__init__() 174 | # if onlyshallow: 175 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 176 | bias=False) 177 | self.bn1 = nn.BatchNorm2d(64) 178 | if self.isshape: 179 | self.bn1_shape = nn.BatchNorm2d(64) 180 | if self.modalbn == 2: 181 | self.bn1_ir = nn.BatchNorm2d(64) 182 | elif self.modalbn == 3: 183 | self.bn1_ir = nn.BatchNorm2d(64) 184 | self.bn1_modalx = nn.BatchNorm2d(64) 185 | self.relu = nn.ReLU(inplace=True) 186 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 187 | # else: 188 | self.layer1 = self._make_layer(block, 64, layers[0]) 189 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 190 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 191 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_conv_stride, dilation=last_conv_dilation) 192 | 193 | for m in self.modules(): 194 | if isinstance(m, nn.Conv2d): 195 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 196 | m.weight.data.normal_(0, math.sqrt(2. / n)) 197 | elif isinstance(m, nn.BatchNorm2d): 198 | m.weight.data.fill_(1) 199 | m.bias.data.zero_() 200 | 201 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 202 | downsample = None 203 | if stride != 1 or self.inplanes != planes * block.expansion: 204 | downsample = nn.Sequential( 205 | nn.Conv2d(self.inplanes, planes * block.expansion, 206 | kernel_size=1, stride=stride, bias=False), 207 | nn.BatchNorm2d(planes * block.expansion), 208 | ) 209 | 210 | 211 | layers = [] 212 | layers.append(block(self.inplanes, planes, stride, downsample, dilation, isshape=self.isshape, modalbn=self.modalbn)) 213 | self.inplanes = planes * block.expansion 214 | for i in range(1, blocks): 215 | layers.append(block(self.inplanes, planes, isshape=self.isshape, modalbn=self.modalbn)) 216 | 217 | return Sequential(*layers) 218 | 219 | def forward(self, x, modal=0): 220 | x = self.conv1(x) 221 | if modal == 0: # RGB 222 | bbn1 = self.bn1 223 | elif modal == 1: # IR 224 | bbn1 = self.bn1_ir 225 | elif modal == 2: # modalx 226 | bbn1 = self.bn1_modalx 227 | elif modal == 3: # shape 228 | assert self.isshape == True 229 | bbn1 = self.bn1_shape 230 | 231 | x = bbn1(x) 232 | 233 | x = self.relu(x) 234 | x = self.maxpool(x) 235 | 236 | x = self.layer1(x, modal) 237 | x = self.layer2(x, modal) 238 | x = self.layer3(x, modal) 239 | x = self.layer4(x, modal) 240 | 241 | return x 242 | 243 | def init_bn_layer(self, layer): 244 | for i in range(len(layer)): 245 | if self.isshape: 246 | layer[i].bn1_shape.weight.data = layer[i].bn1.weight.data.clone() 247 | layer[i].bn1_shape.bias.data = layer[i].bn1.bias.data.clone() 248 | layer[i].bn2_shape.weight.data = layer[i].bn2.weight.data.clone() 249 | layer[i].bn2_shape.bias.data = layer[i].bn2.bias.data.clone() 250 | layer[i].bn3_shape.weight.data = layer[i].bn3.weight.data.clone() 251 | layer[i].bn3_shape.bias.data = layer[i].bn3.bias.data.clone() 252 | if layer[i].downsample is not None: 253 | layer[i].dsbn_shape.weight.data = layer[i].downsample[1].weight.data.clone() 254 | layer[i].dsbn_shape.bias.data = layer[i].downsample[1].bias.data.clone() 255 | if self.modalbn >= 2: 256 | layer[i].bn1_ir.weight.data = layer[i].bn1.weight.data.clone() 257 | layer[i].bn1_ir.bias.data = layer[i].bn1.bias.data.clone() 258 | layer[i].bn2_ir.weight.data = layer[i].bn2.weight.data.clone() 259 | layer[i].bn2_ir.bias.data = layer[i].bn2.bias.data.clone() 260 | layer[i].bn3_ir.weight.data = layer[i].bn3.weight.data.clone() 261 | layer[i].bn3_ir.bias.data = layer[i].bn3.bias.data.clone() 262 | if layer[i].downsample is not None: 263 | layer[i].dsbn_ir.weight.data = layer[i].downsample[1].weight.data.clone() 264 | layer[i].dsbn_ir.bias.data = layer[i].downsample[1].bias.data.clone() 265 | if self.modalbn == 3: 266 | layer[i].bn1_modalx.weight.data = layer[i].bn1.weight.data.clone() 267 | layer[i].bn1_modalx.bias.data = layer[i].bn1.bias.data.clone() 268 | layer[i].bn2_modalx.weight.data = layer[i].bn2.weight.data.clone() 269 | layer[i].bn2_modalx.bias.data = layer[i].bn2.bias.data.clone() 270 | layer[i].bn3_modalx.weight.data = layer[i].bn3.weight.data.clone() 271 | layer[i].bn3_modalx.bias.data = layer[i].bn3.bias.data.clone() 272 | if layer[i].downsample is not None: 273 | layer[i].dsbn_modalx.weight.data = layer[i].downsample[1].weight.data.clone() 274 | layer[i].dsbn_modalx.bias.data = layer[i].downsample[1].bias.data.clone() 275 | def init_bn(self, onlyshallow=False): 276 | # if onlyshallow: 277 | if self.isshape: 278 | self.bn1_shape.weight.data = self.bn1.weight.data.clone() 279 | self.bn1_shape.bias.data = self.bn1.bias.data.clone() 280 | if self.modalbn >= 2: 281 | self.bn1_ir.weight.data = self.bn1.weight.data.clone() 282 | self.bn1_ir.bias.data = self.bn1.bias.data.clone() 283 | if self.modalbn == 3: 284 | self.bn1_modalx.weight.data = self.bn1.weight.data.clone() 285 | self.bn1_modalx.bias.data = self.bn1.bias.data.clone() 286 | # else: 287 | self.init_bn_layer(self.layer1) 288 | self.init_bn_layer(self.layer2) 289 | self.init_bn_layer(self.layer3) 290 | self.init_bn_layer(self.layer4) 291 | 292 | 293 | def average_bn_layer(self, layer): 294 | for i in range(len(layer)): 295 | bn1w = (layer[i].bn1.weight.data.clone()+layer[i].bn1_ir.weight.data.clone()+layer[i].bn1_shape.weight.data.clone())/3 296 | bn1b = (layer[i].bn1.bias.data.clone()+layer[i].bn1_ir.bias.data.clone()+layer[i].bn1_shape.bias.data.clone())/3 297 | bn2w = (layer[i].bn2.weight.data.clone()+layer[i].bn2_ir.weight.data.clone()+layer[i].bn2_shape.weight.data.clone())/3 298 | bn2b = (layer[i].bn2.bias.data.clone()+layer[i].bn2_ir.bias.data.clone()+layer[i].bn2_shape.bias.data.clone())/3 299 | bn3w = (layer[i].bn3.weight.data.clone()+layer[i].bn3_ir.weight.data.clone()+layer[i].bn3_shape.weight.data.clone())/3 300 | bn3b = (layer[i].bn3.bias.data.clone()+layer[i].bn3_ir.bias.data.clone()+layer[i].bn3_shape.bias.data.clone())/3 301 | if layer[i].downsample is not None: 302 | dbbnw = (layer[i].downsample[1].weight.data.clone()+layer[i].dsbn_shape.weight.data.clone()+layer[i].dsbn_ir.weight.data.clone())/3 303 | dbbnb = (layer[i].downsample[1].bias.data.clone()+layer[i].dsbn_shape.bias.data.clone()+layer[i].dsbn_ir.bias.data.clone())/3 304 | layer[i].bn1.weight.data = bn1w.clone() 305 | layer[i].bn1.bias.data = bn1b.clone() 306 | layer[i].bn2.weight.data = bn2w.clone() 307 | layer[i].bn2.bias.data = bn2b.clone() 308 | layer[i].bn3.weight.data = bn3w.clone() 309 | layer[i].bn3.bias.data = bn3b.clone() 310 | if layer[i].downsample is not None: 311 | layer[i].downsample[1].weight.data = dbbnw.clone() 312 | layer[i].downsample[1].bias.data = dbbnb.clone() 313 | if self.isshape: 314 | layer[i].bn1_shape.weight.data = bn1w.clone() 315 | layer[i].bn1_shape.bias.data = bn1b.clone() 316 | layer[i].bn2_shape.weight.data = bn2w.clone() 317 | layer[i].bn2_shape.bias.data = bn2b.clone() 318 | layer[i].bn3_shape.weight.data = bn3w.clone() 319 | layer[i].bn3_shape.bias.data = bn3b.clone() 320 | if layer[i].downsample is not None: 321 | layer[i].dsbn_shape.weight.data = dbbnw.clone() 322 | layer[i].dsbn_shape.bias.data = dbbnb.clone() 323 | if self.modalbn >= 2: 324 | layer[i].bn1_ir.weight.data = bn1w.clone() 325 | layer[i].bn1_ir.bias.data = bn1b.clone() 326 | layer[i].bn2_ir.weight.data = bn2w.clone() 327 | layer[i].bn2_ir.bias.data = bn2b.clone() 328 | layer[i].bn3_ir.weight.data = bn3w.clone() 329 | layer[i].bn3_ir.bias.data = bn3b.clone() 330 | if layer[i].downsample is not None: 331 | layer[i].dsbn_ir.weight.data = dbbnw.clone() 332 | layer[i].dsbn_ir.bias.data = dbbnb.clone() 333 | if self.modalbn == 3: 334 | layer[i].bn1_modalx.weight.data = bn1w.clone() 335 | layer[i].bn1_modalx.bias.data = bn1b.clone() 336 | layer[i].bn2_modalx.weight.data = bn2w.clone() 337 | layer[i].bn2_modalx.bias.data = bn2b.clone() 338 | layer[i].bn3_modalx.weight.data = bn3w.clone() 339 | layer[i].bn3_modalx.bias.data = bn3b.clone() 340 | if layer[i].downsample is not None: 341 | layer[i].dsbn_modalx.weight.data = dbbnw.clone() 342 | layer[i].dsbn_modalx.bias.data = dbbnb.clone() 343 | def average_bn(self, onlyshallow=False): 344 | if onlyshallow: 345 | tmpw = self.bn1.weight.data.clone() + self.bn1_shape.weight.data.clone() 346 | tmpw /= 2 347 | tmpb = self.bn1.bias.data.clone() + self.bn1_shape.bias.data.clone() 348 | tmpb /= 2 349 | self.bn1.weight.data = tmpw.clone() 350 | self.bn1.bias.data = tmpb.clone() 351 | 352 | if self.isshape: 353 | self.bn1_shape.weight.data = tmpw.clone() 354 | self.bn1_shape.bias.data = tmpb.clone() 355 | if self.modalbn >= 2: 356 | self.bn1_ir.weight.data = tmpw.clone() 357 | self.bn1_ir.bias.data = tmpb.clone() 358 | if self.modalbn == 3: 359 | self.bn1_modalx.weight.data = tmpw.clone() 360 | self.bn1_modalx.bias.data = tmpb.clone() 361 | else: 362 | self.average_bn_layer(self.layer1) 363 | self.average_bn_layer(self.layer2) 364 | self.average_bn_layer(self.layer3) 365 | self.average_bn_layer(self.layer4) 366 | # def last_layer_shared(self, ): 367 | # res = [] 368 | # for k, v in self.layer4.named_parameters(): 369 | # if 'conv' in k: 370 | # res.append(v) 371 | # return res 372 | 373 | def remove_fc(state_dict): 374 | """Remove the fc layer parameters from state_dict.""" 375 | # for key, value in state_dict.items(): 376 | for key, value in list(state_dict.items()): 377 | if key.startswith('fc.'): 378 | del state_dict[key] 379 | return state_dict 380 | 381 | 382 | def resnet18(pretrained=False, **kwargs): 383 | """Constructs a ResNet-18 model. 384 | Args: 385 | pretrained (bool): If True, returns a model pre-trained on ImageNet 386 | """ 387 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 388 | if pretrained: 389 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet18']))) 390 | return model 391 | 392 | 393 | def resnet34(pretrained=False, **kwargs): 394 | """Constructs a ResNet-34 model. 395 | Args: 396 | pretrained (bool): If True, returns a model pre-trained on ImageNet 397 | """ 398 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 399 | if pretrained: 400 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet34']))) 401 | return model 402 | 403 | 404 | def resnet50(pretrained=False, **kwargs): 405 | """Constructs a ResNet-50 model. 406 | Args: 407 | pretrained (bool): If True, returns a model pre-trained on ImageNet 408 | """ 409 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 410 | if pretrained: 411 | # model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50']))) 412 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50'])),strict=False) 413 | return model 414 | 415 | 416 | def resnet101(pretrained=False, **kwargs): 417 | """Constructs a ResNet-101 model. 418 | Args: 419 | pretrained (bool): If True, returns a model pre-trained on ImageNet 420 | """ 421 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 422 | if pretrained: 423 | model.load_state_dict( 424 | remove_fc(model_zoo.load_url(model_urls['resnet101'])),strict=False) 425 | return model 426 | 427 | 428 | def resnet152(pretrained=False, **kwargs): 429 | """Constructs a ResNet-152 model. 430 | Args: 431 | pretrained (bool): If True, returns a model pre-trained on ImageNet 432 | """ 433 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 434 | if pretrained: 435 | model.load_state_dict( 436 | remove_fc(model_zoo.load_url(model_urls['resnet152'])), strict=False) 437 | return model -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python train.py --dataset sysu --lr 0.1 --method test --gpu 4 --date 4.11 --gradclip 11 --seed 3 --gpuversion 3090 3 | -------------------------------------------------------------------------------- /testy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import time 4 | import torch 5 | from torch.autograd import Variable 6 | import torch.utils.data as data 7 | import torchvision.transforms as transforms 8 | from data_loader import SYSUData, TestData 9 | from data_manager import * 10 | from eval_metrics import eval_sysu, eval_regdb 11 | from model_bn import embed_net 12 | from utils import * 13 | from eval_sysu import eval_cross_cmc_map 14 | import pdb 15 | 16 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality test sysumm01') 17 | parser.add_argument('--dataset', default='sysu', help='dataset name: regdb or sysu]') 18 | parser.add_argument('--arch', default='resnet50', type=str, 19 | help='network baseline:resnet18 or resnet50') 20 | parser.add_argument('--resume', '-r', default='', type=str, 21 | help='resume from checkpoint') 22 | parser.add_argument('--workers', default=8, type=int, metavar='N', 23 | help='number of data loading workers (default: 4)') 24 | parser.add_argument('--img_w', default=144, type=int, 25 | metavar='imgw', help='img width') 26 | parser.add_argument('--img_h', default=288, type=int, 27 | metavar='imgh', help='img height') 28 | parser.add_argument('--test-batch', default=64, type=int, 29 | metavar='tb', help='testing batch size') 30 | parser.add_argument('--method', default='agw', type=str, 31 | metavar='m', help='method type: base or agw, adp') 32 | parser.add_argument('--trial', default=1, type=int, 33 | metavar='t', help='trial (only for RegDB dataset)') 34 | parser.add_argument('--seed', default=0, type=int, 35 | metavar='t', help='random seed') 36 | parser.add_argument('--gpu', default='0', type=str, 37 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 38 | parser.add_argument('--mode', default='all', type=str, help='all or indoor') 39 | parser.add_argument('--date', default='') 40 | parser.add_argument('--gpuversion', default= '3090', type=str) 41 | path_dict = {} 42 | path_dict['3090'] = ['/home/share/reid_dataset/SYSU-MM01/', '/home/share/fengjw/SYSU_MM01_SHAPE/'] 43 | path_dict['4090'] = ['/home/jiawei/data/SYSU-MM01/', '/home/jiawei/data/SYSU_MM01_SHAPE/'] 44 | args = parser.parse_args() 45 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 46 | 47 | set_seed(args.seed) 48 | 49 | 50 | test_mode = [1, 2] # thermal to visible 51 | pool_dim = 2048 52 | def extract_gall_feat(gall_loader): 53 | net.eval() 54 | print ('Extracting Gallery Feature...') 55 | start = time.time() 56 | ptr = 0 57 | ngall = len(gall_loader.dataset) 58 | gall_feat_pool = np.zeros((ngall, pool_dim)) 59 | gall_feat_fc = np.zeros((ngall, pool_dim)) 60 | with torch.no_grad(): 61 | for batch_idx, (input, label, cam) in enumerate(gall_loader): 62 | batch_num = input.size(0) 63 | input = Variable(input.cuda()) 64 | 65 | feat_pool, feat_fc = get_feature(input, test_mode[0]) 66 | # feat_pool, feat_fc = net(input, input, mode=test_mode[0]) 67 | gall_feat_pool[ptr:ptr+batch_num,: ] = feat_pool.detach().cpu().numpy() 68 | gall_feat_fc[ptr:ptr+batch_num,: ] = feat_fc.detach().cpu().numpy() 69 | ptr = ptr + batch_num 70 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 71 | return gall_feat_pool, gall_feat_fc 72 | 73 | def extract_query_feat(query_loader): 74 | net.eval() 75 | print ('Extracting Query Feature...') 76 | start = time.time() 77 | ptr = 0 78 | nquery = len(query_loader.dataset) 79 | query_feat_pool = np.zeros((nquery, pool_dim)) 80 | query_feat_fc = np.zeros((nquery, pool_dim)) 81 | with torch.no_grad(): 82 | for batch_idx, (input, label, cam) in enumerate(query_loader): 83 | batch_num = input.size(0) 84 | input = Variable(input.cuda()) 85 | feat_pool, feat_fc = get_feature(input, test_mode[1]) 86 | # feat_pool, feat_fc = net(input, input, mode=test_mode[1]) 87 | query_feat_pool[ptr:ptr+batch_num,: ] = feat_pool.detach().cpu().numpy() 88 | query_feat_fc[ptr:ptr+batch_num,: ] = feat_fc.detach().cpu().numpy() 89 | ptr = ptr + batch_num 90 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 91 | return query_feat_pool, query_feat_fc 92 | 93 | def fliplr(img): 94 | '''flip horizontal''' 95 | inv_idx = torch.arange(img.size(3)-1,-1,-1,device=img.device).long() # N x C x H x W 96 | img_flip = img.index_select(3,inv_idx) 97 | return img_flip 98 | 99 | def get_feature(input, mode): 100 | feat_pool, feat_fc = net(input, input, mode=mode) 101 | input2 = fliplr(input) 102 | feat_pool2, feat_fc2 = net(input2, input2, mode=mode) 103 | feat_pool = (feat_pool+feat_pool2)/2 104 | feat_fc = (feat_fc+feat_fc2)/2 105 | 106 | return feat_pool, feat_fc 107 | 108 | 109 | print("==========\nArgs:{}\n==========".format(args)) 110 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 111 | best_acc = 0 # best test accuracy 112 | best_acc_ema = 0 # best test accuracy 113 | start_epoch = 0 114 | 115 | print('==> Loading data..') 116 | # Data loading code 117 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 118 | 119 | 120 | transform_test = transforms.Compose( [ 121 | transforms.ToPILImage(), 122 | transforms.Resize((args.img_h, args.img_w)), 123 | transforms.ToTensor(), 124 | normalize]) 125 | 126 | end = time.time() 127 | # training set 128 | trainset = SYSUData(data_dir=path_dict[args.gpuversion][0],data_dir1=path_dict[args.gpuversion][1]) 129 | # testing set 130 | query_img, query_label, query_cam = process_query_sysu(mode=args.mode,data_path_ori=path_dict[args.gpuversion][0]) 131 | gall_img, gall_label, gall_cam = process_gallery_sysu(mode=args.mode,data_path_ori=path_dict[args.gpuversion][0]) 132 | 133 | gallset = TestData(gall_img, gall_label, gall_cam, transform=transform_test, img_size=(args.img_w, args.img_h)) 134 | queryset = TestData(query_img, query_label, query_cam, transform=transform_test, img_size=(args.img_w, args.img_h)) 135 | 136 | # testing data loader 137 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 138 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 139 | 140 | n_class = len(np.unique(trainset.train_color_label)) 141 | 142 | net = embed_net(n_class, no_local= 'off', gm_pool = 'on', arch=args.arch) 143 | 144 | net.to(device) 145 | # cudnn.benchmark = True 146 | 147 | # best two 148 | net_path = 'save_model/sysu_step2085_p0.2intercutmix_bothcegkl_gradclip11.0_seed3_KL_1.0_p4_n8_lr_0.1_seed_3_ema_best.t' 149 | 150 | net.load_state_dict(torch.load(net_path)['net']) 151 | net.eval() 152 | 153 | def test(net): 154 | # switch to evaluation mode 155 | 156 | 157 | 158 | net.eval() 159 | query_feat_pool, query_feat_fc = extract_query_feat(query_loader) 160 | for trial in range(10): 161 | 162 | gall_img, gall_label, gall_cam = process_gallery_sysu(mode=args.mode, trial=trial,data_path_ori=path_dict[args.gpuversion][0]) 163 | 164 | trial_gallset = TestData(gall_img, gall_label, gall_cam, transform=transform_test, img_size=(args.img_w, args.img_h)) 165 | trial_gall_loader = data.DataLoader(trial_gallset, batch_size=args.test_batch, shuffle=False, num_workers=4) 166 | 167 | gall_feat_pool, gall_feat_fc = extract_gall_feat(trial_gall_loader) 168 | 169 | # pool5 feature 170 | distmat_pool = np.matmul(query_feat_pool, np.transpose(gall_feat_pool)) 171 | cmc_pool, mAP_pool, mINP_pool = eval_sysu(-distmat_pool, query_label, gall_label, query_cam, gall_cam) 172 | 173 | # fc feature 174 | distmat = np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 175 | cmc, mAP, mINP = eval_sysu(-distmat, query_label, gall_label, query_cam, gall_cam) 176 | if trial == 0: 177 | all_cmc = cmc 178 | all_mAP = mAP 179 | all_mINP = mINP 180 | all_cmc_pool = cmc_pool 181 | all_mAP_pool = mAP_pool 182 | all_mINP_pool = mINP_pool 183 | else: 184 | all_cmc = all_cmc + cmc 185 | all_mAP = all_mAP + mAP 186 | all_mINP = all_mINP + mINP 187 | all_cmc_pool = all_cmc_pool + cmc_pool 188 | all_mAP_pool = all_mAP_pool + mAP_pool 189 | all_mINP_pool = all_mINP_pool + mINP_pool 190 | 191 | print('Test Trial: {}'.format(trial)) 192 | print( 193 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 194 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 195 | print( 196 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 197 | cmc_pool[0], cmc_pool[4], cmc_pool[9], cmc_pool[19], mAP_pool, mINP_pool)) 198 | 199 | print( 200 | 'FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 201 | all_cmc[0], all_cmc[4], all_cmc[9], all_cmc[19], all_mAP, all_mINP)) 202 | print( 203 | 'POOL: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 204 | all_cmc_pool[0], all_cmc_pool[4], all_cmc_pool[9], all_cmc_pool[19], all_mAP_pool, all_mINP_pool)) 205 | 206 | 207 | 208 | test(net) 209 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import sys 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.backends.cudnn as cudnn 9 | from torch.autograd import Variable 10 | import torch.utils.data as data 11 | import torchvision 12 | import torch.nn.functional as F 13 | import torchvision.transforms as transforms 14 | from data_loader import SYSUData, TestData 15 | from data_manager import * 16 | from eval_metrics import eval_sysu 17 | from model_bn import embed_net 18 | from utils import * 19 | import copy 20 | from loss import OriTripletLoss, TripletLoss_WRT, TripletLoss_ADP, sce, shape_cpmt_cross_modal_ce 21 | # from tensorboardX import SummaryWriter 22 | from ChannelAug import ChannelAdap, ChannelAdapGray, ChannelRandomErasing 23 | import pdb 24 | import wandb 25 | 26 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 27 | parser.add_argument('--dataset', default='sysu', help='dataset name: regdb or sysu]') 28 | parser.add_argument('--lr', default=0.1 , type=float, help='learning rate, 0.00035 for adam') 29 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer') 30 | parser.add_argument('--arch', default='resnet50', type=str, 31 | help='network baseline:resnet18 or resnet50') 32 | parser.add_argument('--resume', '-r', default='', type=str, 33 | help='resume from checkpoint') 34 | parser.add_argument('--test-only', action='store_true', help='test only') 35 | parser.add_argument('--model_path', default='save_model/', type=str, 36 | help='model save path') 37 | parser.add_argument('--save_epoch', default=20, type=int, 38 | metavar='s', help='save model every 10 epochs') 39 | parser.add_argument('--log_path', default='log/', type=str, 40 | help='log save path') 41 | parser.add_argument('--vis_log_path', default='log/vis_log/', type=str, 42 | help='log save path') 43 | parser.add_argument('--workers', default=8, type=int, metavar='N', 44 | help='number of data loading workers (default: 4)') 45 | parser.add_argument('--img_w', default=144, type=int, 46 | metavar='imgw', help='img width') 47 | parser.add_argument('--img_h', default=288, type=int, 48 | metavar='imgh', help='img height') 49 | parser.add_argument('--batch-size', default=8, type=int, 50 | metavar='B', help='training batch size') 51 | parser.add_argument('--test-batch', default=64, type=int, 52 | metavar='tb', help='testing batch size') 53 | parser.add_argument('--method', default='agw', type=str, 54 | metavar='m', help='method type: base or agw, adp') 55 | parser.add_argument('--margin', default=0.3, type=float, 56 | metavar='margin', help='triplet loss margin') 57 | parser.add_argument('--num_pos', default=4, type=int, 58 | help='num of pos per identity in each modality') 59 | parser.add_argument('--trial', default=1, type=int, 60 | metavar='t', help='trial (only for RegDB dataset)') 61 | parser.add_argument('--seed', default=3, type=int, 62 | metavar='t', help='random seed') 63 | parser.add_argument('--gpu', default='0', type=str, 64 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 65 | parser.add_argument('--mode', default='all', type=str, help='all or indoor') 66 | 67 | parser.add_argument('--date', default='12.22', help='date of exp') 68 | 69 | parser.add_argument('--gradclip', default= 11, type=float, 70 | metavar='gradclip', help='gradient clip') 71 | parser.add_argument('--gpuversion', default= '3090', type=str, help='3090 or 4090') 72 | path_dict = {} 73 | path_dict['3090'] = ['/home/share/reid_dataset/SYSU-MM01/', '/home/share/fengjw/SYSU_MM01_SHAPE/'] 74 | path_dict['4090'] = ['/home/jiawei/data/SYSU-MM01/', '/home/jiawei/data/SYSU_MM01_SHAPE/'] 75 | args = parser.parse_args() 76 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 77 | wandb.init(config=args, project='rgbir-reid2') 78 | args.method = args.method + "_gradclip" + str(args.gradclip) + "_seed" + str(args.seed) 79 | wandb.run.name = args.method 80 | # set_seed(args.seed) 81 | 82 | dataset = args.dataset 83 | if dataset == 'sysu': 84 | log_path = args.log_path + 'sysu_log/' 85 | test_mode = [1, 2] # thermal to visible 86 | 87 | 88 | checkpoint_path = args.model_path 89 | 90 | if not os.path.isdir(log_path): 91 | os.makedirs(log_path) 92 | if not os.path.isdir(checkpoint_path): 93 | os.makedirs(checkpoint_path) 94 | if not os.path.isdir(args.vis_log_path): 95 | os.makedirs(args.vis_log_path) 96 | 97 | suffix = dataset 98 | # if args.method == 'adp': 99 | # suffix = suffix + '_{}_joint_co_nog_ch_nog_sq{}'.format(args.method, args.square) 100 | # else: 101 | suffix = suffix + '_{}'.format(args.method) 102 | 103 | suffix = suffix + '_p{}_n{}_lr_{}'.format( args.num_pos, args.batch_size, args.lr) 104 | 105 | if not args.optim == 'sgd': 106 | suffix = suffix + '_' + args.optim 107 | 108 | 109 | sys.stdout = Logger(log_path + args.date + '/' + suffix + '_os.txt') 110 | 111 | vis_log_dir = args.vis_log_path + args.date + '/' + suffix + '/' 112 | 113 | if not os.path.isdir(vis_log_dir): 114 | os.makedirs(vis_log_dir) 115 | # writer = SummaryWriter(vis_log_dir) 116 | print("==========\nArgs:{}\n==========".format(args)) 117 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 118 | best_acc = 0 # best test accuracy 119 | best_acc_ema = 0 # best test accuracy 120 | start_epoch = 0 121 | 122 | print('==> Loading data..') 123 | # Data loading code 124 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 125 | transform_train_list = [ 126 | transforms.ToPILImage(), 127 | transforms.Pad(10), 128 | transforms.RandomCrop((args.img_h, args.img_w)), 129 | transforms.RandomHorizontalFlip(), 130 | transforms.ToTensor(), 131 | normalize] 132 | 133 | transform_test = transforms.Compose( [ 134 | transforms.ToPILImage(), 135 | transforms.Resize((args.img_h, args.img_w)), 136 | transforms.ToTensor(), 137 | normalize]) 138 | 139 | 140 | transform_train = transforms.Compose( transform_train_list ) 141 | 142 | end = time.time() 143 | if dataset == 'sysu': 144 | # training set 145 | trainset = SYSUData(data_dir=path_dict[args.gpuversion][0],data_dir1=path_dict[args.gpuversion][1]) 146 | # generate the idx of each person identity 147 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 148 | 149 | # testing set 150 | query_img, query_label, query_cam = process_query_sysu(mode=args.mode,data_path_ori=path_dict[args.gpuversion][0]) 151 | # gall_img, gall_label, gall_cam = process_gallery_sysu_all(mode=args.mode,data_path_ori=path_dict[args.gpuversion][0]) 152 | gall_img, gall_label, gall_cam = process_gallery_sysu(mode=args.mode, trial=0, data_path_ori=path_dict[args.gpuversion][0]) 153 | 154 | set_seed(args.seed) 155 | 156 | 157 | 158 | gallset = TestData(gall_img, gall_label, gall_cam, transform=transform_test, img_size=(args.img_w, args.img_h)) 159 | queryset = TestData(query_img, query_label, query_cam, transform=transform_test, img_size=(args.img_w, args.img_h)) 160 | 161 | # testing data loader 162 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 163 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 164 | 165 | n_class = len(np.unique(trainset.train_color_label)) 166 | nquery = len(query_label) 167 | ngall = len(gall_label) 168 | 169 | print('Dataset {} statistics:'.format(dataset)) 170 | print(' ------------------------------') 171 | print(' subset | # ids | # images') 172 | print(' ------------------------------') 173 | print(' visible | {:5d} | {:8d}'.format(n_class, len(trainset.train_color_label))) 174 | print(' thermal | {:5d} | {:8d}'.format(n_class, len(trainset.train_thermal_label))) 175 | print(' ------------------------------') 176 | print(' query | {:5d} | {:8d}'.format(len(np.unique(query_label)), nquery)) 177 | print(' gallery | {:5d} | {:8d}'.format(len(np.unique(gall_label)), ngall)) 178 | print(' ------------------------------') 179 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 180 | 181 | print('==> Building model..', args.method) 182 | 183 | net = embed_net(n_class, no_local= 'off', gm_pool = 'on', arch=args.arch) 184 | net_ema = embed_net(n_class, no_local= 'off', gm_pool = 'on', arch=args.arch) 185 | print('use model without nonlocal but gmpool') 186 | 187 | net.to(device) 188 | net_ema.to(device) 189 | # cudnn.benchmark = True 190 | 191 | if len(args.resume) > 0: 192 | model_path = checkpoint_path + args.resume 193 | if os.path.isfile(model_path): 194 | print('==> loading checkpoint {}'.format(args.resume)) 195 | checkpoint = torch.load(model_path) 196 | start_epoch = checkpoint['epoch'] 197 | net.load_state_dict(checkpoint['net']) 198 | print('==> loaded checkpoint {} (epoch {})' 199 | .format(args.resume, checkpoint['epoch'])) 200 | else: 201 | print('==> no checkpoint found at {}'.format(args.resume)) 202 | 203 | # define loss function 204 | criterion_id = nn.CrossEntropyLoss() 205 | if 'agw' in args.method: 206 | criterion_tri = TripletLoss_WRT() 207 | else: 208 | loader_batch = args.batch_size * args.num_pos 209 | criterion_tri= OriTripletLoss(batch_size=loader_batch, margin=args.margin) 210 | criterion_id.to(device) 211 | criterion_tri.to(device) 212 | if args.optim == 'sgd': 213 | ignored_params = list(map(id, net.classifier.parameters())) 214 | ignored_params += list(map(id, net.bottleneck.parameters())) 215 | ignored_params += list(map(id, net.bottleneck_ir.parameters())) 216 | ignored_params += list(map(id, net.classifier_ir.parameters())) 217 | if hasattr(net,'classifier_shape'): 218 | ignored_params += list(map(id, net.classifier_shape.parameters())) 219 | ignored_params += list(map(id, net.bottleneck_shape.parameters())) 220 | 221 | ignored_params += list(map(id, net.projs.parameters())) 222 | print('#####larger lr for shape#####') 223 | 224 | base_params = filter(lambda p: id(p) not in ignored_params, net.parameters()) 225 | params = [{'params': base_params, 'lr': 0.1 * args.lr}, {'params': net.classifier.parameters(), 'lr': args.lr},] 226 | 227 | params.append({'params': net.bottleneck.parameters(), 'lr': args.lr}) 228 | params.append({'params': net.bottleneck_shape.parameters(), 'lr': args.lr}) 229 | params.append({'params': net.classifier_ir.parameters(), 'lr': args.lr}) 230 | params.append({'params': net.classifier_shape.parameters(), 'lr': args.lr}) 231 | params.append({'params': net.projs.parameters(), 'lr': args.lr}) 232 | params.append({'params': net.bottleneck_ir.parameters(), 'lr': args.lr}) 233 | 234 | # optimizer = optim.Adam(params, weight_decay=5e-4) 235 | optimizer = optim.SGD(params, weight_decay=5e-4, momentum=0.9, nesterov=True) 236 | 237 | 238 | def adjust_learning_rate(optimizer, epoch): 239 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 240 | 241 | ema_w = 1000 242 | if epoch < 10: 243 | lr = args.lr * (epoch + 1) / 10 244 | elif epoch >= 10 and epoch < 20: 245 | lr = args.lr 246 | elif epoch >= 20 and epoch < 85: 247 | lr = args.lr * 0.1 248 | ema_w = 10000 249 | elif epoch < 120: 250 | lr = args.lr * 0.01 251 | ema_w = 100000 252 | optimizer.param_groups[0]['lr'] = 0.1*lr 253 | for i in range(len(optimizer.param_groups) - 1): 254 | optimizer.param_groups[i + 1]['lr'] = lr 255 | 256 | return lr, ema_w 257 | 258 | 259 | def update_ema_variables(net, net_ema, alpha, global_step=None): 260 | with torch.no_grad(): 261 | for ema_item, new_item in zip(net_ema.named_parameters(), net.named_parameters()): 262 | ema_key, ema_param = ema_item 263 | new_key, new_param = new_item 264 | if 'classifier' in ema_key or 'bottleneck' in ema_key or 'projs' in ema_key: 265 | alpha_now = alpha*2 266 | else: 267 | alpha_now = alpha 268 | mygrad = new_param.data - ema_param.data 269 | ema_param.data.add_(mygrad, alpha=alpha_now) 270 | 271 | def rand_bbox(size, lam): 272 | W = size[2] 273 | H = size[3] 274 | cut_rat = np.sqrt(1. - lam) 275 | cut_w = int(W * cut_rat) 276 | cut_h = int(H * cut_rat) 277 | 278 | # uniform 279 | cx = np.random.randint(W) 280 | cy = np.random.randint(H) 281 | 282 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 283 | bby1 = np.clip(cy - cut_h // 2, 0, H) 284 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 285 | bby2 = np.clip(cy + cut_h // 2, 0, H) 286 | 287 | return bbx1, bby1, bbx2, bby2 288 | 289 | def train(epoch): 290 | 291 | current_lr, ema_w = adjust_learning_rate(optimizer, epoch) 292 | print('current lr', current_lr) 293 | train_loss = AverageMeter() 294 | id_loss = AverageMeter() 295 | id_loss_shape = AverageMeter() 296 | id_loss_shape2 = AverageMeter() 297 | mutual_loss = AverageMeter() 298 | mutual_loss2 = AverageMeter() 299 | kl_loss = AverageMeter() 300 | data_time = AverageMeter() 301 | batch_time = AverageMeter() 302 | correct = 0 303 | total = 0 304 | 305 | # switch to train mode 306 | net.train() 307 | net_ema.train() 308 | end = time.time() 309 | for batch_idx, (inputs) in enumerate(trainloader): 310 | x1, x1_shape, x2, x2_shape, y1, y2 = inputs 311 | y = torch.cat((y1, y2), 0) 312 | x1, x1_shape, x2, x2_shape, y1, y2, y = x1.cuda(), x1_shape.cuda(), x2.cuda(), x2_shape.cuda(), y1.cuda(), y2.cuda(), y.cuda() 313 | 314 | data_time.update(time.time() - end) 315 | 316 | cutmix_prob = np.random.rand(1) 317 | if cutmix_prob < 0.2: 318 | # generate mixed sample 319 | x = torch.cat((x1, x2), 0) 320 | x_shape = torch.cat((x1_shape, x2_shape), 0) 321 | lam = np.random.beta(1, 1) 322 | bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam) 323 | 324 | rand_index = torch.randperm(y1.size()[0]).cuda() 325 | target_a = y 326 | target_b = torch.cat((y2[rand_index],y1[rand_index]), 0) 327 | x[:, :, bbx1:bbx2, bby1:bby2] = torch.cat((x2[rand_index, :, bbx1:bbx2, bby1:bby2],x1[rand_index, :, bbx1:bbx2, bby1:bby2]),0) 328 | x_shape[:, :, bbx1:bbx2, bby1:bby2] = torch.cat((x2_shape[rand_index, :, bbx1:bbx2, bby1:bby2],x1_shape[rand_index, :, bbx1:bbx2, bby1:bby2]),0) 329 | 330 | # adjust lambda to exactly match pixel ratio 331 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2])) 332 | # compute output 333 | outputs = net(x[:y1.shape[0]], x[y1.shape[0]:], x_shape[:y1.shape[0]], x_shape[y1.shape[0]:]) 334 | with torch.no_grad(): 335 | outputs_ema = net_ema(x[:y1.shape[0]], x[y1.shape[0]:], x_shape[:y1.shape[0]], x_shape[y1.shape[0]:]) 336 | 337 | loss_id = criterion_id(outputs['rgbir']['logit'], target_a) * lam + criterion_id(outputs['rgbir']['logit'], target_b) * (1. - lam) 338 | loss_id2 = torch.tensor([0]).cuda() 339 | loss_id_shape = criterion_id(outputs['shape']['logit'], target_a) * lam + criterion_id(outputs['shape']['logit'], target_b) * (1. - lam) 340 | loss_tri = torch.tensor([0]).cuda() 341 | loss_kl_rgbir = sce(outputs['rgbir']['logit'][:x1.shape[0]],outputs['rgbir']['logit'][x1.shape[0]:])+sce(outputs['rgbir']['logit'][x1.shape[0]:],outputs['rgbir']['logit'][:x1.shape[0]]) 342 | w1 = torch.tensor([1.]).cuda() 343 | loss_estimate = torch.tensor([0]).cuda() 344 | w2 = torch.tensor([1.]).cuda() 345 | loss_kl_rgbir2 = torch.tensor([0]).cuda() 346 | 347 | else: 348 | with torch.no_grad(): 349 | outputs_ema = net_ema(x1, x2, x1_shape, x2_shape) 350 | outputs = net(x1, x2, x1_shape, x2_shape) 351 | 352 | # id loss 353 | # if epoch < 40: 354 | loss_id = criterion_id(outputs['rgbir']['logit'], y) 355 | loss_id2 = criterion_id(outputs['rgbir']['logit2'], y) 356 | loss_id_shape = criterion_id(outputs['shape']['logit'], y) 357 | 358 | # triplet loss 359 | loss_tri, batch_acc = criterion_tri(outputs['rgbir']['bef'], y) 360 | 361 | # cross modal distill 362 | loss_kl_rgbir = sce(outputs['rgbir']['logit'][:x1.shape[0]],outputs['rgbir']['logit'][x1.shape[0]:])+sce(outputs['rgbir']['logit'][x1.shape[0]:],outputs['rgbir']['logit'][:x1.shape[0]]) 363 | # shape complementary 364 | loss_kl_rgbir2 = shape_cpmt_cross_modal_ce(x1, y1, outputs) 365 | 366 | # shape consistent 367 | loss_estimate = ((outputs['rgbir']['zp']-outputs_ema['shape']['zp'].detach()) ** 2).mean(1).mean() + sce(torch.mm(outputs['rgbir']['zp'], net.classifier_shape.weight.data.detach().t()), outputs_ema['shape']['logit']) 368 | 369 | 370 | ############## reweighting ############### 371 | compliment_grad = torch.autograd.grad(loss_id2+loss_kl_rgbir2, outputs['rgbir']['bef'], retain_graph=True)[0] 372 | consistent_grad = torch.autograd.grad(loss_estimate, outputs['rgbir']['bef'], retain_graph=True)[0] 373 | 374 | with torch.no_grad(): 375 | compliment_grad_norm = (compliment_grad.norm(p=2,dim=-1)).mean() 376 | consistent_grad_norm = (consistent_grad.norm(p=2,dim=-1)).mean() 377 | w1 = consistent_grad_norm / (compliment_grad_norm+consistent_grad_norm) * 2 378 | w2 = compliment_grad_norm / (compliment_grad_norm+consistent_grad_norm) * 2 379 | 380 | ############## orthogonalize loss ############### 381 | proj_inner = torch.mm(F.normalize(net.projs[0], 2, 0).t(), F.normalize(net.projs[0], 2, 0)) 382 | eye_label = torch.eye(net.projs[0].shape[1],device=device) 383 | loss_ortho = (proj_inner - eye_label).abs().sum(1).mean() 384 | 385 | 386 | 387 | loss = loss_id + loss_tri + loss_id_shape + loss_kl_rgbir + w1*loss_estimate + w2*loss_id2 +w2*loss_kl_rgbir2 + loss_ortho 388 | 389 | if not check_loss(loss): 390 | import pdb 391 | pdb.set_trace() 392 | optimizer.zero_grad() 393 | loss.backward() 394 | torch.nn.utils.clip_grad_norm_(net.parameters(), args.gradclip) 395 | optimizer.step() 396 | 397 | 398 | update_ema_variables(net, net_ema, 1/ema_w) 399 | 400 | 401 | # update P 402 | train_loss.update(loss_id2.item(), 2 * x1.size(0)) 403 | id_loss.update(loss_id.item(), 2 * x1.size(0)) 404 | id_loss_shape.update(loss_id_shape.item(), 2 * x1.size(0)) 405 | id_loss_shape2.update(w1.item(), 2 * x1.size(0)) 406 | mutual_loss2.update(loss_kl_rgbir2.item(), 2 * x1.size(0)) 407 | mutual_loss.update(loss_ortho.item(), 2 * x1.size(0)) 408 | # kl_loss.update(loss_kl2.item()+loss_kl.item() , 2 * x1.size(0)) 409 | kl_loss.update(loss_estimate.item(), 2 * x1.size(0)) 410 | total += y.size(0) 411 | 412 | # measure elapsed time 100. * correct / total 413 | batch_time.update(time.time() - end) 414 | end = time.time() 415 | if batch_idx % 50 == 0: 416 | # import pdb 417 | # pdb.set_trace() 418 | print('Epoch:[{}][{}/{}]' 419 | 'L:{id_loss.val:.4f}({id_loss.avg:.4f}) ' 420 | 'L2:{train_loss.val:.4f}({train_loss.avg:.4f}) ' 421 | 'sL:{id_loss_shape.val:.4f}({id_loss_shape.avg:.4f}) ' 422 | 'w1:{id_loss_shape2.val:.4f}({id_loss_shape2.avg:.4f}) ' 423 | 'or:{mutual_loss.val:.4f}({mutual_loss.avg:.4f}) ' 424 | 'ML2:{mutual_loss2.val:.4f}({mutual_loss2.avg:.4f}) ' 425 | 'KL:{kl_loss.val:.4f}({kl_loss.avg:.4f}) '.format( 426 | epoch, batch_idx, len(trainloader), 427 | train_loss=train_loss, id_loss=id_loss, id_loss_shape=id_loss_shape, id_loss_shape2=id_loss_shape2, mutual_loss=mutual_loss, mutual_loss2=mutual_loss2, kl_loss=kl_loss)) 428 | 429 | 430 | def test(net): 431 | pool_dim = 2048 432 | def fliplr(img): 433 | '''flip horizontal''' 434 | inv_idx = torch.arange(img.size(3)-1,-1,-1,device=img.device).long() # N x C x H x W 435 | img_flip = img.index_select(3,inv_idx) 436 | return img_flip 437 | def extract_gall_feat(gall_loader): 438 | net.eval() 439 | print ('Extracting Gallery Feature...') 440 | start = time.time() 441 | ptr = 0 442 | ngall = len(gall_loader.dataset) 443 | gall_feat_pool = np.zeros((ngall, pool_dim)) 444 | gall_feat_fc = np.zeros((ngall, pool_dim)) 445 | with torch.no_grad(): 446 | for batch_idx, (input, label, cam) in enumerate(gall_loader): 447 | batch_num = input.size(0) 448 | input = Variable(input.cuda()) 449 | feat_pool, feat_fc = net(input, input, mode=test_mode[0]) 450 | input2 = fliplr(input) 451 | feat_pool2, feat_fc2 = net(input2, input2, mode=test_mode[0]) 452 | feat_pool = (feat_pool+feat_pool2)/2 453 | feat_fc = (feat_fc+feat_fc2)/2 454 | 455 | gall_feat_pool[ptr:ptr+batch_num,: ] = feat_pool.detach().cpu().numpy() 456 | gall_feat_fc[ptr:ptr+batch_num,: ] = feat_fc.detach().cpu().numpy() 457 | ptr = ptr + batch_num 458 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 459 | return gall_feat_pool, gall_feat_fc 460 | 461 | def extract_query_feat(query_loader): 462 | net.eval() 463 | print ('Extracting Query Feature...') 464 | start = time.time() 465 | ptr = 0 466 | nquery = len(query_loader.dataset) 467 | query_feat_pool = np.zeros((nquery, pool_dim)) 468 | query_feat_fc = np.zeros((nquery, pool_dim)) 469 | with torch.no_grad(): 470 | for batch_idx, (input, label, cam) in enumerate(query_loader): 471 | batch_num = input.size(0) 472 | input = Variable(input.cuda()) 473 | feat_pool, feat_fc = net(input, input, mode=test_mode[1]) 474 | input2 = fliplr(input) 475 | feat_pool2, feat_fc2 = net(input2, input2, mode=test_mode[1]) 476 | feat_pool = (feat_pool+feat_pool2)/2 477 | feat_fc = (feat_fc+feat_fc2)/2 478 | query_feat_pool[ptr:ptr+batch_num,: ] = feat_pool.detach().cpu().numpy() 479 | query_feat_fc[ptr:ptr+batch_num,: ] = feat_fc.detach().cpu().numpy() 480 | ptr = ptr + batch_num 481 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 482 | return query_feat_pool, query_feat_fc 483 | # switch to evaluation mode 484 | net.eval() 485 | query_feat_pool, query_feat_fc = extract_query_feat(query_loader) 486 | 487 | # gall_img, gall_label, gall_cam = process_gallery_sysu(mode=args.mode, trial=0) 488 | 489 | trial_gallset = TestData(gall_img, gall_label, gall_cam, transform=transform_test, img_size=(args.img_w, args.img_h)) 490 | trial_gall_loader = data.DataLoader(trial_gallset, batch_size=args.test_batch, shuffle=False, num_workers=4) 491 | 492 | gall_feat_pool, gall_feat_fc = extract_gall_feat(trial_gall_loader) 493 | 494 | # pool5 feature 495 | distmat_pool = np.matmul(query_feat_pool, np.transpose(gall_feat_pool)) 496 | cmc_pool, mAP_pool, mINP_pool = eval_sysu(-distmat_pool, query_label, gall_label, query_cam, gall_cam) 497 | 498 | # fc feature 499 | distmat = np.matmul(query_feat_fc, np.transpose(gall_feat_fc)) 500 | cmc, mAP, mINP = eval_sysu(-distmat, query_label, gall_label, query_cam, gall_cam) 501 | all_cmc = cmc 502 | all_mAP = mAP 503 | all_mINP = mINP 504 | all_cmc_pool = cmc_pool 505 | all_mAP_pool = mAP_pool 506 | all_mINP_pool = mINP_pool 507 | return all_cmc, all_mAP 508 | 509 | 510 | 511 | def seed_worker(worker_id): 512 | worker_seed = torch.initial_seed() % 2**32 513 | np.random.seed(worker_seed) 514 | random.seed(worker_seed) 515 | 516 | # training 517 | print('==> Start Training...') 518 | for epoch in range(start_epoch, 120 - start_epoch): 519 | 520 | print('==> Preparing Data Loader...') 521 | # identity sampler 522 | sampler = IdentitySampler(trainset.train_color_label, \ 523 | trainset.train_thermal_label, color_pos, thermal_pos, args.num_pos, args.batch_size, 524 | epoch) 525 | 526 | trainset.cIndex = sampler.index1 # color index 527 | trainset.tIndex = sampler.index2 # thermal index 528 | print(epoch) 529 | # print(trainset.cIndex) 530 | # print(trainset.tIndex) 531 | 532 | loader_batch = args.batch_size * args.num_pos 533 | 534 | trainloader = data.DataLoader(trainset, batch_size=loader_batch, \ 535 | sampler=sampler, num_workers=args.workers, drop_last=True) 536 | 537 | # training 538 | 539 | if epoch == 0: 540 | net_ema.load_state_dict(net.state_dict()) 541 | print('init ema modal') 542 | 543 | 544 | train(epoch) 545 | 546 | print('Test Epoch: {}'.format(epoch)) 547 | 548 | # testing 549 | cmc, mAP = test(net) 550 | wandb.log({'rank1': cmc[0], 551 | 'mAP': mAP, 552 | },step=epoch) 553 | cmc_ema, mAP_ema = test(net_ema) 554 | wandb.log({'rank1_ema': cmc_ema[0], 555 | 'mAP_ema': mAP_ema, 556 | },step=epoch) 557 | # save model 558 | if cmc[0] > best_acc: 559 | best_acc = cmc[0] 560 | best_epoch = epoch 561 | state = { 562 | 'net': net.state_dict(), 563 | 'cmc': cmc, 564 | 'mAP': mAP, 565 | 'epoch': epoch, 566 | } 567 | torch.save(state, checkpoint_path + suffix + '_best.t') 568 | if cmc_ema[0] > best_acc_ema: 569 | best_acc_ema = cmc_ema[0] 570 | best_epoch_ema = epoch 571 | state = { 572 | 'net': net_ema.state_dict(), 573 | 'cmc': cmc_ema, 574 | 'mAP': mAP_ema, 575 | 'epoch': epoch, 576 | } 577 | torch.save(state, checkpoint_path + suffix + '_ema_best.t') 578 | if epoch % 5 == 0: 579 | state = { 580 | 'net': net_ema.state_dict(), 581 | } 582 | torch.save(state, checkpoint_path + suffix + '_' + str(epoch) + '_.t') 583 | 584 | print('Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}'.format( 585 | cmc[0], cmc[4], cmc[9], cmc[19], mAP)) 586 | print('Best Epoch [{}]'.format(best_epoch)) 587 | 588 | print('------------------ema eval------------------') 589 | print('Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}'.format( 590 | cmc_ema[0], cmc_ema[4], cmc_ema[9], cmc_ema[19], mAP_ema)) 591 | print('Best Epoch [{}]'.format(best_epoch_ema)) 592 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data.sampler import Sampler 4 | import sys 5 | import os.path as osp 6 | import torch 7 | import random 8 | def load_data(input_data_path ): 9 | with open(input_data_path) as f: 10 | data_file_list = open(input_data_path, 'rt').read().splitlines() 11 | # Get full list of color image and labels 12 | file_image = [s.split(' ')[0] for s in data_file_list] 13 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 14 | 15 | return file_image, file_label 16 | 17 | 18 | def GenIdx( train_color_label, train_thermal_label): 19 | color_pos = [] 20 | unique_label_color = np.unique(train_color_label) 21 | for i in range(len(unique_label_color)): 22 | tmp_pos = [k for k,v in enumerate(train_color_label) if v==unique_label_color[i]] 23 | color_pos.append(tmp_pos) 24 | 25 | thermal_pos = [] 26 | unique_label_thermal = np.unique(train_thermal_label) 27 | for i in range(len(unique_label_thermal)): 28 | tmp_pos = [k for k,v in enumerate(train_thermal_label) if v==unique_label_thermal[i]] 29 | thermal_pos.append(tmp_pos) 30 | return color_pos, thermal_pos 31 | 32 | def GenCamIdx(gall_img, gall_label, mode): 33 | if mode =='indoor': 34 | camIdx = [1,2] 35 | else: 36 | camIdx = [1,2,4,5] 37 | gall_cam = [] 38 | for i in range(len(gall_img)): 39 | gall_cam.append(int(gall_img[i][-10])) 40 | 41 | sample_pos = [] 42 | unique_label = np.unique(gall_label) 43 | for i in range(len(unique_label)): 44 | for j in range(len(camIdx)): 45 | id_pos = [k for k,v in enumerate(gall_label) if v==unique_label[i] and gall_cam[k]==camIdx[j]] 46 | if id_pos: 47 | sample_pos.append(id_pos) 48 | return sample_pos 49 | 50 | def ExtractCam(gall_img): 51 | gall_cam = [] 52 | for i in range(len(gall_img)): 53 | cam_id = int(gall_img[i][-10]) 54 | # if cam_id ==3: 55 | # cam_id = 2 56 | gall_cam.append(cam_id) 57 | 58 | return np.array(gall_cam) 59 | 60 | class IdentitySampler(Sampler): 61 | """Sample person identities evenly in each batch. 62 | Args: 63 | train_color_label, train_thermal_label: labels of two modalities 64 | color_pos, thermal_pos: positions of each identity 65 | batchSize: batch size 66 | """ 67 | 68 | def __init__(self, train_color_label, train_thermal_label, color_pos, thermal_pos, num_pos, batchSize, epoch): 69 | uni_label = np.unique(train_color_label) 70 | self.n_classes = len(uni_label) 71 | 72 | 73 | N = np.maximum(len(train_color_label), len(train_thermal_label)) 74 | for j in range(int(N/(batchSize*num_pos))+1): 75 | batch_idx = np.random.choice(uni_label, batchSize, replace = False) 76 | for i in range(batchSize): 77 | sample_color = np.random.choice(color_pos[batch_idx[i]], num_pos) 78 | sample_thermal = np.random.choice(thermal_pos[batch_idx[i]], num_pos) 79 | 80 | if j ==0 and i==0: 81 | index1= sample_color 82 | index2= sample_thermal 83 | else: 84 | index1 = np.hstack((index1, sample_color)) 85 | index2 = np.hstack((index2, sample_thermal)) 86 | 87 | self.index1 = index1 88 | self.index2 = index2 89 | self.N = N 90 | 91 | def __iter__(self): 92 | return iter(np.arange(len(self.index1))) 93 | 94 | def __len__(self): 95 | return self.N 96 | 97 | class AverageMeter(object): 98 | """Computes and stores the average and current value""" 99 | def __init__(self): 100 | self.reset() 101 | 102 | def reset(self): 103 | self.val = 0 104 | self.avg = 0 105 | self.sum = 0 106 | self.count = 0 107 | 108 | def update(self, val, n=1): 109 | self.val = val 110 | self.sum += val * n 111 | self.count += n 112 | self.avg = self.sum / self.count 113 | 114 | def mkdir_if_missing(directory): 115 | if not osp.exists(directory): 116 | try: 117 | os.makedirs(directory) 118 | except OSError as e: 119 | if e.errno != errno.EEXIST: 120 | raise 121 | class Logger(object): 122 | """ 123 | Write console output to external text file. 124 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 125 | """ 126 | def __init__(self, fpath=None): 127 | self.console = sys.stdout 128 | self.file = None 129 | if fpath is not None: 130 | mkdir_if_missing(osp.dirname(fpath)) 131 | self.file = open(fpath, 'w') 132 | 133 | def __del__(self): 134 | self.close() 135 | 136 | def __enter__(self): 137 | pass 138 | 139 | def __exit__(self, *args): 140 | self.close() 141 | 142 | def write(self, msg): 143 | self.console.write(msg) 144 | if self.file is not None: 145 | self.file.write(msg) 146 | 147 | def flush(self): 148 | self.console.flush() 149 | if self.file is not None: 150 | self.file.flush() 151 | os.fsync(self.file.fileno()) 152 | 153 | def close(self): 154 | self.console.close() 155 | if self.file is not None: 156 | self.file.close() 157 | 158 | def set_seed(seed, cuda=True): 159 | random.seed(seed) 160 | np.random.seed(seed) 161 | torch.manual_seed(seed) 162 | torch.cuda.manual_seed_all(seed) 163 | torch.cuda.manual_seed(seed) 164 | torch.backends.cudnn.deterministic = True # This will slow down training. 165 | torch.backends.cudnn.benchmark = False 166 | 167 | def check_loss(loss): 168 | return not bool(torch.isnan(loss).item()) and bool((loss >= 0.0).item()) and bool((loss < 1e6).item()) 169 | 170 | def set_requires_grad(nets, requires_grad=False): 171 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 172 | Parameters: 173 | nets (network list) -- a list of networks 174 | requires_grad (bool) -- whether the networks require gradients or not 175 | """ 176 | if not isinstance(nets, list): 177 | nets = [nets] 178 | for net in nets: 179 | if net is not None: 180 | for param in net.parameters(): 181 | param.requires_grad = requires_grad --------------------------------------------------------------------------------