The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .github
    └── workflows
    │   └── build.yml
├── .gitignore
├── ALGO_PARAMS.md
├── CMakeLists.txt
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── TESTING_RECALL.md
├── examples
    ├── cpp
    │   ├── EXAMPLES.md
    │   ├── example_epsilon_search.cpp
    │   ├── example_filter.cpp
    │   ├── example_mt_filter.cpp
    │   ├── example_mt_replace_deleted.cpp
    │   ├── example_mt_search.cpp
    │   ├── example_multivector_search.cpp
    │   ├── example_replace_deleted.cpp
    │   └── example_search.cpp
    └── python
    │   ├── EXAMPLES.md
    │   ├── example.py
    │   ├── example_filter.py
    │   ├── example_replace_deleted.py
    │   ├── example_search.py
    │   ├── example_serialization.py
    │   └── pyw_hnswlib.py
├── hnswlib
    ├── bruteforce.h
    ├── hnswalg.h
    ├── hnswlib.h
    ├── space_ip.h
    ├── space_l2.h
    ├── stop_condition.h
    └── visited_list_pool.h
├── pyproject.toml
├── python_bindings
    ├── LazyIndex.py
    ├── __init__.py
    ├── bindings.cpp
    ├── setup.py
    └── tests
    │   └── bindings_test_bf_index.py
├── setup.py
└── tests
    ├── cpp
        ├── download_bigann.py
        ├── epsilon_search_test.cpp
        ├── main.cpp
        ├── multiThreadLoad_test.cpp
        ├── multiThread_replace_test.cpp
        ├── multivector_search_test.cpp
        ├── searchKnnCloserFirst_test.cpp
        ├── searchKnnWithFilter_test.cpp
        ├── sift_1b.cpp
        ├── sift_test.cpp
        ├── update_gen_data.py
        └── updates_test.cpp
    └── python
        ├── bindings_test.py
        ├── bindings_test_filter.py
        ├── bindings_test_getdata.py
        ├── bindings_test_labels.py
        ├── bindings_test_metadata.py
        ├── bindings_test_pickle.py
        ├── bindings_test_recall.py
        ├── bindings_test_replace.py
        ├── bindings_test_resize.py
        ├── bindings_test_spaces.py
        ├── bindings_test_stress_mt_replace.py
        ├── draw_git_test_plots.py
        ├── git_tester.py
        └── speedtest.py


/.github/workflows/build.yml:
--------------------------------------------------------------------------------
 1 | name: HNSW CI
 2 | 
 3 | on: [push, pull_request]
 4 | 
 5 | jobs:
 6 |   test_python:
 7 |     runs-on: ${{matrix.os}}
 8 |     strategy:
 9 |       matrix:
10 |         os: [ubuntu-latest, windows-latest, macos-latest]
11 |         python-version: ["3.7", "3.8", "3.9", "3.10"]
12 |     steps:
13 |       - uses: actions/checkout@v3
14 |       - uses: actions/setup-python@v4
15 |         with:
16 |           python-version: ${{ matrix.python-version }}
17 |           
18 |       - name: Build and install
19 |         run: python -m pip install .
20 |       
21 |       - name: Test
22 |         timeout-minutes: 15
23 |         run: |
24 |           python -m unittest discover -v --start-directory examples/python --pattern "example*.py"
25 |           python -m unittest discover -v --start-directory tests/python --pattern "bindings_test*.py"
26 |   
27 |   test_cpp:
28 |     runs-on: ${{matrix.os}}
29 |     strategy:
30 |       matrix:
31 |         os: [ubuntu-latest, windows-latest, macos-latest]
32 |     steps:
33 |       - uses: actions/checkout@v3
34 |       - uses: actions/setup-python@v4
35 |         with:
36 |           python-version: "3.10"
37 | 
38 |       - name: Build
39 |         run: |
40 |           mkdir build
41 |           cd build
42 |           cmake ..
43 |           if [ "$RUNNER_OS" == "Windows" ]; then
44 |             cmake --build ./ --config Release
45 |           else
46 |             make
47 |           fi
48 |         shell: bash
49 | 
50 |       - name: Prepare test data
51 |         run: |
52 |           pip install numpy
53 |           cd tests/cpp/
54 |           python update_gen_data.py
55 |         shell: bash
56 |       
57 |       - name: Test
58 |         timeout-minutes: 15
59 |         run: |
60 |           cd build
61 |           if [ "$RUNNER_OS" == "Windows" ]; then
62 |             cp ./Release/* ./
63 |           fi
64 |           ./example_search
65 |           ./example_filter
66 |           ./example_replace_deleted
67 |           ./example_mt_search
68 |           ./example_mt_filter
69 |           ./example_mt_replace_deleted
70 |           ./example_multivector_search
71 |           ./example_epsilon_search
72 |           ./searchKnnCloserFirst_test
73 |           ./searchKnnWithFilter_test
74 |           ./multiThreadLoad_test
75 |           ./multiThread_replace_test
76 |           ./test_updates
77 |           ./test_updates update
78 |           ./multivector_search_test
79 |           ./epsilon_search_test
80 |         shell: bash
81 | 


--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
 1 | hnswlib.egg-info/
 2 | build/
 3 | dist/
 4 | tmp/
 5 | python_bindings/tests/__pycache__/
 6 | *.pyd
 7 | hnswlib.cpython*.so
 8 | var/
 9 | .idea/
10 | .vscode/
11 | .vs/
12 | **.DS_Store
13 | *.pyc
14 | 


--------------------------------------------------------------------------------
/ALGO_PARAMS.md:
--------------------------------------------------------------------------------
 1 | # HNSW algorithm parameters
 2 | 
 3 | ## Search parameters:
 4 | * ```ef``` - the size of the dynamic list for the nearest neighbors (used during the search). Higher ```ef```
 5 | leads to more accurate but slower search. ```ef``` cannot be set lower than the number of queried nearest neighbors
 6 | ```k```. The value ```ef``` of can be anything between ```k``` and the size of the dataset.
 7 | * ```k``` number of nearest neighbors to be returned as the result.
 8 | The ```knn_query``` function returns two numpy arrays, containing labels and distances to the k found nearest 
 9 | elements for the queries. Note that in case the algorithm is not be able to find ```k``` neighbors to all of the queries,
10 | (this can be due to problems with graph or ```k```>size of the dataset) an exception is thrown.
11 | 
12 | An example of tuning the parameters can be found in [TESTING_RECALL.md](TESTING_RECALL.md)
13 | 
14 | ## Construction parameters:
15 | * ```M``` - the number of bi-directional links created for every new element during construction. Reasonable range for ```M``` 
16 | is 2-100. Higher ```M``` work better on datasets with high intrinsic dimensionality and/or high recall, while low ```M``` work 
17 | better for datasets with low intrinsic dimensionality and/or low recalls. The parameter also determines the algorithm's memory 
18 | consumption, which is roughly ```M * 8-10``` bytes per stored element.  
19 | As an example for ```dim```=4 random vectors optimal ```M``` for search is somewhere around 6, while for high dimensional datasets 
20 | (word embeddings, good face descriptors), higher ```M``` are required (e.g. ```M```=48-64) for optimal performance at high recall. 
21 | The range ```M```=12-48 is ok for the most of the use cases. When ```M``` is changed one has to update the other parameters. 
22 | Nonetheless, ef and ef_construction parameters can be roughly estimated by assuming that ```M```*```ef_{construction}``` is 
23 | a constant.
24 | 
25 | * ```ef_construction``` - the parameter has the same meaning as ```ef```, but controls the index_time/index_accuracy. Bigger 
26 | ef_construction leads to longer construction, but better index quality. At some point, increasing ef_construction does
27 | not improve the quality of the index. One way to check if the selection of ef_construction was ok is to measure a recall 
28 | for M nearest neighbor search when ```ef``` =```ef_construction```: if the recall is lower than 0.9, than there is room 
29 | for improvement.
30 | * ```num_elements``` - defines the maximum number of elements in the index. The index can be extended by saving/loading (load_index
31 | function has a parameter which defines the new maximum number of elements).
32 | 


--------------------------------------------------------------------------------
/CMakeLists.txt:
--------------------------------------------------------------------------------
  1 | cmake_minimum_required(VERSION 3.0...3.26)
  2 | 
  3 | project(hnswlib
  4 |     LANGUAGES CXX)
  5 | 
  6 | include(GNUInstallDirs)
  7 | include(CheckCXXCompilerFlag)
  8 | 
  9 | add_library(hnswlib INTERFACE)
 10 | add_library(hnswlib::hnswlib ALIAS hnswlib)
 11 | 
 12 | target_include_directories(hnswlib INTERFACE
 13 |     
lt;BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>
 14 |     
lt;INSTALL_INTERFACE:${CMAKE_INSTALL_INCLUDEDIR}>)
 15 | 
 16 | # Install
 17 | install(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/hnswlib
 18 |     DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
 19 | 
 20 | install(TARGETS hnswlib
 21 |     EXPORT hnswlibTargets)
 22 | 
 23 | install(EXPORT hnswlibTargets
 24 |     FILE hnswlibConfig.cmake
 25 |     NAMESPACE hnswlib::
 26 |     DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/hnswlib)
 27 | 
 28 | # Examples and tests
 29 | if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME)
 30 |     option(HNSWLIB_EXAMPLES "Build examples and tests." ON)
 31 | else()
 32 |     option(HNSWLIB_EXAMPLES "Build examples and tests." OFF)
 33 | endif()
 34 | if(HNSWLIB_EXAMPLES)
 35 |     set(CMAKE_CXX_STANDARD 11)
 36 | 
 37 |     if (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
 38 |       SET( CMAKE_CXX_FLAGS  "-Ofast -std=c++11 -DHAVE_CXX0X -openmp -fpic -ftree-vectorize" )
 39 |       check_cxx_compiler_flag("-march=native" COMPILER_SUPPORT_NATIVE_FLAG)
 40 |       if(COMPILER_SUPPORT_NATIVE_FLAG)
 41 |         SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native" )
 42 |         message("set -march=native flag")
 43 |       else()
 44 |         check_cxx_compiler_flag("-mcpu=apple-m1" COMPILER_SUPPORT_M1_FLAG)
 45 |         if(COMPILER_SUPPORT_M1_FLAG)
 46 |           SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mcpu=apple-m1" )
 47 |           message("set -mcpu=apple-m1 flag")
 48 |         endif()
 49 |       endif()
 50 |     elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
 51 |       SET( CMAKE_CXX_FLAGS  "-Ofast -lrt -std=c++11 -DHAVE_CXX0X -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0" )
 52 |     elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
 53 |       SET( CMAKE_CXX_FLAGS  "/O2 -DHAVE_CXX0X /W1 /openmp /EHsc" )
 54 |     endif()
 55 | 
 56 |     # examples
 57 |     add_executable(example_search examples/cpp/example_search.cpp)
 58 |     target_link_libraries(example_search hnswlib)
 59 | 
 60 |     add_executable(example_epsilon_search examples/cpp/example_epsilon_search.cpp)
 61 |     target_link_libraries(example_epsilon_search hnswlib)
 62 | 
 63 |     add_executable(example_multivector_search examples/cpp/example_multivector_search.cpp)
 64 |     target_link_libraries(example_multivector_search hnswlib)
 65 | 
 66 |     add_executable(example_filter examples/cpp/example_filter.cpp)
 67 |     target_link_libraries(example_filter hnswlib)
 68 | 
 69 |     add_executable(example_replace_deleted examples/cpp/example_replace_deleted.cpp)
 70 |     target_link_libraries(example_replace_deleted hnswlib)
 71 | 
 72 |     add_executable(example_mt_search examples/cpp/example_mt_search.cpp)
 73 |     target_link_libraries(example_mt_search hnswlib)
 74 | 
 75 |     add_executable(example_mt_filter examples/cpp/example_mt_filter.cpp)
 76 |     target_link_libraries(example_mt_filter hnswlib)
 77 | 
 78 |     add_executable(example_mt_replace_deleted examples/cpp/example_mt_replace_deleted.cpp)
 79 |     target_link_libraries(example_mt_replace_deleted hnswlib)
 80 | 
 81 |     # tests
 82 |     add_executable(multivector_search_test tests/cpp/multivector_search_test.cpp)
 83 |     target_link_libraries(multivector_search_test hnswlib)
 84 | 
 85 |     add_executable(epsilon_search_test tests/cpp/epsilon_search_test.cpp)
 86 |     target_link_libraries(epsilon_search_test hnswlib)
 87 | 
 88 |     add_executable(test_updates tests/cpp/updates_test.cpp)
 89 |     target_link_libraries(test_updates hnswlib)
 90 | 
 91 |     add_executable(searchKnnCloserFirst_test tests/cpp/searchKnnCloserFirst_test.cpp)
 92 |     target_link_libraries(searchKnnCloserFirst_test hnswlib)
 93 | 
 94 |     add_executable(searchKnnWithFilter_test tests/cpp/searchKnnWithFilter_test.cpp)
 95 |     target_link_libraries(searchKnnWithFilter_test hnswlib)
 96 | 
 97 |     add_executable(multiThreadLoad_test tests/cpp/multiThreadLoad_test.cpp)
 98 |     target_link_libraries(multiThreadLoad_test hnswlib)
 99 | 
100 |     add_executable(multiThread_replace_test tests/cpp/multiThread_replace_test.cpp)
101 |     target_link_libraries(multiThread_replace_test hnswlib)
102 | 
103 |     add_executable(main tests/cpp/main.cpp tests/cpp/sift_1b.cpp)
104 |     target_link_libraries(main hnswlib)
105 | endif()
106 | 


--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
  1 |                                  Apache License
  2 |                            Version 2.0, January 2004
  3 |                         http://www.apache.org/licenses/
  4 | 
  5 |    TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
  6 | 
  7 |    1. Definitions.
  8 | 
  9 |       "License" shall mean the terms and conditions for use, reproduction,
 10 |       and distribution as defined by Sections 1 through 9 of this document.
 11 | 
 12 |       "Licensor" shall mean the copyright owner or entity authorized by
 13 |       the copyright owner that is granting the License.
 14 | 
 15 |       "Legal Entity" shall mean the union of the acting entity and all
 16 |       other entities that control, are controlled by, or are under common
 17 |       control with that entity. For the purposes of this definition,
 18 |       "control" means (i) the power, direct or indirect, to cause the
 19 |       direction or management of such entity, whether by contract or
 20 |       otherwise, or (ii) ownership of fifty percent (50%) or more of the
 21 |       outstanding shares, or (iii) beneficial ownership of such entity.
 22 | 
 23 |       "You" (or "Your") shall mean an individual or Legal Entity
 24 |       exercising permissions granted by this License.
 25 | 
 26 |       "Source" form shall mean the preferred form for making modifications,
 27 |       including but not limited to software source code, documentation
 28 |       source, and configuration files.
 29 | 
 30 |       "Object" form shall mean any form resulting from mechanical
 31 |       transformation or translation of a Source form, including but
 32 |       not limited to compiled object code, generated documentation,
 33 |       and conversions to other media types.
 34 | 
 35 |       "Work" shall mean the work of authorship, whether in Source or
 36 |       Object form, made available under the License, as indicated by a
 37 |       copyright notice that is included in or attached to the work
 38 |       (an example is provided in the Appendix below).
 39 | 
 40 |       "Derivative Works" shall mean any work, whether in Source or Object
 41 |       form, that is based on (or derived from) the Work and for which the
 42 |       editorial revisions, annotations, elaborations, or other modifications
 43 |       represent, as a whole, an original work of authorship. For the purposes
 44 |       of this License, Derivative Works shall not include works that remain
 45 |       separable from, or merely link (or bind by name) to the interfaces of,
 46 |       the Work and Derivative Works thereof.
 47 | 
 48 |       "Contribution" shall mean any work of authorship, including
 49 |       the original version of the Work and any modifications or additions
 50 |       to that Work or Derivative Works thereof, that is intentionally
 51 |       submitted to Licensor for inclusion in the Work by the copyright owner
 52 |       or by an individual or Legal Entity authorized to submit on behalf of
 53 |       the copyright owner. For the purposes of this definition, "submitted"
 54 |       means any form of electronic, verbal, or written communication sent
 55 |       to the Licensor or its representatives, including but not limited to
 56 |       communication on electronic mailing lists, source code control systems,
 57 |       and issue tracking systems that are managed by, or on behalf of, the
 58 |       Licensor for the purpose of discussing and improving the Work, but
 59 |       excluding communication that is conspicuously marked or otherwise
 60 |       designated in writing by the copyright owner as "Not a Contribution."
 61 | 
 62 |       "Contributor" shall mean Licensor and any individual or Legal Entity
 63 |       on behalf of whom a Contribution has been received by Licensor and
 64 |       subsequently incorporated within the Work.
 65 | 
 66 |    2. Grant of Copyright License. Subject to the terms and conditions of
 67 |       this License, each Contributor hereby grants to You a perpetual,
 68 |       worldwide, non-exclusive, no-charge, royalty-free, irrevocable
 69 |       copyright license to reproduce, prepare Derivative Works of,
 70 |       publicly display, publicly perform, sublicense, and distribute the
 71 |       Work and such Derivative Works in Source or Object form.
 72 | 
 73 |    3. Grant of Patent License. Subject to the terms and conditions of
 74 |       this License, each Contributor hereby grants to You a perpetual,
 75 |       worldwide, non-exclusive, no-charge, royalty-free, irrevocable
 76 |       (except as stated in this section) patent license to make, have made,
 77 |       use, offer to sell, sell, import, and otherwise transfer the Work,
 78 |       where such license applies only to those patent claims licensable
 79 |       by such Contributor that are necessarily infringed by their
 80 |       Contribution(s) alone or by combination of their Contribution(s)
 81 |       with the Work to which such Contribution(s) was submitted. If You
 82 |       institute patent litigation against any entity (including a
 83 |       cross-claim or counterclaim in a lawsuit) alleging that the Work
 84 |       or a Contribution incorporated within the Work constitutes direct
 85 |       or contributory patent infringement, then any patent licenses
 86 |       granted to You under this License for that Work shall terminate
 87 |       as of the date such litigation is filed.
 88 | 
 89 |    4. Redistribution. You may reproduce and distribute copies of the
 90 |       Work or Derivative Works thereof in any medium, with or without
 91 |       modifications, and in Source or Object form, provided that You
 92 |       meet the following conditions:
 93 | 
 94 |       (a) You must give any other recipients of the Work or
 95 |           Derivative Works a copy of this License; and
 96 | 
 97 |       (b) You must cause any modified files to carry prominent notices
 98 |           stating that You changed the files; and
 99 | 
100 |       (c) You must retain, in the Source form of any Derivative Works
101 |           that You distribute, all copyright, patent, trademark, and
102 |           attribution notices from the Source form of the Work,
103 |           excluding those notices that do not pertain to any part of
104 |           the Derivative Works; and
105 | 
106 |       (d) If the Work includes a "NOTICE" text file as part of its
107 |           distribution, then any Derivative Works that You distribute must
108 |           include a readable copy of the attribution notices contained
109 |           within such NOTICE file, excluding those notices that do not
110 |           pertain to any part of the Derivative Works, in at least one
111 |           of the following places: within a NOTICE text file distributed
112 |           as part of the Derivative Works; within the Source form or
113 |           documentation, if provided along with the Derivative Works; or,
114 |           within a display generated by the Derivative Works, if and
115 |           wherever such third-party notices normally appear. The contents
116 |           of the NOTICE file are for informational purposes only and
117 |           do not modify the License. You may add Your own attribution
118 |           notices within Derivative Works that You distribute, alongside
119 |           or as an addendum to the NOTICE text from the Work, provided
120 |           that such additional attribution notices cannot be construed
121 |           as modifying the License.
122 | 
123 |       You may add Your own copyright statement to Your modifications and
124 |       may provide additional or different license terms and conditions
125 |       for use, reproduction, or distribution of Your modifications, or
126 |       for any such Derivative Works as a whole, provided Your use,
127 |       reproduction, and distribution of the Work otherwise complies with
128 |       the conditions stated in this License.
129 | 
130 |    5. Submission of Contributions. Unless You explicitly state otherwise,
131 |       any Contribution intentionally submitted for inclusion in the Work
132 |       by You to the Licensor shall be under the terms and conditions of
133 |       this License, without any additional terms or conditions.
134 |       Notwithstanding the above, nothing herein shall supersede or modify
135 |       the terms of any separate license agreement you may have executed
136 |       with Licensor regarding such Contributions.
137 | 
138 |    6. Trademarks. This License does not grant permission to use the trade
139 |       names, trademarks, service marks, or product names of the Licensor,
140 |       except as required for reasonable and customary use in describing the
141 |       origin of the Work and reproducing the content of the NOTICE file.
142 | 
143 |    7. Disclaimer of Warranty. Unless required by applicable law or
144 |       agreed to in writing, Licensor provides the Work (and each
145 |       Contributor provides its Contributions) on an "AS IS" BASIS,
146 |       WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 |       implied, including, without limitation, any warranties or conditions
148 |       of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 |       PARTICULAR PURPOSE. You are solely responsible for determining the
150 |       appropriateness of using or redistributing the Work and assume any
151 |       risks associated with Your exercise of permissions under this License.
152 | 
153 |    8. Limitation of Liability. In no event and under no legal theory,
154 |       whether in tort (including negligence), contract, or otherwise,
155 |       unless required by applicable law (such as deliberate and grossly
156 |       negligent acts) or agreed to in writing, shall any Contributor be
157 |       liable to You for damages, including any direct, indirect, special,
158 |       incidental, or consequential damages of any character arising as a
159 |       result of this License or out of the use or inability to use the
160 |       Work (including but not limited to damages for loss of goodwill,
161 |       work stoppage, computer failure or malfunction, or any and all
162 |       other commercial damages or losses), even if such Contributor
163 |       has been advised of the possibility of such damages.
164 | 
165 |    9. Accepting Warranty or Additional Liability. While redistributing
166 |       the Work or Derivative Works thereof, You may choose to offer,
167 |       and charge a fee for, acceptance of support, warranty, indemnity,
168 |       or other liability obligations and/or rights consistent with this
169 |       License. However, in accepting such obligations, You may act only
170 |       on Your own behalf and on Your sole responsibility, not on behalf
171 |       of any other Contributor, and only if You agree to indemnify,
172 |       defend, and hold each Contributor harmless for any liability
173 |       incurred by, or claims asserted against, such Contributor by reason
174 |       of your accepting any such warranty or additional liability.
175 | 
176 |    END OF TERMS AND CONDITIONS
177 | 
178 |    APPENDIX: How to apply the Apache License to your work.
179 | 
180 |       To apply the Apache License to your work, attach the following
181 |       boilerplate notice, with the fields enclosed by brackets "{}"
182 |       replaced with your own identifying information. (Don't include
183 |       the brackets!)  The text should be enclosed in the appropriate
184 |       comment syntax for the file format. We also recommend that a
185 |       file or class name and description of purpose be included on the
186 |       same "printed page" as the copyright notice for easier
187 |       identification within third-party archives.
188 | 
189 |    Copyright {yyyy} {name of copyright owner}
190 | 
191 |    Licensed under the Apache License, Version 2.0 (the "License");
192 |    you may not use this file except in compliance with the License.
193 |    You may obtain a copy of the License at
194 | 
195 |        http://www.apache.org/licenses/LICENSE-2.0
196 | 
197 |    Unless required by applicable law or agreed to in writing, software
198 |    distributed under the License is distributed on an "AS IS" BASIS,
199 |    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 |    See the License for the specific language governing permissions and
201 |    limitations under the License.
202 | 


--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include hnswlib/*.h
2 | include LICENSE
3 | 


--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
 1 | pypi: dist
 2 | 	twine upload dist/*
 3 | 
 4 | dist:
 5 | 	-rm dist/*
 6 | 	pip install build
 7 | 	python3 -m build --sdist
 8 | 
 9 | test:
10 | 	python3 -m unittest discover --start-directory tests/python --pattern "bindings_test*.py"
11 | 
12 | clean:
13 | 	rm -rf *.egg-info build dist tmp var tests/__pycache__ hnswlib.cpython*.so
14 | 
15 | .PHONY: dist
16 | 


--------------------------------------------------------------------------------
/TESTING_RECALL.md:
--------------------------------------------------------------------------------
 1 | # Testing recall
 2 | 
 3 | Selecting HNSW parameters for a specific use case highly impacts the search quality. One way to test the quality of the constructed index is to compare the HNSW search results to the actual results (i.e., the actual `k` nearest neighbors).
 4 | For that cause, the API enables creating a simple "brute-force" index in which vectors are stored as is, and searching for the `k` nearest neighbors to a query vector requires going over the entire index.  
 5 | Comparing between HNSW and brute-force results may help with finding the desired HNSW parameters for achieving a satisfying recall, based on the index size and data dimension.
 6 | 
 7 | ### Brute force index API
 8 | `hnswlib.BFIndex(space, dim)` creates a non-initialized index in space `space` with integer dimension `dim`.
 9 | 
10 | `hnswlib.BFIndex` methods:
11 | 
12 | `init_index(max_elements)` initializes the index with no elements.
13 | 
14 | max_elements defines the maximum number of elements that can be stored in the structure.
15 | 
16 | `add_items(data, ids)` inserts the data (numpy array of vectors, shape:`N*dim`) into the structure.
17 | `ids` are optional N-size numpy array of integer labels for all elements in data.
18 | 
19 | `delete_vector(label)` delete the element associated with the given `label` so it will be omitted from search results.
20 | 
21 | `knn_query(data, k = 1)` make a batch query for `k `closest elements for each element of the
22 | `data` (shape:`N*dim`). Returns a numpy array of (shape:`N*k`).
23 | 
24 | `load_index(path_to_index, max_elements = 0)` loads the index from persistence to the uninitialized index.
25 | 
26 | `save_index(path_to_index)` saves the index from persistence.
27 | 
28 | ### measuring recall example
29 | 
30 | ```python
31 | import hnswlib
32 | import numpy as np
33 | 
34 | dim = 32
35 | num_elements = 100000
36 | k = 10
37 | nun_queries = 10
38 | 
39 | # Generating sample data
40 | data = np.float32(np.random.random((num_elements, dim)))
41 | 
42 | # Declaring index
43 | hnsw_index = hnswlib.Index(space='l2', dim=dim)  # possible options are l2, cosine or ip
44 | bf_index = hnswlib.BFIndex(space='l2', dim=dim)
45 | 
46 | # Initing both hnsw and brute force indices
47 | # max_elements - the maximum number of elements (capacity). Will throw an exception if exceeded
48 | # during insertion of an element.
49 | # The capacity can be increased by saving/loading the index, see below.
50 | #
51 | # hnsw construction params:
52 | # ef_construction - controls index search speed/build speed tradeoff
53 | #
54 | # M - is tightly connected with internal dimensionality of the data. Strongly affects the memory consumption (~M)
55 | # Higher M leads to higher accuracy/run_time at fixed ef/efConstruction
56 | 
57 | hnsw_index.init_index(max_elements=num_elements, ef_construction=200, M=16)
58 | bf_index.init_index(max_elements=num_elements)
59 | 
60 | # Controlling the recall for hnsw by setting ef:
61 | # higher ef leads to better accuracy, but slower search
62 | hnsw_index.set_ef(200)
63 | 
64 | # Set number of threads used during batch search/construction in hnsw
65 | # By default using all available cores
66 | hnsw_index.set_num_threads(1)
67 | 
68 | print("Adding batch of %d elements" % (len(data)))
69 | hnsw_index.add_items(data)
70 | bf_index.add_items(data)
71 | 
72 | print("Indices built")
73 | 
74 | # Generating query data
75 | query_data = np.float32(np.random.random((nun_queries, dim)))
76 | 
77 | # Query the elements and measure recall:
78 | labels_hnsw, distances_hnsw = hnsw_index.knn_query(query_data, k)
79 | labels_bf, distances_bf = bf_index.knn_query(query_data, k)
80 | 
81 | # Measure recall
82 | correct = 0
83 | for i in range(nun_queries):
84 |     for label in labels_hnsw[i]:
85 |         for correct_label in labels_bf[i]:
86 |             if label == correct_label:
87 |                 correct += 1
88 |                 break
89 | 
90 | print("recall is :", float(correct)/(k*nun_queries))
91 | ```
92 | 


--------------------------------------------------------------------------------
/examples/cpp/EXAMPLES.md:
--------------------------------------------------------------------------------
  1 | # C++ examples
  2 | 
  3 | Creating index, inserting elements, searching and serialization
  4 | ```cpp
  5 | #include "../../hnswlib/hnswlib.h"
  6 | 
  7 | 
  8 | int main() {
  9 |     int dim = 16;               // Dimension of the elements
 10 |     int max_elements = 10000;   // Maximum number of elements, should be known beforehand
 11 |     int M = 16;                 // Tightly connected with internal dimensionality of the data
 12 |                                 // strongly affects the memory consumption
 13 |     int ef_construction = 200;  // Controls index search speed/build speed tradeoff
 14 | 
 15 |     // Initing index
 16 |     hnswlib::L2Space space(dim);
 17 |     hnswlib::HierarchicalNSW<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, max_elements, M, ef_construction);
 18 | 
 19 |     // Generate random data
 20 |     std::mt19937 rng;
 21 |     rng.seed(47);
 22 |     std::uniform_real_distribution<> distrib_real;
 23 |     float* data = new float[dim * max_elements];
 24 |     for (int i = 0; i < dim * max_elements; i++) {
 25 |         data[i] = distrib_real(rng);
 26 |     }
 27 | 
 28 |     // Add data to index
 29 |     for (int i = 0; i < max_elements; i++) {
 30 |         alg_hnsw->addPoint(data + i * dim, i);
 31 |     }
 32 | 
 33 |     // Query the elements for themselves and measure recall
 34 |     float correct = 0;
 35 |     for (int i = 0; i < max_elements; i++) {
 36 |         std::priority_queue<std::pair<float, hnswlib::labeltype>> result = alg_hnsw->searchKnn(data + i * dim, 1);
 37 |         hnswlib::labeltype label = result.top().second;
 38 |         if (label == i) correct++;
 39 |     }
 40 |     float recall = correct / max_elements;
 41 |     std::cout << "Recall: " << recall << "\n";
 42 | 
 43 |     // Serialize index
 44 |     std::string hnsw_path = "hnsw.bin";
 45 |     alg_hnsw->saveIndex(hnsw_path);
 46 |     delete alg_hnsw;
 47 | 
 48 |     // Deserialize index and check recall
 49 |     alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, hnsw_path);
 50 |     correct = 0;
 51 |     for (int i = 0; i < max_elements; i++) {
 52 |         std::priority_queue<std::pair<float, hnswlib::labeltype>> result = alg_hnsw->searchKnn(data + i * dim, 1);
 53 |         hnswlib::labeltype label = result.top().second;
 54 |         if (label == i) correct++;
 55 |     }
 56 |     recall = (float)correct / max_elements;
 57 |     std::cout << "Recall of deserialized index: " << recall << "\n";
 58 | 
 59 |     delete[] data;
 60 |     delete alg_hnsw;
 61 |     return 0;
 62 | }
 63 | ```
 64 | 
 65 | An example of filtering with a boolean function during the search:
 66 | ```cpp
 67 | #include "../../hnswlib/hnswlib.h"
 68 | 
 69 | 
 70 | // Filter that allows labels divisible by divisor
 71 | class PickDivisibleIds: public hnswlib::BaseFilterFunctor {
 72 | unsigned int divisor = 1;
 73 |  public:
 74 |     PickDivisibleIds(unsigned int divisor): divisor(divisor) {
 75 |         assert(divisor != 0);
 76 |     }
 77 |     bool operator()(hnswlib::labeltype label_id) {
 78 |         return label_id % divisor == 0;
 79 |     }
 80 | };
 81 | 
 82 | 
 83 | int main() {
 84 |     int dim = 16;               // Dimension of the elements
 85 |     int max_elements = 10000;   // Maximum number of elements, should be known beforehand
 86 |     int M = 16;                 // Tightly connected with internal dimensionality of the data
 87 |                                 // strongly affects the memory consumption
 88 |     int ef_construction = 200;  // Controls index search speed/build speed tradeoff
 89 | 
 90 |     // Initing index
 91 |     hnswlib::L2Space space(dim);
 92 |     hnswlib::HierarchicalNSW<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, max_elements, M, ef_construction);
 93 | 
 94 |     // Generate random data
 95 |     std::mt19937 rng;
 96 |     rng.seed(47);
 97 |     std::uniform_real_distribution<> distrib_real;
 98 |     float* data = new float[dim * max_elements];
 99 |     for (int i = 0; i < dim * max_elements; i++) {
100 |         data[i] = distrib_real(rng);
101 |     }
102 | 
103 |     // Add data to index
104 |     for (int i = 0; i < max_elements; i++) {
105 |         alg_hnsw->addPoint(data + i * dim, i);
106 |     }
107 | 
108 |     // Create filter that allows only even labels
109 |     PickDivisibleIds pickIdsDivisibleByTwo(2);
110 | 
111 |     // Query the elements for themselves with filter and check returned labels
112 |     int k = 10;
113 |     for (int i = 0; i < max_elements; i++) {
114 |         std::vector<std::pair<float, hnswlib::labeltype>> result = alg_hnsw->searchKnnCloserFirst(data + i * dim, k, &pickIdsDivisibleByTwo);
115 |         for (auto item: result) {
116 |             if (item.second % 2 == 1) std::cout << "Error: found odd label\n";
117 |         }
118 |     }
119 | 
120 |     delete[] data;
121 |     delete alg_hnsw;
122 |     return 0;
123 | }
124 | ```
125 | 
126 | An example with reusing the memory of the deleted elements when new elements are being added (via `allow_replace_deleted` flag):
127 | ```cpp
128 | #include "../../hnswlib/hnswlib.h"
129 | 
130 | 
131 | int main() {
132 |     int dim = 16;               // Dimension of the elements
133 |     int max_elements = 10000;   // Maximum number of elements, should be known beforehand
134 |     int M = 16;                 // Tightly connected with internal dimensionality of the data
135 |                                 // strongly affects the memory consumption
136 |     int ef_construction = 200;  // Controls index search speed/build speed tradeoff
137 | 
138 |     // Initing index
139 |     hnswlib::L2Space space(dim);
140 |     hnswlib::HierarchicalNSW<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, max_elements, M, ef_construction, 100, true);
141 | 
142 |     // Generate random data
143 |     std::mt19937 rng;
144 |     rng.seed(47);
145 |     std::uniform_real_distribution<> distrib_real;
146 |     float* data = new float[dim * max_elements];
147 |     for (int i = 0; i < dim * max_elements; i++) {
148 |         data[i] = distrib_real(rng);
149 |     }
150 | 
151 |     // Add data to index
152 |     for (int i = 0; i < max_elements; i++) {
153 |         alg_hnsw->addPoint(data + i * dim, i);
154 |     }
155 | 
156 |     // Mark first half of elements as deleted
157 |     int num_deleted = max_elements / 2;
158 |     for (int i = 0; i < num_deleted; i++) {
159 |         alg_hnsw->markDelete(i);
160 |     }
161 | 
162 |     float* add_data = new float[dim * num_deleted];
163 |     for (int i = 0; i < dim * num_deleted; i++) {
164 |         add_data[i] = distrib_real(rng);
165 |     }
166 | 
167 |     // Replace deleted data with new elements
168 |     // Maximum number of elements is reached therefore we cannot add new items,
169 |     // but we can replace the deleted ones by using replace_deleted=true
170 |     for (int i = 0; i < num_deleted; i++) {
171 |         int label = max_elements + i;
172 |         alg_hnsw->addPoint(add_data + i * dim, label, true);
173 |     }
174 | 
175 |     delete[] data;
176 |     delete[] add_data;
177 |     delete alg_hnsw;
178 |     return 0;
179 | }
180 | ```
181 | 
182 | Multithreaded examples:
183 | * Creating index, inserting elements, searching [example_mt_search.cpp](example_mt_search.cpp)
184 | * Filtering during the search with a boolean function [example_mt_filter.cpp](example_mt_filter.cpp)
185 | * Reusing the memory of the deleted elements when new elements are being added [example_mt_replace_deleted.cpp](example_mt_replace_deleted.cpp)
186 | 
187 | More examples:
188 | * Multivector search [example_multivector_search.cpp](example_multivector_search.cpp)
189 | * Epsilon search [example_epsilon_search.cpp](example_epsilon_search.cpp)


--------------------------------------------------------------------------------
/examples/cpp/example_epsilon_search.cpp:
--------------------------------------------------------------------------------
 1 | #include "../../hnswlib/hnswlib.h"
 2 | 
 3 | typedef unsigned int docidtype;
 4 | typedef float dist_t;
 5 | 
 6 | int main() {
 7 |     int dim = 16;                  // Dimension of the elements
 8 |     int max_elements = 10000;      // Maximum number of elements, should be known beforehand
 9 |     int M = 16;                    // Tightly connected with internal dimensionality of the data
10 |                                    // strongly affects the memory consumption
11 |     int ef_construction = 200;     // Controls index search speed/build speed tradeoff
12 |     int min_num_candidates = 100;  // Minimum number of candidates to search in the epsilon region
13 |                                    // this parameter is similar to ef
14 | 
15 |     int num_queries = 5;
16 |     float epsilon2 = 2.0;          // Squared distance to query
17 | 
18 |     // Initing index
19 |     hnswlib::L2Space space(dim);
20 |     hnswlib::HierarchicalNSW<dist_t>* alg_hnsw = new hnswlib::HierarchicalNSW<dist_t>(&space, max_elements, M, ef_construction);
21 | 
22 |     // Generate random data
23 |     std::mt19937 rng;
24 |     rng.seed(47);
25 |     std::uniform_real_distribution<> distrib_real;
26 | 
27 |     size_t data_point_size = space.get_data_size();
28 |     char* data = new char[data_point_size * max_elements];
29 |     for (int i = 0; i < max_elements; i++) {
30 |         char* point_data = data + i * data_point_size;
31 |         for (int j = 0; j < dim; j++) {
32 |             char* vec_data = point_data + j * sizeof(float);
33 |             float value = distrib_real(rng);
34 |             *(float*)vec_data = value;
35 |         }
36 |     }
37 | 
38 |     // Add data to index
39 |     for (int i = 0; i < max_elements; i++) {
40 |         hnswlib::labeltype label = i;
41 |         char* point_data = data + i * data_point_size;
42 |         alg_hnsw->addPoint(point_data, label);
43 |     }
44 | 
45 |     // Query random vectors
46 |     for (int i = 0; i < num_queries; i++) {
47 |         char* query_data = new char[data_point_size];
48 |         for (int j = 0; j < dim; j++) {
49 |             size_t offset = j * sizeof(float);
50 |             char* vec_data = query_data + offset;
51 |             float value = distrib_real(rng);
52 |             *(float*)vec_data = value;
53 |         }
54 |         std::cout << "Query #" << i << "\n";
55 |         hnswlib::EpsilonSearchStopCondition<dist_t> stop_condition(epsilon2, min_num_candidates, max_elements);
56 |         std::vector<std::pair<float, hnswlib::labeltype>> result = 
57 |             alg_hnsw->searchStopConditionClosest(query_data, stop_condition);
58 |         size_t num_vectors = result.size();
59 |         std::cout << "Found " << num_vectors << " vectors\n";
60 |         delete[] query_data;
61 |     }
62 | 
63 |     delete[] data;
64 |     delete alg_hnsw;
65 |     return 0;
66 | }
67 | 


--------------------------------------------------------------------------------
/examples/cpp/example_filter.cpp:
--------------------------------------------------------------------------------
 1 | #include "../../hnswlib/hnswlib.h"
 2 | 
 3 | 
 4 | // Filter that allows labels divisible by divisor
 5 | class PickDivisibleIds: public hnswlib::BaseFilterFunctor {
 6 | unsigned int divisor = 1;
 7 |  public:
 8 |     PickDivisibleIds(unsigned int divisor): divisor(divisor) {
 9 |         assert(divisor != 0);
10 |     }
11 |     bool operator()(hnswlib::labeltype label_id) {
12 |         return label_id % divisor == 0;
13 |     }
14 | };
15 | 
16 | 
17 | int main() {
18 |     int dim = 16;               // Dimension of the elements
19 |     int max_elements = 10000;   // Maximum number of elements, should be known beforehand
20 |     int M = 16;                 // Tightly connected with internal dimensionality of the data
21 |                                 // strongly affects the memory consumption
22 |     int ef_construction = 200;  // Controls index search speed/build speed tradeoff
23 | 
24 |     // Initing index
25 |     hnswlib::L2Space space(dim);
26 |     hnswlib::HierarchicalNSW<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, max_elements, M, ef_construction);
27 | 
28 |     // Generate random data
29 |     std::mt19937 rng;
30 |     rng.seed(47);
31 |     std::uniform_real_distribution<> distrib_real;
32 |     float* data = new float[dim * max_elements];
33 |     for (int i = 0; i < dim * max_elements; i++) {
34 |         data[i] = distrib_real(rng);
35 |     }
36 | 
37 |     // Add data to index
38 |     for (int i = 0; i < max_elements; i++) {
39 |         alg_hnsw->addPoint(data + i * dim, i);
40 |     }
41 | 
42 |     // Create filter that allows only even labels
43 |     PickDivisibleIds pickIdsDivisibleByTwo(2);
44 | 
45 |     // Query the elements for themselves with filter and check returned labels
46 |     int k = 10;
47 |     for (int i = 0; i < max_elements; i++) {
48 |         std::vector<std::pair<float, hnswlib::labeltype>> result = alg_hnsw->searchKnnCloserFirst(data + i * dim, k, &pickIdsDivisibleByTwo);
49 |         for (auto item: result) {
50 |             if (item.second % 2 == 1) std::cout << "Error: found odd label\n";
51 |         }
52 |     }
53 | 
54 |     delete[] data;
55 |     delete alg_hnsw;
56 |     return 0;
57 | }
58 | 


--------------------------------------------------------------------------------
/examples/cpp/example_mt_filter.cpp:
--------------------------------------------------------------------------------
  1 | #include "../../hnswlib/hnswlib.h"
  2 | #include <thread>
  3 | 
  4 | 
  5 | // Multithreaded executor
  6 | // The helper function copied from python_bindings/bindings.cpp (and that itself is copied from nmslib)
  7 | // An alternative is using #pragme omp parallel for or any other C++ threading
  8 | template<class Function>
  9 | inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) {
 10 |     if (numThreads <= 0) {
 11 |         numThreads = std::thread::hardware_concurrency();
 12 |     }
 13 | 
 14 |     if (numThreads == 1) {
 15 |         for (size_t id = start; id < end; id++) {
 16 |             fn(id, 0);
 17 |         }
 18 |     } else {
 19 |         std::vector<std::thread> threads;
 20 |         std::atomic<size_t> current(start);
 21 | 
 22 |         // keep track of exceptions in threads
 23 |         // https://stackoverflow.com/a/32428427/1713196
 24 |         std::exception_ptr lastException = nullptr;
 25 |         std::mutex lastExceptMutex;
 26 | 
 27 |         for (size_t threadId = 0; threadId < numThreads; ++threadId) {
 28 |             threads.push_back(std::thread([&, threadId] {
 29 |                 while (true) {
 30 |                     size_t id = current.fetch_add(1);
 31 | 
 32 |                     if (id >= end) {
 33 |                         break;
 34 |                     }
 35 | 
 36 |                     try {
 37 |                         fn(id, threadId);
 38 |                     } catch (...) {
 39 |                         std::unique_lock<std::mutex> lastExcepLock(lastExceptMutex);
 40 |                         lastException = std::current_exception();
 41 |                         /*
 42 |                          * This will work even when current is the largest value that
 43 |                          * size_t can fit, because fetch_add returns the previous value
 44 |                          * before the increment (what will result in overflow
 45 |                          * and produce 0 instead of current + 1).
 46 |                          */
 47 |                         current = end;
 48 |                         break;
 49 |                     }
 50 |                 }
 51 |             }));
 52 |         }
 53 |         for (auto &thread : threads) {
 54 |             thread.join();
 55 |         }
 56 |         if (lastException) {
 57 |             std::rethrow_exception(lastException);
 58 |         }
 59 |     }
 60 | }
 61 | 
 62 | 
 63 | // Filter that allows labels divisible by divisor
 64 | class PickDivisibleIds: public hnswlib::BaseFilterFunctor {
 65 | unsigned int divisor = 1;
 66 |  public:
 67 |     PickDivisibleIds(unsigned int divisor): divisor(divisor) {
 68 |         assert(divisor != 0);
 69 |     }
 70 |     bool operator()(hnswlib::labeltype label_id) {
 71 |         return label_id % divisor == 0;
 72 |     }
 73 | };
 74 | 
 75 | 
 76 | int main() {
 77 |     int dim = 16;               // Dimension of the elements
 78 |     int max_elements = 10000;   // Maximum number of elements, should be known beforehand
 79 |     int M = 16;                 // Tightly connected with internal dimensionality of the data
 80 |                                 // strongly affects the memory consumption
 81 |     int ef_construction = 200;  // Controls index search speed/build speed tradeoff
 82 |     int num_threads = 20;       // Number of threads for operations with index
 83 | 
 84 |     // Initing index
 85 |     hnswlib::L2Space space(dim);
 86 |     hnswlib::HierarchicalNSW<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, max_elements, M, ef_construction);
 87 | 
 88 |     // Generate random data
 89 |     std::mt19937 rng;
 90 |     rng.seed(47);
 91 |     std::uniform_real_distribution<> distrib_real;
 92 |     float* data = new float[dim * max_elements];
 93 |     for (int i = 0; i < dim * max_elements; i++) {
 94 |         data[i] = distrib_real(rng);
 95 |     }
 96 | 
 97 |     // Add data to index
 98 |     ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) {
 99 |         alg_hnsw->addPoint((void*)(data + dim * row), row);
100 |     });
101 | 
102 |     // Create filter that allows only even labels
103 |     PickDivisibleIds pickIdsDivisibleByTwo(2);
104 | 
105 |     // Query the elements for themselves with filter and check returned labels
106 |     int k = 10;
107 |     std::vector<hnswlib::labeltype> neighbors(max_elements * k);
108 |     ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) {
109 |         std::priority_queue<std::pair<float, hnswlib::labeltype>> result = alg_hnsw->searchKnn(data + dim * row, k, &pickIdsDivisibleByTwo);
110 |         for (int i = 0; i < k; i++) {
111 |             hnswlib::labeltype label = result.top().second;
112 |             result.pop();
113 |             neighbors[row * k + i] = label;
114 |         }
115 |     });
116 | 
117 |     for (hnswlib::labeltype label: neighbors) {
118 |         if (label % 2 == 1) std::cout << "Error: found odd label\n";
119 |     }
120 | 
121 |     delete[] data;
122 |     delete alg_hnsw;
123 |     return 0;
124 | }
125 | 


