├── test ├── __init__.py ├── test-jaccard.py └── test-metrics.py ├── .dockerignore ├── ann_benchmarks ├── algorithms │ ├── __init__.py │ ├── dummy_algo.py │ ├── rpforest.py │ ├── kdtree.py │ ├── balltree.py │ ├── annoy.py │ ├── base.py │ ├── n2.py │ ├── sptag.py │ ├── flann.py │ ├── lshf.py │ ├── dolphinnpy.py │ ├── datasketch.py │ ├── scann.py │ ├── faiss_hnsw.py │ ├── hnswlib.py │ ├── kgraph.py │ ├── mrpt.py │ ├── puffinn.py │ ├── nearpy.py │ ├── milvus.py │ ├── faiss_gpu.py │ ├── faiss.py │ ├── nmslib.py │ ├── panng_ngt.py │ ├── onng_ngt.py │ ├── elasticsearch.py │ ├── pynndescent.py │ ├── elastiknn.py │ ├── bruteforce.py │ ├── definitions.py │ └── subprocess.py ├── constants.py ├── __init__.py ├── plotting │ ├── __init__.py │ ├── plot_variants.py │ ├── utils.py │ └── metrics.py ├── data.py ├── distance.py ├── results.py ├── main.py └── runner.py ├── run_algorithm.py ├── results ├── lastfm-64-dot.png ├── glove-25-angular.png ├── gist-960-euclidean.png ├── glove-100-angular.png ├── mnist-784-euclidean.png ├── nytimes-256-angular.png ├── sift-128-euclidean.png └── fashion-mnist-784-euclidean.png ├── install ├── Dockerfile.sklearn ├── Dockerfile.datasketch ├── Dockerfile.n2 ├── Dockerfile.mrpt ├── Dockerfile.annoy ├── Dockerfile.nearpy ├── Dockerfile.rpforest ├── Dockerfile.dolphinn ├── Dockerfile.puffinn ├── Dockerfile.mih ├── Dockerfile ├── Dockerfile.hnswlib ├── Dockerfile.kgraph ├── Dockerfile.pynndescent ├── Dockerfile.flann ├── Dockerfile.ngt ├── Dockerfile.nmslib ├── Dockerfile.faiss ├── Dockerfile.scann ├── Dockerfile.sptag ├── Dockerfile.milvus ├── Dockerfile.elasticsearch └── Dockerfile.elastiknn ├── run.py ├── requirements.txt ├── .gitignore ├── protocol ├── bf-runner ├── ext-query-parameters.md ├── ext-batch-queries.md ├── ext-add-query-metric.md ├── ext-prepared-queries.md ├── bf-runner.py └── specification.md ├── create_dataset.py ├── logging.conf ├── templates ├── detail_page.html ├── latex.template ├── general.html ├── summary.html └── chartjs.template ├── LICENSE ├── .travis.yml ├── algosP.yaml ├── install.py ├── plot.py ├── README.md └── create_website.py /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | data 2 | results 3 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ann_benchmarks/constants.py: -------------------------------------------------------------------------------- 1 | INDEX_DIR = 'indices' 2 | -------------------------------------------------------------------------------- /run_algorithm.py: -------------------------------------------------------------------------------- 1 | from ann_benchmarks.runner import run_from_cmdline 2 | 3 | run_from_cmdline() 4 | -------------------------------------------------------------------------------- /ann_benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | # from ann_benchmarks.main import * 3 | -------------------------------------------------------------------------------- /results/lastfm-64-dot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/ann-benchmarks/master/results/lastfm-64-dot.png -------------------------------------------------------------------------------- /results/glove-25-angular.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/ann-benchmarks/master/results/glove-25-angular.png -------------------------------------------------------------------------------- /ann_benchmarks/plotting/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from ann_benchmarks.plotting import * 3 | -------------------------------------------------------------------------------- /results/gist-960-euclidean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/ann-benchmarks/master/results/gist-960-euclidean.png -------------------------------------------------------------------------------- /results/glove-100-angular.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/ann-benchmarks/master/results/glove-100-angular.png -------------------------------------------------------------------------------- /results/mnist-784-euclidean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/ann-benchmarks/master/results/mnist-784-euclidean.png -------------------------------------------------------------------------------- /results/nytimes-256-angular.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/ann-benchmarks/master/results/nytimes-256-angular.png -------------------------------------------------------------------------------- /results/sift-128-euclidean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/ann-benchmarks/master/results/sift-128-euclidean.png -------------------------------------------------------------------------------- /install/Dockerfile.sklearn: -------------------------------------------------------------------------------- 1 | FROM ann-benchmarks 2 | 3 | RUN pip3 install scikit-learn 4 | RUN python3 -c 'import sklearn' 5 | -------------------------------------------------------------------------------- /install/Dockerfile.datasketch: -------------------------------------------------------------------------------- 1 | FROM ann-benchmarks 2 | 3 | RUN pip3 install datasketch 4 | RUN python3 -c 'import datasketch' 5 | -------------------------------------------------------------------------------- /install/Dockerfile.n2: -------------------------------------------------------------------------------- 1 | FROM ann-benchmarks 2 | 3 | RUN pip3 install cython 4 | RUN pip3 install n2 5 | RUN python3 -c 'import n2' 6 | -------------------------------------------------------------------------------- /install/Dockerfile.mrpt: -------------------------------------------------------------------------------- 1 | FROM ann-benchmarks 2 | 3 | RUN pip3 install sklearn 4 | RUN pip3 install git+https://github.com/vioshyvo/mrpt 5 | -------------------------------------------------------------------------------- /results/fashion-mnist-784-euclidean.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/ann-benchmarks/master/results/fashion-mnist-784-euclidean.png -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from ann_benchmarks.main import main 2 | from multiprocessing import freeze_support 3 | 4 | if __name__ == "__main__": 5 | freeze_support() 6 | main() 7 | -------------------------------------------------------------------------------- /install/Dockerfile.annoy: -------------------------------------------------------------------------------- 1 | FROM ann-benchmarks 2 | 3 | RUN git clone https://github.com/spotify/annoy 4 | RUN cd annoy && python3 setup.py install 5 | RUN python3 -c 'import annoy' 6 | -------------------------------------------------------------------------------- /install/Dockerfile.nearpy: -------------------------------------------------------------------------------- 1 | FROM ann-benchmarks 2 | 3 | RUN apt-get install -y libhdf5-openmpi-dev cython 4 | RUN pip3 install nearpy bitarray redis sklearn 5 | RUN python3 -c 'import nearpy' -------------------------------------------------------------------------------- /install/Dockerfile.rpforest: -------------------------------------------------------------------------------- 1 | FROM ann-benchmarks 2 | 3 | RUN git clone https://github.com/lyst/rpforest 4 | RUN cd rpforest && python3 setup.py install 5 | RUN python3 -c 'import rpforest' 6 | -------------------------------------------------------------------------------- /install/Dockerfile.dolphinn: -------------------------------------------------------------------------------- 1 | FROM ann-benchmarks 2 | 3 | RUN git clone https://github.com/ipsarros/DolphinnPy lib-dolphinnpy 4 | ENV PYTHONPATH lib-dolphinnpy 5 | RUN python3 -c 'import dolphinn' 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ansicolors==1.1.8 2 | docker==2.6.1 3 | h5py==2.7.1 4 | matplotlib==2.1.0 5 | numpy==1.13.3 6 | pyyaml==3.12 7 | psutil==5.6.6 8 | scipy==1.0.0 9 | scikit-learn==0.19.1 10 | jinja2==2.10 11 | -------------------------------------------------------------------------------- /install/Dockerfile.puffinn: -------------------------------------------------------------------------------- 1 | FROM ann-benchmarks 2 | 3 | RUN pip3 install pypandoc 4 | RUN git clone https://github.com/puffinn/puffinn 5 | RUN cd puffinn && python3 setup.py install 6 | RUN python3 -c 'import puffinn' 7 | -------------------------------------------------------------------------------- /install/Dockerfile.mih: -------------------------------------------------------------------------------- 1 | FROM ann-benchmarks 2 | RUN apt-get update && apt-get install -y cmake libhdf5-dev 3 | RUN git clone https://github.com/maumueller/mih 4 | RUN cd mih && mkdir bin && cd bin && cmake ../ && make -j4 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.pyc 3 | *.o 4 | protocol/c/fr-* 5 | 6 | install/*.txt 7 | install/*.yaml 8 | install/lib-*/ 9 | data/* 10 | *.class 11 | 12 | *.log 13 | 14 | results/* 15 | !results/*.png 16 | 17 | venv 18 | 19 | .idea 20 | -------------------------------------------------------------------------------- /protocol/bf-runner: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | where="$(dirname "$0")" 4 | if [ "$1" = "2" -o "$1" = "-2" -o "$1" = "--2" ]; then 5 | PYTHON=python2 6 | else 7 | PYTHON=python3 8 | fi 9 | export PYTHONPATH="$where/..:$PYTHONPATH" 10 | exec $PYTHON "$where/bf-runner.py" 11 | -------------------------------------------------------------------------------- /install/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:18.04 2 | 3 | RUN apt-get update 4 | RUN apt-get install -y python3-numpy python3-scipy python3-pip build-essential git 5 | RUN pip3 install -U pip 6 | 7 | WORKDIR /home/app 8 | COPY requirements.txt run_algorithm.py ./ 9 | RUN pip3 install -rrequirements.txt 10 | 11 | ENTRYPOINT ["python3", "run_algorithm.py"] 12 | -------------------------------------------------------------------------------- /install/Dockerfile.hnswlib: -------------------------------------------------------------------------------- 1 | FROM ann-benchmarks 2 | 3 | RUN apt-get install -y python-setuptools python-pip 4 | RUN pip3 install pybind11 numpy setuptools 5 | RUN git clone https://github.com/nmslib/hnsw.git;cd hnsw; git checkout denorm 6 | 7 | RUN cd hnsw/python_bindings; python3 setup.py install 8 | 9 | RUN python3 -c 'import hnswlib' 10 | 11 | -------------------------------------------------------------------------------- /install/Dockerfile.kgraph: -------------------------------------------------------------------------------- 1 | FROM ann-benchmarks 2 | 3 | RUN apt-get update && apt-get install -y libboost-timer-dev libboost-chrono-dev libboost-program-options-dev libboost-system-dev libboost-python-dev 4 | RUN git clone https://github.com/aaalgo/kgraph 5 | RUN cd kgraph && python3 setup.py build && python3 setup.py install 6 | RUN python3 -c 'import pykgraph' 7 | -------------------------------------------------------------------------------- /install/Dockerfile.pynndescent: -------------------------------------------------------------------------------- 1 | FROM ann-benchmarks 2 | 3 | 4 | RUN apt-get -y install llvm-10 5 | RUN export LLVM_CONFIG=/usr/bin/llvm-config-10; pip3 install 'numba==0.51.2' 'llvmlite==0.34' 'numpy==1.16.4' scikit-learn icc_rt 6 | RUN ldconfig 7 | RUN pip3 install 'numpy==1.17' 8 | RUN pip3 install 'pynndescent>=0.5' 9 | RUN python3 -c 'import pynndescent' 10 | -------------------------------------------------------------------------------- /install/Dockerfile.flann: -------------------------------------------------------------------------------- 1 | FROM ann-benchmarks 2 | 3 | RUN apt-get update && apt-get install -y cmake pkg-config liblz4-dev 4 | RUN git clone https://github.com/mariusmuja/flann 5 | RUN mkdir flann/build 6 | RUN cd flann/build && cmake .. 7 | RUN cd flann/build && make -j4 8 | RUN cd flann/build && make install 9 | RUN pip3 install sklearn 10 | RUN python3 -c 'import pyflann' 11 | -------------------------------------------------------------------------------- /create_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from ann_benchmarks.datasets import DATASETS, get_dataset_fn 3 | 4 | if __name__ == "__main__": 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument( 7 | '--dataset', 8 | choices=DATASETS.keys(), 9 | required=True) 10 | args = parser.parse_args() 11 | fn = get_dataset_fn(args.dataset) 12 | DATASETS[args.dataset](fn) 13 | -------------------------------------------------------------------------------- /install/Dockerfile.ngt: -------------------------------------------------------------------------------- 1 | FROM ann-benchmarks 2 | 3 | RUN apt-get update 4 | RUN apt-get install -y git cmake g++ python3 python3-setuptools python3-pip 5 | RUN pip3 install wheel pybind11 6 | RUN git clone https://github.com/yahoojapan/ngt.git 7 | RUN mkdir -p ngt/build 8 | RUN cd ngt/build && cmake .. 9 | RUN cd ngt/build && make && make install 10 | RUN ldconfig 11 | RUN cd ngt/python && python3 setup.py bdist_wheel 12 | RUN pip3 install ngt/python/dist/ngt-*-linux_x86_64.whl 13 | 14 | -------------------------------------------------------------------------------- /install/Dockerfile.nmslib: -------------------------------------------------------------------------------- 1 | FROM ann-benchmarks 2 | 3 | RUN apt-get update && apt-get install -y cmake libboost-all-dev libeigen3-dev libgsl0-dev 4 | RUN git clone https://github.com/searchivarius/nmslib.git 5 | RUN cd nmslib/similarity_search && cmake . -DWITH_EXTRAS=1 6 | RUN cd nmslib/similarity_search && make -j4 7 | RUN pip3 install pybind11 8 | RUN cd nmslib/python_bindings && python3 setup.py build 9 | RUN cd nmslib/python_bindings && python3 setup.py install 10 | RUN python3 -c 'import nmslib' 11 | -------------------------------------------------------------------------------- /ann_benchmarks/plotting/plot_variants.py: -------------------------------------------------------------------------------- 1 | from ann_benchmarks.plotting.metrics import all_metrics as metrics 2 | 3 | all_plot_variants = { 4 | "recall/time": ("k-nn", "qps"), 5 | "recall/buildtime": ("k-nn", "build"), 6 | "recall/indexsize": ("k-nn", "indexsize"), 7 | "recall/distcomps": ("k-nn", "distcomps"), 8 | "rel/time": ("rel", "qps"), 9 | "recall/candidates": ("k-nn", "candidates"), 10 | "recall/qpssize": ("k-nn", "queriessize"), 11 | "eps/time": ("epsilon", "qps"), 12 | "largeeps/time": ("largeepsilon", "qps") 13 | } 14 | -------------------------------------------------------------------------------- /logging.conf: -------------------------------------------------------------------------------- 1 | [loggers] 2 | keys=root,annb 3 | 4 | [handlers] 5 | keys=consoleHandler,fileHandler 6 | 7 | [formatters] 8 | keys=simpleFormatter 9 | 10 | [formatter_simpleFormatter] 11 | format=%(asctime)s - %(name)s - %(levelname)s - %(message)s 12 | datefmt= 13 | 14 | [handler_consoleHandler] 15 | class=StreamHandler 16 | level=INFO 17 | formatter=simpleFormatter 18 | args=(sys.stdout,) 19 | 20 | [handler_fileHandler] 21 | class=FileHandler 22 | level=INFO 23 | formatter=simpleFormatter 24 | args=('annb.log','w') 25 | 26 | [logger_root] 27 | level=WARN 28 | handlers=consoleHandler 29 | 30 | [logger_annb] 31 | level=INFO 32 | handlers=consoleHandler,fileHandler 33 | qualname=annb 34 | propagate=0 -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/dummy_algo.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import numpy as np 3 | from ann_benchmarks.algorithms.base import BaseANN 4 | 5 | 6 | class DummyAlgoMt(BaseANN): 7 | def __init__(self, metric): 8 | self.name = 'DummyAlgoMultiThread' 9 | 10 | def fit(self, X): 11 | self.len = len(X) - 1 12 | 13 | def query(self, v, n): 14 | return np.random.randint(self.len, size=n) 15 | 16 | 17 | class DummyAlgoSt(BaseANN): 18 | def __init__(self, metric): 19 | self.name = 'DummyAlgoSingleThread' 20 | 21 | def fit(self, X): 22 | self.len = len(X) - 1 23 | 24 | def query(self, v, n): 25 | return np.random.randint(self.len, size=n) 26 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/rpforest.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import rpforest 3 | import numpy 4 | from ann_benchmarks.algorithms.base import BaseANN 5 | 6 | 7 | class RPForest(BaseANN): 8 | def __init__(self, leaf_size, n_trees): 9 | self.name = 'RPForest(leaf_size=%d, n_trees=%d)' % (leaf_size, n_trees) 10 | self._model = rpforest.RPForest(leaf_size=leaf_size, no_trees=n_trees) 11 | 12 | def fit(self, X): 13 | if X.dtype != numpy.double: 14 | X = numpy.array(X).astype(numpy.double) 15 | self._model.fit(X) 16 | 17 | def query(self, v, n): 18 | if v.dtype != numpy.double: 19 | v = numpy.array(v).astype(numpy.double) 20 | return self._model.query(v, n) 21 | -------------------------------------------------------------------------------- /install/Dockerfile.faiss: -------------------------------------------------------------------------------- 1 | FROM ann-benchmarks 2 | 3 | RUN apt-get update && apt-get install -y libopenblas-base libopenblas-dev libpython3-dev swig python3-dev libssl-dev wget 4 | RUN wget https://github.com/Kitware/CMake/releases/download/v3.18.3/cmake-3.18.3-Linux-x86_64.sh && mkdir cmake && sh cmake-3.18.3-Linux-x86_64.sh --skip-license --prefix=cmake && rm cmake-3.18.3-Linux-x86_64.sh 5 | RUN git clone https://github.com/facebookresearch/faiss lib-faiss 6 | RUN cd lib-faiss && ../cmake/bin/cmake -DFAISS_ENABLE_GPU=OFF -DPython_EXECUTABLE=/usr/bin/python3 -B build . 7 | RUN cd lib-faiss && make -C build -j4 8 | RUN cd lib-faiss && cd build && cd faiss && cd python && python3 setup.py install && cd && rm -rf cmake 9 | RUN python3 -c 'import faiss; print(faiss.IndexFlatL2)' 10 | -------------------------------------------------------------------------------- /test/test-jaccard.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy 3 | from ann_benchmarks.distance import jaccard, transform_dense_to_sparse 4 | 5 | class TestJaccard(unittest.TestCase): 6 | def setUp(self): 7 | pass 8 | 9 | def test_similarity(self): 10 | a = [1, 2, 3, 4] 11 | b = [] 12 | c = [1, 2] 13 | d = [5, 6] 14 | 15 | self.assertAlmostEqual(jaccard(a, b), 0.0) 16 | self.assertAlmostEqual(jaccard(a, a), 1.0) 17 | self.assertAlmostEqual(jaccard(a, c), 0.5) 18 | self.assertAlmostEqual(jaccard(c, d), 0.0) 19 | 20 | def test_transformation(self): 21 | X = numpy.array([[True, False, False], [True, False, True], [False, False, True]]) 22 | self.assertEqual(transform_dense_to_sparse(X), [[0],[0, 2], [2]]) 23 | 24 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/kdtree.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import sklearn.neighbors 3 | import sklearn.preprocessing 4 | from ann_benchmarks.algorithms.base import BaseANN 5 | 6 | 7 | class KDTree(BaseANN): 8 | def __init__(self, metric, leaf_size=20): 9 | self._leaf_size = leaf_size 10 | self._metric = metric 11 | self.name = 'KDTree(leaf_size=%d)' % self._leaf_size 12 | 13 | def fit(self, X): 14 | if self._metric == 'angular': 15 | X = sklearn.preprocessing.normalize(X, axis=1, norm='l2') 16 | self._tree = sklearn.neighbors.KDTree(X, leaf_size=self._leaf_size) 17 | 18 | def query(self, v, n): 19 | if self._metric == 'angular': 20 | v = sklearn.preprocessing.normalize([v], axis=1, norm='l2')[0] 21 | dist, ind = self._tree.query([v], k=n) 22 | return ind[0] 23 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/balltree.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import sklearn.neighbors 3 | import sklearn.preprocessing 4 | from ann_benchmarks.algorithms.base import BaseANN 5 | 6 | 7 | class BallTree(BaseANN): 8 | def __init__(self, metric, leaf_size=20): 9 | self._leaf_size = leaf_size 10 | self._metric = metric 11 | self.name = 'BallTree(leaf_size=%d)' % self._leaf_size 12 | 13 | def fit(self, X): 14 | if self._metric == 'angular': 15 | X = sklearn.preprocessing.normalize(X, axis=1, norm='l2') 16 | self._tree = sklearn.neighbors.BallTree(X, leaf_size=self._leaf_size) 17 | 18 | def query(self, v, n): 19 | if self._metric == 'angular': 20 | v = sklearn.preprocessing.normalize([v], axis=1, norm='l2')[0] 21 | dist, ind = self._tree.query([v], k=n) 22 | return ind[0] 23 | -------------------------------------------------------------------------------- /templates/detail_page.html: -------------------------------------------------------------------------------- 1 | {% extends "general.html" %} 2 | {% block content %} 3 |
4 | {% for item in plot_data.keys() %} 5 | {% if item=="normal" %} 6 | {% if batch %} 7 |

Plots for {{title}} in batch mode

8 | {% else %} 9 |

Plots for {{title}}

10 | {% endif %} 11 | {% elif item=="scatter" and args.scatter %} 12 | {% if batch %} 13 |

Scatterplots for {{title}} in batch mode

14 | {% else %} 15 |

Scatterplots for {{title}}

