├── requirements.txt ├── MANIFEST.in ├── pytest.ini ├── LICENSE.txt ├── README.rst ├── setup.py ├── .github └── workflows │ └── python-package.yml ├── test.py └── vptree.py /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pytest 3 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.rst 2 | include LICENSE.txt -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | python_files = test.py 3 | addopts = --verbose 4 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2017 Rickard Sjögren 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a 4 | copy of this software and associated documentation files (the "Software"), 5 | to deal in the Software without restriction, including without limitation 6 | the rights to use, copy, modify, merge, publish, distribute, sublicense, 7 | and/or sell copies of the Software, and to permit persons to whom the 8 | Software is furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS 14 | OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 16 | THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR 17 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 18 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 19 | OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | VP-Tree 2 | ======= 3 | 4 | .. image:: https://badge.fury.io/py/vptree.svg 5 | :target: https://badge.fury.io/py/vptree 6 | 7 | .. image:: https://github.com/RickardSjogren/vptree/actions/workflows/python-package.yml/badge.svg?branch=master 8 | :target: https://github.com/RickardSjogren/vptree/actions/workflows/python-package.yml 9 | 10 | This package contains an implementation of a `vantage-point tree `_ data structure. 11 | 12 | Installation 13 | ------------ 14 | 15 | Simply install through pip: 16 | 17 | .. code-block:: 18 | 19 | pip install vptree 20 | 21 | Example 22 | ------- 23 | 24 | Example usage: 25 | 26 | .. code-block:: python 27 | 28 | import numpy as np 29 | import vptree 30 | 31 | # Define distance function. 32 | def euclidean(p1, p2): 33 | return np.sqrt(np.sum(np.power(p2 - p1, 2))) 34 | 35 | # Generate some random points. 36 | points = np.random.randn(20000, 10) 37 | query = [.5] * 10 38 | 39 | # Build tree in O(n log n) time complexity. 40 | tree = vptree.VPTree(points, euclidean) 41 | 42 | # Query single point. 43 | tree.get_nearest_neighbor(query) 44 | 45 | # Query n-points. 46 | tree.get_n_nearest_neighbors(query, 10) 47 | 48 | # Get all points within certain distance. 49 | tree.get_all_in_range(query, 3.14) 50 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/use/bin/env python 2 | 3 | from setuptools import setup 4 | 5 | 6 | def readme(): 7 | with open('README.rst') as f: 8 | contents = f.read() 9 | return contents 10 | 11 | setup( 12 | name='vptree', 13 | version='1.3', 14 | author='Rickard Sjögren', 15 | author_email='r.sjogren89@gmail.com', 16 | license='MIT', 17 | url='https://github.com/RickardSjogren/vptree', 18 | description=('A package implementing a vantage-point data structure, for ' 19 | 'efficient nearest neighbor searching.'), 20 | long_description=readme(), 21 | py_modules=['vptree'], 22 | test_suite='test', 23 | keywords='python machine learning search', 24 | install_requires=[ 25 | 'numpy', 26 | ], 27 | classifiers=[ 28 | 'Development Status :: 5 - Production/Stable', 29 | 'Intended Audience :: Science/Research', 30 | 'Intended Audience :: Developers', 31 | 'Intended Audience :: Information Technology', 32 | 'License :: OSI Approved :: MIT License', 33 | 'Programming Language :: Python :: 2.7', 34 | 'Programming Language :: Python :: 3.5', 35 | 'Programming Language :: Python :: 3.6', 36 | 'Programming Language :: Python :: 3.7', 37 | 'Programming Language :: Python :: 3.8', 38 | 'Topic :: Scientific/Engineering', 39 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 40 | ] 41 | ) 42 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: CI 5 | 6 | on: 7 | push: 8 | branches: [ "master" ] 9 | pull_request: 10 | branches: [ "master" ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] 20 | 21 | steps: 22 | - uses: actions/checkout@v3 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v3 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install flake8 pytest 31 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 32 | - name: Lint with flake8 33 | run: | 34 | # stop the build if there are Python syntax errors or undefined names 35 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 36 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 37 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 38 | - name: Test with pytest 39 | run: | 40 | pytest 41 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import vptree 3 | import numpy as np 4 | 5 | 6 | class TestVPTree(unittest.TestCase): 7 | 8 | def test_single_nearest_neighbor(self): 9 | dim = 10 10 | query = [.5] * dim 11 | points, brute_force = brute_force_solution(20000, dim, query) 12 | tree = vptree.VPTree(points, euclidean) 13 | 14 | nearest = tree.get_nearest_neighbor(query) 15 | bf_nearest = brute_force[0] 16 | self.assertEqual(nearest[0], bf_nearest[0]) 17 | self.assertTrue(all(n == b for n, b in zip(nearest[1], bf_nearest[1]))) 18 | 19 | def test_nearest_neighbors(self): 20 | dim = 10 21 | query = [.5] * dim 22 | points, brute_force = brute_force_solution(20000, dim, query) 23 | tree = vptree.VPTree(points, euclidean) 24 | 25 | for k in (1, 10, len(points)): 26 | tree_nearest = tree.get_n_nearest_neighbors(query, k) 27 | self.assertEqual(len(tree_nearest), k) 28 | brute_force_nearest = brute_force[:k] 29 | for nearest, bf_nearest in zip(tree_nearest, brute_force_nearest): 30 | self.assertEqual(nearest[0], bf_nearest[0]) 31 | self.assertTrue(all(n == b for n, b in zip(nearest[1], 32 | bf_nearest[1]))) 33 | 34 | def test_epsilon_search(self): 35 | dim = 10 36 | query = [.5] * dim 37 | points, brute_force = brute_force_solution(20000, dim, query) 38 | tree = vptree.VPTree(points, euclidean) 39 | 40 | for eps in (-1, 0, 1, 2, 10): 41 | tree_nearest = sorted(tree.get_all_in_range(query, eps)) 42 | brute_force_nearest = [point for point in brute_force if 43 | point[0] < eps] 44 | for nearest, bf_nearest in zip(tree_nearest, brute_force_nearest): 45 | self.assertEqual(nearest[0], bf_nearest[0]) 46 | self.assertTrue(all(n == b for n, b in zip(nearest[1], 47 | bf_nearest[1]))) 48 | 49 | def test_empty_points_raises_valueerror(self): 50 | self.assertRaises(ValueError, vptree.VPTree, [], euclidean) 51 | 52 | def test_zero_neighbors_raises_valueerror(self): 53 | tree = vptree.VPTree([1, 2, 3], euclidean) 54 | self.assertRaises(ValueError, tree.get_n_nearest_neighbors, [1], 0) 55 | 56 | 57 | def euclidean(p1, p2): 58 | return np.sqrt(np.sum(np.power(p2 - p1, 2))) 59 | 60 | 61 | def brute_force_solution(n, dim, query, dist=euclidean): 62 | points = np.random.randn(n, dim) 63 | brute_force = [(dist(query, point), point) for point in points] 64 | brute_force.sort() 65 | 66 | return points, brute_force 67 | 68 | 69 | if __name__ == '__main__': 70 | unittest.main() 71 | -------------------------------------------------------------------------------- /vptree.py: -------------------------------------------------------------------------------- 1 | """ This module contains an implementation of a Vantage Point-tree (VP-tree).""" 2 | import bisect 3 | import collections 4 | import math 5 | import statistics as stats 6 | 7 | 8 | class VPTree: 9 | 10 | """ VP-Tree data structure for efficient nearest neighbor search. 11 | 12 | The VP-tree is a data structure for efficient nearest neighbor 13 | searching and finds the nearest neighbor in O(log n) 14 | complexity given a tree constructed of n data points. Construction 15 | complexity is O(n log n). 16 | 17 | Parameters 18 | ---------- 19 | points : Iterable 20 | Construction points. 21 | dist_fn : Callable 22 | Function taking to point instances as arguments and returning 23 | the distance between them. 24 | leaf_size : int 25 | Minimum number of points in leaves (IGNORED). 26 | """ 27 | 28 | def __init__(self, points, dist_fn): 29 | self.left = None 30 | self.right = None 31 | self.left_min = math.inf 32 | self.left_max = 0 33 | self.right_min = math.inf 34 | self.right_max = 0 35 | self.dist_fn = dist_fn 36 | 37 | if not len(points): 38 | raise ValueError('Points can not be empty.') 39 | 40 | # Vantage point is point furthest from parent vp. 41 | self.vp = points[0] 42 | points = points[1:] 43 | 44 | if len(points) == 0: 45 | return 46 | 47 | # Choose division boundary at median of distances. 48 | distances = [self.dist_fn(self.vp, p) for p in points] 49 | median = stats.median(distances) 50 | 51 | left_points = [] 52 | right_points = [] 53 | for point, distance in zip(points, distances): 54 | if distance >= median: 55 | self.right_min = min(distance, self.right_min) 56 | if distance > self.right_max: 57 | self.right_max = distance 58 | right_points.insert(0, point) # put furthest first 59 | else: 60 | right_points.append(point) 61 | else: 62 | self.left_min = min(distance, self.left_min) 63 | if distance > self.left_max: 64 | self.left_max = distance 65 | left_points.insert(0, point) # put furthest first 66 | else: 67 | left_points.append(point) 68 | 69 | if len(left_points) > 0: 70 | self.left = VPTree(points=left_points, dist_fn=self.dist_fn) 71 | 72 | if len(right_points) > 0: 73 | self.right = VPTree(points=right_points, dist_fn=self.dist_fn) 74 | 75 | def _is_leaf(self): 76 | return (self.left is None) and (self.right is None) 77 | 78 | def get_nearest_neighbor(self, query): 79 | """ Get single nearest neighbor. 80 | 81 | Parameters 82 | ---------- 83 | query : Any 84 | Query point. 85 | 86 | Returns 87 | ------- 88 | Any 89 | Single nearest neighbor. 90 | """ 91 | return self.get_n_nearest_neighbors(query, n_neighbors=1)[0] 92 | 93 | def get_n_nearest_neighbors(self, query, n_neighbors): 94 | """ Get `n_neighbors` nearest neigbors to `query` 95 | 96 | Parameters 97 | ---------- 98 | query : Any 99 | Query point. 100 | n_neighbors : int 101 | Number of neighbors to fetch. 102 | 103 | Returns 104 | ------- 105 | list 106 | List of `n_neighbors` nearest neighbors. 107 | """ 108 | if not isinstance(n_neighbors, int) or n_neighbors < 1: 109 | raise ValueError('n_neighbors must be strictly positive integer') 110 | neighbors = _AutoSortingList(max_size=n_neighbors) 111 | queue = collections.deque([self]) 112 | furthest_d = math.inf 113 | need_neighbors = True 114 | 115 | while queue: 116 | node = queue.popleft() 117 | if node is None: 118 | continue 119 | d = self.dist_fn(query, node.vp) 120 | 121 | if d < furthest_d or need_neighbors: 122 | neighbors.append((d, node.vp)) 123 | furthest_d = neighbors[-1][0] 124 | if need_neighbors: 125 | need_neighbors = len(neighbors) < n_neighbors 126 | 127 | if node._is_leaf(): 128 | continue 129 | 130 | if need_neighbors: 131 | if d < node.left_max + furthest_d: 132 | queue.append(node.left) 133 | if d >= node.right_min - furthest_d: 134 | queue.append(node.right) 135 | else: 136 | if node.left_min - furthest_d < d < node.left_max + furthest_d: 137 | queue.append(node.left) 138 | if node.right_min - furthest_d <= d <= node.right_max + furthest_d: 139 | queue.append(node.right) 140 | 141 | return list(neighbors) 142 | 143 | def get_all_in_range(self, query, max_distance): 144 | """ Find all neighbours within `max_distance`. 145 | 146 | Parameters 147 | ---------- 148 | query : Any 149 | Query point. 150 | max_distance : float 151 | Threshold distance for query. 152 | 153 | Returns 154 | ------- 155 | neighbors : list 156 | List of points within `max_distance`. 157 | 158 | Notes 159 | ----- 160 | Returned neighbors are not sorted according to distance. 161 | """ 162 | neighbors = list() 163 | nodes_to_visit = [(self, 0)] 164 | 165 | while len(nodes_to_visit) > 0: 166 | node, d0 = nodes_to_visit.pop(0) 167 | if node is None or d0 > max_distance: 168 | continue 169 | 170 | d = self.dist_fn(query, node.vp) 171 | if d < max_distance: 172 | neighbors.append((d, node.vp)) 173 | 174 | if node._is_leaf(): 175 | continue 176 | 177 | if node.left_min <= d <= node.left_max: 178 | nodes_to_visit.insert(0, (node.left, 0)) 179 | elif node.left_min - max_distance <= d <= node.left_max + max_distance: 180 | nodes_to_visit.append((node.left, 181 | node.left_min - d if d < node.left_min 182 | else d - node.left_max)) 183 | 184 | if node.right_min <= d <= node.right_max: 185 | nodes_to_visit.insert(0, (node.right, 0)) 186 | elif node.right_min - max_distance <= d <= node.right_max + max_distance: 187 | nodes_to_visit.append((node.right, 188 | node.right_min - d if d < node.right_min 189 | else d - node.right_max)) 190 | 191 | return neighbors 192 | 193 | 194 | class _AutoSortingList(list): 195 | 196 | """ Simple auto-sorting list. 197 | 198 | Inefficient for large sizes since the queue is sorted at 199 | each push. 200 | 201 | Parameters 202 | --------- 203 | size : int, optional 204 | Max queue size. 205 | """ 206 | 207 | def __init__(self, max_size=None, *args): 208 | super(_AutoSortingList, self).__init__(*args) 209 | self.max_size = max_size 210 | 211 | def append(self, item): 212 | """ insert `item` in sorted order 213 | 214 | Parameters 215 | ---------- 216 | item : Any 217 | Input item. 218 | """ 219 | self.insert(bisect.bisect_left(self, item), item) 220 | if self.max_size is not None and len(self) > self.max_size: 221 | self.pop() 222 | 223 | --------------------------------------------------------------------------------