├── .github └── workflows │ ├── release.yml │ └── test.yml ├── .gitignore ├── ALGO_PARAMS.md ├── CMakeLists.txt ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── TESTING_RECALL.md ├── build.rs ├── examples ├── cpp │ ├── EXAMPLES.md │ ├── example_filter.cpp │ ├── example_mt_filter.cpp │ ├── example_mt_replace_deleted.cpp │ ├── example_mt_search.cpp │ ├── example_replace_deleted.cpp │ └── example_search.cpp └── python │ ├── EXAMPLES.md │ ├── example.py │ ├── example_filter.py │ ├── example_replace_deleted.py │ ├── example_search.py │ ├── example_serialization.py │ └── pyw_hnswlib.py ├── hnswlib ├── bruteforce.h ├── hnswalg.h ├── hnswlib.h ├── space_ip.h ├── space_l2.h └── visited_list_pool.h ├── pyproject.toml ├── python_bindings ├── LazyIndex.py ├── __init__.py ├── bindings.cpp └── setup.py ├── setup.py ├── src ├── bindings.cpp ├── hnsw.rs └── lib.rs └── tests ├── cpp ├── api_test.cpp ├── download_bigann.py ├── getUnormalized_test.cpp ├── main.cpp ├── multiThreadLoad_test.cpp ├── multiThread_replace_test.cpp ├── persistent_test.cpp ├── searchKnnCloserFirst_test.cpp ├── searchKnnWithFilter_test.cpp ├── sift_1b.cpp ├── sift_test.cpp ├── update_gen_data.py └── updates_test.cpp └── python ├── bindings_test.py ├── bindings_test_filter.py ├── bindings_test_getdata.py ├── bindings_test_labels.py ├── bindings_test_metadata.py ├── bindings_test_persistent.py ├── bindings_test_pickle.py ├── bindings_test_recall.py ├── bindings_test_replace.py ├── bindings_test_resize.py ├── bindings_test_spaces.py ├── bindings_test_stress_mt_replace.py ├── git_tester.py └── speedtest.py /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - '[0-9]+.[0-9]+.[0-9]*' 7 | 8 | jobs: 9 | 10 | tests: 11 | uses: ./.github/workflows/test.yml 12 | 13 | build: 14 | needs: [tests] 15 | runs-on: ${{matrix.os}} 16 | strategy: 17 | matrix: 18 | os: [ubuntu-latest, windows-latest, macos-latest] 19 | steps: 20 | - name: Set up QEMU 21 | uses: docker/setup-qemu-action@v2 22 | if: matrix.os == 'ubuntu-latest' 23 | - uses: actions/checkout@v3 24 | - uses: actions/setup-python@v4 25 | with: 26 | python-version: '3.10' 27 | - name: Install cibuildwheel 28 | run: python -m pip install cibuildwheel==2.19.1 29 | - name: Build wheels 30 | run: python -m cibuildwheel --output-dir dist 31 | env: 32 | CIBW_ENVIRONMENT: HNSWLIB_NO_NATIVE=true 33 | CIBW_ENVIRONMENT_PASS_LINUX: HNSWLIB_NO_NATIVE 34 | CIBW_PROJECT_REQUIRES_PYTHON: ">=3.7" 35 | CIBW_SKIP: "pp* *musllinux* cp312-win*" 36 | CIBW_ARCHS_MACOS: "x86_64 arm64" 37 | CIBW_ARCHS_WINDOWS: "AMD64" 38 | CIBW_ARCHS_LINUX: "x86_64 aarch64" 39 | - name: Upload artifacts 40 | uses: actions/upload-artifact@v3 41 | with: 42 | name: python-package-distributions 43 | path: dist 44 | 45 | upload: 46 | runs-on: ubuntu-latest 47 | needs: build 48 | steps: 49 | - uses: actions/checkout@v3 50 | - uses: actions/setup-python@v4 51 | with: 52 | python-version: "3.10" 53 | - name: Build sdist 54 | run: | 55 | python -m pip install . 56 | make dist 57 | - name: Download wheels 58 | uses: actions/download-artifact@v3 59 | with: 60 | name: python-package-distributions 61 | path: dist/ 62 | - name: Publish to Test PyPI 63 | uses: pypa/gh-action-pypi-publish@release/v1 64 | with: 65 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 66 | repository-url: https://test.pypi.org/legacy/ 67 | - name: Publish to PyPI 68 | uses: pypa/gh-action-pypi-publish@release/v1 69 | with: 70 | password: ${{ secrets.PYPI_API_TOKEN }} 71 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | workflow_call: 5 | push: 6 | branches: 7 | - master 8 | pull_request: 9 | branches: 10 | - master 11 | 12 | jobs: 13 | test_python: 14 | runs-on: ${{matrix.os}} 15 | strategy: 16 | matrix: 17 | os: [ubuntu-latest, windows-latest] 18 | python-version: ["3.8", "3.9", "3.10"] 19 | steps: 20 | - uses: actions/checkout@v3 21 | - uses: actions/setup-python@v4 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | 25 | - name: Build and install 26 | run: python -m pip install . 27 | 28 | - name: Test 29 | timeout-minutes: 15 30 | run: | 31 | python -m unittest discover -v --start-directory examples/python --pattern "example*.py" 32 | python -m unittest discover -v --start-directory tests/python --pattern "bindings_test*.py" 33 | 34 | test_cpp: 35 | runs-on: ${{matrix.os}} 36 | strategy: 37 | matrix: 38 | os: [ubuntu-latest, windows-latest] 39 | steps: 40 | - uses: actions/checkout@v3 41 | - uses: actions/setup-python@v4 42 | with: 43 | python-version: "3.10" 44 | 45 | - name: Build 46 | run: | 47 | mkdir build 48 | cd build 49 | cmake .. 50 | if [ "$RUNNER_OS" == "Linux" ]; then 51 | make 52 | elif [ "$RUNNER_OS" == "Windows" ]; then 53 | cmake --build ./ --config Release 54 | fi 55 | shell: bash 56 | 57 | - name: Prepare test data 58 | run: | 59 | pip install numpy 60 | cd tests/cpp/ 61 | python update_gen_data.py 62 | shell: bash 63 | 64 | - name: Test 65 | timeout-minutes: 15 66 | run: | 67 | cd build 68 | if [ "$RUNNER_OS" == "Windows" ]; then 69 | cp ./Release/* ./ 70 | fi 71 | ./example_search 72 | ./example_filter 73 | ./example_replace_deleted 74 | ./example_mt_search 75 | ./example_mt_filter 76 | ./example_mt_replace_deleted 77 | ./searchKnnCloserFirst_test 78 | ./searchKnnWithFilter_test 79 | ./multiThreadLoad_test 80 | ./multiThread_replace_test 81 | ./getUnormalized_test 82 | ./test_updates 83 | ./test_updates update 84 | ./persistent_test 85 | shell: bash 86 | 87 | test_rust: 88 | runs-on: ${{matrix.os}} 89 | strategy: 90 | matrix: 91 | os: [ubuntu-latest, windows-latest] 92 | env: 93 | CARGO_TERM_COLOR: always 94 | steps: 95 | - uses: actions/checkout@v4 96 | - name: Test 97 | run: | 98 | cargo test 99 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | hnswlib.egg-info/ 2 | build/ 3 | dist/ 4 | tmp/ 5 | python_bindings/tests/__pycache__/ 6 | *.pyd 7 | hnswlib.cpython*.so 8 | var/ 9 | .idea/ 10 | .vscode/ 11 | .vs/ 12 | **.DS_Store 13 | *.egg-info/ 14 | venv/ 15 | target 16 | -------------------------------------------------------------------------------- /ALGO_PARAMS.md: -------------------------------------------------------------------------------- 1 | # HNSW algorithm parameters 2 | 3 | ## Search parameters: 4 | * ```ef``` - the size of the dynamic list for the nearest neighbors (used during the search). Higher ```ef``` 5 | leads to more accurate but slower search. ```ef``` cannot be set lower than the number of queried nearest neighbors 6 | ```k```. The value ```ef``` of can be anything between ```k``` and the size of the dataset. 7 | * ```k``` number of nearest neighbors to be returned as the result. 8 | The ```knn_query``` function returns two numpy arrays, containing labels and distances to the k found nearest 9 | elements for the queries. Note that in case the algorithm is not be able to find ```k``` neighbors to all of the queries, 10 | (this can be due to problems with graph or ```k```>size of the dataset) an exception is thrown. 11 | 12 | An example of tuning the parameters can be found in [TESTING_RECALL.md](TESTING_RECALL.md) 13 | 14 | ## Construction parameters: 15 | * ```M``` - the number of bi-directional links created for every new element during construction. Reasonable range for ```M``` 16 | is 2-100. Higher ```M``` work better on datasets with high intrinsic dimensionality and/or high recall, while low ```M``` work 17 | better for datasets with low intrinsic dimensionality and/or low recalls. The parameter also determines the algorithm's memory 18 | consumption, which is roughly ```M * 8-10``` bytes per stored element. 19 | As an example for ```dim```=4 random vectors optimal ```M``` for search is somewhere around 6, while for high dimensional datasets 20 | (word embeddings, good face descriptors), higher ```M``` are required (e.g. ```M```=48-64) for optimal performance at high recall. 21 | The range ```M```=12-48 is ok for the most of the use cases. When ```M``` is changed one has to update the other parameters. 22 | Nonetheless, ef and ef_construction parameters can be roughly estimated by assuming that ```M```*```ef_{construction}``` is 23 | a constant. 24 | 25 | * ```ef_construction``` - the parameter has the same meaning as ```ef```, but controls the index_time/index_accuracy. Bigger 26 | ef_construction leads to longer construction, but better index quality. At some point, increasing ef_construction does 27 | not improve the quality of the index. One way to check if the selection of ef_construction was ok is to measure a recall 28 | for M nearest neighbor search when ```ef``` =```ef_construction```: if the recall is lower than 0.9, than there is room 29 | for improvement. 30 | * ```num_elements``` - defines the maximum number of elements in the index. The index can be extended by saving/loading (load_index 31 | function has a parameter which defines the new maximum number of elements). 32 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required (VERSION 2.6) 2 | project(hnsw_lib 3 | LANGUAGES CXX) 4 | 5 | add_library(hnswlib INTERFACE) 6 | target_include_directories(hnswlib INTERFACE .) 7 | 8 | if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME) 9 | set(CMAKE_CXX_STANDARD 11) 10 | 11 | if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") 12 | SET( CMAKE_CXX_FLAGS "-Ofast -DNDEBUG -std=c++11 -DHAVE_CXX0X -openmp -march=native -fpic -ftree-vectorize") 13 | elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") 14 | SET( CMAKE_CXX_FLAGS "-Ofast -lrt -DNDEBUG -std=c++11 -DHAVE_CXX0X -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0" ) 15 | elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") 16 | SET( CMAKE_CXX_FLAGS "-Ofast -lrt -DNDEBUG -std=c++11 -DHAVE_CXX0X -openmp -march=native -fpic -w -fopenmp -ftree-vectorize" ) 17 | endif() 18 | 19 | # examples 20 | add_executable(example_search examples/cpp/example_search.cpp) 21 | target_link_libraries(example_search hnswlib) 22 | 23 | add_executable(example_filter examples/cpp/example_filter.cpp) 24 | target_link_libraries(example_filter hnswlib) 25 | 26 | add_executable(example_replace_deleted examples/cpp/example_replace_deleted.cpp) 27 | target_link_libraries(example_replace_deleted hnswlib) 28 | 29 | add_executable(example_mt_search examples/cpp/example_mt_search.cpp) 30 | target_link_libraries(example_mt_search hnswlib) 31 | 32 | add_executable(example_mt_filter examples/cpp/example_mt_filter.cpp) 33 | target_link_libraries(example_mt_filter hnswlib) 34 | 35 | add_executable(example_mt_replace_deleted examples/cpp/example_mt_replace_deleted.cpp) 36 | target_link_libraries(example_mt_replace_deleted hnswlib) 37 | 38 | # tests 39 | add_executable(test_updates tests/cpp/updates_test.cpp) 40 | target_link_libraries(test_updates hnswlib) 41 | 42 | add_executable(searchKnnCloserFirst_test tests/cpp/searchKnnCloserFirst_test.cpp) 43 | target_link_libraries(searchKnnCloserFirst_test hnswlib) 44 | 45 | add_executable(searchKnnWithFilter_test tests/cpp/searchKnnWithFilter_test.cpp) 46 | target_link_libraries(searchKnnWithFilter_test hnswlib) 47 | 48 | add_executable(multiThreadLoad_test tests/cpp/multiThreadLoad_test.cpp) 49 | target_link_libraries(multiThreadLoad_test hnswlib) 50 | 51 | add_executable(multiThread_replace_test tests/cpp/multiThread_replace_test.cpp) 52 | target_link_libraries(multiThread_replace_test hnswlib) 53 | 54 | add_executable(main tests/cpp/main.cpp tests/cpp/sift_1b.cpp) 55 | target_link_libraries(main hnswlib) 56 | 57 | add_executable(getUnormalized_test tests/cpp/getUnormalized_test.cpp) 58 | target_link_libraries(getUnormalized_test hnswlib) 59 | 60 | add_executable(persistent_test tests/cpp/persistent_test.cpp) 61 | target_link_libraries(persistent_test hnswlib) 62 | 63 | add_executable(api_tests tests/cpp/api_test.cpp) 64 | target_link_libraries(api_tests hnswlib) 65 | endif() 66 | 67 | -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 4 4 | 5 | [[package]] 6 | name = "bitflags" 7 | version = "2.8.0" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" 10 | 11 | [[package]] 12 | name = "byteorder" 13 | version = "1.5.0" 14 | source = "registry+https://github.com/rust-lang/crates.io-index" 15 | checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" 16 | 17 | [[package]] 18 | name = "cc" 19 | version = "1.2.13" 20 | source = "registry+https://github.com/rust-lang/crates.io-index" 21 | checksum = "c7777341816418c02e033934a09f20dc0ccaf65a5201ef8a450ae0105a573fda" 22 | dependencies = [ 23 | "shlex", 24 | ] 25 | 26 | [[package]] 27 | name = "cfg-if" 28 | version = "1.0.0" 29 | source = "registry+https://github.com/rust-lang/crates.io-index" 30 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" 31 | 32 | [[package]] 33 | name = "crossbeam-deque" 34 | version = "0.8.6" 35 | source = "registry+https://github.com/rust-lang/crates.io-index" 36 | checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" 37 | dependencies = [ 38 | "crossbeam-epoch", 39 | "crossbeam-utils", 40 | ] 41 | 42 | [[package]] 43 | name = "crossbeam-epoch" 44 | version = "0.9.18" 45 | source = "registry+https://github.com/rust-lang/crates.io-index" 46 | checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" 47 | dependencies = [ 48 | "crossbeam-utils", 49 | ] 50 | 51 | [[package]] 52 | name = "crossbeam-utils" 53 | version = "0.8.21" 54 | source = "registry+https://github.com/rust-lang/crates.io-index" 55 | checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" 56 | 57 | [[package]] 58 | name = "either" 59 | version = "1.13.0" 60 | source = "registry+https://github.com/rust-lang/crates.io-index" 61 | checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" 62 | 63 | [[package]] 64 | name = "errno" 65 | version = "0.3.10" 66 | source = "registry+https://github.com/rust-lang/crates.io-index" 67 | checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" 68 | dependencies = [ 69 | "libc", 70 | "windows-sys", 71 | ] 72 | 73 | [[package]] 74 | name = "fastrand" 75 | version = "2.3.0" 76 | source = "registry+https://github.com/rust-lang/crates.io-index" 77 | checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" 78 | 79 | [[package]] 80 | name = "getrandom" 81 | version = "0.2.15" 82 | source = "registry+https://github.com/rust-lang/crates.io-index" 83 | checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" 84 | dependencies = [ 85 | "cfg-if", 86 | "libc", 87 | "wasi 0.11.0+wasi-snapshot-preview1", 88 | ] 89 | 90 | [[package]] 91 | name = "getrandom" 92 | version = "0.3.1" 93 | source = "registry+https://github.com/rust-lang/crates.io-index" 94 | checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" 95 | dependencies = [ 96 | "cfg-if", 97 | "libc", 98 | "wasi 0.13.3+wasi-0.2.2", 99 | "windows-targets", 100 | ] 101 | 102 | [[package]] 103 | name = "hnswlib" 104 | version = "0.8.0" 105 | dependencies = [ 106 | "cc", 107 | "rand", 108 | "rayon", 109 | "tempfile", 110 | "thiserror", 111 | ] 112 | 113 | [[package]] 114 | name = "libc" 115 | version = "0.2.169" 116 | source = "registry+https://github.com/rust-lang/crates.io-index" 117 | checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" 118 | 119 | [[package]] 120 | name = "linux-raw-sys" 121 | version = "0.4.15" 122 | source = "registry+https://github.com/rust-lang/crates.io-index" 123 | checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" 124 | 125 | [[package]] 126 | name = "once_cell" 127 | version = "1.20.3" 128 | source = "registry+https://github.com/rust-lang/crates.io-index" 129 | checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" 130 | 131 | [[package]] 132 | name = "ppv-lite86" 133 | version = "0.2.20" 134 | source = "registry+https://github.com/rust-lang/crates.io-index" 135 | checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" 136 | dependencies = [ 137 | "zerocopy", 138 | ] 139 | 140 | [[package]] 141 | name = "proc-macro2" 142 | version = "1.0.93" 143 | source = "registry+https://github.com/rust-lang/crates.io-index" 144 | checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" 145 | dependencies = [ 146 | "unicode-ident", 147 | ] 148 | 149 | [[package]] 150 | name = "quote" 151 | version = "1.0.38" 152 | source = "registry+https://github.com/rust-lang/crates.io-index" 153 | checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" 154 | dependencies = [ 155 | "proc-macro2", 156 | ] 157 | 158 | [[package]] 159 | name = "rand" 160 | version = "0.8.5" 161 | source = "registry+https://github.com/rust-lang/crates.io-index" 162 | checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" 163 | dependencies = [ 164 | "libc", 165 | "rand_chacha", 166 | "rand_core", 167 | ] 168 | 169 | [[package]] 170 | name = "rand_chacha" 171 | version = "0.3.1" 172 | source = "registry+https://github.com/rust-lang/crates.io-index" 173 | checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" 174 | dependencies = [ 175 | "ppv-lite86", 176 | "rand_core", 177 | ] 178 | 179 | [[package]] 180 | name = "rand_core" 181 | version = "0.6.4" 182 | source = "registry+https://github.com/rust-lang/crates.io-index" 183 | checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" 184 | dependencies = [ 185 | "getrandom 0.2.15", 186 | ] 187 | 188 | [[package]] 189 | name = "rayon" 190 | version = "1.10.0" 191 | source = "registry+https://github.com/rust-lang/crates.io-index" 192 | checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" 193 | dependencies = [ 194 | "either", 195 | "rayon-core", 196 | ] 197 | 198 | [[package]] 199 | name = "rayon-core" 200 | version = "1.12.1" 201 | source = "registry+https://github.com/rust-lang/crates.io-index" 202 | checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" 203 | dependencies = [ 204 | "crossbeam-deque", 205 | "crossbeam-utils", 206 | ] 207 | 208 | [[package]] 209 | name = "rustix" 210 | version = "0.38.44" 211 | source = "registry+https://github.com/rust-lang/crates.io-index" 212 | checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" 213 | dependencies = [ 214 | "bitflags", 215 | "errno", 216 | "libc", 217 | "linux-raw-sys", 218 | "windows-sys", 219 | ] 220 | 221 | [[package]] 222 | name = "shlex" 223 | version = "1.3.0" 224 | source = "registry+https://github.com/rust-lang/crates.io-index" 225 | checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" 226 | 227 | [[package]] 228 | name = "syn" 229 | version = "2.0.98" 230 | source = "registry+https://github.com/rust-lang/crates.io-index" 231 | checksum = "36147f1a48ae0ec2b5b3bc5b537d267457555a10dc06f3dbc8cb11ba3006d3b1" 232 | dependencies = [ 233 | "proc-macro2", 234 | "quote", 235 | "unicode-ident", 236 | ] 237 | 238 | [[package]] 239 | name = "tempfile" 240 | version = "3.16.0" 241 | source = "registry+https://github.com/rust-lang/crates.io-index" 242 | checksum = "38c246215d7d24f48ae091a2902398798e05d978b24315d6efbc00ede9a8bb91" 243 | dependencies = [ 244 | "cfg-if", 245 | "fastrand", 246 | "getrandom 0.3.1", 247 | "once_cell", 248 | "rustix", 249 | "windows-sys", 250 | ] 251 | 252 | [[package]] 253 | name = "thiserror" 254 | version = "1.0.69" 255 | source = "registry+https://github.com/rust-lang/crates.io-index" 256 | checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" 257 | dependencies = [ 258 | "thiserror-impl", 259 | ] 260 | 261 | [[package]] 262 | name = "thiserror-impl" 263 | version = "1.0.69" 264 | source = "registry+https://github.com/rust-lang/crates.io-index" 265 | checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" 266 | dependencies = [ 267 | "proc-macro2", 268 | "quote", 269 | "syn", 270 | ] 271 | 272 | [[package]] 273 | name = "unicode-ident" 274 | version = "1.0.16" 275 | source = "registry+https://github.com/rust-lang/crates.io-index" 276 | checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" 277 | 278 | [[package]] 279 | name = "wasi" 280 | version = "0.11.0+wasi-snapshot-preview1" 281 | source = "registry+https://github.com/rust-lang/crates.io-index" 282 | checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" 283 | 284 | [[package]] 285 | name = "wasi" 286 | version = "0.13.3+wasi-0.2.2" 287 | source = "registry+https://github.com/rust-lang/crates.io-index" 288 | checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" 289 | dependencies = [ 290 | "wit-bindgen-rt", 291 | ] 292 | 293 | [[package]] 294 | name = "windows-sys" 295 | version = "0.59.0" 296 | source = "registry+https://github.com/rust-lang/crates.io-index" 297 | checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" 298 | dependencies = [ 299 | "windows-targets", 300 | ] 301 | 302 | [[package]] 303 | name = "windows-targets" 304 | version = "0.52.6" 305 | source = "registry+https://github.com/rust-lang/crates.io-index" 306 | checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" 307 | dependencies = [ 308 | "windows_aarch64_gnullvm", 309 | "windows_aarch64_msvc", 310 | "windows_i686_gnu", 311 | "windows_i686_gnullvm", 312 | "windows_i686_msvc", 313 | "windows_x86_64_gnu", 314 | "windows_x86_64_gnullvm", 315 | "windows_x86_64_msvc", 316 | ] 317 | 318 | [[package]] 319 | name = "windows_aarch64_gnullvm" 320 | version = "0.52.6" 321 | source = "registry+https://github.com/rust-lang/crates.io-index" 322 | checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" 323 | 324 | [[package]] 325 | name = "windows_aarch64_msvc" 326 | version = "0.52.6" 327 | source = "registry+https://github.com/rust-lang/crates.io-index" 328 | checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" 329 | 330 | [[package]] 331 | name = "windows_i686_gnu" 332 | version = "0.52.6" 333 | source = "registry+https://github.com/rust-lang/crates.io-index" 334 | checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" 335 | 336 | [[package]] 337 | name = "windows_i686_gnullvm" 338 | version = "0.52.6" 339 | source = "registry+https://github.com/rust-lang/crates.io-index" 340 | checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" 341 | 342 | [[package]] 343 | name = "windows_i686_msvc" 344 | version = "0.52.6" 345 | source = "registry+https://github.com/rust-lang/crates.io-index" 346 | checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" 347 | 348 | [[package]] 349 | name = "windows_x86_64_gnu" 350 | version = "0.52.6" 351 | source = "registry+https://github.com/rust-lang/crates.io-index" 352 | checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" 353 | 354 | [[package]] 355 | name = "windows_x86_64_gnullvm" 356 | version = "0.52.6" 357 | source = "registry+https://github.com/rust-lang/crates.io-index" 358 | checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" 359 | 360 | [[package]] 361 | name = "windows_x86_64_msvc" 362 | version = "0.52.6" 363 | source = "registry+https://github.com/rust-lang/crates.io-index" 364 | checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" 365 | 366 | [[package]] 367 | name = "wit-bindgen-rt" 368 | version = "0.33.0" 369 | source = "registry+https://github.com/rust-lang/crates.io-index" 370 | checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" 371 | dependencies = [ 372 | "bitflags", 373 | ] 374 | 375 | [[package]] 376 | name = "zerocopy" 377 | version = "0.7.35" 378 | source = "registry+https://github.com/rust-lang/crates.io-index" 379 | checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" 380 | dependencies = [ 381 | "byteorder", 382 | "zerocopy-derive", 383 | ] 384 | 385 | [[package]] 386 | name = "zerocopy-derive" 387 | version = "0.7.35" 388 | source = "registry+https://github.com/rust-lang/crates.io-index" 389 | checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" 390 | dependencies = [ 391 | "proc-macro2", 392 | "quote", 393 | "syn", 394 | ] 395 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "hnswlib" 3 | edition = "2021" 4 | version = "0.8.1" 5 | 6 | [lib] 7 | path = "src/lib.rs" 8 | 9 | [dependencies] 10 | thiserror = "1.0" 11 | 12 | [dev-dependencies] 13 | tempfile = "3.14.0" 14 | rayon = "1.10.0" 15 | rand = "0.8.5" 16 | 17 | [build-dependencies] 18 | cc = "1.2" 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include hnswlib/*.h 2 | include LICENSE 3 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | pypi: dist 2 | twine upload dist/* 3 | 4 | dist: 5 | -rm dist/* 6 | pip install build 7 | python3 -m build --sdist 8 | 9 | test: 10 | python3 -m unittest discover --start-directory tests/python --pattern "bindings_test*.py" 11 | 12 | clean: 13 | rm -rf *.egg-info build dist tmp var tests/__pycache__ hnswlib.cpython*.so 14 | 15 | .PHONY: dist 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Chroma-Hnswlib - fast approximate nearest neighbor search 2 | Chromas fork of https://github.com/nmslib/hnswlib 3 | 4 | ## Build & Release 5 | 6 | Wheels are automatically built and pushed to PyPI for multiple 7 | platforms via GitHub actions using the 8 | [cibuildwheel](https://github.com/pypa/cibuildwheel). 9 | 10 | The `Publish` Github Action is configured to run whenever a version 11 | tag (a tag string with three period-delimited numbers) is pushed. Is 12 | is necessary to ensure that the version number in `setup.py` has also 13 | been updated, or else the `Publish` action will fail. You must also update the `version` field in `Cargo.toml`. 14 | 15 | ### Building AVX Extensions 16 | 17 | For maximum compatibility, the distributed wheels are not compiled to 18 | make use of Advanced Vector Extensions (AVX). If your hardware 19 | supports AVX, you may get better performance by recompiling this 20 | library on the machine on which it is intended to run. 21 | 22 | To force recompilation when installing, specify the `--no-binary 23 | chroma-hsnwlib` option to PIP when installing dependencies. This can 24 | be added to your `pip install` command, for example: 25 | 26 | ``` 27 | pip install -r requirements.txt --no-binary chroma-hnswlib 28 | ``` 29 | 30 | You can also put the `--no-binary` directive [in your requirements.txt](https://pip.pypa.io/en/stable/cli/pip_install/#install-no-binary). 31 | 32 | If you've already installed dependencies, you must first uninstall 33 | `chroma-hsnwlib` using `pip uninstall chroma-hnswlib` to remove the 34 | precompiled version before reinstalling. 35 | -------------------------------------------------------------------------------- /TESTING_RECALL.md: -------------------------------------------------------------------------------- 1 | # Testing recall 2 | 3 | Selecting HNSW parameters for a specific use case highly impacts the search quality. One way to test the quality of the constructed index is to compare the HNSW search results to the actual results (i.e., the actual `k` nearest neighbors). 4 | For that cause, the API enables creating a simple "brute-force" index in which vectors are stored as is, and searching for the `k` nearest neighbors to a query vector requires going over the entire index. 5 | Comparing between HNSW and brute-force results may help with finding the desired HNSW parameters for achieving a satisfying recall, based on the index size and data dimension. 6 | 7 | ### Brute force index API 8 | `hnswlib.BFIndex(space, dim)` creates a non-initialized index in space `space` with integer dimension `dim`. 9 | 10 | `hnswlib.BFIndex` methods: 11 | 12 | `init_index(max_elements)` initializes the index with no elements. 13 | 14 | max_elements defines the maximum number of elements that can be stored in the structure. 15 | 16 | `add_items(data, ids)` inserts the data (numpy array of vectors, shape:`N*dim`) into the structure. 17 | `ids` are optional N-size numpy array of integer labels for all elements in data. 18 | 19 | `delete_vector(label)` delete the element associated with the given `label` so it will be omitted from search results. 20 | 21 | `knn_query(data, k = 1)` make a batch query for `k `closest elements for each element of the 22 | `data` (shape:`N*dim`). Returns a numpy array of (shape:`N*k`). 23 | 24 | `load_index(path_to_index, max_elements = 0)` loads the index from persistence to the uninitialized index. 25 | 26 | `save_index(path_to_index)` saves the index from persistence. 27 | 28 | ### measuring recall example 29 | 30 | ```python 31 | import hnswlib 32 | import numpy as np 33 | 34 | dim = 32 35 | num_elements = 100000 36 | k = 10 37 | nun_queries = 10 38 | 39 | # Generating sample data 40 | data = np.float32(np.random.random((num_elements, dim))) 41 | 42 | # Declaring index 43 | hnsw_index = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip 44 | bf_index = hnswlib.BFIndex(space='l2', dim=dim) 45 | 46 | # Initing both hnsw and brute force indices 47 | # max_elements - the maximum number of elements (capacity). Will throw an exception if exceeded 48 | # during insertion of an element. 49 | # The capacity can be increased by saving/loading the index, see below. 50 | # 51 | # hnsw construction params: 52 | # ef_construction - controls index search speed/build speed tradeoff 53 | # 54 | # M - is tightly connected with internal dimensionality of the data. Strongly affects the memory consumption (~M) 55 | # Higher M leads to higher accuracy/run_time at fixed ef/efConstruction 56 | 57 | hnsw_index.init_index(max_elements=num_elements, ef_construction=200, M=16) 58 | bf_index.init_index(max_elements=num_elements) 59 | 60 | # Controlling the recall for hnsw by setting ef: 61 | # higher ef leads to better accuracy, but slower search 62 | hnsw_index.set_ef(200) 63 | 64 | # Set number of threads used during batch search/construction in hnsw 65 | # By default using all available cores 66 | hnsw_index.set_num_threads(1) 67 | 68 | print("Adding batch of %d elements" % (len(data))) 69 | hnsw_index.add_items(data) 70 | bf_index.add_items(data) 71 | 72 | print("Indices built") 73 | 74 | # Generating query data 75 | query_data = np.float32(np.random.random((nun_queries, dim))) 76 | 77 | # Query the elements and measure recall: 78 | labels_hnsw, distances_hnsw = hnsw_index.knn_query(query_data, k) 79 | labels_bf, distances_bf = bf_index.knn_query(query_data, k) 80 | 81 | # Measure recall 82 | correct = 0 83 | for i in range(nun_queries): 84 | for label in labels_hnsw[i]: 85 | for correct_label in labels_bf[i]: 86 | if label == correct_label: 87 | correct += 1 88 | break 89 | 90 | print("recall is :", float(correct)/(k*nun_queries)) 91 | ``` 92 | -------------------------------------------------------------------------------- /build.rs: -------------------------------------------------------------------------------- 1 | fn main() -> Result<(), Box> { 2 | // Tell cargo to rerun this build script if the bindings change. 3 | println!("cargo:rerun-if-changed=src/bindings.cpp"); 4 | // Compile the hnswlib bindings. 5 | cc::Build::new() 6 | .cpp(true) 7 | .file("src/bindings.cpp") 8 | .flag("-std=c++11") 9 | .flag("-Ofast") 10 | .flag("-DHAVE_CXX0X") 11 | .flag("-fPIC") 12 | .flag("-ftree-vectorize") 13 | .flag("-w") 14 | .compile("bindings"); 15 | 16 | Ok(()) 17 | } 18 | -------------------------------------------------------------------------------- /examples/cpp/EXAMPLES.md: -------------------------------------------------------------------------------- 1 | # C++ examples 2 | 3 | Creating index, inserting elements, searching and serialization 4 | ```cpp 5 | #include "../../hnswlib/hnswlib.h" 6 | 7 | 8 | int main() { 9 | int dim = 16; // Dimension of the elements 10 | int max_elements = 10000; // Maximum number of elements, should be known beforehand 11 | int M = 16; // Tightly connected with internal dimensionality of the data 12 | // strongly affects the memory consumption 13 | int ef_construction = 200; // Controls index search speed/build speed tradeoff 14 | 15 | // Initing index 16 | hnswlib::L2Space space(dim); 17 | hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction); 18 | 19 | // Generate random data 20 | std::mt19937 rng; 21 | rng.seed(47); 22 | std::uniform_real_distribution<> distrib_real; 23 | float* data = new float[dim * max_elements]; 24 | for (int i = 0; i < dim * max_elements; i++) { 25 | data[i] = distrib_real(rng); 26 | } 27 | 28 | // Add data to index 29 | for (int i = 0; i < max_elements; i++) { 30 | alg_hnsw->addPoint(data + i * dim, i); 31 | } 32 | 33 | // Query the elements for themselves and measure recall 34 | float correct = 0; 35 | for (int i = 0; i < max_elements; i++) { 36 | std::priority_queue> result = alg_hnsw->searchKnn(data + i * dim, 1); 37 | hnswlib::labeltype label = result.top().second; 38 | if (label == i) correct++; 39 | } 40 | float recall = correct / max_elements; 41 | std::cout << "Recall: " << recall << "\n"; 42 | 43 | // Serialize index 44 | std::string hnsw_path = "hnsw.bin"; 45 | alg_hnsw->saveIndex(hnsw_path); 46 | delete alg_hnsw; 47 | 48 | // Deserialize index and check recall 49 | alg_hnsw = new hnswlib::HierarchicalNSW(&space, hnsw_path); 50 | correct = 0; 51 | for (int i = 0; i < max_elements; i++) { 52 | std::priority_queue> result = alg_hnsw->searchKnn(data + i * dim, 1); 53 | hnswlib::labeltype label = result.top().second; 54 | if (label == i) correct++; 55 | } 56 | recall = (float)correct / max_elements; 57 | std::cout << "Recall of deserialized index: " << recall << "\n"; 58 | 59 | delete[] data; 60 | delete alg_hnsw; 61 | return 0; 62 | } 63 | ``` 64 | 65 | An example of filtering with a boolean function during the search: 66 | ```cpp 67 | #include "../../hnswlib/hnswlib.h" 68 | 69 | 70 | // Filter that allows labels divisible by divisor 71 | class PickDivisibleIds: public hnswlib::BaseFilterFunctor { 72 | unsigned int divisor = 1; 73 | public: 74 | PickDivisibleIds(unsigned int divisor): divisor(divisor) { 75 | assert(divisor != 0); 76 | } 77 | bool operator()(hnswlib::labeltype label_id) { 78 | return label_id % divisor == 0; 79 | } 80 | }; 81 | 82 | 83 | int main() { 84 | int dim = 16; // Dimension of the elements 85 | int max_elements = 10000; // Maximum number of elements, should be known beforehand 86 | int M = 16; // Tightly connected with internal dimensionality of the data 87 | // strongly affects the memory consumption 88 | int ef_construction = 200; // Controls index search speed/build speed tradeoff 89 | 90 | // Initing index 91 | hnswlib::L2Space space(dim); 92 | hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction); 93 | 94 | // Generate random data 95 | std::mt19937 rng; 96 | rng.seed(47); 97 | std::uniform_real_distribution<> distrib_real; 98 | float* data = new float[dim * max_elements]; 99 | for (int i = 0; i < dim * max_elements; i++) { 100 | data[i] = distrib_real(rng); 101 | } 102 | 103 | // Add data to index 104 | for (int i = 0; i < max_elements; i++) { 105 | alg_hnsw->addPoint(data + i * dim, i); 106 | } 107 | 108 | // Create filter that allows only even labels 109 | PickDivisibleIds pickIdsDivisibleByTwo(2); 110 | 111 | // Query the elements for themselves with filter and check returned labels 112 | int k = 10; 113 | for (int i = 0; i < max_elements; i++) { 114 | std::vector> result = alg_hnsw->searchKnnCloserFirst(data + i * dim, k, &pickIdsDivisibleByTwo); 115 | for (auto item: result) { 116 | if (item.second % 2 == 1) std::cout << "Error: found odd label\n"; 117 | } 118 | } 119 | 120 | delete[] data; 121 | delete alg_hnsw; 122 | return 0; 123 | } 124 | ``` 125 | 126 | An example with reusing the memory of the deleted elements when new elements are being added (via `allow_replace_deleted` flag): 127 | ```cpp 128 | #include "../../hnswlib/hnswlib.h" 129 | 130 | 131 | int main() { 132 | int dim = 16; // Dimension of the elements 133 | int max_elements = 10000; // Maximum number of elements, should be known beforehand 134 | int M = 16; // Tightly connected with internal dimensionality of the data 135 | // strongly affects the memory consumption 136 | int ef_construction = 200; // Controls index search speed/build speed tradeoff 137 | 138 | // Initing index 139 | hnswlib::L2Space space(dim); 140 | hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction, 100, true); 141 | 142 | // Generate random data 143 | std::mt19937 rng; 144 | rng.seed(47); 145 | std::uniform_real_distribution<> distrib_real; 146 | float* data = new float[dim * max_elements]; 147 | for (int i = 0; i < dim * max_elements; i++) { 148 | data[i] = distrib_real(rng); 149 | } 150 | 151 | // Add data to index 152 | for (int i = 0; i < max_elements; i++) { 153 | alg_hnsw->addPoint(data + i * dim, i); 154 | } 155 | 156 | // Mark first half of elements as deleted 157 | int num_deleted = max_elements / 2; 158 | for (int i = 0; i < num_deleted; i++) { 159 | alg_hnsw->markDelete(i); 160 | } 161 | 162 | float* add_data = new float[dim * num_deleted]; 163 | for (int i = 0; i < dim * num_deleted; i++) { 164 | add_data[i] = distrib_real(rng); 165 | } 166 | 167 | // Replace deleted data with new elements 168 | // Maximum number of elements is reached therefore we cannot add new items, 169 | // but we can replace the deleted ones by using replace_deleted=true 170 | for (int i = 0; i < num_deleted; i++) { 171 | int label = max_elements + i; 172 | alg_hnsw->addPoint(add_data + i * dim, label, true); 173 | } 174 | 175 | delete[] data; 176 | delete[] add_data; 177 | delete alg_hnsw; 178 | return 0; 179 | } 180 | ``` 181 | 182 | Multithreaded examples: 183 | * Creating index, inserting elements, searching [example_mt_search.cpp](example_mt_search.cpp) 184 | * Filtering during the search with a boolean function [example_mt_filter.cpp](example_mt_filter.cpp) 185 | * Reusing the memory of the deleted elements when new elements are being added [example_mt_replace_deleted.cpp](example_mt_replace_deleted.cpp) -------------------------------------------------------------------------------- /examples/cpp/example_filter.cpp: -------------------------------------------------------------------------------- 1 | #include "../../hnswlib/hnswlib.h" 2 | 3 | // Filter that allows labels divisible by divisor 4 | class PickDivisibleIds : public hnswlib::BaseFilterFunctor 5 | { 6 | unsigned int divisor = 1; 7 | 8 | public: 9 | PickDivisibleIds(unsigned int divisor) : divisor(divisor) 10 | { 11 | assert(divisor != 0); 12 | } 13 | bool operator()(hnswlib::labeltype label_id) 14 | { 15 | return label_id % divisor == 0; 16 | } 17 | }; 18 | 19 | int main() 20 | { 21 | int dim = 16; // Dimension of the elements 22 | int max_elements = 10000; // Maximum number of elements, should be known beforehand 23 | int M = 16; // Tightly connected with internal dimensionality of the data 24 | // strongly affects the memory consumption 25 | int ef_construction = 200; // Controls index search speed/build speed tradeoff 26 | 27 | // Initing index 28 | hnswlib::L2Space space(dim); 29 | hnswlib::HierarchicalNSW *alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction); 30 | 31 | // Generate random data 32 | std::mt19937 rng; 33 | rng.seed(47); 34 | std::uniform_real_distribution<> distrib_real; 35 | float *data = new float[dim * max_elements]; 36 | for (int i = 0; i < dim * max_elements; i++) 37 | { 38 | data[i] = distrib_real(rng); 39 | } 40 | 41 | // Add data to index 42 | for (int i = 0; i < max_elements; i++) 43 | { 44 | alg_hnsw->addPoint(data + i * dim, i); 45 | } 46 | 47 | // Create filter that allows only even labels 48 | PickDivisibleIds pickIdsDivisibleByTwo(2); 49 | 50 | // Query the elements for themselves with filter and check returned labels 51 | int k = 10; 52 | for (int i = 0; i < max_elements; i++) 53 | { 54 | std::vector> result = alg_hnsw->searchKnnCloserFirst(data + i * dim, k, &pickIdsDivisibleByTwo); 55 | for (auto item : result) 56 | { 57 | if (item.second % 2 == 1) 58 | std::cout << "Error: found odd label\n"; 59 | } 60 | } 61 | 62 | delete[] data; 63 | delete alg_hnsw; 64 | return 0; 65 | } 66 | -------------------------------------------------------------------------------- /examples/cpp/example_mt_filter.cpp: -------------------------------------------------------------------------------- 1 | #include "../../hnswlib/hnswlib.h" 2 | #include 3 | 4 | // Multithreaded executor 5 | // The helper function copied from python_bindings/bindings.cpp (and that itself is copied from nmslib) 6 | // An alternative is using #pragme omp parallel for or any other C++ threading 7 | template 8 | inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) 9 | { 10 | if (numThreads <= 0) 11 | { 12 | numThreads = std::thread::hardware_concurrency(); 13 | } 14 | 15 | if (numThreads == 1) 16 | { 17 | for (size_t id = start; id < end; id++) 18 | { 19 | fn(id, 0); 20 | } 21 | } 22 | else 23 | { 24 | std::vector threads; 25 | std::atomic current(start); 26 | 27 | // keep track of exceptions in threads 28 | // https://stackoverflow.com/a/32428427/1713196 29 | std::exception_ptr lastException = nullptr; 30 | std::mutex lastExceptMutex; 31 | 32 | for (size_t threadId = 0; threadId < numThreads; ++threadId) 33 | { 34 | threads.push_back(std::thread([&, threadId] 35 | { 36 | while (true) { 37 | size_t id = current.fetch_add(1); 38 | 39 | if (id >= end) { 40 | break; 41 | } 42 | 43 | try { 44 | fn(id, threadId); 45 | } catch (...) { 46 | std::unique_lock lastExcepLock(lastExceptMutex); 47 | lastException = std::current_exception(); 48 | /* 49 | * This will work even when current is the largest value that 50 | * size_t can fit, because fetch_add returns the previous value 51 | * before the increment (what will result in overflow 52 | * and produce 0 instead of current + 1). 53 | */ 54 | current = end; 55 | break; 56 | } 57 | } })); 58 | } 59 | for (auto &thread : threads) 60 | { 61 | thread.join(); 62 | } 63 | if (lastException) 64 | { 65 | std::rethrow_exception(lastException); 66 | } 67 | } 68 | } 69 | 70 | // Filter that allows labels divisible by divisor 71 | class PickDivisibleIds : public hnswlib::BaseFilterFunctor 72 | { 73 | unsigned int divisor = 1; 74 | 75 | public: 76 | PickDivisibleIds(unsigned int divisor) : divisor(divisor) 77 | { 78 | assert(divisor != 0); 79 | } 80 | bool operator()(hnswlib::labeltype label_id) 81 | { 82 | return label_id % divisor == 0; 83 | } 84 | }; 85 | 86 | int main() 87 | { 88 | int dim = 16; // Dimension of the elements 89 | int max_elements = 10000; // Maximum number of elements, should be known beforehand 90 | int M = 16; // Tightly connected with internal dimensionality of the data 91 | // strongly affects the memory consumption 92 | int ef_construction = 200; // Controls index search speed/build speed tradeoff 93 | int num_threads = 20; // Number of threads for operations with index 94 | 95 | // Initing index 96 | hnswlib::L2Space space(dim); 97 | hnswlib::HierarchicalNSW *alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction); 98 | 99 | // Generate random data 100 | std::mt19937 rng; 101 | rng.seed(47); 102 | std::uniform_real_distribution<> distrib_real; 103 | float *data = new float[dim * max_elements]; 104 | for (int i = 0; i < dim * max_elements; i++) 105 | { 106 | data[i] = distrib_real(rng); 107 | } 108 | 109 | // Add data to index 110 | ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) 111 | { alg_hnsw->addPoint((void *)(data + dim * row), row); }); 112 | 113 | // Create filter that allows only even labels 114 | PickDivisibleIds pickIdsDivisibleByTwo(2); 115 | 116 | // Query the elements for themselves with filter and check returned labels 117 | int k = 10; 118 | std::vector neighbors(max_elements * k); 119 | ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) 120 | { 121 | std::priority_queue> result = alg_hnsw->searchKnn(data + dim * row, k, &pickIdsDivisibleByTwo); 122 | for (int i = 0; i < k; i++) { 123 | hnswlib::labeltype label = result.top().second; 124 | result.pop(); 125 | neighbors[row * k + i] = label; 126 | } }); 127 | 128 | for (hnswlib::labeltype label : neighbors) 129 | { 130 | if (label % 2 == 1) 131 | std::cout << "Error: found odd label\n"; 132 | } 133 | 134 | delete[] data; 135 | delete alg_hnsw; 136 | return 0; 137 | } 138 | -------------------------------------------------------------------------------- /examples/cpp/example_mt_replace_deleted.cpp: -------------------------------------------------------------------------------- 1 | #include "../../hnswlib/hnswlib.h" 2 | #include 3 | 4 | // Multithreaded executor 5 | // The helper function copied from python_bindings/bindings.cpp (and that itself is copied from nmslib) 6 | // An alternative is using #pragme omp parallel for or any other C++ threading 7 | template 8 | inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) 9 | { 10 | if (numThreads <= 0) 11 | { 12 | numThreads = std::thread::hardware_concurrency(); 13 | } 14 | 15 | if (numThreads == 1) 16 | { 17 | for (size_t id = start; id < end; id++) 18 | { 19 | fn(id, 0); 20 | } 21 | } 22 | else 23 | { 24 | std::vector threads; 25 | std::atomic current(start); 26 | 27 | // keep track of exceptions in threads 28 | // https://stackoverflow.com/a/32428427/1713196 29 | std::exception_ptr lastException = nullptr; 30 | std::mutex lastExceptMutex; 31 | 32 | for (size_t threadId = 0; threadId < numThreads; ++threadId) 33 | { 34 | threads.push_back(std::thread([&, threadId] 35 | { 36 | while (true) { 37 | size_t id = current.fetch_add(1); 38 | 39 | if (id >= end) { 40 | break; 41 | } 42 | 43 | try { 44 | fn(id, threadId); 45 | } catch (...) { 46 | std::unique_lock lastExcepLock(lastExceptMutex); 47 | lastException = std::current_exception(); 48 | /* 49 | * This will work even when current is the largest value that 50 | * size_t can fit, because fetch_add returns the previous value 51 | * before the increment (what will result in overflow 52 | * and produce 0 instead of current + 1). 53 | */ 54 | current = end; 55 | break; 56 | } 57 | } })); 58 | } 59 | for (auto &thread : threads) 60 | { 61 | thread.join(); 62 | } 63 | if (lastException) 64 | { 65 | std::rethrow_exception(lastException); 66 | } 67 | } 68 | } 69 | 70 | int main() 71 | { 72 | int dim = 16; // Dimension of the elements 73 | int max_elements = 10000; // Maximum number of elements, should be known beforehand 74 | int M = 16; // Tightly connected with internal dimensionality of the data 75 | // strongly affects the memory consumption 76 | int ef_construction = 200; // Controls index search speed/build speed tradeoff 77 | int num_threads = 20; // Number of threads for operations with index 78 | 79 | // Initing index with allow_replace_deleted=true 80 | int seed = 100; 81 | hnswlib::L2Space space(dim); 82 | hnswlib::HierarchicalNSW *alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction, seed, true); 83 | 84 | // Generate random data 85 | std::mt19937 rng; 86 | rng.seed(47); 87 | std::uniform_real_distribution<> distrib_real; 88 | float *data = new float[dim * max_elements]; 89 | for (int i = 0; i < dim * max_elements; i++) 90 | { 91 | data[i] = distrib_real(rng); 92 | } 93 | 94 | // Add data to index 95 | ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) 96 | { alg_hnsw->addPoint((void *)(data + dim * row), row); }); 97 | 98 | // Mark first half of elements as deleted 99 | int num_deleted = max_elements / 2; 100 | ParallelFor(0, num_deleted, num_threads, [&](size_t row, size_t threadId) 101 | { alg_hnsw->markDelete(row); }); 102 | 103 | // Generate additional random data 104 | float *add_data = new float[dim * num_deleted]; 105 | for (int i = 0; i < dim * num_deleted; i++) 106 | { 107 | add_data[i] = distrib_real(rng); 108 | } 109 | 110 | // Replace deleted data with new elements 111 | // Maximum number of elements is reached therefore we cannot add new items, 112 | // but we can replace the deleted ones by using replace_deleted=true 113 | ParallelFor(0, num_deleted, num_threads, [&](size_t row, size_t threadId) 114 | { 115 | hnswlib::labeltype label = max_elements + row; 116 | alg_hnsw->addPoint((void*)(add_data + dim * row), label, true); }); 117 | 118 | delete[] data; 119 | delete[] add_data; 120 | delete alg_hnsw; 121 | return 0; 122 | } 123 | -------------------------------------------------------------------------------- /examples/cpp/example_mt_search.cpp: -------------------------------------------------------------------------------- 1 | #include "../../hnswlib/hnswlib.h" 2 | #include 3 | 4 | // Multithreaded executor 5 | // The helper function copied from python_bindings/bindings.cpp (and that itself is copied from nmslib) 6 | // An alternative is using #pragme omp parallel for or any other C++ threading 7 | template 8 | inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) 9 | { 10 | if (numThreads <= 0) 11 | { 12 | numThreads = std::thread::hardware_concurrency(); 13 | } 14 | 15 | if (numThreads == 1) 16 | { 17 | for (size_t id = start; id < end; id++) 18 | { 19 | fn(id, 0); 20 | } 21 | } 22 | else 23 | { 24 | std::vector threads; 25 | std::atomic current(start); 26 | 27 | // keep track of exceptions in threads 28 | // https://stackoverflow.com/a/32428427/1713196 29 | std::exception_ptr lastException = nullptr; 30 | std::mutex lastExceptMutex; 31 | 32 | for (size_t threadId = 0; threadId < numThreads; ++threadId) 33 | { 34 | threads.push_back(std::thread([&, threadId] 35 | { 36 | while (true) { 37 | size_t id = current.fetch_add(1); 38 | 39 | if (id >= end) { 40 | break; 41 | } 42 | 43 | try { 44 | fn(id, threadId); 45 | } catch (...) { 46 | std::unique_lock lastExcepLock(lastExceptMutex); 47 | lastException = std::current_exception(); 48 | /* 49 | * This will work even when current is the largest value that 50 | * size_t can fit, because fetch_add returns the previous value 51 | * before the increment (what will result in overflow 52 | * and produce 0 instead of current + 1). 53 | */ 54 | current = end; 55 | break; 56 | } 57 | } })); 58 | } 59 | for (auto &thread : threads) 60 | { 61 | thread.join(); 62 | } 63 | if (lastException) 64 | { 65 | std::rethrow_exception(lastException); 66 | } 67 | } 68 | } 69 | 70 | int main() 71 | { 72 | int dim = 16; // Dimension of the elements 73 | int max_elements = 10000; // Maximum number of elements, should be known beforehand 74 | int M = 16; // Tightly connected with internal dimensionality of the data 75 | // strongly affects the memory consumption 76 | int ef_construction = 200; // Controls index search speed/build speed tradeoff 77 | int num_threads = 20; // Number of threads for operations with index 78 | 79 | // Initing index 80 | hnswlib::L2Space space(dim); 81 | hnswlib::HierarchicalNSW *alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction); 82 | 83 | // Generate random data 84 | std::mt19937 rng; 85 | rng.seed(47); 86 | std::uniform_real_distribution<> distrib_real; 87 | float *data = new float[dim * max_elements]; 88 | for (int i = 0; i < dim * max_elements; i++) 89 | { 90 | data[i] = distrib_real(rng); 91 | } 92 | 93 | // Add data to index 94 | ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) 95 | { alg_hnsw->addPoint((void *)(data + dim * row), row); }); 96 | 97 | // Query the elements for themselves and measure recall 98 | std::vector neighbors(max_elements); 99 | ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) 100 | { 101 | std::priority_queue> result = alg_hnsw->searchKnn(data + dim * row, 1); 102 | hnswlib::labeltype label = result.top().second; 103 | neighbors[row] = label; }); 104 | float correct = 0; 105 | for (int i = 0; i < max_elements; i++) 106 | { 107 | hnswlib::labeltype label = neighbors[i]; 108 | if (label == i) 109 | correct++; 110 | } 111 | float recall = correct / max_elements; 112 | std::cout << "Recall: " << recall << "\n"; 113 | 114 | delete[] data; 115 | delete alg_hnsw; 116 | return 0; 117 | } 118 | -------------------------------------------------------------------------------- /examples/cpp/example_replace_deleted.cpp: -------------------------------------------------------------------------------- 1 | #include "../../hnswlib/hnswlib.h" 2 | 3 | int main() 4 | { 5 | int dim = 16; // Dimension of the elements 6 | int max_elements = 10000; // Maximum number of elements, should be known beforehand 7 | int M = 16; // Tightly connected with internal dimensionality of the data 8 | // strongly affects the memory consumption 9 | int ef_construction = 200; // Controls index search speed/build speed tradeoff 10 | 11 | // Initing index with allow_replace_deleted=true 12 | int seed = 100; 13 | hnswlib::L2Space space(dim); 14 | hnswlib::HierarchicalNSW *alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction, seed, true); 15 | 16 | // Generate random data 17 | std::mt19937 rng; 18 | rng.seed(47); 19 | std::uniform_real_distribution<> distrib_real; 20 | float *data = new float[dim * max_elements]; 21 | for (int i = 0; i < dim * max_elements; i++) 22 | { 23 | data[i] = distrib_real(rng); 24 | } 25 | 26 | // Add data to index 27 | for (int i = 0; i < max_elements; i++) 28 | { 29 | alg_hnsw->addPoint(data + i * dim, i); 30 | } 31 | 32 | // Mark first half of elements as deleted 33 | int num_deleted = max_elements / 2; 34 | for (int i = 0; i < num_deleted; i++) 35 | { 36 | alg_hnsw->markDelete(i); 37 | } 38 | 39 | // Generate additional random data 40 | float *add_data = new float[dim * num_deleted]; 41 | for (int i = 0; i < dim * num_deleted; i++) 42 | { 43 | add_data[i] = distrib_real(rng); 44 | } 45 | 46 | // Replace deleted data with new elements 47 | // Maximum number of elements is reached therefore we cannot add new items, 48 | // but we can replace the deleted ones by using replace_deleted=true 49 | for (int i = 0; i < num_deleted; i++) 50 | { 51 | hnswlib::labeltype label = max_elements + i; 52 | alg_hnsw->addPoint(add_data + i * dim, label, true); 53 | } 54 | 55 | delete[] data; 56 | delete[] add_data; 57 | delete alg_hnsw; 58 | return 0; 59 | } 60 | -------------------------------------------------------------------------------- /examples/cpp/example_search.cpp: -------------------------------------------------------------------------------- 1 | #include "../../hnswlib/hnswlib.h" 2 | 3 | int main() 4 | { 5 | int dim = 16; // Dimension of the elements 6 | int max_elements = 10000; // Maximum number of elements, should be known beforehand 7 | int M = 16; // Tightly connected with internal dimensionality of the data 8 | // strongly affects the memory consumption 9 | int ef_construction = 200; // Controls index search speed/build speed tradeoff 10 | 11 | // Initing index 12 | hnswlib::L2Space space(dim); 13 | hnswlib::HierarchicalNSW *alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction); 14 | 15 | // Generate random data 16 | std::mt19937 rng; 17 | rng.seed(47); 18 | std::uniform_real_distribution<> distrib_real; 19 | float *data = new float[dim * max_elements]; 20 | for (int i = 0; i < dim * max_elements; i++) 21 | { 22 | data[i] = distrib_real(rng); 23 | } 24 | 25 | // Add data to index 26 | for (int i = 0; i < max_elements; i++) 27 | { 28 | alg_hnsw->addPoint(data + i * dim, i); 29 | } 30 | 31 | // Query the elements for themselves and measure recall 32 | float correct = 0; 33 | for (int i = 0; i < max_elements; i++) 34 | { 35 | std::priority_queue> result = alg_hnsw->searchKnn(data + i * dim, 1); 36 | hnswlib::labeltype label = result.top().second; 37 | if (label == i) 38 | correct++; 39 | } 40 | float recall = correct / max_elements; 41 | std::cout << "Recall: " << recall << "\n"; 42 | 43 | // Serialize index 44 | std::string hnsw_path = "hnsw.bin"; 45 | alg_hnsw->saveIndex(hnsw_path); 46 | delete alg_hnsw; 47 | 48 | // Deserialize index and check recall 49 | alg_hnsw = new hnswlib::HierarchicalNSW(&space, hnsw_path); 50 | correct = 0; 51 | for (int i = 0; i < max_elements; i++) 52 | { 53 | std::priority_queue> result = alg_hnsw->searchKnn(data + i * dim, 1); 54 | hnswlib::labeltype label = result.top().second; 55 | if (label == i) 56 | correct++; 57 | } 58 | recall = (float)correct / max_elements; 59 | std::cout << "Recall of deserialized index: " << recall << "\n"; 60 | 61 | delete[] data; 62 | delete alg_hnsw; 63 | return 0; 64 | } 65 | -------------------------------------------------------------------------------- /examples/python/EXAMPLES.md: -------------------------------------------------------------------------------- 1 | # Python bindings examples 2 | 3 | Creating index, inserting elements, searching and pickle serialization: 4 | ```python 5 | import hnswlib 6 | import numpy as np 7 | import pickle 8 | 9 | dim = 128 10 | num_elements = 10000 11 | 12 | # Generating sample data 13 | data = np.float32(np.random.random((num_elements, dim))) 14 | ids = np.arange(num_elements) 15 | 16 | # Declaring index 17 | p = hnswlib.Index(space = 'l2', dim = dim) # possible options are l2, cosine or ip 18 | 19 | # Initializing index - the maximum number of elements should be known beforehand 20 | p.init_index(max_elements = num_elements, ef_construction = 200, M = 16) 21 | 22 | # Element insertion (can be called several times): 23 | p.add_items(data, ids) 24 | 25 | # Controlling the recall by setting ef: 26 | p.set_ef(50) # ef should always be > k 27 | 28 | # Query dataset, k - number of the closest elements (returns 2 numpy arrays) 29 | labels, distances = p.knn_query(data, k = 1) 30 | 31 | # Index objects support pickling 32 | # WARNING: serialization via pickle.dumps(p) or p.__getstate__() is NOT thread-safe with p.add_items method! 33 | # Note: ef parameter is included in serialization; random number generator is initialized with random_seed on Index load 34 | p_copy = pickle.loads(pickle.dumps(p)) # creates a copy of index p using pickle round-trip 35 | 36 | ### Index parameters are exposed as class properties: 37 | print(f"Parameters passed to constructor: space={p_copy.space}, dim={p_copy.dim}") 38 | print(f"Index construction: M={p_copy.M}, ef_construction={p_copy.ef_construction}") 39 | print(f"Index size is {p_copy.element_count} and index capacity is {p_copy.max_elements}") 40 | print(f"Search speed/quality trade-off parameter: ef={p_copy.ef}") 41 | ``` 42 | 43 | An example with updates after serialization/deserialization: 44 | ```python 45 | import hnswlib 46 | import numpy as np 47 | 48 | dim = 16 49 | num_elements = 10000 50 | 51 | # Generating sample data 52 | data = np.float32(np.random.random((num_elements, dim))) 53 | 54 | # We split the data in two batches: 55 | data1 = data[:num_elements // 2] 56 | data2 = data[num_elements // 2:] 57 | 58 | # Declaring index 59 | p = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip 60 | 61 | # Initializing index 62 | # max_elements - the maximum number of elements (capacity). Will throw an exception if exceeded 63 | # during insertion of an element. 64 | # The capacity can be increased by saving/loading the index, see below. 65 | # 66 | # ef_construction - controls index search speed/build speed tradeoff 67 | # 68 | # M - is tightly connected with internal dimensionality of the data. Strongly affects memory consumption (~M) 69 | # Higher M leads to higher accuracy/run_time at fixed ef/efConstruction 70 | 71 | p.init_index(max_elements=num_elements//2, ef_construction=100, M=16) 72 | 73 | # Controlling the recall by setting ef: 74 | # higher ef leads to better accuracy, but slower search 75 | p.set_ef(10) 76 | 77 | # Set number of threads used during batch search/construction 78 | # By default using all available cores 79 | p.set_num_threads(4) 80 | 81 | print("Adding first batch of %d elements" % (len(data1))) 82 | p.add_items(data1) 83 | 84 | # Query the elements for themselves and measure recall: 85 | labels, distances = p.knn_query(data1, k=1) 86 | print("Recall for the first batch:", np.mean(labels.reshape(-1) == np.arange(len(data1))), "\n") 87 | 88 | # Serializing and deleting the index: 89 | index_path='first_half.bin' 90 | print("Saving index to '%s'" % index_path) 91 | p.save_index("first_half.bin") 92 | del p 93 | 94 | # Re-initializing, loading the index 95 | p = hnswlib.Index(space='l2', dim=dim) # the space can be changed - keeps the data, alters the distance function. 96 | 97 | print("\nLoading index from 'first_half.bin'\n") 98 | 99 | # Increase the total capacity (max_elements), so that it will handle the new data 100 | p.load_index("first_half.bin", max_elements = num_elements) 101 | 102 | print("Adding the second batch of %d elements" % (len(data2))) 103 | p.add_items(data2) 104 | 105 | # Query the elements for themselves and measure recall: 106 | labels, distances = p.knn_query(data, k=1) 107 | print("Recall for two batches:", np.mean(labels.reshape(-1) == np.arange(len(data))), "\n") 108 | ``` 109 | 110 | An example with a symbolic filter `filter_function` during the search: 111 | ```python 112 | import hnswlib 113 | import numpy as np 114 | 115 | dim = 16 116 | num_elements = 10000 117 | 118 | # Generating sample data 119 | data = np.float32(np.random.random((num_elements, dim))) 120 | 121 | # Declaring index 122 | hnsw_index = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip 123 | 124 | # Initiating index 125 | # max_elements - the maximum number of elements, should be known beforehand 126 | # (probably will be made optional in the future) 127 | # 128 | # ef_construction - controls index search speed/build speed tradeoff 129 | # M - is tightly connected with internal dimensionality of the data 130 | # strongly affects the memory consumption 131 | 132 | hnsw_index.init_index(max_elements=num_elements, ef_construction=100, M=16) 133 | 134 | # Controlling the recall by setting ef: 135 | # higher ef leads to better accuracy, but slower search 136 | hnsw_index.set_ef(10) 137 | 138 | # Set number of threads used during batch search/construction 139 | # By default using all available cores 140 | hnsw_index.set_num_threads(4) 141 | 142 | print("Adding %d elements" % (len(data))) 143 | # Added elements will have consecutive ids 144 | hnsw_index.add_items(data, ids=np.arange(num_elements)) 145 | 146 | print("Querying only even elements") 147 | # Define filter function that allows only even ids 148 | filter_function = lambda idx: idx%2 == 0 149 | # Query the elements for themselves and search only for even elements: 150 | # Warning: search with python filter works slow in multithreaded mode, therefore we set num_threads=1 151 | labels, distances = hnsw_index.knn_query(data, k=1, num_threads=1, filter=filter_function) 152 | # labels contain only elements with even id 153 | ``` 154 | 155 | An example with reusing the memory of the deleted elements when new elements are being added (via `allow_replace_deleted` flag): 156 | ```python 157 | import hnswlib 158 | import numpy as np 159 | 160 | dim = 16 161 | num_elements = 1_000 162 | max_num_elements = 2 * num_elements 163 | 164 | # Generating sample data 165 | labels1 = np.arange(0, num_elements) 166 | data1 = np.float32(np.random.random((num_elements, dim))) # batch 1 167 | labels2 = np.arange(num_elements, 2 * num_elements) 168 | data2 = np.float32(np.random.random((num_elements, dim))) # batch 2 169 | labels3 = np.arange(2 * num_elements, 3 * num_elements) 170 | data3 = np.float32(np.random.random((num_elements, dim))) # batch 3 171 | 172 | # Declaring index 173 | hnsw_index = hnswlib.Index(space='l2', dim=dim) 174 | 175 | # Initiating index 176 | # max_elements - the maximum number of elements, should be known beforehand 177 | # (probably will be made optional in the future) 178 | # 179 | # ef_construction - controls index search speed/build speed tradeoff 180 | # M - is tightly connected with internal dimensionality of the data 181 | # strongly affects the memory consumption 182 | 183 | # Enable replacing of deleted elements 184 | hnsw_index.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True) 185 | 186 | # Controlling the recall by setting ef: 187 | # higher ef leads to better accuracy, but slower search 188 | hnsw_index.set_ef(10) 189 | 190 | # Set number of threads used during batch search/construction 191 | # By default using all available cores 192 | hnsw_index.set_num_threads(4) 193 | 194 | # Add batch 1 and 2 data 195 | hnsw_index.add_items(data1, labels1) 196 | hnsw_index.add_items(data2, labels2) # Note: maximum number of elements is reached 197 | 198 | # Delete data of batch 2 199 | for label in labels2: 200 | hnsw_index.mark_deleted(label) 201 | 202 | # Replace deleted elements 203 | # Maximum number of elements is reached therefore we cannot add new items, 204 | # but we can replace the deleted ones by using replace_deleted=True 205 | hnsw_index.add_items(data3, labels3, replace_deleted=True) 206 | # hnsw_index contains the data of batch 1 and batch 3 only 207 | ``` -------------------------------------------------------------------------------- /examples/python/example.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hnswlib 3 | import numpy as np 4 | 5 | 6 | """ 7 | Example of index building, search and serialization/deserialization 8 | """ 9 | 10 | dim = 16 11 | num_elements = 10000 12 | 13 | # Generating sample data 14 | data = np.float32(np.random.random((num_elements, dim))) 15 | 16 | # We split the data in two batches: 17 | data1 = data[: num_elements // 2] 18 | data2 = data[num_elements // 2 :] 19 | 20 | # Declaring index 21 | p = hnswlib.Index(space="l2", dim=dim) # possible options are l2, cosine or ip 22 | 23 | # Initing index 24 | # max_elements - the maximum number of elements (capacity). Will throw an exception if exceeded 25 | # during insertion of an element. 26 | # The capacity can be increased by saving/loading the index, see below. 27 | # 28 | # ef_construction - controls index search speed/build speed tradeoff 29 | # 30 | # M - is tightly connected with internal dimensionality of the data. Strongly affects the memory consumption (~M) 31 | # Higher M leads to higher accuracy/run_time at fixed ef/efConstruction 32 | 33 | p.init_index(max_elements=num_elements // 2, ef_construction=100, M=16) 34 | 35 | # Controlling the recall by setting ef: 36 | # higher ef leads to better accuracy, but slower search 37 | p.set_ef(10) 38 | 39 | # Set number of threads used during batch search/construction 40 | # By default using all available cores 41 | p.set_num_threads(4) 42 | 43 | print("Adding first batch of %d elements" % (len(data1))) 44 | p.add_items(data1) 45 | 46 | # Query the elements for themselves and measure recall: 47 | labels, distances = p.knn_query(data1, k=1) 48 | print( 49 | "Recall for the first batch:", 50 | np.mean(labels.reshape(-1) == np.arange(len(data1))), 51 | "\n", 52 | ) 53 | 54 | # Serializing and deleting the index: 55 | index_path = "first_half.bin" 56 | print("Saving index to '%s'" % index_path) 57 | p.save_index(index_path) 58 | del p 59 | 60 | # Reiniting, loading the index 61 | p = hnswlib.Index( 62 | space="l2", dim=dim 63 | ) # the space can be changed - keeps the data, alters the distance function. 64 | 65 | print("\nLoading index from 'first_half.bin'\n") 66 | 67 | # Increase the total capacity (max_elements), so that it will handle the new data 68 | p.load_index("first_half.bin", max_elements=num_elements) 69 | 70 | print("Adding the second batch of %d elements" % (len(data2))) 71 | p.add_items(data2) 72 | 73 | # Query the elements for themselves and measure recall: 74 | labels, distances = p.knn_query(data, k=1) 75 | print( 76 | "Recall for two batches:", np.mean(labels.reshape(-1) == np.arange(len(data))), "\n" 77 | ) 78 | 79 | os.remove("first_half.bin") 80 | -------------------------------------------------------------------------------- /examples/python/example_filter.py: -------------------------------------------------------------------------------- 1 | import hnswlib 2 | import numpy as np 3 | 4 | 5 | """ 6 | Example of filtering elements when searching 7 | """ 8 | 9 | dim = 16 10 | num_elements = 10000 11 | 12 | # Generating sample data 13 | data = np.float32(np.random.random((num_elements, dim))) 14 | 15 | # Declaring index 16 | hnsw_index = hnswlib.Index(space="l2", dim=dim) # possible options are l2, cosine or ip 17 | 18 | # Initiating index 19 | # max_elements - the maximum number of elements, should be known beforehand 20 | # (probably will be made optional in the future) 21 | # 22 | # ef_construction - controls index search speed/build speed tradeoff 23 | # M - is tightly connected with internal dimensionality of the data 24 | # strongly affects the memory consumption 25 | 26 | hnsw_index.init_index(max_elements=num_elements, ef_construction=100, M=16) 27 | 28 | # Controlling the recall by setting ef: 29 | # higher ef leads to better accuracy, but slower search 30 | hnsw_index.set_ef(10) 31 | 32 | # Set number of threads used during batch search/construction 33 | # By default using all available cores 34 | hnsw_index.set_num_threads(4) 35 | 36 | print("Adding %d elements" % (len(data))) 37 | # Added elements will have consecutive ids 38 | hnsw_index.add_items(data, ids=np.arange(num_elements)) 39 | 40 | print("Querying only even elements") 41 | # Define filter function that allows only even ids 42 | filter_function = lambda idx: idx % 2 == 0 43 | # Query the elements for themselves and search only for even elements: 44 | # Warning: search with a filter works slow in python in multithreaded mode, therefore we set num_threads=1 45 | labels, distances = hnsw_index.knn_query( 46 | data, k=1, num_threads=1, filter=filter_function 47 | ) 48 | # labels contain only elements with even id 49 | -------------------------------------------------------------------------------- /examples/python/example_replace_deleted.py: -------------------------------------------------------------------------------- 1 | import hnswlib 2 | import numpy as np 3 | 4 | 5 | """ 6 | Example of replacing deleted elements with new ones 7 | """ 8 | 9 | dim = 16 10 | num_elements = 1_000 11 | max_num_elements = 2 * num_elements 12 | 13 | # Generating sample data 14 | labels1 = np.arange(0, num_elements) 15 | data1 = np.float32(np.random.random((num_elements, dim))) # batch 1 16 | labels2 = np.arange(num_elements, 2 * num_elements) 17 | data2 = np.float32(np.random.random((num_elements, dim))) # batch 2 18 | labels3 = np.arange(2 * num_elements, 3 * num_elements) 19 | data3 = np.float32(np.random.random((num_elements, dim))) # batch 3 20 | 21 | # Declaring index 22 | hnsw_index = hnswlib.Index(space="l2", dim=dim) 23 | 24 | # Initiating index 25 | # max_elements - the maximum number of elements, should be known beforehand 26 | # (probably will be made optional in the future) 27 | # 28 | # ef_construction - controls index search speed/build speed tradeoff 29 | # M - is tightly connected with internal dimensionality of the data 30 | # strongly affects the memory consumption 31 | 32 | # Enable replacing of deleted elements 33 | hnsw_index.init_index( 34 | max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True 35 | ) 36 | 37 | # Controlling the recall by setting ef: 38 | # higher ef leads to better accuracy, but slower search 39 | hnsw_index.set_ef(10) 40 | 41 | # Set number of threads used during batch search/construction 42 | # By default using all available cores 43 | hnsw_index.set_num_threads(4) 44 | 45 | # Add batch 1 and 2 data 46 | hnsw_index.add_items(data1, labels1) 47 | hnsw_index.add_items(data2, labels2) # Note: maximum number of elements is reached 48 | 49 | # Delete data of batch 2 50 | for label in labels2: 51 | hnsw_index.mark_deleted(label) 52 | 53 | # Replace deleted elements 54 | # Maximum number of elements is reached therefore we cannot add new items, 55 | # but we can replace the deleted ones by using replace_deleted=True 56 | hnsw_index.add_items(data3, labels3, replace_deleted=True) 57 | # hnsw_index contains the data of batch 1 and batch 3 only 58 | -------------------------------------------------------------------------------- /examples/python/example_search.py: -------------------------------------------------------------------------------- 1 | import hnswlib 2 | import numpy as np 3 | import pickle 4 | 5 | 6 | """ 7 | Example of search 8 | """ 9 | 10 | dim = 128 11 | num_elements = 10000 12 | 13 | # Generating sample data 14 | data = np.float32(np.random.random((num_elements, dim))) 15 | ids = np.arange(num_elements) 16 | 17 | # Declaring index 18 | p = hnswlib.Index(space="l2", dim=dim) # possible options are l2, cosine or ip 19 | 20 | # Initializing index - the maximum number of elements should be known beforehand 21 | p.init_index(max_elements=num_elements, ef_construction=200, M=16) 22 | 23 | # Element insertion (can be called several times): 24 | p.add_items(data, ids) 25 | 26 | # Controlling the recall by setting ef: 27 | p.set_ef(50) # ef should always be > k 28 | 29 | # Query dataset, k - number of the closest elements (returns 2 numpy arrays) 30 | labels, distances = p.knn_query(data, k=1) 31 | 32 | # Index objects support pickling 33 | # WARNING: serialization via pickle.dumps(p) or p.__getstate__() is NOT thread-safe with p.add_items method! 34 | # Note: ef parameter is included in serialization; random number generator is initialized with random_seed on Index load 35 | p_copy = pickle.loads( 36 | pickle.dumps(p) 37 | ) # creates a copy of index p using pickle round-trip 38 | 39 | ### Index parameters are exposed as class properties: 40 | print(f"Parameters passed to constructor: space={p_copy.space}, dim={p_copy.dim}") 41 | print(f"Index construction: M={p_copy.M}, ef_construction={p_copy.ef_construction}") 42 | print( 43 | f"Index size is {p_copy.element_count} and index capacity is {p_copy.max_elements}" 44 | ) 45 | print(f"Search speed/quality trade-off parameter: ef={p_copy.ef}") 46 | -------------------------------------------------------------------------------- /examples/python/example_serialization.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import hnswlib 4 | import numpy as np 5 | 6 | 7 | """ 8 | Example of serialization/deserialization 9 | """ 10 | 11 | dim = 16 12 | num_elements = 10000 13 | 14 | # Generating sample data 15 | data = np.float32(np.random.random((num_elements, dim))) 16 | 17 | # We split the data in two batches: 18 | data1 = data[: num_elements // 2] 19 | data2 = data[num_elements // 2 :] 20 | 21 | # Declaring index 22 | p = hnswlib.Index(space="l2", dim=dim) # possible options are l2, cosine or ip 23 | 24 | # Initializing index 25 | # max_elements - the maximum number of elements (capacity). Will throw an exception if exceeded 26 | # during insertion of an element. 27 | # The capacity can be increased by saving/loading the index, see below. 28 | # 29 | # ef_construction - controls index search speed/build speed tradeoff 30 | # 31 | # M - is tightly connected with internal dimensionality of the data. Strongly affects memory consumption (~M) 32 | # Higher M leads to higher accuracy/run_time at fixed ef/efConstruction 33 | 34 | p.init_index(max_elements=num_elements // 2, ef_construction=100, M=16) 35 | 36 | # Controlling the recall by setting ef: 37 | # higher ef leads to better accuracy, but slower search 38 | p.set_ef(10) 39 | 40 | # Set number of threads used during batch search/construction 41 | # By default using all available cores 42 | p.set_num_threads(4) 43 | 44 | print("Adding first batch of %d elements" % (len(data1))) 45 | p.add_items(data1) 46 | 47 | # Query the elements for themselves and measure recall: 48 | labels, distances = p.knn_query(data1, k=1) 49 | print( 50 | "Recall for the first batch:", 51 | np.mean(labels.reshape(-1) == np.arange(len(data1))), 52 | "\n", 53 | ) 54 | 55 | # Serializing and deleting the index: 56 | index_path = "first_half.bin" 57 | print("Saving index to '%s'" % index_path) 58 | p.save_index("first_half.bin") 59 | del p 60 | 61 | # Re-initializing, loading the index 62 | p = hnswlib.Index( 63 | space="l2", dim=dim 64 | ) # the space can be changed - keeps the data, alters the distance function. 65 | 66 | print("\nLoading index from 'first_half.bin'\n") 67 | 68 | # Increase the total capacity (max_elements), so that it will handle the new data 69 | p.load_index("first_half.bin", max_elements=num_elements) 70 | 71 | print("Adding the second batch of %d elements" % (len(data2))) 72 | p.add_items(data2) 73 | 74 | # Query the elements for themselves and measure recall: 75 | labels, distances = p.knn_query(data, k=1) 76 | print( 77 | "Recall for two batches:", np.mean(labels.reshape(-1) == np.arange(len(data))), "\n" 78 | ) 79 | 80 | os.remove("first_half.bin") 81 | -------------------------------------------------------------------------------- /examples/python/pyw_hnswlib.py: -------------------------------------------------------------------------------- 1 | import hnswlib 2 | import numpy as np 3 | import threading 4 | import pickle 5 | 6 | 7 | """ 8 | Example of python wrapper for hnswlib that supports python objects as ids 9 | """ 10 | 11 | 12 | class Index: 13 | def __init__(self, space, dim): 14 | self.index = hnswlib.Index(space, dim) 15 | self.lock = threading.Lock() 16 | self.dict_labels = {} 17 | self.cur_ind = 0 18 | 19 | def init_index(self, max_elements, ef_construction=200, M=16): 20 | self.index.init_index( 21 | max_elements=max_elements, ef_construction=ef_construction, M=M 22 | ) 23 | 24 | def add_items(self, data, ids=None): 25 | if ids is not None: 26 | assert len(data) == len(ids) 27 | num_added = len(data) 28 | with self.lock: 29 | start = self.cur_ind 30 | self.cur_ind += num_added 31 | int_labels = [] 32 | 33 | if ids is not None: 34 | for dl in ids: 35 | int_labels.append(start) 36 | self.dict_labels[start] = dl 37 | start += 1 38 | else: 39 | for _ in range(len(data)): 40 | int_labels.append(start) 41 | self.dict_labels[start] = start 42 | start += 1 43 | self.index.add_items(data=data, ids=np.asarray(int_labels)) 44 | 45 | def set_ef(self, ef): 46 | self.index.set_ef(ef) 47 | 48 | def load_index(self, path): 49 | self.index.load_index(path) 50 | with open(path + ".pkl", "rb") as f: 51 | self.cur_ind, self.dict_labels = pickle.load(f) 52 | 53 | def save_index(self, path): 54 | self.index.save_index(path) 55 | with open(path + ".pkl", "wb") as f: 56 | pickle.dump((self.cur_ind, self.dict_labels), f) 57 | 58 | def set_num_threads(self, num_threads): 59 | self.index.set_num_threads(num_threads) 60 | 61 | def knn_query(self, data, k=1): 62 | labels_int, distances = self.index.knn_query(data=data, k=k) 63 | labels = [] 64 | for li in labels_int: 65 | labels.append([self.dict_labels[l] for l in li]) 66 | return labels, distances 67 | -------------------------------------------------------------------------------- /hnswlib/bruteforce.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace hnswlib 9 | { 10 | template 11 | class BruteforceSearch : public AlgorithmInterface 12 | { 13 | public: 14 | char *data_; 15 | size_t maxelements_; 16 | size_t cur_element_count; 17 | size_t size_per_element_; 18 | 19 | size_t data_size_; 20 | DISTFUNC fstdistfunc_; 21 | void *dist_func_param_; 22 | std::mutex index_lock; 23 | 24 | std::unordered_map dict_external_to_internal; 25 | 26 | BruteforceSearch(SpaceInterface *s) 27 | : data_(nullptr), 28 | maxelements_(0), 29 | cur_element_count(0), 30 | size_per_element_(0), 31 | data_size_(0), 32 | dist_func_param_(nullptr) 33 | { 34 | } 35 | 36 | BruteforceSearch(SpaceInterface *s, const std::string &location) 37 | : data_(nullptr), 38 | maxelements_(0), 39 | cur_element_count(0), 40 | size_per_element_(0), 41 | data_size_(0), 42 | dist_func_param_(nullptr) 43 | { 44 | loadIndex(location, s); 45 | } 46 | 47 | BruteforceSearch(SpaceInterface *s, size_t maxElements) 48 | { 49 | maxelements_ = maxElements; 50 | data_size_ = s->get_data_size(); 51 | fstdistfunc_ = s->get_dist_func(); 52 | dist_func_param_ = s->get_dist_func_param(); 53 | size_per_element_ = data_size_ + sizeof(labeltype); 54 | data_ = (char *)malloc(maxElements * size_per_element_); 55 | if (data_ == nullptr) 56 | throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data"); 57 | cur_element_count = 0; 58 | } 59 | 60 | ~BruteforceSearch() 61 | { 62 | free(data_); 63 | } 64 | 65 | void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) 66 | { 67 | int idx; 68 | { 69 | std::unique_lock lock(index_lock); 70 | 71 | auto search = dict_external_to_internal.find(label); 72 | if (search != dict_external_to_internal.end()) 73 | { 74 | idx = search->second; 75 | } 76 | else 77 | { 78 | if (cur_element_count >= maxelements_) 79 | { 80 | throw std::runtime_error("The number of elements exceeds the specified limit\n"); 81 | } 82 | idx = cur_element_count; 83 | dict_external_to_internal[label] = idx; 84 | cur_element_count++; 85 | } 86 | } 87 | memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype)); 88 | memcpy(data_ + size_per_element_ * idx, datapoint, data_size_); 89 | } 90 | 91 | void removePoint(labeltype cur_external) 92 | { 93 | size_t cur_c = dict_external_to_internal[cur_external]; 94 | 95 | dict_external_to_internal.erase(cur_external); 96 | 97 | labeltype label = *((labeltype *)(data_ + size_per_element_ * (cur_element_count - 1) + data_size_)); 98 | dict_external_to_internal[label] = cur_c; 99 | memcpy(data_ + size_per_element_ * cur_c, 100 | data_ + size_per_element_ * (cur_element_count - 1), 101 | data_size_ + sizeof(labeltype)); 102 | cur_element_count--; 103 | } 104 | 105 | std::priority_queue> 106 | searchKnn(const void *query_data, size_t k, BaseFilterFunctor *isIdAllowed = nullptr) const 107 | { 108 | assert(k <= cur_element_count); 109 | std::priority_queue> topResults; 110 | if (cur_element_count == 0) 111 | return topResults; 112 | for (int i = 0; i < k; i++) 113 | { 114 | dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); 115 | labeltype label = *((labeltype *)(data_ + size_per_element_ * i + data_size_)); 116 | if ((!isIdAllowed) || (*isIdAllowed)(label)) 117 | { 118 | topResults.push(std::pair(dist, label)); 119 | } 120 | } 121 | dist_t lastdist = topResults.empty() ? std::numeric_limits::max() : topResults.top().first; 122 | for (int i = k; i < cur_element_count; i++) 123 | { 124 | dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); 125 | if (dist <= lastdist) 126 | { 127 | labeltype label = *((labeltype *)(data_ + size_per_element_ * i + data_size_)); 128 | if ((!isIdAllowed) || (*isIdAllowed)(label)) 129 | { 130 | topResults.push(std::pair(dist, label)); 131 | } 132 | if (topResults.size() > k) 133 | topResults.pop(); 134 | 135 | if (!topResults.empty()) 136 | { 137 | lastdist = topResults.top().first; 138 | } 139 | } 140 | } 141 | return topResults; 142 | } 143 | 144 | void saveIndex(const std::string &location) 145 | { 146 | std::ofstream output(location, std::ios::binary); 147 | std::streampos position; 148 | 149 | writeBinaryPOD(output, maxelements_); 150 | writeBinaryPOD(output, size_per_element_); 151 | writeBinaryPOD(output, cur_element_count); 152 | 153 | output.write(data_, maxelements_ * size_per_element_); 154 | 155 | output.close(); 156 | } 157 | 158 | void loadIndex(const std::string &location, SpaceInterface *s) 159 | { 160 | std::ifstream input(location, std::ios::binary); 161 | std::streampos position; 162 | 163 | readBinaryPOD(input, maxelements_); 164 | readBinaryPOD(input, size_per_element_); 165 | readBinaryPOD(input, cur_element_count); 166 | 167 | data_size_ = s->get_data_size(); 168 | fstdistfunc_ = s->get_dist_func(); 169 | dist_func_param_ = s->get_dist_func_param(); 170 | size_per_element_ = data_size_ + sizeof(labeltype); 171 | data_ = (char *)malloc(maxelements_ * size_per_element_); 172 | if (data_ == nullptr) 173 | throw std::runtime_error("Not enough memory: loadIndex failed to allocate data"); 174 | 175 | input.read(data_, maxelements_ * size_per_element_); 176 | 177 | input.close(); 178 | } 179 | }; 180 | } // namespace hnswlib 181 | -------------------------------------------------------------------------------- /hnswlib/hnswlib.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef NO_MANUAL_VECTORIZATION 3 | #if (defined(__SSE__) || _M_IX86_FP > 0 || defined(_M_AMD64) || defined(_M_X64)) 4 | #define USE_SSE 5 | #ifdef __AVX__ 6 | #define USE_AVX 7 | #ifdef __AVX512F__ 8 | #define USE_AVX512 9 | #endif 10 | #endif 11 | #endif 12 | #endif 13 | 14 | #if defined(USE_AVX) || defined(USE_SSE) 15 | #ifdef _MSC_VER 16 | #include 17 | #include 18 | void cpuid(int32_t out[4], int32_t eax, int32_t ecx) 19 | { 20 | __cpuidex(out, eax, ecx); 21 | } 22 | static __int64 xgetbv(unsigned int x) 23 | { 24 | return _xgetbv(x); 25 | } 26 | #else 27 | #include 28 | #include 29 | #include 30 | static void cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) 31 | { 32 | __cpuid_count(eax, ecx, cpuInfo[0], cpuInfo[1], cpuInfo[2], cpuInfo[3]); 33 | } 34 | static uint64_t xgetbv(unsigned int index) 35 | { 36 | uint32_t eax, edx; 37 | __asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index)); 38 | return ((uint64_t)edx << 32) | eax; 39 | } 40 | #endif 41 | 42 | #if defined(USE_AVX512) 43 | #include 44 | #endif 45 | 46 | #if defined(__GNUC__) 47 | #define PORTABLE_ALIGN32 __attribute__((aligned(32))) 48 | #define PORTABLE_ALIGN64 __attribute__((aligned(64))) 49 | #else 50 | #define PORTABLE_ALIGN32 __declspec(align(32)) 51 | #define PORTABLE_ALIGN64 __declspec(align(64)) 52 | #endif 53 | 54 | // Adapted from https://github.com/Mysticial/FeatureDetector 55 | #define _XCR_XFEATURE_ENABLED_MASK 0 56 | 57 | static bool AVXCapable() 58 | { 59 | int cpuInfo[4]; 60 | 61 | // CPU support 62 | cpuid(cpuInfo, 0, 0); 63 | int nIds = cpuInfo[0]; 64 | 65 | bool HW_AVX = false; 66 | if (nIds >= 0x00000001) 67 | { 68 | cpuid(cpuInfo, 0x00000001, 0); 69 | HW_AVX = (cpuInfo[2] & ((int)1 << 28)) != 0; 70 | } 71 | 72 | // OS support 73 | cpuid(cpuInfo, 1, 0); 74 | 75 | bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0; 76 | bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0; 77 | 78 | bool avxSupported = false; 79 | if (osUsesXSAVE_XRSTORE && cpuAVXSuport) 80 | { 81 | uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK); 82 | avxSupported = (xcrFeatureMask & 0x6) == 0x6; 83 | } 84 | return HW_AVX && avxSupported; 85 | } 86 | 87 | static bool AVX512Capable() 88 | { 89 | if (!AVXCapable()) 90 | return false; 91 | 92 | int cpuInfo[4]; 93 | 94 | // CPU support 95 | cpuid(cpuInfo, 0, 0); 96 | int nIds = cpuInfo[0]; 97 | 98 | bool HW_AVX512F = false; 99 | if (nIds >= 0x00000007) 100 | { // AVX512 Foundation 101 | cpuid(cpuInfo, 0x00000007, 0); 102 | HW_AVX512F = (cpuInfo[1] & ((int)1 << 16)) != 0; 103 | } 104 | 105 | // OS support 106 | cpuid(cpuInfo, 1, 0); 107 | 108 | bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0; 109 | bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0; 110 | 111 | bool avx512Supported = false; 112 | if (osUsesXSAVE_XRSTORE && cpuAVXSuport) 113 | { 114 | uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK); 115 | avx512Supported = (xcrFeatureMask & 0xe6) == 0xe6; 116 | } 117 | return HW_AVX512F && avx512Supported; 118 | } 119 | #endif 120 | 121 | #include 122 | #include 123 | #include 124 | #include 125 | 126 | namespace hnswlib 127 | { 128 | typedef size_t labeltype; 129 | 130 | // This can be extended to store state for filtering (e.g. from a std::set) 131 | class BaseFilterFunctor 132 | { 133 | public: 134 | virtual bool operator()(hnswlib::labeltype id) { return true; } 135 | }; 136 | 137 | template 138 | class pairGreater 139 | { 140 | public: 141 | bool operator()(const T &p1, const T &p2) 142 | { 143 | return p1.first > p2.first; 144 | } 145 | }; 146 | 147 | template 148 | static void writeBinaryPOD(std::ostream &out, const T &podRef) 149 | { 150 | out.write((char *)&podRef, sizeof(T)); 151 | } 152 | 153 | template 154 | static void readBinaryPOD(std::istream &in, T &podRef) 155 | { 156 | in.read((char *)&podRef, sizeof(T)); 157 | } 158 | 159 | template 160 | using DISTFUNC = MTYPE (*)(const void *, const void *, const void *); 161 | 162 | template 163 | class SpaceInterface 164 | { 165 | public: 166 | // virtual void search(void *); 167 | virtual size_t get_data_size() = 0; 168 | 169 | virtual DISTFUNC get_dist_func() = 0; 170 | 171 | virtual void *get_dist_func_param() = 0; 172 | 173 | virtual ~SpaceInterface() {} 174 | }; 175 | 176 | template 177 | class AlgorithmInterface 178 | { 179 | public: 180 | virtual void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) = 0; 181 | 182 | virtual std::priority_queue> 183 | searchKnn(const void *, size_t, BaseFilterFunctor *isIdAllowed = nullptr) const = 0; 184 | 185 | // Return k nearest neighbor in the order of closer fist 186 | virtual std::vector> 187 | searchKnnCloserFirst(const void *query_data, size_t k, BaseFilterFunctor *isIdAllowed = nullptr) const; 188 | 189 | virtual void saveIndex(const std::string &location) = 0; 190 | virtual ~AlgorithmInterface() 191 | { 192 | } 193 | }; 194 | 195 | template 196 | std::vector> 197 | AlgorithmInterface::searchKnnCloserFirst(const void *query_data, size_t k, 198 | BaseFilterFunctor *isIdAllowed) const 199 | { 200 | std::vector> result; 201 | 202 | // here searchKnn returns the result in the order of further first 203 | auto ret = searchKnn(query_data, k, isIdAllowed); 204 | { 205 | size_t sz = ret.size(); 206 | result.resize(sz); 207 | while (!ret.empty()) 208 | { 209 | result[--sz] = ret.top(); 210 | ret.pop(); 211 | } 212 | } 213 | 214 | return result; 215 | } 216 | } // namespace hnswlib 217 | 218 | #include "space_l2.h" 219 | #include "space_ip.h" 220 | #include "bruteforce.h" 221 | #include "hnswalg.h" 222 | -------------------------------------------------------------------------------- /hnswlib/space_l2.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "hnswlib.h" 3 | 4 | namespace hnswlib 5 | { 6 | 7 | static float 8 | L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) 9 | { 10 | float *pVect1 = (float *)pVect1v; 11 | float *pVect2 = (float *)pVect2v; 12 | size_t qty = *((size_t *)qty_ptr); 13 | 14 | float res = 0; 15 | for (size_t i = 0; i < qty; i++) 16 | { 17 | float t = *pVect1 - *pVect2; 18 | pVect1++; 19 | pVect2++; 20 | res += t * t; 21 | } 22 | return (res); 23 | } 24 | 25 | #if defined(USE_AVX512) 26 | 27 | // Favor using AVX512 if available. 28 | static float 29 | L2SqrSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) 30 | { 31 | float *pVect1 = (float *)pVect1v; 32 | float *pVect2 = (float *)pVect2v; 33 | size_t qty = *((size_t *)qty_ptr); 34 | float PORTABLE_ALIGN64 TmpRes[16]; 35 | size_t qty16 = qty >> 4; 36 | 37 | const float *pEnd1 = pVect1 + (qty16 << 4); 38 | 39 | __m512 diff, v1, v2; 40 | __m512 sum = _mm512_set1_ps(0); 41 | 42 | while (pVect1 < pEnd1) 43 | { 44 | v1 = _mm512_loadu_ps(pVect1); 45 | pVect1 += 16; 46 | v2 = _mm512_loadu_ps(pVect2); 47 | pVect2 += 16; 48 | diff = _mm512_sub_ps(v1, v2); 49 | // sum = _mm512_fmadd_ps(diff, diff, sum); 50 | sum = _mm512_add_ps(sum, _mm512_mul_ps(diff, diff)); 51 | } 52 | 53 | _mm512_store_ps(TmpRes, sum); 54 | float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + 55 | TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + 56 | TmpRes[13] + TmpRes[14] + TmpRes[15]; 57 | 58 | return (res); 59 | } 60 | #endif 61 | 62 | #if defined(USE_AVX) 63 | 64 | // Favor using AVX if available. 65 | static float 66 | L2SqrSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) 67 | { 68 | float *pVect1 = (float *)pVect1v; 69 | float *pVect2 = (float *)pVect2v; 70 | size_t qty = *((size_t *)qty_ptr); 71 | float PORTABLE_ALIGN32 TmpRes[8]; 72 | size_t qty16 = qty >> 4; 73 | 74 | const float *pEnd1 = pVect1 + (qty16 << 4); 75 | 76 | __m256 diff, v1, v2; 77 | __m256 sum = _mm256_set1_ps(0); 78 | 79 | while (pVect1 < pEnd1) 80 | { 81 | v1 = _mm256_loadu_ps(pVect1); 82 | pVect1 += 8; 83 | v2 = _mm256_loadu_ps(pVect2); 84 | pVect2 += 8; 85 | diff = _mm256_sub_ps(v1, v2); 86 | sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); 87 | 88 | v1 = _mm256_loadu_ps(pVect1); 89 | pVect1 += 8; 90 | v2 = _mm256_loadu_ps(pVect2); 91 | pVect2 += 8; 92 | diff = _mm256_sub_ps(v1, v2); 93 | sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); 94 | } 95 | 96 | _mm256_store_ps(TmpRes, sum); 97 | return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; 98 | } 99 | 100 | #endif 101 | 102 | #if defined(USE_SSE) 103 | 104 | static float 105 | L2SqrSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) 106 | { 107 | float *pVect1 = (float *)pVect1v; 108 | float *pVect2 = (float *)pVect2v; 109 | size_t qty = *((size_t *)qty_ptr); 110 | float PORTABLE_ALIGN32 TmpRes[8]; 111 | size_t qty16 = qty >> 4; 112 | 113 | const float *pEnd1 = pVect1 + (qty16 << 4); 114 | 115 | __m128 diff, v1, v2; 116 | __m128 sum = _mm_set1_ps(0); 117 | 118 | while (pVect1 < pEnd1) 119 | { 120 | //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); 121 | v1 = _mm_loadu_ps(pVect1); 122 | pVect1 += 4; 123 | v2 = _mm_loadu_ps(pVect2); 124 | pVect2 += 4; 125 | diff = _mm_sub_ps(v1, v2); 126 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 127 | 128 | v1 = _mm_loadu_ps(pVect1); 129 | pVect1 += 4; 130 | v2 = _mm_loadu_ps(pVect2); 131 | pVect2 += 4; 132 | diff = _mm_sub_ps(v1, v2); 133 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 134 | 135 | v1 = _mm_loadu_ps(pVect1); 136 | pVect1 += 4; 137 | v2 = _mm_loadu_ps(pVect2); 138 | pVect2 += 4; 139 | diff = _mm_sub_ps(v1, v2); 140 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 141 | 142 | v1 = _mm_loadu_ps(pVect1); 143 | pVect1 += 4; 144 | v2 = _mm_loadu_ps(pVect2); 145 | pVect2 += 4; 146 | diff = _mm_sub_ps(v1, v2); 147 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 148 | } 149 | 150 | _mm_store_ps(TmpRes, sum); 151 | return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 152 | } 153 | #endif 154 | 155 | #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) 156 | static DISTFUNC L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE; 157 | 158 | static float 159 | L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) 160 | { 161 | size_t qty = *((size_t *)qty_ptr); 162 | size_t qty16 = qty >> 4 << 4; 163 | float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16); 164 | float *pVect1 = (float *)pVect1v + qty16; 165 | float *pVect2 = (float *)pVect2v + qty16; 166 | 167 | size_t qty_left = qty - qty16; 168 | float res_tail = L2Sqr(pVect1, pVect2, &qty_left); 169 | return (res + res_tail); 170 | } 171 | #endif 172 | 173 | #if defined(USE_SSE) 174 | static float 175 | L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) 176 | { 177 | float PORTABLE_ALIGN32 TmpRes[8]; 178 | float *pVect1 = (float *)pVect1v; 179 | float *pVect2 = (float *)pVect2v; 180 | size_t qty = *((size_t *)qty_ptr); 181 | 182 | size_t qty4 = qty >> 2; 183 | 184 | const float *pEnd1 = pVect1 + (qty4 << 2); 185 | 186 | __m128 diff, v1, v2; 187 | __m128 sum = _mm_set1_ps(0); 188 | 189 | while (pVect1 < pEnd1) 190 | { 191 | v1 = _mm_loadu_ps(pVect1); 192 | pVect1 += 4; 193 | v2 = _mm_loadu_ps(pVect2); 194 | pVect2 += 4; 195 | diff = _mm_sub_ps(v1, v2); 196 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 197 | } 198 | _mm_store_ps(TmpRes, sum); 199 | return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 200 | } 201 | 202 | static float 203 | L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) 204 | { 205 | size_t qty = *((size_t *)qty_ptr); 206 | size_t qty4 = qty >> 2 << 2; 207 | 208 | float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4); 209 | size_t qty_left = qty - qty4; 210 | 211 | float *pVect1 = (float *)pVect1v + qty4; 212 | float *pVect2 = (float *)pVect2v + qty4; 213 | float res_tail = L2Sqr(pVect1, pVect2, &qty_left); 214 | 215 | return (res + res_tail); 216 | } 217 | #endif 218 | 219 | class L2Space : public SpaceInterface 220 | { 221 | DISTFUNC fstdistfunc_; 222 | size_t data_size_; 223 | size_t dim_; 224 | 225 | public: 226 | L2Space(size_t dim) 227 | { 228 | fstdistfunc_ = L2Sqr; 229 | #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) 230 | #if defined(USE_AVX512) 231 | if (AVX512Capable()) 232 | L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512; 233 | else if (AVXCapable()) 234 | L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; 235 | #elif defined(USE_AVX) 236 | if (AVXCapable()) 237 | L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; 238 | #endif 239 | 240 | if (dim % 16 == 0) 241 | fstdistfunc_ = L2SqrSIMD16Ext; 242 | else if (dim % 4 == 0) 243 | fstdistfunc_ = L2SqrSIMD4Ext; 244 | else if (dim > 16) 245 | fstdistfunc_ = L2SqrSIMD16ExtResiduals; 246 | else if (dim > 4) 247 | fstdistfunc_ = L2SqrSIMD4ExtResiduals; 248 | #endif 249 | dim_ = dim; 250 | data_size_ = dim * sizeof(float); 251 | } 252 | 253 | size_t get_data_size() 254 | { 255 | return data_size_; 256 | } 257 | 258 | DISTFUNC get_dist_func() 259 | { 260 | return fstdistfunc_; 261 | } 262 | 263 | void *get_dist_func_param() 264 | { 265 | return &dim_; 266 | } 267 | 268 | ~L2Space() {} 269 | }; 270 | 271 | static int 272 | L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) 273 | { 274 | size_t qty = *((size_t *)qty_ptr); 275 | int res = 0; 276 | unsigned char *a = (unsigned char *)pVect1; 277 | unsigned char *b = (unsigned char *)pVect2; 278 | 279 | qty = qty >> 2; 280 | for (size_t i = 0; i < qty; i++) 281 | { 282 | res += ((*a) - (*b)) * ((*a) - (*b)); 283 | a++; 284 | b++; 285 | res += ((*a) - (*b)) * ((*a) - (*b)); 286 | a++; 287 | b++; 288 | res += ((*a) - (*b)) * ((*a) - (*b)); 289 | a++; 290 | b++; 291 | res += ((*a) - (*b)) * ((*a) - (*b)); 292 | a++; 293 | b++; 294 | } 295 | return (res); 296 | } 297 | 298 | static int L2SqrI(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) 299 | { 300 | size_t qty = *((size_t *)qty_ptr); 301 | int res = 0; 302 | unsigned char *a = (unsigned char *)pVect1; 303 | unsigned char *b = (unsigned char *)pVect2; 304 | 305 | for (size_t i = 0; i < qty; i++) 306 | { 307 | res += ((*a) - (*b)) * ((*a) - (*b)); 308 | a++; 309 | b++; 310 | } 311 | return (res); 312 | } 313 | 314 | class L2SpaceI : public SpaceInterface 315 | { 316 | DISTFUNC fstdistfunc_; 317 | size_t data_size_; 318 | size_t dim_; 319 | 320 | public: 321 | L2SpaceI(size_t dim) 322 | { 323 | if (dim % 4 == 0) 324 | { 325 | fstdistfunc_ = L2SqrI4x; 326 | } 327 | else 328 | { 329 | fstdistfunc_ = L2SqrI; 330 | } 331 | dim_ = dim; 332 | data_size_ = dim * sizeof(unsigned char); 333 | } 334 | 335 | size_t get_data_size() 336 | { 337 | return data_size_; 338 | } 339 | 340 | DISTFUNC get_dist_func() 341 | { 342 | return fstdistfunc_; 343 | } 344 | 345 | void *get_dist_func_param() 346 | { 347 | return &dim_; 348 | } 349 | 350 | ~L2SpaceI() {} 351 | }; 352 | } // namespace hnswlib 353 | -------------------------------------------------------------------------------- /hnswlib/visited_list_pool.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace hnswlib 8 | { 9 | typedef unsigned short int vl_type; 10 | 11 | class VisitedList 12 | { 13 | public: 14 | vl_type curV; 15 | vl_type *mass; 16 | unsigned int numelements; 17 | 18 | VisitedList(int numelements1) 19 | { 20 | curV = -1; 21 | numelements = numelements1; 22 | mass = new vl_type[numelements]; 23 | } 24 | 25 | void reset() 26 | { 27 | curV++; 28 | if (curV == 0) 29 | { 30 | memset(mass, 0, sizeof(vl_type) * numelements); 31 | curV++; 32 | } 33 | } 34 | 35 | ~VisitedList() { delete[] mass; } 36 | }; 37 | /////////////////////////////////////////////////////////// 38 | // 39 | // Class for multi-threaded pool-management of VisitedLists 40 | // 41 | ///////////////////////////////////////////////////////// 42 | 43 | class VisitedListPool 44 | { 45 | std::deque pool; 46 | std::mutex poolguard; 47 | int numelements; 48 | 49 | public: 50 | VisitedListPool(int initmaxpools, int numelements1) 51 | { 52 | numelements = numelements1; 53 | for (int i = 0; i < initmaxpools; i++) 54 | pool.push_front(new VisitedList(numelements)); 55 | } 56 | 57 | VisitedList *getFreeVisitedList() 58 | { 59 | VisitedList *rez; 60 | { 61 | std::unique_lock lock(poolguard); 62 | if (pool.size() > 0) 63 | { 64 | rez = pool.front(); 65 | pool.pop_front(); 66 | } 67 | else 68 | { 69 | rez = new VisitedList(numelements); 70 | } 71 | } 72 | rez->reset(); 73 | return rez; 74 | } 75 | 76 | void releaseVisitedList(VisitedList *vl) 77 | { 78 | std::unique_lock lock(poolguard); 79 | pool.push_front(vl); 80 | } 81 | 82 | ~VisitedListPool() 83 | { 84 | while (pool.size()) 85 | { 86 | VisitedList *rez = pool.front(); 87 | pool.pop_front(); 88 | delete rez; 89 | } 90 | } 91 | }; 92 | } // namespace hnswlib 93 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel", 5 | "numpy>=1.10.0", 6 | "pybind11>=2.0", 7 | ] 8 | 9 | build-backend = "setuptools.build_meta" 10 | -------------------------------------------------------------------------------- /python_bindings/LazyIndex.py: -------------------------------------------------------------------------------- 1 | import hnswlib 2 | 3 | """ 4 | A python wrapper for lazy indexing, preserves the same api as hnswlib.Index but initializes the index only after adding items for the first time with `add_items`. 5 | """ 6 | 7 | 8 | class LazyIndex(hnswlib.Index): 9 | def __init__(self, space, dim, max_elements=1024, ef_construction=200, M=16): 10 | super().__init__(space, dim) 11 | self.init_max_elements = max_elements 12 | self.init_ef_construction = ef_construction 13 | self.init_M = M 14 | 15 | def init_index(self, max_elements=0, M=0, ef_construction=0): 16 | if max_elements > 0: 17 | self.init_max_elements = max_elements 18 | if ef_construction > 0: 19 | self.init_ef_construction = ef_construction 20 | if M > 0: 21 | self.init_M = M 22 | super().init_index( 23 | self.init_max_elements, self.init_M, self.init_ef_construction 24 | ) 25 | 26 | def add_items(self, data, ids=None, num_threads=-1): 27 | if self.max_elements == 0: 28 | self.init_index() 29 | return super().add_items(data, ids, num_threads) 30 | 31 | def get_items(self, ids=None): 32 | if self.max_elements == 0: 33 | return [] 34 | return super().get_items(ids) 35 | 36 | def knn_query(self, data, k=1, num_threads=-1): 37 | if self.max_elements == 0: 38 | return [], [] 39 | return super().knn_query(data, k, num_threads) 40 | 41 | def resize_index(self, size): 42 | if self.max_elements == 0: 43 | return self.init_index(size) 44 | else: 45 | return super().resize_index(size) 46 | 47 | def set_ef(self, ef): 48 | if self.max_elements == 0: 49 | self.init_ef_construction = ef 50 | return 51 | super().set_ef(ef) 52 | 53 | def get_max_elements(self): 54 | return self.max_elements 55 | 56 | def get_current_count(self): 57 | return self.element_count 58 | -------------------------------------------------------------------------------- /python_bindings/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chroma-core/hnswlib/340a6fff21f0ba2db2bc720636eba4d7736b9534/python_bindings/__init__.py -------------------------------------------------------------------------------- /python_bindings/setup.py: -------------------------------------------------------------------------------- 1 | ../setup.py -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import platform 4 | 5 | import numpy as np 6 | import pybind11 7 | import setuptools 8 | from setuptools import Extension, setup 9 | from setuptools.command.build_ext import build_ext 10 | 11 | __version__ = "0.7.6" 12 | 13 | include_dirs = [ 14 | pybind11.get_include(), 15 | np.get_include(), 16 | ] 17 | 18 | # compatibility when run in python_bindings 19 | bindings_dir = "python_bindings" 20 | if bindings_dir in os.path.basename(os.getcwd()): 21 | source_files = ["./bindings.cpp"] 22 | include_dirs.extend(["../hnswlib/"]) 23 | else: 24 | source_files = ["./python_bindings/bindings.cpp"] 25 | include_dirs.extend(["./hnswlib/"]) 26 | 27 | 28 | libraries = [] 29 | extra_objects = [] 30 | 31 | 32 | ext_modules = [ 33 | Extension( 34 | "hnswlib", 35 | source_files, 36 | include_dirs=include_dirs, 37 | libraries=libraries, 38 | language="c++", 39 | extra_objects=extra_objects, 40 | ), 41 | ] 42 | 43 | 44 | # As of Python 3.6, CCompiler has a `has_flag` method. 45 | # cf http://bugs.python.org/issue26689 46 | def has_flag(compiler, flagname): 47 | """Return a boolean indicating whether a flag name is supported on 48 | the specified compiler. 49 | """ 50 | import tempfile 51 | 52 | with tempfile.NamedTemporaryFile("w", suffix=".cpp") as f: 53 | f.write("int main (int argc, char **argv) { return 0; }") 54 | try: 55 | compiler.compile([f.name], extra_postargs=[flagname]) 56 | except setuptools.distutils.errors.CompileError: 57 | return False 58 | return True 59 | 60 | 61 | def cpp_flag(compiler): 62 | """Return the -std=c++[11/14] compiler flag. 63 | The c++14 is prefered over c++11 (when it is available). 64 | """ 65 | if has_flag(compiler, "-std=c++14"): 66 | return "-std=c++14" 67 | elif has_flag(compiler, "-std=c++11"): 68 | return "-std=c++11" 69 | else: 70 | raise RuntimeError( 71 | "Unsupported compiler -- at least C++11 support " "is needed!" 72 | ) 73 | 74 | 75 | class BuildExt(build_ext): 76 | """A custom build extension for adding compiler-specific options.""" 77 | 78 | c_opts = { 79 | "msvc": ["/EHsc", "/openmp", "/O2"], 80 | #'unix': ['-O3', '-march=native'], # , '-w' 81 | "unix": ["-O3"], # , '-w' 82 | } 83 | if not os.environ.get("HNSWLIB_NO_NATIVE"): 84 | c_opts["unix"].append("-march=native") 85 | 86 | link_opts = { 87 | "unix": [], 88 | "msvc": [], 89 | } 90 | 91 | if sys.platform == "darwin": 92 | if platform.machine() == "arm64": 93 | if "-march=native" in c_opts["unix"]: 94 | c_opts["unix"].remove("-march=native") 95 | c_opts["unix"] += ["-stdlib=libc++", "-mmacosx-version-min=10.7"] 96 | link_opts["unix"] += ["-stdlib=libc++", "-mmacosx-version-min=10.7"] 97 | else: 98 | c_opts["unix"].append("-fopenmp") 99 | link_opts["unix"].extend(["-fopenmp", "-pthread"]) 100 | 101 | def build_extensions(self): 102 | ct = self.compiler.compiler_type 103 | opts = self.c_opts.get(ct, []) 104 | if ct == "unix": 105 | opts.append('-DVERSION_INFO="%s"' % self.distribution.get_version()) 106 | opts.append(cpp_flag(self.compiler)) 107 | if has_flag(self.compiler, "-fvisibility=hidden"): 108 | opts.append("-fvisibility=hidden") 109 | elif ct == "msvc": 110 | opts.append('/DVERSION_INFO=\\"%s\\"' % self.distribution.get_version()) 111 | 112 | for ext in self.extensions: 113 | ext.extra_compile_args.extend(opts) 114 | ext.extra_link_args.extend(self.link_opts.get(ct, [])) 115 | 116 | build_ext.build_extensions(self) 117 | 118 | 119 | setup( 120 | name="chroma-hnswlib", 121 | version=__version__, 122 | description="Chromas fork of hnswlib", 123 | author="Yury Malkov and the original hnswlib authors + Chroma", 124 | url="https://github.com/chroma-core/hnswlib", 125 | long_description="""hnsw""", 126 | ext_modules=ext_modules, 127 | install_requires=["numpy"], 128 | cmdclass={"build_ext": BuildExt}, 129 | zip_safe=False, 130 | ) 131 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | mod hnsw; 2 | pub use hnsw::*; 3 | -------------------------------------------------------------------------------- /tests/cpp/api_test.cpp: -------------------------------------------------------------------------------- 1 | #include "../../hnswlib/hnswlib.h" 2 | 3 | #include 4 | 5 | void testListAllLabels() 6 | { 7 | int d = 1536; 8 | hnswlib::labeltype n = 1000; 9 | 10 | std::vector data(n * d); 11 | 12 | std::mt19937 rng; 13 | rng.seed(47); 14 | std::uniform_real_distribution<> distrib; 15 | 16 | for (auto i = 0; i < n * d; i++) 17 | { 18 | data[i] = distrib(rng); 19 | } 20 | 21 | hnswlib::InnerProductSpace space(d); 22 | hnswlib::HierarchicalNSW *alg_hnsw = new hnswlib::HierarchicalNSW(&space, n, 16, 200, 100, false, false, true, "."); 23 | 24 | for (size_t i = 0; i < n; i++) 25 | { 26 | alg_hnsw->addPoint(data.data() + d * i, i); 27 | } 28 | // Delete odd points. 29 | for (size_t i = 1; i < n; i += 2) 30 | { 31 | alg_hnsw->markDelete(i); 32 | } 33 | // Get all data. 34 | auto res = alg_hnsw->getAllLabels(); 35 | auto non_deleted = res.first; 36 | auto deleted = res.second; 37 | assert(non_deleted.size() == n / 2); 38 | assert(deleted.size() == n / 2); 39 | 40 | for (auto idx : non_deleted) 41 | { 42 | assert(idx % 2 == 0); 43 | } 44 | for (auto idx : deleted) 45 | { 46 | assert(idx % 2 == 1); 47 | } 48 | 49 | // After persisting and reloading the data should be the same. 50 | alg_hnsw->persistDirty(); 51 | 52 | // Load the index with n elements 53 | hnswlib::HierarchicalNSW *alg_hnsw2 = new hnswlib::HierarchicalNSW(&space, ".", false, n, false, false, true); 54 | 55 | // Check that all the data is the same 56 | auto res2 = alg_hnsw2->getAllLabels(); 57 | auto non_deleted2 = res2.first; 58 | auto deleted2 = res2.second; 59 | assert(non_deleted2.size() == n / 2); 60 | assert(deleted2.size() == n / 2); 61 | 62 | for (auto idx : non_deleted2) 63 | { 64 | assert(idx % 2 == 0); 65 | } 66 | for (auto idx : deleted2) 67 | { 68 | assert(idx % 2 == 1); 69 | } 70 | } 71 | 72 | int main() 73 | { 74 | std::cout << "Testing ..." << std::endl; 75 | testListAllLabels(); 76 | std::cout << "Test testListAllLabels ok" << std::endl; 77 | return 0; 78 | } -------------------------------------------------------------------------------- /tests/cpp/download_bigann.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import os 3 | 4 | links = [ 5 | "ftp://ftp.irisa.fr/local/texmex/corpus/bigann_query.bvecs.gz", 6 | "ftp://ftp.irisa.fr/local/texmex/corpus/bigann_gnd.tar.gz", 7 | "ftp://ftp.irisa.fr/local/texmex/corpus/bigann_base.bvecs.gz", 8 | ] 9 | 10 | os.makedirs("downloads", exist_ok=True) 11 | os.makedirs("bigann", exist_ok=True) 12 | for link in links: 13 | name = link.rsplit("/", 1)[-1] 14 | filename = os.path.join("downloads", name) 15 | if not os.path.isfile(filename): 16 | print("Downloading: " + filename) 17 | try: 18 | os.system("wget --output-document=" + filename + " " + link) 19 | except Exception as inst: 20 | print(inst) 21 | print(" Encountered unknown error. Continuing.") 22 | else: 23 | print("Already downloaded: " + filename) 24 | if filename.endswith(".tar.gz"): 25 | command = "tar -zxf " + filename + " --directory bigann" 26 | else: 27 | command = "cat " + filename + " | gzip -dc > bigann/" + name.replace(".gz", "") 28 | print("Unpacking file:", command) 29 | os.system(command) 30 | -------------------------------------------------------------------------------- /tests/cpp/getUnormalized_test.cpp: -------------------------------------------------------------------------------- 1 | #include "../../hnswlib/hnswlib.h" 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | namespace 9 | { 10 | 11 | using idx_t = hnswlib::labeltype; 12 | 13 | void testReadUnormalizedData() 14 | { 15 | int d = 4; 16 | idx_t n = 100; 17 | idx_t nq = 10; 18 | size_t k = 10; 19 | 20 | std::vector data(n * d); 21 | 22 | std::mt19937 rng; 23 | rng.seed(47); 24 | std::uniform_real_distribution<> distrib; 25 | 26 | for (idx_t i = 0; i < n * d; i++) 27 | { 28 | data[i] = distrib(rng); 29 | } 30 | 31 | hnswlib::InnerProductSpace space(d); 32 | hnswlib::HierarchicalNSW *alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n, 16, 200, 100, false, true); 33 | 34 | for (size_t i = 0; i < n; i++) 35 | { 36 | alg_hnsw->addPoint(data.data() + d * i, i); 37 | } 38 | 39 | // Check that all data is the same 40 | for (size_t i = 0; i < n; i++) 41 | { 42 | std::vector actual = alg_hnsw->template getDataByLabel(i); 43 | for (size_t j = 0; j < d; j++) 44 | { 45 | // Check that abs difference is less than 1e-6 46 | assert(std::abs(actual[j] - data[d * i + j]) < 1e-6); 47 | } 48 | } 49 | 50 | delete alg_hnsw; 51 | } 52 | 53 | void testSaveAndLoadUnormalizedData() 54 | { 55 | int d = 4; 56 | idx_t n = 100; 57 | idx_t nq = 10; 58 | size_t k = 10; 59 | 60 | std::vector data(n * d); 61 | 62 | std::mt19937 rng; 63 | rng.seed(47); 64 | std::uniform_real_distribution<> distrib; 65 | 66 | for (idx_t i = 0; i < n * d; i++) 67 | { 68 | data[i] = distrib(rng); 69 | } 70 | 71 | hnswlib::InnerProductSpace space(d); 72 | hnswlib::HierarchicalNSW *alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n, 16, 200, 100, false, true); 73 | 74 | for (size_t i = 0; i < n; i++) 75 | { 76 | alg_hnsw->addPoint(data.data() + d * i, i); 77 | } 78 | 79 | alg_hnsw->saveIndex("test.bin"); 80 | 81 | hnswlib::HierarchicalNSW *alg_hnsw2 = new hnswlib::HierarchicalNSW(&space, "test.bin", false, 2 * n, false, true); 82 | 83 | // Check that all data is the same 84 | for (size_t i = 0; i < n; i++) 85 | { 86 | std::vector actual = alg_hnsw2->template getDataByLabel(i); 87 | for (size_t j = 0; j < d; j++) 88 | { 89 | // Check that abs difference is less than 1e-6 90 | assert(std::abs(actual[j] - data[d * i + j]) < 1e-6); 91 | } 92 | } 93 | 94 | delete alg_hnsw; 95 | } 96 | 97 | void testUpdateUnormalizedData() 98 | { 99 | int d = 4; 100 | idx_t n = 100; 101 | idx_t nq = 10; 102 | size_t k = 10; 103 | 104 | std::vector data(n * d); 105 | 106 | std::mt19937 rng; 107 | rng.seed(47); 108 | std::uniform_real_distribution<> distrib; 109 | 110 | for (idx_t i = 0; i < n * d; i++) 111 | { 112 | data[i] = distrib(rng); 113 | } 114 | 115 | hnswlib::InnerProductSpace space(d); 116 | hnswlib::HierarchicalNSW *alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n, 16, 200, 100, false, true); 117 | 118 | for (size_t i = 0; i < n; i++) 119 | { 120 | alg_hnsw->addPoint(data.data() + d * i, i); 121 | } 122 | 123 | // Check that all data is the same 124 | for (size_t i = 0; i < n; i++) 125 | { 126 | std::vector actual = alg_hnsw->template getDataByLabel(i); 127 | for (size_t j = 0; j < d; j++) 128 | { 129 | // Check that abs difference is less than 1e-6 130 | assert(std::abs(actual[j] - data[d * i + j]) < 1e-6); 131 | } 132 | } 133 | 134 | // Generate new data 135 | std::vector data2(n * d); 136 | for (idx_t i = 0; i < n * d; i++) 137 | { 138 | data2[i] = distrib(rng); 139 | } 140 | 141 | // Update data 142 | for (size_t i = 0; i < n; i++) 143 | { 144 | alg_hnsw->addPoint(data2.data() + d * i, i); 145 | } 146 | 147 | // Check that all data is the same 148 | for (size_t i = 0; i < n; i++) 149 | { 150 | std::vector actual = alg_hnsw->template getDataByLabel(i); 151 | for (size_t j = 0; j < d; j++) 152 | { 153 | // Check that abs difference is less than 1e-6 154 | assert(std::abs(actual[j] - data2[d * i + j]) < 1e-6); 155 | } 156 | } 157 | 158 | delete alg_hnsw; 159 | } 160 | 161 | } // namespace 162 | 163 | void testResizeUnormalizedData() 164 | { 165 | int d = 4; 166 | idx_t n = 100; 167 | idx_t nq = 10; 168 | size_t k = 10; 169 | 170 | std::vector data(n * d); 171 | 172 | std::mt19937 rng; 173 | rng.seed(47); 174 | std::uniform_real_distribution<> distrib; 175 | 176 | for (idx_t i = 0; i < n * d; i++) 177 | { 178 | data[i] = distrib(rng); 179 | } 180 | 181 | hnswlib::InnerProductSpace space(d); 182 | hnswlib::HierarchicalNSW *alg_hnsw = new hnswlib::HierarchicalNSW(&space, n, 16, 200, 100, false, true); 183 | 184 | for (size_t i = 0; i < n; i++) 185 | { 186 | alg_hnsw->addPoint(data.data() + d * i, i); 187 | } 188 | 189 | // Expect add to throw exception 190 | try 191 | { 192 | alg_hnsw->addPoint(data.data(), n); 193 | assert(false); 194 | } 195 | catch (std::runtime_error &e) 196 | { 197 | // Pass 198 | } 199 | 200 | // Resize the index 201 | alg_hnsw->resizeIndex(2 * n); 202 | 203 | // Check that all data is the same 204 | for (size_t i = 0; i < n; i++) 205 | { 206 | std::vector actual = alg_hnsw->template getDataByLabel(i); 207 | for (size_t j = 0; j < d; j++) 208 | { 209 | // Check that abs difference is less than 1e-6 210 | assert(std::abs(actual[j] - data[d * i + j]) < 1e-6); 211 | } 212 | } 213 | 214 | // Update / Add new data 215 | std::vector data2(n * 2 * d); 216 | for (idx_t i = 0; i < n * 2 * d; i++) 217 | { 218 | data2[i] = distrib(rng); 219 | } 220 | for (size_t i = 0; i < n; i++) 221 | { 222 | alg_hnsw->addPoint(data2.data() + d * i, i); 223 | } 224 | 225 | // Check that all data is the same 226 | for (size_t i = 0; i < n; i++) 227 | { 228 | std::vector actual = alg_hnsw->template getDataByLabel(i); 229 | for (size_t j = 0; j < d; j++) 230 | { 231 | // Check that abs difference is less than 1e-6 232 | assert(std::abs(actual[j] - data2[d * i + j]) < 1e-6); 233 | } 234 | } 235 | } 236 | 237 | int main() 238 | { 239 | std::cout << "Testing ..." << std::endl; 240 | testReadUnormalizedData(); 241 | std::cout << "Test testReadUnormalizedData ok" << std::endl; 242 | testSaveAndLoadUnormalizedData(); 243 | std::cout << "Test testSaveAndLoadUnormalizedData ok" << std::endl; 244 | testUpdateUnormalizedData(); 245 | std::cout << "Test testUpdateUnormalizedData ok" << std::endl; 246 | testResizeUnormalizedData(); 247 | std::cout << "Test testResizeUnormalizedData ok" << std::endl; 248 | 249 | return 0; 250 | } 251 | -------------------------------------------------------------------------------- /tests/cpp/main.cpp: -------------------------------------------------------------------------------- 1 | 2 | 3 | void sift_test1B(); 4 | int main() 5 | { 6 | sift_test1B(); 7 | 8 | return 0; 9 | } 10 | -------------------------------------------------------------------------------- /tests/cpp/multiThreadLoad_test.cpp: -------------------------------------------------------------------------------- 1 | #include "../../hnswlib/hnswlib.h" 2 | #include 3 | #include 4 | 5 | int main() 6 | { 7 | std::cout << "Running multithread load test" << std::endl; 8 | int d = 16; 9 | int max_elements = 1000; 10 | 11 | std::mt19937 rng; 12 | rng.seed(47); 13 | std::uniform_real_distribution<> distrib_real; 14 | 15 | hnswlib::L2Space space(d); 16 | hnswlib::HierarchicalNSW *alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * max_elements); 17 | 18 | std::cout << "Building index" << std::endl; 19 | int num_threads = 40; 20 | int num_labels = 10; 21 | 22 | int num_iterations = 10; 23 | int start_label = 0; 24 | 25 | // run threads that will add elements to the index 26 | // about 7 threads (the number depends on num_threads and num_labels) 27 | // will add/update element with the same label simultaneously 28 | while (true) 29 | { 30 | // add elements by batches 31 | std::uniform_int_distribution<> distrib_int(start_label, start_label + num_labels - 1); 32 | std::vector threads; 33 | for (size_t thread_id = 0; thread_id < num_threads; thread_id++) 34 | { 35 | threads.push_back( 36 | std::thread( 37 | [&] 38 | { 39 | for (int iter = 0; iter < num_iterations; iter++) 40 | { 41 | std::vector data(d); 42 | hnswlib::labeltype label = distrib_int(rng); 43 | for (int i = 0; i < d; i++) 44 | { 45 | data[i] = distrib_real(rng); 46 | } 47 | alg_hnsw->addPoint(data.data(), label); 48 | } 49 | })); 50 | } 51 | for (auto &thread : threads) 52 | { 53 | thread.join(); 54 | } 55 | if (alg_hnsw->cur_element_count > max_elements - num_labels) 56 | { 57 | break; 58 | } 59 | start_label += num_labels; 60 | } 61 | 62 | // insert remaining elements if needed 63 | for (hnswlib::labeltype label = 0; label < max_elements; label++) 64 | { 65 | auto search = alg_hnsw->label_lookup_.find(label); 66 | if (search == alg_hnsw->label_lookup_.end()) 67 | { 68 | std::cout << "Adding " << label << std::endl; 69 | std::vector data(d); 70 | for (int i = 0; i < d; i++) 71 | { 72 | data[i] = distrib_real(rng); 73 | } 74 | alg_hnsw->addPoint(data.data(), label); 75 | } 76 | } 77 | 78 | std::cout << "Index is created" << std::endl; 79 | 80 | bool stop_threads = false; 81 | std::vector threads; 82 | 83 | // create threads that will do markDeleted and unmarkDeleted of random elements 84 | // each thread works with specific range of labels 85 | std::cout << "Starting markDeleted and unmarkDeleted threads" << std::endl; 86 | num_threads = 20; 87 | int chunk_size = max_elements / num_threads; 88 | for (size_t thread_id = 0; thread_id < num_threads; thread_id++) 89 | { 90 | threads.push_back( 91 | std::thread( 92 | [&, thread_id] 93 | { 94 | std::uniform_int_distribution<> distrib_int(0, chunk_size - 1); 95 | int start_id = thread_id * chunk_size; 96 | std::vector marked_deleted(chunk_size); 97 | while (!stop_threads) 98 | { 99 | int id = distrib_int(rng); 100 | hnswlib::labeltype label = start_id + id; 101 | if (marked_deleted[id]) 102 | { 103 | alg_hnsw->unmarkDelete(label); 104 | marked_deleted[id] = false; 105 | } 106 | else 107 | { 108 | alg_hnsw->markDelete(label); 109 | marked_deleted[id] = true; 110 | } 111 | } 112 | })); 113 | } 114 | 115 | // create threads that will add and update random elements 116 | std::cout << "Starting add and update elements threads" << std::endl; 117 | num_threads = 20; 118 | std::uniform_int_distribution<> distrib_int_add(max_elements, 2 * max_elements - 1); 119 | for (size_t thread_id = 0; thread_id < num_threads; thread_id++) 120 | { 121 | threads.push_back( 122 | std::thread( 123 | [&] 124 | { 125 | std::vector data(d); 126 | while (!stop_threads) 127 | { 128 | hnswlib::labeltype label = distrib_int_add(rng); 129 | for (int i = 0; i < d; i++) 130 | { 131 | data[i] = distrib_real(rng); 132 | } 133 | alg_hnsw->addPoint(data.data(), label); 134 | std::vector data = alg_hnsw->getDataByLabel(label); 135 | float max_val = *max_element(data.begin(), data.end()); 136 | // never happens but prevents compiler from deleting unused code 137 | if (max_val > 10) 138 | { 139 | throw std::runtime_error("Unexpected value in data"); 140 | } 141 | } 142 | })); 143 | } 144 | 145 | std::cout << "Sleep and continue operations with index" << std::endl; 146 | int sleep_ms = 60 * 1000; 147 | std::this_thread::sleep_for(std::chrono::milliseconds(sleep_ms)); 148 | stop_threads = true; 149 | for (auto &thread : threads) 150 | { 151 | thread.join(); 152 | } 153 | 154 | std::cout << "Finish" << std::endl; 155 | return 0; 156 | } 157 | -------------------------------------------------------------------------------- /tests/cpp/multiThread_replace_test.cpp: -------------------------------------------------------------------------------- 1 | #include "../../hnswlib/hnswlib.h" 2 | #include 3 | #include 4 | 5 | template 6 | inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) 7 | { 8 | if (numThreads <= 0) 9 | { 10 | numThreads = std::thread::hardware_concurrency(); 11 | } 12 | 13 | if (numThreads == 1) 14 | { 15 | for (size_t id = start; id < end; id++) 16 | { 17 | fn(id, 0); 18 | } 19 | } 20 | else 21 | { 22 | std::vector threads; 23 | std::atomic current(start); 24 | 25 | // keep track of exceptions in threads 26 | // https://stackoverflow.com/a/32428427/1713196 27 | std::exception_ptr lastException = nullptr; 28 | std::mutex lastExceptMutex; 29 | 30 | for (size_t threadId = 0; threadId < numThreads; ++threadId) 31 | { 32 | threads.push_back(std::thread([&, threadId] 33 | { 34 | while (true) { 35 | size_t id = current.fetch_add(1); 36 | 37 | if (id >= end) { 38 | break; 39 | } 40 | 41 | try { 42 | fn(id, threadId); 43 | } catch (...) { 44 | std::unique_lock lastExcepLock(lastExceptMutex); 45 | lastException = std::current_exception(); 46 | /* 47 | * This will work even when current is the largest value that 48 | * size_t can fit, because fetch_add returns the previous value 49 | * before the increment (what will result in overflow 50 | * and produce 0 instead of current + 1). 51 | */ 52 | current = end; 53 | break; 54 | } 55 | } })); 56 | } 57 | for (auto &thread : threads) 58 | { 59 | thread.join(); 60 | } 61 | if (lastException) 62 | { 63 | std::rethrow_exception(lastException); 64 | } 65 | } 66 | } 67 | 68 | int main() 69 | { 70 | std::cout << "Running multithread load test" << std::endl; 71 | int d = 16; 72 | int num_elements = 1000; 73 | int max_elements = 2 * num_elements; 74 | int num_threads = 50; 75 | 76 | std::mt19937 rng; 77 | rng.seed(47); 78 | std::uniform_real_distribution<> distrib_real; 79 | 80 | hnswlib::L2Space space(d); 81 | 82 | // generate batch1 and batch2 data 83 | float *batch1 = new float[d * max_elements]; 84 | for (int i = 0; i < d * max_elements; i++) 85 | { 86 | batch1[i] = distrib_real(rng); 87 | } 88 | float *batch2 = new float[d * num_elements]; 89 | for (int i = 0; i < d * num_elements; i++) 90 | { 91 | batch2[i] = distrib_real(rng); 92 | } 93 | 94 | // generate random labels to delete them from index 95 | std::vector rand_labels(max_elements); 96 | for (int i = 0; i < max_elements; i++) 97 | { 98 | rand_labels[i] = i; 99 | } 100 | std::shuffle(rand_labels.begin(), rand_labels.end(), rng); 101 | 102 | int iter = 0; 103 | while (iter < 200) 104 | { 105 | hnswlib::HierarchicalNSW *alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, 16, 200, 123, true); 106 | 107 | // add batch1 data 108 | ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) 109 | { alg_hnsw->addPoint((void *)(batch1 + d * row), row); }); 110 | 111 | // delete half random elements of batch1 data 112 | for (int i = 0; i < num_elements; i++) 113 | { 114 | alg_hnsw->markDelete(rand_labels[i]); 115 | } 116 | 117 | // replace deleted elements with batch2 data 118 | ParallelFor(0, num_elements, num_threads, [&](size_t row, size_t threadId) 119 | { 120 | int label = rand_labels[row] + max_elements; 121 | alg_hnsw->addPoint((void*)(batch2 + d * row), label, true); }); 122 | 123 | iter += 1; 124 | 125 | delete alg_hnsw; 126 | } 127 | 128 | std::cout << "Finish" << std::endl; 129 | 130 | delete[] batch1; 131 | delete[] batch2; 132 | return 0; 133 | } 134 | -------------------------------------------------------------------------------- /tests/cpp/searchKnnCloserFirst_test.cpp: -------------------------------------------------------------------------------- 1 | // This is a test file for testing the interface 2 | // >>> virtual std::vector> 3 | // >>> searchKnnCloserFirst(const void* query_data, size_t k) const; 4 | // of class AlgorithmInterface 5 | 6 | #include "../../hnswlib/hnswlib.h" 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | 13 | namespace 14 | { 15 | 16 | using idx_t = hnswlib::labeltype; 17 | 18 | void test() 19 | { 20 | int d = 4; 21 | idx_t n = 100; 22 | idx_t nq = 10; 23 | size_t k = 10; 24 | 25 | std::vector data(n * d); 26 | std::vector query(nq * d); 27 | 28 | std::mt19937 rng; 29 | rng.seed(47); 30 | std::uniform_real_distribution<> distrib; 31 | 32 | for (idx_t i = 0; i < n * d; ++i) 33 | { 34 | data[i] = distrib(rng); 35 | } 36 | for (idx_t i = 0; i < nq * d; ++i) 37 | { 38 | query[i] = distrib(rng); 39 | } 40 | 41 | hnswlib::L2Space space(d); 42 | hnswlib::AlgorithmInterface *alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); 43 | hnswlib::AlgorithmInterface *alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); 44 | 45 | for (size_t i = 0; i < n; ++i) 46 | { 47 | alg_brute->addPoint(data.data() + d * i, i); 48 | alg_hnsw->addPoint(data.data() + d * i, i); 49 | } 50 | 51 | // test searchKnnCloserFirst of BruteforceSearch 52 | for (size_t j = 0; j < nq; ++j) 53 | { 54 | const void *p = query.data() + j * d; 55 | auto gd = alg_brute->searchKnn(p, k); 56 | auto res = alg_brute->searchKnnCloserFirst(p, k); 57 | assert(gd.size() == res.size()); 58 | size_t t = gd.size(); 59 | while (!gd.empty()) 60 | { 61 | assert(gd.top() == res[--t]); 62 | gd.pop(); 63 | } 64 | } 65 | for (size_t j = 0; j < nq; ++j) 66 | { 67 | const void *p = query.data() + j * d; 68 | auto gd = alg_hnsw->searchKnn(p, k); 69 | auto res = alg_hnsw->searchKnnCloserFirst(p, k); 70 | assert(gd.size() == res.size()); 71 | size_t t = gd.size(); 72 | while (!gd.empty()) 73 | { 74 | assert(gd.top() == res[--t]); 75 | gd.pop(); 76 | } 77 | } 78 | 79 | delete alg_brute; 80 | delete alg_hnsw; 81 | } 82 | 83 | } // namespace 84 | 85 | int main() 86 | { 87 | std::cout << "Testing ..." << std::endl; 88 | test(); 89 | std::cout << "Test ok" << std::endl; 90 | 91 | return 0; 92 | } 93 | -------------------------------------------------------------------------------- /tests/cpp/searchKnnWithFilter_test.cpp: -------------------------------------------------------------------------------- 1 | // This is a test file for testing the filtering feature 2 | 3 | #include "../../hnswlib/hnswlib.h" 4 | 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | namespace 11 | { 12 | 13 | using idx_t = hnswlib::labeltype; 14 | 15 | class PickDivisibleIds : public hnswlib::BaseFilterFunctor 16 | { 17 | unsigned int divisor = 1; 18 | 19 | public: 20 | PickDivisibleIds(unsigned int divisor) : divisor(divisor) 21 | { 22 | assert(divisor != 0); 23 | } 24 | bool operator()(idx_t label_id) 25 | { 26 | return label_id % divisor == 0; 27 | } 28 | }; 29 | 30 | class PickNothing : public hnswlib::BaseFilterFunctor 31 | { 32 | public: 33 | bool operator()(idx_t label_id) 34 | { 35 | return false; 36 | } 37 | }; 38 | 39 | void test_some_filtering(hnswlib::BaseFilterFunctor &filter_func, size_t div_num, size_t label_id_start) 40 | { 41 | int d = 4; 42 | idx_t n = 100; 43 | idx_t nq = 10; 44 | size_t k = 10; 45 | 46 | std::vector data(n * d); 47 | std::vector query(nq * d); 48 | 49 | std::mt19937 rng; 50 | rng.seed(47); 51 | std::uniform_real_distribution<> distrib; 52 | 53 | for (idx_t i = 0; i < n * d; ++i) 54 | { 55 | data[i] = distrib(rng); 56 | } 57 | for (idx_t i = 0; i < nq * d; ++i) 58 | { 59 | query[i] = distrib(rng); 60 | } 61 | 62 | hnswlib::L2Space space(d); 63 | hnswlib::AlgorithmInterface *alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); 64 | hnswlib::AlgorithmInterface *alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); 65 | 66 | for (size_t i = 0; i < n; ++i) 67 | { 68 | // `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs 69 | alg_brute->addPoint(data.data() + d * i, label_id_start + i); 70 | alg_hnsw->addPoint(data.data() + d * i, label_id_start + i); 71 | } 72 | 73 | // test searchKnnCloserFirst of BruteforceSearch with filtering 74 | for (size_t j = 0; j < nq; ++j) 75 | { 76 | const void *p = query.data() + j * d; 77 | auto gd = alg_brute->searchKnn(p, k, &filter_func); 78 | auto res = alg_brute->searchKnnCloserFirst(p, k, &filter_func); 79 | assert(gd.size() == res.size()); 80 | size_t t = gd.size(); 81 | while (!gd.empty()) 82 | { 83 | assert(gd.top() == res[--t]); 84 | assert((gd.top().second % div_num) == 0); 85 | gd.pop(); 86 | } 87 | } 88 | 89 | // test searchKnnCloserFirst of hnsw with filtering 90 | for (size_t j = 0; j < nq; ++j) 91 | { 92 | const void *p = query.data() + j * d; 93 | auto gd = alg_hnsw->searchKnn(p, k, &filter_func); 94 | auto res = alg_hnsw->searchKnnCloserFirst(p, k, &filter_func); 95 | assert(gd.size() == res.size()); 96 | size_t t = gd.size(); 97 | while (!gd.empty()) 98 | { 99 | assert(gd.top() == res[--t]); 100 | assert((gd.top().second % div_num) == 0); 101 | gd.pop(); 102 | } 103 | } 104 | 105 | delete alg_brute; 106 | delete alg_hnsw; 107 | } 108 | 109 | void test_none_filtering(hnswlib::BaseFilterFunctor &filter_func, size_t label_id_start) 110 | { 111 | int d = 4; 112 | idx_t n = 100; 113 | idx_t nq = 10; 114 | size_t k = 10; 115 | 116 | std::vector data(n * d); 117 | std::vector query(nq * d); 118 | 119 | std::mt19937 rng; 120 | rng.seed(47); 121 | std::uniform_real_distribution<> distrib; 122 | 123 | for (idx_t i = 0; i < n * d; ++i) 124 | { 125 | data[i] = distrib(rng); 126 | } 127 | for (idx_t i = 0; i < nq * d; ++i) 128 | { 129 | query[i] = distrib(rng); 130 | } 131 | 132 | hnswlib::L2Space space(d); 133 | hnswlib::AlgorithmInterface *alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); 134 | hnswlib::AlgorithmInterface *alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); 135 | 136 | for (size_t i = 0; i < n; ++i) 137 | { 138 | // `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs 139 | alg_brute->addPoint(data.data() + d * i, label_id_start + i); 140 | alg_hnsw->addPoint(data.data() + d * i, label_id_start + i); 141 | } 142 | 143 | // test searchKnnCloserFirst of BruteforceSearch with filtering 144 | for (size_t j = 0; j < nq; ++j) 145 | { 146 | const void *p = query.data() + j * d; 147 | auto gd = alg_brute->searchKnn(p, k, &filter_func); 148 | auto res = alg_brute->searchKnnCloserFirst(p, k, &filter_func); 149 | assert(gd.size() == res.size()); 150 | assert(0 == gd.size()); 151 | } 152 | 153 | // test searchKnnCloserFirst of hnsw with filtering 154 | for (size_t j = 0; j < nq; ++j) 155 | { 156 | const void *p = query.data() + j * d; 157 | auto gd = alg_hnsw->searchKnn(p, k, &filter_func); 158 | auto res = alg_hnsw->searchKnnCloserFirst(p, k, &filter_func); 159 | assert(gd.size() == res.size()); 160 | assert(0 == gd.size()); 161 | } 162 | 163 | delete alg_brute; 164 | delete alg_hnsw; 165 | } 166 | 167 | } // namespace 168 | 169 | class CustomFilterFunctor : public hnswlib::BaseFilterFunctor 170 | { 171 | std::unordered_set allowed_values; 172 | 173 | public: 174 | explicit CustomFilterFunctor(const std::unordered_set &values) : allowed_values(values) {} 175 | 176 | bool operator()(idx_t id) 177 | { 178 | return allowed_values.count(id) != 0; 179 | } 180 | }; 181 | 182 | int main() 183 | { 184 | std::cout << "Testing ..." << std::endl; 185 | 186 | // some of the elements are filtered 187 | PickDivisibleIds pickIdsDivisibleByThree(3); 188 | test_some_filtering(pickIdsDivisibleByThree, 3, 17); 189 | PickDivisibleIds pickIdsDivisibleBySeven(7); 190 | test_some_filtering(pickIdsDivisibleBySeven, 7, 17); 191 | 192 | // all of the elements are filtered 193 | PickNothing pickNothing; 194 | test_none_filtering(pickNothing, 17); 195 | 196 | // functor style which can capture context 197 | CustomFilterFunctor pickIdsDivisibleByThirteen({26, 39, 52, 65}); 198 | test_some_filtering(pickIdsDivisibleByThirteen, 13, 21); 199 | 200 | std::cout << "Test ok" << std::endl; 201 | 202 | return 0; 203 | } 204 | -------------------------------------------------------------------------------- /tests/cpp/update_gen_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | 5 | def normalized(a, axis=-1, order=2): 6 | l2 = np.atleast_1d(np.linalg.norm(a, order, axis)) 7 | l2[l2 == 0] = 1 8 | return a / np.expand_dims(l2, axis) 9 | 10 | 11 | N = 100000 12 | dummy_data_multiplier = 3 13 | N_queries = 1000 14 | d = 8 15 | K = 5 16 | 17 | np.random.seed(1) 18 | 19 | print("Generating data...") 20 | batches_dummy = [ 21 | normalized(np.float32(np.random.random((N, d)))) 22 | for _ in range(dummy_data_multiplier) 23 | ] 24 | batch_final = normalized(np.float32(np.random.random((N, d)))) 25 | queries = normalized(np.float32(np.random.random((N_queries, d)))) 26 | print("Computing distances...") 27 | dist = np.dot(queries, batch_final.T) 28 | topk = np.argsort(-dist)[:, :K] 29 | print("Saving...") 30 | 31 | try: 32 | os.mkdir("data") 33 | except OSError as e: 34 | pass 35 | 36 | for idx, batch_dummy in enumerate(batches_dummy): 37 | batch_dummy.tofile("data/batch_dummy_%02d.bin" % idx) 38 | batch_final.tofile("data/batch_final.bin") 39 | queries.tofile("data/queries.bin") 40 | np.int32(topk).tofile("data/gt.bin") 41 | with open("data/config.txt", "w") as file: 42 | file.write("%d %d %d %d %d" % (N, dummy_data_multiplier, N_queries, d, K)) 43 | -------------------------------------------------------------------------------- /tests/cpp/updates_test.cpp: -------------------------------------------------------------------------------- 1 | #include "../../hnswlib/hnswlib.h" 2 | #include 3 | 4 | class StopW 5 | { 6 | std::chrono::steady_clock::time_point time_begin; 7 | 8 | public: 9 | StopW() 10 | { 11 | time_begin = std::chrono::steady_clock::now(); 12 | } 13 | 14 | float getElapsedTimeMicro() 15 | { 16 | std::chrono::steady_clock::time_point time_end = std::chrono::steady_clock::now(); 17 | return (std::chrono::duration_cast(time_end - time_begin).count()); 18 | } 19 | 20 | void reset() 21 | { 22 | time_begin = std::chrono::steady_clock::now(); 23 | } 24 | }; 25 | 26 | /* 27 | * replacement for the openmp '#pragma omp parallel for' directive 28 | * only handles a subset of functionality (no reductions etc) 29 | * Process ids from start (inclusive) to end (EXCLUSIVE) 30 | * 31 | * The method is borrowed from nmslib 32 | */ 33 | template 34 | inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) 35 | { 36 | if (numThreads <= 0) 37 | { 38 | numThreads = std::thread::hardware_concurrency(); 39 | } 40 | 41 | if (numThreads == 1) 42 | { 43 | for (size_t id = start; id < end; id++) 44 | { 45 | fn(id, 0); 46 | } 47 | } 48 | else 49 | { 50 | std::vector threads; 51 | std::atomic current(start); 52 | 53 | // keep track of exceptions in threads 54 | // https://stackoverflow.com/a/32428427/1713196 55 | std::exception_ptr lastException = nullptr; 56 | std::mutex lastExceptMutex; 57 | 58 | for (size_t threadId = 0; threadId < numThreads; ++threadId) 59 | { 60 | threads.push_back(std::thread([&, threadId] 61 | { 62 | while (true) { 63 | size_t id = current.fetch_add(1); 64 | 65 | if ((id >= end)) { 66 | break; 67 | } 68 | 69 | try { 70 | fn(id, threadId); 71 | } catch (...) { 72 | std::unique_lock lastExcepLock(lastExceptMutex); 73 | lastException = std::current_exception(); 74 | /* 75 | * This will work even when current is the largest value that 76 | * size_t can fit, because fetch_add returns the previous value 77 | * before the increment (what will result in overflow 78 | * and produce 0 instead of current + 1). 79 | */ 80 | current = end; 81 | break; 82 | } 83 | } })); 84 | } 85 | for (auto &thread : threads) 86 | { 87 | thread.join(); 88 | } 89 | if (lastException) 90 | { 91 | std::rethrow_exception(lastException); 92 | } 93 | } 94 | } 95 | 96 | template 97 | std::vector load_batch(std::string path, int size) 98 | { 99 | std::cout << "Loading " << path << "..."; 100 | // float or int32 (python) 101 | assert(sizeof(datatype) == 4); 102 | 103 | std::ifstream file; 104 | file.open(path, std::ios::binary); 105 | if (!file.is_open()) 106 | { 107 | std::cout << "Cannot open " << path << "\n"; 108 | exit(1); 109 | } 110 | std::vector batch(size); 111 | 112 | file.read((char *)batch.data(), size * sizeof(float)); 113 | std::cout << " DONE\n"; 114 | return batch; 115 | } 116 | 117 | template 118 | static float 119 | test_approx(std::vector &queries, size_t qsize, hnswlib::HierarchicalNSW &appr_alg, size_t vecdim, 120 | std::vector> &answers, size_t K) 121 | { 122 | size_t correct = 0; 123 | size_t total = 0; 124 | 125 | for (int i = 0; i < qsize; i++) 126 | { 127 | std::priority_queue> result = appr_alg.searchKnn((char *)(queries.data() + vecdim * i), K); 128 | total += K; 129 | while (result.size()) 130 | { 131 | if (answers[i].find(result.top().second) != answers[i].end()) 132 | { 133 | correct++; 134 | } 135 | else 136 | { 137 | } 138 | result.pop(); 139 | } 140 | } 141 | return 1.0f * correct / total; 142 | } 143 | 144 | static void 145 | test_vs_recall( 146 | std::vector &queries, 147 | size_t qsize, 148 | hnswlib::HierarchicalNSW &appr_alg, 149 | size_t vecdim, 150 | std::vector> &answers, 151 | size_t k) 152 | { 153 | 154 | std::vector efs = {1}; 155 | for (int i = k; i < 30; i++) 156 | { 157 | efs.push_back(i); 158 | } 159 | for (int i = 30; i < 400; i += 10) 160 | { 161 | efs.push_back(i); 162 | } 163 | for (int i = 1000; i < 100000; i += 5000) 164 | { 165 | efs.push_back(i); 166 | } 167 | std::cout << "ef\trecall\ttime\thops\tdistcomp\n"; 168 | 169 | bool test_passed = false; 170 | for (size_t ef : efs) 171 | { 172 | appr_alg.setEf(ef); 173 | 174 | appr_alg.metric_hops = 0; 175 | appr_alg.metric_distance_computations = 0; 176 | StopW stopw = StopW(); 177 | 178 | float recall = test_approx(queries, qsize, appr_alg, vecdim, answers, k); 179 | float time_us_per_query = stopw.getElapsedTimeMicro() / qsize; 180 | float distance_comp_per_query = appr_alg.metric_distance_computations / (1.0f * qsize); 181 | float hops_per_query = appr_alg.metric_hops / (1.0f * qsize); 182 | 183 | std::cout << ef << "\t" << recall << "\t" << time_us_per_query << "us \t" << hops_per_query << "\t" << distance_comp_per_query << "\n"; 184 | if (recall > 0.99) 185 | { 186 | test_passed = true; 187 | std::cout << "Recall is over 0.99! " << recall << "\t" << time_us_per_query << "us \t" << hops_per_query << "\t" << distance_comp_per_query << "\n"; 188 | break; 189 | } 190 | } 191 | if (!test_passed) 192 | { 193 | std::cerr << "Test failed\n"; 194 | exit(1); 195 | } 196 | } 197 | 198 | int main(int argc, char **argv) 199 | { 200 | int M = 16; 201 | int efConstruction = 200; 202 | int num_threads = std::thread::hardware_concurrency(); 203 | 204 | bool update = false; 205 | 206 | if (argc == 2) 207 | { 208 | if (std::string(argv[1]) == "update") 209 | { 210 | update = true; 211 | std::cout << "Updates are on\n"; 212 | } 213 | else 214 | { 215 | std::cout << "Usage ./test_updates [update]\n"; 216 | exit(1); 217 | } 218 | } 219 | else if (argc > 2) 220 | { 221 | std::cout << "Usage ./test_updates [update]\n"; 222 | exit(1); 223 | } 224 | 225 | std::string path = "../tests/cpp/data/"; 226 | 227 | int N; 228 | int dummy_data_multiplier; 229 | int N_queries; 230 | int d; 231 | int K; 232 | { 233 | std::ifstream configfile; 234 | configfile.open(path + "/config.txt"); 235 | if (!configfile.is_open()) 236 | { 237 | std::cout << "Cannot open config.txt\n"; 238 | return 1; 239 | } 240 | configfile >> N >> dummy_data_multiplier >> N_queries >> d >> K; 241 | 242 | printf("Loaded config: N=%d, d_mult=%d, Nq=%d, dim=%d, K=%d\n", N, dummy_data_multiplier, N_queries, d, K); 243 | } 244 | 245 | hnswlib::L2Space l2space(d); 246 | hnswlib::HierarchicalNSW appr_alg(&l2space, N + 1, M, efConstruction); 247 | 248 | std::vector dummy_batch = load_batch(path + "batch_dummy_00.bin", N * d); 249 | 250 | // Adding enterpoint: 251 | 252 | appr_alg.addPoint((void *)dummy_batch.data(), (size_t)0); 253 | 254 | StopW stopw = StopW(); 255 | 256 | if (update) 257 | { 258 | std::cout << "Update iteration 0\n"; 259 | 260 | ParallelFor(1, N, num_threads, [&](size_t i, size_t threadId) 261 | { appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i); }); 262 | appr_alg.checkIntegrity(); 263 | 264 | ParallelFor(1, N, num_threads, [&](size_t i, size_t threadId) 265 | { appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i); }); 266 | appr_alg.checkIntegrity(); 267 | 268 | for (int b = 1; b < dummy_data_multiplier; b++) 269 | { 270 | std::cout << "Update iteration " << b << "\n"; 271 | char cpath[1024]; 272 | sprintf(cpath, "batch_dummy_%02d.bin", b); 273 | std::vector dummy_batchb = load_batch(path + cpath, N * d); 274 | 275 | ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) 276 | { appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i); }); 277 | appr_alg.checkIntegrity(); 278 | } 279 | } 280 | 281 | std::cout << "Inserting final elements\n"; 282 | std::vector final_batch = load_batch(path + "batch_final.bin", N * d); 283 | 284 | stopw.reset(); 285 | ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) 286 | { appr_alg.addPoint((void *)(final_batch.data() + i * d), i); }); 287 | std::cout << "Finished. Time taken:" << stopw.getElapsedTimeMicro() * 1e-6 << " s\n"; 288 | std::cout << "Running tests\n"; 289 | std::vector queries_batch = load_batch(path + "queries.bin", N_queries * d); 290 | 291 | std::vector gt = load_batch(path + "gt.bin", N_queries * K); 292 | 293 | std::vector> answers(N_queries); 294 | for (int i = 0; i < N_queries; i++) 295 | { 296 | for (int j = 0; j < K; j++) 297 | { 298 | answers[i].insert(gt[i * K + j]); 299 | } 300 | } 301 | 302 | for (int i = 0; i < 3; i++) 303 | { 304 | std::cout << "Test iteration " << i << "\n"; 305 | test_vs_recall(queries_batch, N_queries, appr_alg, d, answers, K); 306 | } 307 | 308 | return 0; 309 | } 310 | -------------------------------------------------------------------------------- /tests/python/bindings_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | import numpy as np 5 | 6 | import hnswlib 7 | 8 | 9 | class RandomSelfTestCase(unittest.TestCase): 10 | def testRandomSelf(self): 11 | dim = 16 12 | num_elements = 10000 13 | 14 | # Generating sample data 15 | data = np.float32(np.random.random((num_elements, dim))) 16 | 17 | # Declaring index 18 | p = hnswlib.Index(space="l2", dim=dim) # possible options are l2, cosine or ip 19 | 20 | # Initiating index 21 | # max_elements - the maximum number of elements, should be known beforehand 22 | # (probably will be made optional in the future) 23 | # 24 | # ef_construction - controls index search speed/build speed tradeoff 25 | # M - is tightly connected with internal dimensionality of the data 26 | # strongly affects the memory consumption 27 | 28 | p.init_index(max_elements=num_elements, ef_construction=100, M=16) 29 | 30 | # Controlling the recall by setting ef: 31 | # higher ef leads to better accuracy, but slower search 32 | p.set_ef(10) 33 | 34 | p.set_num_threads(4) # by default using all available cores 35 | 36 | # We split the data in two batches: 37 | data1 = data[: num_elements // 2] 38 | data2 = data[num_elements // 2 :] 39 | 40 | print("Adding first batch of %d elements" % (len(data1))) 41 | p.add_items(data1) 42 | 43 | # Query the elements for themselves and measure recall: 44 | labels, distances = p.knn_query(data1, k=1) 45 | self.assertAlmostEqual( 46 | np.mean(labels.reshape(-1) == np.arange(len(data1))), 1.0, 3 47 | ) 48 | 49 | # Serializing and deleting the index: 50 | index_path = "first_half.bin" 51 | print("Saving index to '%s'" % index_path) 52 | p.save_index(index_path) 53 | del p 54 | 55 | # Re-initiating, loading the index 56 | p = hnswlib.Index(space="l2", dim=dim) # you can change the sa 57 | 58 | print("\nLoading index from '%s'\n" % index_path) 59 | p.load_index(index_path) 60 | 61 | print("Adding the second batch of %d elements" % (len(data2))) 62 | p.add_items(data2) 63 | 64 | # Query the elements for themselves and measure recall: 65 | labels, distances = p.knn_query(data, k=1) 66 | 67 | self.assertAlmostEqual( 68 | np.mean(labels.reshape(-1) == np.arange(len(data))), 1.0, 3 69 | ) 70 | 71 | os.remove(index_path) 72 | -------------------------------------------------------------------------------- /tests/python/bindings_test_filter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | import numpy as np 5 | 6 | import hnswlib 7 | 8 | 9 | class RandomSelfTestCase(unittest.TestCase): 10 | def testRandomSelf(self): 11 | dim = 16 12 | num_elements = 10000 13 | 14 | # Generating sample data 15 | data = np.float32(np.random.random((num_elements, dim))) 16 | 17 | # Declaring index 18 | hnsw_index = hnswlib.Index( 19 | space="l2", dim=dim 20 | ) # possible options are l2, cosine or ip 21 | bf_index = hnswlib.BFIndex(space="l2", dim=dim) 22 | 23 | # Initiating index 24 | # max_elements - the maximum number of elements, should be known beforehand 25 | # (probably will be made optional in the future) 26 | # 27 | # ef_construction - controls index search speed/build speed tradeoff 28 | # M - is tightly connected with internal dimensionality of the data 29 | # strongly affects the memory consumption 30 | 31 | hnsw_index.init_index(max_elements=num_elements, ef_construction=100, M=16) 32 | bf_index.init_index(max_elements=num_elements) 33 | 34 | # Controlling the recall by setting ef: 35 | # higher ef leads to better accuracy, but slower search 36 | hnsw_index.set_ef(10) 37 | 38 | hnsw_index.set_num_threads(4) # by default using all available cores 39 | 40 | print("Adding %d elements" % (len(data))) 41 | hnsw_index.add_items(data) 42 | bf_index.add_items(data) 43 | 44 | # Query the elements for themselves and measure recall: 45 | labels, distances = hnsw_index.knn_query(data, k=1) 46 | self.assertAlmostEqual( 47 | np.mean(labels.reshape(-1) == np.arange(len(data))), 1.0, 3 48 | ) 49 | 50 | print("Querying only even elements") 51 | # Query the even elements for themselves and measure recall: 52 | filter_function = lambda id: id % 2 == 0 53 | # Warning: search with a filter works slow in python in multithreaded mode, therefore we set num_threads=1 54 | labels, distances = hnsw_index.knn_query( 55 | data, k=1, num_threads=1, filter=filter_function 56 | ) 57 | self.assertAlmostEqual( 58 | np.mean(labels.reshape(-1) == np.arange(len(data))), 0.5, 3 59 | ) 60 | # Verify that there are only even elements: 61 | self.assertTrue(np.max(np.mod(labels, 2)) == 0) 62 | 63 | labels, distances = bf_index.knn_query(data, k=1, filter=filter_function) 64 | self.assertEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), 0.5) 65 | -------------------------------------------------------------------------------- /tests/python/bindings_test_getdata.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | import hnswlib 6 | 7 | 8 | class RandomSelfTestCase(unittest.TestCase): 9 | def testGettingItems(self): 10 | print("\n**** Getting the data by label test ****\n") 11 | 12 | dim = 16 13 | num_elements = 10000 14 | 15 | # Generating sample data 16 | data = np.float32(np.random.random((num_elements, dim))) 17 | labels = np.arange(0, num_elements) 18 | 19 | for space in ["l2", "ip", "cosine"]: 20 | # Declaring index 21 | p = hnswlib.Index( 22 | space=space, dim=dim 23 | ) # possible options are l2, cosine or ip 24 | 25 | # Initiating index 26 | # max_elements - the maximum number of elements, should be known beforehand 27 | # (probably will be made optional in the future) 28 | # 29 | # ef_construction - controls index search speed/build speed tradeoff 30 | # M - is tightly connected with internal dimensionality of the data 31 | # strongly affects the memory consumption 32 | 33 | p.init_index(max_elements=num_elements, ef_construction=100, M=16) 34 | 35 | # Controlling the recall by setting ef: 36 | # higher ef leads to better accuracy, but slower search 37 | p.set_ef(100) 38 | 39 | p.set_num_threads(4) # by default using all available cores 40 | 41 | # Before adding anything, getting any labels should fail 42 | self.assertRaises(Exception, lambda: p.get_items(labels)) 43 | 44 | print("Adding all elements (%d)" % (len(data))) 45 | p.add_items(data, labels) 46 | 47 | # Getting data by label should raise an exception if a scalar is passed: 48 | self.assertRaises(ValueError, lambda: p.get_items(labels[0])) 49 | 50 | # After adding them, all labels should be retrievable 51 | returned_items = p.get_items(labels) 52 | self.assertTrue(np.allclose(data, returned_items, atol=1e-6)) 53 | -------------------------------------------------------------------------------- /tests/python/bindings_test_labels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | import numpy as np 5 | 6 | import hnswlib 7 | 8 | 9 | class RandomSelfTestCase(unittest.TestCase): 10 | def testRandomSelf(self): 11 | for idx in range(2): 12 | print("\n**** Index save-load test ****\n") 13 | 14 | np.random.seed(idx) 15 | dim = 16 16 | num_elements = 10000 17 | 18 | # Generating sample data 19 | data = np.float32(np.random.random((num_elements, dim))) 20 | 21 | # Declaring index 22 | p = hnswlib.Index( 23 | space="l2", dim=dim 24 | ) # possible options are l2, cosine or ip 25 | 26 | # Initiating index 27 | # max_elements - the maximum number of elements, should be known beforehand 28 | # (probably will be made optional in the future) 29 | # 30 | # ef_construction - controls index search speed/build speed tradeoff 31 | # M - is tightly connected with internal dimensionality of the data 32 | # strongly affects the memory consumption 33 | 34 | p.init_index(max_elements=num_elements, ef_construction=100, M=16) 35 | 36 | # Controlling the recall by setting ef: 37 | # higher ef leads to better accuracy, but slower search 38 | p.set_ef(100) 39 | 40 | p.set_num_threads(4) # by default using all available cores 41 | 42 | # We split the data in two batches: 43 | data1 = data[: num_elements // 2] 44 | data2 = data[num_elements // 2 :] 45 | 46 | print("Adding first batch of %d elements" % (len(data1))) 47 | p.add_items(data1) 48 | 49 | # Query the elements for themselves and measure recall: 50 | labels, distances = p.knn_query(data1, k=1) 51 | 52 | items = p.get_items(labels) 53 | 54 | # Check the recall: 55 | self.assertAlmostEqual( 56 | np.mean(labels.reshape(-1) == np.arange(len(data1))), 1.0, 3 57 | ) 58 | 59 | # Check that the returned element data is correct: 60 | diff_with_gt_labels = np.mean(np.abs(data1 - items)) 61 | self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-4) 62 | 63 | # Serializing and deleting the index. 64 | # We need the part to check that serialization is working properly. 65 | 66 | index_path = "first_half.bin" 67 | print("Saving index to '%s'" % index_path) 68 | p.save_index(index_path) 69 | print("Saved. Deleting...") 70 | del p 71 | print("Deleted") 72 | 73 | print("\n**** Mark delete test ****\n") 74 | # Re-initiating, loading the index 75 | print("Re-initiating") 76 | p = hnswlib.Index(space="l2", dim=dim) 77 | 78 | print("\nLoading index from '%s'\n" % index_path) 79 | p.load_index(index_path) 80 | p.set_ef(100) 81 | 82 | print("Adding the second batch of %d elements" % (len(data2))) 83 | p.add_items(data2) 84 | 85 | # Query the elements for themselves and measure recall: 86 | labels, distances = p.knn_query(data, k=1) 87 | items = p.get_items(labels) 88 | 89 | # Check the recall: 90 | self.assertAlmostEqual( 91 | np.mean(labels.reshape(-1) == np.arange(len(data))), 1.0, 3 92 | ) 93 | 94 | # Check that the returned element data is correct: 95 | diff_with_gt_labels = np.mean(np.abs(data - items)) 96 | self.assertAlmostEqual( 97 | diff_with_gt_labels, 0, delta=1e-4 98 | ) # deleting index. 99 | 100 | # Checking that all labels are returned correctly: 101 | sorted_labels = sorted(p.get_ids_list()) 102 | self.assertEqual( 103 | np.sum(~np.asarray(sorted_labels) == np.asarray(range(num_elements))), 0 104 | ) 105 | 106 | # Delete data1 107 | labels1_deleted, _ = p.knn_query(data1, k=1) 108 | # delete probable duplicates from nearest neighbors 109 | labels1_deleted_no_dup = set(labels1_deleted.flatten()) 110 | for l in labels1_deleted_no_dup: 111 | p.mark_deleted(l) 112 | labels2, _ = p.knn_query(data2, k=1) 113 | items = p.get_items(labels2) 114 | diff_with_gt_labels = np.mean(np.abs(data2 - items)) 115 | self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-3) 116 | 117 | labels1_after, _ = p.knn_query(data1, k=1) 118 | for la in labels1_after: 119 | if la[0] in labels1_deleted_no_dup: 120 | print(f"Found deleted label {la[0]} during knn search") 121 | self.assertTrue(False) 122 | print("All the data in data1 are removed") 123 | 124 | # Checking saving/loading index with elements marked as deleted 125 | del_index_path = "with_deleted.bin" 126 | p.save_index(del_index_path) 127 | p = hnswlib.Index(space="l2", dim=dim) 128 | p.load_index(del_index_path) 129 | p.set_ef(100) 130 | 131 | labels1_after, _ = p.knn_query(data1, k=1) 132 | for la in labels1_after: 133 | if la[0] in labels1_deleted_no_dup: 134 | print( 135 | f"Found deleted label {la[0]} during knn search after index loading" 136 | ) 137 | self.assertTrue(False) 138 | 139 | # Unmark deleted data 140 | for l in labels1_deleted_no_dup: 141 | p.unmark_deleted(l) 142 | labels_restored, _ = p.knn_query(data1, k=1) 143 | self.assertAlmostEqual( 144 | np.mean(labels_restored.reshape(-1) == np.arange(len(data1))), 1.0, 3 145 | ) 146 | print("All the data in data1 are restored") 147 | 148 | os.remove(index_path) 149 | os.remove(del_index_path) 150 | -------------------------------------------------------------------------------- /tests/python/bindings_test_metadata.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | import hnswlib 6 | 7 | 8 | class RandomSelfTestCase(unittest.TestCase): 9 | def testMetadata(self): 10 | dim = 16 11 | num_elements = 10000 12 | 13 | # Generating sample data 14 | data = np.float32(np.random.random((num_elements, dim))) 15 | 16 | # Declaring index 17 | p = hnswlib.Index(space="l2", dim=dim) # possible options are l2, cosine or ip 18 | 19 | # Initing index 20 | # max_elements - the maximum number of elements, should be known beforehand 21 | # (probably will be made optional in the future) 22 | # 23 | # ef_construction - controls index search speed/build speed tradeoff 24 | # M - is tightly connected with internal dimensionality of the data 25 | # stronlgy affects the memory consumption 26 | 27 | p.init_index(max_elements=num_elements, ef_construction=100, M=16) 28 | 29 | # Controlling the recall by setting ef: 30 | # higher ef leads to better accuracy, but slower search 31 | p.set_ef(100) 32 | 33 | p.set_num_threads(4) # by default using all available cores 34 | 35 | print("Adding all elements (%d)" % (len(data))) 36 | p.add_items(data) 37 | 38 | # test methods 39 | self.assertEqual(p.get_max_elements(), num_elements) 40 | self.assertEqual(p.get_current_count(), num_elements) 41 | 42 | # test properties 43 | self.assertEqual(p.space, "l2") 44 | self.assertEqual(p.dim, dim) 45 | self.assertEqual(p.M, 16) 46 | self.assertEqual(p.ef_construction, 100) 47 | self.assertEqual(p.max_elements, num_elements) 48 | self.assertEqual(p.element_count, num_elements) 49 | -------------------------------------------------------------------------------- /tests/python/bindings_test_persistent.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import hnswlib 4 | import os 5 | 6 | 7 | class RandomSelfTestCase(unittest.TestCase): 8 | def testPersistentIndex(self): 9 | print("\n**** Using a persistent index test ****\n") 10 | 11 | dim = 16 12 | num_elements = 10000 13 | 14 | # Generating sample data 15 | data = np.float32(np.random.random((num_elements, dim))) 16 | labels = np.arange(0, num_elements) 17 | 18 | # Declaring index 19 | p = hnswlib.Index(space="l2", dim=dim) 20 | 21 | # Initiating index 22 | # Make test dir if it doesn't exist 23 | if not os.path.exists("test_dir"): 24 | os.makedirs("test_dir") 25 | p.init_index( 26 | max_elements=num_elements, 27 | ef_construction=100, 28 | M=16, 29 | is_persistent_index=True, 30 | persistence_location="test_dir", 31 | ) 32 | p.set_num_threads(4) 33 | 34 | print("Adding all elements (%d)" % (len(data))) 35 | p.add_items(data, labels) 36 | p.persist_dirty() 37 | 38 | # Load a persisted index 39 | p2 = hnswlib.Index(space="l2", dim=dim) 40 | p2.load_index("test_dir", is_persistent_index=True) 41 | returned_items = p2.get_items(labels) 42 | self.assertTrue(np.allclose(data, returned_items, atol=1e-6)) 43 | 44 | # Test that the query results are the same between the two indices 45 | query = np.float32(np.random.random((1, dim))) 46 | labels, distances = p.knn_query(query, k=10) 47 | labels2, distances2 = p2.knn_query(query, k=10) 48 | # Check if numpy labels are the same 49 | self.assertTrue((labels == labels2).all()) 50 | self.assertTrue(np.allclose(distances, distances2, atol=1e-6)) 51 | -------------------------------------------------------------------------------- /tests/python/bindings_test_pickle.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import unittest 3 | 4 | import numpy as np 5 | 6 | import hnswlib 7 | 8 | 9 | def get_dist(metric, pt1, pt2): 10 | if metric == "l2": 11 | return np.sum((pt1 - pt2) ** 2) 12 | elif metric == "ip": 13 | return 1.0 - np.sum(np.multiply(pt1, pt2)) 14 | elif metric == "cosine": 15 | return ( 16 | 1.0 17 | - np.sum(np.multiply(pt1, pt2)) 18 | / (np.sum(pt1**2) * np.sum(pt2**2)) ** 0.5 19 | ) 20 | 21 | 22 | def brute_force_distances(metric, items, query_items, k): 23 | dists = np.zeros((query_items.shape[0], items.shape[0])) 24 | for ii in range(items.shape[0]): 25 | for jj in range(query_items.shape[0]): 26 | dists[jj, ii] = get_dist(metric, items[ii, :], query_items[jj, :]) 27 | 28 | labels = np.argsort( 29 | dists, axis=1 30 | ) # equivalent, but faster: np.argpartition(dists, range(k), axis=1) 31 | dists = np.sort( 32 | dists, axis=1 33 | ) # equivalent, but faster: np.partition(dists, range(k), axis=1) 34 | 35 | return labels[:, :k], dists[:, :k] 36 | 37 | 38 | def check_ann_results( 39 | self, 40 | metric, 41 | items, 42 | query_items, 43 | k, 44 | ann_l, 45 | ann_d, 46 | err_thresh=0, 47 | total_thresh=0, 48 | dists_thresh=0, 49 | ): 50 | brute_l, brute_d = brute_force_distances(metric, items, query_items, k) 51 | err_total = 0 52 | for jj in range(query_items.shape[0]): 53 | err = np.sum(np.isin(brute_l[jj, :], ann_l[jj, :], invert=True)) 54 | if err > 0: 55 | print( 56 | f"Warning: {err} labels are missing from ann results (k={k}, err_thresh={err_thresh})" 57 | ) 58 | 59 | if err > err_thresh: 60 | err_total += 1 61 | 62 | self.assertLessEqual( 63 | err_total, 64 | total_thresh, 65 | f"Error: knn_query returned incorrect labels for {err_total} items (k={k})", 66 | ) 67 | 68 | wrong_dists = np.sum(((brute_d - ann_d) ** 2.0) > 1e-3) 69 | if wrong_dists > 0: 70 | dists_count = brute_d.shape[0] * brute_d.shape[1] 71 | print( 72 | f"Warning: {wrong_dists} ann distance values are different from brute-force values (total # of values={dists_count}, dists_thresh={dists_thresh})" 73 | ) 74 | 75 | self.assertLessEqual( 76 | wrong_dists, 77 | dists_thresh, 78 | msg=f"Error: {wrong_dists} ann distance values are different from brute-force values", 79 | ) 80 | 81 | 82 | def test_space_main(self, space, dim): 83 | # Generating sample data 84 | data = np.float32(np.random.random((self.num_elements, dim))) 85 | test_data = np.float32(np.random.random((self.num_test_elements, dim))) 86 | 87 | # Declaring index 88 | p = hnswlib.Index(space=space, dim=dim) # possible options are l2, cosine or ip 89 | print(f"Running pickle tests for {p}") 90 | 91 | p.num_threads = self.num_threads # by default using all available cores 92 | 93 | p0 = pickle.loads(pickle.dumps(p)) # pickle un-initialized Index 94 | p.init_index( 95 | max_elements=self.num_elements, ef_construction=self.ef_construction, M=self.M 96 | ) 97 | p0.init_index( 98 | max_elements=self.num_elements, ef_construction=self.ef_construction, M=self.M 99 | ) 100 | 101 | p.ef = self.ef 102 | p0.ef = self.ef 103 | 104 | p1 = pickle.loads(pickle.dumps(p)) # pickle Index before adding items 105 | 106 | # add items to ann index p,p0,p1 107 | p.add_items(data) 108 | p1.add_items(data) 109 | p0.add_items(data) 110 | 111 | p2 = pickle.loads(pickle.dumps(p)) # pickle Index before adding items 112 | 113 | self.assertTrue( 114 | np.allclose(p.get_items(), p0.get_items()), "items for p and p0 must be same" 115 | ) 116 | self.assertTrue( 117 | np.allclose(p0.get_items(), p1.get_items()), "items for p0 and p1 must be same" 118 | ) 119 | self.assertTrue( 120 | np.allclose(p1.get_items(), p2.get_items()), "items for p1 and p2 must be same" 121 | ) 122 | 123 | # Test if returned distances are same 124 | l, d = p.knn_query(test_data, k=self.k) 125 | l0, d0 = p0.knn_query(test_data, k=self.k) 126 | l1, d1 = p1.knn_query(test_data, k=self.k) 127 | l2, d2 = p2.knn_query(test_data, k=self.k) 128 | 129 | self.assertLessEqual( 130 | np.sum(((d - d0) ** 2.0) > 1e-3), 131 | self.dists_err_thresh, 132 | msg=f"knn distances returned by p and p0 must match", 133 | ) 134 | self.assertLessEqual( 135 | np.sum(((d0 - d1) ** 2.0) > 1e-3), 136 | self.dists_err_thresh, 137 | msg=f"knn distances returned by p0 and p1 must match", 138 | ) 139 | self.assertLessEqual( 140 | np.sum(((d1 - d2) ** 2.0) > 1e-3), 141 | self.dists_err_thresh, 142 | msg=f"knn distances returned by p1 and p2 must match", 143 | ) 144 | 145 | # check if ann results match brute-force search 146 | # allow for 2 labels to be missing from ann results 147 | check_ann_results( 148 | self, 149 | space, 150 | data, 151 | test_data, 152 | self.k, 153 | l, 154 | d, 155 | err_thresh=self.label_err_thresh, 156 | total_thresh=self.item_err_thresh, 157 | dists_thresh=self.dists_err_thresh, 158 | ) 159 | 160 | check_ann_results( 161 | self, 162 | space, 163 | data, 164 | test_data, 165 | self.k, 166 | l2, 167 | d2, 168 | err_thresh=self.label_err_thresh, 169 | total_thresh=self.item_err_thresh, 170 | dists_thresh=self.dists_err_thresh, 171 | ) 172 | 173 | # Check ef parameter value 174 | self.assertEqual(p.ef, self.ef, "incorrect value of p.ef") 175 | self.assertEqual(p0.ef, self.ef, "incorrect value of p0.ef") 176 | self.assertEqual(p2.ef, self.ef, "incorrect value of p2.ef") 177 | self.assertEqual(p1.ef, self.ef, "incorrect value of p1.ef") 178 | 179 | # Check M parameter value 180 | self.assertEqual(p.M, self.M, "incorrect value of p.M") 181 | self.assertEqual(p0.M, self.M, "incorrect value of p0.M") 182 | self.assertEqual(p1.M, self.M, "incorrect value of p1.M") 183 | self.assertEqual(p2.M, self.M, "incorrect value of p2.M") 184 | 185 | # Check ef_construction parameter value 186 | self.assertEqual( 187 | p.ef_construction, self.ef_construction, "incorrect value of p.ef_construction" 188 | ) 189 | self.assertEqual( 190 | p0.ef_construction, 191 | self.ef_construction, 192 | "incorrect value of p0.ef_construction", 193 | ) 194 | self.assertEqual( 195 | p1.ef_construction, 196 | self.ef_construction, 197 | "incorrect value of p1.ef_construction", 198 | ) 199 | self.assertEqual( 200 | p2.ef_construction, 201 | self.ef_construction, 202 | "incorrect value of p2.ef_construction", 203 | ) 204 | 205 | 206 | class PickleUnitTests(unittest.TestCase): 207 | def setUp(self): 208 | self.ef_construction = 200 209 | self.M = 32 210 | self.ef = 400 211 | 212 | self.num_elements = 1000 213 | self.num_test_elements = 100 214 | 215 | self.num_threads = 4 216 | self.k = 25 217 | 218 | self.label_err_thresh = 5 # max number of missing labels allowed per test item 219 | self.item_err_thresh = 5 # max number of items allowed with incorrect labels 220 | 221 | self.dists_err_thresh = ( 222 | 50 # for two matrices, d1 and d2, dists_err_thresh controls max 223 | ) 224 | # number of value pairs that are allowed to be different in d1 and d2 225 | # i.e., number of values that are (d1-d2)**2>1e-3 226 | 227 | def test_inner_product_space(self): 228 | test_space_main(self, "ip", 16) 229 | 230 | def test_l2_space(self): 231 | test_space_main(self, "l2", 53) 232 | 233 | def test_cosine_space(self): 234 | test_space_main(self, "cosine", 32) 235 | -------------------------------------------------------------------------------- /tests/python/bindings_test_recall.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hnswlib 3 | import numpy as np 4 | import unittest 5 | 6 | 7 | class RandomSelfTestCase(unittest.TestCase): 8 | def testRandomSelf(self): 9 | dim = 32 10 | num_elements = 100000 11 | k = 10 12 | num_queries = 20 13 | 14 | recall_threshold = 0.95 15 | 16 | # Generating sample data 17 | data = np.float32(np.random.random((num_elements, dim))) 18 | 19 | # Declaring index 20 | hnsw_index = hnswlib.Index( 21 | space="l2", dim=dim 22 | ) # possible options are l2, cosine or ip 23 | bf_index = hnswlib.BFIndex(space="l2", dim=dim) 24 | 25 | # Initing both hnsw and brute force indices 26 | # max_elements - the maximum number of elements (capacity). Will throw an exception if exceeded 27 | # during insertion of an element. 28 | # The capacity can be increased by saving/loading the index, see below. 29 | # 30 | # hnsw construction params: 31 | # ef_construction - controls index search speed/build speed tradeoff 32 | # 33 | # M - is tightly connected with internal dimensionality of the data. Strongly affects the memory consumption (~M) 34 | # Higher M leads to higher accuracy/run_time at fixed ef/efConstruction 35 | 36 | hnsw_index.init_index(max_elements=num_elements, ef_construction=200, M=16) 37 | bf_index.init_index(max_elements=num_elements) 38 | 39 | # Controlling the recall for hnsw by setting ef: 40 | # higher ef leads to better accuracy, but slower search 41 | hnsw_index.set_ef(200) 42 | 43 | # Set number of threads used during batch search/construction in hnsw 44 | # By default using all available cores 45 | hnsw_index.set_num_threads(4) 46 | 47 | print("Adding batch of %d elements" % (len(data))) 48 | hnsw_index.add_items(data) 49 | bf_index.add_items(data) 50 | 51 | print("Indices built") 52 | 53 | # Generating query data 54 | query_data = np.float32(np.random.random((num_queries, dim))) 55 | 56 | # Query the elements and measure recall: 57 | labels_hnsw, distances_hnsw = hnsw_index.knn_query(query_data, k) 58 | labels_bf, distances_bf = bf_index.knn_query(query_data, k) 59 | 60 | # Measure recall 61 | correct = 0 62 | for i in range(num_queries): 63 | for label in labels_hnsw[i]: 64 | for correct_label in labels_bf[i]: 65 | if label == correct_label: 66 | correct += 1 67 | break 68 | 69 | recall_before = float(correct) / (k * num_queries) 70 | print("recall is :", recall_before) 71 | self.assertGreater(recall_before, recall_threshold) 72 | 73 | # test serializing the brute force index 74 | index_path = "bf_index.bin" 75 | print("Saving index to '%s'" % index_path) 76 | bf_index.save_index(index_path) 77 | del bf_index 78 | 79 | # Re-initiating, loading the index 80 | bf_index = hnswlib.BFIndex(space="l2", dim=dim) 81 | 82 | print("\nLoading index from '%s'\n" % index_path) 83 | bf_index.load_index(index_path) 84 | 85 | # Query the brute force index again to verify that we get the same results 86 | labels_bf, distances_bf = bf_index.knn_query(query_data, k) 87 | 88 | # Measure recall 89 | correct = 0 90 | for i in range(num_queries): 91 | for label in labels_hnsw[i]: 92 | for correct_label in labels_bf[i]: 93 | if label == correct_label: 94 | correct += 1 95 | break 96 | 97 | recall_after = float(correct) / (k * num_queries) 98 | print("recall after reloading is :", recall_after) 99 | 100 | self.assertEqual(recall_before, recall_after) 101 | 102 | os.remove(index_path) 103 | -------------------------------------------------------------------------------- /tests/python/bindings_test_replace.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import unittest 4 | 5 | import numpy as np 6 | 7 | import hnswlib 8 | 9 | 10 | class RandomSelfTestCase(unittest.TestCase): 11 | def testRandomSelf(self): 12 | """ 13 | Tests if replace of deleted elements works correctly 14 | Tests serialization of the index with replaced elements 15 | """ 16 | dim = 16 17 | num_elements = 5000 18 | max_num_elements = 2 * num_elements 19 | 20 | recall_threshold = 0.98 21 | 22 | # Generating sample data 23 | print("Generating data") 24 | # batch 1 25 | first_id = 0 26 | last_id = num_elements 27 | labels1 = np.arange(first_id, last_id) 28 | data1 = np.float32(np.random.random((num_elements, dim))) 29 | # batch 2 30 | first_id += num_elements 31 | last_id += num_elements 32 | labels2 = np.arange(first_id, last_id) 33 | data2 = np.float32(np.random.random((num_elements, dim))) 34 | # batch 3 35 | first_id += num_elements 36 | last_id += num_elements 37 | labels3 = np.arange(first_id, last_id) 38 | data3 = np.float32(np.random.random((num_elements, dim))) 39 | # batch 4 40 | first_id += num_elements 41 | last_id += num_elements 42 | labels4 = np.arange(first_id, last_id) 43 | data4 = np.float32(np.random.random((num_elements, dim))) 44 | 45 | # Declaring index 46 | hnsw_index = hnswlib.Index(space="l2", dim=dim) 47 | hnsw_index.init_index( 48 | max_elements=max_num_elements, 49 | ef_construction=200, 50 | M=16, 51 | allow_replace_deleted=True, 52 | ) 53 | 54 | hnsw_index.set_ef(100) 55 | hnsw_index.set_num_threads(4) 56 | 57 | # Add batch 1 and 2 58 | print("Adding batch 1") 59 | hnsw_index.add_items(data1, labels1) 60 | print("Adding batch 2") 61 | hnsw_index.add_items(data2, labels2) # maximum number of elements is reached 62 | 63 | # Delete nearest neighbors of batch 2 64 | print("Deleting neighbors of batch 2") 65 | labels2_deleted, _ = hnsw_index.knn_query(data2, k=1) 66 | # delete probable duplicates from nearest neighbors 67 | labels2_deleted_no_dup = set(labels2_deleted.flatten()) 68 | num_duplicates = len(labels2_deleted) - len(labels2_deleted_no_dup) 69 | for l in labels2_deleted_no_dup: 70 | hnsw_index.mark_deleted(l) 71 | labels1_found, _ = hnsw_index.knn_query(data1, k=1) 72 | items = hnsw_index.get_items(labels1_found) 73 | diff_with_gt_labels = np.mean(np.abs(data1 - items)) 74 | self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-3) 75 | 76 | labels2_after, _ = hnsw_index.knn_query(data2, k=1) 77 | for la in labels2_after: 78 | if la[0] in labels2_deleted_no_dup: 79 | print(f"Found deleted label {la[0]} during knn search") 80 | self.assertTrue(False) 81 | print("All the neighbors of data2 are removed") 82 | 83 | # Replace deleted elements 84 | print("Inserting batch 3 by replacing deleted elements") 85 | # Maximum number of elements is reached therefore we cannot add new items 86 | # but we can replace the deleted ones 87 | # Note: there may be less than num_elements elements. 88 | # As we could delete less than num_elements because of duplicates 89 | labels3_tr = labels3[0 : labels3.shape[0] - num_duplicates] 90 | data3_tr = data3[0 : data3.shape[0] - num_duplicates] 91 | hnsw_index.add_items(data3_tr, labels3_tr, replace_deleted=True) 92 | 93 | # After replacing, all labels should be retrievable 94 | print("Checking that remaining labels are in index") 95 | # Get remaining data from batch 1 and batch 2 after deletion of elements 96 | remaining_labels = (set(labels1) | set(labels2)) - labels2_deleted_no_dup 97 | remaining_labels_list = list(remaining_labels) 98 | comb_data = np.concatenate((data1, data2), axis=0) 99 | remaining_data = comb_data[remaining_labels_list] 100 | 101 | returned_items = hnsw_index.get_items(remaining_labels_list) 102 | self.assertSequenceEqual(remaining_data.tolist(), returned_items) 103 | 104 | returned_items = hnsw_index.get_items(labels3_tr) 105 | self.assertSequenceEqual(data3_tr.tolist(), returned_items) 106 | 107 | # Check index serialization 108 | # Delete batch 3 109 | print("Deleting batch 3") 110 | for l in labels3_tr: 111 | hnsw_index.mark_deleted(l) 112 | 113 | # Save index 114 | index_path = "index.bin" 115 | print(f"Saving index to {index_path}") 116 | hnsw_index.save_index(index_path) 117 | del hnsw_index 118 | 119 | # Reinit and load the index 120 | hnsw_index = hnswlib.Index( 121 | space="l2", dim=dim 122 | ) # the space can be changed - keeps the data, alters the distance function. 123 | hnsw_index.set_num_threads(4) 124 | print(f"Loading index from {index_path}") 125 | hnsw_index.load_index( 126 | index_path, max_elements=max_num_elements, allow_replace_deleted=True 127 | ) 128 | 129 | # Insert batch 4 130 | print("Inserting batch 4 by replacing deleted elements") 131 | labels4_tr = labels4[0 : labels4.shape[0] - num_duplicates] 132 | data4_tr = data4[0 : data4.shape[0] - num_duplicates] 133 | hnsw_index.add_items(data4_tr, labels4_tr, replace_deleted=True) 134 | 135 | # Check recall 136 | print("Checking recall") 137 | labels_found, _ = hnsw_index.knn_query(data4_tr, k=1) 138 | recall = np.mean(labels_found.reshape(-1) == labels4_tr) 139 | print(f"Recall for the 4 batch: {recall}") 140 | self.assertGreater(recall, recall_threshold) 141 | 142 | # Delete batch 4 143 | print("Deleting batch 4") 144 | for l in labels4_tr: 145 | hnsw_index.mark_deleted(l) 146 | 147 | print("Testing pickle serialization") 148 | hnsw_index_pckl = pickle.loads(pickle.dumps(hnsw_index)) 149 | del hnsw_index 150 | # Insert batch 3 151 | print("Inserting batch 3 by replacing deleted elements") 152 | hnsw_index_pckl.add_items(data3_tr, labels3_tr, replace_deleted=True) 153 | 154 | # Check recall 155 | print("Checking recall") 156 | labels_found, _ = hnsw_index_pckl.knn_query(data3_tr, k=1) 157 | recall = np.mean(labels_found.reshape(-1) == labels3_tr) 158 | print(f"Recall for the 3 batch: {recall}") 159 | self.assertGreater(recall, recall_threshold) 160 | 161 | os.remove(index_path) 162 | 163 | def test_recall_degradation(self): 164 | """ 165 | Compares recall of the index with replaced elements and without 166 | Measures recall degradation 167 | """ 168 | dim = 16 169 | num_elements = 10_000 170 | max_num_elements = 2 * num_elements 171 | query_size = 1_000 172 | k = 100 173 | 174 | recall_threshold = 0.98 175 | max_recall_diff = 0.02 176 | 177 | # Generating sample data 178 | print("Generating data") 179 | # batch 1 180 | first_id = 0 181 | last_id = num_elements 182 | labels1 = np.arange(first_id, last_id) 183 | data1 = np.float32(np.random.random((num_elements, dim))) 184 | # batch 2 185 | first_id += num_elements 186 | last_id += num_elements 187 | labels2 = np.arange(first_id, last_id) 188 | data2 = np.float32(np.random.random((num_elements, dim))) 189 | # batch 3 190 | first_id += num_elements 191 | last_id += num_elements 192 | labels3 = np.arange(first_id, last_id) 193 | data3 = np.float32(np.random.random((num_elements, dim))) 194 | # query to test recall 195 | query_data = np.float32(np.random.random((query_size, dim))) 196 | 197 | # Declaring index 198 | hnsw_index_no_replace = hnswlib.Index(space="l2", dim=dim) 199 | hnsw_index_no_replace.init_index( 200 | max_elements=max_num_elements, 201 | ef_construction=200, 202 | M=16, 203 | allow_replace_deleted=False, 204 | ) 205 | hnsw_index_with_replace = hnswlib.Index(space="l2", dim=dim) 206 | hnsw_index_with_replace.init_index( 207 | max_elements=max_num_elements, 208 | ef_construction=200, 209 | M=16, 210 | allow_replace_deleted=True, 211 | ) 212 | 213 | bf_index = hnswlib.BFIndex(space="l2", dim=dim) 214 | bf_index.init_index(max_elements=max_num_elements) 215 | 216 | hnsw_index_no_replace.set_ef(100) 217 | hnsw_index_no_replace.set_num_threads(50) 218 | hnsw_index_with_replace.set_ef(100) 219 | hnsw_index_with_replace.set_num_threads(50) 220 | 221 | # Add data 222 | print("Adding data") 223 | hnsw_index_with_replace.add_items(data1, labels1) 224 | hnsw_index_with_replace.add_items( 225 | data2, labels2 226 | ) # maximum number of elements is reached 227 | bf_index.add_items(data1, labels1) 228 | bf_index.add_items(data3, labels3) # maximum number of elements is reached 229 | 230 | for l in labels2: 231 | hnsw_index_with_replace.mark_deleted(l) 232 | hnsw_index_with_replace.add_items(data3, labels3, replace_deleted=True) 233 | 234 | hnsw_index_no_replace.add_items(data1, labels1) 235 | hnsw_index_no_replace.add_items( 236 | data3, labels3 237 | ) # maximum number of elements is reached 238 | 239 | # Query the elements and measure recall: 240 | labels_hnsw_with_replace, _ = hnsw_index_with_replace.knn_query(query_data, k) 241 | labels_hnsw_no_replace, _ = hnsw_index_no_replace.knn_query(query_data, k) 242 | labels_bf, distances_bf = bf_index.knn_query(query_data, k) 243 | 244 | # Measure recall 245 | correct_with_replace = 0 246 | correct_no_replace = 0 247 | for i in range(query_size): 248 | for label in labels_hnsw_with_replace[i]: 249 | for correct_label in labels_bf[i]: 250 | if label == correct_label: 251 | correct_with_replace += 1 252 | break 253 | for label in labels_hnsw_no_replace[i]: 254 | for correct_label in labels_bf[i]: 255 | if label == correct_label: 256 | correct_no_replace += 1 257 | break 258 | 259 | recall_with_replace = float(correct_with_replace) / (k * query_size) 260 | recall_no_replace = float(correct_no_replace) / (k * query_size) 261 | print("recall with replace:", recall_with_replace) 262 | print("recall without replace:", recall_no_replace) 263 | 264 | recall_diff = abs(recall_with_replace - recall_with_replace) 265 | 266 | self.assertGreater(recall_no_replace, recall_threshold) 267 | self.assertLess(recall_diff, max_recall_diff) 268 | -------------------------------------------------------------------------------- /tests/python/bindings_test_resize.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | import hnswlib 6 | 7 | 8 | class RandomSelfTestCase(unittest.TestCase): 9 | def testRandomSelf(self): 10 | for idx in range(16): 11 | print("\n**** Index resize test ****\n") 12 | 13 | np.random.seed(idx) 14 | dim = 16 15 | num_elements = 10000 16 | 17 | # Generating sample data 18 | data = np.float32(np.random.random((num_elements, dim))) 19 | 20 | # Declaring index 21 | p = hnswlib.Index( 22 | space="l2", dim=dim 23 | ) # possible options are l2, cosine or ip 24 | 25 | # Initiating index 26 | # max_elements - the maximum number of elements, should be known beforehand 27 | # (probably will be made optional in the future) 28 | # 29 | # ef_construction - controls index search speed/build speed tradeoff 30 | # M - is tightly connected with internal dimensionality of the data 31 | # strongly affects the memory consumption 32 | 33 | p.init_index(max_elements=num_elements // 2, ef_construction=100, M=16) 34 | 35 | # Controlling the recall by setting ef: 36 | # higher ef leads to better accuracy, but slower search 37 | p.set_ef(20) 38 | 39 | p.set_num_threads(idx % 8) # by default using all available cores 40 | 41 | # We split the data in two batches: 42 | data1 = data[: num_elements // 2] 43 | data2 = data[num_elements // 2 :] 44 | 45 | print("Adding first batch of %d elements" % (len(data1))) 46 | p.add_items(data1) 47 | 48 | # Query the elements for themselves and measure recall: 49 | labels, distances = p.knn_query(data1, k=1) 50 | 51 | items = p.get_items(list(range(len(data1)))) 52 | 53 | # Check the recall: 54 | self.assertAlmostEqual( 55 | np.mean(labels.reshape(-1) == np.arange(len(data1))), 1.0, 3 56 | ) 57 | 58 | # Check that the returned element data is correct: 59 | diff_with_gt_labels = np.max(np.abs(data1 - items)) 60 | self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-4) 61 | 62 | print("Resizing the index") 63 | p.resize_index(num_elements) 64 | 65 | print("Adding the second batch of %d elements" % (len(data2))) 66 | p.add_items(data2) 67 | 68 | # Query the elements for themselves and measure recall: 69 | labels, distances = p.knn_query(data, k=1) 70 | items = p.get_items(list(range(num_elements))) 71 | 72 | # Check the recall: 73 | self.assertAlmostEqual( 74 | np.mean(labels.reshape(-1) == np.arange(len(data))), 1.0, 3 75 | ) 76 | 77 | # Check that the returned element data is correct: 78 | diff_with_gt_labels = np.max(np.abs(data - items)) 79 | self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-4) 80 | 81 | # Checking that all labels are returned correctly: 82 | sorted_labels = sorted(p.get_ids_list()) 83 | self.assertEqual( 84 | np.sum(~np.asarray(sorted_labels) == np.asarray(range(num_elements))), 0 85 | ) 86 | -------------------------------------------------------------------------------- /tests/python/bindings_test_spaces.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | import hnswlib 6 | 7 | 8 | class RandomSelfTestCase(unittest.TestCase): 9 | def testRandomSelf(self): 10 | data1 = np.asarray( 11 | [ 12 | [1, 0, 0], 13 | [0, 1, 0], 14 | [0, 0, 1], 15 | [1, 0, 1], 16 | [1, 1, 1], 17 | ] 18 | ) 19 | 20 | for space, expected_distances in [ 21 | ("l2", [[0.0, 1.0, 2.0, 2.0, 2.0]]), 22 | ("ip", [[-2.0, -1.0, 0.0, 0.0, 0.0]]), 23 | ("cosine", [[0, 1.835e-01, 4.23e-01, 4.23e-01, 4.23e-01]]), 24 | ]: 25 | for rightdim in range(1, 128, 3): 26 | for leftdim in range(1, 32, 5): 27 | data2 = np.concatenate( 28 | [ 29 | np.zeros([data1.shape[0], leftdim]), 30 | data1, 31 | np.zeros([data1.shape[0], rightdim]), 32 | ], 33 | axis=1, 34 | ) 35 | dim = data2.shape[1] 36 | p = hnswlib.Index(space=space, dim=dim) 37 | p.init_index(max_elements=5, ef_construction=100, M=16) 38 | 39 | p.set_ef(10) 40 | 41 | p.add_items(data2) 42 | 43 | # Query the elements for themselves and measure recall: 44 | labels, distances = p.knn_query(np.asarray(data2[-1:]), k=5) 45 | 46 | diff = np.mean(np.abs(distances - expected_distances)) 47 | self.assertAlmostEqual(diff, 0, delta=1e-3) 48 | -------------------------------------------------------------------------------- /tests/python/bindings_test_stress_mt_replace.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | import hnswlib 6 | 7 | 8 | class RandomSelfTestCase(unittest.TestCase): 9 | def testRandomSelf(self): 10 | dim = 16 11 | num_elements = 1_000 12 | max_num_elements = 2 * num_elements 13 | 14 | # Generating sample data 15 | # batch 1 16 | first_id = 0 17 | last_id = num_elements 18 | labels1 = np.arange(first_id, last_id) 19 | data1 = np.float32(np.random.random((num_elements, dim))) 20 | # batch 2 21 | first_id += num_elements 22 | last_id += num_elements 23 | labels2 = np.arange(first_id, last_id) 24 | data2 = np.float32(np.random.random((num_elements, dim))) 25 | # batch 3 26 | first_id += num_elements 27 | last_id += num_elements 28 | labels3 = np.arange(first_id, last_id) 29 | data3 = np.float32(np.random.random((num_elements, dim))) 30 | 31 | for _ in range(100): 32 | # Declaring index 33 | hnsw_index = hnswlib.Index(space="l2", dim=dim) 34 | hnsw_index.init_index( 35 | max_elements=max_num_elements, 36 | ef_construction=200, 37 | M=16, 38 | allow_replace_deleted=True, 39 | ) 40 | 41 | hnsw_index.set_ef(100) 42 | hnsw_index.set_num_threads(50) 43 | 44 | # Add batch 1 and 2 45 | hnsw_index.add_items(data1, labels1) 46 | hnsw_index.add_items( 47 | data2, labels2 48 | ) # maximum number of elements is reached 49 | 50 | # Delete nearest neighbors of batch 2 51 | labels2_deleted, _ = hnsw_index.knn_query(data2, k=1) 52 | labels2_deleted_flat = labels2_deleted.flatten() 53 | # delete probable duplicates from nearest neighbors 54 | labels2_deleted_no_dup = set(labels2_deleted_flat) 55 | for l in labels2_deleted_no_dup: 56 | hnsw_index.mark_deleted(l) 57 | labels1_found, _ = hnsw_index.knn_query(data1, k=1) 58 | items = hnsw_index.get_items(labels1_found) 59 | diff_with_gt_labels = np.mean(np.abs(data1 - items)) 60 | self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-3) 61 | 62 | labels2_after, _ = hnsw_index.knn_query(data2, k=1) 63 | labels2_after_flat = labels2_after.flatten() 64 | common = np.intersect1d(labels2_after_flat, labels2_deleted_flat) 65 | self.assertTrue(common.size == 0) 66 | 67 | # Replace deleted elements 68 | # Maximum number of elements is reached therefore we cannot add new items 69 | # but we can replace the deleted ones 70 | # Note: there may be less than num_elements elements. 71 | # As we could delete less than num_elements because of duplicates 72 | num_duplicates = len(labels2_deleted) - len(labels2_deleted_no_dup) 73 | labels3_tr = labels3[0 : labels3.shape[0] - num_duplicates] 74 | data3_tr = data3[0 : data3.shape[0] - num_duplicates] 75 | hnsw_index.add_items(data3_tr, labels3_tr, replace_deleted=True) 76 | -------------------------------------------------------------------------------- /tests/python/git_tester.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | from sys import platform 5 | from pydriller import Repository 6 | 7 | 8 | speedtest_src_path = os.path.join("tests", "python", "speedtest.py") 9 | speedtest_copy_path = os.path.join("tests", "python", "speedtest2.py") 10 | shutil.copyfile( 11 | speedtest_src_path, speedtest_copy_path 12 | ) # the file has to be outside of git 13 | 14 | commits = list(Repository(".", from_tag="v0.6.2").traverse_commits()) 15 | print("Found commits:") 16 | for idx, commit in enumerate(commits): 17 | name = commit.msg.replace("\n", " ").replace("\r", " ") 18 | print(idx, commit.hash, name) 19 | 20 | for commit in commits: 21 | name = commit.msg.replace("\n", " ").replace("\r", " ").replace(",", ";") 22 | print("\nProcessing", commit.hash, name) 23 | 24 | if os.path.exists("build"): 25 | shutil.rmtree("build") 26 | os.system(f"git checkout {commit.hash}") 27 | 28 | # Checking we have actually switched the branch: 29 | current_commit = list(Repository(".").traverse_commits())[-1] 30 | if current_commit.hash != commit.hash: 31 | print("git checkout failed!!!!") 32 | print("git checkout failed!!!!") 33 | print("git checkout failed!!!!") 34 | print("git checkout failed!!!!") 35 | continue 36 | 37 | print("\n\n--------------------\n\n") 38 | ret = os.system("python -m pip install .") 39 | print("Install result:", ret) 40 | 41 | if ret != 0: 42 | print("build failed!!!!") 43 | print("build failed!!!!") 44 | print("build failed!!!!") 45 | print("build failed!!!!") 46 | continue 47 | 48 | # os.system(f'python {speedtest_copy_path} -n "{hash[:4]}_{name}" -d 32 -t 1') 49 | os.system(f'python {speedtest_copy_path} -n "{commit.hash[:4]}_{name}" -d 16 -t 1') 50 | os.system(f'python {speedtest_copy_path} -n "{commit.hash[:4]}_{name}" -d 16 -t 64') 51 | # os.system(f'python {speedtest_copy_path} -n "{name}" -d 64 -t 1') 52 | # os.system(f'python {speedtest_copy_path} -n "{name}" -d 128 -t 1') 53 | # os.system(f'python {speedtest_copy_path} -n "{name}" -d 4 -t 24') 54 | # os.system(f'python {speedtest_copy_path} -n "{name}" -d 128 -t 24') 55 | -------------------------------------------------------------------------------- /tests/python/speedtest.py: -------------------------------------------------------------------------------- 1 | import hnswlib 2 | import numpy as np 3 | import os.path 4 | import time 5 | import argparse 6 | 7 | # Use nargs to specify how many arguments an option should take. 8 | ap = argparse.ArgumentParser() 9 | ap.add_argument('-d') 10 | ap.add_argument('-n') 11 | ap.add_argument('-t') 12 | args = ap.parse_args() 13 | dim = int(args.d) 14 | name = args.n 15 | threads=int(args.t) 16 | num_elements = 400000 17 | 18 | # Generating sample data 19 | np.random.seed(1) 20 | data = np.float32(np.random.random((num_elements, dim))) 21 | 22 | 23 | # index_path=f'speed_index{dim}.bin' 24 | # Declaring index 25 | p = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip 26 | 27 | # if not os.path.isfile(index_path) : 28 | 29 | p.init_index(max_elements=num_elements, ef_construction=60, M=16) 30 | 31 | # Controlling the recall by setting ef: 32 | # higher ef leads to better accuracy, but slower search 33 | p.set_ef(10) 34 | 35 | # Set number of threads used during batch search/construction 36 | # By default using all available cores 37 | p.set_num_threads(64) 38 | t0=time.time() 39 | p.add_items(data) 40 | construction_time=time.time()-t0 41 | # Serializing and deleting the index: 42 | 43 | # print("Saving index to '%s'" % index_path) 44 | # p.save_index(index_path) 45 | p.set_num_threads(threads) 46 | times=[] 47 | time.sleep(1) 48 | p.set_ef(15) 49 | for _ in range(1): 50 | # p.load_index(index_path) 51 | for _ in range(3): 52 | t0=time.time() 53 | qdata=data[:5000*threads] 54 | labels, distances = p.knn_query(qdata, k=1) 55 | tt=time.time()-t0 56 | times.append(tt) 57 | recall=np.sum(labels.reshape(-1)==np.arange(len(qdata)))/len(qdata) 58 | print(f"{tt} seconds, recall= {recall}") 59 | 60 | str_out=f"{np.mean(times)}, {np.median(times)}, {np.std(times)}, {construction_time}, {recall}, {name}" 61 | print(str_out) 62 | with open (f"log2_{dim}_t{threads}.txt","a") as f: 63 | f.write(str_out+"\n") 64 | f.flush() 65 | 66 | --------------------------------------------------------------------------------