16 | {% endif %} 17 | {% endif %} 18 | {% for plot in plot_data[item] %} 19 | {{ plot }} 20 | {% endfor %} 21 |
22 | {% endfor %} 23 | {% endblock %} 24 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/annoy.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import annoy 3 | from ann_benchmarks.algorithms.base import BaseANN 4 | 5 | 6 | class Annoy(BaseANN): 7 | def __init__(self, metric, n_trees): 8 | self._n_trees = n_trees 9 | self._search_k = None 10 | self._metric = metric 11 | 12 | def fit(self, X): 13 | self._annoy = annoy.AnnoyIndex(X.shape[1], metric=self._metric) 14 | for i, x in enumerate(X): 15 | self._annoy.add_item(i, x.tolist()) 16 | self._annoy.build(self._n_trees) 17 | 18 | def set_query_arguments(self, search_k): 19 | self._search_k = search_k 20 | 21 | def query(self, v, n): 22 | return self._annoy.get_nns_by_vector(v.tolist(), n, self._search_k) 23 | 24 | def __str__(self): 25 | return 'Annoy(n_trees=%d, search_k=%d)' % (self._n_trees, 26 | self._search_k) 27 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import psutil 3 | 4 | 5 | class BaseANN(object): 6 | def done(self): 7 | pass 8 | 9 | def get_memory_usage(self): 10 | """Return the current memory usage of this algorithm instance 11 | (in kilobytes), or None if this information is not available.""" 12 | # return in kB for backwards compatibility 13 | return psutil.Process().memory_info().rss / 1024 14 | 15 | def fit(self, X): 16 | pass 17 | 18 | def query(self, q, n): 19 | return [] # array of candidate indices 20 | 21 | def batch_query(self, X, n): 22 | self.res = [] 23 | for q in X: 24 | self.res.append(self.query(q, n)) 25 | 26 | def get_batch_results(self): 27 | return self.res 28 | 29 | def get_additional(self): 30 | return {} 31 | 32 | def __str__(self): 33 | return self.name 34 | -------------------------------------------------------------------------------- /templates/latex.template: -------------------------------------------------------------------------------- 1 | 2 | \begin{figure} 3 | \centering 4 | \begin{tikzpicture} 5 | \begin{axis}[ 6 | xlabel={ {{xlabel}} }, 7 | ylabel={ {{ylabel}} }, 8 | ymode = log, 9 | yticklabel style={/pgf/number format/fixed, 10 | /pgf/number format/precision=3}, 11 | legend style = { anchor=west}, 12 | cycle list name = black white 13 | ] 14 | {% for algo in plot_data %} 15 | {% if algo.scatter %} 16 | \addplot [only marks] coordinates { 17 | {% else %} 18 | \addplot coordinates { 19 | {% endif %} 20 | {% for coord in algo.coords %} 21 | ({{ coord[0]}}, {{ coord[1] }}) 22 | {% endfor %} 23 | }; 24 | \addlegendentry{ {{algo.name}} }; 25 | {% endfor %} 26 | \end{axis} 27 | \end{tikzpicture} 28 | \caption{ {{caption}} } 29 | \label{} 30 | \end{figure} 31 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/n2.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import n2 3 | from ann_benchmarks.algorithms.base import BaseANN 4 | 5 | 6 | class N2(BaseANN): 7 | def __init__(self, metric, method_param): 8 | self._metric = metric 9 | self._m = method_param['M'] 10 | self._m0 = self._m * 2 11 | self._ef_construction = method_param['efConstruction'] 12 | self._n_threads = 1 13 | self._ef_search = -1 14 | 15 | def fit(self, X): 16 | self._n2 = n2.HnswIndex(X.shape[1], self._metric) 17 | for x in X: 18 | self._n2.add_data(x) 19 | self._n2.build(m=self._m, max_m0=self._m0, ef_construction=self._ef_construction, n_threads=self._n_threads, graph_merging='merge_level0') 20 | 21 | def set_query_arguments(self, ef): 22 | self._ef_search = ef 23 | 24 | def query(self, v, n): 25 | return self._n2.search_by_vector(v, n, self._ef_search) 26 | 27 | def __str__(self): 28 | return "N2 (M%d_efCon%d)" % (self._m, self._ef_construction) 29 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/sptag.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import SPTAG 3 | from ann_benchmarks.algorithms.base import BaseANN 4 | 5 | 6 | class Sptag(BaseANN): 7 | def __init__(self, metric, algo): 8 | self._algo = str(algo) 9 | self._metric = { 10 | 'angular': 'Cosine', 'euclidean': 'L2'}[metric] 11 | 12 | def fit(self, X): 13 | self._sptag = SPTAG.AnnIndex(self._algo, 'Float', X.shape[1]) 14 | self._sptag.SetBuildParam("NumberOfThreads", '32') 15 | self._sptag.SetBuildParam("DistCalcMethod", self._metric) 16 | self._sptag.Build(X, X.shape[0]) 17 | 18 | def set_query_arguments(self, MaxCheck): 19 | self._maxCheck = MaxCheck 20 | self._sptag.SetSearchParam("MaxCheck", str(self._maxCheck)) 21 | 22 | def query(self, v, k): 23 | return self._sptag.Search(v, k)[0] 24 | 25 | def __str__(self): 26 | return 'Sptag(metric=%s, algo=%s, check=%d)' % (self._metric, 27 | self._algo, self._maxCheck) 28 | 29 | -------------------------------------------------------------------------------- /install/Dockerfile.scann: -------------------------------------------------------------------------------- 1 | FROM ann-benchmarks 2 | 3 | RUN apt-get install -y software-properties-common curl gnupg rsync 4 | 5 | RUN curl https://bazel.build/bazel-release.pub.gpg | apt-key add - 6 | RUN echo "deb [arch=amd64] https://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list 7 | RUN apt-get update && apt-get install -y bazel-3.4.1 8 | 9 | RUN add-apt-repository -y ppa:ubuntu-toolchain-r/test 10 | RUN apt-get update 11 | RUN apt-get install -y g++-9 clang-8 12 | 13 | RUN pip3 install --upgrade pip 14 | RUN git clone https://github.com/google-research/google-research.git --depth=1 15 | RUN cd google-research/scann && python3 configure.py 16 | RUN PY3="$(which python3)" && cd google-research/scann && PYTHON_BIN_PATH=$PY3 CC=clang-8 bazel-3.4.1 build -c opt --copt=-mavx2 --copt=-mfma --cxxopt="-D_GLIBCXX_USE_CXX11_ABI=0" --cxxopt="-std=c++17" --copt=-fsized-deallocation --copt=-w :build_pip_pkg 17 | RUN cd google-research/scann && PYTHON=python3 ./bazel-bin/build_pip_pkg && pip3 install *.whl 18 | RUN python3 -c 'import scann' 19 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/flann.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import pyflann 3 | import numpy 4 | import sklearn.preprocessing 5 | from ann_benchmarks.algorithms.base import BaseANN 6 | 7 | 8 | class FLANN(BaseANN): 9 | def __init__(self, metric, target_precision): 10 | self._target_precision = target_precision 11 | self.name = 'FLANN(target_precision=%f)' % self._target_precision 12 | self._metric = metric 13 | 14 | def fit(self, X): 15 | self._flann = pyflann.FLANN( 16 | target_precision=self._target_precision, 17 | algorithm='autotuned', log_level='info') 18 | if self._metric == 'angular': 19 | X = sklearn.preprocessing.normalize(X, axis=1, norm='l2') 20 | self._flann.build_index(X) 21 | 22 | def query(self, v, n): 23 | if self._metric == 'angular': 24 | v = sklearn.preprocessing.normalize([v], axis=1, norm='l2')[0] 25 | if v.dtype != numpy.float32: 26 | v = v.astype(numpy.float32) 27 | return self._flann.nn_index(v, n)[0][0] 28 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/lshf.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import sklearn.neighbors 3 | import sklearn.preprocessing 4 | from ann_benchmarks.algorithms.base import BaseANN 5 | 6 | 7 | class LSHF(BaseANN): 8 | def __init__(self, metric, n_estimators=10, n_candidates=50): 9 | self.name = 'LSHF(n_est=%d, n_cand=%d)' % (n_estimators, n_candidates) 10 | self._metric = metric 11 | self._n_estimators = n_estimators 12 | self._n_candidates = n_candidates 13 | 14 | def fit(self, X): 15 | self._lshf = sklearn.neighbors.LSHForest( 16 | n_estimators=self._n_estimators, n_candidates=self._n_candidates) 17 | if self._metric == 'angular': 18 | X = sklearn.preprocessing.normalize(X, axis=1, norm='l2') 19 | self._lshf.fit(X) 20 | 21 | def query(self, v, n): 22 | if self._metric == 'angular': 23 | v = sklearn.preprocessing.normalize([v], axis=1, norm='l2')[0] 24 | return self._lshf.kneighbors([v], return_distance=False, 25 | n_neighbors=n)[0] 26 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/dolphinnpy.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import sys 3 | sys.path.append("install/lib-dolphinnpy") # noqa 4 | import numpy 5 | import ctypes 6 | from dolphinn import Dolphinn 7 | from utils import findmean, isotropize 8 | from ann_benchmarks.algorithms.base import BaseANN 9 | 10 | 11 | class DolphinnPy(BaseANN): 12 | def __init__(self, num_probes): 13 | self.name = 'Dolphinn(num_probes={} )'.format(num_probes) 14 | self.num_probes = num_probes 15 | self.m = 1 16 | self._index = None 17 | 18 | def fit(self, X): 19 | if X.dtype != numpy.float32: 20 | X = numpy.array(X, dtype=numpy.float32) 21 | d = X.shape[1] 22 | self.m = findmean(X, d, 10) 23 | X = isotropize(X, d, self.m) 24 | hypercube_dim = int(numpy.log2(len(X))) - 2 25 | self._index = Dolphinn(X, d, hypercube_dim) 26 | 27 | def query(self, v, n): 28 | q = numpy.array([v]) 29 | q = isotropize(q, len(v), self.m) 30 | res = self._index.queries(q, n, self.num_probes) 31 | return res[0] 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Erik Bernhardsson 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/datasketch.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from datasketch import MinHashLSHForest, MinHash 3 | from ann_benchmarks.algorithms.base import BaseANN 4 | 5 | 6 | class DataSketch(BaseANN): 7 | def __init__(self, metric, n_perm, n_rep): 8 | if metric not in ('jaccard'): 9 | raise NotImplementedError( 10 | "Datasketch doesn't support metric %s" % metric) 11 | self._n_perm = n_perm 12 | self._n_rep = n_rep 13 | self._metric = metric 14 | self.name = 'Datasketch(n_perm=%d, n_rep=%d)' % (n_perm, n_rep) 15 | 16 | def fit(self, X): 17 | self._index = MinHashLSHForest(num_perm=self._n_perm, l=self._n_rep) 18 | for i, x in enumerate(X): 19 | m = MinHash(num_perm=self._n_perm) 20 | for e in x: 21 | m.update(str(e).encode('utf8')) 22 | self._index.add(str(i), m) 23 | self._index.index() 24 | 25 | def query(self, v, n): 26 | m = MinHash(num_perm=self._n_perm) 27 | for e in v: 28 | m.update(str(e).encode('utf8')) 29 | return map(int, self._index.query(m, n)) 30 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/scann.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import numpy as np 4 | import scann 5 | from ann_benchmarks.algorithms.base import BaseANN 6 | 7 | class Scann(BaseANN): 8 | 9 | def __init__(self, n_leaves, avq_threshold, dims_per_block): 10 | self.name = "scann n_leaves={} avq_threshold={:.02f} dims_per_block={}".format( 11 | n_leaves, avq_threshold, dims_per_block) 12 | self.n_leaves = n_leaves 13 | self.avq_threshold = avq_threshold 14 | self.dims_per_block = dims_per_block 15 | 16 | def fit(self, X): 17 | X[np.linalg.norm(X, axis=1) == 0] = 1.0 / np.sqrt(X.shape[1]) 18 | X /= np.linalg.norm(X, axis=1)[:, np.newaxis] 19 | 20 | self.searcher = scann.scann_ops_pybind.builder(X, 10, "dot_product").tree( 21 | self.n_leaves, 1, training_sample_size=350000, spherical=True, quantize_centroids=True).score_ah( 22 | self.dims_per_block, anisotropic_quantization_threshold=self.avq_threshold).reorder( 23 | 1).build() 24 | 25 | def set_query_arguments(self, leaves_reorder): 26 | self.leaves_to_search, self.reorder = leaves_reorder 27 | 28 | def query(self, v, n): 29 | return self.searcher.search(v, n, self.reorder, self.leaves_to_search)[0] 30 | -------------------------------------------------------------------------------- /install/Dockerfile.sptag: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/microsoft/SPTAG/blob/master/Dockerfile 2 | 3 | FROM ann-benchmarks 4 | 5 | RUN git clone https://github.com/microsoft/SPTAG 6 | RUN apt-get update && apt-get -y install wget build-essential libtbb-dev software-properties-common swig 7 | 8 | # cmake >= 3.12 is required 9 | RUN wget "https://github.com/Kitware/CMake/releases/download/v3.14.4/cmake-3.14.4-Linux-x86_64.tar.gz" -q -O - \ 10 | | tar -xz --strip-components=1 -C /usr/local 11 | 12 | # specific version of boost 13 | RUN wget "https://dl.bintray.com/boostorg/release/1.67.0/source/boost_1_67_0.tar.gz" -q -O - \ 14 | | tar -xz && \ 15 | cd boost_1_67_0 && \ 16 | ./bootstrap.sh && \ 17 | ./b2 install && \ 18 | # update ld cache so it finds boost in /usr/local/lib 19 | ldconfig && \ 20 | cd .. && rm -rf boost_1_67_0 21 | 22 | # SPTAG defaults to Python 2 if it's found on the system, so as a hack, we remove it. See https://github.com/microsoft/SPTAG/blob/master/Wrappers/CMakeLists.txt 23 | RUN apt-get -y remove libpython2.7 24 | 25 | # Compile 26 | RUN cd SPTAG && mkdir build && cd build && cmake .. && make && cd .. 27 | 28 | # so python can find the SPTAG module 29 | ENV PYTHONPATH=/home/app/SPTAG/Release 30 | RUN python3 -c 'import SPTAG' 31 | -------------------------------------------------------------------------------- /install/Dockerfile.milvus: -------------------------------------------------------------------------------- 1 | # Install Milvus 2 | FROM milvusdb/milvus:0.6.0-cpu-d120719-2b40dd as milvus 3 | RUN apt-get update 4 | RUN apt-get install -y wget 5 | RUN wget https://raw.githubusercontent.com/milvus-io/docs/master/v0.6.0/assets/server_config.yaml 6 | RUN sed -i 's/cpu_cache_capacity: 16/cpu_cache_capacity: 4/' server_config.yaml # otherwise my Docker blows up 7 | RUN mv server_config.yaml /var/lib/milvus/conf/server_config.yaml 8 | 9 | # Switch back to ANN-benchmarks base image and copy all files 10 | FROM ann-benchmarks 11 | COPY --from=milvus /var/lib/milvus /var/lib/milvus 12 | ENV LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:/var/lib/milvus/lib" 13 | RUN apt-get update 14 | RUN apt-get install -y libmysqlclient-dev 15 | 16 | # Python client 17 | RUN pip3 install pymilvus==0.2.7 18 | 19 | # Fixing some version incompatibility thing 20 | RUN pip3 install numpy==1.18 scipy==1.1.0 scikit-learn==0.21 21 | 22 | # Dumb entrypoint thing that runs the daemon as well 23 | RUN echo '#!/bin/bash' >> entrypoint.sh 24 | RUN echo '/var/lib/milvus/bin/milvus_server -d -c /var/lib/milvus/conf/server_config.yaml -l /var/lib/milvus/conf/log_config.conf' >> entrypoint.sh 25 | RUN echo 'sleep 5' >> entrypoint.sh 26 | RUN echo 'python3 run_algorithm.py "$@"' >> entrypoint.sh 27 | RUN chmod u+x entrypoint.sh 28 | ENTRYPOINT ["/home/app/entrypoint.sh"] 29 | -------------------------------------------------------------------------------- /ann_benchmarks/data.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import numpy 3 | 4 | 5 | def float_parse_entry(line): 6 | return [float(x) for x in line.strip().split()] 7 | 8 | 9 | def float_unparse_entry(entry): 10 | return " ".join(map(str, entry)) 11 | 12 | 13 | def int_parse_entry(line): 14 | return frozenset([int(x) for x in line.strip().split()]) 15 | 16 | 17 | def int_unparse_entry(entry): 18 | return " ".join(map(str, map(int, entry))) 19 | 20 | 21 | def bit_parse_entry(line): 22 | return [bool(int(x)) for x in list(line.strip() 23 | .replace(" ", "") 24 | .replace("\t", ""))] 25 | 26 | 27 | def bit_unparse_entry(entry): 28 | return " ".join(map(lambda el: "1" if el else "0", entry)) 29 | 30 | 31 | type_info = { 32 | "float": { 33 | "type": numpy.float, 34 | "parse_entry": float_parse_entry, 35 | "unparse_entry": float_unparse_entry, 36 | "finish_entries": numpy.vstack 37 | }, 38 | "bit": { 39 | "type": numpy.bool_, 40 | "parse_entry": bit_parse_entry, 41 | "unparse_entry": bit_unparse_entry 42 | }, 43 | "int": { 44 | "type": numpy.object, 45 | "parse_entry": int_parse_entry, 46 | "unparse_entry": int_unparse_entry, 47 | }, 48 | } 49 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/faiss_hnsw.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import faiss 4 | import numpy as np 5 | from ann_benchmarks.constants import INDEX_DIR 6 | from ann_benchmarks.algorithms.base import BaseANN 7 | from ann_benchmarks.algorithms.faiss import Faiss 8 | 9 | 10 | class FaissHNSW(Faiss): 11 | def __init__(self, metric, method_param): 12 | self._metric = metric 13 | self.method_param = method_param 14 | 15 | def fit(self, X): 16 | self.index = faiss.IndexHNSWFlat(len(X[0]), self.method_param["M"]) 17 | self.index.hnsw.efConstruction = self.method_param["efConstruction"] 18 | self.index.verbose = True 19 | 20 | if self._metric == 'angular': 21 | X = X / np.linalg.norm(X, axis=1)[:, np.newaxis] 22 | if X.dtype != np.float32: 23 | X = X.astype(np.float32) 24 | 25 | self.index.add(X) 26 | faiss.omp_set_num_threads(1) 27 | 28 | def set_query_arguments(self, ef): 29 | faiss.cvar.hnsw_stats.reset() 30 | self.index.hnsw.efSearch = ef 31 | 32 | def get_additional(self): 33 | return {"dist_comps": faiss.cvar.hnsw_stats.ndis} 34 | 35 | def __str__(self): 36 | return 'faiss (%s, ef: %d)' % (self.method_param, self.index.hnsw.efSearch) 37 | 38 | def freeIndex(self): 39 | del self.p 40 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/hnswlib.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import hnswlib 4 | import numpy as np 5 | from ann_benchmarks.constants import INDEX_DIR 6 | from ann_benchmarks.algorithms.base import BaseANN 7 | 8 | 9 | class HnswLib(BaseANN): 10 | def __init__(self, metric, method_param): 11 | self.metric = {'angular': 'cosine', 'euclidean': 'l2'}[metric] 12 | self.method_param = method_param 13 | # print(self.method_param,save_index,query_param) 14 | # self.ef=query_param['ef'] 15 | self.name = 'hnswlib (%s)' % (self.method_param) 16 | 17 | def fit(self, X): 18 | # Only l2 is supported currently 19 | self.p = hnswlib.Index(space=self.metric, dim=len(X[0])) 20 | self.p.init_index(max_elements=len(X), 21 | ef_construction=self.method_param["efConstruction"], 22 | M=self.method_param["M"]) 23 | data_labels = np.arange(len(X)) 24 | self.p.add_items(np.asarray(X), data_labels) 25 | self.p.set_num_threads(1) 26 | 27 | def set_query_arguments(self, ef): 28 | self.p.set_ef(ef) 29 | 30 | def query(self, v, n): 31 | # print(np.expand_dims(v,axis=0).shape) 32 | # print(self.p.knn_query(np.expand_dims(v,axis=0), k = n)[0]) 33 | return self.p.knn_query(np.expand_dims(v, axis=0), k=n)[0][0] 34 | 35 | def freeIndex(self): 36 | del self.p 37 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/kgraph.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import numpy 4 | import pykgraph 5 | from ann_benchmarks.constants import INDEX_DIR 6 | from ann_benchmarks.algorithms.base import BaseANN 7 | 8 | 9 | class KGraph(BaseANN): 10 | def __init__(self, metric, index_params, save_index): 11 | metric = str(metric) 12 | self.name = 'KGraph(%s)' % (metric) 13 | self._metric = metric 14 | self._index_params = index_params 15 | self._save_index = save_index 16 | 17 | def fit(self, X): 18 | if X.dtype != numpy.float32: 19 | X = X.astype(numpy.float32) 20 | self._kgraph = pykgraph.KGraph(X, self._metric) 21 | path = os.path.join(INDEX_DIR, 'kgraph-index-%s' % self._metric) 22 | if os.path.exists(path): 23 | self._kgraph.load(path) 24 | else: 25 | # iterations=30, L=100, delta=0.002, recall=0.99, K=25) 26 | self._kgraph.build(**self._index_params) 27 | if not os.path.exists(INDEX_DIR): 28 | os.makedirs(INDEX_DIR) 29 | self._kgraph.save(path) 30 | 31 | def set_query_arguments(self, P): 32 | self._P = P 33 | 34 | def query(self, v, n): 35 | if v.dtype != numpy.float32: 36 | v = v.astype(numpy.float32) 37 | result = self._kgraph.search( 38 | numpy.array([v]), K=n, threads=1, P=self._P) 39 | return result[0] 40 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/mrpt.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import numpy 3 | import sklearn.preprocessing 4 | import mrpt 5 | from ann_benchmarks.algorithms.base import BaseANN 6 | 7 | 8 | class MRPT(BaseANN): 9 | def __init__(self, metric, count): 10 | self._metric = metric 11 | self._k = count 12 | 13 | def fit(self, X): 14 | if X.dtype != numpy.float32: 15 | X = X.astype(numpy.float32) 16 | if self._metric == 'angular': 17 | X = sklearn.preprocessing.normalize(X, axis=1, norm='l2') 18 | 19 | self._index_autotuned = mrpt.MRPTIndex(X) 20 | self._index_autotuned.build_autotune_sample( 21 | target_recall=None, k=self._k, n_test=1000) 22 | 23 | def set_query_arguments(self, target_recall): 24 | self._target_recall = target_recall 25 | self._index = self._index_autotuned.subset(target_recall) 26 | self._par = self._index.parameters() 27 | 28 | def query(self, v, n): 29 | if v.dtype != numpy.float32: 30 | v = v.astype(numpy.float32) 31 | if self._metric == 'angular': 32 | v = sklearn.preprocessing.normalize( 33 | v.reshape(1, -1), axis=1, norm='l2').flatten() 34 | return self._index.ann(v) 35 | 36 | def __str__(self): 37 | str_template = ('MRPT(target recall=%.3f, trees=%d, depth=%d, vote ' 38 | 'threshold=%d, estimated recall=%.3f)') 39 | return str_template % (self._target_recall, self._par['n_trees'], 40 | self._par['depth'], self._par['votes'], 41 | self._par['estimated_recall']) 42 | -------------------------------------------------------------------------------- /install/Dockerfile.elasticsearch: -------------------------------------------------------------------------------- 1 | FROM ann-benchmarks 2 | 3 | WORKDIR /home/app 4 | 5 | # Install elasticsearch. 6 | ENV DEBIAN_FRONTEND noninteractive 7 | RUN apt install -y wget curl htop 8 | RUN wget --quiet https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.9.2-amd64.deb \ 9 | && dpkg -i elasticsearch-7.9.2-amd64.deb \ 10 | && rm elasticsearch-7.9.2-amd64.deb 11 | 12 | # Install python client. 13 | RUN python3 -m pip install --upgrade elasticsearch==7.9.1 14 | 15 | # Configure elasticsearch and JVM for single-node, single-core. 16 | RUN echo '\ 17 | discovery.type: single-node\n\ 18 | network.host: 0.0.0.0\n\ 19 | node.master: true\n\ 20 | node.data: true\n\ 21 | node.processors: 1\n\ 22 | thread_pool.write.size: 1\n\ 23 | thread_pool.search.size: 1\n\ 24 | thread_pool.search.queue_size: 1\n\ 25 | path.data: /var/lib/elasticsearch\n\ 26 | path.logs: /var/log/elasticsearch\n\ 27 | ' > /etc/elasticsearch/elasticsearch.yml 28 | 29 | RUN echo '\ 30 | -Xms3G\n\ 31 | -Xmx3G\n\ 32 | -XX:+UseG1GC\n\ 33 | -XX:G1ReservePercent=25\n\ 34 | -XX:InitiatingHeapOccupancyPercent=30\n\ 35 | -XX:+HeapDumpOnOutOfMemoryError\n\ 36 | -XX:HeapDumpPath=/var/lib/elasticsearch\n\ 37 | -XX:ErrorFile=/var/log/elasticsearch/hs_err_pid%p.log\n\ 38 | -Xlog:gc*,gc+age=trace,safepoint:file=/var/log/elasticsearch/gc.log:utctime,pid,tags:filecount=32,filesize=64m' > /etc/elasticsearch/jvm.options 39 | 40 | # Make sure you can start the service. 41 | RUN service elasticsearch start && service elasticsearch stop 42 | 43 | # Custom entrypoint that also starts the Elasticsearch server. 44 | RUN echo 'service elasticsearch start && python3 -u run_algorithm.py "$@"' > entrypoint.sh 45 | ENTRYPOINT ["/bin/bash", "/home/app/entrypoint.sh"] 46 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/puffinn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import puffinn 3 | from ann_benchmarks.algorithms.base import BaseANN 4 | import numpy 5 | 6 | class Puffinn(BaseANN): 7 | def __init__(self, metric, space=10**6, hash_function="fht_crosspolytope", hash_source='pool', hash_args=None): 8 | if metric not in ['jaccard', 'angular']: 9 | raise NotImplementedError( 10 | "Puffinn doesn't support metric %s" % metric) 11 | self.metric = metric 12 | self.space = space 13 | self.hash_function = hash_function 14 | self.hash_source = hash_source 15 | self.hash_args = hash_args 16 | 17 | def fit(self, X): 18 | if self.hash_args: 19 | self.index = puffinn.Index(self.metric, len(X[0]), self.space,\ 20 | hash_function=self.hash_function, hash_source=self.hash_source,\ 21 | hash_args=self.hash_args) 22 | else: 23 | self.index = puffinn.Index(self.metric, len(X[0]), self.space,\ 24 | hash_function=self.hash_function, hash_source=self.hash_source) 25 | for i, x in enumerate(X): 26 | if self.metric == 'angular': 27 | x = x.tolist() 28 | self.index.insert(x) 29 | self.index.rebuild() 30 | 31 | def set_query_arguments(self, recall): 32 | self.recall = recall 33 | 34 | def query(self, v, n): 35 | if self.metric == 'angular': 36 | v = v.tolist() 37 | return self.index.search(v, n, self.recall) 38 | 39 | def __str__(self): 40 | return 'PUFFINN(space=%d, recall=%f, hf=%s, hashsource=%s)' % (self.space, self.recall, self.hash_function, self.hash_source) 41 | 42 | -------------------------------------------------------------------------------- /ann_benchmarks/distance.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from scipy.spatial.distance import pdist as scipy_pdist 3 | import itertools 4 | import numpy as np 5 | 6 | def pdist(a, b, metric): 7 | return scipy_pdist([a, b], metric=metric)[0] 8 | 9 | # Need own implementation of jaccard because scipy's 10 | # implementation is different 11 | 12 | def jaccard(a, b): 13 | if len(a) == 0 or len(b) == 0: 14 | return 0 15 | intersect = len(set(a) & set(b)) 16 | return intersect / (float)(len(a) + len(b) - intersect) 17 | 18 | def transform_dense_to_sparse(X): 19 | """Converts the n * m dataset into a sparse format 20 | that only holds the non-zero entries (Jaccard).""" 21 | # get list of indices of non-zero elements 22 | indices = np.transpose(np.where(X)) 23 | keys = [] 24 | for _, js in itertools.groupby(indices, lambda ij: ij[0]): 25 | keys.append([j for _, j in js]) 26 | 27 | assert len(X) == len(keys) 28 | 29 | return keys 30 | 31 | metrics = { 32 | 'hamming': { 33 | 'distance': lambda a, b: pdist(a, b, "hamming"), 34 | 'distance_valid': lambda a: True 35 | }, 36 | # return 1 - jaccard similarity, because smaller distances are better. 37 | 'jaccard': { 38 | 'distance': lambda a, b: 1 - jaccard(a, b), 39 | 'distance_valid': lambda a: a < 1 - 1e-5 40 | }, 41 | 'euclidean': { 42 | 'distance': lambda a, b: pdist(a, b, "euclidean"), 43 | 'distance_valid': lambda a: True 44 | }, 45 | 'angular': { 46 | 'distance': lambda a, b: pdist(a, b, "cosine"), 47 | 'distance_valid': lambda a: True 48 | } 49 | } 50 | 51 | dataset_transform = { 52 | 'hamming': lambda X: X, 53 | 'euclidean': lambda X: X, 54 | 'angular': lambda X: X, 55 | 'jaccard' : lambda X: transform_dense_to_sparse(X) 56 | } 57 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/nearpy.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import nearpy 3 | from nearpy.filters import NearestFilter 4 | import sklearn.preprocessing 5 | from ann_benchmarks.algorithms.base import BaseANN 6 | 7 | 8 | class NearPy(BaseANN): 9 | def __init__(self, metric, n_bits, hash_counts): 10 | self._n_bits = n_bits 11 | self._hash_counts = hash_counts 12 | self._metric = metric 13 | self._filter = NearestFilter(10) 14 | self.name = 'NearPy(n_bits=%d, hash_counts=%d)' % ( 15 | self._n_bits, self._hash_counts) 16 | 17 | def fit(self, X): 18 | hashes = [] 19 | 20 | for k in range(self._hash_counts): 21 | nearpy_rbp = nearpy.hashes.RandomBinaryProjections( 22 | 'rbp_%d' % k, self._n_bits) 23 | hashes.append(nearpy_rbp) 24 | 25 | if self._metric == 'euclidean': 26 | dist = nearpy.distances.EuclideanDistance() 27 | self._nearpy_engine = nearpy.Engine( 28 | X.shape[1], 29 | lshashes=hashes, 30 | distance=dist) 31 | else: # Default (angular) = Cosine distance 32 | self._nearpy_engine = nearpy.Engine( 33 | X.shape[1], 34 | lshashes=hashes, 35 | vector_filters=[self._filter]) 36 | 37 | if self._metric == 'angular': 38 | X = sklearn.preprocessing.normalize(X, axis=1, norm='l2') 39 | for i, x in enumerate(X): 40 | self._nearpy_engine.store_vector(x, i) 41 | 42 | def query(self, v, n): 43 | # XXX: This feels like an unpleasant hack, but it's not clear how to do 44 | # better without making changes to NearPy 45 | self._filter.N = n 46 | if self._metric == 'angular': 47 | v = sklearn.preprocessing.normalize([v], axis=1, norm='l2')[0] 48 | return [y for x, y, z in self._nearpy_engine.neighbours(v)] 49 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: required 2 | 3 | language: python 4 | python: 5 | - "3.6" 6 | 7 | services: 8 | - docker 9 | 10 | env: 11 | - LIBRARY=annoy DATASET=random-xs-20-angular 12 | - LIBRARY=dolphinn DATASET=random-xs-20-angular 13 | - LIBRARY=faiss DATASET=random-xs-20-angular 14 | - LIBRARY=flann DATASET=random-xs-20-angular 15 | - LIBRARY=kgraph DATASET=random-xs-20-angular 16 | - LIBRARY=milvus DATASET=random-xs-20-angular 17 | - LIBRARY=mrpt DATASET=random-xs-20-angular 18 | - LIBRARY=n2 DATASET=random-xs-20-angular 19 | - LIBRARY=nearpy DATASET=random-xs-20-angular 20 | - LIBRARY=ngt DATASET=random-xs-20-angular 21 | - LIBRARY=nmslib DATASET=random-xs-20-angular 22 | - LIBRARY=hnswlib DATASET=random-xs-20-angular 23 | - LIBRARY=puffinn DATASET=random-xs-20-angular 24 | - LIBRARY=pynndescent DATASET=random-xs-20-angular 25 | - LIBRARY=rpforest DATASET=random-xs-20-angular 26 | - LIBRARY=sklearn DATASET=random-xs-20-angular 27 | - LIBRARY=sptag DATASET=random-xs-20-angular 28 | - LIBRARY=mih DATASET=random-xs-16-hamming 29 | - LIBRARY=datasketch DATASET=random-s-jaccard 30 | - LIBRARY=scann DATASET=random-xs-20-angular 31 | - LIBRARY=elasticsearch DATASET=random-xs-20-angular 32 | - LIBRARY=elastiknn DATASET=random-xs-20-angular 33 | 34 | before_install: 35 | - pip install -r requirements.txt 36 | - python install.py 37 | 38 | script: 39 | - python run.py --docker-tag ann-benchmarks-${LIBRARY} --max-n-algorithms 5 --dataset $DATASET --run-disabled --timeout 300 40 | - python run.py --docker-tag ann-benchmarks-${LIBRARY} --max-n-algorithms 5 --dataset $DATASET --run-disabled --batch --timeout 300 41 | - sudo chmod -R 777 results/ 42 | - python plot.py --dataset $DATASET --output plot.png 43 | - python plot.py --dataset $DATASET --output plot-batch.png --batch 44 | - python -m unittest test/test-metrics.py 45 | - python -m unittest test/test-jaccard.py 46 | - python create_website.py --outputdir . --scatter --latex 47 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/milvus.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import milvus 3 | import numpy 4 | import sklearn.preprocessing 5 | from ann_benchmarks.algorithms.base import BaseANN 6 | 7 | 8 | class Milvus(BaseANN): 9 | def __init__(self, metric, index_type, nlist): 10 | self._nlist = nlist 11 | self._nprobe = None 12 | self._metric = metric 13 | self._milvus = milvus.Milvus() 14 | self._milvus.connect(host='localhost', port='19530') 15 | self._table_name = 'test01' 16 | self._index_type = index_type 17 | 18 | def fit(self, X): 19 | if self._metric == 'angular': 20 | X = sklearn.preprocessing.normalize(X, axis=1, norm='l2') 21 | 22 | self._milvus.create_table({'table_name': self._table_name, 'dimension': X.shape[1]}) 23 | vector_ids = [id for id in range(len(X))] 24 | self._milvus.insert(table_name=self._table_name, records=X.tolist(), ids=vector_ids) 25 | index_type = getattr(milvus.IndexType, self._index_type) # a bit hacky but works 26 | self._milvus.create_index(self._table_name, {'index_type': index_type, 'nlist': self._nlist}) 27 | 28 | def set_query_arguments(self, nprobe): 29 | if nprobe > self._nlist: 30 | print('warning! nprobe > nlist') 31 | nprobe = self._nlist 32 | self._nprobe = nprobe 33 | 34 | def query(self, v, n): 35 | if self._metric == 'angular': 36 | v /= numpy.linalg.norm(v) 37 | v = v.tolist() 38 | status, results = self._milvus.search(table_name=self._table_name, query_records=[v], top_k=n, nprobe=self._nprobe) 39 | if not results: 40 | return [] # Seems to happen occasionally, not sure why 41 | result_ids = [result.id for result in results[0]] 42 | return result_ids 43 | 44 | def __str__(self): 45 | return 'Milvus(index_type=%s, nlist=%d, nprobe=%d)' % (self._index_type, self._nlist, self._nprobe) 46 | -------------------------------------------------------------------------------- /protocol/ext-query-parameters.md: -------------------------------------------------------------------------------- 1 | (This document describes an extension that front-ends aren't required to implement. Front-ends that don't implement this extension should reject attempts to set the `query-parameters` front-end configuration option.) 2 | 3 | Many algorithms expose parameters that can be changed to adjust their search strategies without requiring that training data be resubmitted. When the front-end configuration option `query-parameters` is set to `1`, a new command will be added to query mode allowing these query configuration parameters to be changed. 4 | 5 | (Front-ends that support other optional query modes, such as prepared or batch queries, should also add this command to those modes.) 6 | 7 | ## Commands 8 | 9 | ### Configuration mode 10 | 11 | #### `frontend query-parameters V` (three tokens) 12 | 13 | If `V` is `1`, then request that query mode expose the `query-params` command. If `V` is anything else, then withdraw this request. 14 | 15 | Responses: 16 | 17 | * `epbprtv0 ok` 18 | 19 | The availability of the `query-params` command has been changed accordingly. 20 | 21 | * `epbprtv0 fail` 22 | 23 | This command has had no effect on the availability of the `query-params` command. 24 | 25 | ### Training mode 26 | 27 | This extension makes no changes to training mode. 28 | 29 | ### Query mode 30 | 31 | When the `query-parameters` front-end configuration option has been set to `1`, this extension adds one new command to query mode: 32 | 33 | #### `query-params [VALUE0, ..., VALUEk] set` (two or more tokens) 34 | 35 | Change the values of the query parameters. 36 | 37 | (The final token `set` is required. It exists for the sake of compatibility with the `batch-queries` extension, which also uses variable-length commands but which requires that the last token specify a number.) 38 | 39 | Responses: 40 | 41 | * `epbprtv0 ok` 42 | 43 | The query parameters were changed to the given values. 44 | 45 | * `epbprtv0 fail` 46 | 47 | The query parameters were not changed to the given values, perhaps because one of them was invalid. 48 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/faiss_gpu.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import sys 3 | # Assumes local installation of FAISS 4 | sys.path.append("faiss") # noqa 5 | import numpy 6 | import ctypes 7 | import faiss 8 | from ann_benchmarks.algorithms.base import BaseANN 9 | 10 | # Implementation based on 11 | # https://github.com/facebookresearch/faiss/blob/master/benchs/bench_gpu_sift1m.py # noqa 12 | 13 | 14 | class FaissGPU(BaseANN): 15 | def __init__(self, n_bits, n_probes): 16 | self.name = 'FaissGPU(n_bits={}, n_probes={})'.format( 17 | n_bits, n_probes) 18 | self._n_bits = n_bits 19 | self._n_probes = n_probes 20 | self._res = faiss.StandardGpuResources() 21 | self._index = None 22 | 23 | def fit(self, X): 24 | X = X.astype(numpy.float32) 25 | self._index = faiss.GpuIndexIVFFlat(self._res, len(X[0]), self._n_bits, 26 | faiss.METRIC_L2) 27 | # self._index = faiss.index_factory(len(X[0]), 28 | # "IVF%d,Flat" % self._n_bits) 29 | # co = faiss.GpuClonerOptions() 30 | # co.useFloat16 = True 31 | # self._index = faiss.index_cpu_to_gpu(self._res, 0, 32 | # self._index, co) 33 | self._index.train(X) 34 | self._index.add(X) 35 | self._index.setNumProbes(self._n_probes) 36 | 37 | def query(self, v, n): 38 | return [label for label, _ in self.query_with_distances(v, n)] 39 | 40 | def query_with_distances(self, v, n): 41 | v = v.astype(numpy.float32).reshape(1, -1) 42 | distances, labels = self._index.search(v, n) 43 | r = [] 44 | for l, d in zip(labels[0], distances[0]): 45 | if l != -1: 46 | r.append((l, d)) 47 | return r 48 | 49 | def batch_query(self, X, n): 50 | self.res = self._index.search(X.astype(numpy.float32), n) 51 | 52 | def get_batch_results(self): 53 | D, L = self.res 54 | res = [] 55 | for i in range(len(D)): 56 | r = [] 57 | for l, d in zip(L[i], D[i]): 58 | if l != -1: 59 | r.append(l) 60 | res.append(r) 61 | return res 62 | -------------------------------------------------------------------------------- /install/Dockerfile.elastiknn: -------------------------------------------------------------------------------- 1 | FROM ann-benchmarks 2 | 3 | WORKDIR /home/app 4 | 5 | # Install elasticsearch. 6 | ENV DEBIAN_FRONTEND noninteractive 7 | RUN apt install -y wget curl htop 8 | RUN wget --quiet https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-oss-7.9.2-amd64.deb \ 9 | && dpkg -i elasticsearch-oss-7.9.2-amd64.deb \ 10 | && rm elasticsearch-oss-7.9.2-amd64.deb 11 | 12 | # Install python client. 13 | RUN python3 -m pip install --upgrade elastiknn-client==0.1.0rc47 14 | 15 | # Install plugin. 16 | RUN /usr/share/elasticsearch/bin/elasticsearch-plugin install --batch \ 17 | https://github.com/alexklibisz/elastiknn/releases/download/0.1.0-PRE47/elastiknn-0.1.0-PRE47_es7.9.2.zip 18 | 19 | # Configure elasticsearch and JVM for single-node, single-core. 20 | RUN cp /etc/elasticsearch/jvm.options /etc/elasticsearch/jvm.options.bak 21 | RUN cp /etc/elasticsearch/elasticsearch.yml /etc/elasticsearch/elasticsearch.yml.bak 22 | 23 | RUN echo '\ 24 | discovery.type: single-node\n\ 25 | network.host: 0.0.0.0\n\ 26 | node.master: true\n\ 27 | node.data: true\n\ 28 | node.processors: 1\n\ 29 | thread_pool.write.size: 1\n\ 30 | thread_pool.search.size: 1\n\ 31 | thread_pool.search.queue_size: 1\n\ 32 | path.data: /var/lib/elasticsearch\n\ 33 | path.logs: /var/log/elasticsearch\n\ 34 | ' > /etc/elasticsearch/elasticsearch.yml 35 | 36 | RUN echo '\ 37 | -Xms3G\n\ 38 | -Xmx3G\n\ 39 | -XX:+UseG1GC\n\ 40 | -XX:G1ReservePercent=25\n\ 41 | -XX:InitiatingHeapOccupancyPercent=30\n\ 42 | -XX:+HeapDumpOnOutOfMemoryError\n\ 43 | -XX:HeapDumpPath=/var/lib/elasticsearch\n\ 44 | -XX:ErrorFile=/var/log/elasticsearch/hs_err_pid%p.log\n\ 45 | -Xlog:gc*,gc+age=trace,safepoint:file=/var/log/elasticsearch/gc.log:utctime,pid,tags:filecount=32,filesize=64m\n\ 46 | -Dcom.sun.management.jmxremote.ssl=false\n\ 47 | -Dcom.sun.management.jmxremote.authenticate=false\n\ 48 | -Dcom.sun.management.jmxremote.local.only=false\n\ 49 | -Dcom.sun.management.jmxremote.port=8097\n\ 50 | -Dcom.sun.management.jmxremote.rmi.port=8097\n\ 51 | -Djava.rmi.server.hostname=localhost' > /etc/elasticsearch/jvm.options 52 | 53 | # JMX port. Need to also map the port when running. 54 | EXPOSE 8097 55 | 56 | # Make sure you can start the service. 57 | RUN service elasticsearch start && service elasticsearch stop 58 | 59 | # Custom entrypoint that also starts the Elasticsearch server.\ 60 | RUN echo 'service elasticsearch start && python3 -u run_algorithm.py "$@"' > entrypoint.sh 61 | ENTRYPOINT ["/bin/bash", "/home/app/entrypoint.sh"] 62 | -------------------------------------------------------------------------------- /algosP.yaml: -------------------------------------------------------------------------------- 1 | float: 2 | any: 3 | bruteforce: 4 | docker-tag: ann-benchmarks-sklearn 5 | module: ann_benchmarks.algorithms.bruteforce 6 | constructor: BruteForce 7 | base-args: ["@metric"] 8 | run-groups: 9 | empty: 10 | args: [] 11 | bruteforce-blas: 12 | docker-tag: ann-benchmarks-sklearn 13 | module: ann_benchmarks.algorithms.bruteforce 14 | constructor: BruteForceBLAS 15 | base-args: ["@metric"] 16 | run-groups: 17 | empty: 18 | args: [] 19 | angular: 20 | pp-bruteforce-lo: 21 | module: ann_benchmarks.algorithms.subprocess 22 | docker-tag: ann-benchmarks-subprocess 23 | constructor: FloatSubprocess 24 | base-args: [["protocol/bf-runner"]] 25 | run-groups: 26 | jf-linear: 27 | args: {"point-type": "float", "distance": "angular"} 28 | pp-bruteforce-hi: 29 | module: ann_benchmarks.algorithms.subprocess 30 | docker-tag: ann-benchmarks-subprocess 31 | constructor: FloatSubprocessPrepared 32 | base-args: [["protocol/bf-runner"]] 33 | run-groups: 34 | jf-linear: 35 | args: {"point-type": "float", "distance": "angular"} 36 | pp-bruteforce-blas-lo: 37 | module: ann_benchmarks.algorithms.subprocess 38 | docker-tag: ann-benchmarks-subprocess 39 | constructor: FloatSubprocess 40 | base-args: [["protocol/bf-runner"]] 41 | run-groups: 42 | jf-linear: 43 | args: {"point-type": "float", "distance": "angular", "fast": 1} 44 | pp-bruteforce-blas-hi: 45 | module: ann_benchmarks.algorithms.subprocess 46 | docker-tag: ann-benchmarks-subprocess 47 | constructor: FloatSubprocessPrepared 48 | base-args: [["protocol/bf-runner"]] 49 | run-groups: 50 | jf-linear: 51 | args: {"point-type": "float", "distance": "angular", "fast": 1} 52 | pp-bruteforce-batch: 53 | module: ann_benchmarks.algorithms.subprocess 54 | docker-tag: ann-benchmarks-subprocess 55 | constructor: FloatSubprocessBatch 56 | base-args: [["protocol/bf-runner"]] 57 | run-groups: 58 | jf-linear: 59 | args: {"point-type": "float", "distance": "angular"} 60 | pp-bruteforce-blas-batch: 61 | module: ann_benchmarks.algorithms.subprocess 62 | docker-tag: ann-benchmarks-subprocess 63 | constructor: FloatSubprocessBatch 64 | base-args: [["protocol/bf-runner"]] 65 | run-groups: 66 | jf-linear: 67 | args: {"point-type": "float", "distance": "angular", "fast": 1} 68 | -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import argparse 4 | import subprocess 5 | from multiprocessing import Pool 6 | from ann_benchmarks.main import positive_int 7 | 8 | 9 | def build(library, args): 10 | print('Building %s...' % library) 11 | if args is not None and len(args) != 0: 12 | q = " ".join(["--build-arg " + x.replace(" ", "\\ ") for x in args]) 13 | else: 14 | q = "" 15 | 16 | try: 17 | subprocess.check_call( 18 | 'docker build %s --rm -t ann-benchmarks-%s -f' 19 | ' install/Dockerfile.%s .' % (q, library, library), shell=True) 20 | return {library: 'success'} 21 | except subprocess.CalledProcessError: 22 | return {library: 'fail'} 23 | 24 | 25 | def build_multiprocess(args): 26 | return build(*args) 27 | 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser( 31 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 32 | parser.add_argument( 33 | "--proc", 34 | default=1, 35 | type=positive_int, 36 | help="the number of process to build docker images") 37 | parser.add_argument( 38 | '--algorithm', 39 | metavar='NAME', 40 | help='build only the named algorithm image', 41 | default=None) 42 | parser.add_argument( 43 | '--build-arg', 44 | help='pass given args to all docker builds', 45 | nargs="+") 46 | args = parser.parse_args() 47 | 48 | print('Building base image...') 49 | subprocess.check_call( 50 | 'docker build \ 51 | --rm -t ann-benchmarks -f install/Dockerfile .', shell=True) 52 | 53 | if args.algorithm: 54 | print('Building algorithm(%s) image...' % args.algorithm) 55 | build(args.algorithm, args.build_arg) 56 | elif os.getenv('LIBRARY'): 57 | print('Building algorithm(%s) image...' % os.getenv('LIBRARY')) 58 | build(os.getenv('LIBRARY'), args.build_arg) 59 | else: 60 | print('Building algorithm images... with (%d) processes' % args.proc) 61 | tags = [fn.split('.')[-1] for fn in os.listdir('install') if fn.startswith('Dockerfile.')] 62 | 63 | if args.proc == 1: 64 | install_status = [build(tag, args.build_arg) for tag in tags] 65 | else: 66 | pool = Pool(processes=args.proc) 67 | install_status = pool.map(build_multiprocess, [(tag, args.build_arg) for tag in tags]) 68 | pool.close() 69 | pool.join() 70 | 71 | print('\n\nInstall Status:\n' + '\n'.join(str(algo) for algo in install_status)) -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/faiss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import sys 3 | sys.path.append("install/lib-faiss") # noqa 4 | import numpy 5 | import sklearn.preprocessing 6 | import ctypes 7 | import faiss 8 | from ann_benchmarks.algorithms.base import BaseANN 9 | 10 | 11 | class Faiss(BaseANN): 12 | def query(self, v, n): 13 | if self._metric == 'angular': 14 | v /= numpy.linalg.norm(v) 15 | D, I = self.index.search(numpy.expand_dims( 16 | v, axis=0).astype(numpy.float32), n) 17 | return I[0] 18 | 19 | def batch_query(self, X, n): 20 | if self._metric == 'angular': 21 | X /= numpy.linalg.norm(X) 22 | self.res = self.index.search(X.astype(numpy.float32), n) 23 | 24 | def get_batch_results(self): 25 | D, L = self.res 26 | res = [] 27 | for i in range(len(D)): 28 | r = [] 29 | for l, d in zip(L[i], D[i]): 30 | if l != -1: 31 | r.append(l) 32 | res.append(r) 33 | return res 34 | 35 | 36 | class FaissLSH(Faiss): 37 | def __init__(self, metric, n_bits): 38 | self._n_bits = n_bits 39 | self.index = None 40 | self._metric = metric 41 | self.name = 'FaissLSH(n_bits={})'.format(self._n_bits) 42 | 43 | def fit(self, X): 44 | if X.dtype != numpy.float32: 45 | X = X.astype(numpy.float32) 46 | f = X.shape[1] 47 | self.index = faiss.IndexLSH(f, self._n_bits) 48 | self.index.train(X) 49 | self.index.add(X) 50 | 51 | 52 | class FaissIVF(Faiss): 53 | def __init__(self, metric, n_list): 54 | self._n_list = n_list 55 | self._metric = metric 56 | 57 | def fit(self, X): 58 | if self._metric == 'angular': 59 | X = sklearn.preprocessing.normalize(X, axis=1, norm='l2') 60 | 61 | if X.dtype != numpy.float32: 62 | X = X.astype(numpy.float32) 63 | 64 | self.quantizer = faiss.IndexFlatL2(X.shape[1]) 65 | index = faiss.IndexIVFFlat( 66 | self.quantizer, X.shape[1], self._n_list, faiss.METRIC_L2) 67 | index.train(X) 68 | index.add(X) 69 | self.index = index 70 | 71 | def set_query_arguments(self, n_probe): 72 | faiss.cvar.indexIVF_stats.reset() 73 | self._n_probe = n_probe 74 | self.index.nprobe = self._n_probe 75 | 76 | def get_additional(self): 77 | return {"dist_comps": faiss.cvar.indexIVF_stats.ndis + # noqa 78 | faiss.cvar.indexIVF_stats.nq * self._n_list} 79 | 80 | def __str__(self): 81 | return 'FaissIVF(n_list=%d, n_probe=%d)' % (self._n_list, 82 | self._n_probe) 83 | -------------------------------------------------------------------------------- /protocol/ext-batch-queries.md: -------------------------------------------------------------------------------- 1 | (This document describes an extension that front-ends aren't required to implement. Front-ends that don't implement this extension should reject attempts to set the `batch-queries` front-end configuration option.) 2 | 3 | When the front-end configuration option `batch-queries` is set to `1`, after finishing training mode, the front-end will transition to batch query mode instead of query mode. In batch query mode, all queries are submitted at once, and the front-end will indicate when the queries have finished before any results are returned. 4 | 5 | ## Commands 6 | 7 | ### Configuration mode 8 | 9 | #### `frontend batch-queries V` (three tokens) 10 | 11 | If `V` is `1`, then request that the front-end transition into batch query mode, and not query mode, after training mode has finished. If `V` is anything else, then request that it transition into query mode as usual. 12 | 13 | Responses: 14 | 15 | * `epbprtv0 ok` 16 | 17 | The front-end will transition into the requested query mode after the training mode has finished. 18 | 19 | * `epbprtv0 fail` 20 | 21 | This command has had no effect on the query mode transition. 22 | 23 | ### Training mode 24 | 25 | This extension changes the behaviour of one command in training mode: 26 | 27 | #### *empty line* (zero tokens) 28 | 29 | Finish training mode and enter batch query mode. 30 | 31 | Responses: 32 | 33 | * `epbprtv0 ok COUNT1 [fail COUNT2]` 34 | 35 | `COUNT1` (potentially zero) entries were successfully interpreted and added to the data structure. (`COUNT2` entries couldn't be interpreted or couldn't be added for other reasons.): 36 | 37 | ### Batch query mode 38 | 39 | In batch query mode, front-ends should respond to three different kinds of command: 40 | 41 | #### `ENTRY0 [..., ENTRYk] N` (two or more tokens) 42 | 43 | Prepare to run a query to find at most `N` (greater than or equal to 1) close matches for each of the `k` query points from `ENTRY0` to `ENTRYk`. 44 | 45 | Responses: 46 | 47 | * `epbprtv0 ok` 48 | 49 | Preparation is complete, and the `query` command can now be used. 50 | 51 | * `epbprtv0 fail` 52 | 53 | Preparation has failed, and the `query` command should not be used. This may occur if one of the `k` query points could not be parsed. 54 | 55 | #### `query` (one token) 56 | 57 | Run the last prepared query. 58 | 59 | Responses: 60 | 61 | * `epbprtv0 ok` 62 | 63 | The query was executed successfully. `k` sets of results will appear after this line, each of them of the same form as in the normal query mode. 64 | 65 | * `epbprtv0 fail` 66 | 67 | No query has been prepared. 68 | 69 | #### *empty line* (zero tokens) 70 | 71 | Finish prepared query mode and terminate the front-end. 72 | 73 | Responses: 74 | 75 | * `epbprtv0 ok` 76 | 77 | The front-end has terminated. 78 | -------------------------------------------------------------------------------- /protocol/ext-add-query-metric.md: -------------------------------------------------------------------------------- 1 | (This document describes an extension that front-ends aren't required to implement. In fact, no front-end is *known* to implement it; this document serves as an example of how to extend the protocol. Front-ends that don't implement this extension should reject attempts to set the `add-query-metric` configuration option.) 2 | 3 | When the configuration option `add-query-metric` is set to a value other than `all`, if that value identifies a query metric known to the front-end, then the value for this metric will be appended to each query response. This option may be set several times; each one will (try to) add another query metric. 4 | 5 | Setting this option to the value `all` will cause *all* metrics known to the front-end to be included. 6 | 7 | ## Commands 8 | 9 | ### Configuration mode 10 | 11 | #### `add-query-metric METRIC` (two tokens) 12 | 13 | Request that query responses also include the value of the query metric `METRIC`, if that's recognised by the front-end. 14 | 15 | Responses: 16 | 17 | * `epbprtv0 ok` 18 | 19 | The metric `METRIC` was recognised, and query responses will include a value for it. 20 | 21 | * `epbprtv0 fail` 22 | 23 | The metric `METRIC` was not recognised; query responses will not include a value for it. 24 | 25 | #### `add-query-metric all` (two tokens) 26 | 27 | Request that query responses also include the values of all query metrics recognised by the front-end. 28 | 29 | Responses: 30 | 31 | * `epbprtv0 ok` 32 | 33 | Query responses will include the values of all metrics known to the front-end. (This may not actually change the output; the front-end could, in principle, support this extension but not recognise any query metrics.) 34 | 35 | * `epbprtv0 fail` 36 | 37 | Front-ends may choose to emit this response if they do not recognise *any* query metrics, but they may also emit `epbprtv0 ok` in these circumstances (to indicate that all zero metrics will be included in the output). 38 | 39 | ### Query mode 40 | 41 | #### `ENTRY N` (two tokens) 42 | 43 | This extension changes the behaviour of one response: 44 | 45 | * `epbprtv0 ok R [NAME0 VALUE0 ...]` 46 | 47 | `R` (greater than zero and less than or equal to `N`) close matches were found. Each of the next `R` lines, when tokenised, will consist of the token `epbprtv0` followed by a token specifying the index of a close match. (The first line should identify the *closest* close match, and the `R`-th should identify the furthest away.) 48 | 49 | If additional query metrics were specified and recognised during configuration mode, then their names and values will be provided as a number of pairs of tokens after `R`. For example, a response including the hypothetical `buckets_searched` and `candidates_checked` metrics might look like this: 50 | 51 | `epbprtv0 ok 10 buckets_searched 8 candidates_checked 507` 52 | -------------------------------------------------------------------------------- /ann_benchmarks/results.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import h5py 4 | import json 5 | import os 6 | import re 7 | import traceback 8 | 9 | 10 | def get_algorithm_name(name, batch_mode): 11 | if batch_mode: 12 | return name + "-batch" 13 | return name 14 | 15 | 16 | def is_batch(name): 17 | return "-batch" in name 18 | 19 | 20 | def get_result_filename(dataset=None, count=None, definition=None, 21 | query_arguments=None, batch_mode=False): 22 | d = ['results'] 23 | if dataset: 24 | d.append(dataset) 25 | if count: 26 | d.append(str(count)) 27 | if definition: 28 | d.append(get_algorithm_name(definition.algorithm, batch_mode)) 29 | data = definition.arguments + query_arguments 30 | d.append(re.sub(r'\W+', '_', json.dumps(data, 31 | sort_keys=True)).strip('_')) 32 | return os.path.join(*d) 33 | 34 | 35 | def store_results(dataset, count, definition, query_arguments, attrs, results, 36 | batch): 37 | fn = get_result_filename( 38 | dataset, count, definition, query_arguments, batch) 39 | head, tail = os.path.split(fn) 40 | if not os.path.isdir(head): 41 | os.makedirs(head) 42 | f = h5py.File(fn, 'w') 43 | for k, v in attrs.items(): 44 | f.attrs[k] = v 45 | times = f.create_dataset('times', (len(results),), 'f') 46 | neighbors = f.create_dataset('neighbors', (len(results), count), 'i') 47 | distances = f.create_dataset('distances', (len(results), count), 'f') 48 | for i, (time, ds) in enumerate(results): 49 | times[i] = time 50 | neighbors[i] = [n for n, d in ds] + [-1] * (count - len(ds)) 51 | distances[i] = [d for n, d in ds] + [float('inf')] * (count - len(ds)) 52 | f.close() 53 | 54 | 55 | def load_all_results(dataset=None, count=None, split_batched=False, 56 | batch_mode=False): 57 | for root, _, files in os.walk(get_result_filename(dataset, count)): 58 | for fn in files: 59 | try: 60 | if split_batched and batch_mode != is_batch(root): 61 | continue 62 | f = h5py.File(os.path.join(root, fn), 'r+') 63 | properties = dict(f.attrs) 64 | # TODO Fix this properly. Sometimes the hdf5 file returns bytes 65 | # This converts these bytes to strings before we work with them 66 | for k in properties.keys(): 67 | try: 68 | properties[k] = properties[k].decode() 69 | except: 70 | pass 71 | yield properties, f 72 | f.close() 73 | except: 74 | print('Was unable to read', fn) 75 | traceback.print_exc() 76 | 77 | 78 | def get_unique_algorithms(): 79 | algorithms = set() 80 | for properties, _ in load_all_results(): 81 | algorithms.add(properties['algo']) 82 | return algorithms 83 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/nmslib.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import nmslib 4 | from ann_benchmarks.constants import INDEX_DIR 5 | from ann_benchmarks.algorithms.base import BaseANN 6 | 7 | 8 | class NmslibReuseIndex(BaseANN): 9 | @staticmethod 10 | def encode(d): 11 | return ["%s=%s" % (a, b) for (a, b) in d.items()] 12 | 13 | def __init__(self, metric, method_name, index_param, query_param): 14 | self._nmslib_metric = { 15 | 'angular': 'cosinesimil', 'euclidean': 'l2'}[metric] 16 | self._method_name = method_name 17 | self._save_index = False 18 | self._index_param = NmslibReuseIndex.encode(index_param) 19 | if query_param is not False: 20 | self._query_param = NmslibReuseIndex.encode(query_param) 21 | self.name = ('Nmslib(method_name={}, index_param={}, ' 22 | 'query_param={})'.format(self._method_name, 23 | self._index_param, 24 | self._query_param)) 25 | else: 26 | self._query_param = None 27 | self.name = 'Nmslib(method_name=%s, index_param=%s)' % ( 28 | self._method_name, self._index_param) 29 | 30 | self._index_name = os.path.join(INDEX_DIR, "nmslib_%s_%s_%s" % ( 31 | self._method_name, metric, '_'.join(self._index_param))) 32 | 33 | d = os.path.dirname(self._index_name) 34 | if not os.path.exists(d): 35 | os.makedirs(d) 36 | 37 | def fit(self, X): 38 | if self._method_name == 'vptree': 39 | # To avoid this issue: terminate called after throwing an instance 40 | # of 'std::runtime_error' 41 | # what(): The data size is too small or the bucket size is too 42 | # big. Select the parameters so that is NOT 43 | # less than * 1000 44 | # Aborted (core dumped) 45 | self._index_param.append('bucketSize=%d' % 46 | min(int(X.shape[0] * 0.0005), 1000)) 47 | 48 | self._index = nmslib.init( 49 | space=self._nmslib_metric, method=self._method_name) 50 | self._index.addDataPointBatch(X) 51 | 52 | if os.path.exists(self._index_name): 53 | print('Loading index from file') 54 | self._index.loadIndex(self._index_name) 55 | else: 56 | self._index.createIndex(self._index_param) 57 | if self._save_index: 58 | self._index.saveIndex(self._index_name) 59 | if self._query_param is not None: 60 | self._index.setQueryTimeParams(self._query_param) 61 | 62 | def set_query_arguments(self, ef): 63 | if self._method_name == 'hnsw' or self._method_name == 'sw-graph': 64 | self._index.setQueryTimeParams(["efSearch=%s" % (ef)]) 65 | 66 | def query(self, v, n): 67 | ids, distances = self._index.knnQuery(v, n) 68 | return ids 69 | 70 | def batch_query(self, X, n): 71 | self.res = self._index.knnQueryBatch(X, n) 72 | 73 | def get_batch_results(self): 74 | return [x for x, _ in self.res] 75 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/panng_ngt.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import sys 3 | import os 4 | import ngtpy 5 | import numpy as np 6 | import subprocess 7 | import time 8 | from ann_benchmarks.algorithms.base import BaseANN 9 | from ann_benchmarks.constants import INDEX_DIR 10 | 11 | 12 | class PANNG(BaseANN): 13 | def __init__(self, metric, object_type, param): 14 | metrics = {'euclidean': 'L2', 'angular': 'Cosine'} 15 | self._edge_size = int(param['edge']) 16 | self._pathadj_size = int(param['pathadj']) 17 | self._edge_size_for_search = int(param['searchedge']) 18 | self._metric = metrics[metric] 19 | self._object_type = object_type 20 | print('PANNG: edge_size=' + str(self._edge_size)) 21 | print('PANNG: pathadj_size=' + str(self._pathadj_size)) 22 | print('PANNG: edge_size_for_search=' + str(self._edge_size_for_search)) 23 | print('PANNG: metric=' + metric) 24 | print('PANNG: object_type=' + object_type) 25 | 26 | def fit(self, X): 27 | print('PANNG: start indexing...') 28 | dim = len(X[0]) 29 | print('PANNG: # of data=' + str(len(X))) 30 | print('PANNG: Dimensionality=' + str(dim)) 31 | index_dir = 'indexes' 32 | if not os.path.exists(index_dir): 33 | os.makedirs(index_dir) 34 | index = os.path.join( 35 | index_dir, 36 | 'PANNG-' + str(self._edge_size) + '-' + str(self._pathadj_size)) 37 | print(index) 38 | if os.path.exists(index): 39 | print('PANNG: index already exists! ' + str(index)) 40 | else: 41 | t0 = time.time() 42 | ngtpy.create(path=index, dimension=dim, 43 | edge_size_for_creation=self._edge_size, 44 | distance_type=self._metric, 45 | object_type=self._object_type) 46 | idx = ngtpy.Index(path=index) 47 | idx.batch_insert(X, num_threads=24, debug=False) 48 | idx.save() 49 | idx.close() 50 | if self._pathadj_size > 0: 51 | print('PANNG: path adjustment') 52 | args = ['ngt', 'prune', '-s ' + str(self._pathadj_size), 53 | index] 54 | subprocess.call(args) 55 | indexingtime = time.time() - t0 56 | print('PANNG: indexing, adjustment and saving time(sec)={}' 57 | .format(indexingtime)) 58 | t0 = time.time() 59 | self.index = ngtpy.Index(path=index, read_only=True) 60 | opentime = time.time() - t0 61 | print('PANNG: open time(sec)=' + str(opentime)) 62 | 63 | def set_query_arguments(self, epsilon): 64 | print("PANNG: epsilon=" + str(epsilon)) 65 | self._epsilon = epsilon - 1.0 66 | self.name = 'PANNG-NGT(%d, %d, %d, %1.3f)' % ( 67 | self._edge_size, 68 | self._pathadj_size, 69 | self._edge_size_for_search, 70 | self._epsilon + 1.0) 71 | 72 | def query(self, v, n): 73 | results = self.index.search( 74 | v, n, self._epsilon, self._edge_size_for_search, 75 | with_distance=False) 76 | return results 77 | 78 | def freeIndex(self): 79 | print('PANNG: free') 80 | -------------------------------------------------------------------------------- /protocol/ext-prepared-queries.md: -------------------------------------------------------------------------------- 1 | (This document describes an extension that front-ends aren't required to implement. Front-ends that don't implement this extension should reject attempts to set the `prepared-queries` front-end configuration option.) 2 | 3 | When the front-end configuration option `prepared-queries` is set to `1`, after finishing training mode, the front-end will transition to prepared query mode instead of query mode. In prepared query mode, parsing a query point -- a potentially expensive operation -- and actually running a query are two different commands; this makes the query timings more representative of the underlying algorithm's behaviour without the overhead of this protocol. 4 | 5 | ## Commands 6 | 7 | ### Configuration mode 8 | 9 | #### `frontend prepared-queries V` (three tokens) 10 | 11 | If `V` is `1`, then request that the front-end transition into prepared query mode, and not query mode, after training mode has finished. If `V` is anything else, then request that it transition into query mode as usual. 12 | 13 | Responses: 14 | 15 | * `epbprtv0 ok` 16 | 17 | The front-end will transition into the requested query mode after the training mode has finished. 18 | 19 | * `epbprtv0 fail` 20 | 21 | This command has had no effect on the query mode transition. 22 | 23 | ### Training mode 24 | 25 | This extension changes the behaviour of one command in training mode: 26 | 27 | #### *empty line* (zero tokens) 28 | 29 | Finish training mode and enter prepared query mode. 30 | 31 | Responses: 32 | 33 | * `epbprtv0 ok COUNT1 [fail COUNT2]` 34 | 35 | `COUNT1` (potentially zero) entries were successfully interpreted and added to the data structure. (`COUNT2` entries couldn't be interpreted or couldn't be added for other reasons.): 36 | 37 | ### Prepared query mode 38 | 39 | In prepared query mode, front-ends should respond to three different kinds of command: 40 | 41 | #### `ENTRY N` (two tokens) 42 | 43 | Prepare to run a query to find at most `N` (greater than or equal to 1) close matches for `ENTRY`. 44 | 45 | Responses: 46 | 47 | * `epbprtv0 ok prepared true` 48 | 49 | Preparation is complete, the `query` command can now be used, and the underlying library wrapper has special support for prepared queries. 50 | 51 | * `epbprtv0 ok prepared false` 52 | 53 | The `query` command can now be used, but the underlying library wrapper doesn't have support for prepared queries, so the `query` command will perform the parsing of `ENTRY` as it would in normal query mode. 54 | 55 | #### `query` (one token) 56 | 57 | Run the last prepared query. 58 | 59 | Responses: 60 | 61 | * `epbprtv0 ok R` 62 | 63 | `R` (greater than zero and less than or equal to the value of `N` that was specified when the query was prepared) close matches were found. The next `R` lines, when tokenised, will consist of the token `epbprtv0` followed by a token specifying the index of a close match. (The first line should identify the *closest* close match, and the `R`-th should identify the furthest away.) 64 | 65 | * `epbprtv0 fail` 66 | 67 | Either no close matches were found, or no query has been prepared. 68 | 69 | #### *empty line* (zero tokens) 70 | 71 | Finish prepared query mode and terminate the front-end. 72 | 73 | Responses: 74 | 75 | * `epbprtv0 ok` 76 | 77 | The front-end has terminated. 78 | -------------------------------------------------------------------------------- /templates/general.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | {{ title }} 9 | 10 | 11 | 12 | 13 | 14 | 15 | 18 | 19 | 20 | 24 | 25 | 26 | 27 | 48 | 49 | {% block content %} {% endblock %} 50 | 51 |
52 |

