├── .gitignore ├── LICENSE ├── README.md ├── assets └── seqmatchnet.jpg ├── datasets.py ├── download.sh ├── get_datasets.py ├── get_models.py ├── main.py ├── seqMatchNet.py ├── structFiles ├── nordland_test_d-1_d2-1.db ├── nordland_train_d-40_d2-10.db ├── nordland_val_d-1_d2-1.db ├── oxford-v1.0_2015-03-17-11-08-44_2014-12-16-18-44-24_test_d-10_d2-5.db ├── oxford-v1.0_2015-03-17-11-08-44_2014-12-16-18-44-24_train_d-20_d2-5.db └── oxford-v1.0_splitInds.npz ├── test.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb* 2 | *__pycache__/* 3 | data 4 | data/* 5 | wandb/* 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Sourav Garg 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SeqMatchNet 2 | Code for the CoRL 2021 Oral paper "SeqMatchNet: Contrastive Learning with Sequence Matching for Place Recognition and Relocalization" 3 | 4 | [[OpenReview](https://openreview.net/forum?id=OQMXb0xiCrt)] [[PDF](https://openreview.net/pdf?id=OQMXb0xiCrt)] [[CoRL 2021 YouTube Video](https://www.youtube.com/watch?v=Rb2Tbu72rG0)] 5 | 6 |

7 | 8 |
SeqMatchNet: Contrastive Learning with Sequence Matching. 9 |

10 | 11 | ## Setup 12 | ### Conda 13 | ```bash 14 | conda create -n seqnet numpy pytorch=1.8.0 torchvision tqdm scikit-learn faiss tensorboardx h5py wandb -c pytorch -c conda-forge 15 | ``` 16 | 17 | ### Download 18 | Run `bash download.sh` to download single image NetVLAD descriptors (3.4 GB) for the Nordland-clean dataset [[a]](#nordclean) and the Oxford dataset (0.3 GB) [[b]](#saveLoc). 19 | 20 | You can download model trained on Oxford from [here](https://cloudstor.aarnet.edu.au/plus/s/y27PHvmZ2xpmcId). 21 | ## Run 22 | 23 | ### Train 24 | To train SeqMatchNet on the Oxford dataset with both the loss and negative mining based on sequence matching: 25 | ```python 26 | python main.py --mode train --seqL 5 --pooling --dataset oxford-v1.0 --loss_trip_method meanOfPairs --neg_trip_method meanOfPairs --expName ox10_MoP_negMoP 27 | ``` 28 | For the Nordland dataset: 29 | ```python 30 | python main.py --mode train --seqL 5 --pooling --dataset nordland-sw --loss_trip_method meanOfPairs --neg_trip_method meanOfPairs --expName nord-sw_MoP_negMoP 31 | ``` 32 | 33 | To train without sequence matching: 34 | ```python 35 | python main.py --mode train --seqL 5 --pooling --dataset oxford-v1.0 --loss_trip_method centerOnly --neg_trip_method centerOnly --expName ox10_CO_negCO 36 | ``` 37 | 38 | ### Test 39 | ```python 40 | python main.py --mode test --seqL 5 --pooling -dataset oxford-v1.0 --split test --resume ./data/runs/ 41 | ``` 42 | 43 | ## Acknowledgement 44 | The code in this repository is based on [oravus/seqNet](https://github.com/oravus/seqNet) and [Nanne/pytorch-NetVlad](https://github.com/Nanne/pytorch-NetVlad). 45 | 46 | ## Citation 47 | ``` 48 | @inproceedings{garg2021seqmatchnet, 49 | title={SeqMatchNet: Contrastive Learning with Sequence Matching for Place Recognition \& Relocalization}, 50 | author={Garg, Sourav and Vankadari, Madhu and Milford, Michael}, 51 | booktitle={5th Annual Conference on Robot Learning}, 52 | year={2021} 53 | } 54 | ``` 55 | 56 | #### Other Related Projects 57 | [SeqNet](https://github.com/oravus/seqNet); 58 | [Delta Descriptors (2020)](https://github.com/oravus/DeltaDescriptors); 59 | [Patch-NetVLAD (2021)](https://github.com/QVPR/Patch-NetVLAD); 60 | [CoarseHash (2020)](https://github.com/oravus/CoarseHash); 61 | [seq2single (2019)](https://github.com/oravus/seq2single); 62 | [LoST (2018)](https://github.com/oravus/lostX) 63 | 64 | [a] This is the clean version of the dataset that excludes images from the tunnels and red lights and can be downloaded from [here](https://cloudstor.aarnet.edu.au/plus/s/8L7loyTZjK0FsWT). 65 | 66 | [b] These will automatically save to `./data/`, you can modify this path in [download.sh](https://github.com/oravus/seqNet/blob/main/download.sh) and [get_datasets.py](https://github.com/oravus/seqNet/blob/5450829c4294fe1d14966bfa1ac9b7c93237369b/get_datasets.py#L6) to specify your workdir. 67 | -------------------------------------------------------------------------------- /assets/seqmatchnet.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oravus/SeqMatchNet/431f2a963587e1bbfd3f840d7d6c299207f8d88c/assets/seqmatchnet.jpg -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import itertools 4 | 5 | import os 6 | from os.path import join, exists 7 | from scipy.io import loadmat, savemat 8 | import numpy as np 9 | from collections import namedtuple 10 | 11 | from sklearn.neighbors import NearestNeighbors 12 | import faiss 13 | import h5py 14 | 15 | dbStruct = namedtuple('dbStruct', ['whichSet', 'dataset', 16 | 'dbImage', 'utmDb', 'qImage', 'utmQ', 'numDb', 'numQ', 17 | 'posDistThr', 'posDistSqThr', 'nonTrivPosDistSqThr', 18 | 'dbTimeStamp', 'qTimeStamp', 'gpsDb', 'gpsQ']) 19 | 20 | class Dataset(): 21 | def __init__(self, dataset_name, train_mat_file, test_mat_file, val_mat_file, opt): 22 | self.dataset_name = dataset_name 23 | self.train_mat_file = train_mat_file 24 | self.test_mat_file = test_mat_file 25 | self.val_mat_file = val_mat_file 26 | self.struct_dir = "./structFiles/" 27 | self.seqL = opt.seqL 28 | self.seqL_filterData = opt.seqL_filterData 29 | self.matcher = opt.matcher 30 | 31 | # descriptor settings 32 | self.dbDescs = None 33 | self.qDescs = None 34 | self.trainInds = None 35 | self.valInds = None 36 | self.testInds = None 37 | self.db_seqBounds = None 38 | self.q_seqBounds = None 39 | 40 | def loadPreComputedDescriptors(self,ft1,ft2,seqBounds=None): 41 | self.dbDescs = np.expand_dims(ft1,(2,3)) 42 | self.qDescs = np.expand_dims(ft2,(2,3)) 43 | encDim = self.dbDescs.shape[1] 44 | print("All Db descs: ", self.dbDescs.shape) 45 | print("All Qry descs: ", self.qDescs.shape) 46 | if seqBounds is None: 47 | self.db_seqBounds = None 48 | self.q_seqBounds = None 49 | else: 50 | self.db_seqBounds = seqBounds[0] 51 | self.q_seqBounds = seqBounds[1] 52 | return encDim 53 | 54 | def get_whole_training_set(self, onlyDB=False): 55 | structFile = join(self.struct_dir, self.train_mat_file) 56 | indsSplit = self.trainInds 57 | return WholeDatasetFromStruct( structFile, indsSplit, self.dbDescs, self.qDescs, seqL=self.seqL, onlyDB=onlyDB, seqBounds=[self.db_seqBounds,self.q_seqBounds],seqL_filterData=self.seqL) 58 | 59 | def get_whole_val_set(self): 60 | structFile = join(self.struct_dir, self.val_mat_file) 61 | indsSplit = self.valInds 62 | return WholeDatasetFromStruct(structFile, indsSplit, self.dbDescs, self.qDescs, seqL=self.seqL, seqBounds=[self.db_seqBounds,self.q_seqBounds],seqL_filterData=self.seqL_filterData) 63 | 64 | def get_whole_test_set(self): 65 | if self.test_mat_file is not None: 66 | structFile = join(self.struct_dir, self.test_mat_file) 67 | indsSplit = self.testInds 68 | return WholeDatasetFromStruct(structFile, indsSplit, self.dbDescs, self.qDescs, seqL=self.seqL, seqBounds=[self.db_seqBounds,self.q_seqBounds],seqL_filterData=self.seqL_filterData) 69 | else: 70 | raise ValueError('test set not available for dataset ' + self.dataset_name) 71 | 72 | def get_training_query_set(self, margin=0.1, nNegSample=1000): 73 | structFile = join(self.struct_dir, self.train_mat_file) 74 | indsSplit = self.trainInds 75 | return QueryDatasetFromStruct(structFile,indsSplit, self.dbDescs, self.qDescs, nNegSample=nNegSample, margin=margin, seqL=self.seqL, seqBounds=[self.db_seqBounds,self.q_seqBounds], matcher=self.matcher) 76 | 77 | def get_val_query_set(self): 78 | structFile = join(self.struct_dir, self.val_mat_file) 79 | indsSplit = self.valInds 80 | return QueryDatasetFromStruct(structFile, indsSplit, self.dbDescs, self.qDescs, seqL=self.seqL, seqBounds=[self.db_seqBounds,self.q_seqBounds]) 81 | 82 | @staticmethod 83 | def collate_fn(batch): 84 | """Creates mini-batch tensors from the list of tuples (query, positive, negatives). 85 | 86 | Args: 87 | batch: list of tuple (query, positive, negatives). 88 | - query: torch tensor of shape (T, C). e.g. (5,4096) 89 | - positive: torch tensor of shape (T, C). 90 | - negative: torch tensor of shape (N, T, C). 91 | Returns: 92 | query: torch tensor of shape (batch_size, T, C). 93 | positive: torch tensor of shape (batch_size, T, C). 94 | negatives: torch tensor of shape (batch_size, T, C). 95 | """ 96 | 97 | batch = list(filter(lambda x: x is not None, batch)) 98 | if len(batch) == 0: 99 | return None, None, None, None, None 100 | 101 | query, positive, negatives, indices = zip(*batch) 102 | 103 | query = data.dataloader.default_collate(query) 104 | positive = data.dataloader.default_collate(positive) 105 | negCounts = data.dataloader.default_collate([x.shape[0] for x in negatives]) 106 | negatives = torch.cat(negatives, 0) 107 | indices = list(itertools.chain(*indices)) 108 | 109 | return query, positive, negatives, negCounts, indices 110 | 111 | def getSeqInds(idx,seqL,maxNum,minNum=0,retLenDiff=False): 112 | seqLOrig = seqL 113 | seqInds = np.arange(max(minNum,idx-seqL//2),min(idx+seqL-seqL//2,maxNum),1) 114 | lenDiff = seqLOrig - len(seqInds) 115 | if retLenDiff: 116 | return lenDiff 117 | 118 | if seqInds[0] == minNum: 119 | seqInds = np.concatenate([seqInds,np.arange(seqInds[-1]+1,seqInds[-1]+1+lenDiff,1)]) 120 | elif lenDiff > 0 and seqInds[-1] in range(maxNum-1,maxNum): 121 | seqInds = np.concatenate([np.arange(seqInds[0]-lenDiff,seqInds[0],1),seqInds]) 122 | return seqInds 123 | 124 | def getValidSeqInds(seqBounds,seqL): 125 | validFlags = [] 126 | for i in range(len(seqBounds)): 127 | sIdMin, sIdMax = seqBounds[i] 128 | lenDiff = getSeqInds(i,seqL,sIdMax,minNum=sIdMin,retLenDiff=True) 129 | validFlags.append(True if lenDiff == 0 else False) 130 | return validFlags 131 | 132 | def parse_db_struct(path): 133 | mat = loadmat(path) 134 | 135 | fieldnames = list(mat['dbStruct'][0, 0].dtype.names) 136 | 137 | dataset = mat['dbStruct'][0, 0]['dataset'].item() 138 | whichSet = mat['dbStruct'][0, 0]['whichSet'].item() 139 | 140 | dbImage = [f[0].item() for f in mat['dbStruct'][0, 0]['dbImageFns']] 141 | qImage = [f[0].item() for f in mat['dbStruct'][0, 0]['qImageFns']] 142 | 143 | numDb = mat['dbStruct'][0, 0]['numImages'].item() 144 | numQ = mat['dbStruct'][0, 0]['numQueries'].item() 145 | 146 | posDistThr = mat['dbStruct'][0, 0]['posDistThr'].item() 147 | posDistSqThr = mat['dbStruct'][0, 0]['posDistSqThr'].item() 148 | if 'nonTrivPosDistSqThr' in fieldnames: 149 | nonTrivPosDistSqThr = mat['dbStruct'][0, 0]['nonTrivPosDistSqThr'].item() 150 | else: 151 | nonTrivPosDistSqThr = None 152 | 153 | if 'dbTimeStamp' in fieldnames and 'qTimeStamp' in fieldnames: 154 | dbTimeStamp = [f[0].item() for f in mat['dbStruct'][0, 0]['dbTimeStamp'].T] 155 | qTimeStamp = [f[0].item() for f in mat['dbStruct'][0, 0]['qTimeStamp'].T] 156 | dbTimeStamp = np.array(dbTimeStamp) 157 | qTimeStamp = np.array(qTimeStamp) 158 | else: 159 | dbTimeStamp = None 160 | qTimeStamp = None 161 | 162 | if 'utmQ' in fieldnames and 'utmDb' in fieldnames: 163 | utmDb = mat['dbStruct'][0, 0]['utmDb'].T 164 | utmQ = mat['dbStruct'][0, 0]['utmQ'].T 165 | else: 166 | utmQ = None 167 | utmDb = None 168 | 169 | if 'gpsQ' in fieldnames and 'gpsDb' in fieldnames: 170 | gpsDb = mat['dbStruct'][0, 0]['gpsDb'].T 171 | gpsQ = mat['dbStruct'][0, 0]['gpsQ'].T 172 | else: 173 | gpsQ = None 174 | gpsDb = None 175 | 176 | return dbStruct(whichSet, dataset, dbImage, utmDb, qImage, utmQ, numDb, numQ, posDistThr, 177 | posDistSqThr, nonTrivPosDistSqThr, dbTimeStamp, qTimeStamp, gpsQ, gpsDb) 178 | 179 | 180 | def save_db_struct(path, db_struct): 181 | assert db_struct.numDb == len(db_struct.dbImage) 182 | assert db_struct.numQ == len(db_struct.qImage) 183 | 184 | inner_dict = { 185 | 'whichSet': db_struct.whichSet, 186 | 'dbImageFns': np.array(db_struct.dbImage, dtype=np.object).reshape(-1, 1), 187 | 'qImageFns': np.array(db_struct.qImage, dtype=np.object).reshape(-1, 1), 188 | 'numImages': db_struct.numDb, 189 | 'numQueries': db_struct.numQ, 190 | 'posDistThr': db_struct.posDistThr, 191 | 'posDistSqThr': db_struct.posDistSqThr, 192 | } 193 | 194 | if db_struct.dataset is not None: 195 | inner_dict['dataset'] = db_struct.dataset 196 | 197 | if db_struct.nonTrivPosDistSqThr is not None: 198 | inner_dict['nonTrivPosDistSqThr'] = db_struct.nonTrivPosDistSqThr 199 | 200 | if db_struct.utmDb is not None and db_struct.utmQ is not None: 201 | assert db_struct.numDb == len(db_struct.utmDb) 202 | assert db_struct.numQ == len(db_struct.utmQ) 203 | inner_dict['utmDb'] = db_struct.utmDb.T 204 | inner_dict['utmQ'] = db_struct.utmQ.T 205 | 206 | if db_struct.gpsDb is not None and db_struct.gpsQ is not None: 207 | assert db_struct.numDb == len(db_struct.gpsDb) 208 | assert db_struct.numQ == len(db_struct.gpsQ) 209 | inner_dict['gpsDb'] = db_struct.gpsDb.T 210 | inner_dict['gpsQ'] = db_struct.gpsQ.T 211 | 212 | if db_struct.dbTimeStamp is not None and db_struct.qTimeStamp is not None: 213 | inner_dict['dbTimeStamp'] = db_struct.dbTimeStamp.astype(np.float64) 214 | inner_dict['qTimeStamp'] = db_struct.qTimeStamp.astype(np.float64) 215 | 216 | savemat(path, {'dbStruct': inner_dict}) 217 | 218 | def getValidSeqData(seqBounds,seqL_filterData,dbStruct): 219 | validFlags = getValidSeqInds(seqBounds,seqL_filterData) 220 | validInds = np.argwhere(validFlags).flatten() 221 | validInds_db = np.argwhere(validFlags[:dbStruct.numDb]).flatten() 222 | validInds_q = np.argwhere(validFlags[dbStruct.numDb:]).flatten() 223 | dbStruct = dbStruct._replace(utmDb=dbStruct.utmDb[validInds_db], numDb=len(validInds_db), utmQ=dbStruct.utmQ[validInds_q], numQ=len(validInds_q)) 224 | print("\n Num sequences violating boundaries: {} \n".format(len(validFlags)-len(validInds))) 225 | return validInds, dbStruct, validInds_db, validInds_q 226 | 227 | class WholeDatasetFromStruct(data.Dataset): 228 | def __init__(self, structFile, indsSplit, dbDescs, qDescs, onlyDB=False, seqL=1, seqBounds=None,seqL_filterData=None): 229 | super().__init__() 230 | 231 | self.seqL = seqL 232 | self.filterBoundaryInds = False if seqL_filterData is None else True 233 | 234 | self.dbStruct = parse_db_struct(structFile) 235 | 236 | self.images = dbDescs[indsSplit[0]] 237 | 238 | if seqBounds[0] is None: 239 | self.seqBounds = np.array([[0,len(self.images)] for _ in range(len(self.images))]) 240 | 241 | if not onlyDB: 242 | qImages = qDescs[indsSplit[1]] 243 | self.images = np.concatenate([self.images,qImages],0) 244 | if seqBounds[0] is None: 245 | q_seqBounds = np.array([[len(self.seqBounds),len(self.images)] for _ in range(len(qImages))]) 246 | self.seqBounds = np.vstack([self.seqBounds,q_seqBounds]) 247 | 248 | if seqBounds[0] is not None: 249 | db_seqBounds = seqBounds[0][indsSplit[0]] 250 | q_seqBounds = db_seqBounds[-1,-1] + seqBounds[1][indsSplit[1]] 251 | self.seqBounds = np.vstack([db_seqBounds,q_seqBounds]) 252 | 253 | self.validInds = np.arange(len(self.images)) 254 | self.numDb_full, self.numQ_full = self.dbStruct.numDb, self.dbStruct.numQ 255 | # update dbStruct and size variables with valid sequence based indices 256 | if self.filterBoundaryInds: 257 | self.validInds, self.dbStruct, _, _ = getValidSeqData(self.seqBounds,seqL_filterData,self.dbStruct) 258 | self.numDb_valid, self.numQ_valid = self.dbStruct.numDb, self.dbStruct.numQ 259 | 260 | self.whichSet = self.dbStruct.whichSet 261 | self.dataset = self.dbStruct.dataset 262 | 263 | self.positives = None 264 | self.distances = None 265 | 266 | def getSeqIndsFromValidInds(self,index): 267 | sIdMin, sIdMax = self.seqBounds[index] 268 | return getSeqInds(index,self.seqL,sIdMax,minNum=sIdMin) 269 | 270 | def getIndices(self,index): 271 | index = self.validInds[index] 272 | sIdMin, sIdMax = self.seqBounds[index] 273 | return getSeqInds(index,self.seqL,sIdMax,minNum=sIdMin) 274 | 275 | def __getitem__(self, index): 276 | img = self.images[np.array([index])] 277 | return img, index 278 | 279 | def __len__(self): 280 | return len(self.images) 281 | 282 | def get_positives(self,retDists=False): 283 | # positives for evaluation are those within trivial threshold range 284 | # fit NN to find them, search by radius 285 | if self.positives is None: 286 | knn = NearestNeighbors(n_jobs=-1) 287 | knn.fit(self.dbStruct.utmDb) 288 | 289 | print("Using Localization Radius: ", self.dbStruct.posDistThr) 290 | self.distances, self.positives = knn.radius_neighbors(self.dbStruct.utmQ, radius=self.dbStruct.posDistThr) 291 | 292 | if retDists: 293 | return self.positives, self.distances 294 | else: 295 | return self.positives 296 | 297 | 298 | class QueryDatasetFromStruct(data.Dataset): 299 | def __init__(self, structFile, indsSplit, dbDescs, qDescs, nNegSample=1000, nNeg=10, margin=0.1, seqL=1, seqBounds=None, matcher=None): 300 | super().__init__() 301 | 302 | self.matcher = matcher 303 | self.seqL = seqL 304 | 305 | self.dbDescs = dbDescs[indsSplit[0]] 306 | self.qDescs = qDescs[indsSplit[1]] 307 | 308 | self.margin = margin 309 | 310 | self.dbStruct = parse_db_struct(structFile) 311 | 312 | if seqBounds[0] is None: 313 | self.db_seqBounds = np.array([[0,len(self.dbDescs)] for _ in range(len(self.dbDescs))]) 314 | self.q_seqBounds = np.array([[0,len(self.qDescs)] for _ in range(len(self.qDescs))]) 315 | else: 316 | self.db_seqBounds = seqBounds[0][indsSplit[0]] 317 | self.q_seqBounds = seqBounds[1][indsSplit[1]] 318 | 319 | # update dbStruct and size variables with valid sequnence based indices 320 | self.numDb_full, self.numQ_full = self.dbStruct.numDb, self.dbStruct.numQ 321 | validFlags_db = getValidSeqInds(self.db_seqBounds,seqL) 322 | validFlags_q = getValidSeqInds(self.q_seqBounds,seqL) 323 | self.validInds_db, self.validInds_q = np.argwhere(validFlags_db).flatten(), np.argwhere(validFlags_q).flatten() 324 | self.dbStruct = self.dbStruct._replace(utmDb=self.dbStruct.utmDb[self.validInds_db], numDb=len(self.validInds_db), utmQ=self.dbStruct.utmQ[self.validInds_q], numQ=len(self.validInds_q)) 325 | # self.numDb_valid, self.numQ_valid = self.dbStruct.numDb, self.dbStruct.numQ 326 | print("\n Num sequences (db) violating boundaries: {}".format(len(validFlags_db)-len(self.validInds_db))) 327 | print("Num sequences (q) violating boundaries: {} \n".format(len(validFlags_q)-len(self.validInds_q))) 328 | 329 | self.whichSet = self.dbStruct.whichSet 330 | self.dataset = self.dbStruct.dataset 331 | self.nNegSample = nNegSample # number of negatives to randomly sample 332 | self.nNeg = nNeg # number of negatives used for training 333 | 334 | self.use_faiss = True 335 | self.use_h5disMat = False 336 | if self.matcher is not None: 337 | self.use_faiss = False 338 | self.use_h5disMat = True 339 | assert(self.use_faiss!=self.use_h5disMat) 340 | 341 | # potential positives are those within nontrivial threshold range 342 | # fit NN to find them, search by radius 343 | knn = NearestNeighbors(n_jobs=-1) 344 | knn.fit(self.dbStruct.utmDb) 345 | 346 | # TODO use sqeuclidean as metric? 347 | self.nontrivial_distances, self.nontrivial_positives = \ 348 | knn.radius_neighbors(self.dbStruct.utmQ, radius=self.dbStruct.nonTrivPosDistSqThr**0.5, 349 | return_distance=True) 350 | 351 | self.nontrivial_positives = list(self.nontrivial_positives) 352 | 353 | # radius returns unsorted, sort once now so we dont have to later 354 | for i, posi in enumerate(self.nontrivial_positives): 355 | self.nontrivial_positives[i] = np.sort(posi) 356 | 357 | # its possible some queries don't have any non trivial potential positives 358 | # lets filter those out 359 | self.queries = np.where(np.array([len(x) for x in self.nontrivial_positives]) > 0)[0] 360 | print("\n Queries within range ",len(self.queries), len(self.nontrivial_positives),"\n") 361 | 362 | # potential negatives are those outside of posDistThr range 363 | potential_positives = knn.radius_neighbors(self.dbStruct.utmQ, 364 | radius=self.dbStruct.posDistThr, 365 | return_distance=False) 366 | 367 | self.potential_negatives = [] 368 | for pos in potential_positives: 369 | self.potential_negatives.append(np.setdiff1d(np.arange(self.dbStruct.numDb), pos, assume_unique=True)) 370 | 371 | self.cache = None # filepath of HDF5 containing feature vectors for images 372 | self.h5feat = None 373 | 374 | self.negCache = [np.empty((0,)) for _ in range(self.dbStruct.numQ)] 375 | 376 | def __getitem__(self, index): 377 | with h5py.File(self.cache, mode='r') as h5: 378 | h5feat = h5.get("features") 379 | h5disMat = h5.get("disMat") 380 | 381 | qOffset = self.dbStruct.numDb 382 | qFeat = h5feat[index + qOffset] 383 | qDis = h5disMat[:,index] 384 | 385 | if len(self.nontrivial_positives[index]) < 1: 386 | # if none are violating then skip this query 387 | return None 388 | 389 | if self.use_h5disMat: 390 | posDis = qDis[self.nontrivial_positives[index]] 391 | posNN = np.argmin(posDis) 392 | dPos = posDis[posNN] 393 | posIndex = self.nontrivial_positives[index][posNN] 394 | else: 395 | posFeat = h5feat[self.nontrivial_positives[index].tolist()] 396 | if self.use_faiss: 397 | faiss_index = faiss.IndexFlatL2(posFeat.shape[1]) 398 | # noinspection PyArgumentList 399 | faiss_index.add(posFeat) 400 | # noinspection PyArgumentList 401 | dPos, posNN = faiss_index.search(qFeat.reshape(1, -1), 1)#posFeat.shape[0]) 402 | dPos = np.sqrt(dPos) # faiss returns squared distance 403 | dPos = dPos[0][-1].item() 404 | posIndex = self.nontrivial_positives[index][posNN[0,-1]].item() 405 | else: 406 | knn = NearestNeighbors(n_jobs=-1) 407 | knn.fit(posFeat) 408 | dPos, posNN = knn.kneighbors(qFeat.reshape(1, -1), 1)#posFeat.shape[0]) 409 | dPos = dPos[0][-1].item() 410 | posIndex = self.nontrivial_positives[index][posNN[0,-1]].item() 411 | 412 | if self.use_h5disMat: 413 | # negSample = self.potential_negatives[index] 414 | negSample = np.random.choice(self.potential_negatives[index], self.nNegSample) 415 | negSample = np.unique(np.concatenate([self.negCache[index], negSample])) 416 | negSample = np.sort(negSample).astype(int) 417 | negDis = qDis[negSample] 418 | negNN = np.argsort(negDis)[:self.nNeg * 10] 419 | dNeg = negDis[negNN] 420 | else: 421 | negSample = np.random.choice(self.potential_negatives[index], self.nNegSample) 422 | negSample = np.unique(np.concatenate([self.negCache[index], negSample])) 423 | negSample = np.sort(negSample).astype(int) #essential to order ascending, speeds up h5 by about double 424 | negFeat = h5feat[negSample.tolist()] 425 | if self.use_faiss: 426 | faiss_index = faiss.IndexFlatL2(posFeat.shape[1]) 427 | # noinspection PyArgumentList 428 | faiss_index.add(negFeat) 429 | # noinspection PyArgumentList 430 | dNeg, negNN = faiss_index.search(qFeat.reshape(1, -1), self.nNeg * 10) 431 | dNeg = np.sqrt(dNeg) 432 | else: 433 | knn.fit(negFeat) 434 | 435 | # to quote netvlad paper code: 10x is hacky but fine 436 | dNeg, negNN = knn.kneighbors(qFeat.reshape(1, -1), self.nNeg * 10) 437 | 438 | dNeg = dNeg.reshape(-1) 439 | negNN = negNN.reshape(-1) 440 | 441 | # try to find negatives that are within margin, if there aren't any return none 442 | violatingNeg = dNeg < dPos + self.margin**0.5 443 | 444 | if np.sum(violatingNeg) < 1: 445 | # if none are violating then skip this query 446 | return None 447 | 448 | negNN = negNN[violatingNeg][:self.nNeg] 449 | negIndices = negSample[negNN].astype(np.int32) 450 | self.negCache[index] = negIndices 451 | 452 | 453 | sIdMin_q, sIdMax_q = self.q_seqBounds[self.validInds_q[index]] 454 | query = self.qDescs[getSeqInds(self.validInds_q[index],self.seqL,sIdMax_q,sIdMin_q)] 455 | sIdMin_p, sIdMax_p = self.db_seqBounds[self.validInds_db[posIndex]] 456 | positive = self.dbDescs[getSeqInds(self.validInds_db[posIndex],self.seqL,sIdMax_p,sIdMin_p)] 457 | negatives = [] 458 | for negIndex in negIndices: 459 | sIdMin_n, sIdMax_n = self.db_seqBounds[self.validInds_db[negIndex]] 460 | negative = self.dbDescs[getSeqInds(self.validInds_db[negIndex],self.seqL,sIdMax_n,sIdMin_n)] 461 | negative = torch.tensor(negative) 462 | negatives.append(negative) 463 | 464 | negatives = torch.stack(negatives, 0) 465 | 466 | # noinspection PyTypeChecker 467 | return query, positive, negatives, [index, posIndex] + negIndices.tolist() 468 | 469 | def __len__(self): 470 | return len(self.validInds_q) 471 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | # download nordland-clean dataset 2 | wget -cO - https://cloudstor.aarnet.edu.au/plus/s/PK98pDvLAesL1aL/download > nordland-clean.zip 3 | mkdir -p ./data/ 4 | unzip nordland-clean.zip -d ./data/ 5 | rm nordland-clean.zip 6 | 7 | # download oxford descriptors 8 | wget -cO - https://cloudstor.aarnet.edu.au/plus/s/T0M1Ry4HXOAkkGz/download > oxford_2014-12-16-18-44-24_stereo_left.npy 9 | wget -cO - https://cloudstor.aarnet.edu.au/plus/s/vr21RnhMmOkW8S9/download > oxford_2015-03-17-11-08-44_stereo_left.npy 10 | mv oxford* ./data/descData/netvlad-pytorch/ -------------------------------------------------------------------------------- /get_datasets.py: -------------------------------------------------------------------------------- 1 | from datasets import Dataset 2 | from torch.utils.data import DataLoader, SubsetRandomSampler 3 | import numpy as np 4 | from os.path import join 5 | 6 | prefix_data = "./data/" 7 | 8 | def get_dataset(opt): 9 | 10 | if 'nordland' in opt.dataset.lower(): 11 | dataset = Dataset('nordland', 'nordland_train_d-40_d2-10.db', 'nordland_test_d-1_d2-1.db', 'nordland_val_d-1_d2-1.db', opt) # train, test, val structs 12 | if 'sw' in opt.dataset.lower(): 13 | ref, qry = 'summer', 'winter' 14 | elif 'sf' in opt.dataset.lower(): 15 | ref, qry = 'spring', 'fall' 16 | ft1 = np.load(join(prefix_data,"descData/{}/nordland-clean-{}.npy".format(opt.descType,ref))) 17 | ft2 = np.load(join(prefix_data,"descData/{}/nordland-clean-{}.npy".format(opt.descType,qry))) 18 | trainInds, testInds, valInds = np.arange(15000), np.arange(15100,18100), np.arange(18200,21200) 19 | 20 | dataset.trainInds = [trainInds, trainInds] 21 | dataset.valInds = [valInds, valInds] 22 | dataset.testInds = [testInds, testInds] 23 | encoder_dim = dataset.loadPreComputedDescriptors(ft1,ft2) 24 | 25 | elif 'oxford' in opt.dataset.lower(): 26 | ref, qry = '2015-03-17-11-08-44', '2014-12-16-18-44-24' 27 | structStr = "{}_{}_{}".format(opt.dataset,ref,qry) 28 | # note: for now temporarily use ox_test as ox_val 29 | if 'v1.0' in opt.dataset: 30 | testStr = '_test_d-10_d2-5.db' 31 | elif 'pnv' in opt.dataset: 32 | testStr = '_test_d-25_d2-5.db' 33 | dataset = Dataset(opt.dataset, structStr+'_train_d-20_d2-5.db', structStr+testStr, structStr+testStr, opt) # train, test, val structs 34 | ft1 = np.load(join(prefix_data,"descData/{}/oxford_{}_stereo_left.npy".format(opt.descType,ref))) 35 | ft2 = np.load(join(prefix_data,"descData/{}/oxford_{}_stereo_left.npy".format(opt.descType,qry))) 36 | splitInds = np.load("./structFiles/{}_splitInds.npz".format(opt.dataset), allow_pickle=True) 37 | 38 | dataset.trainInds = splitInds['trainInds'].tolist() 39 | dataset.valInds = splitInds['valInds'].tolist() 40 | dataset.testInds = splitInds['testInds'].tolist() 41 | encoder_dim = dataset.loadPreComputedDescriptors(ft1,ft2) 42 | 43 | else: 44 | raise Exception('Unknown dataset') 45 | 46 | return dataset, encoder_dim 47 | 48 | 49 | def get_splits(opt, dataset): 50 | whole_train_set, whole_training_data_loader, train_set, whole_test_set = None, None, None, None 51 | if opt.mode.lower() == 'train': 52 | whole_train_set = dataset.get_whole_training_set() 53 | whole_training_data_loader = DataLoader(dataset=whole_train_set, 54 | num_workers=opt.threads, batch_size=opt.cacheBatchSize, shuffle=False, 55 | pin_memory=not opt.nocuda) 56 | 57 | train_set = dataset.get_training_query_set(opt.margin) 58 | 59 | print('====> Training whole set:', len(whole_train_set)) 60 | print('====> Training query set:', len(train_set)) 61 | whole_test_set = dataset.get_whole_val_set() 62 | print('===> Evaluating on val set, query count:', whole_test_set.dbStruct.numQ) 63 | elif opt.mode.lower() == 'test': 64 | if opt.split.lower() == 'test': 65 | whole_test_set = dataset.get_whole_test_set() 66 | print('===> Evaluating on test set') 67 | elif opt.split.lower() == 'train': 68 | whole_test_set = dataset.get_whole_training_set() 69 | print('===> Evaluating on train set') 70 | elif opt.split.lower() in ['val']: 71 | whole_test_set = dataset.get_whole_val_set() 72 | print('===> Evaluating on val set') 73 | else: 74 | raise ValueError('Unknown dataset split: ' + opt.split) 75 | print('====> Query count:', whole_test_set.dbStruct.numQ) 76 | 77 | return whole_train_set, whole_training_data_loader, train_set, whole_test_set 78 | -------------------------------------------------------------------------------- /get_models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch.optim as optim 4 | from os.path import join, isfile 5 | import torchvision.models as models 6 | import torch 7 | 8 | import seqMatchNet 9 | 10 | class WXpB(nn.Module): 11 | def __init__(self, inDims, outDims): 12 | super().__init__() 13 | self.inDims = inDims 14 | self.outDims = outDims 15 | self.conv = nn.Conv1d(inDims, outDims, kernel_size=1) 16 | 17 | def forward(self, x): 18 | x = x.squeeze(-1) # convert [B,C,1,1] to [B,C,1] 19 | feat_transformed = self.conv(x) 20 | return feat_transformed.permute(0,2,1) # return [B,1,C] 21 | 22 | class Flatten(nn.Module): 23 | def forward(self, input): 24 | return input.view(input.size(0), -1) 25 | 26 | class L2Norm(nn.Module): 27 | def __init__(self, dim=1): 28 | super().__init__() 29 | self.dim = dim 30 | 31 | def forward(self, input): 32 | return F.normalize(input, p=2, dim=self.dim) 33 | 34 | def get_pooling(opt,encoder_dim): 35 | 36 | if opt.pooling: 37 | global_pool = nn.AdaptiveMaxPool2d((1,1)) # no effect 38 | poolLayers = nn.Sequential(*[global_pool, WXpB(encoder_dim, opt.outDims), L2Norm(dim=-1)]) 39 | else: 40 | global_pool = nn.AdaptiveMaxPool2d((1,1)) # no effect 41 | poolLayers = nn.Sequential(*[global_pool, Flatten(), L2Norm(dim=-1)]) 42 | return poolLayers 43 | 44 | def get_matcher(opt,device): 45 | 46 | if opt.matcher == 'seqMatchNet': 47 | sm = seqMatchNet.seqMatchNet() 48 | matcherLayers = nn.Sequential(*[sm]) 49 | else: 50 | matcherLayers = None 51 | 52 | return matcherLayers 53 | 54 | def printModelParams(model): 55 | 56 | for name, param in model.named_parameters(): 57 | if param.requires_grad: 58 | print(name, param.shape) 59 | return 60 | 61 | def get_model(opt,input_dim,device): 62 | model = nn.Module() 63 | encoder_dim = input_dim 64 | 65 | poolLayers = get_pooling(opt,encoder_dim) 66 | model.add_module('pool', poolLayers) 67 | 68 | matcherLayers = get_matcher(opt,device) 69 | if matcherLayers is not None: 70 | model.add_module('matcher',matcherLayers) 71 | 72 | isParallel = False 73 | if opt.nGPU > 1 and torch.cuda.device_count() > 1: 74 | model.pool = nn.DataParallel(model.pool) 75 | isParallel = True 76 | 77 | if not opt.resume: 78 | model = model.to(device) 79 | 80 | scheduler, optimizer, criterion = None, None, None 81 | if opt.mode.lower() == 'train': 82 | if opt.optim.upper() == 'ADAM': 83 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, 84 | model.parameters()), lr=opt.lr)#, betas=(0,0.9)) 85 | elif opt.optim.upper() == 'SGD': 86 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, 87 | model.parameters()), lr=opt.lr, 88 | momentum=opt.momentum, 89 | weight_decay=opt.weightDecay) 90 | 91 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=opt.lrStep, gamma=opt.lrGamma) 92 | else: 93 | raise ValueError('Unknown optimizer: ' + opt.optim) 94 | 95 | # used only when matcher is none 96 | criterion = nn.TripletMarginLoss(margin=opt.margin**0.5, p=2, reduction='sum').to(device) 97 | 98 | if opt.resume: 99 | if opt.ckpt.lower() == 'latest': 100 | resume_ckpt = join(opt.resume, 'checkpoints', 'checkpoint.pth.tar') 101 | elif opt.ckpt.lower() == 'best': 102 | resume_ckpt = join(opt.resume, 'checkpoints', 'model_best.pth.tar') 103 | 104 | if isfile(resume_ckpt): 105 | print("=> loading checkpoint '{}'".format(resume_ckpt)) 106 | checkpoint = torch.load(resume_ckpt, map_location=lambda storage, loc: storage) 107 | opt.update({"start_epoch" : checkpoint['epoch']}, allow_val_change=True) 108 | best_metric = checkpoint['best_score'] 109 | model.load_state_dict(checkpoint['state_dict']) 110 | model = model.to(device) 111 | if opt.mode == 'train': 112 | optimizer.load_state_dict(checkpoint['optimizer']) 113 | print("=> loaded checkpoint '{}' (epoch {})" 114 | .format(resume_ckpt, checkpoint['epoch'])) 115 | else: 116 | print("=> no checkpoint found at '{}'".format(resume_ckpt)) 117 | 118 | return model, optimizer, scheduler, criterion, isParallel, encoder_dim 119 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import random, json 4 | from os.path import join, exists 5 | from os import makedirs 6 | 7 | import wandb 8 | import torch 9 | from datetime import datetime 10 | from tqdm import tqdm 11 | 12 | from tensorboardX import SummaryWriter 13 | import numpy as np 14 | 15 | from get_datasets import get_dataset, get_splits, prefix_data 16 | from get_models import get_model 17 | from utils import save_checkpoint 18 | from train import train 19 | from test import test 20 | 21 | parser = argparse.ArgumentParser(description='SeqMatchNet') 22 | parser.add_argument('--mode', type=str, default='train', help='Mode', choices=['train', 'test']) 23 | 24 | # train settings 25 | parser.add_argument('--batchSize', type=int, default=16, help='Number of triplets (query, pos, negs). Each triplet consists of 12 images.') 26 | parser.add_argument('--cacheBatchSize', type=int, default=64, help='Batch size for caching and testing') 27 | parser.add_argument('--cacheRefreshRate', type=int, default=1000, help='How often to refresh cache, in number of queries. 0 for off') 28 | parser.add_argument('--nEpochs', type=int, default=50, help='number of epochs to train for') 29 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') 30 | parser.add_argument('--nGPU', type=int, default=1, help='number of GPU to use.') 31 | parser.add_argument('--optim', type=str, default='SGD', help='optimizer to use', choices=['SGD', 'ADAM']) 32 | parser.add_argument('--lr', type=float, default=0.001, help='Learning Rate.') 33 | parser.add_argument('--lrStep', type=float, default=5, help='Decay LR ever N steps.') 34 | parser.add_argument('--lrGamma', type=float, default=0.5, help='Multiply LR by Gamma for decaying.') 35 | parser.add_argument('--weightDecay', type=float, default=0.001, help='Weight decay for SGD.') 36 | parser.add_argument('--momentum', type=float, default=0.9, help='Momentum for SGD.') 37 | parser.add_argument('--nocuda', action='store_true', help='Dont use cuda') 38 | parser.add_argument('--threads', type=int, default=8, help='Number of threads for each data loader to use') 39 | parser.add_argument('--seed', type=int, default=123, help='Random seed to use.') 40 | parser.add_argument('--patience', type=int, default=0, help='Patience for early stopping. 0 is off.') 41 | parser.add_argument('--evalEvery', type=int, default=1, help='Do a validation set run, and save, every N epochs.') 42 | parser.add_argument('--expName', default='0', help='Unique string for an experiment') 43 | 44 | # path settings 45 | parser.add_argument('--runsPath', type=str, default=join(prefix_data,'runs'), help='Path to save runs to.') 46 | parser.add_argument('--savePath', type=str, default='checkpoints', help='Path to save checkpoints to in logdir. Default=checkpoints/') 47 | parser.add_argument('--cachePath', type=str, default=join(prefix_data,'cache'), help='Path to save cache to.') 48 | parser.add_argument('--resultsPath', type=str, default=None, help='Path to save evaluation results to when mode=test') 49 | 50 | # test settings 51 | parser.add_argument('--resume', type=str, default='', help='Path to load checkpoint from, for resuming training or testing.') 52 | parser.add_argument('--ckpt', type=str, default='latest', help='Resume from latest or best checkpoint.', choices=['latest', 'best']) 53 | parser.add_argument('--split', type=str, default='val', help='Data split to use for testing. Default is val', choices=['test', 'train', 'val']) 54 | parser.add_argument('--numSamples2Project', type=int, default=-1, help='TSNE uses these many samples ([:n]) to project data to 2D; set to -1 to disable') 55 | parser.add_argument('--extractOnly', action='store_true', help='extract descriptors') 56 | parser.add_argument('--predictionsFile', type=str, default=None, help='path to prior predictions data') 57 | parser.add_argument('--seqL_filterData', type=int, help='during testing, db and qry inds will be removed that violate sequence boundaries for this given sequence length') 58 | 59 | # dataset, model etc. 60 | parser.add_argument('--dataset', type=str, default='nordland-sw', help='Dataset to use', choices=['nordland-sw', 'nordland-sf', 'oxford-v1.0', 'oxford-pnv']) 61 | parser.add_argument('--pooling', action='store_true', help='use a fc layer to learn') 62 | parser.add_argument('--seqL', type=int, help='Sequence Length') 63 | parser.add_argument('--outDims', type=int, default=4096, help='Output descriptor dimensions') 64 | parser.add_argument('--margin', type=float, default=0.1, help='Margin for triplet loss. Default=0.1') 65 | parser.add_argument('--descType', type=str, default="netvlad-pytorch", help='underlying descriptor type') 66 | 67 | # matcher settings 68 | parser.add_argument('--matcher', type=str, default='seqMatchNet', help='Matcher Type', choices=['seqMatchNet', 'None']) 69 | parser.add_argument('--loss_trip_method', type=str, default='meanOfPairs', help='', choices=['centerOnly', 'meanOfPairs']) 70 | parser.add_argument('--neg_trip_method', type=str, default='meanOfPairs', help='', choices=['centerOnly', 'meanOfPairs']) 71 | 72 | 73 | if __name__ == "__main__": 74 | # torch.multiprocessing.set_start_method('spawn') 75 | 76 | opt = parser.parse_args() 77 | if opt.matcher == 'None': opt.matcher = None 78 | 79 | restore_var = ['lr', 'lrStep', 'lrGamma', 'weightDecay', 'momentum', 80 | 'runsPath', 'savePath', 'optim', 'margin', 'seed', 'patience', 'outDims'] 81 | if not opt.pooling and opt.resume: 82 | raise Exception("Please run without the '--resume' argument when '--pooling' is not used.") 83 | 84 | if opt.resume: 85 | flag_file = join(opt.resume, 'checkpoints', 'flags.json') 86 | if exists(flag_file): 87 | with open(flag_file, 'r') as f: 88 | stored_flags = {'--'+k : str(v) for k,v in json.load(f).items() if k in restore_var} 89 | to_del = [] 90 | for flag, val in stored_flags.items(): 91 | for act in parser._actions: 92 | if act.dest == flag[2:]: 93 | # store_true / store_false args don't accept arguments, filter these 94 | if type(act.const) == type(True): 95 | if val == str(act.default): 96 | to_del.append(flag) 97 | else: 98 | stored_flags[flag] = '' 99 | for flag in to_del: del stored_flags[flag] 100 | 101 | train_flags = [x for x in list(sum(stored_flags.items(), tuple())) if len(x) > 0] 102 | print('Restored flags:', train_flags) 103 | opt = parser.parse_args(train_flags, namespace=opt) 104 | 105 | wandb_dataStr = opt.dataset.lower()[:4] 106 | wandbResume = False if opt.resume == '' or opt.mode == 'test' else True 107 | wandb.init(project='SeqMatchNet_{}'.format(wandb_dataStr),config=opt,resume=wandbResume,anonymous="allow") 108 | # update runName 109 | runName = wandb.run.name 110 | if opt.expName != '' and runName is not None: #runName is None when running wandb offline 111 | wandb.run.name = opt.expName + "-" + runName.split("-")[-1] 112 | wandb.run.save() 113 | else: 114 | opt.expName = runName 115 | 116 | opt = wandb.config 117 | 118 | print(opt) 119 | 120 | cuda = not opt.nocuda 121 | if cuda and not torch.cuda.is_available(): 122 | raise Exception("No GPU found, please run with --nocuda") 123 | 124 | device = torch.device("cuda" if cuda else "cpu") 125 | 126 | random.seed(opt.seed) 127 | np.random.seed(opt.seed) 128 | torch.manual_seed(opt.seed) 129 | if cuda: 130 | torch.cuda.manual_seed(opt.seed) 131 | 132 | print('===> Loading dataset(s)') 133 | dataset, encoder_dim = get_dataset(opt) 134 | whole_train_set, whole_training_data_loader, train_set, whole_test_set = get_splits(opt, dataset) 135 | 136 | print('===> Building model') 137 | model, optimizer, scheduler, criterion, isParallel, encoder_dim = get_model(opt, encoder_dim, device) 138 | 139 | unique_string = datetime.now().strftime('%b%d_%H-%M-%S')+'_l'+str(opt.seqL)+'_'+ opt.expName 140 | writer = None 141 | 142 | if opt.mode.lower() == 'test': 143 | print('===> Running evaluation step') 144 | epoch = 1 145 | recallsOrDesc, dbEmb, qEmb, rAtL, preds = test(opt, model, encoder_dim, device, whole_test_set, writer, epoch, extract_noEval=opt.extractOnly) 146 | if opt.resultsPath is not None: 147 | if not exists(opt.resultsPath): 148 | makedirs(opt.resultsPath) 149 | if opt.extractOnly: 150 | gt = whole_test_set.get_positives() 151 | numDb = whole_test_set.dbStruct.numDb 152 | np.savez(join(opt.resultsPath,unique_string),dbDesc=recallsOrDesc[:numDb],qDesc=recallsOrDesc[numDb:],gt=gt) 153 | else: 154 | np.savez(join(opt.resultsPath,unique_string),args=opt.__dict__,recalls=recallsOrDesc, dbEmb=dbEmb,qEmb=qEmb,rAtL=rAtL,preds=preds) 155 | 156 | elif opt.mode.lower() == 'train': 157 | print('===> Training model') 158 | logdir = join(opt.runsPath,unique_string) 159 | writer = SummaryWriter(log_dir=logdir) 160 | train_set.cache = join(opt.cachePath, train_set.whichSet + '_feat_cache_{}.hdf5'.format(unique_string)) 161 | 162 | savePath = join(logdir, opt.savePath) 163 | makedirs(savePath) 164 | if not exists(opt.cachePath): makedirs(opt.cachePath) 165 | 166 | with open(join(savePath, 'flags.json'), 'w') as f: 167 | f.write(json.dumps({k:v for k,v in opt.items()})) 168 | print('===> Saving state to:', logdir) 169 | 170 | not_improved = 0 171 | best_score = 0 172 | for epoch in range(opt.start_epoch+1, opt.nEpochs + 1): 173 | train(opt, model, encoder_dim, device, dataset, criterion, optimizer, train_set, whole_train_set, whole_training_data_loader, epoch, writer) 174 | if opt.optim.upper() == 'SGD': 175 | scheduler.step() 176 | if (epoch % opt.evalEvery) == 0: 177 | recalls = test(opt, model, encoder_dim, device, whole_test_set, writer, epoch)[0] 178 | is_best = recalls[5] > best_score 179 | if is_best: 180 | not_improved = 0 181 | best_score = recalls[5] 182 | else: 183 | not_improved += 1 184 | 185 | save_checkpoint(savePath, { 186 | 'epoch': epoch, 187 | 'state_dict': model.state_dict(), 188 | 'recalls': recalls, 189 | 'best_score': best_score, 190 | 'optimizer' : optimizer.state_dict(), 191 | 'parallel' : isParallel, 192 | }, is_best) 193 | 194 | if opt.patience > 0 and not_improved > (opt.patience / opt.evalEvery): 195 | print('Performance did not improve for', opt.patience, 'epochs. Stopping.') 196 | break 197 | 198 | print("=> Best Recall@5: {:.4f}".format(best_score), flush=True) 199 | writer.close() 200 | -------------------------------------------------------------------------------- /seqMatchNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from tqdm import tqdm 6 | import torch 7 | import time 8 | 9 | class seqMatchNet(nn.Module): 10 | 11 | def __init__(self): 12 | super(seqMatchNet, self).__init__() 13 | 14 | def cdist_quick(self,r,c): 15 | return torch.sqrt(2 - 2*torch.matmul(r,c.transpose(1,2))) 16 | 17 | def aggregateSeqScore(self,data): 18 | r, c, method = data 19 | dMat = self.cdist_quick(r,c) 20 | seqL = dMat.shape[1] 21 | dis = torch.diagonal(dMat,0,1,2) 22 | if method == 'centerOnly': 23 | dis = dis[:,seqL//2] 24 | else: # default to 'meanOfPairs' 25 | dis = dis.mean(-1) 26 | return dis 27 | 28 | def forward(self,data): 29 | return self.aggregateSeqScore(data) 30 | 31 | def computeDisMat_torch(r,c): 32 | # assumes descriptors to be l2-normalized 33 | return torch.stack([torch.sqrt(2 - 2*torch.matmul(r,c[i].unsqueeze(1))).squeeze() for i in range(c.shape[0])]).transpose(0,1) 34 | 35 | def modInd(idx,l,n): 36 | return max(l,min(idx,n-l-1)) 37 | 38 | def computeRange(l,n): 39 | li, le = l//2, l-l//2 40 | return torch.stack([torch.arange(modInd(r,li,n)-li,modInd(r,li,n)+le,dtype=int) for r in range(n)]) 41 | 42 | def aggregateMatchScores_pt_fromMat_oneShot(dMat,l,device): 43 | convWeight = torch.eye(l,device=device).unsqueeze(0).unsqueeze(0) 44 | dMat_seq = -1*torch.ones(dMat.shape,device=device) 45 | li, le = l//2, l-l//2 46 | 47 | dMat_seq[li:-le+1,li:-le+1] = torch.nn.functional.conv2d(dMat.unsqueeze(0).unsqueeze(0),convWeight).squeeze() 48 | 49 | # fill left and right columns 50 | dMat_seq[:,:li] = dMat_seq[:,li,None] 51 | dMat_seq[:,-le+1:] = dMat_seq[:,-le,None] 52 | 53 | # fill top and bottom rows 54 | dMat_seq[:li,:] = dMat_seq[None,li,:] 55 | dMat_seq[-le+1:,:] = dMat_seq[None,-le,:] 56 | 57 | return dMat_seq 58 | 59 | def aggregateMatchScores_pt_fromMat(dMat,l,device,refCandidates=None): 60 | li, le = l//2, l-l//2 61 | n = dMat.shape[0] 62 | convWeight = torch.eye(l,device=device).unsqueeze(0).unsqueeze(0) 63 | 64 | # dMat = dMat.to('cpu') 65 | if refCandidates is None: 66 | shape = dMat.shape 67 | else: 68 | shape = refCandidates.transpose().shape 69 | preCompInds = computeRange(l,n) 70 | 71 | dMat_seq = -1*torch.ones(shape,device=device) 72 | 73 | durs = [] 74 | for j in tqdm(range(li,dMat.shape[1]-li), total=dMat.shape[1]-l, leave=True): 75 | t1 = time.time() 76 | if refCandidates is not None: 77 | rCands = preCompInds[refCandidates[j]].flatten() 78 | dMat_cols = dMat[rCands,j-li:j+le].to(device) 79 | dMat_seq[:,j] = torch.nn.functional.conv2d(dMat_cols.unsqueeze(0).unsqueeze(0),convWeight,stride=l).squeeze() 80 | else: 81 | dMat_cols = dMat[:,j-li:j+le].to(device) 82 | dMat_seq[li:-le+1,j] = torch.nn.functional.conv2d(dMat_cols.unsqueeze(0).unsqueeze(0),convWeight).squeeze() 83 | durs.append(time.time()-t1) 84 | 85 | if refCandidates is None: 86 | # fill left and right columns 87 | dMat_seq[:,:li] = dMat_seq[:,li,None] 88 | dMat_seq[:,-le+1:] = dMat_seq[:,-le,None] 89 | 90 | # fill top and bottom rows 91 | dMat_seq[:li,:] = dMat_seq[None,li,:] 92 | dMat_seq[-le+1:,:] = dMat_seq[None,-le,:] 93 | 94 | # assert(np.sum(dMat_seq==-1)==0) 95 | print("Average Time Per Query", np.mean(durs)) 96 | return dMat_seq 97 | 98 | def aggregateMatchScores_pt_fromDesc(dbDesc,qDesc,l,device,refCandidates=None): 99 | numDb, numQ = dbDesc.shape[0], qDesc.shape[0] 100 | convWeight = torch.eye(l,device=device).unsqueeze(0).unsqueeze(0) 101 | 102 | if refCandidates is None: 103 | shape = [numDb,numQ] 104 | else: 105 | shape = refCandidates.transpose().shape 106 | 107 | dMat_seq = -1*torch.ones(shape,device=device) 108 | 109 | durs = [] 110 | for j in tqdm(range(numQ), total=numQ, leave=True): 111 | t1 = time.time() 112 | if refCandidates is not None: 113 | rCands = refCandidates[j] 114 | else: 115 | rCands = torch.arange(numDb) 116 | 117 | dMat = torch.cdist(dbDesc[rCands],qDesc[j].unsqueeze(0)) 118 | dMat_seq[:,j] = torch.nn.functional.conv2d(dMat.unsqueeze(1),convWeight).squeeze() 119 | durs.append(time.time()-t1) 120 | 121 | # assert(torch.sum(dMat_seq==-1)==0) 122 | print("Average Time Per Query", np.mean(durs), np.std(durs)) 123 | return dMat_seq 124 | 125 | def aggregateMatchScores(dMat,l,device='cuda',refCandidates=None,dbDesc=None,qDesc=None,dMatProcOneShot=False): 126 | dMat_seq, matchInds, matchDists = None, None, None 127 | if dMat is None: 128 | dMat_seq = aggregateMatchScores_pt_fromDesc(dbDesc,qDesc,l,device,refCandidates).detach().cpu().numpy() 129 | else: 130 | if dMatProcOneShot: 131 | dMat_seq = aggregateMatchScores_pt_fromMat_oneShot(dMat,l,device).detach().cpu().numpy() 132 | else: 133 | dMat_seq = aggregateMatchScores_pt_fromMat(dMat,l,device,refCandidates).detach().cpu().numpy() 134 | return dMat_seq, matchInds, matchDists 135 | -------------------------------------------------------------------------------- /structFiles/nordland_test_d-1_d2-1.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oravus/SeqMatchNet/431f2a963587e1bbfd3f840d7d6c299207f8d88c/structFiles/nordland_test_d-1_d2-1.db -------------------------------------------------------------------------------- /structFiles/nordland_train_d-40_d2-10.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oravus/SeqMatchNet/431f2a963587e1bbfd3f840d7d6c299207f8d88c/structFiles/nordland_train_d-40_d2-10.db -------------------------------------------------------------------------------- /structFiles/nordland_val_d-1_d2-1.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oravus/SeqMatchNet/431f2a963587e1bbfd3f840d7d6c299207f8d88c/structFiles/nordland_val_d-1_d2-1.db -------------------------------------------------------------------------------- /structFiles/oxford-v1.0_2015-03-17-11-08-44_2014-12-16-18-44-24_test_d-10_d2-5.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oravus/SeqMatchNet/431f2a963587e1bbfd3f840d7d6c299207f8d88c/structFiles/oxford-v1.0_2015-03-17-11-08-44_2014-12-16-18-44-24_test_d-10_d2-5.db -------------------------------------------------------------------------------- /structFiles/oxford-v1.0_2015-03-17-11-08-44_2014-12-16-18-44-24_train_d-20_d2-5.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oravus/SeqMatchNet/431f2a963587e1bbfd3f840d7d6c299207f8d88c/structFiles/oxford-v1.0_2015-03-17-11-08-44_2014-12-16-18-44-24_train_d-20_d2-5.db -------------------------------------------------------------------------------- /structFiles/oxford-v1.0_splitInds.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oravus/SeqMatchNet/431f2a963587e1bbfd3f840d7d6c299207f8d88c/structFiles/oxford-v1.0_splitInds.npz -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from tqdm import tqdm 4 | from sklearn.manifold import TSNE 5 | from scipy.spatial.distance import cdist 6 | import numpy as np 7 | import time 8 | import wandb 9 | 10 | from utils import seq2Batch, getRecallAtN, computeMatches, evaluate, N_VALUES 11 | 12 | def test(opt, model, encoder_dim, device, eval_set, writer, epoch=0, extract_noEval=False): 13 | # TODO what if features dont fit in memory? 14 | test_data_loader = DataLoader(dataset=eval_set, 15 | num_workers=opt.threads, batch_size=opt.cacheBatchSize, shuffle=False, 16 | pin_memory=False) 17 | 18 | model.eval() 19 | with torch.no_grad(): 20 | print('====> Extracting Features') 21 | pool_size = encoder_dim 22 | validInds = eval_set.validInds 23 | dbFeat_single = torch.zeros((len(eval_set), pool_size),device=device) 24 | durs_batch = [] 25 | for iteration, (input, indices) in tqdm(enumerate(test_data_loader, 1),total=len(test_data_loader)-1, leave=False): 26 | t1 = time.time() 27 | image_encoding = seq2Batch(input).float().to(device) 28 | global_single_descs = model.pool(image_encoding).squeeze() 29 | dbFeat_single[indices] = global_single_descs 30 | 31 | if iteration % 50 == 0 or len(test_data_loader) <= 10: 32 | print("==> Batch ({}/{})".format(iteration, 33 | len(test_data_loader)), flush=True) 34 | durs_batch.append(time.time()-t1) 35 | del input, image_encoding, global_single_descs 36 | 37 | del test_data_loader 38 | print("Average batch time:", np.mean(durs_batch), np.std(durs_batch)) 39 | 40 | outSeqL = opt.seqL 41 | # use the original single descriptors for fast seqmatch over dMat (MSLS-like non-continuous dataset not supported) 42 | if (not opt.pooling and opt.matcher is None) and ('nordland' in opt.dataset.lower() or 'tmr' in opt.dataset.lower() or 'ox' in opt.dataset.lower()): 43 | dbFeat = dbFeat_single 44 | numDb = eval_set.numDb_full 45 | # fill sequences centered at single images 46 | else: 47 | dbFeat = torch.zeros(len(validInds), outSeqL, pool_size, device=device) 48 | numDb = eval_set.dbStruct.numDb 49 | for ind in range(len(validInds)): 50 | dbFeat[ind] = dbFeat_single[eval_set.getSeqIndsFromValidInds(validInds[ind])] 51 | del dbFeat_single 52 | 53 | # extracted for both db and query, now split in own sets 54 | qFeat = dbFeat[numDb:] 55 | dbFeat = dbFeat[:numDb] 56 | print(dbFeat.shape, qFeat.shape) 57 | 58 | qFeat_np = qFeat.detach().cpu().numpy().astype('float32') 59 | dbFeat_np = dbFeat.detach().cpu().numpy().astype('float32') 60 | 61 | db_emb, q_emb = None, None 62 | if opt.numSamples2Project != -1 and writer is not None: 63 | db_emb = TSNE(n_components=2).fit_transform(dbFeat_np[:opt.numSamples2Project]) 64 | q_emb = TSNE(n_components=2).fit_transform(qFeat_np[:opt.numSamples2Project]) 65 | 66 | if extract_noEval: 67 | return np.vstack([dbFeat_np,qFeat_np]), db_emb, q_emb, None, None 68 | 69 | predictions, bestDists = computeMatches(opt,N_VALUES,device,dbFeat,qFeat,dbFeat_np,qFeat_np) 70 | 71 | # for each query get those within threshold distance 72 | gt,gtDists = eval_set.get_positives(retDists=True) 73 | gtDistsMat = cdist(eval_set.dbStruct.utmDb,eval_set.dbStruct.utmQ) 74 | 75 | recall_at_n = getRecallAtN(N_VALUES, predictions, gt) 76 | rAtL = evaluate(N_VALUES,predictions,gtDistsMat) 77 | 78 | recalls = {} #make dict for output 79 | for i,n in enumerate(N_VALUES): 80 | recalls[n] = recall_at_n[i] 81 | print("====> Recall@{}: {:.4f}".format(n, recall_at_n[i])) 82 | if writer is not None: 83 | writer.add_scalar('Val/Recall@' + str(n), recall_at_n[i], epoch) 84 | wandb.log({'Val/Recall@' + str(n): recall_at_n[i]},commit=False) 85 | 86 | return recalls, db_emb, q_emb, rAtL, predictions 87 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.dataset import Subset 4 | from tqdm import tqdm 5 | import numpy as np 6 | from os import remove 7 | import h5py 8 | from math import ceil 9 | import wandb 10 | from termcolor import colored 11 | 12 | from utils import batch2Seq, seq2Batch, getRecallAtN, computeMatches, N_VALUES 13 | import seqMatchNet 14 | 15 | def train(opt, model, encoder_dim, device, dataset, criterion, optimizer, train_set, whole_train_set, whole_training_data_loader, epoch, writer): 16 | epoch_loss = 0 17 | startIter = 1 # keep track of batch iter across subsets for logging 18 | 19 | if opt.cacheRefreshRate > 0: 20 | subsetN = ceil(len(train_set) / opt.cacheRefreshRate) 21 | #TODO randomise the arange before splitting? 22 | subsetIdx = np.array_split(np.arange(len(train_set)), subsetN) 23 | else: 24 | subsetN = 1 25 | subsetIdx = [np.arange(len(train_set))] 26 | 27 | nBatches = (len(train_set) + opt.batchSize - 1) // opt.batchSize 28 | 29 | for subIter in range(subsetN): 30 | print('====> Building Cache') 31 | model.eval() 32 | with h5py.File(train_set.cache, mode='w') as h5: 33 | pool_size = encoder_dim 34 | validInds = whole_train_set.validInds 35 | h5feat = h5.create_dataset("features", [len(validInds), pool_size], dtype=np.float32) 36 | h5DisMat = h5.create_dataset("disMat",[whole_train_set.numDb_valid, whole_train_set.numQ_valid],dtype=np.float32) 37 | with torch.no_grad(): 38 | dbFeat_single = torch.zeros(len(whole_train_set), pool_size, device=device) 39 | # expected input B,T,C,H,W (T is 1) 40 | for iteration, (input, indices) in tqdm(enumerate(whole_training_data_loader, 1),total=len(whole_training_data_loader)-1, leave=True): 41 | # convert to B*T,C,H,W 42 | image_encoding = seq2Batch(input).float().to(device) 43 | 44 | # input B*T,C,1,1; outputs B,T,C (T=1); squeeze to B,C 45 | global_single_descs = model.pool(image_encoding).squeeze() 46 | dbFeat_single[indices] = global_single_descs 47 | del input, image_encoding, global_single_descs 48 | 49 | outSeqL = opt.seqL 50 | # fill sequences centered at single images 51 | dbFeat = torch.zeros(len(validInds), outSeqL, pool_size, device=device) 52 | for ind in range(len(validInds)): 53 | dbFeat[ind] = dbFeat_single[whole_train_set.getSeqIndsFromValidInds(validInds[ind])] 54 | if opt.matcher is None: # assumes seqL is 1 in this case 55 | h5feat[ind] = dbFeat[ind].squeeze() 56 | del dbFeat_single 57 | 58 | if opt.matcher is not None: 59 | offset = whole_train_set.numDb_valid 60 | # compute distance matrix 61 | print('====> Caching distance matrix') 62 | if opt.matcher == 'seqMatchNet': 63 | if 'nordland' in opt.dataset.lower() or 'tmr' in opt.dataset.lower(): 64 | dMat_cont = seqMatchNet.seqMatchNet.computeDisMat_torch(dbFeat[:offset,outSeqL//2], dbFeat[offset:,outSeqL//2]) 65 | if opt.neg_trip_method == 'centerOnly': 66 | h5DisMat[...] = dMat_cont.detach().cpu().numpy() 67 | elif opt.neg_trip_method == 'meanOfPairs': 68 | h5DisMat[...] = seqMatchNet.aggregateMatchScores(dMat_cont,outSeqL,device,dMatProcOneShot=False)[0] * (1.0/outSeqL) 69 | else: 70 | if opt.neg_trip_method == 'centerOnly': 71 | h5DisMat[...] = seqMatchNet.aggregateMatchScores(None,1,device, dbDesc=dbFeat[:offset,outSeqL//2:outSeqL//2+1], qDesc=dbFeat[offset:,outSeqL//2:outSeqL//2+1])[0] 72 | elif opt.neg_trip_method == 'meanOfPairs': 73 | h5DisMat[...] = seqMatchNet.aggregateMatchScores(None,outSeqL,device, dbDesc=dbFeat[:offset], qDesc=dbFeat[offset:])[0] * (1.0/outSeqL) 74 | else: 75 | raise("TODO") 76 | 77 | del dbFeat 78 | dMat = h5DisMat[()] 79 | dbFeat_np, qFeat_np = h5feat[:offset].copy(), h5feat[offset:].copy() 80 | 81 | sub_train_set = Subset(dataset=train_set, indices=subsetIdx[subIter]) 82 | 83 | training_data_loader = DataLoader(dataset=sub_train_set, num_workers=opt.threads, 84 | batch_size=opt.batchSize, shuffle=True, 85 | collate_fn=dataset.collate_fn, pin_memory=False) 86 | 87 | if not opt.nocuda: 88 | print('Allocated:', torch.cuda.memory_allocated()) 89 | print('Cached:', torch.cuda.memory_reserved()) 90 | 91 | print('====> Training Queries') 92 | model.train() 93 | for iteration, (query, positives, negatives, 94 | negCounts, indices) in tqdm(enumerate(training_data_loader, startIter),total=len(training_data_loader),leave=True): 95 | loss = 0 96 | if query is None: 97 | continue # in case we get an empty batch 98 | 99 | B, C = len(query), query[0].shape[1] 100 | nNeg = torch.sum(negCounts) 101 | image_encoding = seq2Batch(torch.cat([query, positives, negatives])).float() 102 | 103 | image_encoding = image_encoding.to(device) 104 | global_single_descs = model.pool(image_encoding) 105 | global_single_descs = batch2Seq(global_single_descs.squeeze(1),opt.seqL) 106 | 107 | del image_encoding 108 | g_desc_Q, g_desc_P, g_desc_N = torch.split(global_single_descs, [B, B, nNeg]) 109 | del global_single_descs 110 | optimizer.zero_grad() 111 | 112 | # calculate loss for each Query, Positive, Negative triplet 113 | # due to potential difference in number of negatives have to 114 | # do it per query, per negative 115 | trips_a, trips_p, trips_n = [], [], [] 116 | for i, negCount in enumerate(negCounts): 117 | for n in range(negCount): 118 | negIx = (torch.sum(negCounts[:i]) + n).item() 119 | if opt.matcher is None: 120 | loss += criterion(g_desc_Q[i:i+1].squeeze(1), g_desc_P[i:i+1].squeeze(1), g_desc_N[negIx:negIx+1].squeeze(1)) 121 | else: 122 | trips_a.append(g_desc_Q[i:i+1]) 123 | trips_p.append(g_desc_P[i:i+1]) 124 | trips_n.append(g_desc_N[negIx:negIx+1]) 125 | 126 | del g_desc_Q, g_desc_P, g_desc_N 127 | if opt.matcher is not None: 128 | dis_ap = model.matcher([torch.cat(trips_a), torch.cat(trips_p), opt.loss_trip_method]) 129 | dis_an = model.matcher([torch.cat(trips_a), torch.cat(trips_n), opt.loss_trip_method]) 130 | loss = torch.max(dis_ap - dis_an + opt.margin**0.5,torch.zeros(dis_ap.shape,device=device)).mean() 131 | del trips_a, trips_p, trips_n 132 | else: 133 | loss /= nNeg.float().to(device) # normalise by actual number of negatives 134 | loss.backward() 135 | optimizer.step() 136 | 137 | batch_loss = loss.item() 138 | epoch_loss += batch_loss 139 | 140 | if iteration % 50 == 0 or nBatches <= 10: 141 | print("==> Epoch[{}]({}/{}): Loss: {:.4f}".format(colored(epoch,'red'), iteration, 142 | nBatches, batch_loss), flush=True) 143 | writer.add_scalar('Train/Loss', batch_loss, 144 | ((epoch-1) * nBatches) + iteration) 145 | writer.add_scalar('Train/nNeg', nNeg, 146 | ((epoch-1) * nBatches) + iteration) 147 | wandb.log({"loss":batch_loss, "nNeg":nNeg, "epoch":epoch}) 148 | if not opt.nocuda: 149 | print('Allocated:', torch.cuda.memory_allocated()) 150 | print('Cached:', torch.cuda.memory_reserved()) 151 | 152 | del query, positives, negatives 153 | 154 | startIter += len(training_data_loader) 155 | del training_data_loader, loss 156 | optimizer.zero_grad() 157 | torch.cuda.empty_cache() 158 | remove(train_set.cache) # delete HDF5 cache 159 | 160 | avg_loss = epoch_loss / nBatches 161 | 162 | print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(colored(epoch,'red'), avg_loss), 163 | flush=True) 164 | writer.add_scalar('Train/AvgLoss', avg_loss, epoch) 165 | predictions, bestDists = computeMatches(opt,N_VALUES,device,dbFeat_np=dbFeat_np,qFeat_np=qFeat_np,dMat=dMat) 166 | gt,gtDists = whole_train_set.get_positives(retDists=True) 167 | recall_at_n = getRecallAtN(N_VALUES, predictions, gt) 168 | wandb.log({"loss_e":avg_loss},commit=False) 169 | for i,n in enumerate(N_VALUES): 170 | writer.add_scalar('Train/Recall@' + str(n), recall_at_n[i], epoch) 171 | wandb.log({'Train/Recall@' + str(n): recall_at_n[i]},commit=False) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import faiss 4 | from os.path import join 5 | import shutil 6 | 7 | import seqMatchNet 8 | 9 | N_VALUES = [1,5,10,20,100] 10 | 11 | def batch2Seq(input,l): 12 | inSh = input.shape 13 | input = input.view(inSh[0]//l,l,inSh[1]) 14 | return input 15 | 16 | def seq2Batch(input): 17 | inSh = input.shape 18 | input = input.view(inSh[0]*inSh[1],inSh[2],inSh[3],inSh[4]) 19 | return input 20 | 21 | def save_checkpoint(savePath, state, is_best, filename='checkpoint.pth.tar'): 22 | model_out_path = join(savePath, filename) 23 | torch.save(state, model_out_path) 24 | if is_best: 25 | shutil.copyfile(model_out_path, join(savePath, 'model_best.pth.tar')) 26 | 27 | def getRecallAtN(n_values, predictions, gt): 28 | correct_at_n = np.zeros(len(n_values)) 29 | numQWithoutGt = 0 30 | #TODO can we do this on the matrix in one go? 31 | for qIx, pred in enumerate(predictions): 32 | if len(gt[qIx]) == 0: 33 | numQWithoutGt += 1 34 | continue 35 | for i,n in enumerate(n_values): 36 | # if in top N then also in top NN, where NN > N 37 | if np.any(np.in1d(pred[:n], gt[qIx])): 38 | correct_at_n[i:] += 1 39 | break 40 | # print("Num Q without GT: ", numQWithoutGt, " of ", len(gt)) 41 | return correct_at_n / (len(gt)-numQWithoutGt) 42 | 43 | def computeMatches(opt,n_values,device,dbFeat=None,qFeat=None,dbFeat_np=None,qFeat_np=None,dMat=None): 44 | 45 | if opt.matcher is not None: 46 | if dMat is None: 47 | if opt.predictionsFile is not None: 48 | predPrior = np.load(opt.predictionsFile)['preds'] 49 | predPriorTopK = predPrior[:,:20] 50 | else: 51 | predPriorTopK = None 52 | outSeqL = opt.seqL 53 | dMat = 1.0/outSeqL * seqMatchNet.aggregateMatchScores(None,outSeqL,device, dbDesc=dbFeat, qDesc=qFeat,refCandidates=predPriorTopK)[0] 54 | print(dMat.shape) 55 | predictions = np.argsort(dMat,axis=0)[:max(n_values),:].transpose() 56 | bestDists = dMat[predictions[:,0],np.arange(dMat.shape[1])] 57 | if opt.predictionsFile is not None: 58 | predictions = np.array([predPriorTopK[qIdx][predictions[qIdx]] for qIdx in range(predictions.shape[0])]) 59 | print("Preds:",predictions.shape) 60 | 61 | # single image descriptors 62 | else: 63 | assert(opt.seqL==1) 64 | print('====> Building faiss index') 65 | faiss_index = faiss.IndexFlatL2(dbFeat_np.shape[-1]) 66 | faiss_index.add(np.squeeze(dbFeat_np)) 67 | 68 | distances, predictions = faiss_index.search(np.squeeze(qFeat_np), max(n_values)) 69 | bestDists = distances[:,0] 70 | return predictions, bestDists 71 | 72 | def evaluate(n_values,predictions,gtDistsMat=None): 73 | print('====> Calculating recall @ N') 74 | # compute recall for different loc radii 75 | rAtL = [] 76 | for locRad in [1,5,10,20,40,100,200]: 77 | gtAtL = gtDistsMat <= locRad 78 | gtAtL = [np.argwhere(gtAtL[:,qIx]).flatten() for qIx in range(gtDistsMat.shape[1])] 79 | rAtL.append(getRecallAtN(n_values, predictions, gtAtL)) 80 | return rAtL --------------------------------------------------------------------------------