├── LICENSE ├── README.md ├── cirtorch ├── __pycache__ │ └── functional.cpython-38.pyc └── functional.py ├── datasets ├── __pycache__ │ └── pitts.cpython-38.pyc └── pitts.py ├── find_pair.py ├── main.py ├── netvlad.py ├── networks └── tscm.py ├── options.py ├── readmat.py ├── requirement.txt ├── trainer.py ├── utils.py └── vis.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yehui Shen, Xieyuanli Chen, NuBot 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TSCM: A Teacher-Student Model for Vision Place Recognition Using Cross-Metric Knowledge Distillation 2 | Our work has been accepted by ICRA2024 :clap: 🎉 3 | 4 | If you use our code in your work, please star our repo and cite our paper [[pdf](https://arxiv.org/pdf/2404.01587)]. 5 | 6 | ```bibtex 7 | @inproceedings{shen2024icra, 8 | title={{TSCM: A Teacher-Student Model for Vision Place Recognition Using Cross-Metric Knowledge Distillation}}, 9 | author={Shen, Yehui and Liu, Mingmin and Lu, Huimin and Chen, Xieyuanli}, 10 | booktitle={IEEE International Conference on Robotics and Automation (ICRA)}, 11 | year={2024} 12 | } 13 | ``` 14 | 15 | ## Pittsburgh Dataset 16 | You can download the pittsburgh dataset on https://www.dropbox.com/s/ynep8wzii1z0r6h/pittsburgh.zip?dl=0 17 | ## How to use 18 | If you want to verify the effect, you can download the [stu_30k.pickle](https://www.dropbox.com/scl/fi/2rad0vkf0fd2v9er10g2v/stu_30k.pickle?rlkey=b04iygbqlsspt1upkr9jjyjaj&dl=0) file and put it in the folder /logs/contrast/ ,run the following code 19 | ```shell 20 | python vis.py 21 | ``` 22 | 23 | ## Train the pretrained models 24 | In training mode 25 | ```shell 26 | self.parser.add_argument('---split', type=str, default='val', help='Split to use', choices=['val', 'test']) 27 | ``` 28 | ```shell 29 | # train the teacher net 30 | python main.py --phase=train_tea 31 | 32 | # train the student net supervised by the pretrained teacher net 33 | python main.py --phase=train_stu --resume=[logs/teacher_net_xxx/ckpt_best.pth.tar] 34 | ``` 35 | ## Evaluate the pretrained models 36 | In test mode 37 | ```shell 38 | self.parser.add_argument('---split', type=str, default='val', help='Split to use', choices=['val', 'test']) 39 | ``` 40 | needs to be changed to 41 | ```shell 42 | self.parser.add_argument('---split', type=str, default='test', help='Split to use', choices=['val', 'test']) 43 | ``` 44 | You can run [pre-trained models](https://www.dropbox.com/scl/fo/c7why2nf82gn1ffv6dsr7/h?rlkey=7ariswrhfecaezjh0xz40599i&dl=0). **The teacher_triplet/ckpt.pth.tar in the code needs to be changed to the appropriate name**. 45 | 46 | If this pre-training model doesn't work or you need more pre-training models, please contact me in Issue. 47 | ```shell 48 | python main.py --phase=test_stu --resume=logs/teacher_triplet/ckpt.pth.tar 49 | ``` 50 | ## Use the pretrained model to PR 51 | if you want to use the model to place recognition, you can replace the code in **find_pair.py** with the code in **trainer.py** and run the following code. 52 | ```shell 53 | python main.py --phase=test_stu --resume=logs/teacher_triplet/ckpt.pth.tar 54 | ``` 55 | 56 | Thanks to the open source work of [baseline](https://github.com/ramdrop/stun), the code of TSCM is based on it. 57 | -------------------------------------------------------------------------------- /cirtorch/__pycache__/functional.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/TSCM/6c46db02488c13ff2428345fbd67b522d532d492/cirtorch/__pycache__/functional.cpython-38.pyc -------------------------------------------------------------------------------- /cirtorch/functional.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pdb 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | # -------------------------------------- 8 | # pooling 9 | # -------------------------------------- 10 | 11 | 12 | def mac(x): 13 | return F.max_pool2d(x, (x.size(-2), x.size(-1))) 14 | # return F.adaptive_max_pool2d(x, (1,1)) # alternative 15 | 16 | 17 | def spoc(x): 18 | return F.avg_pool2d(x, (x.size(-2), x.size(-1))) 19 | # return F.adaptive_avg_pool2d(x, (1,1)) # alternative 20 | 21 | 22 | def gem(x, p=3, eps=1e-6): 23 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1. / p) 24 | # return F.lp_pool2d(F.threshold(x, eps, eps), p, (x.size(-2), x.size(-1))) # alternative 25 | 26 | 27 | def rmac(x, L=3, eps=1e-6): 28 | ovr = 0.4 # desired overlap of neighboring regions 29 | steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension 30 | 31 | W = x.size(3) 32 | H = x.size(2) 33 | 34 | w = min(W, H) 35 | w2 = math.floor(w / 2.0 - 1) 36 | 37 | b = (max(H, W) - w) / (steps - 1) 38 | (tmp, idx) = torch.min(torch.abs(((w**2 - w * b) / w**2) - ovr), 0) # steps(idx) regions for long dimension 39 | 40 | # region overplus per dimension 41 | Wd = 0 42 | Hd = 0 43 | if H < W: 44 | Wd = idx.item() + 1 45 | elif H > W: 46 | Hd = idx.item() + 1 47 | 48 | v = F.max_pool2d(x, (x.size(-2), x.size(-1))) 49 | v = v / (torch.norm(v, p=2, dim=1, keepdim=True) + eps).expand_as(v) 50 | 51 | for l in range(1, L + 1): 52 | wl = math.floor(2 * w / (l + 1)) 53 | wl2 = math.floor(wl / 2 - 1) 54 | 55 | if l + Wd == 1: 56 | b = 0 57 | else: 58 | b = (W - wl) / (l + Wd - 1) 59 | cenW = torch.floor(wl2 + torch.Tensor(range(l - 1 + Wd + 1)) * b) - wl2 # center coordinates 60 | if l + Hd == 1: 61 | b = 0 62 | else: 63 | b = (H - wl) / (l + Hd - 1) 64 | cenH = torch.floor(wl2 + torch.Tensor(range(l - 1 + Hd + 1)) * b) - wl2 # center coordinates 65 | 66 | for i_ in cenH.tolist(): 67 | for j_ in cenW.tolist(): 68 | if wl == 0: 69 | continue 70 | R = x[:, :, (int(i_) + torch.Tensor(range(wl)).long()).tolist(), :] 71 | R = R[:, :, :, (int(j_) + torch.Tensor(range(wl)).long()).tolist()] 72 | vt = F.max_pool2d(R, (R.size(-2), R.size(-1))) 73 | vt = vt / (torch.norm(vt, p=2, dim=1, keepdim=True) + eps).expand_as(vt) 74 | v += vt 75 | 76 | return v 77 | 78 | 79 | def roipool(x, rpool, L=3, eps=1e-6): 80 | ovr = 0.4 # desired overlap of neighboring regions 81 | steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension 82 | 83 | W = x.size(3) 84 | H = x.size(2) 85 | 86 | w = min(W, H) 87 | w2 = math.floor(w / 2.0 - 1) 88 | 89 | b = (max(H, W) - w) / (steps - 1) 90 | _, idx = torch.min(torch.abs(((w**2 - w * b) / w**2) - ovr), 0) # steps(idx) regions for long dimension 91 | 92 | # region overplus per dimension 93 | Wd = 0 94 | Hd = 0 95 | if H < W: 96 | Wd = idx.item() + 1 97 | elif H > W: 98 | Hd = idx.item() + 1 99 | 100 | vecs = [] 101 | vecs.append(rpool(x).unsqueeze(1)) 102 | 103 | for l in range(1, L + 1): 104 | wl = math.floor(2 * w / (l + 1)) 105 | wl2 = math.floor(wl / 2 - 1) 106 | 107 | if l + Wd == 1: 108 | b = 0 109 | else: 110 | b = (W - wl) / (l + Wd - 1) 111 | cenW = torch.floor(wl2 + torch.Tensor(range(l - 1 + Wd + 1)) * b).int() - wl2 # center coordinates 112 | if l + Hd == 1: 113 | b = 0 114 | else: 115 | b = (H - wl) / (l + Hd - 1) 116 | cenH = torch.floor(wl2 + torch.Tensor(range(l - 1 + Hd + 1)) * b).int() - wl2 # center coordinates 117 | 118 | for i_ in cenH.tolist(): 119 | for j_ in cenW.tolist(): 120 | if wl == 0: 121 | continue 122 | vecs.append(rpool(x.narrow(2, i_, wl).narrow(3, j_, wl)).unsqueeze(1)) 123 | 124 | return torch.cat(vecs, dim=1) 125 | 126 | 127 | # -------------------------------------- 128 | # normalization 129 | # -------------------------------------- 130 | 131 | 132 | def l2n(x, eps=1e-6): 133 | return x / (torch.norm(x, p=2, dim=1, keepdim=True) + eps).expand_as(x) 134 | 135 | 136 | def powerlaw(x, eps=1e-6): 137 | x = x + self.eps 138 | return x.abs().sqrt().mul(x.sign()) 139 | 140 | 141 | # -------------------------------------- 142 | # loss 143 | # -------------------------------------- 144 | 145 | 146 | def contrastive_loss(x, label, margin=0.7, eps=1e-6): 147 | # x is D x N 148 | dim = x.size(0) # D 149 | nq = torch.sum(label.data == -1) # number of tuples 150 | S = x.size(1) // nq # number of images per tuple including query: 1+1+n 151 | 152 | x1 = x[:, ::S].permute(1, 0).repeat(1, S - 1).view((S - 1) * nq, dim).permute(1, 0) 153 | idx = [i for i in range(len(label)) if label.data[i] != -1] 154 | x2 = x[:, idx] 155 | lbl = label[label != -1] 156 | 157 | dif = x1 - x2 158 | D = torch.pow(dif + eps, 2).sum(dim=0).sqrt() 159 | 160 | y = 0.5 * lbl * torch.pow(D, 2) + 0.5 * (1 - lbl) * torch.pow(torch.clamp(margin - D, min=0), 2) 161 | y = torch.sum(y) 162 | return y 163 | 164 | 165 | def triplet_loss(x, label, margin=0.1): 166 | # x is D x N 167 | dim = x.size(0) # D 168 | nq = torch.sum(label.data == -1).item() # number of tuples 169 | S = x.size(1) // nq # number of images per tuple including query: 1+1+n 170 | 171 | xa = x[:, label.data == -1].permute(1, 0).repeat(1, S - 2).view((S - 2) * nq, dim).permute(1, 0) 172 | xp = x[:, label.data == 1].permute(1, 0).repeat(1, S - 2).view((S - 2) * nq, dim).permute(1, 0) 173 | xn = x[:, label.data == 0] 174 | 175 | dist_pos = torch.sum(torch.pow(xa - xp, 2), dim=0) 176 | dist_neg = torch.sum(torch.pow(xa - xn, 2), dim=0) 177 | 178 | return torch.sum(torch.clamp(dist_pos - dist_neg + margin, min=0)) -------------------------------------------------------------------------------- /datasets/__pycache__/pitts.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nubot-nudt/TSCM/6c46db02488c13ff2428345fbd67b522d532d492/datasets/__pycache__/pitts.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/pitts.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from os.path import join 3 | 4 | import h5py 5 | import numpy as np 6 | import torch 7 | import torch.utils.data as data 8 | import torchvision.transforms as transforms 9 | from PIL import Image 10 | from scipy.io import loadmat 11 | from sklearn.neighbors import NearestNeighbors 12 | from torchvision.transforms import InterpolationMode 13 | 14 | 15 | def input_transform(opt=None): 16 | return transforms.Compose([ 17 | transforms.ToTensor(), 18 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 19 | transforms.Resize((opt.height, opt.width), interpolation=InterpolationMode.BILINEAR), 20 | ]) 21 | 22 | 23 | def get_whole_training_set(opt, onlyDB=False, forCluster=False, return_labels=False): 24 | return WholeDatasetFromStruct(opt, join(opt.structDir, 'pitts30k_train.mat'), opt.imgDir, input_transform=input_transform(opt), onlyDB=onlyDB, forCluster=forCluster, return_labels=return_labels) 25 | 26 | 27 | def get_whole_val_set(opt, return_labels=False): 28 | return WholeDatasetFromStruct(opt, join(opt.structDir, 'pitts30k_val.mat'), opt.imgDir, input_transform=input_transform(opt), return_labels=return_labels) 29 | 30 | 31 | def get_whole_test_set(opt, return_labels=False): 32 | return WholeDatasetFromStruct(opt, join(opt.structDir, 'pitts30k_test.mat'), opt.imgDir, input_transform=input_transform(opt), return_labels=return_labels) 33 | 34 | 35 | def get_training_query_set(opt, margin=0.1): 36 | return QueryDatasetFromStruct(opt, join(opt.structDir, 'pitts30k_train.mat'), opt.imgDir, input_transform=input_transform(opt), margin=margin) 37 | 38 | 39 | def get_val_query_set(opt, margin=0.1): 40 | return QueryDatasetFromStruct(opt, join(opt.structDir, 'pitts30k_val.mat'), opt.imgDir, input_transform=input_transform(opt), margin=margin) 41 | 42 | 43 | def get_quad_set(opt, margin, margin2): 44 | return QuadrupletDataset(opt, join(opt.structDir, 'pitts30k_train.mat'), opt.imgDir, input_transform=input_transform(opt), margin=margin, margin2=margin2) 45 | 46 | 47 | dbStruct = namedtuple('dbStruct', ['whichSet', 'dataset', 'dbImage', 'utmDb', 'qImage', 'utmQ', 'numDb', 'numQ', 'posDistThr', 'posDistSqThr', 'nonTrivPosDistSqThr']) 48 | 49 | 50 | def parse_dbStruct(path): 51 | mat = loadmat(path) 52 | matStruct = mat['dbStruct'].item() 53 | 54 | dataset = 'nuscenes' 55 | 56 | whichSet = matStruct[0].item() 57 | 58 | # .mat file is generated by python, Kaiwen replaces the use of cell (in Matlab) with char (in Python) 59 | dbImage = [f[0].item() for f in matStruct[1]] 60 | # dbImage = matStruct[1] 61 | utmDb = matStruct[2].T 62 | # utmDb = matStruct[2] 63 | 64 | # .mat file is generated by python, I replace the use of cell (in Matlab) with char (in Python) 65 | qImage = [f[0].item() for f in matStruct[3]] 66 | # qImage = matStruct[3] 67 | utmQ = matStruct[4].T 68 | # utmQ = matStruct[4] 69 | 70 | numDb = matStruct[5].item() 71 | numQ = matStruct[6].item() 72 | 73 | posDistThr = matStruct[7].item() 74 | posDistSqThr = matStruct[8].item() 75 | nonTrivPosDistSqThr = matStruct[9].item() 76 | 77 | return dbStruct(whichSet, dataset, dbImage, utmDb, qImage, utmQ, numDb, numQ, posDistThr, posDistSqThr, nonTrivPosDistSqThr) 78 | 79 | 80 | class WholeDatasetFromStructForCluster(data.Dataset): 81 | def __init__(self, opt, structFile, img_dir, input_transform=None, onlyDB=False): 82 | super().__init__() 83 | 84 | self.input_transform = input_transform 85 | 86 | self.dbStruct = parse_dbStruct(structFile) 87 | 88 | self.images = [join(img_dir, 'database', dbIm) for dbIm in self.dbStruct.dbImage] 89 | if not onlyDB: 90 | self.images += [join(img_dir, 'query', qIm) for qIm in self.dbStruct.qImage] 91 | 92 | self.whichSet = self.dbStruct.whichSet 93 | self.dataset = self.dbStruct.dataset 94 | 95 | self.positives = None 96 | self.distances = None 97 | 98 | def __getitem__(self, index): 99 | img = Image.open(self.images[index]) 100 | if self.input_transform: 101 | img = self.input_transform(img) 102 | 103 | return img, index 104 | 105 | def __len__(self): 106 | return len(self.images) 107 | 108 | def getPositives(self): 109 | # positives for evaluation are those within trivial threshold range 110 | # fit NN to find them, search by radius 111 | if self.positives is None: 112 | knn = NearestNeighbors(n_jobs=-1) 113 | knn.fit(self.dbStruct.utmDb) 114 | self.distances, self.positives = knn.radius_neighbors(self.dbStruct.utmQ, radius=self.dbStruct.nonTrivPosDistSqThr**0.5) # TODO: sort!! 115 | 116 | return self.positives 117 | 118 | 119 | class WholeDatasetFromStruct(data.Dataset): 120 | def __init__(self, opt, structFile, img_dir, input_transform=None, onlyDB=False, forCluster=False, return_labels=False): 121 | super().__init__() 122 | self.opt = opt 123 | self.forCluster = forCluster 124 | self.return_labels = return_labels 125 | 126 | self.input_transform = input_transform 127 | 128 | self.dbStruct = parse_dbStruct(structFile) 129 | 130 | self.images = [join(img_dir, 'database', dbIm) for dbIm in self.dbStruct.dbImage] 131 | if not onlyDB: 132 | self.images += [join(img_dir, 'query', qIm) for qIm in self.dbStruct.qImage] 133 | 134 | self.whichSet = self.dbStruct.whichSet 135 | self.dataset = self.dbStruct.dataset 136 | 137 | self.positives = None 138 | self.distances = None 139 | 140 | def load_images(self, index): 141 | filename = self.images[index] 142 | # imgs = [] 143 | img = Image.open(filename) 144 | if self.input_transform: 145 | img = self.input_transform(img) 146 | # imgs.append(img) 147 | # imgs = torch.stack(imgs, 0) 148 | 149 | return img, index 150 | 151 | def __getitem__(self, index): 152 | if self.forCluster: 153 | img = Image.open(self.images[index]) 154 | if self.input_transform: 155 | img = self.input_transform(img) 156 | 157 | return img, index 158 | else: 159 | if self.return_labels: 160 | imgs, index = self.load_images(index) 161 | return imgs, index, self.dbStruct.utmQ[index] 162 | else: 163 | imgs, index = self.load_images(index) 164 | return imgs, index 165 | 166 | def __len__(self): 167 | return len(self.images) 168 | 169 | def get_databases(self): 170 | return self.dbStruct.utmDb 171 | 172 | def get_queries(self): 173 | return self.dbStruct.utmQ 174 | 175 | def get_positives(self): 176 | # positives for evaluation are those within trivial threshold range 177 | # fit NN to find them, search by radius 178 | if self.positives is None: 179 | knn = NearestNeighbors(n_jobs=-1) 180 | knn.fit(self.dbStruct.utmDb) 181 | self.distances, self.positives = knn.radius_neighbors(self.dbStruct.utmQ, radius=self.dbStruct.nonTrivPosDistSqThr**0.5) # TODO: sort!! 182 | 183 | return self.positives 184 | 185 | 186 | def collate_fn(batch): 187 | """Creates mini-batch tensors from the list of tuples (query, positive, negatives). 188 | 189 | Args: 190 | data: list of tuple (query, positive, negatives). 191 | - query: torch tensor of shape (3, h, w). 192 | - positive: torch tensor of shape (3, h, w). 193 | - negative: torch tensor of shape (n, 3, h, w). 194 | Returns: 195 | query: torch tensor of shape (batch_size, 3, h, w). 196 | positive: torch tensor of shape (batch_size, 3, h, w). 197 | negatives: torch tensor of shape (batch_size, n, 3, h, w). 198 | """ 199 | 200 | batch = list(filter(lambda x: x is not None, batch)) 201 | if len(batch) == 0: 202 | return None, None, None, None, None 203 | 204 | query, positive, negatives, indices = zip(*batch) 205 | 206 | query = data.dataloader.default_collate(query) # ([8, 3, 200, 200]) = [(3, 200, 200), (3, 200, 200), .. ] ([8, 1, 3, 200, 200]) 207 | positive = data.dataloader.default_collate(positive) 208 | negCounts = data.dataloader.default_collate([x.shape[0] for x in negatives]) 209 | negatives = torch.cat(negatives, 0) # ([80, 3, 200, 200]) ([80, 1, 3, 200, 200]) 210 | import itertools 211 | indices = list(itertools.chain(*indices)) 212 | 213 | return query, positive, negatives, negCounts, indices 214 | 215 | 216 | class QueryDatasetFromStruct(data.Dataset): 217 | def __init__(self, opt, structFile, img_dir, nNegSample=1000, nNeg=10, margin=0.1, input_transform=None): 218 | super().__init__() 219 | self.opt = opt 220 | self.img_dir = img_dir 221 | self.input_transform = input_transform 222 | self.margin = margin 223 | 224 | self.dbStruct = parse_dbStruct(structFile) 225 | self.whichSet = self.dbStruct.whichSet 226 | self.dataset = self.dbStruct.dataset 227 | self.nNegSample = nNegSample # number of negatives to randomly sample 228 | self.nNeg = nNeg # number of negatives used for training 229 | 230 | # potential positives are those within nontrivial threshold range 231 | # fit NN to find them, search by radius 232 | knn = NearestNeighbors(n_jobs=-1) 233 | knn.fit(self.dbStruct.utmDb) 234 | 235 | # TODO use sqeuclidean as metric? 236 | # 搜索100之内的 237 | self.nontrivial_positives = list(knn.radius_neighbors(self.dbStruct.utmQ, radius=self.dbStruct.nonTrivPosDistSqThr**0.5, return_distance=False)) 238 | # radius returns unsorted, sort once now so we dont have to later 239 | for i, posi in enumerate(self.nontrivial_positives): 240 | self.nontrivial_positives[i] = np.sort(posi) 241 | # its possible some queries don't have any non trivial potential positives 242 | # lets filter those out 243 | self.queries = np.where(np.array([len(x) for x in self.nontrivial_positives]) > 0)[0] 244 | 245 | # potential negatives are those outside of posDistThr range 246 | #搜索25之类的 247 | potential_positives = knn.radius_neighbors(self.dbStruct.utmQ, radius=self.dbStruct.posDistThr, return_distance=False) 248 | 249 | self.potential_negatives = [] 250 | for pos in potential_positives: 251 | self.potential_negatives.append(np.setdiff1d(np.arange(self.dbStruct.numDb), pos, assume_unique=True)) 252 | 253 | self.cache = None # filepath of HDF5 containing feature vectors for images 254 | 255 | self.negCache = [np.empty((0, ), dtype=np.int64) for _ in range(self.dbStruct.numQ)] 256 | 257 | def load_images(self, filename): 258 | # imgs = [] 259 | img = Image.open(filename) 260 | if self.input_transform: 261 | img = self.input_transform(img) 262 | # imgs.append(img) 263 | # imgs = torch.stack(imgs, 0) 264 | 265 | return img 266 | 267 | def __getitem__(self, index): 268 | 269 | index = self.queries[index] # re-map index to match dataset 270 | with h5py.File(self.cache, mode='r') as h5: 271 | h5feat = h5.get("features") 272 | qOffset = self.dbStruct.numDb 273 | 274 | qFeat = h5feat[index + qOffset] 275 | posFeat = h5feat[self.nontrivial_positives[index].tolist()] 276 | qFeat = torch.tensor(qFeat) 277 | posFeat = torch.tensor(posFeat) 278 | dist = torch.norm(qFeat - posFeat, dim=1, p=None) 279 | result = dist.topk(1, largest=False) 280 | dPos, posNN = result.values, result.indices 281 | posIndex = self.nontrivial_positives[index][posNN].item() 282 | 283 | negSample = np.random.choice(self.potential_negatives[index], self.nNegSample) # randomly choose potential_negatives 284 | negSample = np.unique(np.concatenate([self.negCache[index], negSample])) # remember negSamples history for each query 285 | 286 | negFeat = h5feat[negSample.tolist()] 287 | negFeat = torch.tensor(negFeat) 288 | dist = torch.norm(qFeat - negFeat, dim=1, p=None) 289 | result = dist.topk(self.nNeg * 10, largest=False) 290 | dNeg, negNN = result.values, result.indices 291 | 292 | if self.opt.loss == 'cont': 293 | violatingNeg = dNeg.numpy() < self.margin**0.5 294 | else: 295 | violatingNeg = dNeg.numpy() < dPos.numpy() + self.margin**0.5 296 | 297 | if np.sum(violatingNeg) < 1: 298 | return None 299 | 300 | negNN = negNN.numpy() 301 | negNN = negNN[violatingNeg][:self.nNeg] 302 | negIndices = negSample[negNN].astype(np.int32) 303 | self.negCache[index] = negIndices 304 | 305 | query = self.load_images(join(self.img_dir, 'query', self.dbStruct.qImage[index])) 306 | positive = self.load_images(join(self.img_dir, 'database', self.dbStruct.dbImage[posIndex])) 307 | 308 | negatives = [] 309 | for negIndex in negIndices: 310 | negative = self.load_images(join(self.img_dir, 'database', self.dbStruct.dbImage[negIndex])) 311 | negatives.append(negative) 312 | 313 | negatives = torch.stack(negatives, 0) # ([10, 3, 200, 200]) 314 | return query, positive, negatives, [index, posIndex] + negIndices.tolist() 315 | 316 | def __len__(self): 317 | return len(self.queries) 318 | 319 | 320 | class QuadrupletDataset(data.Dataset): 321 | def __init__(self, opt, structFile, img_dir, nNegSample=1000, nNeg=10, margin=0.1, margin2=0.05, input_transform=None): 322 | super().__init__() 323 | self.opt = opt 324 | self.img_dir = img_dir 325 | self.input_transform = input_transform 326 | self.margin = margin 327 | self.margin2 = margin2 328 | 329 | self.dbStruct = parse_dbStruct(structFile) 330 | self.whichSet = self.dbStruct.whichSet 331 | self.dataset = self.dbStruct.dataset 332 | self.nNegSample = nNegSample # number of negatives to randomly sample 333 | self.nNeg = nNeg # number of negatives used for training 334 | 335 | # potential positives are those within nontrivial threshold range, fit NN to find them, search by radius 336 | knn = NearestNeighbors(n_jobs=-1) 337 | knn.fit(self.dbStruct.utmDb) 338 | 339 | self.db_potential_positives = knn.radius_neighbors(self.dbStruct.utmDb, radius=self.dbStruct.posDistThr, return_distance=False) # 6312 340 | self.db_potential_negatives = [] 341 | for pos in self.db_potential_positives: 342 | self.db_potential_negatives.append(np.setdiff1d(np.arange(self.dbStruct.numDb), pos, assume_unique=True)) 343 | 344 | # TODO use sqeuclidean as metric? 345 | self.nontrivial_positives = list(knn.radius_neighbors(self.dbStruct.utmQ, radius=self.dbStruct.nonTrivPosDistSqThr**0.5, return_distance=False)) # 7075 346 | # radius returns unsorted, sort once now so we dont have to later 347 | for i, posi in enumerate(self.nontrivial_positives): 348 | self.nontrivial_positives[i] = np.sort(posi) 349 | # its possible some queries don't have any non trivial potential positives, lets filter those out 350 | self.queries = np.where(np.array([len(x) for x in self.nontrivial_positives]) > 0)[0] 351 | 352 | # potential negatives are those outside of posDistThr range 353 | self.potential_positives = knn.radius_neighbors(self.dbStruct.utmQ, radius=self.dbStruct.posDistThr, return_distance=False) 354 | 355 | self.potential_negatives = [] 356 | for pos in self.potential_positives: 357 | self.potential_negatives.append(np.setdiff1d(np.arange(self.dbStruct.numDb), pos, assume_unique=True)) 358 | 359 | self.cache = None # filepath of HDF5 containing feature vectors for images 360 | 361 | self.negCache = [np.empty((0, ), dtype=np.int64) for _ in range(self.dbStruct.numQ)] 362 | 363 | def load_images(self, filename): 364 | img = Image.open(filename) 365 | if self.input_transform: 366 | img = self.input_transform(img) 367 | return img 368 | 369 | def __getitem__(self, index): 370 | index = self.queries[index] # re-map index to match dataset 371 | with h5py.File(self.cache, mode='r') as h5: 372 | h5feat = h5.get("features") 373 | qOffset = self.dbStruct.numDb 374 | 375 | qFeat = h5feat[index + qOffset] 376 | tmp = self.nontrivial_positives[index] 377 | tmp = tmp.tolist() 378 | posFeat = h5feat[self.nontrivial_positives[index].tolist()] 379 | qFeat = torch.tensor(qFeat) 380 | posFeat = torch.tensor(posFeat) 381 | dist = torch.norm(qFeat - posFeat, dim=1, p=None) 382 | result = dist.topk(1, largest=False) # choose the closet positive 383 | dPos, posNN = result.values, result.indices 384 | posIndex = self.nontrivial_positives[index][posNN].item() 385 | 386 | negSample = np.random.choice(self.potential_negatives[index], self.nNegSample) # randomly choose potential_negatives 387 | negSample = np.unique(np.concatenate([self.negCache[index], negSample])) # encourage to sample from last negIndices + current last negIndices 388 | 389 | negFeat = h5feat[negSample.tolist()] 390 | negFeat = torch.tensor(negFeat) 391 | dist = torch.norm(qFeat - negFeat, dim=1, p=None) 392 | result = dist.topk(self.nNeg * 10, largest=False) 393 | dNeg, negNN = result.values, result.indices 394 | 395 | # try to find negatives that are within margin, if there aren't any return none 396 | violatingNeg = dNeg.numpy() < dPos.numpy() + self.margin**0.5 397 | 398 | if np.sum(violatingNeg) < 1: 399 | # if none are violating then skip this query 400 | return None 401 | 402 | negNN = negNN.numpy() 403 | negNN = negNN[violatingNeg][:self.nNeg] 404 | negIndices = negSample[negNN].astype(np.int32) 405 | self.negCache[index] = negIndices 406 | 407 | query = self.load_images(join(self.img_dir, 'query', self.dbStruct.qImage[index])) 408 | positive = self.load_images(join(self.img_dir, 'database', self.dbStruct.dbImage[posIndex])) 409 | 410 | negatives = [] 411 | negatives2 = [] 412 | negIndices2 = [] 413 | for negIndex in negIndices: 414 | anchor_neg_negs = np.random.choice(self.db_potential_negatives[negIndex], 1000, replace=False) 415 | anchor_poss = self.potential_positives[index] 416 | anchor_neg_negs_clean = np.setdiff1d(anchor_neg_negs, anchor_poss, assume_unique=True) 417 | anchor_neg_negs_clean = np.sort(anchor_neg_negs_clean) 418 | with h5py.File(self.cache, mode='r') as h5: 419 | h5feat = h5.get("features") 420 | negFeat = h5feat[anchor_neg_negs_clean.tolist()] 421 | negFeat = torch.tensor(negFeat) 422 | dist = torch.norm(qFeat - negFeat, dim=1, p=None) 423 | result = dist.topk(self.nNeg * 10, largest=False) 424 | dNeg, negNN = result.values, result.indices 425 | violatingNeg = dNeg.numpy() < dPos.numpy() + self.margin2**0.5 # increase negative samples by using **0.5 426 | if np.sum(violatingNeg) < 1: 427 | return None 428 | negNN = negNN.numpy() 429 | negNN = negNN[violatingNeg][:1] 430 | neg2Index = anchor_neg_negs_clean[negNN].astype(np.int32)[0] 431 | 432 | negative = self.load_images(join(self.img_dir, 'database', self.dbStruct.dbImage[negIndex])) 433 | negative2 = self.load_images(join(self.img_dir, 'database', self.dbStruct.dbImage[neg2Index])) 434 | negatives.append(negative) 435 | negatives2.append(negative2) 436 | negIndices2.append(neg2Index) 437 | 438 | negatives = torch.stack(negatives, 0) # ([num_neg, C, H, W]) 439 | negatives2 = torch.stack(negatives2, 0) # ([num_neg, C, H, W]) 440 | return query, positive, negatives, negatives2, [index, posIndex] + negIndices.tolist() + negIndices2 441 | 442 | def __len__(self): 443 | return len(self.queries) 444 | 445 | 446 | def collate_quad_fn(batch): 447 | """Creates mini-batch tensors from the list of tuples (query, positive, negatives). 448 | 449 | Args: 450 | data: list of tuple (query, positive, negatives). 451 | - query: torch tensor of shape (3, h, w). 452 | - positive: torch tensor of shape (3, h, w). 453 | - negative: torch tensor of shape (n, 3, h, w). 454 | - negative2: torch tensor of shape (n, 3, h, w). 455 | Returns: 456 | query: torch tensor of shape (batch_size, 3, h, w). 457 | positive: torch tensor of shape (batch_size, 3, h, w). 458 | negatives: torch tensor of shape (batch_size, n, 3, h, w). 459 | """ 460 | 461 | batch = list(filter(lambda x: x is not None, batch)) 462 | if len(batch) == 0: 463 | return None, None, None, None, None, None 464 | 465 | query, positive, negatives, negatives2, indices = zip(*batch) 466 | 467 | query = data.dataloader.default_collate(query) # ([8, 3, 200, 200]) = [(3, 200, 200), (3, 200, 200), .. ] ([8, 1, 3, 200, 200]) 468 | positive = data.dataloader.default_collate(positive) 469 | negCounts = data.dataloader.default_collate([x.shape[0] for x in negatives]) 470 | negatives = torch.cat(negatives, 0) # ([80, 3, 200, 200]) ([80, 1, 3, 200, 200]) 471 | negatives2 = torch.cat(negatives2, 0) # ([80, 3, 200, 200]) ([80, 1, 3, 200, 200]) 472 | import itertools 473 | indices = list(itertools.chain(*indices)) 474 | 475 | return query, positive, negatives, negatives2, negCounts, indices -------------------------------------------------------------------------------- /find_pair.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import importlib 3 | import os 4 | import pickle 5 | import shutil 6 | from os.path import dirname, exists, join 7 | import h5py 8 | import faiss 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import wandb 13 | from torch.utils.data import DataLoader 14 | from tqdm import tqdm 15 | from datetime import datetime 16 | import json 17 | import torch.optim as optim 18 | from torchsummary import summary 19 | 20 | os.sys.path.append(os.path.join(os.path.dirname(__file__), '../')) 21 | import time 22 | from options import FixRandom 23 | from utils import cal_recall, light_log, schedule_device 24 | import importlib 25 | from PIL import Image 26 | import torchvision.transforms as transforms 27 | from torchvision.transforms import InterpolationMode 28 | from scipy.io import loadmat 29 | from collections import namedtuple 30 | 31 | dbStruct = namedtuple('dbStruct', 32 | ['whichSet', 'dataset', 'dbImage', 'utmDb', 'qImage', 'utmQ', 'numDb', 'numQ', 'posDistThr', 33 | 'posDistSqThr', 'nonTrivPosDistSqThr']) 34 | 35 | class CKD_loss(nn.Module): 36 | def __init__(self, margin) -> None: 37 | super().__init__() 38 | self.margin = margin 39 | 40 | def forward(self, embs_a, embs_p, embs_n, mu_tea_a, mu_tea_p, mu_tea_n): # (1, D) 41 | SaTp = torch.norm(embs_a - mu_tea_p, p=2).pow(2) 42 | SpTa = torch.norm(embs_p - mu_tea_a, p=2).pow(2) 43 | 44 | SaTn = torch.norm(embs_a - mu_tea_n, p=2).pow(2) 45 | SnTa = torch.norm(embs_n - mu_tea_a, p=2).pow(2) 46 | 47 | SaTa = torch.norm(embs_a - mu_tea_a, p=2).pow(2) 48 | SpTp = torch.norm(embs_p - mu_tea_p, p=2).pow(2) 49 | SnTn = torch.norm(embs_n - mu_tea_n, p=2).pow(2) 50 | dis_D = SpTp + SnTn 51 | # dis_D =SaTp+SpTa+SaTa+SpTp+SnTn 52 | # dis_D=SaTa+SpTp+SnTn 53 | loss = 0.5 * (torch.clamp(self.margin + dis_D, min=0).pow(2)) 54 | 55 | return loss 56 | 57 | 58 | class Trainer: 59 | def __init__(self, options) -> None: 60 | 61 | self.opt = options 62 | 63 | # r variables 64 | self.step = 0 65 | self.epoch = 0 66 | self.current_lr = 0 67 | self.best_recalls = [0, 0, 0] 68 | 69 | # seed 70 | fix_random = FixRandom(self.opt.seed) 71 | self.seed_worker = fix_random.seed_worker() 72 | self.time_stamp = datetime.now().strftime('%m%d_%H%M%S') 73 | 74 | # set device 75 | if self.opt.phase == 'train_tea': 76 | self.opt.cGPU = schedule_device() 77 | if self.opt.cuda and not torch.cuda.is_available(): 78 | raise Exception("No GPU found, please run with --nocuda :(") 79 | torch.cuda.set_device(self.opt.cGPU) 80 | self.device = torch.device("cuda") 81 | print('{}:{}{}'.format('device', self.device, torch.cuda.current_device())) 82 | 83 | # CKD_loss 84 | self.CKD_loss = CKD_loss(margin=torch.tensor(self.opt.margin, device=self.device)) 85 | # make model 86 | if self.opt.phase == 'train_tea': 87 | self.model, self.optimizer, self.scheduler, self.criterion = self.make_model() 88 | elif self.opt.phase == 'train_stu': 89 | self.teacher_net, self.student_net, self.optimizer, self.scheduler, self.criterion = self.make_model() 90 | self.model = self.teacher_net 91 | elif self.opt.phase in ['test_tea', 'test_stu']: 92 | self.model = self.make_model() 93 | else: 94 | raise Exception('Undefined phase :(') 95 | 96 | # make folders 97 | self.make_folders() 98 | # make dataset 99 | self.make_dataset() 100 | # online logs 101 | if self.opt.phase in ['train_tea', 'train_stu']: 102 | wandb.init(project="STUN", config=vars(self.opt), 103 | name=f"{self.opt.loss}_{self.opt.phase}_{self.time_stamp}") 104 | 105 | def make_folders(self): 106 | ''' create folders to store tensorboard files and a copy of networks files 107 | ''' 108 | if self.opt.phase in ['train_tea', 'train_stu']: 109 | self.opt.runsPath = join(self.opt.logsPath, f"{self.opt.loss}_{self.opt.phase}_{self.time_stamp}") 110 | if not os.path.exists(join(self.opt.runsPath, 'models')): 111 | os.makedirs(join(self.opt.runsPath, 'models')) 112 | 113 | if not os.path.exists(join(self.opt.runsPath, 'transformed')): 114 | os.makedirs(join(self.opt.runsPath, 'transformed')) 115 | 116 | for file in [__file__, 'datasets/{}.py'.format(self.opt.dataset), 'networks/{}.py'.format(self.opt.net)]: 117 | shutil.copyfile(file, os.path.join(self.opt.runsPath, 'models', file.split('/')[-1])) 118 | 119 | with open(join(self.opt.runsPath, 'flags.json'), 'w') as f: 120 | f.write(json.dumps({k: v for k, v in vars(self.opt).items()}, indent='')) 121 | 122 | def make_dataset(self): 123 | ''' make dataset 124 | ''' 125 | if self.opt.phase in ['train_tea', 'train_stu']: 126 | assert os.path.exists(f'datasets/{self.opt.dataset}.py'), 'Cannot find ' + f'{self.opt.dataset}.py :(' 127 | self.dataset = importlib.import_module('datasets.' + self.opt.dataset) 128 | elif self.opt.phase in ['test_tea', 'test_stu']: 129 | self.dataset = importlib.import_module('tmp.models.{}'.format(self.opt.dataset)) 130 | 131 | # for emb cache 132 | self.whole_train_set = self.dataset.get_whole_training_set(self.opt) 133 | self.whole_training_data_loader = DataLoader(dataset=self.whole_train_set, num_workers=self.opt.threads, 134 | batch_size=self.opt.cacheBatchSize, shuffle=False, 135 | pin_memory=self.opt.cuda, worker_init_fn=self.seed_worker) 136 | self.whole_val_set = self.dataset.get_whole_val_set(self.opt) 137 | self.whole_val_data_loader = DataLoader(dataset=self.whole_val_set, num_workers=self.opt.threads, 138 | batch_size=self.opt.cacheBatchSize, shuffle=False, 139 | pin_memory=self.opt.cuda, worker_init_fn=self.seed_worker) 140 | self.whole_test_set = self.dataset.get_whole_test_set(self.opt) 141 | self.whole_test_data_loader = DataLoader(dataset=self.whole_test_set, num_workers=self.opt.threads, 142 | batch_size=self.opt.cacheBatchSize, shuffle=False, 143 | pin_memory=self.opt.cuda, worker_init_fn=self.seed_worker) 144 | # for train tuples 145 | self.train_set = self.dataset.get_training_query_set(self.opt, self.opt.margin) 146 | self.training_data_loader = DataLoader(dataset=self.train_set, num_workers=8, batch_size=self.opt.batchSize, 147 | shuffle=True, collate_fn=self.dataset.collate_fn, 148 | worker_init_fn=self.seed_worker) 149 | print('{}:{}, {}:{}, {}:{}, {}:{}, {}:{}'.format('dataset', self.opt.dataset, 'database', 150 | self.whole_train_set.dbStruct.numDb, 'train_set', 151 | self.whole_train_set.dbStruct.numQ, 'val_set', 152 | self.whole_val_set.dbStruct.numQ, 'test_set', 153 | self.whole_test_set.dbStruct.numQ)) 154 | print('{}:{}, {}:{}'.format('cache_bs', self.opt.cacheBatchSize, 'tuple_bs', self.opt.batchSize)) 155 | 156 | def make_model(self): 157 | '''build model 158 | ''' 159 | if self.opt.phase == 'train_tea': 160 | # build teacher net 161 | assert os.path.exists(f'networks/{self.opt.net}.py'), 'Cannot find ' + f'{self.opt.net}.py :(' 162 | network = importlib.import_module('networks.' + self.opt.net) 163 | model = network.deliver_model(self.opt, 'tea') 164 | model = model.to(self.device) 165 | outputs = model(torch.rand((2, 3, self.opt.height, self.opt.width), device=self.device)) 166 | self.opt.output_dim = \ 167 | model(torch.rand((2, 3, self.opt.height, self.opt.width), device=self.device))[0].shape[-1] 168 | self.opt.sigma_dim = \ 169 | model(torch.rand((2, 3, self.opt.height, self.opt.width), device=self.device))[1].shape[-1] # place holder 170 | elif self.opt.phase == 'train_stu': # load teacher net 171 | assert self.opt.resume != '', 'You need to define the teacher/resume path :(' 172 | if exists('tmp'): 173 | shutil.rmtree('tmp') 174 | os.mkdir('tmp') 175 | shutil.copytree(join(dirname(self.opt.resume), 'models'), join('tmp', 'models')) 176 | network = importlib.import_module(f'tmp.models.{self.opt.net}') 177 | model_tea = network.deliver_model(self.opt, 'tea').to(self.device) 178 | checkpoint = torch.load(self.opt.resume) 179 | model_tea.load_state_dict(checkpoint['state_dict']) 180 | # build student net 181 | assert os.path.exists(f'networks/{self.opt.net}.py'), 'Cannot find ' + f'{self.opt.net}.py :(' 182 | network = importlib.import_module('networks.' + self.opt.net) 183 | model = network.deliver_model(self.opt, 'stu').to(self.device) 184 | #checkpointS = torch.load('logs/tri_train_stu_0820_220921/ckpt_e_56.pth.tar') 185 | #model.load_state_dict(checkpointS['state_dict']) 186 | self.opt.output_dim = \ 187 | model(torch.rand((2, 3, self.opt.height, self.opt.width), device=self.device))[0].shape[-1] 188 | self.opt.sigma_dim = \ 189 | model(torch.rand((2, 3, self.opt.height, self.opt.width), device=self.device))[1].shape[-1] 190 | elif self.opt.phase in ['test_tea', 'test_stu']: 191 | # load teacher or student net 192 | assert self.opt.resume != '', 'You need to define a teacher/resume path :(' 193 | if exists('tmp'): 194 | shutil.rmtree('tmp') 195 | os.mkdir('tmp') 196 | shutil.copytree(join(dirname(self.opt.resume), 'models'), join('tmp', 'models')) 197 | network = importlib.import_module('tmp.models.{}'.format(self.opt.net)) 198 | model = network.deliver_model(self.opt, self.opt.phase[-3:]).to(self.device) 199 | checkpoint = torch.load(self.opt.resume) 200 | model.load_state_dict(checkpoint['state_dict']) 201 | 202 | print('{}:{}, {}:{}, {}:{}'.format(model.id, self.opt.net, 'loss', self.opt.loss, 'mu_dim', self.opt.output_dim, 203 | 'sigma_dim', self.opt.sigma_dim if self.opt.phase[-3:] == 'stu' else '-')) 204 | 205 | if self.opt.phase in ['train_tea', 'train_stu']: 206 | # optimizer 207 | if self.opt.optim == 'adam': 208 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), self.opt.lr, 209 | weight_decay=self.opt.weightDecay) 210 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, self.opt.lrGamma, last_epoch=-1, verbose=False) 211 | elif self.opt.optim == 'sgd': 212 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=self.opt.lr, 213 | momentum=self.opt.momentum, weight_decay=self.opt.weightDecay) 214 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.opt.lrStep, gamma=self.opt.lrGamma) 215 | else: 216 | raise NameError('Undefined optimizer :(') 217 | 218 | # loss function 219 | criterion = nn.TripletMarginLoss(margin=self.opt.margin, p=2, reduction='sum').to(self.device) 220 | 221 | if self.opt.nGPU > 1: 222 | model = nn.DataParallel(model) 223 | 224 | if self.opt.phase == 'train_tea': 225 | return model, optimizer, scheduler, criterion 226 | elif self.opt.phase == 'train_stu': 227 | return model_tea, model, optimizer, scheduler, criterion 228 | elif self.opt.phase in ['test_tea', 'test_stu']: 229 | return model 230 | else: 231 | raise NameError('Undefined phase :(') 232 | 233 | def build_embedding_cache(self): 234 | '''build embedding cache, such that we can find the corresponding (p) and (n) with respect to (a) in embedding space 235 | ''' 236 | self.train_set.cache = os.path.join(self.opt.runsPath, self.train_set.whichSet + '_feat_cache.hdf5') 237 | with h5py.File(self.train_set.cache, mode='w') as h5: 238 | h5feat = h5.create_dataset("features", [len(self.whole_train_set), self.opt.output_dim], dtype=np.float32) 239 | with torch.no_grad(): 240 | for iteration, (input, indices) in enumerate(tqdm(self.whole_training_data_loader), 1): 241 | input = input.to(self.device) # torch.Size([32, 3, 154, 154]) ([32, 5, 3, 200, 200]) 242 | emb, _ = self.model(input) 243 | h5feat[indices.detach().numpy(), :] = emb.detach().cpu().numpy() 244 | del input, emb 245 | 246 | def build_embedding_cache_stu(self): 247 | '''build embedding cache, such that we can find the corresponding (p) and (n) with respect to (a) in embedding space 248 | ''' 249 | self.train_set.cache = os.path.join(self.opt.runsPath, self.train_set.whichSet + '_feat_cache.hdf5') 250 | with h5py.File(self.train_set.cache, mode='w') as h5: 251 | h5feat = h5.create_dataset("features", [len(self.whole_train_set), self.opt.output_dim], dtype=np.float32) 252 | with torch.no_grad(): 253 | for iteration, (input, indices) in enumerate(tqdm(self.whole_training_data_loader), 1): 254 | input = input.to(self.device) # torch.Size([32, 3, 154, 154]) ([32, 5, 3, 200, 200]) 255 | emb, _ = self.student_net(input) 256 | h5feat[indices.detach().numpy(), :] = emb.detach().cpu().numpy() 257 | del input, emb 258 | 259 | def process_batch(self, batch_inputs): 260 | ''' 261 | process a batch of input 262 | ''' 263 | anchor, positives, negatives, neg_counts, indices = batch_inputs 264 | 265 | # in case we get an empty batch 266 | if anchor is None: 267 | return None, None 268 | 269 | # some reshaping to put query, pos, negs in a single (N, 3, H, W) tensor, where N = batchSize * (nQuery + nPos + n_neg) 270 | B = anchor.shape[0] # ([8, 1, 3, 200, 200]) 271 | n_neg = torch.sum(neg_counts) # tensor(80) = torch.sum(torch.Size([8])) 272 | 273 | input = torch.cat([anchor, positives, negatives]) # ([B, C, H, 200]) 274 | 275 | input = input.to(self.device) # ([96, 1, C, H, W]) 276 | embs, vars = self.model(input) # ([96, D]) 277 | 278 | tuple_loss = 0 279 | # Standard triplet loss (via PyTorch library) 280 | if self.opt.loss == 'tri': 281 | embs_a, embs_p, embs_n = torch.split(embs, [B, B, n_neg]) 282 | for i, neg_count in enumerate(neg_counts): 283 | for n in range(neg_count): 284 | negIx = (torch.sum(neg_counts[:i]) + n).item() 285 | tuple_loss += self.criterion(embs_a[i:i + 1], embs_p[i:i + 1], embs_n[negIx:negIx + 1]) 286 | tuple_loss /= n_neg.float().to(self.device) 287 | 288 | del input, embs, embs_a, embs_p, embs_n 289 | del anchor, positives, negatives 290 | 291 | return tuple_loss, n_neg 292 | 293 | def process_batch_stu(self, batch_inputs): 294 | ''' 295 | process a batch of input 296 | ''' 297 | anchor, positives, negatives, neg_counts, indices = batch_inputs 298 | 299 | # in case we get an empty batch 300 | if anchor is None: 301 | return None, None 302 | 303 | # some reshaping to put query, pos, negs in a single (N, 3, H, W) tensor, where N = batchSize * (nQuery + nPos + n_neg) 304 | B = anchor.shape[0] # ([8, 1, 3, 200, 200]) 305 | n_neg = torch.sum(neg_counts) # tensor(80) = torch.sum(torch.Size([8])) 306 | input = torch.cat([anchor, positives, negatives]) # ([B, C, H, 200]) 307 | 308 | input = input.to(self.device) # ([96, 1, C, H, W]) 309 | embs, vars = self.student_net(input) # ([96, D]) 310 | 311 | anchor = anchor.to(self.device) 312 | with torch.no_grad(): 313 | mu_tea, _ = self.teacher_net(input) # ([B, D]) 314 | # mu_stu, log_sigma_sq = self.student_net(anchor) # ([B, D]), ([B, D]) 315 | 316 | tuple_loss = 0 317 | loss = 0 318 | CKDloss = 0 319 | 320 | # Standard triplet loss (via PyTorch library) 321 | if self.opt.loss == 'tri': 322 | embs_a, embs_p, embs_n = torch.split(embs, [B, B, n_neg]) 323 | vars_a, vars_p, vars_n = torch.split(vars, [B, B, n_neg]) 324 | mu_tea_a, mu_tea_p, mu_tea_n = torch.split(mu_tea, [B, B, n_neg]) 325 | for i, neg_count in enumerate(neg_counts): 326 | for n in range(neg_count): 327 | negIx = (torch.sum(neg_counts[:i]) + n).item() 328 | tuple_loss += self.criterion(embs_a[i:i + 1], embs_p[i:i + 1], embs_n[negIx:negIx + 1]) 329 | CKDloss += self.CKD_loss(embs_a[i:i + 1], embs_p[i:i + 1], embs_n[negIx:negIx + 1], 330 | mu_tea_a[i:i + 1], mu_tea_p[i:i + 1], mu_tea_n[negIx:negIx + 1]) 331 | 332 | tuple_loss /= n_neg.float().to(self.device) 333 | CKDloss /= n_neg.float().to(self.device) 334 | 335 | del input, embs, embs_a, embs_p, embs_n 336 | del anchor, positives, negatives 337 | 338 | return loss, n_neg 339 | 340 | def train(self): 341 | not_improved = 0 342 | for epoch in range(self.opt.nEpochs): 343 | self.epoch = epoch 344 | self.current_lr = self.optimizer.state_dict()['param_groups'][0]['lr'] 345 | 346 | # build embedding cache 347 | if self.epoch % self.opt.cacheRefreshEvery == 0: 348 | self.model.eval() 349 | self.build_embedding_cache() 350 | self.model.train() 351 | 352 | # train 353 | tuple_loss_sum = 0 354 | for _, batch_inputs in enumerate(tqdm(self.training_data_loader)): 355 | self.step += 1 356 | 357 | self.optimizer.zero_grad() 358 | tuple_loss, n_neg = self.process_batch(batch_inputs) 359 | if tuple_loss is None: 360 | continue 361 | tuple_loss.backward() 362 | self.optimizer.step() 363 | tuple_loss_sum += tuple_loss.item() 364 | 365 | if self.step % 10 == 0: 366 | wandb.log({'train_tuple_loss': tuple_loss.item()}, step=self.step) 367 | wandb.log({'train_batch_num_neg': n_neg}, step=self.step) 368 | 369 | n_batches = len(self.training_data_loader) 370 | wandb.log({'train_avg_tuple_loss': tuple_loss_sum / n_batches}, step=self.step) 371 | torch.cuda.empty_cache() 372 | self.scheduler.step() 373 | 374 | # val every x epochs 375 | if (self.epoch % self.opt.evalEvery) == 0: 376 | recalls = self.val(self.model) 377 | if recalls[0] > self.best_recalls[0]: 378 | self.best_recalls = recalls 379 | not_improved = 0 380 | else: 381 | not_improved += self.opt.evalEvery 382 | # light log 383 | vars_to_log = [ 384 | 'e={:>2d},'.format(self.epoch), 385 | 'lr={:>.8f},'.format(self.current_lr), 386 | 'tl={:>.4f},'.format(tuple_loss_sum / n_batches), 387 | 'r@1/5/10={:.2f}/{:.2f}/{:.2f}'.format(recalls[0], recalls[1], recalls[2]), 388 | '\n' if not_improved else ' *\n', 389 | ] 390 | light_log(self.opt.runsPath, vars_to_log) 391 | else: 392 | recalls = None 393 | self.save_model(self.model, is_best=not not_improved) 394 | 395 | # stop when not improving for a period 396 | if self.opt.phase == 'train_tea': 397 | if self.opt.patience > 0 and not_improved > self.opt.patience: 398 | print('terminated because performance has not improve for', self.opt.patience, 'epochs') 399 | break 400 | 401 | self.save_model(self.model, is_best=False) 402 | print('best r@1/5/10={:.2f}/{:.2f}/{:.2f}'.format(self.best_recalls[0], self.best_recalls[1], 403 | self.best_recalls[2])) 404 | 405 | return self.best_recalls 406 | 407 | def train_student(self): 408 | not_improved = 0 409 | for epoch in range(self.opt.nEpochs): 410 | self.epoch = epoch 411 | self.current_lr = self.optimizer.state_dict()['param_groups'][0]['lr'] 412 | 413 | # build embedding cache 414 | if self.epoch % self.opt.cacheRefreshEvery == 0: 415 | self.student_net.eval() 416 | self.build_embedding_cache() 417 | self.student_net.train() 418 | # train 419 | tuple_loss_sum = 0 420 | for _, batch_inputs in enumerate(tqdm(self.training_data_loader)): 421 | self.step += 1 422 | 423 | self.optimizer.zero_grad() 424 | tuple_loss, n_neg = self.process_batch_stu(batch_inputs) 425 | if tuple_loss is None: 426 | continue 427 | tuple_loss.backward() 428 | self.optimizer.step() 429 | tuple_loss_sum += tuple_loss.item() 430 | loss_sum = tuple_loss_sum 431 | if self.step % 10 == 0: 432 | wandb.log({'train_tuple_loss': tuple_loss.item()}, step=self.step) 433 | wandb.log({'train_batch_num_neg': n_neg}, step=self.step) 434 | 435 | n_batches = len(self.training_data_loader) 436 | wandb.log({'train_avg_tuple_loss': tuple_loss_sum / n_batches}, step=self.step) 437 | wandb.log({'student/epoch_loss': loss_sum / n_batches}, step=self.step) 438 | torch.cuda.empty_cache() 439 | self.scheduler.step() 440 | 441 | # val 442 | if (self.epoch % self.opt.evalEvery) == 0: 443 | recalls = self.val(self.student_net) 444 | if recalls[0] > self.best_recalls[0]: 445 | self.best_recalls = recalls 446 | not_improved = 0 447 | else: 448 | not_improved += self.opt.evalEvery 449 | 450 | light_log(self.opt.runsPath, [ 451 | f'e={self.epoch:>2d},', 452 | f'lr={self.current_lr:>.8f},', 453 | f'tl={loss_sum / n_batches:>.4f},', 454 | f'r@1/5/10={recalls[0]:.2f}/{recalls[1]:.2f}/{recalls[2]:.2f}', 455 | '\n' if not_improved else ' *\n', 456 | ]) 457 | else: 458 | recalls = None 459 | 460 | self.save_model(self.student_net, is_best=False, save_every_epoch=True) 461 | if self.opt.patience > 0 and not_improved > self.opt.patience: 462 | print('terminated because performance has not improve for', self.opt.patience, 'epochs') 463 | break 464 | 465 | print('best r@1/5/10={:.2f}/{:.2f}/{:.2f}'.format(self.best_recalls[0], self.best_recalls[1], 466 | self.best_recalls[2])) 467 | return self.best_recalls 468 | 469 | def val(self, model): 470 | recalls, _ = self.get_recall(model) 471 | for i, n in enumerate([1, 5, 10]): 472 | wandb.log({'{}/{}_r@{}'.format(model.id, self.opt.split, n): recalls[i]}, step=self.step) 473 | # self.writer.add_scalar('{}/{}_r@{}'.format(model.id, self.opt.split, n), recalls[i], self.epoch) 474 | 475 | return recalls 476 | 477 | def test(self): 478 | # recalls, _ = self.get_recall(self.model, save_embs=True) 479 | # print('best r@1/5/10={:.2f}/{:.2f}/{:.2f}'.format(recalls[0], recalls[1], recalls[2])) 480 | self.test4image(self.model) 481 | # return recalls 482 | return None 483 | 484 | def input_transform(self, opt=None): 485 | return transforms.Compose([ 486 | transforms.ToTensor(), 487 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 488 | transforms.Resize((opt.height, opt.width), interpolation=InterpolationMode.BILINEAR), 489 | ]) 490 | 491 | def test4image(self, model): 492 | model.eval() 493 | 494 | if self.opt.split == 'val': 495 | eval_dataloader = self.whole_val_data_loader 496 | eval_set = self.whole_val_set 497 | elif self.opt.split == 'test': 498 | eval_dataloader = self.whole_test_data_loader 499 | eval_set = self.whole_test_set 500 | # print(f"{self.opt.split} len:{len(eval_set)}") 501 | whole_mu = torch.zeros((len(eval_set), self.opt.output_dim), device=self.device) # (N, D) 502 | whole_var = torch.zeros((len(eval_set), self.opt.sigma_dim), device=self.device) # (N, D) 503 | mu_in = torch.zeros((1, self.opt.output_dim), device=self.device) 504 | whole_input = torch.zeros((len(eval_set), 1), device=self.device) 505 | gt = eval_set.get_positives() # (N, n_pos) 506 | 507 | with torch.no_grad(): 508 | inputimage_path = "pittsburgh/query/004/004828_pitch2_yaw11.jpg" 509 | input_transform = self.input_transform(self.opt) 510 | inputimage = input_transform(Image.open(inputimage_path)) 511 | inputimage_device = inputimage.unsqueeze(0) 512 | inputimage_device = inputimage_device.to(self.device) 513 | mu_inputimage, var_inputimage = model(inputimage_device) 514 | mu_in[0, :] = mu_inputimage 515 | 516 | del mu_inputimage, var_inputimage 517 | for iteration, (input, indices) in enumerate(tqdm(eval_dataloader), 1): 518 | input = input.to(self.device) 519 | mu, var = model(input) # (B, D) 520 | # print(input) #(128,3,224,224) 521 | # var = torch.exp(var) 522 | 523 | whole_mu[indices, :] = mu 524 | whole_var[indices, :] = var 525 | del input, mu, var 526 | n_values = [3] 527 | 528 | whole_var = torch.exp(whole_var) 529 | whole_mu = whole_mu.cpu().numpy() 530 | whole_var = whole_var.cpu().numpy() 531 | mu_in = mu_in.cpu().numpy() 532 | # print(mu_in.shape) 533 | mu_inquery = mu_in[:1].astype('float32') 534 | # print(mu_inquery.shape) 535 | mu_q = whole_mu[eval_set.dbStruct.numDb:].astype('float32') 536 | # print(mu_q.shape) 537 | mu_db = whole_mu[:eval_set.dbStruct.numDb].astype('float32') 538 | sigma_q = whole_var[eval_set.dbStruct.numDb:].astype('float32') 539 | sigma_db = whole_var[:eval_set.dbStruct.numDb].astype('float32') 540 | faiss_index = faiss.IndexFlatL2(mu_q.shape[1]) 541 | faiss_index.add(mu_db) 542 | dists, preds = faiss_index.search(mu_q, max(n_values)) # the results is sorted 543 | 544 | dists_input, preds_input = faiss_index.search(mu_inquery, max(n_values)) 545 | print(preds_input[0, 0]) 546 | print(dists_input[0, 0]) 547 | pair_index = preds_input[0, 0] 548 | 549 | structFile = join(self.opt.structDir, 'pitts30k_test.mat') 550 | self.dbStruct = parse_dbStruct(structFile) 551 | image_pair_path = join('pittsburgh', 'database', self.dbStruct.dbImage[pair_index]) 552 | print(image_pair_path) 553 | image_pair = Image.open(image_pair_path) 554 | image_pair.save('output_image.jpg') 555 | # img_dir= 556 | # path=join(img_dir, 'database', dbIm) 557 | 558 | return None 559 | 560 | def save_model(self, model, is_best=False, save_every_epoch=False): 561 | if is_best: 562 | torch.save({ 563 | 'epoch': self.epoch, 564 | 'step': self.step, 565 | 'state_dict': model.state_dict(), 566 | 'optimizer': self.optimizer.state_dict(), 567 | 'scheduler': self.scheduler.state_dict(), 568 | }, os.path.join(self.opt.runsPath, 'ckpt_best.pth.tar')) 569 | 570 | if save_every_epoch: 571 | torch.save({ 572 | 'epoch': self.epoch, 573 | 'step': self.step, 574 | 'state_dict': model.state_dict(), 575 | 'optimizer': self.optimizer.state_dict(), 576 | 'scheduler': self.scheduler.state_dict(), 577 | }, os.path.join(self.opt.runsPath, 'ckpt_e_{}.pth.tar'.format(self.epoch))) 578 | 579 | def get_recall(self, model, save_embs=False): 580 | model.eval() 581 | 582 | if self.opt.split == 'val': 583 | eval_dataloader = self.whole_val_data_loader 584 | eval_set = self.whole_val_set 585 | elif self.opt.split == 'test': 586 | eval_dataloader = self.whole_test_data_loader 587 | eval_set = self.whole_test_set 588 | # print(f"{self.opt.split} len:{len(eval_set)}") 589 | 590 | whole_mu = torch.zeros((len(eval_set), self.opt.output_dim), device=self.device) # (N, D) 591 | whole_var = torch.zeros((len(eval_set), self.opt.sigma_dim), device=self.device) # (N, D) 592 | gt = eval_set.get_positives() # (N, n_pos) 593 | start_time = time.time() 594 | with torch.no_grad(): 595 | for iteration, (input, indices) in enumerate(tqdm(eval_dataloader), 1): 596 | # print(f"Batch {iteration}, Indices: {indices}") 597 | input = input.to(self.device) 598 | mu, var = model(input) # (B, D) 599 | # summary(self.model, input_size=input.shape[1:]) 600 | # print(input.shape) 601 | # var = torch.exp(var) 602 | whole_mu[indices, :] = mu 603 | whole_var[indices, :] = var 604 | del input, mu, var 605 | end_time = time.time() 606 | 607 | elapsed_time = end_time - start_time 608 | print("Elapsed Time:", elapsed_time) 609 | n_values = [1, 5, 10] 610 | 611 | whole_var = torch.exp(whole_var) 612 | whole_mu = whole_mu.cpu().numpy() 613 | whole_var = whole_var.cpu().numpy() 614 | mu_q = whole_mu[eval_set.dbStruct.numDb:].astype('float32') 615 | mu_db = whole_mu[:eval_set.dbStruct.numDb].astype('float32') 616 | sigma_q = whole_var[eval_set.dbStruct.numDb:].astype('float32') 617 | sigma_db = whole_var[:eval_set.dbStruct.numDb].astype('float32') 618 | faiss_index = faiss.IndexFlatL2(mu_q.shape[1]) 619 | faiss_index.add(mu_db) 620 | dists, preds = faiss_index.search(mu_q, max(n_values)) # the results is sorted 621 | 622 | # cull queries without any ground truth positives in the database 623 | val_inds = [True if len(gt[ind]) != 0 else False for ind in range(len(gt))] 624 | val_inds = np.array(val_inds) 625 | mu_q = mu_q[val_inds] 626 | sigma_q = sigma_q[val_inds] 627 | preds = preds[val_inds] 628 | dists = dists[val_inds] 629 | gt = gt[val_inds] 630 | 631 | recall_at_k = cal_recall(preds, gt, n_values) 632 | 633 | if save_embs: 634 | with open(join(self.opt.runsPath, '{}_db_embeddings_{}.pickle'.format(self.opt.split, 635 | self.opt.resume.split('.')[-3].split( 636 | '_')[-1])), 'wb') as handle: 637 | pickle.dump(mu_q, handle, protocol=pickle.HIGHEST_PROTOCOL) 638 | pickle.dump(mu_db, handle, protocol=pickle.HIGHEST_PROTOCOL) 639 | pickle.dump(sigma_q, handle, protocol=pickle.HIGHEST_PROTOCOL) 640 | pickle.dump(sigma_db, handle, protocol=pickle.HIGHEST_PROTOCOL) 641 | pickle.dump(preds, handle, protocol=pickle.HIGHEST_PROTOCOL) 642 | pickle.dump(dists, handle, protocol=pickle.HIGHEST_PROTOCOL) 643 | pickle.dump(gt, handle, protocol=pickle.HIGHEST_PROTOCOL) 644 | pickle.dump(whole_mu, handle, protocol=pickle.HIGHEST_PROTOCOL) 645 | pickle.dump(whole_var, handle, protocol=pickle.HIGHEST_PROTOCOL) 646 | print('embeddings saved for post processing') 647 | 648 | return recall_at_k, None 649 | 650 | 651 | def parse_dbStruct(path): 652 | mat = loadmat(path) 653 | matStruct = mat['dbStruct'].item() 654 | 655 | dataset = 'nuscenes' 656 | 657 | whichSet = matStruct[0].item() 658 | 659 | # .mat file is generated by python, Kaiwen replaces the use of cell (in Matlab) with char (in Python) 660 | dbImage = [f[0].item() for f in matStruct[1]] 661 | # dbImage = matStruct[1] 662 | utmDb = matStruct[2].T 663 | # utmDb = matStruct[2] 664 | 665 | # .mat file is generated by python, I replace the use of cell (in Matlab) with char (in Python) 666 | qImage = [f[0].item() for f in matStruct[3]] 667 | # qImage = matStruct[3] 668 | utmQ = matStruct[4].T 669 | # utmQ = matStruct[4] 670 | 671 | numDb = matStruct[5].item() 672 | numQ = matStruct[6].item() 673 | 674 | posDistThr = matStruct[7].item() 675 | posDistSqThr = matStruct[8].item() 676 | nonTrivPosDistSqThr = matStruct[9].item() 677 | 678 | return dbStruct(whichSet, dataset, dbImage, utmDb, qImage, utmQ, numDb, numQ, posDistThr, posDistSqThr, 679 | nonTrivPosDistSqThr) 680 | 681 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from os.path import dirname, join 2 | 3 | import trainer 4 | from options import Options 5 | 6 | options_handler = Options() 7 | options = options_handler.parse() 8 | 9 | if __name__ == "__main__": 10 | 11 | if options.phase in ['test_tea', 'test_stu', 'train_stu']: 12 | print(f'resume from {options.resume}') 13 | options = options_handler.update_opt_from_json(join(dirname(options.resume), 'flags.json'), options) 14 | options.nEpochs = 200 15 | tr = trainer.Trainer(options) 16 | print(tr.opt.phase, '-->', tr.opt.runsPath) 17 | elif options.phase in ['train_tea']: 18 | tr = trainer.Trainer(options) 19 | print(tr.opt.phase, '-->', tr.opt.runsPath) 20 | 21 | if options.phase in ['train_tea']: 22 | tr.train() 23 | elif options.phase in ['train_stu']: 24 | tr.train_student() 25 | elif options.phase in ['test_tea', 'test_stu']: 26 | tr.test() -------------------------------------------------------------------------------- /netvlad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | class NetVLADLoupe(nn.Module): 7 | def __init__(self, feature_size, max_samples, cluster_size, output_dim, 8 | gating=True, add_batch_norm=True, is_training=True): 9 | super(NetVLADLoupe, self).__init__() 10 | self.feature_size = feature_size 11 | self.max_samples = max_samples 12 | self.output_dim = output_dim 13 | self.is_training = is_training 14 | self.gating = gating 15 | self.add_batch_norm = add_batch_norm 16 | self.cluster_size = cluster_size 17 | self.softmax = nn.Softmax(dim=-1) 18 | 19 | self.cluster_weights = nn.Parameter(torch.randn( 20 | feature_size, cluster_size) * 1 / math.sqrt(feature_size)) 21 | self.cluster_weights2 = nn.Parameter(torch.randn( 22 | 1, feature_size, cluster_size) * 1 / math.sqrt(feature_size)) 23 | self.hidden1_weights = nn.Parameter(torch.randn( 24 | cluster_size * feature_size, output_dim) * 1 / math.sqrt(feature_size)) 25 | 26 | if add_batch_norm: 27 | self.cluster_biases = None 28 | self.bn1 = nn.BatchNorm1d(cluster_size) 29 | else: 30 | self.cluster_biases = nn.Parameter(torch.randn( 31 | cluster_size) * 1 / math.sqrt(feature_size)) 32 | self.bn1 = None 33 | 34 | self.bn2 = nn.BatchNorm1d(output_dim) 35 | 36 | if gating: 37 | self.context_gating = GatingContext( 38 | output_dim, add_batch_norm=add_batch_norm) 39 | 40 | def forward(self, x): 41 | x = x.transpose(1, 3).contiguous() 42 | x = x.view((-1, self.max_samples, self.feature_size)) 43 | activation = torch.matmul(x, self.cluster_weights) 44 | if self.add_batch_norm: 45 | activation = activation.view(-1, self.cluster_size) 46 | activation = self.bn1(activation) 47 | activation = activation.view(-1, self.max_samples, self.cluster_size) 48 | else: 49 | activation = activation + self.cluster_biases 50 | activation = self.softmax(activation) 51 | activation = activation.view((-1, self.max_samples, self.cluster_size)) 52 | 53 | a_sum = activation.sum(-2, keepdim=True) 54 | a = a_sum * self.cluster_weights2 55 | 56 | activation = torch.transpose(activation, 2, 1) 57 | x = x.view((-1, self.max_samples, self.feature_size)) 58 | vlad = torch.matmul(activation, x) 59 | vlad = torch.transpose(vlad, 2, 1) 60 | vlad = vlad - a 61 | 62 | vlad = F.normalize(vlad, dim=1, p=2) 63 | vlad = vlad.reshape((-1, self.cluster_size * self.feature_size)) 64 | vlad = F.normalize(vlad, dim=1, p=2) 65 | vlad = torch.matmul(vlad, self.hidden1_weights) 66 | 67 | if self.gating: 68 | vlad = self.context_gating(vlad) 69 | 70 | return vlad 71 | 72 | 73 | class GatingContext(nn.Module): 74 | def __init__(self, dim, add_batch_norm=True): 75 | super(GatingContext, self).__init__() 76 | self.dim = dim 77 | self.add_batch_norm = add_batch_norm 78 | self.gating_weights = nn.Parameter( 79 | torch.randn(dim, dim) * 1 / math.sqrt(dim)) 80 | self.sigmoid = nn.Sigmoid() 81 | 82 | if add_batch_norm: 83 | self.gating_biases = None 84 | self.bn1 = nn.BatchNorm1d(dim) 85 | else: 86 | self.gating_biases = nn.Parameter( 87 | torch.randn(dim) * 1 / math.sqrt(dim)) 88 | self.bn1 = None 89 | 90 | def forward(self, x): 91 | gates = torch.matmul(x, self.gating_weights) 92 | 93 | if self.add_batch_norm: 94 | gates = self.bn1(gates) 95 | else: 96 | gates = gates + self.gating_biases 97 | 98 | gates = self.sigmoid(gates) 99 | activation = x * gates 100 | 101 | return activation 102 | 103 | if __name__ == '__main__': 104 | net_vlad = NetVLADLoupe(feature_size=512, max_samples=224, cluster_size=64, 105 | output_dim=256, gating=True, add_batch_norm=False, 106 | is_training=True) 107 | inputs = torch.rand((1, 512, 224, 1)) 108 | outputs_tea = net_vlad(inputs) 109 | print(outputs_tea.shape) 110 | -------------------------------------------------------------------------------- /networks/tscm.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import sys 3 | 4 | sys.path.append('..') 5 | from re import L 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.parameter import Parameter 9 | import cirtorch.functional as LF 10 | import math 11 | import torch.nn.functional as F 12 | import torch 13 | import timm 14 | from einops import rearrange, reduce, repeat 15 | from einops.layers.torch import Rearrange, Reduce 16 | import torchvision.models as models 17 | from netvlad import NetVLADLoupe 18 | from torch import Tensor 19 | class L2Norm(nn.Module): 20 | def __init__(self, dim=1): 21 | super().__init__() 22 | self.dim = dim 23 | 24 | def forward(self, input): 25 | return F.normalize(input, p=2, dim=self.dim) 26 | 27 | 28 | class PatchEmbedding(nn.Module): 29 | def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768): 30 | self.patch_size = patch_size 31 | super().__init__() 32 | self.projection = nn.Sequential( 33 | # 使用一个卷积层而不是一个线性层 -> 性能增加 34 | nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size), 35 | # 将卷积操作后的patch铺平 36 | Rearrange('b e h w -> b (h w) e'), 37 | ) 38 | 39 | def forward(self, x: Tensor) -> Tensor: 40 | x = self.projection(x) 41 | return x 42 | 43 | 44 | class GeM(nn.Module): 45 | def __init__(self, p=3, eps=1e-6): 46 | super(GeM, self).__init__() 47 | self.p = Parameter(torch.ones(1) * p) 48 | self.eps = eps 49 | 50 | def forward(self, x): 51 | return LF.gem(x, p=self.p, eps=self.eps) 52 | 53 | def __repr__(self): 54 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' 55 | 56 | 57 | class FeedForward(nn.Module): 58 | def __init__(self, d_model, d_ff=1024, dropout=0.1): 59 | super().__init__() 60 | 61 | self.linear_1 = nn.Linear(d_model, d_ff) 62 | self.dropout = nn.Dropout(dropout) 63 | self.linear_2 = nn.Linear(d_ff, d_model) 64 | 65 | def forward(self, x): 66 | x = self.dropout(F.relu(self.linear_1(x))) 67 | x = self.linear_2(x) 68 | return x 69 | 70 | 71 | class Norm(nn.Module): 72 | def __init__(self, d_model, eps=1e-6): 73 | super().__init__() 74 | 75 | self.size = d_model 76 | self.alpha = nn.Parameter(torch.ones(self.size)) 77 | self.bias = nn.Parameter(torch.zeros(self.size)) 78 | self.eps = eps 79 | 80 | def forward(self, x): 81 | norm = self.alpha * (x - x.mean(dim=-1, keepdim=True)) \ 82 | / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias 83 | return norm 84 | 85 | 86 | def attention(q, k, v, d_k, mask=None, dropout=None): 87 | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) 88 | 89 | if mask is not None: 90 | mask = mask.unsqueeze(1) 91 | scores = scores.masked_fill(mask == 0, -1e9) 92 | 93 | scores = F.softmax(scores, dim=-1) 94 | 95 | if dropout is not None: 96 | scores = dropout(scores) 97 | 98 | output = torch.matmul(scores, v) 99 | return output 100 | 101 | 102 | class MultiHeadAttention(nn.Module): 103 | def __init__(self, heads, d_model, dropout=0.1): 104 | super().__init__() 105 | 106 | self.d_model = d_model 107 | self.d_k = d_model // heads 108 | self.h = heads 109 | 110 | self.q_linear = nn.Linear(d_model, d_model) 111 | self.v_linear = nn.Linear(d_model, d_model) 112 | self.k_linear = nn.Linear(d_model, d_model) 113 | 114 | self.dropout = nn.Dropout(dropout) 115 | self.out = nn.Linear(d_model, d_model) 116 | 117 | def forward(self, q, k, v, mask=None): 118 | bs = q.size(0) 119 | 120 | k = self.k_linear(k).view(bs, -1, self.h, self.d_k) 121 | q = self.q_linear(q).view(bs, -1, self.h, self.d_k) 122 | v = self.v_linear(v).view(bs, -1, self.h, self.d_k) 123 | 124 | k = k.transpose(1, 2) 125 | q = q.transpose(1, 2) 126 | v = v.transpose(1, 2) 127 | 128 | scores = attention(q, k, v, self.d_k, mask, self.dropout) 129 | concat = scores.transpose(1, 2).contiguous() \ 130 | .view(bs, -1, self.d_model) 131 | output = self.out(concat) 132 | 133 | return output 134 | 135 | 136 | class Shen(nn.Module): #整合Vit和resnet 137 | def __init__(self, opt=None): 138 | super().__init__() 139 | heads = 4 140 | d_model = 512 141 | dropout = 0.1 142 | resnet50 = models.resnet50(pretrained=True) 143 | Vit = timm.create_model('vit_base_patch16_224', pretrained=False) 144 | self.linear = nn.Linear(768, 512) 145 | self.linear2 = nn.Linear(1024, 512) 146 | featuresV = list(Vit.children())[:-1] #ViT 147 | featuresR = list(resnet50.children())[:-3]#Res_without_last_stage 148 | self.backboneVV=nn.Sequential(*featuresV) 149 | self.backboneV = nn.Sequential(*featuresV, ClassificationHead(), self.linear) 150 | self.Classification=ClassificationHead() 151 | self.backboneRR = nn.Sequential(*featuresR) 152 | self.backboneR = nn.Sequential(*featuresR, GeM(), nn.Flatten()) 153 | self.gem=GeM() 154 | self.Fl=nn.Flatten() 155 | 156 | self.HW = Rearrange('b e h w -> b (h w) e') 157 | 158 | self.attn1 = MultiHeadAttention(heads, d_model, dropout=dropout) 159 | self.attn2 = MultiHeadAttention(heads, d_model, dropout=dropout) 160 | 161 | 162 | self.ff1 = FeedForward(d_model, dropout=dropout) 163 | self.ff2 = FeedForward(d_model, dropout=dropout) 164 | 165 | 166 | 167 | self.net_vlad = NetVLADLoupe(feature_size=512, max_samples=784, cluster_size=64, 168 | output_dim=512, gating=True, add_batch_norm=False, 169 | is_training=True) 170 | self.net_vlad_R = NetVLADLoupe(feature_size=256, max_samples=392, cluster_size=64, 171 | output_dim=256, gating=True, add_batch_norm=False, 172 | is_training=True) 173 | self.net_vlad_V = NetVLADLoupe(feature_size=256, max_samples=392, cluster_size=64, 174 | output_dim=256, gating=True, add_batch_norm=False, 175 | is_training=True) 176 | def forward(self, inputs): 177 | #ViT branch 178 | outVV=self.backboneVV(inputs) #(B,S,C) 179 | feature_V=self.linear(outVV) 180 | #Res branch 181 | outRR=self.backboneRR(inputs) 182 | outR = self.gem(outRR) 183 | outR = self.Fl(outR) #(B,C) (1*1024)for last concatenation 184 | outRR=self.HW(outRR) #(B,S,C) 185 | feature_R=self.linear2(outRR) 186 | #Inter_Transformer Encoder 187 | feature_fuse1 = feature_V + self.attn1(feature_V, feature_R, feature_R, mask=None) 188 | feature_fuse1 = feature_fuse1 + self.ff1(feature_fuse1) 189 | feature_fuse2 = feature_R + self.attn2(feature_R, feature_V, feature_V, mask=None) 190 | feature_fuse2 = feature_fuse2 + self.ff2(feature_fuse2) 191 | feature_fuse = torch.cat((feature_fuse1, feature_fuse2), dim=-2) 192 | feature_cat_origin = torch.cat((feature_V, feature_R), dim=-2) 193 | feature_fuse = torch.cat((feature_fuse, feature_cat_origin), dim=-1) 194 | 195 | #descriptor from Inter_Transformer Encoder(1*512) 196 | feature_fuse = feature_fuse.permute(0, 2, 1) 197 | feature_com = feature_fuse.unsqueeze(3) 198 | feature_com = self.net_vlad(feature_com) 199 | 200 | #decriptor from Res(1*256) 201 | feature_R = feature_R.permute(0, 2, 1) 202 | feature_R = feature_R.unsqueeze(-1) 203 | feature_R_enhanced = self.net_vlad_R(feature_R) 204 | #decriptor from ViT(1*256) 205 | feature_V= feature_V.permute(0, 2, 1) 206 | feature_V= feature_V.unsqueeze(-1) 207 | feature_V_enhanced = self.net_vlad_V(feature_V) 208 | 209 | #concatenate all descriptors 210 | feature_com = torch.cat((feature_R_enhanced, feature_com), dim=1) 211 | feature_com = torch.cat((feature_com, feature_V_enhanced), dim=1) 212 | feature_com=torch.cat((feature_com, outR), dim=1) 213 | 214 | return feature_com 215 | 216 | 217 | class ClassificationHead(nn.Sequential): 218 | def __init__(self, emb_size: int = 768, n_classes: int = 1000): 219 | super().__init__( 220 | Reduce('b n e -> b e', reduction='mean')) 221 | 222 | class Backbone(nn.Module): 223 | def __init__(self, opt=None): 224 | super().__init__() 225 | 226 | self.sigma_dim = 2048 227 | self.mu_dim = 2048 228 | 229 | self.backbone = Shen() 230 | 231 | 232 | class Stu_Backbone(nn.Module): 233 | def __init__(self): 234 | super(Stu_Backbone, self).__init__() 235 | resnet50 = models.resnet50(pretrained=True) 236 | featuresR = list(resnet50.children())[:-3] # Res去掉最后三层 237 | 238 | self.gem=GeM() 239 | self.Fl=nn.Flatten() 240 | self.backboneR_Stu = nn.Sequential(*featuresR) 241 | self.linear0 = nn.Linear(256, 1024) 242 | 243 | self.cHead = ClassificationHead() 244 | self.net_vlad = NetVLADLoupe(feature_size=196, max_samples=768, cluster_size=64, 245 | output_dim=256, gating=True, add_batch_norm=False, 246 | is_training=True) 247 | 248 | self.HW = Rearrange('b e h w -> b (h w) e') 249 | self.PatchEmbedding = PatchEmbedding() 250 | def forward(self, inputs): 251 | #Res branch(1*1024) 252 | outRR = self.backboneR_Stu(inputs) 253 | outR = self.gem(outRR) 254 | outR = self.Fl(outR) # (B,C) for last concatenation 255 | #descriptor(1*1024) 256 | outVV=self.PatchEmbedding(inputs) #(B,S,C) 257 | feature_V = outVV.permute(0, 2, 1) 258 | feature_V = feature_V.unsqueeze(-1) 259 | feature_V_enhanced = self.net_vlad(feature_V) 260 | outV= self.linear0(feature_V_enhanced) 261 | 262 | #concatenation 263 | feature_fuse = torch.cat((outV,outR), dim=-1) 264 | 265 | 266 | return feature_fuse 267 | 268 | 269 | class TeacherNet(Backbone): 270 | def __init__(self, opt=None): 271 | super().__init__() 272 | self.id = 'teacher' 273 | self.mean_head = nn.Sequential(L2Norm(dim=1)) 274 | 275 | def forward(self, inputs): 276 | B, C, H, W = inputs.shape # (B, 1, 3, 224, 224) 277 | # inputs = inputs.view(B * L, C, H, W) # ([B, 3, 224, 224]) 278 | 279 | backbone_output = self.backbone(inputs) # ([B, 2048, 1, 1]) 280 | mu = self.mean_head(backbone_output).view(B, -1) # ([B, 2048]) <= ([B, 2048, 1, 1]) 281 | 282 | return mu, torch.zeros_like(mu) 283 | 284 | 285 | class StudentNet(TeacherNet): 286 | def __init__(self, opt=None): 287 | super().__init__() 288 | self.id = 'student' 289 | self.var_head = nn.Sequential(nn.Linear(2048, self.sigma_dim), nn.Sigmoid()) 290 | self.backboneS = Stu_Backbone() 291 | def forward(self, inputs): 292 | B, C, H, W = inputs.shape # (B, 1, 3, 224, 224) 293 | inputs = inputs.view(B, C, H, W) # ([B, 3, 224, 224]) 294 | backbone_output = self.backboneS(inputs) 295 | 296 | mu = self.mean_head(backbone_output).view(B, -1) # ([B, 2048]) <= ([B, 2048, 1, 1]) 297 | log_sigma_sq = self.var_head(backbone_output).view(B, -1) # ([B, 2048]) <= ([B, 2048, 1, 1]) 298 | 299 | return mu, log_sigma_sq 300 | 301 | 302 | def deliver_model(opt, id): 303 | if id == 'tea': 304 | return TeacherNet(opt) 305 | elif id == 'stu': 306 | return StudentNet(opt) 307 | 308 | 309 | if __name__ == '__main__': 310 | tea = TeacherNet() 311 | stu = StudentNet() 312 | inputs = torch.rand((1, 3, 224, 224)) 313 | outputs_tea = tea(inputs) 314 | outputs_stu = stu(inputs) 315 | # print(outputs_tea.shape) 316 | # print(outputs_stu.shape) 317 | # print(tea.state_dict()) 318 | print(outputs_tea[0].shape, outputs_tea[1].shape) 319 | print(outputs_stu[0].shape, outputs_stu[1].shape) 320 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | from genericpath import exists 9 | 10 | 11 | class Options: 12 | def __init__(self): 13 | self.parser = argparse.ArgumentParser(description="Options") 14 | self.parser.add_argument('--phase', type=str, default='test_tea', help='phase', choices=['train_tea', 'test_tea', 'train_stu', 'test_stu']) 15 | self.parser.add_argument('--dataset', type=str, default='pitts', help='choose dataset.') 16 | self.parser.add_argument('--structDir', type=str, default='pittsburgh/structure', help='Path for structure.') 17 | self.parser.add_argument('--imgDir', type=str, default='pittsburgh', help='Path for images.') 18 | self.parser.add_argument('--com', type=str, default='', help='comment') 19 | self.parser.add_argument('--height', type=int, default=224, help='number of sequence to use.') 20 | self.parser.add_argument('--width', type=int, default=224, help='number of sequence to use.') 21 | self.parser.add_argument('--net', type=str, default='tscm', help='network') 22 | self.parser.add_argument('--trainer', type=str, default='trainer', help='trainer') 23 | self.parser.add_argument('--loss', type=str, default='tri', help='triplet loss or bayesian triplet loss', choices=['tri', 'cont', 'quad']) 24 | self.parser.add_argument('--margin', type=float, default=0.1, help='Margin for triplet loss. Default=0.1') 25 | self.parser.add_argument('--margin2', type=float, default=0.1, help='Margin2 for quadruplet loss. Default=0.1') 26 | self.parser.add_argument('--output_dim', type=int, default=0, help='Number of feature dimension. Default=512') 27 | self.parser.add_argument('--sigma_dim', type=int, default=0, help='Number of sigma dimension. Default=512') 28 | self.parser.add_argument('--batchSize', type=int, default=8, help='Number of triplets (query, pos, negs). Each triplet consists of 12 images.') 29 | self.parser.add_argument('--cacheBatchSize', type=int, default=128, help='Batch size for caching and testing') 30 | self.parser.add_argument('--cacheRefreshRate', type=int, default=0, help='How often to refresh cache, in number of queries. 0 for off') 31 | self.parser.add_argument('--nEpochs', type=int, default=200, help='number of epochs to train for') 32 | self.parser.add_argument('--nGPU', type=int, default=1, help='number of GPU to use.') 33 | self.parser.add_argument('--cGPU', type=int, default=2, help='core of GPU to use.') 34 | self.parser.add_argument('--optim', type=str, default='adam', help='optimizer to use', choices=['sgd', 'adam']) 35 | self.parser.add_argument('--lr', type=float, default=1e-5, help='Learning Rate.') 36 | self.parser.add_argument('--lrStep', type=float, default=5, help='Decay LR ever N steps.') 37 | self.parser.add_argument('--lrGamma', type=float, default=0.99, help='Multiply LR by Gamma for decaying.') 38 | self.parser.add_argument('--weightDecay', type=float, default=0.001, help='Weight decay for SGD.') 39 | self.parser.add_argument('--momentum', type=float, default=0.9, help='Momentum for SGD.') 40 | self.parser.add_argument('--cuda', action='store_false', help='use cuda') 41 | self.parser.add_argument('--d', action='store_true', help='debug mode') 42 | self.parser.add_argument('--threads', type=int, default=8, help='Number of threads for each data loader to use') 43 | self.parser.add_argument('--seed', type=int, default=1234, help='Random seed to use.') 44 | self.parser.add_argument('--logsPath', type=str, default='./logs', help='Path to save runs to.') 45 | self.parser.add_argument('--runsPath', type=str, default='not defined', help='Path to save runs to.') 46 | self.parser.add_argument('--resume', type=str, default='', help='Path to load checkpoint from, for resuming training or testing.') 47 | self.parser.add_argument('--evalEvery', type=int, default=1, help='Do a validation set run, and save, every N epochs.') 48 | self.parser.add_argument('--cacheRefreshEvery', type=int, default=1, help='refresh embedding cache, every N epochs.') 49 | self.parser.add_argument('--patience', type=int, default=10, help='Patience for early stopping. 0 is off.') 50 | self.parser.add_argument('--split', type=str, default='val', help='Split to use', choices=['val', 'test']) 51 | self.parser.add_argument('--encoder_dim', type=int, default=512, help='Number of feature dimension. Default=512') 52 | 53 | def parse(self): 54 | options = self.parser.parse_args() 55 | return options 56 | 57 | def update_opt_from_json(self, flag_file, options): 58 | if not exists(flag_file): 59 | raise ValueError('{} not exist'.format(flag_file)) 60 | # restore_var = ['runsPath', 'net', 'seqLen', 'num_clusters', 'output_dim', 'structDir', 'imgDir', 'lrStep', 'lrGamma', 'weightDecay', 'momentum', 'num_clusters', 'optim', 'margin', 'seed', 'patience'] 61 | do_not_update_list = ['resume', 'mode', 'phase', 'optim', 'split'] 62 | if os.path.exists(flag_file): 63 | with open(flag_file, 'r') as f: 64 | # stored_flags = {'--' + k: str(v) for k, v in json.load(f).items() if k in restore_var} 65 | stored_flags = {'--' + k: str(v) for k, v in json.load(f).items() if k not in do_not_update_list} 66 | to_del = [] 67 | for flag, val in stored_flags.items(): 68 | for act in self.parser._actions: 69 | if act.dest == flag[2:]: # stored parser match current parser 70 | # store_true / store_false args don't accept arguments, filter these 71 | if type(act.const) == type(True): 72 | if val == str(act.default): 73 | to_del.append(flag) 74 | else: 75 | stored_flags[flag] = '' 76 | else: 77 | if val == str(act.default): 78 | to_del.append(flag) 79 | 80 | for flag, val in stored_flags.items(): 81 | missing = True 82 | for act in self.parser._actions: 83 | if flag[2:] == act.dest: 84 | missing = False 85 | if missing: 86 | to_del.append(flag) 87 | 88 | for flag in to_del: 89 | del stored_flags[flag] 90 | 91 | train_flags = [x for x in list(sum(stored_flags.items(), tuple())) if len(x) > 0] 92 | print('restored flags:', train_flags) 93 | options = self.parser.parse_args(train_flags, namespace=options) 94 | return options 95 | 96 | 97 | class FixRandom: 98 | def __init__(self, seed) -> None: 99 | self.seed = seed 100 | torch.manual_seed(self.seed) 101 | random.seed(self.seed) 102 | np.random.seed(self.seed) 103 | torch.backends.cudnn.benchmark = False 104 | torch.use_deterministic_algorithms(True) 105 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' 106 | 107 | def seed_worker(self): 108 | worker_seed = self.seed 109 | np.random.seed(worker_seed) 110 | random.seed(worker_seed) 111 | -------------------------------------------------------------------------------- /readmat.py: -------------------------------------------------------------------------------- 1 | from scipy.io import loadmat 2 | from os.path import join 3 | features_struct=loadmat('pittsburgh/structure/pitts30k_train.mat' ) 4 | print(features_struct) -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | _libgcc_mutex 0.1 main defaults 2 | _openmp_mutex 4.5 1_gnu defaults 3 | absl-py 1.1.0 pypi_0 pypi 4 | anyio 3.6.1 pypi_0 pypi 5 | appdirs 1.4.4 pypi_0 pypi 6 | argon2-cffi 21.3.0 pypi_0 pypi 7 | argon2-cffi-bindings 21.2.0 pypi_0 pypi 8 | asttokens 2.0.5 pypi_0 pypi 9 | attrs 21.4.0 pypi_0 pypi 10 | babel 2.10.3 pypi_0 pypi 11 | backcall 0.2.0 pypi_0 pypi 12 | beautifulsoup4 4.11.1 pypi_0 pypi 13 | bleach 5.0.0 pypi_0 pypi 14 | brotlipy 0.7.0 py38h27cfd23_1003 defaults 15 | ca-certificates 2021.7.5 h06a4308_1 defaults 16 | cachetools 5.2.0 pypi_0 pypi 17 | certifi 2021.5.30 py38h06a4308_0 defaults 18 | cffi 1.14.6 py38h400218f_0 defaults 19 | chardet 4.0.0 py38h06a4308_1003 defaults 20 | click 8.1.7 pypi_0 pypi 21 | conda 4.10.3 py38h06a4308_0 defaults 22 | conda-package-handling 1.7.3 py38h27cfd23_1 defaults 23 | cryptography 3.4.7 py38hd23ed53_0 defaults 24 | cycler 0.11.0 pypi_0 pypi 25 | debugpy 1.6.0 pypi_0 pypi 26 | decorator 5.1.1 pypi_0 pypi 27 | defusedxml 0.7.1 pypi_0 pypi 28 | docker-pycreds 0.4.0 pypi_0 pypi 29 | einops 0.7.0 pypi_0 pypi 30 | entrypoints 0.4 pypi_0 pypi 31 | executing 0.8.3 pypi_0 pypi 32 | faiss-gpu 1.7.2 pypi_0 pypi 33 | fastjsonschema 2.15.3 pypi_0 pypi 34 | filelock 3.13.1 pypi_0 pypi 35 | fonttools 4.33.3 pypi_0 pypi 36 | fsspec 2024.2.0 pypi_0 pypi 37 | gitdb 4.0.11 pypi_0 pypi 38 | gitpython 3.1.42 pypi_0 pypi 39 | google-auth 2.8.0 pypi_0 pypi 40 | google-auth-oauthlib 0.4.6 pypi_0 pypi 41 | grpcio 1.46.3 pypi_0 pypi 42 | h5py 3.10.0 pypi_0 pypi 43 | huggingface-hub 0.21.4 pypi_0 pypi 44 | idna 2.10 pyhd3eb1b0_0 defaults 45 | imageio 2.34.0 pypi_0 pypi 46 | importlib-metadata 4.11.4 pypi_0 pypi 47 | importlib-resources 5.8.0 pypi_0 pypi 48 | ipykernel 6.15.0 pypi_0 pypi 49 | ipython 8.4.0 pypi_0 pypi 50 | ipython-genutils 0.2.0 pypi_0 pypi 51 | ipywidgets 7.7.0 pypi_0 pypi 52 | jedi 0.18.1 pypi_0 pypi 53 | jinja2 3.1.2 pypi_0 pypi 54 | joblib 1.3.2 pypi_0 pypi 55 | json5 0.9.8 pypi_0 pypi 56 | jsonschema 4.6.0 pypi_0 pypi 57 | jupyter-client 7.3.4 pypi_0 pypi 58 | jupyter-core 4.10.0 pypi_0 pypi 59 | jupyter-server 1.17.1 pypi_0 pypi 60 | jupyterlab 3.4.3 pypi_0 pypi 61 | jupyterlab-language-pack-zh-cn 3.4.post1 pypi_0 pypi 62 | jupyterlab-pygments 0.2.2 pypi_0 pypi 63 | jupyterlab-server 2.14.0 pypi_0 pypi 64 | jupyterlab-widgets 1.1.0 pypi_0 pypi 65 | kiwisolver 1.4.3 pypi_0 pypi 66 | lazy-loader 0.3 pypi_0 pypi 67 | ld_impl_linux-64 2.35.1 h7274673_9 defaults 68 | libffi 3.3 he6710b0_2 defaults 69 | libgcc-ng 9.3.0 h5101ec6_17 defaults 70 | libgomp 9.3.0 h5101ec6_17 defaults 71 | libstdcxx-ng 9.3.0 hd4cf53a_17 defaults 72 | markdown 3.3.7 pypi_0 pypi 73 | markupsafe 2.1.1 pypi_0 pypi 74 | matplotlib 3.5.2 pypi_0 pypi 75 | matplotlib-inline 0.1.3 pypi_0 pypi 76 | mistune 0.8.4 pypi_0 pypi 77 | nbclassic 0.3.7 pypi_0 pypi 78 | nbclient 0.6.4 pypi_0 pypi 79 | nbconvert 6.5.0 pypi_0 pypi 80 | nbformat 5.4.0 pypi_0 pypi 81 | ncurses 6.2 he6710b0_1 defaults 82 | nest-asyncio 1.5.5 pypi_0 pypi 83 | networkx 3.1 pypi_0 pypi 84 | notebook 6.4.12 pypi_0 pypi 85 | notebook-shim 0.1.0 pypi_0 pypi 86 | numpy 1.22.4 pypi_0 pypi 87 | oauthlib 3.2.0 pypi_0 pypi 88 | openssl 1.1.1k h27cfd23_0 defaults 89 | packaging 21.3 pypi_0 pypi 90 | pandocfilters 1.5.0 pypi_0 pypi 91 | parso 0.8.3 pypi_0 pypi 92 | pexpect 4.8.0 pypi_0 pypi 93 | pickleshare 0.7.5 pypi_0 pypi 94 | pillow 9.1.1 pypi_0 pypi 95 | pip 21.1.3 py38h06a4308_0 defaults 96 | prometheus-client 0.14.1 pypi_0 pypi 97 | prompt-toolkit 3.0.29 pypi_0 pypi 98 | protobuf 3.19.4 pypi_0 pypi 99 | psutil 5.9.1 pypi_0 pypi 100 | ptyprocess 0.7.0 pypi_0 pypi 101 | pure-eval 0.2.2 pypi_0 pypi 102 | pyasn1 0.4.8 pypi_0 pypi 103 | pyasn1-modules 0.2.8 pypi_0 pypi 104 | pycosat 0.6.3 py38h7b6447c_1 defaults 105 | pycparser 2.20 py_2 defaults 106 | pygments 2.12.0 pypi_0 pypi 107 | pynvml 11.5.0 pypi_0 pypi 108 | pyopenssl 20.0.1 pyhd3eb1b0_1 defaults 109 | pyparsing 3.0.9 pypi_0 pypi 110 | pyrsistent 0.18.1 pypi_0 pypi 111 | pysocks 1.7.1 py38h06a4308_0 defaults 112 | python 3.8.10 h12debd9_8 defaults 113 | python-dateutil 2.8.2 pypi_0 pypi 114 | pytz 2022.1 pypi_0 pypi 115 | pywavelets 1.4.1 pypi_0 pypi 116 | pyyaml 6.0.1 pypi_0 pypi 117 | pyzmq 23.2.0 pypi_0 pypi 118 | readline 8.1 h27cfd23_0 defaults 119 | requests 2.25.1 pyhd3eb1b0_0 defaults 120 | requests-oauthlib 1.3.1 pypi_0 pypi 121 | rsa 4.8 pypi_0 pypi 122 | ruamel_yaml 0.15.100 py38h27cfd23_0 defaults 123 | safetensors 0.4.2 pypi_0 pypi 124 | scikit-image 0.21.0 pypi_0 pypi 125 | scikit-learn 1.3.2 pypi_0 pypi 126 | scipy 1.10.1 pypi_0 pypi 127 | send2trash 1.8.0 pypi_0 pypi 128 | sentry-sdk 1.42.0 pypi_0 pypi 129 | setproctitle 1.3.3 pypi_0 pypi 130 | setuptools 52.0.0 py38h06a4308_0 defaults 131 | six 1.16.0 pyhd3eb1b0_0 defaults 132 | smmap 5.0.1 pypi_0 pypi 133 | sniffio 1.2.0 pypi_0 pypi 134 | soupsieve 2.3.2.post1 pypi_0 pypi 135 | sqlite 3.36.0 hc218d9a_0 defaults 136 | stack-data 0.3.0 pypi_0 pypi 137 | supervisor 4.2.4 pypi_0 pypi 138 | tensorboard 2.9.1 pypi_0 pypi 139 | tensorboard-data-server 0.6.1 pypi_0 pypi 140 | tensorboard-plugin-wit 1.8.1 pypi_0 pypi 141 | terminado 0.15.0 pypi_0 pypi 142 | threadpoolctl 3.3.0 pypi_0 pypi 143 | tifffile 2023.7.10 pypi_0 pypi 144 | timm 0.9.16 pypi_0 pypi 145 | tinycss2 1.1.1 pypi_0 pypi 146 | tk 8.6.10 hbc83047_0 defaults 147 | torch 1.11.0+cu113 pypi_0 pypi 148 | torchsummary 1.5.1 pypi_0 pypi 149 | torchvision 0.12.0+cu113 pypi_0 pypi 150 | tornado 6.1 pypi_0 pypi 151 | tqdm 4.61.2 pyhd3eb1b0_1 defaults 152 | traitlets 5.3.0 pypi_0 pypi 153 | typing-extensions 4.2.0 pypi_0 pypi 154 | urllib3 1.26.18 pypi_0 pypi 155 | wandb 0.16.4 pypi_0 pypi 156 | wcwidth 0.2.5 pypi_0 pypi 157 | webencodings 0.5.1 pypi_0 pypi 158 | websocket-client 1.3.3 pypi_0 pypi 159 | werkzeug 2.1.2 pypi_0 pypi 160 | wheel 0.36.2 pyhd3eb1b0_0 defaults 161 | widgetsnbextension 3.6.0 pypi_0 pypi 162 | xz 5.2.5 h7b6447c_0 defaults 163 | yaml 0.2.5 h7b6447c_0 defaults 164 | zipp 3.8.0 pypi_0 pypi 165 | zlib 1.2.11 h7b6447c_3 defaults 166 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import importlib 3 | import os 4 | import pickle 5 | import shutil 6 | from os.path import dirname, exists, join 7 | import h5py 8 | import faiss 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import wandb 13 | from torch.utils.data import DataLoader 14 | from tqdm import tqdm 15 | from datetime import datetime 16 | import json 17 | import torch.optim as optim 18 | from torchsummary import summary 19 | 20 | os.sys.path.append(os.path.join(os.path.dirname(__file__), '../')) 21 | import time 22 | from options import FixRandom 23 | from utils import cal_recall, light_log, schedule_device 24 | 25 | 26 | class CKD_loss(nn.Module): 27 | def __init__(self, margin) -> None: 28 | super().__init__() 29 | self.margin = margin 30 | 31 | def forward(self, embs_a, embs_p, embs_n, mu_tea_a, mu_tea_p, mu_tea_n): # (1, D) 32 | SaTp = torch.norm(embs_a - mu_tea_p, p=2).pow(2) 33 | SpTa = torch.norm(embs_p - mu_tea_a, p=2).pow(2) 34 | 35 | SaTn = torch.norm(embs_a - mu_tea_n, p=2).pow(2) 36 | SnTa = torch.norm(embs_n - mu_tea_a, p=2).pow(2) 37 | 38 | SaTa = torch.norm(embs_a - mu_tea_a, p=2).pow(2) 39 | SpTp = torch.norm(embs_p - mu_tea_p, p=2).pow(2) 40 | SnTn = torch.norm(embs_n - mu_tea_n, p=2).pow(2) 41 | 42 | dis_D = SaTp + SpTa + SaTa + SpTp + SnTn 43 | # dis_D=SaTp+SpTa 44 | loss = 0.5 * (torch.clamp(self.margin + dis_D, min=0).pow(2)) 45 | 46 | return loss 47 | 48 | class Trainer: 49 | def __init__(self, options) -> None: 50 | 51 | self.opt = options 52 | 53 | # r variables 54 | self.step = 0 55 | self.epoch = 0 56 | self.current_lr = 0 57 | self.best_recalls = [0, 0, 0] 58 | 59 | # seed 60 | fix_random = FixRandom(self.opt.seed) 61 | self.seed_worker = fix_random.seed_worker() 62 | self.time_stamp = datetime.now().strftime('%m%d_%H%M%S') 63 | 64 | # set device 65 | if self.opt.phase == 'train_tea': 66 | self.opt.cGPU = schedule_device() 67 | if self.opt.cuda and not torch.cuda.is_available(): 68 | raise Exception("No GPU found, please run with --nocuda :(") 69 | torch.cuda.set_device(self.opt.cGPU) 70 | self.device = torch.device("cuda") 71 | print('{}:{}{}'.format('device', self.device, torch.cuda.current_device())) 72 | 73 | # CKD_loss 74 | self.CKD_loss = CKD_loss(margin=torch.tensor(self.opt.margin, device=self.device)) 75 | # make model 76 | if self.opt.phase == 'train_tea': 77 | self.model, self.optimizer, self.scheduler, self.criterion = self.make_model() 78 | elif self.opt.phase == 'train_stu': 79 | self.teacher_net, self.student_net, self.optimizer, self.scheduler, self.criterion = self.make_model() 80 | self.model = self.teacher_net 81 | elif self.opt.phase in ['test_tea', 'test_stu']: 82 | self.model = self.make_model() 83 | else: 84 | raise Exception('Undefined phase :(') 85 | 86 | # make folders 87 | self.make_folders() 88 | # make dataset 89 | self.make_dataset() 90 | # online logs 91 | if self.opt.phase in ['train_tea', 'train_stu']: 92 | wandb.init(project="TSCM", config=vars(self.opt), 93 | name=f"{self.opt.loss}_{self.opt.phase}_{self.time_stamp}") 94 | 95 | def make_folders(self): 96 | ''' create folders to store tensorboard files and a copy of networks files 97 | ''' 98 | if self.opt.phase in ['train_tea', 'train_stu']: 99 | self.opt.runsPath = join(self.opt.logsPath, f"{self.opt.loss}_{self.opt.phase}_{self.time_stamp}") 100 | if not os.path.exists(join(self.opt.runsPath, 'models')): 101 | os.makedirs(join(self.opt.runsPath, 'models')) 102 | 103 | if not os.path.exists(join(self.opt.runsPath, 'transformed')): 104 | os.makedirs(join(self.opt.runsPath, 'transformed')) 105 | 106 | for file in [__file__, 'datasets/{}.py'.format(self.opt.dataset), 'networks/{}.py'.format(self.opt.net)]: 107 | shutil.copyfile(file, os.path.join(self.opt.runsPath, 'models', file.split('/')[-1])) 108 | 109 | with open(join(self.opt.runsPath, 'flags.json'), 'w') as f: 110 | f.write(json.dumps({k: v for k, v in vars(self.opt).items()}, indent='')) 111 | 112 | def make_dataset(self): 113 | ''' make dataset 114 | ''' 115 | if self.opt.phase in ['train_tea', 'train_stu']: 116 | assert os.path.exists(f'datasets/{self.opt.dataset}.py'), 'Cannot find ' + f'{self.opt.dataset}.py :(' 117 | self.dataset = importlib.import_module('datasets.' + self.opt.dataset) 118 | elif self.opt.phase in ['test_tea', 'test_stu']: 119 | self.dataset = importlib.import_module('tmp.models.{}'.format(self.opt.dataset)) 120 | 121 | # for emb cache 122 | self.whole_train_set = self.dataset.get_whole_training_set(self.opt) 123 | self.whole_training_data_loader = DataLoader(dataset=self.whole_train_set, num_workers=self.opt.threads, 124 | batch_size=self.opt.cacheBatchSize, shuffle=False, 125 | pin_memory=self.opt.cuda, worker_init_fn=self.seed_worker) 126 | self.whole_val_set = self.dataset.get_whole_val_set(self.opt) 127 | self.whole_val_data_loader = DataLoader(dataset=self.whole_val_set, num_workers=self.opt.threads, 128 | batch_size=self.opt.cacheBatchSize, shuffle=False, 129 | pin_memory=self.opt.cuda, worker_init_fn=self.seed_worker) 130 | self.whole_test_set = self.dataset.get_whole_test_set(self.opt) 131 | self.whole_test_data_loader = DataLoader(dataset=self.whole_test_set, num_workers=self.opt.threads, 132 | batch_size=self.opt.cacheBatchSize, shuffle=False, 133 | pin_memory=self.opt.cuda, worker_init_fn=self.seed_worker) 134 | 135 | self.train_set = self.dataset.get_training_query_set(self.opt, self.opt.margin) 136 | self.training_data_loader = DataLoader(dataset=self.train_set, num_workers=8, batch_size=self.opt.batchSize, 137 | shuffle=True, collate_fn=self.dataset.collate_fn, 138 | worker_init_fn=self.seed_worker) 139 | print('{}:{}, {}:{}, {}:{}, {}:{}, {}:{}'.format('dataset', self.opt.dataset, 'database', 140 | self.whole_train_set.dbStruct.numDb, 'train_set', 141 | self.whole_train_set.dbStruct.numQ, 'val_set', 142 | self.whole_val_set.dbStruct.numQ, 'test_set', 143 | self.whole_test_set.dbStruct.numQ)) 144 | print('{}:{}, {}:{}'.format('cache_bs', self.opt.cacheBatchSize, 'tuple_bs', self.opt.batchSize)) 145 | 146 | def make_model(self): 147 | '''build model 148 | ''' 149 | if self.opt.phase == 'train_tea': 150 | # build teacher net 151 | assert os.path.exists(f'networks/{self.opt.net}.py'), 'Cannot find ' + f'{self.opt.net}.py :(' 152 | network = importlib.import_module('networks.' + self.opt.net) 153 | model = network.deliver_model(self.opt, 'tea') 154 | model = model.to(self.device) 155 | outputs = model(torch.rand((2, 3, self.opt.height, self.opt.width), device=self.device)) 156 | self.opt.output_dim = \ 157 | model(torch.rand((2, 3, self.opt.height, self.opt.width), device=self.device))[0].shape[-1] 158 | self.opt.sigma_dim = \ 159 | model(torch.rand((2, 3, self.opt.height, self.opt.width), device=self.device))[1].shape[-1] # place holder 160 | elif self.opt.phase == 'train_stu': # load teacher net 161 | assert self.opt.resume != '', 'You need to define the teacher/resume path :(' 162 | if exists('tmp'): 163 | shutil.rmtree('tmp') 164 | os.mkdir('tmp') 165 | shutil.copytree(join(dirname(self.opt.resume), 'models'), join('tmp', 'models')) 166 | network = importlib.import_module(f'tmp.models.{self.opt.net}') 167 | model_tea = network.deliver_model(self.opt, 'tea').to(self.device) 168 | checkpoint = torch.load(self.opt.resume) 169 | model_tea.load_state_dict(checkpoint['state_dict']) 170 | # build student net 171 | assert os.path.exists(f'networks/{self.opt.net}.py'), 'Cannot find ' + f'{self.opt.net}.py :(' 172 | network = importlib.import_module('networks.' + self.opt.net) 173 | model = network.deliver_model(self.opt, 'stu').to(self.device) 174 | # checkpointS=torch.load('logs/tri_train_stu_0804_180109/ckpt_e_1.pth.tar') 175 | # model.load_state_dict(checkpointS['state_dict']) 176 | self.opt.output_dim = \ 177 | model(torch.rand((2, 3, self.opt.height, self.opt.width), device=self.device))[0].shape[-1] 178 | self.opt.sigma_dim = \ 179 | model(torch.rand((2, 3, self.opt.height, self.opt.width), device=self.device))[1].shape[-1] 180 | elif self.opt.phase in ['test_tea', 'test_stu']: 181 | # load teacher or student net 182 | assert self.opt.resume != '', 'You need to define a teacher/resume path :(' 183 | if exists('tmp'): 184 | shutil.rmtree('tmp') 185 | os.mkdir('tmp') 186 | shutil.copytree(join(dirname(self.opt.resume), 'models'), join('tmp', 'models')) 187 | network = importlib.import_module('tmp.models.{}'.format(self.opt.net)) 188 | model = network.deliver_model(self.opt, self.opt.phase[-3:]).to(self.device) 189 | checkpoint = torch.load(self.opt.resume) 190 | model.load_state_dict(checkpoint['state_dict']) 191 | 192 | print('{}:{}, {}:{}, {}:{}'.format(model.id, self.opt.net, 'loss', self.opt.loss, 'mu_dim', self.opt.output_dim, 193 | 'sigma_dim', self.opt.sigma_dim if self.opt.phase[-3:] == 'stu' else '-')) 194 | 195 | if self.opt.phase in ['train_tea', 'train_stu']: 196 | # optimizer 197 | if self.opt.optim == 'adam': 198 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), self.opt.lr, 199 | weight_decay=self.opt.weightDecay) 200 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer, self.opt.lrGamma, last_epoch=-1, verbose=False) 201 | elif self.opt.optim == 'sgd': 202 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=self.opt.lr, 203 | momentum=self.opt.momentum, weight_decay=self.opt.weightDecay) 204 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.opt.lrStep, gamma=self.opt.lrGamma) 205 | else: 206 | raise NameError('Undefined optimizer :(') 207 | 208 | 209 | criterion = nn.TripletMarginLoss(margin=self.opt.margin, p=2, reduction='sum').to(self.device) 210 | 211 | if self.opt.nGPU > 1: 212 | model = nn.DataParallel(model) 213 | 214 | if self.opt.phase == 'train_tea': 215 | return model, optimizer, scheduler, criterion 216 | elif self.opt.phase == 'train_stu': 217 | return model_tea, model, optimizer, scheduler, criterion 218 | elif self.opt.phase in ['test_tea', 'test_stu']: 219 | return model 220 | else: 221 | raise NameError('Undefined phase :(') 222 | 223 | def build_embedding_cache(self): 224 | '''build embedding cache, such that we can find the corresponding (p) and (n) with respect to (a) in embedding space 225 | ''' 226 | self.train_set.cache = os.path.join(self.opt.runsPath, self.train_set.whichSet + '_feat_cache.hdf5') 227 | with h5py.File(self.train_set.cache, mode='w') as h5: 228 | h5feat = h5.create_dataset("features", [len(self.whole_train_set), self.opt.output_dim], dtype=np.float32) 229 | with torch.no_grad(): 230 | for iteration, (input, indices) in enumerate(tqdm(self.whole_training_data_loader), 1): 231 | input = input.to(self.device) # torch.Size([32, 3, 154, 154]) ([32, 5, 3, 200, 200]) 232 | emb, _ = self.model(input) 233 | h5feat[indices.detach().numpy(), :] = emb.detach().cpu().numpy() 234 | del input, emb 235 | 236 | def build_embedding_cache_stu(self): 237 | '''build embedding cache, such that we can find the corresponding (p) and (n) with respect to (a) in embedding space 238 | ''' 239 | self.train_set.cache = os.path.join(self.opt.runsPath, self.train_set.whichSet + '_feat_cache.hdf5') 240 | with h5py.File(self.train_set.cache, mode='w') as h5: 241 | h5feat = h5.create_dataset("features", [len(self.whole_train_set), self.opt.output_dim], dtype=np.float32) 242 | with torch.no_grad(): 243 | for iteration, (input, indices) in enumerate(tqdm(self.whole_training_data_loader), 1): 244 | input = input.to(self.device) # torch.Size([32, 3, 154, 154]) ([32, 5, 3, 200, 200]) 245 | emb, _ = self.student_net(input) 246 | h5feat[indices.detach().numpy(), :] = emb.detach().cpu().numpy() 247 | del input, emb 248 | 249 | def process_batch(self, batch_inputs): 250 | ''' 251 | process a batch of input 252 | ''' 253 | 254 | anchor, positives, negatives, neg_counts, indices = batch_inputs 255 | 256 | # in case we get an empty batch 257 | if anchor is None: 258 | return None, None 259 | 260 | # some reshaping to put query, pos, negs in a single (N, 3, H, W) tensor, where N = batchSize * (nQuery + nPos + n_neg) 261 | B = anchor.shape[0] # ([8, 1, 3, 200, 200]) 262 | n_neg = torch.sum(neg_counts) # tensor(80) = torch.sum(torch.Size([8])) 263 | 264 | input = torch.cat([anchor, positives, negatives]) # ([B, C, H, 200]) 265 | 266 | input = input.to(self.device) # ([96, 1, C, H, W]) 267 | embs, vars = self.model(input) # ([96, D]) 268 | 269 | tuple_loss = 0 270 | # Standard triplet loss (via PyTorch library) 271 | if self.opt.loss == 'tri': 272 | embs_a, embs_p, embs_n = torch.split(embs, [B, B, n_neg]) 273 | for i, neg_count in enumerate(neg_counts): 274 | for n in range(neg_count): 275 | negIx = (torch.sum(neg_counts[:i]) + n).item() 276 | tuple_loss += self.criterion(embs_a[i:i + 1], embs_p[i:i + 1], embs_n[negIx:negIx + 1]) 277 | tuple_loss /= n_neg.float().to(self.device) 278 | 279 | 280 | del input, embs, embs_a, embs_p, embs_n 281 | del anchor, positives, negatives 282 | 283 | return tuple_loss, n_neg 284 | 285 | def process_batch_stu(self, batch_inputs): 286 | ''' 287 | process a batch of input 288 | ''' 289 | anchor, positives, negatives, neg_counts, indices = batch_inputs 290 | 291 | # in case we get an empty batch 292 | if anchor is None: 293 | return None, None 294 | 295 | # some reshaping to put query, pos, negs in a single (N, 3, H, W) tensor, where N = batchSize * (nQuery + nPos + n_neg) 296 | B = anchor.shape[0] # ([8, 1, 3, 200, 200]) 297 | n_neg = torch.sum(neg_counts) # tensor(80) = torch.sum(torch.Size([8])) 298 | 299 | input = torch.cat([anchor, positives, negatives]) # ([B, C, H, 200]) 300 | 301 | input = input.to(self.device) # ([96, 1, C, H, W]) 302 | embs, vars = self.student_net(input) # ([96, D]) 303 | 304 | anchor = anchor.to(self.device) 305 | with torch.no_grad(): 306 | mu_tea, _ = self.teacher_net(input) # ([B, D]) 307 | # mu_stu, log_sigma_sq = self.student_net(anchor) # ([B, D]), ([B, D]) 308 | 309 | 310 | tuple_loss = 0 311 | CKDloss = 0 312 | 313 | # Standard triplet loss (via PyTorch library) 314 | if self.opt.loss == 'tri': 315 | embs_a, embs_p, embs_n = torch.split(embs, [B, B, n_neg]) 316 | vars_a, vars_p, vars_n = torch.split(vars, [B, B, n_neg]) 317 | mu_tea_a, mu_tea_p, mu_tea_n = torch.split(mu_tea, [B, B, n_neg]) 318 | for i, neg_count in enumerate(neg_counts): 319 | for n in range(neg_count): 320 | negIx = (torch.sum(neg_counts[:i]) + n).item() 321 | tuple_loss += self.criterion(embs_a[i:i + 1], embs_p[i:i + 1], embs_n[negIx:negIx + 1]) 322 | CKDloss += self.CKD_loss(embs_a[i:i + 1], embs_p[i:i + 1], embs_n[negIx:negIx + 1], 323 | mu_tea_a[i:i + 1], mu_tea_p[i:i + 1], mu_tea_n[negIx:negIx + 1]) 324 | 325 | tuple_loss /= n_neg.float().to(self.device) 326 | CKDloss /= n_neg.float().to(self.device) 327 | del input, embs, embs_a, embs_p, embs_n 328 | del anchor, positives, negatives 329 | return tuple_loss + CKDloss, n_neg 330 | 331 | def train(self): 332 | not_improved = 0 333 | for epoch in range(self.opt.nEpochs): 334 | self.epoch = epoch 335 | self.current_lr = self.optimizer.state_dict()['param_groups'][0]['lr'] 336 | 337 | # build embedding cache 338 | if self.epoch % self.opt.cacheRefreshEvery == 0: 339 | self.model.eval() 340 | self.build_embedding_cache() 341 | self.model.train() 342 | 343 | # train 344 | tuple_loss_sum = 0 345 | for _, batch_inputs in enumerate(tqdm(self.training_data_loader)): 346 | self.step += 1 347 | 348 | self.optimizer.zero_grad() 349 | tuple_loss, n_neg = self.process_batch(batch_inputs) 350 | if tuple_loss is None: 351 | continue 352 | tuple_loss.backward() 353 | self.optimizer.step() 354 | tuple_loss_sum += tuple_loss.item() 355 | 356 | if self.step % 10 == 0: 357 | wandb.log({'train_tuple_loss': tuple_loss.item()}, step=self.step) 358 | wandb.log({'train_batch_num_neg': n_neg}, step=self.step) 359 | 360 | n_batches = len(self.training_data_loader) 361 | wandb.log({'train_avg_tuple_loss': tuple_loss_sum / n_batches}, step=self.step) 362 | torch.cuda.empty_cache() 363 | self.scheduler.step() 364 | 365 | # val every x epochs 366 | if (self.epoch % self.opt.evalEvery) == 0: 367 | recalls = self.val(self.model) 368 | if recalls[0] > self.best_recalls[0]: 369 | self.best_recalls = recalls 370 | not_improved = 0 371 | else: 372 | not_improved += self.opt.evalEvery 373 | # light log 374 | vars_to_log = [ 375 | 'e={:>2d},'.format(self.epoch), 376 | 'lr={:>.8f},'.format(self.current_lr), 377 | 'tl={:>.4f},'.format(tuple_loss_sum / n_batches), 378 | 'r@1/5/10={:.2f}/{:.2f}/{:.2f}'.format(recalls[0], recalls[1], recalls[2]), 379 | '\n' if not_improved else ' *\n', 380 | ] 381 | light_log(self.opt.runsPath, vars_to_log) 382 | else: 383 | recalls = None 384 | self.save_model(self.model, is_best=not not_improved) 385 | 386 | # stop when not improving for a period 387 | if self.opt.phase == 'train_tea': 388 | if self.opt.patience > 0 and not_improved > self.opt.patience: 389 | print('terminated because performance has not improve for', self.opt.patience, 'epochs') 390 | break 391 | 392 | self.save_model(self.model, is_best=False) 393 | print('best r@1/5/10={:.2f}/{:.2f}/{:.2f}'.format(self.best_recalls[0], self.best_recalls[1], 394 | self.best_recalls[2])) 395 | 396 | return self.best_recalls 397 | 398 | def train_student(self): 399 | not_improved = 0 400 | for epoch in range(self.opt.nEpochs): 401 | self.epoch = epoch 402 | self.current_lr = self.optimizer.state_dict()['param_groups'][0]['lr'] 403 | 404 | # build embedding cache 405 | if self.epoch % self.opt.cacheRefreshEvery == 0: 406 | self.student_net.eval() 407 | self.build_embedding_cache() 408 | self.student_net.train() 409 | # train 410 | tuple_loss_sum = 0 411 | for _, batch_inputs in enumerate(tqdm(self.training_data_loader)): 412 | self.step += 1 413 | 414 | self.optimizer.zero_grad() 415 | tuple_loss, n_neg = self.process_batch_stu(batch_inputs) 416 | if tuple_loss is None: 417 | continue 418 | tuple_loss.backward() 419 | self.optimizer.step() 420 | tuple_loss_sum += tuple_loss.item() 421 | loss_sum = tuple_loss_sum 422 | if self.step % 10 == 0: 423 | wandb.log({'train_tuple_loss': tuple_loss.item()}, step=self.step) 424 | wandb.log({'train_batch_num_neg': n_neg}, step=self.step) 425 | 426 | n_batches = len(self.training_data_loader) 427 | wandb.log({'train_avg_tuple_loss': tuple_loss_sum / n_batches}, step=self.step) 428 | wandb.log({'student/epoch_loss': loss_sum / n_batches}, step=self.step) 429 | torch.cuda.empty_cache() 430 | self.scheduler.step() 431 | 432 | # val 433 | if (self.epoch % self.opt.evalEvery) == 0: 434 | recalls = self.val(self.student_net) 435 | if recalls[0] > self.best_recalls[0]: 436 | self.best_recalls = recalls 437 | not_improved = 0 438 | else: 439 | not_improved += self.opt.evalEvery 440 | 441 | light_log(self.opt.runsPath, [ 442 | f'e={self.epoch:>2d},', 443 | f'lr={self.current_lr:>.8f},', 444 | f'tl={loss_sum / n_batches:>.4f},', 445 | f'r@1/5/10={recalls[0]:.2f}/{recalls[1]:.2f}/{recalls[2]:.2f}', 446 | '\n' if not_improved else ' *\n', 447 | ]) 448 | else: 449 | recalls = None 450 | 451 | self.save_model(self.student_net, is_best=False, save_every_epoch=True) 452 | if self.opt.patience > 0 and not_improved > self.opt.patience: 453 | print('terminated because performance has not improve for', self.opt.patience, 'epochs') 454 | break 455 | 456 | print('best r@1/5/10={:.2f}/{:.2f}/{:.2f}'.format(self.best_recalls[0], self.best_recalls[1], 457 | self.best_recalls[2])) 458 | return self.best_recalls 459 | 460 | def val(self, model): 461 | recalls, _ = self.get_recall(model) 462 | for i, n in enumerate([1, 5, 10]): 463 | wandb.log({'{}/{}_r@{}'.format(model.id, self.opt.split, n): recalls[i]}, step=self.step) 464 | # self.writer.add_scalar('{}/{}_r@{}'.format(model.id, self.opt.split, n), recalls[i], self.epoch) 465 | 466 | return recalls 467 | 468 | def test(self): 469 | recalls, _ = self.get_recall(self.model, save_embs=True) 470 | print('best r@1/5/10={:.2f}/{:.2f}/{:.2f}'.format(recalls[0], recalls[1], recalls[2])) 471 | #summary(self.model, input_size=(3, 224, 224)) 472 | 473 | return recalls 474 | 475 | def save_model(self, model, is_best=False, save_every_epoch=False): 476 | if is_best: 477 | torch.save({ 478 | 'epoch': self.epoch, 479 | 'step': self.step, 480 | 'state_dict': model.state_dict(), 481 | 'optimizer': self.optimizer.state_dict(), 482 | 'scheduler': self.scheduler.state_dict(), 483 | }, os.path.join(self.opt.runsPath, 'ckpt_best.pth.tar')) 484 | 485 | if save_every_epoch: 486 | torch.save({ 487 | 'epoch': self.epoch, 488 | 'step': self.step, 489 | 'state_dict': model.state_dict(), 490 | 'optimizer': self.optimizer.state_dict(), 491 | 'scheduler': self.scheduler.state_dict(), 492 | }, os.path.join(self.opt.runsPath, 'ckpt_e_{}.pth.tar'.format(self.epoch))) 493 | 494 | def get_recall(self, model, save_embs=False): 495 | model.eval() 496 | 497 | if self.opt.split == 'val': 498 | eval_dataloader = self.whole_val_data_loader 499 | eval_set = self.whole_val_set 500 | elif self.opt.split == 'test': 501 | eval_dataloader = self.whole_test_data_loader 502 | eval_set = self.whole_test_set 503 | # print(f"{self.opt.split} len:{len(eval_set)}") 504 | 505 | whole_mu = torch.zeros((len(eval_set), self.opt.output_dim), device=self.device) # (N, D) 506 | whole_var = torch.zeros((len(eval_set), self.opt.sigma_dim), device=self.device) # (N, D) 507 | gt = eval_set.get_positives() # (N, n_pos) 508 | start_time = time.time() 509 | with torch.no_grad(): 510 | for iteration, (input, indices) in enumerate(tqdm(eval_dataloader), 1): 511 | input = input.to(self.device) 512 | # print(input.shape) 513 | mu, var = model(input) # (B, D) 514 | # summary(self.model, input_size=input.shape[1:]) 515 | # print(input.shape) 516 | # var = torch.exp(var) 517 | whole_mu[indices, :] = mu 518 | whole_var[indices, :] = var 519 | del input, mu, var 520 | end_time = time.time() 521 | 522 | elapsed_time = end_time - start_time 523 | print("Elapsed Time:", elapsed_time) 524 | n_values = [1, 5, 10] 525 | 526 | whole_var = torch.exp(whole_var) 527 | whole_mu = whole_mu.cpu().numpy() 528 | whole_var = whole_var.cpu().numpy() 529 | mu_q = whole_mu[eval_set.dbStruct.numDb:].astype('float32') 530 | mu_db = whole_mu[:eval_set.dbStruct.numDb].astype('float32') 531 | sigma_q = whole_var[eval_set.dbStruct.numDb:].astype('float32') 532 | sigma_db = whole_var[:eval_set.dbStruct.numDb].astype('float32') 533 | faiss_index = faiss.IndexFlatL2(mu_q.shape[1]) 534 | faiss_index.add(mu_db) 535 | dists, preds = faiss_index.search(mu_q, max(n_values)) # the results is sorted 536 | 537 | # cull queries without any ground truth positives in the database 538 | val_inds = [True if len(gt[ind]) != 0 else False for ind in range(len(gt))] 539 | val_inds = np.array(val_inds) 540 | mu_q = mu_q[val_inds] 541 | sigma_q = sigma_q[val_inds] 542 | preds = preds[val_inds] 543 | dists = dists[val_inds] 544 | gt = gt[val_inds] 545 | 546 | recall_at_k = cal_recall(preds, gt, n_values) 547 | 548 | if save_embs: 549 | with open(join(self.opt.runsPath, '{}_db_embeddings_{}.pickle'.format(self.opt.split, 550 | self.opt.resume.split('.')[-3].split( 551 | '_')[-1])), 'wb') as handle: 552 | pickle.dump(mu_q, handle, protocol=pickle.HIGHEST_PROTOCOL) 553 | pickle.dump(mu_db, handle, protocol=pickle.HIGHEST_PROTOCOL) 554 | pickle.dump(sigma_q, handle, protocol=pickle.HIGHEST_PROTOCOL) 555 | pickle.dump(sigma_db, handle, protocol=pickle.HIGHEST_PROTOCOL) 556 | pickle.dump(preds, handle, protocol=pickle.HIGHEST_PROTOCOL) 557 | pickle.dump(dists, handle, protocol=pickle.HIGHEST_PROTOCOL) 558 | pickle.dump(gt, handle, protocol=pickle.HIGHEST_PROTOCOL) 559 | pickle.dump(whole_mu, handle, protocol=pickle.HIGHEST_PROTOCOL) 560 | pickle.dump(whole_var, handle, protocol=pickle.HIGHEST_PROTOCOL) 561 | print('embeddings saved for post processing') 562 | 563 | return recall_at_k, None 564 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from os.path import join 3 | 4 | import faiss 5 | import numpy as np 6 | from pynvml import * 7 | from scipy import stats 8 | from scipy.io import loadmat 9 | from scipy.optimize import least_squares 10 | from skimage import io 11 | 12 | 13 | def linear_fit(x, y, w, report_error=False): 14 | def cost(p, x, y, w): 15 | k = p[0] 16 | b = p[1] 17 | error = y - (k * x + b) 18 | error *= w 19 | return error 20 | 21 | p_init = np.array([-1, 1]) 22 | ret = least_squares(cost, p_init, args=(x, y, w), verbose=0) 23 | # print(ret['x'][0], ret['x'][1], ) 24 | y_fitted = ret['x'][0] * x + ret['x'][1] 25 | error = ret['cost'] 26 | if report_error: 27 | return y_fitted, error 28 | else: 29 | return y_fitted 30 | 31 | 32 | def reduce_sigma(sigma, std_or_sq, log_or_linear, hmean_or_mean): 33 | ''' 34 | input sigma: sigma^2, ([1, D]) 35 | output sigma: sigma, (1) 36 | ''' 37 | if log_or_linear == 'log': 38 | print('log') 39 | sigma = np.log(sigma) 40 | elif log_or_linear == 'linear': 41 | pass 42 | else: 43 | raise NameError('undefined') 44 | 45 | if std_or_sq == 'std': 46 | sigma = np.sqrt(sigma) 47 | elif std_or_sq == 'sq': 48 | pass 49 | else: 50 | raise NameError('undefined') 51 | 52 | if hmean_or_mean == 'hmean': 53 | sigma = stats.hmean(sigma, axis=1) # ([numQ,]) 54 | elif hmean_or_mean == 'mean': 55 | sigma = np.mean(sigma, axis=1) # ([numQ,]) 56 | else: 57 | raise NameError('undefined') 58 | 59 | return sigma 60 | 61 | 62 | def schedule_device(): 63 | ''' output id of the graphic card with most free memory 64 | ''' 65 | nvmlInit() 66 | deviceCount = nvmlDeviceGetCount() 67 | frees = [] 68 | for i in range(deviceCount): 69 | handle = nvmlDeviceGetHandleByIndex(i) 70 | # print("GPU", i, ":", nvmlDeviceGetName(handle)) 71 | info = nvmlDeviceGetMemoryInfo(handle) 72 | frees.append(info.free / 1e9) 73 | nvmlShutdown() 74 | # print(frees) 75 | id = frees.index(max(frees)) 76 | # print(id) 77 | return id 78 | 79 | def light_log(path, args): 80 | with open(join(path, 'screen.log'), 'a') as f: 81 | for arg in args: 82 | f.write(arg) 83 | f.flush() 84 | print(arg, end='') 85 | 86 | 87 | def cal_recall_from_embeddings(gt, qFeat, dbFeat): 88 | n_values = [1, 5, 10] 89 | 90 | # ---------------------------------------------------- sklearn --------------------------------------------------- # 91 | # knn = NearestNeighbors(n_jobs=-1) 92 | # knn.fit(dbFeat) 93 | # dists, predictions = knn.kneighbors(qFeat, len(dbFeat)) 94 | 95 | # --------------------------------- use faiss to do NN search -------------------------------- # 96 | faiss_index = faiss.IndexFlatL2(qFeat.shape[1]) 97 | faiss_index.add(dbFeat) 98 | dists, predictions = faiss_index.search(qFeat, max(n_values)) # the results is sorted 99 | 100 | recall_at_n = cal_recall(predictions, gt, n_values) 101 | return recall_at_n 102 | 103 | 104 | def cal_recall(ranks, pidx, ks): 105 | recall_at_k = np.zeros(len(ks)) 106 | for qidx in range(ranks.shape[0]): 107 | for i, k in enumerate(ks): 108 | #print("recall") 109 | #print(ranks[qidx, :k]) 110 | if np.sum(np.in1d(ranks[qidx, :k], pidx[qidx])) > 0: 111 | recall_at_k[i:] += 1 112 | #print("--") 113 | #print(pidx[qidx]) 114 | break 115 | 116 | recall_at_k /= ranks.shape[0] 117 | 118 | return recall_at_k * 100.0 119 | 120 | 121 | def cal_apk(pidx, rank, k): 122 | if len(rank) > k: 123 | rank = rank[:k] 124 | 125 | score = 0.0 126 | num_hits = 0.0 127 | 128 | for i, p in enumerate(rank): 129 | if p in pidx and p not in rank[:i]: 130 | num_hits += 1.0 131 | score += num_hits / (i + 1.0) 132 | 133 | return score / min(len(pidx), k) * 100.0 134 | 135 | 136 | def cal_mapk(ranks, pidxs, k): 137 | 138 | return np.mean([cal_apk(a, p, k) for a, p in zip(pidxs, ranks)]) 139 | 140 | 141 | def get_zoomed_bins(sigma, num_of_bins): 142 | s_min = np.min(sigma) 143 | s_max = np.max(sigma) 144 | print(s_min, s_max) 145 | bins_parent = np.linspace(s_min, s_max, num=num_of_bins) 146 | k = 0 147 | while True: 148 | indices = [] 149 | bins_child = np.linspace(bins_parent[0], bins_parent[-1 - k], num=num_of_bins) 150 | for index in range(num_of_bins - 1): 151 | target_q_ind_l = np.where(sigma >= bins_child[index]) 152 | if index != num_of_bins - 2: 153 | target_q_ind_r = np.where(sigma < bins_child[index + 1]) 154 | else: 155 | target_q_ind_r = np.where(sigma <= bins_child[index + 1]) 156 | target_q_ind = np.intersect1d(target_q_ind_l[0], target_q_ind_r[0]) 157 | indices.append(target_q_ind) 158 | # if len(indices[-1]) > int(sigma.shape[0] * 0.0005): 159 | if len(indices[-1]) > int(sigma.shape[0] * 0.001) or k == num_of_bins - 2: 160 | break 161 | else: 162 | k = k + 1 163 | # print('{:.3f}'.format(sum([len(x) for x in indices]) / sigma.shape[0]), [len(x) for x in indices]) 164 | # print('k=', k) 165 | return indices, bins_child, k 166 | 167 | 168 | def bin_pr(preds, dists, gt, vis=False): 169 | # dists_m = np.around(dists[:, 0], 2) # (4620,) 170 | # dists_u = np.array(list(set(dists_m))) 171 | # dists_u = np.sort(dists_u) # small > large 172 | 173 | dists_u = np.linspace(np.min(dists[:, 0]), np.max(dists[:, 0]), num=100) 174 | 175 | recalls = [] 176 | precisions = [] 177 | for th in dists_u: 178 | TPCount = 0 179 | FPCount = 0 180 | FNCount = 0 181 | TNCount = 0 182 | for index_q in range(dists.shape[0]): 183 | # Positive 184 | if dists[index_q, 0] < th: 185 | # True 186 | if np.any(np.in1d(preds[index_q, 0], gt[index_q])): 187 | TPCount += 1 188 | else: 189 | FPCount += 1 190 | else: 191 | if np.any(np.in1d(preds[index_q, 0], gt[index_q])): 192 | FNCount += 1 193 | else: 194 | TNCount += 1 195 | assert TPCount + FPCount + FNCount + TNCount == dists.shape[0], 'Count Error!' 196 | if TPCount + FNCount == 0 or TPCount + FPCount == 0: 197 | # print('zero') 198 | continue 199 | recall = TPCount / (TPCount + FNCount) 200 | precision = TPCount / (TPCount + FPCount) 201 | recalls.append(recall) 202 | precisions.append(precision) 203 | if vis: 204 | from matplotlib import pyplot as plt 205 | plt.style.use('ggplot') 206 | fig = plt.figure(figsize=(5, 5)) 207 | ax = fig.add_subplot(111) 208 | ax.plot(recalls, precisions) 209 | ax.set_title('Precision-Recall') 210 | ax.set_xlabel('Recall') 211 | ax.set_ylabel('Precision') 212 | ax.set_xlim([0, 1]) 213 | ax.set_ylim([0, 1]) 214 | plt.savefig('pr.png', dpi=200) 215 | return recalls, precisions 216 | 217 | 218 | 219 | def parse_dbStruct_pitts(path): 220 | dbStruct = namedtuple('dbStruct', ['whichSet', 'dataset', 'dbImage', 'utmDb', 'qImage', 'utmQ', 'numDb', 'numQ', 'posDistThr', 'posDistSqThr', 'nonTrivPosDistSqThr']) 221 | 222 | mat = loadmat(path) 223 | matStruct = mat['dbStruct'].item() 224 | 225 | dataset = 'pitts' 226 | 227 | whichSet = matStruct[0].item() 228 | 229 | # .mat file is generated by python, I replace the use of cell (in Matlab) with char (in Python) 230 | dbImage = [f[0].item() for f in matStruct[1]] 231 | # dbImage = matStruct[1] 232 | utmDb = matStruct[2].T 233 | # utmDb = matStruct[2] 234 | 235 | # .mat file is generated by python, I replace the use of cell (in Matlab) with char (in Python) 236 | qImage = [f[0].item() for f in matStruct[3]] 237 | # qImage = matStruct[3] 238 | utmQ = matStruct[4].T 239 | # utmQ = matStruct[4] 240 | 241 | numDb = matStruct[5].item() 242 | numQ = matStruct[6].item() 243 | 244 | posDistThr = matStruct[7].item() 245 | posDistSqThr = matStruct[8].item() 246 | nonTrivPosDistSqThr = matStruct[9].item() 247 | 248 | return dbStruct(whichSet, dataset, dbImage, utmDb, qImage, utmQ, numDb, numQ, posDistThr, posDistSqThr, nonTrivPosDistSqThr) 249 | 250 | def cal_hs(img_path): 251 | img = io.imread(img_path, as_gray=True).reshape(-1, 1) 252 | counts, bins = np.histogram((img * 255).astype(np.int16), np.arange(0, 256, 1)) 253 | counts = counts / np.sum(counts) 254 | cumulative = np.cumsum(counts) 255 | in_min = np.min((img*255).astype(np.int16)) 256 | in_max = np.max((img*255).astype(np.int16)) 257 | per_75 = np.argwhere(cumulative < 0.75)[-1] 258 | per_25 = np.argwhere(cumulative < 0.25)[-1] 259 | hs = (per_75 - per_25)/255 260 | return hs 261 | 262 | if __name__ == '__main__': 263 | pass 264 | -------------------------------------------------------------------------------- /vis.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import pickle 3 | import numpy as np 4 | from matplotlib import pyplot as plt 5 | from tqdm import tqdm 6 | from os.path import join 7 | import utils 8 | import importlib 9 | import faiss 10 | importlib.reload(utils) 11 | 12 | # --------------------------------------------------------------------------------------------------------------------- # 13 | NETWORK = 'contrast' 14 | # Choose NETWORK from |'teacher_triplet'|'student_contrast'|'student_triplet'|'student_quadruplet'| 15 | # --------------------------------------------------------------------------------------------------------------------- # 16 | 17 | NUM_BINS = 11 18 | SHOW_AP = True 19 | exp = NETWORK 20 | resume = join('logs', NETWORK) 21 | 22 | with open(join(resume, 'stu_30k.pickle'), 'rb') as handle: 23 | q_mu = pickle.load(handle) 24 | db_mu = pickle.load(handle) 25 | q_sigma_sq = pickle.load(handle) 26 | db_sigma_sq = pickle.load(handle) 27 | preds = pickle.load(handle) 28 | dists = pickle.load(handle) 29 | gt = pickle.load(handle) 30 | _ = pickle.load(handle) 31 | _ = pickle.load(handle) 32 | 33 | 34 | 35 | q_sigma_sq_h = np.mean(q_sigma_sq, axis=1) 36 | db_sigma_sq_h = np.mean(db_sigma_sq, axis=1) 37 | indices, _, _ = utils.get_zoomed_bins(q_sigma_sq_h, NUM_BINS) 38 | 39 | n_values = [1, 5, 10] 40 | 41 | #print(preds.shape) 42 | #print(gt) 43 | # ---------------------------- recognition metric ---------------------------- # 44 | recall = utils.cal_recall(preds, gt, n_values) / 100.0 45 | print('rec@1/5/10: {:.3f} / {:.3f} / {:.3f}'.format(recall[0], recall[1], recall[2])) 46 | map = [utils.cal_mapk(preds, gt, n) / 100.0 for n in n_values] 47 | print('mAP@1/5/10: {:.3f} / {:.3f} / {:.3f}'.format(map[0], map[1], map[2])) 48 | 49 | if SHOW_AP: 50 | recalls, precisions = utils.bin_pr(preds, dists, gt) 51 | ap = 0 52 | for index_j in range(len(recalls) - 1): 53 | ap += precisions[index_j] * (recalls[index_j + 1] - recalls[index_j]) 54 | 55 | print('AP: {:.3f}'.format(ap)) 56 | 57 | 58 | 59 | --------------------------------------------------------------------------------