Contact

53 |

ANN-Benchmarks has been developed by Martin Aumueller (maau@itu.dk), Erik Bernhardsson (mail@erikbern.com), and Alec Faitfull (alef@itu.dk). Please use 54 | Github to submit your implementation or improvements.

55 |
56 |
57 | 58 | 59 | -------------------------------------------------------------------------------- /test/test-metrics.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from ann_benchmarks.plotting.metrics import ( 3 | knn, queries_per_second, index_size, build_time, candidates, 4 | epsilon, rel) 5 | 6 | 7 | class DummyMetric(): 8 | 9 | def __init__(self): 10 | self.attrs = {} 11 | self.d = {} 12 | 13 | def __getitem__(self, key): 14 | return self.d.get(key, None) 15 | 16 | def __setitem__(self, key, value): 17 | self.d[key] = value 18 | 19 | def __contains__(self, key): 20 | return key in self.d 21 | 22 | def create_group(self, name): 23 | self.d[name] = DummyMetric() 24 | return self.d[name] 25 | 26 | 27 | class TestMetrics(unittest.TestCase): 28 | 29 | def setUp(self): 30 | pass 31 | 32 | def test_recall(self): 33 | exact_queries = [[0.1, 0.25]] 34 | run1 = [[]] 35 | run2 = [[0.2, 0.3]] 36 | run3 = [[0.2]] 37 | run4 = [[0.2, 0.25]] 38 | 39 | self.assertAlmostEqual( 40 | knn(exact_queries, run1, 2, DummyMetric()).attrs['mean'], 0.0) 41 | self.assertAlmostEqual( 42 | knn(exact_queries, run2, 2, DummyMetric()).attrs['mean'], 0.5) 43 | self.assertAlmostEqual( 44 | knn(exact_queries, run3, 2, DummyMetric()).attrs['mean'], 0.5) 45 | self.assertAlmostEqual( 46 | knn(exact_queries, run4, 2, DummyMetric()).attrs['mean'], 1.0) 47 | 48 | def test_epsilon_recall(self): 49 | exact_queries = [[0.05, 0.08, 0.24, 0.3]] 50 | run1 = [[]] 51 | run2 = [[0.1, 0.2, 0.55, 0.7]] 52 | 53 | self.assertAlmostEqual( 54 | epsilon(exact_queries, run1, 4, DummyMetric(), 1).attrs['mean'], 55 | 0.0) 56 | 57 | self.assertAlmostEqual( 58 | epsilon(exact_queries, run2, 4, 59 | DummyMetric(), 0.0001).attrs['mean'], 60 | 0.5) 61 | # distance can be off by factor (1 + 1) * 0.3 = 0.6 => recall .75 62 | self.assertAlmostEqual( 63 | epsilon(exact_queries, run2, 4, DummyMetric(), 1).attrs['mean'], 64 | 0.75) 65 | # distance can be off by factor (1 + 2) * 0.3 = 0.9 => recall 1 66 | self.assertAlmostEqual( 67 | epsilon(exact_queries, run2, 4, DummyMetric(), 2).attrs['mean'], 68 | 1.0) 69 | 70 | def test_relative(self): 71 | exact_queries = [[0.1, 0.2, 0.25, 0.3]] 72 | run1 = [] 73 | run2 = [[0.1, 0.2, 0.25, 0.3]] 74 | run3 = [[0.1, 0.2, 0.55, 0.9]] 75 | 76 | self.assertAlmostEqual( 77 | rel(exact_queries, run1, DummyMetric()), float("inf")) 78 | self.assertAlmostEqual(rel(exact_queries, run2, DummyMetric()), 1) 79 | # total distance exact: 0.85, total distance run3: 1.75 80 | self.assertAlmostEqual(rel(exact_queries, run3, DummyMetric()), 81 | 1.75 / 0.85) 82 | 83 | def test_queries_per_second(self): 84 | self.assertAlmostEqual( 85 | queries_per_second([], {"best_search_time": 0.01}), 86 | 100) 87 | 88 | def test_index_size(self): 89 | self.assertEqual(index_size([], {"index_size": 100}), 100) 90 | 91 | def test_build_time(self): 92 | self.assertEqual(build_time([], {"build_time": 100}), 100) 93 | 94 | def test_candidates(self): 95 | self.assertEqual(candidates([], {"candidates": 10}), 10) 96 | 97 | 98 | if __name__ == '__main__': 99 | unittest.main() 100 | -------------------------------------------------------------------------------- /templates/summary.html: -------------------------------------------------------------------------------- 1 | {% extends "general.html" %} 2 | {% block content %} 3 |
4 |