--------------------------------------------------------------------------------
/examples/cpp/example_mt_replace_deleted.cpp:
--------------------------------------------------------------------------------
  1 | #include "../../hnswlib/hnswlib.h"
  2 | #include <thread>
  3 | 
  4 | 
  5 | // Multithreaded executor
  6 | // The helper function copied from python_bindings/bindings.cpp (and that itself is copied from nmslib)
  7 | // An alternative is using #pragme omp parallel for or any other C++ threading
  8 | template<class Function>
  9 | inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) {
 10 |     if (numThreads <= 0) {
 11 |         numThreads = std::thread::hardware_concurrency();
 12 |     }
 13 | 
 14 |     if (numThreads == 1) {
 15 |         for (size_t id = start; id < end; id++) {
 16 |             fn(id, 0);
 17 |         }
 18 |     } else {
 19 |         std::vector<std::thread> threads;
 20 |         std::atomic<size_t> current(start);
 21 | 
 22 |         // keep track of exceptions in threads
 23 |         // https://stackoverflow.com/a/32428427/1713196
 24 |         std::exception_ptr lastException = nullptr;
 25 |         std::mutex lastExceptMutex;
 26 | 
 27 |         for (size_t threadId = 0; threadId < numThreads; ++threadId) {
 28 |             threads.push_back(std::thread([&, threadId] {
 29 |                 while (true) {
 30 |                     size_t id = current.fetch_add(1);
 31 | 
 32 |                     if (id >= end) {
 33 |                         break;
 34 |                     }
 35 | 
 36 |                     try {
 37 |                         fn(id, threadId);
 38 |                     } catch (...) {
 39 |                         std::unique_lock<std::mutex> lastExcepLock(lastExceptMutex);
 40 |                         lastException = std::current_exception();
 41 |                         /*
 42 |                          * This will work even when current is the largest value that
 43 |                          * size_t can fit, because fetch_add returns the previous value
 44 |                          * before the increment (what will result in overflow
 45 |                          * and produce 0 instead of current + 1).
 46 |                          */
 47 |                         current = end;
 48 |                         break;
 49 |                     }
 50 |                 }
 51 |             }));
 52 |         }
 53 |         for (auto &thread : threads) {
 54 |             thread.join();
 55 |         }
 56 |         if (lastException) {
 57 |             std::rethrow_exception(lastException);
 58 |         }
 59 |     }
 60 | }
 61 | 
 62 | 
 63 | int main() {
 64 |     int dim = 16;               // Dimension of the elements
 65 |     int max_elements = 10000;   // Maximum number of elements, should be known beforehand
 66 |     int M = 16;                 // Tightly connected with internal dimensionality of the data
 67 |                                 // strongly affects the memory consumption
 68 |     int ef_construction = 200;  // Controls index search speed/build speed tradeoff
 69 |     int num_threads = 20;       // Number of threads for operations with index
 70 | 
 71 |     // Initing index with allow_replace_deleted=true
 72 |     int seed = 100; 
 73 |     hnswlib::L2Space space(dim);
 74 |     hnswlib::HierarchicalNSW<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, max_elements, M, ef_construction, seed, true);
 75 | 
 76 |     // Generate random data
 77 |     std::mt19937 rng;
 78 |     rng.seed(47);
 79 |     std::uniform_real_distribution<> distrib_real;
 80 |     float* data = new float[dim * max_elements];
 81 |     for (int i = 0; i < dim * max_elements; i++) {
 82 |         data[i] = distrib_real(rng);
 83 |     }
 84 | 
 85 |     // Add data to index
 86 |     ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) {
 87 |         alg_hnsw->addPoint((void*)(data + dim * row), row);
 88 |     });
 89 | 
 90 |     // Mark first half of elements as deleted
 91 |     int num_deleted = max_elements / 2;
 92 |     ParallelFor(0, num_deleted, num_threads, [&](size_t row, size_t threadId) {
 93 |         alg_hnsw->markDelete(row);
 94 |     });
 95 | 
 96 |     // Generate additional random data
 97 |     float* add_data = new float[dim * num_deleted];
 98 |     for (int i = 0; i < dim * num_deleted; i++) {
 99 |         add_data[i] = distrib_real(rng);
100 |     }
101 | 
102 |     // Replace deleted data with new elements
103 |     // Maximum number of elements is reached therefore we cannot add new items,
104 |     // but we can replace the deleted ones by using replace_deleted=true
105 |     ParallelFor(0, num_deleted, num_threads, [&](size_t row, size_t threadId) {
106 |         hnswlib::labeltype label = max_elements + row;
107 |         alg_hnsw->addPoint((void*)(add_data + dim * row), label, true);
108 |     });
109 | 
110 |     delete[] data;
111 |     delete[] add_data;
112 |     delete alg_hnsw;
113 |     return 0;
114 | }
115 | 


--------------------------------------------------------------------------------
/examples/cpp/example_mt_search.cpp:
--------------------------------------------------------------------------------
  1 | #include "../../hnswlib/hnswlib.h"
  2 | #include <thread>
  3 | 
  4 | 
  5 | // Multithreaded executor
  6 | // The helper function copied from python_bindings/bindings.cpp (and that itself is copied from nmslib)
  7 | // An alternative is using #pragme omp parallel for or any other C++ threading
  8 | template<class Function>
  9 | inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) {
 10 |     if (numThreads <= 0) {
 11 |         numThreads = std::thread::hardware_concurrency();
 12 |     }
 13 | 
 14 |     if (numThreads == 1) {
 15 |         for (size_t id = start; id < end; id++) {
 16 |             fn(id, 0);
 17 |         }
 18 |     } else {
 19 |         std::vector<std::thread> threads;
 20 |         std::atomic<size_t> current(start);
 21 | 
 22 |         // keep track of exceptions in threads
 23 |         // https://stackoverflow.com/a/32428427/1713196
 24 |         std::exception_ptr lastException = nullptr;
 25 |         std::mutex lastExceptMutex;
 26 | 
 27 |         for (size_t threadId = 0; threadId < numThreads; ++threadId) {
 28 |             threads.push_back(std::thread([&, threadId] {
 29 |                 while (true) {
 30 |                     size_t id = current.fetch_add(1);
 31 | 
 32 |                     if (id >= end) {
 33 |                         break;
 34 |                     }
 35 | 
 36 |                     try {
 37 |                         fn(id, threadId);
 38 |                     } catch (...) {
 39 |                         std::unique_lock<std::mutex> lastExcepLock(lastExceptMutex);
 40 |                         lastException = std::current_exception();
 41 |                         /*
 42 |                          * This will work even when current is the largest value that
 43 |                          * size_t can fit, because fetch_add returns the previous value
 44 |                          * before the increment (what will result in overflow
 45 |                          * and produce 0 instead of current + 1).
 46 |                          */
 47 |                         current = end;
 48 |                         break;
 49 |                     }
 50 |                 }
 51 |             }));
 52 |         }
 53 |         for (auto &thread : threads) {
 54 |             thread.join();
 55 |         }
 56 |         if (lastException) {
 57 |             std::rethrow_exception(lastException);
 58 |         }
 59 |     }
 60 | }
 61 | 
 62 | 
 63 | int main() {
 64 |     int dim = 16;               // Dimension of the elements
 65 |     int max_elements = 10000;   // Maximum number of elements, should be known beforehand
 66 |     int M = 16;                 // Tightly connected with internal dimensionality of the data
 67 |                                 // strongly affects the memory consumption
 68 |     int ef_construction = 200;  // Controls index search speed/build speed tradeoff
 69 |     int num_threads = 20;       // Number of threads for operations with index
 70 | 
 71 |     // Initing index
 72 |     hnswlib::L2Space space(dim);
 73 |     hnswlib::HierarchicalNSW<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, max_elements, M, ef_construction);
 74 | 
 75 |     // Generate random data
 76 |     std::mt19937 rng;
 77 |     rng.seed(47);
 78 |     std::uniform_real_distribution<> distrib_real;
 79 |     float* data = new float[dim * max_elements];
 80 |     for (int i = 0; i < dim * max_elements; i++) {
 81 |         data[i] = distrib_real(rng);
 82 |     }
 83 | 
 84 |     // Add data to index
 85 |     ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) {
 86 |         alg_hnsw->addPoint((void*)(data + dim * row), row);
 87 |     });
 88 | 
 89 |     // Query the elements for themselves and measure recall
 90 |     std::vector<hnswlib::labeltype> neighbors(max_elements);
 91 |     ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) {
 92 |         std::priority_queue<std::pair<float, hnswlib::labeltype>> result = alg_hnsw->searchKnn(data + dim * row, 1);
 93 |         hnswlib::labeltype label = result.top().second;
 94 |         neighbors[row] = label;
 95 |     });
 96 |     float correct = 0;
 97 |     for (int i = 0; i < max_elements; i++) {
 98 |         hnswlib::labeltype label = neighbors[i];
 99 |         if (label == i) correct++;
100 |     }
101 |     float recall = correct / max_elements;
102 |     std::cout << "Recall: " << recall << "\n";
103 | 
104 |     delete[] data;
105 |     delete alg_hnsw;
106 |     return 0;
107 | }
108 | 


--------------------------------------------------------------------------------
/examples/cpp/example_multivector_search.cpp:
--------------------------------------------------------------------------------
 1 | #include "../../hnswlib/hnswlib.h"
 2 | 
 3 | typedef unsigned int docidtype;
 4 | typedef float dist_t;
 5 | 
 6 | int main() {
 7 |     int dim = 16;               // Dimension of the elements
 8 |     int max_elements = 10000;   // Maximum number of elements, should be known beforehand
 9 |     int M = 16;                 // Tightly connected with internal dimensionality of the data
10 |                                 // strongly affects the memory consumption
11 |     int ef_construction = 200;  // Controls index search speed/build speed tradeoff
12 | 
13 |     int num_queries = 5;
14 |     int num_docs = 5;           // Number of documents to search
15 |     int ef_collection = 6;      // Number of candidate documents during the search
16 |                                 // Controlls the recall: higher ef leads to better accuracy, but slower search
17 |     docidtype min_doc_id = 0;
18 |     docidtype max_doc_id = 9;
19 | 
20 |     // Initing index
21 |     hnswlib::MultiVectorL2Space<docidtype> space(dim);
22 |     hnswlib::HierarchicalNSW<dist_t>* alg_hnsw = new hnswlib::HierarchicalNSW<dist_t>(&space, max_elements, M, ef_construction);
23 | 
24 |     // Generate random data
25 |     std::mt19937 rng;
26 |     rng.seed(47);
27 |     std::uniform_real_distribution<> distrib_real;
28 |     std::uniform_int_distribution<docidtype> distrib_docid(min_doc_id, max_doc_id);
29 | 
30 |     size_t data_point_size = space.get_data_size();
31 |     char* data = new char[data_point_size * max_elements];
32 |     for (int i = 0; i < max_elements; i++) {
33 |         // set vector value
34 |         char* point_data = data + i * data_point_size;
35 |         for (int j = 0; j < dim; j++) {
36 |             char* vec_data = point_data + j * sizeof(float);
37 |             float value = distrib_real(rng);
38 |             *(float*)vec_data = value;
39 |         }
40 |         // set document id
41 |         docidtype doc_id = distrib_docid(rng);
42 |         space.set_doc_id(point_data, doc_id);
43 |     }
44 | 
45 |     // Add data to index
46 |     std::unordered_map<hnswlib::labeltype, docidtype> label_docid_lookup;
47 |     for (int i = 0; i < max_elements; i++) {
48 |         hnswlib::labeltype label = i;
49 |         char* point_data = data + i * data_point_size;
50 |         alg_hnsw->addPoint(point_data, label);
51 |         label_docid_lookup[label] = space.get_doc_id(point_data);
52 |     }
53 | 
54 |     // Query random vectors
55 |     size_t query_size = dim * sizeof(float);
56 |     for (int i = 0; i < num_queries; i++) {
57 |         char* query_data = new char[query_size];
58 |         for (int j = 0; j < dim; j++) {
59 |             size_t offset = j * sizeof(float);
60 |             char* vec_data = query_data + offset;
61 |             float value = distrib_real(rng);
62 |             *(float*)vec_data = value;
63 |         }
64 |         std::cout << "Query #" << i << "\n";
65 |         hnswlib::MultiVectorSearchStopCondition<docidtype, dist_t> stop_condition(space, num_docs, ef_collection);
66 |         std::vector<std::pair<float, hnswlib::labeltype>> result = 
67 |             alg_hnsw->searchStopConditionClosest(query_data, stop_condition);
68 |         size_t num_vectors = result.size();
69 | 
70 |         std::unordered_map<docidtype, size_t> doc_counter;
71 |         for (auto pair: result) {
72 |             hnswlib::labeltype label = pair.second;
73 |             docidtype doc_id = label_docid_lookup[label];
74 |             doc_counter[doc_id] += 1;
75 |         }
76 |         std::cout << "Found " << doc_counter.size() << " documents, " << num_vectors << " vectors\n";
77 |         delete[] query_data;
78 |     }
79 | 
80 |     delete[] data;
81 |     delete alg_hnsw;
82 |     return 0;
83 | }
84 | 


