├── .gitignore
├── LICENSE
├── README.md
├── assets
└── seqmatchnet.jpg
├── datasets.py
├── download.sh
├── get_datasets.py
├── get_models.py
├── main.py
├── seqMatchNet.py
├── structFiles
├── nordland_test_d-1_d2-1.db
├── nordland_train_d-40_d2-10.db
├── nordland_val_d-1_d2-1.db
├── oxford-v1.0_2015-03-17-11-08-44_2014-12-16-18-44-24_test_d-10_d2-5.db
├── oxford-v1.0_2015-03-17-11-08-44_2014-12-16-18-44-24_train_d-20_d2-5.db
└── oxford-v1.0_splitInds.npz
├── test.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb*
2 | *__pycache__/*
3 | data
4 | data/*
5 | wandb/*
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Sourav Garg
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SeqMatchNet
2 | Code for the CoRL 2021 Oral paper "SeqMatchNet: Contrastive Learning with Sequence Matching for Place Recognition and Relocalization"
3 |
4 | [[OpenReview](https://openreview.net/forum?id=OQMXb0xiCrt)] [[PDF](https://openreview.net/pdf?id=OQMXb0xiCrt)] [[CoRL 2021 YouTube Video](https://www.youtube.com/watch?v=Rb2Tbu72rG0)]
5 |
6 |
7 |
8 |
SeqMatchNet: Contrastive Learning with Sequence Matching.
9 |
10 |
11 | ## Setup
12 | ### Conda
13 | ```bash
14 | conda create -n seqnet numpy pytorch=1.8.0 torchvision tqdm scikit-learn faiss tensorboardx h5py wandb -c pytorch -c conda-forge
15 | ```
16 |
17 | ### Download
18 | Run `bash download.sh` to download single image NetVLAD descriptors (3.4 GB) for the Nordland-clean dataset [[a]](#nordclean) and the Oxford dataset (0.3 GB) [[b]](#saveLoc).
19 |
20 | You can download model trained on Oxford from [here](https://cloudstor.aarnet.edu.au/plus/s/y27PHvmZ2xpmcId).
21 | ## Run
22 |
23 | ### Train
24 | To train SeqMatchNet on the Oxford dataset with both the loss and negative mining based on sequence matching:
25 | ```python
26 | python main.py --mode train --seqL 5 --pooling --dataset oxford-v1.0 --loss_trip_method meanOfPairs --neg_trip_method meanOfPairs --expName ox10_MoP_negMoP
27 | ```
28 | For the Nordland dataset:
29 | ```python
30 | python main.py --mode train --seqL 5 --pooling --dataset nordland-sw --loss_trip_method meanOfPairs --neg_trip_method meanOfPairs --expName nord-sw_MoP_negMoP
31 | ```
32 |
33 | To train without sequence matching:
34 | ```python
35 | python main.py --mode train --seqL 5 --pooling --dataset oxford-v1.0 --loss_trip_method centerOnly --neg_trip_method centerOnly --expName ox10_CO_negCO
36 | ```
37 |
38 | ### Test
39 | ```python
40 | python main.py --mode test --seqL 5 --pooling -dataset oxford-v1.0 --split test --resume ./data/runs/
41 | ```
42 |
43 | ## Acknowledgement
44 | The code in this repository is based on [oravus/seqNet](https://github.com/oravus/seqNet) and [Nanne/pytorch-NetVlad](https://github.com/Nanne/pytorch-NetVlad).
45 |
46 | ## Citation
47 | ```
48 | @inproceedings{garg2021seqmatchnet,
49 | title={SeqMatchNet: Contrastive Learning with Sequence Matching for Place Recognition \& Relocalization},
50 | author={Garg, Sourav and Vankadari, Madhu and Milford, Michael},
51 | booktitle={5th Annual Conference on Robot Learning},
52 | year={2021}
53 | }
54 | ```
55 |
56 | #### Other Related Projects
57 | [SeqNet](https://github.com/oravus/seqNet);
58 | [Delta Descriptors (2020)](https://github.com/oravus/DeltaDescriptors);
59 | [Patch-NetVLAD (2021)](https://github.com/QVPR/Patch-NetVLAD);
60 | [CoarseHash (2020)](https://github.com/oravus/CoarseHash);
61 | [seq2single (2019)](https://github.com/oravus/seq2single);
62 | [LoST (2018)](https://github.com/oravus/lostX)
63 |
64 | [a] This is the clean version of the dataset that excludes images from the tunnels and red lights and can be downloaded from [here](https://cloudstor.aarnet.edu.au/plus/s/8L7loyTZjK0FsWT).
65 |
66 | [b] These will automatically save to `./data/`, you can modify this path in [download.sh](https://github.com/oravus/seqNet/blob/main/download.sh) and [get_datasets.py](https://github.com/oravus/seqNet/blob/5450829c4294fe1d14966bfa1ac9b7c93237369b/get_datasets.py#L6) to specify your workdir.
67 |
--------------------------------------------------------------------------------
/assets/seqmatchnet.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oravus/SeqMatchNet/431f2a963587e1bbfd3f840d7d6c299207f8d88c/assets/seqmatchnet.jpg
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data
3 | import itertools
4 |
5 | import os
6 | from os.path import join, exists
7 | from scipy.io import loadmat, savemat
8 | import numpy as np
9 | from collections import namedtuple
10 |
11 | from sklearn.neighbors import NearestNeighbors
12 | import faiss
13 | import h5py
14 |
15 | dbStruct = namedtuple('dbStruct', ['whichSet', 'dataset',
16 | 'dbImage', 'utmDb', 'qImage', 'utmQ', 'numDb', 'numQ',
17 | 'posDistThr', 'posDistSqThr', 'nonTrivPosDistSqThr',
18 | 'dbTimeStamp', 'qTimeStamp', 'gpsDb', 'gpsQ'])
19 |
20 | class Dataset():
21 | def __init__(self, dataset_name, train_mat_file, test_mat_file, val_mat_file, opt):
22 | self.dataset_name = dataset_name
23 | self.train_mat_file = train_mat_file
24 | self.test_mat_file = test_mat_file
25 | self.val_mat_file = val_mat_file
26 | self.struct_dir = "./structFiles/"
27 | self.seqL = opt.seqL
28 | self.seqL_filterData = opt.seqL_filterData
29 | self.matcher = opt.matcher
30 |
31 | # descriptor settings
32 | self.dbDescs = None
33 | self.qDescs = None
34 | self.trainInds = None
35 | self.valInds = None
36 | self.testInds = None
37 | self.db_seqBounds = None
38 | self.q_seqBounds = None
39 |
40 | def loadPreComputedDescriptors(self,ft1,ft2,seqBounds=None):
41 | self.dbDescs = np.expand_dims(ft1,(2,3))
42 | self.qDescs = np.expand_dims(ft2,(2,3))
43 | encDim = self.dbDescs.shape[1]
44 | print("All Db descs: ", self.dbDescs.shape)
45 | print("All Qry descs: ", self.qDescs.shape)
46 | if seqBounds is None:
47 | self.db_seqBounds = None
48 | self.q_seqBounds = None
49 | else:
50 | self.db_seqBounds = seqBounds[0]
51 | self.q_seqBounds = seqBounds[1]
52 | return encDim
53 |
54 | def get_whole_training_set(self, onlyDB=False):
55 | structFile = join(self.struct_dir, self.train_mat_file)
56 | indsSplit = self.trainInds
57 | return WholeDatasetFromStruct( structFile, indsSplit, self.dbDescs, self.qDescs, seqL=self.seqL, onlyDB=onlyDB, seqBounds=[self.db_seqBounds,self.q_seqBounds],seqL_filterData=self.seqL)
58 |
59 | def get_whole_val_set(self):
60 | structFile = join(self.struct_dir, self.val_mat_file)
61 | indsSplit = self.valInds
62 | return WholeDatasetFromStruct(structFile, indsSplit, self.dbDescs, self.qDescs, seqL=self.seqL, seqBounds=[self.db_seqBounds,self.q_seqBounds],seqL_filterData=self.seqL_filterData)
63 |
64 | def get_whole_test_set(self):
65 | if self.test_mat_file is not None:
66 | structFile = join(self.struct_dir, self.test_mat_file)
67 | indsSplit = self.testInds
68 | return WholeDatasetFromStruct(structFile, indsSplit, self.dbDescs, self.qDescs, seqL=self.seqL, seqBounds=[self.db_seqBounds,self.q_seqBounds],seqL_filterData=self.seqL_filterData)
69 | else:
70 | raise ValueError('test set not available for dataset ' + self.dataset_name)
71 |
72 | def get_training_query_set(self, margin=0.1, nNegSample=1000):
73 | structFile = join(self.struct_dir, self.train_mat_file)
74 | indsSplit = self.trainInds
75 | return QueryDatasetFromStruct(structFile,indsSplit, self.dbDescs, self.qDescs, nNegSample=nNegSample, margin=margin, seqL=self.seqL, seqBounds=[self.db_seqBounds,self.q_seqBounds], matcher=self.matcher)
76 |
77 | def get_val_query_set(self):
78 | structFile = join(self.struct_dir, self.val_mat_file)
79 | indsSplit = self.valInds
80 | return QueryDatasetFromStruct(structFile, indsSplit, self.dbDescs, self.qDescs, seqL=self.seqL, seqBounds=[self.db_seqBounds,self.q_seqBounds])
81 |
82 | @staticmethod
83 | def collate_fn(batch):
84 | """Creates mini-batch tensors from the list of tuples (query, positive, negatives).
85 |
86 | Args:
87 | batch: list of tuple (query, positive, negatives).
88 | - query: torch tensor of shape (T, C). e.g. (5,4096)
89 | - positive: torch tensor of shape (T, C).
90 | - negative: torch tensor of shape (N, T, C).
91 | Returns:
92 | query: torch tensor of shape (batch_size, T, C).
93 | positive: torch tensor of shape (batch_size, T, C).
94 | negatives: torch tensor of shape (batch_size, T, C).
95 | """
96 |
97 | batch = list(filter(lambda x: x is not None, batch))
98 | if len(batch) == 0:
99 | return None, None, None, None, None
100 |
101 | query, positive, negatives, indices = zip(*batch)
102 |
103 | query = data.dataloader.default_collate(query)
104 | positive = data.dataloader.default_collate(positive)
105 | negCounts = data.dataloader.default_collate([x.shape[0] for x in negatives])
106 | negatives = torch.cat(negatives, 0)
107 | indices = list(itertools.chain(*indices))
108 |
109 | return query, positive, negatives, negCounts, indices
110 |
111 | def getSeqInds(idx,seqL,maxNum,minNum=0,retLenDiff=False):
112 | seqLOrig = seqL
113 | seqInds = np.arange(max(minNum,idx-seqL//2),min(idx+seqL-seqL//2,maxNum),1)
114 | lenDiff = seqLOrig - len(seqInds)
115 | if retLenDiff:
116 | return lenDiff
117 |
118 | if seqInds[0] == minNum:
119 | seqInds = np.concatenate([seqInds,np.arange(seqInds[-1]+1,seqInds[-1]+1+lenDiff,1)])
120 | elif lenDiff > 0 and seqInds[-1] in range(maxNum-1,maxNum):
121 | seqInds = np.concatenate([np.arange(seqInds[0]-lenDiff,seqInds[0],1),seqInds])
122 | return seqInds
123 |
124 | def getValidSeqInds(seqBounds,seqL):
125 | validFlags = []
126 | for i in range(len(seqBounds)):
127 | sIdMin, sIdMax = seqBounds[i]
128 | lenDiff = getSeqInds(i,seqL,sIdMax,minNum=sIdMin,retLenDiff=True)
129 | validFlags.append(True if lenDiff == 0 else False)
130 | return validFlags
131 |
132 | def parse_db_struct(path):
133 | mat = loadmat(path)
134 |
135 | fieldnames = list(mat['dbStruct'][0, 0].dtype.names)
136 |
137 | dataset = mat['dbStruct'][0, 0]['dataset'].item()
138 | whichSet = mat['dbStruct'][0, 0]['whichSet'].item()
139 |
140 | dbImage = [f[0].item() for f in mat['dbStruct'][0, 0]['dbImageFns']]
141 | qImage = [f[0].item() for f in mat['dbStruct'][0, 0]['qImageFns']]
142 |
143 | numDb = mat['dbStruct'][0, 0]['numImages'].item()
144 | numQ = mat['dbStruct'][0, 0]['numQueries'].item()
145 |
146 | posDistThr = mat['dbStruct'][0, 0]['posDistThr'].item()
147 | posDistSqThr = mat['dbStruct'][0, 0]['posDistSqThr'].item()
148 | if 'nonTrivPosDistSqThr' in fieldnames:
149 | nonTrivPosDistSqThr = mat['dbStruct'][0, 0]['nonTrivPosDistSqThr'].item()
150 | else:
151 | nonTrivPosDistSqThr = None
152 |
153 | if 'dbTimeStamp' in fieldnames and 'qTimeStamp' in fieldnames:
154 | dbTimeStamp = [f[0].item() for f in mat['dbStruct'][0, 0]['dbTimeStamp'].T]
155 | qTimeStamp = [f[0].item() for f in mat['dbStruct'][0, 0]['qTimeStamp'].T]
156 | dbTimeStamp = np.array(dbTimeStamp)
157 | qTimeStamp = np.array(qTimeStamp)
158 | else:
159 | dbTimeStamp = None
160 | qTimeStamp = None
161 |
162 | if 'utmQ' in fieldnames and 'utmDb' in fieldnames:
163 | utmDb = mat['dbStruct'][0, 0]['utmDb'].T
164 | utmQ = mat['dbStruct'][0, 0]['utmQ'].T
165 | else:
166 | utmQ = None
167 | utmDb = None
168 |
169 | if 'gpsQ' in fieldnames and 'gpsDb' in fieldnames:
170 | gpsDb = mat['dbStruct'][0, 0]['gpsDb'].T
171 | gpsQ = mat['dbStruct'][0, 0]['gpsQ'].T
172 | else:
173 | gpsQ = None
174 | gpsDb = None
175 |
176 | return dbStruct(whichSet, dataset, dbImage, utmDb, qImage, utmQ, numDb, numQ, posDistThr,
177 | posDistSqThr, nonTrivPosDistSqThr, dbTimeStamp, qTimeStamp, gpsQ, gpsDb)
178 |
179 |
180 | def save_db_struct(path, db_struct):
181 | assert db_struct.numDb == len(db_struct.dbImage)
182 | assert db_struct.numQ == len(db_struct.qImage)
183 |
184 | inner_dict = {
185 | 'whichSet': db_struct.whichSet,
186 | 'dbImageFns': np.array(db_struct.dbImage, dtype=np.object).reshape(-1, 1),
187 | 'qImageFns': np.array(db_struct.qImage, dtype=np.object).reshape(-1, 1),
188 | 'numImages': db_struct.numDb,
189 | 'numQueries': db_struct.numQ,
190 | 'posDistThr': db_struct.posDistThr,
191 | 'posDistSqThr': db_struct.posDistSqThr,
192 | }
193 |
194 | if db_struct.dataset is not None:
195 | inner_dict['dataset'] = db_struct.dataset
196 |
197 | if db_struct.nonTrivPosDistSqThr is not None:
198 | inner_dict['nonTrivPosDistSqThr'] = db_struct.nonTrivPosDistSqThr
199 |
200 | if db_struct.utmDb is not None and db_struct.utmQ is not None:
201 | assert db_struct.numDb == len(db_struct.utmDb)
202 | assert db_struct.numQ == len(db_struct.utmQ)
203 | inner_dict['utmDb'] = db_struct.utmDb.T
204 | inner_dict['utmQ'] = db_struct.utmQ.T
205 |
206 | if db_struct.gpsDb is not None and db_struct.gpsQ is not None:
207 | assert db_struct.numDb == len(db_struct.gpsDb)
208 | assert db_struct.numQ == len(db_struct.gpsQ)
209 | inner_dict['gpsDb'] = db_struct.gpsDb.T
210 | inner_dict['gpsQ'] = db_struct.gpsQ.T
211 |
212 | if db_struct.dbTimeStamp is not None and db_struct.qTimeStamp is not None:
213 | inner_dict['dbTimeStamp'] = db_struct.dbTimeStamp.astype(np.float64)
214 | inner_dict['qTimeStamp'] = db_struct.qTimeStamp.astype(np.float64)
215 |
216 | savemat(path, {'dbStruct': inner_dict})
217 |
218 | def getValidSeqData(seqBounds,seqL_filterData,dbStruct):
219 | validFlags = getValidSeqInds(seqBounds,seqL_filterData)
220 | validInds = np.argwhere(validFlags).flatten()
221 | validInds_db = np.argwhere(validFlags[:dbStruct.numDb]).flatten()
222 | validInds_q = np.argwhere(validFlags[dbStruct.numDb:]).flatten()
223 | dbStruct = dbStruct._replace(utmDb=dbStruct.utmDb[validInds_db], numDb=len(validInds_db), utmQ=dbStruct.utmQ[validInds_q], numQ=len(validInds_q))
224 | print("\n Num sequences violating boundaries: {} \n".format(len(validFlags)-len(validInds)))
225 | return validInds, dbStruct, validInds_db, validInds_q
226 |
227 | class WholeDatasetFromStruct(data.Dataset):
228 | def __init__(self, structFile, indsSplit, dbDescs, qDescs, onlyDB=False, seqL=1, seqBounds=None,seqL_filterData=None):
229 | super().__init__()
230 |
231 | self.seqL = seqL
232 | self.filterBoundaryInds = False if seqL_filterData is None else True
233 |
234 | self.dbStruct = parse_db_struct(structFile)
235 |
236 | self.images = dbDescs[indsSplit[0]]
237 |
238 | if seqBounds[0] is None:
239 | self.seqBounds = np.array([[0,len(self.images)] for _ in range(len(self.images))])
240 |
241 | if not onlyDB:
242 | qImages = qDescs[indsSplit[1]]
243 | self.images = np.concatenate([self.images,qImages],0)
244 | if seqBounds[0] is None:
245 | q_seqBounds = np.array([[len(self.seqBounds),len(self.images)] for _ in range(len(qImages))])
246 | self.seqBounds = np.vstack([self.seqBounds,q_seqBounds])
247 |
248 | if seqBounds[0] is not None:
249 | db_seqBounds = seqBounds[0][indsSplit[0]]
250 | q_seqBounds = db_seqBounds[-1,-1] + seqBounds[1][indsSplit[1]]
251 | self.seqBounds = np.vstack([db_seqBounds,q_seqBounds])
252 |
253 | self.validInds = np.arange(len(self.images))
254 | self.numDb_full, self.numQ_full = self.dbStruct.numDb, self.dbStruct.numQ
255 | # update dbStruct and size variables with valid sequence based indices
256 | if self.filterBoundaryInds:
257 | self.validInds, self.dbStruct, _, _ = getValidSeqData(self.seqBounds,seqL_filterData,self.dbStruct)
258 | self.numDb_valid, self.numQ_valid = self.dbStruct.numDb, self.dbStruct.numQ
259 |
260 | self.whichSet = self.dbStruct.whichSet
261 | self.dataset = self.dbStruct.dataset
262 |
263 | self.positives = None
264 | self.distances = None
265 |
266 | def getSeqIndsFromValidInds(self,index):
267 | sIdMin, sIdMax = self.seqBounds[index]
268 | return getSeqInds(index,self.seqL,sIdMax,minNum=sIdMin)
269 |
270 | def getIndices(self,index):
271 | index = self.validInds[index]
272 | sIdMin, sIdMax = self.seqBounds[index]
273 | return getSeqInds(index,self.seqL,sIdMax,minNum=sIdMin)
274 |
275 | def __getitem__(self, index):
276 | img = self.images[np.array([index])]
277 | return img, index
278 |
279 | def __len__(self):
280 | return len(self.images)
281 |
282 | def get_positives(self,retDists=False):
283 | # positives for evaluation are those within trivial threshold range
284 | # fit NN to find them, search by radius
285 | if self.positives is None:
286 | knn = NearestNeighbors(n_jobs=-1)
287 | knn.fit(self.dbStruct.utmDb)
288 |
289 | print("Using Localization Radius: ", self.dbStruct.posDistThr)
290 | self.distances, self.positives = knn.radius_neighbors(self.dbStruct.utmQ, radius=self.dbStruct.posDistThr)
291 |
292 | if retDists:
293 | return self.positives, self.distances
294 | else:
295 | return self.positives
296 |
297 |
298 | class QueryDatasetFromStruct(data.Dataset):
299 | def __init__(self, structFile, indsSplit, dbDescs, qDescs, nNegSample=1000, nNeg=10, margin=0.1, seqL=1, seqBounds=None, matcher=None):
300 | super().__init__()
301 |
302 | self.matcher = matcher
303 | self.seqL = seqL
304 |
305 | self.dbDescs = dbDescs[indsSplit[0]]
306 | self.qDescs = qDescs[indsSplit[1]]
307 |
308 | self.margin = margin
309 |
310 | self.dbStruct = parse_db_struct(structFile)
311 |
312 | if seqBounds[0] is None:
313 | self.db_seqBounds = np.array([[0,len(self.dbDescs)] for _ in range(len(self.dbDescs))])
314 | self.q_seqBounds = np.array([[0,len(self.qDescs)] for _ in range(len(self.qDescs))])
315 | else:
316 | self.db_seqBounds = seqBounds[0][indsSplit[0]]
317 | self.q_seqBounds = seqBounds[1][indsSplit[1]]
318 |
319 | # update dbStruct and size variables with valid sequnence based indices
320 | self.numDb_full, self.numQ_full = self.dbStruct.numDb, self.dbStruct.numQ
321 | validFlags_db = getValidSeqInds(self.db_seqBounds,seqL)
322 | validFlags_q = getValidSeqInds(self.q_seqBounds,seqL)
323 | self.validInds_db, self.validInds_q = np.argwhere(validFlags_db).flatten(), np.argwhere(validFlags_q).flatten()
324 | self.dbStruct = self.dbStruct._replace(utmDb=self.dbStruct.utmDb[self.validInds_db], numDb=len(self.validInds_db), utmQ=self.dbStruct.utmQ[self.validInds_q], numQ=len(self.validInds_q))
325 | # self.numDb_valid, self.numQ_valid = self.dbStruct.numDb, self.dbStruct.numQ
326 | print("\n Num sequences (db) violating boundaries: {}".format(len(validFlags_db)-len(self.validInds_db)))
327 | print("Num sequences (q) violating boundaries: {} \n".format(len(validFlags_q)-len(self.validInds_q)))
328 |
329 | self.whichSet = self.dbStruct.whichSet
330 | self.dataset = self.dbStruct.dataset
331 | self.nNegSample = nNegSample # number of negatives to randomly sample
332 | self.nNeg = nNeg # number of negatives used for training
333 |
334 | self.use_faiss = True
335 | self.use_h5disMat = False
336 | if self.matcher is not None:
337 | self.use_faiss = False
338 | self.use_h5disMat = True
339 | assert(self.use_faiss!=self.use_h5disMat)
340 |
341 | # potential positives are those within nontrivial threshold range
342 | # fit NN to find them, search by radius
343 | knn = NearestNeighbors(n_jobs=-1)
344 | knn.fit(self.dbStruct.utmDb)
345 |
346 | # TODO use sqeuclidean as metric?
347 | self.nontrivial_distances, self.nontrivial_positives = \
348 | knn.radius_neighbors(self.dbStruct.utmQ, radius=self.dbStruct.nonTrivPosDistSqThr**0.5,
349 | return_distance=True)
350 |
351 | self.nontrivial_positives = list(self.nontrivial_positives)
352 |
353 | # radius returns unsorted, sort once now so we dont have to later
354 | for i, posi in enumerate(self.nontrivial_positives):
355 | self.nontrivial_positives[i] = np.sort(posi)
356 |
357 | # its possible some queries don't have any non trivial potential positives
358 | # lets filter those out
359 | self.queries = np.where(np.array([len(x) for x in self.nontrivial_positives]) > 0)[0]
360 | print("\n Queries within range ",len(self.queries), len(self.nontrivial_positives),"\n")
361 |
362 | # potential negatives are those outside of posDistThr range
363 | potential_positives = knn.radius_neighbors(self.dbStruct.utmQ,
364 | radius=self.dbStruct.posDistThr,
365 | return_distance=False)
366 |
367 | self.potential_negatives = []
368 | for pos in potential_positives:
369 | self.potential_negatives.append(np.setdiff1d(np.arange(self.dbStruct.numDb), pos, assume_unique=True))
370 |
371 | self.cache = None # filepath of HDF5 containing feature vectors for images
372 | self.h5feat = None
373 |
374 | self.negCache = [np.empty((0,)) for _ in range(self.dbStruct.numQ)]
375 |
376 | def __getitem__(self, index):
377 | with h5py.File(self.cache, mode='r') as h5:
378 | h5feat = h5.get("features")
379 | h5disMat = h5.get("disMat")
380 |
381 | qOffset = self.dbStruct.numDb
382 | qFeat = h5feat[index + qOffset]
383 | qDis = h5disMat[:,index]
384 |
385 | if len(self.nontrivial_positives[index]) < 1:
386 | # if none are violating then skip this query
387 | return None
388 |
389 | if self.use_h5disMat:
390 | posDis = qDis[self.nontrivial_positives[index]]
391 | posNN = np.argmin(posDis)
392 | dPos = posDis[posNN]
393 | posIndex = self.nontrivial_positives[index][posNN]
394 | else:
395 | posFeat = h5feat[self.nontrivial_positives[index].tolist()]
396 | if self.use_faiss:
397 | faiss_index = faiss.IndexFlatL2(posFeat.shape[1])
398 | # noinspection PyArgumentList
399 | faiss_index.add(posFeat)
400 | # noinspection PyArgumentList
401 | dPos, posNN = faiss_index.search(qFeat.reshape(1, -1), 1)#posFeat.shape[0])
402 | dPos = np.sqrt(dPos) # faiss returns squared distance
403 | dPos = dPos[0][-1].item()
404 | posIndex = self.nontrivial_positives[index][posNN[0,-1]].item()
405 | else:
406 | knn = NearestNeighbors(n_jobs=-1)
407 | knn.fit(posFeat)
408 | dPos, posNN = knn.kneighbors(qFeat.reshape(1, -1), 1)#posFeat.shape[0])
409 | dPos = dPos[0][-1].item()
410 | posIndex = self.nontrivial_positives[index][posNN[0,-1]].item()
411 |
412 | if self.use_h5disMat:
413 | # negSample = self.potential_negatives[index]
414 | negSample = np.random.choice(self.potential_negatives[index], self.nNegSample)
415 | negSample = np.unique(np.concatenate([self.negCache[index], negSample]))
416 | negSample = np.sort(negSample).astype(int)
417 | negDis = qDis[negSample]
418 | negNN = np.argsort(negDis)[:self.nNeg * 10]
419 | dNeg = negDis[negNN]
420 | else:
421 | negSample = np.random.choice(self.potential_negatives[index], self.nNegSample)
422 | negSample = np.unique(np.concatenate([self.negCache[index], negSample]))
423 | negSample = np.sort(negSample).astype(int) #essential to order ascending, speeds up h5 by about double
424 | negFeat = h5feat[negSample.tolist()]
425 | if self.use_faiss:
426 | faiss_index = faiss.IndexFlatL2(posFeat.shape[1])
427 | # noinspection PyArgumentList
428 | faiss_index.add(negFeat)
429 | # noinspection PyArgumentList
430 | dNeg, negNN = faiss_index.search(qFeat.reshape(1, -1), self.nNeg * 10)
431 | dNeg = np.sqrt(dNeg)
432 | else:
433 | knn.fit(negFeat)
434 |
435 | # to quote netvlad paper code: 10x is hacky but fine
436 | dNeg, negNN = knn.kneighbors(qFeat.reshape(1, -1), self.nNeg * 10)
437 |
438 | dNeg = dNeg.reshape(-1)
439 | negNN = negNN.reshape(-1)
440 |
441 | # try to find negatives that are within margin, if there aren't any return none
442 | violatingNeg = dNeg < dPos + self.margin**0.5
443 |
444 | if np.sum(violatingNeg) < 1:
445 | # if none are violating then skip this query
446 | return None
447 |
448 | negNN = negNN[violatingNeg][:self.nNeg]
449 | negIndices = negSample[negNN].astype(np.int32)
450 | self.negCache[index] = negIndices
451 |
452 |
453 | sIdMin_q, sIdMax_q = self.q_seqBounds[self.validInds_q[index]]
454 | query = self.qDescs[getSeqInds(self.validInds_q[index],self.seqL,sIdMax_q,sIdMin_q)]
455 | sIdMin_p, sIdMax_p = self.db_seqBounds[self.validInds_db[posIndex]]
456 | positive = self.dbDescs[getSeqInds(self.validInds_db[posIndex],self.seqL,sIdMax_p,sIdMin_p)]
457 | negatives = []
458 | for negIndex in negIndices:
459 | sIdMin_n, sIdMax_n = self.db_seqBounds[self.validInds_db[negIndex]]
460 | negative = self.dbDescs[getSeqInds(self.validInds_db[negIndex],self.seqL,sIdMax_n,sIdMin_n)]
461 | negative = torch.tensor(negative)
462 | negatives.append(negative)
463 |
464 | negatives = torch.stack(negatives, 0)
465 |
466 | # noinspection PyTypeChecker
467 | return query, positive, negatives, [index, posIndex] + negIndices.tolist()
468 |
469 | def __len__(self):
470 | return len(self.validInds_q)
471 |
--------------------------------------------------------------------------------
/download.sh:
--------------------------------------------------------------------------------
1 | # download nordland-clean dataset
2 | wget -cO - https://cloudstor.aarnet.edu.au/plus/s/PK98pDvLAesL1aL/download > nordland-clean.zip
3 | mkdir -p ./data/
4 | unzip nordland-clean.zip -d ./data/
5 | rm nordland-clean.zip
6 |
7 | # download oxford descriptors
8 | wget -cO - https://cloudstor.aarnet.edu.au/plus/s/T0M1Ry4HXOAkkGz/download > oxford_2014-12-16-18-44-24_stereo_left.npy
9 | wget -cO - https://cloudstor.aarnet.edu.au/plus/s/vr21RnhMmOkW8S9/download > oxford_2015-03-17-11-08-44_stereo_left.npy
10 | mv oxford* ./data/descData/netvlad-pytorch/
--------------------------------------------------------------------------------
/get_datasets.py:
--------------------------------------------------------------------------------
1 | from datasets import Dataset
2 | from torch.utils.data import DataLoader, SubsetRandomSampler
3 | import numpy as np
4 | from os.path import join
5 |
6 | prefix_data = "./data/"
7 |
8 | def get_dataset(opt):
9 |
10 | if 'nordland' in opt.dataset.lower():
11 | dataset = Dataset('nordland', 'nordland_train_d-40_d2-10.db', 'nordland_test_d-1_d2-1.db', 'nordland_val_d-1_d2-1.db', opt) # train, test, val structs
12 | if 'sw' in opt.dataset.lower():
13 | ref, qry = 'summer', 'winter'
14 | elif 'sf' in opt.dataset.lower():
15 | ref, qry = 'spring', 'fall'
16 | ft1 = np.load(join(prefix_data,"descData/{}/nordland-clean-{}.npy".format(opt.descType,ref)))
17 | ft2 = np.load(join(prefix_data,"descData/{}/nordland-clean-{}.npy".format(opt.descType,qry)))
18 | trainInds, testInds, valInds = np.arange(15000), np.arange(15100,18100), np.arange(18200,21200)
19 |
20 | dataset.trainInds = [trainInds, trainInds]
21 | dataset.valInds = [valInds, valInds]
22 | dataset.testInds = [testInds, testInds]
23 | encoder_dim = dataset.loadPreComputedDescriptors(ft1,ft2)
24 |
25 | elif 'oxford' in opt.dataset.lower():
26 | ref, qry = '2015-03-17-11-08-44', '2014-12-16-18-44-24'
27 | structStr = "{}_{}_{}".format(opt.dataset,ref,qry)
28 | # note: for now temporarily use ox_test as ox_val
29 | if 'v1.0' in opt.dataset:
30 | testStr = '_test_d-10_d2-5.db'
31 | elif 'pnv' in opt.dataset:
32 | testStr = '_test_d-25_d2-5.db'
33 | dataset = Dataset(opt.dataset, structStr+'_train_d-20_d2-5.db', structStr+testStr, structStr+testStr, opt) # train, test, val structs
34 | ft1 = np.load(join(prefix_data,"descData/{}/oxford_{}_stereo_left.npy".format(opt.descType,ref)))
35 | ft2 = np.load(join(prefix_data,"descData/{}/oxford_{}_stereo_left.npy".format(opt.descType,qry)))
36 | splitInds = np.load("./structFiles/{}_splitInds.npz".format(opt.dataset), allow_pickle=True)
37 |
38 | dataset.trainInds = splitInds['trainInds'].tolist()
39 | dataset.valInds = splitInds['valInds'].tolist()
40 | dataset.testInds = splitInds['testInds'].tolist()
41 | encoder_dim = dataset.loadPreComputedDescriptors(ft1,ft2)
42 |
43 | else:
44 | raise Exception('Unknown dataset')
45 |
46 | return dataset, encoder_dim
47 |
48 |
49 | def get_splits(opt, dataset):
50 | whole_train_set, whole_training_data_loader, train_set, whole_test_set = None, None, None, None
51 | if opt.mode.lower() == 'train':
52 | whole_train_set = dataset.get_whole_training_set()
53 | whole_training_data_loader = DataLoader(dataset=whole_train_set,
54 | num_workers=opt.threads, batch_size=opt.cacheBatchSize, shuffle=False,
55 | pin_memory=not opt.nocuda)
56 |
57 | train_set = dataset.get_training_query_set(opt.margin)
58 |
59 | print('====> Training whole set:', len(whole_train_set))
60 | print('====> Training query set:', len(train_set))
61 | whole_test_set = dataset.get_whole_val_set()
62 | print('===> Evaluating on val set, query count:', whole_test_set.dbStruct.numQ)
63 | elif opt.mode.lower() == 'test':
64 | if opt.split.lower() == 'test':
65 | whole_test_set = dataset.get_whole_test_set()
66 | print('===> Evaluating on test set')
67 | elif opt.split.lower() == 'train':
68 | whole_test_set = dataset.get_whole_training_set()
69 | print('===> Evaluating on train set')
70 | elif opt.split.lower() in ['val']:
71 | whole_test_set = dataset.get_whole_val_set()
72 | print('===> Evaluating on val set')
73 | else:
74 | raise ValueError('Unknown dataset split: ' + opt.split)
75 | print('====> Query count:', whole_test_set.dbStruct.numQ)
76 |
77 | return whole_train_set, whole_training_data_loader, train_set, whole_test_set
78 |
--------------------------------------------------------------------------------
/get_models.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torch.optim as optim
4 | from os.path import join, isfile
5 | import torchvision.models as models
6 | import torch
7 |
8 | import seqMatchNet
9 |
10 | class WXpB(nn.Module):
11 | def __init__(self, inDims, outDims):
12 | super().__init__()
13 | self.inDims = inDims
14 | self.outDims = outDims
15 | self.conv = nn.Conv1d(inDims, outDims, kernel_size=1)
16 |
17 | def forward(self, x):
18 | x = x.squeeze(-1) # convert [B,C,1,1] to [B,C,1]
19 | feat_transformed = self.conv(x)
20 | return feat_transformed.permute(0,2,1) # return [B,1,C]
21 |
22 | class Flatten(nn.Module):
23 | def forward(self, input):
24 | return input.view(input.size(0), -1)
25 |
26 | class L2Norm(nn.Module):
27 | def __init__(self, dim=1):
28 | super().__init__()
29 | self.dim = dim
30 |
31 | def forward(self, input):
32 | return F.normalize(input, p=2, dim=self.dim)
33 |
34 | def get_pooling(opt,encoder_dim):
35 |
36 | if opt.pooling:
37 | global_pool = nn.AdaptiveMaxPool2d((1,1)) # no effect
38 | poolLayers = nn.Sequential(*[global_pool, WXpB(encoder_dim, opt.outDims), L2Norm(dim=-1)])
39 | else:
40 | global_pool = nn.AdaptiveMaxPool2d((1,1)) # no effect
41 | poolLayers = nn.Sequential(*[global_pool, Flatten(), L2Norm(dim=-1)])
42 | return poolLayers
43 |
44 | def get_matcher(opt,device):
45 |
46 | if opt.matcher == 'seqMatchNet':
47 | sm = seqMatchNet.seqMatchNet()
48 | matcherLayers = nn.Sequential(*[sm])
49 | else:
50 | matcherLayers = None
51 |
52 | return matcherLayers
53 |
54 | def printModelParams(model):
55 |
56 | for name, param in model.named_parameters():
57 | if param.requires_grad:
58 | print(name, param.shape)
59 | return
60 |
61 | def get_model(opt,input_dim,device):
62 | model = nn.Module()
63 | encoder_dim = input_dim
64 |
65 | poolLayers = get_pooling(opt,encoder_dim)
66 | model.add_module('pool', poolLayers)
67 |
68 | matcherLayers = get_matcher(opt,device)
69 | if matcherLayers is not None:
70 | model.add_module('matcher',matcherLayers)
71 |
72 | isParallel = False
73 | if opt.nGPU > 1 and torch.cuda.device_count() > 1:
74 | model.pool = nn.DataParallel(model.pool)
75 | isParallel = True
76 |
77 | if not opt.resume:
78 | model = model.to(device)
79 |
80 | scheduler, optimizer, criterion = None, None, None
81 | if opt.mode.lower() == 'train':
82 | if opt.optim.upper() == 'ADAM':
83 | optimizer = optim.Adam(filter(lambda p: p.requires_grad,
84 | model.parameters()), lr=opt.lr)#, betas=(0,0.9))
85 | elif opt.optim.upper() == 'SGD':
86 | optimizer = optim.SGD(filter(lambda p: p.requires_grad,
87 | model.parameters()), lr=opt.lr,
88 | momentum=opt.momentum,
89 | weight_decay=opt.weightDecay)
90 |
91 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=opt.lrStep, gamma=opt.lrGamma)
92 | else:
93 | raise ValueError('Unknown optimizer: ' + opt.optim)
94 |
95 | # used only when matcher is none
96 | criterion = nn.TripletMarginLoss(margin=opt.margin**0.5, p=2, reduction='sum').to(device)
97 |
98 | if opt.resume:
99 | if opt.ckpt.lower() == 'latest':
100 | resume_ckpt = join(opt.resume, 'checkpoints', 'checkpoint.pth.tar')
101 | elif opt.ckpt.lower() == 'best':
102 | resume_ckpt = join(opt.resume, 'checkpoints', 'model_best.pth.tar')
103 |
104 | if isfile(resume_ckpt):
105 | print("=> loading checkpoint '{}'".format(resume_ckpt))
106 | checkpoint = torch.load(resume_ckpt, map_location=lambda storage, loc: storage)
107 | opt.update({"start_epoch" : checkpoint['epoch']}, allow_val_change=True)
108 | best_metric = checkpoint['best_score']
109 | model.load_state_dict(checkpoint['state_dict'])
110 | model = model.to(device)
111 | if opt.mode == 'train':
112 | optimizer.load_state_dict(checkpoint['optimizer'])
113 | print("=> loaded checkpoint '{}' (epoch {})"
114 | .format(resume_ckpt, checkpoint['epoch']))
115 | else:
116 | print("=> no checkpoint found at '{}'".format(resume_ckpt))
117 |
118 | return model, optimizer, scheduler, criterion, isParallel, encoder_dim
119 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import random, json
4 | from os.path import join, exists
5 | from os import makedirs
6 |
7 | import wandb
8 | import torch
9 | from datetime import datetime
10 | from tqdm import tqdm
11 |
12 | from tensorboardX import SummaryWriter
13 | import numpy as np
14 |
15 | from get_datasets import get_dataset, get_splits, prefix_data
16 | from get_models import get_model
17 | from utils import save_checkpoint
18 | from train import train
19 | from test import test
20 |
21 | parser = argparse.ArgumentParser(description='SeqMatchNet')
22 | parser.add_argument('--mode', type=str, default='train', help='Mode', choices=['train', 'test'])
23 |
24 | # train settings
25 | parser.add_argument('--batchSize', type=int, default=16, help='Number of triplets (query, pos, negs). Each triplet consists of 12 images.')
26 | parser.add_argument('--cacheBatchSize', type=int, default=64, help='Batch size for caching and testing')
27 | parser.add_argument('--cacheRefreshRate', type=int, default=1000, help='How often to refresh cache, in number of queries. 0 for off')
28 | parser.add_argument('--nEpochs', type=int, default=50, help='number of epochs to train for')
29 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')
30 | parser.add_argument('--nGPU', type=int, default=1, help='number of GPU to use.')
31 | parser.add_argument('--optim', type=str, default='SGD', help='optimizer to use', choices=['SGD', 'ADAM'])
32 | parser.add_argument('--lr', type=float, default=0.001, help='Learning Rate.')
33 | parser.add_argument('--lrStep', type=float, default=5, help='Decay LR ever N steps.')
34 | parser.add_argument('--lrGamma', type=float, default=0.5, help='Multiply LR by Gamma for decaying.')
35 | parser.add_argument('--weightDecay', type=float, default=0.001, help='Weight decay for SGD.')
36 | parser.add_argument('--momentum', type=float, default=0.9, help='Momentum for SGD.')
37 | parser.add_argument('--nocuda', action='store_true', help='Dont use cuda')
38 | parser.add_argument('--threads', type=int, default=8, help='Number of threads for each data loader to use')
39 | parser.add_argument('--seed', type=int, default=123, help='Random seed to use.')
40 | parser.add_argument('--patience', type=int, default=0, help='Patience for early stopping. 0 is off.')
41 | parser.add_argument('--evalEvery', type=int, default=1, help='Do a validation set run, and save, every N epochs.')
42 | parser.add_argument('--expName', default='0', help='Unique string for an experiment')
43 |
44 | # path settings
45 | parser.add_argument('--runsPath', type=str, default=join(prefix_data,'runs'), help='Path to save runs to.')
46 | parser.add_argument('--savePath', type=str, default='checkpoints', help='Path to save checkpoints to in logdir. Default=checkpoints/')
47 | parser.add_argument('--cachePath', type=str, default=join(prefix_data,'cache'), help='Path to save cache to.')
48 | parser.add_argument('--resultsPath', type=str, default=None, help='Path to save evaluation results to when mode=test')
49 |
50 | # test settings
51 | parser.add_argument('--resume', type=str, default='', help='Path to load checkpoint from, for resuming training or testing.')
52 | parser.add_argument('--ckpt', type=str, default='latest', help='Resume from latest or best checkpoint.', choices=['latest', 'best'])
53 | parser.add_argument('--split', type=str, default='val', help='Data split to use for testing. Default is val', choices=['test', 'train', 'val'])
54 | parser.add_argument('--numSamples2Project', type=int, default=-1, help='TSNE uses these many samples ([:n]) to project data to 2D; set to -1 to disable')
55 | parser.add_argument('--extractOnly', action='store_true', help='extract descriptors')
56 | parser.add_argument('--predictionsFile', type=str, default=None, help='path to prior predictions data')
57 | parser.add_argument('--seqL_filterData', type=int, help='during testing, db and qry inds will be removed that violate sequence boundaries for this given sequence length')
58 |
59 | # dataset, model etc.
60 | parser.add_argument('--dataset', type=str, default='nordland-sw', help='Dataset to use', choices=['nordland-sw', 'nordland-sf', 'oxford-v1.0', 'oxford-pnv'])
61 | parser.add_argument('--pooling', action='store_true', help='use a fc layer to learn')
62 | parser.add_argument('--seqL', type=int, help='Sequence Length')
63 | parser.add_argument('--outDims', type=int, default=4096, help='Output descriptor dimensions')
64 | parser.add_argument('--margin', type=float, default=0.1, help='Margin for triplet loss. Default=0.1')
65 | parser.add_argument('--descType', type=str, default="netvlad-pytorch", help='underlying descriptor type')
66 |
67 | # matcher settings
68 | parser.add_argument('--matcher', type=str, default='seqMatchNet', help='Matcher Type', choices=['seqMatchNet', 'None'])
69 | parser.add_argument('--loss_trip_method', type=str, default='meanOfPairs', help='', choices=['centerOnly', 'meanOfPairs'])
70 | parser.add_argument('--neg_trip_method', type=str, default='meanOfPairs', help='', choices=['centerOnly', 'meanOfPairs'])
71 |
72 |
73 | if __name__ == "__main__":
74 | # torch.multiprocessing.set_start_method('spawn')
75 |
76 | opt = parser.parse_args()
77 | if opt.matcher == 'None': opt.matcher = None
78 |
79 | restore_var = ['lr', 'lrStep', 'lrGamma', 'weightDecay', 'momentum',
80 | 'runsPath', 'savePath', 'optim', 'margin', 'seed', 'patience', 'outDims']
81 | if not opt.pooling and opt.resume:
82 | raise Exception("Please run without the '--resume' argument when '--pooling' is not used.")
83 |
84 | if opt.resume:
85 | flag_file = join(opt.resume, 'checkpoints', 'flags.json')
86 | if exists(flag_file):
87 | with open(flag_file, 'r') as f:
88 | stored_flags = {'--'+k : str(v) for k,v in json.load(f).items() if k in restore_var}
89 | to_del = []
90 | for flag, val in stored_flags.items():
91 | for act in parser._actions:
92 | if act.dest == flag[2:]:
93 | # store_true / store_false args don't accept arguments, filter these
94 | if type(act.const) == type(True):
95 | if val == str(act.default):
96 | to_del.append(flag)
97 | else:
98 | stored_flags[flag] = ''
99 | for flag in to_del: del stored_flags[flag]
100 |
101 | train_flags = [x for x in list(sum(stored_flags.items(), tuple())) if len(x) > 0]
102 | print('Restored flags:', train_flags)
103 | opt = parser.parse_args(train_flags, namespace=opt)
104 |
105 | wandb_dataStr = opt.dataset.lower()[:4]
106 | wandbResume = False if opt.resume == '' or opt.mode == 'test' else True
107 | wandb.init(project='SeqMatchNet_{}'.format(wandb_dataStr),config=opt,resume=wandbResume,anonymous="allow")
108 | # update runName
109 | runName = wandb.run.name
110 | if opt.expName != '' and runName is not None: #runName is None when running wandb offline
111 | wandb.run.name = opt.expName + "-" + runName.split("-")[-1]
112 | wandb.run.save()
113 | else:
114 | opt.expName = runName
115 |
116 | opt = wandb.config
117 |
118 | print(opt)
119 |
120 | cuda = not opt.nocuda
121 | if cuda and not torch.cuda.is_available():
122 | raise Exception("No GPU found, please run with --nocuda")
123 |
124 | device = torch.device("cuda" if cuda else "cpu")
125 |
126 | random.seed(opt.seed)
127 | np.random.seed(opt.seed)
128 | torch.manual_seed(opt.seed)
129 | if cuda:
130 | torch.cuda.manual_seed(opt.seed)
131 |
132 | print('===> Loading dataset(s)')
133 | dataset, encoder_dim = get_dataset(opt)
134 | whole_train_set, whole_training_data_loader, train_set, whole_test_set = get_splits(opt, dataset)
135 |
136 | print('===> Building model')
137 | model, optimizer, scheduler, criterion, isParallel, encoder_dim = get_model(opt, encoder_dim, device)
138 |
139 | unique_string = datetime.now().strftime('%b%d_%H-%M-%S')+'_l'+str(opt.seqL)+'_'+ opt.expName
140 | writer = None
141 |
142 | if opt.mode.lower() == 'test':
143 | print('===> Running evaluation step')
144 | epoch = 1
145 | recallsOrDesc, dbEmb, qEmb, rAtL, preds = test(opt, model, encoder_dim, device, whole_test_set, writer, epoch, extract_noEval=opt.extractOnly)
146 | if opt.resultsPath is not None:
147 | if not exists(opt.resultsPath):
148 | makedirs(opt.resultsPath)
149 | if opt.extractOnly:
150 | gt = whole_test_set.get_positives()
151 | numDb = whole_test_set.dbStruct.numDb
152 | np.savez(join(opt.resultsPath,unique_string),dbDesc=recallsOrDesc[:numDb],qDesc=recallsOrDesc[numDb:],gt=gt)
153 | else:
154 | np.savez(join(opt.resultsPath,unique_string),args=opt.__dict__,recalls=recallsOrDesc, dbEmb=dbEmb,qEmb=qEmb,rAtL=rAtL,preds=preds)
155 |
156 | elif opt.mode.lower() == 'train':
157 | print('===> Training model')
158 | logdir = join(opt.runsPath,unique_string)
159 | writer = SummaryWriter(log_dir=logdir)
160 | train_set.cache = join(opt.cachePath, train_set.whichSet + '_feat_cache_{}.hdf5'.format(unique_string))
161 |
162 | savePath = join(logdir, opt.savePath)
163 | makedirs(savePath)
164 | if not exists(opt.cachePath): makedirs(opt.cachePath)
165 |
166 | with open(join(savePath, 'flags.json'), 'w') as f:
167 | f.write(json.dumps({k:v for k,v in opt.items()}))
168 | print('===> Saving state to:', logdir)
169 |
170 | not_improved = 0
171 | best_score = 0
172 | for epoch in range(opt.start_epoch+1, opt.nEpochs + 1):
173 | train(opt, model, encoder_dim, device, dataset, criterion, optimizer, train_set, whole_train_set, whole_training_data_loader, epoch, writer)
174 | if opt.optim.upper() == 'SGD':
175 | scheduler.step()
176 | if (epoch % opt.evalEvery) == 0:
177 | recalls = test(opt, model, encoder_dim, device, whole_test_set, writer, epoch)[0]
178 | is_best = recalls[5] > best_score
179 | if is_best:
180 | not_improved = 0
181 | best_score = recalls[5]
182 | else:
183 | not_improved += 1
184 |
185 | save_checkpoint(savePath, {
186 | 'epoch': epoch,
187 | 'state_dict': model.state_dict(),
188 | 'recalls': recalls,
189 | 'best_score': best_score,
190 | 'optimizer' : optimizer.state_dict(),
191 | 'parallel' : isParallel,
192 | }, is_best)
193 |
194 | if opt.patience > 0 and not_improved > (opt.patience / opt.evalEvery):
195 | print('Performance did not improve for', opt.patience, 'epochs. Stopping.')
196 | break
197 |
198 | print("=> Best Recall@5: {:.4f}".format(best_score), flush=True)
199 | writer.close()
200 |
--------------------------------------------------------------------------------
/seqMatchNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from tqdm import tqdm
6 | import torch
7 | import time
8 |
9 | class seqMatchNet(nn.Module):
10 |
11 | def __init__(self):
12 | super(seqMatchNet, self).__init__()
13 |
14 | def cdist_quick(self,r,c):
15 | return torch.sqrt(2 - 2*torch.matmul(r,c.transpose(1,2)))
16 |
17 | def aggregateSeqScore(self,data):
18 | r, c, method = data
19 | dMat = self.cdist_quick(r,c)
20 | seqL = dMat.shape[1]
21 | dis = torch.diagonal(dMat,0,1,2)
22 | if method == 'centerOnly':
23 | dis = dis[:,seqL//2]
24 | else: # default to 'meanOfPairs'
25 | dis = dis.mean(-1)
26 | return dis
27 |
28 | def forward(self,data):
29 | return self.aggregateSeqScore(data)
30 |
31 | def computeDisMat_torch(r,c):
32 | # assumes descriptors to be l2-normalized
33 | return torch.stack([torch.sqrt(2 - 2*torch.matmul(r,c[i].unsqueeze(1))).squeeze() for i in range(c.shape[0])]).transpose(0,1)
34 |
35 | def modInd(idx,l,n):
36 | return max(l,min(idx,n-l-1))
37 |
38 | def computeRange(l,n):
39 | li, le = l//2, l-l//2
40 | return torch.stack([torch.arange(modInd(r,li,n)-li,modInd(r,li,n)+le,dtype=int) for r in range(n)])
41 |
42 | def aggregateMatchScores_pt_fromMat_oneShot(dMat,l,device):
43 | convWeight = torch.eye(l,device=device).unsqueeze(0).unsqueeze(0)
44 | dMat_seq = -1*torch.ones(dMat.shape,device=device)
45 | li, le = l//2, l-l//2
46 |
47 | dMat_seq[li:-le+1,li:-le+1] = torch.nn.functional.conv2d(dMat.unsqueeze(0).unsqueeze(0),convWeight).squeeze()
48 |
49 | # fill left and right columns
50 | dMat_seq[:,:li] = dMat_seq[:,li,None]
51 | dMat_seq[:,-le+1:] = dMat_seq[:,-le,None]
52 |
53 | # fill top and bottom rows
54 | dMat_seq[:li,:] = dMat_seq[None,li,:]
55 | dMat_seq[-le+1:,:] = dMat_seq[None,-le,:]
56 |
57 | return dMat_seq
58 |
59 | def aggregateMatchScores_pt_fromMat(dMat,l,device,refCandidates=None):
60 | li, le = l//2, l-l//2
61 | n = dMat.shape[0]
62 | convWeight = torch.eye(l,device=device).unsqueeze(0).unsqueeze(0)
63 |
64 | # dMat = dMat.to('cpu')
65 | if refCandidates is None:
66 | shape = dMat.shape
67 | else:
68 | shape = refCandidates.transpose().shape
69 | preCompInds = computeRange(l,n)
70 |
71 | dMat_seq = -1*torch.ones(shape,device=device)
72 |
73 | durs = []
74 | for j in tqdm(range(li,dMat.shape[1]-li), total=dMat.shape[1]-l, leave=True):
75 | t1 = time.time()
76 | if refCandidates is not None:
77 | rCands = preCompInds[refCandidates[j]].flatten()
78 | dMat_cols = dMat[rCands,j-li:j+le].to(device)
79 | dMat_seq[:,j] = torch.nn.functional.conv2d(dMat_cols.unsqueeze(0).unsqueeze(0),convWeight,stride=l).squeeze()
80 | else:
81 | dMat_cols = dMat[:,j-li:j+le].to(device)
82 | dMat_seq[li:-le+1,j] = torch.nn.functional.conv2d(dMat_cols.unsqueeze(0).unsqueeze(0),convWeight).squeeze()
83 | durs.append(time.time()-t1)
84 |
85 | if refCandidates is None:
86 | # fill left and right columns
87 | dMat_seq[:,:li] = dMat_seq[:,li,None]
88 | dMat_seq[:,-le+1:] = dMat_seq[:,-le,None]
89 |
90 | # fill top and bottom rows
91 | dMat_seq[:li,:] = dMat_seq[None,li,:]
92 | dMat_seq[-le+1:,:] = dMat_seq[None,-le,:]
93 |
94 | # assert(np.sum(dMat_seq==-1)==0)
95 | print("Average Time Per Query", np.mean(durs))
96 | return dMat_seq
97 |
98 | def aggregateMatchScores_pt_fromDesc(dbDesc,qDesc,l,device,refCandidates=None):
99 | numDb, numQ = dbDesc.shape[0], qDesc.shape[0]
100 | convWeight = torch.eye(l,device=device).unsqueeze(0).unsqueeze(0)
101 |
102 | if refCandidates is None:
103 | shape = [numDb,numQ]
104 | else:
105 | shape = refCandidates.transpose().shape
106 |
107 | dMat_seq = -1*torch.ones(shape,device=device)
108 |
109 | durs = []
110 | for j in tqdm(range(numQ), total=numQ, leave=True):
111 | t1 = time.time()
112 | if refCandidates is not None:
113 | rCands = refCandidates[j]
114 | else:
115 | rCands = torch.arange(numDb)
116 |
117 | dMat = torch.cdist(dbDesc[rCands],qDesc[j].unsqueeze(0))
118 | dMat_seq[:,j] = torch.nn.functional.conv2d(dMat.unsqueeze(1),convWeight).squeeze()
119 | durs.append(time.time()-t1)
120 |
121 | # assert(torch.sum(dMat_seq==-1)==0)
122 | print("Average Time Per Query", np.mean(durs), np.std(durs))
123 | return dMat_seq
124 |
125 | def aggregateMatchScores(dMat,l,device='cuda',refCandidates=None,dbDesc=None,qDesc=None,dMatProcOneShot=False):
126 | dMat_seq, matchInds, matchDists = None, None, None
127 | if dMat is None:
128 | dMat_seq = aggregateMatchScores_pt_fromDesc(dbDesc,qDesc,l,device,refCandidates).detach().cpu().numpy()
129 | else:
130 | if dMatProcOneShot:
131 | dMat_seq = aggregateMatchScores_pt_fromMat_oneShot(dMat,l,device).detach().cpu().numpy()
132 | else:
133 | dMat_seq = aggregateMatchScores_pt_fromMat(dMat,l,device,refCandidates).detach().cpu().numpy()
134 | return dMat_seq, matchInds, matchDists
135 |
--------------------------------------------------------------------------------
/structFiles/nordland_test_d-1_d2-1.db:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oravus/SeqMatchNet/431f2a963587e1bbfd3f840d7d6c299207f8d88c/structFiles/nordland_test_d-1_d2-1.db
--------------------------------------------------------------------------------
/structFiles/nordland_train_d-40_d2-10.db:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oravus/SeqMatchNet/431f2a963587e1bbfd3f840d7d6c299207f8d88c/structFiles/nordland_train_d-40_d2-10.db
--------------------------------------------------------------------------------
/structFiles/nordland_val_d-1_d2-1.db:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oravus/SeqMatchNet/431f2a963587e1bbfd3f840d7d6c299207f8d88c/structFiles/nordland_val_d-1_d2-1.db
--------------------------------------------------------------------------------
/structFiles/oxford-v1.0_2015-03-17-11-08-44_2014-12-16-18-44-24_test_d-10_d2-5.db:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oravus/SeqMatchNet/431f2a963587e1bbfd3f840d7d6c299207f8d88c/structFiles/oxford-v1.0_2015-03-17-11-08-44_2014-12-16-18-44-24_test_d-10_d2-5.db
--------------------------------------------------------------------------------
/structFiles/oxford-v1.0_2015-03-17-11-08-44_2014-12-16-18-44-24_train_d-20_d2-5.db:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oravus/SeqMatchNet/431f2a963587e1bbfd3f840d7d6c299207f8d88c/structFiles/oxford-v1.0_2015-03-17-11-08-44_2014-12-16-18-44-24_train_d-20_d2-5.db
--------------------------------------------------------------------------------
/structFiles/oxford-v1.0_splitInds.npz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/oravus/SeqMatchNet/431f2a963587e1bbfd3f840d7d6c299207f8d88c/structFiles/oxford-v1.0_splitInds.npz
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | from tqdm import tqdm
4 | from sklearn.manifold import TSNE
5 | from scipy.spatial.distance import cdist
6 | import numpy as np
7 | import time
8 | import wandb
9 |
10 | from utils import seq2Batch, getRecallAtN, computeMatches, evaluate, N_VALUES
11 |
12 | def test(opt, model, encoder_dim, device, eval_set, writer, epoch=0, extract_noEval=False):
13 | # TODO what if features dont fit in memory?
14 | test_data_loader = DataLoader(dataset=eval_set,
15 | num_workers=opt.threads, batch_size=opt.cacheBatchSize, shuffle=False,
16 | pin_memory=False)
17 |
18 | model.eval()
19 | with torch.no_grad():
20 | print('====> Extracting Features')
21 | pool_size = encoder_dim
22 | validInds = eval_set.validInds
23 | dbFeat_single = torch.zeros((len(eval_set), pool_size),device=device)
24 | durs_batch = []
25 | for iteration, (input, indices) in tqdm(enumerate(test_data_loader, 1),total=len(test_data_loader)-1, leave=False):
26 | t1 = time.time()
27 | image_encoding = seq2Batch(input).float().to(device)
28 | global_single_descs = model.pool(image_encoding).squeeze()
29 | dbFeat_single[indices] = global_single_descs
30 |
31 | if iteration % 50 == 0 or len(test_data_loader) <= 10:
32 | print("==> Batch ({}/{})".format(iteration,
33 | len(test_data_loader)), flush=True)
34 | durs_batch.append(time.time()-t1)
35 | del input, image_encoding, global_single_descs
36 |
37 | del test_data_loader
38 | print("Average batch time:", np.mean(durs_batch), np.std(durs_batch))
39 |
40 | outSeqL = opt.seqL
41 | # use the original single descriptors for fast seqmatch over dMat (MSLS-like non-continuous dataset not supported)
42 | if (not opt.pooling and opt.matcher is None) and ('nordland' in opt.dataset.lower() or 'tmr' in opt.dataset.lower() or 'ox' in opt.dataset.lower()):
43 | dbFeat = dbFeat_single
44 | numDb = eval_set.numDb_full
45 | # fill sequences centered at single images
46 | else:
47 | dbFeat = torch.zeros(len(validInds), outSeqL, pool_size, device=device)
48 | numDb = eval_set.dbStruct.numDb
49 | for ind in range(len(validInds)):
50 | dbFeat[ind] = dbFeat_single[eval_set.getSeqIndsFromValidInds(validInds[ind])]
51 | del dbFeat_single
52 |
53 | # extracted for both db and query, now split in own sets
54 | qFeat = dbFeat[numDb:]
55 | dbFeat = dbFeat[:numDb]
56 | print(dbFeat.shape, qFeat.shape)
57 |
58 | qFeat_np = qFeat.detach().cpu().numpy().astype('float32')
59 | dbFeat_np = dbFeat.detach().cpu().numpy().astype('float32')
60 |
61 | db_emb, q_emb = None, None
62 | if opt.numSamples2Project != -1 and writer is not None:
63 | db_emb = TSNE(n_components=2).fit_transform(dbFeat_np[:opt.numSamples2Project])
64 | q_emb = TSNE(n_components=2).fit_transform(qFeat_np[:opt.numSamples2Project])
65 |
66 | if extract_noEval:
67 | return np.vstack([dbFeat_np,qFeat_np]), db_emb, q_emb, None, None
68 |
69 | predictions, bestDists = computeMatches(opt,N_VALUES,device,dbFeat,qFeat,dbFeat_np,qFeat_np)
70 |
71 | # for each query get those within threshold distance
72 | gt,gtDists = eval_set.get_positives(retDists=True)
73 | gtDistsMat = cdist(eval_set.dbStruct.utmDb,eval_set.dbStruct.utmQ)
74 |
75 | recall_at_n = getRecallAtN(N_VALUES, predictions, gt)
76 | rAtL = evaluate(N_VALUES,predictions,gtDistsMat)
77 |
78 | recalls = {} #make dict for output
79 | for i,n in enumerate(N_VALUES):
80 | recalls[n] = recall_at_n[i]
81 | print("====> Recall@{}: {:.4f}".format(n, recall_at_n[i]))
82 | if writer is not None:
83 | writer.add_scalar('Val/Recall@' + str(n), recall_at_n[i], epoch)
84 | wandb.log({'Val/Recall@' + str(n): recall_at_n[i]},commit=False)
85 |
86 | return recalls, db_emb, q_emb, rAtL, predictions
87 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import DataLoader
3 | from torch.utils.data.dataset import Subset
4 | from tqdm import tqdm
5 | import numpy as np
6 | from os import remove
7 | import h5py
8 | from math import ceil
9 | import wandb
10 | from termcolor import colored
11 |
12 | from utils import batch2Seq, seq2Batch, getRecallAtN, computeMatches, N_VALUES
13 | import seqMatchNet
14 |
15 | def train(opt, model, encoder_dim, device, dataset, criterion, optimizer, train_set, whole_train_set, whole_training_data_loader, epoch, writer):
16 | epoch_loss = 0
17 | startIter = 1 # keep track of batch iter across subsets for logging
18 |
19 | if opt.cacheRefreshRate > 0:
20 | subsetN = ceil(len(train_set) / opt.cacheRefreshRate)
21 | #TODO randomise the arange before splitting?
22 | subsetIdx = np.array_split(np.arange(len(train_set)), subsetN)
23 | else:
24 | subsetN = 1
25 | subsetIdx = [np.arange(len(train_set))]
26 |
27 | nBatches = (len(train_set) + opt.batchSize - 1) // opt.batchSize
28 |
29 | for subIter in range(subsetN):
30 | print('====> Building Cache')
31 | model.eval()
32 | with h5py.File(train_set.cache, mode='w') as h5:
33 | pool_size = encoder_dim
34 | validInds = whole_train_set.validInds
35 | h5feat = h5.create_dataset("features", [len(validInds), pool_size], dtype=np.float32)
36 | h5DisMat = h5.create_dataset("disMat",[whole_train_set.numDb_valid, whole_train_set.numQ_valid],dtype=np.float32)
37 | with torch.no_grad():
38 | dbFeat_single = torch.zeros(len(whole_train_set), pool_size, device=device)
39 | # expected input B,T,C,H,W (T is 1)
40 | for iteration, (input, indices) in tqdm(enumerate(whole_training_data_loader, 1),total=len(whole_training_data_loader)-1, leave=True):
41 | # convert to B*T,C,H,W
42 | image_encoding = seq2Batch(input).float().to(device)
43 |
44 | # input B*T,C,1,1; outputs B,T,C (T=1); squeeze to B,C
45 | global_single_descs = model.pool(image_encoding).squeeze()
46 | dbFeat_single[indices] = global_single_descs
47 | del input, image_encoding, global_single_descs
48 |
49 | outSeqL = opt.seqL
50 | # fill sequences centered at single images
51 | dbFeat = torch.zeros(len(validInds), outSeqL, pool_size, device=device)
52 | for ind in range(len(validInds)):
53 | dbFeat[ind] = dbFeat_single[whole_train_set.getSeqIndsFromValidInds(validInds[ind])]
54 | if opt.matcher is None: # assumes seqL is 1 in this case
55 | h5feat[ind] = dbFeat[ind].squeeze()
56 | del dbFeat_single
57 |
58 | if opt.matcher is not None:
59 | offset = whole_train_set.numDb_valid
60 | # compute distance matrix
61 | print('====> Caching distance matrix')
62 | if opt.matcher == 'seqMatchNet':
63 | if 'nordland' in opt.dataset.lower() or 'tmr' in opt.dataset.lower():
64 | dMat_cont = seqMatchNet.seqMatchNet.computeDisMat_torch(dbFeat[:offset,outSeqL//2], dbFeat[offset:,outSeqL//2])
65 | if opt.neg_trip_method == 'centerOnly':
66 | h5DisMat[...] = dMat_cont.detach().cpu().numpy()
67 | elif opt.neg_trip_method == 'meanOfPairs':
68 | h5DisMat[...] = seqMatchNet.aggregateMatchScores(dMat_cont,outSeqL,device,dMatProcOneShot=False)[0] * (1.0/outSeqL)
69 | else:
70 | if opt.neg_trip_method == 'centerOnly':
71 | h5DisMat[...] = seqMatchNet.aggregateMatchScores(None,1,device, dbDesc=dbFeat[:offset,outSeqL//2:outSeqL//2+1], qDesc=dbFeat[offset:,outSeqL//2:outSeqL//2+1])[0]
72 | elif opt.neg_trip_method == 'meanOfPairs':
73 | h5DisMat[...] = seqMatchNet.aggregateMatchScores(None,outSeqL,device, dbDesc=dbFeat[:offset], qDesc=dbFeat[offset:])[0] * (1.0/outSeqL)
74 | else:
75 | raise("TODO")
76 |
77 | del dbFeat
78 | dMat = h5DisMat[()]
79 | dbFeat_np, qFeat_np = h5feat[:offset].copy(), h5feat[offset:].copy()
80 |
81 | sub_train_set = Subset(dataset=train_set, indices=subsetIdx[subIter])
82 |
83 | training_data_loader = DataLoader(dataset=sub_train_set, num_workers=opt.threads,
84 | batch_size=opt.batchSize, shuffle=True,
85 | collate_fn=dataset.collate_fn, pin_memory=False)
86 |
87 | if not opt.nocuda:
88 | print('Allocated:', torch.cuda.memory_allocated())
89 | print('Cached:', torch.cuda.memory_reserved())
90 |
91 | print('====> Training Queries')
92 | model.train()
93 | for iteration, (query, positives, negatives,
94 | negCounts, indices) in tqdm(enumerate(training_data_loader, startIter),total=len(training_data_loader),leave=True):
95 | loss = 0
96 | if query is None:
97 | continue # in case we get an empty batch
98 |
99 | B, C = len(query), query[0].shape[1]
100 | nNeg = torch.sum(negCounts)
101 | image_encoding = seq2Batch(torch.cat([query, positives, negatives])).float()
102 |
103 | image_encoding = image_encoding.to(device)
104 | global_single_descs = model.pool(image_encoding)
105 | global_single_descs = batch2Seq(global_single_descs.squeeze(1),opt.seqL)
106 |
107 | del image_encoding
108 | g_desc_Q, g_desc_P, g_desc_N = torch.split(global_single_descs, [B, B, nNeg])
109 | del global_single_descs
110 | optimizer.zero_grad()
111 |
112 | # calculate loss for each Query, Positive, Negative triplet
113 | # due to potential difference in number of negatives have to
114 | # do it per query, per negative
115 | trips_a, trips_p, trips_n = [], [], []
116 | for i, negCount in enumerate(negCounts):
117 | for n in range(negCount):
118 | negIx = (torch.sum(negCounts[:i]) + n).item()
119 | if opt.matcher is None:
120 | loss += criterion(g_desc_Q[i:i+1].squeeze(1), g_desc_P[i:i+1].squeeze(1), g_desc_N[negIx:negIx+1].squeeze(1))
121 | else:
122 | trips_a.append(g_desc_Q[i:i+1])
123 | trips_p.append(g_desc_P[i:i+1])
124 | trips_n.append(g_desc_N[negIx:negIx+1])
125 |
126 | del g_desc_Q, g_desc_P, g_desc_N
127 | if opt.matcher is not None:
128 | dis_ap = model.matcher([torch.cat(trips_a), torch.cat(trips_p), opt.loss_trip_method])
129 | dis_an = model.matcher([torch.cat(trips_a), torch.cat(trips_n), opt.loss_trip_method])
130 | loss = torch.max(dis_ap - dis_an + opt.margin**0.5,torch.zeros(dis_ap.shape,device=device)).mean()
131 | del trips_a, trips_p, trips_n
132 | else:
133 | loss /= nNeg.float().to(device) # normalise by actual number of negatives
134 | loss.backward()
135 | optimizer.step()
136 |
137 | batch_loss = loss.item()
138 | epoch_loss += batch_loss
139 |
140 | if iteration % 50 == 0 or nBatches <= 10:
141 | print("==> Epoch[{}]({}/{}): Loss: {:.4f}".format(colored(epoch,'red'), iteration,
142 | nBatches, batch_loss), flush=True)
143 | writer.add_scalar('Train/Loss', batch_loss,
144 | ((epoch-1) * nBatches) + iteration)
145 | writer.add_scalar('Train/nNeg', nNeg,
146 | ((epoch-1) * nBatches) + iteration)
147 | wandb.log({"loss":batch_loss, "nNeg":nNeg, "epoch":epoch})
148 | if not opt.nocuda:
149 | print('Allocated:', torch.cuda.memory_allocated())
150 | print('Cached:', torch.cuda.memory_reserved())
151 |
152 | del query, positives, negatives
153 |
154 | startIter += len(training_data_loader)
155 | del training_data_loader, loss
156 | optimizer.zero_grad()
157 | torch.cuda.empty_cache()
158 | remove(train_set.cache) # delete HDF5 cache
159 |
160 | avg_loss = epoch_loss / nBatches
161 |
162 | print("===> Epoch {} Complete: Avg. Loss: {:.4f}".format(colored(epoch,'red'), avg_loss),
163 | flush=True)
164 | writer.add_scalar('Train/AvgLoss', avg_loss, epoch)
165 | predictions, bestDists = computeMatches(opt,N_VALUES,device,dbFeat_np=dbFeat_np,qFeat_np=qFeat_np,dMat=dMat)
166 | gt,gtDists = whole_train_set.get_positives(retDists=True)
167 | recall_at_n = getRecallAtN(N_VALUES, predictions, gt)
168 | wandb.log({"loss_e":avg_loss},commit=False)
169 | for i,n in enumerate(N_VALUES):
170 | writer.add_scalar('Train/Recall@' + str(n), recall_at_n[i], epoch)
171 | wandb.log({'Train/Recall@' + str(n): recall_at_n[i]},commit=False)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import faiss
4 | from os.path import join
5 | import shutil
6 |
7 | import seqMatchNet
8 |
9 | N_VALUES = [1,5,10,20,100]
10 |
11 | def batch2Seq(input,l):
12 | inSh = input.shape
13 | input = input.view(inSh[0]//l,l,inSh[1])
14 | return input
15 |
16 | def seq2Batch(input):
17 | inSh = input.shape
18 | input = input.view(inSh[0]*inSh[1],inSh[2],inSh[3],inSh[4])
19 | return input
20 |
21 | def save_checkpoint(savePath, state, is_best, filename='checkpoint.pth.tar'):
22 | model_out_path = join(savePath, filename)
23 | torch.save(state, model_out_path)
24 | if is_best:
25 | shutil.copyfile(model_out_path, join(savePath, 'model_best.pth.tar'))
26 |
27 | def getRecallAtN(n_values, predictions, gt):
28 | correct_at_n = np.zeros(len(n_values))
29 | numQWithoutGt = 0
30 | #TODO can we do this on the matrix in one go?
31 | for qIx, pred in enumerate(predictions):
32 | if len(gt[qIx]) == 0:
33 | numQWithoutGt += 1
34 | continue
35 | for i,n in enumerate(n_values):
36 | # if in top N then also in top NN, where NN > N
37 | if np.any(np.in1d(pred[:n], gt[qIx])):
38 | correct_at_n[i:] += 1
39 | break
40 | # print("Num Q without GT: ", numQWithoutGt, " of ", len(gt))
41 | return correct_at_n / (len(gt)-numQWithoutGt)
42 |
43 | def computeMatches(opt,n_values,device,dbFeat=None,qFeat=None,dbFeat_np=None,qFeat_np=None,dMat=None):
44 |
45 | if opt.matcher is not None:
46 | if dMat is None:
47 | if opt.predictionsFile is not None:
48 | predPrior = np.load(opt.predictionsFile)['preds']
49 | predPriorTopK = predPrior[:,:20]
50 | else:
51 | predPriorTopK = None
52 | outSeqL = opt.seqL
53 | dMat = 1.0/outSeqL * seqMatchNet.aggregateMatchScores(None,outSeqL,device, dbDesc=dbFeat, qDesc=qFeat,refCandidates=predPriorTopK)[0]
54 | print(dMat.shape)
55 | predictions = np.argsort(dMat,axis=0)[:max(n_values),:].transpose()
56 | bestDists = dMat[predictions[:,0],np.arange(dMat.shape[1])]
57 | if opt.predictionsFile is not None:
58 | predictions = np.array([predPriorTopK[qIdx][predictions[qIdx]] for qIdx in range(predictions.shape[0])])
59 | print("Preds:",predictions.shape)
60 |
61 | # single image descriptors
62 | else:
63 | assert(opt.seqL==1)
64 | print('====> Building faiss index')
65 | faiss_index = faiss.IndexFlatL2(dbFeat_np.shape[-1])
66 | faiss_index.add(np.squeeze(dbFeat_np))
67 |
68 | distances, predictions = faiss_index.search(np.squeeze(qFeat_np), max(n_values))
69 | bestDists = distances[:,0]
70 | return predictions, bestDists
71 |
72 | def evaluate(n_values,predictions,gtDistsMat=None):
73 | print('====> Calculating recall @ N')
74 | # compute recall for different loc radii
75 | rAtL = []
76 | for locRad in [1,5,10,20,40,100,200]:
77 | gtAtL = gtDistsMat <= locRad
78 | gtAtL = [np.argwhere(gtAtL[:,qIx]).flatten() for qIx in range(gtDistsMat.shape[1])]
79 | rAtL.append(getRecallAtN(n_values, predictions, gtAtL))
80 | return rAtL
--------------------------------------------------------------------------------