Info

5 |

ANN-Benchmarks is a benchmarking environment for approximate nearest neighbor algorithms search. This website contains the current benchmarking results. Please visit http://github.com/erikbern/ann-benchmarks/ to get an overview over evaluated data sets and algorithms. Make a pull request on Github to add your own code or improvements to the 6 | benchmarking system. 7 |

8 |
9 |

Benchmarking Results

10 |

Results are split by distance measure and dataset. In the bottom, you can find an overview of an algorithm's performance on all datasets. Each dataset is annoted 11 | by (k = ...), the number of nearest neighbors an algorithm was supposed to return. The plot shown depicts Recall (the fraction 12 | of true nearest neighbors found, on average over all queries) against Queries per second. Clicking on a plot reveils detailled interactive plots, including 13 | approximate recall, index size, and build time.

14 | {% for type in ['non-batch', 'batch'] %} 15 | {% if len(dataset_with_distances[type]) > 0 %} 16 | {% if type == 'batch' %} 17 |

Benchmarks for Batched Queries

18 | {% else %} 19 |

Benchmarks for Single Queries

20 | {% endif %} 21 | 22 |

Results by Dataset

23 | {% for distance_data in dataset_with_distances[type] %} 24 |

Distance: {{ distance_data.name }}

25 | {% for entry in distance_data.entries %} 26 | 27 |
28 |
29 |

{{entry.desc}}

30 |
31 |
32 | 33 |
34 |
35 |
36 |
37 | {% endfor %} 38 | {% endfor %} 39 |

Results by Algorithm

40 | 45 | {% for algo in algorithms[type].keys()%} 46 | 47 |
48 |
49 |

{{algo}}

