├── .gitignore ├── README.md ├── get_accuracy_dist.py ├── hnsw.py ├── hnsw_origin.py └── test ├── test_balanced.py ├── test_balanced_accuracy.py ├── test_origin.py └── test_origin_accuracy.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .idea 3 | *.ind 4 | *.hdf5 5 | __pycache__ 6 | bak -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # hnsw-python 2 | 3 | HNSW implemented by python. 4 | 5 | #### Supported distances: 6 | 7 | | Distance | parameter | Equation | 8 | | ----------------- | --------- | ------------------------------------------------------- | 9 | | Squared L2 | 'l2' | d = sum((Ai-Bi)^2) | 10 | | Cosine similarity | 'cosine' | d = 1.0 - sum(Ai\*Bi) / sqrt(sum(Ai\*Ai) \* sum(Bi*Bi)) | 11 | 12 | #### examples 13 | 14 | ```python 15 | import time 16 | from progressbar import * 17 | import pickle 18 | from hnsw import HNSW 19 | 20 | dim = 200 21 | num_elements = 10000 22 | 23 | data = np.array(np.float32(np.random.random((num_elements, dim)))) 24 | hnsw = HNSW('cosine', m0=16, ef=128) 25 | widgets = ['Progress: ',Percentage(), ' ', Bar('#'),' ', Timer(), ' ', ETA()] 26 | 27 | # show progressbar 28 | pbar = ProgressBar(widgets=widgets, maxval=train_len).start() 29 | for i in range(len(data)): 30 | hnsw.add(data[i]) 31 | pbar.update(i + 1) 32 | pbar.finish() 33 | 34 | # save index 35 | with open('glove.ind', 'wb') as f: 36 | picklestring = pickle.dump(hnsw, f, pickle.HIGHEST_PROTOCOL) 37 | 38 | # load index 39 | fr = open('glove.ind','rb') 40 | hnsw_n = pickle.load(fr) 41 | 42 | add_point_time = time.time() 43 | idx = hnsw_n.search(np.float32(np.random.random((1, 200))), 10) 44 | search_time = time.time() 45 | print("Searchtime: %f" % (search_time - add_point_time)) 46 | ``` 47 | 48 | -------------------------------------------------------------------------------- /get_accuracy_dist.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import h5py 4 | import pprint 5 | 6 | f = h5py.File('glove-25-angular.hdf5','r') 7 | distances = f['distances'] 8 | neighbors = f['neighbors'] 9 | test = f['test'] 10 | train = f['train'] 11 | 12 | pprint.pprint(distances[2]) -------------------------------------------------------------------------------- /hnsw.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import pprint 4 | import sys 5 | from heapq import heapify, heappop, heappush, heapreplace, nlargest, nsmallest 6 | from math import log2 7 | from operator import itemgetter 8 | from random import random 9 | 10 | import numpy as np 11 | 12 | 13 | class HNSW(object): 14 | # self._graphs[level][i] contains a {j: dist} dictionary, 15 | # where j is a neighbor of i and dist is distance 16 | 17 | def l2_distance(self, a, b): 18 | return np.linalg.norm(a - b) 19 | 20 | def cosine_distance(self, a, b): 21 | try: 22 | return np.dot(a, b)/(np.linalg.norm(a)*(np.linalg.norm(b))) 23 | except ValueError: 24 | print(a) 25 | print(b) 26 | 27 | 28 | def _distance(self, x, y): 29 | return self.distance_func(x, [y])[0] 30 | 31 | def vectorized_distance_(self, x, ys): 32 | pprint.pprint([self.distance_func(x, y) for y in ys]) 33 | return [self.distance_func(x, y) for y in ys] 34 | 35 | def __init__(self, distance_type, m=5, ef=200, m0=None, heuristic=True, vectorized=False): 36 | self.data = [] 37 | if distance_type == "l2": 38 | # l2 distance 39 | distance_func = self.l2_distance 40 | elif distance_type == "cosine": 41 | # cosine distance 42 | distance_func = self.cosine_distance 43 | else: 44 | raise TypeError('Please check your distance type!') 45 | 46 | self.distance_func = distance_func 47 | 48 | if vectorized: 49 | self.distance = self._distance 50 | self.vectorized_distance = distance_func 51 | else: 52 | self.distance = distance_func 53 | self.vectorized_distance = self.vectorized_distance_ 54 | 55 | self._m = m 56 | self._ef = ef 57 | self._m0 = 2 * m if m0 is None else m0 58 | self._level_mult = 1 / log2(m) 59 | self._graphs = [] 60 | self._enter_point = None 61 | 62 | self._select = ( 63 | self._select_heuristic if heuristic else self._select_naive) 64 | 65 | def add(self, elem, ef=None): 66 | 67 | if ef is None: 68 | ef = self._ef 69 | 70 | distance = self.distance 71 | data = self.data 72 | graphs = self._graphs 73 | point = self._enter_point 74 | m = self._m 75 | 76 | # level at which the element will be inserted 77 | level = int(-log2(random()) * self._level_mult) + 1 78 | # print("level: %d" % level) 79 | 80 | # elem will be at data[idx] 81 | idx = len(data) 82 | data.append(elem) 83 | 84 | if point is not None: # the HNSW is not empty, we have an entry point 85 | dist = distance(elem, data[point]) 86 | # for all levels in which we dont have to insert elem, 87 | # we search for the closest neighbor 88 | for layer in reversed(graphs[level:]): 89 | point, dist = self._search_graph_ef1(elem, point, dist, layer) 90 | # at these levels we have to insert elem; ep is a heap of entry points. 91 | ep = [(-dist, point)] 92 | # pprint.pprint(ep) 93 | layer0 = graphs[0] 94 | for layer in reversed(graphs[:level]): 95 | level_m = m if layer is not layer0 else self._m0 96 | # navigate the graph and update ep with the closest 97 | # nodes we find 98 | ep = self._search_graph(elem, ep, layer, ef) 99 | # insert in g[idx] the best neighbors 100 | layer[idx] = layer_idx = {} 101 | self._select(layer_idx, ep, level_m, layer, heap=True) 102 | # assert len(layer_idx) <= level_m 103 | # insert backlinks to the new node 104 | for j, dist in layer_idx.items(): 105 | self._select(layer[j], (idx, dist), level_m, layer) 106 | # assert len(g[j]) <= level_m 107 | # assert all(e in g for _, e in ep) 108 | for i in range(len(graphs), level): 109 | # for all new levels, we create an empty graph 110 | graphs.append({idx: {}}) 111 | self._enter_point = idx 112 | 113 | def balanced_add(self, elem, ef=None): 114 | if ef is None: 115 | ef = self._ef 116 | 117 | distance = self.distance 118 | data = self.data 119 | graphs = self._graphs 120 | point = self._enter_point 121 | m = self._m 122 | m0 = self._m0 123 | 124 | idx = len(data) 125 | data.append(elem) 126 | 127 | if point is not None: 128 | dist = distance(elem, data[point]) 129 | pd = [(point, dist)] 130 | # pprint.pprint(len(graphs)) 131 | for layer in reversed(graphs[1:]): 132 | point, dist = self._search_graph_ef1(elem, point, dist, layer) 133 | pd.append((point, dist)) 134 | for level, layer in enumerate(graphs): 135 | # print('\n') 136 | # pprint.pprint(layer) 137 | level_m = m0 if level == 0 else m 138 | candidates = self._search_graph( 139 | elem, [(-dist, point)], layer, ef) 140 | layer[idx] = layer_idx = {} 141 | self._select(layer_idx, candidates, level_m, layer, heap=True) 142 | # add reverse edges 143 | for j, dist in layer_idx.items(): 144 | self._select(layer[j], [idx, dist], level_m, layer) 145 | assert len(layer[j]) <= level_m 146 | if len(layer_idx) < level_m: 147 | return 148 | if level < len(graphs) - 1: 149 | if any(p in graphs[level + 1] for p in layer_idx): 150 | return 151 | point, dist = pd.pop() 152 | graphs.append({idx: {}}) 153 | self._enter_point = idx 154 | 155 | def search(self, q, k=None, ef=None): 156 | """Find the k points closest to q.""" 157 | 158 | distance = self.distance 159 | graphs = self._graphs 160 | point = self._enter_point 161 | 162 | if ef is None: 163 | ef = self._ef 164 | 165 | if point is None: 166 | raise ValueError("Empty graph") 167 | 168 | dist = distance(q, self.data[point]) 169 | # look for the closest neighbor from the top to the 2nd level 170 | for layer in reversed(graphs[1:]): 171 | point, dist = self._search_graph_ef1(q, point, dist, layer) 172 | # look for ef neighbors in the bottom level 173 | ep = self._search_graph(q, [(-dist, point)], graphs[0], ef) 174 | 175 | if k is not None: 176 | ep = nlargest(k, ep) 177 | else: 178 | ep.sort(reverse=True) 179 | 180 | return [(idx, -md) for md, idx in ep] 181 | 182 | def _search_graph_ef1(self, q, entry, dist, layer): 183 | """Equivalent to _search_graph when ef=1.""" 184 | 185 | vectorized_distance = self.vectorized_distance 186 | data = self.data 187 | 188 | best = entry 189 | best_dist = dist 190 | candidates = [(dist, entry)] 191 | visited = set([entry]) 192 | 193 | while candidates: 194 | dist, c = heappop(candidates) 195 | if dist > best_dist: 196 | break 197 | edges = [e for e in layer[c] if e not in visited] 198 | visited.update(edges) 199 | dists = vectorized_distance(q, [data[e] for e in edges]) 200 | for e, dist in zip(edges, dists): 201 | if dist < best_dist: 202 | best = e 203 | best_dist = dist 204 | heappush(candidates, (dist, e)) 205 | # break 206 | 207 | return best, best_dist 208 | 209 | def _search_graph(self, q, ep, layer, ef): 210 | 211 | vectorized_distance = self.vectorized_distance 212 | data = self.data 213 | 214 | candidates = [(-mdist, p) for mdist, p in ep] 215 | heapify(candidates) 216 | visited = set(p for _, p in ep) 217 | 218 | while candidates: 219 | dist, c = heappop(candidates) 220 | mref = ep[0][0] 221 | if dist > -mref: 222 | break 223 | # pprint.pprint(layer[c]) 224 | edges = [e for e in layer[c] if e not in visited] 225 | visited.update(edges) 226 | dists = vectorized_distance(q, [data[e] for e in edges]) 227 | for e, dist in zip(edges, dists): 228 | mdist = -dist 229 | if len(ep) < ef: 230 | heappush(candidates, (dist, e)) 231 | heappush(ep, (mdist, e)) 232 | mref = ep[0][0] 233 | elif mdist > mref: 234 | heappush(candidates, (dist, e)) 235 | heapreplace(ep, (mdist, e)) 236 | mref = ep[0][0] 237 | 238 | return ep 239 | 240 | def _select_naive(self, d, to_insert, m, layer, heap=False): 241 | 242 | if not heap: 243 | idx, dist = to_insert 244 | assert idx not in d 245 | if len(d) < m: 246 | d[idx] = dist 247 | else: 248 | max_idx, max_dist = max(d.items(), key=itemgetter(1)) 249 | if dist < max_dist: 250 | del d[max_idx] 251 | d[idx] = dist 252 | return 253 | 254 | assert not any(idx in d for _, idx in to_insert) 255 | to_insert = nlargest(m, to_insert) # smallest m distances 256 | unchecked = m - len(d) 257 | assert 0 <= unchecked <= m 258 | to_insert, checked_ins = to_insert[:unchecked], to_insert[unchecked:] 259 | to_check = len(checked_ins) 260 | if to_check > 0: 261 | checked_del = nlargest(to_check, d.items(), key=itemgetter(1)) 262 | else: 263 | checked_del = [] 264 | for md, idx in to_insert: 265 | d[idx] = -md 266 | zipped = zip(checked_ins, checked_del) 267 | for (md_new, idx_new), (idx_old, d_old) in zipped: 268 | if d_old <= -md_new: 269 | break 270 | del d[idx_old] 271 | d[idx_new] = -md_new 272 | assert len(d) == m 273 | 274 | def _select_heuristic(self, d, to_insert, m, g, heap=False): 275 | 276 | nb_dicts = [g[idx] for idx in d] 277 | 278 | def prioritize(idx, dist): 279 | return any(nd.get(idx, float('inf')) < dist for nd in nb_dicts), dist, idx 280 | 281 | if not heap: 282 | idx, dist = to_insert 283 | to_insert = [prioritize(idx, dist)] 284 | else: 285 | to_insert = nsmallest(m, (prioritize(idx, -mdist) 286 | for mdist, idx in to_insert)) 287 | 288 | assert len(to_insert) > 0 289 | assert not any(idx in d for _, _, idx in to_insert) 290 | 291 | unchecked = m - len(d) 292 | assert 0 <= unchecked <= m 293 | to_insert, checked_ins = to_insert[:unchecked], to_insert[unchecked:] 294 | to_check = len(checked_ins) 295 | if to_check > 0: 296 | checked_del = nlargest(to_check, (prioritize(idx, dist) 297 | for idx, dist in d.items())) 298 | else: 299 | checked_del = [] 300 | for _, dist, idx in to_insert: 301 | d[idx] = dist 302 | zipped = zip(checked_ins, checked_del) 303 | for (p_new, d_new, idx_new), (p_old, d_old, idx_old) in zipped: 304 | if (p_old, d_old) <= (p_new, d_new): 305 | break 306 | del d[idx_old] 307 | d[idx_new] = d_new 308 | assert len(d) == m 309 | 310 | def __getitem__(self, idx): 311 | 312 | for g in self._graphs: 313 | try: 314 | yield from g[idx].items() 315 | except KeyError: 316 | return 317 | 318 | 319 | if __name__ == "__main__": 320 | dim = 25 321 | num_elements = 1000 322 | 323 | import h5py 324 | import time 325 | from progressbar import * 326 | import pickle 327 | 328 | f = h5py.File('glove-25-angular.hdf5','r') 329 | distances = f['distances'] 330 | neighbors = f['neighbors'] 331 | test = f['test'] 332 | train = f['train'] 333 | pprint.pprint(list(f.keys())) 334 | pprint.pprint(train.shape) 335 | # pprint.pprint() 336 | 337 | # Generating sample data 338 | data = np.array(np.float32(np.random.random((num_elements, dim)))) 339 | data_labels = np.arange(num_elements) 340 | 341 | 342 | hnsw = HNSW('cosine', m0=5, ef=10) 343 | 344 | 345 | # widgets = ['Progress: ',Percentage(), ' ', Bar('#'),' ', Timer(), 346 | # ' ', ETA()] 347 | # pbar = ProgressBar(widgets=widgets, maxval=train.shape[0]).start() 348 | # for i in range(train.shape[0]): 349 | # # if i == 1000: 350 | # # break 351 | # hnsw.balanced_add(train[i]) 352 | # pbar.update(i + 1) 353 | # pbar.finish() 354 | 355 | for index, i in enumerate(data): 356 | if index % 1000 == 0: 357 | pprint.pprint('train No.%d' % index) 358 | # hnsw.balanced_add(i) 359 | hnsw.add(i) 360 | 361 | # with open('glove-25-angular-balanced-128.ind', 'wb') as f: 362 | # picklestring = pickle.dump(hnsw, f, pickle.HIGHEST_PROTOCOL) 363 | 364 | # add_point_time = time.time() 365 | # idx = hnsw.search(np.float32(np.random.random((1, 25))), 1) 366 | # search_time = time.time() 367 | # pprint.pprint(idx) 368 | # pprint.pprint("add point time: %f" % (add_point_time - time_start)) 369 | # pprint.pprint("searchtime: %f" % (search_time - add_point_time)) 370 | # print('\n') 371 | # # pprint.pprint(hnsw._graphs) 372 | # for n in hnsw._graphs: 373 | # pprint.pprint(len(n)) 374 | # pprint.pprint(len(hnsw._graphs)) 375 | # print(hnsw.data) 376 | 377 | # for index, i in enumerate(data): 378 | # idx = hnsw.search(i, 1) 379 | # pprint.pprint(idx[0][0]) 380 | # pprint.pprint(i) 381 | # pprint.pprint(hnsw.data[idx[0][0]]) 382 | 383 | # pprint.pprint('------------------------------') 384 | # pprint.pprint(hnsw.data) 385 | # pprint.pprint('------------------------------') 386 | # pprint.pprint(data) 387 | # pprint.pprint('------------------------------') 388 | # pprint.pprint(hnsw._graphs) 389 | # pprint.pprint(len(hnsw._graphs)) -------------------------------------------------------------------------------- /hnsw_origin.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import pprint 4 | import sys 5 | from heapq import heapify, heappop, heappush, heapreplace, nlargest, nsmallest 6 | from math import log2 7 | from operator import itemgetter 8 | from random import random 9 | 10 | import numpy as np 11 | 12 | 13 | class HNSW(object): 14 | # self._graphs[level][i] contains a {j: dist} dictionary, 15 | # where j is a neighbor of i and dist is distance 16 | 17 | def l2_distance(self, a, b): 18 | return np.linalg.norm(a - b) 19 | 20 | def cosine_distance(self, a, b): 21 | try: 22 | return np.dot(a, b)/(np.linalg.norm(a)*(np.linalg.norm(b))) 23 | except ValueError: 24 | print(a) 25 | print(b) 26 | 27 | 28 | def _distance(self, x, y): 29 | return self.distance_func(x, [y])[0] 30 | 31 | def vectorized_distance_(self, x, ys): 32 | return [self.distance_func(x, y) for y in ys] 33 | 34 | def __init__(self, distance_type, m=5, ef=200, m0=None, heuristic=True, vectorized=False): 35 | self.data = [] 36 | if distance_type == "l2": 37 | # l2 distance 38 | distance_func = self.l2_distance 39 | elif distance_type == "cosine": 40 | # cosine distance 41 | distance_func = self.cosine_distance 42 | else: 43 | raise TypeError('Please check your distance type!') 44 | 45 | self.distance_func = distance_func 46 | 47 | if vectorized: 48 | # def distance_1(x, y): 49 | # return distance_func(x, [y])[0] 50 | 51 | self.distance = self._distance 52 | self.vectorized_distance = distance_func 53 | else: 54 | self.distance = distance_func 55 | 56 | # def vectorized_distance(x, ys): 57 | # return [distance_func(x, y) for y in ys] 58 | 59 | self.vectorized_distance = self.vectorized_distance_ 60 | 61 | self._m = m 62 | self._ef = ef 63 | self._m0 = 2 * m if m0 is None else m0 64 | self._level_mult = 1 / log2(m) 65 | self._graphs = [] 66 | self._enter_point = None 67 | 68 | self._select = ( 69 | self._select_heuristic if heuristic else self._select_naive) 70 | 71 | def add(self, elem, ef=None): 72 | 73 | if ef is None: 74 | ef = self._ef 75 | 76 | distance = self.distance 77 | data = self.data 78 | graphs = self._graphs 79 | point = self._enter_point 80 | m = self._m 81 | 82 | # level at which the element will be inserted 83 | level = int(-log2(random()) * self._level_mult) + 1 84 | # print("level: %d" % level) 85 | 86 | # elem will be at data[idx] 87 | idx = len(data) 88 | data.append(elem) 89 | 90 | if point is not None: # the HNSW is not empty, we have an entry point 91 | dist = distance(elem, data[point]) 92 | # for all levels in which we dont have to insert elem, 93 | # we search for the closest neighbor 94 | for layer in reversed(graphs[level:]): 95 | point, dist = self._search_graph_ef1(elem, point, dist, layer) 96 | # at these levels we have to insert elem; ep is a heap of entry points. 97 | ep = [(-dist, point)] 98 | layer0 = graphs[0] 99 | for layer in reversed(graphs[:level]): 100 | level_m = m if layer is not layer0 else self._m0 101 | # navigate the graph and update ep with the closest 102 | # nodes we find 103 | ep = self._search_graph(elem, ep, layer, ef) 104 | # insert in g[idx] the best neighbors 105 | layer[idx] = layer_idx = {} 106 | self._select(layer_idx, ep, level_m, layer, heap=True) 107 | # assert len(layer_idx) <= level_m 108 | # insert backlinks to the new node 109 | for j, dist in layer_idx.items(): 110 | self._select(layer[j], (idx, dist), level_m, layer) 111 | # assert len(g[j]) <= level_m 112 | # assert all(e in g for _, e in ep) 113 | for i in range(len(graphs), level): 114 | # for all new levels, we create an empty graph 115 | graphs.append({idx: {}}) 116 | self._enter_point = idx 117 | 118 | def balanced_add(self, elem, ef=None): 119 | if ef is None: 120 | ef = self._ef 121 | 122 | distance = self.distance 123 | data = self.data 124 | graphs = self._graphs 125 | point = self._enter_point 126 | m = self._m 127 | m0 = self._m0 128 | 129 | idx = len(data) 130 | data.append(elem) 131 | 132 | if point is not None: 133 | dist = distance(elem, data[point]) 134 | pd = [(point, dist)] 135 | # pprint.pprint(len(graphs)) 136 | for layer in reversed(graphs[1:]): 137 | point, dist = self._search_graph_ef1(elem, point, dist, layer) 138 | pd.append((point, dist)) 139 | for level, layer in enumerate(graphs): 140 | # print('\n') 141 | # pprint.pprint(layer) 142 | level_m = m0 if level == 0 else m 143 | candidates = self._search_graph( 144 | elem, [(-dist, point)], layer, ef) 145 | layer[idx] = layer_idx = {} 146 | self._select(layer_idx, candidates, level_m, layer, heap=True) 147 | # add reverse edges 148 | for j, dist in layer_idx.items(): 149 | self._select(layer[j], [idx, dist], level_m, layer) 150 | assert len(layer[j]) <= level_m 151 | if len(layer_idx) < level_m: 152 | return 153 | if level < len(graphs) - 1: 154 | if any(p in graphs[level + 1] for p in layer_idx): 155 | return 156 | point, dist = pd.pop() 157 | graphs.append({idx: {}}) 158 | self._enter_point = idx 159 | 160 | def search(self, q, k=None, ef=None): 161 | 162 | distance = self.distance 163 | graphs = self._graphs 164 | point = self._enter_point 165 | 166 | if ef is None: 167 | ef = self._ef 168 | 169 | if point is None: 170 | raise ValueError("Empty graph") 171 | 172 | dist = distance(q, self.data[point]) 173 | # look for the closest neighbor from the top to the 2nd level 174 | for layer in reversed(graphs[1:]): 175 | point, dist = self._search_graph_ef1(q, point, dist, layer) 176 | # look for ef neighbors in the bottom level 177 | ep = self._search_graph(q, [(-dist, point)], graphs[0], ef) 178 | 179 | if k is not None: 180 | ep = nlargest(k, ep) 181 | else: 182 | ep.sort(reverse=True) 183 | 184 | return [(idx, -md) for md, idx in ep] 185 | 186 | def _search_graph_ef1(self, q, entry, dist, layer): 187 | 188 | vectorized_distance = self.vectorized_distance 189 | data = self.data 190 | 191 | best = entry 192 | best_dist = dist 193 | candidates = [(dist, entry)] 194 | visited = set([entry]) 195 | 196 | while candidates: 197 | dist, c = heappop(candidates) 198 | if dist > best_dist: 199 | break 200 | edges = [e for e in layer[c] if e not in visited] 201 | visited.update(edges) 202 | dists = vectorized_distance(q, [data[e] for e in edges]) 203 | for e, dist in zip(edges, dists): 204 | if dist < best_dist: 205 | best = e 206 | best_dist = dist 207 | heappush(candidates, (dist, e)) 208 | # break 209 | 210 | return best, best_dist 211 | 212 | def _search_graph(self, q, ep, layer, ef): 213 | 214 | vectorized_distance = self.vectorized_distance 215 | data = self.data 216 | 217 | candidates = [(-mdist, p) for mdist, p in ep] 218 | heapify(candidates) 219 | visited = set(p for _, p in ep) 220 | 221 | while candidates: 222 | dist, c = heappop(candidates) 223 | mref = ep[0][0] 224 | if dist > -mref: 225 | break 226 | 227 | edges = [e for e in layer[c] if e not in visited] 228 | visited.update(edges) 229 | dists = vectorized_distance(q, [data[e] for e in edges]) 230 | for e, dist in zip(edges, dists): 231 | mdist = -dist 232 | if len(ep) < ef: 233 | heappush(candidates, (dist, e)) 234 | heappush(ep, (mdist, e)) 235 | mref = ep[0][0] 236 | elif mdist > mref: 237 | heappush(candidates, (dist, e)) 238 | heapreplace(ep, (mdist, e)) 239 | mref = ep[0][0] 240 | 241 | return ep 242 | 243 | def _select_naive(self, d, to_insert, m, layer, heap=False): 244 | 245 | if not heap: # shortcut when we've got only one thing to insert 246 | idx, dist = to_insert 247 | assert idx not in d 248 | if len(d) < m: 249 | d[idx] = dist 250 | else: 251 | max_idx, max_dist = max(d.items(), key=itemgetter(1)) 252 | if dist < max_dist: 253 | del d[max_idx] 254 | d[idx] = dist 255 | return 256 | 257 | # so we have more than one item to insert, it's a bit more tricky 258 | assert not any(idx in d for _, idx in to_insert) 259 | to_insert = nlargest(m, to_insert) # smallest m distances 260 | unchecked = m - len(d) 261 | assert 0 <= unchecked <= m 262 | to_insert, checked_ins = to_insert[:unchecked], to_insert[unchecked:] 263 | to_check = len(checked_ins) 264 | if to_check > 0: 265 | checked_del = nlargest(to_check, d.items(), key=itemgetter(1)) 266 | else: 267 | checked_del = [] 268 | for md, idx in to_insert: 269 | d[idx] = -md 270 | zipped = zip(checked_ins, checked_del) 271 | for (md_new, idx_new), (idx_old, d_old) in zipped: 272 | if d_old <= -md_new: 273 | break 274 | del d[idx_old] 275 | d[idx_new] = -md_new 276 | assert len(d) == m 277 | 278 | def _select_heuristic(self, d, to_insert, m, g, heap=False): 279 | 280 | nb_dicts = [g[idx] for idx in d] 281 | 282 | def prioritize(idx, dist): 283 | return any(nd.get(idx, float('inf')) < dist for nd in nb_dicts), dist, idx 284 | 285 | if not heap: 286 | idx, dist = to_insert 287 | to_insert = [prioritize(idx, dist)] 288 | else: 289 | to_insert = nsmallest(m, (prioritize(idx, -mdist) 290 | for mdist, idx in to_insert)) 291 | 292 | assert len(to_insert) > 0 293 | assert not any(idx in d for _, _, idx in to_insert) 294 | 295 | unchecked = m - len(d) 296 | assert 0 <= unchecked <= m 297 | to_insert, checked_ins = to_insert[:unchecked], to_insert[unchecked:] 298 | to_check = len(checked_ins) 299 | if to_check > 0: 300 | checked_del = nlargest(to_check, (prioritize(idx, dist) 301 | for idx, dist in d.items())) 302 | else: 303 | checked_del = [] 304 | for _, dist, idx in to_insert: 305 | d[idx] = dist 306 | zipped = zip(checked_ins, checked_del) 307 | for (p_new, d_new, idx_new), (p_old, d_old, idx_old) in zipped: 308 | if (p_old, d_old) <= (p_new, d_new): 309 | break 310 | del d[idx_old] 311 | d[idx_new] = d_new 312 | assert len(d) == m 313 | 314 | def __getitem__(self, idx): 315 | 316 | for g in self._graphs: 317 | try: 318 | yield from g[idx].items() 319 | except KeyError: 320 | return 321 | 322 | 323 | if __name__ == "__main__": 324 | # dim = 200 325 | # num_elements = 100 326 | 327 | import h5py 328 | import time 329 | from progressbar import * 330 | import pickle 331 | 332 | f = h5py.File('glove-25-angular.hdf5','r') 333 | distances = f['distances'] 334 | neighbors = f['neighbors'] 335 | test = f['test'] 336 | train = f['train'] 337 | train_len = train.shape[0] 338 | pprint.pprint(list(f.keys())) 339 | pprint.pprint(train.shape) 340 | # pprint.pprint() 341 | 342 | # # Generating sample data 343 | # data = np.array(np.float32(np.random.random((num_elements, dim)))) 344 | # data_labels = np.arange(num_elements) 345 | 346 | 347 | hnsw = HNSW('cosine', m0=16, ef=128) 348 | 349 | widgets = ['Progress: ',Percentage(), ' ', Bar('#'),' ', Timer(), 350 | ' ', ETA()] 351 | pbar = ProgressBar(widgets=widgets, maxval=train_len).start() 352 | for i in range(train_len): 353 | # if i == 1000: 354 | # break 355 | hnsw.add(train[i]) 356 | pbar.update(i + 1) 357 | pbar.finish() 358 | 359 | with open('glove-25-angular-origin-128.ind', 'wb') as f: 360 | picklestring = pickle.dump(hnsw, f, pickle.HIGHEST_PROTOCOL) 361 | 362 | add_point_time = time.time() 363 | idx = hnsw.search(np.float32(np.random.random((1, 25))), 10) 364 | search_time = time.time() 365 | # pprint.pprint(idx) 366 | # pprint.pprint("add point time: %f" % (add_point_time - time_start)) 367 | pprint.pprint("searchtime: %f" % (search_time - add_point_time)) 368 | # print('\n') 369 | # # pprint.pprint(hnsw._graphs) 370 | # for n in hnsw._graphs: 371 | # pprint.pprint(len(n)) 372 | # pprint.pprint(len(hnsw._graphs)) 373 | # print(hnsw.data) 374 | -------------------------------------------------------------------------------- /test/test_balanced.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import pickle 4 | import pprint 5 | import time 6 | 7 | import h5py 8 | import numpy as np 9 | from pandas import DataFrame 10 | import sys 11 | 12 | sys.path.append('/Users/Ryan/code/python/hnsw-python') 13 | from hnsw import HNSW 14 | 15 | fr = open('glove-25-angular-balanced.ind','rb') 16 | hnsw_n = pickle.load(fr) 17 | 18 | f = h5py.File('glove-25-angular.hdf5','r') 19 | distances = f['distances'] 20 | neighbors = f['neighbors'] 21 | test = f['test'] 22 | train = f['train'] 23 | 24 | variance_record = [] 25 | mean_record = [] 26 | 27 | for j in range(20): 28 | print(j) 29 | time_record = [] 30 | for index, i in enumerate(test): 31 | search_begin = time.time() 32 | idx = hnsw_n.search(i, 10) 33 | # pprint.pprint(idx) 34 | search_end = time.time() 35 | search_time = search_end - search_begin 36 | time_record.append(search_time * 1000) 37 | 38 | variance_n = np.var(time_record) 39 | mean_n = np.mean(time_record) 40 | pprint.pprint('variance: %f' % variance_n) 41 | pprint.pprint('mean: %f' % mean_n) 42 | variance_record.append(variance_n) 43 | mean_record.append(mean_n) 44 | 45 | data = { 46 | 'mean_balanced': mean_record, 47 | 'variance_balanced': variance_record 48 | } 49 | 50 | df = DataFrame(data) 51 | df.to_excel('variance_result_balanced_8.xlsx') 52 | -------------------------------------------------------------------------------- /test/test_balanced_accuracy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import pickle 4 | import pprint 5 | import sys 6 | import time 7 | 8 | import h5py 9 | import numpy as np 10 | 11 | sys.path.append('/Users/Ryan/code/python/hnsw-python') 12 | from hnsw import HNSW 13 | 14 | 15 | fr = open('glove-25-angular-balanced-128.ind','rb') 16 | hnsw_n = pickle.load(fr) 17 | 18 | # add_point_time = time.time() 19 | # idx = hnsw_n.search(np.float32(np.random.random((1, 25))), 10) 20 | # search_time = time.time() 21 | # pprint.pprint(idx) 22 | # pprint.pprint("searchtime: %f" % (search_time - add_point_time)) 23 | 24 | f = h5py.File('glove-25-angular.hdf5','r') 25 | distances = f['distances'] 26 | neighbors = f['neighbors'] 27 | test = f['test'] 28 | train = f['train'] 29 | 30 | pprint.pprint(len(hnsw_n._graphs)) 31 | 32 | for index, i in enumerate(train[0:10]): 33 | idx = hnsw_n.search(i, 5) 34 | pprint.pprint(idx) 35 | -------------------------------------------------------------------------------- /test/test_origin.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import pickle 4 | import pprint 5 | import sys 6 | import time 7 | 8 | import h5py 9 | import numpy as np 10 | from pandas import DataFrame 11 | 12 | sys.path.append('/Users/Ryan/code/python/hnsw-python') 13 | from hnsw import HNSW 14 | 15 | 16 | fr = open('glove-25-angular-origin-128.ind','rb') 17 | hnsw_n = pickle.load(fr) 18 | 19 | # add_point_time = time.time() 20 | # idx = hnsw_n.search(np.float32(np.random.random((1, 25))), 10) 21 | # search_time = time.time() 22 | # pprint.pprint(idx) 23 | # pprint.pprint("searchtime: %f" % (search_time - add_point_time)) 24 | 25 | f = h5py.File('glove-25-angular.hdf5','r') 26 | distances = f['distances'] 27 | neighbors = f['neighbors'] 28 | test = f['test'] 29 | train = f['train'] 30 | 31 | variance_record = [] 32 | mean_record = [] 33 | 34 | for j in range(20): 35 | print(j) 36 | time_record = [] 37 | for index, i in enumerate(test): 38 | search_begin = time.time() 39 | idx = hnsw_n.search(i, 10) 40 | # pprint.pprint(idx) 41 | search_end = time.time() 42 | search_time = search_end - search_begin 43 | time_record.append(search_time * 1000) 44 | 45 | variance_n = np.var(time_record) 46 | mean_n = np.mean(time_record) 47 | pprint.pprint('variance: %f' % variance_n) 48 | pprint.pprint('mean: %f' % mean_n) 49 | variance_record.append(variance_n) 50 | mean_record.append(mean_n) 51 | 52 | data = { 53 | 'mean_origin': mean_record, 54 | 'variance_origin': variance_record 55 | } 56 | 57 | df = DataFrame(data) 58 | df.to_excel('variance_result_origin_8.xlsx') 59 | -------------------------------------------------------------------------------- /test/test_origin_accuracy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import pickle 4 | import pprint 5 | import sys 6 | import time 7 | 8 | import h5py 9 | import numpy as np 10 | 11 | sys.path.append('/Users/Ryan/code/python/hnsw-python') 12 | from hnsw import HNSW 13 | 14 | 15 | fr = open('glove-25-angular-origin-128.ind','rb') 16 | hnsw_n = pickle.load(fr) 17 | 18 | # add_point_time = time.time() 19 | # idx = hnsw_n.search(np.float32(np.random.random((1, 25))), 10) 20 | # search_time = time.time() 21 | # pprint.pprint(idx) 22 | # pprint.pprint("searchtime: %f" % (search_time - add_point_time)) 23 | 24 | f = h5py.File('glove-25-angular.hdf5','r') 25 | distances = f['distances'] 26 | neighbors = f['neighbors'] 27 | test = f['test'] 28 | train = f['train'] 29 | 30 | pprint.pprint(len(hnsw_n._graphs)) 31 | 32 | for index, i in enumerate(train[0:10]): 33 | idx = hnsw_n.search(i, 5) 34 | pprint.pprint(idx) 35 | --------------------------------------------------------------------------------