├── .gitignore ├── hnswlib-jna-example ├── lib │ ├── libhnswlib-jna-x86-64.dll │ ├── libhnswlib-jna-x86-64.dylib │ ├── libhnswlib-jna-x86-64.exp │ ├── libhnswlib-jna-x86-64.lib │ └── libhnswlib-jna-x86-64.so ├── pom.xml └── src │ └── main │ └── java │ └── com │ └── stepstone │ └── search │ └── hnswlib │ └── jna │ └── example │ └── App.java ├── hnswlib-jna ├── src │ ├── main │ │ ├── resources │ │ │ ├── libhnswlib-jna-x86-64.so │ │ │ ├── libhnswlib-jna-aarch64.so │ │ │ ├── libhnswlib-jna-x86-64.dll │ │ │ ├── libhnswlib-jna-x86-64.dylib │ │ │ ├── libhnswlib-jna-x86-64.exp │ │ │ └── libhnswlib-jna-x86-64.libw │ │ └── java │ │ │ └── com │ │ │ └── stepstone │ │ │ └── search │ │ │ └── hnswlib │ │ │ └── jna │ │ │ ├── SpaceName.java │ │ │ ├── exception │ │ │ ├── OnceIndexIsClearedItCannotBeReusedException.java │ │ │ ├── IndexAlreadyInitializedException.java │ │ │ ├── UnableToCreateNewIndexInstanceException.java │ │ │ ├── IndexNotInitializedException.java │ │ │ ├── ItemCannotBeInsertedIntoTheVectorSpaceException.java │ │ │ ├── UnexpectedNativeException.java │ │ │ └── QueryCannotReturnResultsException.java │ │ │ ├── QueryTuple.java │ │ │ ├── HnswlibFactory.java │ │ │ ├── Hnswlib.java │ │ │ ├── ConcurrentIndex.java │ │ │ └── Index.java │ └── test │ │ └── java │ │ └── com │ │ └── stepstone │ │ └── search │ │ └── hnswlib │ │ └── jna │ │ ├── ConcurrentIndexTest.java │ │ ├── IndexPerformanceTest.java │ │ ├── ConcurrentIndexPerformanceTest.java │ │ ├── HnswlibTestUtils.java │ │ ├── AbstractPerformanceTest.java │ │ ├── IndexTest.java │ │ └── AbstractIndexTest.java └── pom.xml ├── hnswlib-jna-legacy ├── src │ ├── test │ │ └── java │ │ │ └── com │ │ │ └── stepstone │ │ │ └── search │ │ │ └── hnswlib │ │ │ └── jna │ │ │ ├── LegacyIndexTest.java │ │ │ └── LegacyConcurrentIndexTest.java │ └── main │ │ └── java │ │ └── com │ │ └── stepstone │ │ └── search │ │ └── hnswlib │ │ └── jna │ │ └── HnswlibFactory.java └── pom.xml ├── .travis.yml ├── hnswlib ├── visited_list_pool.h ├── hnswlib.h ├── bruteforce.h ├── space_l2.h ├── space_ip.h └── hnswalg.h ├── README.md ├── pom.xml ├── bindings.cpp └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | **/*.iml 2 | **/.idea/ 3 | **/target/ 4 | **/*.log 5 | 6 | -------------------------------------------------------------------------------- /hnswlib-jna-example/lib/libhnswlib-jna-x86-64.dll: -------------------------------------------------------------------------------- 1 | /* this is just a placeholder for the dynamic library for windows */ 2 | -------------------------------------------------------------------------------- /hnswlib-jna-example/lib/libhnswlib-jna-x86-64.dylib: -------------------------------------------------------------------------------- 1 | /* this is just a placeholder for the dynamic library for mac */ 2 | -------------------------------------------------------------------------------- /hnswlib-jna-example/lib/libhnswlib-jna-x86-64.exp: -------------------------------------------------------------------------------- 1 | /* this is just a placeholder for the dynamic library for windows */ 2 | -------------------------------------------------------------------------------- /hnswlib-jna-example/lib/libhnswlib-jna-x86-64.lib: -------------------------------------------------------------------------------- 1 | /* this is just a placeholder for the dynamic library for windows */ 2 | -------------------------------------------------------------------------------- /hnswlib-jna-example/lib/libhnswlib-jna-x86-64.so: -------------------------------------------------------------------------------- 1 | /* this is just a placeholder for the dynamic library for linux */ 2 | -------------------------------------------------------------------------------- /hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepstone-tech/hnswlib-jna/HEAD/hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.so -------------------------------------------------------------------------------- /hnswlib-jna/src/main/resources/libhnswlib-jna-aarch64.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepstone-tech/hnswlib-jna/HEAD/hnswlib-jna/src/main/resources/libhnswlib-jna-aarch64.so -------------------------------------------------------------------------------- /hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepstone-tech/hnswlib-jna/HEAD/hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.dll -------------------------------------------------------------------------------- /hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.dylib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepstone-tech/hnswlib-jna/HEAD/hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.dylib -------------------------------------------------------------------------------- /hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.exp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepstone-tech/hnswlib-jna/HEAD/hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.exp -------------------------------------------------------------------------------- /hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.libw: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stepstone-tech/hnswlib-jna/HEAD/hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.libw -------------------------------------------------------------------------------- /hnswlib-jna-legacy/src/test/java/com/stepstone/search/hnswlib/jna/LegacyIndexTest.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna; 2 | 3 | public class LegacyIndexTest extends IndexTest { 4 | 5 | } 6 | -------------------------------------------------------------------------------- /hnswlib-jna-legacy/src/test/java/com/stepstone/search/hnswlib/jna/LegacyConcurrentIndexTest.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna; 2 | 3 | public class LegacyConcurrentIndexTest extends ConcurrentIndexTest { 4 | 5 | } 6 | -------------------------------------------------------------------------------- /hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/SpaceName.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna; 2 | 3 | /** 4 | * Space names available in the native implementation. 5 | */ 6 | public enum SpaceName { L2, IP, COSINE /* requires normalized arrays */ } 7 | -------------------------------------------------------------------------------- /hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/exception/OnceIndexIsClearedItCannotBeReusedException.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna.exception; 2 | 3 | /** 4 | * Exception thrown when operations are done in a cleared index. 5 | */ 6 | public class OnceIndexIsClearedItCannotBeReusedException extends UnexpectedNativeException { 7 | } 8 | -------------------------------------------------------------------------------- /hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/exception/IndexAlreadyInitializedException.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna.exception; 2 | 3 | /** 4 | * Exception raised when the method initialize() of an 5 | * index has been called more than once. 6 | */ 7 | public class IndexAlreadyInitializedException extends UnexpectedNativeException { 8 | } 9 | -------------------------------------------------------------------------------- /hnswlib-jna/src/test/java/com/stepstone/search/hnswlib/jna/ConcurrentIndexTest.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna; 2 | 3 | public class ConcurrentIndexTest extends AbstractIndexTest { 4 | 5 | @Override 6 | protected Index createIndexInstance(SpaceName spaceName, int dimensions) { 7 | return new ConcurrentIndex(spaceName, dimensions); 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/exception/UnableToCreateNewIndexInstanceException.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna.exception; 2 | 3 | /** 4 | * This exception is thrown when it is not possible (for some reason) to create 5 | * a new index instance in the native side. 6 | */ 7 | public class UnableToCreateNewIndexInstanceException extends UnexpectedNativeException { 8 | } 9 | -------------------------------------------------------------------------------- /hnswlib-jna/src/test/java/com/stepstone/search/hnswlib/jna/IndexPerformanceTest.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna; 2 | 3 | import org.junit.Ignore; 4 | 5 | @Ignore 6 | public class IndexPerformanceTest extends AbstractPerformanceTest { 7 | 8 | @Override 9 | protected Index createIndexInstance(SpaceName spaceName, int dimensions) { 10 | return new Index(spaceName, dimensions); 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/exception/IndexNotInitializedException.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna.exception; 2 | 3 | /** 4 | * Exception thrown when the index reference is not initialized on the native side. 5 | * (the method initialize() is not called after the object instantiation) 6 | */ 7 | public class IndexNotInitializedException extends UnexpectedNativeException { 8 | } 9 | -------------------------------------------------------------------------------- /hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/exception/ItemCannotBeInsertedIntoTheVectorSpaceException.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna.exception; 2 | 3 | /** 4 | * Exception thrown when the max number of elements into a vector 5 | * space is reached. This value is set during the vector space initialization. 6 | */ 7 | public class ItemCannotBeInsertedIntoTheVectorSpaceException extends UnexpectedNativeException { 8 | } 9 | -------------------------------------------------------------------------------- /hnswlib-jna/src/test/java/com/stepstone/search/hnswlib/jna/ConcurrentIndexPerformanceTest.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna; 2 | 3 | import org.junit.Ignore; 4 | 5 | @Ignore 6 | public class ConcurrentIndexPerformanceTest extends AbstractPerformanceTest { 7 | 8 | @Override 9 | protected Index createIndexInstance(SpaceName spaceName, int dimensions) { 10 | return new ConcurrentIndex(spaceName, dimensions); 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/exception/UnexpectedNativeException.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna.exception; 2 | 3 | /** 4 | * General exception for errors that happened on the native implementation. 5 | */ 6 | public class UnexpectedNativeException extends RuntimeException { 7 | 8 | public UnexpectedNativeException() { 9 | } 10 | 11 | public UnexpectedNativeException(String message) { 12 | super(message); 13 | } 14 | 15 | } 16 | -------------------------------------------------------------------------------- /hnswlib-jna/src/test/java/com/stepstone/search/hnswlib/jna/HnswlibTestUtils.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna; 2 | 3 | import java.util.Random; 4 | 5 | public final class HnswlibTestUtils { 6 | 7 | public static float[] getRandomFloatArray(int dimension){ 8 | float[] array = new float[dimension]; 9 | Random random = new Random(); 10 | for (int i = 0; i < dimension; i++){ 11 | array[i] = random.nextFloat(); 12 | } 13 | return array; 14 | } 15 | 16 | } 17 | -------------------------------------------------------------------------------- /hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/exception/QueryCannotReturnResultsException.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna.exception; 2 | 3 | /** 4 | * Exception that represents that results could not be returned by the query. 5 | */ 6 | public class QueryCannotReturnResultsException extends UnexpectedNativeException { 7 | 8 | private static final String MESSAGE = "Probably ef or M is too small"; 9 | 10 | public QueryCannotReturnResultsException() { 11 | super(MESSAGE); 12 | } 13 | 14 | } 15 | -------------------------------------------------------------------------------- /hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/QueryTuple.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna; 2 | 3 | /** 4 | * Query Tuple that represents the results of a knn query. 5 | * It contains two arrays: ids and coefficients. 6 | */ 7 | public class QueryTuple { 8 | 9 | int[] ids; 10 | float[] coefficients; 11 | 12 | QueryTuple (int k) { 13 | ids = new int[k]; 14 | coefficients = new float[k]; 15 | } 16 | 17 | public float[] getCoefficients() { 18 | return coefficients; 19 | } 20 | 21 | public int[] getIds() { 22 | return ids; 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /hnswlib-jna-example/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | hnswlib-jna-example 8 | com.stepstone.search.hnswlib.jna.example 9 | hnswlib-jna-example 10 | 11 | 12 | 13 | Apache License, Version 2.0 14 | http://www.apache.org/licenses/LICENSE-2.0.txt 15 | 16 | 17 | 18 | 19 | com.stepstone.search.hnswlib.jna 20 | hnswlib-jna-parent 21 | 1.4.0 22 | .. 23 | 24 | 25 | 26 | UTF-8 27 | 1.8 28 | 1.8 29 | 30 | 31 | 32 | 33 | com.stepstone.search.hnswlib.jna 34 | hnswlib-jna 35 | 1.4.0 36 | 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | arch: amd64 2 | language: cpp 3 | compiler: clang 4 | 5 | jobs: 6 | include: 7 | #################################### 8 | - stage: "Unit tests on linux" # 9 | #################################### 10 | os: linux 11 | dist: bionic 12 | env: 13 | - JAVA_OPTS="-Xmx2048m -Xms512m" 14 | - MAVEN_OPTS="$JAVA_OPTS" 15 | script: 16 | - clang++ -fPIC -std=c++11 -O3 -shared bindings.cpp -I hnswlib -o l hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.so 17 | - mvn test 18 | #################################### 19 | - stage: "Unit tests on macos" # 20 | #################################### 21 | os: osx 22 | osx_image: xcode9.3 23 | script: 24 | - clang++ -std=c++11 -O3 -shared bindings.cpp -I hnswlib -o l hnswlib-jna/src/main/resources/libhnswlib-jna-x86-64.dylib 25 | - mvn test 26 | ################################################### 27 | - stage: "Unit tests on macos (no compilation)" # 28 | ################################################### 29 | os: osx 30 | osx_image: xcode9.3 31 | script: 32 | - mvn test 33 | #################################### 34 | - stage: "Unit tests on windows" # 35 | #################################### 36 | os: windows 37 | before_install: 38 | - choco install jdk8 --version 8.0.211 39 | - choco install maven --version 3.6.3 40 | script: 41 | - export JAVA_HOME="/c/Program Files/Java/jdk1.8.0_211/" 42 | - clang++ -O3 -shared bindings.cpp -I hnswlib -o hnswlib-jna/src/main/resources/libhnswlib-jna.dll 43 | - /c/ProgramData/chocolatey/lib/maven/apache-maven-3.6.3/bin/mvn test 44 | ##################################################### 45 | - stage: "Unit tests on windows (no compilation)" # 46 | ##################################################### 47 | os: windows 48 | before_install: 49 | - choco install jdk8 --version 8.0.211 50 | - choco install maven --version 3.6.3 51 | script: 52 | - export JAVA_HOME="/c/Program Files/Java/jdk1.8.0_211/" 53 | - /c/ProgramData/chocolatey/lib/maven/apache-maven-3.6.3/bin/mvn test 54 | -------------------------------------------------------------------------------- /hnswlib/visited_list_pool.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace hnswlib { 7 | typedef unsigned short int vl_type; 8 | 9 | class VisitedList { 10 | public: 11 | vl_type curV; 12 | vl_type *mass; 13 | unsigned int numelements; 14 | 15 | VisitedList(int numelements1) { 16 | curV = -1; 17 | numelements = numelements1; 18 | mass = new vl_type[numelements]; 19 | } 20 | 21 | void reset() { 22 | curV++; 23 | if (curV == 0) { 24 | memset(mass, 0, sizeof(vl_type) * numelements); 25 | curV++; 26 | } 27 | }; 28 | 29 | ~VisitedList() { delete[] mass; } 30 | }; 31 | /////////////////////////////////////////////////////////// 32 | // 33 | // Class for multi-threaded pool-management of VisitedLists 34 | // 35 | ///////////////////////////////////////////////////////// 36 | 37 | class VisitedListPool { 38 | std::deque pool; 39 | std::mutex poolguard; 40 | int numelements; 41 | 42 | public: 43 | VisitedListPool(int initmaxpools, int numelements1) { 44 | numelements = numelements1; 45 | for (int i = 0; i < initmaxpools; i++) 46 | pool.push_front(new VisitedList(numelements)); 47 | } 48 | 49 | VisitedList *getFreeVisitedList() { 50 | VisitedList *rez; 51 | { 52 | std::unique_lock lock(poolguard); 53 | if (pool.size() > 0) { 54 | rez = pool.front(); 55 | pool.pop_front(); 56 | } else { 57 | rez = new VisitedList(numelements); 58 | } 59 | } 60 | rez->reset(); 61 | return rez; 62 | }; 63 | 64 | void releaseVisitedList(VisitedList *vl) { 65 | std::unique_lock lock(poolguard); 66 | pool.push_front(vl); 67 | }; 68 | 69 | ~VisitedListPool() { 70 | while (pool.size()) { 71 | VisitedList *rez = pool.front(); 72 | pool.pop_front(); 73 | delete rez; 74 | } 75 | }; 76 | }; 77 | } 78 | 79 | -------------------------------------------------------------------------------- /hnswlib-jna/src/test/java/com/stepstone/search/hnswlib/jna/AbstractPerformanceTest.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna; 2 | 3 | import com.stepstone.search.hnswlib.jna.exception.UnexpectedNativeException; 4 | import org.junit.BeforeClass; 5 | import org.junit.Ignore; 6 | import org.junit.Test; 7 | 8 | import java.time.Instant; 9 | import java.util.concurrent.ExecutorService; 10 | import java.util.concurrent.Executors; 11 | import java.util.concurrent.TimeUnit; 12 | 13 | import static org.junit.Assert.assertTrue; 14 | 15 | @Ignore 16 | public abstract class AbstractPerformanceTest { 17 | 18 | protected abstract Index createIndexInstance(SpaceName spaceName, int dimensions); 19 | 20 | @Test 21 | public void testPerformanceSingleThreadInsertionOf600kItems() throws UnexpectedNativeException { 22 | Index index = createIndexInstance(SpaceName.COSINE, 50); 23 | int numItems = 600_000; 24 | index.initialize(numItems); 25 | long begin = Instant.now().getEpochSecond(); 26 | for (int i = 0; i < numItems; i++) { 27 | index.addItem(HnswlibTestUtils.getRandomFloatArray(50)); 28 | } 29 | long end = Instant.now().getEpochSecond(); 30 | assertTrue((end - begin) < 600); /* +/- 8min for 1 CPU of a MacBook Pro [Intel i5 2.4GHz] (on 20/01/2020) */ 31 | index.clear(); 32 | } 33 | 34 | @Test 35 | public void testPerformanceMultiThreadedInsertionOf600kItems() throws UnexpectedNativeException, InterruptedException { 36 | int cpus = Runtime.getRuntime().availableProcessors(); 37 | ExecutorService executorService = Executors.newFixedThreadPool(cpus); 38 | 39 | int numItems = 600_000; 40 | Index index = createIndexInstance(SpaceName.COSINE, 50); 41 | index.initialize(numItems); 42 | 43 | Runnable addItemIndex = () -> { 44 | try { 45 | index.addItem(HnswlibTestUtils.getRandomFloatArray(50)); 46 | } catch (UnexpectedNativeException e) { 47 | e.printStackTrace(); 48 | } 49 | }; 50 | 51 | long begin = Instant.now().getEpochSecond(); 52 | for (int i = 0; i < numItems; i++) { 53 | executorService.submit(addItemIndex); 54 | } 55 | executorService.shutdown(); 56 | executorService.awaitTermination(5, TimeUnit.MINUTES); 57 | long end = Instant.now().getEpochSecond(); 58 | assertTrue((end - begin) < 150); /* 102s ~ running on a MacBook Pro [Intel i5 2.4GHz] (on 20/01/2020) */ 59 | index.clear(); 60 | } 61 | 62 | } 63 | -------------------------------------------------------------------------------- /hnswlib/hnswlib.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #ifndef NO_MANUAL_VECTORIZATION 3 | #ifdef __SSE__ 4 | #define USE_SSE 5 | #ifdef __AVX__ 6 | #define USE_AVX 7 | #endif 8 | #endif 9 | #endif 10 | 11 | #if defined(USE_AVX) || defined(USE_SSE) 12 | #ifdef _MSC_VER 13 | #include 14 | #include 15 | #else 16 | #include 17 | #endif 18 | 19 | #if defined(__GNUC__) 20 | #define PORTABLE_ALIGN32 __attribute__((aligned(32))) 21 | #else 22 | #define PORTABLE_ALIGN32 __declspec(align(32)) 23 | #endif 24 | #endif 25 | 26 | #include 27 | #include 28 | #include 29 | #include 30 | 31 | namespace hnswlib { 32 | typedef size_t labeltype; 33 | 34 | template 35 | class pairGreater { 36 | public: 37 | bool operator()(const T& p1, const T& p2) { 38 | return p1.first > p2.first; 39 | } 40 | }; 41 | 42 | template 43 | static void writeBinaryPOD(std::ostream &out, const T &podRef) { 44 | out.write((char *) &podRef, sizeof(T)); 45 | } 46 | 47 | template 48 | static void readBinaryPOD(std::istream &in, T &podRef) { 49 | in.read((char *) &podRef, sizeof(T)); 50 | } 51 | 52 | template 53 | using DISTFUNC = MTYPE(*)(const void *, const void *, const void *); 54 | 55 | 56 | template 57 | class SpaceInterface { 58 | public: 59 | //virtual void search(void *); 60 | virtual size_t get_data_size() = 0; 61 | 62 | virtual DISTFUNC get_dist_func() = 0; 63 | 64 | virtual void *get_dist_func_param() = 0; 65 | 66 | virtual ~SpaceInterface() {} 67 | }; 68 | 69 | template 70 | class AlgorithmInterface { 71 | public: 72 | virtual void addPoint(const void *datapoint, labeltype label)=0; 73 | virtual std::priority_queue> searchKnn(const void *, size_t) const = 0; 74 | template 75 | std::vector> searchKnn(const void*, size_t, Comp) { 76 | } 77 | virtual void saveIndex(const std::string &location)=0; 78 | virtual ~AlgorithmInterface(){ 79 | } 80 | }; 81 | 82 | 83 | } 84 | 85 | #include "space_l2.h" 86 | #include "space_ip.h" 87 | #include "bruteforce.h" 88 | #include "hnswalg.h" 89 | -------------------------------------------------------------------------------- /hnswlib-jna/src/test/java/com/stepstone/search/hnswlib/jna/IndexTest.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna; 2 | 3 | import com.stepstone.search.hnswlib.jna.exception.IndexNotInitializedException; 4 | import com.stepstone.search.hnswlib.jna.exception.OnceIndexIsClearedItCannotBeReusedException; 5 | import com.stepstone.search.hnswlib.jna.exception.UnexpectedNativeException; 6 | import org.junit.Test; 7 | 8 | import static org.hamcrest.CoreMatchers.instanceOf; 9 | import static org.junit.Assert.assertEquals; 10 | import static org.junit.Assert.assertThat; 11 | 12 | public class IndexTest extends AbstractIndexTest { 13 | 14 | @Override 15 | protected Index createIndexInstance(SpaceName spaceName, int dimensions) { 16 | return new Index(spaceName, dimensions); 17 | } 18 | 19 | @Test 20 | public void testSynchronisedIndex() throws UnexpectedNativeException { 21 | Index i1 = createIndexInstance(SpaceName.COSINE, 50); 22 | i1.initialize(500_000, 16, 200, 100); 23 | Index syncIndex = Index.synchronizedIndex(i1); 24 | assertEquals(syncIndex.getLength(), i1.getLength()); 25 | assertThat(syncIndex, instanceOf(ConcurrentIndex.class)); 26 | syncIndex.clear(); 27 | } 28 | 29 | @Test(expected = OnceIndexIsClearedItCannotBeReusedException.class) 30 | public void testSynchronisedIndexFailAfterReferenceClear() throws UnexpectedNativeException { 31 | Index i1 = createIndexInstance(SpaceName.COSINE, 50); 32 | i1.initialize(500_000, 16, 200, 100); 33 | Index syncIndex = Index.synchronizedIndex(i1); 34 | syncIndex.clear(); 35 | //has to fail as i1 was cleared through syncIndex 36 | i1.addItem(HnswlibTestUtils.getRandomFloatArray(50)); 37 | } 38 | 39 | @Test 40 | public void testComputeSimilarity() { 41 | Index index = createIndexInstance(SpaceName.COSINE, 2); 42 | index.initialize(); 43 | float similarityClose = index.computeSimilarity( 44 | new float[]{1F, 2F}, 45 | new float[]{1F, 3F} 46 | ); 47 | float similarityFar = index.computeSimilarity( 48 | new float[] {1F, 100F}, 49 | new float[] {50F, 450F} 50 | ); 51 | // both values are minus, so the closer one should be closer to zero than the farther one 52 | assertEquals(Float.compare(similarityClose, similarityFar), 1); 53 | } 54 | 55 | @Test(expected = IndexNotInitializedException.class) 56 | public void testComputeSimilarityWhenNotInitialized() { 57 | Index index = createIndexInstance(SpaceName.COSINE, 2); 58 | index.computeSimilarity( 59 | new float[] {1F, 100F}, 60 | new float[] {50F, 450F}); 61 | } 62 | 63 | } 64 | -------------------------------------------------------------------------------- /hnswlib-jna/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | hnswlib-jna 8 | hnswlib-jna 9 | 1.4.0 10 | 11 | 12 | 13 | Apache License, Version 2.0 14 | http://www.apache.org/licenses/LICENSE-2.0.txt 15 | 16 | 17 | 18 | 19 | com.stepstone.search.hnswlib.jna 20 | hnswlib-jna-parent 21 | 1.4.0 22 | .. 23 | 24 | 25 | 26 | UTF-8 27 | 1.8 28 | 1.8 29 | 2.2 30 | 5.5.0 31 | 4.11 32 | 33 | 34 | 35 | 36 | net.java.dev.jna 37 | jna 38 | ${jna.version} 39 | 40 | 41 | 42 | it.unimi.dsi 43 | fastutil 44 | 8.5.4 45 | 46 | 47 | junit 48 | junit 49 | ${junit.version} 50 | test 51 | 52 | 53 | 54 | 55 | 56 | 57 | maven-assembly-plugin 58 | 59 | 60 | package 61 | 62 | single 63 | 64 | 65 | 66 | 67 | 68 | jar-with-dependencies 69 | 70 | 71 | 72 | 73 | org.apache.maven.plugins 74 | maven-jar-plugin 75 | ${jar.plugin.version} 76 | 77 | 78 | 79 | test-jar 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/HnswlibFactory.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna; 2 | 3 | import com.sun.jna.Native; 4 | import com.sun.jna.Platform; 5 | 6 | import java.io.IOException; 7 | import java.io.InputStream; 8 | import java.nio.file.Files; 9 | import java.nio.file.Path; 10 | import java.nio.file.StandardCopyOption; 11 | 12 | /** 13 | * Factory for the hnswlib JNA implementation. 14 | */ 15 | final class HnswlibFactory { 16 | 17 | private static final String LIBRARY_NAME = "hnswlib-jna-" + Platform.ARCH; 18 | private static final String JNA_LIBRARY_PATH_PROPERTY = "jna.library.path"; 19 | 20 | private static Hnswlib instance; 21 | 22 | private HnswlibFactory() { 23 | } 24 | 25 | /** 26 | * Return a single instance of the loaded library. 27 | * 28 | * @return hnswlib JNA instance. 29 | */ 30 | static synchronized Hnswlib getInstance() { 31 | if (instance == null) { 32 | try { 33 | checkIfLibraryProvidedNeedsToBeLoadedIntoSO(); 34 | instance = Native.load(LIBRARY_NAME, Hnswlib.class); 35 | } catch (UnsatisfiedLinkError | IOException | NullPointerException e) { 36 | throw new UnsatisfiedLinkError("It's not possible to use the pre-generated dynamic libraries on your system. " 37 | + "Please compile it yourself (if not done yet) and set the \"" + JNA_LIBRARY_PATH_PROPERTY + "\" property " 38 | + "with correct path to where \"" + getLibraryFileName() + "\" is located."); 39 | } 40 | } 41 | return instance; 42 | } 43 | 44 | private static String getLibraryFileName(){ 45 | String extension; 46 | if (Platform.isLinux()) { 47 | extension = "so"; 48 | } else if (Platform.isWindows()) { 49 | extension = "dll"; 50 | } else { 51 | extension = "dylib"; 52 | } 53 | return String.format("libhnswlib-jna-%s.%s", Platform.ARCH, extension); 54 | } 55 | 56 | private static void copyPreGeneratedLibraryFiles(Path folder, String fileName) throws IOException { 57 | InputStream libraryStream = HnswlibFactory.class.getResourceAsStream("/" + fileName); 58 | /* windows seems to be blocking manipulation of .lib files; we store as .libw for now. */ 59 | Files.copy(libraryStream, folder.resolve(fileName.replace(".libw",".lib")), StandardCopyOption.REPLACE_EXISTING); 60 | } 61 | 62 | private static void checkIfLibraryProvidedNeedsToBeLoadedIntoSO() throws IOException { 63 | String property = System.getProperty(JNA_LIBRARY_PATH_PROPERTY); 64 | if (property == null) { 65 | Path libraryFolder = Files.createTempDirectory(LIBRARY_NAME); 66 | copyPreGeneratedLibraryFiles(libraryFolder, getLibraryFileName()); 67 | if (Platform.isWindows()) { 68 | copyPreGeneratedLibraryFiles(libraryFolder, "libhnswlib-jna-x86-64.exp"); 69 | copyPreGeneratedLibraryFiles(libraryFolder, "libhnswlib-jna-x86-64.libw"); 70 | } 71 | System.setProperty(JNA_LIBRARY_PATH_PROPERTY, libraryFolder.toString()); 72 | libraryFolder.toFile().deleteOnExit(); 73 | } 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /hnswlib-jna-legacy/src/main/java/com/stepstone/search/hnswlib/jna/HnswlibFactory.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna; 2 | 3 | import com.sun.jna.Native; 4 | import com.sun.jna.Platform; 5 | 6 | import java.io.IOException; 7 | import java.io.InputStream; 8 | import java.nio.file.Files; 9 | import java.nio.file.Path; 10 | import java.nio.file.StandardCopyOption; 11 | 12 | /** 13 | * Factory for the hnswlib JNA (4.x.x) implementation. 14 | */ 15 | public final class HnswlibFactory { 16 | 17 | private static final String LIBRARY_NAME = "hnswlib-jna-" + Platform.ARCH; 18 | private static final String JNA_LIBRARY_PATH_PROPERTY = "jna.library.path"; 19 | 20 | private static Hnswlib instance; 21 | 22 | private HnswlibFactory() { 23 | } 24 | 25 | /** 26 | * Return a single instance of the loaded library. 27 | * 28 | * @return hnswlib JNA instance. 29 | */ 30 | static synchronized Hnswlib getInstance() { 31 | if (instance == null) { 32 | try { 33 | checkIfLibraryProvidedNeedsToBeLoadedIntoSO(); 34 | instance = (Hnswlib) Native.loadLibrary(LIBRARY_NAME, Hnswlib.class); 35 | } catch (UnsatisfiedLinkError | IOException | NullPointerException e) { 36 | throw new UnsatisfiedLinkError("It's not possible to use the pre-generated dynamic libraries on your system. " 37 | + "Please compile it yourself (if not done yet) and set the \"" + JNA_LIBRARY_PATH_PROPERTY + "\" property " 38 | + "with correct path to where \"" + getLibraryFileName() + "\" is located."); 39 | } 40 | } 41 | return instance; 42 | } 43 | 44 | private static String getLibraryFileName(){ 45 | String extension; 46 | if (Platform.isLinux()) { 47 | extension = "so"; 48 | } else if (Platform.isWindows()) { 49 | extension = "dll"; 50 | } else { 51 | extension = "dylib"; 52 | } 53 | return String.format("libhnswlib-jna-%s.%s", Platform.ARCH, extension); 54 | } 55 | 56 | private static void copyPreGeneratedLibraryFiles(Path folder, String fileName) throws IOException { 57 | InputStream libraryStream = HnswlibFactory.class.getResourceAsStream("/" + fileName); 58 | /* windows seems to be blocking manipulation of .lib files; we store as .libw for now. */ 59 | Files.copy(libraryStream, folder.resolve(fileName.replace(".libw",".lib")), StandardCopyOption.REPLACE_EXISTING); 60 | } 61 | 62 | private static void checkIfLibraryProvidedNeedsToBeLoadedIntoSO() throws IOException { 63 | String property = System.getProperty(JNA_LIBRARY_PATH_PROPERTY); 64 | if (property == null) { 65 | Path libraryFolder = Files.createTempDirectory(LIBRARY_NAME); 66 | copyPreGeneratedLibraryFiles(libraryFolder, getLibraryFileName()); 67 | if (Platform.isWindows()) { 68 | copyPreGeneratedLibraryFiles(libraryFolder, "libhnswlib-jna-x86-64.exp"); 69 | copyPreGeneratedLibraryFiles(libraryFolder, "libhnswlib-jna-x86-64.libw"); 70 | } 71 | System.setProperty(JNA_LIBRARY_PATH_PROPERTY, libraryFolder.toString()); 72 | libraryFolder.toFile().deleteOnExit(); 73 | } 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /hnswlib-jna-legacy/pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | hnswlib-jna-legacy 8 | hnswlib-jna-legacy 9 | 1.4.0 10 | 11 | 12 | 13 | Apache License, Version 2.0 14 | http://www.apache.org/licenses/LICENSE-2.0.txt 15 | 16 | 17 | 18 | 19 | com.stepstone.search.hnswlib.jna 20 | hnswlib-jna-parent 21 | 1.4.0 22 | .. 23 | 24 | 25 | 26 | 1.4.0 27 | UTF-8 28 | 1.8 29 | 1.8 30 | 3.1.2 31 | 4.2.2 32 | 4.11 33 | 34 | 35 | 36 | 37 | net.java.dev.jna 38 | jna 39 | ${jna.version} 40 | 41 | 42 | junit 43 | junit 44 | ${junit.version} 45 | test 46 | 47 | 48 | com.stepstone.search.hnswlib.jna 49 | hnswlib-jna 50 | ${hnswlib.jna.version} 51 | compile 52 | 53 | 54 | net.java.dev.jna 55 | jna 56 | 57 | 58 | true 59 | 60 | 61 | com.stepstone.search.hnswlib.jna 62 | hnswlib-jna 63 | ${hnswlib.jna.version} 64 | test-jar 65 | test 66 | 67 | 68 | 69 | 70 | 71 | 72 | maven-assembly-plugin 73 | 74 | 75 | package 76 | 77 | single 78 | 79 | 80 | 81 | 82 | 83 | jar-with-dependencies 84 | 85 | 86 | 87 | 88 | 89 | org.apache.maven.plugins 90 | maven-dependency-plugin 91 | ${maven.dependency.plugin.version} 92 | 93 | 94 | unpack 95 | prepare-package 96 | 97 | unpack 98 | 99 | 100 | 101 | 102 | com.stepstone.search.hnswlib.jna 103 | hnswlib-jna 104 | ${hnswlib.jna.version} 105 | jar 106 | false 107 | ${project.build.directory}/classes 108 | **/*.class,**/*.dll,**/*.dylib,,**/*.exp,**/*.lib*,**/*.so 109 | **/HnswlibFactory.class,**/*test.class 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /hnswlib-jna-example/src/main/java/com/stepstone/search/hnswlib/jna/example/App.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna.example; 2 | 3 | import com.stepstone.search.hnswlib.jna.Index; 4 | import com.stepstone.search.hnswlib.jna.QueryTuple; 5 | import com.stepstone.search.hnswlib.jna.SpaceName; 6 | 7 | import java.io.File; 8 | import java.time.Instant; 9 | import java.util.Arrays; 10 | import java.util.HashMap; 11 | import java.util.Map; 12 | import java.util.Random; 13 | import java.util.concurrent.ExecutorService; 14 | import java.util.concurrent.Executors; 15 | import java.util.concurrent.TimeUnit; 16 | 17 | public class App { 18 | 19 | private static void exampleOfACosineIndex() { 20 | float[] i1 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; 21 | Index.normalize(i1); 22 | float[] i2 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.95f}; 23 | Index.normalize(i2); 24 | float[] i3 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.9f}; 25 | Index.normalize(i3); /* For cosine, if the normalization is not explicitly done, it will be done on the native side */ 26 | /* when you call index.addItem(). When explicitly done (in the Java code), use addNormalizedItem() 27 | to avoid double normalization. */ 28 | 29 | Index indexCosine = new Index(SpaceName.COSINE, 7); 30 | indexCosine.initialize(3); 31 | indexCosine.addNormalizedItem(i1, 1_111_111); /* 1_111_111 is an ID */ 32 | indexCosine.addNormalizedItem(i2, 2_222_222); 33 | indexCosine.addNormalizedItem(i3); /* if not defined, an incremental ID will be automatically assigned */ 34 | 35 | float[] input = new float[] {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; 36 | Index.normalize(input); 37 | 38 | QueryTuple cosineQT = indexCosine.knnNormalizedQuery(input, 3); 39 | 40 | System.out.println("Cosine Index - Query Results: "); 41 | System.out.println(Arrays.toString(cosineQT.getCoefficients())); 42 | System.out.println(Arrays.toString(cosineQT.getIds())); 43 | indexCosine.clear(); 44 | } 45 | 46 | private static void exampleOfAInnerProductIndex() { 47 | float[] i1 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; 48 | float[] i2 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.95f}; 49 | float[] i3 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.9f}; 50 | 51 | Index indexIP = new Index(SpaceName.IP, 7); 52 | indexIP.initialize(3, 16, 100, 200); /* set maxNumberOfElements, m, efConstruction and randomSeed */ 53 | indexIP.setEf(10); 54 | indexIP.addItem(i1, 1_111_111); /* 1_111_111 is an ID */ 55 | indexIP.addItem(i2, 0xCAFECAFE); 56 | indexIP.addItem(i3); /* if not defined, an incremental ID will be automatically assigned */ 57 | 58 | float[] input = new float[] {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; 59 | 60 | QueryTuple ipQT = indexIP.knnQuery(input, 3); 61 | 62 | System.out.println("Inner Product Index - Query Results: "); 63 | System.out.println(Arrays.toString(ipQT.getCoefficients())); 64 | System.out.println(Arrays.toString(ipQT.getIds())); 65 | indexIP.clear(); 66 | } 67 | 68 | private static void exampleOfMultiThreadedIndexBuild() throws InterruptedException { 69 | int numberOfItems = 200_000; 70 | int numberOfThreads = Runtime.getRuntime().availableProcessors(); /* try numberOfThreads = 1 to see the difference ;D */ 71 | 72 | /* this step is just to have some content for indexing (if you have your vectors, you're good to go) */ 73 | Map vectorsMap = new HashMap<>(numberOfItems); 74 | for (int i = 0; i < numberOfItems; i++){ 75 | vectorsMap.put(i , getRandomFloatArray(40)); 76 | } 77 | /* ************************************************************************************************* */ 78 | 79 | Index index = new Index(SpaceName.IP, 7); 80 | index.initialize(numberOfItems); 81 | 82 | long startTime = Instant.now().getEpochSecond(); 83 | ExecutorService executorService = Executors.newFixedThreadPool(numberOfThreads); 84 | for (Map.Entry entry : vectorsMap.entrySet()) { 85 | executorService.submit( () -> index.addItem(entry.getValue(), entry.getKey()) ); 86 | } 87 | executorService.shutdown(); 88 | executorService.awaitTermination(10, TimeUnit.MINUTES); 89 | long endTime = Instant.now().getEpochSecond(); 90 | 91 | System.out.println("Multi Threaded Index Build:"); 92 | System.out.println("Building time for " + index.getLength() + " items took " + (endTime - startTime) + " seconds with " + numberOfThreads + " threads"); 93 | } 94 | 95 | private static float[] getRandomFloatArray(int dimension){ 96 | float[] array = new float[dimension]; 97 | Random random = new Random(); 98 | for (int i = 0; i < dimension; i++){ 99 | array[i] = random.nextFloat(); 100 | } 101 | return array; 102 | } 103 | 104 | /** 105 | * This is an example of how manually specify the location of the 106 | * dynamic libraries for hnswlib-jna. This step is required when 107 | * the pre-compiled ones (provided within the jars) are not sufficient 108 | * due to operating system dependencies or version of others libraries. 109 | */ 110 | private static void setupHnswlibJnaDLLManually(){ 111 | File projectFolder = new File("hnswlib-jna-example/lib"); 112 | System.setProperty("jna.library.path", projectFolder.getAbsolutePath()); 113 | } 114 | 115 | public static void main( String[] args ) throws InterruptedException { 116 | //setupHnswlibJnaDLLManually(); 117 | exampleOfACosineIndex(); 118 | exampleOfAInnerProductIndex(); 119 | exampleOfMultiThreadedIndexBuild(); 120 | } 121 | 122 | } 123 | -------------------------------------------------------------------------------- /hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/Hnswlib.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna; 2 | 3 | import com.sun.jna.Library; 4 | import com.sun.jna.Pointer; 5 | 6 | /** 7 | * Interface that implements JNA (Java Native Access) to Hnswlib 8 | * a fast approximate nearest neighbor search library available 9 | * on (https://github.com/nmslib/hnswlib). This implementation was created 10 | * in order to provide a high performance also in Java (not only in Python or C++). 11 | * 12 | * This implementation relies also in a dynamic library generated 13 | * from the sources available in bindings.cpp. 14 | */ 15 | public interface Hnswlib extends Library { 16 | 17 | /** 18 | * Allocates memory for the index in the native context and 19 | * stores the address in a JNA Pointer variable. 20 | * 21 | * @param spaceName - use: l2, ip or cosine strings only; 22 | * @param dimension - length of the vectors used for indexation. 23 | * 24 | * @return the index reference pointer. 25 | */ 26 | Pointer createNewIndex(String spaceName, int dimension); 27 | 28 | /** 29 | * Initialize the index with information needed for the indexation. 30 | * 31 | * @param index - JNA pointer reference of the index; 32 | * @param maxNumberOfElements - max number of elements in the index; 33 | * @param m - the value of M; 34 | * @param efConstruction - ef parameter; 35 | * @param randomSeed - a random seed specified by the user. 36 | * 37 | * @return a result code. 38 | */ 39 | int initNewIndex(Pointer index, int maxNumberOfElements, int m, int efConstruction, int randomSeed); 40 | 41 | /** 42 | * Add an item to the index. 43 | * 44 | * @param item - array containing the input to be inserted into the index; 45 | * @param normalized - is the item normalized? if not and if required, it will be performed at the native level; 46 | * @param id - an identifier to be used for this entry; 47 | * @param index - JNA pointer reference of the index. 48 | * 49 | * @return a result code. 50 | */ 51 | int addItemToIndex(float[] item, boolean normalized, int id, Pointer index); 52 | 53 | /** 54 | * Retrieve the number of elements already inserted into the index. 55 | * 56 | * @param index - JNA pointer reference of the index. 57 | * 58 | * @return number of items in the index. 59 | */ 60 | int getIndexLength(Pointer index); 61 | 62 | /** 63 | * Save the content of an index into a file (using native implementation). 64 | * 65 | * @param index - JNA pointer reference of the index. 66 | * @param path - path where the index will be stored. 67 | * 68 | * @return a result code. 69 | */ 70 | int saveIndexToPath(Pointer index, String path); 71 | 72 | /** 73 | * Restore the content of an index saved into a file (using native implementation). 74 | * 75 | * @param index - JNA pointer reference of the index; 76 | * @param maxNumberOfElements - max number of items to be inserted into the index; 77 | * @param path - path where the index will be stored. 78 | * 79 | * @return a result code. 80 | */ 81 | int loadIndexFromPath(Pointer index, int maxNumberOfElements, String path); 82 | 83 | /** 84 | * This function invokes the knnQuery available in the hnswlib native library. 85 | * 86 | * @param index - JNA pointer reference of the index; 87 | * @param input - input used for the query; 88 | * @param normalized - is the input normalized? if not and if required, it will be performed at the native level; 89 | * @param k - dimension used for the query; 90 | * @param indices [output] retrieves the indices returned by the query; 91 | * @param coefficients [output] retrieves the coefficients returned by the query. 92 | * 93 | * @return a result code. 94 | */ 95 | int knnQuery(Pointer index, float[] input, boolean normalized, int k, int[] indices, float[] coefficients); 96 | 97 | /** 98 | * Clear the index from the memory. 99 | * 100 | * @param index - JNA pointer reference of the index. 101 | * 102 | * @return a result code. 103 | */ 104 | int clearIndex(Pointer index); 105 | 106 | /** 107 | * Sets the query time accuracy / speed trade-off value. 108 | * 109 | * @param index - JNA pointer reference of the index; 110 | * @param ef value. 111 | * 112 | * @return a result code. 113 | */ 114 | int setEf(Pointer index, int ef); 115 | 116 | /** 117 | * Populate vector with data for given id 118 | * @param index index 119 | * @param id id 120 | * @param vector vector 121 | * @param dim dimension 122 | * 123 | * @return result code 124 | */ 125 | int getData(Pointer index, int id, float[] vector, int dim); 126 | 127 | /** 128 | * Determine whether the index contains data for given id. 129 | * 130 | * @param index index 131 | * @param id id 132 | * 133 | * @return result_code 134 | */ 135 | int hasId(Pointer index, int id); 136 | 137 | /** 138 | * Compute similarity between two vectors 139 | * 140 | * @param index index 141 | * @param vector1 vector1 142 | * @param vector2 vector2 143 | * 144 | * @return similarity score between vectors 145 | */ 146 | float computeSimilarity(Pointer index, float[] vector1, float[] vector2); 147 | 148 | /** 149 | * Retrieves the value of M. 150 | * 151 | * @param index reference. 152 | * 153 | * @return value of M. 154 | */ 155 | int getM(Pointer index); 156 | 157 | /** 158 | * Retrieves the current ef construction value. 159 | * 160 | * @param index reference. 161 | * 162 | * @return efConstruction value. 163 | */ 164 | int getEfConstruction(Pointer index); 165 | 166 | /** 167 | * Retrieves the current ef value. 168 | * 169 | * @param index reference. 170 | * 171 | * @return EF value. 172 | */ 173 | int getEf(Pointer index); 174 | 175 | /** 176 | * Marks an item ID as deleted. 177 | * 178 | * @param index reference; 179 | * @param id label. 180 | * 181 | * @return a result code. 182 | */ 183 | int markDeleted(Pointer index, int id); 184 | 185 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | # __Hnswlib with JNA (Java Native Access)__ 5 | 6 | This project contains a [JNA](https://github.com/java-native-access/jna) (Java Native Access) implementation built on top of the native [Hnswlib](https://github.com/nmslib/hnswlib) (Hierarchical Navigable Small World Graph) which offers a fast approximate nearest neighbor search. It includes some modifications and simplifications in order to provide Hnswlib features with native like performance to applications written in Java. Differently from the original Python implementation, the multi-thread support is not included in the bindings itself but it can be easily implemented on the Java side. `Hnswlib-jna` works in collaboration with a __shared library__ which contains the native code. For more information, please check the sections below. 7 | 8 | ## __Dependencies__ 9 | 10 | ### __Pre-Generated Shared Library__ 11 | 12 | The jar file includes some pre-generated libraries for _Windows_, _Debian Linux_ and _MacOS_ (x86-64) which should allow an easy integration and abstract all complexity related to compilation. An extra library for Debian Linux (aarch64) is also available for tests with AWS Graviton 2. In the case of operating system issues, a runtime exception will be thrown and the manual compilation will be advised. 13 | 14 | __On Windows, the [Build Tools for Visual Studio 2019 (C++ build tools)](https://visualstudio.microsoft.com/downloads/#build-tools-for-visual-studio-2019) is required__. 15 | 16 | ## __Using in Your Project__ 17 | 18 | Add the following dependency in your `pom.xml`: 19 | ``` 20 | 21 | com.stepstone.search.hnswlib.jna 22 | hnswlib-jna 23 | 1.4.0 24 | 25 | ``` 26 | 27 | For more information and implementation details, please check [hnswlib-jna-example](./hnswlib-jna-example/). 28 | 29 | ## __Manual Compilation (Whenever it is advised)__ 30 | 31 | This section includes more information about how to compile the shared libraries on Windows, Linux and Mac for different architectures (e.g., `x86-64`, `aarch64`). __If you were able to run the example project on your PC, this section can be ignored.__ 32 | 33 | ### __Compiling the Shared Library__ 34 | 35 | To generate the shared library required by this project, `binding.cpp` needs to be compiled using a C compiler (e.g., `clang` or `gcc`) with C++11 support, at least. The library can be generated with `clang` via: 36 | ``` 37 | clang++ -O3 -shared bindings.cpp -I hnswlib -o /lib/libhnswlib-jna-x86-64.dylib 38 | ``` 39 | __Note:__ The shared library's name must be: __libhnswlib-jna-ARCH.EXT__ where `ARCH` is the canonical architecture name (e.g., `x86-64` for AMD64, or `aarch64` for ARM64) and `EXT` is `dylib` for MacOS, for windows use `dll`, and linux `so`. 40 | 41 | #### Instructions for Windows 42 | 43 | ##### Using Visual Studio Build Tools 44 | 45 | 1. Download and install [LLVM](https://releases.llvm.org/9.0.0/LLVM-9.0.0-win64.exe); 46 | 2. Download and install [Build Tools for Visual Studio 2019 (or higher)](https://visualstudio.microsoft.com/downloads/#build-tools-for-visual-studio-2019); 47 | 3. Compile the bindings using `clang`: 48 | ``` 49 | clang++ -O3 -shared bindings.cpp -I hnswlib -o /lib/libhnswlib-jna-x86-64.dll 50 | ``` 51 | This procedure will generate the 3 necessary files: `libhnswlib-jna-x86-64.dll`, `libhnswlib-jna-x86-64.exp` and `libhnswlib-jna-x86-64.lib`. 52 | 53 | ##### Using MinGW64 54 | 55 | 1. Download and install [LLVM](https://releases.llvm.org/9.0.0/LLVM-9.0.0-win64.exe); 56 | 2. Make sure that LLVM's bin folder is in your PATH; 57 | 3. Download [MinGW-w64 with Headers for Clang](https://sourceforge.net/projects/mingw-w64/files/Toolchains%20targetting%20Win64/Personal%20Builds/mingw-builds/8.1.0/threads-posix/seh/); 58 | 4. Unpack the archive and include MinGW64's bin folder into your PATH as well; 59 | 5. Compile the bindings using `clang`: 60 | ``` 61 | clang++ -O3 -target x86_64-pc-windows-gnu -shared bindings.cpp -I hnswlib -o lib/libhnswlib-jna-x86-64.dll -lpthread 62 | ``` 63 | This procedure will generate `libhnswlib-jna-x86-64.dll`. 64 | 65 | #### Instructions for Linux 66 | 67 | 1. Download and install `clang` (older versions might trigger compilation issues, so it is better use a recent version); 68 | 2. Compile the bindings using `clang`: 69 | ``` 70 | clang++ -O3 -fPIC -shared -std=c++11 bindings.cpp -I hnswlib -o /lib/libhnswlib-jna-x86-64.so 71 | ``` 72 | This procedure will generate `libhnswlib-jna-x86-64.so`. 73 | 74 | ### __Reading the Shared Library in Your Project__ 75 | 76 | Once the shared library is available, it is necessary to tell the `JVM` and `JNA` where it is located. This can be done by setting the property `jna.library.path` via JVM parameters or system properties. 77 | 78 | #### Via JVM parameters 79 | ``` 80 | -Djna.library.path=/lib 81 | ``` 82 | #### Programmatically via System Class 83 | ``` 84 | System.setProperty("jna.library.path", "/lib"); 85 | ``` 86 | For more information and implementation details, please check [hnswlib-jna-example](./hnswlib-jna-example/). 87 | 88 | ## License 89 | Copyright 2020 StepStone Services 90 | 91 | Licensed under the Apache License, Version 2.0 (the "License"); 92 | you may not use this file except in compliance with the License. 93 | You may obtain a copy of the License at 94 | 95 |     [http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0) 96 | 97 | Unless required by applicable law or agreed to in writing, software 98 | distributed under the License is distributed on an "AS IS" BASIS, 99 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 100 | See the License for the specific language governing permissions and 101 | limitations under the License. 102 | -------------------------------------------------------------------------------- /pom.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 4.0.0 6 | 7 | hnswlib 8 | com.stepstone.search.hnswlib.jna 9 | hnswlib-jna-parent 10 | 1.5.0 11 | pom 12 | 13 | This project contains a JNA (Java Native Access) implementation built on top of the native Hnswlib (Hierarchical Navigable Small World Graph) which offers a fast approximate nearest neighbor search. It includes some modifications and simplifications in order to provide Hnswlib features with native like performance to applications written in Java. Differently from the original Python implementation, the multi-thread support is not included in the bindings itself but it can be easily implemented on the Java side. 14 | https://github.com/stepstone-tech/hnswlib-jna 15 | 16 | 17 | scm:git:git://github.com/stepstone-tech/hnswlib-jna.git 18 | scm:git:ssh://github.com:stepstone-tech/hnswlib-jna.git 19 | https://github.com/stepstone-tech/hnswlib-jna/tree/master 20 | 21 | 22 | 23 | 24 | Alex Docherty 25 | alexander.docherty@stepstone.com 26 | StepStone 27 | https://www.stepstone.com 28 | 29 | 30 | Casper Davies 31 | casper.davies@stepstone.com 32 | StepStone 33 | https://www.stepstone.com 34 | 35 | 36 | German Hurtado 37 | german.hurtado@stepstone.com 38 | StepStone 39 | https://www.stepstone.com 40 | 41 | 42 | Henri David 43 | henri.david@stepstone.com 44 | StepStone 45 | https://www.stepstone.com 46 | 47 | 48 | Hussama Ismail 49 | hussama.ismail@stepstone.com 50 | StepStone 51 | https://www.stepstone.com 52 | 53 | 54 | Stefan Skoruppa 55 | stefan.skoruppa@stepstone.com 56 | StepStone 57 | https://www.stepstone.com 58 | 59 | 60 | Tomasz Wojtun 61 | tomasz.wojtun@stepstone.com 62 | StepStone 63 | https://www.stepstone.com 64 | 65 | 66 | Vinitha Venugopalsavithri 67 | vinitha.venugopalsavithri@stepstone.com 68 | StepStone 69 | https://www.stepstone.com 70 | 71 | 72 | Zhenhua Mai 73 | zhenhua.mai@stepstone.com 74 | StepStone 75 | https://www.stepstone.com 76 | 77 | 78 | 79 | 80 | 81 | Apache License, Version 2.0 82 | http://www.apache.org/licenses/LICENSE-2.0.txt 83 | 84 | 85 | 86 | 87 | hnswlib-jna 88 | hnswlib-jna-legacy 89 | hnswlib-jna-example 90 | 91 | 92 | 93 | 94 | ossrh 95 | https://oss.sonatype.org/content/repositories/snapshots 96 | 97 | 98 | 99 | 100 | 101 | 102 | org.sonatype.plugins 103 | nexus-staging-maven-plugin 104 | 1.6.7 105 | true 106 | 107 | ossrh 108 | https://oss.sonatype.org/ 109 | true 110 | 111 | 112 | 113 | org.apache.maven.plugins 114 | maven-source-plugin 115 | 2.2.1 116 | 117 | 118 | attach-sources 119 | 120 | jar-no-fork 121 | 122 | 123 | 124 | 125 | 126 | org.apache.maven.plugins 127 | maven-javadoc-plugin 128 | 2.9.1 129 | 130 | 131 | attach-javadocs 132 | 133 | jar 134 | 135 | 136 | 137 | 138 | 139 | org.apache.maven.plugins 140 | maven-gpg-plugin 141 | 1.5 142 | 143 | 144 | sign-artifacts 145 | verify 146 | 147 | sign 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | -------------------------------------------------------------------------------- /hnswlib/bruteforce.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | namespace hnswlib { 8 | template 9 | class BruteforceSearch : public AlgorithmInterface { 10 | public: 11 | BruteforceSearch(SpaceInterface *s) { 12 | 13 | } 14 | BruteforceSearch(SpaceInterface *s, const std::string &location) { 15 | loadIndex(location, s); 16 | } 17 | 18 | BruteforceSearch(SpaceInterface *s, size_t maxElements) { 19 | maxelements_ = maxElements; 20 | data_size_ = s->get_data_size(); 21 | fstdistfunc_ = s->get_dist_func(); 22 | dist_func_param_ = s->get_dist_func_param(); 23 | size_per_element_ = data_size_ + sizeof(labeltype); 24 | data_ = (char *) malloc(maxElements * size_per_element_); 25 | if (data_ == nullptr) 26 | std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data"); 27 | cur_element_count = 0; 28 | } 29 | 30 | ~BruteforceSearch() { 31 | free(data_); 32 | } 33 | 34 | char *data_; 35 | size_t maxelements_; 36 | size_t cur_element_count; 37 | size_t size_per_element_; 38 | 39 | size_t data_size_; 40 | DISTFUNC fstdistfunc_; 41 | void *dist_func_param_; 42 | std::mutex index_lock; 43 | 44 | std::unordered_map dict_external_to_internal; 45 | 46 | void addPoint(const void *datapoint, labeltype label) { 47 | 48 | int idx; 49 | { 50 | std::unique_lock lock(index_lock); 51 | 52 | 53 | 54 | auto search=dict_external_to_internal.find(label); 55 | if (search != dict_external_to_internal.end()) { 56 | idx=search->second; 57 | } 58 | else{ 59 | if (cur_element_count >= maxelements_) { 60 | throw std::runtime_error("The number of elements exceeds the specified limit\n"); 61 | } 62 | idx=cur_element_count; 63 | dict_external_to_internal[label] = idx; 64 | cur_element_count++; 65 | } 66 | } 67 | memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype)); 68 | memcpy(data_ + size_per_element_ * idx, datapoint, data_size_); 69 | 70 | 71 | 72 | 73 | }; 74 | 75 | void removePoint(labeltype cur_external) { 76 | size_t cur_c=dict_external_to_internal[cur_external]; 77 | 78 | dict_external_to_internal.erase(cur_external); 79 | 80 | labeltype label=*((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_)); 81 | dict_external_to_internal[label]=cur_c; 82 | memcpy(data_ + size_per_element_ * cur_c, 83 | data_ + size_per_element_ * (cur_element_count-1), 84 | data_size_+sizeof(labeltype)); 85 | cur_element_count--; 86 | 87 | } 88 | 89 | 90 | std::priority_queue> 91 | searchKnn(const void *query_data, size_t k) const { 92 | std::priority_queue> topResults; 93 | if (cur_element_count == 0) return topResults; 94 | for (int i = 0; i < k; i++) { 95 | dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); 96 | topResults.push(std::pair(dist, *((labeltype *) (data_ + size_per_element_ * i + 97 | data_size_)))); 98 | } 99 | dist_t lastdist = topResults.top().first; 100 | for (int i = k; i < cur_element_count; i++) { 101 | dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); 102 | if (dist <= lastdist) { 103 | topResults.push(std::pair(dist, *((labeltype *) (data_ + size_per_element_ * i + 104 | data_size_)))); 105 | if (topResults.size() > k) 106 | topResults.pop(); 107 | lastdist = topResults.top().first; 108 | } 109 | 110 | } 111 | return topResults; 112 | }; 113 | 114 | template 115 | std::vector> 116 | searchKnn(const void* query_data, size_t k, Comp comp) { 117 | std::vector> result; 118 | if (cur_element_count == 0) return result; 119 | 120 | auto ret = searchKnn(query_data, k); 121 | 122 | while (!ret.empty()) { 123 | result.push_back(ret.top()); 124 | ret.pop(); 125 | } 126 | 127 | std::sort(result.begin(), result.end(), comp); 128 | 129 | return result; 130 | } 131 | 132 | void saveIndex(const std::string &location) { 133 | std::ofstream output(location, std::ios::binary); 134 | std::streampos position; 135 | 136 | writeBinaryPOD(output, maxelements_); 137 | writeBinaryPOD(output, size_per_element_); 138 | writeBinaryPOD(output, cur_element_count); 139 | 140 | output.write(data_, maxelements_ * size_per_element_); 141 | 142 | output.close(); 143 | } 144 | 145 | void loadIndex(const std::string &location, SpaceInterface *s) { 146 | 147 | 148 | std::ifstream input(location, std::ios::binary); 149 | std::streampos position; 150 | 151 | readBinaryPOD(input, maxelements_); 152 | readBinaryPOD(input, size_per_element_); 153 | readBinaryPOD(input, cur_element_count); 154 | 155 | data_size_ = s->get_data_size(); 156 | fstdistfunc_ = s->get_dist_func(); 157 | dist_func_param_ = s->get_dist_func_param(); 158 | size_per_element_ = data_size_ + sizeof(labeltype); 159 | data_ = (char *) malloc(maxelements_ * size_per_element_); 160 | if (data_ == nullptr) 161 | std::runtime_error("Not enough memory: loadIndex failed to allocate data"); 162 | 163 | input.read(data_, maxelements_ * size_per_element_); 164 | 165 | input.close(); 166 | 167 | } 168 | 169 | }; 170 | } 171 | -------------------------------------------------------------------------------- /hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/ConcurrentIndex.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna; 2 | 3 | import java.nio.file.Path; 4 | import java.util.Optional; 5 | import java.util.concurrent.locks.Lock; 6 | import java.util.concurrent.locks.ReadWriteLock; 7 | import java.util.concurrent.locks.ReentrantReadWriteLock; 8 | 9 | /** 10 | * This class offers a thread-safe alternative for a small-world Index. 11 | * It allows concurrent item insertion and querying which is not supported 12 | * by the native Hnswlib implementation. 13 | * 14 | * Note: this class relies on a ReadWriteLock with fairness enabled. So, 15 | * when multi-thread insertions are serialized. To take advantage of parallel 16 | * insertion, please create a Index instance and then retrieve a ConcurrentIndex 17 | * one via Index.synchronizedIndex() method call. 18 | */ 19 | public class ConcurrentIndex extends Index { 20 | 21 | private ReadWriteLock readWriteLock = new ReentrantReadWriteLock(true); 22 | private Lock readLock = readWriteLock.readLock(); 23 | private Lock writeLock = readWriteLock.writeLock(); 24 | 25 | public ConcurrentIndex(SpaceName spaceName, int dimensions) { 26 | super(spaceName, dimensions); 27 | } 28 | 29 | /** 30 | * Thread-safe method which adds an item without ID to the index. 31 | * Internally, an incremental ID (starting from 1) will be given to this item. 32 | * 33 | * @param item - float array with the length expected by the index (dimension). 34 | */ 35 | @Override 36 | public void addItem(float[] item) { 37 | this.writeLock.lock(); 38 | try { 39 | super.addItem(item, NO_ID); 40 | } finally { 41 | this.writeLock.unlock(); 42 | } 43 | } 44 | 45 | /** 46 | * Thread-safe method which adds an item with ID to the index. 47 | * It won't apply any extra normalization unless it is required 48 | * by the Vector Space (e.g., COSINE). 49 | * 50 | * @param item - float array with the length expected by the index (dimension); 51 | * @param id - an identifier used by the native library. 52 | */ 53 | @Override 54 | public void addItem(float[] item, int id) { 55 | this.writeLock.lock(); 56 | try { 57 | super.addItem(item, id); 58 | } finally { 59 | this.writeLock.unlock(); 60 | } 61 | } 62 | 63 | /** 64 | * Thread-safe method which adds a normalized item without ID to the index. 65 | * Internally, an incremental ID (starting from 0) will be given to this item. 66 | * 67 | * @param item - float array with the length expected by the index (dimension). 68 | */ 69 | @Override 70 | public void addNormalizedItem(float[] item) { 71 | this.writeLock.lock(); 72 | try { 73 | super.addNormalizedItem(item, Index.NO_ID); 74 | } finally { 75 | this.writeLock.unlock(); 76 | } 77 | } 78 | 79 | /** 80 | * Thread-safe method which adds a normalized item with ID to the index. 81 | * 82 | * @param item - float array with the length expected by the index (dimension); 83 | * @param id - an identifier used by the native library. 84 | */ 85 | @Override 86 | public void addNormalizedItem(float[] item, int id) { 87 | this.writeLock.lock(); 88 | try { 89 | super.addNormalizedItem(item, id); 90 | } finally { 91 | this.writeLock.unlock(); 92 | } 93 | } 94 | 95 | /** 96 | * Thread-safe method which returns the number of elements 97 | * already inserted in the index. 98 | * 99 | * @return elements count. 100 | */ 101 | @Override 102 | public int getLength(){ 103 | this.readLock.lock(); 104 | try { 105 | return super.getLength(); 106 | } finally { 107 | this.readLock.unlock(); 108 | } 109 | } 110 | 111 | /** 112 | * Thread-safe method which performs a knn query in the index instance. 113 | * In case the vector space requires the input to be normalized, it will 114 | * normalize at the native level. 115 | * 116 | * @param input - float array; 117 | * @param k - number of results expected. 118 | * 119 | * @return a query tuple instance that contain the indices and coefficients. 120 | */ 121 | @Override 122 | public QueryTuple knnQuery(float[] input, int k) { 123 | this.readLock.lock(); 124 | QueryTuple queryTuple; 125 | try { 126 | queryTuple = super.knnQuery(input, k); 127 | } finally { 128 | this.readLock.unlock(); 129 | } 130 | return queryTuple; 131 | } 132 | 133 | /** 134 | * Thread-safe method which performs a knn query in the index instance 135 | * using an normalized input. It will not normalize the vector again. 136 | * 137 | * @param input - a normalized float array; 138 | * @param k - number of results expected. 139 | * 140 | * @return a query tuple instance that contain the indices and coefficients. 141 | */ 142 | @Override 143 | public QueryTuple knnNormalizedQuery(float[] input, int k) { 144 | this.readLock.lock(); 145 | QueryTuple queryTuple; 146 | try { 147 | queryTuple = super.knnNormalizedQuery(input, k); 148 | } finally { 149 | this.readLock.unlock(); 150 | } 151 | return queryTuple; 152 | } 153 | 154 | /** 155 | * Thread-safe method which stores the content of the index into a file. 156 | * This method relies on the native implementation. 157 | * 158 | * @param path - destination path. 159 | */ 160 | @Override 161 | public void save(Path path) { 162 | this.readLock.lock(); 163 | try { 164 | super.save(path); 165 | } finally { 166 | this.readLock.unlock(); 167 | } 168 | } 169 | 170 | /** 171 | * Thread-safe method which loads the content stored in a file path onto the index. 172 | * 173 | * Note: if the index was previously initialized, the old 174 | * content will be erased. 175 | * 176 | * @param path - path to the index file; 177 | * @param maxNumberOfElements - max number of elements in the index. 178 | */ 179 | @Override 180 | public void load(Path path, int maxNumberOfElements) { 181 | this.writeLock.lock(); 182 | try { 183 | super.load(path, maxNumberOfElements); 184 | } finally { 185 | this.writeLock.unlock(); 186 | } 187 | } 188 | 189 | /** 190 | * Thread-safe method which frees the memory allocated for this index in the native context. 191 | * 192 | * NOTE: Once the index is cleared, it cannot be initialized or used again. 193 | */ 194 | @Override 195 | public void clear() { 196 | this.writeLock.lock(); 197 | try { 198 | super.clear(); 199 | } finally { 200 | this.writeLock.unlock(); 201 | } 202 | } 203 | 204 | /** 205 | * Thread-safe method which sets the query time accuracy / speed trade-off value. 206 | * 207 | * @param ef value. 208 | */ 209 | @Override 210 | public void setEf(int ef) { 211 | this.writeLock.lock(); 212 | try { 213 | super.setEf(ef); 214 | } finally { 215 | this.writeLock.unlock(); 216 | } 217 | } 218 | 219 | /** 220 | * Thread-safe method that checks whether there is an item with the specified identifier in the index. 221 | * 222 | * @param id - identifier. 223 | * 224 | * @return true or false. 225 | */ 226 | public boolean hasId(int id) { 227 | this.readLock.lock(); 228 | boolean hasId; 229 | try { 230 | hasId = super.hasId(id); 231 | } finally { 232 | this.readLock.unlock(); 233 | } 234 | return hasId; 235 | } 236 | 237 | /** 238 | * Thread-safe method that marks an ID as deleted. 239 | * 240 | * @param id identifier. 241 | */ 242 | public void markDeleted(int id) { 243 | this.writeLock.lock(); 244 | try { 245 | super.markDeleted(id); 246 | } finally { 247 | this.writeLock.unlock(); 248 | } 249 | } 250 | 251 | /** 252 | * Thread-safe method that gets the data from a specific identifier in the index. 253 | * 254 | * @param id - identifier. 255 | * 256 | * @return an optional containing or not the 257 | */ 258 | public Optional getData(int id) { 259 | this.readLock.lock(); 260 | Optional data; 261 | try { 262 | data = super.getData(id); 263 | } finally { 264 | this.readLock.unlock(); 265 | } 266 | return data; 267 | } 268 | 269 | } 270 | -------------------------------------------------------------------------------- /hnswlib/space_l2.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "hnswlib.h" 3 | 4 | namespace hnswlib { 5 | 6 | static float 7 | L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 8 | float *pVect1 = (float *) pVect1v; 9 | float *pVect2 = (float *) pVect2v; 10 | size_t qty = *((size_t *) qty_ptr); 11 | 12 | float res = 0; 13 | for (size_t i = 0; i < qty; i++) { 14 | float t = *pVect1 - *pVect2; 15 | pVect1++; 16 | pVect2++; 17 | res += t * t; 18 | } 19 | return (res); 20 | } 21 | 22 | #if defined(USE_AVX) 23 | 24 | // Favor using AVX if available. 25 | static float 26 | L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 27 | float *pVect1 = (float *) pVect1v; 28 | float *pVect2 = (float *) pVect2v; 29 | size_t qty = *((size_t *) qty_ptr); 30 | float PORTABLE_ALIGN32 TmpRes[8]; 31 | size_t qty16 = qty >> 4; 32 | 33 | const float *pEnd1 = pVect1 + (qty16 << 4); 34 | 35 | __m256 diff, v1, v2; 36 | __m256 sum = _mm256_set1_ps(0); 37 | 38 | while (pVect1 < pEnd1) { 39 | v1 = _mm256_loadu_ps(pVect1); 40 | pVect1 += 8; 41 | v2 = _mm256_loadu_ps(pVect2); 42 | pVect2 += 8; 43 | diff = _mm256_sub_ps(v1, v2); 44 | sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); 45 | 46 | v1 = _mm256_loadu_ps(pVect1); 47 | pVect1 += 8; 48 | v2 = _mm256_loadu_ps(pVect2); 49 | pVect2 += 8; 50 | diff = _mm256_sub_ps(v1, v2); 51 | sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); 52 | } 53 | 54 | _mm256_store_ps(TmpRes, sum); 55 | return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; 56 | } 57 | 58 | #elif defined(USE_SSE) 59 | 60 | static float 61 | L2SqrSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 62 | float *pVect1 = (float *) pVect1v; 63 | float *pVect2 = (float *) pVect2v; 64 | size_t qty = *((size_t *) qty_ptr); 65 | float PORTABLE_ALIGN32 TmpRes[8]; 66 | size_t qty16 = qty >> 4; 67 | 68 | const float *pEnd1 = pVect1 + (qty16 << 4); 69 | 70 | __m128 diff, v1, v2; 71 | __m128 sum = _mm_set1_ps(0); 72 | 73 | while (pVect1 < pEnd1) { 74 | //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); 75 | v1 = _mm_loadu_ps(pVect1); 76 | pVect1 += 4; 77 | v2 = _mm_loadu_ps(pVect2); 78 | pVect2 += 4; 79 | diff = _mm_sub_ps(v1, v2); 80 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 81 | 82 | v1 = _mm_loadu_ps(pVect1); 83 | pVect1 += 4; 84 | v2 = _mm_loadu_ps(pVect2); 85 | pVect2 += 4; 86 | diff = _mm_sub_ps(v1, v2); 87 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 88 | 89 | v1 = _mm_loadu_ps(pVect1); 90 | pVect1 += 4; 91 | v2 = _mm_loadu_ps(pVect2); 92 | pVect2 += 4; 93 | diff = _mm_sub_ps(v1, v2); 94 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 95 | 96 | v1 = _mm_loadu_ps(pVect1); 97 | pVect1 += 4; 98 | v2 = _mm_loadu_ps(pVect2); 99 | pVect2 += 4; 100 | diff = _mm_sub_ps(v1, v2); 101 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 102 | } 103 | 104 | _mm_store_ps(TmpRes, sum); 105 | return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 106 | } 107 | #endif 108 | 109 | #if defined(USE_SSE) || defined(USE_AVX) 110 | static float 111 | L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 112 | size_t qty = *((size_t *) qty_ptr); 113 | size_t qty16 = qty >> 4 << 4; 114 | float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16); 115 | float *pVect1 = (float *) pVect1v + qty16; 116 | float *pVect2 = (float *) pVect2v + qty16; 117 | 118 | size_t qty_left = qty - qty16; 119 | float res_tail = L2Sqr(pVect1, pVect2, &qty_left); 120 | return (res + res_tail); 121 | } 122 | #endif 123 | 124 | 125 | #ifdef USE_SSE 126 | static float 127 | L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 128 | float PORTABLE_ALIGN32 TmpRes[8]; 129 | float *pVect1 = (float *) pVect1v; 130 | float *pVect2 = (float *) pVect2v; 131 | size_t qty = *((size_t *) qty_ptr); 132 | 133 | 134 | size_t qty4 = qty >> 2; 135 | 136 | const float *pEnd1 = pVect1 + (qty4 << 2); 137 | 138 | __m128 diff, v1, v2; 139 | __m128 sum = _mm_set1_ps(0); 140 | 141 | while (pVect1 < pEnd1) { 142 | v1 = _mm_loadu_ps(pVect1); 143 | pVect1 += 4; 144 | v2 = _mm_loadu_ps(pVect2); 145 | pVect2 += 4; 146 | diff = _mm_sub_ps(v1, v2); 147 | sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); 148 | } 149 | _mm_store_ps(TmpRes, sum); 150 | return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 151 | } 152 | 153 | static float 154 | L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 155 | size_t qty = *((size_t *) qty_ptr); 156 | size_t qty4 = qty >> 2 << 2; 157 | 158 | float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4); 159 | size_t qty_left = qty - qty4; 160 | 161 | float *pVect1 = (float *) pVect1v + qty4; 162 | float *pVect2 = (float *) pVect2v + qty4; 163 | float res_tail = L2Sqr(pVect1, pVect2, &qty_left); 164 | 165 | return (res + res_tail); 166 | } 167 | #endif 168 | 169 | class L2Space : public SpaceInterface { 170 | 171 | DISTFUNC fstdistfunc_; 172 | size_t data_size_; 173 | size_t dim_; 174 | public: 175 | L2Space(size_t dim) { 176 | fstdistfunc_ = L2Sqr; 177 | #if defined(USE_SSE) || defined(USE_AVX) 178 | if (dim % 16 == 0) 179 | fstdistfunc_ = L2SqrSIMD16Ext; 180 | else if (dim % 4 == 0) 181 | fstdistfunc_ = L2SqrSIMD4Ext; 182 | else if (dim > 16) 183 | fstdistfunc_ = L2SqrSIMD16ExtResiduals; 184 | else if (dim > 4) 185 | fstdistfunc_ = L2SqrSIMD4ExtResiduals; 186 | #endif 187 | dim_ = dim; 188 | data_size_ = dim * sizeof(float); 189 | } 190 | 191 | size_t get_data_size() { 192 | return data_size_; 193 | } 194 | 195 | DISTFUNC get_dist_func() { 196 | return fstdistfunc_; 197 | } 198 | 199 | void *get_dist_func_param() { 200 | return &dim_; 201 | } 202 | 203 | ~L2Space() {} 204 | }; 205 | 206 | static int 207 | L2SqrI(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) { 208 | 209 | size_t qty = *((size_t *) qty_ptr); 210 | int res = 0; 211 | unsigned char *a = (unsigned char *) pVect1; 212 | unsigned char *b = (unsigned char *) pVect2; 213 | 214 | qty = qty >> 2; 215 | for (size_t i = 0; i < qty; i++) { 216 | 217 | res += ((*a) - (*b)) * ((*a) - (*b)); 218 | a++; 219 | b++; 220 | res += ((*a) - (*b)) * ((*a) - (*b)); 221 | a++; 222 | b++; 223 | res += ((*a) - (*b)) * ((*a) - (*b)); 224 | a++; 225 | b++; 226 | res += ((*a) - (*b)) * ((*a) - (*b)); 227 | a++; 228 | b++; 229 | 230 | 231 | } 232 | 233 | return (res); 234 | 235 | } 236 | 237 | class L2SpaceI : public SpaceInterface { 238 | 239 | DISTFUNC fstdistfunc_; 240 | size_t data_size_; 241 | size_t dim_; 242 | public: 243 | L2SpaceI(size_t dim) { 244 | fstdistfunc_ = L2SqrI; 245 | dim_ = dim; 246 | data_size_ = dim * sizeof(unsigned char); 247 | } 248 | 249 | size_t get_data_size() { 250 | return data_size_; 251 | } 252 | 253 | DISTFUNC get_dist_func() { 254 | return fstdistfunc_; 255 | } 256 | 257 | void *get_dist_func_param() { 258 | return &dim_; 259 | } 260 | 261 | ~L2SpaceI() {} 262 | }; 263 | 264 | 265 | } -------------------------------------------------------------------------------- /hnswlib/space_ip.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include "hnswlib.h" 3 | 4 | namespace hnswlib { 5 | 6 | static float 7 | InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) { 8 | size_t qty = *((size_t *) qty_ptr); 9 | float res = 0; 10 | for (unsigned i = 0; i < qty; i++) { 11 | res += ((float *) pVect1)[i] * ((float *) pVect2)[i]; 12 | } 13 | return (1.0f - res); 14 | 15 | } 16 | 17 | #if defined(USE_AVX) 18 | 19 | // Favor using AVX if available. 20 | static float 21 | InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 22 | float PORTABLE_ALIGN32 TmpRes[8]; 23 | float *pVect1 = (float *) pVect1v; 24 | float *pVect2 = (float *) pVect2v; 25 | size_t qty = *((size_t *) qty_ptr); 26 | 27 | size_t qty16 = qty / 16; 28 | size_t qty4 = qty / 4; 29 | 30 | const float *pEnd1 = pVect1 + 16 * qty16; 31 | const float *pEnd2 = pVect1 + 4 * qty4; 32 | 33 | __m256 sum256 = _mm256_set1_ps(0); 34 | 35 | while (pVect1 < pEnd1) { 36 | //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); 37 | 38 | __m256 v1 = _mm256_loadu_ps(pVect1); 39 | pVect1 += 8; 40 | __m256 v2 = _mm256_loadu_ps(pVect2); 41 | pVect2 += 8; 42 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); 43 | 44 | v1 = _mm256_loadu_ps(pVect1); 45 | pVect1 += 8; 46 | v2 = _mm256_loadu_ps(pVect2); 47 | pVect2 += 8; 48 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); 49 | } 50 | 51 | __m128 v1, v2; 52 | __m128 sum_prod = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1)); 53 | 54 | while (pVect1 < pEnd2) { 55 | v1 = _mm_loadu_ps(pVect1); 56 | pVect1 += 4; 57 | v2 = _mm_loadu_ps(pVect2); 58 | pVect2 += 4; 59 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 60 | } 61 | 62 | _mm_store_ps(TmpRes, sum_prod); 63 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];; 64 | return 1.0f - sum; 65 | } 66 | 67 | #elif defined(USE_SSE) 68 | 69 | static float 70 | InnerProductSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 71 | float PORTABLE_ALIGN32 TmpRes[8]; 72 | float *pVect1 = (float *) pVect1v; 73 | float *pVect2 = (float *) pVect2v; 74 | size_t qty = *((size_t *) qty_ptr); 75 | 76 | size_t qty16 = qty / 16; 77 | size_t qty4 = qty / 4; 78 | 79 | const float *pEnd1 = pVect1 + 16 * qty16; 80 | const float *pEnd2 = pVect1 + 4 * qty4; 81 | 82 | __m128 v1, v2; 83 | __m128 sum_prod = _mm_set1_ps(0); 84 | 85 | while (pVect1 < pEnd1) { 86 | v1 = _mm_loadu_ps(pVect1); 87 | pVect1 += 4; 88 | v2 = _mm_loadu_ps(pVect2); 89 | pVect2 += 4; 90 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 91 | 92 | v1 = _mm_loadu_ps(pVect1); 93 | pVect1 += 4; 94 | v2 = _mm_loadu_ps(pVect2); 95 | pVect2 += 4; 96 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 97 | 98 | v1 = _mm_loadu_ps(pVect1); 99 | pVect1 += 4; 100 | v2 = _mm_loadu_ps(pVect2); 101 | pVect2 += 4; 102 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 103 | 104 | v1 = _mm_loadu_ps(pVect1); 105 | pVect1 += 4; 106 | v2 = _mm_loadu_ps(pVect2); 107 | pVect2 += 4; 108 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 109 | } 110 | 111 | while (pVect1 < pEnd2) { 112 | v1 = _mm_loadu_ps(pVect1); 113 | pVect1 += 4; 114 | v2 = _mm_loadu_ps(pVect2); 115 | pVect2 += 4; 116 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 117 | } 118 | 119 | _mm_store_ps(TmpRes, sum_prod); 120 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 121 | 122 | return 1.0f - sum; 123 | } 124 | 125 | #endif 126 | 127 | #if defined(USE_AVX) 128 | 129 | static float 130 | InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 131 | float PORTABLE_ALIGN32 TmpRes[8]; 132 | float *pVect1 = (float *) pVect1v; 133 | float *pVect2 = (float *) pVect2v; 134 | size_t qty = *((size_t *) qty_ptr); 135 | 136 | size_t qty16 = qty / 16; 137 | 138 | 139 | const float *pEnd1 = pVect1 + 16 * qty16; 140 | 141 | __m256 sum256 = _mm256_set1_ps(0); 142 | 143 | while (pVect1 < pEnd1) { 144 | //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); 145 | 146 | __m256 v1 = _mm256_loadu_ps(pVect1); 147 | pVect1 += 8; 148 | __m256 v2 = _mm256_loadu_ps(pVect2); 149 | pVect2 += 8; 150 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); 151 | 152 | v1 = _mm256_loadu_ps(pVect1); 153 | pVect1 += 8; 154 | v2 = _mm256_loadu_ps(pVect2); 155 | pVect2 += 8; 156 | sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); 157 | } 158 | 159 | _mm256_store_ps(TmpRes, sum256); 160 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; 161 | 162 | return 1.0f - sum; 163 | } 164 | 165 | #elif defined(USE_SSE) 166 | 167 | static float 168 | InnerProductSIMD16Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 169 | float PORTABLE_ALIGN32 TmpRes[8]; 170 | float *pVect1 = (float *) pVect1v; 171 | float *pVect2 = (float *) pVect2v; 172 | size_t qty = *((size_t *) qty_ptr); 173 | 174 | size_t qty16 = qty / 16; 175 | 176 | const float *pEnd1 = pVect1 + 16 * qty16; 177 | 178 | __m128 v1, v2; 179 | __m128 sum_prod = _mm_set1_ps(0); 180 | 181 | while (pVect1 < pEnd1) { 182 | v1 = _mm_loadu_ps(pVect1); 183 | pVect1 += 4; 184 | v2 = _mm_loadu_ps(pVect2); 185 | pVect2 += 4; 186 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 187 | 188 | v1 = _mm_loadu_ps(pVect1); 189 | pVect1 += 4; 190 | v2 = _mm_loadu_ps(pVect2); 191 | pVect2 += 4; 192 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 193 | 194 | v1 = _mm_loadu_ps(pVect1); 195 | pVect1 += 4; 196 | v2 = _mm_loadu_ps(pVect2); 197 | pVect2 += 4; 198 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 199 | 200 | v1 = _mm_loadu_ps(pVect1); 201 | pVect1 += 4; 202 | v2 = _mm_loadu_ps(pVect2); 203 | pVect2 += 4; 204 | sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); 205 | } 206 | _mm_store_ps(TmpRes, sum_prod); 207 | float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; 208 | 209 | return 1.0f - sum; 210 | } 211 | 212 | #endif 213 | 214 | #if defined(USE_SSE) || defined(USE_AVX) 215 | static float 216 | InnerProductSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 217 | size_t qty = *((size_t *) qty_ptr); 218 | size_t qty16 = qty >> 4 << 4; 219 | float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16); 220 | float *pVect1 = (float *) pVect1v + qty16; 221 | float *pVect2 = (float *) pVect2v + qty16; 222 | 223 | size_t qty_left = qty - qty16; 224 | float res_tail = InnerProduct(pVect1, pVect2, &qty_left); 225 | return res + res_tail - 1.0f; 226 | } 227 | 228 | static float 229 | InnerProductSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { 230 | size_t qty = *((size_t *) qty_ptr); 231 | size_t qty4 = qty >> 2 << 2; 232 | 233 | float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4); 234 | size_t qty_left = qty - qty4; 235 | 236 | float *pVect1 = (float *) pVect1v + qty4; 237 | float *pVect2 = (float *) pVect2v + qty4; 238 | float res_tail = InnerProduct(pVect1, pVect2, &qty_left); 239 | 240 | return res + res_tail - 1.0f; 241 | } 242 | #endif 243 | 244 | class InnerProductSpace : public SpaceInterface { 245 | 246 | DISTFUNC fstdistfunc_; 247 | size_t data_size_; 248 | size_t dim_; 249 | public: 250 | InnerProductSpace(size_t dim) { 251 | fstdistfunc_ = InnerProduct; 252 | #if defined(USE_AVX) || defined(USE_SSE) 253 | if (dim % 16 == 0) 254 | fstdistfunc_ = InnerProductSIMD16Ext; 255 | else if (dim % 4 == 0) 256 | fstdistfunc_ = InnerProductSIMD4Ext; 257 | else if (dim > 16) 258 | fstdistfunc_ = InnerProductSIMD16ExtResiduals; 259 | else if (dim > 4) 260 | fstdistfunc_ = InnerProductSIMD4ExtResiduals; 261 | #endif 262 | dim_ = dim; 263 | data_size_ = dim * sizeof(float); 264 | } 265 | 266 | size_t get_data_size() { 267 | return data_size_; 268 | } 269 | 270 | DISTFUNC get_dist_func() { 271 | return fstdistfunc_; 272 | } 273 | 274 | void *get_dist_func_param() { 275 | return &dim_; 276 | } 277 | 278 | ~InnerProductSpace() {} 279 | }; 280 | 281 | 282 | } 283 | -------------------------------------------------------------------------------- /bindings.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * Simplified C binding for hnswlib able to work with JNA (Java Native Access) 3 | * in order to get a similar native performance on Java. This code is based on the python 4 | * binding available at: https://github.com/nmslib/hnswlib/blob/master/python_bindings/bindings.cpp 5 | * 6 | * Some modifications and simplifications have been done on the C side. 7 | * The multithread support can be used and handled on the Java side. 8 | * 9 | * This work is still in progress. Please feel free to contribute and give ideas. 10 | */ 11 | 12 | #include 13 | #include 14 | #include 15 | #include "hnswlib/hnswlib.h" 16 | 17 | #if _WIN32 18 | #define DLLEXPORT __declspec(dllexport) 19 | #else 20 | #define DLLEXPORT 21 | #endif 22 | 23 | #define EXTERN_C extern "C" 24 | 25 | #define RESULT_SUCCESSFUL 0 26 | #define RESULT_EXCEPTION_THROWN 1 27 | #define RESULT_INDEX_ALREADY_INITIALIZED 2 28 | #define RESULT_QUERY_CANNOT_RETURN 3 29 | #define RESULT_ITEM_CANNOT_BE_INSERTED_INTO_THE_VECTOR_SPACE 4 30 | #define RESULT_ONCE_INDEX_IS_CLEARED_IT_CANNOT_BE_REUSED 5 31 | #define RESULT_GET_DATA_FAILED 6 32 | #define RESULT_ID_NOT_IN_INDEX 7 33 | #define RESULT_INDEX_NOT_INITIALIZED 8 34 | 35 | #define TRY_CATCH_NO_INITIALIZE_CHECK_AND_RETURN_INT_BLOCK(block) if (index_cleared) return RESULT_ONCE_INDEX_IS_CLEARED_IT_CANNOT_BE_REUSED; int result_code = RESULT_SUCCESSFUL; try { block } catch (...) { result_code = RESULT_EXCEPTION_THROWN; }; return result_code; 36 | #define TRY_CATCH_RETURN_INT_BLOCK(block) if (!index_initialized) return RESULT_INDEX_NOT_INITIALIZED; TRY_CATCH_NO_INITIALIZE_CHECK_AND_RETURN_INT_BLOCK(block) 37 | 38 | template 39 | class Index { 40 | public: 41 | Index(const std::string &space_name, const int dim) : 42 | space_name(space_name), dim(dim) { 43 | data_must_be_normalized = false; 44 | if(space_name=="L2") { 45 | l2space = new hnswlib::L2Space(dim); 46 | } else if(space_name=="IP") { 47 | l2space = new hnswlib::InnerProductSpace(dim); 48 | } else if(space_name=="COSINE") { 49 | l2space = new hnswlib::InnerProductSpace(dim); 50 | data_must_be_normalized = true; 51 | } 52 | appr_alg = NULL; 53 | index_initialized = false; 54 | index_cleared = false; 55 | } 56 | 57 | int init_new_index(const size_t maxElements, const size_t M, const size_t efConstruction, const size_t random_seed) { 58 | TRY_CATCH_NO_INITIALIZE_CHECK_AND_RETURN_INT_BLOCK({ 59 | if (appr_alg) { 60 | return RESULT_INDEX_ALREADY_INITIALIZED; 61 | } 62 | appr_alg = new hnswlib::HierarchicalNSW(l2space, maxElements, M, efConstruction, random_seed); 63 | index_initialized = true; 64 | }); 65 | } 66 | 67 | int set_ef(size_t ef) { 68 | TRY_CATCH_RETURN_INT_BLOCK({ 69 | appr_alg->ef_ = ef; 70 | }); 71 | } 72 | 73 | int get_ef() { 74 | return appr_alg->ef_; 75 | } 76 | 77 | int get_ef_construction() { 78 | return appr_alg->ef_construction_; 79 | } 80 | 81 | int get_M() { 82 | return appr_alg->M_; 83 | } 84 | 85 | int save_index(const std::string &path_to_index) { 86 | TRY_CATCH_RETURN_INT_BLOCK({ 87 | appr_alg->saveIndex(path_to_index); 88 | }); 89 | } 90 | 91 | int load_index(const std::string &path_to_index, size_t max_elements) { 92 | TRY_CATCH_NO_INITIALIZE_CHECK_AND_RETURN_INT_BLOCK({ 93 | if (appr_alg) { 94 | std::cerr << "Warning: Calling load_index for an already initialized index. Old index is being deallocated."; 95 | delete appr_alg; 96 | } 97 | appr_alg = new hnswlib::HierarchicalNSW(l2space, path_to_index, false, max_elements); 98 | }); 99 | } 100 | 101 | void normalize_array(float* array){ 102 | float norm = 0.0f; 103 | for (int i=0; i= get_max_elements()) { 115 | return RESULT_ITEM_CANNOT_BE_INSERTED_INTO_THE_VECTOR_SPACE; 116 | } 117 | if ((data_must_be_normalized == true) && (item_normalized == false)) { 118 | normalize_array(item); 119 | } 120 | int current_id = id != -1 ? id : incremental_id++; 121 | appr_alg->addPoint(item, current_id); 122 | }); 123 | } 124 | 125 | int hasId(int id) { 126 | TRY_CATCH_RETURN_INT_BLOCK({ 127 | int label_c; 128 | auto search = (appr_alg->label_lookup_.find(id)); 129 | if (search == (appr_alg->label_lookup_.end()) || (appr_alg->isMarkedDeleted(search->second))) { 130 | return RESULT_ID_NOT_IN_INDEX; 131 | } 132 | }); 133 | } 134 | 135 | int getDataById(int id, float* data, int dim) { 136 | TRY_CATCH_RETURN_INT_BLOCK({ 137 | int label_c; 138 | auto search = (appr_alg->label_lookup_.find(id)); 139 | if (search == (appr_alg->label_lookup_.end()) || (appr_alg->isMarkedDeleted(search->second))) { 140 | return RESULT_ID_NOT_IN_INDEX; 141 | } 142 | label_c = search->second; 143 | char* data_ptrv = (appr_alg->getDataByInternalId(label_c)); 144 | float* data_ptr = (float*) data_ptrv; 145 | for (int i = 0; i < dim; i++) { 146 | data[i] = *data_ptr; 147 | data_ptr += 1; 148 | } 149 | }); 150 | } 151 | 152 | float compute_similarity(float* vector1, float* vector2) { 153 | float similarity; 154 | try { 155 | similarity = (appr_alg->fstdistfunc_(vector1, vector2, (appr_alg -> dist_func_param_))); 156 | } catch (...) { 157 | similarity = NAN; 158 | } 159 | return similarity; 160 | } 161 | 162 | int knn_query(float* input, bool input_normalized, int k, int* indices /* output */, float* coefficients /* output */) { 163 | std::priority_queue> result; 164 | TRY_CATCH_RETURN_INT_BLOCK({ 165 | if ((data_must_be_normalized == true) && (input_normalized == false)) { 166 | normalize_array(input); 167 | } 168 | result = appr_alg->searchKnn((void*) input, k); 169 | if (result.size() != k) 170 | return RESULT_QUERY_CANNOT_RETURN; 171 | for (int i = k - 1; i >= 0; i--) { 172 | auto &result_tuple = result.top(); 173 | coefficients[i] = result_tuple.first; 174 | indices[i] = result_tuple.second; 175 | result.pop(); 176 | } 177 | }); 178 | } 179 | 180 | int mark_deleted(int label) { 181 | TRY_CATCH_RETURN_INT_BLOCK({ 182 | appr_alg->markDelete(label); 183 | }); 184 | } 185 | 186 | void resize_index(size_t new_size) { 187 | appr_alg->resizeIndex(new_size); 188 | } 189 | 190 | int get_max_elements() const { 191 | return appr_alg->max_elements_; 192 | } 193 | 194 | int get_current_count() const { 195 | return appr_alg->cur_element_count; 196 | } 197 | 198 | int clear_index() { 199 | TRY_CATCH_NO_INITIALIZE_CHECK_AND_RETURN_INT_BLOCK({ 200 | delete l2space; 201 | if (appr_alg) 202 | delete appr_alg; 203 | index_cleared = true; 204 | }); 205 | } 206 | 207 | std::string space_name; 208 | int dim; 209 | bool index_cleared; 210 | bool index_initialized; 211 | bool data_must_be_normalized; 212 | std::atomic incremental_id{0}; 213 | hnswlib::HierarchicalNSW *appr_alg; 214 | hnswlib::SpaceInterface *l2space; 215 | 216 | ~Index() { 217 | clear_index(); 218 | } 219 | }; 220 | 221 | EXTERN_C DLLEXPORT Index* createNewIndex(char* spaceName, int dimension){ 222 | Index* index; 223 | try { 224 | index = new Index(spaceName, dimension); 225 | } catch (...) { 226 | index = NULL; 227 | } 228 | return index; 229 | } 230 | 231 | EXTERN_C DLLEXPORT int initNewIndex(Index* index, int maxNumberOfElements, int M = 16, int efConstruction = 200, int randomSeed = 100) { 232 | return index->init_new_index(maxNumberOfElements, M, efConstruction, randomSeed); 233 | } 234 | 235 | EXTERN_C DLLEXPORT int addItemToIndex(float* item, int normalized, int label, Index* index) { 236 | return index->add_item(item, normalized, label); 237 | } 238 | 239 | EXTERN_C DLLEXPORT int getIndexLength(Index* index) { 240 | if (index->appr_alg) { 241 | return index->appr_alg->cur_element_count; 242 | } else { 243 | return 0; 244 | } 245 | } 246 | 247 | EXTERN_C DLLEXPORT int saveIndexToPath(Index* index, char* path) { 248 | std::string path_string(path); 249 | return index->save_index(path_string); 250 | } 251 | 252 | EXTERN_C DLLEXPORT int loadIndexFromPath(Index* index, size_t maxNumberOfElements, char* path) { 253 | std::string path_string(path); 254 | return index->load_index(path_string, maxNumberOfElements); 255 | } 256 | 257 | EXTERN_C DLLEXPORT int knnQuery(Index* index, float* input, int normalized, int k, int* indices /* output */, float* coefficients /* output */) { 258 | return index->knn_query(input, normalized, k, indices, coefficients); 259 | } 260 | 261 | EXTERN_C DLLEXPORT int clearIndex(Index* index) { 262 | return index->clear_index(); 263 | } 264 | 265 | EXTERN_C DLLEXPORT int setEf(Index* index, int ef) { 266 | return index->set_ef(ef); 267 | } 268 | 269 | EXTERN_C DLLEXPORT int getData(Index* index, int id, float* vector, int dim) { 270 | return index->getDataById(id, vector, dim); 271 | } 272 | 273 | EXTERN_C DLLEXPORT int hasId(Index* index, int id) { 274 | return index->hasId(id); 275 | } 276 | 277 | EXTERN_C DLLEXPORT float computeSimilarity(Index* index, float* vector1, float* vector2) { 278 | return index->compute_similarity(vector1, vector2); 279 | } 280 | 281 | EXTERN_C DLLEXPORT int getM(Index* index) { 282 | return index->get_M(); 283 | } 284 | 285 | EXTERN_C DLLEXPORT int getEfConstruction(Index* index) { 286 | return index->get_ef_construction(); 287 | } 288 | 289 | EXTERN_C DLLEXPORT int getEf(Index* index) { 290 | return index->get_ef(); 291 | } 292 | 293 | EXTERN_C DLLEXPORT int markDeleted(Index* index, int id) { 294 | return index->mark_deleted(id); 295 | } 296 | 297 | int main(){ 298 | return RESULT_SUCCESSFUL; 299 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /hnswlib-jna/src/main/java/com/stepstone/search/hnswlib/jna/Index.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna; 2 | 3 | import com.stepstone.search.hnswlib.jna.exception.IndexAlreadyInitializedException; 4 | import com.stepstone.search.hnswlib.jna.exception.IndexNotInitializedException; 5 | import com.stepstone.search.hnswlib.jna.exception.ItemCannotBeInsertedIntoTheVectorSpaceException; 6 | import com.stepstone.search.hnswlib.jna.exception.OnceIndexIsClearedItCannotBeReusedException; 7 | import com.stepstone.search.hnswlib.jna.exception.QueryCannotReturnResultsException; 8 | import com.stepstone.search.hnswlib.jna.exception.UnableToCreateNewIndexInstanceException; 9 | import com.stepstone.search.hnswlib.jna.exception.UnexpectedNativeException; 10 | import com.sun.jna.Pointer; 11 | import it.unimi.dsi.fastutil.ints.IntArraySet; 12 | import it.unimi.dsi.fastutil.ints.IntSet; 13 | import it.unimi.dsi.fastutil.ints.IntSets; 14 | 15 | import java.nio.file.Path; 16 | import java.util.Optional; 17 | 18 | /** 19 | * Represents a small world index in the java context. 20 | * This class includes some abstraction to make the integration 21 | * with the native library a bit more java like and relies on the 22 | * JNA implementation. 23 | * 24 | * Each instance of index has a different memory context and should 25 | * work independently. 26 | */ 27 | public class Index { 28 | 29 | protected static final int NO_ID = -1; 30 | private static final int RESULT_SUCCESSFUL = 0; 31 | private static final int RESULT_QUERY_NO_RESULTS = 3; 32 | private static final int RESULT_ITEM_CANNOT_BE_INSERTED_INTO_THE_VECTOR_SPACE = 4; 33 | private static final int RESULT_ONCE_INDEX_IS_CLEARED_IT_CANNOT_BE_REUSED = 5; 34 | private static final int RESULT_INDEX_NOT_INITIALIZED = 8; 35 | 36 | private static Hnswlib hnswlib = HnswlibFactory.getInstance(); 37 | 38 | private Pointer reference; 39 | private boolean initialized; 40 | private boolean cleared; 41 | private SpaceName spaceName; 42 | private int dimension; 43 | private boolean referenceReused; 44 | private IntSet ids = IntSets.synchronize(new IntArraySet()); 45 | 46 | public Index(SpaceName spaceName, int dimension) { 47 | this.spaceName = spaceName; 48 | this.dimension = dimension; 49 | reference = hnswlib.createNewIndex(spaceName.toString(), dimension); 50 | if (reference == null) { 51 | throw new UnableToCreateNewIndexInstanceException(); 52 | } 53 | } 54 | 55 | /** 56 | * This method initializes the index with the default values 57 | * for the parameters m, efConstruction, randomSeed and sets 58 | * the maxNumberOfElements to 1_000_000. 59 | * 60 | * Note: not setting the maxNumberOfElements might lead to out of memory 61 | * issues and unpredictable behaviours in your application. Thus, use this 62 | * method wisely and combine it with monitoring. 63 | * 64 | * For more information, please @see {link #initialize(int, int, int, int)}. 65 | */ 66 | public void initialize() { 67 | initialize(1_000_000); 68 | } 69 | 70 | /** 71 | * For more information, please @see {link #initialize(int, int, int, int)}. 72 | * 73 | * @param maxNumberOfElements allowed in the index. 74 | */ 75 | public void initialize(int maxNumberOfElements) { 76 | initialize(maxNumberOfElements, 16, 200, 100); 77 | } 78 | 79 | /** 80 | * Initialize the index to be used. 81 | * 82 | * @param maxNumberOfElements ; 83 | * @param m ; 84 | * @param efConstruction ; 85 | * @param randomSeed . 86 | * 87 | * @throws IndexAlreadyInitializedException when a index reference was initialized before. 88 | * @throws UnexpectedNativeException when something unexpected happened in the native side. 89 | */ 90 | public void initialize(int maxNumberOfElements, int m, int efConstruction, int randomSeed) { 91 | if (initialized) { 92 | throw new IndexAlreadyInitializedException(); 93 | } else { 94 | checkResultCode(hnswlib.initNewIndex(reference, maxNumberOfElements, m, efConstruction, randomSeed)); 95 | initialized = true; 96 | } 97 | } 98 | 99 | /** 100 | * Sets the query time accuracy / speed trade-off value. 101 | * 102 | * @param ef value. 103 | */ 104 | public void setEf(int ef) { 105 | checkResultCode(hnswlib.setEf(reference, ef)); 106 | } 107 | 108 | /** 109 | * Add an item without ID to the index. Internally, an incremental 110 | * identifier (starting from 1) will be given to this item. 111 | * 112 | * @param item - float array with the length expected by the index (dimension). 113 | */ 114 | public void addItem(float[] item) { 115 | addItem(item, NO_ID); 116 | } 117 | 118 | /** 119 | * Add an item with ID to the index. It won't apply any extra normalization 120 | * unless it is required by the Vector Space (e.g., COSINE). 121 | * 122 | * @param item - float array with the length expected by the index (dimension); 123 | * @param id - an identifier used by the native library. 124 | */ 125 | public void addItem(float[] item, int id) { 126 | checkResultCode(hnswlib.addItemToIndex(item, false, id, reference)); 127 | } 128 | 129 | /** 130 | * Add an item with ID to the index. It won't apply any extra normalization 131 | * unless it is required by the Vector Space (e.g., COSINE). 132 | * Save id to internal collection if saveId = true 133 | * @param item item 134 | * @param id id 135 | * @param saveId true to save id to internal collection 136 | */ 137 | public void addItem(float[] item, int id, boolean saveId) { 138 | addItem(item, id); 139 | if (saveId) { 140 | ids.add(id); 141 | } 142 | } 143 | 144 | public void setIds(IntSet ids) { 145 | this.ids = ids; 146 | } 147 | 148 | /** 149 | * Get ids for items in this index 150 | * @return set of ids 151 | */ 152 | public IntSet getIds() { 153 | return ids; 154 | } 155 | 156 | /** 157 | * Add a normalized item without ID to the index. Internally, an incremental 158 | * ID (starting from 0) will be given to this item. 159 | * 160 | * @param item - float array with the length expected by the index (dimension). 161 | */ 162 | public void addNormalizedItem(float[] item) { 163 | addNormalizedItem(item, NO_ID); 164 | } 165 | 166 | /** 167 | * Add a normalized item with ID to the index. 168 | * 169 | * @param item - float array with the length expected by the index (dimension); 170 | * @param id - an identifier used by the native library. 171 | */ 172 | public void addNormalizedItem(float[] item, int id) { 173 | checkResultCode(hnswlib.addItemToIndex(item, true, id, reference)); 174 | } 175 | 176 | /** 177 | * Add a normalized item with ID to the index. 178 | * 179 | * @param item - float array with the length expected by the index (dimension); 180 | * @param id - an identifier used by the native library. 181 | */ 182 | public void addNormalizedItem(float[] item, int id, boolean saveId) { 183 | addNormalizedItem(item, id); 184 | if (saveId) { 185 | ids.add(id); 186 | } 187 | } 188 | 189 | /** 190 | * Return the number of elements already inserted in 191 | * the index. 192 | * 193 | * @return elements count. 194 | */ 195 | public int getLength(){ 196 | return hnswlib.getIndexLength(reference); 197 | } 198 | 199 | /** 200 | * Performs a knn query in the index instance. In case the vector space requires 201 | * the input to be normalized, it will normalize at the native level. 202 | * 203 | * @param input - float array; 204 | * @param k - number of results expected. 205 | * 206 | * @return a query tuple instance that contain the indices and coefficients. 207 | */ 208 | public QueryTuple knnQuery(float[] input, int k) { 209 | QueryTuple queryTuple = new QueryTuple(k); 210 | checkResultCode(hnswlib.knnQuery(reference, input, false, k, queryTuple.ids, queryTuple.coefficients)); 211 | return queryTuple; 212 | } 213 | 214 | /** 215 | * Performs a knn query in the index instance using an normalized input. 216 | * It will not normalize the vector again. 217 | * 218 | * @param input - a normalized float array; 219 | * @param k - number of results expected. 220 | * 221 | * @return a query tuple instance that contain the indices and coefficients. 222 | */ 223 | public QueryTuple knnNormalizedQuery(float[] input, int k) { 224 | QueryTuple queryTuple = new QueryTuple(k); 225 | checkResultCode(hnswlib.knnQuery(reference, input, true, k, queryTuple.ids, queryTuple.coefficients)); 226 | return queryTuple; 227 | } 228 | 229 | /** 230 | * Stores the content of the index into a file. 231 | * This method relies on the native implementation. 232 | * 233 | * @param path - destination path. 234 | */ 235 | public void save(Path path) { 236 | checkResultCode(hnswlib.saveIndexToPath(reference, path.toAbsolutePath().toString())); 237 | } 238 | 239 | /** 240 | * This method loads the content stored in a file path onto the index. 241 | * 242 | * Note: if the index was previously initialized, the old 243 | * content will be erased. 244 | * 245 | * @param path - path to the index file; 246 | * @param maxNumberOfElements - max number of elements in the index. 247 | */ 248 | public void load(Path path, int maxNumberOfElements) { 249 | checkResultCode(hnswlib.loadIndexFromPath(reference, maxNumberOfElements, path.toAbsolutePath().toString())); 250 | } 251 | 252 | /** 253 | * Free the memory allocated for this index in the native context. 254 | * 255 | * NOTE: Once the index is cleared, it cannot be initialized or used again. 256 | */ 257 | public void clear() { 258 | checkResultCode(hnswlib.clearIndex(reference)); 259 | cleared = true; 260 | } 261 | 262 | /** 263 | * Cleanup the area allocated by the index in the native side. 264 | * 265 | * @throws Throwable when anything weird happened. :) 266 | */ 267 | @Override 268 | protected void finalize() throws Throwable { 269 | if (!cleared && !referenceReused) { 270 | this.clear(); 271 | } 272 | super.finalize(); 273 | } 274 | 275 | /** 276 | * This method checks the result code coming from the 277 | * native execution is correct otherwise throws an exception. 278 | * 279 | * @throws UnexpectedNativeException when something went out of control in the native side. 280 | */ 281 | private void checkResultCode(int resultCode) { 282 | switch (resultCode) { 283 | case RESULT_SUCCESSFUL: 284 | break; 285 | case RESULT_QUERY_NO_RESULTS: 286 | throw new QueryCannotReturnResultsException(); 287 | case RESULT_ITEM_CANNOT_BE_INSERTED_INTO_THE_VECTOR_SPACE: 288 | throw new ItemCannotBeInsertedIntoTheVectorSpaceException(); 289 | case RESULT_ONCE_INDEX_IS_CLEARED_IT_CANNOT_BE_REUSED: 290 | throw new OnceIndexIsClearedItCannotBeReusedException(); 291 | case RESULT_INDEX_NOT_INITIALIZED: 292 | throw new IndexNotInitializedException(); 293 | default: 294 | throw new UnexpectedNativeException(); 295 | } 296 | } 297 | 298 | /** 299 | * Checks whether there is an item with the specified identifier in the index. 300 | * 301 | * @param id - identifier. 302 | * @return true or false. 303 | */ 304 | public boolean hasId(int id) { 305 | return hnswlib.hasId(reference, id) == RESULT_SUCCESSFUL; 306 | } 307 | 308 | /** 309 | * Gets the data from a specific identifier in the index. 310 | * 311 | * @param id - identifier. 312 | * 313 | * @return an optional containing or not the 314 | */ 315 | public Optional getData(int id) { 316 | float[] vector = new float[dimension]; 317 | int success = hnswlib.getData(reference, id, vector, dimension); 318 | if (success == RESULT_SUCCESSFUL) { 319 | return Optional.of(vector); 320 | } 321 | return Optional.empty(); 322 | } 323 | 324 | /** 325 | * Computer similarity on the native side taking advantage of 326 | * SSE, AVX, SIMD instructions, when available. 327 | * 328 | * @param vector1 array with correct dimension; 329 | * @param vector2 array with correct dimension. 330 | * 331 | * @return the similarity score. 332 | */ 333 | public float computeSimilarity(float[] vector1, float[] vector2) { 334 | checkIndexIsInitialized(); 335 | return hnswlib.computeSimilarity(reference, vector1, vector2); 336 | } 337 | 338 | /** 339 | * Retrieves the current M value. 340 | * 341 | * @return the M value. 342 | */ 343 | public int getM(){ 344 | checkIndexIsInitialized(); 345 | return hnswlib.getM(reference); 346 | } 347 | 348 | /** 349 | * Retrieves the current Ef value. 350 | * 351 | * @return the EF value. 352 | */ 353 | public int getEf(){ 354 | checkIndexIsInitialized(); 355 | return hnswlib.getEf(reference); 356 | } 357 | 358 | /** 359 | * Retrieves the current ef construction. 360 | * 361 | * @return the ef construction value. 362 | */ 363 | public int getEfConstruction(){ 364 | checkIndexIsInitialized(); 365 | return hnswlib.getEfConstruction(reference); 366 | } 367 | 368 | /** 369 | * Marks an ID as deleted. 370 | * 371 | * @param id identifier. 372 | */ 373 | public void markDeleted(int id){ 374 | checkResultCode(hnswlib.markDeleted(reference, id)); 375 | if (ids.contains(id)) { 376 | ids.remove(id); 377 | } 378 | } 379 | 380 | private void checkIndexIsInitialized() { 381 | if (!initialized) { 382 | throw new IndexNotInitializedException(); 383 | } 384 | } 385 | 386 | /** 387 | * Util function that normalizes an array. 388 | * 389 | * @param array input. 390 | */ 391 | public static strictfp void normalize(float [] array){ 392 | int n = array.length; 393 | float norm = 0; 394 | for (float v : array) { 395 | norm += v * v; 396 | } 397 | norm = (float) (1.0f / (Math.sqrt(norm) + 1e-30f)); 398 | for (int i = 0; i < n; i++) { 399 | array[i] = array[i] * norm; 400 | } 401 | } 402 | 403 | /** 404 | * This method returns a ConcurrentIndex instance which contains 405 | * the same items present in a Index object. The indexes will share 406 | * the same native pointer, so there will be no memory duplication. 407 | * 408 | * It is important to say that the Index class allows adding items is in 409 | * parallel, so the building time can be much slower. On the other hand, 410 | * ConcurrentIndex offers thread-safe methods for adding and querying, which 411 | * can be interesting for multi-threaded environment with online 412 | * updates. Via this method, you can get the best of the two worlds. 413 | * 414 | * @param index with the items added 415 | * @return a thread-safe index 416 | */ 417 | public static Index synchronizedIndex(Index index) { 418 | Index concurrentIndex = new ConcurrentIndex(index.spaceName, index.dimension); 419 | concurrentIndex.reference = index.reference; 420 | concurrentIndex.cleared = index.cleared; 421 | concurrentIndex.initialized = index.initialized; 422 | concurrentIndex.setIds(index.getIds()); 423 | index.referenceReused = true; 424 | return concurrentIndex; 425 | } 426 | } 427 | -------------------------------------------------------------------------------- /hnswlib-jna/src/test/java/com/stepstone/search/hnswlib/jna/AbstractIndexTest.java: -------------------------------------------------------------------------------- 1 | package com.stepstone.search.hnswlib.jna; 2 | 3 | import com.stepstone.search.hnswlib.jna.exception.IndexAlreadyInitializedException; 4 | import com.stepstone.search.hnswlib.jna.exception.IndexNotInitializedException; 5 | import com.stepstone.search.hnswlib.jna.exception.ItemCannotBeInsertedIntoTheVectorSpaceException; 6 | import com.stepstone.search.hnswlib.jna.exception.OnceIndexIsClearedItCannotBeReusedException; 7 | import com.stepstone.search.hnswlib.jna.exception.QueryCannotReturnResultsException; 8 | import com.stepstone.search.hnswlib.jna.exception.UnexpectedNativeException; 9 | import org.junit.Test; 10 | 11 | import java.io.File; 12 | import java.io.IOException; 13 | import java.nio.file.Path; 14 | import java.nio.file.Paths; 15 | import java.util.Optional; 16 | import java.util.concurrent.ExecutorService; 17 | import java.util.concurrent.Executors; 18 | import java.util.concurrent.TimeUnit; 19 | 20 | import static org.junit.Assert.assertArrayEquals; 21 | import static org.junit.Assert.assertEquals; 22 | import static org.junit.Assert.assertFalse; 23 | import static org.junit.Assert.assertNotNull; 24 | import static org.junit.Assert.assertNull; 25 | import static org.junit.Assert.assertTrue; 26 | 27 | public abstract class AbstractIndexTest { 28 | 29 | protected abstract Index createIndexInstance(SpaceName spaceName, int dimensions); 30 | 31 | @Test 32 | public void testSingleIndexInstantiation() throws UnexpectedNativeException { 33 | Index i1 = createIndexInstance(SpaceName.IP, 30); 34 | assertNotNull(i1); 35 | i1.clear(); 36 | } 37 | 38 | @Test 39 | public void testMultipleIndexInstantiation() throws UnexpectedNativeException { 40 | Index i1 = createIndexInstance(SpaceName.IP, 30); 41 | assertNotNull(i1); 42 | Index i2 = createIndexInstance(SpaceName.COSINE, 30); 43 | assertNotNull(i2); 44 | Index i3 = createIndexInstance(SpaceName.L2, 30); 45 | assertNotNull(i3); 46 | i1.clear(); 47 | i2.clear(); 48 | i3.clear(); 49 | } 50 | 51 | @Test 52 | public void testIndexInitialization() throws UnexpectedNativeException { 53 | Index i1 = createIndexInstance(SpaceName.COSINE, 50); 54 | i1.initialize(500_000, 16, 200, 100); 55 | assertEquals(0, i1.getLength()); 56 | i1.clear(); 57 | } 58 | 59 | @Test 60 | public void testIndexInitialization2() throws UnexpectedNativeException { 61 | Index i1 = createIndexInstance(SpaceName.COSINE, 50); 62 | i1.initialize(); 63 | assertEquals(0, i1.getLength()); 64 | i1.clear(); 65 | } 66 | 67 | @Test(expected = IndexAlreadyInitializedException.class) 68 | public void testIndexMultipleInitialization() throws UnexpectedNativeException { 69 | Index i1 = createIndexInstance(SpaceName.COSINE, 50); 70 | i1.initialize(500_000, 16, 200, 100); 71 | i1.initialize(); 72 | } 73 | 74 | @Test 75 | public void testIndexAddItem() throws UnexpectedNativeException { 76 | Index i1 = createIndexInstance(SpaceName.COSINE, 3); 77 | i1.initialize(1); 78 | i1.addItem(new float[] { 1.3f, 1.2f, 1.5f }, 3); 79 | assertEquals(1, i1.getLength()); 80 | i1.clear(); 81 | } 82 | 83 | @Test 84 | public void testIndexAddItemIndependence() throws UnexpectedNativeException { 85 | testIndexAddItem(); 86 | Index i2 = createIndexInstance(SpaceName.IP, 4); 87 | i2.initialize(3); 88 | assertEquals(0, i2.getLength()); 89 | i2.clear(); 90 | } 91 | 92 | @Test 93 | public void testIndexSaveAndLoad() throws UnexpectedNativeException, IOException { 94 | File tempFile = File.createTempFile("index", "sm"); 95 | Path tempFilePath = Paths.get(tempFile.getAbsolutePath()); 96 | 97 | Index i1 = createIndexInstance(SpaceName.COSINE, 3); 98 | i1.initialize(1); 99 | i1.addItem(new float[] { 1.3f, 1.2f, 1.5f }, 3); 100 | i1.save(tempFilePath); 101 | i1.clear(); 102 | 103 | Index i2 = createIndexInstance(SpaceName.COSINE, 3); 104 | assertEquals(0, i2.getLength()); 105 | i2.load(tempFilePath,1); 106 | assertEquals(1, i2.getLength()); 107 | i2.clear(); 108 | 109 | assertTrue(tempFile.delete()); 110 | } 111 | 112 | @Test 113 | public void testParallelAddItemsInMultipleIndexes() throws InterruptedException, UnexpectedNativeException { 114 | int cpus = Runtime.getRuntime().availableProcessors(); 115 | ExecutorService executorService = Executors.newFixedThreadPool(cpus); 116 | 117 | Index i1 = createIndexInstance(SpaceName.L2, 50); 118 | i1.initialize(1_050); 119 | 120 | Index i2 = createIndexInstance(SpaceName.COSINE, 50); 121 | i2.initialize(1_050); 122 | 123 | Runnable addItemIndex1 = () -> { 124 | try { 125 | i1.addItem(HnswlibTestUtils.getRandomFloatArray(50)); 126 | } catch (UnexpectedNativeException e) { 127 | e.printStackTrace(); 128 | } 129 | }; 130 | Runnable addItemIndex2 = () -> { 131 | try { 132 | i2.addItem(HnswlibTestUtils.getRandomFloatArray(50)); 133 | } catch (UnexpectedNativeException e) { 134 | e.printStackTrace(); 135 | } 136 | }; 137 | 138 | for(int i = 0; i < 1_000; i++) { 139 | executorService.submit(addItemIndex1); 140 | executorService.submit(addItemIndex2); 141 | } 142 | 143 | executorService.shutdown(); 144 | executorService.awaitTermination(5, TimeUnit.MINUTES); 145 | 146 | assertEquals(1_000, i1.getLength()); 147 | assertEquals(1_000, i2.getLength()); 148 | 149 | i1.clear(); i2.clear(); 150 | } 151 | 152 | @Test 153 | public void testConcurrentInsertQuery() throws InterruptedException, UnexpectedNativeException { 154 | ExecutorService executorService = Executors.newFixedThreadPool(50); 155 | 156 | Index i1 = createIndexInstance(SpaceName.L2, 50); 157 | i1.initialize(1_050); 158 | 159 | float[] randomFloatArray = HnswlibTestUtils.getRandomFloatArray(50); 160 | 161 | Runnable addItemIndex1 = () -> { 162 | try { 163 | i1.addItem(randomFloatArray); 164 | } catch (UnexpectedNativeException e) { 165 | e.printStackTrace(); 166 | } 167 | }; 168 | 169 | Runnable queryItemIndex1 = () -> { 170 | QueryTuple queryTuple; 171 | try { 172 | queryTuple = i1.knnQuery(randomFloatArray, 1); 173 | assertEquals(50, queryTuple.getIds().length); 174 | assertEquals(50, queryTuple.getCoefficients().length); 175 | } catch (UnexpectedNativeException e) { 176 | e.printStackTrace(); 177 | } 178 | }; 179 | 180 | for(int i = 0; i < 1_000; i++) { 181 | executorService.submit(addItemIndex1); 182 | executorService.submit(queryItemIndex1); 183 | } 184 | 185 | executorService.shutdown(); 186 | executorService.awaitTermination(5, TimeUnit.MINUTES); 187 | 188 | assertEquals(1_000, i1.getLength()); 189 | i1.clear(); 190 | } 191 | 192 | @Test(expected = QueryCannotReturnResultsException.class) 193 | public void testQueryEmptyException() throws UnexpectedNativeException { 194 | Index idx = createIndexInstance(SpaceName.COSINE, 3); 195 | idx.initialize(300); 196 | QueryTuple queryTuple = idx.knnQuery(new float[] {1.3f, 1.4f, 1.5f}, 3); 197 | assertNull(queryTuple); 198 | } 199 | 200 | @Test 201 | public void testOverwritingAnItemInTheModel() throws UnexpectedNativeException { 202 | Index index = createIndexInstance(SpaceName.COSINE, 4); 203 | index.initialize(5); 204 | 205 | index.addItem(new float[] { 1.0f, 1.0f, 1.0f, 1.0f}, 1); 206 | index.addItem(new float[] { 1.0f, 1.0f, 1.0f, 0.95f}, 2); 207 | index.addItem(new float[] { 1.0f, 1.0f, 1.0f, 0.9f}, 3); 208 | index.addItem(new float[] { 1.0f, 1.0f, 1.0f, 0.85f}, 4); 209 | 210 | QueryTuple queryTuple = index.knnQuery(new float[] {1.0f, 1.0f, 1.0f, 1.0f}, 3); 211 | assertEquals(1, queryTuple.ids[0]); 212 | assertEquals(2, queryTuple.ids[1]); 213 | assertEquals(3, queryTuple.ids[2]); 214 | 215 | index.addItem(new float[] { 0.0f, 0.0f, 0.0f, 0.0f}, 2); 216 | queryTuple = index.knnQuery(new float[] {1.0f, 1.0f, 1.0f, 1.0f}, 3); 217 | assertEquals(1, queryTuple.ids[0]); 218 | assertEquals(3, queryTuple.ids[1]); 219 | assertEquals(4, queryTuple.ids[2]); 220 | 221 | index.clear(); 222 | } 223 | 224 | @Test(expected = ItemCannotBeInsertedIntoTheVectorSpaceException.class) 225 | public void testIncludingMoreItemsThanPossible() throws UnexpectedNativeException { 226 | Index index = createIndexInstance(SpaceName.L2, 4); 227 | index.initialize(2); 228 | 229 | index.addItem(new float[] { 1.0f, 1.0f, 1.0f, 1.0f}, 1); 230 | index.addItem(new float[] { 1.0f, 1.0f, 1.0f, 0.95f}, 2); 231 | index.addItem(new float[] { 1.0f, 1.0f, 1.0f, 0.9f}, 3); 232 | } 233 | 234 | @Test 235 | public void testNativeArrayNormalization() throws UnexpectedNativeException { 236 | Index index = createIndexInstance(SpaceName.COSINE, 7); 237 | index.initialize(20); 238 | 239 | float[] item1 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; 240 | index.addItem(item1); /* COSINE requires a normalized item. So, this input will be normalized (and modified) before being added to the index. */ 241 | assertArrayEquals(new float[] {0.3779645f, 0.3779645f, 0.3779645f, 0.3779645f, 0.3779645f, 0.3779645f, 0.3779645f}, item1, 0.000001f); 242 | 243 | float[] item2 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; 244 | index.addNormalizedItem(item1); /* since we are using a add normalized method, nothing should happen here. */ 245 | assertArrayEquals(new float[] {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, item2,0.000001f); 246 | 247 | QueryTuple queryTuple = index.knnQuery(new float[] {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, 1); 248 | assertEquals(-2.3841858E-7f, queryTuple.getCoefficients()[0], 0.00001); 249 | 250 | queryTuple = index.knnNormalizedQuery(item2, 1); 251 | assertEquals(-1.645751476f, queryTuple.getCoefficients()[0], 0.00001); 252 | } 253 | 254 | @Test 255 | public void testHostNormalization() { 256 | float[] item1 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; 257 | Index.normalize(item1); 258 | assertArrayEquals(new float[] {0.3779645f, 0.3779645f, 0.3779645f, 0.3779645f, 0.3779645f, 0.3779645f, 0.3779645f}, item1, 0.000001f); 259 | } 260 | 261 | @Test 262 | public void testIndexCosineEqualsToIPWhenNormalized() throws UnexpectedNativeException { 263 | float[] i1 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; 264 | Index.normalize(i1); 265 | float[] i2 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.95f}; 266 | Index.normalize(i2); 267 | float[] i3 = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.9f}; 268 | Index.normalize(i3); 269 | 270 | Index indexCosine = createIndexInstance(SpaceName.COSINE, 7); 271 | indexCosine.initialize(3); 272 | indexCosine.addNormalizedItem(i1, 1_111_111); 273 | indexCosine.addNormalizedItem(i2, 1_222_222); 274 | indexCosine.addNormalizedItem(i3, 1_333_333); 275 | 276 | Index indexIP = createIndexInstance(SpaceName.IP, 7); 277 | indexIP.initialize(3); 278 | indexIP.addNormalizedItem(i1, 1_111_111); 279 | indexIP.addNormalizedItem(i2, 1_222_222); 280 | indexIP.addNormalizedItem(i3, 1_333_333); 281 | 282 | float[] input = new float[] {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; 283 | Index.normalize(input); 284 | 285 | QueryTuple cosineQT = indexCosine.knnNormalizedQuery(input, 3); 286 | QueryTuple ipQT = indexCosine.knnNormalizedQuery(input, 3); 287 | 288 | assertArrayEquals(cosineQT.getCoefficients(), ipQT.getCoefficients(), 0.000001f); 289 | assertArrayEquals(cosineQT.getIds(), ipQT.getIds()); 290 | 291 | indexIP.clear(); 292 | indexCosine.clear(); 293 | } 294 | 295 | @Test 296 | public void testSimpleQueryOf5ElementsAndDimension7IP() throws UnexpectedNativeException { 297 | Index index = createIndexInstance(SpaceName.IP, 7); 298 | index.initialize(7); 299 | 300 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f }, 5); 301 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.95f }, 6); 302 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.9f }, 7); 303 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.85f }, 8); 304 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.8f },9); 305 | 306 | float[] input = new float[] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f }; 307 | QueryTuple ipQT = index.knnQuery(input, 4); 308 | 309 | assertArrayEquals(new int[] {5, 6, 7, 8}, ipQT.getIds()); 310 | assertArrayEquals(new float[] {-6.0f, -5.95f, -5.9f, -5.85f}, ipQT.getCoefficients(), 0.000001f); 311 | index.clear(); 312 | } 313 | 314 | @Test 315 | public void testSimpleQueryOf5ElementsAndDimension7Cosine() throws UnexpectedNativeException { 316 | Index index = createIndexInstance(SpaceName.COSINE, 7); 317 | index.initialize(7); 318 | 319 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f }, 14); 320 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.95f }, 13); 321 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.9f }, 12); 322 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.85f }, 11); 323 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.8f },10); 324 | 325 | float[] input = new float[] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f }; 326 | QueryTuple ipQT = index.knnQuery(input, 4); 327 | 328 | assertArrayEquals(new int[] {14, 13, 12, 11}, ipQT.getIds()); 329 | assertArrayEquals(new float[] {-2.3841858E-7f, 1.552105E-4f, 6.2948465E-4f, 0.001435399f}, ipQT.getCoefficients(), 0.000001f); 330 | index.clear(); 331 | } 332 | 333 | @Test 334 | public void testSimpleQueryOf5ElementsAndDimension7L2() throws UnexpectedNativeException { 335 | Index index = createIndexInstance(SpaceName.L2, 7); 336 | index.initialize(7); 337 | 338 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.9f }, 48); 339 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.85f }, 10); 340 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.95f }, 35); 341 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.8f },1); 342 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f }, 33); 343 | 344 | float[] input = new float[] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f }; 345 | QueryTuple ipQT = index.knnQuery(input, 4); 346 | 347 | assertArrayEquals(new int[] {33, 35, 48, 10}, ipQT.getIds()); 348 | assertArrayEquals(new float[] { 0.0f, 0.002500001f, 0.010000004f, 0.022499993f}, ipQT.getCoefficients(), 0.000001f); 349 | index.clear(); 350 | } 351 | 352 | @Test(expected = OnceIndexIsClearedItCannotBeReusedException.class) 353 | public void testDoubleClear() throws UnexpectedNativeException { 354 | Index idx = createIndexInstance(SpaceName.IP, 30); 355 | idx.initialize(3); 356 | idx.clear(); 357 | idx.clear(); 358 | } 359 | 360 | @Test(expected = OnceIndexIsClearedItCannotBeReusedException.class) 361 | public void testUsageAfterClear1() throws UnexpectedNativeException { 362 | Index idx = createIndexInstance(SpaceName.IP, 30); 363 | idx.clear(); 364 | idx.initialize(30); 365 | } 366 | 367 | @Test(expected = OnceIndexIsClearedItCannotBeReusedException.class) 368 | public void testUsageAfterClear2() throws UnexpectedNativeException { 369 | Index index = createIndexInstance(SpaceName.IP, 30); 370 | index.initialize(30); 371 | index.clear(); 372 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.9f }, 48); 373 | } 374 | 375 | @Test(expected = OnceIndexIsClearedItCannotBeReusedException.class) 376 | public void testUsageAfterClear3() throws UnexpectedNativeException { 377 | Index index = createIndexInstance(SpaceName.IP, 30); 378 | index.initialize(30); 379 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.9f }, 48); 380 | index.clear(); 381 | float[] input = new float[] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f }; 382 | index.knnQuery(input, 4); 383 | } 384 | 385 | @Test 386 | public void testTryingDoubleClearDueToGCWhenReferenceIsLost() throws UnexpectedNativeException { 387 | Index index = createIndexInstance(SpaceName.IP, 30); 388 | index.initialize(30); 389 | index.clear(); 390 | index = createIndexInstance(SpaceName.IP, 30); 391 | int counter = 10; 392 | while (counter-- > 0) { 393 | System.gc(); 394 | } 395 | index.initialize(30); 396 | assertNotNull(index); 397 | } 398 | 399 | @Test 400 | public void testGetData() { 401 | Index index = createIndexInstance(SpaceName.COSINE, 3); 402 | index.initialize(); 403 | float[] vector = {1F, 2F, 3F}; 404 | index.addItem(vector); 405 | assertTrue(index.hasId(0)); 406 | Optional data = index.getData(0); 407 | assertTrue(data.isPresent()); 408 | assertArrayEquals(vector, data.get(), 0.0f); 409 | assertFalse(index.hasId(1)); 410 | assertFalse(index.getData(1).isPresent()); 411 | 412 | float[] vector2 = {1F, 2F, 3F}; 413 | index.addItem(vector2, 1230); 414 | assertTrue(index.hasId(1230)); 415 | assertFalse(index.hasId(1231)); 416 | 417 | index.clear(); 418 | assertFalse(index.hasId(1230)); 419 | assertFalse(index.hasId(1231)); 420 | } 421 | 422 | @Test 423 | public void testGetDataWhenIndexCleared() { 424 | Index index = createIndexInstance(SpaceName.COSINE, 3); 425 | index.initialize(); 426 | index.clear(); 427 | assertFalse(index.hasId(1202)); 428 | Index index2 = createIndexInstance(SpaceName.COSINE, 1); 429 | assertFalse(index2.hasId(123)); 430 | } 431 | 432 | @Test(expected = IndexNotInitializedException.class) 433 | public void testUseAddItemIndexWithoutInitialize() { 434 | Index index = createIndexInstance(SpaceName.COSINE, 1); 435 | index.addItem(new float[1]); 436 | } 437 | 438 | @Test(expected = IndexNotInitializedException.class) 439 | public void testUseAddNormalizedItemIndexWithoutInitialize() { 440 | Index index = createIndexInstance(SpaceName.COSINE, 1); 441 | index.addNormalizedItem(new float[1]); 442 | } 443 | 444 | @Test(expected = IndexNotInitializedException.class) 445 | public void testUseKnnQueryIndexWithoutInitialize() { 446 | Index index = createIndexInstance(SpaceName.COSINE, 1); 447 | index.knnQuery(new float[1],1); 448 | } 449 | 450 | @Test(expected = IndexNotInitializedException.class) 451 | public void testUseKnnNormalizedQueryQueryIndexWithoutInitialize() { 452 | Index index = createIndexInstance(SpaceName.COSINE, 1); 453 | index.knnNormalizedQuery(new float[1],1); 454 | } 455 | 456 | @Test(expected = IndexNotInitializedException.class) 457 | public void testGetMWithoutInitializeIndex() { 458 | Index index = createIndexInstance(SpaceName.COSINE, 1); 459 | index.getM(); 460 | } 461 | 462 | @Test(expected = IndexNotInitializedException.class) 463 | public void testGetEfWithoutInitializeIndex() { 464 | Index index = createIndexInstance(SpaceName.COSINE, 1); 465 | index.getEf(); 466 | } 467 | 468 | @Test(expected = IndexNotInitializedException.class) 469 | public void testGetEfConstructionWithoutInitializeIndex() { 470 | Index index = createIndexInstance(SpaceName.COSINE, 1); 471 | index.getEfConstruction(); 472 | } 473 | 474 | @Test 475 | public void testMarkAsDeleted() { 476 | Index index = createIndexInstance(SpaceName.COSINE, 7); 477 | index.initialize(7); 478 | 479 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f }, 14); 480 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.95f }, 13); 481 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.9f }, 12); 482 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.85f }, 11); 483 | index.addItem(new float [] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 0.8f },10); 484 | 485 | float[] input = new float[] { 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f }; 486 | QueryTuple ipQT = index.knnQuery(input, 4); 487 | 488 | assertArrayEquals(new int[] {14, 13, 12, 11}, ipQT.getIds()); 489 | assertArrayEquals(new float[] {-2.3841858E-7f, 1.552105E-4f, 6.2948465E-4f, 0.001435399f}, ipQT.getCoefficients(), 0.000001f); 490 | 491 | index.markDeleted(13); 492 | QueryTuple ipQT2 = index.knnQuery(input, 4); 493 | assertArrayEquals(new int[] {14, 12, 11, 10}, ipQT2.getIds()); 494 | assertArrayEquals(new float[] {-2.3841858E-7f, 6.2948465E-4f, 0.001435399f, 0.0025851727f}, ipQT2.getCoefficients(), 0.000001f); 495 | 496 | index.markDeleted(12); 497 | QueryTuple ipQT3 = index.knnQuery(input, 3); 498 | assertArrayEquals(new int[] {14, 11, 10}, ipQT3.getIds()); 499 | assertArrayEquals(new float[] {-2.3841858E-7f, 0.001435399f, 0.0025851727f}, ipQT3.getCoefficients(), 0.000001f); 500 | 501 | index.clear(); 502 | } 503 | 504 | } 505 | -------------------------------------------------------------------------------- /hnswlib/hnswalg.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "visited_list_pool.h" 4 | #include "hnswlib.h" 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | namespace hnswlib { 13 | typedef unsigned int tableint; 14 | typedef unsigned int linklistsizeint; 15 | 16 | template 17 | class HierarchicalNSW : public AlgorithmInterface { 18 | public: 19 | static const tableint max_update_element_locks = 65536; 20 | HierarchicalNSW(SpaceInterface *s) { 21 | 22 | } 23 | 24 | HierarchicalNSW(SpaceInterface *s, const std::string &location, bool nmslib = false, size_t max_elements=0) { 25 | loadIndex(location, s, max_elements); 26 | } 27 | 28 | HierarchicalNSW(SpaceInterface *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) : 29 | link_list_locks_(max_elements), element_levels_(max_elements), link_list_update_locks_(max_update_element_locks) { 30 | max_elements_ = max_elements; 31 | 32 | has_deletions_=false; 33 | data_size_ = s->get_data_size(); 34 | fstdistfunc_ = s->get_dist_func(); 35 | dist_func_param_ = s->get_dist_func_param(); 36 | M_ = M; 37 | maxM_ = M_; 38 | maxM0_ = M_ * 2; 39 | ef_construction_ = std::max(ef_construction,M_); 40 | ef_ = 10; 41 | 42 | level_generator_.seed(random_seed); 43 | update_probability_generator_.seed(random_seed + 1); 44 | 45 | size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); 46 | size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype); 47 | offsetData_ = size_links_level0_; 48 | label_offset_ = size_links_level0_ + data_size_; 49 | offsetLevel0_ = 0; 50 | 51 | data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_); 52 | if (data_level0_memory_ == nullptr) 53 | throw std::runtime_error("Not enough memory"); 54 | 55 | cur_element_count = 0; 56 | 57 | visited_list_pool_ = new VisitedListPool(1, max_elements); 58 | 59 | 60 | 61 | //initializations for special treatment of the first node 62 | enterpoint_node_ = -1; 63 | maxlevel_ = -1; 64 | 65 | linkLists_ = (char **) malloc(sizeof(void *) * max_elements_); 66 | if (linkLists_ == nullptr) 67 | throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists"); 68 | size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); 69 | mult_ = 1 / log(1.0 * M_); 70 | revSize_ = 1.0 / mult_; 71 | } 72 | 73 | struct CompareByFirst { 74 | constexpr bool operator()(std::pair const &a, 75 | std::pair const &b) const noexcept { 76 | return a.first < b.first; 77 | } 78 | }; 79 | 80 | ~HierarchicalNSW() { 81 | 82 | free(data_level0_memory_); 83 | for (tableint i = 0; i < cur_element_count; i++) { 84 | if (element_levels_[i] > 0) 85 | free(linkLists_[i]); 86 | } 87 | free(linkLists_); 88 | delete visited_list_pool_; 89 | } 90 | 91 | size_t max_elements_; 92 | size_t cur_element_count; 93 | size_t size_data_per_element_; 94 | size_t size_links_per_element_; 95 | 96 | size_t M_; 97 | size_t maxM_; 98 | size_t maxM0_; 99 | size_t ef_construction_; 100 | 101 | double mult_, revSize_; 102 | int maxlevel_; 103 | 104 | 105 | VisitedListPool *visited_list_pool_; 106 | std::mutex cur_element_count_guard_; 107 | 108 | std::vector link_list_locks_; 109 | 110 | // Locks to prevent race condition during update/insert of an element at same time. 111 | // Note: Locks for additions can also be used to prevent this race condition if the querying of KNN is not exposed along with update/inserts i.e multithread insert/update/query in parallel. 112 | std::vector link_list_update_locks_; 113 | tableint enterpoint_node_; 114 | 115 | 116 | size_t size_links_level0_; 117 | size_t offsetData_, offsetLevel0_; 118 | 119 | 120 | char *data_level0_memory_; 121 | char **linkLists_; 122 | std::vector element_levels_; 123 | 124 | size_t data_size_; 125 | 126 | bool has_deletions_; 127 | 128 | 129 | size_t label_offset_; 130 | DISTFUNC fstdistfunc_; 131 | void *dist_func_param_; 132 | std::unordered_map label_lookup_; 133 | 134 | std::default_random_engine level_generator_; 135 | std::default_random_engine update_probability_generator_; 136 | 137 | inline labeltype getExternalLabel(tableint internal_id) const { 138 | labeltype return_label; 139 | memcpy(&return_label,(data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype)); 140 | return return_label; 141 | } 142 | 143 | inline void setExternalLabel(tableint internal_id, labeltype label) const { 144 | memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype)); 145 | } 146 | 147 | inline labeltype *getExternalLabeLp(tableint internal_id) const { 148 | return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_); 149 | } 150 | 151 | inline char *getDataByInternalId(tableint internal_id) const { 152 | return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_); 153 | } 154 | 155 | int getRandomLevel(double reverse_size) { 156 | std::uniform_real_distribution distribution(0.0, 1.0); 157 | double r = -log(distribution(level_generator_)) * reverse_size; 158 | return (int) r; 159 | } 160 | 161 | 162 | std::priority_queue, std::vector>, CompareByFirst> 163 | searchBaseLayer(tableint ep_id, const void *data_point, int layer) { 164 | VisitedList *vl = visited_list_pool_->getFreeVisitedList(); 165 | vl_type *visited_array = vl->mass; 166 | vl_type visited_array_tag = vl->curV; 167 | 168 | std::priority_queue, std::vector>, CompareByFirst> top_candidates; 169 | std::priority_queue, std::vector>, CompareByFirst> candidateSet; 170 | 171 | dist_t lowerBound; 172 | if (!isMarkedDeleted(ep_id)) { 173 | dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); 174 | top_candidates.emplace(dist, ep_id); 175 | lowerBound = dist; 176 | candidateSet.emplace(-dist, ep_id); 177 | } else { 178 | lowerBound = std::numeric_limits::max(); 179 | candidateSet.emplace(-lowerBound, ep_id); 180 | } 181 | visited_array[ep_id] = visited_array_tag; 182 | 183 | while (!candidateSet.empty()) { 184 | std::pair curr_el_pair = candidateSet.top(); 185 | if ((-curr_el_pair.first) > lowerBound) { 186 | break; 187 | } 188 | candidateSet.pop(); 189 | 190 | tableint curNodeNum = curr_el_pair.second; 191 | 192 | std::unique_lock lock(link_list_locks_[curNodeNum]); 193 | 194 | int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_); 195 | if (layer == 0) { 196 | data = (int*)get_linklist0(curNodeNum); 197 | } else { 198 | data = (int*)get_linklist(curNodeNum, layer); 199 | // data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_); 200 | } 201 | size_t size = getListCount((linklistsizeint*)data); 202 | tableint *datal = (tableint *) (data + 1); 203 | #ifdef USE_SSE 204 | _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); 205 | _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); 206 | _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); 207 | _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0); 208 | #endif 209 | 210 | for (size_t j = 0; j < size; j++) { 211 | tableint candidate_id = *(datal + j); 212 | // if (candidate_id == 0) continue; 213 | #ifdef USE_SSE 214 | _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0); 215 | _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0); 216 | #endif 217 | if (visited_array[candidate_id] == visited_array_tag) continue; 218 | visited_array[candidate_id] = visited_array_tag; 219 | char *currObj1 = (getDataByInternalId(candidate_id)); 220 | 221 | dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_); 222 | if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { 223 | candidateSet.emplace(-dist1, candidate_id); 224 | #ifdef USE_SSE 225 | _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0); 226 | #endif 227 | 228 | if (!isMarkedDeleted(candidate_id)) 229 | top_candidates.emplace(dist1, candidate_id); 230 | 231 | if (top_candidates.size() > ef_construction_) 232 | top_candidates.pop(); 233 | 234 | if (!top_candidates.empty()) 235 | lowerBound = top_candidates.top().first; 236 | } 237 | } 238 | } 239 | visited_list_pool_->releaseVisitedList(vl); 240 | 241 | return top_candidates; 242 | } 243 | 244 | mutable std::atomic metric_distance_computations; 245 | mutable std::atomic metric_hops; 246 | 247 | template 248 | std::priority_queue, std::vector>, CompareByFirst> 249 | searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const { 250 | VisitedList *vl = visited_list_pool_->getFreeVisitedList(); 251 | vl_type *visited_array = vl->mass; 252 | vl_type visited_array_tag = vl->curV; 253 | 254 | std::priority_queue, std::vector>, CompareByFirst> top_candidates; 255 | std::priority_queue, std::vector>, CompareByFirst> candidate_set; 256 | 257 | dist_t lowerBound; 258 | if (!has_deletions || !isMarkedDeleted(ep_id)) { 259 | dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); 260 | lowerBound = dist; 261 | top_candidates.emplace(dist, ep_id); 262 | candidate_set.emplace(-dist, ep_id); 263 | } else { 264 | lowerBound = std::numeric_limits::max(); 265 | candidate_set.emplace(-lowerBound, ep_id); 266 | } 267 | 268 | visited_array[ep_id] = visited_array_tag; 269 | 270 | while (!candidate_set.empty()) { 271 | 272 | std::pair current_node_pair = candidate_set.top(); 273 | 274 | if ((-current_node_pair.first) > lowerBound) { 275 | break; 276 | } 277 | candidate_set.pop(); 278 | 279 | tableint current_node_id = current_node_pair.second; 280 | int *data = (int *) get_linklist0(current_node_id); 281 | size_t size = getListCount((linklistsizeint*)data); 282 | // bool cur_node_deleted = isMarkedDeleted(current_node_id); 283 | if(collect_metrics){ 284 | metric_hops++; 285 | metric_distance_computations+=size; 286 | } 287 | 288 | #ifdef USE_SSE 289 | _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); 290 | _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); 291 | _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); 292 | _mm_prefetch((char *) (data + 2), _MM_HINT_T0); 293 | #endif 294 | 295 | for (size_t j = 1; j <= size; j++) { 296 | int candidate_id = *(data + j); 297 | // if (candidate_id == 0) continue; 298 | #ifdef USE_SSE 299 | _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0); 300 | _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_, 301 | _MM_HINT_T0);//////////// 302 | #endif 303 | if (!(visited_array[candidate_id] == visited_array_tag)) { 304 | 305 | visited_array[candidate_id] = visited_array_tag; 306 | 307 | char *currObj1 = (getDataByInternalId(candidate_id)); 308 | dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); 309 | 310 | if (top_candidates.size() < ef || lowerBound > dist) { 311 | candidate_set.emplace(-dist, candidate_id); 312 | #ifdef USE_SSE 313 | _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + 314 | offsetLevel0_,/////////// 315 | _MM_HINT_T0);//////////////////////// 316 | #endif 317 | 318 | if (!has_deletions || !isMarkedDeleted(candidate_id)) 319 | top_candidates.emplace(dist, candidate_id); 320 | 321 | if (top_candidates.size() > ef) 322 | top_candidates.pop(); 323 | 324 | if (!top_candidates.empty()) 325 | lowerBound = top_candidates.top().first; 326 | } 327 | } 328 | } 329 | } 330 | 331 | visited_list_pool_->releaseVisitedList(vl); 332 | return top_candidates; 333 | } 334 | 335 | void getNeighborsByHeuristic2( 336 | std::priority_queue, std::vector>, CompareByFirst> &top_candidates, 337 | const size_t M) { 338 | if (top_candidates.size() < M) { 339 | return; 340 | } 341 | 342 | std::priority_queue> queue_closest; 343 | std::vector> return_list; 344 | while (top_candidates.size() > 0) { 345 | queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second); 346 | top_candidates.pop(); 347 | } 348 | 349 | while (queue_closest.size()) { 350 | if (return_list.size() >= M) 351 | break; 352 | std::pair curent_pair = queue_closest.top(); 353 | dist_t dist_to_query = -curent_pair.first; 354 | queue_closest.pop(); 355 | bool good = true; 356 | 357 | for (std::pair second_pair : return_list) { 358 | dist_t curdist = 359 | fstdistfunc_(getDataByInternalId(second_pair.second), 360 | getDataByInternalId(curent_pair.second), 361 | dist_func_param_);; 362 | if (curdist < dist_to_query) { 363 | good = false; 364 | break; 365 | } 366 | } 367 | if (good) { 368 | return_list.push_back(curent_pair); 369 | } 370 | } 371 | 372 | for (std::pair curent_pair : return_list) { 373 | top_candidates.emplace(-curent_pair.first, curent_pair.second); 374 | } 375 | } 376 | 377 | 378 | linklistsizeint *get_linklist0(tableint internal_id) const { 379 | return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); 380 | }; 381 | 382 | linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const { 383 | return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); 384 | }; 385 | 386 | linklistsizeint *get_linklist(tableint internal_id, int level) const { 387 | return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); 388 | }; 389 | 390 | linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const { 391 | return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level); 392 | }; 393 | 394 | tableint mutuallyConnectNewElement(const void *data_point, tableint cur_c, 395 | std::priority_queue, std::vector>, CompareByFirst> &top_candidates, 396 | int level, bool isUpdate) { 397 | size_t Mcurmax = level ? maxM_ : maxM0_; 398 | getNeighborsByHeuristic2(top_candidates, M_); 399 | if (top_candidates.size() > M_) 400 | throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic"); 401 | 402 | std::vector selectedNeighbors; 403 | selectedNeighbors.reserve(M_); 404 | while (top_candidates.size() > 0) { 405 | selectedNeighbors.push_back(top_candidates.top().second); 406 | top_candidates.pop(); 407 | } 408 | 409 | tableint next_closest_entry_point = selectedNeighbors[0]; 410 | 411 | { 412 | linklistsizeint *ll_cur; 413 | if (level == 0) 414 | ll_cur = get_linklist0(cur_c); 415 | else 416 | ll_cur = get_linklist(cur_c, level); 417 | 418 | if (*ll_cur && !isUpdate) { 419 | throw std::runtime_error("The newly inserted element should have blank link list"); 420 | } 421 | setListCount(ll_cur,selectedNeighbors.size()); 422 | tableint *data = (tableint *) (ll_cur + 1); 423 | for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { 424 | if (data[idx] && !isUpdate) 425 | throw std::runtime_error("Possible memory corruption"); 426 | if (level > element_levels_[selectedNeighbors[idx]]) 427 | throw std::runtime_error("Trying to make a link on a non-existent level"); 428 | 429 | data[idx] = selectedNeighbors[idx]; 430 | 431 | } 432 | } 433 | 434 | for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { 435 | 436 | std::unique_lock lock(link_list_locks_[selectedNeighbors[idx]]); 437 | 438 | linklistsizeint *ll_other; 439 | if (level == 0) 440 | ll_other = get_linklist0(selectedNeighbors[idx]); 441 | else 442 | ll_other = get_linklist(selectedNeighbors[idx], level); 443 | 444 | size_t sz_link_list_other = getListCount(ll_other); 445 | 446 | if (sz_link_list_other > Mcurmax) 447 | throw std::runtime_error("Bad value of sz_link_list_other"); 448 | if (selectedNeighbors[idx] == cur_c) 449 | throw std::runtime_error("Trying to connect an element to itself"); 450 | if (level > element_levels_[selectedNeighbors[idx]]) 451 | throw std::runtime_error("Trying to make a link on a non-existent level"); 452 | 453 | tableint *data = (tableint *) (ll_other + 1); 454 | 455 | bool is_cur_c_present = false; 456 | if (isUpdate) { 457 | for (size_t j = 0; j < sz_link_list_other; j++) { 458 | if (data[j] == cur_c) { 459 | is_cur_c_present = true; 460 | break; 461 | } 462 | } 463 | } 464 | 465 | // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics. 466 | if (!is_cur_c_present) { 467 | if (sz_link_list_other < Mcurmax) { 468 | data[sz_link_list_other] = cur_c; 469 | setListCount(ll_other, sz_link_list_other + 1); 470 | } else { 471 | // finding the "weakest" element to replace it with the new one 472 | dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), 473 | dist_func_param_); 474 | // Heuristic: 475 | std::priority_queue, std::vector>, CompareByFirst> candidates; 476 | candidates.emplace(d_max, cur_c); 477 | 478 | for (size_t j = 0; j < sz_link_list_other; j++) { 479 | candidates.emplace( 480 | fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]), 481 | dist_func_param_), data[j]); 482 | } 483 | 484 | getNeighborsByHeuristic2(candidates, Mcurmax); 485 | 486 | int indx = 0; 487 | while (candidates.size() > 0) { 488 | data[indx] = candidates.top().second; 489 | candidates.pop(); 490 | indx++; 491 | } 492 | 493 | setListCount(ll_other, indx); 494 | // Nearest K: 495 | /*int indx = -1; 496 | for (int j = 0; j < sz_link_list_other; j++) { 497 | dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); 498 | if (d > d_max) { 499 | indx = j; 500 | d_max = d; 501 | } 502 | } 503 | if (indx >= 0) { 504 | data[indx] = cur_c; 505 | } */ 506 | } 507 | } 508 | } 509 | 510 | return next_closest_entry_point; 511 | } 512 | 513 | std::mutex global; 514 | size_t ef_; 515 | 516 | void setEf(size_t ef) { 517 | ef_ = ef; 518 | } 519 | 520 | 521 | std::priority_queue> searchKnnInternal(void *query_data, int k) { 522 | std::priority_queue> top_candidates; 523 | if (cur_element_count == 0) return top_candidates; 524 | tableint currObj = enterpoint_node_; 525 | dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); 526 | 527 | for (size_t level = maxlevel_; level > 0; level--) { 528 | bool changed = true; 529 | while (changed) { 530 | changed = false; 531 | int *data; 532 | data = (int *) get_linklist(currObj,level); 533 | int size = getListCount(data); 534 | tableint *datal = (tableint *) (data + 1); 535 | for (int i = 0; i < size; i++) { 536 | tableint cand = datal[i]; 537 | if (cand < 0 || cand > max_elements_) 538 | throw std::runtime_error("cand error"); 539 | dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); 540 | 541 | if (d < curdist) { 542 | curdist = d; 543 | currObj = cand; 544 | changed = true; 545 | } 546 | } 547 | } 548 | } 549 | 550 | if (has_deletions_) { 551 | std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, 552 | ef_); 553 | top_candidates.swap(top_candidates1); 554 | } 555 | else{ 556 | std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, 557 | ef_); 558 | top_candidates.swap(top_candidates1); 559 | } 560 | 561 | while (top_candidates.size() > k) { 562 | top_candidates.pop(); 563 | } 564 | return top_candidates; 565 | }; 566 | 567 | void resizeIndex(size_t new_max_elements){ 568 | if (new_max_elements(new_max_elements).swap(link_list_locks_); 580 | 581 | // Reallocate base layer 582 | char * data_level0_memory_new = (char *) malloc(new_max_elements * size_data_per_element_); 583 | if (data_level0_memory_new == nullptr) 584 | throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer"); 585 | memcpy(data_level0_memory_new, data_level0_memory_,cur_element_count * size_data_per_element_); 586 | free(data_level0_memory_); 587 | data_level0_memory_=data_level0_memory_new; 588 | 589 | // Reallocate all other layers 590 | char ** linkLists_new = (char **) malloc(sizeof(void *) * new_max_elements); 591 | if (linkLists_new == nullptr) 592 | throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers"); 593 | memcpy(linkLists_new, linkLists_,cur_element_count * sizeof(void *)); 594 | free(linkLists_); 595 | linkLists_=linkLists_new; 596 | 597 | max_elements_=new_max_elements; 598 | 599 | } 600 | 601 | void saveIndex(const std::string &location) { 602 | std::ofstream output(location, std::ios::binary); 603 | std::streampos position; 604 | 605 | writeBinaryPOD(output, offsetLevel0_); 606 | writeBinaryPOD(output, max_elements_); 607 | writeBinaryPOD(output, cur_element_count); 608 | writeBinaryPOD(output, size_data_per_element_); 609 | writeBinaryPOD(output, label_offset_); 610 | writeBinaryPOD(output, offsetData_); 611 | writeBinaryPOD(output, maxlevel_); 612 | writeBinaryPOD(output, enterpoint_node_); 613 | writeBinaryPOD(output, maxM_); 614 | 615 | writeBinaryPOD(output, maxM0_); 616 | writeBinaryPOD(output, M_); 617 | writeBinaryPOD(output, mult_); 618 | writeBinaryPOD(output, ef_construction_); 619 | 620 | output.write(data_level0_memory_, cur_element_count * size_data_per_element_); 621 | 622 | for (size_t i = 0; i < cur_element_count; i++) { 623 | unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; 624 | writeBinaryPOD(output, linkListSize); 625 | if (linkListSize) 626 | output.write(linkLists_[i], linkListSize); 627 | } 628 | output.close(); 629 | } 630 | 631 | void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i=0) { 632 | 633 | 634 | std::ifstream input(location, std::ios::binary); 635 | 636 | if (!input.is_open()) 637 | throw std::runtime_error("Cannot open file"); 638 | 639 | 640 | // get file size: 641 | input.seekg(0,input.end); 642 | std::streampos total_filesize=input.tellg(); 643 | input.seekg(0,input.beg); 644 | 645 | readBinaryPOD(input, offsetLevel0_); 646 | readBinaryPOD(input, max_elements_); 647 | readBinaryPOD(input, cur_element_count); 648 | 649 | size_t max_elements=max_elements_i; 650 | if(max_elements < cur_element_count) 651 | max_elements = max_elements_; 652 | max_elements_ = max_elements; 653 | readBinaryPOD(input, size_data_per_element_); 654 | readBinaryPOD(input, label_offset_); 655 | readBinaryPOD(input, offsetData_); 656 | readBinaryPOD(input, maxlevel_); 657 | readBinaryPOD(input, enterpoint_node_); 658 | 659 | readBinaryPOD(input, maxM_); 660 | readBinaryPOD(input, maxM0_); 661 | readBinaryPOD(input, M_); 662 | readBinaryPOD(input, mult_); 663 | readBinaryPOD(input, ef_construction_); 664 | 665 | 666 | data_size_ = s->get_data_size(); 667 | fstdistfunc_ = s->get_dist_func(); 668 | dist_func_param_ = s->get_dist_func_param(); 669 | 670 | auto pos=input.tellg(); 671 | 672 | 673 | /// Optional - check if index is ok: 674 | 675 | input.seekg(cur_element_count * size_data_per_element_,input.cur); 676 | for (size_t i = 0; i < cur_element_count; i++) { 677 | if(input.tellg() < 0 || input.tellg()>=total_filesize){ 678 | throw std::runtime_error("Index seems to be corrupted or unsupported"); 679 | } 680 | 681 | unsigned int linkListSize; 682 | readBinaryPOD(input, linkListSize); 683 | if (linkListSize != 0) { 684 | input.seekg(linkListSize,input.cur); 685 | } 686 | } 687 | 688 | // throw exception if it either corrupted or old index 689 | if(input.tellg()!=total_filesize) 690 | throw std::runtime_error("Index seems to be corrupted or unsupported"); 691 | 692 | input.clear(); 693 | 694 | /// Optional check end 695 | 696 | input.seekg(pos,input.beg); 697 | 698 | 699 | data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); 700 | if (data_level0_memory_ == nullptr) 701 | throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); 702 | input.read(data_level0_memory_, cur_element_count * size_data_per_element_); 703 | 704 | 705 | 706 | 707 | size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); 708 | 709 | 710 | size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); 711 | std::vector(max_elements).swap(link_list_locks_); 712 | std::vector(max_update_element_locks).swap(link_list_update_locks_); 713 | 714 | 715 | visited_list_pool_ = new VisitedListPool(1, max_elements); 716 | 717 | 718 | linkLists_ = (char **) malloc(sizeof(void *) * max_elements); 719 | if (linkLists_ == nullptr) 720 | throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); 721 | element_levels_ = std::vector(max_elements); 722 | revSize_ = 1.0 / mult_; 723 | ef_ = 10; 724 | for (size_t i = 0; i < cur_element_count; i++) { 725 | label_lookup_[getExternalLabel(i)]=i; 726 | unsigned int linkListSize; 727 | readBinaryPOD(input, linkListSize); 728 | if (linkListSize == 0) { 729 | element_levels_[i] = 0; 730 | 731 | linkLists_[i] = nullptr; 732 | } else { 733 | element_levels_[i] = linkListSize / size_links_per_element_; 734 | linkLists_[i] = (char *) malloc(linkListSize); 735 | if (linkLists_[i] == nullptr) 736 | throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); 737 | input.read(linkLists_[i], linkListSize); 738 | } 739 | } 740 | 741 | has_deletions_=false; 742 | 743 | for (size_t i = 0; i < cur_element_count; i++) { 744 | if(isMarkedDeleted(i)) 745 | has_deletions_=true; 746 | } 747 | 748 | input.close(); 749 | 750 | return; 751 | } 752 | 753 | template 754 | std::vector getDataByLabel(labeltype label) 755 | { 756 | tableint label_c; 757 | auto search = label_lookup_.find(label); 758 | if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { 759 | throw std::runtime_error("Label not found"); 760 | } 761 | label_c = search->second; 762 | 763 | char* data_ptrv = getDataByInternalId(label_c); 764 | size_t dim = *((size_t *) dist_func_param_); 765 | std::vector data; 766 | data_t* data_ptr = (data_t*) data_ptrv; 767 | for (int i = 0; i < dim; i++) { 768 | data.push_back(*data_ptr); 769 | data_ptr += 1; 770 | } 771 | return data; 772 | } 773 | 774 | static const unsigned char DELETE_MARK = 0x01; 775 | // static const unsigned char REUSE_MARK = 0x10; 776 | /** 777 | * Marks an element with the given label deleted, does NOT really change the current graph. 778 | * @param label 779 | */ 780 | void markDelete(labeltype label) 781 | { 782 | has_deletions_=true; 783 | auto search = label_lookup_.find(label); 784 | if (search == label_lookup_.end()) { 785 | throw std::runtime_error("Label not found"); 786 | } 787 | markDeletedInternal(search->second); 788 | } 789 | 790 | /** 791 | * Uses the first 8 bits of the memory for the linked list to store the mark, 792 | * whereas maxM0_ has to be limited to the lower 24 bits, however, still large enough in almost all cases. 793 | * @param internalId 794 | */ 795 | void markDeletedInternal(tableint internalId) { 796 | unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; 797 | *ll_cur |= DELETE_MARK; 798 | } 799 | 800 | /** 801 | * Remove the deleted mark of the node. 802 | * @param internalId 803 | */ 804 | void unmarkDeletedInternal(tableint internalId) { 805 | unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; 806 | *ll_cur &= ~DELETE_MARK; 807 | } 808 | 809 | /** 810 | * Checks the first 8 bits of the memory to see if the element is marked deleted. 811 | * @param internalId 812 | * @return 813 | */ 814 | bool isMarkedDeleted(tableint internalId) const { 815 | unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId))+2; 816 | return *ll_cur & DELETE_MARK; 817 | } 818 | 819 | unsigned short int getListCount(linklistsizeint * ptr) const { 820 | return *((unsigned short int *)ptr); 821 | } 822 | 823 | void setListCount(linklistsizeint * ptr, unsigned short int size) const { 824 | *((unsigned short int*)(ptr))=*((unsigned short int *)&size); 825 | } 826 | 827 | void addPoint(const void *data_point, labeltype label) { 828 | addPoint(data_point, label,-1); 829 | } 830 | 831 | void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) { 832 | // update the feature vector associated with existing point with new vector 833 | memcpy(getDataByInternalId(internalId), dataPoint, data_size_); 834 | 835 | int maxLevelCopy = maxlevel_; 836 | tableint entryPointCopy = enterpoint_node_; 837 | // If point to be updated is entry point and graph just contains single element then just return. 838 | if (entryPointCopy == internalId && cur_element_count == 1) 839 | return; 840 | 841 | int elemLevel = element_levels_[internalId]; 842 | std::uniform_real_distribution distribution(0.0, 1.0); 843 | for (int layer = 0; layer <= elemLevel; layer++) { 844 | std::unordered_set sCand; 845 | std::unordered_set sNeigh; 846 | std::vector listOneHop = getConnectionsWithLock(internalId, layer); 847 | if (listOneHop.size() == 0) 848 | continue; 849 | 850 | sCand.insert(internalId); 851 | 852 | for (auto&& elOneHop : listOneHop) { 853 | sCand.insert(elOneHop); 854 | 855 | if (distribution(update_probability_generator_) > updateNeighborProbability) 856 | continue; 857 | 858 | sNeigh.insert(elOneHop); 859 | 860 | std::vector listTwoHop = getConnectionsWithLock(elOneHop, layer); 861 | for (auto&& elTwoHop : listTwoHop) { 862 | sCand.insert(elTwoHop); 863 | } 864 | } 865 | 866 | for (auto&& neigh : sNeigh) { 867 | // if (neigh == internalId) 868 | // continue; 869 | 870 | std::priority_queue, std::vector>, CompareByFirst> candidates; 871 | int size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; 872 | int elementsToKeep = std::min(int(ef_construction_), size); 873 | for (auto&& cand : sCand) { 874 | if (cand == neigh) 875 | continue; 876 | 877 | dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_); 878 | if (candidates.size() < elementsToKeep) { 879 | candidates.emplace(distance, cand); 880 | } else { 881 | if (distance < candidates.top().first) { 882 | candidates.pop(); 883 | candidates.emplace(distance, cand); 884 | } 885 | } 886 | } 887 | 888 | // Retrieve neighbours using heuristic and set connections. 889 | getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_); 890 | 891 | { 892 | std::unique_lock lock(link_list_locks_[neigh]); 893 | linklistsizeint *ll_cur; 894 | ll_cur = get_linklist_at_level(neigh, layer); 895 | int candSize = candidates.size(); 896 | setListCount(ll_cur, candSize); 897 | tableint *data = (tableint *) (ll_cur + 1); 898 | for (size_t idx = 0; idx < candSize; idx++) { 899 | data[idx] = candidates.top().second; 900 | candidates.pop(); 901 | } 902 | } 903 | } 904 | } 905 | 906 | repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy); 907 | }; 908 | 909 | void repairConnectionsForUpdate(const void *dataPoint, tableint entryPointInternalId, tableint dataPointInternalId, int dataPointLevel, int maxLevel) { 910 | tableint currObj = entryPointInternalId; 911 | if (dataPointLevel < maxLevel) { 912 | dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_); 913 | for (int level = maxLevel; level > dataPointLevel; level--) { 914 | bool changed = true; 915 | while (changed) { 916 | changed = false; 917 | unsigned int *data; 918 | std::unique_lock lock(link_list_locks_[currObj]); 919 | data = get_linklist_at_level(currObj,level); 920 | int size = getListCount(data); 921 | tableint *datal = (tableint *) (data + 1); 922 | #ifdef USE_SSE 923 | _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); 924 | #endif 925 | for (int i = 0; i < size; i++) { 926 | #ifdef USE_SSE 927 | _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0); 928 | #endif 929 | tableint cand = datal[i]; 930 | dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_); 931 | if (d < curdist) { 932 | curdist = d; 933 | currObj = cand; 934 | changed = true; 935 | } 936 | } 937 | } 938 | } 939 | } 940 | 941 | if (dataPointLevel > maxLevel) 942 | throw std::runtime_error("Level of item to be updated cannot be bigger than max level"); 943 | 944 | for (int level = dataPointLevel; level >= 0; level--) { 945 | std::priority_queue, std::vector>, CompareByFirst> topCandidates = searchBaseLayer( 946 | currObj, dataPoint, level); 947 | 948 | std::priority_queue, std::vector>, CompareByFirst> filteredTopCandidates; 949 | while (topCandidates.size() > 0) { 950 | if (topCandidates.top().second != dataPointInternalId) 951 | filteredTopCandidates.push(topCandidates.top()); 952 | 953 | topCandidates.pop(); 954 | } 955 | 956 | // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself. 957 | // To prevent self loops, the `topCandidates` is filtered and thus can be empty. 958 | if (filteredTopCandidates.size() > 0) { 959 | bool epDeleted = isMarkedDeleted(entryPointInternalId); 960 | if (epDeleted) { 961 | filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId); 962 | if (filteredTopCandidates.size() > ef_construction_) 963 | filteredTopCandidates.pop(); 964 | } 965 | 966 | currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true); 967 | } 968 | } 969 | } 970 | 971 | std::vector getConnectionsWithLock(tableint internalId, int level) { 972 | std::unique_lock lock(link_list_locks_[internalId]); 973 | unsigned int *data = get_linklist_at_level(internalId, level); 974 | int size = getListCount(data); 975 | std::vector result(size); 976 | tableint *ll = (tableint *) (data + 1); 977 | memcpy(result.data(), ll,size * sizeof(tableint)); 978 | return result; 979 | }; 980 | 981 | tableint addPoint(const void *data_point, labeltype label, int level) { 982 | 983 | tableint cur_c = 0; 984 | { 985 | // Checking if the element with the same label already exists 986 | // if so, updating it *instead* of creating a new element. 987 | std::unique_lock templock_curr(cur_element_count_guard_); 988 | auto search = label_lookup_.find(label); 989 | if (search != label_lookup_.end()) { 990 | tableint existingInternalId = search->second; 991 | 992 | templock_curr.unlock(); 993 | 994 | std::unique_lock lock_el_update(link_list_update_locks_[(existingInternalId & (max_update_element_locks - 1))]); 995 | updatePoint(data_point, existingInternalId, 1.0); 996 | return existingInternalId; 997 | } 998 | 999 | if (cur_element_count >= max_elements_) { 1000 | throw std::runtime_error("The number of elements exceeds the specified limit"); 1001 | }; 1002 | 1003 | cur_c = cur_element_count; 1004 | cur_element_count++; 1005 | label_lookup_[label] = cur_c; 1006 | } 1007 | 1008 | // Take update lock to prevent race conditions on an element with insertion/update at the same time. 1009 | std::unique_lock lock_el_update(link_list_update_locks_[(cur_c & (max_update_element_locks - 1))]); 1010 | std::unique_lock lock_el(link_list_locks_[cur_c]); 1011 | int curlevel = getRandomLevel(mult_); 1012 | if (level > 0) 1013 | curlevel = level; 1014 | 1015 | element_levels_[cur_c] = curlevel; 1016 | 1017 | 1018 | std::unique_lock templock(global); 1019 | int maxlevelcopy = maxlevel_; 1020 | if (curlevel <= maxlevelcopy) 1021 | templock.unlock(); 1022 | tableint currObj = enterpoint_node_; 1023 | tableint enterpoint_copy = enterpoint_node_; 1024 | 1025 | 1026 | memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); 1027 | 1028 | // Initialisation of the data and label 1029 | memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype)); 1030 | memcpy(getDataByInternalId(cur_c), data_point, data_size_); 1031 | 1032 | 1033 | if (curlevel) { 1034 | linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1); 1035 | if (linkLists_[cur_c] == nullptr) 1036 | throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist"); 1037 | memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1); 1038 | } 1039 | 1040 | if ((signed)currObj != -1) { 1041 | 1042 | if (curlevel < maxlevelcopy) { 1043 | 1044 | dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_); 1045 | for (int level = maxlevelcopy; level > curlevel; level--) { 1046 | 1047 | 1048 | bool changed = true; 1049 | while (changed) { 1050 | changed = false; 1051 | unsigned int *data; 1052 | std::unique_lock lock(link_list_locks_[currObj]); 1053 | data = get_linklist(currObj,level); 1054 | int size = getListCount(data); 1055 | 1056 | tableint *datal = (tableint *) (data + 1); 1057 | for (int i = 0; i < size; i++) { 1058 | tableint cand = datal[i]; 1059 | if (cand < 0 || cand > max_elements_) 1060 | throw std::runtime_error("cand error"); 1061 | dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_); 1062 | if (d < curdist) { 1063 | curdist = d; 1064 | currObj = cand; 1065 | changed = true; 1066 | } 1067 | } 1068 | } 1069 | } 1070 | } 1071 | 1072 | bool epDeleted = isMarkedDeleted(enterpoint_copy); 1073 | for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) { 1074 | if (level > maxlevelcopy || level < 0) // possible? 1075 | throw std::runtime_error("Level error"); 1076 | 1077 | std::priority_queue, std::vector>, CompareByFirst> top_candidates = searchBaseLayer( 1078 | currObj, data_point, level); 1079 | if (epDeleted) { 1080 | top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy); 1081 | if (top_candidates.size() > ef_construction_) 1082 | top_candidates.pop(); 1083 | } 1084 | currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false); 1085 | } 1086 | 1087 | 1088 | } else { 1089 | // Do nothing for the first element 1090 | enterpoint_node_ = 0; 1091 | maxlevel_ = curlevel; 1092 | 1093 | } 1094 | 1095 | //Releasing lock for the maximum level 1096 | if (curlevel > maxlevelcopy) { 1097 | enterpoint_node_ = cur_c; 1098 | maxlevel_ = curlevel; 1099 | } 1100 | return cur_c; 1101 | }; 1102 | 1103 | std::priority_queue> 1104 | searchKnn(const void *query_data, size_t k) const { 1105 | std::priority_queue> result; 1106 | if (cur_element_count == 0) return result; 1107 | 1108 | tableint currObj = enterpoint_node_; 1109 | dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); 1110 | 1111 | for (int level = maxlevel_; level > 0; level--) { 1112 | bool changed = true; 1113 | while (changed) { 1114 | changed = false; 1115 | unsigned int *data; 1116 | 1117 | data = (unsigned int *) get_linklist(currObj, level); 1118 | int size = getListCount(data); 1119 | metric_hops++; 1120 | metric_distance_computations+=size; 1121 | 1122 | tableint *datal = (tableint *) (data + 1); 1123 | for (int i = 0; i < size; i++) { 1124 | tableint cand = datal[i]; 1125 | if (cand < 0 || cand > max_elements_) 1126 | throw std::runtime_error("cand error"); 1127 | dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); 1128 | 1129 | if (d < curdist) { 1130 | curdist = d; 1131 | currObj = cand; 1132 | changed = true; 1133 | } 1134 | } 1135 | } 1136 | } 1137 | 1138 | std::priority_queue, std::vector>, CompareByFirst> top_candidates; 1139 | if (has_deletions_) { 1140 | top_candidates=searchBaseLayerST( 1141 | currObj, query_data, std::max(ef_, k)); 1142 | } 1143 | else{ 1144 | top_candidates=searchBaseLayerST( 1145 | currObj, query_data, std::max(ef_, k)); 1146 | } 1147 | 1148 | while (top_candidates.size() > k) { 1149 | top_candidates.pop(); 1150 | } 1151 | while (top_candidates.size() > 0) { 1152 | std::pair rez = top_candidates.top(); 1153 | result.push(std::pair(rez.first, getExternalLabel(rez.second))); 1154 | top_candidates.pop(); 1155 | } 1156 | return result; 1157 | }; 1158 | 1159 | template 1160 | std::vector> 1161 | searchKnn(const void* query_data, size_t k, Comp comp) { 1162 | std::vector> result; 1163 | if (cur_element_count == 0) return result; 1164 | 1165 | auto ret = searchKnn(query_data, k); 1166 | 1167 | while (!ret.empty()) { 1168 | result.push_back(ret.top()); 1169 | ret.pop(); 1170 | } 1171 | 1172 | std::sort(result.begin(), result.end(), comp); 1173 | 1174 | return result; 1175 | } 1176 | 1177 | void checkIntegrity(){ 1178 | int connections_checked=0; 1179 | std::vector inbound_connections_num(cur_element_count,0); 1180 | for(int i = 0;i < cur_element_count; i++){ 1181 | for(int l = 0;l <= element_levels_[i]; l++){ 1182 | linklistsizeint *ll_cur = get_linklist_at_level(i,l); 1183 | int size = getListCount(ll_cur); 1184 | tableint *data = (tableint *) (ll_cur + 1); 1185 | std::unordered_set s; 1186 | for (int j=0; j 0); 1188 | assert(data[j] < cur_element_count); 1189 | assert (data[j] != i); 1190 | inbound_connections_num[data[j]]++; 1191 | s.insert(data[j]); 1192 | connections_checked++; 1193 | 1194 | } 1195 | assert(s.size() == size); 1196 | } 1197 | } 1198 | if(cur_element_count > 1){ 1199 | int min1=inbound_connections_num[0], max1=inbound_connections_num[0]; 1200 | for(int i=0; i < cur_element_count; i++){ 1201 | assert(inbound_connections_num[i] > 0); 1202 | min1=std::min(inbound_connections_num[i],min1); 1203 | max1=std::max(inbound_connections_num[i],max1); 1204 | } 1205 | std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n"; 1206 | } 1207 | std::cout << "integrity ok, checked " << connections_checked << " connections\n"; 1208 | 1209 | } 1210 | 1211 | }; 1212 | 1213 | } 1214 | --------------------------------------------------------------------------------