50 |
51 |
52 | 53 |
54 |
55 |
56 |
57 | {% endfor %} 58 | {% endif %} 59 | {% endfor %} 60 | {% endblock %} 61 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/onng_ngt.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import sys 3 | import os 4 | import ngtpy 5 | import numpy as np 6 | import subprocess 7 | import time 8 | from ann_benchmarks.algorithms.base import BaseANN 9 | from ann_benchmarks.constants import INDEX_DIR 10 | 11 | 12 | class ONNG(BaseANN): 13 | def __init__(self, metric, object_type, epsilon, param): 14 | metrics = {'euclidean': '2', 'angular': 'C'} 15 | self._edge_size = int(param['edge']) 16 | self._outdegree = int(param['outdegree']) 17 | self._indegree = int(param['indegree']) 18 | self._metric = metrics[metric] 19 | self._object_type = object_type 20 | self._edge_size_for_search = int(param['search_edge']) if 'search_edge' in param.keys() else -2 21 | self._tree_disabled = (param['tree'] == False) if 'tree' in param.keys() else False 22 | self._build_time_limit = 4 23 | self._epsilon = epsilon 24 | print('ONNG: edge_size=' + str(self._edge_size)) 25 | print('ONNG: outdegree=' + str(self._outdegree)) 26 | print('ONNG: indegree=' + str(self._indegree)) 27 | print('ONNG: edge_size_for_search=' + str(self._edge_size_for_search)) 28 | print('ONNG: epsilon=' + str(self._epsilon)) 29 | print('ONNG: metric=' + metric) 30 | print('ONNG: object_type=' + object_type) 31 | 32 | def fit(self, X): 33 | print('ONNG: start indexing...') 34 | dim = len(X[0]) 35 | print('ONNG: # of data=' + str(len(X))) 36 | print('ONNG: dimensionality=' + str(dim)) 37 | index_dir = 'indexes' 38 | if not os.path.exists(index_dir): 39 | os.makedirs(index_dir) 40 | index = os.path.join( 41 | index_dir, 42 | 'ONNG-{}-{}-{}'.format(self._edge_size, self._outdegree, 43 | self._indegree)) 44 | anngIndex = os.path.join(index_dir, 'ANNG-' + str(self._edge_size)) 45 | print('ONNG: index=' + index) 46 | if (not os.path.exists(index)) and (not os.path.exists(anngIndex)): 47 | print('ONNG: create ANNG') 48 | t = time.time() 49 | args = ['ngt', 'create', '-it', '-p8', '-b500', '-ga', '-of', 50 | '-D' + self._metric, '-d' + str(dim), 51 | '-E' + str(self._edge_size), '-S0', 52 | '-e' + str(self._epsilon), '-P0', '-B30', 53 | '-T' + str(self._build_time_limit), anngIndex] 54 | subprocess.call(args) 55 | idx = ngtpy.Index(path=anngIndex) 56 | idx.batch_insert(X, num_threads=24, debug=False) 57 | idx.save() 58 | idx.close() 59 | print('ONNG: ANNG construction time(sec)=' + str(time.time() - t)) 60 | if not os.path.exists(index): 61 | print('ONNG: degree adjustment') 62 | t = time.time() 63 | args = ['ngt', 'reconstruct-graph', '-mS', 64 | '-o ' + str(self._outdegree), 65 | '-i ' + str(self._indegree), anngIndex, index] 66 | subprocess.call(args) 67 | print('ONNG: degree adjustment time(sec)=' + str(time.time() - t)) 68 | if os.path.exists(index): 69 | print('ONNG: index already exists! ' + str(index)) 70 | t = time.time() 71 | print(self._tree_disabled) 72 | #self.index = ngtpy.Index(index, read_only=True, tree_disabled=self._tree_disabled) 73 | self.index = ngtpy.Index(index, read_only=True) 74 | self.indexName = index 75 | print('ONNG: open time(sec)=' + str(time.time() - t)) 76 | else: 77 | print('ONNG: something wrong.') 78 | print('ONNG: end of fit') 79 | 80 | def set_query_arguments(self, epsilon): 81 | print("ONNG: epsilon=" + str(epsilon)) 82 | self._epsilon = epsilon - 1.0 83 | self.name = 'ONNG-NGT(%s, %s, %s, %s, %1.3f)' % ( 84 | self._edge_size, self._outdegree, 85 | self._indegree, self._edge_size_for_search, 86 | self._epsilon + 1.0) 87 | 88 | def query(self, v, n): 89 | results = self.index.search( 90 | v, n, self._epsilon, self._edge_size_for_search, 91 | with_distance=False) 92 | return results 93 | 94 | def freeIndex(self): 95 | print('ONNG: free') 96 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/elasticsearch.py: -------------------------------------------------------------------------------- 1 | """ 2 | ann-benchmarks interfaces for Elasticsearch. 3 | Note that this requires X-Pack, which is not included in the OSS version of Elasticsearch. 4 | """ 5 | import logging 6 | from time import sleep 7 | from urllib.error import URLError 8 | from urllib.request import Request, urlopen 9 | 10 | from elasticsearch import Elasticsearch 11 | from elasticsearch.helpers import bulk 12 | 13 | from ann_benchmarks.algorithms.base import BaseANN 14 | 15 | # Configure the elasticsearch logger. 16 | # By default, it writes an INFO statement for every request. 17 | logging.getLogger("elasticsearch").setLevel(logging.WARN) 18 | 19 | # Uncomment these lines if you want to see timing for every HTTP request and its duration. 20 | # logging.basicConfig(level=logging.INFO) 21 | # logging.getLogger("elasticsearch").setLevel(logging.INFO) 22 | 23 | def es_wait(): 24 | print("Waiting for elasticsearch health endpoint...") 25 | req = Request("http://localhost:9200/_cluster/health?wait_for_status=yellow&timeout=1s") 26 | for i in range(30): 27 | try: 28 | res = urlopen(req) 29 | if res.getcode() == 200: 30 | print("Elasticsearch is ready") 31 | return 32 | except URLError: 33 | pass 34 | sleep(1) 35 | raise RuntimeError("Failed to connect to local elasticsearch") 36 | 37 | 38 | class ElasticsearchScriptScoreQuery(BaseANN): 39 | """ 40 | KNN using the Elasticsearch dense_vector datatype and script score functions. 41 | - Dense vector field type: https://www.elastic.co/guide/en/elasticsearch/reference/master/dense-vector.html 42 | - Dense vector queries: https://www.elastic.co/guide/en/elasticsearch/reference/master/query-dsl-script-score-query.html 43 | """ 44 | 45 | def __init__(self, metric: str, dimension: int): 46 | self.name = f"elasticsearch-script-score-query_metric={metric}_dimension={dimension}" 47 | self.metric = metric 48 | self.dimension = dimension 49 | self.index = f"es-ssq-{metric}-{dimension}" 50 | self.es = Elasticsearch(["http://localhost:9200"]) 51 | self.batch_res = [] 52 | if self.metric == "euclidean": 53 | self.script = "1 / (1 + l2norm(params.query_vec, \"vec\"))" 54 | elif self.metric == "angular": 55 | self.script = "1.0 + cosineSimilarity(params.query_vec, \"vec\")" 56 | else: 57 | raise NotImplementedError(f"Not implemented for metric {self.metric}") 58 | es_wait() 59 | 60 | def fit(self, X): 61 | body = dict(settings=dict(number_of_shards=1, number_of_replicas=0)) 62 | mapping = dict( 63 | properties=dict( 64 | id=dict(type="keyword", store=True), 65 | vec=dict(type="dense_vector", dims=self.dimension) 66 | ) 67 | ) 68 | self.es.indices.create(self.index, body=body) 69 | self.es.indices.put_mapping(mapping, self.index) 70 | 71 | def gen(): 72 | for i, vec in enumerate(X): 73 | yield { "_op_type": "index", "_index": self.index, "vec": vec.tolist(), 'id': str(i + 1) } 74 | 75 | (_, errors) = bulk(self.es, gen(), chunk_size=500, max_retries=9) 76 | assert len(errors) == 0, errors 77 | 78 | self.es.indices.refresh(self.index) 79 | self.es.indices.forcemerge(self.index, max_num_segments=1) 80 | 81 | def query(self, q, n): 82 | body = dict( 83 | query=dict( 84 | script_score=dict( 85 | query=dict(match_all=dict()), 86 | script=dict( 87 | source=self.script, 88 | params=dict(query_vec=q.tolist()) 89 | ) 90 | ) 91 | ) 92 | ) 93 | res = self.es.search(index=self.index, body=body, size=n, _source=False, docvalue_fields=['id'], 94 | stored_fields="_none_", filter_path=["hits.hits.fields.id"]) 95 | return [int(h['fields']['id'][0]) - 1 for h in res['hits']['hits']] 96 | 97 | def batch_query(self, X, n): 98 | self.batch_res = [self.query(q, n) for q in X] 99 | 100 | def get_batch_results(self): 101 | return self.batch_res 102 | 103 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/pynndescent.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import pynndescent 3 | from ann_benchmarks.algorithms.base import BaseANN 4 | import numpy as np 5 | import scipy.sparse 6 | 7 | 8 | class PyNNDescent(BaseANN): 9 | def __init__(self, metric, index_param_dict, n_search_trees=1): 10 | if "n_neighbors" in index_param_dict: 11 | self._n_neighbors = int(index_param_dict["n_neighbors"]) 12 | else: 13 | self._n_neighbors = 30 14 | 15 | if "pruning_degree_multiplier" in index_param_dict: 16 | self._pruning_degree_multiplier = float( 17 | index_param_dict["pruning_degree_multiplier"] 18 | ) 19 | else: 20 | self._pruning_degree_multiplier = 1.5 21 | 22 | if "diversify_prob" in index_param_dict: 23 | self._diversify_prob = float(index_param_dict["diversify_prob"]) 24 | else: 25 | self._diversify_prob = 1.0 26 | 27 | if "leaf_size" in index_param_dict: 28 | self._leaf_size = int(index_param_dict["leaf_size"]) 29 | else: 30 | leaf_size = 32 31 | 32 | self._n_search_trees = int(n_search_trees) 33 | 34 | self._pynnd_metric = { 35 | "angular": "dot", 36 | # 'angular': 'cosine', 37 | "euclidean": "euclidean", 38 | "hamming": "hamming", 39 | "jaccard": "jaccard", 40 | }[metric] 41 | 42 | def _sparse_convert_for_fit(self, X): 43 | lil_data = [] 44 | self._n_cols = 1 45 | self._n_rows = len(X) 46 | for i in range(self._n_rows): 47 | lil_data.append([1] * len(X[i])) 48 | if max(X[i]) + 1 > self._n_cols: 49 | self._n_cols = max(X[i]) + 1 50 | 51 | result = scipy.sparse.lil_matrix( 52 | (self._n_rows, self._n_cols), dtype=np.int 53 | ) 54 | result.rows[:] = list(X) 55 | result.data[:] = lil_data 56 | return result.tocsr() 57 | 58 | def _sparse_convert_for_query(self, v): 59 | result = scipy.sparse.csr_matrix((1, self._n_cols), dtype=np.int) 60 | result.indptr = np.array([0, len(v)]) 61 | result.indices = np.array(v).astype(np.int32) 62 | result.data = np.ones(len(v), dtype=np.int) 63 | return result 64 | 65 | def fit(self, X): 66 | if self._pynnd_metric == "jaccard": 67 | # Convert to sparse matrix format 68 | X = self._sparse_convert_for_fit(X) 69 | 70 | self._index = pynndescent.NNDescent( 71 | X, 72 | n_neighbors=self._n_neighbors, 73 | metric=self._pynnd_metric, 74 | low_memory=True, 75 | leaf_size=self._leaf_size, 76 | pruning_degree_multiplier=self._pruning_degree_multiplier, 77 | diversify_prob=self._diversify_prob, 78 | n_search_trees=self._n_search_trees, 79 | compressed=True, 80 | verbose=True, 81 | ) 82 | if hasattr(self._index, "prepare"): 83 | self._index.prepare() 84 | else: 85 | self._index._init_search_graph() 86 | if self._index._is_sparse: 87 | if hasattr(self._index, "_init_sparse_search_function"): 88 | self._index._init_sparse_search_function() 89 | else: 90 | if hasattr(self._index, "_init_search_function"): 91 | self._index._init_search_function() 92 | 93 | def set_query_arguments(self, epsilon=0.1): 94 | self._epsilon = float(epsilon) 95 | 96 | def query(self, v, n): 97 | if self._pynnd_metric == "jaccard": 98 | # convert index array to sparse matrix format and query 99 | v = self._sparse_convert_for_query(v) 100 | ind, dist = self._index.query(v, k=n, epsilon=self._epsilon) 101 | else: 102 | ind, dist = self._index.query( 103 | v.reshape(1, -1).astype("float32"), k=n, epsilon=self._epsilon 104 | ) 105 | return ind[0] 106 | 107 | def __str__(self): 108 | str_template = "PyNNDescent(n_neighbors=%d, pruning_mult=%.2f, diversify_prob=%.3f, epsilon=%.3f, leaf_size=%02d)" 109 | return str_template % ( 110 | self._n_neighbors, 111 | self._pruning_degree_multiplier, 112 | self._diversify_prob, 113 | self._epsilon, 114 | self._leaf_size, 115 | ) 116 | -------------------------------------------------------------------------------- /ann_benchmarks/plotting/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import itertools 4 | import numpy 5 | from ann_benchmarks.plotting.metrics import all_metrics as metrics 6 | 7 | 8 | def get_or_create_metrics(run): 9 | if 'metrics' not in run: 10 | run.create_group('metrics') 11 | return run['metrics'] 12 | 13 | 14 | def create_pointset(data, xn, yn): 15 | xm, ym = (metrics[xn], metrics[yn]) 16 | rev_y = -1 if ym["worst"] < 0 else 1 17 | rev_x = -1 if xm["worst"] < 0 else 1 18 | data.sort(key=lambda t: (rev_y * t[-1], rev_x * t[-2])) 19 | 20 | axs, ays, als = [], [], [] 21 | # Generate Pareto frontier 22 | xs, ys, ls = [], [], [] 23 | last_x = xm["worst"] 24 | comparator = ((lambda xv, lx: xv > lx) 25 | if last_x < 0 else (lambda xv, lx: xv < lx)) 26 | for algo, algo_name, xv, yv in data: 27 | if not xv or not yv: 28 | continue 29 | axs.append(xv) 30 | ays.append(yv) 31 | als.append(algo_name) 32 | if comparator(xv, last_x): 33 | last_x = xv 34 | xs.append(xv) 35 | ys.append(yv) 36 | ls.append(algo_name) 37 | return xs, ys, ls, axs, ays, als 38 | 39 | 40 | def compute_metrics(true_nn_distances, res, metric_1, metric_2, 41 | recompute=False): 42 | all_results = {} 43 | for i, (properties, run) in enumerate(res): 44 | algo = properties['algo'] 45 | algo_name = properties['name'] 46 | # cache distances to avoid access to hdf5 file 47 | run_distances = numpy.array(run['distances']) 48 | if recompute and 'metrics' in run: 49 | del run['metrics'] 50 | metrics_cache = get_or_create_metrics(run) 51 | 52 | metric_1_value = metrics[metric_1]['function']( 53 | true_nn_distances, 54 | run_distances, metrics_cache, properties) 55 | metric_2_value = metrics[metric_2]['function']( 56 | true_nn_distances, 57 | run_distances, metrics_cache, properties) 58 | 59 | print('%3d: %80s %12.3f %12.3f' % 60 | (i, algo_name, metric_1_value, metric_2_value)) 61 | 62 | all_results.setdefault(algo, []).append( 63 | (algo, algo_name, metric_1_value, metric_2_value)) 64 | 65 | return all_results 66 | 67 | 68 | def compute_all_metrics(true_nn_distances, run, properties, recompute=False): 69 | algo = properties["algo"] 70 | algo_name = properties["name"] 71 | print('--') 72 | print(algo_name) 73 | results = {} 74 | # cache distances to avoid access to hdf5 file 75 | run_distances = numpy.array(run["distances"]) 76 | if recompute and 'metrics' in run: 77 | del run['metrics'] 78 | metrics_cache = get_or_create_metrics(run) 79 | 80 | for name, metric in metrics.items(): 81 | v = metric["function"]( 82 | true_nn_distances, run_distances, metrics_cache, properties) 83 | results[name] = v 84 | if v: 85 | print('%s: %g' % (name, v)) 86 | return (algo, algo_name, results) 87 | 88 | 89 | def generate_n_colors(n): 90 | vs = numpy.linspace(0.3, 0.9, 7) 91 | colors = [(.9, .4, .4, 1.)] 92 | 93 | def euclidean(a, b): 94 | return sum((x - y)**2 for x, y in zip(a, b)) 95 | while len(colors) < n: 96 | new_color = max(itertools.product(vs, vs, vs), 97 | key=lambda a: min(euclidean(a, b) for b in colors)) 98 | colors.append(new_color + (1.,)) 99 | return colors 100 | 101 | 102 | def create_linestyles(unique_algorithms): 103 | colors = dict( 104 | zip(unique_algorithms, generate_n_colors(len(unique_algorithms)))) 105 | linestyles = dict((algo, ['--', '-.', '-', ':'][i % 4]) 106 | for i, algo in enumerate(unique_algorithms)) 107 | markerstyles = dict((algo, ['+', '<', 'o', '*', 'x'][i % 5]) 108 | for i, algo in enumerate(unique_algorithms)) 109 | faded = dict((algo, (r, g, b, 0.3)) 110 | for algo, (r, g, b, a) in colors.items()) 111 | return dict((algo, (colors[algo], faded[algo], 112 | linestyles[algo], markerstyles[algo])) 113 | for algo in unique_algorithms) 114 | 115 | 116 | def get_up_down(metric): 117 | if metric["worst"] == float("inf"): 118 | return "down" 119 | return "up" 120 | 121 | 122 | def get_left_right(metric): 123 | if metric["worst"] == float("inf"): 124 | return "left" 125 | return "right" 126 | 127 | 128 | def get_plot_label(xm, ym): 129 | template = ("%(xlabel)s-%(ylabel)s tradeoff - %(updown)s and" 130 | " to the %(leftright)s is better") 131 | return template % {"xlabel": xm["description"], 132 | "ylabel": ym["description"], 133 | "updown": get_up_down(ym), 134 | "leftright": get_left_right(xm)} 135 | -------------------------------------------------------------------------------- /templates/chartjs.template: -------------------------------------------------------------------------------- 1 |

{{xlabel}}/{{ylabel}}