--------------------------------------------------------------------------------
/examples/cpp/example_replace_deleted.cpp:
--------------------------------------------------------------------------------
 1 | #include "../../hnswlib/hnswlib.h"
 2 | 
 3 | 
 4 | int main() {
 5 |     int dim = 16;               // Dimension of the elements
 6 |     int max_elements = 10000;   // Maximum number of elements, should be known beforehand
 7 |     int M = 16;                 // Tightly connected with internal dimensionality of the data
 8 |                                 // strongly affects the memory consumption
 9 |     int ef_construction = 200;  // Controls index search speed/build speed tradeoff
10 | 
11 |     // Initing index with allow_replace_deleted=true
12 |     int seed = 100; 
13 |     hnswlib::L2Space space(dim);
14 |     hnswlib::HierarchicalNSW<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, max_elements, M, ef_construction, seed, true);
15 | 
16 |     // Generate random data
17 |     std::mt19937 rng;
18 |     rng.seed(47);
19 |     std::uniform_real_distribution<> distrib_real;
20 |     float* data = new float[dim * max_elements];
21 |     for (int i = 0; i < dim * max_elements; i++) {
22 |         data[i] = distrib_real(rng);
23 |     }
24 | 
25 |     // Add data to index
26 |     for (int i = 0; i < max_elements; i++) {
27 |         alg_hnsw->addPoint(data + i * dim, i);
28 |     }
29 | 
30 |     // Mark first half of elements as deleted
31 |     int num_deleted = max_elements / 2;
32 |     for (int i = 0; i < num_deleted; i++) {
33 |         alg_hnsw->markDelete(i);
34 |     }
35 | 
36 |     // Generate additional random data
37 |     float* add_data = new float[dim * num_deleted];
38 |     for (int i = 0; i < dim * num_deleted; i++) {
39 |         add_data[i] = distrib_real(rng);
40 |     }
41 | 
42 |     // Replace deleted data with new elements
43 |     // Maximum number of elements is reached therefore we cannot add new items,
44 |     // but we can replace the deleted ones by using replace_deleted=true
45 |     for (int i = 0; i < num_deleted; i++) {
46 |         hnswlib::labeltype label = max_elements + i;
47 |         alg_hnsw->addPoint(add_data + i * dim, label, true);
48 |     }
49 | 
50 |     delete[] data;
51 |     delete[] add_data;
52 |     delete alg_hnsw;
53 |     return 0;
54 | }
55 | 


--------------------------------------------------------------------------------
/examples/cpp/example_search.cpp:
--------------------------------------------------------------------------------
 1 | #include "../../hnswlib/hnswlib.h"
 2 | 
 3 | 
 4 | int main() {
 5 |     int dim = 16;               // Dimension of the elements
 6 |     int max_elements = 10000;   // Maximum number of elements, should be known beforehand
 7 |     int M = 16;                 // Tightly connected with internal dimensionality of the data
 8 |                                 // strongly affects the memory consumption
 9 |     int ef_construction = 200;  // Controls index search speed/build speed tradeoff
10 | 
11 |     // Initing index
12 |     hnswlib::L2Space space(dim);
13 |     hnswlib::HierarchicalNSW<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, max_elements, M, ef_construction);
14 | 
15 |     // Generate random data
16 |     std::mt19937 rng;
17 |     rng.seed(47);
18 |     std::uniform_real_distribution<> distrib_real;
19 |     float* data = new float[dim * max_elements];
20 |     for (int i = 0; i < dim * max_elements; i++) {
21 |         data[i] = distrib_real(rng);
22 |     }
23 | 
24 |     // Add data to index
25 |     for (int i = 0; i < max_elements; i++) {
26 |         alg_hnsw->addPoint(data + i * dim, i);
27 |     }
28 | 
29 |     // Query the elements for themselves and measure recall
30 |     float correct = 0;
31 |     for (int i = 0; i < max_elements; i++) {
32 |         std::priority_queue<std::pair<float, hnswlib::labeltype>> result = alg_hnsw->searchKnn(data + i * dim, 1);
33 |         hnswlib::labeltype label = result.top().second;
34 |         if (label == i) correct++;
35 |     }
36 |     float recall = correct / max_elements;
37 |     std::cout << "Recall: " << recall << "\n";
38 | 
39 |     // Serialize index
40 |     std::string hnsw_path = "hnsw.bin";
41 |     alg_hnsw->saveIndex(hnsw_path);
42 |     delete alg_hnsw;
43 | 
44 |     // Deserialize index and check recall
45 |     alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, hnsw_path);
46 |     correct = 0;
47 |     for (int i = 0; i < max_elements; i++) {
48 |         std::priority_queue<std::pair<float, hnswlib::labeltype>> result = alg_hnsw->searchKnn(data + i * dim, 1);
49 |         hnswlib::labeltype label = result.top().second;
50 |         if (label == i) correct++;
51 |     }
52 |     recall = (float)correct / max_elements;
53 |     std::cout << "Recall of deserialized index: " << recall << "\n";
54 | 
55 |     delete[] data;
56 |     delete alg_hnsw;
57 |     return 0;
58 | }
59 | 


--------------------------------------------------------------------------------
/examples/python/EXAMPLES.md:
--------------------------------------------------------------------------------
  1 | # Python bindings examples
  2 | 
  3 | Creating index, inserting elements, searching and pickle serialization:
  4 | ```python
  5 | import hnswlib
  6 | import numpy as np
  7 | import pickle
  8 | 
  9 | dim = 128
 10 | num_elements = 10000
 11 | 
 12 | # Generating sample data
 13 | data = np.float32(np.random.random((num_elements, dim)))
 14 | ids = np.arange(num_elements)
 15 | 
 16 | # Declaring index
 17 | p = hnswlib.Index(space = 'l2', dim = dim) # possible options are l2, cosine or ip
 18 | 
 19 | # Initializing index - the maximum number of elements should be known beforehand
 20 | p.init_index(max_elements = num_elements, ef_construction = 200, M = 16)
 21 | 
 22 | # Element insertion (can be called several times):
 23 | p.add_items(data, ids)
 24 | 
 25 | # Controlling the recall by setting ef:
 26 | p.set_ef(50) # ef should always be > k
 27 | 
 28 | # Query dataset, k - number of the closest elements (returns 2 numpy arrays)
 29 | labels, distances = p.knn_query(data, k = 1)
 30 | 
 31 | # Index objects support pickling
 32 | # WARNING: serialization via pickle.dumps(p) or p.__getstate__() is NOT thread-safe with p.add_items method!
 33 | # Note: ef parameter is included in serialization; random number generator is initialized with random_seed on Index load
 34 | p_copy = pickle.loads(pickle.dumps(p)) # creates a copy of index p using pickle round-trip
 35 | 
 36 | ### Index parameters are exposed as class properties:
 37 | print(f"Parameters passed to constructor:  space={p_copy.space}, dim={p_copy.dim}") 
 38 | print(f"Index construction: M={p_copy.M}, ef_construction={p_copy.ef_construction}")
 39 | print(f"Index size is {p_copy.element_count} and index capacity is {p_copy.max_elements}")
 40 | print(f"Search speed/quality trade-off parameter: ef={p_copy.ef}")
 41 | ```
 42 | 
 43 | An example with updates after serialization/deserialization:
 44 | ```python
 45 | import hnswlib
 46 | import numpy as np
 47 | 
 48 | dim = 16
 49 | num_elements = 10000
 50 | 
 51 | # Generating sample data
 52 | data = np.float32(np.random.random((num_elements, dim)))
 53 | 
 54 | # We split the data in two batches:
 55 | data1 = data[:num_elements // 2]
 56 | data2 = data[num_elements // 2:]
 57 | 
 58 | # Declaring index
 59 | p = hnswlib.Index(space='l2', dim=dim)  # possible options are l2, cosine or ip
 60 | 
 61 | # Initializing index
 62 | # max_elements - the maximum number of elements (capacity). Will throw an exception if exceeded
 63 | # during insertion of an element.
 64 | # The capacity can be increased by saving/loading the index, see below.
 65 | #
 66 | # ef_construction - controls index search speed/build speed tradeoff
 67 | #
 68 | # M - is tightly connected with internal dimensionality of the data. Strongly affects memory consumption (~M)
 69 | # Higher M leads to higher accuracy/run_time at fixed ef/efConstruction
 70 | 
 71 | p.init_index(max_elements=num_elements//2, ef_construction=100, M=16)
 72 | 
 73 | # Controlling the recall by setting ef:
 74 | # higher ef leads to better accuracy, but slower search
 75 | p.set_ef(10)
 76 | 
 77 | # Set number of threads used during batch search/construction
 78 | # By default using all available cores
 79 | p.set_num_threads(4)
 80 | 
 81 | print("Adding first batch of %d elements" % (len(data1)))
 82 | p.add_items(data1)
 83 | 
 84 | # Query the elements for themselves and measure recall:
 85 | labels, distances = p.knn_query(data1, k=1)
 86 | print("Recall for the first batch:", np.mean(labels.reshape(-1) == np.arange(len(data1))), "\n")
 87 | 
 88 | # Serializing and deleting the index:
 89 | index_path='first_half.bin'
 90 | print("Saving index to '%s'" % index_path)
 91 | p.save_index("first_half.bin")
 92 | del p
 93 | 
 94 | # Re-initializing, loading the index
 95 | p = hnswlib.Index(space='l2', dim=dim)  # the space can be changed - keeps the data, alters the distance function.
 96 | 
 97 | print("\nLoading index from 'first_half.bin'\n")
 98 | 
 99 | # Increase the total capacity (max_elements), so that it will handle the new data
100 | p.load_index("first_half.bin", max_elements = num_elements)
101 | 
102 | print("Adding the second batch of %d elements" % (len(data2)))
103 | p.add_items(data2)
104 | 
105 | # Query the elements for themselves and measure recall:
106 | labels, distances = p.knn_query(data, k=1)
107 | print("Recall for two batches:", np.mean(labels.reshape(-1) == np.arange(len(data))), "\n")
108 | ```
109 | 
110 | An example with a symbolic filter `filter_function` during the search:
111 | ```python
112 | import hnswlib
113 | import numpy as np
114 | 
115 | dim = 16
116 | num_elements = 10000
117 | 
118 | # Generating sample data
119 | data = np.float32(np.random.random((num_elements, dim)))
120 | 
121 | # Declaring index
122 | hnsw_index = hnswlib.Index(space='l2', dim=dim)  # possible options are l2, cosine or ip
123 | 
124 | # Initiating index
125 | # max_elements - the maximum number of elements, should be known beforehand
126 | #     (probably will be made optional in the future)
127 | #
128 | # ef_construction - controls index search speed/build speed tradeoff
129 | # M - is tightly connected with internal dimensionality of the data
130 | #     strongly affects the memory consumption
131 | 
132 | hnsw_index.init_index(max_elements=num_elements, ef_construction=100, M=16)
133 | 
134 | # Controlling the recall by setting ef:
135 | # higher ef leads to better accuracy, but slower search
136 | hnsw_index.set_ef(10)
137 | 
138 | # Set number of threads used during batch search/construction
139 | # By default using all available cores
140 | hnsw_index.set_num_threads(4)
141 | 
142 | print("Adding %d elements" % (len(data)))
143 | # Added elements will have consecutive ids
144 | hnsw_index.add_items(data, ids=np.arange(num_elements))
145 | 
146 | print("Querying only even elements")
147 | # Define filter function that allows only even ids
148 | filter_function = lambda idx: idx%2 == 0
149 | # Query the elements for themselves and search only for even elements:
150 | # Warning: search with python filter works slow in multithreaded mode, therefore we set num_threads=1
151 | labels, distances = hnsw_index.knn_query(data, k=1, num_threads=1, filter=filter_function)
152 | # labels contain only elements with even id
153 | ```
154 | 
155 | An example with reusing the memory of the deleted elements when new elements are being added (via `allow_replace_deleted` flag):
156 | ```python
157 | import hnswlib
158 | import numpy as np
159 | 
160 | dim = 16
161 | num_elements = 1_000
162 | max_num_elements = 2 * num_elements
163 | 
164 | # Generating sample data
165 | labels1 = np.arange(0, num_elements)
166 | data1 = np.float32(np.random.random((num_elements, dim)))  # batch 1
167 | labels2 = np.arange(num_elements, 2 * num_elements)
168 | data2 = np.float32(np.random.random((num_elements, dim)))  # batch 2
169 | labels3 = np.arange(2 * num_elements, 3 * num_elements)
170 | data3 = np.float32(np.random.random((num_elements, dim)))  # batch 3
171 | 
172 | # Declaring index
173 | hnsw_index = hnswlib.Index(space='l2', dim=dim)
174 | 
175 | # Initiating index
176 | # max_elements - the maximum number of elements, should be known beforehand
177 | #     (probably will be made optional in the future)
178 | #
179 | # ef_construction - controls index search speed/build speed tradeoff
180 | # M - is tightly connected with internal dimensionality of the data
181 | #     strongly affects the memory consumption
182 | 
183 | # Enable replacing of deleted elements
184 | hnsw_index.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True)
185 | 
186 | # Controlling the recall by setting ef:
187 | # higher ef leads to better accuracy, but slower search
188 | hnsw_index.set_ef(10)
189 | 
190 | # Set number of threads used during batch search/construction
191 | # By default using all available cores
192 | hnsw_index.set_num_threads(4)
193 | 
194 | # Add batch 1 and 2 data
195 | hnsw_index.add_items(data1, labels1)
196 | hnsw_index.add_items(data2, labels2)  # Note: maximum number of elements is reached
197 | 
198 | # Delete data of batch 2
199 | for label in labels2:
200 |     hnsw_index.mark_deleted(label)
201 | 
202 | # Replace deleted elements
203 | # Maximum number of elements is reached therefore we cannot add new items,
204 | # but we can replace the deleted ones by using replace_deleted=True
205 | hnsw_index.add_items(data3, labels3, replace_deleted=True)
206 | # hnsw_index contains the data of batch 1 and batch 3 only
207 | ```


--------------------------------------------------------------------------------
/examples/python/example.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import hnswlib
 3 | import numpy as np
 4 | 
 5 | 
 6 | """
 7 | Example of index building, search and serialization/deserialization
 8 | """
 9 | 
10 | dim = 16
11 | num_elements = 10000
12 | 
13 | # Generating sample data
14 | data = np.float32(np.random.random((num_elements, dim)))
15 | 
16 | # We split the data in two batches:
17 | data1 = data[:num_elements // 2]
18 | data2 = data[num_elements // 2:]
19 | 
20 | # Declaring index
21 | p = hnswlib.Index(space='l2', dim=dim)  # possible options are l2, cosine or ip
22 | 
23 | # Initing index
24 | # max_elements - the maximum number of elements (capacity). Will throw an exception if exceeded
25 | # during insertion of an element.
26 | # The capacity can be increased by saving/loading the index, see below.
27 | #
28 | # ef_construction - controls index search speed/build speed tradeoff
29 | #
30 | # M - is tightly connected with internal dimensionality of the data. Strongly affects the memory consumption (~M)
31 | # Higher M leads to higher accuracy/run_time at fixed ef/efConstruction
32 | 
33 | p.init_index(max_elements=num_elements//2, ef_construction=100, M=16)
34 | 
35 | # Controlling the recall by setting ef:
36 | # higher ef leads to better accuracy, but slower search
37 | p.set_ef(10)
38 | 
39 | # Set number of threads used during batch search/construction
40 | # By default using all available cores
41 | p.set_num_threads(4)
42 | 
43 | print("Adding first batch of %d elements" % (len(data1)))
44 | p.add_items(data1)
45 | 
46 | # Query the elements for themselves and measure recall:
47 | labels, distances = p.knn_query(data1, k=1)
48 | print("Recall for the first batch:", np.mean(labels.reshape(-1) == np.arange(len(data1))), "\n")
49 | 
50 | # Serializing and deleting the index:
51 | index_path='first_half.bin'
52 | print("Saving index to '%s'" % index_path)
53 | p.save_index(index_path)
54 | del p
55 | 
56 | # Reiniting, loading the index
57 | p = hnswlib.Index(space='l2', dim=dim)  # the space can be changed - keeps the data, alters the distance function.
58 | 
59 | print("\nLoading index from 'first_half.bin'\n")
60 | 
61 | # Increase the total capacity (max_elements), so that it will handle the new data
62 | p.load_index("first_half.bin", max_elements = num_elements)
63 | 
64 | print("Adding the second batch of %d elements" % (len(data2)))
65 | p.add_items(data2)
66 | 
67 | # Query the elements for themselves and measure recall:
68 | labels, distances = p.knn_query(data, k=1)
69 | print("Recall for two batches:", np.mean(labels.reshape(-1) == np.arange(len(data))), "\n")
70 | 
71 | os.remove("first_half.bin")
72 | 


--------------------------------------------------------------------------------
/examples/python/example_filter.py:
--------------------------------------------------------------------------------
 1 | import hnswlib
 2 | import numpy as np
 3 | 
 4 | 
 5 | """
 6 | Example of filtering elements when searching
 7 | """
 8 | 
 9 | dim = 16
10 | num_elements = 10000
11 | 
12 | # Generating sample data
13 | data = np.float32(np.random.random((num_elements, dim)))
14 | 
15 | # Declaring index
16 | hnsw_index = hnswlib.Index(space='l2', dim=dim)  # possible options are l2, cosine or ip
17 | 
18 | # Initiating index
19 | # max_elements - the maximum number of elements, should be known beforehand
20 | #     (probably will be made optional in the future)
21 | #
22 | # ef_construction - controls index search speed/build speed tradeoff
23 | # M - is tightly connected with internal dimensionality of the data
24 | #     strongly affects the memory consumption
25 | 
26 | hnsw_index.init_index(max_elements=num_elements, ef_construction=100, M=16)
27 | 
28 | # Controlling the recall by setting ef:
29 | # higher ef leads to better accuracy, but slower search
30 | hnsw_index.set_ef(10)
31 | 
32 | # Set number of threads used during batch search/construction
33 | # By default using all available cores
34 | hnsw_index.set_num_threads(4)
35 | 
36 | print("Adding %d elements" % (len(data)))
37 | # Added elements will have consecutive ids
38 | hnsw_index.add_items(data, ids=np.arange(num_elements))
39 | 
40 | print("Querying only even elements")
41 | # Define filter function that allows only even ids
42 | filter_function = lambda idx: idx%2 == 0
43 | # Query the elements for themselves and search only for even elements:
44 | # Warning: search with a filter works slow in python in multithreaded mode, therefore we set num_threads=1
45 | labels, distances = hnsw_index.knn_query(data, k=1, num_threads=1, filter=filter_function)
46 | # labels contain only elements with even id
47 | 


--------------------------------------------------------------------------------
/examples/python/example_replace_deleted.py:
--------------------------------------------------------------------------------
 1 | import hnswlib
 2 | import numpy as np
 3 | 
 4 | 
 5 | """
 6 | Example of replacing deleted elements with new ones
 7 | """
 8 | 
 9 | dim = 16
10 | num_elements = 1_000
11 | max_num_elements = 2 * num_elements
12 | 
13 | # Generating sample data
14 | labels1 = np.arange(0, num_elements)
15 | data1 = np.float32(np.random.random((num_elements, dim)))  # batch 1
16 | labels2 = np.arange(num_elements, 2 * num_elements)
17 | data2 = np.float32(np.random.random((num_elements, dim)))  # batch 2
18 | labels3 = np.arange(2 * num_elements, 3 * num_elements)
19 | data3 = np.float32(np.random.random((num_elements, dim)))  # batch 3
20 | 
21 | # Declaring index
22 | hnsw_index = hnswlib.Index(space='l2', dim=dim)
23 | 
24 | # Initiating index
25 | # max_elements - the maximum number of elements, should be known beforehand
26 | #     (probably will be made optional in the future)
27 | #
28 | # ef_construction - controls index search speed/build speed tradeoff
29 | # M - is tightly connected with internal dimensionality of the data
30 | #     strongly affects the memory consumption
31 | 
32 | # Enable replacing of deleted elements
33 | hnsw_index.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True)
34 | 
35 | # Controlling the recall by setting ef:
36 | # higher ef leads to better accuracy, but slower search
37 | hnsw_index.set_ef(10)
38 | 
39 | # Set number of threads used during batch search/construction
40 | # By default using all available cores
41 | hnsw_index.set_num_threads(4)
42 | 
43 | # Add batch 1 and 2 data
44 | hnsw_index.add_items(data1, labels1)
45 | hnsw_index.add_items(data2, labels2)  # Note: maximum number of elements is reached
46 | 
47 | # Delete data of batch 2
48 | for label in labels2:
49 |     hnsw_index.mark_deleted(label)
50 | 
51 | # Replace deleted elements
52 | # Maximum number of elements is reached therefore we cannot add new items,
53 | # but we can replace the deleted ones by using replace_deleted=True
54 | hnsw_index.add_items(data3, labels3, replace_deleted=True)
55 | # hnsw_index contains the data of batch 1 and batch 3 only
56 | 


--------------------------------------------------------------------------------
/examples/python/example_search.py:
--------------------------------------------------------------------------------
 1 | import hnswlib
 2 | import numpy as np
 3 | import pickle
 4 | 
 5 | 
 6 | """
 7 | Example of search
 8 | """
 9 | 
10 | dim = 128
11 | num_elements = 10000
12 | 
13 | # Generating sample data
14 | data = np.float32(np.random.random((num_elements, dim)))
15 | ids = np.arange(num_elements)
16 | 
17 | # Declaring index
18 | p = hnswlib.Index(space='l2', dim=dim)  # possible options are l2, cosine or ip
19 | 
20 | # Initializing index - the maximum number of elements should be known beforehand
21 | p.init_index(max_elements=num_elements, ef_construction=200, M=16)
22 | 
23 | # Element insertion (can be called several times):
24 | p.add_items(data, ids)
25 | 
26 | # Controlling the recall by setting ef:
27 | p.set_ef(50)  # ef should always be > k
28 | 
29 | # Query dataset, k - number of the closest elements (returns 2 numpy arrays)
30 | labels, distances = p.knn_query(data, k=1)
31 | 
32 | # Index objects support pickling
33 | # WARNING: serialization via pickle.dumps(p) or p.__getstate__() is NOT thread-safe with p.add_items method!
34 | # Note: ef parameter is included in serialization; random number generator is initialized with random_seed on Index load
35 | p_copy = pickle.loads(pickle.dumps(p))  # creates a copy of index p using pickle round-trip
36 | 
37 | ### Index parameters are exposed as class properties:
38 | print(f"Parameters passed to constructor:  space={p_copy.space}, dim={p_copy.dim}")
39 | print(f"Index construction: M={p_copy.M}, ef_construction={p_copy.ef_construction}")
40 | print(f"Index size is {p_copy.element_count} and index capacity is {p_copy.max_elements}")
41 | print(f"Search speed/quality trade-off parameter: ef={p_copy.ef}")
42 | 


--------------------------------------------------------------------------------
/examples/python/example_serialization.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | 
 3 | import hnswlib
 4 | import numpy as np
 5 | 
 6 | 
 7 | """
 8 | Example of serialization/deserialization
 9 | """
10 | 
11 | dim = 16
12 | num_elements = 10000
13 | 
14 | # Generating sample data
15 | data = np.float32(np.random.random((num_elements, dim)))
16 | 
17 | # We split the data in two batches:
18 | data1 = data[:num_elements // 2]
19 | data2 = data[num_elements // 2:]
20 | 
21 | # Declaring index
22 | p = hnswlib.Index(space='l2', dim=dim)  # possible options are l2, cosine or ip
23 | 
24 | # Initializing index
25 | # max_elements - the maximum number of elements (capacity). Will throw an exception if exceeded
26 | # during insertion of an element.
27 | # The capacity can be increased by saving/loading the index, see below.
28 | #
29 | # ef_construction - controls index search speed/build speed tradeoff
30 | #
31 | # M - is tightly connected with internal dimensionality of the data. Strongly affects memory consumption (~M)
32 | # Higher M leads to higher accuracy/run_time at fixed ef/efConstruction
33 | 
34 | p.init_index(max_elements=num_elements//2, ef_construction=100, M=16)
35 | 
36 | # Controlling the recall by setting ef:
37 | # higher ef leads to better accuracy, but slower search
38 | p.set_ef(10)
39 | 
40 | # Set number of threads used during batch search/construction
41 | # By default using all available cores
42 | p.set_num_threads(4)
43 | 
44 | print("Adding first batch of %d elements" % (len(data1)))
45 | p.add_items(data1)
46 | 
47 | # Query the elements for themselves and measure recall:
48 | labels, distances = p.knn_query(data1, k=1)
49 | print("Recall for the first batch:", np.mean(labels.reshape(-1) == np.arange(len(data1))), "\n")
50 | 
51 | # Serializing and deleting the index:
52 | index_path='first_half.bin'
53 | print("Saving index to '%s'" % index_path)
54 | p.save_index("first_half.bin")
55 | del p
56 | 
57 | # Re-initializing, loading the index
58 | p = hnswlib.Index(space='l2', dim=dim)  # the space can be changed - keeps the data, alters the distance function.
59 | 
60 | print("\nLoading index from 'first_half.bin'\n")
61 | 
62 | # Increase the total capacity (max_elements), so that it will handle the new data
63 | p.load_index("first_half.bin", max_elements = num_elements)
64 | 
65 | print("Adding the second batch of %d elements" % (len(data2)))
66 | p.add_items(data2)
67 | 
68 | # Query the elements for themselves and measure recall:
69 | labels, distances = p.knn_query(data, k=1)
70 | print("Recall for two batches:", np.mean(labels.reshape(-1) == np.arange(len(data))), "\n")
71 | 
72 | os.remove("first_half.bin")
73 | 


--------------------------------------------------------------------------------
/examples/python/pyw_hnswlib.py:
--------------------------------------------------------------------------------
 1 | import hnswlib
 2 | import numpy as np
 3 | import threading
 4 | import pickle
 5 | 
 6 | 
 7 | """
 8 | Example of python wrapper for hnswlib that supports python objects as ids
 9 | """
10 | 
11 | class Index():
12 |     def __init__(self, space, dim):
13 |         self.index = hnswlib.Index(space, dim)
14 |         self.lock = threading.Lock()
15 |         self.dict_labels = {}
16 |         self.cur_ind = 0
17 | 
18 |     def init_index(self, max_elements, ef_construction=200, M=16):
19 |         self.index.init_index(max_elements=max_elements, ef_construction=ef_construction, M=M)
20 | 
21 |     def add_items(self, data, ids=None):
22 |         if ids is not None:
23 |             assert len(data) == len(ids)
24 |         num_added = len(data)
25 |         with self.lock:
26 |             start = self.cur_ind
27 |             self.cur_ind += num_added
28 |         int_labels = []
29 | 
30 |         if ids is not None:
31 |             for dl in ids:
32 |                 int_labels.append(start)
33 |                 self.dict_labels[start] = dl
34 |                 start += 1
35 |         else:
36 |             for _ in range(len(data)):
37 |                 int_labels.append(start)
38 |                 self.dict_labels[start] = start
39 |                 start += 1
40 |         self.index.add_items(data=data, ids=np.asarray(int_labels))
41 | 
42 |     def set_ef(self, ef):
43 |         self.index.set_ef(ef)
44 | 
45 |     def load_index(self, path):
46 |         self.index.load_index(path)
47 |         with open(path + ".pkl", "rb") as f:
48 |             self.cur_ind, self.dict_labels = pickle.load(f)
49 | 
50 |     def save_index(self, path):
51 |         self.index.save_index(path)
52 |         with open(path + ".pkl", "wb") as f:
53 |             pickle.dump((self.cur_ind, self.dict_labels), f)
54 | 
55 |     def set_num_threads(self, num_threads):
56 |         self.index.set_num_threads(num_threads)
57 | 
58 |     def knn_query(self, data, k=1):
59 |         labels_int, distances = self.index.knn_query(data=data, k=k)
60 |         labels = []
61 |         for li in labels_int:
62 |             labels.append(
63 |                 [self.dict_labels[l] for l in li]
64 |             )
65 |         return labels, distances
66 | 


--------------------------------------------------------------------------------
/hnswlib/bruteforce.h:
--------------------------------------------------------------------------------
  1 | #pragma once
  2 | #include <unordered_map>
  3 | #include <fstream>
  4 | #include <mutex>
  5 | #include <algorithm>
  6 | #include <assert.h>
  7 | 
  8 | namespace hnswlib {
  9 | template<typename dist_t>
 10 | class BruteforceSearch : public AlgorithmInterface<dist_t> {
 11 |  public:
 12 |     char *data_;
 13 |     size_t maxelements_;
 14 |     size_t cur_element_count;
 15 |     size_t size_per_element_;
 16 | 
 17 |     size_t data_size_;
 18 |     DISTFUNC <dist_t> fstdistfunc_;
 19 |     void *dist_func_param_;
 20 |     std::mutex index_lock;
 21 | 
 22 |     std::unordered_map<labeltype, size_t > dict_external_to_internal;
 23 | 
 24 | 
 25 |     BruteforceSearch(SpaceInterface <dist_t> *s)
 26 |         : data_(nullptr),
 27 |             maxelements_(0),
 28 |             cur_element_count(0),
 29 |             size_per_element_(0),
 30 |             data_size_(0),
 31 |             dist_func_param_(nullptr) {
 32 |     }
 33 | 
 34 | 
 35 |     BruteforceSearch(SpaceInterface<dist_t> *s, const std::string &location)
 36 |         : data_(nullptr),
 37 |             maxelements_(0),
 38 |             cur_element_count(0),
 39 |             size_per_element_(0),
 40 |             data_size_(0),
 41 |             dist_func_param_(nullptr) {
 42 |         loadIndex(location, s);
 43 |     }
 44 | 
 45 | 
 46 |     BruteforceSearch(SpaceInterface <dist_t> *s, size_t maxElements) {
 47 |         maxelements_ = maxElements;
 48 |         data_size_ = s->get_data_size();
 49 |         fstdistfunc_ = s->get_dist_func();
 50 |         dist_func_param_ = s->get_dist_func_param();
 51 |         size_per_element_ = data_size_ + sizeof(labeltype);
 52 |         data_ = (char *) malloc(maxElements * size_per_element_);
 53 |         if (data_ == nullptr)
 54 |             throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data");
 55 |         cur_element_count = 0;
 56 |     }
 57 | 
 58 | 
 59 |     ~BruteforceSearch() {
 60 |         free(data_);
 61 |     }
 62 | 
 63 | 
 64 |     void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) {
 65 |         int idx;
 66 |         {
 67 |             std::unique_lock<std::mutex> lock(index_lock);
 68 | 
 69 |             auto search = dict_external_to_internal.find(label);
 70 |             if (search != dict_external_to_internal.end()) {
 71 |                 idx = search->second;
 72 |             } else {
 73 |                 if (cur_element_count >= maxelements_) {
 74 |                     throw std::runtime_error("The number of elements exceeds the specified limit\n");
 75 |                 }
 76 |                 idx = cur_element_count;
 77 |                 dict_external_to_internal[label] = idx;
 78 |                 cur_element_count++;
 79 |             }
 80 |         }
 81 |         memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype));
 82 |         memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);
 83 |     }
 84 | 
 85 | 
 86 |     void removePoint(labeltype cur_external) {
 87 |         std::unique_lock<std::mutex> lock(index_lock);
 88 | 
 89 |         auto found = dict_external_to_internal.find(cur_external);
 90 |         if (found == dict_external_to_internal.end()) {
 91 |             return;
 92 |         }
 93 | 
 94 |         dict_external_to_internal.erase(found);
 95 | 
 96 |         size_t cur_c = found->second;
 97 |         labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
 98 |         dict_external_to_internal[label] = cur_c;
 99 |         memcpy(data_ + size_per_element_ * cur_c,
100 |                 data_ + size_per_element_ * (cur_element_count-1),
101 |                 data_size_+sizeof(labeltype));
102 |         cur_element_count--;
103 |     }
104 | 
105 | 
106 |     std::priority_queue<std::pair<dist_t, labeltype >>
107 |     searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
108 |         assert(k <= cur_element_count);
109 |         std::priority_queue<std::pair<dist_t, labeltype >> topResults;
110 |         if (cur_element_count == 0) return topResults;
111 |         for (int i = 0; i < k; i++) {
112 |             dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
113 |             labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
114 |             if ((!isIdAllowed) || (*isIdAllowed)(label)) {
115 |                 topResults.emplace(dist, label);
116 |             }
117 |         }
118 |         dist_t lastdist = topResults.empty() ? std::numeric_limits<dist_t>::max() : topResults.top().first;
119 |         for (int i = k; i < cur_element_count; i++) {
120 |             dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
121 |             if (dist <= lastdist) {
122 |                 labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
123 |                 if ((!isIdAllowed) || (*isIdAllowed)(label)) {
124 |                     topResults.emplace(dist, label);
125 |                 }
126 |                 if (topResults.size() > k)
127 |                     topResults.pop();
128 | 
129 |                 if (!topResults.empty()) {
130 |                     lastdist = topResults.top().first;
131 |                 }
132 |             }
133 |         }
134 |         return topResults;
135 |     }
136 | 
137 | 
138 |     void saveIndex(const std::string &location) {
139 |         std::ofstream output(location, std::ios::binary);
140 |         std::streampos position;
141 | 
142 |         writeBinaryPOD(output, maxelements_);
143 |         writeBinaryPOD(output, size_per_element_);
144 |         writeBinaryPOD(output, cur_element_count);
145 | 
146 |         output.write(data_, maxelements_ * size_per_element_);
147 | 
148 |         output.close();
149 |     }
150 | 
151 | 
152 |     void loadIndex(const std::string &location, SpaceInterface<dist_t> *s) {
153 |         std::ifstream input(location, std::ios::binary);
154 |         std::streampos position;
155 | 
156 |         readBinaryPOD(input, maxelements_);
157 |         readBinaryPOD(input, size_per_element_);
158 |         readBinaryPOD(input, cur_element_count);
159 | 
160 |         data_size_ = s->get_data_size();
161 |         fstdistfunc_ = s->get_dist_func();
162 |         dist_func_param_ = s->get_dist_func_param();
163 |         size_per_element_ = data_size_ + sizeof(labeltype);
164 |         data_ = (char *) malloc(maxelements_ * size_per_element_);
165 |         if (data_ == nullptr)
166 |             throw std::runtime_error("Not enough memory: loadIndex failed to allocate data");
167 | 
168 |         input.read(data_, maxelements_ * size_per_element_);
169 | 
170 |         input.close();
171 |     }
172 | };
173 | }  // namespace hnswlib
174 | 


