├── lib ├── layers │ ├── __init__.py │ ├── normalization.py │ ├── pooling.py │ ├── loss.py │ └── functional.py ├── utils │ ├── __init__.py │ ├── general.py │ ├── flops.py │ ├── whiten.py │ ├── evaluate.py │ ├── download.py │ └── download_win.py ├── datasets │ ├── __init__.py │ ├── testdataset.py │ ├── datahelpers.py │ ├── diffusion.py │ ├── genericdataset.py │ └── traindataset.py ├── networks │ ├── __init__.py │ ├── mobilenet_v2.py │ └── imageretrievalnet.py ├── __init__.py ├── parse_args.py └── cli.py ├── dataset.py ├── README.md ├── training.py ├── download.py ├── extract_features.py ├── main.py └── test.py /lib/layers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/networks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datasets, layers, networks, utils 2 | 3 | from .datasets import datahelpers, genericdataset, testdataset, traindataset 4 | from .layers import functional, loss, normalization, pooling 5 | from .networks import imageretrievalnet 6 | from .utils import general, download, evaluate, whiten -------------------------------------------------------------------------------- /lib/layers/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import lib.layers.functional as LF 5 | 6 | # -------------------------------------- 7 | # Normalization layers 8 | # -------------------------------------- 9 | 10 | class L2N(nn.Module): 11 | 12 | def __init__(self, eps=1e-6): 13 | super(L2N,self).__init__() 14 | self.eps = eps 15 | 16 | def forward(self, x): 17 | return LF.l2n(x, eps=self.eps) 18 | 19 | def __repr__(self): 20 | return self.__class__.__name__ + '(' + 'eps=' + str(self.eps) + ')' 21 | 22 | 23 | class PowerLaw(nn.Module): 24 | 25 | def __init__(self, eps=1e-6): 26 | super(PowerLaw, self).__init__() 27 | self.eps = eps 28 | 29 | def forward(self, x): 30 | return LF.powerlaw(x, eps=self.eps) 31 | 32 | def __repr__(self): 33 | return self.__class__.__name__ + '(' + 'eps=' + str(self.eps) + ')' -------------------------------------------------------------------------------- /lib/utils/general.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | 4 | def get_root(): 5 | return os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))) 6 | 7 | 8 | def get_data_root(): 9 | return os.path.join(get_root(), 'data') 10 | 11 | 12 | def htime(c): 13 | c = round(c) 14 | 15 | days = c // 86400 16 | hours = c // 3600 % 24 17 | minutes = c // 60 % 60 18 | seconds = c % 60 19 | 20 | if days > 0: 21 | return '{:d}d {:d}h {:d}m {:d}s'.format(days, hours, minutes, seconds) 22 | if hours > 0: 23 | return '{:d}h {:d}m {:d}s'.format(hours, minutes, seconds) 24 | if minutes > 0: 25 | return '{:d}m {:d}s'.format(minutes, seconds) 26 | return '{:d}s'.format(seconds) 27 | 28 | 29 | def sha256_hash(filename, block_size=65536, length=8): 30 | sha256 = hashlib.sha256() 31 | with open(filename, 'rb') as f: 32 | for block in iter(lambda: f.read(block_size), b''): 33 | sha256.update(block) 34 | return sha256.hexdigest()[:length-1] -------------------------------------------------------------------------------- /lib/datasets/testdataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import pdb 4 | 5 | DATASETS = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k', 'retrieval-sfm-120k', 'revisitop1m', 'instre'] 6 | 7 | def configdataset(dataset, dir_main): 8 | 9 | dataset = dataset.lower() 10 | if dataset not in DATASETS: 11 | raise ValueError('Unknown dataset: {}!'.format(dataset)) 12 | 13 | # loading imlist, qimlist, and gnd, in cfg as a dict 14 | gnd_fname = os.path.join(dir_main, dataset, 'gnd_{}.pkl'.format(dataset)) 15 | with open(gnd_fname, 'rb') as f: 16 | cfg = pickle.load(f) 17 | cfg['gnd_fname'] = gnd_fname 18 | 19 | cfg['ext'] = '.jpg' 20 | cfg['qext'] = '.jpg' 21 | cfg['dir_data'] = os.path.join(dir_main, dataset) 22 | cfg['dir_images'] = os.path.join(cfg['dir_data'], 'jpg') 23 | 24 | cfg['n'] = len(cfg['imlist']) 25 | cfg['nq'] = len(cfg['qimlist']) 26 | 27 | cfg['im_fname'] = config_imname 28 | cfg['qim_fname'] = config_qimname 29 | 30 | cfg['dataset'] = dataset 31 | 32 | return cfg 33 | 34 | def config_imname(cfg, i): 35 | return os.path.join(cfg['dir_images'], cfg['imlist'][i] + cfg['ext']) 36 | 37 | def config_qimname(cfg, i): 38 | return os.path.join(cfg['dir_images'], cfg['qimlist'][i] + cfg['qext']) 39 | -------------------------------------------------------------------------------- /lib/parse_args.py: -------------------------------------------------------------------------------- 1 | 2 | def from_args_to_string(args): 3 | # create export dir if it doesnt exist 4 | directory = "{}".format(args.training_dataset) 5 | directory += "_{}".format(args.arch) 6 | directory += "_{}".format(args.pool) 7 | if args.local_whitening: 8 | directory += "_lwhiten" 9 | if args.regional: 10 | directory += "_r" 11 | if args.whitening: 12 | directory += "_whiten" 13 | if not args.pretrained: 14 | directory += "_notpretrained" 15 | if args.test_whiten: 16 | directory += "_test_whiten_on_{}".format(args.test_whiten) 17 | directory += "_{}_m{:.2f}".format(args.loss, args.loss_margin) 18 | directory += "_{}_lr{:.1e}_wd{:.1e}".format(args.optimizer, args.lr, args.weight_decay) 19 | directory += "_nnum{}_qsize{}_psize{}".format(args.neg_num, args.query_size, args.pool_size) 20 | directory += "_bsize{}_uevery{}_imsize{}".format(args.batch_size, args.update_every, args.image_size) 21 | directory += "_temp{}".format(args.temp) 22 | directory += "_{}".format(args.mode) 23 | if args.mode == "ap": 24 | directory += "_nexamples_{}".format(args.nexamples) 25 | #if args.ts: 26 | # directory += "_ts" 27 | #if args.reg: 28 | # directory += "_reg" 29 | directory += "_teach_{}".format(args.teacher) 30 | directory += args.comment 31 | return directory -------------------------------------------------------------------------------- /lib/utils/flops.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torchvision.models as models 4 | import torch 5 | from ptflops import get_model_complexity_info 6 | 7 | from lib.networks.imageretrievalnet import init_network 8 | 9 | parser = argparse.ArgumentParser(description='PyTorch CNN Flops calculation') 10 | 11 | # network 12 | group = parser.add_mutually_exclusive_group(required=True) 13 | group.add_argument('--network-path', '-npath', metavar='NETWORK', 14 | help="pretrained network or network path (destination where network is saved)") 15 | 16 | 17 | def main(): 18 | args = parser.parse_args() 19 | 20 | state = torch.load(args.network_path) 21 | 22 | net_params = {} 23 | net_params['architecture'] = state['meta']['architecture'] 24 | net_params['pooling'] = state['meta']['pooling'] 25 | net_params['local_whitening'] = state['meta'].get('local_whitening', False) 26 | net_params['regional'] = state['meta'].get('regional', False) 27 | net_params['whitening'] = state['meta'].get('whitening', False) 28 | net_params['mean'] = state['meta']['mean'] 29 | net_params['std'] = state['meta']['std'] 30 | net_params['pretrained'] = False 31 | net_params['teacher'] = 'resnet101' 32 | 33 | net = init_network(net_params) 34 | net.cuda() 35 | flops, params = get_model_complexity_info(net, (3, 362, 362), as_strings=True, print_per_layer_stat=True) 36 | print('{:<30} {:<8}'.format('Computational complexity: ', flops)) 37 | print('{:<30} {:<8}'.format('Number of parameters: ', params)) 38 | 39 | 40 | if __name__ == '__main__': 41 | main() -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | DATASETS = ['roxford5k', 'rparis6k', 'revisitop1m'] 5 | 6 | def configdataset(dataset, dir_main): 7 | 8 | dataset = dataset.lower() 9 | 10 | if dataset not in DATASETS: 11 | raise ValueError('Unknown dataset: {}!'.format(dataset)) 12 | 13 | if dataset == 'roxford5k' or dataset == 'rparis6k': 14 | # loading imlist, qimlist, and gnd, in cfg as a dict 15 | gnd_fname = os.path.join(dir_main, dataset, 'gnd_{}.pkl'.format(dataset)) 16 | with open(gnd_fname, 'rb') as f: 17 | cfg = pickle.load(f) 18 | cfg['gnd_fname'] = gnd_fname 19 | cfg['ext'] = '.jpg' 20 | cfg['qext'] = '.jpg' 21 | 22 | elif dataset == 'revisitop1m': 23 | # loading imlist from a .txt file 24 | cfg = {} 25 | cfg['imlist_fname'] = os.path.join(dir_main, dataset, '{}.txt'.format(dataset)) 26 | cfg['imlist'] = read_imlist(cfg['imlist_fname']) 27 | cfg['qimlist'] = [] 28 | cfg['ext'] = '' 29 | cfg['qext'] = '' 30 | 31 | cfg['dir_data'] = os.path.join(dir_main, dataset) 32 | cfg['dir_images'] = os.path.join(cfg['dir_data'], 'jpg') 33 | 34 | cfg['n'] = len(cfg['imlist']) 35 | cfg['nq'] = len(cfg['qimlist']) 36 | 37 | cfg['im_fname'] = config_imname 38 | cfg['qim_fname'] = config_qimname 39 | 40 | cfg['dataset'] = dataset 41 | 42 | return cfg 43 | 44 | def config_imname(cfg, i): 45 | return os.path.join(cfg['dir_images'], cfg['imlist'][i] + cfg['ext']) 46 | 47 | def config_qimname(cfg, i): 48 | return os.path.join(cfg['dir_images'], cfg['qimlist'][i] + cfg['qext']) 49 | 50 | def read_imlist(imlist_fn): 51 | with open(imlist_fn, 'r') as file: 52 | imlist = file.read().splitlines() 53 | return imlist 54 | -------------------------------------------------------------------------------- /lib/utils/whiten.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | def whitenapply(X, m, P, dimensions=None): 5 | 6 | if not dimensions: 7 | dimensions = P.shape[0] 8 | 9 | X = np.dot(P[:dimensions, :], X-m) 10 | X = X / (np.linalg.norm(X, ord=2, axis=0, keepdims=True) + 1e-6) 11 | 12 | return X 13 | 14 | def pcawhitenlearn(X): 15 | 16 | N = X.shape[1] 17 | 18 | # Learning PCA w/o annotations 19 | m = X.mean(axis=1, keepdims=True) 20 | Xc = X - m 21 | Xcov = np.dot(Xc, Xc.T) 22 | Xcov = (Xcov + Xcov.T) / (2*N) 23 | eigval, eigvec = np.linalg.eig(Xcov) 24 | order = eigval.argsort()[::-1] 25 | eigval = eigval[order] 26 | eigvec = eigvec[:, order] 27 | 28 | P = np.dot(np.linalg.inv(np.sqrt(np.diag(eigval))), eigvec.T) 29 | 30 | return m, P 31 | 32 | def whitenlearn(X, qidxs, pidxs): 33 | 34 | # Learning Lw w annotations 35 | pdb.set_trace() 36 | m = X[:, qidxs].mean(axis=1, keepdims=True) 37 | df = X[:, qidxs] - X[:, pidxs] 38 | S = np.dot(df, df.T) / df.shape[1] 39 | P = np.linalg.inv(cholesky(S)) 40 | df = np.dot(P, X-m) 41 | D = np.dot(df, df.T) 42 | eigval, eigvec = np.linalg.eig(D) 43 | order = eigval.argsort()[::-1] 44 | eigval = eigval[order] 45 | eigvec = eigvec[:, order] 46 | 47 | P = np.dot(eigvec.T, P) 48 | 49 | return m, P 50 | 51 | def cholesky(S): 52 | # Cholesky decomposition 53 | # with adding a small value on the diagonal 54 | # until matrix is positive definite 55 | alpha = 0 56 | while 1: 57 | try: 58 | L = np.linalg.cholesky(S + alpha*np.eye(*S.shape)) 59 | return L 60 | except: 61 | if alpha == 0: 62 | alpha = 1e-10 63 | else: 64 | alpha *= 10 65 | print(">>>> {}::cholesky: Matrix is not positive definite, adding {:.0e} on the diagonal" 66 | .format(os.path.basename(__file__), alpha)) 67 | -------------------------------------------------------------------------------- /lib/datasets/datahelpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image, ImageFile 3 | 4 | import torch 5 | 6 | 7 | 8 | def cid2filename(cid, prefix): 9 | """ 10 | Creates a training image path out of its CID name 11 | 12 | Arguments 13 | --------- 14 | cid : name of the image 15 | prefix : root directory where images are saved 16 | 17 | Returns 18 | ------- 19 | filename : full image filename 20 | """ 21 | return os.path.join(prefix, cid[-2:], cid[-4:-2], cid[-6:-4], cid) 22 | 23 | def pil_loader(path): 24 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 25 | ImageFile.LOAD_TRUNCATED_IMAGES = True 26 | with open(path, 'rb') as f: 27 | img = Image.open(f) 28 | return img.convert('RGB') 29 | 30 | def accimage_loader(path): 31 | import accimage 32 | try: 33 | return accimage.Image(path) 34 | except IOError: 35 | # Potentially a decoding problem, fall back to PIL.Image 36 | return pil_loader(path) 37 | 38 | def default_loader(path): 39 | from torchvision import get_image_backend 40 | if get_image_backend() == 'accimage': 41 | return accimage_loader(path) 42 | else: 43 | return pil_loader(path) 44 | 45 | def imresize(img, imsize): 46 | img.thumbnail((imsize, imsize), Image.ANTIALIAS) 47 | return img 48 | 49 | def flip(x, dim): 50 | xsize = x.size() 51 | dim = x.dim() + dim if dim < 0 else dim 52 | x = x.view(-1, *xsize[dim:]) 53 | x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1, -1, -1), ('cpu','cuda')[x.is_cuda])().long(), :] 54 | return x.view(xsize) 55 | 56 | def collate_tuples(batch): 57 | if len(batch) == 1: 58 | return [batch[0][0]], [batch[0][1]] 59 | return [batch[i][0] for i in range(len(batch))], [batch[i][1] for i in range(len(batch))] 60 | 61 | def collate_tuples_dist(batch): 62 | if len(batch) == 1: 63 | return [batch[0][0]], [batch[0][1]] 64 | return [batch[i][0] for i in range(len(batch))], [batch[i][1] for i in range(len(batch))], [batch[i][2] for i in range(len(batch))] -------------------------------------------------------------------------------- /lib/layers/pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | 5 | import lib.layers.functional as LF 6 | from lib.layers.normalization import L2N 7 | 8 | # -------------------------------------- 9 | # Pooling layers 10 | # -------------------------------------- 11 | 12 | class MAC(nn.Module): 13 | 14 | def __init__(self): 15 | super(MAC,self).__init__() 16 | 17 | def forward(self, x): 18 | return LF.mac(x) 19 | 20 | def __repr__(self): 21 | return self.__class__.__name__ + '()' 22 | 23 | 24 | class SPoC(nn.Module): 25 | 26 | def __init__(self): 27 | super(SPoC,self).__init__() 28 | 29 | def forward(self, x): 30 | return LF.spoc(x) 31 | 32 | def __repr__(self): 33 | return self.__class__.__name__ + '()' 34 | 35 | 36 | class GeM(nn.Module): 37 | 38 | def __init__(self, p=3, eps=1e-6): 39 | super(GeM,self).__init__() 40 | self.p = Parameter(torch.ones(1)*p) 41 | self.eps = eps 42 | 43 | def forward(self, x): 44 | return LF.gem(x, p=self.p, eps=self.eps) 45 | 46 | def __repr__(self): 47 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' 48 | 49 | class GeMmp(nn.Module): 50 | 51 | def __init__(self, p=3, mp=1, eps=1e-6): 52 | super(GeMmp,self).__init__() 53 | self.p = Parameter(torch.ones(mp)*p) 54 | self.mp = mp 55 | self.eps = eps 56 | 57 | def forward(self, x): 58 | return LF.gem(x, p=self.p.unsqueeze(-1).unsqueeze(-1), eps=self.eps) 59 | 60 | def __repr__(self): 61 | return self.__class__.__name__ + '(' + 'p=' + '[{}]'.format(self.mp) + ', ' + 'eps=' + str(self.eps) + ')' 62 | 63 | class RMAC(nn.Module): 64 | 65 | def __init__(self, L=3, eps=1e-6): 66 | super(RMAC,self).__init__() 67 | self.L = L 68 | self.eps = eps 69 | 70 | def forward(self, x): 71 | return LF.rmac(x, L=self.L, eps=self.eps) 72 | 73 | def __repr__(self): 74 | return self.__class__.__name__ + '(' + 'L=' + '{}'.format(self.L) + ')' 75 | 76 | 77 | class Rpool(nn.Module): 78 | 79 | def __init__(self, rpool, whiten=None, L=3, eps=1e-6): 80 | super(Rpool,self).__init__() 81 | self.rpool = rpool 82 | self.L = L 83 | self.whiten = whiten 84 | self.norm = L2N() 85 | self.eps = eps 86 | 87 | def forward(self, x, aggregate=True): 88 | # features -> roipool 89 | o = LF.roipool(x, self.rpool, self.L, self.eps) # size: #im, #reg, D, 1, 1 90 | 91 | # concatenate regions from all images in the batch 92 | s = o.size() 93 | o = o.view(s[0]*s[1], s[2], s[3], s[4]) # size: #im x #reg, D, 1, 1 94 | 95 | # rvecs -> norm 96 | o = self.norm(o) 97 | 98 | # rvecs -> whiten -> norm 99 | if self.whiten is not None: 100 | o = self.norm(self.whiten(o.squeeze(-1).squeeze(-1))) 101 | 102 | # reshape back to regions per image 103 | o = o.view(s[0], s[1], s[2], s[3], s[4]) # size: #im, #reg, D, 1, 1 104 | 105 | # aggregate regions into a single global vector per image 106 | if aggregate: 107 | # rvecs -> sumpool -> norm 108 | o = self.norm(o.sum(1, keepdim=False)) # size: #im, D, 1, 1 109 | 110 | return o 111 | 112 | def __repr__(self): 113 | return super(Rpool, self).__repr__() + '(' + 'L=' + '{}'.format(self.L) + ')' -------------------------------------------------------------------------------- /lib/datasets/diffusion.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import os 3 | import numpy as np 4 | from scipy.io import loadmat 5 | from scipy.sparse import csr_matrix, eye, diags 6 | from scipy.sparse import linalg as s_linalg 7 | 8 | def sim_kernel(dot_product): 9 | return np.maximum(np.power(dot_product,3),0) 10 | 11 | def normalize_connection_graph(G): 12 | W = csr_matrix(G) 13 | W = W - diags(W.diagonal()) 14 | D = np.array(1./ np.sqrt(W.sum(axis = 1))) 15 | D[np.isnan(D)] = 0 16 | D[np.isinf(D)] = 0 17 | D_mh = diags(D.reshape(-1)) 18 | Wn = D_mh * W * D_mh 19 | return Wn 20 | 21 | def topK_W(G, K = 100): 22 | sortidxs = np.argsort(-G, axis = 1) 23 | for i in range(G.shape[0]): 24 | G[i,sortidxs[i,K:]] = 0 25 | G = np.minimum(G, G.T) 26 | return G 27 | 28 | def find_trunc_graph(qs, W, levels = 3): 29 | needed_idxs = [] 30 | needed_idxs = list(np.nonzero(qs > 0)[0]) 31 | for l in range(levels): 32 | idid = W.nonzero()[1] 33 | needed_idxs.extend(list(idid)) 34 | needed_idxs =list(set(needed_idxs)) 35 | return np.array(needed_idxs), W[needed_idxs,:][:,needed_idxs] 36 | 37 | def dfs_trunk(sim, A,alpha = 0.99, QUERYKNN = 10, maxiter = 8, K = 100, tol = 1e-3): 38 | qsim = sim_kernel(sim).T 39 | sortidxs = np.argsort(-qsim, axis = 1) 40 | for i in range(len(qsim)): 41 | qsim[i,sortidxs[i,QUERYKNN:]] = 0 42 | qsims = sim_kernel(qsim) 43 | W = sim_kernel(A) 44 | W = csr_matrix(topK_W(W, K)) 45 | out_ranks = [] 46 | # t =time() 47 | for i in range(qsims.shape[0]): 48 | qs = qsims[i,:] 49 | tt = time() 50 | w_idxs, W_trunk = find_trunc_graph(qs, W, 2); 51 | Wn = normalize_connection_graph(W_trunk) 52 | Wnn = eye(Wn.shape[0]) - alpha * Wn 53 | f,inf = s_linalg.minres(Wnn, qs[w_idxs], tol=tol, maxiter=maxiter) 54 | ranks = w_idxs[np.argsort(-f.reshape(-1))] 55 | missing = np.setdiff1d(np.arange(A.shape[1]), ranks) 56 | out_ranks.append(np.concatenate([ranks.reshape(-1,1), missing.reshape(-1,1)], axis = 0)) 57 | # print time() -t, 'qtime' 58 | out_ranks = np.concatenate(out_ranks, axis = 1) 59 | return out_ranks 60 | 61 | def cg_diffusion(qsims, Wn, alpha = 0.99, maxiter = 20, tol = 1e-6): 62 | Wnn = eye(Wn.shape[0]) - alpha * Wn 63 | #pdb.set_trace() 64 | out_sims = [] 65 | for i in range(qsims.shape[0]): 66 | f,inf = s_linalg.cg(Wnn, qsims[i,:], tol=tol, maxiter=maxiter) 67 | # f,inf = s_linalg.minres(Wnn, qsims[i,:], tol=tol, maxiter=maxiter) 68 | out_sims.append(f.reshape(-1,1)) 69 | 70 | out_sims = np.concatenate(out_sims, axis = 1) 71 | #pdb.set_trace() 72 | ranks = np.argsort(-out_sims, axis = 0) 73 | # pdb.set_trace() 74 | return ranks, out_sims 75 | 76 | def cg_diffusion_sel(qsims, Wnn, maxiter = 20, tol = 1e-6): 77 | 78 | #pdb.set_trace() 79 | f,inf = s_linalg.cg(Wnn, qsims, tol=tol, maxiter=maxiter) 80 | 81 | #pdb.set_trace() 82 | ranks = np.argsort(-f, axis = 0) 83 | # pdb.set_trace() 84 | return ranks, f 85 | 86 | def fsr_rankR(qsims, Wn, alpha = 0.99, R = 2000): 87 | vals, vecs = s_linalg.eigsh(Wn, k = R) 88 | p2 = diags((1.0 - alpha) / (1.0 - alpha*vals)) 89 | vc = csr_matrix(vecs) 90 | p3 = vc.dot(p2) 91 | vc_norm = (vc.multiply(vc)).sum(axis = 0) 92 | out_sims = [] 93 | for i in range(qsims.shape[0]): 94 | qsims_sparse = csr_matrix(qsims[i:i+1,:]) 95 | p1 =(vc.T).dot(qsims_sparse.T) 96 | diff_sim = csr_matrix(p3)*csr_matrix(p1) 97 | out_sims.append(diff_sim.todense().reshape(-1,1)) 98 | out_sims = np.concatenate(out_sims, axis = 1) 99 | ranks = np.argsort(-out_sims, axis = 0) 100 | return ranks -------------------------------------------------------------------------------- /lib/datasets/genericdataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | 4 | import torch 5 | import torch.utils.data as data 6 | 7 | from lib.datasets.datahelpers import default_loader, imresize 8 | 9 | 10 | class ImagesFromList(data.Dataset): 11 | """A generic data loader that loads images from a list 12 | (Based on ImageFolder from pytorch) 13 | 14 | Args: 15 | root (string): Root directory path. 16 | images (list): Relative image paths as strings. 17 | imsize (int, Default: None): Defines the maximum size of longer image side 18 | bbxs (list): List of (x1,y1,x2,y2) tuples to crop the query images 19 | transform (callable, optional): A function/transform that takes in an PIL image 20 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 21 | loader (callable, optional): A function to load an image given its path. 22 | 23 | Attributes: 24 | images_fn (list): List of full image filename 25 | """ 26 | 27 | def __init__(self, root, images, imsize=None, bbxs=None, transform=None, loader=default_loader): 28 | 29 | images_fn = [os.path.join(root,images[i]) for i in range(len(images))] 30 | 31 | if len(images_fn) == 0: 32 | raise(RuntimeError("Dataset contains 0 images!")) 33 | 34 | self.root = root 35 | self.images = images 36 | self.imsize = imsize 37 | self.images_fn = images_fn 38 | self.bbxs = bbxs 39 | self.transform = transform 40 | self.loader = loader 41 | 42 | def __getitem__(self, index): 43 | """ 44 | Args: 45 | index (int): Index 46 | 47 | Returns: 48 | image (PIL): Loaded image 49 | """ 50 | path = self.images_fn[index] 51 | img = self.loader(path) 52 | imfullsize = max(img.size) 53 | 54 | if self.bbxs is not None: 55 | img = img.crop(self.bbxs[index]) 56 | 57 | if self.imsize is not None: 58 | if self.bbxs is not None: 59 | img = imresize(img, self.imsize * max(img.size) / imfullsize) 60 | else: 61 | img = imresize(img, self.imsize) 62 | 63 | if self.transform is not None: 64 | img = self.transform(img) 65 | 66 | return img 67 | 68 | def __len__(self): 69 | return len(self.images_fn) 70 | 71 | def __repr__(self): 72 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 73 | fmt_str += ' Number of images: {}\n'.format(self.__len__()) 74 | fmt_str += ' Root Location: {}\n'.format(self.root) 75 | tmp = ' Transforms (if any): ' 76 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 77 | return fmt_str 78 | 79 | class ImagesFromDataList(data.Dataset): 80 | """A generic data loader that loads images given as an array of pytorch tensors 81 | (Based on ImageFolder from pytorch) 82 | 83 | Args: 84 | images (list): Images as tensors. 85 | transform (callable, optional): A function/transform that image as a tensors 86 | and returns a transformed version. E.g, ``normalize`` with mean and std 87 | """ 88 | 89 | def __init__(self, images, transform=None): 90 | 91 | if len(images) == 0: 92 | raise(RuntimeError("Dataset contains 0 images!")) 93 | 94 | self.images = images 95 | self.transform = transform 96 | 97 | def __getitem__(self, index): 98 | """ 99 | Args: 100 | index (int): Index 101 | 102 | Returns: 103 | image (Tensor): Loaded image 104 | """ 105 | img = self.images[index] 106 | if self.transform is not None: 107 | img = self.transform(img) 108 | 109 | if len(img.size()): 110 | img = img.unsqueeze(0) 111 | 112 | return img 113 | 114 | def __len__(self): 115 | return len(self.images) 116 | 117 | def __repr__(self): 118 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 119 | fmt_str += ' Number of images: {}\n'.format(self.__len__()) 120 | tmp = ' Transforms (if any): ' 121 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 122 | return fmt_str -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Asymmetric metric learning 2 | 3 | This is the official code that enables the reproduction of the results from our paper: 4 | 5 | **Asymmetric metric learning for knowledge transfer**, 6 | Budnik M., Avrithis Y. 7 | [[arXiv](https://arxiv.org/abs/2006.16331)] 8 | 9 | ### Content 10 | 11 | This repository provides the means to train and test all the models presented in the paper. This includes: 12 | 13 | 1. Code to train the models with and without the teacher (asymmetric and symmetric). 14 | 1. Code to do symmetric and asymmetric testing on rOxford and rParis datasets. 15 | 1. Best pre-trainend models (including whitening). 16 | 17 | ### Dependencies 18 | 19 | 1. Python3 (tested on version 3.6) 20 | 1. Numpy 1.19 21 | 1. PyTorch (tested on version 1.4.0) 22 | 1. Datasets and base models will be downloaded automatically. 23 | 24 | 25 | ### Training and testing the networks 26 | 27 | To train a model use the following script: 28 | ```bash 29 | python main.py [-h] [--training-dataset DATASET] [--directory EXPORT_DIR] [--no-val] 30 | [--test-datasets DATASETS] [--test-whiten DATASET] 31 | [--val-freq N] [--save-freq N] [--arch ARCH] [--pool POOL] 32 | [--local-whitening] [--regional] [--whitening] 33 | [--not-pretrained] [--loss LOSS] [--loss-margin LM] 34 | [--mode MODE] [--teacher TEACHER] [--sym] 35 | [--feat-path FEAT] [--feat-val-path FEATVAL] 36 | [--image-size N] [--neg-num N] [--query-size N] 37 | [--pool-size N] [--gpu-id N] [--workers N] [--epochs N] 38 | [--batch-size N] [--optimizer OPTIMIZER] [--lr LR] 39 | [--momentum M] [--weight-decay W] [--print-freq N] 40 | [--resume FILENAME] [--comment COMMENT] 41 | 42 | ``` 43 | Most parameters are the same as in [CNN Image Retrieval in PyTorch](https://github.com/filipradenovic/cnnimageretrieval-pytorch). Here, we describe parameters added or modified in this work, namely: 44 | --arch - architecture of the model to be trained, in our case the student. 45 | --mode - is the training mode, which determines how the dataset is handled, e.g. are the tuples constructed randomly or with mining; which examples are coming from the teacher vs student, etc. So for example while the --loss is set to 'contrastive', 'ts' enables standard student-teacher training (includes mining), 'ts_self' trains using the Contr+ approach, 'reg' uses the regression. When using 'rand' or 'reg' no mining is used. With 'std' it follows the original training protocol from [here](https://github.com/filipradenovic/cnnimageretrieval-pytorch) (the teacher model is not used). 46 | --teacher - the model of the teacher(vgg16 or resnet101), note that this param makes the last layer of the student match that of the teacher. Therefore, this can be used even in a standard symmetric training. 47 | --sym - a flag that indicates if the training should be symmetric or asymmetric. 48 | --feat-path and --feat-val-path - a path to the extracted teacher features used to train the student. The features can be extracted using the extract_features.py script. 49 | 50 | To perform a symmetric test of the model that is already trained: 51 | ```bash 52 | python test.py [-h] (--network-path NETWORK | --network-offtheshelf NETWORK) 53 | [--datasets DATASETS] [--image-size N] [--multiscale MULTISCALE] 54 | [--whitening WHITENING] [--teacher TEACHER] 55 | ``` 56 | For the asymmetric testing: 57 | 58 | ```bash 59 | python test.py [-h] (--network-path NETWORK | --network-offtheshelf NETWORK) 60 | [--datasets DATASETS] [--image-size N] [--multiscale MULTISCALE] 61 | [--whitening WHITENING] [--teacher TEACHER] [--asym] 62 | ``` 63 | 64 | Examples: 65 | 66 | Perform a symmetric test with a pre-trained model: 67 | 68 | ```bash 69 | 70 | python test.py -npath mobilenet-v2-gem-contr-vgg16 -d 'roxford5k,rparis6k' -ms '[1, 1/2**(1/2), 1/2]' -w retrieval-SfM-120k --teacher vgg16 71 | ``` 72 | 73 | For an asymmetric test: 74 | 75 | ```bash 76 | 77 | python test.py -npath mobilenet-v2-gem-contr-vgg16 -d 'roxford5k,rparis6k' -ms '[1, 1/2**(1/2), 1/2]' -w retrieval-SfM-120k --teacher vgg16 --asym 78 | ``` 79 | 80 | If you are interested in just the trained models, you can find the links to them in the test.py file. 81 | 82 | ### Acknowledgements 83 | 84 | This code is adapted and modified based on the amazing repository by F. Radenović called 85 | [CNN Image Retrieval in PyTorch: Training and evaluating CNNs for Image Retrieval in PyTorch](https://github.com/filipradenovic/cnnimageretrieval-pytorch) 86 | 87 | -------------------------------------------------------------------------------- /lib/networks/mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | import math 4 | 5 | __all__ = ['mobilenetv2'] 6 | 7 | 8 | def _make_divisible(v, divisor, min_value=None): 9 | """ 10 | This function is taken from the original tf repo. 11 | It ensures that all layers have a channel number that is divisible by 8 12 | It can be seen here: 13 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 14 | :param v: 15 | :param divisor: 16 | :param min_value: 17 | :return: 18 | """ 19 | if min_value is None: 20 | min_value = divisor 21 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 22 | # Make sure that round down does not go down by more than 10%. 23 | if new_v < 0.9 * v: 24 | new_v += divisor 25 | return new_v 26 | 27 | 28 | def conv_3x3_bn(inp, oup, stride): 29 | return nn.Sequential( 30 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 31 | nn.BatchNorm2d(oup), 32 | nn.ReLU6(inplace=True) 33 | ) 34 | 35 | 36 | def conv_1x1_bn(inp, oup): 37 | return nn.Sequential( 38 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 39 | nn.BatchNorm2d(oup), 40 | nn.ReLU6(inplace=True) 41 | ) 42 | 43 | 44 | class InvertedResidual(nn.Module): 45 | def __init__(self, inp, oup, stride, expand_ratio): 46 | super(InvertedResidual, self).__init__() 47 | assert stride in [1, 2] 48 | 49 | hidden_dim = round(inp * expand_ratio) 50 | self.identity = stride == 1 and inp == oup 51 | 52 | if expand_ratio == 1: 53 | self.conv = nn.Sequential( 54 | # dw 55 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 56 | nn.BatchNorm2d(hidden_dim), 57 | nn.ReLU6(inplace=True), 58 | # pw-linear 59 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 60 | nn.BatchNorm2d(oup), 61 | ) 62 | else: 63 | self.conv = nn.Sequential( 64 | # pw 65 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), 66 | nn.BatchNorm2d(hidden_dim), 67 | nn.ReLU6(inplace=True), 68 | # dw 69 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), 70 | nn.BatchNorm2d(hidden_dim), 71 | nn.ReLU6(inplace=True), 72 | # pw-linear 73 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 74 | nn.BatchNorm2d(oup), 75 | ) 76 | 77 | def forward(self, x): 78 | if self.identity: 79 | return x + self.conv(x) 80 | else: 81 | return self.conv(x) 82 | 83 | 84 | class MobileNetV2(nn.Module): 85 | def __init__(self, num_classes=1000, width_mult=1.): 86 | super(MobileNetV2, self).__init__() 87 | # setting of inverted residual blocks 88 | self.cfgs = [ 89 | # t, c, n, s 90 | [1, 16, 1, 1], 91 | [6, 24, 2, 2], 92 | [6, 32, 3, 2], 93 | [6, 64, 4, 2], 94 | [6, 96, 3, 1], 95 | [6, 160, 3, 2], 96 | [6, 320, 1, 1], 97 | ] 98 | 99 | # building first layer 100 | input_channel = _make_divisible(32 * width_mult, 4 if width_mult == 0.1 else 8) 101 | layers = [conv_3x3_bn(3, input_channel, 2)] 102 | # building inverted residual blocks 103 | block = InvertedResidual 104 | for t, c, n, s in self.cfgs: 105 | output_channel = _make_divisible(c * width_mult, 4 if width_mult == 0.1 else 8) 106 | for i in range(n): 107 | layers.append(block(input_channel, output_channel, s if i == 0 else 1, t)) 108 | input_channel = output_channel 109 | self.features = nn.Sequential(*layers) 110 | # building last several layers 111 | output_channel = _make_divisible(1280 * width_mult, 4 if width_mult == 0.1 else 8) if width_mult > 1.0 else 1280 112 | self.conv = conv_1x1_bn(input_channel, output_channel) 113 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 114 | self.classifier = nn.Linear(output_channel, num_classes) 115 | 116 | self._initialize_weights() 117 | 118 | def forward(self, x): 119 | x = self.features(x) 120 | x = self.conv(x) 121 | x = self.avgpool(x) 122 | x = x.view(x.size(0), -1) 123 | x = self.classifier(x) 124 | return x 125 | 126 | def _initialize_weights(self): 127 | for m in self.modules(): 128 | if isinstance(m, nn.Conv2d): 129 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 130 | m.weight.data.normal_(0, math.sqrt(2. / n)) 131 | if m.bias is not None: 132 | m.bias.data.zero_() 133 | elif isinstance(m, nn.BatchNorm2d): 134 | m.weight.data.fill_(1) 135 | m.bias.data.zero_() 136 | elif isinstance(m, nn.Linear): 137 | m.weight.data.normal_(0, 0.01) 138 | m.bias.data.zero_() 139 | 140 | def mobilenetv2(**kwargs): 141 | """ 142 | Constructs a MobileNet V2 model 143 | """ 144 | return MobileNetV2(**kwargs) 145 | 146 | -------------------------------------------------------------------------------- /lib/utils/evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def compute_ap(ranks, nres): 4 | """ 5 | Computes average precision for given ranked indexes. 6 | 7 | Arguments 8 | --------- 9 | ranks : zerro-based ranks of positive images 10 | nres : number of positive images 11 | 12 | Returns 13 | ------- 14 | ap : average precision 15 | """ 16 | 17 | # number of images ranked by the system 18 | nimgranks = len(ranks) 19 | 20 | # accumulate trapezoids in PR-plot 21 | ap = 0 22 | 23 | recall_step = 1. / nres 24 | 25 | for j in np.arange(nimgranks): 26 | rank = ranks[j] 27 | 28 | if rank == 0: 29 | precision_0 = 1. 30 | else: 31 | precision_0 = float(j) / rank 32 | 33 | precision_1 = float(j + 1) / (rank + 1) 34 | 35 | ap += (precision_0 + precision_1) * recall_step / 2. 36 | 37 | return ap 38 | 39 | def compute_map(ranks, gnd, kappas=[]): 40 | """ 41 | Computes the mAP for a given set of returned results. 42 | 43 | Usage: 44 | map = compute_map (ranks, gnd) 45 | computes mean average precsion (map) only 46 | 47 | map, aps, pr, prs = compute_map (ranks, gnd, kappas) 48 | computes mean average precision (map), average precision (aps) for each query 49 | computes mean precision at kappas (pr), precision at kappas (prs) for each query 50 | 51 | Notes: 52 | 1) ranks starts from 0, ranks.shape = db_size X #queries 53 | 2) The junk results (e.g., the query itself) should be declared in the gnd stuct array 54 | 3) If there are no positive images for some query, that query is excluded from the evaluation 55 | """ 56 | 57 | map = 0. 58 | nq = len(gnd) # number of queries 59 | aps = np.zeros(nq) 60 | pr = np.zeros(len(kappas)) 61 | prs = np.zeros((nq, len(kappas))) 62 | nempty = 0 63 | 64 | for i in np.arange(nq): 65 | qgnd = np.array(gnd[i]['ok']) 66 | 67 | # no positive images, skip from the average 68 | if qgnd.shape[0] == 0: 69 | aps[i] = float('nan') 70 | prs[i, :] = float('nan') 71 | nempty += 1 72 | continue 73 | 74 | try: 75 | qgndj = np.array(gnd[i]['junk']) 76 | except: 77 | qgndj = np.empty(0) 78 | 79 | # sorted positions of positive and junk images (0 based) 80 | pos = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgnd)] 81 | junk = np.arange(ranks.shape[0])[np.in1d(ranks[:,i], qgndj)] 82 | 83 | k = 0; 84 | ij = 0; 85 | if len(junk): 86 | # decrease positions of positives based on the number of 87 | # junk images appearing before them 88 | ip = 0 89 | while (ip < len(pos)): 90 | while (ij < len(junk) and pos[ip] > junk[ij]): 91 | k += 1 92 | ij += 1 93 | pos[ip] = pos[ip] - k 94 | ip += 1 95 | 96 | # compute ap 97 | ap = compute_ap(pos, len(qgnd)) 98 | map = map + ap 99 | aps[i] = ap 100 | 101 | # compute precision @ k 102 | pos += 1 # get it to 1-based 103 | for j in np.arange(len(kappas)): 104 | kq = min(max(pos), kappas[j]); 105 | prs[i, j] = (pos <= kq).sum() / kq 106 | pr = pr + prs[i, :] 107 | 108 | map = map / (nq - nempty) 109 | pr = pr / (nq - nempty) 110 | 111 | return map, aps, pr, prs 112 | 113 | 114 | def compute_map_and_print(dataset, ranks, gnd, log, kappas=[1, 5, 10]): 115 | 116 | # old evaluation protocol 117 | if dataset.startswith('oxford5k') or dataset.startswith('paris6k') or dataset.startswith('instre'): 118 | map, aps, _, _ = compute_map(ranks, gnd) 119 | res = '>> {}: mAP {:.2f}'.format(dataset, np.around(map*100, decimals=2)) 120 | print(res) 121 | log.write(res+'\n') 122 | # new evaluation protocol 123 | elif dataset.startswith('roxford5k') or dataset.startswith('rparis6k'): 124 | 125 | gnd_t = [] 126 | for i in range(len(gnd)): 127 | g = {} 128 | g['ok'] = np.concatenate([gnd[i]['easy']]) 129 | g['junk'] = np.concatenate([gnd[i]['junk'], gnd[i]['hard']]) 130 | gnd_t.append(g) 131 | mapE, apsE, mprE, prsE = compute_map(ranks, gnd_t, kappas) 132 | 133 | gnd_t = [] 134 | for i in range(len(gnd)): 135 | g = {} 136 | g['ok'] = np.concatenate([gnd[i]['easy'], gnd[i]['hard']]) 137 | g['junk'] = np.concatenate([gnd[i]['junk']]) 138 | gnd_t.append(g) 139 | mapM, apsM, mprM, prsM = compute_map(ranks, gnd_t, kappas) 140 | 141 | gnd_t = [] 142 | for i in range(len(gnd)): 143 | g = {} 144 | g['ok'] = np.concatenate([gnd[i]['hard']]) 145 | g['junk'] = np.concatenate([gnd[i]['junk'], gnd[i]['easy']]) 146 | gnd_t.append(g) 147 | mapH, apsH, mprH, prsH = compute_map(ranks, gnd_t, kappas) 148 | res_1 = '>> {}: mAP E: {}, M: {}, H: {}'.format(dataset, np.around(mapE*100, decimals=2), np.around(mapM*100, decimals=2), np.around(mapH*100, decimals=2)) 149 | res_2 = '>> {}: mP@k{} E: {}, M: {}, H: {}'.format(dataset, kappas, np.around(mprE*100, decimals=2), np.around(mprM*100, decimals=2), np.around(mprH*100, decimals=2)) 150 | print(res_1) 151 | print(res_2) 152 | log.write(res_1+'\n') 153 | log.write(res_2+'\n') -------------------------------------------------------------------------------- /lib/layers/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import lib.layers.functional as LF 5 | import torch.nn.functional as F 6 | import pdb 7 | 8 | 9 | 10 | # -------------------------------------- 11 | # Loss/Error layers 12 | # -------------------------------------- 13 | 14 | class ContrastiveLoss(nn.Module): 15 | r"""CONTRASTIVELOSS layer that computes contrastive loss for a batch of images: 16 | Q query tuples, each packed in the form of (q,p,n1,..nN) 17 | 18 | Args: 19 | x: tuples arranges in columns as [q,p,n1,nN, ... ] 20 | label: -1 for query, 1 for corresponding positive, 0 for corresponding negative 21 | margin: contrastive loss margin. Default: 0.7 22 | 23 | >>> contrastive_loss = ContrastiveLoss(margin=0.7) 24 | >>> input = torch.randn(128, 35, requires_grad=True) 25 | >>> label = torch.Tensor([-1, 1, 0, 0, 0, 0, 0] * 5) 26 | >>> output = contrastive_loss(input, label) 27 | >>> output.backward() 28 | """ 29 | 30 | def __init__(self, margin=0.7, eps=1e-6): 31 | super(ContrastiveLoss, self).__init__() 32 | self.margin = margin 33 | self.eps = eps 34 | 35 | def forward(self, x, label): 36 | return LF.contrastive_loss(x, label, margin=self.margin, eps=self.eps) 37 | 38 | def __repr__(self): 39 | return self.__class__.__name__ + '(' + 'margin=' + '{:.4f}'.format(self.margin) + ')' 40 | 41 | 42 | 43 | class ContrastiveDistLoss(nn.Module): 44 | r"""CONTRASTIVELOSS layer that computes contrastive loss for a batch of images: 45 | Q query tuples, each packed in the form of (q,p,n1,..nN) 46 | 47 | Args: 48 | x: tuples arranges in columns as [q,p,n1,nN, ... ] 49 | label: -1 for query, 1 for corresponding positive, 0 for corresponding negative 50 | margin: contrastive loss margin. Default: 0.7 51 | 52 | >>> contrastive_loss = ContrastiveLoss(margin=0.7) 53 | >>> input = torch.randn(128, 35, requires_grad=True) 54 | >>> label = torch.Tensor([-1, 1, 0, 0, 0, 0, 0] * 5) 55 | >>> output = contrastive_loss(input, label) 56 | >>> output.backward() 57 | """ 58 | 59 | def __init__(self, margin=0.7, eps=1e-6): 60 | super(ContrastiveDistLoss, self).__init__() 61 | self.margin = margin 62 | self.eps = eps 63 | 64 | def forward(self, x, label, dist): 65 | return LF.contrastive_loss_dist(x, label, dist, eps=self.eps) 66 | 67 | def __repr__(self): 68 | return self.__class__.__name__ + '(' + 'margin=' + '{:.4f}'.format(self.margin) + ')' 69 | 70 | 71 | class TripletLoss(nn.Module): 72 | 73 | def __init__(self, margin=0.1): 74 | super(TripletLoss, self).__init__() 75 | self.margin = margin 76 | 77 | def forward(self, x, label): 78 | return LF.triplet_loss(x, label, margin=self.margin) 79 | 80 | def __repr__(self): 81 | return self.__class__.__name__ + '(' + 'margin=' + '{:.4f}'.format(self.margin) + ')' 82 | 83 | 84 | class CrossEntropyLoss(nn.Module): 85 | 86 | def __init__(self, temp=1, eps=1e-6): 87 | super(CrossEntropyLoss, self).__init__() 88 | self.temp = temp 89 | self.eps = eps 90 | 91 | 92 | def forward(self, x, label): 93 | return LF.cross_entropy_loss(x, label, temp=self.temp, eps=self.eps) 94 | 95 | def __repr__(self): 96 | return self.__class__.__name__ + '(' + 'temp=' + '{:.2f}'.format(self.temp) + ')' 97 | 98 | 99 | 100 | class CrossEntropyDistLoss(nn.Module): 101 | 102 | def __init__(self, temp=1, eps=1e-6): 103 | super(CrossEntropyDistLoss, self).__init__() 104 | self.temp = temp 105 | self.eps = eps 106 | 107 | def forward(self, x, label): 108 | return LF.cross_entropy_loss_dist(x, label, temp=self.temp, eps=self.eps) 109 | 110 | def __repr__(self): 111 | return self.__class__.__name__ + '(' + 'temp=' + '{:.2f}'.format(self.temp) + ')' 112 | 113 | 114 | 115 | class MultiSimilarityLoss(nn.Module): 116 | def __init__(self): 117 | super(MultiSimilarityLoss, self).__init__() 118 | self.thresh = 0.5 119 | self.margin = 0.1 120 | 121 | self.scale_pos = 2.0 122 | self.scale_neg = 40.0 123 | 124 | def forward(self, feats, labels): 125 | 126 | assert feats.size(1) == labels.size(0), \ 127 | f"feats.size(0): {feats.size(0)} is not equal to labels.size(0): {labels.size(0)}" 128 | batch_size = feats.size(1) 129 | sim_mat = torch.matmul(torch.t(feats), feats) 130 | 131 | epsilon = 1e-5 132 | 133 | pos_pair = sim_mat[0][labels == 1] 134 | 135 | neg_pair = sim_mat[0][labels == 0] 136 | 137 | 138 | loss = 0 139 | if len(neg_pair) >= 1 or len(pos_pair) >= 1: 140 | # weighting step 141 | pos_loss = 1.0 / self.scale_pos * torch.log( 142 | 1 + torch.sum(torch.exp(-self.scale_pos * (pos_pair - self.thresh)))) 143 | neg_loss = 1.0 / self.scale_neg * torch.log( 144 | 1 + torch.sum(torch.exp(self.scale_neg * (neg_pair - self.thresh)))) 145 | 146 | loss = pos_loss + neg_loss 147 | if loss == 0: 148 | return torch.zeros([], requires_grad=True) 149 | 150 | return loss 151 | 152 | class HardDarkRank(nn.Module): 153 | def __init__(self, alpha=3, beta=3, permute_len=4): 154 | super().__init__() 155 | self.alpha = alpha 156 | self.beta = beta 157 | self.permute_len = permute_len 158 | 159 | def forward(self, student, teacher): 160 | score_teacher = -1 * self.alpha * LF.pdist(teacher, squared=False).pow(self.beta) 161 | score_student = -1 * self.alpha * LF.pdist(student, squared=False).pow(self.beta) 162 | 163 | permute_idx = score_teacher.sort(dim=1, descending=True)[1][:, 1:(self.permute_len+1)] 164 | ordered_student = torch.gather(score_student, 1, permute_idx) 165 | 166 | log_prob = (ordered_student - torch.stack([torch.logsumexp(ordered_student[:, i:], dim=1) for i in range(permute_idx.size(1))], dim=1)).sum(dim=1) 167 | loss = (-1 * log_prob).mean() 168 | 169 | return loss 170 | 171 | class RKdAD(nn.Module): 172 | def __init__(self): 173 | super(RKdAD, self).__init__() 174 | 175 | def pdist(self, e, squared=False, eps=1e-12): 176 | e_square = e.pow(2).sum(dim=1) 177 | prod = e @ e.t() 178 | res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps) 179 | 180 | if not squared: 181 | res = res.sqrt() 182 | 183 | res = res.clone() 184 | res[range(len(e)), range(len(e))] = 0 185 | return res 186 | 187 | def forward(self, student, teacher): 188 | # N x C 189 | # N x N x C 190 | 191 | with torch.no_grad(): 192 | td = (teacher.unsqueeze(0) - teacher.unsqueeze(1)) 193 | norm_td = F.normalize(td, p=2, dim=2) 194 | t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1) 195 | 196 | t_d = self.pdist(teacher, squared=False) 197 | mean_td = t_d[t_d>0].mean() 198 | t_d = t_d / mean_td 199 | 200 | sd = (student.unsqueeze(0) - student.unsqueeze(1)) 201 | norm_sd = F.normalize(sd, p=2, dim=2) 202 | s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1) 203 | 204 | d = self.pdist(student, squared=False) 205 | mean_d = d[d>0].mean() 206 | d = d / mean_d 207 | 208 | loss_a = F.smooth_l1_loss(s_angle, t_angle, reduction='elementwise_mean') 209 | loss_d = F.smooth_l1_loss(d, t_d, reduction='elementwise_mean') 210 | 211 | 212 | return 2*loss_a+loss_d 213 | 214 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | 4 | 5 | def train(train_loader, model, criterion, optimizer, epoch, log, args): 6 | batch_time = AverageMeter() 7 | data_time = AverageMeter() 8 | losses = AverageMeter() 9 | 10 | # create tuples for training 11 | avg_neg_distance = train_loader.dataset.create_epoch_tuples(model) 12 | 13 | # switch to train mode 14 | model.train() 15 | model.apply(set_batchnorm_eval) 16 | 17 | # zero out gradients 18 | optimizer.zero_grad() 19 | 20 | end = time.time() 21 | for i, (input, target) in enumerate(train_loader): 22 | # measure data loading time 23 | data_time.update(time.time() - end) 24 | 25 | nq = len(input) # number of training tuples 26 | ni = len(input[0]) # number of images per tuple 27 | if args.mode == 'rand': 28 | outputs = torch.zeros(nq, target[0].shape[0]).cuda() 29 | for q in range(nq): 30 | ni = len(input[q]) 31 | output = torch.zeros(model.meta['outputdim'], ni).cuda() 32 | if args.mode in ['ts', 'ts_self', 'ts_rand', 'reg', 'reg_only_pos', 'rand_tpl_a']: 33 | if args.sym == True: 34 | for imi in range(ni): 35 | output[:, imi] = model(input[q][imi].cuda()).squeeze() 36 | else: 37 | for imi in range(ni): 38 | if imi == 0: 39 | output[:, imi] = model(input[q][imi].cuda()).squeeze() 40 | else: 41 | output[:, imi] = torch.tensor(input[q][imi]).float().cuda() 42 | 43 | elif args.mode in ['std', 'rand_tpl']: 44 | for imi in range(ni): 45 | output[:, imi] = model(input[q][imi].cuda()).squeeze() 46 | else: 47 | for imi in range(ni): 48 | output[:, imi] = model(input[q][imi].cuda()).squeeze() 49 | outputs[q,:] = output.squeeze() 50 | if args.mode != 'rand': 51 | loss = criterion(output, target[q].t().cuda()) 52 | losses.update(loss.item()) 53 | loss.backward() 54 | if args.mode == 'rand': 55 | targets = torch.stack(target).cuda() 56 | loss = criterion(outputs, targets) 57 | losses.update(loss.item()) 58 | loss.backward() 59 | 60 | if (i + 1) % args.update_every == 0: 61 | # do one step for multiple batches 62 | # accumulated gradients are used 63 | optimizer.step() 64 | # zero out gradients so we can 65 | # accumulate new ones over batches 66 | optimizer.zero_grad() 67 | # print('>> Train: [{0}][{1}/{2}]\t' 68 | # 'Weight update performed'.format( 69 | # epoch+1, i+1, len(train_loader))) 70 | 71 | # measure elapsed time 72 | batch_time.update(time.time() - end) 73 | end = time.time() 74 | 75 | if (i+1) % args.print_freq == 0 or i == 0 or (i+1) == len(train_loader): 76 | out = '>> Train: [{0}][{1}/{2}]\t Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t \ 77 | Data {data_time.val:.3f} ({data_time.avg:.3f})\t \ 78 | Loss {loss.val:.4f} ({loss.avg:.4f})'.format( 79 | epoch+1, i+1, len(train_loader), batch_time=batch_time, 80 | data_time=data_time, loss=losses) 81 | print(out) 82 | log.write(out+'\n') 83 | 84 | return losses.avg 85 | 86 | 87 | def validate(val_loader, model, criterion, epoch, args): 88 | batch_time = AverageMeter() 89 | losses = AverageMeter() 90 | 91 | # create tuples for validation 92 | avg_neg_distance = val_loader.dataset.create_epoch_tuples(model) 93 | 94 | # switch to evaluate mode 95 | model.eval() 96 | 97 | end = time.time() 98 | for i, (input, target) in enumerate(val_loader): 99 | 100 | nq = len(input) # number of training tuples 101 | ni = len(input[0]) # number of images per tuple 102 | output = torch.zeros(model.meta['outputdim'], nq*ni).cuda() 103 | if args.mode == 'rand': 104 | outputs = torch.zeros(nq, target[0].shape[0]).cuda() 105 | for q in range(nq): 106 | if args.mode in ['ts', 'reg', 'reg_only_pos', 'ts_self', 'ts_rand', 'rand_tpl_a']: 107 | if args.sym == True: 108 | for imi in range(ni): 109 | output[:, q*ni + imi] = model(input[q][imi].cuda()).squeeze() 110 | else: 111 | for imi in range(ni): 112 | if imi == 0: 113 | output[:, q*ni + imi] = model(input[q][imi].cuda()).squeeze() 114 | else: 115 | output[:, q*ni + imi] = torch.tensor(input[q][imi]).float().cuda() 116 | elif args.mode == 'rand': 117 | for imi in range(ni): 118 | output[:, imi] = model(input[q][imi].cuda()).squeeze() 119 | else: 120 | for imi in range(ni): 121 | # compute output vector for image imi of query q 122 | output[:, q*ni + imi] = model(input[q][imi].cuda()).squeeze() 123 | 124 | # no need to reduce memory consumption (no backward pass): 125 | # compute loss for the full batch 126 | 127 | if args.mode == 'rand': 128 | targets = torch.stack(target).cuda().t() 129 | loss = criterion(output.t(), targets.t()) 130 | else: 131 | if args.sym: 132 | loss = criterion(output, torch.cat(target).cuda().t()) 133 | else: 134 | loss = criterion(output, torch.cat(target).cuda()) 135 | 136 | # record loss 137 | losses.update(loss.item()/nq, nq) 138 | 139 | # measure elapsed time 140 | batch_time.update(time.time() - end) 141 | end = time.time() 142 | 143 | if (i+1) % args.print_freq == 0 or i == 0 or (i+1) == len(val_loader): 144 | print('>> Val: [{0}][{1}/{2}]\t' 145 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 146 | 'Loss {loss.val:.4f} ({loss.avg:.4f})'.format( 147 | epoch+1, i+1, len(val_loader), batch_time=batch_time, loss=losses)) 148 | 149 | return losses.avg 150 | 151 | class AverageMeter(object): 152 | """Computes and stores the average and current value""" 153 | def __init__(self): 154 | self.reset() 155 | 156 | def reset(self): 157 | self.val = 0 158 | self.avg = 0 159 | self.sum = 0 160 | self.count = 0 161 | 162 | def update(self, val, n=1): 163 | self.val = val 164 | self.sum += val * n 165 | self.count += n 166 | self.avg = self.sum / self.count 167 | 168 | def set_batchnorm_eval(m): 169 | classname = m.__class__.__name__ 170 | if classname.find('BatchNorm') != -1: 171 | # freeze running mean and std: 172 | # we do training one image at a time 173 | # so the statistics would not be per batch 174 | # hence we choose freezing (ie using imagenet statistics) 175 | m.eval() 176 | # # freeze parameters: 177 | # # in fact no need to freeze scale and bias 178 | # # they can be learned 179 | # # that is why next two lines are commented 180 | # for p in m.parameters(): 181 | # p.requires_grad = False 182 | -------------------------------------------------------------------------------- /lib/utils/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def download_test(data_dir): 4 | """ 5 | DOWNLOAD_TEST Checks, and, if required, downloads the necessary datasets for the testing. 6 | 7 | download_test(DATA_ROOT) checks if the data necessary for running the example script exist. 8 | If not it downloads it in the folder structure: 9 | DATA_ROOT/test/oxford5k/ : folder with Oxford images and ground truth file 10 | DATA_ROOT/test/paris6k/ : folder with Paris images and ground truth file 11 | DATA_ROOT/test/roxford5k/ : folder with Oxford images and revisited ground truth file 12 | DATA_ROOT/test/rparis6k/ : folder with Paris images and revisited ground truth file 13 | """ 14 | 15 | # Create data folder if it does not exist 16 | if not os.path.isdir(data_dir): 17 | os.mkdir(data_dir) 18 | 19 | # Create datasets folder if it does not exist 20 | datasets_dir = os.path.join(data_dir, 'test') 21 | if not os.path.isdir(datasets_dir): 22 | os.mkdir(datasets_dir) 23 | 24 | # Download datasets folders test/DATASETNAME/ 25 | datasets = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k'] 26 | for di in range(len(datasets)): 27 | dataset = datasets[di] 28 | 29 | if dataset == 'oxford5k': 30 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/oxbuildings' 31 | dl_files = ['oxbuild_images.tgz'] 32 | elif dataset == 'paris6k': 33 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/parisbuildings' 34 | dl_files = ['paris_1.tgz', 'paris_2.tgz'] 35 | elif dataset == 'roxford5k': 36 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/oxbuildings' 37 | dl_files = ['oxbuild_images.tgz'] 38 | elif dataset == 'rparis6k': 39 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/parisbuildings' 40 | dl_files = ['paris_1.tgz', 'paris_2.tgz'] 41 | else: 42 | raise ValueError('Unknown dataset: {}!'.format(dataset)) 43 | 44 | dst_dir = os.path.join(datasets_dir, dataset, 'jpg') 45 | if not os.path.isdir(dst_dir): 46 | 47 | # for oxford and paris download images 48 | if dataset == 'oxford5k' or dataset == 'paris6k': 49 | print('>> Dataset {} directory does not exist. Creating: {}'.format(dataset, dst_dir)) 50 | os.makedirs(dst_dir) 51 | for dli in range(len(dl_files)): 52 | dl_file = dl_files[dli] 53 | src_file = os.path.join(src_dir, dl_file) 54 | dst_file = os.path.join(dst_dir, dl_file) 55 | print('>> Downloading dataset {} archive {}...'.format(dataset, dl_file)) 56 | os.system('wget {} -O {}'.format(src_file, dst_file)) 57 | print('>> Extracting dataset {} archive {}...'.format(dataset, dl_file)) 58 | # create tmp folder 59 | dst_dir_tmp = os.path.join(dst_dir, 'tmp') 60 | os.system('mkdir {}'.format(dst_dir_tmp)) 61 | # extract in tmp folder 62 | os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir_tmp)) 63 | # remove all (possible) subfolders by moving only files in dst_dir 64 | os.system('find {} -type f -exec mv -i {{}} {} \\;'.format(dst_dir_tmp, dst_dir)) 65 | # remove tmp folder 66 | os.system('rm -rf {}'.format(dst_dir_tmp)) 67 | print('>> Extracted, deleting dataset {} archive {}...'.format(dataset, dl_file)) 68 | os.system('rm {}'.format(dst_file)) 69 | 70 | # for roxford and rparis just make sym links 71 | elif dataset == 'roxford5k' or dataset == 'rparis6k': 72 | print('>> Dataset {} directory does not exist. Creating: {}'.format(dataset, dst_dir)) 73 | dataset_old = dataset[1:] 74 | dst_dir_old = os.path.join(datasets_dir, dataset_old, 'jpg') 75 | os.mkdir(os.path.join(datasets_dir, dataset)) 76 | os.system('ln -s {} {}'.format(dst_dir_old, dst_dir)) 77 | print('>> Created symbolic link from {} jpg to {} jpg'.format(dataset_old, dataset)) 78 | 79 | 80 | gnd_src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'test', dataset) 81 | gnd_dst_dir = os.path.join(datasets_dir, dataset) 82 | gnd_dl_file = 'gnd_{}.pkl'.format(dataset) 83 | gnd_src_file = os.path.join(gnd_src_dir, gnd_dl_file) 84 | gnd_dst_file = os.path.join(gnd_dst_dir, gnd_dl_file) 85 | if not os.path.exists(gnd_dst_file): 86 | print('>> Downloading dataset {} ground truth file...'.format(dataset)) 87 | os.system('wget {} -O {}'.format(gnd_src_file, gnd_dst_file)) 88 | 89 | 90 | def download_train(data_dir): 91 | """ 92 | DOWNLOAD_TRAIN Checks, and, if required, downloads the necessary datasets for the training. 93 | 94 | download_train(DATA_ROOT) checks if the data necessary for running the example script exist. 95 | If not it downloads it in the folder structure: 96 | DATA_ROOT/train/retrieval-SfM-120k/ : folder with rsfm120k images and db files 97 | DATA_ROOT/train/retrieval-SfM-30k/ : folder with rsfm30k images and db files 98 | """ 99 | 100 | # Create data folder if it does not exist 101 | if not os.path.isdir(data_dir): 102 | os.mkdir(data_dir) 103 | 104 | # Create datasets folder if it does not exist 105 | datasets_dir = os.path.join(data_dir, 'train') 106 | if not os.path.isdir(datasets_dir): 107 | os.mkdir(datasets_dir) 108 | 109 | # Download folder train/retrieval-SfM-120k/ 110 | src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'train', 'ims') 111 | dst_dir = os.path.join(datasets_dir, 'retrieval-SfM-120k', 'ims') 112 | dl_file = 'ims.tar.gz' 113 | if not os.path.isdir(dst_dir): 114 | src_file = os.path.join(src_dir, dl_file) 115 | dst_file = os.path.join(dst_dir, dl_file) 116 | print('>> Image directory does not exist. Creating: {}'.format(dst_dir)) 117 | os.makedirs(dst_dir) 118 | print('>> Downloading ims.tar.gz...') 119 | os.system('wget {} -O {}'.format(src_file, dst_file)) 120 | print('>> Extracting {}...'.format(dst_file)) 121 | os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir)) 122 | print('>> Extracted, deleting {}...'.format(dst_file)) 123 | os.system('rm {}'.format(dst_file)) 124 | 125 | # Create symlink for train/retrieval-SfM-30k/ 126 | dst_dir_old = os.path.join(datasets_dir, 'retrieval-SfM-120k', 'ims') 127 | dst_dir = os.path.join(datasets_dir, 'retrieval-SfM-30k', 'ims') 128 | if not os.path.isdir(dst_dir): 129 | os.makedirs(os.path.join(datasets_dir, 'retrieval-SfM-30k')) 130 | os.system('ln -s {} {}'.format(dst_dir_old, dst_dir)) 131 | print('>> Created symbolic link from retrieval-SfM-120k/ims to retrieval-SfM-30k/ims') 132 | 133 | # Download db files 134 | src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'train', 'dbs') 135 | datasets = ['retrieval-SfM-120k', 'retrieval-SfM-30k'] 136 | for dataset in datasets: 137 | dst_dir = os.path.join(datasets_dir, dataset) 138 | if dataset == 'retrieval-SfM-120k': 139 | dl_files = ['{}.pkl'.format(dataset), '{}-whiten.pkl'.format(dataset)] 140 | elif dataset == 'retrieval-SfM-30k': 141 | dl_files = ['{}-whiten.pkl'.format(dataset)] 142 | 143 | if not os.path.isdir(dst_dir): 144 | print('>> Dataset directory does not exist. Creating: {}'.format(dst_dir)) 145 | os.mkdir(dst_dir) 146 | 147 | for i in range(len(dl_files)): 148 | src_file = os.path.join(src_dir, dl_files[i]) 149 | dst_file = os.path.join(dst_dir, dl_files[i]) 150 | if not os.path.isfile(dst_file): 151 | print('>> DB file {} does not exist. Downloading...'.format(dl_files[i])) 152 | os.system('wget {} -O {}'.format(src_file, dst_file)) 153 | -------------------------------------------------------------------------------- /lib/utils/download_win.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def download_test(data_dir): 4 | """ 5 | DOWNLOAD_TEST Checks, and, if required, downloads the necessary datasets for the testing. 6 | 7 | download_test(DATA_ROOT) checks if the data necessary for running the example script exist. 8 | If not it downloads it in the folder structure: 9 | DATA_ROOT/test/oxford5k/ : folder with Oxford images and ground truth file 10 | DATA_ROOT/test/paris6k/ : folder with Paris images and ground truth file 11 | DATA_ROOT/test/roxford5k/ : folder with Oxford images and revisited ground truth file 12 | DATA_ROOT/test/rparis6k/ : folder with Paris images and revisited ground truth file 13 | """ 14 | 15 | # Create data folder if it does not exist 16 | if not os.path.isdir(data_dir): 17 | os.mkdir(data_dir) 18 | 19 | # Create datasets folder if it does not exist 20 | datasets_dir = os.path.join(data_dir, 'test') 21 | if not os.path.isdir(datasets_dir): 22 | os.mkdir(datasets_dir) 23 | 24 | # Download datasets folders test/DATASETNAME/ 25 | datasets = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k'] 26 | for di in range(len(datasets)): 27 | dataset = datasets[di] 28 | 29 | if dataset == 'oxford5k': 30 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/oxbuildings' 31 | dl_files = ['oxbuild_images.tgz'] 32 | elif dataset == 'paris6k': 33 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/parisbuildings' 34 | dl_files = ['paris_1.tgz', 'paris_2.tgz'] 35 | elif dataset == 'roxford5k': 36 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/oxbuildings' 37 | dl_files = ['oxbuild_images.tgz'] 38 | elif dataset == 'rparis6k': 39 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/parisbuildings' 40 | dl_files = ['paris_1.tgz', 'paris_2.tgz'] 41 | else: 42 | raise ValueError('Unknown dataset: {}!'.format(dataset)) 43 | 44 | dst_dir = os.path.join(datasets_dir, dataset, 'jpg') 45 | if not os.path.isdir(dst_dir): 46 | 47 | # for oxford and paris download images 48 | if dataset == 'oxford5k' or dataset == 'paris6k': 49 | print('>> Dataset {} directory does not exist. Creating: {}'.format(dataset, dst_dir)) 50 | os.makedirs(dst_dir) 51 | for dli in range(len(dl_files)): 52 | dl_file = dl_files[dli] 53 | src_file = os.path.join(src_dir, dl_file) 54 | dst_file = os.path.join(dst_dir, dl_file) 55 | print('>> Downloading dataset {} archive {}...'.format(dataset, dl_file)) 56 | os.system('wget {} -O {}'.format(src_file, dst_file)) 57 | print('>> Extracting dataset {} archive {}...'.format(dataset, dl_file)) 58 | # create tmp folder 59 | dst_dir_tmp = os.path.join(dst_dir, 'tmp') 60 | os.system('mkdir {}'.format(dst_dir_tmp)) 61 | # extract in tmp folder 62 | os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir_tmp)) 63 | # remove all (possible) subfolders by moving only files in dst_dir 64 | os.system('find {} -type f -exec mv -i {{}} {} \\;'.format(dst_dir_tmp, dst_dir)) 65 | # remove tmp folder 66 | os.system('rd {}'.format(dst_dir_tmp)) 67 | print('>> Extracted, deleting dataset {} archive {}...'.format(dataset, dl_file)) 68 | os.system('del {}'.format(dst_file)) 69 | 70 | # for roxford and rparis just make sym links 71 | elif dataset == 'roxford5k' or dataset == 'rparis6k': 72 | print('>> Dataset {} directory does not exist. Creating: {}'.format(dataset, dst_dir)) 73 | dataset_old = dataset[1:] 74 | dst_dir_old = os.path.join(datasets_dir, dataset_old, 'jpg') 75 | os.mkdir(os.path.join(datasets_dir, dataset)) 76 | os.system('cmd /c mklink /d {} {}'.format(dst_dir_old, dst_dir)) 77 | print('>> Created symbolic link from {} jpg to {} jpg'.format(dataset_old, dataset)) 78 | 79 | 80 | gnd_src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'test', dataset) 81 | gnd_dst_dir = os.path.join(datasets_dir, dataset) 82 | gnd_dl_file = 'gnd_{}.pkl'.format(dataset) 83 | gnd_src_file = os.path.join(gnd_src_dir, gnd_dl_file) 84 | gnd_dst_file = os.path.join(gnd_dst_dir, gnd_dl_file) 85 | if not os.path.exists(gnd_dst_file): 86 | print('>> Downloading dataset {} ground truth file...'.format(dataset)) 87 | os.system('wget {} -O {}'.format(gnd_src_file, gnd_dst_file)) 88 | 89 | 90 | def download_train(data_dir): 91 | """ 92 | DOWNLOAD_TRAIN Checks, and, if required, downloads the necessary datasets for the training. 93 | 94 | download_train(DATA_ROOT) checks if the data necessary for running the example script exist. 95 | If not it downloads it in the folder structure: 96 | DATA_ROOT/train/retrieval-SfM-120k/ : folder with rsfm120k images and db files 97 | DATA_ROOT/train/retrieval-SfM-30k/ : folder with rsfm30k images and db files 98 | """ 99 | 100 | # Create data folder if it does not exist 101 | if not os.path.isdir(data_dir): 102 | os.mkdir(data_dir) 103 | print(data_dir) 104 | # Create datasets folder if it does not exist 105 | datasets_dir = os.path.join(data_dir, 'train') 106 | if not os.path.isdir(datasets_dir): 107 | os.mkdir(datasets_dir) 108 | 109 | # Download folder train/retrieval-SfM-120k/ 110 | src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'train', 'ims') 111 | dst_dir = os.path.join(datasets_dir, 'retrieval-SfM-120k', 'ims') 112 | dl_file = 'ims.tar.gz' 113 | if not os.path.isdir(dst_dir): 114 | src_file = os.path.join(src_dir, dl_file) 115 | dst_file = os.path.join(dst_dir, dl_file) 116 | print('>> Image directory does not exist. Creating: {}'.format(dst_dir)) 117 | os.makedirs(dst_dir) 118 | print('>> Downloading ims.tar.gz...') 119 | # os.system('wget {} -O {}'.format(src_file, dst_file)) 120 | print('>> Extracting {}...'.format(dst_file)) 121 | os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir)) 122 | print('>> Extracted, deleting {}...'.format(dst_file)) 123 | os.system('del {}'.format(dst_file)) 124 | 125 | # Create symlink for train/retrieval-SfM-30k/ 126 | dst_dir_old = os.path.join(datasets_dir, 'retrieval-SfM-120k', 'ims') 127 | dst_dir = os.path.join(datasets_dir, 'retrieval-SfM-30k', 'ims') 128 | if not os.path.isdir(dst_dir): 129 | os.makedirs(os.path.join(datasets_dir, 'retrieval-SfM-30k','ims')) 130 | os.system('mklink {} {}'.format(dst_dir_old, dst_dir)) 131 | print('>> Created symbolic link from retrieval-SfM-120k/ims to retrieval-SfM-30k/ims') 132 | 133 | # Download db files 134 | src_dir = os.path.join('http://cmp.felk.cvut.cz/cnnimageretrieval/data', 'train', 'dbs') 135 | datasets = ['retrieval-SfM-120k', 'retrieval-SfM-30k'] 136 | for dataset in datasets: 137 | dst_dir = os.path.join(datasets_dir, dataset) 138 | if dataset == 'retrieval-SfM-120k': 139 | dl_files = ['{}.pkl'.format(dataset), '{}-whiten.pkl'.format(dataset)] 140 | elif dataset == 'retrieval-SfM-30k': 141 | dl_files = ['{}-whiten.pkl'.format(dataset)] 142 | 143 | if not os.path.isdir(dst_dir): 144 | print('>> Dataset directory does not exist. Creating: {}'.format(dst_dir)) 145 | os.mkdir(dst_dir) 146 | 147 | for i in range(len(dl_files)): 148 | src_file = os.path.join(src_dir, dl_files[i]) 149 | dst_file = os.path.join(dst_dir, dl_files[i]) 150 | if not os.path.isfile(dst_file): 151 | print('>> DB file {} does not exist. Downloading...'.format(dl_files[i])) 152 | os.system('wget {} -O {}'.format(src_file, dst_file)) 153 | -------------------------------------------------------------------------------- /lib/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def create_parser(): 4 | parser = argparse.ArgumentParser(description='Asymmetric metric learning') 5 | 6 | # export directory, training and val datasets, test datasets 7 | parser.add_argument('--directory', metavar='EXPORT_DIR', 8 | help='destination where trained network should be saved') 9 | parser.add_argument('--training-dataset', '-d', metavar='DATASET', default='retrieval-SfM-120k', choices=training_dataset_names, 10 | help='training dataset: ' + 11 | ' | '.join(training_dataset_names) + 12 | ' (default: retrieval-SfM-120k)') 13 | parser.add_argument('--no-val', dest='val', action='store_false', 14 | help='do not run validation') 15 | parser.add_argument('--test-datasets', '-td', metavar='DATASETS', default='roxford5k,rparis6k', 16 | help='comma separated list of test datasets: ' + 17 | ' | '.join(test_datasets_names) + 18 | ' (default: roxford5k,rparis6k)') 19 | parser.add_argument('--test-whiten', metavar='DATASET', default='', choices=test_whiten_names, 20 | help='dataset used to learn whitening for testing: ' + 21 | ' | '.join(test_whiten_names) + 22 | ' (default: None)') 23 | parser.add_argument('--val-freq', default=1, type=int, metavar='N', 24 | help='run val evaluation every N epochs (default: 1)') 25 | parser.add_argument('--save-freq', default=1, type=int, metavar='N', 26 | help='save model every N epochs (default: 1)') 27 | 28 | # network architecture and initialization options 29 | parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet101', choices=model_names, 30 | help='model architecture: ' + 31 | ' | '.join(model_names) + 32 | ' (default: resnet101)') 33 | parser.add_argument('--pool', '-p', metavar='POOL', default='gem', choices=pool_names, 34 | help='pooling options: ' + 35 | ' | '.join(pool_names) + 36 | ' (default: gem)') 37 | parser.add_argument('--local-whitening', '-lw', dest='local_whitening', action='store_true', 38 | help='train model with learnable local whitening (linear layer) before the pooling') 39 | parser.add_argument('--regional', '-r', dest='regional', action='store_true', 40 | help='train model with regional pooling using fixed grid') 41 | parser.add_argument('--whitening', '-w', dest='whitening', action='store_true', 42 | help='train model with learnable whitening (linear layer) after the pooling') 43 | parser.add_argument('--not-pretrained', dest='pretrained', action='store_false', 44 | help='initialize model with random weights (default: pretrained on imagenet)') 45 | parser.add_argument('--loss', '-l', metavar='LOSS', default='contrastive', 46 | choices=loss_names, 47 | help='training loss options: ' + 48 | ' | '.join(loss_names) + 49 | ' (default: contrastive)') 50 | parser.add_argument('--mode', '-m', metavar='MODE', default='std', 51 | choices=mode_names, 52 | help='training mode options: ' + 53 | ' | '.join(mode_names) + 54 | ' (default: std)') 55 | parser.add_argument('--loss-margin', '-lm', metavar='LM', default=0.7, type=float, 56 | help='loss margin: (default: 0.7)') 57 | 58 | # train/val options specific for image retrieval learning 59 | parser.add_argument('--image-size', default=1024, type=int, metavar='N', 60 | help='maximum size of longer image side used for training (default: 1024)') 61 | parser.add_argument('--neg-num', '-nn', default=5, type=int, metavar='N', 62 | help='number of negative image per train/val tuple (default: 5)') 63 | parser.add_argument('--query-size', '-qs', default=2000, type=int, metavar='N', 64 | help='number of queries randomly drawn per one train epoch (default: 2000)') 65 | parser.add_argument('--pool-size', '-ps', default=20000, type=int, metavar='N', 66 | help='size of the pool for hard negative mining (default: 20000)') 67 | 68 | # standard train/val options 69 | parser.add_argument('--gpu-id', '-g', default='0', metavar='N', 70 | help='gpu id used for training (default: 0)') 71 | parser.add_argument('--workers', '-j', default=8, type=int, metavar='N', 72 | help='number of data loading workers (default: 8)') 73 | parser.add_argument('--epochs', default=100, type=int, metavar='N', 74 | help='number of total epochs to run (default: 100)') 75 | parser.add_argument('--batch-size', '-b', default=5, type=int, metavar='N', 76 | help='number of (q,p,n1,...,nN) tuples in a mini-batch (default: 5)') 77 | parser.add_argument('--update-every', '-u', default=1, type=int, metavar='N', 78 | help='update model weights every N batches, used to handle really large batches, ' + 79 | 'batch_size effectively becomes update_every x batch_size (default: 1)') 80 | parser.add_argument('--optimizer', '-o', metavar='OPTIMIZER', default='adam', 81 | choices=optimizer_names, 82 | help='optimizer options: ' + 83 | ' | '.join(optimizer_names) + 84 | ' (default: adam)') 85 | parser.add_argument('--lr', '--learning-rate', default=1e-6, type=float, 86 | metavar='LR', help='initial learning rate (default: 1e-6)') 87 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 88 | help='momentum') 89 | parser.add_argument('--weight-decay', '--wd', default=1e-6, type=float, 90 | metavar='W', help='weight decay (default: 1e-6)') 91 | parser.add_argument('--print-freq', default=10, type=int, 92 | metavar='N', help='print frequency (default: 10)') 93 | parser.add_argument('--resume', default='', type=str, metavar='FILENAME', 94 | help='name of the latest checkpoint (default: None)') 95 | parser.add_argument('--comment', '-c', default='', type=str, metavar='COMMENT', 96 | help='additional experiment comment') 97 | parser.add_argument('--temp', default=0.1, type=float, metavar='TEMP', 98 | help='temperature for the softmax loss function') 99 | parser.add_argument('--nexamples', default=1000, type=int, metavar='N', # Probably don't need !!! 100 | help='number of negative examples for AP or cross(default: 1000)') 101 | parser.add_argument('--teacher', '-t', metavar='TEACHER', default='vgg16', 102 | choices=teacher_names, 103 | help='training mode options: ' + 104 | ' | '.join(teacher_names) + 105 | ' (default: vgg16)') 106 | parser.add_argument('--sym', dest='sym', action='store_true', 107 | help='symmetric training') 108 | parser.add_argument('--feat-path', metavar='FEAT', 109 | help='Path to the extracted features from the teacher for training') 110 | parser.add_argument('--feat-val-path', metavar='FEATVAL', 111 | help='Path to the extracted features from the teacher for validation') 112 | parser.add_argument('--pos-num', '-pn', default=3, type=int, metavar='N', 113 | help='number of positive images per train/val tuple (default: 5)') 114 | 115 | return parser 116 | 117 | def parse_commandline_args(): 118 | return create_parser().parse_args() 119 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib.request 3 | import tarfile 4 | 5 | def download_datasets(data_dir): 6 | """ 7 | DOWNLOAD_DATASETS Checks, and, if required, downloads the necessary datasets for the testing. 8 | 9 | download_datasets(DATA_ROOT) checks if the data necessary for running the example script exist. 10 | If not it downloads it in the folder structure: 11 | DATA_ROOT/datasets/roxford5k/ : folder with Oxford images 12 | DATA_ROOT/datasets/rparis6k/ : folder with Paris images 13 | """ 14 | 15 | # Create data folder if it does not exist 16 | if not os.path.isdir(data_dir): 17 | os.mkdir(data_dir) 18 | 19 | # Create datasets folder if it does not exist 20 | datasets_dir = os.path.join(data_dir, 'datasets') 21 | if not os.path.isdir(datasets_dir): 22 | os.mkdir(datasets_dir) 23 | 24 | # Download datasets folders datasets/DATASETNAME/ 25 | datasets = ['roxford5k', 'rparis6k'] 26 | for di in range(len(datasets)): 27 | dataset = datasets[di] 28 | 29 | if dataset == 'roxford5k': 30 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/oxbuildings' 31 | dl_files = ['oxbuild_images.tgz'] 32 | elif dataset == 'rparis6k': 33 | src_dir = 'http://www.robots.ox.ac.uk/~vgg/data/parisbuildings' 34 | dl_files = ['paris_1.tgz', 'paris_2.tgz'] 35 | else: 36 | raise ValueError('Unknown dataset: {}!'.format(dataset)) 37 | 38 | dst_dir = os.path.join(data_dir, 'datasets', dataset, 'jpg') 39 | if not os.path.isdir(dst_dir): 40 | print('>> Dataset {} directory does not exist. Creating: {}'.format(dataset, dst_dir)) 41 | os.makedirs(dst_dir) 42 | for dli in range(len(dl_files)): 43 | dl_file = dl_files[dli] 44 | src_file = os.path.join(src_dir, dl_file) 45 | dst_file = os.path.join(dst_dir, dl_file) 46 | print('>> Downloading dataset {} archive {}...'.format(dataset, dl_file)) 47 | os.system('wget {} -O {}'.format(src_file, dst_file)) 48 | print('>> Extracting dataset {} archive {}...'.format(dataset, dl_file)) 49 | # create tmp folder 50 | dst_dir_tmp = os.path.join(dst_dir, 'tmp') 51 | os.system('mkdir {}'.format(dst_dir_tmp)) 52 | # extract in tmp folder 53 | os.system('tar -zxf {} -C {}'.format(dst_file, dst_dir_tmp)) 54 | # remove all (possible) subfolders by moving only files in dst_dir 55 | os.system('find {} -type f -exec mv -i {{}} {} \\;'.format(dst_dir_tmp, dst_dir)) 56 | # remove tmp folder 57 | os.system('rm -rf {}'.format(dst_dir_tmp)) 58 | print('>> Extracted, deleting dataset {} archive {}...'.format(dataset, dl_file)) 59 | os.system('rm {}'.format(dst_file)) 60 | 61 | gnd_src_dir = os.path.join('http://cmp.felk.cvut.cz/revisitop/data', 'datasets', dataset) 62 | gnd_dst_dir = os.path.join(data_dir, 'datasets', dataset) 63 | gnd_dl_file = 'gnd_{}.pkl'.format(dataset) 64 | gnd_src_file = os.path.join(gnd_src_dir, gnd_dl_file) 65 | gnd_dst_file = os.path.join(gnd_dst_dir, gnd_dl_file) 66 | if not os.path.exists(gnd_dst_file): 67 | print('>> Downloading dataset {} ground truth file...'.format(dataset)) 68 | os.system('wget {} -O {}'.format(gnd_src_file, gnd_dst_file)) 69 | 70 | 71 | def download_distractors(data_dir): 72 | """ 73 | DOWNLOAD_DISTRACTORS Checks, and, if required, downloads the distractor dataset. 74 | 75 | download_distractors(DATA_ROOT) checks if the distractor dataset exist. 76 | If not it downloads it in the folder: 77 | DATA_ROOT/datasets/revisitop1m/ : folder with 1M distractor images 78 | """ 79 | 80 | # Create data folder if it does not exist 81 | if not os.path.isdir(data_dir): 82 | os.mkdir(data_dir) 83 | 84 | # Create datasets folder if it does not exist 85 | datasets_dir = os.path.join(data_dir, 'datasets') 86 | if not os.path.isdir(datasets_dir): 87 | os.mkdir(datasets_dir) 88 | 89 | dataset = 'revisitop1m' 90 | nfiles = 100 91 | src_dir = 'http://ptak.felk.cvut.cz/revisitop/revisitop1m/jpg' 92 | dl_files = 'revisitop1m.{}.tar.gz' 93 | dst_dir = os.path.join(data_dir, 'datasets', dataset, 'jpg') 94 | dst_dir_tmp = os.path.join(data_dir, 'datasets', dataset, 'jpg_tmp') 95 | if not os.path.isdir(dst_dir): 96 | print('>> Dataset {} directory does not exist.\n>> Creating: {}'.format(dataset, dst_dir)) 97 | if not os.path.isdir(dst_dir_tmp): 98 | os.makedirs(dst_dir_tmp) 99 | for dfi in range(nfiles): 100 | dl_file = dl_files.format(dfi+1) 101 | src_file = os.path.join(src_dir, dl_file) 102 | dst_file = os.path.join(dst_dir_tmp, dl_file) 103 | dst_file_tmp = os.path.join(dst_dir_tmp, dl_file + '.tmp') 104 | if os.path.exists(dst_file): 105 | print('>> [{}/{}] Skipping dataset {} archive {}, already exists...'.format(dfi+1, nfiles, dataset, dl_file)) 106 | else: 107 | while 1: 108 | try: 109 | print('>> [{}/{}] Downloading dataset {} archive {}...'.format(dfi+1, nfiles, dataset, dl_file)) 110 | urllib.request.urlretrieve(src_file, dst_file_tmp) 111 | os.rename(dst_file_tmp, dst_file) 112 | break 113 | except: 114 | print('>>>> Download failed. Try this one again...') 115 | for dfi in range(nfiles): 116 | dl_file = dl_files.format(dfi+1) 117 | dst_file = os.path.join(dst_dir_tmp, dl_file) 118 | print('>> [{}/{}] Extracting dataset {} archive {}...'.format(dfi+1, nfiles, dataset, dl_file)) 119 | tar = tarfile.open(dst_file) 120 | tar.extractall(path=dst_dir_tmp) 121 | tar.close() 122 | print('>> [{}/{}] Extracted, deleting dataset {} archive {}...'.format(dfi+1, nfiles, dataset, dl_file)) 123 | os.remove(dst_file) 124 | # rename tmp folder 125 | os.rename(dst_dir_tmp, dst_dir) 126 | 127 | # download image list 128 | gnd_src_dir = 'http://ptak.felk.cvut.cz/revisitop/revisitop1m/' 129 | gnd_dst_dir = os.path.join(data_dir, 'datasets', dataset) 130 | gnd_dl_file = '{}.txt'.format(dataset) 131 | gnd_src_file = os.path.join(gnd_src_dir, gnd_dl_file) 132 | gnd_dst_file = os.path.join(gnd_dst_dir, gnd_dl_file) 133 | if not os.path.exists(gnd_dst_file): 134 | print('>> Downloading dataset {} image list file...'.format(dataset)) 135 | urllib.request.urlretrieve(gnd_src_file, gnd_dst_file) 136 | 137 | 138 | def download_features(data_dir): 139 | """ 140 | DOWNLOAD_FEATURES Checks, and, if required, downloads the necessary features for the example testing. 141 | 142 | download_features(DATA_ROOT) checks if the data necessary for running the example script exist. 143 | If not it downloads it in the folder: DATA_ROOT/features 144 | """ 145 | 146 | # Create data folder if it does not exist 147 | if not os.path.isdir(data_dir): 148 | os.mkdir(data_dir) 149 | 150 | # Create features folder if it does not exist 151 | features_dir = os.path.join(data_dir, 'features') 152 | if not os.path.isdir(features_dir): 153 | os.mkdir(features_dir) 154 | 155 | # Download example features 156 | datasets = ['roxford5k', 'rparis6k'] 157 | for di in range(len(datasets)): 158 | dataset = datasets[di] 159 | 160 | feat_src_dir = os.path.join('http://cmp.felk.cvut.cz/revisitop/data', 'features') 161 | feat_dst_dir = os.path.join(data_dir, 'features') 162 | feat_dl_file = '{}_resnet_rsfm120k_gem.mat'.format(dataset) 163 | feat_src_file = os.path.join(feat_src_dir, feat_dl_file) 164 | feat_dst_file = os.path.join(feat_dst_dir, feat_dl_file) 165 | if not os.path.exists(feat_dst_file): 166 | print('>> Downloading dataset {} features file {}...'.format(dataset, feat_dl_file)) 167 | os.system('wget {} -O {}'.format(feat_src_file, feat_dst_file)) 168 | -------------------------------------------------------------------------------- /lib/layers/functional.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | # -------------------------------------- 7 | # pooling 8 | # -------------------------------------- 9 | 10 | def mac(x): 11 | return F.max_pool2d(x, (x.size(-2), x.size(-1))) 12 | # return F.adaptive_max_pool2d(x, (1,1)) # alternative 13 | 14 | 15 | def spoc(x): 16 | return F.avg_pool2d(x, (x.size(-2), x.size(-1))) 17 | # return F.adaptive_avg_pool2d(x, (1,1)) # alternative 18 | 19 | 20 | def gem(x, p=3, eps=1e-6): 21 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p) 22 | # return F.lp_pool2d(F.threshold(x, eps, eps), p, (x.size(-2), x.size(-1))) # alternative 23 | 24 | 25 | def rmac(x, L=3, eps=1e-6): 26 | ovr = 0.4 # desired overlap of neighboring regions 27 | steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension 28 | 29 | W = x.size(3) 30 | H = x.size(2) 31 | 32 | w = min(W, H) 33 | w2 = math.floor(w/2.0 - 1) 34 | 35 | b = (max(H, W)-w)/(steps-1) 36 | (tmp, idx) = torch.min(torch.abs(((w**2 - w*b)/w**2)-ovr), 0) # steps(idx) regions for long dimension 37 | 38 | # region overplus per dimension 39 | Wd = 0; 40 | Hd = 0; 41 | if H < W: 42 | Wd = idx.item() + 1 43 | elif H > W: 44 | Hd = idx.item() + 1 45 | 46 | v = F.max_pool2d(x, (x.size(-2), x.size(-1))) 47 | v = v / (torch.norm(v, p=2, dim=1, keepdim=True) + eps).expand_as(v) 48 | 49 | for l in range(1, L+1): 50 | wl = math.floor(2*w/(l+1)) 51 | wl2 = math.floor(wl/2 - 1) 52 | 53 | if l+Wd == 1: 54 | b = 0 55 | else: 56 | b = (W-wl)/(l+Wd-1) 57 | cenW = torch.floor(wl2 + torch.Tensor(range(l-1+Wd+1))*b) - wl2 # center coordinates 58 | if l+Hd == 1: 59 | b = 0 60 | else: 61 | b = (H-wl)/(l+Hd-1) 62 | cenH = torch.floor(wl2 + torch.Tensor(range(l-1+Hd+1))*b) - wl2 # center coordinates 63 | 64 | for i_ in cenH.tolist(): 65 | for j_ in cenW.tolist(): 66 | if wl == 0: 67 | continue 68 | R = x[:,:,(int(i_)+torch.Tensor(range(wl)).long()).tolist(),:] 69 | R = R[:,:,:,(int(j_)+torch.Tensor(range(wl)).long()).tolist()] 70 | vt = F.max_pool2d(R, (R.size(-2), R.size(-1))) 71 | vt = vt / (torch.norm(vt, p=2, dim=1, keepdim=True) + eps).expand_as(vt) 72 | v += vt 73 | 74 | return v 75 | 76 | 77 | def roipool(x, rpool, L=3, eps=1e-6): 78 | ovr = 0.4 # desired overlap of neighboring regions 79 | steps = torch.Tensor([2, 3, 4, 5, 6, 7]) # possible regions for the long dimension 80 | 81 | W = x.size(3) 82 | H = x.size(2) 83 | 84 | w = min(W, H) 85 | w2 = math.floor(w/2.0 - 1) 86 | 87 | b = (max(H, W)-w)/(steps-1) 88 | _, idx = torch.min(torch.abs(((w**2 - w*b)/w**2)-ovr), 0) # steps(idx) regions for long dimension 89 | 90 | # region overplus per dimension 91 | Wd = 0; 92 | Hd = 0; 93 | if H < W: 94 | Wd = idx.item() + 1 95 | elif H > W: 96 | Hd = idx.item() + 1 97 | 98 | vecs = [] 99 | vecs.append(rpool(x).unsqueeze(1)) 100 | 101 | for l in range(1, L+1): 102 | wl = math.floor(2*w/(l+1)) 103 | wl2 = math.floor(wl/2 - 1) 104 | 105 | if l+Wd == 1: 106 | b = 0 107 | else: 108 | b = (W-wl)/(l+Wd-1) 109 | cenW = torch.floor(wl2 + torch.Tensor(range(l-1+Wd+1))*b).int() - wl2 # center coordinates 110 | if l+Hd == 1: 111 | b = 0 112 | else: 113 | b = (H-wl)/(l+Hd-1) 114 | cenH = torch.floor(wl2 + torch.Tensor(range(l-1+Hd+1))*b).int() - wl2 # center coordinates 115 | 116 | for i_ in cenH.tolist(): 117 | for j_ in cenW.tolist(): 118 | if wl == 0: 119 | continue 120 | vecs.append(rpool(x.narrow(2,i_,wl).narrow(3,j_,wl)).unsqueeze(1)) 121 | 122 | return torch.cat(vecs, dim=1) 123 | 124 | 125 | # -------------------------------------- 126 | # normalization 127 | # -------------------------------------- 128 | 129 | def l2n(x, eps=1e-6): 130 | return x / (torch.norm(x, p=2, dim=1, keepdim=True) + eps).expand_as(x) 131 | 132 | def powerlaw(x, eps=1e-6): 133 | x = x + self.eps 134 | return x.abs().sqrt().mul(x.sign()) 135 | 136 | # -------------------------------------- 137 | # loss 138 | # -------------------------------------- 139 | 140 | def contrastive_loss(x, label, margin=0.7, eps=1e-6): 141 | # x is D x N 142 | dim = x.size(0) # D 143 | nq = torch.sum(label.data==-1) # number of tuples 144 | S = x.size(1) // int(nq) # number of images per tuple including query: 1+1+n 145 | 146 | x1 = x[:, ::S].permute(1,0).repeat(1,S-1).view((S-1)*nq,dim).permute(1,0) 147 | 148 | idx = [i for i in range(len(label)) if label.data[i] != -1] 149 | x2 = x[:, idx] 150 | lbl = label[label!=-1] 151 | 152 | dif = x1 - x2 153 | D = torch.pow(dif+eps, 2).sum(dim=0).sqrt() 154 | y = 0.5*lbl*torch.pow(D,2) + 0.5*(1-lbl)*torch.pow(torch.clamp(margin-D, min=0),2) 155 | y = torch.sum(y) 156 | return y 157 | 158 | 159 | def contrastive_loss_dist(x, label, dists, eps=1e-6): 160 | # x is D x N 161 | dim = x.size(0) # D 162 | nq = torch.sum(label.data==-1) # number of tuples 163 | S = x.size(1) // nq # number of images per tuple including query: 1+1+n 164 | 165 | x1 = x[:, ::S].permute(1,0).repeat(1,S-1).view((S-1)*nq,dim).permute(1,0) 166 | idx = [i for i in range(len(label)) if label.data[i] != -1] 167 | x2 = x[:, idx] 168 | lbl = label[label!=-1] 169 | 170 | dif = x1 - x2 171 | D = torch.pow(dif+eps, 2).sum(dim=0).sqrt() 172 | dists = torch.Tensor(dists).cuda() 173 | y = 0.5*lbl*torch.pow(D,2) + 0.5*(1-lbl)*torch.pow(dists-D,2) 174 | y = torch.sum(y) 175 | return y 176 | 177 | def triplet_loss(x, label, margin=0.1): 178 | # x is D x N 179 | dim = x.size(0) # D 180 | nq = torch.sum(label.data==-1).item() # number of tuples 181 | S = x.size(1) // nq # number of images per tuple including query: 1+1+n 182 | 183 | xa = x[:, label.data==-1].permute(1,0).repeat(1,S-2).view((S-2)*nq,dim).permute(1,0) 184 | xp = x[:, label.data==1].permute(1,0).repeat(1,S-2).view((S-2)*nq,dim).permute(1,0) 185 | xn = x[:, label.data==0] 186 | 187 | dist_pos = torch.sum(torch.pow(xa - xp, 2), dim=0) 188 | dist_neg = torch.sum(torch.pow(xa - xn, 2), dim=0) 189 | 190 | out = '/nfs/nas4/mbudnik/dataset_descs/feature_translation/study/' 191 | pf = out+'triplet_dist_pos.txt' 192 | with open(pf,'a') as pout: 193 | for i in range(len(dist_pos)): 194 | pout.write(str(round(float(dist_pos[i]),4))+'\n') 195 | nf = out+'triplet_dist_neg.txt' 196 | with open(nf, 'a') as nout: 197 | for i in range(len(dist_neg)): 198 | nout.write(str(round(float(dist_neg[i]),4))+'\n') 199 | df = out+'triplet_dist_diff.txt' 200 | with open(df, 'a') as nout: 201 | for i in range(len(dist_neg)): 202 | nout.write(str(round(float(dist_pos[i] - dist_neg[i]),4))+'\n') 203 | #pdb.set_trace() 204 | 205 | return torch.sum(torch.clamp(dist_pos - dist_neg + margin, min=0)) 206 | 207 | def cross_entropy_loss(x, label, temp, eps=1e-6): 208 | 209 | #pdb.set_trace() 210 | y = label*F.log_softmax(x/temp) 211 | y = -torch.sum(y) 212 | #print(y) 213 | return y 214 | 215 | def cross_entropy_loss_dist(x, label, temp, eps=1e-6): 216 | 217 | dim = x.size(0) # D 218 | nq = torch.sum(label.data==-1) # number of tuples 219 | S = x.size(1) // int(nq) # number of images per tuple including query: 1+1+n 220 | x1 = x[:, ::S].permute(1,0).repeat(1,S-1).view((S-1)*nq,dim).permute(1,0) 221 | idx = [i for i in range(len(label)) if label.data[i] != -1] 222 | x2 = x[:, idx] 223 | lbl = label[label!=-1] 224 | 225 | dif = x1 - x2 226 | D = torch.pow(dif+eps, 2).sum(dim=0).sqrt() 227 | 228 | y = lbl*F.log_softmax(-D/temp) 229 | y = -torch.sum(y) 230 | return y 231 | 232 | def pdist(e, squared=False, eps=1e-12): 233 | e_square = e.pow(2).sum(dim=1) 234 | prod = e @ e.t() 235 | res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps) 236 | 237 | if not squared: 238 | res = res.sqrt() 239 | 240 | res = res.clone() 241 | res[range(len(e)), range(len(e))] = 0 242 | return res 243 | 244 | 245 | -------------------------------------------------------------------------------- /extract_features.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import pickle 5 | import pdb 6 | 7 | import numpy as np 8 | 9 | import torch 10 | from torch.utils.model_zoo import load_url 11 | from torchvision import transforms 12 | 13 | from cirtorch.networks.imageretrievalnet import init_network, extract_vectors 14 | from cirtorch.datasets.datahelpers import cid2filename 15 | from cirtorch.datasets.testdataset import configdataset 16 | from cirtorch.utils.download import download_train, download_test 17 | from cirtorch.utils.whiten import whitenlearn, whitenapply 18 | from cirtorch.utils.evaluate import compute_map_and_print 19 | from cirtorch.utils.general import get_data_root, htime 20 | 21 | PRETRAINED = { 22 | 'retrievalSfM120k-vgg16-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/retrievalSfM120k-vgg16-gem-b4dcdc6.pth', 23 | 'retrievalSfM120k-resnet101-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/retrievalSfM120k-resnet101-gem-b80fb85.pth', 24 | 'rSfM120k-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet50-gem-w-97bf910.pth', 25 | 'rSfM120k-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet101-gem-w-a155e54.pth', 26 | 'rSfM120k-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet152-gem-w-f39cada.pth', 27 | 'gl18-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet50-gem-w-83fdc30.pth', 28 | 'gl18-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet101-gem-w-a4d43db.pth', 29 | 'gl18-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet152-gem-w-21278d5.pth', 30 | } 31 | 32 | datasets_names = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k', 'retrieval-sfm-120k', 'retrieval-SfM-30k'] 33 | 34 | parser = argparse.ArgumentParser(description='Feature extractor for a given model and dataset.') 35 | 36 | # network 37 | group = parser.add_mutually_exclusive_group(required=True) 38 | group.add_argument('--network-path', '-npath', metavar='NETWORK', 39 | help="pretrained network or network path (destination where network is saved)") 40 | group.add_argument('--network-offtheshelf', '-noff', metavar='NETWORK', 41 | help="off-the-shelf network, in the format 'ARCHITECTURE-POOLING' or 'ARCHITECTURE-POOLING-{reg-lwhiten-whiten}'," + 42 | " examples: 'resnet101-gem' | 'resnet101-gem-reg' | 'resnet101-gem-whiten' | 'resnet101-gem-lwhiten' | 'resnet101-gem-reg-whiten'") 43 | 44 | parser.add_argument('--datasets', '-d', metavar='DATASETS', default='roxford5k,rparis6k', 45 | help="comma separated list of test datasets: " + 46 | " | ".join(datasets_names) + 47 | " (default: 'oxford5k,paris6k')") 48 | parser.add_argument('--image-size', '-imsize', default=1024, type=int, metavar='N', 49 | help="maximum size of longer image side used for testing (default: 1024)") 50 | parser.add_argument('--multiscale', '-ms', metavar='MULTISCALE', default='[1]', 51 | help="use multiscale vectors for testing, " + 52 | " examples: '[1]' | '[1, 1/2**(1/2), 1/2]' | '[1, 2**(1/2), 1/2**(1/2)]' (default: '[1]')") 53 | parser.add_argument('--gpu-id', '-g', default='0', metavar='N', 54 | help="gpu id used for testing (default: '0')") 55 | 56 | 57 | def pil_loader(path): 58 | # to avoid crashing for truncated (corrupted images) 59 | ImageFile.LOAD_TRUNCATED_IMAGES = True 60 | # open path as file to avoid ResourceWarning 61 | #(https://github.com/python-pillow/Pillow/issues/835) 62 | with open(path, 'rb') as f: 63 | img = Image.open(f) 64 | return img.convert('RGB') 65 | 66 | def main(): 67 | args = parser.parse_args() 68 | 69 | # check if there are unknown datasets 70 | for dataset in args.datasets.split(','): 71 | if dataset not in datasets_names: 72 | raise ValueError('Unsupported or unknown dataset: {}!'.format(dataset)) 73 | 74 | # check if test dataset are downloaded 75 | # and download if they are not 76 | download_train(get_data_root()) 77 | download_test(get_data_root()) 78 | 79 | # setting up the visible GPU 80 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 81 | 82 | # loading network from path 83 | if args.network_path is not None: 84 | 85 | print(">> Loading network:\n>>>> '{}'".format(args.network_path)) 86 | if args.network_path in PRETRAINED: 87 | # pretrained networks (downloaded automatically) 88 | state = load_url(PRETRAINED[args.network_path], model_dir=os.path.join(get_data_root(), 'networks')) 89 | else: 90 | # fine-tuned network from path 91 | state = torch.load(args.network_path) 92 | 93 | # parsing net params from meta 94 | # architecture, pooling, mean, std required 95 | # the rest has default values, in case that is doesnt exist 96 | net_params = {} 97 | net_params['architecture'] = state['meta']['architecture'] 98 | net_params['pooling'] = state['meta']['pooling'] 99 | net_params['local_whitening'] = state['meta'].get('local_whitening', False) 100 | net_params['regional'] = state['meta'].get('regional', False) 101 | net_params['whitening'] = state['meta'].get('whitening', False) 102 | net_params['mean'] = state['meta']['mean'] 103 | net_params['std'] = state['meta']['std'] 104 | net_params['pretrained'] = False 105 | 106 | # load network 107 | net = init_network(net_params) 108 | net.load_state_dict(state['state_dict']) 109 | 110 | # if whitening is precomputed 111 | 112 | print(">>>> loaded network: ") 113 | print(net.meta_repr()) 114 | 115 | # loading offtheshelf network 116 | elif args.network_offtheshelf is not None: 117 | 118 | # parse off-the-shelf parameters 119 | offtheshelf = args.network_offtheshelf.split('-') 120 | net_params = {} 121 | net_params['architecture'] = offtheshelf[0] 122 | net_params['pooling'] = offtheshelf[1] 123 | net_params['local_whitening'] = 'lwhiten' in offtheshelf[2:] 124 | net_params['regional'] = 'reg' in offtheshelf[2:] 125 | net_params['whitening'] = 'whiten' in offtheshelf[2:] 126 | net_params['pretrained'] = True 127 | 128 | # load off-the-shelf network 129 | print(">> Loading off-the-shelf network:\n>>>> '{}'".format(args.network_offtheshelf)) 130 | net = init_network(net_params) 131 | print(">>>> loaded network: ") 132 | print(net.meta_repr()) 133 | 134 | # setting up the multi-scale parameters 135 | ms = list(eval(args.multiscale)) 136 | if len(ms)>1 and net.meta['pooling'] == 'gem' and not net.meta['regional'] and not net.meta['whitening']: 137 | msp = net.pool.p.item() 138 | print(">> Set-up multiscale:") 139 | print(">>>> ms: {}".format(ms)) 140 | print(">>>> msp: {}".format(msp)) 141 | else: 142 | msp = 1 143 | 144 | # moving network to gpu and eval mode 145 | net.cuda() 146 | net.eval() 147 | 148 | # set up the transform 149 | normalize = transforms.Normalize( 150 | mean=net.meta['mean'], 151 | std=net.meta['std'] 152 | ) 153 | transform = transforms.Compose([ 154 | transforms.ToTensor(), 155 | normalize 156 | ]) 157 | 158 | # evaluate on test datasets 159 | datasets = args.datasets.split(',') 160 | for dataset in datasets: 161 | start = time.time() 162 | 163 | print('>> {}: Extracting...'.format(dataset)) 164 | data_root = get_data_root() 165 | cfg = configdataset(dataset, os.path.join(data_root, 'datasets')) 166 | images = [cfg['im_fname'](cfg,i) for i in range(cfg['n'])] 167 | vecs = extract_vectors(net, images, args.image_size, transform, ms=ms, msp=msp) 168 | feat_dir = os.path.join(data_root, 'features') 169 | out_path = os.path.join(feat_dir, '%s_%s' % (dataset, args.network_path)) 170 | np.save(out_path, vecs) 171 | out_path_list = os.path.join(feat_dir, '%s_%s_img_list.txt' % (dataset, args.network_path)) 172 | with open(out_path_list, 'w') as opl: 173 | for x in images: 174 | opl.write(x+'\n') 175 | 176 | print('>> {}: elapsed time: {}'.format(dataset, htime(time.time()-start))) 177 | 178 | 179 | if __name__ == '__main__': 180 | main() -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import time 4 | import math 5 | import pickle 6 | import pdb 7 | 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim 13 | import torch.utils.data 14 | 15 | import torchvision.transforms as transforms 16 | import torchvision.models as models 17 | 18 | from lib.networks.mobilenet_v2 import MobileNetV2 19 | 20 | from lib.networks.imageretrievalnet import init_network, extract_vectors 21 | from lib.layers.loss import ContrastiveLoss, TripletLoss, ContrastiveDistLoss, CrossEntropyLoss, CrossEntropyDistLoss, MultiSimilarityLoss, RKdAD 22 | 23 | from lib.datasets.traindataset import TuplesDataset, TuplesDatasetTS, TuplesDatasetTSWithSelf, RegressionTS, RegressionTSOnlyPos, TuplesDatasetRand 24 | from lib.datasets.traindataset import RandomTriplet, RandomTripletAsym, TuplesDatasetTSRand 25 | 26 | from lib.datasets.testdataset import configdataset 27 | from lib.datasets.datahelpers import collate_tuples, collate_tuples_dist, cid2filename 28 | from lib.utils.download import download_train, download_test 29 | from lib.utils.whiten import whitenlearn, whitenapply 30 | from lib.utils.evaluate import compute_map_and_print 31 | from lib.utils.general import get_data_root, htime 32 | from lib import cli, parse_args 33 | 34 | from training import train, validate 35 | 36 | import torch.nn.functional as F 37 | 38 | training_dataset_names = ['retrieval-SfM-120k'] 39 | test_datasets_names = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k'] 40 | test_whiten_names = ['retrieval-SfM-30k', 'retrieval-SfM-120k'] 41 | 42 | model_names = sorted(name for name in models.__dict__ 43 | if name.islower() and not name.startswith("__") 44 | and callable(models.__dict__[name])) 45 | model_names.append('mobilenet_v3') 46 | model_names.append('efficientnet_b3') 47 | model_names.append('efficientnet_b3_new') # which one? 48 | 49 | pool_names = ['mac', 'spoc', 'gem', 'gemmp'] 50 | 51 | loss_names = ['contrastive', 'triplet', 'contrastive_dist', 'cross_entropy', 'cross_entropy_dist', 'multi', 'rkd'] 52 | mode_names = ['ts', 'ts_self', 'reg', 'reg_only_pos', 'std', 'rand', 'rand_tpl', 'rand_tpl_a', 'ts_rand'] 53 | 54 | teacher_names = ['vgg16', 'resnet101'] 55 | optimizer_names = ['sgd', 'adam'] 56 | 57 | min_loss = float('inf') 58 | 59 | def main(): 60 | global min_loss 61 | 62 | # manually check if there are unknown test datasets 63 | for dataset in args.test_datasets.split(','): 64 | if dataset not in test_datasets_names: 65 | raise ValueError('Unsupported or unknown test dataset: {}!'.format(dataset)) 66 | 67 | # check if test dataset are downloaded 68 | # and download if they are not 69 | data_root = '/nfs/nas4/mbudnik/dataset_descs/data/datasets' 70 | download_train(data_root) 71 | download_test(data_root) 72 | 73 | directory = parse_args.from_args_to_string(args) 74 | 75 | args.directory = os.path.join(args.directory, directory) 76 | print(">> Creating directory if it does not exist:\n>> '{}'".format(args.directory)) 77 | if not os.path.exists(args.directory): 78 | os.makedirs(args.directory) 79 | log_out = args.directory+'/log.txt' 80 | log = open(log_out,'a') 81 | loss_log_out = args.directory+'/loss_log.txt' 82 | loss_log = open(loss_log_out,'a') 83 | 84 | # set cuda visible device 85 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 86 | 87 | # set random seeds 88 | # TODO: maybe pass as argument in future implementation? 89 | torch.manual_seed(0) 90 | torch.cuda.manual_seed_all(0) 91 | np.random.seed(0) 92 | 93 | # initialize model 94 | if args.pretrained: 95 | print(">> Using pre-trained model '{}'".format(args.arch)) 96 | else: 97 | print(">> Using model from scratch (random weights) '{}'".format(args.arch)) 98 | 99 | model_params = {} 100 | model_params['architecture'] = args.arch 101 | model_params['pooling'] = args.pool 102 | model_params['local_whitening'] = args.local_whitening 103 | model_params['regional'] = args.regional 104 | model_params['whitening'] = args.whitening 105 | # model_params['mean'] = ... # will use default 106 | # model_params['std'] = ... # will use default 107 | model_params['pretrained'] = args.pretrained 108 | model_params['teacher'] = args.teacher 109 | model = init_network(model_params) 110 | 111 | # move network to gpu 112 | model.cuda() 113 | 114 | # define loss function (criterion) and optimizer 115 | if args.loss == 'contrastive': 116 | criterion = ContrastiveLoss(margin=args.loss_margin).cuda() 117 | elif args.loss == 'contrastive_dist': 118 | criterion = ContrastiveDistLoss(margin=args.loss_margin).cuda() 119 | elif args.loss == 'triplet': 120 | criterion = TripletLoss(margin=args.loss_margin).cuda() 121 | elif args.loss == 'cross_entropy': 122 | criterion = CrossEntropyLoss(temp=args.temp).cuda() 123 | elif args.loss == 'cross_entropy_dist': 124 | criterion = CrossEntropyDistLoss(temp=args.temp).cuda() 125 | elif args.loss == 'multi': 126 | criterion = MultiSimilarityLoss().cuda() 127 | elif args.loss == 'rkd': 128 | criterion = RKdAD().cuda() 129 | else: 130 | raise(RuntimeError("Loss {} not available!".format(args.loss))) 131 | 132 | # parameters split into features, pool, whitening 133 | # IMPORTANT: no weight decay for pooling parameter p in GeM or regional-GeM 134 | parameters = [] 135 | # add feature parameters 136 | parameters.append({'params': model.features.parameters()}) 137 | # add local whitening if exists 138 | if model.lwhiten is not None: 139 | parameters.append({'params': model.lwhiten.parameters()}) 140 | # add pooling parameters (or regional whitening which is part of the pooling layer!) 141 | if not args.regional: 142 | # global, only pooling parameter p weight decay should be 0 143 | if args.pool == 'gem': 144 | parameters.append({'params': model.pool.parameters(), 'lr': args.lr*10, 'weight_decay': 0}) 145 | elif args.pool == 'gemmp': 146 | parameters.append({'params': model.pool.parameters(), 'lr': args.lr*100, 'weight_decay': 0}) 147 | else: 148 | # regional, pooling parameter p weight decay should be 0, 149 | # and we want to add regional whitening if it is there 150 | if args.pool == 'gem': 151 | parameters.append({'params': model.pool.rpool.parameters(), 'lr': args.lr*10, 'weight_decay': 0}) 152 | elif args.pool == 'gemmp': 153 | parameters.append({'params': model.pool.rpool.parameters(), 'lr': args.lr*100, 'weight_decay': 0}) 154 | if model.pool.whiten is not None: 155 | parameters.append({'params': model.pool.whiten.parameters()}) 156 | # add final whitening if exists 157 | if model.whiten is not None: 158 | parameters.append({'params': model.whiten.parameters()}) 159 | 160 | # define optimizer 161 | if args.optimizer == 'sgd': 162 | optimizer = torch.optim.SGD(parameters, args.lr, momentum=args.momentum, weight_decay=args.weight_decay) 163 | elif args.optimizer == 'adam': 164 | optimizer = torch.optim.Adam(parameters, args.lr, weight_decay=args.weight_decay) 165 | 166 | # define learning rate decay schedule 167 | # TODO: maybe pass as argument in future implementation? 168 | exp_decay = math.exp(-0.01) 169 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=exp_decay) 170 | # optionally resume from a checkpoint 171 | start_epoch = 0 172 | if args.resume: 173 | args.resume = os.path.join(args.directory, args.resume) 174 | if os.path.isfile(args.resume): 175 | # load checkpoint weights and update model and optimizer 176 | print(">> Loading checkpoint:\n>> '{}'".format(args.resume)) 177 | checkpoint = torch.load(args.resume) 178 | start_epoch = checkpoint['epoch'] 179 | min_loss = checkpoint['min_loss'] 180 | model.load_state_dict(checkpoint['state_dict']) 181 | optimizer.load_state_dict(checkpoint['optimizer']) 182 | print(">>>> loaded checkpoint:\n>>>> '{}' (epoch {})" 183 | .format(args.resume, checkpoint['epoch'])) 184 | # important not to forget scheduler updating 185 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=exp_decay, last_epoch=checkpoint['epoch']-1) 186 | else: 187 | print(">> Finding the last checkpoint") 188 | all_file = os.listdir(args.directory) 189 | last_ckpt = 0 190 | ckpt_iter = 0 191 | for f in all_file: 192 | if f.startswith('model_epoch'): 193 | ckpt_temp = int(all_file[ckpt_iter].split('.')[0].split('model_epoch')[1]) 194 | if ckpt_temp > last_ckpt: 195 | last_ckpt = ckpt_temp 196 | ckpt_iter += 1 197 | resume_last = os.path.join(args.directory, 'model_epoch'+str(last_ckpt)+'.pth.tar') 198 | if os.path.isfile(resume_last): 199 | print(">> Loading checkpoint:\n>> '{}'".format(resume_last)) 200 | checkpoint = torch.load(resume_last) 201 | start_epoch = checkpoint['epoch'] 202 | min_loss = checkpoint['min_loss'] 203 | model.load_state_dict(checkpoint['state_dict']) 204 | optimizer.load_state_dict(checkpoint['optimizer']) 205 | print(">>>> loaded checkpoint:\n>>>> '{}' (epoch {})" 206 | .format(resume_last, checkpoint['epoch'])) 207 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=exp_decay, last_epoch=checkpoint['epoch']-1) 208 | else: 209 | print(">> No checkpoint found at '{}'".format(resume_last)) 210 | 211 | # Data loading code 212 | normalize = transforms.Normalize(mean=model.meta['mean'], std=model.meta['std']) 213 | transform = transforms.Compose([ 214 | transforms.ToTensor(), 215 | normalize, 216 | ]) 217 | 218 | if args.mode == 'ts': 219 | tr_dataset = TuplesDatasetTS 220 | elif args.mode == 'ts_self': 221 | tr_dataset = TuplesDatasetTSWithSelf 222 | elif args.mode == 'ts_rand': 223 | tr_dataset = TuplesDatasetTSRand 224 | elif args.mode == 'rand': 225 | tr_dataset = TuplesDatasetRand 226 | elif args.mode == 'rand_tpl': 227 | tr_dataset = RandomTriplet 228 | elif args.mode == 'rand_tpl_a': 229 | tr_dataset = RandomTripletAsym 230 | elif args.mode == 'reg' or args.mode == 'reg_only_pos': 231 | tr_dataset = RegressionTS 232 | else: 233 | tr_dataset = TuplesDataset 234 | 235 | train_dataset = tr_dataset( 236 | name=args.training_dataset, 237 | mode='train', 238 | imsize=args.image_size, 239 | nnum=args.neg_num, 240 | qsize=args.query_size, 241 | poolsize=args.pool_size, 242 | feat_path=args.feat_path, 243 | transform=transform, 244 | nexamples=args.nexamples 245 | ) 246 | 247 | train_loader = torch.utils.data.DataLoader( 248 | train_dataset, batch_size=args.batch_size, shuffle=True, 249 | num_workers=args.workers, pin_memory=False, sampler=None, 250 | drop_last=True, collate_fn=collate_tuples 251 | ) 252 | 253 | #----------------------- VALIDATION ----------------------------------- 254 | if args.val: 255 | if args.mode in ['std', 'rand_tpl']: 256 | vl_dataset = TuplesDataset 257 | elif args.mode == 'rand': 258 | vl_dataset = TuplesDatasetRand 259 | elif args.mode == 'ts_rand': 260 | vl_dataset = TuplesDatasetTSRand 261 | else: 262 | vl_dataset = TuplesDatasetTS 263 | 264 | val_dataset = vl_dataset(name=args.training_dataset, mode='val', 265 | imsize=args.image_size, nnum=args.neg_num, qsize=float('Inf'), 266 | poolsize=float('Inf'), feat_path=args.feat_val_path, transform=transform) 267 | 268 | val_loader = torch.utils.data.DataLoader( 269 | val_dataset, batch_size=args.batch_size, shuffle=False, 270 | num_workers=args.workers, pin_memory=True, 271 | drop_last=True, collate_fn=collate_tuples 272 | ) 273 | 274 | 275 | loss_log.write("epoch, train_loss, val_loss\n") 276 | for epoch in range(start_epoch, args.epochs): 277 | # set manual seeds per epoch 278 | np.random.seed(epoch) 279 | torch.manual_seed(epoch) 280 | torch.cuda.manual_seed_all(epoch) 281 | 282 | # adjust learning rate for each epoch 283 | scheduler.step() 284 | 285 | # train for one epoch on train set 286 | 287 | loss = train(train_loader, model, criterion, optimizer, epoch, log, args) 288 | 289 | loss_log.write('%s, %s' %(epoch, loss)) 290 | # evaluate on validation set 291 | if args.val and (epoch + 1) % args.val_freq == 0: 292 | with torch.no_grad(): 293 | loss = validate(val_loader, model, criterion, epoch, args) 294 | loss_log.write(', %s' % loss) 295 | 296 | loss_log.write('\n') 297 | 298 | # remember best loss and save checkpoint 299 | is_best = False 300 | if args.val and (epoch + 1) % args.val_freq == 0: 301 | is_best = loss < min_loss 302 | min_loss = min(loss, min_loss) 303 | elif args.val == False: 304 | is_best = loss < min_loss 305 | min_loss = min(loss, min_loss) 306 | if (epoch + 1) % args.save_freq == 0: 307 | save_checkpoint({ 308 | 'epoch': epoch + 1, 309 | 'meta': model.meta, 310 | 'state_dict': model.state_dict(), 311 | 'min_loss': min_loss, 312 | 'optimizer' : optimizer.state_dict(), 313 | }, is_best, args.directory) 314 | if is_best: 315 | save_checkpoint_best({ 316 | 'epoch': epoch + 1, 317 | 'meta': model.meta, 318 | 'state_dict': model.state_dict(), 319 | 'min_loss': min_loss, 320 | 'optimizer' : optimizer.state_dict(), 321 | }, args.directory) 322 | log.close() 323 | loss_log.close() 324 | 325 | 326 | def save_checkpoint(state, is_best, directory): 327 | filename = os.path.join(directory, 'model_epoch%d.pth.tar' % state['epoch']) 328 | torch.save(state, filename) 329 | if is_best: 330 | filename_best = os.path.join(directory, 'model_best.pth.tar') 331 | shutil.copyfile(filename, filename_best) 332 | 333 | 334 | def save_checkpoint_best(state, directory): 335 | filename_best = os.path.join(directory, 'model_best.pth.tar') 336 | torch.save(state, filename_best) 337 | 338 | 339 | if __name__ == '__main__': 340 | args = cli.parse_commandline_args() 341 | main() 342 | -------------------------------------------------------------------------------- /lib/networks/imageretrievalnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pdb 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.utils.model_zoo as model_zoo 7 | 8 | import torchvision 9 | 10 | import timm 11 | from efficientnet_pytorch import EfficientNet 12 | 13 | from lib.layers.pooling import MAC, SPoC, GeM, GeMmp, RMAC, Rpool 14 | from lib.layers.normalization import L2N, PowerLaw 15 | from lib.datasets.genericdataset import ImagesFromList 16 | from lib.utils.general import get_data_root 17 | 18 | # for some models, we have imported features (convolutions) from caffe because the image retrieval performance is higher for them 19 | FEATURES = { 20 | 'vgg16' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-vgg16-features-d369c8e.pth', 21 | 'resnet50' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet50-features-ac468af.pth', 22 | 'resnet101' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet101-features-10a101d.pth', 23 | 'resnet152' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/imagenet/imagenet-caffe-resnet152-features-1011020.pth', 24 | } 25 | 26 | # TODO: pre-compute for more architectures and properly test variations (pre l2norm, post l2norm) 27 | # pre-computed local pca whitening that can be applied before the pooling layer 28 | L_WHITENING = { 29 | 'resnet101' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-lwhiten-9f830ef.pth', # no pre l2 norm 30 | # 'resnet101' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-lwhiten-da5c935.pth', # with pre l2 norm 31 | } 32 | 33 | # possible global pooling layers, each on of these can be made regional 34 | POOLING = { 35 | 'mac' : MAC, 36 | 'spoc' : SPoC, 37 | 'gem' : GeM, 38 | 'gemmp' : GeMmp, 39 | 'rmac' : RMAC, 40 | } 41 | 42 | # TODO: pre-compute for: resnet50-gem-r, resnet50-mac-r, vgg16-mac-r, alexnet-mac-r 43 | # pre-computed regional whitening, for most commonly used architectures and pooling methods 44 | R_WHITENING = { 45 | 'alexnet-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-r-rwhiten-c8cf7e2.pth', 46 | 'vgg16-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-r-rwhiten-19b204e.pth', 47 | 'resnet101-mac-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-mac-r-rwhiten-7f1ed8c.pth', 48 | 'resnet101-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-r-rwhiten-adace84.pth', 49 | } 50 | 51 | # TODO: pre-compute for more architectures 52 | # pre-computed final (global) whitening, for most commonly used architectures and pooling methods 53 | WHITENING = { 54 | 'alexnet-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-whiten-454ad53.pth', 55 | 'alexnet-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-alexnet-gem-r-whiten-4c9126b.pth', 56 | 'vgg16-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-whiten-eaa6695.pth', 57 | 'vgg16-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-vgg16-gem-r-whiten-83582df.pth', 58 | 'resnet50-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet50-gem-whiten-f15da7b.pth', 59 | 'resnet101-mac-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-mac-r-whiten-9df41d3.pth', 60 | 'resnet101-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-whiten-22ab0c1.pth', 61 | 'resnet101-gem-r' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gem-r-whiten-b379c0a.pth', 62 | 'resnet101-gemmp' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet101-gemmp-whiten-770f53c.pth', 63 | 'resnet152-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-resnet152-gem-whiten-abe7b93.pth', 64 | 'densenet121-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet121-gem-whiten-79e3eea.pth', 65 | 'densenet169-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet169-gem-whiten-6b2a76a.pth', 66 | 'densenet201-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/whiten/retrieval-SfM-120k/retrieval-SfM-120k-densenet201-gem-whiten-22ea45c.pth', 67 | } 68 | 69 | # output dimensionality for supported architectures 70 | OUTPUT_DIM = { 71 | 'alexnet' : 256, 72 | 'vgg11' : 512, 73 | 'vgg13' : 512, 74 | 'vgg16' : 512, 75 | 'vgg19' : 512, 76 | 'resnet18' : 512, 77 | 'resnet34' : 512, 78 | 'resnet50' : 2048, 79 | 'resnet101' : 2048, 80 | 'resnet152' : 2048, 81 | 'densenet121' : 1024, 82 | 'densenet169' : 1664, 83 | 'densenet201' : 1920, 84 | 'densenet161' : 2208, # largest densenet 85 | 'squeezenet1_0' : 512, 86 | 'squeezenet1_1' : 512, 87 | 'mobilenet_v2' : 512, 88 | 'mobilenet_v3' : 512, 89 | 'efficientnet_b3' : 512, 90 | 'efficientnet_b3_new' : 512, 91 | 'mobilenet_v2' : 512 92 | } 93 | 94 | OUTPUT_DIM_RESNET = { 95 | 'resnet101' : 2048, 96 | 'resnet152' : 2048, 97 | 'densenet121' : 1024, 98 | 'densenet169' : 1664, 99 | 'densenet201' : 1920, 100 | 'densenet161' : 2208, # largest densenet 101 | 'mobilenet_v2' : 2048, 102 | 'mobilenet_v3' : 2048, 103 | 'efficientnet_b3' : 2048, 104 | } 105 | 106 | class ImageRetrievalNet(nn.Module): 107 | 108 | def __init__(self, features, lwhiten, pool, whiten, meta): 109 | super(ImageRetrievalNet, self).__init__() 110 | self.features = nn.Sequential(*features) 111 | self.lwhiten = lwhiten 112 | self.pool = pool 113 | self.whiten = whiten 114 | self.norm = L2N() 115 | self.meta = meta 116 | 117 | def forward(self, x): 118 | # x -> features 119 | o = self.features(x) 120 | 121 | # TODO: properly test (with pre-l2norm and/or post-l2norm) 122 | # if lwhiten exist: features -> local whiten 123 | if self.lwhiten is not None: 124 | # o = self.norm(o) 125 | s = o.size() 126 | o = o.permute(0,2,3,1).contiguous().view(-1, s[1]) 127 | o = self.lwhiten(o) 128 | o = o.view(s[0],s[2],s[3],self.lwhiten.out_features).permute(0,3,1,2) 129 | # o = self.norm(o) 130 | 131 | # features -> pool -> norm 132 | o = self.norm(self.pool(o)).squeeze(-1).squeeze(-1) 133 | 134 | # if whiten exist: pooled features -> whiten -> norm 135 | if self.whiten is not None: 136 | o = self.norm(self.whiten(o)) 137 | 138 | # permute so that it is Dx1 column vector per image (DxN if many images) 139 | return o.permute(1,0) 140 | 141 | def __repr__(self): 142 | tmpstr = super(ImageRetrievalNet, self).__repr__()[:-1] 143 | tmpstr += self.meta_repr() 144 | tmpstr = tmpstr + ')' 145 | return tmpstr 146 | 147 | def meta_repr(self): 148 | tmpstr = ' (' + 'meta' + '): dict( \n' # + self.meta.__repr__() + '\n' 149 | tmpstr += ' architecture: {}\n'.format(self.meta['architecture']) 150 | tmpstr += ' local_whitening: {}\n'.format(self.meta['local_whitening']) 151 | tmpstr += ' pooling: {}\n'.format(self.meta['pooling']) 152 | tmpstr += ' regional: {}\n'.format(self.meta['regional']) 153 | tmpstr += ' whitening: {}\n'.format(self.meta['whitening']) 154 | tmpstr += ' outputdim: {}\n'.format(self.meta['outputdim']) 155 | tmpstr += ' mean: {}\n'.format(self.meta['mean']) 156 | tmpstr += ' std: {}\n'.format(self.meta['std']) 157 | tmpstr = tmpstr + ' )\n' 158 | return tmpstr 159 | 160 | 161 | def init_network(params): 162 | 163 | # parse params with default values 164 | architecture = params.get('architecture', 'vgg16') 165 | local_whitening = params.get('local_whitening', False) 166 | pooling = params.get('pooling', 'gem') 167 | regional = params.get('regional', False) 168 | whitening = params.get('whitening', False) 169 | mean = params.get('mean', [0.485, 0.456, 0.406]) 170 | std = params.get('std', [0.229, 0.224, 0.225]) 171 | pretrained = params.get('pretrained', True) 172 | teacher = params.get('teacher') 173 | # get output dimensionality size 174 | if teacher == 'resnet101': 175 | dim = OUTPUT_DIM_RESNET[architecture] 176 | else: 177 | dim = OUTPUT_DIM[architecture] 178 | 179 | # loading network from torchvision 180 | if pretrained: 181 | if architecture not in FEATURES and architecture != 'mobilenet_v3' and architecture != 'efficientnet_b3' and architecture != 'efficientnet_b3_new': 182 | # initialize with network pretrained on imagenet in pytorch 183 | net_in = getattr(torchvision.models, architecture)(pretrained=True) 184 | elif architecture == 'mobilenet_v3': 185 | net_in = timm.create_model('mobilenetv3_100', num_classes = 1000, in_chans = 3, pretrained=True, checkpoint_path='') 186 | elif architecture == 'efficientnet_b3': 187 | net_in = timm.create_model('tf_efficientnet_b3', num_classes = 1000, in_chans = 3, pretrained=True, checkpoint_path='') 188 | elif architecture == 'efficientnet_b3_new': 189 | net_in = EfficientNet.from_pretrained('efficientnet-b3') 190 | else: 191 | # initialize with random weights, later on we will fill features with custom pretrained network 192 | net_in = getattr(torchvision.models, architecture)(pretrained=False) 193 | else: 194 | # initialize with random weights 195 | if architecture != 'mobilenet_v3' and architecture != 'efficientnet_b3': 196 | net_in = getattr(torchvision.models, architecture)(pretrained=False) 197 | elif architecture == 'efficientnet_b3': 198 | net_in = timm.create_model('tf_efficientnet_b3', num_classes = 1000, in_chans = 3, pretrained=False, checkpoint_path='') 199 | else: 200 | net_in = timm.create_model('mobilenetv3_100', num_classes = 1000, in_chans = 3, pretrained=False, checkpoint_path='') 201 | #pdb.set_trace() 202 | 203 | # initialize features 204 | # take only convolutions for features, 205 | # always ends with ReLU to make last activations non-negative 206 | if architecture.startswith('alexnet'): 207 | features = list(net_in.features.children())[:-1] 208 | elif architecture.startswith('vgg'): 209 | features = list(net_in.features.children())[:-1] 210 | elif architecture.startswith('resnet'): 211 | features = list(net_in.children())[:-2] 212 | elif architecture.startswith('densenet'): 213 | features = list(net_in.features.children()) 214 | features.append(nn.ReLU(inplace=True)) 215 | elif architecture.startswith('squeezenet'): 216 | features = list(net_in.features.children()) 217 | elif architecture.startswith('mobilenet_v2'): 218 | features = net_in.features 219 | features = nn.Sequential(features, nn.Conv2d(1280, dim, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()) 220 | features = list(features.children()) 221 | features.append(nn.ReLU(inplace=True)) 222 | elif architecture.startswith('mobilenet_v3'): 223 | features = list(net_in.children())[:-2] 224 | features.append(nn.Conv2d(1280, dim, kernel_size=(1, 1), stride=(1, 1), bias=False)) 225 | features.append(nn.ReLU(inplace=True)) 226 | elif architecture == 'efficientnet_b3': 227 | features = list(net_in.children())[:-2] 228 | features.append(nn.Conv2d(1536, dim, kernel_size=(1, 1), stride=(1, 1), bias=False)) 229 | features.append(nn.ReLU(inplace=True)) 230 | elif architecture == 'efficientnet_b3_new': 231 | features = list(net_in.children())[:-4] 232 | features.append(nn.Conv2d(1536, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)) 233 | features.append(nn.ReLU(inplace=True)) 234 | else: 235 | raise ValueError('Unsupported or unknown architecture: {}!'.format(architecture)) 236 | 237 | # initialize local whitening 238 | if local_whitening: 239 | lwhiten = nn.Linear(dim, dim, bias=True) 240 | # TODO: lwhiten with possible dimensionality reduce 241 | 242 | if pretrained: 243 | lw = architecture 244 | if lw in L_WHITENING: 245 | print(">> {}: for '{}' custom computed local whitening '{}' is used" 246 | .format(os.path.basename(__file__), lw, os.path.basename(L_WHITENING[lw]))) 247 | whiten_dir = os.path.join(get_data_root(), 'whiten') 248 | lwhiten.load_state_dict(model_zoo.load_url(L_WHITENING[lw], model_dir=whiten_dir)) 249 | else: 250 | print(">> {}: for '{}' there is no local whitening computed, random weights are used" 251 | .format(os.path.basename(__file__), lw)) 252 | 253 | else: 254 | lwhiten = None 255 | 256 | # initialize pooling 257 | if pooling == 'gemmp': 258 | pool = POOLING[pooling](mp=dim) 259 | else: 260 | pool = POOLING[pooling]() 261 | 262 | # initialize regional pooling 263 | if regional: 264 | rpool = pool 265 | rwhiten = nn.Linear(dim, dim, bias=True) 266 | # TODO: rwhiten with possible dimensionality reduce 267 | 268 | if pretrained: 269 | rw = '{}-{}-r'.format(architecture, pooling) 270 | if rw in R_WHITENING: 271 | print(">> {}: for '{}' custom computed regional whitening '{}' is used" 272 | .format(os.path.basename(__file__), rw, os.path.basename(R_WHITENING[rw]))) 273 | whiten_dir = os.path.join(get_data_root(), 'whiten') 274 | rwhiten.load_state_dict(model_zoo.load_url(R_WHITENING[rw], model_dir=whiten_dir)) 275 | else: 276 | print(">> {}: for '{}' there is no regional whitening computed, random weights are used" 277 | .format(os.path.basename(__file__), rw)) 278 | 279 | pool = Rpool(rpool, rwhiten) 280 | 281 | # initialize whitening 282 | if whitening: 283 | whiten = nn.Linear(dim, dim, bias=True) 284 | # TODO: whiten with possible dimensionality reduce 285 | 286 | if pretrained: 287 | w = architecture 288 | if local_whitening: 289 | w += '-lw' 290 | w += '-' + pooling 291 | if regional: 292 | w += '-r' 293 | if w in WHITENING: 294 | print(">> {}: for '{}' custom computed whitening '{}' is used" 295 | .format(os.path.basename(__file__), w, os.path.basename(WHITENING[w]))) 296 | whiten_dir = os.path.join(get_data_root(), 'whiten') 297 | whiten.load_state_dict(model_zoo.load_url(WHITENING[w], model_dir=whiten_dir)) 298 | else: 299 | print(">> {}: for '{}' there is no whitening computed, random weights are used" 300 | .format(os.path.basename(__file__), w)) 301 | else: 302 | whiten = None 303 | 304 | # create meta information to be stored in the network 305 | meta = { 306 | 'architecture' : architecture, 307 | 'local_whitening' : local_whitening, 308 | 'pooling' : pooling, 309 | 'regional' : regional, 310 | 'whitening' : whitening, 311 | 'mean' : mean, 312 | 'std' : std, 313 | 'outputdim' : dim, 314 | } 315 | 316 | # create a generic image retrieval network 317 | net = ImageRetrievalNet(features, lwhiten, pool, whiten, meta) 318 | 319 | # initialize features with custom pretrained network if needed 320 | if pretrained and architecture in FEATURES: 321 | print(">> {}: for '{}' custom pretrained features '{}' are used" 322 | .format(os.path.basename(__file__), architecture, os.path.basename(FEATURES[architecture]))) 323 | model_dir = os.path.join(get_data_root(), 'networks') 324 | net.features.load_state_dict(model_zoo.load_url(FEATURES[architecture], model_dir=model_dir)) 325 | 326 | return net 327 | 328 | 329 | def extract_vectors(net, images, image_size, transform, bbxs=None, ms=[1], msp=1, print_freq=10, workers=8): 330 | # moving network to gpu and eval mode 331 | net.cuda() 332 | net.eval() 333 | 334 | # creating dataset loader 335 | loader = torch.utils.data.DataLoader( 336 | ImagesFromList(root='', images=images, imsize=image_size, bbxs=bbxs, transform=transform), 337 | batch_size=1, shuffle=False, num_workers=workers, pin_memory=True 338 | ) 339 | 340 | # extracting vectors 341 | with torch.no_grad(): 342 | vecs = torch.zeros(net.meta['outputdim'], len(images)) 343 | for i, input in enumerate(loader): 344 | input = input.cuda() 345 | 346 | if len(ms) == 1 and ms[0] == 1: 347 | vecs[:, i] = extract_ss(net, input) 348 | else: 349 | vecs[:, i] = extract_ms(net, input, ms, msp) 350 | 351 | if (i+1) % print_freq == 0 or (i+1) == len(images): 352 | print('\r>>>> {}/{} done...'.format((i+1), len(images)), end='') 353 | print('') 354 | 355 | return vecs 356 | 357 | def extract_ss(net, input): 358 | return net(input).cpu().data.squeeze() 359 | 360 | def extract_ms(net, input, ms, msp): 361 | 362 | v = torch.zeros(net.meta['outputdim']) 363 | 364 | for s in ms: 365 | if s == 1: 366 | input_t = input.clone() 367 | else: 368 | input_t = nn.functional.interpolate(input, scale_factor=s, mode='bilinear', align_corners=False) 369 | v += net(input_t).pow(msp).cpu().data.squeeze() 370 | 371 | v /= len(ms) 372 | v = v.pow(1./msp) 373 | v /= v.norm() 374 | 375 | return v 376 | 377 | 378 | def extract_regional_vectors(net, images, image_size, transform, bbxs=None, ms=[1], msp=1, print_freq=10): 379 | # moving network to gpu and eval mode 380 | net.cuda() 381 | net.eval() 382 | 383 | # creating dataset loader 384 | loader = torch.utils.data.DataLoader( 385 | ImagesFromList(root='', images=images, imsize=image_size, bbxs=bbxs, transform=transform), 386 | batch_size=1, shuffle=False, num_workers=8, pin_memory=True 387 | ) 388 | 389 | # extracting vectors 390 | with torch.no_grad(): 391 | vecs = [] 392 | for i, input in enumerate(loader): 393 | input = input.cuda() 394 | 395 | if len(ms) == 1: 396 | vecs.append(extract_ssr(net, input)) 397 | else: 398 | # TODO: not implemented yet 399 | # vecs.append(extract_msr(net, input, ms, msp)) 400 | raise NotImplementedError 401 | 402 | if (i+1) % print_freq == 0 or (i+1) == len(images): 403 | print('\r>>>> {}/{} done...'.format((i+1), len(images)), end='') 404 | print('') 405 | 406 | return vecs 407 | 408 | def extract_ssr(net, input): 409 | return net.pool(net.features(input), aggregate=False).squeeze(0).squeeze(-1).squeeze(-1).permute(1,0).cpu().data 410 | 411 | 412 | def extract_local_vectors(net, images, image_size, transform, bbxs=None, ms=[1], msp=1, print_freq=10): 413 | # moving network to gpu and eval mode 414 | net.cuda() 415 | net.eval() 416 | 417 | # creating dataset loader 418 | loader = torch.utils.data.DataLoader( 419 | ImagesFromList(root='', images=images, imsize=image_size, bbxs=bbxs, transform=transform), 420 | batch_size=1, shuffle=False, num_workers=8, pin_memory=True 421 | ) 422 | 423 | # extracting vectors 424 | with torch.no_grad(): 425 | vecs = [] 426 | for i, input in enumerate(loader): 427 | input = input.cuda() 428 | 429 | if len(ms) == 1: 430 | vecs.append(extract_ssl(net, input)) 431 | else: 432 | # TODO: not implemented yet 433 | # vecs.append(extract_msl(net, input, ms, msp)) 434 | raise NotImplementedError 435 | 436 | if (i+1) % print_freq == 0 or (i+1) == len(images): 437 | print('\r>>>> {}/{} done...'.format((i+1), len(images)), end='') 438 | print('') 439 | 440 | return vecs 441 | 442 | def extract_ssl(net, input): 443 | return net.norm(net.features(input)).squeeze(0).view(net.meta['outputdim'], -1).cpu().data -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import pickle 5 | 6 | 7 | import numpy as np 8 | 9 | import torch 10 | from torch.utils.model_zoo import load_url 11 | from torchvision import transforms 12 | 13 | from lib.networks.imageretrievalnet import init_network, extract_vectors 14 | from lib.datasets.datahelpers import cid2filename 15 | from lib.datasets.testdataset import configdataset 16 | from lib.utils.download import download_train, download_test 17 | from lib.utils.whiten import whitenlearn, whitenapply 18 | from lib.utils.evaluate import compute_map_and_print 19 | from lib.utils.general import get_data_root, htime 20 | 21 | PRETRAINED = { 22 | 'retrievalSfM120k-vgg16-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/retrievalSfM120k-vgg16-gem-b4dcdc6.pth', 23 | 'retrievalSfM120k-resnet101-gem' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/retrievalSfM120k-resnet101-gem-b80fb85.pth', 24 | # new networks with whitening learned end-to-end 25 | 'rSfM120k-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet50-gem-w-97bf910.pth', 26 | 'rSfM120k-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet101-gem-w-a155e54.pth', 27 | 'rSfM120k-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/retrieval-SfM-120k/rSfM120k-tl-resnet152-gem-w-f39cada.pth', 28 | 'gl18-tl-resnet50-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet50-gem-w-83fdc30.pth', 29 | 'gl18-tl-resnet101-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet101-gem-w-a4d43db.pth', 30 | 'gl18-tl-resnet152-gem-w' : 'http://cmp.felk.cvut.cz/cnnimageretrieval/data/networks/gl18/gl18-tl-resnet152-gem-w-21278d5.pth', 31 | # pretrained studnet models without teacher: 32 | 'efficientnet-b3-gem-contr_2048' : 'http://files.inria.fr/aml/efficientnet_b3_gem_contr_2048.pth.tar', 33 | 'efficientnet-b3-gem-contr_512' : 'http://files.inria.fr/aml/efficientnet_b3_gem_contr_512.pth.tar', 34 | 'mobilenet-v2-gem-contr-2048' : 'http://files.inria.fr/aml/mobilenet_v2_gem_contr_2048.pth.tar', 35 | 'mobilenet-v2-gem-contr-512' : 'http://files.inria.fr/aml/mobilenet_v2_gem_contr_512.pth.tar', 36 | # pretrained studnet models with teacher: 37 | 'efficientnet-b3-gem-contr-plus-resnet101' : 'http://files.inria.fr/aml/efficientnet_b3_gem_contr_plus_resnet101.pth.tar', 38 | 'efficientnet-b3-gem-contr-plus-vgg16' : 'http://files.inria.fr/aml/efficientnet_b3_gem_contr_plus_vgg16.pth.tar', 39 | 'efficientnet-b3-gem-reg-resnet101' : 'http://files.inria.fr/aml/efficientnet_b3_gem_reg_resnet101.pth.tar', 40 | 'efficientnet-b3-gem-reg-vgg16' : 'http://files.inria.fr/aml/efficientnet_b3_gem_reg_vgg16.pthaa.tar', 41 | 'efficientnet-b3-gem-contr-resnet101' : 'http://files.inria.fr/aml/efficientnet_b3_gem_contr_resnet101.pth.tar', 42 | 'efficientnet-b3-gem-contr-vgg16' : 'http://files.inria.fr/aml/efficientnet_b3_gem_contr_vgg16.pth.tar', 43 | 'efficientnet-b3-gem-dark-resnet101' : 'http://files.inria.fr/aml/efficientnet_b3_gem_dark_resnet101.pth.tar', 44 | 'efficientnet-b3-gem-dark-vgg16' : 'http://files.inria.fr/aml/efficientnet_b3_gem_dark_vgg16.pth.tar', 45 | 'efficientnet-b3-gem-ms-resnet101' : 'http://files.inria.fr/aml/efficientnet_b3_gem_ms_resnet101.pth.tar', 46 | 'efficientnet-b3-gem-ms-vgg16' : 'http://files.inria.fr/aml/efficientnet_b3_gem_ms_vgg16.pth.tar', 47 | 'efficientnet-b3-gem-rkd-resnet101' : 'http://files.inria.fr/aml/efficientnet_b3_gem_rkd_resnet101.pth.tar', 48 | 'efficientnet-b3-gem-rkd-vgg16' : 'http://files.inria.fr/aml/efficientnet_b3_gem_rkd_vgg16.pth.tar', 49 | 'efficientnet-b3-gem-triplet-resnet101' : 'http://files.inria.fr/aml/efficientnet_b3_gem_triplet_resnet101.pth.tar', 50 | 'efficientnet-b3-gem-triplet-vgg16' : 'http://files.inria.fr/aml/efficientnet_b3_gem_triplet_vgg16.pth.tar', 51 | 'mobilenet-v2-gem-contr-plus-resnet101' : 'http://files.inria.fr/aml/mobilenet_v2_gem_contr_plus_resnet101.pth.tar', 52 | 'mobilenet-v2-gem-contr-plus-vgg16' : 'http://files.inria.fr/aml/mobilenet_v2_gem_contr_plus_vgg16.pth.tar', 53 | 'mobilenet-v2-gem-reg-resnet101' : 'http://files.inria.fr/aml/mobilenet_v2_gem_reg_resnet101.pth.tar', 54 | 'mobilenet-v2-gem-reg-vgg16' : 'http://files.inria.fr/aml/mobilenet_v2_gem_reg_vgg16.pth.tar', 55 | 'mobilenet-v2-gem-contr-resnet101' : 'http://files.inria.fr/aml/mobilenet_v2_gem_contr_resnet101.pth.tar', 56 | 'mobilenet-v2-gem-contr-vgg16' : 'http://files.inria.fr/aml/mobilenet_v2_gem_contr_vgg16.pth.tar', 57 | 'mobilenet-v2-gem-dark-resnet101' : 'http://files.inria.fr/aml/mobilenet_v2_gem_dark_resnet101.pth.tar', 58 | 'mobilenet-v2-gem-dark-vgg16' : 'http://files.inria.fr/aml/mobilenet_v2_gem_dark_vgg16.pth.tar', 59 | 'mobilenet-v2-gem-ms-resnet101' : 'http://files.inria.fr/aml/mobilenet_v2_gem_ms_resnet101.pth.tar', 60 | 'mobilenet-v2-gem-ms-vgg16' : 'http://files.inria.fr/aml/mobilenet_v2_gem_ms_vgg16.pth.tar', 61 | 'mobilenet-v2-gem-rkd-resnet101' : 'http://files.inria.fr/aml/mobilenet_v2_gem_rkd_resnet101.pth.tar', 62 | 'mobilenet-v2-gem-rkd-vgg16' : 'http://files.inria.fr/aml/mobilenet_v2_gem_rkd_vgg16.pth.tar', 63 | 'mobilenet-v2-gem-triplet-resnet101' : 'http://files.inria.fr/aml/mobilenet_v2_gem_triplet_resnet101.pth.tar', 64 | 'mobilenet-v2-gem-triplet-vgg16' : 'http://files.inria.fr/aml/mobilenet_v2_gem_triplet_vgg16.pth.tar' 65 | } 66 | 67 | PRETRAINED_WHITENING = { 68 | # pretrained whitening for studnet models without teacher: 69 | 'efficientnet-b3-gem-contr_2048' : 'http://files.inria.fr/aml/efficientnet_b3_gem_contr_2048_whitening.pth', 70 | 'efficientnet-b3-gem-contr_512' : 'http://files.inria.fr/aml/efficientnet_b3_gem_contr_512_whitening.pth', 71 | 'mobilenet-v2-gem-contr-2048' : 'http://files.inria.fr/aml/mobilenet_v2_gem_contr_2048_whitening.pth', 72 | 'mobilenet-v2-gem-contr-512' : 'http://files.inria.fr/aml/mobilenet_v2_gem_contr_512_whitening.pth', 73 | # pretrained whitening for studnet models with teacher: 74 | 'efficientnet-b3-gem-contr-plus-resnet101' : 'http://files.inria.fr/aml/efficientnet_b3_gem_contr_plus_resnet101_whitening.pth', 75 | 'efficientnet-b3-gem-contr-plus-vgg16' : 'http://files.inria.fr/aml/efficientnet_b3_gem_contr_plus_vgg16_whitening.pth', 76 | 'efficientnet-b3-gem-reg-resnet101' : 'http://files.inria.fr/aml/efficientnet_b3_gem_reg_resnet101_whitening.pth', 77 | 'efficientnet-b3-gem-reg-vgg16' : 'http://files.inria.fr/aml/efficientnet_b3_gem_reg_vgg16_whitening.pth', 78 | 'efficientnet-b3-gem-contr-resnet101' : 'http://files.inria.fr/aml/efficientnet_b3_gem_contr_resnet101_whitening.pth', 79 | 'efficientnet-b3-gem-contr-vgg16' : 'http://files.inria.fr/aml/efficientnet_b3_gem_contr_vgg16_whitening.pth', 80 | 'efficientnet-b3-gem-dark-resnet101' : 'http://files.inria.fr/aml/efficientnet_b3_gem_dark_resnet101_whitening.pth', 81 | 'efficientnet-b3-gem-dark-vgg16' : 'http://files.inria.fr/aml/efficientnet_b3_gem_dark_vgg16_whitening.pth', 82 | 'efficientnet-b3-gem-ms-resnet101' : 'http://files.inria.fr/aml/efficientnet_b3_gem_ms_resnet101_whitening.pth', 83 | 'efficientnet-b3-gem-ms-vgg16' : 'http://files.inria.fr/aml/efficientnet_b3_gem_ms_vgg16_whitening.pth', 84 | 'efficientnet-b3-gem-rkd-resnet101' : 'http://files.inria.fr/aml/efficientnet_b3_gem_rkd_resnet101_whitening.pth', 85 | 'efficientnet-b3-gem-rkd-vgg16' : 'http://files.inria.fr/aml/efficientnet_b3_gem_rkd_vgg16_whitening.pth', 86 | 'efficientnet-b3-gem-triplet-resnet101' : 'http://files.inria.fr/aml/efficientnet_b3_gem_triplet_resnet101_whitening.pth', 87 | 'efficientnet-b3-gem-triplet-vgg16' : 'http://files.inria.fr/aml/efficientnet_b3_gem_triplet_vgg16_whitening.pth', 88 | 'mobilenet-v2-gem-contr-plus-resnet101' : 'http://files.inria.fr/aml/mobilenet_v2_gem_contr_plus_resnet101_whitening.pth', 89 | 'mobilenet-v2-gem-contr-plus-vgg16' : 'http://files.inria.fr/aml/mobilenet_v2_gem_contr_plus_vgg16_whitening.pth', 90 | 'mobilenet-v2-gem-reg-resnet101' : 'http://files.inria.fr/aml/mobilenet_v2_gem_reg_resnet101_whitening.pth', 91 | 'mobilenet-v2-gem-reg-vgg16' : 'http://files.inria.fr/aml/mobilenet_v2_gem_reg_vgg16_whitening.pth', 92 | 'mobilenet-v2-gem-contr-resnet101' : 'http://files.inria.fr/aml/mobilenet_v2_gem_contr_resnet101_whitening.pth', 93 | 'mobilenet-v2-gem-contr-vgg16' : 'http://files.inria.fr/aml/mobilenet_v2_gem_contr_vgg16_whitening.pth', 94 | 'mobilenet-v2-gem-dark-resnet101' : 'http://files.inria.fr/aml/mobilenet_v2_gem_dark_resnet101_whitening.pth', 95 | 'mobilenet-v2-gem-dark-vgg16' : 'http://files.inria.fr/aml/mobilenet_v2_gem_dark_vgg16_whitening.pth', 96 | 'mobilenet-v2-gem-ms-resnet101' : 'http://files.inria.fr/aml/mobilenet_v2_gem_ms_resnet101_whitening.pth', 97 | 'mobilenet-v2-gem-ms-vgg16' : 'http://files.inria.fr/aml/mobilenet_v2_gem_ms_vgg16_whitening.pth', 98 | 'mobilenet-v2-gem-rkd-resnet101' : 'http://files.inria.fr/aml/mobilenet_v2_gem_rkd_resnet101_whitening.pth', 99 | 'mobilenet-v2-gem-rkd-vgg16' : 'http://files.inria.fr/aml/mobilenet_v2_gem_rkd_vgg16_whitening.pth', 100 | 'mobilenet-v2-gem-triplet-resnet101' : 'http://files.inria.fr/aml/mobilenet_v2_gem_triplet_resnet101_whitening.pth', 101 | 'mobilenet-v2-gem-triplet-vgg16' : 'http://files.inria.fr/aml/mobilenet_v2_gem_triplet_vgg16_whitening.pth' 102 | } 103 | 104 | datasets_names = ['oxford5k', 'paris6k', 'roxford5k', 'rparis6k', 'retrieval-SfM-120k', 'instre'] 105 | whitening_names = ['retrieval-SfM-30k', 'retrieval-SfM-120k'] 106 | 107 | parser = argparse.ArgumentParser(description='PyTorch CNN Image Retrieval Testing') 108 | 109 | # network 110 | group = parser.add_mutually_exclusive_group(required=True) 111 | group.add_argument('--network-path', '-npath', metavar='NETWORK', 112 | help="pretrained network or network path (destination where network is saved)") 113 | group.add_argument('--network-offtheshelf', '-noff', metavar='NETWORK', 114 | help="off-the-shelf network, in the format 'ARCHITECTURE-POOLING' or 'ARCHITECTURE-POOLING-{reg-lwhiten-whiten}'," + 115 | " examples: 'resnet101-gem' | 'resnet101-gem-reg' | 'resnet101-gem-whiten' | 'resnet101-gem-lwhiten' | 'resnet101-gem-reg-whiten'") 116 | 117 | # test options 118 | parser.add_argument('--datasets', '-d', metavar='DATASETS', default='roxford5k,rparis6k', 119 | help="comma separated list of test datasets: " + 120 | " | ".join(datasets_names) + 121 | " (default: 'oxford5k,paris6k')") 122 | parser.add_argument('--image-size', '-imsize', default=1024, type=int, metavar='N', 123 | help="maximum size of longer image side used for testing (default: 1024)") 124 | parser.add_argument('--multiscale', '-ms', metavar='MULTISCALE', default='[1]', 125 | help="use multiscale vectors for testing, " + 126 | " examples: '[1]' | '[1, 1/2**(1/2), 1/2]' | '[1, 2**(1/2), 1/2**(1/2)]' (default: '[1]')") 127 | parser.add_argument('--whitening', '-w', metavar='WHITENING', default=None, choices=whitening_names, 128 | help="dataset used to learn whitening for testing: " + 129 | " | ".join(whitening_names) + 130 | " (default: None)") 131 | parser.add_argument('--workers', '-j', default=8, type=int, metavar='N', 132 | help='number of data loading workers (default: 8)') 133 | # GPU ID 134 | parser.add_argument('--gpu-id', '-g', default='0', metavar='N', 135 | help="gpu id used for testing (default: '0')") 136 | parser.add_argument('--teacher', '-t', default='vgg16', metavar='TEACHER', 137 | help="The teacher used for training of the student model.") 138 | parser.add_argument('--asym', dest='asym', action='store_true', 139 | help='Runs symmetric testing by default') 140 | 141 | def main(): 142 | args = parser.parse_args() 143 | 144 | # check if there are unknown datasets 145 | for dataset in args.datasets.split(','): 146 | if dataset not in datasets_names: 147 | raise ValueError('Unsupported or unknown dataset: {}!'.format(dataset)) 148 | 149 | # check if test dataset are downloaded 150 | # and download if they are not 151 | data_root = get_data_root() 152 | #download_train(get_data_root()) 153 | download_test(data_root) 154 | model_path = os.path.join(data_root, 'model') 155 | # setting up the visible GPU 156 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id 157 | if args.asym: 158 | t_name = 'retrievalSfM120k-%s-gem' % (args.teacher) 159 | state = load_url(PRETRAINED[t_name], model_dir=model_path) 160 | net_params_teacher = {} 161 | net_params_teacher['architecture'] = state['meta']['architecture'] 162 | net_params_teacher['pooling'] = state['meta']['pooling'] 163 | net_params_teacher['local_whitening'] = state['meta'].get('local_whitening', False) 164 | net_params_teacher['regional'] = state['meta'].get('regional', False) 165 | net_params_teacher['whitening'] = state['meta'].get('whitening', False) 166 | net_params_teacher['mean'] = state['meta']['mean'] 167 | net_params_teacher['std'] = state['meta']['std'] 168 | net_params_teacher['pretrained'] = True 169 | net_teacher = init_network(net_params_teacher) 170 | net_teacher.load_state_dict(state['state_dict']) 171 | if 'Lw' in state['meta']: 172 | net_teacher.meta['Lw'] = state['meta']['Lw'] 173 | 174 | # loading network from path 175 | if args.network_path is not None: 176 | 177 | print(">> Loading network:\n>>>> '{}'".format(args.network_path)) 178 | if args.network_path in PRETRAINED: 179 | # pretrained networks (downloaded automatically) 180 | #state = load_url(PRETRAINED[args.network_path], model_dir=os.path.join(get_data_root(), 'networks')) 181 | state = load_url(PRETRAINED[args.network_path], model_dir=model_path) 182 | else: 183 | # fine-tuned network from path 184 | state = torch.load(args.network_path) 185 | 186 | # parsing net params from meta 187 | # architecture, pooling, mean, std required 188 | # the rest has default values, in case that is doesnt exist 189 | net_params = {} 190 | net_params['architecture'] = state['meta']['architecture'] 191 | net_params['pooling'] = state['meta']['pooling'] 192 | net_params['local_whitening'] = state['meta'].get('local_whitening', False) 193 | net_params['regional'] = state['meta'].get('regional', False) 194 | net_params['whitening'] = state['meta'].get('whitening', False) 195 | net_params['mean'] = state['meta']['mean'] 196 | net_params['std'] = state['meta']['std'] 197 | net_params['pretrained'] = False 198 | if args.teacher == 'resnet101': 199 | net_params['teacher'] = 'resnet101' 200 | else: 201 | net_params['teacher'] = 'vgg16' 202 | 203 | # load network 204 | 205 | net = init_network(net_params) 206 | #pdb.set_trace() 207 | net.load_state_dict(state['state_dict']) 208 | 209 | # if whitening is precomputed 210 | if 'Lw' in state['meta']: 211 | net.meta['Lw'] = state['meta']['Lw'] 212 | 213 | print(">>>> loaded network: ") 214 | print(net.meta_repr()) 215 | 216 | # loading offtheshelf network 217 | elif args.network_offtheshelf is not None: 218 | 219 | # parse off-the-shelf parameters 220 | offtheshelf = args.network_offtheshelf.split('-') 221 | net_params = {} 222 | net_params['architecture'] = offtheshelf[0] 223 | net_params['pooling'] = offtheshelf[1] 224 | net_params['local_whitening'] = 'lwhiten' in offtheshelf[2:] 225 | net_params['regional'] = 'reg' in offtheshelf[2:] 226 | net_params['whitening'] = 'whiten' in offtheshelf[2:] 227 | net_params['pretrained'] = True 228 | 229 | # load off-the-shelf network 230 | print(">> Loading off-the-shelf network:\n>>>> '{}'".format(args.network_offtheshelf)) 231 | net = init_network(net_params) 232 | print(">>>> loaded network: ") 233 | print(net.meta_repr()) 234 | 235 | # setting up the multi-scale parameters 236 | ms = list(eval(args.multiscale)) 237 | if len(ms)>1 and net.meta['pooling'] == 'gem' and not net.meta['regional'] and not net.meta['whitening']: 238 | msp = net.pool.p.item() 239 | print(">> Set-up multiscale:") 240 | print(">>>> ms: {}".format(ms)) 241 | print(">>>> msp: {}".format(msp)) 242 | else: 243 | msp = 1 244 | 245 | # moving network to gpu and eval mode 246 | net.cuda() 247 | net.eval() 248 | 249 | # set up the transform 250 | normalize = transforms.Normalize( 251 | mean=net.meta['mean'], 252 | std=net.meta['std'] 253 | ) 254 | transform = transforms.Compose([ 255 | transforms.ToTensor(), 256 | normalize 257 | ]) 258 | if args.asym: 259 | log_out = './log_test_asym.txt' 260 | else: 261 | log_out = './log_test.txt' 262 | log = open(log_out,'a') 263 | 264 | # compute whitening 265 | if args.whitening is not None: 266 | start = time.time() 267 | if args.asym: 268 | net_wh = net_teacher 269 | else: 270 | net_wh = net 271 | 272 | if 'Lw' in net_wh.meta and args.whitening in net_wh.meta['Lw']: 273 | print('>> {}: Whitening is precomputed, loading it...'.format(args.whitening)) 274 | if len(ms)>1: 275 | Lw = net_wh.meta['Lw'][args.whitening]['ms'] 276 | else: 277 | Lw = net_wh.meta['Lw'][args.whitening]['ss'] 278 | 279 | else: 280 | # if we evaluate networks from path we should save/load whitening 281 | # not to compute it every time 282 | 283 | if args.network_path is not None and args.network_path in PRETRAINED_WHITENING: 284 | whiten_fn = args.network_path.split('.')[0] + '_whitening.pth' 285 | elif args.network_path is not None: 286 | whiten_fn = args.network_path + '_{}_whiten'.format(args.whitening) 287 | if len(ms) > 1: 288 | whiten_fn += '_ms' 289 | whiten_fn += '.pth' 290 | else: 291 | whiten_fn = None 292 | 293 | if args.network_path in PRETRAINED_WHITENING: 294 | Lw = load_url(PRETRAINED_WHITENING[args.network_path], model_dir=model_path) 295 | 296 | elif whiten_fn is not None and os.path.isfile(whiten_fn): 297 | print('>> {}: Whitening is precomputed, loading it...'.format(args.whitening)) 298 | Lw = torch.load(whiten_fn) 299 | 300 | else: 301 | print('>> {}: Learning whitening...'.format(args.whitening)) 302 | 303 | # loading db 304 | db_root = os.path.join(get_data_root(), 'train', args.whitening) 305 | ims_root = os.path.join(db_root, 'ims') 306 | db_fn = os.path.join(db_root, '{}-whiten.pkl'.format(args.whitening)) 307 | with open(db_fn, 'rb') as f: 308 | db = pickle.load(f) 309 | images = [cid2filename(db['cids'][i], ims_root) for i in range(len(db['cids']))] 310 | 311 | # extract whitening vectors 312 | print('>> {}: Extracting...'.format(args.whitening)) 313 | wvecs = extract_vectors(net_wh, images, args.image_size, transform, ms=ms, msp=msp, workers=args.workers) 314 | 315 | # learning whitening 316 | print('>> {}: Learning...'.format(args.whitening)) 317 | wvecs = wvecs.numpy() 318 | m, P = whitenlearn(wvecs, db['qidxs'], db['pidxs']) 319 | Lw = {'m': m, 'P': P} 320 | 321 | # saving whitening if whiten_fn exists 322 | if whiten_fn is not None: 323 | print('>> {}: Saving to {}...'.format(args.whitening, whiten_fn)) 324 | torch.save(Lw, whiten_fn) 325 | 326 | print('>> {}: elapsed time: {}'.format(args.whitening, htime(time.time()-start))) 327 | 328 | else: 329 | Lw = None 330 | 331 | # evaluate on test datasets 332 | datasets = args.datasets.split(',') 333 | for dataset in datasets: 334 | start = time.time() 335 | 336 | print('>> {}: Extracting...'.format(dataset)) 337 | 338 | # prepare config structure for the test dataset 339 | cfg = configdataset(dataset, os.path.join(get_data_root(), 'test')) 340 | #cfg = configdataset(dataset,data_path) 341 | 342 | images = [cfg['im_fname'](cfg,i) for i in range(cfg['n'])] 343 | qimages = [cfg['qim_fname'](cfg,i) for i in range(cfg['nq'])] 344 | try: 345 | bbxs = [tuple(cfg['gnd'][i]['bbx']) for i in range(cfg['nq'])] 346 | except: 347 | bbxs = None # for holidaysmanrot and copydays 348 | 349 | # extract database and query vectors 350 | print('>> {}: database images...'.format(dataset)) 351 | if args.asym: 352 | vecs = extract_vectors(net_teacher, images, args.image_size, transform, ms=ms, msp=msp) 353 | else: 354 | vecs = extract_vectors(net, images, args.image_size, transform, ms=ms, msp=msp, workers=args.workers) 355 | print('>> {}: query images...'.format(dataset)) 356 | qvecs = extract_vectors(net, qimages, args.image_size, transform, bbxs=bbxs, ms=ms, msp=msp, workers=args.workers) 357 | 358 | print('>> {}: Evaluating...'.format(dataset)) 359 | 360 | # convert to numpy 361 | vecs = vecs.numpy() 362 | qvecs = qvecs.numpy() 363 | 364 | # search, rank, and print 365 | scores = np.dot(vecs.T, qvecs) 366 | ranks = np.argsort(-scores, axis=0) 367 | compute_map_and_print(dataset, ranks, cfg['gnd'], log) 368 | 369 | if Lw is not None: 370 | # whiten the vectors 371 | vecs_lw = whitenapply(vecs, Lw['m'], Lw['P']) 372 | qvecs_lw = whitenapply(qvecs, Lw['m'], Lw['P']) 373 | 374 | # search, rank, and print 375 | scores = np.dot(vecs_lw.T, qvecs_lw) 376 | ranks = np.argsort(-scores, axis=0) 377 | compute_map_and_print(dataset + ' + whiten', ranks, cfg['gnd'], log) 378 | 379 | print('>> {}: elapsed time: {}'.format(dataset, htime(time.time()-start))) 380 | 381 | 382 | if __name__ == '__main__': 383 | main() -------------------------------------------------------------------------------- /lib/datasets/traindataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import pdb 4 | import numpy as np 5 | import torch 6 | import torch.utils.data as data 7 | import time 8 | 9 | from lib.datasets.datahelpers import default_loader, imresize, cid2filename 10 | from lib.datasets.genericdataset import ImagesFromList 11 | from lib.utils.general import get_data_root 12 | 13 | import torch.nn.functional as F 14 | from shutil import copyfile 15 | 16 | class TuplesDataset(data.Dataset): 17 | """ 18 | Args: 19 | name (string): dataset name: 'retrieval-sfm-120k' 20 | mode (string): 'train' or 'val' for training and validation parts of dataset 21 | imsize (int, Default: None): Defines the maximum size of longer image side 22 | transform (callable, optional): A function/transform that takes in an PIL image 23 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 24 | loader (callable, optional): A function to load an image given its path. 25 | nnum (int, Default:5): Number of negatives for a query image in a training tuple 26 | qsize (int, Default:1000): Number of query images, ie number of (q,p,n1,...nN) tuples, to be processed in one epoch 27 | poolsize (int, Default:10000): Pool size for negative images re-mining 28 | 29 | Attributes: 30 | images (list): List of full filenames for each image 31 | clusters (list): List of clusterID per image 32 | qpool (list): List of all query image indexes 33 | ppool (list): List of positive image indexes, each corresponding to query at the same position in qpool 34 | 35 | qidxs (list): List of qsize query image indexes to be processed in an epoch 36 | pidxs (list): List of qsize positive image indexes, each corresponding to query at the same position in qidxs 37 | nidxs (list): List of qsize tuples of negative images 38 | Each nidxs tuple contains nnum images corresponding to query image at the same position in qidxs 39 | 40 | Lists qidxs, pidxs, nidxs are refreshed by calling the ``create_epoch_tuples()`` method, 41 | ie new q-p pairs are picked and negative images are remined 42 | """ 43 | 44 | def __init__(self, name, mode, imsize=None, nnum=5, qsize=2000, poolsize=20000, feat_path='', transform=None, loader=default_loader, nexamples = 1000): 45 | 46 | if not (mode == 'train' or mode == 'val'): 47 | raise(RuntimeError("MODE should be either train or val, passed as string")) 48 | 49 | if name.startswith('retrieval-SfM'): 50 | # setting up paths 51 | data_root = get_data_root() 52 | db_root = os.path.join(data_root, 'train', name) 53 | ims_root = os.path.join(db_root, 'ims') 54 | 55 | # loading db 56 | db_fn = os.path.join(db_root, '{}.pkl'.format(name)) 57 | with open(db_fn, 'rb') as f: 58 | db = pickle.load(f)[mode] 59 | 60 | # setting fullpath for images 61 | self.images = [cid2filename(db['cids'][i], ims_root) for i in range(len(db['cids']))] 62 | #for x in self.images: 63 | # if not os.path.isfile(x): 64 | # print(x) 65 | 66 | elif name.startswith('gl'): 67 | ## TODO: NOT IMPLEMENTED YET PROPOERLY (WITH AUTOMATIC DOWNLOAD) 68 | 69 | # setting up paths 70 | db_root = '/mnt/fry2/users/datasets/landmarkscvprw18/recognition/' 71 | ims_root = os.path.join(db_root, 'images', 'train') 72 | 73 | # loading db 74 | db_fn = os.path.join(db_root, '{}.pkl'.format(name)) 75 | with open(db_fn, 'rb') as f: 76 | db = pickle.load(f)[mode] 77 | 78 | # setting fullpath for images 79 | self.images = [os.path.join(ims_root, db['cids'][i]+'.jpg') for i in range(len(db['cids']))] 80 | 81 | else: 82 | raise(RuntimeError("Unknown dataset name!")) 83 | 84 | # initializing tuples dataset 85 | self.name = name 86 | self.mode = mode 87 | self.imsize = imsize 88 | self.clusters = db['cluster'] 89 | self.qpool = db['qidxs'] 90 | self.ppool = db['pidxs'] 91 | 92 | ## If we want to keep only unique q-p pairs 93 | ## However, ordering of pairs will change, although that is not important 94 | # qpidxs = list(set([(self.qidxs[i], self.pidxs[i]) for i in range(len(self.qidxs))])) 95 | # self.qidxs = [qpidxs[i][0] for i in range(len(qpidxs))] 96 | # self.pidxs = [qpidxs[i][1] for i in range(len(qpidxs))] 97 | 98 | # size of training subset for an epoch 99 | self.nnum = nnum 100 | self.qsize = min(qsize, len(self.qpool)) 101 | self.poolsize = min(poolsize, len(self.images)) 102 | self.qidxs = None 103 | self.pidxs = None 104 | self.nidxs = None 105 | 106 | self.transform = transform 107 | self.loader = loader 108 | 109 | self.print_freq = 10 110 | 111 | def __getitem__(self, index): 112 | """ 113 | Args: 114 | index (int): Index 115 | 116 | Returns: 117 | images tuple (q,p,n1,...,nN): Loaded train/val tuple at index of self.qidxs 118 | """ 119 | if self.__len__() == 0: 120 | raise(RuntimeError("List qidxs is empty. Run ``dataset.create_epoch_tuples(net)`` method to create subset for train/val!")) 121 | 122 | output = [] 123 | # query image 124 | output.append(self.loader(self.images[self.qidxs[index]])) 125 | # positive image 126 | output.append(self.loader(self.images[self.pidxs[index]])) 127 | # negative images 128 | for i in range(len(self.nidxs[index])): 129 | output.append(self.loader(self.images[self.nidxs[index][i]])) 130 | 131 | if self.imsize is not None: 132 | output = [imresize(img, self.imsize) for img in output] 133 | 134 | if self.transform is not None: 135 | output = [self.transform(output[i]).unsqueeze_(0) for i in range(len(output))] 136 | 137 | target = torch.Tensor([-1, 1] + [0]*len(self.nidxs[index])) 138 | 139 | return output, target 140 | 141 | def __len__(self): 142 | # if not self.qidxs: 143 | # return 0 144 | # return len(self.qidxs) 145 | return self.qsize 146 | 147 | def __repr__(self): 148 | fmt_str = self.__class__.__name__ + '\n' 149 | fmt_str += ' Name and mode: {} {}\n'.format(self.name, self.mode) 150 | fmt_str += ' Number of images: {}\n'.format(len(self.images)) 151 | fmt_str += ' Number of training tuples: {}\n'.format(len(self.qpool)) 152 | fmt_str += ' Number of negatives per tuple: {}\n'.format(self.nnum) 153 | fmt_str += ' Number of tuples processed in an epoch: {}\n'.format(self.qsize) 154 | fmt_str += ' Pool size for negative remining: {}\n'.format(self.poolsize) 155 | tmp = ' Transforms (if any): ' 156 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 157 | return fmt_str 158 | 159 | def create_epoch_tuples(self, net): 160 | 161 | print('>> Creating tuples for an epoch of {}-{}...'.format(self.name, self.mode)) 162 | print(">>>> used network: ") 163 | print(net.meta_repr()) 164 | 165 | ## ------------------------ 166 | ## SELECTING POSITIVE PAIRS 167 | ## ------------------------ 168 | 169 | # draw qsize random queries for tuples 170 | idxs2qpool = torch.randperm(len(self.qpool))[:self.qsize] 171 | self.qidxs = [self.qpool[i] for i in idxs2qpool] 172 | self.pidxs = [self.ppool[i] for i in idxs2qpool] 173 | 174 | ## ------------------------ 175 | ## SELECTING NEGATIVE PAIRS 176 | ## ------------------------ 177 | 178 | # if nnum = 0 create dummy nidxs 179 | # useful when only positives used for training 180 | if self.nnum == 0: 181 | self.nidxs = [[] for _ in range(len(self.qidxs))] 182 | return 0 183 | 184 | # draw poolsize random images for pool of negatives images 185 | idxs2images = torch.randperm(len(self.images))[:self.poolsize] 186 | 187 | # prepare network 188 | net.cuda() 189 | net.eval() 190 | 191 | # no gradients computed, to reduce memory and increase speed 192 | with torch.no_grad(): 193 | 194 | print('>> Extracting descriptors for query images...') 195 | # prepare query loader 196 | loader = torch.utils.data.DataLoader( 197 | ImagesFromList(root='', images=[self.images[i] for i in self.qidxs], imsize=self.imsize, transform=self.transform), 198 | batch_size=1, shuffle=False, num_workers=8, pin_memory=True 199 | ) 200 | # extract query vectors 201 | qvecs = torch.zeros(net.meta['outputdim'], len(self.qidxs)).cuda() 202 | for i, input in enumerate(loader): 203 | qvecs[:, i] = net(input.cuda()).data.squeeze() 204 | if (i+1) % self.print_freq == 0 or (i+1) == len(self.qidxs): 205 | print('\r>>>> {}/{} done...'.format(i+1, len(self.qidxs)), end='') 206 | print('') 207 | 208 | print('>> Extracting descriptors for negative pool...') 209 | # prepare negative pool data loader 210 | loader = torch.utils.data.DataLoader( 211 | ImagesFromList(root='', images=[self.images[i] for i in idxs2images], imsize=self.imsize, transform=self.transform), 212 | batch_size=1, shuffle=False, num_workers=8, pin_memory=True 213 | ) 214 | # extract negative pool vectors 215 | poolvecs = torch.zeros(net.meta['outputdim'], len(idxs2images)).cuda() 216 | for i, input in enumerate(loader): 217 | poolvecs[:, i] = net(input.cuda()).data.squeeze() 218 | if (i+1) % self.print_freq == 0 or (i+1) == len(idxs2images): 219 | print('\r>>>> {}/{} done...'.format(i+1, len(idxs2images)), end='') 220 | print('') 221 | 222 | print('>> Searching for hard negatives...') 223 | # compute dot product scores and ranks on GPU 224 | scores = torch.mm(poolvecs.t(), qvecs) 225 | scores, ranks = torch.sort(scores, dim=0, descending=True) 226 | avg_ndist = torch.tensor(0).float().cuda() # for statistics 227 | n_ndist = torch.tensor(0).float().cuda() # for statistics 228 | # selection of negative examples 229 | self.nidxs = [] 230 | for q in range(len(self.qidxs)): 231 | # do not use query cluster, 232 | # those images are potentially positive 233 | qcluster = self.clusters[self.qidxs[q]] 234 | clusters = [qcluster] 235 | nidxs = [] 236 | r = 0 237 | while len(nidxs) < self.nnum: 238 | potential = idxs2images[ranks[r, q]] 239 | # take at most one image from the same cluster 240 | if not self.clusters[potential] in clusters: 241 | nidxs.append(potential) 242 | clusters.append(self.clusters[potential]) 243 | avg_ndist += torch.pow(qvecs[:,q]-poolvecs[:,ranks[r, q]]+1e-6, 2).sum(dim=0).sqrt() 244 | n_ndist += 1 245 | r += 1 246 | self.nidxs.append(nidxs) 247 | print('>>>> Average negative l2-distance: {:.2f}'.format(avg_ndist/n_ndist)) 248 | print('>>>> Done') 249 | 250 | return (avg_ndist/n_ndist).item() # return average negative l2-distance 251 | 252 | 253 | class TuplesDatasetTS(TuplesDataset): 254 | """ 255 | Args: 256 | name (string): dataset name: 'retrieval-sfm-120k' 257 | mode (string): 'train' or 'val' for training and validation parts of dataset 258 | imsize (int, Default: None): Defines the maximum size of longer image side 259 | transform (callable, optional): A function/transform that takes in an PIL image 260 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 261 | loader (callable, optional): A function to load an image given its path. 262 | nnum (int, Default:5): Number of negatives for a query image in a training tuple 263 | qsize (int, Default:1000): Number of query images, ie number of (q,p,n1,...nN) tuples, to be processed in one epoch 264 | poolsize (int, Default:10000): Pool size for negative images re-mining 265 | 266 | Attributes: 267 | images (list): List of full filenames for each image 268 | clusters (list): List of clusterID per image 269 | qpool (list): List of all query image indexes 270 | ppool (list): List of positive image indexes, each corresponding to query at the same position in qpool 271 | 272 | qidxs (list): List of qsize query image indexes to be processed in an epoch 273 | pidxs (list): List of qsize positive image indexes, each corresponding to query at the same position in qidxs 274 | nidxs (list): List of qsize tuples of negative images 275 | Each nidxs tuple contains nnum images corresponding to query image at the same position in qidxs 276 | 277 | Lists qidxs, pidxs, nidxs are refreshed by calling the ``create_epoch_tuples()`` method, 278 | ie new q-p pairs are picked and negative images are remined 279 | """ 280 | 281 | def __init__(self, name, mode, imsize=None, nnum=5, qsize=2000, poolsize=20000, feat_path='', transform=None, loader=default_loader, nexamples = 1000): 282 | super.__init__(self, name, mode, imsize, nnum, qsize, poolsize, feat_path, transform, loader, nexamples) 283 | self.feat = np.load(feat_path) 284 | 285 | def __getitem__(self, index): 286 | """ 287 | Args: 288 | index (int): Index 289 | 290 | Returns: 291 | images tuple (q,p,n1,...,nN): Loaded train/val tuple at index of self.qidxs 292 | """ 293 | if self.__len__() == 0: 294 | raise(RuntimeError("List qidxs is empty. Run ``dataset.create_epoch_tuples(net)`` method to create subset for train/val!")) 295 | 296 | output = [] 297 | # query image 298 | output.append(self.loader(self.images[self.qidxs[index]])) 299 | if self.imsize is not None: 300 | output = [imresize(img, self.imsize) for img in output] 301 | if self.transform is not None: 302 | output = [self.transform(output[i]).unsqueeze_(0) for i in range(len(output))] 303 | 304 | #positive and negative vectors from teacher 305 | output.append(self.feat[:,self.pidxs[index]]) 306 | for i in range(len(self.nidxs[index])): 307 | output.append(self.feat[:,self.nidxs[index][i]]) 308 | 309 | target = torch.Tensor([-1, 1] + [0]*len(self.nidxs[index])) 310 | 311 | return output, target 312 | 313 | 314 | def create_epoch_tuples(self, net): 315 | 316 | print('>> Creating tuples for an epoch of {}-{}...'.format(self.name, self.mode)) 317 | print(">>>> used network: ") 318 | print(net.meta_repr()) 319 | 320 | ## ------------------------ 321 | ## SELECTING POSITIVE PAIRS 322 | ## ------------------------ 323 | 324 | # draw qsize random queries for tuples 325 | idxs2qpool = torch.randperm(len(self.qpool))[:self.qsize] 326 | self.qidxs = [self.qpool[i] for i in idxs2qpool] 327 | self.pidxs = [self.ppool[i] for i in idxs2qpool] 328 | 329 | ## ------------------------ 330 | ## SELECTING NEGATIVE PAIRS 331 | ## ------------------------ 332 | 333 | # if nnum = 0 create dummy nidxs 334 | # useful when only positives used for training 335 | if self.nnum == 0: 336 | self.nidxs = [[] for _ in range(len(self.qidxs))] 337 | return 0 338 | 339 | # draw poolsize random images for pool of negatives images 340 | idxs2images = torch.randperm(len(self.images))[:self.poolsize] 341 | 342 | # prepare network 343 | net.cuda() 344 | net.eval() 345 | 346 | # no gradients computed, to reduce memory and increase speed 347 | with torch.no_grad(): 348 | 349 | print('>> Extracting descriptors for query images...') 350 | # prepare query loader 351 | loader = torch.utils.data.DataLoader( 352 | ImagesFromList(root='', images=[self.images[i] for i in self.qidxs], imsize=self.imsize, transform=self.transform), 353 | batch_size=1, shuffle=False, num_workers=8, pin_memory=True 354 | ) 355 | # extract query vectors 356 | 357 | qvecs = torch.zeros(net.meta['outputdim'], len(self.qidxs)).cuda() 358 | for i, input in enumerate(loader): 359 | qvecs[:, i] = net(input.cuda()).data.squeeze() 360 | if (i+1) % self.print_freq == 0 or (i+1) == len(self.qidxs): 361 | print('\r>>>> {}/{} done...'.format(i+1, len(self.qidxs)), end='') 362 | print('') 363 | 364 | print('>> Extracting descriptors for negative pool...') 365 | # copy negative pool vectors 366 | poolvecs = torch.zeros(net.meta['outputdim'], len(idxs2images)).cuda() 367 | for i in range(len(idxs2images)): 368 | poolvecs[:, i]=torch.tensor(self.feat[:,idxs2images[i]]).float().cuda() 369 | 370 | print('') 371 | print('>> Searching for hard negatives...') 372 | # compute dot product scores and ranks on GPU 373 | scores = torch.mm(poolvecs.t(), qvecs) 374 | scores, ranks = torch.sort(scores, dim=0, descending=True) 375 | avg_ndist = torch.tensor(0).float().cuda() # for statistics 376 | n_ndist = torch.tensor(0).float().cuda() # for statistics 377 | # selection of negative examples 378 | self.nidxs = [] 379 | for q in range(len(self.qidxs)): 380 | # do not use query cluster, 381 | # those images are potentially positive 382 | qcluster = self.clusters[self.qidxs[q]] 383 | clusters = [qcluster] 384 | nidxs = [] 385 | r = 0 386 | while len(nidxs) < self.nnum: 387 | 388 | potential = idxs2images[ranks[r, q]] 389 | # take at most one image from the same cluster 390 | if not self.clusters[potential] in clusters: 391 | nidxs.append(potential) 392 | clusters.append(self.clusters[potential]) 393 | avg_ndist += torch.pow(qvecs[:,q]-poolvecs[:,ranks[r, q]]+1e-6, 2).sum(dim=0).sqrt() 394 | n_ndist += 1 395 | r += 1 396 | self.nidxs.append(nidxs) 397 | print('>>>> Average negative l2-distance: {:.2f}'.format(avg_ndist/n_ndist)) 398 | print('>>>> Done') 399 | 400 | return (avg_ndist/n_ndist).item() # return average negative l2-distance 401 | 402 | 403 | class TuplesDatasetRand(TuplesDataset): 404 | """ 405 | Used for regression. Outputs the anchor image and its corresponding vectors 406 | in the teacher space. 407 | """ 408 | def __init__(self, name, mode, imsize=None, nnum=5, qsize=2000, poolsize=20000, feat_path='', transform=None, loader=default_loader, nexamples = 1000): 409 | super.__init__(self, name, mode, imsize, nnum, qsize, poolsize, feat_path, transform, loader, nexamples) 410 | 411 | def __getitem__(self, index): 412 | 413 | if self.__len__() == 0: 414 | raise(RuntimeError("List qidxs is empty. Run ``dataset.create_epoch_tuples(net)`` method to create subset for train/val!")) 415 | 416 | output = [] 417 | output.append(self.loader(self.images[self.qidxs[index]])) 418 | 419 | if self.imsize is not None: 420 | output = [imresize(img, self.imsize) for img in output] 421 | if self.transform is not None: 422 | output = [self.transform(output[i]).unsqueeze_(0) for i in range(len(output))] 423 | 424 | target = torch.Tensor(self.feat[:,self.qidxs[index]]) 425 | 426 | return output, target 427 | 428 | def __len__(self): 429 | return self.poolsize 430 | 431 | def create_epoch_tuples(self, net): 432 | 433 | self.qidxs = torch.randperm(len(self.images))[:self.poolsize] 434 | 435 | return 0 436 | 437 | 438 | class TuplesDatasetTSWithSelf(TuplesDatasetTS): 439 | 440 | def __init__(self, name, mode, imsize=None, nnum=5, qsize=2000, poolsize=20000, feat_path='', transform=None, loader=default_loader, nexamples = 1000): 441 | super.__init__(self, name, mode, imsize, nnum, qsize, poolsize, feat_path, transform, loader, nexamples) 442 | 443 | def __getitem__(self, index): 444 | """ 445 | Args: 446 | index (int): Index 447 | 448 | Returns: 449 | images tuple (q,a,p,n1,...,nN): Loaded train/val tuple at index of self.qidxs 450 | The first element is an images, the following are vectors from the teacher. 451 | """ 452 | if self.__len__() == 0: 453 | raise(RuntimeError("List qidxs is empty. Run ``dataset.create_epoch_tuples(net)`` method to create subset for train/val!")) 454 | 455 | output = [] 456 | # query image 457 | output.append(self.loader(self.images[self.qidxs[index]])) 458 | 459 | if self.imsize is not None: 460 | output = [imresize(img, self.imsize) for img in output] 461 | 462 | if self.transform is not None: 463 | output = [self.transform(output[i]).unsqueeze_(0) for i in range(len(output))] 464 | 465 | # anchor from teacher 466 | output.append(self.feat[:,self.qidxs[index]]) 467 | 468 | # positive vector from teacher 469 | output.append(self.feat[:,self.pidxs[index]]) 470 | # negatives from teacher 471 | for i in range(len(self.nidxs[index])): 472 | output.append(self.feat[:,self.nidxs[index][i]]) 473 | 474 | target = torch.Tensor([-1, 1, 1] + [0]*len(self.nidxs[index])) 475 | 476 | return output, target 477 | 478 | 479 | class RegressionTS(TuplesDatasetTS): 480 | 481 | def __init__(self, name, mode, imsize=None, nnum=5, qsize=2000, poolsize=20000, feat_path='', transform=None, loader=default_loader, nexamples = 1000): 482 | 483 | super.__init__(self, name, mode, imsize, nnum, qsize, poolsize, feat_path, transform, loader, nexamples) 484 | 485 | def __getitem__(self, index): 486 | """ 487 | Args: 488 | index (int): Index 489 | 490 | Returns: 491 | images tuple (q,p): Loaded train/val tuple at index of self.qidxs 492 | The first element is an image, the second a vector from the teacher. 493 | """ 494 | if self.__len__() == 0: 495 | raise(RuntimeError("List qidxs is empty. Run ``dataset.create_epoch_tuples(net)`` method to create subset for train/val!")) 496 | 497 | output = [] 498 | # query image 499 | output.append(self.loader(self.images[self.qidxs[index]])) 500 | if self.imsize is not None: 501 | output = [imresize(img, self.imsize) for img in output] 502 | if self.transform is not None: 503 | output = [self.transform(output[i]).unsqueeze_(0) for i in range(len(output))] 504 | # anchor vector from teacher 505 | output.append(self.feat[:,self.qidxs[index]]) 506 | 507 | target = torch.Tensor([-1, 1]) 508 | return output, target 509 | 510 | def __len__(self): 511 | return self.poolsize 512 | 513 | def create_epoch_tuples(self, net): 514 | self.qidxs = torch.randperm(len(self.images))[:self.poolsize] 515 | return 0 516 | 517 | 518 | class RegressionTSOnlyPos(TuplesDatasetTS): 519 | 520 | def __init__(self, name, mode, imsize=None, nnum=5, qsize=2000, poolsize=20000, feat_path='', transform=None, loader=default_loader, nexamples = 1000): 521 | 522 | super.__init__(self, name, mode, imsize, nnum, qsize, poolsize, feat_path, transform, loader, nexamples) 523 | 524 | def __getitem__(self, index): 525 | """ 526 | Args: 527 | index (int): Index 528 | 529 | Returns: 530 | images tuple (q,p,n1,...,nN): Loaded train/val tuple at index of self.qidxs 531 | """ 532 | if self.__len__() == 0: 533 | raise(RuntimeError("List qidxs is empty. Run ``dataset.create_epoch_tuples(net)`` method to create subset for train/val!")) 534 | 535 | output = [] 536 | # query image 537 | output.append(self.loader(self.images[self.qidxs[index]])) 538 | if self.imsize is not None: 539 | output = [imresize(img, self.imsize) for img in output] 540 | if self.transform is not None: 541 | output = [self.transform(output[i]).unsqueeze_(0) for i in range(len(output))] 542 | 543 | output.append(self.feat[:,self.qidxs[index]]) 544 | 545 | target = torch.Tensor([-1, 1]) 546 | return output, target 547 | 548 | def __len__(self): 549 | return self.poolsize 550 | 551 | def create_epoch_tuples(self, net): 552 | idxs2qpool = torch.randperm(len(self.qpool))[:self.qsize] 553 | self.qidxs = [self.qpool[i] for i in idxs2qpool] 554 | self.pidxs = [self.ppool[i] for i in idxs2qpool] 555 | return 0 556 | 557 | 558 | class RandomTriplet(TuplesDataset): 559 | 560 | def __init__(self, name, mode, imsize=None, nnum=5, qsize=2000, poolsize=20000, feat_path='', transform=None, loader=default_loader, nexamples = 1000): 561 | 562 | super.__init__(self, name, mode, imsize, nnum, qsize, poolsize, feat_path, transform, loader, nexamples) 563 | 564 | def __getitem__(self, index): 565 | 566 | if self.__len__() == 0: 567 | raise(RuntimeError("List qidxs is empty. Run ``dataset.create_epoch_tuples(net)`` method to create subset for train/val!")) 568 | 569 | output = [] 570 | # query image 571 | output.append(self.loader(self.images[self.qidxs[index]])) 572 | output.append(self.loader(self.images[self.pidxs[index]])) 573 | output.append(self.loader(self.images[self.nidxs[index]])) 574 | if self.imsize is not None: 575 | output = [imresize(img, self.imsize) for img in output] 576 | if self.transform is not None: 577 | output = [self.transform(output[i]).unsqueeze_(0) for i in range(len(output))] 578 | 579 | target = torch.Tensor([-1, 1, 0]) 580 | return output, target 581 | 582 | def create_epoch_tuples(self, net): 583 | 584 | perm_temp = torch.randperm(len(self.images)) 585 | self.qidxs = perm_temp[:self.qsize] 586 | k = 300 587 | idxs2images = perm_temp[self.qsize:self.poolsize+self.qsize] 588 | 589 | # prepare network 590 | net.cuda() 591 | net.eval() 592 | 593 | # no gradients computed, to reduce memory and increase speed 594 | with torch.no_grad(): 595 | # prepare query loader 596 | loader = torch.utils.data.DataLoader( 597 | ImagesFromList(root='', images=[self.images[i] for i in self.qidxs], imsize=self.imsize, transform=self.transform), 598 | batch_size=1, shuffle=False, num_workers=8, pin_memory=True 599 | ) 600 | # extract query vectors 601 | qvecs = torch.zeros(net.meta['outputdim'], len(self.qidxs)).cuda() 602 | for i in range(len(self.qidxs)): 603 | qvecs[:, i] = torch.tensor(self.feat[:,self.qidxs[i]]).float().cuda() 604 | 605 | # extract negative pool vectors 606 | poolvecs = torch.zeros(net.meta['outputdim'], len(idxs2images)).cuda() 607 | for i in range(len(idxs2images)): 608 | poolvecs[:, i]=torch.tensor(self.feat[:,idxs2images[i]]).float().cuda() 609 | 610 | scores = torch.mm(poolvecs.t(), qvecs) 611 | scores, ranks = torch.sort(scores, dim=0, descending=True) 612 | 613 | self.nidxs = [] 614 | self.pidxs = [] 615 | 616 | for q in range(len(self.qidxs)): 617 | cands = torch.randperm(k)[:2] 618 | nidxs = [] 619 | if cands[0] < cands[1]: 620 | self.nidxs.append(idxs2images[ranks[cands[1], q]]) 621 | self.pidxs.append(idxs2images[ranks[cands[0], q]]) 622 | else: 623 | self.nidxs.append(idxs2images[ranks[cands[0], q]]) 624 | self.pidxs.append(idxs2images[ranks[cands[1], q]]) 625 | 626 | return 0 627 | 628 | class RandomTripletAsym(RandomTriplet): 629 | 630 | def __init__(self, name, mode, imsize=None, nnum=5, qsize=2000, poolsize=20000, feat_path='', transform=None, loader=default_loader, nexamples = 1000): 631 | super.__init__(self, name, mode, imsize, nnum, qsize, poolsize, feat_path, transform, loader, nexamples) 632 | self.feat = np.load(feat_path) 633 | 634 | def __getitem__(self, index): 635 | 636 | if self.__len__() == 0: 637 | raise(RuntimeError("List qidxs is empty. Run ``dataset.create_epoch_tuples(net)`` method to create subset for train/val!")) 638 | output = [] 639 | output.append(self.loader(self.images[self.qidxs[index]])) 640 | 641 | if self.imsize is not None: 642 | output = [imresize(img, self.imsize) for img in output] 643 | 644 | if self.transform is not None: 645 | output = [self.transform(output[i]).unsqueeze_(0) for i in range(len(output))] 646 | 647 | output.append(self.feat[:,self.pidxs[index]]) 648 | output.append(self.feat[:,self.nidxs[index]]) 649 | target = torch.Tensor([-1, 1, 0]) 650 | 651 | return output, target 652 | 653 | class TuplesDatasetTSRand(TuplesDatasetTS): 654 | 655 | def __init__(self, name, mode, imsize=None, nnum=5, qsize=2000, poolsize=20000, feat_path='', transform=None, loader=default_loader, nexamples = 1000): 656 | super.__init__(self, name, mode, imsize, nnum, qsize, poolsize, feat_path, transform, loader, nexamples) 657 | 658 | 659 | def __getitem__(self, index): 660 | """ 661 | Args: 662 | index (int): Index 663 | 664 | Returns: 665 | images tuple (q,p,n1,...,nN): Loaded train/val tuple at index of self.qidxs 666 | The first element is an image, the rest are a vector from the teacher 667 | """ 668 | if self.__len__() == 0: 669 | raise(RuntimeError("List qidxs is empty. Run ``dataset.create_epoch_tuples(net)`` method to create subset for train/val!")) 670 | output = [] 671 | # query image 672 | output.append(self.loader(self.images[self.qidxs[index]])) 673 | if self.imsize is not None: 674 | output = [imresize(img, self.imsize) for img in output] 675 | if self.transform is not None: 676 | output = [self.transform(output[i]).unsqueeze_(0) for i in range(len(output))] 677 | # positive image 678 | output.append(self.feat[:,self.pidxs[index]]) 679 | # negative images 680 | for i in range(len(self.nidxs[index])): 681 | output.append(self.feat[:,self.nidxs[index][i]]) 682 | 683 | target = torch.Tensor([-1, 1] + [0]*len(self.nidxs[index])) 684 | 685 | return output, target 686 | 687 | def create_epoch_tuples(self, net): 688 | 689 | print('>> Creating tuples for an epoch of {}-{}...'.format(self.name, self.mode)) 690 | print(">>>> used network: ") 691 | print(net.meta_repr()) 692 | 693 | ## SELECTING POSITIVE PAIRS 694 | # draw qsize random queries for tuples 695 | idxs2qpool = torch.randperm(len(self.qpool))[:self.qsize] 696 | self.qidxs = [self.qpool[i] for i in idxs2qpool] 697 | self.pidxs = [self.ppool[i] for i in idxs2qpool] 698 | 699 | ## SELECTING NEGATIVE PAIRS 700 | if self.nnum == 0: 701 | self.nidxs = [[] for _ in range(len(self.qidxs))] 702 | return 0 703 | 704 | idxs2images = torch.randperm(len(self.images))[:self.poolsize] 705 | # prepare network 706 | net.cuda() 707 | net.eval() 708 | with torch.no_grad(): 709 | self.nidxs = [] 710 | for q in range(len(self.qidxs)): 711 | qcluster = self.clusters[self.qidxs[q]] 712 | clusters = [qcluster] 713 | nidxs = [] 714 | r = 0 715 | while len(nidxs) < self.nnum: 716 | rand_sel = torch.randperm(len(idxs2images)) 717 | potential = idxs2images[rand_sel[r]] 718 | if not self.clusters[potential] in clusters: 719 | nidxs.append(potential) 720 | clusters.append(self.clusters[potential]) 721 | 722 | r += 1 723 | self.nidxs.append(nidxs) 724 | 725 | return 0 726 | --------------------------------------------------------------------------------