├── pynndescent ├── tests │ ├── __init__.py │ ├── test_data │ │ ├── cosine_hang.npy │ │ ├── pynndescent_bug_np.npz │ │ └── cosine_near_duplicates.npy │ ├── conftest.py │ ├── test_rank.py │ ├── test_distances.py │ ├── test_hub_trees.py │ └── test_pynndescent_.py ├── __init__.py ├── threaded_rp_trees.py ├── graph_utils.py ├── sparse_nndescent.py └── utils.py ├── doc ├── mnist.png ├── sift.png ├── fmnist.png ├── glove100.png ├── glove25.png ├── lastfm.png ├── nytimes.png ├── diversify1.png ├── diversify2.png ├── directed_1nn.png ├── basic_triangle.png ├── undirected_1nn.png ├── common_neighbors.png ├── pynndescent_logo.png ├── neighbor_of_neighbor.png ├── _static │ ├── nndescent_search.mp4 │ ├── nndescent_search_larger.mp4 │ ├── nndescent_search_largest.mp4 │ └── how_pynnd_works_nn_descent_naive.mp4 ├── pynndescent_logo_no_text.png ├── how_pynn_works_naive_nndescent_iter0.png ├── how_pynn_works_naive_nndescent_iter1.png ├── how_pynn_works_naive_nndescent_iter2.png ├── how_pynn_works_naive_nndescent_iter3.png ├── how_pynn_works_naive_nndescent_iter4.png ├── how_pynn_works_naive_nndescent_iter5.png ├── how_pynn_works_initialized_nndescent_iter0.png ├── how_pynn_works_initialized_nndescent_iter1.png ├── how_pynn_works_initialized_nndescent_iter2.png ├── requirements.txt ├── api.rst ├── Makefile ├── make.bat ├── conf.py ├── index.rst ├── sparse_data_with_pynndescent.ipynb ├── performance.ipynb └── pynndescent_metrics.ipynb ├── requirements.txt ├── MANIFEST.in ├── setup.py ├── .readthedocs.yml ├── LICENSE ├── pyproject.toml ├── .gitignore ├── CONTRIBUTING.md ├── CODE_OF_CONDUCT.md ├── azure-pipelines.yml └── README.rst /pynndescent/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /doc/mnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/mnist.png -------------------------------------------------------------------------------- /doc/sift.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/sift.png -------------------------------------------------------------------------------- /doc/fmnist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/fmnist.png -------------------------------------------------------------------------------- /doc/glove100.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/glove100.png -------------------------------------------------------------------------------- /doc/glove25.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/glove25.png -------------------------------------------------------------------------------- /doc/lastfm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/lastfm.png -------------------------------------------------------------------------------- /doc/nytimes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/nytimes.png -------------------------------------------------------------------------------- /doc/diversify1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/diversify1.png -------------------------------------------------------------------------------- /doc/diversify2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/diversify2.png -------------------------------------------------------------------------------- /doc/directed_1nn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/directed_1nn.png -------------------------------------------------------------------------------- /doc/basic_triangle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/basic_triangle.png -------------------------------------------------------------------------------- /doc/undirected_1nn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/undirected_1nn.png -------------------------------------------------------------------------------- /doc/common_neighbors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/common_neighbors.png -------------------------------------------------------------------------------- /doc/pynndescent_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/pynndescent_logo.png -------------------------------------------------------------------------------- /doc/neighbor_of_neighbor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/neighbor_of_neighbor.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | joblib 2 | numpy>=1.17 3 | scikit-learn>=0.18 4 | scipy>=1.0 5 | numba>=0.55.0 6 | llvmlite>=0.38 7 | -------------------------------------------------------------------------------- /doc/_static/nndescent_search.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/_static/nndescent_search.mp4 -------------------------------------------------------------------------------- /doc/pynndescent_logo_no_text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/pynndescent_logo_no_text.png -------------------------------------------------------------------------------- /doc/_static/nndescent_search_larger.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/_static/nndescent_search_larger.mp4 -------------------------------------------------------------------------------- /doc/_static/nndescent_search_largest.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/_static/nndescent_search_largest.mp4 -------------------------------------------------------------------------------- /pynndescent/tests/test_data/cosine_hang.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/pynndescent/tests/test_data/cosine_hang.npy -------------------------------------------------------------------------------- /doc/how_pynn_works_naive_nndescent_iter0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/how_pynn_works_naive_nndescent_iter0.png -------------------------------------------------------------------------------- /doc/how_pynn_works_naive_nndescent_iter1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/how_pynn_works_naive_nndescent_iter1.png -------------------------------------------------------------------------------- /doc/how_pynn_works_naive_nndescent_iter2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/how_pynn_works_naive_nndescent_iter2.png -------------------------------------------------------------------------------- /doc/how_pynn_works_naive_nndescent_iter3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/how_pynn_works_naive_nndescent_iter3.png -------------------------------------------------------------------------------- /doc/how_pynn_works_naive_nndescent_iter4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/how_pynn_works_naive_nndescent_iter4.png -------------------------------------------------------------------------------- /doc/how_pynn_works_naive_nndescent_iter5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/how_pynn_works_naive_nndescent_iter5.png -------------------------------------------------------------------------------- /doc/_static/how_pynnd_works_nn_descent_naive.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/_static/how_pynnd_works_nn_descent_naive.mp4 -------------------------------------------------------------------------------- /doc/how_pynn_works_initialized_nndescent_iter0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/how_pynn_works_initialized_nndescent_iter0.png -------------------------------------------------------------------------------- /doc/how_pynn_works_initialized_nndescent_iter1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/how_pynn_works_initialized_nndescent_iter1.png -------------------------------------------------------------------------------- /doc/how_pynn_works_initialized_nndescent_iter2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/doc/how_pynn_works_initialized_nndescent_iter2.png -------------------------------------------------------------------------------- /pynndescent/tests/test_data/pynndescent_bug_np.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/pynndescent/tests/test_data/pynndescent_bug_np.npz -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include *.md 3 | include *.rst 4 | include requirements.txt 5 | recursive-include pynndescent * 6 | prune pynndescent/__pycache__ 7 | -------------------------------------------------------------------------------- /pynndescent/tests/test_data/cosine_near_duplicates.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lmcinnes/pynndescent/HEAD/pynndescent/tests/test_data/cosine_near_duplicates.npy -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | # All package metadata is now defined in pyproject.toml as per PEP 621. 4 | # This setup() call is retained for compatibility or other tooling purposes. 5 | setup() -------------------------------------------------------------------------------- /doc/requirements.txt: -------------------------------------------------------------------------------- 1 | joblib 2 | scikit-learn>=0.18 3 | scipy>=1.0 4 | numba>=0.51.2 5 | llvmlite>=0.34 6 | pygments>=2.4.1 7 | jupyterlab_pygments>=0.1.1 8 | ipykernel 9 | nbsphinx 10 | matplotlib 11 | seaborn 12 | numpy 13 | pandas 14 | sphinx_rtd_theme 15 | -------------------------------------------------------------------------------- /doc/api.rst: -------------------------------------------------------------------------------- 1 | PyNNDescent API Guide 2 | ===================== 3 | 4 | PyNNDescent has only two classes :class:`NNDescent` and :class:`PyNNDescentTransformer`. 5 | 6 | PyNNDescent 7 | ----------- 8 | 9 | .. autoclass:: pynndescent.pynndescent_.NNDescent 10 | :members: 11 | 12 | .. autoclass:: pynndescent.pynndescent_.PyNNDescentTransformer 13 | :members: 14 | 15 | A number of internal functions can also be accessed separately for more fine tuned work. 16 | 17 | Distance Functions 18 | ------------------ 19 | 20 | .. automodule:: pynndescent.distances 21 | :members: -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-22.04 11 | tools: 12 | python: "3.11" 13 | 14 | # Build documentation in the docs/ directory with Sphinx 15 | sphinx: 16 | configuration: doc/conf.py 17 | 18 | # We recommend specifying your dependencies to enable reproducible builds: 19 | # https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 20 | python: 21 | install: 22 | - requirements: doc/requirements.txt 23 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /pynndescent/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import numba 4 | 5 | from .pynndescent_ import NNDescent, PyNNDescentTransformer 6 | 7 | if sys.version_info[:2] >= (3, 8): 8 | import importlib.metadata as importlib_metadata 9 | else: 10 | import importlib_metadata 11 | 12 | # Workaround: https://github.com/numba/numba/issues/3341 13 | if numba.config.THREADING_LAYER == "omp": 14 | try: 15 | from numba.np.ufunc import tbbpool 16 | 17 | numba.config.THREADING_LAYER = "tbb" 18 | except ImportError as e: 19 | # might be a missing symbol due to e.g. tbb libraries missing 20 | numba.config.THREADING_LAYER = "workqueue" 21 | 22 | __version__ = importlib_metadata.version("pynndescent") 23 | -------------------------------------------------------------------------------- /doc/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2018, Leland McInnes 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.2"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "pynndescent" 7 | version = "0.6.0" 8 | authors = [{name = "Leland McInnes", email = "leland.mcinnes@gmail.com"}] 9 | maintainers = [{name = "Leland McInnes", email = "leland.mcinnes@gmail.com"}] 10 | license = "BSD-2-Clause" 11 | description = "Nearest Neighbor Descent" 12 | keywords = ["nearest", "neighbor", "knn", "ANN"] 13 | readme = "README.rst" 14 | classifiers = [ 15 | "Development Status :: 3 - Alpha", 16 | "Intended Audience :: Science/Research", 17 | "Intended Audience :: Developers", 18 | "Programming Language :: Python", 19 | "Topic :: Software Development", 20 | "Topic :: Scientific/Engineering", 21 | "Operating System :: Microsoft :: Windows", 22 | "Operating System :: POSIX", 23 | "Operating System :: Unix", 24 | "Operating System :: MacOS", 25 | "Programming Language :: Python :: 3.10", 26 | "Programming Language :: Python :: 3.11", 27 | "Programming Language :: Python :: 3.12", 28 | "Programming Language :: Python :: 3.13", 29 | ] 30 | urls = {Homepage = "http://github.com/lmcinnes/pynndescent"} 31 | dependencies = [ 32 | "scikit-learn >= 0.18", 33 | "scipy >= 1.0", 34 | "numba >= 0.55.0", 35 | "llvmlite >= 0.38", 36 | "joblib >= 0.11", 37 | ] 38 | 39 | [project.optional-dependencies] 40 | testing = ["pytest"] 41 | 42 | [tool.setuptools] 43 | packages = ["pynndescent"] 44 | zip-safe = false 45 | include-package-data = false 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | doc/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # PyCharm 104 | .idea/ 105 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Contributions of all kinds are welcome. In particular pull requests are appreciated. 4 | The authors will endeavour to help walk you through any issues in the pull request 5 | discussion, so please feel free to open a pull request even if you are new to such things. 6 | 7 | ## Issues 8 | 9 | The easiest contribution to make is to [file an issue](https://github.com/lmcinnes/umap/issues/new). 10 | It is beneficial if you check the [FAQ](https://umap-learn.readthedocs.io/en/latest/faq.html), 11 | and do a cursory search of [existing issues](https://github.com/lmcinnes/umap/issues?utf8=%E2%9C%93&q=is%3Aissue). 12 | It is also helpful, but not necessary, if you can provide clear instruction for 13 | how to reproduce a problem. If you have resolved an issue yourself please consider 14 | contributing to the FAQ to add your problem, and its resolution, so others can 15 | benefit from your work. 16 | 17 | ## Documentation 18 | 19 | Contributing to documentation is the easiest way to get started. Providing simple 20 | clear or helpful documentation for new users is critical. Anything that *you* as 21 | a new user found hard to understand, or difficult to work out, are excellent places 22 | to begin. Contributions to more detailed and descriptive error messages is 23 | especially appreciated. To contribute to the documentation please 24 | [fork the project](https://github.com/lmcinnes/umap/issues#fork-destination-box) 25 | into your own repository, make changes there, and then submit a pull request. 26 | 27 | ## Code 28 | 29 | Code contributions are always welcome, from simple bug fixes, to new features. To 30 | contribute code please 31 | [fork the project](https://github.com/lmcinnes/umap/issues#fork-destination-box) 32 | into your own repository, make changes there, and then submit a pull request. If 33 | you are fixing a known issue please add the issue number to the PR message. If you 34 | are fixing a new issue feel free to file an issue and then reference it in the PR. 35 | You can [browse open issues](https://github.com/lmcinnes/umap/issues), 36 | or consult the [project roadmap](https://github.com/lmcinnes/umap/issues/15), for potential code 37 | contributions. Fixes for issues tagged with 'help wanted' are especially appreciated. 38 | 39 | ### Code formatting 40 | 41 | If possible, install the [black code formatter](https://github.com/python/black) (e.g. 42 | `pip install black`) and run it before submitting a pull request. This helps maintain consistency 43 | across the code, but also there is a check in the Travis-CI continuous integration system which 44 | will show up as a failure in the pull request if `black` detects that it hasn't been run. 45 | 46 | Formatting is as simple as running: 47 | 48 | ```bash 49 | black . 50 | ``` 51 | 52 | in the root of the project. 53 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | sys.path.insert(0, os.path.abspath("..")) 17 | sys.path.insert(0, os.path.abspath(".")) 18 | 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = "pynndescent" 23 | copyright = "2020, Leland McInnes" 24 | author = "Leland McInnes" 25 | 26 | # The full version, including alpha/beta/rc tags 27 | release = "0.5.0" 28 | 29 | master_doc = "index" 30 | 31 | # -- General configuration --------------------------------------------------- 32 | 33 | # Add any Sphinx extension module names here, as strings. They can be 34 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 35 | # ones. 36 | extensions = [ 37 | "sphinx.ext.autodoc", 38 | "sphinx.ext.intersphinx", 39 | "nbsphinx", 40 | "sphinx.ext.mathjax", 41 | ] 42 | 43 | # Add any paths that contain templates here, relative to this directory. 44 | templates_path = ["_templates"] 45 | 46 | # List of patterns, relative to source directory, that match files and 47 | # directories to ignore when looking for source files. 48 | # This pattern also affects html_static_path and html_extra_path. 49 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 50 | 51 | 52 | # -- Options for HTML output ------------------------------------------------- 53 | 54 | # The theme to use for HTML and HTML Help pages. See the documentation for 55 | # a list of builtin themes. 56 | # 57 | # html_theme = 'alabaster' 58 | html_theme = "sphinx_rtd_theme" 59 | 60 | # Theme options are theme-specific and customize the look and feel of a theme 61 | # further. For a list of options available for each theme, see the 62 | # documentation. 63 | # 64 | html_theme_options = {"navigation_depth": 3, "logo_only": True} 65 | 66 | html_logo = "pynndescent_logo_no_text.png" 67 | 68 | # Add any paths that contain custom static files (such as style sheets) here, 69 | # relative to this directory. They are copied after the builtin static files, 70 | # so a file named "default.css" will overwrite the builtin "default.css". 71 | html_static_path = ["_static"] 72 | 73 | 74 | # Example configuration for intersphinx: refer to the Python standard library. 75 | intersphinx_mapping = { 76 | "python": ("https://docs.python.org/{.major}".format(sys.version_info), None), 77 | "numpy": ("https://docs.scipy.org/doc/numpy/", None), 78 | "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), 79 | "matplotlib": ("https://matplotlib.org/", None), 80 | "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), 81 | "sklearn": ("http://scikit-learn.org/stable/", None), 82 | "bokeh": ("http://bokeh.pydata.org/en/latest/", None), 83 | } 84 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. 6 | 7 | ## Our Standards 8 | 9 | Examples of behavior that contributes to creating a positive environment include: 10 | 11 | * Using welcoming and inclusive language 12 | * Being respectful of differing viewpoints and experiences 13 | * Gracefully accepting constructive criticism 14 | * Focusing on what is best for the community 15 | * Showing empathy towards other community members 16 | 17 | Examples of unacceptable behavior by participants include: 18 | 19 | * The use of sexualized language or imagery and unwelcome sexual attention or advances 20 | * Trolling, insulting/derogatory comments, and personal or political attacks 21 | * Public or private harassment 22 | * Publishing others' private information, such as a physical or electronic address, without explicit permission 23 | * Other conduct which could reasonably be considered inappropriate in a professional setting 24 | 25 | ## Our Responsibilities 26 | 27 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. 28 | 29 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 30 | 31 | ## Scope 32 | 33 | This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 34 | 35 | ## Enforcement 36 | 37 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at leland.mcinnes@gmail.com. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. 38 | 39 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. 40 | 41 | ## Attribution 42 | 43 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, available at [http://contributor-covenant.org/version/1/4][version] 44 | 45 | [homepage]: http://contributor-covenant.org 46 | [version]: http://contributor-covenant.org/version/1/4/ -------------------------------------------------------------------------------- /pynndescent/threaded_rp_trees.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numba 3 | 4 | from pynndescent.utils import tau_rand_int, norm 5 | 6 | ###################################################### 7 | # Alternative tree approach; should be the basis 8 | # for a dask-distributed version of the algorithm 9 | ###################################################### 10 | 11 | 12 | @numba.njit(fastmath=True, nogil=True) 13 | def apply_hyperplane( 14 | data, 15 | hyperplane_vector, 16 | hyperplane_offset, 17 | hyperplane_node_num, 18 | current_num_nodes, 19 | data_node_loc, 20 | rng_state, 21 | ): 22 | 23 | left_node = current_num_nodes 24 | right_node = current_num_nodes + 1 25 | 26 | for i in range(data_node_loc.shape[0]): 27 | if data_node_loc[i] != hyperplane_node_num: 28 | continue 29 | 30 | margin = hyperplane_offset 31 | for d in range(hyperplane_vector.shape[0]): 32 | margin += hyperplane_vector[d] * data[i, d] 33 | 34 | if margin == 0: 35 | if abs(tau_rand_int(rng_state)) % 2 == 0: 36 | data_node_loc[i] = left_node 37 | else: 38 | data_node_loc[i] = right_node 39 | elif margin > 0: 40 | data_node_loc[i] = left_node 41 | else: 42 | data_node_loc[i] = right_node 43 | 44 | return 45 | 46 | 47 | @numba.njit(fastmath=True, nogil=True) 48 | def make_euclidean_hyperplane(data, indices, rng_state): 49 | left_index = tau_rand_int(rng_state) % indices.shape[0] 50 | right_index = tau_rand_int(rng_state) % indices.shape[0] 51 | right_index += left_index == right_index 52 | right_index = right_index % indices.shape[0] 53 | left = indices[left_index] 54 | right = indices[right_index] 55 | 56 | # Compute the normal vector to the hyperplane (the vector between 57 | # the two points) and the offset from the origin 58 | hyperplane_offset = 0.0 59 | hyperplane_vector = np.empty(data.shape[1], dtype=np.float32) 60 | 61 | for d in range(data.shape[1]): 62 | hyperplane_vector[d] = data[left, d] - data[right, d] 63 | hyperplane_offset -= ( 64 | hyperplane_vector[d] * (data[left, d] + data[right, d]) / 2.0 65 | ) 66 | 67 | return hyperplane_vector, hyperplane_offset 68 | 69 | 70 | @numba.njit(fastmath=True, nogil=True) 71 | def make_angular_hyperplane(data, indices, rng_state): 72 | left_index = tau_rand_int(rng_state) % indices.shape[0] 73 | right_index = tau_rand_int(rng_state) % indices.shape[0] 74 | right_index += left_index == right_index 75 | right_index = right_index % indices.shape[0] 76 | left = indices[left_index] 77 | right = indices[right_index] 78 | 79 | left_norm = norm(data[left]) 80 | right_norm = norm(data[right]) 81 | 82 | if left_norm == 0.0: 83 | left_norm = 1.0 84 | 85 | if right_norm == 0.0: 86 | right_norm = 1.0 87 | 88 | # Compute the normal vector to the hyperplane (the vector between 89 | # the two points) and the offset from the origin 90 | hyperplane_offset = 0.0 91 | hyperplane_vector = np.empty(data.shape[1], dtype=np.float32) 92 | 93 | for d in range(data.shape[1]): 94 | hyperplane_vector[d] = (data[left, d] / left_norm) - ( 95 | data[right, d] / right_norm 96 | ) 97 | 98 | return hyperplane_vector, hyperplane_offset 99 | -------------------------------------------------------------------------------- /pynndescent/tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import numpy as np 4 | from scipy import sparse 5 | 6 | # Making Random Seed as a fixture in case it would be 7 | # needed in tests for random states 8 | @pytest.fixture 9 | def seed(): 10 | return 189212 # 0b101110001100011100 11 | 12 | 13 | np.random.seed(189212) 14 | 15 | 16 | @pytest.fixture 17 | def spatial_data(): 18 | sp_data = np.random.randn(10, 20) 19 | # Add some all zero graph_data for corner case test 20 | sp_data = np.vstack([sp_data, np.zeros((2, 20))]).astype(np.float32, order="C") 21 | return sp_data 22 | 23 | 24 | @pytest.fixture 25 | def binary_data(): 26 | bin_data = np.random.choice(a=[False, True], size=(10, 20), p=[0.66, 1 - 0.66]) 27 | # Add some all zero graph_data for corner case test 28 | bin_data = np.vstack([bin_data, np.zeros((2, 20), dtype="bool")]) 29 | return bin_data 30 | 31 | 32 | @pytest.fixture 33 | def sparse_spatial_data(spatial_data, binary_data): 34 | sp_sparse_data = sparse.csr_matrix(spatial_data * binary_data, dtype=np.float32) 35 | sp_sparse_data.sort_indices() 36 | return sp_sparse_data 37 | 38 | 39 | @pytest.fixture 40 | def sparse_binary_data(binary_data): 41 | bin_sparse_data = sparse.csr_matrix(binary_data) 42 | bin_sparse_data.sort_indices() 43 | return bin_sparse_data 44 | 45 | 46 | @pytest.fixture 47 | def nn_data(): 48 | nndata = np.random.uniform(0, 1, size=(1000, 5)) 49 | # Add some all zero graph_data for corner case test 50 | nndata = np.vstack([nndata, np.zeros((2, 5))]) 51 | return nndata 52 | 53 | 54 | @pytest.fixture 55 | def sparse_nn_data(): 56 | return sparse.random(1000, 50, density=0.5, format="csr") 57 | 58 | 59 | @pytest.fixture 60 | def cosine_hang_data(): 61 | this_dir = os.path.dirname(os.path.abspath(__file__)) 62 | data_path = os.path.join(this_dir, "test_data/cosine_hang.npy") 63 | return np.load(data_path) 64 | 65 | 66 | @pytest.fixture 67 | def cosine_near_duplicates_data(): 68 | this_dir = os.path.dirname(os.path.abspath(__file__)) 69 | data_path = os.path.join(this_dir, "test_data/cosine_near_duplicates.npy") 70 | return np.load(data_path) 71 | 72 | 73 | @pytest.fixture 74 | def small_data(): 75 | return np.random.uniform(40, 5, size=(20, 5)) 76 | 77 | 78 | @pytest.fixture 79 | def sparse_small_data(): 80 | # Too low dim might cause more than one empty row, 81 | # which might decrease the computed performance 82 | return sparse.random(40, 32, density=0.5, format="csr") 83 | 84 | 85 | @pytest.fixture 86 | def update_data(): 87 | np.random.seed(12345) 88 | xs_orig = np.random.uniform(0, 1, size=(1000, 5)) 89 | xs_fresh = np.random.uniform(0, 1, size=xs_orig.shape) 90 | xs_fresh_small = np.random.uniform(0, 1, size=(100, xs_orig.shape[1])) 91 | xs_for_complete_update = np.random.uniform(0, 1, size=xs_orig.shape) 92 | updates = [ 93 | (xs_orig, None, None, None), 94 | (xs_orig, xs_fresh, None, None), 95 | (xs_orig, None, xs_for_complete_update, list(range(xs_orig.shape[0]))), 96 | (xs_orig, None, -xs_orig[0:50:2], list(range(0, 50, 2))), 97 | (xs_orig, None, -xs_orig[0:500:2], list(range(0, 500, 2))), 98 | (xs_orig, xs_fresh, xs_for_complete_update, list(range(xs_orig.shape[0]))), 99 | (xs_orig, xs_fresh_small, -xs_orig[0:50:2], list(range(0, 50, 2))), 100 | (xs_orig, xs_fresh, -xs_orig[0:500:2], list(range(0, 500, 2))), 101 | ] 102 | return updates 103 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | .. pynndescent documentation master file, created by 2 | sphinx-quickstart on Sat Sep 12 12:01:55 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | .. image:: pynndescent_logo.png 7 | :width: 600 8 | :align: center 9 | :alt: PyNNDescent Logo 10 | 11 | PyNNDescent for fast Approximate Nearest Neighbors 12 | ================================================== 13 | 14 | PyNNDescent is a Python nearest neighbor descent for approximate nearest neighbors. 15 | It provides a python implementation of Nearest Neighbor 16 | Descent for k-neighbor-graph construction and approximate nearest neighbor 17 | search, as per the paper: 18 | 19 | Dong, Wei, Charikar Moses, and Kai Li. 20 | *"Efficient k-nearest neighbor graph construction for generic similarity 21 | measures."* 22 | Proceedings of the 20th international conference on World wide web. ACM, 2011. 23 | 24 | This library supplements that approach with the use of random projection trees for 25 | initialisation. This can be particularly useful for the metrics that are 26 | amenable to such approaches (euclidean, minkowski, angular, cosine, etc.). Graph 27 | diversification is also performed, pruning the longest edges of any triangles in the 28 | graph. 29 | 30 | Currently this library targets relatively high accuracy 31 | (80%-100% accuracy rate) approximate nearest neighbor searches. 32 | 33 | Why use PyNNDescent? 34 | -------------------- 35 | 36 | PyNNDescent provides fast approximate nearest neighbor queries. The 37 | `ann-benchmarks `_ system puts it 38 | solidly in the mix of top performing ANN libraries: 39 | 40 | **SIFT-128 Euclidean** 41 | 42 | .. image:: https://pynndescent.readthedocs.io/en/latest/_images/sift.png 43 | :alt: ANN benchmark performance for SIFT 128 dataset 44 | 45 | **NYTimes-256 Angular** 46 | 47 | .. image:: https://pynndescent.readthedocs.io/en/latest/_images/nytimes.png 48 | :alt: ANN benchmark performance for NYTimes 256 dataset 49 | 50 | While PyNNDescent is among fastest ANN library, it is also both easy to install (pip 51 | and conda installable) with no platform or compilation issues, and is very flexible, 52 | supporting a wide variety of distance metrics by default: 53 | 54 | **Minkowski style metrics** 55 | 56 | - euclidean 57 | - l2 58 | - sqeuclidean 59 | - manhattan 60 | - taxicab 61 | - l1 62 | - chebyshev 63 | - linfinity 64 | - minkowski 65 | 66 | **Standardised/weighted spatial metrics** 67 | 68 | - mahalanobis 69 | - wminkowski (weighted_minkowski) 70 | - seuclidean (standardised_euclidean) 71 | 72 | **Miscellaneous spatial metrics** 73 | 74 | - canberra 75 | - braycurtis 76 | - haversine 77 | 78 | **Angular and correlation metrics** 79 | 80 | - cosine 81 | - dot 82 | - inner_product 83 | - correlation 84 | - spearmanr 85 | - tsss 86 | - true_angular 87 | 88 | **Probability metrics** 89 | 90 | - hellinger 91 | - kantorovich (wasserstein) 92 | - wasserstein_1d 93 | - circular_kantorovich (circular_wasserstein) 94 | - sinkhorn 95 | - jensen_shannon 96 | - symmetric_kl 97 | 98 | **Metrics for binary data** 99 | 100 | - hamming 101 | - jaccard 102 | - dice 103 | - matching 104 | - russellrao 105 | - kulsinski 106 | - rogerstanimoto 107 | - sokalmichener 108 | - sokalsneath 109 | - yule 110 | 111 | **Metrics for bit-packed binary data** 112 | 113 | - bit_hamming 114 | - bit_jaccard 115 | 116 | and also custom user defined distance metrics while still retaining performance. 117 | 118 | PyNNDescent also integrates well with Scikit-learn, including providing support 119 | for the KNeighborTransformer as a drop in replacement for algorithms 120 | that make use of nearest neighbor computations. 121 | 122 | Installing 123 | ---------- 124 | 125 | PyNNDescent is designed to be easy to install being a pure python module with 126 | relatively light requirements: 127 | 128 | * numpy 129 | * scipy 130 | * scikit-learn >= 0.22 131 | * numba >= 0.51 132 | 133 | all of which should be pip or conda installable. The easiest way to install should be 134 | via conda: 135 | 136 | .. code:: bash 137 | 138 | conda install -c conda-forge pynndescent 139 | 140 | or via pip: 141 | 142 | .. code:: bash 143 | 144 | pip install pynndescent 145 | 146 | 147 | .. toctree:: 148 | :maxdepth: 2 149 | :caption: User Guide / Tutorial: 150 | 151 | how_to_use_pynndescent 152 | pynndescent_metrics 153 | sparse_data_with_pynndescent 154 | pynndescent_in_pipelines 155 | 156 | .. toctree:: 157 | :maxdepth: 2 158 | :caption: Background 159 | 160 | how_pynndescent_works 161 | performance 162 | 163 | .. toctree:: 164 | :caption: API Reference: 165 | 166 | api 167 | 168 | Indices and tables 169 | ================== 170 | 171 | * :ref:`genindex` 172 | * :ref:`modindex` 173 | * :ref:`search` 174 | -------------------------------------------------------------------------------- /azure-pipelines.yml: -------------------------------------------------------------------------------- 1 | # Trigger a build when there is a push to the main branch or a tag starts with release- 2 | trigger: 3 | branches: 4 | include: 5 | - master 6 | tags: 7 | include: 8 | - release-* 9 | 10 | # Trigger a build when there is a pull request to the main branch 11 | # Ignore PRs that are just updating the docs 12 | pr: 13 | branches: 14 | include: 15 | - master 16 | exclude: 17 | - doc/* 18 | - README.rst 19 | 20 | parameters: 21 | - name: includeReleaseCandidates 22 | displayName: "Allow pre-release dependencies" 23 | type: boolean 24 | default: false 25 | 26 | variables: 27 | triggeredByPullRequest: $[eq(variables['Build.Reason'], 'PullRequest')] 28 | 29 | stages: 30 | - stage: RunAllTests 31 | displayName: Run test suite 32 | jobs: 33 | - job: run_platform_tests 34 | strategy: 35 | matrix: 36 | 37 | mac_py310: 38 | imageName: 'macOS-latest' 39 | python.version: '3.10' 40 | linux_py310: 41 | imageName: 'ubuntu-latest' 42 | python.version: '3.10' 43 | windows_py310: 44 | imageName: 'windows-latest' 45 | python.version: '3.10' 46 | mac_py311: 47 | imageName: 'macOS-latest' 48 | python.version: '3.11' 49 | linux_py311: 50 | imageName: 'ubuntu-latest' 51 | python.version: '3.11' 52 | windows_py311: 53 | imageName: 'windows-latest' 54 | python.version: '3.11' 55 | mac_py312: 56 | imageName: 'macOS-latest' 57 | python.version: '3.12' 58 | linux_py312: 59 | imageName: 'ubuntu-latest' 60 | python.version: '3.12' 61 | windows_py312: 62 | imageName: 'windows-latest' 63 | python.version: '3.12' 64 | mac_py313: 65 | imageName: 'macOS-latest' 66 | python.version: '3.13' 67 | linux_py313: 68 | imageName: 'ubuntu-latest' 69 | python.version: '3.13' 70 | windows_py313: 71 | imageName: 'windows-latest' 72 | python.version: '3.13' 73 | pool: 74 | vmImage: $(imageName) 75 | 76 | steps: 77 | - task: UsePythonVersion@0 78 | inputs: 79 | versionSpec: '$(python.version)' 80 | displayName: 'Use Python $(python.version)' 81 | 82 | - script: | 83 | python -m pip install --upgrade pip 84 | displayName: 'Upgrade pip' 85 | 86 | - script: | 87 | pip install -r requirements.txt 88 | displayName: 'Install dependencies' 89 | condition: ${{ eq(parameters.includeReleaseCandidates, false) }} 90 | 91 | - script: | 92 | pip install --pre -r requirements.txt 93 | displayName: 'Install dependencies (allow pre-releases)' 94 | condition: ${{ eq(parameters.includeReleaseCandidates, true) }} 95 | 96 | - script: | 97 | pip install -e . 98 | pip install pytest pytest-azurepipelines 99 | pip install pytest-cov 100 | displayName: 'Install package' 101 | 102 | - script: | 103 | pytest pynndescent/tests --show-capture=no -v --disable-warnings --junitxml=junit/test-results.xml --cov=pynndescent/ --cov-report=xml --cov-report=html 104 | displayName: 'Run tests' 105 | 106 | - task: PublishTestResults@2 107 | inputs: 108 | testResultsFiles: '$(System.DefaultWorkingDirectory)/**/coverage.xml' 109 | testRunTitle: '$(Agent.OS) - $(Build.BuildNumber)[$(Agent.JobName)] - Python $(python.version)' 110 | condition: succeededOrFailed() 111 | 112 | - stage: BuildPublishArtifact 113 | dependsOn: RunAllTests 114 | condition: and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/tags/release-'), eq(variables.triggeredByPullRequest, false)) 115 | jobs: 116 | - job: BuildArtifacts 117 | displayName: Build source dists and wheels 118 | pool: 119 | vmImage: 'ubuntu-latest' 120 | steps: 121 | - task: UsePythonVersion@0 122 | inputs: 123 | versionSpec: '3.10' 124 | displayName: 'Use Python 3.10' 125 | 126 | - script: | 127 | python -m pip install --upgrade pip 128 | pip install wheel 129 | pip install -r requirements.txt 130 | displayName: 'Install dependencies' 131 | 132 | - script: | 133 | pip install -e . 134 | displayName: 'Install package locally' 135 | 136 | - script: | 137 | pip install build 138 | python -m build --wheel --sdist --outdir dist/ . 139 | displayName: 'Build package' 140 | 141 | - bash: | 142 | export PACKAGE_VERSION="$(python setup.py --version)" 143 | echo "Package Version: ${PACKAGE_VERSION}" 144 | echo "##vso[task.setvariable variable=packageVersionFormatted;]release-${PACKAGE_VERSION}" 145 | displayName: 'Get package version' 146 | 147 | - script: | 148 | echo "Version in git tag $(Build.SourceBranchName) does not match version derived from setup.py $(packageVersionFormatted)" 149 | exit 1 150 | displayName: Raise error if version doesnt match tag 151 | condition: and(succeeded(), ne(variables['Build.SourceBranchName'], variables['packageVersionFormatted'])) 152 | 153 | - task: DownloadSecureFile@1 154 | name: PYPIRC_CONFIG 155 | displayName: 'Download pypirc' 156 | inputs: 157 | secureFile: 'pypirc' 158 | 159 | - script: | 160 | pip install twine 161 | twine upload --repository pypi --config-file $(PYPIRC_CONFIG.secureFilePath) dist/* 162 | displayName: 'Upload to PyPI' 163 | condition: and(succeeded(), eq(variables['Build.SourceBranchName'], variables['packageVersionFormatted'])) 164 | 165 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | .. image:: doc/pynndescent_logo.png 2 | :width: 600 3 | :align: center 4 | :alt: PyNNDescent Logo 5 | 6 | .. image:: https://dev.azure.com/TutteInstitute/build-pipelines/_apis/build/status%2Flmcinnes.pynndescent?branchName=master 7 | :target: https://dev.azure.com/TutteInstitute/build-pipelines/_build?definitionId=17 8 | :alt: Azure Pipelines Build Status 9 | .. image:: https://readthedocs.org/projects/pynndescent/badge/?version=latest 10 | :target: https://pynndescent.readthedocs.io/en/latest/?badge=latest 11 | :alt: Documentation Status 12 | 13 | =========== 14 | PyNNDescent 15 | =========== 16 | 17 | PyNNDescent is a Python nearest neighbor descent for approximate nearest neighbors. 18 | It provides a python implementation of Nearest Neighbor 19 | Descent for k-neighbor-graph construction and approximate nearest neighbor 20 | search, as per the paper: 21 | 22 | Dong, Wei, Charikar Moses, and Kai Li. 23 | *"Efficient k-nearest neighbor graph construction for generic similarity 24 | measures."* 25 | Proceedings of the 20th international conference on World wide web. ACM, 2011. 26 | 27 | This library supplements that approach with the use of random projection trees for 28 | initialisation. This can be particularly useful for the metrics that are 29 | amenable to such approaches (euclidean, minkowski, angular, cosine, etc.). Graph 30 | diversification is also performed, pruning the longest edges of any triangles in the 31 | graph. 32 | 33 | Currently this library targets relatively high accuracy 34 | (80%-100% accuracy rate) approximate nearest neighbor searches. 35 | 36 | -------------------- 37 | Why use PyNNDescent? 38 | -------------------- 39 | 40 | PyNNDescent provides fast approximate nearest neighbor queries. The 41 | `ann-benchmarks `_ system puts it 42 | solidly in the mix of top performing ANN libraries: 43 | 44 | **SIFT-128 Euclidean** 45 | 46 | .. image:: https://pynndescent.readthedocs.io/en/latest/_images/sift.png 47 | :alt: ANN benchmark performance for SIFT 128 dataset 48 | 49 | **NYTimes-256 Angular** 50 | 51 | .. image:: https://pynndescent.readthedocs.io/en/latest/_images/nytimes.png 52 | :alt: ANN benchmark performance for NYTimes 256 dataset 53 | 54 | While PyNNDescent is among fastest ANN library, it is also both easy to install (pip 55 | and conda installable) with no platform or compilation issues, and is very flexible, 56 | supporting a wide variety of distance metrics by default: 57 | 58 | **Minkowski style metrics** 59 | 60 | - euclidean 61 | - manhattan 62 | - chebyshev 63 | - minkowski 64 | 65 | **Miscellaneous spatial metrics** 66 | 67 | - canberra 68 | - braycurtis 69 | - haversine 70 | 71 | **Normalized spatial metrics** 72 | 73 | - mahalanobis 74 | - wminkowski 75 | - seuclidean 76 | 77 | **Angular and correlation metrics** 78 | 79 | - cosine 80 | - dot 81 | - correlation 82 | - spearmanr 83 | - tsss 84 | - true_angular 85 | 86 | **Probability metrics** 87 | 88 | - hellinger 89 | - wasserstein 90 | 91 | **Metrics for binary data** 92 | 93 | - hamming 94 | - jaccard 95 | - dice 96 | - russelrao 97 | - kulsinski 98 | - rogerstanimoto 99 | - sokalmichener 100 | - sokalsneath 101 | - yule 102 | 103 | and also custom user defined distance metrics while still retaining performance. 104 | 105 | PyNNDescent also integrates well with Scikit-learn, including providing support 106 | for the KNeighborTransformer as a drop in replacement for algorithms 107 | that make use of nearest neighbor computations. 108 | 109 | ---------------------- 110 | How to use PyNNDescent 111 | ---------------------- 112 | 113 | PyNNDescent aims to have a very simple interface. It is similar to (but more 114 | limited than) KDTrees and BallTrees in ``sklearn``. In practice there are 115 | only two operations -- index construction, and querying an index for nearest 116 | neighbors. 117 | 118 | To build a new search index on some training data ``data`` you can do something 119 | like 120 | 121 | .. code:: python 122 | 123 | from pynndescent import NNDescent 124 | index = NNDescent(data) 125 | 126 | You can then use the index for searching (and can pickle it to disk if you 127 | wish). To search a pynndescent index for the 15 nearest neighbors of a test data 128 | set ``query_data`` you can do something like 129 | 130 | .. code:: python 131 | 132 | index.query(query_data, k=15) 133 | 134 | and that is pretty much all there is to it. You can find more details in the 135 | `documentation `_. 136 | 137 | ---------- 138 | Installing 139 | ---------- 140 | 141 | PyNNDescent is designed to be easy to install being a pure python module with 142 | relatively light requirements: 143 | 144 | * numpy 145 | * scipy 146 | * scikit-learn >= 0.22 147 | * numba >= 0.51 148 | 149 | all of which should be pip or conda installable. The easiest way to install should be 150 | via conda: 151 | 152 | .. code:: bash 153 | 154 | conda install -c conda-forge pynndescent 155 | 156 | or via pip: 157 | 158 | .. code:: bash 159 | 160 | pip install pynndescent 161 | 162 | To manually install this package: 163 | 164 | .. code:: bash 165 | 166 | wget https://github.com/lmcinnes/pynndescent/archive/master.zip 167 | unzip master.zip 168 | rm master.zip 169 | cd pynndescent-master 170 | python setup.py install 171 | 172 | ---------------- 173 | Help and Support 174 | ---------------- 175 | 176 | This project is still young. The documentation is still growing. In the meantime please 177 | `open an issue `_ 178 | and I will try to provide any help and guidance that I can. Please also check 179 | the docstrings on the code, which provide some descriptions of the parameters. 180 | 181 | ------- 182 | License 183 | ------- 184 | 185 | The pynndescent package is 2-clause BSD licensed. Enjoy. 186 | 187 | ------------ 188 | Contributing 189 | ------------ 190 | 191 | Contributions are more than welcome! There are lots of opportunities 192 | for potential projects, so please get in touch if you would like to 193 | help out. Everything from code to notebooks to 194 | examples and documentation are all *equally valuable* so please don't feel 195 | you can't contribute. To contribute please `fork the project `_ make your changes and 196 | submit a pull request. We will do our best to work through any issues with 197 | you and get your code merged into the main branch. 198 | 199 | 200 | -------------------------------------------------------------------------------- /pynndescent/tests/test_rank.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | from numpy.testing import assert_array_equal 4 | 5 | from pynndescent.distances import rankdata 6 | 7 | 8 | def test_empty(): 9 | """rankdata([]) should return an empty array.""" 10 | a = np.array([], dtype=int) 11 | r = rankdata(a) 12 | assert_array_equal(r, np.array([], dtype=np.float64)) 13 | 14 | 15 | def test_one(): 16 | """Check rankdata with an array of length 1.""" 17 | data = [100] 18 | a = np.array(data, dtype=int) 19 | r = rankdata(a) 20 | assert_array_equal(r, np.array([1.0], dtype=np.float64)) 21 | 22 | 23 | def test_basic(): 24 | """Basic tests of rankdata.""" 25 | data = [100, 10, 50] 26 | expected = np.array([3.0, 1.0, 2.0], dtype=np.float64) 27 | a = np.array(data, dtype=int) 28 | r = rankdata(a) 29 | assert_array_equal(r, expected) 30 | 31 | data = [40, 10, 30, 10, 50] 32 | expected = np.array([4.0, 1.5, 3.0, 1.5, 5.0], dtype=np.float64) 33 | a = np.array(data, dtype=int) 34 | r = rankdata(a) 35 | assert_array_equal(r, expected) 36 | 37 | data = [20, 20, 20, 10, 10, 10] 38 | expected = np.array([5.0, 5.0, 5.0, 2.0, 2.0, 2.0], dtype=np.float64) 39 | a = np.array(data, dtype=int) 40 | r = rankdata(a) 41 | assert_array_equal(r, expected) 42 | # The docstring states explicitly that the argument is flattened. 43 | a2d = a.reshape(2, 3) 44 | r = rankdata(a2d) 45 | assert_array_equal(r, expected) 46 | 47 | 48 | def test_rankdata_object_string(): 49 | min_rank = lambda a: [1 + sum(i < j for i in a) for j in a] 50 | max_rank = lambda a: [sum(i <= j for i in a) for j in a] 51 | ordinal_rank = lambda a: min_rank([(x, i) for i, x in enumerate(a)]) 52 | 53 | def average_rank(a): 54 | return np.array([(i + j) / 2.0 for i, j in zip(min_rank(a), max_rank(a))]) 55 | 56 | def dense_rank(a): 57 | b = np.unique(a) 58 | return np.array([1 + sum(i < j for i in b) for j in a]) 59 | 60 | rankf = dict( 61 | min=min_rank, 62 | max=max_rank, 63 | ordinal=ordinal_rank, 64 | average=average_rank, 65 | dense=dense_rank, 66 | ) 67 | 68 | def check_ranks(a): 69 | for method in "min", "max", "dense", "ordinal", "average": 70 | out = rankdata(a, method=method) 71 | assert_array_equal(out, rankf[method](a)) 72 | 73 | check_ranks(np.random.uniform(size=[200])) 74 | 75 | 76 | def test_large_int(): 77 | data = np.array([2**60, 2**60 + 1], dtype=np.uint64) 78 | r = rankdata(data) 79 | assert_array_equal(r, [1.0, 2.0]) 80 | 81 | data = np.array([2**60, 2**60 + 1], dtype=np.int64) 82 | r = rankdata(data) 83 | assert_array_equal(r, [1.0, 2.0]) 84 | 85 | data = np.array([2**60, -(2**60) + 1], dtype=np.int64) 86 | r = rankdata(data) 87 | assert_array_equal(r, [2.0, 1.0]) 88 | 89 | 90 | def test_big_tie(): 91 | for n in [10000, 100000, 1000000]: 92 | data = np.ones(n, dtype=int) 93 | r = rankdata(data) 94 | expected_rank = 0.5 * (n + 1) 95 | assert_array_equal(r, expected_rank * data, "test failed with n=%d" % n) 96 | 97 | 98 | @pytest.mark.parametrize( 99 | "values,method,expected", 100 | [ # values, method, expected 101 | (np.array([], np.float64), "average", np.array([], np.float64)), 102 | (np.array([], np.float64), "min", np.array([], np.float64)), 103 | (np.array([], np.float64), "max", np.array([], np.float64)), 104 | (np.array([], np.float64), "dense", np.array([], np.float64)), 105 | (np.array([], np.float64), "ordinal", np.array([], np.float64)), 106 | # 107 | (np.array([100], np.float64), "average", np.array([1.0], np.float64)), 108 | (np.array([100], np.float64), "min", np.array([1.0], np.float64)), 109 | (np.array([100], np.float64), "max", np.array([1.0], np.float64)), 110 | (np.array([100], np.float64), "dense", np.array([1.0], np.float64)), 111 | (np.array([100], np.float64), "ordinal", np.array([1.0], np.float64)), 112 | # # 113 | ( 114 | np.array([100, 100, 100], np.float64), 115 | "average", 116 | np.array([2.0, 2.0, 2.0], np.float64), 117 | ), 118 | ( 119 | np.array([100, 100, 100], np.float64), 120 | "min", 121 | np.array([1.0, 1.0, 1.0], np.float64), 122 | ), 123 | ( 124 | np.array([100, 100, 100], np.float64), 125 | "max", 126 | np.array([3.0, 3.0, 3.0], np.float64), 127 | ), 128 | ( 129 | np.array([100, 100, 100], np.float64), 130 | "dense", 131 | np.array([1.0, 1.0, 1.0], np.float64), 132 | ), 133 | ( 134 | np.array([100, 100, 100], np.float64), 135 | "ordinal", 136 | np.array([1.0, 2.0, 3.0], np.float64), 137 | ), 138 | # 139 | ( 140 | np.array([100, 300, 200], np.float64), 141 | "average", 142 | np.array([1.0, 3.0, 2.0], np.float64), 143 | ), 144 | ( 145 | np.array([100, 300, 200], np.float64), 146 | "min", 147 | np.array([1.0, 3.0, 2.0], np.float64), 148 | ), 149 | ( 150 | np.array([100, 300, 200], np.float64), 151 | "max", 152 | np.array([1.0, 3.0, 2.0], np.float64), 153 | ), 154 | ( 155 | np.array([100, 300, 200], np.float64), 156 | "dense", 157 | np.array([1.0, 3.0, 2.0], np.float64), 158 | ), 159 | ( 160 | np.array([100, 300, 200], np.float64), 161 | "ordinal", 162 | np.array([1.0, 3.0, 2.0], np.float64), 163 | ), 164 | # 165 | ( 166 | np.array([100, 200, 300, 200], np.float64), 167 | "average", 168 | np.array([1.0, 2.5, 4.0, 2.5], np.float64), 169 | ), 170 | ( 171 | np.array([100, 200, 300, 200], np.float64), 172 | "min", 173 | np.array([1.0, 2.0, 4.0, 2.0], np.float64), 174 | ), 175 | ( 176 | np.array([100, 200, 300, 200], np.float64), 177 | "max", 178 | np.array([1.0, 3.0, 4.0, 3.0], np.float64), 179 | ), 180 | ( 181 | np.array([100, 200, 300, 200], np.float64), 182 | "dense", 183 | np.array([1.0, 2.0, 3.0, 2.0], np.float64), 184 | ), 185 | ( 186 | np.array([100, 200, 300, 200], np.float64), 187 | "ordinal", 188 | np.array([1.0, 2.0, 4.0, 3.0], np.float64), 189 | ), 190 | # 191 | ( 192 | np.array([100, 200, 300, 200, 100], np.float64), 193 | "average", 194 | np.array([1.5, 3.5, 5.0, 3.5, 1.5], np.float64), 195 | ), 196 | ( 197 | np.array([100, 200, 300, 200, 100], np.float64), 198 | "min", 199 | np.array([1.0, 3.0, 5.0, 3.0, 1.0], np.float64), 200 | ), 201 | ( 202 | np.array([100, 200, 300, 200, 100], np.float64), 203 | "max", 204 | np.array([2.0, 4.0, 5.0, 4.0, 2.0], np.float64), 205 | ), 206 | ( 207 | np.array([100, 200, 300, 200, 100], np.float64), 208 | "dense", 209 | np.array([1.0, 2.0, 3.0, 2.0, 1.0], np.float64), 210 | ), 211 | ( 212 | np.array([100, 200, 300, 200, 100], np.float64), 213 | "ordinal", 214 | np.array([1.0, 3.0, 5.0, 4.0, 2.0], np.float64), 215 | ), 216 | # 217 | ( 218 | np.array([10] * 30, np.float64), 219 | "ordinal", 220 | np.arange(1.0, 31.0, dtype=np.float64), 221 | ), 222 | ], 223 | ) 224 | def test_cases(values, method, expected): 225 | r = rankdata(values, method=method) 226 | assert_array_equal(r, expected) 227 | -------------------------------------------------------------------------------- /pynndescent/graph_utils.py: -------------------------------------------------------------------------------- 1 | import numba 2 | import numpy as np 3 | import heapq 4 | 5 | from scipy.sparse import coo_matrix 6 | from scipy.sparse.csgraph import connected_components 7 | from itertools import combinations 8 | 9 | import pynndescent.distances as pynnd_dist 10 | import joblib 11 | 12 | from pynndescent.utils import ( 13 | rejection_sample, 14 | make_heap, 15 | deheap_sort, 16 | simple_heap_push, 17 | has_been_visited, 18 | mark_visited, 19 | ) 20 | 21 | FLOAT32_EPS = np.finfo(np.float32).eps 22 | 23 | 24 | def create_component_search(index): 25 | alternative_dot = pynnd_dist.alternative_dot 26 | alternative_cosine = pynnd_dist.alternative_cosine 27 | 28 | data = index._raw_data 29 | indptr = index._search_graph.indptr 30 | indices = index._search_graph.indices 31 | dist = index._distance_func 32 | 33 | @numba.njit( 34 | fastmath=True, 35 | nogil=True, 36 | locals={ 37 | "current_query": numba.types.float32[::1], 38 | "i": numba.types.uint32, 39 | "j": numba.types.uint32, 40 | "heap_priorities": numba.types.float32[::1], 41 | "heap_indices": numba.types.int32[::1], 42 | "candidate": numba.types.int32, 43 | "vertex": numba.types.int32, 44 | "d": numba.types.float32, 45 | "d_vertex": numba.types.float32, 46 | "visited": numba.types.uint8[::1], 47 | "indices": numba.types.int32[::1], 48 | "indptr": numba.types.int32[::1], 49 | "data": numba.types.float32[:, ::1], 50 | "heap_size": numba.types.int16, 51 | "distance_scale": numba.types.float32, 52 | "distance_bound": numba.types.float32, 53 | "seed_scale": numba.types.float32, 54 | }, 55 | ) 56 | def custom_search_closure(query_points, candidate_indices, k, epsilon, visited): 57 | result = make_heap(query_points.shape[0], k) 58 | distance_scale = 1.0 + epsilon 59 | 60 | for i in range(query_points.shape[0]): 61 | visited[:] = 0 62 | if dist == alternative_dot or dist == alternative_cosine: 63 | norm = np.sqrt((query_points[i] ** 2).sum()) 64 | if norm > 0.0: 65 | current_query = query_points[i] / norm 66 | else: 67 | continue 68 | else: 69 | current_query = query_points[i] 70 | 71 | heap_priorities = result[1][i] 72 | heap_indices = result[0][i] 73 | seed_set = [(np.float32(np.inf), np.int32(-1)) for j in range(0)] 74 | 75 | ############ Init ################ 76 | n_initial_points = candidate_indices.shape[0] 77 | 78 | for j in range(n_initial_points): 79 | candidate = np.int32(candidate_indices[j]) 80 | d = dist(data[candidate], current_query) 81 | # indices are guaranteed different 82 | simple_heap_push(heap_priorities, heap_indices, d, candidate) 83 | heapq.heappush(seed_set, (d, candidate)) 84 | mark_visited(visited, candidate) 85 | 86 | ############ Search ############## 87 | distance_bound = distance_scale * heap_priorities[0] 88 | 89 | # Find smallest seed point 90 | d_vertex, vertex = heapq.heappop(seed_set) 91 | 92 | while d_vertex < distance_bound: 93 | 94 | for j in range(indptr[vertex], indptr[vertex + 1]): 95 | 96 | candidate = indices[j] 97 | 98 | if has_been_visited(visited, candidate) == 0: 99 | mark_visited(visited, candidate) 100 | 101 | d = dist(data[candidate], current_query) 102 | 103 | if d < distance_bound: 104 | simple_heap_push( 105 | heap_priorities, heap_indices, d, candidate 106 | ) 107 | heapq.heappush(seed_set, (d, candidate)) 108 | # Update bound 109 | distance_bound = distance_scale * heap_priorities[0] 110 | 111 | # find new smallest seed point 112 | if len(seed_set) == 0: 113 | break 114 | else: 115 | d_vertex, vertex = heapq.heappop(seed_set) 116 | 117 | return result 118 | 119 | return custom_search_closure 120 | 121 | 122 | # @numba.njit(nogil=True) 123 | def find_component_connection_edge( 124 | component1, 125 | component2, 126 | search_closure, 127 | raw_data, 128 | visited, 129 | rng_state, 130 | search_size=10, 131 | epsilon=0.0, 132 | ): 133 | indices = [np.zeros(1, dtype=np.int64) for i in range(2)] 134 | indices[0] = component1[ 135 | rejection_sample(np.int64(search_size), component1.shape[0], rng_state) 136 | ] 137 | indices[1] = component2[ 138 | rejection_sample(np.int64(search_size), component2.shape[0], rng_state) 139 | ] 140 | query_side = 0 141 | query_points = raw_data[indices[query_side]] 142 | candidate_indices = indices[1 - query_side].copy() 143 | changed = [True, True] 144 | best_dist = np.inf 145 | best_edge = (indices[0][0], indices[1][0]) 146 | 147 | while changed[0] or changed[1]: 148 | inds, dists, _ = search_closure( 149 | query_points, candidate_indices, search_size, epsilon, visited 150 | ) 151 | inds, dists = deheap_sort(inds, dists) 152 | for i in range(dists.shape[0]): 153 | for j in range(dists.shape[1]): 154 | if dists[i, j] < best_dist: 155 | best_dist = dists[i, j] 156 | best_edge = (indices[query_side][i], inds[i, j]) 157 | candidate_indices = indices[query_side] 158 | new_indices = np.unique(inds[:, 0]) 159 | if indices[1 - query_side].shape[0] == new_indices.shape[0]: 160 | changed[1 - query_side] = np.any(indices[1 - query_side] != new_indices) 161 | indices[1 - query_side] = new_indices 162 | query_points = raw_data[indices[1 - query_side]] 163 | query_side = 1 - query_side 164 | 165 | return best_edge[0], best_edge[1], best_dist 166 | 167 | 168 | def adjacency_matrix_representation(neighbor_indices, neighbor_distances): 169 | result = coo_matrix( 170 | (neighbor_indices.shape[0], neighbor_indices.shape[0]), dtype=np.float32 171 | ) 172 | 173 | # Preserve any distance 0 points 174 | neighbor_distances[neighbor_distances == 0.0] = FLOAT32_EPS 175 | 176 | result.row = np.repeat( 177 | np.arange(neighbor_indices.shape[0], dtype=np.int32), neighbor_indices.shape[1] 178 | ) 179 | result.col = neighbor_indices.ravel() 180 | result.data = neighbor_distances.ravel() 181 | 182 | # Get rid of any -1 index entries 183 | result = result.tocsr() 184 | result.data[result.indices == -1] = 0.0 185 | result.eliminate_zeros() 186 | 187 | # Symmetrize 188 | result = result.maximum(result.T) 189 | 190 | return result 191 | 192 | 193 | def connect_graph(graph, index, search_size=10, n_jobs=None): 194 | 195 | search_closure = create_component_search(index) 196 | n_components, component_ids = connected_components(graph) 197 | result = graph.tolil() 198 | 199 | # Translate component ids into internal vertex order 200 | component_ids = component_ids[index._vertex_order] 201 | 202 | def new_edge(c1, c2): 203 | component1 = np.where(component_ids == c1)[0] 204 | component2 = np.where(component_ids == c2)[0] 205 | 206 | i, j, d = find_component_connection_edge( 207 | component1, 208 | component2, 209 | search_closure, 210 | index._raw_data, 211 | index._visited, 212 | index.rng_state, 213 | search_size=search_size, 214 | ) 215 | 216 | # Correct the distance if required 217 | if index._distance_correction is not None: 218 | d = index._distance_correction(d) 219 | 220 | # Convert indices to original data order 221 | i = index._vertex_order[i] 222 | j = index._vertex_order[j] 223 | 224 | return i, j, d 225 | 226 | new_edges = joblib.Parallel(n_jobs=n_jobs, prefer="threads")( 227 | joblib.delayed(new_edge)(c1, c2) 228 | for c1, c2 in combinations(range(n_components), 2) 229 | ) 230 | 231 | for i, j, d in new_edges: 232 | result[i, j] = d 233 | result[j, i] = d 234 | 235 | return result.tocsr() 236 | -------------------------------------------------------------------------------- /doc/sparse_data_with_pynndescent.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Working with sparse data\n", 8 | "\n", 9 | "Not all data conveniently fits in numpy arrays; sometimes a lot of the data entries are zero and we want to use a sparse data storage format. This is especially common for extremely high dimensional data (data with thousands, or even hundreds of thousands of dimensions). Such data is a lot harder to work with for many tasks, including nearest neighbor search. Let's see how we can work with sparse data like this in PyNNDescent.\n", 10 | "\n", 11 | "First we'll need some data. For that let's use a standard NLP dataset that we can pull together with scikit-learn." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 9, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import pynndescent\n", 21 | "import sklearn.datasets\n", 22 | "import sys" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "metadata": {}, 28 | "source": [ 29 | "We need to fetch the train and test sets separately, but conveniently the data has already been converted from the text of newsgroup messages into [TF-IDF](https://en.wikipedia.org/wiki/Tf%E2%80%93idf) matrices. This means that we have a feature column for each word in the vocabulary -- though often the vocabulary is pruned a little." 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 2, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "news_train = sklearn.datasets.fetch_20newsgroups_vectorized(subset='train')\n", 39 | "news_test = sklearn.datasets.fetch_20newsgroups_vectorized(subset='test')" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "Now that we have the data let's see what it looks like:" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 3, 52 | "metadata": {}, 53 | "outputs": [ 54 | { 55 | "data": { 56 | "text/plain": [ 57 | "<11314x130107 sparse matrix of type ''\n", 58 | "\twith 1787565 stored elements in Compressed Sparse Row format>" 59 | ] 60 | }, 61 | "execution_count": 3, 62 | "metadata": {}, 63 | "output_type": "execute_result" 64 | } 65 | ], 66 | "source": [ 67 | "news_train.data" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "Not a numpy array! Instead it is a [SciPy sparse matrix](https://docs.scipy.org/doc/scipy/reference/sparse.html). It has 11314 rows (so not many data samples), but 130107 columns (a *lot* of features -- as noted, one for each word in the vocabulary). Despite that large size there are only 17876565 non-zero entries. The trick with sparse matrices is that they only store information about those entries that aren't zero -- they need to keep track of where they are, and what the value is, but they can ignore all the zero entries. If this were a raw numpy array with all those zeros in place it would have ... " 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 7, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "data": { 84 | "text/plain": [ 85 | "1472030598" 86 | ] 87 | }, 88 | "execution_count": 7, 89 | "metadata": {}, 90 | "output_type": "execute_result" 91 | } 92 | ], 93 | "source": [ 94 | "11314 * 130107" 95 | ] 96 | }, 97 | { 98 | "cell_type": "markdown", 99 | "metadata": {}, 100 | "source": [ 101 | "... a lot of entries. To store all of that in memory would require (at 4 bytes per entry) ..." 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": 8, 107 | "metadata": {}, 108 | "outputs": [ 109 | { 110 | "data": { 111 | "text/plain": [ 112 | "5.48374130576849" 113 | ] 114 | }, 115 | "execution_count": 8, 116 | "metadata": {}, 117 | "output_type": "execute_result" 118 | } 119 | ], 120 | "source": [ 121 | "(11314 * 130107 * 4) / 1024**3" 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "... almost 5.5 GB! That is possible, but likely impractical on a laptop. And this is for a case with a small number of data samples. With more samples the size would grow enormous very quickly indeed. Instead we have an object that uses ..." 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 13, 134 | "metadata": {}, 135 | "outputs": [ 136 | { 137 | "data": { 138 | "text/plain": [ 139 | "15.385956764221191" 140 | ] 141 | }, 142 | "execution_count": 13, 143 | "metadata": {}, 144 | "output_type": "execute_result" 145 | } 146 | ], 147 | "source": [ 148 | "(\n", 149 | " news_train.data.data.nbytes \n", 150 | " + news_train.data.indices.size \n", 151 | " + news_train.data.indptr.nbytes\n", 152 | ") / 1024**2" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "metadata": {}, 158 | "source": [ 159 | "... only 15 MB. You will also note that to extract that information required poking at some of the internal attributes of the sparse matrix (``data``, ``indices``, and ``indptr``). This is the downside of the sparse format -- they are more complicated to work with. It is certainly the case that many tools are simply not able to deal with these sparse structures at all, and the data would need to be cast to a numpy array and take up that full amount of memory.\n", 160 | "\n", 161 | "Fortunately PyNNDescent is built to work with sparse matrix data. To see that in practice let's hand the sparse matrix directly to ``NNDescent`` and watch it work." 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 4, 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "name": "stdout", 171 | "output_type": "stream", 172 | "text": [ 173 | "CPU times: user 7min 3s, sys: 1.77 s, total: 7min 5s\n", 174 | "Wall time: 2min 24s\n" 175 | ] 176 | } 177 | ], 178 | "source": [ 179 | "%%time\n", 180 | "index = pynndescent.NNDescent(news_train.data, metric=\"cosine\")" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": {}, 186 | "source": [ 187 | "You will note that this is a much longer index construction time than we would normally expect with only around eleven thousand data points -- there is overhead in working with sparse matrices that makes it slower. That, combined with the fact that the data has over a hundred and thirty-thousand dimensions means this is a computationally intensive task. Still it is likely better than working with the full 5.5 GB numpy array, and certainly better when dealing with larger sparse matrices where there is simply no way to instantiate a numpy array large enough to hold the data.\n", 188 | "\n", 189 | "We can query the index -- but we have to use the same sparse matrix structure (we can't query with numpy arrays for am index built with sparse data). Fortunately the test data is already in that format so we can simply perform the query:" 190 | ] 191 | }, 192 | { 193 | "cell_type": "code", 194 | "execution_count": 5, 195 | "metadata": {}, 196 | "outputs": [ 197 | { 198 | "name": "stdout", 199 | "output_type": "stream", 200 | "text": [ 201 | "CPU times: user 1min 25s, sys: 1.5 s, total: 1min 26s\n", 202 | "Wall time: 1min 26s\n" 203 | ] 204 | } 205 | ], 206 | "source": [ 207 | "%%time\n", 208 | "neighbors = index.query(news_test.data)" 209 | ] 210 | }, 211 | { 212 | "cell_type": "markdown", 213 | "metadata": {}, 214 | "source": [ 215 | "And that's it! Everything works essentially transparently with sparse data -- it is just slower. Still, slower is a lot better than not working at all." 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 6, 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "data": { 225 | "text/plain": [ 226 | "(array([[ 1635, 4487, 9220, ..., 10071, 9572, 1793],\n", 227 | " [ 8648, 567, 2123, ..., 783, 6031, 9275],\n", 228 | " [ 7345, 4852, 4674, ..., 1679, 7518, 4228],\n", 229 | " ...,\n", 230 | " [ 6137, 10518, 6469, ..., 6937, 11083, 6164],\n", 231 | " [ 2926, 2011, 1679, ..., 7229, 1635, 11270],\n", 232 | " [ 3215, 3665, 4899, ..., 10810, 9907, 9311]], dtype=int32),\n", 233 | " array([[0.34316891, 0.34799916, 0.35074973, ..., 0.35952544, 0.36456954,\n", 234 | " 0.36472946],\n", 235 | " [0.25881797, 0.26937056, 0.28056568, ..., 0.29210907, 0.29795945,\n", 236 | " 0.29900843],\n", 237 | " [0.41124642, 0.42213005, 0.4426493 , ..., 0.46922857, 0.4724912 ,\n", 238 | " 0.47429329],\n", 239 | " ...,\n", 240 | " [0.21533132, 0.22482073, 0.24046886, ..., 0.26193857, 0.26805884,\n", 241 | " 0.26866162],\n", 242 | " [0.19485909, 0.19515198, 0.19891578, ..., 0.20851403, 0.21159202,\n", 243 | " 0.21265447],\n", 244 | " [0.43528455, 0.43638128, 0.4378109 , ..., 0.45176154, 0.452402 ,\n", 245 | " 0.45243692]]))" 246 | ] 247 | }, 248 | "execution_count": 6, 249 | "metadata": {}, 250 | "output_type": "execute_result" 251 | } 252 | ], 253 | "source": [ 254 | "neighbors" 255 | ] 256 | }, 257 | { 258 | "cell_type": "markdown", 259 | "metadata": {}, 260 | "source": [ 261 | "One final caveat is that custom distance metrics for sparse data need to be able to work with sparse data and thus have a different function signature. In practice this is really something you only want to try if you are familiar with working with sparse data structures. If that's the case then you can look through ``pynndescent.sparse.py`` for examples of many common distance functions and it will quickly become clear what is required." 262 | ] 263 | } 264 | ], 265 | "metadata": { 266 | "kernelspec": { 267 | "display_name": "Python 3", 268 | "language": "python", 269 | "name": "python3" 270 | }, 271 | "language_info": { 272 | "codemirror_mode": { 273 | "name": "ipython", 274 | "version": 3 275 | }, 276 | "file_extension": ".py", 277 | "mimetype": "text/x-python", 278 | "name": "python", 279 | "nbconvert_exporter": "python", 280 | "pygments_lexer": "ipython3", 281 | "version": "3.8.0" 282 | } 283 | }, 284 | "nbformat": 4, 285 | "nbformat_minor": 4 286 | } 287 | -------------------------------------------------------------------------------- /pynndescent/sparse_nndescent.py: -------------------------------------------------------------------------------- 1 | # Author: Leland McInnes 2 | # Enough simple sparse operations in numba to enable sparse UMAP 3 | # 4 | # License: BSD 3 clause 5 | from __future__ import print_function 6 | import numpy as np 7 | import numba 8 | 9 | from pynndescent.utils import ( 10 | tau_rand_int, 11 | make_heap, 12 | new_build_candidates, 13 | deheap_sort, 14 | checked_flagged_heap_push, 15 | sparse_generate_graph_update_array, 16 | apply_graph_update_array, 17 | EMPTY_GRAPH, 18 | ) 19 | 20 | from pynndescent.sparse import sparse_euclidean 21 | 22 | 23 | @numba.njit(parallel=True, cache=False) 24 | def generate_leaf_updates( 25 | updates, 26 | n_updates_per_thread, 27 | leaf_block, 28 | dist_thresholds, 29 | inds, 30 | indptr, 31 | data, 32 | dist, 33 | n_threads, 34 | ): 35 | """Generate leaf updates into pre-allocated arrays for parallel efficiency.""" 36 | n_leaves = leaf_block.shape[0] 37 | leaves_per_thread = (n_leaves + n_threads - 1) // n_threads 38 | 39 | # Reset update counts 40 | for t in range(n_threads): 41 | n_updates_per_thread[t] = 0 42 | 43 | for t in numba.prange(n_threads): 44 | start_leaf = t * leaves_per_thread 45 | end_leaf = min(start_leaf + leaves_per_thread, n_leaves) 46 | max_updates = updates.shape[1] 47 | count = 0 48 | 49 | for leaf_idx in range(start_leaf, end_leaf): 50 | for i in range(leaf_block.shape[1]): 51 | p = leaf_block[leaf_idx, i] 52 | if p < 0: 53 | break 54 | 55 | for j in range(i + 1, leaf_block.shape[1]): 56 | q = leaf_block[leaf_idx, j] 57 | if q < 0: 58 | break 59 | 60 | from_inds = inds[indptr[p] : indptr[p + 1]] 61 | from_data = data[indptr[p] : indptr[p + 1]] 62 | 63 | to_inds = inds[indptr[q] : indptr[q + 1]] 64 | to_data = data[indptr[q] : indptr[q + 1]] 65 | d = dist(from_inds, from_data, to_inds, to_data) 66 | 67 | if d < dist_thresholds[p] or d < dist_thresholds[q]: 68 | if count < max_updates: 69 | updates[t, count, 0] = np.float32(p) 70 | updates[t, count, 1] = np.float32(q) 71 | updates[t, count, 2] = d 72 | count += 1 73 | 74 | n_updates_per_thread[t] = count 75 | 76 | return updates 77 | 78 | 79 | @numba.njit( 80 | locals={"d": numba.float32, "p": numba.int32, "q": numba.int32}, 81 | cache=False, 82 | parallel=True, 83 | ) 84 | def init_rp_tree(inds, indptr, data, dist, current_graph, leaf_array, n_threads=8): 85 | n_leaves = leaf_array.shape[0] 86 | block_size = n_threads * 64 87 | n_blocks = n_leaves // block_size 88 | 89 | max_leaf_size = leaf_array.shape[1] 90 | updates_per_thread = ( 91 | int(block_size * max_leaf_size * (max_leaf_size - 1) / (2 * n_threads)) + 1 92 | ) 93 | updates = np.zeros((n_threads, updates_per_thread, 3), dtype=np.float32) 94 | n_updates_per_thread = np.zeros(n_threads, dtype=np.int32) 95 | 96 | n_vertices = current_graph[0].shape[0] 97 | vertex_block_size = n_vertices // n_threads + 1 98 | 99 | for i in range(n_blocks + 1): 100 | block_start = i * block_size 101 | block_end = min(n_leaves, (i + 1) * block_size) 102 | 103 | leaf_block = leaf_array[block_start:block_end] 104 | dist_thresholds = current_graph[1][:, 0] 105 | 106 | generate_leaf_updates( 107 | updates, 108 | n_updates_per_thread, 109 | leaf_block, 110 | dist_thresholds, 111 | inds, 112 | indptr, 113 | data, 114 | dist, 115 | n_threads, 116 | ) 117 | 118 | for t in numba.prange(n_threads): 119 | v_block_start = t * vertex_block_size 120 | v_block_end = min(v_block_start + vertex_block_size, n_vertices) 121 | 122 | for j in range(n_threads): 123 | for k in range(n_updates_per_thread[j]): 124 | p = np.int32(updates[j, k, 0]) 125 | q = np.int32(updates[j, k, 1]) 126 | d = updates[j, k, 2] 127 | 128 | if p >= v_block_start and p < v_block_end: 129 | checked_flagged_heap_push( 130 | current_graph[1][p], 131 | current_graph[0][p], 132 | current_graph[2][p], 133 | d, 134 | q, 135 | np.uint8(1), 136 | ) 137 | if q >= v_block_start and q < v_block_end: 138 | checked_flagged_heap_push( 139 | current_graph[1][q], 140 | current_graph[0][q], 141 | current_graph[2][q], 142 | d, 143 | p, 144 | np.uint8(1), 145 | ) 146 | 147 | 148 | @numba.njit( 149 | fastmath=True, 150 | locals={"d": numba.float32, "i": numba.int32, "idx": numba.int32}, 151 | cache=False, 152 | ) 153 | def init_random(n_neighbors, inds, indptr, data, heap, dist, rng_state): 154 | n_samples = indptr.shape[0] - 1 155 | for i in range(n_samples): 156 | if heap[0][i, 0] < 0.0: 157 | for j in range(n_neighbors - np.sum(heap[0][i] >= 0.0)): 158 | idx = np.abs(tau_rand_int(rng_state)) % n_samples 159 | 160 | from_inds = inds[indptr[idx] : indptr[idx + 1]] 161 | from_data = data[indptr[idx] : indptr[idx + 1]] 162 | 163 | to_inds = inds[indptr[i] : indptr[i + 1]] 164 | to_data = data[indptr[i] : indptr[i + 1]] 165 | d = dist(from_inds, from_data, to_inds, to_data) 166 | 167 | checked_flagged_heap_push( 168 | heap[1][i], heap[0][i], heap[2][i], d, idx, np.uint8(1) 169 | ) 170 | 171 | return 172 | 173 | 174 | @numba.njit(cache=False) 175 | def sparse_process_candidates( 176 | inds, 177 | indptr, 178 | data, 179 | dist, 180 | current_graph, 181 | new_candidate_neighbors, 182 | old_candidate_neighbors, 183 | n_blocks, 184 | block_size, 185 | n_threads, 186 | update_array, 187 | n_updates_per_thread, 188 | ): 189 | """Process candidate neighbors for sparse data using array-based updates.""" 190 | c = 0 191 | n_vertices = new_candidate_neighbors.shape[0] 192 | for i in range(n_blocks + 1): 193 | block_start = i * block_size 194 | block_end = min(n_vertices, (i + 1) * block_size) 195 | 196 | new_candidate_block = new_candidate_neighbors[block_start:block_end] 197 | old_candidate_block = old_candidate_neighbors[block_start:block_end] 198 | 199 | dist_thresholds = current_graph[1][:, 0] 200 | 201 | sparse_generate_graph_update_array( 202 | update_array, 203 | n_updates_per_thread, 204 | new_candidate_block, 205 | old_candidate_block, 206 | dist_thresholds, 207 | inds, 208 | indptr, 209 | data, 210 | dist, 211 | n_threads, 212 | ) 213 | 214 | c += apply_graph_update_array( 215 | current_graph, update_array, n_updates_per_thread, n_threads 216 | ) 217 | 218 | return c 219 | 220 | 221 | @numba.njit() 222 | def nn_descent_internal( 223 | current_graph, 224 | inds, 225 | indptr, 226 | data, 227 | n_neighbors, 228 | rng_state, 229 | max_candidates=50, 230 | dist=sparse_euclidean, 231 | n_iters=10, 232 | delta=0.001, 233 | verbose=False, 234 | ): 235 | n_vertices = indptr.shape[0] - 1 236 | block_size = 16384 237 | n_blocks = n_vertices // block_size 238 | n_threads = numba.get_num_threads() 239 | 240 | # Pre-allocate update arrays 241 | max_updates_per_thread = ( 242 | int( 243 | (max_candidates**2 + max_candidates * (max_candidates - 1) / 2) 244 | * block_size 245 | / n_threads 246 | ) 247 | + 1024 248 | ) 249 | update_array = np.empty((n_threads, max_updates_per_thread, 3), dtype=np.float32) 250 | n_updates_per_thread = np.zeros(n_threads, dtype=np.int32) 251 | 252 | for n in range(n_iters): 253 | if verbose: 254 | print("\t", n + 1, " / ", n_iters) 255 | 256 | (new_candidate_neighbors, old_candidate_neighbors) = new_build_candidates( 257 | current_graph, max_candidates, rng_state, n_threads 258 | ) 259 | 260 | c = sparse_process_candidates( 261 | inds, 262 | indptr, 263 | data, 264 | dist, 265 | current_graph, 266 | new_candidate_neighbors, 267 | old_candidate_neighbors, 268 | n_blocks, 269 | block_size, 270 | n_threads, 271 | update_array, 272 | n_updates_per_thread, 273 | ) 274 | 275 | if c <= delta * n_neighbors * n_vertices: 276 | if verbose: 277 | print("\tStopping threshold met -- exiting after", n + 1, "iterations") 278 | return 279 | 280 | 281 | @numba.njit() 282 | def nn_descent( 283 | inds, 284 | indptr, 285 | data, 286 | n_neighbors, 287 | rng_state, 288 | max_candidates=50, 289 | dist=sparse_euclidean, 290 | n_iters=10, 291 | delta=0.001, 292 | init_graph=EMPTY_GRAPH, 293 | rp_tree_init=True, 294 | leaf_array=None, 295 | low_memory=False, 296 | verbose=False, 297 | ): 298 | 299 | n_samples = indptr.shape[0] - 1 300 | 301 | if init_graph[0].shape[0] == 1: # EMPTY_GRAPH 302 | current_graph = make_heap(n_samples, n_neighbors) 303 | 304 | if rp_tree_init: 305 | init_rp_tree(inds, indptr, data, dist, current_graph, leaf_array) 306 | 307 | init_random(n_neighbors, inds, indptr, data, current_graph, dist, rng_state) 308 | elif init_graph[0].shape[0] == n_samples and init_graph[0].shape[1] == n_neighbors: 309 | current_graph = init_graph 310 | else: 311 | raise ValueError("Invalid initial graph specified!") 312 | 313 | # Note: low_memory parameter is kept for API compatibility but 314 | # now uses the efficient array-based implementation 315 | nn_descent_internal( 316 | current_graph, 317 | inds, 318 | indptr, 319 | data, 320 | n_neighbors, 321 | rng_state, 322 | max_candidates=max_candidates, 323 | dist=dist, 324 | n_iters=n_iters, 325 | delta=delta, 326 | verbose=verbose, 327 | ) 328 | 329 | return deheap_sort(current_graph[0], current_graph[1]) 330 | -------------------------------------------------------------------------------- /doc/performance.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# PyNNDescent Performance\n", 8 | "\n", 9 | "How fast is PyNNDescent for approximate nearest neighbor search? How does it compare with other approximate nearest neighbor search algorithms and implementations? To answer these kinds of questions we'll make use of the [ann-benchmarks](https://github.com/erikbern/ann-benchmarks) suite of tools for benchmarking approximate nearest neighbor (ANN) search algorithms. The suite provides a wide array of datasets to benchmark on, and supports a wide array of ANN search libraries. Since the runtime of these benchmarks is quite large we'll be presenting results obtained earlier, and only for a selection of datasets and for the main state-of-the-art implementations. This page thus reflects the performance at a given point in time, and on a specific choice of benchmarking hardware. Implementations may (and likely will) improve, and different hardware will likely result in somewhat different performance characteristics amongst the implementations benchmarked here.\n", 10 | "\n", 11 | "We chose the following implementations of ANN search based on their strong performance in ANN search benchmarks in general:\n", 12 | "\n", 13 | " * Annoy (a tree based algorithm for comparison)\n", 14 | " * HNSW from FAISS, Facebooks ANN library\n", 15 | " * HNSW from nmslib, the reference implementation of the algorithm\n", 16 | " * HNSW from hnswlib, a small spinoff library from nmslib\n", 17 | " * ONNG from NGT, a more recent algorithm and implementaton with impressive performance\n", 18 | " * PyNNDescent version 0.5\n", 19 | " \n", 20 | "Not all the algorithms ran entirely successfully on all the datasets; where an algorithm gave spurious or unrepresentative results we have left it off rather the given benchmark.\n", 21 | "\n", 22 | "The ann-benchmark suite is designed to look at the trade-off in performance between search accuracy and search speed (or other performance statistic, such as index creation time, or index size). Since this is a trade-off that can often be tuned by appropriately adjusting parameters ann-benchmarks handles this by running a predefined (for each algorithm or implementation) range of parameters. It then finds the [pareto frontier](https://en.wikipedia.org/wiki/Pareto_efficiency#Use_in_engineering) for the optimal speed / accuracy trade-off and presents this as a curve. The various implementations can then be compared in terms of the pareto frontier curves. The default choices of measure for ann-benchmarks puts recall (effective search accuracy) along the x-axis and queries-per-second (search speed) on the y-axis. Thus curves that are further up and / or more to the right are providing better speed and / or more accuracy.\n", 23 | "\n", 24 | "\n", 25 | "To get a good overview of the relative performance characteristics of the different implementations we'll look at the speed / accuracy trade-off curves for a variety of datasets. This is because the dataset size, dimensionality, distribution and metric can all have non-trivial impacts on performance in various ways, and results for one dataset are not necessarily representative of how things will look for a different dataset. We will introduce each dataset in turn, and then look at the performance curves. To start with we'll consider datasets which use Euclidean distance." 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "## Euclidean distance" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "Euclidean distance is the usual notion of distance that we are familiar with in everyday life, just extended to arbitrary dimensions (instead of only two or three). It is defined as $d(\\bar{x}, \\bar{y}) = \\sum_i (x_i - y_i)^2$ for vectors $\\bar{x} = (x_1, x_2, \\ldots, x_D)$ and $\\bar{y} = (y_1, y_2, \\ldots, y_D)$. It is widely used as a distance measure, but can have difficulties with high dimensional data in some cases.\n", 40 | "\n", 41 | "The first dataset we will consider that uses Euclidean distance is the MNIST dataset. MNIST consists of grayscale images of handwritten digits (from 0 to 9). Each digit image is 28 by 28 pixels, which is usually unravelled into a single vectors of 784 dimensions. In total there are 70,000 images in the dataset, and ann-benchmarks uses the usual split into 60,000 training samples and 10,000 test samples. The ANN index is built on the training set, and then the test set is used as the query data for benchmarking." 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "
\"MNIST
" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "Remember that up and to the right is better. Also note that the y axis (queries per second) is plotted in *log scale* so each major grid step represents an order of magnitude performance difference. We can see that PyNNDescent performs very well here, outpacing the other ANN libraries in the high accuracy range. It is worth noting, however, that for lower accuracy queries it finishes essentially on par with ONNG, and unlike ONNG and nmslib's HNSW implementation, it does not extend to very high performance but low accuracy queries. If speed is absolutely paramount, and you only need to be in the vaguely right ballpark for accuracy then PyNNDescent may not be the right choice here.\n", 56 | "\n", 57 | "Next up for dataset is Fashion-MNIST. This was a dataset designed to be a drop in replacement for MNIST, but meant to be more challenging for machine learning tasks. Instead of grayscale images of digits it is grayscale images of fashion items (dresses, shirts, pants, boots, sandas, handbags, etc.). Just like MNIST each image is 28 by 28 pixels resulting in 784-dimensional vectors. Also just like MNIST there are 70,000 total images, split into 60,000 training images and 10,000 test images." 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "
\"Fashion-MNIST
" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "Again we see a very similar result (although this should not entirely be a surprise given the similarity of the dataset in terms of the number of samples and dimensionality). PyNNDescent performs very well in the high accuracy regime, but does not scale to the very high performance but low accuracy ranges that ONNG and nmslib's HNSW manage. It is also worth noting the clear difference between the various graph based search algorithms and the tree based Annoy -- while Annoy is a very impressive ANN search implementation it compares poorly to the graph based search techniques on these datasets.\n", 72 | "\n", 73 | "Next up is the SIFT dataset. SIFT stands for [Scale-Invariant Feature Transform](https://en.wikipedia.org/wiki/Scale-invariant_feature_transform) and is a technique from compute vision for generating feature vectors from images. For ann-benchmarks this means that there exist some large databases of SIFT features from image datasets which can be used to test nearest neighbor search. In particular the SIFT dataset in ann-benchmarks is a dataset of one million SIFT vectors where each vector is 128-dimensional. This provides a good contrast to the earlier datasets which had relatively high dimensionality, but not an especially large number of samples. For ann-benchmarks the dataset is split into 990,000 training samples, and 10,000 test samples for querying with. " 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "
\"SIFT-128
" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "Again we see that PyNNDescent performs very well. This time, however, with the more challenging search problem presented by a training set this large, it does produce some lower accuracy searches and in those cases both ONG and nmslib's HNSW outperform it. It's also worth noting that in this lower dimensional dataset Annoy performs better, comparatively, than the previous datasets. Still, over the Euclidean distance datasets tested here PyNNDescent remains a clear winner for high accuracy queries. Let's move on to the angular distance based datasets." 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "## Angular distance\n", 95 | "\n", 96 | "Angular based distances measure the similarity of two vectors in terms of the angle they span -- the greater the angle the larger the distance between the vectors. Thus two vectors of different length can be viewed as being very close as long as they are pointing in the same direction. Another way of looking at this is to imagine that the data is being projected onto a high dimensional sphere (by intersecting a ray in the vectors direction with a unit sphere), and distances are measured in terms of arcs around the sphere.\n", 97 | "\n", 98 | "In practice the most commonly used angular distance is cosine distance, defined as\n", 99 | "\n", 100 | "$$d(\\bar{x}, \\bar{y}) = 1 - \\sum_i \\frac{x_i y_i}{\\|\\bar{x}\\|_2 \\|\\bar{y}\\|_2}$$\n", 101 | "\n", 102 | "where $\\|\\bar{x}\\|_2$ denotes the $\\ell^2$ [norm](https://en.wikipedia.org/wiki/Norm_\n", 103 | "(mathematics)#Euclidean_norm) of $\\bar{x}$. To see why this is a measure of angular distance note that $\\sum_i x_i y_i$ is the euclidean dot product of $\\bar{x}$ and $\\bar{y}$ and that the euclidean dot product formula gives $\\bar{x}\\cdot \\bar{y} = \\|x\\|_2 \\|y\\|_2 \\cos\\theta$ where $\\theta$ is the angle between the vectors.\n", 104 | "\n", 105 | "In the case where the vectors all have unit norm the cosine distance reduces to just one minus the dot product of the vectors -- which is sometimes used as an angular distance measure. Indeed, that is the case for our first dataset, the LastFM dataset. This dataset is constructed of 64 factors in a recommendation system for the Last FM online music service. It contains 292,385 training samples and 50,000 test samples. Compared to the other datasets explored so far this is considerably lower dimensional and the distance computation is simpler. Let's see what results we get." 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "metadata": {}, 111 | "source": [ 112 | "
\"LastFM
" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "Here we see hnswlib and HNSW from nmslib performing extremely well -- outpacing ONNG unlike we saw in the previous euclidean datasets. The HNSW implementation is FAISS is further behind. While PyNNDescent is not the fastest option on this dataset it is highly competitive with the two top performing HNSW implementations.\n", 120 | "\n", 121 | "The next dataset is a GloVe dataset of word vectors. The GloVe datasets are generated from a word-word co-occurrence count matrix generated from vast collections of text. Each word that occurs (frequently enough) in the text will get a resulting vector, with the principle that words with similar meanings will be assigned vectors that are similar (in angular distance). The dimensionality of the generated vectors is an input to the GloVe algorithm. For the first of the the GloVe datasets we will be looking at the 25 dimensional vectors. Since GloVe vectors were trained useing a vast corpus there are over one million different words represented, and thus we have 1,183,514 training samples and 10,000 test samples to work with. This gives is a low dimensional but extremely large dataset to work with." 122 | ] 123 | }, 124 | { 125 | "cell_type": "markdown", 126 | "metadata": {}, 127 | "source": [ 128 | "
\"GloVe-25
" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "metadata": {}, 134 | "source": [ 135 | "In this case PyNNDescent and hnswlib are the apparent winners -- although PyNNDescent, similar to the earlier examples, performs less well once we get below abot 80% accuracy.\n", 136 | "\n", 137 | "Next we'll move up to a higher dimensional version of GloVe vectors. These vectors were trained on the same underlying text dataset, so we have the same number of samples (both for train and test), but now we have 100 dimensional vectors. This makes the problem more challenging as the underlying distance computation is a little more expensive given the higher dimensionality." 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "
\"GloVe-100
" 145 | ] 146 | }, 147 | { 148 | "cell_type": "markdown", 149 | "metadata": {}, 150 | "source": [ 151 | "This time it is ONNG that surges to the front of the pack. Relatively speaking PyNNDescent is not too far behind. This goes to show, however, how much performance can vary based on the exact nature of the dataset: while ONNG was a (relatively) poor performer on the 25-dimensional version of this data with hnswlib out in front, the roles are reversed for this 100-dimensional data.\n", 152 | "\n", 153 | "The last dataset is the NY-Times dataset. This is data generated as dimension reduced (via PCA) [TF-IDF](https://en.wikipedia.org/wiki/Tf%E2%80%93idf) vectors of New York Times articles. The resulting dataset has 290,000 training samples and 10,000 test samples in 256 dimensions. This is quite a challenging dataset, and all the algorithms have significantly lower query-per-second performance on this data." 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "metadata": {}, 159 | "source": [ 160 | "
\"NY-Times
" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "Here we see that PyNNDescent and ONNG are the best performing implementations, particularly at the higher accuracy range (ONNG has a slight edge on PyNNDescent here).\n", 168 | "\n", 169 | "This concludes our examination of performance for now. Having examined performance for many different datasets it is clear that the various algorithms and implementations vary in performance depending on the exact nature of the data. None the less, we hope that this has demonstrated that PyNNDescent has excellent performance characteristics across a wide variety of datasets, often performing better than many state-of-the-art implementations." 170 | ] 171 | } 172 | ], 173 | "metadata": { 174 | "kernelspec": { 175 | "display_name": "Python [conda env:umap_0.5dev]", 176 | "language": "python", 177 | "name": "conda-env-umap_0.5dev-py" 178 | }, 179 | "language_info": { 180 | "codemirror_mode": { 181 | "name": "ipython", 182 | "version": 3 183 | }, 184 | "file_extension": ".py", 185 | "mimetype": "text/x-python", 186 | "name": "python", 187 | "nbconvert_exporter": "python", 188 | "pygments_lexer": "ipython3", 189 | "version": "3.8.1" 190 | }, 191 | "pycharm": { 192 | "stem_cell": { 193 | "cell_type": "raw", 194 | "source": [], 195 | "metadata": { 196 | "collapsed": false 197 | } 198 | } 199 | } 200 | }, 201 | "nbformat": 4, 202 | "nbformat_minor": 4 203 | } -------------------------------------------------------------------------------- /pynndescent/tests/test_distances.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | from numpy.testing import assert_array_equal, assert_array_almost_equal 4 | import pynndescent.distances as dist 5 | import pynndescent.sparse as spdist 6 | from scipy import stats 7 | from scipy.sparse import csr_matrix 8 | from scipy.version import full_version as scipy_full_version 9 | from sklearn.metrics import pairwise_distances 10 | from sklearn.neighbors import BallTree 11 | from sklearn.preprocessing import normalize 12 | 13 | 14 | @pytest.mark.parametrize( 15 | "metric", 16 | [ 17 | "euclidean", 18 | "manhattan", 19 | "chebyshev", 20 | "minkowski", 21 | "hamming", 22 | "canberra", 23 | "braycurtis", 24 | "cosine", 25 | "correlation", 26 | ], 27 | ) 28 | def test_spatial_check(spatial_data, metric): 29 | dist_matrix = pairwise_distances(spatial_data, metric=metric) 30 | # scipy is bad sometimes 31 | if metric == "braycurtis": 32 | dist_matrix[np.where(~np.isfinite(dist_matrix))] = 0.0 33 | if metric in ("cosine", "correlation"): 34 | dist_matrix[np.where(~np.isfinite(dist_matrix))] = 1.0 35 | # And because distance between all zero vectors should be zero 36 | dist_matrix[10, 11] = 0.0 37 | dist_matrix[11, 10] = 0.0 38 | dist_function = dist.named_distances[metric] 39 | test_matrix = np.array( 40 | [ 41 | [ 42 | dist_function(spatial_data[i], spatial_data[j]) 43 | for j in range(spatial_data.shape[0]) 44 | ] 45 | for i in range(spatial_data.shape[0]) 46 | ] 47 | ) 48 | assert_array_almost_equal( 49 | test_matrix, 50 | dist_matrix, 51 | err_msg="Distances don't match for metric {}".format(metric), 52 | ) 53 | 54 | 55 | @pytest.mark.parametrize( 56 | "metric", 57 | [ 58 | "jaccard", 59 | "matching", 60 | "dice", 61 | "rogerstanimoto", 62 | "russellrao", 63 | "sokalmichener", 64 | "sokalsneath", 65 | "yule", 66 | ], 67 | ) 68 | def test_binary_check(binary_data, metric): 69 | dist_matrix = pairwise_distances(binary_data, metric=metric) 70 | if metric in ("jaccard", "dice", "sokalsneath", "yule"): 71 | dist_matrix[np.where(~np.isfinite(dist_matrix))] = 0.0 72 | if metric == "russellrao": 73 | dist_matrix[np.where(~np.isfinite(dist_matrix))] = 0.0 74 | # And because distance between all zero vectors should be zero 75 | dist_matrix[10, 11] = 0.0 76 | dist_matrix[11, 10] = 0.0 77 | dist_function = dist.named_distances[metric] 78 | test_matrix = np.array( 79 | [ 80 | [ 81 | dist_function(binary_data[i], binary_data[j]) 82 | for j in range(binary_data.shape[0]) 83 | ] 84 | for i in range(binary_data.shape[0]) 85 | ] 86 | ) 87 | assert_array_almost_equal( 88 | test_matrix, 89 | dist_matrix, 90 | err_msg="Distances don't match for metric {}".format(metric), 91 | ) 92 | 93 | 94 | @pytest.mark.parametrize( 95 | "metric", 96 | [ 97 | "euclidean", 98 | "manhattan", 99 | "chebyshev", 100 | "minkowski", 101 | "hamming", 102 | "canberra", 103 | "cosine", 104 | "braycurtis", 105 | "correlation", 106 | ], 107 | ) 108 | def test_sparse_spatial_check(sparse_spatial_data, metric, decimal=6): 109 | if metric in spdist.sparse_named_distances: 110 | dist_matrix = pairwise_distances( 111 | np.asarray(sparse_spatial_data.todense()).astype(np.float32), metric=metric 112 | ) 113 | if metric in ("braycurtis", "dice", "sokalsneath", "yule"): 114 | dist_matrix[np.where(~np.isfinite(dist_matrix))] = 0.0 115 | if metric in ("cosine", "correlation", "russellrao"): 116 | dist_matrix[np.where(~np.isfinite(dist_matrix))] = 1.0 117 | # And because distance between all zero vectors should be zero 118 | dist_matrix[10, 11] = 0.0 119 | dist_matrix[11, 10] = 0.0 120 | 121 | dist_function = spdist.sparse_named_distances[metric] 122 | if metric in spdist.sparse_need_n_features: 123 | test_matrix = np.array( 124 | [ 125 | [ 126 | dist_function( 127 | sparse_spatial_data[i].indices, 128 | sparse_spatial_data[i].data, 129 | sparse_spatial_data[j].indices, 130 | sparse_spatial_data[j].data, 131 | sparse_spatial_data.shape[1], 132 | ) 133 | for j in range(sparse_spatial_data.shape[0]) 134 | ] 135 | for i in range(sparse_spatial_data.shape[0]) 136 | ] 137 | ) 138 | else: 139 | test_matrix = np.array( 140 | [ 141 | [ 142 | dist_function( 143 | sparse_spatial_data[i].indices, 144 | sparse_spatial_data[i].data, 145 | sparse_spatial_data[j].indices, 146 | sparse_spatial_data[j].data, 147 | ) 148 | for j in range(sparse_spatial_data.shape[0]) 149 | ] 150 | for i in range(sparse_spatial_data.shape[0]) 151 | ] 152 | ) 153 | assert_array_almost_equal( 154 | test_matrix, 155 | dist_matrix, 156 | err_msg="Sparse distances don't match for metric {}".format(metric), 157 | decimal=decimal, 158 | ) 159 | 160 | 161 | @pytest.mark.parametrize( 162 | "metric", 163 | [ 164 | "jaccard", 165 | "matching", 166 | "dice", 167 | "rogerstanimoto", 168 | "russellrao", 169 | "sokalmichener", 170 | "sokalsneath", 171 | ], 172 | ) 173 | def test_sparse_binary_check(sparse_binary_data, metric): 174 | if metric in spdist.sparse_named_distances: 175 | dist_matrix = pairwise_distances( 176 | np.asarray(sparse_binary_data.todense()), metric=metric 177 | ) 178 | if metric in ("jaccard", "dice", "sokalsneath"): 179 | dist_matrix[np.where(~np.isfinite(dist_matrix))] = 0.0 180 | if metric == "russellrao": 181 | dist_matrix[np.where(~np.isfinite(dist_matrix))] = 1.0 182 | # And because distance between all zero vectors should be zero 183 | dist_matrix[10, 11] = 0.0 184 | dist_matrix[11, 10] = 0.0 185 | 186 | dist_function = spdist.sparse_named_distances[metric] 187 | if metric in spdist.sparse_need_n_features: 188 | test_matrix = np.array( 189 | [ 190 | [ 191 | dist_function( 192 | sparse_binary_data[i].indices, 193 | sparse_binary_data[i].data, 194 | sparse_binary_data[j].indices, 195 | sparse_binary_data[j].data, 196 | sparse_binary_data.shape[1], 197 | ) 198 | for j in range(sparse_binary_data.shape[0]) 199 | ] 200 | for i in range(sparse_binary_data.shape[0]) 201 | ] 202 | ) 203 | else: 204 | test_matrix = np.array( 205 | [ 206 | [ 207 | dist_function( 208 | sparse_binary_data[i].indices, 209 | sparse_binary_data[i].data, 210 | sparse_binary_data[j].indices, 211 | sparse_binary_data[j].data, 212 | ) 213 | for j in range(sparse_binary_data.shape[0]) 214 | ] 215 | for i in range(sparse_binary_data.shape[0]) 216 | ] 217 | ) 218 | 219 | assert_array_almost_equal( 220 | test_matrix, 221 | dist_matrix, 222 | err_msg="Sparse distances don't match for metric {}".format(metric), 223 | ) 224 | 225 | 226 | def test_seuclidean(spatial_data): 227 | v = np.abs(np.random.randn(spatial_data.shape[1])) 228 | dist_matrix = pairwise_distances(spatial_data, metric="seuclidean", V=v) 229 | test_matrix = np.array( 230 | [ 231 | [ 232 | dist.standardised_euclidean(spatial_data[i], spatial_data[j], v) 233 | for j in range(spatial_data.shape[0]) 234 | ] 235 | for i in range(spatial_data.shape[0]) 236 | ] 237 | ) 238 | assert_array_almost_equal( 239 | test_matrix, 240 | dist_matrix, 241 | err_msg="Distances don't match for metric seuclidean", 242 | ) 243 | 244 | 245 | @pytest.mark.skipif( 246 | scipy_full_version < "1.8", reason="incorrect function in scipy<1.8" 247 | ) 248 | def test_weighted_minkowski(spatial_data): 249 | v = np.abs(np.random.randn(spatial_data.shape[1])) 250 | dist_matrix = pairwise_distances(spatial_data, metric="minkowski", w=v, p=3) 251 | test_matrix = np.array( 252 | [ 253 | [ 254 | dist.weighted_minkowski(spatial_data[i], spatial_data[j], v, p=3) 255 | for j in range(spatial_data.shape[0]) 256 | ] 257 | for i in range(spatial_data.shape[0]) 258 | ] 259 | ) 260 | assert_array_almost_equal( 261 | test_matrix, 262 | dist_matrix, 263 | err_msg="Distances don't match for metric weighted_minkowski", 264 | ) 265 | 266 | 267 | def test_mahalanobis(spatial_data): 268 | v = np.cov(np.transpose(spatial_data)) 269 | dist_matrix = pairwise_distances(spatial_data, metric="mahalanobis", VI=v) 270 | test_matrix = np.array( 271 | [ 272 | [ 273 | dist.mahalanobis(spatial_data[i], spatial_data[j], v) 274 | for j in range(spatial_data.shape[0]) 275 | ] 276 | for i in range(spatial_data.shape[0]) 277 | ] 278 | ) 279 | assert_array_almost_equal( 280 | test_matrix, 281 | dist_matrix, 282 | err_msg="Distances don't match for metric mahalanobis", 283 | ) 284 | 285 | 286 | def test_haversine(spatial_data): 287 | tree = BallTree(spatial_data[:, :2], metric="haversine") 288 | dist_matrix, _ = tree.query(spatial_data[:, :2], k=spatial_data.shape[0]) 289 | test_matrix = np.array( 290 | [ 291 | [ 292 | dist.haversine(spatial_data[i, :2], spatial_data[j, :2]) 293 | for j in range(spatial_data.shape[0]) 294 | ] 295 | for i in range(spatial_data.shape[0]) 296 | ] 297 | ) 298 | test_matrix.sort(axis=1) 299 | assert_array_almost_equal( 300 | test_matrix, 301 | dist_matrix, 302 | err_msg="Distances don't match for metric haversine", 303 | ) 304 | 305 | 306 | def test_spearmanr(): 307 | x = np.random.randn(100) 308 | y = np.random.randn(100) 309 | 310 | scipy_expected = stats.spearmanr(x, y) 311 | r = dist.spearmanr(x, y) 312 | assert_array_almost_equal(r, 1 - scipy_expected.correlation) 313 | 314 | 315 | def test_alternative_distances(): 316 | 317 | for distname in dist.fast_distance_alternatives: 318 | 319 | true_dist = dist.named_distances[distname] 320 | alt_dist = dist.fast_distance_alternatives[distname]["dist"] 321 | correction = dist.fast_distance_alternatives[distname]["correction"] 322 | 323 | for i in range(100): 324 | x = np.random.random(30).astype(np.float32) 325 | y = np.random.random(30).astype(np.float32) 326 | x[x < 0.25] = 0.0 327 | y[y < 0.25] = 0.0 328 | 329 | true_distance = true_dist(x, y) 330 | corrected_alt_distance = correction(alt_dist(x, y)) 331 | 332 | assert np.isclose(true_distance, corrected_alt_distance) 333 | 334 | 335 | def test_jensen_shannon(): 336 | test_data = np.random.random(size=(10, 50)) 337 | test_data = normalize(test_data, norm="l1") 338 | for i in range(test_data.shape[0]): 339 | for j in range(i + 1, test_data.shape[0]): 340 | m = (test_data[i] + test_data[j]) / 2.0 341 | p = test_data[i] 342 | q = test_data[j] 343 | d1 = ( 344 | -np.sum(m * np.log(m)) 345 | + (np.sum(p * np.log(p)) + np.sum(q * np.log(q))) / 2.0 346 | ) 347 | d2 = dist.jensen_shannon_divergence(p, q) 348 | assert np.isclose(d1, d2, rtol=1e-4) 349 | 350 | 351 | def test_sparse_jensen_shannon(): 352 | test_data = np.random.random(size=(10, 100)) 353 | # sparsify 354 | test_data[test_data <= 0.5] = 0.0 355 | sparse_test_data = csr_matrix(test_data) 356 | sparse_test_data = normalize(sparse_test_data, norm="l1") 357 | test_data = normalize(test_data, norm="l1") 358 | 359 | for i in range(test_data.shape[0]): 360 | for j in range(i + 1, test_data.shape[0]): 361 | m = (test_data[i] + test_data[j]) / 2.0 362 | p = test_data[i] 363 | q = test_data[j] 364 | d1 = ( 365 | -np.sum(m[m > 0] * np.log(m[m > 0])) 366 | + ( 367 | np.sum(p[p > 0] * np.log(p[p > 0])) 368 | + np.sum(q[q > 0] * np.log(q[q > 0])) 369 | ) 370 | / 2.0 371 | ) 372 | d2 = spdist.sparse_jensen_shannon_divergence( 373 | sparse_test_data[i].indices, 374 | sparse_test_data[i].data, 375 | sparse_test_data[j].indices, 376 | sparse_test_data[j].data, 377 | ) 378 | assert np.isclose(d1, d2, rtol=1e-3) 379 | 380 | 381 | @pytest.mark.parametrize("p", [1.0, 2.0, 3.0, 0.5]) 382 | def test_wasserstein_1d(p): 383 | test_data = np.random.random(size=(10, 100)) 384 | # sparsify 385 | test_data[test_data <= 0.5] = 0.0 386 | sparse_test_data = csr_matrix(test_data) 387 | 388 | for i in range(test_data.shape[0]): 389 | for j in range(i + 1, test_data.shape[0]): 390 | d1 = dist.wasserstein_1d(test_data[i], test_data[j], p) 391 | d2 = spdist.sparse_wasserstein_1d( 392 | sparse_test_data[i].indices, 393 | sparse_test_data[i].data, 394 | sparse_test_data[j].indices, 395 | sparse_test_data[j].data, 396 | p, 397 | ) 398 | assert np.isclose(d1, d2) 399 | 400 | 401 | def test_bit_hamming(): 402 | test_data = np.random.randint(0, 255, size=(10, 100), dtype=np.uint8) 403 | unpacked_data = np.zeros( 404 | (test_data.shape[0], test_data.shape[1] * 8), dtype=np.float32 405 | ) 406 | for i in range(unpacked_data.shape[0]): 407 | for j in range(unpacked_data.shape[1]): 408 | unpacked_data[i, j] = (test_data[i, j // 8] & (1 << (j % 8))) > 0 409 | 410 | all_pairs = pairwise_distances(unpacked_data, metric="hamming") 411 | for i in range(test_data.shape[0]): 412 | for j in range(i + 1, test_data.shape[0]): 413 | d1 = dist.bit_hamming(test_data[i], test_data[j]) / (test_data.shape[1] * 8) 414 | d2 = all_pairs[i, j] 415 | assert np.isclose(d1, d2) 416 | 417 | 418 | def test_bit_jaccard(): 419 | test_data = np.random.randint(0, 255, size=(10, 100), dtype=np.uint8) 420 | unpacked_data = np.zeros( 421 | (test_data.shape[0], test_data.shape[1] * 8), dtype=np.float32 422 | ) 423 | for i in range(unpacked_data.shape[0]): 424 | for j in range(unpacked_data.shape[1]): 425 | unpacked_data[i, j] = (test_data[i, j // 8] & (1 << (j % 8))) > 0 426 | 427 | all_pairs = pairwise_distances(unpacked_data, metric="jaccard") 428 | for i in range(test_data.shape[0]): 429 | for j in range(i + 1, test_data.shape[0]): 430 | d1 = 1.0 - np.exp(-dist.bit_jaccard(test_data[i], test_data[j])) 431 | d2 = all_pairs[i, j] 432 | assert np.isclose(d1, d2) 433 | -------------------------------------------------------------------------------- /pynndescent/tests/test_hub_trees.py: -------------------------------------------------------------------------------- 1 | """Unit tests for hub tree implementations. 2 | 3 | Tests hub-based tree construction for all data types: 4 | - Dense euclidean 5 | - Dense angular (cosine) 6 | - Sparse euclidean 7 | - Sparse angular (cosine) 8 | - Bit-packed 9 | """ 10 | 11 | import numpy as np 12 | import pytest 13 | from sklearn.neighbors import KDTree, NearestNeighbors 14 | from sklearn.preprocessing import normalize 15 | import scipy.sparse as sparse 16 | 17 | from pynndescent import NNDescent 18 | from pynndescent.rp_trees import ( 19 | euclidean_hub_split, 20 | angular_hub_split, 21 | sparse_euclidean_hub_split, 22 | sparse_angular_hub_split, 23 | bit_hub_split, 24 | compute_global_degrees, 25 | ) 26 | 27 | 28 | # ============================================================================ 29 | # Fixtures 30 | # ============================================================================ 31 | 32 | 33 | @pytest.fixture 34 | def hub_tree_data(): 35 | """Generate test data for hub tree tests.""" 36 | np.random.seed(42) 37 | return np.random.uniform(0, 1, size=(500, 20)).astype(np.float32) 38 | 39 | 40 | @pytest.fixture 41 | def hub_tree_sparse_data(): 42 | """Generate sparse test data for hub tree tests.""" 43 | np.random.seed(42) 44 | return sparse.random(500, 50, density=0.5, format="csr", dtype=np.float32) 45 | 46 | 47 | @pytest.fixture 48 | def hub_tree_bit_data(): 49 | """Generate bit-packed test data for hub tree tests.""" 50 | np.random.seed(42) 51 | data = np.random.uniform(0, 1, size=(500, 20)).astype(np.float32) 52 | return (data * 256).astype(np.uint8) 53 | 54 | 55 | # ============================================================================ 56 | # Test hub split functions directly 57 | # ============================================================================ 58 | 59 | 60 | def test_euclidean_hub_split_produces_valid_split(hub_tree_data): 61 | """Test that euclidean_hub_split produces a valid split.""" 62 | # Build a simple neighbor graph 63 | nnd = NNDescent(hub_tree_data, n_neighbors=15, random_state=42) 64 | neighbor_indices = nnd._neighbor_graph[0] 65 | 66 | indices = np.arange(100, dtype=np.int32) # Test with first 100 points 67 | rng_state = np.array([42, 12345, 67890], dtype=np.int64) 68 | global_degrees = compute_global_degrees(neighbor_indices) 69 | 70 | left, right, hyperplane, offset, balance = euclidean_hub_split( 71 | hub_tree_data, indices, neighbor_indices, global_degrees, rng_state 72 | ) 73 | 74 | # Check that split is valid 75 | assert len(left) > 0, "Left partition should not be empty" 76 | assert len(right) > 0, "Right partition should not be empty" 77 | assert len(left) + len(right) == len(indices), "Split should preserve all points" 78 | assert len(set(left) & set(right)) == 0, "Partitions should not overlap" 79 | assert ( 80 | hyperplane.shape[0] == hub_tree_data.shape[1] 81 | ), "Hyperplane dim should match data" 82 | 83 | 84 | def test_angular_hub_split_produces_valid_split(hub_tree_data): 85 | """Test that angular_hub_split produces a valid split.""" 86 | angular_data = normalize(hub_tree_data, norm="l2").astype(np.float32) 87 | 88 | nnd = NNDescent(angular_data, metric="cosine", n_neighbors=15, random_state=42) 89 | neighbor_indices = nnd._neighbor_graph[0] 90 | 91 | indices = np.arange(100, dtype=np.int32) 92 | rng_state = np.array([42, 12345, 67890], dtype=np.int64) 93 | global_degrees = compute_global_degrees(neighbor_indices) 94 | 95 | left, right, hyperplane, offset, balance = angular_hub_split( 96 | angular_data, indices, neighbor_indices, global_degrees, rng_state 97 | ) 98 | 99 | assert len(left) > 0, "Left partition should not be empty" 100 | assert len(right) > 0, "Right partition should not be empty" 101 | assert len(left) + len(right) == len(indices), "Split should preserve all points" 102 | assert offset == 0.0, "Angular split should have zero offset" 103 | 104 | 105 | def test_sparse_euclidean_hub_split_produces_valid_split(hub_tree_sparse_data): 106 | """Test that sparse_euclidean_hub_split produces a valid split.""" 107 | nnd = NNDescent(hub_tree_sparse_data, n_neighbors=15, random_state=42) 108 | neighbor_indices = nnd._neighbor_graph[0] 109 | 110 | indices = np.arange(100, dtype=np.int32) 111 | rng_state = np.array([42, 12345, 67890], dtype=np.int64) 112 | global_degrees = compute_global_degrees(neighbor_indices) 113 | 114 | sp_data = hub_tree_sparse_data.tocsr() 115 | left, right, hyperplane, offset = sparse_euclidean_hub_split( 116 | sp_data.indices, 117 | sp_data.indptr, 118 | sp_data.data, 119 | indices, 120 | neighbor_indices, 121 | global_degrees, 122 | rng_state, 123 | ) 124 | 125 | assert len(left) > 0, "Left partition should not be empty" 126 | assert len(right) > 0, "Right partition should not be empty" 127 | assert len(left) + len(right) == len(indices), "Split should preserve all points" 128 | 129 | 130 | def test_sparse_angular_hub_split_produces_valid_split(hub_tree_sparse_data): 131 | """Test that sparse_angular_hub_split produces a valid split.""" 132 | normalized_data = normalize(hub_tree_sparse_data, norm="l2") 133 | 134 | nnd = NNDescent(normalized_data, metric="cosine", n_neighbors=15, random_state=42) 135 | neighbor_indices = nnd._neighbor_graph[0] 136 | 137 | indices = np.arange(100, dtype=np.int32) 138 | rng_state = np.array([42, 12345, 67890], dtype=np.int64) 139 | global_degrees = compute_global_degrees(neighbor_indices) 140 | 141 | sp_data = normalized_data.tocsr() 142 | left, right, hyperplane, offset = sparse_angular_hub_split( 143 | sp_data.indices, 144 | sp_data.indptr, 145 | sp_data.data, 146 | indices, 147 | neighbor_indices, 148 | global_degrees, 149 | rng_state, 150 | ) 151 | 152 | assert len(left) > 0, "Left partition should not be empty" 153 | assert len(right) > 0, "Right partition should not be empty" 154 | assert len(left) + len(right) == len(indices), "Split should preserve all points" 155 | 156 | 157 | def test_bitpacked_hub_split_produces_valid_split(hub_tree_bit_data): 158 | """Test that bit_hub_split produces a valid split.""" 159 | nnd = NNDescent( 160 | hub_tree_bit_data, metric="bit_jaccard", n_neighbors=15, random_state=42 161 | ) 162 | neighbor_indices = nnd._neighbor_graph[0] 163 | 164 | indices = np.arange(100, dtype=np.int32) 165 | rng_state = np.array([42, 12345, 67890], dtype=np.int64) 166 | global_degrees = compute_global_degrees(neighbor_indices) 167 | 168 | left, right, hyperplane, offset = bit_hub_split( 169 | hub_tree_bit_data, indices, neighbor_indices, global_degrees, rng_state 170 | ) 171 | 172 | assert len(left) > 0, "Left partition should not be empty" 173 | assert len(right) > 0, "Right partition should not be empty" 174 | assert len(left) + len(right) == len(indices), "Split should preserve all points" 175 | assert ( 176 | hyperplane.shape[0] == hub_tree_bit_data.shape[1] * 2 177 | ), "Bit hyperplane should be 2x dim" 178 | 179 | 180 | # ============================================================================ 181 | # Test hub tree construction 182 | # ============================================================================ 183 | 184 | 185 | # ============================================================================ 186 | # Test hub tree construction (via end-to-end tests) 187 | # Note: Direct tree construction tests removed due to Numba typed list 188 | # serialization issues. The functionality is tested through query accuracy tests. 189 | # ============================================================================ 190 | 191 | 192 | # ============================================================================ 193 | # Test end-to-end query accuracy with hub trees 194 | # ============================================================================ 195 | 196 | 197 | def test_dense_euclidean_hub_tree_query_accuracy(hub_tree_data): 198 | """Test query accuracy after prepare() with dense euclidean hub tree.""" 199 | train_data = hub_tree_data[100:] 200 | query_data = hub_tree_data[:100] 201 | 202 | nnd = NNDescent(train_data, metric="euclidean", n_neighbors=15, random_state=42) 203 | nnd.prepare() # This builds the hub tree 204 | 205 | knn_indices, _ = nnd.query(query_data, k=10, epsilon=0.2) 206 | 207 | # Get true neighbors 208 | tree = KDTree(train_data) 209 | true_indices = tree.query(query_data, 10, return_distance=False) 210 | 211 | num_correct = 0.0 212 | for i in range(query_data.shape[0]): 213 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 214 | 215 | percent_correct = num_correct / (query_data.shape[0] * 10) 216 | assert percent_correct >= 0.90, f"Query accuracy too low: {percent_correct:.2%}" 217 | 218 | 219 | def test_dense_angular_hub_tree_query_accuracy(hub_tree_data): 220 | """Test query accuracy after prepare() with dense angular hub tree.""" 221 | angular_data = normalize(hub_tree_data, norm="l2").astype(np.float32) 222 | train_data = angular_data[100:] 223 | query_data = angular_data[:100] 224 | 225 | nnd = NNDescent(train_data, metric="cosine", n_neighbors=15, random_state=42) 226 | nnd.prepare() 227 | 228 | knn_indices, _ = nnd.query(query_data, k=10, epsilon=0.2) 229 | 230 | nn = NearestNeighbors(metric="cosine").fit(train_data) 231 | true_indices = nn.kneighbors(query_data, n_neighbors=10, return_distance=False) 232 | 233 | num_correct = 0.0 234 | for i in range(query_data.shape[0]): 235 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 236 | 237 | percent_correct = num_correct / (query_data.shape[0] * 10) 238 | assert percent_correct >= 0.90, f"Query accuracy too low: {percent_correct:.2%}" 239 | 240 | 241 | def test_sparse_euclidean_hub_tree_query_accuracy(hub_tree_sparse_data): 242 | """Test query accuracy after prepare() with sparse euclidean hub tree.""" 243 | train_data = hub_tree_sparse_data[100:] 244 | query_data = hub_tree_sparse_data[:100] 245 | 246 | nnd = NNDescent(train_data, metric="euclidean", n_neighbors=15, random_state=42) 247 | nnd.prepare() 248 | 249 | knn_indices, _ = nnd.query(query_data, k=10, epsilon=0.2) 250 | 251 | tree = KDTree(train_data.toarray()) 252 | true_indices = tree.query(query_data.toarray(), 10, return_distance=False) 253 | 254 | num_correct = 0.0 255 | for i in range(query_data.shape[0]): 256 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 257 | 258 | percent_correct = num_correct / (query_data.shape[0] * 10) 259 | assert percent_correct >= 0.85, f"Query accuracy too low: {percent_correct:.2%}" 260 | 261 | 262 | def test_sparse_angular_hub_tree_query_accuracy(hub_tree_sparse_data): 263 | """Test query accuracy after prepare() with sparse angular hub tree.""" 264 | normalized_data = normalize(hub_tree_sparse_data, norm="l2") 265 | train_data = normalized_data[100:] 266 | query_data = normalized_data[:100] 267 | 268 | nnd = NNDescent(train_data, metric="cosine", n_neighbors=15, random_state=42) 269 | nnd.prepare() 270 | 271 | knn_indices, _ = nnd.query(query_data, k=10, epsilon=0.2) 272 | 273 | nn = NearestNeighbors(metric="cosine").fit(train_data.toarray()) 274 | true_indices = nn.kneighbors( 275 | query_data.toarray(), n_neighbors=10, return_distance=False 276 | ) 277 | 278 | num_correct = 0.0 279 | for i in range(query_data.shape[0]): 280 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 281 | 282 | percent_correct = num_correct / (query_data.shape[0] * 10) 283 | assert percent_correct >= 0.85, f"Query accuracy too low: {percent_correct:.2%}" 284 | 285 | 286 | def test_bitpacked_hub_tree_query_accuracy(hub_tree_bit_data): 287 | """Test query accuracy after prepare() with bit-packed hub tree.""" 288 | # Unpack for ground truth computation 289 | unpacked_data = np.zeros( 290 | (hub_tree_bit_data.shape[0], hub_tree_bit_data.shape[1] * 8), dtype=np.float32 291 | ) 292 | for i in range(unpacked_data.shape[0]): 293 | for j in range(unpacked_data.shape[1]): 294 | unpacked_data[i, j] = (hub_tree_bit_data[i, j // 8] & (1 << (j % 8))) > 0 295 | 296 | train_idx = slice(100, None) 297 | query_idx = slice(0, 100) 298 | 299 | nnd = NNDescent( 300 | hub_tree_bit_data[train_idx], 301 | metric="bit_jaccard", 302 | n_neighbors=15, 303 | random_state=42, 304 | ) 305 | nnd.prepare() 306 | 307 | knn_indices, _ = nnd.query(hub_tree_bit_data[query_idx], k=10, epsilon=0.3) 308 | 309 | nn = NearestNeighbors(metric="jaccard").fit(unpacked_data[train_idx]) 310 | true_indices = nn.kneighbors( 311 | unpacked_data[query_idx], n_neighbors=10, return_distance=False 312 | ) 313 | 314 | num_correct = 0.0 315 | for i in range(100): 316 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 317 | 318 | percent_correct = num_correct / (100 * 10) 319 | assert percent_correct >= 0.70, f"Query accuracy too low: {percent_correct:.2%}" 320 | 321 | 322 | # ============================================================================ 323 | # Test self-query accuracy (points should find themselves) 324 | # ============================================================================ 325 | 326 | 327 | def test_dense_euclidean_hub_tree_self_query(hub_tree_data): 328 | """Test that points can find themselves after prepare() with hub tree.""" 329 | nnd = NNDescent(hub_tree_data, metric="euclidean", n_neighbors=15, random_state=42) 330 | nnd.prepare() 331 | 332 | # Query first 50 points 333 | knn_indices, knn_distances = nnd.query(hub_tree_data[:50], k=1) 334 | 335 | self_found = sum(1 for i in range(50) if knn_indices[i, 0] == i) 336 | assert self_found >= 45, f"Self-query accuracy too low: {self_found}/50" 337 | 338 | 339 | def test_dense_angular_hub_tree_self_query(hub_tree_data): 340 | """Test self-query with angular hub tree.""" 341 | angular_data = normalize(hub_tree_data, norm="l2").astype(np.float32) 342 | 343 | nnd = NNDescent(angular_data, metric="cosine", n_neighbors=15, random_state=42) 344 | nnd.prepare() 345 | 346 | knn_indices, _ = nnd.query(angular_data[:50], k=1) 347 | 348 | self_found = sum(1 for i in range(50) if knn_indices[i, 0] == i) 349 | assert self_found >= 45, f"Self-query accuracy too low: {self_found}/50" 350 | 351 | 352 | def test_sparse_euclidean_hub_tree_self_query(hub_tree_sparse_data): 353 | """Test self-query with sparse euclidean hub tree.""" 354 | nnd = NNDescent( 355 | hub_tree_sparse_data, metric="euclidean", n_neighbors=15, random_state=42 356 | ) 357 | nnd.prepare() 358 | 359 | knn_indices, _ = nnd.query(hub_tree_sparse_data[:50], k=1) 360 | 361 | self_found = sum(1 for i in range(50) if knn_indices[i, 0] == i) 362 | assert self_found >= 40, f"Self-query accuracy too low: {self_found}/50" 363 | 364 | 365 | def test_sparse_angular_hub_tree_self_query(hub_tree_sparse_data): 366 | """Test self-query with sparse angular hub tree.""" 367 | normalized_data = normalize(hub_tree_sparse_data, norm="l2") 368 | 369 | nnd = NNDescent(normalized_data, metric="cosine", n_neighbors=15, random_state=42) 370 | nnd.prepare() 371 | 372 | knn_indices, _ = nnd.query(normalized_data[:50], k=1) 373 | 374 | self_found = sum(1 for i in range(50) if knn_indices[i, 0] == i) 375 | assert self_found >= 40, f"Self-query accuracy too low: {self_found}/50" 376 | 377 | 378 | def test_bitpacked_hub_tree_self_query(hub_tree_bit_data): 379 | """Test self-query with bit-packed hub tree.""" 380 | nnd = NNDescent( 381 | hub_tree_bit_data, metric="bit_jaccard", n_neighbors=15, random_state=42 382 | ) 383 | nnd.prepare() 384 | 385 | knn_indices, _ = nnd.query(hub_tree_bit_data[:50], k=1) 386 | 387 | self_found = sum(1 for i in range(50) if knn_indices[i, 0] == i) 388 | assert self_found >= 40, f"Self-query accuracy too low: {self_found}/50" 389 | -------------------------------------------------------------------------------- /doc/pynndescent_metrics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# PyNNDescent with different metrics\n", 8 | "\n", 9 | "In the initial tutorial we looked at how to get PyNNDescent running on your data, and how to query the indexes it builds. Implicit in all of that was the measure of distance used to determine what counts as the \"nearest\" neighbors. By default PyNNDescent uses the euclidean metric (because that is what people generally expect when they talk about distance). This is not the only way to measure distance however, and is often not the right choice for very high dimensional data for example. Let's look at how to use PyNNDescent with other metrics.\n", 10 | "\n", 11 | "First we'll need some libraries, and some test data. As before we will use ann-benchmarks for data, so we will reuse the data download function from the previous tutorial." 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 1, 17 | "metadata": { 18 | "pycharm": { 19 | "is_executing": false 20 | } 21 | }, 22 | "outputs": [], 23 | "source": [ 24 | "import pynndescent\n", 25 | "import numpy as np\n", 26 | "import h5py\n", 27 | "from urllib.request import urlretrieve\n", 28 | "import os" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 2, 34 | "metadata": { 35 | "pycharm": { 36 | "is_executing": false 37 | } 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "def get_ann_benchmark_data(dataset_name):\n", 42 | " if not os.path.exists(f\"{dataset_name}.hdf5\"):\n", 43 | " print(f\"Dataset {dataset_name} is not cached; downloading now ...\")\n", 44 | " urlretrieve(f\"http://ann-benchmarks.com/{dataset_name}.hdf5\", f\"{dataset_name}.hdf5\")\n", 45 | " hdf5_file = h5py.File(f\"{dataset_name}.hdf5\", \"r\")\n", 46 | " return np.array(hdf5_file['train']), np.array(hdf5_file['test']), hdf5_file.attrs['distance']" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "## Built in metrics\n", 54 | "\n", 55 | "Let's grab some data where euclidean distance doesn't make sense. We'll use the NY-Times dataset, which is a [TF-IDF](https://en.wikipedia.org/wiki/Tf%E2%80%93idf) matrix of data generated from NY-Times news stories. The particulars are less important here, but what matters is that the most sensible way to measure distance on this data is with an angular metric, such as cosine distance." 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 3, 61 | "metadata": { 62 | "pycharm": { 63 | "is_executing": false 64 | } 65 | }, 66 | "outputs": [ 67 | { 68 | "name": "stdout", 69 | "output_type": "stream", 70 | "text": [ 71 | "Dataset nytimes-256-angular is not cached; downloading now ...\n" 72 | ] 73 | }, 74 | { 75 | "data": { 76 | "text/plain": [ 77 | "(290000, 256)" 78 | ] 79 | }, 80 | "execution_count": 3, 81 | "metadata": {}, 82 | "output_type": "execute_result" 83 | } 84 | ], 85 | "source": [ 86 | "nytimes_train, nytimes_test, distance = get_ann_benchmark_data('nytimes-256-angular')\n", 87 | "nytimes_train.shape" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "Now that we have the data we can check the distance measure suggested by ann-benchmarks" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": 4, 100 | "metadata": { 101 | "pycharm": { 102 | "is_executing": false 103 | } 104 | }, 105 | "outputs": [ 106 | { 107 | "data": { 108 | "text/plain": [ 109 | "'angular'" 110 | ] 111 | }, 112 | "execution_count": 4, 113 | "metadata": {}, 114 | "output_type": "execute_result" 115 | } 116 | ], 117 | "source": [ 118 | "distance" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "So an angular measure of distance -- cosine distance will suffice. How do we manage to get PyNNDescent working with cosine distance (which isn't even a real metric! it violates the triangle inequality) instead of standard euclidean?" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "metadata": { 132 | "pycharm": { 133 | "is_executing": false 134 | } 135 | }, 136 | "outputs": [ 137 | { 138 | "name": "stdout", 139 | "output_type": "stream", 140 | "text": [ 141 | "CPU times: user 5min 2s, sys: 1min 22s, total: 6min 24s\n", 142 | "Wall time: 30.6 s\n" 143 | ] 144 | } 145 | ], 146 | "source": [ 147 | "%%time\n", 148 | "index = pynndescent.NNDescent(nytimes_train, metric=\"cosine\", parallel_batch_queries=True)\n", 149 | "index.prepare()" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": {}, 155 | "source": [ 156 | "That's right, it uses the scikit-learn standard of the ``metric`` keyword and accepts a string that names the metric. We can now query the index, and it will use that metric in the query as well." 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": 6, 162 | "metadata": { 163 | "pycharm": { 164 | "is_executing": false 165 | } 166 | }, 167 | "outputs": [ 168 | { 169 | "name": "stdout", 170 | "output_type": "stream", 171 | "text": [ 172 | "CPU times: user 26.7 s, sys: 106 ms, total: 26.8 s\n", 173 | "Wall time: 26.8 s\n" 174 | ] 175 | } 176 | ], 177 | "source": [ 178 | "%%time\n", 179 | "neighbors = index.query(nytimes_train)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "metadata": {}, 185 | "source": [ 186 | "It is worth noting at this point that these results will probably be a little sub-optimal since angular distances are harder to index, and as a result to get the same level accuracy in the nearest neighbor approximation we should be using a larger value than the default ``30`` for ``n_neighbors``. Beyond that, however, nothing else changes from the tutorial earlier -- except that we can't use kd-trees to learn the true neighbors, since they require distances that respect the triangle inequality.\n", 187 | "\n", 188 | "How many metrics does PyNNDescent support out of the box? Quite a few actually:" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 7, 194 | "metadata": { 195 | "pycharm": { 196 | "is_executing": false 197 | } 198 | }, 199 | "outputs": [ 200 | { 201 | "name": "stdout", 202 | "output_type": "stream", 203 | "text": [ 204 | "euclidean\n", 205 | "l2\n", 206 | "sqeuclidean\n", 207 | "manhattan\n", 208 | "taxicab\n", 209 | "l1\n", 210 | "chebyshev\n", 211 | "linfinity\n", 212 | "linfty\n", 213 | "linf\n", 214 | "minkowski\n", 215 | "seuclidean\n", 216 | "standardised_euclidean\n", 217 | "wminkowski\n", 218 | "weighted_minkowski\n", 219 | "mahalanobis\n", 220 | "canberra\n", 221 | "cosine\n", 222 | "dot\n", 223 | "inner_product\n", 224 | "correlation\n", 225 | "haversine\n", 226 | "braycurtis\n", 227 | "spearmanr\n", 228 | "tsss\n", 229 | "true_angular\n", 230 | "hellinger\n", 231 | "kantorovich\n", 232 | "wasserstein\n", 233 | "wasserstein_1d\n", 234 | "wasserstein-1d\n", 235 | "kantorovich-1d\n", 236 | "kantorovich_1d\n", 237 | "circular_kantorovich\n", 238 | "circular_wasserstein\n", 239 | "sinkhorn\n", 240 | "jensen-shannon\n", 241 | "jensen_shannon\n", 242 | "symmetric-kl\n", 243 | "symmetric_kl\n", 244 | "symmetric_kullback_liebler\n", 245 | "hamming\n", 246 | "jaccard\n", 247 | "dice\n", 248 | "matching\n", 249 | "kulsinski\n", 250 | "rogerstanimoto\n", 251 | "russellrao\n", 252 | "sokalsneath\n", 253 | "sokalmichener\n", 254 | "yule\n", 255 | "bit_hamming\n", 256 | "bit_jaccard\n" 257 | ] 258 | } 259 | ], 260 | "source": [ 261 | "for dist in pynndescent.distances.named_distances:\n", 262 | " print(dist)" 263 | ] 264 | }, 265 | { 266 | "cell_type": "markdown", 267 | "metadata": {}, 268 | "source": [ 269 | "Some of these are repeats or alternate names for the same metric, and some of these are fairly simple, but others, such as ``spearmanr``, or ``hellinger`` are useful statistical measures not often implemented elsewhere, and others, such as ``wasserstein`` are complex and hard to compute metrics. Having all of these readily available in a fast approximate nearest neighbor library is one of PyNNDescent's strengths." 270 | ] 271 | }, 272 | { 273 | "cell_type": "markdown", 274 | "metadata": {}, 275 | "source": [ 276 | "## Custom metrics\n", 277 | "\n", 278 | "We can go even further in terms of interesting metrics however. You can write your own custom metrics and hand them to PyNNDescent to use on your data. There, of course, a few caveats with this. Many nearest neighbor libraries allow for the possibility of user defined metrics. If you are using Python this often ends up coming in two flavours:\n", 279 | "\n", 280 | " 1. Write some C, C++ or Cython code and compile it against the library itself\n", 281 | " 2. Write a python distance function, but lose almost all performance\n", 282 | " \n", 283 | "With PyNNDescent we get a different trade-off. Because we use [Numba](http://numba.pydata.org/) for just-in-time compiling of Python code instead of a C or C++ backend you don't need to do an offline compilation step and can instead have your custom Python distance function compiled and used on the fly. The cost for that is that the custom distance function you write must be a numba jitted function. This, in turn, means that you can only use Python functionality that is [supported by numba](). That is still a fairly large amount of functionality, especially when we are talking about numerical work, but it is a limit. It also means that you will need to import numba and decorate your custom distance function accordingly. Let's look at how to do that." 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 8, 289 | "metadata": { 290 | "pycharm": { 291 | "is_executing": false 292 | } 293 | }, 294 | "outputs": [], 295 | "source": [ 296 | "import numba" 297 | ] 298 | }, 299 | { 300 | "cell_type": "markdown", 301 | "metadata": {}, 302 | "source": [ 303 | "Let's start by simply implementing euclidean distance where $d(\\mathbf{x},\\mathbf{y}) = \\sqrt{\\sum_i (\\mathbf{x}_i - \\mathbf{y}_i)^2}$. This is already implemented in PyNNDescent, but it is a simple distance measure that everyone knows and will serve to illustrate the process. First let's write the function -- using numpy functionality this will be fairly short:" 304 | ] 305 | }, 306 | { 307 | "cell_type": "code", 308 | "execution_count": 9, 309 | "metadata": { 310 | "pycharm": { 311 | "is_executing": false 312 | } 313 | }, 314 | "outputs": [], 315 | "source": [ 316 | "def euclidean(x, y):\n", 317 | " return np.sqrt(np.sum((x - y)**2))" 318 | ] 319 | }, 320 | { 321 | "cell_type": "markdown", 322 | "metadata": {}, 323 | "source": [ 324 | "Now we need to get the function compiled so PyNNDescent can use it. That is actually as easy as adding a decorator to the top of the function telling numba that it should compile the function when it gets called." 325 | ] 326 | }, 327 | { 328 | "cell_type": "code", 329 | "execution_count": 24, 330 | "metadata": { 331 | "pycharm": { 332 | "is_executing": false 333 | } 334 | }, 335 | "outputs": [], 336 | "source": [ 337 | "@numba.jit\n", 338 | "def euclidean(x, y):\n", 339 | " return np.sqrt(np.sum((x - y)**2))" 340 | ] 341 | }, 342 | { 343 | "cell_type": "markdown", 344 | "metadata": {}, 345 | "source": [ 346 | "To ensure our timing doesn't include the numba compile time of compiling this function, let's run it a few times." 347 | ] 348 | }, 349 | { 350 | "cell_type": "code", 351 | "execution_count": 25, 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "[euclidean(nytimes_train[0], nytimes_train[i]) for i in range(100)];" 356 | ] 357 | }, 358 | { 359 | "cell_type": "markdown", 360 | "metadata": {}, 361 | "source": [ 362 | "We can now pass this function directly to PyNNdescent as a metric and everything will \"just work\". We'll just train on the smaller test set since it will take a while." 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 21, 368 | "metadata": { 369 | "pycharm": { 370 | "is_executing": false 371 | } 372 | }, 373 | "outputs": [ 374 | { 375 | "name": "stdout", 376 | "output_type": "stream", 377 | "text": [ 378 | "CPU times: user 8.72 s, sys: 2.19 s, total: 10.9 s\n", 379 | "Wall time: 253 ms\n" 380 | ] 381 | } 382 | ], 383 | "source": [ 384 | "%%time\n", 385 | "index = pynndescent.NNDescent(nytimes_test, metric=euclidean)" 386 | ] 387 | }, 388 | { 389 | "cell_type": "markdown", 390 | "metadata": {}, 391 | "source": [ 392 | "This is a little slower than we might have expected, and that's because a great deal of the computation time is spent evaluating that metric. While numba will compile what we wrote we can make it a little faster if we look through the [numba performance tips documentation](https://numba.readthedocs.io/en/stable/user/performance-tips.html). The two main things to note are that we can use explicit loops instead of numpy routines, and we can add arguments to the decorator such as ``fastmath=True`` to speed things up a little. Let's rewrite it:" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": null, 398 | "metadata": { 399 | "pycharm": { 400 | "is_executing": false 401 | } 402 | }, 403 | "outputs": [], 404 | "source": [ 405 | "@numba.jit(fastmath=True)\n", 406 | "def euclidean(x, y):\n", 407 | " result = 0.0\n", 408 | " for i in range(x.shape[0]):\n", 409 | " result += (x[i] - y[i])**2\n", 410 | " return np.sqrt(result)\n", 411 | "\n", 412 | "[euclidean(nytimes_train[0], nytimes_train[i]) for i in range(100)];" 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": 17, 418 | "metadata": { 419 | "pycharm": { 420 | "is_executing": false 421 | } 422 | }, 423 | "outputs": [ 424 | { 425 | "name": "stdout", 426 | "output_type": "stream", 427 | "text": [ 428 | "CPU times: user 6.39 s, sys: 2.46 s, total: 8.85 s\n", 429 | "Wall time: 202 ms\n" 430 | ] 431 | } 432 | ], 433 | "source": [ 434 | "%%time\n", 435 | "index = pynndescent.NNDescent(nytimes_test, metric=euclidean)" 436 | ] 437 | }, 438 | { 439 | "cell_type": "markdown", 440 | "metadata": {}, 441 | "source": [ 442 | "That is faster! If we are really on the hunt for performance however, you might note that, for the purposes of finding nearest neighbors the exact values of the distance are not as important as the ordering on distances. In other words we could use the square of euclidean distance and we would get all the same neighbors (since the square root is a monotonic order preserving function of squared euclidean distance). That would, for example, save us a square root computation. We could do the square roots afterwards to just the distances to the nearest neighbors. Let's reproduce what PyNNDescent actually uses internally for euclidean distance:" 443 | ] 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": null, 448 | "metadata": { 449 | "pycharm": { 450 | "is_executing": false 451 | } 452 | }, 453 | "outputs": [], 454 | "source": [ 455 | "@numba.njit(\n", 456 | " [\n", 457 | " \"f4(f4[::1],f4[::1])\",\n", 458 | " numba.types.float32(\n", 459 | " numba.types.Array(numba.types.float32, 1, \"C\", readonly=True),\n", 460 | " numba.types.Array(numba.types.float32, 1, \"C\", readonly=True),\n", 461 | " ),\n", 462 | " ],\n", 463 | " fastmath=True,\n", 464 | " locals={\n", 465 | " \"result\": numba.types.float32,\n", 466 | " \"diff\": numba.types.float32,\n", 467 | " \"dim\": numba.types.uint32,\n", 468 | " \"i\": numba.types.uint16,\n", 469 | " },\n", 470 | ")\n", 471 | "def squared_euclidean(x, y):\n", 472 | " r\"\"\"Squared euclidean distance.\n", 473 | "\n", 474 | " .. math::\n", 475 | " D(x, y) = \\sum_i (x_i - y_i)^2\n", 476 | " \"\"\"\n", 477 | " result = 0.0\n", 478 | " dim = x.shape[0]\n", 479 | " for i in range(dim):\n", 480 | " diff = x[i] - y[i]\n", 481 | " result += diff * diff\n", 482 | "\n", 483 | " return result\n", 484 | "\n", 485 | "[squared_euclidean(nytimes_train[0], nytimes_train[i]) for i in range(100)];" 486 | ] 487 | }, 488 | { 489 | "cell_type": "markdown", 490 | "metadata": {}, 491 | "source": [ 492 | "That is definitely more complicated! Most of it, however, is arguments to the decorator giving it extra typing information to let it squeeze out every drop of performance possible. By default numba will infer types, or even compile different versions for the different types it sees. With a little extra information, however, it can make smarter decisions and optimizations during compilation. Let's see how fast that goes:" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": 18, 498 | "metadata": { 499 | "pycharm": { 500 | "is_executing": false 501 | } 502 | }, 503 | "outputs": [ 504 | { 505 | "name": "stdout", 506 | "output_type": "stream", 507 | "text": [ 508 | "CPU times: user 5.5 s, sys: 2.02 s, total: 7.52 s\n", 509 | "Wall time: 173 ms\n" 510 | ] 511 | } 512 | ], 513 | "source": [ 514 | "%%time\n", 515 | "index = pynndescent.NNDescent(nytimes_test, metric=squared_euclidean)" 516 | ] 517 | }, 518 | { 519 | "cell_type": "markdown", 520 | "metadata": {}, 521 | "source": [ 522 | "Definitely faster again -- so there are significant gains to be had if you are willing to put in some work to write your function. Still, the naive approach we started with, just decorating the obvious implementation, did very well, so unless you desperately need top tier performance for your custom metric a straightforward approach will suffice. And for comparison here is the tailored C++ implementation that libraries like [nmslib](https://github.com/nmslib/nmslib) and [hnswlib](https://github.com/nmslib/hnswlib) use:\n", 523 | "\n", 524 | "```C++\n", 525 | "static float\n", 526 | "L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {\n", 527 | " float *pVect1 = (float *) pVect1v;\n", 528 | " float *pVect2 = (float *) pVect2v;\n", 529 | " size_t qty = *((size_t *) qty_ptr);\n", 530 | " float PORTABLE_ALIGN32 TmpRes[8];\n", 531 | " size_t qty16 = qty >> 4;\n", 532 | "\n", 533 | " const float *pEnd1 = pVect1 + (qty16 << 4);\n", 534 | "\n", 535 | " __m256 diff, v1, v2;\n", 536 | " __m256 sum = _mm256_set1_ps(0);\n", 537 | "\n", 538 | " while (pVect1 < pEnd1) {\n", 539 | " v1 = _mm256_loadu_ps(pVect1);\n", 540 | " pVect1 += 8;\n", 541 | " v2 = _mm256_loadu_ps(pVect2);\n", 542 | " pVect2 += 8;\n", 543 | " diff = _mm256_sub_ps(v1, v2);\n", 544 | " sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));\n", 545 | "\n", 546 | " v1 = _mm256_loadu_ps(pVect1);\n", 547 | " pVect1 += 8;\n", 548 | " v2 = _mm256_loadu_ps(pVect2);\n", 549 | " pVect2 += 8;\n", 550 | " diff = _mm256_sub_ps(v1, v2);\n", 551 | " sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));\n", 552 | " }\n", 553 | "\n", 554 | " _mm256_store_ps(TmpRes, sum);\n", 555 | " return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];\n", 556 | "}\n", 557 | "```\n", 558 | "\n", 559 | "Comparatively, the python code, even with its extra numba decorations, looks pretty straightforward. Notably (at last testing) the numba code and this C++ code (when suitably compiled with AVX flags etc.) have essentially indistinguishable performance. Numba is awfully good at finding optimizations for numerical code." 560 | ] 561 | }, 562 | { 563 | "cell_type": "markdown", 564 | "metadata": {}, 565 | "source": [ 566 | "### Beware of bounded distances\n", 567 | "\n", 568 | "There is one remaining caveat on custom distance functions that is important. Many distances, such as cosine distance and jaccard distance are bounded: the values always fall in some fixed finite range (in these cases between 0 and 1). When querying new data points against an index PyNNDescent bounds the search by some multiple (1 + epsilon) of the most distant of the the top k neighbors found so far. This allows a limited amount of backtracking and avoids getting stuck in local minima. It does, however, not play well with bounded distances -- a small but non-zero epsilon can end up failing to bound the search at all (suppose epsilon is 0.2 and the most distant of the the top k neighbors has cosine distance 0.8 for example). The trick to getting around this is the same trick described above when we decided not to bother taking the square root of the euclidean distance -- we can apply transform to the distance values that preserves all ordering. This means that, for example, internally PyNNDescent uses the *negative log* of the cosine *similarity* instead of cosine distance (and converts the distance values when done). You will want to use a similar trick if your distance function has a strict finite upper bound." 569 | ] 570 | } 571 | ], 572 | "metadata": { 573 | "kernelspec": { 574 | "display_name": "pynndescent_dev", 575 | "language": "python", 576 | "name": "python3" 577 | }, 578 | "language_info": { 579 | "codemirror_mode": { 580 | "name": "ipython", 581 | "version": 3 582 | }, 583 | "file_extension": ".py", 584 | "mimetype": "text/x-python", 585 | "name": "python", 586 | "nbconvert_exporter": "python", 587 | "pygments_lexer": "ipython3", 588 | "version": "3.12.12" 589 | } 590 | }, 591 | "nbformat": 4, 592 | "nbformat_minor": 4 593 | } 594 | -------------------------------------------------------------------------------- /pynndescent/tests/test_pynndescent_.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import re 4 | import pathlib 5 | import pytest 6 | from contextlib import redirect_stdout 7 | 8 | import numpy as np 9 | from sklearn.neighbors import KDTree 10 | from sklearn.neighbors import NearestNeighbors 11 | from sklearn.preprocessing import normalize 12 | import pickle 13 | import joblib 14 | import scipy 15 | 16 | from pynndescent import NNDescent, PyNNDescentTransformer 17 | 18 | 19 | def test_nn_descent_neighbor_accuracy(nn_data, seed): 20 | knn_indices, _ = NNDescent( 21 | nn_data, "euclidean", {}, 10, random_state=np.random.RandomState(seed) 22 | )._neighbor_graph 23 | 24 | tree = KDTree(nn_data) 25 | true_indices = tree.query(nn_data, 10, return_distance=False) 26 | 27 | num_correct = 0.0 28 | for i in range(nn_data.shape[0]): 29 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 30 | 31 | percent_correct = num_correct / (nn_data.shape[0] * 10) 32 | assert ( 33 | percent_correct >= 0.98 34 | ), "NN-descent did not get 99% accuracy on nearest neighbors" 35 | 36 | 37 | def test_angular_nn_descent_neighbor_accuracy(nn_data, seed): 38 | knn_indices, _ = NNDescent( 39 | nn_data, "cosine", {}, 10, random_state=np.random.RandomState(seed) 40 | )._neighbor_graph 41 | 42 | angular_data = normalize(nn_data, norm="l2") 43 | tree = KDTree(angular_data) 44 | true_indices = tree.query(angular_data, 10, return_distance=False) 45 | 46 | num_correct = 0.0 47 | for i in range(nn_data.shape[0]): 48 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 49 | 50 | percent_correct = num_correct / (nn_data.shape[0] * 10) 51 | assert ( 52 | percent_correct >= 0.98 53 | ), "NN-descent did not get 99% accuracy on nearest neighbors" 54 | 55 | 56 | def test_bitpacked_nn_descent_neighbor_accuracy(nn_data, seed): 57 | bitpacked_data = (nn_data * 256).astype(np.uint8) 58 | unpacked_data = np.zeros( 59 | (bitpacked_data.shape[0], bitpacked_data.shape[1] * 8), dtype=np.float32 60 | ) 61 | for i in range(unpacked_data.shape[0]): 62 | for j in range(unpacked_data.shape[1]): 63 | unpacked_data[i, j] = (bitpacked_data[i, j // 8] & (1 << (j % 8))) > 0 64 | 65 | knn_indices, _ = NNDescent( 66 | bitpacked_data, "bit_jaccard", {}, 10, random_state=np.random.RandomState(seed) 67 | )._neighbor_graph 68 | 69 | nn_finder = NearestNeighbors(n_neighbors=10, metric="jaccard").fit(unpacked_data) 70 | true_indices = nn_finder.kneighbors(unpacked_data, 10, return_distance=False) 71 | 72 | num_correct = 0.0 73 | for i in range(nn_data.shape[0]): 74 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 75 | 76 | percent_correct = num_correct / (nn_data.shape[0] * 10) 77 | assert ( 78 | percent_correct >= 0.60 79 | ), "NN-descent did not get 60% accuracy on nearest neighbors" 80 | 81 | 82 | @pytest.mark.skipif( 83 | list(map(int, re.findall(r"[0-9]+\.[0-9]+\.?[0-9]*", scipy.version.version)[0].split("."))) < [1, 3, 0], 84 | reason="requires scipy >= 1.3.0", 85 | ) 86 | def test_sparse_nn_descent_neighbor_accuracy(sparse_nn_data, seed): 87 | knn_indices, _ = NNDescent( 88 | sparse_nn_data, "euclidean", n_neighbors=20, random_state=None 89 | )._neighbor_graph 90 | 91 | tree = KDTree(sparse_nn_data.toarray()) 92 | true_indices = tree.query(sparse_nn_data.toarray(), 10, return_distance=False) 93 | 94 | num_correct = 0.0 95 | for i in range(sparse_nn_data.shape[0]): 96 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 97 | 98 | percent_correct = num_correct / (sparse_nn_data.shape[0] * 10) 99 | assert ( 100 | percent_correct >= 0.85 101 | ), "Sparse NN-descent did not get 95% accuracy on nearest neighbors" 102 | 103 | 104 | @pytest.mark.skipif( 105 | list(map(int, scipy.version.version.split("."))) < [1, 3, 0], 106 | reason="requires scipy >= 1.3.0", 107 | ) 108 | def test_sparse_angular_nn_descent_neighbor_accuracy(sparse_nn_data): 109 | knn_indices, _ = NNDescent( 110 | sparse_nn_data, "cosine", {}, 20, random_state=None 111 | )._neighbor_graph 112 | 113 | angular_data = normalize(sparse_nn_data, norm="l2").toarray() 114 | tree = KDTree(angular_data) 115 | true_indices = tree.query(angular_data, 10, return_distance=False) 116 | 117 | num_correct = 0.0 118 | for i in range(sparse_nn_data.shape[0]): 119 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 120 | 121 | percent_correct = num_correct / (sparse_nn_data.shape[0] * 10) 122 | assert ( 123 | percent_correct >= 0.85 124 | ), "Sparse angular NN-descent did not get 98% accuracy on nearest neighbors" 125 | 126 | 127 | def test_nn_descent_query_accuracy(nn_data): 128 | nnd = NNDescent(nn_data[200:], "euclidean", n_neighbors=10, random_state=None) 129 | knn_indices, _ = nnd.query(nn_data[:200], k=10, epsilon=0.2) 130 | 131 | tree = KDTree(nn_data[200:]) 132 | true_indices = tree.query(nn_data[:200], 10, return_distance=False) 133 | 134 | num_correct = 0.0 135 | for i in range(true_indices.shape[0]): 136 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 137 | 138 | percent_correct = num_correct / (true_indices.shape[0] * 10) 139 | assert ( 140 | percent_correct >= 0.95 141 | ), "NN-descent query did not get 95% accuracy on nearest neighbors" 142 | 143 | 144 | def test_nn_descent_query_accuracy_angular(nn_data): 145 | nnd = NNDescent(nn_data[200:], "cosine", n_neighbors=30, random_state=None) 146 | knn_indices, _ = nnd.query(nn_data[:200], k=10, epsilon=0.32) 147 | 148 | nn = NearestNeighbors(metric="cosine").fit(nn_data[200:]) 149 | true_indices = nn.kneighbors(nn_data[:200], n_neighbors=10, return_distance=False) 150 | 151 | num_correct = 0.0 152 | for i in range(true_indices.shape[0]): 153 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 154 | 155 | percent_correct = num_correct / (true_indices.shape[0] * 10) 156 | assert ( 157 | percent_correct >= 0.95 158 | ), "NN-descent query did not get 95% accuracy on nearest neighbors" 159 | 160 | 161 | def test_sparse_nn_descent_query_accuracy(sparse_nn_data): 162 | nnd = NNDescent( 163 | sparse_nn_data[200:], "euclidean", n_neighbors=15, random_state=None 164 | ) 165 | knn_indices, _ = nnd.query(sparse_nn_data[:200], k=10, epsilon=0.24) 166 | 167 | tree = KDTree(sparse_nn_data[200:].toarray()) 168 | true_indices = tree.query(sparse_nn_data[:200].toarray(), 10, return_distance=False) 169 | 170 | num_correct = 0.0 171 | for i in range(true_indices.shape[0]): 172 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 173 | 174 | percent_correct = num_correct / (true_indices.shape[0] * 10) 175 | assert ( 176 | percent_correct >= 0.95 177 | ), "Sparse NN-descent query did not get 95% accuracy on nearest neighbors" 178 | 179 | 180 | def test_sparse_nn_descent_query_accuracy_angular(sparse_nn_data): 181 | nnd = NNDescent(sparse_nn_data[200:], "cosine", n_neighbors=50, random_state=None) 182 | knn_indices, _ = nnd.query(sparse_nn_data[:200], k=10, epsilon=0.36) 183 | 184 | nn = NearestNeighbors(metric="cosine").fit(sparse_nn_data[200:].toarray()) 185 | true_indices = nn.kneighbors( 186 | sparse_nn_data[:200].toarray(), n_neighbors=10, return_distance=False 187 | ) 188 | 189 | num_correct = 0.0 190 | for i in range(true_indices.shape[0]): 191 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 192 | 193 | percent_correct = num_correct / (true_indices.shape[0] * 10) 194 | assert ( 195 | percent_correct >= 0.95 196 | ), "Sparse NN-descent query did not get 95% accuracy on nearest neighbors" 197 | 198 | 199 | def test_bitpacked_nn_descent_query_accuracy(nn_data): 200 | bitpacked_data = (nn_data * 256).astype(np.uint8) 201 | unpacked_data = np.zeros( 202 | (bitpacked_data.shape[0], bitpacked_data.shape[1] * 8), dtype=np.float32 203 | ) 204 | for i in range(unpacked_data.shape[0]): 205 | for j in range(unpacked_data.shape[1]): 206 | unpacked_data[i, j] = (bitpacked_data[i, j // 8] & (1 << (j % 8))) > 0 207 | 208 | nnd = NNDescent( 209 | bitpacked_data[200:], "bit_jaccard", n_neighbors=50, random_state=None 210 | ) 211 | knn_indices, _ = nnd.query(bitpacked_data[:200], k=10, epsilon=0.36) 212 | 213 | nn = NearestNeighbors(metric="jaccard").fit(unpacked_data[200:]) 214 | true_indices = nn.kneighbors( 215 | unpacked_data[:200], n_neighbors=10, return_distance=False 216 | ) 217 | 218 | num_correct = 0.0 219 | for i in range(true_indices.shape[0]): 220 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 221 | 222 | percent_correct = num_correct / (true_indices.shape[0] * 10) 223 | assert ( 224 | percent_correct >= 0.80 225 | ), "Sparse NN-descent query did not get 95% accuracy on nearest neighbors" 226 | 227 | 228 | def test_transformer_equivalence(nn_data): 229 | N_NEIGHBORS = 15 230 | EPSILON = 0.15 231 | train = nn_data[:400] 232 | test = nn_data[:200] 233 | 234 | # Note we shift N_NEIGHBORS to conform to sklearn's KNeighborTransformer defn 235 | nnd = NNDescent( 236 | data=train, n_neighbors=N_NEIGHBORS + 1, random_state=42, compressed=False 237 | ) 238 | indices, dists = nnd.query(test, k=N_NEIGHBORS, epsilon=EPSILON) 239 | sort_idx = np.argsort(indices, axis=1) 240 | indices_sorted = np.vstack( 241 | [indices[i, sort_idx[i]] for i in range(sort_idx.shape[0])] 242 | ) 243 | dists_sorted = np.vstack([dists[i, sort_idx[i]] for i in range(sort_idx.shape[0])]) 244 | 245 | # Note we shift N_NEIGHBORS to conform to sklearn' KNeighborTransformer defn 246 | transformer = PyNNDescentTransformer( 247 | n_neighbors=N_NEIGHBORS, search_epsilon=EPSILON, random_state=42 248 | ).fit(train, compress_index=False) 249 | Xt = transformer.transform(test).sorted_indices() 250 | 251 | assert np.all(Xt.indices == indices_sorted.flatten()) 252 | assert np.allclose(Xt.data, dists_sorted.flat) 253 | 254 | 255 | def test_random_state_none(nn_data, spatial_data): 256 | knn_indices, _ = NNDescent( 257 | nn_data, "euclidean", {}, 10, random_state=None 258 | )._neighbor_graph 259 | 260 | tree = KDTree(nn_data) 261 | true_indices = tree.query(nn_data, 10, return_distance=False) 262 | 263 | num_correct = 0.0 264 | for i in range(nn_data.shape[0]): 265 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 266 | 267 | percent_correct = num_correct / (spatial_data.shape[0] * 10) 268 | assert ( 269 | percent_correct >= 0.99 270 | ), "NN-descent did not get 99% accuracy on nearest neighbors" 271 | 272 | 273 | def test_deterministic(): 274 | seed = np.random.RandomState(42) 275 | 276 | x1 = seed.normal(0, 100, (1000, 50)) 277 | x2 = seed.normal(0, 100, (1000, 50)) 278 | 279 | index1 = NNDescent(x1, random_state=np.random.RandomState(42)) 280 | neighbors1, distances1 = index1.query(x2) 281 | 282 | index2 = NNDescent(x1, random_state=np.random.RandomState(42)) 283 | neighbors2, distances2 = index2.query(x2) 284 | 285 | np.testing.assert_equal(neighbors1, neighbors2) 286 | np.testing.assert_equal(distances1, distances2) 287 | 288 | 289 | # This tests a recursion error on cosine metric reported at: 290 | # https://github.com/lmcinnes/umap/issues/99 291 | # graph_data used is a cut-down version of that provided by @scharron 292 | # It contains lots of all-zero vectors and some other duplicates 293 | def test_rp_trees_should_not_stack_overflow_with_duplicate_data(seed, cosine_hang_data): 294 | 295 | n_neighbors = 10 296 | knn_indices, _ = NNDescent( 297 | cosine_hang_data, 298 | "cosine", 299 | {}, 300 | n_neighbors, 301 | random_state=np.random.RandomState(seed), 302 | n_trees=20, 303 | )._neighbor_graph 304 | 305 | for i in range(cosine_hang_data.shape[0]): 306 | assert len(knn_indices[i]) == len( 307 | np.unique(knn_indices[i]) 308 | ), "Duplicate graph_indices in knn graph" 309 | 310 | 311 | def test_deduplicated_data_behaves_normally(seed, cosine_hang_data): 312 | 313 | data = np.unique(cosine_hang_data, axis=0) 314 | data = data[~np.all(data == 0, axis=1)] 315 | data = data[:1000] 316 | 317 | n_neighbors = 10 318 | knn_indices, _ = NNDescent( 319 | data, 320 | "cosine", 321 | {}, 322 | n_neighbors, 323 | random_state=np.random.RandomState(seed), 324 | n_trees=20, 325 | )._neighbor_graph 326 | 327 | for i in range(data.shape[0]): 328 | assert len(knn_indices[i]) == len( 329 | np.unique(knn_indices[i]) 330 | ), "Duplicate graph_indices in knn graph" 331 | 332 | angular_data = normalize(data, norm="l2") 333 | tree = KDTree(angular_data) 334 | true_indices = tree.query(angular_data, n_neighbors, return_distance=False) 335 | 336 | num_correct = 0 337 | for i in range(data.shape[0]): 338 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 339 | 340 | proportion_correct = num_correct / (data.shape[0] * n_neighbors) 341 | assert ( 342 | proportion_correct >= 0.95 343 | ), "NN-descent did not get 95% accuracy on nearest neighbors" 344 | 345 | 346 | def test_rp_trees_should_not_stack_overflow_with_near_duplicate_data( 347 | seed, cosine_near_duplicates_data 348 | ): 349 | 350 | n_neighbors = 10 351 | knn_indices, _ = NNDescent( 352 | cosine_near_duplicates_data, 353 | "cosine", 354 | {}, 355 | n_neighbors, 356 | random_state=np.random.RandomState(seed), 357 | n_trees=20, 358 | )._neighbor_graph 359 | 360 | for i in range(cosine_near_duplicates_data.shape[0]): 361 | assert len(knn_indices[i]) == len( 362 | np.unique(knn_indices[i]) 363 | ), "Duplicate graph_indices in knn graph" 364 | 365 | 366 | def test_output_when_verbose_is_true(spatial_data, seed): 367 | out = io.StringIO() 368 | with redirect_stdout(out): 369 | _ = NNDescent( 370 | data=spatial_data, 371 | metric="euclidean", 372 | metric_kwds={}, 373 | n_neighbors=4, 374 | random_state=np.random.RandomState(seed), 375 | n_trees=5, 376 | n_iters=2, 377 | verbose=True, 378 | ) 379 | output = out.getvalue() 380 | assert re.match("^.*5 trees", output, re.DOTALL) 381 | assert re.match("^.*2 iterations", output, re.DOTALL) 382 | 383 | 384 | def test_no_output_when_verbose_is_false(spatial_data, seed): 385 | out = io.StringIO() 386 | with redirect_stdout(out): 387 | _ = NNDescent( 388 | data=spatial_data, 389 | metric="euclidean", 390 | metric_kwds={}, 391 | n_neighbors=4, 392 | random_state=np.random.RandomState(seed), 393 | n_trees=5, 394 | n_iters=2, 395 | verbose=False, 396 | ) 397 | output = out.getvalue().strip() 398 | assert len(output) == 0 399 | 400 | 401 | # same as the previous two test, but this time using the PyNNDescentTransformer 402 | # interface 403 | def test_transformer_output_when_verbose_is_true(spatial_data, seed): 404 | out = io.StringIO() 405 | with redirect_stdout(out): 406 | _ = PyNNDescentTransformer( 407 | n_neighbors=4, 408 | metric="euclidean", 409 | metric_kwds={}, 410 | random_state=np.random.RandomState(seed), 411 | n_trees=5, 412 | n_iters=2, 413 | verbose=True, 414 | ).fit_transform(spatial_data) 415 | output = out.getvalue() 416 | assert re.match("^.*5 trees", output, re.DOTALL) 417 | assert re.match("^.*2 iterations", output, re.DOTALL) 418 | 419 | 420 | def test_transformer_output_when_verbose_is_false(spatial_data, seed): 421 | out = io.StringIO() 422 | with redirect_stdout(out): 423 | _ = PyNNDescentTransformer( 424 | n_neighbors=4, 425 | metric="standardised_euclidean", 426 | metric_kwds={"sigma": np.ones(spatial_data.shape[1])}, 427 | random_state=np.random.RandomState(seed), 428 | n_trees=5, 429 | n_iters=2, 430 | verbose=False, 431 | ).fit_transform(spatial_data) 432 | output = out.getvalue().strip() 433 | assert len(output) == 0 434 | 435 | 436 | def test_pickle_unpickle(): 437 | seed = np.random.RandomState(42) 438 | 439 | x1 = seed.normal(0, 100, (1000, 50)) 440 | x2 = seed.normal(0, 100, (1000, 50)) 441 | 442 | index1 = NNDescent(x1, "euclidean", {}, 10, random_state=None) 443 | neighbors1, distances1 = index1.query(x2) 444 | 445 | mem_temp = io.BytesIO() 446 | pickle.dump(index1, mem_temp) 447 | mem_temp.seek(0) 448 | index2 = pickle.load(mem_temp) 449 | 450 | neighbors2, distances2 = index2.query(x2) 451 | 452 | np.testing.assert_equal(neighbors1, neighbors2) 453 | np.testing.assert_equal(distances1, distances2) 454 | 455 | 456 | def test_compressed_pickle_unpickle(): 457 | seed = np.random.RandomState(42) 458 | 459 | x1 = seed.normal(0, 100, (1000, 50)) 460 | x2 = seed.normal(0, 100, (1000, 50)) 461 | 462 | index1 = NNDescent(x1, "euclidean", {}, 10, random_state=None, compressed=True) 463 | neighbors1, distances1 = index1.query(x2) 464 | 465 | mem_temp = io.BytesIO() 466 | pickle.dump(index1, mem_temp) 467 | mem_temp.seek(0) 468 | index2 = pickle.load(mem_temp) 469 | 470 | neighbors2, distances2 = index2.query(x2) 471 | 472 | np.testing.assert_equal(neighbors1, neighbors2) 473 | np.testing.assert_equal(distances1, distances2) 474 | 475 | 476 | def test_transformer_pickle_unpickle(): 477 | seed = np.random.RandomState(42) 478 | 479 | x1 = seed.normal(0, 100, (1000, 50)) 480 | x2 = seed.normal(0, 100, (1000, 50)) 481 | 482 | index1 = PyNNDescentTransformer(n_neighbors=10).fit(x1) 483 | result1 = index1.transform(x2) 484 | 485 | mem_temp = io.BytesIO() 486 | pickle.dump(index1, mem_temp) 487 | mem_temp.seek(0) 488 | index2 = pickle.load(mem_temp) 489 | 490 | result2 = index2.transform(x2) 491 | 492 | np.testing.assert_equal(result1.indices, result2.indices) 493 | np.testing.assert_equal(result1.data, result2.data) 494 | 495 | 496 | def test_joblib_dump(): 497 | seed = np.random.RandomState(42) 498 | 499 | x1 = seed.normal(0, 100, (1000, 50)) 500 | x2 = seed.normal(0, 100, (1000, 50)) 501 | 502 | index1 = NNDescent(x1, "euclidean", {}, 10, random_state=None) 503 | neighbors1, distances1 = index1.query(x2) 504 | 505 | mem_temp = io.BytesIO() 506 | joblib.dump(index1, mem_temp) 507 | mem_temp.seek(0) 508 | index2 = joblib.load(mem_temp) 509 | 510 | neighbors2, distances2 = index2.query(x2) 511 | 512 | np.testing.assert_equal(neighbors1, neighbors2) 513 | np.testing.assert_equal(distances1, distances2) 514 | 515 | 516 | @pytest.mark.parametrize("metric", ["euclidean", "cosine"]) 517 | def test_update_no_prepare_query_accuracy(nn_data, metric): 518 | nnd = NNDescent(nn_data[200:800], metric=metric, n_neighbors=10, random_state=None) 519 | nnd.update(xs_fresh=nn_data[800:]) 520 | 521 | knn_indices, _ = nnd.query(nn_data[:200], k=10, epsilon=0.2) 522 | 523 | true_nnd = NearestNeighbors(metric=metric).fit(nn_data[200:]) 524 | true_indices = true_nnd.kneighbors(nn_data[:200], 10, return_distance=False) 525 | 526 | num_correct = 0.0 527 | for i in range(true_indices.shape[0]): 528 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 529 | 530 | percent_correct = num_correct / (true_indices.shape[0] * 10) 531 | assert percent_correct >= 0.95, ( 532 | "NN-descent query did not get 95% " "accuracy on nearest neighbors" 533 | ) 534 | 535 | 536 | @pytest.mark.parametrize("metric", ["euclidean", "cosine"]) 537 | def test_update_w_prepare_query_accuracy(nn_data, metric): 538 | nnd = NNDescent( 539 | nn_data[200:800], 540 | metric=metric, 541 | n_neighbors=10, 542 | random_state=None, 543 | compressed=False, 544 | ) 545 | nnd.prepare() 546 | 547 | nnd.update(xs_fresh=nn_data[800:]) 548 | nnd.prepare() 549 | 550 | knn_indices, _ = nnd.query(nn_data[:200], k=10, epsilon=0.2) 551 | 552 | true_nnd = NearestNeighbors(metric=metric).fit(nn_data[200:]) 553 | true_indices = true_nnd.kneighbors(nn_data[:200], 10, return_distance=False) 554 | 555 | num_correct = 0.0 556 | for i in range(true_indices.shape[0]): 557 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 558 | 559 | percent_correct = num_correct / (true_indices.shape[0] * 10) 560 | assert percent_correct >= 0.95, ( 561 | "NN-descent query did not get 95% " "accuracy on nearest neighbors" 562 | ) 563 | 564 | 565 | @pytest.mark.parametrize("metric", ["euclidean", "cosine"]) 566 | def test_update_w_prepare_query_accuracy(nn_data, metric): 567 | nnd = NNDescent( 568 | nn_data[200:800], 569 | metric=metric, 570 | n_neighbors=10, 571 | random_state=None, 572 | compressed=False, 573 | ) 574 | nnd.prepare() 575 | 576 | nnd.update(xs_fresh=nn_data[800:]) 577 | nnd.prepare() 578 | 579 | knn_indices, _ = nnd.query(nn_data[:200], k=10, epsilon=0.2) 580 | 581 | true_nnd = NearestNeighbors(metric=metric).fit(nn_data[200:]) 582 | true_indices = true_nnd.kneighbors(nn_data[:200], 10, return_distance=False) 583 | 584 | num_correct = 0.0 585 | for i in range(true_indices.shape[0]): 586 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 587 | 588 | percent_correct = num_correct / (true_indices.shape[0] * 10) 589 | assert percent_correct >= 0.95, ( 590 | "NN-descent query did not get 95% " "accuracy on nearest neighbors" 591 | ) 592 | 593 | 594 | def evaluate_predictions(neighbors_true, neigbhors_computed, n_neighbors): 595 | n_correct = 0 596 | n_all = neighbors_true.shape[0] * n_neighbors 597 | for i in range(neighbors_true.shape[0]): 598 | n_correct += np.sum(np.isin(neighbors_true[i], neigbhors_computed[i])) 599 | return n_correct / n_all 600 | 601 | 602 | @pytest.mark.parametrize("metric", ["manhattan", "euclidean", "cosine"]) 603 | @pytest.mark.parametrize("case", list(range(8))) # the number of cases in update_data 604 | def test_update_with_changed_data(update_data, case, metric): 605 | def evaluate(nn_descent, xs_to_fit, xs_to_query): 606 | true_nn = NearestNeighbors(metric=metric, n_neighbors=k).fit(xs_to_fit) 607 | neighbors, _ = nn_descent.query(xs_to_query, k=k) 608 | neighbors_expected = true_nn.kneighbors(xs_to_query, k, return_distance=False) 609 | p_correct = evaluate_predictions(neighbors_expected, neighbors, k) 610 | assert p_correct >= 0.95, ( 611 | "NN-descent query did not get 95% " "accuracy on nearest neighbors" 612 | ) 613 | 614 | k = 10 615 | xs_orig, xs_fresh, xs_updated, indices_updated = update_data[case] 616 | queries1 = xs_orig 617 | 618 | # original 619 | index = NNDescent(xs_orig, metric=metric, n_neighbors=40, random_state=1234) 620 | index.prepare() 621 | evaluate(index, xs_orig, queries1) 622 | # updated 623 | index.update( 624 | xs_fresh=xs_fresh, xs_updated=xs_updated, updated_indices=indices_updated 625 | ) 626 | if xs_fresh is not None: 627 | xs = np.vstack((xs_orig, xs_fresh)) 628 | queries2 = np.vstack((queries1, xs_fresh)) 629 | else: 630 | xs = xs_orig 631 | queries2 = queries1 632 | if indices_updated is not None: 633 | xs[indices_updated] = xs_updated 634 | evaluate(index, xs, queries2) 635 | if indices_updated is not None: 636 | evaluate(index, xs, xs_updated) 637 | 638 | 639 | @pytest.mark.parametrize("n_trees", [1, 2, 3, 10]) 640 | def test_tree_numbers_after_multiple_updates(n_trees): 641 | trees_after_update = max(1, int(np.round(n_trees / 3))) 642 | 643 | nnd = NNDescent(np.array([[1.0]]), n_neighbors=1, n_trees=n_trees) 644 | 645 | assert nnd.n_trees == n_trees, "NN-descent update changed the number of trees" 646 | assert ( 647 | nnd.n_trees_after_update == trees_after_update 648 | ), "The value of the n_trees_after_update in NN-descent after update(s) is wrong" 649 | for i in range(5): 650 | nnd.update(xs_fresh=np.array([[i]], dtype=np.float64)) 651 | assert ( 652 | nnd.n_trees == trees_after_update 653 | ), "The value of the n_trees in NN-descent after update(s) is wrong" 654 | assert ( 655 | nnd.n_trees_after_update == trees_after_update 656 | ), "The value of the n_trees_after_update in NN-descent after update(s) is wrong" 657 | 658 | 659 | @pytest.mark.parametrize("metric", ["euclidean", "cosine"]) 660 | def test_tree_init_false(nn_data, metric): 661 | nnd = NNDescent( 662 | nn_data[200:], metric=metric, n_neighbors=10, random_state=None, tree_init=False 663 | ) 664 | nnd.prepare() 665 | 666 | knn_indices, _ = nnd.query(nn_data[:200], k=10, epsilon=0.2) 667 | 668 | true_nnd = NearestNeighbors(metric=metric).fit(nn_data[200:]) 669 | true_indices = true_nnd.kneighbors(nn_data[:200], 10, return_distance=False) 670 | 671 | num_correct = 0.0 672 | for i in range(true_indices.shape[0]): 673 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 674 | 675 | percent_correct = num_correct / (true_indices.shape[0] * 10) 676 | assert percent_correct >= 0.95, ( 677 | "NN-descent query did not get 95% " "accuracy on nearest neighbors" 678 | ) 679 | 680 | 681 | @pytest.mark.parametrize( 682 | "metric", ["euclidean", "manhattan"] 683 | ) # cosine makes no sense for 1D 684 | def test_one_dimensional_data(nn_data, metric): 685 | nnd = NNDescent( 686 | nn_data[200:, :1], 687 | metric=metric, 688 | n_neighbors=20, 689 | random_state=None, 690 | tree_init=False, 691 | ) 692 | nnd.prepare() 693 | 694 | knn_indices, _ = nnd.query(nn_data[:200, :1], k=10, epsilon=0.2) 695 | 696 | true_nnd = NearestNeighbors(metric=metric).fit(nn_data[200:, :1]) 697 | true_indices = true_nnd.kneighbors(nn_data[:200, :1], 10, return_distance=False) 698 | 699 | num_correct = 0.0 700 | for i in range(true_indices.shape[0]): 701 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 702 | 703 | percent_correct = num_correct / (true_indices.shape[0] * 10) 704 | assert percent_correct >= 0.95, ( 705 | "NN-descent query did not get 95% " "accuracy on nearest neighbors" 706 | ) 707 | 708 | 709 | @pytest.mark.parametrize("metric", ["euclidean", "cosine"]) 710 | def test_tree_no_split(small_data, sparse_small_data, metric): 711 | k = 10 712 | for data, data_type in zip([small_data, sparse_small_data], ["dense", "sparse"]): 713 | n_instances = data.shape[0] 714 | leaf_size = n_instances + 1 # just to be safe 715 | data_train = data[n_instances // 2 :] 716 | data_test = data[: n_instances // 2] 717 | 718 | nnd = NNDescent( 719 | data_train, 720 | metric=metric, 721 | n_neighbors=data_train.shape[0] - 1, 722 | random_state=None, 723 | tree_init=True, 724 | leaf_size=leaf_size, 725 | ) 726 | nnd.prepare() 727 | knn_indices, _ = nnd.query(data_test, k=k, epsilon=0.2) 728 | 729 | true_nnd = NearestNeighbors(metric=metric).fit(data_train) 730 | true_indices = true_nnd.kneighbors(data_test, k, return_distance=False) 731 | 732 | num_correct = 0.0 733 | for i in range(true_indices.shape[0]): 734 | num_correct += np.sum(np.isin(true_indices[i], knn_indices[i])) 735 | 736 | percent_correct = num_correct / (true_indices.shape[0] * k) 737 | assert ( 738 | percent_correct >= 0.95 739 | ), "NN-descent query did not get 95% for accuracy on nearest neighbors on {} data".format( 740 | data_type 741 | ) 742 | 743 | 744 | @pytest.mark.skipif( 745 | "NUMBA_DISABLE_JIT" in os.environ, reason="Too expensive for disabled Numba" 746 | ) 747 | def test_bad_data(): 748 | test_data_dir = pathlib.Path(__file__).parent / "test_data" 749 | data = np.sqrt( 750 | np.load(test_data_dir / "pynndescent_bug_np.npz")["arr_0"] 751 | ) 752 | index = NNDescent(data, metric="cosine") 753 | -------------------------------------------------------------------------------- /pynndescent/utils.py: -------------------------------------------------------------------------------- 1 | # Author: Leland McInnes 2 | # 3 | # License: BSD 2 clause 4 | 5 | import time 6 | 7 | import numba 8 | import numpy as np 9 | 10 | 11 | @numba.njit("void(i8[:], i8)", cache=True) 12 | def seed(rng_state, seed): 13 | """Seed the random number generator with a given seed.""" 14 | rng_state.fill(seed + 0xFFFF) 15 | 16 | 17 | @numba.njit("i4(i8[:])", cache=True) 18 | def tau_rand_int(state): 19 | """A fast (pseudo)-random number generator. 20 | 21 | Parameters 22 | ---------- 23 | state: array of int64, shape (3,) 24 | The internal state of the rng 25 | 26 | Returns 27 | ------- 28 | A (pseudo)-random int32 value 29 | """ 30 | state[0] = (((state[0] & 4294967294) << 12) & 0xFFFFFFFF) ^ ( 31 | (((state[0] << 13) & 0xFFFFFFFF) ^ state[0]) >> 19 32 | ) 33 | state[1] = (((state[1] & 4294967288) << 4) & 0xFFFFFFFF) ^ ( 34 | (((state[1] << 2) & 0xFFFFFFFF) ^ state[1]) >> 25 35 | ) 36 | state[2] = (((state[2] & 4294967280) << 17) & 0xFFFFFFFF) ^ ( 37 | (((state[2] << 3) & 0xFFFFFFFF) ^ state[2]) >> 11 38 | ) 39 | 40 | return state[0] ^ state[1] ^ state[2] 41 | 42 | 43 | @numba.njit("f4(i8[:])", cache=True) 44 | def tau_rand(state): 45 | """A fast (pseudo)-random number generator for floats in the range [0,1] 46 | 47 | Parameters 48 | ---------- 49 | state: array of int64, shape (3,) 50 | The internal state of the rng 51 | 52 | Returns 53 | ------- 54 | A (pseudo)-random float32 in the interval [0, 1] 55 | """ 56 | integer = tau_rand_int(state) 57 | return abs(float(integer) / 0x7FFFFFFF) 58 | 59 | 60 | @numba.njit( 61 | [ 62 | "f4(f4[::1])", 63 | numba.types.float32( 64 | numba.types.Array(numba.types.float32, 1, "C", readonly=True) 65 | ), 66 | ], 67 | locals={ 68 | "dim": numba.types.intp, 69 | "i": numba.types.uint32, 70 | # "result": numba.types.float32, # This provides speed, but causes errors in corner cases 71 | }, 72 | fastmath=True, 73 | cache=True, 74 | ) 75 | def norm(vec): 76 | """Compute the (standard l2) norm of a vector. 77 | 78 | Parameters 79 | ---------- 80 | vec: array of shape (dim,) 81 | 82 | Returns 83 | ------- 84 | The l2 norm of vec. 85 | """ 86 | result = 0.0 87 | dim = vec.shape[0] 88 | for i in range(dim): 89 | result += vec[i] * vec[i] 90 | return np.sqrt(result) 91 | 92 | 93 | @numba.njit(cache=True) 94 | def rejection_sample(n_samples, pool_size, rng_state): 95 | """Generate n_samples many integers from 0 to pool_size such that no 96 | integer is selected twice. The duplication constraint is achieved via 97 | rejection sampling. 98 | 99 | Parameters 100 | ---------- 101 | n_samples: int 102 | The number of random samples to select from the pool 103 | 104 | pool_size: int 105 | The size of the total pool of candidates to sample from 106 | 107 | rng_state: array of int64, shape (3,) 108 | Internal state of the random number generator 109 | 110 | Returns 111 | ------- 112 | sample: array of shape(n_samples,) 113 | The ``n_samples`` randomly selected elements from the pool. 114 | """ 115 | result = np.empty(n_samples, dtype=np.int64) 116 | for i in range(n_samples): 117 | reject_sample = True 118 | j = 0 119 | while reject_sample: 120 | j = tau_rand_int(rng_state) % pool_size 121 | for k in range(i): 122 | if j == result[k]: 123 | break 124 | else: 125 | reject_sample = False 126 | result[i] = j 127 | return result 128 | 129 | 130 | @numba.njit(cache=True) 131 | def make_heap(n_points, size): 132 | """Constructor for the numba enabled heap objects. The heaps are used 133 | for approximate nearest neighbor search, maintaining a list of potential 134 | neighbors sorted by their distance. We also flag if potential neighbors 135 | are newly added to the list or not. Internally this is stored as 136 | a single ndarray; the first axis determines whether we are looking at the 137 | array of candidate graph_indices, the array of distances, or the flag array for 138 | whether elements are new or not. Each of these arrays are of shape 139 | (``n_points``, ``size``) 140 | 141 | Parameters 142 | ---------- 143 | n_points: int 144 | The number of graph_data points to track in the heap. 145 | 146 | size: int 147 | The number of items to keep on the heap for each graph_data point. 148 | 149 | Returns 150 | ------- 151 | heap: An ndarray suitable for passing to other numba enabled heap functions. 152 | """ 153 | indices = np.full((int(n_points), int(size)), -1, dtype=np.int32) 154 | distances = np.full((int(n_points), int(size)), np.inf, dtype=np.float32) 155 | flags = np.zeros((int(n_points), int(size)), dtype=np.uint8) 156 | result = (indices, distances, flags) 157 | 158 | return result 159 | 160 | 161 | # Sentinel value for empty/uninitialized graphs 162 | EMPTY_GRAPH = make_heap(1, 1) 163 | 164 | 165 | @numba.njit(cache=True) 166 | def siftdown(heap1, heap2, elt): 167 | """Restore the heap property for a heap with an out of place element 168 | at position ``elt``. This works with a heap pair where heap1 carries 169 | the weights and heap2 holds the corresponding elements.""" 170 | while elt * 2 + 1 < heap1.shape[0]: 171 | left_child = elt * 2 + 1 172 | right_child = left_child + 1 173 | swap = elt 174 | 175 | if heap1[swap] < heap1[left_child]: 176 | swap = left_child 177 | 178 | if right_child < heap1.shape[0] and heap1[swap] < heap1[right_child]: 179 | swap = right_child 180 | 181 | if swap == elt: 182 | break 183 | else: 184 | heap1[elt], heap1[swap] = heap1[swap], heap1[elt] 185 | heap2[elt], heap2[swap] = heap2[swap], heap2[elt] 186 | elt = swap 187 | 188 | 189 | @numba.njit(parallel=True, cache=False) 190 | def deheap_sort(indices, distances): 191 | """Given two arrays representing a heap (indices and distances), reorder the 192 | arrays by increasing distance. This is effectively just the second half of 193 | heap sort (the first half not being required since we already have the 194 | graph_data in a heap). 195 | 196 | Note that this is done in-place. 197 | 198 | Parameters 199 | ---------- 200 | indices : array of shape (n_samples, n_neighbors) 201 | The graph indices to sort by distance. 202 | distances : array of shape (n_samples, n_neighbors) 203 | The corresponding edge distance. 204 | 205 | Returns 206 | ------- 207 | indices, distances: arrays of shape (n_samples, n_neighbors) 208 | The indices and distances sorted by increasing distance. 209 | """ 210 | for i in numba.prange(indices.shape[0]): 211 | # starting from the end of the array and moving back 212 | for j in range(indices.shape[1] - 1, 0, -1): 213 | indices[i, 0], indices[i, j] = indices[i, j], indices[i, 0] 214 | distances[i, 0], distances[i, j] = distances[i, j], distances[i, 0] 215 | 216 | siftdown(distances[i, :j], indices[i, :j], 0) 217 | 218 | return indices, distances 219 | 220 | 221 | @numba.njit(parallel=True, locals={"idx": numba.types.int64}, cache=False) 222 | def new_build_candidates(current_graph, max_candidates, rng_state, n_threads): 223 | """Build a heap of candidate neighbors for nearest neighbor descent. For 224 | each vertex the candidate neighbors are any current neighbors, and any 225 | vertices that have the vertex as one of their nearest neighbors. 226 | 227 | Parameters 228 | ---------- 229 | current_graph: heap 230 | The current state of the graph for nearest neighbor descent. 231 | 232 | max_candidates: int 233 | The maximum number of new candidate neighbors. 234 | 235 | rng_state: array of int64, shape (3,) 236 | The internal state of the rng 237 | 238 | Returns 239 | ------- 240 | candidate_neighbors: A heap with an array of (randomly sorted) candidate 241 | neighbors for each vertex in the graph. 242 | """ 243 | current_indices = current_graph[0] 244 | current_flags = current_graph[2] 245 | 246 | n_vertices = current_indices.shape[0] 247 | n_neighbors = current_indices.shape[1] 248 | 249 | new_candidate_indices = np.full((n_vertices, max_candidates), -1, dtype=np.int32) 250 | new_candidate_priority = np.full( 251 | (n_vertices, max_candidates), np.inf, dtype=np.float32 252 | ) 253 | 254 | old_candidate_indices = np.full((n_vertices, max_candidates), -1, dtype=np.int32) 255 | old_candidate_priority = np.full( 256 | (n_vertices, max_candidates), np.inf, dtype=np.float32 257 | ) 258 | 259 | block_size = n_vertices // n_threads + 1 260 | 261 | for n in numba.prange(n_threads): 262 | local_rng_state = rng_state + n 263 | block_start = n * block_size 264 | block_end = min(block_start + block_size, n_vertices) 265 | 266 | for i in range(n_vertices): 267 | for j in range(n_neighbors): 268 | idx = current_indices[i, j] 269 | 270 | if idx >= 0 and ( 271 | (i >= block_start and i < block_end) 272 | or (idx >= block_start and idx < block_end) 273 | ): 274 | isn = current_flags[i, j] 275 | d = tau_rand(local_rng_state) 276 | 277 | if isn: 278 | if i >= block_start and i < block_end: 279 | checked_heap_push( 280 | new_candidate_priority[i], 281 | new_candidate_indices[i], 282 | d, 283 | idx, 284 | ) 285 | if idx >= block_start and idx < block_end: 286 | checked_heap_push( 287 | new_candidate_priority[idx], 288 | new_candidate_indices[idx], 289 | d, 290 | i, 291 | ) 292 | else: 293 | if i >= block_start and i < block_end: 294 | checked_heap_push( 295 | old_candidate_priority[i], 296 | old_candidate_indices[i], 297 | d, 298 | idx, 299 | ) 300 | if idx >= block_start and idx < block_end: 301 | checked_heap_push( 302 | old_candidate_priority[idx], 303 | old_candidate_indices[idx], 304 | d, 305 | i, 306 | ) 307 | 308 | indices = current_graph[0] 309 | flags = current_graph[2] 310 | 311 | for i in numba.prange(n_vertices): 312 | for j in range(n_neighbors): 313 | idx = indices[i, j] 314 | 315 | for k in range(max_candidates): 316 | if new_candidate_indices[i, k] == idx: 317 | flags[i, j] = 0 318 | break 319 | 320 | return new_candidate_indices, old_candidate_indices 321 | 322 | 323 | @numba.njit("b1(u1[::1],i4)", cache=True) 324 | def has_been_visited(table, candidate): 325 | loc = candidate >> 3 326 | mask = 1 << (candidate & 7) 327 | return table[loc] & mask 328 | 329 | 330 | @numba.njit("void(u1[::1],i4)", cache=True) 331 | def mark_visited(table, candidate): 332 | loc = candidate >> 3 333 | mask = 1 << (candidate & 7) 334 | table[loc] |= mask 335 | return 336 | 337 | 338 | @numba.njit("b1(u1[::1],i4)", cache=True) 339 | def check_and_mark_visited(table, candidate): 340 | """Check if candidate was visited and mark it as visited in one operation. 341 | 342 | Returns True if the candidate was already visited, False otherwise. 343 | More efficient than separate has_been_visited + mark_visited calls. 344 | """ 345 | loc = candidate >> 3 346 | mask = numba.uint8(1 << (candidate & 7)) 347 | was_visited = table[loc] & mask 348 | table[loc] |= mask 349 | return was_visited 350 | 351 | 352 | @numba.njit( 353 | "i4(f4[::1],i4[::1],f4,i4)", 354 | fastmath=True, 355 | locals={ 356 | "size": numba.types.intp, 357 | "i": numba.types.uint16, 358 | "ic1": numba.types.uint16, 359 | "ic2": numba.types.uint16, 360 | "i_swap": numba.types.uint16, 361 | }, 362 | cache=True, 363 | ) 364 | def simple_heap_push(priorities, indices, p, n): 365 | if p >= priorities[0]: 366 | return 0 367 | 368 | size = priorities.shape[0] 369 | 370 | # insert val at position zero 371 | priorities[0] = p 372 | indices[0] = n 373 | 374 | # descend the heap, swapping values until the max heap criterion is met 375 | i = 0 376 | while True: 377 | ic1 = 2 * i + 1 378 | ic2 = ic1 + 1 379 | 380 | if ic1 >= size: 381 | break 382 | elif ic2 >= size: 383 | if priorities[ic1] > p: 384 | i_swap = ic1 385 | else: 386 | break 387 | elif priorities[ic1] >= priorities[ic2]: 388 | if p < priorities[ic1]: 389 | i_swap = ic1 390 | else: 391 | break 392 | else: 393 | if p < priorities[ic2]: 394 | i_swap = ic2 395 | else: 396 | break 397 | 398 | priorities[i] = priorities[i_swap] 399 | indices[i] = indices[i_swap] 400 | 401 | i = i_swap 402 | 403 | priorities[i] = p 404 | indices[i] = n 405 | 406 | return 1 407 | 408 | 409 | @numba.njit( 410 | "i4(f4[::1],i4[::1],f4,i4)", 411 | fastmath=True, 412 | locals={ 413 | "size": numba.types.intp, 414 | "i": numba.types.uint16, 415 | "ic1": numba.types.uint16, 416 | "ic2": numba.types.uint16, 417 | "i_swap": numba.types.uint16, 418 | }, 419 | cache=True, 420 | ) 421 | def checked_heap_push(priorities, indices, p, n): 422 | if p >= priorities[0]: 423 | return 0 424 | 425 | size = priorities.shape[0] 426 | 427 | # break if we already have this element. 428 | for i in range(size): 429 | if n == indices[i]: 430 | return 0 431 | 432 | # insert val at position zero 433 | priorities[0] = p 434 | indices[0] = n 435 | 436 | # descend the heap, swapping values until the max heap criterion is met 437 | i = 0 438 | while True: 439 | ic1 = 2 * i + 1 440 | ic2 = ic1 + 1 441 | 442 | if ic1 >= size: 443 | break 444 | elif ic2 >= size: 445 | if priorities[ic1] > p: 446 | i_swap = ic1 447 | else: 448 | break 449 | elif priorities[ic1] >= priorities[ic2]: 450 | if p < priorities[ic1]: 451 | i_swap = ic1 452 | else: 453 | break 454 | else: 455 | if p < priorities[ic2]: 456 | i_swap = ic2 457 | else: 458 | break 459 | 460 | priorities[i] = priorities[i_swap] 461 | indices[i] = indices[i_swap] 462 | 463 | i = i_swap 464 | 465 | priorities[i] = p 466 | indices[i] = n 467 | 468 | return 1 469 | 470 | 471 | @numba.njit( 472 | "i4(f4[::1],i4[::1],u1[::1],f4,i4,u1)", 473 | fastmath=True, 474 | locals={ 475 | "size": numba.types.intp, 476 | "i": numba.types.uint16, 477 | "ic1": numba.types.uint16, 478 | "ic2": numba.types.uint16, 479 | "i_swap": numba.types.uint16, 480 | }, 481 | cache=True, 482 | ) 483 | def checked_flagged_heap_push(priorities, indices, flags, p, n, f): 484 | if p >= priorities[0]: 485 | return 0 486 | 487 | size = priorities.shape[0] 488 | 489 | # break if we already have this element. 490 | for i in range(size): 491 | if n == indices[i]: 492 | return 0 493 | 494 | # insert val at position zero 495 | priorities[0] = p 496 | indices[0] = n 497 | flags[0] = f 498 | 499 | # descend the heap, swapping values until the max heap criterion is met 500 | i = 0 501 | while True: 502 | ic1 = 2 * i + 1 503 | ic2 = ic1 + 1 504 | 505 | if ic1 >= size: 506 | break 507 | elif ic2 >= size: 508 | if priorities[ic1] > p: 509 | i_swap = ic1 510 | else: 511 | break 512 | elif priorities[ic1] >= priorities[ic2]: 513 | if p < priorities[ic1]: 514 | i_swap = ic1 515 | else: 516 | break 517 | else: 518 | if p < priorities[ic2]: 519 | i_swap = ic2 520 | else: 521 | break 522 | 523 | priorities[i] = priorities[i_swap] 524 | indices[i] = indices[i_swap] 525 | flags[i] = flags[i_swap] 526 | 527 | i = i_swap 528 | 529 | priorities[i] = p 530 | indices[i] = n 531 | flags[i] = f 532 | 533 | return 1 534 | 535 | 536 | @numba.njit( 537 | parallel=True, 538 | locals={ 539 | "dist_thresh_p": numba.float32, 540 | "dist_thresh_q": numba.float32, 541 | "p": numba.int32, 542 | "q": numba.int32, 543 | "d": numba.float32, 544 | "max_updates": numba.int32, 545 | "max_threshold": numba.float32, 546 | }, 547 | cache=False, 548 | fastmath=True, 549 | ) 550 | def generate_graph_update_array( 551 | update_array, 552 | n_updates_per_thread, 553 | new_candidate_block, 554 | old_candidate_block, 555 | dist_thresholds, 556 | data, 557 | dist, 558 | n_threads, 559 | ): 560 | """Generate graph updates into a pre-allocated array. 561 | 562 | This is more efficient than generating lists of tuples because: 563 | 1. No dynamic memory allocation during the parallel loop 564 | 2. Better cache locality with contiguous array storage 565 | 3. Each thread writes to its own section of the array 566 | 567 | Parameters 568 | ---------- 569 | update_array : ndarray of shape (n_threads, max_updates_per_thread, 3) 570 | Pre-allocated array to store updates. Each row stores (p, q, d). 571 | 572 | n_updates_per_thread : ndarray of shape (n_threads,) 573 | Output array to store the number of updates generated by each thread. 574 | 575 | new_candidate_block : ndarray of shape (block_size, max_candidates) 576 | New candidate indices for this block. 577 | 578 | old_candidate_block : ndarray of shape (block_size, max_candidates) 579 | Old candidate indices for this block. 580 | 581 | dist_thresholds : ndarray of shape (n_vertices,) 582 | Current distance thresholds (max heap distance) for each vertex. 583 | 584 | data : ndarray of shape (n_vertices, n_features) 585 | The data points. 586 | 587 | dist : callable 588 | Distance function. 589 | 590 | n_threads : int 591 | Number of threads to use. 592 | """ 593 | block_size = new_candidate_block.shape[0] 594 | max_new_candidates = new_candidate_block.shape[1] 595 | max_old_candidates = old_candidate_block.shape[1] 596 | rows_per_thread = (block_size // n_threads) + 1 597 | 598 | for t in numba.prange(n_threads): 599 | idx = 0 600 | max_updates = update_array.shape[1] 601 | 602 | for r in range(rows_per_thread): 603 | i = t * rows_per_thread + r 604 | if i >= block_size or idx >= max_updates: 605 | break 606 | 607 | for j in range(max_new_candidates): 608 | if idx >= max_updates: 609 | break 610 | 611 | p = new_candidate_block[i, j] 612 | if p < 0: 613 | continue 614 | 615 | data_p = data[p] 616 | dist_thresh_p = dist_thresholds[p] 617 | 618 | # Compare with other new candidates (start at j to match original behavior) 619 | for k in range(j, max_new_candidates): 620 | if idx >= max_updates: 621 | break 622 | 623 | q = new_candidate_block[i, k] 624 | if q < 0: 625 | continue 626 | 627 | d = dist(data_p, data[q]) 628 | 629 | # Use max for better branch prediction than OR condition 630 | dist_thresh_q = dist_thresholds[q] 631 | max_threshold = max(dist_thresh_p, dist_thresh_q) 632 | 633 | if d <= max_threshold: 634 | update_array[t, idx, 0] = p 635 | update_array[t, idx, 1] = q 636 | update_array[t, idx, 2] = d 637 | idx += 1 638 | 639 | # Compare with old candidates 640 | for k in range(max_old_candidates): 641 | if idx >= max_updates: 642 | break 643 | 644 | q = old_candidate_block[i, k] 645 | if q < 0: 646 | continue 647 | 648 | d = dist(data_p, data[q]) 649 | dist_thresh_q = dist_thresholds[q] 650 | max_threshold = max(dist_thresh_p, dist_thresh_q) 651 | 652 | if d <= max_threshold: 653 | update_array[t, idx, 0] = p 654 | update_array[t, idx, 1] = q 655 | update_array[t, idx, 2] = d 656 | idx += 1 657 | 658 | n_updates_per_thread[t] = idx 659 | 660 | 661 | @numba.njit( 662 | parallel=True, 663 | cache=True, 664 | locals={ 665 | "p": numba.int32, 666 | "q": numba.int32, 667 | "d": numba.float32, 668 | "added": numba.uint8, 669 | "n": numba.uint32, 670 | "t": numba.uint32, 671 | "j": numba.uint32, 672 | }, 673 | ) 674 | def apply_graph_update_array( 675 | current_graph, update_array, n_updates_per_thread, n_threads 676 | ): 677 | """Apply graph updates from a pre-allocated array. 678 | 679 | Uses block-based processing where each thread only updates vertices 680 | in its assigned block, avoiding the need for duplicate checking. 681 | 682 | Parameters 683 | ---------- 684 | current_graph : tuple of (indices, distances, flags) 685 | The current nearest neighbor graph heap. 686 | 687 | update_array : ndarray of shape (n_threads, max_updates_per_thread, 3) 688 | Array of updates where each row is (p, q, d). 689 | 690 | n_updates_per_thread : ndarray of shape (n_threads,) 691 | Number of valid updates from each generating thread. 692 | 693 | n_threads : int 694 | Number of threads. 695 | 696 | Returns 697 | ------- 698 | n_changes : int 699 | Total number of updates that modified the graph. 700 | """ 701 | n_changes = 0 702 | priorities = current_graph[1] 703 | indices = current_graph[0] 704 | flags = current_graph[2] 705 | 706 | n_vertices = priorities.shape[0] 707 | vertex_block_size = n_vertices // n_threads + 1 708 | 709 | for n in numba.prange(n_threads): 710 | block_start = n * vertex_block_size 711 | block_end = min(block_start + vertex_block_size, n_vertices) 712 | 713 | # Each thread scans all updates but only applies those 714 | # where p or q falls in its block 715 | for t in range(n_threads): 716 | for j in range(n_updates_per_thread[t]): 717 | p = np.int32(update_array[t, j, 0]) 718 | q = np.int32(update_array[t, j, 1]) 719 | d = np.float32(update_array[t, j, 2]) 720 | 721 | if p >= block_start and p < block_end: 722 | added = checked_flagged_heap_push( 723 | priorities[p], indices[p], flags[p], d, q, 1 724 | ) 725 | n_changes += added 726 | 727 | if q >= block_start and q < block_end: 728 | added = checked_flagged_heap_push( 729 | priorities[q], indices[q], flags[q], d, p, 1 730 | ) 731 | n_changes += added 732 | 733 | return n_changes 734 | 735 | 736 | @numba.njit( 737 | parallel=True, 738 | cache=False, 739 | fastmath=True, 740 | locals={ 741 | "dist_thresh_p": numba.float32, 742 | "dist_thresh_q": numba.float32, 743 | "p": numba.int32, 744 | "q": numba.int32, 745 | "d": numba.float32, 746 | "max_updates": numba.int32, 747 | "max_threshold": numba.float32, 748 | }, 749 | ) 750 | def sparse_generate_graph_update_array( 751 | update_array, 752 | n_updates_per_thread, 753 | new_candidate_block, 754 | old_candidate_block, 755 | dist_thresholds, 756 | inds, 757 | indptr, 758 | data, 759 | dist, 760 | n_threads, 761 | ): 762 | """Generate graph updates for sparse data into a pre-allocated array.""" 763 | block_size = new_candidate_block.shape[0] 764 | max_new_candidates = new_candidate_block.shape[1] 765 | max_old_candidates = old_candidate_block.shape[1] 766 | rows_per_thread = (block_size // n_threads) + 1 767 | 768 | for t in numba.prange(n_threads): 769 | idx = 0 770 | max_updates = update_array.shape[1] 771 | 772 | for r in range(rows_per_thread): 773 | i = t * rows_per_thread + r 774 | if i >= block_size or idx >= max_updates: 775 | break 776 | 777 | for j in range(max_new_candidates): 778 | if idx >= max_updates: 779 | break 780 | 781 | p = new_candidate_block[i, j] 782 | if p < 0: 783 | continue 784 | 785 | from_inds = inds[indptr[p] : indptr[p + 1]] 786 | from_data = data[indptr[p] : indptr[p + 1]] 787 | dist_thresh_p = dist_thresholds[p] 788 | 789 | # Compare with other new candidates (start at j to match original) 790 | for k in range(j, max_new_candidates): 791 | if idx >= max_updates: 792 | break 793 | 794 | q = new_candidate_block[i, k] 795 | if q < 0: 796 | continue 797 | 798 | to_inds = inds[indptr[q] : indptr[q + 1]] 799 | to_data = data[indptr[q] : indptr[q + 1]] 800 | d = dist(from_inds, from_data, to_inds, to_data) 801 | 802 | dist_thresh_q = dist_thresholds[q] 803 | max_threshold = max(dist_thresh_p, dist_thresh_q) 804 | 805 | if d <= max_threshold: 806 | update_array[t, idx, 0] = p 807 | update_array[t, idx, 1] = q 808 | update_array[t, idx, 2] = d 809 | idx += 1 810 | 811 | # Compare with old candidates 812 | for k in range(max_old_candidates): 813 | if idx >= max_updates: 814 | break 815 | 816 | q = old_candidate_block[i, k] 817 | if q < 0: 818 | continue 819 | 820 | to_inds = inds[indptr[q] : indptr[q + 1]] 821 | to_data = data[indptr[q] : indptr[q + 1]] 822 | d = dist(from_inds, from_data, to_inds, to_data) 823 | 824 | dist_thresh_q = dist_thresholds[q] 825 | max_threshold = max(dist_thresh_p, dist_thresh_q) 826 | 827 | if d <= max_threshold: 828 | update_array[t, idx, 0] = p 829 | update_array[t, idx, 1] = q 830 | update_array[t, idx, 2] = d 831 | idx += 1 832 | 833 | n_updates_per_thread[t] = idx 834 | 835 | 836 | @numba.njit(cache=False) 837 | def initalize_heap_from_graph_indices(heap, graph_indices, data, metric): 838 | 839 | for i in range(graph_indices.shape[0]): 840 | for idx in range(graph_indices.shape[1]): 841 | j = graph_indices[i, idx] 842 | if j >= 0: 843 | d = metric(data[i], data[j]) 844 | checked_flagged_heap_push(heap[1][i], heap[0][i], heap[2][i], d, j, 1) 845 | 846 | return heap 847 | 848 | 849 | @numba.njit(cache=True) 850 | def initalize_heap_from_graph_indices_and_distances( 851 | heap, graph_indices, graph_distances 852 | ): 853 | for i in range(graph_indices.shape[0]): 854 | for idx in range(graph_indices.shape[1]): 855 | j = graph_indices[i, idx] 856 | if j >= 0: 857 | d = graph_distances[i, idx] 858 | checked_flagged_heap_push(heap[1][i], heap[0][i], heap[2][i], d, j, 1) 859 | 860 | return heap 861 | 862 | 863 | @numba.njit(parallel=True, cache=False) 864 | def sparse_initalize_heap_from_graph_indices( 865 | heap, graph_indices, data_indptr, data_indices, data_vals, metric 866 | ): 867 | 868 | for i in numba.prange(graph_indices.shape[0]): 869 | for idx in range(graph_indices.shape[1]): 870 | j = graph_indices[i, idx] 871 | ind1 = data_indices[data_indptr[i] : data_indptr[i + 1]] 872 | data1 = data_vals[data_indptr[i] : data_indptr[i + 1]] 873 | ind2 = data_indices[data_indptr[j] : data_indptr[j + 1]] 874 | data2 = data_vals[data_indptr[j] : data_indptr[j + 1]] 875 | d = metric(ind1, data1, ind2, data2) 876 | checked_flagged_heap_push(heap[1][i], heap[0][i], heap[2][i], d, j, 1) 877 | 878 | return heap 879 | 880 | 881 | # Generates a timestamp for use in logging messages when verbose=True 882 | def ts(): 883 | return time.ctime(time.time()) 884 | --------------------------------------------------------------------------------