2 |
3 | 4 | 85 |
86 | {% if args.latex %} 87 |
88 |
89 | 90 |
91 |
92 | 97 | 102 | {% endif %} 103 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/elastiknn.py: -------------------------------------------------------------------------------- 1 | """ 2 | ann-benchmarks interfaces for elastiknn: https://github.com/alexklibisz/elastiknn 3 | Uses the elastiknn python client 4 | To install a local copy of the client, run `pip install --upgrade -e /path/to/elastiknn/client-python/` 5 | To monitor the Elasticsearch JVM using Visualvm, add `ports={ "8097": 8097 }` to the `containers.run` call in runner.py. 6 | """ 7 | from sys import stderr 8 | from urllib.error import URLError 9 | 10 | import numpy as np 11 | from elastiknn.api import Vec 12 | from elastiknn.models import ElastiknnModel 13 | from elastiknn.utils import dealias_metric 14 | 15 | from ann_benchmarks.algorithms.base import BaseANN 16 | 17 | from urllib.request import Request, urlopen 18 | from time import sleep, perf_counter 19 | 20 | import logging 21 | 22 | # Mute the elasticsearch logger. 23 | # By default, it writes an INFO statement for every request. 24 | logging.getLogger("elasticsearch").setLevel(logging.WARN) 25 | 26 | 27 | def es_wait(): 28 | print("Waiting for elasticsearch health endpoint...") 29 | req = Request("http://localhost:9200/_cluster/health?wait_for_status=yellow&timeout=1s") 30 | for i in range(30): 31 | try: 32 | res = urlopen(req) 33 | if res.getcode() == 200: 34 | print("Elasticsearch is ready") 35 | return 36 | except URLError: 37 | pass 38 | sleep(1) 39 | raise RuntimeError("Failed to connect to local elasticsearch") 40 | 41 | 42 | class Exact(BaseANN): 43 | 44 | def __init__(self, metric: str, dimension: int): 45 | self.name = f"eknn-exact-metric={metric}_dimension={dimension}" 46 | self.metric = metric 47 | self.dimension = dimension 48 | self.model = ElastiknnModel("exact", dealias_metric(metric)) 49 | self.batch_res = None 50 | es_wait() 51 | 52 | def _handle_sparse(self, X): 53 | # convert list of lists of indices to sparse vectors. 54 | return [Vec.SparseBool(x, self.dimension) for x in X] 55 | 56 | def fit(self, X): 57 | if self.metric in {'jaccard', 'hamming'}: 58 | return self.model.fit(self._handle_sparse(X), shards=1)[0] 59 | else: 60 | return self.model.fit(X, shards=1) 61 | 62 | def query(self, q, n): 63 | if self.metric in {'jaccard', 'hamming'}: 64 | return self.model.kneighbors(self._handle_sparse([q]), n)[0] 65 | else: 66 | return self.model.kneighbors(np.expand_dims(q, 0), n)[0] 67 | 68 | def batch_query(self, X, n): 69 | if self.metric in {'jaccard', 'hamming'}: 70 | self.batch_res = self.model.kneighbors(self._handle_sparse(X), n) 71 | else: 72 | self.batch_res = self.model.kneighbors(X, n) 73 | 74 | def get_batch_results(self): 75 | return self.batch_res 76 | 77 | 78 | class L2Lsh(BaseANN): 79 | 80 | def __init__(self, L: int, k: int, w: int): 81 | self.name_prefix = f"eknn-l2lsh-L={L}-k={k}-w={w}" 82 | self.name = None # set based on query args. 83 | self.model = ElastiknnModel("lsh", "l2", mapping_params=dict(L=L, k=k, w=w)) 84 | self.X_max = 1.0 85 | self.query_params = dict() 86 | self.batch_res = None 87 | self.sum_query_dur = 0 88 | self.num_queries = 0 89 | es_wait() 90 | 91 | def fit(self, X): 92 | print(f"{self.name_prefix}: indexing {len(X)} vectors") 93 | 94 | # I found it's best to scale the vectors into [0, 1], i.e. divide by the max. 95 | self.X_max = X.max() 96 | return self.model.fit(X / self.X_max, shards=1) 97 | 98 | def set_query_arguments(self, candidates: int, probes: int): 99 | # This gets called when starting a new batch of queries. 100 | # Update the name and model's query parameters based on the given params. 101 | self.name = f"{self.name_prefix}_candidates={candidates}_probes={probes}" 102 | self.model.set_query_params(dict(candidates=candidates, probes=probes)) 103 | # Reset the counters. 104 | self.num_queries = 0 105 | self.sum_query_dur = 0 106 | 107 | def query(self, q, n): 108 | # If QPS after 100 queries is < 10, this setting is bad and won't complete within the default timeout. 109 | if self.num_queries > 100 and self.num_queries / self.sum_query_dur < 10: 110 | print("Throughput after 100 queries is less than 10 q/s. Terminating to avoid wasteful computation.", flush=True) 111 | exit(0) 112 | else: 113 | t0 = perf_counter() 114 | res = self.model.kneighbors(np.expand_dims(q, 0) / self.X_max, n)[0] 115 | dur = (perf_counter() - t0) 116 | self.sum_query_dur += dur 117 | self.num_queries += 1 118 | return res 119 | 120 | def batch_query(self, X, n): 121 | self.batch_res = self.model.kneighbors(X, n) 122 | 123 | def get_batch_results(self): 124 | return self.batch_res 125 | -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib as mpl 3 | mpl.use('Agg') # noqa 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import argparse 7 | 8 | from ann_benchmarks.datasets import get_dataset 9 | from ann_benchmarks.algorithms.definitions import get_definitions 10 | from ann_benchmarks.plotting.metrics import all_metrics as metrics 11 | from ann_benchmarks.plotting.utils import (get_plot_label, compute_metrics, 12 | create_linestyles, create_pointset) 13 | from ann_benchmarks.results import (store_results, load_all_results, 14 | get_unique_algorithms, get_algorithm_name) 15 | 16 | 17 | def create_plot(all_data, raw, x_log, y_log, xn, yn, fn_out, linestyles, 18 | batch): 19 | xm, ym = (metrics[xn], metrics[yn]) 20 | # Now generate each plot 21 | handles = [] 22 | labels = [] 23 | plt.figure(figsize=(12, 9)) 24 | for algo in sorted(all_data.keys(), key=lambda x: x.lower()): 25 | xs, ys, ls, axs, ays, als = create_pointset(all_data[algo], xn, yn) 26 | color, faded, linestyle, marker = linestyles[algo] 27 | handle, = plt.plot(xs, ys, '-', label=algo, color=color, 28 | ms=7, mew=3, lw=3, linestyle=linestyle, 29 | marker=marker) 30 | handles.append(handle) 31 | if raw: 32 | handle2, = plt.plot(axs, ays, '-', label=algo, color=faded, 33 | ms=5, mew=2, lw=2, linestyle=linestyle, 34 | marker=marker) 35 | labels.append(get_algorithm_name(algo, batch)) 36 | 37 | if x_log: 38 | plt.gca().set_xscale('log') 39 | if y_log: 40 | plt.gca().set_yscale('log') 41 | plt.gca().set_title(get_plot_label(xm, ym)) 42 | plt.gca().set_ylabel(ym['description']) 43 | plt.gca().set_xlabel(xm['description']) 44 | box = plt.gca().get_position() 45 | # plt.gca().set_position([box.x0, box.y0, box.width * 0.8, box.height]) 46 | plt.gca().legend(handles, labels, loc='center left', 47 | bbox_to_anchor=(1, 0.5), prop={'size': 9}) 48 | plt.grid(b=True, which='major', color='0.65', linestyle='-') 49 | if 'lim' in xm: 50 | plt.xlim(xm['lim']) 51 | if 'lim' in ym: 52 | plt.ylim(ym['lim']) 53 | plt.savefig(fn_out, bbox_inches='tight') 54 | plt.close() 55 | 56 | 57 | if __name__ == "__main__": 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument( 60 | '--dataset', 61 | metavar="DATASET", 62 | default='glove-100-angular') 63 | parser.add_argument( 64 | '--count', 65 | default=10) 66 | parser.add_argument( 67 | '--definitions', 68 | metavar='FILE', 69 | help='load algorithm definitions from FILE', 70 | default='algos.yaml') 71 | parser.add_argument( 72 | '--limit', 73 | default=-1) 74 | parser.add_argument( 75 | '-o', '--output') 76 | parser.add_argument( 77 | '-x', '--x-axis', 78 | help='Which metric to use on the X-axis', 79 | choices=metrics.keys(), 80 | default="k-nn") 81 | parser.add_argument( 82 | '-y', '--y-axis', 83 | help='Which metric to use on the Y-axis', 84 | choices=metrics.keys(), 85 | default="qps") 86 | parser.add_argument( 87 | '-X', '--x-log', 88 | help='Draw the X-axis using a logarithmic scale', 89 | action='store_true') 90 | parser.add_argument( 91 | '-Y', '--y-log', 92 | help='Draw the Y-axis using a logarithmic scale', 93 | action='store_true') 94 | parser.add_argument( 95 | '--raw', 96 | help='Show raw results (not just Pareto frontier) in faded colours', 97 | action='store_true') 98 | parser.add_argument( 99 | '--batch', 100 | help='Plot runs in batch mode', 101 | action='store_true') 102 | parser.add_argument( 103 | '--recompute', 104 | help='Clears the cache and recomputes the metrics', 105 | action='store_true') 106 | args = parser.parse_args() 107 | 108 | if not args.output: 109 | args.output = 'results/%s.png' % get_algorithm_name( 110 | args.dataset, args.batch) 111 | print('writing output to %s' % args.output) 112 | 113 | dataset = get_dataset(args.dataset) 114 | count = int(args.count) 115 | unique_algorithms = get_unique_algorithms() 116 | results = load_all_results(args.dataset, count, True, args.batch) 117 | linestyles = create_linestyles(sorted(unique_algorithms)) 118 | runs = compute_metrics(np.array(dataset["distances"]), 119 | results, args.x_axis, args.y_axis, args.recompute) 120 | if not runs: 121 | raise Exception('Nothing to plot') 122 | 123 | create_plot(runs, args.raw, args.x_log, 124 | args.y_log, args.x_axis, args.y_axis, args.output, 125 | linestyles, args.batch) 126 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/bruteforce.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import numpy 3 | import sklearn.neighbors 4 | from ann_benchmarks.distance import metrics as pd 5 | from ann_benchmarks.algorithms.base import BaseANN 6 | 7 | 8 | class BruteForce(BaseANN): 9 | def __init__(self, metric): 10 | if metric not in ('angular', 'euclidean', 'hamming'): 11 | raise NotImplementedError( 12 | "BruteForce doesn't support metric %s" % metric) 13 | self._metric = metric 14 | self.name = 'BruteForce()' 15 | 16 | def fit(self, X): 17 | metric = {'angular': 'cosine', 'euclidean': 'l2', 18 | 'hamming': 'hamming'}[self._metric] 19 | self._nbrs = sklearn.neighbors.NearestNeighbors( 20 | algorithm='brute', metric=metric) 21 | self._nbrs.fit(X) 22 | 23 | def query(self, v, n): 24 | return list(self._nbrs.kneighbors( 25 | [v], return_distance=False, n_neighbors=n)[0]) 26 | 27 | def query_with_distances(self, v, n): 28 | (distances, positions) = self._nbrs.kneighbors( 29 | [v], return_distance=True, n_neighbors=n) 30 | return zip(list(positions[0]), list(distances[0])) 31 | 32 | 33 | class BruteForceBLAS(BaseANN): 34 | """kNN search that uses a linear scan = brute force.""" 35 | 36 | def __init__(self, metric, precision=numpy.float32): 37 | if metric not in ('angular', 'euclidean', 'hamming', 'jaccard'): 38 | raise NotImplementedError( 39 | "BruteForceBLAS doesn't support metric %s" % metric) 40 | elif metric == 'hamming' and precision != numpy.bool: 41 | raise NotImplementedError( 42 | "BruteForceBLAS doesn't support precision" 43 | " %s with Hamming distances" % precision) 44 | self._metric = metric 45 | self._precision = precision 46 | self.name = 'BruteForceBLAS()' 47 | 48 | def fit(self, X): 49 | """Initialize the search index.""" 50 | if self._metric == 'angular': 51 | # precompute (squared) length of each vector 52 | lens = (X ** 2).sum(-1) 53 | # normalize index vectors to unit length 54 | X /= numpy.sqrt(lens)[..., numpy.newaxis] 55 | self.index = numpy.ascontiguousarray(X, dtype=self._precision) 56 | elif self._metric == 'hamming': 57 | # Regarding bitvectors as vectors in l_2 is faster for blas 58 | X = X.astype(numpy.float32) 59 | # precompute (squared) length of each vector 60 | lens = (X ** 2).sum(-1) 61 | self.index = numpy.ascontiguousarray(X, dtype=numpy.float32) 62 | self.lengths = numpy.ascontiguousarray(lens, dtype=numpy.float32) 63 | elif self._metric == 'euclidean': 64 | # precompute (squared) length of each vector 65 | lens = (X ** 2).sum(-1) 66 | self.index = numpy.ascontiguousarray(X, dtype=self._precision) 67 | self.lengths = numpy.ascontiguousarray(lens, dtype=self._precision) 68 | elif self._metric == 'jaccard': 69 | self.index = X 70 | else: 71 | # shouldn't get past the constructor! 72 | assert False, "invalid metric" 73 | 74 | def query(self, v, n): 75 | return [index for index, _ in self.query_with_distances(v, n)] 76 | 77 | def query_with_distances(self, v, n): 78 | """Find indices of `n` most similar vectors from the index to query 79 | vector `v`.""" 80 | 81 | if self._metric != 'jaccard': 82 | # use same precision for query as for index 83 | v = numpy.ascontiguousarray(v, dtype=self.index.dtype) 84 | 85 | # HACK we ignore query length as that's a constant 86 | # not affecting the final ordering 87 | if self._metric == 'angular': 88 | # argmax_a cossim(a, b) = argmax_a dot(a, b) / |a||b| = argmin_a -dot(a, b) # noqa 89 | dists = -numpy.dot(self.index, v) 90 | elif self._metric == 'euclidean': 91 | # argmin_a (a - b)^2 = argmin_a a^2 - 2ab + b^2 = argmin_a a^2 - 2ab # noqa 92 | dists = self.lengths - 2 * numpy.dot(self.index, v) 93 | elif self._metric == 'hamming': 94 | # Just compute hamming distance using euclidean distance 95 | dists = self.lengths - 2 * numpy.dot(self.index, v) 96 | elif self._metric == 'jaccard': 97 | dists = [pd[self._metric]['distance'](v, e) for e in self.index] 98 | else: 99 | # shouldn't get past the constructor! 100 | assert False, "invalid metric" 101 | # partition-sort by distance, get `n` closest 102 | nearest_indices = numpy.argpartition(dists, n)[:n] 103 | indices = [idx for idx in nearest_indices if pd[self._metric] 104 | ["distance_valid"](dists[idx])] 105 | 106 | def fix(index): 107 | ep = self.index[index] 108 | ev = v 109 | return (index, pd[self._metric]['distance'](ep, ev)) 110 | return map(fix, indices) 111 | -------------------------------------------------------------------------------- /ann_benchmarks/plotting/metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import numpy as np 3 | 4 | 5 | def knn_threshold(data, count, epsilon): 6 | return data[count - 1] + epsilon 7 | 8 | 9 | def epsilon_threshold(data, count, epsilon): 10 | return data[count - 1] * (1 + epsilon) 11 | 12 | 13 | def get_recall_values(dataset_distances, run_distances, count, threshold, 14 | epsilon=1e-3): 15 | recalls = np.zeros(len(run_distances)) 16 | for i in range(len(run_distances)): 17 | t = threshold(dataset_distances[i], count, epsilon) 18 | actual = 0 19 | for d in run_distances[i][:count]: 20 | if d <= t: 21 | actual += 1 22 | recalls[i] = actual 23 | return (np.mean(recalls) / float(count), 24 | np.std(recalls) / float(count), 25 | recalls) 26 | 27 | 28 | def knn(dataset_distances, run_distances, count, metrics, epsilon=1e-3): 29 | if 'knn' not in metrics: 30 | print('Computing knn metrics') 31 | knn_metrics = metrics.create_group('knn') 32 | mean, std, recalls = get_recall_values(dataset_distances, 33 | run_distances, count, 34 | knn_threshold, epsilon) 35 | knn_metrics.attrs['mean'] = mean 36 | knn_metrics.attrs['std'] = std 37 | knn_metrics['recalls'] = recalls 38 | else: 39 | print("Found cached result") 40 | return metrics['knn'] 41 | 42 | 43 | def epsilon(dataset_distances, run_distances, count, metrics, epsilon=0.01): 44 | s = 'eps' + str(epsilon) 45 | if s not in metrics: 46 | print('Computing epsilon metrics') 47 | epsilon_metrics = metrics.create_group(s) 48 | mean, std, recalls = get_recall_values(dataset_distances, 49 | run_distances, count, 50 | epsilon_threshold, epsilon) 51 | epsilon_metrics.attrs['mean'] = mean 52 | epsilon_metrics.attrs['std'] = std 53 | epsilon_metrics['recalls'] = recalls 54 | else: 55 | print("Found cached result") 56 | return metrics[s] 57 | 58 | 59 | def rel(dataset_distances, run_distances, metrics): 60 | if 'rel' not in metrics.attrs: 61 | print('Computing rel metrics') 62 | total_closest_distance = 0.0 63 | total_candidate_distance = 0.0 64 | for true_distances, found_distances in zip(dataset_distances, 65 | run_distances): 66 | for rdist, cdist in zip(true_distances, found_distances): 67 | total_closest_distance += rdist 68 | total_candidate_distance += cdist 69 | if total_closest_distance < 0.01: 70 | metrics.attrs['rel'] = float("inf") 71 | else: 72 | metrics.attrs['rel'] = total_candidate_distance / \ 73 | total_closest_distance 74 | else: 75 | print("Found cached result") 76 | return metrics.attrs['rel'] 77 | 78 | 79 | def queries_per_second(queries, attrs): 80 | return 1.0 / attrs["best_search_time"] 81 | 82 | 83 | def index_size(queries, attrs): 84 | # TODO(erikbern): should replace this with peak memory usage or something 85 | return attrs.get("index_size", 0) 86 | 87 | 88 | def build_time(queries, attrs): 89 | return attrs["build_time"] 90 | 91 | 92 | def candidates(queries, attrs): 93 | return attrs["candidates"] 94 | 95 | 96 | def dist_computations(queries, attrs): 97 | return attrs.get("dist_comps", 0) / (attrs['run_count'] * len(queries)) 98 | 99 | 100 | all_metrics = { 101 | "k-nn": { 102 | "description": "Recall", 103 | "function": lambda true_distances, run_distances, metrics, run_attrs: knn(true_distances, run_distances, run_attrs["count"], metrics).attrs['mean'], # noqa 104 | "worst": float("-inf"), 105 | "lim": [0.0, 1.03] 106 | }, 107 | "epsilon": { 108 | "description": "Epsilon 0.01 Recall", 109 | "function": lambda true_distances, run_distances, metrics, run_attrs: epsilon(true_distances, run_distances, run_attrs["count"], metrics).attrs['mean'], # noqa 110 | "worst": float("-inf") 111 | }, 112 | "largeepsilon": { 113 | "description": "Epsilon 0.1 Recall", 114 | "function": lambda true_distances, run_distances, metrics, run_attrs: epsilon(true_distances, run_distances, run_attrs["count"], metrics, 0.1).attrs['mean'], # noqa 115 | "worst": float("-inf") 116 | }, 117 | "rel": { 118 | "description": "Relative Error", 119 | "function": lambda true_distances, run_distances, metrics, run_attrs: rel(true_distances, run_distances, metrics), # noqa 120 | "worst": float("inf") 121 | }, 122 | "qps": { 123 | "description": "Queries per second (1/s)", 124 | "function": lambda true_distances, run_distances, metrics, run_attrs: queries_per_second(true_distances, run_attrs), # noqa 125 | "worst": float("-inf") 126 | }, 127 | "distcomps": { 128 | "description": "Distance computations", 129 | "function": lambda true_distances, run_distances, metrics, run_attrs: dist_computations(true_distances, run_attrs), # noqa 130 | "worst": float("inf") 131 | }, 132 | "build": { 133 | "description": "Build time (s)", 134 | "function": lambda true_distances, run_distances, metrics, run_attrs: build_time(true_distances, run_attrs), # noqa 135 | "worst": float("inf") 136 | }, 137 | "candidates": { 138 | "description": "Candidates generated", 139 | "function": lambda true_distances, run_distances, metrics, run_attrs: candidates(true_distances, run_attrs), # noqa 140 | "worst": float("inf") 141 | }, 142 | "indexsize": { 143 | "description": "Index size (kB)", 144 | "function": lambda true_distances, run_distances, metrics, run_attrs: index_size(true_distances, run_attrs), # noqa 145 | "worst": float("inf") 146 | }, 147 | "queriessize": { 148 | "description": "Index size (kB)/Queries per second (s)", 149 | "function": lambda true_distances, run_distances, metrics, run_attrs: index_size(true_distances, run_attrs) / queries_per_second(true_distances, run_attrs), # noqa 150 | "worst": float("inf") 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/definitions.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from os import sep as pathsep 3 | import collections 4 | import importlib 5 | import os 6 | import sys 7 | import traceback 8 | import yaml 9 | from enum import Enum 10 | from itertools import product 11 | 12 | 13 | Definition = collections.namedtuple( 14 | 'Definition', 15 | ['algorithm', 'constructor', 'module', 'docker_tag', 16 | 'arguments', 'query_argument_groups', 'disabled']) 17 | 18 | 19 | def get_algorithm_name(name, batch): 20 | if batch: 21 | return name + "-batch" 22 | return name 23 | 24 | 25 | def instantiate_algorithm(definition): 26 | print('Trying to instantiate %s.%s(%s)' % 27 | (definition.module, definition.constructor, definition.arguments)) 28 | module = importlib.import_module(definition.module) 29 | constructor = getattr(module, definition.constructor) 30 | return constructor(*definition.arguments) 31 | 32 | 33 | class InstantiationStatus(Enum): 34 | AVAILABLE = 0 35 | NO_CONSTRUCTOR = 1 36 | NO_MODULE = 2 37 | 38 | 39 | def algorithm_status(definition): 40 | try: 41 | module = importlib.import_module(definition.module) 42 | if hasattr(module, definition.constructor): 43 | return InstantiationStatus.AVAILABLE 44 | else: 45 | return InstantiationStatus.NO_CONSTRUCTOR 46 | except ImportError: 47 | return InstantiationStatus.NO_MODULE 48 | 49 | 50 | def _generate_combinations(args): 51 | if isinstance(args, list): 52 | args = [el if isinstance(el, list) else [el] for el in args] 53 | return [list(x) for x in product(*args)] 54 | elif isinstance(args, dict): 55 | flat = [] 56 | for k, v in args.items(): 57 | if isinstance(v, list): 58 | flat.append([(k, el) for el in v]) 59 | else: 60 | flat.append([(k, v)]) 61 | return [dict(x) for x in product(*flat)] 62 | else: 63 | raise TypeError("No args handling exists for %s" % type(args).__name__) 64 | 65 | 66 | def _substitute_variables(arg, vs): 67 | if isinstance(arg, dict): 68 | return dict([(k, _substitute_variables(v, vs)) 69 | for k, v in arg.items()]) 70 | elif isinstance(arg, list): 71 | return [_substitute_variables(a, vs) for a in arg] 72 | elif isinstance(arg, str) and arg in vs: 73 | return vs[arg] 74 | else: 75 | return arg 76 | 77 | 78 | def _get_definitions(definition_file): 79 | with open(definition_file, "r") as f: 80 | return yaml.load(f, yaml.SafeLoader) 81 | 82 | 83 | def list_algorithms(definition_file): 84 | definitions = _get_definitions(definition_file) 85 | 86 | print('The following algorithms are supported...') 87 | for point in definitions: 88 | print('\t... for the point type "%s"...' % point) 89 | for metric in definitions[point]: 90 | print('\t\t... and the distance metric "%s":' % metric) 91 | for algorithm in definitions[point][metric]: 92 | print('\t\t\t%s' % algorithm) 93 | 94 | 95 | def get_unique_algorithms(definition_file): 96 | definitions = _get_definitions(definition_file) 97 | algos = set() 98 | for point in definitions: 99 | for metric in definitions[point]: 100 | for algorithm in definitions[point][metric]: 101 | algos.add(algorithm) 102 | return list(sorted(algos)) 103 | 104 | 105 | def get_definitions(definition_file, dimension, point_type="float", 106 | distance_metric="euclidean", count=10): 107 | definitions = _get_definitions(definition_file) 108 | 109 | algorithm_definitions = {} 110 | if "any" in definitions[point_type]: 111 | algorithm_definitions.update(definitions[point_type]["any"]) 112 | algorithm_definitions.update(definitions[point_type][distance_metric]) 113 | 114 | definitions = [] 115 | for (name, algo) in algorithm_definitions.items(): 116 | for k in ['docker-tag', 'module', 'constructor']: 117 | if k not in algo: 118 | raise Exception( 119 | 'algorithm %s does not define a "%s" property' % (name, k)) 120 | 121 | base_args = [] 122 | if "base-args" in algo: 123 | base_args = algo["base-args"] 124 | 125 | for run_group in algo["run-groups"].values(): 126 | if "arg-groups" in run_group: 127 | groups = [] 128 | for arg_group in run_group["arg-groups"]: 129 | if isinstance(arg_group, dict): 130 | # Dictionaries need to be expanded into lists in order 131 | # for the subsequent call to _generate_combinations to 132 | # do the right thing 133 | groups.append(_generate_combinations(arg_group)) 134 | else: 135 | groups.append(arg_group) 136 | args = _generate_combinations(groups) 137 | elif "args" in run_group: 138 | args = _generate_combinations(run_group["args"]) 139 | else: 140 | assert False, "? what? %s" % run_group 141 | 142 | if "query-arg-groups" in run_group: 143 | groups = [] 144 | for arg_group in run_group["query-arg-groups"]: 145 | if isinstance(arg_group, dict): 146 | groups.append(_generate_combinations(arg_group)) 147 | else: 148 | groups.append(arg_group) 149 | query_args = _generate_combinations(groups) 150 | elif "query-args" in run_group: 151 | query_args = _generate_combinations(run_group["query-args"]) 152 | else: 153 | query_args = [] 154 | 155 | for arg_group in args: 156 | aargs = [] 157 | aargs.extend(base_args) 158 | if isinstance(arg_group, list): 159 | aargs.extend(arg_group) 160 | else: 161 | aargs.append(arg_group) 162 | 163 | vs = { 164 | "@count": count, 165 | "@metric": distance_metric, 166 | "@dimension": dimension 167 | } 168 | aargs = [_substitute_variables(arg, vs) for arg in aargs] 169 | definitions.append(Definition( 170 | algorithm=name, 171 | docker_tag=algo['docker-tag'], 172 | module=algo['module'], 173 | constructor=algo['constructor'], 174 | arguments=aargs, 175 | query_argument_groups=query_args, 176 | disabled=algo.get('disabled', False) 177 | )) 178 | 179 | return definitions 180 | -------------------------------------------------------------------------------- /protocol/bf-runner.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | from enum import Enum 5 | from shlex import split 6 | 7 | from ann_benchmarks.data import type_info 8 | from ann_benchmarks.distance import metrics 9 | from ann_benchmarks.algorithms.bruteforce import BruteForce, BruteForceBLAS 10 | 11 | 12 | class QueryMode(Enum): 13 | NORMAL = 0, 14 | PREPARED = 1, 15 | BATCH = 2 16 | 17 | 18 | __true_print = print 19 | 20 | 21 | def print(*args, **kwargs): 22 | __true_print(*args, **kwargs) 23 | sys.stdout.flush() 24 | 25 | 26 | def next_line(): 27 | for line in iter(sys.stdin.readline, ''): 28 | yield split(line.strip()) 29 | 30 | 31 | if __name__ == '__main__': 32 | point_type = None 33 | distance = None 34 | query_mode = QueryMode.NORMAL 35 | fast = False 36 | query_parameters = False 37 | # Configuration mode 38 | for line in next_line(): 39 | if not line: 40 | break 41 | elif len(line) == 2: 42 | var, val = line[0], line[1] 43 | if var == "point-type": 44 | if val in type_info: 45 | point_type = type_info[val] 46 | print("epbprtv0 ok") 47 | else: 48 | print("epbprtv0 fail") 49 | elif var == "distance": 50 | if val in metrics: 51 | distance = val 52 | print("epbprtv0 ok") 53 | else: 54 | print("epbprtv0 fail") 55 | elif var == "fast": 56 | fast = (val == "1") 57 | print("epbprtv0 ok") 58 | else: 59 | print("epbprtv0 fail") 60 | elif len(line) == 3 and line[0] == "frontend": 61 | var, val = line[1], line[2] 62 | if var == "prepared-queries": 63 | query_mode = \ 64 | QueryMode.PREPARED if val == "1" else QueryMode.NORMAL 65 | print("epbprtv0 ok") 66 | elif var == "batch-queries": 67 | query_mode = \ 68 | QueryMode.BATCH if val == "1" else QueryMode.NORMAL 69 | print("epbprtv0 ok") 70 | elif var == "query-parameters": 71 | query_parameters = (val == "1") 72 | print("epbprtv0 ok") 73 | else: 74 | print("epbprtv0 fail") 75 | else: 76 | print("epbprtv0 fail") 77 | if point_type and distance: 78 | print("epbprtv0 ok") 79 | else: 80 | print("epbprtv0 fail") 81 | sys.exit(1) 82 | 83 | obj = None 84 | if not fast: 85 | obj = BruteForce(distance) 86 | else: 87 | obj = BruteForceBLAS(distance) 88 | 89 | parser = point_type["parse_entry"] 90 | # Training mode 91 | points = [] 92 | for line in next_line(): 93 | if not line: 94 | break 95 | elif len(line) == 1: 96 | point = line[0] 97 | try: 98 | parsed = parser(point) 99 | print("epbprtv0 ok %d" % len(points)) 100 | points.append(parsed) 101 | except ValueError: 102 | print("epbprtv0 fail") 103 | else: 104 | print("epbprtv0 fail len %d" % len(line)) 105 | if "finish_entries" in point_type: 106 | points = point_type["finish_entries"](points) 107 | obj.fit(points) 108 | print("epbprtv0 ok %d" % len(points)) 109 | 110 | def _query_parameters(line): 111 | if hasattr(obj, "set_query_arguments"): 112 | try: 113 | obj.set_query_arguments(*line[1:-1]) 114 | print("epbprtv0 ok") 115 | except TypeError: 116 | print("epbprtv0 fail") 117 | else: 118 | print("epbprtv0 fail") 119 | 120 | if query_mode == QueryMode.NORMAL: 121 | # Query mode 122 | for line in next_line(): 123 | if not line: 124 | break 125 | elif query_parameters and line[0] == "query-params" \ 126 | and line[-1] == "set": 127 | _query_parameters(line) 128 | elif len(line) == 2: 129 | try: 130 | query_point, k = line[0], int(line[1]) 131 | parsed = parser(query_point) 132 | results = obj.query(parsed, k) 133 | if results: 134 | print("epbprtv0 ok %d" % len(results)) 135 | for index in results: 136 | print("epbprtv0 %d" % index) 137 | else: 138 | print("epbprtv0 fail") 139 | except ValueError: 140 | print("epbprtv0 fail") 141 | else: 142 | print("epbprtv0 fail") 143 | elif query_mode == QueryMode.PREPARED: 144 | # Prepared query mode 145 | parsed = None 146 | k = None 147 | for line in next_line(): 148 | if not line: 149 | break 150 | elif query_parameters and line[0] == "query-params" \ 151 | and line[-1] == "set": 152 | _query_parameters(line) 153 | elif line == ["query"]: 154 | if parsed and k: 155 | results = obj.query(parsed, k) 156 | if results: 157 | print("epbprtv0 ok %d" % len(results)) 158 | for index in results: 159 | print("epbprtv0 %d" % index) 160 | else: 161 | print("epbprtv0 fail") 162 | else: 163 | print("epbprtv0 fail") 164 | elif len(line) == 2: 165 | try: 166 | parsed, k = parser(line[0]), int(line[1]) 167 | print("epbprtv0 ok prepared true") 168 | except ValueError: 169 | print("epbprtv0 fail") 170 | else: 171 | print("epbprtv0 fail") 172 | elif query_mode == QueryMode.BATCH: 173 | # Batch query mode 174 | parsed = None 175 | k = None 176 | for line in next_line(): 177 | if not line: 178 | break 179 | elif query_parameters and line[0] == "query-params" \ 180 | and line[-1] == "set": 181 | _query_parameters(line) 182 | elif line == ["query"]: 183 | if parsed and k: 184 | results = obj.batch_query(parsed, k) 185 | print("epbprtv0 ok") 186 | for result in obj.get_batch_results(): 187 | if result: 188 | print("epbprtv0 ok %d" % len(result)) 189 | for index in result: 190 | print("epbprtv0 %d" % index) 191 | else: 192 | print("epbprtv0 fail") 193 | else: 194 | print("epbprtv0 fail") 195 | elif len(line) > 1: 196 | try: 197 | parsed, k = map(parser, line[0:-1]), int(line[-1]) 198 | print("epbprtv0 ok") 199 | except ValueError as e: 200 | print("epbprtv0 fail" % e) 201 | else: 202 | print("epbprtv0 fail") 203 | pass 204 | print("epbprtv0 ok") 205 | -------------------------------------------------------------------------------- /protocol/specification.md: -------------------------------------------------------------------------------- 1 | This document specifies a simple text-based protocol that can be used to benchmark algorithms that don't have a Python wrapper. A program that implements the algorithm side of this specification will be referred to in the rest of this document as a "front-end". 2 | 3 | This protocol is line-oriented; both sides should configure their input and output streams to be line-buffered. Front-ends receive messages by reading lines from standard input and send messages by writing lines to standard output. 4 | 5 | ## Modes 6 | 7 | A front-end begins in configuration mode. When configuration is complete, it transitions into training mode; when training data has been supplied, into query mode; and, when no more queries remain, it terminates. It isn't possible to return from one mode to an earlier mode without restarting the front-end. 8 | 9 | A front-end reads lines from standard input, tokenises them, and interprets them according to its current mode; responses are written as lines to standard output. To enable protocol responses to be distinguished from other messages that may appear on standard output, the first token of a line containing a response will always be `epbprtv0`; the second will be `ok` when a command succeeds, potentially followed by other tokens, and `fail` when it doesn't. 10 | 11 | (The obscure token `epbprtv0` is intended to uniquely identify this protocol, and is meant to suggest something like "**e**xternal **p**rogram **b**enchmarking **pr**o**t**ocol, **v**ersion **0**".) 12 | 13 | A front-end may choose to include extra tokens in its responses after the tokens required by this specification to communicate more information back to the caller. 14 | 15 | ## Tokenisation 16 | 17 | Both the front-end and `ann-benchmarks` perform *tokenisation* on the lines of text they send and receive. The rules for tokenisation are as follows: 18 | 19 | * A token is a sequence of characters separated by one or more whitespace characters. 20 | 21 | Input | Token 1 | Token 2 | Token 3 22 | ----- | ------- | ------- | ------- 23 | abc | abc | | 24 | a bc | a | bc | 25 | a bc | a | bc | 26 | a b c | a | b | c 27 | 28 | * A sequence surrounded by single quote marks will be treated as part of a token, even if it contains whitespace or doesn't contain any other characters. 29 | 30 | Input | Token 1 | Token 2 | Token 3 31 | ----- | ------- | ------- | ------- 32 | 'a b c' | a b c | | 33 | 'a b c'd | a b cd | | 34 | a '' b | a | *empty string* | b 35 | 36 | * A sequence surrounded by double quote marks will be treated as part of a token, even if it contains whitespace or doesn't contain any other characters. 37 | 38 | Input | Token 1 | Token 2 | Token 3 39 | ----- | ------- | ------- | ------- 40 | "a b c" | a b c | | 41 | "a b c"d | a b cd | | 42 | a "" b | a | *empty string* | b 43 | 44 | * Outside of a quoted sequence, preceding a character with a backslash causes any special significance it may have to be ignored; the character is then said to have been "escaped". 45 | 46 | Input | Token 1 | Token 2 | Token 3 47 | ----- | ------- | ------- | ------- 48 | \a \b \c | a | b | c 49 | 50 | An escaped whitespace character doesn't separate tokens: 51 | 52 | Input | Token 1 | Token 2 53 | ----- | ------- | ------- 54 | a b\ c | a | b c | 55 | "a b c"\ d | a b c d | 56 | 57 | An escaped quote mark doesn't begin a sequence: 58 | 59 | Input | Token 1 | Token 2 | Token 3 60 | ----- | ------- | ------- | ------- 61 | \'a b c\' | a | b | c | 62 | \"a b c\" | a | b | c | 63 | 64 | An escaped backslash doesn't escape the subsequent character: 65 | 66 | Input | Token 1 | Token 2 67 | ----- | ------- | ------- 68 | a\\\\"b c" d | a\b c | d 69 | 70 | * In sequences begun by a double quote mark, only double quote marks and backslashes (and, for compatibility reasons, dollar signs) may be escaped; the backslash otherwise has no special significance. 71 | 72 | Input | Token 1 | Token 2 | Token 3 73 | ----- | ------- | ------- | ------- 74 | "\a \b" \c | \a \b | c | 75 | "\\\\ \\" \\$ a" "\b" c | \ " $ a | \b | c 76 | 77 | * In sequences begun by a single quote mark, a backslash has no special significance. 78 | 79 | Input | Token 1 | Token 2 80 | ----- | ------- | ------- 81 | 'a b' c | a b | c 82 | 'a b\\' c | a b\ | c 83 | 84 | Apart from the fact that newline characters can't be escaped, these rules should match the tokenisation rules of the POSIX shell. 85 | 86 | ## Commands 87 | 88 | Commands are sent to the front-end by `ann-benchmarks`. Each command consists of a single line of text; the front-end replies with one or more lines of text. Front-ends can't initiate communication; they can only reply to commands. 89 | 90 | This section specifies these commands, along with the possible responses a front-end might send. 91 | 92 | If a front-end receives a command that it doesn't understand in the current mode (or at all), it should respond with `epbprtv0 fail` and continue processing commands. 93 | 94 | ### Configuration mode 95 | 96 | In configuration mode, front-ends should respond to three different kinds of command: 97 | 98 | #### `VAR VAL` (two tokens) 99 | 100 | Set the value of the algorithm configuration option `VAR` to `VAL`. 101 | 102 | Responses: 103 | 104 | * `epbprtv0 ok` 105 | 106 | The value specified for the algorithm configuration option `VAR` was acceptable, and the option has been set. 107 | 108 | * `epbprtv0 fail` 109 | 110 | The value specified for the algorithm configuration option `VAR` wasn't acceptable. No change has been made to the value of this option. 111 | 112 | #### `frontend VAR VAL` (three tokens) 113 | 114 | Set the value of the front-end configuration option `VAR` to `VAL`. Front-end configuration options may cause the front-end to behave in a manner other than that described in this specification. 115 | 116 | Responses: 117 | 118 | * `epbprtv0 ok` 119 | 120 | The value specified for the front-end configuration option `VAR` was acceptable, and the option has been set. 121 | 122 | * `epbprtv0 fail` 123 | 124 | The value specified for the front-end configuration option `VAR` wasn't acceptable. No change has been made to the value of this option. 125 | 126 | #### *empty line* (zero tokens) 127 | 128 | Finish configuration mode and enter training mode. 129 | 130 | Responses: 131 | 132 | * `epbprtv0 ok` 133 | 134 | Training mode has been entered. 135 | 136 | * `epbprtv0 fail` 137 | 138 | One or more configuration options required by the algorithm weren't specified, and so the query process has terminated. 139 | 140 | ### Training mode 141 | 142 | In training mode, front-ends should respond to two different kinds of command: 143 | 144 | #### `ENTRY` (one token) 145 | 146 | Interpret `ENTRY` as an item of training data. 147 | 148 | Responses: 149 | 150 | * `epbprtv0 ok` 151 | 152 | `ENTRY` was added as the next item of training data. The index values returned in query mode refer to the first item added as `0`, the second as `1`, and so on. 153 | 154 | * `epbprtv0 fail` 155 | 156 | Either `ENTRY` couldn't be interpreted as an item of training data, or the training data wasn't accepted. 157 | 158 | #### *empty line* (zero tokens) 159 | 160 | Finish training mode and enter query mode. 161 | 162 | Responses: 163 | 164 | * `epbprtv0 ok COUNT1 [fail COUNT2]` 165 | 166 | `COUNT1` (potentially zero) entries were successfully interpreted and added to the data structure. (`COUNT2` entries couldn't be interpreted or couldn't be added for other reasons.) 167 | 168 | ### Query mode 169 | 170 | In query mode, front-ends should respond to two different kinds of command: 171 | 172 | #### `ENTRY N` (two tokens) 173 | 174 | Return the indices of at most `N` (greater than or equal to 1) close matches for `ENTRY`. 175 | 176 | Responses: 177 | 178 | * `epbprtv0 ok R` 179 | 180 | `R` (greater than zero and less than or equal to `N`) close matches were found. Each of the next `R` lines, when tokenised, will consist of the token `epbprtv0` followed by a token specifying the index of a close match. (The first line should identify the *closest* close match, and the `R`-th should identify the furthest away.) 181 | 182 | * `epbprtv0 fail` 183 | 184 | No close matches were found. 185 | 186 | #### *empty line* (zero tokens) 187 | 188 | Finish query mode and terminate the front-end. 189 | 190 | Responses: 191 | 192 | * `epbprtv0 ok` 193 | 194 | The front-end has terminated. 195 | -------------------------------------------------------------------------------- /ann_benchmarks/algorithms/subprocess.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from os.path import basename 3 | import shlex 4 | from types import MethodType 5 | import psutil 6 | import subprocess 7 | from ann_benchmarks.data import \ 8 | bit_unparse_entry, int_unparse_entry, float_unparse_entry 9 | from ann_benchmarks.algorithms.base import BaseANN 10 | 11 | 12 | class SubprocessStoppedError(Exception): 13 | def __init__(self, code): 14 | super(Exception, self).__init__(code) 15 | self.code = code 16 | 17 | 18 | class Subprocess(BaseANN): 19 | def _raw_line(self): 20 | return shlex.split( 21 | self._get_program_handle().stdout.readline().strip()) 22 | 23 | def _line(self): 24 | line = self._raw_line() 25 | # print("<- %s" % (" ".join(line))) 26 | while len(line) < 1 or line[0] != "epbprtv0": 27 | line = self._raw_line() 28 | return line[1:] 29 | 30 | @staticmethod 31 | def _quote(token): 32 | return "'" + str(token).replace("'", "'\\'") + "'" 33 | 34 | def _write(self, string): 35 | # print("-> %s" % string) 36 | self._get_program_handle().stdin.write(string + "\n") 37 | 38 | # Called immediately before transitioning from query mode to training mode 39 | def _configuration_hook(self): 40 | pass 41 | 42 | def _get_program_handle(self): 43 | if self._program: 44 | self._program.poll() 45 | if self._program.returncode: 46 | raise SubprocessStoppedError(self._program.returncode) 47 | else: 48 | self._program = subprocess.Popen( 49 | self._args, 50 | bufsize=1, # line buffering 51 | stdin=subprocess.PIPE, 52 | stdout=subprocess.PIPE, 53 | universal_newlines=True) 54 | 55 | for key, value in iter(self._params.items()): 56 | self._write("%s %s" % 57 | (Subprocess._quote(key), Subprocess._quote(value))) 58 | assert self._line()[0] == "ok", """\ 59 | assigning value '%s' to option '%s' failed""" % (value, key) 60 | self._configuration_hook() 61 | 62 | self._write("") 63 | assert self._line()[0] == "ok", """\ 64 | transitioning to training mode failed""" 65 | return self._program 66 | 67 | def __init__(self, args, encoder, params): 68 | self.name = "Subprocess(program = %s, %s)" % \ 69 | (basename(args[0]), str(params)) 70 | self._program = None 71 | self._args = args 72 | self._encoder = encoder 73 | self._params = params 74 | 75 | def get_memory_usage(self): 76 | if not self._program: 77 | self._get_program_handle() 78 | return psutil.Process(pid=self._program.pid).memory_info().rss / 1024 79 | 80 | def fit(self, X): 81 | for entry in X: 82 | d = Subprocess._quote(self._encoder(entry)) 83 | self._write(d) 84 | assert self._line()[0] == "ok", """\ 85 | encoded training point '%s' was rejected""" % d 86 | self._write("") 87 | assert self._line()[0] == "ok", """\ 88 | transitioning to query mode failed""" 89 | 90 | def query(self, v, n): 91 | d = Subprocess._quote(self._encoder(v)) 92 | self._write("%s %d" % (d, n)) 93 | return self._handle_query_response() 94 | 95 | def _handle_query_response(self): 96 | status = self._line() 97 | if status[0] == "ok": 98 | count = int(status[1]) 99 | return self._collect_query_response_lines(count) 100 | else: 101 | assert status[0] == "fail", """\ 102 | query neither succeeded nor failed""" 103 | return [] 104 | 105 | def _collect_query_response_lines(self, count): 106 | results = [] 107 | i = 0 108 | while i < count: 109 | line = self._line() 110 | results.append(int(line[0])) 111 | i += 1 112 | return results 113 | 114 | def done(self): 115 | if self._program: 116 | self._program.poll() 117 | if not self._program.returncode: 118 | self._program.terminate() 119 | 120 | 121 | class PreparedSubprocess(Subprocess): 122 | def __init__(self, args, encoder, params): 123 | super(PreparedSubprocess, self).__init__(args, encoder, params) 124 | self._result_count = None 125 | 126 | def _configuration_hook(self): 127 | self._write("frontend prepared-queries 1") 128 | assert self._line()[0] == "ok", """\ 129 | enabling prepared queries mode failed""" 130 | 131 | def query(self, v, n): 132 | self.prepare_query(v, n) 133 | self.run_prepared_query() 134 | return self.get_prepared_query_results() 135 | 136 | def prepare_query(self, v, n): 137 | d = Subprocess._quote(self._encoder(v)) 138 | self._write("%s %d" % (d, n)) 139 | assert self._line()[0] == "ok", """\ 140 | preparing the query '%s' failed""" % d 141 | 142 | def run_prepared_query(self): 143 | self._write("query") 144 | status = self._line() 145 | if status[0] == "ok": 146 | self._result_count = int(status[1]) 147 | else: 148 | assert status[0] == "fail", """\ 149 | query neither succeeded nor failed""" 150 | self._result_count = 0 151 | 152 | def get_prepared_query_results(self): 153 | if self._result_count: 154 | try: 155 | return self._collect_query_response_lines(self._result_count) 156 | finally: 157 | self._result_count = 0 158 | else: 159 | return [] 160 | 161 | 162 | class BatchSubprocess(Subprocess): 163 | def __init__(self, args, encoder, params): 164 | super(BatchSubprocess, self).__init__(args, encoder, params) 165 | self._qp_count = None 166 | 167 | def _configuration_hook(self): 168 | self._write("frontend batch-queries 1") 169 | assert self._line()[0] == "ok", """\ 170 | enabling batch queries mode failed""" 171 | 172 | def query(self, v, n): 173 | self.prepare_batch_query([v], n) 174 | self.run_batch_query() 175 | return self.get_batch_results()[0] 176 | 177 | def prepare_batch_query(self, X, n): 178 | d = " ".join(map(lambda p: Subprocess._quote(self._encoder(p)), X)) 179 | self._qp_count = len(X) 180 | self._write("%s %d" % (d, n)) 181 | assert self._line()[0] == "ok", """\ 182 | preparing the batch query '%s' failed""" % d 183 | 184 | def run_batch_query(self): 185 | self._write("query") 186 | status = self._line() 187 | assert status[0] == "ok", """\ 188 | batch query failed completely""" 189 | 190 | def get_batch_results(self): 191 | results = [] 192 | i = 0 193 | while i < self._qp_count: 194 | # print("%d/%d" % (i, self._qp_count)) 195 | status = self._line() 196 | if status[0] == "ok": 197 | rc = int(status[1]) 198 | results.append(self._collect_query_response_lines(rc)) 199 | else: 200 | results.append([]) 201 | i += 1 202 | return results 203 | 204 | 205 | def BitSubprocess(args, params): 206 | return Subprocess(args, bit_unparse_entry, params) 207 | 208 | 209 | def BitSubprocessPrepared(args, params): 210 | return PreparedSubprocess(args, bit_unparse_entry, params) 211 | 212 | 213 | def FloatSubprocess(args, params): 214 | return Subprocess(args, float_unparse_entry, params) 215 | 216 | 217 | def FloatSubprocessPrepared(args, params): 218 | return PreparedSubprocess(args, float_unparse_entry, params) 219 | 220 | 221 | def FloatSubprocessBatch(args, params): 222 | return BatchSubprocess(args, float_unparse_entry, params) 223 | 224 | 225 | def IntSubprocess(args, params): 226 | return Subprocess(args, int_unparse_entry, params) 227 | 228 | 229 | def QueryParamWrapper(constructor, args, params): 230 | r = constructor(args, params) 231 | 232 | def _do(self, original=r._configuration_hook): 233 | original() 234 | self._write("frontend query-parameters 1") 235 | assert self._line()[0] == "ok", """\ 236 | enabling query parameter support failed""" 237 | r._configuration_hook = MethodType(_do, r) 238 | 239 | def _sqa(self, *args): 240 | self._write("query-params %s set" % 241 | (" ".join(map(Subprocess._quote, args)))) 242 | assert self._line()[0] == "ok", """\ 243 | reconfiguring query parameters failed""" 244 | print(args) 245 | r.set_query_arguments = MethodType(_sqa, r) 246 | return r 247 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Benchmarking nearest neighbors 2 | ============================== 3 | 4 | [![travis badge](https://img.shields.io/travis/erikbern/ann-benchmarks/master.svg?style=flat)](https://travis-ci.org/erikbern/ann-benchmarks) 5 | 6 | Doing fast searching of nearest neighbors in high dimensional spaces is an increasingly important problem, but so far there has not been a lot of empirical attempts at comparing approaches in an objective way. 7 | 8 | This project contains some tools to benchmark various implementations of approximate nearest neighbor (ANN) search for different metrics. We have pregenerated datasets (in HDF5) formats and we also have Docker containers for each algorithm. There's a [test suite](https://travis-ci.org/erikbern/ann-benchmarks) that makes sure every algorithm works. 9 | 10 | Evaluated 11 | ========= 12 | 13 | * [Annoy](https://github.com/spotify/annoy) 14 | * [FLANN](http://www.cs.ubc.ca/research/flann/) 15 | * [scikit-learn](http://scikit-learn.org/stable/modules/neighbors.html): LSHForest, KDTree, BallTree 16 | * [PANNS](https://github.com/ryanrhymes/panns) 17 | * [NearPy](http://pixelogik.github.io/NearPy/) 18 | * [KGraph](https://github.com/aaalgo/kgraph) 19 | * [NMSLIB (Non-Metric Space Library)](https://github.com/nmslib/nmslib): SWGraph, HNSW, BallTree, MPLSH 20 | * [hnswlib (a part of nmslib project)](https://github.com/nmslib/hnsw) 21 | * [RPForest](https://github.com/lyst/rpforest) 22 | * [FAISS](https://github.com/facebookresearch/faiss.git) 23 | * [DolphinnPy](https://github.com/ipsarros/DolphinnPy) 24 | * [Datasketch](https://github.com/ekzhu/datasketch) 25 | * [PyNNDescent](https://github.com/lmcinnes/pynndescent) 26 | * [MRPT](https://github.com/teemupitkanen/mrpt) 27 | * [NGT](https://github.com/yahoojapan/NGT): ONNG, PANNG 28 | * [SPTAG](https://github.com/microsoft/SPTAG) 29 | * [PUFFINN](https://github.com/puffinn/puffinn) 30 | * [N2](https://github.com/kakao/n2) 31 | * [ScaNN](https://github.com/google-research/google-research/tree/master/scann) 32 | * [Elastiknn](https://github.com/alexklibisz/elastiknn) 33 | 34 | Data sets 35 | ========= 36 | 37 | We have a number of precomputed data sets for this. All data sets are pre-split into train/test and come with ground truth data in the form of the top 100 neighbors. We store them in a HDF5 format: 38 | 39 | | Dataset | Dimensions | Train size | Test size | Neighbors | Distance | Download | 40 | | ----------------------------------------------------------------- | ---------: | ---------: | --------: | --------: | --------- | -------------------------------------------------------------------------- | 41 | | [DEEP1B](http://sites.skoltech.ru/compvision/noimi/) | 96 | 9,990,000 | 10,000 | 100 | Angular | [HDF5](http://ann-benchmarks.com/deep-image-96-angular.hdf5) (3.6GB) 42 | | [Fashion-MNIST](https://github.com/zalandoresearch/fashion-mnist) | 784 | 60,000 | 10,000 | 100 | Euclidean | [HDF5](http://ann-benchmarks.com/fashion-mnist-784-euclidean.hdf5) (217MB) | 43 | | [GIST](http://corpus-texmex.irisa.fr/) | 960 | 1,000,000 | 1,000 | 100 | Euclidean | [HDF5](http://ann-benchmarks.com/gist-960-euclidean.hdf5) (3.6GB) | 44 | | [GloVe](http://nlp.stanford.edu/projects/glove/) | 25 | 1,183,514 | 10,000 | 100 | Angular | [HDF5](http://ann-benchmarks.com/glove-25-angular.hdf5) (121MB) | 45 | | GloVe | 50 | 1,183,514 | 10,000 | 100 | Angular | [HDF5](http://ann-benchmarks.com/glove-50-angular.hdf5) (235MB) | 46 | | GloVe | 100 | 1,183,514 | 10,000 | 100 | Angular | [HDF5](http://ann-benchmarks.com/glove-100-angular.hdf5) (463MB) | 47 | | GloVe | 200 | 1,183,514 | 10,000 | 100 | Angular | [HDF5](http://ann-benchmarks.com/glove-200-angular.hdf5) (918MB) | 48 | | [Kosarak](http://fimi.uantwerpen.be/data/) | 27983 | 74,962 | 500 | 100 | Jaccard | [HDF5](http://ann-benchmarks.com/kosarak-jaccard.hdf5) (2.0GB) | 49 | | [MNIST](http://yann.lecun.com/exdb/mnist/) | 784 | 60,000 | 10,000 | 100 | Euclidean | [HDF5](http://ann-benchmarks.com/mnist-784-euclidean.hdf5) (217MB) | 50 | | [NYTimes](https://archive.ics.uci.edu/ml/datasets/bag+of+words) | 256 | 290,000 | 10,000 | 100 | Angular | [HDF5](http://ann-benchmarks.com/nytimes-256-angular.hdf5) (301MB) | 51 | | [SIFT](https://corpus-texmex.irisa.fr/) | 128 | 1,000,000 | 10,000 | 100 | Euclidean | [HDF5](http://ann-benchmarks.com/sift-128-euclidean.hdf5) (501MB) | 52 | | [Last.fm](https://github.com/erikbern/ann-benchmarks/pull/91) | 65 | 292,385 | 50,000 | 100 | Angular | [HDF5](http://ann-benchmarks.com/lastfm-64-dot.hdf5) (135MB) | 53 | 54 | Results 55 | ======= 56 | 57 | These are all as of 2020-07-12, running all benchmarks on a c5.4xlarge machine on AWS with `--parallelism` set to 3: 58 | 59 | glove-100-angular 60 | ----------------- 61 | 62 | ![glove-100-angular](https://raw.github.com/erikbern/ann-benchmarks/master/results/glove-100-angular.png) 63 | 64 | sift-128-euclidean 65 | ------------------ 66 | 67 | ![glove-100-angular](https://raw.github.com/erikbern/ann-benchmarks/master/results/sift-128-euclidean.png) 68 | 69 | fashion-mnist-784-euclidean 70 | --------------------------- 71 | 72 | ![fashion-mnist-784-euclidean](https://raw.github.com/erikbern/ann-benchmarks/master/results/fashion-mnist-784-euclidean.png) 73 | 74 | lastfm-64-dot 75 | ------------------ 76 | 77 | ![lastfm-64-dot](https://raw.github.com/erikbern/ann-benchmarks/master/results/lastfm-64-dot.png) 78 | 79 | nytimes-256-angular 80 | ------------------- 81 | 82 | ![nytimes-256-angular](https://raw.github.com/erikbern/ann-benchmarks/master/results/nytimes-256-angular.png) 83 | 84 | glove-25-angular 85 | ---------------- 86 | 87 | ![glove-25-angular](https://raw.github.com/erikbern/ann-benchmarks/master/results/glove-25-angular.png) 88 | 89 | Install 90 | ======= 91 | 92 | The only prerequisite is Python (tested with 3.6) and Docker. 93 | 94 | 1. Clone the repo. 95 | 2. Run `pip install -r requirements.txt`. 96 | 3. Run `python install.py` to build all the libraries inside Docker containers (this can take a while, like 10-30 minutes). 97 | 98 | Running 99 | ======= 100 | 101 | 1. Run `python run.py` (this can take an extremely long time, potentially days) 102 | 2. Run `python plot.py` or `python create_website.py` to plot results. 103 | 104 | You can customize the algorithms and datasets if you want to: 105 | 106 | * Check that `algos.yaml` contains the parameter settings that you want to test 107 | * To run experiments on SIFT, invoke `python run.py --dataset glove-100-angular`. See `python run.py --help` for more information on possible settings. Note that experiments can take a long time. 108 | * To process the results, either use `python plot.py --dataset glove-100-angular` or `python create_website.py`. An example call: `python create_website.py --plottype recall/time --latex --scatter --outputdir website/`. 109 | 110 | Including your algorithm 111 | ======================== 112 | 113 | 1. Add your algorithm into `ann_benchmarks/algorithms` by providing a small Python wrapper. 114 | 2. Add a Dockerfile in `install/` for it 115 | 3. Add it to `algos.yaml` 116 | 4. Add it to `.travis.yml` 117 | 118 | Principles 119 | ========== 120 | 121 | * Everyone is welcome to submit pull requests with tweaks and changes to how each library is being used. 122 | * In particular: if you are the author of any of these libraries, and you think the benchmark can be improved, consider making the improvement and submitting a pull request. 123 | * This is meant to be an ongoing project and represent the current state. 124 | * Make everything easy to replicate, including installing and preparing the datasets. 125 | * Try many different values of parameters for each library and ignore the points that are not on the precision-performance frontier. 126 | * High-dimensional datasets with approximately 100-1000 dimensions. This is challenging but also realistic. Not more than 1000 dimensions because those problems should probably be solved by doing dimensionality reduction separately. 127 | * Single queries are used by default. ANN-Benchmarks enforces that only one CPU is saturated during experimentation, i.e., no multi-threading. A batch mode is available that provides all queries to the implementations at once. Add the flag `--batch` to `run.py` and `plot.py` to enable batch mode. 128 | * Avoid extremely costly index building (more than several hours). 129 | * Focus on datasets that fit in RAM. Out of core ANN could be the topic of a later comparison. 130 | * We mainly support CPU-based ANN algorithms. GPU support exists for FAISS, but it has to be compiled with GPU support locally and experiments must be run using the flags `--local --batch`. 131 | * Do proper train/test set of index data and query points. 132 | * Note that set similarity was supported in the past. This might hopefully be added back soon. 133 | 134 | 135 | Authors 136 | ======= 137 | 138 | Built by [Erik Bernhardsson](https://erikbern.com) with significant contributions from [Martin Aumüller](http://itu.dk/people/maau/) and [Alexander Faithfull](https://github.com/ale-f). 139 | 140 | Related Publication 141 | ================== 142 | 143 | The following publication details design principles behind the benchmarking framework: 144 | 145 | - M. Aumüller, E. Bernhardsson, A. Faithfull: 146 | [ANN-Benchmarks: A Benchmarking Tool for Approximate Nearest Neighbor Algorithms](https://arxiv.org/abs/1807.05614). Information Systems 2019. DOI: [10.1016/j.is.2019.02.006](https://doi.org/10.1016/j.is.2019.02.006) 147 | -------------------------------------------------------------------------------- /ann_benchmarks/main.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import argparse 3 | import logging 4 | import logging.config 5 | 6 | import docker 7 | import multiprocessing.pool 8 | import os 9 | import psutil 10 | import random 11 | import shutil 12 | import sys 13 | import traceback 14 | 15 | from ann_benchmarks.datasets import get_dataset, DATASETS 16 | from ann_benchmarks.constants import INDEX_DIR 17 | from ann_benchmarks.algorithms.definitions import (get_definitions, 18 | list_algorithms, 19 | algorithm_status, 20 | InstantiationStatus) 21 | from ann_benchmarks.results import get_result_filename 22 | from ann_benchmarks.runner import run, run_docker 23 | 24 | 25 | def positive_int(s): 26 | i = None 27 | try: 28 | i = int(s) 29 | except ValueError: 30 | pass 31 | if not i or i < 1: 32 | raise argparse.ArgumentTypeError("%r is not a positive integer" % s) 33 | return i 34 | 35 | 36 | def run_worker(cpu, args, queue): 37 | while not queue.empty(): 38 | definition = queue.get() 39 | if args.local: 40 | run(definition, args.dataset, args.count, args.runs, args.batch) 41 | else: 42 | memory_margin = 500e6 # reserve some extra memory for misc stuff 43 | mem_limit = int((psutil.virtual_memory().available - memory_margin) / args.parallelism) 44 | run_docker(definition, args.dataset, args.count, 45 | args.runs, args.timeout, args.batch, str(cpu), mem_limit) 46 | 47 | 48 | def main(): 49 | parser = argparse.ArgumentParser( 50 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 51 | parser.add_argument( 52 | '--dataset', 53 | metavar='NAME', 54 | help='the dataset to load training points from', 55 | default='glove-100-angular', 56 | choices=DATASETS.keys()) 57 | parser.add_argument( 58 | "-k", "--count", 59 | default=10, 60 | type=positive_int, 61 | help="the number of near neighbours to search for") 62 | parser.add_argument( 63 | '--definitions', 64 | metavar='FILE', 65 | help='load algorithm definitions from FILE', 66 | default='algos.yaml') 67 | parser.add_argument( 68 | '--algorithm', 69 | metavar='NAME', 70 | help='run only the named algorithm', 71 | default=None) 72 | parser.add_argument( 73 | '--docker-tag', 74 | metavar='NAME', 75 | help='run only algorithms in a particular docker image', 76 | default=None) 77 | parser.add_argument( 78 | '--list-algorithms', 79 | help='print the names of all known algorithms and exit', 80 | action='store_true') 81 | parser.add_argument( 82 | '--force', 83 | help='re-run algorithms even if their results already exist', 84 | action='store_true') 85 | parser.add_argument( 86 | '--runs', 87 | metavar='COUNT', 88 | type=positive_int, 89 | help='run each algorithm instance %(metavar)s times and use only' 90 | ' the best result', 91 | default=5) 92 | parser.add_argument( 93 | '--timeout', 94 | type=int, 95 | help='Timeout (in seconds) for each individual algorithm run, or -1' 96 | 'if no timeout should be set', 97 | default=2 * 3600) 98 | parser.add_argument( 99 | '--local', 100 | action='store_true', 101 | help='If set, then will run everything locally (inside the same ' 102 | 'process) rather than using Docker') 103 | parser.add_argument( 104 | '--batch', 105 | action='store_true', 106 | help='If set, algorithms get all queries at once') 107 | parser.add_argument( 108 | '--max-n-algorithms', 109 | type=int, 110 | help='Max number of algorithms to run (just used for testing)', 111 | default=-1) 112 | parser.add_argument( 113 | '--run-disabled', 114 | help='run algorithms that are disabled in algos.yml', 115 | action='store_true') 116 | parser.add_argument( 117 | '--parallelism', 118 | type=positive_int, 119 | help='Number of Docker containers in parallel', 120 | default=1) 121 | 122 | args = parser.parse_args() 123 | if args.timeout == -1: 124 | args.timeout = None 125 | 126 | if args.list_algorithms: 127 | list_algorithms(args.definitions) 128 | sys.exit(0) 129 | 130 | logging.config.fileConfig("logging.conf") 131 | logger = logging.getLogger("annb") 132 | 133 | # Nmslib specific code 134 | # Remove old indices stored on disk 135 | if os.path.exists(INDEX_DIR): 136 | shutil.rmtree(INDEX_DIR) 137 | 138 | dataset = get_dataset(args.dataset) 139 | dimension = len(dataset['train'][0]) # TODO(erikbern): ugly 140 | point_type = dataset.attrs.get('point_type', 'float') 141 | distance = dataset.attrs['distance'] 142 | definitions = get_definitions( 143 | args.definitions, dimension, point_type, distance, args.count) 144 | 145 | # Filter out, from the loaded definitions, all those query argument groups 146 | # that correspond to experiments that have already been run. (This might 147 | # mean removing a definition altogether, so we can't just use a list 148 | # comprehension.) 149 | filtered_definitions = [] 150 | for definition in definitions: 151 | query_argument_groups = definition.query_argument_groups 152 | if not query_argument_groups: 153 | query_argument_groups = [[]] 154 | not_yet_run = [] 155 | for query_arguments in query_argument_groups: 156 | fn = get_result_filename(args.dataset, 157 | args.count, definition, 158 | query_arguments, args.batch) 159 | if args.force or not os.path.exists(fn): 160 | not_yet_run.append(query_arguments) 161 | if not_yet_run: 162 | if definition.query_argument_groups: 163 | definition = definition._replace( 164 | query_argument_groups=not_yet_run) 165 | filtered_definitions.append(definition) 166 | definitions = filtered_definitions 167 | 168 | random.shuffle(definitions) 169 | 170 | if args.algorithm: 171 | logger.info(f'running only {args.algorithm}') 172 | definitions = [d for d in definitions if d.algorithm == args.algorithm] 173 | 174 | if not args.local: 175 | # See which Docker images we have available 176 | docker_client = docker.from_env() 177 | docker_tags = set() 178 | for image in docker_client.images.list(): 179 | for tag in image.tags: 180 | tag = tag.split(':')[0] 181 | docker_tags.add(tag) 182 | 183 | if args.docker_tag: 184 | logger.info(f'running only {args.docker_tag}') 185 | definitions = [ 186 | d for d in definitions if d.docker_tag == args.docker_tag] 187 | 188 | if set(d.docker_tag for d in definitions).difference(docker_tags): 189 | logger.info(f'not all docker images available, only: {set(docker_tags)}') 190 | logger.info(f'missing docker images: ' 191 | f'{str(set(d.docker_tag for d in definitions).difference(docker_tags))}') 192 | definitions = [ 193 | d for d in definitions if d.docker_tag in docker_tags] 194 | else: 195 | def _test(df): 196 | status = algorithm_status(df) 197 | # If the module was loaded but doesn't actually have a constructor 198 | # of the right name, then the definition is broken 199 | if status == InstantiationStatus.NO_CONSTRUCTOR: 200 | raise Exception("%s.%s(%s): error: the module '%s' does not" 201 | " expose the named constructor" % ( 202 | df.module, df.constructor, 203 | df.arguments, df.module)) 204 | 205 | if status == InstantiationStatus.NO_MODULE: 206 | # If the module couldn't be loaded (presumably because 207 | # of a missing dependency), print a warning and remove 208 | # this definition from the list of things to be run 209 | logging.warning("%s.%s(%s): the module '%s' could not be " 210 | "loaded; skipping" % (df.module, df.constructor, 211 | df.arguments, df.module)) 212 | return False 213 | else: 214 | return True 215 | definitions = [d for d in definitions if _test(d)] 216 | 217 | if not args.run_disabled: 218 | if len([d for d in definitions if d.disabled]): 219 | logger.info(f'Not running disabled algorithms {[d for d in definitions if d.disabled]}') 220 | definitions = [d for d in definitions if not d.disabled] 221 | 222 | if args.max_n_algorithms >= 0: 223 | definitions = definitions[:args.max_n_algorithms] 224 | 225 | if len(definitions) == 0: 226 | raise Exception('Nothing to run') 227 | else: 228 | logger.info(f'Order: {definitions}') 229 | 230 | if args.parallelism > multiprocessing.cpu_count() - 1: 231 | raise Exception('Parallelism larger than %d! (CPU count minus one)' % (multiprocessing.cpu_count() - 1)) 232 | 233 | # Multiprocessing magic to farm this out to all CPUs 234 | queue = multiprocessing.Queue() 235 | for definition in definitions: 236 | queue.put(definition) 237 | workers = [multiprocessing.Process(target=run_worker, args=(i+1, args, queue)) 238 | for i in range(args.parallelism)] 239 | [worker.start() for worker in workers] 240 | [worker.join() for worker in workers] 241 | 242 | # TODO: need to figure out cleanup handling here 243 | -------------------------------------------------------------------------------- /ann_benchmarks/runner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import threading 6 | import time 7 | import traceback 8 | 9 | import colors 10 | import docker 11 | import numpy 12 | import psutil 13 | 14 | from ann_benchmarks.algorithms.definitions import (Definition, 15 | instantiate_algorithm, 16 | get_algorithm_name) 17 | from ann_benchmarks.datasets import get_dataset, DATASETS 18 | from ann_benchmarks.distance import metrics, dataset_transform 19 | from ann_benchmarks.results import store_results 20 | 21 | 22 | def run_individual_query(algo, X_train, X_test, distance, count, run_count, 23 | batch): 24 | prepared_queries = \ 25 | (batch and hasattr(algo, "prepare_batch_query")) or \ 26 | ((not batch) and hasattr(algo, "prepare_query")) 27 | 28 | best_search_time = float('inf') 29 | for i in range(run_count): 30 | print('Run %d/%d...' % (i + 1, run_count)) 31 | # a bit dumb but can't be a scalar since of Python's scoping rules 32 | n_items_processed = [0] 33 | 34 | def single_query(v): 35 | if prepared_queries: 36 | algo.prepare_query(v, count) 37 | start = time.time() 38 | algo.run_prepared_query() 39 | total = (time.time() - start) 40 | candidates = algo.get_prepared_query_results() 41 | else: 42 | start = time.time() 43 | candidates = algo.query(v, count) 44 | total = (time.time() - start) 45 | candidates = [(int(idx), float(metrics[distance]['distance'](v, X_train[idx]))) # noqa 46 | for idx in candidates] 47 | n_items_processed[0] += 1 48 | if n_items_processed[0] % 1000 == 0: 49 | print('Processed %d/%d queries...' % (n_items_processed[0], len(X_test))) 50 | if len(candidates) > count: 51 | print('warning: algorithm %s returned %d results, but count' 52 | ' is only %d)' % (algo, len(candidates), count)) 53 | return (total, candidates) 54 | 55 | def batch_query(X): 56 | if prepared_queries: 57 | algo.prepare_batch_query(X, count) 58 | start = time.time() 59 | algo.run_batch_query() 60 | total = (time.time() - start) 61 | else: 62 | start = time.time() 63 | algo.batch_query(X, count) 64 | total = (time.time() - start) 65 | results = algo.get_batch_results() 66 | candidates = [[(int(idx), float(metrics[distance]['distance'](v, X_train[idx]))) # noqa 67 | for idx in single_results] 68 | for v, single_results in zip(X, results)] 69 | return [(total / float(len(X)), v) for v in candidates] 70 | 71 | if batch: 72 | results = batch_query(X_test) 73 | else: 74 | results = [single_query(x) for x in X_test] 75 | 76 | total_time = sum(time for time, _ in results) 77 | total_candidates = sum(len(candidates) for _, candidates in results) 78 | search_time = total_time / len(X_test) 79 | avg_candidates = total_candidates / len(X_test) 80 | best_search_time = min(best_search_time, search_time) 81 | 82 | verbose = hasattr(algo, "query_verbose") 83 | attrs = { 84 | "batch_mode": batch, 85 | "best_search_time": best_search_time, 86 | "candidates": avg_candidates, 87 | "expect_extra": verbose, 88 | "name": str(algo), 89 | "run_count": run_count, 90 | "distance": distance, 91 | "count": int(count) 92 | } 93 | additional = algo.get_additional() 94 | for k in additional: 95 | attrs[k] = additional[k] 96 | return (attrs, results) 97 | 98 | 99 | def run(definition, dataset, count, run_count, batch): 100 | algo = instantiate_algorithm(definition) 101 | assert not definition.query_argument_groups \ 102 | or hasattr(algo, "set_query_arguments"), """\ 103 | error: query argument groups have been specified for %s.%s(%s), but the \ 104 | algorithm instantiated from it does not implement the set_query_arguments \ 105 | function""" % (definition.module, definition.constructor, definition.arguments) 106 | 107 | D = get_dataset(dataset) 108 | X_train = numpy.array(D['train']) 109 | X_test = numpy.array(D['test']) 110 | distance = D.attrs['distance'] 111 | print('got a train set of size (%d * %d)' % X_train.shape) 112 | print('got %d queries' % len(X_test)) 113 | 114 | X_train = dataset_transform[distance](X_train) 115 | X_test = dataset_transform[distance](X_test) 116 | 117 | try: 118 | prepared_queries = False 119 | if hasattr(algo, "supports_prepared_queries"): 120 | prepared_queries = algo.supports_prepared_queries() 121 | 122 | t0 = time.time() 123 | memory_usage_before = algo.get_memory_usage() 124 | algo.fit(X_train) 125 | build_time = time.time() - t0 126 | index_size = algo.get_memory_usage() - memory_usage_before 127 | print('Built index in', build_time) 128 | print('Index size: ', index_size) 129 | 130 | query_argument_groups = definition.query_argument_groups 131 | # Make sure that algorithms with no query argument groups still get run 132 | # once by providing them with a single, empty, harmless group 133 | if not query_argument_groups: 134 | query_argument_groups = [[]] 135 | 136 | for pos, query_arguments in enumerate(query_argument_groups, 1): 137 | print("Running query argument group %d of %d..." % 138 | (pos, len(query_argument_groups))) 139 | if query_arguments: 140 | algo.set_query_arguments(*query_arguments) 141 | descriptor, results = run_individual_query( 142 | algo, X_train, X_test, distance, count, run_count, batch) 143 | descriptor["build_time"] = build_time 144 | descriptor["index_size"] = index_size 145 | descriptor["algo"] = get_algorithm_name( 146 | definition.algorithm, batch) 147 | descriptor["dataset"] = dataset 148 | store_results(dataset, count, definition, 149 | query_arguments, descriptor, results, batch) 150 | finally: 151 | algo.done() 152 | 153 | 154 | def run_from_cmdline(): 155 | parser = argparse.ArgumentParser() 156 | parser.add_argument( 157 | '--dataset', 158 | choices=DATASETS.keys(), 159 | required=True) 160 | parser.add_argument( 161 | '--algorithm', 162 | required=True) 163 | parser.add_argument( 164 | '--module', 165 | required=True) 166 | parser.add_argument( 167 | '--constructor', 168 | required=True) 169 | parser.add_argument( 170 | '--count', 171 | required=True, 172 | type=int) 173 | parser.add_argument( 174 | '--runs', 175 | required=True, 176 | type=int) 177 | parser.add_argument( 178 | '--batch', 179 | action='store_true') 180 | parser.add_argument( 181 | 'build') 182 | parser.add_argument( 183 | 'queries', 184 | nargs='*', 185 | default=[]) 186 | args = parser.parse_args() 187 | algo_args = json.loads(args.build) 188 | query_args = [json.loads(q) for q in args.queries] 189 | 190 | definition = Definition( 191 | algorithm=args.algorithm, 192 | docker_tag=None, # not needed 193 | module=args.module, 194 | constructor=args.constructor, 195 | arguments=algo_args, 196 | query_argument_groups=query_args, 197 | disabled=False 198 | ) 199 | run(definition, args.dataset, args.count, args.runs, args.batch) 200 | 201 | 202 | def run_docker(definition, dataset, count, runs, timeout, batch, cpu_limit, 203 | mem_limit=None): 204 | cmd = ['--dataset', dataset, 205 | '--algorithm', definition.algorithm, 206 | '--module', definition.module, 207 | '--constructor', definition.constructor, 208 | '--runs', str(runs), 209 | '--count', str(count)] 210 | if batch: 211 | cmd += ['--batch'] 212 | cmd.append(json.dumps(definition.arguments)) 213 | cmd += [json.dumps(qag) for qag in definition.query_argument_groups] 214 | 215 | client = docker.from_env() 216 | if mem_limit is None: 217 | mem_limit = psutil.virtual_memory().available 218 | 219 | container = client.containers.run( 220 | definition.docker_tag, 221 | cmd, 222 | volumes={ 223 | os.path.abspath('ann_benchmarks'): 224 | {'bind': '/home/app/ann_benchmarks', 'mode': 'ro'}, 225 | os.path.abspath('data'): 226 | {'bind': '/home/app/data', 'mode': 'ro'}, 227 | os.path.abspath('results'): 228 | {'bind': '/home/app/results', 'mode': 'rw'}, 229 | }, 230 | cpuset_cpus=cpu_limit, 231 | mem_limit=mem_limit, 232 | detach=True) 233 | logger = logging.getLogger(f"annb.{container.short_id}") 234 | 235 | logger.info('Created container %s: CPU limit %s, mem limit %s, timeout %d, command %s' % \ 236 | (container.short_id, cpu_limit, mem_limit, timeout, cmd)) 237 | 238 | def stream_logs(): 239 | for line in container.logs(stream=True): 240 | logger.info(colors.color(line.decode().rstrip(), fg='blue')) 241 | 242 | t = threading.Thread(target=stream_logs, daemon=True) 243 | t.start() 244 | 245 | try: 246 | exit_code = container.wait(timeout=timeout) 247 | 248 | # Exit if exit code 249 | if exit_code not in [0, None]: 250 | logger.error(colors.color(container.logs().decode(), fg='red')) 251 | logger.error('Child process for container %s raised exception %d' % (container.short_id, exit_code)) 252 | except: 253 | logger.error('Container.wait for container %s failed with exception' % container.short_id) 254 | traceback.print_exc() 255 | finally: 256 | pass 257 | # container.remove(force=True) 258 | -------------------------------------------------------------------------------- /create_website.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | mpl.use('Agg') # noqa 3 | import argparse 4 | import os 5 | import json 6 | import pickle 7 | import yaml 8 | import numpy 9 | import hashlib 10 | from jinja2 import Environment, FileSystemLoader 11 | 12 | from ann_benchmarks import results 13 | from ann_benchmarks.algorithms.definitions import get_algorithm_name 14 | from ann_benchmarks.datasets import get_dataset 15 | from ann_benchmarks.plotting.plot_variants import (all_plot_variants 16 | as plot_variants) 17 | from ann_benchmarks.plotting.metrics import all_metrics as metrics 18 | from ann_benchmarks.plotting.utils import (get_plot_label, compute_metrics, 19 | compute_all_metrics, 20 | create_pointset, 21 | create_linestyles) 22 | import plot 23 | 24 | colors = [ 25 | "rgba(166,206,227,1)", 26 | "rgba(31,120,180,1)", 27 | "rgba(178,223,138,1)", 28 | "rgba(51,160,44,1)", 29 | "rgba(251,154,153,1)", 30 | "rgba(227,26,28,1)", 31 | "rgba(253,191,111,1)", 32 | "rgba(255,127,0,1)", 33 | "rgba(202,178,214,1)" 34 | ] 35 | 36 | point_styles = { 37 | "o": "circle", 38 | "<": "triangle", 39 | "*": "star", 40 | "x": "cross", 41 | "+": "rect", 42 | } 43 | 44 | 45 | def convert_color(color): 46 | r, g, b, a = color 47 | return "rgba(%(r)d, %(g)d, %(b)d, %(a)d)" % { 48 | "r": r * 255, "g": g * 255, "b": b * 255, "a": a} 49 | 50 | 51 | def convert_linestyle(ls): 52 | new_ls = {} 53 | for algo in ls.keys(): 54 | algostyle = ls[algo] 55 | new_ls[algo] = (convert_color(algostyle[0]), 56 | convert_color(algostyle[1]), 57 | algostyle[2], point_styles[algostyle[3]]) 58 | return new_ls 59 | 60 | 61 | def get_run_desc(properties): 62 | return "%(dataset)s_%(count)d_%(distance)s" % properties 63 | 64 | 65 | def get_dataset_from_desc(desc): 66 | return desc.split("_")[0] 67 | 68 | 69 | def get_count_from_desc(desc): 70 | return desc.split("_")[1] 71 | 72 | 73 | def get_distance_from_desc(desc): 74 | return desc.split("_")[2] 75 | 76 | 77 | def get_dataset_label(desc): 78 | return "{} (k = {})".format(get_dataset_from_desc(desc), 79 | get_count_from_desc(desc)) 80 | 81 | 82 | def directory_path(s): 83 | if not os.path.isdir(s): 84 | raise argparse.ArgumentTypeError("'%s' is not a directory" % s) 85 | return s + "/" 86 | 87 | 88 | def prepare_data(data, xn, yn): 89 | """Change format from (algo, instance, dict) to (algo, instance, x, y).""" 90 | res = [] 91 | for algo, algo_name, result in data: 92 | res.append((algo, algo_name, result[xn], result[yn])) 93 | return res 94 | 95 | 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument( 98 | '--plottype', 99 | help='Generate only the plots specified', 100 | nargs='*', 101 | choices=plot_variants.keys(), 102 | default=plot_variants.keys()) 103 | parser.add_argument( 104 | '--outputdir', 105 | help='Select output directory', 106 | default='.', 107 | type=directory_path, 108 | action='store') 109 | parser.add_argument( 110 | '--latex', 111 | help='generates latex code for each plot', 112 | action='store_true') 113 | parser.add_argument( 114 | '--scatter', 115 | help='create scatterplot for data', 116 | action='store_true') 117 | parser.add_argument( 118 | '--recompute', 119 | help='Clears the cache and recomputes the metrics', 120 | action='store_true') 121 | args = parser.parse_args() 122 | 123 | 124 | def get_lines(all_data, xn, yn, render_all_points): 125 | """ For each algorithm run on a dataset, obtain its performance 126 | curve coords.""" 127 | plot_data = [] 128 | for algo in sorted(all_data.keys(), key=lambda x: x.lower()): 129 | xs, ys, ls, axs, ays, als = \ 130 | create_pointset(prepare_data(all_data[algo], xn, yn), xn, yn) 131 | if render_all_points: 132 | xs, ys, ls = axs, ays, als 133 | plot_data.append({"name": algo, "coords": zip(xs, ys), "labels": ls, 134 | "scatter": render_all_points}) 135 | return plot_data 136 | 137 | 138 | def create_plot(all_data, xn, yn, linestyle, j2_env, additional_label="", 139 | plottype="line"): 140 | xm, ym = (metrics[xn], metrics[yn]) 141 | render_all_points = plottype == "bubble" 142 | plot_data = get_lines(all_data, xn, yn, render_all_points) 143 | latex_code = j2_env.get_template("latex.template").\ 144 | render(plot_data=plot_data, caption=get_plot_label(xm, ym), 145 | xlabel=xm["description"], ylabel=ym["description"]) 146 | plot_data = get_lines(all_data, xn, yn, render_all_points) 147 | button_label = hashlib.sha224((get_plot_label(xm, ym) + additional_label) 148 | .encode("utf-8")).hexdigest() 149 | return j2_env.get_template("chartjs.template").\ 150 | render(args=args, latex_code=latex_code, button_label=button_label, 151 | data_points=plot_data, 152 | xlabel=xm["description"], ylabel=ym["description"], 153 | plottype=plottype, plot_label=get_plot_label(xm, ym), 154 | label=additional_label, linestyle=linestyle, 155 | render_all_points=render_all_points) 156 | 157 | 158 | def build_detail_site(data, label_func, j2_env, linestyles, batch=False): 159 | for (name, runs) in data.items(): 160 | print("Building '%s'" % name) 161 | all_runs = runs.keys() 162 | label = label_func(name) 163 | data = {"normal": [], "scatter": []} 164 | 165 | for plottype in args.plottype: 166 | xn, yn = plot_variants[plottype] 167 | data["normal"].append(create_plot( 168 | runs, xn, yn, convert_linestyle(linestyles), j2_env)) 169 | if args.scatter: 170 | data["scatter"].append( 171 | create_plot(runs, xn, yn, convert_linestyle(linestyles), 172 | j2_env, "Scatterplot ", "bubble")) 173 | 174 | # create png plot for summary page 175 | data_for_plot = {} 176 | for k in runs.keys(): 177 | data_for_plot[k] = prepare_data(runs[k], 'k-nn', 'qps') 178 | plot.create_plot( 179 | data_for_plot, False, 180 | False, True, 'k-nn', 'qps', 181 | args.outputdir + get_algorithm_name(name, batch) + ".png", 182 | linestyles, batch) 183 | output_path = "".join([args.outputdir, 184 | get_algorithm_name(name, batch), 185 | ".html"]) 186 | with open(output_path, "w") as text_file: 187 | text_file.write(j2_env.get_template("detail_page.html"). 188 | render(title=label, plot_data=data, 189 | args=args, batch=batch)) 190 | 191 | 192 | def build_index_site(datasets, algorithms, j2_env, file_name): 193 | dataset_data = {'batch': [], 'non-batch': []} 194 | for mode in ['batch', 'non-batch']: 195 | distance_measures = sorted( 196 | set([get_distance_from_desc(e) for e in datasets[mode].keys()])) 197 | sorted_datasets = sorted( 198 | set([get_dataset_from_desc(e) for e in datasets[mode].keys()])) 199 | 200 | for dm in distance_measures: 201 | d = {"name": dm.capitalize(), "entries": []} 202 | for ds in sorted_datasets: 203 | matching_datasets = [e for e in datasets[mode].keys() 204 | if get_dataset_from_desc(e) == ds and # noqa 205 | get_distance_from_desc(e) == dm] 206 | sorted_matches = sorted( 207 | matching_datasets, 208 | key=lambda e: int(get_count_from_desc(e))) 209 | for idd in sorted_matches: 210 | d["entries"].append( 211 | {"name": idd, "desc": get_dataset_label(idd)}) 212 | dataset_data[mode].append(d) 213 | 214 | with open(args.outputdir + "index.html", "w") as text_file: 215 | text_file.write(j2_env.get_template("summary.html"). 216 | render(title="ANN-Benchmarks", 217 | dataset_with_distances=dataset_data, 218 | algorithms=algorithms, 219 | label_func=get_algorithm_name)) 220 | 221 | 222 | def load_all_results(): 223 | """Read all result files and compute all metrics""" 224 | all_runs_by_dataset = {'batch': {}, 'non-batch': {}} 225 | all_runs_by_algorithm = {'batch': {}, 'non-batch': {}} 226 | cached_true_dist = [] 227 | old_sdn = None 228 | for properties, f in results.load_all_results(): 229 | sdn = get_run_desc(properties) 230 | if sdn != old_sdn: 231 | dataset = get_dataset(properties["dataset"]) 232 | cached_true_dist = list(dataset["distances"]) 233 | old_sdn = sdn 234 | algo = properties["algo"] 235 | ms = compute_all_metrics( 236 | cached_true_dist, f, properties, args.recompute) 237 | algo_ds = get_dataset_label(sdn) 238 | idx = "non-batch" 239 | if properties["batch_mode"]: 240 | idx = "batch" 241 | all_runs_by_algorithm[idx].setdefault( 242 | algo, {}).setdefault(algo_ds, []).append(ms) 243 | all_runs_by_dataset[idx].setdefault( 244 | sdn, {}).setdefault(algo, []).append(ms) 245 | 246 | return (all_runs_by_dataset, all_runs_by_algorithm) 247 | 248 | 249 | j2_env = Environment(loader=FileSystemLoader("./templates/"), trim_blocks=True) 250 | j2_env.globals.update(zip=zip, len=len) 251 | runs_by_ds, runs_by_algo = load_all_results() 252 | dataset_names = [get_dataset_label(x) for x in list( 253 | runs_by_ds['batch'].keys()) + list(runs_by_ds['non-batch'].keys())] 254 | algorithm_names = list(runs_by_algo['batch'].keys( 255 | )) + list(runs_by_algo['non-batch'].keys()) 256 | 257 | linestyles = {**create_linestyles(dataset_names), 258 | **create_linestyles(algorithm_names)} 259 | 260 | build_detail_site( 261 | runs_by_ds['non-batch'], 262 | lambda label: get_dataset_label(label), j2_env, linestyles, False) 263 | 264 | build_detail_site( 265 | runs_by_ds['batch'], 266 | lambda label: get_dataset_label(label), j2_env, linestyles, True) 267 | 268 | build_detail_site( 269 | runs_by_algo['non-batch'], 270 | lambda x: x, j2_env, linestyles, False) 271 | 272 | build_detail_site( 273 | runs_by_algo['batch'], lambda x: x, j2_env, linestyles, True) 274 | 275 | build_index_site(runs_by_ds, runs_by_algo, j2_env, "index.html") 276 | --------------------------------------------------------------------------------