├── .gitignore ├── README.md ├── indexes ├── _base.py ├── annoy.py ├── flat.py ├── gann.py ├── hnsw.py ├── pq.py ├── sq.py └── vamana.py └── tutorials ├── .DS_Store ├── 2023-03-02_data_science_dojo_00_ann_algorithms.ipynb ├── 2023-03-02_data_science_dojo_01_semantic_search.ipynb └── 2023-03-02_data_science_dojo_02_reverse_image_search.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | **.ipynb_checkpoints 2 | tutorials/*.zip 3 | tutorials/*/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # vector-search 2 | -------------------------------------------------------------------------------- /indexes/_base.py: -------------------------------------------------------------------------------- 1 | import queue 2 | 3 | import numpy as np 4 | 5 | 6 | class immarray(np.ndarray): 7 | """Immutable `array` class. 8 | Immediately sets `writeable` to false and adds a hash function. 9 | """ 10 | def __new__(cls, *args, **kwargs): 11 | return super().__new__(cls, *args, **kwargs) 12 | 13 | def __init__(self, *args, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | self.flags.writeable = False 16 | 17 | def __hash__(self): 18 | return id(self) 19 | 20 | 21 | class _BaseIndex: 22 | 23 | def __init__(self): 24 | self._index = None 25 | 26 | def create(self, dataset): 27 | """Index builder.""" 28 | raise NotImplementedError() 29 | 30 | def insert(self, vector): 31 | """Insert a single vector.""" 32 | raise NotImplementedError() 33 | 34 | def search(self, vector, nq=10): 35 | """Naive (flat) search.""" 36 | nns = queue.PriorityQueue() # should probably use heapq 37 | for (n, v) in enumerate(self._index): 38 | d = -np.linalg.norm(v - vector) 39 | if nns.qsize() == 0 or d > nns.queue[0][0]: 40 | nns.put((d, n)) 41 | if nns.qsize() > nq: 42 | nns.get() 43 | out = [] 44 | for n in range(nq): 45 | if nns.empty(): 46 | break 47 | out.insert(0, nns.get()) 48 | return out 49 | 50 | @property 51 | def index(self): 52 | if self._index: 53 | return self._index 54 | raise ValueError("Call create() first") 55 | -------------------------------------------------------------------------------- /indexes/annoy.py: -------------------------------------------------------------------------------- 1 | """ 2 | annoy.py: Approximate Nearest Neighbors Oh Yeah. 3 | """ 4 | 5 | from timeit import default_timer 6 | from typing import List, Optional 7 | import random 8 | 9 | import numpy as np 10 | 11 | from ._base import _BaseIndex 12 | from ._base import immarray 13 | 14 | 15 | class Node(object): 16 | """Initialize with a set of vectors, then call `split()`. 17 | """ 18 | 19 | def __init__(self, ref: np.ndarray, vecs: List[np.ndarray]): 20 | self._ref = ref 21 | self._vecs = vecs 22 | self._left = None 23 | self._right = None 24 | 25 | @property 26 | def ref(self) -> Optional[np.ndarray]: 27 | """Reference point in n-d hyperspace. Evaluates to `False` if root node. 28 | """ 29 | return self._ref 30 | 31 | @property 32 | def vecs(self) -> List[np.ndarray]: 33 | """Vectors for this leaf node. Evaluates to `False` if not a leaf. 34 | """ 35 | return self._vecs 36 | 37 | @property 38 | def left(self) -> Optional[object]: 39 | """Left node. 40 | """ 41 | return self._left 42 | 43 | @property 44 | def right(self) -> Optional[object]: 45 | """Right node. 46 | """ 47 | return self._right 48 | 49 | def split(self, K: int, imb: float) -> bool: 50 | 51 | # stopping condition: maximum # of vectors for a leaf node 52 | if len(self._vecs) <= K: 53 | return False 54 | 55 | # continue for a maximum of 5 iterations 56 | for n in range(5): 57 | left_vecs = [] 58 | right_vecs = [] 59 | 60 | # take two random indexes and set as left and right halves 61 | left_ref = self._vecs.pop(np.random.randint(len(self._vecs))) 62 | right_ref = self._vecs.pop(np.random.randint(len(self._vecs))) 63 | 64 | # split vectors into halves 65 | for vec in self._vecs: 66 | dist_l = np.linalg.norm(vec - left_ref) 67 | dist_r = np.linalg.norm(vec - right_ref) 68 | if dist_l < dist_r: 69 | left_vecs.append(vec) 70 | else: 71 | right_vecs.append(vec) 72 | 73 | # check to make sure that the tree is mostly balanced 74 | r = len(left_vecs) / len(self._vecs) 75 | if r < imb and r > (1 - imb): 76 | self._left = Node(left_ref, left_vecs) 77 | self._right = Node(right_ref, right_vecs) 78 | return True 79 | 80 | # redo tree build process if imbalance is high 81 | self._vecs.append(left_ref) 82 | self._vecs.append(right_ref) 83 | 84 | return False 85 | 86 | 87 | def _select_nearby(node: Node, q: np.ndarray, thresh: int = 0): 88 | """Functions identically to _is_query_in_left_half, but can return both. 89 | """ 90 | if not node.left or not node.right: 91 | return () 92 | dist_l = np.linalg.norm(q - node.left.ref) 93 | dist_r = np.linalg.norm(q - node.right.ref) 94 | if np.abs(dist_l - dist_r) < thresh: 95 | return (node.left, node.right) 96 | if dist_l < dist_r: 97 | return (node.left,) 98 | return (node.right,) 99 | 100 | 101 | def _build_tree(node, K: int, imb: float): 102 | """Recurses on left and right halves to build a tree. 103 | """ 104 | node.split(K=K, imb=imb) 105 | if node.left: 106 | _build_tree(node.left, K=K, imb=imb) 107 | if node.right: 108 | _build_tree(node.right, K=K, imb=imb) 109 | 110 | 111 | def build_forest(vecs: List[np.ndarray], N: int = 32, K: int = 64, imb: float = 0.95) -> List[Node]: 112 | """Builds a forest of `N` trees. 113 | """ 114 | forest = [] 115 | for _ in range(N): 116 | root = Node(None, vecs) 117 | _build_tree(root, K, imb) 118 | forest.append(root) 119 | return forest 120 | 121 | 122 | def _query_linear(vecs: List[np.ndarray], q: np.ndarray, k: int) -> List[np.ndarray]: 123 | return sorted(vecs, key=lambda v: np.linalg.norm(q-v))[:k] 124 | 125 | 126 | def _query_tree(root: Node, q: np.ndarray, k: int) -> List[np.ndarray]: 127 | """Queries a single tree. 128 | """ 129 | 130 | pq = [root] 131 | nns = [] 132 | while pq: 133 | node = pq.pop(0) 134 | nearby = _select_nearby(node, q, thresh=0.05) 135 | 136 | # if `_select_nearby` does not return either node, then we are at a leaf 137 | if nearby: 138 | pq.extend(nearby) 139 | else: 140 | nns.extend(node.vecs) 141 | 142 | # brute-force search the nearest neighbors 143 | return _query_linear(nns, q, k) 144 | 145 | 146 | def query_forest(forest: List[Node], q, k: int = 10): 147 | nns = set() 148 | for root in forest: 149 | # merge `nns` with query result 150 | res = _query_tree(root, q, k) 151 | nns.update(res) 152 | return _query_linear(nns, q, k) 153 | 154 | 155 | if __name__ == "__main__": 156 | from annoy import AnnoyIndex 157 | import random 158 | 159 | # create dataset 160 | N = 2**17 #131072 161 | N = 2**15 #32768 162 | #N = 2**13 #8192 163 | d = 128 164 | k = 10 165 | 166 | dataset = [] 167 | for _ in range(N): 168 | vec = np.random.random(d) 169 | dataset.append(vec.view(immarray)) 170 | 171 | #print([type(v) for v in dataset]) 172 | 173 | # create query vector 174 | query = np.random.random(128) 175 | 176 | # create index 177 | index = build_forest(dataset) 178 | 179 | # perform query 180 | start = default_timer() 181 | result = query_forest(index, query, k=k) 182 | print(default_timer() - start) 183 | 184 | # brute-force ground truth 185 | start = default_timer() 186 | actual = _query_linear(dataset, query, k=k) 187 | print(default_timer() - start) 188 | 189 | # baseline annoy index 190 | annoy = AnnoyIndex(d, "euclidean") 191 | for (i, vec) in enumerate(dataset): 192 | annoy.add_item(i, vec) 193 | annoy.build(n_trees=32) 194 | baseline = annoy.get_nns_by_vector(query, n=k) 195 | baseline = [dataset[i] for i in baseline] 196 | 197 | # determine top-k recall 198 | r = 0 199 | for vec in actual: 200 | for vec2 in result: 201 | #for vec2 in baseline: 202 | if np.any(vec == vec2): 203 | r += 1 204 | break 205 | 206 | 207 | # print the top-k for each 208 | print(f"recall: {r}/{k}") 209 | #print(f"result: {[np.linalg.norm(query - nn) for nn in result]}") 210 | #print(f"actual: {[np.linalg.norm(query - nn) for nn in actual]}") 211 | 212 | 213 | -------------------------------------------------------------------------------- /indexes/flat.py: -------------------------------------------------------------------------------- 1 | import queue 2 | import numpy as np 3 | 4 | from ._base import _BaseIndex 5 | 6 | 7 | class FlatIndex(_BaseIndex): 8 | 9 | def __init__(self): 10 | pass 11 | 12 | def create(self, dataset): 13 | """The index is the same as the dataset itself.""" 14 | self._index = dataset 15 | 16 | def search(self, vector, nq=10): 17 | """Performs a naive search.""" 18 | return super().search(vector, nq) 19 | 20 | 21 | if __name__ == '__main__': 22 | flat = FlatIndex() 23 | flat.create(np.random.randn(1000, 256)) 24 | print(flat.search(np.random.randn(256))) 25 | -------------------------------------------------------------------------------- /indexes/gann.py: -------------------------------------------------------------------------------- 1 | """ 2 | gann.py: Good (Great/Gangster/Godlike) Approximate Nearest Neighbors. 3 | """ 4 | 5 | from timeit import default_timer 6 | from typing import List, Optional, Tuple 7 | import random 8 | 9 | import numpy as np 10 | 11 | 12 | def perfect_split(vecs: List[np.ndarray]) -> Tuple[np.ndarray]: 13 | """Returns reference points which split the input vectors perfectly. 14 | """ 15 | pass 16 | 17 | 18 | class Node(object): 19 | """Initialize with a set of vectors, then call `split()`. 20 | """ 21 | 22 | def __init__(self, ref: np.ndarray, vecs: List[np.ndarray]): 23 | self._ref = ref 24 | self._vecs = vecs 25 | self._left = None 26 | self._right = None 27 | 28 | @property 29 | def ref(self) -> Optional[np.ndarray]: 30 | """Reference point in n-d hyperspace. Evaluates to `False` if root node. 31 | """ 32 | return self._ref 33 | 34 | @property 35 | def vecs(self) -> List[np.ndarray]: 36 | """Vectors for this leaf node. Evaluates to `False` if not a leaf. 37 | """ 38 | return self._vecs 39 | 40 | @property 41 | def left(self) -> Optional[object]: 42 | """Left node. 43 | """ 44 | return self._left 45 | 46 | @property 47 | def right(self) -> Optional[object]: 48 | """Right node. 49 | """ 50 | return self._right 51 | 52 | def split(self, K: int, imb: float) -> bool: 53 | 54 | # stopping condition: maximum # of vectors for a leaf node 55 | if len(self._vecs) <= K: 56 | return False 57 | 58 | # continue for a maximum of 5 iterations 59 | for n in range(5): 60 | left_vecs = [] 61 | right_vecs = [] 62 | 63 | # take two random indexes and set as left and right halves 64 | left_ref = self._vecs.pop(np.random.randint(len(self._vecs))) 65 | right_ref = self._vecs.pop(np.random.randint(len(self._vecs))) 66 | 67 | # split vectors into halves 68 | for vec in self._vecs: 69 | dist_l = np.linalg.norm(vec - left_ref) 70 | dist_r = np.linalg.norm(vec - right_ref) 71 | if dist_l < dist_r: 72 | left_vecs.append(vec) 73 | else: 74 | right_vecs.append(vec) 75 | 76 | # check to make sure that the tree is mostly balanced 77 | r = len(left_vecs) / len(right_vecs) 78 | r = len(left_vecs) / len(self._vecs) 79 | #print(r) 80 | if r < imb and r > (1 - imb): 81 | self._left = Node(left_ref, left_vecs) 82 | self._right = Node(right_ref, right_vecs) 83 | return True 84 | 85 | # redo tree build process if imbalance is high 86 | 87 | print("fuck") 88 | return False 89 | 90 | 91 | def _select_nearby(node: Node, q: np.ndarray, thresh: int = 0): 92 | """Functions identically to _is_query_in_left_half, but can return both. 93 | """ 94 | if not node.left or not node.right: 95 | return () 96 | dist_l = np.linalg.norm(q - node.left.ref) 97 | dist_r = np.linalg.norm(q - node.right.ref) 98 | if np.abs(dist_l - dist_r) < thresh: 99 | return (node.left, node.right) 100 | if dist_l < dist_r: 101 | return (node.left,) 102 | return (node.right,) 103 | 104 | 105 | def _build_tree(node, K: int, imb: float): 106 | """Recurses on left and right halves to build a tree. 107 | """ 108 | node.split(K=K, imb=imb) 109 | if node.left and node.right: 110 | _build_tree(node.left, K=K, imb=imb) 111 | _build_tree(node.right, K=K, imb=imb) 112 | 113 | 114 | def build_forest(vecs: List[np.ndarray], N: int = 8, K: int = 64, imb: float = 0.95) -> List[Node]: 115 | """Builds a forest of `N` trees. 116 | """ 117 | forest = [] 118 | for _ in range(N): 119 | root = Node(None, vecs) 120 | _build_tree(root, K, imb) 121 | forest.append(root) 122 | return forest 123 | 124 | 125 | def _query_linear(vecs: List[np.ndarray], q: np.ndarray, k: int) -> List[np.ndarray]: 126 | vecs = np.array(vecs) 127 | idxs = np.argsort(np.linalg.norm(vecs - q, axis=1)) 128 | vecs = vecs[idxs][:k] 129 | return list(vecs) 130 | 131 | 132 | def _query_tree(root: Node, q: np.ndarray, k: int) -> List[np.ndarray]: 133 | """Queries a single tree. 134 | """ 135 | 136 | pq = [root] 137 | nns = [] 138 | while pq: 139 | # iteratively determine whether right or left node is closer 140 | node = pq.pop(0) 141 | nearby = _select_nearby(node, q, thresh=1e-2) 142 | if nearby: 143 | pq.extend(nearby) 144 | else: 145 | nns.extend(node.vecs) 146 | 147 | # brute-force search the nearest neighbors 148 | return _query_linear(nns, q, k) 149 | 150 | 151 | def query_forest(forest: List[Node], q, k: int = 10): 152 | nns = [] 153 | for root in forest: 154 | nns.extend(_query_tree(root, q, k)) 155 | return _query_linear(nns, q, k) 156 | 157 | 158 | if __name__ == "__main__": 159 | 160 | # create dataset 161 | N = 100000 162 | d = 128 163 | k = 10 164 | dataset = np.random.random((N, d)) 165 | dataset = [np.random.random(d) for _ in range(N)] 166 | 167 | # create query vector 168 | query = np.random.random(128) 169 | 170 | # create index 171 | index = build_forest(dataset) 172 | 173 | # perform query 174 | start = default_timer() 175 | result = query_forest(index, query, k=k) 176 | print(default_timer() - start) 177 | 178 | # brute-force ground truth 179 | start = default_timer() 180 | actual = _query_linear(dataset, query, k=k) 181 | print(default_timer() - start) 182 | 183 | # determine top-k recall 184 | r = 0 185 | for vec in actual: 186 | for vec2 in result: 187 | if np.all(vec == vec2): 188 | r += 1 189 | break 190 | 191 | # print the top-k for each 192 | 193 | print(f"recall: {r}/{k}") 194 | 195 | 196 | -------------------------------------------------------------------------------- /indexes/hnsw.py: -------------------------------------------------------------------------------- 1 | from bisect import insort 2 | from heapq import heapify, heappop, heappush 3 | 4 | import numpy as np 5 | 6 | from ._base import _BaseIndex 7 | 8 | 9 | class HNSW(_BaseIndex): 10 | 11 | def __init__(self, L=5, mL=0.62, efc=10): 12 | self._L = L 13 | self._mL = mL 14 | self._efc = efc 15 | self._index = [[] for _ in range(L)] 16 | 17 | @staticmethod 18 | def _search_layer(graph, entry, query, ef=1): 19 | 20 | best = (np.linalg.norm(graph[entry][0] - query), entry) 21 | 22 | nns = [best] 23 | visit = set(best) # set of visited nodes 24 | candid = [best] # candidate nodes to insert into nearest neighbors 25 | heapify(candid) 26 | 27 | # find top-k nearest neighbors 28 | while candid: 29 | cv = heappop(candid) 30 | 31 | if nns[-1][0] > cv[0]: 32 | break 33 | 34 | # loop through all nearest neighbors to the candidate vector 35 | for e in graph[cv[1]][1]: 36 | d = np.linalg.norm(graph[e][0] - query) 37 | if (d, e) not in visit: 38 | visit.add((d, e)) 39 | 40 | # push only "better" vectors into candidate heap 41 | if d < nns[-1][0] or len(nns) < ef: 42 | heappush(candid, (d, e)) 43 | insort(nns, (d, e)) 44 | if len(nns) > ef: 45 | nns.pop() 46 | 47 | return nns 48 | 49 | def create(self, dataset): 50 | for v in dataset: 51 | self.insert(v) 52 | 53 | def search(self, query, ef=1): 54 | 55 | # if the index is empty, return an empty list 56 | if not self._index[0]: 57 | return [] 58 | 59 | best_v = 0 # set the initial best vertex to the entry point 60 | for graph in self._index: 61 | best_d, best_v = HNSW._search_layer(graph, best_v, query, ef=1)[0] 62 | if graph[best_v][2]: 63 | best_v = graph[best_v][2] 64 | else: 65 | return HNSW._search_layer(graph, best_v, query, ef=ef) 66 | 67 | def _get_insert_layer(self): 68 | # ml is a multiplicative factor used to normalize the distribution 69 | l = -int(np.log(np.random.random()) * self._mL) 70 | return min(l, self._L-1) 71 | 72 | def insert(self, vec, efc=10): 73 | 74 | # if the index is empty, insert the vector into all layers and return 75 | if not self._index[0]: 76 | i = None 77 | for graph in self._index[::-1]: 78 | graph.append((vec, [], i)) 79 | i = 0 80 | return 81 | 82 | l = self._get_insert_layer() 83 | 84 | start_v = 0 85 | for n, graph in enumerate(self._index): 86 | 87 | # perform insertion for layers [l, L) only 88 | if n < l: 89 | _, start_v = self._search_layer(graph, start_v, vec, ef=1)[0] 90 | else: 91 | node = (vec, [], len(self._index[n+1]) if n < self._L-1 else None) 92 | nns = self._search_layer(graph, start_v, vec, ef=efc) 93 | for nn in nns: 94 | node[1].append(nn[1]) # outbound connections to NNs 95 | graph[nn[1]][1].append(len(graph)) # inbound connections to node 96 | graph.append(node) 97 | 98 | # set the starting vertex to the nearest neighbor in the next layer 99 | start_v = graph[start_v][2] 100 | 101 | 102 | if __name__ == "__main__": 103 | dim = 128 104 | nvec = 1000 105 | hnsw = HNSW() 106 | 107 | 108 | hnsw.create([np.random.randn(dim) for _ in range(nvec)]) 109 | print(hnsw.search(np.random.randn(dim))) 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /indexes/pq.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.cluster.vq import kmeans2 3 | 4 | from ._base import _BaseIndex 5 | 6 | 7 | class ProductQuantizer(_BaseIndex): 8 | 9 | def __init__(self, M=16, K=256): 10 | self.M = M # number of subvectors 11 | self.K = K # number of centroids 12 | self._index = None 13 | self._centroids = None 14 | 15 | def create(self, dataset): 16 | """Fits PQ model based on the input dataset.""" 17 | sublen = dataset.shape[1] // self.M 18 | self._centroids = np.empty((self.M, self.K, sublen), dtype=np.float64) 19 | self._index = np.empty((dataset.shape[0], self.M), dtype=np.uint8) 20 | for m in range(self.M): 21 | subspace = dataset[:,m*sublen:(m+1)*sublen] 22 | (centroids, assignments) = kmeans2(subspace, self.K, iter=32) 23 | self._centroids[m,:,:] = centroids 24 | self._index[:,m] = np.uint8(assignments) 25 | 26 | def search(self, vector, nq=10): 27 | """Performs quantization + naive search.""" 28 | quantized = self.quantize(vector) 29 | return super().search(quantized, nq) 30 | 31 | def quantize(self, vector): 32 | """Quantizes the input vector based on PQ parameters.""" 33 | quantized = np.empty((self.M,), dtype=np.uint8) 34 | sublen = vector.size // self.M 35 | 36 | for m in range(self.M): 37 | subvec = vector[m*sublen:(m+1)*sublen] 38 | centroids = self._centroids[m,:,:] 39 | distances = np.linalg.norm(subvec - centroids, axis=1) 40 | quantized[m] = np.argmin(distances) 41 | return quantized 42 | 43 | def restore(self, vector): 44 | """Restores the original vector using PQ parameters.""" 45 | return np.hstack([self._centroids[m,vector[m],:] for m in range(self.M)]) 46 | 47 | 48 | if __name__ == "__main__": 49 | pq = ProductQuantizer() 50 | pq.create(np.random.randn(10000, 256)) 51 | print(pq.search(np.random.randn(256))) 52 | -------------------------------------------------------------------------------- /indexes/sq.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ._base import _BaseIndex 4 | 5 | 6 | class ScalarQuantizer(_BaseIndex): 7 | 8 | def __init__(self): 9 | self._index = None 10 | self._starts = None 11 | self._steps = None 12 | 13 | def create(self, dataset): 14 | """Calculates and stores SQ parameters based on the input dataset.""" 15 | self._starts = np.min(dataset) 16 | self._steps = (np.max(dataset) - self._starts) / 255 17 | self._index = np.uint8((dataset - self._starts) / self._steps) 18 | 19 | def search(self, vector, nq=10): 20 | """Performs quantization + naive search.""" 21 | quantized = self.quantize(vector) 22 | return super().search(quantized, nq) 23 | 24 | def quantize(self, vector): 25 | """Quantizes the input vector based on SQ parameters""" 26 | return np.uint8((vector - self._starts) / self._steps) 27 | 28 | def restore(self, vector): 29 | """Restores the original vector using SQ parameters.""" 30 | return (vector * self._steps) + self._starts 31 | 32 | 33 | if __name__ == "__main__": 34 | sq = ScalarQuantizer() 35 | sq.create(np.random.randn(1000, 256)) 36 | print(sq.search(np.random.randn(256))) -------------------------------------------------------------------------------- /indexes/vamana.py: -------------------------------------------------------------------------------- 1 | 2 | from heapq import heapify, heappop, heappush 3 | 4 | import numpy as np 5 | 6 | from ._base import _BaseIndex 7 | from ._base import immarray 8 | 9 | 10 | class Vamana(_BaseIndex) 11 | """Vamana graph algorithm implementation. Every element in each graph is a 12 | 2-tuple containing the vector and a list of indexes the vector links to 13 | within the graph. 14 | """ 15 | 16 | def __init__(self, ): 17 | super().__init__() 18 | self._start = None # starting vector 19 | 20 | 21 | def create(self, dataset): 22 | pass 23 | 24 | 25 | def insert(self, vector): 26 | raise NotImplementedError 27 | 28 | 29 | def search(query, nq: int = 10): 30 | """Greedy search. 31 | """ 32 | 33 | best = (np.linalg.norm(self._index[self._start][0] - query), entry) 34 | 35 | nns = [] 36 | visit = set() # set of visited nodes 37 | nns = heapify(nns) 38 | 39 | # find top-k nearest neighbors 40 | while nns - visit: 41 | nn = nns[0] 42 | for idx in nn[1]: 43 | d = np.linalg.norm(self._index[idx][0] - query) 44 | nns.append((d, nn)) 45 | visit.add((d, nn)) 46 | 47 | if len(nns) > nq: 48 | nns = nns[:nq] 49 | 50 | visit.add(cv) 51 | 52 | return nns 53 | 54 | 55 | def _robust_prune(node, candid, a: int = 1.5, R): 56 | 57 | candid.update(node[1]) 58 | node[1] = [] 59 | 60 | while candid: 61 | (d, nn) = (float("inf"), None) 62 | for n in candid: 63 | 64 | 65 | def build_index(): 66 | pass -------------------------------------------------------------------------------- /tutorials/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fzliu/vector-search/2a4035c902e5e8520f0d711cd35ecb1066d74283/tutorials/.DS_Store -------------------------------------------------------------------------------- /tutorials/2023-03-02_data_science_dojo_00_ann_algorithms.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "d322d99a", 6 | "metadata": {}, 7 | "source": [ 8 | "# An Overview of Common Indexing Algorithms\n", 9 | "\n", 10 | "In this notebook, we'll go over some common indexing algorithms and discuss the tradeoffs." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 34, 16 | "id": "fc8208ca", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import sys; sys.path.insert(0, '..')\n", 21 | "import numpy as np\n", 22 | "\n", 23 | "data = np.random.randn(10000, 256)" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "id": "7f77fbcd", 29 | "metadata": {}, 30 | "source": [ 31 | "## Flat Indexing" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 35, 37 | "id": "b904731d", 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "from indexes.flat import FlatIndex\n", 42 | "\n", 43 | "flat = FlatIndex()\n", 44 | "flat.create(data)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 36, 50 | "id": "08cc839e", 51 | "metadata": {}, 52 | "outputs": [ 53 | { 54 | "name": "stdout", 55 | "output_type": "stream", 56 | "text": [ 57 | "74.7 ms ± 2.71 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "%timeit flat.search(np.random.randn(256))" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "id": "31ad2c6a", 68 | "metadata": {}, 69 | "source": [ 70 | "## Product Quantization (PQ)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 37, 76 | "id": "4d13c5dd", 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "from indexes.pq import ProductQuantizer\n", 81 | "\n", 82 | "pq = ProductQuantizer()\n", 83 | "pq.create(data)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 38, 89 | "id": "d17c7cce", 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "75.7 ms ± 2.63 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "%timeit pq.search(np.random.randn(256))" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "id": "e63ed611", 107 | "metadata": {}, 108 | "source": [ 109 | "## Hierarchical Navigable Small Worlds (HNSW)" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 39, 115 | "id": "a26ed001", 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "from indexes.hnsw import HNSW\n", 120 | "\n", 121 | "hnsw = HNSW()\n", 122 | "hnsw.create(data)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": 40, 128 | "id": "37ac8c39", 129 | "metadata": {}, 130 | "outputs": [ 131 | { 132 | "name": "stdout", 133 | "output_type": "stream", 134 | "text": [ 135 | "34.1 ms ± 3.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" 136 | ] 137 | } 138 | ], 139 | "source": [ 140 | "%timeit hnsw.search(np.random.randn(256))" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "id": "6765ae82", 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [] 150 | } 151 | ], 152 | "metadata": { 153 | "kernelspec": { 154 | "display_name": "Python 3 (ipykernel)", 155 | "language": "python", 156 | "name": "python3" 157 | }, 158 | "language_info": { 159 | "codemirror_mode": { 160 | "name": "ipython", 161 | "version": 3 162 | }, 163 | "file_extension": ".py", 164 | "mimetype": "text/x-python", 165 | "name": "python", 166 | "nbconvert_exporter": "python", 167 | "pygments_lexer": "ipython3", 168 | "version": "3.8.9" 169 | } 170 | }, 171 | "nbformat": 4, 172 | "nbformat_minor": 5 173 | } 174 | -------------------------------------------------------------------------------- /tutorials/2023-03-02_data_science_dojo_01_semantic_search.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Movie Search Using Milvus and SentenceTransformers\n", 8 | "In this example we are going to be going over a Wikipedia article search using Milvus and and the SentenceTransformers library. The dataset we are searching through is the Wikipedia-Movie-Plots Dataset found on [Kaggle](https://www.kaggle.com/datasets/jrobischon/wikipedia-movie-plots). For this example we have rehosted the data in a public google drive.\n", 9 | "\n", 10 | "Lets get started.\n", 11 | "\n", 12 | "## Installing Requirements\n", 13 | "For this example we are going to be using `pymilvus` to connect to use Milvus, `sentence-transformers` to connect to embed the movie plots, and `gdown` to download the example dataset." 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": 18, 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "name": "stdout", 23 | "output_type": "stream", 24 | "text": [ 25 | "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", 26 | "To disable this warning, you can either:\n", 27 | "\t- Avoid using `tokenizers` before the fork if possible\n", 28 | "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", 29 | "Requirement already satisfied: pymilvus in /Users/fzliu/.pyenv/lib/python3.8/site-packages (2.2.1)\n", 30 | "Requirement already satisfied: sentence-transformers in /Users/fzliu/.pyenv/lib/python3.8/site-packages (2.2.2)\n", 31 | "Requirement already satisfied: gdown in /Users/fzliu/.pyenv/lib/python3.8/site-packages (4.3.1)\n", 32 | "Requirement already satisfied: grpcio-tools<=1.48.0,>=1.47.0 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from pymilvus) (1.47.2)\n", 33 | "Requirement already satisfied: grpcio<=1.48.0,>=1.47.0 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from pymilvus) (1.47.2)\n", 34 | "Requirement already satisfied: pandas>=1.2.4 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from pymilvus) (1.4.2)\n", 35 | "Requirement already satisfied: mmh3<=3.0.0,>=2.0 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from pymilvus) (3.0.0)\n", 36 | "Requirement already satisfied: ujson<=5.4.0,>=2.0.0 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from pymilvus) (5.1.0)\n", 37 | "Requirement already satisfied: torch>=1.6.0 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from sentence-transformers) (1.13.1)\n", 38 | "Requirement already satisfied: sentencepiece in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from sentence-transformers) (0.1.96)\n", 39 | "Requirement already satisfied: scipy in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from sentence-transformers) (1.9.3)\n", 40 | "Requirement already satisfied: transformers<5.0.0,>=4.6.0 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from sentence-transformers) (4.25.1)\n", 41 | "Requirement already satisfied: huggingface-hub>=0.4.0 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from sentence-transformers) (0.11.1)\n", 42 | "Requirement already satisfied: scikit-learn in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from sentence-transformers) (1.0.2)\n", 43 | "Requirement already satisfied: numpy in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from sentence-transformers) (1.23.5)\n", 44 | "Requirement already satisfied: nltk in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from sentence-transformers) (3.8.1)\n", 45 | "Requirement already satisfied: torchvision in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from sentence-transformers) (0.14.1)\n", 46 | "Requirement already satisfied: tqdm in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from sentence-transformers) (4.62.3)\n", 47 | "Requirement already satisfied: filelock in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from gdown) (3.3.0)\n", 48 | "Requirement already satisfied: six in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from gdown) (1.16.0)\n", 49 | "Requirement already satisfied: requests[socks] in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from gdown) (2.26.0)\n", 50 | "Requirement already satisfied: beautifulsoup4 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from gdown) (4.10.0)\n", 51 | "Requirement already satisfied: setuptools in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from grpcio-tools<=1.48.0,>=1.47.0->pymilvus) (49.2.1)\n", 52 | "Requirement already satisfied: protobuf<4.0dev,>=3.12.0 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from grpcio-tools<=1.48.0,>=1.47.0->pymilvus) (3.20.3)\n", 53 | "Requirement already satisfied: typing-extensions>=3.7.4.3 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from huggingface-hub>=0.4.0->sentence-transformers) (3.10.0.2)\n", 54 | "Requirement already satisfied: packaging>=20.9 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from huggingface-hub>=0.4.0->sentence-transformers) (21.0)\n", 55 | "Requirement already satisfied: pyyaml>=5.1 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from huggingface-hub>=0.4.0->sentence-transformers) (5.4.1)\n", 56 | "Requirement already satisfied: pyparsing>=2.0.2 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from packaging>=20.9->huggingface-hub>=0.4.0->sentence-transformers) (2.4.7)\n", 57 | "Requirement already satisfied: python-dateutil>=2.8.1 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from pandas>=1.2.4->pymilvus) (2.8.2)\n", 58 | "Requirement already satisfied: pytz>=2020.1 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from pandas>=1.2.4->pymilvus) (2021.1)\n", 59 | "Requirement already satisfied: regex!=2019.12.17 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from transformers<5.0.0,>=4.6.0->sentence-transformers) (2021.10.8)\n", 60 | "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from transformers<5.0.0,>=4.6.0->sentence-transformers) (0.12.1)\n", 61 | "Requirement already satisfied: soupsieve>1.2 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from beautifulsoup4->gdown) (2.3.1)\n", 62 | "Requirement already satisfied: joblib in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from nltk->sentence-transformers) (1.1.0)\n", 63 | "Requirement already satisfied: click in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from nltk->sentence-transformers) (8.0.3)\n", 64 | "Requirement already satisfied: certifi>=2017.4.17 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from requests[socks]->gdown) (2021.5.30)\n", 65 | "Requirement already satisfied: idna<4,>=2.5 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from requests[socks]->gdown) (3.2)\n", 66 | "Requirement already satisfied: charset-normalizer~=2.0.0 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from requests[socks]->gdown) (2.0.6)\n", 67 | "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from requests[socks]->gdown) (1.26.7)\n", 68 | "Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from requests[socks]->gdown) (1.7.1)\n", 69 | "Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from scikit-learn->sentence-transformers) (3.0.0)\n", 70 | "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /Users/fzliu/.pyenv/lib/python3.8/site-packages (from torchvision->sentence-transformers) (9.1.1)\n", 71 | "\u001b[33mWARNING: You are using pip version 21.2.4; however, version 23.0.1 is available.\n", 72 | "You should consider upgrading via the '/Users/fzliu/.pyenv/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n" 73 | ] 74 | } 75 | ], 76 | "source": [ 77 | "! pip install pymilvus sentence-transformers gdown" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "## Grabbing the Data\n", 85 | "We are going to use `gdown` to grab the zip from Google Drive and then decompress it with the built in `zipfile` library." 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "metadata": { 92 | "scrolled": true 93 | }, 94 | "outputs": [], 95 | "source": [ 96 | "import gdown\n", 97 | "url = 'https://drive.google.com/uc?id=11ISS45aO2ubNCGaC3Lvd3D7NT8Y7MeO8'\n", 98 | "output = './movies.zip'\n", 99 | "gdown.download(url, output)" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 3, 105 | "metadata": {}, 106 | "outputs": [], 107 | "source": [ 108 | "import zipfile\n", 109 | "\n", 110 | "with zipfile.ZipFile(\"./movies.zip\",\"r\") as zip_ref:\n", 111 | " zip_ref.extractall(\"./movies\")" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "## Global Arguments\n", 119 | "Here we can find the main arguments that need to be modified for running with your own accounts. Beside each is a description of what it is." 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 19, 125 | "metadata": {}, 126 | "outputs": [], 127 | "source": [ 128 | "# Zilliz Cloud Setup Arguments\n", 129 | "import os\n", 130 | "COLLECTION_NAME = 'movies_db' # Collection name\n", 131 | "DIMENSION = 384 # Embeddings size\n", 132 | "URI=os.getenv('VECTOR_DB_URL') # Endpoint URI obtained from Zilliz Cloud\n", 133 | "USER='db_admin' # Username specified when you created this database\n", 134 | "PASSWORD=os.getenv('VECTOR_DB_PASS') # Password set for that account\n", 135 | "\n", 136 | "# Inference Arguments\n", 137 | "BATCH_SIZE = 128\n", 138 | "\n", 139 | "# Search Arguments\n", 140 | "TOP_K = 3" 141 | ] 142 | }, 143 | { 144 | "cell_type": "markdown", 145 | "metadata": {}, 146 | "source": [ 147 | "## Setting Up Milvus\n", 148 | "At this point we are going to begin setting up Milvus. The steps are as follows:\n", 149 | "\n", 150 | "1. Connect to the Milvus instance using the provided URI.\n", 151 | "2. If the collection already exists, drop it.\n", 152 | "3. Create the collection that holds the id, title of the movie, and the plot embedding.\n", 153 | "4. Create an index on the newly created collection and load it into memory.\n", 154 | "\n", 155 | "Once these steps are done the collection is ready to be inserted into and searched. Any data added will be indexed automatically and be available to search immidiately. If the data is very fresh, the search might be slower as brute force searching will be used on data that is still in process of getting indexed.\n" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 20, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "from pymilvus import connections\n", 165 | "\n", 166 | "# Connect to Milvus Database\n", 167 | "connections.connect(uri=URI, user=USER, password=PASSWORD, secure=True)" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 21, 173 | "metadata": {}, 174 | "outputs": [], 175 | "source": [ 176 | "from pymilvus import utility\n", 177 | "\n", 178 | "# Remove any previous collections with the same name\n", 179 | "if utility.has_collection(COLLECTION_NAME):\n", 180 | " utility.drop_collection(COLLECTION_NAME)" 181 | ] 182 | }, 183 | { 184 | "cell_type": "code", 185 | "execution_count": 22, 186 | "metadata": {}, 187 | "outputs": [], 188 | "source": [ 189 | "from pymilvus import FieldSchema, CollectionSchema, DataType, Collection\n", 190 | "\n", 191 | "\n", 192 | "# Create collection which includes the id, title, and embedding.\n", 193 | "fields = [\n", 194 | " FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),\n", 195 | " FieldSchema(name='title', dtype=DataType.VARCHAR, max_length=200), # VARCHARS need a maximum length, so for this example they are set to 200 characters\n", 196 | " FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)\n", 197 | "]\n", 198 | "schema = CollectionSchema(fields=fields)\n", 199 | "collection = Collection(name=COLLECTION_NAME, schema=schema)" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 23, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "# Create an IVF_FLAT index for collection.\n", 209 | "index_params = {\n", 210 | " 'metric_type':'L2',\n", 211 | " 'index_type':\"AUTOINDEX\", # IVF_FLAT, IVF_PQ, HNSW, ANNOY, DiskANN\n", 212 | " 'params':{}\n", 213 | "}\n", 214 | "collection.create_index(field_name=\"embedding\", index_params=index_params)\n", 215 | "collection.load()" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "metadata": {}, 221 | "source": [ 222 | "## Inserting the Data\n", 223 | "In these next few steps we will be: \n", 224 | "1. Loading the data.\n", 225 | "2. Embedding the plot text data using SentenceTransformers.\n", 226 | "3. Inserting the data into Milvus.\n", 227 | "\n", 228 | "For this example we are going using SentenceTransformers miniLM model to create embeddings of the plot text. This model returns 384-dim embeddings." 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": 13, 234 | "metadata": {}, 235 | "outputs": [ 236 | { 237 | "name": "stderr", 238 | "output_type": "stream", 239 | "text": [ 240 | "2023-03-01 22:14:24.344794: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n", 241 | "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" 242 | ] 243 | }, 244 | { 245 | "data": { 246 | "application/vnd.jupyter.widget-view+json": { 247 | "model_id": "9b50c848941f4414adab966fd59299c0", 248 | "version_major": 2, 249 | "version_minor": 0 250 | }, 251 | "text/plain": [ 252 | "Downloading: 0%| | 0.00/1.18k [00:00