--------------------------------------------------------------------------------
/hnswlib/hnswlib.h:
--------------------------------------------------------------------------------
  1 | #pragma once
  2 | 
  3 | // https://github.com/nmslib/hnswlib/pull/508
  4 | // This allows others to provide their own error stream (e.g. RcppHNSW)
  5 | #ifndef HNSWLIB_ERR_OVERRIDE
  6 |   #define HNSWERR std::cerr
  7 | #else
  8 |   #define HNSWERR HNSWLIB_ERR_OVERRIDE
  9 | #endif
 10 | 
 11 | #ifndef NO_MANUAL_VECTORIZATION
 12 | #if (defined(__SSE__) || _M_IX86_FP > 0 || defined(_M_AMD64) || defined(_M_X64))
 13 | #define USE_SSE
 14 | #ifdef __AVX__
 15 | #define USE_AVX
 16 | #ifdef __AVX512F__
 17 | #define USE_AVX512
 18 | #endif
 19 | #endif
 20 | #endif
 21 | #endif
 22 | 
 23 | #if defined(USE_AVX) || defined(USE_SSE)
 24 | #ifdef _MSC_VER
 25 | #include <intrin.h>
 26 | #include <stdexcept>
 27 | static void cpuid(int32_t out[4], int32_t eax, int32_t ecx) {
 28 |     __cpuidex(out, eax, ecx);
 29 | }
 30 | static __int64 xgetbv(unsigned int x) {
 31 |     return _xgetbv(x);
 32 | }
 33 | #else
 34 | #include <x86intrin.h>
 35 | #include <cpuid.h>
 36 | #include <stdint.h>
 37 | static void cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) {
 38 |     __cpuid_count(eax, ecx, cpuInfo[0], cpuInfo[1], cpuInfo[2], cpuInfo[3]);
 39 | }
 40 | static uint64_t xgetbv(unsigned int index) {
 41 |     uint32_t eax, edx;
 42 |     __asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index));
 43 |     return ((uint64_t)edx << 32) | eax;
 44 | }
 45 | #endif
 46 | 
 47 | #if defined(USE_AVX512)
 48 | #include <immintrin.h>
 49 | #endif
 50 | 
 51 | #if defined(__GNUC__)
 52 | #define PORTABLE_ALIGN32 __attribute__((aligned(32)))
 53 | #define PORTABLE_ALIGN64 __attribute__((aligned(64)))
 54 | #else
 55 | #define PORTABLE_ALIGN32 __declspec(align(32))
 56 | #define PORTABLE_ALIGN64 __declspec(align(64))
 57 | #endif
 58 | 
 59 | // Adapted from https://github.com/Mysticial/FeatureDetector
 60 | #define _XCR_XFEATURE_ENABLED_MASK  0
 61 | 
 62 | static bool AVXCapable() {
 63 |     int cpuInfo[4];
 64 | 
 65 |     // CPU support
 66 |     cpuid(cpuInfo, 0, 0);
 67 |     int nIds = cpuInfo[0];
 68 | 
 69 |     bool HW_AVX = false;
 70 |     if (nIds >= 0x00000001) {
 71 |         cpuid(cpuInfo, 0x00000001, 0);
 72 |         HW_AVX = (cpuInfo[2] & ((int)1 << 28)) != 0;
 73 |     }
 74 | 
 75 |     // OS support
 76 |     cpuid(cpuInfo, 1, 0);
 77 | 
 78 |     bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0;
 79 |     bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0;
 80 | 
 81 |     bool avxSupported = false;
 82 |     if (osUsesXSAVE_XRSTORE && cpuAVXSuport) {
 83 |         uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK);
 84 |         avxSupported = (xcrFeatureMask & 0x6) == 0x6;
 85 |     }
 86 |     return HW_AVX && avxSupported;
 87 | }
 88 | 
 89 | static bool AVX512Capable() {
 90 |     if (!AVXCapable()) return false;
 91 | 
 92 |     int cpuInfo[4];
 93 | 
 94 |     // CPU support
 95 |     cpuid(cpuInfo, 0, 0);
 96 |     int nIds = cpuInfo[0];
 97 | 
 98 |     bool HW_AVX512F = false;
 99 |     if (nIds >= 0x00000007) {  //  AVX512 Foundation
100 |         cpuid(cpuInfo, 0x00000007, 0);
101 |         HW_AVX512F = (cpuInfo[1] & ((int)1 << 16)) != 0;
102 |     }
103 | 
104 |     // OS support
105 |     cpuid(cpuInfo, 1, 0);
106 | 
107 |     bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0;
108 |     bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0;
109 | 
110 |     bool avx512Supported = false;
111 |     if (osUsesXSAVE_XRSTORE && cpuAVXSuport) {
112 |         uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK);
113 |         avx512Supported = (xcrFeatureMask & 0xe6) == 0xe6;
114 |     }
115 |     return HW_AVX512F && avx512Supported;
116 | }
117 | #endif
118 | 
119 | #include <queue>
120 | #include <vector>
121 | #include <iostream>
122 | #include <string.h>
123 | 
124 | namespace hnswlib {
125 | typedef size_t labeltype;
126 | 
127 | // This can be extended to store state for filtering (e.g. from a std::set)
128 | class BaseFilterFunctor {
129 |  public:
130 |     virtual bool operator()(hnswlib::labeltype id) { return true; }
131 |     virtual ~BaseFilterFunctor() {};
132 | };
133 | 
134 | template<typename dist_t>
135 | class BaseSearchStopCondition {
136 |  public:
137 |     virtual void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) = 0;
138 | 
139 |     virtual void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) = 0;
140 | 
141 |     virtual bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) = 0;
142 | 
143 |     virtual bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) = 0;
144 | 
145 |     virtual bool should_remove_extra() = 0;
146 | 
147 |     virtual void filter_results(std::vector<std::pair<dist_t, labeltype >> &candidates) = 0;
148 | 
149 |     virtual ~BaseSearchStopCondition() {}
150 | };
151 | 
152 | template <typename T>
153 | class pairGreater {
154 |  public:
155 |     bool operator()(const T& p1, const T& p2) {
156 |         return p1.first > p2.first;
157 |     }
158 | };
159 | 
160 | template<typename T>
161 | static void writeBinaryPOD(std::ostream &out, const T &podRef) {
162 |     out.write((char *) &podRef, sizeof(T));
163 | }
164 | 
165 | template<typename T>
166 | static void readBinaryPOD(std::istream &in, T &podRef) {
167 |     in.read((char *) &podRef, sizeof(T));
168 | }
169 | 
170 | template<typename MTYPE>
171 | using DISTFUNC = MTYPE(*)(const void *, const void *, const void *);
172 | 
173 | template<typename MTYPE>
174 | class SpaceInterface {
175 |  public:
176 |     // virtual void search(void *);
177 |     virtual size_t get_data_size() = 0;
178 | 
179 |     virtual DISTFUNC<MTYPE> get_dist_func() = 0;
180 | 
181 |     virtual void *get_dist_func_param() = 0;
182 | 
183 |     virtual ~SpaceInterface() {}
184 | };
185 | 
186 | template<typename dist_t>
187 | class AlgorithmInterface {
188 |  public:
189 |     virtual void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) = 0;
190 | 
191 |     virtual std::priority_queue<std::pair<dist_t, labeltype>>
192 |         searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0;
193 | 
194 |     // Return k nearest neighbor in the order of closer fist
195 |     virtual std::vector<std::pair<dist_t, labeltype>>
196 |         searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const;
197 | 
198 |     virtual void saveIndex(const std::string &location) = 0;
199 |     virtual ~AlgorithmInterface(){
200 |     }
201 | };
202 | 
203 | template<typename dist_t>
204 | std::vector<std::pair<dist_t, labeltype>>
205 | AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t k,
206 |                                                  BaseFilterFunctor* isIdAllowed) const {
207 |     std::vector<std::pair<dist_t, labeltype>> result;
208 | 
209 |     // here searchKnn returns the result in the order of further first
210 |     auto ret = searchKnn(query_data, k, isIdAllowed);
211 |     {
212 |         size_t sz = ret.size();
213 |         result.resize(sz);
214 |         while (!ret.empty()) {
215 |             result[--sz] = ret.top();
216 |             ret.pop();
217 |         }
218 |     }
219 | 
220 |     return result;
221 | }
222 | }  // namespace hnswlib
223 | 
224 | #include "space_l2.h"
225 | #include "space_ip.h"
226 | #include "stop_condition.h"
227 | #include "bruteforce.h"
228 | #include "hnswalg.h"
229 | 


--------------------------------------------------------------------------------
/hnswlib/space_l2.h:
--------------------------------------------------------------------------------
  1 | #pragma once
  2 | #include "hnswlib.h"
  3 | 
  4 | namespace hnswlib {
  5 | 
  6 | static float
  7 | L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
  8 |     float *pVect1 = (float *) pVect1v;
  9 |     float *pVect2 = (float *) pVect2v;
 10 |     size_t qty = *((size_t *) qty_ptr);
 11 | 
 12 |     float res = 0;
 13 |     for (size_t i = 0; i < qty; i++) {
 14 |         float t = *pVect1 - *pVect2;
 15 |         pVect1++;
 16 |         pVect2++;
 17 |         res += t * t;
 18 |     }
 19 |     return (res);
 20 | }
 21 | 
 22 | #if defined(USE_AVX512)
 23 | 
 24 | // Favor using AVX512 if available.
 25 | static float
 26 | L2SqrSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
 27 |     float *pVect1 = (float *) pVect1v;
 28 |     float *pVect2 = (float *) pVect2v;
 29 |     size_t qty = *((size_t *) qty_ptr);
 30 |     float PORTABLE_ALIGN64 TmpRes[16];
 31 |     size_t qty16 = qty >> 4;
 32 | 
 33 |     const float *pEnd1 = pVect1 + (qty16 << 4);
 34 | 
 35 |     __m512 diff, v1, v2;
 36 |     __m512 sum = _mm512_set1_ps(0);
 37 | 
 38 |     while (pVect1 < pEnd1) {
 39 |         v1 = _mm512_loadu_ps(pVect1);
 40 |         pVect1 += 16;
 41 |         v2 = _mm512_loadu_ps(pVect2);
 42 |         pVect2 += 16;
 43 |         diff = _mm512_sub_ps(v1, v2);
 44 |         // sum = _mm512_fmadd_ps(diff, diff, sum);
 45 |         sum = _mm512_add_ps(sum, _mm512_mul_ps(diff, diff));
 46 |     }
 47 | 
 48 |     _mm512_store_ps(TmpRes, sum);
 49 |     float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] +
 50 |             TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] +
 51 |             TmpRes[13] + TmpRes[14] + TmpRes[15];
 52 | 
 53 |     return (res);
 54 | }
 55 | #endif
 56 | 
 57 | #if defined(USE_AVX)
 58 | 
 59 | // Favor using AVX if available.
 60 | static float
 61 | L2SqrSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
 62 |     float *pVect1 = (float *) pVect1v;
 63 |     float *pVect2 = (float *) pVect2v;
 64 |     size_t qty = *((size_t *) qty_ptr);
 65 |     float PORTABLE_ALIGN32 TmpRes[8];
 66 |     size_t qty16 = qty >> 4;
 67 | 
 68 |     const float *pEnd1 = pVect1 + (qty16 << 4);
 69 | 
 70 |     __m256 diff, v1, v2;
 71 |     __m256 sum = _mm256_set1_ps(0);
 72 | 
 73 |     while (pVect1 < pEnd1) {
 74 |         v1 = _mm256_loadu_ps(pVect1);
 75 |         pVect1 += 8;
 76 |         v2 = _mm256_loadu_ps(pVect2);
 77 |         pVect2 += 8;
 78 |         diff = _mm256_sub_ps(v1, v2);
 79 |         sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
 80 | 
 81 |         v1 = _mm256_loadu_ps(pVect1);
 82 |         pVect1 += 8;
 83 |         v2 = _mm256_loadu_ps(pVect2);
 84 |         pVect2 += 8;
 85 |         diff = _mm256_sub_ps(v1, v2);
 86 |         sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
 87 |     }
 88 | 
 89 |     _mm256_store_ps(TmpRes, sum);
 90 |     return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
 91 | }
 92 | 
 93 | #endif
 94 | 
 95 | #if defined(USE_SSE)
 96 | 
 97 | static float
 98 | L2SqrSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
 99 |     float *pVect1 = (float *) pVect1v;
100 |     float *pVect2 = (float *) pVect2v;
101 |     size_t qty = *((size_t *) qty_ptr);
102 |     float PORTABLE_ALIGN32 TmpRes[8];
103 |     size_t qty16 = qty >> 4;
104 | 
105 |     const float *pEnd1 = pVect1 + (qty16 << 4);
106 | 
107 |     __m128 diff, v1, v2;
108 |     __m128 sum = _mm_set1_ps(0);
109 | 
110 |     while (pVect1 < pEnd1) {
111 |         //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
112 |         v1 = _mm_loadu_ps(pVect1);
113 |         pVect1 += 4;
114 |         v2 = _mm_loadu_ps(pVect2);
115 |         pVect2 += 4;
116 |         diff = _mm_sub_ps(v1, v2);
117 |         sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
118 | 
119 |         v1 = _mm_loadu_ps(pVect1);
120 |         pVect1 += 4;
121 |         v2 = _mm_loadu_ps(pVect2);
122 |         pVect2 += 4;
123 |         diff = _mm_sub_ps(v1, v2);
124 |         sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
125 | 
126 |         v1 = _mm_loadu_ps(pVect1);
127 |         pVect1 += 4;
128 |         v2 = _mm_loadu_ps(pVect2);
129 |         pVect2 += 4;
130 |         diff = _mm_sub_ps(v1, v2);
131 |         sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
132 | 
133 |         v1 = _mm_loadu_ps(pVect1);
134 |         pVect1 += 4;
135 |         v2 = _mm_loadu_ps(pVect2);
136 |         pVect2 += 4;
137 |         diff = _mm_sub_ps(v1, v2);
138 |         sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
139 |     }
140 | 
141 |     _mm_store_ps(TmpRes, sum);
142 |     return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
143 | }
144 | #endif
145 | 
146 | #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
147 | static DISTFUNC<float> L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE;
148 | 
149 | static float
150 | L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
151 |     size_t qty = *((size_t *) qty_ptr);
152 |     size_t qty16 = qty >> 4 << 4;
153 |     float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16);
154 |     float *pVect1 = (float *) pVect1v + qty16;
155 |     float *pVect2 = (float *) pVect2v + qty16;
156 | 
157 |     size_t qty_left = qty - qty16;
158 |     float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
159 |     return (res + res_tail);
160 | }
161 | #endif
162 | 
163 | 
164 | #if defined(USE_SSE)
165 | static float
166 | L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
167 |     float PORTABLE_ALIGN32 TmpRes[8];
168 |     float *pVect1 = (float *) pVect1v;
169 |     float *pVect2 = (float *) pVect2v;
170 |     size_t qty = *((size_t *) qty_ptr);
171 | 
172 | 
173 |     size_t qty4 = qty >> 2;
174 | 
175 |     const float *pEnd1 = pVect1 + (qty4 << 2);
176 | 
177 |     __m128 diff, v1, v2;
178 |     __m128 sum = _mm_set1_ps(0);
179 | 
180 |     while (pVect1 < pEnd1) {
181 |         v1 = _mm_loadu_ps(pVect1);
182 |         pVect1 += 4;
183 |         v2 = _mm_loadu_ps(pVect2);
184 |         pVect2 += 4;
185 |         diff = _mm_sub_ps(v1, v2);
186 |         sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
187 |     }
188 |     _mm_store_ps(TmpRes, sum);
189 |     return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
190 | }
191 | 
192 | static float
193 | L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
194 |     size_t qty = *((size_t *) qty_ptr);
195 |     size_t qty4 = qty >> 2 << 2;
196 | 
197 |     float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4);
198 |     size_t qty_left = qty - qty4;
199 | 
200 |     float *pVect1 = (float *) pVect1v + qty4;
201 |     float *pVect2 = (float *) pVect2v + qty4;
202 |     float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
203 | 
204 |     return (res + res_tail);
205 | }
206 | #endif
207 | 
208 | class L2Space : public SpaceInterface<float> {
209 |     DISTFUNC<float> fstdistfunc_;
210 |     size_t data_size_;
211 |     size_t dim_;
212 | 
213 |  public:
214 |     L2Space(size_t dim) {
215 |         fstdistfunc_ = L2Sqr;
216 | #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
217 |     #if defined(USE_AVX512)
218 |         if (AVX512Capable())
219 |             L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512;
220 |         else if (AVXCapable())
221 |             L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
222 |     #elif defined(USE_AVX)
223 |         if (AVXCapable())
224 |             L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
225 |     #endif
226 | 
227 |         if (dim % 16 == 0)
228 |             fstdistfunc_ = L2SqrSIMD16Ext;
229 |         else if (dim % 4 == 0)
230 |             fstdistfunc_ = L2SqrSIMD4Ext;
231 |         else if (dim > 16)
232 |             fstdistfunc_ = L2SqrSIMD16ExtResiduals;
233 |         else if (dim > 4)
234 |             fstdistfunc_ = L2SqrSIMD4ExtResiduals;
235 | #endif
236 |         dim_ = dim;
237 |         data_size_ = dim * sizeof(float);
238 |     }
239 | 
240 |     size_t get_data_size() {
241 |         return data_size_;
242 |     }
243 | 
244 |     DISTFUNC<float> get_dist_func() {
245 |         return fstdistfunc_;
246 |     }
247 | 
248 |     void *get_dist_func_param() {
249 |         return &dim_;
250 |     }
251 | 
252 |     ~L2Space() {}
253 | };
254 | 
255 | static int
256 | L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) {
257 |     size_t qty = *((size_t *) qty_ptr);
258 |     int res = 0;
259 |     unsigned char *a = (unsigned char *) pVect1;
260 |     unsigned char *b = (unsigned char *) pVect2;
261 | 
262 |     qty = qty >> 2;
263 |     for (size_t i = 0; i < qty; i++) {
264 |         res += ((*a) - (*b)) * ((*a) - (*b));
265 |         a++;
266 |         b++;
267 |         res += ((*a) - (*b)) * ((*a) - (*b));
268 |         a++;
269 |         b++;
270 |         res += ((*a) - (*b)) * ((*a) - (*b));
271 |         a++;
272 |         b++;
273 |         res += ((*a) - (*b)) * ((*a) - (*b));
274 |         a++;
275 |         b++;
276 |     }
277 |     return (res);
278 | }
279 | 
280 | static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) {
281 |     size_t qty = *((size_t*)qty_ptr);
282 |     int res = 0;
283 |     unsigned char* a = (unsigned char*)pVect1;
284 |     unsigned char* b = (unsigned char*)pVect2;
285 | 
286 |     for (size_t i = 0; i < qty; i++) {
287 |         res += ((*a) - (*b)) * ((*a) - (*b));
288 |         a++;
289 |         b++;
290 |     }
291 |     return (res);
292 | }
293 | 
294 | class L2SpaceI : public SpaceInterface<int> {
295 |     DISTFUNC<int> fstdistfunc_;
296 |     size_t data_size_;
297 |     size_t dim_;
298 | 
299 |  public:
300 |     L2SpaceI(size_t dim) {
301 |         if (dim % 4 == 0) {
302 |             fstdistfunc_ = L2SqrI4x;
303 |         } else {
304 |             fstdistfunc_ = L2SqrI;
305 |         }
306 |         dim_ = dim;
307 |         data_size_ = dim * sizeof(unsigned char);
308 |     }
309 | 
310 |     size_t get_data_size() {
311 |         return data_size_;
312 |     }
313 | 
314 |     DISTFUNC<int> get_dist_func() {
315 |         return fstdistfunc_;
316 |     }
317 | 
318 |     void *get_dist_func_param() {
319 |         return &dim_;
320 |     }
321 | 
322 |     ~L2SpaceI() {}
323 | };
324 | }  // namespace hnswlib
325 | 


--------------------------------------------------------------------------------
/hnswlib/stop_condition.h:
--------------------------------------------------------------------------------
  1 | #pragma once
  2 | #include "space_l2.h"
  3 | #include "space_ip.h"
  4 | #include <assert.h>
  5 | #include <unordered_map>
  6 | 
  7 | namespace hnswlib {
  8 | 
  9 | template<typename DOCIDTYPE>
 10 | class BaseMultiVectorSpace : public SpaceInterface<float> {
 11 |  public:
 12 |     virtual DOCIDTYPE get_doc_id(const void *datapoint) = 0;
 13 | 
 14 |     virtual void set_doc_id(void *datapoint, DOCIDTYPE doc_id) = 0;
 15 | };
 16 | 
 17 | 
 18 | template<typename DOCIDTYPE>
 19 | class MultiVectorL2Space : public BaseMultiVectorSpace<DOCIDTYPE> {
 20 |     DISTFUNC<float> fstdistfunc_;
 21 |     size_t data_size_;
 22 |     size_t vector_size_;
 23 |     size_t dim_;
 24 | 
 25 |  public:
 26 |     MultiVectorL2Space(size_t dim) {
 27 |         fstdistfunc_ = L2Sqr;
 28 | #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
 29 |     #if defined(USE_AVX512)
 30 |         if (AVX512Capable())
 31 |             L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512;
 32 |         else if (AVXCapable())
 33 |             L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
 34 |     #elif defined(USE_AVX)
 35 |         if (AVXCapable())
 36 |             L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
 37 |     #endif
 38 | 
 39 |         if (dim % 16 == 0)
 40 |             fstdistfunc_ = L2SqrSIMD16Ext;
 41 |         else if (dim % 4 == 0)
 42 |             fstdistfunc_ = L2SqrSIMD4Ext;
 43 |         else if (dim > 16)
 44 |             fstdistfunc_ = L2SqrSIMD16ExtResiduals;
 45 |         else if (dim > 4)
 46 |             fstdistfunc_ = L2SqrSIMD4ExtResiduals;
 47 | #endif
 48 |         dim_ = dim;
 49 |         vector_size_ = dim * sizeof(float);
 50 |         data_size_ = vector_size_ + sizeof(DOCIDTYPE);
 51 |     }
 52 | 
 53 |     size_t get_data_size() override {
 54 |         return data_size_;
 55 |     }
 56 | 
 57 |     DISTFUNC<float> get_dist_func() override {
 58 |         return fstdistfunc_;
 59 |     }
 60 | 
 61 |     void *get_dist_func_param() override {
 62 |         return &dim_;
 63 |     }
 64 | 
 65 |     DOCIDTYPE get_doc_id(const void *datapoint) override {
 66 |         return *(DOCIDTYPE *)((char *)datapoint + vector_size_);
 67 |     }
 68 | 
 69 |     void set_doc_id(void *datapoint, DOCIDTYPE doc_id) override {
 70 |         *(DOCIDTYPE*)((char *)datapoint + vector_size_) = doc_id;
 71 |     }
 72 | 
 73 |     ~MultiVectorL2Space() {}
 74 | };
 75 | 
 76 | 
 77 | template<typename DOCIDTYPE>
 78 | class MultiVectorInnerProductSpace : public BaseMultiVectorSpace<DOCIDTYPE> {
 79 |     DISTFUNC<float> fstdistfunc_;
 80 |     size_t data_size_;
 81 |     size_t vector_size_;
 82 |     size_t dim_;
 83 | 
 84 |  public:
 85 |     MultiVectorInnerProductSpace(size_t dim) {
 86 |         fstdistfunc_ = InnerProductDistance;
 87 | #if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512)
 88 |     #if defined(USE_AVX512)
 89 |         if (AVX512Capable()) {
 90 |             InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512;
 91 |             InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512;
 92 |         } else if (AVXCapable()) {
 93 |             InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
 94 |             InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
 95 |         }
 96 |     #elif defined(USE_AVX)
 97 |         if (AVXCapable()) {
 98 |             InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
 99 |             InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
100 |         }
101 |     #endif
102 |     #if defined(USE_AVX)
103 |         if (AVXCapable()) {
104 |             InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX;
105 |             InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX;
106 |         }
107 |     #endif
108 | 
109 |         if (dim % 16 == 0)
110 |             fstdistfunc_ = InnerProductDistanceSIMD16Ext;
111 |         else if (dim % 4 == 0)
112 |             fstdistfunc_ = InnerProductDistanceSIMD4Ext;
113 |         else if (dim > 16)
114 |             fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals;
115 |         else if (dim > 4)
116 |             fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals;
117 | #endif
118 |         vector_size_ = dim * sizeof(float);
119 |         data_size_ = vector_size_ + sizeof(DOCIDTYPE);
120 |     }
121 | 
122 |     size_t get_data_size() override {
123 |         return data_size_;
124 |     }
125 | 
126 |     DISTFUNC<float> get_dist_func() override {
127 |         return fstdistfunc_;
128 |     }
129 | 
130 |     void *get_dist_func_param() override {
131 |         return &dim_;
132 |     }
133 | 
134 |     DOCIDTYPE get_doc_id(const void *datapoint) override {
135 |         return *(DOCIDTYPE *)((char *)datapoint + vector_size_);
136 |     }
137 | 
138 |     void set_doc_id(void *datapoint, DOCIDTYPE doc_id) override {
139 |         *(DOCIDTYPE*)((char *)datapoint + vector_size_) = doc_id;
140 |     }
141 | 
142 |     ~MultiVectorInnerProductSpace() {}
143 | };
144 | 
145 | 
146 | template<typename DOCIDTYPE, typename dist_t>
147 | class MultiVectorSearchStopCondition : public BaseSearchStopCondition<dist_t> {
148 |     size_t curr_num_docs_;
149 |     size_t num_docs_to_search_;
150 |     size_t ef_collection_;
151 |     std::unordered_map<DOCIDTYPE, size_t> doc_counter_;
152 |     std::priority_queue<std::pair<dist_t, DOCIDTYPE>> search_results_;
153 |     BaseMultiVectorSpace<DOCIDTYPE>& space_;
154 | 
155 |  public:
156 |     MultiVectorSearchStopCondition(
157 |         BaseMultiVectorSpace<DOCIDTYPE>& space,
158 |         size_t num_docs_to_search,
159 |         size_t ef_collection = 10)
160 |         : space_(space) {
161 |             curr_num_docs_ = 0;
162 |             num_docs_to_search_ = num_docs_to_search;
163 |             ef_collection_ = std::max(ef_collection, num_docs_to_search);
164 |         }
165 | 
166 |     void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override {
167 |         DOCIDTYPE doc_id = space_.get_doc_id(datapoint);
168 |         if (doc_counter_[doc_id] == 0) {
169 |             curr_num_docs_ += 1;
170 |         }
171 |         search_results_.emplace(dist, doc_id);
172 |         doc_counter_[doc_id] += 1;
173 |     }
174 | 
175 |     void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override {
176 |         DOCIDTYPE doc_id = space_.get_doc_id(datapoint);
177 |         doc_counter_[doc_id] -= 1;
178 |         if (doc_counter_[doc_id] == 0) {
179 |             curr_num_docs_ -= 1;
180 |         }
181 |         search_results_.pop();
182 |     }
183 | 
184 |     bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) override {
185 |         bool stop_search = candidate_dist > lowerBound && curr_num_docs_ == ef_collection_;
186 |         return stop_search;
187 |     }
188 | 
189 |     bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) override {
190 |         bool flag_consider_candidate = curr_num_docs_ < ef_collection_ || lowerBound > candidate_dist;
191 |         return flag_consider_candidate;
192 |     }
193 | 
194 |     bool should_remove_extra() override {
195 |         bool flag_remove_extra = curr_num_docs_ > ef_collection_;
196 |         return flag_remove_extra;
197 |     }
198 | 
199 |     void filter_results(std::vector<std::pair<dist_t, labeltype >> &candidates) override {
200 |         while (curr_num_docs_ > num_docs_to_search_) {
201 |             dist_t dist_cand = candidates.back().first;
202 |             dist_t dist_res = search_results_.top().first;
203 |             assert(dist_cand == dist_res);
204 |             DOCIDTYPE doc_id = search_results_.top().second;
205 |             doc_counter_[doc_id] -= 1;
206 |             if (doc_counter_[doc_id] == 0) {
207 |                 curr_num_docs_ -= 1;
208 |             }
209 |             search_results_.pop();
210 |             candidates.pop_back();
211 |         }
212 |     }
213 | 
214 |     ~MultiVectorSearchStopCondition() {}
215 | };
216 | 
217 | 
218 | template<typename dist_t>
219 | class EpsilonSearchStopCondition : public BaseSearchStopCondition<dist_t> {
220 |     float epsilon_;
221 |     size_t min_num_candidates_;
222 |     size_t max_num_candidates_;
223 |     size_t curr_num_items_;
224 | 
225 |  public:
226 |     EpsilonSearchStopCondition(float epsilon, size_t min_num_candidates, size_t max_num_candidates) {
227 |         assert(min_num_candidates <= max_num_candidates);
228 |         epsilon_ = epsilon;
229 |         min_num_candidates_ = min_num_candidates;
230 |         max_num_candidates_ = max_num_candidates;
231 |         curr_num_items_ = 0;
232 |     }
233 | 
234 |     void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override {
235 |         curr_num_items_ += 1;
236 |     }
237 | 
238 |     void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override {
239 |         curr_num_items_ -= 1;
240 |     }
241 | 
242 |     bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) override {
243 |         if (candidate_dist > lowerBound && curr_num_items_ == max_num_candidates_) {
244 |             // new candidate can't improve found results
245 |             return true;
246 |         }
247 |         if (candidate_dist > epsilon_ && curr_num_items_ >= min_num_candidates_) {
248 |             // new candidate is out of epsilon region and
249 |             // minimum number of candidates is checked
250 |             return true;
251 |         }
252 |         return false;
253 |     }
254 | 
255 |     bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) override {
256 |         bool flag_consider_candidate = curr_num_items_ < max_num_candidates_ || lowerBound > candidate_dist;
257 |         return flag_consider_candidate;
258 |     }
259 | 
260 |     bool should_remove_extra() {
261 |         bool flag_remove_extra = curr_num_items_ > max_num_candidates_;
262 |         return flag_remove_extra;
263 |     }
264 | 
265 |     void filter_results(std::vector<std::pair<dist_t, labeltype >> &candidates) override {
266 |         while (!candidates.empty() && candidates.back().first > epsilon_) {
267 |             candidates.pop_back();
268 |         }
269 |         while (candidates.size() > max_num_candidates_) {
270 |             candidates.pop_back();
271 |         }
272 |     }
273 | 
274 |     ~EpsilonSearchStopCondition() {}
275 | };
276 | }  // namespace hnswlib
277 | 


