├── .gitignore ├── .vscode └── settings.json ├── LICENSE ├── README.md ├── cirtorch └── functional.py ├── datasets └── pitts.py ├── download.sh ├── ece.jpg ├── eval_batch.sh ├── eval_ece_sh.py ├── main.py ├── networks └── res50gem.py ├── options.py ├── trainer.py ├── utils.py └── vis_results.py /.gitignore: -------------------------------------------------------------------------------- 1 | pittsburgh/ 2 | tmp/ 3 | __pycache__/ 4 | logs/ 5 | *.zip 6 | dropbox_shared/ 7 | wandb/ 8 | .vscode/ -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "git.ignoreLimitWarning": true 3 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Kaiwen Cai 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ### [IROS2022] STUN: Self-Teaching Uncertainty Estimation for Place Recognition 2 | 3 | #### 0. Environment Setup ⚙️ 4 | - Ubuntu 18.04, python 3.8, A100 5 | - PyTorch 1.8.1 + CUDA 11.1 6 | 7 | #### 1. Download the Pittsburgh dataset and pretrained models 📨 8 | ```shell 9 | bash ./download.sh 10 | ``` 11 | 12 | the files will be downloaded and saved in the following folders: 13 | 14 | ```shell 15 | pittsburgh 16 | ├── database 17 | ├── query 18 | └── structure 19 | 20 | logs 21 | ├── student_contrast 22 | ├── student_quadruplet 23 | ├── student_triplet 24 | └── teacher_triplet 25 | ``` 26 | 27 | #### 2. Evaluate the pretrained models 🔍 28 | 29 | ```shell 30 | # STUN 31 | python main.py --resume=logs/student_triplet/ckpt.pth.tar 32 | 33 | # STUN (Constrast) 34 | python main.py --resume=logs/student_constrast/ckpt.pth.tar 35 | 36 | # STUN (Quadruplet) 37 | python main.py --resume=logs/student_quadruplet/ckpt.pth.tar 38 | 39 | # Standard Triplet 40 | python main.py --phase=test_tea --resume=logs/teacher_triplet/ckpt.pth.tar 41 | 42 | ``` 43 | 44 | #### 3. Plot results 📈 45 | 46 | ```shell 47 | python vis_results.py 48 | # you can plot results of different models by populate the NETWORK variable. 49 | ``` 50 | 51 | #### 4. Train and evaluate STUN from scratch 🧭 52 | 53 | ```shell 54 | # train the teacher net 55 | python main.py --phase=train_tea --loss=tri 56 | 57 | # train the student net supervised by the pretrained teacher net 58 | python main.py --phase=train_stu --resume=[teacher_net_xxx/ckpt_best.pth.tar] 59 | 60 | ``` 61 | After analyzing empirical figures, we found the correlation between recall@N and uncertainty level evolve into a sensible trend after 30 epochs. But ECE (Expected Calibration Error) will diverge if the student network is excessively trained. As a result, we focused our examination on the model's performance from epoch=30 to epoch=35 and chose the one with the lowest ECE. 62 | ```shell 63 | # evaluate 64 | ./eval_batch.sh 65 | ``` 66 | ![ece.png](ece.jpg) 67 | 68 | If you find our work useful, please consider citing: 69 | ``` 70 | @INPROCEEDINGS{stun_cai, 71 | author={Cai, Kaiwen and Lu, Chris Xiaoxuan and Huang, Xiaowei}, 72 | booktitle={2022 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)}, 73 | title={STUN: Self-Teaching Uncertainty Estimation for Place Recognition}, 74 | year={2022}, 75 | volume={}, 76 | number={}, 77 | pages={6614-6621}, 78 | doi={10.1109/IROS47612.2022.9981546}} 79 | ``` 80 | 81 | 92 | -------------------------------------------------------------------------------- /cirtorch/functional.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pdb 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | # -------------------------------------- 8 | # pooling 9 | # -------------------------------------- 10 | 11 | 12 | def mac(x): 13 | return F.max_pool2d(x, (x.size(-2), x.size(-1))) 14 | # return F.adaptive_max_pool2d(x, (1,1)) # alternative 15 | 16 | 17 | def spoc(x): 18 | return F.avg_pool2d(x, (x.size(-2), x.size(-1))) 19 | # return F.adaptive_avg_pool2d(x, (1,1)) # alternative 20 | 21 | 22 | def gem(x, p=3, eps=1e-6): 23 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p) 24 | # return F.lp_pool2d(F.threshold(x, eps, eps), p, (x.size(-2), x.size(-1))) # alternative 25 | 26 | 27 | def rmac(x, L=3, eps=1e-6): 28 | ovr = 0.4 # desired overlap of neighboring regions 29 | steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension 30 | 31 | W = x.size(3) 32 | H = x.size(2) 33 | 34 | w = min(W, H) 35 | w2 = math.floor(w / 2.0 - 1) 36 | 37 | b = (max(H, W) - w) / (steps - 1) 38 | (tmp, idx) = torch.min(torch.abs(((w**2 - w * b) / w**2) - ovr), 0) # steps(idx) regions for long dimension 39 | 40 | # region overplus per dimension 41 | Wd = 0 42 | Hd = 0 43 | if H < W: 44 | Wd = idx.item() + 1 45 | elif H > W: 46 | Hd = idx.item() + 1 47 | 48 | v = F.max_pool2d(x, (x.size(-2), x.size(-1))) 49 | v = v / (torch.norm(v, p=2, dim=1, keepdim=True) + eps).expand_as(v) 50 | 51 | for l in range(1, L + 1): 52 | wl = math.floor(2 * w / (l + 1)) 53 | wl2 = math.floor(wl / 2 - 1) 54 | 55 | if l + Wd == 1: 56 | b = 0 57 | else: 58 | b = (W - wl) / (l + Wd - 1) 59 | cenW = torch.floor(wl2 + torch.Tensor(range(l - 1 + Wd + 1)) * b) - wl2 # center coordinates 60 | if l + Hd == 1: 61 | b = 0 62 | else: 63 | b = (H - wl) / (l + Hd - 1) 64 | cenH = torch.floor(wl2 + torch.Tensor(range(l - 1 + Hd + 1)) * b) - wl2 # center coordinates 65 | 66 | for i_ in cenH.tolist(): 67 | for j_ in cenW.tolist(): 68 | if wl == 0: 69 | continue 70 | R = x[:, :, (int(i_) + torch.Tensor(range(wl)).long()).tolist(), :] 71 | R = R[:, :, :, (int(j_) + torch.Tensor(range(wl)).long()).tolist()] 72 | vt = F.max_pool2d(R, (R.size(-2), R.size(-1))) 73 | vt = vt / (torch.norm(vt, p=2, dim=1, keepdim=True) + eps).expand_as(vt) 74 | v += vt 75 | 76 | return v 77 | 78 | 79 | def roipool(x, rpool, L=3, eps=1e-6): 80 | ovr = 0.4 # desired overlap of neighboring regions 81 | steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension 82 | 83 | W = x.size(3) 84 | H = x.size(2) 85 | 86 | w = min(W, H) 87 | w2 = math.floor(w / 2.0 - 1) 88 | 89 | b = (max(H, W) - w) / (steps - 1) 90 | _, idx = torch.min(torch.abs(((w**2 - w * b) / w**2) - ovr), 0) # steps(idx) regions for long dimension 91 | 92 | # region overplus per dimension 93 | Wd = 0 94 | Hd = 0 95 | if H < W: 96 | Wd = idx.item() + 1 97 | elif H > W: 98 | Hd = idx.item() + 1 99 | 100 | vecs = [] 101 | vecs.append(rpool(x).unsqueeze(1)) 102 | 103 | for l in range(1, L + 1): 104 | wl = math.floor(2 * w / (l + 1)) 105 | wl2 = math.floor(wl / 2 - 1) 106 | 107 | if l + Wd == 1: 108 | b = 0 109 | else: 110 | b = (W - wl) / (l + Wd - 1) 111 | cenW = torch.floor(wl2 + torch.Tensor(range(l - 1 + Wd + 1)) * b).int() - wl2 # center coordinates 112 | if l + Hd == 1: 113 | b = 0 114 | else: 115 | b = (H - wl) / (l + Hd - 1) 116 | cenH = torch.floor(wl2 + torch.Tensor(range(l - 1 + Hd + 1)) * b).int() - wl2 # center coordinates 117 | 118 | for i_ in cenH.tolist(): 119 | for j_ in cenW.tolist(): 120 | if wl == 0: 121 | continue 122 | vecs.append(rpool(x.narrow(2, i_, wl).narrow(3, j_, wl)).unsqueeze(1)) 123 | 124 | return torch.cat(vecs, dim=1) 125 | 126 | 127 | # -------------------------------------- 128 | # normalization 129 | # -------------------------------------- 130 | 131 | 132 | def l2n(x, eps=1e-6): 133 | return x / (torch.norm(x, p=2, dim=1, keepdim=True) + eps).expand_as(x) 134 | 135 | 136 | def powerlaw(x, eps=1e-6): 137 | x = x + self.eps 138 | return x.abs().sqrt().mul(x.sign()) 139 | 140 | 141 | # -------------------------------------- 142 | # loss 143 | # -------------------------------------- 144 | 145 | 146 | def contrastive_loss(x, label, margin=0.7, eps=1e-6): 147 | # x is D x N 148 | dim = x.size(0) # D 149 | nq = torch.sum(label.data == -1) # number of tuples 150 | S = x.size(1) // nq # number of images per tuple including query: 1+1+n 151 | 152 | x1 = x[:, ::S].permute(1, 0).repeat(1, S - 1).view((S - 1) * nq, dim).permute(1, 0) 153 | idx = [i for i in range(len(label)) if label.data[i] != -1] 154 | x2 = x[:, idx] 155 | lbl = label[label != -1] 156 | 157 | dif = x1 - x2 158 | D = torch.pow(dif + eps, 2).sum(dim=0).sqrt() 159 | 160 | y = 0.5 * lbl * torch.pow(D, 2) + 0.5 * (1 - lbl) * torch.pow(torch.clamp(margin - D, min=0), 2) 161 | y = torch.sum(y) 162 | return y 163 | 164 | 165 | def triplet_loss(x, label, margin=0.1): 166 | # x is D x N 167 | dim = x.size(0) # D 168 | nq = torch.sum(label.data == -1).item() # number of tuples 169 | S = x.size(1) // nq # number of images per tuple including query: 1+1+n 170 | 171 | xa = x[:, label.data == -1].permute(1, 0).repeat(1, S - 2).view((S - 2) * nq, dim).permute(1, 0) 172 | xp = x[:, label.data == 1].permute(1, 0).repeat(1, S - 2).view((S - 2) * nq, dim).permute(1, 0) 173 | xn = x[:, label.data == 0] 174 | 175 | dist_pos = torch.sum(torch.pow(xa - xp, 2), dim=0) 176 | dist_neg = torch.sum(torch.pow(xa - xn, 2), dim=0) 177 | 178 | return torch.sum(torch.clamp(dist_pos - dist_neg + margin, min=0)) -------------------------------------------------------------------------------- /datasets/pitts.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from os.path import join 3 | 4 | import h5py 5 | import numpy as np 6 | import torch 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | from PIL import Image 10 | from scipy.io import loadmat 11 | from sklearn.neighbors import NearestNeighbors 12 | from torchvision.transforms import InterpolationMode 13 | 14 | 15 | def input_transform(opt=None): 16 | return transforms.Compose([ 17 | transforms.ToTensor(), 18 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 19 | transforms.Resize((opt.height, opt.width), interpolation=InterpolationMode.BILINEAR), 20 | ]) 21 | 22 | 23 | def get_whole_training_set(opt, onlyDB=False, forCluster=False, return_labels=False): 24 | return WholeDatasetFromStruct(opt, join(opt.structDir, 'pitts30k_train.mat'), opt.imgDir, input_transform=input_transform(opt), onlyDB=onlyDB, forCluster=forCluster, return_labels=return_labels) 25 | 26 | 27 | def get_whole_val_set(opt, return_labels=False): 28 | return WholeDatasetFromStruct(opt, join(opt.structDir, 'pitts30k_val.mat'), opt.imgDir, input_transform=input_transform(opt), return_labels=return_labels) 29 | 30 | 31 | def get_whole_test_set(opt, return_labels=False): 32 | return WholeDatasetFromStruct(opt, join(opt.structDir, 'pitts30k_test.mat'), opt.imgDir, input_transform=input_transform(opt), return_labels=return_labels) 33 | 34 | 35 | def get_training_query_set(opt, margin=0.1): 36 | return QueryDatasetFromStruct(opt, join(opt.structDir, 'pitts30k_train.mat'), opt.imgDir, input_transform=input_transform(opt), margin=margin) 37 | 38 | 39 | def get_val_query_set(opt, margin=0.1): 40 | return QueryDatasetFromStruct(opt, join(opt.structDir, 'pitts30k_val.mat'), opt.imgDir, input_transform=input_transform(opt), margin=margin) 41 | 42 | 43 | def get_quad_set(opt, margin, margin2): 44 | return QuadrupletDataset(opt, join(opt.structDir, 'pitts30k_train.mat'), opt.imgDir, input_transform=input_transform(opt), margin=margin, margin2=margin2) 45 | 46 | 47 | dbStruct = namedtuple('dbStruct', ['whichSet', 'dataset', 'dbImage', 'utmDb', 'qImage', 'utmQ', 'numDb', 'numQ', 'posDistThr', 'posDistSqThr', 'nonTrivPosDistSqThr']) 48 | 49 | 50 | def parse_dbStruct(path): 51 | mat = loadmat(path) 52 | matStruct = mat['dbStruct'].item() 53 | 54 | dataset = 'nuscenes' 55 | 56 | whichSet = matStruct[0].item() 57 | 58 | # .mat file is generated by python, Kaiwen replaces the use of cell (in Matlab) with char (in Python) 59 | dbImage = [f[0].item() for f in matStruct[1]] 60 | # dbImage = matStruct[1] 61 | utmDb = matStruct[2].T 62 | # utmDb = matStruct[2] 63 | 64 | # .mat file is generated by python, I replace the use of cell (in Matlab) with char (in Python) 65 | qImage = [f[0].item() for f in matStruct[3]] 66 | # qImage = matStruct[3] 67 | utmQ = matStruct[4].T 68 | # utmQ = matStruct[4] 69 | 70 | numDb = matStruct[5].item() 71 | numQ = matStruct[6].item() 72 | 73 | posDistThr = matStruct[7].item() 74 | posDistSqThr = matStruct[8].item() 75 | nonTrivPosDistSqThr = matStruct[9].item() 76 | 77 | return dbStruct(whichSet, dataset, dbImage, utmDb, qImage, utmQ, numDb, numQ, posDistThr, posDistSqThr, nonTrivPosDistSqThr) 78 | 79 | 80 | class WholeDatasetFromStructForCluster(data.Dataset): 81 | def __init__(self, opt, structFile, img_dir, input_transform=None, onlyDB=False): 82 | super().__init__() 83 | 84 | self.input_transform = input_transform 85 | 86 | self.dbStruct = parse_dbStruct(structFile) 87 | 88 | self.images = [join(img_dir, 'database', dbIm) for dbIm in self.dbStruct.dbImage] 89 | if not onlyDB: 90 | self.images += [join(img_dir, 'query', qIm) for qIm in self.dbStruct.qImage] 91 | 92 | self.whichSet = self.dbStruct.whichSet 93 | self.dataset = self.dbStruct.dataset 94 | 95 | self.positives = None 96 | self.distances = None 97 | 98 | def __getitem__(self, index): 99 | img = Image.open(self.images[index]) 100 | if self.input_transform: 101 | img = self.input_transform(img) 102 | 103 | return img, index 104 | 105 | def __len__(self): 106 | return len(self.images) 107 | 108 | def getPositives(self): 109 | # positives for evaluation are those within trivial threshold range 110 | # fit NN to find them, search by radius 111 | if self.positives is None: 112 | knn = NearestNeighbors(n_jobs=-1) 113 | knn.fit(self.dbStruct.utmDb) 114 | self.distances, self.positives = knn.radius_neighbors(self.dbStruct.utmQ, radius=self.dbStruct.nonTrivPosDistSqThr**0.5) # TODO: sort!! 115 | 116 | return self.positives 117 | 118 | 119 | class WholeDatasetFromStruct(data.Dataset): 120 | def __init__(self, opt, structFile, img_dir, input_transform=None, onlyDB=False, forCluster=False, return_labels=False): 121 | super().__init__() 122 | self.opt = opt 123 | self.forCluster = forCluster 124 | self.return_labels = return_labels 125 | 126 | self.input_transform = input_transform 127 | 128 | self.dbStruct = parse_dbStruct(structFile) 129 | 130 | self.images = [join(img_dir, 'database', dbIm) for dbIm in self.dbStruct.dbImage] 131 | if not onlyDB: 132 | self.images += [join(img_dir, 'query', qIm) for qIm in self.dbStruct.qImage] 133 | 134 | self.whichSet = self.dbStruct.whichSet 135 | self.dataset = self.dbStruct.dataset 136 | 137 | self.positives = None 138 | self.distances = None 139 | 140 | def load_images(self, index): 141 | filename = self.images[index] 142 | # imgs = [] 143 | img = Image.open(filename) 144 | if self.input_transform: 145 | img = self.input_transform(img) 146 | # imgs.append(img) 147 | # imgs = torch.stack(imgs, 0) 148 | 149 | return img, index 150 | 151 | def __getitem__(self, index): 152 | if self.forCluster: 153 | img = Image.open(self.images[index]) 154 | if self.input_transform: 155 | img = self.input_transform(img) 156 | 157 | return img, index 158 | else: 159 | if self.return_labels: 160 | imgs, index = self.load_images(index) 161 | return imgs, index, self.dbStruct.utmQ[index] 162 | else: 163 | imgs, index = self.load_images(index) 164 | return imgs, index 165 | 166 | def __len__(self): 167 | return len(self.images) 168 | 169 | def get_databases(self): 170 | return self.dbStruct.utmDb 171 | 172 | def get_queries(self): 173 | return self.dbStruct.utmQ 174 | 175 | def get_positives(self): 176 | # positives for evaluation are those within trivial threshold range 177 | # fit NN to find them, search by radius 178 | if self.positives is None: 179 | knn = NearestNeighbors(n_jobs=-1) 180 | knn.fit(self.dbStruct.utmDb) 181 | self.distances, self.positives = knn.radius_neighbors(self.dbStruct.utmQ, radius=self.dbStruct.nonTrivPosDistSqThr**0.5) # TODO: sort!! 182 | 183 | return self.positives 184 | 185 | 186 | def collate_fn(batch): 187 | """Creates mini-batch tensors from the list of tuples (query, positive, negatives). 188 | 189 | Args: 190 | data: list of tuple (query, positive, negatives). 191 | - query: torch tensor of shape (3, h, w). 192 | - positive: torch tensor of shape (3, h, w). 193 | - negative: torch tensor of shape (n, 3, h, w). 194 | Returns: 195 | query: torch tensor of shape (batch_size, 3, h, w). 196 | positive: torch tensor of shape (batch_size, 3, h, w). 197 | negatives: torch tensor of shape (batch_size, n, 3, h, w). 198 | """ 199 | 200 | batch = list(filter(lambda x: x is not None, batch)) 201 | if len(batch) == 0: 202 | return None, None, None, None, None 203 | 204 | query, positive, negatives, indices = zip(*batch) 205 | 206 | query = data.dataloader.default_collate(query) # ([8, 3, 200, 200]) = [(3, 200, 200), (3, 200, 200), .. ] ([8, 1, 3, 200, 200]) 207 | positive = data.dataloader.default_collate(positive) 208 | negCounts = data.dataloader.default_collate([x.shape[0] for x in negatives]) 209 | negatives = torch.cat(negatives, 0) # ([80, 3, 200, 200]) ([80, 1, 3, 200, 200]) 210 | import itertools 211 | indices = list(itertools.chain(*indices)) 212 | 213 | return query, positive, negatives, negCounts, indices 214 | 215 | 216 | class QueryDatasetFromStruct(data.Dataset): 217 | def __init__(self, opt, structFile, img_dir, nNegSample=1000, nNeg=10, margin=0.1, input_transform=None): 218 | super().__init__() 219 | self.opt = opt 220 | self.img_dir = img_dir 221 | self.input_transform = input_transform 222 | self.margin = margin 223 | 224 | self.dbStruct = parse_dbStruct(structFile) 225 | self.whichSet = self.dbStruct.whichSet 226 | self.dataset = self.dbStruct.dataset 227 | self.nNegSample = nNegSample # number of negatives to randomly sample 228 | self.nNeg = nNeg # number of negatives used for training 229 | 230 | # potential positives are those within nontrivial threshold range 231 | # fit NN to find them, search by radius 232 | knn = NearestNeighbors(n_jobs=-1) 233 | knn.fit(self.dbStruct.utmDb) 234 | 235 | # TODO use sqeuclidean as metric? 236 | self.nontrivial_positives = list(knn.radius_neighbors(self.dbStruct.utmQ, radius=self.dbStruct.nonTrivPosDistSqThr**0.5, return_distance=False)) 237 | # radius returns unsorted, sort once now so we dont have to later 238 | for i, posi in enumerate(self.nontrivial_positives): 239 | self.nontrivial_positives[i] = np.sort(posi) 240 | # its possible some queries don't have any non trivial potential positives 241 | # lets filter those out 242 | self.queries = np.where(np.array([len(x) for x in self.nontrivial_positives]) > 0)[0] 243 | 244 | # potential negatives are those outside of posDistThr range 245 | potential_positives = knn.radius_neighbors(self.dbStruct.utmQ, radius=self.dbStruct.posDistThr, return_distance=False) 246 | 247 | self.potential_negatives = [] 248 | for pos in potential_positives: 249 | self.potential_negatives.append(np.setdiff1d(np.arange(self.dbStruct.numDb), pos, assume_unique=True)) 250 | 251 | self.cache = None # filepath of HDF5 containing feature vectors for images 252 | 253 | self.negCache = [np.empty((0, ), dtype=np.int64) for _ in range(self.dbStruct.numQ)] 254 | 255 | def load_images(self, filename): 256 | # imgs = [] 257 | img = Image.open(filename) 258 | if self.input_transform: 259 | img = self.input_transform(img) 260 | # imgs.append(img) 261 | # imgs = torch.stack(imgs, 0) 262 | 263 | return img 264 | 265 | def __getitem__(self, index): 266 | 267 | index = self.queries[index] # re-map index to match dataset 268 | with h5py.File(self.cache, mode='r') as h5: 269 | h5feat = h5.get("features") 270 | qOffset = self.dbStruct.numDb 271 | 272 | qFeat = h5feat[index + qOffset] 273 | posFeat = h5feat[self.nontrivial_positives[index].tolist()] 274 | qFeat = torch.tensor(qFeat) 275 | posFeat = torch.tensor(posFeat) 276 | dist = torch.norm(qFeat - posFeat, dim=1, p=None) 277 | result = dist.topk(1, largest=False) 278 | dPos, posNN = result.values, result.indices 279 | posIndex = self.nontrivial_positives[index][posNN].item() 280 | 281 | negSample = np.random.choice(self.potential_negatives[index], self.nNegSample) # randomly choose potential_negatives 282 | negSample = np.unique(np.concatenate([self.negCache[index], negSample])) # remember negSamples history for each query 283 | 284 | negFeat = h5feat[negSample.tolist()] 285 | negFeat = torch.tensor(negFeat) 286 | dist = torch.norm(qFeat - negFeat, dim=1, p=None) 287 | result = dist.topk(self.nNeg * 10, largest=False) 288 | dNeg, negNN = result.values, result.indices 289 | 290 | if self.opt.loss == 'cont': 291 | violatingNeg = dNeg.numpy() < self.margin**0.5 292 | else: 293 | violatingNeg = dNeg.numpy() < dPos.numpy() + self.margin**0.5 294 | 295 | if np.sum(violatingNeg) < 1: 296 | return None 297 | 298 | negNN = negNN.numpy() 299 | negNN = negNN[violatingNeg][:self.nNeg] 300 | negIndices = negSample[negNN].astype(np.int32) 301 | self.negCache[index] = negIndices 302 | 303 | query = self.load_images(join(self.img_dir, 'query', self.dbStruct.qImage[index])) 304 | positive = self.load_images(join(self.img_dir, 'database', self.dbStruct.dbImage[posIndex])) 305 | 306 | negatives = [] 307 | for negIndex in negIndices: 308 | negative = self.load_images(join(self.img_dir, 'database', self.dbStruct.dbImage[negIndex])) 309 | negatives.append(negative) 310 | 311 | negatives = torch.stack(negatives, 0) # ([10, 3, 200, 200]) 312 | return query, positive, negatives, [index, posIndex] + negIndices.tolist() 313 | 314 | def __len__(self): 315 | return len(self.queries) 316 | 317 | 318 | class QuadrupletDataset(data.Dataset): 319 | def __init__(self, opt, structFile, img_dir, nNegSample=1000, nNeg=10, margin=0.1, margin2=0.05, input_transform=None): 320 | super().__init__() 321 | self.opt = opt 322 | self.img_dir = img_dir 323 | self.input_transform = input_transform 324 | self.margin = margin 325 | self.margin2 = margin2 326 | 327 | self.dbStruct = parse_dbStruct(structFile) 328 | self.whichSet = self.dbStruct.whichSet 329 | self.dataset = self.dbStruct.dataset 330 | self.nNegSample = nNegSample # number of negatives to randomly sample 331 | self.nNeg = nNeg # number of negatives used for training 332 | 333 | # potential positives are those within nontrivial threshold range, fit NN to find them, search by radius 334 | knn = NearestNeighbors(n_jobs=-1) 335 | knn.fit(self.dbStruct.utmDb) 336 | 337 | self.db_potential_positives = knn.radius_neighbors(self.dbStruct.utmDb, radius=self.dbStruct.posDistThr, return_distance=False) # 6312 338 | self.db_potential_negatives = [] 339 | for pos in self.db_potential_positives: 340 | self.db_potential_negatives.append(np.setdiff1d(np.arange(self.dbStruct.numDb), pos, assume_unique=True)) 341 | 342 | # TODO use sqeuclidean as metric? 343 | self.nontrivial_positives = list(knn.radius_neighbors(self.dbStruct.utmQ, radius=self.dbStruct.nonTrivPosDistSqThr**0.5, return_distance=False)) # 7075 344 | # radius returns unsorted, sort once now so we dont have to later 345 | for i, posi in enumerate(self.nontrivial_positives): 346 | self.nontrivial_positives[i] = np.sort(posi) 347 | # its possible some queries don't have any non trivial potential positives, lets filter those out 348 | self.queries = np.where(np.array([len(x) for x in self.nontrivial_positives]) > 0)[0] 349 | 350 | # potential negatives are those outside of posDistThr range 351 | self.potential_positives = knn.radius_neighbors(self.dbStruct.utmQ, radius=self.dbStruct.posDistThr, return_distance=False) 352 | 353 | self.potential_negatives = [] 354 | for pos in self.potential_positives: 355 | self.potential_negatives.append(np.setdiff1d(np.arange(self.dbStruct.numDb), pos, assume_unique=True)) 356 | 357 | self.cache = None # filepath of HDF5 containing feature vectors for images 358 | 359 | self.negCache = [np.empty((0, ), dtype=np.int64) for _ in range(self.dbStruct.numQ)] 360 | 361 | def load_images(self, filename): 362 | img = Image.open(filename) 363 | if self.input_transform: 364 | img = self.input_transform(img) 365 | return img 366 | 367 | def __getitem__(self, index): 368 | index = self.queries[index] # re-map index to match dataset 369 | with h5py.File(self.cache, mode='r') as h5: 370 | h5feat = h5.get("features") 371 | qOffset = self.dbStruct.numDb 372 | 373 | qFeat = h5feat[index + qOffset] 374 | tmp = self.nontrivial_positives[index] 375 | tmp = tmp.tolist() 376 | posFeat = h5feat[self.nontrivial_positives[index].tolist()] 377 | qFeat = torch.tensor(qFeat) 378 | posFeat = torch.tensor(posFeat) 379 | dist = torch.norm(qFeat - posFeat, dim=1, p=None) 380 | result = dist.topk(1, largest=False) # choose the closet positive 381 | dPos, posNN = result.values, result.indices 382 | posIndex = self.nontrivial_positives[index][posNN].item() 383 | 384 | negSample = np.random.choice(self.potential_negatives[index], self.nNegSample) # randomly choose potential_negatives 385 | negSample = np.unique(np.concatenate([self.negCache[index], negSample])) # encourage to sample from last negIndices + current last negIndices 386 | 387 | negFeat = h5feat[negSample.tolist()] 388 | negFeat = torch.tensor(negFeat) 389 | dist = torch.norm(qFeat - negFeat, dim=1, p=None) 390 | result = dist.topk(self.nNeg * 10, largest=False) 391 | dNeg, negNN = result.values, result.indices 392 | 393 | # try to find negatives that are within margin, if there aren't any return none 394 | violatingNeg = dNeg.numpy() < dPos.numpy() + self.margin**0.5 395 | 396 | if np.sum(violatingNeg) < 1: 397 | # if none are violating then skip this query 398 | return None 399 | 400 | negNN = negNN.numpy() 401 | negNN = negNN[violatingNeg][:self.nNeg] 402 | negIndices = negSample[negNN].astype(np.int32) 403 | self.negCache[index] = negIndices 404 | 405 | query = self.load_images(join(self.img_dir, 'query', self.dbStruct.qImage[index])) 406 | positive = self.load_images(join(self.img_dir, 'database', self.dbStruct.dbImage[posIndex])) 407 | 408 | negatives = [] 409 | negatives2 = [] 410 | negIndices2 = [] 411 | for negIndex in negIndices: 412 | anchor_neg_negs = np.random.choice(self.db_potential_negatives[negIndex], 1000, replace=False) 413 | anchor_poss = self.potential_positives[index] 414 | anchor_neg_negs_clean = np.setdiff1d(anchor_neg_negs, anchor_poss, assume_unique=True) 415 | anchor_neg_negs_clean = np.sort(anchor_neg_negs_clean) 416 | with h5py.File(self.cache, mode='r') as h5: 417 | h5feat = h5.get("features") 418 | negFeat = h5feat[anchor_neg_negs_clean.tolist()] 419 | negFeat = torch.tensor(negFeat) 420 | dist = torch.norm(qFeat - negFeat, dim=1, p=None) 421 | result = dist.topk(self.nNeg * 10, largest=False) 422 | dNeg, negNN = result.values, result.indices 423 | violatingNeg = dNeg.numpy() < dPos.numpy() + self.margin2**0.5 # increase negative samples by using **0.5 424 | if np.sum(violatingNeg) < 1: 425 | return None 426 | negNN = negNN.numpy() 427 | negNN = negNN[violatingNeg][:1] 428 | neg2Index = anchor_neg_negs_clean[negNN].astype(np.int32)[0] 429 | 430 | negative = self.load_images(join(self.img_dir, 'database', self.dbStruct.dbImage[negIndex])) 431 | negative2 = self.load_images(join(self.img_dir, 'database', self.dbStruct.dbImage[neg2Index])) 432 | negatives.append(negative) 433 | negatives2.append(negative2) 434 | negIndices2.append(neg2Index) 435 | 436 | negatives = torch.stack(negatives, 0) # ([num_neg, C, H, W]) 437 | negatives2 = torch.stack(negatives2, 0) # ([num_neg, C, H, W]) 438 | return query, positive, negatives, negatives2, [index, posIndex] + negIndices.tolist() + negIndices2 439 | 440 | def __len__(self): 441 | return len(self.queries) 442 | 443 | 444 | def collate_quad_fn(batch): 445 | """Creates mini-batch tensors from the list of tuples (query, positive, negatives). 446 | 447 | Args: 448 | data: list of tuple (query, positive, negatives). 449 | - query: torch tensor of shape (3, h, w). 450 | - positive: torch tensor of shape (3, h, w). 451 | - negative: torch tensor of shape (n, 3, h, w). 452 | - negative2: torch tensor of shape (n, 3, h, w). 453 | Returns: 454 | query: torch tensor of shape (batch_size, 3, h, w). 455 | positive: torch tensor of shape (batch_size, 3, h, w). 456 | negatives: torch tensor of shape (batch_size, n, 3, h, w). 457 | """ 458 | 459 | batch = list(filter(lambda x: x is not None, batch)) 460 | if len(batch) == 0: 461 | return None, None, None, None, None, None 462 | 463 | query, positive, negatives, negatives2, indices = zip(*batch) 464 | 465 | query = data.dataloader.default_collate(query) # ([8, 3, 200, 200]) = [(3, 200, 200), (3, 200, 200), .. ] ([8, 1, 3, 200, 200]) 466 | positive = data.dataloader.default_collate(positive) 467 | negCounts = data.dataloader.default_collate([x.shape[0] for x in negatives]) 468 | negatives = torch.cat(negatives, 0) # ([80, 3, 200, 200]) ([80, 1, 3, 200, 200]) 469 | negatives2 = torch.cat(negatives2, 0) # ([80, 3, 200, 200]) ([80, 1, 3, 200, 200]) 470 | import itertools 471 | indices = list(itertools.chain(*indices)) 472 | 473 | return query, positive, negatives, negatives2, negCounts, indices -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | echo "downloading the pittsburgh dataset .." 2 | wget "https://www.dropbox.com/s/ynep8wzii1z0r6h/pittsburgh.zip?dl=0" 3 | unzip -q pittsburgh.zip 4 | 5 | echo "downloading the pretrained models .." 6 | wget "https://www.dropbox.com/s/bjph9hkhzdrevk5/logs.zip?dl=0" 7 | unzip -q logs.zip 8 | 9 | -------------------------------------------------------------------------------- /ece.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ramdrop/stun/bda3537fcd3562c8f9f3e9a8e104789d960b48bd/ece.jpg -------------------------------------------------------------------------------- /eval_batch.sh: -------------------------------------------------------------------------------- 1 | 2 | # for((i=0;i<60;i=i+5)); 3 | for((i=31;i<35;i=i+1)); 4 | do 5 | resume=logs/tri_train_stu_0312_174824/ckpt_e_$i.pth.tar 6 | echo "running ${i}.." 7 | python main.py --phase='test_stu' --split='val' --resume=${resume} 8 | python main.py --phase='test_stu' --split='test' --resume=${resume} 9 | python eval_ece_sh.py --split='val' --epoch=${i} --resume=${resume} --network='res50' 10 | python eval_ece_sh.py --split='test' --epoch=${i} --resume=${resume} --network='res50' 11 | done 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /eval_ece_sh.py: -------------------------------------------------------------------------------- 1 | #%% 2 | from scipy import stats 3 | import pickle 4 | import numpy as np 5 | from matplotlib import pyplot as plt 6 | from tqdm import tqdm 7 | from os.path import join, dirname 8 | import utils 9 | import importlib 10 | import argparse 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--resume", type=str, default='/LOCAL/ramdrop/dataset/mmrec_dataset/7n5s_xy11') 14 | parser.add_argument("--split", type=str, default='test', choices=['test', 'val']) 15 | parser.add_argument("--network", type=str, default='res50') 16 | parser.add_argument("--epoch", type=int) 17 | args = parser.parse_args() 18 | 19 | importlib.reload(utils) 20 | # ------------------------------------- - ------------------------------------ # 21 | DATASET = 'pitts' 22 | NETWORK = args.network # 'res50' 23 | LOSS = 'tri' # |'cont'|'tri'|'quad'| 24 | 25 | LOG_OR_LINEAR = 'linear' # |'linear'|'log'| 26 | STD_OR_SQ = 'sq' # |'std'|'sq'| 27 | HMEAN_OR_MEAN = 'mean' # |'hmean'|'mean'| 28 | 29 | NUM_BINS = 11 30 | SHOW_AP = False 31 | # ------------------------------------- - ------------------------------------ # 32 | 33 | exp = '{}_{}_{}_{}_{}'.format(DATASET, NETWORK, LOG_OR_LINEAR, STD_OR_SQ, HMEAN_OR_MEAN) 34 | resume = args.resume 35 | print(resume) 36 | with open(join(dirname(resume), '{}_db_embeddings_{}.pickle'.format(args.split, resume.split('.')[-3].split('_')[-1])), 'rb') as handle: 37 | q_mu = pickle.load(handle) 38 | db_mu = pickle.load(handle) 39 | q_sigma_sq = pickle.load(handle) 40 | db_sigma_sq = pickle.load(handle) 41 | preds = pickle.load(handle) 42 | dists = pickle.load(handle) 43 | gt = pickle.load(handle) 44 | _ = pickle.load(handle) 45 | _ = pickle.load(handle) 46 | 47 | #%% 48 | # CALCULATE ECE ====================== # 49 | q_sigma_sq_h = utils.reduce_sigma(q_sigma_sq, STD_OR_SQ, LOG_OR_LINEAR, HMEAN_OR_MEAN) 50 | indices, _, k = utils.get_zoomed_bins(q_sigma_sq_h, NUM_BINS) 51 | 52 | bins_recall = np.zeros((NUM_BINS-1, 3)) 53 | bins_map = np.zeros((NUM_BINS-1, 3)) 54 | bins_ap = np.zeros((NUM_BINS - 1)) 55 | 56 | ece_bins_recall = np.zeros((NUM_BINS - 1, 3)) 57 | ece_bins_map = np.zeros((NUM_BINS - 1, 3)) 58 | ece_bins_ap = np.zeros((NUM_BINS - 1)) 59 | 60 | n_values = [1, 5, 10] 61 | for index in tqdm(range(NUM_BINS - 1)): 62 | if len(indices[index]) == 0: 63 | continue 64 | 65 | pred_bin = preds[indices[index]] 66 | dist_bin = dists[indices[index]] 67 | gt_bin = gt[indices[index]] 68 | 69 | if SHOW_AP: 70 | # calculate AP 71 | recalls, precisions = utils.bin_pr(pred_bin, dist_bin, gt_bin) 72 | ap = 0 73 | for index_j in range(len(recalls) - 1): 74 | ap += precisions[index_j] * (recalls[index_j + 1] - recalls[index_j]) 75 | bins_ap[index] = ap 76 | ece_bins_ap[index] = len(indices[index]) / q_sigma_sq_h.shape[0] * np.abs(ap - (NUM_BINS - 1 - index) / ((NUM_BINS - 1))) 77 | # ece_bins_ap[index] = np.abs(ap - (10 - index) * 0.1) 78 | 79 | # calculate r@N 80 | recall_at_n = utils.cal_recall(pred_bin, gt_bin, n_values) 81 | bins_recall[index] = recall_at_n 82 | ece_bins_recall[index] = np.array([len(indices[index]) / q_sigma_sq_h.shape[0] * np.abs(recall_at_n[i] / 100.0 - (NUM_BINS - 1 - index) / ((NUM_BINS - 1))) for i in range(len(n_values))]) 83 | # ece_bins_recall[index] = np.array([np.abs(recall_at_n[i] / 100.0 - (10 - index) * 0.1) for i in range(len(n_values))]) 84 | 85 | # calculate mAP@N 86 | map_n = [utils.cal_mapk(pred_bin, gt_bin, n) for n in n_values] 87 | bins_map[index] = map_n 88 | ece_bins_map[index] = np.array([len(indices[index]) / q_sigma_sq_h.shape[0] * np.abs(map_n[i] / 100.0 - (NUM_BINS - 1 - index) / ((NUM_BINS - 1))) for i in range(len(n_values))]) 89 | # ece_bins_map[index] = np.array([np.abs(map_n[i] / 100.0 - (10 - index) * 0.1) for i in range(len(n_values))]) 90 | 91 | 92 | # PRINT SUMMARY ====================== # 93 | # print('ECE_rec@1/5/10: {:.3f}/{:.3f}/{:.3f}'.format(ece_bins_recall.sum(axis=0)[0], ece_bins_recall.sum(axis=0)[1], ece_bins_recall.sum(axis=0)[2])) 94 | # print('ECE_mAP@1/5/10: {:.3f}/{:.3f}/{:.3f}'.format(ece_bins_map.sum(axis=0)[0], ece_bins_map.sum(axis=0)[1], ece_bins_map.sum(axis=0)[2])) 95 | # print('ECE_AP: {:.3f}'.format(ece_bins_ap.sum())) 96 | 97 | #%% 98 | # RECOGNITION METRIC ================= # 99 | recall = utils.cal_recall(preds, gt, n_values) / 100.0 100 | # print('rec@1/5/10: {:.3f}/{:.3f}/{:.3f}'.format(recall[0], recall[1], recall[2])) 101 | map = [utils.cal_mapk(preds, gt, n) / 100.0 for n in n_values] 102 | # print('mAP@1/5/10: {:.3f}/{:.3f}/{:.3f}'.format(map[0], map[1], map[2])) 103 | 104 | if SHOW_AP: 105 | recalls, precisions = utils.bin_pr(preds, dists, gt) 106 | ap = 0 107 | for index_j in range(len(recalls) - 1): 108 | ap += precisions[index_j] * (recalls[index_j + 1] - recalls[index_j]) 109 | # print('AP: {:.3f}'.format(ap)) 110 | 111 | #%% 112 | # VISULIZATION ======================= # 113 | w = np.array([len(indices[index]) / q_sigma_sq_h.shape[0] for index in range(NUM_BINS - 1)]) 114 | x = np.arange(0, NUM_BINS - 1, 1) 115 | 116 | plt.style.use('ggplot') 117 | fig, axs = plt.subplots(2, 2, figsize=(10, 10), squeeze=False) 118 | fig.suptitle('k={}'.format(k)) 119 | 120 | ax = axs[0][0] 121 | ax.bar(np.arange(len(indices)), [len(x) for x in indices]) 122 | ax.set_xlabel('sigma^2\n(uncertainty: low -> high)') 123 | ax.set_ylabel('num of samples') 124 | 125 | ax = axs[0][1] 126 | ax.plot(np.arange(NUM_BINS - 1), bins_recall[:, 0], marker='o') 127 | ax.plot(np.arange(NUM_BINS - 1), utils.linear_fit(x, bins_recall[:, 0], w), marker='', alpha=0.2, c='black') 128 | ax.plot(np.arange(NUM_BINS - 1), bins_recall[:, 1], marker='o') 129 | ax.plot(np.arange(NUM_BINS - 1), utils.linear_fit(x, bins_recall[:, 1], w), marker='', alpha=0.2, c='black') 130 | ax.plot(np.arange(NUM_BINS - 1), bins_recall[:, 2], marker='o') 131 | ax.plot(np.arange(NUM_BINS - 1), utils.linear_fit(x, bins_recall[:, 2], w), marker='', alpha=0.2, c='black') 132 | 133 | ax.set_xlabel('sigma^2\n(uncertainty: low -> high)') 134 | ax.set_ylabel('recall@n') 135 | 136 | ax = axs[1][0] 137 | ax.plot(np.arange(NUM_BINS - 1), bins_map[:, 0], marker='o') 138 | ax.plot(np.arange(NUM_BINS - 1), utils.linear_fit(x, bins_map[:, 0], w), marker='', alpha=0.2, c='black') 139 | ax.plot(np.arange(NUM_BINS - 1), bins_map[:, 1], marker='o') 140 | ax.plot(np.arange(NUM_BINS - 1), utils.linear_fit(x, bins_map[:, 1], w), marker='', alpha=0.2, c='black') 141 | ax.plot(np.arange(NUM_BINS - 1), bins_map[:, 2], marker='o') 142 | ax.plot(np.arange(NUM_BINS - 1), utils.linear_fit(x, bins_map[:, 2], w), marker='', alpha=0.2, c='black') 143 | 144 | ax.set_xlabel('sigma^2\n(uncertainty: low -> high)') 145 | ax.set_ylabel('mAP@n') 146 | 147 | if SHOW_AP: 148 | ax = axs[1][1] 149 | ax.plot(np.arange(NUM_BINS - 1), bins_ap, marker='o') 150 | ax.plot(np.arange(NUM_BINS - 1), utils.linear_fit(x, bins_ap, w), marker='', alpha=0.2, c='black') 151 | ax.set_xlabel('sigma^2\n(uncertainty: low -> high)') 152 | ax.set_ylabel('AP') 153 | 154 | 155 | with open(join(dirname(resume),'{}.log'.format(args.split)), 'a') as f: 156 | arg = 'e:{:>2d} rec@1: {:.3f}'.format(args.epoch, recall[0]) 157 | arg += ' ECE_rec@1/5/10: {:.3f}/{:.3f}/{:.3f}'.format(ece_bins_recall.sum(axis=0)[0], ece_bins_recall.sum(axis=0)[1], ece_bins_recall.sum(axis=0)[2]) 158 | arg += ' ECE_mAP@1/5/10: {:.3f}/{:.3f}/{:.3f}\n'.format(ece_bins_map.sum(axis=0)[0], ece_bins_map.sum(axis=0)[1], ece_bins_map.sum(axis=0)[2]) 159 | f.write(arg) 160 | f.flush() 161 | 162 | plt.savefig(join(dirname(resume), 'ece_{}_{}.jpg'.format(args.split, args.epoch)), dpi=200) 163 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, join 2 | 3 | import trainer 4 | from options import Options 5 | 6 | options_handler = Options() 7 | options = options_handler.parse() 8 | 9 | if __name__ == "__main__": 10 | 11 | if options.phase in ['test_tea', 'test_stu', 'train_stu']: 12 | print(f'resume from {options.resume}') 13 | options = options_handler.update_opt_from_json(join(dirname(options.resume), 'flags.json'), options) 14 | tr = trainer.Trainer(options) 15 | print(tr.opt.phase, '-->', tr.opt.runsPath) 16 | elif options.phase in ['train_tea']: 17 | tr = trainer.Trainer(options) 18 | print(tr.opt.phase, '-->', tr.opt.runsPath) 19 | 20 | if options.phase in ['train_tea']: 21 | tr.train() 22 | elif options.phase in ['train_stu']: 23 | tr.train_student() 24 | elif options.phase in ['test_tea', 'test_stu']: 25 | tr.test() -------------------------------------------------------------------------------- /networks/res50gem.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import sys 3 | 4 | sys.path.append('..') 5 | from re import L 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.parameter import Parameter 9 | import cirtorch.functional as LF 10 | import math 11 | import torch.nn.functional as F 12 | 13 | 14 | class L2Norm(nn.Module): 15 | def __init__(self, dim=1): 16 | super().__init__() 17 | self.dim = dim 18 | 19 | def forward(self, input): 20 | return F.normalize(input, p=2, dim=self.dim) 21 | 22 | 23 | class GeM(nn.Module): 24 | def __init__(self, p=3, eps=1e-6): 25 | super(GeM, self).__init__() 26 | self.p = Parameter(torch.ones(1) * p) 27 | self.eps = eps 28 | 29 | def forward(self, x): 30 | return LF.gem(x, p=self.p, eps=self.eps) 31 | 32 | def __repr__(self): 33 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' 34 | 35 | 36 | class Backbone(nn.Module): 37 | def __init__(self, opt=None): 38 | super().__init__() 39 | 40 | self.sigma_dim = 2048 41 | self.mu_dim = 2048 42 | 43 | resnet50 = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True, verbose=False) 44 | features = list(resnet50.children())[:-2] 45 | # feature map: ([B,3,224,224])->([B,2048,7,7]) 46 | 47 | self.backbone = nn.Sequential(*features, GeM(), nn.Flatten()) 48 | for module in self.backbone.modules(): 49 | if isinstance(module, nn.BatchNorm2d): 50 | if hasattr(module, 'weight'): 51 | module.weight.requires_grad_(False) 52 | if hasattr(module, 'bias'): 53 | module.bias.requires_grad_(False) 54 | 55 | 56 | class TeacherNet(Backbone): 57 | def __init__(self, opt=None): 58 | super().__init__() 59 | self.id = 'teacher' 60 | self.mean_head = nn.Sequential(L2Norm(dim=1)) 61 | 62 | def forward(self, inputs): 63 | B, C, H, W = inputs.shape # (B, 1, 3, 224, 224) 64 | # inputs = inputs.view(B * L, C, H, W) # ([B, 3, 224, 224]) 65 | 66 | backbone_output = self.backbone(inputs) # ([B, 2048, 1, 1]) 67 | mu = self.mean_head(backbone_output).view(B, -1) # ([B, 2048]) <= ([B, 2048, 1, 1]) 68 | 69 | return mu, torch.zeros_like(mu) 70 | 71 | 72 | class StudentNet(TeacherNet): 73 | def __init__(self, opt=None): 74 | super().__init__() 75 | self.id = 'student' 76 | self.var_head = nn.Sequential(nn.Linear(2048, self.sigma_dim), nn.Sigmoid()) 77 | 78 | def forward(self, inputs): 79 | B, C, H, W = inputs.shape # (B, 1, 3, 224, 224) 80 | inputs = inputs.view(B, C, H, W) # ([B, 3, 224, 224]) 81 | 82 | backbone_output = self.backbone(inputs) # ([B, 2048, 7, 7]) 83 | mu = self.mean_head(backbone_output).view(B, -1) # ([B, 2048]) <= ([B, 2048, 1, 1]) 84 | log_sigma_sq = self.var_head(backbone_output).view(B, -1) # ([B, 2048]) <= ([B, 2048, 1, 1]) 85 | 86 | return mu, log_sigma_sq 87 | 88 | 89 | def deliver_model(opt, id): 90 | if id == 'tea': 91 | return TeacherNet(opt) 92 | elif id == 'stu': 93 | return StudentNet(opt) 94 | 95 | 96 | if __name__ == '__main__': 97 | tea = TeacherNet() 98 | stu = StudentNet() 99 | inputs = torch.rand((1, 3, 224, 224)) 100 | outputs_tea = tea(inputs) 101 | outputs_stu = stu(inputs) 102 | 103 | print(outputs_tea[0].shape, outputs_tea[1].shape) 104 | print(outputs_tea[0].shape, outputs_stu[1].shape) 105 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | from genericpath import exists 9 | 10 | 11 | class Options: 12 | def __init__(self): 13 | self.parser = argparse.ArgumentParser(description="Options") 14 | self.parser.add_argument('--phase', type=str, default='test_tea', help='phase', choices=['train_tea', 'test_tea', 'train_stu', 'test_stu']) 15 | self.parser.add_argument('--dataset', type=str, default='pitts', help='choose dataset.') 16 | self.parser.add_argument('--structDir', type=str, default='pittsburgh/structure', help='Path for structure.') 17 | self.parser.add_argument('--imgDir', type=str, default='pittsburgh', help='Path for images.') 18 | self.parser.add_argument('--com', type=str, default='', help='comment') 19 | self.parser.add_argument('--height', type=int, default=200, help='number of sequence to use.') 20 | self.parser.add_argument('--width', type=int, default=200, help='number of sequence to use.') 21 | self.parser.add_argument('--net', type=str, default='res50gem', help='network') 22 | self.parser.add_argument('--trainer', type=str, default='trainer', help='trainer') 23 | self.parser.add_argument('--loss', type=str, default='tri', help='triplet loss or bayesian triplet loss', choices=['tri', 'cont', 'quad']) 24 | self.parser.add_argument('--margin', type=float, default=0.1, help='Margin for triplet loss. Default=0.1') 25 | self.parser.add_argument('--margin2', type=float, default=0.1, help='Margin2 for quadruplet loss. Default=0.1') 26 | self.parser.add_argument('--output_dim', type=int, default=0, help='Number of feature dimension. Default=512') 27 | self.parser.add_argument('--sigma_dim', type=int, default=0, help='Number of sigma dimension. Default=512') 28 | self.parser.add_argument('--batchSize', type=int, default=8, help='Number of triplets (query, pos, negs). Each triplet consists of 12 images.') 29 | self.parser.add_argument('--cacheBatchSize', type=int, default=128, help='Batch size for caching and testing') 30 | self.parser.add_argument('--cacheRefreshRate', type=int, default=0, help='How often to refresh cache, in number of queries. 0 for off') 31 | self.parser.add_argument('--nEpochs', type=int, default=60, help='number of epochs to train for') 32 | self.parser.add_argument('--nGPU', type=int, default=1, help='number of GPU to use.') 33 | self.parser.add_argument('--cGPU', type=int, default=2, help='core of GPU to use.') 34 | self.parser.add_argument('--optim', type=str, default='adam', help='optimizer to use', choices=['sgd', 'adam']) 35 | self.parser.add_argument('--lr', type=float, default=1e-5, help='Learning Rate.') 36 | self.parser.add_argument('--lrStep', type=float, default=5, help='Decay LR ever N steps.') 37 | self.parser.add_argument('--lrGamma', type=float, default=0.99, help='Multiply LR by Gamma for decaying.') 38 | self.parser.add_argument('--weightDecay', type=float, default=0.001, help='Weight decay for SGD.') 39 | self.parser.add_argument('--momentum', type=float, default=0.9, help='Momentum for SGD.') 40 | self.parser.add_argument('--cuda', action='store_false', help='use cuda') 41 | self.parser.add_argument('--d', action='store_true', help='debug mode') 42 | self.parser.add_argument('--threads', type=int, default=8, help='Number of threads for each data loader to use') 43 | self.parser.add_argument('--seed', type=int, default=1234, help='Random seed to use.') 44 | self.parser.add_argument('--logsPath', type=str, default='./logs', help='Path to save runs to.') 45 | self.parser.add_argument('--runsPath', type=str, default='not defined', help='Path to save runs to.') 46 | self.parser.add_argument('--resume', type=str, default='', help='Path to load checkpoint from, for resuming training or testing.') 47 | self.parser.add_argument('--evalEvery', type=int, default=1, help='Do a validation set run, and save, every N epochs.') 48 | self.parser.add_argument('--cacheRefreshEvery', type=int, default=1, help='refresh embedding cache, every N epochs.') 49 | self.parser.add_argument('--patience', type=int, default=10, help='Patience for early stopping. 0 is off.') 50 | self.parser.add_argument('--split', type=str, default='val', help='Split to use', choices=['val', 'test']) 51 | self.parser.add_argument('--encoder_dim', type=int, default=512, help='Number of feature dimension. Default=512') 52 | 53 | def parse(self): 54 | options = self.parser.parse_args() 55 | return options 56 | 57 | def update_opt_from_json(self, flag_file, options): 58 | if not exists(flag_file): 59 | raise ValueError('{} not exist'.format(flag_file)) 60 | # restore_var = ['runsPath', 'net', 'seqLen', 'num_clusters', 'output_dim', 'structDir', 'imgDir', 'lrStep', 'lrGamma', 'weightDecay', 'momentum', 'num_clusters', 'optim', 'margin', 'seed', 'patience'] 61 | do_not_update_list = ['resume', 'mode', 'phase', 'optim', 'split'] 62 | if os.path.exists(flag_file): 63 | with open(flag_file, 'r') as f: 64 | # stored_flags = {'--' + k: str(v) for k, v in json.load(f).items() if k in restore_var} 65 | stored_flags = {'--' + k: str(v) for k, v in json.load(f).items() if k not in do_not_update_list} 66 | to_del = [] 67 | for flag, val in stored_flags.items(): 68 | for act in self.parser._actions: 69 | if act.dest == flag[2:]: # stored parser match current parser 70 | # store_true / store_false args don't accept arguments, filter these 71 | if type(act.const) == type(True): 72 | if val == str(act.default): 73 | to_del.append(flag) 74 | else: 75 | stored_flags[flag] = '' 76 | else: 77 | if val == str(act.default): 78 | to_del.append(flag) 79 | 80 | for flag, val in stored_flags.items(): 81 | missing = True 82 | for act in self.parser._actions: 83 | if flag[2:] == act.dest: 84 | missing = False 85 | if missing: 86 | to_del.append(flag) 87 | 88 | for flag in to_del: 89 | del stored_flags[flag] 90 | 91 | train_flags = [x for x in list(sum(stored_flags.items(), tuple())) if len(x) > 0] 92 | print('restored flags:', train_flags) 93 | options = self.parser.parse_args(train_flags, namespace=options) 94 | return options 95 | 96 | 97 | class FixRandom: 98 | def __init__(self, seed) -> None: 99 | self.seed = seed 100 | torch.manual_seed(self.seed) 101 | random.seed(self.seed) 102 | np.random.seed(self.seed) 103 | torch.backends.cudnn.benchmark = False 104 | torch.use_deterministic_algorithms(True) 105 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 106 | 107 | def seed_worker(self): 108 | worker_seed = self.seed 109 | np.random.seed(worker_seed) 110 | random.seed(worker_seed) 111 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import importlib 3 | import os 4 | import pickle 5 | import shutil 6 | from os.path import dirname, exists, join 7 | import h5py 8 | import faiss 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import wandb 13 | from torch.utils.data import DataLoader 14 | from tqdm import tqdm 15 | from datetime import datetime 16 | import json 17 | import torch.optim as optim 18 | 19 | os.sys.path.append(os.path.join(os.path.dirname(__file__), '../')) 20 | 21 | from options import FixRandom 22 | from utils import cal_recall, light_log, schedule_device 23 | 24 | 25 | class ContrastiveLoss(nn.Module): 26 | def __init__(self, margin) -> None: 27 | super().__init__() 28 | self.margin = margin 29 | 30 | def forward(self, emb_a, emb, pos_pair=True): # (1, D) 31 | if pos_pair: 32 | loss = 0.5 * (torch.norm(emb_a - emb, dim=1).pow(2)) 33 | else: 34 | dis_D = torch.norm(emb_a - emb, dim=1) 35 | loss = 0.5 * (torch.clamp(self.margin - dis_D, min=0).pow(2)) 36 | 37 | return loss 38 | 39 | 40 | class QuadrupletLoss(nn.Module): 41 | def __init__(self, margin, margin2) -> None: 42 | super().__init__() 43 | device = torch.device("cuda") 44 | self.cri = nn.TripletMarginLoss(margin=margin, p=2, reduction='sum').to(device) 45 | self.cri2 = nn.TripletMarginLoss(margin=margin2, p=2, reduction='sum').to(device) 46 | 47 | def forward(self, emb_a, emb_p, emb_n, emb_n2): # (1, D) 48 | loss1 = self.cri(emb_a, emb_p, emb_n) 49 | loss2 = self.cri(emb_a, emb_p, emb_n2) 50 | loss = loss1 + loss2 51 | return loss 52 | 53 | 54 | class Trainer: 55 | def __init__(self, options) -> None: 56 | 57 | self.opt = options 58 | 59 | # r variables 60 | self.step = 0 61 | self.epoch = 0 62 | self.current_lr = 0 63 | self.best_recalls = [0, 0, 0] 64 | 65 | # seed 66 | fix_random = FixRandom(self.opt.seed) 67 | self.seed_worker = fix_random.seed_worker() 68 | self.time_stamp = datetime.now().strftime('%m%d_%H%M%S') 69 | 70 | # set device 71 | if self.opt.phase == 'train_tea': 72 | self.opt.cGPU = schedule_device() 73 | if self.opt.cuda and not torch.cuda.is_available(): 74 | raise Exception("No GPU found, please run with --nocuda :(") 75 | torch.cuda.set_device(self.opt.cGPU) 76 | self.device = torch.device("cuda") 77 | print('{}:{}{}'.format('device', self.device, torch.cuda.current_device())) 78 | 79 | # make model 80 | if self.opt.phase == 'train_tea': 81 | self.model, self.optimizer, self.scheduler, self.criterion = self.make_model() 82 | elif self.opt.phase == 'train_stu': 83 | self.teacher_net, self.student_net, self.optimizer, self.scheduler, self.criterion = self.make_model() 84 | self.model = self.teacher_net 85 | elif self.opt.phase in ['test_tea', 'test_stu']: 86 | self.model = self.make_model() 87 | else: 88 | raise Exception('Undefined phase :(') 89 | 90 | # make folders 91 | self.make_folders() 92 | # make dataset 93 | self.make_dataset() 94 | # online logs 95 | if self.opt.phase in ['train_tea', 'train_stu']: 96 | wandb.init(project="STUN", config=vars(self.opt), name=f"{self.opt.loss}_{self.opt.phase}_{self.time_stamp}") 97 | 98 | 99 | def make_folders(self): 100 | ''' create folders to store tensorboard files and a copy of networks files 101 | ''' 102 | if self.opt.phase in ['train_tea', 'train_stu']: 103 | self.opt.runsPath = join(self.opt.logsPath, f"{self.opt.loss}_{self.opt.phase}_{self.time_stamp}") 104 | if not os.path.exists(join(self.opt.runsPath, 'models')): 105 | os.makedirs(join(self.opt.runsPath, 'models')) 106 | 107 | if not os.path.exists(join(self.opt.runsPath, 'transformed')): 108 | os.makedirs(join(self.opt.runsPath, 'transformed')) 109 | 110 | for file in [__file__, 'datasets/{}.py'.format(self.opt.dataset), 'networks/{}.py'.format(self.opt.net)]: 111 | shutil.copyfile(file, os.path.join(self.opt.runsPath, 'models', file.split('/')[-1])) 112 | 113 | with open(join(self.opt.runsPath, 'flags.json'), 'w') as f: 114 | f.write(json.dumps({k: v for k, v in vars(self.opt).items()}, indent='')) 115 | 116 | def make_dataset(self): 117 | ''' make dataset 118 | ''' 119 | if self.opt.phase in ['train_tea', 'train_stu']: 120 | assert os.path.exists(f'datasets/{self.opt.dataset}.py'), 'Cannot find ' + f'{self.opt.dataset}.py :(' 121 | self.dataset = importlib.import_module('datasets.' + self.opt.dataset) 122 | elif self.opt.phase in ['test_tea', 'test_stu']: 123 | self.dataset = importlib.import_module('tmp.models.{}'.format(self.opt.dataset)) 124 | 125 | # for emb cache 126 | self.whole_train_set = self.dataset.get_whole_training_set(self.opt) 127 | self.whole_training_data_loader = DataLoader(dataset=self.whole_train_set, num_workers=self.opt.threads, batch_size=self.opt.cacheBatchSize, shuffle=False, pin_memory=self.opt.cuda, worker_init_fn=self.seed_worker) 128 | self.whole_val_set = self.dataset.get_whole_val_set(self.opt) 129 | self.whole_val_data_loader = DataLoader(dataset=self.whole_val_set, num_workers=self.opt.threads, batch_size=self.opt.cacheBatchSize, shuffle=False, pin_memory=self.opt.cuda, worker_init_fn=self.seed_worker) 130 | self.whole_test_set = self.dataset.get_whole_test_set(self.opt) 131 | self.whole_test_data_loader = DataLoader(dataset=self.whole_test_set, num_workers=self.opt.threads, batch_size=self.opt.cacheBatchSize, shuffle=False, pin_memory=self.opt.cuda, worker_init_fn=self.seed_worker) 132 | # for train tuples 133 | if self.opt.loss == 'quad': 134 | self.train_set = self.dataset.get_quad_set(self.opt, self.opt.margin, self.opt.margin2) 135 | self.training_data_loader = DataLoader(dataset=self.train_set, num_workers=8, batch_size=self.opt.batchSize, shuffle=True, collate_fn=self.dataset.collate_quad_fn, worker_init_fn=self.seed_worker) 136 | else: 137 | self.train_set = self.dataset.get_training_query_set(self.opt, self.opt.margin) 138 | self.training_data_loader = DataLoader(dataset=self.train_set, num_workers=8, batch_size=self.opt.batchSize, shuffle=True, collate_fn=self.dataset.collate_fn, worker_init_fn=self.seed_worker) 139 | print('{}:{}, {}:{}, {}:{}, {}:{}, {}:{}'.format('dataset', self.opt.dataset, 'database', self.whole_train_set.dbStruct.numDb, 'train_set', self.whole_train_set.dbStruct.numQ, 'val_set', self.whole_val_set.dbStruct.numQ, 'test_set', 140 | self.whole_test_set.dbStruct.numQ)) 141 | print('{}:{}, {}:{}'.format('cache_bs', self.opt.cacheBatchSize, 'tuple_bs', self.opt.batchSize)) 142 | 143 | 144 | def make_model(self): 145 | '''build model 146 | ''' 147 | if self.opt.phase == 'train_tea': 148 | # build teacher net 149 | assert os.path.exists(f'networks/{self.opt.net}.py'), 'Cannot find ' + f'{self.opt.net}.py :(' 150 | network = importlib.import_module('networks.' + self.opt.net) 151 | model = network.deliver_model(self.opt, 'tea') 152 | model = model.to(self.device) 153 | outputs = model(torch.rand((2, 3, self.opt.height, self.opt.width), device=self.device)) 154 | self.opt.output_dim = model(torch.rand((2, 3, self.opt.height, self.opt.width), device=self.device))[0].shape[-1] 155 | self.opt.sigma_dim = model(torch.rand((2, 3, self.opt.height, self.opt.width), device=self.device))[1].shape[-1] # place holder 156 | elif self.opt.phase == 'train_stu': # load teacher net 157 | assert self.opt.resume != '', 'You need to define the teacher/resume path :(' 158 | if exists('tmp'): 159 | shutil.rmtree('tmp') 160 | os.mkdir('tmp') 161 | shutil.copytree(join(dirname(self.opt.resume), 'models'), join('tmp', 'models')) 162 | network = importlib.import_module(f'tmp.models.{self.opt.net}') 163 | model_tea = network.deliver_model(self.opt, 'tea').to(self.device) 164 | checkpoint = torch.load(self.opt.resume) 165 | model_tea.load_state_dict(checkpoint['state_dict']) 166 | # build student net 167 | assert os.path.exists(f'networks/{self.opt.net}.py'), 'Cannot find ' + f'{self.opt.net}.py :(' 168 | network = importlib.import_module('networks.' + self.opt.net) 169 | model = network.deliver_model(self.opt, 'stu').to(self.device) 170 | self.opt.output_dim = model(torch.rand((2, 3, self.opt.height, self.opt.width), device=self.device))[0].shape[-1] 171 | self.opt.sigma_dim = model(torch.rand((2, 3, self.opt.height, self.opt.width), device=self.device))[1].shape[-1] 172 | elif self.opt.phase in ['test_tea', 'test_stu']: 173 | # load teacher or student net 174 | assert self.opt.resume != '', 'You need to define a teacher/resume path :(' 175 | if exists('tmp'): 176 | shutil.rmtree('tmp') 177 | os.mkdir('tmp') 178 | shutil.copytree(join(dirname(self.opt.resume), 'models'), join('tmp', 'models')) 179 | network = importlib.import_module('tmp.models.{}'.format(self.opt.net)) 180 | model = network.deliver_model(self.opt, self.opt.phase[-3:]).to(self.device) 181 | checkpoint = torch.load(self.opt.resume) 182 | model.load_state_dict(checkpoint['state_dict']) 183 | 184 | print('{}:{}, {}:{}, {}:{}'.format(model.id, self.opt.net, 'loss', self.opt.loss, 'mu_dim', self.opt.output_dim, 'sigma_dim', self.opt.sigma_dim if self.opt.phase[-3:] == 'stu' else '-')) 185 | 186 | if self.opt.phase in ['train_tea', 'train_stu']: 187 | # optimizer 188 | if self.opt.optim == 'adam': 189 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), self.opt.lr, weight_decay=self.opt.weightDecay) 190 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, self.opt.lrGamma, last_epoch=-1, verbose=False) 191 | elif self.opt.optim == 'sgd': 192 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=self.opt.lr, momentum=self.opt.momentum, weight_decay=self.opt.weightDecay) 193 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.opt.lrStep, gamma=self.opt.lrGamma) 194 | else: 195 | raise NameError('Undefined optimizer :(') 196 | 197 | # loss function 198 | if self.opt.loss == 'tri': 199 | criterion = nn.TripletMarginLoss(margin=self.opt.margin, p=2, reduction='sum').to(self.device) 200 | elif self.opt.loss == 'cont': 201 | criterion = ContrastiveLoss(margin=torch.tensor(self.opt.margin, device=self.device)) 202 | elif self.opt.loss == 'quad': 203 | criterion = QuadrupletLoss(margin=self.opt.margin, margin2=self.opt.margin2).to(self.device) 204 | 205 | if self.opt.nGPU > 1: 206 | model = nn.DataParallel(model) 207 | 208 | if self.opt.phase == 'train_tea': 209 | return model, optimizer, scheduler, criterion 210 | elif self.opt.phase == 'train_stu': 211 | return model_tea, model, optimizer, scheduler, criterion 212 | elif self.opt.phase in ['test_tea', 'test_stu']: 213 | return model 214 | else: 215 | raise NameError('Undefined phase :(') 216 | 217 | 218 | def build_embedding_cache(self): 219 | '''build embedding cache, such that we can find the corresponding (p) and (n) with respect to (a) in embedding space 220 | ''' 221 | self.train_set.cache = os.path.join(self.opt.runsPath, self.train_set.whichSet + '_feat_cache.hdf5') 222 | with h5py.File(self.train_set.cache, mode='w') as h5: 223 | h5feat = h5.create_dataset("features", [len(self.whole_train_set), self.opt.output_dim], dtype=np.float32) 224 | with torch.no_grad(): 225 | for iteration, (input, indices) in enumerate(tqdm(self.whole_training_data_loader), 1): 226 | input = input.to(self.device) # torch.Size([32, 3, 154, 154]) ([32, 5, 3, 200, 200]) 227 | emb, _ = self.model(input) 228 | h5feat[indices.detach().numpy(), :] = emb.detach().cpu().numpy() 229 | del input, emb 230 | 231 | def process_batch(self, batch_inputs): 232 | ''' 233 | process a batch of input 234 | ''' 235 | if self.opt.loss == 'quad': 236 | anchor, positives, negatives, negatives2, neg_counts, indices = batch_inputs 237 | else: 238 | anchor, positives, negatives, neg_counts, indices = batch_inputs 239 | 240 | # in case we get an empty batch 241 | if anchor is None: 242 | return None, None 243 | 244 | # some reshaping to put query, pos, negs in a single (N, 3, H, W) tensor, where N = batchSize * (nQuery + nPos + n_neg) 245 | B = anchor.shape[0] # ([8, 1, 3, 200, 200]) 246 | n_neg = torch.sum(neg_counts) # tensor(80) = torch.sum(torch.Size([8])) 247 | if self.opt.loss == 'quad': 248 | input = torch.cat([anchor, positives, negatives, negatives2]) # ([B, C, H, 200]) 249 | else: 250 | input = torch.cat([anchor, positives, negatives]) # ([B, C, H, 200]) 251 | 252 | input = input.to(self.device) # ([96, 1, C, H, W]) 253 | embs, vars = self.model(input) # ([96, D]) 254 | 255 | # monitor uncertainty values 256 | if self.step % 100 == 0: 257 | wandb.log({'sigma_sq/avg': torch.mean(vars).item()}, step=self.step) 258 | wandb.log({'sigma_sq/max': torch.max(vars).item()}, step=self.step) 259 | wandb.log({'sigma_sq/min': torch.min(vars).item()}, step=self.step) 260 | 261 | tuple_loss = 0 262 | # Standard triplet loss (via PyTorch library) 263 | if self.opt.loss == 'tri': 264 | embs_a, embs_p, embs_n = torch.split(embs, [B, B, n_neg]) 265 | for i, neg_count in enumerate(neg_counts): 266 | for n in range(neg_count): 267 | negIx = (torch.sum(neg_counts[:i]) + n).item() 268 | tuple_loss += self.criterion(embs_a[i:i + 1], embs_p[i:i + 1], embs_n[negIx:negIx + 1]) 269 | tuple_loss /= n_neg.float().to(self.device) 270 | # Contrastive loss 271 | elif self.opt.loss == 'cont': 272 | embs_a, embs_p, embs_n = torch.split(embs, [B, B, n_neg]) # embs_a: ([B, D]) 273 | dis_pos_min, dis_neg_min, dis_neg_avg = 0, 0, 0 274 | for i, neg_count in enumerate(neg_counts): 275 | dis_pos_min += torch.norm(embs_a[i:i + 1] - embs_p[i:i + 1], dim=1) 276 | tuple_loss += self.criterion(embs_a[i:i + 1], embs_p[i:i + 1], pos_pair=True) 277 | for n in range(neg_count): 278 | negIx = (torch.sum(neg_counts[:i]) + n).item() 279 | if n == 0: 280 | dis_neg_min += torch.norm(embs_a[i:i + 1] - embs_n[negIx:negIx + 1], dim=1) 281 | dis_neg_avg += dis_neg_min 282 | else: 283 | dis_neg_avg += torch.norm(embs_a[i:i + 1] - embs_n[negIx:negIx + 1], dim=1) 284 | 285 | tuple_loss += self.criterion(embs_a[i:i + 1], embs_n[negIx:negIx + 1], pos_pair=False) 286 | tuple_loss /= (n_neg + 1).float().to(self.device) 287 | if self.step % 100 == 0: 288 | wandb.log({'pair_dis/pos_min': dis_pos_min.item()}, step=self.step) 289 | wandb.log({'pair_dis/neg_min': (dis_neg_min / n_neg).item()}, step=self.step) 290 | wandb.log({'pair_dis/neg_avg': (dis_neg_avg / (n_neg + 1)).item()}, step=self.step) 291 | # Quadruplet loss 292 | elif self.opt.loss == 'quad': 293 | embs_a, embs_p, embs_n, embs_n2 = torch.split(embs, [B, B, n_neg, n_neg]) 294 | for i, neg_count in enumerate(neg_counts): 295 | for n in range(neg_count): 296 | negIx = (torch.sum(neg_counts[:i]) + n).item() 297 | tuple_loss += self.criterion(embs_a[i:i + 1], embs_p[i:i + 1], embs_n[negIx:negIx + 1], embs_n2[negIx:negIx + 1]) 298 | tuple_loss /= 2 * n_neg.float().to(self.device) 299 | 300 | del input, embs, embs_a, embs_p, embs_n 301 | del anchor, positives, negatives 302 | 303 | return tuple_loss, n_neg 304 | 305 | def train(self): 306 | not_improved = 0 307 | for epoch in range(self.opt.nEpochs): 308 | self.epoch = epoch 309 | self.current_lr = self.optimizer.state_dict()['param_groups'][0]['lr'] 310 | 311 | # build embedding cache 312 | if self.epoch % self.opt.cacheRefreshEvery == 0: 313 | self.model.eval() 314 | self.build_embedding_cache() 315 | self.model.train() 316 | 317 | # train 318 | tuple_loss_sum = 0 319 | for _, batch_inputs in enumerate(tqdm(self.training_data_loader)): 320 | self.step += 1 321 | 322 | self.optimizer.zero_grad() 323 | tuple_loss, n_neg = self.process_batch(batch_inputs) 324 | if tuple_loss is None: 325 | continue 326 | tuple_loss.backward() 327 | self.optimizer.step() 328 | tuple_loss_sum += tuple_loss.item() 329 | 330 | if self.step % 10 == 0: 331 | wandb.log({'train_tuple_loss': tuple_loss.item()}, step=self.step) 332 | wandb.log({'train_batch_num_neg': n_neg}, step=self.step) 333 | 334 | n_batches = len(self.training_data_loader) 335 | wandb.log({'train_avg_tuple_loss': tuple_loss_sum / n_batches}, step=self.step) 336 | torch.cuda.empty_cache() 337 | self.scheduler.step() 338 | 339 | # val every x epochs 340 | if (self.epoch % self.opt.evalEvery) == 0: 341 | recalls = self.val(self.model) 342 | if recalls[0] > self.best_recalls[0]: 343 | self.best_recalls = recalls 344 | not_improved = 0 345 | else: 346 | not_improved += self.opt.evalEvery 347 | # light log 348 | vars_to_log = [ 349 | 'e={:>2d},'.format(self.epoch), 350 | 'lr={:>.8f},'.format(self.current_lr), 351 | 'tl={:>.4f},'.format(tuple_loss_sum / n_batches), 352 | 'r@1/5/10={:.2f}/{:.2f}/{:.2f}'.format(recalls[0], recalls[1], recalls[2]), 353 | '\n' if not_improved else ' *\n', 354 | ] 355 | light_log(self.opt.runsPath, vars_to_log) 356 | else: 357 | recalls = None 358 | self.save_model(self.model, is_best=not not_improved) 359 | 360 | # stop when not improving for a period 361 | if self.opt.phase == 'train_tea': 362 | if self.opt.patience > 0 and not_improved > self.opt.patience: 363 | print('terminated because performance has not improve for', self.opt.patience, 'epochs') 364 | break 365 | 366 | self.save_model(self.model, is_best=False) 367 | print('best r@1/5/10={:.2f}/{:.2f}/{:.2f}'.format(self.best_recalls[0], self.best_recalls[1], self.best_recalls[2])) 368 | 369 | return self.best_recalls 370 | 371 | def train_student(self): 372 | not_improved = 0 373 | for epoch in range(self.opt.nEpochs): 374 | self.epoch = epoch 375 | self.current_lr = self.optimizer.state_dict()['param_groups'][0]['lr'] 376 | 377 | mu_delta_sq_sum, sigma_sq_sum, log_sigma_sq_sum, left_sum, loss_sum = 0, 0, 0, 0, 0 378 | n_batches = len(self.whole_training_data_loader) 379 | for iteration, (input, indices) in enumerate(tqdm(self.whole_training_data_loader)): 380 | self.step += 1 381 | input = input.to(self.device) # ([B, C, H, W]) 382 | self.optimizer.zero_grad() 383 | 384 | with torch.no_grad(): 385 | mu_tea, _ = self.teacher_net(input) # ([B, D]) 386 | mu_stu, log_sigma_sq = self.student_net(input) # ([B, D]), ([B, D]) 387 | 388 | # ---------------------- shift sigma_sq ---------------------- # 389 | if self.opt.loss in ['tri', 'quad']: # empically found shifting distribution to be helpful for these losses 390 | log_sigma_sq = torch.clamp(10 * log_sigma_sq + 0.2, 0, 1) 391 | # == numerator 392 | mu_delta = torch.norm((mu_stu - mu_tea), p=2, dim=-1, keepdim=True) # L2 norm -> ([B, D]) 393 | # == denominator 394 | sigma_sq = torch.exp(log_sigma_sq) 395 | # == regulizer 396 | loss = (mu_delta / sigma_sq + log_sigma_sq).mean() # ([B, D]) 397 | 398 | loss.backward() 399 | self.optimizer.step() 400 | 401 | mu_delta_sq_sum += mu_delta.mean().item() 402 | sigma_sq_sum += sigma_sq.mean().item() 403 | log_sigma_sq_sum += log_sigma_sq.mean().item() 404 | left_sum += (mu_delta / sigma_sq).mean().item() 405 | loss_sum += loss.item() 406 | if self.step % 5 == 0: 407 | wandb.log({'student/loss_mu_delta_sq': mu_delta.mean().item()}, step=self.step) 408 | wandb.log({'student/loss_sigma_sq': sigma_sq.mean().item()}, step=self.step) 409 | wandb.log({'student/loss_log_sigma_sq': log_sigma_sq.mean().item()}, step=self.step) 410 | wandb.log({'student/loss_left': (mu_delta / sigma_sq).mean().item()}, step=self.step) 411 | wandb.log({'student/loss': loss.item()}, step=self.step) 412 | 413 | wandb.log({'student/epoch_loss_mu_delta_sq': mu_delta_sq_sum / n_batches}, step=self.step) 414 | wandb.log({'student/epoch_loss_sigma_sq': sigma_sq_sum / n_batches}, step=self.step) 415 | wandb.log({'student/epoch_loss_log_sigma_sq': log_sigma_sq_sum / n_batches}, step=self.step) 416 | wandb.log({'student/epoch_loss_left': left_sum / n_batches}, step=self.step) 417 | wandb.log({'student/epoch_loss': loss_sum / n_batches}, step=self.step) 418 | self.scheduler.step() 419 | 420 | # val 421 | if (self.epoch % self.opt.evalEvery) == 0: 422 | recalls = self.val(self.student_net) 423 | if recalls[0] > self.best_recalls[0]: 424 | self.best_recalls = recalls 425 | not_improved = 0 426 | else: 427 | not_improved += self.opt.evalEvery 428 | 429 | light_log(self.opt.runsPath, [ 430 | f'e={self.epoch:>2d},', 431 | f'lr={self.current_lr:>.8f},', 432 | f'tl={loss_sum / n_batches:>.4f},', 433 | f'r@1/5/10={recalls[0]:.2f}/{recalls[1]:.2f}/{recalls[2]:.2f}', 434 | '\n' if not_improved else ' *\n', 435 | ]) 436 | else: 437 | recalls = None 438 | 439 | self.save_model(self.student_net, is_best=False, save_every_epoch=True) 440 | if self.opt.patience > 0 and not_improved > self.opt.patience: 441 | print('terminated because performance has not improve for', self.opt.patience, 'epochs') 442 | break 443 | 444 | print('best r@1/5/10={:.2f}/{:.2f}/{:.2f}'.format(self.best_recalls[0], self.best_recalls[1], self.best_recalls[2])) 445 | return self.best_recalls 446 | 447 | def val(self, model): 448 | recalls, _ = self.get_recall(model) 449 | for i, n in enumerate([1, 5, 10]): 450 | wandb.log({'{}/{}_r@{}'.format(model.id, self.opt.split, n): recalls[i]}, step=self.step) 451 | # self.writer.add_scalar('{}/{}_r@{}'.format(model.id, self.opt.split, n), recalls[i], self.epoch) 452 | 453 | return recalls 454 | 455 | def test(self): 456 | recalls, _ = self.get_recall(self.model, save_embs=True) 457 | print('best r@1/5/10={:.2f}/{:.2f}/{:.2f}'.format(recalls[0], recalls[1], recalls[2])) 458 | 459 | return recalls 460 | 461 | def save_model(self, model, is_best=False, save_every_epoch=False): 462 | if is_best: 463 | torch.save({ 464 | 'epoch': self.epoch, 465 | 'step': self.step, 466 | 'state_dict': model.state_dict(), 467 | 'optimizer': self.optimizer.state_dict(), 468 | 'scheduler': self.scheduler.state_dict(), 469 | }, os.path.join(self.opt.runsPath, 'ckpt_best.pth.tar')) 470 | 471 | if save_every_epoch: 472 | torch.save({ 473 | 'epoch': self.epoch, 474 | 'step': self.step, 475 | 'state_dict': model.state_dict(), 476 | 'optimizer': self.optimizer.state_dict(), 477 | 'scheduler': self.scheduler.state_dict(), 478 | }, os.path.join(self.opt.runsPath, 'ckpt_e_{}.pth.tar'.format(self.epoch))) 479 | 480 | def get_recall(self, model, save_embs=False): 481 | model.eval() 482 | 483 | if self.opt.split == 'val': 484 | eval_dataloader = self.whole_val_data_loader 485 | eval_set = self.whole_val_set 486 | elif self.opt.split == 'test': 487 | eval_dataloader = self.whole_test_data_loader 488 | eval_set = self.whole_test_set 489 | # print(f"{self.opt.split} len:{len(eval_set)}") 490 | 491 | whole_mu = torch.zeros((len(eval_set), self.opt.output_dim), device=self.device) # (N, D) 492 | whole_var = torch.zeros((len(eval_set), self.opt.sigma_dim), device=self.device) # (N, D) 493 | gt = eval_set.get_positives() # (N, n_pos) 494 | 495 | with torch.no_grad(): 496 | for iteration, (input, indices) in enumerate(tqdm(eval_dataloader), 1): 497 | input = input.to(self.device) 498 | mu, var = model(input) # (B, D) 499 | 500 | # var = torch.exp(var) 501 | whole_mu[indices, :] = mu 502 | whole_var[indices, :] = var 503 | del input, mu, var 504 | 505 | n_values = [1, 5, 10] 506 | 507 | whole_var = torch.exp(whole_var) 508 | whole_mu = whole_mu.cpu().numpy() 509 | whole_var = whole_var.cpu().numpy() 510 | mu_q = whole_mu[eval_set.dbStruct.numDb:].astype('float32') 511 | mu_db = whole_mu[:eval_set.dbStruct.numDb].astype('float32') 512 | sigma_q = whole_var[eval_set.dbStruct.numDb:].astype('float32') 513 | sigma_db = whole_var[:eval_set.dbStruct.numDb].astype('float32') 514 | faiss_index = faiss.IndexFlatL2(mu_q.shape[1]) 515 | faiss_index.add(mu_db) 516 | dists, preds = faiss_index.search(mu_q, max(n_values)) # the results is sorted 517 | 518 | # cull queries without any ground truth positives in the database 519 | val_inds = [True if len(gt[ind]) != 0 else False for ind in range(len(gt))] 520 | val_inds = np.array(val_inds) 521 | mu_q = mu_q[val_inds] 522 | sigma_q = sigma_q[val_inds] 523 | preds = preds[val_inds] 524 | dists = dists[val_inds] 525 | gt = gt[val_inds] 526 | 527 | recall_at_k = cal_recall(preds, gt, n_values) 528 | 529 | if save_embs: 530 | with open(join(self.opt.runsPath, '{}_db_embeddings_{}.pickle'.format(self.opt.split, self.opt.resume.split('.')[-3].split('_')[-1])), 'wb') as handle: 531 | pickle.dump(mu_q, handle, protocol=pickle.HIGHEST_PROTOCOL) 532 | pickle.dump(mu_db, handle, protocol=pickle.HIGHEST_PROTOCOL) 533 | pickle.dump(sigma_q, handle, protocol=pickle.HIGHEST_PROTOCOL) 534 | pickle.dump(sigma_db, handle, protocol=pickle.HIGHEST_PROTOCOL) 535 | pickle.dump(preds, handle, protocol=pickle.HIGHEST_PROTOCOL) 536 | pickle.dump(dists, handle, protocol=pickle.HIGHEST_PROTOCOL) 537 | pickle.dump(gt, handle, protocol=pickle.HIGHEST_PROTOCOL) 538 | pickle.dump(whole_mu, handle, protocol=pickle.HIGHEST_PROTOCOL) 539 | pickle.dump(whole_var, handle, protocol=pickle.HIGHEST_PROTOCOL) 540 | print('embeddings saved for post processing') 541 | 542 | return recall_at_k, None 543 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from os.path import join 3 | 4 | import faiss 5 | import numpy as np 6 | from pynvml import * 7 | from scipy import stats 8 | from scipy.io import loadmat 9 | from scipy.optimize import least_squares 10 | from skimage import io 11 | 12 | 13 | def linear_fit(x, y, w, report_error=False): 14 | def cost(p, x, y, w): 15 | k = p[0] 16 | b = p[1] 17 | error = y - (k * x + b) 18 | error *= w 19 | return error 20 | 21 | p_init = np.array([-1, 1]) 22 | ret = least_squares(cost, p_init, args=(x, y, w), verbose=0) 23 | # print(ret['x'][0], ret['x'][1], ) 24 | y_fitted = ret['x'][0] * x + ret['x'][1] 25 | error = ret['cost'] 26 | if report_error: 27 | return y_fitted, error 28 | else: 29 | return y_fitted 30 | 31 | 32 | def reduce_sigma(sigma, std_or_sq, log_or_linear, hmean_or_mean): 33 | ''' 34 | input sigma: sigma^2, ([1, D]) 35 | output sigma: sigma, (1) 36 | ''' 37 | if log_or_linear == 'log': 38 | print('log') 39 | sigma = np.log(sigma) 40 | elif log_or_linear == 'linear': 41 | pass 42 | else: 43 | raise NameError('undefined') 44 | 45 | if std_or_sq == 'std': 46 | sigma = np.sqrt(sigma) 47 | elif std_or_sq == 'sq': 48 | pass 49 | else: 50 | raise NameError('undefined') 51 | 52 | if hmean_or_mean == 'hmean': 53 | sigma = stats.hmean(sigma, axis=1) # ([numQ,]) 54 | elif hmean_or_mean == 'mean': 55 | sigma = np.mean(sigma, axis=1) # ([numQ,]) 56 | else: 57 | raise NameError('undefined') 58 | 59 | return sigma 60 | 61 | 62 | def schedule_device(): 63 | ''' output id of the graphic card with most free memory 64 | ''' 65 | nvmlInit() 66 | deviceCount = nvmlDeviceGetCount() 67 | frees = [] 68 | for i in range(deviceCount): 69 | handle = nvmlDeviceGetHandleByIndex(i) 70 | # print("GPU", i, ":", nvmlDeviceGetName(handle)) 71 | info = nvmlDeviceGetMemoryInfo(handle) 72 | frees.append(info.free / 1e9) 73 | nvmlShutdown() 74 | # print(frees) 75 | id = frees.index(max(frees)) 76 | # print(id) 77 | return id 78 | 79 | def light_log(path, args): 80 | with open(join(path, 'screen.log'), 'a') as f: 81 | for arg in args: 82 | f.write(arg) 83 | f.flush() 84 | print(arg, end='') 85 | 86 | 87 | def cal_recall_from_embeddings(gt, qFeat, dbFeat): 88 | n_values = [1, 5, 10] 89 | 90 | # ---------------------------------------------------- sklearn --------------------------------------------------- # 91 | # knn = NearestNeighbors(n_jobs=-1) 92 | # knn.fit(dbFeat) 93 | # dists, predictions = knn.kneighbors(qFeat, len(dbFeat)) 94 | 95 | # --------------------------------- use faiss to do NN search -------------------------------- # 96 | faiss_index = faiss.IndexFlatL2(qFeat.shape[1]) 97 | faiss_index.add(dbFeat) 98 | dists, predictions = faiss_index.search(qFeat, max(n_values)) # the results is sorted 99 | 100 | recall_at_n = cal_recall(predictions, gt, n_values) 101 | return recall_at_n 102 | 103 | 104 | def cal_recall(ranks, pidx, ks): 105 | 106 | recall_at_k = np.zeros(len(ks)) 107 | for qidx in range(ranks.shape[0]): 108 | for i, k in enumerate(ks): 109 | if np.sum(np.in1d(ranks[qidx, :k], pidx[qidx])) > 0: 110 | recall_at_k[i:] += 1 111 | break 112 | 113 | recall_at_k /= ranks.shape[0] 114 | 115 | return recall_at_k * 100.0 116 | 117 | 118 | def cal_apk(pidx, rank, k): 119 | if len(rank) > k: 120 | rank = rank[:k] 121 | 122 | score = 0.0 123 | num_hits = 0.0 124 | 125 | for i, p in enumerate(rank): 126 | if p in pidx and p not in rank[:i]: 127 | num_hits += 1.0 128 | score += num_hits / (i + 1.0) 129 | 130 | return score / min(len(pidx), k) * 100.0 131 | 132 | 133 | def cal_mapk(ranks, pidxs, k): 134 | 135 | return np.mean([cal_apk(a, p, k) for a, p in zip(pidxs, ranks)]) 136 | 137 | 138 | def get_zoomed_bins(sigma, num_of_bins): 139 | s_min = np.min(sigma) 140 | s_max = np.max(sigma) 141 | print(s_min, s_max) 142 | bins_parent = np.linspace(s_min, s_max, num=num_of_bins) 143 | k = 0 144 | while True: 145 | indices = [] 146 | bins_child = np.linspace(bins_parent[0], bins_parent[-1 - k], num=num_of_bins) 147 | for index in range(num_of_bins - 1): 148 | target_q_ind_l = np.where(sigma >= bins_child[index]) 149 | if index != num_of_bins - 2: 150 | target_q_ind_r = np.where(sigma < bins_child[index + 1]) 151 | else: 152 | target_q_ind_r = np.where(sigma <= bins_child[index + 1]) 153 | target_q_ind = np.intersect1d(target_q_ind_l[0], target_q_ind_r[0]) 154 | indices.append(target_q_ind) 155 | # if len(indices[-1]) > int(sigma.shape[0] * 0.0005): 156 | if len(indices[-1]) > int(sigma.shape[0] * 0.001) or k == num_of_bins - 2: 157 | break 158 | else: 159 | k = k + 1 160 | # print('{:.3f}'.format(sum([len(x) for x in indices]) / sigma.shape[0]), [len(x) for x in indices]) 161 | # print('k=', k) 162 | return indices, bins_child, k 163 | 164 | 165 | def bin_pr(preds, dists, gt, vis=False): 166 | # dists_m = np.around(dists[:, 0], 2) # (4620,) 167 | # dists_u = np.array(list(set(dists_m))) 168 | # dists_u = np.sort(dists_u) # small > large 169 | 170 | dists_u = np.linspace(np.min(dists[:, 0]), np.max(dists[:, 0]), num=100) 171 | 172 | recalls = [] 173 | precisions = [] 174 | for th in dists_u: 175 | TPCount = 0 176 | FPCount = 0 177 | FNCount = 0 178 | TNCount = 0 179 | for index_q in range(dists.shape[0]): 180 | # Positive 181 | if dists[index_q, 0] < th: 182 | # True 183 | if np.any(np.in1d(preds[index_q, 0], gt[index_q])): 184 | TPCount += 1 185 | else: 186 | FPCount += 1 187 | else: 188 | if np.any(np.in1d(preds[index_q, 0], gt[index_q])): 189 | FNCount += 1 190 | else: 191 | TNCount += 1 192 | assert TPCount + FPCount + FNCount + TNCount == dists.shape[0], 'Count Error!' 193 | if TPCount + FNCount == 0 or TPCount + FPCount == 0: 194 | # print('zero') 195 | continue 196 | recall = TPCount / (TPCount + FNCount) 197 | precision = TPCount / (TPCount + FPCount) 198 | recalls.append(recall) 199 | precisions.append(precision) 200 | if vis: 201 | from matplotlib import pyplot as plt 202 | plt.style.use('ggplot') 203 | fig = plt.figure(figsize=(5, 5)) 204 | ax = fig.add_subplot(111) 205 | ax.plot(recalls, precisions) 206 | ax.set_title('Precision-Recall') 207 | ax.set_xlabel('Recall') 208 | ax.set_ylabel('Precision') 209 | ax.set_xlim([0, 1]) 210 | ax.set_ylim([0, 1]) 211 | plt.savefig('pr.png', dpi=200) 212 | return recalls, precisions 213 | 214 | 215 | 216 | def parse_dbStruct_pitts(path): 217 | dbStruct = namedtuple('dbStruct', ['whichSet', 'dataset', 'dbImage', 'utmDb', 'qImage', 'utmQ', 'numDb', 'numQ', 'posDistThr', 'posDistSqThr', 'nonTrivPosDistSqThr']) 218 | 219 | mat = loadmat(path) 220 | matStruct = mat['dbStruct'].item() 221 | 222 | dataset = 'pitts' 223 | 224 | whichSet = matStruct[0].item() 225 | 226 | # .mat file is generated by python, I replace the use of cell (in Matlab) with char (in Python) 227 | dbImage = [f[0].item() for f in matStruct[1]] 228 | # dbImage = matStruct[1] 229 | utmDb = matStruct[2].T 230 | # utmDb = matStruct[2] 231 | 232 | # .mat file is generated by python, I replace the use of cell (in Matlab) with char (in Python) 233 | qImage = [f[0].item() for f in matStruct[3]] 234 | # qImage = matStruct[3] 235 | utmQ = matStruct[4].T 236 | # utmQ = matStruct[4] 237 | 238 | numDb = matStruct[5].item() 239 | numQ = matStruct[6].item() 240 | 241 | posDistThr = matStruct[7].item() 242 | posDistSqThr = matStruct[8].item() 243 | nonTrivPosDistSqThr = matStruct[9].item() 244 | 245 | return dbStruct(whichSet, dataset, dbImage, utmDb, qImage, utmQ, numDb, numQ, posDistThr, posDistSqThr, nonTrivPosDistSqThr) 246 | 247 | def cal_hs(img_path): 248 | img = io.imread(img_path, as_gray=True).reshape(-1, 1) 249 | counts, bins = np.histogram((img * 255).astype(np.int16), np.arange(0, 256, 1)) 250 | counts = counts / np.sum(counts) 251 | cumulative = np.cumsum(counts) 252 | in_min = np.min((img*255).astype(np.int16)) 253 | in_max = np.max((img*255).astype(np.int16)) 254 | per_75 = np.argwhere(cumulative < 0.75)[-1] 255 | per_25 = np.argwhere(cumulative < 0.25)[-1] 256 | hs = (per_75 - per_25)/255 257 | return hs 258 | 259 | if __name__ == '__main__': 260 | pass 261 | -------------------------------------------------------------------------------- /vis_results.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import pickle 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | from tqdm import tqdm 6 | from os.path import join 7 | import utils 8 | import importlib 9 | 10 | importlib.reload(utils) 11 | 12 | # --------------------------------------------------------------------------------------------------------------------- # 13 | NETWORK = 'teacher_triplet' 14 | # Choose NETWORK from |'teacher_triplet'|'student_contrast'|'student_triplet'|'student_quadruplet'| 15 | # --------------------------------------------------------------------------------------------------------------------- # 16 | 17 | NUM_BINS = 11 18 | SHOW_AP = True 19 | exp = NETWORK 20 | resume = join('logs', NETWORK) 21 | 22 | with open(join(resume, 'embs.pickle'), 'rb') as handle: 23 | q_mu = pickle.load(handle) 24 | db_mu = pickle.load(handle) 25 | q_sigma_sq = pickle.load(handle) 26 | db_sigma_sq = pickle.load(handle) 27 | preds = pickle.load(handle) 28 | dists = pickle.load(handle) 29 | gt = pickle.load(handle) 30 | _ = pickle.load(handle) 31 | _ = pickle.load(handle) 32 | 33 | 34 | q_sigma_sq_h = np.mean(q_sigma_sq, axis=1) 35 | db_sigma_sq_h = np.mean(db_sigma_sq, axis=1) 36 | indices, _, _ = utils.get_zoomed_bins(q_sigma_sq_h, NUM_BINS) 37 | 38 | # ---------------------- ECE and Recognition performance --------------------- # 39 | bins_recall = np.zeros((NUM_BINS-1, 3)) 40 | bins_map = np.zeros((NUM_BINS-1, 3)) 41 | bins_ap = np.zeros((NUM_BINS - 1)) 42 | 43 | ece_bins_recall = np.zeros((NUM_BINS - 1, 3)) 44 | ece_bins_map = np.zeros((NUM_BINS - 1, 3)) 45 | ece_bins_ap = np.zeros((NUM_BINS - 1)) 46 | 47 | n_values = [1, 5, 10] 48 | for index in tqdm(range(NUM_BINS - 1)): 49 | if len(indices[index]) == 0: 50 | continue 51 | 52 | pred_bin = preds[indices[index]] 53 | dist_bin = dists[indices[index]] 54 | gt_bin = gt[indices[index]] 55 | 56 | if SHOW_AP: 57 | # calculate AP 58 | recalls, precisions = utils.bin_pr(pred_bin, dist_bin, gt_bin) 59 | ap = 0 60 | for index_j in range(len(recalls) - 1): 61 | ap += precisions[index_j] * (recalls[index_j + 1] - recalls[index_j]) 62 | bins_ap[index] = ap 63 | ece_bins_ap[index] = len(indices[index]) / q_sigma_sq_h.shape[0] * np.abs(ap - (NUM_BINS - 1 - index) / ((NUM_BINS - 1))) 64 | 65 | # calculate r@N 66 | recall_at_n = utils.cal_recall(pred_bin, gt_bin, n_values) 67 | bins_recall[index] = recall_at_n 68 | ece_bins_recall[index] = np.array([len(indices[index]) / q_sigma_sq_h.shape[0] * np.abs(recall_at_n[i] / 100.0 - (NUM_BINS - 1 - index) / ((NUM_BINS - 1))) for i in range(len(n_values))]) 69 | 70 | # calculate mAP@N 71 | map_n = [utils.cal_mapk(pred_bin, gt_bin, n) for n in n_values] 72 | bins_map[index] = map_n 73 | ece_bins_map[index] = np.array([len(indices[index]) / q_sigma_sq_h.shape[0] * np.abs(map_n[i] / 100.0 - (NUM_BINS - 1 - index) / ((NUM_BINS - 1))) for i in range(len(n_values))]) 74 | 75 | 76 | # ---------------------------- uncertainty metric ---------------------------- # 77 | print('ECE_rec@1/5/10: {:.3f} / {:.3f} / {:.3f}'.format(ece_bins_recall.sum(axis=0)[0], ece_bins_recall.sum(axis=0)[1], ece_bins_recall.sum(axis=0)[2])) 78 | print('ECE_mAP@1/5/10: {:.3f} / {:.3f} / {:.3f}'.format(ece_bins_map.sum(axis=0)[0], ece_bins_map.sum(axis=0)[1], ece_bins_map.sum(axis=0)[2])) 79 | if SHOW_AP: 80 | print('ECE_AP: {:.3f}'.format(ece_bins_ap.sum())) 81 | 82 | # ---------------------------- recognition metric ---------------------------- # 83 | recall = utils.cal_recall(preds, gt, n_values) / 100.0 84 | print('rec@1/5/10: {:.3f} / {:.3f} / {:.3f}'.format(recall[0], recall[1], recall[2])) 85 | map = [utils.cal_mapk(preds, gt, n) / 100.0 for n in n_values] 86 | print('mAP@1/5/10: {:.3f} / {:.3f} / {:.3f}'.format(map[0], map[1], map[2])) 87 | 88 | if SHOW_AP: 89 | recalls, precisions = utils.bin_pr(preds, dists, gt) 90 | ap = 0 91 | for index_j in range(len(recalls) - 1): 92 | ap += precisions[index_j] * (recalls[index_j + 1] - recalls[index_j]) 93 | 94 | print('AP: {:.3f}'.format(ap)) 95 | 96 | 97 | 98 | # ------------------------------- visulization ------------------------------- # 99 | w = np.array([len(indices[index]) / q_sigma_sq_h.shape[0] for index in range(NUM_BINS - 1)]) 100 | x = np.arange(0, NUM_BINS - 1, 1) 101 | 102 | plt.style.use('ggplot') 103 | fig, axs = plt.subplots(2, 2, figsize=(10, 10), squeeze=False) 104 | fig.suptitle(exp) 105 | 106 | ax = axs[0][0] 107 | ax.bar(np.arange(len(indices)), [len(x) for x in indices]) 108 | ax.set_xlabel('sigma^2\n(uncertainty: low -> high)') 109 | ax.set_ylabel('num of samples') 110 | 111 | ax = axs[0][1] 112 | ax.plot(np.arange(NUM_BINS - 1), bins_recall[:, 0], marker='o') 113 | ax.plot(np.arange(NUM_BINS - 1), bins_recall[:, 1], marker='o') 114 | ax.plot(np.arange(NUM_BINS - 1), bins_recall[:, 2], marker='o') 115 | 116 | ax.set_xlabel('sigma^2\n(uncertainty: low -> high)') 117 | ax.set_ylabel('recall@n') 118 | 119 | ax = axs[1][0] 120 | ax.plot(np.arange(NUM_BINS - 1), bins_map[:, 0], marker='o') 121 | ax.plot(np.arange(NUM_BINS - 1), bins_map[:, 1], marker='o') 122 | ax.plot(np.arange(NUM_BINS - 1), bins_map[:, 2], marker='o') 123 | 124 | ax.set_xlabel('sigma^2\n(uncertainty: low -> high)') 125 | ax.set_ylabel('mAP@n') 126 | 127 | if SHOW_AP: 128 | ax = axs[1][1] 129 | ax.plot(np.arange(NUM_BINS - 1), bins_ap, marker='o') 130 | ax.set_xlabel('sigma^2\n(uncertainty: low -> high)') 131 | ax.set_ylabel('AP') 132 | plt.savefig(join(resume, 'performance.png'), dpi=200) 133 | --------------------------------------------------------------------------------