├── src └── sklearn_ann │ ├── __init__.py │ ├── cluster │ ├── __init__.py │ └── rnn_dbscan.py │ ├── kneighbors │ ├── __init__.py │ ├── sklearn.py │ ├── pynndescent.py │ ├── nmslib.py │ ├── annoy.py │ └── faiss.py │ ├── test_utils.py │ └── utils.py ├── .gitignore ├── .editorconfig ├── tests ├── test_cluster.py └── test_kneighbors │ ├── conftest.py │ ├── test_faiss.py │ ├── test_nmslib.py │ ├── test_annoy.py │ └── test_common.py ├── .readthedocs.yaml ├── .github └── workflows │ ├── publish.yml │ └── pytest.yml ├── docs ├── clustering.rst ├── index.rst ├── Makefile ├── _templates │ └── autosummary │ │ └── class.rst ├── background.rst ├── conf.py └── kneighbors.rst ├── .pre-commit-config.yaml ├── LICENSE ├── examples ├── rnn_dbscan_big.py └── rnn_dbscan_simple.py ├── README.rst └── pyproject.toml /src/sklearn_ann/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/sklearn_ann/cluster/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/sklearn_ann/kneighbors/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Caches 2 | __pycache__/ 3 | /*cache/ 4 | /node_modules/ 5 | 6 | # Build artifacts 7 | /dist/ 8 | /docs/_build/ 9 | 10 | # IDEs and dev files 11 | activate.sh 12 | .vscode/ 13 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # This file configures editors, IDEs, and prettier 2 | 3 | root = true 4 | 5 | [*] 6 | charset = utf-8 7 | end_of_line = lf 8 | insert_final_newline = true 9 | trim_trailing_whitespace = true 10 | max_line_length = 88 11 | indent_size = 4 12 | indent_style = space 13 | 14 | [*.{yml,yaml}] 15 | indent_size = 2 16 | -------------------------------------------------------------------------------- /tests/test_cluster.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from sklearn.utils.estimator_checks import check_estimator 3 | 4 | from sklearn_ann.cluster.rnn_dbscan import RnnDBSCAN 5 | 6 | ESTIMATORS = [RnnDBSCAN] 7 | 8 | 9 | @pytest.mark.parametrize("Estimator", ESTIMATORS) 10 | def test_all_estimators(Estimator): 11 | check_estimator(Estimator()) 12 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-22.04 5 | tools: 6 | python: "3.11" 7 | 8 | python: 9 | install: 10 | - method: pip 11 | path: . 12 | extra_requirements: 13 | - docs 14 | - annlibs 15 | 16 | sphinx: 17 | configuration: docs/conf.py 18 | # TODO: uncomment once docs build without autosummary warnings 19 | # fail_on_warning: true 20 | -------------------------------------------------------------------------------- /src/sklearn_ann/kneighbors/sklearn.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from sklearn.neighbors import KNeighborsTransformer 4 | 5 | BallTreeTransformer = partial(KNeighborsTransformer, algorithm="ball_tree") 6 | KDTreeTransformer = partial(KNeighborsTransformer, algorithm="kd_tree") 7 | BruteTransformer = partial(KNeighborsTransformer, algorithm="brute") 8 | 9 | 10 | __all__ = ["BallTreeTransformer", "BruteTransformer", "KDTreeTransformer"] 11 | -------------------------------------------------------------------------------- /src/sklearn_ann/kneighbors/pynndescent.py: -------------------------------------------------------------------------------- 1 | from pynndescent import PyNNDescentTransformer as PyNNDescentTransformerBase 2 | 3 | 4 | def no_op(): 5 | pass 6 | 7 | 8 | class PyNNDescentTransformer(PyNNDescentTransformerBase): 9 | def fit(self, X, compress_index=True): 10 | super().fit(X, compress_index=compress_index) 11 | self.index_.compress_index = no_op 12 | return self 13 | 14 | 15 | __all__ = ["PyNNDescentTransformer"] 16 | -------------------------------------------------------------------------------- /tests/test_kneighbors/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from numpy.random import default_rng 3 | from scipy.spatial.distance import pdist, squareform 4 | 5 | 6 | @pytest.fixture(scope="module") 7 | def random_small(): 8 | gen = default_rng(42) 9 | return 2 * gen.random((64, 128)) - 1 10 | 11 | 12 | @pytest.fixture(scope="module") 13 | def random_small_pdists(random_small): 14 | return { 15 | metric: squareform(pdist(random_small, metric=metric)) 16 | for metric in ["euclidean", "cosine"] 17 | } 18 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | publish: 9 | runs-on: ubuntu-latest 10 | environment: pypi 11 | permissions: 12 | id-token: write # to authenticate as Trusted Publisher to pypi.org 13 | steps: 14 | - uses: actions/checkout@v5 15 | - uses: actions/setup-python@v6 16 | with: 17 | python-version: "3.x" 18 | cache: pip 19 | - run: pip install build 20 | - run: python -m build 21 | - uses: pypa/gh-action-pypi-publish@release/v1 22 | -------------------------------------------------------------------------------- /docs/clustering.rst: -------------------------------------------------------------------------------- 1 | Clustering 2 | ========== 3 | 4 | While it is possible to use the transformers of the sklearn_ann.kneighbors module together with clustering algorithms from scikit-learn directly, there is often a mismatch between techniques like DBSCAN, which require for each node its neighbors within a certain radius, and kNN-graph which has a fixed number of. This mismatch may result in k being set to high, to make sure that, slowing things down. 5 | 6 | This module contains an implementation of RNN-DBSCAN, which is based on the kNN-graph structure. 7 | 8 | .. automodule:: sklearn_ann.cluster.rnn_dbscan 9 | :members: 10 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. sklearn-ann documentation master file, created by 2 | sphinx-quickstart on Mon Jan 4 10:40:28 2021. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | sklearn-ann 7 | =========== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | :caption: Contents: 12 | 13 | .. include:: ../README.rst 14 | :start-after: inclusion-marker-do-not-remove 15 | 16 | 17 | User Guide 18 | --------------------- 19 | 20 | .. toctree:: 21 | :maxdepth: 2 22 | 23 | background 24 | kneighbors 25 | clustering 26 | 27 | 28 | Indices and tables 29 | ================== 30 | 31 | * :ref:`genindex` 32 | * :ref:`modindex` 33 | * :ref:`search` 34 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= hatch run sphinx-build 8 | SOURCEDIR = $(CURDIR) 9 | BUILDDIR = $(CURDIR)/_build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v6.0.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-added-large-files 8 | - id: check-case-conflict 9 | - id: check-toml 10 | - id: check-yaml 11 | - id: check-merge-conflict 12 | - id: detect-private-key 13 | - id: no-commit-to-branch 14 | args: ["--branch=main"] 15 | - repo: https://github.com/astral-sh/ruff-pre-commit 16 | rev: v0.14.9 17 | hooks: 18 | - id: ruff 19 | args: ["--fix"] 20 | - id: ruff-format 21 | - repo: https://github.com/pre-commit/mirrors-prettier 22 | rev: v4.0.0-alpha.8 23 | hooks: 24 | - id: prettier 25 | -------------------------------------------------------------------------------- /tests/test_kneighbors/test_faiss.py: -------------------------------------------------------------------------------- 1 | from sklearn_ann.test_utils import assert_row_close, needs 2 | 3 | try: 4 | from sklearn_ann.kneighbors.faiss import FAISSTransformer 5 | except ImportError: 6 | pass 7 | 8 | 9 | @needs.faiss 10 | def test_euclidean(random_small, random_small_pdists): 11 | trans = FAISSTransformer(metric="euclidean") 12 | mat = trans.fit_transform(random_small) 13 | euclidean_dist = random_small_pdists["euclidean"] 14 | assert_row_close(mat, euclidean_dist) 15 | 16 | 17 | @needs.faiss 18 | def test_cosine(random_small, random_small_pdists): 19 | trans = FAISSTransformer(metric="cosine") 20 | mat = trans.fit_transform(random_small) 21 | cosine_dist = random_small_pdists["cosine"] 22 | assert_row_close(mat, cosine_dist) 23 | -------------------------------------------------------------------------------- /tests/test_kneighbors/test_nmslib.py: -------------------------------------------------------------------------------- 1 | from sklearn_ann.test_utils import assert_row_close, needs 2 | 3 | try: 4 | from sklearn_ann.kneighbors.nmslib import NMSlibTransformer 5 | except ImportError: 6 | pass 7 | 8 | 9 | @needs.nmslib 10 | def test_euclidean(random_small, random_small_pdists): 11 | trans = NMSlibTransformer(metric="euclidean") 12 | mat = trans.fit_transform(random_small) 13 | euclidean_dist = random_small_pdists["euclidean"] 14 | assert_row_close(mat, euclidean_dist) 15 | 16 | 17 | @needs.nmslib 18 | def test_cosine(random_small, random_small_pdists): 19 | trans = NMSlibTransformer(metric="cosine") 20 | mat = trans.fit_transform(random_small) 21 | cosine_dist = random_small_pdists["cosine"] 22 | assert_row_close(mat, cosine_dist) 23 | -------------------------------------------------------------------------------- /docs/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. add toctree option to make autodoc generate the pages 6 | 7 | .. autoclass:: {{ objname }} 8 | 9 | {% block attributes %} 10 | {% if attributes %} 11 | .. rubric:: Attributes 12 | 13 | .. autosummary:: 14 | :toctree: . 15 | {% for item in attributes %} 16 | ~{{ fullname }}.{{ item }} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | {% block methods %} 22 | {% if methods %} 23 | .. rubric:: Methods 24 | 25 | .. autosummary:: 26 | :toctree: . 27 | {% for item in methods %} 28 | {%- if item != '__init__' %} 29 | ~{{ fullname }}.{{ item }} 30 | {%- endif -%} 31 | {%- endfor %} 32 | {% endif %} 33 | {% endblock %} 34 | -------------------------------------------------------------------------------- /tests/test_kneighbors/test_annoy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from sklearn_ann.test_utils import assert_row_close, needs 5 | 6 | try: 7 | from sklearn_ann.kneighbors.annoy import AnnoyTransformer 8 | except ImportError: 9 | pass 10 | 11 | 12 | @needs.annoy 13 | def test_euclidean(random_small, random_small_pdists): 14 | trans = AnnoyTransformer(metric="euclidean") 15 | mat = trans.fit_transform(random_small) 16 | euclidean_dist = random_small_pdists["euclidean"] 17 | assert_row_close(mat, euclidean_dist) 18 | 19 | 20 | @needs.annoy 21 | @pytest.mark.xfail(reason="not sure why this isn't working") 22 | def test_angular(random_small, random_small_pdists): 23 | trans = AnnoyTransformer(metric="angular") 24 | mat = trans.fit_transform(random_small) 25 | angular_dist = np.arccos(1 - random_small_pdists["cosine"]) 26 | assert_row_close(mat, angular_dist) 27 | -------------------------------------------------------------------------------- /src/sklearn_ann/test_utils.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from importlib.util import find_spec 3 | 4 | 5 | def assert_row_close(sp_mat, actual_pdist, row=42, thresh=0.01): 6 | row_mat = sp_mat.getrow(row) 7 | for col, val in zip(row_mat.indices, row_mat.data): 8 | assert abs(actual_pdist[row, col] - val) < thresh 9 | 10 | 11 | class needs(Enum): 12 | """ 13 | A pytest mark generator for skipping tests if a package is not installed. 14 | 15 | Can be used as a decorator: 16 | 17 | >>> @needs.faiss 18 | >>> def test_x(): pass 19 | 20 | or be called to create a mark object: 21 | 22 | >>> pytest.param(..., marks=[needs.annoy()]) 23 | """ 24 | 25 | annoy = ("annoy",) 26 | faiss = ("faiss-cpu", "faiss-gpu") 27 | nmslib = ("nmslib",) 28 | pynndescent = ("pynndescent",) 29 | 30 | def __call__(self, fn=None): 31 | import pytest 32 | 33 | what = ( 34 | f"package {self.value[0]}" 35 | if len(self.value) == 1 36 | else f"one of the packages {set(self.value)}" 37 | ) 38 | mark = pytest.mark.skipif( 39 | not find_spec(self.name), 40 | reason=f"`import {self.name}` needs {what} installed.", 41 | ) 42 | return mark if fn is None else mark(fn) 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020, scikit-ann contributors 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of project-template nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | name: Run tests 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | workflow_dispatch: 8 | 9 | jobs: 10 | test: 11 | strategy: 12 | matrix: 13 | python-version: ["3.9", "3.13"] 14 | runs-on: ubuntu-latest 15 | timeout-minutes: 60 16 | steps: 17 | - name: Checkout 18 | uses: actions/checkout@v5 19 | with: 20 | filter: blob:none 21 | fetch-depth: 0 22 | 23 | - name: Setup Python 24 | uses: actions/setup-python@v6 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | cache: pip 28 | 29 | - name: Install project and dependencies 30 | run: pip install .[annlibs,tests] 31 | 32 | - name: Run test suite 33 | run: pytest -v --color=yes 34 | 35 | build: 36 | runs-on: ubuntu-latest 37 | steps: 38 | - uses: actions/checkout@v5 39 | with: 40 | filter: blob:none 41 | fetch-depth: 0 42 | - uses: actions/setup-python@v6 43 | with: 44 | python-version: "3.x" 45 | cache: pip 46 | - name: Install tools 47 | run: pip install twine build 48 | - name: Build and check 49 | run: | 50 | python -m build 51 | twine check dist/*.whl 52 | 53 | check: 54 | if: always() 55 | needs: 56 | - build 57 | - test 58 | runs-on: ubuntu-latest 59 | steps: 60 | # https://github.com/marketplace/actions/alls-green#why 61 | - uses: re-actors/alls-green@release/v1 62 | with: 63 | jobs: ${{ toJSON(needs) }} 64 | -------------------------------------------------------------------------------- /docs/background.rst: -------------------------------------------------------------------------------- 1 | Background and API design 2 | ========================= 3 | 4 | There have been long standing efficiency issues with scikit-learn's. In 5 | particular, the `ball tree`_ and `k-d tree`_ to not scale well to high 6 | dimensional spaces. The decision was taken that the best way to integrate other 7 | techniques was to allow all applicable unsupervised estimators methods to take 8 | a sparse matrix, typically being a KNN-graph of the points, but potentially 9 | being any estimate. These `slides from PyParis 2018`_ explain some background, 10 | while `issue #10463`_ and `pull request #10482`_ give discussion, justification 11 | and benchmarks and more detail regarding the approach. 12 | 13 | The main advantage of this technique is that the sparse matrix/KNN-graph can be built transformer from the data, and these to be sequenced using the scikit-learn pipeline mechanism. This approach allows for, for example parameter search to be done on the KNN-graph construction technique together with the estimator. Typically the transformer should closely follow the interface of KNeighborsTransformer. The `exact contract is outlined in the user guide`_. . There is also `an example notebook with early versions of the transformers in this library`_. 14 | 15 | .. _`ball tree`: https://en.wikipedia.org/wiki/Ball_tree 16 | .. _`k-d tree`: https://en.wikipedia.org/wiki/K-d_tree 17 | .. _`slides from PyParis 2018`: https://tomdlt.github.io/decks/2018_pyparis/ 18 | .. _`issue #10463`: https://github.com/scikit-learn/scikit-learn/issues/10463 19 | .. _`pull request #10482`: https://github.com/scikit-learn/scikit-learn/pull/10482 20 | .. _`exact contract is outlined in the user guide`: https://scikit-learn.org/stable/modules/neighbors.html#neighbors-transformer 21 | .. _`an example notebook with early versions of the transformers in this library`: https://scikit-learn.org/stable/auto_examples/neighbors/approximate_nearest_neighbors.html#sphx-glr-auto-examples-neighbors-approximate-nearest-neighbors-py 22 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | import os 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | # import sys 16 | # sys.path.insert(0, os.path.abspath('.')) 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = "sklearn-ann" 21 | copyright = "2021, Frankie Robertson" 22 | author = "Frankie Robertson" 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | extensions = [ 31 | "sphinx.ext.autodoc", 32 | "sphinx.ext.autosummary", 33 | "sphinx.ext.napoleon", 34 | "scanpydoc.definition_list_typed_field", 35 | "scanpydoc.rtd_github_links", 36 | "sphinx_issues", 37 | "sphinx.ext.linkcode", 38 | ] 39 | 40 | # Add any paths that contain templates here, relative to this directory. 41 | templates_path = ["_templates"] 42 | 43 | # List of patterns, relative to source directory, that match files and 44 | # directories to ignore when looking for source files. 45 | # This pattern also affects html_static_path and html_extra_path. 46 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 47 | 48 | autosummary_generate = True 49 | autodoc_default_options = { 50 | "undoc-members": True, 51 | } 52 | 53 | 54 | # -- Options for HTML output ------------------------------------------------- 55 | 56 | # The theme to use for HTML and HTML Help pages. See the documentation for 57 | # a list of builtin themes. 58 | # 59 | html_theme = "sphinx_book_theme" 60 | html_theme_options = dict( 61 | repository_url="https://github.com/frankier/sklearn-ann", 62 | repository_branch=os.environ.get("READTHEDOCS_GIT_IDENTIFIER", "main"), 63 | ) 64 | rtd_links_prefix = "src" 65 | 66 | autodoc_mock_imports = ["annoy", "faiss", "pynndescent", "nmslib"] 67 | -------------------------------------------------------------------------------- /examples/rnn_dbscan_big.py: -------------------------------------------------------------------------------- 1 | """ 2 | ======================================================= 3 | Demo of RnnDBSCAN clustering algorithm on large dataset 4 | ======================================================= 5 | 6 | Tests RnnDBSCAN on a large dataset. Requires pandas. 7 | 8 | """ 9 | 10 | import numpy as np 11 | from joblib import Memory 12 | from sklearn import metrics 13 | from sklearn.datasets import fetch_openml 14 | 15 | from sklearn_ann.cluster.rnn_dbscan import simple_rnn_dbscan_pipeline 16 | 17 | 18 | # ############################################################################# 19 | # Generate sample data 20 | def fetch_mnist(): 21 | print("Downloading mnist_784") 22 | mnist = fetch_openml("mnist_784") 23 | return mnist.data / 255, mnist.target 24 | 25 | 26 | memory = Memory("./mnist") 27 | 28 | X, y = memory.cache(fetch_mnist)() 29 | 30 | 31 | def run_rnn_dbscan(neighbor_transformer, n_neighbors, **kwargs): 32 | # ############################################################################# 33 | # Compute RnnDBSCAN 34 | 35 | pipeline = simple_rnn_dbscan_pipeline(neighbor_transformer, n_neighbors, **kwargs) 36 | labels = pipeline.fit_predict(X) 37 | db = pipeline.named_steps["rnndbscan"] 38 | core_samples_mask = np.zeros_like(db.labels_, dtype=bool) 39 | core_samples_mask[db.core_sample_indices_] = True 40 | 41 | # Number of clusters in labels, ignoring noise if present. 42 | n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0) 43 | n_noise_ = list(labels).count(-1) 44 | 45 | print(f"""\ 46 | Estimated number of clusters: {n_clusters_} 47 | Estimated number of noise points: {n_noise_} 48 | Homogeneity: {metrics.homogeneity_score(y, labels):0.3f} 49 | Completeness: {metrics.completeness_score(y, labels):0.3f} 50 | V-measure: {metrics.v_measure_score(y, labels):0.3f} 51 | Adjusted Rand Index: {metrics.adjusted_rand_score(y, labels):0.3f} 52 | Adjusted Mutual Information: {metrics.adjusted_mutual_info_score(y, labels):0.3f} 53 | Silhouette Coefficient: {metrics.silhouette_score(X, labels):0.3f}\ 54 | """) 55 | 56 | 57 | if __name__ == "__main__": 58 | import code 59 | 60 | print("""\ 61 | Now you can import your chosen transformer_cls and run: 62 | run_rnn_dbscan(transformer_cls, n_neighbors, **params) 63 | e.g. 64 | from sklearn_ann.kneighbors.pynndescent import PyNNDescentTransformer 65 | run_rnn_dbscan(PyNNDescentTransformer, 10)\ 66 | """) 67 | code.interact(local=locals()) 68 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | .. -*- mode: rst -*- 2 | 3 | |PyPI|_ |ReadTheDocs|_ 4 | 5 | .. |PyPI| image:: https://img.shields.io/pypi/v/sklearn-ann 6 | .. _PyPI: https://pypi.org/project/sklearn-ann/ 7 | 8 | .. |ReadTheDocs| image:: https://readthedocs.org/projects/sklearn-ann/badge/?version=latest 9 | .. _ReadTheDocs: https://sklearn-ann.readthedocs.io/en/latest/?badge=latest 10 | 11 | sklearn-ann 12 | =========== 13 | 14 | .. inclusion-marker-do-not-remove 15 | 16 | **sklearn-ann** eases integration of approximate nearest neighbours 17 | libraries such as annoy, nmslib and faiss into your sklearn 18 | pipelines. It consists of: 19 | 20 | * ``Transformers`` conforming to the same interface as 21 | ``KNeighborsTransformer`` which can be used to transform feature matrices 22 | into sparse distance matrices for use by any estimator that can deal with 23 | sparse distance matrices. Many, but not all, of scikit-learn's clustering and 24 | manifold learning algorithms can work with this kind of input. 25 | * RNN-DBSCAN: a variant of DBSCAN based on reverse nearest 26 | neighbours. 27 | 28 | Installation 29 | ============ 30 | 31 | To install the latest release from PyPI, run: 32 | 33 | .. code-block:: bash 34 | 35 | pip install sklearn-ann 36 | 37 | To install the latest development version from GitHub, run: 38 | 39 | .. code-block:: bash 40 | 41 | pip install git+https://github.com/scikit-learn-contrib/sklearn-ann.git#egg=sklearn-ann 42 | 43 | Why? When do I want this? 44 | ========================= 45 | 46 | The main scenarios in which this is needed is for performing 47 | *clustering or manifold learning or high dimensional data*. The 48 | reason is that currently the only neighbourhood algorithms which are 49 | build into scikit-learn are essentially the standard tree approaches 50 | to space partitioning: the ball tree and the K-D tree. These do not 51 | perform competitively in high dimensional spaces. 52 | 53 | Development 54 | =========== 55 | 56 | This project is managed using Hatch_ and pre-commit_. To get started, run ``pre-commit 57 | install`` and ``hatch env create``. Run all commands using ``hatch run python 58 | `` which will ensure the environment is kept up to date. pre-commit_ comes into 59 | play on every `git commit` after installation. 60 | 61 | Consult ``pyproject.toml`` for which dependency groups and extras exist, 62 | and the Hatch help or user guide for more info on what they are. 63 | 64 | .. _Hatch: https://hatch.pypa.io/ 65 | .. _pre-commit: https://pre-commit.com/ 66 | -------------------------------------------------------------------------------- /src/sklearn_ann/kneighbors/nmslib.py: -------------------------------------------------------------------------------- 1 | import nmslib 2 | import numpy as np 3 | from scipy.sparse import csr_matrix 4 | from sklearn.base import BaseEstimator, TransformerMixin 5 | from sklearn.utils import Tags, TargetTags, TransformerTags 6 | from sklearn.utils.validation import validate_data 7 | 8 | from ..utils import TransformerChecksMixin, check_metric 9 | 10 | # see more metric in the manual 11 | # https://github.com/nmslib/nmslib/tree/master/manual 12 | METRIC_MAP = { 13 | "sqeuclidean": "l2", 14 | "euclidean": "l2", 15 | "cosine": "cosinesimil", 16 | "l1": "l1", 17 | "l2": "l2", 18 | } 19 | 20 | 21 | class NMSlibTransformer(TransformerChecksMixin, TransformerMixin, BaseEstimator): 22 | """Wrapper for using nmslib as sklearn's KNeighborsTransformer""" 23 | 24 | def __init__( 25 | self, n_neighbors=5, *, metric="euclidean", method="sw-graph", n_jobs=1 26 | ): 27 | self.n_neighbors = n_neighbors 28 | self.method = method 29 | self.metric = metric 30 | self.n_jobs = n_jobs 31 | 32 | def fit(self, X, y=None): 33 | X = validate_data(self, X) 34 | self.n_samples_fit_ = X.shape[0] 35 | 36 | check_metric(self.metric, METRIC_MAP) 37 | space = METRIC_MAP[self.metric] 38 | 39 | self.nmslib_ = nmslib.init(method=self.method, space=space) 40 | self.nmslib_.addDataPointBatch(X) 41 | self.nmslib_.createIndex() 42 | return self 43 | 44 | def transform(self, X): 45 | X = self._transform_checks(X, "nmslib_") 46 | n_samples_transform = X.shape[0] 47 | 48 | # For compatibility reasons, as each sample is considered as its own 49 | # neighbor, one extra neighbor will be computed. 50 | n_neighbors = self.n_neighbors + 1 51 | 52 | results = self.nmslib_.knnQueryBatch(X, k=n_neighbors, num_threads=self.n_jobs) 53 | indices, distances = zip(*results) 54 | indices, distances = np.vstack(indices), np.vstack(distances) 55 | 56 | if self.metric == "sqeuclidean": 57 | distances **= 2 58 | 59 | indptr = np.arange(0, n_samples_transform * n_neighbors + 1, n_neighbors) 60 | kneighbors_graph = csr_matrix( 61 | (distances.ravel(), indices.ravel(), indptr), 62 | shape=(n_samples_transform, self.n_samples_fit_), 63 | ) 64 | 65 | return kneighbors_graph 66 | 67 | def __sklearn_tags__(self) -> Tags: 68 | return Tags( 69 | estimator_type="transformer", 70 | target_tags=TargetTags(required=False), 71 | transformer_tags=TransformerTags(preserves_dtype=[np.float32]), 72 | ) 73 | -------------------------------------------------------------------------------- /docs/kneighbors.rst: -------------------------------------------------------------------------------- 1 | Implementations of the KNeighborsTransformer interface 2 | ====================================================== 3 | 4 | This module contains transformers which transform from array-like structures of 5 | shape (n_samples, n_features) to KNN-graphs encoded as scipy.sparse.csr_matrix. 6 | They conform to the KNeighborsTransformer interface. Each submodule in this 7 | module provides facilities for exactly one external nearest neighbour library. 8 | 9 | Annoy 10 | ----- 11 | 12 | `Annoy (Approximate Nearest Neighbors Oh Yeah)`_ is a C++ library with Python 13 | bindings to search for points in space that are close to a given query point. The originates from Spotify. 14 | It uses a forest of random projection trees. 15 | 16 | .. _`Annoy (Approximate Nearest Neighbors Oh Yeah)`: https://github.com/spotify/annoy 17 | 18 | 19 | .. automodule:: sklearn_ann.kneighbors.annoy 20 | :members: 21 | 22 | FAISS 23 | ----- 24 | 25 | `FAISS (Facebook AI Similarity Search)`_ is a library for efficient similarity 26 | search and clustering of dense vectors. The project originates from Facebook AI 27 | Research (FAIR). It contains multiple algorithms including algorithms for 28 | exact/brute force nearest neighbour, methods based on quantization and product 29 | quantization, and methods based on Hierarchical Navigable Small World graphs 30 | (HNSW). There are some `guidelines on how to choose the best index for your 31 | purposes`. 32 | 33 | .. _`FAISS (Facebook AI Similarity Search)`: https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index 34 | 35 | .. _`guidelines on how to choose the best index for your purposes`: https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index 36 | 37 | 38 | .. automodule:: sklearn_ann.kneighbors.faiss 39 | :members: 40 | 41 | nmslib 42 | ------ 43 | 44 | `nmslib (non-metric space library)` is a library for similarity search support 45 | metric and non-metric spaces. It contains multiple algorithms. 46 | 47 | 48 | .. automodule:: sklearn_ann.kneighbors.nmslib 49 | :members: 50 | 51 | PyNNDescent 52 | ----------- 53 | 54 | `PyNNDescent`_ is a Python nearest neighbor descent for approximate nearest 55 | neighbors. It iteratively improves kNN-graph using the transitive property, 56 | using random projections for initialisation. This transformer is actually 57 | implemented as part of PyNNDescent, and simply re-exported here for (foolish) 58 | consistency. If you only need this transformer, just use PyNNDescent directly. 59 | 60 | 61 | .. automodule:: sklearn_ann.kneighbors.pynndescent 62 | :members: 63 | 64 | sklearn 65 | ------- 66 | 67 | `scikit-learn` itself contains ball tree and k-d indices. KNeighborsTransformer is re-exported here specialised for these two types of index for consistency. 68 | 69 | 70 | .. automodule:: sklearn_ann.kneighbors.sklearn 71 | :members: 72 | -------------------------------------------------------------------------------- /examples/rnn_dbscan_simple.py: -------------------------------------------------------------------------------- 1 | """ 2 | =================================== 3 | Demo of RNN-DBSCAN clustering algorithm 4 | =================================== 5 | 6 | Finds core samples of high density and expands clusters from them. 7 | 8 | Mostly copypasted from sklearn's DBSCAN example. 9 | 10 | """ 11 | 12 | import numpy as np 13 | from sklearn import metrics 14 | from sklearn.datasets import make_blobs 15 | from sklearn.preprocessing import StandardScaler 16 | 17 | from sklearn_ann.cluster.rnn_dbscan import RnnDBSCAN 18 | 19 | # ############################################################################# 20 | # Generate sample data 21 | centers = [[1, 1], [-1, -1], [1, -1]] 22 | X, labels_true = make_blobs( 23 | n_samples=750, centers=centers, cluster_std=0.4, random_state=0 24 | ) 25 | 26 | X = StandardScaler().fit_transform(X) 27 | 28 | # ############################################################################# 29 | # Compute DBSCAN 30 | db = RnnDBSCAN(n_neighbors=10).fit(X) 31 | core_samples_mask = np.zeros_like(db.labels_, dtype=bool) 32 | core_samples_mask[db.core_sample_indices_] = True 33 | labels = db.labels_ 34 | 35 | # Number of clusters in labels, ignoring noise if present. 36 | n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0) 37 | n_noise_ = list(labels).count(-1) 38 | 39 | print(f"""\ 40 | Estimated number of clusters: {n_clusters_} 41 | Estimated number of noise points: {n_noise_} 42 | Homogeneity: {metrics.homogeneity_score(labels_true, labels):0.3f} 43 | Completeness: {metrics.completeness_score(labels_true, labels):0.3f} 44 | V-measure: {metrics.v_measure_score(labels_true, labels):0.3f} 45 | Adjusted Rand Index: {metrics.adjusted_rand_score(labels_true, labels):0.3f} 46 | Adjusted Mutual Info: {metrics.adjusted_mutual_info_score(labels_true, labels):0.3f} 47 | Silhouette Coefficient: {metrics.silhouette_score(X, labels):0.3f}\ 48 | """) 49 | 50 | # ############################################################################# 51 | # Plot result 52 | import matplotlib.pyplot as plt 53 | 54 | # Black removed and is used for noise instead. 55 | unique_labels = set(labels) 56 | colors = [plt.cm.Spectral(each) for each in np.linspace(0, 1, len(unique_labels))] 57 | for k, col in zip(unique_labels, colors): 58 | if k == -1: 59 | # Black used for noise. 60 | col = [0, 0, 0, 1] 61 | 62 | class_member_mask = labels == k 63 | 64 | xy = X[class_member_mask & core_samples_mask] 65 | plt.plot( 66 | xy[:, 0], 67 | xy[:, 1], 68 | "o", 69 | markerfacecolor=tuple(col), 70 | markeredgecolor="k", 71 | markersize=14, 72 | ) 73 | 74 | xy = X[class_member_mask & ~core_samples_mask] 75 | plt.plot( 76 | xy[:, 0], 77 | xy[:, 1], 78 | "o", 79 | markerfacecolor=tuple(col), 80 | markeredgecolor="k", 81 | markersize=6, 82 | ) 83 | 84 | plt.title(f"Estimated number of clusters: {n_clusters_}") 85 | plt.show() 86 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "sklearn-ann" 3 | description = "Various integrations for ANN (Approximate Nearest Neighbours) libraries into scikit-learn." 4 | authors = [ 5 | { name = "Frankie Robertson", email = "frankie@robertson.name" }, 6 | { name = "Philipp Angerer", email = "phil.angerer@gmail.com" } 7 | ] 8 | license = "BSD-3-Clause" 9 | urls.Source = "https://github.com/scikit-learn-contrib/sklearn-ann" 10 | urls.Documentation = "https://sklearn-ann.readthedocs.io/" 11 | dynamic = ["version", "readme"] 12 | requires-python = ">=3.9" 13 | dependencies = [ 14 | "scikit-learn>=1.6.0", 15 | "scipy>=1.11.1,<2.0.0", 16 | ] 17 | 18 | [project.optional-dependencies] 19 | tests = [ 20 | "pytest>=6.2.1", 21 | "pytest-cov>=2.10.1", 22 | ] 23 | docs = [ 24 | "sphinx>=7", 25 | "sphinx-gallery>=0.8.2", 26 | "sphinx-book-theme>=1.1.0", 27 | "sphinx-issues>=1.2.0", 28 | "numpydoc>=1.1.0", 29 | "matplotlib>=3.3.3", 30 | "scanpydoc", 31 | ] 32 | annoy = [ 33 | "annoy>=1.17.0,<2.0.0", 34 | ] 35 | faiss = [ 36 | "faiss-cpu>=1.6.5,<2.0.0", 37 | ] 38 | pynndescent = [ 39 | "pynndescent>=0.5.1,<1.0.0", 40 | "numba>=0.52", 41 | ] 42 | nmslib = [ 43 | "nmslib-metabrainz>=2.1.1,<3.0.0", 44 | ] 45 | annlibs = [ 46 | "sklearn-ann[annoy,faiss,pynndescent,nmslib]", 47 | ] 48 | 49 | [tool.hatch.version] 50 | source = "vcs" 51 | 52 | [tool.hatch.metadata.hooks.fancy-pypi-readme] 53 | content-type = "text/x-rst" 54 | [[tool.hatch.metadata.hooks.fancy-pypi-readme.fragments]] 55 | path = "README.rst" 56 | start-after = ".. inclusion-marker-do-not-remove\n\n" 57 | 58 | [tool.pytest.ini_options] 59 | addopts = [ 60 | "--import-mode=importlib", 61 | ] 62 | # Flaky tests should be marked as xfail(strict=False), 63 | # this will notify us when a test considered broken starts succeeding. 64 | xfail_strict = true 65 | 66 | [tool.ruff.lint] 67 | select = [ 68 | "F", # Pyflakes 69 | "E", # Pycodestyle errors 70 | "W", # Pycodestyle warnings 71 | "I", # Isort 72 | "UP", # Pyupgrade 73 | "PT", # Pytest style 74 | "PTH", # Pathlib 75 | "RUF", # Ruff’s own rules 76 | "T20", # print statements 77 | ] 78 | ignore = [ 79 | # Don’t complain about “confusables” 80 | "RUF001", "RUF002", "RUF003" 81 | ] 82 | [tool.ruff.lint.per-file-ignores] 83 | "examples/*.py" = ["E402", "T20"] 84 | "tests/*.py" = ["T20"] 85 | [tool.ruff.lint.isort] 86 | known-first-party = ["sklearn_ann"] 87 | 88 | [tool.hatch.envs.docs] 89 | installer = "uv" 90 | features = ["docs", "annlibs"] 91 | scripts.build = "sphinx-build -M html docs docs/_build" 92 | 93 | [tool.hatch.envs.hatch-test] 94 | default-args = [] 95 | features = ["tests", "annlibs"] 96 | 97 | [tool.hatch.build.targets.wheel] 98 | packages = ["src/sklearn_ann"] 99 | 100 | [build-system] 101 | requires = ["hatchling", "hatch-vcs", "hatch-fancy-pypi-readme"] 102 | build-backend = "hatchling.build" 103 | -------------------------------------------------------------------------------- /src/sklearn_ann/kneighbors/annoy.py: -------------------------------------------------------------------------------- 1 | import annoy 2 | import numpy as np 3 | from scipy.sparse import csr_matrix 4 | from sklearn.base import BaseEstimator, TransformerMixin 5 | from sklearn.utils import Tags, TargetTags, TransformerTags 6 | from sklearn.utils.validation import validate_data 7 | 8 | from ..utils import TransformerChecksMixin 9 | 10 | 11 | class AnnoyTransformer(TransformerChecksMixin, TransformerMixin, BaseEstimator): 12 | """Wrapper for using annoy.AnnoyIndex as sklearn's KNeighborsTransformer""" 13 | 14 | def __init__(self, n_neighbors=5, *, metric="euclidean", n_trees=10, search_k=-1): 15 | self.n_neighbors = n_neighbors 16 | self.n_trees = n_trees 17 | self.search_k = search_k 18 | self.metric = metric 19 | 20 | def fit(self, X, y=None): 21 | X = validate_data(self, X) 22 | self.n_samples_fit_ = X.shape[0] 23 | metric = self.metric if self.metric != "sqeuclidean" else "euclidean" 24 | self.annoy_ = annoy.AnnoyIndex(X.shape[1], metric=metric) 25 | for i, x in enumerate(X): 26 | self.annoy_.add_item(i, x.tolist()) 27 | self.annoy_.build(self.n_trees) 28 | return self 29 | 30 | def transform(self, X): 31 | X = self._transform_checks(X, "annoy_") 32 | return self._transform(X) 33 | 34 | def fit_transform(self, X, y=None): 35 | return self.fit(X)._transform(X=None) 36 | 37 | def _transform(self, X): 38 | """As `transform`, but handles X is None for faster `fit_transform`.""" 39 | 40 | n_samples_transform = self.n_samples_fit_ if X is None else X.shape[0] 41 | 42 | # For compatibility reasons, as each sample is considered as its own 43 | # neighbor, one extra neighbor will be computed. 44 | n_neighbors = self.n_neighbors + 1 45 | 46 | indices = np.empty((n_samples_transform, n_neighbors), dtype=int) 47 | distances = np.empty((n_samples_transform, n_neighbors)) 48 | 49 | if X is None: 50 | for i in range(self.annoy_.get_n_items()): 51 | ind, dist = self.annoy_.get_nns_by_item( 52 | i, n_neighbors, self.search_k, include_distances=True 53 | ) 54 | 55 | indices[i], distances[i] = ind, dist 56 | else: 57 | for i, x in enumerate(X): 58 | indices[i], distances[i] = self.annoy_.get_nns_by_vector( 59 | x.tolist(), n_neighbors, self.search_k, include_distances=True 60 | ) 61 | 62 | if self.metric == "sqeuclidean": 63 | distances **= 2 64 | 65 | indptr = np.arange(0, n_samples_transform * n_neighbors + 1, n_neighbors) 66 | kneighbors_graph = csr_matrix( 67 | (distances.ravel(), indices.ravel(), indptr), 68 | shape=(n_samples_transform, self.n_samples_fit_), 69 | ) 70 | 71 | return kneighbors_graph 72 | 73 | def __sklearn_tags__(self) -> Tags: 74 | return Tags( 75 | estimator_type="transformer", 76 | target_tags=TargetTags(required=False), 77 | transformer_tags=TransformerTags(), 78 | ) 79 | -------------------------------------------------------------------------------- /src/sklearn_ann/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.sparse import csr_matrix 3 | from sklearn.utils.validation import validate_data 4 | 5 | 6 | def check_metric(metric, metrics): 7 | if metric not in metrics: 8 | raise ValueError(f"Unknown metric {metric!r}. Valid metrics are {metrics!r}") 9 | 10 | 11 | def get_sparse_indices(mat, idx): 12 | start_idx = mat.indptr[idx] 13 | end_idx = mat.indptr[idx + 1] 14 | return mat.indices[start_idx:end_idx] 15 | 16 | 17 | def get_sparse_row(mat, idx): 18 | start_idx = mat.indptr[idx] 19 | end_idx = mat.indptr[idx + 1] 20 | return zip(mat.indices[start_idx:end_idx], mat.data[start_idx:end_idx]) 21 | 22 | 23 | def trunc_csr(csr, k): 24 | indptr = np.empty_like(csr.indptr) 25 | num_rows = len(csr.indptr) - 1 26 | indices = [None] * num_rows 27 | data = [None] * num_rows 28 | cur_indptr = 0 29 | for row_idx in range(num_rows): 30 | indptr[row_idx] = cur_indptr 31 | start_idx = csr.indptr[row_idx] 32 | old_end_idx = csr.indptr[row_idx + 1] 33 | end_idx = min(old_end_idx, start_idx + k) 34 | data[row_idx] = csr.data[start_idx:end_idx] 35 | indices[row_idx] = csr.indices[start_idx:end_idx] 36 | ptr_inc = min(k, old_end_idx - start_idx) 37 | cur_indptr = cur_indptr + ptr_inc 38 | indptr[-1] = cur_indptr 39 | return csr_matrix((np.concatenate(data), np.concatenate(indices), indptr)) 40 | 41 | 42 | def or_else_csrs(csr1, csr2): 43 | # Possible TODO: Could use numba/Cython to speed this up? 44 | if csr1.shape != csr2.shape: 45 | raise ValueError("csr1 and csr2 must be the same shape") 46 | indptr = np.empty_like(csr1.indptr) 47 | indices = [] 48 | data = [] 49 | for row_idx in range(len(indptr) - 1): 50 | indptr[row_idx] = len(indices) 51 | csr1_it = iter(get_sparse_row(csr1, row_idx)) 52 | csr2_it = iter(get_sparse_row(csr2, row_idx)) 53 | cur_csr1 = next(csr1_it, None) 54 | cur_csr2 = next(csr2_it, None) 55 | while 1: 56 | if cur_csr1 is None and cur_csr2 is None: 57 | break 58 | elif cur_csr1 is None: 59 | cur_index, cur_datum = cur_csr2 60 | elif cur_csr2 is None: 61 | cur_index, cur_datum = cur_csr1 62 | elif cur_csr1[0] < cur_csr2[0]: 63 | cur_index, cur_datum = cur_csr1 64 | cur_csr1 = next(csr1_it, None) 65 | elif cur_csr2[0] < cur_csr1[0]: 66 | cur_index, cur_datum = cur_csr2 67 | cur_csr2 = next(csr2_it, None) 68 | else: 69 | cur_index, cur_datum = cur_csr1 70 | cur_csr1 = next(csr1_it, None) 71 | cur_csr2 = next(csr2_it, None) 72 | indices.append(cur_index) 73 | data.append(cur_datum) 74 | indptr[-1] = len(indices) 75 | return csr_matrix((data, indices, indptr), shape=csr1.shape) 76 | 77 | 78 | def postprocess_knn_csr(knns, include_fwd=True, include_rev=False): 79 | if not include_fwd and not include_rev: 80 | raise ValueError("One of include_fwd or include_rev must be True") 81 | elif include_rev and not include_fwd: 82 | return knns.transpose(copy=False) 83 | elif not include_rev and include_fwd: 84 | return knns 85 | else: 86 | inv_knns = knns.transpose(copy=True) 87 | return or_else_csrs(knns, inv_knns) 88 | 89 | 90 | class TransformerChecksMixin: 91 | def _transform_checks(self, X, *fitted_props, **check_params): 92 | from sklearn.utils.validation import check_is_fitted 93 | 94 | X = validate_data(self, X, reset=False, **check_params) 95 | check_is_fitted(self, *fitted_props) 96 | return X 97 | -------------------------------------------------------------------------------- /tests/test_kneighbors/test_common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from sklearn.utils.estimator_checks import check_estimator 4 | 5 | from sklearn_ann.test_utils import needs 6 | 7 | try: 8 | from sklearn_ann.kneighbors.annoy import AnnoyTransformer 9 | except ImportError: 10 | AnnoyTransformer = "AnnoyTransformer" 11 | try: 12 | from sklearn_ann.kneighbors.faiss import FAISSTransformer 13 | except ImportError: 14 | FAISSTransformer = "FAISSTransformer" 15 | try: 16 | from sklearn_ann.kneighbors.nmslib import NMSlibTransformer 17 | except ImportError: 18 | NMSlibTransformer = "NMSlibTransformer" 19 | try: 20 | from sklearn_ann.kneighbors.pynndescent import PyNNDescentTransformer 21 | except ImportError: 22 | PyNNDescentTransformer = "PyNNDescentTransformer" 23 | from sklearn_ann.kneighbors.sklearn import BallTreeTransformer, KDTreeTransformer 24 | 25 | ESTIMATORS = [ 26 | pytest.param(AnnoyTransformer, marks=[needs.annoy()]), 27 | pytest.param(FAISSTransformer, marks=[needs.faiss()]), 28 | pytest.param(NMSlibTransformer, marks=[needs.nmslib()]), 29 | pytest.param(PyNNDescentTransformer, marks=[needs.pynndescent()]), 30 | pytest.param(BallTreeTransformer), 31 | pytest.param(KDTreeTransformer), 32 | ] 33 | 34 | PER_ESTIMATOR_XFAIL_CHECKS = { 35 | AnnoyTransformer: dict(check_estimators_pickle="Cannot pickle AnnoyIndex"), 36 | FAISSTransformer: dict( 37 | check_estimators_pickle="Cannot pickle FAISS index", 38 | check_methods_subset_invariance="Unable to reset FAISS internal RNG", 39 | ), 40 | NMSlibTransformer: dict(check_estimators_pickle="Cannot pickle NMSLib index"), 41 | } 42 | 43 | 44 | def add_mark(param, mark): 45 | return pytest.param(*param.values, marks=[*param.marks, mark], id=param.id) 46 | 47 | 48 | @pytest.mark.parametrize( 49 | "Estimator", 50 | [ 51 | add_mark( 52 | est, 53 | pytest.mark.xfail( 54 | reason="cannot deal with all dtypes (problem is upsteam)" 55 | ), 56 | ) 57 | if est.values[0] is PyNNDescentTransformer 58 | else est 59 | for est in ESTIMATORS 60 | ], 61 | ) 62 | def test_all_estimators(Estimator): 63 | check_estimator( 64 | Estimator(), 65 | expected_failed_checks=PER_ESTIMATOR_XFAIL_CHECKS.get(Estimator, {}), 66 | ) 67 | 68 | 69 | # The following critera are from: 70 | # https://scikit-learn.org/stable/modules/neighbors.html#nearest-neighbors-transformer 71 | # * only explicitly store nearest neighborhoods of each sample with respect to the 72 | # training data. This should include those at 0 distance from a query point, 73 | # including the matrix diagonal when computing the nearest neighborhoods between the 74 | # training data and itself. 75 | # * each row’s data should store the distance in increasing order 76 | # (optional. Unsorted data will be stable-sorted, adding a computational overhead). 77 | # * all values in data should be non-negative. 78 | # * there should be no duplicate indices in any row (see https://github.com/scipy/scipy/issues/5807). 79 | # * if the algorithm being passed the precomputed matrix uses k nearest neighbors 80 | # (as opposed to radius neighborhood), at least k neighbors must be stored in each row 81 | # (or k+1, as explained in the following note). 82 | 83 | 84 | def mark_diagonal_0_xfail(est): 85 | """Mark flaky tests as xfail(strict=False).""" 86 | # Should probably postprocess these... 87 | reasons = { 88 | PyNNDescentTransformer: "sometimes doesn't return diagonal==0", 89 | FAISSTransformer: "sometimes returns diagonal==eps where eps is small", 90 | } 91 | [val] = est.values 92 | name = val.__name__ if isinstance(val, type) else val 93 | if reason := reasons.get(val): 94 | return add_mark(est, pytest.mark.xfail(reason=f"{name} {reason}", strict=False)) 95 | return est 96 | 97 | 98 | @pytest.mark.parametrize( 99 | "Estimator", [mark_diagonal_0_xfail(est) for est in ESTIMATORS] 100 | ) 101 | def test_all_return_diagonal_0(random_small, Estimator): 102 | # * only explicitly store nearest neighborhoods of each sample with respect to the 103 | # training data. This should include those at 0 distance from a query point, 104 | # including the matrix diagonal when computing the nearest neighborhoods 105 | # between the training data and itself. 106 | 107 | # Check: do we alway get an "extra" neighbour (diagonal/self) 108 | est = Estimator(n_neighbors=3) 109 | knns = est.fit_transform(random_small) 110 | assert (knns.getnnz(1) == 4).all() 111 | 112 | # Check: diagonal is 0 113 | next_expected_diagonal = 0 114 | for row_idx in range(knns.shape[0]): 115 | start_idx = knns.indptr[row_idx] 116 | end_idx = knns.indptr[row_idx + 1] 117 | for col_idx, val in zip( 118 | knns.indices[start_idx:end_idx], knns.data[start_idx:end_idx] 119 | ): 120 | print("self0", row_idx, start_idx, end_idx, col_idx, val) 121 | if row_idx != col_idx: 122 | continue 123 | assert col_idx == next_expected_diagonal 124 | assert val == 0 125 | next_expected_diagonal += 1 126 | assert next_expected_diagonal == len(random_small) 127 | 128 | 129 | @pytest.mark.parametrize("Estimator", ESTIMATORS) 130 | def test_all_same(random_small, Estimator): 131 | # Again but for the case of the same element 132 | ones = np.ones((64, 4)) 133 | est = Estimator(n_neighbors=3) 134 | knns = est.fit_transform(ones) 135 | print("knns", knns) 136 | assert (knns.getnnz(1) == 4).all() 137 | assert len(knns.nonzero()[0]) == 0 138 | -------------------------------------------------------------------------------- /src/sklearn_ann/kneighbors/faiss.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | 5 | import faiss 6 | import numpy as np 7 | from faiss import normalize_L2 8 | from joblib import cpu_count 9 | from scipy.sparse import csr_matrix 10 | from sklearn.base import BaseEstimator, TransformerMixin 11 | from sklearn.utils import Tags, TargetTags, TransformerTags 12 | from sklearn.utils.validation import validate_data 13 | 14 | from ..utils import TransformerChecksMixin, postprocess_knn_csr 15 | 16 | L2_INFO = {"metric": faiss.METRIC_L2, "sqrt": True} 17 | 18 | 19 | METRIC_MAP = { 20 | "cosine": { 21 | "metric": faiss.METRIC_INNER_PRODUCT, 22 | "normalize": True, 23 | "negate": True, 24 | }, 25 | "l1": {"metric": faiss.METRIC_L1}, 26 | "cityblock": {"metric": faiss.METRIC_L1}, 27 | "manhattan": {"metric": faiss.METRIC_L1}, 28 | "l2": L2_INFO, 29 | "euclidean": L2_INFO, 30 | "sqeuclidean": {"metric": faiss.METRIC_L2}, 31 | "canberra": {"metric": faiss.METRIC_Canberra}, 32 | "braycurtis": {"metric": faiss.METRIC_BrayCurtis}, 33 | "jensenshannon": {"metric": faiss.METRIC_JensenShannon}, 34 | } 35 | 36 | 37 | def mk_faiss_index(feats, inner_metric, index_key="", nprobe=128) -> faiss.Index: 38 | size, dim = feats.shape 39 | if not index_key: 40 | if inner_metric == faiss.METRIC_INNER_PRODUCT: 41 | index = faiss.IndexFlatIP(dim) 42 | else: 43 | index = faiss.IndexFlatL2(dim) 44 | else: 45 | if index_key.find("HNSW") < 0: 46 | raise NotImplementedError( 47 | "HNSW not implemented: returns distances insted of sims" 48 | ) 49 | nlist = min(4096, 8 * round(math.sqrt(size))) 50 | if index_key == "IVF": 51 | quantizer = index 52 | index = faiss.IndexIVFFlat(quantizer, dim, nlist, inner_metric) 53 | else: 54 | index = faiss.index_factory(dim, index_key, inner_metric) 55 | if index_key.find("Flat") < 0: 56 | assert not index.is_trained 57 | index.train(feats) 58 | index.nprobe = min(nprobe, nlist) 59 | assert index.is_trained 60 | index.add(feats) 61 | return index 62 | 63 | 64 | class FAISSTransformer(TransformerChecksMixin, TransformerMixin, BaseEstimator): 65 | def __init__( 66 | self, 67 | n_neighbors=5, 68 | *, 69 | metric="euclidean", 70 | index_key="", 71 | n_probe=128, 72 | n_jobs=-1, 73 | include_fwd=True, 74 | include_rev=False, 75 | ): 76 | self.n_neighbors = n_neighbors 77 | self.metric = metric 78 | self.index_key = index_key 79 | self.n_probe = n_probe 80 | self.n_jobs = n_jobs 81 | self.include_fwd = include_fwd 82 | self.include_rev = include_rev 83 | 84 | @property 85 | def _metric_info(self): 86 | return METRIC_MAP[self.metric] 87 | 88 | def fit(self, X, y=None): 89 | normalize = self._metric_info.get("normalize", False) 90 | X = validate_data(self, X, dtype=np.float32, copy=normalize) 91 | self.n_samples_fit_ = X.shape[0] 92 | if self.n_jobs == -1: 93 | n_jobs = cpu_count() 94 | else: 95 | n_jobs = self.n_jobs 96 | faiss.omp_set_num_threads(n_jobs) 97 | inner_metric = self._metric_info["metric"] 98 | if normalize: 99 | normalize_L2(X) 100 | self.faiss_ = mk_faiss_index(X, inner_metric, self.index_key, self.n_probe) 101 | return self 102 | 103 | def transform(self, X): 104 | normalize = self._metric_info.get("normalize", False) 105 | X = self._transform_checks(X, "faiss_", dtype=np.float32, copy=normalize) 106 | if normalize: 107 | normalize_L2(X) 108 | return self._transform(X) 109 | 110 | def _transform(self, X): 111 | n_samples_transform = self.n_samples_fit_ if X is None else X.shape[0] 112 | n_neighbors = self.n_neighbors + 1 113 | if X is None: 114 | sims, nbrs = self.faiss_.search( 115 | np.reshape( 116 | faiss.rev_swig_ptr( 117 | self.faiss_.get_xb(), self.faiss_.ntotal * self.faiss_.d 118 | ), 119 | (self.faiss_.ntotal, self.faiss_.d), 120 | ), 121 | k=n_neighbors, 122 | ) 123 | else: 124 | sims, nbrs = self.faiss_.search(X, k=n_neighbors) 125 | dist_arr = np.array(sims, dtype=np.float32) 126 | if self._metric_info.get("sqrt", False): 127 | dist_arr = np.sqrt(dist_arr) 128 | if self._metric_info.get("negate", False): 129 | dist_arr = 1 - dist_arr 130 | del sims 131 | nbr_arr = np.array(nbrs, dtype=np.int32) 132 | del nbrs 133 | indptr = np.arange(0, n_samples_transform * n_neighbors + 1, n_neighbors) 134 | """ 135 | dist_arr = np.concatenate( 136 | [ 137 | np.zeros( 138 | (n_samples_transform, 1), 139 | dtype=dist_arr.dtype 140 | ), 141 | dist_arr 142 | ], axis=1 143 | ) 144 | nbr_arr = np.concatenate( 145 | [ 146 | np.arange(n_samples_transform)[:, np.newaxis], 147 | nbr_arr 148 | ], axis=1 149 | ) 150 | """ 151 | mat = csr_matrix( 152 | (dist_arr.ravel(), nbr_arr.ravel(), indptr), 153 | shape=(n_samples_transform, self.n_samples_fit_), 154 | ) 155 | return postprocess_knn_csr( 156 | mat, include_fwd=self.include_fwd, include_rev=self.include_rev 157 | ) 158 | 159 | def fit_transform(self, X, y=None): 160 | return self.fit(X, y=y)._transform(X=None) 161 | 162 | def __sklearn_tags__(self) -> Tags: 163 | return Tags( 164 | estimator_type="transformer", 165 | target_tags=TargetTags(required=False), 166 | transformer_tags=TransformerTags(preserves_dtype=[np.float32]), 167 | # Could be made deterministic *if* we could reset FAISS's internal RNG 168 | non_deterministic=True, 169 | ) 170 | -------------------------------------------------------------------------------- /src/sklearn_ann/cluster/rnn_dbscan.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from typing import cast 3 | 4 | import numpy as np 5 | from sklearn.base import BaseEstimator, ClusterMixin 6 | from sklearn.neighbors import KNeighborsTransformer 7 | from sklearn.utils import Tags 8 | from sklearn.utils.validation import validate_data 9 | 10 | from ..utils import get_sparse_row 11 | 12 | UNCLASSIFIED = -2 13 | NOISE = -1 14 | 15 | 16 | def join(it1, it2): 17 | cur_it1 = next(it1, None) 18 | cur_it2 = next(it2, None) 19 | while 1: 20 | if cur_it1 is None and cur_it2 is None: 21 | break 22 | elif cur_it1 is None: 23 | yield cur_it2 24 | cur_it2 = next(it2, None) 25 | elif cur_it2 is None: 26 | yield cur_it1 27 | cur_it1 = next(it1, None) 28 | elif cur_it1[0] == cur_it2[0]: 29 | yield cur_it1 30 | cur_it1 = next(it1, None) 31 | cur_it2 = next(it2, None) 32 | elif cur_it1[0] < cur_it2[0]: 33 | yield cur_it1 34 | cur_it1 = next(it1, None) 35 | else: 36 | yield cur_it2 37 | cur_it2 = next(it2, None) 38 | 39 | 40 | def neighborhood(is_core, knns, rev_knns, idx): 41 | # TODO: Make this inner bit faster 42 | knn_it = get_sparse_row(knns, idx) 43 | rev_core_knn_it = ( 44 | (other_idx, dist) 45 | for other_idx, dist in get_sparse_row(rev_knns, idx) 46 | if is_core[other_idx] 47 | ) 48 | yield from ( 49 | (other_idx, dist) 50 | for other_idx, dist in join(knn_it, rev_core_knn_it) 51 | if other_idx != idx 52 | ) 53 | 54 | 55 | def rnn_dbscan_inner(is_core, knns, rev_knns, labels): 56 | cluster = 0 57 | cur_dens = 0 58 | dens = [] 59 | for x_idx in range(len(labels)): 60 | if labels[x_idx] == UNCLASSIFIED: 61 | # Expand cluster 62 | if is_core[x_idx]: 63 | labels[x_idx] = cluster 64 | # TODO: Make this inner bit faster - can just assume 65 | # sorted an keep sorted 66 | seeds = deque() 67 | for neighbor_idx, dist in neighborhood(is_core, knns, rev_knns, x_idx): 68 | labels[neighbor_idx] = cluster 69 | if dist > cur_dens: 70 | cur_dens = dist 71 | seeds.append(neighbor_idx) 72 | while seeds: 73 | y_idx = seeds.popleft() 74 | if is_core[y_idx]: 75 | for z_idx, dist in neighborhood(is_core, knns, rev_knns, y_idx): 76 | if dist > cur_dens: 77 | cur_dens = dist 78 | if labels[z_idx] == UNCLASSIFIED: 79 | seeds.append(z_idx) 80 | labels[z_idx] = cluster 81 | elif labels[z_idx] == NOISE: 82 | labels[z_idx] = cluster 83 | dens.append(cur_dens) 84 | cur_dens = 0 85 | cluster += 1 86 | else: 87 | labels[x_idx] = NOISE 88 | # Expand clusters 89 | for x_idx in range(len(labels)): 90 | if labels[x_idx] == NOISE: 91 | min_cluster = NOISE 92 | min_dist = float("inf") 93 | for n_idx, n_dist in get_sparse_row(knns, x_idx): 94 | if n_dist >= min_dist or not is_core[n_idx]: 95 | continue 96 | cluster = labels[n_idx] 97 | if n_dist > dens[cluster]: 98 | continue 99 | min_cluster = cluster 100 | min_dist = n_dist 101 | labels[x_idx] = min_cluster 102 | return dens 103 | 104 | 105 | class RnnDBSCAN(ClusterMixin, BaseEstimator): 106 | """ 107 | Implements the RNN-DBSCAN clustering algorithm. 108 | 109 | Parameters 110 | ---------- 111 | n_neighbors : int 112 | The number of neighbors in the kNN-graph (the k in kNN), and the 113 | theshold of reverse nearest neighbors for a node to be considered a 114 | core node. 115 | input_guarantee : "none" | "kneighbors" 116 | A guarantee on input matrices. If equal to "kneighbors", the algorithm 117 | will assume you are passing in the kNN graph exactly as required, e.g. 118 | with n_neighbors. This can be used to pass in a graph produced by one 119 | of the implementations of the KNeighborsTransformer interface. 120 | n_jobs : int 121 | The number of jobs to use. Currently has not effect since no part of 122 | the algorithm has been parallelled. 123 | keep_knns : bool 124 | If true, the kNN and inverse kNN graph will be saved to `knns_` and 125 | `rev_knns_` after fitting. 126 | 127 | See Also 128 | -------- 129 | simple_rnn_dbscan_pipeline: 130 | Create a pipeline of a KNeighborsTransformer and RnnDBSCAN 131 | 132 | References 133 | ---------- 134 | A. Bryant and K. Cios, "RNN-DBSCAN: A Density-Based Clustering 135 | Algorithm Using Reverse Nearest Neighbor Density Estimates," in IEEE 136 | Transactions on Knowledge and Data Engineering, vol. 30, no. 6, pp. 137 | 1109-1121, 1 June 2018, doi: 10.1109/TKDE.2017.2787640. 138 | """ 139 | 140 | def __init__( 141 | self, n_neighbors=5, *, input_guarantee="none", n_jobs=None, keep_knns=False 142 | ): 143 | self.n_neighbors = n_neighbors 144 | self.input_guarantee = input_guarantee 145 | self.n_jobs = n_jobs 146 | self.keep_knns = keep_knns 147 | 148 | def fit(self, X, y=None): 149 | X = validate_data(self, X, accept_sparse="csr") 150 | if self.input_guarantee == "none": 151 | algorithm = KNeighborsTransformer(n_neighbors=self.n_neighbors) 152 | X = algorithm.fit_transform(X) 153 | elif self.input_guarantee == "kneighbors": 154 | pass 155 | else: 156 | raise ValueError( 157 | "Expected input_guarantee to be one of 'none', 'kneighbors'" 158 | ) 159 | 160 | XT = X.transpose().tocsr(copy=True) 161 | if self.keep_knns: 162 | self.knns_ = X 163 | self.rev_knns_ = XT 164 | 165 | # Initially, all samples are unclassified. 166 | labels = np.full(X.shape[0], UNCLASSIFIED, dtype=np.int32) 167 | 168 | # A list of all core samples found. -1 is to account for diagonal. 169 | core_samples = XT.getnnz(1) - 1 >= self.n_neighbors 170 | 171 | dens = rnn_dbscan_inner(core_samples, X, XT, labels) 172 | 173 | self.core_sample_indices_ = core_samples.nonzero() 174 | self.labels_ = labels 175 | self.dens_ = dens 176 | 177 | return self 178 | 179 | def fit_predict(self, X, y=None): 180 | self.fit(X, y=y) 181 | return self.labels_ 182 | 183 | def drop_knns(self): 184 | del self.knns_ 185 | del self.rev_knns_ 186 | 187 | def __sklearn_tags__(self) -> Tags: 188 | tags = cast(Tags, super().__sklearn_tags__()) 189 | tags.input_tags.sparse = True 190 | return tags 191 | 192 | 193 | def simple_rnn_dbscan_pipeline( 194 | neighbor_transformer, n_neighbors, n_jobs=None, keep_knns=None, **kwargs 195 | ): 196 | """ 197 | Create a simple pipeline comprising a transformer and RnnDBSCAN. 198 | 199 | Parameters 200 | ---------- 201 | neighbor_transformer : class implementing KNeighborsTransformer interface 202 | n_neighbors: 203 | Passed to neighbor_transformer and RnnDBSCAN 204 | n_jobs: 205 | Passed to neighbor_transformer and RnnDBSCAN 206 | keep_knns: 207 | Passed to RnnDBSCAN 208 | kwargs: 209 | Passed to neighbor_transformer 210 | """ 211 | from sklearn.pipeline import make_pipeline 212 | 213 | return make_pipeline( 214 | neighbor_transformer(n_neighbors=n_neighbors, n_jobs=n_jobs, **kwargs), 215 | RnnDBSCAN( 216 | n_neighbors=n_neighbors, 217 | input_guarantee="kneighbors", 218 | n_jobs=n_jobs, 219 | keep_knns=keep_knns, 220 | ), 221 | ) 222 | --------------------------------------------------------------------------------