--------------------------------------------------------------------------------
/hnswlib/visited_list_pool.h:
--------------------------------------------------------------------------------
 1 | #pragma once
 2 | 
 3 | #include <mutex>
 4 | #include <string.h>
 5 | #include <deque>
 6 | 
 7 | namespace hnswlib {
 8 | typedef unsigned short int vl_type;
 9 | 
10 | class VisitedList {
11 |  public:
12 |     vl_type curV;
13 |     vl_type *mass;
14 |     unsigned int numelements;
15 | 
16 |     VisitedList(int numelements1) {
17 |         curV = -1;
18 |         numelements = numelements1;
19 |         mass = new vl_type[numelements];
20 |     }
21 | 
22 |     void reset() {
23 |         curV++;
24 |         if (curV == 0) {
25 |             memset(mass, 0, sizeof(vl_type) * numelements);
26 |             curV++;
27 |         }
28 |     }
29 | 
30 |     ~VisitedList() { delete[] mass; }
31 | };
32 | ///////////////////////////////////////////////////////////
33 | //
34 | // Class for multi-threaded pool-management of VisitedLists
35 | //
36 | /////////////////////////////////////////////////////////
37 | 
38 | class VisitedListPool {
39 |     std::deque<VisitedList *> pool;
40 |     std::mutex poolguard;
41 |     int numelements;
42 | 
43 |  public:
44 |     VisitedListPool(int initmaxpools, int numelements1) {
45 |         numelements = numelements1;
46 |         for (int i = 0; i < initmaxpools; i++)
47 |             pool.push_front(new VisitedList(numelements));
48 |     }
49 | 
50 |     VisitedList *getFreeVisitedList() {
51 |         VisitedList *rez;
52 |         {
53 |             std::unique_lock <std::mutex> lock(poolguard);
54 |             if (pool.size() > 0) {
55 |                 rez = pool.front();
56 |                 pool.pop_front();
57 |             } else {
58 |                 rez = new VisitedList(numelements);
59 |             }
60 |         }
61 |         rez->reset();
62 |         return rez;
63 |     }
64 | 
65 |     void releaseVisitedList(VisitedList *vl) {
66 |         std::unique_lock <std::mutex> lock(poolguard);
67 |         pool.push_front(vl);
68 |     }
69 | 
70 |     ~VisitedListPool() {
71 |         while (pool.size()) {
72 |             VisitedList *rez = pool.front();
73 |             pool.pop_front();
74 |             delete rez;
75 |         }
76 |     }
77 | };
78 | }  // namespace hnswlib
79 | 


--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
 1 | [build-system]
 2 | requires = [
 3 |     "setuptools>=42",
 4 |     "wheel",
 5 |     "numpy>=1.10.0",
 6 |     "pybind11>=2.0",
 7 | ]
 8 | 
 9 | build-backend = "setuptools.build_meta"
10 | 


--------------------------------------------------------------------------------
/python_bindings/LazyIndex.py:
--------------------------------------------------------------------------------
 1 | import hnswlib
 2 | """
 3 |     A python wrapper for lazy indexing, preserves the same api as hnswlib.Index but initializes the index only after adding items for the first time with `add_items`.
 4 | """
 5 | class LazyIndex(hnswlib.Index):
 6 |     def __init__(self, space, dim,max_elements=1024, ef_construction=200, M=16):
 7 |         super().__init__(space, dim)
 8 |         self.init_max_elements=max_elements
 9 |         self.init_ef_construction=ef_construction
10 |         self.init_M=M
11 |     def init_index(self, max_elements=0,M=0,ef_construction=0):
12 |         if max_elements>0:
13 |             self.init_max_elements=max_elements
14 |         if ef_construction>0:
15 |             self.init_ef_construction=ef_construction
16 |         if M>0:
17 |             self.init_M=M
18 |         super().init_index(self.init_max_elements, self.init_M, self.init_ef_construction)
19 |     def add_items(self, data, ids=None, num_threads=-1):
20 |         if self.max_elements==0:
21 |             self.init_index()
22 |         return super().add_items(data,ids, num_threads)
23 |     def get_items(self, ids=None):
24 |         if self.max_elements==0:
25 |             return []
26 |         return super().get_items(ids)
27 |     def knn_query(self, data,k=1, num_threads=-1):
28 |         if self.max_elements==0:
29 |             return [], []
30 |         return super().knn_query(data, k, num_threads)
31 |     def resize_index(self, size):
32 |         if self.max_elements==0:
33 |             return self.init_index(size)
34 |         else:
35 |             return super().resize_index(size)
36 |     def set_ef(self, ef):
37 |         if self.max_elements==0:
38 |             self.init_ef_construction=ef
39 |             return
40 |         super().set_ef(ef)
41 |     def get_max_elements(self):
42 |         return self.max_elements
43 |     def get_current_count(self):
44 |         return self.element_count
45 | 


--------------------------------------------------------------------------------
/python_bindings/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nmslib/hnswlib/c1b9b79af3d10c6ee7b5d0afa1ce851ae975254c/python_bindings/__init__.py


--------------------------------------------------------------------------------
/python_bindings/setup.py:
--------------------------------------------------------------------------------
1 | ../setup.py


--------------------------------------------------------------------------------
/python_bindings/tests/bindings_test_bf_index.py:
--------------------------------------------------------------------------------
 1 | import unittest
 2 | 
 3 | import numpy as np
 4 | 
 5 | import hnswlib
 6 | 
 7 | 
 8 | class RandomSelfTestCase(unittest.TestCase):
 9 |     def testBFIndex(self):
10 | 
11 |         dim = 16
12 |         num_elements = 10000
13 |         num_queries = 1000
14 |         k = 20
15 | 
16 |         # Generating sample data
17 |         data = np.float32(np.random.random((num_elements, dim)))
18 | 
19 |         # Declaring index
20 |         bf_index = hnswlib.BFIndex(space='l2', dim=dim)  # possible options are l2, cosine or ip
21 |         bf_index.init_index(max_elements=num_elements)
22 | 
23 |         num_threads = 8
24 |         bf_index.set_num_threads(num_threads)  # by default using all available cores
25 | 
26 |         print(f"Adding all elements {num_elements}")
27 |         bf_index.add_items(data)
28 | 
29 |         self.assertEqual(bf_index.num_threads, num_threads)
30 |         self.assertEqual(bf_index.get_max_elements(), num_elements)
31 |         self.assertEqual(bf_index.get_current_count(), num_elements)
32 | 
33 |         queries = np.float32(np.random.random((num_queries, dim)))
34 |         print("Searching nearest neighbours")
35 |         labels, distances = bf_index.knn_query(queries, k=k)
36 | 
37 |         print("Checking results")
38 |         for i in range(num_queries):
39 |             query = queries[i]
40 |             sq_dists = (data - query)**2
41 |             dists = np.sum(sq_dists, axis=1)
42 |             labels_gt = np.argsort(dists)[:k]
43 |             dists_gt = dists[labels_gt]
44 |             dists_bf = distances[i]
45 |             # we can compare labels but because of numeric errors in distance calculation in C++ and numpy
46 |             # sometimes we get different order of labels, therefore we compare distances
47 |             max_diff_with_gt = np.max(np.abs(dists_gt - dists_bf))
48 | 
49 |             self.assertTrue(max_diff_with_gt < 1e-5)
50 | 


--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import sys
  3 | import platform
  4 | 
  5 | import numpy as np
  6 | import pybind11
  7 | import setuptools
  8 | from setuptools import Extension, setup
  9 | from setuptools.command.build_ext import build_ext
 10 | 
 11 | __version__ = '0.8.0'
 12 | 
 13 | 
 14 | include_dirs = [
 15 |     pybind11.get_include(),
 16 |     np.get_include(),
 17 | ]
 18 | 
 19 | # compatibility when run in python_bindings
 20 | bindings_dir = 'python_bindings'
 21 | if bindings_dir in os.path.basename(os.getcwd()):
 22 |     source_files = ['./bindings.cpp']
 23 |     include_dirs.extend(['../hnswlib/'])
 24 | else:
 25 |     source_files = ['./python_bindings/bindings.cpp']
 26 |     include_dirs.extend(['./hnswlib/'])
 27 | 
 28 | 
 29 | libraries = []
 30 | extra_objects = []
 31 | 
 32 | 
 33 | ext_modules = [
 34 |     Extension(
 35 |         'hnswlib',
 36 |         source_files,
 37 |         include_dirs=include_dirs,
 38 |         libraries=libraries,
 39 |         language='c++',
 40 |         extra_objects=extra_objects,
 41 |     ),
 42 | ]
 43 | 
 44 | 
 45 | # As of Python 3.6, CCompiler has a `has_flag` method.
 46 | # cf http://bugs.python.org/issue26689
 47 | def has_flag(compiler, flagname):
 48 |     """Return a boolean indicating whether a flag name is supported on
 49 |     the specified compiler.
 50 |     """
 51 |     import tempfile
 52 |     with tempfile.NamedTemporaryFile('w', suffix='.cpp') as f:
 53 |         f.write('int main (int argc, char **argv) { return 0; }')
 54 |         try:
 55 |             compiler.compile([f.name], extra_postargs=[flagname])
 56 |         except setuptools.distutils.errors.CompileError:
 57 |             return False
 58 |     return True
 59 | 
 60 | 
 61 | def cpp_flag(compiler):
 62 |     """Return the -std=c++[11/14] compiler flag.
 63 |     The c++14 is prefered over c++11 (when it is available).
 64 |     """
 65 |     if has_flag(compiler, '-std=c++14'):
 66 |         return '-std=c++14'
 67 |     elif has_flag(compiler, '-std=c++11'):
 68 |         return '-std=c++11'
 69 |     else:
 70 |         raise RuntimeError('Unsupported compiler -- at least C++11 support '
 71 |                            'is needed!')
 72 | 
 73 | 
 74 | class BuildExt(build_ext):
 75 |     """A custom build extension for adding compiler-specific options."""
 76 |     compiler_flag_native = '-march=native'
 77 |     c_opts = {
 78 |         'msvc': ['/EHsc', '/openmp', '/O2'],
 79 |         'unix': ['-O3', compiler_flag_native],  # , '-w'
 80 |     }
 81 |     link_opts = {
 82 |         'unix': [],
 83 |         'msvc': [],
 84 |     }
 85 | 
 86 |     if os.environ.get("HNSWLIB_NO_NATIVE"):
 87 |         c_opts['unix'].remove(compiler_flag_native)
 88 | 
 89 |     if sys.platform == 'darwin':
 90 |         c_opts['unix'] += ['-stdlib=libc++', '-mmacosx-version-min=10.7']
 91 |         link_opts['unix'] += ['-stdlib=libc++', '-mmacosx-version-min=10.7']
 92 |     else:
 93 |         c_opts['unix'].append("-fopenmp")
 94 |         link_opts['unix'].extend(['-fopenmp', '-pthread'])
 95 | 
 96 |     def build_extensions(self):
 97 |         ct = self.compiler.compiler_type
 98 |         opts = BuildExt.c_opts.get(ct, [])
 99 |         if ct == 'unix':
100 |             opts.append('-DVERSION_INFO="%s"' % self.distribution.get_version())
101 |             opts.append(cpp_flag(self.compiler))
102 |             if has_flag(self.compiler, '-fvisibility=hidden'):
103 |                 opts.append('-fvisibility=hidden')
104 |             if not os.environ.get("HNSWLIB_NO_NATIVE"):
105 |                 # check that native flag is available
106 |                 print('checking avalability of flag:', BuildExt.compiler_flag_native)
107 |                 if not has_flag(self.compiler, BuildExt.compiler_flag_native):
108 |                     print('removing unsupported compiler flag:', BuildExt.compiler_flag_native)
109 |                     opts.remove(BuildExt.compiler_flag_native)
110 |                     # for macos add apple-m1 flag if it's available
111 |                     if sys.platform == 'darwin':
112 |                         m1_flag = '-mcpu=apple-m1'
113 |                         print('checking avalability of flag:', m1_flag)
114 |                         if has_flag(self.compiler, m1_flag):
115 |                             print('adding flag:', m1_flag)
116 |                             opts.append(m1_flag)
117 |                         else:
118 |                             print(f'flag: {m1_flag} is not available')
119 |                 else:
120 |                     print(f'flag: {BuildExt.compiler_flag_native} is available')
121 |         elif ct == 'msvc':
122 |             opts.append('/DVERSION_INFO=\\"%s\\"' % self.distribution.get_version())
123 | 
124 |         for ext in self.extensions:
125 |             ext.extra_compile_args.extend(opts)
126 |             ext.extra_link_args.extend(BuildExt.link_opts.get(ct, []))
127 | 
128 |         build_ext.build_extensions(self)
129 | 
130 | 
131 | setup(
132 |     name='hnswlib',
133 |     version=__version__,
134 |     description='hnswlib',
135 |     author='Yury Malkov and others',
136 |     url='https://github.com/yurymalkov/hnsw',
137 |     long_description="""hnsw""",
138 |     ext_modules=ext_modules,
139 |     install_requires=['numpy'],
140 |     cmdclass={'build_ext': BuildExt},
141 |     zip_safe=False,
142 | )
143 | 


--------------------------------------------------------------------------------
/tests/cpp/download_bigann.py:
--------------------------------------------------------------------------------
 1 | import os.path
 2 | import os
 3 | 
 4 | links = ['ftp://ftp.irisa.fr/local/texmex/corpus/bigann_query.bvecs.gz',
 5 |          'ftp://ftp.irisa.fr/local/texmex/corpus/bigann_gnd.tar.gz',
 6 |          'ftp://ftp.irisa.fr/local/texmex/corpus/bigann_base.bvecs.gz']
 7 | 
 8 | os.makedirs('downloads', exist_ok=True)
 9 | os.makedirs('bigann', exist_ok=True)
10 | for link in links:
11 |     name = link.rsplit('/', 1)[-1]
12 |     filename = os.path.join('downloads', name)
13 |     if not os.path.isfile(filename):
14 |         print('Downloading: ' + filename)
15 |         try:
16 |             os.system('wget --output-document=' + filename + ' ' + link)
17 |         except Exception as inst:
18 |             print(inst)
19 |             print('  Encountered unknown error. Continuing.')
20 |     else:
21 |         print('Already downloaded: ' + filename)
22 |     if filename.endswith('.tar.gz'):
23 |         command = 'tar -zxf ' + filename + ' --directory bigann'
24 |     else:
25 |         command = 'cat ' + filename + ' | gzip -dc > bigann/' + name.replace(".gz", "")
26 |     print("Unpacking file:", command)
27 |     os.system(command)
28 | 


--------------------------------------------------------------------------------
/tests/cpp/epsilon_search_test.cpp:
--------------------------------------------------------------------------------
  1 | #include "assert.h"
  2 | #include "../../hnswlib/hnswlib.h"
  3 | 
  4 | typedef unsigned int docidtype;
  5 | typedef float dist_t;
  6 | 
  7 | int main() {
  8 |     int dim = 16;               // Dimension of the elements
  9 |     int max_elements = 10000;   // Maximum number of elements, should be known beforehand
 10 |     int M = 16;                 // Tightly connected with internal dimensionality of the data
 11 |                                 // strongly affects the memory consumption
 12 |     int ef_construction = 200;  // Controls index search speed/build speed tradeoff
 13 | 
 14 |     int num_queries = 100;
 15 |     float epsilon2 = 1.0;                    // Squared distance to query
 16 |     int max_num_candidates = max_elements;   // Upper bound on the number of returned elements in the epsilon region
 17 |     int min_num_candidates = 2000;           // Minimum number of candidates to search in the epsilon region
 18 |                                              // this parameter is similar to ef
 19 | 
 20 |     // Initing index
 21 |     hnswlib::L2Space space(dim);
 22 |     hnswlib::BruteforceSearch<dist_t>* alg_brute = new hnswlib::BruteforceSearch<dist_t>(&space, max_elements);
 23 |     hnswlib::HierarchicalNSW<dist_t>* alg_hnsw = new hnswlib::HierarchicalNSW<dist_t>(&space, max_elements, M, ef_construction);
 24 | 
 25 |     // Generate random data
 26 |     std::mt19937 rng;
 27 |     rng.seed(47);
 28 |     std::uniform_real_distribution<> distrib_real;
 29 | 
 30 |     float* data = new float[dim * max_elements];
 31 |     for (int i = 0; i < dim * max_elements; i++) {
 32 |         data[i] = distrib_real(rng);
 33 |     }
 34 | 
 35 |     // Add data to index
 36 |     std::cout << "Building index ...\n";
 37 |     for (int i = 0; i < max_elements; i++) {
 38 |         hnswlib::labeltype label = i;
 39 |         float* point_data = data + i * dim;
 40 |         alg_hnsw->addPoint(point_data, label);
 41 |         alg_brute->addPoint(point_data, label);
 42 |     }
 43 |     std::cout << "Index is ready\n";
 44 | 
 45 |     // Query random vectors
 46 |     for (int i = 0; i < num_queries; i++) {
 47 |         float* query_data = new float[dim];
 48 |         for (int j = 0; j < dim; j++) {
 49 |             query_data[j] = distrib_real(rng);
 50 |         }
 51 |         hnswlib::EpsilonSearchStopCondition<dist_t> stop_condition(epsilon2, min_num_candidates, max_num_candidates);
 52 |         std::vector<std::pair<float, hnswlib::labeltype>> result_hnsw =
 53 |             alg_hnsw->searchStopConditionClosest(query_data, stop_condition);
 54 |         
 55 |         // check that returned results are in epsilon region
 56 |         size_t num_vectors = result_hnsw.size();
 57 |         std::unordered_set<hnswlib::labeltype> hnsw_labels;
 58 |         for (auto pair: result_hnsw) {
 59 |             float dist = pair.first;
 60 |             hnswlib::labeltype label = pair.second;
 61 |             hnsw_labels.insert(label);
 62 |             assert(dist >=0 && dist <= epsilon2);
 63 |         }
 64 |         std::priority_queue<std::pair<float, hnswlib::labeltype>> result_brute =
 65 |             alg_brute->searchKnn(query_data, max_elements);
 66 |         
 67 |         // check recall
 68 |         std::unordered_set<hnswlib::labeltype> gt_labels;
 69 |         while (!result_brute.empty()) {
 70 |             float dist = result_brute.top().first;
 71 |             hnswlib::labeltype label = result_brute.top().second;
 72 |             if (dist < epsilon2) {
 73 |                 gt_labels.insert(label);
 74 |             }
 75 |             result_brute.pop();
 76 |         }
 77 |         float correct = 0;
 78 |         for (const auto& hnsw_label: hnsw_labels) {
 79 |             if (gt_labels.find(hnsw_label) != gt_labels.end()) {
 80 |                 correct += 1;
 81 |             }
 82 |         }
 83 |         if (gt_labels.size() == 0) {
 84 |             assert(correct == 0);
 85 |             continue;
 86 |         }
 87 |         float recall = correct / gt_labels.size();
 88 |         assert(recall > 0.95);
 89 |         delete[] query_data;
 90 |     }
 91 |     std::cout << "Recall is OK\n";
 92 | 
 93 |     // Query the elements for themselves and check that query can be found
 94 |     float epsilon2_small = 0.0001f;
 95 |     int min_candidates_small = 500;
 96 |     for (size_t i = 0; i < max_elements; i++) {
 97 |         hnswlib::EpsilonSearchStopCondition<dist_t> stop_condition(epsilon2_small, min_candidates_small, max_num_candidates);
 98 |         std::vector<std::pair<float, hnswlib::labeltype>> result = 
 99 |             alg_hnsw->searchStopConditionClosest(alg_hnsw->getDataByInternalId(i), stop_condition);
100 |         size_t num_vectors = result.size();
101 |         // get closest distance
102 |         float dist = -1;
103 |         if (!result.empty()) {
104 |             dist = result[0].first;
105 |         }
106 |         assert(dist == 0);
107 |     }
108 |     std::cout << "Small epsilon search is OK\n";
109 | 
110 |     delete[] data;
111 |     delete alg_brute;
112 |     delete alg_hnsw;
113 |     return 0;
114 | }
115 | 


--------------------------------------------------------------------------------
/tests/cpp/main.cpp:
--------------------------------------------------------------------------------
1 | 
2 | 
3 | void sift_test1B();
4 | int main() {
5 |     sift_test1B();
6 | 
7 |     return 0;
8 | }
9 | 


--------------------------------------------------------------------------------
/tests/cpp/multiThreadLoad_test.cpp:
--------------------------------------------------------------------------------
  1 | #include "../../hnswlib/hnswlib.h"
  2 | #include <thread>
  3 | #include <chrono>
  4 | 
  5 | 
  6 | int main() {
  7 |     std::cout << "Running multithread load test" << std::endl;
  8 |     int d = 16;
  9 |     int max_elements = 1000;
 10 | 
 11 |     std::mt19937 rng;
 12 |     rng.seed(47);
 13 |     std::uniform_real_distribution<> distrib_real;
 14 | 
 15 |     hnswlib::L2Space space(d);
 16 |     hnswlib::HierarchicalNSW<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, 2 * max_elements);
 17 | 
 18 |     std::cout << "Building index" << std::endl;
 19 |     int num_threads = 40;
 20 |     int num_labels = 10;
 21 | 
 22 |     int num_iterations = 10;
 23 |     int start_label = 0;
 24 | 
 25 |     // run threads that will add elements to the index
 26 |     // about 7 threads (the number depends on num_threads and num_labels)
 27 |     // will add/update element with the same label simultaneously
 28 |     while (true) {
 29 |         // add elements by batches
 30 |         std::uniform_int_distribution<> distrib_int(start_label, start_label + num_labels - 1);
 31 |         std::vector<std::thread> threads;
 32 |         for (size_t thread_id = 0; thread_id < num_threads; thread_id++) {
 33 |             threads.push_back(
 34 |                 std::thread(
 35 |                     [&] {
 36 |                         for (int iter = 0; iter < num_iterations; iter++) {
 37 |                             std::vector<float> data(d);
 38 |                             hnswlib::labeltype label = distrib_int(rng);
 39 |                             for (int i = 0; i < d; i++) {
 40 |                                 data[i] = distrib_real(rng);
 41 |                             }
 42 |                             alg_hnsw->addPoint(data.data(), label);
 43 |                         }
 44 |                     }
 45 |                 )
 46 |             );
 47 |         }
 48 |         for (auto &thread : threads) {
 49 |             thread.join();
 50 |         }
 51 |         if (alg_hnsw->cur_element_count > max_elements - num_labels) {
 52 |             break;
 53 |         }
 54 |         start_label += num_labels;
 55 |     }
 56 | 
 57 |     // insert remaining elements if needed
 58 |     for (hnswlib::labeltype label = 0; label < max_elements; label++) {
 59 |         auto search = alg_hnsw->label_lookup_.find(label);
 60 |         if (search == alg_hnsw->label_lookup_.end()) {
 61 |             std::cout << "Adding " << label << std::endl;
 62 |             std::vector<float> data(d);
 63 |             for (int i = 0; i < d; i++) {
 64 |                 data[i] = distrib_real(rng);
 65 |             }
 66 |             alg_hnsw->addPoint(data.data(), label);
 67 |         }
 68 |     }
 69 | 
 70 |     std::cout << "Index is created" << std::endl;
 71 | 
 72 |     bool stop_threads = false;
 73 |     std::vector<std::thread> threads;
 74 | 
 75 |     // create threads that will do markDeleted and unmarkDeleted of random elements
 76 |     // each thread works with specific range of labels
 77 |     std::cout << "Starting markDeleted and unmarkDeleted threads" << std::endl;
 78 |     num_threads = 20;
 79 |     int chunk_size = max_elements / num_threads;
 80 |     for (size_t thread_id = 0; thread_id < num_threads; thread_id++) {
 81 |         threads.push_back(
 82 |             std::thread(
 83 |                 [&, thread_id] {
 84 |                     std::uniform_int_distribution<> distrib_int(0, chunk_size - 1);
 85 |                     int start_id = thread_id * chunk_size;
 86 |                     std::vector<bool> marked_deleted(chunk_size);
 87 |                     while (!stop_threads) {
 88 |                         int id = distrib_int(rng);
 89 |                         hnswlib::labeltype label = start_id + id;
 90 |                         if (marked_deleted[id]) {
 91 |                             alg_hnsw->unmarkDelete(label);
 92 |                             marked_deleted[id] = false;
 93 |                         } else {
 94 |                             alg_hnsw->markDelete(label);
 95 |                             marked_deleted[id] = true;
 96 |                         }
 97 |                     }
 98 |                 }
 99 |             )
100 |         );
101 |     }
102 | 
103 |     // create threads that will add and update random elements
104 |     std::cout << "Starting add and update elements threads" << std::endl;
105 |     num_threads = 20;
106 |     std::uniform_int_distribution<> distrib_int_add(max_elements, 2 * max_elements - 1);
107 |     for (size_t thread_id = 0; thread_id < num_threads; thread_id++) {
108 |         threads.push_back(
109 |             std::thread(
110 |                 [&] {
111 |                     std::vector<float> data(d);
112 |                     while (!stop_threads) {
113 |                         hnswlib::labeltype label = distrib_int_add(rng);
114 |                         for (int i = 0; i < d; i++) {
115 |                             data[i] = distrib_real(rng);
116 |                         }
117 |                         alg_hnsw->addPoint(data.data(), label);
118 |                         std::vector<float> data = alg_hnsw->getDataByLabel<float>(label);
119 |                         float max_val = *max_element(data.begin(), data.end());
120 |                         // never happens but prevents compiler from deleting unused code
121 |                         if (max_val > 10) {
122 |                             throw std::runtime_error("Unexpected value in data");
123 |                         }
124 |                     }
125 |                 }
126 |             )
127 |         );
128 |     }
129 | 
130 |     std::cout << "Sleep and continue operations with index" << std::endl;
131 |     int sleep_ms = 60 * 1000;
132 |     std::this_thread::sleep_for(std::chrono::milliseconds(sleep_ms));
133 |     stop_threads = true;
134 |     for (auto &thread : threads) {
135 |         thread.join();
136 |     }
137 |     
138 |     std::cout << "Finish" << std::endl;
139 |     return 0;
140 | }
141 | 


