├── LICENSE ├── README.md ├── attention.py ├── data_loader.py ├── data_manager.py ├── eval_metrics.py ├── loss.py ├── model_main.py ├── resnet.py ├── test_ddag.py ├── train_ddag.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 mangye16 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DDAG 2 | Pytorch Code of DDAG for Visible-Infrared Person Re-Identification in ECCV 2020. [PDF](https://arxiv.org/pdf/2007.09314.pdf) 3 | 4 | A Huawei MindSpore implementation of our DDAG method is [HERE](https://gitee.com/mindspore/models/tree/master/research/cv/DDAG). Thanks to Zhiwei Zhang zhangzw12319@163.com. 5 | 6 | ## Highlight 7 | 8 | The goal of this work is to learn a robust and discriminative cross-modality representation for visible-infrarerd person re-identification. 9 | 10 | - Intra-modality Weighted-Part Aggregation (IWPA): It learns discriminative part-aggregated features by mining the contextual part relation. 11 | 12 | - Cross-modality Graph Structured Attention (CGSA): It enhances the feature by incorporating the neighborhood information across two modalities. 13 | 14 | ### Results on the SYSU-MM01 Dataset 15 | Method |Datasets | Rank@1 | mAP | mINP | 16 | |------| -------- | ----- | ----- | ----- | 17 | | AGW [[1](https://github.com/mangye16/Cross-Modal-Re-ID-baseline)] |#SYSU-MM01 (All-Search) | ~ 47.50% | ~ 47.65% | ~ 35.30% | 18 | | DDAG|#SYSU-MM01 (All-Search) | ~ 54.75% | ~ 53.02% | ~39.62% | 19 | | AGW [[1](https://github.com/mangye16/Cross-Modal-Re-ID-baseline)] |#SYSU-MM01 (Indoor-Search) | ~ 54.17% | ~ 62.97% | ~ 59.23%| 20 | | DDAG|#SYSU-MM01 (Indoor-Search) | ~ 61.02% | ~ 67.98% | ~ 62.61%| 21 | 22 | *The code has been tested in Python 3.7, PyTorch=1.0. Both of these two datasets may have some fluctuation due to random spliting 23 | 24 | ### 1. Prepare the datasets. 25 | 26 | - (1) RegDB Dataset [1]: The RegDB dataset can be downloaded from this [website](http://dm.dongguk.edu/link.html) by submitting a copyright form. 27 | 28 | - (Named: "Dongguk Body-based Person Recognition Database (DBPerson-Recog-DB1)" on their website). 29 | 30 | - A private download link can be requested via sending me an email (mangye16@gmail.com). 31 | 32 | - (2) SYSU-MM01 Dataset [2]: The SYSU-MM01 dataset can be downloaded from this [website](http://isee.sysu.edu.cn/project/RGBIRReID.htm). 33 | 34 | - run `python pre_process_sysu.py` [link](https://github.com/mangye16/Cross-Modal-Re-ID-baseline/blob/master/pre_process_sysu.py) in to pepare the dataset, the training data will be stored in ".npy" format. 35 | 36 | ### 2. Training. 37 | Train a model by 38 | ```bash 39 | python train_ddag.py --dataset sysu --lr 0.1 --graph --wpa --part 3 --gpu 0 40 | ``` 41 | 42 | - `--dataset`: which dataset "sysu" or "regdb". 43 | 44 | - `--lr`: initial learning rate. 45 | 46 | - `--graph`: using graph attention. 47 | 48 | - `--wpa`: using weighted part attention 49 | 50 | - `--part`: part number 51 | 52 | - `--gpu`: which gpu to run. 53 | 54 | You may need manually define the data path first. 55 | 56 | 57 | ### 3. Testing. 58 | 59 | Test a model on SYSU-MM01 or RegDB dataset by 60 | ```bash 61 | python test_ddag.py --dataset sysu --mode all --wpa --graph --gpu 1 --resume 'model_path' 62 | ``` 63 | - `--dataset`: which dataset "sysu" or "regdb". 64 | 65 | - `--mode`: "all" or "indoor" all search or indoor search (only for sysu dataset). 66 | 67 | - `--trial`: testing trial (only for RegDB dataset). 68 | 69 | - `--resume`: the saved model path. ** Important ** 70 | 71 | - `--gpu`: which gpu to run. 72 | 73 | ### 4. Citation 74 | 75 | Please kindly cite the references in your publications if it helps your research: 76 | ``` 77 | @inproceedings{eccv20ddag, 78 | title={Dynamic Dual-Attentive Aggregation Learning for Visible-Infrared Person Re-Identification}, 79 | author={Ye, Mang and Shen, Jianbing and Crandall, David J. and Shao, Ling and Luo, Jiebo}, 80 | booktitle={European Conference on Computer Vision (ECCV)}, 81 | year={2020}, 82 | } 83 | ``` 84 | 85 | ``` 86 | @article{arxiv20reidsurvey, 87 | title={Deep Learning for Person Re-identification: A Survey and Outlook}, 88 | author={Ye, Mang and Shen, Jianbing and Lin, Gaojie and Xiang, Tao and Shao, Ling and Hoi, Steven C. H.}, 89 | journal={arXiv preprint arXiv:2001.04193}, 90 | year={2020}, 91 | } 92 | ``` 93 | 94 | Contact: mangye16@gmail.com 95 | -------------------------------------------------------------------------------- /attention.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | """ 6 | PART of the code is from the following link 7 | https://github.com/Diego999/pyGAT/blob/master/layers.py 8 | """ 9 | 10 | 11 | class Normalize(nn.Module): 12 | def __init__(self, power=2): 13 | super(Normalize, self).__init__() 14 | self.power = power 15 | 16 | def forward(self, x): 17 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 18 | out = x.div(norm) 19 | return out 20 | 21 | class GraphAttentionLayer(nn.Module): 22 | """ 23 | Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 24 | """ 25 | 26 | def __init__(self, in_features, out_features, dropout, alpha=0.2, concat=True): 27 | super(GraphAttentionLayer, self).__init__() 28 | self.dropout = dropout 29 | self.in_features = in_features 30 | self.out_features = out_features 31 | self.alpha = alpha 32 | self.concat = concat 33 | 34 | self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) 35 | nn.init.xavier_uniform_(self.W.data, gain=1.414) 36 | self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1))) 37 | nn.init.xavier_uniform_(self.a.data, gain=1.414) 38 | 39 | self.leakyrelu = nn.LeakyReLU(self.alpha) 40 | 41 | def forward(self, input, adj): 42 | h = torch.mm(input, self.W) 43 | N = h.size()[0] 44 | 45 | a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features) 46 | e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2)) 47 | 48 | zero_vec = -9e15 * torch.ones_like(e) 49 | attention = torch.where(adj > 0, e, zero_vec) 50 | attention = F.softmax(attention, dim=1) 51 | attention = F.dropout(attention, self.dropout, training=self.training) 52 | h_prime = torch.matmul(attention, h) 53 | 54 | if self.concat: 55 | return F.elu(h_prime) 56 | else: 57 | return h_prime 58 | 59 | def __repr__(self): 60 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 61 | 62 | 63 | class SpecialSpmmFunction(torch.autograd.Function): 64 | """Special function for only sparse region backpropataion layer.""" 65 | 66 | @staticmethod 67 | def forward(ctx, indices, values, shape, b): 68 | assert indices.requires_grad == False 69 | a = torch.sparse_coo_tensor(indices, values, shape) 70 | ctx.save_for_backward(a, b) 71 | ctx.N = shape[0] 72 | return torch.matmul(a, b) 73 | 74 | @staticmethod 75 | def backward(ctx, grad_output): 76 | a, b = ctx.saved_tensors 77 | grad_values = grad_b = None 78 | if ctx.needs_input_grad[1]: 79 | grad_a_dense = grad_output.matmul(b.t()) 80 | edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :] 81 | grad_values = grad_a_dense.view(-1)[edge_idx] 82 | if ctx.needs_input_grad[3]: 83 | grad_b = a.t().matmul(grad_output) 84 | return None, grad_values, None, grad_b 85 | 86 | 87 | class SpecialSpmm(nn.Module): 88 | def forward(self, indices, values, shape, b): 89 | return SpecialSpmmFunction.apply(indices, values, shape, b) 90 | 91 | 92 | class SpGraphAttentionLayer(nn.Module): 93 | """ 94 | Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903 95 | """ 96 | 97 | def __init__(self, in_features, out_features, dropout, alpha, concat=True): 98 | super(SpGraphAttentionLayer, self).__init__() 99 | self.in_features = in_features 100 | self.out_features = out_features 101 | self.alpha = alpha 102 | self.concat = concat 103 | 104 | self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) 105 | nn.init.xavier_normal_(self.W.data, gain=1.414) 106 | 107 | self.a = nn.Parameter(torch.zeros(size=(1, 2 * out_features))) 108 | nn.init.xavier_normal_(self.a.data, gain=1.414) 109 | 110 | self.dropout = nn.Dropout(dropout) 111 | self.leakyrelu = nn.LeakyReLU(self.alpha) 112 | self.special_spmm = SpecialSpmm() 113 | 114 | def forward(self, input, adj): 115 | dv = 'cuda' if input.is_cuda else 'cpu' 116 | 117 | N = input.size()[0] 118 | edge = adj.nonzero().t() 119 | 120 | h = torch.mm(input, self.W) 121 | # h: N x out 122 | assert not torch.isnan(h).any() 123 | 124 | # Self-attention on the nodes - Shared attention mechanism 125 | edge_h = torch.cat((h[edge[0, :], :], h[edge[1, :], :]), dim=1).t() 126 | # edge: 2*D x E 127 | 128 | edge_e = torch.exp(-self.leakyrelu(self.a.mm(edge_h).squeeze())) 129 | assert not torch.isnan(edge_e).any() 130 | # edge_e: E 131 | 132 | e_rowsum = self.special_spmm(edge, edge_e, torch.Size([N, N]), torch.ones(size=(N, 1), device=dv)) 133 | # e_rowsum: N x 1 134 | 135 | edge_e = self.dropout(edge_e) 136 | # edge_e: E 137 | 138 | h_prime = self.special_spmm(edge, edge_e, torch.Size([N, N]), h) 139 | assert not torch.isnan(h_prime).any() 140 | # h_prime: N x out 141 | 142 | h_prime = h_prime.div(e_rowsum) 143 | # h_prime: N x out 144 | assert not torch.isnan(h_prime).any() 145 | 146 | if self.concat: 147 | # if this layer is not last layer, 148 | return F.elu(h_prime) 149 | else: 150 | # if this layer is last layer, 151 | return h_prime 152 | 153 | def __repr__(self): 154 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 155 | 156 | 157 | class IWPA(nn.Module): 158 | """ 159 | Part attention layer, "Dynamic Dual-Attentive Aggregation Learning for Visible-Infrared Person Re-Identification" 160 | """ 161 | def __init__(self, in_channels, part = 3, inter_channels=None, out_channels=None): 162 | super(IWPA, self).__init__() 163 | 164 | self.in_channels = in_channels 165 | self.inter_channels = inter_channels 166 | self.out_channels = out_channels 167 | self.l2norm = Normalize(2) 168 | 169 | if self.inter_channels is None: 170 | self.inter_channels = in_channels 171 | 172 | if self.out_channels is None: 173 | self.out_channels = in_channels 174 | 175 | conv_nd = nn.Conv2d 176 | 177 | self.fc1 = nn.Sequential( 178 | conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, 179 | padding=0), 180 | ) 181 | 182 | self.fc2 = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 183 | kernel_size=1, stride=1, padding=0) 184 | 185 | self.fc3 = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 186 | kernel_size=1, stride=1, padding=0) 187 | 188 | self.W = nn.Sequential( 189 | conv_nd(in_channels=self.inter_channels, out_channels=self.out_channels, 190 | kernel_size=1, stride=1, padding=0), 191 | nn.BatchNorm2d(self.out_channels), 192 | ) 193 | nn.init.constant_(self.W[1].weight, 0.0) 194 | nn.init.constant_(self.W[1].bias, 0.0) 195 | 196 | 197 | self.bottleneck = nn.BatchNorm1d(in_channels) 198 | self.bottleneck.bias.requires_grad_(False) # no shift 199 | 200 | nn.init.normal_(self.bottleneck.weight.data, 1.0, 0.01) 201 | nn.init.zeros_(self.bottleneck.bias.data) 202 | 203 | # weighting vector of the part features 204 | self.gate = nn.Parameter(torch.FloatTensor(part)) 205 | nn.init.constant_(self.gate, 1/part) 206 | def forward(self, x, feat, t=None, part=0): 207 | bt, c, h, w = x.shape 208 | b = bt // t 209 | 210 | # get part features 211 | part_feat = F.adaptive_avg_pool2d(x, (part, 1)) 212 | part_feat = part_feat.view(b, t, c, part) 213 | part_feat = part_feat.permute(0, 2, 1, 3) # B, C, T, Part 214 | 215 | part_feat1 = self.fc1(part_feat).view(b, self.inter_channels, -1) # B, C//r, T*Part 216 | part_feat1 = part_feat1.permute(0, 2, 1) # B, T*Part, C//r 217 | 218 | part_feat2 = self.fc2(part_feat).view(b, self.inter_channels, -1) # B, C//r, T*Part 219 | 220 | part_feat3 = self.fc3(part_feat).view(b, self.inter_channels, -1) # B, C//r, T*Part 221 | part_feat3 = part_feat3.permute(0, 2, 1) # B, T*Part, C//r 222 | 223 | # get cross-part attention 224 | cpa_att = torch.matmul(part_feat1, part_feat2) # B, T*Part, T*Part 225 | cpa_att = F.softmax(cpa_att, dim=-1) 226 | 227 | # collect contextual information 228 | refined_part_feat = torch.matmul(cpa_att, part_feat3) # B, T*Part, C//r 229 | refined_part_feat = refined_part_feat.permute(0, 2, 1).contiguous() # B, C//r, T*Part 230 | refined_part_feat = refined_part_feat.view(b, self.inter_channels, part) # B, C//r, T, Part 231 | 232 | gate = F.softmax(self.gate, dim=-1) 233 | weight_part_feat = torch.matmul(refined_part_feat, gate) 234 | x = F.adaptive_avg_pool2d(x, (1, 1)) 235 | # weight_part_feat = weight_part_feat + x.view(x.size(0), x.size(1)) 236 | 237 | weight_part_feat = weight_part_feat + feat 238 | feat = self.bottleneck(weight_part_feat) 239 | 240 | return feat -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image, ImageChops 3 | from torchvision import transforms 4 | import random 5 | import pdb 6 | import torch 7 | import torchvision.datasets as datasets 8 | import torch.utils.data as data 9 | 10 | 11 | class SYSUData(data.Dataset): 12 | def __init__(self, data_dir, transform=None, colorIndex = None, thermalIndex = None): 13 | 14 | # data_dir = '../Datasets/SYSU-MM01/' 15 | 16 | # Load training images (path) and labels 17 | train_color_image = np.load(data_dir + 'train_rgb_resized_img.npy') 18 | self.train_color_label = np.load(data_dir + 'train_rgb_resized_label.npy') 19 | 20 | train_thermal_image = np.load(data_dir + 'train_ir_resized_img.npy') 21 | self.train_thermal_label = np.load(data_dir + 'train_ir_resized_label.npy') 22 | 23 | # BGR to RGB 24 | self.train_color_image = train_color_image 25 | self.train_thermal_image = train_thermal_image 26 | self.transform = transform 27 | self.cIndex = colorIndex 28 | self.tIndex = thermalIndex 29 | 30 | def __getitem__(self, index): 31 | 32 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 33 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 34 | 35 | img1 = self.transform(img1) 36 | img2 = self.transform(img2) 37 | 38 | return img1, img2, target1, target2 39 | 40 | def __len__(self): 41 | return len(self.train_color_label) 42 | 43 | 44 | class RegDBData(data.Dataset): 45 | def __init__(self, data_dir, trial, transform=None, colorIndex = None, thermalIndex = None): 46 | # Load training images (path) and labels 47 | # data_dir = '../Datasets/RegDB/' 48 | train_color_list = data_dir + 'idx/train_visible_{}'.format(trial)+ '.txt' 49 | train_thermal_list = data_dir + 'idx/train_thermal_{}'.format(trial)+ '.txt' 50 | 51 | color_img_file, train_color_label = load_data(train_color_list) 52 | thermal_img_file, train_thermal_label = load_data(train_thermal_list) 53 | 54 | train_color_image = [] 55 | for i in range(len(color_img_file)): 56 | 57 | img = Image.open(data_dir+ color_img_file[i]) 58 | img = img.resize((144, 288), Image.ANTIALIAS) 59 | pix_array = np.array(img) 60 | train_color_image.append(pix_array) 61 | train_color_image = np.array(train_color_image) 62 | 63 | train_thermal_image = [] 64 | for i in range(len(thermal_img_file)): 65 | img = Image.open(data_dir+ thermal_img_file[i]) 66 | img = img.resize((144, 288), Image.ANTIALIAS) 67 | pix_array = np.array(img) 68 | train_thermal_image.append(pix_array) 69 | train_thermal_image = np.array(train_thermal_image) 70 | 71 | # BGR to RGB 72 | self.train_color_image = train_color_image 73 | self.train_color_label = train_color_label 74 | 75 | # BGR to RGB 76 | self.train_thermal_image = train_thermal_image 77 | self.train_thermal_label = train_thermal_label 78 | 79 | self.transform = transform 80 | self.cIndex = colorIndex 81 | self.tIndex = thermalIndex 82 | 83 | def __getitem__(self, index): 84 | 85 | img1, target1 = self.train_color_image[self.cIndex[index]], self.train_color_label[self.cIndex[index]] 86 | img2, target2 = self.train_thermal_image[self.tIndex[index]], self.train_thermal_label[self.tIndex[index]] 87 | 88 | img1 = self.transform(img1) 89 | img2 = self.transform(img2) 90 | 91 | return img1, img2, target1, target2 92 | 93 | def __len__(self): 94 | return len(self.train_color_label) 95 | 96 | class TestData(data.Dataset): 97 | def __init__(self, test_img_file, test_label, transform=None, img_size = (144,288)): 98 | 99 | test_image = [] 100 | for i in range(len(test_img_file)): 101 | img = Image.open(test_img_file[i]) 102 | img = img.resize((img_size[0], img_size[1]), Image.ANTIALIAS) 103 | pix_array = np.array(img) 104 | test_image.append(pix_array) 105 | test_image = np.array(test_image) 106 | self.test_image = test_image 107 | self.test_label = test_label 108 | self.transform = transform 109 | 110 | def __getitem__(self, index): 111 | img1, target1 = self.test_image[index], self.test_label[index] 112 | img1 = self.transform(img1) 113 | return img1, target1 114 | 115 | def __len__(self): 116 | return len(self.test_image) 117 | 118 | class TestDataOld(data.Dataset): 119 | def __init__(self, data_dir, test_img_file, test_label, transform=None, img_size = (144,288)): 120 | 121 | test_image = [] 122 | for i in range(len(test_img_file)): 123 | img = Image.open(data_dir + test_img_file[i]) 124 | img = img.resize((img_size[0], img_size[1]), Image.ANTIALIAS) 125 | pix_array = np.array(img) 126 | test_image.append(pix_array) 127 | test_image = np.array(test_image) 128 | self.test_image = test_image 129 | self.test_label = test_label 130 | self.transform = transform 131 | 132 | def __getitem__(self, index): 133 | img1, target1 = self.test_image[index], self.test_label[index] 134 | img1 = self.transform(img1) 135 | return img1, target1 136 | 137 | def __len__(self): 138 | return len(self.test_image) 139 | def load_data(input_data_path ): 140 | with open(input_data_path) as f: 141 | data_file_list = open(input_data_path, 'rt').read().splitlines() 142 | # Get full list of image and labels 143 | file_image = [s.split(' ')[0] for s in data_file_list] 144 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 145 | 146 | return file_image, file_label -------------------------------------------------------------------------------- /data_manager.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os 3 | import sys 4 | import numpy as np 5 | import random 6 | 7 | def process_query_sysu(data_path, mode = 'all', relabel=False): 8 | if mode== 'all': 9 | ir_cameras = ['cam3','cam6'] 10 | elif mode =='indoor': 11 | ir_cameras = ['cam3','cam6'] 12 | 13 | file_path = os.path.join(data_path,'exp/test_id.txt') 14 | files_rgb = [] 15 | files_ir = [] 16 | 17 | with open(file_path, 'r') as file: 18 | ids = file.read().splitlines() 19 | ids = [int(y) for y in ids[0].split(',')] 20 | ids = ["%04d" % x for x in ids] 21 | 22 | for id in sorted(ids): 23 | for cam in ir_cameras: 24 | img_dir = os.path.join(data_path,cam,id) 25 | if os.path.isdir(img_dir): 26 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 27 | files_ir.extend(new_files) 28 | query_img = [] 29 | query_id = [] 30 | query_cam = [] 31 | for img_path in files_ir: 32 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 33 | query_img.append(img_path) 34 | query_id.append(pid) 35 | query_cam.append(camid) 36 | return query_img, np.array(query_id), np.array(query_cam) 37 | 38 | def process_gallery_sysu(data_path, mode = 'all', trial = 0, relabel=False): 39 | 40 | random.seed(trial) 41 | 42 | if mode== 'all': 43 | rgb_cameras = ['cam1','cam2','cam4','cam5'] 44 | elif mode =='indoor': 45 | rgb_cameras = ['cam1','cam2'] 46 | 47 | file_path = os.path.join(data_path,'exp/test_id.txt') 48 | files_rgb = [] 49 | with open(file_path, 'r') as file: 50 | ids = file.read().splitlines() 51 | ids = [int(y) for y in ids[0].split(',')] 52 | ids = ["%04d" % x for x in ids] 53 | 54 | for id in sorted(ids): 55 | for cam in rgb_cameras: 56 | img_dir = os.path.join(data_path,cam,id) 57 | if os.path.isdir(img_dir): 58 | new_files = sorted([img_dir+'/'+i for i in os.listdir(img_dir)]) 59 | files_rgb.append(random.choice(new_files)) 60 | gall_img = [] 61 | gall_id = [] 62 | gall_cam = [] 63 | for img_path in files_rgb: 64 | camid, pid = int(img_path[-15]), int(img_path[-13:-9]) 65 | gall_img.append(img_path) 66 | gall_id.append(pid) 67 | gall_cam.append(camid) 68 | return gall_img, np.array(gall_id), np.array(gall_cam) 69 | 70 | def process_test_regdb(img_dir, trial = 1, modal = 'visible'): 71 | if modal=='visible': 72 | input_data_path = img_dir + 'idx/test_visible_{}'.format(trial) + '.txt' 73 | elif modal=='thermal': 74 | input_data_path = img_dir + 'idx/test_thermal_{}'.format(trial) + '.txt' 75 | 76 | with open(input_data_path) as f: 77 | data_file_list = open(input_data_path, 'rt').read().splitlines() 78 | # Get full list of image and labels 79 | file_image = [img_dir + '/' + s.split(' ')[0] for s in data_file_list] 80 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 81 | 82 | return file_image, np.array(file_label) -------------------------------------------------------------------------------- /eval_metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import numpy as np 3 | """Cross-Modality ReID""" 4 | 5 | def eval_sysu(distmat, q_pids, g_pids, q_camids, g_camids, max_rank = 20): 6 | """Evaluation with sysu metric 7 | Key: for each query identity, its gallery images from the same camera view are discarded. "Following the original setting in ite dataset" 8 | """ 9 | num_q, num_g = distmat.shape 10 | if num_g < max_rank: 11 | max_rank = num_g 12 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 13 | indices = np.argsort(distmat, axis=1) 14 | pred_label = g_pids[indices] 15 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 16 | 17 | # compute cmc curve for each query 18 | new_all_cmc = [] 19 | all_cmc = [] 20 | all_AP = [] 21 | all_INP = [] 22 | num_valid_q = 0. # number of valid query 23 | for q_idx in range(num_q): 24 | # get query pid and camid 25 | q_pid = q_pids[q_idx] 26 | q_camid = q_camids[q_idx] 27 | 28 | # remove gallery samples that have the same pid and camid with query 29 | order = indices[q_idx] 30 | remove = (q_camid == 3) & (g_camids[order] == 2) 31 | keep = np.invert(remove) 32 | 33 | # compute cmc curve 34 | # the cmc calculation is different from standard protocol 35 | # we follow the protocol of the author's released code 36 | new_cmc = pred_label[q_idx][keep] 37 | new_index = np.unique(new_cmc, return_index=True)[1] 38 | new_cmc = [new_cmc[index] for index in sorted(new_index)] 39 | 40 | new_match = (new_cmc == q_pid).astype(np.int32) 41 | new_cmc = new_match.cumsum() 42 | new_all_cmc.append(new_cmc[:max_rank]) 43 | 44 | orig_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 45 | if not np.any(orig_cmc): 46 | # this condition is true when query identity does not appear in gallery 47 | continue 48 | 49 | cmc = orig_cmc.cumsum() 50 | 51 | # compute mINP 52 | # refernece: Deep Learning for Person Re-identification: A Survey and Outlook 53 | pos_idx = np.where(orig_cmc == 1) 54 | pos_max_idx = np.max(pos_idx) 55 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0) 56 | all_INP.append(inp) 57 | 58 | cmc[cmc > 1] = 1 59 | 60 | all_cmc.append(cmc[:max_rank]) 61 | num_valid_q += 1. 62 | 63 | # compute average precision 64 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 65 | num_rel = orig_cmc.sum() 66 | tmp_cmc = orig_cmc.cumsum() 67 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 68 | tmp_cmc = np.asarray(tmp_cmc) * orig_cmc 69 | AP = tmp_cmc.sum() / num_rel 70 | all_AP.append(AP) 71 | 72 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 73 | 74 | all_cmc = np.asarray(all_cmc).astype(np.float32) 75 | all_cmc = all_cmc.sum(0) / num_valid_q # standard CMC 76 | 77 | new_all_cmc = np.asarray(new_all_cmc).astype(np.float32) 78 | new_all_cmc = new_all_cmc.sum(0) / num_valid_q 79 | mAP = np.mean(all_AP) 80 | mINP = np.mean(all_INP) 81 | return new_all_cmc, mAP, mINP 82 | 83 | 84 | 85 | def eval_regdb(distmat, q_pids, g_pids, max_rank = 20): 86 | num_q, num_g = distmat.shape 87 | if num_g < max_rank: 88 | max_rank = num_g 89 | print("Note: number of gallery samples is quite small, got {}".format(num_g)) 90 | indices = np.argsort(distmat, axis=1) 91 | matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32) 92 | 93 | # compute cmc curve for each query 94 | all_cmc = [] 95 | all_AP = [] 96 | all_INP = [] 97 | num_valid_q = 0. # number of valid query 98 | 99 | # only two cameras 100 | q_camids = np.ones(num_q).astype(np.int32) 101 | g_camids = 2* np.ones(num_g).astype(np.int32) 102 | 103 | for q_idx in range(num_q): 104 | # get query pid and camid 105 | q_pid = q_pids[q_idx] 106 | q_camid = q_camids[q_idx] 107 | 108 | # remove gallery samples that have the same pid and camid with query 109 | order = indices[q_idx] 110 | remove = (g_pids[order] == q_pid) & (g_camids[order] == q_camid) 111 | keep = np.invert(remove) 112 | 113 | # compute cmc curve 114 | raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches 115 | if not np.any(raw_cmc): 116 | # this condition is true when query identity does not appear in gallery 117 | continue 118 | 119 | cmc = raw_cmc.cumsum() 120 | 121 | # compute mINP 122 | # refernece: Deep Learning for Person Re-identification: A Survey and Outlook 123 | pos_idx = np.where(raw_cmc == 1) 124 | pos_max_idx = np.max(pos_idx) 125 | inp = cmc[pos_max_idx]/ (pos_max_idx + 1.0) 126 | all_INP.append(inp) 127 | 128 | cmc[cmc > 1] = 1 129 | 130 | all_cmc.append(cmc[:max_rank]) 131 | num_valid_q += 1. 132 | 133 | # compute average precision 134 | # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision 135 | num_rel = raw_cmc.sum() 136 | tmp_cmc = raw_cmc.cumsum() 137 | tmp_cmc = [x / (i+1.) for i, x in enumerate(tmp_cmc)] 138 | tmp_cmc = np.asarray(tmp_cmc) * raw_cmc 139 | AP = tmp_cmc.sum() / num_rel 140 | all_AP.append(AP) 141 | 142 | assert num_valid_q > 0, "Error: all query identities do not appear in gallery" 143 | 144 | all_cmc = np.asarray(all_cmc).astype(np.float32) 145 | all_cmc = all_cmc.sum(0) / num_valid_q 146 | mAP = np.mean(all_AP) 147 | mINP = np.mean(all_INP) 148 | return all_cmc, mAP, mINP -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd.function import Function 6 | from torch.autograd import Variable 7 | 8 | class KLLoss(nn.Module): 9 | def __init__(self): 10 | super(KLLoss, self).__init__() 11 | def forward(self, pred, label): 12 | # pred: 2D matrix (batch_size, num_classes) 13 | # label: 1D vector indicating class number 14 | T=3 15 | 16 | predict = F.log_softmax(pred/T,dim=1) 17 | target_data = F.softmax(label/T,dim=1) 18 | target_data =target_data+10**(-7) 19 | target = Variable(target_data.data.cuda(),requires_grad=False) 20 | loss=T*T*((target*(target.log()-predict)).sum(1).sum()/target.size()[0]) 21 | return loss 22 | 23 | class OriTripletLoss(nn.Module): 24 | """Triplet loss with hard positive/negative mining. 25 | 26 | Reference: 27 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 28 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 29 | 30 | Args: 31 | - margin (float): margin for triplet. 32 | """ 33 | 34 | def __init__(self, batch_size, margin=0.3): 35 | super(OriTripletLoss, self).__init__() 36 | self.margin = margin 37 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 38 | 39 | def forward(self, inputs, targets): 40 | """ 41 | Args: 42 | - inputs: feature matrix with shape (batch_size, feat_dim) 43 | - targets: ground truth labels with shape (num_classes) 44 | """ 45 | n = inputs.size(0) 46 | 47 | # Compute pairwise distance, replace by the official when merged 48 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 49 | dist = dist + dist.t() 50 | dist.addmm_(1, -2, inputs, inputs.t()) 51 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 52 | 53 | # For each anchor, find the hardest positive and negative 54 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 55 | dist_ap, dist_an = [], [] 56 | for i in range(n): 57 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 58 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 59 | dist_ap = torch.cat(dist_ap) 60 | dist_an = torch.cat(dist_an) 61 | 62 | # Compute ranking hinge loss 63 | y = torch.ones_like(dist_an) 64 | loss = self.ranking_loss(dist_an, dist_ap, y) 65 | 66 | # compute accuracy 67 | correct = torch.ge(dist_an, dist_ap).sum().item() 68 | return loss, correct 69 | 70 | 71 | class TripletLoss(nn.Module): 72 | """Triplet loss with hard positive/negative mining. 73 | 74 | Reference: 75 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 76 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 77 | 78 | Args: 79 | - margin (float): margin for triplet. 80 | """ 81 | def __init__(self, batch_size, margin=0.5): 82 | super(TripletLoss, self).__init__() 83 | self.margin = margin 84 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 85 | self.batch_size = batch_size 86 | self.mask = torch.eye(batch_size) 87 | def forward(self, input, target): 88 | """ 89 | Args: 90 | - input: feature matrix with shape (batch_size, feat_dim) 91 | - target: ground truth labels with shape (num_classes) 92 | """ 93 | n = self.batch_size 94 | input1 = input.narrow(0,0,n) 95 | input2 = input.narrow(0,n,n) 96 | 97 | # Compute pairwise distance, replace by the official when merged 98 | dist = pdist_torch(input1, input2) 99 | 100 | # For each anchor, find the hardest positive and negative 101 | # mask = target1.expand(n, n).eq(target1.expand(n, n).t()) 102 | dist_ap, dist_an = [], [] 103 | for i in range(n): 104 | dist_ap.append(dist[i,i].unsqueeze(0)) 105 | dist_an.append(dist[i][self.mask[i] == 0].min().unsqueeze(0)) 106 | dist_ap = torch.cat(dist_ap) 107 | dist_an = torch.cat(dist_an) 108 | 109 | # Compute ranking hinge loss 110 | y = torch.ones_like(dist_an) 111 | loss = self.ranking_loss(dist_an, dist_ap, y) 112 | 113 | # compute accuracy 114 | correct = torch.ge(dist_an, dist_ap).sum().item() 115 | return loss, correct*2 116 | 117 | class BiTripletLoss(nn.Module): 118 | """Triplet loss with hard positive/negative mining. 119 | 120 | Reference: 121 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 122 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 123 | 124 | Args: 125 | - margin (float): margin for triplet.suffix 126 | """ 127 | def __init__(self, batch_size, margin=0.5): 128 | super(BiTripletLoss, self).__init__() 129 | self.margin = margin 130 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 131 | self.batch_size = batch_size 132 | self.mask = torch.eye(batch_size) 133 | def forward(self, input, target): 134 | """ 135 | Args: 136 | - input: feature matrix with shape (batch_size, feat_dim) 137 | - target: ground truth labels with shape (num_classes) 138 | """ 139 | n = self.batch_size 140 | input1 = input.narrow(0,0,n) 141 | input2 = input.narrow(0,n,n) 142 | 143 | # Compute pairwise distance, replace by the official when merged 144 | dist = pdist_torch(input1, input2) 145 | 146 | # For each anchor, find the hardest positive and negative 147 | # mask = target1.expand(n, n).eq(target1.expand(n, n).t()) 148 | dist_ap, dist_an = [], [] 149 | for i in range(n): 150 | dist_ap.append(dist[i,i].unsqueeze(0)) 151 | dist_an.append(dist[i][self.mask[i] == 0].min().unsqueeze(0)) 152 | dist_ap = torch.cat(dist_ap) 153 | dist_an = torch.cat(dist_an) 154 | 155 | # Compute ranking hinge loss 156 | y = torch.ones_like(dist_an) 157 | loss1 = self.ranking_loss(dist_an, dist_ap, y) 158 | 159 | # compute accuracy 160 | correct1 = torch.ge(dist_an, dist_ap).sum().item() 161 | 162 | # Compute pairwise distance, replace by the official when merged 163 | dist2 = pdist_torch(input2, input1) 164 | 165 | # For each anchor, find the hardest positive and negative 166 | dist_ap2, dist_an2 = [], [] 167 | for i in range(n): 168 | dist_ap2.append(dist2[i,i].unsqueeze(0)) 169 | dist_an2.append(dist2[i][self.mask[i] == 0].min().unsqueeze(0)) 170 | dist_ap2 = torch.cat(dist_ap2) 171 | dist_an2 = torch.cat(dist_an2) 172 | 173 | # Compute ranking hinge loss 174 | y2 = torch.ones_like(dist_an2) 175 | # loss2 = self.ranking_loss(dist_an2, dist_ap2, y2) 176 | 177 | loss2 = torch.sum(torch.nn.functional.relu(dist_ap2 + self.margin - dist_an2)) 178 | 179 | # compute accuracy 180 | correct2 = torch.ge(dist_an2, dist_ap2).sum().item() 181 | 182 | loss = torch.add(loss1, loss2) 183 | return loss, correct1 + correct2 184 | 185 | 186 | class BDTRLoss(nn.Module): 187 | """Triplet loss with hard positive/negative mining. 188 | 189 | Reference: 190 | Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737. 191 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py. 192 | 193 | Args: 194 | - margin (float): margin for triplet.suffix 195 | """ 196 | def __init__(self, batch_size, margin=0.5): 197 | super(BDTRLoss, self).__init__() 198 | self.margin = margin 199 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 200 | self.batch_size = batch_size 201 | self.mask = torch.eye(batch_size) 202 | def forward(self, inputs, targets): 203 | """ 204 | Args: 205 | - input: feature matrix with shape (batch_size, feat_dim) 206 | - target: ground truth labels with shape (num_classes) 207 | """ 208 | n = inputs.size(0) 209 | 210 | # Compute pairwise distance, replace by the official when merged 211 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 212 | dist = dist + dist.t() 213 | dist.addmm_(1, -2, inputs, inputs.t()) 214 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 215 | 216 | # For each anchor, find the hardest positive and negative 217 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 218 | dist_ap, dist_an = [], [] 219 | for i in range(n): 220 | dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) 221 | dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) 222 | dist_ap = torch.cat(dist_ap) 223 | dist_an = torch.cat(dist_an) 224 | 225 | # Compute ranking hinge loss 226 | y = torch.ones_like(dist_an) 227 | loss = self.ranking_loss(dist_an, dist_ap, y) 228 | correct = torch.ge(dist_an, dist_ap).sum().item() 229 | return loss, correct 230 | 231 | def pdist_torch(emb1, emb2): 232 | ''' 233 | compute the eucilidean distance matrix between embeddings1 and embeddings2 234 | using gpu 235 | ''' 236 | m, n = emb1.shape[0], emb2.shape[0] 237 | emb1_pow = torch.pow(emb1, 2).sum(dim = 1, keepdim = True).expand(m, n) 238 | emb2_pow = torch.pow(emb2, 2).sum(dim = 1, keepdim = True).expand(n, m).t() 239 | dist_mtx = emb1_pow + emb2_pow 240 | dist_mtx = dist_mtx.addmm_(1, -2, emb1, emb2.t()) 241 | # dist_mtx = dist_mtx.clamp(min = 1e-12) 242 | dist_mtx = dist_mtx.clamp(min = 1e-12).sqrt() 243 | return dist_mtx 244 | 245 | 246 | def pdist_np(emb1, emb2): 247 | ''' 248 | compute the eucilidean distance matrix between embeddings1 and embeddings2 249 | using cpu 250 | ''' 251 | m, n = emb1.shape[0], emb2.shape[0] 252 | emb1_pow = np.square(emb1).sum(axis = 1)[..., np.newaxis] 253 | emb2_pow = np.square(emb2).sum(axis = 1)[np.newaxis, ...] 254 | dist_mtx = -2 * np.matmul(emb1, emb2.T) + emb1_pow + emb2_pow 255 | # dist_mtx = np.sqrt(dist_mtx.clip(min = 1e-12)) 256 | return dist_mtx -------------------------------------------------------------------------------- /model_main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from torchvision import models 5 | from torch.autograd import Variable 6 | from resnet import resnet50, resnet18 7 | import torch.nn.functional as F 8 | import math 9 | from attention import GraphAttentionLayer, IWPA 10 | 11 | class Normalize(nn.Module): 12 | def __init__(self, power=2): 13 | super(Normalize, self).__init__() 14 | self.power = power 15 | 16 | def forward(self, x): 17 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 18 | out = x.div(norm) 19 | return out 20 | 21 | 22 | 23 | # ##################################################################### 24 | def weights_init_kaiming(m): 25 | classname = m.__class__.__name__ 26 | # print(classname) 27 | if classname.find('Conv') != -1: 28 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 29 | elif classname.find('Linear') != -1: 30 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 31 | init.zeros_(m.bias.data) 32 | elif classname.find('BatchNorm1d') != -1: 33 | init.normal_(m.weight.data, 1.0, 0.01) 34 | init.zeros_(m.bias.data) 35 | 36 | 37 | def weights_init_classifier(m): 38 | classname = m.__class__.__name__ 39 | if classname.find('Linear') != -1: 40 | init.normal_(m.weight.data, 0, 0.001) 41 | if m.bias: 42 | init.zeros_(m.bias.data) 43 | 44 | # Defines the new fc layer and classification layer 45 | # |--Linear--|--bn--|--relu--|--Linear--| 46 | class FeatureBlock(nn.Module): 47 | def __init__(self, input_dim, low_dim, dropout=0.5, relu=True): 48 | super(FeatureBlock, self).__init__() 49 | feat_block = [] 50 | feat_block += [nn.Linear(input_dim, low_dim)] 51 | feat_block += [nn.BatchNorm1d(low_dim)] 52 | 53 | feat_block = nn.Sequential(*feat_block) 54 | feat_block.apply(weights_init_kaiming) 55 | self.feat_block = feat_block 56 | 57 | def forward(self, x): 58 | x = self.feat_block(x) 59 | return x 60 | 61 | 62 | class ClassBlock(nn.Module): 63 | def __init__(self, input_dim, class_num, dropout=0.5, relu=True): 64 | super(ClassBlock, self).__init__() 65 | classifier = [] 66 | if relu: 67 | classifier += [nn.LeakyReLU(0.1)] 68 | if dropout: 69 | classifier += [nn.Dropout(p=dropout)] 70 | 71 | classifier += [nn.Linear(input_dim, class_num)] 72 | classifier = nn.Sequential(*classifier) 73 | classifier.apply(weights_init_classifier) 74 | 75 | self.classifier = classifier 76 | 77 | def forward(self, x): 78 | x = self.classifier(x) 79 | return x 80 | 81 | class visible_module(nn.Module): 82 | def __init__(self, arch='resnet50'): 83 | super(visible_module, self).__init__() 84 | 85 | model_v = resnet50(pretrained=True, 86 | last_conv_stride=1, last_conv_dilation=1) 87 | # avg pooling to global pooling 88 | self.visible = model_v 89 | 90 | def forward(self, x): 91 | x = self.visible.conv1(x) 92 | x = self.visible.bn1(x) 93 | x = self.visible.relu(x) 94 | x = self.visible.maxpool(x) 95 | return x 96 | 97 | 98 | class thermal_module(nn.Module): 99 | def __init__(self, arch='resnet50'): 100 | super(thermal_module, self).__init__() 101 | 102 | model_t = resnet50(pretrained=True, 103 | last_conv_stride=1, last_conv_dilation=1) 104 | # avg pooling to global pooling 105 | self.thermal = model_t 106 | 107 | def forward(self, x): 108 | x = self.thermal.conv1(x) 109 | x = self.thermal.bn1(x) 110 | x = self.thermal.relu(x) 111 | x = self.thermal.maxpool(x) 112 | return x 113 | 114 | 115 | class base_resnet(nn.Module): 116 | def __init__(self, arch='resnet50'): 117 | super(base_resnet, self).__init__() 118 | 119 | model_base = resnet50(pretrained=True, 120 | last_conv_stride=1, last_conv_dilation=1) 121 | # avg pooling to global pooling 122 | model_base.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 123 | self.base = model_base 124 | 125 | def forward(self, x): 126 | x = self.base.layer1(x) 127 | x = self.base.layer2(x) 128 | x = self.base.layer3(x) 129 | x = self.base.layer4(x) 130 | return x 131 | 132 | 133 | class embed_net(nn.Module): 134 | def __init__(self, low_dim, class_num, drop=0.2, part = 3, alpha=0.2, nheads=4, arch='resnet50', wpa = False): 135 | super(embed_net, self).__init__() 136 | 137 | self.thermal_module = thermal_module(arch=arch) 138 | self.visible_module = visible_module(arch=arch) 139 | self.base_resnet = base_resnet(arch=arch) 140 | pool_dim = 2048 141 | self.dropout = drop 142 | self.part = part 143 | self.lpa = wpa 144 | 145 | self.l2norm = Normalize(2) 146 | self.bottleneck = nn.BatchNorm1d(pool_dim) 147 | self.bottleneck.bias.requires_grad_(False) # no shift 148 | 149 | self.classifier = nn.Linear(pool_dim, class_num, bias=False) 150 | 151 | self.classifier1 = nn.Linear(pool_dim, class_num, bias=False) 152 | self.classifier2 = nn.Linear(pool_dim, class_num, bias=False) 153 | 154 | self.bottleneck.apply(weights_init_kaiming) 155 | self.classifier.apply(weights_init_classifier) 156 | self.classifier1.apply(weights_init_classifier) 157 | self.classifier2.apply(weights_init_classifier) 158 | 159 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 160 | self.wpa = IWPA(pool_dim, part) 161 | 162 | 163 | self.attentions = [GraphAttentionLayer(pool_dim, low_dim, dropout=drop, alpha=alpha, concat=True) for _ in range(nheads)] 164 | for i, attention in enumerate(self.attentions): 165 | self.add_module('attention_{}'.format(i), attention) 166 | 167 | self.out_att = GraphAttentionLayer(low_dim * nheads, class_num, dropout=drop, alpha=alpha, concat=False) 168 | 169 | def forward(self, x1, x2, adj, modal=0, cpa = False): 170 | # domain specific block 171 | if modal == 0: 172 | x1 = self.visible_module(x1) 173 | x2 = self.thermal_module(x2) 174 | x = torch.cat((x1, x2), 0) 175 | elif modal == 1: 176 | x = self.visible_module(x1) 177 | elif modal == 2: 178 | x = self.thermal_module(x2) 179 | 180 | # shared four blocks 181 | x = self.base_resnet(x) 182 | x_pool = self.avgpool(x) 183 | x_pool = x_pool.view(x_pool.size(0), x_pool.size(1)) 184 | feat = self.bottleneck(x_pool) 185 | 186 | if self.lpa: 187 | # intra-modality weighted part attention 188 | feat_att = self.wpa(x, feat, 1, self.part) 189 | 190 | if self.training: 191 | # cross-modality graph attention 192 | x_g = F.dropout(x_pool, self.dropout, training=self.training) 193 | x_g = torch.cat([att(x_g, adj) for att in self.attentions], dim=1) 194 | x_g = F.dropout(x_g, self.dropout, training=self.training) 195 | x_g = F.elu(self.out_att(x_g, adj)) 196 | return x_pool, self.classifier(feat), self.classifier(feat_att), F.log_softmax(x_g, dim=1) 197 | else: 198 | return self.l2norm(feat), self.l2norm(feat_att) -------------------------------------------------------------------------------- /resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 6 | 'resnet152'] 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 18 | """3x3 convolution with padding""" 19 | # original padding is 1; original dilation is 1 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=dilation, bias=False, dilation=dilation) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride, dilation) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | expansion = 4 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 60 | super(Bottleneck, self).__init__() 61 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 62 | self.bn1 = nn.BatchNorm2d(planes) 63 | # original padding is 1; original dilation is 1 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=dilation, bias=False, dilation=dilation) 65 | self.bn2 = nn.BatchNorm2d(planes) 66 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 67 | self.bn3 = nn.BatchNorm2d(planes * 4) 68 | self.relu = nn.ReLU(inplace=True) 69 | self.downsample = downsample 70 | self.stride = stride 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.bn1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.bn2(out) 81 | out = self.relu(out) 82 | 83 | out = self.conv3(out) 84 | out = self.bn3(out) 85 | 86 | if self.downsample is not None: 87 | residual = self.downsample(x) 88 | 89 | out += residual 90 | out = self.relu(out) 91 | 92 | return out 93 | 94 | 95 | class ResNet(nn.Module): 96 | 97 | def __init__(self, block, layers, last_conv_stride=2, last_conv_dilation=1): 98 | 99 | self.inplanes = 64 100 | super(ResNet, self).__init__() 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 102 | bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=last_conv_stride, dilation=last_conv_dilation) 110 | 111 | for m in self.modules(): 112 | if isinstance(m, nn.Conv2d): 113 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 114 | m.weight.data.normal_(0, math.sqrt(2. / n)) 115 | elif isinstance(m, nn.BatchNorm2d): 116 | m.weight.data.fill_(1) 117 | m.bias.data.zero_() 118 | 119 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 120 | downsample = None 121 | if stride != 1 or self.inplanes != planes * block.expansion: 122 | downsample = nn.Sequential( 123 | nn.Conv2d(self.inplanes, planes * block.expansion, 124 | kernel_size=1, stride=stride, bias=False), 125 | nn.BatchNorm2d(planes * block.expansion), 126 | ) 127 | 128 | layers = [] 129 | layers.append(block(self.inplanes, planes, stride, downsample, dilation)) 130 | self.inplanes = planes * block.expansion 131 | for i in range(1, blocks): 132 | layers.append(block(self.inplanes, planes)) 133 | 134 | return nn.Sequential(*layers) 135 | 136 | def forward(self, x): 137 | x = self.conv1(x) 138 | x = self.bn1(x) 139 | x = self.relu(x) 140 | x = self.maxpool(x) 141 | 142 | x = self.layer1(x) 143 | x = self.layer2(x) 144 | x = self.layer3(x) 145 | x = self.layer4(x) 146 | 147 | return x 148 | 149 | 150 | def remove_fc(state_dict): 151 | """Remove the fc layer parameters from state_dict.""" 152 | # for key, value in state_dict.items(): 153 | for key, value in list(state_dict.items()): 154 | if key.startswith('fc.'): 155 | del state_dict[key] 156 | return state_dict 157 | 158 | 159 | def resnet18(pretrained=False, **kwargs): 160 | """Constructs a ResNet-18 model. 161 | Args: 162 | pretrained (bool): If True, returns a model pre-trained on ImageNet 163 | """ 164 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 165 | if pretrained: 166 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet18']))) 167 | return model 168 | 169 | 170 | def resnet34(pretrained=False, **kwargs): 171 | """Constructs a ResNet-34 model. 172 | Args: 173 | pretrained (bool): If True, returns a model pre-trained on ImageNet 174 | """ 175 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 176 | if pretrained: 177 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet34']))) 178 | return model 179 | 180 | 181 | def resnet50(pretrained=False, **kwargs): 182 | """Constructs a ResNet-50 model. 183 | Args: 184 | pretrained (bool): If True, returns a model pre-trained on ImageNet 185 | """ 186 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 187 | if pretrained: 188 | # model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50']))) 189 | model.load_state_dict(remove_fc(model_zoo.load_url(model_urls['resnet50']))) 190 | return model 191 | 192 | 193 | def resnet101(pretrained=False, **kwargs): 194 | """Constructs a ResNet-101 model. 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | """ 198 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 199 | if pretrained: 200 | model.load_state_dict( 201 | remove_fc(model_zoo.load_url(model_urls['resnet101']))) 202 | return model 203 | 204 | 205 | def resnet152(pretrained=False, **kwargs): 206 | """Constructs a ResNet-152 model. 207 | Args: 208 | pretrained (bool): If True, returns a model pre-trained on ImageNet 209 | """ 210 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 211 | if pretrained: 212 | model.load_state_dict( 213 | remove_fc(model_zoo.load_url(model_urls['resnet152']))) 214 | return model -------------------------------------------------------------------------------- /test_ddag.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import sys 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.nn.functional as F 8 | import torch.backends.cudnn as cudnn 9 | from torch.autograd import Variable 10 | import torch.utils.data as data 11 | import torchvision 12 | import torchvision.transforms as transforms 13 | from data_loader import SYSUData, RegDBData, TestData 14 | from data_manager import * 15 | from eval_metrics import eval_sysu, eval_regdb 16 | from model_main import embed_net 17 | from utils import * 18 | 19 | import time 20 | import scipy.io as scio 21 | 22 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 23 | parser.add_argument('--dataset', default='sysu', help='dataset name: regdb or sysu]') 24 | parser.add_argument('--lr', default=0.01, type=float, help='learning rate') 25 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer') 26 | parser.add_argument('--arch', default='resnet50', type=str, help='network baseline') 27 | parser.add_argument('--resume', '-r', default='', type=str, help='resume from checkpoint') 28 | parser.add_argument('--model_path', default='save_model/', type=str, help='model save path') 29 | parser.add_argument('--log_path', default='log/', type=str, help='log save path') 30 | parser.add_argument('--workers', default=4, type=int, metavar='N', 31 | help='number of data loading workers (default: 4)') 32 | parser.add_argument('--low-dim', default=512, type=int, 33 | metavar='D', help='feature dimension') 34 | parser.add_argument('--img_w', default=144, type=int, 35 | metavar='imgw', help='img width') 36 | parser.add_argument('--img_h', default=288, type=int, 37 | metavar='imgh', help='img height') 38 | parser.add_argument('--batch-size', default=32, type=int, 39 | metavar='B', help='training batch size') 40 | parser.add_argument('--part', default=3, type=int, 41 | metavar='tb', help=' part number') 42 | parser.add_argument('--test-batch', default=64, type=int, 43 | metavar='tb', help='testing batch size') 44 | parser.add_argument('--method', default='id', type=str, 45 | metavar='m', help='Method type') 46 | parser.add_argument('--drop', default=0.0, type=float, 47 | metavar='drop', help='dropout ratio') 48 | parser.add_argument('--trial', default=1, type=int, 49 | metavar='t', help='trial') 50 | parser.add_argument('--gpu', default='0', type=str, 51 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 52 | parser.add_argument('--mode', default='all', type=str, help='all or indoor') 53 | parser.add_argument('--graph', action='store_true', help='either add graph learning') 54 | parser.add_argument('--wpa', action='store_true', help='either add weighted part attention') 55 | args = parser.parse_args() 56 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 57 | np.random.seed(1) 58 | dataset = args.dataset 59 | if dataset == 'sysu': 60 | # TODO: define your data path for RegDB dataset 61 | data_path = 'YOUR DATA PATH' 62 | n_class = 395 63 | test_mode = [1, 2] 64 | elif dataset =='regdb': 65 | # TODO: define your data path for RegDB dataset 66 | data_path = 'YOUR DATA PATH' 67 | n_class = 206 68 | test_mode = [2, 1] 69 | 70 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 71 | best_acc = 0 # best test accuracy 72 | start_epoch = 0 73 | 74 | print('==> Building model..') 75 | net = embed_net(args.low_dim, n_class, drop=args.drop, part=args.part, arch=args.arch, wpa=args.wpa) 76 | net.to(device) 77 | cudnn.benchmark = True 78 | 79 | print('==> Resuming from checkpoint..') 80 | checkpoint_path = args.model_path 81 | if len(args.resume)>0: 82 | model_path = checkpoint_path + args.resume 83 | # model_path = checkpoint_path + 'test_best.t' 84 | if os.path.isfile(model_path): 85 | print('==> loading checkpoint {}'.format(args.resume)) 86 | checkpoint = torch.load(model_path) 87 | start_epoch = checkpoint['epoch'] 88 | # pdb.set_trace() 89 | net.load_state_dict(checkpoint['net']) 90 | print('==> loaded checkpoint {} (epoch {})' 91 | .format(args.resume, checkpoint['epoch'])) 92 | else: 93 | print('==> no checkpoint found at {}!!!!!!!!!!'.format(args.resume)) 94 | 95 | 96 | if args.method =='id': 97 | criterion = nn.CrossEntropyLoss() 98 | criterion.to(device) 99 | 100 | print('==> Loading data..') 101 | # Data loading code 102 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 103 | transform_train = transforms.Compose([ 104 | transforms.ToPILImage(), 105 | # transforms.Resize((280,150), interpolation=2), 106 | transforms.RandomCrop((args.img_h,args.img_w)), 107 | transforms.RandomHorizontalFlip(), 108 | transforms.ToTensor(), 109 | normalize, 110 | ]) 111 | 112 | transform_test = transforms.Compose([ 113 | transforms.ToPILImage(), 114 | transforms.Resize((args.img_h,args.img_w)), 115 | transforms.ToTensor(), 116 | normalize, 117 | ]) 118 | 119 | end = time.time() 120 | 121 | if dataset =='sysu': 122 | # testing set 123 | query_img, query_label, query_cam = process_query_sysu(data_path, mode = args.mode) 124 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode = args.mode, trial = 0) 125 | 126 | nquery = len(query_label) 127 | ngall = len(gall_label) 128 | print("Dataset statistics:") 129 | print(" ------------------------------") 130 | print(" subset | # ids | # images") 131 | print(" ------------------------------") 132 | print(" query | {:5d} | {:8d}".format(len(np.unique(query_label)), nquery)) 133 | print(" gallery | {:5d} | {:8d}".format(len(np.unique(gall_label)), ngall)) 134 | print(" ------------------------------") 135 | 136 | 137 | queryset = TestData(query_img, query_label, transform = transform_test, img_size =(args.img_w, args.img_h)) 138 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=4) 139 | 140 | elif dataset =='regdb': 141 | # training set 142 | trainset = RegDBData(data_path, args.trial, transform=transform_train) 143 | # generate the idx of each person identity 144 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 145 | 146 | # testing set 147 | query_img, query_label = process_test_regdb(data_path, trial = args.trial, modal = 'visible') 148 | gall_img, gall_label = process_test_regdb(data_path, trial = args.trial, modal = 'thermal') 149 | 150 | gallset = TestData(gall_img, gall_label, transform = transform_test, img_size =(args.img_w,args.img_h)) 151 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 152 | 153 | print('Data Loading Time:\t {:.3f}'.format(time.time()-end)) 154 | 155 | feature_dim = 2048 156 | if args.arch =='resnet50': 157 | pool_dim = 2048 158 | elif args.arch =='resnet18': 159 | pool_dim = 512 160 | 161 | def extract_gall_feat(gall_loader): 162 | net.eval() 163 | print ('Extracting Gallery Feature...') 164 | start = time.time() 165 | ptr = 0 166 | gall_feat = np.zeros((ngall, feature_dim)) 167 | gall_feat_att = np.zeros((ngall, pool_dim)) 168 | with torch.no_grad(): 169 | for batch_idx, (input, label ) in enumerate(gall_loader): 170 | batch_num = input.size(0) 171 | input = Variable(input.cuda()) 172 | feat, feat_att = net(input, input, 0, test_mode[0]) 173 | gall_feat[ptr:ptr+batch_num,: ] = feat.detach().cpu().numpy() 174 | gall_feat_att[ptr:ptr+batch_num,: ] = feat_att.detach().cpu().numpy() 175 | ptr = ptr + batch_num 176 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 177 | return gall_feat, gall_feat_att 178 | 179 | def extract_query_feat(query_loader): 180 | net.eval() 181 | print ('Extracting Query Feature...') 182 | start = time.time() 183 | ptr = 0 184 | query_feat = np.zeros((nquery, feature_dim)) 185 | query_feat_att = np.zeros((nquery, pool_dim)) 186 | with torch.no_grad(): 187 | for batch_idx, (input, label ) in enumerate(query_loader): 188 | batch_num = input.size(0) 189 | input = Variable(input.cuda()) 190 | feat, feat_att = net(input, input, 0, test_mode[1]) 191 | query_feat[ptr:ptr+batch_num,: ] = feat.detach().cpu().numpy() 192 | query_feat_att[ptr:ptr+batch_num,: ] = feat_att.detach().cpu().numpy() 193 | ptr = ptr + batch_num 194 | print('Extracting Time:\t {:.3f}'.format(time.time()-start)) 195 | return query_feat, query_feat_att 196 | 197 | query_feat, query_feat_att = extract_query_feat(query_loader) 198 | 199 | all_cmc = 0 200 | all_mAP = 0 201 | all_cmc_pool = 0 202 | 203 | for trial in range(10): 204 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode = args.mode, trial = trial) 205 | 206 | trial_gallset = TestData(gall_img, gall_label, transform = transform_test,img_size =(args.img_w,args.img_h)) 207 | trial_gall_loader = data.DataLoader(trial_gallset, batch_size=args.test_batch, shuffle=False, num_workers=4) 208 | 209 | gall_feat, gall_feat_att = extract_gall_feat(trial_gall_loader) 210 | 211 | # fc feature 212 | distmat = np.matmul(query_feat, np.transpose(gall_feat)) 213 | cmc, mAP, mINP = eval_sysu(-distmat, query_label, gall_label,query_cam, gall_cam) 214 | 215 | # attention feature 216 | distmat_att = np.matmul(query_feat_att, np.transpose(gall_feat_att)) 217 | cmc_att, mAP_att, mINP_att = eval_sysu(-distmat_att, query_label, gall_label,query_cam, gall_cam) 218 | if trial ==0: 219 | all_cmc = cmc 220 | all_mAP = mAP 221 | all_mINP = mINP 222 | all_cmc_att = cmc_att 223 | all_mAP_att = mAP_att 224 | all_mINP_att = mINP_att 225 | else: 226 | all_cmc = all_cmc + cmc 227 | all_mAP = all_mAP + mAP 228 | all_mINP = all_mINP + mINP 229 | all_cmc_att = all_cmc_att + cmc_att 230 | all_mAP_att = all_mAP_att + mAP_att 231 | all_mINP_att = all_mINP_att + mINP_att 232 | 233 | print('Test Trial: {}'.format(trial)) 234 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 235 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 236 | print('FC_att: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 237 | cmc_att[0], cmc_att[4], cmc_att[9], cmc_att[19], mAP_att, mINP_att)) 238 | 239 | cmc = all_cmc /10 240 | mAP = all_mAP /10 241 | mINP = all_mINP /10 242 | 243 | cmc_att = all_cmc_att /10 244 | mAP_att = all_mAP_att /10 245 | mINP_att = all_mINP_att /10 246 | print ('All Average:') 247 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 248 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 249 | print('FC_att: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 250 | cmc_att[0], cmc_att[4], cmc_att[9], cmc_att[19], mAP_att, mINP_att)) 251 | -------------------------------------------------------------------------------- /train_ddag.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import sys 4 | import time 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.backends.cudnn as cudnn 9 | from torch.autograd import Variable 10 | import torch.utils.data as data 11 | import torchvision 12 | import torchvision.transforms as transforms 13 | from data_loader import SYSUData, RegDBData, TestData 14 | from data_manager import * 15 | from eval_metrics import eval_sysu, eval_regdb 16 | from model_main import embed_net 17 | from utils import * 18 | from loss import OriTripletLoss 19 | from torch.optim import lr_scheduler 20 | from tensorboardX import SummaryWriter 21 | import torch.nn.functional as F 22 | import math 23 | 24 | parser = argparse.ArgumentParser(description='PyTorch Cross-Modality Training') 25 | parser.add_argument('--dataset', default='sysu', help='dataset name: regdb or sysu]') 26 | parser.add_argument('--lr', default=0.1 , type=float, help='learning rate, 0.00035 for adam') 27 | parser.add_argument('--optim', default='sgd', type=str, help='optimizer') 28 | parser.add_argument('--arch', default='resnet50', type=str, 29 | help='network baseline:resnet50') 30 | parser.add_argument('--resume', '-r', default='', type=str, 31 | help='resume from checkpoint') 32 | parser.add_argument('--test-only', action='store_true', help='test only') 33 | parser.add_argument('--model_path', default='save_model/', type=str, 34 | help='model save path') 35 | parser.add_argument('--save_epoch', default=20, type=int, 36 | metavar='s', help='save model every 10 epochs') 37 | parser.add_argument('--log_path', default='log/', type=str, 38 | help='log save path') 39 | parser.add_argument('--vis_log_path', default='log/vis_log_ddag/', type=str, 40 | help='log save path') 41 | parser.add_argument('--workers', default=4, type=int, metavar='N', 42 | help='number of data loading workers (default: 4)') 43 | parser.add_argument('--low-dim', default=512, type=int, 44 | metavar='D', help='feature dimension') 45 | parser.add_argument('--img_w', default=144, type=int, 46 | metavar='imgw', help='img width') 47 | parser.add_argument('--img_h', default=288, type=int, 48 | metavar='imgh', help='img height') 49 | parser.add_argument('--batch-size', default=8, type=int, 50 | metavar='B', help='training batch size') 51 | parser.add_argument('--test-batch', default=64, type=int, 52 | metavar='tb', help='testing batch size') 53 | parser.add_argument('--part', default=3, type=int, 54 | metavar='tb', help=' part number') 55 | parser.add_argument('--method', default='id+tri', type=str, 56 | metavar='m', help='method type') 57 | parser.add_argument('--drop', default=0.2, type=float, 58 | metavar='drop', help='dropout ratio') 59 | parser.add_argument('--margin', default=0.3, type=float, 60 | metavar='margin', help='triplet loss margin') 61 | parser.add_argument('--num_pos', default=4, type=int, 62 | help='num of pos per identity in each modality') 63 | parser.add_argument('--trial', default=1, type=int, 64 | metavar='t', help='trial (only for RegDB dataset)') 65 | parser.add_argument('--seed', default=0, type=int, 66 | metavar='t', help='random seed') 67 | parser.add_argument('--gpu', default='0', type=str, 68 | help='gpu device ids for CUDA_VISIBLE_DEVICES') 69 | parser.add_argument('--mode', default='all', type=str, help='all or indoor') 70 | parser.add_argument('--lambda0', default=1.0, type=float, 71 | metavar='lambda0', help='graph attention weights') 72 | parser.add_argument('--graph', action='store_true', help='either add graph attention or not') 73 | parser.add_argument('--wpa', action='store_true', help='either add weighted part attention') 74 | 75 | args = parser.parse_args() 76 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 77 | 78 | set_seed(args.seed) 79 | 80 | dataset = args.dataset 81 | if dataset == 'sysu': 82 | # TODO: define your data path 83 | data_path = 'YOUR DATA PATH' 84 | log_path = args.log_path + 'sysu_log_ddag/' 85 | test_mode = [1, 2] # infrared to visible 86 | elif dataset =='regdb': 87 | # TODO: define your data path for RegDB dataset 88 | data_path = 'YOUR DATA PATH' 89 | log_path = args.log_path + 'regdb_log_ddag/' 90 | test_mode = [2, 1] # visible to infrared 91 | 92 | checkpoint_path = args.model_path 93 | 94 | if not os.path.isdir(log_path): 95 | os.makedirs(log_path) 96 | if not os.path.isdir(checkpoint_path): 97 | os.makedirs(checkpoint_path) 98 | if not os.path.isdir(args.vis_log_path): 99 | os.makedirs(args.vis_log_path) 100 | 101 | # log file name 102 | suffix = dataset 103 | if args.graph: 104 | suffix = suffix + '_G' 105 | if args.wpa: 106 | suffix = suffix + '_P_{}'.format(args.part) 107 | suffix = suffix + '_drop_{}_{}_{}_lr_{}_seed_{}'.format(args.drop, args.num_pos, args.batch_size, args.lr, args.seed) 108 | if not args.optim == 'sgd': 109 | suffix = suffix + '_' + args.optim 110 | if dataset == 'regdb': 111 | suffix = suffix + '_trial_{}'.format(args.trial) 112 | 113 | test_log_file = open(log_path + suffix + '.txt', "w") 114 | sys.stdout = Logger(log_path + suffix + '_os.txt') 115 | 116 | vis_log_dir = args.vis_log_path + suffix + '/' 117 | 118 | if not os.path.isdir(vis_log_dir): 119 | os.makedirs(vis_log_dir) 120 | writer = SummaryWriter(vis_log_dir) 121 | print("==========\nArgs:{}\n==========".format(args)) 122 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 123 | best_acc = 0 # best test accuracy 124 | start_epoch = 0 125 | feature_dim = args.low_dim 126 | wG = 0 127 | end = time.time() 128 | 129 | print('==> Loading data..') 130 | # Data loading code 131 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 132 | transform_train = transforms.Compose([ 133 | transforms.ToPILImage(), 134 | transforms.Pad(10), 135 | transforms.RandomCrop((args.img_h, args.img_w)), 136 | transforms.RandomHorizontalFlip(), 137 | transforms.ToTensor(), 138 | normalize, 139 | ]) 140 | transform_test = transforms.Compose([ 141 | transforms.ToPILImage(), 142 | transforms.Resize((args.img_h, args.img_w)), 143 | transforms.ToTensor(), 144 | normalize, 145 | ]) 146 | 147 | 148 | if dataset == 'sysu': 149 | # training set 150 | trainset = SYSUData(data_path, transform=transform_train) 151 | # generate the idx of each person identity 152 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 153 | 154 | # testing set 155 | query_img, query_label, query_cam = process_query_sysu(data_path, mode=args.mode) 156 | gall_img, gall_label, gall_cam = process_gallery_sysu(data_path, mode=args.mode, trial=0) 157 | 158 | elif dataset == 'regdb': 159 | # training set 160 | trainset = RegDBData(data_path, args.trial, transform=transform_train) 161 | # generate the idx of each person identity 162 | color_pos, thermal_pos = GenIdx(trainset.train_color_label, trainset.train_thermal_label) 163 | 164 | # testing set 165 | query_img, query_label = process_test_regdb(data_path, trial=args.trial, modal='visible') 166 | gall_img, gall_label = process_test_regdb(data_path, trial=args.trial, modal='thermal') 167 | 168 | gallset = TestData(gall_img, gall_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 169 | queryset = TestData(query_img, query_label, transform=transform_test, img_size=(args.img_w, args.img_h)) 170 | 171 | # testing data loader 172 | gall_loader = data.DataLoader(gallset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 173 | query_loader = data.DataLoader(queryset, batch_size=args.test_batch, shuffle=False, num_workers=args.workers) 174 | 175 | n_class = len(np.unique(trainset.train_color_label)) 176 | nquery = len(query_label) 177 | ngall = len(gall_label) 178 | 179 | print('Dataset {} statistics:'.format(dataset)) 180 | print(' ------------------------------') 181 | print(' subset | # ids | # images') 182 | print(' ------------------------------') 183 | print(' visible | {:5d} | {:8d}'.format(n_class, len(trainset.train_color_label))) 184 | print(' thermal | {:5d} | {:8d}'.format(n_class, len(trainset.train_thermal_label))) 185 | print(' ------------------------------') 186 | print(' query | {:5d} | {:8d}'.format(len(np.unique(query_label)), nquery)) 187 | print(' gallery | {:5d} | {:8d}'.format(len(np.unique(gall_label)), ngall)) 188 | print(' ------------------------------') 189 | print('Data Loading Time:\t {:.3f}'.format(time.time() - end)) 190 | 191 | print('==> Building model..') 192 | net = embed_net(args.low_dim, n_class, drop=args.drop, part=args.part, arch=args.arch, wpa=args.wpa) 193 | net.to(device) 194 | cudnn.benchmark = True 195 | 196 | if len(args.resume) > 0: 197 | model_path = checkpoint_path + args.resume 198 | if os.path.isfile(model_path): 199 | print('==> loading checkpoint {}'.format(args.resume)) 200 | checkpoint = torch.load(model_path) 201 | start_epoch = checkpoint['epoch'] 202 | net.load_state_dict(checkpoint['net']) 203 | print('==> loaded checkpoint {} (epoch {})' 204 | .format(args.resume, checkpoint['epoch'])) 205 | else: 206 | print('==> no checkpoint found at {}'.format(args.resume)) 207 | 208 | # define loss function 209 | criterion1 = nn.CrossEntropyLoss() 210 | loader_batch = args.batch_size * args.num_pos 211 | criterion2 = OriTripletLoss(batch_size=loader_batch, margin=args.margin) 212 | criterion1.to(device) 213 | criterion2.to(device) 214 | 215 | # optimizer 216 | if args.optim == 'sgd': 217 | ignored_params = list(map(id, net.bottleneck.parameters())) \ 218 | + list(map(id, net.classifier.parameters())) \ 219 | + list(map(id, net.wpa.parameters())) \ 220 | + list(map(id, net.attention_0.parameters())) \ 221 | + list(map(id, net.attention_1.parameters())) \ 222 | + list(map(id, net.attention_2.parameters())) \ 223 | + list(map(id, net.attention_3.parameters())) \ 224 | + list(map(id, net.out_att.parameters())) 225 | 226 | base_params = filter(lambda p: id(p) not in ignored_params, net.parameters()) 227 | 228 | optimizer_P = optim.SGD([ 229 | {'params': base_params, 'lr': 0.1 * args.lr}, 230 | {'params': net.bottleneck.parameters(), 'lr': args.lr}, 231 | {'params': net.classifier.parameters(), 'lr': args.lr}, 232 | {'params': net.wpa.parameters(), 'lr': args.lr}, 233 | {'params': net.attention_0.parameters(), 'lr': args.lr}, 234 | {'params': net.attention_1.parameters(), 'lr': args.lr}, 235 | {'params': net.attention_2.parameters(), 'lr': args.lr}, 236 | {'params': net.attention_3.parameters(), 'lr': args.lr}, 237 | {'params': net.out_att.parameters(), 'lr': args.lr} ,], 238 | weight_decay=5e-4, momentum=0.9, nesterov=True) 239 | 240 | optimizer_G = optim.SGD([ 241 | {'params': net.attention_0.parameters(), 'lr': args.lr}, 242 | {'params': net.attention_1.parameters(), 'lr': args.lr}, 243 | {'params': net.attention_2.parameters(), 'lr': args.lr}, 244 | {'params': net.attention_3.parameters(), 'lr': args.lr}, 245 | {'params': net.out_att.parameters(), 'lr': args.lr}, ], 246 | weight_decay=5e-4, momentum=0.9, nesterov=True) 247 | 248 | # exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) 249 | def adjust_learning_rate(optimizer_P, optimizer_G, epoch): 250 | """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 251 | if epoch < 10: 252 | lr = args.lr * (epoch + 1) / 10 253 | elif 10 <= epoch < 20: 254 | lr = args.lr 255 | elif 20 <= epoch < 50: 256 | lr = args.lr * 0.1 257 | elif epoch >= 50: 258 | lr = args.lr * 0.01 259 | 260 | optimizer_P.param_groups[0]['lr'] = 0.1 * lr 261 | for i in range(len(optimizer_P.param_groups) - 1): 262 | optimizer_P.param_groups[i + 1]['lr'] = lr 263 | return lr 264 | 265 | 266 | def train(epoch, wG): 267 | # adjust learning rate 268 | current_lr = adjust_learning_rate(optimizer_P, optimizer_G, epoch) 269 | train_loss = AverageMeter() 270 | id_loss = AverageMeter() 271 | tri_loss = AverageMeter() 272 | graph_loss = AverageMeter() 273 | data_time = AverageMeter() 274 | batch_time = AverageMeter() 275 | correct = 0 276 | total = 0 277 | 278 | # switch to train mode 279 | net.train() 280 | end = time.time() 281 | 282 | for batch_idx, (input1, input2, label1, label2) in enumerate(trainloader): 283 | 284 | labels = torch.cat((label1, label2), 0) 285 | 286 | 287 | # Graph construction 288 | # one_hot = F.one_hot(labels, num_classes=n_class) # for version > 1.2 289 | one_hot = torch.index_select(torch.eye(n_class), dim = 0, index = labels) 290 | # Compute A in Eq. (6) 291 | adj = torch.mm(one_hot, torch.transpose(one_hot, 0, 1)).float() + torch.eye(labels.size()[0]).float() 292 | w_norm = adj.pow(2).sum(1, keepdim=True).pow(1. / 2) 293 | adj_norm = adj.div(w_norm) # normalized adjacency matrix 294 | 295 | input1 = Variable(input1.cuda()) 296 | input2 = Variable(input2.cuda()) 297 | 298 | labels = Variable(labels.cuda()) 299 | adj_norm = Variable(adj_norm.cuda()) 300 | data_time.update(time.time() - end) 301 | 302 | # Forward into the network 303 | feat, out0, out_att, output = net(input1, input2, adj_norm) 304 | 305 | # baseline loss: identity loss + triplet loss Eq. (1) 306 | loss_id = criterion1(out0, labels) 307 | loss_tri, batch_acc = criterion2(feat, labels) 308 | correct += (batch_acc / 2) 309 | _, predicted = out0.max(1) 310 | correct += (predicted.eq(labels).sum().item() / 2) 311 | 312 | # Part attention loss 313 | loss_p = criterion1(out_att, labels) 314 | 315 | # Graph attention loss Eq. (9) 316 | loss_G = F.nll_loss(output, labels) 317 | 318 | # Instance-level part-aggregated feature learning Eq. (10) 319 | loss = loss_id + loss_tri + loss_p 320 | # Overall loss Eq. (11) 321 | loss_total = loss + wG * loss_G 322 | 323 | # optimization 324 | optimizer_P.zero_grad() 325 | loss_total.backward() 326 | optimizer_P.step() 327 | 328 | # log different loss components 329 | train_loss.update(loss.item(), 2 * input1.size(0)) 330 | id_loss.update(loss_id.item(), 2 * input1.size(0)) 331 | tri_loss.update(loss_tri.item(), 2 * input1.size(0)) 332 | graph_loss.update(loss_G.item(), 2 * input1.size(0)) 333 | total += labels.size(0) 334 | 335 | # measure elapsed time 336 | batch_time.update(time.time() - end) 337 | end = time.time() 338 | if batch_idx % 10 == 0: 339 | print('Epoch: [{}][{}/{}] ' 340 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 341 | 'lr:{} ' 342 | 'Loss: {train_loss.val:.4f} ({train_loss.avg:.4f}) ' 343 | 'iLoss: {id_loss.val:.4f} ({id_loss.avg:.4f}) ' 344 | 'TLoss: {tri_loss.val:.4f} ({tri_loss.avg:.4f}) ' 345 | 'GLoss: {graph_loss.val:.4f} ({graph_loss.avg:.4f}) ' 346 | 'Accu: {:.2f}'.format( 347 | epoch, batch_idx, len(trainloader), current_lr, 348 | 100. * correct / total, batch_time=batch_time, 349 | train_loss=train_loss, id_loss=id_loss, tri_loss=tri_loss, graph_loss=graph_loss)) 350 | 351 | writer.add_scalar('total_loss', train_loss.avg, epoch) 352 | writer.add_scalar('id_loss', id_loss.avg, epoch) 353 | writer.add_scalar('tri_loss', tri_loss.avg, epoch) 354 | writer.add_scalar('graph_loss', graph_loss.avg, epoch) 355 | writer.add_scalar('lr', current_lr, epoch) 356 | # computer wG 357 | return 1. / (1. + train_loss.avg) 358 | 359 | def test(epoch): 360 | # switch to evaluation mode 361 | net.eval() 362 | print('Extracting Gallery Feature...') 363 | start = time.time() 364 | ptr = 0 365 | gall_feat = np.zeros((ngall, 2048)) 366 | gall_feat_att = np.zeros((ngall, 2048)) 367 | with torch.no_grad(): 368 | for batch_idx, (input, label) in enumerate(gall_loader): 369 | batch_num = input.size(0) 370 | input = Variable(input.cuda()) 371 | feat, feat_att = net(input, input, 0, test_mode[0]) 372 | gall_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 373 | gall_feat_att[ptr:ptr + batch_num, :] = feat_att.detach().cpu().numpy() 374 | ptr = ptr + batch_num 375 | print('Extracting Time:\t {:.3f}'.format(time.time() - start)) 376 | 377 | # switch to evaluation 378 | net.eval() 379 | print('Extracting Query Feature...') 380 | start = time.time() 381 | ptr = 0 382 | query_feat = np.zeros((nquery, 2048)) 383 | query_feat_att = np.zeros((nquery, 2048)) 384 | with torch.no_grad(): 385 | for batch_idx, (input, label) in enumerate(query_loader): 386 | batch_num = input.size(0) 387 | input = Variable(input.cuda()) 388 | feat, feat_att = net(input, input, 0, test_mode[1]) 389 | query_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy() 390 | query_feat_att[ptr:ptr + batch_num, :] = feat_att.detach().cpu().numpy() 391 | ptr = ptr + batch_num 392 | print('Extracting Time:\t {:.3f}'.format(time.time() - start)) 393 | 394 | start = time.time() 395 | # compute the similarity 396 | distmat = np.matmul(query_feat, np.transpose(gall_feat)) 397 | distmat_att = np.matmul(query_feat_att, np.transpose(gall_feat_att)) 398 | 399 | # evaluation 400 | if dataset == 'regdb': 401 | cmc, mAP, mINP = eval_regdb(-distmat, query_label, gall_label) 402 | cmc_att, mAP_att, mINP_att = eval_regdb(-distmat_att, query_label, gall_label) 403 | elif dataset == 'sysu': 404 | cmc, mAP, mINP = eval_sysu(-distmat, query_label, gall_label, query_cam, gall_cam) 405 | cmc_att, mAP_att, mINP_att = eval_sysu(-distmat_att, query_label, gall_label, query_cam, gall_cam) 406 | print('Evaluation Time:\t {:.3f}'.format(time.time() - start)) 407 | 408 | writer.add_scalar('rank1', cmc[0], epoch) 409 | writer.add_scalar('mAP', mAP, epoch) 410 | writer.add_scalar('rank1_att', cmc_att[0], epoch) 411 | writer.add_scalar('mAP_att', mAP_att, epoch) 412 | writer.add_scalar('mAP_att', mAP_att, epoch) 413 | writer.add_scalar('mINP_att', mINP_att, epoch) 414 | return cmc, mAP, mINP, cmc_att, mAP_att, mINP_att 415 | 416 | 417 | # training 418 | print('==> Start Training...') 419 | for epoch in range(start_epoch, 81 - start_epoch): 420 | 421 | print('==> Preparing Data Loader...') 422 | # identity sampler: 423 | sampler = IdentitySampler(trainset.train_color_label, \ 424 | trainset.train_thermal_label, color_pos, thermal_pos, args.num_pos, args.batch_size, 425 | epoch) 426 | 427 | trainset.cIndex = sampler.index1 # color index 428 | trainset.tIndex = sampler.index2 # infrared index 429 | print(epoch) 430 | print(trainset.cIndex) 431 | print(trainset.tIndex) 432 | 433 | loader_batch = args.batch_size * args.num_pos 434 | 435 | trainloader = data.DataLoader(trainset, batch_size=loader_batch, \ 436 | sampler=sampler, num_workers=args.workers, drop_last=True) 437 | 438 | # training 439 | wG = train(epoch, wG) 440 | 441 | if epoch > 0 and epoch % 2 == 0: 442 | print('Test Epoch: {}'.format(epoch)) 443 | print('Test Epoch: {}'.format(epoch), file=test_log_file) 444 | 445 | # testing 446 | cmc, mAP, mINP, cmc_att, mAP_att, mINP_att = test(epoch) 447 | # log output 448 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 449 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP)) 450 | print('FC: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 451 | cmc[0], cmc[4], cmc[9], cmc[19], mAP, mINP), file=test_log_file) 452 | 453 | print('FC_att: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 454 | cmc_att[0], cmc_att[4], cmc_att[9], cmc_att[19], mAP_att, mINP_att)) 455 | print('FC_att: Rank-1: {:.2%} | Rank-5: {:.2%} | Rank-10: {:.2%}| Rank-20: {:.2%}| mAP: {:.2%}| mINP: {:.2%}'.format( 456 | cmc_att[0], cmc_att[4], cmc_att[9], cmc_att[19], mAP_att, mINP_att), file=test_log_file) 457 | test_log_file.flush() 458 | 459 | # save model 460 | if cmc_att[0] > best_acc: # not the real best for sysu-mm01 461 | best_acc = cmc_att[0] 462 | state = { 463 | 'net': net.state_dict(), 464 | 'cmc': cmc_att, 465 | 'mAP': mAP_att, 466 | 'epoch': epoch, 467 | } 468 | torch.save(state, checkpoint_path + suffix + '_best.t') -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | import numbers 4 | import numpy as np 5 | from torch.utils.data.sampler import Sampler 6 | import sys 7 | import os.path as osp 8 | import scipy.io as scio 9 | import torch 10 | 11 | def load_data(input_data_path ): 12 | with open(input_data_path) as f: 13 | data_file_list = open(input_data_path, 'rt').read().splitlines() 14 | # Get full list of color image and labels 15 | file_image = [s.split(' ')[0] for s in data_file_list] 16 | file_label = [int(s.split(' ')[1]) for s in data_file_list] 17 | 18 | return file_image, file_label 19 | 20 | 21 | def GenIdx( train_color_label, train_thermal_label): 22 | color_pos = [] 23 | unique_label_color = np.unique(train_color_label) 24 | for i in range(len(unique_label_color)): 25 | tmp_pos = [k for k,v in enumerate(train_color_label) if v==unique_label_color[i]] 26 | color_pos.append(tmp_pos) 27 | 28 | thermal_pos = [] 29 | unique_label_thermal = np.unique(train_thermal_label) 30 | for i in range(len(unique_label_thermal)): 31 | tmp_pos = [k for k,v in enumerate(train_thermal_label) if v==unique_label_thermal[i]] 32 | thermal_pos.append(tmp_pos) 33 | return color_pos, thermal_pos 34 | 35 | def GenCamIdx(gall_img, gall_label, mode): 36 | if mode =='indoor': 37 | camIdx = [1,2] 38 | else: 39 | camIdx = [1,2,4,5] 40 | gall_cam = [] 41 | for i in range(len(gall_img)): 42 | gall_cam.append(int(gall_img[i][-10])) 43 | 44 | sample_pos = [] 45 | unique_label = np.unique(gall_label) 46 | for i in range(len(unique_label)): 47 | for j in range(len(camIdx)): 48 | id_pos = [k for k,v in enumerate(gall_label) if v==unique_label[i] and gall_cam[k]==camIdx[j]] 49 | if id_pos: 50 | sample_pos.append(id_pos) 51 | return sample_pos 52 | 53 | def ExtractCam(gall_img): 54 | gall_cam = [] 55 | for i in range(len(gall_img)): 56 | cam_id = int(gall_img[i][-10]) 57 | # if cam_id ==3: 58 | # cam_id = 2 59 | gall_cam.append(cam_id) 60 | 61 | return np.array(gall_cam) 62 | 63 | class IdentitySampler(Sampler): 64 | """Sample person identities evenly in each batch. 65 | Args: 66 | train_color_label, train_thermal_label: labels of two modalities 67 | color_pos, thermal_pos: positions of each identity 68 | batchSize: batch size 69 | """ 70 | 71 | def __init__(self, train_color_label, train_thermal_label, color_pos, thermal_pos, num_pos, batchSize, epoch): 72 | uni_label = np.unique(train_color_label) 73 | self.n_classes = len(uni_label) 74 | 75 | 76 | N = np.maximum(len(train_color_label), len(train_thermal_label)) 77 | for j in range(int(N/(batchSize*num_pos))+1): 78 | batch_idx = np.random.choice(uni_label, batchSize, replace = False) 79 | for i in range(batchSize): 80 | sample_color = np.random.choice(color_pos[batch_idx[i]], num_pos) 81 | sample_thermal = np.random.choice(thermal_pos[batch_idx[i]], num_pos) 82 | 83 | if j ==0 and i==0: 84 | index1= sample_color 85 | index2= sample_thermal 86 | else: 87 | index1 = np.hstack((index1, sample_color)) 88 | index2 = np.hstack((index2, sample_thermal)) 89 | 90 | self.index1 = index1 91 | self.index2 = index2 92 | self.N = N 93 | 94 | def __iter__(self): 95 | return iter(np.arange(len(self.index1))) 96 | 97 | def __len__(self): 98 | return self.N 99 | 100 | class AverageMeter(object): 101 | """Computes and stores the average and current value""" 102 | def __init__(self): 103 | self.reset() 104 | 105 | def reset(self): 106 | self.val = 0 107 | self.avg = 0 108 | self.sum = 0 109 | self.count = 0 110 | 111 | def update(self, val, n=1): 112 | self.val = val 113 | self.sum += val * n 114 | self.count += n 115 | self.avg = self.sum / self.count 116 | 117 | def mkdir_if_missing(directory): 118 | if not osp.exists(directory): 119 | try: 120 | os.makedirs(directory) 121 | except OSError as e: 122 | if e.errno != errno.EEXIST: 123 | raise 124 | class Logger(object): 125 | """ 126 | Write console output to external text file. 127 | Code imported from https://github.com/Cysu/open-reid/blob/master/reid/utils/logging.py. 128 | """ 129 | def __init__(self, fpath=None): 130 | self.console = sys.stdout 131 | self.file = None 132 | if fpath is not None: 133 | mkdir_if_missing(osp.dirname(fpath)) 134 | self.file = open(fpath, 'w') 135 | 136 | def __del__(self): 137 | self.close() 138 | 139 | def __enter__(self): 140 | pass 141 | 142 | def __exit__(self, *args): 143 | self.close() 144 | 145 | def write(self, msg): 146 | self.console.write(msg) 147 | if self.file is not None: 148 | self.file.write(msg) 149 | 150 | def flush(self): 151 | self.console.flush() 152 | if self.file is not None: 153 | self.file.flush() 154 | os.fsync(self.file.fileno()) 155 | 156 | def close(self): 157 | self.console.close() 158 | if self.file is not None: 159 | self.file.close() 160 | 161 | def set_seed(seed, cuda=True): 162 | np.random.seed(seed) 163 | torch.manual_seed(seed) 164 | if cuda: 165 | torch.cuda.manual_seed(seed) 166 | 167 | def set_requires_grad(nets, requires_grad=False): 168 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 169 | Parameters: 170 | nets (network list) -- a list of networks 171 | requires_grad (bool) -- whether the networks require gradients or not 172 | """ 173 | if not isinstance(nets, list): 174 | nets = [nets] 175 | for net in nets: 176 | if net is not None: 177 | for param in net.parameters(): 178 | param.requires_grad = requires_grad --------------------------------------------------------------------------------