├── .github └── workflows │ ├── build.yml │ └── publish.yml ├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── docs ├── Makefile ├── conf.py ├── index.rst ├── make.bat ├── requirements.txt └── source │ ├── api.rst │ └── tutorial.rst ├── mypy.ini ├── nanopq ├── __init__.py ├── convert_faiss.py ├── opq.py └── pq.py ├── pyproject.toml ├── requirements-dev.txt ├── setup.py └── tests ├── __init__.py ├── test_convert_faiss.py ├── test_opq.py └── test_pq.py /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Building 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ${{ matrix.os }} 9 | strategy: 10 | matrix: 11 | os: [ubuntu-latest, macos-latest, windows-latest] 12 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 13 | 14 | steps: 15 | - name: Checkout 16 | uses: actions/checkout@v3 17 | 18 | # Install the latest miniconda. The "test" environment is activated 19 | - name: Setup miniconda 20 | uses: conda-incubator/setup-miniconda@v2 21 | with: 22 | miniconda-version: "latest" 23 | python-version: ${{ matrix.python-version }} 24 | 25 | - name: Install dependencies 26 | shell: bash -l {0} # to activate conda 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install pytest 30 | pip install . # Install this library 31 | pip install .[dev] # Install dev dependencies 32 | conda install -c pytorch faiss-cpu=1.7.4 mkl=2021 blas=1.0=mkl 33 | 34 | - name: Test with pytest 35 | shell: bash -l {0} # to activate conda 36 | run: | 37 | make test 38 | 39 | - name: Run mypy 40 | shell: bash -l {0} # to activate conda 41 | run: | 42 | mypy nanopq --ignore-missing-imports -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publishing 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | pypi-publish: 9 | name: Upload release to PyPI 10 | runs-on: ubuntu-latest 11 | environment: release 12 | permissions: 13 | id-token: write # IMPORTANT: this permission is mandatory for trusted publishing 14 | steps: 15 | # retrieve your distributions here 16 | 17 | - name: Checkout 18 | uses: actions/checkout@v4 19 | 20 | - name: Set up Python 3.x 21 | uses: actions/setup-python@v4 22 | with: 23 | python-version: '3.x' 24 | 25 | - name: Build package 26 | run: | 27 | python setup.py sdist bdist_wheel 28 | 29 | - name: Publish package distributions to PyPI 30 | uses: pypa/gh-action-pypi-publish@release/v1 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .pytest_cache 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | env/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *,cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # IPython Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # dotenv 82 | .env 83 | 84 | # virtualenv 85 | .venv/ 86 | venv/ 87 | ENV/ 88 | 89 | # Spyder project settings 90 | .spyderproject 91 | 92 | # Rope project settings 93 | .ropeproject 94 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | build: 3 | os: ubuntu-22.04 4 | tools: 5 | python: "3.10" 6 | # You can also specify other tool versions: 7 | # nodejs: "16" 8 | 9 | # Build documentation in the docs/ directory with Sphinx 10 | sphinx: 11 | configuration: docs/conf.py 12 | 13 | # Dependencies required to build your docs 14 | python: 15 | install: 16 | - requirements: docs/requirements.txt 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Yusuke Matsui 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test clean build format deploy test_deploy mypy 2 | 3 | test: mypy 4 | pytest 5 | 6 | mypy: 7 | mypy nanopq --ignore-missing-imports 8 | 9 | clean: 10 | rm -rf build dist *.egg-info 11 | 12 | build: 13 | python setup.py sdist bdist_wheel 14 | 15 | # To run format, install pysen by 'pip install "pysen[lint]"' 16 | format: 17 | pysen run format 18 | 19 | deploy: clean build 20 | twine upload dist/* 21 | 22 | test_deploy: clean build 23 | twine upload --repository-url https://test.pypi.org/legacy/ dist/* 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nanopq 2 | 3 | [![Build Status](https://github.com/matsui528/nanopq/actions/workflows/build.yml/badge.svg)](https://github.com/matsui528/nanopq/actions) 4 | [![Documentation Status](https://readthedocs.org/projects/nanopq/badge/?version=latest)](https://nanopq.readthedocs.io/en/latest/?badge=latest) 5 | [![PyPI version](https://badge.fury.io/py/nanopq.svg)](https://badge.fury.io/py/nanopq) 6 | [![Downloads](https://pepy.tech/badge/nanopq)](https://pepy.tech/project/nanopq) 7 | 8 | Nano Product Quantization (nanopq): a vanilla implementation of Product Quantization (PQ) and Optimized Product Quantization (OPQ) written in pure python without any third party dependencies. 9 | 10 | 11 | 12 | ## Installing 13 | You can install the package via pip. This library works with Python 3.5+ on linux. 14 | ``` 15 | pip install nanopq 16 | ``` 17 | 18 | ## [Documentation](https://nanopq.readthedocs.io/en/latest/index.html) 19 | - [Tutorial](https://nanopq.readthedocs.io/en/latest/source/tutorial.html) 20 | - [API](https://nanopq.readthedocs.io/en/latest/source/api.html) 21 | 22 | ## Example 23 | 24 | ```python 25 | import nanopq 26 | import numpy as np 27 | 28 | N, Nt, D = 10000, 2000, 128 29 | X = np.random.random((N, D)).astype(np.float32) # 10,000 128-dim vectors to be indexed 30 | Xt = np.random.random((Nt, D)).astype(np.float32) # 2,000 128-dim vectors for training 31 | query = np.random.random((D,)).astype(np.float32) # a 128-dim query vector 32 | 33 | # Instantiate with M=8 sub-spaces 34 | pq = nanopq.PQ(M=8) 35 | 36 | # Train codewords 37 | pq.fit(Xt) 38 | 39 | # Encode to PQ-codes 40 | X_code = pq.encode(X) # (10000, 8) with dtype=np.uint8 41 | 42 | # Results: create a distance table online, and compute Asymmetric Distance to each PQ-code 43 | dists = pq.dtable(query).adist(X_code) # (10000, ) 44 | ``` 45 | 46 | ## Author 47 | - [Yusuke Matsui](http://yusukematsui.me) 48 | 49 | ## Contributors 50 | - [@Hiroshiba](https://github.com/Hiroshiba) fixed a bug of importlib ([#3](https://github.com/matsui528/nanopq/pull/3)) 51 | - [@calvinmccarter](https://github.com/calvinmccarter) implemented parametric initialization for OPQ ([#14](https://github.com/matsui528/nanopq/pull/14)) 52 | - [@de9uch1](https://github.com/de9uch1) exntended the interface to the faiss so that OPQ can be handled ([#19](https://github.com/matsui528/nanopq/pull/19)) 53 | - [@mpskex](https://github.com/mpskex) implemented (1) initialization of clustering and (2) dot-product for computation ([#24](https://github.com/matsui528/nanopq/pull/24)) 54 | - [@lsb](https://github.com/lsb) fixed a typo ([#26](https://github.com/matsui528/nanopq/pull/26)) 55 | - [@asukaminato0721](https://github.com/asukaminato0721) used Literal for string inputs ([#42](https://github.com/matsui528/nanopq/pull/42)) 56 | 57 | ## Reference 58 | - [H. Jegou, M. Douze, and C. Schmid, "Product Quantization for Nearest Neighbor Search", IEEE TPAMI 2011](https://ieeexplore.ieee.org/document/5432202/) (the original paper of PQ) 59 | - [T. Ge, K. He, Q. Ke, and J. Sun, "Optimized Product Quantization", IEEE TPAMI 2014](https://ieeexplore.ieee.org/document/6678503/) (the original paper of OPQ) 60 | - [Y. Matsui, Y. Uchida, H. Jegou, and S. Satoh, "A Survey of Product Quantization", ITE MTA 2018](https://www.jstage.jst.go.jp/article/mta/6/1/6_2/_pdf/) (a survey paper of PQ) 61 | - [PQ in faiss](https://github.com/facebookresearch/faiss/wiki/Faiss-building-blocks:-clustering,-PCA,-quantization#pq-encoding--decoding) (Faiss contains an optimized implementation of PQ. [See the difference to ours here](https://nanopq.readthedocs.io/en/latest/source/tutorial.html#difference-from-pq-in-faiss)) 62 | - [Rayuela.jl](https://github.com/una-dinosauria/Rayuela.jl) (Julia implementation of several encoding algorithms including PQ and OPQ) 63 | - [PQk-means](https://github.com/DwangoMediaVillage/pqkmeans) (clustering on PQ-codes. The implementation of nanopq is compatible to [that of PQk-means](https://github.com/DwangoMediaVillage/pqkmeans/blob/master/tutorial/1_pqkmeans.ipynb)) 64 | - [Rii](https://github.com/matsui528/rii) (IVFPQ-based ANN algorithm using nanopq) 65 | - [Product quantization in Faiss and from scratch](https://www.youtube.com/watch?v=PNVJvZEkuXo) (Related tutorial) 66 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = nanopq 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | 18 | sys.path.insert(0, os.path.abspath("../")) 19 | 20 | # -- Project information ----------------------------------------------------- 21 | 22 | project = "nanopq" 23 | copyright = "2019, Yusuke Matsui" 24 | author = "Yusuke Matsui" 25 | 26 | # The short X.Y version 27 | version = "" 28 | # The full version, including alpha/beta/rc tags 29 | release = "" 30 | 31 | 32 | # -- General configuration --------------------------------------------------- 33 | 34 | # If your documentation needs a minimal Sphinx version, state it here. 35 | # 36 | # needs_sphinx = '1.0' 37 | 38 | # Add any Sphinx extension module names here, as strings. They can be 39 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 40 | # ones. 41 | extensions = [ 42 | "sphinx.ext.autodoc", 43 | "sphinx.ext.napoleon", 44 | ] 45 | 46 | # Add any paths that contain templates here, relative to this directory. 47 | templates_path = ["_templates"] 48 | 49 | # The suffix(es) of source filenames. 50 | # You can specify multiple suffix as a list of string: 51 | # 52 | # source_suffix = ['.rst', '.md'] 53 | source_suffix = ".rst" 54 | 55 | # The master toctree document. 56 | master_doc = "index" 57 | 58 | # The language for content autogenerated by Sphinx. Refer to documentation 59 | # for a list of supported languages. 60 | # 61 | # This is also used if you do content translation via gettext catalogs. 62 | # Usually you set "language" from the command line for these cases. 63 | # language = None 64 | 65 | # List of patterns, relative to source directory, that match files and 66 | # directories to ignore when looking for source files. 67 | # This pattern also affects html_static_path and html_extra_path . 68 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 69 | 70 | # The name of the Pygments (syntax highlighting) style to use. 71 | pygments_style = "sphinx" 72 | 73 | 74 | # -- Options for HTML output ------------------------------------------------- 75 | 76 | # The theme to use for HTML and HTML Help pages. See the documentation for 77 | # a list of builtin themes. 78 | # 79 | html_theme = "sphinx_rtd_theme" 80 | 81 | # Theme options are theme-specific and customize the look and feel of a theme 82 | # further. For a list of options available for each theme, see the 83 | # documentation. 84 | # 85 | # html_theme_options = {} 86 | 87 | # Add any paths that contain custom static files (such as style sheets) here, 88 | # relative to this directory. They are copied after the builtin static files, 89 | # so a file named "default.css" will overwrite the builtin "default.css". 90 | # html_static_path = ['_static'] 91 | 92 | # Custom sidebar templates, must be a dictionary that maps document names 93 | # to template names. 94 | # 95 | # The default sidebars (for documents that don't match any pattern) are 96 | # defined by theme itself. Builtin themes are using these templates by 97 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 98 | # 'searchbox.html']``. 99 | # 100 | # html_sidebars = {} 101 | 102 | 103 | # -- Options for HTMLHelp output --------------------------------------------- 104 | 105 | # Output file base name for HTML help builder. 106 | htmlhelp_basename = "nanopqdoc" 107 | 108 | 109 | # -- Options for LaTeX output ------------------------------------------------ 110 | 111 | latex_elements = { 112 | # The paper size ('letterpaper' or 'a4paper'). 113 | # 114 | # 'papersize': 'letterpaper', 115 | # The font size ('10pt', '11pt' or '12pt'). 116 | # 117 | # 'pointsize': '10pt', 118 | # Additional stuff for the LaTeX preamble. 119 | # 120 | # 'preamble': '', 121 | # Latex figure (float) alignment 122 | # 123 | # 'figure_align': 'htbp', 124 | } 125 | 126 | # Grouping the document tree into LaTeX files. List of tuples 127 | # (source start file, target name, title, 128 | # author, documentclass [howto, manual, or own class]). 129 | latex_documents = [ 130 | (master_doc, "nanopq.tex", "nanopq Documentation", "Yusuke Matsui", "manual"), 131 | ] 132 | 133 | 134 | # -- Options for manual page output ------------------------------------------ 135 | 136 | # One entry per manual page. List of tuples 137 | # (source start file, name, description, authors, manual section). 138 | man_pages = [(master_doc, "nanopq", "nanopq Documentation", [author], 1)] 139 | 140 | 141 | # -- Options for Texinfo output ---------------------------------------------- 142 | 143 | # Grouping the document tree into Texinfo files. List of tuples 144 | # (source start file, target name, title, author, 145 | # dir menu entry, description, category) 146 | texinfo_documents = [ 147 | ( 148 | master_doc, 149 | "nanopq", 150 | "nanopq Documentation", 151 | author, 152 | "nanopq", 153 | "One line description of project.", 154 | "Miscellaneous", 155 | ), 156 | ] 157 | 158 | 159 | # -- Extension configuration ------------------------------------------------- 160 | 161 | # Napoleon settings 162 | # napoleon_include_init_with_doc = True 163 | 164 | # autodoc 165 | autodoc_member_order = "bysource" 166 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | `nanopq `_ documentation 2 | =================================================================== 3 | 4 | 5 | 6 | Installation 7 | ------------- 8 | You can install the package via pip. This library works with Python 3.5+ on linux. 9 | 10 | :: 11 | 12 | $ pip install nanopq 13 | 14 | 15 | Contents 16 | -------- 17 | 18 | .. toctree:: 19 | :maxdepth: 2 20 | 21 | 22 | source/tutorial 23 | source/api 24 | 25 | 26 | 27 | 28 | Indices and tables 29 | ================== 30 | 31 | * :ref:`genindex` 32 | * :ref:`modindex` 33 | * :ref:`search` 34 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=nanopq 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | faiss-cpu # To render docs easily, use un-official pypi version of faiss 4 | sphinx-rtd-theme -------------------------------------------------------------------------------- /docs/source/api.rst: -------------------------------------------------------------------------------- 1 | API Reference 2 | ============== 3 | 4 | .. automodule:: nanopq 5 | 6 | 7 | 8 | Product Quantization (PQ) 9 | ------------------------- 10 | 11 | .. autoclass:: nanopq.PQ 12 | :members: 13 | :undoc-members: 14 | 15 | 16 | Distance Table 17 | ------------------------- 18 | 19 | .. autoclass:: nanopq.DistanceTable 20 | :members: 21 | :undoc-members: 22 | 23 | 24 | Optimized Product Quantization (OPQ) 25 | ------------------------------------ 26 | 27 | .. autoclass:: nanopq.OPQ 28 | :members: 29 | :undoc-members: 30 | 31 | 32 | Convert Functions to/from Faiss 33 | ------------------------------- 34 | 35 | .. autofunction:: nanopq.nanopq_to_faiss 36 | .. autofunction:: nanopq.faiss_to_nanopq 37 | -------------------------------------------------------------------------------- /docs/source/tutorial.rst: -------------------------------------------------------------------------------- 1 | Tutorial 2 | ========== 3 | 4 | Basic of PQ 5 | ------------ 6 | 7 | This tutorial shows the basic usage of Nano Product Quantization Library (nanopq). 8 | Product quantization (PQ) is one of the most widely used algorithms 9 | for memory-efficient approximated nearest neighbor search, 10 | especially in the field of computer vision. 11 | This package contains a vanilla implementation of PQ and its improved version, Optimized Product Quantization (OPQ). 12 | 13 | Let us first prepare 10,000 12-dim vectors for database, 2,000 vectors for training, 14 | and a query vector. They must be np.ndarray with np.float32. 15 | 16 | .. code-block:: python 17 | 18 | import nanopq 19 | import numpy as np 20 | 21 | X = np.random.random((10000, 12)).astype(np.float32) 22 | Xt = np.random.random((2000, 12)).astype(np.float32) 23 | query = np.random.random((12, )).astype(np.float32) 24 | 25 | The basic idea of PQ is to split an input `D`-dim vector into `M` `D/M`-dim sub-vectors. 26 | Each sub-vector is then quantized into an identifier of the nearest codeword. 27 | 28 | First of all, a PQ class (:class:`nanopq.PQ`) is instantiated with the number of sub-vector (`M`) 29 | and the number of codeword for each sub-space (`Ks`). 30 | 31 | .. code-block:: python 32 | 33 | pq = nanopq.PQ(M=4, Ks=256, verbose=True) 34 | 35 | Note that `M` is a parameter to control the trade off of accuracy and memory-cost. 36 | If you set larger `M`, you can achieve better quantization (i.e., less reconstruction error) 37 | with more memory usage. 38 | `Ks` specifies the number of codewords for quantization. 39 | This is tyically 256 so that each sub-space is represented by 8 bits = 1 byte = np.uint8. 40 | The memory cost for each pq-code is `M * log_2 Ks` bits. 41 | 42 | Next, you need to train this quantizer by running k-means clustering for each sub-space 43 | of the training vectors. 44 | 45 | .. code-block:: python 46 | 47 | pq.fit(vecs=Xt, iter=20, seed=123) 48 | 49 | If you do not have training data, you can simply use the database vectors 50 | (or a subset of them) for training: ``pq.fit(vecs=X[:1000])``. After that, you can see codewords by `pq.codewords`. 51 | 52 | Note that, alternatively, you can instantiate and train an instance in one line if you want: 53 | 54 | .. code-block:: python 55 | 56 | pq = nanopq.PQ(M=4, Ks=256).fit(vecs=Xt, iter=20, seed=123) 57 | 58 | 59 | Given this quantizer, database vectors can be encoded to PQ-codes. 60 | 61 | .. code-block:: python 62 | 63 | X_code = pq.encode(vecs=X) 64 | 65 | The resulting PQ-code (a list of identifiers) can be regarded as a memory-efficient representation of the original vector, 66 | where the shape of `X_code` is (N, M). 67 | 68 | For the querying phase, the asymmetric distance between the query 69 | and the database PQ-codes can be computed efficiently. 70 | 71 | .. code-block:: python 72 | 73 | dt = pq.dtable(query=query) # dt.dtable.shape = (4, 256) 74 | dists = dt.adist(codes=X_code) # (10000,) 75 | 76 | For each query, a distance table (`dt`) is first computed online. 77 | `dt` is an instance of :class:`nanopq.DistanceTable` class, which is a wrapper of the actual table (np.array), `dtable`. 78 | The elements of `dt.dtable` are computed by comparing each sub-vector of the query 79 | to the codewords for each sub-subspace. 80 | More specifically, `dt.dtable[m][ks]` contains the squared Euclidean distance between 81 | (1) the `m`-th sub-vector of the query and (2) the `ks`-th codeword 82 | for the `m`-th sub-space (`pq.codewords[m][ks]`). 83 | 84 | Given `dtable`, the asymmetric distance to each PQ-code can be efficiently computed (`adist`). 85 | This can be achieved by simply fetching pre-computed distance value (the element of `dtable`) 86 | using PQ-codes. 87 | 88 | Note that the above two lines can be chained in a single line. 89 | 90 | .. code-block:: python 91 | 92 | dists = pq.dtable(query=query).adist(codes=X_code) # (10000,) 93 | 94 | 95 | The nearest feature is the one with the minimum distance. 96 | 97 | .. code-block:: python 98 | 99 | min_n = np.argmin(dists) 100 | 101 | 102 | Note that the search result is similar to that 103 | by the exact squared Euclidean distance. 104 | 105 | .. code-block:: python 106 | 107 | # The first 30 results by PQ 108 | print(dists[:30]) 109 | 110 | # The first 30 results by the exact scan 111 | dists_exact = np.linalg.norm(X - query, axis=1) ** 2 112 | print(dists_exact[:30]) 113 | 114 | 115 | Decode (reconstruction) 116 | ------------------------------- 117 | 118 | Given PQ-codes, the original `D`-dim vectors can be 119 | approximately reconstructed by fetching codewords 120 | 121 | .. code-block:: python 122 | 123 | X_reconstructed = pq.decode(codes=X_code) # (10000, 12) 124 | # The following two results should be similar 125 | print(X[:3]) 126 | print(X_reconstructed[:3]) 127 | 128 | 129 | 130 | I/O by pickling 131 | ------------------ 132 | 133 | A PQ instance can be pickled. Note that PQ-codes can be pickled as well because they are 134 | just a numpy array. 135 | 136 | .. code-block:: python 137 | 138 | import pickle 139 | 140 | with open('pq.pkl', 'wb') as f: 141 | pickle.dump(pq, f) 142 | 143 | with open('pq.pkl', 'rb') as f: 144 | pq_dumped = pickle.load(f) # pq_dumped is identical to pq 145 | 146 | 147 | 148 | Optimized PQ (OPQ) 149 | ------------------- 150 | 151 | Optimized Product Quantizaion (OPQ; :class:`nanopq.OPQ`), which is an improved version of PQ, is also available 152 | with the same interface as follows. 153 | 154 | .. code-block:: python 155 | 156 | opq = nanopq.OPQ(M=4).fit(vecs=Xt, pq_iter=20, rotation_iter=10, seed=123) 157 | X_code = opq.encode(vecs=X) 158 | dists = opq.dtable(query=query).adist(codes=X_code) 159 | 160 | The resultant codes approximate the original vectors finer, 161 | that usually leads to the better search accuracy. 162 | The training of OPQ will take much longer time compared to that of PQ. 163 | 164 | 165 | Relation to PQ in faiss 166 | ----------------------- 167 | 168 | Note that 169 | `PQ is implemented in Faiss `_, 170 | whereas Faiss is one of the most powerful ANN libraries developed by the original authors of PQ: 171 | 172 | - `faiss.ProductQuantizer `_: The core component of PQ. 173 | - `faiss.IndexPQ `_: The search interface. IndexPQ = ProductQuantizer + PQ-codes. 174 | 175 | Since Faiss is highly optimized, you should use PQ in Faiss if the runtime is your most important criteria. 176 | The difference between PQ in `nanopq` and that in Faiss is highlighted as follows: 177 | 178 | - Our `nanopq` can be installed simply by pip without any third party dependencies such as Intel MKL 179 | - The core part of `nanopq` is a vanilla implementation of PQ written in a single python file. 180 | It would be easier to extend that for further applications. 181 | - A standalone OPQ is implemented. 182 | - The result of :func:`nanopq.DistanceTable.adist` is **not** sorted. This would be useful when you would like to 183 | know not only the nearest but also the other results. 184 | - The accuracy (reconstruction error) of `nanopq.PQ` and that of `faiss.IndexPQ` are `almost same `_. 185 | 186 | You can convert an instance of `nanopq.PQ` to/from that of `faiss.IndexPQ` 187 | by :func:`nanopq.nanopq_to_faiss` or :func:`nanopq.faiss_to_nanopq`. 188 | 189 | .. code-block:: python 190 | 191 | # nanopq -> faiss 192 | pq_nanopq = nanopq.PQ(M).fit(vecs=Xt) 193 | pq_faiss = nanopq.nanopq_to_faiss(pq_nanopq) # faiss.IndexPQ 194 | 195 | # faiss -> nanopq 196 | import faiss 197 | pq_faiss2 = faiss.IndexPQ(D, M, nbits) 198 | pq_faiss2.train(x=Xt) 199 | pq_faiss2.add(x=Xb) 200 | # pq_nanopq2 is an instance of nanopq.PQ. 201 | # Cb is encoded vectors 202 | pq_nanopq2, Cb = nanopq.faiss_to_nanopq(pq_faiss2) 203 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | ignore_missing_imports = True -------------------------------------------------------------------------------- /nanopq/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["PQ", "OPQ", "DistanceTable", "nanopq_to_faiss", "faiss_to_nanopq"] 2 | __version__ = "0.2.1" 3 | 4 | from .convert_faiss import faiss_to_nanopq, nanopq_to_faiss 5 | from .opq import OPQ 6 | from .pq import PQ, DistanceTable 7 | -------------------------------------------------------------------------------- /nanopq/convert_faiss.py: -------------------------------------------------------------------------------- 1 | # Try to import faiss 2 | import importlib.util 3 | 4 | spec = importlib.util.find_spec("faiss") 5 | if spec is None: 6 | pass # If faiss hasn't been installed. Just skip 7 | else: 8 | import faiss 9 | faiss_metric_map = { 10 | "l2": faiss.METRIC_L2, 11 | "dot": faiss.METRIC_INNER_PRODUCT, 12 | "angular": faiss.METRIC_INNER_PRODUCT, 13 | } 14 | 15 | 16 | import numpy as np 17 | 18 | from .opq import OPQ 19 | from .pq import PQ 20 | 21 | 22 | 23 | def nanopq_to_faiss(pq_nanopq): 24 | """Convert a :class:`nanopq.PQ` instance to `faiss.IndexPQ `_. 25 | To use this function, `faiss module needs to be installed `_. 26 | 27 | Args: 28 | pq_nanopq (nanopq.PQ): An input PQ instance. 29 | 30 | Returns: 31 | faiss.IndexPQ: A converted PQ instance, with the same codewords to the input. 32 | 33 | """ 34 | assert isinstance(pq_nanopq, PQ), "Error. pq_nanopq must be nanopq.pq" 35 | assert ( 36 | pq_nanopq.codewords is not None 37 | ), "Error. pq_nanopq.codewords must have been set beforehand" 38 | D = pq_nanopq.Ds * pq_nanopq.M 39 | nbits = {np.uint8: 8, np.uint16: 16, np.uint32: 32}[pq_nanopq.code_dtype] 40 | 41 | pq_faiss = faiss.IndexPQ(D, pq_nanopq.M, nbits, faiss_metric_map[pq_nanopq.metric]) 42 | 43 | for m in range(pq_nanopq.M): 44 | # Prepare std::vector 45 | codewords_cpp_m = faiss.Float32Vector() 46 | 47 | # Flatten m-th codewords from (Ks, Ds) to (Ks * Ds, ), then copy them to cpp 48 | faiss.copy_array_to_vector(pq_nanopq.codewords[m].reshape(-1), codewords_cpp_m) 49 | 50 | # Set the codeword to ProductQuantizer in IndexPQ 51 | pq_faiss.pq.set_params(centroids=codewords_cpp_m.data(), m=m) 52 | 53 | pq_faiss.is_trained = True 54 | 55 | return pq_faiss 56 | 57 | 58 | def faiss_to_nanopq(pq_faiss): 59 | """Convert a `faiss.IndexPQ `_ 60 | or a `faiss.IndexPreTransform `_ instance to :class:`nanopq.OPQ`. 61 | To use this function, `faiss module needs to be installed `_. 62 | 63 | Args: 64 | pq_faiss (Union[faiss.IndexPQ, faiss.IndexPreTransform]): An input PQ or OPQ instance. 65 | 66 | Returns: 67 | tuple: 68 | * Union[nanopq.PQ, nanopq.OPQ]: A converted PQ or OPQ instance, with the same codewords to the input. 69 | * np.ndarray: Stored PQ codes in the input IndexPQ, with the shape=(N, M). This will be empty if codes are not stored 70 | 71 | """ 72 | assert isinstance( 73 | pq_faiss, (faiss.IndexPQ, faiss.IndexPreTransform) 74 | ), "Error. pq_faiss must be IndexPQ or IndexPreTransform" 75 | assert pq_faiss.is_trained, "Error. pq_faiss must have been trained" 76 | 77 | if isinstance(pq_faiss, faiss.IndexPreTransform): 78 | opq_matrix: faiss.LinearTransform = faiss.downcast_VectorTransform( 79 | pq_faiss.chain.at(0) 80 | ) 81 | pq_faiss: faiss.IndexPQ = faiss.downcast_index(pq_faiss.index) 82 | pq_nanopq = OPQ(M=pq_faiss.pq.M, Ks=int(2**pq_faiss.pq.nbits)) 83 | pq_nanopq.pq.Ds = int(pq_faiss.pq.d / pq_faiss.pq.M) 84 | 85 | # Extract codewords from pq_IndexPQ.ProductQuantizer, reshape them to M*Ks*Ds 86 | codewords = faiss.vector_to_array(pq_faiss.pq.centroids).reshape( 87 | pq_nanopq.M, pq_nanopq.Ks, pq_nanopq.Ds 88 | ) 89 | 90 | pq_nanopq.pq.codewords = codewords 91 | pq_nanopq.R = ( 92 | faiss.vector_to_array(opq_matrix.A) 93 | .reshape(opq_matrix.d_out, opq_matrix.d_in) 94 | .transpose(1, 0) 95 | ) 96 | else: 97 | pq_nanopq = PQ(M=pq_faiss.pq.M, Ks=int(2**pq_faiss.pq.nbits)) 98 | pq_nanopq.Ds = int(pq_faiss.pq.d / pq_faiss.pq.M) 99 | 100 | # Extract codewords from pq_IndexPQ.ProductQuantizer, reshape them to M*Ks*Ds 101 | codewords = faiss.vector_to_array(pq_faiss.pq.centroids).reshape( 102 | pq_nanopq.M, pq_nanopq.Ks, pq_nanopq.Ds 103 | ) 104 | pq_nanopq.codewords = codewords 105 | 106 | return pq_nanopq, faiss.vector_to_array(pq_faiss.codes).reshape(-1, pq_faiss.pq.M) 107 | -------------------------------------------------------------------------------- /nanopq/opq.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import numpy as np 4 | from typing import Literal 5 | 6 | from .pq import PQ 7 | 8 | 9 | class OPQ(object): 10 | """Pure python implementation of Optimized Product Quantization (OPQ) [Ge14]_. 11 | 12 | OPQ is a simple extension of PQ. 13 | The best rotation matrix `R` is prepared using training vectors. 14 | Each input vector is rotated via `R`, then quantized into PQ-codes 15 | in the same manner as the original PQ. 16 | 17 | .. [Ge14] T. Ge et al., "Optimized Product Quantization", IEEE TPAMI 2014 18 | 19 | Args: 20 | M (int): The number of sub-spaces 21 | Ks (int): The number of codewords for each subspace (typically 256, so that each sub-vector is quantized 22 | into 8 bits = 1 byte = uint8) 23 | metric (Literal["l2", "dot"]): Type of metric used among vectors (either 'l2' or 'dot') 24 | Note that even for 'dot', kmeans and encoding are performed in the Euclidean space. 25 | verbose (bool): Verbose flag 26 | 27 | Attributes: 28 | R (np.ndarray): Rotation matrix with the shape=(D, D) and dtype=np.float32 29 | 30 | 31 | """ 32 | 33 | def __init__(self, M, Ks=256, metric: Literal["l2", "dot"] = "l2", verbose=True): 34 | self.pq = PQ(M, Ks, metric=metric, verbose=verbose) 35 | self.R = None 36 | 37 | def __eq__(self, other): 38 | if isinstance(other, OPQ): 39 | return self.pq == other.pq and np.array_equal(self.R, other.R) 40 | else: 41 | return False 42 | 43 | @property 44 | def M(self): 45 | """int: The number of sub-space""" 46 | return self.pq.M 47 | 48 | @property 49 | def Ks(self): 50 | """int: The number of codewords for each subspace""" 51 | return self.pq.Ks 52 | 53 | @property 54 | def verbose(self): 55 | """bool: Verbose flag""" 56 | return self.pq.verbose 57 | 58 | @verbose.setter 59 | def verbose(self, v): 60 | self.pq.verbose = v 61 | 62 | @property 63 | def code_dtype(self): 64 | """object: dtype of PQ-code. Either np.uint{8, 16, 32}""" 65 | return self.pq.code_dtype 66 | 67 | @property 68 | def codewords(self): 69 | """np.ndarray: shape=(M, Ks, Ds) with dtype=np.float32. 70 | codewords[m][ks] means ks-th codeword (Ds-dim) for m-th subspace 71 | """ 72 | return self.pq.codewords 73 | 74 | @property 75 | def Ds(self): 76 | """int: The dim of each sub-vector, i.e., Ds=D/M""" 77 | return self.pq.Ds 78 | 79 | def eigenvalue_allocation(self, vecs): 80 | """Given training vectors, this function learns a rotation matrix. 81 | The rotation matrix is computed so as to minimize the distortion bound of PQ, 82 | assuming a multivariate Gaussian distribution. 83 | 84 | This function is a translation from the original MATLAB implementation to that of python 85 | http://kaiminghe.com/cvpr13/index.html 86 | 87 | Args: 88 | vecs: (np.ndarray): Training vectors with shape=(N, D) and dtype=np.float32. 89 | 90 | Returns: 91 | R: (np.ndarray) rotation matrix of shape=(D, D) with dtype=np.float32. 92 | """ 93 | _, D = vecs.shape 94 | cov = np.cov(vecs, rowvar=False) 95 | w, v = np.linalg.eig(cov) 96 | sort_ix = np.argsort(np.abs(w))[::-1] 97 | eig_vals = w[sort_ix] 98 | eig_vecs = v[:, sort_ix] 99 | 100 | assert D % self.M == 0, "input dimension must be dividable by M" 101 | Ds = D // self.M 102 | dim_tables = defaultdict(list) 103 | fvals = np.log(eig_vals + 1e-10) 104 | fvals = fvals - np.min(fvals) + 1 105 | sum_list = np.zeros(self.M) 106 | big_number = 1e10 + np.sum(fvals) 107 | 108 | cur_subidx = 0 109 | for d in range(D): 110 | dim_tables[cur_subidx].append(d) 111 | sum_list[cur_subidx] += fvals[d] 112 | if len(dim_tables[cur_subidx]) == Ds: 113 | sum_list[cur_subidx] = big_number 114 | cur_subidx = np.argmin(sum_list) 115 | 116 | dim_ordered = [] 117 | for m in range(self.M): 118 | dim_ordered.extend(dim_tables[m]) 119 | 120 | R = eig_vecs[:, dim_ordered] 121 | R = R.astype(dtype=np.float32) 122 | return R 123 | 124 | def fit( 125 | self, 126 | vecs, 127 | parametric_init=False, 128 | pq_iter=20, 129 | rotation_iter=10, 130 | seed=123, 131 | minit: Literal["random", "++", "points", "matrix"] = "points", 132 | ): 133 | """Given training vectors, this function alternatively trains 134 | (a) codewords and (b) a rotation matrix. 135 | The procedure of training codewords is same as :func:`PQ.fit`. 136 | The rotation matrix is computed so as to minimize the quantization error 137 | given codewords (Orthogonal Procrustes problem) 138 | 139 | This function is a translation from the original MATLAB implementation to that of python 140 | http://kaiminghe.com/cvpr13/index.html 141 | 142 | If you find the error message is messy, please turn off the verbose flag, then 143 | you can see the reduction of error for each iteration clearly 144 | 145 | Args: 146 | vecs (np.ndarray): Training vectors with shape=(N, D) and dtype=np.float32. 147 | parametric_init (bool): Whether to initialize rotation using parametric assumption. 148 | pq_iter (int): The number of iteration for k-means 149 | rotation_iter (int): The number of iteration for learning rotation 150 | seed (int): The seed for random process 151 | minit (Literal["random", "++", "points", "matrix"]): The method for initialization of centroids for k-means (either 'random', '++', 'points', 'matrix') 152 | 153 | Returns: 154 | object: self 155 | 156 | """ 157 | assert vecs.dtype == np.float32 158 | assert vecs.ndim == 2 159 | _, D = vecs.shape 160 | if parametric_init: 161 | self.R = self.eigenvalue_allocation(vecs) 162 | else: 163 | self.R = np.eye(D, dtype=np.float32) 164 | 165 | for i in range(rotation_iter): 166 | if self.verbose: 167 | print("OPQ rotation training: {} / {}".format(i, rotation_iter)) 168 | X = vecs @ self.R 169 | 170 | # (a) Train codewords 171 | pq_tmp = PQ(M=self.M, Ks=self.Ks, verbose=self.verbose) 172 | if i == rotation_iter - 1: 173 | # In the final loop, run the full training 174 | pq_tmp.fit(X, iter=pq_iter, seed=seed, minit=minit) 175 | else: 176 | # During the training for OPQ, just run one-pass (iter=1) PQ training 177 | pq_tmp.fit(X, iter=1, seed=seed, minit=minit) 178 | 179 | # (b) Update a rotation matrix R 180 | X_ = pq_tmp.decode(pq_tmp.encode(X)) 181 | U, s, V = np.linalg.svd(vecs.T @ X_) 182 | if self.verbose: 183 | print( 184 | "==== Reconstruction error:", np.linalg.norm(X - X_, "fro"), "====" 185 | ) 186 | if i == rotation_iter - 1: 187 | self.pq = pq_tmp 188 | break 189 | else: 190 | self.R = U @ V 191 | 192 | return self 193 | 194 | def rotate(self, vecs): 195 | """Rotate input vector(s) by the rotation matrix.` 196 | 197 | Args: 198 | vecs (np.ndarray): Input vector(s) with dtype=np.float32. 199 | The shape can be a single vector (D, ) or several vectors (N, D) 200 | 201 | Returns: 202 | np.ndarray: Rotated vectors with the same shape and dtype to the input vecs. 203 | 204 | """ 205 | assert vecs.dtype == np.float32 206 | assert vecs.ndim in [1, 2] 207 | 208 | if vecs.ndim == 2: 209 | return vecs @ self.R 210 | elif vecs.ndim == 1: 211 | return (vecs.reshape(1, -1) @ self.R).reshape(-1) 212 | 213 | def encode(self, vecs): 214 | """Rotate input vectors by :func:`OPQ.rotate`, then encode them via :func:`PQ.encode`. 215 | 216 | Args: 217 | vecs (np.ndarray): Input vectors with shape=(N, D) and dtype=np.float32. 218 | 219 | Returns: 220 | np.ndarray: PQ codes with shape=(N, M) and dtype=self.code_dtype 221 | 222 | """ 223 | return self.pq.encode(self.rotate(vecs)) 224 | 225 | def decode(self, codes): 226 | """Given PQ-codes, reconstruct original D-dimensional vectors via :func:`PQ.decode`, 227 | and applying an inverse-rotation. 228 | 229 | Args: 230 | codes (np.ndarray): PQ-cdoes with shape=(N, M) and dtype=self.code_dtype. 231 | Each row is a PQ-code 232 | 233 | Returns: 234 | np.ndarray: Reconstructed vectors with shape=(N, D) and dtype=np.float32 235 | 236 | """ 237 | # Because R is a rotation matrix (R^t * R = I), R^-1 should be R^t 238 | return self.pq.decode(codes) @ self.R.T 239 | 240 | def dtable(self, query): 241 | """Compute a distance table for a query vector. The query is 242 | first rotated by :func:`OPQ.rotate`, then DistanceTable is computed by :func:`PQ.dtable`. 243 | 244 | Args: 245 | query (np.ndarray): Input vector with shape=(D, ) and dtype=np.float32 246 | 247 | Returns: 248 | nanopq.DistanceTable: 249 | Distance table. which contains 250 | dtable with shape=(M, Ks) and dtype=np.float32 251 | 252 | """ 253 | return self.pq.dtable(self.rotate(query)) 254 | -------------------------------------------------------------------------------- /nanopq/pq.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.cluster.vq import kmeans2, vq 3 | from typing import Literal 4 | 5 | 6 | def dist_l2(q, x): 7 | return np.linalg.norm(q - x, ord=2, axis=1) ** 2 8 | 9 | 10 | def dist_ip(q, x): 11 | return q @ x.T 12 | 13 | 14 | metric_function_map = {"l2": dist_l2, "dot": dist_ip} 15 | 16 | 17 | class PQ(object): 18 | """Pure python implementation of Product Quantization (PQ) [Jegou11]_. 19 | 20 | For the indexing phase of database vectors, 21 | a `D`-dim input vector is divided into `M` `D`/`M`-dim sub-vectors. 22 | Each sub-vector is quantized into a small integer via `Ks` codewords. 23 | For the querying phase, given a new `D`-dim query vector, the distance beween the query 24 | and the database PQ-codes are efficiently approximated via Asymmetric Distance. 25 | 26 | All vectors must be np.ndarray with np.float32 27 | 28 | .. [Jegou11] H. Jegou et al., "Product Quantization for Nearest Neighbor Search", IEEE TPAMI 2011 29 | 30 | Args: 31 | M (int): The number of sub-space 32 | Ks (int): The number of codewords for each subspace 33 | (typically 256, so that each sub-vector is quantized 34 | into 8 bits = 1 byte = uint8) 35 | metric (Literal["l2", "dot"]): Type of metric used among vectors (either 'l2' or 'dot') 36 | Note that even for 'dot', kmeans and encoding are performed in the Euclidean space. 37 | verbose (bool): Verbose flag 38 | 39 | Attributes: 40 | M (int): The number of sub-space 41 | Ks (int): The number of codewords for each subspace 42 | metric (Literal["l2", "dot"]): Type of metric used among vectors 43 | verbose (bool): Verbose flag 44 | code_dtype (object): dtype of PQ-code. Either np.uint{8, 16, 32} 45 | codewords (np.ndarray): shape=(M, Ks, Ds) with dtype=np.float32. 46 | codewords[m][ks] means ks-th codeword (Ds-dim) for m-th subspace 47 | Ds (int): The dim of each sub-vector, i.e., Ds=D/M 48 | 49 | """ 50 | 51 | def __init__(self, M, Ks=256, metric: Literal["l2", "dot"] = "l2", verbose=True): 52 | assert 0 < Ks <= 2**32 53 | assert metric in ["l2", "dot"] 54 | self.M, self.Ks, self.metric, self.verbose = M, Ks, metric, verbose 55 | self.code_dtype = ( 56 | np.uint8 if Ks <= 2**8 else (np.uint16 if Ks <= 2**16 else np.uint32) 57 | ) 58 | self.codewords = None 59 | self.Ds = None 60 | 61 | if verbose: 62 | print( 63 | "M: {}, Ks: {}, metric : {}, code_dtype: {}".format( 64 | M, Ks, self.code_dtype, metric 65 | ) 66 | ) 67 | 68 | def __eq__(self, other): 69 | if isinstance(other, PQ): 70 | return ( 71 | self.M, 72 | self.Ks, 73 | self.metric, 74 | self.verbose, 75 | self.code_dtype, 76 | self.Ds, 77 | ) == ( 78 | other.M, 79 | other.Ks, 80 | other.metric, 81 | other.verbose, 82 | other.code_dtype, 83 | other.Ds, 84 | ) and np.array_equal(self.codewords, other.codewords) 85 | else: 86 | return False 87 | 88 | def fit( 89 | self, 90 | vecs, 91 | iter=20, 92 | seed=123, 93 | minit: Literal["random", "++", "points", "matrix"] = "points", 94 | ): 95 | """Given training vectors, run k-means for each sub-space and create 96 | codewords for each sub-space. 97 | 98 | This function should be run once first of all. 99 | 100 | Args: 101 | vecs (np.ndarray): Training vectors with shape=(N, D) and dtype=np.float32. 102 | iter (int): The number of iteration for k-means 103 | seed (int): The seed for random process 104 | minit (Literal["random", "++", "points", "matrix"]): The method for initialization of centroids for k-means (either 'random', '++', 'points', 'matrix') 105 | 106 | Returns: 107 | object: self 108 | 109 | """ 110 | assert vecs.dtype == np.float32 111 | assert vecs.ndim == 2 112 | N, D = vecs.shape 113 | assert self.Ks < N, "the number of training vector should be more than Ks" 114 | assert D % self.M == 0, "input dimension must be dividable by M" 115 | assert minit in ["random", "++", "points", "matrix"] 116 | self.Ds = int(D / self.M) 117 | 118 | np.random.seed(seed) 119 | if self.verbose: 120 | print("iter: {}, seed: {}".format(iter, seed)) 121 | 122 | # [m][ks][ds]: m-th subspace, ks-the codeword, ds-th dim 123 | self.codewords = np.zeros((self.M, self.Ks, self.Ds), dtype=np.float32) 124 | for m in range(self.M): 125 | if self.verbose: 126 | print("Training the subspace: {} / {}".format(m, self.M)) 127 | vecs_sub = vecs[:, m * self.Ds : (m + 1) * self.Ds] 128 | self.codewords[m], _ = kmeans2(vecs_sub, self.Ks, iter=iter, minit=minit) 129 | return self 130 | 131 | def encode(self, vecs): 132 | """Encode input vectors into PQ-codes. 133 | 134 | Args: 135 | vecs (np.ndarray): Input vectors with shape=(N, D) and dtype=np.float32. 136 | 137 | Returns: 138 | np.ndarray: PQ codes with shape=(N, M) and dtype=self.code_dtype 139 | 140 | """ 141 | assert vecs.dtype == np.float32 142 | assert vecs.ndim == 2 143 | N, D = vecs.shape 144 | assert D == self.Ds * self.M, "input dimension must be Ds * M" 145 | 146 | # codes[n][m] : code of n-th vec, m-th subspace 147 | codes = np.empty((N, self.M), dtype=self.code_dtype) 148 | for m in range(self.M): 149 | if self.verbose: 150 | print("Encoding the subspace: {} / {}".format(m, self.M)) 151 | vecs_sub = vecs[:, m * self.Ds : (m + 1) * self.Ds] 152 | codes[:, m], _ = vq(vecs_sub, self.codewords[m]) 153 | 154 | return codes 155 | 156 | def decode(self, codes): 157 | """Given PQ-codes, reconstruct original D-dimensional vectors 158 | approximately by fetching the codewords. 159 | 160 | Args: 161 | codes (np.ndarray): PQ-cdoes with shape=(N, M) and dtype=self.code_dtype. 162 | Each row is a PQ-code 163 | 164 | Returns: 165 | np.ndarray: Reconstructed vectors with shape=(N, D) and dtype=np.float32 166 | 167 | """ 168 | assert codes.ndim == 2 169 | N, M = codes.shape 170 | assert M == self.M 171 | assert codes.dtype == self.code_dtype 172 | 173 | vecs = np.empty((N, self.Ds * self.M), dtype=np.float32) 174 | for m in range(self.M): 175 | vecs[:, m * self.Ds : (m + 1) * self.Ds] = self.codewords[m][codes[:, m], :] 176 | 177 | return vecs 178 | 179 | def dtable(self, query): 180 | """Compute a distance table for a query vector. 181 | The distances are computed by comparing each sub-vector of the query 182 | to the codewords for each sub-subspace. 183 | `dtable[m][ks]` contains the squared Euclidean distance between 184 | the `m`-th sub-vector of the query and the `ks`-th codeword 185 | for the `m`-th sub-space (`self.codewords[m][ks]`). 186 | 187 | Args: 188 | query (np.ndarray): Input vector with shape=(D, ) and dtype=np.float32 189 | 190 | Returns: 191 | nanopq.DistanceTable: 192 | Distance table. which contains 193 | dtable with shape=(M, Ks) and dtype=np.float32 194 | 195 | """ 196 | assert query.dtype == np.float32 197 | assert query.ndim == 1, "input must be a single vector" 198 | (D,) = query.shape 199 | assert D == self.Ds * self.M, "input dimension must be Ds * M" 200 | 201 | # dtable[m] : distance between m-th subvec and m-th codewords (m-th subspace) 202 | # dtable[m][ks] : distance between m-th subvec and ks-th codeword of m-th codewords 203 | dtable = np.empty((self.M, self.Ks), dtype=np.float32) 204 | for m in range(self.M): 205 | query_sub = query[m * self.Ds : (m + 1) * self.Ds] 206 | dtable[m, :] = metric_function_map[self.metric]( 207 | query_sub, self.codewords[m] 208 | ) 209 | 210 | # In case of L2, the above line would be: 211 | # dtable[m, :] = np.linalg.norm(self.codewords[m] - query_sub, axis=1) ** 2 212 | 213 | return DistanceTable(dtable, metric=self.metric) 214 | 215 | 216 | class DistanceTable(object): 217 | """Distance table from query to codewords. 218 | Given a query vector, a PQ/OPQ instance compute this DistanceTable class 219 | using :func:`PQ.dtable` or :func:`OPQ.dtable`. 220 | The Asymmetric Distance from query to each database codes can be computed 221 | by :func:`DistanceTable.adist`. 222 | 223 | Args: 224 | dtable (np.ndarray): Distance table with shape=(M, Ks) and dtype=np.float32 225 | computed by :func:`PQ.dtable` or :func:`OPQ.dtable` 226 | metric (Literal["l2", "dot"]): metric type to calculate distance 227 | 228 | Attributes: 229 | dtable (np.ndarray): Distance table with shape=(M, Ks) and dtype=np.float32. 230 | Note that dtable[m][ks] contains the squared Euclidean distance between 231 | (1) m-th sub-vector of query and (2) ks-th codeword for m-th subspace. 232 | 233 | """ 234 | 235 | def __init__(self, dtable, metric: Literal["l2", "dot"] = "l2"): 236 | assert dtable.ndim == 2 237 | assert dtable.dtype == np.float32 238 | assert metric in ["l2", "dot"] 239 | self.dtable = dtable 240 | self.metric = metric 241 | 242 | def adist(self, codes): 243 | """Given PQ-codes, compute Asymmetric Distances between the query (self.dtable) 244 | and the PQ-codes. 245 | 246 | Args: 247 | codes (np.ndarray): PQ codes with shape=(N, M) and 248 | dtype=pq.code_dtype where pq is a pq instance that creates the codes 249 | 250 | Returns: 251 | np.ndarray: Asymmetric Distances with shape=(N, ) and dtype=np.float32 252 | 253 | """ 254 | 255 | assert codes.ndim == 2 256 | N, M = codes.shape 257 | assert M == self.dtable.shape[0] 258 | 259 | # Fetch distance values using codes. 260 | dists = np.sum(self.dtable[range(M), codes], axis=1) 261 | 262 | # The above line is equivalent to the followings: 263 | # dists = np.zeros((N, )).astype(np.float32) 264 | # for n in range(N): 265 | # for m in range(M): 266 | # dists[n] += self.dtable[m][codes[n][m]] 267 | 268 | return dists 269 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pysen] 2 | version = "0.10" 3 | 4 | [tool.pysen.lint] 5 | enable_black = true 6 | enable_flake8 = true 7 | enable_isort = true 8 | enable_mypy = true 9 | mypy_preset = "strict" 10 | line_length = 88 11 | py_version = "py37" 12 | [[tool.pysen.lint.mypy_targets]] 13 | paths = ["."] -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | mypy -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from setuptools import find_packages, setup 4 | 5 | with open("README.md") as f: 6 | readme = f.read() 7 | 8 | with open("nanopq/__init__.py") as f: 9 | version = re.search(r"__version__ = \"(.*?)\"", f.read()).group(1) 10 | 11 | setup( 12 | name="nanopq", 13 | version=version, 14 | description="Pure python implementation of product quantization for nearest neighbor search ", 15 | long_description=readme, 16 | long_description_content_type="text/markdown", 17 | author="Yusuke Matsui", 18 | author_email="matsui528@gmail.com", 19 | url="https://github.com/matsui528/nanopq", 20 | license="MIT", 21 | packages=find_packages(exclude=("tests", "docs")), 22 | install_requires=["numpy", "scipy"], 23 | extras_require={ 24 | "dev": ["mypy"], 25 | }, 26 | ) 27 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matsui528/nanopq/7f4afc9bac69b9f6b587fc19ac48a4302b830c59/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_convert_faiss.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | sys.path.append(str(Path(__file__).resolve().parent)) 5 | import importlib.util 6 | import unittest 7 | 8 | import nanopq 9 | import numpy as np 10 | 11 | spec = importlib.util.find_spec("faiss") 12 | if spec is None: 13 | raise unittest.SkipTest( 14 | "Cannot find the faiss module. Skipt the test for convert_faiss" 15 | ) 16 | else: 17 | import faiss 18 | 19 | print("faiss version:", faiss.__version__) 20 | 21 | 22 | class TestSuite(unittest.TestCase): 23 | def setUp(self): 24 | np.random.seed(123) 25 | 26 | def test_nanopq_to_faiss(self): 27 | D, M, Ks = 32, 4, 256 28 | Nt, Nb, Nq = 2000, 10000, 100 29 | Xt = np.random.rand(Nt, D).astype(np.float32) 30 | Xb = np.random.rand(Nb, D).astype(np.float32) 31 | Xq = np.random.rand(Nq, D).astype(np.float32) 32 | pq_nanopq = nanopq.PQ(M=M, Ks=Ks) 33 | pq_nanopq.fit(vecs=Xt) 34 | 35 | with self.assertRaises(AssertionError): # opq is not supported 36 | opq = nanopq.OPQ(M=M, Ks=Ks) 37 | nanopq.nanopq_to_faiss(opq) 38 | 39 | pq_faiss = nanopq.nanopq_to_faiss(pq_nanopq) # IndexPQ 40 | 41 | # Encoded results should be same 42 | Cb_nanopq = pq_nanopq.encode(vecs=Xb) 43 | Cb_faiss = pq_faiss.pq.compute_codes(x=Xb) # ProductQuantizer in IndexPQ 44 | self.assertTrue(np.array_equal(Cb_nanopq, Cb_faiss)) 45 | 46 | # Search result should be same 47 | topk = 10 48 | pq_faiss.add(Xb) 49 | _, ids1 = pq_faiss.search(x=Xq, k=topk) 50 | ids2 = np.array( 51 | [ 52 | np.argsort(pq_nanopq.dtable(query=xq).adist(codes=Cb_nanopq))[:topk] 53 | for xq in Xq 54 | ] 55 | ) 56 | 57 | self.assertTrue(np.array_equal(ids1, ids2)) 58 | 59 | def test_faiss_to_nanopq_pq(self): 60 | D, M, Ks = 32, 4, 256 61 | Nt, Nb, Nq = 2000, 10000, 100 62 | nbits = int(np.log2(Ks)) 63 | assert nbits == 8 64 | Xt = np.random.rand(Nt, D).astype(np.float32) 65 | Xb = np.random.rand(Nb, D).astype(np.float32) 66 | Xq = np.random.rand(Nq, D).astype(np.float32) 67 | 68 | pq_faiss = faiss.IndexPQ(D, M, nbits) 69 | pq_faiss.train(x=Xt) 70 | pq_faiss.add(x=Xb) 71 | 72 | pq_nanopq, Cb_faiss = nanopq.faiss_to_nanopq(pq_faiss=pq_faiss) 73 | self.assertIsInstance(pq_nanopq, nanopq.PQ) 74 | self.assertEqual(pq_nanopq.codewords.shape, (M, Ks, int(D / M))) 75 | 76 | # Encoded results should be same 77 | Cb_nanopq = pq_nanopq.encode(vecs=Xb) 78 | self.assertTrue(np.array_equal(Cb_nanopq, Cb_faiss)) 79 | 80 | # Search result should be same 81 | topk = 100 82 | _, ids1 = pq_faiss.search(x=Xq, k=topk) 83 | ids2 = np.array( 84 | [ 85 | np.argsort(pq_nanopq.dtable(query=xq).adist(codes=Cb_nanopq))[:topk] 86 | for xq in Xq 87 | ] 88 | ) 89 | self.assertTrue(np.array_equal(ids1, ids2)) 90 | 91 | def test_faiss_to_nanopq_opq(self): 92 | D, M, Ks = 32, 4, 256 93 | Nt, Nb, Nq = 2000, 10000, 100 94 | nbits = int(np.log2(Ks)) 95 | assert nbits == 8 96 | Xt = np.random.rand(Nt, D).astype(np.float32) 97 | Xb = np.random.rand(Nb, D).astype(np.float32) 98 | Xq = np.random.rand(Nq, D).astype(np.float32) 99 | 100 | pq_faiss = faiss.IndexPQ(D, M, nbits) 101 | opq_matrix = faiss.OPQMatrix(D, M=M) 102 | pq_faiss = faiss.IndexPreTransform(opq_matrix, pq_faiss) 103 | pq_faiss.train(x=Xt) 104 | pq_faiss.add(x=Xb) 105 | 106 | pq_nanopq, Cb_faiss = nanopq.faiss_to_nanopq(pq_faiss=pq_faiss) 107 | self.assertIsInstance(pq_nanopq, nanopq.OPQ) 108 | self.assertEqual(pq_nanopq.codewords.shape, (M, Ks, int(D / M))) 109 | 110 | # Encoded results should be same 111 | Cb_nanopq = pq_nanopq.encode(vecs=Xb) 112 | self.assertTrue(np.array_equal(Cb_nanopq, Cb_faiss)) 113 | 114 | # Search result should be same 115 | topk = 100 116 | _, ids1 = pq_faiss.search(x=Xq, k=topk) 117 | ids2 = np.array( 118 | [ 119 | np.argsort(pq_nanopq.dtable(query=xq).adist(codes=Cb_nanopq))[:topk] 120 | for xq in Xq 121 | ] 122 | ) 123 | self.assertTrue(np.array_equal(ids1, ids2)) 124 | 125 | def test_faiss_nanopq_compare_accuracy(self): 126 | D, M, Ks = 32, 4, 256 127 | Nt, Nb, Nq = 20000, 10000, 100 128 | nbits = int(np.log2(Ks)) 129 | assert nbits == 8 130 | Xt = np.random.rand(Nt, D).astype(np.float32) 131 | Xb = np.random.rand(Nb, D).astype(np.float32) 132 | Xq = np.random.rand(Nq, D).astype(np.float32) 133 | 134 | pq_faiss = faiss.IndexPQ(D, M, nbits) 135 | pq_faiss.train(x=Xt) 136 | Cb_faiss = pq_faiss.pq.compute_codes(Xb) 137 | Xb_faiss_ = pq_faiss.pq.decode(Cb_faiss) 138 | 139 | pq_nanopq = nanopq.PQ(M=M, Ks=Ks) 140 | pq_nanopq.fit(vecs=Xt) 141 | Cb_nanopq = pq_nanopq.encode(vecs=Xb) 142 | Xb_nanopq_ = pq_nanopq.decode(codes=Cb_nanopq) 143 | 144 | # Reconstruction error should be almost identical 145 | avg_relative_error_faiss = ((Xb - Xb_faiss_) ** 2).sum() / (Xb**2).sum() 146 | avg_relative_error_nanopq = ((Xb - Xb_nanopq_) ** 2).sum() / (Xb**2).sum() 147 | diff_rel = ( 148 | avg_relative_error_faiss - avg_relative_error_nanopq 149 | ) / avg_relative_error_faiss 150 | diff_rel = np.sqrt(diff_rel**2) 151 | print("avg_rel_error_faiss:", avg_relative_error_faiss) 152 | print("avg_rel_error_nanopq:", avg_relative_error_nanopq) 153 | print("diff rel:", diff_rel) 154 | 155 | self.assertLess(diff_rel, 0.01) 156 | 157 | 158 | if __name__ == "__main__": 159 | unittest.main() 160 | -------------------------------------------------------------------------------- /tests/test_opq.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | sys.path.append(str(Path(__file__).resolve().parent)) 5 | import unittest 6 | 7 | import nanopq 8 | import numpy as np 9 | 10 | 11 | class TestSuite(unittest.TestCase): 12 | def setUp(self): 13 | np.random.seed(123) 14 | 15 | def test_property(self): 16 | opq = nanopq.OPQ(M=4, Ks=256) 17 | self.assertEqual( 18 | (opq.M, opq.Ks, opq.verbose, opq.code_dtype), 19 | (opq.pq.M, opq.pq.Ks, opq.pq.verbose, opq.pq.code_dtype), 20 | ) 21 | 22 | def test_fit(self): 23 | N, D, M, Ks = 100, 12, 4, 10 24 | X = np.random.random((N, D)).astype(np.float32) 25 | opq = nanopq.OPQ(M=M, Ks=Ks) 26 | opq.fit(X) 27 | self.assertEqual(opq.Ds, D / M) 28 | self.assertEqual(opq.codewords.shape, (M, Ks, D / M)) 29 | self.assertEqual(opq.R.shape, (D, D)) 30 | 31 | opq2 = nanopq.OPQ(M=M, Ks=Ks).fit(X) # Can be called as a chain 32 | self.assertTrue(np.allclose(opq.codewords, opq2.codewords)) 33 | 34 | def test_eq(self): 35 | import copy 36 | 37 | N, D, M, Ks = 100, 12, 4, 10 38 | X = np.random.random((N, D)).astype(np.float32) 39 | opq1 = nanopq.OPQ(M=M, Ks=Ks) 40 | opq2 = nanopq.OPQ(M=M, Ks=Ks) 41 | opq3 = copy.deepcopy(opq1) 42 | opq4 = nanopq.OPQ(M=M, Ks=2 * Ks) 43 | self.assertTrue(opq1 == opq1) 44 | self.assertTrue(opq1 == opq2) 45 | self.assertTrue(opq1 == opq3) 46 | self.assertTrue(opq1 != opq4) 47 | 48 | opq1.fit(X) 49 | opq2.fit(X) 50 | opq3 = copy.deepcopy(opq1) 51 | opq4.fit(X) 52 | self.assertTrue(opq1 == opq1) 53 | self.assertTrue(opq1 == opq2) 54 | self.assertTrue(opq1 == opq3) 55 | self.assertTrue(opq1 != opq4) 56 | 57 | def test_rotate(self): 58 | N, D, M, Ks = 100, 12, 4, 10 59 | X = np.random.random((N, D)).astype(np.float32) 60 | opq = nanopq.OPQ(M=M, Ks=Ks) 61 | opq.fit(X) 62 | rotated_vec = opq.rotate(X[0]) 63 | rotated_vecs = opq.rotate(X[:3]) 64 | self.assertEqual(rotated_vec.shape, (D,)) 65 | self.assertEqual(rotated_vecs.shape, (3, D)) 66 | 67 | # Because R is a rotation matrix (R^t * R = I), R^t should be R^(-1) 68 | self.assertAlmostEqual( 69 | np.linalg.norm(opq.R.T - np.linalg.inv(opq.R)), 0.0, places=3 70 | ) 71 | 72 | def test_parametric_init(self): 73 | N, D, M, Ks = 100, 12, 2, 20 74 | X = np.random.random((N, D)).astype(np.float32) 75 | opq = nanopq.OPQ(M=M, Ks=Ks) 76 | opq.fit(X, parametric_init=False, rotation_iter=1) 77 | err_init = np.linalg.norm(X - opq.decode(opq.encode(X))) 78 | 79 | opq = nanopq.OPQ(M=M, Ks=Ks) 80 | opq.fit(X, parametric_init=True, rotation_iter=1) 81 | err = np.linalg.norm(X - opq.decode(opq.encode(X))) 82 | 83 | self.assertLess(err, err_init) 84 | 85 | 86 | if __name__ == "__main__": 87 | unittest.main() 88 | -------------------------------------------------------------------------------- /tests/test_pq.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | sys.path.append(str(Path(__file__).resolve().parent)) 5 | import unittest 6 | 7 | import nanopq 8 | import numpy as np 9 | 10 | 11 | class TestSuite(unittest.TestCase): 12 | def setUp(self): 13 | np.random.seed(123) 14 | 15 | def test_instantiate(self): 16 | pq1 = nanopq.PQ(M=4, Ks=256) 17 | pq2 = nanopq.PQ(M=4, Ks=500) 18 | pq3 = nanopq.PQ(M=4, Ks=2**16 + 10) 19 | self.assertEqual(pq1.code_dtype, np.uint8) 20 | self.assertEqual(pq2.code_dtype, np.uint16) 21 | self.assertEqual(pq3.code_dtype, np.uint32) 22 | 23 | def test_fit(self): 24 | N, D, M, Ks = 100, 12, 4, 10 25 | X = np.random.random((N, D)).astype(np.float32) 26 | pq = nanopq.PQ(M=M, Ks=Ks) 27 | pq.fit(X) 28 | self.assertEqual(pq.Ds, D / M) 29 | self.assertEqual(pq.codewords.shape, (M, Ks, D / M)) 30 | 31 | pq2 = nanopq.PQ(M=M, Ks=Ks).fit(X) # Can be called as a chain 32 | self.assertTrue(np.allclose(pq.codewords, pq2.codewords)) 33 | 34 | def test_eq(self): 35 | import copy 36 | 37 | N, D, M, Ks = 100, 12, 4, 10 38 | X = np.random.random((N, D)).astype(np.float32) 39 | pq1 = nanopq.PQ(M=M, Ks=Ks) 40 | pq2 = nanopq.PQ(M=M, Ks=Ks) 41 | pq3 = copy.deepcopy(pq1) 42 | pq4 = nanopq.PQ(M=M, Ks=2 * Ks) 43 | self.assertTrue(pq1 == pq1) 44 | self.assertTrue(pq1 == pq2) 45 | self.assertTrue(pq1 == pq3) 46 | self.assertTrue(pq1 != pq4) 47 | 48 | pq1.fit(X) 49 | pq2.fit(X) 50 | pq3 = copy.deepcopy(pq1) 51 | pq4.fit(X) 52 | self.assertTrue(pq1 == pq1) 53 | self.assertTrue(pq1 == pq2) 54 | self.assertTrue(pq1 == pq3) 55 | self.assertTrue(pq1 != pq4) 56 | 57 | def test_encode_decode(self): 58 | N, D, M, Ks = 100, 12, 4, 10 59 | X = np.random.random((N, D)).astype(np.float32) 60 | pq = nanopq.PQ(M=M, Ks=Ks) 61 | pq.fit(X) 62 | X_ = pq.encode(X) # encoded 63 | self.assertEqual(X_.shape, (N, M)) 64 | self.assertEqual(X_.dtype, np.uint8) 65 | X__ = pq.decode(X_) # reconstructed 66 | self.assertEqual(X.shape, X__.shape) 67 | # The original X and the reconstructed X__ should be similar 68 | self.assertTrue(np.linalg.norm(X - X__) ** 2 / np.linalg.norm(X) ** 2 < 0.1) 69 | 70 | def test_search(self): 71 | N, D, M, Ks = 100, 12, 4, 10 72 | X = np.random.random((N, D)).astype(np.float32) 73 | pq = nanopq.PQ(M=M, Ks=Ks) 74 | pq.fit(X) 75 | X_ = pq.encode(X) 76 | q = X[13] 77 | dtbl = pq.dtable(q) 78 | self.assertEqual(dtbl.dtable.shape, (M, Ks)) 79 | dists = dtbl.adist(X_) 80 | self.assertEqual(len(dists), N) 81 | self.assertEqual(np.argmin(dists), 13) 82 | dists2 = pq.dtable(q).adist(X_) # can be chained 83 | self.assertAlmostEqual(dists.tolist(), dists2.tolist()) 84 | 85 | def test_pickle(self): 86 | import pickle 87 | 88 | N, D, M, Ks = 100, 12, 4, 10 89 | X = np.random.random((N, D)).astype(np.float32) 90 | pq = nanopq.PQ(M=M, Ks=Ks) 91 | pq.fit(X) 92 | dumped = pickle.dumps(pq) 93 | pq2 = pickle.loads(dumped) 94 | self.assertEqual( 95 | (pq.M, pq.Ks, pq.verbose, pq.code_dtype, pq.Ds), 96 | (pq2.M, pq2.Ks, pq2.verbose, pq2.code_dtype, pq2.Ds), 97 | ) 98 | self.assertTrue(np.allclose(pq.codewords, pq2.codewords)) 99 | self.assertTrue(pq == pq2) 100 | 101 | def test_ip(self): 102 | N, D, M, Ks = 100, 12, 4, 10 103 | X = np.random.random((N, D)).astype(np.float32) 104 | pq = nanopq.PQ(M=M, Ks=Ks, metric="dot") 105 | pq.fit(X) 106 | X_ = pq.encode(X) 107 | q = X[13] 108 | dist1 = pq.dtable(q).adist(X_) 109 | dtable = np.empty((pq.M, pq.Ks), dtype=np.float32) 110 | for m in range(pq.M): 111 | query_sub = q[m * pq.Ds : (m + 1) * pq.Ds] 112 | dtable[m, :] = np.matmul(pq.codewords[m], query_sub[None, :].T).sum(axis=-1) 113 | dist2 = np.sum(dtable[range(M), X_], axis=1) 114 | self.assertTrue((dist1 == dist2).all()) 115 | self.assertTrue(abs(np.mean(np.matmul(X, q[:, None]).squeeze() - dist1)) < 1e-7) 116 | 117 | 118 | if __name__ == "__main__": 119 | unittest.main() 120 | --------------------------------------------------------------------------------