--------------------------------------------------------------------------------
/tests/cpp/multiThread_replace_test.cpp:
--------------------------------------------------------------------------------
  1 | #include "../../hnswlib/hnswlib.h"
  2 | #include <thread>
  3 | #include <chrono>
  4 | 
  5 | 
  6 | template<class Function>
  7 | inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) {
  8 |     if (numThreads <= 0) {
  9 |         numThreads = std::thread::hardware_concurrency();
 10 |     }
 11 | 
 12 |     if (numThreads == 1) {
 13 |         for (size_t id = start; id < end; id++) {
 14 |             fn(id, 0);
 15 |         }
 16 |     } else {
 17 |         std::vector<std::thread> threads;
 18 |         std::atomic<size_t> current(start);
 19 | 
 20 |         // keep track of exceptions in threads
 21 |         // https://stackoverflow.com/a/32428427/1713196
 22 |         std::exception_ptr lastException = nullptr;
 23 |         std::mutex lastExceptMutex;
 24 | 
 25 |         for (size_t threadId = 0; threadId < numThreads; ++threadId) {
 26 |             threads.push_back(std::thread([&, threadId] {
 27 |                 while (true) {
 28 |                     size_t id = current.fetch_add(1);
 29 | 
 30 |                     if (id >= end) {
 31 |                         break;
 32 |                     }
 33 | 
 34 |                     try {
 35 |                         fn(id, threadId);
 36 |                     } catch (...) {
 37 |                         std::unique_lock<std::mutex> lastExcepLock(lastExceptMutex);
 38 |                         lastException = std::current_exception();
 39 |                         /*
 40 |                          * This will work even when current is the largest value that
 41 |                          * size_t can fit, because fetch_add returns the previous value
 42 |                          * before the increment (what will result in overflow
 43 |                          * and produce 0 instead of current + 1).
 44 |                          */
 45 |                         current = end;
 46 |                         break;
 47 |                     }
 48 |                 }
 49 |             }));
 50 |         }
 51 |         for (auto &thread : threads) {
 52 |             thread.join();
 53 |         }
 54 |         if (lastException) {
 55 |             std::rethrow_exception(lastException);
 56 |         }
 57 |     }
 58 | }
 59 | 
 60 | 
 61 | int main() {
 62 |     std::cout << "Running multithread load test" << std::endl;
 63 |     int d = 16;
 64 |     int num_elements = 1000;
 65 |     int max_elements = 2 * num_elements;
 66 |     int num_threads = 50;
 67 | 
 68 |     std::mt19937 rng;
 69 |     rng.seed(47);
 70 |     std::uniform_real_distribution<> distrib_real;
 71 | 
 72 |     hnswlib::L2Space space(d);
 73 | 
 74 |     // generate batch1 and batch2 data
 75 |     float* batch1 = new float[d * max_elements];
 76 |     for (int i = 0; i < d * max_elements; i++) {
 77 |         batch1[i] = distrib_real(rng);
 78 |     }
 79 |     float* batch2 = new float[d * num_elements];
 80 |     for (int i = 0; i < d * num_elements; i++) {
 81 |         batch2[i] = distrib_real(rng);
 82 |     }
 83 | 
 84 |     // generate random labels to delete them from index
 85 |     std::vector<int> rand_labels(max_elements);
 86 |     for (int i = 0; i < max_elements; i++) {
 87 |         rand_labels[i] = i;
 88 |     }
 89 |     std::shuffle(rand_labels.begin(), rand_labels.end(), rng);
 90 | 
 91 |     int iter = 0;
 92 |     while (iter < 200) {
 93 |         hnswlib::HierarchicalNSW<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, max_elements, 16, 200, 123, true);
 94 | 
 95 |         // add batch1 data
 96 |         ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) {
 97 |             alg_hnsw->addPoint((void*)(batch1 + d * row), row);
 98 |         });
 99 | 
100 |         // delete half random elements of batch1 data
101 |         for (int i = 0; i < num_elements; i++) {
102 |             alg_hnsw->markDelete(rand_labels[i]);
103 |         }
104 | 
105 |         // replace deleted elements with batch2 data
106 |         ParallelFor(0, num_elements, num_threads, [&](size_t row, size_t threadId) {
107 |             int label = rand_labels[row] + max_elements;
108 |             alg_hnsw->addPoint((void*)(batch2 + d * row), label, true);
109 |         });
110 | 
111 |         iter += 1;
112 | 
113 |         delete alg_hnsw;
114 |     }
115 |     
116 |     std::cout << "Finish" << std::endl;
117 | 
118 |     delete[] batch1;
119 |     delete[] batch2;
120 |     return 0;
121 | }
122 | 


--------------------------------------------------------------------------------
/tests/cpp/multivector_search_test.cpp:
--------------------------------------------------------------------------------
  1 | #include <assert.h>
  2 | #include "../../hnswlib/hnswlib.h"
  3 | 
  4 | typedef unsigned int docidtype;
  5 | typedef float dist_t;
  6 | 
  7 | int main() {
  8 |     int dim = 16;               // Dimension of the elements
  9 |     int max_elements = 1000;    // Maximum number of elements, should be known beforehand
 10 |     int M = 16;                 // Tightly connected with internal dimensionality of the data
 11 |                                 // strongly affects the memory consumption
 12 |     int ef_construction = 200;  // Controls index search speed/build speed tradeoff
 13 | 
 14 |     int num_queries = 100;
 15 |     int num_docs = 10;          // Number of documents to search
 16 |     int ef_collection = 15;     // Number of candidate documents during the search
 17 |                                 // Controlls the recall: higher ef leads to better accuracy, but slower search
 18 |     docidtype min_doc_id = 0;
 19 |     docidtype max_doc_id = 49;
 20 | 
 21 |     // Initing index
 22 |     hnswlib::MultiVectorL2Space<docidtype> space(dim);
 23 |     hnswlib::BruteforceSearch<dist_t>* alg_brute = new hnswlib::BruteforceSearch<dist_t>(&space, max_elements);
 24 |     hnswlib::HierarchicalNSW<dist_t>* alg_hnsw = new hnswlib::HierarchicalNSW<dist_t>(&space, max_elements, M, ef_construction);
 25 | 
 26 |     // Generate random data
 27 |     std::mt19937 rng;
 28 |     rng.seed(47);
 29 |     std::uniform_real_distribution<> distrib_real;
 30 |     std::uniform_int_distribution<docidtype> distrib_docid(min_doc_id, max_doc_id);
 31 | 
 32 |     size_t data_point_size = space.get_data_size();
 33 |     char* data = new char[data_point_size * max_elements];
 34 |     for (int i = 0; i < max_elements; i++) {
 35 |         // set vector value
 36 |         char* point_data = data + i * data_point_size;
 37 |         for (int j = 0; j < dim; j++) {
 38 |             char* vec_data = point_data + j * sizeof(float);
 39 |             float value = distrib_real(rng);
 40 |             *(float*)vec_data = value;
 41 |         }
 42 |         // set document id
 43 |         docidtype doc_id = distrib_docid(rng);
 44 |         space.set_doc_id(point_data, doc_id);
 45 |     }
 46 | 
 47 |     // Add data to index
 48 |     std::unordered_map<hnswlib::labeltype, docidtype> label_docid_lookup;
 49 |     for (int i = 0; i < max_elements; i++) {
 50 |         hnswlib::labeltype label = i;
 51 |         char* point_data = data + i * data_point_size;
 52 |         alg_hnsw->addPoint(point_data, label);
 53 |         alg_brute->addPoint(point_data, label);
 54 |         label_docid_lookup[label] = space.get_doc_id(point_data);
 55 |     }
 56 | 
 57 |     // Query random vectors and check overall recall
 58 |     float correct = 0;
 59 |     float total_num_elements = 0;
 60 |     size_t query_size = dim * sizeof(float);
 61 |     for (int i = 0; i < num_queries; i++) {
 62 |         char* query_data = new char[query_size];
 63 |         for (int j = 0; j < dim; j++) {
 64 |             size_t offset = j * sizeof(float);
 65 |             char* vec_data = query_data + offset;
 66 |             float value = distrib_real(rng);
 67 |             *(float*)vec_data = value;
 68 |         }
 69 |         hnswlib::MultiVectorSearchStopCondition<docidtype, dist_t> stop_condition(space, num_docs, ef_collection);
 70 |         std::vector<std::pair<dist_t, hnswlib::labeltype>> hnsw_results =
 71 |             alg_hnsw->searchStopConditionClosest(query_data, stop_condition);
 72 | 
 73 |         // check number of found documents
 74 |         std::unordered_set<docidtype> hnsw_docs;
 75 |         std::unordered_set<hnswlib::labeltype> hnsw_labels;
 76 |         for (auto pair: hnsw_results) {
 77 |             hnswlib::labeltype label = pair.second;
 78 |             hnsw_labels.emplace(label);
 79 |             docidtype doc_id = label_docid_lookup[label];
 80 |             hnsw_docs.emplace(doc_id);
 81 |         }
 82 |         assert(hnsw_docs.size() == num_docs);
 83 | 
 84 |         // Check overall recall
 85 |         std::vector<std::pair<dist_t, hnswlib::labeltype>> gt_results = 
 86 |             alg_brute->searchKnnCloserFirst(query_data, max_elements);
 87 |         std::unordered_set<docidtype> gt_docs;
 88 |         for (int i = 0; i < gt_results.size(); i++) {
 89 |             if (gt_docs.size() == num_docs) {
 90 |                 break;
 91 |             }
 92 |             hnswlib::labeltype gt_label = gt_results[i].second;
 93 |             if (hnsw_labels.find(gt_label) != hnsw_labels.end()) {
 94 |                 correct += 1;
 95 |             }
 96 |             docidtype gt_doc_id = label_docid_lookup[gt_label];
 97 |             gt_docs.emplace(gt_doc_id);
 98 |             total_num_elements += 1;
 99 |         }
100 |         delete[] query_data;
101 |     }
102 |     float recall = correct / total_num_elements;
103 |     std::cout << "random elements search recall : " << recall << "\n";
104 |     assert(recall > 0.95);
105 | 
106 |     // Query the elements for themselves and measure recall
107 |     correct = 0;
108 |     for (int i = 0; i < max_elements; i++) {
109 |         hnswlib::MultiVectorSearchStopCondition<docidtype, dist_t> stop_condition(space, num_docs, ef_collection);
110 |         std::vector<std::pair<float, hnswlib::labeltype>> result =
111 |             alg_hnsw->searchStopConditionClosest(data + i * data_point_size, stop_condition);
112 |         hnswlib::labeltype label = -1;
113 |         if (!result.empty()) {
114 |             label = result[0].second;
115 |         }
116 |         if (label == i) correct++;
117 |     }
118 |     recall = correct / max_elements;
119 |     std::cout << "same elements search recall : " << recall << "\n";
120 |     assert(recall > 0.99);
121 | 
122 |     delete[] data;
123 |     delete alg_brute;
124 |     delete alg_hnsw;
125 |     return 0;
126 | }
127 | 


--------------------------------------------------------------------------------
/tests/cpp/searchKnnCloserFirst_test.cpp:
--------------------------------------------------------------------------------
 1 | // This is a test file for testing the interface
 2 | //  >>> virtual std::vector<std::pair<dist_t, labeltype>>
 3 | //  >>>    searchKnnCloserFirst(const void* query_data, size_t k) const;
 4 | // of class AlgorithmInterface
 5 | 
 6 | #include "../../hnswlib/hnswlib.h"
 7 | 
 8 | #include <assert.h>
 9 | 
10 | #include <vector>
11 | #include <iostream>
12 | 
13 | namespace {
14 | 
15 | using idx_t = hnswlib::labeltype;
16 | 
17 | void test() {
18 |     int d = 4;
19 |     idx_t n = 100;
20 |     idx_t nq = 10;
21 |     size_t k = 10;
22 | 
23 |     std::vector<float> data(n * d);
24 |     std::vector<float> query(nq * d);
25 | 
26 |     std::mt19937 rng;
27 |     rng.seed(47);
28 |     std::uniform_real_distribution<> distrib;
29 | 
30 |     for (idx_t i = 0; i < n * d; ++i) {
31 |         data[i] = distrib(rng);
32 |     }
33 |     for (idx_t i = 0; i < nq * d; ++i) {
34 |         query[i] = distrib(rng);
35 |     }
36 | 
37 |     hnswlib::L2Space space(d);
38 |     hnswlib::AlgorithmInterface<float>* alg_brute  = new hnswlib::BruteforceSearch<float>(&space, 2 * n);
39 |     hnswlib::AlgorithmInterface<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, 2 * n);
40 | 
41 |     for (size_t i = 0; i < n; ++i) {
42 |         alg_brute->addPoint(data.data() + d * i, i);
43 |         alg_hnsw->addPoint(data.data() + d * i, i);
44 |     }
45 | 
46 |     // test searchKnnCloserFirst of BruteforceSearch
47 |     for (size_t j = 0; j < nq; ++j) {
48 |         const void* p = query.data() + j * d;
49 |         auto gd = alg_brute->searchKnn(p, k);
50 |         auto res = alg_brute->searchKnnCloserFirst(p, k);
51 |         assert(gd.size() == res.size());
52 |         size_t t = gd.size();
53 |         while (!gd.empty()) {
54 |             assert(gd.top() == res[--t]);
55 |             gd.pop();
56 |         }
57 |     }
58 |     for (size_t j = 0; j < nq; ++j) {
59 |         const void* p = query.data() + j * d;
60 |         auto gd = alg_hnsw->searchKnn(p, k);
61 |         auto res = alg_hnsw->searchKnnCloserFirst(p, k);
62 |         assert(gd.size() == res.size());
63 |         size_t t = gd.size();
64 |         while (!gd.empty()) {
65 |             assert(gd.top() == res[--t]);
66 |             gd.pop();
67 |         }
68 |     }
69 | 
70 |     delete alg_brute;
71 |     delete alg_hnsw;
72 | }
73 | 
74 | }  // namespace
75 | 
76 | int main() {
77 |     std::cout << "Testing ..." << std::endl;
78 |     test();
79 |     std::cout << "Test ok" << std::endl;
80 | 
81 |     return 0;
82 | }
83 | 


--------------------------------------------------------------------------------
/tests/cpp/searchKnnWithFilter_test.cpp:
--------------------------------------------------------------------------------
  1 | // This is a test file for testing the filtering feature
  2 | 
  3 | #include "../../hnswlib/hnswlib.h"
  4 | 
  5 | #include <assert.h>
  6 | 
  7 | #include <vector>
  8 | #include <iostream>
  9 | 
 10 | namespace {
 11 | 
 12 | using idx_t = hnswlib::labeltype;
 13 | 
 14 | class PickDivisibleIds: public hnswlib::BaseFilterFunctor {
 15 | unsigned int divisor = 1;
 16 |  public:
 17 |     PickDivisibleIds(unsigned int divisor): divisor(divisor) {
 18 |         assert(divisor != 0);
 19 |     }
 20 |     bool operator()(idx_t label_id) {
 21 |         return label_id % divisor == 0;
 22 |     }
 23 | };
 24 | 
 25 | class PickNothing: public hnswlib::BaseFilterFunctor {
 26 |  public:
 27 |     bool operator()(idx_t label_id) {
 28 |         return false;
 29 |     }
 30 | };
 31 | 
 32 | void test_some_filtering(hnswlib::BaseFilterFunctor& filter_func, size_t div_num, size_t label_id_start) {
 33 |     int d = 4;
 34 |     idx_t n = 100;
 35 |     idx_t nq = 10;
 36 |     size_t k = 10;
 37 | 
 38 |     std::vector<float> data(n * d);
 39 |     std::vector<float> query(nq * d);
 40 | 
 41 |     std::mt19937 rng;
 42 |     rng.seed(47);
 43 |     std::uniform_real_distribution<> distrib;
 44 | 
 45 |     for (idx_t i = 0; i < n * d; ++i) {
 46 |         data[i] = distrib(rng);
 47 |     }
 48 |     for (idx_t i = 0; i < nq * d; ++i) {
 49 |         query[i] = distrib(rng);
 50 |     }
 51 | 
 52 |     hnswlib::L2Space space(d);
 53 |     hnswlib::AlgorithmInterface<float>* alg_brute  = new hnswlib::BruteforceSearch<float>(&space, 2 * n);
 54 |     hnswlib::AlgorithmInterface<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, 2 * n);
 55 | 
 56 |     for (size_t i = 0; i < n; ++i) {
 57 |         // `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs
 58 |         alg_brute->addPoint(data.data() + d * i, label_id_start + i);
 59 |         alg_hnsw->addPoint(data.data() + d * i, label_id_start + i);
 60 |     }
 61 | 
 62 |     // test searchKnnCloserFirst of BruteforceSearch with filtering
 63 |     for (size_t j = 0; j < nq; ++j) {
 64 |         const void* p = query.data() + j * d;
 65 |         auto gd = alg_brute->searchKnn(p, k, &filter_func);
 66 |         auto res = alg_brute->searchKnnCloserFirst(p, k, &filter_func);
 67 |         assert(gd.size() == res.size());
 68 |         size_t t = gd.size();
 69 |         while (!gd.empty()) {
 70 |             assert(gd.top() == res[--t]);
 71 |             assert((gd.top().second % div_num) == 0);
 72 |             gd.pop();
 73 |         }
 74 |     }
 75 | 
 76 |     // test searchKnnCloserFirst of hnsw with filtering
 77 |     for (size_t j = 0; j < nq; ++j) {
 78 |         const void* p = query.data() + j * d;
 79 |         auto gd = alg_hnsw->searchKnn(p, k, &filter_func);
 80 |         auto res = alg_hnsw->searchKnnCloserFirst(p, k, &filter_func);
 81 |         assert(gd.size() == res.size());
 82 |         size_t t = gd.size();
 83 |         while (!gd.empty()) {
 84 |             assert(gd.top() == res[--t]);
 85 |             assert((gd.top().second % div_num) == 0);
 86 |             gd.pop();
 87 |         }
 88 |     }
 89 | 
 90 |     delete alg_brute;
 91 |     delete alg_hnsw;
 92 | }
 93 | 
 94 | void test_none_filtering(hnswlib::BaseFilterFunctor& filter_func, size_t label_id_start) {
 95 |     int d = 4;
 96 |     idx_t n = 100;
 97 |     idx_t nq = 10;
 98 |     size_t k = 10;
 99 | 
100 |     std::vector<float> data(n * d);
101 |     std::vector<float> query(nq * d);
102 | 
103 |     std::mt19937 rng;
104 |     rng.seed(47);
105 |     std::uniform_real_distribution<> distrib;
106 | 
107 |     for (idx_t i = 0; i < n * d; ++i) {
108 |         data[i] = distrib(rng);
109 |     }
110 |     for (idx_t i = 0; i < nq * d; ++i) {
111 |         query[i] = distrib(rng);
112 |     }
113 | 
114 |     hnswlib::L2Space space(d);
115 |     hnswlib::AlgorithmInterface<float>* alg_brute  = new hnswlib::BruteforceSearch<float>(&space, 2 * n);
116 |     hnswlib::AlgorithmInterface<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, 2 * n);
117 | 
118 |     for (size_t i = 0; i < n; ++i) {
119 |         // `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs
120 |         alg_brute->addPoint(data.data() + d * i, label_id_start + i);
121 |         alg_hnsw->addPoint(data.data() + d * i, label_id_start + i);
122 |     }
123 | 
124 |     // test searchKnnCloserFirst of BruteforceSearch with filtering
125 |     for (size_t j = 0; j < nq; ++j) {
126 |         const void* p = query.data() + j * d;
127 |         auto gd = alg_brute->searchKnn(p, k, &filter_func);
128 |         auto res = alg_brute->searchKnnCloserFirst(p, k, &filter_func);
129 |         assert(gd.size() == res.size());
130 |         assert(0 == gd.size());
131 |     }
132 | 
133 |     // test searchKnnCloserFirst of hnsw with filtering
134 |     for (size_t j = 0; j < nq; ++j) {
135 |         const void* p = query.data() + j * d;
136 |         auto gd = alg_hnsw->searchKnn(p, k, &filter_func);
137 |         auto res = alg_hnsw->searchKnnCloserFirst(p, k, &filter_func);
138 |         assert(gd.size() == res.size());
139 |         assert(0 == gd.size());
140 |     }
141 | 
142 |     delete alg_brute;
143 |     delete alg_hnsw;
144 | }
145 | 
146 | }  // namespace
147 | 
148 | class CustomFilterFunctor: public hnswlib::BaseFilterFunctor {
149 |     std::unordered_set<idx_t> allowed_values;
150 | 
151 |  public:
152 |     explicit CustomFilterFunctor(const std::unordered_set<idx_t>& values) : allowed_values(values) {}
153 | 
154 |     bool operator()(idx_t id) {
155 |         return allowed_values.count(id) != 0;
156 |     }
157 | };
158 | 
159 | int main() {
160 |     std::cout << "Testing ..." << std::endl;
161 | 
162 |     // some of the elements are filtered
163 |     PickDivisibleIds pickIdsDivisibleByThree(3);
164 |     test_some_filtering(pickIdsDivisibleByThree, 3, 17);
165 |     PickDivisibleIds pickIdsDivisibleBySeven(7);
166 |     test_some_filtering(pickIdsDivisibleBySeven, 7, 17);
167 | 
168 |     // all of the elements are filtered
169 |     PickNothing pickNothing;
170 |     test_none_filtering(pickNothing, 17);
171 | 
172 |     // functor style which can capture context
173 |     CustomFilterFunctor pickIdsDivisibleByThirteen({26, 39, 52, 65});
174 |     test_some_filtering(pickIdsDivisibleByThirteen, 13, 21);
175 | 
176 |     std::cout << "Test ok" << std::endl;
177 | 
178 |     return 0;
179 | }
180 | 


--------------------------------------------------------------------------------
/tests/cpp/update_gen_data.py:
--------------------------------------------------------------------------------
 1 | import numpy as np
 2 | import os
 3 | 
 4 | def normalized(a, axis=-1, order=2):
 5 |     l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
 6 |     l2[l2==0] = 1
 7 |     return a / np.expand_dims(l2, axis)
 8 | 
 9 | N=100000
10 | dummy_data_multiplier=3
11 | N_queries = 1000
12 | d=8
13 | K=5
14 | 
15 | np.random.seed(1)
16 | 
17 | print("Generating data...")
18 | batches_dummy= [ normalized(np.float32(np.random.random( (N,d)))) for _ in range(dummy_data_multiplier)]
19 | batch_final = normalized (np.float32(np.random.random( (N,d))))
20 | queries = normalized(np.float32(np.random.random( (N_queries,d))))
21 | print("Computing distances...")
22 | dist=np.dot(queries,batch_final.T)
23 | topk=np.argsort(-dist)[:,:K]
24 | print("Saving...")
25 | 
26 | try:
27 |     os.mkdir("data")
28 | except OSError as e:
29 |     pass
30 | 
31 | for idx, batch_dummy in enumerate(batches_dummy):
32 |     batch_dummy.tofile('data/batch_dummy_%02d.bin' % idx)
33 | batch_final.tofile('data/batch_final.bin')
34 | queries.tofile('data/queries.bin')
35 | np.int32(topk).tofile('data/gt.bin')
36 | with open("data/config.txt", "w") as file:
37 |     file.write("%d %d %d %d %d" %(N, dummy_data_multiplier, N_queries, d, K))


