├── README.md ├── main.py ├── netvlad.py ├── pittsburgh.py └── tokyo247.py /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-NetVlad 2 | 3 | Implementation of [NetVlad](https://arxiv.org/abs/1511.07247) in PyTorch, including code for training the model on the Pittsburgh dataset. 4 | 5 | ### Reproducing the paper 6 | 7 | Below are the result as compared to the results in third row in the right column of Table 1: 8 | 9 | | |R@1|R@5|R@10| 10 | |---|---|---|---| 11 | | [NetVlad paper](https://arxiv.org/abs/1511.07247) | 84.1 | 94.6 | 95.5 | 12 | | pytorch-NetVlad(alexnet) | 68.6 | 84.6 | 89.3 | 13 | | pytorch-NetVlad(vgg16) | 85.2 | 94.8 | 97.0 | 14 | 15 | Running main.py with train mode and default settings should give similar scores to the ones shown above. Additionally, the model state for the above run is 16 | available here: https://drive.google.com/open?id=17luTjZFCX639guSVy00OUtzfTQo4AMF2 17 | 18 | Using this checkpoint and the following command you can obtain the results shown above: 19 | 20 | python main.py --mode=test --split=val --resume=vgg16_netvlad_checkpoint/ 21 | 22 | # Setup 23 | 24 | ## Dependencies 25 | 26 | 1. [PyTorch](https://pytorch.org/get-started/locally/) (at least v0.4.0) 27 | 2. [Faiss](https://github.com/facebookresearch/faiss) 28 | 3. [scipy](https://www.scipy.org/) 29 | - [numpy](http://www.numpy.org/) 30 | - [sklearn](https://scikit-learn.org/stable/) 31 | - [h5py](https://www.h5py.org/) 32 | 4. [tensorboardX](https://github.com/lanpa/tensorboardX) 33 | 34 | ## Data 35 | 36 | Running this code requires a copy of the Pittsburgh 250k (available [here](https://github.com/Relja/netvlad/issues/42)), 37 | and the dataset specifications for the Pittsburgh dataset (available [here](https://www.di.ens.fr/willow/research/netvlad/data/netvlad_v100_datasets.tar.gz)). 38 | `pittsburgh.py` contains a hardcoded path to a directory, where the code expects directories `000` to `010` with the various Pittsburth database images, a directory 39 | `queries_real` with subdirectories `000` to `010` with the query images, and a directory `datasets` with the dataset specifications (.mat files). 40 | 41 | 42 | # Usage 43 | 44 | `main.py` contains the majority of the code, and has three different modes (`train`, `test`, `cluster`) which we'll discuss in mode detail below. 45 | 46 | ## Train 47 | 48 | In order to initialise the NetVlad layer it is necessary to first run `main.py` with the correct settings and `--mode=cluster`. After which a model can be trained using (the following default flags): 49 | 50 | python main.py --mode=train --arch=vgg16 --pooling=netvlad --num_clusters=64 51 | 52 | The commandline args, the tensorboard data, and the model state will all be saved to `opt.runsPath`, which subsequently can be used for testing, or to resuming training. 53 | 54 | For more information on all commandline arguments run: 55 | 56 | python main.py --help 57 | 58 | ## Test 59 | 60 | To test a previously trained model on the Pittsburgh 30k testset (replace directory with correct dir for your case): 61 | 62 | python main.py --mode=test --resume=runsPath/Nov19_12-00-00_vgg16_netvlad --split=test 63 | 64 | The commandline arguments for training were saved, so we shouldnt need to specify them for testing. 65 | Additionally, to obtain the 'off the shelf' performance we can also omit the resume directory: 66 | 67 | python main.py --mode=test 68 | 69 | ## Cluster 70 | 71 | In order to initialise the NetVlad layer we need to first sample from the data and obtain `opt.num_clusters` centroids. This step is 72 | necessary for each configuration of the network and for each dataset. To cluster simply run 73 | 74 | python main.py --mode=cluster --arch=vgg16 --pooling=netvlad --num_clusters=64 75 | 76 | with the correct values for any additional commandline arguments. 77 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | from math import log10, ceil 4 | import random, shutil, json 5 | from os.path import join, exists, isfile, realpath, dirname 6 | from os import makedirs, remove, chdir, environ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader, SubsetRandomSampler 14 | from torch.utils.data.dataset import Subset 15 | import torchvision.transforms as transforms 16 | from PIL import Image 17 | from datetime import datetime 18 | import torchvision.datasets as datasets 19 | import torchvision.models as models 20 | import h5py 21 | import faiss 22 | 23 | from tensorboardX import SummaryWriter 24 | import numpy as np 25 | import netvlad 26 | 27 | parser = argparse.ArgumentParser(description='pytorch-NetVlad') 28 | parser.add_argument('--mode', type=str, default='train', help='Mode', choices=['train', 'test', 'cluster']) 29 | parser.add_argument('--batchSize', type=int, default=4, 30 | help='Number of triplets (query, pos, negs). Each triplet consists of 12 images.') 31 | parser.add_argument('--cacheBatchSize', type=int, default=24, help='Batch size for caching and testing') 32 | parser.add_argument('--cacheRefreshRate', type=int, default=1000, 33 | help='How often to refresh cache, in number of queries. 0 for off') 34 | parser.add_argument('--nEpochs', type=int, default=30, help='number of epochs to train for') 35 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 36 | help='manual epoch number (useful on restarts)') 37 | parser.add_argument('--nGPU', type=int, default=1, help='number of GPU to use.') 38 | parser.add_argument('--optim', type=str, default='SGD', help='optimizer to use', choices=['SGD', 'ADAM']) 39 | parser.add_argument('--lr', type=float, default=0.0001, help='Learning Rate.') 40 | parser.add_argument('--lrStep', type=float, default=5, help='Decay LR ever N steps.') 41 | parser.add_argument('--lrGamma', type=float, default=0.5, help='Multiply LR by Gamma for decaying.') 42 | parser.add_argument('--weightDecay', type=float, default=0.001, help='Weight decay for SGD.') 43 | parser.add_argument('--momentum', type=float, default=0.9, help='Momentum for SGD.') 44 | parser.add_argument('--nocuda', action='store_true', help='Dont use cuda') 45 | parser.add_argument('--threads', type=int, default=8, help='Number of threads for each data loader to use') 46 | parser.add_argument('--seed', type=int, default=123, help='Random seed to use.') 47 | parser.add_argument('--dataPath', type=str, default='/nfs/ibrahimi/data/', help='Path for centroid data.') 48 | parser.add_argument('--runsPath', type=str, default='/nfs/ibrahimi/runs/', help='Path to save runs to.') 49 | parser.add_argument('--savePath', type=str, default='checkpoints', 50 | help='Path to save checkpoints to in logdir. Default=checkpoints/') 51 | parser.add_argument('--cachePath', type=str, default=environ['TMPDIR'], help='Path to save cache to.') 52 | parser.add_argument('--resume', type=str, default='', help='Path to load checkpoint from, for resuming training or testing.') 53 | parser.add_argument('--ckpt', type=str, default='latest', 54 | help='Resume from latest or best checkpoint.', choices=['latest', 'best']) 55 | parser.add_argument('--evalEvery', type=int, default=1, 56 | help='Do a validation set run, and save, every N epochs.') 57 | parser.add_argument('--patience', type=int, default=10, help='Patience for early stopping. 0 is off.') 58 | parser.add_argument('--dataset', type=str, default='pittsburgh', 59 | help='Dataset to use', choices=['pittsburgh']) 60 | parser.add_argument('--arch', type=str, default='vgg16', 61 | help='basenetwork to use', choices=['vgg16', 'alexnet']) 62 | parser.add_argument('--vladv2', action='store_true', help='Use VLAD v2') 63 | parser.add_argument('--pooling', type=str, default='netvlad', help='type of pooling to use', 64 | choices=['netvlad', 'max', 'avg']) 65 | parser.add_argument('--num_clusters', type=int, default=64, help='Number of NetVlad clusters. Default=64') 66 | parser.add_argument('--margin', type=float, default=0.1, help='Margin for triplet loss. Default=0.1') 67 | parser.add_argument('--split', type=str, default='val', help='Data split to use for testing. Default is val', 68 | choices=['test', 'test250k', 'train', 'val']) 69 | parser.add_argument('--fromscratch', action='store_true', help='Train from scratch rather than using pretrained models') 70 | 71 | def train(epoch): 72 | epoch_loss = 0 73 | startIter = 1 # keep track of batch iter across subsets for logging 74 | 75 | if opt.cacheRefreshRate > 0: 76 | subsetN = ceil(len(train_set) / opt.cacheRefreshRate) 77 | #TODO randomise the arange before splitting? 78 | subsetIdx = np.array_split(np.arange(len(train_set)), subsetN) 79 | else: 80 | subsetN = 1 81 | subsetIdx = [np.arange(len(train_set))] 82 | 83 | nBatches = (len(train_set) + opt.batchSize - 1) // opt.batchSize 84 | 85 | for subIter in range(subsetN): 86 | print('====> Building Cache') 87 | model.eval() 88 | train_set.cache = join(opt.cachePath, train_set.whichSet + '_feat_cache.hdf5') 89 | with h5py.File(train_set.cache, mode='w') as h5: 90 | pool_size = encoder_dim 91 | if opt.pooling.lower() == 'netvlad': pool_size *= opt.num_clusters 92 | h5feat = h5.create_dataset("features", 93 | [len(whole_train_set), pool_size], 94 | dtype=np.float32) 95 | with torch.no_grad(): 96 | for iteration, (input, indices) in enumerate(whole_training_data_loader, 1): 97 | input = input.to(device) 98 | image_encoding = model.encoder(input) 99 | vlad_encoding = model.pool(image_encoding) 100 | h5feat[indices.detach().numpy(), :] = vlad_encoding.detach().cpu().numpy() 101 | del input, image_encoding, vlad_encoding 102 | 103 | sub_train_set = Subset(dataset=train_set, indices=subsetIdx[subIter]) 104 | 105 | training_data_loader = DataLoader(dataset=sub_train_set, num_workers=opt.threads, 106 | batch_size=opt.batchSize, shuffle=True, 107 | collate_fn=dataset.collate_fn, pin_memory=cuda) 108 | 109 | print('Allocated:', torch.cuda.memory_allocated()) 110 | print('Cached:', torch.cuda.memory_cached()) 111 | 112 | model.train() 113 | for iteration, (query, positives, negatives, 114 | negCounts, indices) in enumerate(training_data_loader, startIter): 115 | # some reshaping to put query, pos, negs in a single (N, 3, H, W) tensor 116 | # where N = batchSize * (nQuery + nPos + nNeg) 117 | if query is None: continue # in case we get an empty batch 118 | 119 | B, C, H, W = query.shape 120 | nNeg = torch.sum(negCounts) 121 | input = torch.cat([query, positives, negatives]) 122 | 123 | input = input.to(device) 124 | image_encoding = model.encoder(input) 125 | vlad_encoding = model.pool(image_encoding) 126 | 127 | vladQ, vladP, vladN = torch.split(vlad_encoding, [B, B, nNeg]) 128 | 129 | optimizer.zero_grad() 130 | 131 | # calculate loss for each Query, Positive, Negative triplet 132 | # due to potential difference in number of negatives have to 133 | # do it per query, per negative 134 | loss = 0 135 | for i, negCount in enumerate(negCounts): 136 | for n in range(negCount): 137 | negIx = (torch.sum(negCounts[:i]) + n).item() 138 | loss += criterion(vladQ[i:i+1], vladP[i:i+1], vladN[negIx:negIx+1]) 139 | 140 | loss /= nNeg.float().to(device) # normalise by actual number of negatives 141 | loss.backward() 142 | optimizer.step() 143 | del input, image_encoding, vlad_encoding, vladQ, vladP, vladN 144 | del query, positives, negatives 145 | 146 | batch_loss = loss.item() 147 | epoch_loss += batch_loss 148 | 149 | if iteration % 50 == 0 or nBatches <= 10: 150 | print("==> Epoch[{}]({}/{}): Loss: {:.4f}".format(epoch, iteration, 151 | nBatches, batch_loss), flush=True) 152 | writer.add_scalar('Train/Loss', batch_loss, 153 | ((epoch-1) * nBatches) + iteration) 154 | writer.add_scalar('Train/nNeg', nNeg, 155 | ((epoch-1) * nBatches) + iteration) 156 | print('Allocated:', torch.cuda.memory_allocated()) 157 | print('Cached:', torch.cuda.memory_cached()) 158 | 159 | startIter += len(training_data_loader) 160 | del training_data_loader, loss 161 | optimizer.zero_grad() 162 | torch.cuda.empty_cache() 163 | remove(train_set.cache) # delete HDF5 cache 164 | 165 | avg_loss = epoch_loss / nBatches 166 | 167 | print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(epoch, avg_loss), 168 | flush=True) 169 | writer.add_scalar('Train/AvgLoss', avg_loss, epoch) 170 | 171 | def test(eval_set, epoch=0, write_tboard=False): 172 | # TODO what if features dont fit in memory? 173 | test_data_loader = DataLoader(dataset=eval_set, 174 | num_workers=opt.threads, batch_size=opt.cacheBatchSize, shuffle=False, 175 | pin_memory=cuda) 176 | 177 | model.eval() 178 | with torch.no_grad(): 179 | print('====> Extracting Features') 180 | pool_size = encoder_dim 181 | if opt.pooling.lower() == 'netvlad': pool_size *= opt.num_clusters 182 | dbFeat = np.empty((len(eval_set), pool_size)) 183 | 184 | for iteration, (input, indices) in enumerate(test_data_loader, 1): 185 | input = input.to(device) 186 | image_encoding = model.encoder(input) 187 | vlad_encoding = model.pool(image_encoding) 188 | 189 | dbFeat[indices.detach().numpy(), :] = vlad_encoding.detach().cpu().numpy() 190 | if iteration % 50 == 0 or len(test_data_loader) <= 10: 191 | print("==> Batch ({}/{})".format(iteration, 192 | len(test_data_loader)), flush=True) 193 | 194 | del input, image_encoding, vlad_encoding 195 | del test_data_loader 196 | 197 | # extracted for both db and query, now split in own sets 198 | qFeat = dbFeat[eval_set.dbStruct.numDb:].astype('float32') 199 | dbFeat = dbFeat[:eval_set.dbStruct.numDb].astype('float32') 200 | 201 | print('====> Building faiss index') 202 | faiss_index = faiss.IndexFlatL2(pool_size) 203 | faiss_index.add(dbFeat) 204 | 205 | print('====> Calculating recall @ N') 206 | n_values = [1,5,10,20] 207 | 208 | _, predictions = faiss_index.search(qFeat, max(n_values)) 209 | 210 | # for each query get those within threshold distance 211 | gt = eval_set.getPositives() 212 | 213 | correct_at_n = np.zeros(len(n_values)) 214 | #TODO can we do this on the matrix in one go? 215 | for qIx, pred in enumerate(predictions): 216 | for i,n in enumerate(n_values): 217 | # if in top N then also in top NN, where NN > N 218 | if np.any(np.in1d(pred[:n], gt[qIx])): 219 | correct_at_n[i:] += 1 220 | break 221 | recall_at_n = correct_at_n / eval_set.dbStruct.numQ 222 | 223 | recalls = {} #make dict for output 224 | for i,n in enumerate(n_values): 225 | recalls[n] = recall_at_n[i] 226 | print("====> Recall@{}: {:.4f}".format(n, recall_at_n[i])) 227 | if write_tboard: writer.add_scalar('Val/Recall@' + str(n), recall_at_n[i], epoch) 228 | 229 | return recalls 230 | 231 | def get_clusters(cluster_set): 232 | nDescriptors = 50000 233 | nPerImage = 100 234 | nIm = ceil(nDescriptors/nPerImage) 235 | 236 | sampler = SubsetRandomSampler(np.random.choice(len(cluster_set), nIm, replace=False)) 237 | data_loader = DataLoader(dataset=cluster_set, 238 | num_workers=opt.threads, batch_size=opt.cacheBatchSize, shuffle=False, 239 | pin_memory=cuda, 240 | sampler=sampler) 241 | 242 | if not exists(join(opt.dataPath, 'centroids')): 243 | makedirs(join(opt.dataPath, 'centroids')) 244 | 245 | initcache = join(opt.dataPath, 'centroids', opt.arch + '_' + cluster_set.dataset + '_' + str(opt.num_clusters) + '_desc_cen.hdf5') 246 | with h5py.File(initcache, mode='w') as h5: 247 | with torch.no_grad(): 248 | model.eval() 249 | print('====> Extracting Descriptors') 250 | dbFeat = h5.create_dataset("descriptors", 251 | [nDescriptors, encoder_dim], 252 | dtype=np.float32) 253 | 254 | for iteration, (input, indices) in enumerate(data_loader, 1): 255 | input = input.to(device) 256 | image_descriptors = model.encoder(input).view(input.size(0), encoder_dim, -1).permute(0, 2, 1) 257 | 258 | batchix = (iteration-1)*opt.cacheBatchSize*nPerImage 259 | for ix in range(image_descriptors.size(0)): 260 | # sample different location for each image in batch 261 | sample = np.random.choice(image_descriptors.size(1), nPerImage, replace=False) 262 | startix = batchix + ix*nPerImage 263 | dbFeat[startix:startix+nPerImage, :] = image_descriptors[ix, sample, :].detach().cpu().numpy() 264 | 265 | if iteration % 50 == 0 or len(data_loader) <= 10: 266 | print("==> Batch ({}/{})".format(iteration, 267 | ceil(nIm/opt.cacheBatchSize)), flush=True) 268 | del input, image_descriptors 269 | 270 | print('====> Clustering..') 271 | niter = 100 272 | kmeans = faiss.Kmeans(encoder_dim, opt.num_clusters, niter=niter, verbose=False) 273 | kmeans.train(dbFeat[...]) 274 | 275 | print('====> Storing centroids', kmeans.centroids.shape) 276 | h5.create_dataset('centroids', data=kmeans.centroids) 277 | print('====> Done!') 278 | 279 | def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): 280 | model_out_path = join(opt.savePath, filename) 281 | torch.save(state, model_out_path) 282 | if is_best: 283 | shutil.copyfile(model_out_path, join(opt.savePath, 'model_best.pth.tar')) 284 | 285 | class Flatten(nn.Module): 286 | def forward(self, input): 287 | return input.view(input.size(0), -1) 288 | 289 | class L2Norm(nn.Module): 290 | def __init__(self, dim=1): 291 | super().__init__() 292 | self.dim = dim 293 | 294 | def forward(self, input): 295 | return F.normalize(input, p=2, dim=self.dim) 296 | 297 | if __name__ == "__main__": 298 | opt = parser.parse_args() 299 | 300 | restore_var = ['lr', 'lrStep', 'lrGamma', 'weightDecay', 'momentum', 301 | 'runsPath', 'savePath', 'arch', 'num_clusters', 'pooling', 'optim', 302 | 'margin', 'seed', 'patience'] 303 | if opt.resume: 304 | flag_file = join(opt.resume, 'checkpoints', 'flags.json') 305 | if exists(flag_file): 306 | with open(flag_file, 'r') as f: 307 | stored_flags = {'--'+k : str(v) for k,v in json.load(f).items() if k in restore_var} 308 | to_del = [] 309 | for flag, val in stored_flags.items(): 310 | for act in parser._actions: 311 | if act.dest == flag[2:]: 312 | # store_true / store_false args don't accept arguments, filter these 313 | if type(act.const) == type(True): 314 | if val == str(act.default): 315 | to_del.append(flag) 316 | else: 317 | stored_flags[flag] = '' 318 | for flag in to_del: del stored_flags[flag] 319 | 320 | train_flags = [x for x in list(sum(stored_flags.items(), tuple())) if len(x) > 0] 321 | print('Restored flags:', train_flags) 322 | opt = parser.parse_args(train_flags, namespace=opt) 323 | 324 | print(opt) 325 | 326 | if opt.dataset.lower() == 'pittsburgh': 327 | import pittsburgh as dataset 328 | else: 329 | raise Exception('Unknown dataset') 330 | 331 | cuda = not opt.nocuda 332 | if cuda and not torch.cuda.is_available(): 333 | raise Exception("No GPU found, please run with --nocuda") 334 | 335 | device = torch.device("cuda" if cuda else "cpu") 336 | 337 | random.seed(opt.seed) 338 | np.random.seed(opt.seed) 339 | torch.manual_seed(opt.seed) 340 | if cuda: 341 | torch.cuda.manual_seed(opt.seed) 342 | 343 | print('===> Loading dataset(s)') 344 | if opt.mode.lower() == 'train': 345 | whole_train_set = dataset.get_whole_training_set() 346 | whole_training_data_loader = DataLoader(dataset=whole_train_set, 347 | num_workers=opt.threads, batch_size=opt.cacheBatchSize, shuffle=False, 348 | pin_memory=cuda) 349 | 350 | train_set = dataset.get_training_query_set(opt.margin) 351 | 352 | print('====> Training query set:', len(train_set)) 353 | whole_test_set = dataset.get_whole_val_set() 354 | print('===> Evaluating on val set, query count:', whole_test_set.dbStruct.numQ) 355 | elif opt.mode.lower() == 'test': 356 | if opt.split.lower() == 'test': 357 | whole_test_set = dataset.get_whole_test_set() 358 | print('===> Evaluating on test set') 359 | elif opt.split.lower() == 'test250k': 360 | whole_test_set = dataset.get_250k_test_set() 361 | print('===> Evaluating on test250k set') 362 | elif opt.split.lower() == 'train': 363 | whole_test_set = dataset.get_whole_training_set() 364 | print('===> Evaluating on train set') 365 | elif opt.split.lower() == 'val': 366 | whole_test_set = dataset.get_whole_val_set() 367 | print('===> Evaluating on val set') 368 | else: 369 | raise ValueError('Unknown dataset split: ' + opt.split) 370 | print('====> Query count:', whole_test_set.dbStruct.numQ) 371 | elif opt.mode.lower() == 'cluster': 372 | whole_train_set = dataset.get_whole_training_set(onlyDB=True) 373 | 374 | print('===> Building model') 375 | 376 | pretrained = not opt.fromscratch 377 | if opt.arch.lower() == 'alexnet': 378 | encoder_dim = 256 379 | encoder = models.alexnet(pretrained=pretrained) 380 | # capture only features and remove last relu and maxpool 381 | layers = list(encoder.features.children())[:-2] 382 | 383 | if pretrained: 384 | # if using pretrained only train conv5 385 | for l in layers[:-1]: 386 | for p in l.parameters(): 387 | p.requires_grad = False 388 | 389 | elif opt.arch.lower() == 'vgg16': 390 | encoder_dim = 512 391 | encoder = models.vgg16(pretrained=pretrained) 392 | # capture only feature part and remove last relu and maxpool 393 | layers = list(encoder.features.children())[:-2] 394 | 395 | if pretrained: 396 | # if using pretrained then only train conv5_1, conv5_2, and conv5_3 397 | for l in layers[:-5]: 398 | for p in l.parameters(): 399 | p.requires_grad = False 400 | 401 | if opt.mode.lower() == 'cluster' and not opt.vladv2: 402 | layers.append(L2Norm()) 403 | 404 | encoder = nn.Sequential(*layers) 405 | model = nn.Module() 406 | model.add_module('encoder', encoder) 407 | 408 | if opt.mode.lower() != 'cluster': 409 | if opt.pooling.lower() == 'netvlad': 410 | net_vlad = netvlad.NetVLAD(num_clusters=opt.num_clusters, dim=encoder_dim, vladv2=opt.vladv2) 411 | if not opt.resume: 412 | if opt.mode.lower() == 'train': 413 | initcache = join(opt.dataPath, 'centroids', opt.arch + '_' + train_set.dataset + '_' + str(opt.num_clusters) +'_desc_cen.hdf5') 414 | else: 415 | initcache = join(opt.dataPath, 'centroids', opt.arch + '_' + whole_test_set.dataset + '_' + str(opt.num_clusters) +'_desc_cen.hdf5') 416 | 417 | if not exists(initcache): 418 | raise FileNotFoundError('Could not find clusters, please run with --mode=cluster before proceeding') 419 | 420 | with h5py.File(initcache, mode='r') as h5: 421 | clsts = h5.get("centroids")[...] 422 | traindescs = h5.get("descriptors")[...] 423 | net_vlad.init_params(clsts, traindescs) 424 | del clsts, traindescs 425 | 426 | model.add_module('pool', net_vlad) 427 | elif opt.pooling.lower() == 'max': 428 | global_pool = nn.AdaptiveMaxPool2d((1,1)) 429 | model.add_module('pool', nn.Sequential(*[global_pool, Flatten(), L2Norm()])) 430 | elif opt.pooling.lower() == 'avg': 431 | global_pool = nn.AdaptiveAvgPool2d((1,1)) 432 | model.add_module('pool', nn.Sequential(*[global_pool, Flatten(), L2Norm()])) 433 | else: 434 | raise ValueError('Unknown pooling type: ' + opt.pooling) 435 | 436 | isParallel = False 437 | if opt.nGPU > 1 and torch.cuda.device_count() > 1: 438 | model.encoder = nn.DataParallel(model.encoder) 439 | if opt.mode.lower() != 'cluster': 440 | model.pool = nn.DataParallel(model.pool) 441 | isParallel = True 442 | 443 | if not opt.resume: 444 | model = model.to(device) 445 | 446 | if opt.mode.lower() == 'train': 447 | if opt.optim.upper() == 'ADAM': 448 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, 449 | model.parameters()), lr=opt.lr)#, betas=(0,0.9)) 450 | elif opt.optim.upper() == 'SGD': 451 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, 452 | model.parameters()), lr=opt.lr, 453 | momentum=opt.momentum, 454 | weight_decay=opt.weightDecay) 455 | 456 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=opt.lrStep, gamma=opt.lrGamma) 457 | else: 458 | raise ValueError('Unknown optimizer: ' + opt.optim) 459 | 460 | # original paper/code doesn't sqrt() the distances, we do, so sqrt() the margin, I think :D 461 | criterion = nn.TripletMarginLoss(margin=opt.margin**0.5, 462 | p=2, reduction='sum').to(device) 463 | 464 | if opt.resume: 465 | if opt.ckpt.lower() == 'latest': 466 | resume_ckpt = join(opt.resume, 'checkpoints', 'checkpoint.pth.tar') 467 | elif opt.ckpt.lower() == 'best': 468 | resume_ckpt = join(opt.resume, 'checkpoints', 'model_best.pth.tar') 469 | 470 | if isfile(resume_ckpt): 471 | print("=> loading checkpoint '{}'".format(resume_ckpt)) 472 | checkpoint = torch.load(resume_ckpt, map_location=lambda storage, loc: storage) 473 | opt.start_epoch = checkpoint['epoch'] 474 | best_metric = checkpoint['best_score'] 475 | model.load_state_dict(checkpoint['state_dict']) 476 | model = model.to(device) 477 | if opt.mode == 'train': 478 | optimizer.load_state_dict(checkpoint['optimizer']) 479 | print("=> loaded checkpoint '{}' (epoch {})" 480 | .format(resume_ckpt, checkpoint['epoch'])) 481 | else: 482 | print("=> no checkpoint found at '{}'".format(resume_ckpt)) 483 | 484 | if opt.mode.lower() == 'test': 485 | print('===> Running evaluation step') 486 | epoch = 1 487 | recalls = test(whole_test_set, epoch, write_tboard=False) 488 | elif opt.mode.lower() == 'cluster': 489 | print('===> Calculating descriptors and clusters') 490 | get_clusters(whole_train_set) 491 | elif opt.mode.lower() == 'train': 492 | print('===> Training model') 493 | writer = SummaryWriter(log_dir=join(opt.runsPath, datetime.now().strftime('%b%d_%H-%M-%S')+'_'+opt.arch+'_'+opt.pooling)) 494 | 495 | # write checkpoints in logdir 496 | logdir = writer.file_writer.get_logdir() 497 | opt.savePath = join(logdir, opt.savePath) 498 | if not opt.resume: 499 | makedirs(opt.savePath) 500 | 501 | with open(join(opt.savePath, 'flags.json'), 'w') as f: 502 | f.write(json.dumps( 503 | {k:v for k,v in vars(opt).items()} 504 | )) 505 | print('===> Saving state to:', logdir) 506 | 507 | not_improved = 0 508 | best_score = 0 509 | for epoch in range(opt.start_epoch+1, opt.nEpochs + 1): 510 | if opt.optim.upper() == 'SGD': 511 | scheduler.step(epoch) 512 | train(epoch) 513 | if (epoch % opt.evalEvery) == 0: 514 | recalls = test(whole_test_set, epoch, write_tboard=True) 515 | is_best = recalls[5] > best_score 516 | if is_best: 517 | not_improved = 0 518 | best_score = recalls[5] 519 | else: 520 | not_improved += 1 521 | 522 | save_checkpoint({ 523 | 'epoch': epoch, 524 | 'state_dict': model.state_dict(), 525 | 'recalls': recalls, 526 | 'best_score': best_score, 527 | 'optimizer' : optimizer.state_dict(), 528 | 'parallel' : isParallel, 529 | }, is_best) 530 | 531 | if opt.patience > 0 and not_improved > (opt.patience / opt.evalEvery): 532 | print('Performance did not improve for', opt.patience, 'epochs. Stopping.') 533 | break 534 | 535 | print("=> Best Recall@5: {:.4f}".format(best_score), flush=True) 536 | writer.close() 537 | -------------------------------------------------------------------------------- /netvlad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from sklearn.neighbors import NearestNeighbors 5 | import numpy as np 6 | 7 | # based on https://github.com/lyakaap/NetVLAD-pytorch/blob/master/netvlad.py 8 | class NetVLAD(nn.Module): 9 | """NetVLAD layer implementation""" 10 | 11 | def __init__(self, num_clusters=64, dim=128, 12 | normalize_input=True, vladv2=False): 13 | """ 14 | Args: 15 | num_clusters : int 16 | The number of clusters 17 | dim : int 18 | Dimension of descriptors 19 | alpha : float 20 | Parameter of initialization. Larger value is harder assignment. 21 | normalize_input : bool 22 | If true, descriptor-wise L2 normalization is applied to input. 23 | vladv2 : bool 24 | If true, use vladv2 otherwise use vladv1 25 | """ 26 | super(NetVLAD, self).__init__() 27 | self.num_clusters = num_clusters 28 | self.dim = dim 29 | self.alpha = 0 30 | self.vladv2 = vladv2 31 | self.normalize_input = normalize_input 32 | self.conv = nn.Conv2d(dim, num_clusters, kernel_size=(1, 1), bias=vladv2) 33 | self.centroids = nn.Parameter(torch.rand(num_clusters, dim)) 34 | 35 | def init_params(self, clsts, traindescs): 36 | #TODO replace numpy ops with pytorch ops 37 | if self.vladv2 == False: 38 | clstsAssign = clsts / np.linalg.norm(clsts, axis=1, keepdims=True) 39 | dots = np.dot(clstsAssign, traindescs.T) 40 | dots.sort(0) 41 | dots = dots[::-1, :] # sort, descending 42 | 43 | self.alpha = (-np.log(0.01) / np.mean(dots[0,:] - dots[1,:])).item() 44 | self.centroids = nn.Parameter(torch.from_numpy(clsts)) 45 | self.conv.weight = nn.Parameter(torch.from_numpy(self.alpha*clstsAssign).unsqueeze(2).unsqueeze(3)) 46 | self.conv.bias = None 47 | else: 48 | knn = NearestNeighbors(n_jobs=-1) #TODO faiss? 49 | knn.fit(traindescs) 50 | del traindescs 51 | dsSq = np.square(knn.kneighbors(clsts, 2)[1]) 52 | del knn 53 | self.alpha = (-np.log(0.01) / np.mean(dsSq[:,1] - dsSq[:,0])).item() 54 | self.centroids = nn.Parameter(torch.from_numpy(clsts)) 55 | del clsts, dsSq 56 | 57 | self.conv.weight = nn.Parameter( 58 | (2.0 * self.alpha * self.centroids).unsqueeze(-1).unsqueeze(-1) 59 | ) 60 | self.conv.bias = nn.Parameter( 61 | - self.alpha * self.centroids.norm(dim=1) 62 | ) 63 | 64 | def forward(self, x): 65 | N, C = x.shape[:2] 66 | 67 | if self.normalize_input: 68 | x = F.normalize(x, p=2, dim=1) # across descriptor dim 69 | 70 | # soft-assignment 71 | soft_assign = self.conv(x).view(N, self.num_clusters, -1) 72 | soft_assign = F.softmax(soft_assign, dim=1) 73 | 74 | x_flatten = x.view(N, C, -1) 75 | 76 | # calculate residuals to each clusters 77 | vlad = torch.zeros([N, self.num_clusters, C], dtype=x.dtype, layout=x.layout, device=x.device) 78 | for C in range(self.num_clusters): # slower than non-looped, but lower memory usage 79 | residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - \ 80 | self.centroids[C:C+1, :].expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0) 81 | residual *= soft_assign[:,C:C+1,:].unsqueeze(2) 82 | vlad[:,C:C+1,:] = residual.sum(dim=-1) 83 | 84 | vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization 85 | vlad = vlad.view(x.size(0), -1) # flatten 86 | vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize 87 | 88 | return vlad 89 | -------------------------------------------------------------------------------- /pittsburgh.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import torch.utils.data as data 4 | 5 | from os.path import join, exists 6 | from scipy.io import loadmat 7 | import numpy as np 8 | from random import randint, random 9 | from collections import namedtuple 10 | from PIL import Image 11 | 12 | from sklearn.neighbors import NearestNeighbors 13 | import h5py 14 | 15 | root_dir = '/nfs/ibrahimi/data/pittsburgh/' 16 | if not exists(root_dir): 17 | raise FileNotFoundError('root_dir is hardcoded, please adjust to point to Pittsburth dataset') 18 | 19 | struct_dir = join(root_dir, 'datasets/') 20 | queries_dir = join(root_dir, 'queries_real') 21 | 22 | def input_transform(): 23 | return transforms.Compose([ 24 | transforms.ToTensor(), 25 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 26 | std=[0.229, 0.224, 0.225]), 27 | ]) 28 | 29 | def get_whole_training_set(onlyDB=False): 30 | structFile = join(struct_dir, 'pitts30k_train.mat') 31 | return WholeDatasetFromStruct(structFile, 32 | input_transform=input_transform(), 33 | onlyDB=onlyDB) 34 | 35 | def get_whole_val_set(): 36 | structFile = join(struct_dir, 'pitts30k_val.mat') 37 | return WholeDatasetFromStruct(structFile, 38 | input_transform=input_transform()) 39 | 40 | def get_250k_val_set(): 41 | structFile = join(struct_dir, 'pitts250k_val.mat') 42 | return WholeDatasetFromStruct(structFile, 43 | input_transform=input_transform()) 44 | def get_whole_test_set(): 45 | structFile = join(struct_dir, 'pitts30k_test.mat') 46 | return WholeDatasetFromStruct(structFile, 47 | input_transform=input_transform()) 48 | 49 | def get_250k_test_set(): 50 | structFile = join(struct_dir, 'pitts250k_test.mat') 51 | return WholeDatasetFromStruct(structFile, 52 | input_transform=input_transform()) 53 | 54 | def get_training_query_set(margin=0.1): 55 | structFile = join(struct_dir, 'pitts30k_train.mat') 56 | return QueryDatasetFromStruct(structFile, 57 | input_transform=input_transform(), margin=margin) 58 | 59 | def get_val_query_set(): 60 | structFile = join(struct_dir, 'pitts30k_val.mat') 61 | return QueryDatasetFromStruct(structFile, 62 | input_transform=input_transform()) 63 | 64 | def get_250k_val_query_set(): 65 | structFile = join(struct_dir, 'pitts250k_val.mat') 66 | return QueryDatasetFromStruct(structFile, 67 | input_transform=input_transform()) 68 | 69 | dbStruct = namedtuple('dbStruct', ['whichSet', 'dataset', 70 | 'dbImage', 'utmDb', 'qImage', 'utmQ', 'numDb', 'numQ', 71 | 'posDistThr', 'posDistSqThr', 'nonTrivPosDistSqThr']) 72 | 73 | def parse_dbStruct(path): 74 | mat = loadmat(path) 75 | matStruct = mat['dbStruct'].item() 76 | 77 | if '250k' in path.split('/')[-1]: 78 | dataset = 'pitts250k' 79 | else: 80 | dataset = 'pitts30k' 81 | 82 | whichSet = matStruct[0].item() 83 | 84 | dbImage = [f[0].item() for f in matStruct[1]] 85 | utmDb = matStruct[2].T 86 | 87 | qImage = [f[0].item() for f in matStruct[3]] 88 | utmQ = matStruct[4].T 89 | 90 | numDb = matStruct[5].item() 91 | numQ = matStruct[6].item() 92 | 93 | posDistThr = matStruct[7].item() 94 | posDistSqThr = matStruct[8].item() 95 | nonTrivPosDistSqThr = matStruct[9].item() 96 | 97 | return dbStruct(whichSet, dataset, dbImage, utmDb, qImage, 98 | utmQ, numDb, numQ, posDistThr, 99 | posDistSqThr, nonTrivPosDistSqThr) 100 | 101 | class WholeDatasetFromStruct(data.Dataset): 102 | def __init__(self, structFile, input_transform=None, onlyDB=False): 103 | super().__init__() 104 | 105 | self.input_transform = input_transform 106 | 107 | self.dbStruct = parse_dbStruct(structFile) 108 | self.images = [join(root_dir, dbIm) for dbIm in self.dbStruct.dbImage] 109 | if not onlyDB: 110 | self.images += [join(queries_dir, qIm) for qIm in self.dbStruct.qImage] 111 | 112 | self.whichSet = self.dbStruct.whichSet 113 | self.dataset = self.dbStruct.dataset 114 | 115 | self.positives = None 116 | self.distances = None 117 | 118 | def __getitem__(self, index): 119 | img = Image.open(self.images[index]) 120 | 121 | if self.input_transform: 122 | img = self.input_transform(img) 123 | 124 | return img, index 125 | 126 | def __len__(self): 127 | return len(self.images) 128 | 129 | def getPositives(self): 130 | # positives for evaluation are those within trivial threshold range 131 | #fit NN to find them, search by radius 132 | if self.positives is None: 133 | knn = NearestNeighbors(n_jobs=-1) 134 | knn.fit(self.dbStruct.utmDb) 135 | 136 | self.distances, self.positives = knn.radius_neighbors(self.dbStruct.utmQ, 137 | radius=self.dbStruct.posDistThr) 138 | 139 | return self.positives 140 | 141 | def collate_fn(batch): 142 | """Creates mini-batch tensors from the list of tuples (query, positive, negatives). 143 | 144 | Args: 145 | data: list of tuple (query, positive, negatives). 146 | - query: torch tensor of shape (3, h, w). 147 | - positive: torch tensor of shape (3, h, w). 148 | - negative: torch tensor of shape (n, 3, h, w). 149 | Returns: 150 | query: torch tensor of shape (batch_size, 3, h, w). 151 | positive: torch tensor of shape (batch_size, 3, h, w). 152 | negatives: torch tensor of shape (batch_size, n, 3, h, w). 153 | """ 154 | 155 | batch = list(filter (lambda x:x is not None, batch)) 156 | if len(batch) == 0: return None, None, None, None, None 157 | 158 | query, positive, negatives, indices = zip(*batch) 159 | 160 | query = data.dataloader.default_collate(query) 161 | positive = data.dataloader.default_collate(positive) 162 | negCounts = data.dataloader.default_collate([x.shape[0] for x in negatives]) 163 | negatives = torch.cat(negatives, 0) 164 | import itertools 165 | indices = list(itertools.chain(*indices)) 166 | 167 | return query, positive, negatives, negCounts, indices 168 | 169 | class QueryDatasetFromStruct(data.Dataset): 170 | def __init__(self, structFile, nNegSample=1000, nNeg=10, margin=0.1, input_transform=None): 171 | super().__init__() 172 | 173 | self.input_transform = input_transform 174 | self.margin = margin 175 | 176 | self.dbStruct = parse_dbStruct(structFile) 177 | self.whichSet = self.dbStruct.whichSet 178 | self.dataset = self.dbStruct.dataset 179 | self.nNegSample = nNegSample # number of negatives to randomly sample 180 | self.nNeg = nNeg # number of negatives used for training 181 | 182 | # potential positives are those within nontrivial threshold range 183 | #fit NN to find them, search by radius 184 | knn = NearestNeighbors(n_jobs=-1) 185 | knn.fit(self.dbStruct.utmDb) 186 | 187 | # TODO use sqeuclidean as metric? 188 | self.nontrivial_positives = list(knn.radius_neighbors(self.dbStruct.utmQ, 189 | radius=self.dbStruct.nonTrivPosDistSqThr**0.5, 190 | return_distance=False)) 191 | # radius returns unsorted, sort once now so we dont have to later 192 | for i,posi in enumerate(self.nontrivial_positives): 193 | self.nontrivial_positives[i] = np.sort(posi) 194 | # its possible some queries don't have any non trivial potential positives 195 | # lets filter those out 196 | self.queries = np.where(np.array([len(x) for x in self.nontrivial_positives])>0)[0] 197 | 198 | # potential negatives are those outside of posDistThr range 199 | potential_positives = knn.radius_neighbors(self.dbStruct.utmQ, 200 | radius=self.dbStruct.posDistThr, 201 | return_distance=False) 202 | 203 | self.potential_negatives = [] 204 | for pos in potential_positives: 205 | self.potential_negatives.append(np.setdiff1d(np.arange(self.dbStruct.numDb), 206 | pos, assume_unique=True)) 207 | 208 | self.cache = None # filepath of HDF5 containing feature vectors for images 209 | 210 | self.negCache = [np.empty((0,)) for _ in range(self.dbStruct.numQ)] 211 | 212 | def __getitem__(self, index): 213 | index = self.queries[index] # re-map index to match dataset 214 | with h5py.File(self.cache, mode='r') as h5: 215 | h5feat = h5.get("features") 216 | 217 | qOffset = self.dbStruct.numDb 218 | qFeat = h5feat[index+qOffset] 219 | 220 | posFeat = h5feat[self.nontrivial_positives[index].tolist()] 221 | knn = NearestNeighbors(n_jobs=-1) # TODO replace with faiss? 222 | knn.fit(posFeat) 223 | dPos, posNN = knn.kneighbors(qFeat.reshape(1,-1), 1) 224 | dPos = dPos.item() 225 | posIndex = self.nontrivial_positives[index][posNN[0]].item() 226 | 227 | negSample = np.random.choice(self.potential_negatives[index], self.nNegSample) 228 | negSample = np.unique(np.concatenate([self.negCache[index], negSample])) 229 | 230 | negFeat = h5feat[list(map(int, negSample))] 231 | knn.fit(negFeat) 232 | 233 | dNeg, negNN = knn.kneighbors(qFeat.reshape(1,-1), 234 | self.nNeg*10) # to quote netvlad paper code: 10x is hacky but fine 235 | dNeg = dNeg.reshape(-1) 236 | negNN = negNN.reshape(-1) 237 | 238 | # try to find negatives that are within margin, if there aren't any return none 239 | violatingNeg = dNeg < dPos + self.margin**0.5 240 | 241 | if np.sum(violatingNeg) < 1: 242 | #if none are violating then skip this query 243 | return None 244 | 245 | negNN = negNN[violatingNeg][:self.nNeg] 246 | negIndices = negSample[negNN].astype(np.int32) 247 | self.negCache[index] = negIndices 248 | 249 | query = Image.open(join(queries_dir, self.dbStruct.qImage[index])) 250 | positive = Image.open(join(root_dir, self.dbStruct.dbImage[posIndex])) 251 | 252 | if self.input_transform: 253 | query = self.input_transform(query) 254 | positive = self.input_transform(positive) 255 | 256 | negatives = [] 257 | for negIndex in negIndices: 258 | negative = Image.open(join(root_dir, self.dbStruct.dbImage[negIndex])) 259 | if self.input_transform: 260 | negative = self.input_transform(negative) 261 | negatives.append(negative) 262 | 263 | negatives = torch.stack(negatives, 0) 264 | 265 | return query, positive, negatives, [index, posIndex]+negIndices.tolist() 266 | 267 | def __len__(self): 268 | return len(self.queries) 269 | -------------------------------------------------------------------------------- /tokyo247.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import torch.utils.data as data 4 | 5 | from os.path import join, exists 6 | from scipy.io import loadmat 7 | import numpy as np 8 | from random import randint, random 9 | from collections import namedtuple 10 | from PIL import Image 11 | 12 | from sklearn.neighbors import NearestNeighbors 13 | import h5py 14 | 15 | root_dir = '/nfs/ibrahimi/data/pittsburgh/' 16 | if not exists(root_dir): 17 | raise FileNotFoundError('root_dir is hardcoded, please adjust to point to Pittsburgh dataset') 18 | 19 | struct_dir = join(root_dir, 'datasets/') 20 | #queries_dir = join(root_dir, 'queries_real') 21 | 22 | def input_transform(): 23 | return transforms.Compose([ 24 | transforms.ToTensor(), 25 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 26 | std=[0.229, 0.224, 0.225]), 27 | ]) 28 | 29 | def get_whole_training_set(onlyDB=False): 30 | structFile = join(struct_dir, 'tokyoTM_train.mat') 31 | return WholeDatasetFromStruct(structFile, 32 | input_transform=input_transform(), 33 | onlyDB=onlyDB) 34 | 35 | def get_whole_val_set(): 36 | structFile = join(struct_dir, 'tokyoTM_val.mat') 37 | return WholeDatasetFromStruct(structFile, 38 | input_transform=input_transform()) 39 | 40 | def get_training_query_set(margin=0.1): 41 | structFile = join(struct_dir, 'tokyoTM_train.mat') 42 | return QueryDatasetFromStruct(structFile, 43 | input_transform=input_transform(), margin=margin) 44 | 45 | def get_val_query_set(): 46 | structFile = join(struct_dir, 'tokyoTM_val.mat') 47 | return QueryDatasetFromStruct(structFile, 48 | input_transform=input_transform()) 49 | 50 | dbStruct = namedtuple('dbStruct', ['whichSet', 'dataset', 51 | 'dbImage', 'utmDb', 'qImage', 'utmQ', 'numDb', 'numQ', 52 | 'posDistThr', 'posDistSqThr', 'nonTrivPosDistSqThr']) 53 | 54 | def parse_dbStruct(path): 55 | mat = loadmat(path) 56 | matStruct = mat['dbStruct'].item() 57 | 58 | whichSet = matStruct[0].item() 59 | 60 | dbImage = [f[0].item() for f in matStruct[1]] 61 | utmDb = matStruct[2].T 62 | 63 | qImage = [f[0].item() for f in matStruct[4]] 64 | utmQ = matStruct[5].T 65 | 66 | numDb = matStruct[7].item() 67 | numQ = matStruct[8].item() 68 | 69 | posDistThr = matStruct[9].item() 70 | posDistSqThr = matStruct[10].item() 71 | nonTrivPosDistSqThr = matStruct[11].item() 72 | 73 | return dbStruct(whichSet, dataset, dbImage, utmDb, qImage, 74 | utmQ, numDb, numQ, posDistThr, 75 | posDistSqThr, nonTrivPosDistSqThr) 76 | 77 | class WholeDatasetFromStruct(data.Dataset): 78 | def __init__(self, structFile, input_transform=None, onlyDB=False): 79 | super().__init__() 80 | 81 | self.input_transform = input_transform 82 | 83 | self.dbStruct = parse_dbStruct(structFile) 84 | self.images = [join(root_dir, dbIm) for dbIm in self.dbStruct.dbImage] 85 | if not onlyDB: 86 | self.images += [join(queries_dir, qIm) for qIm in self.dbStruct.qImage] 87 | 88 | self.whichSet = self.dbStruct.whichSet 89 | self.dataset = self.dbStruct.dataset 90 | 91 | self.positives = None 92 | self.distances = None 93 | 94 | def __getitem__(self, index): 95 | img = Image.open(self.images[index]) 96 | 97 | if self.input_transform: 98 | img = self.input_transform(img) 99 | 100 | return img, index 101 | 102 | def __len__(self): 103 | return len(self.images) 104 | 105 | def getPositives(self): 106 | # positives for evaluation are those within trivial threshold range 107 | #fit NN to find them, search by radius 108 | if self.positives is None: 109 | knn = NearestNeighbors(n_jobs=-1) 110 | knn.fit(self.dbStruct.utmDb) 111 | 112 | self.distances, self.positives = knn.radius_neighbors(self.dbStruct.utmQ, 113 | radius=self.dbStruct.posDistThr) 114 | 115 | return self.positives 116 | 117 | def collate_fn(batch): 118 | """Creates mini-batch tensors from the list of tuples (query, positive, negatives). 119 | 120 | Args: 121 | data: list of tuple (query, positive, negatives). 122 | - query: torch tensor of shape (3, h, w). 123 | - positive: torch tensor of shape (3, h, w). 124 | - negative: torch tensor of shape (n, 3, h, w). 125 | Returns: 126 | query: torch tensor of shape (batch_size, 3, h, w). 127 | positive: torch tensor of shape (batch_size, 3, h, w). 128 | negatives: torch tensor of shape (batch_size, n, 3, h, w). 129 | """ 130 | 131 | batch = list(filter (lambda x:x is not None, batch)) 132 | if len(batch) == 0: return None, None, None, None, None 133 | 134 | query, positive, negatives, indices = zip(*batch) 135 | 136 | query = data.dataloader.default_collate(query) 137 | positive = data.dataloader.default_collate(positive) 138 | negCounts = data.dataloader.default_collate([x.shape[0] for x in negatives]) 139 | negatives = torch.cat(negatives, 0) 140 | import itertools 141 | indices = list(itertools.chain(*indices)) 142 | 143 | return query, positive, negatives, negCounts, indices 144 | 145 | class QueryDatasetFromStruct(data.Dataset): 146 | def __init__(self, structFile, nNegSample=1000, nNeg=10, margin=0.1, input_transform=None): 147 | super().__init__() 148 | 149 | self.input_transform = input_transform 150 | self.margin = margin 151 | 152 | self.dbStruct = parse_dbStruct(structFile) 153 | self.whichSet = self.dbStruct.whichSet 154 | self.dataset = self.dbStruct.dataset 155 | self.nNegSample = nNegSample # number of negatives to randomly sample 156 | self.nNeg = nNeg # number of negatives used for training 157 | 158 | # potential positives are those within nontrivial threshold range 159 | #fit NN to find them, search by radius 160 | knn = NearestNeighbors(n_jobs=-1) 161 | knn.fit(self.dbStruct.utmDb) 162 | 163 | # TODO use sqeuclidean as metric? 164 | self.nontrivial_positives = list(knn.radius_neighbors(self.dbStruct.utmQ, 165 | radius=self.dbStruct.nonTrivPosDistSqThr**0.5, 166 | return_distance=False)) 167 | # radius returns unsorted, sort once now so we dont have to later 168 | for i,posi in enumerate(self.nontrivial_positives): 169 | self.nontrivial_positives[i] = np.sort(posi) 170 | # its possible some queries don't have any non trivial potential positives 171 | # lets filter those out 172 | self.queries = np.where(np.array([len(x) for x in self.nontrivial_positives])>0)[0] 173 | 174 | # potential negatives are those outside of posDistThr range 175 | potential_positives = knn.radius_neighbors(self.dbStruct.utmQ, 176 | radius=self.dbStruct.posDistThr, 177 | return_distance=False) 178 | 179 | self.potential_negatives = [] 180 | for pos in potential_positives: 181 | self.potential_negatives.append(np.setdiff1d(np.arange(self.dbStruct.numDb), 182 | pos, assume_unique=True)) 183 | 184 | self.cache = None # filepath of HDF5 containing feature vectors for images 185 | 186 | self.negCache = [np.empty((0,)) for _ in range(self.dbStruct.numQ)] 187 | 188 | def __getitem__(self, index): 189 | index = self.queries[index] # re-map index to match dataset 190 | with h5py.File(self.cache, mode='r') as h5: 191 | h5feat = h5.get("features") 192 | 193 | qOffset = self.dbStruct.numDb 194 | qFeat = h5feat[index+qOffset] 195 | 196 | posFeat = h5feat[self.nontrivial_positives[index].tolist()] 197 | knn = NearestNeighbors(n_jobs=-1) # TODO replace with faiss? 198 | knn.fit(posFeat) 199 | dPos, posNN = knn.kneighbors(qFeat.reshape(1,-1), 1) 200 | dPos = dPos.item() 201 | posIndex = self.nontrivial_positives[index][posNN[0]].item() 202 | 203 | negSample = np.random.choice(self.potential_negatives[index], self.nNegSample) 204 | negSample = np.unique(np.concatenate([self.negCache[index], negSample])) 205 | 206 | negFeat = h5feat[negSample.tolist()] 207 | knn.fit(negFeat) 208 | 209 | dNeg, negNN = knn.kneighbors(qFeat.reshape(1,-1), 210 | self.nNeg*10) # to quote netvlad paper code: 10x is hacky but fine 211 | dNeg = dNeg.reshape(-1) 212 | negNN = negNN.reshape(-1) 213 | 214 | # try to find negatives that are within margin, if there aren't any return none 215 | violatingNeg = dNeg < dPos + self.margin**0.5 216 | 217 | if np.sum(violatingNeg) < 1: 218 | #if none are violating then skip this query 219 | return None 220 | 221 | negNN = negNN[violatingNeg][:self.nNeg] 222 | negIndices = negSample[negNN].astype(np.int32) 223 | self.negCache[index] = negIndices 224 | 225 | query = Image.open(join(queries_dir, self.dbStruct.qImage[index])) 226 | positive = Image.open(join(root_dir, self.dbStruct.dbImage[posIndex])) 227 | 228 | if self.input_transform: 229 | query = self.input_transform(query) 230 | positive = self.input_transform(positive) 231 | 232 | negatives = [] 233 | for negIndex in negIndices: 234 | negative = Image.open(join(root_dir, self.dbStruct.dbImage[negIndex])) 235 | if self.input_transform: 236 | negative = self.input_transform(negative) 237 | negatives.append(negative) 238 | 239 | negatives = torch.stack(negatives, 0) 240 | 241 | return query, positive, negatives, [index, posIndex]+negIndices.tolist() 242 | 243 | def __len__(self): 244 | return len(self.queries) 245 | --------------------------------------------------------------------------------