├── .formatter.exs ├── .github ├── do-build-manylinux2014.sh ├── script-to-build-manylinux2014.sh └── workflows │ ├── ci.yml │ ├── precompile-manylinux.yml │ └── precompile.yml ├── .gitignore ├── 3rd_party └── hnswlib │ ├── bruteforce.h │ ├── hnswalg.h │ ├── hnswlib.h │ ├── space_ip.h │ ├── space_l2.h │ ├── stop_condition.h │ └── visited_list_pool.h ├── CMakeLists.txt ├── LICENSE ├── Makefile ├── Makefile.win ├── README.md ├── c_src ├── hnswlib_index.hpp ├── hnswlib_nif.cpp ├── nif_utils.cpp └── nif_utils.hpp ├── lib ├── hnswlib_bfindex.ex ├── hnswlib_helper.ex ├── hnswlib_index.ex └── hnswlib_nif.ex ├── mix.exs ├── mix.lock └── test ├── hnswlib_bfindex_test.exs ├── hnswlib_index_test.exs └── test_helper.exs /.formatter.exs: -------------------------------------------------------------------------------- 1 | # Used by "mix format" 2 | [ 3 | inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"] 4 | ] 5 | -------------------------------------------------------------------------------- /.github/do-build-manylinux2014.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -x 4 | 5 | MIX_ENV=$1 6 | OTP_VERSION=$2 7 | ELIXIR_VERSION=$3 8 | OPENSSL_VERSION=$4 9 | ARCH=$5 10 | HNSWLIB_CI_PRECOMPILE=$6 11 | 12 | OTP_MAJOR_VERSION=$(cut -d "." -f 1 <<< "$OTP_VERSION") 13 | OPENSSL_VERSION=${OPENSSL_VERSION:-3.2.1} 14 | PERFIX_DIR="/openssl-${ARCH}" 15 | OPENSSL_ARCHIVE="openssl-${ARCH}-linux-gnu.tar.gz" 16 | OTP_ARCHIVE="otp-${ARCH}-linux-gnu.tar.gz" 17 | 18 | yum install -y openssl-devel ncurses-devel && \ 19 | cd / && \ 20 | curl -fSL "https://github.com/cocoa-xu/openssl-build/releases/download/v${OPENSSL_VERSION}/${OPENSSL_ARCHIVE}" -o "${OPENSSL_ARCHIVE}" && \ 21 | mkdir -p "${PERFIX_DIR}" && \ 22 | tar -xf "${OPENSSL_ARCHIVE}" -C "${PERFIX_DIR}" && \ 23 | curl -fSL "https://github.com/cocoa-xu/otp-build/releases/download/v${OTP_VERSION}/${OTP_ARCHIVE}" -o "${OTP_ARCHIVE}" && \ 24 | mkdir -p "otp" && \ 25 | tar -xf "${OTP_ARCHIVE}" -C "otp" && \ 26 | export PATH="/otp/usr/local/bin:${PATH}" && \ 27 | export ERL_ROOTDIR="/otp/usr/local/lib/erlang" && \ 28 | mkdir -p "elixir-${ELIXIR_VERSION}" && \ 29 | cd "elixir-${ELIXIR_VERSION}" && \ 30 | curl -fSL "https://github.com/elixir-lang/elixir/releases/download/${ELIXIR_VERSION}/elixir-otp-${OTP_MAJOR_VERSION}.zip" -o "elixir-otp-${OTP_MAJOR_VERSION}.zip" && \ 31 | unzip -q "elixir-otp-${OTP_MAJOR_VERSION}.zip" && \ 32 | export PATH="/elixir-${ELIXIR_VERSION}/bin:${PATH}" && \ 33 | export CMAKE_HNSWLIB_OPTIONS="-D CMAKE_C_FLAGS=\"-static-libgcc -static-libstdc++\" -D CMAKE_CXX_FLAGS=\"-static-libgcc -static-libstdc++\"" && \ 34 | cd /work && \ 35 | mix local.hex --force && \ 36 | mix deps.get 37 | 38 | # Mix compile 39 | cd /work 40 | export MIX_ENV="${MIX_ENV}" 41 | export ELIXIR_MAKE_CACHE_DIR=$(pwd)/cache 42 | export HNSWLIB_CI_PRECOMPILE="${HNSWLIB_CI_PRECOMPILE}" 43 | 44 | mkdir -p "${ELIXIR_MAKE_CACHE_DIR}" 45 | mix elixir_make.precompile 46 | -------------------------------------------------------------------------------- /.github/script-to-build-manylinux2014.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | set -x 4 | 5 | MIX_ENV=$1 6 | OTP_VERSION=$2 7 | ELIXIR_VERSION=$3 8 | OPENSSL_VERSION=$4 9 | ARCH=$5 10 | HNSWLIB_CI_PRECOMPILE=$6 11 | 12 | sudo docker run --privileged --network=host --rm -v `pwd`:/work "quay.io/pypa/manylinux2014_$ARCH:latest" \ 13 | sh -c "chmod a+x /work/do-build-manylinux2014.sh && /work/do-build-manylinux2014.sh ${MIX_ENV} ${OTP_VERSION} ${ELIXIR_VERSION} ${OPENSSL_VERSION} ${ARCH} ${HNSWLIB_CI_PRECOMPILE}" 14 | 15 | if [ -d "`pwd`/cache" ]; then 16 | sudo chmod -R a+wr `pwd`/cache ; 17 | fi 18 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - main 8 | paths-ignore: 9 | - '**/*.md' 10 | - 'LICENSE*' 11 | - '.github/workflows/precompile.yml' 12 | 13 | concurrency: 14 | group: ${{ github.workflow }}-${{ github.ref }} 15 | cancel-in-progress: true 16 | 17 | jobs: 18 | linux: 19 | runs-on: ubuntu-20.04 20 | env: 21 | MIX_ENV: test 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | - uses: erlef/setup-beam@v1 26 | with: 27 | otp-version: 25 28 | elixir-version: 1.14 29 | 30 | - name: Compile and Test 31 | run: | 32 | mix deps.get 33 | mix elixir_make.precompile 34 | mix test 35 | 36 | windows: 37 | runs-on: windows-latest 38 | env: 39 | MIX_ENV: test 40 | steps: 41 | - uses: actions/checkout@v4 42 | 43 | - uses: erlef/setup-beam@v1 44 | with: 45 | otp-version: "26.2.1" 46 | elixir-version: "1.16.0" 47 | 48 | - uses: ilammy/msvc-dev-cmd@v1 49 | with: 50 | arch: x64 51 | 52 | - name: Compile and Test 53 | shell: bash 54 | run: | 55 | mix deps.get 56 | mix elixir_make.precompile 57 | mix test 58 | 59 | macos: 60 | runs-on: macos-13 61 | env: 62 | MIX_ENV: test 63 | ELIXIR_VERSION: "1.16.0" 64 | OTP_VERSION: ${{ matrix.otp_version }} 65 | strategy: 66 | matrix: 67 | otp_version: ["25.3.2.8", "26.2.2"] 68 | 69 | name: macOS x86_64 - OTP ${{ matrix.otp_version }} 70 | 71 | steps: 72 | - name: Checkout 73 | uses: actions/checkout@v4 74 | 75 | - name: Install OTP and Elixir 76 | run: | 77 | curl -fsSO https://elixir-lang.org/install.sh 78 | sh install.sh "elixir@${ELIXIR_VERSION}" "otp@${OTP_VERSION}" 79 | 80 | - name: Compile and Test 81 | run: | 82 | export OTP_MAIN_VER="${OTP_VERSION%%.*}" 83 | export PATH=$HOME/.elixir-install/installs/otp/${OTP_VERSION}/bin:$PATH 84 | export PATH=$HOME/.elixir-install/installs/elixir/${ELIXIR_VERSION}-otp-${OTP_MAIN_VER}/bin:$PATH 85 | 86 | mix local.hex --force 87 | mix local.rebar --force 88 | 89 | mix deps.get 90 | mix elixir_make.precompile 91 | mix test 92 | -------------------------------------------------------------------------------- /.github/workflows/precompile-manylinux.yml: -------------------------------------------------------------------------------- 1 | name: precompile-manylinux 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.ref }} 10 | cancel-in-progress: true 11 | 12 | permissions: 13 | contents: write 14 | 15 | jobs: 16 | linux: 17 | runs-on: ubuntu-latest 18 | env: 19 | MIX_ENV: prod 20 | OPENSSL_VERSION: "3.2.1" 21 | ELIXIR_VERSION: "v1.15.4" 22 | HNSWLIB_CI_PRECOMPILE: "manylinux2014" 23 | strategy: 24 | matrix: 25 | otp_version: ["25.3.2.9", "26.2.2"] 26 | arch: [x86_64, i686, s390x] 27 | 28 | name: ${{ matrix.arch }}-linux-gnu - OTP ${{ matrix.otp_version }} 29 | 30 | steps: 31 | - uses: actions/checkout@v4 32 | 33 | - name: Pull docker image 34 | run: | 35 | sudo docker pull quay.io/pypa/manylinux2014_${{ matrix.arch }}:latest 36 | 37 | - name: Install binfmt 38 | run: | 39 | sudo apt install -y binfmt-support qemu-user-static 40 | 41 | - name: Precompile 42 | run: | 43 | cp .github/script-to-build-manylinux2014.sh ./ 44 | cp .github/do-build-manylinux2014.sh ./ 45 | 46 | bash ./script-to-build-manylinux2014.sh "${{ env.MIX_ENV }}" "${{ matrix.otp_version }}" "${{ env.ELIXIR_VERSION }}" "${{ env.OPENSSL_VERSION }}" "${{ matrix.arch }}" "${{ env.HNSWLIB_CI_PRECOMPILE }}" 47 | 48 | - uses: softprops/action-gh-release@v2 49 | if: startsWith(github.ref, 'refs/tags/') 50 | with: 51 | files: | 52 | cache/*.tar.gz 53 | cache/*.sha256 54 | -------------------------------------------------------------------------------- /.github/workflows/precompile.yml: -------------------------------------------------------------------------------- 1 | name: precompile 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | 8 | concurrency: 9 | group: ${{ github.workflow }}-${{ github.ref }} 10 | cancel-in-progress: true 11 | 12 | permissions: 13 | contents: write 14 | 15 | jobs: 16 | linux: 17 | runs-on: ubuntu-20.04 18 | env: 19 | MIX_ENV: prod 20 | HNSWLIB_CI_PRECOMPILE: "true" 21 | strategy: 22 | matrix: 23 | otp_version: [25, 26] 24 | 25 | name: Linux GNU - OTP ${{ matrix.otp_version }} 26 | 27 | steps: 28 | - name: Checkout 29 | uses: actions/checkout@v4 30 | 31 | - uses: erlef/setup-beam@v1 32 | with: 33 | otp-version: ${{ matrix.otp_version }} 34 | elixir-version: 1.15 35 | 36 | - name: Install system dependecies 37 | run: | 38 | sudo apt-get update 39 | sudo apt-get install -y build-essential automake autoconf pkg-config bc m4 unzip zip \ 40 | gcc-aarch64-linux-gnu g++-aarch64-linux-gnu \ 41 | gcc-arm-linux-gnueabihf g++-arm-linux-gnueabihf \ 42 | gcc-riscv64-linux-gnu g++-riscv64-linux-gnu \ 43 | gcc-powerpc64le-linux-gnu g++-powerpc64le-linux-gnu 44 | 45 | - name: Precompile 46 | run: | 47 | export ELIXIR_MAKE_CACHE_DIR=$(pwd)/cache 48 | mkdir -p "${ELIXIR_MAKE_CACHE_DIR}" 49 | mix deps.get 50 | mix elixir_make.precompile 51 | 52 | - uses: softprops/action-gh-release@v2 53 | if: startsWith(github.ref, 'refs/tags/') 54 | with: 55 | files: | 56 | cache/*.tar.gz 57 | cache/*.sha256 58 | 59 | macos: 60 | runs-on: macos-13 61 | env: 62 | MIX_ENV: prod 63 | ELIXIR_VERSION: "1.16.0" 64 | OTP_VERSION: ${{ matrix.otp_version }} 65 | strategy: 66 | matrix: 67 | otp_version: ["25.3.2.8", "26.2.2"] 68 | 69 | name: macOS - OTP ${{ matrix.otp_version }} 70 | 71 | steps: 72 | - name: Checkout 73 | uses: actions/checkout@v4 74 | 75 | - name: Install OTP and Elixir 76 | run: | 77 | curl -fsSO https://elixir-lang.org/install.sh 78 | sh install.sh "elixir@${ELIXIR_VERSION}" "otp@${OTP_VERSION}" 79 | 80 | - name: Compile and Test 81 | run: | 82 | export OTP_MAIN_VER="${OTP_VERSION%%.*}" 83 | export PATH=$HOME/.elixir-install/installs/otp/${OTP_VERSION}/bin:$PATH 84 | export PATH=$HOME/.elixir-install/installs/elixir/${ELIXIR_VERSION}-otp-${OTP_MAIN_VER}/bin:$PATH 85 | 86 | export ELIXIR_MAKE_CACHE_DIR=$(pwd)/cache 87 | mkdir -p "${ELIXIR_MAKE_CACHE_DIR}" 88 | 89 | mix local.hex --force 90 | mix local.rebar --force 91 | 92 | mix deps.get 93 | mix elixir_make.precompile 94 | 95 | - uses: softprops/action-gh-release@v2 96 | if: startsWith(github.ref, 'refs/tags/') 97 | with: 98 | files: | 99 | cache/*.tar.gz 100 | cache/*.sha256 101 | 102 | windows: 103 | runs-on: windows-latest 104 | env: 105 | MIX_ENV: prod 106 | strategy: 107 | matrix: 108 | otp_version: [25, 26] 109 | 110 | name: Windows - OTP ${{ matrix.otp_version }} 111 | 112 | steps: 113 | - uses: actions/checkout@v4 114 | 115 | - uses: erlef/setup-beam@v1 116 | with: 117 | otp-version: ${{ matrix.otp_version }} 118 | elixir-version: 1.15 119 | 120 | - uses: ilammy/msvc-dev-cmd@v1 121 | with: 122 | arch: x64 123 | 124 | - name: Precompile 125 | shell: bash 126 | run: | 127 | export ELIXIR_MAKE_CACHE_DIR=$(pwd)/cache 128 | mkdir -p "${ELIXIR_MAKE_CACHE_DIR}" 129 | mix deps.get 130 | mix elixir_make.precompile 131 | 132 | - uses: softprops/action-gh-release@v2 133 | if: startsWith(github.ref, 'refs/tags/') 134 | with: 135 | files: | 136 | cache/*.tar.gz 137 | cache/*.sha256 138 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # The directory Mix will write compiled artifacts to. 2 | /_build/ 3 | 4 | # If you run "mix test --cover", coverage assets end up here. 5 | /cover/ 6 | 7 | # The directory Mix downloads your dependencies sources to. 8 | /deps/ 9 | 10 | # Where third-party dependencies like ExDoc output generated docs. 11 | /doc/ 12 | 13 | # Ignore .fetch files in case you like to edit your project deps locally. 14 | /.fetch 15 | 16 | # If the VM crashes, it generates a dump, let's ignore it too. 17 | erl_crash.dump 18 | 19 | # Also ignore archive artifacts (built via "mix archive.build"). 20 | *.ez 21 | 22 | # Ignore package tarball (built via "mix hex.build"). 23 | hnsw_elixir-*.tar 24 | 25 | # Temporary files, for example, from tests. 26 | /tmp/ 27 | 28 | # macOS 29 | .DS_Store 30 | 31 | # IDE 32 | .vscode 33 | -------------------------------------------------------------------------------- /3rd_party/hnswlib/bruteforce.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | namespace hnswlib { 9 | template 10 | class BruteforceSearch : public AlgorithmInterface { 11 | public: 12 | char *data_; 13 | size_t maxelements_; 14 | size_t cur_element_count; 15 | size_t size_per_element_; 16 | 17 | size_t data_size_; 18 | DISTFUNC fstdistfunc_; 19 | void *dist_func_param_; 20 | std::mutex index_lock; 21 | 22 | std::unordered_map dict_external_to_internal; 23 | 24 | 25 | BruteforceSearch(SpaceInterface *s) 26 | : data_(nullptr), 27 | maxelements_(0), 28 | cur_element_count(0), 29 | size_per_element_(0), 30 | data_size_(0), 31 | dist_func_param_(nullptr) { 32 | } 33 | 34 | 35 | BruteforceSearch(SpaceInterface *s, const std::string &location) 36 | : data_(nullptr), 37 | maxelements_(0), 38 | cur_element_count(0), 39 | size_per_element_(0), 40 | data_size_(0), 41 | dist_func_param_(nullptr) { 42 | loadIndex(location, s); 43 | } 44 | 45 | 46 | BruteforceSearch(SpaceInterface *s, size_t maxElements) { 47 | maxelements_ = maxElements; 48 | data_size_ = s->get_data_size(); 49 | fstdistfunc_ = s->get_dist_func(); 50 | dist_func_param_ = s->get_dist_func_param(); 51 | size_per_element_ = data_size_ + sizeof(labeltype); 52 | data_ = (char *) malloc(maxElements * size_per_element_); 53 | if (data_ == nullptr) 54 | throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data"); 55 | cur_element_count = 0; 56 | } 57 | 58 | 59 | ~BruteforceSearch() { 60 | free(data_); 61 | } 62 | 63 | 64 | void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) { 65 | int idx; 66 | { 67 | std::unique_lock lock(index_lock); 68 | 69 | auto search = dict_external_to_internal.find(label); 70 | if (search != dict_external_to_internal.end()) { 71 | idx = search->second; 72 | } else { 73 | if (cur_element_count >= maxelements_) { 74 | throw std::runtime_error("The number of elements exceeds the specified limit\n"); 75 | } 76 | idx = cur_element_count; 77 | dict_external_to_internal[label] = idx; 78 | cur_element_count++; 79 | } 80 | } 81 | memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype)); 82 | memcpy(data_ + size_per_element_ * idx, datapoint, data_size_); 83 | } 84 | 85 | 86 | void removePoint(labeltype cur_external) { 87 | std::unique_lock lock(index_lock); 88 | 89 | auto found = dict_external_to_internal.find(cur_external); 90 | if (found == dict_external_to_internal.end()) { 91 | return; 92 | } 93 | 94 | size_t cur_c = found->second; 95 | dict_external_to_internal.erase(found); 96 | 97 | labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_)); 98 | dict_external_to_internal[label] = cur_c; 99 | memcpy(data_ + size_per_element_ * cur_c, 100 | data_ + size_per_element_ * (cur_element_count-1), 101 | data_size_+sizeof(labeltype)); 102 | cur_element_count--; 103 | } 104 | 105 | 106 | std::priority_queue> 107 | searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { 108 | assert(k <= cur_element_count); 109 | std::priority_queue> topResults; 110 | if (cur_element_count == 0) return topResults; 111 | for (int i = 0; i < k; i++) { 112 | dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); 113 | labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_)); 114 | if ((!isIdAllowed) || (*isIdAllowed)(label)) { 115 | topResults.emplace(dist, label); 116 | } 117 | } 118 | dist_t lastdist = topResults.empty() ? std::numeric_limits::max() : topResults.top().first; 119 | for (int i = k; i < cur_element_count; i++) { 120 | dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); 121 | if (dist <= lastdist) { 122 | labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_)); 123 | if ((!isIdAllowed) || (*isIdAllowed)(label)) { 124 | topResults.emplace(dist, label); 125 | } 126 | if (topResults.size() > k) 127 | topResults.pop(); 128 | 129 | if (!topResults.empty()) { 130 | lastdist = topResults.top().first; 131 | } 132 | } 133 | } 134 | return topResults; 135 | } 136 | 137 | 138 | void saveIndex(const std::string &location) { 139 | std::ofstream output(location, std::ios::binary); 140 | std::streampos position; 141 | 142 | writeBinaryPOD(output, maxelements_); 143 | writeBinaryPOD(output, size_per_element_); 144 | writeBinaryPOD(output, cur_element_count); 145 | 146 | output.write(data_, maxelements_ * size_per_element_); 147 | 148 | output.close(); 149 | } 150 | 151 | 152 | void loadIndex(const std::string &location, SpaceInterface *s) { 153 | std::ifstream input(location, std::ios::binary); 154 | std::streampos position; 155 | 156 | readBinaryPOD(input, maxelements_); 157 | readBinaryPOD(input, size_per_element_); 158 | readBinaryPOD(input, cur_element_count); 159 | 160 | data_size_ = s->get_data_size(); 161 | fstdistfunc_ = s->get_dist_func(); 162 | dist_func_param_ = s->get_dist_func_param(); 163 | size_per_element_ = data_size_ + sizeof(labeltype); 164 | data_ = (char *) malloc(maxelements_ * size_per_element_); 165 | if (data_ == nullptr) 166 | throw std::runtime_error("Not enough memory: loadIndex failed to allocate data"); 167 | 168 | input.read(data_, maxelements_ * size_per_element_); 169 | 170 | input.close(); 171 | } 172 | }; 173 | } // namespace hnswlib 174 | -------------------------------------------------------------------------------- /3rd_party/hnswlib/hnswlib.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // https://github.com/nmslib/hnswlib/pull/508 4 | // This allows others to provide their own error stream (e.g. RcppHNSW) 5 | #ifndef HNSWLIB_ERR_OVERRIDE 6 | #define HNSWERR std::cerr 7 | #else 8 | #define HNSWERR HNSWLIB_ERR_OVERRIDE 9 | #endif 10 | 11 | #ifndef NO_MANUAL_VECTORIZATION 12 | #if (defined(__SSE__) || _M_IX86_FP > 0 || defined(_M_AMD64) || defined(_M_X64)) 13 | #define USE_SSE 14 | #ifdef __AVX__ 15 | #define USE_AVX 16 | #ifdef __AVX512F__ 17 | #define USE_AVX512 18 | #endif 19 | #endif 20 | #endif 21 | #endif 22 | 23 | #if defined(USE_AVX) || defined(USE_SSE) 24 | #ifdef _MSC_VER 25 | #include 26 | #include 27 | static void cpuid(int32_t out[4], int32_t eax, int32_t ecx) { 28 | __cpuidex(out, eax, ecx); 29 | } 30 | static __int64 xgetbv(unsigned int x) { 31 | return _xgetbv(x); 32 | } 33 | #else 34 | #include 35 | #include 36 | #include 37 | static void cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) { 38 | __cpuid_count(eax, ecx, cpuInfo[0], cpuInfo[1], cpuInfo[2], cpuInfo[3]); 39 | } 40 | static uint64_t xgetbv(unsigned int index) { 41 | uint32_t eax, edx; 42 | __asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index)); 43 | return ((uint64_t)edx << 32) | eax; 44 | } 45 | #endif 46 | 47 | #if defined(USE_AVX512) 48 | #include 49 | #endif 50 | 51 | #if defined(__GNUC__) 52 | #define PORTABLE_ALIGN32 __attribute__((aligned(32))) 53 | #define PORTABLE_ALIGN64 __attribute__((aligned(64))) 54 | #else 55 | #define PORTABLE_ALIGN32 __declspec(align(32)) 56 | #define PORTABLE_ALIGN64 __declspec(align(64)) 57 | #endif 58 | 59 | // Adapted from https://github.com/Mysticial/FeatureDetector 60 | #define _XCR_XFEATURE_ENABLED_MASK 0 61 | 62 | static bool AVXCapable() { 63 | int cpuInfo[4]; 64 | 65 | // CPU support 66 | cpuid(cpuInfo, 0, 0); 67 | int nIds = cpuInfo[0]; 68 | 69 | bool HW_AVX = false; 70 | if (nIds >= 0x00000001) { 71 | cpuid(cpuInfo, 0x00000001, 0); 72 | HW_AVX = (cpuInfo[2] & ((int)1 << 28)) != 0; 73 | } 74 | 75 | // OS support 76 | cpuid(cpuInfo, 1, 0); 77 | 78 | bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0; 79 | bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0; 80 | 81 | bool avxSupported = false; 82 | if (osUsesXSAVE_XRSTORE && cpuAVXSuport) { 83 | uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK); 84 | avxSupported = (xcrFeatureMask & 0x6) == 0x6; 85 | } 86 | return HW_AVX && avxSupported; 87 | } 88 | 89 | static bool AVX512Capable() { 90 | if (!AVXCapable()) return false; 91 | 92 | int cpuInfo[4]; 93 | 94 | // CPU support 95 | cpuid(cpuInfo, 0, 0); 96 | int nIds = cpuInfo[0]; 97 | 98 | bool HW_AVX512F = false; 99 | if (nIds >= 0x00000007) { // AVX512 Foundation 100 | cpuid(cpuInfo, 0x00000007, 0); 101 | HW_AVX512F = (cpuInfo[1] & ((int)1 << 16)) != 0; 102 | } 103 | 104 | // OS support 105 | cpuid(cpuInfo, 1, 0); 106 | 107 | bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0; 108 | bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0; 109 | 110 | bool avx512Supported = false; 111 | if (osUsesXSAVE_XRSTORE && cpuAVXSuport) { 112 | uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK); 113 | avx512Supported = (xcrFeatureMask & 0xe6) == 0xe6; 114 | } 115 | return HW_AVX512F && avx512Supported; 116 | } 117 | #endif 118 | 119 | #include 120 | #include 121 | #include 122 | #include 123 | 124 | namespace hnswlib { 125 | typedef size_t labeltype; 126 | 127 | // This can be extended to store state for filtering (e.g. from a std::set) 128 | class BaseFilterFunctor { 129 | public: 130 | virtual bool operator()(hnswlib::labeltype id) { return true; } 131 | virtual ~BaseFilterFunctor() {}; 132 | }; 133 | 134 | template 135 | class BaseSearchStopCondition { 136 | public: 137 | virtual void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) = 0; 138 | 139 | virtual void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) = 0; 140 | 141 | virtual bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) = 0; 142 | 143 | virtual bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) = 0; 144 | 145 | virtual bool should_remove_extra() = 0; 146 | 147 | virtual void filter_results(std::vector> &candidates) = 0; 148 | 149 | virtual ~BaseSearchStopCondition() {} 150 | }; 151 | 152 | template 153 | class pairGreater { 154 | public: 155 | bool operator()(const T& p1, const T& p2) { 156 | return p1.first > p2.first; 157 | } 158 | }; 159 | 160 | template 161 | static void writeBinaryPOD(std::ostream &out, const T &podRef) { 162 | out.write((char *) &podRef, sizeof(T)); 163 | } 164 | 165 | template 166 | static void readBinaryPOD(std::istream &in, T &podRef) { 167 | in.read((char *) &podRef, sizeof(T)); 168 | } 169 | 170 | template 171 | using DISTFUNC = MTYPE(*)(const void *, const void *, const void *); 172 | 173 | template 174 | class SpaceInterface { 175 | public: 176 | // virtual void search(void *); 177 | virtual size_t get_data_size() = 0; 178 | 179 | virtual DISTFUNC get_dist_func() = 0; 180 | 181 | virtual void *get_dist_func_param() = 0; 182 | 183 | virtual ~SpaceInterface() {} 184 | }; 185 | 186 | template 187 | class AlgorithmInterface { 188 | public: 189 | virtual void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) = 0; 190 | 191 | virtual std::priority_queue> 192 | searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0; 193 | 194 | // Return k nearest neighbor in the order of closer fist 195 | virtual std::vector> 196 | searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const; 197 | 198 | virtual void saveIndex(const std::string &location) = 0; 199 | virtual ~AlgorithmInterface(){ 200 | } 201 | }; 202 | 203 | template 204 | std::vector> 205 | AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k, 206 | BaseFilterFunctor* isIdAllowed) const { 207 | std::vector> result; 208 | 209 | // here searchKnn returns the result in the order of further first 210 | auto ret = searchKnn(query_data, k, isIdAllowed); 211 | { 212 | size_t sz = ret.size(); 213 | result.resize(sz); 214 | while (!ret.empty()) { 215 | result[--sz] = ret.top(); 216 | ret.pop(); 217 | } 218 | } 219 | 220 | return result; 221 | } 222 | } // namespace hnswlib 223 | 224 | #include "space_l2.h" 225 | #include "space_ip.h" 226 | #include "stop_condition.h" 227 | #include "bruteforce.h" 228 | #include "hnswalg.h" 229 | -------------------------------------------------------------------------------- /3rd_party/hnswlib/space_ip.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "hnswlib.h" 3 | 4 | namespace hnswlib { 5 | 6 | static float 7 | InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) { 8 | size_t qty = *((size_t *) qty_ptr); 9 | float res = 0; 10 | for (unsigned i = 0; i < qty; i++) { 11 | res += ((float *) pVect1)[i] * ((float *) pVect2)[i]; 12 | } 13 | return res; 14 | } 15 | 16 | static float 17 | InnerProductDistance(const void *pVect1, const void *pVect2, const void *qty_ptr) { 18 | return 1.0f - InnerProduct(pVect1, pVect2, qty_ptr); 19 | } 20 | 21 | #if defined(USE_AVX) 22 | 23 | // Favor using AVX if available. 24 | static float 25 | InnerProductSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 26 | float PORTABLE_ALIGN32 TmpRes[8]; 27 | float *pVect1 = (float *) pVect1v; 28 | float *pVect2 = (float *) pVect2v; 29 | size_t qty = *((size_t *) qty_ptr); 30 | 31 | size_t qty16 = qty / 16; 32 | size_t qty4 = qty / 4; 33 | 34 | const float *pEnd1 = pVect1 + 16 * qty16; 35 | const float *pEnd2 = pVect1 + 4 * qty4; 36 | 37 | __m256 sum256 = _mm256_set1_ps(0); 38 | 39 | while (pVect1 < pEnd1) { 40 | //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); 41 | 42 | __m256 v1 = _mm256_loadu_ps(pVect1); 43 | pVect1 += 8; 44 | __m256 v2 = _mm256_loadu_ps(pVect2); 45 | pVect2 += 8; 46 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); 47 | 48 | v1 = _mm256_loadu_ps(pVect1); 49 | pVect1 += 8; 50 | v2 = _mm256_loadu_ps(pVect2); 51 | pVect2 += 8; 52 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); 53 | } 54 | 55 | __m128 v1, v2; 56 | __m128 sum_prod = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1)); 57 | 58 | while (pVect1 < pEnd2) { 59 | v1 = _mm_loadu_ps(pVect1); 60 | pVect1 += 4; 61 | v2 = _mm_loadu_ps(pVect2); 62 | pVect2 += 4; 63 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 64 | } 65 | 66 | _mm_store_ps(TmpRes, sum_prod); 67 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 68 | return sum; 69 | } 70 | 71 | static float 72 | InnerProductDistanceSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 73 | return 1.0f - InnerProductSIMD4ExtAVX(pVect1v, pVect2v, qty_ptr); 74 | } 75 | 76 | #endif 77 | 78 | #if defined(USE_SSE) 79 | 80 | static float 81 | InnerProductSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 82 | float PORTABLE_ALIGN32 TmpRes[8]; 83 | float *pVect1 = (float *) pVect1v; 84 | float *pVect2 = (float *) pVect2v; 85 | size_t qty = *((size_t *) qty_ptr); 86 | 87 | size_t qty16 = qty / 16; 88 | size_t qty4 = qty / 4; 89 | 90 | const float *pEnd1 = pVect1 + 16 * qty16; 91 | const float *pEnd2 = pVect1 + 4 * qty4; 92 | 93 | __m128 v1, v2; 94 | __m128 sum_prod = _mm_set1_ps(0); 95 | 96 | while (pVect1 < pEnd1) { 97 | v1 = _mm_loadu_ps(pVect1); 98 | pVect1 += 4; 99 | v2 = _mm_loadu_ps(pVect2); 100 | pVect2 += 4; 101 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 102 | 103 | v1 = _mm_loadu_ps(pVect1); 104 | pVect1 += 4; 105 | v2 = _mm_loadu_ps(pVect2); 106 | pVect2 += 4; 107 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 108 | 109 | v1 = _mm_loadu_ps(pVect1); 110 | pVect1 += 4; 111 | v2 = _mm_loadu_ps(pVect2); 112 | pVect2 += 4; 113 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 114 | 115 | v1 = _mm_loadu_ps(pVect1); 116 | pVect1 += 4; 117 | v2 = _mm_loadu_ps(pVect2); 118 | pVect2 += 4; 119 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 120 | } 121 | 122 | while (pVect1 < pEnd2) { 123 | v1 = _mm_loadu_ps(pVect1); 124 | pVect1 += 4; 125 | v2 = _mm_loadu_ps(pVect2); 126 | pVect2 += 4; 127 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 128 | } 129 | 130 | _mm_store_ps(TmpRes, sum_prod); 131 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 132 | 133 | return sum; 134 | } 135 | 136 | static float 137 | InnerProductDistanceSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 138 | return 1.0f - InnerProductSIMD4ExtSSE(pVect1v, pVect2v, qty_ptr); 139 | } 140 | 141 | #endif 142 | 143 | 144 | #if defined(USE_AVX512) 145 | 146 | static float 147 | InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 148 | float PORTABLE_ALIGN64 TmpRes[16]; 149 | float *pVect1 = (float *) pVect1v; 150 | float *pVect2 = (float *) pVect2v; 151 | size_t qty = *((size_t *) qty_ptr); 152 | 153 | size_t qty16 = qty / 16; 154 | 155 | 156 | const float *pEnd1 = pVect1 + 16 * qty16; 157 | 158 | __m512 sum512 = _mm512_set1_ps(0); 159 | 160 | size_t loop = qty16 / 4; 161 | 162 | while (loop--) { 163 | __m512 v1 = _mm512_loadu_ps(pVect1); 164 | __m512 v2 = _mm512_loadu_ps(pVect2); 165 | pVect1 += 16; 166 | pVect2 += 16; 167 | 168 | __m512 v3 = _mm512_loadu_ps(pVect1); 169 | __m512 v4 = _mm512_loadu_ps(pVect2); 170 | pVect1 += 16; 171 | pVect2 += 16; 172 | 173 | __m512 v5 = _mm512_loadu_ps(pVect1); 174 | __m512 v6 = _mm512_loadu_ps(pVect2); 175 | pVect1 += 16; 176 | pVect2 += 16; 177 | 178 | __m512 v7 = _mm512_loadu_ps(pVect1); 179 | __m512 v8 = _mm512_loadu_ps(pVect2); 180 | pVect1 += 16; 181 | pVect2 += 16; 182 | 183 | sum512 = _mm512_fmadd_ps(v1, v2, sum512); 184 | sum512 = _mm512_fmadd_ps(v3, v4, sum512); 185 | sum512 = _mm512_fmadd_ps(v5, v6, sum512); 186 | sum512 = _mm512_fmadd_ps(v7, v8, sum512); 187 | } 188 | 189 | while (pVect1 < pEnd1) { 190 | __m512 v1 = _mm512_loadu_ps(pVect1); 191 | __m512 v2 = _mm512_loadu_ps(pVect2); 192 | pVect1 += 16; 193 | pVect2 += 16; 194 | sum512 = _mm512_fmadd_ps(v1, v2, sum512); 195 | } 196 | 197 | float sum = _mm512_reduce_add_ps(sum512); 198 | return sum; 199 | } 200 | 201 | static float 202 | InnerProductDistanceSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 203 | return 1.0f - InnerProductSIMD16ExtAVX512(pVect1v, pVect2v, qty_ptr); 204 | } 205 | 206 | #endif 207 | 208 | #if defined(USE_AVX) 209 | 210 | static float 211 | InnerProductSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 212 | float PORTABLE_ALIGN32 TmpRes[8]; 213 | float *pVect1 = (float *) pVect1v; 214 | float *pVect2 = (float *) pVect2v; 215 | size_t qty = *((size_t *) qty_ptr); 216 | 217 | size_t qty16 = qty / 16; 218 | 219 | 220 | const float *pEnd1 = pVect1 + 16 * qty16; 221 | 222 | __m256 sum256 = _mm256_set1_ps(0); 223 | 224 | while (pVect1 < pEnd1) { 225 | //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); 226 | 227 | __m256 v1 = _mm256_loadu_ps(pVect1); 228 | pVect1 += 8; 229 | __m256 v2 = _mm256_loadu_ps(pVect2); 230 | pVect2 += 8; 231 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); 232 | 233 | v1 = _mm256_loadu_ps(pVect1); 234 | pVect1 += 8; 235 | v2 = _mm256_loadu_ps(pVect2); 236 | pVect2 += 8; 237 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); 238 | } 239 | 240 | _mm256_store_ps(TmpRes, sum256); 241 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; 242 | 243 | return sum; 244 | } 245 | 246 | static float 247 | InnerProductDistanceSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 248 | return 1.0f - InnerProductSIMD16ExtAVX(pVect1v, pVect2v, qty_ptr); 249 | } 250 | 251 | #endif 252 | 253 | #if defined(USE_SSE) 254 | 255 | static float 256 | InnerProductSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 257 | float PORTABLE_ALIGN32 TmpRes[8]; 258 | float *pVect1 = (float *) pVect1v; 259 | float *pVect2 = (float *) pVect2v; 260 | size_t qty = *((size_t *) qty_ptr); 261 | 262 | size_t qty16 = qty / 16; 263 | 264 | const float *pEnd1 = pVect1 + 16 * qty16; 265 | 266 | __m128 v1, v2; 267 | __m128 sum_prod = _mm_set1_ps(0); 268 | 269 | while (pVect1 < pEnd1) { 270 | v1 = _mm_loadu_ps(pVect1); 271 | pVect1 += 4; 272 | v2 = _mm_loadu_ps(pVect2); 273 | pVect2 += 4; 274 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 275 | 276 | v1 = _mm_loadu_ps(pVect1); 277 | pVect1 += 4; 278 | v2 = _mm_loadu_ps(pVect2); 279 | pVect2 += 4; 280 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 281 | 282 | v1 = _mm_loadu_ps(pVect1); 283 | pVect1 += 4; 284 | v2 = _mm_loadu_ps(pVect2); 285 | pVect2 += 4; 286 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 287 | 288 | v1 = _mm_loadu_ps(pVect1); 289 | pVect1 += 4; 290 | v2 = _mm_loadu_ps(pVect2); 291 | pVect2 += 4; 292 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 293 | } 294 | _mm_store_ps(TmpRes, sum_prod); 295 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 296 | 297 | return sum; 298 | } 299 | 300 | static float 301 | InnerProductDistanceSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 302 | return 1.0f - InnerProductSIMD16ExtSSE(pVect1v, pVect2v, qty_ptr); 303 | } 304 | 305 | #endif 306 | 307 | #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) 308 | static DISTFUNC InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE; 309 | static DISTFUNC InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE; 310 | static DISTFUNC InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtSSE; 311 | static DISTFUNC InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtSSE; 312 | 313 | static float 314 | InnerProductDistanceSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 315 | size_t qty = *((size_t *) qty_ptr); 316 | size_t qty16 = qty >> 4 << 4; 317 | float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16); 318 | float *pVect1 = (float *) pVect1v + qty16; 319 | float *pVect2 = (float *) pVect2v + qty16; 320 | 321 | size_t qty_left = qty - qty16; 322 | float res_tail = InnerProduct(pVect1, pVect2, &qty_left); 323 | return 1.0f - (res + res_tail); 324 | } 325 | 326 | static float 327 | InnerProductDistanceSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 328 | size_t qty = *((size_t *) qty_ptr); 329 | size_t qty4 = qty >> 2 << 2; 330 | 331 | float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4); 332 | size_t qty_left = qty - qty4; 333 | 334 | float *pVect1 = (float *) pVect1v + qty4; 335 | float *pVect2 = (float *) pVect2v + qty4; 336 | float res_tail = InnerProduct(pVect1, pVect2, &qty_left); 337 | 338 | return 1.0f - (res + res_tail); 339 | } 340 | #endif 341 | 342 | class InnerProductSpace : public SpaceInterface { 343 | DISTFUNC fstdistfunc_; 344 | size_t data_size_; 345 | size_t dim_; 346 | 347 | public: 348 | InnerProductSpace(size_t dim) { 349 | fstdistfunc_ = InnerProductDistance; 350 | #if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512) 351 | #if defined(USE_AVX512) 352 | if (AVX512Capable()) { 353 | InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512; 354 | InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512; 355 | } else if (AVXCapable()) { 356 | InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; 357 | InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; 358 | } 359 | #elif defined(USE_AVX) 360 | if (AVXCapable()) { 361 | InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; 362 | InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; 363 | } 364 | #endif 365 | #if defined(USE_AVX) 366 | if (AVXCapable()) { 367 | InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX; 368 | InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX; 369 | } 370 | #endif 371 | 372 | if (dim % 16 == 0) 373 | fstdistfunc_ = InnerProductDistanceSIMD16Ext; 374 | else if (dim % 4 == 0) 375 | fstdistfunc_ = InnerProductDistanceSIMD4Ext; 376 | else if (dim > 16) 377 | fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals; 378 | else if (dim > 4) 379 | fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals; 380 | #endif 381 | dim_ = dim; 382 | data_size_ = dim * sizeof(float); 383 | } 384 | 385 | size_t get_data_size() { 386 | return data_size_; 387 | } 388 | 389 | DISTFUNC get_dist_func() { 390 | return fstdistfunc_; 391 | } 392 | 393 | void *get_dist_func_param() { 394 | return &dim_; 395 | } 396 | 397 | ~InnerProductSpace() {} 398 | }; 399 | 400 | } // namespace hnswlib 401 | -------------------------------------------------------------------------------- /3rd_party/hnswlib/space_l2.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "hnswlib.h" 3 | 4 | namespace hnswlib { 5 | 6 | static float 7 | L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 8 | float *pVect1 = (float *) pVect1v; 9 | float *pVect2 = (float *) pVect2v; 10 | size_t qty = *((size_t *) qty_ptr); 11 | 12 | float res = 0; 13 | for (size_t i = 0; i < qty; i++) { 14 | float t = *pVect1 - *pVect2; 15 | pVect1++; 16 | pVect2++; 17 | res += t * t; 18 | } 19 | return (res); 20 | } 21 | 22 | #if defined(USE_AVX512) 23 | 24 | // Favor using AVX512 if available. 25 | static float 26 | L2SqrSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 27 | float *pVect1 = (float *) pVect1v; 28 | float *pVect2 = (float *) pVect2v; 29 | size_t qty = *((size_t *) qty_ptr); 30 | float PORTABLE_ALIGN64 TmpRes[16]; 31 | size_t qty16 = qty >> 4; 32 | 33 | const float *pEnd1 = pVect1 + (qty16 << 4); 34 | 35 | __m512 diff, v1, v2; 36 | __m512 sum = _mm512_set1_ps(0); 37 | 38 | while (pVect1 < pEnd1) { 39 | v1 = _mm512_loadu_ps(pVect1); 40 | pVect1 += 16; 41 | v2 = _mm512_loadu_ps(pVect2); 42 | pVect2 += 16; 43 | diff = _mm512_sub_ps(v1, v2); 44 | // sum = _mm512_fmadd_ps(diff, diff, sum); 45 | sum = _mm512_add_ps(sum, _mm512_mul_ps(diff, diff)); 46 | } 47 | 48 | _mm512_store_ps(TmpRes, sum); 49 | float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + 50 | TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + 51 | TmpRes[13] + TmpRes[14] + TmpRes[15]; 52 | 53 | return (res); 54 | } 55 | #endif 56 | 57 | #if defined(USE_AVX) 58 | 59 | // Favor using AVX if available. 60 | static float 61 | L2SqrSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 62 | float *pVect1 = (float *) pVect1v; 63 | float *pVect2 = (float *) pVect2v; 64 | size_t qty = *((size_t *) qty_ptr); 65 | float PORTABLE_ALIGN32 TmpRes[8]; 66 | size_t qty16 = qty >> 4; 67 | 68 | const float *pEnd1 = pVect1 + (qty16 << 4); 69 | 70 | __m256 diff, v1, v2; 71 | __m256 sum = _mm256_set1_ps(0); 72 | 73 | while (pVect1 < pEnd1) { 74 | v1 = _mm256_loadu_ps(pVect1); 75 | pVect1 += 8; 76 | v2 = _mm256_loadu_ps(pVect2); 77 | pVect2 += 8; 78 | diff = _mm256_sub_ps(v1, v2); 79 | sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); 80 | 81 | v1 = _mm256_loadu_ps(pVect1); 82 | pVect1 += 8; 83 | v2 = _mm256_loadu_ps(pVect2); 84 | pVect2 += 8; 85 | diff = _mm256_sub_ps(v1, v2); 86 | sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); 87 | } 88 | 89 | _mm256_store_ps(TmpRes, sum); 90 | return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; 91 | } 92 | 93 | #endif 94 | 95 | #if defined(USE_SSE) 96 | 97 | static float 98 | L2SqrSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 99 | float *pVect1 = (float *) pVect1v; 100 | float *pVect2 = (float *) pVect2v; 101 | size_t qty = *((size_t *) qty_ptr); 102 | float PORTABLE_ALIGN32 TmpRes[8]; 103 | size_t qty16 = qty >> 4; 104 | 105 | const float *pEnd1 = pVect1 + (qty16 << 4); 106 | 107 | __m128 diff, v1, v2; 108 | __m128 sum = _mm_set1_ps(0); 109 | 110 | while (pVect1 < pEnd1) { 111 | //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); 112 | v1 = _mm_loadu_ps(pVect1); 113 | pVect1 += 4; 114 | v2 = _mm_loadu_ps(pVect2); 115 | pVect2 += 4; 116 | diff = _mm_sub_ps(v1, v2); 117 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 118 | 119 | v1 = _mm_loadu_ps(pVect1); 120 | pVect1 += 4; 121 | v2 = _mm_loadu_ps(pVect2); 122 | pVect2 += 4; 123 | diff = _mm_sub_ps(v1, v2); 124 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 125 | 126 | v1 = _mm_loadu_ps(pVect1); 127 | pVect1 += 4; 128 | v2 = _mm_loadu_ps(pVect2); 129 | pVect2 += 4; 130 | diff = _mm_sub_ps(v1, v2); 131 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 132 | 133 | v1 = _mm_loadu_ps(pVect1); 134 | pVect1 += 4; 135 | v2 = _mm_loadu_ps(pVect2); 136 | pVect2 += 4; 137 | diff = _mm_sub_ps(v1, v2); 138 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 139 | } 140 | 141 | _mm_store_ps(TmpRes, sum); 142 | return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 143 | } 144 | #endif 145 | 146 | #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) 147 | static DISTFUNC L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE; 148 | 149 | static float 150 | L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 151 | size_t qty = *((size_t *) qty_ptr); 152 | size_t qty16 = qty >> 4 << 4; 153 | float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16); 154 | float *pVect1 = (float *) pVect1v + qty16; 155 | float *pVect2 = (float *) pVect2v + qty16; 156 | 157 | size_t qty_left = qty - qty16; 158 | float res_tail = L2Sqr(pVect1, pVect2, &qty_left); 159 | return (res + res_tail); 160 | } 161 | #endif 162 | 163 | 164 | #if defined(USE_SSE) 165 | static float 166 | L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 167 | float PORTABLE_ALIGN32 TmpRes[8]; 168 | float *pVect1 = (float *) pVect1v; 169 | float *pVect2 = (float *) pVect2v; 170 | size_t qty = *((size_t *) qty_ptr); 171 | 172 | 173 | size_t qty4 = qty >> 2; 174 | 175 | const float *pEnd1 = pVect1 + (qty4 << 2); 176 | 177 | __m128 diff, v1, v2; 178 | __m128 sum = _mm_set1_ps(0); 179 | 180 | while (pVect1 < pEnd1) { 181 | v1 = _mm_loadu_ps(pVect1); 182 | pVect1 += 4; 183 | v2 = _mm_loadu_ps(pVect2); 184 | pVect2 += 4; 185 | diff = _mm_sub_ps(v1, v2); 186 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 187 | } 188 | _mm_store_ps(TmpRes, sum); 189 | return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 190 | } 191 | 192 | static float 193 | L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 194 | size_t qty = *((size_t *) qty_ptr); 195 | size_t qty4 = qty >> 2 << 2; 196 | 197 | float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4); 198 | size_t qty_left = qty - qty4; 199 | 200 | float *pVect1 = (float *) pVect1v + qty4; 201 | float *pVect2 = (float *) pVect2v + qty4; 202 | float res_tail = L2Sqr(pVect1, pVect2, &qty_left); 203 | 204 | return (res + res_tail); 205 | } 206 | #endif 207 | 208 | class L2Space : public SpaceInterface { 209 | DISTFUNC fstdistfunc_; 210 | size_t data_size_; 211 | size_t dim_; 212 | 213 | public: 214 | L2Space(size_t dim) { 215 | fstdistfunc_ = L2Sqr; 216 | #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) 217 | #if defined(USE_AVX512) 218 | if (AVX512Capable()) 219 | L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512; 220 | else if (AVXCapable()) 221 | L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; 222 | #elif defined(USE_AVX) 223 | if (AVXCapable()) 224 | L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; 225 | #endif 226 | 227 | if (dim % 16 == 0) 228 | fstdistfunc_ = L2SqrSIMD16Ext; 229 | else if (dim % 4 == 0) 230 | fstdistfunc_ = L2SqrSIMD4Ext; 231 | else if (dim > 16) 232 | fstdistfunc_ = L2SqrSIMD16ExtResiduals; 233 | else if (dim > 4) 234 | fstdistfunc_ = L2SqrSIMD4ExtResiduals; 235 | #endif 236 | dim_ = dim; 237 | data_size_ = dim * sizeof(float); 238 | } 239 | 240 | size_t get_data_size() { 241 | return data_size_; 242 | } 243 | 244 | DISTFUNC get_dist_func() { 245 | return fstdistfunc_; 246 | } 247 | 248 | void *get_dist_func_param() { 249 | return &dim_; 250 | } 251 | 252 | ~L2Space() {} 253 | }; 254 | 255 | static int 256 | L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) { 257 | size_t qty = *((size_t *) qty_ptr); 258 | int res = 0; 259 | unsigned char *a = (unsigned char *) pVect1; 260 | unsigned char *b = (unsigned char *) pVect2; 261 | 262 | qty = qty >> 2; 263 | for (size_t i = 0; i < qty; i++) { 264 | res += ((*a) - (*b)) * ((*a) - (*b)); 265 | a++; 266 | b++; 267 | res += ((*a) - (*b)) * ((*a) - (*b)); 268 | a++; 269 | b++; 270 | res += ((*a) - (*b)) * ((*a) - (*b)); 271 | a++; 272 | b++; 273 | res += ((*a) - (*b)) * ((*a) - (*b)); 274 | a++; 275 | b++; 276 | } 277 | return (res); 278 | } 279 | 280 | static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) { 281 | size_t qty = *((size_t*)qty_ptr); 282 | int res = 0; 283 | unsigned char* a = (unsigned char*)pVect1; 284 | unsigned char* b = (unsigned char*)pVect2; 285 | 286 | for (size_t i = 0; i < qty; i++) { 287 | res += ((*a) - (*b)) * ((*a) - (*b)); 288 | a++; 289 | b++; 290 | } 291 | return (res); 292 | } 293 | 294 | class L2SpaceI : public SpaceInterface { 295 | DISTFUNC fstdistfunc_; 296 | size_t data_size_; 297 | size_t dim_; 298 | 299 | public: 300 | L2SpaceI(size_t dim) { 301 | if (dim % 4 == 0) { 302 | fstdistfunc_ = L2SqrI4x; 303 | } else { 304 | fstdistfunc_ = L2SqrI; 305 | } 306 | dim_ = dim; 307 | data_size_ = dim * sizeof(unsigned char); 308 | } 309 | 310 | size_t get_data_size() { 311 | return data_size_; 312 | } 313 | 314 | DISTFUNC get_dist_func() { 315 | return fstdistfunc_; 316 | } 317 | 318 | void *get_dist_func_param() { 319 | return &dim_; 320 | } 321 | 322 | ~L2SpaceI() {} 323 | }; 324 | } // namespace hnswlib 325 | -------------------------------------------------------------------------------- /3rd_party/hnswlib/stop_condition.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "space_l2.h" 3 | #include "space_ip.h" 4 | #include 5 | #include 6 | 7 | namespace hnswlib { 8 | 9 | template 10 | class BaseMultiVectorSpace : public SpaceInterface { 11 | public: 12 | virtual DOCIDTYPE get_doc_id(const void *datapoint) = 0; 13 | 14 | virtual void set_doc_id(void *datapoint, DOCIDTYPE doc_id) = 0; 15 | }; 16 | 17 | 18 | template 19 | class MultiVectorL2Space : public BaseMultiVectorSpace { 20 | DISTFUNC fstdistfunc_; 21 | size_t data_size_; 22 | size_t vector_size_; 23 | size_t dim_; 24 | 25 | public: 26 | MultiVectorL2Space(size_t dim) { 27 | fstdistfunc_ = L2Sqr; 28 | #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) 29 | #if defined(USE_AVX512) 30 | if (AVX512Capable()) 31 | L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512; 32 | else if (AVXCapable()) 33 | L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; 34 | #elif defined(USE_AVX) 35 | if (AVXCapable()) 36 | L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; 37 | #endif 38 | 39 | if (dim % 16 == 0) 40 | fstdistfunc_ = L2SqrSIMD16Ext; 41 | else if (dim % 4 == 0) 42 | fstdistfunc_ = L2SqrSIMD4Ext; 43 | else if (dim > 16) 44 | fstdistfunc_ = L2SqrSIMD16ExtResiduals; 45 | else if (dim > 4) 46 | fstdistfunc_ = L2SqrSIMD4ExtResiduals; 47 | #endif 48 | dim_ = dim; 49 | vector_size_ = dim * sizeof(float); 50 | data_size_ = vector_size_ + sizeof(DOCIDTYPE); 51 | } 52 | 53 | size_t get_data_size() override { 54 | return data_size_; 55 | } 56 | 57 | DISTFUNC get_dist_func() override { 58 | return fstdistfunc_; 59 | } 60 | 61 | void *get_dist_func_param() override { 62 | return &dim_; 63 | } 64 | 65 | DOCIDTYPE get_doc_id(const void *datapoint) override { 66 | return *(DOCIDTYPE *)((char *)datapoint + vector_size_); 67 | } 68 | 69 | void set_doc_id(void *datapoint, DOCIDTYPE doc_id) override { 70 | *(DOCIDTYPE*)((char *)datapoint + vector_size_) = doc_id; 71 | } 72 | 73 | ~MultiVectorL2Space() {} 74 | }; 75 | 76 | 77 | template 78 | class MultiVectorInnerProductSpace : public BaseMultiVectorSpace { 79 | DISTFUNC fstdistfunc_; 80 | size_t data_size_; 81 | size_t vector_size_; 82 | size_t dim_; 83 | 84 | public: 85 | MultiVectorInnerProductSpace(size_t dim) { 86 | fstdistfunc_ = InnerProductDistance; 87 | #if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512) 88 | #if defined(USE_AVX512) 89 | if (AVX512Capable()) { 90 | InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512; 91 | InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512; 92 | } else if (AVXCapable()) { 93 | InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; 94 | InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; 95 | } 96 | #elif defined(USE_AVX) 97 | if (AVXCapable()) { 98 | InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; 99 | InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; 100 | } 101 | #endif 102 | #if defined(USE_AVX) 103 | if (AVXCapable()) { 104 | InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX; 105 | InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX; 106 | } 107 | #endif 108 | 109 | if (dim % 16 == 0) 110 | fstdistfunc_ = InnerProductDistanceSIMD16Ext; 111 | else if (dim % 4 == 0) 112 | fstdistfunc_ = InnerProductDistanceSIMD4Ext; 113 | else if (dim > 16) 114 | fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals; 115 | else if (dim > 4) 116 | fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals; 117 | #endif 118 | vector_size_ = dim * sizeof(float); 119 | data_size_ = vector_size_ + sizeof(DOCIDTYPE); 120 | } 121 | 122 | size_t get_data_size() override { 123 | return data_size_; 124 | } 125 | 126 | DISTFUNC get_dist_func() override { 127 | return fstdistfunc_; 128 | } 129 | 130 | void *get_dist_func_param() override { 131 | return &dim_; 132 | } 133 | 134 | DOCIDTYPE get_doc_id(const void *datapoint) override { 135 | return *(DOCIDTYPE *)((char *)datapoint + vector_size_); 136 | } 137 | 138 | void set_doc_id(void *datapoint, DOCIDTYPE doc_id) override { 139 | *(DOCIDTYPE*)((char *)datapoint + vector_size_) = doc_id; 140 | } 141 | 142 | ~MultiVectorInnerProductSpace() {} 143 | }; 144 | 145 | 146 | template 147 | class MultiVectorSearchStopCondition : public BaseSearchStopCondition { 148 | size_t curr_num_docs_; 149 | size_t num_docs_to_search_; 150 | size_t ef_collection_; 151 | std::unordered_map doc_counter_; 152 | std::priority_queue> search_results_; 153 | BaseMultiVectorSpace& space_; 154 | 155 | public: 156 | MultiVectorSearchStopCondition( 157 | BaseMultiVectorSpace& space, 158 | size_t num_docs_to_search, 159 | size_t ef_collection = 10) 160 | : space_(space) { 161 | curr_num_docs_ = 0; 162 | num_docs_to_search_ = num_docs_to_search; 163 | ef_collection_ = std::max(ef_collection, num_docs_to_search); 164 | } 165 | 166 | void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override { 167 | DOCIDTYPE doc_id = space_.get_doc_id(datapoint); 168 | if (doc_counter_[doc_id] == 0) { 169 | curr_num_docs_ += 1; 170 | } 171 | search_results_.emplace(dist, doc_id); 172 | doc_counter_[doc_id] += 1; 173 | } 174 | 175 | void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override { 176 | DOCIDTYPE doc_id = space_.get_doc_id(datapoint); 177 | doc_counter_[doc_id] -= 1; 178 | if (doc_counter_[doc_id] == 0) { 179 | curr_num_docs_ -= 1; 180 | } 181 | search_results_.pop(); 182 | } 183 | 184 | bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) override { 185 | bool stop_search = candidate_dist > lowerBound && curr_num_docs_ == ef_collection_; 186 | return stop_search; 187 | } 188 | 189 | bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) override { 190 | bool flag_consider_candidate = curr_num_docs_ < ef_collection_ || lowerBound > candidate_dist; 191 | return flag_consider_candidate; 192 | } 193 | 194 | bool should_remove_extra() override { 195 | bool flag_remove_extra = curr_num_docs_ > ef_collection_; 196 | return flag_remove_extra; 197 | } 198 | 199 | void filter_results(std::vector> &candidates) override { 200 | while (curr_num_docs_ > num_docs_to_search_) { 201 | dist_t dist_cand = candidates.back().first; 202 | dist_t dist_res = search_results_.top().first; 203 | assert(dist_cand == dist_res); 204 | DOCIDTYPE doc_id = search_results_.top().second; 205 | doc_counter_[doc_id] -= 1; 206 | if (doc_counter_[doc_id] == 0) { 207 | curr_num_docs_ -= 1; 208 | } 209 | search_results_.pop(); 210 | candidates.pop_back(); 211 | } 212 | } 213 | 214 | ~MultiVectorSearchStopCondition() {} 215 | }; 216 | 217 | 218 | template 219 | class EpsilonSearchStopCondition : public BaseSearchStopCondition { 220 | float epsilon_; 221 | size_t min_num_candidates_; 222 | size_t max_num_candidates_; 223 | size_t curr_num_items_; 224 | 225 | public: 226 | EpsilonSearchStopCondition(float epsilon, size_t min_num_candidates, size_t max_num_candidates) { 227 | assert(min_num_candidates <= max_num_candidates); 228 | epsilon_ = epsilon; 229 | min_num_candidates_ = min_num_candidates; 230 | max_num_candidates_ = max_num_candidates; 231 | curr_num_items_ = 0; 232 | } 233 | 234 | void add_point_to_result(labeltype label, const void *datapoint, dist_t dist) override { 235 | curr_num_items_ += 1; 236 | } 237 | 238 | void remove_point_from_result(labeltype label, const void *datapoint, dist_t dist) override { 239 | curr_num_items_ -= 1; 240 | } 241 | 242 | bool should_stop_search(dist_t candidate_dist, dist_t lowerBound) override { 243 | if (candidate_dist > lowerBound && curr_num_items_ == max_num_candidates_) { 244 | // new candidate can't improve found results 245 | return true; 246 | } 247 | if (candidate_dist > epsilon_ && curr_num_items_ >= min_num_candidates_) { 248 | // new candidate is out of epsilon region and 249 | // minimum number of candidates is checked 250 | return true; 251 | } 252 | return false; 253 | } 254 | 255 | bool should_consider_candidate(dist_t candidate_dist, dist_t lowerBound) override { 256 | bool flag_consider_candidate = curr_num_items_ < max_num_candidates_ || lowerBound > candidate_dist; 257 | return flag_consider_candidate; 258 | } 259 | 260 | bool should_remove_extra() { 261 | bool flag_remove_extra = curr_num_items_ > max_num_candidates_; 262 | return flag_remove_extra; 263 | } 264 | 265 | void filter_results(std::vector> &candidates) override { 266 | while (!candidates.empty() && candidates.back().first > epsilon_) { 267 | candidates.pop_back(); 268 | } 269 | while (candidates.size() > max_num_candidates_) { 270 | candidates.pop_back(); 271 | } 272 | } 273 | 274 | ~EpsilonSearchStopCondition() {} 275 | }; 276 | } // namespace hnswlib 277 | -------------------------------------------------------------------------------- /3rd_party/hnswlib/visited_list_pool.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace hnswlib { 8 | typedef unsigned short int vl_type; 9 | 10 | class VisitedList { 11 | public: 12 | vl_type curV; 13 | vl_type *mass; 14 | unsigned int numelements; 15 | 16 | VisitedList(int numelements1) { 17 | curV = -1; 18 | numelements = numelements1; 19 | mass = new vl_type[numelements]; 20 | } 21 | 22 | void reset() { 23 | curV++; 24 | if (curV == 0) { 25 | memset(mass, 0, sizeof(vl_type) * numelements); 26 | curV++; 27 | } 28 | } 29 | 30 | ~VisitedList() { delete[] mass; } 31 | }; 32 | /////////////////////////////////////////////////////////// 33 | // 34 | // Class for multi-threaded pool-management of VisitedLists 35 | // 36 | ///////////////////////////////////////////////////////// 37 | 38 | class VisitedListPool { 39 | std::deque pool; 40 | std::mutex poolguard; 41 | int numelements; 42 | 43 | public: 44 | VisitedListPool(int initmaxpools, int numelements1) { 45 | numelements = numelements1; 46 | for (int i = 0; i < initmaxpools; i++) 47 | pool.push_front(new VisitedList(numelements)); 48 | } 49 | 50 | VisitedList *getFreeVisitedList() { 51 | VisitedList *rez; 52 | { 53 | std::unique_lock lock(poolguard); 54 | if (pool.size() > 0) { 55 | rez = pool.front(); 56 | pool.pop_front(); 57 | } else { 58 | rez = new VisitedList(numelements); 59 | } 60 | } 61 | rez->reset(); 62 | return rez; 63 | } 64 | 65 | void releaseVisitedList(VisitedList *vl) { 66 | std::unique_lock lock(poolguard); 67 | pool.push_front(vl); 68 | } 69 | 70 | ~VisitedListPool() { 71 | while (pool.size()) { 72 | VisitedList *rez = pool.front(); 73 | pool.pop_front(); 74 | delete rez; 75 | } 76 | } 77 | }; 78 | } // namespace hnswlib 79 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.16 FATAL_ERROR) 2 | project(hnswlib_nif C CXX) 3 | 4 | if(NOT DEFINED PRIV_DIR) 5 | if(DEFINED MIX_APP_PATH AND NOT "${MIX_APP_PATH}" STREQUAL "") 6 | if(WIN32) 7 | string(REPLACE "\\" "/" MIX_APP_PATH "${MIX_APP_PATH}") 8 | endif() 9 | set(PRIV_DIR "${MIX_APP_PATH}/priv") 10 | else() 11 | set(PRIV_DIR "${CMAKE_CURRENT_SOURCE_DIR}/priv") 12 | endif() 13 | endif() 14 | message(STATUS "Using PRIV_DIR: ${PRIV_DIR}") 15 | 16 | if(DEFINED ERTS_INCLUDE_DIR AND NOT "${ERTS_INCLUDE_DIR}" STREQUAL "") 17 | set(ERTS_INCLUDE_DIR "${ERTS_INCLUDE_DIR}") 18 | else() 19 | set(ERTS_INCLUDE_DIR_ONE_LINER "erl -noshell -eval \"io:format('~ts/erts-~ts/include/', [code:root_dir(), erlang:system_info(version)]), halt().\"") 20 | if(WIN32) 21 | execute_process(COMMAND powershell -command "${ERTS_INCLUDE_DIR_ONE_LINER}" OUTPUT_VARIABLE ERTS_INCLUDE_DIR) 22 | else() 23 | execute_process(COMMAND bash -c "${ERTS_INCLUDE_DIR_ONE_LINER}" OUTPUT_VARIABLE ERTS_INCLUDE_DIR) 24 | endif() 25 | set(ERTS_INCLUDE_DIR "${ERTS_INCLUDE_DIR}") 26 | endif() 27 | message(STATUS "Using ERTS_INCLUDE_DIR: ${ERTS_INCLUDE_DIR}") 28 | 29 | if(UNIX AND APPLE) 30 | set(CMAKE_SHARED_LINKER_FLAGS "-flat_namespace -undefined suppress -undefined dynamic_lookup") 31 | endif() 32 | 33 | set(CMAKE_CXX_STANDARD 14) 34 | if(DEFINED ENV{TARGET_GCC_FLAGS}) 35 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} $ENV{TARGET_GCC_FLAGS}") 36 | endif() 37 | 38 | message(STATUS "CMAKE_TOOLCHAIN_FILE: ${CMAKE_TOOLCHAIN_FILE}") 39 | 40 | if(WIN32) 41 | string(REPLACE "\\" "/" C_SRC "${C_SRC}") 42 | endif() 43 | set(SOURCE_FILES 44 | "${C_SRC}/nif_utils.cpp" 45 | "${C_SRC}/hnswlib_nif.cpp" 46 | ) 47 | 48 | include_directories("${ERTS_INCLUDE_DIR}") 49 | include_directories("${HNSWLIB_SRC}") 50 | 51 | add_library(hnswlib_nif SHARED 52 | ${SOURCE_FILES} 53 | ) 54 | install( 55 | TARGETS hnswlib_nif 56 | RUNTIME DESTINATION "${PRIV_DIR}" 57 | ) 58 | 59 | set_target_properties(hnswlib_nif PROPERTIES PREFIX "") 60 | if(NOT WIN32) 61 | set_target_properties(hnswlib_nif PROPERTIES SUFFIX ".so") 62 | endif() 63 | set_target_properties(hnswlib_nif PROPERTIES 64 | INSTALL_RPATH_USE_LINK_PATH TRUE 65 | BUILD_WITH_INSTALL_RPATH TRUE 66 | ) 67 | 68 | if(UNIX AND NOT APPLE) 69 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wno-unused-but-set-variable -Wno-reorder") 70 | elseif(UNIX AND APPLE) 71 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wno-reorder-ctor") 72 | set(CMAKE_SHARED_LINKER_FLAGS "-flat_namespace -undefined suppress -undefined dynamic_lookup") 73 | endif() 74 | 75 | if(WIN32) 76 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /bigobj /wd4996 /wd4267 /wd4068") 77 | else() 78 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC") 79 | if (CMAKE_BUILD_TYPE STREQUAL "Debug") 80 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O0 -g3") 81 | else() 82 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") 83 | endif() 84 | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-function -Wno-sign-compare -Wno-unused-parameter -Wno-missing-field-initializers -Wno-deprecated-declarations") 85 | endif() 86 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | ifndef MIX_APP_PATH 2 | MIX_APP_PATH=$(shell pwd) 3 | endif 4 | 5 | PRIV_DIR = $(MIX_APP_PATH)/priv 6 | NIF_SO = $(PRIV_DIR)/hnswlib_nif.so 7 | HNSWLIB_SRC = $(shell pwd)/3rd_party/hnswlib 8 | C_SRC = $(shell pwd)/c_src 9 | 10 | ifdef CC_PRECOMPILER_CURRENT_TARGET 11 | ifeq ($(findstring darwin, $(CC_PRECOMPILER_CURRENT_TARGET)), darwin) 12 | ifeq ($(findstring aarch64, $(CC_PRECOMPILER_CURRENT_TARGET)), aarch64) 13 | CMAKE_CONFIGURE_FLAGS=-D CMAKE_OSX_ARCHITECTURES=arm64 14 | else 15 | CMAKE_CONFIGURE_FLAGS=-D CMAKE_OSX_ARCHITECTURES=x86_64 16 | endif 17 | endif 18 | 19 | ifeq ($(findstring manylinux2014, $(HNSWLIB_CI_PRECOMPILE)), manylinux2014) 20 | CC=gcc 21 | CXX=g++ 22 | endif 23 | endif 24 | 25 | ifdef CMAKE_TOOLCHAIN_FILE 26 | CMAKE_CONFIGURE_FLAGS=-D CMAKE_TOOLCHAIN_FILE="$(CMAKE_TOOLCHAIN_FILE)" 27 | endif 28 | 29 | CMAKE_BUILD_TYPE ?= Release 30 | DEFAULT_JOBS ?= 1 31 | CMAKE_HNSWLIB_BUILD_DIR = $(MIX_APP_PATH)/cmake_hnswlib 32 | CMAKE_HNSWLIB_OPTIONS ?= "" 33 | MAKE_BUILD_FLAGS ?= -j$(DEFAULT_JOBS) 34 | 35 | .DEFAULT_GLOBAL := build 36 | 37 | build: $(NIF_SO) 38 | @echo > /dev/null 39 | 40 | $(NIF_SO): 41 | @ mkdir -p "$(PRIV_DIR)" 42 | @ if [ ! -f "${NIF_SO}" ]; then \ 43 | mkdir -p "$(CMAKE_HNSWLIB_BUILD_DIR)" && \ 44 | cd "$(CMAKE_HNSWLIB_BUILD_DIR)" && \ 45 | { cmake --no-warn-unused-cli \ 46 | -D CMAKE_BUILD_TYPE="$(CMAKE_BUILD_TYPE)" \ 47 | -D C_SRC="$(C_SRC)" \ 48 | -D HNSWLIB_SRC="$(HNSWLIB_SRC)" \ 49 | -D MIX_APP_PATH="$(MIX_APP_PATH)" \ 50 | -D PRIV_DIR="$(PRIV_DIR)" \ 51 | -D ERTS_INCLUDE_DIR="$(ERTS_INCLUDE_DIR)" \ 52 | $(CMAKE_CONFIGURE_FLAGS) $(CMAKE_HNSWLIB_OPTIONS) "$(shell pwd)" && \ 53 | make "$(MAKE_BUILD_FLAGS)" \ 54 | || { echo "\033[0;31mincomplete build of hnswlib found in '$(CMAKE_HNSWLIB_BUILD_DIR)', please delete that directory and retry\033[0m" && exit 1 ; } ; } \ 55 | && if [ "$(EVISION_PREFER_PRECOMPILED)" != "true" ]; then \ 56 | cp "$(CMAKE_HNSWLIB_BUILD_DIR)/hnswlib_nif.so" "$(NIF_SO)" ; \ 57 | fi ; \ 58 | fi 59 | 60 | cleanup: 61 | @ rm -rf "$(CMAKE_HNSWLIB_BUILD_DIR)" 62 | @ rm -rf "$(NIF_SO)" 63 | -------------------------------------------------------------------------------- /Makefile.win: -------------------------------------------------------------------------------- 1 | !IFNDEF MIX_APP_PATH 2 | MIX_APP_PATH=$(MAKEDIR) 3 | !ENDIF 4 | 5 | PRIV_DIR = $(MIX_APP_PATH)/priv 6 | NIF_SO = $(PRIV_DIR)/hnswlib_nif.dll 7 | HNSWLIB_SRC = $(MAKEDIR)\3rd_party\hnswlib 8 | C_SRC = $(MAKEDIR)\c_src 9 | !IFDEF CMAKE_TOOLCHAIN_FILE 10 | CMAKE_CONFIGURE_FLAGS=-D CMAKE_TOOLCHAIN_FILE="$(CMAKE_TOOLCHAIN_FILE)" 11 | !ENDIF 12 | 13 | !IFNDEF CMAKE_BUILD_TYPE 14 | CMAKE_BUILD_TYPE = Release 15 | !ENDIF 16 | !IFNDEF CMAKE_GENERATOR_TYPE 17 | !IFNDEF MSBUILD_PLATFORM 18 | 19 | !IF "$(HAVE_NINJA)" == "true" 20 | CMAKE_GENERATOR_TYPE=Ninja 21 | !ELSE 22 | CMAKE_GENERATOR_TYPE=NMake Makefiles 23 | !ENDIF 24 | 25 | !ENDIF 26 | !ENDIF 27 | 28 | CMAKE_HNSWLIB_BUILD_DIR = $(MIX_APP_PATH)/cmake_hnswlib 29 | 30 | !IFNDEF CMAKE_HNSWLIB_OPTIONS 31 | CMAKE_HNSWLIB_OPTIONS = "" 32 | !ENDIF 33 | 34 | CMAKE_BUILD_PARAMETER= --config "$(CMAKE_BUILD_TYPE)" 35 | 36 | build: $(NIF_SO) 37 | 38 | $(NIF_SO): 39 | @ if not exist "$(PRIV_DIR)" mkdir "$(PRIV_DIR)" 40 | @ if not exist "$(NIF_SO)" ( \ 41 | if not exist "$(CMAKE_HNSWLIB_BUILD_DIR)" mkdir "$(CMAKE_HNSWLIB_BUILD_DIR)" && \ 42 | cd "$(CMAKE_HNSWLIB_BUILD_DIR)" && \ 43 | cmake -G "$(CMAKE_GENERATOR_TYPE)" \ 44 | --no-warn-unused-cli \ 45 | -D C_SRC="$(C_SRC)" \ 46 | -D HNSWLIB_SRC="$(HNSWLIB_SRC)" \ 47 | -D MIX_APP_PATH="$(MIX_APP_PATH)" \ 48 | -D PRIV_DIR="$(PRIV_DIR)" \ 49 | -D ERTS_INCLUDE_DIR="$(ERTS_INCLUDE_DIR)" \ 50 | $(CMAKE_CONFIGURE_FLAGS) "$(MAKEDIR)" && \ 51 | cmake --build . $(CMAKE_BUILD_PARAMETER) && \ 52 | cmake --install . $(CMAKE_BUILD_PARAMETER) \ 53 | ) 54 | 55 | cleanup: 56 | @ powershell -command "if (Test-Path \"$(NIF_SO)\" -PathType Leaf) { Remove-Item \"$(NIF_SO)\" }" 57 | @ powershell -command "if (Test-Path \"$(CMAKE_HNSWLIB_BUILD_DIR)\" ) { Remove-Item \"$(CMAKE_HNSWLIB_BUILD_DIR)\" -Recurse -Force } 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HNSWLib 2 | 3 | Elixir binding for the [hnswlib](https://github.com/nmslib/hnswlib) library. 4 | 5 | Currently in development, alpha software. 6 | 7 | ## Usage 8 | 9 | ### Create an Index 10 | 11 | ```elixir 12 | # working in L2-space 13 | # other possible values are 14 | # `:ip` (inner product) 15 | # `:cosine` (cosine similarity) 16 | iex> space = :l2 17 | :l2 18 | # each vector is a 2D-vec 19 | iex> dim = 2 20 | 2 21 | # limit the maximum elements to 200 22 | iex> max_elements = 200 23 | 200 24 | # create Index 25 | iex> {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 26 | {:ok, 27 | %HNSWLib.Index{ 28 | space: :l2, 29 | dim: 2, 30 | reference: #Reference<0.2548668725.3381002243.154990> 31 | }} 32 | ``` 33 | 34 | ### Add vectors to the Index 35 | 36 | ```elixir 37 | iex> data = 38 | Nx.tensor( 39 | [ 40 | [42, 42], 41 | [43, 43], 42 | [0, 0], 43 | [200, 200], 44 | [200, 220] 45 | ], 46 | type: :f32 47 | ) 48 | #Nx.Tensor< 49 | f32[5][2] 50 | [ 51 | [42.0, 42.0], 52 | [43.0, 43.0], 53 | [0.0, 0.0], 54 | [200.0, 200.0], 55 | [200.0, 220.0] 56 | ] 57 | > 58 | iex> HNSWLib.Index.get_current_count(index) 59 | {:ok, 0} 60 | iex> HNSWLib.Index.add_items(index, data) 61 | :ok 62 | iex> HNSWLib.Index.get_current_count(index) 63 | {:ok, 5} 64 | ``` 65 | 66 | ### Query nearest vector(s) in the index 67 | 68 | ```elixir 69 | # query 70 | iex> query = Nx.tensor([1, 2], type: :f32) 71 | #Nx.Tensor< 72 | f32[2] 73 | [1.0, 2.0] 74 | > 75 | iex> {:ok, labels, dists} = HNSWLib.Index.knn_query(index, query) 76 | {:ok, 77 | #Nx.Tensor< 78 | u64[1][1] 79 | [ 80 | [2] 81 | ] 82 | >, 83 | #Nx.Tensor< 84 | f32[1][1] 85 | [ 86 | [5.0] 87 | ] 88 | >} 89 | 90 | iex> {:ok, labels, dists} = HNSWLib.Index.knn_query(index, query, k: 3) 91 | {:ok, 92 | #Nx.Tensor< 93 | u64[1][3] 94 | [ 95 | [2, 0, 1] 96 | ] 97 | >, 98 | #Nx.Tensor< 99 | f32[1][3] 100 | [ 101 | [5.0, 3281.0, 3445.0] 102 | ] 103 | >} 104 | ``` 105 | 106 | ### Save an Index to file 107 | 108 | ```elixir 109 | iex> HNSWLib.Index.save_index(index, "my_index.bin") 110 | :ok 111 | ``` 112 | 113 | ### Load an Index from file 114 | 115 | ```elixir 116 | iex> {:ok, saved_index} = HNSWLib.Index.load_index(space, dim, "my_index.bin") 117 | {:ok, 118 | %HNSWLib.Index{ 119 | space: :l2, 120 | dim: 2, 121 | reference: #Reference<0.2105700569.2629697564.236704> 122 | }} 123 | iex> HNSWLib.Index.get_current_count(saved_index) 124 | {:ok, 5} 125 | iex> {:ok, data} = HNSWLib.Index.get_items(saved_index, [2, 0, 1]) 126 | {:ok, 127 | [ 128 | <<0, 0, 0, 0, 0, 0, 0, 0>>, 129 | <<0, 0, 40, 66, 0, 0, 40, 66>>, 130 | <<0, 0, 44, 66, 0, 0, 44, 66>> 131 | ]} 132 | iex> tensors = Nx.stack(Enum.map(data, fn d -> Nx.from_binary(d, :f32) end)) 133 | #Nx.Tensor< 134 | f32[3][2] 135 | [ 136 | [0.0, 0.0], 137 | [42.0, 42.0], 138 | [43.0, 43.0] 139 | ] 140 | > 141 | ``` 142 | 143 | ## Installation 144 | 145 | If [available in Hex](https://hex.pm/docs/publish), the package can be installed 146 | by adding `hnswlib` to your list of dependencies in `mix.exs`: 147 | 148 | ```elixir 149 | def deps do 150 | [ 151 | {:hnswlib, "~> 0.1.0"} 152 | ] 153 | end 154 | ``` 155 | 156 | Documentation can be generated with [ExDoc](https://github.com/elixir-lang/ex_doc) 157 | and published on [HexDocs](https://hexdocs.pm). Once published, the docs can 158 | be found at . 159 | 160 | -------------------------------------------------------------------------------- /c_src/hnswlib_index.hpp: -------------------------------------------------------------------------------- 1 | #ifndef HNSWLIB_INDEX_HPP 2 | #define HNSWLIB_INDEX_HPP 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include "nif_utils.hpp" 15 | 16 | /* 17 | * replacement for the openmp '#pragma omp parallel for' directive 18 | * only handles a subset of functionality (no reductions etc) 19 | * Process ids from start (inclusive) to end (EXCLUSIVE) 20 | * 21 | * The method is borrowed from nmslib 22 | */ 23 | template 24 | inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) { 25 | if (numThreads <= 0) { 26 | numThreads = std::thread::hardware_concurrency(); 27 | } 28 | 29 | if (numThreads == 1) { 30 | for (size_t id = start; id < end; id++) { 31 | fn(id, 0); 32 | } 33 | } else { 34 | std::vector threads; 35 | std::atomic current(start); 36 | 37 | // keep track of exceptions in threads 38 | // https://stackoverflow.com/a/32428427/1713196 39 | std::exception_ptr lastException = nullptr; 40 | std::mutex lastExceptMutex; 41 | 42 | for (size_t threadId = 0; threadId < numThreads; ++threadId) { 43 | threads.push_back(std::thread([&, threadId] { 44 | while (true) { 45 | size_t id = current.fetch_add(1); 46 | 47 | if (id >= end) { 48 | break; 49 | } 50 | 51 | try { 52 | fn(id, threadId); 53 | } catch (...) { 54 | std::unique_lock lastExcepLock(lastExceptMutex); 55 | lastException = std::current_exception(); 56 | /* 57 | * This will work even when current is the largest value that 58 | * size_t can fit, because fetch_add returns the previous value 59 | * before the increment (what will result in overflow 60 | * and produce 0 instead of current + 1). 61 | */ 62 | current = end; 63 | break; 64 | } 65 | } 66 | })); 67 | } 68 | for (auto &thread : threads) { 69 | thread.join(); 70 | } 71 | if (lastException) { 72 | std::rethrow_exception(lastException); 73 | } 74 | } 75 | } 76 | 77 | 78 | inline void assert_true(bool expr, const std::string & msg) { 79 | if (expr == false) throw std::runtime_error("Unpickle Error: " + msg); 80 | return; 81 | } 82 | 83 | 84 | class CustomFilterFunctor: public hnswlib::BaseFilterFunctor { 85 | std::function filter; 86 | 87 | public: 88 | explicit CustomFilterFunctor(const std::function& f) { 89 | filter = f; 90 | } 91 | 92 | bool operator()(hnswlib::labeltype id) { 93 | return filter(id); 94 | } 95 | }; 96 | 97 | template 98 | class Index { 99 | public: 100 | static const int ser_version = 1; // serialization version 101 | 102 | std::string space_name; 103 | int dim; 104 | size_t seed; 105 | size_t default_ef; 106 | 107 | bool index_inited; 108 | bool ep_added; 109 | bool normalize; 110 | int num_threads_default; 111 | hnswlib::labeltype cur_l; 112 | hnswlib::HierarchicalNSW* appr_alg; 113 | hnswlib::SpaceInterface* l2space; 114 | 115 | 116 | Index(const std::string &space_name, const int dim) : space_name(space_name), dim(dim) { 117 | normalize = false; 118 | if (space_name == "l2") { 119 | l2space = new hnswlib::L2Space(dim); 120 | } else if (space_name == "ip") { 121 | l2space = new hnswlib::InnerProductSpace(dim); 122 | } else if (space_name == "cosine") { 123 | l2space = new hnswlib::InnerProductSpace(dim); 124 | normalize = true; 125 | } else { 126 | throw std::runtime_error("Space name must be one of l2, ip, or cosine."); 127 | } 128 | appr_alg = NULL; 129 | ep_added = true; 130 | index_inited = false; 131 | num_threads_default = std::thread::hardware_concurrency(); 132 | 133 | default_ef = 10; 134 | } 135 | 136 | 137 | ~Index() { 138 | delete l2space; 139 | if (appr_alg) 140 | delete appr_alg; 141 | } 142 | 143 | 144 | void init_new_index( 145 | size_t maxElements, 146 | size_t M, 147 | size_t efConstruction, 148 | size_t random_seed, 149 | bool allow_replace_deleted) { 150 | if (appr_alg) { 151 | throw std::runtime_error("The index is already initiated."); 152 | } 153 | cur_l = 0; 154 | appr_alg = new hnswlib::HierarchicalNSW(l2space, maxElements, M, efConstruction, random_seed, allow_replace_deleted); 155 | index_inited = true; 156 | ep_added = false; 157 | appr_alg->ef_ = default_ef; 158 | seed = random_seed; 159 | } 160 | 161 | 162 | void set_ef(size_t ef) { 163 | default_ef = ef; 164 | if (appr_alg) 165 | appr_alg->ef_ = ef; 166 | } 167 | 168 | 169 | void set_num_threads(int num_threads) { 170 | this->num_threads_default = num_threads; 171 | } 172 | 173 | size_t indexFileSize() const { 174 | return appr_alg->indexFileSize(); 175 | } 176 | 177 | void saveIndex(const std::string &path_to_index) { 178 | appr_alg->saveIndex(path_to_index); 179 | } 180 | 181 | 182 | void loadIndex(const std::string &path_to_index, size_t max_elements, bool allow_replace_deleted) { 183 | if (appr_alg) { 184 | fprintf(stderr, "Warning: Calling load_index for an already inited index. Old index is being deallocated.\r\n"); 185 | delete appr_alg; 186 | } 187 | appr_alg = new hnswlib::HierarchicalNSW(l2space, path_to_index, false, max_elements, allow_replace_deleted); 188 | cur_l = appr_alg->cur_element_count; 189 | index_inited = true; 190 | } 191 | 192 | 193 | void normalize_vector(float* data, float* norm_array) { 194 | float norm = 0.0f; 195 | for (int i = 0; i < dim; i++) 196 | norm += data[i] * data[i]; 197 | norm = 1.0f / (sqrtf(norm) + 1e-30f); 198 | for (int i = 0; i < dim; i++) 199 | norm_array[i] = data[i] * norm; 200 | } 201 | 202 | 203 | void addItems(float * input, size_t rows, size_t features, const uint64_t * ids, size_t ids_count, int num_threads = -1, bool replace_deleted = false) { 204 | if (num_threads <= 0) 205 | num_threads = num_threads_default; 206 | 207 | if (features != dim) 208 | throw std::runtime_error("Wrong dimensionality of the vectors"); 209 | 210 | // avoid using threads when the number of additions is small: 211 | if (rows <= num_threads * 4) { 212 | num_threads = 1; 213 | } 214 | 215 | { 216 | int start = 0; 217 | if (!ep_added) { 218 | uint64_t id = ids_count ? ids[0] : (cur_l); 219 | float* vector_data = input; 220 | std::vector norm_array(dim); 221 | if (normalize) { 222 | normalize_vector(vector_data, norm_array.data()); 223 | vector_data = norm_array.data(); 224 | } 225 | appr_alg->addPoint((void *)vector_data, (size_t)id, replace_deleted); 226 | start = 1; 227 | ep_added = true; 228 | } 229 | 230 | if (normalize == false) { 231 | ParallelFor(start, rows, num_threads, [&](size_t row, size_t threadId) { 232 | uint64_t id = ids_count ? ids[row] : (cur_l + row); 233 | appr_alg->addPoint((void *)(input + row * dim), (size_t)id, replace_deleted); 234 | }); 235 | } else { 236 | std::vector norm_array(num_threads * dim); 237 | ParallelFor(start, rows, num_threads, [&](size_t row, size_t threadId) { 238 | // normalize vector: 239 | size_t start_idx = threadId * dim; 240 | normalize_vector((float *)(input + row * dim), (norm_array.data() + start_idx)); 241 | 242 | uint64_t id = ids_count ? ids[row] : (cur_l + row); 243 | appr_alg->addPoint((void *)(norm_array.data() + start_idx), (size_t)id, replace_deleted); 244 | }); 245 | } 246 | cur_l += rows; 247 | } 248 | } 249 | 250 | 251 | std::vector> getDataReturnList(const uint64_t* ids, size_t ids_count) { 252 | std::vector> data; 253 | for (size_t i = 0; i < ids_count; i++) { 254 | data.push_back(appr_alg->template getDataByLabel((size_t)ids[i])); 255 | } 256 | return data; 257 | } 258 | 259 | 260 | std::vector getIdsList() { 261 | std::vector ids; 262 | 263 | for (auto kv : appr_alg->label_lookup_) { 264 | ids.push_back(kv.first); 265 | } 266 | std::sort(ids.begin(), ids.end()); 267 | return ids; 268 | } 269 | 270 | 271 | // return true if no error, false otherwise (the `{:error, reason}`-tuple will be saved in `out`) 272 | bool knnQuery( 273 | ErlNifEnv * env, 274 | float * input, 275 | size_t rows, 276 | size_t features, 277 | size_t k, 278 | int num_threads, 279 | // const std::function& filter, 280 | ERL_NIF_TERM& out) { 281 | ErlNifBinary data_l_bin; 282 | ErlNifBinary data_d_bin; 283 | 284 | hnswlib::labeltype* data_l; 285 | dist_t* data_d; 286 | 287 | if (num_threads <= 0) { 288 | num_threads = num_threads_default; 289 | } 290 | 291 | // avoid using threads when the number of searches is small: 292 | if (rows <= num_threads * 4) { 293 | num_threads = 1; 294 | } 295 | 296 | if (!enif_alloc_binary(sizeof(hnswlib::labeltype) * rows * k, &data_l_bin)) { 297 | out = hnswlib_error(env, "out of memory for storing labels"); 298 | return false; 299 | } 300 | data_l = (hnswlib::labeltype *)data_l_bin.data; 301 | 302 | if (!enif_alloc_binary(sizeof(dist_t) * rows * k, &data_d_bin)) { 303 | enif_release_binary(&data_l_bin); 304 | out = hnswlib_error(env, "out of memory for storing distances"); 305 | return false; 306 | } 307 | data_d = (dist_t *)data_d_bin.data; 308 | 309 | // CustomFilterFunctor idFilter; 310 | CustomFilterFunctor* p_idFilter = nullptr; 311 | 312 | try { 313 | if (normalize == false) { 314 | ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { 315 | std::priority_queue> result = appr_alg->searchKnn( 316 | (void *)(input + row * features), k, p_idFilter); 317 | if (result.size() != k) { 318 | throw std::runtime_error( 319 | "Cannot return the results in a contigious 2D array. Probably ef or M is too small"); 320 | } 321 | 322 | for (int i = k - 1; i >= 0; i--) { 323 | auto& result_tuple = result.top(); 324 | data_d[row * k + i] = result_tuple.first; 325 | data_l[row * k + i] = result_tuple.second; 326 | result.pop(); 327 | } 328 | }); 329 | } else { 330 | std::vector norm_array(num_threads * features); 331 | ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { 332 | float* data = input + row * features; 333 | 334 | size_t start_idx = threadId * dim; 335 | normalize_vector(data, (norm_array.data() + start_idx)); 336 | 337 | std::priority_queue> result = appr_alg->searchKnn( 338 | (void*)(norm_array.data() + start_idx), k, p_idFilter); 339 | if (result.size() != k) { 340 | throw std::runtime_error( 341 | "Cannot return the results in a contigious 2D array. Probably ef or M is too small"); 342 | } 343 | 344 | for (int i = k - 1; i >= 0; i--) { 345 | auto& result_tuple = result.top(); 346 | data_d[row * k + i] = result_tuple.first; 347 | data_l[row * k + i] = result_tuple.second; 348 | result.pop(); 349 | } 350 | }); 351 | } 352 | 353 | ERL_NIF_TERM labels_out = enif_make_binary(env, &data_l_bin); 354 | ERL_NIF_TERM dists_out = enif_make_binary(env, &data_d_bin); 355 | 356 | ERL_NIF_TERM label_size = enif_make_uint(env, sizeof(hnswlib::labeltype) * 8); 357 | ERL_NIF_TERM dist_size = enif_make_uint(env, sizeof(dist_t) * 8); 358 | out = enif_make_tuple7(env, 359 | hnswlib_atom(env, "ok"), 360 | labels_out, 361 | dists_out, 362 | enif_make_uint64(env, rows), 363 | enif_make_uint64(env, k), 364 | label_size, 365 | dist_size); 366 | } catch (std::runtime_error &err) { 367 | out = hnswlib_error(env, err.what()); 368 | 369 | enif_release_binary(&data_l_bin); 370 | enif_release_binary(&data_d_bin); 371 | } 372 | 373 | return true; 374 | } 375 | 376 | 377 | void markDeleted(size_t label) { 378 | appr_alg->markDelete(label); 379 | } 380 | 381 | 382 | void unmarkDeleted(size_t label) { 383 | appr_alg->unmarkDelete(label); 384 | } 385 | 386 | 387 | void resizeIndex(size_t new_size) { 388 | appr_alg->resizeIndex(new_size); 389 | } 390 | 391 | 392 | size_t getMaxElements() const { 393 | return appr_alg->max_elements_; 394 | } 395 | 396 | 397 | size_t getCurrentCount() const { 398 | return appr_alg->cur_element_count; 399 | } 400 | 401 | ERL_NIF_TERM hnswlib_atom(ErlNifEnv *env, const char *msg) { 402 | ERL_NIF_TERM a; 403 | if (enif_make_existing_atom(env, msg, &a, ERL_NIF_LATIN1)) { 404 | return a; 405 | } else { 406 | return enif_make_atom(env, msg); 407 | } 408 | } 409 | 410 | // Helper for returning `{:error, msg}` from NIF. 411 | ERL_NIF_TERM hnswlib_error(ErlNifEnv *env, const char *msg) { 412 | ERL_NIF_TERM error_atom = hnswlib_atom(env, "error"); 413 | ERL_NIF_TERM reason; 414 | unsigned char *ptr; 415 | size_t len = strlen(msg); 416 | if ((ptr = enif_make_new_binary(env, len, &reason)) != nullptr) { 417 | strcpy((char *) ptr, msg); 418 | return enif_make_tuple2(env, error_atom, reason); 419 | } else { 420 | ERL_NIF_TERM msg_term = enif_make_string(env, msg, ERL_NIF_LATIN1); 421 | return enif_make_tuple2(env, error_atom, msg_term); 422 | } 423 | } 424 | }; 425 | 426 | template 427 | class BFIndex { 428 | public: 429 | static const int ser_version = 1; // serialization version 430 | 431 | std::string space_name; 432 | int dim; 433 | bool index_inited; 434 | bool normalize; 435 | int num_threads_default; 436 | 437 | hnswlib::labeltype cur_l; 438 | hnswlib::BruteforceSearch* alg; 439 | hnswlib::SpaceInterface* space; 440 | 441 | 442 | BFIndex(const std::string &space_name, const int dim) : space_name(space_name), dim(dim) { 443 | normalize = false; 444 | if (space_name == "l2") { 445 | space = new hnswlib::L2Space(dim); 446 | } else if (space_name == "ip") { 447 | space = new hnswlib::InnerProductSpace(dim); 448 | } else if (space_name == "cosine") { 449 | space = new hnswlib::InnerProductSpace(dim); 450 | normalize = true; 451 | } else { 452 | throw std::runtime_error("Space name must be one of l2, ip, or cosine."); 453 | } 454 | alg = NULL; 455 | index_inited = false; 456 | 457 | num_threads_default = std::thread::hardware_concurrency(); 458 | } 459 | 460 | 461 | ~BFIndex() { 462 | delete space; 463 | if (alg) 464 | delete alg; 465 | } 466 | 467 | 468 | size_t getMaxElements() const { 469 | return alg->maxelements_; 470 | } 471 | 472 | 473 | size_t getCurrentCount() const { 474 | return alg->cur_element_count; 475 | } 476 | 477 | 478 | void set_num_threads(int num_threads) { 479 | this->num_threads_default = num_threads; 480 | } 481 | 482 | 483 | void init_new_index(const size_t maxElements) { 484 | if (alg) { 485 | throw std::runtime_error("The index is already initiated."); 486 | } 487 | cur_l = 0; 488 | alg = new hnswlib::BruteforceSearch(space, maxElements); 489 | index_inited = true; 490 | } 491 | 492 | 493 | void normalize_vector(float* data, float* norm_array) { 494 | float norm = 0.0f; 495 | for (int i = 0; i < dim; i++) 496 | norm += data[i] * data[i]; 497 | norm = 1.0f / (sqrtf(norm) + 1e-30f); 498 | for (int i = 0; i < dim; i++) 499 | norm_array[i] = data[i] * norm; 500 | } 501 | 502 | 503 | void addItems(float * input, size_t rows, size_t features, const uint64_t* ids, size_t ids_count) { 504 | if (features != dim) 505 | throw std::runtime_error("Wrong dimensionality of the vectors"); 506 | 507 | for (size_t row = 0; row < rows; row++) { 508 | uint64_t id = ids_count ? ids[row] : cur_l + row; 509 | if (!normalize) { 510 | alg->addPoint((void *)(input + row * features), (size_t)id); 511 | } else { 512 | std::vector normalized_vector(dim); 513 | normalize_vector((float *)(input + row * features), normalized_vector.data()); 514 | alg->addPoint((void *)normalized_vector.data(), (size_t)id); 515 | } 516 | } 517 | cur_l+=rows; 518 | } 519 | 520 | 521 | void deleteVector(size_t label) { 522 | alg->removePoint(label); 523 | } 524 | 525 | 526 | void saveIndex(const std::string &path_to_index) { 527 | alg->saveIndex(path_to_index); 528 | } 529 | 530 | 531 | void loadIndex(const std::string &path_to_index, size_t max_elements) { 532 | if (alg) { 533 | fprintf(stderr, "Warning: Calling load_index for an already inited index. Old index is being deallocated.\r\n"); 534 | delete alg; 535 | } 536 | alg = new hnswlib::BruteforceSearch(space, path_to_index); 537 | cur_l = alg->cur_element_count; 538 | index_inited = true; 539 | } 540 | 541 | 542 | bool knnQuery( 543 | ErlNifEnv * env, 544 | float* input, 545 | size_t rows, 546 | size_t features, 547 | size_t k, 548 | // const std::function& filter, 549 | ERL_NIF_TERM& out) { 550 | ErlNifBinary data_l_bin; 551 | ErlNifBinary data_d_bin; 552 | 553 | hnswlib::labeltype* data_l; 554 | dist_t* data_d; 555 | 556 | try { 557 | if (!enif_alloc_binary(sizeof(hnswlib::labeltype) * rows * k, &data_l_bin)) { 558 | out = hnswlib_error(env, "out of memory for storing labels"); 559 | return false; 560 | } 561 | data_l = (hnswlib::labeltype *)data_l_bin.data; 562 | 563 | if (!enif_alloc_binary(sizeof(dist_t) * rows * k, &data_d_bin)) { 564 | enif_release_binary(&data_l_bin); 565 | out = hnswlib_error(env, "out of memory for storing distances"); 566 | return false; 567 | } 568 | data_d = (dist_t *)data_d_bin.data; 569 | 570 | // CustomFilterFunctor idFilter(filter); 571 | CustomFilterFunctor* p_idFilter = nullptr; 572 | 573 | for (size_t row = 0; row < rows; row++) { 574 | std::priority_queue> result = alg->searchKnn( 575 | (void *)(input + row * features), k, p_idFilter); 576 | for (int i = k - 1; i >= 0; i--) { 577 | auto &result_tuple = result.top(); 578 | data_d[row * k + i] = result_tuple.first; 579 | data_l[row * k + i] = result_tuple.second; 580 | result.pop(); 581 | } 582 | } 583 | 584 | ERL_NIF_TERM labels_out = enif_make_binary(env, &data_l_bin); 585 | ERL_NIF_TERM dists_out = enif_make_binary(env, &data_d_bin); 586 | 587 | ERL_NIF_TERM label_size = enif_make_uint(env, sizeof(hnswlib::labeltype) * 8); 588 | ERL_NIF_TERM dist_size = enif_make_uint(env, sizeof(dist_t) * 8); 589 | out = enif_make_tuple7(env, 590 | hnswlib_atom(env, "ok"), 591 | labels_out, 592 | dists_out, 593 | enif_make_uint64(env, rows), 594 | enif_make_uint64(env, k), 595 | label_size, 596 | dist_size); 597 | } catch (std::runtime_error &err) { 598 | out = hnswlib_error(env, err.what()); 599 | 600 | enif_release_binary(&data_l_bin); 601 | enif_release_binary(&data_d_bin); 602 | } 603 | 604 | return true; 605 | } 606 | 607 | ERL_NIF_TERM hnswlib_atom(ErlNifEnv *env, const char *msg) { 608 | ERL_NIF_TERM a; 609 | if (enif_make_existing_atom(env, msg, &a, ERL_NIF_LATIN1)) { 610 | return a; 611 | } else { 612 | return enif_make_atom(env, msg); 613 | } 614 | } 615 | 616 | // Helper for returning `{:error, msg}` from NIF. 617 | ERL_NIF_TERM hnswlib_error(ErlNifEnv *env, const char *msg) { 618 | ERL_NIF_TERM error_atom = hnswlib_atom(env, "error"); 619 | ERL_NIF_TERM reason; 620 | unsigned char *ptr; 621 | size_t len = strlen(msg); 622 | if ((ptr = enif_make_new_binary(env, len, &reason)) != nullptr) { 623 | strcpy((char *) ptr, msg); 624 | return enif_make_tuple2(env, error_atom, reason); 625 | } else { 626 | ERL_NIF_TERM msg_term = enif_make_string(env, msg, ERL_NIF_LATIN1); 627 | return enif_make_tuple2(env, error_atom, msg_term); 628 | } 629 | } 630 | }; 631 | 632 | struct NifResHNSWLibIndex { 633 | Index * val; 634 | ErlNifRWLock * rwlock; 635 | 636 | static ErlNifResourceType * type; 637 | static NifResHNSWLibIndex * allocate_resource(ErlNifEnv * env, ERL_NIF_TERM &error) { 638 | NifResHNSWLibIndex * res = (NifResHNSWLibIndex *)enif_alloc_resource(NifResHNSWLibIndex::type, sizeof(NifResHNSWLibIndex)); 639 | if (res == nullptr) { 640 | error = erlang::nif::error(env, "cannot allocate NifResHNSWLibIndex resource"); 641 | return res; 642 | } 643 | 644 | res->rwlock = enif_rwlock_create((char *)"hnswlib.index"); 645 | if (res->rwlock == nullptr) { 646 | error = erlang::nif::error(env, "cannot allocate rwlock for the NifResHNSWLibIndex resource"); 647 | return res; 648 | } 649 | 650 | return res; 651 | } 652 | 653 | static NifResHNSWLibIndex * get_resource(ErlNifEnv * env, ERL_NIF_TERM term, ERL_NIF_TERM &error) { 654 | NifResHNSWLibIndex * self_res = nullptr; 655 | if (!enif_get_resource(env, term, NifResHNSWLibIndex::type, (void **)&self_res) || self_res == nullptr || self_res->val == nullptr) { 656 | error = erlang::nif::error(env, "cannot access NifResHNSWLibIndex resource"); 657 | } 658 | return self_res; 659 | } 660 | 661 | static void destruct_resource(ErlNifEnv *env, void *args) { 662 | auto res = (NifResHNSWLibIndex *)args; 663 | if (res) { 664 | if (res->val) { 665 | delete res->val; 666 | res->val = nullptr; 667 | } 668 | 669 | if (res->rwlock) { 670 | enif_rwlock_destroy(res->rwlock); 671 | res->rwlock = nullptr; 672 | } 673 | } 674 | } 675 | }; 676 | 677 | struct NifResHNSWLibBFIndex { 678 | BFIndex * val; 679 | ErlNifRWLock * rwlock; 680 | 681 | static ErlNifResourceType * type; 682 | static NifResHNSWLibBFIndex * allocate_resource(ErlNifEnv * env, ERL_NIF_TERM &error) { 683 | NifResHNSWLibBFIndex * res = (NifResHNSWLibBFIndex *)enif_alloc_resource(NifResHNSWLibBFIndex::type, sizeof(NifResHNSWLibBFIndex)); 684 | if (res == nullptr) { 685 | error = erlang::nif::error(env, "cannot allocate NifResHNSWLibBFIndex resource"); 686 | return res; 687 | } 688 | 689 | res->rwlock = enif_rwlock_create((char *)"hnswlib.bfindex"); 690 | if (res->rwlock == nullptr) { 691 | error = erlang::nif::error(env, "cannot allocate rwlock for the NifResHNSWLibBFIndex resource"); 692 | return res; 693 | } 694 | 695 | return res; 696 | } 697 | 698 | static NifResHNSWLibBFIndex * get_resource(ErlNifEnv * env, ERL_NIF_TERM term, ERL_NIF_TERM &error) { 699 | NifResHNSWLibBFIndex * self_res = nullptr; 700 | if (!enif_get_resource(env, term, NifResHNSWLibBFIndex::type, (void **)&self_res) || self_res == nullptr || self_res->val == nullptr) { 701 | error = erlang::nif::error(env, "cannot access NifResHNSWLibBFIndex resource"); 702 | } 703 | return self_res; 704 | } 705 | 706 | static void destruct_resource(ErlNifEnv *env, void *args) { 707 | auto res = (NifResHNSWLibBFIndex *)args; 708 | if (res) { 709 | if (res->val) { 710 | delete res->val; 711 | res->val = nullptr; 712 | } 713 | 714 | if (res->rwlock) { 715 | enif_rwlock_destroy(res->rwlock); 716 | res->rwlock = nullptr; 717 | } 718 | } 719 | } 720 | }; 721 | 722 | #endif /* HNSWLIB_INDEX_HPP */ 723 | -------------------------------------------------------------------------------- /c_src/nif_utils.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by Cocoa on 14/06/2022. 3 | // 4 | 5 | #include "nif_utils.hpp" 6 | 7 | namespace erlang { 8 | namespace nif { 9 | 10 | // Atoms 11 | 12 | int get_atom(ErlNifEnv *env, ERL_NIF_TERM term, std::string &var) { 13 | unsigned atom_length; 14 | if (!enif_get_atom_length(env, term, &atom_length, ERL_NIF_LATIN1)) { 15 | return 0; 16 | } 17 | 18 | var.resize(atom_length + 1); 19 | 20 | if (!enif_get_atom(env, term, &(*(var.begin())), var.size(), ERL_NIF_LATIN1)) { 21 | return 0; 22 | } 23 | 24 | var.resize(atom_length); 25 | return 1; 26 | } 27 | 28 | ERL_NIF_TERM atom(ErlNifEnv *env, const char *msg) { 29 | ERL_NIF_TERM a; 30 | if (enif_make_existing_atom(env, msg, &a, ERL_NIF_LATIN1)) { 31 | return a; 32 | } else { 33 | return enif_make_atom(env, msg); 34 | } 35 | } 36 | 37 | // Helper for returning `{:error, msg}` from NIF. 38 | ERL_NIF_TERM error(ErlNifEnv *env, const char *msg) { 39 | ERL_NIF_TERM error_atom = atom(env, "error"); 40 | ERL_NIF_TERM reason; 41 | unsigned char *ptr; 42 | size_t len = strlen(msg); 43 | if ((ptr = enif_make_new_binary(env, len, &reason)) != nullptr) { 44 | strcpy((char *) ptr, msg); 45 | return enif_make_tuple2(env, error_atom, reason); 46 | } else { 47 | ERL_NIF_TERM msg_term = enif_make_string(env, msg, ERL_NIF_LATIN1); 48 | return enif_make_tuple2(env, error_atom, msg_term); 49 | } 50 | } 51 | 52 | // Helper for returning `{:ok, term}` from NIF. 53 | ERL_NIF_TERM ok(ErlNifEnv *env) { 54 | return atom(env, "ok"); 55 | } 56 | 57 | // Helper for returning `:ok` from NIF. 58 | ERL_NIF_TERM ok(ErlNifEnv *env, ERL_NIF_TERM term) { 59 | return enif_make_tuple2(env, ok(env), term); 60 | } 61 | 62 | // Boolean type 63 | 64 | int get(ErlNifEnv *env, ERL_NIF_TERM term, bool *var) { 65 | std::string b; 66 | if (get_atom(env, term, b)) { 67 | if (b == "true") { 68 | *var = true; 69 | return 1; 70 | } else if (b == "false") { 71 | *var = false; 72 | return 1; 73 | } else { 74 | return 0; 75 | } 76 | } else { 77 | return 0; 78 | } 79 | } 80 | 81 | // Numeric types 82 | 83 | int get(ErlNifEnv *env, ERL_NIF_TERM term, int *var) { 84 | return enif_get_int(env, term, var); 85 | } 86 | 87 | int get(ErlNifEnv *env, ERL_NIF_TERM term, unsigned int *var) { 88 | return enif_get_uint(env, term, var); 89 | } 90 | 91 | int get(ErlNifEnv *env, ERL_NIF_TERM term, long long *var) { 92 | return enif_get_int64(env, term, reinterpret_cast(var)); 93 | } 94 | 95 | int get(ErlNifEnv *env, ERL_NIF_TERM term, unsigned long long *var) { 96 | return enif_get_uint64(env, term, reinterpret_cast(var)); 97 | } 98 | 99 | int get(ErlNifEnv *env, ERL_NIF_TERM term, long *var) { 100 | return enif_get_int64(env, term, reinterpret_cast(var)); 101 | } 102 | 103 | int get(ErlNifEnv *env, ERL_NIF_TERM term, unsigned long *var) { 104 | return enif_get_uint64(env, term, reinterpret_cast(var)); 105 | } 106 | 107 | int get(ErlNifEnv *env, ERL_NIF_TERM term, double *var) { 108 | return enif_get_double(env, term, var); 109 | } 110 | 111 | // Standard types 112 | 113 | int get(ErlNifEnv *env, ERL_NIF_TERM term, std::string &var) { 114 | unsigned len; 115 | int ret = enif_get_list_length(env, term, &len); 116 | 117 | if (!ret) { 118 | ErlNifBinary bin; 119 | ret = enif_inspect_binary(env, term, &bin); 120 | if (!ret) { 121 | return 0; 122 | } 123 | var = std::string((const char *) bin.data, bin.size); 124 | return ret; 125 | } 126 | 127 | var.resize(len + 1); 128 | ret = enif_get_string(env, term, &*(var.begin()), var.size(), ERL_NIF_LATIN1); 129 | 130 | if (ret > 0) { 131 | var.resize(ret - 1); 132 | } else if (ret == 0) { 133 | var.resize(0); 134 | } else { 135 | } 136 | 137 | return ret; 138 | } 139 | 140 | ERL_NIF_TERM make(ErlNifEnv *env, bool var) { 141 | if (var) { 142 | return atom(env, "true"); 143 | } else { 144 | return atom(env, "false"); 145 | } 146 | } 147 | 148 | ERL_NIF_TERM make(ErlNifEnv *env, long var) { 149 | return enif_make_int64(env, var); 150 | } 151 | 152 | ERL_NIF_TERM make(ErlNifEnv *env, int32_t var) { 153 | return enif_make_int(env, var); 154 | } 155 | 156 | ERL_NIF_TERM make(ErlNifEnv *env, long long var) { 157 | return enif_make_int64(env, var); 158 | } 159 | 160 | ERL_NIF_TERM make(ErlNifEnv *env, uint32_t var) { 161 | return enif_make_uint(env, var); 162 | } 163 | 164 | ERL_NIF_TERM make(ErlNifEnv *env, unsigned long long var) { 165 | return enif_make_uint64(env, var); 166 | } 167 | 168 | ERL_NIF_TERM make(ErlNifEnv *env, double var) { 169 | return enif_make_double(env, var); 170 | } 171 | 172 | ERL_NIF_TERM make(ErlNifEnv *env, ErlNifBinary var) { 173 | return enif_make_binary(env, &var); 174 | } 175 | 176 | ERL_NIF_TERM make(ErlNifEnv *env, const std::string& var) { 177 | return enif_make_string(env, var.c_str(), ERL_NIF_LATIN1); 178 | } 179 | 180 | ERL_NIF_TERM make(ErlNifEnv *env, const char *string) { 181 | return enif_make_string(env, string, ERL_NIF_LATIN1); 182 | } 183 | 184 | int make(ErlNifEnv *env, bool var, ERL_NIF_TERM &out) { 185 | out = make(env, var); 186 | return 0; 187 | } 188 | 189 | int make(ErlNifEnv *env, long var, ERL_NIF_TERM &out) { 190 | out = make(env, var); 191 | return 0; 192 | } 193 | 194 | int make(ErlNifEnv *env, int var, ERL_NIF_TERM &out) { 195 | out = make(env, var); 196 | return 0; 197 | } 198 | 199 | int make(ErlNifEnv *env, double var, ERL_NIF_TERM &out) { 200 | out = make(env, var); 201 | return 0; 202 | } 203 | 204 | int make(ErlNifEnv *env, ErlNifBinary var, ERL_NIF_TERM &out) { 205 | out = make(env, var); 206 | return 0; 207 | } 208 | 209 | int make(ErlNifEnv *env, const std::string& var, ERL_NIF_TERM &out) { 210 | out = make_binary(env, var); 211 | return 0; 212 | } 213 | 214 | int make(ErlNifEnv *env, const char *var, ERL_NIF_TERM &out) { 215 | out = make_binary(env, var); 216 | return 0; 217 | } 218 | 219 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out) { 220 | size_t count = array.size(); 221 | uint8_t * data = (uint8_t *)array.data(); 222 | return make_u32_list_from_c_array(env, count, data, out); 223 | } 224 | 225 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out) { 226 | size_t count = array.size(); 227 | uint16_t * data = (uint16_t *)array.data(); 228 | return make_u32_list_from_c_array(env, count, data, out); 229 | } 230 | 231 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out) { 232 | size_t count = array.size(); 233 | uint32_t * data = (uint32_t *)array.data(); 234 | return make_u32_list_from_c_array(env, count, data, out); 235 | } 236 | 237 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out) { 238 | size_t count = array.size(); 239 | uint64_t * data = (uint64_t *)array.data(); 240 | return make_u64_list_from_c_array(env, count, data, out); 241 | } 242 | 243 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out) { 244 | size_t count = array.size(); 245 | int8_t * data = (int8_t *)array.data(); 246 | return make_i32_list_from_c_array(env, count, data, out); 247 | } 248 | 249 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out) { 250 | size_t count = array.size(); 251 | int16_t * data = (int16_t *)array.data(); 252 | return make_i32_list_from_c_array(env, count, data, out); 253 | } 254 | 255 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out) { 256 | size_t count = array.size(); 257 | int32_t * data = (int32_t *)array.data(); 258 | return make_i32_list_from_c_array(env, count, data, out); 259 | } 260 | 261 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out) { 262 | size_t count = array.size(); 263 | int64_t * data = (int64_t *)array.data(); 264 | return make_i64_list_from_c_array(env, count, data, out); 265 | } 266 | 267 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out) { 268 | size_t count = array.size(); 269 | if (sizeof(unsigned long int) == 8) { 270 | uint64_t * data = (uint64_t *)array.data(); 271 | return make_u64_list_from_c_array(env, count, data, out); 272 | } else if (sizeof(unsigned long int) == 4) { 273 | uint32_t * data = (uint32_t *)array.data(); 274 | return make_u32_list_from_c_array(env, count, data, out); 275 | } else { 276 | // error 277 | return 1; 278 | } 279 | } 280 | 281 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out) { 282 | size_t count = array.size(); 283 | float * data = (float *)array.data(); 284 | return make_f64_list_from_c_array(env, count, data, out); 285 | } 286 | 287 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out) { 288 | size_t count = array.size(); 289 | double * data = (double *)array.data(); 290 | return make_f64_list_from_c_array(env, count, data, out); 291 | } 292 | 293 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out) { 294 | size_t count = array.size(); 295 | if (count == 0) { 296 | out = enif_make_list_from_array(env, nullptr, 0); 297 | return 0; 298 | } 299 | 300 | ERL_NIF_TERM *terms = (ERL_NIF_TERM *)enif_alloc(sizeof(ERL_NIF_TERM) * count); 301 | if (terms == nullptr) { 302 | return 1; 303 | } 304 | for (size_t i = 0; i < count; ++i) { 305 | terms[i] = make_binary(env, array[i]); 306 | } 307 | out = enif_make_list_from_array(env, terms, (unsigned)count); 308 | enif_free(terms); 309 | return 0; 310 | } 311 | 312 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out) { 313 | size_t count = array.size(); 314 | if (count == 0) { 315 | out = enif_make_list_from_array(env, nullptr, 0); 316 | return 0; 317 | } 318 | 319 | ERL_NIF_TERM *terms = (ERL_NIF_TERM *)enif_alloc(sizeof(ERL_NIF_TERM) * count); 320 | if (terms == nullptr) { 321 | return 1; 322 | } 323 | for (size_t i = 0; i < count; ++i) { 324 | terms[i] = make_binary(env, *array[i]); 325 | } 326 | out = enif_make_list_from_array(env, terms, (unsigned)count); 327 | enif_free(terms); 328 | return 0; 329 | } 330 | 331 | ERL_NIF_TERM make_binary(ErlNifEnv *env, const char *c_string) { 332 | ERL_NIF_TERM binary_str; 333 | unsigned char *ptr; 334 | size_t len = strlen(c_string); 335 | if ((ptr = enif_make_new_binary(env, len, &binary_str)) != nullptr) { 336 | memcpy((char *)ptr, c_string, len); 337 | return binary_str; 338 | } else { 339 | fprintf(stderr, "internal error: cannot allocate memory for binary string\r\n"); 340 | return atom(env, "error"); 341 | } 342 | } 343 | 344 | ERL_NIF_TERM make_binary(ErlNifEnv *env, const std::string& string) { 345 | ERL_NIF_TERM binary_str; 346 | unsigned char *ptr; 347 | size_t len = string.size(); 348 | if ((ptr = enif_make_new_binary(env, len, &binary_str)) != nullptr) { 349 | memcpy((char *)ptr, string.c_str(), len); 350 | return binary_str; 351 | } else { 352 | fprintf(stderr, "internal error: cannot allocate memory for binary string\r\n"); 353 | return atom(env, "error"); 354 | } 355 | } 356 | 357 | // Check if :nil 358 | int check_nil(ErlNifEnv *env, ERL_NIF_TERM term) { 359 | std::string atom_str; 360 | if (get_atom(env, term, atom_str) && atom_str == "nil") { 361 | return true; 362 | } 363 | return false; 364 | } 365 | 366 | // Containers 367 | 368 | int get_tuple(ErlNifEnv *env, ERL_NIF_TERM tuple, std::vector &var) { 369 | const ERL_NIF_TERM *terms; 370 | int length; 371 | if (!enif_get_tuple(env, tuple, &length, &terms)) { 372 | return 0; 373 | } 374 | 375 | var.reserve(length); 376 | 377 | for (int i = 0; i < length; i++) { 378 | int data; 379 | if (!get(env, terms[i], &data)) { 380 | return 0; 381 | } 382 | 383 | var.push_back(data); 384 | } 385 | return 1; 386 | } 387 | 388 | int get_list(ErlNifEnv *env, ERL_NIF_TERM list, std::vector &var) { 389 | unsigned int length; 390 | if (!enif_get_list_length(env, list, &length)) { 391 | return 0; 392 | } 393 | 394 | var.reserve(length); 395 | ERL_NIF_TERM head, tail; 396 | 397 | while (enif_get_list_cell(env, list, &head, &tail)) { 398 | ErlNifBinary elem; 399 | if (!enif_inspect_binary(env, head, &elem)) { 400 | return 0; 401 | } 402 | 403 | var.push_back(elem); 404 | list = tail; 405 | } 406 | return 1; 407 | } 408 | 409 | int get_list(ErlNifEnv *env, ERL_NIF_TERM list, std::vector &var) { 410 | unsigned int length; 411 | if (!enif_get_list_length(env, list, &length)) { 412 | return 0; 413 | } 414 | 415 | var.reserve(length); 416 | ERL_NIF_TERM head, tail; 417 | 418 | while (enif_get_list_cell(env, list, &head, &tail)) { 419 | std::string elem; 420 | if (!get_atom(env, head, elem)) { 421 | return 0; 422 | } 423 | 424 | var.push_back(elem); 425 | list = tail; 426 | } 427 | 428 | return 1; 429 | } 430 | 431 | int get_list(ErlNifEnv *env, ERL_NIF_TERM list, std::vector &var) { 432 | unsigned int length; 433 | if (!enif_get_list_length(env, list, &length)) { 434 | return 0; 435 | } 436 | 437 | var.reserve(length); 438 | ERL_NIF_TERM head, tail; 439 | 440 | while (enif_get_list_cell(env, list, &head, &tail)) { 441 | int elem; 442 | if (!get(env, head, &elem)) { 443 | return 0; 444 | } 445 | 446 | var.push_back(elem); 447 | list = tail; 448 | } 449 | 450 | return 1; 451 | } 452 | 453 | int get_list(ErlNifEnv *env, ERL_NIF_TERM list, std::vector &var) { 454 | unsigned int length; 455 | if (!enif_get_list_length(env, list, &length)) { 456 | return 0; 457 | } 458 | 459 | var.reserve(length); 460 | ERL_NIF_TERM head, tail; 461 | 462 | while (enif_get_list_cell(env, list, &head, &tail)) { 463 | int64_t elem; 464 | if (!get(env, head, &elem)) { 465 | return 0; 466 | } 467 | 468 | var.push_back(elem); 469 | list = tail; 470 | } 471 | return 1; 472 | } 473 | 474 | int get_list(ErlNifEnv *env, ERL_NIF_TERM list, std::vector &var) { 475 | unsigned int length; 476 | if (!enif_get_list_length(env, list, &length)) { 477 | return 0; 478 | } 479 | 480 | var.reserve(length); 481 | ERL_NIF_TERM head, tail; 482 | 483 | while (enif_get_list_cell(env, list, &head, &tail)) { 484 | uint64_t elem; 485 | if (!get(env, head, &elem)) { 486 | return 0; 487 | } 488 | 489 | var.push_back(elem); 490 | list = tail; 491 | } 492 | return 1; 493 | } 494 | 495 | } 496 | } 497 | -------------------------------------------------------------------------------- /c_src/nif_utils.hpp: -------------------------------------------------------------------------------- 1 | #ifndef ERLANG_NIF_UTILS_HPP 2 | #define ERLANG_NIF_UTILS_HPP 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | 15 | namespace erlang { 16 | namespace nif { 17 | // Atoms 18 | 19 | int get_atom(ErlNifEnv *env, ERL_NIF_TERM term, std::string &var); 20 | ERL_NIF_TERM atom(ErlNifEnv *env, const char *msg); 21 | 22 | // Helper for returning `{:error, msg}` from NIF. 23 | ERL_NIF_TERM error(ErlNifEnv *env, const char *msg); 24 | 25 | // Helper for returning `{:ok, term}` from NIF. 26 | ERL_NIF_TERM ok(ErlNifEnv *env); 27 | 28 | // Helper for returning `:ok` from NIF. 29 | ERL_NIF_TERM ok(ErlNifEnv *env, ERL_NIF_TERM term); 30 | 31 | template 32 | int make_f64_list_from_c_array(ErlNifEnv *env, size_t count, T *data, ERL_NIF_TERM &out) { 33 | if (count == 0) { 34 | out = enif_make_list_from_array(env, nullptr, 0); 35 | return 0; 36 | } 37 | 38 | ERL_NIF_TERM *terms = (ERL_NIF_TERM *)enif_alloc(sizeof(ERL_NIF_TERM) * count); 39 | if (terms == nullptr) { 40 | return 1; 41 | } 42 | for (size_t i = 0; i < count; ++i) { 43 | terms[i] = enif_make_double(env, (double)(data[i])); 44 | } 45 | out = enif_make_list_from_array(env, terms, (unsigned) count); 46 | enif_free(terms); 47 | return 0; 48 | } 49 | 50 | template 51 | int make_i64_list_from_c_array(ErlNifEnv *env, size_t count, T *data, ERL_NIF_TERM &out) { 52 | if (count == 0) { 53 | out = enif_make_list_from_array(env, nullptr, 0); 54 | return 0; 55 | } 56 | 57 | ERL_NIF_TERM *terms = (ERL_NIF_TERM *)enif_alloc(sizeof(ERL_NIF_TERM) * count); 58 | if (terms == nullptr) { 59 | return 1; 60 | } 61 | for (size_t i = 0; i < count; ++i) { 62 | terms[i] = enif_make_int64(env, (int64_t)(data[i])); 63 | } 64 | out = enif_make_list_from_array(env, terms, (unsigned) count); 65 | enif_free(terms); 66 | return 0; 67 | } 68 | 69 | template 70 | int make_u64_list_from_c_array(ErlNifEnv *env, size_t count, T *data, ERL_NIF_TERM &out) { 71 | if (count == 0) { 72 | out = enif_make_list_from_array(env, nullptr, 0); 73 | return 0; 74 | } 75 | 76 | ERL_NIF_TERM *terms = (ERL_NIF_TERM *)enif_alloc(sizeof(ERL_NIF_TERM) * count); 77 | if (terms == nullptr) { 78 | return 1; 79 | } 80 | for (size_t i = 0; i < count; ++i) { 81 | terms[i] = enif_make_uint64(env, (uint64_t)(data[i])); 82 | } 83 | out = enif_make_list_from_array(env, terms, (unsigned) count); 84 | enif_free(terms); 85 | return 0; 86 | } 87 | 88 | template 89 | int make_i32_list_from_c_array(ErlNifEnv *env, size_t count, T *data, ERL_NIF_TERM &out) { 90 | if (count == 0) { 91 | out = enif_make_list_from_array(env, nullptr, 0); 92 | return 0; 93 | } 94 | 95 | ERL_NIF_TERM *terms = (ERL_NIF_TERM *)enif_alloc(sizeof(ERL_NIF_TERM) * count); 96 | if (terms == nullptr) { 97 | return 1; 98 | } 99 | for (size_t i = 0; i < count; ++i) { 100 | terms[i] = enif_make_int(env, (int32_t)(data[i])); 101 | } 102 | out = enif_make_list_from_array(env, terms, (unsigned) count); 103 | enif_free(terms); 104 | return 0; 105 | } 106 | 107 | template 108 | int make_u32_list_from_c_array(ErlNifEnv *env, size_t count, T *data, ERL_NIF_TERM &out) { 109 | if (count == 0) { 110 | out = enif_make_list_from_array(env, nullptr, 0); 111 | return 0; 112 | } 113 | 114 | ERL_NIF_TERM *terms = (ERL_NIF_TERM *)enif_alloc(sizeof(ERL_NIF_TERM) * count); 115 | if (terms == nullptr) { 116 | return 1; 117 | } 118 | for (size_t i = 0; i < count; ++i) { 119 | terms[i] = enif_make_uint(env, (uint32_t)(data[i])); 120 | } 121 | out = enif_make_list_from_array(env, terms, (unsigned) count); 122 | enif_free(terms); 123 | return 0; 124 | } 125 | 126 | // Numeric types 127 | 128 | int get(ErlNifEnv *env, ERL_NIF_TERM term, int *var); 129 | 130 | int get(ErlNifEnv *env, ERL_NIF_TERM term, unsigned int *var); 131 | 132 | int get(ErlNifEnv *env, ERL_NIF_TERM term, long long *var); 133 | 134 | int get(ErlNifEnv *env, ERL_NIF_TERM term, unsigned long long *var); 135 | 136 | int get(ErlNifEnv *env, ERL_NIF_TERM term, long *var); 137 | 138 | int get(ErlNifEnv *env, ERL_NIF_TERM term, unsigned long *var); 139 | 140 | int get(ErlNifEnv *env, ERL_NIF_TERM term, double *var); 141 | 142 | // Standard types 143 | 144 | int get(ErlNifEnv *env, ERL_NIF_TERM term, std::string &var); 145 | 146 | ERL_NIF_TERM make(ErlNifEnv *env, bool var); 147 | ERL_NIF_TERM make(ErlNifEnv *env, long var); 148 | ERL_NIF_TERM make(ErlNifEnv *env, int32_t var); 149 | ERL_NIF_TERM make(ErlNifEnv *env, long long var); 150 | ERL_NIF_TERM make(ErlNifEnv *env, uint32_t var); 151 | ERL_NIF_TERM make(ErlNifEnv *env, unsigned long long var); 152 | ERL_NIF_TERM make(ErlNifEnv *env, double var); 153 | ERL_NIF_TERM make(ErlNifEnv *env, ErlNifBinary var); 154 | ERL_NIF_TERM make(ErlNifEnv *env, std::string var); 155 | ERL_NIF_TERM make(ErlNifEnv *env, const char *string); 156 | 157 | int make(ErlNifEnv *env, bool var, ERL_NIF_TERM &out); 158 | int make(ErlNifEnv *env, long var, ERL_NIF_TERM &out); 159 | int make(ErlNifEnv *env, int var, ERL_NIF_TERM &out); 160 | int make(ErlNifEnv *env, double var, ERL_NIF_TERM &out); 161 | int make(ErlNifEnv *env, ErlNifBinary var, ERL_NIF_TERM &out); 162 | int make(ErlNifEnv *env, const std::string& var, ERL_NIF_TERM &out); 163 | int make(ErlNifEnv *env, const char *string, ERL_NIF_TERM &out); 164 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out); 165 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out); 166 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out); 167 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out); 168 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out); 169 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out); 170 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out); 171 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out); 172 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out); 173 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out); 174 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out); 175 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out); 176 | int make(ErlNifEnv *env, const std::vector& array, ERL_NIF_TERM &out); 177 | 178 | ERL_NIF_TERM make_binary(ErlNifEnv *env, const char *c_string); 179 | ERL_NIF_TERM make_binary(ErlNifEnv *env, const std::string& string); 180 | 181 | template 182 | int make(ErlNifEnv *env, const std::map& map, ERL_NIF_TERM &out, bool atom_key) { 183 | bool failed = false; 184 | size_t size = map.size(); 185 | 186 | if (size == 0) { 187 | out = enif_make_new_map(env); 188 | return 0; 189 | } 190 | 191 | ERL_NIF_TERM * keys = (ERL_NIF_TERM *)enif_alloc(sizeof(ERL_NIF_TERM) * size); 192 | if (!keys) { 193 | return 1; 194 | } 195 | ERL_NIF_TERM * values = (ERL_NIF_TERM *)enif_alloc(sizeof(ERL_NIF_TERM) * size); 196 | if (!values) { 197 | enif_free(keys); 198 | return 1; 199 | } 200 | 201 | size_t index = 0; 202 | for (const auto &p : map) { 203 | if (atom_key) { 204 | keys[index] = atom(env, p.first.c_str()); 205 | } else { 206 | if (make(env, p.first, keys[index])) { 207 | failed = true; 208 | break; 209 | } 210 | } 211 | 212 | if (make(env, p.second, values[index])) { 213 | failed = true; 214 | break; 215 | } 216 | 217 | index++; 218 | } 219 | 220 | if (failed) { 221 | enif_free(keys); 222 | enif_free(values); 223 | return 1; 224 | } 225 | 226 | ERL_NIF_TERM map_out; 227 | if (!enif_make_map_from_arrays(env, keys, values, index, &map_out)) { 228 | return 1; 229 | } 230 | 231 | out = map_out; 232 | return 0; 233 | } 234 | 235 | template 236 | int make(ErlNifEnv *env, const std::vector>& array, ERL_NIF_TERM &out, bool atom_key) { 237 | size_t count = array.size(); 238 | if (count == 0) { 239 | out = enif_make_list_from_array(env, nullptr, 0); 240 | return 0; 241 | } 242 | 243 | ERL_NIF_TERM *terms = (ERL_NIF_TERM *)enif_alloc(sizeof(ERL_NIF_TERM) * count); 244 | if (terms == nullptr) { 245 | return 1; 246 | } 247 | for (size_t i = 0; i < count; ++i) { 248 | if (make(env, array[i], terms[i], atom_key)) { 249 | enif_free(terms); 250 | return 1; 251 | } 252 | } 253 | out = enif_make_list_from_array(env, terms, (unsigned)count); 254 | enif_free(terms); 255 | return 0; 256 | } 257 | 258 | // Check if :nil 259 | int check_nil(ErlNifEnv *env, ERL_NIF_TERM term); 260 | 261 | // Boolean 262 | 263 | int get(ErlNifEnv *env, ERL_NIF_TERM term, bool *var); 264 | 265 | // Containers 266 | 267 | int get_tuple(ErlNifEnv *env, ERL_NIF_TERM tuple, std::vector &var); 268 | int get_list(ErlNifEnv *env, ERL_NIF_TERM list, std::vector &var); 269 | int get_list(ErlNifEnv *env, ERL_NIF_TERM list, std::vector &var); 270 | int get_list(ErlNifEnv *env, ERL_NIF_TERM list, std::vector &var); 271 | int get_list(ErlNifEnv *env, ERL_NIF_TERM list, std::vector &var); 272 | int get_list(ErlNifEnv *env, ERL_NIF_TERM list, std::vector &var); 273 | 274 | } 275 | } 276 | 277 | #endif // ERLANG_NIF_UTILS_HPP 278 | -------------------------------------------------------------------------------- /lib/hnswlib_bfindex.ex: -------------------------------------------------------------------------------- 1 | defmodule HNSWLib.BFIndex do 2 | @moduledoc """ 3 | Documentation for `HNSWLib.BFIndex`. 4 | """ 5 | 6 | defstruct [:space, :dim, :reference] 7 | alias __MODULE__, as: T 8 | alias HNSWLib.Helper 9 | 10 | @doc """ 11 | Construct a new BFIndex 12 | 13 | ##### Positional Parameters 14 | 15 | - *space*: `:cosine` | `:ip` | `:l2`. 16 | 17 | An atom that indicates the vector space. Valid values are 18 | 19 | - `:cosine`, cosine space 20 | - `:ip`, inner product space 21 | - `:l2`, L2 space 22 | 23 | - *dim*: `non_neg_integer()`. 24 | 25 | Number of dimensions for each vector. 26 | 27 | - *max_elements*: `non_neg_integer()`. 28 | 29 | Number of maximum elements. 30 | """ 31 | @spec new(:cosine | :ip | :l2, non_neg_integer(), non_neg_integer()) :: 32 | {:ok, %T{}} | {:error, String.t()} 33 | def new(space, dim, max_elements) 34 | when (space == :l2 or space == :ip or space == :cosine) and is_integer(dim) and dim >= 0 and 35 | is_integer(max_elements) and max_elements >= 0 do 36 | with {:ok, ref} <- HNSWLib.Nif.bfindex_new(space, dim, max_elements) do 37 | {:ok, 38 | %T{ 39 | space: space, 40 | dim: dim, 41 | reference: ref 42 | }} 43 | else 44 | {:error, reason} -> 45 | {:error, reason} 46 | end 47 | end 48 | 49 | @doc """ 50 | Query the index with a single vector or a list of vectors. 51 | 52 | ##### Positional Parameters 53 | 54 | - *query*: `Nx.Tensor.t() | binary() | [binary()]`. 55 | 56 | A vector or a list of vectors to query. 57 | 58 | If *query* is a list of vectors, the vectors must be of the same dimension. 59 | 60 | ##### Keyword Paramters 61 | 62 | - *k*: `pos_integer()`. 63 | 64 | Number of nearest neighbors to return. 65 | """ 66 | @spec knn_query(%T{}, Nx.Tensor.t() | binary() | [binary()], [ 67 | {:k, pos_integer()} 68 | ]) :: {:ok, Nx.Tensor.t(), Nx.Tensor.t()} | {:error, String.t()} 69 | def knn_query(self, query, opts \\ []) 70 | 71 | def knn_query(self = %T{}, query, opts) when is_binary(query) do 72 | k = Helper.get_keyword!(opts, :k, :pos_integer, 1) 73 | Helper.might_be_float_data!(query) 74 | features = trunc(byte_size(query) / HNSWLib.Nif.float_size()) 75 | Helper.ensure_vector_dimension!(self, features, true) 76 | 77 | _do_knn_query(self, query, k, nil, 1, features) 78 | end 79 | 80 | def knn_query(self = %T{}, query, opts) when is_list(query) do 81 | k = Helper.get_keyword!(opts, :k, :pos_integer, 1) 82 | filter = Helper.get_keyword!(opts, :filter, {:function, 1}, nil, true) 83 | {rows, features} = Helper.list_of_binary(query) 84 | Helper.ensure_vector_dimension!(self, features, true) 85 | 86 | _do_knn_query(self, IO.iodata_to_binary(query), k, filter, rows, features) 87 | end 88 | 89 | def knn_query(self = %T{}, query = %Nx.Tensor{}, opts) do 90 | k = Helper.get_keyword!(opts, :k, :pos_integer, 1) 91 | filter = Helper.get_keyword!(opts, :filter, {:function, 1}, nil, true) 92 | {f32_data, rows, features} = Helper.verify_data_tensor!(self, query) 93 | 94 | _do_knn_query(self, f32_data, k, filter, rows, features) 95 | end 96 | 97 | defp _do_knn_query(self, query, k, filter, rows, features) do 98 | case HNSWLib.Nif.bfindex_knn_query(self.reference, query, k, filter, rows, features) do 99 | {:ok, labels, dists, rows, k, label_bits, dist_bits} -> 100 | labels = Nx.reshape(Nx.from_binary(labels, :"u#{label_bits}"), {rows, k}) 101 | dists = Nx.reshape(Nx.from_binary(dists, :"f#{dist_bits}"), {rows, k}) 102 | {:ok, labels, dists} 103 | 104 | {:error, reason} -> 105 | {:error, reason} 106 | end 107 | end 108 | 109 | @doc """ 110 | Add items to the index. 111 | 112 | ##### Positional Parameters 113 | 114 | - *data*: `Nx.Tensor.t()`. 115 | 116 | Data to add to the index. 117 | 118 | ##### Keyword Parameters 119 | 120 | - *ids*: `Nx.Tensor.t() | [non_neg_integer()] | nil`. 121 | 122 | IDs to assign to the data. 123 | 124 | If `nil`, IDs will be assigned sequentially starting from 0. 125 | 126 | Defaults to `nil`. 127 | 128 | """ 129 | @spec add_items(%T{}, Nx.Tensor.t(), [ 130 | {:ids, Nx.Tensor.t() | [non_neg_integer()] | nil} 131 | ]) :: :ok | {:error, String.t()} 132 | def add_items(self, data, opts \\ []) 133 | 134 | def add_items(self = %T{}, data = %Nx.Tensor{}, opts) when is_list(opts) do 135 | ids = Helper.normalize_ids!(opts[:ids]) 136 | {f32_data, rows, features} = Helper.verify_data_tensor!(self, data) 137 | 138 | HNSWLib.Nif.bfindex_add_items(self.reference, f32_data, ids, rows, features) 139 | end 140 | 141 | @doc """ 142 | Delete vectors with the given labels from the index. 143 | """ 144 | def delete_vector(self = %T{}, label) when is_integer(label) do 145 | HNSWLib.Nif.bfindex_delete_vector(self.reference, label) 146 | end 147 | 148 | @doc """ 149 | Get the current number of threads to use in the index. 150 | """ 151 | @spec set_num_threads(%T{}, pos_integer()) :: :ok | {:error, String.t()} 152 | def set_num_threads(self = %T{}, num_threads) 153 | when is_integer(num_threads) and num_threads > 0 do 154 | HNSWLib.Nif.bfindex_set_num_threads(self.reference, num_threads) 155 | end 156 | 157 | @doc """ 158 | Save current index to disk. 159 | 160 | ##### Positional Parameters 161 | 162 | - *path*: `Path.t()`. 163 | 164 | Path to save the index to. 165 | """ 166 | @spec save_index(%T{}, Path.t()) :: :ok | {:error, String.t()} 167 | def save_index(self = %T{}, path) when is_binary(path) do 168 | HNSWLib.Nif.bfindex_save_index(self.reference, path) 169 | end 170 | 171 | @doc """ 172 | Load index from disk. 173 | 174 | ##### Positional Parameters 175 | 176 | - *space*: `:cosine` | `:ip` | `:l2`. 177 | 178 | An atom that indicates the vector space. Valid values are 179 | 180 | - `:cosine`, cosine space 181 | - `:ip`, inner product space 182 | - `:l2`, L2 space 183 | 184 | - *dim*: `non_neg_integer()`. 185 | 186 | Number of dimensions for each vector. 187 | 188 | - *path*: `Path.t()`. 189 | 190 | Path to load the index from. 191 | 192 | ##### Keyword Parameters 193 | 194 | - *max_elements*: `non_neg_integer()`. 195 | 196 | Maximum number of elements to load from the index. 197 | 198 | If set to 0, all elements will be loaded. 199 | 200 | Defaults to 0. 201 | """ 202 | @spec load_index(:cosine | :ip | :l2, non_neg_integer(), Path.t(), [ 203 | {:max_elements, non_neg_integer()} 204 | ]) :: {:ok, %T{}} | {:error, String.t()} 205 | def load_index(space, dim, path, opts \\ []) 206 | when (space == :l2 or space == :ip or space == :cosine) and is_integer(dim) and dim >= 0 and 207 | is_binary(path) and is_list(opts) do 208 | max_elements = Helper.get_keyword!(opts, :max_elements, :non_neg_integer, 0) 209 | 210 | with {:ok, ref} <- HNSWLib.Nif.bfindex_load_index(space, dim, path, max_elements) do 211 | {:ok, 212 | %T{ 213 | space: space, 214 | dim: dim, 215 | reference: ref 216 | }} 217 | else 218 | {:error, reason} -> 219 | {:error, reason} 220 | end 221 | end 222 | 223 | @doc """ 224 | Get the maximum number of elements the index can hold. 225 | """ 226 | @spec get_max_elements(%T{}) :: {:ok, integer()} | {:error, String.t()} 227 | def get_max_elements(self = %T{}) do 228 | HNSWLib.Nif.bfindex_get_max_elements(self.reference) 229 | end 230 | 231 | @doc """ 232 | Get the current number of elements in the index. 233 | """ 234 | @spec get_current_count(%T{}) :: {:ok, integer()} | {:error, String.t()} 235 | def get_current_count(self = %T{}) do 236 | HNSWLib.Nif.bfindex_get_current_count(self.reference) 237 | end 238 | 239 | @doc """ 240 | Get the current number of threads to use in the index. 241 | """ 242 | @spec get_num_threads(%T{}) :: {:ok, integer()} | {:error, String.t()} 243 | def get_num_threads(self = %T{}) do 244 | HNSWLib.Nif.bfindex_get_num_threads(self.reference) 245 | end 246 | end 247 | -------------------------------------------------------------------------------- /lib/hnswlib_helper.ex: -------------------------------------------------------------------------------- 1 | defmodule HNSWLib.Helper do 2 | @moduledoc false 3 | 4 | def get_keyword!(opts, key, type, default, allow_nil? \\ false) do 5 | val = opts[key] || default 6 | 7 | if allow_nil? and val == nil do 8 | val 9 | else 10 | case get_keyword(key, opts[key] || default, type) do 11 | {:ok, val} -> 12 | val 13 | 14 | {:error, reason} -> 15 | raise ArgumentError, reason 16 | end 17 | end 18 | end 19 | 20 | defp get_keyword(_key, val, :non_neg_integer) when is_integer(val) and val >= 0 do 21 | {:ok, val} 22 | end 23 | 24 | defp get_keyword(key, val, :non_neg_integer) do 25 | {:error, 26 | "expect keyword parameter `#{inspect(key)}` to be a non-negative integer, got `#{inspect(val)}`"} 27 | end 28 | 29 | defp get_keyword(_key, val, :pos_integer) when is_integer(val) and val > 0 do 30 | {:ok, val} 31 | end 32 | 33 | defp get_keyword(key, val, :pos_integer) do 34 | {:error, 35 | "expect keyword parameter `#{inspect(key)}` to be a positive integer, got `#{inspect(val)}`"} 36 | end 37 | 38 | defp get_keyword(_key, val, :integer) when is_integer(val) do 39 | {:ok, val} 40 | end 41 | 42 | defp get_keyword(key, val, :integer) do 43 | {:error, "expect keyword parameter `#{inspect(key)}` to be an integer, got `#{inspect(val)}`"} 44 | end 45 | 46 | defp get_keyword(_key, val, :boolean) when is_boolean(val) do 47 | {:ok, val} 48 | end 49 | 50 | defp get_keyword(key, val, :boolean) do 51 | {:error, "expect keyword parameter `#{inspect(key)}` to be a boolean, got `#{inspect(val)}`"} 52 | end 53 | 54 | defp get_keyword(_key, val, :function) when is_function(val) do 55 | {:ok, val} 56 | end 57 | 58 | defp get_keyword(key, val, :function) do 59 | {:error, "expect keyword parameter `#{inspect(key)}` to be a function, got `#{inspect(val)}`"} 60 | end 61 | 62 | defp get_keyword(_key, val, {:function, arity}) 63 | when is_integer(arity) and arity >= 0 and is_function(val, arity) do 64 | {:ok, val} 65 | end 66 | 67 | defp get_keyword(key, val, {:function, arity}) when is_integer(arity) and arity >= 0 do 68 | {:error, 69 | "expect keyword parameter `#{inspect(key)}` to be a function that can be applied with #{arity} number of arguments , got `#{inspect(val)}`"} 70 | end 71 | 72 | defp get_keyword(_key, val, :atom) when is_atom(val) do 73 | {:ok, val} 74 | end 75 | 76 | defp get_keyword(key, val, {:atom, allowed_atoms}) 77 | when is_atom(val) and is_list(allowed_atoms) do 78 | if val in allowed_atoms do 79 | {:ok, val} 80 | else 81 | {:error, 82 | "expect keyword parameter `#{inspect(key)}` to be an atom and is one of `#{inspect(allowed_atoms)}`, got `#{inspect(val)}`"} 83 | end 84 | end 85 | 86 | def list_of_binary(data) when is_list(data) do 87 | count = Enum.count(data) 88 | 89 | if count > 0 do 90 | first = Enum.at(data, 0) 91 | 92 | if is_binary(first) do 93 | expected_size = byte_size(first) 94 | 95 | if rem(expected_size, HNSWLib.Nif.float_size()) != 0 do 96 | raise ArgumentError, 97 | "vector feature size should be a multiple of #{HNSWLib.Nif.float_size()} (sizeof(float))" 98 | else 99 | features = trunc(expected_size / HNSWLib.Nif.float_size()) 100 | 101 | if list_of_binary(data, expected_size) == false do 102 | raise ArgumentError, "all vectors in the input list should have the same size" 103 | else 104 | {count, features} 105 | end 106 | end 107 | end 108 | else 109 | {0, 0} 110 | end 111 | end 112 | 113 | defp list_of_binary([elem | rest], expected_size) when is_binary(elem) do 114 | if byte_size(elem) == expected_size do 115 | list_of_binary(rest, expected_size) 116 | else 117 | false 118 | end 119 | end 120 | 121 | defp list_of_binary([], expected_size) do 122 | expected_size 123 | end 124 | 125 | def verify_data_tensor!(self, data = %Nx.Tensor{}) do 126 | {rows, features} = 127 | case data.shape do 128 | {rows, features} -> 129 | ensure_vector_dimension!(self, features, {rows, features}) 130 | 131 | {features} -> 132 | ensure_vector_dimension!(self, features, {1, features}) 133 | 134 | shape -> 135 | raise ArgumentError, 136 | "Input vector data wrong shape. Number of dimensions #{tuple_size(shape)}. Data must be a 1D or 2D array." 137 | end 138 | 139 | {Nx.to_binary(Nx.as_type(data, :f32)), rows, features} 140 | end 141 | 142 | def ensure_vector_dimension!(%{dim: dim}, dim, ret), do: ret 143 | 144 | def ensure_vector_dimension!(%{dim: dim}, features, _ret) do 145 | raise ArgumentError, 146 | "Wrong dimensionality of the vectors, expect `#{dim}`, got `#{features}`" 147 | end 148 | 149 | def might_be_float_data!(data) do 150 | if rem(byte_size(data), float_size()) != 0 do 151 | raise ArgumentError, 152 | "vector feature size should be a multiple of #{HNSWLib.Nif.float_size()} (sizeof(float))" 153 | end 154 | end 155 | 156 | def normalize_ids!(ids = %Nx.Tensor{}) do 157 | case ids.shape do 158 | {_} -> 159 | Nx.to_binary(Nx.as_type(ids, :u64)) 160 | 161 | shape -> 162 | raise ArgumentError, "expect ids to be a 1D array, got `#{inspect(shape)}`" 163 | end 164 | end 165 | 166 | def normalize_ids!(ids) when is_list(ids) do 167 | if Enum.all?(ids, fn x -> 168 | is_integer(x) and x >= 0 169 | end) do 170 | for item <- ids, into: "", do: <> 171 | else 172 | raise ArgumentError, "expect `ids` to be a list of non-negative integers" 173 | end 174 | end 175 | 176 | def normalize_ids!(nil) do 177 | <<>> 178 | end 179 | 180 | def float_size do 181 | HNSWLib.Nif.float_size() 182 | end 183 | end 184 | -------------------------------------------------------------------------------- /lib/hnswlib_index.ex: -------------------------------------------------------------------------------- 1 | defmodule HNSWLib.Index do 2 | @moduledoc """ 3 | Documentation for `HNSWLib.Index`. 4 | """ 5 | 6 | defstruct [:space, :dim, :reference] 7 | alias __MODULE__, as: T 8 | alias HNSWLib.Helper 9 | 10 | @doc """ 11 | Construct a new Index 12 | 13 | ##### Positional Parameters 14 | 15 | - *space*: `:cosine` | `:ip` | `:l2`. 16 | 17 | An atom that indicates the vector space. Valid values are 18 | 19 | - `:cosine`, cosine space 20 | - `:ip`, inner product space 21 | - `:l2`, L2 space 22 | 23 | - *dim*: `non_neg_integer()`. 24 | 25 | Number of dimensions for each vector. 26 | 27 | - *max_elements*: `pos_integer()`. 28 | 29 | Number of maximum elements. 30 | 31 | ##### Keyword Paramters 32 | 33 | - *m*: `non_neg_integer()`. 34 | 35 | `M` is tightly connected with internal dimensionality of the data 36 | strongly affects the memory consumption 37 | 38 | - *ef_construction*: `non_neg_integer()`. 39 | 40 | controls index search speed/build speed tradeoff 41 | 42 | - *random_seed*: `non_neg_integer()`. 43 | - *allow_replace_deleted*: `boolean()`. 44 | """ 45 | @spec new(:cosine | :ip | :l2, non_neg_integer(), pos_integer(), [ 46 | {:m, non_neg_integer()}, 47 | {:ef_construction, non_neg_integer()}, 48 | {:random_seed, non_neg_integer()}, 49 | {:allow_replace_deleted, boolean()} 50 | ]) :: {:ok, %T{}} | {:error, String.t()} 51 | def new(space, dim, max_elements, opts \\ []) 52 | when (space == :l2 or space == :ip or space == :cosine) and is_integer(dim) and dim >= 0 and 53 | is_integer(max_elements) and max_elements > 0 do 54 | m = Helper.get_keyword!(opts, :m, :non_neg_integer, 16) 55 | ef_construction = Helper.get_keyword!(opts, :ef_construction, :non_neg_integer, 200) 56 | random_seed = Helper.get_keyword!(opts, :random_seed, :non_neg_integer, 100) 57 | allow_replace_deleted = Helper.get_keyword!(opts, :allow_replace_deleted, :boolean, false) 58 | 59 | with {:ok, ref} <- 60 | HNSWLib.Nif.index_new( 61 | space, 62 | dim, 63 | max_elements, 64 | m, 65 | ef_construction, 66 | random_seed, 67 | allow_replace_deleted 68 | ) do 69 | {:ok, 70 | %T{ 71 | space: space, 72 | dim: dim, 73 | reference: ref 74 | }} 75 | else 76 | {:error, reason} -> 77 | {:error, reason} 78 | end 79 | end 80 | 81 | @doc """ 82 | Query the index with a single vector or a list of vectors. 83 | 84 | ##### Positional Parameters 85 | 86 | - *query*: `Nx.Tensor.t() | binary() | [binary()]`. 87 | 88 | A vector or a list of vectors to query. 89 | 90 | If *query* is a list of vectors, the vectors must be of the same dimension. 91 | 92 | ##### Keyword Paramters 93 | 94 | - *k*: `pos_integer()`. 95 | 96 | Number of nearest neighbors to return. 97 | 98 | - *num_threads*: `integer()`. 99 | 100 | Number of threads to use. 101 | """ 102 | @spec knn_query(%T{}, Nx.Tensor.t() | binary() | [binary()], [ 103 | {:k, pos_integer()}, 104 | {:num_threads, integer()} 105 | # {:filter, function()} 106 | ]) :: {:ok, Nx.Tensor.t(), Nx.Tensor.t()} | {:error, String.t()} 107 | def knn_query(self, query, opts \\ []) 108 | 109 | def knn_query(self = %T{}, query, opts) when is_binary(query) do 110 | k = Helper.get_keyword!(opts, :k, :pos_integer, 1) 111 | num_threads = Helper.get_keyword!(opts, :num_threads, :integer, -1) 112 | Helper.might_be_float_data!(query) 113 | features = trunc(byte_size(query) / Helper.float_size()) 114 | Helper.ensure_vector_dimension!(self, features, true) 115 | 116 | _do_knn_query(self, query, k, num_threads, nil, 1, features) 117 | end 118 | 119 | def knn_query(self = %T{}, query, opts) when is_list(query) do 120 | k = Helper.get_keyword!(opts, :k, :pos_integer, 1) 121 | num_threads = Helper.get_keyword!(opts, :num_threads, :integer, -1) 122 | {rows, features} = Helper.list_of_binary(query) 123 | Helper.ensure_vector_dimension!(self, features, true) 124 | 125 | _do_knn_query(self, IO.iodata_to_binary(query), k, num_threads, nil, rows, features) 126 | end 127 | 128 | def knn_query(self = %T{}, query = %Nx.Tensor{}, opts) do 129 | k = Helper.get_keyword!(opts, :k, :pos_integer, 1) 130 | num_threads = Helper.get_keyword!(opts, :num_threads, :integer, -1) 131 | {f32_data, rows, features} = Helper.verify_data_tensor!(self, query) 132 | 133 | _do_knn_query(self, f32_data, k, num_threads, nil, rows, features) 134 | end 135 | 136 | defp _do_knn_query(self = %T{}, query, k, num_threads, filter, rows, features) do 137 | case HNSWLib.Nif.index_knn_query( 138 | self.reference, 139 | query, 140 | k, 141 | num_threads, 142 | filter, 143 | rows, 144 | features 145 | ) do 146 | {:ok, labels, dists, rows, k, label_bits, dist_bits} -> 147 | labels = Nx.reshape(Nx.from_binary(labels, :"u#{label_bits}"), {rows, k}) 148 | dists = Nx.reshape(Nx.from_binary(dists, :"f#{dist_bits}"), {rows, k}) 149 | {:ok, labels, dists} 150 | 151 | {:error, reason} -> 152 | {:error, reason} 153 | end 154 | end 155 | 156 | @doc """ 157 | Get a list of existing IDs in the index. 158 | """ 159 | @spec get_ids_list(%T{}) :: {:ok, [integer()]} | {:error, String.t()} 160 | def get_ids_list(self = %T{}) do 161 | HNSWLib.Nif.index_get_ids_list(self.reference) 162 | end 163 | 164 | @doc """ 165 | Get the ef parameter. 166 | """ 167 | @spec get_ef(%T{}) :: {:ok, non_neg_integer()} | {:error, String.t()} 168 | def get_ef(self = %T{}) do 169 | HNSWLib.Nif.index_get_ef(self.reference) 170 | end 171 | 172 | @doc """ 173 | Set the ef parameter. 174 | """ 175 | @spec set_ef(%T{}, non_neg_integer()) :: :ok | {:error, String.t()} 176 | def set_ef(self = %T{}, new_ef) when is_integer(new_ef) and new_ef >= 0 do 177 | HNSWLib.Nif.index_set_ef(self.reference, new_ef) 178 | end 179 | 180 | @doc """ 181 | Get the number of threads to use. 182 | """ 183 | @spec get_num_threads(%T{}) :: {:ok, integer()} | {:error, String.t()} 184 | def get_num_threads(self = %T{}) do 185 | HNSWLib.Nif.index_get_num_threads(self.reference) 186 | end 187 | 188 | @doc """ 189 | Set the number of threads to use. 190 | """ 191 | @spec set_num_threads(%T{}, integer()) :: :ok | {:error, String.t()} 192 | def set_num_threads(self = %T{}, new_num_threads) do 193 | HNSWLib.Nif.index_set_num_threads(self.reference, new_num_threads) 194 | end 195 | 196 | @doc """ 197 | Get the size the of index file. 198 | """ 199 | @spec index_file_size(%T{}) :: {:ok, non_neg_integer()} | {:error, String.t()} 200 | def index_file_size(self = %T{}) do 201 | HNSWLib.Nif.index_index_file_size(self.reference) 202 | end 203 | 204 | @doc """ 205 | Save current index to disk. 206 | 207 | ##### Positional Parameters 208 | 209 | - *path*: `Path.t()`. 210 | 211 | Path to save the index to. 212 | """ 213 | @spec save_index(%T{}, Path.t()) :: :ok | {:error, String.t()} 214 | def save_index(self = %T{}, path) when is_binary(path) do 215 | HNSWLib.Nif.index_save_index(self.reference, path) 216 | end 217 | 218 | @doc """ 219 | Load index from disk. 220 | 221 | ##### Positional Parameters 222 | 223 | - *space*: `:cosine` | `:ip` | `:l2`. 224 | 225 | An atom that indicates the vector space. Valid values are 226 | 227 | - `:cosine`, cosine space 228 | - `:ip`, inner product space 229 | - `:l2`, L2 space 230 | 231 | - *dim*: `non_neg_integer()`. 232 | 233 | Number of dimensions for each vector. 234 | 235 | - *path*: `Path.t()`. 236 | 237 | Path to load the index from. 238 | 239 | ##### Keyword Parameters 240 | 241 | - *max_elements*: `non_neg_integer()`. 242 | 243 | Maximum number of elements to load from the index. 244 | If set to 0, all elements will be loaded. 245 | Default: 0. 246 | 247 | - *allow_replace_deleted*: `boolean()`. 248 | """ 249 | @spec load_index(:cosine | :ip | :l2, non_neg_integer(), Path.t(), [ 250 | {:max_elements, non_neg_integer()}, 251 | {:allow_replace_deleted, boolean()} 252 | ]) :: {:ok, %T{}} | {:error, String.t()} 253 | def load_index(space, dim, path, opts \\ []) 254 | when (space == :l2 or space == :ip or space == :cosine) and is_integer(dim) and dim >= 0 and 255 | is_binary(path) and is_list(opts) do 256 | max_elements = Helper.get_keyword!(opts, :max_elements, :non_neg_integer, 0) 257 | allow_replace_deleted = Helper.get_keyword!(opts, :allow_replace_deleted, :boolean, false) 258 | 259 | with {:ok, ref} <- 260 | HNSWLib.Nif.index_load_index(space, dim, path, max_elements, allow_replace_deleted) do 261 | {:ok, 262 | %T{ 263 | space: space, 264 | dim: dim, 265 | reference: ref 266 | }} 267 | else 268 | {:error, reason} -> 269 | {:error, reason} 270 | end 271 | end 272 | 273 | @doc """ 274 | Mark a label as deleted. 275 | 276 | ##### Positional Parameters 277 | 278 | - *label*: `non_neg_integer()`. 279 | 280 | Label to mark as deleted. 281 | """ 282 | @spec mark_deleted(%T{}, non_neg_integer()) :: :ok | {:error, String.t()} 283 | def mark_deleted(self = %T{}, label) when is_integer(label) and label >= 0 do 284 | HNSWLib.Nif.index_mark_deleted(self.reference, label) 285 | end 286 | 287 | @doc """ 288 | Unmark a label as deleted. 289 | 290 | ##### Positional Parameters 291 | 292 | - *label*: `non_neg_integer()`. 293 | 294 | Label to unmark as deleted. 295 | """ 296 | @spec unmark_deleted(%T{}, non_neg_integer()) :: :ok | {:error, String.t()} 297 | def unmark_deleted(self = %T{}, label) when is_integer(label) and label >= 0 do 298 | HNSWLib.Nif.index_unmark_deleted(self.reference, label) 299 | end 300 | 301 | @doc """ 302 | Add items to the index. 303 | 304 | ##### Positional Parameters 305 | 306 | - *data*: `Nx.Tensor.t()`. 307 | 308 | Data to add to the index. 309 | 310 | ##### Keyword Parameters 311 | 312 | - *ids*: `Nx.Tensor.t() | [non_neg_integer()] | nil`. 313 | 314 | IDs to assign to the data. 315 | 316 | If `nil`, IDs will be assigned sequentially starting from 0. 317 | 318 | Defaults to `nil`. 319 | 320 | - *num_threads*: `integer()`. 321 | 322 | Number of threads to use. 323 | 324 | If set to `-1`, the number of threads will be automatically determined. 325 | 326 | Defaults to `-1`. 327 | 328 | - *replace_deleted*: `boolean()`. 329 | 330 | Whether to replace deleted items. 331 | 332 | Defaults to `false`. 333 | """ 334 | @spec add_items(%T{}, Nx.Tensor.t(), [ 335 | {:ids, Nx.Tensor.t() | [non_neg_integer()] | nil}, 336 | {:num_threads, integer()}, 337 | {:replace_deleted, false} 338 | ]) :: :ok | {:error, String.t()} 339 | def add_items(self, data, opts \\ []) 340 | 341 | def add_items(self = %T{}, data = %Nx.Tensor{}, opts) when is_list(opts) do 342 | num_threads = Helper.get_keyword!(opts, :num_threads, :integer, -1) 343 | replace_deleted = Helper.get_keyword!(opts, :replace_deleted, :boolean, false) 344 | ids = Helper.normalize_ids!(opts[:ids]) 345 | {f32_data, rows, features} = Helper.verify_data_tensor!(self, data) 346 | 347 | HNSWLib.Nif.index_add_items( 348 | self.reference, 349 | f32_data, 350 | ids, 351 | num_threads, 352 | replace_deleted, 353 | rows, 354 | features 355 | ) 356 | end 357 | 358 | @doc """ 359 | Retrieve items from the index using IDs. 360 | 361 | ##### Positional Parameters 362 | 363 | - *ids*: `Nx.Tensor.t() | [non_neg_integer()]`. 364 | 365 | IDs to retrieve. 366 | """ 367 | @spec get_items(%T{}, Nx.Tensor.t() | [integer()]) :: {:ok, [binary()]} | {:error, String.t()} 368 | def get_items(self = %T{}, ids) do 369 | ids = Helper.normalize_ids!(ids) 370 | 371 | HNSWLib.Nif.index_get_items(self.reference, ids) 372 | end 373 | 374 | @doc """ 375 | Resize the index. 376 | 377 | ##### Positional Parameters 378 | 379 | - *new_size*: `non_neg_integer()`. 380 | 381 | New size of the index. 382 | """ 383 | @spec resize_index(%T{}, non_neg_integer()) :: :ok | {:error, String.t()} 384 | def resize_index(self = %T{}, new_size) when is_integer(new_size) and new_size >= 0 do 385 | HNSWLib.Nif.index_resize_index(self.reference, new_size) 386 | end 387 | 388 | @doc """ 389 | Get the maximum number of elements the index can hold. 390 | """ 391 | @spec get_max_elements(%T{}) :: {:ok, integer()} | {:error, String.t()} 392 | def get_max_elements(self = %T{}) do 393 | HNSWLib.Nif.index_get_max_elements(self.reference) 394 | end 395 | 396 | @doc """ 397 | Get the current number of elements in the index. 398 | """ 399 | @spec get_current_count(%T{}) :: {:ok, integer()} | {:error, String.t()} 400 | def get_current_count(self = %T{}) do 401 | HNSWLib.Nif.index_get_current_count(self.reference) 402 | end 403 | 404 | @doc """ 405 | Get the ef_construction parameter. 406 | """ 407 | @spec get_ef_construction(%T{}) :: {:ok, integer()} | {:error, String.t()} 408 | def get_ef_construction(self = %T{}) do 409 | HNSWLib.Nif.index_get_ef_construction(self.reference) 410 | end 411 | 412 | @doc """ 413 | Get the M parameter. 414 | """ 415 | @spec get_m(%T{}) :: {:ok, integer()} | {:error, String.t()} 416 | def get_m(self = %T{}) do 417 | HNSWLib.Nif.index_get_m(self.reference) 418 | end 419 | end 420 | -------------------------------------------------------------------------------- /lib/hnswlib_nif.ex: -------------------------------------------------------------------------------- 1 | defmodule HNSWLib.Nif do 2 | @moduledoc false 3 | 4 | @on_load :load_nif 5 | def load_nif do 6 | nif_file = ~c"#{:code.priv_dir(:hnswlib)}/hnswlib_nif" 7 | 8 | case :erlang.load_nif(nif_file, 0) do 9 | :ok -> :ok 10 | {:error, {:reload, _}} -> :ok 11 | {:error, reason} -> IO.puts("Failed to load nif: #{inspect(reason)}") 12 | end 13 | end 14 | 15 | def index_new( 16 | _space, 17 | _dim, 18 | _max_elements, 19 | _m, 20 | _ef_construction, 21 | _random_seed, 22 | _allow_replace_deleted 23 | ), 24 | do: :erlang.nif_error(:not_loaded) 25 | 26 | def index_knn_query(_self, _data, _k, _num_threads, _filter, _rows, _features), 27 | do: :erlang.nif_error(:not_loaded) 28 | 29 | def index_add_items(_self, _f32_data, _ids, _num_threads, _replace_deleted, _rows, _features), 30 | do: :erlang.nif_error(:not_loaded) 31 | 32 | def index_get_items(_self, _ids), do: :erlang.nif_error(:not_loaded) 33 | 34 | def index_get_ids_list(_self), do: :erlang.nif_error(:not_loaded) 35 | 36 | def index_get_ef(_self), do: :erlang.nif_error(:not_loaded) 37 | 38 | def index_set_ef(_self, _new_ef), do: :erlang.nif_error(:not_loaded) 39 | 40 | def index_get_num_threads(_self), do: :erlang.nif_error(:not_loaded) 41 | 42 | def index_set_num_threads(_self, _new_num_threads), do: :erlang.nif_error(:not_loaded) 43 | 44 | def index_index_file_size(_self), do: :erlang.nif_error(:not_loaded) 45 | 46 | def index_save_index(_self, _path), do: :erlang.nif_error(:not_loaded) 47 | 48 | def index_load_index(_space, _dim, _path, _max_elements, _allow_replace_deleted), 49 | do: :erlang.nif_error(:not_loaded) 50 | 51 | def index_mark_deleted(_self, _label), do: :erlang.nif_error(:not_loaded) 52 | 53 | def index_unmark_deleted(_self, _label), do: :erlang.nif_error(:not_loaded) 54 | 55 | def index_resize_index(_self, _new_size), do: :erlang.nif_error(:not_loaded) 56 | 57 | def index_get_max_elements(_self), do: :erlang.nif_error(:not_loaded) 58 | 59 | def index_get_current_count(_self), do: :erlang.nif_error(:not_loaded) 60 | 61 | def index_get_ef_construction(_self), do: :erlang.nif_error(:not_loaded) 62 | 63 | def index_get_m(_self), do: :erlang.nif_error(:not_loaded) 64 | 65 | def bfindex_new(_space, _dim, _max_elements), do: :erlang.nif_error(:not_loaded) 66 | 67 | def bfindex_knn_query(_self, _data, _k, _filter, _rows, _features), 68 | do: :erlang.nif_error(:not_loaded) 69 | 70 | def bfindex_add_items(_self, _f32_data, _ids, _rows, _features), 71 | do: :erlang.nif_error(:not_loaded) 72 | 73 | def bfindex_delete_vector(_self, _label), do: :erlang.nif_error(:not_loaded) 74 | 75 | def bfindex_set_num_threads(_self, _num_threads), do: :erlang.nif_error(:not_loaded) 76 | 77 | def bfindex_save_index(_self, _path), do: :erlang.nif_error(:not_loaded) 78 | 79 | def bfindex_load_index(_space, _dim, _path, _max_elements), do: :erlang.nif_error(:not_loaded) 80 | 81 | def bfindex_get_max_elements(_self), do: :erlang.nif_error(:not_loaded) 82 | 83 | def bfindex_get_current_count(_self), do: :erlang.nif_error(:not_loaded) 84 | 85 | def bfindex_get_num_threads(_self), do: :erlang.nif_error(:not_loaded) 86 | def float_size, do: :erlang.nif_error(:not_loaded) 87 | end 88 | -------------------------------------------------------------------------------- /mix.exs: -------------------------------------------------------------------------------- 1 | defmodule HNSWLib.MixProject do 2 | use Mix.Project 3 | 4 | @version "0.1.6-dev" 5 | @github_url "https://github.com/elixir-nx/hnswlib" 6 | 7 | def project do 8 | [ 9 | app: :hnswlib, 10 | version: @version, 11 | elixir: "~> 1.12", 12 | start_permanent: Mix.env() == :prod, 13 | deps: deps(), 14 | package: package(), 15 | docs: docs(), 16 | main: "HNSWLib", 17 | description: "Elixir binding for the hnswlib library", 18 | compilers: [:elixir_make] ++ Mix.compilers(), 19 | make_precompiler: {:nif, CCPrecompiler}, 20 | make_precompiler_url: "#{@github_url}/releases/download/v#{@version}/@{artefact_filename}", 21 | make_precompiler_filename: "hnswlib_nif", 22 | make_precompiler_nif_versions: [versions: ["2.16", "2.17"]], 23 | cc_precompiler: cc_precompiler() 24 | ] 25 | end 26 | 27 | def application do 28 | [ 29 | extra_applications: [:logger] 30 | ] 31 | end 32 | 33 | defp deps do 34 | [ 35 | # compilation 36 | {:cc_precompiler, "~> 0.1.0"}, 37 | {:elixir_make, "~> 0.8.0"}, 38 | 39 | # runtime 40 | {:nx, "~> 0.5"}, 41 | 42 | # docs 43 | {:ex_doc, "~> 0.29", only: :docs, runtime: false} 44 | ] 45 | end 46 | 47 | defp docs do 48 | [ 49 | source_ref: "v#{@version}", 50 | source_url: @github_url 51 | ] 52 | end 53 | 54 | defp cc_precompiler do 55 | extra_options = 56 | if System.get_env("HNSWLIB_CI_PRECOMPILE") == "true" do 57 | [ 58 | only_listed_targets: true, 59 | exclude_current_target: true 60 | ] 61 | else 62 | [] 63 | end 64 | 65 | [cleanup: "cleanup"] ++ extra_options 66 | end 67 | 68 | defp package() do 69 | [ 70 | files: 71 | ~w(3rd_party/hnswlib c_src lib mix.exs README* LICENSE* CMakeLists.txt Makefile checksum.exs), 72 | licenses: ["Apache-2.0"], 73 | links: %{"GitHub" => @github_url} 74 | ] 75 | end 76 | end 77 | -------------------------------------------------------------------------------- /mix.lock: -------------------------------------------------------------------------------- 1 | %{ 2 | "cc_precompiler": {:hex, :cc_precompiler, "0.1.9", "e8d3364f310da6ce6463c3dd20cf90ae7bbecbf6c5203b98bf9b48035592649b", [:mix], [{:elixir_make, "~> 0.7", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "9dcab3d0f3038621f1601f13539e7a9ee99843862e66ad62827b0c42b2f58a54"}, 3 | "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, 4 | "earmark_parser": {:hex, :earmark_parser, "1.4.35", "437773ca9384edf69830e26e9e7b2e0d22d2596c4a6b17094a3b29f01ea65bb8", [:mix], [], "hexpm", "8652ba3cb85608d0d7aa2d21b45c6fad4ddc9a1f9a1f1b30ca3a246f0acc33f6"}, 5 | "elixir_make": {:hex, :elixir_make, "0.8.3", "d38d7ee1578d722d89b4d452a3e36bcfdc644c618f0d063b874661876e708683", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "5c99a18571a756d4af7a4d89ca75c28ac899e6103af6f223982f09ce44942cc9"}, 6 | "ex_doc": {:hex, :ex_doc, "0.30.6", "5f8b54854b240a2b55c9734c4b1d0dd7bdd41f71a095d42a70445c03cf05a281", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "bd48f2ddacf4e482c727f9293d9498e0881597eae6ddc3d9562bd7923375109f"}, 7 | "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, 8 | "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, 9 | "makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"}, 10 | "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, 11 | "nx": {:hex, :nx, "0.6.1", "df65cd61312bcaa756559fb994596d403c822e353616094fdfc31a15181c95f8", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "23dcc8e2824a6e19fcdebef39145fdff7625fd7d26fd50c1990ac0a1dd05f960"}, 12 | "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, 13 | } 14 | -------------------------------------------------------------------------------- /test/hnswlib_bfindex_test.exs: -------------------------------------------------------------------------------- 1 | defmodule HNSWLib.BFIndex.Test do 2 | use ExUnit.Case 3 | doctest HNSWLib.BFIndex 4 | 5 | test "HNSWLib.BFIndex.new/3 with L2-space" do 6 | space = :l2 7 | dim = 8 8 | max_elements = 200 9 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 10 | 11 | assert is_reference(index.reference) 12 | assert space == index.space 13 | assert dim == index.dim 14 | 15 | dim = 12 16 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 17 | 18 | assert is_reference(index.reference) 19 | assert space == index.space 20 | assert dim == index.dim 21 | 22 | space = :cosine 23 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 24 | 25 | assert is_reference(index.reference) 26 | assert space == index.space 27 | assert dim == index.dim 28 | 29 | space = :ip 30 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 31 | 32 | assert is_reference(index.reference) 33 | assert space == index.space 34 | assert dim == index.dim 35 | end 36 | 37 | test "HNSWLib.BFIndex.new/3 with cosine-space" do 38 | space = :cosine 39 | dim = 8 40 | max_elements = 200 41 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 42 | 43 | assert is_reference(index.reference) 44 | assert space == index.space 45 | assert dim == index.dim 46 | 47 | dim = 12 48 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 49 | 50 | assert is_reference(index.reference) 51 | assert space == index.space 52 | assert dim == index.dim 53 | end 54 | 55 | test "HNSWLib.BFIndex.new/3 with inner-product space" do 56 | space = :ip 57 | dim = 8 58 | max_elements = 200 59 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 60 | 61 | assert is_reference(index.reference) 62 | assert space == index.space 63 | assert dim == index.dim 64 | 65 | dim = 12 66 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 67 | 68 | assert is_reference(index.reference) 69 | assert space == index.space 70 | assert dim == index.dim 71 | end 72 | 73 | test "HNSWLib.BFIndex.knn_query/2 with binary" do 74 | space = :l2 75 | dim = 2 76 | max_elements = 200 77 | 78 | data = 79 | Nx.tensor( 80 | [ 81 | [42, 42], 82 | [43, 43], 83 | [0, 0], 84 | [200, 200], 85 | [200, 220] 86 | ], 87 | type: :f32 88 | ) 89 | 90 | ids = [5, 6, 7, 8, 9] 91 | 92 | query = <<41.0::float-32-native, 41.0::float-32-native>> 93 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 94 | assert :ok == HNSWLib.BFIndex.add_items(index, data, ids: ids) 95 | 96 | {:ok, labels, dists} = HNSWLib.BFIndex.knn_query(index, query, k: 3) 97 | assert 1 == Nx.to_number(Nx.all_close(labels, Nx.tensor([5, 6, 7]))) 98 | assert 1 == Nx.to_number(Nx.all_close(dists, Nx.tensor([2.0, 8.0, 3362.0]))) 99 | end 100 | 101 | test "HNSWLib.BFIndex.knn_query/2 with [binary]" do 102 | space = :l2 103 | dim = 2 104 | max_elements = 200 105 | 106 | data = 107 | Nx.tensor( 108 | [ 109 | [42, 42], 110 | [43, 43], 111 | [0, 0], 112 | [200, 200], 113 | [200, 220] 114 | ], 115 | type: :f32 116 | ) 117 | 118 | ids = [5, 6, 7, 8, 9] 119 | 120 | query = [ 121 | <<0.0::float-32-native, 0.0::float-32-native>>, 122 | <<41.0::float-32-native, 41.0::float-32-native>> 123 | ] 124 | 125 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 126 | assert :ok == HNSWLib.BFIndex.add_items(index, data, ids: ids) 127 | 128 | {:ok, labels, dists} = HNSWLib.BFIndex.knn_query(index, query, k: 3) 129 | assert 1 == Nx.to_number(Nx.all_close(labels, Nx.tensor([[7, 5, 6], [5, 6, 7]]))) 130 | 131 | assert 1 == 132 | Nx.to_number( 133 | Nx.all_close(dists, Nx.tensor([[0.0, 3528.0, 3698.0], [2.0, 8.0, 3362.0]])) 134 | ) 135 | end 136 | 137 | test "HNSWLib.BFIndex.knn_query/2 with Nx.Tensor (:f32)" do 138 | space = :l2 139 | dim = 2 140 | max_elements = 200 141 | 142 | data = 143 | Nx.tensor( 144 | [ 145 | [42, 42], 146 | [43, 43], 147 | [0, 0], 148 | [200, 200], 149 | [200, 220] 150 | ], 151 | type: :f32 152 | ) 153 | 154 | query = Nx.tensor([1, 2], type: :f32) 155 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 156 | assert :ok == HNSWLib.BFIndex.add_items(index, data) 157 | 158 | {:ok, labels, dists} = HNSWLib.BFIndex.knn_query(index, query) 159 | assert 1 == Nx.to_number(Nx.all_close(labels, Nx.tensor([2]))) 160 | assert 1 == Nx.to_number(Nx.all_close(dists, Nx.tensor([5]))) 161 | end 162 | 163 | test "HNSWLib.BFIndex.knn_query/2 with Nx.Tensor (:u8)" do 164 | space = :l2 165 | dim = 2 166 | max_elements = 200 167 | 168 | data = 169 | Nx.tensor( 170 | [ 171 | [42, 42], 172 | [43, 43], 173 | [0, 0], 174 | [200, 200], 175 | [200, 220] 176 | ], 177 | type: :f32 178 | ) 179 | 180 | query = Nx.tensor([1, 2], type: :u8) 181 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 182 | assert :ok == HNSWLib.BFIndex.add_items(index, data) 183 | 184 | {:ok, labels, dists} = HNSWLib.BFIndex.knn_query(index, query) 185 | assert 1 == Nx.to_number(Nx.all_close(labels, Nx.tensor([2]))) 186 | assert 1 == Nx.to_number(Nx.all_close(dists, Nx.tensor([5]))) 187 | end 188 | 189 | test "HNSWLib.BFIndex.knn_query/2 with invalid length of data" do 190 | space = :ip 191 | dim = 2 192 | max_elements = 200 193 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 194 | data = <<42::16, 1::24>> 195 | 196 | assert_raise ArgumentError, 197 | "vector feature size should be a multiple of 4 (sizeof(float))", 198 | fn -> 199 | HNSWLib.BFIndex.knn_query(index, data) 200 | end 201 | end 202 | 203 | test "HNSWLib.BFIndex.knn_query/2 with invalid dimensions of data" do 204 | space = :ip 205 | dim = 2 206 | max_elements = 200 207 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 208 | data = <<42::float-32, 42::float-32, 42::float-32>> 209 | 210 | assert_raise ArgumentError, "Wrong dimensionality of the vectors, expect `2`, got `3`", fn -> 211 | HNSWLib.BFIndex.knn_query(index, data) 212 | end 213 | end 214 | 215 | test "HNSWLib.BFIndex.knn_query/2 with inconsistent dimensions of [data]" do 216 | space = :ip 217 | dim = 2 218 | max_elements = 200 219 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 220 | data = [<<42::float-32, 42::float-32>>, <<42::float-32, 42::float-32, 42::float-32>>] 221 | 222 | assert_raise ArgumentError, "all vectors in the input list should have the same size", fn -> 223 | HNSWLib.BFIndex.knn_query(index, data) 224 | end 225 | end 226 | 227 | test "HNSWLib.BFIndex.knn_query/2 with invalid dimensions of [data]" do 228 | space = :ip 229 | dim = 2 230 | max_elements = 200 231 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 232 | 233 | data = [ 234 | <<42::float-32, 42::float-32, 42::float-32>>, 235 | <<42::float-32, 42::float-32, 42::float-32>> 236 | ] 237 | 238 | assert_raise ArgumentError, "Wrong dimensionality of the vectors, expect `2`, got `3`", fn -> 239 | HNSWLib.BFIndex.knn_query(index, data) 240 | end 241 | end 242 | 243 | test "HNSWLib.BFIndex.knn_query/2 with invalid type for `k`" do 244 | space = :ip 245 | dim = 2 246 | max_elements = 200 247 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 248 | data = <<42.0, 42.0>> 249 | k = :invalid 250 | 251 | assert_raise ArgumentError, 252 | "expect keyword parameter `:k` to be a positive integer, got `:invalid`", 253 | fn -> 254 | HNSWLib.BFIndex.knn_query(index, data, k: k) 255 | end 256 | end 257 | 258 | test "HNSWLib.BFIndex.add_items/3 without specifying ids" do 259 | space = :l2 260 | dim = 2 261 | max_elements = 200 262 | items = Nx.tensor([[10, 20], [30, 40]], type: :f32) 263 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 264 | 265 | assert :ok == HNSWLib.BFIndex.add_items(index, items) 266 | end 267 | 268 | test "HNSWLib.BFIndex.add_items/3 with specifying ids (Nx.Tensor)" do 269 | space = :l2 270 | dim = 2 271 | max_elements = 200 272 | items = Nx.tensor([[10, 20], [30, 40]], type: :f32) 273 | ids = Nx.tensor([100, 200]) 274 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 275 | 276 | assert :ok == HNSWLib.BFIndex.add_items(index, items, ids: ids) 277 | end 278 | 279 | test "HNSWLib.BFIndex.add_items/3 with specifying ids (list)" do 280 | space = :l2 281 | dim = 2 282 | max_elements = 200 283 | items = Nx.tensor([[10, 20], [30, 40]], type: :f32) 284 | ids = [100, 200] 285 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 286 | 287 | assert :ok == HNSWLib.BFIndex.add_items(index, items, ids: ids) 288 | end 289 | 290 | test "HNSWLib.BFIndex.add_items/3 with wrong dim of data tensor" do 291 | space = :l2 292 | dim = 2 293 | max_elements = 200 294 | items = Nx.tensor([[10, 20, 300], [30, 40, 500]], type: :f32) 295 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 296 | 297 | assert_raise ArgumentError, "Wrong dimensionality of the vectors, expect `2`, got `3`", fn -> 298 | HNSWLib.BFIndex.add_items(index, items) 299 | end 300 | end 301 | 302 | test "HNSWLib.BFIndex.add_items/3 with wrong dim of ids" do 303 | space = :l2 304 | dim = 2 305 | max_elements = 200 306 | items = Nx.tensor([[10, 20], [30, 40]], type: :f32) 307 | ids = Nx.tensor([[100], [200]]) 308 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 309 | 310 | assert_raise ArgumentError, "expect ids to be a 1D array, got `{2, 1}`", fn -> 311 | HNSWLib.BFIndex.add_items(index, items, ids: ids) 312 | end 313 | end 314 | 315 | test "HNSWLib.BFIndex.delete_vector/2" do 316 | space = :l2 317 | dim = 2 318 | max_elements = 200 319 | items = Nx.tensor([[10, 20], [30, 40]], type: :f32) 320 | query = Nx.tensor([29, 39], type: :f32) 321 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 322 | 323 | assert :ok == HNSWLib.BFIndex.add_items(index, items) 324 | assert :ok == HNSWLib.BFIndex.delete_vector(index, 0) 325 | 326 | {:ok, labels, dists} = HNSWLib.BFIndex.knn_query(index, query) 327 | assert 1 == Nx.to_number(Nx.all_close(labels, Nx.tensor([1]))) 328 | assert 1 == Nx.to_number(Nx.all_close(dists, Nx.tensor([2]))) 329 | end 330 | 331 | test "HNSWLib.BFIndex.save_index/2" do 332 | space = :l2 333 | dim = 2 334 | max_elements = 200 335 | items = Nx.tensor([[10, 20], [30, 40]], type: :f32) 336 | ids = Nx.tensor([100, 200]) 337 | save_to = Path.join([__DIR__, "saved_bfindex.bin"]) 338 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 339 | :ok = HNSWLib.BFIndex.add_items(index, items, ids: ids) 340 | 341 | # ensure file does not exist 342 | File.rm(save_to) 343 | assert :ok == HNSWLib.BFIndex.save_index(index, save_to) 344 | assert File.exists?(save_to) 345 | 346 | # cleanup 347 | File.rm(save_to) 348 | end 349 | 350 | test "HNSWLib.BFIndex.load_index/3" do 351 | space = :l2 352 | dim = 2 353 | max_elements = 200 354 | items = Nx.tensor([[10, 20], [30, 40]], type: :f32) 355 | ids = Nx.tensor([100, 200]) 356 | save_to = Path.join([__DIR__, "saved_bfindex.bin"]) 357 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 358 | :ok = HNSWLib.BFIndex.add_items(index, items, ids: ids) 359 | 360 | # ensure file does not exist 361 | File.rm(save_to) 362 | assert :ok == HNSWLib.BFIndex.save_index(index, save_to) 363 | assert File.exists?(save_to) 364 | 365 | {:ok, index_from_save} = HNSWLib.BFIndex.load_index(space, dim, save_to) 366 | 367 | assert HNSWLib.BFIndex.get_max_elements(index) == 368 | HNSWLib.BFIndex.get_max_elements(index_from_save) 369 | 370 | # cleanup 371 | File.rm(save_to) 372 | end 373 | 374 | test "HNSWLib.BFIndex.load_index/3 with new max_elements" do 375 | space = :l2 376 | dim = 2 377 | max_elements = 200 378 | items = Nx.tensor([[10, 20], [30, 40]], type: :f32) 379 | ids = Nx.tensor([100, 200]) 380 | save_to = Path.join([__DIR__, "saved_bfindex.bin"]) 381 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 382 | :ok = HNSWLib.BFIndex.add_items(index, items, ids: ids) 383 | 384 | # ensure file does not exist 385 | File.rm(save_to) 386 | assert :ok == HNSWLib.BFIndex.save_index(index, save_to) 387 | assert File.exists?(save_to) 388 | 389 | new_max_elements = 100 390 | 391 | {:ok, _index_from_save} = 392 | HNSWLib.BFIndex.load_index(space, dim, save_to, max_elements: new_max_elements) 393 | 394 | assert {:ok, 200} == HNSWLib.BFIndex.get_max_elements(index) 395 | # fix: upstream bug? 396 | # assert {:ok, 100} == HNSWLib.BFIndex.get_max_elements(index_from_save) 397 | 398 | # cleanup 399 | File.rm(save_to) 400 | end 401 | 402 | test "HNSWLib.BFIndex.get_max_elements/1" do 403 | space = :l2 404 | dim = 2 405 | max_elements = 200 406 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 407 | 408 | assert {:ok, 200} == HNSWLib.BFIndex.get_max_elements(index) 409 | end 410 | 411 | test "HNSWLib.BFIndex.get_current_count/1 when empty" do 412 | space = :l2 413 | dim = 2 414 | max_elements = 200 415 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 416 | 417 | assert {:ok, 0} == HNSWLib.BFIndex.get_current_count(index) 418 | end 419 | 420 | test "HNSWLib.BFIndex.get_current_count/1 before and after" do 421 | space = :l2 422 | dim = 2 423 | max_elements = 200 424 | items = Nx.tensor([[10, 20], [30, 40]], type: :f32) 425 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 426 | 427 | assert {:ok, 0} == HNSWLib.BFIndex.get_current_count(index) 428 | assert :ok == HNSWLib.BFIndex.add_items(index, items) 429 | assert {:ok, 2} == HNSWLib.BFIndex.get_current_count(index) 430 | end 431 | 432 | test "HNSWLib.BFIndex.get_num_threads/1 default case" do 433 | space = :l2 434 | dim = 2 435 | max_elements = 200 436 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 437 | 438 | {:ok, cur_num_threads} = HNSWLib.BFIndex.get_num_threads(index) 439 | assert System.schedulers() == cur_num_threads 440 | end 441 | 442 | test "HNSWLib.BFIndex.get_num_threads/1 before and after HNSWLib.BFIndex.set_num_threads/2" do 443 | space = :l2 444 | dim = 2 445 | max_elements = 200 446 | {:ok, index} = HNSWLib.BFIndex.new(space, dim, max_elements) 447 | 448 | {:ok, cur_num_threads} = HNSWLib.BFIndex.get_num_threads(index) 449 | assert System.schedulers() == cur_num_threads 450 | assert :ok == HNSWLib.BFIndex.set_num_threads(index, cur_num_threads + 1) 451 | {:ok, updated_num_threads} = HNSWLib.BFIndex.get_num_threads(index) 452 | assert updated_num_threads == cur_num_threads + 1 453 | end 454 | end 455 | -------------------------------------------------------------------------------- /test/hnswlib_index_test.exs: -------------------------------------------------------------------------------- 1 | defmodule HNSWLib.Index.Test do 2 | use ExUnit.Case 3 | doctest HNSWLib.Index 4 | 5 | test "HNSWLib.Index.new should not accept 0 for max_elements" do 6 | space = :l2 7 | dim = 8 8 | max_elements = 0 9 | assert_raise FunctionClauseError, "no function clause matching in HNSWLib.Index.new/4", fn -> 10 | HNSWLib.Index.new(space, dim, max_elements) 11 | end 12 | end 13 | 14 | test "HNSWLib.Index.new/3 with L2-space" do 15 | space = :l2 16 | dim = 8 17 | max_elements = 200 18 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 19 | 20 | assert is_reference(index.reference) 21 | assert space == index.space 22 | assert dim == index.dim 23 | 24 | dim = 12 25 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 26 | 27 | assert is_reference(index.reference) 28 | assert space == index.space 29 | assert dim == index.dim 30 | 31 | space = :cosine 32 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 33 | 34 | assert is_reference(index.reference) 35 | assert space == index.space 36 | assert dim == index.dim 37 | 38 | space = :ip 39 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 40 | 41 | assert is_reference(index.reference) 42 | assert space == index.space 43 | assert dim == index.dim 44 | end 45 | 46 | test "HNSWLib.Index.new/3 with cosine-space" do 47 | space = :cosine 48 | dim = 8 49 | max_elements = 200 50 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 51 | 52 | assert is_reference(index.reference) 53 | assert space == index.space 54 | assert dim == index.dim 55 | 56 | dim = 12 57 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 58 | 59 | assert is_reference(index.reference) 60 | assert space == index.space 61 | assert dim == index.dim 62 | end 63 | 64 | test "HNSWLib.Index.new/3 with inner-product space" do 65 | space = :ip 66 | dim = 8 67 | max_elements = 200 68 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 69 | 70 | assert is_reference(index.reference) 71 | assert space == index.space 72 | assert dim == index.dim 73 | 74 | dim = 12 75 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 76 | 77 | assert is_reference(index.reference) 78 | assert space == index.space 79 | assert dim == index.dim 80 | end 81 | 82 | test "HNSWLib.Index.new/3 with non-default keyword parameters" do 83 | space = :ip 84 | dim = 8 85 | max_elements = 200 86 | 87 | m = 200 88 | ef_construction = 400 89 | random_seed = 42 90 | allow_replace_deleted = true 91 | 92 | {:ok, index} = 93 | HNSWLib.Index.new(space, dim, max_elements, 94 | m: m, 95 | ef_construction: ef_construction, 96 | random_seed: random_seed, 97 | allow_replace_deleted: allow_replace_deleted 98 | ) 99 | 100 | assert is_reference(index.reference) 101 | assert space == index.space 102 | assert dim == index.dim 103 | 104 | dim = 12 105 | 106 | {:ok, index} = 107 | HNSWLib.Index.new(space, dim, max_elements, 108 | m: m, 109 | ef_construction: ef_construction, 110 | random_seed: random_seed, 111 | allow_replace_deleted: allow_replace_deleted 112 | ) 113 | 114 | assert is_reference(index.reference) 115 | assert space == index.space 116 | assert dim == index.dim 117 | end 118 | 119 | test "HNSWLib.Index.new/3 with invalid keyword parameter m" do 120 | space = :ip 121 | dim = 8 122 | max_elements = 200 123 | 124 | m = -1 125 | 126 | assert_raise ArgumentError, 127 | "expect keyword parameter `:m` to be a non-negative integer, got `#{inspect(m)}`", 128 | fn -> 129 | HNSWLib.Index.new(space, dim, max_elements, m: m) 130 | end 131 | end 132 | 133 | test "HNSWLib.Index.new/3 with invalid keyword parameter ef_construction" do 134 | space = :ip 135 | dim = 8 136 | max_elements = 200 137 | 138 | ef_construction = -1 139 | 140 | assert_raise ArgumentError, 141 | "expect keyword parameter `:ef_construction` to be a non-negative integer, got `#{inspect(ef_construction)}`", 142 | fn -> 143 | HNSWLib.Index.new(space, dim, max_elements, ef_construction: ef_construction) 144 | end 145 | end 146 | 147 | test "HNSWLib.Index.new/3 with invalid keyword parameter random_seed" do 148 | space = :ip 149 | dim = 8 150 | max_elements = 200 151 | 152 | random_seed = -1 153 | 154 | assert_raise ArgumentError, 155 | "expect keyword parameter `:random_seed` to be a non-negative integer, got `#{inspect(random_seed)}`", 156 | fn -> 157 | HNSWLib.Index.new(space, dim, max_elements, random_seed: random_seed) 158 | end 159 | end 160 | 161 | test "HNSWLib.Index.new/3 with invalid keyword parameter allow_replace_deleted" do 162 | space = :ip 163 | dim = 8 164 | max_elements = 200 165 | 166 | allow_replace_deleted = -1 167 | 168 | assert_raise ArgumentError, 169 | "expect keyword parameter `:allow_replace_deleted` to be a boolean, got `#{inspect(allow_replace_deleted)}`", 170 | fn -> 171 | HNSWLib.Index.new(space, dim, max_elements, 172 | allow_replace_deleted: allow_replace_deleted 173 | ) 174 | end 175 | end 176 | 177 | test "HNSWLib.Index.knn_query/2 with binary" do 178 | space = :l2 179 | dim = 2 180 | max_elements = 200 181 | 182 | data = 183 | Nx.tensor( 184 | [ 185 | [42, 42], 186 | [43, 43], 187 | [0, 0], 188 | [200, 200], 189 | [200, 220] 190 | ], 191 | type: :f32 192 | ) 193 | 194 | ids = [5, 6, 7, 8, 9] 195 | 196 | query = <<41.0::float-32-native, 41.0::float-32-native>> 197 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 198 | assert :ok == HNSWLib.Index.add_items(index, data, ids: ids) 199 | 200 | {:ok, labels, dists} = HNSWLib.Index.knn_query(index, query, k: 3) 201 | assert 1 == Nx.to_number(Nx.all_close(labels, Nx.tensor([5, 6, 7]))) 202 | assert 1 == Nx.to_number(Nx.all_close(dists, Nx.tensor([2.0, 8.0, 3362.0]))) 203 | end 204 | 205 | test "HNSWLib.Index.knn_query/2 with [binary]" do 206 | space = :l2 207 | dim = 2 208 | max_elements = 200 209 | 210 | data = 211 | Nx.tensor( 212 | [ 213 | [42, 42], 214 | [43, 43], 215 | [0, 0], 216 | [200, 200], 217 | [200, 220] 218 | ], 219 | type: :f32 220 | ) 221 | 222 | ids = [5, 6, 7, 8, 9] 223 | 224 | query = [ 225 | <<0.0::float-32-native, 0.0::float-32-native>>, 226 | <<41.0::float-32-native, 41.0::float-32-native>> 227 | ] 228 | 229 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 230 | assert :ok == HNSWLib.Index.add_items(index, data, ids: ids) 231 | 232 | {:ok, labels, dists} = HNSWLib.Index.knn_query(index, query, k: 3) 233 | assert 1 == Nx.to_number(Nx.all_close(labels, Nx.tensor([[7, 5, 6], [5, 6, 7]]))) 234 | 235 | assert 1 == 236 | Nx.to_number( 237 | Nx.all_close(dists, Nx.tensor([[0.0, 3528.0, 3698.0], [2.0, 8.0, 3362.0]])) 238 | ) 239 | end 240 | 241 | test "HNSWLib.Index.knn_query/2 with Nx.Tensor (:f32)" do 242 | space = :l2 243 | dim = 2 244 | max_elements = 200 245 | 246 | data = 247 | Nx.tensor( 248 | [ 249 | [42, 42], 250 | [43, 43], 251 | [0, 0], 252 | [200, 200], 253 | [200, 220] 254 | ], 255 | type: :f32 256 | ) 257 | 258 | query = Nx.tensor([1, 2], type: :f32) 259 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 260 | assert :ok == HNSWLib.Index.add_items(index, data) 261 | 262 | {:ok, labels, dists} = HNSWLib.Index.knn_query(index, query) 263 | assert 1 == Nx.to_number(Nx.all_close(labels, Nx.tensor([2]))) 264 | assert 1 == Nx.to_number(Nx.all_close(dists, Nx.tensor([5]))) 265 | end 266 | 267 | test "HNSWLib.Index.knn_query/2 with Nx.Tensor (:u8)" do 268 | space = :l2 269 | dim = 2 270 | max_elements = 200 271 | 272 | data = 273 | Nx.tensor( 274 | [ 275 | [42, 42], 276 | [43, 43], 277 | [0, 0], 278 | [200, 200], 279 | [200, 220] 280 | ], 281 | type: :f32 282 | ) 283 | 284 | query = Nx.tensor([1, 2], type: :u8) 285 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 286 | assert :ok == HNSWLib.Index.add_items(index, data) 287 | 288 | {:ok, labels, dists} = HNSWLib.Index.knn_query(index, query) 289 | assert 1 == Nx.to_number(Nx.all_close(labels, Nx.tensor([2]))) 290 | assert 1 == Nx.to_number(Nx.all_close(dists, Nx.tensor([5]))) 291 | end 292 | 293 | test "HNSWLib.Index.knn_query/2 with invalid length of data" do 294 | space = :ip 295 | dim = 2 296 | max_elements = 200 297 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 298 | data = <<42::16, 1::24>> 299 | 300 | assert_raise ArgumentError, 301 | "vector feature size should be a multiple of 4 (sizeof(float))", 302 | fn -> 303 | HNSWLib.Index.knn_query(index, data) 304 | end 305 | end 306 | 307 | test "HNSWLib.Index.knn_query/2 with invalid dimensions of data" do 308 | space = :ip 309 | dim = 2 310 | max_elements = 200 311 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 312 | data = <<42::float-32, 42::float-32, 42::float-32>> 313 | 314 | assert_raise ArgumentError, "Wrong dimensionality of the vectors, expect `2`, got `3`", fn -> 315 | HNSWLib.Index.knn_query(index, data) 316 | end 317 | end 318 | 319 | test "HNSWLib.Index.knn_query/2 with inconsistent dimensions of [data]" do 320 | space = :ip 321 | dim = 2 322 | max_elements = 200 323 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 324 | data = [<<42::float-32, 42::float-32>>, <<42::float-32, 42::float-32, 42::float-32>>] 325 | 326 | assert_raise ArgumentError, "all vectors in the input list should have the same size", fn -> 327 | HNSWLib.Index.knn_query(index, data) 328 | end 329 | end 330 | 331 | test "HNSWLib.Index.knn_query/2 with invalid dimensions of [data]" do 332 | space = :ip 333 | dim = 2 334 | max_elements = 200 335 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 336 | 337 | data = [ 338 | <<42::float-32, 42::float-32, 42::float-32>>, 339 | <<42::float-32, 42::float-32, 42::float-32>> 340 | ] 341 | 342 | assert_raise ArgumentError, "Wrong dimensionality of the vectors, expect `2`, got `3`", fn -> 343 | HNSWLib.Index.knn_query(index, data) 344 | end 345 | end 346 | 347 | test "HNSWLib.Index.knn_query/2 with invalid type for `k`" do 348 | space = :ip 349 | dim = 2 350 | max_elements = 200 351 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 352 | data = <<42.0, 42.0>> 353 | k = :invalid 354 | 355 | assert_raise ArgumentError, 356 | "expect keyword parameter `:k` to be a positive integer, got `:invalid`", 357 | fn -> 358 | HNSWLib.Index.knn_query(index, data, k: k) 359 | end 360 | end 361 | 362 | test "HNSWLib.Index.knn_query/2 with invalid type for `num_threads`" do 363 | space = :ip 364 | dim = 2 365 | max_elements = 200 366 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 367 | data = <<42.0, 42.0>> 368 | num_threads = :invalid 369 | 370 | assert_raise ArgumentError, 371 | "expect keyword parameter `:num_threads` to be an integer, got `:invalid`", 372 | fn -> 373 | HNSWLib.Index.knn_query(index, data, num_threads: num_threads) 374 | end 375 | end 376 | 377 | # test "HNSWLib.Index.knn_query/2 with invalid type for `filter`" do 378 | # space = :ip 379 | # dim = 2 380 | # max_elements = 200 381 | # {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 382 | # data = <<42.0, 42.0>> 383 | # filter = :invalid 384 | 385 | # assert {:error, 386 | # "expect keyword parameter `:filter` to be a function that can be applied with 1 number of arguments , got `:invalid`"} == 387 | # HNSWLib.Index.knn_query(index, data, filter: filter) 388 | # end 389 | 390 | test "HNSWLib.Index.add_items/3 without specifying ids" do 391 | space = :l2 392 | dim = 2 393 | max_elements = 200 394 | items = Nx.tensor([[10, 20], [30, 40]], type: :f32) 395 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 396 | 397 | assert {:ok, []} == HNSWLib.Index.get_ids_list(index) 398 | 399 | assert :ok == HNSWLib.Index.add_items(index, items) 400 | assert {:ok, [0, 1]} == HNSWLib.Index.get_ids_list(index) 401 | end 402 | 403 | test "HNSWLib.Index.add_items/3 with specifying ids (Nx.Tensor)" do 404 | space = :l2 405 | dim = 2 406 | max_elements = 200 407 | items = Nx.tensor([[10, 20], [30, 40]], type: :f32) 408 | ids = Nx.tensor([100, 200]) 409 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 410 | 411 | assert {:ok, []} == HNSWLib.Index.get_ids_list(index) 412 | 413 | assert :ok == HNSWLib.Index.add_items(index, items, ids: ids) 414 | assert {:ok, [100, 200]} == HNSWLib.Index.get_ids_list(index) 415 | end 416 | 417 | test "HNSWLib.Index.add_items/3 with specifying ids (list)" do 418 | space = :l2 419 | dim = 2 420 | max_elements = 200 421 | items = Nx.tensor([[10, 20], [30, 40]], type: :f32) 422 | ids = [100, 200] 423 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 424 | 425 | assert {:ok, []} == HNSWLib.Index.get_ids_list(index) 426 | 427 | assert :ok == HNSWLib.Index.add_items(index, items, ids: ids) 428 | assert {:ok, [100, 200]} == HNSWLib.Index.get_ids_list(index) 429 | end 430 | 431 | test "HNSWLib.Index.add_items/3 with wrong dim of data tensor" do 432 | space = :l2 433 | dim = 2 434 | max_elements = 200 435 | items = Nx.tensor([[10, 20, 300], [30, 40, 500]], type: :f32) 436 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 437 | 438 | assert {:ok, []} == HNSWLib.Index.get_ids_list(index) 439 | 440 | assert_raise ArgumentError, "Wrong dimensionality of the vectors, expect `2`, got `3`", fn -> 441 | HNSWLib.Index.add_items(index, items) 442 | end 443 | end 444 | 445 | test "HNSWLib.Index.add_items/3 with wrong dim of ids" do 446 | space = :l2 447 | dim = 2 448 | max_elements = 200 449 | items = Nx.tensor([[10, 20], [30, 40]], type: :f32) 450 | ids = Nx.tensor([[100], [200]]) 451 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 452 | 453 | assert {:ok, []} == HNSWLib.Index.get_ids_list(index) 454 | 455 | assert_raise ArgumentError, "expect ids to be a 1D array, got `{2, 1}`", fn -> 456 | HNSWLib.Index.add_items(index, items, ids: ids) 457 | end 458 | end 459 | 460 | test "HNSWLib.Index.get_items/2" do 461 | space = :l2 462 | dim = 2 463 | max_elements = 200 464 | items = Nx.tensor([[10, 20], [30, 40]], type: :f32) 465 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 466 | 467 | assert {:ok, []} == HNSWLib.Index.get_ids_list(index) 468 | 469 | assert :ok == HNSWLib.Index.add_items(index, items) 470 | assert {:ok, [0, 1]} == HNSWLib.Index.get_ids_list(index) 471 | 472 | {:ok, [f32_binary_0, f32_binary_1]} = HNSWLib.Index.get_items(index, [0, 1]) 473 | assert f32_binary_0 == Nx.to_binary(items[0]) 474 | assert f32_binary_1 == Nx.to_binary(items[1]) 475 | 476 | {:ok, [f32_binary_1]} = HNSWLib.Index.get_items(index, [1]) 477 | assert f32_binary_1 == Nx.to_binary(items[1]) 478 | 479 | {:ok, [f32_binary_0]} = HNSWLib.Index.get_items(index, [0]) 480 | assert f32_binary_0 == Nx.to_binary(items[0]) 481 | 482 | assert {:error, "Label not found"} == HNSWLib.Index.get_items(index, [2]) 483 | end 484 | 485 | test "HNSWLib.Index.get_ids_list/1 when empty" do 486 | space = :ip 487 | dim = 2 488 | max_elements = 200 489 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 490 | 491 | assert {:ok, []} == HNSWLib.Index.get_ids_list(index) 492 | end 493 | 494 | test "HNSWLib.Index.get_ef/1 with default init config" do 495 | space = :ip 496 | dim = 2 497 | max_elements = 200 498 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 499 | 500 | assert {:ok, 10} == HNSWLib.Index.get_ef(index) 501 | end 502 | 503 | test "HNSWLib.Index.set_ef/2" do 504 | space = :ip 505 | dim = 2 506 | max_elements = 200 507 | new_ef = 1000 508 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 509 | 510 | assert {:ok, 10} == HNSWLib.Index.get_ef(index) 511 | assert :ok == HNSWLib.Index.set_ef(index, new_ef) 512 | assert {:ok, 1000} == HNSWLib.Index.get_ef(index) 513 | end 514 | 515 | test "HNSWLib.Index.get_num_threads/1 with default config" do 516 | space = :l2 517 | dim = 2 518 | max_elements = 200 519 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 520 | 521 | {:ok, num_threads} = HNSWLib.Index.get_num_threads(index) 522 | assert num_threads == System.schedulers_online() 523 | end 524 | 525 | test "HNSWLib.Index.set_num_threads/2" do 526 | space = :l2 527 | dim = 2 528 | max_elements = 200 529 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 530 | 531 | {:ok, num_threads} = HNSWLib.Index.get_num_threads(index) 532 | assert num_threads == System.schedulers_online() 533 | 534 | :ok = HNSWLib.Index.set_num_threads(index, 2) 535 | {:ok, num_threads} = HNSWLib.Index.get_num_threads(index) 536 | assert num_threads == 2 537 | end 538 | 539 | test "HNSWLib.Index.index_file_size/1" do 540 | space = :l2 541 | dim = 2 542 | max_elements = 2 543 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 544 | 545 | assert {:ok, 96} == HNSWLib.Index.index_file_size(index) 546 | 547 | max_elements = 400 548 | assert :ok == HNSWLib.Index.resize_index(index, max_elements) 549 | assert {:ok, 96} == HNSWLib.Index.index_file_size(index) 550 | 551 | items = Nx.tensor([[10, 20], [30, 40]], type: :f32) 552 | ids = Nx.tensor([100, 200]) 553 | :ok = HNSWLib.Index.add_items(index, items, ids: ids) 554 | assert {:ok, 400} == HNSWLib.Index.index_file_size(index) 555 | 556 | max_elements = 800 557 | assert :ok == HNSWLib.Index.resize_index(index, max_elements) 558 | assert {:ok, 400} == HNSWLib.Index.index_file_size(index) 559 | 560 | :ok = HNSWLib.Index.add_items(index, items, ids: ids) 561 | assert {:ok, 400} == HNSWLib.Index.index_file_size(index) 562 | end 563 | 564 | test "HNSWLib.Index.save_index/2" do 565 | space = :l2 566 | dim = 2 567 | max_elements = 200 568 | items = Nx.tensor([[10, 20], [30, 40]], type: :f32) 569 | ids = Nx.tensor([100, 200]) 570 | save_to = Path.join([__DIR__, "saved_index.bin"]) 571 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 572 | :ok = HNSWLib.Index.add_items(index, items, ids: ids) 573 | 574 | # ensure file does not exist 575 | File.rm(save_to) 576 | assert :ok == HNSWLib.Index.save_index(index, save_to) 577 | assert File.exists?(save_to) 578 | 579 | # cleanup 580 | File.rm(save_to) 581 | end 582 | 583 | test "HNSWLib.Index.load_index/3" do 584 | space = :l2 585 | dim = 2 586 | max_elements = 200 587 | items = Nx.tensor([[10, 20], [30, 40]], type: :f32) 588 | ids = Nx.tensor([100, 200]) 589 | save_to = Path.join([__DIR__, "saved_index.bin"]) 590 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 591 | :ok = HNSWLib.Index.add_items(index, items, ids: ids) 592 | 593 | # ensure file does not exist 594 | File.rm(save_to) 595 | assert :ok == HNSWLib.Index.save_index(index, save_to) 596 | assert File.exists?(save_to) 597 | 598 | {:ok, index_from_save} = HNSWLib.Index.load_index(space, dim, save_to) 599 | assert HNSWLib.Index.get_ids_list(index) == HNSWLib.Index.get_ids_list(index_from_save) 600 | 601 | assert HNSWLib.Index.get_current_count(index) == 602 | HNSWLib.Index.get_current_count(index_from_save) 603 | 604 | assert HNSWLib.Index.get_ef(index) == HNSWLib.Index.get_ef(index_from_save) 605 | 606 | assert HNSWLib.Index.get_ef_construction(index) == 607 | HNSWLib.Index.get_ef_construction(index_from_save) 608 | 609 | assert HNSWLib.Index.get_m(index) == HNSWLib.Index.get_m(index_from_save) 610 | 611 | assert HNSWLib.Index.get_max_elements(index) == 612 | HNSWLib.Index.get_max_elements(index_from_save) 613 | 614 | # cleanup 615 | File.rm(save_to) 616 | end 617 | 618 | test "HNSWLib.Index.load_index/3 with new max_elements" do 619 | space = :l2 620 | dim = 2 621 | max_elements = 200 622 | items = Nx.tensor([[10, 20], [30, 40]], type: :f32) 623 | ids = Nx.tensor([100, 200]) 624 | save_to = Path.join([__DIR__, "saved_index.bin"]) 625 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 626 | :ok = HNSWLib.Index.add_items(index, items, ids: ids) 627 | 628 | # ensure file does not exist 629 | File.rm(save_to) 630 | assert :ok == HNSWLib.Index.save_index(index, save_to) 631 | assert File.exists?(save_to) 632 | 633 | new_max_elements = 100 634 | 635 | {:ok, index_from_save} = 636 | HNSWLib.Index.load_index(space, dim, save_to, max_elements: new_max_elements) 637 | 638 | assert HNSWLib.Index.get_ids_list(index) == HNSWLib.Index.get_ids_list(index_from_save) 639 | 640 | assert HNSWLib.Index.get_current_count(index) == 641 | HNSWLib.Index.get_current_count(index_from_save) 642 | 643 | assert HNSWLib.Index.get_ef(index) == HNSWLib.Index.get_ef(index_from_save) 644 | 645 | assert HNSWLib.Index.get_ef_construction(index) == 646 | HNSWLib.Index.get_ef_construction(index_from_save) 647 | 648 | assert HNSWLib.Index.get_m(index) == HNSWLib.Index.get_m(index_from_save) 649 | assert {:ok, 200} == HNSWLib.Index.get_max_elements(index) 650 | assert {:ok, 100} == HNSWLib.Index.get_max_elements(index_from_save) 651 | 652 | # cleanup 653 | File.rm(save_to) 654 | end 655 | 656 | test "HNSWLib.Index.load_index/3 with missing file" do 657 | bad_filepath = "this/file/doesnt/exist" 658 | refute File.exists?(bad_filepath) 659 | assert {:error, "Cannot open file"} = HNSWLib.Index.load_index(:l2, 2, bad_filepath) 660 | end 661 | 662 | test "HNSWLib.Index.mark_deleted/2" do 663 | space = :ip 664 | dim = 2 665 | max_elements = 200 666 | items = Nx.tensor([[10, 20], [30, 40]], type: :f32) 667 | ids = Nx.tensor([100, 200]) 668 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 669 | assert :ok == HNSWLib.Index.add_items(index, items, ids: ids) 670 | 671 | assert :ok == HNSWLib.Index.mark_deleted(index, 100) 672 | end 673 | 674 | test "HNSWLib.Index.mark_deleted/2 when the id does not exist" do 675 | space = :ip 676 | dim = 2 677 | max_elements = 200 678 | items = Nx.tensor([[10, 20], [30, 40]], type: :f32) 679 | ids = Nx.tensor([100, 200]) 680 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 681 | assert :ok == HNSWLib.Index.add_items(index, items, ids: ids) 682 | 683 | assert {:error, "Label not found"} == HNSWLib.Index.mark_deleted(index, 1000) 684 | end 685 | 686 | test "HNSWLib.Index.unmark_deleted/2" do 687 | space = :ip 688 | dim = 2 689 | max_elements = 200 690 | items = Nx.tensor([[10, 20], [30, 40]], type: :f32) 691 | ids = Nx.tensor([100, 200]) 692 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 693 | assert :ok == HNSWLib.Index.add_items(index, items, ids: ids) 694 | 695 | assert :ok == HNSWLib.Index.mark_deleted(index, 100) 696 | assert :ok == HNSWLib.Index.unmark_deleted(index, 100) 697 | end 698 | 699 | test "HNSWLib.Index.unmark_deleted/2 when the id does not exist" do 700 | space = :ip 701 | dim = 2 702 | max_elements = 200 703 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 704 | 705 | assert {:error, "Label not found"} == HNSWLib.Index.unmark_deleted(index, 1000) 706 | end 707 | 708 | test "HNSWLib.Index.resize_index/2" do 709 | space = :l2 710 | dim = 2 711 | max_elements = 200 712 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 713 | 714 | assert {:ok, 200} == HNSWLib.Index.get_max_elements(index) 715 | 716 | max_elements = 400 717 | assert :ok == HNSWLib.Index.resize_index(index, max_elements) 718 | assert {:ok, 400} == HNSWLib.Index.get_max_elements(index) 719 | end 720 | 721 | test "HNSWLib.Index.resize_index/2 with size that exceeds memory capacity" do 722 | space = :l2 723 | dim = 200 724 | max_elements = 2 725 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 726 | 727 | assert {:ok, 2} == HNSWLib.Index.get_max_elements(index) 728 | 729 | max_elements = 999_999_999_999_999_999 730 | 731 | assert {:error, "no enough memory available to resize the index"} == 732 | HNSWLib.Index.resize_index(index, max_elements) 733 | end 734 | 735 | test "HNSWLib.Index.get_max_elements/1" do 736 | space = :l2 737 | dim = 2 738 | max_elements = 200 739 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 740 | 741 | assert {:ok, 200} == HNSWLib.Index.get_max_elements(index) 742 | end 743 | 744 | test "HNSWLib.Index.get_current_count/1 when empty" do 745 | space = :l2 746 | dim = 2 747 | max_elements = 200 748 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 749 | 750 | assert {:ok, 0} == HNSWLib.Index.get_current_count(index) 751 | end 752 | 753 | test "HNSWLib.Index.get_current_count/1 before and after" do 754 | space = :l2 755 | dim = 2 756 | max_elements = 200 757 | items = Nx.tensor([[10, 20], [30, 40]], type: :f32) 758 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 759 | 760 | assert {:ok, 0} == HNSWLib.Index.get_current_count(index) 761 | assert :ok == HNSWLib.Index.add_items(index, items) 762 | assert {:ok, 2} == HNSWLib.Index.get_current_count(index) 763 | end 764 | 765 | test "HNSWLib.Index.get_ef_construction/1 with default config" do 766 | space = :l2 767 | dim = 2 768 | max_elements = 200 769 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 770 | 771 | assert {:ok, 200} == HNSWLib.Index.get_ef_construction(index) 772 | end 773 | 774 | test "HNSWLib.Index.get_ef_construction/1 with custom config" do 775 | space = :l2 776 | dim = 2 777 | max_elements = 200 778 | ef_construction = 300 779 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements, ef_construction: ef_construction) 780 | 781 | assert {:ok, 300} == HNSWLib.Index.get_ef_construction(index) 782 | end 783 | 784 | test "HNSWLib.Index.get_m/1 with default config" do 785 | space = :l2 786 | dim = 2 787 | max_elements = 200 788 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements) 789 | 790 | assert {:ok, 16} == HNSWLib.Index.get_m(index) 791 | end 792 | 793 | test "HNSWLib.Index.get_m/1 with custom config" do 794 | space = :l2 795 | dim = 2 796 | max_elements = 200 797 | m = 32 798 | {:ok, index} = HNSWLib.Index.new(space, dim, max_elements, m: m) 799 | 800 | assert {:ok, 32} == HNSWLib.Index.get_m(index) 801 | end 802 | end 803 | -------------------------------------------------------------------------------- /test/test_helper.exs: -------------------------------------------------------------------------------- 1 | ExUnit.start() 2 | --------------------------------------------------------------------------------