--------------------------------------------------------------------------------
/tests/cpp/updates_test.cpp:
--------------------------------------------------------------------------------
  1 | #include "../../hnswlib/hnswlib.h"
  2 | #include <thread>
  3 | 
  4 | 
  5 | class StopW {
  6 |     std::chrono::steady_clock::time_point time_begin;
  7 | 
  8 |  public:
  9 |     StopW() {
 10 |         time_begin = std::chrono::steady_clock::now();
 11 |     }
 12 | 
 13 |     float getElapsedTimeMicro() {
 14 |         std::chrono::steady_clock::time_point time_end = std::chrono::steady_clock::now();
 15 |         return (std::chrono::duration_cast<std::chrono::microseconds>(time_end - time_begin).count());
 16 |     }
 17 | 
 18 |     void reset() {
 19 |         time_begin = std::chrono::steady_clock::now();
 20 |     }
 21 | };
 22 | 
 23 | 
 24 | /*
 25 |  * replacement for the openmp '#pragma omp parallel for' directive
 26 |  * only handles a subset of functionality (no reductions etc)
 27 |  * Process ids from start (inclusive) to end (EXCLUSIVE)
 28 |  *
 29 |  * The method is borrowed from nmslib 
 30 |  */
 31 | template<class Function>
 32 | inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) {
 33 |     if (numThreads <= 0) {
 34 |         numThreads = std::thread::hardware_concurrency();
 35 |     }
 36 | 
 37 |     if (numThreads == 1) {
 38 |         for (size_t id = start; id < end; id++) {
 39 |             fn(id, 0);
 40 |         }
 41 |     } else {
 42 |         std::vector<std::thread> threads;
 43 |         std::atomic<size_t> current(start);
 44 | 
 45 |         // keep track of exceptions in threads
 46 |         // https://stackoverflow.com/a/32428427/1713196
 47 |         std::exception_ptr lastException = nullptr;
 48 |         std::mutex lastExceptMutex;
 49 | 
 50 |         for (size_t threadId = 0; threadId < numThreads; ++threadId) {
 51 |             threads.push_back(std::thread([&, threadId] {
 52 |                 while (true) {
 53 |                     size_t id = current.fetch_add(1);
 54 | 
 55 |                     if ((id >= end)) {
 56 |                         break;
 57 |                     }
 58 | 
 59 |                     try {
 60 |                         fn(id, threadId);
 61 |                     } catch (...) {
 62 |                         std::unique_lock<std::mutex> lastExcepLock(lastExceptMutex);
 63 |                         lastException = std::current_exception();
 64 |                         /*
 65 |                          * This will work even when current is the largest value that
 66 |                          * size_t can fit, because fetch_add returns the previous value
 67 |                          * before the increment (what will result in overflow
 68 |                          * and produce 0 instead of current + 1).
 69 |                          */
 70 |                         current = end;
 71 |                         break;
 72 |                     }
 73 |                 }
 74 |             }));
 75 |         }
 76 |         for (auto &thread : threads) {
 77 |             thread.join();
 78 |         }
 79 |         if (lastException) {
 80 |             std::rethrow_exception(lastException);
 81 |         }
 82 |     }
 83 | }
 84 | 
 85 | 
 86 | template <typename datatype>
 87 | std::vector<datatype> load_batch(std::string path, int size) {
 88 |     std::cout << "Loading " << path << "...";
 89 |     // float or int32 (python)
 90 |     assert(sizeof(datatype) == 4);
 91 | 
 92 |     std::ifstream file;
 93 |     file.open(path, std::ios::binary);
 94 |     if (!file.is_open()) {
 95 |         std::cout << "Cannot open " << path << "\n";
 96 |         exit(1);
 97 |     }
 98 |     std::vector<datatype> batch(size);
 99 | 
100 |     file.read((char *)batch.data(), size * sizeof(float));
101 |     std::cout << " DONE\n";
102 |     return batch;
103 | }
104 | 
105 | 
106 | template <typename d_type>
107 | static float
108 | test_approx(std::vector<float> &queries, size_t qsize, hnswlib::HierarchicalNSW<d_type> &appr_alg, size_t vecdim,
109 |             std::vector<std::unordered_set<hnswlib::labeltype>> &answers, size_t K) {
110 |     size_t correct = 0;
111 |     size_t total = 0;
112 | 
113 |     for (int i = 0; i < qsize; i++) {
114 |         std::priority_queue<std::pair<d_type, hnswlib::labeltype>> result = appr_alg.searchKnn((char *)(queries.data() + vecdim * i), K);
115 |         total += K;
116 |         while (result.size()) {
117 |             if (answers[i].find(result.top().second) != answers[i].end()) {
118 |                 correct++;
119 |             } else {
120 |             }
121 |             result.pop();
122 |         }
123 |     }
124 |     return 1.0f * correct / total;
125 | }
126 | 
127 | 
128 | static void
129 | test_vs_recall(
130 |     std::vector<float> &queries,
131 |     size_t qsize,
132 |     hnswlib::HierarchicalNSW<float> &appr_alg,
133 |     size_t vecdim,
134 |     std::vector<std::unordered_set<hnswlib::labeltype>> &answers,
135 |     size_t k) {
136 | 
137 |     std::vector<size_t> efs = {1};
138 |     for (int i = k; i < 30; i++) {
139 |         efs.push_back(i);
140 |     }
141 |     for (int i = 30; i < 400; i+=10) {
142 |         efs.push_back(i);
143 |     }
144 |     for (int i = 1000; i < 100000; i += 5000) {
145 |         efs.push_back(i);
146 |     }
147 |     std::cout << "ef\trecall\ttime\thops\tdistcomp\n";
148 | 
149 |     bool test_passed = false;
150 |     for (size_t ef : efs) {
151 |         appr_alg.setEf(ef);
152 | 
153 |         appr_alg.metric_hops = 0;
154 |         appr_alg.metric_distance_computations = 0;
155 |         StopW stopw = StopW();
156 | 
157 |         float recall = test_approx<float>(queries, qsize, appr_alg, vecdim, answers, k);
158 |         float time_us_per_query = stopw.getElapsedTimeMicro() / qsize;
159 |         float distance_comp_per_query =  appr_alg.metric_distance_computations / (1.0f * qsize);
160 |         float hops_per_query =  appr_alg.metric_hops / (1.0f * qsize);
161 | 
162 |         std::cout << ef << "\t" << recall << "\t" << time_us_per_query << "us \t" << hops_per_query << "\t" << distance_comp_per_query << "\n";
163 |         if (recall > 0.99) {
164 |             test_passed = true;
165 |             std::cout << "Recall is over 0.99! " << recall << "\t" << time_us_per_query << "us \t" << hops_per_query << "\t" << distance_comp_per_query << "\n";
166 |             break;
167 |         }
168 |     }
169 |     if (!test_passed) {
170 |         std::cerr << "Test failed\n";
171 |         exit(1);
172 |     }
173 | }
174 | 
175 | 
176 | int main(int argc, char **argv) {
177 |     int M = 16;
178 |     int efConstruction = 200;
179 |     int num_threads = std::thread::hardware_concurrency();
180 | 
181 |     bool update = false;
182 | 
183 |     if (argc == 2) {
184 |         if (std::string(argv[1]) == "update") {
185 |             update = true;
186 |             std::cout << "Updates are on\n";
187 |         } else {
188 |             std::cout << "Usage ./test_updates [update]\n";
189 |             exit(1);
190 |         }
191 |     } else if (argc > 2) {
192 |         std::cout << "Usage ./test_updates [update]\n";
193 |         exit(1);
194 |     }
195 | 
196 |     std::string path = "../tests/cpp/data/";
197 | 
198 |     int N;
199 |     int dummy_data_multiplier;
200 |     int N_queries;
201 |     int d;
202 |     int K;
203 |     {
204 |         std::ifstream configfile;
205 |         configfile.open(path + "/config.txt");
206 |         if (!configfile.is_open()) {
207 |             std::cout << "Cannot open config.txt\n";
208 |             return 1;
209 |         }
210 |         configfile >> N >> dummy_data_multiplier >> N_queries >> d >> K;
211 | 
212 |         printf("Loaded config: N=%d, d_mult=%d, Nq=%d, dim=%d, K=%d\n", N, dummy_data_multiplier, N_queries, d, K);
213 |     }
214 | 
215 |     hnswlib::L2Space l2space(d);
216 |     hnswlib::HierarchicalNSW<float> appr_alg(&l2space, N + 1, M, efConstruction);
217 | 
218 |     std::vector<float> dummy_batch = load_batch<float>(path + "batch_dummy_00.bin", N * d);
219 | 
220 |     // Adding enterpoint:
221 | 
222 |     appr_alg.addPoint((void *)dummy_batch.data(), (size_t)0);
223 | 
224 |     StopW stopw = StopW();
225 | 
226 |     if (update) {
227 |         std::cout << "Update iteration 0\n";
228 | 
229 |         ParallelFor(1, N, num_threads, [&](size_t i, size_t threadId) {
230 |             appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i);
231 |         });
232 |         appr_alg.checkIntegrity();
233 | 
234 |         ParallelFor(1, N, num_threads, [&](size_t i, size_t threadId) {
235 |             appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i);
236 |         });
237 |         appr_alg.checkIntegrity();
238 | 
239 |         for (int b = 1; b < dummy_data_multiplier; b++) {
240 |             std::cout << "Update iteration " << b << "\n";
241 |             char cpath[1024];
242 |             snprintf(cpath, sizeof(cpath), "batch_dummy_%02d.bin", b);
243 |             std::vector<float> dummy_batchb = load_batch<float>(path + cpath, N * d);
244 | 
245 |             ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) {
246 |                 appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i);
247 |             });
248 |             appr_alg.checkIntegrity();
249 |         }
250 |     }
251 | 
252 |     std::cout << "Inserting final elements\n";
253 |     std::vector<float> final_batch = load_batch<float>(path + "batch_final.bin", N * d);
254 | 
255 |     stopw.reset();
256 |     ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) {
257 |                     appr_alg.addPoint((void *)(final_batch.data() + i * d), i);
258 |                 });
259 |     std::cout << "Finished. Time taken:" << stopw.getElapsedTimeMicro()*1e-6 << " s\n";
260 |     std::cout << "Running tests\n";
261 |     std::vector<float> queries_batch = load_batch<float>(path + "queries.bin", N_queries * d);
262 | 
263 |     std::vector<int> gt = load_batch<int>(path + "gt.bin", N_queries * K);
264 | 
265 |     std::vector<std::unordered_set<hnswlib::labeltype>> answers(N_queries);
266 |     for (int i = 0; i < N_queries; i++) {
267 |         for (int j = 0; j < K; j++) {
268 |             answers[i].insert(gt[i * K + j]);
269 |         }
270 |     }
271 | 
272 |     for (int i = 0; i < 3; i++) {
273 |         std::cout << "Test iteration " << i << "\n";
274 |         test_vs_recall(queries_batch, N_queries, appr_alg, d, answers, K);
275 |     }
276 | 
277 |     return 0;
278 | }
279 | 


--------------------------------------------------------------------------------
/tests/python/bindings_test.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import unittest
 3 | 
 4 | import numpy as np
 5 | 
 6 | import hnswlib
 7 | 
 8 | 
 9 | class RandomSelfTestCase(unittest.TestCase):
10 |     def testRandomSelf(self):
11 | 
12 |         dim = 16
13 |         num_elements = 10000
14 | 
15 |         # Generating sample data
16 |         data = np.float32(np.random.random((num_elements, dim)))
17 | 
18 |         # Declaring index
19 |         p = hnswlib.Index(space='l2', dim=dim)  # possible options are l2, cosine or ip
20 | 
21 |         # Initiating index
22 |         # max_elements - the maximum number of elements, should be known beforehand
23 |         #     (probably will be made optional in the future)
24 |         #
25 |         # ef_construction - controls index search speed/build speed tradeoff
26 |         # M - is tightly connected with internal dimensionality of the data
27 |         #     strongly affects the memory consumption
28 | 
29 |         p.init_index(max_elements=num_elements, ef_construction=100, M=16)
30 | 
31 |         # Controlling the recall by setting ef:
32 |         # higher ef leads to better accuracy, but slower search
33 |         p.set_ef(10)
34 | 
35 |         p.set_num_threads(4)  # by default using all available cores
36 | 
37 |         # We split the data in two batches:
38 |         data1 = data[:num_elements // 2]
39 |         data2 = data[num_elements // 2:]
40 | 
41 |         print("Adding first batch of %d elements" % (len(data1)))
42 |         p.add_items(data1)
43 | 
44 |         # Query the elements for themselves and measure recall:
45 |         labels, distances = p.knn_query(data1, k=1)
46 |         self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data1))), 1.0, 3)
47 | 
48 |         # Serializing and deleting the index:
49 |         index_path = 'first_half.bin'
50 |         print("Saving index to '%s'" % index_path)
51 |         p.save_index(index_path)
52 |         del p
53 | 
54 |         # Re-initiating, loading the index
55 |         p = hnswlib.Index(space='l2', dim=dim)  # you can change the sa
56 | 
57 |         print("\nLoading index from '%s'\n" % index_path)
58 |         p.load_index(index_path)
59 | 
60 |         print("Adding the second batch of %d elements" % (len(data2)))
61 |         p.add_items(data2)
62 | 
63 |         # Query the elements for themselves and measure recall:
64 |         labels, distances = p.knn_query(data, k=1)
65 | 
66 |         self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), 1.0, 3)
67 |         
68 |         os.remove(index_path)
69 | 


--------------------------------------------------------------------------------
/tests/python/bindings_test_filter.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import unittest
 3 | 
 4 | import numpy as np
 5 | 
 6 | import hnswlib
 7 | 
 8 | 
 9 | class RandomSelfTestCase(unittest.TestCase):
10 |     def testRandomSelf(self):
11 | 
12 |         dim = 16
13 |         num_elements = 10000
14 | 
15 |         # Generating sample data
16 |         data = np.float32(np.random.random((num_elements, dim)))
17 | 
18 |         # Declaring index
19 |         hnsw_index = hnswlib.Index(space='l2', dim=dim)  # possible options are l2, cosine or ip
20 |         bf_index = hnswlib.BFIndex(space='l2', dim=dim)
21 | 
22 |         # Initiating index
23 |         # max_elements - the maximum number of elements, should be known beforehand
24 |         #     (probably will be made optional in the future)
25 |         #
26 |         # ef_construction - controls index search speed/build speed tradeoff
27 |         # M - is tightly connected with internal dimensionality of the data
28 |         #     strongly affects the memory consumption
29 | 
30 |         hnsw_index.init_index(max_elements=num_elements, ef_construction=100, M=16)
31 |         bf_index.init_index(max_elements=num_elements)
32 | 
33 |         # Controlling the recall by setting ef:
34 |         # higher ef leads to better accuracy, but slower search
35 |         hnsw_index.set_ef(10)
36 | 
37 |         hnsw_index.set_num_threads(4)  # by default using all available cores
38 | 
39 |         print("Adding %d elements" % (len(data)))
40 |         hnsw_index.add_items(data)
41 |         bf_index.add_items(data)
42 | 
43 |         # Query the elements for themselves and measure recall:
44 |         labels, distances = hnsw_index.knn_query(data, k=1)
45 |         self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), 1.0, 3)
46 | 
47 |         print("Querying only even elements")
48 |         # Query the even elements for themselves and measure recall:
49 |         filter_function = lambda id: id%2 == 0
50 |         # Warning: search with a filter works slow in python in multithreaded mode, therefore we set num_threads=1
51 |         labels, distances = hnsw_index.knn_query(data, k=1, num_threads=1, filter=filter_function)
52 |         self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), .5, 3)
53 |         # Verify that there are only even elements:
54 |         self.assertTrue(np.max(np.mod(labels, 2)) == 0)
55 | 
56 |         labels, distances = bf_index.knn_query(data, k=1, filter=filter_function)
57 |         self.assertEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), .5)
58 | 


--------------------------------------------------------------------------------
/tests/python/bindings_test_getdata.py:
--------------------------------------------------------------------------------
 1 | import unittest
 2 | 
 3 | import numpy as np
 4 | 
 5 | import hnswlib
 6 | 
 7 | 
 8 | class RandomSelfTestCase(unittest.TestCase):
 9 |     def testGettingItems(self):
10 |         print("\n**** Getting the data by label test ****\n")
11 | 
12 |         dim = 16
13 |         num_elements = 10000
14 | 
15 |         # Generating sample data
16 |         data = np.float32(np.random.random((num_elements, dim)))
17 |         labels = np.arange(0, num_elements)
18 | 
19 |         # Declaring index
20 |         p = hnswlib.Index(space='l2', dim=dim)  # possible options are l2, cosine or ip
21 | 
22 |         # Initiating index
23 |         # max_elements - the maximum number of elements, should be known beforehand
24 |         #     (probably will be made optional in the future)
25 |         #
26 |         # ef_construction - controls index search speed/build speed tradeoff
27 |         # M - is tightly connected with internal dimensionality of the data
28 |         #     strongly affects the memory consumption
29 | 
30 |         p.init_index(max_elements=num_elements, ef_construction=100, M=16)
31 | 
32 |         # Controlling the recall by setting ef:
33 |         # higher ef leads to better accuracy, but slower search
34 |         p.set_ef(100)
35 | 
36 |         p.set_num_threads(4)  # by default using all available cores
37 | 
38 |         # Before adding anything, getting any labels should fail
39 |         self.assertRaises(Exception, lambda: p.get_items(labels))
40 | 
41 |         print("Adding all elements (%d)" % (len(data)))
42 |         p.add_items(data, labels)
43 | 
44 |         # Getting data by label should raise an exception if a scalar is passed:
45 |         self.assertRaises(ValueError, lambda: p.get_items(labels[0]))
46 | 
47 |         # After adding them, all labels should be retrievable
48 |         returned_items_np = p.get_items(labels)
49 |         self.assertTrue((data == returned_items_np).all())
50 | 
51 |         # check returned type of get_items
52 |         self.assertTrue(isinstance(returned_items_np, np.ndarray))
53 |         returned_items_list = p.get_items(labels, return_type="list")
54 |         self.assertTrue(isinstance(returned_items_list, list))
55 |         self.assertTrue(isinstance(returned_items_list[0], list))
56 | 


--------------------------------------------------------------------------------
/tests/python/bindings_test_labels.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import unittest
  3 | 
  4 | import numpy as np
  5 | 
  6 | import hnswlib
  7 | 
  8 | 
  9 | class RandomSelfTestCase(unittest.TestCase):
 10 |     def testRandomSelf(self):
 11 |         for idx in range(2):
 12 |             print("\n**** Index save-load test ****\n")
 13 | 
 14 |             np.random.seed(idx)
 15 |             dim = 16
 16 |             num_elements = 10000
 17 | 
 18 |             # Generating sample data
 19 |             data = np.float32(np.random.random((num_elements, dim)))
 20 | 
 21 |             # Declaring index
 22 |             p = hnswlib.Index(space='l2', dim=dim)  # possible options are l2, cosine or ip
 23 | 
 24 |             # Initiating index
 25 |             # max_elements - the maximum number of elements, should be known beforehand
 26 |             #     (probably will be made optional in the future)
 27 |             #
 28 |             # ef_construction - controls index search speed/build speed tradeoff
 29 |             # M - is tightly connected with internal dimensionality of the data
 30 |             #     strongly affects the memory consumption
 31 | 
 32 |             p.init_index(max_elements=num_elements, ef_construction=100, M=16)
 33 | 
 34 |             # Controlling the recall by setting ef:
 35 |             # higher ef leads to better accuracy, but slower search
 36 |             p.set_ef(100)
 37 | 
 38 |             p.set_num_threads(4)  # by default using all available cores
 39 | 
 40 |             # We split the data in two batches:
 41 |             data1 = data[:num_elements // 2]
 42 |             data2 = data[num_elements // 2:]
 43 | 
 44 |             print("Adding first batch of %d elements" % (len(data1)))
 45 |             p.add_items(data1)
 46 | 
 47 |             # Query the elements for themselves and measure recall:
 48 |             labels, distances = p.knn_query(data1, k=1)
 49 | 
 50 |             items = p.get_items(labels)
 51 | 
 52 |             # Check the recall:
 53 |             self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data1))), 1.0, 3)
 54 | 
 55 |             # Check that the returned element data is correct:
 56 |             diff_with_gt_labels=np.mean(np.abs(data1-items))
 57 |             self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-4)
 58 | 
 59 |             # Serializing and deleting the index.
 60 |             # We need the part to check that serialization is working properly.
 61 | 
 62 |             index_path = 'first_half.bin'
 63 |             print("Saving index to '%s'" % index_path)
 64 |             p.save_index(index_path)
 65 |             print("Saved. Deleting...")
 66 |             del p
 67 |             print("Deleted")
 68 | 
 69 |             print("\n**** Mark delete test ****\n")
 70 |             # Re-initiating, loading the index
 71 |             print("Re-initiating")
 72 |             p = hnswlib.Index(space='l2', dim=dim)
 73 | 
 74 |             print("\nLoading index from '%s'\n" % index_path)
 75 |             p.load_index(index_path)
 76 |             p.set_ef(100)
 77 | 
 78 |             print("Adding the second batch of %d elements" % (len(data2)))
 79 |             p.add_items(data2)
 80 | 
 81 |             # Query the elements for themselves and measure recall:
 82 |             labels, distances = p.knn_query(data, k=1)
 83 |             items = p.get_items(labels)
 84 | 
 85 |             # Check the recall:
 86 |             self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), 1.0, 3)
 87 | 
 88 |             # Check that the returned element data is correct:
 89 |             diff_with_gt_labels = np.mean(np.abs(data-items))
 90 |             self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-4) # deleting index.
 91 | 
 92 |             # Checking that all labels are returned correctly:
 93 |             sorted_labels = sorted(p.get_ids_list())
 94 |             self.assertEqual(np.sum(~np.asarray(sorted_labels) == np.asarray(range(num_elements))), 0)
 95 | 
 96 |             # Delete data1
 97 |             labels1_deleted, _ = p.knn_query(data1, k=1)
 98 |             # delete probable duplicates from nearest neighbors
 99 |             labels1_deleted_no_dup = set(labels1_deleted.flatten())
100 |             for l in labels1_deleted_no_dup:
101 |                 p.mark_deleted(l)
102 |             labels2, _ = p.knn_query(data2, k=1)
103 |             items = p.get_items(labels2)
104 |             diff_with_gt_labels = np.mean(np.abs(data2-items))
105 |             self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-3)
106 | 
107 |             labels1_after, _ = p.knn_query(data1, k=1)
108 |             for la in labels1_after:
109 |                 if la[0] in labels1_deleted_no_dup:
110 |                     print(f"Found deleted label {la[0]} during knn search")
111 |                     self.assertTrue(False)
112 |             print("All the data in data1 are removed")
113 | 
114 |             # Checking saving/loading index with elements marked as deleted
115 |             del_index_path = "with_deleted.bin"
116 |             p.save_index(del_index_path)
117 |             p = hnswlib.Index(space='l2', dim=dim)
118 |             p.load_index(del_index_path)
119 |             p.set_ef(100)
120 | 
121 |             labels1_after, _ = p.knn_query(data1, k=1)
122 |             for la in labels1_after:
123 |                 if la[0] in labels1_deleted_no_dup:
124 |                     print(f"Found deleted label {la[0]} during knn search after index loading")
125 |                     self.assertTrue(False)
126 | 
127 |             # Unmark deleted data
128 |             for l in labels1_deleted_no_dup:
129 |                 p.unmark_deleted(l)
130 |             labels_restored, _ = p.knn_query(data1, k=1)
131 |             self.assertAlmostEqual(np.mean(labels_restored.reshape(-1) == np.arange(len(data1))), 1.0, 3)
132 |             print("All the data in data1 are restored")
133 | 
134 |         os.remove(index_path)
135 |         os.remove(del_index_path)
136 | 


--------------------------------------------------------------------------------
/tests/python/bindings_test_metadata.py:
--------------------------------------------------------------------------------
 1 | import unittest
 2 | 
 3 | import numpy as np
 4 | 
 5 | import hnswlib
 6 | 
 7 | 
 8 | class RandomSelfTestCase(unittest.TestCase):
 9 |     def testMetadata(self):
10 | 
11 |         dim = 16
12 |         num_elements = 10000
13 | 
14 |         # Generating sample data
15 |         data = np.float32(np.random.random((num_elements, dim)))
16 | 
17 |         # Declaring index
18 |         p = hnswlib.Index(space='l2', dim=dim)  # possible options are l2, cosine or ip
19 | 
20 |         # Initing index
21 |         # max_elements - the maximum number of elements, should be known beforehand
22 |         #     (probably will be made optional in the future)
23 |         #
24 |         # ef_construction - controls index search speed/build speed tradeoff
25 |         # M - is tightly connected with internal dimensionality of the data
26 |         #     stronlgy affects the memory consumption
27 | 
28 |         p.init_index(max_elements=num_elements, ef_construction=100, M=16)
29 | 
30 |         # Controlling the recall by setting ef:
31 |         # higher ef leads to better accuracy, but slower search
32 |         p.set_ef(100)
33 | 
34 |         p.set_num_threads(4)  # by default using all available cores
35 | 
36 |         print("Adding all elements (%d)" % (len(data)))
37 |         p.add_items(data)
38 | 
39 |         # test methods
40 |         self.assertEqual(p.get_max_elements(), num_elements)
41 |         self.assertEqual(p.get_current_count(), num_elements)
42 | 
43 |         # test properties
44 |         self.assertEqual(p.space, 'l2')
45 |         self.assertEqual(p.dim, dim)
46 |         self.assertEqual(p.M, 16)
47 |         self.assertEqual(p.ef_construction, 100)
48 |         self.assertEqual(p.max_elements, num_elements)
49 |         self.assertEqual(p.element_count, num_elements)
50 | 


--------------------------------------------------------------------------------
/tests/python/bindings_test_pickle.py:
--------------------------------------------------------------------------------
  1 | import pickle
  2 | import unittest
  3 | 
  4 | import numpy as np
  5 | 
  6 | import hnswlib
  7 | 
  8 | 
  9 | def get_dist(metric, pt1, pt2):
 10 |     if metric == 'l2':
 11 |         return np.sum((pt1-pt2)**2)
 12 |     elif metric == 'ip':
 13 |         return 1. - np.sum(np.multiply(pt1, pt2))
 14 |     elif metric == 'cosine':
 15 |         return 1. - np.sum(np.multiply(pt1, pt2)) / (np.sum(pt1**2) * np.sum(pt2**2))**.5
 16 | 
 17 | 
 18 | def brute_force_distances(metric, items, query_items, k):
 19 |     dists = np.zeros((query_items.shape[0], items.shape[0]))
 20 |     for ii in range(items.shape[0]):
 21 |         for jj in range(query_items.shape[0]):
 22 |             dists[jj,ii] = get_dist(metric, items[ii, :], query_items[jj, :])
 23 | 
 24 |     labels = np.argsort(dists, axis=1) # equivalent, but faster: np.argpartition(dists, range(k), axis=1)
 25 |     dists = np.sort(dists, axis=1)     # equivalent, but faster: np.partition(dists, range(k), axis=1)
 26 | 
 27 |     return labels[:, :k], dists[:, :k]
 28 | 
 29 | 
 30 | def check_ann_results(self, metric, items, query_items, k, ann_l, ann_d, err_thresh=0, total_thresh=0, dists_thresh=0):
 31 |     brute_l, brute_d = brute_force_distances(metric, items, query_items, k)
 32 |     err_total = 0
 33 |     for jj in range(query_items.shape[0]):
 34 |         err = np.sum(np.isin(brute_l[jj, :], ann_l[jj, :], invert=True))
 35 |         if err > 0:
 36 |             print(f"Warning: {err} labels are missing from ann results (k={k}, err_thresh={err_thresh})")
 37 | 
 38 |         if err > err_thresh:
 39 |             err_total += 1
 40 | 
 41 |     self.assertLessEqual(err_total, total_thresh, f"Error: knn_query returned incorrect labels for {err_total} items (k={k})")
 42 | 
 43 |     wrong_dists = np.sum(((brute_d - ann_d)**2.) > 1e-3)
 44 |     if wrong_dists > 0:
 45 |         dists_count = brute_d.shape[0]*brute_d.shape[1]
 46 |         print(f"Warning: {wrong_dists} ann distance values are different from brute-force values (total # of values={dists_count}, dists_thresh={dists_thresh})")
 47 | 
 48 |     self.assertLessEqual(wrong_dists, dists_thresh, msg=f"Error: {wrong_dists} ann distance values are different from brute-force values")
 49 | 
 50 | 
 51 | def test_space_main(self, space, dim):
 52 | 
 53 |     # Generating sample data
 54 |     data = np.float32(np.random.random((self.num_elements, dim)))
 55 |     test_data = np.float32(np.random.random((self.num_test_elements, dim)))
 56 | 
 57 |     # Declaring index
 58 |     p = hnswlib.Index(space=space, dim=dim)  # possible options are l2, cosine or ip
 59 |     print(f"Running pickle tests for {p}")
 60 | 
 61 |     p.num_threads = self.num_threads  # by default using all available cores
 62 | 
 63 |     p0 = pickle.loads(pickle.dumps(p)) # pickle un-initialized Index
 64 |     p.init_index(max_elements=self.num_elements, ef_construction=self.ef_construction, M=self.M)
 65 |     p0.init_index(max_elements=self.num_elements, ef_construction=self.ef_construction, M=self.M)
 66 | 
 67 |     p.ef = self.ef
 68 |     p0.ef = self.ef
 69 | 
 70 |     p1 = pickle.loads(pickle.dumps(p)) # pickle Index before adding items
 71 | 
 72 |     # add items to ann index p,p0,p1
 73 |     p.add_items(data)
 74 |     p1.add_items(data)
 75 |     p0.add_items(data)
 76 | 
 77 |     p2=pickle.loads(pickle.dumps(p)) # pickle Index before adding items
 78 | 
 79 |     self.assertTrue(np.allclose(p.get_items(), p0.get_items()), "items for p and p0 must be same")
 80 |     self.assertTrue(np.allclose(p0.get_items(), p1.get_items()), "items for p0 and p1 must be same")
 81 |     self.assertTrue(np.allclose(p1.get_items(), p2.get_items()), "items for p1 and p2 must be same")
 82 | 
 83 |     # Test if returned distances are same
 84 |     l, d = p.knn_query(test_data, k=self.k)
 85 |     l0, d0 = p0.knn_query(test_data, k=self.k)
 86 |     l1, d1 = p1.knn_query(test_data, k=self.k)
 87 |     l2, d2 = p2.knn_query(test_data, k=self.k)
 88 | 
 89 |     self.assertLessEqual(np.sum(((d-d0)**2.) > 1e-3), self.dists_err_thresh, msg=f"knn distances returned by p and p0 must match")
 90 |     self.assertLessEqual(np.sum(((d0-d1)**2.) > 1e-3), self.dists_err_thresh, msg=f"knn distances returned by p0 and p1 must match")
 91 |     self.assertLessEqual(np.sum(((d1-d2)**2.) > 1e-3), self.dists_err_thresh, msg=f"knn distances returned by p1 and p2 must match")
 92 | 
 93 |     # check if ann results match brute-force search
 94 |     #   allow for 2 labels to be missing from ann results
 95 |     check_ann_results(self, space, data, test_data, self.k, l, d,
 96 |                            err_thresh=self.label_err_thresh,
 97 |                            total_thresh=self.item_err_thresh,
 98 |                            dists_thresh=self.dists_err_thresh)
 99 | 
100 |     check_ann_results(self, space, data, test_data, self.k, l2, d2,
101 |                            err_thresh=self.label_err_thresh,
102 |                            total_thresh=self.item_err_thresh,
103 |                            dists_thresh=self.dists_err_thresh)
104 | 
105 |     # Check ef parameter value
106 |     self.assertEqual(p.ef, self.ef, "incorrect value of p.ef")
107 |     self.assertEqual(p0.ef, self.ef, "incorrect value of p0.ef")
108 |     self.assertEqual(p2.ef, self.ef, "incorrect value of p2.ef")
109 |     self.assertEqual(p1.ef, self.ef, "incorrect value of p1.ef")
110 | 
111 |     # Check M parameter value
112 |     self.assertEqual(p.M, self.M, "incorrect value of p.M")
113 |     self.assertEqual(p0.M, self.M, "incorrect value of p0.M")
114 |     self.assertEqual(p1.M, self.M, "incorrect value of p1.M")
115 |     self.assertEqual(p2.M, self.M, "incorrect value of p2.M")
116 | 
117 |     # Check ef_construction parameter value
118 |     self.assertEqual(p.ef_construction, self.ef_construction, "incorrect value of p.ef_construction")
119 |     self.assertEqual(p0.ef_construction, self.ef_construction, "incorrect value of p0.ef_construction")
120 |     self.assertEqual(p1.ef_construction, self.ef_construction, "incorrect value of p1.ef_construction")
121 |     self.assertEqual(p2.ef_construction, self.ef_construction, "incorrect value of p2.ef_construction")
122 | 
123 | 
124 | class PickleUnitTests(unittest.TestCase):
125 | 
126 |     def setUp(self):
127 |         self.ef_construction = 200
128 |         self.M = 32
129 |         self.ef = 400
130 | 
131 |         self.num_elements = 1000
132 |         self.num_test_elements = 100
133 | 
134 |         self.num_threads = 4
135 |         self.k = 25
136 | 
137 |         self.label_err_thresh = 5  # max number of missing labels allowed per test item
138 |         self.item_err_thresh = 5   # max number of items allowed with incorrect labels
139 | 
140 |         self.dists_err_thresh = 50 # for two matrices, d1 and d2, dists_err_thresh controls max
141 |                                  # number of value pairs that are allowed to be different in d1 and d2
142 |                                  # i.e., number of values that are (d1-d2)**2>1e-3
143 | 
144 |     def test_inner_product_space(self):
145 |         test_space_main(self, 'ip', 16)
146 | 
147 |     def test_l2_space(self):
148 |         test_space_main(self, 'l2', 53)
149 | 
150 |     def test_cosine_space(self):
151 |         test_space_main(self, 'cosine', 32)
152 | 


