├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── data └── .gitkeep ├── dataset.py ├── diffusion.py ├── evaluate.py ├── knn.py ├── mat2npy.py ├── rank.py ├── slides.pdf └── tmp └── .gitkeep /.gitignore: -------------------------------------------------------------------------------- 1 | # python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | *.egg 7 | *.egg-info 8 | *.jbl 9 | *.npy 10 | data 11 | tmp 12 | 13 | # matlab 14 | *.mat 15 | 16 | # git 17 | .git 18 | 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Fan YANG 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # directory to data 2 | DATA_DIR=./data 3 | # directory to cache files 4 | TMP_DIR=./tmp 5 | # oxford5k, oxford105k, paris6k, paris106k 6 | DATASET=oxford5k 7 | # resnet or siamac 8 | FEATURE_TYPE=resnet 9 | 10 | .PHONY: rank 11 | rank: 12 | python rank.py \ 13 | --cache_dir $(TMP_DIR)/$(DATASET)_$(FEATURE_TYPE) \ 14 | --query_path $(DATA_DIR)/query/$(DATASET)_$(FEATURE_TYPE)_glob.npy \ 15 | --gallery_path $(DATA_DIR)/gallery/$(DATASET)_$(FEATURE_TYPE)_glob.npy \ 16 | --gnd_path $(DATA_DIR)/gnd_$(DATASET).pkl \ 17 | --dataset_name $(DATASET) \ 18 | --truncation_size 1000 19 | 20 | 21 | .PHONY: mat2npy 22 | mat2npy: 23 | python mat2npy.py \ 24 | --dataset_name $(DATASET) \ 25 | --feature_type $(FEATURE_TYPE) \ 26 | --mat_dir $(DATA_DIR) 27 | 28 | 29 | .PHONY: download 30 | download: 31 | wget http://cmp.felk.cvut.cz/cnnimageretrieval/data/test/oxford5k/gnd_oxford5k.pkl -O $(DATA_DIR)/gnd_oxford5k.pkl 32 | wget http://cmp.felk.cvut.cz/cnnimageretrieval/data/test/paris6k/gnd_paris6k.pkl -O $(DATA_DIR)/gnd_paris6k.pkl 33 | ln -s $(DATA_DIR)/gnd_oxford5k.pkl $(DATA_DIR)/gnd_oxford105k.pkl 34 | ln -s $(DATA_DIR)/gnd_paris6k.pkl $(DATA_DIR)/gnd_paris106k.pkl 35 | for dataset in oxford5k oxford105k paris6k paris106k; do \ 36 | for feature in siamac resnet; do \ 37 | wget ftp://ftp.irisa.fr/local/texmex/corpus/diffusion/data/$$dataset\_$$feature.mat -O $(DATA_DIR)/$$dataset\_$$feature.mat; \ 38 | done; \ 39 | done 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This is a faster and improved version of diffusion retrieval, inspired by [diffusion-retrieval](https://github.com/ahmetius/diffusion-retrieval). 2 | 3 | Reference: 4 | - [F. Yang](https://fyang.me/about), [R. Hinami](http://www.satoh-lab.nii.ac.jp/member/hinami/), [Y. Matsui](http://yusukematsui.me), S. Ly, [S. Satoh](http://research.nii.ac.jp/~satoh/index.html), "**Efficient Image Retrieval via Decoupling Diffusion into Online and Offline Processing**", AAAI 2019. \[[arXiv](https://arxiv.org/abs/1811.10907)\] 5 | 6 | If you would like to understand further details of our method, these [slides](https://github.com/fyang93/diffusion/blob/master/slides.pdf) may provide some help. 7 | 8 | ## Features 9 | 10 | - All random walk processes are moved to offline, making the online search remarkably fast 11 | 12 | - In contrast to previous works, we achieved better performance by applying late truncation instead of early truncation to the graph 13 | 14 | ## Requirements 15 | 16 | - Install Facebook [FAISS](https://github.com/facebookresearch/faiss) by running `conda install faiss-cpu -c pytorch` 17 | > Optional: install the faiss-gpu under the [instruction](https://github.com/facebookresearch/faiss/blob/master/INSTALL.md) according to your CUDA version 18 | 19 | - Install joblib by running `conda install joblib` 20 | 21 | - Install tqdm by running `conda install tqdm` 22 | 23 | ## Parameters 24 | 25 | All parameters can be modified in `Makefile`. You may want to edit [DATASET](https://github.com/fyang93/diffusion/blob/master/Makefile#L6) and [FEATURE_TYPE](https://github.com/fyang93/diffusion/blob/master/Makefile#L8) to test all combinations of each dataset and each feature type. 26 | Another parameter [truncation_size](https://github.com/fyang93/diffusion/blob/master/Makefile#L18) is set to 1000 by default, for large datasets like Oxford105k and Paris106k, changing it to 5000 will improve the performance. 27 | 28 | ## Run 29 | 30 | - Run `make download` to download files needed in experiments; 31 | 32 | - Run `make mat2npy` to convert .mat files to .npy files; 33 | 34 | - Run `make rank` to get the results. If you have GPUs, try using commands like `CUDA_VISIBLE_DEVICES=0,1 make rank`, `0,1` are examples of GPU ids. 35 | > Note: on Oxford5k and Paris6k datasets, the `truncation_size` parameter should be no larger than 1024 when using GPUs according to FAISS's limitation. You can use CPUs instead. 36 | 37 | ## Updates!! 38 | 39 | - We changed the evaluation protocol to the official one. Our previous evaluation code had issues on computing the precision for the first true positive result, which causes the mAP slightly higher than its real value. Since all results in the paper were obtained by the previous evaluation, the comparison is still solid. 40 | - We provide a new retrieval method that uses all queries at once which produces better performance. If you want to use the algorithm described in the paper, please check `search_old` in `rank.py`. 41 | 42 | ## Authors 43 | 44 | - [Fan Yang](https://fyang.me/about) wrote the algorithm 45 | - [Ryota Hinami](http://www.satoh-lab.nii.ac.jp/member/hinami/) wrote the first evaluator implementation (changed to [official evaluation](https://github.com/filipradenovic/cnnimageretrieval-pytorch/blob/master/cirtorch/utils/evaluate.py)) 46 | 47 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fyang93/diffusion/5f5c6b3dae1bf887c686c0f424869b417d632481/data/.gitkeep -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | " dataset module " 5 | 6 | import os 7 | import numpy as np 8 | import joblib 9 | 10 | 11 | def load(path): 12 | """Load features 13 | """ 14 | if not os.path.exists(path): 15 | raise Exception("{} does not exist".format(path)) 16 | ext = os.path.splitext(path)[-1] 17 | return {'.npy': np, '.jbl': joblib}[ext].load(path) 18 | 19 | 20 | class Dataset(object): 21 | """Dataset class 22 | """ 23 | 24 | def __init__(self, query_path, gallery_path): 25 | self.query_path = query_path 26 | self.gallery_path = gallery_path 27 | self._queries = None 28 | self._gallery = None 29 | 30 | @property 31 | def queries(self): 32 | if self._queries is None: 33 | self._queries = load(self.query_path) 34 | return self._queries 35 | 36 | @property 37 | def gallery(self): 38 | if self._gallery is None: 39 | self._gallery = load(self.gallery_path) 40 | return self._gallery 41 | 42 | -------------------------------------------------------------------------------- /diffusion.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | " diffusion module " 5 | 6 | import os 7 | import time 8 | import numpy as np 9 | import joblib 10 | from joblib import Parallel, delayed 11 | import scipy.sparse as sparse 12 | import scipy.sparse.linalg as linalg 13 | from tqdm import tqdm 14 | from knn import KNN, ANN 15 | 16 | 17 | trunc_ids = None 18 | trunc_init = None 19 | lap_alpha = None 20 | 21 | 22 | def get_offline_result(i): 23 | ids = trunc_ids[i] 24 | trunc_lap = lap_alpha[ids][:, ids] 25 | scores, _ = linalg.cg(trunc_lap, trunc_init, tol=1e-6, maxiter=20) 26 | return scores 27 | 28 | 29 | def cache(filename): 30 | """Decorator to cache results 31 | """ 32 | def decorator(func): 33 | def wrapper(*args, **kw): 34 | self = args[0] 35 | path = os.path.join(self.cache_dir, filename) 36 | time0 = time.time() 37 | if os.path.exists(path): 38 | result = joblib.load(path) 39 | cost = time.time() - time0 40 | print('[cache] loading {} costs {:.2f}s'.format(path, cost)) 41 | return result 42 | result = func(*args, **kw) 43 | cost = time.time() - time0 44 | print('[cache] obtaining {} costs {:.2f}s'.format(path, cost)) 45 | joblib.dump(result, path) 46 | return result 47 | return wrapper 48 | return decorator 49 | 50 | 51 | class Diffusion(object): 52 | """Diffusion class 53 | """ 54 | def __init__(self, features, cache_dir): 55 | self.features = features 56 | self.N = len(self.features) 57 | self.cache_dir = cache_dir 58 | # use ANN for large datasets 59 | self.use_ann = self.N >= 100000 60 | if self.use_ann: 61 | self.ann = ANN(self.features, method='cosine') 62 | self.knn = KNN(self.features, method='cosine') 63 | 64 | @cache('offline.jbl') 65 | def get_offline_results(self, n_trunc, kd=50): 66 | """Get offline diffusion results for each gallery feature 67 | """ 68 | print('[offline] starting offline diffusion') 69 | print('[offline] 1) prepare Laplacian and initial state') 70 | global trunc_ids, trunc_init, lap_alpha 71 | if self.use_ann: 72 | _, trunc_ids = self.ann.search(self.features, n_trunc) 73 | sims, ids = self.knn.search(self.features, kd) 74 | lap_alpha = self.get_laplacian(sims, ids) 75 | else: 76 | sims, ids = self.knn.search(self.features, n_trunc) 77 | trunc_ids = ids 78 | lap_alpha = self.get_laplacian(sims[:, :kd], ids[:, :kd]) 79 | trunc_init = np.zeros(n_trunc) 80 | trunc_init[0] = 1 81 | 82 | print('[offline] 2) gallery-side diffusion') 83 | results = Parallel(n_jobs=-1, prefer='threads')(delayed(get_offline_result)(i) 84 | for i in tqdm(range(self.N), 85 | desc='[offline] diffusion')) 86 | all_scores = np.concatenate(results) 87 | 88 | print('[offline] 3) merge offline results') 89 | rows = np.repeat(np.arange(self.N), n_trunc) 90 | offline = sparse.csr_matrix((all_scores, (rows, trunc_ids.reshape(-1))), 91 | shape=(self.N, self.N), 92 | dtype=np.float32) 93 | return offline 94 | 95 | # @cache('laplacian.jbl') 96 | def get_laplacian(self, sims, ids, alpha=0.99): 97 | """Get Laplacian_alpha matrix 98 | """ 99 | affinity = self.get_affinity(sims, ids) 100 | num = affinity.shape[0] 101 | degrees = affinity @ np.ones(num) + 1e-12 102 | # mat: degree matrix ^ (-1/2) 103 | mat = sparse.dia_matrix( 104 | (degrees ** (-0.5), [0]), shape=(num, num), dtype=np.float32) 105 | stochastic = mat @ affinity @ mat 106 | sparse_eye = sparse.dia_matrix( 107 | (np.ones(num), [0]), shape=(num, num), dtype=np.float32) 108 | lap_alpha = sparse_eye - alpha * stochastic 109 | return lap_alpha 110 | 111 | # @cache('affinity.jbl') 112 | def get_affinity(self, sims, ids, gamma=3): 113 | """Create affinity matrix for the mutual kNN graph of the whole dataset 114 | Args: 115 | sims: similarities of kNN 116 | ids: indexes of kNN 117 | Returns: 118 | affinity: affinity matrix 119 | """ 120 | num = sims.shape[0] 121 | sims[sims < 0] = 0 # similarity should be non-negative 122 | sims = sims ** gamma 123 | # vec_ids: feature vectors' ids 124 | # mut_ids: mutual (reciprocal) nearest neighbors' ids 125 | # mut_sims: similarites between feature vectors and their mutual nearest neighbors 126 | vec_ids, mut_ids, mut_sims = [], [], [] 127 | for i in range(num): 128 | # check reciprocity: i is in j's kNN and j is in i's kNN when i != j 129 | ismutual = np.isin(ids[ids[i]], i).any(axis=1) 130 | ismutual[0] = False 131 | if ismutual.any(): 132 | vec_ids.append(i * np.ones(ismutual.sum(), dtype=int)) 133 | mut_ids.append(ids[i, ismutual]) 134 | mut_sims.append(sims[i, ismutual]) 135 | vec_ids, mut_ids, mut_sims = map(np.concatenate, [vec_ids, mut_ids, mut_sims]) 136 | affinity = sparse.csc_matrix((mut_sims, (vec_ids, mut_ids)), 137 | shape=(num, num), dtype=np.float32) 138 | return affinity 139 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def compute_ap(ranks, nres): 5 | """ 6 | Computes average precision for given ranked indexes. 7 | 8 | Arguments 9 | --------- 10 | ranks : zerro-based ranks of positive images 11 | nres : number of positive images 12 | 13 | Returns 14 | ------- 15 | ap : average precision 16 | """ 17 | 18 | # number of images ranked by the system 19 | nimgranks = len(ranks) 20 | # accumulate trapezoids in PR-plot 21 | ap = 0 22 | 23 | recall_step = 1.0 / nres 24 | 25 | for j in np.arange(nimgranks): 26 | rank = ranks[j] 27 | if rank == 0: 28 | precision_0 = 1.0 29 | else: 30 | precision_0 = float(j) / rank 31 | precision_1 = float(j + 1) / (rank + 1) 32 | ap += (precision_0 + precision_1) * recall_step / 2.0 33 | return ap 34 | 35 | 36 | def compute_map(ranks, gnd, kappas=[]): 37 | """ 38 | Computes the mAP for a given set of returned results. 39 | 40 | Usage: 41 | map = compute_map (ranks, gnd) 42 | computes mean average precsion (map) only 43 | 44 | map, aps, pr, prs = compute_map (ranks, gnd, kappas) 45 | computes mean average precision (map), average precision (aps) for each query 46 | computes mean precision at kappas (pr), precision at kappas (prs) for each query 47 | 48 | Notes: 49 | 1) ranks starts from 0, ranks.shape = db_size X #queries 50 | 2) The junk results (e.g., the query itself) should be declared in the gnd stuct array 51 | 3) If there are no positive images for some query, that query is excluded from the evaluation 52 | """ 53 | 54 | map = 0.0 55 | nq = len(gnd) # number of queries 56 | aps = np.zeros(nq) 57 | pr = np.zeros(len(kappas)) 58 | prs = np.zeros((nq, len(kappas))) 59 | nempty = 0 60 | 61 | for i in np.arange(nq): 62 | qgnd = np.array(gnd[i]["ok"]) 63 | 64 | # no positive images, skip from the average 65 | if qgnd.shape[0] == 0: 66 | aps[i] = float("nan") 67 | prs[i, :] = float("nan") 68 | nempty += 1 69 | continue 70 | 71 | try: 72 | qgndj = np.array(gnd[i]["junk"]) 73 | except: 74 | qgndj = np.empty(0) 75 | 76 | # sorted positions of positive and junk images (0 based) 77 | pos = np.arange(ranks.shape[0])[np.in1d(ranks[:, i], qgnd)] 78 | junk = np.arange(ranks.shape[0])[np.in1d(ranks[:, i], qgndj)] 79 | 80 | k = 0 81 | ij = 0 82 | if len(junk): 83 | # decrease positions of positives based on the number of 84 | # junk images appearing before them 85 | ip = 0 86 | while ip < len(pos): 87 | while ij < len(junk) and pos[ip] > junk[ij]: 88 | k += 1 89 | ij += 1 90 | pos[ip] = pos[ip] - k 91 | ip += 1 92 | 93 | # compute ap 94 | ap = compute_ap(pos, len(qgnd)) 95 | map = map + ap 96 | aps[i] = ap 97 | 98 | # compute precision @ k 99 | pos += 1 # get it to 1-based 100 | for j in np.arange(len(kappas)): 101 | kq = min(max(pos), kappas[j]) 102 | prs[i, j] = (pos <= kq).sum() / kq 103 | pr = pr + prs[i, :] 104 | 105 | map = map / (nq - nempty) 106 | pr = pr / (nq - nempty) 107 | return map, aps, pr, prs 108 | 109 | 110 | def compute_map_and_print(dataset, ranks, gnd, kappas=[1, 5, 10]): 111 | # old evaluation protocol 112 | if dataset.startswith("oxford") or dataset.startswith("paris"): 113 | map, aps, _, _ = compute_map(ranks, gnd) 114 | print(">> {}: mAP {:.2f}".format(dataset, 115 | np.around(map * 100, decimals=2))) 116 | 117 | # new evaluation protocol 118 | elif dataset.startswith("roxford") or dataset.startswith("rparis"): 119 | gnd_t = [] 120 | for i in range(len(gnd)): 121 | g = {} 122 | g["ok"] = np.concatenate([gnd[i]["easy"]]) 123 | g["junk"] = np.concatenate([gnd[i]["junk"], gnd[i]["hard"]]) 124 | gnd_t.append(g) 125 | mapE, apsE, mprE, prsE = compute_map(ranks, gnd_t, kappas) 126 | 127 | gnd_t = [] 128 | for i in range(len(gnd)): 129 | g = {} 130 | g["ok"] = np.concatenate([gnd[i]["easy"], gnd[i]["hard"]]) 131 | g["junk"] = np.concatenate([gnd[i]["junk"]]) 132 | gnd_t.append(g) 133 | mapM, apsM, mprM, prsM = compute_map(ranks, gnd_t, kappas) 134 | 135 | gnd_t = [] 136 | for i in range(len(gnd)): 137 | g = {} 138 | g["ok"] = np.concatenate([gnd[i]["hard"]]) 139 | g["junk"] = np.concatenate([gnd[i]["junk"], gnd[i]["easy"]]) 140 | gnd_t.append(g) 141 | mapH, apsH, mprH, prsH = compute_map(ranks, gnd_t, kappas) 142 | 143 | print(">> {}: mAP E: {}, M: {}, H: {}".format( 144 | dataset, 145 | np.around(mapE * 100, decimals=2), 146 | np.around(mapM * 100, decimals=2), 147 | np.around(mapH * 100, decimals=2), 148 | )) 149 | print(">> {}: mP@k{} E: {}, M: {}, H: {}".format( 150 | dataset, 151 | kappas, 152 | np.around(mprE * 100, decimals=2), 153 | np.around(mprM * 100, decimals=2), 154 | np.around(mprH * 100, decimals=2), 155 | )) 156 | -------------------------------------------------------------------------------- /knn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | " knn module, all credits to faiss! " 5 | 6 | import os 7 | import numpy as np 8 | import time 9 | import faiss 10 | from tqdm import tqdm 11 | 12 | 13 | class BaseKNN(object): 14 | """KNN base class""" 15 | def __init__(self, database, method): 16 | if database.dtype != np.float32: 17 | database = database.astype(np.float32) 18 | self.N = len(database) 19 | self.D = database[0].shape[-1] 20 | self.database = database if database.flags['C_CONTIGUOUS'] \ 21 | else np.ascontiguousarray(database) 22 | 23 | def add(self, batch_size=10000): 24 | """Add data into index""" 25 | if self.N <= batch_size: 26 | self.index.add(self.database) 27 | else: 28 | [self.index.add(self.database[i:i+batch_size]) 29 | for i in tqdm(range(0, len(self.database), batch_size), 30 | desc='[index] add')] 31 | 32 | def search(self, queries, k): 33 | """Search 34 | Args: 35 | queries: query vectors 36 | k: get top-k results 37 | Returns: 38 | sims: similarities of k-NN 39 | ids: indexes of k-NN 40 | """ 41 | if not queries.flags['C_CONTIGUOUS']: 42 | queries = np.ascontiguousarray(queries) 43 | if queries.dtype != np.float32: 44 | queries = queries.astype(np.float32) 45 | sims, ids = self.index.search(queries, k) 46 | return sims, ids 47 | 48 | 49 | class KNN(BaseKNN): 50 | """KNN class 51 | Args: 52 | database: feature vectors in database 53 | method: distance metric 54 | """ 55 | def __init__(self, database, method): 56 | super().__init__(database, method) 57 | self.index = {'cosine': faiss.IndexFlatIP, 58 | 'euclidean': faiss.IndexFlatL2}[method](self.D) 59 | if os.environ.get('CUDA_VISIBLE_DEVICES'): 60 | self.index = faiss.index_cpu_to_all_gpus(self.index) 61 | self.add() 62 | 63 | 64 | class ANN(BaseKNN): 65 | """Approximate nearest neighbor search class 66 | Args: 67 | database: feature vectors in database 68 | method: distance metric 69 | """ 70 | def __init__(self, database, method, M=128, nbits=8, nlist=316, nprobe=64): 71 | super().__init__(database, method) 72 | self.quantizer = {'cosine': faiss.IndexFlatIP, 73 | 'euclidean': faiss.IndexFlatL2}[method](self.D) 74 | self.index = faiss.IndexIVFPQ(self.quantizer, self.D, nlist, M, nbits) 75 | samples = database[np.random.permutation(np.arange(self.N))[:self.N // 5]] 76 | print("[ANN] train") 77 | self.index.train(samples) 78 | self.add() 79 | self.index.nprobe = nprobe 80 | 81 | -------------------------------------------------------------------------------- /mat2npy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | " convert .mat file to .npy file " 5 | 6 | import os 7 | import argparse 8 | import numpy as np 9 | import joblib 10 | import h5py 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--dataset_name', 16 | type=str, 17 | required=True, 18 | choices=['oxford5k', 'oxford105k', 19 | 'paris6k', 'paris106k'], 20 | help=""" 21 | Name of the dataset 22 | """) 23 | parser.add_argument('--feature_type', 24 | type=str, 25 | required=True, 26 | choices=['resnet', 'siamac'], 27 | help=""" 28 | Feature type 29 | """) 30 | parser.add_argument('--mat_dir', 31 | type=str, 32 | required=True, 33 | help=""" 34 | Directory to the .mat file 35 | """) 36 | args = parser.parse_args() 37 | return args 38 | 39 | 40 | if __name__ == '__main__': 41 | args = parse_args() 42 | input_file = '{}_{}.mat'.format(args.dataset_name, args.feature_type) 43 | glob_output_file = '{}_{}_glob.npy'.format(args.dataset_name, args.feature_type) 44 | query_dir = os.path.join(args.mat_dir, 'query') 45 | gallery_dir = os.path.join(args.mat_dir, 'gallery') 46 | if not os.path.exists(query_dir): 47 | os.makedirs(query_dir) 48 | if not os.path.exists(gallery_dir): 49 | os.makedirs(gallery_dir) 50 | with h5py.File(os.path.join(args.mat_dir, input_file), 'r') as f: 51 | glob_q = np.array([f[x[0]][:] for x in f['/glob/Q']]) 52 | np.save(os.path.join(args.mat_dir, 'query', 53 | glob_output_file), np.squeeze(glob_q, axis=1)) 54 | glob_g = np.array([f[x[0]][:] for x in f['/glob/V']]) 55 | np.save(os.path.join(args.mat_dir, 'gallery', 56 | glob_output_file), np.squeeze(glob_g, axis=1)) 57 | -------------------------------------------------------------------------------- /rank.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | " rank module " 5 | 6 | import os 7 | import time 8 | import argparse 9 | import pickle 10 | import numpy as np 11 | from tqdm import tqdm 12 | from dataset import Dataset 13 | from knn import KNN 14 | from diffusion import Diffusion 15 | from sklearn import preprocessing 16 | from evaluate import compute_map_and_print 17 | 18 | 19 | def search(): 20 | n_query = len(queries) 21 | diffusion = Diffusion(np.vstack([queries, gallery]), args.cache_dir) 22 | offline = diffusion.get_offline_results(args.truncation_size, args.kd) 23 | features = preprocessing.normalize(offline, norm="l2", axis=1) 24 | scores = features[:n_query] @ features[n_query:].T 25 | ranks = np.argsort(-scores.todense()) 26 | evaluate(ranks) 27 | 28 | 29 | def search_old(gamma=3): 30 | diffusion = Diffusion(gallery, args.cache_dir) 31 | offline = diffusion.get_offline_results(args.truncation_size, args.kd) 32 | 33 | time0 = time.time() 34 | print('[search] 1) k-NN search') 35 | sims, ids = diffusion.knn.search(queries, args.kq) 36 | sims = sims ** gamma 37 | qr_num = ids.shape[0] 38 | 39 | print('[search] 2) linear combination') 40 | all_scores = np.empty((qr_num, args.truncation_size), dtype=np.float32) 41 | all_ranks = np.empty((qr_num, args.truncation_size), dtype=np.int) 42 | for i in tqdm(range(qr_num), desc='[search] query'): 43 | scores = sims[i] @ offline[ids[i]] 44 | parts = np.argpartition(-scores, args.truncation_size)[:args.truncation_size] 45 | ranks = np.argsort(-scores[parts]) 46 | all_scores[i] = scores[parts][ranks] 47 | all_ranks[i] = parts[ranks] 48 | print('[search] search costs {:.2f}s'.format(time.time() - time0)) 49 | 50 | # 3) evaluation 51 | evaluate(all_ranks) 52 | 53 | 54 | def evaluate(ranks): 55 | gnd_name = os.path.splitext(os.path.basename(args.gnd_path))[0] 56 | with open(args.gnd_path, 'rb') as f: 57 | gnd = pickle.load(f)['gnd'] 58 | compute_map_and_print(gnd_name.split("_")[-1], ranks.T, gnd) 59 | 60 | 61 | def parse_args(): 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--cache_dir', 64 | type=str, 65 | default='./cache', 66 | help=""" 67 | Directory to cache 68 | """) 69 | parser.add_argument('--dataset_name', 70 | type=str, 71 | required=True, 72 | help=""" 73 | Name of the dataset 74 | """) 75 | parser.add_argument('--query_path', 76 | type=str, 77 | required=True, 78 | help=""" 79 | Path to query features 80 | """) 81 | parser.add_argument('--gallery_path', 82 | type=str, 83 | required=True, 84 | help=""" 85 | Path to gallery features 86 | """) 87 | parser.add_argument('--gnd_path', 88 | type=str, 89 | help=""" 90 | Path to ground-truth 91 | """) 92 | parser.add_argument('-n', '--truncation_size', 93 | type=int, 94 | default=1000, 95 | help=""" 96 | Number of images in the truncated gallery 97 | """) 98 | args = parser.parse_args() 99 | args.kq, args.kd = 10, 50 100 | return args 101 | 102 | 103 | if __name__ == "__main__": 104 | args = parse_args() 105 | if not os.path.isdir(args.cache_dir): 106 | os.makedirs(args.cache_dir) 107 | dataset = Dataset(args.query_path, args.gallery_path) 108 | queries, gallery = dataset.queries, dataset.gallery 109 | search() 110 | 111 | -------------------------------------------------------------------------------- /slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fyang93/diffusion/5f5c6b3dae1bf887c686c0f424869b417d632481/slides.pdf -------------------------------------------------------------------------------- /tmp/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fyang93/diffusion/5f5c6b3dae1bf887c686c0f424869b417d632481/tmp/.gitkeep --------------------------------------------------------------------------------