├── .formatter.exs ├── .github └── workflows │ └── test.yaml ├── .gitignore ├── Makefile ├── README.md ├── c_src ├── ex_faiss.cc └── ex_faiss │ ├── clustering.cc │ ├── clustering.h │ ├── index.cc │ ├── index.h │ ├── nif_util.cc │ └── nif_util.h ├── lib ├── ex_faiss.ex └── ex_faiss │ ├── clustering.ex │ ├── index.ex │ ├── nif.ex │ └── shared.ex ├── mix.exs ├── mix.lock └── test ├── ex_faiss ├── clustering_test.exs └── index_test.exs ├── ex_faiss_test.exs └── test_helper.exs /.formatter.exs: -------------------------------------------------------------------------------- 1 | # Used by "mix format" 2 | [ 3 | inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"] 4 | ] 5 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: Test 2 | on: 3 | pull_request: 4 | push: 5 | branches: 6 | - main 7 | env: 8 | OTP_VERSION: "25.0" 9 | ELIXIR_VERSION: "1.14.0" 10 | MIX_ENV: test 11 | jobs: 12 | main: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v3 16 | - name: Install Erlang & Elixir 17 | uses: erlef/setup-beam@v1 18 | with: 19 | otp-version: "${{ env.OTP_VERSION }}" 20 | elixir-version: "${{ env.ELIXIR_VERSION }}" 21 | - uses: actions/cache@v3 22 | with: 23 | path: | 24 | deps 25 | _build 26 | cache 27 | key: ${{ runner.os }}-mix-${{ matrix.pair.elixir }}-${{ matrix.pair.otp }}-${{ hashFiles('**/mix.lock') }} 28 | restore-keys: | 29 | ${{ runner.os }}-mix- 30 | - name: Install mix dependencies 31 | run: mix deps.get 32 | - name: Check formatting 33 | run: mix format --check-formatted 34 | - name: Run tests 35 | run: mix test -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore cache artifacts 2 | /cache/ 3 | 4 | # The directory Mix will write compiled artifacts to. 5 | /_build/ 6 | 7 | # If you run "mix test --cover", coverage assets end up here. 8 | /cover/ 9 | 10 | # The directory Mix downloads your dependencies sources to. 11 | /deps/ 12 | 13 | # Where third-party dependencies like ExDoc output generated docs. 14 | /doc/ 15 | 16 | # Ignore .fetch files in case you like to edit your project deps locally. 17 | /.fetch 18 | 19 | # If the VM crashes, it generates a dump, let's ignore it too. 20 | erl_crash.dump 21 | 22 | # Also ignore archive artifacts (built via "mix archive.build"). 23 | *.ez 24 | 25 | # Ignore package tarball (built via "mix hex.build"). 26 | ex_faiss-*.tar 27 | 28 | # Temporary files, for example, from tests. 29 | /tmp/ 30 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Environment variables passed via elixir_make 2 | # ERTS_INCLUDE_DIR 3 | # MIX_APP_PATH 4 | 5 | TEMP ?= $(HOME)/.cache 6 | FAISS_CACHE ?= $(TEMP)/ex_faiss 7 | FAISS_GIT_REPO ?= https://www.github.com/facebookresearch/faiss 8 | FAISS_GIT_REV ?= 19f7696deedc93615c3ee0ff4de22284b53e0243 9 | FAISS_NS = faiss-$(FAISS_GIT_REV) 10 | FAISS_DIR = $(FAISS_CACHE)/$(FAISS_NS) 11 | FAISS_LIB_DIR = $(FAISS_DIR)/build/faiss 12 | FAISS_LIB_DIR_FLAG = $(FAISS_DIR)/build/faiss/ex_faiss.ok 13 | 14 | # Private configuration 15 | PRIV_DIR = $(MIX_APP_PATH)/priv 16 | EX_FAISS_DIR = c_src/ex_faiss 17 | EX_FAISS_CACHE_SO = cache/libex_faiss.so 18 | EX_FAISS_CACHE_LIB_DIR = cache/lib 19 | EX_FAISS_SO = $(PRIV_DIR)/libex_faiss.so 20 | EX_FAISS_LIB_DIR = $(PRIV_DIR)/lib 21 | 22 | # Build flags 23 | CFLAGS = -I$(ERTS_INCLUDE_DIR) -I$(FAISS_DIR) -fPIC -O3 -shared -std=c++14 24 | CMAKE_FLAGS = -DFAISS_ENABLE_PYTHON=OFF -DBUILD_TESTING=OFF -DBUILD_SHARED_LIBS=ON 25 | 26 | ifeq ($(USE_CUDA), true) 27 | CFLAGS += -D__CUDA__ 28 | CMAKE_FLAGS += -DFAISS_ENABLE_GPU=ON 29 | else 30 | CMAKE_FLAGS += -DFAISS_ENABLE_GPU=OFF 31 | endif 32 | 33 | C_SRCS = c_src/ex_faiss.cc $(EX_FAISS_DIR)/nif_util.cc $(EX_FAISS_DIR)/nif_util.h \ 34 | $(EX_FAISS_DIR)/index.cc $(EX_FAISS_DIR)/index.h $(EX_FAISS_DIR)/clustering.cc \ 35 | $(EX_FAISS_DIR)/clustering.h 36 | 37 | LDFLAGS = -L$(EX_FAISS_CACHE_LIB_DIR) -lfaiss 38 | 39 | ifeq ($(shell uname -s), Darwin) 40 | LDFLAGS += -flat_namespace -undefined suppress 41 | POST_INSTALL = install_name_tool $(EX_FAISS_CACHE_SO) -change @rpath/libfaiss.dylib @loader_path/lib/libfaiss.dylib 42 | 43 | ifeq ($(USE_LLVM_BREW), true) 44 | LLVM_PREFIX=$(shell brew --prefix llvm) 45 | CMAKE_FLAGS += -DCMAKE_CXX_COMPILER=$(LLVM_PREFIX)/bin/clang++ 46 | endif 47 | else 48 | # Use a relative RPATH, so at runtime libex_faiss.so looks for libfaiss.so 49 | # in ./lib regardless of the absolute location. This way priv can be safely 50 | # packed into an Elixir release. Also, we use $$ to escape Makefile variable 51 | # and single quotes to escape shell variable 52 | LDFLAGS += -Wl,-rpath,'$$ORIGIN/lib' 53 | POST_INSTALL = $(NOOP) 54 | endif 55 | 56 | $(EX_FAISS_SO): $(EX_FAISS_CACHE_SO) 57 | @ mkdir -p $(PRIV_DIR) 58 | @ if [ "${MIX_BUILD_EMBEDDED}" = "true" ]; then \ 59 | cp -a $(abspath $(EX_FAISS_CACHE_LIB_DIR)) $(EX_FAISS_LIB_DIR) ; \ 60 | cp -a $(abspath $(EX_FAISS_CACHE_SO)) $(EX_FAISS_SO) ; \ 61 | else \ 62 | ln -sf $(abspath $(EX_FAISS_CACHE_LIB_DIR)) $(EX_FAISS_LIB_DIR) ; \ 63 | ln -sf $(abspath $(EX_FAISS_CACHE_SO)) $(EX_FAISS_SO) ; \ 64 | fi 65 | 66 | $(EX_FAISS_CACHE_SO): $(FAISS_LIB_DIR_FLAG) $(C_SRCS) 67 | @mkdir -p cache 68 | cp -a $(FAISS_LIB_DIR) $(EX_FAISS_CACHE_LIB_DIR) 69 | $(CXX) $(CFLAGS) c_src/ex_faiss.cc $(EX_FAISS_DIR)/nif_util.cc $(EX_FAISS_DIR)/index.cc \ 70 | $(EX_FAISS_DIR)/clustering.cc -o $(EX_FAISS_CACHE_SO) $(LDFLAGS) 71 | $(POST_INSTALL) 72 | 73 | $(FAISS_LIB_DIR_FLAG): 74 | rm -rf $(FAISS_DIR) && \ 75 | mkdir -p $(FAISS_DIR) && \ 76 | cd $(FAISS_DIR) && \ 77 | git init && \ 78 | git remote add origin $(FAISS_GIT_REPO) && \ 79 | git fetch --depth 1 origin $(FAISS_GIT_REV) && \ 80 | git checkout FETCH_HEAD && \ 81 | cmake -B build . $(CMAKE_FLAGS) && \ 82 | make -C build -j faiss 83 | touch $(FAISS_LIB_DIR_FLAG) 84 | 85 | clean: 86 | rm -rf $(EX_FAISS_CACHE_SO) 87 | rm -rf $(EX_FAISS_CACHE_LIB_DIR) 88 | rm -rf $(EX_FAISS_SO) 89 | rm -rf $(EX_FAISS_LIB_DIR) 90 | rm -rf $(FAISS_DIR) 91 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ExFaiss 2 | 3 | > Note: this library is experimental and not a priority. Consider using [HNSWLib](https://github.com/elixir-nx/hnswlib) instead. 4 | 5 | Elixir front-end for [Facebook AI Similarity Search (Faiss)](https://github.com/facebookresearch/faiss). 6 | 7 | ExFaiss is a low-level wrapper around Faiss which allows you to create and manage Faiss indices and clusterings. Faiss enables efficient search and clustering of dense vectors and has the potential to scale to millions, billions, and even trillions of vectors. ExFaiss works directly with [Nx](https://github.com/elixir-nx/nx) tensors, so you can seamlessly integrate ExFaiss into your existing Elixir ML workflows. 8 | 9 | ## Installation 10 | 11 | Add `ex_faiss` to your dependencies: 12 | 13 | ```elixir 14 | def deps do 15 | [ 16 | {:ex_faiss, github: "elixir-nx/ex_faiss"} 17 | ] 18 | end 19 | ``` 20 | 21 | ExFaiss will download, build, and cache Faiss on the first compilation. You must have CMake installed in order to build Faiss. 22 | 23 | ### macOS Compilation 24 | 25 | If you have troubles building on a macOS, you can try installing LLVM from homebrew. 26 | 27 | ```shell 28 | $ brew install llvm cmake 29 | ``` 30 | 31 | And tell ExFaiss to use it by setting the environment variable `USE_LLVM_BREW=true`. 32 | 33 | ### GPU Installation 34 | 35 | If you have an NVIDIA GPU with CUDA installed, you can enable the GPU build by setting the environment variable `USE_CUDA=true`. Note that if you have already built Faiss without GPU support, you will need to delete the cached build before continuing. You can clean the existing installation by running `make clean`. 36 | 37 | ## Working with Indices 38 | 39 | You can create indices which follow the syntax of Faiss' [Index Factory](https://github.com/facebookresearch/faiss/wiki/The-index-factory). Indices require you to also specify a dimensionality of the vectors you plan to store: 40 | 41 | ```elixir 42 | index = ExFaiss.Index.new(128, "Flat") 43 | ``` 44 | 45 | You can optionally place an index on a GPU by specifying the `:device` option: 46 | 47 | ```elixir 48 | index = ExFaiss.Index.new(128, "Flat", device: :cuda) 49 | ``` 50 | 51 | Finally, you can add one or more tensors to the index at a time: 52 | 53 | ```elixir 54 | index = ExFaiss.Index.add(index, Nx.random_uniform({32, 128})) 55 | ``` 56 | 57 | And then search the index for similar vectors: 58 | 59 | ```elixir 60 | result = ExFaiss.Index.search(index, Nx.random_uniform({128}), 5) 61 | ``` 62 | 63 | Returns: 64 | 65 | ``` 66 | %{ 67 | distances: #Nx.Tensor< 68 | f32[1][5] 69 | [ 70 | [18.473186492919922, 18.697336196899414, 19.020721435546875, 19.091503143310547, 19.53148078918457] 71 | ] 72 | >, 73 | labels: #Nx.Tensor< 74 | s64[1][5] 75 | [ 76 | [25, 0, 2, 9, 13] 77 | ] 78 | > 79 | } 80 | ``` 81 | 82 | ## License 83 | 84 | ``` 85 | Copyright (c) 2022 The Machine Learning Working Group of the Erlang Ecosystem Foundation 86 | 87 | Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 88 | 89 | Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. 90 | ``` 91 | -------------------------------------------------------------------------------- /c_src/ex_faiss.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "ex_faiss/nif_util.h" 5 | #include "ex_faiss/index.h" 6 | #include "ex_faiss/clustering.h" 7 | 8 | #if defined(__CUDA__) 9 | #include 10 | #endif 11 | 12 | void free_ex_faiss_index(ErlNifEnv * env, void * obj) { 13 | ex_faiss::ExFaissIndex ** index = (ex_faiss::ExFaissIndex **) obj; 14 | if (*index != nullptr) { 15 | delete *index; 16 | *index = nullptr; 17 | } 18 | } 19 | 20 | void free_ex_faiss_clustering(ErlNifEnv * env, void * obj) { 21 | ex_faiss::ExFaissClustering ** clustering = (ex_faiss::ExFaissClustering **) obj; 22 | if (*clustering != nullptr) { 23 | delete *clustering; 24 | *clustering = nullptr; 25 | } 26 | } 27 | 28 | static int open_resources(ErlNifEnv* env) { 29 | const char * mod = "ExFaiss"; 30 | 31 | if (!nif::open_resource(env, mod, "Index", free_ex_faiss_index)) { 32 | return -1; 33 | } 34 | if (!nif::open_resource(env, mod, "Clustering", free_ex_faiss_clustering)) { 35 | return -1; 36 | } 37 | 38 | return 1; 39 | } 40 | 41 | static int load(ErlNifEnv* env, void ** priv, ERL_NIF_TERM load_info) { 42 | if (open_resources(env) == -1) return -1; 43 | 44 | return 0; 45 | } 46 | 47 | ERL_NIF_TERM new_index(ErlNifEnv * env, int argc, const ERL_NIF_TERM argv[]) { 48 | if (argc != 3) { 49 | nif::error(env, "Bad argument count."); 50 | } 51 | 52 | int64_t dim; 53 | std::string description; 54 | faiss::MetricType metric_type; 55 | 56 | if (!nif::get(env, argv[0], &dim)) { 57 | return nif::error(env, "Unable to get dimensionality."); 58 | } 59 | if (!nif::get(env, argv[1], description)) { 60 | return nif::error(env, "Unable to get string."); 61 | } 62 | if (!nif::get_metric_type(env, argv[2], &metric_type)) { 63 | return nif::error(env, "Unable to get metric type."); 64 | } 65 | 66 | ex_faiss::ExFaissIndex * index = new ex_faiss::ExFaissIndex(dim, description.c_str(), metric_type); 67 | return nif::ok(env, nif::make(env, index)); 68 | } 69 | 70 | ERL_NIF_TERM clone_index(ErlNifEnv * env, int argc, const ERL_NIF_TERM argv[]) { 71 | if (argc != 1) { 72 | return nif::error(env, "Bad argument count."); 73 | } 74 | 75 | ex_faiss::ExFaissIndex ** index; 76 | 77 | if (!nif::get(env, argv[0], index)) { 78 | return nif::error(env, "Unable to get index."); 79 | } 80 | 81 | ex_faiss::ExFaissIndex * cloned = (*index)->Clone(); 82 | return nif::ok(env, nif::make(env, cloned)); 83 | } 84 | 85 | ERL_NIF_TERM write_index(ErlNifEnv * env, int argc, const ERL_NIF_TERM argv[]) { 86 | if (argc != 2) { 87 | return nif::error(env, "Bad argument count."); 88 | } 89 | 90 | ex_faiss::ExFaissIndex ** index; 91 | std::string fname; 92 | 93 | if (!nif::get(env, argv[0], index)) { 94 | return nif::error(env, "Unable to get index."); 95 | } 96 | if (!nif::get(env, argv[1], fname)) { 97 | return nif::error(env, "Unable to get fname."); 98 | } 99 | 100 | (*index)->WriteToFile(fname.c_str()); 101 | 102 | return nif::ok(env); 103 | } 104 | 105 | ERL_NIF_TERM read_index(ErlNifEnv * env, int argc, const ERL_NIF_TERM argv[]) { 106 | if (argc != 2) { 107 | return nif::error(env, "Bad argument count."); 108 | } 109 | 110 | std::string fname; 111 | int32_t io_flags; 112 | 113 | if (!nif::get(env, argv[0], fname)) { 114 | return nif::error(env, "Unable to get fname."); 115 | } 116 | if (!nif::get(env, argv[1], &io_flags)) { 117 | return nif::error(env, "Unable to get IO flags."); 118 | } 119 | 120 | ex_faiss::ExFaissIndex * index = ex_faiss::ReadIndexFromFile(fname.c_str(), io_flags); 121 | 122 | return nif::ok(env, nif::make(env, index)); 123 | } 124 | 125 | ERL_NIF_TERM add_to_index(ErlNifEnv * env, int argc, const ERL_NIF_TERM argv[]) { 126 | if (argc != 3) { 127 | return nif::error(env, "Bad argument count."); 128 | } 129 | 130 | ex_faiss::ExFaissIndex ** index; 131 | int64_t n; 132 | ErlNifBinary data; 133 | 134 | if (!nif::get(env, argv[0], index)) { 135 | return nif::error(env, "Unable to get index."); 136 | } 137 | if (!nif::get(env, argv[1], &n)) { 138 | return nif::error(env, "Unable to get n."); 139 | } 140 | if (!nif::get_binary(env, argv[2], &data)) { 141 | return nif::error(env, "Unable to get data."); 142 | } 143 | 144 | (*index)->Add(n, reinterpret_cast(data.data)); 145 | 146 | return nif::ok(env); 147 | } 148 | 149 | ERL_NIF_TERM add_with_ids_to_index(ErlNifEnv * env, int argc, const ERL_NIF_TERM argv[]) { 150 | if (argc != 4) { 151 | return nif::error(env, "Bad argument count."); 152 | } 153 | 154 | ex_faiss::ExFaissIndex ** index; 155 | int64_t n; 156 | ErlNifBinary data; 157 | ErlNifBinary ids; 158 | 159 | if (!nif::get(env, argv[0], index)) { 160 | return nif::error(env, "Unable to get index."); 161 | } 162 | if (!nif::get(env, argv[1], &n)) { 163 | return nif::error(env, "Unable to get n."); 164 | } 165 | if (!nif::get_binary(env, argv[2], &data)) { 166 | return nif::error(env, "Unable to get data."); 167 | } 168 | if (!nif::get_binary(env, argv[3], &ids)) { 169 | return nif::error(env, "Unable to get ids."); 170 | } 171 | 172 | (*index)->AddWithIds(n, reinterpret_cast(data.data), reinterpret_cast(ids.data)); 173 | 174 | return nif::ok(env); 175 | } 176 | 177 | ERL_NIF_TERM search_index(ErlNifEnv * env, int argc, const ERL_NIF_TERM argv[]) { 178 | if (argc != 4) { 179 | return nif::error(env, "Bad argument count."); 180 | } 181 | 182 | ex_faiss::ExFaissIndex ** index; 183 | int64_t n; 184 | ErlNifBinary data; 185 | int64_t k; 186 | 187 | if (!nif::get(env, argv[0], index)) { 188 | return nif::error(env, "Unable to get index."); 189 | } 190 | if (!nif::get(env, argv[1], &n)) { 191 | return nif::error(env, "Unable to get n."); 192 | } 193 | if (!nif::get_binary(env, argv[2], &data)) { 194 | return nif::error(env, "Unable to get data."); 195 | } 196 | if (!nif::get(env, argv[3], &k)) { 197 | return nif::error(env, "Unable to get k."); 198 | } 199 | 200 | ErlNifBinary distances, labels; 201 | enif_alloc_binary(n * k * sizeof(float), &distances); 202 | enif_alloc_binary(n * k * sizeof(int64_t), &labels); 203 | 204 | (*index)->Search(n, 205 | reinterpret_cast(data.data), 206 | k, 207 | reinterpret_cast(distances.data), 208 | reinterpret_cast(labels.data)); 209 | 210 | ERL_NIF_TERM distances_term = nif::make(env, distances); 211 | ERL_NIF_TERM labels_term = nif::make(env, labels); 212 | 213 | return nif::ok(env, enif_make_tuple2(env, distances_term, labels_term)); 214 | } 215 | 216 | ERL_NIF_TERM train_index(ErlNifEnv * env, int argc, const ERL_NIF_TERM argv[]) { 217 | if (argc != 3) { 218 | return nif::error(env, "Bad argument count."); 219 | } 220 | 221 | ex_faiss::ExFaissIndex ** index; 222 | int64_t n; 223 | ErlNifBinary data; 224 | 225 | if (!nif::get(env, argv[0], index)) { 226 | return nif::error(env, "Unable to get index."); 227 | } 228 | if (!nif::get(env, argv[1], &n)) { 229 | return nif::error(env, "Unable to get n."); 230 | } 231 | if (!nif::get_binary(env, argv[2], &data)) { 232 | return nif::error(env, "Unable to get data."); 233 | } 234 | 235 | (*index)->Train(n, reinterpret_cast(data.data)); 236 | 237 | return nif::ok(env); 238 | } 239 | 240 | ERL_NIF_TERM reconstruct_batch_from_index(ErlNifEnv * env, int argc, const ERL_NIF_TERM argv[]) { 241 | if (argc != 3) { 242 | return nif::error(env, "Bad argument count."); 243 | } 244 | 245 | ex_faiss::ExFaissIndex ** index; 246 | int64_t n; 247 | ErlNifBinary keys; 248 | 249 | if (!nif::get(env, argv[0], index)) { 250 | return nif::error(env, "Unable to get index."); 251 | } 252 | if (!nif::get(env, argv[1], &n)) { 253 | return nif::error(env, "Unable to get n."); 254 | } 255 | if (!nif::get_binary(env, argv[2], &keys)) { 256 | return nif::error(env, "Unable to get keys."); 257 | } 258 | 259 | int64_t d = (*index)->dim(); 260 | 261 | ErlNifBinary reconstruction; 262 | enif_alloc_binary(n * d * sizeof(float), &reconstruction); 263 | 264 | (*index)->ReconstructBatch(n, reinterpret_cast(keys.data), reinterpret_cast(reconstruction.data)); 265 | 266 | return nif::ok(env, nif::make(env, reconstruction)); 267 | } 268 | 269 | ERL_NIF_TERM compute_residuals_from_index(ErlNifEnv * env, int argc, const ERL_NIF_TERM argv[]) { 270 | if (argc != 4) { 271 | return nif::error(env, "Bad argument count."); 272 | } 273 | 274 | ex_faiss::ExFaissIndex ** index; 275 | int64_t n; 276 | ErlNifBinary data; 277 | ErlNifBinary keys; 278 | 279 | if (!nif::get(env, argv[0], index)) { 280 | return nif::error(env, "Unable to get index."); 281 | } 282 | if (!nif::get(env, argv[1], &n)) { 283 | return nif::error(env, "Unable to get n."); 284 | } 285 | if (!nif::get_binary(env, argv[2], &data)) { 286 | return nif::error(env, "Unable to get data."); 287 | } 288 | if (!nif::get_binary(env, argv[3], &keys)) { 289 | return nif::error(env, "Unable to get keys."); 290 | } 291 | 292 | int64_t d = (*index)->dim(); 293 | 294 | ErlNifBinary residuals; 295 | enif_alloc_binary(n * d * sizeof(float), &residuals); 296 | 297 | (*index)->ComputeResiduals(n, 298 | reinterpret_cast(data.data), 299 | reinterpret_cast(residuals.data), 300 | reinterpret_cast(keys.data)); 301 | 302 | return nif::ok(env, nif::make(env, residuals)); 303 | } 304 | 305 | ERL_NIF_TERM reset_index(ErlNifEnv * env, int argc, const ERL_NIF_TERM argv[]) { 306 | if (argc != 1) { 307 | return nif::error(env, "Bad argument count."); 308 | } 309 | 310 | ex_faiss::ExFaissIndex ** index; 311 | 312 | if (!nif::get(env, argv[0], index)) { 313 | return nif::error(env, "Unable to get index."); 314 | } 315 | 316 | (*index)->Reset(); 317 | 318 | return nif::ok(env); 319 | } 320 | 321 | ERL_NIF_TERM get_index_dim(ErlNifEnv * env, int argc, const ERL_NIF_TERM argv[]) { 322 | if (argc != 1) { 323 | return nif::error(env, "Bad argument count."); 324 | } 325 | 326 | ex_faiss::ExFaissIndex ** index; 327 | 328 | if (!nif::get(env, argv[0], index)) { 329 | return nif::error(env, "Unable to get index."); 330 | } 331 | 332 | int dim = (*index)->dim(); 333 | 334 | return nif::ok(env, nif::make(env, dim)); 335 | } 336 | 337 | ERL_NIF_TERM get_index_n_vectors(ErlNifEnv * env, int argc, const ERL_NIF_TERM argv[]) { 338 | if (argc != 1) { 339 | return nif::error(env, "Bad argument count."); 340 | } 341 | 342 | ex_faiss::ExFaissIndex ** index; 343 | 344 | if (!nif::get(env, argv[0], index)) { 345 | return nif::error(env, "Unable to get index."); 346 | } 347 | 348 | int64_t n_total = (*index)->n_total(); 349 | 350 | return nif::ok(env, nif::make(env, n_total)); 351 | } 352 | 353 | ERL_NIF_TERM index_cpu_to_gpu(ErlNifEnv * env, int argc, const ERL_NIF_TERM argv[]) { 354 | if (argc != 2) { 355 | return nif::error(env, "Bad argument count."); 356 | } 357 | 358 | ex_faiss::ExFaissIndex ** index; 359 | int device; 360 | 361 | if (!nif::get(env, argv[0], index)) { 362 | return nif::error(env, "Unable to get index."); 363 | } 364 | if (!nif::get(env, argv[1], &device)) { 365 | return nif::error(env, "Unable to get device."); 366 | } 367 | 368 | ex_faiss::ExFaissIndex * gpu_index = (*index)->CloneToGpu(device); 369 | 370 | return nif::ok(env, nif::make(env, gpu_index)); 371 | } 372 | 373 | ERL_NIF_TERM get_num_gpus(ErlNifEnv * env, int argc, const ERL_NIF_TERM argv[]) { 374 | if (argc != 0) { 375 | return nif::error(env, "Bad argument count."); 376 | } 377 | 378 | int gpus; 379 | 380 | #if defined(__CUDA__) 381 | gpus = faiss::gpu::getNumDevices(); 382 | #else 383 | gpus = 0; 384 | #endif 385 | 386 | return nif::ok(env, nif::make(env, gpus)); 387 | } 388 | 389 | ERL_NIF_TERM new_clustering(ErlNifEnv * env, int argc, const ERL_NIF_TERM argv[]) { 390 | if (argc != 2) { 391 | return nif::error(env, "Bad argument count."); 392 | } 393 | 394 | int d; 395 | int k; 396 | 397 | if (!nif::get(env, argv[0], &d)) { 398 | return nif::error(env, "Unable to get d."); 399 | } 400 | if (!nif::get(env, argv[1], &k)) { 401 | return nif::error(env, "Unable to get k."); 402 | } 403 | 404 | ex_faiss::ExFaissClustering * clustering = new ex_faiss::ExFaissClustering(d, k); 405 | 406 | return nif::ok(env, nif::make(env, clustering)); 407 | } 408 | 409 | ERL_NIF_TERM train_clustering(ErlNifEnv * env, int argc, const ERL_NIF_TERM argv[]) { 410 | if (argc != 4) { 411 | return nif::error(env, "Bad argument count."); 412 | } 413 | 414 | ex_faiss::ExFaissClustering ** clustering; 415 | int64_t n; 416 | ErlNifBinary data; 417 | ex_faiss::ExFaissIndex ** index; 418 | 419 | if (!nif::get(env, argv[0], clustering)) { 420 | return nif::error(env, "Unable to get clustering."); 421 | } 422 | if (!nif::get(env, argv[1], &n)) { 423 | return nif::error(env, "Unable to get n."); 424 | } 425 | if (!nif::get_binary(env, argv[2], &data)) { 426 | return nif::error(env, "Unable to get data."); 427 | } 428 | if (!nif::get(env, argv[3], index)) { 429 | return nif::error(env, "Unable to get index."); 430 | } 431 | 432 | (*clustering)->Train(n, reinterpret_cast(data.data), *index); 433 | 434 | return nif::ok(env); 435 | } 436 | 437 | ERL_NIF_TERM get_clustering_centroids(ErlNifEnv * env, int argc, const ERL_NIF_TERM argv[]) { 438 | if (argc != 1) { 439 | return nif::error(env, "Bad argument count."); 440 | } 441 | 442 | ex_faiss::ExFaissClustering ** clustering; 443 | 444 | if (!nif::get(env, argv[0], clustering)) { 445 | return nif::error(env, "Unable to get clustering."); 446 | } 447 | 448 | size_t d, k; 449 | d = (*clustering)->dimensionality(); 450 | k = (*clustering)->n_centroids(); 451 | 452 | ErlNifBinary data; 453 | enif_alloc_binary(d * k * sizeof(float), &data); 454 | 455 | std::vector centroids = (*clustering)->centroids(); 456 | std::memcpy(data.data, centroids.data(), data.size); 457 | 458 | return nif::ok(env, nif::make(env, data)); 459 | } 460 | 461 | static ErlNifFunc ex_faiss_funcs[] = { 462 | // Index CPU 463 | {"new_index", 3, new_index}, 464 | {"clone_index", 1, clone_index}, 465 | {"write_index", 2, write_index, ERL_NIF_DIRTY_JOB_IO_BOUND}, 466 | {"read_index", 2, read_index, ERL_NIF_DIRTY_JOB_IO_BOUND}, 467 | {"add_to_index", 3, add_to_index, ERL_NIF_DIRTY_JOB_CPU_BOUND}, 468 | {"add_with_ids_to_index", 4, add_with_ids_to_index, ERL_NIF_DIRTY_JOB_CPU_BOUND}, 469 | {"search_index", 4, search_index, ERL_NIF_DIRTY_JOB_CPU_BOUND}, 470 | {"train_index", 3, train_index, ERL_NIF_DIRTY_JOB_CPU_BOUND}, 471 | {"reset_index", 1, reset_index}, 472 | {"reconstruct_batch_from_index", 3, reconstruct_batch_from_index}, 473 | {"compute_residuals_from_index", 4, compute_residuals_from_index}, 474 | {"get_index_dim", 1, get_index_dim}, 475 | {"get_index_n_vectors", 1, get_index_n_vectors}, 476 | // Index GPU 477 | {"index_cpu_to_gpu", 2, index_cpu_to_gpu, ERL_NIF_DIRTY_JOB_IO_BOUND}, 478 | {"get_num_gpus", 0, get_num_gpus}, 479 | // Clustering CPU 480 | {"new_clustering", 2, new_clustering, ERL_NIF_DIRTY_JOB_CPU_BOUND}, 481 | {"train_clustering", 4, train_clustering, ERL_NIF_DIRTY_JOB_CPU_BOUND}, 482 | {"get_clustering_centroids", 1, get_clustering_centroids} 483 | }; 484 | 485 | ERL_NIF_INIT(Elixir.ExFaiss.NIF, ex_faiss_funcs, &load, NULL, NULL, NULL); -------------------------------------------------------------------------------- /c_src/ex_faiss/clustering.cc: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "index.h" 6 | #include "clustering.h" 7 | 8 | namespace ex_faiss { 9 | 10 | ExFaissClustering::ExFaissClustering(int d, int k) { 11 | clustering_ = std::make_unique(d, k); 12 | } 13 | 14 | void ExFaissClustering::Train(int64_t n, const float * x, ExFaissIndex * index) { 15 | clustering_->train(n, x, *(index->index())); 16 | } 17 | 18 | } // namespace ex_faiss -------------------------------------------------------------------------------- /c_src/ex_faiss/clustering.h: -------------------------------------------------------------------------------- 1 | #ifndef EX_FAISS_CLUSTERING_H_ 2 | #define EX_FAISS_CLUSTERING_H_ 3 | 4 | #include "index.h" 5 | #include 6 | 7 | namespace ex_faiss { 8 | 9 | class ExFaissClustering { 10 | public: 11 | ExFaissClustering(int d, int k); 12 | 13 | // TODO: Handle weights 14 | void Train(int64_t n, const float * x, ExFaissIndex * index); 15 | 16 | std::vector centroids() { return clustering_->centroids; } 17 | size_t dimensionality() { return clustering_->d; } 18 | size_t n_centroids() { return clustering_->k; } 19 | std::vector iteration_stats() { return clustering_->iteration_stats; } 20 | 21 | private: 22 | std::unique_ptr clustering_; 23 | }; 24 | 25 | } // namespace ex_faiss 26 | #endif -------------------------------------------------------------------------------- /c_src/ex_faiss/index.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #if defined(__CUDA__) 6 | #include 7 | #include 8 | #endif 9 | 10 | #include "index.h" 11 | 12 | namespace ex_faiss { 13 | 14 | ExFaissIndex::ExFaissIndex(faiss::Index * index) { 15 | index_ = std::unique_ptr(index); 16 | } 17 | 18 | ExFaissIndex::ExFaissIndex(int dim, 19 | const char * description, 20 | faiss::MetricType metric_type) { 21 | faiss::Index * index = faiss::index_factory(dim, description, metric_type); 22 | index_ = std::unique_ptr(index); 23 | } 24 | 25 | ExFaissIndex * ExFaissIndex::Clone() { 26 | faiss::Index * index = faiss::clone_index(index_.get()); 27 | return new ExFaissIndex(index); 28 | } 29 | 30 | ExFaissIndex * ExFaissIndex::CloneToGpu(int device) { 31 | #if defined(__CUDA__) 32 | faiss::gpu::StandardGpuResources res; 33 | faiss::Index * index = faiss::gpu::index_cpu_to_gpu(&res, device, index_.get()); 34 | return new ExFaissIndex(index); 35 | #else 36 | return nullptr; 37 | #endif 38 | } 39 | 40 | void ExFaissIndex::Add(int64_t n, const float * x) { 41 | index_->add(n, x); 42 | } 43 | 44 | void ExFaissIndex::AddWithIds(int64_t n, const float * x, const int64_t * xids) { 45 | index_->add_with_ids(n, x, xids); 46 | } 47 | 48 | void ExFaissIndex::Search(int64_t n, const float * x, int64_t k, float * distances, int64_t * labels) { 49 | index_->search(n, x, k, distances, labels); 50 | } 51 | 52 | void ExFaissIndex::Train(int64_t n, const float * x) { 53 | index_->train(n, x); 54 | } 55 | 56 | void ExFaissIndex::Reset() { 57 | index_->reset(); 58 | } 59 | 60 | void ExFaissIndex::ReconstructBatch(int64_t n, const int64_t * keys, float * recons) { 61 | index_->reconstruct_batch(n, keys, recons); 62 | } 63 | 64 | void ExFaissIndex::ComputeResiduals(int64_t n, const float * data, float * resid, const int64_t * keys) { 65 | index_->compute_residual_n(n, data, resid, keys); 66 | } 67 | 68 | void ExFaissIndex::WriteToFile(const char * fname) { 69 | faiss::write_index(index_.get(), fname); 70 | } 71 | 72 | ExFaissIndex * ReadIndexFromFile(const char * fname, int io_flags) { 73 | faiss::Index * index = faiss::read_index(fname, io_flags); 74 | return new ExFaissIndex(index); 75 | } 76 | } // namespace ex_faiss -------------------------------------------------------------------------------- /c_src/ex_faiss/index.h: -------------------------------------------------------------------------------- 1 | #ifndef EX_FAISS_INDEX_H_ 2 | #define EX_FAISS_INDEX_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | namespace ex_faiss { 9 | 10 | class ExFaissIndex { 11 | public: 12 | ExFaissIndex(faiss::Index * index); 13 | 14 | ExFaissIndex(int d, const char * description, faiss::MetricType metric_type); 15 | 16 | void Add(int64_t n, const float * x); 17 | 18 | void AddWithIds(int64_t n, const float * x, const int64_t * xids); 19 | 20 | void Search(int64_t n, 21 | const float * x, 22 | int64_t k, 23 | float * distances, 24 | int64_t * labels); 25 | 26 | void Train(int64_t n, const float * x); 27 | 28 | void WriteToFile(const char * fname); 29 | 30 | ExFaissIndex * Clone(); 31 | 32 | ExFaissIndex * CloneToGpu(int device); 33 | 34 | void Reset(); 35 | 36 | void ReconstructBatch(int64_t n, const int64_t * keys, float * recons); 37 | 38 | void ComputeResiduals(int64_t n, const float * data, float * resid, const int64_t * keys); 39 | 40 | faiss::Index * index() { return index_.get(); } 41 | int dim() { return index_->d; } 42 | int64_t n_total() { return index_->ntotal; } 43 | 44 | private: 45 | std::unique_ptr index_; 46 | }; 47 | 48 | ExFaissIndex * ReadIndexFromFile(const char * fname, int io_flags); 49 | 50 | } // namespace ex_faiss 51 | #endif -------------------------------------------------------------------------------- /c_src/ex_faiss/nif_util.cc: -------------------------------------------------------------------------------- 1 | #include "nif_util.h" 2 | 3 | namespace nif { 4 | 5 | ERL_NIF_TERM ok(ErlNifEnv* env, ERL_NIF_TERM term) { 6 | return enif_make_tuple2(env, ok(env), term); 7 | } 8 | 9 | ERL_NIF_TERM ok(ErlNifEnv* env) { 10 | return enif_make_atom(env, "ok"); 11 | } 12 | 13 | ERL_NIF_TERM error(ErlNifEnv * env, const char * msg) { 14 | ERL_NIF_TERM atom = enif_make_atom(env, "error"); 15 | ERL_NIF_TERM msg_term = enif_make_string(env, msg, ERL_NIF_LATIN1); 16 | return enif_make_tuple2(env, atom, msg_term); 17 | } 18 | 19 | ERL_NIF_TERM make(ErlNifEnv* env, int var) { 20 | return enif_make_int(env, var); 21 | } 22 | 23 | ERL_NIF_TERM make(ErlNifEnv * env, int64_t var) { 24 | return enif_make_int64(env, var); 25 | } 26 | 27 | ERL_NIF_TERM make(ErlNifEnv* env, ErlNifBinary var) { 28 | return enif_make_binary(env, &var); 29 | } 30 | 31 | int get(ErlNifEnv* env, ERL_NIF_TERM term, int32_t * var) { 32 | return enif_get_int(env, term, 33 | reinterpret_cast(var)); 34 | } 35 | 36 | int get(ErlNifEnv* env, ERL_NIF_TERM term, int64_t * var) { 37 | return enif_get_int64(env, term, 38 | reinterpret_cast(var)); 39 | } 40 | 41 | int get(ErlNifEnv* env, ERL_NIF_TERM term, std::string &var) { 42 | unsigned len; 43 | int ret = enif_get_list_length(env, term, &len); 44 | 45 | if (!ret) { 46 | ErlNifBinary bin; 47 | ret = enif_inspect_binary(env, term, &bin); 48 | if (!ret) { 49 | return 0; 50 | } 51 | var = std::string((const char*)bin.data, bin.size); 52 | return ret; 53 | } 54 | 55 | var.resize(len+1); 56 | ret = enif_get_string(env, term, &*(var.begin()), var.size(), ERL_NIF_LATIN1); 57 | 58 | if (ret > 0) { 59 | var.resize(ret-1); 60 | } else if (ret == 0) { 61 | var.resize(0); 62 | } else {} 63 | 64 | return ret; 65 | } 66 | 67 | int get_metric_type(ErlNifEnv * env, ERL_NIF_TERM term, faiss::MetricType * metric_type) { 68 | int value; 69 | if (!enif_get_int(env, term, &value)) return 0; 70 | *metric_type = faiss::MetricType(value); 71 | return 1; 72 | } 73 | 74 | int get_binary(ErlNifEnv * env, ERL_NIF_TERM term, ErlNifBinary * var) { 75 | return enif_inspect_binary(env, term, var); 76 | } 77 | 78 | int get_list(ErlNifEnv* env, ERL_NIF_TERM list, std::vector &var) { 79 | unsigned int length; 80 | if (!enif_get_list_length(env, list, &length)) return 0; 81 | var.reserve(length); 82 | ERL_NIF_TERM head, tail; 83 | 84 | while (enif_get_list_cell(env, list, &head, &tail)) { 85 | int64_t elem; 86 | if (!get(env, head, &elem)) return 0; 87 | var.push_back(elem); 88 | list = tail; 89 | } 90 | return 1; 91 | } 92 | 93 | } -------------------------------------------------------------------------------- /c_src/ex_faiss/nif_util.h: -------------------------------------------------------------------------------- 1 | #ifndef EX_FAISS_NIF_UTIL_H_ 2 | #define EX_FAISS_NIF_UTIL_H_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #if !defined(__GNUC__) && (defined(__WIN32__) || defined(_WIN32) || defined(_WIN32_)) 10 | typedef unsigned __int64 nif_uint64_t; 11 | typedef signed __int64 nif_int64_t; 12 | #else 13 | typedef unsigned long nif_uint64_t; 14 | typedef signed long nif_int64_t; 15 | #endif 16 | 17 | namespace nif { 18 | 19 | ERL_NIF_TERM ok(ErlNifEnv * env, ERL_NIF_TERM term); 20 | ERL_NIF_TERM ok(ErlNifEnv * env); 21 | ERL_NIF_TERM error(ErlNifEnv * env, const char * msg); 22 | 23 | ERL_NIF_TERM make(ErlNifEnv * env, int var); 24 | ERL_NIF_TERM make(ErlNifEnv * env, int64_t var); 25 | ERL_NIF_TERM make(ErlNifEnv * env, ErlNifBinary var); 26 | 27 | int get(ErlNifEnv * env, ERL_NIF_TERM term, int32_t * var); 28 | int get(ErlNifEnv * env, ERL_NIF_TERM term, int64_t * var); 29 | int get(ErlNifEnv * env, ERL_NIF_TERM term, std::string& var); 30 | 31 | int get_metric_type(ErlNifEnv * env, ERL_NIF_TERM ter, faiss::MetricType * metric_type); 32 | 33 | int get_binary(ErlNifEnv * env, ERL_NIF_TERM term, ErlNifBinary * var); 34 | 35 | int get_list(ErlNifEnv* env, 36 | ERL_NIF_TERM list, 37 | std::vector &var); 38 | 39 | // Template struct for resources. The struct lets us use templates 40 | // to store and retrieve open resources later on. This implementation 41 | // is the same as the approach taken in the goertzenator/nifpp 42 | // C++11 wrapper around the Erlang NIF API. 43 | template 44 | struct resource_object { 45 | static ErlNifResourceType *type; 46 | }; 47 | template ErlNifResourceType* resource_object::type = 0; 48 | 49 | // Default destructor passed when opening a resource. The default 50 | // behavior is to invoke the underlying objects destructor and 51 | // set the resource pointer to NULL. 52 | template 53 | void default_dtor(ErlNifEnv* env, void * obj) { 54 | T* resource = reinterpret_cast(obj); 55 | resource->~T(); 56 | resource = nullptr; 57 | } 58 | 59 | // Opens a resource for the given template type T. If no 60 | // destructor is given, uses the default destructor defined 61 | // above. 62 | template 63 | int open_resource(ErlNifEnv* env, 64 | const char* mod, 65 | const char* name, 66 | ErlNifResourceDtor* dtor = nullptr) { 67 | if (dtor == nullptr) { 68 | dtor = &default_dtor; 69 | } 70 | ErlNifResourceType *type; 71 | ErlNifResourceFlags flags = ErlNifResourceFlags(ERL_NIF_RT_CREATE|ERL_NIF_RT_TAKEOVER); 72 | type = enif_open_resource_type(env, mod, name, dtor, flags, NULL); 73 | if (type == NULL) { 74 | resource_object::type = 0; 75 | return -1; 76 | } else { 77 | resource_object::type = type; 78 | } 79 | return 1; 80 | } 81 | 82 | // Returns a resource of the given template type T. 83 | template 84 | ERL_NIF_TERM get(ErlNifEnv* env, ERL_NIF_TERM term, T* &var) { 85 | return enif_get_resource(env, term, 86 | resource_object::type, 87 | reinterpret_cast(&var)); 88 | } 89 | 90 | // Creates a reference to the given resource of type T. 91 | template 92 | ERL_NIF_TERM make(ErlNifEnv* env, T &var) { 93 | void* ptr = enif_alloc_resource(resource_object::type, sizeof(T)); 94 | new(ptr) T(std::move(var)); 95 | ERL_NIF_TERM ret = enif_make_resource(env, ptr); 96 | enif_release_resource(ptr); 97 | return ret; 98 | } 99 | } 100 | 101 | #endif -------------------------------------------------------------------------------- /lib/ex_faiss.ex: -------------------------------------------------------------------------------- 1 | defmodule ExFaiss do 2 | end 3 | -------------------------------------------------------------------------------- /lib/ex_faiss/clustering.ex: -------------------------------------------------------------------------------- 1 | defmodule ExFaiss.Clustering do 2 | @moduledoc """ 3 | Wraps references to Faiss clustering. 4 | """ 5 | alias __MODULE__ 6 | alias ExFaiss.Index 7 | import ExFaiss.Shared 8 | require Logger 9 | 10 | defstruct [:ref, :k, :index, :trained?] 11 | 12 | @doc """ 13 | Creates a new Faiss clustering object. 14 | """ 15 | def new(d, k, _opts \\ []) do 16 | # TODO: Handle options 17 | # TODO: Create correct index 18 | cluster = ExFaiss.NIF.new_clustering(d, k) |> unwrap!() 19 | index = Index.new(d, "Flat") 20 | %Clustering{ref: cluster, index: index, k: k} 21 | end 22 | 23 | @doc """ 24 | Trains a Faiss clustering object. 25 | """ 26 | def train( 27 | %Clustering{ref: clustering, index: %Index{dim: dim, ref: index}} = cluster, 28 | %Nx.Tensor{} = tensor 29 | ) do 30 | validate_type!(tensor, {:f, 32}) 31 | 32 | case Nx.shape(tensor) do 33 | {^dim} -> 34 | # TODO: Warn? 35 | data = Nx.to_binary(tensor) 36 | ExFaiss.NIF.train_clustering(clustering, 1, data, index) 37 | 38 | {n, ^dim} -> 39 | data = Nx.to_binary(tensor) 40 | ExFaiss.NIF.train_clustering(clustering, n, data, index) 41 | 42 | shape -> 43 | raise ArgumentError, 44 | "invalid shape for index with dim #{inspect(dim)}," <> 45 | " tensor shape must be rank-1 or rank-2 with trailing" <> 46 | " dimension equal to dimension of the index, got shape" <> 47 | " #{inspect(shape)}" 48 | end 49 | 50 | %{cluster | trained?: true} 51 | end 52 | 53 | @doc """ 54 | Returns cluster assignment for given embedding. 55 | """ 56 | def get_cluster_assignment( 57 | %Clustering{trained?: true, index: %Index{} = index}, 58 | %Nx.Tensor{} = tensor 59 | ) do 60 | Index.search(index, tensor, 1) 61 | end 62 | 63 | def get_cluster_assignment(_, _) do 64 | raise ArgumentError, "cannot get cluster assignments for un-trained clustering" 65 | end 66 | 67 | @doc """ 68 | Returns clustering centroids of given clustering. 69 | """ 70 | def get_centroids(%Clustering{trained?: true, ref: clustering, k: k, index: %Index{dim: d}}) do 71 | centroids_data = ExFaiss.NIF.get_clustering_centroids(clustering) |> unwrap!() 72 | 73 | centroids_data 74 | |> Nx.from_binary(:f32) 75 | |> Nx.reshape({k, d}) 76 | end 77 | end 78 | -------------------------------------------------------------------------------- /lib/ex_faiss/index.ex: -------------------------------------------------------------------------------- 1 | defmodule ExFaiss.Index do 2 | @moduledoc """ 3 | Wraps references to a Faiss index. 4 | """ 5 | alias __MODULE__ 6 | import ExFaiss.Shared 7 | 8 | defstruct [:dim, :ref, :device] 9 | 10 | # TODO: In all of these results, we copy the underlying data with to_binary 11 | # but FAISS does not take ownership of the data, so there must be a way we 12 | # can just provide a view of the data without copying. I think it may 13 | # be backend specific though 14 | 15 | # TODO: Anything that just returns :ok right now should return 16 | # the index 17 | 18 | # TODO: Handle selectors in search 19 | 20 | # TODO: Handle :errors from C++ exceptions 21 | 22 | # TODO: Change order of new arguments 23 | # TODO: Add description to struct 24 | 25 | @doc """ 26 | Creates a new Faiss index which stores vectors 27 | of the given dimensionality `dim`. 28 | 29 | ## Options 30 | 31 | * `:metric` - metric type. One of [:l2] 32 | 33 | * `:device` - device type. One of `:host`, `:cuda`, or 34 | `{:cuda, device}` where device is an integer device 35 | ordinal 36 | """ 37 | def new(dim, description, opts \\ []) when is_integer(dim) and dim > 0 do 38 | # TODO: Handle Index factory description as options 39 | # TODO: Maybe have sigil to construct factory descriptions 40 | opts = Keyword.validate!(opts, metric: :l2, device: :host) 41 | metric_type = metric_type_to_int(opts[:metric]) 42 | 43 | ref = ExFaiss.NIF.new_index(dim, description, metric_type) |> unwrap!() 44 | 45 | case opts[:device] do 46 | :cuda -> 47 | new_gpu_index(ref, dim, 0) 48 | 49 | {:cuda, device} -> 50 | new_gpu_index(ref, dim, device) 51 | 52 | :host -> 53 | %Index{dim: dim, ref: ref, device: :host} 54 | 55 | device -> 56 | raise ArgumentError, "invalid device #{inspect(device)}" 57 | end 58 | end 59 | 60 | # TODO: Handle replicated index 61 | defp new_gpu_index(ref, dim, device) do 62 | devices = ExFaiss.NIF.get_num_gpus() |> unwrap!() 63 | 64 | cond do 65 | devices <= 0 -> 66 | raise ArgumentError, 67 | "no gpu devices found, please ensure you've set" <> 68 | " the environment variable USE_CUDA=true, and" <> 69 | " that you have CUDA enabled devices" 70 | 71 | devices < device or device < 0 -> 72 | raise ArgumentError, 73 | "device #{inspect(device)} is out of bounds for" <> 74 | " number of devices #{inspect(devices)}" 75 | 76 | true -> 77 | ref = ExFaiss.NIF.index_cpu_to_gpu(ref, device) |> unwrap!() 78 | %Index{ref: ref, dim: dim, device: {:cuda, device}} 79 | end 80 | end 81 | 82 | @doc """ 83 | Adds the given tensors to the given index. 84 | """ 85 | def add(%Index{dim: dim, ref: ref} = index, %Nx.Tensor{} = tensor) do 86 | validate_type!(tensor, {:f, 32}) 87 | 88 | case Nx.shape(tensor) do 89 | {^dim} -> 90 | data = Nx.to_binary(tensor) 91 | ExFaiss.NIF.add_to_index(ref, 1, data) 92 | 93 | {n, ^dim} -> 94 | data = Nx.to_binary(tensor) 95 | ExFaiss.NIF.add_to_index(ref, n, data) 96 | 97 | shape -> 98 | invalid_shape_error!(dim, shape) 99 | end 100 | 101 | index 102 | end 103 | 104 | @doc """ 105 | Adds the given tensors and IDs to the given index. 106 | """ 107 | def add_with_ids(%Index{dim: dim, ref: ref} = index, %Nx.Tensor{} = tensor, %Nx.Tensor{} = ids) do 108 | validate_type!(tensor, {:f, 32}) 109 | validate_type!(ids, {:s, 64}) 110 | 111 | case {Nx.shape(tensor), Nx.shape(ids)} do 112 | {{^dim}, {1}} -> 113 | data = Nx.to_binary(tensor) 114 | xids = Nx.to_binary(ids) 115 | ExFaiss.NIF.add_with_ids_to_index(ref, 1, data, xids) 116 | 117 | {{n, ^dim}, {n}} -> 118 | data = Nx.to_binary(tensor) 119 | xids = Nx.to_binary(ids) 120 | ExFaiss.NIF.add_with_ids_to_index(ref, n, data, xids) 121 | 122 | {tensor_shape, ids_shape} -> 123 | raise ArgumentError, 124 | "invalid shape for index with dim #{inspect(dim)}," <> 125 | " tensor shape must be rank-1 or rank-2 with trailing" <> 126 | " dimension equal to dimension of the index, while ids" <> 127 | " shape must be rank-1 with dimension equal to leading" <> 128 | " dimension of data, or 1 if data is rank-1, got shapes" <> 129 | " ids: #{inspect(ids_shape)}, embeddings: #{inspect(tensor_shape)}" 130 | end 131 | 132 | index 133 | end 134 | 135 | @doc """ 136 | Searches the given index for the top `k` matches 137 | close to the given query vector. 138 | 139 | The result is a map with keys `:labels` and `:distances` 140 | which represent the index ID and pairwise distances from 141 | the query vector for each result vector. 142 | """ 143 | def search(%Index{dim: dim, ref: index}, %Nx.Tensor{} = tensor, k) 144 | when is_integer(k) and k > 0 do 145 | validate_type!(tensor, {:f, 32}) 146 | 147 | case Nx.shape(tensor) do 148 | {^dim} -> 149 | data = Nx.to_binary(tensor) 150 | {distances, labels} = ExFaiss.NIF.search_index(index, 1, data, k) |> unwrap!() 151 | 152 | %{ 153 | distances: distances |> Nx.from_binary(:f32) |> Nx.reshape({1, k}), 154 | labels: labels |> Nx.from_binary(:s64) |> Nx.reshape({1, k}) 155 | } 156 | 157 | {n, ^dim} -> 158 | data = Nx.to_binary(tensor) 159 | {distances, labels} = ExFaiss.NIF.search_index(index, n, data, k) |> unwrap!() 160 | 161 | %{ 162 | distances: distances |> Nx.from_binary(:f32) |> Nx.reshape({n, k}), 163 | labels: labels |> Nx.from_binary(:s64) |> Nx.reshape({n, k}) 164 | } 165 | 166 | shape -> 167 | invalid_shape_error!(dim, shape) 168 | end 169 | end 170 | 171 | @doc """ 172 | Trains an index on a representative set of vectors. 173 | """ 174 | def train(%Index{dim: dim, ref: ref} = index, %Nx.Tensor{} = tensor) do 175 | validate_type!(tensor, {:f, 32}) 176 | 177 | case Nx.shape(tensor) do 178 | {n, ^dim} -> 179 | data = Nx.to_binary(tensor) 180 | ExFaiss.NIF.train_index(ref, n, data) 181 | 182 | shape -> 183 | invalid_shape_error!(dim, shape) 184 | end 185 | 186 | index 187 | end 188 | 189 | @doc """ 190 | Creates a copy of the given index. 191 | """ 192 | def clone(%Index{dim: dim, ref: index}) do 193 | ref = ExFaiss.NIF.clone_index(index) |> unwrap!() 194 | %Index{dim: dim, ref: ref} 195 | end 196 | 197 | @doc """ 198 | Reconstructs stored vectors at the given indices. 199 | """ 200 | def reconstruct(%Index{dim: dim, ref: index}, %Nx.Tensor{} = keys) do 201 | n = 202 | case Nx.shape(keys) do 203 | {n} -> 204 | n 205 | 206 | {} -> 207 | 1 208 | end 209 | 210 | keys_data = Nx.to_binary(keys) 211 | 212 | index 213 | |> ExFaiss.NIF.reconstruct_batch_from_index(n, keys_data) 214 | |> unwrap!() 215 | |> Nx.from_binary(:f32) 216 | |> Nx.reshape({n, dim}) 217 | end 218 | 219 | @doc """ 220 | Computes residuals after indexing. 221 | """ 222 | def compute_residuals(%Index{dim: dim, ref: index}, %Nx.Tensor{} = xs, %Nx.Tensor{} = keys) do 223 | n = 224 | case {Nx.shape(xs), Nx.shape(keys)} do 225 | {{^dim}, {1}} -> 226 | 1 227 | 228 | {{n, ^dim}, {n}} -> 229 | n 230 | 231 | {tensor_shape, ids_shape} -> 232 | raise ArgumentError, 233 | "invalid shape for index with dim #{inspect(dim)}," <> 234 | " tensor shape must be rank-1 or rank-2 with trailing" <> 235 | " dimension equal to dimension of the index, while ids" <> 236 | " shape must be rank-1 with dimension equal to leading" <> 237 | " dimension of data, or 1 if data is rank-1, got shapes" <> 238 | " ids: #{inspect(ids_shape)}, embeddings: #{inspect(tensor_shape)}" 239 | end 240 | 241 | xs_data = Nx.to_binary(xs) 242 | keys_data = Nx.to_binary(keys) 243 | 244 | index 245 | |> ExFaiss.NIF.compute_residuals_from_index(n, xs_data, keys_data) 246 | |> unwrap!() 247 | |> Nx.from_binary(:f32) 248 | |> Nx.reshape({n, dim}) 249 | end 250 | 251 | @doc """ 252 | Writes an index to a file. 253 | """ 254 | def to_file(%Index{ref: index}, fname) do 255 | :ok = ExFaiss.NIF.write_index(index, fname) 256 | end 257 | 258 | @doc """ 259 | Reads an index from a file. 260 | """ 261 | def from_file(fname, io_flags) do 262 | ref = ExFaiss.NIF.read_index(fname, io_flags) |> unwrap!() 263 | dim = ExFaiss.NIF.get_index_dim(ref) |> unwrap!() 264 | %Index{dim: dim, ref: ref} 265 | end 266 | 267 | @doc """ 268 | Gets the number of vectors in the index. 269 | """ 270 | def get_num_vectors(%Index{ref: index}) do 271 | ExFaiss.NIF.get_index_n_vectors(index) |> unwrap!() 272 | end 273 | 274 | defp invalid_shape_error!(dim, shape) do 275 | raise ArgumentError, 276 | "invalid shape for index with dim #{inspect(dim)}," <> 277 | " tensor shape must be rank-1 or rank-2 with trailing" <> 278 | " dimension equal to dimension of the index, got shape" <> 279 | " #{inspect(shape)}" 280 | end 281 | 282 | defp metric_type_to_int(:inner_product), do: 0 283 | defp metric_type_to_int(:l2), do: 1 284 | defp metric_type_to_int(:l1), do: 2 285 | defp metric_type_to_int(:linf), do: 3 286 | defp metric_type_to_int(:lp), do: 4 287 | defp metric_type_to_int(:canberra), do: 20 288 | defp metric_type_to_int(:braycurtis), do: 21 289 | defp metric_type_to_int(:jensenshannon), do: 22 290 | defp metric_type_to_int(invalid), do: raise(ArgumentError, "invalid metric #{inspect(invalid)}") 291 | end 292 | -------------------------------------------------------------------------------- /lib/ex_faiss/nif.ex: -------------------------------------------------------------------------------- 1 | defmodule ExFaiss.NIF do 2 | @on_load :__on_load__ 3 | 4 | def __on_load__ do 5 | path = :filename.join(:code.priv_dir(:ex_faiss), ~c"libex_faiss") 6 | :erlang.load_nif(path, 0) 7 | end 8 | 9 | # Index operations 10 | def new_index(_dim, _description, _metric), do: :erlang.nif_error(:undef) 11 | def clone_index(_index), do: :erlang.nif_error(:undef) 12 | def add_to_index(_index, _dim, _data), do: :erlang.nif_error(:undef) 13 | def add_with_ids_to_index(_index, _dim, _data, _ids), do: :erlang.nif_error(:undef) 14 | def search_index(_index, _n, _data, _k), do: :erlang.nif_error(:undef) 15 | def train_index(_index, _n, _data), do: :erlang.nif_error(:undef) 16 | def reset_index(_index), do: :erlang.nif_error(:undef) 17 | def reconstruct_batch_from_index(_index, _n, _data), do: :erlang.nif_error(:undef) 18 | def compute_residuals_from_index(_index, _n, _data, _keys), do: :erlang.nif_error(:undef) 19 | def write_index(_index, _fname), do: :erlang.nif_error(:undef) 20 | def read_index(_fname, _io_flags), do: :erlang.nif_error(:undef) 21 | def get_index_dim(_index), do: :erlang.nif_error(:undef) 22 | def get_index_n_vectors(_index), do: :erlang.nif_error(:undef) 23 | 24 | # Gpu operations 25 | def index_cpu_to_gpu(_index, _device), do: :erlang.nif_error(:undef) 26 | def get_num_gpus(), do: :erlang.nif_error(:undef) 27 | 28 | # Clustering operations 29 | def new_clustering(_dim, _k), do: :erlang.nif_error(:undef) 30 | def train_clustering(_clustering, _n, _data, _index), do: :erlang.nif_error(:undef) 31 | def get_clustering_centroids(_clustering), do: :erlang.nif_error(:undef) 32 | end 33 | -------------------------------------------------------------------------------- /lib/ex_faiss/shared.ex: -------------------------------------------------------------------------------- 1 | defmodule ExFaiss.Shared do 2 | @moduledoc false 3 | 4 | def validate_type!(tensor, type) do 5 | unless Nx.type(tensor) == type do 6 | raise ArgumentError, 7 | "invalid type #{inspect(Nx.type(tensor))}, vector type" <> 8 | " must be #{inspect(type)}" 9 | end 10 | end 11 | 12 | def unwrap!({:ok, val}), do: val 13 | def unwrap!({:error, reason}), do: raise(reason) 14 | end 15 | -------------------------------------------------------------------------------- /mix.exs: -------------------------------------------------------------------------------- 1 | defmodule ExFaiss.MixProject do 2 | use Mix.Project 3 | 4 | def project do 5 | [ 6 | app: :ex_faiss, 7 | version: "0.1.0", 8 | elixir: "~> 1.14", 9 | start_permanent: Mix.env() == :prod, 10 | compilers: [:elixir_make] ++ Mix.compilers(), 11 | deps: deps() 12 | ] 13 | end 14 | 15 | # Run "mix help compile.app" to learn about applications. 16 | def application do 17 | [ 18 | extra_applications: [:logger] 19 | ] 20 | end 21 | 22 | # Run "mix help deps" to learn about dependencies. 23 | defp deps do 24 | [ 25 | {:elixir_make, "~> 0.4", runtime: false}, 26 | {:nx, "~> 0.4"} 27 | ] 28 | end 29 | end 30 | -------------------------------------------------------------------------------- /mix.lock: -------------------------------------------------------------------------------- 1 | %{ 2 | "complex": {:hex, :complex, "0.4.2", "923e5db0be13dbb3ea00cf8459d9f75f3afdd9ff5a82742ded21064330d28273", [:mix], [], "hexpm", "069a085ef820ce675a2619fd125b963ff4514af2102c7f7d7965128e5ec0a429"}, 3 | "elixir_make": {:hex, :elixir_make, "0.6.3", "bc07d53221216838d79e03a8019d0839786703129599e9619f4ab74c8c096eac", [:mix], [], "hexpm", "f5cbd651c5678bcaabdbb7857658ee106b12509cd976c2c2fca99688e1daf716"}, 4 | "nx": {:hex, :nx, "0.4.0", "2ec2cebec6a9ac8a3d5ae8ef79345cf92f37f9018d50817684e51e97b86f3d36", [:mix], [{:complex, "~> 0.4.2", [hex: :complex, repo: "hexpm", optional: false]}], "hexpm", "bab955768dadfe2208723fbffc9255341b023291f2aabcbd25bf98167dd3399e"}, 5 | } 6 | -------------------------------------------------------------------------------- /test/ex_faiss/clustering_test.exs: -------------------------------------------------------------------------------- 1 | defmodule ExFaiss.ClusteringTest do 2 | use ExUnit.Case 3 | 4 | alias ExFaiss.Clustering 5 | alias ExFaiss.Index 6 | 7 | describe "new" do 8 | test "creates a new clustering" do 9 | assert %Clustering{k: 10, ref: _, index: %Index{dim: 128}} = Clustering.new(128, 10) 10 | end 11 | end 12 | 13 | describe "train" do 14 | test "trains a cluster and adds to index" do 15 | trained = 16 | Clustering.new(128, 10) 17 | |> Clustering.train(Nx.random_uniform({100, 128})) 18 | 19 | assert %Clustering{k: 10, ref: _, index: %Index{dim: 128} = index} = trained 20 | assert Index.get_num_vectors(index) == 10 21 | end 22 | end 23 | end 24 | -------------------------------------------------------------------------------- /test/ex_faiss/index_test.exs: -------------------------------------------------------------------------------- 1 | defmodule ExFaiss.IndexTest do 2 | use ExUnit.Case 3 | 4 | alias ExFaiss.Index 5 | 6 | describe "new" do 7 | test "creates indices from descriptions" do 8 | assert %Index{} = 9 | ExFaiss.Index.new( 10 | 512, 11 | "OPQ16_64,IVF262144(IVF512,PQ32x4fs,RFlat),PQ16x4fsr,Refine(OPQ56_112,PQ56)" 12 | ) 13 | 14 | assert %Index{} = ExFaiss.Index.new(128, "PCA80,Flat") 15 | assert %Index{} = ExFaiss.Index.new(128, "OPQ16_64,IMI2x8,PQ8+16") 16 | assert %Index{} = ExFaiss.Index.new(512, "Flat", metric: :jensenshannon) 17 | end 18 | 19 | @tag :cuda 20 | test "creates gpu indices from descriptions" do 21 | # TODO: Unsupported clone? 22 | # assert %Index{device: {:cuda, 0}} = 23 | # ExFaiss.Index.new( 24 | # 512, 25 | # "OPQ16_64,IVF262144(IVF512,PQ32x4fs,RFlat),PQ16x4fsr,Refine(OPQ56_112,PQ56)", 26 | # device: {:cuda, 0} 27 | # ) 28 | 29 | assert %Index{device: {:cuda, 0}} = ExFaiss.Index.new(128, "PCA80,Flat", device: {:cuda, 0}) 30 | 31 | assert %Index{device: {:cuda, 0}} = 32 | ExFaiss.Index.new(128, "OPQ16_64,IMI2x8,PQ8+16", device: {:cuda, 0}) 33 | 34 | assert %Index{device: {:cuda, 0}} = 35 | ExFaiss.Index.new(512, "Flat", metric: :jensenshannon, device: {:cuda, 0}) 36 | end 37 | end 38 | 39 | describe "clone" do 40 | test "creates clones of index" do 41 | assert %Index{ref: ref1} = index = ExFaiss.Index.new(512, "Flat") 42 | 43 | assert %Index{ref: ref2} = ExFaiss.Index.clone(index) 44 | 45 | assert ref1 != ref2 46 | end 47 | end 48 | 49 | describe "add" do 50 | test "adds valid tensors" do 51 | index = ExFaiss.Index.new(512, "Flat") 52 | 53 | assert %Index{} = ExFaiss.Index.add(index, Nx.random_uniform({512})) 54 | assert %Index{} = ExFaiss.Index.add(index, Nx.random_uniform({2, 512})) 55 | end 56 | 57 | @tag :cuda 58 | test "adds valid tensors to gpu index" do 59 | index = ExFaiss.Index.new(512, "Flat", device: {:cuda, 0}) 60 | 61 | assert %Index{} = ExFaiss.Index.add(index, Nx.random_uniform({512})) 62 | assert %Index{} = ExFaiss.Index.add(index, Nx.random_uniform({2, 512})) 63 | end 64 | 65 | test "raises on invalid types" do 66 | index1 = ExFaiss.Index.new(128, "Flat") 67 | 68 | assert_raise ArgumentError, ~r/invalid type/, fn -> 69 | ExFaiss.Index.add(index1, Nx.random_uniform({2, 128}, type: :f64)) 70 | end 71 | end 72 | 73 | test "raises on invalid shapes" do 74 | index1 = ExFaiss.Index.new(128, "Flat") 75 | 76 | assert_raise ArgumentError, ~r/invalid shape/, fn -> 77 | ExFaiss.Index.add(index1, Nx.random_uniform({2, 3, 4})) 78 | end 79 | 80 | assert_raise ArgumentError, ~r/invalid shape/, fn -> 81 | ExFaiss.Index.add(index1, Nx.random_uniform({2, 256})) 82 | end 83 | end 84 | end 85 | 86 | describe "search" do 87 | test "searches a simple flat index" do 88 | index = 89 | ExFaiss.Index.new(1, "Flat", metric: :l1) 90 | |> ExFaiss.Index.add(Nx.iota({64, 1}, type: :f32)) 91 | 92 | assert %{distances: distances, labels: labels} = 93 | ExFaiss.Index.search(index, Nx.tensor([0.0]), 32) 94 | 95 | assert distances == Nx.iota({1, 32}, type: :f32) 96 | assert labels == Nx.iota({1, 32}) 97 | end 98 | 99 | @tag :cuda 100 | test "searches a simple flat gpu index" do 101 | index = 102 | ExFaiss.Index.new(1, "Flat", metric: :l1, device: {:cuda, 0}) 103 | |> ExFaiss.Index.add(Nx.iota({64, 1}, type: :f32)) 104 | 105 | assert %{distances: distances, labels: labels} = 106 | ExFaiss.Index.search(index, Nx.tensor([0.0]), 32) 107 | 108 | assert distances == Nx.iota({1, 32}, type: :f32) 109 | assert labels == Nx.iota({1, 32}) 110 | end 111 | end 112 | 113 | describe "train" do 114 | test "trains an index" do 115 | index = 116 | ExFaiss.Index.new(10, "HNSW,Flat") 117 | |> ExFaiss.Index.train(Nx.random_uniform({100, 10})) 118 | |> ExFaiss.Index.add(Nx.random_uniform({100, 10})) 119 | 120 | assert %Index{} = index 121 | end 122 | 123 | @tag :cuda 124 | test "trains an index on gpu" do 125 | index = 126 | ExFaiss.Index.new(10, "HNSW,Flat", device: {:cuda, 0}) 127 | |> ExFaiss.Index.train(Nx.random_uniform({100, 10})) 128 | |> ExFaiss.Index.add(Nx.random_uniform({100, 10})) 129 | 130 | assert %Index{} = index 131 | end 132 | end 133 | 134 | describe "reconstruct" do 135 | test "reconstructs vectors from keys" do 136 | data = Nx.random_uniform({1, 128}) 137 | 138 | result = 139 | ExFaiss.Index.new(128, "Flat") 140 | |> ExFaiss.Index.add(data) 141 | |> ExFaiss.Index.reconstruct(Nx.tensor([0])) 142 | 143 | assert result == data 144 | end 145 | end 146 | 147 | describe "compute_residuals" do 148 | test "computes residuals from data and keys" do 149 | data = Nx.broadcast(0.0, {1, 128}) 150 | 151 | result = 152 | ExFaiss.Index.new(128, "Flat") 153 | |> ExFaiss.Index.add(data) 154 | |> ExFaiss.Index.compute_residuals(data, Nx.tensor([0])) 155 | 156 | assert result == data 157 | end 158 | end 159 | 160 | describe "memory" do 161 | @tag :slow 162 | test "does not leak" do 163 | for _ <- 1..100 do 164 | ExFaiss.Index.new(128, "Flat") 165 | |> ExFaiss.Index.add(Nx.random_uniform({500, 128})) 166 | 167 | :erlang.garbage_collect() 168 | end 169 | end 170 | 171 | @tag :slow 172 | test "does not leak on gpu" do 173 | for _ <- 1..100 do 174 | ExFaiss.Index.new(128, "Flat", device: {:cuda, 0}) 175 | |> ExFaiss.Index.add(Nx.random_uniform({500, 128})) 176 | 177 | :erlang.garbage_collect() 178 | end 179 | end 180 | end 181 | 182 | describe "multi-device" do 183 | @describetag :multi_device 184 | 185 | test "creates an on non-default device" do 186 | %Index{device: {:cuda, 1}} = ExFaiss.Index.new(128, "Flat", device: {:cuda, 1}) 187 | end 188 | end 189 | end 190 | -------------------------------------------------------------------------------- /test/ex_faiss_test.exs: -------------------------------------------------------------------------------- 1 | defmodule ExFaissTest do 2 | use ExUnit.Case 3 | doctest ExFaiss 4 | end 5 | -------------------------------------------------------------------------------- /test/test_helper.exs: -------------------------------------------------------------------------------- 1 | ExUnit.start(exclude: [:slow, :cuda, :multi_device]) 2 | --------------------------------------------------------------------------------