--------------------------------------------------------------------------------
/tests/python/bindings_test_recall.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import hnswlib
  3 | import numpy as np
  4 | import unittest
  5 | 
  6 | 
  7 | class RandomSelfTestCase(unittest.TestCase):
  8 |     def testRandomSelf(self):
  9 |         dim = 32
 10 |         num_elements = 100000
 11 |         k = 10
 12 |         num_queries = 20
 13 | 
 14 |         recall_threshold = 0.95
 15 | 
 16 |         # Generating sample data
 17 |         data = np.float32(np.random.random((num_elements, dim)))
 18 | 
 19 |         # Declaring index
 20 |         hnsw_index = hnswlib.Index(space='l2', dim=dim)  # possible options are l2, cosine or ip
 21 |         bf_index = hnswlib.BFIndex(space='l2', dim=dim)
 22 | 
 23 |         # Initing both hnsw and brute force indices
 24 |         # max_elements - the maximum number of elements (capacity). Will throw an exception if exceeded
 25 |         # during insertion of an element.
 26 |         # The capacity can be increased by saving/loading the index, see below.
 27 |         #
 28 |         # hnsw construction params:
 29 |         # ef_construction - controls index search speed/build speed tradeoff
 30 |         #
 31 |         # M - is tightly connected with internal dimensionality of the data. Strongly affects the memory consumption (~M)
 32 |         # Higher M leads to higher accuracy/run_time at fixed ef/efConstruction
 33 | 
 34 |         hnsw_index.init_index(max_elements=num_elements, ef_construction=200, M=16)
 35 |         bf_index.init_index(max_elements=num_elements)
 36 | 
 37 |         # Controlling the recall for hnsw by setting ef:
 38 |         # higher ef leads to better accuracy, but slower search
 39 |         hnsw_index.set_ef(200)
 40 | 
 41 |         # Set number of threads used during batch search/construction in hnsw
 42 |         # By default using all available cores
 43 |         hnsw_index.set_num_threads(4)
 44 | 
 45 |         print("Adding batch of %d elements" % (len(data)))
 46 |         hnsw_index.add_items(data)
 47 |         bf_index.add_items(data)
 48 | 
 49 |         print("Indices built")
 50 | 
 51 |         # Generating query data
 52 |         query_data = np.float32(np.random.random((num_queries, dim)))
 53 | 
 54 |         # Query the elements and measure recall:
 55 |         labels_hnsw, distances_hnsw = hnsw_index.knn_query(query_data, k)
 56 |         labels_bf, distances_bf = bf_index.knn_query(query_data, k)
 57 | 
 58 |         # Measure recall
 59 |         correct = 0
 60 |         for i in range(num_queries):
 61 |             for label in labels_hnsw[i]:
 62 |                 for correct_label in labels_bf[i]:
 63 |                     if label == correct_label:
 64 |                         correct += 1
 65 |                         break
 66 | 
 67 |         recall_before = float(correct) / (k*num_queries)
 68 |         print("recall is :", recall_before)
 69 |         self.assertGreater(recall_before, recall_threshold)
 70 | 
 71 |         # test serializing  the brute force index
 72 |         index_path = 'bf_index.bin'
 73 |         print("Saving index to '%s'" % index_path)
 74 |         bf_index.save_index(index_path)
 75 |         del bf_index
 76 | 
 77 |         # Re-initiating, loading the index
 78 |         bf_index = hnswlib.BFIndex(space='l2', dim=dim)
 79 | 
 80 |         print("\nLoading index from '%s'\n" % index_path)
 81 |         bf_index.load_index(index_path)
 82 | 
 83 |         # Query the brute force index again to verify that we get the same results
 84 |         labels_bf, distances_bf = bf_index.knn_query(query_data, k)
 85 | 
 86 |         # Measure recall
 87 |         correct = 0
 88 |         for i in range(num_queries):
 89 |             for label in labels_hnsw[i]:
 90 |                 for correct_label in labels_bf[i]:
 91 |                     if label == correct_label:
 92 |                         correct += 1
 93 |                         break
 94 | 
 95 |         recall_after = float(correct) / (k*num_queries)
 96 |         print("recall after reloading is :", recall_after)
 97 | 
 98 |         self.assertEqual(recall_before, recall_after)
 99 | 
100 |         os.remove(index_path)
101 | 


--------------------------------------------------------------------------------
/tests/python/bindings_test_replace.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import pickle
  3 | import unittest
  4 | 
  5 | import numpy as np
  6 | 
  7 | import hnswlib
  8 | 
  9 | 
 10 | class RandomSelfTestCase(unittest.TestCase):
 11 |     def testRandomSelf(self):
 12 |         """
 13 |             Tests if replace of deleted elements works correctly
 14 |             Tests serialization of the index with replaced elements
 15 |         """
 16 |         dim = 16
 17 |         num_elements = 5000
 18 |         max_num_elements = 2 * num_elements
 19 | 
 20 |         recall_threshold = 0.98
 21 | 
 22 |         # Generating sample data
 23 |         print("Generating data")
 24 |         # batch 1
 25 |         first_id = 0
 26 |         last_id = num_elements
 27 |         labels1 = np.arange(first_id, last_id)
 28 |         data1 = np.float32(np.random.random((num_elements, dim)))
 29 |         # batch 2
 30 |         first_id += num_elements
 31 |         last_id += num_elements
 32 |         labels2 = np.arange(first_id, last_id)
 33 |         data2 = np.float32(np.random.random((num_elements, dim)))
 34 |         # batch 3
 35 |         first_id += num_elements
 36 |         last_id += num_elements
 37 |         labels3 = np.arange(first_id, last_id)
 38 |         data3 = np.float32(np.random.random((num_elements, dim)))
 39 |         # batch 4
 40 |         first_id += num_elements
 41 |         last_id += num_elements
 42 |         labels4 = np.arange(first_id, last_id)
 43 |         data4 = np.float32(np.random.random((num_elements, dim)))
 44 | 
 45 |         # Declaring index
 46 |         hnsw_index = hnswlib.Index(space='l2', dim=dim)
 47 |         hnsw_index.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True)
 48 | 
 49 |         hnsw_index.set_ef(100)
 50 |         hnsw_index.set_num_threads(4)
 51 | 
 52 |         # Add batch 1 and 2
 53 |         print("Adding batch 1")
 54 |         hnsw_index.add_items(data1, labels1)
 55 |         print("Adding batch 2")
 56 |         hnsw_index.add_items(data2, labels2)  # maximum number of elements is reached
 57 | 
 58 |         # Delete nearest neighbors of batch 2
 59 |         print("Deleting neighbors of batch 2")
 60 |         labels2_deleted, _ = hnsw_index.knn_query(data2, k=1)
 61 |         # delete probable duplicates from nearest neighbors
 62 |         labels2_deleted_no_dup = set(labels2_deleted.flatten())
 63 |         num_duplicates = len(labels2_deleted) - len(labels2_deleted_no_dup)
 64 |         for l in labels2_deleted_no_dup:
 65 |             hnsw_index.mark_deleted(l)
 66 |         labels1_found, _ = hnsw_index.knn_query(data1, k=1)
 67 |         items = hnsw_index.get_items(labels1_found)
 68 |         diff_with_gt_labels = np.mean(np.abs(data1 - items))
 69 |         self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-3)
 70 | 
 71 |         labels2_after, _ = hnsw_index.knn_query(data2, k=1)
 72 |         for la in labels2_after:
 73 |             if la[0] in labels2_deleted_no_dup:
 74 |                 print(f"Found deleted label {la[0]} during knn search")
 75 |                 self.assertTrue(False)
 76 |         print("All the neighbors of data2 are removed")
 77 | 
 78 |         # Replace deleted elements
 79 |         print("Inserting batch 3 by replacing deleted elements")
 80 |         # Maximum number of elements is reached therefore we cannot add new items
 81 |         # but we can replace the deleted ones
 82 |         # Note: there may be less than num_elements elements.
 83 |         #       As we could delete less than num_elements because of duplicates
 84 |         labels3_tr = labels3[0:labels3.shape[0] - num_duplicates]
 85 |         data3_tr = data3[0:data3.shape[0] - num_duplicates]
 86 |         hnsw_index.add_items(data3_tr, labels3_tr, replace_deleted=True)
 87 | 
 88 |         # After replacing, all labels should be retrievable
 89 |         print("Checking that remaining labels are in index")
 90 |         # Get remaining data from batch 1 and batch 2 after deletion of elements
 91 |         remaining_labels = (set(labels1) | set(labels2)) - labels2_deleted_no_dup
 92 |         remaining_labels_list = list(remaining_labels)
 93 |         comb_data = np.concatenate((data1, data2), axis=0)
 94 |         remaining_data = comb_data[remaining_labels_list]
 95 | 
 96 |         returned_items = hnsw_index.get_items(remaining_labels_list)
 97 |         self.assertTrue((remaining_data == returned_items).all())
 98 | 
 99 |         returned_items = hnsw_index.get_items(labels3_tr)
100 |         self.assertTrue((data3_tr == returned_items).all())
101 | 
102 |         # Check index serialization
103 |         # Delete batch 3
104 |         print("Deleting batch 3")
105 |         for l in labels3_tr:
106 |             hnsw_index.mark_deleted(l)
107 | 
108 |         # Save index
109 |         index_path = "index.bin"
110 |         print(f"Saving index to {index_path}")
111 |         hnsw_index.save_index(index_path)
112 |         del hnsw_index
113 | 
114 |         # Reinit and load the index
115 |         hnsw_index = hnswlib.Index(space='l2', dim=dim)  # the space can be changed - keeps the data, alters the distance function.
116 |         hnsw_index.set_num_threads(4)
117 |         print(f"Loading index from {index_path}")
118 |         hnsw_index.load_index(index_path, max_elements=max_num_elements, allow_replace_deleted=True)
119 | 
120 |         # Insert batch 4
121 |         print("Inserting batch 4 by replacing deleted elements")
122 |         labels4_tr = labels4[0:labels4.shape[0] - num_duplicates]
123 |         data4_tr = data4[0:data4.shape[0] - num_duplicates]
124 |         hnsw_index.add_items(data4_tr, labels4_tr, replace_deleted=True)
125 | 
126 |         # Check recall
127 |         print("Checking recall")
128 |         labels_found, _ = hnsw_index.knn_query(data4_tr, k=1)
129 |         recall = np.mean(labels_found.reshape(-1) == labels4_tr)
130 |         print(f"Recall for the 4 batch: {recall}")
131 |         self.assertGreater(recall, recall_threshold)
132 | 
133 |         # Delete batch 4
134 |         print("Deleting batch 4")
135 |         for l in labels4_tr:
136 |             hnsw_index.mark_deleted(l)
137 | 
138 |         print("Testing pickle serialization")
139 |         hnsw_index_pckl = pickle.loads(pickle.dumps(hnsw_index))
140 |         del hnsw_index
141 |         # Insert batch 3
142 |         print("Inserting batch 3 by replacing deleted elements")
143 |         hnsw_index_pckl.add_items(data3_tr, labels3_tr, replace_deleted=True)
144 | 
145 |         # Check recall
146 |         print("Checking recall")
147 |         labels_found, _ = hnsw_index_pckl.knn_query(data3_tr, k=1)
148 |         recall = np.mean(labels_found.reshape(-1) == labels3_tr)
149 |         print(f"Recall for the 3 batch: {recall}")
150 |         self.assertGreater(recall, recall_threshold)
151 | 
152 |         os.remove(index_path)
153 | 
154 | 
155 |     def test_recall_degradation(self):
156 |         """
157 |             Compares recall of the index with replaced elements and without
158 |             Measures recall degradation
159 |         """
160 |         dim = 16
161 |         num_elements = 10_000
162 |         max_num_elements = 2 * num_elements
163 |         query_size = 1_000
164 |         k = 100
165 | 
166 |         recall_threshold = 0.98
167 |         max_recall_diff = 0.02
168 | 
169 |         # Generating sample data
170 |         print("Generating data")
171 |         # batch 1
172 |         first_id = 0
173 |         last_id = num_elements
174 |         labels1 = np.arange(first_id, last_id)
175 |         data1 = np.float32(np.random.random((num_elements, dim)))
176 |         # batch 2
177 |         first_id += num_elements
178 |         last_id += num_elements
179 |         labels2 = np.arange(first_id, last_id)
180 |         data2 = np.float32(np.random.random((num_elements, dim)))
181 |         # batch 3
182 |         first_id += num_elements
183 |         last_id += num_elements
184 |         labels3 = np.arange(first_id, last_id)
185 |         data3 = np.float32(np.random.random((num_elements, dim)))
186 |         # query to test recall
187 |         query_data = np.float32(np.random.random((query_size, dim)))
188 | 
189 |         # Declaring index
190 |         hnsw_index_no_replace = hnswlib.Index(space='l2', dim=dim)
191 |         hnsw_index_no_replace.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=False)
192 |         hnsw_index_with_replace = hnswlib.Index(space='l2', dim=dim)
193 |         hnsw_index_with_replace.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True)
194 | 
195 |         bf_index = hnswlib.BFIndex(space='l2', dim=dim)
196 |         bf_index.init_index(max_elements=max_num_elements)
197 | 
198 |         hnsw_index_no_replace.set_ef(100)
199 |         hnsw_index_no_replace.set_num_threads(50)
200 |         hnsw_index_with_replace.set_ef(100)
201 |         hnsw_index_with_replace.set_num_threads(50)
202 | 
203 |         # Add data
204 |         print("Adding data")
205 |         hnsw_index_with_replace.add_items(data1, labels1)
206 |         hnsw_index_with_replace.add_items(data2, labels2)  # maximum number of elements is reached
207 |         bf_index.add_items(data1, labels1)
208 |         bf_index.add_items(data3, labels3)  # maximum number of elements is reached
209 | 
210 |         for l in labels2:
211 |             hnsw_index_with_replace.mark_deleted(l)
212 |         hnsw_index_with_replace.add_items(data3, labels3, replace_deleted=True)
213 | 
214 |         hnsw_index_no_replace.add_items(data1, labels1)
215 |         hnsw_index_no_replace.add_items(data3, labels3)  # maximum number of elements is reached
216 | 
217 |         # Query the elements and measure recall:
218 |         labels_hnsw_with_replace, _ = hnsw_index_with_replace.knn_query(query_data, k)
219 |         labels_hnsw_no_replace, _ = hnsw_index_no_replace.knn_query(query_data, k)
220 |         labels_bf, distances_bf = bf_index.knn_query(query_data, k)
221 | 
222 |         # Measure recall
223 |         correct_with_replace = 0
224 |         correct_no_replace = 0
225 |         for i in range(query_size):
226 |             for label in labels_hnsw_with_replace[i]:
227 |                 for correct_label in labels_bf[i]:
228 |                     if label == correct_label:
229 |                         correct_with_replace += 1
230 |                         break
231 |             for label in labels_hnsw_no_replace[i]:
232 |                 for correct_label in labels_bf[i]:
233 |                     if label == correct_label:
234 |                         correct_no_replace += 1
235 |                         break
236 | 
237 |         recall_with_replace = float(correct_with_replace) / (k*query_size)
238 |         recall_no_replace = float(correct_no_replace) / (k*query_size)
239 |         print("recall with replace:", recall_with_replace)
240 |         print("recall without replace:", recall_no_replace)
241 | 
242 |         recall_diff = abs(recall_with_replace - recall_with_replace)
243 | 
244 |         self.assertGreater(recall_no_replace, recall_threshold)
245 |         self.assertLess(recall_diff, max_recall_diff)
246 | 


--------------------------------------------------------------------------------
/tests/python/bindings_test_resize.py:
--------------------------------------------------------------------------------
 1 | import unittest
 2 | 
 3 | import numpy as np
 4 | 
 5 | import hnswlib
 6 | 
 7 | 
 8 | class RandomSelfTestCase(unittest.TestCase):
 9 |     def testRandomSelf(self):
10 |         for idx in range(16):
11 |             print("\n**** Index resize test ****\n")
12 | 
13 |             np.random.seed(idx)
14 |             dim = 16
15 |             num_elements = 10000
16 | 
17 |             # Generating sample data
18 |             data = np.float32(np.random.random((num_elements, dim)))
19 | 
20 |             # Declaring index
21 |             p = hnswlib.Index(space='l2', dim=dim)  # possible options are l2, cosine or ip
22 | 
23 |             # Initiating index
24 |             # max_elements - the maximum number of elements, should be known beforehand
25 |             #     (probably will be made optional in the future)
26 |             #
27 |             # ef_construction - controls index search speed/build speed tradeoff
28 |             # M - is tightly connected with internal dimensionality of the data
29 |             #     strongly affects the memory consumption
30 | 
31 |             p.init_index(max_elements=num_elements//2, ef_construction=100, M=16)
32 | 
33 |             # Controlling the recall by setting ef:
34 |             # higher ef leads to better accuracy, but slower search
35 |             p.set_ef(20)
36 | 
37 |             p.set_num_threads(idx % 8)  # by default using all available cores
38 | 
39 |             # We split the data in two batches:
40 |             data1 = data[:num_elements // 2]
41 |             data2 = data[num_elements // 2:]
42 | 
43 |             print("Adding first batch of %d elements" % (len(data1)))
44 |             p.add_items(data1)
45 | 
46 |             # Query the elements for themselves and measure recall:
47 |             labels, distances = p.knn_query(data1, k=1)
48 | 
49 |             items = p.get_items(list(range(len(data1))))
50 | 
51 |             # Check the recall:
52 |             self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data1))), 1.0, 3)
53 | 
54 |             # Check that the returned element data is correct:
55 |             diff_with_gt_labels = np.max(np.abs(data1-items))
56 |             self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-4)
57 | 
58 |             print("Resizing the index")
59 |             p.resize_index(num_elements)
60 | 
61 |             print("Adding the second batch of %d elements" % (len(data2)))
62 |             p.add_items(data2)
63 | 
64 |             # Query the elements for themselves and measure recall:
65 |             labels, distances = p.knn_query(data, k=1)
66 |             items=p.get_items(list(range(num_elements)))
67 | 
68 |             # Check the recall:
69 |             self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), 1.0, 3)
70 | 
71 |             # Check that the returned element data is correct:
72 |             diff_with_gt_labels = np.max(np.abs(data-items))
73 |             self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-4)
74 | 
75 |             # Checking that all labels are returned correctly:
76 |             sorted_labels = sorted(p.get_ids_list())
77 |             self.assertEqual(np.sum(~np.asarray(sorted_labels) == np.asarray(range(num_elements))), 0)
78 | 


--------------------------------------------------------------------------------
/tests/python/bindings_test_spaces.py:
--------------------------------------------------------------------------------
 1 | import unittest
 2 | 
 3 | import numpy as np
 4 | 
 5 | import hnswlib
 6 | 
 7 | class RandomSelfTestCase(unittest.TestCase):
 8 |     def testRandomSelf(self):
 9 | 
10 |         data1 = np.asarray([[1, 0, 0],
11 |                             [0, 1, 0],
12 |                             [0, 0, 1],
13 |                             [1, 0, 1],
14 |                             [1, 1, 1],
15 |                             ])
16 | 
17 |         for space, expected_distances in [
18 |             ('l2', [[0., 1., 2., 2., 2.]]),
19 |             ('ip', [[-2., -1., 0., 0., 0.]]),
20 |             ('cosine', [[0, 1.835e-01, 4.23e-01, 4.23e-01, 4.23e-01]])]:
21 | 
22 |             for rightdim in range(1, 128, 3):
23 |                 for leftdim in range(1, 32, 5):
24 |                     data2 = np.concatenate(
25 |                         [np.zeros([data1.shape[0], leftdim]), data1, np.zeros([data1.shape[0], rightdim])], axis=1)
26 |                     dim = data2.shape[1]
27 |                     p = hnswlib.Index(space=space, dim=dim)
28 |                     p.init_index(max_elements=5, ef_construction=100, M=16)
29 | 
30 |                     p.set_ef(10)
31 | 
32 |                     p.add_items(data2)
33 | 
34 |                     # Query the elements for themselves and measure recall:
35 |                     labels, distances = p.knn_query(np.asarray(data2[-1:]), k=5)
36 | 
37 |                     
38 |                     diff=np.mean(np.abs(distances-expected_distances))                    
39 |                     self.assertAlmostEqual(diff, 0, delta=1e-3)
40 | 


--------------------------------------------------------------------------------
/tests/python/bindings_test_stress_mt_replace.py:
--------------------------------------------------------------------------------
 1 | import unittest
 2 | 
 3 | import numpy as np
 4 | 
 5 | import hnswlib
 6 | 
 7 | 
 8 | class RandomSelfTestCase(unittest.TestCase):
 9 |     def testRandomSelf(self):
10 |         dim = 16
11 |         num_elements = 1_000
12 |         max_num_elements = 2 * num_elements
13 | 
14 |         # Generating sample data
15 |         # batch 1
16 |         first_id = 0
17 |         last_id = num_elements
18 |         labels1 = np.arange(first_id, last_id)
19 |         data1 = np.float32(np.random.random((num_elements, dim)))
20 |         # batch 2
21 |         first_id += num_elements
22 |         last_id += num_elements
23 |         labels2 = np.arange(first_id, last_id)
24 |         data2 = np.float32(np.random.random((num_elements, dim)))
25 |         # batch 3
26 |         first_id += num_elements
27 |         last_id += num_elements
28 |         labels3 = np.arange(first_id, last_id)
29 |         data3 = np.float32(np.random.random((num_elements, dim)))
30 | 
31 |         for _ in range(100):
32 |             # Declaring index
33 |             hnsw_index = hnswlib.Index(space='l2', dim=dim)
34 |             hnsw_index.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True)
35 | 
36 |             hnsw_index.set_ef(100)
37 |             hnsw_index.set_num_threads(50)
38 | 
39 |             # Add batch 1 and 2
40 |             hnsw_index.add_items(data1, labels1)
41 |             hnsw_index.add_items(data2, labels2)  # maximum number of elements is reached
42 | 
43 |             # Delete nearest neighbors of batch 2
44 |             labels2_deleted, _ = hnsw_index.knn_query(data2, k=1)
45 |             labels2_deleted_flat = labels2_deleted.flatten()
46 |             # delete probable duplicates from nearest neighbors
47 |             labels2_deleted_no_dup = set(labels2_deleted_flat)
48 |             for l in labels2_deleted_no_dup:
49 |                 hnsw_index.mark_deleted(l)
50 |             labels1_found, _ = hnsw_index.knn_query(data1, k=1)
51 |             items = hnsw_index.get_items(labels1_found)
52 |             diff_with_gt_labels = np.mean(np.abs(data1 - items))
53 |             self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-3)
54 | 
55 |             labels2_after, _ = hnsw_index.knn_query(data2, k=1)
56 |             labels2_after_flat = labels2_after.flatten()
57 |             common = np.intersect1d(labels2_after_flat, labels2_deleted_flat)
58 |             self.assertTrue(common.size == 0)
59 | 
60 |             # Replace deleted elements
61 |             # Maximum number of elements is reached therefore we cannot add new items
62 |             # but we can replace the deleted ones
63 |             # Note: there may be less than num_elements elements.
64 |             #       As we could delete less than num_elements because of duplicates
65 |             num_duplicates = len(labels2_deleted) - len(labels2_deleted_no_dup)
66 |             labels3_tr = labels3[0:labels3.shape[0] - num_duplicates]
67 |             data3_tr = data3[0:data3.shape[0] - num_duplicates]
68 |             hnsw_index.add_items(data3_tr, labels3_tr, replace_deleted=True)
69 | 


--------------------------------------------------------------------------------
/tests/python/draw_git_test_plots.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import glob
 3 | import pandas as pd
 4 | import matplotlib.pyplot as plt
 5 | import numpy as np
 6 | def plot_data_from_file(file_path):
 7 |     # Load the data, assuming the last column is text
 8 |     data = pd.read_csv(file_path, header=None)
 9 |     rep_size=len(set(data[data.columns[-1]]))
10 |     data.drop(data.columns[-1], axis=1, inplace=True)  # Drop the last column (text)
11 | 
12 |     # Number of numerical columns
13 |     num_columns = data.shape[1]
14 | 
15 |     # Create a subplot for each column
16 |     fig, axes = plt.subplots(num_columns, 1, figsize=(10, 6 * num_columns))
17 |     
18 |     # In case there is only one column, axes will not be an array, so we convert it
19 |     if num_columns == 1:
20 |         axes = [axes]
21 |     
22 |     for i, ax in enumerate(axes):
23 |         idx=0
24 |         ax.scatter(np.asarray(data.index,dtype=np.int64)%rep_size, data[i], label=f'Column {i+1}')
25 |         ax.set_title(f'Column {i+1}')
26 |         ax.set_xlabel('ID Number')
27 |         ax.set_ylabel('Value')
28 |         ax.legend()
29 |         ax.grid(True)
30 | 
31 |     plt.tight_layout()
32 |     plt.suptitle(f'Data from {os.path.basename(file_path)}')
33 | 
34 |     # Save the plot to a file
35 |     plt.savefig(file_path.replace('.txt', '.png'))
36 |     plt.close()
37 | 
38 | def scan_and_plot(directory):
39 |     # Scan for .txt files in the given directory
40 |     txt_files = glob.glob(os.path.join(directory, '*.txt'))
41 | 
42 |     # Process each file
43 |     for file in txt_files:
44 |         print(f'Processing {file}...')
45 |         plot_data_from_file(file)
46 |         print(f'Plot saved for {file}')
47 | # Replace 'your_folder_path' with the path to the folder containing the .txt files
48 | scan_and_plot('./')


--------------------------------------------------------------------------------
/tests/python/git_tester.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import shutil
 3 | 
 4 | from sys import platform
 5 | from pydriller import Repository
 6 | 
 7 | 
 8 | speedtest_src_path = os.path.join("tests", "python", "speedtest.py")
 9 | speedtest_copy_path = os.path.join("tests", "python", "speedtest2.py")
10 | shutil.copyfile(speedtest_src_path, speedtest_copy_path) # the file has to be outside of git
11 | 
12 | commits = list(Repository('.', from_tag="v0.7.0").traverse_commits())
13 | print("Found commits:")
14 | for idx, commit in enumerate(commits):
15 |     name = commit.msg.replace('\n', ' ').replace('\r', ' ')
16 |     print(idx, commit.hash, name)
17 | 
18 | for commit in commits:
19 |     commit_time = commit.author_date.strftime("%Y-%m-%d %H:%M:%S") 
20 |     author_name = commit.author.name
21 |     name = "auth:"+author_name+"_"+commit_time+"_msg:"+commit.msg.replace('\n', ' ').replace('\r', ' ').replace(",", ";")
22 |     print("\nProcessing", commit.hash, name)
23 |     
24 |     if os.path.exists("build"):
25 |         shutil.rmtree("build")
26 |     os.system(f"git checkout {commit.hash}")
27 |     
28 |     # Checking we have actually switched the branch:
29 |     current_commit=list(Repository('.').traverse_commits())[-1]
30 |     if current_commit.hash != commit.hash:
31 |         print("git checkout failed!!!!")
32 |         print("git checkout failed!!!!")
33 |         print("git checkout failed!!!!")
34 |         print("git checkout failed!!!!")
35 |         continue
36 |     
37 |     print("\n\n--------------------\n\n")
38 |     ret = os.system("python -m pip install .")
39 |     print("Install result:", ret)
40 | 
41 |     if ret != 0:
42 |         print("build failed!!!!")
43 |         print("build failed!!!!")
44 |         print("build failed!!!!")
45 |         print("build failed!!!!")
46 |         continue
47 | 
48 | 
49 |     os.system(f'python {speedtest_copy_path} -n "{commit.hash[:4]}_{name}" -d 16 -t 1')
50 |     os.system(f'python {speedtest_copy_path} -n "{commit.hash[:4]}_{name}" -d 16 -t 64')
51 |     os.system(f'python {speedtest_copy_path} -n "{commit.hash[:4]}_{name}" -d 4 -t 1')
52 |     os.system(f'python {speedtest_copy_path} -n "{commit.hash[:4]}_{name}" -d 4 -t 64')
53 |     os.system(f'python {speedtest_copy_path} -n "{commit.hash[:4]}_{name}" -d 128 -t 1')
54 |     os.system(f'python {speedtest_copy_path} -n "{commit.hash[:4]}_{name}" -d 128 -t 64')
55 | 
56 | 


--------------------------------------------------------------------------------
/tests/python/speedtest.py:
--------------------------------------------------------------------------------
 1 | import hnswlib
 2 | import numpy as np
 3 | import os.path
 4 | import time
 5 | import argparse
 6 | 
 7 | # Use nargs to specify how many arguments an option should take.
 8 | ap = argparse.ArgumentParser()
 9 | ap.add_argument('-d')
10 | ap.add_argument('-n')
11 | ap.add_argument('-t')
12 | args = ap.parse_args()
13 | dim = int(args.d)
14 | name = args.n
15 | threads=int(args.t)
16 | num_elements = 400000
17 | 
18 | # Generating sample data
19 | np.random.seed(1)
20 | data = np.float32(np.random.random((num_elements, dim)))
21 | 
22 | 
23 | # index_path=f'speed_index{dim}.bin'
24 | # Declaring index
25 | p = hnswlib.Index(space='l2', dim=dim)  # possible options are l2, cosine or ip
26 | 
27 | # if not os.path.isfile(index_path) :
28 | 
29 | p.init_index(max_elements=num_elements, ef_construction=60, M=16)
30 | 
31 | # Controlling the recall by setting ef:
32 | # higher ef leads to better accuracy, but slower search
33 | p.set_ef(10)
34 | 
35 | # Set number of threads used during batch search/construction
36 | # By default using all available cores
37 | p.set_num_threads(64)
38 | t0=time.time()
39 | p.add_items(data)
40 | construction_time=time.time()-t0
41 | # Serializing and deleting the index:
42 | 
43 | # print("Saving index to '%s'" % index_path)
44 | # p.save_index(index_path)
45 | p.set_num_threads(threads)
46 | times=[]
47 | time.sleep(1)
48 | p.set_ef(15)
49 | for _ in range(1):
50 |     # p.load_index(index_path)
51 |     for _ in range(3):
52 |         t0=time.time()
53 |         qdata=data[:5000*threads]
54 |         labels, distances = p.knn_query(qdata, k=1)
55 |         tt=time.time()-t0
56 |         times.append(tt)
57 |         recall=np.sum(labels.reshape(-1)==np.arange(len(qdata)))/len(qdata)
58 |         print(f"{tt} seconds, recall= {recall}")    
59 |         
60 | str_out=f"{np.mean(times)}, {np.median(times)}, {np.std(times)}, {construction_time}, {recall}, {name}"
61 | print(str_out)
62 | with open (f"log2_{dim}_t{threads}.txt","a") as f:
63 |     f.write(str_out+"\n")
64 |     f.flush()
65 | 
66 | 


--------------------------------------------------------------------------------