├── Loggers.py ├── Losses.py ├── README.md ├── Utils.py ├── clustering.py ├── data └── .gitignore ├── dr_ae_HardNet.py ├── dr_ae_SIFT.py ├── dr_pca.py ├── dr_ss_HardNet.py ├── dr_ss_SIFT.py ├── dr_sv_HardNet.py ├── dr_sv_SIFT.py ├── example.py ├── hpatches_extract_SIFT_64.py ├── models ├── .gitignore ├── HardNet_sv_dim16.pth ├── HardNet_sv_dim24.pth ├── HardNet_sv_dim32.pth ├── HardNet_sv_dim64.pth ├── SIFT_sv_dim16.pth ├── SIFT_sv_dim24.pth ├── SIFT_sv_dim32.pth └── SIFT_sv_dim64.pth ├── pics ├── HardNet-Hpatches.PNG ├── HardNet-Ours-SV-64.svg ├── HardNet-PCA-64.svg ├── HardNet.svg ├── MKD-Hpatches.PNG ├── MKD-Ours-SV-64.svg ├── MKD-PCA-64.svg ├── MKD.svg ├── SIFT-Hpatches.png ├── SIFT-Ours-SV-64.svg ├── SIFT-PCA-64.svg ├── SIFT.svg ├── TFeat-Hpatches.PNG ├── TFeat-Ours-SV-64.svg ├── TFeat-PCA-64.svg ├── TFeat.svg ├── localization.PNG └── overview.png ├── raw_descriptors └── .gitignore └── util.py /Loggers.py: -------------------------------------------------------------------------------- 1 | from tensorboard_logger import configure, log_value 2 | import os 3 | 4 | class FileLogger: 5 | "Log text in file." 6 | def __init__(self, path): 7 | self.path = path 8 | 9 | def log_string(self, file_name, string): 10 | """Stores log string in specified file.""" 11 | text_file = open(self.path+file_name+".log", "a") 12 | text_file.write(string+''+str(string)+'\n') 13 | text_file.close() 14 | 15 | def log_stats(self, file_name, text_to_save, value): 16 | """Stores log in specified file.""" 17 | text_file = open(self.path+file_name+".log", "a") 18 | text_file.write(text_to_save+' '+str(value)+'\n') 19 | text_file.close() 20 | 21 | 22 | class Logger(object): 23 | "Tensorboard Logger" 24 | def __init__(self, log_dir): 25 | # clean previous logged data under the same directory name 26 | self._remove(log_dir) 27 | 28 | # configure the project 29 | configure(log_dir) 30 | 31 | self.global_step = 0 32 | 33 | def log_value(self, name, value): 34 | log_value(name, value, self.global_step) 35 | return self 36 | 37 | def step(self): 38 | self.global_step += 1 39 | 40 | @staticmethod 41 | def _remove(path): 42 | """ param could either be relative or absolute. """ 43 | if os.path.isfile(path): 44 | os.remove(path) # remove the file 45 | elif os.path.isdir(path): 46 | import shutil 47 | shutil.rmtree(path) # remove dir and all contains 48 | -------------------------------------------------------------------------------- /Losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import sys 4 | 5 | def distance_matrix_vector(anchor, positive): 6 | """Given batch of anchor descriptors and positive descriptors calculate distance matrix""" 7 | 8 | d1_sq = torch.sum(anchor * anchor, dim=1).unsqueeze(-1) 9 | d2_sq = torch.sum(positive * positive, dim=1).unsqueeze(-1) 10 | 11 | eps = 1e-6 12 | return torch.sqrt((d1_sq.repeat(1, positive.size(0)) + torch.t(d2_sq.repeat(1, anchor.size(0))) 13 | - 2.0 * torch.bmm(anchor.unsqueeze(0), torch.t(positive).unsqueeze(0)).squeeze(0))+eps) 14 | 15 | def inner_dot_matrix(anchor, postive): 16 | inner = torch.mm(anchor, torch.t(postive)) 17 | mask = torch.eye(inner.size(1)).cuda() 18 | inner = inner - 1e-6*mask 19 | dist_m = torch.sqrt( 2.0*(1.0-inner) + 1e-8) 20 | return dist_m 21 | 22 | def distance_vectors_pairwise(anchor, positive, negative = None): 23 | """Given batch of anchor descriptors and positive descriptors calculate distance matrix""" 24 | 25 | a_sq = torch.sum(anchor * anchor, dim=1) 26 | p_sq = torch.sum(positive * positive, dim=1) 27 | 28 | eps = 1e-8 29 | d_a_p = torch.sqrt(a_sq + p_sq - 2*torch.sum(anchor * positive, dim = 1) + eps) 30 | if negative is not None: 31 | n_sq = torch.sum(negative * negative, dim=1) 32 | d_a_n = torch.sqrt(a_sq + n_sq - 2*torch.sum(anchor * negative, dim = 1) + eps) 33 | d_p_n = torch.sqrt(p_sq + n_sq - 2*torch.sum(positive * negative, dim = 1) + eps) 34 | return d_a_p, d_a_n, d_p_n 35 | return d_a_p 36 | def loss_random_sampling(anchor, positive, negative, anchor_swap = False, margin = 1.0, loss_type = "triplet_margin"): 37 | """Loss with random sampling (no hard in batch). 38 | """ 39 | 40 | assert anchor.size() == positive.size(), "Input sizes between positive and negative must be equal." 41 | assert anchor.size() == negative.size(), "Input sizes between positive and negative must be equal." 42 | assert anchor.dim() == 2, "Inputd must be a 2D matrix." 43 | eps = 1e-8 44 | (pos, d_a_n, d_p_n) = distance_vectors_pairwise(anchor, positive, negative) 45 | if anchor_swap: 46 | min_neg = torch.min(d_a_n, d_p_n) 47 | else: 48 | min_neg = d_a_n 49 | 50 | if loss_type == "triplet_margin": 51 | loss = torch.clamp(margin + pos - min_neg, min=0.0) 52 | elif loss_type == 'softmax': 53 | exp_pos = torch.exp(2.0 - pos); 54 | exp_den = exp_pos + torch.exp(2.0 - min_neg) + eps; 55 | loss = - torch.log( exp_pos / exp_den ) 56 | elif loss_type == 'contrastive': 57 | loss = torch.clamp(margin - min_neg, min=0.0) + pos; 58 | else: 59 | print ('Unknown loss type. Try triplet_margin, softmax or contrastive') 60 | sys.exit(1) 61 | loss = torch.mean(loss) 62 | return loss 63 | 64 | def loss_L2Net(anchor, positive, anchor_swap = False, margin = 1.0, loss_type = "triplet_margin"): 65 | """L2Net losses: using whole batch as negatives, not only hardest. 66 | """ 67 | 68 | assert anchor.size() == positive.size(), "Input sizes between positive and negative must be equal." 69 | assert anchor.dim() == 2, "Inputd must be a 2D matrix." 70 | eps = 1e-8 71 | dist_matrix = distance_matrix_vector(anchor, positive) 72 | eye = torch.autograd.Variable(torch.eye(dist_matrix.size(1))).cuda() 73 | 74 | # steps to filter out same patches that occur in distance matrix as negatives 75 | pos1 = torch.diag(dist_matrix) 76 | dist_without_min_on_diag = dist_matrix+eye*10 77 | mask = (dist_without_min_on_diag.ge(0.008).float()-1.0)*(-1) 78 | mask = mask.type_as(dist_without_min_on_diag)*10 79 | dist_without_min_on_diag = dist_without_min_on_diag+mask 80 | 81 | if loss_type == 'softmax': 82 | exp_pos = torch.exp(2.0 - pos1) 83 | exp_den = torch.sum(torch.exp(2.0 - dist_matrix),1) + eps 84 | loss = -torch.log( exp_pos / exp_den ) 85 | if anchor_swap: 86 | exp_den1 = torch.sum(torch.exp(2.0 - dist_matrix),0) + eps 87 | loss += -torch.log( exp_pos / exp_den1 ) 88 | else: 89 | print ('Only softmax loss works with L2Net sampling') 90 | sys.exit(1) 91 | loss = torch.mean(loss) 92 | return loss 93 | 94 | def loss_HardNet(anchor, positive, anchor_swap = False, anchor_ave = False,\ 95 | margin = 1.0, batch_reduce = 'min', loss_type = "triplet_margin"): 96 | """HardNet margin loss - calculates loss based on distance matrix based on positive distance and closest negative distance. 97 | """ 98 | 99 | assert anchor.size() == positive.size(), "Input sizes between positive and negative must be equal." 100 | assert anchor.dim() == 2, "Inputd must be a 2D matrix." 101 | eps = 1e-8 102 | dist_matrix = distance_matrix_vector(anchor, positive) +eps 103 | eye = torch.autograd.Variable(torch.eye(dist_matrix.size(1))).cuda() 104 | 105 | # steps to filter out same patches that occur in distance matrix as negatives 106 | pos1 = torch.diag(dist_matrix) 107 | dist_without_min_on_diag = dist_matrix+eye*10 108 | mask = (dist_without_min_on_diag.ge(0.008).float()-1.0)*(-1) 109 | mask = mask.type_as(dist_without_min_on_diag)*10 110 | dist_without_min_on_diag = dist_without_min_on_diag+mask 111 | if batch_reduce == 'min': 112 | min_neg = torch.min(dist_without_min_on_diag,1)[0] 113 | if anchor_swap: 114 | min_neg2 = torch.min(dist_without_min_on_diag,0)[0] 115 | min_neg = torch.min(min_neg,min_neg2) 116 | min_neg = min_neg 117 | pos = pos1 118 | elif batch_reduce == 'average': 119 | pos = pos1.repeat(anchor.size(0)).view(-1,1).squeeze(0) 120 | min_neg = dist_without_min_on_diag.view(-1,1) 121 | if anchor_swap: 122 | min_neg2 = torch.t(dist_without_min_on_diag).contiguous().view(-1,1) 123 | min_neg = torch.min(min_neg,min_neg2) 124 | min_neg = min_neg.squeeze(0) 125 | elif batch_reduce == 'random': 126 | idxs = torch.autograd.Variable(torch.randperm(anchor.size()[0]).long()).cuda() 127 | min_neg = dist_without_min_on_diag.gather(1,idxs.view(-1,1)) 128 | if anchor_swap: 129 | min_neg2 = torch.t(dist_without_min_on_diag).gather(1,idxs.view(-1,1)) 130 | min_neg = torch.min(min_neg,min_neg2) 131 | min_neg = torch.t(min_neg).squeeze(0) 132 | pos = pos1 133 | else: 134 | print ('Unknown batch reduce mode. Try min, average or random') 135 | sys.exit(1) 136 | if loss_type == "triplet_margin": 137 | loss = torch.clamp(margin + pos - min_neg, min=0.0) 138 | #loss = nn.ReLU()(margin + pos - min_neg) 139 | elif loss_type == "triplet_margin_QHT": 140 | loss = torch.square(torch.clamp(margin + pos - min_neg, min=0.0)) 141 | elif loss_type == 'softmax': 142 | exp_pos = torch.exp(2.0 - pos) 143 | exp_den = exp_pos + torch.exp(2.0 - min_neg) + eps 144 | loss = - torch.log( exp_pos / exp_den ) 145 | elif loss_type == 'contrastive': 146 | loss = torch.clamp(margin - min_neg, min=0.0) + pos 147 | else: 148 | print ('Unknown loss type. Try triplet_margin, softmax or contrastive') 149 | sys.exit(1) 150 | loss = torch.mean(loss) 151 | return loss 152 | 153 | def loss_HardNet_metric(anchor, positive,out_a_raw,out_p_raw, anchor_swap = False, anchor_ave = False,\ 154 | margin = 1.0, batch_reduce = 'min', loss_type = "triplet_margin",alpha=0.0): 155 | """HardNet margin loss - calculates loss based on distance matrix based on positive distance and closest negative distance. 156 | """ 157 | 158 | assert anchor.size() == positive.size(), "Input sizes between positive and negative must be equal." 159 | assert anchor.dim() == 2, "Inputd must be a 2D matrix." 160 | eps = 1e-8 161 | dist_matrix = distance_matrix_vector(anchor, positive) +eps 162 | eye = torch.autograd.Variable(torch.eye(dist_matrix.size(1))).cuda() 163 | 164 | # steps to filter out same patches that occur in distance matrix as negatives 165 | pos1 = torch.diag(dist_matrix) 166 | dist_without_min_on_diag = dist_matrix+eye*10 167 | mask = (dist_without_min_on_diag.ge(0.008).float()-1.0)*(-1) 168 | mask = mask.type_as(dist_without_min_on_diag)*10 169 | dist_without_min_on_diag = dist_without_min_on_diag+mask 170 | if batch_reduce == 'min': 171 | min_neg = torch.min(dist_without_min_on_diag,1)[0] 172 | if anchor_swap: 173 | min_neg2 = torch.min(dist_without_min_on_diag,0)[0] 174 | min_neg = torch.min(min_neg,min_neg2) 175 | min_neg = min_neg 176 | pos = pos1 177 | elif batch_reduce == 'average': 178 | pos = pos1.repeat(anchor.size(0)).view(-1,1).squeeze(0) 179 | min_neg = dist_without_min_on_diag.view(-1,1) 180 | if anchor_swap: 181 | min_neg2 = torch.t(dist_without_min_on_diag).contiguous().view(-1,1) 182 | min_neg = torch.min(min_neg,min_neg2) 183 | min_neg = min_neg.squeeze(0) 184 | elif batch_reduce == 'random': 185 | idxs = torch.autograd.Variable(torch.randperm(anchor.size()[0]).long()).cuda() 186 | min_neg = dist_without_min_on_diag.gather(1,idxs.view(-1,1)) 187 | if anchor_swap: 188 | min_neg2 = torch.t(dist_without_min_on_diag).gather(1,idxs.view(-1,1)) 189 | min_neg = torch.min(min_neg,min_neg2) 190 | min_neg = torch.t(min_neg).squeeze(0) 191 | pos = pos1 192 | else: 193 | print ('Unknown batch reduce mode. Try min, average or random') 194 | sys.exit(1) 195 | if loss_type == "triplet_margin": 196 | loss = torch.clamp(margin + pos - min_neg, min=0.0) 197 | #loss = nn.ReLU()(margin + pos - min_neg) 198 | elif loss_type == "triplet_margin_QHT": 199 | loss = torch.square(torch.clamp(margin + pos - min_neg, min=0.0)) 200 | elif loss_type == 'softmax': 201 | exp_pos = torch.exp(2.0 - pos) 202 | exp_den = exp_pos + torch.exp(2.0 - min_neg) + eps 203 | loss = - torch.log( exp_pos / exp_den ) 204 | elif loss_type == 'contrastive': 205 | loss = torch.clamp(margin - min_neg, min=0.0) + pos 206 | else: 207 | print ('Unknown loss type. Try triplet_margin, softmax or contrastive') 208 | sys.exit(1) 209 | loss = torch.mean(loss) 210 | 211 | e_loss = torch.tensor(0.).float().cuda() 212 | 213 | sqdist_matrix_anchor_embeddings = 2 - 2 * anchor @ anchor.T 214 | sqdist_matrix_anchor = 2 - 2 * out_a_raw @ out_a_raw.T 215 | 216 | sqdist_matrix_positive_embeddings = 2 - 2 * positive @ positive.T 217 | sqdist_matrix_positive = 2 - 2 * out_p_raw @ out_p_raw.T 218 | 219 | sqdist_matrix_anchor_positive_embeddings = 2 - 2 * anchor @ positive.T 220 | sqdist_matrix_anchor_positive = 2 - 2 * out_a_raw @ out_p_raw.T 221 | 222 | e_loss += torch.mean( 223 | torch.abs(sqdist_matrix_anchor - sqdist_matrix_anchor_embeddings) 224 | ) 225 | 226 | e_loss += torch.mean( 227 | torch.abs(sqdist_matrix_positive - sqdist_matrix_positive_embeddings) 228 | ) 229 | 230 | e_loss += torch.mean( 231 | torch.abs(sqdist_matrix_anchor_positive - sqdist_matrix_anchor_positive_embeddings) 232 | ) 233 | 234 | if alpha > 0: 235 | loss_sum = loss + alpha * e_loss 236 | elif alpha < 0: 237 | loss_sum = e_loss 238 | else: 239 | loss_sum = loss 240 | 241 | return loss_sum 242 | 243 | def global_orthogonal_regularization(anchor, negative): 244 | 245 | neg_dis = torch.sum(torch.mul(anchor,negative),1) 246 | dim = anchor.size(1) 247 | gor = torch.pow(torch.mean(neg_dis),2) + torch.clamp(torch.mean(torch.pow(neg_dis,2))-1.0/dim, min=0.0) 248 | 249 | return gor 250 | 251 | class Loss_HyNet(): 252 | 253 | def __init__(self, device, dim_desc, margin, alpha, is_sosr, knn_sos=8): 254 | self.device = device 255 | self.margin = margin 256 | self.alpha = alpha 257 | self.is_sosr = is_sosr 258 | self.dim_desc = dim_desc 259 | self.knn_sos = knn_sos 260 | self.index_dim = torch.LongTensor(range(0, dim_desc)) 261 | 262 | def sort_distance(self): 263 | L = self.L.clone().detach() 264 | L = L + 2 * self.mask_pos_pair 265 | L = L + 2 * L.le(dist_th).float() 266 | 267 | R = self.R.clone().detach() 268 | R = R + 2 * self.mask_pos_pair 269 | R = R + 2 * R.le(dist_th).float() 270 | 271 | LR = self.LR.clone().detach() 272 | LR = LR + 2 * self.mask_pos_pair 273 | LR = LR + 2 * LR.le(dist_th).float() 274 | 275 | self.indice_L = torch.argsort(L, dim=1) 276 | self.indice_R = torch.argsort(R, dim=0) 277 | self.indice_LR = torch.argsort(LR, dim=1) 278 | self.indice_RL = torch.argsort(LR, dim=0) 279 | return 280 | 281 | def triplet_loss_hybrid(self): 282 | L = self.L 283 | R = self.R 284 | LR = self.LR 285 | indice_L = self.indice_L[:, 0] 286 | indice_R = self.indice_R[0, :] 287 | indice_LR = self.indice_LR[:, 0] 288 | indice_RL = self.indice_RL[0, :] 289 | index_desc = self.index_desc 290 | 291 | dist_pos = LR[self.mask_pos_pair.bool()] 292 | dist_neg_LL = L[index_desc, indice_L] 293 | dist_neg_RR = R[indice_R, index_desc] 294 | dist_neg_LR = LR[index_desc, indice_LR] 295 | dist_neg_RL = LR[indice_RL, index_desc] 296 | dist_neg = torch.cat((dist_neg_LL.unsqueeze(0), 297 | dist_neg_RR.unsqueeze(0), 298 | dist_neg_LR.unsqueeze(0), 299 | dist_neg_RL.unsqueeze(0)), dim=0) 300 | dist_neg_hard, index_neg_hard = torch.sort(dist_neg, dim=0) 301 | dist_neg_hard = dist_neg_hard[0, :] 302 | # scipy.io.savemat('dist.mat', dict(dist_pos=dist_pos.cpu().detach().numpy(), dist_neg=dist_neg_hard.cpu().detach().numpy())) 303 | 304 | loss_triplet = torch.clamp(self.margin + (dist_pos + dist_pos.pow(2)/2*self.alpha) - (dist_neg_hard + dist_neg_hard.pow(2)/2*self.alpha), min=0.0) 305 | 306 | self.num_triplet_display = loss_triplet.gt(0).sum() 307 | 308 | self.loss = self.loss + loss_triplet.sum() 309 | self.dist_pos_display = dist_pos.detach().mean() 310 | self.dist_neg_display = dist_neg_hard.detach().mean() 311 | 312 | return 313 | 314 | def norm_loss_pos(self): 315 | diff_norm = self.norm_L - self.norm_R 316 | self.loss += diff_norm.pow(2).sum().mul(0.1) 317 | 318 | def sos_loss(self): 319 | L = self.L 320 | R = self.R 321 | knn = self.knn_sos 322 | indice_L = self.indice_L[:, 0:knn] 323 | indice_R = self.indice_R[0:knn, :] 324 | indice_LR = self.indice_LR[:, 0:knn] 325 | indice_RL = self.indice_RL[0:knn, :] 326 | index_desc = self.index_desc 327 | num_pt_per_batch = self.num_pt_per_batch 328 | index_row = index_desc.unsqueeze(1).expand(-1, knn) 329 | index_col = index_desc.unsqueeze(0).expand(knn, -1) 330 | 331 | A_L = torch.zeros(num_pt_per_batch, num_pt_per_batch).to(self.device) 332 | A_R = torch.zeros(num_pt_per_batch, num_pt_per_batch).to(self.device) 333 | A_LR = torch.zeros(num_pt_per_batch, num_pt_per_batch).to(self.device) 334 | 335 | A_L[index_row, indice_L] = 1 336 | A_R[indice_R, index_col] = 1 337 | A_LR[index_row, indice_LR] = 1 338 | A_LR[indice_RL, index_col] = 1 339 | 340 | A_L = A_L + A_L.t() 341 | A_L = A_L.gt(0).float() 342 | A_R = A_R + A_R.t() 343 | A_R = A_R.gt(0).float() 344 | A_LR = A_LR + A_LR.t() 345 | A_LR = A_LR.gt(0).float() 346 | A = A_L + A_R + A_LR 347 | A = A.gt(0).float() * self.mask_neg_pair 348 | 349 | sturcture_dif = (L - R) * A 350 | self.loss = self.loss + sturcture_dif.pow(2).sum(dim=1).add(eps_sqrt).sqrt().sum() 351 | 352 | return 353 | 354 | def compute(self, desc_L, desc_R, desc_raw_L, desc_raw_R): 355 | num_pt_per_batch = desc_L.shape[0] 356 | self.num_pt_per_batch = num_pt_per_batch 357 | self.index_desc = torch.LongTensor(range(0, num_pt_per_batch)) 358 | diagnal = torch.eye(num_pt_per_batch) 359 | self.mask_pos_pair = diagnal.eq(1).float().to(self.device) 360 | self.mask_neg_pair = diagnal.eq(0).float().to(self.device) 361 | self.desc_L = desc_L 362 | self.desc_R = desc_R 363 | self.desc_raw_L = desc_raw_L 364 | self.desc_raw_R = desc_raw_R 365 | self.norm_L = self.desc_raw_L.pow(2).sum(1).add(eps_sqrt).sqrt() 366 | self.norm_R = self.desc_raw_R.pow(2).sum(1).add(eps_sqrt).sqrt() 367 | self.L = cal_l2_distance_matrix(desc_L, desc_L) 368 | self.R = cal_l2_distance_matrix(desc_R, desc_R) 369 | self.LR = cal_l2_distance_matrix(desc_L, desc_R) 370 | 371 | self.loss = torch.Tensor([0]).to(self.device) 372 | 373 | self.sort_distance() 374 | self.triplet_loss_hybrid() 375 | self.norm_loss_pos() 376 | if self.is_sosr: 377 | self.sos_loss() 378 | 379 | return self.loss, self.dist_pos_display, self.dist_neg_display 380 | 381 | class Loss_SOSNet(): 382 | 383 | def __init__(self, device, dim_desc, margin, knn_sos=8): 384 | self.device = device 385 | self.margin = margin 386 | self.dim_desc = dim_desc 387 | self.knn_sos = knn_sos 388 | self.index_dim = torch.LongTensor(range(0, dim_desc)) 389 | 390 | def sort_distance(self): 391 | L = self.L.clone().detach() 392 | L = L + 2 * self.mask_pos_pair 393 | L = L + 2 * L.le(dist_th).float() 394 | 395 | R = self.R.clone().detach() 396 | R = R + 2 * self.mask_pos_pair 397 | R = R + 2 * R.le(dist_th).float() 398 | 399 | LR = self.LR.clone().detach() 400 | LR = LR + 2 * self.mask_pos_pair 401 | LR = LR + 2 * LR.le(dist_th).float() 402 | 403 | self.indice_L = torch.argsort(L, dim=1) 404 | self.indice_R = torch.argsort(R, dim=0) 405 | self.indice_LR = torch.argsort(LR, dim=1) 406 | self.indice_RL = torch.argsort(LR, dim=0) 407 | return 408 | 409 | def triplet_loss(self): 410 | L = self.L 411 | R = self.R 412 | LR = self.LR 413 | indice_L = self.indice_L[:, 0] 414 | indice_R = self.indice_R[0, :] 415 | indice_LR = self.indice_LR[:, 0] 416 | indice_RL = self.indice_RL[0, :] 417 | index_desc = self.index_desc 418 | 419 | dist_neg_hard_L = torch.min(LR[index_desc, indice_LR], L[index_desc, indice_L]) 420 | dist_neg_hard_R = torch.min(LR[indice_RL, index_desc], R[indice_R, index_desc]) 421 | dist_neg_hard = torch.min(dist_neg_hard_L, dist_neg_hard_R) 422 | dist_pos = LR[self.mask_pos_pair.bool()] 423 | loss = torch.clamp(self.margin + dist_pos - dist_neg_hard, min=0.0) 424 | 425 | loss = loss.pow(2) 426 | 427 | self.loss = self.loss + loss.sum() 428 | self.dist_pos_display = dist_pos.detach().mean() 429 | self.dist_neg_display = dist_neg_hard.detach().mean() 430 | 431 | return 432 | 433 | def sos_loss(self): 434 | L = self.L 435 | R = self.R 436 | knn = self.knn_sos 437 | indice_L = self.indice_L[:, 0:knn] 438 | indice_R = self.indice_R[0:knn, :] 439 | indice_LR = self.indice_LR[:, 0:knn] 440 | indice_RL = self.indice_RL[0:knn, :] 441 | index_desc = self.index_desc 442 | num_pt_per_batch = self.num_pt_per_batch 443 | index_row = index_desc.unsqueeze(1).expand(-1, knn) 444 | index_col = index_desc.unsqueeze(0).expand(knn, -1) 445 | 446 | A_L = torch.zeros(num_pt_per_batch, num_pt_per_batch).to(self.device) 447 | A_R = torch.zeros(num_pt_per_batch, num_pt_per_batch).to(self.device) 448 | A_LR = torch.zeros(num_pt_per_batch, num_pt_per_batch).to(self.device) 449 | 450 | A_L[index_row, indice_L] = 1 451 | A_R[indice_R, index_col] = 1 452 | A_LR[index_row, indice_LR] = 1 453 | A_LR[indice_RL, index_col] = 1 454 | 455 | A_L = A_L + A_L.t() 456 | A_L = A_L.gt(0).float() 457 | A_R = A_R + A_R.t() 458 | A_R = A_R.gt(0).float() 459 | A_LR = A_LR + A_LR.t() 460 | A_LR = A_LR.gt(0).float() 461 | A = A_L + A_R + A_LR 462 | A = A.gt(0).float() * self.mask_neg_pair 463 | 464 | sturcture_dif = (L - R) * A 465 | self.loss = self.loss + sturcture_dif.pow(2).sum(dim=1).add(eps_sqrt).sqrt().sum() 466 | 467 | return 468 | 469 | def compute(self, desc_l, desc_r): 470 | num_pt_per_batch = desc_l.shape[0] 471 | self.num_pt_per_batch = num_pt_per_batch 472 | self.index_desc = torch.LongTensor(range(0, num_pt_per_batch)) 473 | diagnal = torch.eye(num_pt_per_batch) 474 | self.mask_pos_pair = diagnal.eq(1).float().to(self.device) 475 | self.mask_neg_pair = diagnal.eq(0).float().to(self.device) 476 | self.loss = torch.Tensor([0]).to(self.device) 477 | self.L = cal_l2_distance_matrix(desc_l, desc_l) 478 | self.R = cal_l2_distance_matrix(desc_r, desc_r) 479 | self.LR = cal_l2_distance_matrix(desc_l, desc_r) 480 | self.sort_distance() 481 | self.triplet_loss() 482 | self.sos_loss() 483 | 484 | return self.loss, self.dist_pos_display, self.dist_neg_display 485 | 486 | dist_th = 8e-3 487 | eps_sqrt = 1e-6 488 | 489 | def cal_l2_distance_matrix(x, y, flag_sqrt=True): 490 | ''''distance matrix of x with respect to y, d_ij is the distance between x_i and y_j''' 491 | D = torch.abs(2 * (1 - torch.mm(x, y.t()))) 492 | if flag_sqrt: 493 | D = torch.sqrt(D + eps_sqrt) 494 | return D 495 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Descriptors Dimensionality Reduction 2 | 3 | This repository contains the implementation of the paper: 4 | 5 | **Learning-Based Dimensionality Reduction for Computing Compact and Effective Local Feature Descriptors (ICRA2023)** 6 | [Hao Dong](https://sites.google.com/view/dong-hao/), [Xieyuanli Chen](https://www.ipb.uni-bonn.de/people/xieyuanli-chen/), [Mihai Dusmanu](https://dsmn.ml/), [Viktor Larsson](https://vlarsson.github.io/), [Marc Pollefeys](https://www.inf.ethz.ch/personal/pomarc/) and [Cyrill Stachniss](http://www.ipb.uni-bonn.de/people/cyrill-stachniss/) 7 | [Link](https://arxiv.org/abs/2209.13586) to the arXiv version of the paper is available. 8 | 9 | We propose and evaluate an MLP-based network for descriptor dimensionality reduction and show its superiority over PCA on multiple descriptors in various tasks. 10 | 11 | 12 | Overview of our approach. We first compute descriptors of given image patches. Then an MLP-based network is used for dimensionality reduction. We aim to learn an MLP-based projection better than PCA to generate lower-dimensional descriptors. 13 | 14 | ## Abstract 15 | A distinctive representation of image patches in form of features is a key component of many computer vision and robotics tasks, such as image matching, image retrieval, and visual localization. State-of-the-art descriptors, from hand-crafted descriptors such as SIFT to learned ones such as HardNet, are usually high dimensional; 128 dimensions or even more. The higher the dimensionality, the larger the memory consumption and computational time for approaches using such descriptors. In this paper, we investigate multi-layer perceptrons (MLPs) to extract low-dimensional but high-quality descriptors. We thoroughly analyze our method in unsupervised, self-supervised, and supervised settings, and evaluate the dimensionality reduction results on four representative descriptors. We consider different applications, including visual localization, patch verification, image matching and retrieval. The experiments show that our lightweight MLPs achieve better dimensionality reduction than PCA. The lower-dimensional descriptors generated by our approach outperform the original higher-dimensional descriptors in downstream tasks, especially for the hand-crafted ones. 16 | 17 | ## Installation 18 | ```bash 19 | pip install numpy kornia tqdm torch torchvision scipy faiss tensorboard_logger tabulate 20 | ``` 21 | ## Example 22 | A simple example for SIFT-SV-64 with pre-trained model: 23 | ```bash 24 | python example.py 25 | ``` 26 | 27 | ## Training 28 | ### PCA Reduction 29 | For SIFT: 30 | ```bash 31 | python dr_pca.py --descriptor SIFT 32 | ``` 33 | For HardNet: 34 | ```bash 35 | python dr_pca.py --descriptor HardNet 36 | ``` 37 | 38 | ### Unsupervised Reduction 39 | For SIFT: 40 | ```bash 41 | python dr_ae_SIFT.py 42 | ``` 43 | For HardNet: 44 | ```bash 45 | python dr_ae_HardNet.py 46 | ``` 47 | 48 | ### Self-supervised Reduction 49 | For SIFT: 50 | ```bash 51 | python dr_ss_SIFT.py 52 | ``` 53 | For HardNet: 54 | ```bash 55 | python dr_ss_HardNet.py 56 | ``` 57 | 58 | ### Supervised Reduction 59 | For SIFT: 60 | ```bash 61 | python dr_sv_SIFT.py 62 | ``` 63 | For HardNet: 64 | ```bash 65 | python dr_sv_HardNet.py 66 | ``` 67 | 68 | ### Evaluate on HPatches dataset 69 | Download HPatches dataset and benchmark: 70 | ```bash 71 | git clone https://github.com/hpatches/hpatches-benchmark.git 72 | cd hpatches-benchmark/ 73 | sh download.sh hpatches 74 | ``` 75 | Extract descriptors on HPatches dataset and evaluate: 76 | ```bash 77 | cd .. 78 | python hpatches_extract_SIFT_64.py /path/to/HPatches/dataset 79 | mkdir hpatches-benchmark/data/descriptors 80 | mv SIFT_sv_dim64/ hpatches-benchmark/data/descriptors/ 81 | cd hpatches-benchmark/python/ 82 | python hpatches_eval.py --descr-name=SIFT_sv_dim64 --task=matching --delimiter="," 83 | python hpatches_results.py --descr=SIFT_sv_dim64 --results-dir=results/ --task=matching 84 | python hpatches_eval.py --descr-name=SIFT_sv_dim64 --task=verification --delimiter="," 85 | python hpatches_results.py --descr=SIFT_sv_dim64 --results-dir=results/ --task=verification 86 | python hpatches_eval.py --descr-name=SIFT_sv_dim64 --task=retrieval --delimiter="," 87 | python hpatches_results.py --descr=SIFT_sv_dim64 --results-dir=results/ --task=retrieval 88 | ``` 89 | 90 | 91 | 92 | ## Citation 93 | If you use our implementation in your academic work, please cite the corresponding [paper](https://arxiv.org/abs/2209.13586): 94 | 95 | @inproceedings{dong2022dr, 96 | title={Learning-Based Dimensionality Reduction for Computing Compact and Effective Local Feature Descriptors}, 97 | author={Dong, Hao and Chen, Xieyuanli and Dusmanu, Mihai and Larsson, Viktor and Pollefeys, Marc and Stachniss, Cyrill}, 98 | booktitle={Proceedings of the IEEE International Conference on Robotics and Automation (ICRA)}, 99 | year={2023} 100 | } 101 | 102 | ## Localization on Aachen Day-Night v1.1 and InLoc datasets 103 | 104 | 105 | ## Supplementary 106 | 107 | ### Embedding visualization of the descriptors 108 | 109 |
110 |                         SIFT                                      SIFT-PCA-64                             SIFT-Ours-SV-64
111 |
112 |                         MKD                                     MKD-PCA-64                            MKD-Ours-SV-64
113 |
114 |                         TFeat                                     TFeat-PCA-64                         TFeat-Ours-SV-64
115 |
116 |                         HardNet                              HardNet-PCA-64                      HardNet-Ours-SV-64
117 | 118 | We provide the t-SNE embedding visualization of the descriptors on UBC Phototour Liberty. We visualize the embeddings of SIFT, MKD, TFeat, and HardNet generated using PCA and supervised methods by mapping high-dimensional descriptors (128 and 64) into 2D using t-SNE visualization. We pass the image patches through the descriptor extractor, followed by PCA or MLPs, to get lower-dimensional descriptors and determine their 2D locations using t-SNE transformation. Finally, we visualize the entire patch at each location. 119 | 120 | From the visualization, we can observe similar results as we discussed in the paper. For SIFT and MKD, the original descriptor space is irregular, and similar and dissimilar features are overlapped. Therefore, PCA projection will keep this irregular structure of the descriptor space. However, after learning a more discriminative representation using triplet loss, similar image patches in the descriptor space are close to each other while dissimilar ones have distances from each other. For TFeat and HardNet, since the outputting space is already optimized for the $\ell_2$ metric, image patches in the descriptor space are already separated well based on their appearance. Therefore, a simple PCA can preserve this distinctive structure and perform on par compared with the learned projection. 121 | 122 | 123 | ### Experiments for patch verification, image matching, and patch retrieval on HPatches dataset 124 |
125 |
126 |
127 |
128 | 129 | 130 | ## Acknowledgment 131 | We thank greatly for the authors of the following opensource projects: 132 | 133 | - [deepcluster](https://github.com/facebookresearch/deepcluster) 134 | - [hardnet](https://github.com/DagnyT/hardnet) -------------------------------------------------------------------------------- /Utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.init 3 | import torch.nn as nn 4 | import cv2 5 | import numpy as np 6 | 7 | # resize image to size 32x32 8 | cv2_scale36 = lambda x: cv2.resize(x, dsize=(36, 36), 9 | interpolation=cv2.INTER_LINEAR) 10 | cv2_scale = lambda x: cv2.resize(x, dsize=(32, 32), 11 | interpolation=cv2.INTER_LINEAR) 12 | # reshape image 13 | np_reshape = lambda x: np.reshape(x, (32, 32, 1)) 14 | 15 | class L2Norm(nn.Module): 16 | def __init__(self): 17 | super(L2Norm,self).__init__() 18 | self.eps = 1e-10 19 | def forward(self, x): 20 | norm = torch.sqrt(torch.sum(x * x, dim = 1) + self.eps) 21 | x= x / norm.unsqueeze(-1).expand_as(x) 22 | return x 23 | 24 | class L1Norm(nn.Module): 25 | def __init__(self): 26 | super(L1Norm,self).__init__() 27 | self.eps = 1e-10 28 | def forward(self, x): 29 | norm = torch.sum(torch.abs(x), dim = 1) + self.eps 30 | x= x / norm.expand_as(x) 31 | return x 32 | 33 | 34 | def str2bool(v): 35 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 36 | return True 37 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 38 | return False 39 | -------------------------------------------------------------------------------- /clustering.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import faiss 4 | import numpy as np 5 | from PIL import Image 6 | from PIL import ImageFile 7 | from scipy.sparse import csr_matrix, find 8 | import torch 9 | import torch.utils.data as data 10 | import torchvision.transforms as transforms 11 | 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | 14 | __all__ = ['PIC', 'Kmeans', 'cluster_assign', 'arrange_clustering'] 15 | 16 | 17 | def pil_loader(path): 18 | """Loads an image. 19 | Args: 20 | path (string): path to image file 21 | Returns: 22 | Image 23 | """ 24 | with open(path, 'rb') as f: 25 | img = Image.open(f) 26 | return img.convert('RGB') 27 | 28 | 29 | class ReassignedDataset(data.Dataset): 30 | """A dataset where the new images labels are given in argument. 31 | Args: 32 | image_indexes (list): list of data indexes 33 | pseudolabels (list): list of labels for each data 34 | dataset (list): list of tuples with paths to images 35 | transform (callable, optional): a function/transform that takes in 36 | an PIL image and returns a 37 | transformed version 38 | """ 39 | 40 | def __init__(self, image_indexes, pseudolabels, dataset): 41 | self.imgs = self.make_dataset(image_indexes, pseudolabels, dataset) 42 | 43 | def make_dataset(self, image_indexes, pseudolabels, dataset): 44 | label_to_idx = {label: idx for idx, label in enumerate(set(pseudolabels))} 45 | images = [] 46 | for j, idx in enumerate(image_indexes): 47 | descriptor = dataset[idx] 48 | pseudolabel = label_to_idx[pseudolabels[j]] 49 | images.append((descriptor, pseudolabel)) 50 | return images 51 | 52 | def __getitem__(self, index): 53 | """ 54 | Args: 55 | index (int): index of data 56 | Returns: 57 | tuple: (image, pseudolabel) where pseudolabel is the cluster of index datapoint 58 | """ 59 | descriptor, pseudolabel = self.imgs[index] 60 | return descriptor, pseudolabel 61 | 62 | def __len__(self): 63 | return len(self.imgs) 64 | 65 | 66 | def preprocess_features(npdata): 67 | """Preprocess an array of features. 68 | Args: 69 | npdata (np.array N * ndim): features to preprocess 70 | pca (int): dim of output 71 | Returns: 72 | np.array of dim N * pca: data PCA-reduced, whitened and L2-normalized 73 | """ 74 | # _, ndim = npdata.shape 75 | npdata = npdata.astype('float32') 76 | 77 | # Apply PCA-whitening with Faiss 78 | '''mat = faiss.PCAMatrix (ndim, pca, eigen_power=-0.5) 79 | mat.train(npdata) 80 | assert mat.is_trained 81 | npdata = mat.apply_py(npdata) 82 | 83 | # L2 normalization 84 | row_sums = np.linalg.norm(npdata, axis=1) 85 | npdata = npdata / row_sums[:, np.newaxis]''' 86 | 87 | return npdata 88 | 89 | 90 | def make_graph(xb, nnn): 91 | """Builds a graph of nearest neighbors. 92 | Args: 93 | xb (np.array): data 94 | nnn (int): number of nearest neighbors 95 | Returns: 96 | list: for each data the list of ids to its nnn nearest neighbors 97 | list: for each data the list of distances to its nnn NN 98 | """ 99 | N, dim = xb.shape 100 | 101 | # we need only a StandardGpuResources per GPU 102 | res = faiss.StandardGpuResources() 103 | 104 | # L2 105 | flat_config = faiss.GpuIndexFlatConfig() 106 | flat_config.device = int(torch.cuda.device_count()) - 1 107 | index = faiss.GpuIndexFlatL2(res, dim, flat_config) 108 | index.add(xb) 109 | D, I = index.search(xb, nnn + 1) 110 | return I, D 111 | 112 | 113 | def cluster_assign(images_lists, dataset): 114 | """Creates a dataset from clustering, with clusters as labels. 115 | Args: 116 | images_lists (list of list): for each cluster, the list of image indexes 117 | belonging to this cluster 118 | dataset (list): initial dataset 119 | Returns: 120 | ReassignedDataset(torch.utils.data.Dataset): a dataset with clusters as 121 | labels 122 | """ 123 | assert images_lists is not None 124 | pseudolabels = [] 125 | image_indexes = [] 126 | for cluster, images in enumerate(images_lists): 127 | image_indexes.extend(images) 128 | pseudolabels.extend([cluster] * len(images)) 129 | 130 | return ReassignedDataset(image_indexes, pseudolabels, dataset) 131 | 132 | 133 | def run_kmeans(x, nmb_clusters, verbose=False): 134 | """Runs kmeans on 1 GPU. 135 | Args: 136 | x: data 137 | nmb_clusters (int): number of clusters 138 | Returns: 139 | list: ids of data in each cluster 140 | """ 141 | n_data, d = x.shape 142 | 143 | # faiss implementation of k-means 144 | clus = faiss.Clustering(d, nmb_clusters) 145 | 146 | # Change faiss seed at each k-means so that the randomly picked 147 | # initialization centroids do not correspond to the same feature ids 148 | # from an epoch to another. 149 | clus.seed = np.random.randint(1234) 150 | 151 | clus.niter = 20 152 | clus.max_points_per_centroid = 10000000 153 | res = faiss.StandardGpuResources() 154 | flat_config = faiss.GpuIndexFlatConfig() 155 | flat_config.useFloat16 = False 156 | flat_config.device = 0 157 | index = faiss.GpuIndexFlatL2(res, d, flat_config) 158 | 159 | # perform the training 160 | clus.train(x, index) 161 | _, I = index.search(x, 1) 162 | # losses = faiss.vector_to_array(clus.obj) 163 | stats = clus.iteration_stats 164 | losses = np.array([ 165 | stats.at(i).obj for i in range(stats.size()) 166 | ]) 167 | 168 | if verbose: 169 | print('k-means loss evolution: {0}'.format(losses)) 170 | 171 | return [int(n[0]) for n in I], losses[-1] 172 | 173 | 174 | def arrange_clustering(images_lists): 175 | pseudolabels = [] 176 | image_indexes = [] 177 | for cluster, images in enumerate(images_lists): 178 | image_indexes.extend(images) 179 | pseudolabels.extend([cluster] * len(images)) 180 | indexes = np.argsort(image_indexes) 181 | return np.asarray(pseudolabels)[indexes] 182 | 183 | 184 | class Kmeans(object): 185 | def __init__(self, k): 186 | self.k = k 187 | 188 | def cluster(self, data, verbose=False): 189 | """Performs k-means clustering. 190 | Args: 191 | x_data (np.array N * dim): data to cluster 192 | """ 193 | end = time.time() 194 | 195 | # PCA-reducing, whitening and L2-normalization 196 | xb = preprocess_features(data) 197 | 198 | # cluster the data 199 | I, loss = run_kmeans(xb, self.k, verbose) 200 | self.images_lists = [[] for i in range(self.k)] 201 | for i in range(len(data)): 202 | self.images_lists[I[i]].append(i) 203 | 204 | if verbose: 205 | print('k-means time: {0:.0f} s'.format(time.time() - end)) 206 | 207 | return loss 208 | 209 | 210 | def make_adjacencyW(I, D, sigma): 211 | """Create adjacency matrix with a Gaussian kernel. 212 | Args: 213 | I (numpy array): for each vertex the ids to its nnn linked vertices 214 | + first column of identity. 215 | D (numpy array): for each data the l2 distances to its nnn linked vertices 216 | + first column of zeros. 217 | sigma (float): Bandwidth of the Gaussian kernel. 218 | 219 | Returns: 220 | csr_matrix: affinity matrix of the graph. 221 | """ 222 | V, k = I.shape 223 | k = k - 1 224 | indices = np.reshape(np.delete(I, 0, 1), (1, -1)) 225 | indptr = np.multiply(k, np.arange(V + 1)) 226 | 227 | def exp_ker(d): 228 | return np.exp(-d / sigma**2) 229 | 230 | exp_ker = np.vectorize(exp_ker) 231 | res_D = exp_ker(D) 232 | data = np.reshape(np.delete(res_D, 0, 1), (1, -1)) 233 | adj_matrix = csr_matrix((data[0], indices[0], indptr), shape=(V, V)) 234 | return adj_matrix 235 | 236 | 237 | def run_pic(I, D, sigma, alpha): 238 | """Run PIC algorithm""" 239 | a = make_adjacencyW(I, D, sigma) 240 | graph = a + a.transpose() 241 | cgraph = graph 242 | nim = graph.shape[0] 243 | 244 | W = graph 245 | t0 = time.time() 246 | 247 | v0 = np.ones(nim) / nim 248 | 249 | # power iterations 250 | v = v0.astype('float32') 251 | 252 | t0 = time.time() 253 | dt = 0 254 | for i in range(200): 255 | vnext = np.zeros(nim, dtype='float32') 256 | 257 | vnext = vnext + W.transpose().dot(v) 258 | 259 | vnext = alpha * vnext + (1 - alpha) / nim 260 | # L1 normalize 261 | vnext /= vnext.sum() 262 | v = vnext 263 | 264 | if i == 200 - 1: 265 | clust = find_maxima_cluster(W, v) 266 | 267 | return [int(i) for i in clust] 268 | 269 | 270 | def find_maxima_cluster(W, v): 271 | n, m = W.shape 272 | assert (n == m) 273 | assign = np.zeros(n) 274 | # for each node 275 | pointers = list(range(n)) 276 | for i in range(n): 277 | best_vi = 0 278 | l0 = W.indptr[i] 279 | l1 = W.indptr[i + 1] 280 | for l in range(l0, l1): 281 | j = W.indices[l] 282 | vi = W.data[l] * (v[j] - v[i]) 283 | if vi > best_vi: 284 | best_vi = vi 285 | pointers[i] = j 286 | n_clus = 0 287 | cluster_ids = -1 * np.ones(n) 288 | for i in range(n): 289 | if pointers[i] == i: 290 | cluster_ids[i] = n_clus 291 | n_clus = n_clus + 1 292 | for i in range(n): 293 | # go from pointers to pointers starting from i until reached a local optim 294 | current_node = i 295 | while pointers[current_node] != current_node: 296 | current_node = pointers[current_node] 297 | 298 | assign[i] = cluster_ids[current_node] 299 | assert (assign[i] >= 0) 300 | return assign 301 | 302 | 303 | class PIC(object): 304 | """Class to perform Power Iteration Clustering on a graph of nearest neighbors. 305 | Args: 306 | args: for consistency with k-means init 307 | sigma (float): bandwidth of the Gaussian kernel (default 0.2) 308 | nnn (int): number of nearest neighbors (default 5) 309 | alpha (float): parameter in PIC (default 0.001) 310 | distribute_singletons (bool): If True, reassign each singleton to 311 | the cluster of its closest non 312 | singleton nearest neighbors (up to nnn 313 | nearest neighbors). 314 | Attributes: 315 | images_lists (list of list): for each cluster, the list of image indexes 316 | belonging to this cluster 317 | """ 318 | 319 | def __init__(self, args=None, sigma=0.2, nnn=5, alpha=0.001, distribute_singletons=True): 320 | self.sigma = sigma 321 | self.alpha = alpha 322 | self.nnn = nnn 323 | self.distribute_singletons = distribute_singletons 324 | 325 | def cluster(self, data, verbose=False): 326 | end = time.time() 327 | 328 | # preprocess the data 329 | xb = preprocess_features(data) 330 | 331 | # construct nnn graph 332 | I, D = make_graph(xb, self.nnn) 333 | 334 | # run PIC 335 | clust = run_pic(I, D, self.sigma, self.alpha) 336 | images_lists = {} 337 | for h in set(clust): 338 | images_lists[h] = [] 339 | for data, c in enumerate(clust): 340 | images_lists[c].append(data) 341 | 342 | # allocate singletons to clusters of their closest NN not singleton 343 | if self.distribute_singletons: 344 | clust_NN = {} 345 | for i in images_lists: 346 | # if singleton 347 | if len(images_lists[i]) == 1: 348 | s = images_lists[i][0] 349 | # for NN 350 | for n in I[s, 1:]: 351 | # if NN is not a singleton 352 | if not len(images_lists[clust[n]]) == 1: 353 | clust_NN[s] = n 354 | break 355 | for s in clust_NN: 356 | del images_lists[clust[s]] 357 | clust[s] = clust[clust_NN[s]] 358 | images_lists[clust[s]].append(s) 359 | 360 | self.images_lists = [] 361 | for c in images_lists: 362 | self.images_lists.append(images_lists[c]) 363 | 364 | if verbose: 365 | print('pic time: {0:.0f} s'.format(time.time() - end)) 366 | return 0 367 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | liberty/ 3 | logs/ 4 | models/ 5 | notredame/ 6 | yosemite/ 7 | *.pt 8 | -------------------------------------------------------------------------------- /dr_ae_HardNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import os 4 | import numpy as np 5 | from torch.utils.data import DataLoader 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | import argparse 10 | import random 11 | from tqdm import tqdm 12 | 13 | parser = argparse.ArgumentParser(description='PyTorch dr') 14 | parser.add_argument('--descriptor', type=str, default='HardNet', help='descriptor') 15 | parser.add_argument('--dataset_names', type=str, default='liberty', help='dataset_names, notredame, yosemite, liberty') 16 | parser.add_argument('--reduce_dim', type=int, default=64, help='reduce_dim') 17 | parser.add_argument('--hidden', type=int, default=96, help='hidden') 18 | parser.add_argument('--bsz', type=int, default=1024, help='bsz') 19 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 20 | parser.add_argument('--seed', type=int, default=0, help='random seed (default: 0)') 21 | args = parser.parse_args() 22 | 23 | np.random.seed(args.seed) 24 | torch.manual_seed(args.seed) 25 | random.seed(args.seed) 26 | torch.backends.cudnn.deterministic = True 27 | torch.backends.cudnn.benchmark = False 28 | 29 | 30 | class DescriotorDataset(Dataset): 31 | def __init__(self, des_dir, descriptor): 32 | self.descriptorsfile = os.path.join(des_dir, descriptor + '-' + args.dataset_names + '.npz') 33 | self.descriptors = np.load(self.descriptorsfile)['descriptors'] 34 | 35 | def __len__(self): 36 | return self.descriptors.shape[0] 37 | 38 | def __getitem__(self, idx): 39 | descriptor = self.descriptors[idx] 40 | return torch.from_numpy(descriptor) 41 | 42 | descriptors = DescriotorDataset( 43 | des_dir="raw_descriptors", 44 | descriptor=args.descriptor 45 | ) 46 | 47 | 48 | class Encoder(nn.Module): 49 | def __init__(self, n_components, hidden=1024): 50 | super(Encoder, self).__init__() 51 | self.enc_net = nn.Sequential( 52 | nn.Linear(128, hidden), 53 | nn.ReLU(inplace=True), 54 | nn.BatchNorm1d(hidden), 55 | nn.Linear(hidden, n_components), 56 | ) 57 | 58 | def forward(self, x): 59 | output = self.enc_net(x) 60 | output = F.normalize(output, dim=1) 61 | return output 62 | 63 | class Decoder(nn.Module): 64 | def __init__(self, n_components, hidden=1024): 65 | super(Decoder, self).__init__() 66 | self.dec_net = nn.Sequential( 67 | nn.Linear(n_components, hidden), 68 | nn.ReLU(inplace=True), 69 | nn.BatchNorm1d(hidden), 70 | nn.Linear(hidden, 128), 71 | nn.ReLU() 72 | ) 73 | 74 | def forward(self, z): 75 | output = self.dec_net(z) 76 | output = F.normalize(output, dim=1) 77 | return output 78 | 79 | 80 | def distance_loss(encoders, decoders, batch, device, alpha=0.1): 81 | target_descriptors = batch 82 | embeddings = encoders(batch) 83 | 84 | t_loss = torch.tensor(0.).float().to(device) 85 | output_descriptors = decoders(embeddings) 86 | current_loss = torch.mean( 87 | torch.norm(output_descriptors - target_descriptors, dim=1) 88 | ) 89 | t_loss += current_loss 90 | 91 | e_loss = torch.tensor(0.).float().to(device) 92 | 93 | sqdist_matrix_embeddings = 2 - 2 * embeddings @ embeddings.T 94 | sqdist_matrix_target = 2 - 2 * target_descriptors @ target_descriptors.T 95 | 96 | e_loss += torch.mean( 97 | torch.abs(sqdist_matrix_target - sqdist_matrix_embeddings) 98 | ) 99 | 100 | if alpha > 0: 101 | loss = t_loss + alpha * e_loss 102 | else: 103 | loss = t_loss 104 | 105 | return loss, (t_loss.detach(), e_loss.detach()) 106 | 107 | class UpdatingMean(): 108 | def __init__(self): 109 | self.sum = 0 110 | self.n = 0 111 | 112 | def mean(self): 113 | return self.sum / self.n 114 | 115 | def add(self, loss): 116 | self.sum += loss 117 | self.n += 1 118 | 119 | device = torch.device('cuda:0') 120 | 121 | encoder = Encoder(args.reduce_dim, hidden=args.hidden) 122 | encoder.to(device) 123 | 124 | decoder = Decoder(args.reduce_dim, hidden=args.hidden) 125 | decoder.to(device) 126 | 127 | 128 | encoder_optimizer = optim.Adam(encoder.parameters(), lr=args.lr) 129 | decoder_optimizer = optim.Adam(decoder.parameters(), lr=args.lr) 130 | 131 | loss_function = lambda encoders, decoders, batch, device: distance_loss( 132 | encoders, decoders, batch, device, 133 | alpha=0.1 134 | ) 135 | 136 | train_dataloader = DataLoader(descriptors, batch_size=args.bsz, shuffle=True) 137 | print("start training") 138 | num_epochs = 10 139 | for epoch in range(num_epochs): 140 | encoder.train() 141 | decoder.train() 142 | epoch_loss = UpdatingMean() 143 | epoch_t_loss = UpdatingMean() 144 | epoch_e_loss = UpdatingMean() 145 | progress_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader)) 146 | 147 | for batch_idx, batch in progress_bar: 148 | encoder_optimizer.zero_grad() 149 | decoder_optimizer.zero_grad() 150 | 151 | batch = batch.to(device) 152 | 153 | loss, (t_loss, e_loss) = loss_function(encoder, decoder, batch, device) 154 | 155 | epoch_loss.add(loss.data.cpu().numpy()) 156 | epoch_t_loss.add(t_loss) 157 | epoch_e_loss.add(e_loss) 158 | progress_bar.set_postfix( 159 | loss=('%.4f' % epoch_loss.mean()), 160 | t_loss=('%.4f' % epoch_t_loss.mean()), 161 | e_loss=('%.4f' % epoch_e_loss.mean()) 162 | ) 163 | 164 | loss.backward() 165 | 166 | encoder_optimizer.step() 167 | decoder_optimizer.step() 168 | 169 | print ('Epoch {}, train_error: {:.4f}' 170 | .format(epoch, epoch_loss.mean())) 171 | 172 | file_name = 'models/ae_' + args.descriptor + '_' + str(args.reduce_dim) + '_' + args.dataset_names + '.pth' 173 | torch.save(encoder.state_dict(), file_name) 174 | -------------------------------------------------------------------------------- /dr_ae_SIFT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import os 4 | import numpy as np 5 | from torch.utils.data import DataLoader 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | import argparse 10 | import random 11 | 12 | parser = argparse.ArgumentParser(description='PyTorch dr') 13 | parser.add_argument('--descriptor', type=str, default='SIFT', help='descriptor') 14 | parser.add_argument('--dataset_names', type=str, default='liberty', help='dataset_names, notredame, yosemite, liberty') 15 | parser.add_argument('--reduce_dim', type=int, default=64, help='reduce_dim') 16 | parser.add_argument('--hidden', type=int, default=1024, help='hidden') 17 | parser.add_argument('--bsz', type=int, default=1024, help='bsz') 18 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 19 | parser.add_argument('--seed', type=int, default=0, help='random seed (default: 0)') 20 | args = parser.parse_args() 21 | 22 | np.random.seed(args.seed) 23 | torch.manual_seed(args.seed) 24 | random.seed(args.seed) 25 | torch.backends.cudnn.deterministic = True 26 | torch.backends.cudnn.benchmark = False 27 | 28 | 29 | class DescriotorDataset(Dataset): 30 | def __init__(self, des_dir, descriptor): 31 | self.descriptorsfile = os.path.join(des_dir, descriptor + '-' + args.dataset_names + '.npz') 32 | self.descriptors = np.load(self.descriptorsfile)['descriptors'] 33 | 34 | def __len__(self): 35 | return self.descriptors.shape[0] 36 | 37 | def __getitem__(self, idx): 38 | descriptor = self.descriptors[idx] 39 | return torch.from_numpy(descriptor) 40 | 41 | descriptors = DescriotorDataset( 42 | des_dir="raw_descriptors", 43 | descriptor=args.descriptor 44 | ) 45 | 46 | 47 | class Encoder(nn.Module): 48 | def __init__(self, n_components, hidden=1024): 49 | super(Encoder, self).__init__() 50 | self.enc_net = nn.Sequential( 51 | nn.Linear(128, hidden), 52 | nn.ReLU(inplace=True), 53 | nn.BatchNorm1d(hidden), 54 | nn.Linear(hidden, hidden), 55 | nn.ReLU(inplace=True), 56 | nn.BatchNorm1d(hidden), 57 | nn.Linear(hidden, n_components), 58 | ) 59 | 60 | def forward(self, x): 61 | output = self.enc_net(x) 62 | output = F.normalize(output, dim=1) 63 | return output 64 | 65 | class Decoder(nn.Module): 66 | def __init__(self, n_components, hidden=1024): 67 | super(Decoder, self).__init__() 68 | self.dec_net = nn.Sequential( 69 | nn.Linear(n_components, hidden), 70 | nn.ReLU(inplace=True), 71 | nn.BatchNorm1d(hidden), 72 | nn.Linear(hidden, hidden), 73 | nn.ReLU(inplace=True), 74 | nn.BatchNorm1d(hidden), 75 | nn.Linear(hidden, 128), 76 | nn.ReLU() 77 | ) 78 | 79 | def forward(self, z): 80 | output = self.dec_net(z) 81 | output = F.normalize(output, dim=1) 82 | return output 83 | 84 | device = torch.device('cuda:0') 85 | 86 | encoder = Encoder(args.reduce_dim, hidden=args.hidden) 87 | encoder.to(device) 88 | 89 | decoder = Decoder(args.reduce_dim, hidden=args.hidden) 90 | decoder.to(device) 91 | 92 | 93 | encoder_optimizer = optim.Adam(encoder.parameters(), lr=args.lr) 94 | decoder_optimizer = optim.Adam(decoder.parameters(), lr=args.lr) 95 | criterion = nn.MSELoss() 96 | train_dataloader = DataLoader(descriptors, batch_size=args.bsz, shuffle=True) 97 | print("start training") 98 | num_epochs = 5 99 | for epoch in range(num_epochs): 100 | encoder.train() 101 | decoder.train() 102 | losses = [] 103 | for descs in train_dataloader: 104 | encoder_optimizer.zero_grad() 105 | decoder_optimizer.zero_grad() 106 | 107 | descs = descs.to(device) 108 | output = encoder(descs) 109 | output = decoder(output) 110 | 111 | loss = criterion(output, descs) 112 | losses.append(loss.item()) 113 | loss.backward() 114 | 115 | encoder_optimizer.step() 116 | decoder_optimizer.step() 117 | 118 | mean_loss = np.mean(np.array(losses)) 119 | print ('Epoch {}, train_error: {:.4f}' 120 | .format(epoch, mean_loss)) 121 | 122 | file_name = 'models/ae_' + args.descriptor + '_' + str(args.reduce_dim) + '_' + args.dataset_names + '.pth' 123 | torch.save(encoder.state_dict(), file_name) 124 | -------------------------------------------------------------------------------- /dr_pca.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision import datasets 4 | import kornia 5 | import cv2 6 | import numpy as np 7 | import argparse 8 | from tqdm import tqdm 9 | from sklearn.decomposition import PCA 10 | import pickle as pk 11 | import os 12 | 13 | parser = argparse.ArgumentParser(description='PyTorch dr') 14 | parser.add_argument('--descriptor', type=str, default='SIFT', help='descriptor') 15 | parser.add_argument('--dataset_names', type=str, default='liberty', help='dataset_names, notredame, yosemite, liberty') 16 | parser.add_argument('--reduce_dim', type=int, default=64, help='reduce_dim') 17 | args = parser.parse_args() 18 | 19 | dataset = datasets.PhotoTour( 20 | root='./data', name=args.dataset_names, train=True, transform=None, download=True) 21 | 22 | if args.descriptor == 'SIFT': 23 | des = kornia.feature.SIFTDescriptor(32, 8, 4, False).cuda() 24 | elif args.descriptor == 'MKD': 25 | des = kornia.feature.MKDDescriptor().cuda() 26 | elif args.descriptor == 'TFeat': 27 | des = kornia.feature.TFeat(pretrained=True).cuda() 28 | elif args.descriptor == 'HardNet': 29 | des = kornia.feature.HardNet(pretrained=True).cuda() 30 | 31 | cv2_scale = lambda x: cv2.resize(x, dsize=(32, 32), 32 | interpolation=cv2.INTER_LINEAR) 33 | 34 | dataloader = DataLoader(dataset, batch_size=128, shuffle=False) 35 | descriptors = torch.empty((0,128), dtype=torch.float) 36 | 37 | for batch_idx, patches in enumerate(tqdm(dataloader)): 38 | patches_32 = np.empty([0,1,32,32]) 39 | patches = patches.cpu().detach().numpy() 40 | for i in range(patches.shape[0]): 41 | patch = cv2_scale(patches[i]) 42 | patch = np.expand_dims(patch, axis=0) 43 | patch = np.expand_dims(patch, axis=0) 44 | patches_32 = np.concatenate((patches_32,patch),axis=0) 45 | descs = des(torch.from_numpy(patches_32).float().cuda()).cpu().detach() 46 | descriptors = torch.vstack((descriptors,descs)) 47 | 48 | descriptors = descriptors.cpu().detach().numpy() 49 | print(descriptors.shape) 50 | 51 | descriptorsfile_name = args.descriptor + '-' + args.dataset_names + '.npz' 52 | descriptorsfile = os.path.join('raw_descriptors', descriptorsfile_name) 53 | np.savez(descriptorsfile, descriptors=descriptors) 54 | 55 | pca = PCA(n_components=args.reduce_dim) 56 | pca.fit(descriptors) 57 | 58 | descriptors_pca = pca.transform(descriptors) 59 | print(descriptors_pca.shape) 60 | 61 | save_name = 'pca' + str(args.reduce_dim) + '-' + args.descriptor + '-' + args.dataset_names + '.pkl' 62 | pk.dump(pca, open("models/"+save_name,"wb")) 63 | 64 | pca_reload = pk.load(open("models/"+save_name,'rb')) 65 | descriptors_pca_reload = pca_reload .transform(descriptors) 66 | print(descriptors_pca_reload.shape) -------------------------------------------------------------------------------- /dr_ss_HardNet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import time 5 | import random 6 | import faiss 7 | import numpy as np 8 | from copy import deepcopy 9 | from sklearn.metrics.cluster import normalized_mutual_info_score 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.optim 15 | import torch.utils.data 16 | import torchvision.transforms as transforms 17 | import torchvision.datasets as datasets 18 | import torchvision.datasets as dset 19 | from torch.utils.data import Dataset 20 | from torch.autograd import Variable 21 | import clustering 22 | from util import AverageMeter, Logger, UnifLabelSampler 23 | from tqdm import tqdm 24 | import kornia 25 | import copy 26 | import PIL 27 | import torch.nn.functional as F 28 | 29 | HardNet = kornia.feature.HardNet(pretrained=True).cuda() 30 | HardNet.eval() 31 | 32 | def ErrorRateAt95Recall(labels, scores): 33 | distances = 1.0 / (scores + 1e-8) 34 | recall_point = 0.95 35 | labels = labels[np.argsort(distances)] 36 | threshold_index = np.argmax(np.cumsum(labels) >= recall_point * np.sum(labels)) 37 | 38 | FP = np.sum(labels[:threshold_index] == 0) # Below threshold (i.e., labelled positive), but should be negative 39 | TN = np.sum(labels[threshold_index:] == 0) # Above threshold (i.e., labelled negative), and should be negative 40 | return float(FP) / float(FP + TN) 41 | 42 | dataset_names = ['liberty', 'notredame', 'yosemite'] 43 | def parse_args(): 44 | parser = argparse.ArgumentParser(description='PyTorch Implementation of DeepCluster') 45 | parser.add_argument('--arch', '-a', type=str, metavar='ARCH', 46 | choices=['sift', 'hardnet'], default='sift', 47 | help='architecture (default: sift)') 48 | parser.add_argument('--sobel', action='store_true', help='Sobel filtering') 49 | parser.add_argument('--clustering', type=str, choices=['Kmeans', 'PIC'], 50 | default='Kmeans', help='clustering algorithm (default: Kmeans)') 51 | parser.add_argument('--nmb_cluster', '--k', type=int, default=100000, 52 | help='number of cluster for k-means (default: 10000)') 53 | parser.add_argument('--lr', default=0.001, type=float, 54 | help='learning rate (default: 0.05)') 55 | parser.add_argument('--wd', default=-5, type=float, 56 | help='weight decay pow (default: -5)') 57 | parser.add_argument('--reassign', type=float, default=10., 58 | help="""how many epochs of training between two consecutive 59 | reassignments of clusters (default: 1)""") 60 | parser.add_argument('--workers', default=4, type=int, 61 | help='number of data loading workers (default: 4)') 62 | parser.add_argument('--epochs', type=int, default=20, 63 | help='number of total epochs to run (default: 200)') 64 | parser.add_argument('--start_epoch', default=0, type=int, 65 | help='manual epoch number (useful on restarts) (default: 0)') 66 | parser.add_argument('--batch', default=256, type=int, 67 | help='mini-batch size (default: 256)') 68 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)') 69 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 70 | help='path to checkpoint (default: None)') 71 | parser.add_argument('--checkpoints', type=int, default=25000, 72 | help='how many iterations between two checkpoints (default: 25000)') 73 | parser.add_argument('--seed', type=int, default=31, help='random seed (default: 31)') 74 | parser.add_argument('--exp', type=str, default='data/logs/ss', help='path to exp folder') 75 | parser.add_argument('--verbose', action='store_true', default=True, help='chatty') 76 | parser.add_argument('--training-set', default= 'liberty', 77 | help='Other options: liberty notredame, yosemite') 78 | parser.add_argument('--dataroot', type=str, 79 | default='data/', 80 | help='path to dataset') 81 | parser.add_argument('--reduce_dim', type=int, default=64, help='reduce_dim') 82 | parser.add_argument('--descriptor', type=str, default='SIFT', help='descriptor') 83 | parser.add_argument('--dataset_names', type=str, default='liberty', help='dataset_names, notredame, yosemite, liberty') 84 | return parser.parse_args() 85 | 86 | 87 | class ArcClassifier(nn.Module): 88 | def __init__(self, dim, num_classes, margin=0.1, gamma=1.0, 89 | trainable_gamma=True, eps=1e-7): 90 | super().__init__() 91 | self.weight = nn.parameter.Parameter(torch.empty([num_classes, dim])) 92 | nn.init.xavier_uniform_(self.weight) 93 | self.margin = margin 94 | self.eps = eps 95 | self.gamma = nn.parameter.Parameter(torch.ones(1) * gamma) 96 | if not trainable_gamma: 97 | self.gamma.requires_grad = False 98 | def forward(self, x, labels): 99 | raw_logits = F.linear(x, F.normalize(self.weight)) 100 | theta = torch.acos(raw_logits.clamp(-1 + self.eps, 1 - self.eps)) 101 | # Only apply margin if theta <= np.pi - self.margin. 102 | # mask = (theta <= np.pi - self.margin) 103 | # marginal_target_logits = torch.where( 104 | # mask, torch.cos(theta + self.margin), raw_logits) 105 | # Only apply margin if it lowers the logit. 106 | marginal_target_logits = torch.min(torch.cos(theta + self.margin), raw_logits) 107 | one_hot = F.one_hot(labels, num_classes=raw_logits.size(1)).bool() 108 | final_logits = torch.where(one_hot, marginal_target_logits, raw_logits) 109 | final_logits *= self.gamma 110 | return final_logits 111 | 112 | class MLP(nn.Module): 113 | def __init__(self, n_components = 64, hidden=96): 114 | super(MLP, self).__init__() 115 | self.classifier = nn.Linear(128, n_components) 116 | self._initialize_weights() 117 | 118 | def forward(self, x): 119 | x = self.classifier(x) 120 | x = F.normalize(x, dim=1) 121 | return x 122 | 123 | def _initialize_weights(self): 124 | for y, m in enumerate(self.modules()): 125 | if isinstance(m, nn.BatchNorm1d): 126 | m.weight.data.fill_(1) 127 | m.bias.data.zero_() 128 | elif isinstance(m, nn.Linear): 129 | m.weight.data.normal_(0, 0.01) 130 | m.bias.data.zero_() 131 | 132 | def main(args, test_loaders=[]): 133 | # fix random seeds 134 | torch.manual_seed(args.seed) 135 | torch.cuda.manual_seed_all(args.seed) 136 | np.random.seed(args.seed) 137 | 138 | # CNN 139 | if args.verbose: 140 | print('Architecture: {}'.format(args.arch)) 141 | 142 | n_components=64 143 | model = MLP(n_components=n_components,hidden=96) 144 | model.cuda() 145 | 146 | cudnn.benchmark = True 147 | 148 | optimizer = torch.optim.Adam( 149 | filter(lambda x: x.requires_grad, model.parameters()), 150 | ) 151 | 152 | # define loss function 153 | criterion = nn.CrossEntropyLoss().cuda() 154 | 155 | # optionally resume from a checkpoint 156 | if args.resume: 157 | if os.path.isfile(args.resume): 158 | print("=> loading checkpoint '{}'".format(args.resume)) 159 | checkpoint = torch.load(args.resume) 160 | args.start_epoch = checkpoint['epoch'] 161 | # remove top_layer parameters from checkpoint 162 | for key in checkpoint['state_dict']: 163 | if 'top_layer' in key: 164 | del checkpoint['state_dict'][key] 165 | model.load_state_dict(checkpoint['state_dict']) 166 | optimizer.load_state_dict(checkpoint['optimizer']) 167 | print("=> loaded checkpoint '{}' (epoch {})" 168 | .format(args.resume, checkpoint['epoch'])) 169 | else: 170 | print("=> no checkpoint found at '{}'".format(args.resume)) 171 | 172 | # creating checkpoint repo 173 | exp_check = os.path.join(args.exp, 'checkpoints') 174 | if not os.path.isdir(exp_check): 175 | os.makedirs(exp_check) 176 | 177 | # creating cluster assignments log 178 | cluster_log = Logger(os.path.join(args.exp, 'clusters')) 179 | 180 | # load the data 181 | end = time.time() 182 | 183 | # clustering algorithm to use 184 | deepcluster = clustering.__dict__[args.clustering](args.nmb_cluster) 185 | 186 | # training convnet with DeepCluster 187 | for epoch in range(args.start_epoch, args.epochs): 188 | end = time.time() 189 | dataloader = create_train_loaders() 190 | if epoch > 0: 191 | features, descriptors = compute_features(dataloader, model, 450092) 192 | else: 193 | features = compute_features_init(dataloader, 450092) 194 | descriptors = features 195 | 196 | # cluster the features 197 | if args.verbose: 198 | print('Cluster the features') 199 | clustering_loss = deepcluster.cluster(features, verbose=args.verbose) 200 | 201 | # assign pseudo-labels 202 | if args.verbose: 203 | print('Assign pseudo labels') 204 | train_dataset = clustering.cluster_assign(deepcluster.images_lists, 205 | descriptors) 206 | 207 | # uniformly sample per target 208 | sampler = UnifLabelSampler(int(args.reassign * len(train_dataset)), 209 | deepcluster.images_lists) 210 | 211 | train_dataloader = torch.utils.data.DataLoader( 212 | train_dataset, 213 | batch_size=args.batch, 214 | num_workers=args.workers, 215 | sampler=sampler, 216 | pin_memory=True, 217 | ) 218 | 219 | classifier = ArcClassifier(dim=n_components, num_classes=args.nmb_cluster) 220 | classifier.cuda() 221 | 222 | # train network with clusters as pseudo-labels 223 | end = time.time() 224 | loss = train(train_dataloader, model, classifier, criterion, optimizer, epoch) 225 | 226 | # print log 227 | if args.verbose: 228 | print('###### Epoch [{0}] ###### \n' 229 | 'Time: {1:.3f} s\n' 230 | 'Clustering loss: {2:.3f} \n' 231 | 'ConvNet loss: {3:.3f}' 232 | .format(epoch, time.time() - end, clustering_loss, loss)) 233 | try: 234 | nmi = normalized_mutual_info_score( 235 | clustering.arrange_clustering(deepcluster.images_lists), 236 | clustering.arrange_clustering(cluster_log.data[-1]) 237 | ) 238 | print('NMI against previous assignment: {0:.3f}'.format(nmi)) 239 | except IndexError: 240 | pass 241 | print('####################### \n') 242 | # save running checkpoint 243 | torch.save({'epoch': epoch + 1, 244 | 'arch': args.arch, 245 | 'state_dict': model.state_dict(), 246 | 'optimizer' : optimizer.state_dict()}, 247 | os.path.join(args.exp, 'checkpoint_20000_adam_200.pth.tar')) 248 | 249 | # save cluster assignments 250 | cluster_log.log(deepcluster.images_lists) 251 | 252 | file_name = 'models/ss_' + args.descriptor + '_' + str(args.reduce_dim) + '_' + args.dataset_names + '.pth' 253 | torch.save(model.state_dict(), file_name) 254 | for test_loader in test_loaders: 255 | test(test_loader['dataloader'], model, epoch, test_loader['name']) 256 | # test(test_loaders[0]['dataloader'], model, epoch, test_loaders[0]['name']) 257 | 258 | 259 | def train(loader, model, classifier, crit, opt, epoch): 260 | batch_time = AverageMeter() 261 | losses = AverageMeter() 262 | data_time = AverageMeter() 263 | forward_time = AverageMeter() 264 | backward_time = AverageMeter() 265 | 266 | # switch to train mode 267 | model.train() 268 | classifier.train() 269 | optimizer_tl = torch.optim.Adam( 270 | classifier.parameters(), 271 | ) 272 | 273 | end = time.time() 274 | for i, (input_tensor, target) in enumerate(loader): 275 | data_time.update(time.time() - end) 276 | 277 | # save checkpoint 278 | n = len(loader) * epoch + i 279 | if n % args.checkpoints == 0: 280 | path = os.path.join( 281 | args.exp, 282 | 'checkpoints', 283 | 'checkpoint_' + str(n / args.checkpoints) + '.pth.tar', 284 | ) 285 | if args.verbose: 286 | print('Save checkpoint at: {0}'.format(path)) 287 | torch.save({ 288 | 'epoch': epoch + 1, 289 | 'arch': args.arch, 290 | 'state_dict': model.state_dict(), 291 | 'optimizer' : opt.state_dict() 292 | }, path) 293 | 294 | target = target.cuda() 295 | input_var = torch.autograd.Variable(input_tensor.cuda()) 296 | target_var = torch.autograd.Variable(target) 297 | 298 | output = model(input_var) 299 | output = classifier(output,target_var) 300 | loss = crit(output, target_var) 301 | 302 | # record loss 303 | losses.update(loss.item(), input_tensor.size(0)) 304 | 305 | # compute gradient and do SGD step 306 | opt.zero_grad() 307 | optimizer_tl.zero_grad() 308 | loss.backward() 309 | opt.step() 310 | optimizer_tl.step() 311 | 312 | # measure elapsed time 313 | batch_time.update(time.time() - end) 314 | end = time.time() 315 | 316 | if args.verbose and (i % 200) == 0: 317 | print('Epoch: [{0}][{1}/{2}]\t' 318 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 319 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})\t' 320 | 'Loss: {loss.val:.4f} ({loss.avg:.4f})' 321 | .format(epoch, i, len(loader), batch_time=batch_time, 322 | data_time=data_time, loss=losses)) 323 | 324 | return losses.avg 325 | 326 | class TripletPhotoTour(dset.PhotoTour): 327 | """ 328 | From the PhotoTour Dataset it generates triplet samples 329 | note: a triplet is composed by a pair of matching images and one of 330 | different class. 331 | """ 332 | def __init__(self, train=True, transform=None, batch_size = None,load_random_triplets = False, *arg, **kw): 333 | super(TripletPhotoTour, self).__init__(*arg, **kw) 334 | self.transform = transform 335 | self.out_triplets = load_random_triplets 336 | self.train = train 337 | self.n_triplets = 5000000 338 | self.batch_size = batch_size 339 | 340 | if self.train: 341 | print('Generating {} triplets'.format(self.n_triplets)) 342 | self.triplets = self.generate_triplets(self.labels, self.n_triplets) 343 | 344 | @staticmethod 345 | def generate_triplets(labels, num_triplets): 346 | def create_indices(_labels): 347 | inds = dict() 348 | for idx, ind in enumerate(_labels): 349 | if ind not in inds: 350 | inds[ind] = [] 351 | inds[ind].append(idx) 352 | return inds 353 | 354 | triplets = [] 355 | indices = create_indices(labels.numpy()) 356 | unique_labels = np.unique(labels.numpy()) 357 | n_classes = unique_labels.shape[0] 358 | # add only unique indices in batch 359 | already_idxs = set() 360 | 361 | for x in tqdm(range(num_triplets)): 362 | if len(already_idxs) >= 1024: 363 | already_idxs = set() 364 | c1 = np.random.randint(0, n_classes) 365 | while c1 in already_idxs: 366 | c1 = np.random.randint(0, n_classes) 367 | already_idxs.add(c1) 368 | c2 = np.random.randint(0, n_classes) 369 | while c1 == c2: 370 | c2 = np.random.randint(0, n_classes) 371 | if len(indices[c1]) == 2: # hack to speed up process 372 | n1, n2 = 0, 1 373 | else: 374 | n1 = np.random.randint(0, len(indices[c1])) 375 | n2 = np.random.randint(0, len(indices[c1])) 376 | while n1 == n2: 377 | n2 = np.random.randint(0, len(indices[c1])) 378 | n3 = np.random.randint(0, len(indices[c2])) 379 | triplets.append([indices[c1][n1], indices[c1][n2], indices[c2][n3]]) 380 | return torch.LongTensor(np.array(triplets)) 381 | 382 | def __getitem__(self, index): 383 | def transform_img(img): 384 | if self.transform is not None: 385 | img = self.transform(img.numpy()) 386 | return img 387 | 388 | if not self.train: 389 | m = self.matches[index] 390 | img1 = transform_img(self.data[m[0]]) 391 | img2 = transform_img(self.data[m[1]]) 392 | return img1, img2, m[2] 393 | 394 | t = self.triplets[index] 395 | a, p, n = self.data[t[0]], self.data[t[1]], self.data[t[2]] 396 | 397 | img_a = transform_img(a) 398 | img_p = transform_img(p) 399 | img_n = None 400 | if self.out_triplets: 401 | img_n = transform_img(n) 402 | # transform images if required 403 | if True: 404 | do_flip = random.random() > 0.5 405 | do_rot = random.random() > 0.5 406 | if do_rot: 407 | img_a = img_a.permute(0,2,1) 408 | img_p = img_p.permute(0,2,1) 409 | if self.out_triplets: 410 | img_n = img_n.permute(0,2,1) 411 | if do_flip: 412 | img_a = torch.from_numpy(deepcopy(img_a.numpy()[:,:,::-1])) 413 | img_p = torch.from_numpy(deepcopy(img_p.numpy()[:,:,::-1])) 414 | if self.out_triplets: 415 | img_n = torch.from_numpy(deepcopy(img_n.numpy()[:,:,::-1])) 416 | if self.out_triplets: 417 | return (img_a, img_p, img_n) 418 | else: 419 | return (img_a, img_p) 420 | 421 | def __len__(self): 422 | if self.train: 423 | return self.triplets.size(0) 424 | else: 425 | return self.matches.size(0) 426 | 427 | def create_loaders(): 428 | 429 | test_dataset_names = copy.copy(dataset_names) 430 | test_dataset_names.remove(args.training_set) 431 | 432 | kwargs = {'num_workers': args.workers, 'pin_memory': True} 433 | 434 | np_reshape64 = lambda x: np.reshape(x, (64, 64, 1)) 435 | transform_test = transforms.Compose([ 436 | transforms.Lambda(np_reshape64), 437 | transforms.ToPILImage(), 438 | transforms.Resize(32), 439 | transforms.ToTensor()]) 440 | 441 | test_loaders = [{'name': name, 442 | 'dataloader': torch.utils.data.DataLoader( 443 | TripletPhotoTour(train=False, 444 | batch_size=1024, 445 | root=args.dataroot, 446 | name=name, 447 | download=True, 448 | transform=transform_test), 449 | batch_size=1024, 450 | shuffle=False, **kwargs)} 451 | for name in test_dataset_names] 452 | 453 | return test_loaders 454 | 455 | class NewPhotoTour(dset.PhotoTour): 456 | """ 457 | From the PhotoTour Dataset it generates triplet samples 458 | note: a triplet is composed by a pair of matching images and one of 459 | different class. 460 | """ 461 | def __init__(self, *arg, **kw): 462 | super(NewPhotoTour, self).__init__(*arg, **kw) 463 | 464 | def __getitem__(self, index): 465 | if self.train: 466 | data = self.data[index] 467 | if self.transform is not None: 468 | data = self.transform(data.numpy()) 469 | return data 470 | 471 | def create_train_loaders(): 472 | kwargs = {'num_workers': args.workers, 'pin_memory': True} 473 | np_reshape64 = lambda x: np.reshape(x, (64, 64, 1)) 474 | transform_train = transforms.Compose([ 475 | transforms.Lambda(np_reshape64), 476 | transforms.ToPILImage(), 477 | transforms.RandomRotation(5,PIL.Image.BILINEAR), 478 | transforms.RandomResizedCrop(32, scale = (0.9,1.0),ratio = (0.9,1.1)), 479 | transforms.Resize(32), 480 | transforms.ToTensor()]) 481 | 482 | train_loader = torch.utils.data.DataLoader( 483 | NewPhotoTour( 484 | root='./data', 485 | name=args.training_set, 486 | train=True, 487 | transform=transform_train, 488 | download=True), 489 | batch_size=args.batch, 490 | shuffle=False, **kwargs) 491 | 492 | return train_loader 493 | 494 | def test(test_loader, model, epoch, logger_test_name): 495 | # switch to evaluate mode 496 | model.eval() 497 | 498 | labels, distances = [], [] 499 | 500 | pbar = tqdm(enumerate(test_loader)) 501 | with torch.no_grad(): 502 | for batch_idx, (data_a, data_p, label) in pbar: 503 | data_a, data_p = data_a.cuda(), data_p.cuda() 504 | out_a = model(HardNet(data_a)) 505 | out_p = model(HardNet(data_p)) 506 | dists = torch.sqrt(torch.sum((out_a - out_p) ** 2, 1)) # euclidean distance 507 | distances.append(dists.data.cpu().numpy().reshape(-1,1)) 508 | ll = label.data.cpu().numpy().reshape(-1, 1) 509 | labels.append(ll) 510 | 511 | if batch_idx % 10 == 0: 512 | pbar.set_description(logger_test_name+' Test Epoch: {} [{}/{} ({:.0f}%)]'.format( 513 | epoch, batch_idx * len(data_a), len(test_loader.dataset), 514 | 100. * batch_idx / len(test_loader))) 515 | 516 | num_tests = test_loader.dataset.matches.size(0) 517 | labels = np.vstack(labels).reshape(num_tests) 518 | distances = np.vstack(distances).reshape(num_tests) 519 | 520 | fpr95 = ErrorRateAt95Recall(labels, 1.0 / (distances + 1e-8)) 521 | print('\33[91mTest set: Accuracy(FPR95): {:.8f}\n\33[0m'.format(fpr95)) 522 | 523 | 524 | def compute_features(dataloader, model, N): 525 | if args.verbose: 526 | print('Compute features') 527 | batch_time = AverageMeter() 528 | end = time.time() 529 | model.eval() 530 | # discard the label information in the dataloader 531 | with torch.no_grad(): 532 | for i, input_tensor in enumerate(dataloader): 533 | input_var = input_tensor.cuda() 534 | des = HardNet(input_var) 535 | aux = model(des).cpu().detach().numpy() 536 | descr = des.cpu().detach().numpy() 537 | 538 | if i == 0: 539 | features = np.zeros((N, aux.shape[1]), dtype='float32') 540 | descrs = np.zeros((N, descr.shape[1]), dtype='float32') 541 | 542 | aux = aux.astype('float32') 543 | if i < len(dataloader) - 1: 544 | features[i * args.batch: (i + 1) * args.batch] = aux 545 | descrs[i * args.batch: (i + 1) * args.batch] = descr 546 | else: 547 | # special treatment for final batch 548 | features[i * args.batch:] = aux 549 | descrs[i * args.batch:] = descr 550 | 551 | # measure elapsed time 552 | batch_time.update(time.time() - end) 553 | end = time.time() 554 | 555 | if args.verbose and (i % 200) == 0: 556 | print('{0} / {1}\t' 557 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})' 558 | .format(i, len(dataloader), batch_time=batch_time)) 559 | return features, descrs 560 | 561 | def compute_features_init(dataloader, N): 562 | if args.verbose: 563 | print('Compute features init') 564 | batch_time = AverageMeter() 565 | end = time.time() 566 | # discard the label information in the dataloader 567 | for i, input_tensor in enumerate(dataloader): 568 | aux = HardNet(input_tensor.cuda()).cpu().detach().numpy() 569 | 570 | if i == 0: 571 | features = np.zeros((N, aux.shape[1]), dtype='float32') 572 | 573 | aux = aux.astype('float32') 574 | if i < len(dataloader) - 1: 575 | features[i * args.batch: (i + 1) * args.batch] = aux 576 | else: 577 | # special treatment for final batch 578 | features[i * args.batch:] = aux 579 | 580 | # measure elapsed time 581 | batch_time.update(time.time() - end) 582 | end = time.time() 583 | 584 | if args.verbose and (i % 200) == 0: 585 | print('{0} / {1}\t' 586 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})' 587 | .format(i, len(dataloader), batch_time=batch_time)) 588 | return features 589 | 590 | 591 | if __name__ == '__main__': 592 | args = parse_args() 593 | test_loader = create_loaders() 594 | #test_loader = [] 595 | main(args,test_loader) 596 | -------------------------------------------------------------------------------- /dr_ss_SIFT.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import time 5 | import random 6 | import faiss 7 | import numpy as np 8 | from copy import deepcopy 9 | from sklearn.metrics.cluster import normalized_mutual_info_score 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.parallel 13 | import torch.backends.cudnn as cudnn 14 | import torch.optim 15 | import torch.utils.data 16 | import torchvision.transforms as transforms 17 | import torchvision.datasets as datasets 18 | import torchvision.datasets as dset 19 | from torch.utils.data import Dataset 20 | from torch.autograd import Variable 21 | import clustering 22 | from util import AverageMeter, Logger, UnifLabelSampler 23 | from tqdm import tqdm 24 | import kornia 25 | import copy 26 | import PIL 27 | import torch.nn.functional as F 28 | 29 | SIFT = kornia.feature.SIFTDescriptor(32, 8, 4, False).cuda() 30 | SIFT.eval() 31 | 32 | def ErrorRateAt95Recall(labels, scores): 33 | distances = 1.0 / (scores + 1e-8) 34 | recall_point = 0.95 35 | labels = labels[np.argsort(distances)] 36 | threshold_index = np.argmax(np.cumsum(labels) >= recall_point * np.sum(labels)) 37 | 38 | FP = np.sum(labels[:threshold_index] == 0) # Below threshold (i.e., labelled positive), but should be negative 39 | TN = np.sum(labels[threshold_index:] == 0) # Above threshold (i.e., labelled negative), and should be negative 40 | return float(FP) / float(FP + TN) 41 | 42 | dataset_names = ['liberty', 'notredame', 'yosemite'] 43 | def parse_args(): 44 | parser = argparse.ArgumentParser(description='PyTorch Implementation of DeepCluster') 45 | parser.add_argument('--arch', '-a', type=str, metavar='ARCH', 46 | choices=['sift', 'hardnet'], default='sift', 47 | help='architecture (default: sift)') 48 | parser.add_argument('--sobel', action='store_true', help='Sobel filtering') 49 | parser.add_argument('--clustering', type=str, choices=['Kmeans', 'PIC'], 50 | default='Kmeans', help='clustering algorithm (default: Kmeans)') 51 | parser.add_argument('--nmb_cluster', '--k', type=int, default=100000, 52 | help='number of cluster for k-means (default: 10000)') 53 | parser.add_argument('--lr', default=0.001, type=float, 54 | help='learning rate (default: 0.05)') 55 | parser.add_argument('--wd', default=-5, type=float, 56 | help='weight decay pow (default: -5)') 57 | parser.add_argument('--reassign', type=float, default=10., 58 | help="""how many epochs of training between two consecutive 59 | reassignments of clusters (default: 1)""") 60 | parser.add_argument('--workers', default=4, type=int, 61 | help='number of data loading workers (default: 4)') 62 | parser.add_argument('--epochs', type=int, default=20, 63 | help='number of total epochs to run (default: 200)') 64 | parser.add_argument('--start_epoch', default=0, type=int, 65 | help='manual epoch number (useful on restarts) (default: 0)') 66 | parser.add_argument('--batch', default=256, type=int, 67 | help='mini-batch size (default: 256)') 68 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)') 69 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 70 | help='path to checkpoint (default: None)') 71 | parser.add_argument('--checkpoints', type=int, default=25000, 72 | help='how many iterations between two checkpoints (default: 25000)') 73 | parser.add_argument('--seed', type=int, default=31, help='random seed (default: 31)') 74 | parser.add_argument('--exp', type=str, default='data/logs/ss', help='path to exp folder') 75 | parser.add_argument('--verbose', action='store_true', default=True, help='chatty') 76 | parser.add_argument('--training-set', default= 'liberty', 77 | help='Other options: notredame, yosemite') 78 | parser.add_argument('--dataroot', type=str, 79 | default='data/', 80 | help='path to dataset') 81 | parser.add_argument('--reduce_dim', type=int, default=64, help='reduce_dim') 82 | parser.add_argument('--descriptor', type=str, default='SIFT', help='descriptor') 83 | parser.add_argument('--dataset_names', type=str, default='liberty', help='dataset_names, notredame, yosemite, liberty') 84 | return parser.parse_args() 85 | 86 | 87 | class ArcClassifier(nn.Module): 88 | def __init__(self, dim, num_classes, margin=0.1, gamma=1.0, 89 | trainable_gamma=True, eps=1e-7): 90 | super().__init__() 91 | self.weight = nn.parameter.Parameter(torch.empty([num_classes, dim])) 92 | nn.init.xavier_uniform_(self.weight) 93 | self.margin = margin 94 | self.eps = eps 95 | self.gamma = nn.parameter.Parameter(torch.ones(1) * gamma) 96 | if not trainable_gamma: 97 | self.gamma.requires_grad = False 98 | def forward(self, x, labels): 99 | raw_logits = F.linear(x, F.normalize(self.weight)) 100 | theta = torch.acos(raw_logits.clamp(-1 + self.eps, 1 - self.eps)) 101 | # Only apply margin if theta <= np.pi - self.margin. 102 | # mask = (theta <= np.pi - self.margin) 103 | # marginal_target_logits = torch.where( 104 | # mask, torch.cos(theta + self.margin), raw_logits) 105 | # Only apply margin if it lowers the logit. 106 | marginal_target_logits = torch.min(torch.cos(theta + self.margin), raw_logits) 107 | one_hot = F.one_hot(labels, num_classes=raw_logits.size(1)).bool() 108 | final_logits = torch.where(one_hot, marginal_target_logits, raw_logits) 109 | final_logits *= self.gamma 110 | return final_logits 111 | 112 | class MLP(nn.Module): 113 | def __init__(self, n_components = 64,hidden=96): 114 | super(MLP, self).__init__() 115 | 116 | self.classifier = nn.Sequential( 117 | nn.Linear(128, hidden), 118 | nn.ReLU(inplace=True), 119 | nn.BatchNorm1d(hidden), 120 | nn.Linear(hidden, n_components), 121 | ) 122 | self._initialize_weights() 123 | 124 | def forward(self, x): 125 | x = self.classifier(x) 126 | x = F.normalize(x, dim=1) 127 | return x 128 | 129 | def _initialize_weights(self): 130 | for y, m in enumerate(self.modules()): 131 | if isinstance(m, nn.BatchNorm1d): 132 | m.weight.data.fill_(1) 133 | m.bias.data.zero_() 134 | elif isinstance(m, nn.Linear): 135 | m.weight.data.normal_(0, 0.01) 136 | m.bias.data.zero_() 137 | 138 | def main(args, test_loaders=[]): 139 | # fix random seeds 140 | torch.manual_seed(args.seed) 141 | torch.cuda.manual_seed_all(args.seed) 142 | np.random.seed(args.seed) 143 | 144 | if args.verbose: 145 | print('Architecture: {}'.format(args.arch)) 146 | 147 | model = MLP(n_components=args.reduce_dim, hidden=512) 148 | model.cuda() 149 | 150 | cudnn.benchmark = True 151 | 152 | optimizer = torch.optim.Adam( 153 | filter(lambda x: x.requires_grad, model.parameters()), 154 | ) 155 | 156 | criterion = nn.CrossEntropyLoss().cuda() 157 | 158 | # optionally resume from a checkpoint 159 | if args.resume: 160 | if os.path.isfile(args.resume): 161 | print("=> loading checkpoint '{}'".format(args.resume)) 162 | checkpoint = torch.load(args.resume) 163 | args.start_epoch = checkpoint['epoch'] 164 | # remove top_layer parameters from checkpoint 165 | for key in checkpoint['state_dict']: 166 | if 'top_layer' in key: 167 | del checkpoint['state_dict'][key] 168 | model.load_state_dict(checkpoint['state_dict']) 169 | optimizer.load_state_dict(checkpoint['optimizer']) 170 | print("=> loaded checkpoint '{}' (epoch {})" 171 | .format(args.resume, checkpoint['epoch'])) 172 | else: 173 | print("=> no checkpoint found at '{}'".format(args.resume)) 174 | 175 | # creating checkpoint repo 176 | exp_check = os.path.join(args.exp, 'checkpoints') 177 | if not os.path.isdir(exp_check): 178 | os.makedirs(exp_check) 179 | 180 | # creating cluster assignments log 181 | cluster_log = Logger(os.path.join(args.exp, 'clusters')) 182 | 183 | # load the data 184 | end = time.time() 185 | 186 | # clustering algorithm to use 187 | deepcluster = clustering.__dict__[args.clustering](args.nmb_cluster) 188 | 189 | # training convnet with DeepCluster 190 | for epoch in range(args.start_epoch, args.epochs): 191 | end = time.time() 192 | 193 | dataloader = create_train_loaders() 194 | if epoch > 0: 195 | features, descriptors = compute_features(dataloader, model, 450092) 196 | else: 197 | features = compute_features_init(dataloader, 450092) 198 | descriptors = features 199 | 200 | # cluster the features 201 | if args.verbose: 202 | print('Cluster the features') 203 | clustering_loss = deepcluster.cluster(features, verbose=args.verbose) 204 | 205 | # assign pseudo-labels 206 | if args.verbose: 207 | print('Assign pseudo labels') 208 | train_dataset = clustering.cluster_assign(deepcluster.images_lists, 209 | descriptors) 210 | 211 | # uniformly sample per target 212 | sampler = UnifLabelSampler(int(args.reassign * len(train_dataset)), 213 | deepcluster.images_lists) 214 | 215 | train_dataloader = torch.utils.data.DataLoader( 216 | train_dataset, 217 | batch_size=args.batch, 218 | num_workers=args.workers, 219 | sampler=sampler, 220 | pin_memory=True, 221 | ) 222 | 223 | classifier = ArcClassifier(dim=args.reduce_dim, num_classes=args.nmb_cluster) 224 | classifier.cuda() 225 | 226 | # train network with clusters as pseudo-labels 227 | end = time.time() 228 | loss = train(train_dataloader, model, classifier, criterion, optimizer, epoch) 229 | 230 | # print log 231 | if args.verbose: 232 | print('###### Epoch [{0}] ###### \n' 233 | 'Time: {1:.3f} s\n' 234 | 'Clustering loss: {2:.3f} \n' 235 | 'ConvNet loss: {3:.3f}' 236 | .format(epoch, time.time() - end, clustering_loss, loss)) 237 | try: 238 | nmi = normalized_mutual_info_score( 239 | clustering.arrange_clustering(deepcluster.images_lists), 240 | clustering.arrange_clustering(cluster_log.data[-1]) 241 | ) 242 | print('NMI against previous assignment: {0:.3f}'.format(nmi)) 243 | except IndexError: 244 | pass 245 | print('####################### \n') 246 | # save running checkpoint 247 | torch.save({'epoch': epoch + 1, 248 | 'arch': args.arch, 249 | 'state_dict': model.state_dict(), 250 | 'optimizer' : optimizer.state_dict()}, 251 | os.path.join(args.exp, 'checkpoint_20000_adam_200.pth.tar')) 252 | 253 | # save cluster assignments 254 | cluster_log.log(deepcluster.images_lists) 255 | 256 | 257 | file_name = 'models/ss_' + args.descriptor + '_' + str(args.reduce_dim) + '_' + args.dataset_names + '.pth' 258 | torch.save(model.state_dict(), file_name) 259 | test(test_loaders[0]['dataloader'], model, epoch, test_loaders[0]['name']) 260 | 261 | 262 | def train(loader, model, classifier, crit, opt, epoch): 263 | batch_time = AverageMeter() 264 | losses = AverageMeter() 265 | data_time = AverageMeter() 266 | 267 | # switch to train mode 268 | model.train() 269 | classifier.train() 270 | 271 | optimizer_tl = torch.optim.Adam( 272 | classifier.parameters(), 273 | ) 274 | 275 | end = time.time() 276 | for i, (input_tensor, target) in enumerate(loader): 277 | data_time.update(time.time() - end) 278 | 279 | # save checkpoint 280 | n = len(loader) * epoch + i 281 | if n % args.checkpoints == 0: 282 | path = os.path.join( 283 | args.exp, 284 | 'checkpoints', 285 | 'checkpoint_' + str(n / args.checkpoints) + '.pth.tar', 286 | ) 287 | if args.verbose: 288 | print('Save checkpoint at: {0}'.format(path)) 289 | torch.save({ 290 | 'epoch': epoch + 1, 291 | 'arch': args.arch, 292 | 'state_dict': model.state_dict(), 293 | 'optimizer' : opt.state_dict() 294 | }, path) 295 | 296 | target = target.cuda() 297 | input_var = torch.autograd.Variable(input_tensor.cuda()) 298 | target_var = torch.autograd.Variable(target) 299 | 300 | output = model(input_var) 301 | output = classifier(output,target_var) 302 | loss = crit(output, target_var) 303 | 304 | # record loss 305 | losses.update(loss.item(), input_tensor.size(0)) 306 | 307 | # compute gradient and do SGD step 308 | opt.zero_grad() 309 | optimizer_tl.zero_grad() 310 | loss.backward() 311 | opt.step() 312 | optimizer_tl.step() 313 | 314 | # measure elapsed time 315 | batch_time.update(time.time() - end) 316 | end = time.time() 317 | 318 | if args.verbose and (i % 200) == 0: 319 | print('Epoch: [{0}][{1}/{2}]\t' 320 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 321 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})\t' 322 | 'Loss: {loss.val:.4f} ({loss.avg:.4f})' 323 | .format(epoch, i, len(loader), batch_time=batch_time, 324 | data_time=data_time, loss=losses)) 325 | 326 | return losses.avg 327 | 328 | class TripletPhotoTour(dset.PhotoTour): 329 | """ 330 | From the PhotoTour Dataset it generates triplet samples 331 | note: a triplet is composed by a pair of matching images and one of 332 | different class. 333 | """ 334 | def __init__(self, train=True, transform=None, batch_size = None,load_random_triplets = False, *arg, **kw): 335 | super(TripletPhotoTour, self).__init__(*arg, **kw) 336 | self.transform = transform 337 | self.out_triplets = load_random_triplets 338 | self.train = train 339 | self.n_triplets = 5000000 340 | self.batch_size = batch_size 341 | 342 | if self.train: 343 | print('Generating {} triplets'.format(self.n_triplets)) 344 | self.triplets = self.generate_triplets(self.labels, self.n_triplets) 345 | 346 | @staticmethod 347 | def generate_triplets(labels, num_triplets): 348 | def create_indices(_labels): 349 | inds = dict() 350 | for idx, ind in enumerate(_labels): 351 | if ind not in inds: 352 | inds[ind] = [] 353 | inds[ind].append(idx) 354 | return inds 355 | 356 | triplets = [] 357 | indices = create_indices(labels.numpy()) 358 | unique_labels = np.unique(labels.numpy()) 359 | n_classes = unique_labels.shape[0] 360 | # add only unique indices in batch 361 | already_idxs = set() 362 | 363 | for x in tqdm(range(num_triplets)): 364 | if len(already_idxs) >= 1024: 365 | already_idxs = set() 366 | c1 = np.random.randint(0, n_classes) 367 | while c1 in already_idxs: 368 | c1 = np.random.randint(0, n_classes) 369 | already_idxs.add(c1) 370 | c2 = np.random.randint(0, n_classes) 371 | while c1 == c2: 372 | c2 = np.random.randint(0, n_classes) 373 | if len(indices[c1]) == 2: # hack to speed up process 374 | n1, n2 = 0, 1 375 | else: 376 | n1 = np.random.randint(0, len(indices[c1])) 377 | n2 = np.random.randint(0, len(indices[c1])) 378 | while n1 == n2: 379 | n2 = np.random.randint(0, len(indices[c1])) 380 | n3 = np.random.randint(0, len(indices[c2])) 381 | triplets.append([indices[c1][n1], indices[c1][n2], indices[c2][n3]]) 382 | return torch.LongTensor(np.array(triplets)) 383 | 384 | def __getitem__(self, index): 385 | def transform_img(img): 386 | if self.transform is not None: 387 | img = self.transform(img.numpy()) 388 | return img 389 | 390 | if not self.train: 391 | m = self.matches[index] 392 | img1 = transform_img(self.data[m[0]]) 393 | img2 = transform_img(self.data[m[1]]) 394 | return img1, img2, m[2] 395 | 396 | t = self.triplets[index] 397 | a, p, n = self.data[t[0]], self.data[t[1]], self.data[t[2]] 398 | 399 | img_a = transform_img(a) 400 | img_p = transform_img(p) 401 | img_n = None 402 | if self.out_triplets: 403 | img_n = transform_img(n) 404 | # transform images if required 405 | if True: 406 | do_flip = random.random() > 0.5 407 | do_rot = random.random() > 0.5 408 | if do_rot: 409 | img_a = img_a.permute(0,2,1) 410 | img_p = img_p.permute(0,2,1) 411 | if self.out_triplets: 412 | img_n = img_n.permute(0,2,1) 413 | if do_flip: 414 | img_a = torch.from_numpy(deepcopy(img_a.numpy()[:,:,::-1])) 415 | img_p = torch.from_numpy(deepcopy(img_p.numpy()[:,:,::-1])) 416 | if self.out_triplets: 417 | img_n = torch.from_numpy(deepcopy(img_n.numpy()[:,:,::-1])) 418 | if self.out_triplets: 419 | return (img_a, img_p, img_n) 420 | else: 421 | return (img_a, img_p) 422 | 423 | def __len__(self): 424 | if self.train: 425 | return self.triplets.size(0) 426 | else: 427 | return self.matches.size(0) 428 | 429 | def create_loaders(): 430 | 431 | test_dataset_names = copy.copy(dataset_names) 432 | test_dataset_names.remove(args.training_set) 433 | 434 | kwargs = {'num_workers': args.workers, 'pin_memory': True} 435 | 436 | np_reshape64 = lambda x: np.reshape(x, (64, 64, 1)) 437 | transform_test = transforms.Compose([ 438 | transforms.Lambda(np_reshape64), 439 | transforms.ToPILImage(), 440 | transforms.Resize(32), 441 | transforms.ToTensor()]) 442 | 443 | test_loaders = [{'name': name, 444 | 'dataloader': torch.utils.data.DataLoader( 445 | TripletPhotoTour(train=False, 446 | batch_size=1024, 447 | root=args.dataroot, 448 | name=name, 449 | download=True, 450 | transform=transform_test), 451 | batch_size=1024, 452 | shuffle=False, **kwargs)} 453 | for name in test_dataset_names] 454 | 455 | return test_loaders 456 | 457 | class NewPhotoTour(dset.PhotoTour): 458 | """ 459 | From the PhotoTour Dataset it generates triplet samples 460 | note: a triplet is composed by a pair of matching images and one of 461 | different class. 462 | """ 463 | def __init__(self, *arg, **kw): 464 | super(NewPhotoTour, self).__init__(*arg, **kw) 465 | 466 | def __getitem__(self, index): 467 | if self.train: 468 | data = self.data[index] 469 | if self.transform is not None: 470 | data = self.transform(data.numpy()) 471 | return data 472 | 473 | def create_train_loaders(): 474 | kwargs = {'num_workers': args.workers, 'pin_memory': True} 475 | np_reshape64 = lambda x: np.reshape(x, (64, 64, 1)) 476 | transform_train = transforms.Compose([ 477 | transforms.Lambda(np_reshape64), 478 | transforms.ToPILImage(), 479 | transforms.RandomRotation(5,PIL.Image.BILINEAR), 480 | transforms.RandomResizedCrop(32, scale = (0.9,1.0),ratio = (0.9,1.1)), 481 | transforms.Resize(32), 482 | transforms.ToTensor()]) 483 | 484 | train_loader = torch.utils.data.DataLoader( 485 | NewPhotoTour( 486 | root='./data', 487 | name='liberty', 488 | train=True, 489 | transform=transform_train, 490 | download=True), 491 | batch_size=args.batch, 492 | shuffle=False, **kwargs) 493 | 494 | return train_loader 495 | 496 | def test(test_loader, model, epoch, logger_test_name): 497 | # switch to evaluate mode 498 | model.eval() 499 | 500 | labels, distances = [], [] 501 | 502 | pbar = tqdm(enumerate(test_loader)) 503 | with torch.no_grad(): 504 | for batch_idx, (data_a, data_p, label) in pbar: 505 | data_a, data_p = data_a.cuda(), data_p.cuda() 506 | out_a = model(SIFT(data_a)) 507 | out_p = model(SIFT(data_p)) 508 | dists = torch.sqrt(torch.sum((out_a - out_p) ** 2, 1)) # euclidean distance 509 | distances.append(dists.data.cpu().numpy().reshape(-1,1)) 510 | ll = label.data.cpu().numpy().reshape(-1, 1) 511 | labels.append(ll) 512 | 513 | if batch_idx % 10 == 0: 514 | pbar.set_description(logger_test_name+' Test Epoch: {} [{}/{} ({:.0f}%)]'.format( 515 | epoch, batch_idx * len(data_a), len(test_loader.dataset), 516 | 100. * batch_idx / len(test_loader))) 517 | 518 | num_tests = test_loader.dataset.matches.size(0) 519 | labels = np.vstack(labels).reshape(num_tests) 520 | distances = np.vstack(distances).reshape(num_tests) 521 | 522 | fpr95 = ErrorRateAt95Recall(labels, 1.0 / (distances + 1e-8)) 523 | print('\33[91mTest set: Accuracy(FPR95): {:.8f}\n\33[0m'.format(fpr95)) 524 | 525 | 526 | def compute_features(dataloader, model, N): 527 | if args.verbose: 528 | print('Compute features') 529 | batch_time = AverageMeter() 530 | end = time.time() 531 | model.eval() 532 | # discard the label information in the dataloader 533 | with torch.no_grad(): 534 | for i, input_tensor in enumerate(dataloader): 535 | input_var = input_tensor.cuda() 536 | des = SIFT(input_var) 537 | aux = model(des).cpu().detach().numpy() 538 | descr = des.cpu().detach().numpy() 539 | 540 | if i == 0: 541 | features = np.zeros((N, aux.shape[1]), dtype='float32') 542 | descrs = np.zeros((N, descr.shape[1]), dtype='float32') 543 | 544 | aux = aux.astype('float32') 545 | if i < len(dataloader) - 1: 546 | features[i * args.batch: (i + 1) * args.batch] = aux 547 | descrs[i * args.batch: (i + 1) * args.batch] = descr 548 | else: 549 | # special treatment for final batch 550 | features[i * args.batch:] = aux 551 | descrs[i * args.batch:] = descr 552 | 553 | # measure elapsed time 554 | batch_time.update(time.time() - end) 555 | end = time.time() 556 | 557 | if args.verbose and (i % 200) == 0: 558 | print('{0} / {1}\t' 559 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})' 560 | .format(i, len(dataloader), batch_time=batch_time)) 561 | return features, descrs 562 | 563 | def compute_features_init(dataloader, N): 564 | if args.verbose: 565 | print('Compute features init') 566 | batch_time = AverageMeter() 567 | end = time.time() 568 | # discard the label information in the dataloader 569 | for i, input_tensor in enumerate(dataloader): 570 | aux = SIFT(input_tensor.cuda()).cpu().detach().numpy() 571 | #aux = input_tensor.cpu().numpy() 572 | 573 | if i == 0: 574 | features = np.zeros((N, aux.shape[1]), dtype='float32') 575 | 576 | aux = aux.astype('float32') 577 | if i < len(dataloader) - 1: 578 | features[i * args.batch: (i + 1) * args.batch] = aux 579 | else: 580 | # special treatment for final batch 581 | features[i * args.batch:] = aux 582 | 583 | # measure elapsed time 584 | batch_time.update(time.time() - end) 585 | end = time.time() 586 | 587 | if args.verbose and (i % 200) == 0: 588 | print('{0} / {1}\t' 589 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})' 590 | .format(i, len(dataloader), batch_time=batch_time)) 591 | return features 592 | 593 | 594 | if __name__ == '__main__': 595 | args = parse_args() 596 | test_loader = create_loaders() 597 | main(args,test_loader) 598 | -------------------------------------------------------------------------------- /dr_sv_HardNet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from copy import deepcopy 3 | import math 4 | import argparse 5 | import torch 6 | import torch.nn.init 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torchvision.datasets as dset 10 | import torchvision.transforms as transforms 11 | from torch.autograd import Variable 12 | import torch.backends.cudnn as cudnn 13 | import os 14 | from tqdm import tqdm 15 | import numpy as np 16 | import random 17 | import cv2 18 | import copy 19 | import PIL 20 | from Losses import loss_HardNet_metric 21 | from Utils import L2Norm, cv2_scale, np_reshape 22 | from Utils import str2bool 23 | import torch.nn as nn 24 | import torch.nn.functional as F 25 | import kornia 26 | 27 | def ErrorRateAt95Recall(labels, scores): 28 | distances = 1.0 / (scores + 1e-8) 29 | recall_point = 0.95 30 | labels = labels[np.argsort(distances)] 31 | threshold_index = np.argmax(np.cumsum(labels) >= recall_point * np.sum(labels)) 32 | 33 | FP = np.sum(labels[:threshold_index] == 0) # Below threshold (i.e., labelled positive), but should be negative 34 | TN = np.sum(labels[threshold_index:] == 0) # Above threshold (i.e., labelled negative), and should be negative 35 | return float(FP) / float(FP + TN) 36 | 37 | HardNet = kornia.feature.HardNet(pretrained=True).cuda() 38 | HardNet.eval() 39 | 40 | parser = argparse.ArgumentParser(description='PyTorch dr') 41 | parser.add_argument('--dataroot', type=str, 42 | default='data/', 43 | help='path to dataset') 44 | parser.add_argument('--enable-logging',type=str2bool, default=False, 45 | help='output to tensorlogger') 46 | parser.add_argument('--log-dir', default='data/logs/', 47 | help='folder to output log') 48 | parser.add_argument('--model-dir', default='data/models/', 49 | help='folder to output model checkpoints') 50 | parser.add_argument('--experiment-name', default= 'triplet/', # 51 | help='experiment path') 52 | parser.add_argument('--training-set', default= 'liberty', 53 | help='Other options: liberty,notredame, yosemite') 54 | parser.add_argument('--loss', default= 'triplet_margin', 55 | help='Other options: softmax, contrastive') 56 | parser.add_argument('--batch-reduce', default= 'min', 57 | help='Other options: average, random, random_global, L2Net') 58 | parser.add_argument('--num-workers', default= 0, type=int, 59 | help='Number of workers to be created') 60 | parser.add_argument('--pin-memory',type=bool, default= True, 61 | help='') 62 | parser.add_argument('--decor',type=str2bool, default = False, 63 | help='L2Net decorrelation penalty') 64 | parser.add_argument('--anchorave', type=str2bool, default=False, 65 | help='anchorave') 66 | parser.add_argument('--imageSize', type=int, default=32, 67 | help='the height / width of the input image to network') 68 | parser.add_argument('--mean-image', type=float, default=0.443728476019, 69 | help='mean of train dataset for normalization') 70 | parser.add_argument('--std-image', type=float, default=0.20197947209, 71 | help='std of train dataset for normalization') 72 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 73 | help='path to latest checkpoint (default: none)') 74 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 75 | help='manual epoch number (useful on restarts)') 76 | parser.add_argument('--epochs', type=int, default=10, metavar='E', # 77 | help='number of epochs to train (default: 10)') 78 | parser.add_argument('--anchorswap', type=str2bool, default=True, 79 | help='turns on anchor swap') 80 | parser.add_argument('--batch-size', type=int, default=1024, metavar='BS', # 81 | help='input batch size for training (default: 1024)') 82 | parser.add_argument('--test-batch-size', type=int, default=1024, metavar='BST', 83 | help='input batch size for testing (default: 1024)') 84 | parser.add_argument('--n-triplets', type=int, default=5000000, metavar='N', 85 | help='how many triplets will generate from the dataset') 86 | parser.add_argument('--margin', type=float, default=1.0, metavar='MARGIN', 87 | help='the margin value for the triplet loss function (default: 1.0') 88 | parser.add_argument('--gor',type=str2bool, default=False, 89 | help='use gor') 90 | parser.add_argument('--freq', type=float, default=10.0, 91 | help='frequency for cyclic learning rate') 92 | parser.add_argument('--alpha', type=float, default=3.0, metavar='ALPHA', # 93 | help='gor parameter') 94 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR', # 95 | help='learning rate (default: 10.0. Yes, ten is not typo)') 96 | parser.add_argument('--fliprot', type=str2bool, default=True, 97 | help='turns on flip and 90deg rotation augmentation') 98 | parser.add_argument('--augmentation', type=str2bool, default=False, # 99 | help='turns on shift and small scale rotation augmentation') 100 | parser.add_argument('--lr-decay', default=1e-6, type=float, metavar='LRD', 101 | help='learning rate decay ratio (default: 1e-6') 102 | parser.add_argument('--wd', default=1e-4, type=float, 103 | metavar='W', help='weight decay (default: 1e-4)') 104 | parser.add_argument('--optimizer', default='adam', type=str, # 105 | metavar='OPT', help='The optimizer to use (default: SGD)') 106 | # Device options 107 | parser.add_argument('--no-cuda', action='store_true', default=False, 108 | help='enables CUDA training') 109 | parser.add_argument('--gpu-id', default='0', type=str, 110 | help='id(s) for CUDA_VISIBLE_DEVICES') 111 | parser.add_argument('--seed', type=int, default=0, metavar='S', 112 | help='random seed (default: 0)') 113 | parser.add_argument('--log-interval', type=int, default=10, metavar='LI', 114 | help='how many batches to wait before logging training status') 115 | parser.add_argument('--reduce_dim', type=int, default=64, help='reduce_dim') 116 | parser.add_argument('--descriptor', type=str, default='HardNet', help='descriptor') 117 | args = parser.parse_args() 118 | 119 | suffix = '{}_{}_{}'.format(args.experiment_name, args.training_set, args.batch_reduce) 120 | 121 | if args.gor: 122 | suffix = suffix+'_gor_alpha{:1.1f}'.format(args.alpha) 123 | if args.anchorswap: 124 | suffix = suffix + '_as' 125 | if args.anchorave: 126 | suffix = suffix + '_av' 127 | if args.fliprot: 128 | suffix = suffix + '_fliprot' 129 | 130 | triplet_flag = (args.batch_reduce == 'random_global') or args.gor 131 | 132 | dataset_names = ['liberty', 'notredame', 'yosemite'] 133 | 134 | # set the device to use by setting CUDA_VISIBLE_DEVICES env variable in 135 | # order to prevent any memory allocation on unused GPUs 136 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 137 | 138 | args.cuda = not args.no_cuda and torch.cuda.is_available() 139 | 140 | print (("NOT " if not args.cuda else "") + "Using cuda") 141 | 142 | if args.cuda: 143 | cudnn.benchmark = True 144 | torch.cuda.manual_seed_all(args.seed) 145 | torch.backends.cudnn.deterministic = True 146 | 147 | # create loggin directory 148 | if not os.path.exists(args.log_dir): 149 | os.makedirs(args.log_dir) 150 | 151 | # set random seeds 152 | random.seed(args.seed) 153 | torch.manual_seed(args.seed) 154 | np.random.seed(args.seed) 155 | 156 | class TripletPhotoTour(dset.PhotoTour): 157 | """ 158 | From the PhotoTour Dataset it generates triplet samples 159 | note: a triplet is composed by a pair of matching images and one of 160 | different class. 161 | """ 162 | def __init__(self, train=True, transform=None, batch_size = None,load_random_triplets = False, *arg, **kw): 163 | super(TripletPhotoTour, self).__init__(*arg, **kw) 164 | self.transform = transform 165 | self.out_triplets = load_random_triplets 166 | self.train = train 167 | self.n_triplets = args.n_triplets 168 | self.batch_size = batch_size 169 | 170 | if self.train: 171 | print('Generating {} triplets'.format(self.n_triplets)) 172 | self.triplets = self.generate_triplets(self.labels, self.n_triplets) 173 | 174 | @staticmethod 175 | def generate_triplets(labels, num_triplets): 176 | def create_indices(_labels): 177 | inds = dict() 178 | for idx, ind in enumerate(_labels): 179 | if ind not in inds: 180 | inds[ind] = [] 181 | inds[ind].append(idx) 182 | return inds 183 | 184 | triplets = [] 185 | indices = create_indices(labels.numpy()) 186 | unique_labels = np.unique(labels.numpy()) 187 | n_classes = unique_labels.shape[0] 188 | # add only unique indices in batch 189 | already_idxs = set() 190 | 191 | for x in tqdm(range(num_triplets)): 192 | if len(already_idxs) >= args.batch_size: 193 | already_idxs = set() 194 | c1 = np.random.randint(0, n_classes) 195 | while c1 in already_idxs: 196 | c1 = np.random.randint(0, n_classes) 197 | already_idxs.add(c1) 198 | c2 = np.random.randint(0, n_classes) 199 | while c1 == c2: 200 | c2 = np.random.randint(0, n_classes) 201 | if len(indices[c1]) == 2: # hack to speed up process 202 | n1, n2 = 0, 1 203 | else: 204 | n1 = np.random.randint(0, len(indices[c1])) 205 | n2 = np.random.randint(0, len(indices[c1])) 206 | while n1 == n2: 207 | n2 = np.random.randint(0, len(indices[c1])) 208 | n3 = np.random.randint(0, len(indices[c2])) 209 | triplets.append([indices[c1][n1], indices[c1][n2], indices[c2][n3]]) 210 | return torch.LongTensor(np.array(triplets)) 211 | 212 | def __getitem__(self, index): 213 | def transform_img(img): 214 | if self.transform is not None: 215 | img = self.transform(img.numpy()) 216 | return img 217 | 218 | if not self.train: 219 | m = self.matches[index] 220 | img1 = transform_img(self.data[m[0]]) 221 | img2 = transform_img(self.data[m[1]]) 222 | return img1, img2, m[2] 223 | 224 | t = self.triplets[index] 225 | a, p, n = self.data[t[0]], self.data[t[1]], self.data[t[2]] 226 | 227 | img_a = transform_img(a) 228 | img_p = transform_img(p) 229 | img_n = None 230 | if self.out_triplets: 231 | img_n = transform_img(n) 232 | # transform images if required 233 | if args.fliprot: 234 | do_flip = random.random() > 0.5 235 | do_rot = random.random() > 0.5 236 | if do_rot: 237 | img_a = img_a.permute(0,2,1) 238 | img_p = img_p.permute(0,2,1) 239 | if self.out_triplets: 240 | img_n = img_n.permute(0,2,1) 241 | if do_flip: 242 | img_a = torch.from_numpy(deepcopy(img_a.numpy()[:,:,::-1])) 243 | img_p = torch.from_numpy(deepcopy(img_p.numpy()[:,:,::-1])) 244 | if self.out_triplets: 245 | img_n = torch.from_numpy(deepcopy(img_n.numpy()[:,:,::-1])) 246 | if self.out_triplets: 247 | return (img_a, img_p, img_n) 248 | else: 249 | return (img_a, img_p) 250 | 251 | def __len__(self): 252 | if self.train: 253 | return self.triplets.size(0) 254 | else: 255 | return self.matches.size(0) 256 | 257 | class Encoder(nn.Module): 258 | def __init__(self, n_components,hidden=96): 259 | super(Encoder, self).__init__() 260 | self.enc_net = nn.Sequential( 261 | nn.Linear(128, hidden), 262 | nn.BatchNorm1d(hidden), 263 | nn.ReLU(), 264 | nn.Linear(hidden, n_components) 265 | ) 266 | def forward(self, x): 267 | output = self.enc_net(x) 268 | output = F.normalize(output, dim=1) 269 | return output 270 | 271 | def create_loaders(load_random_triplets = False): 272 | 273 | test_dataset_names = copy.copy(dataset_names) 274 | test_dataset_names.remove(args.training_set) 275 | 276 | kwargs = {'num_workers': args.num_workers, 'pin_memory': args.pin_memory} if args.cuda else {} 277 | 278 | np_reshape64 = lambda x: np.reshape(x, (64, 64, 1)) 279 | transform_test = transforms.Compose([ 280 | transforms.Lambda(np_reshape64), 281 | transforms.ToPILImage(), 282 | transforms.Resize(32), 283 | transforms.ToTensor()]) 284 | transform_train = transforms.Compose([ 285 | transforms.Lambda(np_reshape64), 286 | transforms.ToPILImage(), 287 | transforms.RandomRotation(5,PIL.Image.BILINEAR), 288 | transforms.RandomResizedCrop(32, scale = (0.9,1.0),ratio = (0.9,1.1)), 289 | transforms.Resize(32), 290 | transforms.ToTensor()]) 291 | transform = transforms.Compose([ 292 | transforms.Lambda(cv2_scale), 293 | transforms.Lambda(np_reshape), 294 | transforms.ToTensor(), 295 | transforms.Normalize((args.mean_image,), (args.std_image,))]) 296 | if not args.augmentation: 297 | transform_train = transform 298 | transform_test = transform 299 | train_loader = torch.utils.data.DataLoader( 300 | TripletPhotoTour(train=True, 301 | load_random_triplets = load_random_triplets, 302 | batch_size=args.batch_size, 303 | root=args.dataroot, 304 | name=args.training_set, 305 | download=True, 306 | transform=transform_train), 307 | batch_size=args.batch_size, 308 | shuffle=False, **kwargs) 309 | 310 | test_loaders = [{'name': name, 311 | 'dataloader': torch.utils.data.DataLoader( 312 | TripletPhotoTour(train=False, 313 | batch_size=args.test_batch_size, 314 | root=args.dataroot, 315 | name=name, 316 | download=True, 317 | transform=transform_test), 318 | batch_size=args.test_batch_size, 319 | shuffle=False, **kwargs)} 320 | for name in test_dataset_names] 321 | 322 | return train_loader, test_loaders 323 | 324 | def train(train_loader, model, optimizer, epoch, logger, load_triplets = False): 325 | # switch to train mode 326 | model.train() 327 | pbar = tqdm(enumerate(train_loader)) 328 | for batch_idx, data in pbar: 329 | if load_triplets: 330 | data_a, data_p, data_n = data 331 | else: 332 | data_a, data_p = data 333 | 334 | if args.cuda: 335 | data_a, data_p = data_a.cuda(), data_p.cuda() 336 | out_a_raw = HardNet(data_a) 337 | out_p_raw = HardNet(data_p) 338 | out_a = model(out_a_raw) 339 | out_p = model(out_p_raw) 340 | if load_triplets: 341 | data_n = data_n.cuda() 342 | out_n = model(HardNet(data_n)) 343 | 344 | loss = loss_HardNet_metric(out_a, out_p,out_a_raw,out_p_raw, 345 | margin=args.margin, 346 | anchor_swap=args.anchorswap, 347 | anchor_ave=args.anchorave, 348 | batch_reduce = args.batch_reduce, 349 | loss_type = args.loss, alpha=args.alpha) 350 | 351 | optimizer.zero_grad() 352 | loss.backward() 353 | optimizer.step() 354 | adjust_learning_rate(optimizer) 355 | if batch_idx % args.log_interval == 0: 356 | pbar.set_description( 357 | 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 358 | epoch, batch_idx * len(data_a), len(train_loader.dataset), 359 | 100. * batch_idx / len(train_loader), 360 | loss.item())) 361 | 362 | if (args.enable_logging): 363 | logger.log_value('loss', loss.item()).step() 364 | 365 | try: 366 | os.stat('{}{}'.format(args.model_dir,suffix)) 367 | except: 368 | os.makedirs('{}{}'.format(args.model_dir,suffix)) 369 | 370 | def test(test_loader, model, epoch, logger, logger_test_name): 371 | # switch to evaluate mode 372 | model.eval() 373 | 374 | labels, distances = [], [] 375 | 376 | pbar = tqdm(enumerate(test_loader)) 377 | with torch.no_grad(): 378 | for batch_idx, (data_a, data_p, label) in pbar: 379 | 380 | if args.cuda: 381 | data_a, data_p = data_a.cuda(), data_p.cuda() 382 | 383 | out_a = model(HardNet(data_a)) 384 | out_p = model(HardNet(data_p)) 385 | dists = torch.sqrt(torch.sum((out_a - out_p) ** 2, 1)) # euclidean distance 386 | distances.append(dists.data.cpu().numpy().reshape(-1,1)) 387 | ll = label.data.cpu().numpy().reshape(-1, 1) 388 | labels.append(ll) 389 | 390 | if batch_idx % args.log_interval == 0: 391 | pbar.set_description(logger_test_name+' Test Epoch: {} [{}/{} ({:.0f}%)]'.format( 392 | epoch, batch_idx * len(data_a), len(test_loader.dataset), 393 | 100. * batch_idx / len(test_loader))) 394 | 395 | num_tests = test_loader.dataset.matches.size(0) 396 | labels = np.vstack(labels).reshape(num_tests) 397 | distances = np.vstack(distances).reshape(num_tests) 398 | 399 | fpr95 = ErrorRateAt95Recall(labels, 1.0 / (distances + 1e-8)) 400 | print('\33[91mTest set: Accuracy(FPR95): {:.8f}\n\33[0m'.format(fpr95)) 401 | 402 | if (args.enable_logging): 403 | logger.log_value(logger_test_name+' fpr95', fpr95) 404 | return fpr95 405 | 406 | def adjust_learning_rate(optimizer): 407 | """Updates the learning rate given the learning rate decay. 408 | The routine has been implemented according to the original Lua SGD optimizer 409 | """ 410 | for group in optimizer.param_groups: 411 | if 'step' not in group: 412 | group['step'] = 0. 413 | else: 414 | group['step'] += 1. 415 | group['lr'] = args.lr * ( 416 | 1.0 - float(group['step']) * float(args.batch_size) / (args.n_triplets * float(args.epochs))) 417 | #print("group['lr']: ", group['lr']) 418 | return 419 | 420 | def create_optimizer(model, new_lr): 421 | # setup optimizer 422 | if args.optimizer == 'sgd': 423 | optimizer = optim.SGD(model.parameters(), lr=new_lr, 424 | momentum=0.9, dampening=0.9, 425 | weight_decay=args.wd) 426 | elif args.optimizer == 'adam': 427 | optimizer = optim.Adam(model.parameters(), lr=new_lr, 428 | weight_decay=args.wd) 429 | else: 430 | raise Exception('Not supported optimizer: {0}'.format(args.optimizer)) 431 | return optimizer 432 | 433 | 434 | def main(train_loader, test_loaders, model, logger, file_logger): 435 | # print the experiment configuration 436 | print('\nparsed options:\n{}\n'.format(vars(args))) 437 | if args.cuda: 438 | model.cuda() 439 | 440 | optimizer1 = create_optimizer(model, args.lr) 441 | 442 | # optionally resume from a checkpoint 443 | if args.resume: 444 | if os.path.isfile(args.resume): 445 | print('=> loading checkpoint {}'.format(args.resume)) 446 | checkpoint = torch.load(args.resume) 447 | args.start_epoch = checkpoint['epoch'] 448 | checkpoint = torch.load(args.resume) 449 | model.load_state_dict(checkpoint['state_dict']) 450 | else: 451 | print('=> no checkpoint found at {}'.format(args.resume)) 452 | 453 | file_fpr95 = 'fpr95.npy' 454 | file_fpr95 = os.path.join('{}{}'.format(args.model_dir,suffix), file_fpr95) 455 | fpr95 = [] 456 | 457 | start = args.start_epoch 458 | end = start + args.epochs 459 | for epoch in range(start, end): 460 | 461 | # iterate over test loaders and test results 462 | train(train_loader, model, optimizer1, epoch, logger, triplet_flag) 463 | for test_loader in test_loaders: 464 | test(test_loader['dataloader'], model, epoch, logger, test_loader['name']) 465 | fpr95_not = test(test_loaders[0]['dataloader'], model, epoch, logger, test_loaders[0]['name']) 466 | fpr95.append(fpr95_not) 467 | np.save(file_fpr95, fpr95) 468 | #randomize train loader batches 469 | if epoch < (end - 1) : 470 | train_loader, test_loaders2 = create_loaders(load_random_triplets=triplet_flag) 471 | 472 | 473 | if __name__ == '__main__': 474 | LOG_DIR = args.log_dir 475 | if not os.path.isdir(LOG_DIR): 476 | os.makedirs(LOG_DIR) 477 | LOG_DIR = os.path.join(args.log_dir, suffix) 478 | DESCS_DIR = os.path.join(LOG_DIR, 'temp_descs') 479 | logger, file_logger = None, None 480 | model = Encoder(n_components=args.reduce_dim,hidden=96) 481 | if(args.enable_logging): 482 | from Loggers import Logger, FileLogger 483 | logger = Logger(LOG_DIR) 484 | #file_logger = FileLogger(./log/+suffix) 485 | train_loader, test_loaders = create_loaders(load_random_triplets = triplet_flag) 486 | main(train_loader, test_loaders, model, logger, file_logger) 487 | 488 | file_name = 'models/' + args.descriptor + '_sv_dim' + str(args.reduce_dim) + '.pth' 489 | torch.save(model.state_dict(), file_name) 490 | -------------------------------------------------------------------------------- /dr_sv_SIFT.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from copy import deepcopy 3 | import math 4 | import argparse 5 | import torch 6 | import torch.nn.init 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torchvision.datasets as dset 10 | import torchvision.transforms as transforms 11 | from torch.autograd import Variable 12 | import torch.backends.cudnn as cudnn 13 | import os 14 | from tqdm import tqdm 15 | import numpy as np 16 | import random 17 | import cv2 18 | import copy 19 | import PIL 20 | from Losses import loss_HardNet 21 | from Utils import L2Norm, cv2_scale, np_reshape 22 | from Utils import str2bool 23 | import torch.nn as nn 24 | import torch.nn.functional as F 25 | import kornia 26 | 27 | def ErrorRateAt95Recall(labels, scores): 28 | distances = 1.0 / (scores + 1e-8) 29 | recall_point = 0.95 30 | labels = labels[np.argsort(distances)] 31 | threshold_index = np.argmax(np.cumsum(labels) >= recall_point * np.sum(labels)) 32 | 33 | FP = np.sum(labels[:threshold_index] == 0) # Below threshold (i.e., labelled positive), but should be negative 34 | TN = np.sum(labels[threshold_index:] == 0) # Above threshold (i.e., labelled negative), and should be negative 35 | return float(FP) / float(FP + TN) 36 | 37 | SIFT = kornia.feature.SIFTDescriptor(32, 8, 4, False).cuda() 38 | SIFT.eval() 39 | 40 | parser = argparse.ArgumentParser(description='PyTorch dr') 41 | parser.add_argument('--dataroot', type=str, 42 | default='data/', 43 | help='path to dataset') 44 | parser.add_argument('--enable-logging',type=str2bool, default=True, 45 | help='output to tensorlogger') 46 | parser.add_argument('--log-dir', default='data/logs/', 47 | help='folder to output log') 48 | parser.add_argument('--model-dir', default='data/models/', 49 | help='folder to output model checkpoints') 50 | parser.add_argument('--experiment-name', default= 'triplet/', # 51 | help='experiment path') 52 | parser.add_argument('--training-set', default= 'liberty', 53 | help='Other options: notredame, yosemite') 54 | parser.add_argument('--loss', default= 'triplet_margin', 55 | help='Other options: softmax, contrastive') 56 | parser.add_argument('--batch-reduce', default= 'min', 57 | help='Other options: average, random, random_global, L2Net') 58 | parser.add_argument('--num-workers', default= 0, type=int, 59 | help='Number of workers to be created') 60 | parser.add_argument('--pin-memory',type=bool, default= True, 61 | help='') 62 | parser.add_argument('--decor',type=str2bool, default = False, 63 | help='L2Net decorrelation penalty') 64 | parser.add_argument('--anchorave', type=str2bool, default=False, 65 | help='anchorave') 66 | parser.add_argument('--imageSize', type=int, default=32, 67 | help='the height / width of the input image to network') 68 | parser.add_argument('--mean-image', type=float, default=0.443728476019, 69 | help='mean of train dataset for normalization') 70 | parser.add_argument('--std-image', type=float, default=0.20197947209, 71 | help='std of train dataset for normalization') 72 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 73 | help='path to latest checkpoint (default: none)') 74 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 75 | help='manual epoch number (useful on restarts)') 76 | parser.add_argument('--epochs', type=int, default=10, metavar='E', # 77 | help='number of epochs to train (default: 10)') 78 | parser.add_argument('--anchorswap', type=str2bool, default=True, 79 | help='turns on anchor swap') 80 | parser.add_argument('--batch-size', type=int, default=1024, metavar='BS', # 81 | help='input batch size for training (default: 1024)') 82 | parser.add_argument('--test-batch-size', type=int, default=1024, metavar='BST', 83 | help='input batch size for testing (default: 1024)') 84 | parser.add_argument('--n-triplets', type=int, default=5000000, metavar='N', 85 | help='how many triplets will generate from the dataset') 86 | parser.add_argument('--margin', type=float, default=1.0, metavar='MARGIN', 87 | help='the margin value for the triplet loss function (default: 1.0') 88 | parser.add_argument('--gor',type=str2bool, default=False, 89 | help='use gor') 90 | parser.add_argument('--freq', type=float, default=10.0, 91 | help='frequency for cyclic learning rate') 92 | parser.add_argument('--alpha', type=float, default=1.0, metavar='ALPHA', 93 | help='gor parameter') 94 | parser.add_argument('--lr', type=float, default=0.01, metavar='LR', # 95 | help='learning rate (default: 10.0. Yes, ten is not typo)') 96 | parser.add_argument('--fliprot', type=str2bool, default=True, 97 | help='turns on flip and 90deg rotation augmentation') 98 | parser.add_argument('--augmentation', type=str2bool, default=False, 99 | help='turns on shift and small scale rotation augmentation') 100 | parser.add_argument('--lr-decay', default=1e-6, type=float, metavar='LRD', 101 | help='learning rate decay ratio (default: 1e-6') 102 | parser.add_argument('--wd', default=1e-4, type=float, 103 | metavar='W', help='weight decay (default: 1e-4)') 104 | parser.add_argument('--optimizer', default='adam', type=str, # 105 | metavar='OPT', help='The optimizer to use (default: SGD)') 106 | # Device options 107 | parser.add_argument('--no-cuda', action='store_true', default=False, 108 | help='enables CUDA training') 109 | parser.add_argument('--gpu-id', default='0', type=str, 110 | help='id(s) for CUDA_VISIBLE_DEVICES') 111 | parser.add_argument('--seed', type=int, default=0, metavar='S', 112 | help='random seed (default: 0)') 113 | parser.add_argument('--log-interval', type=int, default=10, metavar='LI', 114 | help='how many batches to wait before logging training status') 115 | parser.add_argument('--reduce_dim', type=int, default=64, help='reduce_dim') 116 | parser.add_argument('--descriptor', type=str, default='SIFT', help='descriptor') 117 | args = parser.parse_args() 118 | 119 | suffix = '{}_{}_{}'.format(args.experiment_name, args.training_set, args.batch_reduce) 120 | 121 | if args.gor: 122 | suffix = suffix+'_gor_alpha{:1.1f}'.format(args.alpha) 123 | if args.anchorswap: 124 | suffix = suffix + '_as' 125 | if args.anchorave: 126 | suffix = suffix + '_av' 127 | if args.fliprot: 128 | suffix = suffix + '_fliprot' 129 | 130 | triplet_flag = (args.batch_reduce == 'random_global') or args.gor 131 | 132 | dataset_names = ['liberty', 'notredame', 'yosemite'] 133 | 134 | # set the device to use by setting CUDA_VISIBLE_DEVICES env variable in 135 | # order to prevent any memory allocation on unused GPUs 136 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 137 | 138 | args.cuda = not args.no_cuda and torch.cuda.is_available() 139 | 140 | print (("NOT " if not args.cuda else "") + "Using cuda") 141 | 142 | if args.cuda: 143 | cudnn.benchmark = True 144 | torch.cuda.manual_seed_all(args.seed) 145 | torch.backends.cudnn.deterministic = True 146 | 147 | # create loggin directory 148 | if not os.path.exists(args.log_dir): 149 | os.makedirs(args.log_dir) 150 | 151 | # set random seeds 152 | random.seed(args.seed) 153 | torch.manual_seed(args.seed) 154 | np.random.seed(args.seed) 155 | 156 | class TripletPhotoTour(dset.PhotoTour): 157 | """ 158 | From the PhotoTour Dataset it generates triplet samples 159 | note: a triplet is composed by a pair of matching images and one of 160 | different class. 161 | """ 162 | def __init__(self, train=True, transform=None, batch_size = None,load_random_triplets = False, *arg, **kw): 163 | super(TripletPhotoTour, self).__init__(*arg, **kw) 164 | self.transform = transform 165 | self.out_triplets = load_random_triplets 166 | self.train = train 167 | self.n_triplets = args.n_triplets 168 | self.batch_size = batch_size 169 | 170 | if self.train: 171 | print('Generating {} triplets'.format(self.n_triplets)) 172 | self.triplets = self.generate_triplets(self.labels, self.n_triplets) 173 | 174 | @staticmethod 175 | def generate_triplets(labels, num_triplets): 176 | def create_indices(_labels): 177 | inds = dict() 178 | for idx, ind in enumerate(_labels): 179 | if ind not in inds: 180 | inds[ind] = [] 181 | inds[ind].append(idx) 182 | return inds 183 | 184 | triplets = [] 185 | indices = create_indices(labels.numpy()) 186 | unique_labels = np.unique(labels.numpy()) 187 | n_classes = unique_labels.shape[0] 188 | # add only unique indices in batch 189 | already_idxs = set() 190 | 191 | for x in tqdm(range(num_triplets)): 192 | if len(already_idxs) >= args.batch_size: 193 | already_idxs = set() 194 | c1 = np.random.randint(0, n_classes) 195 | while c1 in already_idxs: 196 | c1 = np.random.randint(0, n_classes) 197 | already_idxs.add(c1) 198 | c2 = np.random.randint(0, n_classes) 199 | while c1 == c2: 200 | c2 = np.random.randint(0, n_classes) 201 | if len(indices[c1]) == 2: # hack to speed up process 202 | n1, n2 = 0, 1 203 | else: 204 | n1 = np.random.randint(0, len(indices[c1])) 205 | n2 = np.random.randint(0, len(indices[c1])) 206 | while n1 == n2: 207 | n2 = np.random.randint(0, len(indices[c1])) 208 | n3 = np.random.randint(0, len(indices[c2])) 209 | triplets.append([indices[c1][n1], indices[c1][n2], indices[c2][n3]]) 210 | return torch.LongTensor(np.array(triplets)) 211 | 212 | def __getitem__(self, index): 213 | def transform_img(img): 214 | if self.transform is not None: 215 | img = self.transform(img.numpy()) 216 | return img 217 | 218 | if not self.train: 219 | m = self.matches[index] 220 | img1 = transform_img(self.data[m[0]]) 221 | img2 = transform_img(self.data[m[1]]) 222 | return img1, img2, m[2] 223 | 224 | t = self.triplets[index] 225 | a, p, n = self.data[t[0]], self.data[t[1]], self.data[t[2]] 226 | 227 | img_a = transform_img(a) 228 | img_p = transform_img(p) 229 | img_n = None 230 | if self.out_triplets: 231 | img_n = transform_img(n) 232 | # transform images if required 233 | if args.fliprot: 234 | do_flip = random.random() > 0.5 235 | do_rot = random.random() > 0.5 236 | if do_rot: 237 | img_a = img_a.permute(0,2,1) 238 | img_p = img_p.permute(0,2,1) 239 | if self.out_triplets: 240 | img_n = img_n.permute(0,2,1) 241 | if do_flip: 242 | img_a = torch.from_numpy(deepcopy(img_a.numpy()[:,:,::-1])) 243 | img_p = torch.from_numpy(deepcopy(img_p.numpy()[:,:,::-1])) 244 | if self.out_triplets: 245 | img_n = torch.from_numpy(deepcopy(img_n.numpy()[:,:,::-1])) 246 | if self.out_triplets: 247 | return (img_a, img_p, img_n) 248 | else: 249 | return (img_a, img_p) 250 | 251 | def __len__(self): 252 | if self.train: 253 | return self.triplets.size(0) 254 | else: 255 | return self.matches.size(0) 256 | 257 | class L2Norm(nn.Module): 258 | def __init__(self): 259 | super(L2Norm,self).__init__() 260 | self.eps = 1e-10 261 | def forward(self, x): 262 | norm = torch.sqrt(torch.sum(x * x, dim = 1) + self.eps) 263 | x= x / norm.unsqueeze(-1).expand_as(x) 264 | return x 265 | 266 | class Encoder(nn.Module): 267 | def __init__(self, n_components,hidden=512): 268 | super(Encoder, self).__init__() 269 | self.enc_net = nn.Sequential( 270 | nn.Linear(128, hidden), 271 | nn.ReLU(), 272 | nn.BatchNorm1d(hidden), 273 | nn.Linear(hidden, hidden), 274 | nn.ReLU(), 275 | nn.BatchNorm1d(hidden), 276 | nn.Linear(hidden, n_components) 277 | ) 278 | def forward(self, x): 279 | return L2Norm()(self.enc_net(x)) 280 | 281 | def create_loaders(load_random_triplets = False): 282 | 283 | test_dataset_names = copy.copy(dataset_names) 284 | test_dataset_names.remove(args.training_set) 285 | 286 | kwargs = {'num_workers': args.num_workers, 'pin_memory': args.pin_memory} if args.cuda else {} 287 | 288 | np_reshape64 = lambda x: np.reshape(x, (64, 64, 1)) 289 | transform_test = transforms.Compose([ 290 | transforms.Lambda(np_reshape64), 291 | transforms.ToPILImage(), 292 | transforms.Resize(32), 293 | transforms.ToTensor()]) 294 | transform_train = transforms.Compose([ 295 | transforms.Lambda(np_reshape64), 296 | transforms.ToPILImage(), 297 | transforms.RandomRotation(5,PIL.Image.BILINEAR), 298 | transforms.RandomResizedCrop(32, scale = (0.9,1.0),ratio = (0.9,1.1)), 299 | transforms.Resize(32), 300 | transforms.ToTensor()]) 301 | transform = transforms.Compose([ 302 | transforms.Lambda(cv2_scale), 303 | transforms.Lambda(np_reshape), 304 | transforms.ToTensor(), 305 | transforms.Normalize((args.mean_image,), (args.std_image,))]) 306 | if not args.augmentation: 307 | transform_train = transform 308 | transform_test = transform 309 | train_loader = torch.utils.data.DataLoader( 310 | TripletPhotoTour(train=True, 311 | load_random_triplets = load_random_triplets, 312 | batch_size=args.batch_size, 313 | root=args.dataroot, 314 | name=args.training_set, 315 | download=True, 316 | transform=transform_train), 317 | batch_size=args.batch_size, 318 | shuffle=False, **kwargs) 319 | 320 | test_loaders = [{'name': name, 321 | 'dataloader': torch.utils.data.DataLoader( 322 | TripletPhotoTour(train=False, 323 | batch_size=args.test_batch_size, 324 | root=args.dataroot, 325 | name=name, 326 | download=True, 327 | transform=transform_test), 328 | batch_size=args.test_batch_size, 329 | shuffle=False, **kwargs)} 330 | for name in test_dataset_names] 331 | 332 | return train_loader, test_loaders 333 | 334 | def train(train_loader, model, optimizer, epoch, logger, load_triplets=False): 335 | # switch to train mode 336 | model.train() 337 | pbar = tqdm(enumerate(train_loader)) 338 | for batch_idx, data in pbar: 339 | if load_triplets: 340 | data_a, data_p, data_n = data 341 | else: 342 | data_a, data_p = data 343 | 344 | if args.cuda: 345 | data_a, data_p = data_a.cuda(), data_p.cuda() 346 | out_a = model(SIFT(data_a)) 347 | out_p = model(SIFT(data_p)) 348 | if load_triplets: 349 | data_n = data_n.cuda() 350 | out_n = model(SIFT(data_n)) 351 | 352 | loss = loss_HardNet(out_a, out_p, 353 | margin=args.margin, 354 | anchor_swap=args.anchorswap, 355 | anchor_ave=args.anchorave, 356 | batch_reduce = args.batch_reduce, 357 | loss_type = args.loss) 358 | 359 | optimizer.zero_grad() 360 | loss.backward() 361 | optimizer.step() 362 | adjust_learning_rate(optimizer) 363 | if batch_idx % args.log_interval == 0: 364 | pbar.set_description( 365 | 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 366 | epoch, batch_idx * len(data_a), len(train_loader.dataset), 367 | 100. * batch_idx / len(train_loader), 368 | loss.item())) 369 | 370 | if (args.enable_logging): 371 | logger.log_value('loss', loss.item()).step() 372 | 373 | try: 374 | os.stat('{}{}'.format(args.model_dir,suffix)) 375 | except: 376 | os.makedirs('{}{}'.format(args.model_dir,suffix)) 377 | 378 | 379 | def test(test_loader, model, epoch, logger, logger_test_name): 380 | # switch to evaluate mode 381 | model.eval() 382 | 383 | labels, distances = [], [] 384 | 385 | pbar = tqdm(enumerate(test_loader)) 386 | with torch.no_grad(): 387 | for batch_idx, (data_a, data_p, label) in pbar: 388 | 389 | if args.cuda: 390 | data_a, data_p = data_a.cuda(), data_p.cuda() 391 | 392 | out_a = model(SIFT(data_a)) 393 | out_p = model(SIFT(data_p)) 394 | dists = torch.sqrt(torch.sum((out_a - out_p) ** 2, 1)) # euclidean distance 395 | distances.append(dists.data.cpu().numpy().reshape(-1,1)) 396 | ll = label.data.cpu().numpy().reshape(-1, 1) 397 | labels.append(ll) 398 | 399 | if batch_idx % args.log_interval == 0: 400 | pbar.set_description(logger_test_name+' Test Epoch: {} [{}/{} ({:.0f}%)]'.format( 401 | epoch, batch_idx * len(data_a), len(test_loader.dataset), 402 | 100. * batch_idx / len(test_loader))) 403 | 404 | num_tests = test_loader.dataset.matches.size(0) 405 | labels = np.vstack(labels).reshape(num_tests) 406 | distances = np.vstack(distances).reshape(num_tests) 407 | 408 | fpr95 = ErrorRateAt95Recall(labels, 1.0 / (distances + 1e-8)) 409 | print('\33[91mTest set: Accuracy(FPR95): {:.8f}\n\33[0m'.format(fpr95)) 410 | 411 | if (args.enable_logging): 412 | logger.log_value(logger_test_name+' fpr95', fpr95) 413 | return 414 | 415 | def adjust_learning_rate(optimizer): 416 | """Updates the learning rate given the learning rate decay. 417 | The routine has been implemented according to the original Lua SGD optimizer 418 | """ 419 | for group in optimizer.param_groups: 420 | if 'step' not in group: 421 | group['step'] = 0. 422 | else: 423 | group['step'] += 1. 424 | group['lr'] = args.lr * ( 425 | 1.0 - float(group['step']) * float(args.batch_size) / (args.n_triplets * float(args.epochs))) 426 | return 427 | 428 | def create_optimizer(model, new_lr): 429 | # setup optimizer 430 | if args.optimizer == 'sgd': 431 | optimizer = optim.SGD(model.parameters(), lr=new_lr, 432 | momentum=0.9, dampening=0.9, 433 | weight_decay=args.wd) 434 | elif args.optimizer == 'adam': 435 | optimizer = optim.Adam(model.parameters(), lr=new_lr, 436 | weight_decay=args.wd) 437 | else: 438 | raise Exception('Not supported optimizer: {0}'.format(args.optimizer)) 439 | return optimizer 440 | 441 | 442 | def main(train_loader, test_loaders, model, logger, file_logger): 443 | # print the experiment configuration 444 | print('\nparsed options:\n{}\n'.format(vars(args))) 445 | 446 | if args.cuda: 447 | model.cuda() 448 | 449 | optimizer1 = create_optimizer(model, args.lr) 450 | 451 | # optionally resume from a checkpoint 452 | if args.resume: 453 | if os.path.isfile(args.resume): 454 | print('=> loading checkpoint {}'.format(args.resume)) 455 | checkpoint = torch.load(args.resume) 456 | args.start_epoch = checkpoint['epoch'] 457 | checkpoint = torch.load(args.resume) 458 | model.load_state_dict(checkpoint['state_dict']) 459 | else: 460 | print('=> no checkpoint found at {}'.format(args.resume)) 461 | 462 | 463 | start = args.start_epoch 464 | end = start + args.epochs 465 | for epoch in range(start, end): 466 | 467 | # iterate over test loaders and test results 468 | train(train_loader, model, optimizer1, epoch, logger, triplet_flag) 469 | test(test_loaders[0]['dataloader'], model, epoch, logger, test_loaders[0]['name']) 470 | #randomize train loader batches 471 | if epoch < (end - 1) : 472 | train_loader, test_loaders2 = create_loaders(load_random_triplets=triplet_flag) 473 | 474 | 475 | if __name__ == '__main__': 476 | LOG_DIR = args.log_dir 477 | if not os.path.isdir(LOG_DIR): 478 | os.makedirs(LOG_DIR) 479 | LOG_DIR = os.path.join(args.log_dir, suffix) 480 | DESCS_DIR = os.path.join(LOG_DIR, 'temp_descs') 481 | logger, file_logger = None, None 482 | model = Encoder(n_components=args.reduce_dim,hidden=512) 483 | if(args.enable_logging): 484 | from Loggers import Logger, FileLogger 485 | logger = Logger(LOG_DIR) 486 | #file_logger = FileLogger(./log/+suffix) 487 | train_loader, test_loaders = create_loaders(load_random_triplets = triplet_flag) 488 | main(train_loader, test_loaders, model, logger, file_logger) 489 | 490 | file_name = 'models/' + args.descriptor + '_sv_dim' + str(args.reduce_dim) + '.pth' 491 | torch.save(model.state_dict(), file_name) 492 | 493 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from kornia.constants import T 3 | import numpy as np 4 | from sklearn.decomposition import PCA 5 | import pickle as pk 6 | import kornia 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torchvision import datasets 11 | import argparse 12 | from torch.utils.data import DataLoader 13 | from tqdm import tqdm 14 | 15 | parser = argparse.ArgumentParser(description='PyTorch dr') 16 | parser.add_argument('--descriptor', type=str, default='SIFT', help='descriptor') 17 | parser.add_argument('--dataset_names', type=str, default='liberty', help='dataset_names, notredame, yosemite, liberty') 18 | parser.add_argument('--reduce_dim', type=int, default=64, help='reduce_dim') 19 | args = parser.parse_args() 20 | 21 | dataset = datasets.PhotoTour( 22 | root='./data', name=args.dataset_names, train=True, transform=None, download=True) 23 | 24 | 25 | def load_model(model, filename, device): 26 | model.load_state_dict(torch.load(filename, map_location=lambda storage, loc: storage)) 27 | print('Model loaded from %s.' % filename) 28 | model.to(device) 29 | model.eval() 30 | 31 | class L2Norm(nn.Module): 32 | def __init__(self): 33 | super(L2Norm,self).__init__() 34 | self.eps = 1e-10 35 | def forward(self, x): 36 | norm = torch.sqrt(torch.sum(x * x, dim = 1) + self.eps) 37 | x= x / norm.unsqueeze(-1).expand_as(x) 38 | return x 39 | 40 | class Encoder(nn.Module): 41 | def __init__(self, n_components,hidden=128): 42 | super(Encoder, self).__init__() 43 | self.enc_net = nn.Sequential( 44 | nn.Linear(128, hidden), 45 | nn.ReLU(), 46 | nn.BatchNorm1d(hidden), 47 | nn.Linear(hidden, hidden), 48 | nn.ReLU(), 49 | nn.BatchNorm1d(hidden), 50 | nn.Linear(hidden, n_components) 51 | ) 52 | 53 | def forward(self, x): 54 | return L2Norm()(self.enc_net(x)) 55 | 56 | SIFT = kornia.feature.SIFTDescriptor(32, 8, 4, False).cuda() 57 | device = torch.device('cuda:0') 58 | n_components=64 59 | encoder = Encoder(n_components=n_components,hidden=512) 60 | load_model(encoder, 'models/SIFT_sv_dim64.pth', device) 61 | 62 | cv2_scale = lambda x: cv2.resize(x, dsize=(32, 32), 63 | interpolation=cv2.INTER_LINEAR) 64 | 65 | dataloader = DataLoader(dataset, batch_size=128, shuffle=False) 66 | descriptors = torch.empty((0,128), dtype=torch.float) 67 | 68 | for batch_idx, patches in enumerate(tqdm(dataloader)): 69 | patches_32 = np.empty([0,1,32,32]) 70 | patches = patches.cpu().detach().numpy() 71 | for i in range(patches.shape[0]): 72 | patch = cv2_scale(patches[i]) 73 | patch = np.expand_dims(patch, axis=0) 74 | patch = np.expand_dims(patch, axis=0) 75 | patches_32 = np.concatenate((patches_32,patch),axis=0) 76 | descs = SIFT(torch.from_numpy(patches_32).float().cuda()) # original SIFT descriptor (128) 77 | print(descs.shape) 78 | descs_dr = encoder(descs) # reducted SIFT descriptor (64), to be used in downstream tasks 79 | print(descs_dr.shape) 80 | break 81 | 82 | -------------------------------------------------------------------------------- /hpatches_extract_SIFT_64.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import glob 3 | import os 4 | import cv2 5 | from kornia.constants import T 6 | import numpy as np 7 | from sklearn.decomposition import PCA 8 | import pickle as pk 9 | import kornia 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | assert len(sys.argv)==2, "Usage python hpatches_extract.py hpatches_db_root_folder" 14 | 15 | # all types of patches 16 | tps = ['ref','e1','e2','e3','e4','e5','h1','h2','h3','h4','h5',\ 17 | 't1','t2','t3','t4','t5'] 18 | 19 | def load_model(model, filename, device): 20 | model.load_state_dict(torch.load(filename, map_location=lambda storage, loc: storage)) 21 | print('Model loaded from %s.' % filename) 22 | model.to(device) 23 | model.eval() 24 | 25 | class L2Norm(nn.Module): 26 | def __init__(self): 27 | super(L2Norm,self).__init__() 28 | self.eps = 1e-10 29 | def forward(self, x): 30 | norm = torch.sqrt(torch.sum(x * x, dim = 1) + self.eps) 31 | x= x / norm.unsqueeze(-1).expand_as(x) 32 | return x 33 | 34 | class Encoder(nn.Module): 35 | def __init__(self, n_components,hidden=128): 36 | super(Encoder, self).__init__() 37 | self.enc_net = nn.Sequential( 38 | nn.Linear(128, hidden), 39 | nn.ReLU(), 40 | nn.BatchNorm1d(hidden), 41 | nn.Linear(hidden, hidden), 42 | nn.ReLU(), 43 | nn.BatchNorm1d(hidden), 44 | nn.Linear(hidden, n_components) 45 | ) 46 | 47 | def forward(self, x): 48 | return L2Norm()(self.enc_net(x)) 49 | 50 | class hpatches_sequence: 51 | """Class for loading an HPatches sequence from a sequence folder""" 52 | itr = tps 53 | def __init__(self,base): 54 | name = base.split('/') 55 | self.name = name[-1] 56 | self.base = base 57 | for t in self.itr: 58 | im_path = os.path.join(base, t+'.png') 59 | im = cv2.imread(im_path,0) 60 | self.N = im.shape[0]/65 61 | setattr(self, t, np.split(im, self.N)) 62 | 63 | 64 | seqs = glob.glob(sys.argv[1]+'/*') 65 | seqs = [os.path.abspath(p) for p in seqs] 66 | 67 | hidden=512 68 | descr_name = 'SIFT_sv_dim64' 69 | SIFT = kornia.feature.SIFTDescriptor(32, 8, 4, False).cuda() 70 | device = torch.device('cuda:0') 71 | n_components=64 72 | encoder = Encoder(n_components=n_components,hidden=hidden) 73 | load_model(encoder, 'models/SIFT_sv_dim64.pth', device) 74 | cv2_scale = lambda x: cv2.resize(x, dsize=(32, 32), 75 | interpolation=cv2.INTER_LINEAR) 76 | w = 65 77 | for seq_path in seqs: 78 | seq = hpatches_sequence(seq_path) 79 | path = os.path.join(descr_name,seq.name) 80 | if not os.path.exists(path): 81 | os.makedirs(path) 82 | for tp in tps: 83 | print(seq.name+'/'+tp) 84 | if os.path.isfile(os.path.join(path,tp+'.csv')): 85 | continue 86 | n_patches = 0 87 | for i,patch in enumerate(getattr(seq, tp)): 88 | n_patches+=1 89 | patches_for_net = np.zeros((n_patches, 1, 32, 32)) 90 | for i,patch in enumerate(getattr(seq, tp)): 91 | patches_for_net[i,0,:,:] = cv2.resize(patch[0:w,0:w],(32,32)) 92 | encoder.eval() 93 | outs = [] 94 | bs = 128 95 | n_batches = int(n_patches / bs) + 1 96 | for batch_idx in range(n_batches): 97 | st = batch_idx * bs 98 | if batch_idx == n_batches - 1: 99 | if (batch_idx + 1) * bs > n_patches: 100 | end = n_patches 101 | else: 102 | end = (batch_idx + 1) * bs 103 | else: 104 | end = (batch_idx + 1) * bs 105 | if st >= end: 106 | continue 107 | data_a = patches_for_net[st: end, :, :, :].astype(np.float32) 108 | data_a = torch.from_numpy(data_a) 109 | 110 | data_a = data_a.to(device) 111 | 112 | out_a = SIFT(data_a) 113 | out_a = encoder(out_a) 114 | outs.append(out_a.data.cpu().numpy().reshape(-1, n_components)) 115 | res_desc = np.concatenate(outs) 116 | res_desc = np.reshape(res_desc, (n_patches, -1)) 117 | out = np.reshape(res_desc, (n_patches,-1)) 118 | np.savetxt(os.path.join(path,tp+'.csv'), out, delimiter=',', fmt='%10.5f') 119 | 120 | -------------------------------------------------------------------------------- /models/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | liberty/ 3 | logs/ 4 | models/ 5 | notredame/ 6 | yosemite/ 7 | *.pt 8 | *.pkl 9 | -------------------------------------------------------------------------------- /models/HardNet_sv_dim16.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRBonn/descriptor-dr/0c186c9a0ebba3c560aee5381bb25f47c6a328ca/models/HardNet_sv_dim16.pth -------------------------------------------------------------------------------- /models/HardNet_sv_dim24.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRBonn/descriptor-dr/0c186c9a0ebba3c560aee5381bb25f47c6a328ca/models/HardNet_sv_dim24.pth -------------------------------------------------------------------------------- /models/HardNet_sv_dim32.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRBonn/descriptor-dr/0c186c9a0ebba3c560aee5381bb25f47c6a328ca/models/HardNet_sv_dim32.pth -------------------------------------------------------------------------------- /models/HardNet_sv_dim64.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRBonn/descriptor-dr/0c186c9a0ebba3c560aee5381bb25f47c6a328ca/models/HardNet_sv_dim64.pth -------------------------------------------------------------------------------- /models/SIFT_sv_dim16.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRBonn/descriptor-dr/0c186c9a0ebba3c560aee5381bb25f47c6a328ca/models/SIFT_sv_dim16.pth -------------------------------------------------------------------------------- /models/SIFT_sv_dim24.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRBonn/descriptor-dr/0c186c9a0ebba3c560aee5381bb25f47c6a328ca/models/SIFT_sv_dim24.pth -------------------------------------------------------------------------------- /models/SIFT_sv_dim32.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRBonn/descriptor-dr/0c186c9a0ebba3c560aee5381bb25f47c6a328ca/models/SIFT_sv_dim32.pth -------------------------------------------------------------------------------- /models/SIFT_sv_dim64.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRBonn/descriptor-dr/0c186c9a0ebba3c560aee5381bb25f47c6a328ca/models/SIFT_sv_dim64.pth -------------------------------------------------------------------------------- /pics/HardNet-Hpatches.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRBonn/descriptor-dr/0c186c9a0ebba3c560aee5381bb25f47c6a328ca/pics/HardNet-Hpatches.PNG -------------------------------------------------------------------------------- /pics/MKD-Hpatches.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRBonn/descriptor-dr/0c186c9a0ebba3c560aee5381bb25f47c6a328ca/pics/MKD-Hpatches.PNG -------------------------------------------------------------------------------- /pics/SIFT-Hpatches.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRBonn/descriptor-dr/0c186c9a0ebba3c560aee5381bb25f47c6a328ca/pics/SIFT-Hpatches.png -------------------------------------------------------------------------------- /pics/TFeat-Hpatches.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRBonn/descriptor-dr/0c186c9a0ebba3c560aee5381bb25f47c6a328ca/pics/TFeat-Hpatches.PNG -------------------------------------------------------------------------------- /pics/localization.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRBonn/descriptor-dr/0c186c9a0ebba3c560aee5381bb25f47c6a328ca/pics/localization.PNG -------------------------------------------------------------------------------- /pics/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRBonn/descriptor-dr/0c186c9a0ebba3c560aee5381bb25f47c6a328ca/pics/overview.png -------------------------------------------------------------------------------- /raw_descriptors/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | liberty/ 3 | logs/ 4 | models/ 5 | notredame/ 6 | yosemite/ 7 | *.pt 8 | *.pkl 9 | *.pth 10 | *.npz 11 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | import os 8 | import pickle 9 | 10 | import numpy as np 11 | from torch.utils.data.sampler import Sampler 12 | 13 | 14 | class UnifLabelSampler(Sampler): 15 | """Samples elements uniformely accross pseudolabels. 16 | Args: 17 | N (int): size of returned iterator. 18 | images_lists: dict of key (target), value (list of data with this target) 19 | """ 20 | 21 | def __init__(self, N, images_lists): 22 | self.N = N 23 | self.images_lists = images_lists 24 | self.indexes = self.generate_indexes_epoch() 25 | 26 | def generate_indexes_epoch(self): 27 | nmb_non_empty_clusters = 0 28 | for i in range(len(self.images_lists)): 29 | if len(self.images_lists[i]) != 0: 30 | nmb_non_empty_clusters += 1 31 | 32 | size_per_pseudolabel = int(self.N / nmb_non_empty_clusters) + 1 33 | res = np.array([]) 34 | 35 | for i in range(len(self.images_lists)): 36 | # skip empty clusters 37 | if len(self.images_lists[i]) == 0: 38 | continue 39 | indexes = np.random.choice( 40 | self.images_lists[i], 41 | size_per_pseudolabel, 42 | replace=(len(self.images_lists[i]) <= size_per_pseudolabel) 43 | ) 44 | res = np.concatenate((res, indexes)) 45 | 46 | np.random.shuffle(res) 47 | res = list(res.astype('int')) 48 | if len(res) >= self.N: 49 | return res[:self.N] 50 | res += res[: (self.N - len(res))] 51 | return res 52 | 53 | def __iter__(self): 54 | return iter(self.indexes) 55 | 56 | def __len__(self): 57 | return len(self.indexes) 58 | 59 | 60 | class AverageMeter(object): 61 | """Computes and stores the average and current value""" 62 | def __init__(self): 63 | self.reset() 64 | 65 | def reset(self): 66 | self.val = 0 67 | self.avg = 0 68 | self.sum = 0 69 | self.count = 0 70 | 71 | def update(self, val, n=1): 72 | self.val = val 73 | self.sum += val * n 74 | self.count += n 75 | self.avg = self.sum / self.count 76 | 77 | 78 | def learning_rate_decay(optimizer, t, lr_0): 79 | for param_group in optimizer.param_groups: 80 | lr = lr_0 / np.sqrt(1 + lr_0 * param_group['weight_decay'] * t) 81 | param_group['lr'] = lr 82 | 83 | 84 | class Logger(object): 85 | """ Class to update every epoch to keep trace of the results 86 | Methods: 87 | - log() log and save 88 | """ 89 | 90 | def __init__(self, path): 91 | self.path = path 92 | self.data = [] 93 | 94 | def log(self, train_point): 95 | self.data.append(train_point) 96 | with open(os.path.join(self.path), 'wb') as fp: 97 | pickle.dump(self.data, fp, -1) 98 | --------------------------------------------------------------------------------